mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-08-04 11:12:35 +00:00
Refactor upscale_2 helper out of ScuNET/SwinIR; make sure devices are right
This commit is contained in:
@@ -11,23 +11,40 @@ from modules import images, shared, torch_utils
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def upscale_without_tiling(model, img: Image.Image):
|
||||
img = np.array(img)
|
||||
img = img[:, :, ::-1]
|
||||
img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
|
||||
img = torch.from_numpy(img).float()
|
||||
def pil_image_to_torch_bgr(img: Image.Image) -> torch.Tensor:
|
||||
img = np.array(img.convert("RGB"))
|
||||
img = img[:, :, ::-1] # flip RGB to BGR
|
||||
img = np.transpose(img, (2, 0, 1)) # HWC to CHW
|
||||
img = np.ascontiguousarray(img) / 255 # Rescale to [0, 1]
|
||||
return torch.from_numpy(img)
|
||||
|
||||
|
||||
def torch_bgr_to_pil_image(tensor: torch.Tensor) -> Image.Image:
|
||||
if tensor.ndim == 4:
|
||||
# If we're given a tensor with a batch dimension, squeeze it out
|
||||
# (but only if it's a batch of size 1).
|
||||
if tensor.shape[0] != 1:
|
||||
raise ValueError(f"{tensor.shape} does not describe a BCHW tensor")
|
||||
tensor = tensor.squeeze(0)
|
||||
assert tensor.ndim == 3, f"{tensor.shape} does not describe a CHW tensor"
|
||||
# TODO: is `tensor.float().cpu()...numpy()` the most efficient idiom?
|
||||
arr = tensor.float().cpu().clamp_(0, 1).numpy() # clamp
|
||||
arr = 255.0 * np.moveaxis(arr, 0, 2) # CHW to HWC, rescale
|
||||
arr = arr.astype(np.uint8)
|
||||
arr = arr[:, :, ::-1] # flip BGR to RGB
|
||||
return Image.fromarray(arr, "RGB")
|
||||
|
||||
|
||||
def upscale_pil_patch(model, img: Image.Image) -> Image.Image:
|
||||
"""
|
||||
Upscale a given PIL image using the given model.
|
||||
"""
|
||||
param = torch_utils.get_param(model)
|
||||
img = img.unsqueeze(0).to(device=param.device, dtype=param.dtype)
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(img)
|
||||
|
||||
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||
output = 255. * np.moveaxis(output, 0, 2)
|
||||
output = output.astype(np.uint8)
|
||||
output = output[:, :, ::-1]
|
||||
return Image.fromarray(output, 'RGB')
|
||||
tensor = pil_image_to_torch_bgr(img).unsqueeze(0) # add batch dimension
|
||||
tensor = tensor.to(device=param.device, dtype=param.dtype)
|
||||
return torch_bgr_to_pil_image(model(tensor))
|
||||
|
||||
|
||||
def upscale_with_model(
|
||||
@@ -40,7 +57,7 @@ def upscale_with_model(
|
||||
) -> Image.Image:
|
||||
if tile_size <= 0:
|
||||
logger.debug("Upscaling %s without tiling", img)
|
||||
output = upscale_without_tiling(model, img)
|
||||
output = upscale_pil_patch(model, img)
|
||||
logger.debug("=> %s", output)
|
||||
return output
|
||||
|
||||
@@ -52,7 +69,7 @@ def upscale_with_model(
|
||||
newrow = []
|
||||
for x, w, tile in row:
|
||||
logger.debug("Tile (%d, %d) %s...", x, y, tile)
|
||||
output = upscale_without_tiling(model, tile)
|
||||
output = upscale_pil_patch(model, tile)
|
||||
scale_factor = output.width // tile.width
|
||||
logger.debug("=> %s (scale factor %s)", output, scale_factor)
|
||||
newrow.append([x * scale_factor, w * scale_factor, output])
|
||||
@@ -71,19 +88,22 @@ def upscale_with_model(
|
||||
|
||||
|
||||
def tiled_upscale_2(
|
||||
img,
|
||||
img: torch.Tensor,
|
||||
model,
|
||||
*,
|
||||
tile_size: int,
|
||||
tile_overlap: int,
|
||||
scale: int,
|
||||
device,
|
||||
desc="Tiled upscale",
|
||||
):
|
||||
# Alternative implementation of `upscale_with_model` originally used by
|
||||
# SwinIR and ScuNET. It differs from `upscale_with_model` in that tiling and
|
||||
# weighting is done in PyTorch space, as opposed to `images.Grid` doing it in
|
||||
# Pillow space without weighting.
|
||||
|
||||
# Grab the device the model is on, and use it.
|
||||
device = torch_utils.get_param(model).device
|
||||
|
||||
b, c, h, w = img.size()
|
||||
tile_size = min(tile_size, h, w)
|
||||
|
||||
@@ -100,7 +120,8 @@ def tiled_upscale_2(
|
||||
h * scale,
|
||||
w * scale,
|
||||
device=device,
|
||||
).type_as(img)
|
||||
dtype=img.dtype,
|
||||
)
|
||||
weights = torch.zeros_like(result)
|
||||
logger.debug("Upscaling %s to %s with tiles", img.shape, result.shape)
|
||||
with tqdm.tqdm(total=len(h_idx_list) * len(w_idx_list), desc=desc, disable=not shared.opts.enable_upscale_progressbar) as pbar:
|
||||
@@ -112,11 +133,13 @@ def tiled_upscale_2(
|
||||
if shared.state.interrupted or shared.state.skipped:
|
||||
break
|
||||
|
||||
# Only move this patch to the device if it's not already there.
|
||||
in_patch = img[
|
||||
...,
|
||||
h_idx : h_idx + tile_size,
|
||||
w_idx : w_idx + tile_size,
|
||||
]
|
||||
].to(device=device)
|
||||
|
||||
out_patch = model(in_patch)
|
||||
|
||||
result[
|
||||
@@ -138,3 +161,29 @@ def tiled_upscale_2(
|
||||
output = result.div_(weights)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def upscale_2(
|
||||
img: Image.Image,
|
||||
model,
|
||||
*,
|
||||
tile_size: int,
|
||||
tile_overlap: int,
|
||||
scale: int,
|
||||
desc: str,
|
||||
):
|
||||
"""
|
||||
Convenience wrapper around `tiled_upscale_2` that handles PIL images.
|
||||
"""
|
||||
tensor = pil_image_to_torch_bgr(img).float().unsqueeze(0) # add batch dimension
|
||||
|
||||
with torch.no_grad():
|
||||
output = tiled_upscale_2(
|
||||
tensor,
|
||||
model,
|
||||
tile_size=tile_size,
|
||||
tile_overlap=tile_overlap,
|
||||
scale=scale,
|
||||
desc=desc,
|
||||
)
|
||||
return torch_bgr_to_pil_image(output)
|
||||
|
Reference in New Issue
Block a user