Skip to content

Commit

Permalink
improve memory footprint of torch.testing.assert_close
Browse files Browse the repository at this point in the history
ghstack-source-id: ec7cd022806cea09dfd1cd4e1e91477d4d5dedf4
Pull Request resolved: #96131
  • Loading branch information
pmeier committed Mar 29, 2023
1 parent 4ae4c6f commit 3c3216a
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 53 deletions.
13 changes: 0 additions & 13 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -10082,19 +10082,6 @@ def test_assert_close(self):
# with self.assertRaisesRegex(AssertionError, "Tensor-likes are not close!"):
# torch.testing.assert_close(a, nan)

@unittest.expectedFailure
def test_mps_compat(self):
# If this test is successful, that means that all operations in the comparison logic are supported natively on
# the MPS backend. Please remove this test as well as the compatibility logic in
# torch.testing._comparison.TensorLikePair._equalize_attributes
actual = torch.tensor(1.0, device="mps")
expected = actual.clone()

# We can't use assert_close or TensorLikePair.compare() directly, since that would hit the compatibility logic
# in torch.testing._comparison.TensorLikePair._equalize_attributes that we want to circumvent here
pair = TensorLikePair(actual, expected)
pair._compare_values(actual, expected)

def test_double_error(self):
with self.assertRaisesRegex(TypeError, "the MPS framework doesn't support float64"):
a = torch.ones(2, dtype=torch.float64, device="mps")
Expand Down
77 changes: 37 additions & 40 deletions torch/testing/_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,8 @@ def make_diff_msg(


def make_scalar_mismatch_msg(
actual: Union[int, float, complex],
expected: Union[int, float, complex],
actual: Union[bool, int, float, complex],
expected: Union[bool, int, float, complex],
*,
rtol: float,
atol: float,
Expand All @@ -215,8 +215,8 @@ def make_scalar_mismatch_msg(
"""Makes a mismatch error message for scalars.
Args:
actual (Union[int, float, complex]): Actual scalar.
expected (Union[int, float, complex]): Expected scalar.
actual (Union[bool, int, float, complex]): Actual scalar.
expected (Union[bool, int, float, complex]): Expected scalar.
rtol (float): Relative tolerance.
atol (float): Absolute tolerance.
identifier (Optional[Union[str, Callable[[str], str]]]): Optional description for the scalars. Can be passed
Expand All @@ -239,7 +239,7 @@ def make_scalar_mismatch_msg(
def make_tensor_mismatch_msg(
actual: torch.Tensor,
expected: torch.Tensor,
mismatches: torch.Tensor,
matches: torch.Tensor,
*,
rtol: float,
atol: float,
Expand All @@ -250,8 +250,8 @@ def make_tensor_mismatch_msg(
Args:
actual (torch.Tensor): Actual tensor.
expected (torch.Tensor): Expected tensor.
mismatches (torch.Tensor): Boolean mask of the same shape as ``actual`` and ``expected`` that indicates the
location of mismatches.
matches (torch.Tensor): Boolean mask of the same shape as ``actual`` and ``expected`` that indicates the
location of matches.
rtol (float): Relative tolerance.
atol (float): Absolute tolerance.
identifier (Optional[Union[str, Callable[[str], str]]]): Optional description for the tensors. Can be passed
Expand All @@ -260,34 +260,40 @@ def make_tensor_mismatch_msg(
"""

def unravel_flat_index(flat_index: int) -> Tuple[int, ...]:
if not mismatches.shape:
if not matches.shape:
return ()

inverse_index = []
for size in mismatches.shape[::-1]:
for size in matches.shape[::-1]:
div, mod = divmod(flat_index, size)
flat_index = div
inverse_index.append(mod)

return tuple(inverse_index[::-1])

number_of_elements = mismatches.numel()
total_mismatches = torch.sum(mismatches).item()
number_of_elements = matches.numel()
total_mismatches = number_of_elements - int(torch.sum(matches))
extra = (
f"Mismatched elements: {total_mismatches} / {number_of_elements} "
f"({total_mismatches / number_of_elements:.1%})"
)

a_flat = actual.flatten()
b_flat = expected.flatten()
matches_flat = ~mismatches.flatten()
actual_flat = actual.flatten()
expected_flat = expected.flatten()
matches_flat = matches.flatten()

abs_diff = torch.abs(a_flat - b_flat)
if not actual.dtype.is_floating_point and not actual.dtype.is_complex:
# TODO: Instead of always upcasting to int64, it would be sufficient to cast to the next higher dtype to avoid
# overflow
actual_flat = actual_flat.to(torch.int64)
expected_flat = expected_flat.to(torch.int64)

abs_diff = torch.abs(actual_flat - expected_flat)
# Ensure that only mismatches are used for the max_abs_diff computation
abs_diff[matches_flat] = 0
max_abs_diff, max_abs_diff_flat_idx = torch.max(abs_diff, 0)

rel_diff = abs_diff / torch.abs(b_flat)
rel_diff = abs_diff / torch.abs(expected_flat)
# Ensure that only mismatches are used for the max_rel_diff computation
rel_diff[matches_flat] = 0
max_rel_diff, max_rel_diff_flat_idx = torch.max(rel_diff, 0)
Expand Down Expand Up @@ -948,14 +954,24 @@ def _compare_sparse_compressed_values(
),
)

# Compressed and plain indices in the CSR / CSC / BSR / BSC sparse formates can be `torch.int32` _or_
# `torch.int64`. While the same dtype is enforced for the compressed and plain indices of a single tensor, it
# can be different between two tensors. Thus, we need to convert them to the same dtype, or the comparison will
# fail.
actual_compressed_indices = compressed_indices_method(actual)
expected_compressed_indices = compressed_indices_method(expected)
indices_dtype = torch.promote_types(
actual_compressed_indices.dtype, expected_compressed_indices.dtype
)

self._compare_regular_values_equal(
compressed_indices_method(actual),
compressed_indices_method(expected),
actual_compressed_indices.to(indices_dtype),
expected_compressed_indices.to(indices_dtype),
identifier=f"Sparse {format_name} {compressed_indices_method.__name__}",
)
self._compare_regular_values_equal(
plain_indices_method(actual),
plain_indices_method(expected),
plain_indices_method(actual).to(indices_dtype),
plain_indices_method(expected).to(indices_dtype),
identifier=f"Sparse {format_name} {plain_indices_method.__name__}",
)
self._compare_regular_values_close(
Expand Down Expand Up @@ -991,7 +1007,6 @@ def _compare_regular_values_close(
identifier: Optional[Union[str, Callable[[str], str]]] = None,
) -> None:
"""Checks if the values of two tensors are close up to a desired tolerance."""
actual, expected = self._promote_for_comparison(actual, expected)
matches = torch.isclose(
actual, expected, rtol=rtol, atol=atol, equal_nan=equal_nan
)
Expand All @@ -1008,28 +1023,10 @@ def _compare_regular_values_close(
)
else:
msg = make_tensor_mismatch_msg(
actual, expected, ~matches, rtol=rtol, atol=atol, identifier=identifier
actual, expected, matches, rtol=rtol, atol=atol, identifier=identifier
)
self._fail(AssertionError, msg)

def _promote_for_comparison(
self, actual: torch.Tensor, expected: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Promotes the inputs to the comparison dtype based on the input dtype.
Returns:
Inputs promoted to the highest precision dtype of the same dtype category. :class:`torch.bool` is treated
as integral dtype.
"""
# This is called after self._equalize_attributes() and thus `actual` and `expected` already have the same dtype.
if actual.dtype.is_complex:
dtype = torch.complex128
elif actual.dtype.is_floating_point:
dtype = torch.float64
else:
dtype = torch.int64
return actual.to(dtype), expected.to(dtype)

def extra_repr(self) -> Sequence[str]:
return (
"rtol",
Expand Down

0 comments on commit 3c3216a

Please sign in to comment.