mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-04 11:12:35 +00:00
Merge branch 'master' into hn-activation
This commit is contained in:
@@ -5,6 +5,7 @@ import html
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
import inspect
|
||||
|
||||
import modules.textual_inversion.dataset
|
||||
import torch
|
||||
@@ -15,10 +16,12 @@ from modules import devices, processing, sd_models, shared
|
||||
from modules.textual_inversion import textual_inversion
|
||||
from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
||||
from torch import einsum
|
||||
from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_normal_, kaiming_uniform_, zeros_
|
||||
|
||||
from collections import defaultdict, deque
|
||||
from statistics import stdev, mean
|
||||
|
||||
|
||||
class HypernetworkModule(torch.nn.Module):
|
||||
multiplier = 1.0
|
||||
activation_dict = {
|
||||
@@ -26,9 +29,12 @@ class HypernetworkModule(torch.nn.Module):
|
||||
"leakyrelu": torch.nn.LeakyReLU,
|
||||
"elu": torch.nn.ELU,
|
||||
"swish": torch.nn.Hardswish,
|
||||
"tanh": torch.nn.Tanh,
|
||||
"sigmoid": torch.nn.Sigmoid,
|
||||
}
|
||||
activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
|
||||
|
||||
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False, activate_output=False):
|
||||
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal', add_layer_norm=False, use_dropout=False, activate_output=False):
|
||||
super().__init__()
|
||||
|
||||
assert layer_structure is not None, "layer_structure must not be None"
|
||||
@@ -65,9 +71,24 @@ class HypernetworkModule(torch.nn.Module):
|
||||
else:
|
||||
for layer in self.linear:
|
||||
if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
|
||||
layer.weight.data.normal_(mean=0.0, std=0.01)
|
||||
layer.bias.data.zero_()
|
||||
|
||||
w, b = layer.weight.data, layer.bias.data
|
||||
if weight_init == "Normal" or type(layer) == torch.nn.LayerNorm:
|
||||
normal_(w, mean=0.0, std=0.01)
|
||||
normal_(b, mean=0.0, std=0.005)
|
||||
elif weight_init == 'XavierUniform':
|
||||
xavier_uniform_(w)
|
||||
zeros_(b)
|
||||
elif weight_init == 'XavierNormal':
|
||||
xavier_normal_(w)
|
||||
zeros_(b)
|
||||
elif weight_init == 'KaimingUniform':
|
||||
kaiming_uniform_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')
|
||||
zeros_(b)
|
||||
elif weight_init == 'KaimingNormal':
|
||||
kaiming_normal_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')
|
||||
zeros_(b)
|
||||
else:
|
||||
raise KeyError(f"Key {weight_init} is not defined as initialization!")
|
||||
self.to(devices.device)
|
||||
|
||||
def fix_old_state_dict(self, state_dict):
|
||||
@@ -105,7 +126,7 @@ class Hypernetwork:
|
||||
filename = None
|
||||
name = None
|
||||
|
||||
def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False, activate_output=False):
|
||||
def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, activate_output=False)
|
||||
self.filename = None
|
||||
self.name = name
|
||||
self.layers = {}
|
||||
@@ -114,14 +135,15 @@ class Hypernetwork:
|
||||
self.sd_checkpoint_name = None
|
||||
self.layer_structure = layer_structure
|
||||
self.activation_func = activation_func
|
||||
self.weight_init = weight_init
|
||||
self.add_layer_norm = add_layer_norm
|
||||
self.use_dropout = use_dropout
|
||||
self.activate_output = activate_output
|
||||
|
||||
for size in enable_sizes or []:
|
||||
self.layers[size] = (
|
||||
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout, self.activate_output),
|
||||
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout, self.activate_output),
|
||||
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout, self.activate_output),
|
||||
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout, self.activate_output),
|
||||
)
|
||||
|
||||
def weights(self):
|
||||
@@ -145,6 +167,7 @@ class Hypernetwork:
|
||||
state_dict['layer_structure'] = self.layer_structure
|
||||
state_dict['activation_func'] = self.activation_func
|
||||
state_dict['is_layer_norm'] = self.add_layer_norm
|
||||
state_dict['weight_initialization'] = self.weight_init
|
||||
state_dict['use_dropout'] = self.use_dropout
|
||||
state_dict['sd_checkpoint'] = self.sd_checkpoint
|
||||
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
|
||||
@@ -160,16 +183,22 @@ class Hypernetwork:
|
||||
state_dict = torch.load(filename, map_location='cpu')
|
||||
|
||||
self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
|
||||
print(self.layer_structure)
|
||||
self.activation_func = state_dict.get('activation_func', None)
|
||||
print(f"Activation function is {self.activation_func}")
|
||||
self.weight_init = state_dict.get('weight_initialization', 'Normal')
|
||||
print(f"Weight initialization is {self.weight_init}")
|
||||
self.add_layer_norm = state_dict.get('is_layer_norm', False)
|
||||
self.use_dropout = state_dict.get('use_dropout', False)
|
||||
print(f"Layer norm is set to {self.add_layer_norm}")
|
||||
self.use_dropout = state_dict.get('use_dropout', False
|
||||
print(f"Dropout usage is set to {self.use_dropout}" )
|
||||
self.activate_output = state_dict.get('activate_output', True)
|
||||
|
||||
for size, sd in state_dict.items():
|
||||
if type(size) == int:
|
||||
self.layers[size] = (
|
||||
HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout, self.activate_output),
|
||||
HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout, self.activate_output),
|
||||
HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout, self.activate_output),
|
||||
HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout, self.activate_output),
|
||||
)
|
||||
|
||||
self.name = state_dict.get('name', self.name)
|
||||
|
Reference in New Issue
Block a user