store patches for Lora in a specialized module

This commit is contained in:
AUTOMATIC1111
2023-08-15 19:23:27 +03:00
parent 7327be97aa
commit f01682ee01
4 changed files with 118 additions and 61 deletions

View File

@@ -0,0 +1,31 @@
import torch
import networks
from modules import patches
class LoraPatches:
def __init__(self):
self.Linear_forward = patches.patch(__name__, torch.nn.Linear, 'forward', networks.network_Linear_forward)
self.Linear_load_state_dict = patches.patch(__name__, torch.nn.Linear, '_load_from_state_dict', networks.network_Linear_load_state_dict)
self.Conv2d_forward = patches.patch(__name__, torch.nn.Conv2d, 'forward', networks.network_Conv2d_forward)
self.Conv2d_load_state_dict = patches.patch(__name__, torch.nn.Conv2d, '_load_from_state_dict', networks.network_Conv2d_load_state_dict)
self.GroupNorm_forward = patches.patch(__name__, torch.nn.GroupNorm, 'forward', networks.network_GroupNorm_forward)
self.GroupNorm_load_state_dict = patches.patch(__name__, torch.nn.GroupNorm, '_load_from_state_dict', networks.network_GroupNorm_load_state_dict)
self.LayerNorm_forward = patches.patch(__name__, torch.nn.LayerNorm, 'forward', networks.network_LayerNorm_forward)
self.LayerNorm_load_state_dict = patches.patch(__name__, torch.nn.LayerNorm, '_load_from_state_dict', networks.network_LayerNorm_load_state_dict)
self.MultiheadAttention_forward = patches.patch(__name__, torch.nn.MultiheadAttention, 'forward', networks.network_MultiheadAttention_forward)
self.MultiheadAttention_load_state_dict = patches.patch(__name__, torch.nn.MultiheadAttention, '_load_from_state_dict', networks.network_MultiheadAttention_load_state_dict)
def undo(self):
self.Linear_forward = patches.undo(__name__, torch.nn.Linear, 'forward')
self.Linear_load_state_dict = patches.undo(__name__, torch.nn.Linear, '_load_from_state_dict')
self.Conv2d_forward = patches.undo(__name__, torch.nn.Conv2d, 'forward')
self.Conv2d_load_state_dict = patches.undo(__name__, torch.nn.Conv2d, '_load_from_state_dict')
self.GroupNorm_forward = patches.undo(__name__, torch.nn.GroupNorm, 'forward')
self.GroupNorm_load_state_dict = patches.undo(__name__, torch.nn.GroupNorm, '_load_from_state_dict')
self.LayerNorm_forward = patches.undo(__name__, torch.nn.LayerNorm, 'forward')
self.LayerNorm_load_state_dict = patches.undo(__name__, torch.nn.LayerNorm, '_load_from_state_dict')
self.MultiheadAttention_forward = patches.undo(__name__, torch.nn.MultiheadAttention, 'forward')
self.MultiheadAttention_load_state_dict = patches.undo(__name__, torch.nn.MultiheadAttention, '_load_from_state_dict')

View File

@@ -2,6 +2,7 @@ import logging
import os
import re
import lora_patches
import network
import network_lora
import network_hada
@@ -418,74 +419,74 @@ def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
def network_Linear_forward(self, input):
if shared.opts.lora_functional:
return network_forward(self, input, torch.nn.Linear_forward_before_network)
return network_forward(self, input, originals.Linear_forward)
network_apply_weights(self)
return torch.nn.Linear_forward_before_network(self, input)
return originals.Linear_forward(self, input)
def network_Linear_load_state_dict(self, *args, **kwargs):
network_reset_cached_weight(self)
return torch.nn.Linear_load_state_dict_before_network(self, *args, **kwargs)
return originals.Linear_load_state_dict(self, *args, **kwargs)
def network_Conv2d_forward(self, input):
if shared.opts.lora_functional:
return network_forward(self, input, torch.nn.Conv2d_forward_before_network)
return network_forward(self, input, originals.Conv2d_forward)
network_apply_weights(self)
return torch.nn.Conv2d_forward_before_network(self, input)
return originals.Conv2d_forward(self, input)
def network_Conv2d_load_state_dict(self, *args, **kwargs):
network_reset_cached_weight(self)
return torch.nn.Conv2d_load_state_dict_before_network(self, *args, **kwargs)
return originals.Conv2d_load_state_dict(self, *args, **kwargs)
def network_GroupNorm_forward(self, input):
if shared.opts.lora_functional:
return network_forward(self, input, torch.nn.GroupNorm_forward_before_network)
return network_forward(self, input, originals.GroupNorm_forward)
network_apply_weights(self)
return torch.nn.GroupNorm_forward_before_network(self, input)
return originals.GroupNorm_forward(self, input)
def network_GroupNorm_load_state_dict(self, *args, **kwargs):
network_reset_cached_weight(self)
return torch.nn.GroupNorm_load_state_dict_before_network(self, *args, **kwargs)
return originals.GroupNorm_load_state_dict(self, *args, **kwargs)
def network_LayerNorm_forward(self, input):
if shared.opts.lora_functional:
return network_forward(self, input, torch.nn.LayerNorm_forward_before_network)
return network_forward(self, input, originals.LayerNorm_forward)
network_apply_weights(self)
return torch.nn.LayerNorm_forward_before_network(self, input)
return originals.LayerNorm_forward(self, input)
def network_LayerNorm_load_state_dict(self, *args, **kwargs):
network_reset_cached_weight(self)
return torch.nn.LayerNorm_load_state_dict_before_network(self, *args, **kwargs)
return originals.LayerNorm_load_state_dict(self, *args, **kwargs)
def network_MultiheadAttention_forward(self, *args, **kwargs):
network_apply_weights(self)
return torch.nn.MultiheadAttention_forward_before_network(self, *args, **kwargs)
return originals.MultiheadAttention_forward(self, *args, **kwargs)
def network_MultiheadAttention_load_state_dict(self, *args, **kwargs):
network_reset_cached_weight(self)
return torch.nn.MultiheadAttention_load_state_dict_before_network(self, *args, **kwargs)
return originals.MultiheadAttention_load_state_dict(self, *args, **kwargs)
def list_available_networks():
@@ -552,6 +553,9 @@ def infotext_pasted(infotext, params):
if added:
params["Prompt"] += "\n" + "".join(added)
originals: lora_patches.LoraPatches = None
extra_network_lora = None
available_networks = {}

View File

@@ -7,17 +7,14 @@ from fastapi import FastAPI
import network
import networks
import lora # noqa:F401
import lora_patches
import extra_networks_lora
import ui_extra_networks_lora
from modules import script_callbacks, ui_extra_networks, extra_networks, shared
from modules import script_callbacks, ui_extra_networks, extra_networks, shared, patches
def unload():
torch.nn.Linear.forward = torch.nn.Linear_forward_before_network
torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_network
torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_network
torch.nn.Conv2d._load_from_state_dict = torch.nn.Conv2d_load_state_dict_before_network
torch.nn.MultiheadAttention.forward = torch.nn.MultiheadAttention_forward_before_network
torch.nn.MultiheadAttention._load_from_state_dict = torch.nn.MultiheadAttention_load_state_dict_before_network
networks.originals.undo()
def before_ui():
@@ -28,46 +25,7 @@ def before_ui():
extra_networks.register_extra_network_alias(networks.extra_network_lora, "lyco")
if not hasattr(torch.nn, 'Linear_forward_before_network'):
torch.nn.Linear_forward_before_network = torch.nn.Linear.forward
if not hasattr(torch.nn, 'Linear_load_state_dict_before_network'):
torch.nn.Linear_load_state_dict_before_network = torch.nn.Linear._load_from_state_dict
if not hasattr(torch.nn, 'Conv2d_forward_before_network'):
torch.nn.Conv2d_forward_before_network = torch.nn.Conv2d.forward
if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_network'):
torch.nn.Conv2d_load_state_dict_before_network = torch.nn.Conv2d._load_from_state_dict
if not hasattr(torch.nn, 'GroupNorm_forward_before_network'):
torch.nn.GroupNorm_forward_before_network = torch.nn.GroupNorm.forward
if not hasattr(torch.nn, 'GroupNorm_load_state_dict_before_network'):
torch.nn.GroupNorm_load_state_dict_before_network = torch.nn.GroupNorm._load_from_state_dict
if not hasattr(torch.nn, 'LayerNorm_forward_before_network'):
torch.nn.LayerNorm_forward_before_network = torch.nn.LayerNorm.forward
if not hasattr(torch.nn, 'LayerNorm_load_state_dict_before_network'):
torch.nn.LayerNorm_load_state_dict_before_network = torch.nn.LayerNorm._load_from_state_dict
if not hasattr(torch.nn, 'MultiheadAttention_forward_before_network'):
torch.nn.MultiheadAttention_forward_before_network = torch.nn.MultiheadAttention.forward
if not hasattr(torch.nn, 'MultiheadAttention_load_state_dict_before_network'):
torch.nn.MultiheadAttention_load_state_dict_before_network = torch.nn.MultiheadAttention._load_from_state_dict
torch.nn.Linear.forward = networks.network_Linear_forward
torch.nn.Linear._load_from_state_dict = networks.network_Linear_load_state_dict
torch.nn.Conv2d.forward = networks.network_Conv2d_forward
torch.nn.Conv2d._load_from_state_dict = networks.network_Conv2d_load_state_dict
torch.nn.GroupNorm.forward = networks.network_GroupNorm_forward
torch.nn.GroupNorm._load_from_state_dict = networks.network_GroupNorm_load_state_dict
torch.nn.LayerNorm.forward = networks.network_LayerNorm_forward
torch.nn.LayerNorm._load_from_state_dict = networks.network_LayerNorm_load_state_dict
torch.nn.MultiheadAttention.forward = networks.network_MultiheadAttention_forward
torch.nn.MultiheadAttention._load_from_state_dict = networks.network_MultiheadAttention_load_state_dict
networks.originals = lora_patches.LoraPatches()
script_callbacks.on_model_loaded(networks.assign_network_names_to_compvis_modules)
script_callbacks.on_script_unloaded(unload)