Add/modify CFG callbacks

Required by self-attn guidance extension
https://github.com/ashen-sensored/sd_webui_SAG
This commit is contained in:
catboxanon
2023-05-14 01:49:41 +00:00
parent e8eea1bb7a
commit 3078001439
2 changed files with 42 additions and 1 deletions

View File

@@ -8,6 +8,7 @@ from modules.shared import opts, state
import modules.shared as shared
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback
from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback
samplers_k_diffusion = [
('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}),
@@ -160,7 +161,7 @@ class CFGDenoiser(torch.nn.Module):
fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes])
x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be
denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps)
denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model)
cfg_denoised_callback(denoised_params)
devices.test_for_nans(x_out, "unet")
@@ -180,6 +181,11 @@ class CFGDenoiser(torch.nn.Module):
if self.mask is not None:
denoised = self.init_latent * self.mask + self.nmask * denoised
after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps)
cfg_after_cfg_callback(after_cfg_callback_params)
if after_cfg_callback_params.output_altered:
denoised = after_cfg_callback_params.x
self.step += 1
return denoised