mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-04 19:22:32 +00:00
Merge branch 'master' into training-help-text
This commit is contained in:
@@ -22,40 +22,57 @@ from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
||||
class HypernetworkModule(torch.nn.Module):
|
||||
multiplier = 1.0
|
||||
|
||||
def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False):
|
||||
def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False, activation_func=None):
|
||||
super().__init__()
|
||||
if layer_structure is not None:
|
||||
assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
|
||||
assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
|
||||
else:
|
||||
layer_structure = parse_layer_structure(dim, state_dict)
|
||||
|
||||
assert layer_structure is not None, "layer_structure must not be None"
|
||||
assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
|
||||
assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
|
||||
|
||||
linears = []
|
||||
for i in range(len(layer_structure) - 1):
|
||||
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
|
||||
|
||||
if activation_func == "relu":
|
||||
linears.append(torch.nn.ReLU())
|
||||
elif activation_func == "leakyrelu":
|
||||
linears.append(torch.nn.LeakyReLU())
|
||||
elif activation_func == 'linear' or activation_func is None:
|
||||
pass
|
||||
else:
|
||||
raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}')
|
||||
|
||||
if add_layer_norm:
|
||||
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
|
||||
|
||||
self.linear = torch.nn.Sequential(*linears)
|
||||
|
||||
if state_dict is not None:
|
||||
try:
|
||||
self.load_state_dict(state_dict)
|
||||
except RuntimeError:
|
||||
self.try_load_previous(state_dict)
|
||||
self.fix_old_state_dict(state_dict)
|
||||
self.load_state_dict(state_dict)
|
||||
else:
|
||||
for layer in self.linear:
|
||||
layer.weight.data.normal_(mean = 0.0, std = 0.01)
|
||||
layer.bias.data.zero_()
|
||||
if type(layer) == torch.nn.Linear:
|
||||
layer.weight.data.normal_(mean=0.0, std=0.01)
|
||||
layer.bias.data.zero_()
|
||||
|
||||
self.to(devices.device)
|
||||
|
||||
def try_load_previous(self, state_dict):
|
||||
states = self.state_dict()
|
||||
states['linear.0.bias'].copy_(state_dict['linear1.bias'])
|
||||
states['linear.0.weight'].copy_(state_dict['linear1.weight'])
|
||||
states['linear.1.bias'].copy_(state_dict['linear2.bias'])
|
||||
states['linear.1.weight'].copy_(state_dict['linear2.weight'])
|
||||
def fix_old_state_dict(self, state_dict):
|
||||
changes = {
|
||||
'linear1.bias': 'linear.0.bias',
|
||||
'linear1.weight': 'linear.0.weight',
|
||||
'linear2.bias': 'linear.1.bias',
|
||||
'linear2.weight': 'linear.1.weight',
|
||||
}
|
||||
|
||||
for fr, to in changes.items():
|
||||
x = state_dict.get(fr, None)
|
||||
if x is None:
|
||||
continue
|
||||
|
||||
del state_dict[fr]
|
||||
state_dict[to] = x
|
||||
|
||||
def forward(self, x):
|
||||
return x + self.linear(x) * self.multiplier
|
||||
@@ -63,7 +80,8 @@ class HypernetworkModule(torch.nn.Module):
|
||||
def trainables(self):
|
||||
layer_structure = []
|
||||
for layer in self.linear:
|
||||
layer_structure += [layer.weight, layer.bias]
|
||||
if type(layer) == torch.nn.Linear:
|
||||
layer_structure += [layer.weight, layer.bias]
|
||||
return layer_structure
|
||||
|
||||
|
||||
@@ -71,23 +89,11 @@ def apply_strength(value=None):
|
||||
HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength
|
||||
|
||||
|
||||
def parse_layer_structure(dim, state_dict):
|
||||
i = 0
|
||||
layer_structure = [1]
|
||||
|
||||
while (key := "linear.{}.weight".format(i)) in state_dict:
|
||||
weight = state_dict[key]
|
||||
layer_structure.append(len(weight) // dim)
|
||||
i += 1
|
||||
|
||||
return layer_structure
|
||||
|
||||
|
||||
class Hypernetwork:
|
||||
filename = None
|
||||
name = None
|
||||
|
||||
def __init__(self, name=None, enable_sizes=None, layer_structure=None, add_layer_norm=False):
|
||||
def __init__(self, name=None, enable_sizes=None, layer_structure=None, add_layer_norm=False, activation_func=None):
|
||||
self.filename = None
|
||||
self.name = name
|
||||
self.layers = {}
|
||||
@@ -96,11 +102,12 @@ class Hypernetwork:
|
||||
self.sd_checkpoint_name = None
|
||||
self.layer_structure = layer_structure
|
||||
self.add_layer_norm = add_layer_norm
|
||||
self.activation_func = activation_func
|
||||
|
||||
for size in enable_sizes or []:
|
||||
self.layers[size] = (
|
||||
HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm),
|
||||
HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm),
|
||||
HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm, self.activation_func),
|
||||
HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm, self.activation_func),
|
||||
)
|
||||
|
||||
def weights(self):
|
||||
@@ -123,6 +130,7 @@ class Hypernetwork:
|
||||
state_dict['name'] = self.name
|
||||
state_dict['layer_structure'] = self.layer_structure
|
||||
state_dict['is_layer_norm'] = self.add_layer_norm
|
||||
state_dict['activation_func'] = self.activation_func
|
||||
state_dict['sd_checkpoint'] = self.sd_checkpoint
|
||||
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
|
||||
|
||||
@@ -135,17 +143,19 @@ class Hypernetwork:
|
||||
|
||||
state_dict = torch.load(filename, map_location='cpu')
|
||||
|
||||
self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
|
||||
self.add_layer_norm = state_dict.get('is_layer_norm', False)
|
||||
self.activation_func = state_dict.get('activation_func', None)
|
||||
|
||||
for size, sd in state_dict.items():
|
||||
if type(size) == int:
|
||||
self.layers[size] = (
|
||||
HypernetworkModule(size, sd[0], state_dict["layer_structure"], state_dict["is_layer_norm"]),
|
||||
HypernetworkModule(size, sd[1], state_dict["layer_structure"], state_dict["is_layer_norm"]),
|
||||
HypernetworkModule(size, sd[0], self.layer_structure, self.add_layer_norm, self.activation_func),
|
||||
HypernetworkModule(size, sd[1], self.layer_structure, self.add_layer_norm, self.activation_func),
|
||||
)
|
||||
|
||||
self.name = state_dict.get('name', self.name)
|
||||
self.step = state_dict.get('step', 0)
|
||||
self.layer_structure = state_dict.get('layer_structure', None)
|
||||
self.add_layer_norm = state_dict.get('is_layer_norm', False)
|
||||
self.sd_checkpoint = state_dict.get('sd_checkpoint', None)
|
||||
self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)
|
||||
|
||||
@@ -244,7 +254,11 @@ def stack_conds(conds):
|
||||
|
||||
return torch.stack(conds)
|
||||
|
||||
|
||||
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
|
||||
from modules import images
|
||||
|
||||
assert hypernetwork_name, 'hypernetwork not selected'
|
||||
|
||||
path = shared.hypernetworks.get(hypernetwork_name, None)
|
||||
@@ -287,6 +301,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
||||
|
||||
last_saved_file = "<none>"
|
||||
last_saved_image = "<none>"
|
||||
forced_filename = "<none>"
|
||||
|
||||
ititial_step = hypernetwork.step or 0
|
||||
if ititial_step > steps:
|
||||
@@ -334,7 +349,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
||||
})
|
||||
|
||||
if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
|
||||
last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png')
|
||||
forced_filename = f'{hypernetwork_name}-{hypernetwork.step}'
|
||||
last_saved_image = os.path.join(images_dir, forced_filename)
|
||||
|
||||
optimizer.zero_grad()
|
||||
shared.sd_model.cond_stage_model.to(devices.device)
|
||||
@@ -370,7 +386,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
||||
|
||||
if image is not None:
|
||||
shared.state.current_image = image
|
||||
image.save(last_saved_image)
|
||||
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename)
|
||||
last_saved_image += f", prompt: {preview_text}"
|
||||
|
||||
shared.state.job_no = hypernetwork.step
|
||||
|
Reference in New Issue
Block a user