mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-04 11:12:35 +00:00
Merge pull request #9177 from devNegative-asm/master
(Optimization) Option to remove negative conditioning at low sigma values
This commit is contained in:
@@ -76,7 +76,7 @@ class CFGDenoiser(torch.nn.Module):
|
||||
|
||||
return denoised
|
||||
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
|
||||
if state.interrupted or state.skipped:
|
||||
raise sd_samplers_common.InterruptedException
|
||||
|
||||
@@ -116,6 +116,14 @@ class CFGDenoiser(torch.nn.Module):
|
||||
tensor = denoiser_params.text_cond
|
||||
uncond = denoiser_params.text_uncond
|
||||
|
||||
if self.step % 2 and s_min_uncond > 0 and not is_edit_model:
|
||||
# alternating uncond allows for higher thresholds without the quality loss normally expected from raising it
|
||||
sigma_threshold = s_min_uncond
|
||||
if(torch.dot(sigma,sigma) < sigma.shape[0] * (sigma_threshold*sigma_threshold) ):
|
||||
uncond = torch.zeros([0,0,uncond.shape[2]])
|
||||
x_in=x_in[:x_in.shape[0]//2]
|
||||
sigma_in=sigma_in[:sigma_in.shape[0]//2]
|
||||
|
||||
if tensor.shape[1] == uncond.shape[1]:
|
||||
if not is_edit_model:
|
||||
cond_in = torch.cat([tensor, uncond])
|
||||
@@ -144,7 +152,8 @@ class CFGDenoiser(torch.nn.Module):
|
||||
|
||||
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b]))
|
||||
|
||||
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict([uncond], image_cond_in[-uncond.shape[0]:]))
|
||||
if uncond.shape[0]:
|
||||
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict([uncond], image_cond_in[-uncond.shape[0]:]))
|
||||
|
||||
denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps)
|
||||
cfg_denoised_callback(denoised_params)
|
||||
@@ -152,12 +161,15 @@ class CFGDenoiser(torch.nn.Module):
|
||||
devices.test_for_nans(x_out, "unet")
|
||||
|
||||
if opts.live_preview_content == "Prompt":
|
||||
sd_samplers_common.store_latent(x_out[0:uncond.shape[0]])
|
||||
sd_samplers_common.store_latent(x_out[0:x_out.shape[0]-uncond.shape[0]])
|
||||
elif opts.live_preview_content == "Negative prompt":
|
||||
sd_samplers_common.store_latent(x_out[-uncond.shape[0]:])
|
||||
|
||||
if not is_edit_model:
|
||||
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
|
||||
if uncond.shape[0]:
|
||||
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
|
||||
else:
|
||||
denoised = x_out
|
||||
else:
|
||||
denoised = self.combine_denoised_for_edit_model(x_out, cond_scale)
|
||||
|
||||
@@ -165,7 +177,6 @@ class CFGDenoiser(torch.nn.Module):
|
||||
denoised = self.init_latent * self.mask + self.nmask * denoised
|
||||
|
||||
self.step += 1
|
||||
|
||||
return denoised
|
||||
|
||||
|
||||
@@ -244,6 +255,7 @@ class KDiffusionSampler:
|
||||
self.model_wrap_cfg.step = 0
|
||||
self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
|
||||
self.eta = p.eta if p.eta is not None else opts.eta_ancestral
|
||||
self.s_min_uncond = getattr(p, 's_min_uncond', 0.0)
|
||||
|
||||
k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
|
||||
|
||||
@@ -326,6 +338,7 @@ class KDiffusionSampler:
|
||||
'image_cond': image_conditioning,
|
||||
'uncond': unconditional_conditioning,
|
||||
'cond_scale': p.cfg_scale,
|
||||
's_min_uncond': self.s_min_uncond
|
||||
}
|
||||
|
||||
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||
@@ -359,7 +372,8 @@ class KDiffusionSampler:
|
||||
'cond': conditioning,
|
||||
'image_cond': image_conditioning,
|
||||
'uncond': unconditional_conditioning,
|
||||
'cond_scale': p.cfg_scale
|
||||
'cond_scale': p.cfg_scale,
|
||||
's_min_uncond': self.s_min_uncond
|
||||
}, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||
|
||||
return samples
|
||||
|
Reference in New Issue
Block a user