mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-04 03:10:21 +00:00
Merge branch 'dev' into refiner
This commit is contained in:
@@ -1,9 +1,11 @@
|
||||
from collections import namedtuple
|
||||
import inspect
|
||||
from collections import namedtuple, deque
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared, sd_models
|
||||
from modules.shared import opts, state
|
||||
import k_diffusion.sampling
|
||||
|
||||
SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
|
||||
|
||||
@@ -155,3 +157,137 @@ def apply_refiner(sampler):
|
||||
return True
|
||||
|
||||
|
||||
class TorchHijack:
|
||||
def __init__(self, sampler_noises):
|
||||
# Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based
|
||||
# implementation.
|
||||
self.sampler_noises = deque(sampler_noises)
|
||||
|
||||
def __getattr__(self, item):
|
||||
if item == 'randn_like':
|
||||
return self.randn_like
|
||||
|
||||
if hasattr(torch, item):
|
||||
return getattr(torch, item)
|
||||
|
||||
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
|
||||
|
||||
def randn_like(self, x):
|
||||
if self.sampler_noises:
|
||||
noise = self.sampler_noises.popleft()
|
||||
if noise.shape == x.shape:
|
||||
return noise
|
||||
|
||||
return devices.randn_like(x)
|
||||
|
||||
|
||||
class Sampler:
|
||||
def __init__(self, funcname):
|
||||
self.funcname = funcname
|
||||
self.func = funcname
|
||||
self.extra_params = []
|
||||
self.sampler_noises = None
|
||||
self.stop_at = None
|
||||
self.eta = None
|
||||
self.config = None # set by the function calling the constructor
|
||||
self.last_latent = None
|
||||
self.s_min_uncond = None
|
||||
self.s_churn = 0.0
|
||||
self.s_tmin = 0.0
|
||||
self.s_tmax = float('inf')
|
||||
self.s_noise = 1.0
|
||||
|
||||
self.eta_option_field = 'eta_ancestral'
|
||||
self.eta_infotext_field = 'Eta'
|
||||
|
||||
self.conditioning_key = shared.sd_model.model.conditioning_key
|
||||
|
||||
self.model_wrap = None
|
||||
self.model_wrap_cfg = None
|
||||
|
||||
def callback_state(self, d):
|
||||
step = d['i']
|
||||
|
||||
if self.stop_at is not None and step > self.stop_at:
|
||||
raise InterruptedException
|
||||
|
||||
state.sampling_step = step
|
||||
shared.total_tqdm.update()
|
||||
|
||||
def launch_sampling(self, steps, func):
|
||||
state.sampling_steps = steps
|
||||
state.sampling_step = 0
|
||||
|
||||
try:
|
||||
return func()
|
||||
except RecursionError:
|
||||
print(
|
||||
'Encountered RecursionError during sampling, returning last latent. '
|
||||
'rho >5 with a polyexponential scheduler may cause this error. '
|
||||
'You should try to use a smaller rho value instead.'
|
||||
)
|
||||
return self.last_latent
|
||||
except InterruptedException:
|
||||
return self.last_latent
|
||||
|
||||
def number_of_needed_noises(self, p):
|
||||
return p.steps
|
||||
|
||||
def initialize(self, p) -> dict:
|
||||
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
|
||||
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
|
||||
self.model_wrap_cfg.step = 0
|
||||
self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
|
||||
self.eta = p.eta if p.eta is not None else getattr(opts, self.eta_option_field, 0.0)
|
||||
self.s_min_uncond = getattr(p, 's_min_uncond', 0.0)
|
||||
|
||||
k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
|
||||
|
||||
extra_params_kwargs = {}
|
||||
for param_name in self.extra_params:
|
||||
if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters:
|
||||
extra_params_kwargs[param_name] = getattr(p, param_name)
|
||||
|
||||
if 'eta' in inspect.signature(self.func).parameters:
|
||||
if self.eta != 1.0:
|
||||
p.extra_generation_params[self.eta_infotext_field] = self.eta
|
||||
|
||||
extra_params_kwargs['eta'] = self.eta
|
||||
|
||||
if len(self.extra_params) > 0:
|
||||
s_churn = getattr(opts, 's_churn', p.s_churn)
|
||||
s_tmin = getattr(opts, 's_tmin', p.s_tmin)
|
||||
s_tmax = getattr(opts, 's_tmax', p.s_tmax) or self.s_tmax # 0 = inf
|
||||
s_noise = getattr(opts, 's_noise', p.s_noise)
|
||||
|
||||
if s_churn != self.s_churn:
|
||||
extra_params_kwargs['s_churn'] = s_churn
|
||||
p.s_churn = s_churn
|
||||
p.extra_generation_params['Sigma churn'] = s_churn
|
||||
if s_tmin != self.s_tmin:
|
||||
extra_params_kwargs['s_tmin'] = s_tmin
|
||||
p.s_tmin = s_tmin
|
||||
p.extra_generation_params['Sigma tmin'] = s_tmin
|
||||
if s_tmax != self.s_tmax:
|
||||
extra_params_kwargs['s_tmax'] = s_tmax
|
||||
p.s_tmax = s_tmax
|
||||
p.extra_generation_params['Sigma tmax'] = s_tmax
|
||||
if s_noise != self.s_noise:
|
||||
extra_params_kwargs['s_noise'] = s_noise
|
||||
p.s_noise = s_noise
|
||||
p.extra_generation_params['Sigma noise'] = s_noise
|
||||
|
||||
return extra_params_kwargs
|
||||
|
||||
def create_noise_sampler(self, x, sigmas, p):
|
||||
"""For DPM++ SDE: manually create noise sampler to enable deterministic results across different batch sizes"""
|
||||
if shared.opts.no_dpmpp_sde_batch_determinism:
|
||||
return None
|
||||
|
||||
from k_diffusion.sampling import BrownianTreeNoiseSampler
|
||||
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||
current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size]
|
||||
return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds)
|
||||
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user