Gradient accumulation, autocast fix, new latent sampling method, etc

This commit is contained in:
flamelaw
2022-11-20 12:35:26 +09:00
parent 47a44c7e42
commit bd68e35de3
7 changed files with 408 additions and 273 deletions

View File

@@ -0,0 +1,10 @@
from torch.utils.checkpoint import checkpoint
def BasicTransformerBlock_forward(self, x, context=None):
return checkpoint(self._forward, x, context)
def AttentionBlock_forward(self, x):
return checkpoint(self._forward, x)
def ResBlock_forward(self, x, emb):
return checkpoint(self._forward, x, emb)