Skip to content

Commit

Permalink
update triton version to 2.2.0
Browse files Browse the repository at this point in the history
  • Loading branch information
chenxu2048 committed Mar 7, 2024
1 parent 9970b79 commit 5d31fe7
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 5 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ uvicorn[standard]
pydantic >= 2.0 # Required for OpenAI server.
prometheus_client >= 0.18.0
pynvml == 11.5.0
triton >= 2.1.0
triton >= 2.2.0
outlines >= 0.0.27
cupy-cuda12x == 12.1.0 # Required for CUDA graphs. CUDA 11.8 users should install cupy-cuda11x instead.
7 changes: 3 additions & 4 deletions vllm/model_executor/layers/triton_kernel/prefix_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
import packaging

assert packaging.version.parse(triton.__version__) >= packaging.version.parse(
"2.1.0"), "Triton version >= 2.1.0 is required."

"2.2.0"), "Triton version >= 2.2.0 is required."

@triton.jit
def _fwd_kernel(
Expand Down Expand Up @@ -99,7 +98,7 @@ def _fwd_kernel(
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
k = tl.load(K_cache + off_k,
mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len,
other=0.0)
other=0.0).to(q.dtype)

qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
Expand All @@ -126,7 +125,7 @@ def _fwd_kernel(
# update acc
v = tl.load(V_cache + off_v,
mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len,
other=0.0)
other=0.0).to(k.dtype)

p = p.to(v.dtype)
acc += tl.dot(p, v)
Expand Down

0 comments on commit 5d31fe7

Please sign in to comment.