mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-04 03:10:21 +00:00
add CFG denoiser implementation for DDIM, PLMS and UniPC (this is the commit when you can run both old and new implementations to compare them)
This commit is contained in:
@@ -39,7 +39,7 @@ class CFGDenoiser(torch.nn.Module):
|
||||
negative prompt.
|
||||
"""
|
||||
|
||||
def __init__(self, model):
|
||||
def __init__(self, model, sampler):
|
||||
super().__init__()
|
||||
self.inner_model = model
|
||||
self.mask = None
|
||||
@@ -48,6 +48,7 @@ class CFGDenoiser(torch.nn.Module):
|
||||
self.step = 0
|
||||
self.image_cfg_scale = None
|
||||
self.padded_cond_uncond = False
|
||||
self.sampler = sampler
|
||||
|
||||
def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
|
||||
denoised_uncond = x_out[-uncond.shape[0]:]
|
||||
@@ -65,6 +66,9 @@ class CFGDenoiser(torch.nn.Module):
|
||||
|
||||
return denoised
|
||||
|
||||
def get_pred_x0(self, x_in, x_out, sigma):
|
||||
return x_out
|
||||
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
|
||||
if state.interrupted or state.skipped:
|
||||
raise sd_samplers_common.InterruptedException
|
||||
@@ -78,6 +82,9 @@ class CFGDenoiser(torch.nn.Module):
|
||||
|
||||
assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
|
||||
|
||||
if self.mask is not None:
|
||||
x = self.init_latent * self.mask + self.nmask * x
|
||||
|
||||
batch_size = len(conds_list)
|
||||
repeats = [len(conds_list[i]) for i in range(batch_size)]
|
||||
|
||||
@@ -170,11 +177,6 @@ class CFGDenoiser(torch.nn.Module):
|
||||
|
||||
devices.test_for_nans(x_out, "unet")
|
||||
|
||||
if opts.live_preview_content == "Prompt":
|
||||
sd_samplers_common.store_latent(torch.cat([x_out[i:i+1] for i in denoised_image_indexes]))
|
||||
elif opts.live_preview_content == "Negative prompt":
|
||||
sd_samplers_common.store_latent(x_out[-uncond.shape[0]:])
|
||||
|
||||
if is_edit_model:
|
||||
denoised = self.combine_denoised_for_edit_model(x_out, cond_scale)
|
||||
elif skip_uncond:
|
||||
@@ -182,8 +184,16 @@ class CFGDenoiser(torch.nn.Module):
|
||||
else:
|
||||
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
|
||||
|
||||
if self.mask is not None:
|
||||
denoised = self.init_latent * self.mask + self.nmask * denoised
|
||||
self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma)
|
||||
|
||||
if opts.live_preview_content == "Prompt":
|
||||
preview = self.sampler.last_latent
|
||||
elif opts.live_preview_content == "Negative prompt":
|
||||
preview = self.get_pred_x0(x_in[-uncond.shape[0]:], x_out[-uncond.shape[0]:], sigma)
|
||||
else:
|
||||
preview = self.get_pred_x0(torch.cat([x_in[i:i+1] for i in denoised_image_indexes]), torch.cat([denoised[i:i+1] for i in denoised_image_indexes]), sigma)
|
||||
|
||||
sd_samplers_common.store_latent(preview)
|
||||
|
||||
after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps)
|
||||
cfg_after_cfg_callback(after_cfg_callback_params)
|
||||
@@ -192,27 +202,3 @@ class CFGDenoiser(torch.nn.Module):
|
||||
self.step += 1
|
||||
return denoised
|
||||
|
||||
|
||||
class TorchHijack:
|
||||
def __init__(self, sampler_noises):
|
||||
# Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based
|
||||
# implementation.
|
||||
self.sampler_noises = deque(sampler_noises)
|
||||
|
||||
def __getattr__(self, item):
|
||||
if item == 'randn_like':
|
||||
return self.randn_like
|
||||
|
||||
if hasattr(torch, item):
|
||||
return getattr(torch, item)
|
||||
|
||||
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
|
||||
|
||||
def randn_like(self, x):
|
||||
if self.sampler_noises:
|
||||
noise = self.sampler_noises.popleft()
|
||||
if noise.shape == x.shape:
|
||||
return noise
|
||||
|
||||
return devices.randn_like(x)
|
||||
|
||||
|
Reference in New Issue
Block a user