Autofix Ruff W (not W605) (mostly whitespace)

This commit is contained in:
Aarni Koskela
2023-05-11 18:28:15 +03:00
parent 431bc5a297
commit 49a55b410b
39 changed files with 167 additions and 166 deletions

View File

@@ -13,7 +13,7 @@ from basicsr.utils.registry import ARCH_REGISTRY
def normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
@torch.jit.script
def swish(x):
@@ -210,15 +210,15 @@ class AttnBlock(nn.Module):
# compute attention
b, c, h, w = q.shape
q = q.reshape(b, c, h*w)
q = q.permute(0, 2, 1)
q = q.permute(0, 2, 1)
k = k.reshape(b, c, h*w)
w_ = torch.bmm(q, k)
w_ = torch.bmm(q, k)
w_ = w_ * (int(c)**(-0.5))
w_ = F.softmax(w_, dim=2)
# attend to values
v = v.reshape(b, c, h*w)
w_ = w_.permute(0, 2, 1)
w_ = w_.permute(0, 2, 1)
h_ = torch.bmm(v, w_)
h_ = h_.reshape(b, c, h, w)
@@ -270,18 +270,18 @@ class Encoder(nn.Module):
def forward(self, x):
for block in self.blocks:
x = block(x)
return x
class Generator(nn.Module):
def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
super().__init__()
self.nf = nf
self.ch_mult = ch_mult
self.nf = nf
self.ch_mult = ch_mult
self.num_resolutions = len(self.ch_mult)
self.num_res_blocks = res_blocks
self.resolution = img_size
self.resolution = img_size
self.attn_resolutions = attn_resolutions
self.in_channels = emb_dim
self.out_channels = 3
@@ -315,24 +315,24 @@ class Generator(nn.Module):
blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1))
self.blocks = nn.ModuleList(blocks)
def forward(self, x):
for block in self.blocks:
x = block(x)
return x
@ARCH_REGISTRY.register()
class VQAutoEncoder(nn.Module):
def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=None, codebook_size=1024, emb_dim=256,
beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
super().__init__()
logger = get_root_logger()
self.in_channels = 3
self.nf = nf
self.n_blocks = res_blocks
self.in_channels = 3
self.nf = nf
self.n_blocks = res_blocks
self.codebook_size = codebook_size
self.embed_dim = emb_dim
self.ch_mult = ch_mult
@@ -363,11 +363,11 @@ class VQAutoEncoder(nn.Module):
self.kl_weight
)
self.generator = Generator(
self.nf,
self.nf,
self.embed_dim,
self.ch_mult,
self.n_blocks,
self.resolution,
self.ch_mult,
self.n_blocks,
self.resolution,
self.attn_resolutions
)
@@ -432,4 +432,4 @@ class VQGANDiscriminator(nn.Module):
raise ValueError('Wrong params!')
def forward(self, x):
return self.main(x)
return self.main(x)