mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-09 13:49:48 +00:00
Verify architecture for loaded Spandrel models
This commit is contained in:
@@ -37,6 +37,7 @@ class FaceRestorerCodeFormer(face_restoration_utils.CommonFaceRestoration):
|
||||
return modelloader.load_spandrel_model(
|
||||
model_path,
|
||||
device=devices.device_codeformer,
|
||||
expected_architecture='CodeFormer',
|
||||
).model
|
||||
raise ValueError("No codeformer model found")
|
||||
|
||||
|
@@ -49,6 +49,7 @@ class UpscalerESRGAN(Upscaler):
|
||||
return modelloader.load_spandrel_model(
|
||||
filename,
|
||||
device=('cpu' if devices.device_esrgan.type == 'mps' else None),
|
||||
expected_architecture='ESRGAN',
|
||||
)
|
||||
|
||||
|
||||
|
@@ -37,6 +37,7 @@ class FaceRestorerGFPGAN(face_restoration_utils.CommonFaceRestoration):
|
||||
net = modelloader.load_spandrel_model(
|
||||
model_path,
|
||||
device=self.get_device(),
|
||||
expected_architecture='GFPGAN',
|
||||
).model
|
||||
net.different_w = True # see https://github.com/chaiNNer-org/spandrel/pull/81
|
||||
return net
|
||||
|
@@ -39,4 +39,5 @@ class UpscalerHAT(Upscaler):
|
||||
return modelloader.load_spandrel_model(
|
||||
path,
|
||||
device=devices.device_esrgan, # TODO: should probably be device_hat
|
||||
expected_architecture='HAT',
|
||||
)
|
||||
|
@@ -6,6 +6,8 @@ import shutil
|
||||
import importlib
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import torch
|
||||
|
||||
from modules import shared
|
||||
from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone
|
||||
from modules.paths import script_path, models_path
|
||||
@@ -183,9 +185,18 @@ def load_upscalers():
|
||||
)
|
||||
|
||||
|
||||
def load_spandrel_model(path, *, device, half: bool = False, dtype=None):
|
||||
def load_spandrel_model(
|
||||
path: str,
|
||||
*,
|
||||
device: str | torch.device | None,
|
||||
half: bool = False,
|
||||
dtype: str | None = None,
|
||||
expected_architecture: str | None = None,
|
||||
):
|
||||
import spandrel
|
||||
model = spandrel.ModelLoader(device=device).load_from_file(path)
|
||||
if expected_architecture and model.architecture != expected_architecture:
|
||||
raise TypeError(f"Model {path} is not a {expected_architecture} model")
|
||||
if half:
|
||||
model = model.model.half()
|
||||
if dtype:
|
||||
|
@@ -1,9 +1,9 @@
|
||||
import os
|
||||
|
||||
from modules.upscaler_utils import upscale_with_model
|
||||
from modules.upscaler import Upscaler, UpscalerData
|
||||
from modules.shared import cmd_opts, opts
|
||||
from modules import modelloader, errors
|
||||
from modules.shared import cmd_opts, opts
|
||||
from modules.upscaler import Upscaler, UpscalerData
|
||||
from modules.upscaler_utils import upscale_with_model
|
||||
|
||||
|
||||
class UpscalerRealESRGAN(Upscaler):
|
||||
@@ -40,6 +40,7 @@ class UpscalerRealESRGAN(Upscaler):
|
||||
info.local_data_path,
|
||||
device=self.device,
|
||||
half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling),
|
||||
expected_architecture="RealESRGAN",
|
||||
)
|
||||
return upscale_with_model(
|
||||
mod,
|
||||
|
Reference in New Issue
Block a user