Fix bugs for torch.nn.MultiheadAttention

This commit is contained in:
Kohaku-Blueleaf
2024-03-09 12:31:32 +08:00
parent 12bcacf413
commit 851c3d51ed
2 changed files with 13 additions and 4 deletions

View File

@@ -117,6 +117,12 @@ class NetworkModule:
if hasattr(self.sd_module, 'weight'):
self.shape = self.sd_module.weight.shape
elif isinstance(self.sd_module, nn.MultiheadAttention):
# For now, only self-attn use Pytorch's MHA
# So assume all qkvo proj have same shape
self.shape = self.sd_module.out_proj.weight.shape
else:
self.shape = None
self.ops = None
self.extra_kwargs = {}
@@ -146,7 +152,7 @@ class NetworkModule:
self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None
self.scale = weights.w["scale"].item() if "scale" in weights.w else None
self.dora_scale = weights.w["dora_scale"] if "dora_scale" in weights.w else None
self.dora_scale = weights.w.get("dora_scale", None)
self.dora_mean_dim = tuple(i for i in range(len(self.shape)) if i != 1)
def multiplier(self):