-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding support for the Muon Optimizer #1914
base: main
Are you sure you want to change the base?
Adding support for the Muon Optimizer #1914
Conversation
Used a basic 2 layer MLPwith a dummy dataset:
more trainings wil come! |
LLM SFT Finetuning
Muon
Adam
|
That is definitely interesting but I think https://github.com/stockeh/mlx-optimizers may be a more suitable repository. Wdyt? |
@Goekdeniz-Guelmez 🔥 Perfect timing! The Muon optimizer just dropped, and now it’s already in MLX!!! |
@Goekdeniz-Guelmez @angeloskath yes, we have Muon already: https://github.com/stockeh/mlx-optimizers/blob/main/mlx_optimizers/muon.py thought I do believe Keller Jordan had made some minor updates since. |
@stockeh I didn't new the optimiser repo existed :D. But yea there are some differences with the new one. The new maintains the same mathematical principles but extends support to higher-dimensional tensors like conv filters through reshaping rather than using a separate optimizer. Also improves efficiency with a streamlined Newton-Schulz iteration formula and applies weight decay earlier in optimization process. The code now handles non-2D parameters more consistently and uses generalized transpose and normalization logic, works with tensors of any dimensionality. |
hi, @stockeh We recently worked on Muon and released the Moonlight model, see (https://github.com/MoonshotAI/Moonlight/tree/master). We had some empirical observations for muon to scale (and we did not see it in current implementation), and hope you do not mind me sharing it here:
The implementation is easy, see an example here: https://github.com/MoonshotAI/Moonlight/blob/master/examples/toy_train.py#L197-L203 These suggestions are empirically helpful to over-train as observed during our pretraining on Moonlight. What are your guys' opinions? |
@toothacher17 Wow! The Moonlight team just popped in with actual scaling tips! 🚀 Love seeing them share those crucial details about weight decay and matrix shape adjustments. This is what makes open source so awesome - experts freely sharing knowledge that turns theory into production-ready code. MLX bringing ML minds together at its finest! |
@lin72h @toothacher17 I agree I did not see that coming, but its very welcome. |
Thanks! Our team, Moonshot AI, believed that Muon is scalable, and did some empirical experiments! In case you guys are interested, we have a tech report discussing it: https://arxiv.org/abs/2502.16982 |
@toothacher17 Just read the Moonshot paper - same as K1 paper and even more innovative than DeepSeek's work! you folks at Moonshot haven't gotten the attention you deserve - your work just got overshadowed by DeepSeek's timing. The ideas in Moonlight are truly incredible. Open-sourcing this level of innovation is something to be genuinely proud of. |
@lin72h |
@toothacher17 Keep up the awesome work! Moonshot rocks! |
@toothacher17 I appreciate you sharing your insights! I found the paper to be quiet informative. I think most of these changes are easy enough to implement into |
Can you say more about that? |
I guess that's because for now, AdamW is chained with Muon to handle those non-matrix parameters, e.g. embedding, lm head, and rmsnorm gamma. In future, there might be a chance to get rid of AdamW and only use Muon purely, for example: https://github.com/modula-systems/modula It's not large scale proven yet, but it might be promising |
@awni tldr: I don't think anything has to change with mlx, specifically, but I may change mlx-optimizers' Muon class to not include AdamW and simplify the delegation logic with a separate optim. I originally said this when thinking about how we pass params to the optimizer, e.g., in KellerJordan/Muon muon_params = [p for p in model.body.parameters() if p.ndim >= 2]
adamw_params = ([p for p in model.body.parameters() if p.ndim < 2]
+ [*model.head.parameters(), *model.embed.parameters()])
optimizers = [Muon(muon_params, lr=0.02, momentum=0.95),
torch.optim.AdamW(adamw_params, lr=3e-4, betas=(0.90, 0.95), weight_decay=0.01)]
...
# in the training step
for opt in optimizers:
opt.step() Moonlight's implementation differs in that their custom Muon class accepts both But, I thought about this some more and think it's easier as a general approach to just define multiple optimizers as we've discussed in this discussion, i.e., def split_grads(grads):
grads = tree_flatten(grads)
weights = [(k, v) for k, v in grads if v.ndim == 2]
biases = [(k, v) for k, v in grads if v.ndim == 1]
weights = tree_unflatten(weights)
biases = tree_unflatten(biases)
return weights, biases
@partial(mx.compile, inputs=state, outputs=state)
def step(X, T):
train_step_fn = nn.value_and_grad(self.model, self.eval_fn)
loss, grads = train_step_fn(X, T)
weights, biases = split_grads(grads)
self.optimizers[0].update(self.model, weights)
self.optimizers[1].update(self.model, biases)
return loss This would just require a bit of a refactor and description for using Muon in mlx-optimizers, should the optims be separate. |
Thanks for the detailed explanation, that makes sense! |
Proposed changes
First contribution to the MLX repo. Add the Muon optimizer to MLX's optimizer suite. Muon (MomentUm Orthogonalized by Newton-schulz) is a novel optimizer that combines momentum-based SGD with orthogonalization of parameter updates via Newton-Schulz iterations. This optimizer has shown promising results for training neural networks, particularly for convolutional and transformer architectures.
The implementation follows the approach described in https://kellerjordan.github.io/posts/muon/ , adapting it to MLX's framework. The optimizer performs standard SGD-momentum updates, followed by an orthogonalization step that replaces each 2D parameter's update with the nearest orthogonal matrix using an efficient Newton-Schulz iteration.
Key features of this implementation:
Checklist
pre-commit run --all-files
to format my code / installed pre-commit prior to committing changes