Implementation for sgm_uniform branch

This commit is contained in:
Kohaku-Blueleaf
2024-03-19 20:05:54 +08:00
parent c4a00affc5
commit a6b5a513f9
3 changed files with 21 additions and 2 deletions

View File

@@ -3,6 +3,7 @@ import inspect
import k_diffusion.sampling
from modules import sd_samplers_common, sd_samplers_extra, sd_samplers_cfg_denoiser
from modules.sd_samplers_cfg_denoiser import CFGDenoiser # noqa: F401
from modules.sd_samplers_custom_schedulers import sgm_uniform
from modules.script_callbacks import ExtraNoiseParams, extra_noise_callback
from modules.shared import opts
@@ -62,7 +63,8 @@ k_diffusion_scheduler = {
'Automatic': None,
'karras': k_diffusion.sampling.get_sigmas_karras,
'exponential': k_diffusion.sampling.get_sigmas_exponential,
'polyexponential': k_diffusion.sampling.get_sigmas_polyexponential
'polyexponential': k_diffusion.sampling.get_sigmas_polyexponential,
'sgm_uniform' : sgm_uniform,
}
@@ -121,6 +123,11 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
if opts.k_sched_type != 'exponential' and opts.rho != 0 and opts.rho != default_rho:
sigmas_kwargs['rho'] = opts.rho
p.extra_generation_params["Schedule rho"] = opts.rho
if opts.k_sched_type == 'sgm_uniform':
# Ensure the "step" will be target step + 1
steps += 1 if not discard_next_to_last_sigma else 0
sigmas_kwargs['inner_model'] = self.model_wrap
sigmas_kwargs.pop('rho', None)
sigmas = sigmas_func(n=steps, **sigmas_kwargs, device=shared.device)
elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':