Skip to content

Commit

Permalink
check scale.ndim before applying t/transpose (#1339)
Browse files Browse the repository at this point in the history
* check `scale.ndim` before applying `t`/`transpose`

because (a) `scale` could be 0D/1D and `transpose` and (b) the args and
kwargs of `torch.ops.aten.transpose.int` would supply `dim0` and `dim1`,
leading to cause dim canonicalization to fail.
e.g. [`torch._prims_common.canonicalize_dims`](https://github.com/pytorch/pytorch/blob/07906f2/torch/_prims_common/__init__.py#L704)

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>

* add test of `.t()` and `.transpose(0, 1)`

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>

* change cond to transpose scale

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>

---------

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
  • Loading branch information
crcrpar authored Dec 2, 2024
1 parent 2e36daa commit 65b885f
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 1 deletion.
19 changes: 19 additions & 0 deletions test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,25 @@ def test_copy_(self):
fp8_b.copy_(fp8_a)
torch.testing.assert_close(fp8_a._data, fp8_b._data)

def test_transpose(self):
a = torch.rand((16, 16), dtype=torch.bfloat16)
for axiswise_dim in (None, 0, -1):
scale_a = tensor_to_scale(a, e4m3_dtype)
fp8_a = hp_tensor_and_scale_to_float8(
a, scale_a, e4m3_dtype, axiswise_dim=axiswise_dim
)
fp8_b = hp_tensor_and_scale_to_float8(
a, scale_a, e4m3_dtype, axiswise_dim=axiswise_dim
)

fp8_a_transposed = fp8_a.transpose(0, 1)
fp8_b_t = fp8_b.t()

torch.testing.assert_close(
(fp8_a_transposed._data, fp8_a_transposed._scale),
(fp8_b_t._data, fp8_b_t._scale),
)

@pytest.mark.parametrize("shape", [(8, 16), (4, 8, 16), (2, 4, 8, 16)])
@pytest.mark.parametrize("axiswise_dim", [0, -1])
def test_axiswise_dynamic_cast(self, shape, axiswise_dim):
Expand Down
5 changes: 4 additions & 1 deletion torchao/float8/float8_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,10 @@ def float8_desugar_data_and_scale_op(aten_op, args, kwargs=None):
)
def float8_transpose(aten_op, args, kwargs=None):
new_data = aten_op(args[0]._data, *args[1:], **kwargs)
new_scale = aten_op(args[0]._scale, *args[1:], **kwargs)
if args[0]._scale.ndim > 1:
new_scale = aten_op(args[0]._scale, *args[1:], **kwargs)
else:
new_scale = args[0]._scale

if aten_op == aten.transpose.int:
_assert_tensorwise_scale(aten_op, args[0]._scale)
Expand Down

0 comments on commit 65b885f

Please sign in to comment.