initial refiner support

This commit is contained in:
AUTOMATIC1111
2023-08-06 17:01:07 +03:00
parent 57e8a11d17
commit f1975b0213
6 changed files with 76 additions and 9 deletions

View File

@@ -2,7 +2,7 @@ from collections import namedtuple
import numpy as np
import torch
from PIL import Image
from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared
from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared, sd_models
from modules.shared import opts, state
SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
@@ -127,3 +127,20 @@ def replace_torchsde_browinan():
replace_torchsde_browinan()
def apply_refiner(sampler):
completed_ratio = sampler.step / sampler.steps
if completed_ratio > shared.opts.sd_refiner_switch_at and shared.sd_model.sd_checkpoint_info.title != shared.opts.sd_refiner_checkpoint:
refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(shared.opts.sd_refiner_checkpoint)
if refiner_checkpoint_info is None:
raise Exception(f'Could not find checkpoint with name {shared.opts.sd_refiner_checkpoint}')
with sd_models.SkipWritingToConfig():
sd_models.reload_model_weights(info=refiner_checkpoint_info)
devices.torch_gc()
sampler.update_inner_model()
sampler.p.setup_conds()