From a067f2109d75485967c16083046403e1b51d1ccb Mon Sep 17 00:00:00 2001 From: guoty Date: Tue, 24 Sep 2024 01:50:43 +0800 Subject: [PATCH] Fix batch dimension --- distvae/modules/patch_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distvae/modules/patch_utils.py b/distvae/modules/patch_utils.py index ecc9c01..376348a 100644 --- a/distvae/modules/patch_utils.py +++ b/distvae/modules/patch_utils.py @@ -34,7 +34,7 @@ def forward(self, patch_hidden_state): ) patch_hidden_state_list = [ torch.empty( - [1, patch_hidden_state.shape[1], patch_height_list[i].item(), patch_hidden_state.shape[-1]], + [patch_hidden_state.shape[0], patch_hidden_state.shape[1], patch_height_list[i].item(), patch_hidden_state.shape[-1]], dtype=patch_hidden_state.dtype, device=f"cuda:{self.rank}" ) for i in range(self.world_size)