-
Notifications
You must be signed in to change notification settings - Fork 23.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
improve memory footprint of torch.testing.assert_close #96131
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/96131
Note: Links to docs will display an error until the docs builds have been completed. ⏳ No Failures, 1 PendingAs of commit 8e1f3aa: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@@ -1008,28 +1013,10 @@ def _compare_regular_values_close( | |||
) | |||
else: | |||
msg = make_tensor_mismatch_msg( | |||
actual, expected, ~matches, rtol=rtol, atol=atol, identifier=identifier |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the error case, we created the mismatches = ~matches
tensor here and turned it back into a matches
inside make_tensor_mismatch_msg
. With a minor refactoring, we no longer need to invert matches
and can use it directly.
@@ -991,7 +997,6 @@ def _compare_regular_values_close( | |||
identifier: Optional[Union[str, Callable[[str], str]]] = None, | |||
) -> None: | |||
"""Checks if the values of two tensors are close up to a desired tolerance.""" | |||
actual, expected = self._promote_for_comparison(actual, expected) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We unconditionally upcasted here in the past, since that was needed for isclose
. This is no longer the case and so we can just drop that.
if not actual.dtype.is_floating_point and not actual.dtype.is_complex: | ||
# TODO: Instead of always upcasting to int64, it would be sufficient to cast to the next higher dtype to avoid | ||
# overflow | ||
actual_flat = actual_flat.to(torch.int64) | ||
expected_flat = expected_flat.to(torch.int64) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
However, we still need to upcast in the error case, since we want to display the absolute diff and that is not supported for torch.bool
and might overflow for other integer dtypes.
actual_flat = actual.flatten() | ||
expected_flat = expected.flatten() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Driveby renaming. a
and b
were only used in the beginning and should be actual
and expected
now.
Redo of #90172 out of stack. [ghstack-poisoned]
Redo of #90172 out of stack. [ghstack-poisoned]
Redo of #90172 out of stack. [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks, @pmeier!
I have an OT feature request, so feel free to ignore it.
# Ensure that only mismatches are used for the max_abs_diff computation | ||
abs_diff[matches_flat] = 0 | ||
max_abs_diff, max_abs_diff_flat_idx = torch.max(abs_diff, 0) | ||
|
||
rel_diff = abs_diff / torch.abs(b_flat) | ||
rel_diff = abs_diff / torch.abs(expected_flat) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A slight OT suggestion: could we have a better normalization factor here (say, (torch.abs(actual) + torch.abs(expected)) / 2
) for the case where expected
contains zeros (having zeros is typical, say when comparing the indices of sparse tensors)? Atm, the mismatch messages from assert_close
depends on the order of inputs, for example:
>>> torch.testing.assert_close(torch.tensor([1, 0]), torch.tensor([1, 1]))
<snip>
Mismatched elements: 1 / 2 (50.0%)
Greatest absolute difference: 1 at index (1,)
Greatest relative difference: 1.0 at index (1,)
>>> torch.testing.assert_close(torch.tensor([1, 1]), torch.tensor([1, 0]))
<snip>
Mismatched elements: 1 / 2 (50.0%)
Greatest absolute difference: 1 at index (1,)
Greatest relative difference: inf at index (1,)
(btw, reporting relative differences for non-float tensors is often pointless as well).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Atm, the mismatch messages from
assert_close
depends on the order of inputs, for example:
It's not just the messages, it is the actual op. Internally, we rely on torch.isclose
and that is already asymmetric. It defines closeness as abs(actual - expected) <= atol + rtol * abs(expected)
. Believe me when I say, we (torch.testing
team) wanted to change that, but there is just too much inertia. PyTorch is not an outlier here; numpy
(and virtually every other array library) is doing the same.
Pythons math
module is doing the more sensible thing in defining closeness as abs(actual - expected) <= max(atol, rtol * max(abs(actual), abs(expected)))
. You can read more about this whole issue in PEP485.
At some point we tried to get this behavior specified by the Array API, but couldn't gain enough traction. See data-apis/array-api#170.
(btw, reporting relative differences for non-float tensors is often pointless as well).
Doesn't that somewhat contradict the use case you gave earlier?
for the case where expected contains zeros (having zeros is typical, say when comparing the indices of sparse tensors)
Redo of #90172 out of stack. [ghstack-poisoned]
Rebased |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pytorchbot merge -r |
@pytorchbot successfully started a rebase job. Check the current status here |
Redo of #90172 out of stack. [ghstack-poisoned]
Successfully rebased |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: trunk / macos-12-py3-arm64-mps / test (default, 1, 1) Details for Dev Infra teamRaised by workflow job |
Redo of #90172 out of stack. [ghstack-poisoned]
test/test_mps.py
Outdated
@@ -10070,7 +10070,7 @@ def test_mps_compat(self): | |||
# If this test is successful, that means that all operations in the comparison logic are supported natively on | |||
# the MPS backend. Please remove this test as well as the compatibility logic in | |||
# torch.testing._comparison.TensorLikePair._equalize_attributes | |||
actual = torch.tensor(1.0, device="mps") | |||
actual = torch.zeros(2, 3, 4, 5, device="mps") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kulinseth I've increased the shape to 4 dimensions here, because otherwise this test would pass although torch.testing.assert_close
is not ready. See #95538 for details.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good
@kulinseth It seems the test is still passing: https://hud.pytorch.org/pr/96131#12075596190. Does that mean the behavior was fixed? Otherwise, could you send me a patch that consistently makes this test fail? I don't have access to a MPS machine and don't want to waste CI resources by pushing multiple times just for this one test. |
@kulinseth any update on this? |
Sorry for delay @pmeier . I think we have support till 4 dims, if we increase the dimensions to 5 , then test starts failing . |
Argh, my bad. Let me fix that. |
Redo of #90172 out of stack. [ghstack-poisoned]
@kulinseth Test fails now, but unfortunately, the error is not recoverable
and thus we never hit the xfail. Not sure what to do with this. I'll remove the test to unblock and leave a comment in #95538. LMK if you want to handle it differently. |
Redo of #90172 out of stack. [ghstack-poisoned]
ghstack-source-id: 1b796fd7695e8ba2673eb05cccf2f7d9174b21bd Pull Request resolved: #96131
@pytorchbot merge -r viable/strict |
@pytorchbot successfully started a rebase job. Check the current status here |
Redo of #90172 out of stack. [ghstack-poisoned]
Successfully rebased |
ghstack-source-id: ec7cd022806cea09dfd1cd4e1e91477d4d5dedf4 Pull Request resolved: #96131
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: GraphQL query fragment PRCheckSuites on CheckSuiteConnection { fragment CommitAuthors on PullRequestCommitConnection { query ($owner: String!, $name: String!, $number: Int!) { Details for Dev Infra teamRaised by workflow job |
Redo of #90172 out of stack. [ghstack-poisoned]
ghstack-source-id: 11844b06eccc59a5eca1d577c2d6538427e74461 Pull Request resolved: #96131
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack (oldest at bottom):
Redo of #90172 out of stack.