Merge branch 'ae'

This commit is contained in:
AUTOMATIC
2022-10-21 13:34:48 +03:00
9 changed files with 326 additions and 21 deletions

View File

@@ -19,6 +19,7 @@ attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
def apply_optimizations():
undo_optimizations()
@@ -167,11 +168,11 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
remade_tokens = remade_tokens[:last_comma]
length = len(remade_tokens)
rem = int(math.ceil(length / 75)) * 75 - length
remade_tokens += [id_end] * rem + reloc_tokens
multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults
if embedding is None:
remade_tokens.append(token)
multipliers.append(weight)
@@ -223,7 +224,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
def process_text_old(self, text):
id_start = self.wrapped.tokenizer.bos_token_id
id_end = self.wrapped.tokenizer.eos_token_id
@@ -280,7 +280,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
token_count = len(remade_tokens)
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end]
remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
@@ -290,7 +290,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
hijack_fixes.append(fixes)
batch_multipliers.append(multipliers)
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
def forward(self, text):
use_old = opts.use_old_emphasis_implementation
if use_old:
@@ -302,11 +302,11 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
if len(used_custom_terms) > 0:
self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
if use_old:
self.hijack.fixes = hijack_fixes
return self.process_tokens(remade_batch_tokens, batch_multipliers)
z = None
i = 0
while max(map(len, remade_batch_tokens)) != 0:
@@ -320,7 +320,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
if fix[0] == i:
fixes.append(fix[1])
self.hijack.fixes.append(fixes)
tokens = []
multipliers = []
for j in range(len(remade_batch_tokens)):
@@ -333,19 +333,19 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
z1 = self.process_tokens(tokens, multipliers)
z = z1 if z is None else torch.cat((z, z1), axis=-2)
z = shared.aesthetic_clip(z, remade_batch_tokens)
remade_batch_tokens = rem_tokens
batch_multipliers = rem_multipliers
i += 1
return z
def process_tokens(self, remade_batch_tokens, batch_multipliers):
if not opts.use_old_emphasis_implementation:
remade_batch_tokens = [[self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in remade_batch_tokens]
batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers]
tokens = torch.asarray(remade_batch_tokens).to(device)
outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
@@ -385,8 +385,8 @@ class EmbeddingsWithFixes(torch.nn.Module):
for fixes, tensor in zip(batch_fixes, inputs_embeds):
for offset, embedding in fixes:
emb = embedding.vec
emb_len = min(tensor.shape[0]-offset-1, emb.shape[0])
tensor = torch.cat([tensor[0:offset+1], emb[0:emb_len], tensor[offset+1+emb_len:]])
emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])
vecs.append(tensor)