Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mx cleanup [2/x]: refactor mx gemm #1593

Merged
merged 10 commits into from
Jan 24, 2025
Merged

mx cleanup [2/x]: refactor mx gemm #1593

merged 10 commits into from
Jan 24, 2025

Conversation

vkuzo
Copy link
Contributor

@vkuzo vkuzo commented Jan 21, 2025

Summary:

Refactors the MX gemm emulation code to properly emulate the memory layout
constraints we expect from the future mx-enabled hardware, where we
expect:

  • the first argument to the mx gemm to be required row-major memory
    format
  • the second argument to the mx gemm to be required col-major memory
    format

Fixes #1501 and #932

Note that two morally unrelated issues were uncovered with this
refactor:

  1. when autocast is on, compile is no longer matching eager numerics.
    Since the "before this PR" state isn't really representative of the
    world, I'm treating this as a newly uncovered issue, and we can fix
    it in a future PR.
  2. our transpose logic for fp4 packed into two elements per byte doesn't
    work for tensors of shape (M, 1), because we currently rely on the
    is_contiguous() function to see if our tensor was transposed. We
    could work around, but punting that until a time that becomes
    important. I expect most tensors in real world usage with MX to not
    hit this case.

Test Plan:

pytest test/prototype/mx_formats/ -s -x

Reviewers:

Subscribers:

Tasks:

Tags:

vkuzo added 3 commits January 15, 2025 12:35
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
@vkuzo
Copy link
Contributor Author

vkuzo commented Jan 21, 2025

Copy link

pytorch-bot bot commented Jan 21, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1593

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 21, 2025
vkuzo added a commit that referenced this pull request Jan 21, 2025
Summary:

Refactors the MX gemm emulation code to properly emulate the memory layout
constraints we expect from the future mx-enabled hardware, where we
expect:
* the first argument to the mx gemm to be required row-major memory
  format
* the second argument to the mx gemm to be required col-major memory
  format

Note that two morally unrelated issues were uncovered with this
refactor:
1. when autocast is on, compile is no longer matching eager numerics.
   Since the "before this PR" state isn't really representative of the
   world, I'm treating this as a newly uncovered issue, and we can fix
   it in a future PR.
2. our transpose logic for fp4 packed into two elements per byte doesn't
   work for tensors of shape (M, 1), because we currently rely on the
   `is_contiguous()` function to see if our tensor was transposed. We
   could work around, but punting that until a time that becomes
   important. I expect most tensors in real world usage with MX to not
   hit this case.

Test Plan:

```
pytest test/prototype/mx_formats/ -s -x
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 8d6ca220c947880f5221b284149f5563745a517d
ghstack-comment-id: 2605962974
Pull Request resolved: #1593
@vkuzo vkuzo added the topic: not user facing Use this tag if you don't want this PR to show up in release notes label Jan 21, 2025
vkuzo added a commit that referenced this pull request Jan 21, 2025
Summary:

Refactors the MX gemm emulation code to properly emulate the memory layout
constraints we expect from the future mx-enabled hardware, where we
expect:
* the first argument to the mx gemm to be required row-major memory
  format
* the second argument to the mx gemm to be required col-major memory
  format

Note that two morally unrelated issues were uncovered with this
refactor:
1. when autocast is on, compile is no longer matching eager numerics.
   Since the "before this PR" state isn't really representative of the
   world, I'm treating this as a newly uncovered issue, and we can fix
   it in a future PR.
2. our transpose logic for fp4 packed into two elements per byte doesn't
   work for tensors of shape (M, 1), because we currently rely on the
   `is_contiguous()` function to see if our tensor was transposed. We
   could work around, but punting that until a time that becomes
   important. I expect most tensors in real world usage with MX to not
   hit this case.

Test Plan:

```
pytest test/prototype/mx_formats/ -s -x
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 8d6ca220c947880f5221b284149f5563745a517d
ghstack-comment-id: 2605962974
Pull Request resolved: #1593
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Jan 21, 2025
Summary:

Refactors the MX gemm emulation code to properly emulate the memory layout
constraints we expect from the future mx-enabled hardware, where we
expect:
* the first argument to the mx gemm to be required row-major memory
  format
* the second argument to the mx gemm to be required col-major memory
  format

Note that two morally unrelated issues were uncovered with this
refactor:
1. when autocast is on, compile is no longer matching eager numerics.
   Since the "before this PR" state isn't really representative of the
   world, I'm treating this as a newly uncovered issue, and we can fix
   it in a future PR.
2. our transpose logic for fp4 packed into two elements per byte doesn't
   work for tensors of shape (M, 1), because we currently rely on the
   `is_contiguous()` function to see if our tensor was transposed. We
   could work around, but punting that until a time that becomes
   important. I expect most tensors in real world usage with MX to not
   hit this case.

Test Plan:

```
pytest test/prototype/mx_formats/ -s -x
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: ebe8708f2f4a2f68b198fff092e71b4048cd3662
ghstack-comment-id: 2605962974
Pull Request resolved: #1593
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Jan 21, 2025
Summary:

Refactors the MX gemm emulation code to properly emulate the memory layout
constraints we expect from the future mx-enabled hardware, where we
expect:
* the first argument to the mx gemm to be required row-major memory
  format
* the second argument to the mx gemm to be required col-major memory
  format

Note that two morally unrelated issues were uncovered with this
refactor:
1. when autocast is on, compile is no longer matching eager numerics.
   Since the "before this PR" state isn't really representative of the
   world, I'm treating this as a newly uncovered issue, and we can fix
   it in a future PR.
2. our transpose logic for fp4 packed into two elements per byte doesn't
   work for tensors of shape (M, 1), because we currently rely on the
   `is_contiguous()` function to see if our tensor was transposed. We
   could work around, but punting that until a time that becomes
   important. I expect most tensors in real world usage with MX to not
   hit this case.

Test Plan:

```
pytest test/prototype/mx_formats/ -s -x
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: a828471c404846d83b8cac41579bd3a67125c9c9
ghstack-comment-id: 2605962974
Pull Request resolved: #1593
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Jan 21, 2025
Summary:

Refactors the MX gemm emulation code to properly emulate the memory layout
constraints we expect from the future mx-enabled hardware, where we
expect:
* the first argument to the mx gemm to be required row-major memory
  format
* the second argument to the mx gemm to be required col-major memory
  format

Note that two morally unrelated issues were uncovered with this
refactor:
1. when autocast is on, compile is no longer matching eager numerics.
   Since the "before this PR" state isn't really representative of the
   world, I'm treating this as a newly uncovered issue, and we can fix
   it in a future PR.
2. our transpose logic for fp4 packed into two elements per byte doesn't
   work for tensors of shape (M, 1), because we currently rely on the
   `is_contiguous()` function to see if our tensor was transposed. We
   could work around, but punting that until a time that becomes
   important. I expect most tensors in real world usage with MX to not
   hit this case.

Test Plan:

```
pytest test/prototype/mx_formats/ -s -x
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: a5253c97882449ff91dd623c10757e000cee28da
ghstack-comment-id: 2605962974
Pull Request resolved: #1593
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Jan 21, 2025
Summary:

Refactors the MX gemm emulation code to properly emulate the memory layout
constraints we expect from the future mx-enabled hardware, where we
expect:
* the first argument to the mx gemm to be required row-major memory
  format
* the second argument to the mx gemm to be required col-major memory
  format

Note that two morally unrelated issues were uncovered with this
refactor:
1. when autocast is on, compile is no longer matching eager numerics.
   Since the "before this PR" state isn't really representative of the
   world, I'm treating this as a newly uncovered issue, and we can fix
   it in a future PR.
2. our transpose logic for fp4 packed into two elements per byte doesn't
   work for tensors of shape (M, 1), because we currently rely on the
   `is_contiguous()` function to see if our tensor was transposed. We
   could work around, but punting that until a time that becomes
   important. I expect most tensors in real world usage with MX to not
   hit this case.

Test Plan:

```
pytest test/prototype/mx_formats/ -s -x
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: af87d8b65132372f4915312ea71482f6862c4df2
ghstack-comment-id: 2605962974
Pull Request resolved: #1593
@vkuzo vkuzo requested review from andrewor14 and drisspg January 21, 2025 23:52
vkuzo added 2 commits January 24, 2025 13:21
[ghstack-poisoned]
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Jan 24, 2025
Summary:

Refactors the MX gemm emulation code to properly emulate the memory layout
constraints we expect from the future mx-enabled hardware, where we
expect:
* the first argument to the mx gemm to be required row-major memory
  format
* the second argument to the mx gemm to be required col-major memory
  format

Note that two morally unrelated issues were uncovered with this
refactor:
1. when autocast is on, compile is no longer matching eager numerics.
   Since the "before this PR" state isn't really representative of the
   world, I'm treating this as a newly uncovered issue, and we can fix
   it in a future PR.
2. our transpose logic for fp4 packed into two elements per byte doesn't
   work for tensors of shape (M, 1), because we currently rely on the
   `is_contiguous()` function to see if our tensor was transposed. We
   could work around, but punting that until a time that becomes
   important. I expect most tensors in real world usage with MX to not
   hit this case.

Test Plan:

```
pytest test/prototype/mx_formats/ -s -x
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 85912dadcbff63210a0c9b5ace12751d75b7a8ca
ghstack-comment-id: 2605962974
Pull Request resolved: #1593
[ghstack-poisoned]
@vkuzo vkuzo changed the base branch from gh/vkuzo/17/head to main January 24, 2025 23:58
@vkuzo vkuzo merged commit 6b472e5 into main Jan 24, 2025
23 of 24 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: not user facing Use this tag if you don't want this PR to show up in release notes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Bug: MXLinear backward pass implementation
3 participants