support specifying te and unet weights separately

update lora code
support full module
This commit is contained in:
AUTOMATIC1111
2023-07-17 09:00:47 +03:00
parent 46466f09d0
commit 238adeaffb
10 changed files with 151 additions and 78 deletions

View File

@@ -1,5 +1,6 @@
import torch
import lyco_helpers
import network
from modules import devices
@@ -16,29 +17,42 @@ class NetworkModuleLora(network.NetworkModule):
def __init__(self, net: network.Network, weights: network.NetworkWeights):
super().__init__(net, weights)
self.up = self.create_module(weights.w["lora_up.weight"])
self.down = self.create_module(weights.w["lora_down.weight"])
self.alpha = weights.w["alpha"] if "alpha" in weights.w else None
self.up_model = self.create_module(weights.w, "lora_up.weight")
self.down_model = self.create_module(weights.w, "lora_down.weight")
self.mid_model = self.create_module(weights.w, "lora_mid.weight", none_ok=True)
self.dim = weights.w["lora_down.weight"].shape[0]
def create_module(self, weights, key, none_ok=False):
weight = weights.get(key)
def create_module(self, weight, none_ok=False):
if weight is None and none_ok:
return None
if type(self.sd_module) == torch.nn.Linear:
is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention]
is_conv = type(self.sd_module) in [torch.nn.Conv2d]
if is_linear:
weight = weight.reshape(weight.shape[0], -1)
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
elif type(self.sd_module) == torch.nn.modules.linear.NonDynamicallyQuantizableLinear:
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
elif type(self.sd_module) == torch.nn.MultiheadAttention:
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
elif type(self.sd_module) == torch.nn.Conv2d and weight.shape[2:] == (1, 1):
elif is_conv and key == "lora_down.weight" or key == "dyn_up":
if len(weight.shape) == 2:
weight = weight.reshape(weight.shape[0], -1, 1, 1)
if weight.shape[2] != 1 or weight.shape[3] != 1:
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False)
else:
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
elif is_conv and key == "lora_mid.weight":
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False)
elif is_conv and key == "lora_up.weight" or key == "dyn_down":
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
elif type(self.sd_module) == torch.nn.Conv2d and weight.shape[2:] == (3, 3):
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (3, 3), bias=False)
else:
print(f'Network layer {self.network_key} matched a layer with unsupported type: {type(self.sd_module).__name__}')
return None
raise AssertionError(f'Lora layer {self.network_key} matched a layer with unsupported type: {type(self.sd_module).__name__}')
with torch.no_grad():
if weight.shape != module.weight.shape:
weight = weight.reshape(module.weight.shape)
module.weight.copy_(weight)
module.to(device=devices.cpu, dtype=devices.dtype)
@@ -46,25 +60,27 @@ class NetworkModuleLora(network.NetworkModule):
return module
def calc_updown(self, target):
up = self.up.weight.to(target.device, dtype=target.dtype)
down = self.down.weight.to(target.device, dtype=target.dtype)
def calc_updown(self, orig_weight):
up = self.up_model.weight.to(orig_weight.device, dtype=orig_weight.dtype)
down = self.down_model.weight.to(orig_weight.device, dtype=orig_weight.dtype)
if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
output_shape = [up.size(0), down.size(1)]
if self.mid_model is not None:
# cp-decomposition
mid = self.mid_model.weight.to(orig_weight.device, dtype=orig_weight.dtype)
updown = lyco_helpers.rebuild_cp_decomposition(up, down, mid)
output_shape += mid.shape[2:]
else:
updown = up @ down
if len(down.shape) == 4:
output_shape += down.shape[2:]
updown = lyco_helpers.rebuild_conventional(up, down, output_shape, self.network.dyn_dim)
updown = updown * self.network.multiplier * (self.alpha / self.up.weight.shape[1] if self.alpha else 1.0)
return updown
return self.finalize_updown(updown, orig_weight, output_shape)
def forward(self, x, y):
self.up.to(device=devices.device)
self.down.to(device=devices.device)
self.up_model.to(device=devices.device)
self.down_model.to(device=devices.device)
return y + self.up(self.down(x)) * self.network.multiplier * (self.alpha / self.up.weight.shape[1] if self.alpha else 1.0)
return y + self.up_model(self.down_model(x)) * self.multiplier() * self.calc_scale()