lora extension rework to include other types of networks

This commit is contained in:
AUTOMATIC1111
2023-07-16 23:13:55 +03:00
parent 7d26c479ee
commit b75b004fe6
10 changed files with 777 additions and 589 deletions

View File

@@ -0,0 +1,15 @@
import torch
def make_weight_cp(t, wa, wb):
temp = torch.einsum('i j k l, j r -> i r k l', t, wb)
return torch.einsum('i j k l, i r -> r j k l', temp, wa)
def rebuild_conventional(up, down, shape, dyn_dim=None):
up = up.reshape(up.size(0), -1)
down = down.reshape(down.size(0), -1)
if dyn_dim is not None:
up = up[:, :dyn_dim]
down = down[:dyn_dim, :]
return (up @ down).reshape(shape)