mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-04 03:10:21 +00:00
Working UniPC (for batch size 1)
This commit is contained in:
@@ -7,19 +7,27 @@ import torch
|
||||
|
||||
from modules.shared import state
|
||||
from modules import sd_samplers_common, prompt_parser, shared
|
||||
import modules.models.diffusion.uni_pc
|
||||
|
||||
|
||||
samplers_data_compvis = [
|
||||
sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}),
|
||||
sd_samplers_common.SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}),
|
||||
sd_samplers_common.SamplerData('UniPC', lambda model: VanillaStableDiffusionSampler(modules.models.diffusion.uni_pc.UniPCSampler, model), [], {}),
|
||||
]
|
||||
|
||||
|
||||
class VanillaStableDiffusionSampler:
|
||||
def __init__(self, constructor, sd_model):
|
||||
self.sampler = constructor(sd_model)
|
||||
self.is_ddim = hasattr(self.sampler, 'p_sample_ddim')
|
||||
self.is_plms = hasattr(self.sampler, 'p_sample_plms')
|
||||
self.orig_p_sample_ddim = self.sampler.p_sample_plms if self.is_plms else self.sampler.p_sample_ddim
|
||||
self.is_unipc = isinstance(self.sampler, modules.models.diffusion.uni_pc.UniPCSampler)
|
||||
self.orig_p_sample_ddim = None
|
||||
if self.is_plms:
|
||||
self.orig_p_sample_ddim = self.sampler.p_sample_plms
|
||||
elif self.is_ddim:
|
||||
self.orig_p_sample_ddim = self.sampler.p_sample_ddim
|
||||
self.mask = None
|
||||
self.nmask = None
|
||||
self.init_latent = None
|
||||
@@ -45,6 +53,15 @@ class VanillaStableDiffusionSampler:
|
||||
return self.last_latent
|
||||
|
||||
def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
|
||||
x_dec, ts, cond, unconditional_conditioning = self.before_sample(x_dec, ts, cond, unconditional_conditioning)
|
||||
|
||||
res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
|
||||
|
||||
x_dec, ts, cond, unconditional_conditioning, res = self.after_sample(x_dec, ts, cond, unconditional_conditioning, res)
|
||||
|
||||
return res
|
||||
|
||||
def before_sample(self, x, ts, cond, unconditional_conditioning):
|
||||
if state.interrupted or state.skipped:
|
||||
raise sd_samplers_common.InterruptedException
|
||||
|
||||
@@ -76,7 +93,7 @@ class VanillaStableDiffusionSampler:
|
||||
|
||||
if self.mask is not None:
|
||||
img_orig = self.sampler.model.q_sample(self.init_latent, ts)
|
||||
x_dec = img_orig * self.mask + self.nmask * x_dec
|
||||
x = img_orig * self.mask + self.nmask * x
|
||||
|
||||
# Wrap the image conditioning back up since the DDIM code can accept the dict directly.
|
||||
# Note that they need to be lists because it just concatenates them later.
|
||||
@@ -84,7 +101,13 @@ class VanillaStableDiffusionSampler:
|
||||
cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]}
|
||||
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
|
||||
|
||||
res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
|
||||
return x, ts, cond, unconditional_conditioning
|
||||
|
||||
def after_sample(self, x, ts, cond, uncond, res):
|
||||
if self.is_unipc:
|
||||
# unipc model_fn returns (pred_x0)
|
||||
# p_sample_ddim returns (x_prev, pred_x0)
|
||||
res = (None, res[0])
|
||||
|
||||
if self.mask is not None:
|
||||
self.last_latent = self.init_latent * self.mask + self.nmask * res[1]
|
||||
@@ -97,7 +120,7 @@ class VanillaStableDiffusionSampler:
|
||||
state.sampling_step = self.step
|
||||
shared.total_tqdm.update()
|
||||
|
||||
return res
|
||||
return x, ts, cond, uncond, res
|
||||
|
||||
def initialize(self, p):
|
||||
self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim
|
||||
@@ -107,12 +130,14 @@ class VanillaStableDiffusionSampler:
|
||||
for fieldname in ['p_sample_ddim', 'p_sample_plms']:
|
||||
if hasattr(self.sampler, fieldname):
|
||||
setattr(self.sampler, fieldname, self.p_sample_ddim_hook)
|
||||
if self.is_unipc:
|
||||
self.sampler.set_hooks(lambda x, t, c, u: self.before_sample(x, t, c, u), lambda x, t, c, u, r: self.after_sample(x, t, c, u, r))
|
||||
|
||||
self.mask = p.mask if hasattr(p, 'mask') else None
|
||||
self.nmask = p.nmask if hasattr(p, 'nmask') else None
|
||||
|
||||
def adjust_steps_if_invalid(self, p, num_steps):
|
||||
if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'):
|
||||
if ((self.config.name == 'DDIM' or self.config.name == "UniPC") and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'):
|
||||
valid_step = 999 / (1000 // num_steps)
|
||||
if valid_step == math.floor(valid_step):
|
||||
return int(valid_step) + 1
|
||||
|
Reference in New Issue
Block a user