Skip to content

Commit

Permalink
support bfloat16
Browse files Browse the repository at this point in the history
  • Loading branch information
akihironitta committed Jul 29, 2024
1 parent 2556c97 commit 1a71007
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
7 changes: 6 additions & 1 deletion pyg_lib/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,12 @@ def _(inputs, ptr, other):
assert ptr.dim() == 1
assert other.dim() == 3
assert ptr.size() == (other.size(0) + 1, )
return torch.empty(inputs.size(0), other.size(2), device=inputs.device)
return torch.empty(
inputs.size(0),
other.size(2),
device=inputs.device,
dtype=inputs.dtype,
)


def sampled_add(
Expand Down
7 changes: 6 additions & 1 deletion test/ops/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,16 @@ def test_segment_matmul_autograd(dtype, device):
pytest.param(torch.float32, id='float32'),
pytest.param(torch.bfloat16, id='bfloat16'),
])
@pytest.mark.parametrize('requires_grad', [False, True])
@pytest.mark.parametrize('requires_grad', [
pytest.param(False, id='requires_grad_False'),
pytest.param(True, id='requires_grad_True'),
])
@pytest.mark.skipif(not _WITH_PT24, reason='PyTorch 2.4.0 is required')
def test_segment_matmul_opcheck(device, dtype, requires_grad):
if requires_grad:
pytest.skip('TODO: Support requires_grad=True')
if device.type == 'cuda' and dtype == torch.bfloat16:
pytest.skip('CUDA does not support bfloat16')

from torch.library import opcheck

Expand Down

0 comments on commit 1a71007

Please sign in to comment.