mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-05 03:32:37 +00:00
Autofix Ruff W (not W605) (mostly whitespace)
This commit is contained in:
@@ -119,7 +119,7 @@ class TransformerSALayer(nn.Module):
|
||||
tgt_mask: Optional[Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||
query_pos: Optional[Tensor] = None):
|
||||
|
||||
|
||||
# self attention
|
||||
tgt2 = self.norm1(tgt)
|
||||
q = k = self.with_pos_embed(tgt2, query_pos)
|
||||
@@ -159,7 +159,7 @@ class Fuse_sft_block(nn.Module):
|
||||
|
||||
@ARCH_REGISTRY.register()
|
||||
class CodeFormer(VQAutoEncoder):
|
||||
def __init__(self, dim_embd=512, n_head=8, n_layers=9,
|
||||
def __init__(self, dim_embd=512, n_head=8, n_layers=9,
|
||||
codebook_size=1024, latent_size=256,
|
||||
connect_list=('32', '64', '128', '256'),
|
||||
fix_modules=('quantize', 'generator')):
|
||||
@@ -179,14 +179,14 @@ class CodeFormer(VQAutoEncoder):
|
||||
self.feat_emb = nn.Linear(256, self.dim_embd)
|
||||
|
||||
# transformer
|
||||
self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
|
||||
self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
|
||||
for _ in range(self.n_layers)])
|
||||
|
||||
# logits_predict head
|
||||
self.idx_pred_layer = nn.Sequential(
|
||||
nn.LayerNorm(dim_embd),
|
||||
nn.Linear(dim_embd, codebook_size, bias=False))
|
||||
|
||||
|
||||
self.channels = {
|
||||
'16': 512,
|
||||
'32': 256,
|
||||
@@ -221,7 +221,7 @@ class CodeFormer(VQAutoEncoder):
|
||||
enc_feat_dict = {}
|
||||
out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
|
||||
for i, block in enumerate(self.encoder.blocks):
|
||||
x = block(x)
|
||||
x = block(x)
|
||||
if i in out_list:
|
||||
enc_feat_dict[str(x.shape[-1])] = x.clone()
|
||||
|
||||
@@ -266,11 +266,11 @@ class CodeFormer(VQAutoEncoder):
|
||||
fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
|
||||
|
||||
for i, block in enumerate(self.generator.blocks):
|
||||
x = block(x)
|
||||
x = block(x)
|
||||
if i in fuse_list: # fuse after i-th block
|
||||
f_size = str(x.shape[-1])
|
||||
if w>0:
|
||||
x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
|
||||
out = x
|
||||
# logits doesn't need softmax before cross_entropy loss
|
||||
return out, logits, lq_feat
|
||||
return out, logits, lq_feat
|
||||
|
Reference in New Issue
Block a user