SD3 lora support

This commit is contained in:
AUTOMATIC1111
2024-07-15 08:31:55 +03:00
parent b2453d280a
commit 7e5cdaab4b
6 changed files with 106 additions and 24 deletions

View File

@@ -7,6 +7,7 @@ import torch.nn as nn
import torch.nn.functional as F
from modules import sd_models, cache, errors, hashes, shared
import modules.models.sd3.mmdit
NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd_module'])
@@ -114,7 +115,10 @@ class NetworkModule:
self.sd_key = weights.sd_key
self.sd_module = weights.sd_module
if hasattr(self.sd_module, 'weight'):
if isinstance(self.sd_module, modules.models.sd3.mmdit.QkvLinear):
s = self.sd_module.weight.shape
self.shape = (s[0] // 3, s[1])
elif hasattr(self.sd_module, 'weight'):
self.shape = self.sd_module.weight.shape
elif isinstance(self.sd_module, nn.MultiheadAttention):
# For now, only self-attn use Pytorch's MHA