add UI for reordering callbacks

This commit is contained in:
AUTOMATIC1111
2024-03-10 14:07:51 +03:00
parent 0411eced89
commit 7e5e67330b
5 changed files with 192 additions and 44 deletions

View File

@@ -8,7 +8,7 @@ from typing import Optional, Any
from fastapi import FastAPI
from gradio import Blocks
from modules import errors, timer, extensions
from modules import errors, timer, extensions, shared
def report_exception(c, job):
@@ -124,9 +124,10 @@ class ScriptCallback:
name: str = None
def add_callback(callbacks, fun, *, name=None, category='unknown'):
stack = [x for x in inspect.stack() if x.filename != __file__]
filename = stack[0].filename if stack else 'unknown file'
def add_callback(callbacks, fun, *, name=None, category='unknown', filename=None):
if filename is None:
stack = [x for x in inspect.stack() if x.filename != __file__]
filename = stack[0].filename if stack else 'unknown file'
extension = extensions.find_extension(filename)
extension_name = extension.canonical_name if extension else 'base'
@@ -146,6 +147,43 @@ def add_callback(callbacks, fun, *, name=None, category='unknown'):
callbacks.append(ScriptCallback(filename, fun, unique_callback_name))
def sort_callbacks(category, unordered_callbacks, *, enable_user_sort=True):
callbacks = unordered_callbacks.copy()
if enable_user_sort:
for name in reversed(getattr(shared.opts, 'prioritized_callbacks_' + category, [])):
index = next((i for i, callback in enumerate(callbacks) if callback.name == name), None)
if index is not None:
callbacks.insert(0, callbacks.pop(index))
return callbacks
def ordered_callbacks(category, unordered_callbacks=None, *, enable_user_sort=True):
if unordered_callbacks is None:
unordered_callbacks = callback_map.get('callbacks_' + category, [])
if not enable_user_sort:
return sort_callbacks(category, unordered_callbacks, enable_user_sort=False)
callbacks = ordered_callbacks_map.get(category)
if callbacks is not None and len(callbacks) == len(unordered_callbacks):
return callbacks
callbacks = sort_callbacks(category, unordered_callbacks)
ordered_callbacks_map[category] = callbacks
return callbacks
def enumerate_callbacks():
for category, callbacks in callback_map.items():
if category.startswith('callbacks_'):
category = category[10:]
yield category, callbacks
callback_map = dict(
callbacks_app_started=[],
callbacks_model_loaded=[],
@@ -170,14 +208,18 @@ callback_map = dict(
callbacks_before_token_counter=[],
)
ordered_callbacks_map = {}
def clear_callbacks():
for callback_list in callback_map.values():
callback_list.clear()
ordered_callbacks_map.clear()
def app_started_callback(demo: Optional[Blocks], app: FastAPI):
for c in callback_map['callbacks_app_started']:
for c in ordered_callbacks('app_started'):
try:
c.callback(demo, app)
timer.startup_timer.record(os.path.basename(c.script))
@@ -186,7 +228,7 @@ def app_started_callback(demo: Optional[Blocks], app: FastAPI):
def app_reload_callback():
for c in callback_map['callbacks_on_reload']:
for c in ordered_callbacks('on_reload'):
try:
c.callback()
except Exception:
@@ -194,7 +236,7 @@ def app_reload_callback():
def model_loaded_callback(sd_model):
for c in callback_map['callbacks_model_loaded']:
for c in ordered_callbacks('model_loaded'):
try:
c.callback(sd_model)
except Exception:
@@ -204,7 +246,7 @@ def model_loaded_callback(sd_model):
def ui_tabs_callback():
res = []
for c in callback_map['callbacks_ui_tabs']:
for c in ordered_callbacks('ui_tabs'):
try:
res += c.callback() or []
except Exception:
@@ -214,7 +256,7 @@ def ui_tabs_callback():
def ui_train_tabs_callback(params: UiTrainTabParams):
for c in callback_map['callbacks_ui_train_tabs']:
for c in ordered_callbacks('ui_train_tabs'):
try:
c.callback(params)
except Exception:
@@ -222,7 +264,7 @@ def ui_train_tabs_callback(params: UiTrainTabParams):
def ui_settings_callback():
for c in callback_map['callbacks_ui_settings']:
for c in ordered_callbacks('ui_settings'):
try:
c.callback()
except Exception:
@@ -230,7 +272,7 @@ def ui_settings_callback():
def before_image_saved_callback(params: ImageSaveParams):
for c in callback_map['callbacks_before_image_saved']:
for c in ordered_callbacks('before_image_saved'):
try:
c.callback(params)
except Exception:
@@ -238,7 +280,7 @@ def before_image_saved_callback(params: ImageSaveParams):
def image_saved_callback(params: ImageSaveParams):
for c in callback_map['callbacks_image_saved']:
for c in ordered_callbacks('image_saved'):
try:
c.callback(params)
except Exception:
@@ -246,7 +288,7 @@ def image_saved_callback(params: ImageSaveParams):
def extra_noise_callback(params: ExtraNoiseParams):
for c in callback_map['callbacks_extra_noise']:
for c in ordered_callbacks('extra_noise'):
try:
c.callback(params)
except Exception:
@@ -254,7 +296,7 @@ def extra_noise_callback(params: ExtraNoiseParams):
def cfg_denoiser_callback(params: CFGDenoiserParams):
for c in callback_map['callbacks_cfg_denoiser']:
for c in ordered_callbacks('cfg_denoiser'):
try:
c.callback(params)
except Exception:
@@ -262,7 +304,7 @@ def cfg_denoiser_callback(params: CFGDenoiserParams):
def cfg_denoised_callback(params: CFGDenoisedParams):
for c in callback_map['callbacks_cfg_denoised']:
for c in ordered_callbacks('cfg_denoised'):
try:
c.callback(params)
except Exception:
@@ -270,7 +312,7 @@ def cfg_denoised_callback(params: CFGDenoisedParams):
def cfg_after_cfg_callback(params: AfterCFGCallbackParams):
for c in callback_map['callbacks_cfg_after_cfg']:
for c in ordered_callbacks('cfg_after_cfg'):
try:
c.callback(params)
except Exception:
@@ -278,7 +320,7 @@ def cfg_after_cfg_callback(params: AfterCFGCallbackParams):
def before_component_callback(component, **kwargs):
for c in callback_map['callbacks_before_component']:
for c in ordered_callbacks('before_component'):
try:
c.callback(component, **kwargs)
except Exception:
@@ -286,7 +328,7 @@ def before_component_callback(component, **kwargs):
def after_component_callback(component, **kwargs):
for c in callback_map['callbacks_after_component']:
for c in ordered_callbacks('after_component'):
try:
c.callback(component, **kwargs)
except Exception:
@@ -294,7 +336,7 @@ def after_component_callback(component, **kwargs):
def image_grid_callback(params: ImageGridLoopParams):
for c in callback_map['callbacks_image_grid']:
for c in ordered_callbacks('image_grid'):
try:
c.callback(params)
except Exception:
@@ -302,7 +344,7 @@ def image_grid_callback(params: ImageGridLoopParams):
def infotext_pasted_callback(infotext: str, params: dict[str, Any]):
for c in callback_map['callbacks_infotext_pasted']:
for c in ordered_callbacks('infotext_pasted'):
try:
c.callback(infotext, params)
except Exception:
@@ -310,7 +352,7 @@ def infotext_pasted_callback(infotext: str, params: dict[str, Any]):
def script_unloaded_callback():
for c in reversed(callback_map['callbacks_script_unloaded']):
for c in reversed(ordered_callbacks('script_unloaded')):
try:
c.callback()
except Exception:
@@ -318,7 +360,7 @@ def script_unloaded_callback():
def before_ui_callback():
for c in reversed(callback_map['callbacks_before_ui']):
for c in reversed(ordered_callbacks('before_ui')):
try:
c.callback()
except Exception:
@@ -328,7 +370,7 @@ def before_ui_callback():
def list_optimizers_callback():
res = []
for c in callback_map['callbacks_list_optimizers']:
for c in ordered_callbacks('list_optimizers'):
try:
c.callback(res)
except Exception:
@@ -340,7 +382,7 @@ def list_optimizers_callback():
def list_unets_callback():
res = []
for c in callback_map['callbacks_list_unets']:
for c in ordered_callbacks('list_unets'):
try:
c.callback(res)
except Exception:
@@ -350,7 +392,7 @@ def list_unets_callback():
def before_token_counter_callback(params: BeforeTokenCounterParams):
for c in callback_map['callbacks_before_token_counter']:
for c in ordered_callbacks('before_token_counter'):
try:
c.callback(params)
except Exception: