mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-04 11:12:35 +00:00
split samplers into one more files for k-diffusion
This commit is contained in:
@@ -2,18 +2,12 @@ from collections import deque
|
||||
import torch
|
||||
import inspect
|
||||
import k_diffusion.sampling
|
||||
import ldm.models.diffusion.ddim
|
||||
import ldm.models.diffusion.plms
|
||||
from modules import prompt_parser, devices, sd_samplers_common, sd_samplers_compvis
|
||||
|
||||
from modules.shared import opts, state
|
||||
import modules.shared as shared
|
||||
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
|
||||
|
||||
# imports for functions that previously were here and are used by other modules
|
||||
from modules.sd_samplers_common import samples_to_image_grid, sample_to_image
|
||||
|
||||
|
||||
samplers_k_diffusion = [
|
||||
('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}),
|
||||
('Euler', 'sample_euler', ['k_euler'], {}),
|
||||
@@ -40,50 +34,6 @@ samplers_data_k_diffusion = [
|
||||
if hasattr(k_diffusion.sampling, funcname)
|
||||
]
|
||||
|
||||
all_samplers = [
|
||||
*samplers_data_k_diffusion,
|
||||
sd_samplers_common.SamplerData('DDIM', lambda model: sd_samplers_compvis.VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}),
|
||||
sd_samplers_common.SamplerData('PLMS', lambda model: sd_samplers_compvis.VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}),
|
||||
]
|
||||
all_samplers_map = {x.name: x for x in all_samplers}
|
||||
|
||||
samplers = []
|
||||
samplers_for_img2img = []
|
||||
samplers_map = {}
|
||||
|
||||
|
||||
def create_sampler(name, model):
|
||||
if name is not None:
|
||||
config = all_samplers_map.get(name, None)
|
||||
else:
|
||||
config = all_samplers[0]
|
||||
|
||||
assert config is not None, f'bad sampler name: {name}'
|
||||
|
||||
sampler = config.constructor(model)
|
||||
sampler.config = config
|
||||
|
||||
return sampler
|
||||
|
||||
|
||||
def set_samplers():
|
||||
global samplers, samplers_for_img2img
|
||||
|
||||
hidden = set(opts.hide_samplers)
|
||||
hidden_img2img = set(opts.hide_samplers + ['PLMS'])
|
||||
|
||||
samplers = [x for x in all_samplers if x.name not in hidden]
|
||||
samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img]
|
||||
|
||||
samplers_map.clear()
|
||||
for sampler in all_samplers:
|
||||
samplers_map[sampler.name.lower()] = sampler.name
|
||||
for alias in sampler.aliases:
|
||||
samplers_map[alias.lower()] = sampler.name
|
||||
|
||||
|
||||
set_samplers()
|
||||
|
||||
sampler_extra_params = {
|
||||
'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
||||
'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
|
||||
@@ -92,6 +42,13 @@ sampler_extra_params = {
|
||||
|
||||
|
||||
class CFGDenoiser(torch.nn.Module):
|
||||
"""
|
||||
Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet)
|
||||
that can take a noisy picture and produce a noise-free picture using two guidances (prompts)
|
||||
instead of one. Originally, the second prompt is just an empty string, but we use non-empty
|
||||
negative prompt.
|
||||
"""
|
||||
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.inner_model = model
|
||||
|
Reference in New Issue
Block a user