initial refiner support

This commit is contained in:
AUTOMATIC1111
2023-08-06 17:01:07 +03:00
parent 57e8a11d17
commit f1975b0213
6 changed files with 76 additions and 9 deletions

View File

@@ -19,7 +19,8 @@ samplers_data_compvis = [
class VanillaStableDiffusionSampler:
def __init__(self, constructor, sd_model):
self.sampler = constructor(sd_model)
self.p = None
self.sampler = constructor(shared.sd_model)
self.is_ddim = hasattr(self.sampler, 'p_sample_ddim')
self.is_plms = hasattr(self.sampler, 'p_sample_plms')
self.is_unipc = isinstance(self.sampler, modules.models.diffusion.uni_pc.UniPCSampler)
@@ -32,6 +33,7 @@ class VanillaStableDiffusionSampler:
self.nmask = None
self.init_latent = None
self.sampler_noises = None
self.steps = None
self.step = 0
self.stop_at = None
self.eta = None
@@ -44,6 +46,7 @@ class VanillaStableDiffusionSampler:
return 0
def launch_sampling(self, steps, func):
self.steps = steps
state.sampling_steps = steps
state.sampling_step = 0
@@ -61,10 +64,15 @@ class VanillaStableDiffusionSampler:
return res
def update_inner_model(self):
self.sampler.model = shared.sd_model
def before_sample(self, x, ts, cond, unconditional_conditioning):
if state.interrupted or state.skipped:
raise sd_samplers_common.InterruptedException
sd_samplers_common.apply_refiner(self)
if self.stop_at is not None and self.step > self.stop_at:
raise sd_samplers_common.InterruptedException
@@ -134,6 +142,8 @@ class VanillaStableDiffusionSampler:
self.update_step(x)
def initialize(self, p):
self.p = p
if self.is_ddim:
self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim
else: