Skip to content

Commit 6376fec

Browse files
committed
Add Laplace scheduler that samples more around mid-range noise levels (around log SNR=0), increasing performance (lower FID) with faster convergence speed, and robust to resolution and objective. Reference: https://arxiv.org/pdf/2407.03297.
1 parent dcf836c commit 6376fec

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

Diff for: src/diffusers/schedulers/scheduling_ddpm.py

+9
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,13 @@ def betas_for_alpha_bar(
7272

7373
def alpha_bar_fn(t):
7474
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
75+
76+
elif alpha_transform_type == "laplace":
77+
78+
def alpha_bar_fn(t):
79+
lmb = - 0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
80+
snr = math.exp(lmb)
81+
return math.sqrt(snr / (1 + snr))
7582

7683
elif alpha_transform_type == "exp":
7784

@@ -206,6 +213,8 @@ def __init__(
206213
elif beta_schedule == "squaredcos_cap_v2":
207214
# Glide cosine schedule
208215
self.betas = betas_for_alpha_bar(num_train_timesteps)
216+
elif beta_schedule == "laplace":
217+
self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_transform_type="laplace")
209218
elif beta_schedule == "sigmoid":
210219
# GeoDiff sigmoid schedule
211220
betas = torch.linspace(-6, 6, num_train_timesteps)

0 commit comments

Comments
 (0)