This commit is contained in:
AUTOMATIC1111
2024-06-16 08:13:23 +03:00
parent 5b2a60b8e2
commit 79de09c3df
2 changed files with 16 additions and 13 deletions

View File

@@ -1,6 +1,7 @@
### This file contains impls for underlying related models (CLIP, T5, etc)
import torch, math
import torch
import math
from torch import nn
from transformers import CLIPTokenizer, T5TokenizerFast
@@ -14,7 +15,7 @@ def attention(q, k, v, heads, mask=None):
"""Convenience wrapper around a basic attention operation"""
b, _, dim_head = q.shape
dim_head //= heads
q, k, v = map(lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2), (q, k, v))
q, k, v = [t.view(b, -1, heads, dim_head).transpose(1, 2) for t in (q, k, v)]
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
return out.transpose(1, 2).reshape(b, -1, heads * dim_head)
@@ -89,8 +90,8 @@ class CLIPEncoder(torch.nn.Module):
if intermediate_output < 0:
intermediate_output = len(self.layers) + intermediate_output
intermediate = None
for i, l in enumerate(self.layers):
x = l(x, mask)
for i, layer in enumerate(self.layers):
x = layer(x, mask)
if i == intermediate_output:
intermediate = x.clone()
return x, intermediate
@@ -215,7 +216,7 @@ class SD3Tokenizer:
class ClipTokenWeightEncoder:
def encode_token_weights(self, token_weight_pairs):
tokens = list(map(lambda a: a[0], token_weight_pairs[0]))
tokens = [a[0] for a in token_weight_pairs[0]]
out, pooled = self([tokens])
if pooled is not None:
first_pooled = pooled[0:1].cpu()
@@ -229,7 +230,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
"""Uses the CLIP transformer encoder for text (from huggingface)"""
LAYERS = ["last", "pooled", "hidden"]
def __init__(self, device="cpu", max_length=77, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=CLIPTextModel,
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, return_projected_pooled=True):
special_tokens=None, layer_norm_hidden_state=True, return_projected_pooled=True):
super().__init__()
assert layer in self.LAYERS
self.transformer = model_class(textmodel_json_config, dtype, device)
@@ -240,7 +241,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
param.requires_grad = False
self.layer = layer
self.layer_idx = None
self.special_tokens = special_tokens
self.special_tokens = special_tokens if special_tokens is not None else {"start": 49406, "end": 49407, "pad": 49407}
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
self.layer_norm_hidden_state = layer_norm_hidden_state
self.return_projected_pooled = return_projected_pooled
@@ -465,8 +466,8 @@ class T5Stack(torch.nn.Module):
intermediate = None
x = self.embed_tokens(input_ids)
past_bias = None
for i, l in enumerate(self.block):
x, past_bias = l(x, past_bias)
for i, layer in enumerate(self.block):
x, past_bias = layer(x, past_bias)
if i == intermediate_output:
intermediate = x.clone()
x = self.final_layer_norm(x)