Merge pull request #1851 from C43H66N12O12S2/flash

xformers attention
This commit is contained in:
AUTOMATIC1111
2022-10-08 16:29:59 +03:00
committed by GitHub
6 changed files with 55 additions and 6 deletions

View File

@@ -1,7 +1,14 @@
import math
import torch
from torch import einsum
try:
import xformers.ops
import functorch
xformers._is_functorch_available = True
shared.xformers_available = True
except:
print('Cannot find xformers, defaulting to split attention. Try setting --xformers in your webui-user file if you wish to install it.')
continue
from ldm.util import default
from einops import rearrange
@@ -115,6 +122,25 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
return self.to_out(r2)
def xformers_attention_forward(self, x, context=None, mask=None):
h = self.heads
q_in = self.to_q(x)
context = default(context, x)
hypernetwork = shared.selected_hypernetwork()
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
if hypernetwork_layers is not None:
k_in = self.to_k(hypernetwork_layers[0](context))
v_in = self.to_v(hypernetwork_layers[1](context))
else:
k_in = self.to_k(context)
v_in = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
out = rearrange(out, 'b n h d -> b n (h d)', h=h)
return self.to_out(out)
def cross_attention_attnblock_forward(self, x):
h_ = x
h_ = self.norm(h_)
@@ -177,3 +203,13 @@ def cross_attention_attnblock_forward(self, x):
h3 += x
return h3
def xformers_attnblock_forward(self, x):
h_ = x
h_ = self.norm(h_)
q1 = self.q(h_).contiguous()
k1 = self.k(h_).contiguous()
v = self.v(h_).contiguous()
out = xformers.ops.memory_efficient_attention(q1, k1, v)
out = self.proj_out(out)
return x+out