Lora cache in memory

This commit is contained in:
AUTOMATIC1111
2023-08-09 16:54:49 +03:00
parent 7ba8f11688
commit eed963e972
2 changed files with 22 additions and 3 deletions

View File

@@ -195,6 +195,15 @@ def load_network(name, network_on_disk):
return net
def purge_networks_from_memory():
while len(networks_in_memory) > shared.opts.lora_in_memory_limit and len(networks_in_memory) > 0:
name = next(iter(networks_in_memory))
networks_in_memory.pop(name, None)
devices.torch_gc()
def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None):
already_loaded = {}
@@ -212,15 +221,19 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
failed_to_load_networks = []
for i, name in enumerate(names):
for i, (network_on_disk, name) in enumerate(zip(networks_on_disk, names)):
net = already_loaded.get(name, None)
network_on_disk = networks_on_disk[i]
if network_on_disk is not None:
if net is None:
net = networks_in_memory.get(name)
if net is None or os.path.getmtime(network_on_disk.filename) > net.mtime:
try:
net = load_network(name, network_on_disk)
networks_in_memory.pop(name, None)
networks_in_memory[name] = net
except Exception as e:
errors.display(e, f"loading network {network_on_disk.filename}")
continue
@@ -242,6 +255,8 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
if failed_to_load_networks:
sd_hijack.model_hijack.comments.append("Failed to find networks: " + ", ".join(failed_to_load_networks))
purge_networks_from_memory()
def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
weights_backup = getattr(self, "network_weights_backup", None)
@@ -462,6 +477,7 @@ def infotext_pasted(infotext, params):
available_networks = {}
available_network_aliases = {}
loaded_networks = []
networks_in_memory = {}
available_network_hash_lookup = {}
forbidden_network_aliases = {}