From 4b774ee4d0ecd645ce3f3e1fc0dda9b1c332dcd1 Mon Sep 17 00:00:00 2001 From: Shintaro Iwasaki Date: Wed, 15 Mar 2023 19:06:38 -0700 Subject: [PATCH] [OPS/BLOCKSPARSE] remove unnecessary mask (#1351) This PR applies a minor patch that removes unnecessary masks in `_dsd_kernel()`. ### Details `offs_bn` is defined as follows and not updated after that. ```py offs_bn = pid_m * TILE_N + tl.arange(0, TILE_N) offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn % DS0, TILE_N), TILE_N) ``` Because `offs_bn = offs_bn % DS0`, this mask is always `True`. ```py b = tl.load(pb, mask=offs_bn[None, :] < DS0) ``` This PR removes this mask (as well as explicit `mask=True`). --- python/triton/ops/blocksparse/matmul.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/triton/ops/blocksparse/matmul.py b/python/triton/ops/blocksparse/matmul.py index 4b6d98aac3e4..2c60f1a1a5e3 100644 --- a/python/triton/ops/blocksparse/matmul.py +++ b/python/triton/ops/blocksparse/matmul.py @@ -181,8 +181,8 @@ def _dsd_kernel( inc_b = tl.load(pinc) inc_b = tl.multiple_of(inc_b, 8) for k in range(K, 0, -TILE_K): - a = tl.load(pa, mask=True) - b = tl.load(pb, mask=offs_bn[None, :] < DS0) + a = tl.load(pa) + b = tl.load(pb) acc += tl.dot(a, b) pa += inc_a pb += inc_b * stride_bk