Merge branch 'dev' into multiple_loaded_models

This commit is contained in:
AUTOMATIC1111
2023-08-01 16:55:55 +03:00
15 changed files with 254 additions and 149 deletions

View File

@@ -14,7 +14,8 @@ import ldm.modules.midas as midas
from ldm.util import instantiate_from_config
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache
from modules.sd_hijack_inpainting import do_inpainting_hijack
from modules.timer import Timer
import tomesd
@@ -32,6 +33,8 @@ class CheckpointInfo:
self.filename = filename
abspath = os.path.abspath(filename)
self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors"
if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir):
name = abspath.replace(shared.cmd_opts.ckpt_dir, '')
elif abspath.startswith(model_path):
@@ -42,6 +45,19 @@ class CheckpointInfo:
if name.startswith("\\") or name.startswith("/"):
name = name[1:]
def read_metadata():
metadata = read_metadata_from_safetensors(filename)
self.modelspec_thumbnail = metadata.pop('modelspec.thumbnail', None)
return metadata
self.metadata = {}
if self.is_safetensors:
try:
self.metadata = cache.cached_data_for_file('safetensors-metadata', "checkpoint/" + name, filename, read_metadata)
except Exception as e:
errors.display(e, f"reading metadata for {filename}")
self.name = name
self.name_for_extra = os.path.splitext(os.path.basename(filename))[0]
self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
@@ -54,15 +70,6 @@ class CheckpointInfo:
self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])
self.metadata = {}
_, ext = os.path.splitext(self.filename)
if ext.lower() == ".safetensors":
try:
self.metadata = read_metadata_from_safetensors(filename)
except Exception as e:
errors.display(e, f"reading checkpoint metadata: {filename}")
def register(self):
checkpoints_list[self.title] = self
for id in self.ids:
@@ -78,7 +85,7 @@ class CheckpointInfo:
if self.shorthash not in self.ids:
self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]']
checkpoints_list.pop(self.title)
checkpoints_list.pop(self.title, None)
self.title = f'{self.name} [{self.shorthash}]'
self.register()