argument to disable memory efficient for sdp

This commit is contained in:
Pam
2023-03-10 12:19:36 +05:00
parent fec0a89511
commit 37acba2633
3 changed files with 13 additions and 3 deletions

View File

@@ -388,6 +388,10 @@ def scaled_dot_product_attention_forward(self, x, context=None, mask=None):
hidden_states = self.to_out[1](hidden_states)
return hidden_states
def scaled_dot_product_no_mem_attention_forward(self, x, context=None, mask=None):
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
return scaled_dot_product_attention_forward(self, x, context, mask)
def cross_attention_attnblock_forward(self, x):
h_ = x
h_ = self.norm(h_)