Merge branch 'dev' into multiple_loaded_models

This commit is contained in:
AUTOMATIC1111
2023-08-05 07:52:29 +03:00
45 changed files with 919 additions and 471 deletions

View File

@@ -66,8 +66,9 @@ class CheckpointInfo:
self.shorthash = self.sha256[0:10] if self.sha256 else None
self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'
self.short_title = self.name_for_extra if self.shorthash is None else f'{self.name_for_extra} [{self.shorthash}]'
self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])
self.ids = [self.hash, self.model_name, self.title, name, self.name_for_extra, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])
def register(self):
checkpoints_list[self.title] = self
@@ -86,6 +87,7 @@ class CheckpointInfo:
checkpoints_list.pop(self.title, None)
self.title = f'{self.name} [{self.shorthash}]'
self.short_title = f'{self.name_for_extra} [{self.shorthash}]'
self.register()
return self.shorthash
@@ -106,14 +108,8 @@ def setup_model():
enable_midas_autodownload()
def checkpoint_tiles():
def convert(name):
return int(name) if name.isdigit() else name.lower()
def alphanumeric_key(key):
return [convert(c) for c in re.split('([0-9]+)', key)]
return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key)
def checkpoint_tiles(use_short=False):
return [x.short_title if use_short else x.title for x in checkpoints_list.values()]
def list_models():
@@ -136,11 +132,14 @@ def list_models():
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
for filename in sorted(model_list, key=str.lower):
for filename in model_list:
checkpoint_info = CheckpointInfo(filename)
checkpoint_info.register()
re_strip_checksum = re.compile(r"\s*\[[^]]+]\s*$")
def get_closet_checkpoint_match(search_string):
checkpoint_info = checkpoint_aliases.get(search_string, None)
if checkpoint_info is not None:
@@ -150,6 +149,11 @@ def get_closet_checkpoint_match(search_string):
if found:
return found[0]
search_string_without_checksum = re.sub(re_strip_checksum, '', search_string)
found = sorted([info for info in checkpoints_list.values() if search_string_without_checksum in info.title], key=lambda x: len(x.title))
if found:
return found[0]
return None
@@ -302,12 +306,13 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
sd_models_xl.extend_sdxl(model)
model.load_state_dict(state_dict, strict=False)
del state_dict
timer.record("apply weights to model")
if shared.opts.sd_checkpoint_cache > 0:
# cache newly loaded model
checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
checkpoints_loaded[checkpoint_info] = state_dict
del state_dict
if shared.cmd_opts.opt_channelslast:
model.to(memory_format=torch.channels_last)