added support for hypernetworks (???)

This commit is contained in:
AUTOMATIC
2022-10-07 10:17:52 +03:00
parent 2995107fa2
commit bad7cb29ce
4 changed files with 88 additions and 3 deletions

View File

@@ -5,6 +5,8 @@ from torch import einsum
from ldm.util import default
from einops import rearrange
from modules import shared
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
def split_cross_attention_forward_v1(self, x, context=None, mask=None):
@@ -42,8 +44,19 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
q_in = self.to_q(x)
context = default(context, x)
k_in = self.to_k(context) * self.scale
v_in = self.to_v(context)
hypernetwork = shared.selected_hypernetwork()
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
if hypernetwork_layers is not None:
k_in = self.to_k(hypernetwork_layers[0](context))
v_in = self.to_v(hypernetwork_layers[1](context))
else:
k_in = self.to_k(context)
v_in = self.to_v(context)
k_in *= self.scale
del context, x
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))