make it possible for scripts to add cross attention optimizations

add UI selection for cross attention optimization
This commit is contained in:
AUTOMATIC
2023-05-18 22:48:28 +03:00
parent 2e006fa500
commit 2582a0fd3b
7 changed files with 226 additions and 49 deletions

View File

@@ -9,10 +9,139 @@ from torch import einsum
from ldm.util import default
from einops import rearrange
from modules import shared, errors, devices
from modules import shared, errors, devices, sub_quadratic_attention, script_callbacks
from modules.hypernetworks import hypernetwork
from .sub_quadratic_attention import efficient_dot_product_attention
import ldm.modules.attention
import ldm.modules.diffusionmodules.model
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
class SdOptimization:
def __init__(self, name, label=None, cmd_opt=None):
self.name = name
self.label = label
self.cmd_opt = cmd_opt
def title(self):
if self.label is None:
return self.name
return f"{self.name} - {self.label}"
def is_available(self):
return True
def priority(self):
return 0
def apply(self):
pass
def undo(self):
ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
class SdOptimizationXformers(SdOptimization):
def __init__(self):
super().__init__("xformers", cmd_opt="xformers")
def is_available(self):
return shared.cmd_opts.force_enable_xformers or (shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0))
def priority(self):
return 100
def apply(self):
ldm.modules.attention.CrossAttention.forward = xformers_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward
class SdOptimizationSdpNoMem(SdOptimization):
def __init__(self, name="sdp-no-mem", label="scaled dot product without memory efficient attention", cmd_opt="opt_sdp_no_mem_attention"):
super().__init__(name, label, cmd_opt)
def is_available(self):
return hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(torch.nn.functional.scaled_dot_product_attention)
def priority(self):
return 90
def apply(self):
ldm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward
class SdOptimizationSdp(SdOptimizationSdpNoMem):
def __init__(self):
super().__init__("sdp", "scaled dot product", cmd_opt="opt_sdp_attention")
def priority(self):
return 80
def apply(self):
ldm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward
class SdOptimizationSubQuad(SdOptimization):
def __init__(self):
super().__init__("sub-quadratic", cmd_opt="opt_sub_quad_attention")
def priority(self):
return 10
def apply(self):
ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward
class SdOptimizationV1(SdOptimization):
def __init__(self):
super().__init__("V1", "original v1", cmd_opt="opt_split_attention_v1")
def priority(self):
return 10
def apply(self):
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
class SdOptimizationInvokeAI(SdOptimization):
def __init__(self):
super().__init__("InvokeAI", cmd_opt="opt_split_attention_invokeai")
def priority(self):
return 1000 if not torch.cuda.is_available() else 10
def apply(self):
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI
class SdOptimizationDoggettx(SdOptimization):
def __init__(self):
super().__init__("Doggettx", cmd_opt="opt_split_attention")
def priority(self):
return 20
def apply(self):
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
def list_optimizers(res):
res.extend([
SdOptimizationXformers(),
SdOptimizationSdpNoMem(),
SdOptimizationSdp(),
SdOptimizationSubQuad(),
SdOptimizationV1(),
SdOptimizationInvokeAI(),
SdOptimizationDoggettx(),
])
if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
@@ -299,7 +428,7 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_
kv_chunk_size = k_tokens
with devices.without_autocast(disable=q.dtype == v.dtype):
return efficient_dot_product_attention(
return sub_quadratic_attention.efficient_dot_product_attention(
q,
k,
v,