Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distributed layers #1270

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open

Distributed layers #1270

wants to merge 9 commits into from

Conversation

angeloskath
Copy link
Member

Adds linear layers that allow training and inference of a model sharded across several devices. The main things added are

  • float16/bfloat16 reductions for MPI
  • AllToShardedLinear and its quantized sibling
  • ShardedToAllLinear and its quantized sibling

simply changing linear layers to the above results in a model that works out of the box with distributed inference and training.

I am starting it as a draft so that we can iterate a bit on the design. The negative aspects of the above design are that we have yet another linear layer to think about when implementing LoRA and friends or weird new quantizations for instance. Perhaps it would be better to make the above layers with an internal linear layer so model surgery that swaps linear layers would still work out of the box.

@awni
Copy link
Member

awni commented Jul 17, 2024

I kind of like this design. I like that it's all quite simple and easy to follow and we have a lot of control over how to shard the model (as in ml-explore/mlx-examples#890). We could possibly find a way to reduce the code needed for adding a new custom linear-like layer.. but the simplicity is nice, I wouldn't want to give that up.

@angeloskath angeloskath force-pushed the distributed-layers branch 2 times, most recently from 061d214 to b32ce2c Compare August 29, 2024 08:20
@angeloskath angeloskath force-pushed the distributed-layers branch 2 times, most recently from ab26116 to 3d431c0 Compare September 6, 2024 18:03
@awni awni mentioned this pull request Sep 16, 2024
@angeloskath angeloskath force-pushed the distributed-layers branch 5 times, most recently from 2298954 to 1697581 Compare November 5, 2024 19:35
@awni awni force-pushed the distributed-layers branch 3 times, most recently from 31ba022 to 60e7e02 Compare January 18, 2025 14:06
@awni awni force-pushed the distributed-layers branch 2 times, most recently from 07b5bd5 to 794eb42 Compare February 6, 2025 15:36
@angeloskath angeloskath force-pushed the distributed-layers branch 3 times, most recently from 517eb95 to a323642 Compare March 4, 2025 21:32
@angeloskath
Copy link
Member Author

I am marking this ready for review. The main things that are new since I started the branch:

Exposing mx.contiguous. This ensures both that the array is contiguous and that it occupies at most x.size() * x.itemsize() + 16384 bytes. Mainly a contiguous slice is still going to be copied.

shard_linear convenience function and shard_inplace. The first one just creates the appropriate linear layer quantized or not. The second actually shards the parameters in place which allows us to shard any layer and apply the collective operations as we see fit. It is used for instance to shard the single stream transformer blocks in FLUX but only perform one communication (ml-explore/mlx-examples#1325).

The sharding functions now also take a groups argument. This assumes the linear layer is a fused one and splits it according to the groups argument (evenly or percentage wise). I think the argument name may need improving here.

@angeloskath angeloskath marked this pull request as ready for review March 6, 2025 23:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants