mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-04 03:10:21 +00:00
getting SD2.1 to run on SDXL repo
This commit is contained in:
@@ -144,7 +144,12 @@ def get_learned_conditioning(model, prompts, steps):
|
||||
|
||||
cond_schedule = []
|
||||
for i, (end_at_step, _) in enumerate(prompt_schedule):
|
||||
cond_schedule.append(ScheduledPromptConditioning(end_at_step, conds[i]))
|
||||
if isinstance(conds, dict):
|
||||
cond = {k: v[i] for k, v in conds.items()}
|
||||
else:
|
||||
cond = conds[i]
|
||||
|
||||
cond_schedule.append(ScheduledPromptConditioning(end_at_step, cond))
|
||||
|
||||
cache[prompt] = cond_schedule
|
||||
res.append(cond_schedule)
|
||||
@@ -214,20 +219,57 @@ def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearne
|
||||
return MulticondLearnedConditioning(shape=(len(prompts),), batch=res)
|
||||
|
||||
|
||||
class DictWithShape(dict):
|
||||
def __init__(self, x, shape):
|
||||
super().__init__()
|
||||
self.update(x)
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self["crossattn"].shape
|
||||
|
||||
|
||||
def reconstruct_cond_batch(c: List[List[ScheduledPromptConditioning]], current_step):
|
||||
param = c[0][0].cond
|
||||
res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)
|
||||
is_dict = isinstance(param, dict)
|
||||
|
||||
if is_dict:
|
||||
dict_cond = param
|
||||
res = {k: torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype) for k, param in dict_cond.items()}
|
||||
res = DictWithShape(res, (len(c),) + dict_cond['crossattn'].shape)
|
||||
else:
|
||||
res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)
|
||||
|
||||
for i, cond_schedule in enumerate(c):
|
||||
target_index = 0
|
||||
for current, entry in enumerate(cond_schedule):
|
||||
if current_step <= entry.end_at_step:
|
||||
target_index = current
|
||||
break
|
||||
res[i] = cond_schedule[target_index].cond
|
||||
|
||||
if is_dict:
|
||||
for k, param in cond_schedule[target_index].cond.items():
|
||||
res[k][i] = param
|
||||
else:
|
||||
res[i] = cond_schedule[target_index].cond
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def stack_conds(tensors):
|
||||
# if prompts have wildly different lengths above the limit we'll get tensors of different shapes
|
||||
# and won't be able to torch.stack them. So this fixes that.
|
||||
token_count = max([x.shape[0] for x in tensors])
|
||||
for i in range(len(tensors)):
|
||||
if tensors[i].shape[0] != token_count:
|
||||
last_vector = tensors[i][-1:]
|
||||
last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1])
|
||||
tensors[i] = torch.vstack([tensors[i], last_vector_repeated])
|
||||
|
||||
return torch.stack(tensors)
|
||||
|
||||
|
||||
|
||||
def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
|
||||
param = c.batch[0][0].schedules[0].cond
|
||||
|
||||
@@ -249,16 +291,14 @@ def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
|
||||
|
||||
conds_list.append(conds_for_batch)
|
||||
|
||||
# if prompts have wildly different lengths above the limit we'll get tensors fo different shapes
|
||||
# and won't be able to torch.stack them. So this fixes that.
|
||||
token_count = max([x.shape[0] for x in tensors])
|
||||
for i in range(len(tensors)):
|
||||
if tensors[i].shape[0] != token_count:
|
||||
last_vector = tensors[i][-1:]
|
||||
last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1])
|
||||
tensors[i] = torch.vstack([tensors[i], last_vector_repeated])
|
||||
if isinstance(tensors[0], dict):
|
||||
keys = list(tensors[0].keys())
|
||||
stacked = {k: stack_conds([x[k] for x in tensors]) for k in keys}
|
||||
stacked = DictWithShape(stacked, stacked['crossattn'].shape)
|
||||
else:
|
||||
stacked = stack_conds(tensors).to(device=param.device, dtype=param.dtype)
|
||||
|
||||
return conds_list, torch.stack(tensors).to(device=param.device, dtype=param.dtype)
|
||||
return conds_list, stacked
|
||||
|
||||
|
||||
re_attention = re.compile(r"""
|
||||
|
Reference in New Issue
Block a user