mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-04 11:12:35 +00:00
@@ -1,7 +1,14 @@
|
||||
import math
|
||||
import torch
|
||||
from torch import einsum
|
||||
|
||||
try:
|
||||
import xformers.ops
|
||||
import functorch
|
||||
xformers._is_functorch_available = True
|
||||
shared.xformers_available = True
|
||||
except:
|
||||
print('Cannot find xformers, defaulting to split attention. Try setting --xformers in your webui-user file if you wish to install it.')
|
||||
continue
|
||||
from ldm.util import default
|
||||
from einops import rearrange
|
||||
|
||||
@@ -115,6 +122,25 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
|
||||
|
||||
return self.to_out(r2)
|
||||
|
||||
def xformers_attention_forward(self, x, context=None, mask=None):
|
||||
h = self.heads
|
||||
q_in = self.to_q(x)
|
||||
context = default(context, x)
|
||||
hypernetwork = shared.selected_hypernetwork()
|
||||
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
|
||||
if hypernetwork_layers is not None:
|
||||
k_in = self.to_k(hypernetwork_layers[0](context))
|
||||
v_in = self.to_v(hypernetwork_layers[1](context))
|
||||
else:
|
||||
k_in = self.to_k(context)
|
||||
v_in = self.to_v(context)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
|
||||
del q_in, k_in, v_in
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
|
||||
|
||||
out = rearrange(out, 'b n h d -> b n (h d)', h=h)
|
||||
return self.to_out(out)
|
||||
|
||||
def cross_attention_attnblock_forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
@@ -177,3 +203,13 @@ def cross_attention_attnblock_forward(self, x):
|
||||
h3 += x
|
||||
|
||||
return h3
|
||||
|
||||
def xformers_attnblock_forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q1 = self.q(h_).contiguous()
|
||||
k1 = self.k(h_).contiguous()
|
||||
v = self.v(h_).contiguous()
|
||||
out = xformers.ops.memory_efficient_attention(q1, k1, v)
|
||||
out = self.proj_out(out)
|
||||
return x+out
|
||||
|
Reference in New Issue
Block a user