rework hypertile into a built-in extension

This commit is contained in:
AUTOMATIC1111
2023-11-26 10:51:45 +03:00
parent 3a9bf4ac10
commit d2e0c1ca13
5 changed files with 183 additions and 151 deletions

View File

@@ -1,10 +1,13 @@
"""
Hypertile module for splitting attention layers in SD-1.5 U-Net and SD-1.5 VAE
Warn : The patch works well only if the input image has a width and height that are multiples of 128
Author : @tfernd Github : https://github.com/tfernd/HyperTile
Warn: The patch works well only if the input image has a width and height that are multiples of 128
Original author: @tfernd Github: https://github.com/tfernd/HyperTile
"""
from __future__ import annotations
import functools
from dataclasses import dataclass
from typing import Callable
from typing_extensions import Literal
@@ -18,6 +21,19 @@ import random
from einops import rearrange
@dataclass
class HypertileParams:
depth = 0
layer_name = ""
tile_size: int = 0
swap_size: int = 0
aspect_ratio: float = 1.0
forward = None
enabled = False
# TODO add SD-XL layers
DEPTH_LAYERS = {
0: [
@@ -176,6 +192,7 @@ DEPTH_LAYERS_XL = {
RNG_INSTANCE = random.Random()
def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int:
"""
Returns a random divisor of value that
@@ -193,10 +210,13 @@ def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int:
return ns[idx]
def set_hypertile_seed(seed: int) -> None:
RNG_INSTANCE.seed(seed)
def largest_tile_size_available(width:int, height:int) -> int:
@functools.cache
def largest_tile_size_available(width: int, height: int) -> int:
"""
Calculates the largest tile size available for a given width and height
Tile size is always a power of 2
@@ -207,6 +227,7 @@ def largest_tile_size_available(width:int, height:int) -> int:
largest_tile_size_available *= 2
return largest_tile_size_available
def iterative_closest_divisors(hw:int, aspect_ratio:float) -> tuple[int, int]:
"""
Finds h and w such that h*w = hw and h/w = aspect_ratio
@@ -219,6 +240,7 @@ def iterative_closest_divisors(hw:int, aspect_ratio:float) -> tuple[int, int]:
closest_pair = pairs[ratios.index(closest_ratio)] # closest pair of divisors to aspect_ratio
return closest_pair
@cache
def find_hw_candidates(hw:int, aspect_ratio:float) -> tuple[int, int]:
"""
@@ -240,132 +262,87 @@ def find_hw_candidates(hw:int, aspect_ratio:float) -> tuple[int, int]:
w = int(w_candidate)
return h, w
@contextmanager
def split_attention(
layer: nn.Module,
/,
aspect_ratio: float, # width/height
tile_size: int = 128, # 128 for VAE
swap_size: int = 1, # 1 for VAE
*,
disable: bool = False,
max_depth: Literal[0, 1, 2, 3] = 0, # ! Try 0 or 1
scale_depth: bool = True, # scale the tile-size depending on the depth
is_sdxl: bool = False, # is the model SD-XL
):
# Hijacks AttnBlock from ldm and Attention from diffusers
if disable:
logging.info(f"Attention for {layer.__class__.__qualname__} not splitted")
yield
return
def self_attn_forward(params: HypertileParams, scale_depth=True) -> Callable:
latent_tile_size = max(128, tile_size) // 8
@wraps(params.forward)
def wrapper(*args, **kwargs):
if not params.enabled:
return params.forward(*args, **kwargs)
def self_attn_forward(forward: Callable, depth: int, layer_name: str, module: nn.Module) -> Callable:
@wraps(forward)
def wrapper(*args, **kwargs):
x = args[0]
latent_tile_size = max(128, params.tile_size) // 8
x = args[0]
# VAE
if x.ndim == 4:
b, c, h, w = x.shape
# VAE
if x.ndim == 4:
b, c, h, w = x.shape
nh = random_divisor(h, latent_tile_size, swap_size)
nw = random_divisor(w, latent_tile_size, swap_size)
nh = random_divisor(h, latent_tile_size, params.swap_size)
nw = random_divisor(w, latent_tile_size, params.swap_size)
if nh * nw > 1:
x = rearrange(x, "b c (nh h) (nw w) -> (b nh nw) c h w", nh=nh, nw=nw) # split into nh * nw tiles
if nh * nw > 1:
x = rearrange(x, "b c (nh h) (nw w) -> (b nh nw) c h w", nh=nh, nw=nw) # split into nh * nw tiles
out = forward(x, *args[1:], **kwargs)
out = params.forward(x, *args[1:], **kwargs)
if nh * nw > 1:
out = rearrange(out, "(b nh nw) c h w -> b c (nh h) (nw w)", nh=nh, nw=nw)
if nh * nw > 1:
out = rearrange(out, "(b nh nw) c h w -> b c (nh h) (nw w)", nh=nh, nw=nw)
# U-Net
else:
hw: int = x.size(1)
h, w = find_hw_candidates(hw, aspect_ratio)
assert h * w == hw, f"Invalid aspect ratio {aspect_ratio} for input of shape {x.shape}, hw={hw}, h={h}, w={w}"
factor = 2**depth if scale_depth else 1
nh = random_divisor(h, latent_tile_size * factor, swap_size)
nw = random_divisor(w, latent_tile_size * factor, swap_size)
module._split_sizes_hypertile.append((nh, nw)) # type: ignore
if nh * nw > 1:
x = rearrange(x, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw)
out = forward(x, *args[1:], **kwargs)
if nh * nw > 1:
out = rearrange(out, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw)
out = rearrange(out, "b nh nw (h w) c -> b (nh h nw w) c", h=h // nh, w=w // nw)
return out
return wrapper
# Handle hijacking the forward method and recovering afterwards
try:
if is_sdxl:
layers = DEPTH_LAYERS_XL
# U-Net
else:
layers = DEPTH_LAYERS
for depth in range(max_depth + 1):
for layer_name, module in layer.named_modules():
hw: int = x.size(1)
h, w = find_hw_candidates(hw, params.aspect_ratio)
assert h * w == hw, f"Invalid aspect ratio {params.aspect_ratio} for input of shape {x.shape}, hw={hw}, h={h}, w={w}"
factor = 2 ** params.depth if scale_depth else 1
nh = random_divisor(h, latent_tile_size * factor, params.swap_size)
nw = random_divisor(w, latent_tile_size * factor, params.swap_size)
if nh * nw > 1:
x = rearrange(x, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw)
out = params.forward(x, *args[1:], **kwargs)
if nh * nw > 1:
out = rearrange(out, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw)
out = rearrange(out, "b nh nw (h w) c -> b (nh h nw w) c", h=h // nh, w=w // nw)
return out
return wrapper
def hypertile_hook_model(model: nn.Module, width, height, *, enable=False, tile_size_max=128, swap_size=1, max_depth=3, is_sdxl=False):
hypertile_layers = getattr(model, "__webui_hypertile_layers", None)
if hypertile_layers is None:
if not enable:
return
hypertile_layers = {}
layers = DEPTH_LAYERS_XL if is_sdxl else DEPTH_LAYERS
for depth in range(4):
for layer_name, module in model.named_modules():
if any(layer_name.endswith(try_name) for try_name in layers[depth]):
# print input shape for debugging
logging.debug(f"HyperTile hijacking attention layer at depth {depth}: {layer_name}")
# hijack
module._original_forward_hypertile = module.forward
module.forward = self_attn_forward(module.forward, depth, layer_name, module)
module._split_sizes_hypertile = []
yield
finally:
for layer_name, module in layer.named_modules():
# remove hijack
if hasattr(module, "_original_forward_hypertile"):
if module._split_sizes_hypertile:
logging.debug(f"layer {layer_name} splitted with ({module._split_sizes_hypertile})")
# recover
module.forward = module._original_forward_hypertile
del module._original_forward_hypertile
del module._split_sizes_hypertile
params = HypertileParams()
module.__webui_hypertile_params = params
params.forward = module.forward
params.depth = depth
params.layer_name = layer_name
module.forward = self_attn_forward(params)
def hypertile_context_vae(model:nn.Module, aspect_ratio:float, tile_size:int, opts):
"""
Returns context manager for VAE
"""
enabled = opts.hypertile_split_vae_attn
swap_size = opts.hypertile_swap_size_vae
max_depth = opts.hypertile_max_depth_vae
tile_size_max = opts.hypertile_max_tile_vae
return split_attention(
model,
aspect_ratio=aspect_ratio,
tile_size=min(tile_size, tile_size_max),
swap_size=swap_size,
disable=not enabled,
max_depth=max_depth,
is_sdxl=False,
)
hypertile_layers[layer_name] = 1
def hypertile_context_unet(model:nn.Module, aspect_ratio:float, tile_size:int, opts, is_sdxl:bool):
"""
Returns context manager for U-Net
"""
enabled = opts.hypertile_split_unet_attn
swap_size = opts.hypertile_swap_size_unet
max_depth = opts.hypertile_max_depth_unet
tile_size_max = opts.hypertile_max_tile_unet
return split_attention(
model,
aspect_ratio=aspect_ratio,
tile_size=min(tile_size, tile_size_max),
swap_size=swap_size,
disable=not enabled,
max_depth=max_depth,
is_sdxl=is_sdxl,
)
model.__webui_hypertile_layers = hypertile_layers
aspect_ratio = width / height
tile_size = min(largest_tile_size_available(width, height), tile_size_max)
for layer_name, module in model.named_modules():
if layer_name in hypertile_layers:
params = module.__webui_hypertile_params
params.tile_size = tile_size
params.swap_size = swap_size
params.aspect_ratio = aspect_ratio
params.enabled = enable and params.depth <= max_depth