mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-03 19:02:27 +00:00
initial refiner support
This commit is contained in:
@@ -19,7 +19,8 @@ samplers_data_compvis = [
|
||||
|
||||
class VanillaStableDiffusionSampler:
|
||||
def __init__(self, constructor, sd_model):
|
||||
self.sampler = constructor(sd_model)
|
||||
self.p = None
|
||||
self.sampler = constructor(shared.sd_model)
|
||||
self.is_ddim = hasattr(self.sampler, 'p_sample_ddim')
|
||||
self.is_plms = hasattr(self.sampler, 'p_sample_plms')
|
||||
self.is_unipc = isinstance(self.sampler, modules.models.diffusion.uni_pc.UniPCSampler)
|
||||
@@ -32,6 +33,7 @@ class VanillaStableDiffusionSampler:
|
||||
self.nmask = None
|
||||
self.init_latent = None
|
||||
self.sampler_noises = None
|
||||
self.steps = None
|
||||
self.step = 0
|
||||
self.stop_at = None
|
||||
self.eta = None
|
||||
@@ -44,6 +46,7 @@ class VanillaStableDiffusionSampler:
|
||||
return 0
|
||||
|
||||
def launch_sampling(self, steps, func):
|
||||
self.steps = steps
|
||||
state.sampling_steps = steps
|
||||
state.sampling_step = 0
|
||||
|
||||
@@ -61,10 +64,15 @@ class VanillaStableDiffusionSampler:
|
||||
|
||||
return res
|
||||
|
||||
def update_inner_model(self):
|
||||
self.sampler.model = shared.sd_model
|
||||
|
||||
def before_sample(self, x, ts, cond, unconditional_conditioning):
|
||||
if state.interrupted or state.skipped:
|
||||
raise sd_samplers_common.InterruptedException
|
||||
|
||||
sd_samplers_common.apply_refiner(self)
|
||||
|
||||
if self.stop_at is not None and self.step > self.stop_at:
|
||||
raise sd_samplers_common.InterruptedException
|
||||
|
||||
@@ -134,6 +142,8 @@ class VanillaStableDiffusionSampler:
|
||||
self.update_step(x)
|
||||
|
||||
def initialize(self, p):
|
||||
self.p = p
|
||||
|
||||
if self.is_ddim:
|
||||
self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim
|
||||
else:
|
||||
|
Reference in New Issue
Block a user