From 022d835565f253841f7f9272ba320bb0cec4770d Mon Sep 17 00:00:00 2001 From: huchenlei Date: Wed, 15 May 2024 15:20:40 -0400 Subject: [PATCH 1/3] use_checkpoint = False --- configs/alt-diffusion-inference.yaml | 2 +- configs/alt-diffusion-m18-inference.yaml | 2 +- configs/instruct-pix2pix.yaml | 2 +- configs/sd_xl_inpaint.yaml | 2 +- configs/v1-inference.yaml | 2 +- configs/v1-inpainting-inference.yaml | 2 +- modules/sd_hijack_checkpoint.py | 9 ++++++--- modules/sd_models_config.py | 2 +- 8 files changed, 13 insertions(+), 10 deletions(-) diff --git a/configs/alt-diffusion-inference.yaml b/configs/alt-diffusion-inference.yaml index cfbee72d7..4944ab5c8 100644 --- a/configs/alt-diffusion-inference.yaml +++ b/configs/alt-diffusion-inference.yaml @@ -40,7 +40,7 @@ model: use_spatial_transformer: True transformer_depth: 1 context_dim: 768 - use_checkpoint: True + use_checkpoint: False legacy: False first_stage_config: diff --git a/configs/alt-diffusion-m18-inference.yaml b/configs/alt-diffusion-m18-inference.yaml index 41a031d55..c60dca8c7 100644 --- a/configs/alt-diffusion-m18-inference.yaml +++ b/configs/alt-diffusion-m18-inference.yaml @@ -41,7 +41,7 @@ model: use_linear_in_transformer: True transformer_depth: 1 context_dim: 1024 - use_checkpoint: True + use_checkpoint: False legacy: False first_stage_config: diff --git a/configs/instruct-pix2pix.yaml b/configs/instruct-pix2pix.yaml index 4e896879d..564e50ae2 100644 --- a/configs/instruct-pix2pix.yaml +++ b/configs/instruct-pix2pix.yaml @@ -45,7 +45,7 @@ model: use_spatial_transformer: True transformer_depth: 1 context_dim: 768 - use_checkpoint: True + use_checkpoint: False legacy: False first_stage_config: diff --git a/configs/sd_xl_inpaint.yaml b/configs/sd_xl_inpaint.yaml index 3bad37218..f40f45e33 100644 --- a/configs/sd_xl_inpaint.yaml +++ b/configs/sd_xl_inpaint.yaml @@ -21,7 +21,7 @@ model: params: adm_in_channels: 2816 num_classes: sequential - use_checkpoint: True + use_checkpoint: False in_channels: 9 out_channels: 4 model_channels: 320 diff --git a/configs/v1-inference.yaml b/configs/v1-inference.yaml index d4effe569..25c4d9ed0 100644 --- a/configs/v1-inference.yaml +++ b/configs/v1-inference.yaml @@ -40,7 +40,7 @@ model: use_spatial_transformer: True transformer_depth: 1 context_dim: 768 - use_checkpoint: True + use_checkpoint: False legacy: False first_stage_config: diff --git a/configs/v1-inpainting-inference.yaml b/configs/v1-inpainting-inference.yaml index f9eec37d2..68c199f99 100644 --- a/configs/v1-inpainting-inference.yaml +++ b/configs/v1-inpainting-inference.yaml @@ -40,7 +40,7 @@ model: use_spatial_transformer: True transformer_depth: 1 context_dim: 768 - use_checkpoint: True + use_checkpoint: False legacy: False first_stage_config: diff --git a/modules/sd_hijack_checkpoint.py b/modules/sd_hijack_checkpoint.py index 2604d969f..b2f05bbdc 100644 --- a/modules/sd_hijack_checkpoint.py +++ b/modules/sd_hijack_checkpoint.py @@ -4,16 +4,19 @@ import ldm.modules.attention import ldm.modules.diffusionmodules.openaimodel +# Setting flag=False so that torch skips checking parameters. +# parameters checking is expensive in frequent operations. + def BasicTransformerBlock_forward(self, x, context=None): - return checkpoint(self._forward, x, context) + return checkpoint(self._forward, x, context, flag=False) def AttentionBlock_forward(self, x): - return checkpoint(self._forward, x) + return checkpoint(self._forward, x, flag=False) def ResBlock_forward(self, x, emb): - return checkpoint(self._forward, x, emb) + return checkpoint(self._forward, x, emb, flag=False) stored = [] diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py index b38137eb5..9cec4f13d 100644 --- a/modules/sd_models_config.py +++ b/modules/sd_models_config.py @@ -35,7 +35,7 @@ def is_using_v_parameterization_for_sd2(state_dict): with sd_disable_initialization.DisableInitialization(): unet = ldm.modules.diffusionmodules.openaimodel.UNetModel( - use_checkpoint=True, + use_checkpoint=False, use_fp16=False, image_size=32, in_channels=4, From 58eec83a546b8d61500c7b801cb0bdbe7650f6a6 Mon Sep 17 00:00:00 2001 From: huchenlei Date: Thu, 16 May 2024 16:39:02 -0400 Subject: [PATCH 2/3] Fully prevent use_checkpoint --- modules/sd_models.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/modules/sd_models.py b/modules/sd_models.py index ff245b7a6..a33fa7c33 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -551,6 +551,11 @@ def repair_config(sd_config): karlo_path = os.path.join(paths.models_path, 'karlo') sd_config.model.params.noise_aug_config.params.clip_stats_path = sd_config.model.params.noise_aug_config.params.clip_stats_path.replace("checkpoints/karlo_models", karlo_path) + # Do not use checkpoint for inference. + # This helps prevent extra performance overhead on checking parameters. + # The perf overhead is about 100ms/it on 4090. + sd_config.model.params.network_config.params.use_checkpoint = False + def rescale_zero_terminal_snr_abar(alphas_cumprod): alphas_bar_sqrt = alphas_cumprod.sqrt() From 47f1d42a7e77259e2e7418ae8f941718c55cfd25 Mon Sep 17 00:00:00 2001 From: huchenlei Date: Thu, 16 May 2024 20:06:04 -0400 Subject: [PATCH 3/3] Fix for SD15 models --- modules/sd_models.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index a33fa7c33..cda142bdd 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -553,8 +553,11 @@ def repair_config(sd_config): # Do not use checkpoint for inference. # This helps prevent extra performance overhead on checking parameters. - # The perf overhead is about 100ms/it on 4090. - sd_config.model.params.network_config.params.use_checkpoint = False + # The perf overhead is about 100ms/it on 4090 for SDXL. + if hasattr(sd_config.model.params, "network_config"): + sd_config.model.params.network_config.params.use_checkpoint = False + if hasattr(sd_config.model.params, "unet_config"): + sd_config.model.params.unet_config.params.use_checkpoint = False def rescale_zero_terminal_snr_abar(alphas_cumprod):