mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-03 19:02:27 +00:00
add new sampler DDIM CFG++
This commit is contained in:
@@ -40,6 +40,43 @@ def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def ddim_cfgpp(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0):
|
||||
""" Implements CFG++: Manifold-constrained Classifier Free Guidance For Diffusion Models (2024).
|
||||
Uses the unconditional noise prediction instead of the conditional noise to guide the denoising direction.
|
||||
The CFG scale is divided by 12.5 to map CFG from [0.0, 12.5] to [0, 1.0].
|
||||
"""
|
||||
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
|
||||
alphas = alphas_cumprod[timesteps]
|
||||
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(float64(x))
|
||||
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
|
||||
sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy()))
|
||||
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones((x.shape[0]))
|
||||
s_x = x.new_ones((x.shape[0], 1, 1, 1))
|
||||
for i in tqdm.trange(len(timesteps) - 1, disable=disable):
|
||||
index = len(timesteps) - 1 - i
|
||||
|
||||
e_t = model(x, timesteps[index].item() * s_in, **extra_args)
|
||||
last_noise_uncond = model.last_noise_uncond
|
||||
|
||||
a_t = alphas[index].item() * s_x
|
||||
a_prev = alphas_prev[index].item() * s_x
|
||||
sigma_t = sigmas[index].item() * s_x
|
||||
sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_x
|
||||
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * last_noise_uncond
|
||||
noise = sigma_t * k_diffusion.sampling.torch.randn_like(x)
|
||||
x = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': 0, 'sigma_hat': 0, 'denoised': pred_x0})
|
||||
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def plms(model, x, timesteps, extra_args=None, callback=None, disable=None):
|
||||
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
|
||||
|
Reference in New Issue
Block a user