Merge pull request #15205 from AUTOMATIC1111/callback_order

Callback order
This commit is contained in:
AUTOMATIC1111
2024-03-16 09:45:41 +03:00
committed by GitHub
7 changed files with 361 additions and 121 deletions

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
import configparser
import dataclasses
import os
import threading
import re
@@ -22,6 +23,13 @@ def active():
return [x for x in extensions if x.enabled]
@dataclasses.dataclass
class CallbackOrderInfo:
name: str
before: list
after: list
class ExtensionMetadata:
filename = "metadata.ini"
config: configparser.ConfigParser
@@ -65,6 +73,22 @@ class ExtensionMetadata:
# both "," and " " are accepted as separator
return [x for x in re.split(r"[,\s]+", text.strip()) if x]
def list_callback_order_instructions(self):
for section in self.config.sections():
if not section.startswith("callbacks/"):
continue
callback_name = section[10:]
if not callback_name.startswith(self.canonical_name):
errors.report(f"Callback order section for extension {self.canonical_name} is referencing the wrong extension: {section}")
continue
before = self.parse_list(self.config.get(section, 'Before', fallback=''))
after = self.parse_list(self.config.get(section, 'After', fallback=''))
yield CallbackOrderInfo(callback_name, before, after)
class Extension:
lock = threading.Lock()
@@ -188,6 +212,7 @@ class Extension:
def list_extensions():
extensions.clear()
extension_paths.clear()
if shared.cmd_opts.disable_all_extensions:
print("*** \"--disable-all-extensions\" arg was used, will not load any extensions ***")
@@ -222,6 +247,7 @@ def list_extensions():
is_builtin = dirname == extensions_builtin_dir
extension = Extension(name=extension_dirname, path=path, enabled=extension_dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin, metadata=metadata)
extensions.append(extension)
extension_paths[extension.path] = extension
loaded_extensions[canonical_name] = extension
# check for requirements
@@ -240,4 +266,19 @@ def list_extensions():
continue
def find_extension(filename):
parentdir = os.path.dirname(os.path.realpath(filename))
while parentdir != filename:
extension = extension_paths.get(parentdir)
if extension is not None:
return extension
filename = parentdir
parentdir = os.path.dirname(filename)
return None
extensions: list[Extension] = []
extension_paths: dict[str, Extension] = {}