-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Distributed layers #1270
base: main
Are you sure you want to change the base?
Distributed layers #1270
Conversation
I kind of like this design. I like that it's all quite simple and easy to follow and we have a lot of control over how to shard the model (as in ml-explore/mlx-examples#890). We could possibly find a way to reduce the code needed for adding a new custom linear-like layer.. but the simplicity is nice, I wouldn't want to give that up. |
6090542
to
fea9644
Compare
061d214
to
b32ce2c
Compare
ab26116
to
3d431c0
Compare
2298954
to
1697581
Compare
31ba022
to
60e7e02
Compare
07b5bd5
to
794eb42
Compare
517eb95
to
a323642
Compare
a323642
to
dd89374
Compare
I am marking this ready for review. The main things that are new since I started the branch: Exposing
The sharding functions now also take a |
Adds linear layers that allow training and inference of a model sharded across several devices. The main things added are
simply changing linear layers to the above results in a model that works out of the box with distributed inference and training.
I am starting it as a draft so that we can iterate a bit on the design. The negative aspects of the above design are that we have yet another linear layer to think about when implementing LoRA and friends or weird new quantizations for instance. Perhaps it would be better to make the above layers with an internal linear layer so model surgery that swaps linear layers would still work out of the box.