mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-08 05:12:35 +00:00
Add safetensors support to LDSR
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
import gc
|
||||
import time
|
||||
import warnings
|
||||
@@ -8,6 +9,7 @@ import torchvision
|
||||
from PIL import Image
|
||||
from einops import rearrange, repeat
|
||||
from omegaconf import OmegaConf
|
||||
import safetensors.torch
|
||||
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.util import instantiate_from_config, ismap
|
||||
@@ -28,8 +30,12 @@ class LDSR:
|
||||
model: torch.nn.Module = cached_ldsr_model
|
||||
else:
|
||||
print(f"Loading model from {self.modelPath}")
|
||||
pl_sd = torch.load(self.modelPath, map_location="cpu")
|
||||
sd = pl_sd["state_dict"]
|
||||
_, extension = os.path.splitext(self.modelPath)
|
||||
if extension.lower() == ".safetensors":
|
||||
pl_sd = safetensors.torch.load_file(self.modelPath, device="cpu")
|
||||
else:
|
||||
pl_sd = torch.load(self.modelPath, map_location="cpu")
|
||||
sd = pl_sd["state_dict"] if "state_dict" in pl_sd else pl_sd
|
||||
config = OmegaConf.load(self.yamlPath)
|
||||
config.model.target = "ldm.models.diffusion.ddpm.LatentDiffusionV1"
|
||||
model: torch.nn.Module = instantiate_from_config(config.model)
|
||||
|
Reference in New Issue
Block a user