Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable 2D sharding #17

Merged
merged 6 commits into from
Aug 1, 2023
Merged

Conversation

alanwaketan
Copy link
Collaborator

@alanwaketan alanwaketan commented Aug 1, 2023

Summary:
This pull request adds 2D SPMD sharding to the table. It will shard both weights and activations. Here is the sharding strategy.

Let's say we have a 2D mesh (data, model) and data x model == num_devices:

  1. input (data,, None, model)
  2. embedding (model, data)
  3. attn QKV (data, model)
  4. attn O (model, data)
  5. mlp gate, up (model, data)
  6. mlp down (data, model)
  7. activation (data,, None, model)

Currently you can specify the model dimension using a new option --spmd_2d_sharding, then the data dimension will be auto-calculated.

TODO: maybe we should have another option to specify whether or not we should shard the activations/inputs or shard them differently.

Copy link

@jonb377 jonb377 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, nice one Jiewen!

Comment on lines 540 to 541
data_model_mesh = xs.Mesh(device_ids, (data, mod))
model_data_mesh = xs.Mesh(device_ids, (mod, data))
Copy link

@jonb377 jonb377 Aug 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you try with HybridMesh? It should provide some performance gain, but I haven't actually benchmarked the difference. Here and in modeling_llama.py

@khatwanimohit may have some benchmarked differences on the simple shardings.py script

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me do that. Always forgot.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed that too.

Comment on lines 554 to 557
elif 'gate_proj' in name or 'up_proj' in name:
xs.mark_sharding(param, data_model_mesh, range(len(param.shape)))
elif 'down_proj' in name:
xs.mark_sharding(param, model_data_mesh, range(len(param.shape)))
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for my understanding: I noticed that HF shards gate_proj and up_proj on the 0th dim and down_proj on the 1st dim, but here you're sharding gate and up on the data_model mesh, which places the model axis on dim 1.

Is this just a difference in 1- and 2-D sharding?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good catch. I don't know. Let me dig into it. I'm following the slides attached on the top of the spreadsheet.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No worries! I was just curious, using the sharding from the slides makes sense.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, you are right. I have corrected the error.

@alanwaketan
Copy link
Collaborator Author

Thanks Jon for approving the pull request.

@alanwaketan alanwaketan merged commit 813af25 into llama2-google-next-training Aug 1, 2023
alanwaketan added a commit that referenced this pull request Aug 2, 2023
Summary:
This pull requests fix a bug in #17 where it forgot to guard 2D sharding for activations and inputs.

Test Plan:
N/A.
alanwaketan added a commit that referenced this pull request Oct 27, 2023
Summary:
This pull request adds 2D SPMD sharding to the table. It will shard both weights and activations. Here is the sharding strategy.

Let's say we have a 2D mesh (data, model) and data x model == num_devices:
1. input (data,, None, model)
2. embedding (model, data)
3. attn QKV (data, model)
4. attn O (model, data)
5. mlp gate, up (model, data)
6. mlp down (data, model)
7. activation (data,, None, model)
Currently you can specify the model dimension using a new option --spmd_2d_sharding, then the data dimension will be auto-calculated.

TODO: maybe we should have another option to specify whether or not we should shard the activations/inputs or shard them differently.
alanwaketan added a commit that referenced this pull request Oct 27, 2023
Summary:
This pull requests fix a bug in #17 where it forgot to guard 2D sharding for activations and inputs.

Test Plan:
N/A.
yeounoh pushed a commit that referenced this pull request Mar 19, 2024
Summary:
This pull request adds 2D SPMD sharding to the table. It will shard both weights and activations. Here is the sharding strategy.

Let's say we have a 2D mesh (data, model) and data x model == num_devices:
1. input (data,, None, model)
2. embedding (model, data)
3. attn QKV (data, model)
4. attn O (model, data)
5. mlp gate, up (model, data)
6. mlp down (data, model)
7. activation (data,, None, model)
Currently you can specify the model dimension using a new option --spmd_2d_sharding, then the data dimension will be auto-calculated.

TODO: maybe we should have another option to specify whether or not we should shard the activations/inputs or shard them differently.
yeounoh pushed a commit that referenced this pull request Mar 19, 2024
Summary:
This pull requests fix a bug in #17 where it forgot to guard 2D sharding for activations and inputs.

Test Plan:
N/A.
vanbasten23 pushed a commit that referenced this pull request May 21, 2024
Summary:
This pull request adds 2D SPMD sharding to the table. It will shard both weights and activations. Here is the sharding strategy.

Let's say we have a 2D mesh (data, model) and data x model == num_devices:
1. input (data,, None, model)
2. embedding (model, data)
3. attn QKV (data, model)
4. attn O (model, data)
5. mlp gate, up (model, data)
6. mlp down (data, model)
7. activation (data,, None, model)
Currently you can specify the model dimension using a new option --spmd_2d_sharding, then the data dimension will be auto-calculated.

TODO: maybe we should have another option to specify whether or not we should shard the activations/inputs or shard them differently.
vanbasten23 pushed a commit that referenced this pull request May 21, 2024
Summary:
This pull requests fix a bug in #17 where it forgot to guard 2D sharding for activations and inputs.

Test Plan:
N/A.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants