manual fixes for ruff

This commit is contained in:
AUTOMATIC
2023-05-10 08:25:25 +03:00
parent 762265eab5
commit 96d6ca4199
22 changed files with 129 additions and 129 deletions

View File

@@ -479,7 +479,7 @@ class LatentDiffusion(DDPM):
self.cond_stage_key = cond_stage_key
try:
self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
except:
except Exception:
self.num_downs = 0
if not scale_by_std:
self.scale_factor = scale_factor
@@ -891,16 +891,6 @@ class LatentDiffusion(DDPM):
c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
return self.p_losses(x, c, t, *args, **kwargs)
def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset
def rescale_bbox(bbox):
x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
w = min(bbox[2] / crop_coordinates[2], 1 - x0)
h = min(bbox[3] / crop_coordinates[3], 1 - y0)
return x0, y0, w, h
return [rescale_bbox(b) for b in bboxes]
def apply_model(self, x_noisy, t, cond, return_ids=False):
if isinstance(cond, dict):
@@ -1171,8 +1161,10 @@ class LatentDiffusion(DDPM):
if i % log_every_t == 0 or i == timesteps - 1:
intermediates.append(x0_partial)
if callback: callback(i)
if img_callback: img_callback(img, i)
if callback:
callback(i)
if img_callback:
img_callback(img, i)
return img, intermediates
@torch.no_grad()
@@ -1219,8 +1211,10 @@ class LatentDiffusion(DDPM):
if i % log_every_t == 0 or i == timesteps - 1:
intermediates.append(img)
if callback: callback(i)
if img_callback: img_callback(img, i)
if callback:
callback(i)
if img_callback:
img_callback(img, i)
if return_intermediates:
return img, intermediates
@@ -1337,7 +1331,7 @@ class LatentDiffusion(DDPM):
if inpaint:
# make a simple center square
b, h, w = z.shape[0], z.shape[2], z.shape[3]
h, w = z.shape[2], z.shape[3]
mask = torch.ones(N, h, w).to(self.device)
# zeros will be filled in
mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.

View File

@@ -54,7 +54,8 @@ class UniPCSampler(object):
if conditioning is not None:
if isinstance(conditioning, dict):
ctmp = conditioning[list(conditioning.keys())[0]]
while isinstance(ctmp, list): ctmp = ctmp[0]
while isinstance(ctmp, list):
ctmp = ctmp[0]
cbs = ctmp.shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")