Fix logspam and live previews

This commit is contained in:
space-nuko
2023-02-10 04:47:08 -08:00
parent 1253199889
commit 21880eb9e5
3 changed files with 41 additions and 31 deletions

View File

@@ -103,16 +103,11 @@ class VanillaStableDiffusionSampler:
return x, ts, cond, unconditional_conditioning
def after_sample(self, x, ts, cond, uncond, res):
if self.is_unipc:
# unipc model_fn returns (pred_x0)
# p_sample_ddim returns (x_prev, pred_x0)
res = (None, res[0])
def update_step(self, last_latent):
if self.mask is not None:
self.last_latent = self.init_latent * self.mask + self.nmask * res[1]
self.last_latent = self.init_latent * self.mask + self.nmask * last_latent
else:
self.last_latent = res[1]
self.last_latent = last_latent
sd_samplers_common.store_latent(self.last_latent)
@@ -120,8 +115,15 @@ class VanillaStableDiffusionSampler:
state.sampling_step = self.step
shared.total_tqdm.update()
def after_sample(self, x, ts, cond, uncond, res):
if not self.is_unipc:
self.update_step(res[1])
return x, ts, cond, uncond, res
def unipc_after_update(self, x, model_x):
self.update_step(x)
def initialize(self, p):
self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim
if self.eta != 0.0:
@@ -131,7 +133,7 @@ class VanillaStableDiffusionSampler:
if hasattr(self.sampler, fieldname):
setattr(self.sampler, fieldname, self.p_sample_ddim_hook)
if self.is_unipc:
self.sampler.set_hooks(lambda x, t, c, u: self.before_sample(x, t, c, u), lambda x, t, c, u, r: self.after_sample(x, t, c, u, r))
self.sampler.set_hooks(lambda x, t, c, u: self.before_sample(x, t, c, u), lambda x, t, c, u, r: self.after_sample(x, t, c, u, r), lambda x, mx: self.unipc_after_update(x, mx))
self.mask = p.mask if hasattr(p, 'mask') else None
self.nmask = p.nmask if hasattr(p, 'nmask') else None