mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-04 11:12:35 +00:00
extra networks UI
rework of hypernets: rather than via settings, hypernets are added directly to prompt as <hypernet:name:weight>
This commit is contained in:
@@ -25,7 +25,6 @@ from statistics import stdev, mean
|
||||
optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}
|
||||
|
||||
class HypernetworkModule(torch.nn.Module):
|
||||
multiplier = 1.0
|
||||
activation_dict = {
|
||||
"linear": torch.nn.Identity,
|
||||
"relu": torch.nn.ReLU,
|
||||
@@ -41,6 +40,8 @@ class HypernetworkModule(torch.nn.Module):
|
||||
add_layer_norm=False, activate_output=False, dropout_structure=None):
|
||||
super().__init__()
|
||||
|
||||
self.multiplier = 1.0
|
||||
|
||||
assert layer_structure is not None, "layer_structure must not be None"
|
||||
assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
|
||||
assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
|
||||
@@ -115,7 +116,7 @@ class HypernetworkModule(torch.nn.Module):
|
||||
state_dict[to] = x
|
||||
|
||||
def forward(self, x):
|
||||
return x + self.linear(x) * (HypernetworkModule.multiplier if not self.training else 1)
|
||||
return x + self.linear(x) * (self.multiplier if not self.training else 1)
|
||||
|
||||
def trainables(self):
|
||||
layer_structure = []
|
||||
@@ -125,9 +126,6 @@ class HypernetworkModule(torch.nn.Module):
|
||||
return layer_structure
|
||||
|
||||
|
||||
def apply_strength(value=None):
|
||||
HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength
|
||||
|
||||
#param layer_structure : sequence used for length, use_dropout : controlling boolean, last_layer_dropout : for compatibility check.
|
||||
def parse_dropout_structure(layer_structure, use_dropout, last_layer_dropout):
|
||||
if layer_structure is None:
|
||||
@@ -192,6 +190,20 @@ class Hypernetwork:
|
||||
for param in layer.parameters():
|
||||
param.requires_grad = mode
|
||||
|
||||
def to(self, device):
|
||||
for k, layers in self.layers.items():
|
||||
for layer in layers:
|
||||
layer.to(device)
|
||||
|
||||
return self
|
||||
|
||||
def set_multiplier(self, multiplier):
|
||||
for k, layers in self.layers.items():
|
||||
for layer in layers:
|
||||
layer.multiplier = multiplier
|
||||
|
||||
return self
|
||||
|
||||
def eval(self):
|
||||
for k, layers in self.layers.items():
|
||||
for layer in layers:
|
||||
@@ -269,11 +281,13 @@ class Hypernetwork:
|
||||
self.optimizer_state_dict = None
|
||||
if self.optimizer_state_dict:
|
||||
self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW')
|
||||
print("Loaded existing optimizer from checkpoint")
|
||||
print(f"Optimizer name is {self.optimizer_name}")
|
||||
if shared.opts.print_hypernet_extra:
|
||||
print("Loaded existing optimizer from checkpoint")
|
||||
print(f"Optimizer name is {self.optimizer_name}")
|
||||
else:
|
||||
self.optimizer_name = "AdamW"
|
||||
print("No saved optimizer exists in checkpoint")
|
||||
if shared.opts.print_hypernet_extra:
|
||||
print("No saved optimizer exists in checkpoint")
|
||||
|
||||
for size, sd in state_dict.items():
|
||||
if type(size) == int:
|
||||
@@ -306,23 +320,43 @@ def list_hypernetworks(path):
|
||||
return res
|
||||
|
||||
|
||||
def load_hypernetwork(filename):
|
||||
path = shared.hypernetworks.get(filename, None)
|
||||
# Prevent any file named "None.pt" from being loaded.
|
||||
if path is not None and filename != "None":
|
||||
print(f"Loading hypernetwork {filename}")
|
||||
try:
|
||||
shared.loaded_hypernetwork = Hypernetwork()
|
||||
shared.loaded_hypernetwork.load(path)
|
||||
def load_hypernetwork(name):
|
||||
path = shared.hypernetworks.get(name, None)
|
||||
|
||||
except Exception:
|
||||
print(f"Error loading hypernetwork {path}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
else:
|
||||
if shared.loaded_hypernetwork is not None:
|
||||
print("Unloading hypernetwork")
|
||||
if path is None:
|
||||
return None
|
||||
|
||||
shared.loaded_hypernetwork = None
|
||||
hypernetwork = Hypernetwork()
|
||||
|
||||
try:
|
||||
hypernetwork.load(path)
|
||||
except Exception:
|
||||
print(f"Error loading hypernetwork {path}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
return None
|
||||
|
||||
return hypernetwork
|
||||
|
||||
|
||||
def load_hypernetworks(names, multipliers=None):
|
||||
already_loaded = {}
|
||||
|
||||
for hypernetwork in shared.loaded_hypernetworks:
|
||||
if hypernetwork.name in names:
|
||||
already_loaded[hypernetwork.name] = hypernetwork
|
||||
|
||||
shared.loaded_hypernetworks.clear()
|
||||
|
||||
for i, name in enumerate(names):
|
||||
hypernetwork = already_loaded.get(name, None)
|
||||
if hypernetwork is None:
|
||||
hypernetwork = load_hypernetwork(name)
|
||||
|
||||
if hypernetwork is None:
|
||||
continue
|
||||
|
||||
hypernetwork.set_multiplier(multipliers[i] if multipliers else 1.0)
|
||||
shared.loaded_hypernetworks.append(hypernetwork)
|
||||
|
||||
|
||||
def find_closest_hypernetwork_name(search: str):
|
||||
@@ -336,18 +370,27 @@ def find_closest_hypernetwork_name(search: str):
|
||||
return applicable[0]
|
||||
|
||||
|
||||
def apply_hypernetwork(hypernetwork, context, layer=None):
|
||||
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
|
||||
def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None):
|
||||
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context_k.shape[2], None)
|
||||
|
||||
if hypernetwork_layers is None:
|
||||
return context, context
|
||||
return context_k, context_v
|
||||
|
||||
if layer is not None:
|
||||
layer.hyper_k = hypernetwork_layers[0]
|
||||
layer.hyper_v = hypernetwork_layers[1]
|
||||
|
||||
context_k = hypernetwork_layers[0](context)
|
||||
context_v = hypernetwork_layers[1](context)
|
||||
context_k = hypernetwork_layers[0](context_k)
|
||||
context_v = hypernetwork_layers[1](context_v)
|
||||
return context_k, context_v
|
||||
|
||||
|
||||
def apply_hypernetworks(hypernetworks, context, layer=None):
|
||||
context_k = context
|
||||
context_v = context
|
||||
for hypernetwork in hypernetworks:
|
||||
context_k, context_v = apply_single_hypernetwork(hypernetwork, context_k, context_v, layer)
|
||||
|
||||
return context_k, context_v
|
||||
|
||||
|
||||
@@ -357,7 +400,7 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None):
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
|
||||
context_k, context_v = apply_hypernetwork(shared.loaded_hypernetwork, context, self)
|
||||
context_k, context_v = apply_hypernetworks(shared.loaded_hypernetworks, context, self)
|
||||
k = self.to_k(context_k)
|
||||
v = self.to_v(context_v)
|
||||
|
||||
@@ -464,8 +507,9 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
||||
template_file = template_file.path
|
||||
|
||||
path = shared.hypernetworks.get(hypernetwork_name, None)
|
||||
shared.loaded_hypernetwork = Hypernetwork()
|
||||
shared.loaded_hypernetwork.load(path)
|
||||
hypernetwork = Hypernetwork()
|
||||
hypernetwork.load(path)
|
||||
shared.loaded_hypernetworks = [hypernetwork]
|
||||
|
||||
shared.state.job = "train-hypernetwork"
|
||||
shared.state.textinfo = "Initializing hypernetwork training..."
|
||||
@@ -489,7 +533,6 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
||||
else:
|
||||
images_dir = None
|
||||
|
||||
hypernetwork = shared.loaded_hypernetwork
|
||||
checkpoint = sd_models.select_checkpoint()
|
||||
|
||||
initial_step = hypernetwork.step or 0
|
||||
|
Reference in New Issue
Block a user