remove duplicated code

This commit is contained in:
AUTOMATIC1111
2023-10-14 12:14:56 +03:00
parent 891ccb767c
commit a8cbe50c9f
2 changed files with 42 additions and 63 deletions

View File

@@ -15,7 +15,7 @@ import torch
from typing import Union
from modules import shared, devices, sd_models, errors, scripts, sd_hijack
from modules.textual_inversion.textual_inversion import Embedding
import modules.textual_inversion.textual_inversion as textual_inversion
from lora_logger import logger
@@ -210,34 +210,7 @@ def load_network(name, network_on_disk):
embeddings = {}
for emb_name, data in bundle_embeddings.items():
# textual inversion embeddings
if 'string_to_param' in data:
param_dict = data['string_to_param']
param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
emb = next(iter(param_dict.items()))[1]
vec = emb.detach().to(devices.device, dtype=torch.float32)
shape = vec.shape[-1]
vectors = vec.shape[0]
elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding
vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()}
shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1]
vectors = data['clip_g'].shape[0]
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
emb = next(iter(data.values()))
if len(emb.shape) == 1:
emb = emb.unsqueeze(0)
vec = emb.detach().to(devices.device, dtype=torch.float32)
shape = vec.shape[-1]
vectors = vec.shape[0]
else:
raise Exception(f"Couldn't identify {emb_name} in lora: {name} as neither textual inversion embedding nor diffuser concept.")
embedding = Embedding(vec, emb_name)
embedding.vectors = vectors
embedding.shape = shape
embedding = textual_inversion.create_embedding_from_data(data, emb_name, filename=network_on_disk.filename + "/" + emb_name)
embedding.loaded = None
embeddings[emb_name] = embedding