SDXL support

This commit is contained in:
AUTOMATIC1111
2023-07-12 23:52:43 +03:00
parent af081211ee
commit da464a3fb3
16 changed files with 241 additions and 44 deletions

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
import re
from collections import namedtuple
from typing import List
@@ -109,7 +111,19 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"])
def get_learned_conditioning(model, prompts, steps):
class SdConditioning(list):
"""
A list with prompts for stable diffusion's conditioner model.
Can also specify width and height of created image - SDXL needs it.
"""
def __init__(self, prompts, width=None, height=None):
super().__init__()
self.extend(prompts)
self.width = width or getattr(prompts, 'width', None)
self.height = height or getattr(prompts, 'height', None)
def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps):
"""converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond),
and the sampling step at which this condition is to be replaced by the next one.
@@ -160,11 +174,13 @@ def get_learned_conditioning(model, prompts, steps):
re_AND = re.compile(r"\bAND\b")
re_weight = re.compile(r"^(.*?)(?:\s*:\s*([-+]?(?:\d+\.?|\d*\.\d+)))?\s*$")
def get_multicond_prompt_list(prompts):
def get_multicond_prompt_list(prompts: SdConditioning | list[str]):
res_indexes = []
prompt_flat_list = []
prompt_indexes = {}
prompt_flat_list = SdConditioning(prompts)
prompt_flat_list.clear()
for prompt in prompts:
subprompts = re_AND.split(prompt)
@@ -201,6 +217,7 @@ class MulticondLearnedConditioning:
self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS
self.batch: List[List[ComposableScheduledPromptConditioning]] = batch
def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearnedConditioning:
"""same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt.
For each prompt, the list is obtained by splitting the prompt using the AND separator.