Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Allow for multi node training for accelerated moe #129

Merged
merged 1 commit into from
Feb 27, 2025

Conversation

kmehant
Copy link
Collaborator

@kmehant kmehant commented Feb 23, 2025

Current implementation uses global rank of the process to prepare the device index which would not work in a multi node setting. Therefore, we would need to use local rank since devices are not continuously indexed across the nodes.

@kmehant kmehant changed the title Allow for multi node training for accelerated moe fix: Allow for multi node training for accelerated moe Feb 23, 2025
@kmehant kmehant marked this pull request as ready for review February 23, 2025 19:14
@kmehant kmehant requested a review from fabianlim as a code owner February 23, 2025 19:14
@fabianlim
Copy link
Contributor

@kmehant i understand the fix, but can you update the description for record keeping purposes

@fabianlim fabianlim requested a review from willmj February 24, 2025 02:55
@kmehant
Copy link
Collaborator Author

kmehant commented Feb 24, 2025

#129 (comment)

@fabianlim
Apologies for missing that, have added it.

Copy link
Contributor

@fabianlim fabianlim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM but one suggestion

@@ -65,7 +66,7 @@ def augmentation(
rank, world_size = 0, 1
if torch.distributed.is_initialized():
world_size = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
rank = int(os.environ["LOCAL_RANK"])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we make it consistent and follow the new style.

Suggested change
rank = int(os.environ["LOCAL_RANK"])
# we do not need to use the fallback as this is wrapped in an `is_initialized` block
rank = torch.distributed.get_node_local_rank()

Copy link
Collaborator Author

@kmehant kmehant Feb 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fabianlim Have included this suggestion thanks.

@kmehant kmehant force-pushed the mn-sharedmoe-final branch 3 times, most recently from 1bb2f8c to 548b710 Compare February 24, 2025 06:43
Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
@kmehant
Copy link
Collaborator Author

kmehant commented Feb 24, 2025

@fabianlim requesting your merge.

@fabianlim
Copy link
Contributor

@kmehant lets have @willmj look at it first

@fabianlim fabianlim merged commit 791bdd9 into foundation-model-stack:main Feb 27, 2025
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants