mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-03 19:02:27 +00:00
custom unet support
This commit is contained in:
@@ -3,7 +3,7 @@ from torch.nn.functional import silu
|
||||
from types import MethodType
|
||||
|
||||
import modules.textual_inversion.textual_inversion
|
||||
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors
|
||||
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet
|
||||
from modules.hypernetworks import hypernetwork
|
||||
from modules.shared import cmd_opts
|
||||
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
|
||||
@@ -43,7 +43,7 @@ def list_optimizers():
|
||||
optimizers.extend(new_optimizers)
|
||||
|
||||
|
||||
def apply_optimizations():
|
||||
def apply_optimizations(option=None):
|
||||
global current_optimizer
|
||||
|
||||
undo_optimizations()
|
||||
@@ -60,7 +60,7 @@ def apply_optimizations():
|
||||
current_optimizer.undo()
|
||||
current_optimizer = None
|
||||
|
||||
selection = shared.opts.cross_attention_optimization
|
||||
selection = option or shared.opts.cross_attention_optimization
|
||||
if selection == "Automatic" and len(optimizers) > 0:
|
||||
matching_optimizer = next(iter([x for x in optimizers if x.cmd_opt and getattr(shared.cmd_opts, x.cmd_opt, False)]), optimizers[0])
|
||||
else:
|
||||
@@ -72,12 +72,13 @@ def apply_optimizations():
|
||||
matching_optimizer = optimizers[0]
|
||||
|
||||
if matching_optimizer is not None:
|
||||
print(f"Applying optimization: {matching_optimizer.name}... ", end='')
|
||||
print(f"Applying attention optimization: {matching_optimizer.name}... ", end='')
|
||||
matching_optimizer.apply()
|
||||
print("done.")
|
||||
current_optimizer = matching_optimizer
|
||||
return current_optimizer.name
|
||||
else:
|
||||
print("Disabling attention optimization")
|
||||
return ''
|
||||
|
||||
|
||||
@@ -155,9 +156,9 @@ class StableDiffusionModelHijack:
|
||||
def __init__(self):
|
||||
self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)
|
||||
|
||||
def apply_optimizations(self):
|
||||
def apply_optimizations(self, option=None):
|
||||
try:
|
||||
self.optimization_method = apply_optimizations()
|
||||
self.optimization_method = apply_optimizations(option)
|
||||
except Exception as e:
|
||||
errors.display(e, "applying cross attention optimization")
|
||||
undo_optimizations()
|
||||
@@ -194,6 +195,11 @@ class StableDiffusionModelHijack:
|
||||
|
||||
self.layers = flatten(m)
|
||||
|
||||
if not hasattr(ldm.modules.diffusionmodules.openaimodel, 'copy_of_UNetModel_forward_for_webui'):
|
||||
ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui = ldm.modules.diffusionmodules.openaimodel.UNetModel.forward
|
||||
|
||||
ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = sd_unet.UNetModel_forward
|
||||
|
||||
def undo_hijack(self, m):
|
||||
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
|
||||
m.cond_stage_model = m.cond_stage_model.wrapped
|
||||
@@ -215,6 +221,8 @@ class StableDiffusionModelHijack:
|
||||
self.layers = None
|
||||
self.clip = None
|
||||
|
||||
ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui
|
||||
|
||||
def apply_circular(self, enable):
|
||||
if self.circular_enabled == enable:
|
||||
return
|
||||
|
Reference in New Issue
Block a user