Add general forward method for all modules.

This commit is contained in:
Kohaku-Blueleaf
2024-01-05 16:32:19 +08:00
parent a06dab8d7a
commit 18ca987c92
2 changed files with 39 additions and 7 deletions

View File

@@ -458,23 +458,23 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
self.network_current_names = wanted_names
def network_forward(module, input, original_forward):
def network_forward(org_module, input, original_forward):
"""
Old way of applying Lora by executing operations during layer's forward.
Stacking many loras this way results in big performance degradation.
"""
if len(loaded_networks) == 0:
return original_forward(module, input)
return original_forward(org_module, input)
input = devices.cond_cast_unet(input)
network_restore_weights_from_backup(module)
network_reset_cached_weight(module)
network_restore_weights_from_backup(org_module)
network_reset_cached_weight(org_module)
y = original_forward(module, input)
y = original_forward(org_module, input)
network_layer_name = getattr(module, 'network_layer_name', None)
network_layer_name = getattr(org_module, 'network_layer_name', None)
for lora in loaded_networks:
module = lora.modules.get(network_layer_name, None)
if module is None: