mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-04 11:12:35 +00:00
Re-implement universal model loading
This commit is contained in:
@@ -5,15 +5,35 @@ import traceback
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
|
||||
import modules.esrgam_model_arch as arch
|
||||
from modules import shared
|
||||
from modules.shared import opts
|
||||
from modules.devices import has_mps
|
||||
import modules.images
|
||||
from modules import shared
|
||||
from modules import shared, modelloader
|
||||
from modules.devices import has_mps
|
||||
from modules.paths import models_path
|
||||
from modules.shared import opts
|
||||
|
||||
model_dir = "ESRGAN"
|
||||
model_path = os.path.join(models_path, model_dir)
|
||||
model_url = "https://drive.google.com/u/0/uc?id=1TPrz5QKd8DHHt1k8SRtm6tMiPjz_Qene&export=download"
|
||||
model_name = "ESRGAN_x4.pth"
|
||||
|
||||
|
||||
def load_model(filename):
|
||||
def load_model(path: str, name: str):
|
||||
global model_path
|
||||
global model_url
|
||||
global model_dir
|
||||
global model_name
|
||||
if "http" in path:
|
||||
filename = load_file_from_url(url=model_url, model_dir=model_path, file_name=model_name, progress=True)
|
||||
else:
|
||||
filename = path
|
||||
if not os.path.exists(filename) or filename is None:
|
||||
print("Unable to load %s from %s" % (model_dir, filename))
|
||||
return None
|
||||
print("Loading %s from %s" % (model_dir, filename))
|
||||
# this code is adapted from https://github.com/xinntao/ESRGAN
|
||||
pretrained_net = torch.load(filename, map_location='cpu' if has_mps else None)
|
||||
crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
|
||||
@@ -118,24 +138,30 @@ def esrgan_upscale(model, img):
|
||||
class UpscalerESRGAN(modules.images.Upscaler):
|
||||
def __init__(self, filename, title):
|
||||
self.name = title
|
||||
self.model = load_model(filename)
|
||||
self.filename = filename
|
||||
|
||||
def do_upscale(self, img):
|
||||
model = self.model.to(shared.device)
|
||||
model = load_model(self.filename, self.name)
|
||||
if model is None:
|
||||
return img
|
||||
model.to(shared.device)
|
||||
img = esrgan_upscale(model, img)
|
||||
return img
|
||||
|
||||
|
||||
def load_models(dirname):
|
||||
for file in os.listdir(dirname):
|
||||
path = os.path.join(dirname, file)
|
||||
model_name, extension = os.path.splitext(file)
|
||||
|
||||
if extension != '.pt' and extension != '.pth':
|
||||
continue
|
||||
def setup_model(dirname):
|
||||
global model_path
|
||||
global model_name
|
||||
if not os.path.exists(model_path):
|
||||
os.makedirs(model_path)
|
||||
|
||||
model_paths = modelloader.load_models(model_path, command_path=dirname, ext_filter=[".pt", ".pth"])
|
||||
if len(model_paths) == 0:
|
||||
modules.shared.sd_upscalers.append(UpscalerESRGAN(model_url, model_name))
|
||||
for file in model_paths:
|
||||
name = modelloader.friendly_name(file)
|
||||
try:
|
||||
modules.shared.sd_upscalers.append(UpscalerESRGAN(path, model_name))
|
||||
modules.shared.sd_upscalers.append(UpscalerESRGAN(file, name))
|
||||
except Exception:
|
||||
print(f"Error loading ESRGAN model: {path}", file=sys.stderr)
|
||||
print(f"Error loading ESRGAN model: {file}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
Reference in New Issue
Block a user