mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-04 19:22:32 +00:00
ruff manual fixes
This commit is contained in:
@@ -52,7 +52,7 @@ class DDPM(pl.LightningModule):
|
||||
beta_schedule="linear",
|
||||
loss_type="l2",
|
||||
ckpt_path=None,
|
||||
ignore_keys=[],
|
||||
ignore_keys=None,
|
||||
load_only_unet=False,
|
||||
monitor="val/loss",
|
||||
use_ema=True,
|
||||
@@ -107,7 +107,7 @@ class DDPM(pl.LightningModule):
|
||||
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
||||
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [], only_model=load_only_unet)
|
||||
|
||||
# If initialing from EMA-only checkpoint, create EMA model after loading.
|
||||
if self.use_ema and not load_ema:
|
||||
@@ -194,7 +194,9 @@ class DDPM(pl.LightningModule):
|
||||
if context is not None:
|
||||
print(f"{context}: Restored training weights")
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
|
||||
def init_from_ckpt(self, path, ignore_keys=None, only_model=False):
|
||||
ignore_keys = ignore_keys or []
|
||||
|
||||
sd = torch.load(path, map_location="cpu")
|
||||
if "state_dict" in list(sd.keys()):
|
||||
sd = sd["state_dict"]
|
||||
@@ -473,7 +475,7 @@ class LatentDiffusion(DDPM):
|
||||
conditioning_key = None
|
||||
ckpt_path = kwargs.pop("ckpt_path", None)
|
||||
ignore_keys = kwargs.pop("ignore_keys", [])
|
||||
super().__init__(conditioning_key=conditioning_key, *args, load_ema=load_ema, **kwargs)
|
||||
super().__init__(*args, conditioning_key=conditioning_key, load_ema=load_ema, **kwargs)
|
||||
self.concat_mode = concat_mode
|
||||
self.cond_stage_trainable = cond_stage_trainable
|
||||
self.cond_stage_key = cond_stage_key
|
||||
@@ -1433,10 +1435,10 @@ class Layout2ImgDiffusion(LatentDiffusion):
|
||||
# TODO: move all layout-specific hacks to this class
|
||||
def __init__(self, cond_stage_key, *args, **kwargs):
|
||||
assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
|
||||
super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs)
|
||||
super().__init__(*args, cond_stage_key=cond_stage_key, **kwargs)
|
||||
|
||||
def log_images(self, batch, N=8, *args, **kwargs):
|
||||
logs = super().log_images(batch=batch, N=N, *args, **kwargs)
|
||||
logs = super().log_images(*args, batch=batch, N=N, **kwargs)
|
||||
|
||||
key = 'train' if self.training else 'validation'
|
||||
dset = self.trainer.datamodule.datasets[key]
|
||||
|
Reference in New Issue
Block a user