make it possible to use hypernetworks without opt split attention

This commit is contained in:
AUTOMATIC
2022-10-07 16:39:51 +03:00
parent 97bc0b9504
commit f7c787eb7c
2 changed files with 38 additions and 10 deletions

View File

@@ -4,7 +4,12 @@ import sys
import traceback
import torch
from modules import devices
from ldm.util import default
from modules import devices, shared
import torch
from torch import einsum
from einops import rearrange, repeat
class HypernetworkModule(torch.nn.Module):
@@ -48,15 +53,36 @@ def load_hypernetworks(path):
return res
def apply(self, x, context=None, mask=None, original=None):
def attention_CrossAttention_forward(self, x, context=None, mask=None):
h = self.heads
if CrossAttention.hypernetwork is not None and context.shape[2] in CrossAttention.hypernetwork:
if context.shape[1] == 77 and CrossAttention.noise_cond:
context = context + (torch.randn_like(context) * 0.1)
h_k, h_v = CrossAttention.hypernetwork[context.shape[2]]
k = self.to_k(h_k(context))
v = self.to_v(h_v(context))
q = self.to_q(x)
context = default(context, x)
hypernetwork = shared.selected_hypernetwork()
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
if hypernetwork_layers is not None:
k = self.to_k(hypernetwork_layers[0](context))
v = self.to_v(hypernetwork_layers[1](context))
else:
k = self.to_k(context)
v = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
if mask is not None:
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
attn = sim.softmax(dim=-1)
out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
return self.to_out(out)