-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: Allow for multi node training for accelerated moe #129
fix: Allow for multi node training for accelerated moe #129
Conversation
@kmehant i understand the fix, but can you update the description for record keeping purposes |
@fabianlim |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM but one suggestion
@@ -65,7 +66,7 @@ def augmentation( | |||
rank, world_size = 0, 1 | |||
if torch.distributed.is_initialized(): | |||
world_size = torch.distributed.get_world_size() | |||
rank = torch.distributed.get_rank() | |||
rank = int(os.environ["LOCAL_RANK"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we make it consistent and follow the new style.
rank = int(os.environ["LOCAL_RANK"]) | |
# we do not need to use the fallback as this is wrapped in an `is_initialized` block | |
rank = torch.distributed.get_node_local_rank() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fabianlim Have included this suggestion thanks.
1bb2f8c
to
548b710
Compare
Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com> Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com> Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
548b710
to
ed7821d
Compare
@fabianlim requesting your merge. |
Current implementation uses global rank of the process to prepare the device index which would not work in a multi node setting. Therefore, we would need to use local rank since devices are not continuously indexed across the nodes.