change hash to sha256

This commit is contained in:
AUTOMATIC
2023-01-14 09:56:59 +03:00
parent 82725f0ac4
commit a95f135308
9 changed files with 159 additions and 51 deletions

View File

@@ -14,17 +14,56 @@ import ldm.modules.midas as midas
from ldm.util import instantiate_from_config
from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors
from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes
from modules.paths import models_path
from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting
model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(models_path, model_dir))
CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name'])
checkpoints_list = {}
checkpoint_alisases = {}
checkpoints_loaded = collections.OrderedDict()
class CheckpointInfo:
def __init__(self, filename):
self.filename = filename
abspath = os.path.abspath(filename)
if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir):
name = abspath.replace(shared.cmd_opts.ckpt_dir, '')
elif abspath.startswith(model_path):
name = abspath.replace(model_path, '')
else:
name = os.path.basename(filename)
if name.startswith("\\") or name.startswith("/"):
name = name[1:]
self.title = name
self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
self.hash = model_hash(filename)
self.ids = [self.hash, self.model_name, self.title, f'{name} [{self.hash}]']
self.shorthash = None
self.sha256 = None
def register(self):
checkpoints_list[self.title] = self
for id in self.ids:
checkpoint_alisases[id] = self
def calculate_shorthash(self):
self.sha256 = hashes.sha256(self.filename, self.title)
self.shorthash = self.sha256[0:10]
if self.shorthash not in self.ids:
self.ids += [self.shorthash, self.sha256]
self.register()
return self.shorthash
try:
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
@@ -43,10 +82,14 @@ def setup_model():
enable_midas_autodownload()
def checkpoint_tiles():
convert = lambda name: int(name) if name.isdigit() else name.lower()
alphanumeric_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)]
return sorted([x.title for x in checkpoints_list.values()], key = alphanumeric_key)
def checkpoint_tiles():
def convert(name):
return int(name) if name.isdigit() else name.lower()
def alphanumeric_key(key):
return [convert(c) for c in re.split('([0-9]+)', key)]
return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key)
def find_checkpoint_config(info):
@@ -62,48 +105,38 @@ def find_checkpoint_config(info):
def list_models():
checkpoints_list.clear()
checkpoint_alisases.clear()
model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"], ext_blacklist=[".vae.safetensors"])
def modeltitle(path, shorthash):
abspath = os.path.abspath(path)
if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir):
name = abspath.replace(shared.cmd_opts.ckpt_dir, '')
elif abspath.startswith(model_path):
name = abspath.replace(model_path, '')
else:
name = os.path.basename(path)
if name.startswith("\\") or name.startswith("/"):
name = name[1:]
shortname = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
return f'{name} [{shorthash}]', shortname
cmd_ckpt = shared.cmd_opts.ckpt
if os.path.exists(cmd_ckpt):
h = model_hash(cmd_ckpt)
title, short_model_name = modeltitle(cmd_ckpt, h)
checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name)
shared.opts.data['sd_model_checkpoint'] = title
checkpoint_info = CheckpointInfo(cmd_ckpt)
checkpoint_info.register()
shared.opts.data['sd_model_checkpoint'] = checkpoint_info.title
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
for filename in model_list:
h = model_hash(filename)
title, short_model_name = modeltitle(filename, h)
checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name)
checkpoint_info = CheckpointInfo(filename)
checkpoint_info.register()
def get_closet_checkpoint_match(searchString):
applicable = sorted([info for info in checkpoints_list.values() if searchString in info.title], key = lambda x:len(x.title))
if len(applicable) > 0:
return applicable[0]
def get_closet_checkpoint_match(search_string):
checkpoint_info = checkpoint_alisases.get(search_string, None)
if checkpoint_info is not None:
return
found = sorted([info for info in checkpoints_list.values() if search_string in info.title], key=lambda x: len(x.title))
if found:
return found[0]
return None
def model_hash(filename):
"""old hash that only looks at a small part of the file and is prone to collisions"""
try:
with open(filename, "rb") as file:
import hashlib
@@ -119,7 +152,7 @@ def model_hash(filename):
def select_checkpoint():
model_checkpoint = shared.opts.sd_model_checkpoint
checkpoint_info = checkpoints_list.get(model_checkpoint, None)
checkpoint_info = checkpoint_alisases.get(model_checkpoint, None)
if checkpoint_info is not None:
return checkpoint_info
@@ -189,9 +222,8 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None
return sd
def load_model_weights(model, checkpoint_info, vae_file="auto"):
checkpoint_file = checkpoint_info.filename
sd_model_hash = checkpoint_info.hash
def load_model_weights(model, checkpoint_info: CheckpointInfo, vae_file="auto"):
sd_model_hash = checkpoint_info.calculate_shorthash()
cache_enabled = shared.opts.sd_checkpoint_cache > 0
@@ -201,9 +233,9 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
model.load_state_dict(checkpoints_loaded[checkpoint_info])
else:
# load from file
print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")
print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}")
sd = read_state_dict(checkpoint_file)
sd = read_state_dict(checkpoint_info.filename)
model.load_state_dict(sd, strict=False)
del sd
@@ -235,14 +267,14 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
checkpoints_loaded.popitem(last=False) # LRU
model.sd_model_hash = sd_model_hash
model.sd_model_checkpoint = checkpoint_file
model.sd_model_checkpoint = checkpoint_info.filename
model.sd_checkpoint_info = checkpoint_info
model.logvar = model.logvar.to(devices.device) # fix for training
sd_vae.delete_base_vae()
sd_vae.clear_loaded_vae()
vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file)
vae_file = sd_vae.resolve_vae(checkpoint_info.filename, vae_file=vae_file)
sd_vae.load_vae(model, vae_file)