Add safetensors support to LDSR

This commit is contained in:
wywywywy
2022-12-10 18:57:18 +00:00
parent 685f9631b5
commit 8bcdd50461
2 changed files with 14 additions and 4 deletions

View File

@@ -1,3 +1,4 @@
import os
import gc
import time
import warnings
@@ -8,6 +9,7 @@ import torchvision
from PIL import Image
from einops import rearrange, repeat
from omegaconf import OmegaConf
import safetensors.torch
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.util import instantiate_from_config, ismap
@@ -28,8 +30,12 @@ class LDSR:
model: torch.nn.Module = cached_ldsr_model
else:
print(f"Loading model from {self.modelPath}")
pl_sd = torch.load(self.modelPath, map_location="cpu")
sd = pl_sd["state_dict"]
_, extension = os.path.splitext(self.modelPath)
if extension.lower() == ".safetensors":
pl_sd = safetensors.torch.load_file(self.modelPath, device="cpu")
else:
pl_sd = torch.load(self.modelPath, map_location="cpu")
sd = pl_sd["state_dict"] if "state_dict" in pl_sd else pl_sd
config = OmegaConf.load(self.yamlPath)
config.model.target = "ldm.models.diffusion.ddpm.LatentDiffusionV1"
model: torch.nn.Module = instantiate_from_config(config.model)