mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-04 11:12:35 +00:00
Add general forward method for all modules.
This commit is contained in:
@@ -3,6 +3,10 @@ import os
|
||||
from collections import namedtuple
|
||||
import enum
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from modules import sd_models, cache, errors, hashes, shared
|
||||
|
||||
NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd_module'])
|
||||
@@ -115,6 +119,29 @@ class NetworkModule:
|
||||
if hasattr(self.sd_module, 'weight'):
|
||||
self.shape = self.sd_module.weight.shape
|
||||
|
||||
self.ops = None
|
||||
self.extra_kwargs = {}
|
||||
if isinstance(self.sd_module, nn.Conv2d):
|
||||
self.ops = F.conv2d
|
||||
self.extra_kwargs = {
|
||||
'stride': self.sd_module.stride,
|
||||
'padding': self.sd_module.padding
|
||||
}
|
||||
elif isinstance(self.sd_module, nn.Linear):
|
||||
self.ops = F.linear
|
||||
elif isinstance(self.sd_module, nn.LayerNorm):
|
||||
self.ops = F.layer_norm
|
||||
self.extra_kwargs = {
|
||||
'normalized_shape': self.sd_module.normalized_shape,
|
||||
'eps': self.sd_module.eps
|
||||
}
|
||||
elif isinstance(self.sd_module, nn.GroupNorm):
|
||||
self.ops = F.group_norm
|
||||
self.extra_kwargs = {
|
||||
'num_groups': self.sd_module.num_groups,
|
||||
'eps': self.sd_module.eps
|
||||
}
|
||||
|
||||
self.dim = None
|
||||
self.bias = weights.w.get("bias")
|
||||
self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None
|
||||
@@ -155,5 +182,10 @@ class NetworkModule:
|
||||
raise NotImplementedError()
|
||||
|
||||
def forward(self, x, y):
|
||||
raise NotImplementedError()
|
||||
"""A general forward implementation for all modules"""
|
||||
if self.ops is None:
|
||||
raise NotImplementedError()
|
||||
else:
|
||||
updown, ex_bias = self.calc_updown(self.sd_module.weight)
|
||||
return y + self.ops(x, weight=updown, bias=ex_bias, **self.extra_kwargs)
|
||||
|
||||
|
Reference in New Issue
Block a user