changes for inpainting for #35

support for --medvram
attempt to support share
This commit is contained in:
AUTOMATIC
2022-09-01 11:41:42 +03:00
parent 3e4103541c
commit e1648fc1d1
2 changed files with 76 additions and 53 deletions

110
webui.py
View File

@@ -6,7 +6,10 @@ script_path = os.path.dirname(os.path.realpath(__file__))
sd_path = os.path.dirname(script_path)
# add parent directory to path; this is where Stable diffusion repo should be
path_dirs = [(sd_path, 'ldm', 'Stable Diffusion'), ('../../taming-transformers', 'taming', 'Taming Transformers')]
path_dirs = [
(sd_path, 'ldm', 'Stable Diffusion'),
('../../taming-transformers', 'taming', 'Taming Transformers')
]
for d, must_exist, what in path_dirs:
must_exist_path = os.path.abspath(os.path.join(script_path, d, must_exist))
if not os.path.exists(must_exist_path):
@@ -38,15 +41,10 @@ from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
# fix gradio phoning home
gradio.utils.version_check = lambda: None
gradio.utils.get_local_ip_address = lambda: '127.0.0.1'
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI
mimetypes.init()
mimetypes.add_type('application/javascript', '.js')
# some of those options should not be changed at all because they would break the model, so I removed them from options.
opt_C = 4
opt_f = 8
@@ -65,14 +63,21 @@ parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
parser.add_argument("--embeddings-dir", type=str, default='embeddings', help="embeddings dirtectory for textual inversion (default: embeddings)")
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
parser.add_argument("--lowvram", action='store_true', help="enamble stable diffusion model optimizations for low vram")
parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrficing a little speed for low VRM usage")
parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrficing a lot of speed for very low VRM usage")
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)")
cmd_opts = parser.parse_args()
cpu = torch.device("cpu")
gpu = torch.device("cuda")
device = gpu if torch.cuda.is_available() else cpu
batch_cond_uncond = not (cmd_opts.lowvram or cmd_opts.medvram)
if not cmd_opts.share:
# fix gradio phoning home
gradio.utils.version_check = lambda: None
gradio.utils.get_local_ip_address = lambda: '127.0.0.1'
css_hide_progressbar = """
.wrap .m-12 svg { display:none!important; }
@@ -294,21 +299,25 @@ def setup_for_low_vram(sd_model):
sd_model.first_stage_model.decode = lambda z, de=sd_model.first_stage_model.decode: first_stage_model_decode_wrap(sd_model.first_stage_model, de, z)
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
# the third remaining model is still too big for 4GB, so we also do the same for its submodules
# so that only one of them is in GPU at a time
diff_model = sd_model.model.diffusion_model
stored = diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = None, None, None, None
sd_model.model.to(device)
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = stored
if cmd_opts.medvram:
sd_model.model.register_forward_pre_hook(send_me_to_gpu)
else:
diff_model = sd_model.model.diffusion_model
# install hooks for bits of third model
diff_model.time_embed.register_forward_pre_hook(send_me_to_gpu)
for block in diff_model.input_blocks:
block.register_forward_pre_hook(send_me_to_gpu)
diff_model.middle_block.register_forward_pre_hook(send_me_to_gpu)
for block in diff_model.output_blocks:
block.register_forward_pre_hook(send_me_to_gpu)
# the third remaining model is still too big for 4GB, so we also do the same for its submodules
# so that only one of them is in GPU at a time
stored = diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = None, None, None, None
sd_model.model.to(device)
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = stored
# install hooks for bits of third model
diff_model.time_embed.register_forward_pre_hook(send_me_to_gpu)
for block in diff_model.input_blocks:
block.register_forward_pre_hook(send_me_to_gpu)
diff_model.middle_block.register_forward_pre_hook(send_me_to_gpu)
for block in diff_model.output_blocks:
block.register_forward_pre_hook(send_me_to_gpu)
def create_random_tensors(shape, seeds):
@@ -860,7 +869,7 @@ class VanillaStableDiffusionSampler:
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning):
t_enc = int(min(p.denoising_strength, 0.999) * p.steps)
# existing code fail with cetin step counts, like 9
# existing code fails with cetin step counts, like 9
try:
self.sampler.make_schedule(ddim_num_steps=p.steps, verbose=False)
except Exception:
@@ -887,13 +896,26 @@ class CFGDenoiser(nn.Module):
def __init__(self, model):
super().__init__()
self.inner_model = model
self.mask = None
self.nmask = None
self.init_latent = None
def forward(self, x, sigma, uncond, cond, cond_scale):
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2)
cond_in = torch.cat([uncond, cond])
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
return uncond + (cond - uncond) * cond_scale
if batch_cond_uncond:
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2)
cond_in = torch.cat([uncond, cond])
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
denoised = uncond + (cond - uncond) * cond_scale
else:
uncond = self.inner_model(x, sigma, cond=uncond)
cond = self.inner_model(x, sigma, cond=cond)
denoised = uncond + (cond - uncond) * cond_scale
if self.mask is not None:
denoised = self.init_latent * self.mask + self.nmask * denoised
return denoised
class KDiffusionSampler:
@@ -910,19 +932,13 @@ class KDiffusionSampler:
xi = x + noise
if p.mask is not None:
if p.inpainting_fill == 2:
xi = xi * p.mask + noise * p.nmask
elif p.inpainting_fill == 3:
xi = xi * p.mask
sigma_sched = sigmas[p.steps - t_enc - 1:]
def mask_cb(v):
v["denoised"][:] = v["denoised"][:] * p.nmask + p.init_latent * p.mask
return self.func(self.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=mask_cb if p.mask is not None else None)
self.model_wrap_cfg.mask = p.mask
self.model_wrap_cfg.nmask = p.nmask
self.model_wrap_cfg.init_latent = p.init_latent
return self.func(self.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False)
def sample(self, p: StableDiffusionProcessing, x, conditioning, unconditional_conditioning):
sigmas = self.model_wrap.get_sigmas(p.steps)
@@ -932,7 +948,7 @@ class KDiffusionSampler:
return samples_ddim
Processed = namedtuple('Processed', ['images','seed', 'info'])
Processed = namedtuple('Processed', ['images', 'seed', 'info'])
def process_images(p: StableDiffusionProcessing) -> Processed:
@@ -1315,7 +1331,6 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
if self.mask_blur > 0:
self.image_mask = self.image_mask.filter(ImageFilter.GaussianBlur(self.mask_blur)).convert('L')
if self.inpaint_full_res:
self.mask_for_overlay = self.image_mask
mask = self.image_mask.convert('L')
@@ -1383,6 +1398,13 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.nmask = torch.asarray(latmask).to(device).type(sd_model.dtype)
def sample(self, x, conditioning, unconditional_conditioning):
if self.mask is not None:
if self.inpainting_fill == 2:
x = x * self.mask + create_random_tensors(x.shape[1:], [self.seed + x + 1 for x in range(x.shape[0])]) * self.nmask
elif self.inpainting_fill == 3:
x = x * self.mask
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning)
if self.mask is not None:
@@ -1805,10 +1827,10 @@ sd_config = OmegaConf.load(cmd_opts.config)
sd_model = load_model_from_config(sd_config, cmd_opts.ckpt)
sd_model = (sd_model if cmd_opts.no_half else sd_model.half())
if not cmd_opts.lowvram:
sd_model = sd_model.to(device)
else:
if cmd_opts.lowvram or cmd_opts.medvram:
setup_for_low_vram(sd_model)
else:
sd_model = sd_model.to(device)
model_hijack = StableDiffusionModelHijack()
model_hijack.hijack(sd_model)
@@ -1855,5 +1877,5 @@ def inject_gradio_html(javascript):
inject_gradio_html(javascript)
demo.queue(concurrency_count=1)
demo.launch()
demo.launch(share=cmd_opts.share)