This commit is contained in:
discus0434
2022-10-20 00:10:45 +00:00
parent 634acdd954
commit 6f98e89486
3 changed files with 45 additions and 32 deletions

View File

@@ -22,16 +22,20 @@ from modules.textual_inversion.learn_schedule import LearnRateScheduler
class HypernetworkModule(torch.nn.Module):
multiplier = 1.0
def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False):
def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False, activation_func=None):
super().__init__()
assert layer_structure is not None, "layer_structure mut not be None"
assert layer_structure is not None, "layer_structure must not be None"
assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
linears = []
for i in range(len(layer_structure) - 1):
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
if activation_func == "relu":
linears.append(torch.nn.ReLU())
if activation_func == "leakyrelu":
linears.append(torch.nn.LeakyReLU())
if add_layer_norm:
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
@@ -42,8 +46,9 @@ class HypernetworkModule(torch.nn.Module):
self.load_state_dict(state_dict)
else:
for layer in self.linear:
layer.weight.data.normal_(mean=0.0, std=0.01)
layer.bias.data.zero_()
if not "ReLU" in layer.__str__():
layer.weight.data.normal_(mean=0.0, std=0.01)
layer.bias.data.zero_()
self.to(devices.device)
@@ -69,7 +74,8 @@ class HypernetworkModule(torch.nn.Module):
def trainables(self):
layer_structure = []
for layer in self.linear:
layer_structure += [layer.weight, layer.bias]
if not "ReLU" in layer.__str__():
layer_structure += [layer.weight, layer.bias]
return layer_structure
@@ -81,7 +87,7 @@ class Hypernetwork:
filename = None
name = None
def __init__(self, name=None, enable_sizes=None, layer_structure=None, add_layer_norm=False):
def __init__(self, name=None, enable_sizes=None, layer_structure=None, add_layer_norm=False, activation_func=None):
self.filename = None
self.name = name
self.layers = {}
@@ -90,11 +96,12 @@ class Hypernetwork:
self.sd_checkpoint_name = None
self.layer_structure = layer_structure
self.add_layer_norm = add_layer_norm
self.activation_func = activation_func
for size in enable_sizes or []:
self.layers[size] = (
HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm),
HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm),
HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm, self.activation_func),
HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm, self.activation_func),
)
def weights(self):
@@ -117,6 +124,7 @@ class Hypernetwork:
state_dict['name'] = self.name
state_dict['layer_structure'] = self.layer_structure
state_dict['is_layer_norm'] = self.add_layer_norm
state_dict['activation_func'] = self.activation_func
state_dict['sd_checkpoint'] = self.sd_checkpoint
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
@@ -131,12 +139,13 @@ class Hypernetwork:
self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
self.add_layer_norm = state_dict.get('is_layer_norm', False)
self.activation_func = state_dict.get('activation_func', None)
for size, sd in state_dict.items():
if type(size) == int:
self.layers[size] = (
HypernetworkModule(size, sd[0], self.layer_structure, self.add_layer_norm),
HypernetworkModule(size, sd[1], self.layer_structure, self.add_layer_norm),
HypernetworkModule(size, sd[0], self.layer_structure, self.add_layer_norm, self.activation_func),
HypernetworkModule(size, sd[1], self.layer_structure, self.add_layer_norm, self.activation_func),
)
self.name = state_dict.get('name', self.name)