Merge branch 'dev' into feature/restore-progress

This commit is contained in:
AUTOMATIC1111
2023-04-29 22:13:40 +03:00
committed by GitHub
63 changed files with 1387 additions and 327 deletions

View File

@@ -5,6 +5,7 @@ import importlib
import signal
import re
import warnings
import json
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
@@ -20,6 +21,9 @@ startup_timer = timer.Timer()
import torch
import pytorch_lightning # pytorch_lightning should be imported after torch, but it re-enables warnings on import so import once to disable them
warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="pytorch_lightning")
warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision")
startup_timer.record("import torch")
import gradio
@@ -37,7 +41,7 @@ if ".dev" in torch.__version__ or "+git" in torch.__version__:
torch.__long_version__ = torch.__version__
torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks
from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks, config_states
import modules.codeformer_model as codeformer
import modules.face_restoration
import modules.gfpgan_model as gfpgan
@@ -67,11 +71,51 @@ else:
server_name = "0.0.0.0" if cmd_opts.listen else None
def fix_asyncio_event_loop_policy():
"""
The default `asyncio` event loop policy only automatically creates
event loops in the main threads. Other threads must create event
loops explicitly or `asyncio.get_event_loop` (and therefore
`.IOLoop.current`) will fail. Installing this policy allows event
loops to be created automatically on any thread, matching the
behavior of Tornado versions prior to 5.0 (or 5.0 on Python 2).
"""
import asyncio
if sys.platform == "win32" and hasattr(asyncio, "WindowsSelectorEventLoopPolicy"):
# "Any thread" and "selector" should be orthogonal, but there's not a clean
# interface for composing policies so pick the right base.
_BasePolicy = asyncio.WindowsSelectorEventLoopPolicy # type: ignore
else:
_BasePolicy = asyncio.DefaultEventLoopPolicy
class AnyThreadEventLoopPolicy(_BasePolicy): # type: ignore
"""Event loop policy that allows loop creation on any thread.
Usage::
asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
"""
def get_event_loop(self) -> asyncio.AbstractEventLoop:
try:
return super().get_event_loop()
except (RuntimeError, AssertionError):
# This was an AssertionError in python 3.4.2 (which ships with debian jessie)
# and changed to a RuntimeError in 3.4.3.
# "There is no current event loop in thread %r"
loop = self.new_event_loop()
self.set_event_loop(loop)
return loop
asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
def check_versions():
if shared.cmd_opts.skip_version_check:
return
expected_torch_version = "1.13.1"
expected_torch_version = "2.0.0"
if version.parse(torch.__version__) < version.parse(expected_torch_version):
errors.print_error_explanation(f"""
@@ -84,7 +128,7 @@ there are reports of issues with training tab on the latest version.
Use --skip-version-check commandline argument to disable this check.
""".strip())
expected_xformers_version = "0.0.16rc425"
expected_xformers_version = "0.0.17"
if shared.xformers_available:
import xformers
@@ -99,12 +143,27 @@ Use --skip-version-check commandline argument to disable this check.
def initialize():
fix_asyncio_event_loop_policy()
check_versions()
extensions.list_extensions()
localization.list_localizations(cmd_opts.localizations_dir)
startup_timer.record("list extensions")
config_state_file = shared.opts.restore_config_state_file
shared.opts.restore_config_state_file = ""
shared.opts.save(shared.config_filename)
if os.path.isfile(config_state_file):
print(f"*** About to restore extension state from file: {config_state_file}")
with open(config_state_file, "r", encoding="utf-8") as f:
config_state = json.load(f)
config_states.restore_extension_config(config_state)
startup_timer.record("restore extension config")
elif config_state_file:
print(f"!!! Config state backup not found: {config_state_file}")
if cmd_opts.ui_debug_mode:
shared.sd_upscalers = upscaler.UpscalerLanczos().scalers
modules.scripts.load_scripts()
@@ -126,9 +185,6 @@ def initialize():
modules.scripts.load_scripts()
startup_timer.record("load scripts")
modelloader.load_upscalers()
startup_timer.record("load upscalers")
modules.sd_vae.refresh_vae_list()
startup_timer.record("refresh VAE")
@@ -150,6 +206,7 @@ def initialize():
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
startup_timer.record("opts onchange")
shared.reload_hypernetworks()
@@ -212,6 +269,8 @@ def wait_on_server(demo=None):
time.sleep(0.5)
demo.close()
time.sleep(0.5)
modules.script_callbacks.app_reload_callback()
break
@@ -260,6 +319,7 @@ def webui():
server_port=cmd_opts.port,
ssl_keyfile=cmd_opts.tls_keyfile,
ssl_certfile=cmd_opts.tls_certfile,
ssl_verify=cmd_opts.disable_tls_verify,
debug=cmd_opts.gradio_debug,
auth=[tuple(cred.split(':')) for cred in gradio_auth_creds] if gradio_auth_creds else None,
inbrowser=cmd_opts.autolaunch,
@@ -302,6 +362,19 @@ def webui():
extensions.list_extensions()
startup_timer.record("list extensions")
config_state_file = shared.opts.restore_config_state_file
shared.opts.restore_config_state_file = ""
shared.opts.save(shared.config_filename)
if os.path.isfile(config_state_file):
print(f"*** About to restore extension state from file: {config_state_file}")
with open(config_state_file, "r", encoding="utf-8") as f:
config_state = json.load(f)
config_states.restore_extension_config(config_state)
startup_timer.record("restore extension config")
elif config_state_file:
print(f"!!! Config state backup not found: {config_state_file}")
localization.list_localizations(cmd_opts.localizations_dir)
modelloader.forbid_loaded_nonbuiltin_upscalers()