option to keep multiple models in memory

This commit is contained in:
AUTOMATIC1111
2023-08-01 00:24:48 +03:00
parent 6f0abbb71a
commit b235022c61
6 changed files with 135 additions and 35 deletions

View File

@@ -15,7 +15,6 @@ import ldm.modules.midas as midas
from ldm.util import instantiate_from_config
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl
from modules.sd_hijack_inpainting import do_inpainting_hijack
from modules.timer import Timer
import tomesd
@@ -423,6 +422,7 @@ sdxl_refiner_clip_weight = 'conditioner.embedders.0.model.ln_final.weight'
class SdModelData:
def __init__(self):
self.sd_model = None
self.loaded_sd_models = []
self.was_loaded_at_least_once = False
self.lock = threading.Lock()
@@ -437,6 +437,7 @@ class SdModelData:
try:
load_model()
except Exception as e:
errors.display(e, "loading stable diffusion model", full_traceback=True)
print("", file=sys.stderr)
@@ -448,11 +449,24 @@ class SdModelData:
def set_sd_model(self, v):
self.sd_model = v
try:
self.loaded_sd_models.remove(v)
except ValueError:
pass
if v is not None:
self.loaded_sd_models.insert(0, v)
model_data = SdModelData()
def get_empty_cond(sd_model):
from modules import extra_networks, processing
p = processing.StableDiffusionProcessingTxt2Img()
extra_networks.activate(p, {})
if hasattr(sd_model, 'conditioner'):
d = sd_model.get_learned_conditioning([""])
return d['crossattn']
@@ -460,19 +474,43 @@ def get_empty_cond(sd_model):
return sd_model.cond_stage_model([""])
def send_model_to_cpu(m):
from modules import lowvram
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.send_everything_to_cpu()
else:
m.to(devices.cpu)
devices.torch_gc()
def send_model_to_device(m):
from modules import lowvram
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.setup_for_low_vram(m, shared.cmd_opts.medvram)
else:
m.to(shared.device)
def send_model_to_trash(m):
m.to(device="meta")
devices.torch_gc()
def load_model(checkpoint_info=None, already_loaded_state_dict=None):
from modules import lowvram, sd_hijack
from modules import sd_hijack
checkpoint_info = checkpoint_info or select_checkpoint()
timer = Timer()
if model_data.sd_model:
sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
send_model_to_trash(model_data.sd_model)
model_data.sd_model = None
gc.collect()
devices.torch_gc()
do_inpainting_hijack()
timer = Timer()
timer.record("unload existing model")
if already_loaded_state_dict is not None:
state_dict = already_loaded_state_dict
@@ -512,12 +550,9 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
with sd_disable_initialization.LoadStateDictOnMeta(state_dict, devices.cpu):
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
timer.record("load weights from state dict")
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram)
else:
sd_model.to(shared.device)
send_model_to_device(sd_model)
timer.record("move model to device")
sd_hijack.model_hijack.hijack(sd_model)
@@ -525,7 +560,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
timer.record("hijack")
sd_model.eval()
model_data.sd_model = sd_model
model_data.set_sd_model(sd_model)
model_data.was_loaded_at_least_once = True
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
@@ -546,10 +581,61 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
return sd_model
def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
"""
Checks if the desired checkpoint from checkpoint_info is not already loaded in model_data.loaded_sd_models.
If it is loaded, returns that (moving it to GPU if necessary, and moving the currently loadded model to CPU if necessary).
If not, returns the model that can be used to load weights from checkpoint_info's file.
If no such model exists, returns None.
Additionaly deletes loaded models that are over the limit set in settings (sd_checkpoints_limit).
"""
already_loaded = None
for i in reversed(range(len(model_data.loaded_sd_models))):
loaded_model = model_data.loaded_sd_models[i]
if loaded_model.sd_checkpoint_info.filename == checkpoint_info.filename:
already_loaded = loaded_model
continue
if len(model_data.loaded_sd_models) > shared.opts.sd_checkpoints_limit > 0:
print(f"Unloading model {len(model_data.loaded_sd_models)} over the limit of {shared.opts.sd_checkpoints_limit}: {loaded_model.sd_checkpoint_info.title}")
model_data.loaded_sd_models.pop()
send_model_to_trash(loaded_model)
timer.record("send model to trash")
if shared.opts.sd_checkpoints_keep_in_cpu:
send_model_to_cpu(sd_model)
timer.record("send model to cpu")
if already_loaded is not None:
send_model_to_device(already_loaded)
timer.record("send model to device")
model_data.set_sd_model(already_loaded)
print(f"Using already loaded model {already_loaded.sd_checkpoint_info.title}: done in {timer.summary()}")
return model_data.sd_model
elif shared.opts.sd_checkpoints_limit > 1 and len(model_data.loaded_sd_models) < shared.opts.sd_checkpoints_limit:
print(f"Loading model {checkpoint_info.title} ({len(model_data.loaded_sd_models) + 1} out of {shared.opts.sd_checkpoints_limit})")
model_data.sd_model = None
load_model(checkpoint_info)
return model_data.sd_model
elif len(model_data.loaded_sd_models) > 0:
sd_model = model_data.loaded_sd_models.pop()
model_data.sd_model = sd_model
print(f"Reusing loaded model {sd_model.sd_checkpoint_info.title} to load {checkpoint_info.title}")
return sd_model
else:
return None
def reload_model_weights(sd_model=None, info=None):
from modules import lowvram, devices, sd_hijack
from modules import devices, sd_hijack
checkpoint_info = info or select_checkpoint()
timer = Timer()
if not sd_model:
sd_model = model_data.sd_model
@@ -558,19 +644,17 @@ def reload_model_weights(sd_model=None, info=None):
else:
current_checkpoint_info = sd_model.sd_checkpoint_info
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
return
return sd_model
sd_model = reuse_model_from_already_loaded(sd_model, checkpoint_info, timer)
if sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename:
return sd_model
if sd_model is not None:
sd_unet.apply_unet("None")
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.send_everything_to_cpu()
else:
sd_model.to(devices.cpu)
send_model_to_cpu(sd_model)
sd_hijack.model_hijack.undo_hijack(sd_model)
timer = Timer()
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
@@ -578,7 +662,9 @@ def reload_model_weights(sd_model=None, info=None):
timer.record("find config")
if sd_model is None or checkpoint_config != sd_model.used_config:
del sd_model
if sd_model is not None:
send_model_to_trash(sd_model)
load_model(checkpoint_info, already_loaded_state_dict=state_dict)
return model_data.sd_model
@@ -601,6 +687,8 @@ def reload_model_weights(sd_model=None, info=None):
print(f"Weights loaded in {timer.summary()}.")
model_data.set_sd_model(sd_model)
return sd_model