Upscaler.load_model: don't return None, just use exceptions

This commit is contained in:
Aarni Koskela
2023-05-29 10:38:51 +03:00
parent e3a973a68d
commit bf67a5dcf4
5 changed files with 52 additions and 64 deletions

View File

@@ -1,4 +1,3 @@
import os.path
import sys
import PIL.Image
@@ -8,7 +7,7 @@ from tqdm import tqdm
import modules.upscaler
from modules import devices, modelloader, script_callbacks, errors
from scunet_model_arch import SCUNet as net
from scunet_model_arch import SCUNet
from modules.modelloader import load_file_from_url
from modules.shared import opts
@@ -88,9 +87,10 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
torch.cuda.empty_cache()
model = self.load_model(selected_file)
if model is None:
print(f"ScuNET: Unable to load model from {selected_file}", file=sys.stderr)
try:
model = self.load_model(selected_file)
except Exception as e:
print(f"ScuNET: Unable to load model from {selected_file}: {e}", file=sys.stderr)
return img
device = devices.get_device_for('scunet')
@@ -123,11 +123,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
filename = load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth")
else:
filename = path
if not os.path.exists(os.path.join(self.model_path, filename)) or filename is None:
print(f"ScuNET: Unable to load model from {filename}", file=sys.stderr)
return None
model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
model = SCUNet(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
model.load_state_dict(torch.load(filename), strict=True)
model.eval()
for _, v in model.named_parameters():