From 2643c74ef61f124948b06c9f8b692238bffe5c55 Mon Sep 17 00:00:00 2001 From: Junyu Chen Date: Thu, 9 Jan 2025 04:36:40 -0800 Subject: [PATCH 1/4] autoencoder_dc tiling --- .../models/autoencoders/autoencoder_dc.py | 107 +++++++++++++++++- 1 file changed, 101 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_dc.py b/src/diffusers/models/autoencoders/autoencoder_dc.py index 109e37c23e1b..89465d704cb0 100644 --- a/src/diffusers/models/autoencoders/autoencoder_dc.py +++ b/src/diffusers/models/autoencoders/autoencoder_dc.py @@ -479,12 +479,15 @@ def __init__( self.use_tiling = False # The minimal tile height and width for spatial tiling to be used - self.tile_sample_min_height = 512 - self.tile_sample_min_width = 512 + self.tile_sample_min_height = 1024 + self.tile_sample_min_width = 1024 # The minimal distance between two spatial tiles - self.tile_sample_stride_height = 448 - self.tile_sample_stride_width = 448 + self.tile_sample_stride_height = 896 + self.tile_sample_stride_width = 896 + + self.tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + self.tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio def enable_tiling( self, @@ -515,6 +518,8 @@ def enable_tiling( self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width + self.tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + self.tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio def disable_tiling(self) -> None: r""" @@ -606,11 +611,101 @@ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutp return (decoded,) return DecoderOutput(sample=decoded) + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[2], b.shape[2], blend_extent) + for y in range(blend_extent): + b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + for x in range(blend_extent): + b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent) + return b + def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> torch.Tensor: - raise NotImplementedError("`tiled_encode` has not been implemented for AutoencoderDC.") + batch_size, num_channels, height, width = x.shape + latent_height = height // self.spatial_compression_ratio + latent_width = width // self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + blend_height = tile_latent_min_height - tile_latent_stride_height + blend_width = tile_latent_min_width - tile_latent_stride_width + + # Split x into overlapping tiles and encode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, x.shape[2], self.tile_sample_stride_height): + row = [] + for j in range(0, x.shape[3], self.tile_sample_stride_width): + tile = x[:, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width] + if tile.shape[2] % self.spatial_compression_ratio != 0 or tile.shape[3] % self.spatial_compression_ratio != 0: + tile = F.pad(tile, (0, (self.spatial_compression_ratio - tile.shape[3]) % self.spatial_compression_ratio, 0, (self.spatial_compression_ratio - tile.shape[2]) % self.spatial_compression_ratio)) + tile = self.encoder(tile) + row.append(tile) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :tile_latent_stride_height, :tile_latent_stride_width]) + result_rows.append(torch.cat(result_row, dim=3)) + + encoded = torch.cat(result_rows, dim=2)[:, :, :latent_height, :latent_width] + + if not return_dict: + return (encoded,) + return EncoderOutput(latent=encoded) def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: - raise NotImplementedError("`tiled_decode` has not been implemented for AutoencoderDC.") + batch_size, num_channels, height, width = z.shape + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_height = self.tile_sample_min_height - self.tile_sample_stride_height + blend_width = self.tile_sample_min_width - self.tile_sample_stride_width + + # Split z into overlapping tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, tile_latent_stride_height): + row = [] + for j in range(0, width, tile_latent_stride_width): + tile = z[:, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width] + decoded = self.decoder(tile) + row.append(decoded) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width]) + result_rows.append(torch.cat(result_row, dim=3)) + + decoded = torch.cat(result_rows, dim=2) + + if not return_dict: + return (decoded,) + return DecoderOutput(sample=decoded) def forward(self, sample: torch.Tensor, return_dict: bool = True) -> torch.Tensor: encoded = self.encode(sample, return_dict=False)[0] From c0b1ca5300f964fcb5d91e856a8e7b041400087d Mon Sep 17 00:00:00 2001 From: Junyu Chen Date: Thu, 9 Jan 2025 19:06:54 -0800 Subject: [PATCH 2/4] add tiling and slicing support in SANA pipelines --- src/diffusers/pipelines/sana/pipeline_sana.py | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/src/diffusers/pipelines/sana/pipeline_sana.py b/src/diffusers/pipelines/sana/pipeline_sana.py index afc2f74c9e8f..8b318597c12d 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana.py +++ b/src/diffusers/pipelines/sana/pipeline_sana.py @@ -218,6 +218,35 @@ def __init__( ) self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor) + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + def encode_prompt( self, prompt: Union[str, List[str]], From d80dea51e46320d93bb17e779b26d320d60eb665 Mon Sep 17 00:00:00 2001 From: Junyu Chen Date: Thu, 9 Jan 2025 19:12:48 -0800 Subject: [PATCH 3/4] create variables for padding length because the line becomes too long --- src/diffusers/models/autoencoders/autoencoder_dc.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_dc.py b/src/diffusers/models/autoencoders/autoencoder_dc.py index 89465d704cb0..7f6aeef4d8fd 100644 --- a/src/diffusers/models/autoencoders/autoencoder_dc.py +++ b/src/diffusers/models/autoencoders/autoencoder_dc.py @@ -643,7 +643,9 @@ def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> torch.Tenso for j in range(0, x.shape[3], self.tile_sample_stride_width): tile = x[:, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width] if tile.shape[2] % self.spatial_compression_ratio != 0 or tile.shape[3] % self.spatial_compression_ratio != 0: - tile = F.pad(tile, (0, (self.spatial_compression_ratio - tile.shape[3]) % self.spatial_compression_ratio, 0, (self.spatial_compression_ratio - tile.shape[2]) % self.spatial_compression_ratio)) + pad_h = (self.spatial_compression_ratio - tile.shape[2]) % self.spatial_compression_ratio + pad_w = (self.spatial_compression_ratio - tile.shape[3]) % self.spatial_compression_ratio + tile = F.pad(tile, (0, pad_w, 0, pad_h)) tile = self.encoder(tile) row.append(tile) rows.append(row) From c3b9a8ef81553a78763482ad7b82e5ced4e489d9 Mon Sep 17 00:00:00 2001 From: Junyu Chen Date: Thu, 9 Jan 2025 19:35:23 -0800 Subject: [PATCH 4/4] add tiling and slicing support in pag SANA pipelines --- .../pipelines/pag/pipeline_pag_sana.py | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sana.py b/src/diffusers/pipelines/pag/pipeline_pag_sana.py index f363a1a557bc..2cdc1c70cdcc 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sana.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sana.py @@ -183,6 +183,35 @@ def __init__( pag_attn_processors=(PAGCFGSanaLinearAttnProcessor2_0(), PAGIdentitySanaLinearAttnProcessor2_0()), ) + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + def encode_prompt( self, prompt: Union[str, List[str]],