mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-03 19:02:27 +00:00
scaled dot product attention
This commit is contained in:
@@ -346,6 +346,48 @@ def xformers_attention_forward(self, x, context=None, mask=None):
|
||||
out = rearrange(out, 'b n h d -> b n (h d)', h=h)
|
||||
return self.to_out(out)
|
||||
|
||||
# Based on Diffusers usage of scaled dot product attention from https://github.com/huggingface/diffusers/blob/c7da8fd23359a22d0df2741688b5b4f33c26df21/src/diffusers/models/cross_attention.py
|
||||
# The scaled_dot_product_attention_forward function contains parts of code under Apache-2.0 license listed under Scaled Dot Product Attention in the Licenses section of the web UI interface
|
||||
def scaled_dot_product_attention_forward(self, x, context=None, mask=None):
|
||||
batch_size, sequence_length, inner_dim = x.shape
|
||||
|
||||
if mask is not None:
|
||||
mask = self.prepare_attention_mask(mask, sequence_length, batch_size)
|
||||
mask = mask.view(batch_size, self.heads, -1, mask.shape[-1])
|
||||
|
||||
h = self.heads
|
||||
q_in = self.to_q(x)
|
||||
context = default(context, x)
|
||||
|
||||
context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
|
||||
k_in = self.to_k(context_k)
|
||||
v_in = self.to_v(context_v)
|
||||
|
||||
head_dim = inner_dim // h
|
||||
q = q_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
|
||||
k = k_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
|
||||
v = v_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
|
||||
|
||||
del q_in, k_in, v_in
|
||||
|
||||
dtype = q.dtype
|
||||
if shared.opts.upcast_attn:
|
||||
q, k = q.float(), k.float()
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, h * head_dim)
|
||||
hidden_states = hidden_states.to(dtype)
|
||||
|
||||
# linear proj
|
||||
hidden_states = self.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = self.to_out[1](hidden_states)
|
||||
return hidden_states
|
||||
|
||||
def cross_attention_attnblock_forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
|
Reference in New Issue
Block a user