mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-04 11:12:35 +00:00
prompt editing
This commit is contained in:
@@ -7,6 +7,7 @@ from PIL import Image
|
||||
import k_diffusion.sampling
|
||||
import ldm.models.diffusion.ddim
|
||||
import ldm.models.diffusion.plms
|
||||
from modules import prompt_parser
|
||||
|
||||
from modules.shared import opts, cmd_opts, state
|
||||
import modules.shared as shared
|
||||
@@ -53,20 +54,6 @@ def store_latent(decoded):
|
||||
shared.state.current_image = sample_to_image(decoded)
|
||||
|
||||
|
||||
def p_sample_ddim_hook(sampler_wrapper, x_dec, cond, ts, *args, **kwargs):
|
||||
if sampler_wrapper.mask is not None:
|
||||
img_orig = sampler_wrapper.sampler.model.q_sample(sampler_wrapper.init_latent, ts)
|
||||
x_dec = img_orig * sampler_wrapper.mask + sampler_wrapper.nmask * x_dec
|
||||
|
||||
res = sampler_wrapper.orig_p_sample_ddim(x_dec, cond, ts, *args, **kwargs)
|
||||
|
||||
if sampler_wrapper.mask is not None:
|
||||
store_latent(sampler_wrapper.init_latent * sampler_wrapper.mask + sampler_wrapper.nmask * res[1])
|
||||
else:
|
||||
store_latent(res[1])
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def extended_tdqm(sequence, *args, desc=None, **kwargs):
|
||||
state.sampling_steps = len(sequence)
|
||||
@@ -93,6 +80,25 @@ class VanillaStableDiffusionSampler:
|
||||
self.mask = None
|
||||
self.nmask = None
|
||||
self.init_latent = None
|
||||
self.step = 0
|
||||
|
||||
def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
|
||||
cond = prompt_parser.reconstruct_cond_batch(cond, self.step)
|
||||
unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
|
||||
|
||||
if self.mask is not None:
|
||||
img_orig = self.sampler.model.q_sample(self.init_latent, ts)
|
||||
x_dec = img_orig * self.mask + self.nmask * x_dec
|
||||
|
||||
res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
|
||||
|
||||
if self.mask is not None:
|
||||
store_latent(self.init_latent * self.mask + self.nmask * res[1])
|
||||
else:
|
||||
store_latent(res[1])
|
||||
|
||||
self.step += 1
|
||||
return res
|
||||
|
||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning):
|
||||
t_enc = int(min(p.denoising_strength, 0.999) * p.steps)
|
||||
@@ -105,7 +111,7 @@ class VanillaStableDiffusionSampler:
|
||||
|
||||
x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
|
||||
|
||||
self.sampler.p_sample_ddim = lambda x_dec, cond, ts, *args, **kwargs: p_sample_ddim_hook(self, x_dec, cond, ts, *args, **kwargs)
|
||||
self.sampler.p_sample_ddim = self.p_sample_ddim_hook
|
||||
self.mask = p.mask
|
||||
self.nmask = p.nmask
|
||||
self.init_latent = p.init_latent
|
||||
@@ -117,7 +123,7 @@ class VanillaStableDiffusionSampler:
|
||||
def sample(self, p, x, conditioning, unconditional_conditioning):
|
||||
for fieldname in ['p_sample_ddim', 'p_sample_plms']:
|
||||
if hasattr(self.sampler, fieldname):
|
||||
setattr(self.sampler, fieldname, lambda x_dec, cond, ts, *args, **kwargs: p_sample_ddim_hook(self, x_dec, cond, ts, *args, **kwargs))
|
||||
setattr(self.sampler, fieldname, self.p_sample_ddim_hook)
|
||||
self.mask = None
|
||||
self.nmask = None
|
||||
self.init_latent = None
|
||||
@@ -138,8 +144,12 @@ class CFGDenoiser(torch.nn.Module):
|
||||
self.mask = None
|
||||
self.nmask = None
|
||||
self.init_latent = None
|
||||
self.step = 0
|
||||
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
||||
cond = prompt_parser.reconstruct_cond_batch(cond, self.step)
|
||||
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
|
||||
|
||||
if shared.batch_cond_uncond:
|
||||
x_in = torch.cat([x] * 2)
|
||||
sigma_in = torch.cat([sigma] * 2)
|
||||
@@ -154,6 +164,8 @@ class CFGDenoiser(torch.nn.Module):
|
||||
if self.mask is not None:
|
||||
denoised = self.init_latent * self.mask + self.nmask * denoised
|
||||
|
||||
self.step += 1
|
||||
|
||||
return denoised
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user