This commit is contained in:
AUTOMATIC1111
2024-06-16 08:13:23 +03:00
parent 5b2a60b8e2
commit 79de09c3df
2 changed files with 16 additions and 13 deletions

View File

@@ -1,6 +1,8 @@
### Impls of the SD3 core diffusion model and VAE
import torch, math, einops
import torch
import math
import einops
from modules.models.sd3.mmdit import MMDiT
from PIL import Image
@@ -214,7 +216,7 @@ class AttnBlock(torch.nn.Module):
k = self.k(hidden)
v = self.v(hidden)
b, c, h, w = q.shape
q, k, v = map(lambda x: einops.rearrange(x, "b c h w -> b 1 (h w) c").contiguous(), (q, k, v))
q, k, v = [einops.rearrange(x, "b c h w -> b 1 (h w) c").contiguous() for x in (q, k, v)]
hidden = torch.nn.functional.scaled_dot_product_attention(q, k, v) # scale is dim ** -0.5 per default
hidden = einops.rearrange(hidden, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
hidden = self.proj_out(hidden)
@@ -259,7 +261,7 @@ class VAEEncoder(torch.nn.Module):
attn = torch.nn.ModuleList()
block_in = ch*in_ch_mult[i_level]
block_out = ch*ch_mult[i_level]
for i_block in range(num_res_blocks):
for _ in range(num_res_blocks):
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dtype=dtype, device=device))
block_in = block_out
down = torch.nn.Module()
@@ -318,7 +320,7 @@ class VAEDecoder(torch.nn.Module):
for i_level in reversed(range(self.num_resolutions)):
block = torch.nn.ModuleList()
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
for _ in range(self.num_res_blocks + 1):
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dtype=dtype, device=device))
block_in = block_out
up = torch.nn.Module()