mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-04 11:12:35 +00:00
do not replace entire unet for the resolution hack
This commit is contained in:
@@ -313,31 +313,3 @@ def xformers_attnblock_forward(self, x):
|
||||
return x + out
|
||||
except NotImplementedError:
|
||||
return cross_attention_attnblock_forward(self, x)
|
||||
|
||||
def patched_unet_forward(self, x, timesteps=None, context=None, y=None,**kwargs):
|
||||
assert (y is not None) == (
|
||||
self.num_classes is not None
|
||||
), "must specify y if and only if the model is class-conditional"
|
||||
hs = []
|
||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
||||
emb = self.time_embed(t_emb)
|
||||
|
||||
if self.num_classes is not None:
|
||||
assert y.shape == (x.shape[0],)
|
||||
emb = emb + self.label_emb(y)
|
||||
|
||||
h = x.type(self.dtype)
|
||||
for module in self.input_blocks:
|
||||
h = module(h, emb, context)
|
||||
hs.append(h)
|
||||
h = self.middle_block(h, emb, context)
|
||||
for module in self.output_blocks:
|
||||
if h.shape[-2:] != hs[-1].shape[-2:]:
|
||||
h = F.interpolate(h, hs[-1].shape[-2:], mode="nearest")
|
||||
h = torch.cat([h, hs.pop()], dim=1)
|
||||
h = module(h, emb, context)
|
||||
h = h.type(x.dtype)
|
||||
if self.predict_codebook_ids:
|
||||
return self.id_predictor(h)
|
||||
else:
|
||||
return self.out(h)
|
||||
|
Reference in New Issue
Block a user