mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-04 03:10:21 +00:00
support for SD3: infinite prompt length, token counting
This commit is contained in:
@@ -27,24 +27,21 @@ chunk. Those objects are found in PromptChunk.fixes and, are placed into FrozenC
|
||||
are applied by sd_hijack.EmbeddingsWithFixes's forward function."""
|
||||
|
||||
|
||||
class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
||||
"""A pytorch module that is a wrapper for FrozenCLIPEmbedder module. it enhances FrozenCLIPEmbedder, making it possible to
|
||||
have unlimited prompt length and assign weights to tokens in prompt.
|
||||
"""
|
||||
|
||||
def __init__(self, wrapped, hijack):
|
||||
class TextConditionalModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
self.wrapped = wrapped
|
||||
"""Original FrozenCLIPEmbedder module; can also be FrozenOpenCLIPEmbedder or xlmr.BertSeriesModelWithTransformation,
|
||||
depending on model."""
|
||||
|
||||
self.hijack: sd_hijack.StableDiffusionModelHijack = hijack
|
||||
self.hijack = sd_hijack.model_hijack
|
||||
self.chunk_length = 75
|
||||
|
||||
self.is_trainable = getattr(wrapped, 'is_trainable', False)
|
||||
self.input_key = getattr(wrapped, 'input_key', 'txt')
|
||||
self.legacy_ucg_val = None
|
||||
self.is_trainable = False
|
||||
self.input_key = 'txt'
|
||||
self.return_pooled = False
|
||||
|
||||
self.comma_token = None
|
||||
self.id_start = None
|
||||
self.id_end = None
|
||||
self.id_pad = None
|
||||
|
||||
def empty_chunk(self):
|
||||
"""creates an empty PromptChunk and returns it"""
|
||||
@@ -210,10 +207,6 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
||||
is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream"
|
||||
"""
|
||||
|
||||
if opts.use_old_emphasis_implementation:
|
||||
import modules.sd_hijack_clip_old
|
||||
return modules.sd_hijack_clip_old.forward_old(self, texts)
|
||||
|
||||
batch_chunks, token_count = self.process_texts(texts)
|
||||
|
||||
used_embeddings = {}
|
||||
@@ -252,7 +245,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
||||
if any(x for x in texts if "(" in x or "[" in x) and opts.emphasis != "Original":
|
||||
self.hijack.extra_generation_params["Emphasis"] = opts.emphasis
|
||||
|
||||
if getattr(self.wrapped, 'return_pooled', False):
|
||||
if self.return_pooled:
|
||||
return torch.hstack(zs), zs[0].pooled
|
||||
else:
|
||||
return torch.hstack(zs)
|
||||
@@ -292,6 +285,34 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
||||
return z
|
||||
|
||||
|
||||
class FrozenCLIPEmbedderWithCustomWordsBase(TextConditionalModel):
|
||||
"""A pytorch module that is a wrapper for FrozenCLIPEmbedder module. it enhances FrozenCLIPEmbedder, making it possible to
|
||||
have unlimited prompt length and assign weights to tokens in prompt.
|
||||
"""
|
||||
|
||||
def __init__(self, wrapped, hijack):
|
||||
super().__init__()
|
||||
|
||||
self.hijack = hijack
|
||||
|
||||
self.wrapped = wrapped
|
||||
"""Original FrozenCLIPEmbedder module; can also be FrozenOpenCLIPEmbedder or xlmr.BertSeriesModelWithTransformation,
|
||||
depending on model."""
|
||||
|
||||
self.is_trainable = getattr(wrapped, 'is_trainable', False)
|
||||
self.input_key = getattr(wrapped, 'input_key', 'txt')
|
||||
self.return_pooled = getattr(self.wrapped, 'return_pooled', False)
|
||||
|
||||
self.legacy_ucg_val = None # for sgm codebase
|
||||
|
||||
def forward(self, texts):
|
||||
if opts.use_old_emphasis_implementation:
|
||||
import modules.sd_hijack_clip_old
|
||||
return modules.sd_hijack_clip_old.forward_old(self, texts)
|
||||
|
||||
return super().forward(texts)
|
||||
|
||||
|
||||
class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase):
|
||||
def __init__(self, wrapped, hijack):
|
||||
super().__init__(wrapped, hijack)
|
||||
|
Reference in New Issue
Block a user