-
Notifications
You must be signed in to change notification settings - Fork 235
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
mx cleanup [2/x]: refactor mx gemm #1593
Merged
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Stack from ghstack (oldest at bottom): |
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1593
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
vkuzo
added a commit
that referenced
this pull request
Jan 21, 2025
Summary: Refactors the MX gemm emulation code to properly emulate the memory layout constraints we expect from the future mx-enabled hardware, where we expect: * the first argument to the mx gemm to be required row-major memory format * the second argument to the mx gemm to be required col-major memory format Note that two morally unrelated issues were uncovered with this refactor: 1. when autocast is on, compile is no longer matching eager numerics. Since the "before this PR" state isn't really representative of the world, I'm treating this as a newly uncovered issue, and we can fix it in a future PR. 2. our transpose logic for fp4 packed into two elements per byte doesn't work for tensors of shape (M, 1), because we currently rely on the `is_contiguous()` function to see if our tensor was transposed. We could work around, but punting that until a time that becomes important. I expect most tensors in real world usage with MX to not hit this case. Test Plan: ``` pytest test/prototype/mx_formats/ -s -x ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 8d6ca220c947880f5221b284149f5563745a517d ghstack-comment-id: 2605962974 Pull Request resolved: #1593
vkuzo
added a commit
that referenced
this pull request
Jan 21, 2025
Summary: Refactors the MX gemm emulation code to properly emulate the memory layout constraints we expect from the future mx-enabled hardware, where we expect: * the first argument to the mx gemm to be required row-major memory format * the second argument to the mx gemm to be required col-major memory format Note that two morally unrelated issues were uncovered with this refactor: 1. when autocast is on, compile is no longer matching eager numerics. Since the "before this PR" state isn't really representative of the world, I'm treating this as a newly uncovered issue, and we can fix it in a future PR. 2. our transpose logic for fp4 packed into two elements per byte doesn't work for tensors of shape (M, 1), because we currently rely on the `is_contiguous()` function to see if our tensor was transposed. We could work around, but punting that until a time that becomes important. I expect most tensors in real world usage with MX to not hit this case. Test Plan: ``` pytest test/prototype/mx_formats/ -s -x ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 8d6ca220c947880f5221b284149f5563745a517d ghstack-comment-id: 2605962974 Pull Request resolved: #1593
vkuzo
added a commit
that referenced
this pull request
Jan 21, 2025
Summary: Refactors the MX gemm emulation code to properly emulate the memory layout constraints we expect from the future mx-enabled hardware, where we expect: * the first argument to the mx gemm to be required row-major memory format * the second argument to the mx gemm to be required col-major memory format Note that two morally unrelated issues were uncovered with this refactor: 1. when autocast is on, compile is no longer matching eager numerics. Since the "before this PR" state isn't really representative of the world, I'm treating this as a newly uncovered issue, and we can fix it in a future PR. 2. our transpose logic for fp4 packed into two elements per byte doesn't work for tensors of shape (M, 1), because we currently rely on the `is_contiguous()` function to see if our tensor was transposed. We could work around, but punting that until a time that becomes important. I expect most tensors in real world usage with MX to not hit this case. Test Plan: ``` pytest test/prototype/mx_formats/ -s -x ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: ebe8708f2f4a2f68b198fff092e71b4048cd3662 ghstack-comment-id: 2605962974 Pull Request resolved: #1593
vkuzo
added a commit
that referenced
this pull request
Jan 21, 2025
Summary: Refactors the MX gemm emulation code to properly emulate the memory layout constraints we expect from the future mx-enabled hardware, where we expect: * the first argument to the mx gemm to be required row-major memory format * the second argument to the mx gemm to be required col-major memory format Note that two morally unrelated issues were uncovered with this refactor: 1. when autocast is on, compile is no longer matching eager numerics. Since the "before this PR" state isn't really representative of the world, I'm treating this as a newly uncovered issue, and we can fix it in a future PR. 2. our transpose logic for fp4 packed into two elements per byte doesn't work for tensors of shape (M, 1), because we currently rely on the `is_contiguous()` function to see if our tensor was transposed. We could work around, but punting that until a time that becomes important. I expect most tensors in real world usage with MX to not hit this case. Test Plan: ``` pytest test/prototype/mx_formats/ -s -x ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: a828471c404846d83b8cac41579bd3a67125c9c9 ghstack-comment-id: 2605962974 Pull Request resolved: #1593
vkuzo
added a commit
that referenced
this pull request
Jan 21, 2025
Summary: Refactors the MX gemm emulation code to properly emulate the memory layout constraints we expect from the future mx-enabled hardware, where we expect: * the first argument to the mx gemm to be required row-major memory format * the second argument to the mx gemm to be required col-major memory format Note that two morally unrelated issues were uncovered with this refactor: 1. when autocast is on, compile is no longer matching eager numerics. Since the "before this PR" state isn't really representative of the world, I'm treating this as a newly uncovered issue, and we can fix it in a future PR. 2. our transpose logic for fp4 packed into two elements per byte doesn't work for tensors of shape (M, 1), because we currently rely on the `is_contiguous()` function to see if our tensor was transposed. We could work around, but punting that until a time that becomes important. I expect most tensors in real world usage with MX to not hit this case. Test Plan: ``` pytest test/prototype/mx_formats/ -s -x ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: a5253c97882449ff91dd623c10757e000cee28da ghstack-comment-id: 2605962974 Pull Request resolved: #1593
vkuzo
added a commit
that referenced
this pull request
Jan 21, 2025
Summary: Refactors the MX gemm emulation code to properly emulate the memory layout constraints we expect from the future mx-enabled hardware, where we expect: * the first argument to the mx gemm to be required row-major memory format * the second argument to the mx gemm to be required col-major memory format Note that two morally unrelated issues were uncovered with this refactor: 1. when autocast is on, compile is no longer matching eager numerics. Since the "before this PR" state isn't really representative of the world, I'm treating this as a newly uncovered issue, and we can fix it in a future PR. 2. our transpose logic for fp4 packed into two elements per byte doesn't work for tensors of shape (M, 1), because we currently rely on the `is_contiguous()` function to see if our tensor was transposed. We could work around, but punting that until a time that becomes important. I expect most tensors in real world usage with MX to not hit this case. Test Plan: ``` pytest test/prototype/mx_formats/ -s -x ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: af87d8b65132372f4915312ea71482f6862c4df2 ghstack-comment-id: 2605962974 Pull Request resolved: #1593
drisspg
approved these changes
Jan 24, 2025
vkuzo
added a commit
that referenced
this pull request
Jan 24, 2025
Summary: Refactors the MX gemm emulation code to properly emulate the memory layout constraints we expect from the future mx-enabled hardware, where we expect: * the first argument to the mx gemm to be required row-major memory format * the second argument to the mx gemm to be required col-major memory format Note that two morally unrelated issues were uncovered with this refactor: 1. when autocast is on, compile is no longer matching eager numerics. Since the "before this PR" state isn't really representative of the world, I'm treating this as a newly uncovered issue, and we can fix it in a future PR. 2. our transpose logic for fp4 packed into two elements per byte doesn't work for tensors of shape (M, 1), because we currently rely on the `is_contiguous()` function to see if our tensor was transposed. We could work around, but punting that until a time that becomes important. I expect most tensors in real world usage with MX to not hit this case. Test Plan: ``` pytest test/prototype/mx_formats/ -s -x ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 85912dadcbff63210a0c9b5ace12751d75b7a8ca ghstack-comment-id: 2605962974 Pull Request resolved: #1593
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Labels
CLA Signed
This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
topic: not user facing
Use this tag if you don't want this PR to show up in release notes
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary:
Refactors the MX gemm emulation code to properly emulate the memory layout
constraints we expect from the future mx-enabled hardware, where we
expect:
format
format
Fixes #1501 and #932
Note that two morally unrelated issues were uncovered with this
refactor:
Since the "before this PR" state isn't really representative of the
world, I'm treating this as a newly uncovered issue, and we can fix
it in a future PR.
work for tensors of shape (M, 1), because we currently rely on the
is_contiguous()
function to see if our tensor was transposed. Wecould work around, but punting that until a time that becomes
important. I expect most tensors in real world usage with MX to not
hit this case.
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags: