Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ops/blocksparse] remove unnecessary mask #1351

Merged

Conversation

shintaro-iwasaki
Copy link
Contributor

This PR applies a minor patch that removes unnecessary masks in _dsd_kernel().

Details

offs_bn is defined as follows and not updated after that.

offs_bn = pid_m * TILE_N + tl.arange(0, TILE_N)
offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn % DS0, TILE_N), TILE_N)

Because offs_bn = offs_bn % DS0, this mask is always True.

b = tl.load(pb, mask=offs_bn[None, :] < DS0)

This PR removes this mask (as well as explicit mask=True).

@ptillet ptillet merged commit 4b774ee into triton-lang:main Mar 16, 2023
pingzhuu pushed a commit to siliconflow/triton that referenced this pull request Apr 2, 2024
This PR applies a minor patch that removes unnecessary masks in
`_dsd_kernel()`.

### Details

`offs_bn` is defined as follows and not updated after that.
```py
offs_bn = pid_m * TILE_N + tl.arange(0, TILE_N)
offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn % DS0, TILE_N), TILE_N)
```

Because `offs_bn = offs_bn % DS0`, this mask is always `True`.
```py
b = tl.load(pb, mask=offs_bn[None, :] < DS0)
```
This PR removes this mask (as well as explicit `mask=True`).
ZzEeKkAa pushed a commit to ZzEeKkAa/triton that referenced this pull request Aug 5, 2024
Addressing the compilation failure issue of
`test_core.py::test_fp8_dot_acc` (triton-lang#1351).
Enabling `allow_fp8e4b15` in `compiler.py` solves the problem.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants