Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Weight Normalization, addressing issue #1888 #1921

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

cavit99
Copy link

@cavit99 cavit99 commented Mar 4, 2025

Proposed changes

This PR implements weight normalization for MLX, addressing issue #1888. Weight normalization is a reparameterization technique that decouples the magnitude of a weight tensor from its direction, making optimization more efficient by improving the conditioning of the optimization problem. Is particularly important for audio processing, among other applications.

Key Features

  • Core C++ implementation of mx.weight_norm with optimized paths for different dimensions
  • Python module weight_norm.py with user-friendly API and layer wrappers
  • Proper handling of MLX's channel ordering differences from PyTorch
  • Workaround for the linalg::norm 2-axes limitation
  • Convenience classes for common layer types (Linear, Conv1d, Conv2d)
  • Comprehensive test suite validating mathematical properties and cross-framework compatibility

Implementation Details

Core C++ Implementation

The core weight_norm operation is implemented with three different paths based on the number of axes to normalize over:

  1. Direct path for 1-2 axes using optimized linalg::norm kernels
  2. Reshape-based approach for >2 axes, which:
    • Identifies dimensions to keep vs. normalize
    • Handles special cases: normalizing all dims, keeping one dim, keeping multiple dims
    • Reshapes appropriately to leverage the optimized 2D norm kernel
    • Reshapes results back for broadcasting

Python Layer

The Python implementation:

  • Provides a weight_norm function that wraps MLX modules
  • Handles dimension ordering differences for different layer types
  • Computes initial g parameter as the norm of the original weight
  • Overrides the module's forward pass to apply weight normalization on-the-fly
  • Includes convenience classes (WeightNormLinear, WeightNormConv1d, WeightNormConv2d)

Testing and Verification

Testing follows a comprehensive two-pronged approach:

1. Mathematical Property Tests

  • Verify that the normalized weights have the correct norm (equals g)
  • Confirm that the direction of normalized weights matches v
  • Validate that changing g correctly scales the weight norms
  • Test edge cases like normalizing over all dimensions

2. Cross-Framework Verification

  • Compare against PyTorch's weight normalization
  • Test both independent implementations and direct weight transfer
  • Document expected differences between frameworks and how to achieve exact equivalence

3. Performance Benchmarks

download
Benchmarked on Apple M3 Max shows MLX outperforms PyTorch MPS:

  • Linear layers (1 axis): 4.90x-5.26x speedup
  • Conv1d layers (2 axes): 1.46x-2.05x speedup
  • Conv2d layers (3 axes): 1.50x-1.76x speedup

Usage Examples

Core API

import mlx.core as mx

# Create parameters
v = mx.random.normal((64, 3, 3))  # Direction tensor
g = mx.random.normal((64, 1, 1))  # Magnitude tensor

# Apply weight normalization
w = mx.weight_norm(v, g, axes=[1, 2])

Module API

import mlx.nn as nn
from mlx.nn.layers.weight_norm import weight_norm

# Apply to existing layer
linear = nn.Linear(10, 20)
linear_wn = weight_norm(linear)

# Use convenience class
conv1d_wn = nn.WeightNormConv1d(16, 32, kernel_size=3)

Resolves #1888.

Checklist

  • [ X] I have read the CONTRIBUTING document
  • [X ] I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • [X ] I have added tests that prove my fix is effective or that my feature works
  • [ X] I have updated the necessary documentation (if needed)

cavit99 added 3 commits March 4, 2025 01:30
Improve the weight normalization implementation by:
Use the optimized C++ mx.weight_norm() in WeightNormWrapper.call
Add comprehensive tests for WeightNormConv2d
Verify direct API usage matches module wrapper results
Test normalization over multiple axes and edge cases
Add specific test for GitHub issue ml-explore#1888
This change ensures maximum performance by leveraging the C++
implementation with its optimized handling of >2 axes normalization.
@Blaizzy
Copy link

Blaizzy commented Mar 4, 2025

Thanks a lot @cavit99, this is great work!

One tiny nit:

  • Could we change the weight naming to weight_g and weight_v? Makes easier to map from torch and remember.

@Blaizzy
Copy link

Blaizzy commented Mar 4, 2025

As far as I can tell from intial testing.

This PR does address my issues..

Screenshot 2025-03-04 at 7 11 59 PM

The only difference is that I prefered using torch channel first for loading and transposed the weight at run time.

Because Kokoro has a lot transpose operations(~35) and wanted to avoid bugs.

class WeightNormConv1D(nn.Module):
    """Conv1d with weight normalization"""

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        stride: int = 1,
        padding: int = 1,
        dilation: int = 1,
        groups: int = 1,
        bias: bool = True,
        encode: bool = False,
    ):
        super().__init__()

        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = groups

        # Initialize weight magnitude (g) and direction (v) vectors
        self.weight_g = mx.ones((out_channels, 1, 1))  # Scalar magnitude per output channel
        self.weight_v = mx.ones(
            (out_channels, kernel_size, in_channels)
        )  # Direction vectors

        self.bias = mx.zeros(in_channels if encode else out_channels) if bias else None

    def __call__(self, x, conv):

        weight = weight_norm(self.weight_v, self.weight_g, dim=0)

        if self.bias is not None:
            bias = self.bias.reshape(1, 1, -1)
        try:
                ...
                # Input is channels last, need to transpose weight
                return apply_conv(x, weight.T)
        except Exception as e:
            print(f"Error: {e}")
            print(f"x.shape: {x.shape}, weight.shape: {weight.shape}")
            raise e

@cavit99
Copy link
Author

cavit99 commented Mar 4, 2025

Thanks a lot @cavit99, this is great work!

One tiny nit:

  • Could we change the weight naming to weight_g and weight_v? Makes easier to map from torch and remember.

agreed from my side, so I pushed that change to the PR, thank you

@Blaizzy
Copy link

Blaizzy commented Mar 4, 2025

Perfect! 🤩

Now we wait for @awni :)

@cavit99
Copy link
Author

cavit99 commented Mar 4, 2025

he's gonna look and say meh, maybe if you stick it in normalization.py isn't he

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Feature Request: Add Weight Normalization Support (weight_norm)
2 participants