mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-08 05:12:35 +00:00
Re-implement universal model loading
This commit is contained in:
@@ -1,21 +1,39 @@
|
||||
import contextlib
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
import cv2
|
||||
import os
|
||||
import contextlib
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import torch
|
||||
import modules.images
|
||||
from modules.shared import cmd_opts, opts, device
|
||||
from modules.swinir_arch import SwinIR as net
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
|
||||
import modules.images
|
||||
from modules import modelloader
|
||||
from modules.paths import models_path
|
||||
from modules.shared import cmd_opts, opts, device
|
||||
from modules.swinir_model_arch import SwinIR as net
|
||||
|
||||
model_dir = "SwinIR"
|
||||
model_url = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth"
|
||||
model_name = "SwinIR x4"
|
||||
model_path = os.path.join(models_path, model_dir)
|
||||
cmd_path = ""
|
||||
precision_scope = (
|
||||
torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
|
||||
)
|
||||
|
||||
|
||||
def load_model(filename, scale=4):
|
||||
def load_model(path, scale=4):
|
||||
global model_path
|
||||
global model_name
|
||||
if "http" in path:
|
||||
dl_name = "%s%s" % (model_name.replace(" ", "_"), ".pth")
|
||||
filename = load_file_from_url(url=path, model_dir=model_path, file_name=dl_name, progress=True)
|
||||
else:
|
||||
filename = path
|
||||
if filename is None or not os.path.exists(filename):
|
||||
return None
|
||||
model = net(
|
||||
upscale=scale,
|
||||
in_chans=3,
|
||||
@@ -37,19 +55,29 @@ def load_model(filename, scale=4):
|
||||
return model
|
||||
|
||||
|
||||
def load_models(dirname):
|
||||
for file in os.listdir(dirname):
|
||||
path = os.path.join(dirname, file)
|
||||
model_name, extension = os.path.splitext(file)
|
||||
def setup_model(dirname):
|
||||
global model_path
|
||||
global model_name
|
||||
global cmd_path
|
||||
if not os.path.exists(model_path):
|
||||
os.makedirs(model_path)
|
||||
cmd_path = dirname
|
||||
model_file = ""
|
||||
try:
|
||||
models = modelloader.load_models(model_path, ext_filter=[".pt", ".pth"], command_path=cmd_path)
|
||||
|
||||
if extension != ".pt" and extension != ".pth":
|
||||
continue
|
||||
if len(models) != 0:
|
||||
model_file = models[0]
|
||||
name = modelloader.friendly_name(model_file)
|
||||
else:
|
||||
# Add the "default" model if none are found.
|
||||
model_file = model_url
|
||||
name = model_name
|
||||
|
||||
try:
|
||||
modules.shared.sd_upscalers.append(UpscalerSwin(path, model_name))
|
||||
except Exception:
|
||||
print(f"Error loading SwinIR model: {path}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
modules.shared.sd_upscalers.append(UpscalerSwin(model_file, name))
|
||||
except Exception:
|
||||
print(f"Error loading SwinIR model: {model_file}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
|
||||
def upscale(
|
||||
@@ -115,9 +143,16 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
|
||||
class UpscalerSwin(modules.images.Upscaler):
|
||||
def __init__(self, filename, title):
|
||||
self.name = title
|
||||
self.model = load_model(filename)
|
||||
self.filename = filename
|
||||
|
||||
def do_upscale(self, img):
|
||||
model = self.model.to(device)
|
||||
model = load_model(self.filename)
|
||||
if model is None:
|
||||
return img
|
||||
model = model.to(device)
|
||||
img = upscale(img, model)
|
||||
return img
|
||||
try:
|
||||
torch.cuda.empty_cache()
|
||||
except:
|
||||
pass
|
||||
return img
|
Reference in New Issue
Block a user