mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-03 19:02:27 +00:00
Verify architecture for loaded Spandrel models
This commit is contained in:
@@ -6,6 +6,8 @@ import shutil
|
||||
import importlib
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import torch
|
||||
|
||||
from modules import shared
|
||||
from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone
|
||||
from modules.paths import script_path, models_path
|
||||
@@ -183,9 +185,18 @@ def load_upscalers():
|
||||
)
|
||||
|
||||
|
||||
def load_spandrel_model(path, *, device, half: bool = False, dtype=None):
|
||||
def load_spandrel_model(
|
||||
path: str,
|
||||
*,
|
||||
device: str | torch.device | None,
|
||||
half: bool = False,
|
||||
dtype: str | None = None,
|
||||
expected_architecture: str | None = None,
|
||||
):
|
||||
import spandrel
|
||||
model = spandrel.ModelLoader(device=device).load_from_file(path)
|
||||
if expected_architecture and model.architecture != expected_architecture:
|
||||
raise TypeError(f"Model {path} is not a {expected_architecture} model")
|
||||
if half:
|
||||
model = model.model.half()
|
||||
if dtype:
|
||||
|
Reference in New Issue
Block a user