mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-04 11:12:35 +00:00
ruff manual fixes
This commit is contained in:
@@ -24,7 +24,7 @@ class VQModel(pl.LightningModule):
|
||||
n_embed,
|
||||
embed_dim,
|
||||
ckpt_path=None,
|
||||
ignore_keys=[],
|
||||
ignore_keys=None,
|
||||
image_key="image",
|
||||
colorize_nlabels=None,
|
||||
monitor=None,
|
||||
@@ -62,7 +62,7 @@ class VQModel(pl.LightningModule):
|
||||
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
||||
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [])
|
||||
self.scheduler_config = scheduler_config
|
||||
self.lr_g_factor = lr_g_factor
|
||||
|
||||
@@ -81,11 +81,11 @@ class VQModel(pl.LightningModule):
|
||||
if context is not None:
|
||||
print(f"{context}: Restored training weights")
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=list()):
|
||||
def init_from_ckpt(self, path, ignore_keys=None):
|
||||
sd = torch.load(path, map_location="cpu")["state_dict"]
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
for ik in ignore_keys:
|
||||
for ik in ignore_keys or []:
|
||||
if k.startswith(ik):
|
||||
print("Deleting key {} from state_dict.".format(k))
|
||||
del sd[k]
|
||||
@@ -270,7 +270,7 @@ class VQModel(pl.LightningModule):
|
||||
|
||||
class VQModelInterface(VQModel):
|
||||
def __init__(self, embed_dim, *args, **kwargs):
|
||||
super().__init__(embed_dim=embed_dim, *args, **kwargs)
|
||||
super().__init__(*args, embed_dim=embed_dim, **kwargs)
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
def encode(self, x):
|
||||
|
Reference in New Issue
Block a user