wip incorrect OFT implementation

This commit is contained in:
v0xie
2023-10-17 23:35:50 -07:00
parent 861cbd5636
commit ec718f76b5
2 changed files with 87 additions and 0 deletions

View File

@@ -11,6 +11,7 @@ import network_ia3
import network_lokr
import network_full
import network_norm
import network_oft
import torch
from typing import Union
@@ -28,6 +29,7 @@ module_types = [
network_full.ModuleTypeFull(),
network_norm.ModuleTypeNorm(),
network_glora.ModuleTypeGLora(),
network_oft.ModuleTypeOFT(),
]
@@ -183,6 +185,9 @@ def load_network(name, network_on_disk):
elif sd_module is None and "lora_te1_text_model" in key_network_without_network_parts:
key = key_network_without_network_parts.replace("lora_te1_text_model", "0_transformer_text_model")
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
elif sd_module is None and "oft_unet" in key_network_without_network_parts:
key = key_network_without_network_parts.replace("oft_unet", "diffusion_model")
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
# some SD1 Loras also have correct compvis keys
if sd_module is None: