Add extra norm module into built-in lora ext

refer to LyCORIS 1.9.0.dev6
add new option and module for training norm layer
(Which is reported to be good for style)
This commit is contained in:
Kohaku-Blueleaf
2023-08-13 02:27:39 +08:00
parent b2080756fc
commit bd4da4474b
4 changed files with 105 additions and 11 deletions

View File

@@ -7,6 +7,7 @@ import network_hada
import network_ia3
import network_lokr
import network_full
import network_norm
import torch
from typing import Union
@@ -19,6 +20,7 @@ module_types = [
network_ia3.ModuleTypeIa3(),
network_lokr.ModuleTypeLokr(),
network_full.ModuleTypeFull(),
network_norm.ModuleTypeNorm(),
]
@@ -31,6 +33,8 @@ suffix_conversion = {
"resnets": {
"conv1": "in_layers_2",
"conv2": "out_layers_3",
"norm1": "in_layers_0",
"norm2": "out_layers_0",
"time_emb_proj": "emb_layers_1",
"conv_shortcut": "skip_connection",
}
@@ -258,20 +262,25 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
purge_networks_from_memory()
def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
weights_backup = getattr(self, "network_weights_backup", None)
bias_backup = getattr(self, "network_bias_backup", None)
if weights_backup is None:
if weights_backup is None and bias_backup is None:
return
if isinstance(self, torch.nn.MultiheadAttention):
self.in_proj_weight.copy_(weights_backup[0])
self.out_proj.weight.copy_(weights_backup[1])
else:
self.weight.copy_(weights_backup)
if weights_backup is not None:
if isinstance(self, torch.nn.MultiheadAttention):
self.in_proj_weight.copy_(weights_backup[0])
self.out_proj.weight.copy_(weights_backup[1])
else:
self.weight.copy_(weights_backup)
if bias_backup is not None:
self.bias.copy_(bias_backup)
def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):
"""
Applies the currently selected set of networks to the weights of torch layer self.
If weights already have this particular set of networks applied, does nothing.
@@ -294,6 +303,11 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
self.network_weights_backup = weights_backup
bias_backup = getattr(self, "network_bias_backup", None)
if bias_backup is None and getattr(self, 'bias', None) is not None:
bias_backup = self.bias.to(devices.cpu, copy=True)
self.network_bias_backup = bias_backup
if current_names != wanted_names:
network_restore_weights_from_backup(self)
@@ -301,13 +315,15 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
module = net.modules.get(network_layer_name, None)
if module is not None and hasattr(self, 'weight'):
with torch.no_grad():
updown = module.calc_updown(self.weight)
updown, ex_bias = module.calc_updown(self.weight)
if len(self.weight.shape) == 4 and self.weight.shape[1] == 9:
# inpainting model. zero pad updown to make channel[1] 4 to 9
updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))
self.weight += updown
if getattr(self, 'bias', None) is not None:
self.bias += ex_bias
continue
module_q = net.modules.get(network_layer_name + "_q_proj", None)
@@ -397,6 +413,36 @@ def network_Conv2d_load_state_dict(self, *args, **kwargs):
return torch.nn.Conv2d_load_state_dict_before_network(self, *args, **kwargs)
def network_GroupNorm_forward(self, input):
if shared.opts.lora_functional:
return network_forward(self, input, torch.nn.GroupNorm_forward_before_network)
network_apply_weights(self)
return torch.nn.GroupNorm_forward_before_network(self, input)
def network_GroupNorm_load_state_dict(self, *args, **kwargs):
network_reset_cached_weight(self)
return torch.nn.GroupNorm_load_state_dict_before_network(self, *args, **kwargs)
def network_LayerNorm_forward(self, input):
if shared.opts.lora_functional:
return network_forward(self, input, torch.nn.LayerNorm_forward_before_network)
network_apply_weights(self)
return torch.nn.LayerNorm_forward_before_network(self, input)
def network_LayerNorm_load_state_dict(self, *args, **kwargs):
network_reset_cached_weight(self)
return torch.nn.LayerNorm_load_state_dict_before_network(self, *args, **kwargs)
def network_MultiheadAttention_forward(self, *args, **kwargs):
network_apply_weights(self)