Skip to content

Commit

Permalink
[OPS/BLOCKSPARSE] remove unnecessary mask (#1351)
Browse files Browse the repository at this point in the history
This PR applies a minor patch that removes unnecessary masks in
`_dsd_kernel()`.

### Details

`offs_bn` is defined as follows and not updated after that.
```py
offs_bn = pid_m * TILE_N + tl.arange(0, TILE_N)
offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn % DS0, TILE_N), TILE_N)
```

Because `offs_bn = offs_bn % DS0`, this mask is always `True`.
```py
b = tl.load(pb, mask=offs_bn[None, :] < DS0)
```
This PR removes this mask (as well as explicit `mask=True`).
  • Loading branch information
shintaro-iwasaki authored Mar 16, 2023
1 parent c175473 commit 4b774ee
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions python/triton/ops/blocksparse/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,8 @@ def _dsd_kernel(
inc_b = tl.load(pinc)
inc_b = tl.multiple_of(inc_b, 8)
for k in range(K, 0, -TILE_K):
a = tl.load(pa, mask=True)
b = tl.load(pb, mask=offs_bn[None, :] < DS0)
a = tl.load(pa)
b = tl.load(pb)
acc += tl.dot(a, b)
pa += inc_a
pb += inc_b * stride_bk
Expand Down

0 comments on commit 4b774ee

Please sign in to comment.