mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-04 11:12:35 +00:00
store patches for Lora in a specialized module
This commit is contained in:
@@ -7,17 +7,14 @@ from fastapi import FastAPI
|
||||
import network
|
||||
import networks
|
||||
import lora # noqa:F401
|
||||
import lora_patches
|
||||
import extra_networks_lora
|
||||
import ui_extra_networks_lora
|
||||
from modules import script_callbacks, ui_extra_networks, extra_networks, shared
|
||||
from modules import script_callbacks, ui_extra_networks, extra_networks, shared, patches
|
||||
|
||||
|
||||
def unload():
|
||||
torch.nn.Linear.forward = torch.nn.Linear_forward_before_network
|
||||
torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_network
|
||||
torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_network
|
||||
torch.nn.Conv2d._load_from_state_dict = torch.nn.Conv2d_load_state_dict_before_network
|
||||
torch.nn.MultiheadAttention.forward = torch.nn.MultiheadAttention_forward_before_network
|
||||
torch.nn.MultiheadAttention._load_from_state_dict = torch.nn.MultiheadAttention_load_state_dict_before_network
|
||||
networks.originals.undo()
|
||||
|
||||
|
||||
def before_ui():
|
||||
@@ -28,46 +25,7 @@ def before_ui():
|
||||
extra_networks.register_extra_network_alias(networks.extra_network_lora, "lyco")
|
||||
|
||||
|
||||
if not hasattr(torch.nn, 'Linear_forward_before_network'):
|
||||
torch.nn.Linear_forward_before_network = torch.nn.Linear.forward
|
||||
|
||||
if not hasattr(torch.nn, 'Linear_load_state_dict_before_network'):
|
||||
torch.nn.Linear_load_state_dict_before_network = torch.nn.Linear._load_from_state_dict
|
||||
|
||||
if not hasattr(torch.nn, 'Conv2d_forward_before_network'):
|
||||
torch.nn.Conv2d_forward_before_network = torch.nn.Conv2d.forward
|
||||
|
||||
if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_network'):
|
||||
torch.nn.Conv2d_load_state_dict_before_network = torch.nn.Conv2d._load_from_state_dict
|
||||
|
||||
if not hasattr(torch.nn, 'GroupNorm_forward_before_network'):
|
||||
torch.nn.GroupNorm_forward_before_network = torch.nn.GroupNorm.forward
|
||||
|
||||
if not hasattr(torch.nn, 'GroupNorm_load_state_dict_before_network'):
|
||||
torch.nn.GroupNorm_load_state_dict_before_network = torch.nn.GroupNorm._load_from_state_dict
|
||||
|
||||
if not hasattr(torch.nn, 'LayerNorm_forward_before_network'):
|
||||
torch.nn.LayerNorm_forward_before_network = torch.nn.LayerNorm.forward
|
||||
|
||||
if not hasattr(torch.nn, 'LayerNorm_load_state_dict_before_network'):
|
||||
torch.nn.LayerNorm_load_state_dict_before_network = torch.nn.LayerNorm._load_from_state_dict
|
||||
|
||||
if not hasattr(torch.nn, 'MultiheadAttention_forward_before_network'):
|
||||
torch.nn.MultiheadAttention_forward_before_network = torch.nn.MultiheadAttention.forward
|
||||
|
||||
if not hasattr(torch.nn, 'MultiheadAttention_load_state_dict_before_network'):
|
||||
torch.nn.MultiheadAttention_load_state_dict_before_network = torch.nn.MultiheadAttention._load_from_state_dict
|
||||
|
||||
torch.nn.Linear.forward = networks.network_Linear_forward
|
||||
torch.nn.Linear._load_from_state_dict = networks.network_Linear_load_state_dict
|
||||
torch.nn.Conv2d.forward = networks.network_Conv2d_forward
|
||||
torch.nn.Conv2d._load_from_state_dict = networks.network_Conv2d_load_state_dict
|
||||
torch.nn.GroupNorm.forward = networks.network_GroupNorm_forward
|
||||
torch.nn.GroupNorm._load_from_state_dict = networks.network_GroupNorm_load_state_dict
|
||||
torch.nn.LayerNorm.forward = networks.network_LayerNorm_forward
|
||||
torch.nn.LayerNorm._load_from_state_dict = networks.network_LayerNorm_load_state_dict
|
||||
torch.nn.MultiheadAttention.forward = networks.network_MultiheadAttention_forward
|
||||
torch.nn.MultiheadAttention._load_from_state_dict = networks.network_MultiheadAttention_load_state_dict
|
||||
networks.originals = lora_patches.LoraPatches()
|
||||
|
||||
script_callbacks.on_model_loaded(networks.assign_network_names_to_compvis_modules)
|
||||
script_callbacks.on_script_unloaded(unload)
|
||||
|
Reference in New Issue
Block a user