Re-implement universal model loading

This commit is contained in:
d8ahazard
2022-09-26 09:29:50 -05:00
parent bfb7f15d46
commit 740070ea9c
12 changed files with 449 additions and 134 deletions

View File

@@ -1,21 +1,39 @@
import contextlib
import os
import sys
import traceback
import cv2
import os
import contextlib
import numpy as np
from PIL import Image
import torch
import modules.images
from modules.shared import cmd_opts, opts, device
from modules.swinir_arch import SwinIR as net
import numpy as np
import torch
from PIL import Image
from basicsr.utils.download_util import load_file_from_url
import modules.images
from modules import modelloader
from modules.paths import models_path
from modules.shared import cmd_opts, opts, device
from modules.swinir_model_arch import SwinIR as net
model_dir = "SwinIR"
model_url = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth"
model_name = "SwinIR x4"
model_path = os.path.join(models_path, model_dir)
cmd_path = ""
precision_scope = (
torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
)
def load_model(filename, scale=4):
def load_model(path, scale=4):
global model_path
global model_name
if "http" in path:
dl_name = "%s%s" % (model_name.replace(" ", "_"), ".pth")
filename = load_file_from_url(url=path, model_dir=model_path, file_name=dl_name, progress=True)
else:
filename = path
if filename is None or not os.path.exists(filename):
return None
model = net(
upscale=scale,
in_chans=3,
@@ -37,19 +55,29 @@ def load_model(filename, scale=4):
return model
def load_models(dirname):
for file in os.listdir(dirname):
path = os.path.join(dirname, file)
model_name, extension = os.path.splitext(file)
def setup_model(dirname):
global model_path
global model_name
global cmd_path
if not os.path.exists(model_path):
os.makedirs(model_path)
cmd_path = dirname
model_file = ""
try:
models = modelloader.load_models(model_path, ext_filter=[".pt", ".pth"], command_path=cmd_path)
if extension != ".pt" and extension != ".pth":
continue
if len(models) != 0:
model_file = models[0]
name = modelloader.friendly_name(model_file)
else:
# Add the "default" model if none are found.
model_file = model_url
name = model_name
try:
modules.shared.sd_upscalers.append(UpscalerSwin(path, model_name))
except Exception:
print(f"Error loading SwinIR model: {path}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
modules.shared.sd_upscalers.append(UpscalerSwin(model_file, name))
except Exception:
print(f"Error loading SwinIR model: {model_file}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
def upscale(
@@ -115,9 +143,16 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
class UpscalerSwin(modules.images.Upscaler):
def __init__(self, filename, title):
self.name = title
self.model = load_model(filename)
self.filename = filename
def do_upscale(self, img):
model = self.model.to(device)
model = load_model(self.filename)
if model is None:
return img
model = model.to(device)
img = upscale(img, model)
return img
try:
torch.cuda.empty_cache()
except:
pass
return img