extra networks UI

rework of hypernets: rather than via settings, hypernets are added directly to prompt as <hypernet:name:weight>
This commit is contained in:
AUTOMATIC
2023-01-21 08:36:07 +03:00
parent e33cace2c2
commit 40ff6db532
25 changed files with 767 additions and 216 deletions

View File

@@ -25,7 +25,6 @@ from statistics import stdev, mean
optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}
class HypernetworkModule(torch.nn.Module):
multiplier = 1.0
activation_dict = {
"linear": torch.nn.Identity,
"relu": torch.nn.ReLU,
@@ -41,6 +40,8 @@ class HypernetworkModule(torch.nn.Module):
add_layer_norm=False, activate_output=False, dropout_structure=None):
super().__init__()
self.multiplier = 1.0
assert layer_structure is not None, "layer_structure must not be None"
assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
@@ -115,7 +116,7 @@ class HypernetworkModule(torch.nn.Module):
state_dict[to] = x
def forward(self, x):
return x + self.linear(x) * (HypernetworkModule.multiplier if not self.training else 1)
return x + self.linear(x) * (self.multiplier if not self.training else 1)
def trainables(self):
layer_structure = []
@@ -125,9 +126,6 @@ class HypernetworkModule(torch.nn.Module):
return layer_structure
def apply_strength(value=None):
HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength
#param layer_structure : sequence used for length, use_dropout : controlling boolean, last_layer_dropout : for compatibility check.
def parse_dropout_structure(layer_structure, use_dropout, last_layer_dropout):
if layer_structure is None:
@@ -192,6 +190,20 @@ class Hypernetwork:
for param in layer.parameters():
param.requires_grad = mode
def to(self, device):
for k, layers in self.layers.items():
for layer in layers:
layer.to(device)
return self
def set_multiplier(self, multiplier):
for k, layers in self.layers.items():
for layer in layers:
layer.multiplier = multiplier
return self
def eval(self):
for k, layers in self.layers.items():
for layer in layers:
@@ -269,11 +281,13 @@ class Hypernetwork:
self.optimizer_state_dict = None
if self.optimizer_state_dict:
self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW')
print("Loaded existing optimizer from checkpoint")
print(f"Optimizer name is {self.optimizer_name}")
if shared.opts.print_hypernet_extra:
print("Loaded existing optimizer from checkpoint")
print(f"Optimizer name is {self.optimizer_name}")
else:
self.optimizer_name = "AdamW"
print("No saved optimizer exists in checkpoint")
if shared.opts.print_hypernet_extra:
print("No saved optimizer exists in checkpoint")
for size, sd in state_dict.items():
if type(size) == int:
@@ -306,23 +320,43 @@ def list_hypernetworks(path):
return res
def load_hypernetwork(filename):
path = shared.hypernetworks.get(filename, None)
# Prevent any file named "None.pt" from being loaded.
if path is not None and filename != "None":
print(f"Loading hypernetwork {filename}")
try:
shared.loaded_hypernetwork = Hypernetwork()
shared.loaded_hypernetwork.load(path)
def load_hypernetwork(name):
path = shared.hypernetworks.get(name, None)
except Exception:
print(f"Error loading hypernetwork {path}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
else:
if shared.loaded_hypernetwork is not None:
print("Unloading hypernetwork")
if path is None:
return None
shared.loaded_hypernetwork = None
hypernetwork = Hypernetwork()
try:
hypernetwork.load(path)
except Exception:
print(f"Error loading hypernetwork {path}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
return None
return hypernetwork
def load_hypernetworks(names, multipliers=None):
already_loaded = {}
for hypernetwork in shared.loaded_hypernetworks:
if hypernetwork.name in names:
already_loaded[hypernetwork.name] = hypernetwork
shared.loaded_hypernetworks.clear()
for i, name in enumerate(names):
hypernetwork = already_loaded.get(name, None)
if hypernetwork is None:
hypernetwork = load_hypernetwork(name)
if hypernetwork is None:
continue
hypernetwork.set_multiplier(multipliers[i] if multipliers else 1.0)
shared.loaded_hypernetworks.append(hypernetwork)
def find_closest_hypernetwork_name(search: str):
@@ -336,18 +370,27 @@ def find_closest_hypernetwork_name(search: str):
return applicable[0]
def apply_hypernetwork(hypernetwork, context, layer=None):
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None):
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context_k.shape[2], None)
if hypernetwork_layers is None:
return context, context
return context_k, context_v
if layer is not None:
layer.hyper_k = hypernetwork_layers[0]
layer.hyper_v = hypernetwork_layers[1]
context_k = hypernetwork_layers[0](context)
context_v = hypernetwork_layers[1](context)
context_k = hypernetwork_layers[0](context_k)
context_v = hypernetwork_layers[1](context_v)
return context_k, context_v
def apply_hypernetworks(hypernetworks, context, layer=None):
context_k = context
context_v = context
for hypernetwork in hypernetworks:
context_k, context_v = apply_single_hypernetwork(hypernetwork, context_k, context_v, layer)
return context_k, context_v
@@ -357,7 +400,7 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None):
q = self.to_q(x)
context = default(context, x)
context_k, context_v = apply_hypernetwork(shared.loaded_hypernetwork, context, self)
context_k, context_v = apply_hypernetworks(shared.loaded_hypernetworks, context, self)
k = self.to_k(context_k)
v = self.to_v(context_v)
@@ -464,8 +507,9 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
template_file = template_file.path
path = shared.hypernetworks.get(hypernetwork_name, None)
shared.loaded_hypernetwork = Hypernetwork()
shared.loaded_hypernetwork.load(path)
hypernetwork = Hypernetwork()
hypernetwork.load(path)
shared.loaded_hypernetworks = [hypernetwork]
shared.state.job = "train-hypernetwork"
shared.state.textinfo = "Initializing hypernetwork training..."
@@ -489,7 +533,6 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
else:
images_dir = None
hypernetwork = shared.loaded_hypernetwork
checkpoint = sd_models.select_checkpoint()
initial_step = hypernetwork.step or 0