prompt editing

This commit is contained in:
AUTOMATIC
2022-09-15 13:10:16 +03:00
parent b28cf84c36
commit f2693bec08
3 changed files with 161 additions and 19 deletions

View File

@@ -7,6 +7,7 @@ from PIL import Image
import k_diffusion.sampling
import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms
from modules import prompt_parser
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
@@ -53,20 +54,6 @@ def store_latent(decoded):
shared.state.current_image = sample_to_image(decoded)
def p_sample_ddim_hook(sampler_wrapper, x_dec, cond, ts, *args, **kwargs):
if sampler_wrapper.mask is not None:
img_orig = sampler_wrapper.sampler.model.q_sample(sampler_wrapper.init_latent, ts)
x_dec = img_orig * sampler_wrapper.mask + sampler_wrapper.nmask * x_dec
res = sampler_wrapper.orig_p_sample_ddim(x_dec, cond, ts, *args, **kwargs)
if sampler_wrapper.mask is not None:
store_latent(sampler_wrapper.init_latent * sampler_wrapper.mask + sampler_wrapper.nmask * res[1])
else:
store_latent(res[1])
return res
def extended_tdqm(sequence, *args, desc=None, **kwargs):
state.sampling_steps = len(sequence)
@@ -93,6 +80,25 @@ class VanillaStableDiffusionSampler:
self.mask = None
self.nmask = None
self.init_latent = None
self.step = 0
def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
cond = prompt_parser.reconstruct_cond_batch(cond, self.step)
unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
if self.mask is not None:
img_orig = self.sampler.model.q_sample(self.init_latent, ts)
x_dec = img_orig * self.mask + self.nmask * x_dec
res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
if self.mask is not None:
store_latent(self.init_latent * self.mask + self.nmask * res[1])
else:
store_latent(res[1])
self.step += 1
return res
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning):
t_enc = int(min(p.denoising_strength, 0.999) * p.steps)
@@ -105,7 +111,7 @@ class VanillaStableDiffusionSampler:
x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
self.sampler.p_sample_ddim = lambda x_dec, cond, ts, *args, **kwargs: p_sample_ddim_hook(self, x_dec, cond, ts, *args, **kwargs)
self.sampler.p_sample_ddim = self.p_sample_ddim_hook
self.mask = p.mask
self.nmask = p.nmask
self.init_latent = p.init_latent
@@ -117,7 +123,7 @@ class VanillaStableDiffusionSampler:
def sample(self, p, x, conditioning, unconditional_conditioning):
for fieldname in ['p_sample_ddim', 'p_sample_plms']:
if hasattr(self.sampler, fieldname):
setattr(self.sampler, fieldname, lambda x_dec, cond, ts, *args, **kwargs: p_sample_ddim_hook(self, x_dec, cond, ts, *args, **kwargs))
setattr(self.sampler, fieldname, self.p_sample_ddim_hook)
self.mask = None
self.nmask = None
self.init_latent = None
@@ -138,8 +144,12 @@ class CFGDenoiser(torch.nn.Module):
self.mask = None
self.nmask = None
self.init_latent = None
self.step = 0
def forward(self, x, sigma, uncond, cond, cond_scale):
cond = prompt_parser.reconstruct_cond_batch(cond, self.step)
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
if shared.batch_cond_uncond:
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2)
@@ -154,6 +164,8 @@ class CFGDenoiser(torch.nn.Module):
if self.mask is not None:
denoised = self.init_latent * self.mask + self.nmask * denoised
self.step += 1
return denoised