mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-09 05:39:47 +00:00
Add support Stable Diffusion 2.0
This commit is contained in:
@@ -9,7 +9,7 @@ sys.path.insert(0, script_path)
|
||||
|
||||
# search for directory of stable diffusion in following places
|
||||
sd_path = None
|
||||
possible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion'), '.', os.path.dirname(script_path)]
|
||||
possible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion-stability-ai'), '.', os.path.dirname(script_path)]
|
||||
for possible_sd_path in possible_sd_paths:
|
||||
if os.path.exists(os.path.join(possible_sd_path, 'ldm/models/diffusion/ddpm.py')):
|
||||
sd_path = os.path.abspath(possible_sd_path)
|
||||
|
@@ -9,18 +9,29 @@ from torch.nn.functional import silu
|
||||
|
||||
import modules.textual_inversion.textual_inversion
|
||||
from modules import prompt_parser, devices, sd_hijack_optimizations, shared
|
||||
from modules.shared import opts, device, cmd_opts
|
||||
from modules.shared import cmd_opts
|
||||
from modules import sd_hijack_clip, sd_hijack_open_clip
|
||||
|
||||
from modules.sd_hijack_optimizations import invokeAI_mps_available
|
||||
|
||||
import ldm.modules.attention
|
||||
import ldm.modules.diffusionmodules.model
|
||||
import ldm.models.diffusion.ddim
|
||||
import ldm.models.diffusion.plms
|
||||
import ldm.modules.encoders.modules
|
||||
|
||||
attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
|
||||
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
|
||||
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
|
||||
|
||||
# new memory efficient cross attention blocks do not support hypernets and we already
|
||||
# have memory efficient cross attention anyway, so this disables SD2.0's memory efficient cross attention
|
||||
ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.CrossAttention
|
||||
ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention
|
||||
|
||||
# silence new console spam from SD2
|
||||
ldm.modules.attention.print = lambda *args: None
|
||||
ldm.modules.diffusionmodules.model.print = lambda *args: None
|
||||
|
||||
def apply_optimizations():
|
||||
undo_optimizations()
|
||||
@@ -49,16 +60,11 @@ def apply_optimizations():
|
||||
|
||||
|
||||
def undo_optimizations():
|
||||
from modules.hypernetworks import hypernetwork
|
||||
|
||||
ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
|
||||
ldm.modules.attention.CrossAttention.forward = attention_CrossAttention_forward # this stops hypernets from working
|
||||
ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
|
||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
|
||||
|
||||
|
||||
def get_target_prompt_token_count(token_count):
|
||||
return math.ceil(max(token_count, 1) / 75) * 75
|
||||
|
||||
|
||||
class StableDiffusionModelHijack:
|
||||
fixes = None
|
||||
@@ -70,10 +76,13 @@ class StableDiffusionModelHijack:
|
||||
embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir)
|
||||
|
||||
def hijack(self, m):
|
||||
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
|
||||
|
||||
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
|
||||
m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
|
||||
if type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder:
|
||||
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
|
||||
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
|
||||
m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
|
||||
elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder:
|
||||
m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self)
|
||||
m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
|
||||
|
||||
self.clip = m.cond_stage_model
|
||||
|
||||
@@ -89,12 +98,15 @@ class StableDiffusionModelHijack:
|
||||
self.layers = flatten(m)
|
||||
|
||||
def undo_hijack(self, m):
|
||||
if type(m.cond_stage_model) == FrozenCLIPEmbedderWithCustomWords:
|
||||
if type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
|
||||
m.cond_stage_model = m.cond_stage_model.wrapped
|
||||
|
||||
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
|
||||
if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
|
||||
model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped
|
||||
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
|
||||
if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
|
||||
model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped
|
||||
elif type(m.cond_stage_model) == sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords:
|
||||
m.cond_stage_model.wrapped.model.token_embedding = m.cond_stage_model.wrapped.model.token_embedding.wrapped
|
||||
m.cond_stage_model = m.cond_stage_model.wrapped
|
||||
|
||||
self.apply_circular(False)
|
||||
self.layers = None
|
||||
@@ -114,262 +126,9 @@ class StableDiffusionModelHijack:
|
||||
|
||||
def tokenize(self, text):
|
||||
_, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
|
||||
return remade_batch_tokens[0], token_count, get_target_prompt_token_count(token_count)
|
||||
return remade_batch_tokens[0], token_count, sd_hijack_clip.get_target_prompt_token_count(token_count)
|
||||
|
||||
|
||||
class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
||||
def __init__(self, wrapped, hijack):
|
||||
super().__init__()
|
||||
self.wrapped = wrapped
|
||||
self.hijack: StableDiffusionModelHijack = hijack
|
||||
self.tokenizer = wrapped.tokenizer
|
||||
self.token_mults = {}
|
||||
|
||||
self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ',</w>'][0]
|
||||
|
||||
tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k]
|
||||
for text, ident in tokens_with_parens:
|
||||
mult = 1.0
|
||||
for c in text:
|
||||
if c == '[':
|
||||
mult /= 1.1
|
||||
if c == ']':
|
||||
mult *= 1.1
|
||||
if c == '(':
|
||||
mult *= 1.1
|
||||
if c == ')':
|
||||
mult /= 1.1
|
||||
|
||||
if mult != 1.0:
|
||||
self.token_mults[ident] = mult
|
||||
|
||||
def tokenize_line(self, line, used_custom_terms, hijack_comments):
|
||||
id_end = self.wrapped.tokenizer.eos_token_id
|
||||
|
||||
if opts.enable_emphasis:
|
||||
parsed = prompt_parser.parse_prompt_attention(line)
|
||||
else:
|
||||
parsed = [[line, 1.0]]
|
||||
|
||||
tokenized = self.wrapped.tokenizer([text for text, _ in parsed], truncation=False, add_special_tokens=False)["input_ids"]
|
||||
|
||||
fixes = []
|
||||
remade_tokens = []
|
||||
multipliers = []
|
||||
last_comma = -1
|
||||
|
||||
for tokens, (text, weight) in zip(tokenized, parsed):
|
||||
i = 0
|
||||
while i < len(tokens):
|
||||
token = tokens[i]
|
||||
|
||||
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
|
||||
|
||||
if token == self.comma_token:
|
||||
last_comma = len(remade_tokens)
|
||||
elif opts.comma_padding_backtrack != 0 and max(len(remade_tokens), 1) % 75 == 0 and last_comma != -1 and len(remade_tokens) - last_comma <= opts.comma_padding_backtrack:
|
||||
last_comma += 1
|
||||
reloc_tokens = remade_tokens[last_comma:]
|
||||
reloc_mults = multipliers[last_comma:]
|
||||
|
||||
remade_tokens = remade_tokens[:last_comma]
|
||||
length = len(remade_tokens)
|
||||
|
||||
rem = int(math.ceil(length / 75)) * 75 - length
|
||||
remade_tokens += [id_end] * rem + reloc_tokens
|
||||
multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults
|
||||
|
||||
if embedding is None:
|
||||
remade_tokens.append(token)
|
||||
multipliers.append(weight)
|
||||
i += 1
|
||||
else:
|
||||
emb_len = int(embedding.vec.shape[0])
|
||||
iteration = len(remade_tokens) // 75
|
||||
if (len(remade_tokens) + emb_len) // 75 != iteration:
|
||||
rem = (75 * (iteration + 1) - len(remade_tokens))
|
||||
remade_tokens += [id_end] * rem
|
||||
multipliers += [1.0] * rem
|
||||
iteration += 1
|
||||
fixes.append((iteration, (len(remade_tokens) % 75, embedding)))
|
||||
remade_tokens += [0] * emb_len
|
||||
multipliers += [weight] * emb_len
|
||||
used_custom_terms.append((embedding.name, embedding.checksum()))
|
||||
i += embedding_length_in_tokens
|
||||
|
||||
token_count = len(remade_tokens)
|
||||
prompt_target_length = get_target_prompt_token_count(token_count)
|
||||
tokens_to_add = prompt_target_length - len(remade_tokens)
|
||||
|
||||
remade_tokens = remade_tokens + [id_end] * tokens_to_add
|
||||
multipliers = multipliers + [1.0] * tokens_to_add
|
||||
|
||||
return remade_tokens, fixes, multipliers, token_count
|
||||
|
||||
def process_text(self, texts):
|
||||
used_custom_terms = []
|
||||
remade_batch_tokens = []
|
||||
hijack_comments = []
|
||||
hijack_fixes = []
|
||||
token_count = 0
|
||||
|
||||
cache = {}
|
||||
batch_multipliers = []
|
||||
for line in texts:
|
||||
if line in cache:
|
||||
remade_tokens, fixes, multipliers = cache[line]
|
||||
else:
|
||||
remade_tokens, fixes, multipliers, current_token_count = self.tokenize_line(line, used_custom_terms, hijack_comments)
|
||||
token_count = max(current_token_count, token_count)
|
||||
|
||||
cache[line] = (remade_tokens, fixes, multipliers)
|
||||
|
||||
remade_batch_tokens.append(remade_tokens)
|
||||
hijack_fixes.append(fixes)
|
||||
batch_multipliers.append(multipliers)
|
||||
|
||||
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
|
||||
|
||||
def process_text_old(self, text):
|
||||
id_start = self.wrapped.tokenizer.bos_token_id
|
||||
id_end = self.wrapped.tokenizer.eos_token_id
|
||||
maxlen = self.wrapped.max_length # you get to stay at 77
|
||||
used_custom_terms = []
|
||||
remade_batch_tokens = []
|
||||
overflowing_words = []
|
||||
hijack_comments = []
|
||||
hijack_fixes = []
|
||||
token_count = 0
|
||||
|
||||
cache = {}
|
||||
batch_tokens = self.wrapped.tokenizer(text, truncation=False, add_special_tokens=False)["input_ids"]
|
||||
batch_multipliers = []
|
||||
for tokens in batch_tokens:
|
||||
tuple_tokens = tuple(tokens)
|
||||
|
||||
if tuple_tokens in cache:
|
||||
remade_tokens, fixes, multipliers = cache[tuple_tokens]
|
||||
else:
|
||||
fixes = []
|
||||
remade_tokens = []
|
||||
multipliers = []
|
||||
mult = 1.0
|
||||
|
||||
i = 0
|
||||
while i < len(tokens):
|
||||
token = tokens[i]
|
||||
|
||||
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
|
||||
|
||||
mult_change = self.token_mults.get(token) if opts.enable_emphasis else None
|
||||
if mult_change is not None:
|
||||
mult *= mult_change
|
||||
i += 1
|
||||
elif embedding is None:
|
||||
remade_tokens.append(token)
|
||||
multipliers.append(mult)
|
||||
i += 1
|
||||
else:
|
||||
emb_len = int(embedding.vec.shape[0])
|
||||
fixes.append((len(remade_tokens), embedding))
|
||||
remade_tokens += [0] * emb_len
|
||||
multipliers += [mult] * emb_len
|
||||
used_custom_terms.append((embedding.name, embedding.checksum()))
|
||||
i += embedding_length_in_tokens
|
||||
|
||||
if len(remade_tokens) > maxlen - 2:
|
||||
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
|
||||
ovf = remade_tokens[maxlen - 2:]
|
||||
overflowing_words = [vocab.get(int(x), "") for x in ovf]
|
||||
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
|
||||
hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
|
||||
|
||||
token_count = len(remade_tokens)
|
||||
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
|
||||
remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
|
||||
cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
|
||||
|
||||
multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
|
||||
multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
|
||||
|
||||
remade_batch_tokens.append(remade_tokens)
|
||||
hijack_fixes.append(fixes)
|
||||
batch_multipliers.append(multipliers)
|
||||
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
|
||||
|
||||
def forward(self, text):
|
||||
use_old = opts.use_old_emphasis_implementation
|
||||
if use_old:
|
||||
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text)
|
||||
else:
|
||||
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
|
||||
|
||||
self.hijack.comments += hijack_comments
|
||||
|
||||
if len(used_custom_terms) > 0:
|
||||
self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
|
||||
|
||||
if use_old:
|
||||
self.hijack.fixes = hijack_fixes
|
||||
return self.process_tokens(remade_batch_tokens, batch_multipliers)
|
||||
|
||||
z = None
|
||||
i = 0
|
||||
while max(map(len, remade_batch_tokens)) != 0:
|
||||
rem_tokens = [x[75:] for x in remade_batch_tokens]
|
||||
rem_multipliers = [x[75:] for x in batch_multipliers]
|
||||
|
||||
self.hijack.fixes = []
|
||||
for unfiltered in hijack_fixes:
|
||||
fixes = []
|
||||
for fix in unfiltered:
|
||||
if fix[0] == i:
|
||||
fixes.append(fix[1])
|
||||
self.hijack.fixes.append(fixes)
|
||||
|
||||
tokens = []
|
||||
multipliers = []
|
||||
for j in range(len(remade_batch_tokens)):
|
||||
if len(remade_batch_tokens[j]) > 0:
|
||||
tokens.append(remade_batch_tokens[j][:75])
|
||||
multipliers.append(batch_multipliers[j][:75])
|
||||
else:
|
||||
tokens.append([self.wrapped.tokenizer.eos_token_id] * 75)
|
||||
multipliers.append([1.0] * 75)
|
||||
|
||||
z1 = self.process_tokens(tokens, multipliers)
|
||||
z = z1 if z is None else torch.cat((z, z1), axis=-2)
|
||||
|
||||
remade_batch_tokens = rem_tokens
|
||||
batch_multipliers = rem_multipliers
|
||||
i += 1
|
||||
|
||||
return z
|
||||
|
||||
def process_tokens(self, remade_batch_tokens, batch_multipliers):
|
||||
if not opts.use_old_emphasis_implementation:
|
||||
remade_batch_tokens = [[self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in remade_batch_tokens]
|
||||
batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers]
|
||||
|
||||
tokens = torch.asarray(remade_batch_tokens).to(device)
|
||||
outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
|
||||
|
||||
if opts.CLIP_stop_at_last_layers > 1:
|
||||
z = outputs.hidden_states[-opts.CLIP_stop_at_last_layers]
|
||||
z = self.wrapped.transformer.text_model.final_layer_norm(z)
|
||||
else:
|
||||
z = outputs.last_hidden_state
|
||||
|
||||
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
|
||||
batch_multipliers_of_same_length = [x + [1.0] * (75 - len(x)) for x in batch_multipliers]
|
||||
batch_multipliers = torch.asarray(batch_multipliers_of_same_length).to(device)
|
||||
original_mean = z.mean()
|
||||
z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
|
||||
new_mean = z.mean()
|
||||
z *= original_mean / new_mean
|
||||
|
||||
return z
|
||||
|
||||
|
||||
class EmbeddingsWithFixes(torch.nn.Module):
|
||||
def __init__(self, wrapped, embeddings):
|
||||
|
301
modules/sd_hijack_clip.py
Normal file
301
modules/sd_hijack_clip.py
Normal file
@@ -0,0 +1,301 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
from modules import prompt_parser, devices
|
||||
from modules.shared import opts
|
||||
|
||||
|
||||
def get_target_prompt_token_count(token_count):
|
||||
return math.ceil(max(token_count, 1) / 75) * 75
|
||||
|
||||
|
||||
class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
||||
def __init__(self, wrapped, hijack):
|
||||
super().__init__()
|
||||
self.wrapped = wrapped
|
||||
self.hijack = hijack
|
||||
|
||||
def tokenize(self, texts):
|
||||
raise NotImplementedError
|
||||
|
||||
def encode_with_transformers(self, tokens):
|
||||
raise NotImplementedError
|
||||
|
||||
def encode_embedding_init_text(self, init_text, nvpt):
|
||||
raise NotImplementedError
|
||||
|
||||
def tokenize_line(self, line, used_custom_terms, hijack_comments):
|
||||
if opts.enable_emphasis:
|
||||
parsed = prompt_parser.parse_prompt_attention(line)
|
||||
else:
|
||||
parsed = [[line, 1.0]]
|
||||
|
||||
tokenized = self.tokenize([text for text, _ in parsed])
|
||||
|
||||
fixes = []
|
||||
remade_tokens = []
|
||||
multipliers = []
|
||||
last_comma = -1
|
||||
|
||||
for tokens, (text, weight) in zip(tokenized, parsed):
|
||||
i = 0
|
||||
while i < len(tokens):
|
||||
token = tokens[i]
|
||||
|
||||
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
|
||||
|
||||
if token == self.comma_token:
|
||||
last_comma = len(remade_tokens)
|
||||
elif opts.comma_padding_backtrack != 0 and max(len(remade_tokens), 1) % 75 == 0 and last_comma != -1 and len(remade_tokens) - last_comma <= opts.comma_padding_backtrack:
|
||||
last_comma += 1
|
||||
reloc_tokens = remade_tokens[last_comma:]
|
||||
reloc_mults = multipliers[last_comma:]
|
||||
|
||||
remade_tokens = remade_tokens[:last_comma]
|
||||
length = len(remade_tokens)
|
||||
|
||||
rem = int(math.ceil(length / 75)) * 75 - length
|
||||
remade_tokens += [self.id_end] * rem + reloc_tokens
|
||||
multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults
|
||||
|
||||
if embedding is None:
|
||||
remade_tokens.append(token)
|
||||
multipliers.append(weight)
|
||||
i += 1
|
||||
else:
|
||||
emb_len = int(embedding.vec.shape[0])
|
||||
iteration = len(remade_tokens) // 75
|
||||
if (len(remade_tokens) + emb_len) // 75 != iteration:
|
||||
rem = (75 * (iteration + 1) - len(remade_tokens))
|
||||
remade_tokens += [self.id_end] * rem
|
||||
multipliers += [1.0] * rem
|
||||
iteration += 1
|
||||
fixes.append((iteration, (len(remade_tokens) % 75, embedding)))
|
||||
remade_tokens += [0] * emb_len
|
||||
multipliers += [weight] * emb_len
|
||||
used_custom_terms.append((embedding.name, embedding.checksum()))
|
||||
i += embedding_length_in_tokens
|
||||
|
||||
token_count = len(remade_tokens)
|
||||
prompt_target_length = get_target_prompt_token_count(token_count)
|
||||
tokens_to_add = prompt_target_length - len(remade_tokens)
|
||||
|
||||
remade_tokens = remade_tokens + [self.id_end] * tokens_to_add
|
||||
multipliers = multipliers + [1.0] * tokens_to_add
|
||||
|
||||
return remade_tokens, fixes, multipliers, token_count
|
||||
|
||||
def process_text(self, texts):
|
||||
used_custom_terms = []
|
||||
remade_batch_tokens = []
|
||||
hijack_comments = []
|
||||
hijack_fixes = []
|
||||
token_count = 0
|
||||
|
||||
cache = {}
|
||||
batch_multipliers = []
|
||||
for line in texts:
|
||||
if line in cache:
|
||||
remade_tokens, fixes, multipliers = cache[line]
|
||||
else:
|
||||
remade_tokens, fixes, multipliers, current_token_count = self.tokenize_line(line, used_custom_terms, hijack_comments)
|
||||
token_count = max(current_token_count, token_count)
|
||||
|
||||
cache[line] = (remade_tokens, fixes, multipliers)
|
||||
|
||||
remade_batch_tokens.append(remade_tokens)
|
||||
hijack_fixes.append(fixes)
|
||||
batch_multipliers.append(multipliers)
|
||||
|
||||
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
|
||||
|
||||
def process_text_old(self, texts):
|
||||
id_start = self.id_start
|
||||
id_end = self.id_end
|
||||
maxlen = self.wrapped.max_length # you get to stay at 77
|
||||
used_custom_terms = []
|
||||
remade_batch_tokens = []
|
||||
hijack_comments = []
|
||||
hijack_fixes = []
|
||||
token_count = 0
|
||||
|
||||
cache = {}
|
||||
batch_tokens = self.tokenize(texts)
|
||||
batch_multipliers = []
|
||||
for tokens in batch_tokens:
|
||||
tuple_tokens = tuple(tokens)
|
||||
|
||||
if tuple_tokens in cache:
|
||||
remade_tokens, fixes, multipliers = cache[tuple_tokens]
|
||||
else:
|
||||
fixes = []
|
||||
remade_tokens = []
|
||||
multipliers = []
|
||||
mult = 1.0
|
||||
|
||||
i = 0
|
||||
while i < len(tokens):
|
||||
token = tokens[i]
|
||||
|
||||
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
|
||||
|
||||
mult_change = self.token_mults.get(token) if opts.enable_emphasis else None
|
||||
if mult_change is not None:
|
||||
mult *= mult_change
|
||||
i += 1
|
||||
elif embedding is None:
|
||||
remade_tokens.append(token)
|
||||
multipliers.append(mult)
|
||||
i += 1
|
||||
else:
|
||||
emb_len = int(embedding.vec.shape[0])
|
||||
fixes.append((len(remade_tokens), embedding))
|
||||
remade_tokens += [0] * emb_len
|
||||
multipliers += [mult] * emb_len
|
||||
used_custom_terms.append((embedding.name, embedding.checksum()))
|
||||
i += embedding_length_in_tokens
|
||||
|
||||
if len(remade_tokens) > maxlen - 2:
|
||||
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
|
||||
ovf = remade_tokens[maxlen - 2:]
|
||||
overflowing_words = [vocab.get(int(x), "") for x in ovf]
|
||||
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
|
||||
hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
|
||||
|
||||
token_count = len(remade_tokens)
|
||||
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
|
||||
remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
|
||||
cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
|
||||
|
||||
multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
|
||||
multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
|
||||
|
||||
remade_batch_tokens.append(remade_tokens)
|
||||
hijack_fixes.append(fixes)
|
||||
batch_multipliers.append(multipliers)
|
||||
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
|
||||
|
||||
def forward(self, text):
|
||||
use_old = opts.use_old_emphasis_implementation
|
||||
if use_old:
|
||||
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text)
|
||||
else:
|
||||
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
|
||||
|
||||
self.hijack.comments += hijack_comments
|
||||
|
||||
if len(used_custom_terms) > 0:
|
||||
self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
|
||||
|
||||
if use_old:
|
||||
self.hijack.fixes = hijack_fixes
|
||||
return self.process_tokens(remade_batch_tokens, batch_multipliers)
|
||||
|
||||
z = None
|
||||
i = 0
|
||||
while max(map(len, remade_batch_tokens)) != 0:
|
||||
rem_tokens = [x[75:] for x in remade_batch_tokens]
|
||||
rem_multipliers = [x[75:] for x in batch_multipliers]
|
||||
|
||||
self.hijack.fixes = []
|
||||
for unfiltered in hijack_fixes:
|
||||
fixes = []
|
||||
for fix in unfiltered:
|
||||
if fix[0] == i:
|
||||
fixes.append(fix[1])
|
||||
self.hijack.fixes.append(fixes)
|
||||
|
||||
tokens = []
|
||||
multipliers = []
|
||||
for j in range(len(remade_batch_tokens)):
|
||||
if len(remade_batch_tokens[j]) > 0:
|
||||
tokens.append(remade_batch_tokens[j][:75])
|
||||
multipliers.append(batch_multipliers[j][:75])
|
||||
else:
|
||||
tokens.append([self.id_end] * 75)
|
||||
multipliers.append([1.0] * 75)
|
||||
|
||||
z1 = self.process_tokens(tokens, multipliers)
|
||||
z = z1 if z is None else torch.cat((z, z1), axis=-2)
|
||||
|
||||
remade_batch_tokens = rem_tokens
|
||||
batch_multipliers = rem_multipliers
|
||||
i += 1
|
||||
|
||||
return z
|
||||
|
||||
def process_tokens(self, remade_batch_tokens, batch_multipliers):
|
||||
if not opts.use_old_emphasis_implementation:
|
||||
remade_batch_tokens = [[self.id_start] + x[:75] + [self.id_end] for x in remade_batch_tokens]
|
||||
batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers]
|
||||
|
||||
tokens = torch.asarray(remade_batch_tokens).to(devices.device)
|
||||
|
||||
if self.id_end != self.id_pad:
|
||||
for batch_pos in range(len(remade_batch_tokens)):
|
||||
index = remade_batch_tokens[batch_pos].index(self.id_end)
|
||||
tokens[batch_pos, index+1:tokens.shape[1]] = self.id_pad
|
||||
|
||||
z = self.encode_with_transformers(tokens)
|
||||
|
||||
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
|
||||
batch_multipliers_of_same_length = [x + [1.0] * (75 - len(x)) for x in batch_multipliers]
|
||||
batch_multipliers = torch.asarray(batch_multipliers_of_same_length).to(devices.device)
|
||||
original_mean = z.mean()
|
||||
z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
|
||||
new_mean = z.mean()
|
||||
z *= original_mean / new_mean
|
||||
|
||||
return z
|
||||
|
||||
|
||||
class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase):
|
||||
def __init__(self, wrapped, hijack):
|
||||
super().__init__(wrapped, hijack)
|
||||
self.tokenizer = wrapped.tokenizer
|
||||
self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ',</w>'][0]
|
||||
|
||||
self.token_mults = {}
|
||||
tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k]
|
||||
for text, ident in tokens_with_parens:
|
||||
mult = 1.0
|
||||
for c in text:
|
||||
if c == '[':
|
||||
mult /= 1.1
|
||||
if c == ']':
|
||||
mult *= 1.1
|
||||
if c == '(':
|
||||
mult *= 1.1
|
||||
if c == ')':
|
||||
mult /= 1.1
|
||||
|
||||
if mult != 1.0:
|
||||
self.token_mults[ident] = mult
|
||||
|
||||
self.id_start = self.wrapped.tokenizer.bos_token_id
|
||||
self.id_end = self.wrapped.tokenizer.eos_token_id
|
||||
self.id_pad = self.id_end
|
||||
|
||||
def tokenize(self, texts):
|
||||
tokenized = self.wrapped.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"]
|
||||
|
||||
return tokenized
|
||||
|
||||
def encode_with_transformers(self, tokens):
|
||||
outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
|
||||
|
||||
if opts.CLIP_stop_at_last_layers > 1:
|
||||
z = outputs.hidden_states[-opts.CLIP_stop_at_last_layers]
|
||||
z = self.wrapped.transformer.text_model.final_layer_norm(z)
|
||||
else:
|
||||
z = outputs.last_hidden_state
|
||||
|
||||
return z
|
||||
|
||||
def encode_embedding_init_text(self, init_text, nvpt):
|
||||
embedding_layer = self.wrapped.transformer.text_model.embeddings
|
||||
ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"]
|
||||
embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
|
||||
|
||||
return embedded
|
@@ -199,8 +199,8 @@ def sample_plms(self,
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None):
|
||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, dynamic_threshold=None):
|
||||
b, *_, device = *x.shape, x.device
|
||||
|
||||
def get_model_output(x, t):
|
||||
@@ -249,6 +249,8 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
if quantize_denoised:
|
||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||
if dynamic_threshold is not None:
|
||||
pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
|
||||
# direction pointing to x_t
|
||||
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
||||
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
||||
@@ -321,12 +323,16 @@ def should_hijack_inpainting(checkpoint_info):
|
||||
|
||||
|
||||
def do_inpainting_hijack():
|
||||
ldm.models.diffusion.ddpm.get_unconditional_conditioning = get_unconditional_conditioning
|
||||
# most of this stuff seems to no longer be needed because it is already included into SD2.0
|
||||
# LatentInpaintDiffusion remains because SD2.0's LatentInpaintDiffusion can't be loaded without specifying a checkpoint
|
||||
# p_sample_plms is needed because PLMS can't work with dicts as conditionings
|
||||
# this file should be cleaned up later if weverything tuens out to work fine
|
||||
|
||||
# ldm.models.diffusion.ddpm.get_unconditional_conditioning = get_unconditional_conditioning
|
||||
ldm.models.diffusion.ddpm.LatentInpaintDiffusion = LatentInpaintDiffusion
|
||||
|
||||
ldm.models.diffusion.ddim.DDIMSampler.p_sample_ddim = p_sample_ddim
|
||||
ldm.models.diffusion.ddim.DDIMSampler.sample = sample_ddim
|
||||
# ldm.models.diffusion.ddim.DDIMSampler.p_sample_ddim = p_sample_ddim
|
||||
# ldm.models.diffusion.ddim.DDIMSampler.sample = sample_ddim
|
||||
|
||||
ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
|
||||
ldm.models.diffusion.plms.PLMSSampler.sample = sample_plms
|
||||
|
||||
# ldm.models.diffusion.plms.PLMSSampler.sample = sample_plms
|
||||
|
37
modules/sd_hijack_open_clip.py
Normal file
37
modules/sd_hijack_open_clip.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import open_clip.tokenizer
|
||||
import torch
|
||||
|
||||
from modules import sd_hijack_clip, devices
|
||||
from modules.shared import opts
|
||||
|
||||
tokenizer = open_clip.tokenizer._tokenizer
|
||||
|
||||
|
||||
class FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase):
|
||||
def __init__(self, wrapped, hijack):
|
||||
super().__init__(wrapped, hijack)
|
||||
|
||||
self.comma_token = [v for k, v in tokenizer.encoder.items() if k == ',</w>'][0]
|
||||
self.id_start = tokenizer.encoder["<start_of_text>"]
|
||||
self.id_end = tokenizer.encoder["<end_of_text>"]
|
||||
self.id_pad = 0
|
||||
|
||||
def tokenize(self, texts):
|
||||
assert not opts.use_old_emphasis_implementation, 'Old emphasis implementation not supported for Open Clip'
|
||||
|
||||
tokenized = [tokenizer.encode(text) for text in texts]
|
||||
|
||||
return tokenized
|
||||
|
||||
def encode_with_transformers(self, tokens):
|
||||
# set self.wrapped.layer_idx here according to opts.CLIP_stop_at_last_layers
|
||||
z = self.wrapped.encode_with_transformer(tokens)
|
||||
|
||||
return z
|
||||
|
||||
def encode_embedding_init_text(self, init_text, nvpt):
|
||||
ids = tokenizer.encode(init_text)
|
||||
ids = torch.asarray([ids], device=devices.device, dtype=torch.int)
|
||||
embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0)
|
||||
|
||||
return embedded
|
@@ -127,7 +127,8 @@ class InterruptedException(BaseException):
|
||||
class VanillaStableDiffusionSampler:
|
||||
def __init__(self, constructor, sd_model):
|
||||
self.sampler = constructor(sd_model)
|
||||
self.orig_p_sample_ddim = self.sampler.p_sample_ddim if hasattr(self.sampler, 'p_sample_ddim') else self.sampler.p_sample_plms
|
||||
self.is_plms = hasattr(self.sampler, 'p_sample_plms')
|
||||
self.orig_p_sample_ddim = self.sampler.p_sample_plms if self.is_plms else self.sampler.p_sample_ddim
|
||||
self.mask = None
|
||||
self.nmask = None
|
||||
self.init_latent = None
|
||||
@@ -218,7 +219,6 @@ class VanillaStableDiffusionSampler:
|
||||
self.mask = p.mask if hasattr(p, 'mask') else None
|
||||
self.nmask = p.nmask if hasattr(p, 'nmask') else None
|
||||
|
||||
|
||||
def adjust_steps_if_invalid(self, p, num_steps):
|
||||
if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'):
|
||||
valid_step = 999 / (1000 // num_steps)
|
||||
@@ -227,7 +227,6 @@ class VanillaStableDiffusionSampler:
|
||||
|
||||
return num_steps
|
||||
|
||||
|
||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||
steps, t_enc = setup_img2img_steps(p, steps)
|
||||
steps = self.adjust_steps_if_invalid(p, steps)
|
||||
@@ -260,9 +259,10 @@ class VanillaStableDiffusionSampler:
|
||||
steps = self.adjust_steps_if_invalid(p, steps or p.steps)
|
||||
|
||||
# Wrap the conditioning models with additional image conditioning for inpainting model
|
||||
# dummy_for_plms is needed because PLMS code checks the first item in the dict to have the right shape
|
||||
if image_conditioning is not None:
|
||||
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
|
||||
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
|
||||
conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_concat": [image_conditioning]}
|
||||
unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_concat": [image_conditioning]}
|
||||
|
||||
samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
|
||||
|
||||
@@ -350,7 +350,9 @@ class TorchHijack:
|
||||
|
||||
class KDiffusionSampler:
|
||||
def __init__(self, funcname, sd_model):
|
||||
self.model_wrap = k_diffusion.external.CompVisDenoiser(sd_model, quantize=shared.opts.enable_quantization)
|
||||
denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
|
||||
|
||||
self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization)
|
||||
self.funcname = funcname
|
||||
self.func = getattr(k_diffusion.sampling, self.funcname)
|
||||
self.extra_params = sampler_extra_params.get(funcname, [])
|
||||
|
@@ -11,17 +11,15 @@ import tqdm
|
||||
import modules.artists
|
||||
import modules.interrogate
|
||||
import modules.memmon
|
||||
import modules.sd_models
|
||||
import modules.styles
|
||||
import modules.devices as devices
|
||||
from modules import sd_samplers, sd_models, localization, sd_vae, extensions, script_loading
|
||||
from modules.hypernetworks import hypernetwork
|
||||
from modules import localization, sd_vae, extensions, script_loading
|
||||
from modules.paths import models_path, script_path, sd_path
|
||||
|
||||
sd_model_file = os.path.join(script_path, 'model.ckpt')
|
||||
default_sd_model_file = sd_model_file
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config", type=str, default=os.path.join(sd_path, "configs/stable-diffusion/v1-inference.yaml"), help="path to config which constructs model",)
|
||||
parser.add_argument("--config", type=str, default=os.path.join(script_path, "v1-inference.yaml"), help="path to config which constructs model",)
|
||||
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
|
||||
parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints")
|
||||
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
|
||||
@@ -121,10 +119,12 @@ xformers_available = False
|
||||
config_filename = cmd_opts.ui_settings_file
|
||||
|
||||
os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True)
|
||||
hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
|
||||
hypernetworks = {}
|
||||
loaded_hypernetwork = None
|
||||
|
||||
|
||||
def reload_hypernetworks():
|
||||
from modules.hypernetworks import hypernetwork
|
||||
global hypernetworks
|
||||
|
||||
hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
|
||||
@@ -206,10 +206,11 @@ class State:
|
||||
if self.current_latent is None:
|
||||
return
|
||||
|
||||
import modules.sd_samplers
|
||||
if opts.show_progress_grid:
|
||||
self.current_image = sd_samplers.samples_to_image_grid(self.current_latent)
|
||||
self.current_image = modules.sd_samplers.samples_to_image_grid(self.current_latent)
|
||||
else:
|
||||
self.current_image = sd_samplers.sample_to_image(self.current_latent)
|
||||
self.current_image = modules.sd_samplers.sample_to_image(self.current_latent)
|
||||
|
||||
self.current_image_sampling_step = self.sampling_step
|
||||
|
||||
@@ -248,6 +249,21 @@ def options_section(section_identifier, options_dict):
|
||||
return options_dict
|
||||
|
||||
|
||||
def list_checkpoint_tiles():
|
||||
import modules.sd_models
|
||||
return modules.sd_models.checkpoint_tiles()
|
||||
|
||||
|
||||
def refresh_checkpoints():
|
||||
import modules.sd_models
|
||||
return modules.sd_models.list_models()
|
||||
|
||||
|
||||
def list_samplers():
|
||||
import modules.sd_samplers
|
||||
return modules.sd_samplers.all_samplers
|
||||
|
||||
|
||||
hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config}
|
||||
|
||||
options_templates = {}
|
||||
@@ -333,7 +349,7 @@ options_templates.update(options_section(('training', "Training"), {
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
||||
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models),
|
||||
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
|
||||
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
|
||||
"sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": sd_vae.vae_list}, refresh=sd_vae.refresh_vae_list),
|
||||
"sd_vae_as_default": OptionInfo(False, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
|
||||
@@ -385,7 +401,7 @@ options_templates.update(options_section(('ui', "User interface"), {
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
|
||||
"hide_samplers": OptionInfo([], "Hide samplers in user interface (requires restart)", gr.CheckboxGroup, lambda: {"choices": [x.name for x in sd_samplers.all_samplers]}),
|
||||
"hide_samplers": OptionInfo([], "Hide samplers in user interface (requires restart)", gr.CheckboxGroup, lambda: {"choices": [x.name for x in list_samplers()]}),
|
||||
"eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||
"eta_ancestral": OptionInfo(1.0, "eta (noise multiplier) for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||
"ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
|
||||
|
@@ -64,7 +64,8 @@ class EmbeddingDatabase:
|
||||
|
||||
self.word_embeddings[embedding.name] = embedding
|
||||
|
||||
ids = model.cond_stage_model.tokenizer([embedding.name], add_special_tokens=False)['input_ids'][0]
|
||||
# TODO changing between clip and open clip changes tokenization, which will cause embeddings to stop working
|
||||
ids = model.cond_stage_model.tokenize([embedding.name])[0]
|
||||
|
||||
first_id = ids[0]
|
||||
if first_id not in self.ids_lookup:
|
||||
@@ -155,13 +156,11 @@ class EmbeddingDatabase:
|
||||
|
||||
def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
|
||||
cond_model = shared.sd_model.cond_stage_model
|
||||
embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
|
||||
|
||||
with devices.autocast():
|
||||
cond_model([""]) # will send cond model to GPU if lowvram/medvram is active
|
||||
|
||||
ids = cond_model.tokenizer(init_text, max_length=num_vectors_per_token, return_tensors="pt", add_special_tokens=False)["input_ids"]
|
||||
embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
|
||||
embedded = cond_model.encode_embedding_init_text(init_text, num_vectors_per_token)
|
||||
vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
|
||||
|
||||
for i in range(num_vectors_per_token):
|
||||
|
@@ -478,9 +478,7 @@ def create_toprow(is_img2img):
|
||||
if is_img2img:
|
||||
with gr.Column(scale=1, elem_id="interrogate_col"):
|
||||
button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate")
|
||||
|
||||
if cmd_opts.deepdanbooru:
|
||||
button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
|
||||
button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
|
||||
|
||||
with gr.Column(scale=1):
|
||||
with gr.Row():
|
||||
@@ -1004,11 +1002,10 @@ def create_ui(wrap_gradio_gpu_call):
|
||||
outputs=[img2img_prompt],
|
||||
)
|
||||
|
||||
if cmd_opts.deepdanbooru:
|
||||
img2img_deepbooru.click(
|
||||
fn=interrogate_deepbooru,
|
||||
inputs=[init_img],
|
||||
outputs=[img2img_prompt],
|
||||
img2img_deepbooru.click(
|
||||
fn=interrogate_deepbooru,
|
||||
inputs=[init_img],
|
||||
outputs=[img2img_prompt],
|
||||
)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user