Add Birch-san's sub-quadratic attention implementation

This commit is contained in:
brkirch
2022-12-27 08:50:55 -05:00
parent 4af3ca5393
commit d782a95967
6 changed files with 312 additions and 35 deletions

View File

@@ -1,7 +1,7 @@
import math
import sys
import traceback
import importlib
import psutil
import torch
from torch import einsum
@@ -12,6 +12,8 @@ from einops import rearrange
from modules import shared
from modules.hypernetworks import hypernetwork
from .sub_quadratic_attention import efficient_dot_product_attention
if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
try:
@@ -22,6 +24,19 @@ if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
print(traceback.format_exc(), file=sys.stderr)
def get_available_vram():
if shared.device.type == 'cuda':
stats = torch.cuda.memory_stats(shared.device)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
return mem_free_total
else:
return psutil.virtual_memory().available
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
def split_cross_attention_forward_v1(self, x, context=None, mask=None):
h = self.heads
@@ -76,12 +91,7 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
stats = torch.cuda.memory_stats(q.device)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
mem_free_total = get_available_vram()
gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
@@ -118,19 +128,8 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
return self.to_out(r2)
def check_for_psutil():
try:
spec = importlib.util.find_spec('psutil')
return spec is not None
except ModuleNotFoundError:
return False
invokeAI_mps_available = check_for_psutil()
# -- Taken from https://github.com/invoke-ai/InvokeAI and modified --
if invokeAI_mps_available:
import psutil
mem_total_gb = psutil.virtual_memory().total // (1 << 30)
mem_total_gb = psutil.virtual_memory().total // (1 << 30)
def einsum_op_compvis(q, k, v):
s = einsum('b i d, b j d -> b i j', q, k)
@@ -215,6 +214,70 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
# -- End of code from https://github.com/invoke-ai/InvokeAI --
# Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1
def sub_quad_attention_forward(self, x, context=None, mask=None):
assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor."
h = self.heads
q = self.to_q(x)
context = default(context, x)
context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
k = self.to_k(context_k)
v = self.to_v(context_v)
del context, context_k, context_v, x
q = q.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold_bytes=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2)
out_proj, dropout = self.to_out
x = out_proj(x)
x = dropout(x)
return x
def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold_bytes=None, use_checkpoint=True):
bytes_per_token = torch.finfo(q.dtype).bits//8
batch_x_heads, q_tokens, _ = q.shape
_, k_tokens, _ = k.shape
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
available_vram = int(get_available_vram() * 0.9) if q.device.type == 'mps' else int(get_available_vram() * 0.7)
if chunk_threshold_bytes is None:
chunk_threshold_bytes = available_vram
elif chunk_threshold_bytes == 0:
chunk_threshold_bytes = None
if kv_chunk_size_min is None:
kv_chunk_size_min = chunk_threshold_bytes // (batch_x_heads * bytes_per_token * (k.shape[2] + v.shape[2]))
elif kv_chunk_size_min == 0:
kv_chunk_size_min = None
if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes:
# the big matmul fits into our memory limit; do everything in 1 chunk,
# i.e. send it down the unchunked fast-path
query_chunk_size = q_tokens
kv_chunk_size = k_tokens
return efficient_dot_product_attention(
q,
k,
v,
query_chunk_size=q_chunk_size,
kv_chunk_size=kv_chunk_size,
kv_chunk_size_min = kv_chunk_size_min,
use_checkpoint=use_checkpoint,
)
def xformers_attention_forward(self, x, context=None, mask=None):
h = self.heads
q_in = self.to_q(x)
@@ -252,12 +315,7 @@ def cross_attention_attnblock_forward(self, x):
h_ = torch.zeros_like(k, device=q.device)
stats = torch.cuda.memory_stats(q.device)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
mem_free_total = get_available_vram()
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
mem_required = tensor_size * 2.5
@@ -312,3 +370,19 @@ def xformers_attnblock_forward(self, x):
return x + out
except NotImplementedError:
return cross_attention_attnblock_forward(self, x)
def sub_quad_attnblock_forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
b, c, h, w = q.shape
q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
out = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold_bytes=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
out = rearrange(out, 'b (h w) c -> b c h w', h=h)
out = self.proj_out(out)
return x + out