SD3 lora support

This commit is contained in:
AUTOMATIC1111
2024-07-15 08:31:55 +03:00
parent b2453d280a
commit 7e5cdaab4b
6 changed files with 106 additions and 24 deletions

View File

@@ -1,6 +1,7 @@
import torch
import lyco_helpers
import modules.models.sd3.mmdit
import network
from modules import devices
@@ -10,6 +11,13 @@ class ModuleTypeLora(network.ModuleType):
if all(x in weights.w for x in ["lora_up.weight", "lora_down.weight"]):
return NetworkModuleLora(net, weights)
if all(x in weights.w for x in ["lora_A.weight", "lora_B.weight"]):
w = weights.w.copy()
weights.w.clear()
weights.w.update({"lora_up.weight": w["lora_B.weight"], "lora_down.weight": w["lora_A.weight"]})
return NetworkModuleLora(net, weights)
return None
@@ -29,7 +37,7 @@ class NetworkModuleLora(network.NetworkModule):
if weight is None and none_ok:
return None
is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention]
is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention, modules.models.sd3.mmdit.QkvLinear]
is_conv = type(self.sd_module) in [torch.nn.Conv2d]
if is_linear: