From ed7821d50a82efcce81abc069be49a55cdb7ed4a Mon Sep 17 00:00:00 2001 From: Mehant Kammakomati Date: Sun, 23 Feb 2025 23:51:14 +0530 Subject: [PATCH] fix: compute device correctly Signed-off-by: Mehant Kammakomati Signed-off-by: Yu Chin Fabian Lim Signed-off-by: Mehant Kammakomati --- .../src/fms_acceleration_moe/framework_plugin_scattermoe.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py b/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py index 2ad26ed2..3a4a3c27 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py @@ -69,7 +69,8 @@ def augmentation( rank, world_size = 0, 1 if torch.distributed.is_initialized(): world_size = torch.distributed.get_world_size() - rank = torch.distributed.get_rank() + # we do not need to use the fallback as this is wrapped in an `is_initialized` block + rank = torch.distributed.get_node_local_rank() if not hasattr(model.config, "name_or_path") or not model.config.name_or_path: raise ValueError(