-
Notifications
You must be signed in to change notification settings - Fork 10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable 2D sharding #17
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, nice one Jiewen!
data_model_mesh = xs.Mesh(device_ids, (data, mod)) | ||
model_data_mesh = xs.Mesh(device_ids, (mod, data)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you try with HybridMesh? It should provide some performance gain, but I haven't actually benchmarked the difference. Here and in modeling_llama.py
@khatwanimohit may have some benchmarked differences on the simple shardings.py
script
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me do that. Always forgot.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed that too.
elif 'gate_proj' in name or 'up_proj' in name: | ||
xs.mark_sharding(param, data_model_mesh, range(len(param.shape))) | ||
elif 'down_proj' in name: | ||
xs.mark_sharding(param, model_data_mesh, range(len(param.shape))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just for my understanding: I noticed that HF shards gate_proj
and up_proj
on the 0th dim and down_proj
on the 1st dim, but here you're sharding gate
and up
on the data_model
mesh, which places the model
axis on dim 1.
Is this just a difference in 1- and 2-D sharding?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good catch. I don't know. Let me dig into it. I'm following the slides attached on the top of the spreadsheet.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No worries! I was just curious, using the sharding from the slides makes sense.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea, you are right. I have corrected the error.
Thanks Jon for approving the pull request. |
Summary: This pull requests fix a bug in #17 where it forgot to guard 2D sharding for activations and inputs. Test Plan: N/A.
Summary: This pull request adds 2D SPMD sharding to the table. It will shard both weights and activations. Here is the sharding strategy. Let's say we have a 2D mesh (data, model) and data x model == num_devices: 1. input (data,, None, model) 2. embedding (model, data) 3. attn QKV (data, model) 4. attn O (model, data) 5. mlp gate, up (model, data) 6. mlp down (data, model) 7. activation (data,, None, model) Currently you can specify the model dimension using a new option --spmd_2d_sharding, then the data dimension will be auto-calculated. TODO: maybe we should have another option to specify whether or not we should shard the activations/inputs or shard them differently.
Summary: This pull requests fix a bug in #17 where it forgot to guard 2D sharding for activations and inputs. Test Plan: N/A.
Summary: This pull request adds 2D SPMD sharding to the table. It will shard both weights and activations. Here is the sharding strategy. Let's say we have a 2D mesh (data, model) and data x model == num_devices: 1. input (data,, None, model) 2. embedding (model, data) 3. attn QKV (data, model) 4. attn O (model, data) 5. mlp gate, up (model, data) 6. mlp down (data, model) 7. activation (data,, None, model) Currently you can specify the model dimension using a new option --spmd_2d_sharding, then the data dimension will be auto-calculated. TODO: maybe we should have another option to specify whether or not we should shard the activations/inputs or shard them differently.
Summary: This pull requests fix a bug in #17 where it forgot to guard 2D sharding for activations and inputs. Test Plan: N/A.
Summary: This pull request adds 2D SPMD sharding to the table. It will shard both weights and activations. Here is the sharding strategy. Let's say we have a 2D mesh (data, model) and data x model == num_devices: 1. input (data,, None, model) 2. embedding (model, data) 3. attn QKV (data, model) 4. attn O (model, data) 5. mlp gate, up (model, data) 6. mlp down (data, model) 7. activation (data,, None, model) Currently you can specify the model dimension using a new option --spmd_2d_sharding, then the data dimension will be auto-calculated. TODO: maybe we should have another option to specify whether or not we should shard the activations/inputs or shard them differently.
Summary: This pull requests fix a bug in #17 where it forgot to guard 2D sharding for activations and inputs. Test Plan: N/A.
Summary:
This pull request adds 2D SPMD sharding to the table. It will shard both weights and activations. Here is the sharding strategy.
Let's say we have a 2D mesh (data, model) and data x model == num_devices:
Currently you can specify the model dimension using a new option --spmd_2d_sharding, then the data dimension will be auto-calculated.
TODO: maybe we should have another option to specify whether or not we should shard the activations/inputs or shard them differently.