mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-04 11:12:35 +00:00
Depth2img model support
This commit is contained in:
@@ -7,6 +7,9 @@ import torch
|
||||
import re
|
||||
import safetensors.torch
|
||||
from omegaconf import OmegaConf
|
||||
from os import mkdir
|
||||
from urllib import request
|
||||
import ldm.modules.midas as midas
|
||||
|
||||
from ldm.util import instantiate_from_config
|
||||
|
||||
@@ -36,6 +39,7 @@ def setup_model():
|
||||
os.makedirs(model_path)
|
||||
|
||||
list_models()
|
||||
enable_midas_autodownload()
|
||||
|
||||
|
||||
def checkpoint_tiles():
|
||||
@@ -227,6 +231,48 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
|
||||
sd_vae.load_vae(model, vae_file)
|
||||
|
||||
|
||||
def enable_midas_autodownload():
|
||||
"""
|
||||
Gives the ldm.modules.midas.api.load_model function automatic downloading.
|
||||
|
||||
When the 512-depth-ema model, and other future models like it, is loaded,
|
||||
it calls midas.api.load_model to load the associated midas depth model.
|
||||
This function applies a wrapper to download the model to the correct
|
||||
location automatically.
|
||||
"""
|
||||
|
||||
midas_path = os.path.join(models_path, 'midas')
|
||||
|
||||
# stable-diffusion-stability-ai hard-codes the midas model path to
|
||||
# a location that differs from where other scripts using this model look.
|
||||
# HACK: Overriding the path here.
|
||||
for k, v in midas.api.ISL_PATHS.items():
|
||||
file_name = os.path.basename(v)
|
||||
midas.api.ISL_PATHS[k] = os.path.join(midas_path, file_name)
|
||||
|
||||
midas_urls = {
|
||||
"dpt_large": "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt",
|
||||
"dpt_hybrid": "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt",
|
||||
"midas_v21": "https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21-f6b98070.pt",
|
||||
"midas_v21_small": "https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21_small-70d6b9c8.pt",
|
||||
}
|
||||
|
||||
midas.api.load_model_inner = midas.api.load_model
|
||||
|
||||
def load_model_wrapper(model_type):
|
||||
path = midas.api.ISL_PATHS[model_type]
|
||||
if not os.path.exists(path):
|
||||
if not os.path.exists(midas_path):
|
||||
mkdir(midas_path)
|
||||
|
||||
print(f"Downloading midas model weights for {model_type} to {path}")
|
||||
request.urlretrieve(midas_urls[model_type], path)
|
||||
print(f"{model_type} downloaded")
|
||||
|
||||
return midas.api.load_model_inner(model_type)
|
||||
|
||||
midas.api.load_model = load_model_wrapper
|
||||
|
||||
def load_model(checkpoint_info=None):
|
||||
from modules import lowvram, sd_hijack
|
||||
checkpoint_info = checkpoint_info or select_checkpoint()
|
||||
|
Reference in New Issue
Block a user