This commit is contained in:
MalumaDev
2022-10-14 10:56:41 +02:00
parent fdecb63685
commit bb57f30c2d
9 changed files with 172 additions and 38 deletions

View File

@@ -9,11 +9,14 @@ from torch.nn.functional import silu
import modules.textual_inversion.textual_inversion
from modules import prompt_parser, devices, sd_hijack_optimizations, shared
from modules.shared import opts, device, cmd_opts
from modules.shared import opts, device, cmd_opts, aesthetic_embeddings
from modules.sd_hijack_optimizations import invokeAI_mps_available
import ldm.modules.attention
import ldm.modules.diffusionmodules.model
from transformers import CLIPVisionModel, CLIPModel
import torch.optim as optim
import copy
attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
@@ -109,13 +112,29 @@ class StableDiffusionModelHijack:
_, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
return remade_batch_tokens[0], token_count, get_target_prompt_token_count(token_count)
def slerp(low, high, val):
low_norm = low/torch.norm(low, dim=1, keepdim=True)
high_norm = high/torch.norm(high, dim=1, keepdim=True)
omega = torch.acos((low_norm*high_norm).sum(1))
so = torch.sin(omega)
res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
return res
class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
def __init__(self, wrapped, hijack):
super().__init__()
self.wrapped = wrapped
self.clipModel = CLIPModel.from_pretrained(
self.wrapped.transformer.name_or_path
)
del self.clipModel.vision_model
self.hijack: StableDiffusionModelHijack = hijack
self.tokenizer = wrapped.tokenizer
# self.vision = CLIPVisionModel.from_pretrained(self.wrapped.transformer.name_or_path).eval()
self.image_embs_name = None
self.image_embs = None
self.load_image_embs(None)
self.token_mults = {}
self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ',</w>'][0]
@@ -136,6 +155,23 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
if mult != 1.0:
self.token_mults[ident] = mult
def set_aesthetic_params(self, aesthetic_lr, aesthetic_weight, aesthetic_steps, image_embs_name=None,
aesthetic_slerp=True):
self.slerp = aesthetic_slerp
self.aesthetic_lr = aesthetic_lr
self.aesthetic_weight = aesthetic_weight
self.aesthetic_steps = aesthetic_steps
self.load_image_embs(image_embs_name)
def load_image_embs(self, image_embs_name):
if image_embs_name is None or len(image_embs_name) == 0:
image_embs_name = None
if image_embs_name is not None and self.image_embs_name != image_embs_name:
self.image_embs_name = image_embs_name
self.image_embs = torch.load(aesthetic_embeddings[self.image_embs_name], map_location=device)
self.image_embs /= self.image_embs.norm(dim=-1, keepdim=True)
self.image_embs.requires_grad_(False)
def tokenize_line(self, line, used_custom_terms, hijack_comments):
id_end = self.wrapped.tokenizer.eos_token_id
@@ -333,7 +369,47 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
z1 = self.process_tokens(tokens, multipliers)
z = z1 if z is None else torch.cat((z, z1), axis=-2)
if len(text[
0]) != 0 and self.aesthetic_steps != 0 and self.aesthetic_lr != 0 and self.aesthetic_weight != 0 and self.image_embs_name != None:
if not opts.use_old_emphasis_implementation:
remade_batch_tokens = [
[self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in
remade_batch_tokens]
tokens = torch.asarray(remade_batch_tokens).to(device)
with torch.enable_grad():
model = copy.deepcopy(self.clipModel).to(device)
model.requires_grad_(True)
# We optimize the model to maximize the similarity
optimizer = optim.Adam(
model.text_model.parameters(), lr=self.aesthetic_lr
)
for i in range(self.aesthetic_steps):
text_embs = model.get_text_features(input_ids=tokens)
text_embs = text_embs / text_embs.norm(dim=-1, keepdim=True)
sim = text_embs @ self.image_embs.T
loss = -sim
optimizer.zero_grad()
loss.mean().backward()
optimizer.step()
zn = model.text_model(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
if opts.CLIP_stop_at_last_layers > 1:
zn = zn.hidden_states[-opts.CLIP_stop_at_last_layers]
zn = model.text_model.final_layer_norm(zn)
else:
zn = zn.last_hidden_state
model.cpu()
del model
if self.slerp:
z = slerp(z, zn, self.aesthetic_weight)
else:
z = z * (1 - self.aesthetic_weight) + zn * self.aesthetic_weight
remade_batch_tokens = rem_tokens
batch_multipliers = rem_multipliers
i += 1