sd3 TI support

This commit is contained in:
AUTOMATIC1111
2024-07-07 16:36:53 +03:00
parent 1da4907927
commit 11cfe0dd05
3 changed files with 26 additions and 5 deletions

View File

@@ -5,6 +5,8 @@ import math
from torch import nn
from transformers import CLIPTokenizer, T5TokenizerFast
from modules import sd_hijack
#################################################################################################
### Core/Utility
@@ -110,9 +112,9 @@ class CLIPEncoder(torch.nn.Module):
class CLIPEmbeddings(torch.nn.Module):
def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None):
def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None, textual_inversion_key="clip_l"):
super().__init__()
self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim, dtype=dtype, device=device)
self.token_embedding = sd_hijack.TextualInversionEmbeddings(vocab_size, embed_dim, dtype=dtype, device=device, textual_inversion_key=textual_inversion_key)
self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
def forward(self, input_tokens):
@@ -127,7 +129,7 @@ class CLIPTextModel_(torch.nn.Module):
intermediate_size = config_dict["intermediate_size"]
intermediate_activation = config_dict["hidden_act"]
super().__init__()
self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device)
self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device, textual_inversion_key=config_dict.get('textual_inversion_key', 'clip_l'))
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device)
self.final_layer_norm = nn.LayerNorm(embed_dim, dtype=dtype, device=device)

View File

@@ -40,6 +40,7 @@ CLIPG_CONFIG = {
"intermediate_size": 5120,
"num_attention_heads": 20,
"num_hidden_layers": 32,
"textual_inversion_key": "clip_g",
}
T5_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/t5xxl_fp16.safetensors"
@@ -204,7 +205,10 @@ class SD3Cond(torch.nn.Module):
self.t5xxl.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
def encode_embedding_init_text(self, init_text, nvpt):
return torch.tensor([[0]], device=devices.device) # XXX
return self.model_lg.encode_embedding_init_text(init_text, nvpt)
def tokenize(self, texts):
return self.model_lg.tokenize(texts)
def medvram_modules(self):
return [self.clip_g, self.clip_l, self.t5xxl]