Add Pad conds v0 option

This commit is contained in:
AUTOMATIC1111
2024-01-27 22:30:12 +03:00
parent e717eaff86
commit 757dda9ade
6 changed files with 78 additions and 19 deletions

View File

@@ -53,6 +53,7 @@ class CFGDenoiser(torch.nn.Module):
self.step = 0
self.image_cfg_scale = None
self.padded_cond_uncond = False
self.padded_cond_uncond_v0 = False
self.sampler = sampler
self.model_wrap = None
self.p = None
@@ -91,6 +92,62 @@ class CFGDenoiser(torch.nn.Module):
self.sampler.sampler_extra_args['cond'] = c
self.sampler.sampler_extra_args['uncond'] = uc
def pad_cond_uncond(self, cond, uncond):
empty = shared.sd_model.cond_stage_model_empty_prompt
num_repeats = (cond.shape[1] - cond.shape[1]) // empty.shape[1]
if num_repeats < 0:
cond = pad_cond(cond, -num_repeats, empty)
self.padded_cond_uncond = True
elif num_repeats > 0:
uncond = pad_cond(uncond, num_repeats, empty)
self.padded_cond_uncond = True
return cond, uncond
def pad_cond_uncond_v0(self, cond, uncond):
"""
Pads the 'uncond' tensor to match the shape of the 'cond' tensor.
If 'uncond' is a dictionary, it is assumed that the 'crossattn' key holds the tensor to be padded.
If 'uncond' is a tensor, it is padded directly.
If the number of columns in 'uncond' is less than the number of columns in 'cond', the last column of 'uncond'
is repeated to match the number of columns in 'cond'.
If the number of columns in 'uncond' is greater than the number of columns in 'cond', 'uncond' is truncated
to match the number of columns in 'cond'.
Args:
cond (torch.Tensor or DictWithShape): The condition tensor to match the shape of 'uncond'.
uncond (torch.Tensor or DictWithShape): The tensor to be padded, or a dictionary containing the tensor to be padded.
Returns:
tuple: A tuple containing the 'cond' tensor and the padded 'uncond' tensor.
Note:
This is the padding that was always used in DDIM before version 1.6.0
"""
is_dict_cond = isinstance(uncond, dict)
uncond_vec = uncond['crossattn'] if is_dict_cond else uncond
if uncond_vec.shape[1] < cond.shape[1]:
last_vector = uncond_vec[:, -1:]
last_vector_repeated = last_vector.repeat([1, cond.shape[1] - uncond_vec.shape[1], 1])
uncond_vec = torch.hstack([uncond_vec, last_vector_repeated])
self.padded_cond_uncond_v0 = True
elif uncond_vec.shape[1] > cond.shape[1]:
uncond_vec = uncond_vec[:, :cond.shape[1]]
self.padded_cond_uncond_v0 = True
if is_dict_cond:
uncond['crossattn'] = uncond_vec
else:
uncond = uncond_vec
return cond, uncond
def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
if state.interrupted or state.skipped:
raise sd_samplers_common.InterruptedException
@@ -162,16 +219,11 @@ class CFGDenoiser(torch.nn.Module):
sigma_in = sigma_in[:-batch_size]
self.padded_cond_uncond = False
self.padded_cond_uncond_v0 = False
if shared.opts.pad_cond_uncond and tensor.shape[1] != uncond.shape[1]:
empty = shared.sd_model.cond_stage_model_empty_prompt
num_repeats = (tensor.shape[1] - uncond.shape[1]) // empty.shape[1]
if num_repeats < 0:
tensor = pad_cond(tensor, -num_repeats, empty)
self.padded_cond_uncond = True
elif num_repeats > 0:
uncond = pad_cond(uncond, num_repeats, empty)
self.padded_cond_uncond = True
tensor, uncond = self.pad_cond_uncond(tensor, uncond)
elif shared.opts.pad_cond_uncond_v0 and tensor.shape[1] != uncond.shape[1]:
tensor, uncond = self.pad_cond_uncond_v0(tensor, uncond)
if tensor.shape[1] == uncond.shape[1] or skip_uncond:
if is_edit_model: