add dropout

This commit is contained in:
discus0434
2022-10-22 11:07:00 +00:00
parent 6a02841fff
commit 0e8ca8e7af
3 changed files with 72 additions and 53 deletions

View File

@@ -1,47 +1,60 @@
import csv
import datetime
import glob
import html
import os
import sys
import traceback
import tqdm
import csv
import torch
from ldm.util import default
from modules import devices, shared, processing, sd_models
import torch
from torch import einsum
from einops import rearrange, repeat
import modules.textual_inversion.dataset
import torch
import tqdm
from einops import rearrange, repeat
from ldm.util import default
from modules import devices, processing, sd_models, shared
from modules.textual_inversion import textual_inversion
from modules.textual_inversion.learn_schedule import LearnRateScheduler
from torch import einsum
class HypernetworkModule(torch.nn.Module):
multiplier = 1.0
activation_dict = {"relu": torch.nn.ReLU, "leakyrelu": torch.nn.LeakyReLU, "elu": torch.nn.ELU,
"swish": torch.nn.Hardswish}
def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False, activation_func=None):
activation_dict = {
"relu": torch.nn.ReLU,
"leakyrelu": torch.nn.LeakyReLU,
"elu": torch.nn.ELU,
"swish": torch.nn.Hardswish,
}
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False):
super().__init__()
assert layer_structure is not None, "layer_structure must not be None"
assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
assert activation_func not in self.activation_dict.keys() + "linear", f"Valid activation funcs: 'linear', 'relu', 'leakyrelu', 'elu', 'swish'"
linears = []
for i in range(len(layer_structure) - 1):
# Add a fully-connected layer
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
# if skip_first_layer because first parameters potentially contain negative values
# if i < 1: continue
if activation_func in HypernetworkModule.activation_dict:
linears.append(HypernetworkModule.activation_dict[activation_func]())
# Add an activation func
if activation_func == "linear":
pass
elif activation_func in self.activation_dict:
linears.append(self.activation_dict[activation_func]())
else:
print("Invalid key {} encountered as activation function!".format(activation_func))
# if use_dropout:
# linears.append(torch.nn.Dropout(p=0.3))
raise NotImplementedError(
"Valid activation funcs: 'linear', 'relu', 'leakyrelu', 'elu', 'swish'"
)
# Add dropout
if use_dropout:
linears.append(torch.nn.Dropout(p=0.3))
# Add layer normalization
if add_layer_norm:
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
@@ -93,7 +106,7 @@ class Hypernetwork:
filename = None
name = None
def __init__(self, name=None, enable_sizes=None, layer_structure=None, add_layer_norm=False, activation_func=None):
def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False):
self.filename = None
self.name = name
self.layers = {}
@@ -101,13 +114,14 @@ class Hypernetwork:
self.sd_checkpoint = None
self.sd_checkpoint_name = None
self.layer_structure = layer_structure
self.add_layer_norm = add_layer_norm
self.activation_func = activation_func
self.add_layer_norm = add_layer_norm
self.use_dropout = use_dropout
for size in enable_sizes or []:
self.layers[size] = (
HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm, self.activation_func),
HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm, self.activation_func),
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
)
def weights(self):
@@ -129,8 +143,9 @@ class Hypernetwork:
state_dict['step'] = self.step
state_dict['name'] = self.name
state_dict['layer_structure'] = self.layer_structure
state_dict['is_layer_norm'] = self.add_layer_norm
state_dict['activation_func'] = self.activation_func
state_dict['is_layer_norm'] = self.add_layer_norm
state_dict['use_dropout'] = self.use_dropout
state_dict['sd_checkpoint'] = self.sd_checkpoint
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
@@ -144,8 +159,9 @@ class Hypernetwork:
state_dict = torch.load(filename, map_location='cpu')
self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
self.add_layer_norm = state_dict.get('is_layer_norm', False)
self.activation_func = state_dict.get('activation_func', None)
self.add_layer_norm = state_dict.get('is_layer_norm', False)
self.use_dropout = state_dict.get('use_dropout', False)
for size, sd in state_dict.items():
if type(size) == int: