mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-04 19:22:32 +00:00
add split attention layer optimization from https://github.com/basujindal/stable-diffusion/pull/117
This commit is contained in:
@@ -3,8 +3,43 @@ import sys
|
||||
import traceback
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch import einsum
|
||||
|
||||
from modules.shared import opts, device
|
||||
from modules.shared import opts, device, cmd_opts
|
||||
|
||||
from ldm.util import default
|
||||
from einops import rearrange
|
||||
import ldm.modules.attention
|
||||
|
||||
|
||||
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
|
||||
def split_cross_attention_forward(self, x, context=None, mask=None):
|
||||
h = self.heads
|
||||
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
del context, x
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||
|
||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
|
||||
for i in range(0, q.shape[0], 2):
|
||||
end = i + 2
|
||||
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
|
||||
s1 *= self.scale
|
||||
|
||||
s2 = s1.softmax(dim=-1)
|
||||
del s1
|
||||
|
||||
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
|
||||
del s2
|
||||
|
||||
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
|
||||
del r1
|
||||
|
||||
return self.to_out(r2)
|
||||
|
||||
|
||||
class StableDiffusionModelHijack:
|
||||
@@ -67,6 +102,9 @@ class StableDiffusionModelHijack:
|
||||
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
|
||||
m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
|
||||
|
||||
if cmd_opts.opt_split_attention:
|
||||
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
|
||||
|
||||
|
||||
class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
|
||||
def __init__(self, wrapped, hijack):
|
||||
@@ -205,4 +243,8 @@ class EmbeddingsWithFixes(torch.nn.Module):
|
||||
return inputs_embeds
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
model_hijack = StableDiffusionModelHijack()
|
||||
|
Reference in New Issue
Block a user