added progressbar

added an option to disable progressbar
added interrupt support to DDIM/PLMS
This commit is contained in:
AUTOMATIC
2022-09-06 02:09:01 +03:00
parent b6763fb884
commit a243bc7859
11 changed files with 170 additions and 9 deletions

View File

@@ -1,10 +1,12 @@
from collections import namedtuple
import ldm.models.diffusion.ddim
import torch
import tqdm
import k_diffusion.sampling
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
@@ -29,8 +31,8 @@ samplers_data_k_diffusion = [
samplers = [
*samplers_data_k_diffusion,
SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(DDIMSampler, model), []),
SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(PLMSSampler, model), []),
SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), []),
SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), []),
]
samplers_for_img2img = [x for x in samplers if x.name != 'PLMS']
@@ -43,6 +45,23 @@ def p_sample_ddim_hook(sampler_wrapper, x_dec, cond, ts, *args, **kwargs):
return sampler_wrapper.orig_p_sample_ddim(x_dec, cond, ts, *args, **kwargs)
def extended_tdqm(sequence, *args, desc=None, **kwargs):
state.sampling_steps = len(sequence)
state.sampling_step = 0
for x in tqdm.tqdm(sequence, *args, desc=state.job, **kwargs):
if state.interrupted:
break
yield x
state.sampling_step += 1
ldm.models.diffusion.ddim.tqdm = lambda *args, desc=None, **kwargs: extended_tdqm(*args, desc=desc, **kwargs)
ldm.models.diffusion.plms.tqdm = lambda *args, desc=None, **kwargs: extended_tdqm(*args, desc=desc, **kwargs)
class VanillaStableDiffusionSampler:
def __init__(self, constructor, sd_model):
self.sampler = constructor(sd_model)
@@ -102,13 +121,18 @@ class CFGDenoiser(torch.nn.Module):
return denoised
def extended_trange(*args, **kwargs):
for x in tqdm.trange(*args, desc=state.job, **kwargs):
def extended_trange(count, *args, **kwargs):
state.sampling_steps = count
state.sampling_step = 0
for x in tqdm.trange(count, *args, desc=state.job, **kwargs):
if state.interrupted:
break
yield x
state.sampling_step += 1
class KDiffusionSampler:
def __init__(self, funcname, sd_model):