mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-04 11:12:35 +00:00
added progressbar
added an option to disable progressbar added interrupt support to DDIM/PLMS
This commit is contained in:
@@ -1,10 +1,12 @@
|
||||
from collections import namedtuple
|
||||
|
||||
import ldm.models.diffusion.ddim
|
||||
import torch
|
||||
import tqdm
|
||||
|
||||
import k_diffusion.sampling
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.models.diffusion.plms import PLMSSampler
|
||||
import ldm.models.diffusion.ddim
|
||||
import ldm.models.diffusion.plms
|
||||
|
||||
from modules.shared import opts, cmd_opts, state
|
||||
import modules.shared as shared
|
||||
@@ -29,8 +31,8 @@ samplers_data_k_diffusion = [
|
||||
|
||||
samplers = [
|
||||
*samplers_data_k_diffusion,
|
||||
SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(DDIMSampler, model), []),
|
||||
SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(PLMSSampler, model), []),
|
||||
SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), []),
|
||||
SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), []),
|
||||
]
|
||||
samplers_for_img2img = [x for x in samplers if x.name != 'PLMS']
|
||||
|
||||
@@ -43,6 +45,23 @@ def p_sample_ddim_hook(sampler_wrapper, x_dec, cond, ts, *args, **kwargs):
|
||||
return sampler_wrapper.orig_p_sample_ddim(x_dec, cond, ts, *args, **kwargs)
|
||||
|
||||
|
||||
def extended_tdqm(sequence, *args, desc=None, **kwargs):
|
||||
state.sampling_steps = len(sequence)
|
||||
state.sampling_step = 0
|
||||
|
||||
for x in tqdm.tqdm(sequence, *args, desc=state.job, **kwargs):
|
||||
if state.interrupted:
|
||||
break
|
||||
|
||||
yield x
|
||||
|
||||
state.sampling_step += 1
|
||||
|
||||
|
||||
ldm.models.diffusion.ddim.tqdm = lambda *args, desc=None, **kwargs: extended_tdqm(*args, desc=desc, **kwargs)
|
||||
ldm.models.diffusion.plms.tqdm = lambda *args, desc=None, **kwargs: extended_tdqm(*args, desc=desc, **kwargs)
|
||||
|
||||
|
||||
class VanillaStableDiffusionSampler:
|
||||
def __init__(self, constructor, sd_model):
|
||||
self.sampler = constructor(sd_model)
|
||||
@@ -102,13 +121,18 @@ class CFGDenoiser(torch.nn.Module):
|
||||
return denoised
|
||||
|
||||
|
||||
def extended_trange(*args, **kwargs):
|
||||
for x in tqdm.trange(*args, desc=state.job, **kwargs):
|
||||
def extended_trange(count, *args, **kwargs):
|
||||
state.sampling_steps = count
|
||||
state.sampling_step = 0
|
||||
|
||||
for x in tqdm.trange(count, *args, desc=state.job, **kwargs):
|
||||
if state.interrupted:
|
||||
break
|
||||
|
||||
yield x
|
||||
|
||||
state.sampling_step += 1
|
||||
|
||||
|
||||
class KDiffusionSampler:
|
||||
def __init__(self, funcname, sd_model):
|
||||
|
Reference in New Issue
Block a user