mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-04 11:12:35 +00:00
new implementation for attention/emphasis
This commit is contained in:
@@ -6,6 +6,7 @@ import torch
|
||||
import numpy as np
|
||||
from torch import einsum
|
||||
|
||||
from modules import prompt_parser
|
||||
from modules.shared import opts, device, cmd_opts
|
||||
|
||||
from ldm.util import default
|
||||
@@ -211,6 +212,7 @@ class StableDiffusionModelHijack:
|
||||
param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
|
||||
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
|
||||
emb = next(iter(param_dict.items()))[1]
|
||||
# diffuser concepts
|
||||
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
|
||||
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
|
||||
|
||||
@@ -236,7 +238,7 @@ class StableDiffusionModelHijack:
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
continue
|
||||
|
||||
print(f"Loaded a total of {len(self.word_embeddings)} text inversion embeddings.")
|
||||
print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.")
|
||||
|
||||
def hijack(self, m):
|
||||
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
|
||||
@@ -275,6 +277,7 @@ class StableDiffusionModelHijack:
|
||||
_, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
|
||||
return remade_batch_tokens[0], token_count, max_length
|
||||
|
||||
|
||||
class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
||||
def __init__(self, wrapped, hijack):
|
||||
super().__init__()
|
||||
@@ -300,7 +303,92 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
||||
if mult != 1.0:
|
||||
self.token_mults[ident] = mult
|
||||
|
||||
def process_text(self, text):
|
||||
|
||||
def tokenize_line(self, line, used_custom_terms, hijack_comments):
|
||||
id_start = self.wrapped.tokenizer.bos_token_id
|
||||
id_end = self.wrapped.tokenizer.eos_token_id
|
||||
maxlen = self.wrapped.max_length
|
||||
|
||||
if opts.enable_emphasis:
|
||||
parsed = prompt_parser.parse_prompt_attention(line)
|
||||
else:
|
||||
parsed = [[line, 1.0]]
|
||||
|
||||
tokenized = self.wrapped.tokenizer([text for text, _ in parsed], truncation=False, add_special_tokens=False)["input_ids"]
|
||||
|
||||
fixes = []
|
||||
remade_tokens = []
|
||||
multipliers = []
|
||||
|
||||
for tokens, (text, weight) in zip(tokenized, parsed):
|
||||
i = 0
|
||||
while i < len(tokens):
|
||||
token = tokens[i]
|
||||
|
||||
possible_matches = self.hijack.ids_lookup.get(token, None)
|
||||
|
||||
if possible_matches is None:
|
||||
remade_tokens.append(token)
|
||||
multipliers.append(weight)
|
||||
else:
|
||||
found = False
|
||||
for ids, word in possible_matches:
|
||||
if tokens[i:i + len(ids)] == ids:
|
||||
emb_len = int(self.hijack.word_embeddings[word].shape[0])
|
||||
fixes.append((len(remade_tokens), word))
|
||||
remade_tokens += [0] * emb_len
|
||||
multipliers += [weight] * emb_len
|
||||
i += len(ids) - 1
|
||||
found = True
|
||||
used_custom_terms.append((word, self.hijack.word_embeddings_checksums[word]))
|
||||
break
|
||||
|
||||
if not found:
|
||||
remade_tokens.append(token)
|
||||
multipliers.append(weight)
|
||||
i += 1
|
||||
|
||||
if len(remade_tokens) > maxlen - 2:
|
||||
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
|
||||
ovf = remade_tokens[maxlen - 2:]
|
||||
overflowing_words = [vocab.get(int(x), "") for x in ovf]
|
||||
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
|
||||
hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
|
||||
|
||||
token_count = len(remade_tokens)
|
||||
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
|
||||
remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
|
||||
|
||||
multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
|
||||
multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
|
||||
|
||||
return remade_tokens, fixes, multipliers, token_count
|
||||
|
||||
def process_text(self, texts):
|
||||
used_custom_terms = []
|
||||
remade_batch_tokens = []
|
||||
hijack_comments = []
|
||||
hijack_fixes = []
|
||||
token_count = 0
|
||||
|
||||
cache = {}
|
||||
batch_multipliers = []
|
||||
for line in texts:
|
||||
if line in cache:
|
||||
remade_tokens, fixes, multipliers = cache[line]
|
||||
else:
|
||||
remade_tokens, fixes, multipliers, token_count = self.tokenize_line(line, used_custom_terms, hijack_comments)
|
||||
|
||||
cache[line] = (remade_tokens, fixes, multipliers)
|
||||
|
||||
remade_batch_tokens.append(remade_tokens)
|
||||
hijack_fixes.append(fixes)
|
||||
batch_multipliers.append(multipliers)
|
||||
|
||||
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
|
||||
|
||||
|
||||
def process_text_old(self, text):
|
||||
id_start = self.wrapped.tokenizer.bos_token_id
|
||||
id_end = self.wrapped.tokenizer.eos_token_id
|
||||
maxlen = self.wrapped.max_length
|
||||
@@ -376,12 +464,18 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
||||
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
|
||||
|
||||
def forward(self, text):
|
||||
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
|
||||
|
||||
if opts.use_old_emphasis_implementation:
|
||||
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text)
|
||||
else:
|
||||
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
|
||||
|
||||
|
||||
self.hijack.fixes = hijack_fixes
|
||||
self.hijack.comments = hijack_comments
|
||||
|
||||
if len(used_custom_terms) > 0:
|
||||
self.hijack.comments.append("Used custom terms: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
|
||||
self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
|
||||
|
||||
tokens = torch.asarray(remade_batch_tokens).to(device)
|
||||
outputs = self.wrapped.transformer(input_ids=tokens)
|
||||
|
Reference in New Issue
Block a user