-
Notifications
You must be signed in to change notification settings - Fork 651
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[backend] 3/3 Triton 2 update #272
Changes from all commits
713d6c6
3340f74
e6c4046
7ec9726
a9c6065
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,7 +19,7 @@ | |
def _create_blocksparse_tensor( | ||
device, block_size=32, Z=8, C=2, H=64, W=64, dtype=torch.float32 | ||
): | ||
layout = torch.randint(2, (C, H // block_size, W // block_size)) | ||
layout = torch.randint(2, (C, H // block_size, W // block_size), device=device) | ||
layout[:, :, 0] = 1 | ||
layout[:, 0, :] = 1 | ||
values = torch.randn(Z, layout.sum(), block_size, block_size, device=device).to( | ||
|
@@ -56,6 +56,29 @@ def _create_tensor(tensor_type, device, dtype, shape, sparsity): | |
) | ||
|
||
|
||
def _seed(): | ||
torch.random.manual_seed(42) | ||
torch.cuda.manual_seed_all(42) | ||
|
||
|
||
def _get_dtype_atol(tensor_type, device: str): | ||
_seed() | ||
|
||
if tensor_type == BlockSparseTensor and "cuda" in device: | ||
# Upstream GPU blocksparse (Triton op) uses TF32 by default for all internal computations | ||
# TF32 has the precision of fp16 but the range of fp32 | ||
# See https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/ | ||
torch.backends.cuda.matmul.allow_tf32 = True | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @fmassa this seems to be a better fit following the switch to triton2, which internally moved all tl.dot() operations to tf32 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cc @ptillet, just swapping triton 1.1 for 2.dev meant that this test would not pass anymore, as we discussed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. SGTM wrt the tests! |
||
torch.backends.cudnn.allow_tf32 = True | ||
return torch.float32, 1e-1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. wow, that is quite some low precision... |
||
|
||
# Force pytorch to keep its computations as float32 (will default to tf32 with recent cuda and ampere+ GPU) | ||
torch.backends.cuda.matmul.allow_tf32 = False | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @fmassa this fixed issues that I was seeing with these unit tests on an ampere GPU, which I presume stemmed from the fact that the sparse kernels were fp32 while pytorch defaulted to tf32 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh wow, thanks for spotting this! One more instance where tf32 is being somewhat harmful. Maybe worth commenting on pytorch/pytorch#67384 ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's a strange format, range of fp32 but precision of fp16, it's also kind of peculiar that it's really 18bits but named tf32.. |
||
torch.backends.cudnn.allow_tf32 = False | ||
|
||
return torch.float32, 1e-5 | ||
|
||
|
||
@pytest.mark.parametrize("device", _devices) | ||
@pytest.mark.parametrize("func", [torch.add, torch.mul]) | ||
def test_sparse_binary_ops(func, device): | ||
|
@@ -83,6 +106,7 @@ def test_sparse_binary_ops(func, device): | |
def test_masked_matmul(tensor_type, device): | ||
N, C, H, W, L = 8, 2, 64, 64, 32 | ||
sparsity = 0.7 | ||
dtype, atol = _get_dtype_atol(tensor_type, device) | ||
|
||
shape0 = (N, C, H, W) | ||
shape1 = (N, C, H, L) | ||
|
@@ -98,8 +122,8 @@ def test_masked_matmul(tensor_type, device): | |
) | ||
mask = mask_sparse.to_dense() | ||
|
||
a = torch.randn(shape1, device=device) | ||
b = torch.randn(shape2, device=device) | ||
a = torch.randn(shape1, device=device, dtype=dtype) | ||
b = torch.randn(shape2, device=device, dtype=dtype) | ||
|
||
aa = a.clone() | ||
bb = b.clone() | ||
|
@@ -119,24 +143,23 @@ def test_masked_matmul(tensor_type, device): | |
res_dense = torch.where(mask, res_dense, torch.full_like(res_dense, float("-inf"))) | ||
|
||
assert res.dtype == res_gt.dtype | ||
assert torch.allclose(res_dense, res_gt, atol=5e-6) | ||
assert torch.allclose(res_dense, res_gt, atol=atol) | ||
|
||
# try to workaround non-contiguous issues with triton for now | ||
res_gt.backward(torch.ones_like(res_gt)) | ||
res.values().backward(torch.ones_like(res.values())) | ||
# TODO: this is not passing for BlockSparse!!! | ||
if tensor_type != BlockSparseTensor: | ||
assert torch.allclose(a.grad, aa.grad, atol=5e-6) | ||
assert torch.allclose(b.grad, bb.grad, atol=5e-6) | ||
|
||
assert torch.allclose(a.grad, aa.grad, atol=atol) | ||
assert torch.allclose(b.grad, bb.grad, atol=atol) | ||
|
||
|
||
@pytest.mark.parametrize("tensor_type", _tensor_types) | ||
@pytest.mark.parametrize("device", _devices) | ||
def test_bmm(tensor_type, device): | ||
N, C, H, W, L = 8, 2, 64, 64, 32 | ||
dtype = torch.float32 | ||
sparsity = 0.8 | ||
dtype, atol = _get_dtype_atol(tensor_type, device) | ||
|
||
sparsity = 0.8 | ||
shape0 = (N, C, H, W) | ||
shape1 = (N, C, W, L) | ||
|
||
|
@@ -153,7 +176,7 @@ def test_bmm(tensor_type, device): | |
a_sparse.requires_grad_(True) | ||
a.requires_grad_(True) | ||
|
||
b = torch.randn(shape1, device=device) | ||
b = torch.randn(shape1, device=device, dtype=dtype) | ||
b2 = b.clone() | ||
|
||
b.requires_grad_(True) | ||
|
@@ -163,23 +186,28 @@ def test_bmm(tensor_type, device): | |
res = a_sparse @ b2 | ||
|
||
assert res.dtype == res_gt.dtype | ||
assert torch.allclose(res, res_gt, atol=1e-5) | ||
assert torch.allclose( | ||
res, res_gt, atol=atol | ||
), f"{torch.max(torch.abs(res-res_gt))} - tolerance: {atol}" | ||
|
||
res_gt.sum().backward() | ||
res.sum().backward() | ||
|
||
a_grad = a.grad.clone().detach() | ||
a_grad[~mask] = 0 | ||
|
||
assert torch.allclose(b.grad, b2.grad, atol=1e-5) | ||
assert torch.allclose(a_grad, a_sparse.grad.to_dense(), atol=1e-5) | ||
assert torch.allclose(b.grad, b2.grad, atol=atol) | ||
assert torch.allclose( | ||
a_grad, a_sparse.grad.to_dense(), atol=atol | ||
), f"{torch.max(torch.abs(a_grad-a_sparse.grad.to_dense()))}" | ||
|
||
|
||
@pytest.mark.parametrize("tensor_type", _tensor_types) | ||
@pytest.mark.parametrize("device", _devices) | ||
def test_sparse_softmax(tensor_type, device): | ||
N, C, H, W = 8, 2, 64, 64 | ||
dtype = torch.float32 | ||
dtype, atol = _get_dtype_atol(tensor_type, device) | ||
|
||
sparsity = 0.8 | ||
|
||
shape0 = (N, C, H, W) | ||
|
@@ -203,7 +231,9 @@ def test_sparse_softmax(tensor_type, device): | |
res = res_sparse.to_dense() | ||
|
||
assert res.dtype == res_gt.dtype | ||
assert torch.allclose(res, res_gt) | ||
assert torch.allclose( | ||
res, res_gt, atol=atol | ||
), f"{torch.max(torch.abs(res- res_gt))}" | ||
|
||
# WARNING: gradients are modified in-place! | ||
res_sparse.values().backward(torch.ones_like(res_sparse.values())) | ||
|
@@ -212,7 +242,9 @@ def test_sparse_softmax(tensor_type, device): | |
a_grad = a.grad.clone() | ||
a_grad[~mask] = 0 | ||
|
||
assert torch.allclose(a_grad, a_sparse.grad.to_dense(), atol=1e-6) | ||
assert torch.allclose( | ||
a_grad, a_sparse.grad.to_dense(), atol=atol | ||
), f"{torch.max(torch.abs(a_grad- a_sparse.grad.to_dense()))}" | ||
|
||
|
||
@pytest.mark.parametrize("tensor_type", _tensor_types) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this was to remove some reproducibility issues in between circleci and my machine..