Autofix Ruff W (not W605) (mostly whitespace)

This commit is contained in:
Aarni Koskela
2023-05-11 18:28:15 +03:00
parent 431bc5a297
commit 49a55b410b
39 changed files with 167 additions and 166 deletions

View File

@@ -119,7 +119,7 @@ class TransformerSALayer(nn.Module):
tgt_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
# self attention
tgt2 = self.norm1(tgt)
q = k = self.with_pos_embed(tgt2, query_pos)
@@ -159,7 +159,7 @@ class Fuse_sft_block(nn.Module):
@ARCH_REGISTRY.register()
class CodeFormer(VQAutoEncoder):
def __init__(self, dim_embd=512, n_head=8, n_layers=9,
def __init__(self, dim_embd=512, n_head=8, n_layers=9,
codebook_size=1024, latent_size=256,
connect_list=('32', '64', '128', '256'),
fix_modules=('quantize', 'generator')):
@@ -179,14 +179,14 @@ class CodeFormer(VQAutoEncoder):
self.feat_emb = nn.Linear(256, self.dim_embd)
# transformer
self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
for _ in range(self.n_layers)])
# logits_predict head
self.idx_pred_layer = nn.Sequential(
nn.LayerNorm(dim_embd),
nn.Linear(dim_embd, codebook_size, bias=False))
self.channels = {
'16': 512,
'32': 256,
@@ -221,7 +221,7 @@ class CodeFormer(VQAutoEncoder):
enc_feat_dict = {}
out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
for i, block in enumerate(self.encoder.blocks):
x = block(x)
x = block(x)
if i in out_list:
enc_feat_dict[str(x.shape[-1])] = x.clone()
@@ -266,11 +266,11 @@ class CodeFormer(VQAutoEncoder):
fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
for i, block in enumerate(self.generator.blocks):
x = block(x)
x = block(x)
if i in fuse_list: # fuse after i-th block
f_size = str(x.shape[-1])
if w>0:
x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
out = x
# logits doesn't need softmax before cross_entropy loss
return out, logits, lq_feat
return out, logits, lq_feat

View File

@@ -13,7 +13,7 @@ from basicsr.utils.registry import ARCH_REGISTRY
def normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
@torch.jit.script
def swish(x):
@@ -210,15 +210,15 @@ class AttnBlock(nn.Module):
# compute attention
b, c, h, w = q.shape
q = q.reshape(b, c, h*w)
q = q.permute(0, 2, 1)
q = q.permute(0, 2, 1)
k = k.reshape(b, c, h*w)
w_ = torch.bmm(q, k)
w_ = torch.bmm(q, k)
w_ = w_ * (int(c)**(-0.5))
w_ = F.softmax(w_, dim=2)
# attend to values
v = v.reshape(b, c, h*w)
w_ = w_.permute(0, 2, 1)
w_ = w_.permute(0, 2, 1)
h_ = torch.bmm(v, w_)
h_ = h_.reshape(b, c, h, w)
@@ -270,18 +270,18 @@ class Encoder(nn.Module):
def forward(self, x):
for block in self.blocks:
x = block(x)
return x
class Generator(nn.Module):
def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
super().__init__()
self.nf = nf
self.ch_mult = ch_mult
self.nf = nf
self.ch_mult = ch_mult
self.num_resolutions = len(self.ch_mult)
self.num_res_blocks = res_blocks
self.resolution = img_size
self.resolution = img_size
self.attn_resolutions = attn_resolutions
self.in_channels = emb_dim
self.out_channels = 3
@@ -315,24 +315,24 @@ class Generator(nn.Module):
blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1))
self.blocks = nn.ModuleList(blocks)
def forward(self, x):
for block in self.blocks:
x = block(x)
return x
@ARCH_REGISTRY.register()
class VQAutoEncoder(nn.Module):
def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=None, codebook_size=1024, emb_dim=256,
beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
super().__init__()
logger = get_root_logger()
self.in_channels = 3
self.nf = nf
self.n_blocks = res_blocks
self.in_channels = 3
self.nf = nf
self.n_blocks = res_blocks
self.codebook_size = codebook_size
self.embed_dim = emb_dim
self.ch_mult = ch_mult
@@ -363,11 +363,11 @@ class VQAutoEncoder(nn.Module):
self.kl_weight
)
self.generator = Generator(
self.nf,
self.nf,
self.embed_dim,
self.ch_mult,
self.n_blocks,
self.resolution,
self.ch_mult,
self.n_blocks,
self.resolution,
self.attn_resolutions
)
@@ -432,4 +432,4 @@ class VQGANDiscriminator(nn.Module):
raise ValueError('Wrong params!')
def forward(self, x):
return self.main(x)
return self.main(x)