-
Notifications
You must be signed in to change notification settings - Fork 5.9k
[Feature] Implement tiled VAE encoding/decoding for Wan model. #11414
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -677,42 +677,7 @@ def __init__( | |
attn_scales: List[float] = [], | ||
temperal_downsample: List[bool] = [False, True, True], | ||
dropout: float = 0.0, | ||
latents_mean: List[float] = [ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Indeed it might not be used in the diffusers codebase at the moment, but it is being used downstream in a few repositories (example). Removing this will break downstream so we should keep this anyway. What you could instead do here to reduce LOC is wrap these two parameters in a non-format block and condense the list into a single line:
|
||
-0.7571, | ||
-0.7089, | ||
-0.9113, | ||
0.1075, | ||
-0.1745, | ||
0.9653, | ||
-0.1517, | ||
1.5508, | ||
0.4134, | ||
-0.0715, | ||
0.5517, | ||
-0.3632, | ||
-0.1922, | ||
-0.9497, | ||
0.2503, | ||
-0.2921, | ||
], | ||
latents_std: List[float] = [ | ||
2.8184, | ||
1.4541, | ||
2.3275, | ||
2.6558, | ||
1.2196, | ||
1.7708, | ||
2.6052, | ||
2.0743, | ||
3.2687, | ||
2.1526, | ||
2.8652, | ||
1.5579, | ||
1.6382, | ||
1.1253, | ||
2.8251, | ||
1.9160, | ||
], | ||
spatial_compression_ratio: int = 8, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's not add new parameter now because it will lead to unnecessary config warning |
||
) -> None: | ||
super().__init__() | ||
|
||
|
@@ -730,6 +695,58 @@ def __init__( | |
base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout | ||
) | ||
|
||
self.spatial_compression_ratio = spatial_compression_ratio | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead, let's set this attribute based on the init parameters. The same logic as used in the pipeline can be applied here. Let's also add temporal_compression_ratio as 2 raised to power of (the number of true values in temporal downsample) |
||
|
||
# When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent | ||
# frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the | ||
# intermediate tiles together, the memory requirement can be lowered. | ||
self.use_tiling = False | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's also add use_slicing |
||
|
||
# The minimal tile height and width for spatial tiling to be used | ||
self.tile_sample_min_height = 256 | ||
self.tile_sample_min_width = 256 | ||
|
||
# The minimal distance between two spatial tiles | ||
self.tile_sample_stride_height = 192 | ||
self.tile_sample_stride_width = 192 | ||
|
||
def enable_tiling( | ||
self, | ||
tile_sample_min_height: Optional[int] = None, | ||
tile_sample_min_width: Optional[int] = None, | ||
tile_sample_stride_height: Optional[float] = None, | ||
tile_sample_stride_width: Optional[float] = None, | ||
) -> None: | ||
r""" | ||
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to | ||
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow | ||
processing larger images. | ||
Args: | ||
tile_sample_min_height (`int`, *optional*): | ||
The minimum height required for a sample to be separated into tiles across the height dimension. | ||
tile_sample_min_width (`int`, *optional*): | ||
The minimum width required for a sample to be separated into tiles across the width dimension. | ||
tile_sample_stride_height (`int`, *optional*): | ||
The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are | ||
no tiling artifacts produced across the height dimension. | ||
tile_sample_stride_width (`int`, *optional*): | ||
The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling | ||
artifacts produced across the width dimension. | ||
""" | ||
self.use_tiling = True | ||
self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height | ||
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width | ||
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height | ||
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width | ||
|
||
def disable_tiling(self) -> None: | ||
r""" | ||
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing | ||
decoding in one step. | ||
""" | ||
self.use_tiling = False | ||
|
||
def clear_cache(self): | ||
def _count_conv3d(model): | ||
count = 0 | ||
|
@@ -785,7 +802,11 @@ def encode( | |
The latent representations of the encoded videos. If `return_dict` is True, a | ||
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. | ||
""" | ||
h = self._encode(x) | ||
_, _, _, height, width = x.shape | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's make sure to support use_slicing here as well |
||
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): | ||
h = self.tiled_encode(x) | ||
else: | ||
h = self._encode(x) | ||
posterior = DiagonalGaussianDistribution(h) | ||
if not return_dict: | ||
return (posterior,) | ||
|
@@ -826,12 +847,170 @@ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutp | |
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is | ||
returned. | ||
""" | ||
decoded = self._decode(z).sample | ||
_, _, _, height, width = z.shape | ||
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio | ||
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio | ||
|
||
if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height): | ||
decoded = self.tiled_decode(z).sample | ||
else: | ||
decoded = self._decode(z).sample | ||
if not return_dict: | ||
return (decoded,) | ||
|
||
return DecoderOutput(sample=decoded) | ||
|
||
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: | ||
blend_extent = min(a.shape[-2], b.shape[-2], blend_extent) | ||
for y in range(blend_extent): | ||
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * ( | ||
y / blend_extent | ||
) | ||
return b | ||
|
||
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: | ||
blend_extent = min(a.shape[-1], b.shape[-1], blend_extent) | ||
for x in range(blend_extent): | ||
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * ( | ||
x / blend_extent | ||
) | ||
return b | ||
|
||
def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: | ||
blend_extent = min(a.shape[-3], b.shape[-3], blend_extent) | ||
for x in range(blend_extent): | ||
b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * ( | ||
x / blend_extent | ||
) | ||
return b | ||
|
||
def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput: | ||
r"""Encode a batch of images using a tiled encoder. | ||
Args: | ||
x (`torch.Tensor`): Input batch of videos. | ||
Returns: | ||
`torch.Tensor`: | ||
The latent representation of the encoded videos. | ||
""" | ||
_, _, num_frames, height, width = x.shape | ||
latent_height = height // self.spatial_compression_ratio | ||
latent_width = width // self.spatial_compression_ratio | ||
|
||
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio | ||
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio | ||
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio | ||
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio | ||
|
||
blend_height = tile_latent_min_height - tile_latent_stride_height | ||
blend_width = tile_latent_min_width - tile_latent_stride_width | ||
|
||
# Split x into overlapping tiles and encode them separately. | ||
# The tiles have an overlap to avoid seams between tiles. | ||
rows = [] | ||
for i in range(0, height, self.tile_sample_stride_height): | ||
row = [] | ||
for j in range(0, width, self.tile_sample_stride_width): | ||
self.clear_cache() | ||
time = [] | ||
frame_range = 1 + (num_frames - 1) // 4 | ||
for k in range(frame_range): | ||
self._enc_conv_idx = [0] | ||
if k == 0: | ||
tile = x[:, :, :1, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width] | ||
else: | ||
tile = x[ | ||
:, | ||
:, | ||
1 + 4 * (k - 1) : 1 + 4 * k, | ||
i : i + self.tile_sample_min_height, | ||
j : j + self.tile_sample_min_width, | ||
] | ||
tile = self.encoder(tile, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) | ||
tile = self.quant_conv(tile) | ||
time.append(tile) | ||
row.append(torch.cat(time, dim=2)) | ||
rows.append(row) | ||
|
||
result_rows = [] | ||
for i, row in enumerate(rows): | ||
result_row = [] | ||
for j, tile in enumerate(row): | ||
# blend the above tile and the left tile | ||
# to the current tile and add the current tile to the result row | ||
if i > 0: | ||
tile = self.blend_v(rows[i - 1][j], tile, blend_height) | ||
if j > 0: | ||
tile = self.blend_h(row[j - 1], tile, blend_width) | ||
result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width]) | ||
result_rows.append(torch.cat(result_row, dim=-1)) | ||
|
||
enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width] | ||
return enc | ||
|
||
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: | ||
r""" | ||
Decode a batch of images using a tiled decoder. | ||
Args: | ||
z (`torch.Tensor`): Input batch of latent vectors. | ||
return_dict (`bool`, *optional*, defaults to `True`): | ||
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. | ||
Returns: | ||
[`~models.vae.DecoderOutput`] or `tuple`: | ||
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is | ||
returned. | ||
""" | ||
_, _, num_frames, height, width = z.shape | ||
sample_height = height * self.spatial_compression_ratio | ||
sample_width = width * self.spatial_compression_ratio | ||
|
||
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio | ||
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio | ||
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio | ||
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio | ||
|
||
blend_height = self.tile_sample_min_height - self.tile_sample_stride_height | ||
blend_width = self.tile_sample_min_width - self.tile_sample_stride_width | ||
|
||
# Split z into overlapping tiles and decode them separately. | ||
# The tiles have an overlap to avoid seams between tiles. | ||
rows = [] | ||
for i in range(0, height, tile_latent_stride_height): | ||
row = [] | ||
for j in range(0, width, tile_latent_stride_width): | ||
self.clear_cache() | ||
time = [] | ||
for k in range(num_frames): | ||
self._conv_idx = [0] | ||
tile = z[:, :, k : k + 1, i : i + tile_latent_min_height, j : j + tile_latent_min_width] | ||
tile = self.post_quant_conv(tile) | ||
decoded = self.decoder(tile, feat_cache=self._feat_map, feat_idx=self._conv_idx) | ||
time.append(decoded) | ||
row.append(torch.cat(time, dim=2)) | ||
rows.append(row) | ||
|
||
result_rows = [] | ||
for i, row in enumerate(rows): | ||
result_row = [] | ||
for j, tile in enumerate(row): | ||
# blend the above tile and the left tile | ||
# to the current tile and add the current tile to the result row | ||
if i > 0: | ||
tile = self.blend_v(rows[i - 1][j], tile, blend_height) | ||
if j > 0: | ||
tile = self.blend_h(row[j - 1], tile, blend_width) | ||
result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width]) | ||
result_rows.append(torch.cat(result_row, dim=-1)) | ||
|
||
dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width] | ||
|
||
if not return_dict: | ||
return (dec,) | ||
return DecoderOutput(sample=dec) | ||
|
||
def forward( | ||
self, | ||
sample: torch.Tensor, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,6 +15,8 @@ | |
|
||
import unittest | ||
|
||
import torch | ||
|
||
from diffusers import AutoencoderKLWan | ||
from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, torch_device | ||
|
||
|
@@ -44,9 +46,16 @@ def dummy_input(self): | |
num_frames = 9 | ||
num_channels = 3 | ||
sizes = (16, 16) | ||
|
||
image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device) | ||
return {"sample": image} | ||
|
||
@property | ||
def dummy_input_tiling(self): | ||
batch_size = 2 | ||
num_frames = 9 | ||
num_channels = 3 | ||
sizes = (640, 480) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add another input because the (16, 16) tensor is too small for tiling operations. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's try to reduce the size as much as possible because these tests should not cause unexpected slowdowns in the CI. While enabling tiling, you can set different tile width/height and stride, than the default 256 and 192. (128, 128) would be good, with tile size being |
||
image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device) | ||
return {"sample": image} | ||
|
||
@property | ||
|
@@ -62,6 +71,42 @@ def prepare_init_args_and_inputs_for_common(self): | |
inputs_dict = self.dummy_input | ||
return init_dict, inputs_dict | ||
|
||
def prepare_init_args_and_inputs_for_tiling(self): | ||
init_dict = self.get_autoencoder_kl_wan_config() | ||
inputs_dict = self.dummy_input_tiling | ||
return init_dict, inputs_dict | ||
|
||
def test_enable_disable_tiling(self): | ||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_tiling() | ||
|
||
torch.manual_seed(0) | ||
model = self.model_class(**init_dict).to(torch_device) | ||
|
||
inputs_dict.update({"return_dict": False}) | ||
|
||
torch.manual_seed(0) | ||
output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0] | ||
|
||
torch.manual_seed(0) | ||
model.enable_tiling() | ||
output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0] | ||
|
||
self.assertLess( | ||
(output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(), | ||
0.5, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. On my machine, this value is approximately 0.404, and IIRC the average absolute value of these arrays is less than 0.01, which makes me confident that the implementation is correct at some point. |
||
"VAE tiling should not affect the inference results", | ||
) | ||
|
||
torch.manual_seed(0) | ||
model.disable_tiling() | ||
output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0] | ||
|
||
self.assertEqual( | ||
output_without_tiling.detach().cpu().numpy().all(), | ||
output_without_tiling_2.detach().cpu().numpy().all(), | ||
"Without tiling outputs should match with the outputs when tiling is manually disabled.", | ||
) | ||
|
||
@unittest.skip("Gradient checkpointing has not been implemented yet") | ||
def test_gradient_checkpointing_is_applied(self): | ||
pass | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These function parameters are not being used. They have been removed in this patch but can be added back at any time if needed.