mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-03 19:02:27 +00:00
added support for AND from https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/
This commit is contained in:
@@ -97,10 +97,26 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
|
||||
|
||||
|
||||
ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"])
|
||||
ScheduledPromptBatch = namedtuple("ScheduledPromptBatch", ["shape", "schedules"])
|
||||
|
||||
|
||||
def get_learned_conditioning(model, prompts, steps):
|
||||
"""converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond),
|
||||
and the sampling step at which this condition is to be replaced by the next one.
|
||||
|
||||
Input:
|
||||
(model, ['a red crown', 'a [blue:green:5] jeweled crown'], 20)
|
||||
|
||||
Output:
|
||||
[
|
||||
[
|
||||
ScheduledPromptConditioning(end_at_step=20, cond=tensor([[-0.3886, 0.0229, -0.0523, ..., -0.4901, -0.3066, 0.0674], ..., [ 0.3317, -0.5102, -0.4066, ..., 0.4119, -0.7647, -1.0160]], device='cuda:0'))
|
||||
],
|
||||
[
|
||||
ScheduledPromptConditioning(end_at_step=5, cond=tensor([[-0.3886, 0.0229, -0.0522, ..., -0.4901, -0.3067, 0.0673], ..., [-0.0192, 0.3867, -0.4644, ..., 0.1135, -0.3696, -0.4625]], device='cuda:0')),
|
||||
ScheduledPromptConditioning(end_at_step=20, cond=tensor([[-0.3886, 0.0229, -0.0522, ..., -0.4901, -0.3067, 0.0673], ..., [-0.7352, -0.4356, -0.7888, ..., 0.6994, -0.4312, -1.2593]], device='cuda:0'))
|
||||
]
|
||||
]
|
||||
"""
|
||||
res = []
|
||||
|
||||
prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps)
|
||||
@@ -123,13 +139,75 @@ def get_learned_conditioning(model, prompts, steps):
|
||||
cache[prompt] = cond_schedule
|
||||
res.append(cond_schedule)
|
||||
|
||||
return ScheduledPromptBatch((len(prompts),) + res[0][0].cond.shape, res)
|
||||
return res
|
||||
|
||||
|
||||
def reconstruct_cond_batch(c: ScheduledPromptBatch, current_step):
|
||||
param = c.schedules[0][0].cond
|
||||
res = torch.zeros(c.shape, device=param.device, dtype=param.dtype)
|
||||
for i, cond_schedule in enumerate(c.schedules):
|
||||
re_AND = re.compile(r"\bAND\b")
|
||||
re_weight = re.compile(r"^(.*?)(?:\s*:\s*([-+]?\s*(?:\d+|\d*\.\d+)?))?\s*$")
|
||||
|
||||
|
||||
def get_multicond_prompt_list(prompts):
|
||||
res_indexes = []
|
||||
|
||||
prompt_flat_list = []
|
||||
prompt_indexes = {}
|
||||
|
||||
for prompt in prompts:
|
||||
subprompts = re_AND.split(prompt)
|
||||
|
||||
indexes = []
|
||||
for subprompt in subprompts:
|
||||
text, weight = re_weight.search(subprompt).groups()
|
||||
|
||||
weight = float(weight) if weight is not None else 1.0
|
||||
|
||||
index = prompt_indexes.get(text, None)
|
||||
if index is None:
|
||||
index = len(prompt_flat_list)
|
||||
prompt_flat_list.append(text)
|
||||
prompt_indexes[text] = index
|
||||
|
||||
indexes.append((index, weight))
|
||||
|
||||
res_indexes.append(indexes)
|
||||
|
||||
return res_indexes, prompt_flat_list, prompt_indexes
|
||||
|
||||
|
||||
class ComposableScheduledPromptConditioning:
|
||||
def __init__(self, schedules, weight=1.0):
|
||||
self.schedules: list[ScheduledPromptConditioning] = schedules
|
||||
self.weight: float = weight
|
||||
|
||||
|
||||
class MulticondLearnedConditioning:
|
||||
def __init__(self, shape, batch):
|
||||
self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS
|
||||
self.batch: list[list[ComposableScheduledPromptConditioning]] = batch
|
||||
|
||||
|
||||
def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearnedConditioning:
|
||||
"""same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt.
|
||||
For each prompt, the list is obtained by splitting the prompt using the AND separator.
|
||||
|
||||
https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/
|
||||
"""
|
||||
|
||||
res_indexes, prompt_flat_list, prompt_indexes = get_multicond_prompt_list(prompts)
|
||||
|
||||
learned_conditioning = get_learned_conditioning(model, prompt_flat_list, steps)
|
||||
|
||||
res = []
|
||||
for indexes in res_indexes:
|
||||
res.append([ComposableScheduledPromptConditioning(learned_conditioning[i], weight) for i, weight in indexes])
|
||||
|
||||
return MulticondLearnedConditioning(shape=(len(prompts),), batch=res)
|
||||
|
||||
|
||||
def reconstruct_cond_batch(c: list[list[ScheduledPromptConditioning]], current_step):
|
||||
param = c[0][0].cond
|
||||
res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)
|
||||
for i, cond_schedule in enumerate(c):
|
||||
target_index = 0
|
||||
for current, (end_at, cond) in enumerate(cond_schedule):
|
||||
if current_step <= end_at:
|
||||
@@ -140,6 +218,30 @@ def reconstruct_cond_batch(c: ScheduledPromptBatch, current_step):
|
||||
return res
|
||||
|
||||
|
||||
def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
|
||||
param = c.batch[0][0].schedules[0].cond
|
||||
|
||||
tensors = []
|
||||
conds_list = []
|
||||
|
||||
for batch_no, composable_prompts in enumerate(c.batch):
|
||||
conds_for_batch = []
|
||||
|
||||
for cond_index, composable_prompt in enumerate(composable_prompts):
|
||||
target_index = 0
|
||||
for current, (end_at, cond) in enumerate(composable_prompt.schedules):
|
||||
if current_step <= end_at:
|
||||
target_index = current
|
||||
break
|
||||
|
||||
conds_for_batch.append((len(tensors), composable_prompt.weight))
|
||||
tensors.append(composable_prompt.schedules[target_index].cond)
|
||||
|
||||
conds_list.append(conds_for_batch)
|
||||
|
||||
return conds_list, torch.stack(tensors).to(device=param.device, dtype=param.dtype)
|
||||
|
||||
|
||||
re_attention = re.compile(r"""
|
||||
\\\(|
|
||||
\\\)|
|
||||
|
Reference in New Issue
Block a user