mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-04 11:12:35 +00:00
textual inversion support for SDXL
This commit is contained in:
@@ -197,7 +197,7 @@ class StableDiffusionModelHijack:
|
||||
conditioner.embedders[i] = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, self)
|
||||
text_cond_models.append(conditioner.embedders[i])
|
||||
if typename == 'FrozenOpenCLIPEmbedder2':
|
||||
embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self)
|
||||
embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self, textual_inversion_key='clip_g')
|
||||
conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords(embedder, self)
|
||||
text_cond_models.append(conditioner.embedders[i])
|
||||
|
||||
@@ -292,10 +292,11 @@ class StableDiffusionModelHijack:
|
||||
|
||||
|
||||
class EmbeddingsWithFixes(torch.nn.Module):
|
||||
def __init__(self, wrapped, embeddings):
|
||||
def __init__(self, wrapped, embeddings, textual_inversion_key='clip_l'):
|
||||
super().__init__()
|
||||
self.wrapped = wrapped
|
||||
self.embeddings = embeddings
|
||||
self.textual_inversion_key = textual_inversion_key
|
||||
|
||||
def forward(self, input_ids):
|
||||
batch_fixes = self.embeddings.fixes
|
||||
@@ -309,7 +310,8 @@ class EmbeddingsWithFixes(torch.nn.Module):
|
||||
vecs = []
|
||||
for fixes, tensor in zip(batch_fixes, inputs_embeds):
|
||||
for offset, embedding in fixes:
|
||||
emb = devices.cond_cast_unet(embedding.vec)
|
||||
vec = embedding.vec[self.textual_inversion_key] if isinstance(embedding.vec, dict) else embedding.vec
|
||||
emb = devices.cond_cast_unet(vec)
|
||||
emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
|
||||
tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])
|
||||
|
||||
|
Reference in New Issue
Block a user