Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding support for the Muon Optimizer #1914

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

Goekdeniz-Guelmez
Copy link

@Goekdeniz-Guelmez Goekdeniz-Guelmez commented Feb 28, 2025

Proposed changes

First contribution to the MLX repo. Add the Muon optimizer to MLX's optimizer suite. Muon (MomentUm Orthogonalized by Newton-schulz) is a novel optimizer that combines momentum-based SGD with orthogonalization of parameter updates via Newton-Schulz iterations. This optimizer has shown promising results for training neural networks, particularly for convolutional and transformer architectures.
The implementation follows the approach described in https://kellerjordan.github.io/posts/muon/ , adapting it to MLX's framework. The optimizer performs standard SGD-momentum updates, followed by an orthogonalization step that replaces each 2D parameter's update with the nearest orthogonal matrix using an efficient Newton-Schulz iteration.
Key features of this implementation:

  • Support for standard optimizer features (learning rate, momentum, weight decay, Nesterov)
  • Efficient Newton-Schulz orthogonalization that works with bfloat16
  • Special handling for parameters of different dimensions
  • Appropriate scaling for non-square matrices

Checklist

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

@Goekdeniz-Guelmez
Copy link
Author

Used a basic 2 layer MLPwith a dummy dataset:

Training with Muon optimizer...
Epoch 1/3, Batch 0/78, Loss: 2.3055
Epoch 1/3, Batch 10/78, Loss: 2.3316
Epoch 1/3, Batch 20/78, Loss: 2.3054
Epoch 1/3, Batch 30/78, Loss: 2.3088
Epoch 1/3, Batch 40/78, Loss: 2.2991
Epoch 1/3, Batch 50/78, Loss: 2.2983
Epoch 1/3, Batch 60/78, Loss: 2.3218
Epoch 1/3, Batch 70/78, Loss: 2.2940
Epoch 1/3 completed in 0.58s - Train Loss: 2.3070, Val Loss: 2.3065, Val Acc: 0.0917
Epoch 2/3, Batch 0/78, Loss: 2.2920
Epoch 2/3, Batch 10/78, Loss: 2.2957
Epoch 2/3, Batch 20/78, Loss: 2.2779
Epoch 2/3, Batch 30/78, Loss: 2.2840
Epoch 2/3, Batch 40/78, Loss: 2.2765
Epoch 2/3, Batch 50/78, Loss: 2.2732
Epoch 2/3, Batch 60/78, Loss: 2.2992
Epoch 2/3, Batch 70/78, Loss: 2.2644
Epoch 2/3 completed in 0.06s - Train Loss: 2.2824, Val Loss: 2.3078, Val Acc: 0.0990
Epoch 3/3, Batch 0/78, Loss: 2.2648
Epoch 3/3, Batch 10/78, Loss: 2.2709
Epoch 3/3, Batch 20/78, Loss: 2.2467
Epoch 3/3, Batch 30/78, Loss: 2.2562
Epoch 3/3, Batch 40/78, Loss: 2.2492
Epoch 3/3, Batch 50/78, Loss: 2.2410
Epoch 3/3, Batch 60/78, Loss: 2.2740
Epoch 3/3, Batch 70/78, Loss: 2.2283
Epoch 3/3 completed in 0.06s - Train Loss: 2.2541, Val Loss: 2.3098, Val Acc: 0.0969

Training with standard SGD optimizer for comparison...
Epoch 1/3, Batch 0/78, Loss: 2.3028
Epoch 1/3, Batch 10/78, Loss: 2.3079
Epoch 1/3, Batch 20/78, Loss: 2.3219
Epoch 1/3, Batch 30/78, Loss: 2.3094
Epoch 1/3, Batch 40/78, Loss: 2.3017
Epoch 1/3, Batch 50/78, Loss: 2.3161
Epoch 1/3, Batch 60/78, Loss: 2.3095
Epoch 1/3, Batch 70/78, Loss: 2.2873
Epoch 1/3 completed in 0.03s - Train Loss: 2.3081, Val Loss: 2.3074, Val Acc: 0.0969
Epoch 2/3, Batch 0/78, Loss: 2.2914
Epoch 2/3, Batch 10/78, Loss: 2.2927
Epoch 2/3, Batch 20/78, Loss: 2.3017
Epoch 2/3, Batch 30/78, Loss: 2.2921
Epoch 2/3, Batch 40/78, Loss: 2.2866
Epoch 2/3, Batch 50/78, Loss: 2.3123
Epoch 2/3, Batch 60/78, Loss: 2.3063
Epoch 2/3, Batch 70/78, Loss: 2.2799
Epoch 2/3 completed in 0.03s - Train Loss: 2.2974, Val Loss: 2.3079, Val Acc: 0.0896
Epoch 3/3, Batch 0/78, Loss: 2.2833
Epoch 3/3, Batch 10/78, Loss: 2.2801
Epoch 3/3, Batch 20/78, Loss: 2.2918
Epoch 3/3, Batch 30/78, Loss: 2.2843
Epoch 3/3, Batch 40/78, Loss: 2.2869
Epoch 3/3, Batch 50/78, Loss: 2.2981
Epoch 3/3, Batch 60/78, Loss: 2.2954
Epoch 3/3, Batch 70/78, Loss: 2.2703
Epoch 3/3 completed in 0.03s - Train Loss: 2.2884, Val Loss: 2.3082, Val Acc: 0.1177

more trainings wil come!

@Goekdeniz-Guelmez
Copy link
Author

LLM SFT Finetuning

python -m mlx_lm.lora \
--model Qwen/Qwen2.5-1.5B-Instruct \
--train \
--data /Users/gokdenizgulmez/Library/Mobile\ Documents/com\~apple\~CloudDocs/Datastes/MLX/data_smoll \ <- samantha
--fine-tune-type dora \
--num-layers 4 \
--batch-size 1 \
--iters 100 \
--val-batches 1 \
--steps-per-report 1 \
--steps-per-eval 50 \
--adapter-path /Users/gokdenizgulmez/Library/Mobile\ Documents/com\~apple\~CloudDocs/Datastes/MLX/test_Muon \
--save-every 10 \
--max-seq-length 4096 \
--grad-checkpoint

Muon

Iter 1: Val loss 1.638, Val took 1.710s
Iter 1: Train loss 1.466, Learning Rate 1.000e-05, It/sec 0.958, Tokens/sec 619.948, Trained Tokens 647, Peak mem 5.860 GB
Iter 2: Train loss 1.189, Learning Rate 1.000e-05, It/sec 0.173, Tokens/sec 564.635, Trained Tokens 3902, Peak mem 12.744 GB
Iter 3: Train loss 1.399, Learning Rate 1.000e-05, It/sec 0.204, Tokens/sec 356.355, Trained Tokens 5652, Peak mem 12.744 GB
Iter 4: Train loss 2.229, Learning Rate 1.000e-05, It/sec 0.933, Tokens/sec 323.692, Trained Tokens 5999, Peak mem 12.744 GB
Iter 5: Train loss 1.512, Learning Rate 1.000e-05, It/sec 0.340, Tokens/sec 621.293, Trained Tokens 7825, Peak mem 12.744 GB
Iter 6: Train loss 1.446, Learning Rate 1.000e-05, It/sec 0.582, Tokens/sec 617.335, Trained Tokens 8886, Peak mem 12.744 GB
Iter 7: Train loss 1.597, Learning Rate 1.000e-05, It/sec 0.823, Tokens/sec 627.471, Trained Tokens 9648, Peak mem 12.744 GB
Iter 8: Train loss 1.793, Learning Rate 1.000e-05, It/sec 0.558, Tokens/sec 523.831, Trained Tokens 10587, Peak mem 12.744 GB
Iter 9: Train loss 1.763, Learning Rate 1.000e-05, It/sec 0.646, Tokens/sec 658.091, Trained Tokens 11605, Peak mem 12.744 GB
Iter 10: Train loss 1.022, Learning Rate 1.000e-05, It/sec 0.367, Tokens/sec 520.598, Trained Tokens 13025, Peak mem 12.744 GB
Iter 10: Saved adapter weights to /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/test_Muon/adapters.safetensors and /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/test_Muon/0000010_adapters.safetensors.
Iter 11: Train loss 1.156, Learning Rate 1.000e-05, It/sec 0.469, Tokens/sec 619.401, Trained Tokens 14346, Peak mem 12.744 GB
Iter 12: Train loss 1.381, Learning Rate 1.000e-05, It/sec 0.494, Tokens/sec 614.444, Trained Tokens 15590, Peak mem 12.744 GB
Iter 13: Train loss 1.725, Learning Rate 1.000e-05, It/sec 0.245, Tokens/sec 578.744, Trained Tokens 17956, Peak mem 12.744 GB
Iter 14: Train loss 1.447, Learning Rate 1.000e-05, It/sec 0.245, Tokens/sec 556.406, Trained Tokens 20224, Peak mem 12.744 GB
Iter 15: Train loss 1.443, Learning Rate 1.000e-05, It/sec 0.295, Tokens/sec 581.241, Trained Tokens 22194, Peak mem 12.744 GB
Iter 16: Train loss 1.397, Learning Rate 1.000e-05, It/sec 0.382, Tokens/sec 610.621, Trained Tokens 23793, Peak mem 12.744 GB
Iter 17: Train loss 1.550, Learning Rate 1.000e-05, It/sec 0.551, Tokens/sec 602.464, Trained Tokens 24886, Peak mem 12.744 GB
Iter 18: Train loss 0.884, Learning Rate 1.000e-05, It/sec 0.574, Tokens/sec 611.241, Trained Tokens 25951, Peak mem 12.744 GB
Iter 19: Train loss 1.424, Learning Rate 1.000e-05, It/sec 0.254, Tokens/sec 341.668, Trained Tokens 27296, Peak mem 12.744 GB
Iter 20: Train loss 1.713, Learning Rate 1.000e-05, It/sec 0.422, Tokens/sec 606.567, Trained Tokens 28735, Peak mem 12.744 GB
Iter 20: Saved adapter weights to /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/test_Muon/adapters.safetensors and /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/test_Muon/0000020_adapters.safetensors.
Iter 21: Train loss 1.322, Learning Rate 1.000e-05, It/sec 0.844, Tokens/sec 615.596, Trained Tokens 29464, Peak mem 12.744 GB
Iter 22: Train loss 1.685, Learning Rate 1.000e-05, It/sec 0.490, Tokens/sec 597.661, Trained Tokens 30683, Peak mem 12.744 GB
Iter 23: Train loss 1.298, Learning Rate 1.000e-05, It/sec 0.380, Tokens/sec 572.848, Trained Tokens 32190, Peak mem 12.744 GB
Iter 24: Train loss 1.707, Learning Rate 1.000e-05, It/sec 0.769, Tokens/sec 598.655, Trained Tokens 32968, Peak mem 12.744 GB
Iter 25: Train loss 1.942, Learning Rate 1.000e-05, It/sec 1.001, Tokens/sec 618.561, Trained Tokens 33586, Peak mem 12.744 GB
Iter 26: Train loss 1.394, Learning Rate 1.000e-05, It/sec 0.232, Tokens/sec 438.454, Trained Tokens 35475, Peak mem 12.744 GB
Iter 27: Train loss 0.959, Learning Rate 1.000e-05, It/sec 0.698, Tokens/sec 606.886, Trained Tokens 36344, Peak mem 12.744 GB
Iter 28: Train loss 1.813, Learning Rate 1.000e-05, It/sec 1.252, Tokens/sec 629.799, Trained Tokens 36847, Peak mem 12.744 GB
Iter 29: Train loss 1.326, Learning Rate 1.000e-05, It/sec 0.453, Tokens/sec 582.481, Trained Tokens 38134, Peak mem 12.744 GB
Iter 30: Train loss 1.428, Learning Rate 1.000e-05, It/sec 0.336, Tokens/sec 585.171, Trained Tokens 39877, Peak mem 12.744 GB
Iter 30: Saved adapter weights to /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/test_Muon/adapters.safetensors and /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/test_Muon/0000030_adapters.safetensors.
Iter 31: Train loss 1.194, Learning Rate 1.000e-05, It/sec 0.268, Tokens/sec 576.286, Trained Tokens 42028, Peak mem 12.744 GB
Iter 32: Train loss 1.635, Learning Rate 1.000e-05, It/sec 0.487, Tokens/sec 611.771, Trained Tokens 43283, Peak mem 12.744 GB
Iter 33: Train loss 1.634, Learning Rate 1.000e-05, It/sec 0.488, Tokens/sec 610.358, Trained Tokens 44534, Peak mem 12.744 GB
Iter 34: Train loss 1.293, Learning Rate 1.000e-05, It/sec 0.841, Tokens/sec 606.378, Trained Tokens 45255, Peak mem 12.744 GB
Iter 35: Train loss 1.499, Learning Rate 1.000e-05, It/sec 0.458, Tokens/sec 611.155, Trained Tokens 46588, Peak mem 12.744 GB
Iter 36: Train loss 1.660, Learning Rate 1.000e-05, It/sec 0.776, Tokens/sec 607.610, Trained Tokens 47371, Peak mem 12.744 GB
Iter 37: Train loss 1.599, Learning Rate 1.000e-05, It/sec 0.780, Tokens/sec 600.431, Trained Tokens 48141, Peak mem 12.744 GB
Iter 38: Train loss 1.995, Learning Rate 1.000e-05, It/sec 0.657, Tokens/sec 621.528, Trained Tokens 49087, Peak mem 12.744 GB
Iter 39: Train loss 1.799, Learning Rate 1.000e-05, It/sec 0.617, Tokens/sec 608.044, Trained Tokens 50072, Peak mem 12.744 GB
Iter 40: Train loss 1.822, Learning Rate 1.000e-05, It/sec 0.583, Tokens/sec 603.529, Trained Tokens 51108, Peak mem 12.744 GB
Iter 40: Saved adapter weights to /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/test_Muon/adapters.safetensors and /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/test_Muon/0000040_adapters.safetensors.
Iter 41: Train loss 1.490, Learning Rate 1.000e-05, It/sec 1.013, Tokens/sec 602.438, Trained Tokens 51703, Peak mem 12.744 GB
Iter 42: Train loss 1.045, Learning Rate 1.000e-05, It/sec 0.542, Tokens/sec 608.561, Trained Tokens 52826, Peak mem 12.744 GB
Iter 43: Train loss 1.379, Learning Rate 1.000e-05, It/sec 0.416, Tokens/sec 597.864, Trained Tokens 54264, Peak mem 12.744 GB
Iter 44: Train loss 1.652, Learning Rate 1.000e-05, It/sec 0.625, Tokens/sec 601.502, Trained Tokens 55227, Peak mem 12.744 GB
Iter 45: Train loss 1.688, Learning Rate 1.000e-05, It/sec 0.910, Tokens/sec 630.532, Trained Tokens 55920, Peak mem 12.744 GB
Iter 46: Train loss 1.800, Learning Rate 1.000e-05, It/sec 0.769, Tokens/sec 620.766, Trained Tokens 56727, Peak mem 12.744 GB
Iter 47: Train loss 1.776, Learning Rate 1.000e-05, It/sec 0.699, Tokens/sec 620.070, Trained Tokens 57614, Peak mem 12.744 GB
Iter 48: Train loss 1.485, Learning Rate 1.000e-05, It/sec 0.656, Tokens/sec 627.084, Trained Tokens 58570, Peak mem 12.744 GB
Iter 49: Train loss 1.485, Learning Rate 1.000e-05, It/sec 0.253, Tokens/sec 570.277, Trained Tokens 60827, Peak mem 12.744 GB
Iter 50: Val loss 1.750, Val took 1.551s
Iter 50: Train loss 1.714, Learning Rate 1.000e-05, It/sec 0.512, Tokens/sec 617.394, Trained Tokens 62032, Peak mem 12.744 GB
Iter 50: Saved adapter weights to /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/test_Muon/adapters.safetensors and /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/test_Muon/0000050_adapters.safetensors.
Iter 51: Train loss 1.254, Learning Rate 1.000e-05, It/sec 0.657, Tokens/sec 619.473, Trained Tokens 62975, Peak mem 12.744 GB
Iter 52: Train loss 1.144, Learning Rate 1.000e-05, It/sec 0.438, Tokens/sec 600.079, Trained Tokens 64344, Peak mem 12.744 GB
Iter 53: Train loss 1.054, Learning Rate 1.000e-05, It/sec 0.624, Tokens/sec 599.977, Trained Tokens 65306, Peak mem 12.744 GB
Iter 54: Train loss 1.468, Learning Rate 1.000e-05, It/sec 0.465, Tokens/sec 600.392, Trained Tokens 66598, Peak mem 12.744 GB
Iter 55: Train loss 1.737, Learning Rate 1.000e-05, It/sec 0.706, Tokens/sec 623.741, Trained Tokens 67482, Peak mem 12.744 GB
Iter 56: Train loss 1.339, Learning Rate 1.000e-05, It/sec 0.542, Tokens/sec 616.532, Trained Tokens 68619, Peak mem 12.744 GB
Iter 57: Train loss 1.531, Learning Rate 1.000e-05, It/sec 0.295, Tokens/sec 581.258, Trained Tokens 70588, Peak mem 12.744 GB
Iter 58: Train loss 1.428, Learning Rate 1.000e-05, It/sec 0.462, Tokens/sec 592.028, Trained Tokens 71869, Peak mem 12.744 GB
Iter 59: Train loss 1.068, Learning Rate 1.000e-05, It/sec 0.396, Tokens/sec 593.931, Trained Tokens 73369, Peak mem 12.744 GB
Iter 60: Train loss 1.690, Learning Rate 1.000e-05, It/sec 1.116, Tokens/sec 617.204, Trained Tokens 73922, Peak mem 12.744 GB
Iter 60: Saved adapter weights to /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/test_Muon/adapters.safetensors and /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/test_Muon/0000060_adapters.safetensors.
Iter 61: Train loss 1.314, Learning Rate 1.000e-05, It/sec 0.245, Tokens/sec 568.104, Trained Tokens 76239, Peak mem 12.744 GB
Iter 62: Train loss 1.633, Learning Rate 1.000e-05, It/sec 0.415, Tokens/sec 362.680, Trained Tokens 77112, Peak mem 12.744 GB
Iter 63: Train loss 1.392, Learning Rate 1.000e-05, It/sec 0.322, Tokens/sec 581.527, Trained Tokens 78919, Peak mem 12.744 GB
Iter 64: Train loss 2.108, Learning Rate 1.000e-05, It/sec 1.422, Tokens/sec 605.856, Trained Tokens 79345, Peak mem 12.744 GB
Iter 65: Train loss 1.301, Learning Rate 1.000e-05, It/sec 0.381, Tokens/sec 594.039, Trained Tokens 80904, Peak mem 12.744 GB
Iter 66: Train loss 1.689, Learning Rate 1.000e-05, It/sec 1.009, Tokens/sec 592.073, Trained Tokens 81491, Peak mem 12.744 GB
Iter 67: Train loss 1.520, Learning Rate 1.000e-05, It/sec 0.394, Tokens/sec 596.168, Trained Tokens 83006, Peak mem 12.744 GB
Iter 68: Train loss 1.473, Learning Rate 1.000e-05, It/sec 0.320, Tokens/sec 581.864, Trained Tokens 84824, Peak mem 12.744 GB
Iter 69: Train loss 1.679, Learning Rate 1.000e-05, It/sec 0.461, Tokens/sec 583.405, Trained Tokens 86090, Peak mem 12.744 GB
Iter 70: Train loss 1.591, Learning Rate 1.000e-05, It/sec 0.608, Tokens/sec 597.838, Trained Tokens 87073, Peak mem 12.744 GB
Iter 70: Saved adapter weights to /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/test_Muon/adapters.safetensors and /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/test_Muon/0000070_adapters.safetensors.
Iter 71: Train loss 1.537, Learning Rate 1.000e-05, It/sec 0.300, Tokens/sec 577.528, Trained Tokens 88996, Peak mem 12.744 GB
Iter 72: Train loss 1.283, Learning Rate 1.000e-05, It/sec 0.383, Tokens/sec 599.126, Trained Tokens 90562, Peak mem 12.744 GB
Iter 73: Train loss 2.318, Learning Rate 1.000e-05, It/sec 1.264, Tokens/sec 578.981, Trained Tokens 91020, Peak mem 12.744 GB
Iter 74: Train loss 1.272, Learning Rate 1.000e-05, It/sec 0.375, Tokens/sec 426.068, Trained Tokens 92157, Peak mem 12.744 GB
Iter 75: Train loss 1.732, Learning Rate 1.000e-05, It/sec 0.362, Tokens/sec 588.029, Trained Tokens 93783, Peak mem 12.744 GB
Iter 76: Train loss 1.517, Learning Rate 1.000e-05, It/sec 0.303, Tokens/sec 540.497, Trained Tokens 95568, Peak mem 12.744 GB
Iter 77: Train loss 1.440, Learning Rate 1.000e-05, It/sec 0.745, Tokens/sec 614.929, Trained Tokens 96393, Peak mem 12.744 GB
Iter 78: Train loss 1.558, Learning Rate 1.000e-05, It/sec 0.600, Tokens/sec 613.459, Trained Tokens 97416, Peak mem 12.744 GB
Iter 79: Train loss 1.480, Learning Rate 1.000e-05, It/sec 0.354, Tokens/sec 490.311, Trained Tokens 98801, Peak mem 12.744 GB
Iter 80: Train loss 1.417, Learning Rate 1.000e-05, It/sec 0.825, Tokens/sec 599.015, Trained Tokens 99527, Peak mem 12.744 GB
Iter 80: Saved adapter weights to /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/test_Muon/adapters.safetensors and /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/test_Muon/0000080_adapters.safetensors.
Iter 81: Train loss 1.021, Learning Rate 1.000e-05, It/sec 0.367, Tokens/sec 352.983, Trained Tokens 100488, Peak mem 12.744 GB
Iter 82: Train loss 1.551, Learning Rate 1.000e-05, It/sec 0.293, Tokens/sec 576.772, Trained Tokens 102455, Peak mem 12.744 GB
Iter 83: Train loss 1.587, Learning Rate 1.000e-05, It/sec 0.505, Tokens/sec 607.992, Trained Tokens 103659, Peak mem 12.744 GB
Iter 84: Train loss 1.675, Learning Rate 1.000e-05, It/sec 0.917, Tokens/sec 589.006, Trained Tokens 104301, Peak mem 12.744 GB
Iter 85: Train loss 1.498, Learning Rate 1.000e-05, It/sec 0.612, Tokens/sec 602.992, Trained Tokens 105287, Peak mem 12.744 GB
Iter 86: Train loss 1.832, Learning Rate 1.000e-05, It/sec 0.646, Tokens/sec 611.999, Trained Tokens 106235, Peak mem 12.744 GB
Iter 87: Train loss 1.497, Learning Rate 1.000e-05, It/sec 0.457, Tokens/sec 610.228, Trained Tokens 107570, Peak mem 12.744 GB
Iter 88: Train loss 1.059, Learning Rate 1.000e-05, It/sec 0.310, Tokens/sec 437.320, Trained Tokens 108980, Peak mem 12.744 GB
Iter 89: Train loss 1.317, Learning Rate 1.000e-05, It/sec 0.565, Tokens/sec 608.608, Trained Tokens 110057, Peak mem 12.744 GB
Iter 90: Train loss 1.514, Learning Rate 1.000e-05, It/sec 0.255, Tokens/sec 437.049, Trained Tokens 111774, Peak mem 12.744 GB
Iter 90: Saved adapter weights to /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/test_Muon/adapters.safetensors and /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/test_Muon/0000090_adapters.safetensors.
Iter 91: Train loss 1.480, Learning Rate 1.000e-05, It/sec 0.547, Tokens/sec 601.086, Trained Tokens 112873, Peak mem 12.744 GB
Iter 92: Train loss 1.635, Learning Rate 1.000e-05, It/sec 0.456, Tokens/sec 441.178, Trained Tokens 113840, Peak mem 12.744 GB
Iter 93: Train loss 1.270, Learning Rate 1.000e-05, It/sec 0.567, Tokens/sec 605.699, Trained Tokens 114909, Peak mem 12.744 GB
Iter 94: Train loss 1.615, Learning Rate 1.000e-05, It/sec 0.289, Tokens/sec 487.284, Trained Tokens 116595, Peak mem 12.744 GB
Iter 95: Train loss 1.383, Learning Rate 1.000e-05, It/sec 0.254, Tokens/sec 397.945, Trained Tokens 118162, Peak mem 12.744 GB
Iter 96: Train loss 1.099, Learning Rate 1.000e-05, It/sec 0.545, Tokens/sec 620.016, Trained Tokens 119299, Peak mem 12.744 GB
Iter 97: Train loss 1.660, Learning Rate 1.000e-05, It/sec 0.465, Tokens/sec 597.845, Trained Tokens 120585, Peak mem 12.744 GB
Iter 98: Train loss 1.412, Learning Rate 1.000e-05, It/sec 0.402, Tokens/sec 593.111, Trained Tokens 122061, Peak mem 12.744 GB
Iter 99: Train loss 1.752, Learning Rate 1.000e-05, It/sec 0.441, Tokens/sec 596.284, Trained Tokens 123412, Peak mem 12.744 GB
Iter 100: Val loss 1.590, Val took 1.321s

Adam

Iter 1: Val loss 1.638, Val took 1.709s
Iter 1: Train loss 1.466, Learning Rate 1.000e-05, It/sec 0.972, Tokens/sec 629.107, Trained Tokens 647, Peak mem 5.860 GB
Iter 2: Train loss 1.189, Learning Rate 1.000e-05, It/sec 0.177, Tokens/sec 575.080, Trained Tokens 3902, Peak mem 12.744 GB
Iter 3: Train loss 1.398, Learning Rate 1.000e-05, It/sec 0.356, Tokens/sec 623.304, Trained Tokens 5652, Peak mem 12.744 GB
Iter 4: Train loss 2.225, Learning Rate 1.000e-05, It/sec 1.811, Tokens/sec 628.255, Trained Tokens 5999, Peak mem 12.744 GB
Iter 5: Train loss 1.511, Learning Rate 1.000e-05, It/sec 0.319, Tokens/sec 582.413, Trained Tokens 7825, Peak mem 12.744 GB
Iter 6: Train loss 1.445, Learning Rate 1.000e-05, It/sec 0.613, Tokens/sec 650.911, Trained Tokens 8886, Peak mem 12.744 GB
Iter 7: Train loss 1.593, Learning Rate 1.000e-05, It/sec 0.881, Tokens/sec 671.453, Trained Tokens 9648, Peak mem 12.744 GB
Iter 8: Train loss 1.786, Learning Rate 1.000e-05, It/sec 0.601, Tokens/sec 564.607, Trained Tokens 10587, Peak mem 12.744 GB
Iter 9: Train loss 1.756, Learning Rate 1.000e-05, It/sec 0.645, Tokens/sec 656.109, Trained Tokens 11605, Peak mem 12.744 GB
Iter 10: Train loss 1.019, Learning Rate 1.000e-05, It/sec 0.445, Tokens/sec 632.010, Trained Tokens 13025, Peak mem 12.744 GB
Iter 10: Saved adapter weights to /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/test_Muon/adapters.safetensors and /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/test_Muon/0000010_adapters.safetensors.
Iter 11: Train loss 1.154, Learning Rate 1.000e-05, It/sec 0.481, Tokens/sec 635.520, Trained Tokens 14346, Peak mem 12.744 GB
Iter 12: Train loss 1.377, Learning Rate 1.000e-05, It/sec 0.444, Tokens/sec 552.467, Trained Tokens 15590, Peak mem 12.744 GB
Iter 13: Train loss 1.722, Learning Rate 1.000e-05, It/sec 0.256, Tokens/sec 604.657, Trained Tokens 17956, Peak mem 12.744 GB
Iter 14: Train loss 1.442, Learning Rate 1.000e-05, It/sec 0.266, Tokens/sec 602.850, Trained Tokens 20224, Peak mem 12.744 GB
Iter 15: Train loss 1.437, Learning Rate 1.000e-05, It/sec 0.312, Tokens/sec 614.792, Trained Tokens 22194, Peak mem 12.744 GB
Iter 16: Train loss 1.387, Learning Rate 1.000e-05, It/sec 0.396, Tokens/sec 633.014, Trained Tokens 23793, Peak mem 12.744 GB
Iter 17: Train loss 1.538, Learning Rate 1.000e-05, It/sec 0.573, Tokens/sec 625.796, Trained Tokens 24886, Peak mem 12.744 GB
Iter 18: Train loss 0.874, Learning Rate 1.000e-05, It/sec 0.596, Tokens/sec 635.225, Trained Tokens 25951, Peak mem 12.744 GB
Iter 19: Train loss 1.410, Learning Rate 1.000e-05, It/sec 0.443, Tokens/sec 596.280, Trained Tokens 27296, Peak mem 12.744 GB
Iter 20: Train loss 1.700, Learning Rate 1.000e-05, It/sec 0.423, Tokens/sec 609.029, Trained Tokens 28735, Peak mem 12.744 GB
Iter 20: Saved adapter weights to /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/test_Muon/adapters.safetensors and /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/test_Muon/0000020_adapters.safetensors.
Iter 21: Train loss 1.297, Learning Rate 1.000e-05, It/sec 0.853, Tokens/sec 622.047, Trained Tokens 29464, Peak mem 12.744 GB
Iter 22: Train loss 1.664, Learning Rate 1.000e-05, It/sec 0.492, Tokens/sec 599.321, Trained Tokens 30683, Peak mem 12.744 GB
Iter 23: Train loss 1.282, Learning Rate 1.000e-05, It/sec 0.374, Tokens/sec 563.105, Trained Tokens 32190, Peak mem 12.744 GB
Iter 24: Train loss 1.677, Learning Rate 1.000e-05, It/sec 0.788, Tokens/sec 613.112, Trained Tokens 32968, Peak mem 12.744 GB
Iter 25: Train loss 1.898, Learning Rate 1.000e-05, It/sec 1.017, Tokens/sec 628.653, Trained Tokens 33586, Peak mem 12.744 GB
Iter 26: Train loss 1.379, Learning Rate 1.000e-05, It/sec 0.311, Tokens/sec 587.401, Trained Tokens 35475, Peak mem 12.744 GB
Iter 27: Train loss 0.933, Learning Rate 1.000e-05, It/sec 0.712, Tokens/sec 619.114, Trained Tokens 36344, Peak mem 12.744 GB
Iter 28: Train loss 1.784, Learning Rate 1.000e-05, It/sec 1.268, Tokens/sec 637.579, Trained Tokens 36847, Peak mem 12.744 GB
Iter 29: Train loss 1.300, Learning Rate 1.000e-05, It/sec 0.370, Tokens/sec 475.946, Trained Tokens 38134, Peak mem 12.744 GB
Iter 30: Train loss 1.412, Learning Rate 1.000e-05, It/sec 0.338, Tokens/sec 588.921, Trained Tokens 39877, Peak mem 12.744 GB
Iter 30: Saved adapter weights to /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/test_Muon/adapters.safetensors and /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/test_Muon/0000030_adapters.safetensors.
Iter 31: Train loss 1.183, Learning Rate 1.000e-05, It/sec 0.245, Tokens/sec 528.035, Trained Tokens 42028, Peak mem 12.744 GB
Iter 32: Train loss 1.611, Learning Rate 1.000e-05, It/sec 0.370, Tokens/sec 464.962, Trained Tokens 43283, Peak mem 12.744 GB
Iter 33: Train loss 1.605, Learning Rate 1.000e-05, It/sec 0.491, Tokens/sec 614.706, Trained Tokens 44534, Peak mem 12.744 GB
Iter 34: Train loss 1.249, Learning Rate 1.000e-05, It/sec 0.845, Tokens/sec 609.598, Trained Tokens 45255, Peak mem 12.744 GB
Iter 35: Train loss 1.471, Learning Rate 1.000e-05, It/sec 0.458, Tokens/sec 610.400, Trained Tokens 46588, Peak mem 12.744 GB
Iter 36: Train loss 1.621, Learning Rate 1.000e-05, It/sec 0.742, Tokens/sec 580.700, Trained Tokens 47371, Peak mem 12.744 GB
Iter 37: Train loss 1.530, Learning Rate 1.000e-05, It/sec 0.684, Tokens/sec 526.483, Trained Tokens 48141, Peak mem 12.744 GB
Iter 38: Train loss 1.958, Learning Rate 1.000e-05, It/sec 0.659, Tokens/sec 623.280, Trained Tokens 49087, Peak mem 12.744 GB
Iter 39: Train loss 1.742, Learning Rate 1.000e-05, It/sec 0.519, Tokens/sec 511.139, Trained Tokens 50072, Peak mem 12.744 GB
Iter 40: Train loss 1.761, Learning Rate 1.000e-05, It/sec 0.489, Tokens/sec 506.599, Trained Tokens 51108, Peak mem 12.744 GB
Iter 40: Saved adapter weights to /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/test_Muon/adapters.safetensors and /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/test_Muon/0000040_adapters.safetensors.
Iter 41: Train loss 1.426, Learning Rate 1.000e-05, It/sec 0.716, Tokens/sec 425.871, Trained Tokens 51703, Peak mem 12.744 GB
Iter 42: Train loss 1.010, Learning Rate 1.000e-05, It/sec 0.415, Tokens/sec 466.172, Trained Tokens 52826, Peak mem 12.744 GB
Iter 43: Train loss 1.331, Learning Rate 1.000e-05, It/sec 0.393, Tokens/sec 564.757, Trained Tokens 54264, Peak mem 12.744 GB
Iter 44: Train loss 1.583, Learning Rate 1.000e-05, It/sec 0.621, Tokens/sec 597.810, Trained Tokens 55227, Peak mem 12.744 GB
Iter 45: Train loss 1.569, Learning Rate 1.000e-05, It/sec 0.358, Tokens/sec 248.046, Trained Tokens 55920, Peak mem 12.744 GB
Iter 46: Train loss 1.694, Learning Rate 1.000e-05, It/sec 0.694, Tokens/sec 559.991, Trained Tokens 56727, Peak mem 12.744 GB
Iter 47: Train loss 1.692, Learning Rate 1.000e-05, It/sec 0.636, Tokens/sec 564.235, Trained Tokens 57614, Peak mem 12.744 GB
Iter 48: Train loss 1.399, Learning Rate 1.000e-05, It/sec 0.646, Tokens/sec 617.509, Trained Tokens 58570, Peak mem 12.744 GB
Iter 49: Train loss 1.451, Learning Rate 1.000e-05, It/sec 0.242, Tokens/sec 547.274, Trained Tokens 60827, Peak mem 12.744 GB
Iter 50: Val loss 1.704, Val took 2.831s
Iter 50: Train loss 1.662, Learning Rate 1.000e-05, It/sec 0.505, Tokens/sec 608.072, Trained Tokens 62032, Peak mem 12.744 GB
Iter 50: Saved adapter weights to /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/test_Muon/adapters.safetensors and /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/test_Muon/0000050_adapters.safetensors.
Iter 51: Train loss 1.195, Learning Rate 1.000e-05, It/sec 0.655, Tokens/sec 617.362, Trained Tokens 62975, Peak mem 12.744 GB
Iter 52: Train loss 1.104, Learning Rate 1.000e-05, It/sec 0.438, Tokens/sec 599.338, Trained Tokens 64344, Peak mem 12.744 GB
Iter 53: Train loss 0.998, Learning Rate 1.000e-05, It/sec 0.603, Tokens/sec 580.070, Trained Tokens 65306, Peak mem 12.744 GB
Iter 54: Train loss 1.410, Learning Rate 1.000e-05, It/sec 0.296, Tokens/sec 382.980, Trained Tokens 66598, Peak mem 12.744 GB
Iter 55: Train loss 1.631, Learning Rate 1.000e-05, It/sec 0.673, Tokens/sec 594.911, Trained Tokens 67482, Peak mem 12.744 GB
Iter 56: Train loss 1.277, Learning Rate 1.000e-05, It/sec 0.469, Tokens/sec 532.733, Trained Tokens 68619, Peak mem 12.744 GB
Iter 57: Train loss 1.476, Learning Rate 1.000e-05, It/sec 0.241, Tokens/sec 473.725, Trained Tokens 70588, Peak mem 12.744 GB
Iter 58: Train loss 1.381, Learning Rate 1.000e-05, It/sec 0.458, Tokens/sec 586.469, Trained Tokens 71869, Peak mem 12.744 GB
Iter 59: Train loss 1.023, Learning Rate 1.000e-05, It/sec 0.396, Tokens/sec 594.025, Trained Tokens 73369, Peak mem 12.744 GB
Iter 60: Train loss 1.503, Learning Rate 1.000e-05, It/sec 1.126, Tokens/sec 622.941, Trained Tokens 73922, Peak mem 12.744 GB
Iter 60: Saved adapter weights to /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/test_Muon/adapters.safetensors and /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/test_Muon/0000060_adapters.safetensors.
Iter 61: Train loss 1.274, Learning Rate 1.000e-05, It/sec 0.245, Tokens/sec 566.824, Trained Tokens 76239, Peak mem 12.744 GB
Iter 62: Train loss 1.542, Learning Rate 1.000e-05, It/sec 0.704, Tokens/sec 614.720, Trained Tokens 77112, Peak mem 12.744 GB
Iter 63: Train loss 1.327, Learning Rate 1.000e-05, It/sec 0.318, Tokens/sec 574.726, Trained Tokens 78919, Peak mem 12.744 GB
Iter 64: Train loss 1.936, Learning Rate 1.000e-05, It/sec 1.453, Tokens/sec 618.989, Trained Tokens 79345, Peak mem 12.744 GB
Iter 65: Train loss 1.261, Learning Rate 1.000e-05, It/sec 0.381, Tokens/sec 593.783, Trained Tokens 80904, Peak mem 12.744 GB
Iter 66: Train loss 1.509, Learning Rate 1.000e-05, It/sec 1.016, Tokens/sec 596.523, Trained Tokens 81491, Peak mem 12.744 GB
Iter 67: Train loss 1.472, Learning Rate 1.000e-05, It/sec 0.394, Tokens/sec 596.880, Trained Tokens 83006, Peak mem 12.744 GB
Iter 68: Train loss 1.413, Learning Rate 1.000e-05, It/sec 0.318, Tokens/sec 577.550, Trained Tokens 84824, Peak mem 12.744 GB
Iter 69: Train loss 1.605, Learning Rate 1.000e-05, It/sec 0.473, Tokens/sec 599.301, Trained Tokens 86090, Peak mem 12.744 GB
Iter 70: Train loss 1.450, Learning Rate 1.000e-05, It/sec 0.610, Tokens/sec 599.262, Trained Tokens 87073, Peak mem 12.744 GB
Iter 70: Saved adapter weights to /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/test_Muon/adapters.safetensors and /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/test_Muon/0000070_adapters.safetensors.
Iter 71: Train loss 1.467, Learning Rate 1.000e-05, It/sec 0.298, Tokens/sec 573.052, Trained Tokens 88996, Peak mem 12.744 GB
Iter 72: Train loss 1.228, Learning Rate 1.000e-05, It/sec 0.382, Tokens/sec 598.298, Trained Tokens 90562, Peak mem 12.744 GB
Iter 73: Train loss 2.151, Learning Rate 1.000e-05, It/sec 1.282, Tokens/sec 587.172, Trained Tokens 91020, Peak mem 12.744 GB
Iter 74: Train loss 1.185, Learning Rate 1.000e-05, It/sec 0.384, Tokens/sec 436.397, Trained Tokens 92157, Peak mem 12.744 GB
Iter 75: Train loss 1.664, Learning Rate 1.000e-05, It/sec 0.361, Tokens/sec 586.634, Trained Tokens 93783, Peak mem 12.744 GB
Iter 76: Train loss 1.446, Learning Rate 1.000e-05, It/sec 0.300, Tokens/sec 536.035, Trained Tokens 95568, Peak mem 12.744 GB
Iter 77: Train loss 1.337, Learning Rate 1.000e-05, It/sec 0.748, Tokens/sec 617.211, Trained Tokens 96393, Peak mem 12.744 GB
Iter 78: Train loss 1.480, Learning Rate 1.000e-05, It/sec 0.600, Tokens/sec 614.257, Trained Tokens 97416, Peak mem 12.744 GB
Iter 79: Train loss 1.411, Learning Rate 1.000e-05, It/sec 0.359, Tokens/sec 496.979, Trained Tokens 98801, Peak mem 12.744 GB
Iter 80: Train loss 1.302, Learning Rate 1.000e-05, It/sec 0.844, Tokens/sec 612.869, Trained Tokens 99527, Peak mem 12.744 GB
Iter 80: Saved adapter weights to /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/test_Muon/adapters.safetensors and /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/test_Muon/0000080_adapters.safetensors.
Iter 81: Train loss 0.899, Learning Rate 1.000e-05, It/sec 0.491, Tokens/sec 472.231, Trained Tokens 100488, Peak mem 12.744 GB
Iter 82: Train loss 1.505, Learning Rate 1.000e-05, It/sec 0.180, Tokens/sec 354.323, Trained Tokens 102455, Peak mem 12.744 GB
Iter 83: Train loss 1.514, Learning Rate 1.000e-05, It/sec 0.506, Tokens/sec 609.009, Trained Tokens 103659, Peak mem 12.744 GB
Iter 84: Train loss 1.533, Learning Rate 1.000e-05, It/sec 0.835, Tokens/sec 536.039, Trained Tokens 104301, Peak mem 12.744 GB
Iter 85: Train loss 1.409, Learning Rate 1.000e-05, It/sec 0.612, Tokens/sec 603.270, Trained Tokens 105287, Peak mem 12.744 GB
Iter 86: Train loss 1.699, Learning Rate 1.000e-05, It/sec 0.654, Tokens/sec 620.038, Trained Tokens 106235, Peak mem 12.744 GB
Iter 87: Train loss 1.415, Learning Rate 1.000e-05, It/sec 0.443, Tokens/sec 591.791, Trained Tokens 107570, Peak mem 12.744 GB
Iter 88: Train loss 0.965, Learning Rate 1.000e-05, It/sec 0.368, Tokens/sec 518.771, Trained Tokens 108980, Peak mem 12.744 GB
Iter 89: Train loss 1.227, Learning Rate 1.000e-05, It/sec 0.338, Tokens/sec 363.499, Trained Tokens 110057, Peak mem 12.744 GB
Iter 90: Train loss 1.449, Learning Rate 1.000e-05, It/sec 0.347, Tokens/sec 595.435, Trained Tokens 111774, Peak mem 12.744 GB
Iter 90: Saved adapter weights to /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/test_Muon/adapters.safetensors and /Users/gokdenizgulmez/Library/Mobile Documents/com~apple~CloudDocs/Datastes/MLX/test_Muon/0000090_adapters.safetensors.
Iter 91: Train loss 1.358, Learning Rate 1.000e-05, It/sec 0.544, Tokens/sec 597.848, Trained Tokens 112873, Peak mem 12.744 GB
Iter 92: Train loss 1.535, Learning Rate 1.000e-05, It/sec 0.510, Tokens/sec 493.189, Trained Tokens 113840, Peak mem 12.744 GB
Iter 93: Train loss 1.172, Learning Rate 1.000e-05, It/sec 0.569, Tokens/sec 608.483, Trained Tokens 114909, Peak mem 12.744 GB
Iter 94: Train loss 1.540, Learning Rate 1.000e-05, It/sec 0.350, Tokens/sec 590.328, Trained Tokens 116595, Peak mem 12.744 GB
Iter 95: Train loss 1.309, Learning Rate 1.000e-05, It/sec 0.382, Tokens/sec 599.283, Trained Tokens 118162, Peak mem 12.744 GB
Iter 96: Train loss 1.005, Learning Rate 1.000e-05, It/sec 0.543, Tokens/sec 617.826, Trained Tokens 119299, Peak mem 12.744 GB
Iter 97: Train loss 1.548, Learning Rate 1.000e-05, It/sec 0.466, Tokens/sec 599.246, Trained Tokens 120585, Peak mem 12.744 GB
Iter 98: Train loss 1.339, Learning Rate 1.000e-05, It/sec 0.401, Tokens/sec 592.229, Trained Tokens 122061, Peak mem 12.744 GB
Iter 99: Train loss 1.658, Learning Rate 1.000e-05, It/sec 0.443, Tokens/sec 598.112, Trained Tokens 123412, Peak mem 12.744 GB
Iter 100: Val loss 1.505, Val took 1.323s
Iter 100: Train loss 1.380, Learning Rate 1.000e-05, It/sec 0.363, Tokens/sec 446.054, Trained Tokens 124641, Peak mem 12.744 GB

@Goekdeniz-Guelmez Goekdeniz-Guelmez changed the title initial commit with workong optmimizer Adding support for the Muon Optimizer Feb 28, 2025
@Goekdeniz-Guelmez Goekdeniz-Guelmez marked this pull request as ready for review February 28, 2025 23:33
@angeloskath
Copy link
Member

That is definitely interesting but I think https://github.com/stockeh/mlx-optimizers may be a more suitable repository. Wdyt?

@lin72h
Copy link

lin72h commented Mar 1, 2025

@Goekdeniz-Guelmez 🔥 Perfect timing! The Muon optimizer just dropped, and now it’s already in MLX!!!
pure optimization wizardry.

@stockeh
Copy link

stockeh commented Mar 1, 2025

@Goekdeniz-Guelmez @angeloskath yes, we have Muon already:

https://github.com/stockeh/mlx-optimizers/blob/main/mlx_optimizers/muon.py

thought I do believe Keller Jordan had made some minor updates since.

@Goekdeniz-Guelmez
Copy link
Author

@stockeh I didn't new the optimiser repo existed :D. But yea there are some differences with the new one. The new maintains the same mathematical principles but extends support to higher-dimensional tensors like conv filters through reshaping rather than using a separate optimizer. Also improves efficiency with a streamlined Newton-Schulz iteration formula and applies weight decay earlier in optimization process. The code now handles non-2D parameters more consistently and uses generalized transpose and normalization logic, works with tensors of any dimensionality.

@toothacher17
Copy link

toothacher17 commented Mar 1, 2025

hi, @stockeh

We recently worked on Muon and released the Moonlight model, see (https://github.com/MoonshotAI/Moonlight/tree/master). We had some empirical observations for muon to scale (and we did not see it in current implementation), and hope you do not mind me sharing it here:

  1. introducing weight decay, otherwise your weight rms might be too big when over-trained;
  2. adjusting the update rms based on the matrix shape, otherwise your model weights will not have consistent update RMS. This line (https://github.com/stockeh/mlx-optimizers/blob/main/mlx_optimizers/muon.py#L104) might be dangerous because it has a strong assumption to work under the nanoGPT setting.

The implementation is easy, see an example here: https://github.com/MoonshotAI/Moonlight/blob/master/examples/toy_train.py#L197-L203

These suggestions are empirically helpful to over-train as observed during our pretraining on Moonlight. What are your guys' opinions?

@lin72h
Copy link

lin72h commented Mar 1, 2025

@toothacher17 Wow! The Moonlight team just popped in with actual scaling tips! 🚀 Love seeing them share those crucial details about weight decay and matrix shape adjustments. This is what makes open source so awesome - experts freely sharing knowledge that turns theory into production-ready code. MLX bringing ML minds together at its finest!

@Goekdeniz-Guelmez
Copy link
Author

@lin72h @toothacher17 I agree I did not see that coming, but its very welcome.

@toothacher17
Copy link

@lin72h @Goekdeniz-Guelmez

Thanks! Our team, Moonshot AI, believed that Muon is scalable, and did some empirical experiments! In case you guys are interested, we have a tech report discussing it: https://arxiv.org/abs/2502.16982

@lin72h
Copy link

lin72h commented Mar 1, 2025

@toothacher17 Just read the Moonshot paper - same as K1 paper and even more innovative than DeepSeek's work! you folks at Moonshot haven't gotten the attention you deserve - your work just got overshadowed by DeepSeek's timing. The ideas in Moonlight are truly incredible. Open-sourcing this level of innovation is something to be genuinely proud of.
华人骄傲!

@toothacher17
Copy link

@lin72h
Thanks a lot for the kind words! Deepseek is truly doing a great job! I personally very much admire their contributions to push forward the progress of open source and AGI. We are a humble team and we'll keep work hard to deliver and publish good stuffs!

@lin72h
Copy link

lin72h commented Mar 1, 2025

@toothacher17 Keep up the awesome work! Moonshot rocks!

@stockeh
Copy link

stockeh commented Mar 1, 2025

@toothacher17 I appreciate you sharing your insights! I found the paper to be quiet informative.

I think most of these changes are easy enough to implement into mlx-optimizers. Although, I do wish there were an easier way to delegate the parameters we'd want to use with Muon and others with AdamW (for example), considering the difference in how torch and mlx optimizers are initialized. That said, I'm happy to add wd and scaling, as you have, with what's in there now!

@awni
Copy link
Member

awni commented Mar 1, 2025

I do wish there were an easier way to delegate the parameters we'd want to use with Muon and others with AdamW

Can you say more about that?

@toothacher17
Copy link

I do wish there were an easier way to delegate the parameters we'd want to use with Muon and others with AdamW

Can you say more about that?

I guess that's because for now, AdamW is chained with Muon to handle those non-matrix parameters, e.g. embedding, lm head, and rmsnorm gamma. In future, there might be a chance to get rid of AdamW and only use Muon purely, for example: https://github.com/modula-systems/modula It's not large scale proven yet, but it might be promising

@stockeh
Copy link

stockeh commented Mar 3, 2025

I do wish there were an easier way to delegate the parameters we'd want to use with Muon and others with AdamW

Can you say more about that?

@awni tldr: I don't think anything has to change with mlx, specifically, but I may change mlx-optimizers' Muon class to not include AdamW and simplify the delegation logic with a separate optim.

I originally said this when thinking about how we pass params to the optimizer, e.g., in KellerJordan/Muon

muon_params = [p for p in model.body.parameters() if p.ndim >= 2]
adamw_params = ([p for p in model.body.parameters() if p.ndim < 2]
              + [*model.head.parameters(), *model.embed.parameters()])
optimizers = [Muon(muon_params, lr=0.02, momentum=0.95),
              torch.optim.AdamW(adamw_params, lr=3e-4, betas=(0.90, 0.95), weight_decay=0.01)]
...
# in the training step
for opt in optimizers:
    opt.step()

Moonlight's implementation differs in that their custom Muon class accepts both muon_params and adamw_params to have only one optimizer. This kind of logic is a bit more challenging to generalize in mlx if we wanted to have more custom rules, e.g., set of layer names and dimensionality.

But, I thought about this some more and think it's easier as a general approach to just define multiple optimizers as we've discussed in this discussion, i.e.,

def split_grads(grads):
    grads = tree_flatten(grads)
    weights = [(k, v) for k, v in grads if v.ndim == 2]
    biases = [(k, v) for k, v in grads if v.ndim == 1]
    weights = tree_unflatten(weights)
    biases = tree_unflatten(biases)
    return weights, biases

@partial(mx.compile, inputs=state, outputs=state)
def step(X, T):
    train_step_fn = nn.value_and_grad(self.model, self.eval_fn)
    loss, grads = train_step_fn(X, T)
    weights, biases = split_grads(grads)
    self.optimizers[0].update(self.model, weights)
    self.optimizers[1].update(self.model, biases)
    return loss

This would just require a bit of a refactor and description for using Muon in mlx-optimizers, should the optims be separate.

@awni
Copy link
Member

awni commented Mar 3, 2025

Thanks for the detailed explanation, that makes sense!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants