Skip to content

Commit

Permalink
[Fix][Dlight] Fix GeneralReduction for log-sum-exp
Browse files Browse the repository at this point in the history
This PR fixes the GeneralReduction dlight rule so that it can support
scheduling log-sum-exp function.

Prior to this issue, the rule makes a strong assumption on the pattern
of the given function, which allows scheduling softmax, but fails to
schedule log-sum-exp due to pattern mismatch. This PR enhances the rule
and makes it able to match the pattern of log-sum-exp and apply
subsequent scheduling.

A regression test is added.
  • Loading branch information
MasterJH5574 committed Apr 25, 2024
1 parent 4f8c03f commit 1580ab5
Show file tree
Hide file tree
Showing 2 changed files with 176 additions and 8 deletions.
35 changes: 27 additions & 8 deletions python/tvm/dlight/gpu/general_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"""Reduction rule for operators including softmax, layer norm, RMS norm, etc"""
from typing import List, Union

from tvm import tir
from tvm import arith, tir
from tvm.target import Target

from ..base import normalize_prim_func, try_inline_contiguous_spatial
Expand Down Expand Up @@ -57,13 +57,32 @@ def apply( # pylint: disable=too-many-locals
# Align the number of block iters of the last block.
num_last_block_iter = len(block_infos[-1].dom_kind())
if num_last_block_iter < len(dom_kind):
index_map = tir.IndexMap.from_func(
lambda *iters: (
[tir.const(0, iters[0].dtype)] * (len(dom_kind) - num_last_block_iter)
+ list(iters)
),
ndim=num_last_block_iter,
)

def f_layout_mapping(*iters):
analyzer = arith.Analyzer()
# Try to match the iters of last block to the iters of the first block.
# For matched positions, use the iter from the input `iters`.
# For unmatched positions, use a new iter which is constant 0.
num_matched = 0
target_layout_iters = []
for block_iter in block_infos[0].iters:
if num_matched < len(iters) and analyzer.can_prove_equal(
block_iter.dom, block_infos[-1].iters[num_matched].dom
):
target_layout_iters.append(iters[num_matched])
num_matched += 1
else:
target_layout_iters.append(tir.const(0, iters[0].dtype))

# If all the iters of the last block can match, return the new layout.
if num_matched == len(iters):
return target_layout_iters
# Otherwise, fallback to appending zeros in the beginning.
return [tir.const(0, iters[0].dtype)] * (
len(dom_kind) - num_last_block_iter
) + list(iters)

index_map = tir.IndexMap.from_func(f_layout_mapping, ndim=num_last_block_iter)
sch.transform_block_layout(block_infos[-1].block_rv, index_map)

try:
Expand Down
149 changes: 149 additions & 0 deletions tests/python/dlight/test_gpu_general_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,5 +453,154 @@ def main(A: T.Buffer((1, 2048), "float32"), B: T.Buffer((2048,), "float32"), C:
_check(Before, After)


def test_logsumexp():
@I.ir_module
class Before:
@T.prim_func
def compute_lse(var_A: T.handle, var_blocked_lse: T.handle):
T.func_attr({"tir.noalias": T.bool(True)})
batch_size = T.int64(is_size_var=True)
vocab_size = T.int64(is_size_var=True)
num_chunks = T.int64(is_size_var=True)
A = T.match_buffer(var_A, (batch_size, vocab_size), dtype="float32")
blocked_lse = T.match_buffer(var_blocked_lse, (batch_size, num_chunks), dtype="float32")
A_pad = T.alloc_buffer((batch_size, num_chunks, T.int64(4096)), dtype="float32")
temp_max = T.alloc_buffer((batch_size, num_chunks), dtype="float32")
temp_sum = T.alloc_buffer((batch_size, num_chunks), dtype="float32")

for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(4096)):
with T.block("pad"):
v0, v1, v2 = T.axis.remap("SSS", [l0, l1, l2])
A_pad[v0, v1, v2] = T.if_then_else(
v1 * T.int64(4096) + v2 < vocab_size,
A[v0, v1 * T.int64(4096) + v2],
T.min_value("float32"),
)

for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(4096)):
with T.block("max"):
v0, v1, v2 = T.axis.remap("SSR", [l0, l1, l2])
with T.init():
temp_max[v0, v1] = T.min_value("float32")
temp_max[v0, v1] = T.max(temp_max[v0, v1], A_pad[v0, v1, v2])

for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(4096)):
with T.block("sum_exp"):
v0, v1, v2 = T.axis.remap("SSR", [l0, l1, l2])
with T.init():
temp_sum[v0, v1] = T.float32(0)
temp_sum[v0, v1] += T.if_then_else(
v1 * T.int64(4096) + v2 < vocab_size,
T.exp(A_pad[v0, v1, v2] - temp_max[v0, v1]),
T.float32(0),
)

for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(1)):
with T.block("log"):
v0, v1, v2 = T.axis.remap("SSS", [l0, l1, l2])
blocked_lse[v0, v1] = T.log(temp_sum[v0, v1]) + temp_max[v0, v1]

@I.ir_module
class After:
@T.prim_func
def compute_lse(var_A: T.handle, var_blocked_lse: T.handle):
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
batch_size, vocab_size = T.int64(is_size_var=True), T.int64(is_size_var=True)
A = T.match_buffer(var_A, (batch_size, vocab_size))
num_chunks = T.int64(is_size_var=True)
blocked_lse = T.match_buffer(var_blocked_lse, (batch_size, num_chunks))
temp_max_shared = T.alloc_buffer((batch_size, num_chunks), scope="shared")
temp_sum_shared = T.alloc_buffer((batch_size, num_chunks), scope="shared")
for ax0_ax1_fused in T.thread_binding(batch_size * num_chunks, thread="blockIdx.x"):
for ax0, ax1 in T.grid(T.int64(1), T.int64(1)):
for ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"):
for ax2_fused_0 in T.serial(
T.int64(16),
annotations={
"pragma_auto_unroll_max_step": 256,
"pragma_unroll_explicit": 1,
},
):
with T.block("max"):
v0 = T.axis.spatial(
batch_size,
ax0_ax1_fused % (num_chunks * batch_size) // num_chunks + ax0,
)
v1 = T.axis.spatial(num_chunks, ax0_ax1_fused % num_chunks + ax1)
v2 = T.axis.reduce(
T.int64(4096), ax2_fused_0 * T.int64(256) + ax2_fused_1
)
T.reads(A[v0, v1 * T.int64(4096) + v2])
T.writes(temp_max_shared[v0, v1])
with T.init():
temp_max_shared[v0, v1] = T.min_value("float32")
temp_max_shared[v0, v1] = T.max(
temp_max_shared[v0, v1],
T.if_then_else(
v1 * T.int64(4096) + v2 < vocab_size,
A[v0, v1 * T.int64(4096) + v2],
T.min_value("float32"),
),
)
for ax0, ax1 in T.grid(T.int64(1), T.int64(1)):
for ax2_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"):
for ax2_fused_0 in T.serial(
T.int64(16),
annotations={
"pragma_auto_unroll_max_step": 256,
"pragma_unroll_explicit": 1,
},
):
with T.block("sum_exp"):
v0 = T.axis.spatial(
batch_size,
ax0_ax1_fused % (num_chunks * batch_size) // num_chunks + ax0,
)
v1 = T.axis.spatial(num_chunks, ax0_ax1_fused % num_chunks + ax1)
v2 = T.axis.reduce(
T.int64(4096), ax2_fused_0 * T.int64(256) + ax2_fused_1
)
T.reads(A[v0, v1 * T.int64(4096) + v2], temp_max_shared[v0, v1])
T.writes(temp_sum_shared[v0, v1])
with T.init():
temp_sum_shared[v0, v1] = T.float32(0)
temp_sum_shared[v0, v1] = temp_sum_shared[v0, v1] + T.if_then_else(
v1 * T.int64(4096) + v2 < vocab_size,
T.exp(
(
T.if_then_else(
v1 * T.int64(4096) + v2 < vocab_size,
A[v0, v1 * T.int64(4096) + v2],
T.min_value("float32"),
)
- temp_max_shared[v0, v1]
)
),
T.float32(0),
)
for ax2_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"):
for ax2_0 in T.serial(
T.int64(1),
annotations={
"pragma_auto_unroll_max_step": 256,
"pragma_unroll_explicit": 1,
},
):
with T.block("log"):
v0 = T.axis.spatial(
batch_size, ax0_ax1_fused % (num_chunks * batch_size) // num_chunks
)
v1 = T.axis.spatial(num_chunks, ax0_ax1_fused % num_chunks)
v2 = T.axis.spatial(T.int64(1), ax2_0 * T.int64(256) + ax2_1)
T.where(ax2_0 * T.int64(256) + ax2_1 < T.int64(1))
T.reads(temp_sum_shared[v0, v1], temp_max_shared[v0, v1])
T.writes(blocked_lse[v0, v1])
blocked_lse[v0, v1] = (
T.log(temp_sum_shared[v0, v1]) + temp_max_shared[v0, v1]
)

_check(Before, After)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 1580ab5

Please sign in to comment.