mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-04 03:10:21 +00:00
SDXL support
This commit is contained in:
@@ -14,7 +14,11 @@ from modules.hypernetworks import hypernetwork
|
||||
import ldm.modules.attention
|
||||
import ldm.modules.diffusionmodules.model
|
||||
|
||||
import sgm.modules.attention
|
||||
import sgm.modules.diffusionmodules.model
|
||||
|
||||
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
|
||||
sgm_diffusionmodules_model_AttnBlock_forward = sgm.modules.diffusionmodules.model.AttnBlock.forward
|
||||
|
||||
|
||||
class SdOptimization:
|
||||
@@ -39,6 +43,9 @@ class SdOptimization:
|
||||
ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
|
||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
|
||||
|
||||
sgm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
|
||||
sgm.modules.diffusionmodules.model.AttnBlock.forward = sgm_diffusionmodules_model_AttnBlock_forward
|
||||
|
||||
|
||||
class SdOptimizationXformers(SdOptimization):
|
||||
name = "xformers"
|
||||
@@ -51,6 +58,8 @@ class SdOptimizationXformers(SdOptimization):
|
||||
def apply(self):
|
||||
ldm.modules.attention.CrossAttention.forward = xformers_attention_forward
|
||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward
|
||||
sgm.modules.attention.CrossAttention.forward = xformers_attention_forward
|
||||
sgm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward
|
||||
|
||||
|
||||
class SdOptimizationSdpNoMem(SdOptimization):
|
||||
@@ -65,6 +74,8 @@ class SdOptimizationSdpNoMem(SdOptimization):
|
||||
def apply(self):
|
||||
ldm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward
|
||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward
|
||||
sgm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward
|
||||
sgm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward
|
||||
|
||||
|
||||
class SdOptimizationSdp(SdOptimizationSdpNoMem):
|
||||
@@ -76,6 +87,8 @@ class SdOptimizationSdp(SdOptimizationSdpNoMem):
|
||||
def apply(self):
|
||||
ldm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward
|
||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward
|
||||
sgm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward
|
||||
sgm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward
|
||||
|
||||
|
||||
class SdOptimizationSubQuad(SdOptimization):
|
||||
@@ -86,6 +99,8 @@ class SdOptimizationSubQuad(SdOptimization):
|
||||
def apply(self):
|
||||
ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward
|
||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward
|
||||
sgm.modules.attention.CrossAttention.forward = sub_quad_attention_forward
|
||||
sgm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward
|
||||
|
||||
|
||||
class SdOptimizationV1(SdOptimization):
|
||||
@@ -94,9 +109,9 @@ class SdOptimizationV1(SdOptimization):
|
||||
cmd_opt = "opt_split_attention_v1"
|
||||
priority = 10
|
||||
|
||||
|
||||
def apply(self):
|
||||
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
|
||||
sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
|
||||
|
||||
|
||||
class SdOptimizationInvokeAI(SdOptimization):
|
||||
@@ -109,6 +124,7 @@ class SdOptimizationInvokeAI(SdOptimization):
|
||||
|
||||
def apply(self):
|
||||
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI
|
||||
sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI
|
||||
|
||||
|
||||
class SdOptimizationDoggettx(SdOptimization):
|
||||
@@ -119,6 +135,8 @@ class SdOptimizationDoggettx(SdOptimization):
|
||||
def apply(self):
|
||||
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
|
||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
|
||||
sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward
|
||||
sgm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
|
||||
|
||||
|
||||
def list_optimizers(res):
|
||||
@@ -155,7 +173,7 @@ def get_available_vram():
|
||||
|
||||
|
||||
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
|
||||
def split_cross_attention_forward_v1(self, x, context=None, mask=None):
|
||||
def split_cross_attention_forward_v1(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0):
|
||||
h = self.heads
|
||||
|
||||
q_in = self.to_q(x)
|
||||
@@ -196,7 +214,7 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
|
||||
|
||||
|
||||
# taken from https://github.com/Doggettx/stable-diffusion and modified
|
||||
def split_cross_attention_forward(self, x, context=None, mask=None):
|
||||
def split_cross_attention_forward(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0):
|
||||
h = self.heads
|
||||
|
||||
q_in = self.to_q(x)
|
||||
@@ -262,11 +280,13 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
|
||||
# -- Taken from https://github.com/invoke-ai/InvokeAI and modified --
|
||||
mem_total_gb = psutil.virtual_memory().total // (1 << 30)
|
||||
|
||||
|
||||
def einsum_op_compvis(q, k, v):
|
||||
s = einsum('b i d, b j d -> b i j', q, k)
|
||||
s = s.softmax(dim=-1, dtype=s.dtype)
|
||||
return einsum('b i j, b j d -> b i d', s, v)
|
||||
|
||||
|
||||
def einsum_op_slice_0(q, k, v, slice_size):
|
||||
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
for i in range(0, q.shape[0], slice_size):
|
||||
@@ -274,6 +294,7 @@ def einsum_op_slice_0(q, k, v, slice_size):
|
||||
r[i:end] = einsum_op_compvis(q[i:end], k[i:end], v[i:end])
|
||||
return r
|
||||
|
||||
|
||||
def einsum_op_slice_1(q, k, v, slice_size):
|
||||
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
@@ -281,6 +302,7 @@ def einsum_op_slice_1(q, k, v, slice_size):
|
||||
r[:, i:end] = einsum_op_compvis(q[:, i:end], k, v)
|
||||
return r
|
||||
|
||||
|
||||
def einsum_op_mps_v1(q, k, v):
|
||||
if q.shape[0] * q.shape[1] <= 2**16: # (512x512) max q.shape[1]: 4096
|
||||
return einsum_op_compvis(q, k, v)
|
||||
@@ -290,12 +312,14 @@ def einsum_op_mps_v1(q, k, v):
|
||||
slice_size -= 1
|
||||
return einsum_op_slice_1(q, k, v, slice_size)
|
||||
|
||||
|
||||
def einsum_op_mps_v2(q, k, v):
|
||||
if mem_total_gb > 8 and q.shape[0] * q.shape[1] <= 2**16:
|
||||
return einsum_op_compvis(q, k, v)
|
||||
else:
|
||||
return einsum_op_slice_0(q, k, v, 1)
|
||||
|
||||
|
||||
def einsum_op_tensor_mem(q, k, v, max_tensor_mb):
|
||||
size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
|
||||
if size_mb <= max_tensor_mb:
|
||||
@@ -305,6 +329,7 @@ def einsum_op_tensor_mem(q, k, v, max_tensor_mb):
|
||||
return einsum_op_slice_0(q, k, v, q.shape[0] // div)
|
||||
return einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1))
|
||||
|
||||
|
||||
def einsum_op_cuda(q, k, v):
|
||||
stats = torch.cuda.memory_stats(q.device)
|
||||
mem_active = stats['active_bytes.all.current']
|
||||
@@ -315,6 +340,7 @@ def einsum_op_cuda(q, k, v):
|
||||
# Divide factor of safety as there's copying and fragmentation
|
||||
return einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
|
||||
|
||||
|
||||
def einsum_op(q, k, v):
|
||||
if q.device.type == 'cuda':
|
||||
return einsum_op_cuda(q, k, v)
|
||||
@@ -328,7 +354,8 @@ def einsum_op(q, k, v):
|
||||
# Tested on i7 with 8MB L3 cache.
|
||||
return einsum_op_tensor_mem(q, k, v, 32)
|
||||
|
||||
def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
|
||||
|
||||
def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0):
|
||||
h = self.heads
|
||||
|
||||
q = self.to_q(x)
|
||||
@@ -356,7 +383,7 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
|
||||
|
||||
# Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1
|
||||
# The sub_quad_attention_forward function is under the MIT License listed under Memory Efficient Attention in the Licenses section of the web UI interface
|
||||
def sub_quad_attention_forward(self, x, context=None, mask=None):
|
||||
def sub_quad_attention_forward(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0):
|
||||
assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor."
|
||||
|
||||
h = self.heads
|
||||
@@ -392,6 +419,7 @@ def sub_quad_attention_forward(self, x, context=None, mask=None):
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold=None, use_checkpoint=True):
|
||||
bytes_per_token = torch.finfo(q.dtype).bits//8
|
||||
batch_x_heads, q_tokens, _ = q.shape
|
||||
@@ -442,7 +470,7 @@ def get_xformers_flash_attention_op(q, k, v):
|
||||
return None
|
||||
|
||||
|
||||
def xformers_attention_forward(self, x, context=None, mask=None):
|
||||
def xformers_attention_forward(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0):
|
||||
h = self.heads
|
||||
q_in = self.to_q(x)
|
||||
context = default(context, x)
|
||||
@@ -465,9 +493,10 @@ def xformers_attention_forward(self, x, context=None, mask=None):
|
||||
out = rearrange(out, 'b n h d -> b n (h d)', h=h)
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
# Based on Diffusers usage of scaled dot product attention from https://github.com/huggingface/diffusers/blob/c7da8fd23359a22d0df2741688b5b4f33c26df21/src/diffusers/models/cross_attention.py
|
||||
# The scaled_dot_product_attention_forward function contains parts of code under Apache-2.0 license listed under Scaled Dot Product Attention in the Licenses section of the web UI interface
|
||||
def scaled_dot_product_attention_forward(self, x, context=None, mask=None):
|
||||
def scaled_dot_product_attention_forward(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0):
|
||||
batch_size, sequence_length, inner_dim = x.shape
|
||||
|
||||
if mask is not None:
|
||||
@@ -507,10 +536,12 @@ def scaled_dot_product_attention_forward(self, x, context=None, mask=None):
|
||||
hidden_states = self.to_out[1](hidden_states)
|
||||
return hidden_states
|
||||
|
||||
def scaled_dot_product_no_mem_attention_forward(self, x, context=None, mask=None):
|
||||
|
||||
def scaled_dot_product_no_mem_attention_forward(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0):
|
||||
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
|
||||
return scaled_dot_product_attention_forward(self, x, context, mask)
|
||||
|
||||
|
||||
def cross_attention_attnblock_forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
@@ -569,6 +600,7 @@ def cross_attention_attnblock_forward(self, x):
|
||||
|
||||
return h3
|
||||
|
||||
|
||||
def xformers_attnblock_forward(self, x):
|
||||
try:
|
||||
h_ = x
|
||||
@@ -592,6 +624,7 @@ def xformers_attnblock_forward(self, x):
|
||||
except NotImplementedError:
|
||||
return cross_attention_attnblock_forward(self, x)
|
||||
|
||||
|
||||
def sdp_attnblock_forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
@@ -612,10 +645,12 @@ def sdp_attnblock_forward(self, x):
|
||||
out = self.proj_out(out)
|
||||
return x + out
|
||||
|
||||
|
||||
def sdp_no_mem_attnblock_forward(self, x):
|
||||
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
|
||||
return sdp_attnblock_forward(self, x)
|
||||
|
||||
|
||||
def sub_quad_attnblock_forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
|
Reference in New Issue
Block a user