mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-04 19:22:32 +00:00
integrate the new samplers PR
This commit is contained in:
@@ -13,46 +13,46 @@ from modules.shared import opts, cmd_opts, state
|
||||
import modules.shared as shared
|
||||
|
||||
|
||||
SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases'])
|
||||
SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
|
||||
|
||||
samplers_k_diffusion = [
|
||||
('Euler a', 'sample_euler_ancestral', ['k_euler_a']),
|
||||
('Euler', 'sample_euler', ['k_euler']),
|
||||
('LMS', 'sample_lms', ['k_lms']),
|
||||
('Heun', 'sample_heun', ['k_heun']),
|
||||
('DPM2', 'sample_dpm_2', ['k_dpm_2']),
|
||||
('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a']),
|
||||
('DPM fast', 'sample_dpm_fast', ['k_dpm_fast']),
|
||||
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad']),
|
||||
('Euler a', 'sample_euler_ancestral', ['k_euler_a'], {}),
|
||||
('Euler', 'sample_euler', ['k_euler'], {}),
|
||||
('LMS', 'sample_lms', ['k_lms'], {}),
|
||||
('Heun', 'sample_heun', ['k_heun'], {}),
|
||||
('DPM2', 'sample_dpm_2', ['k_dpm_2'], {}),
|
||||
('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {}),
|
||||
('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {}),
|
||||
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {}),
|
||||
('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
|
||||
('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras'}),
|
||||
('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras'}),
|
||||
]
|
||||
|
||||
if opts.show_karras_scheduler_variants:
|
||||
k_diffusion.sampling.sample_dpm_2_ka = k_diffusion.sampling.sample_dpm_2
|
||||
k_diffusion.sampling.sample_dpm_2_ancestral_ka = k_diffusion.sampling.sample_dpm_2_ancestral
|
||||
k_diffusion.sampling.sample_lms_ka = k_diffusion.sampling.sample_lms
|
||||
samplers_k_diffusion_ka = [
|
||||
('LMS K Scheduling', 'sample_lms_ka', ['k_lms_ka']),
|
||||
('DPM2 K Scheduling', 'sample_dpm_2_ka', ['k_dpm_2_ka']),
|
||||
('DPM2 a K Scheduling', 'sample_dpm_2_ancestral_ka', ['k_dpm_2_a_ka']),
|
||||
]
|
||||
samplers_k_diffusion.extend(samplers_k_diffusion_ka)
|
||||
|
||||
samplers_data_k_diffusion = [
|
||||
SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases)
|
||||
for label, funcname, aliases in samplers_k_diffusion
|
||||
SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases, options)
|
||||
for label, funcname, aliases, options in samplers_k_diffusion
|
||||
if hasattr(k_diffusion.sampling, funcname)
|
||||
]
|
||||
|
||||
all_samplers = [
|
||||
*samplers_data_k_diffusion,
|
||||
SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), []),
|
||||
SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), []),
|
||||
SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}),
|
||||
SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}),
|
||||
]
|
||||
|
||||
samplers = []
|
||||
samplers_for_img2img = []
|
||||
|
||||
|
||||
def create_sampler_with_index(list_of_configs, index, model):
|
||||
config = list_of_configs[index]
|
||||
sampler = config.constructor(model)
|
||||
sampler.config = config
|
||||
|
||||
return sampler
|
||||
|
||||
|
||||
def set_samplers():
|
||||
global samplers, samplers_for_img2img
|
||||
|
||||
@@ -130,6 +130,7 @@ class VanillaStableDiffusionSampler:
|
||||
self.step = 0
|
||||
self.eta = None
|
||||
self.default_eta = 0.0
|
||||
self.config = None
|
||||
|
||||
def number_of_needed_noises(self, p):
|
||||
return 0
|
||||
@@ -291,6 +292,7 @@ class KDiffusionSampler:
|
||||
self.stop_at = None
|
||||
self.eta = None
|
||||
self.default_eta = 1.0
|
||||
self.config = None
|
||||
|
||||
def callback_state(self, d):
|
||||
store_latent(d["denoised"])
|
||||
@@ -355,11 +357,12 @@ class KDiffusionSampler:
|
||||
steps = steps or p.steps
|
||||
|
||||
if p.sampler_noise_scheduler_override:
|
||||
sigmas = p.sampler_noise_scheduler_override(steps)
|
||||
elif self.funcname.endswith('ka'):
|
||||
sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=0.1, sigma_max=10, device=shared.device)
|
||||
sigmas = p.sampler_noise_scheduler_override(steps)
|
||||
elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
|
||||
sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=0.1, sigma_max=10, device=shared.device)
|
||||
else:
|
||||
sigmas = self.model_wrap.get_sigmas(steps)
|
||||
sigmas = self.model_wrap.get_sigmas(steps)
|
||||
|
||||
x = x * sigmas[0]
|
||||
|
||||
extra_params_kwargs = self.initialize(p)
|
||||
|
Reference in New Issue
Block a user