From 5e6ef3769e9f472001621136f025efa41c1d2ba3 Mon Sep 17 00:00:00 2001 From: Harsh Menon Date: Mon, 13 Jan 2025 13:55:47 -0800 Subject: [PATCH] Paged Attention v2 Signed-off-by: Harsh Menon --- iree/turbine/kernel/_support/regions.py | 8 +- iree/turbine/kernel/ops/wave_ops.py | 3 +- iree/turbine/kernel/wave/codegen.py | 15 +- iree/turbine/kernel/wave/constraints.py | 7 +- .../kernel/wave/expansion/expansion.py | 44 ++- .../kernel/wave/expansion/expansion_utils.py | 23 +- .../kernel/wave/index_sequence_analysis.py | 16 +- .../kernel/wave/symbolic_constraints.py | 2 +- .../wave/templates/paged_decode_attention.py | 340 +++++++++++++++++ iree/turbine/kernel/wave/type_inference.py | 18 +- iree/turbine/kernel/wave/utils.py | 26 +- iree/turbine/kernel/wave/wave.py | 2 + lit_tests/kernel/wave/attention.py | 89 +++++ .../wave/attention/paged_attention_test.py | 358 ++++++++++++++++++ 14 files changed, 929 insertions(+), 22 deletions(-) create mode 100644 iree/turbine/kernel/wave/templates/paged_decode_attention.py create mode 100644 tests/kernel/wave/attention/paged_attention_test.py diff --git a/iree/turbine/kernel/_support/regions.py b/iree/turbine/kernel/_support/regions.py index d39b15a8c..48c3a21ec 100644 --- a/iree/turbine/kernel/_support/regions.py +++ b/iree/turbine/kernel/_support/regions.py @@ -94,10 +94,10 @@ def trace(self, *args, **kwargs) -> Tuple[str, List[fx.Proxy]]: subgraph_name = self.region_graph.add_subgraph("region", traced, inner_freevars) return subgraph_name, implicit_capture - def _create_graph_input(self, name: str, type_expr=None) -> fx.Proxy: + def _create_graph_input(self, node: fx.Node, name: str, type_expr=None) -> fx.Proxy: proxy = self.create_proxy("placeholder", name, (), {}, type_expr=type_expr) # Can use this to check where the freevar has been lifted from. - proxy.node.meta["lifted"] = None + proxy.node.meta["lifted"] = node return proxy def _lift_tracked_freevar_to_input(self, proxy: fx.Proxy): @@ -109,7 +109,9 @@ def _lift_tracked_freevar_to_input(self, proxy: fx.Proxy): return self.lifted_freevars[proxy] # Otherwise, create a new input and store it. - new_proxy = self._create_graph_input(proxy.node.name, proxy.node.type) + new_proxy = self._create_graph_input( + proxy.node, proxy.node.name, proxy.node.type + ) self.lifted_freevars[proxy] = new_proxy # Propagate freevar usage upwards. diff --git a/iree/turbine/kernel/ops/wave_ops.py b/iree/turbine/kernel/ops/wave_ops.py index 55c143b63..8b32d4d67 100644 --- a/iree/turbine/kernel/ops/wave_ops.py +++ b/iree/turbine/kernel/ops/wave_ops.py @@ -1075,13 +1075,14 @@ def transform_index_backwards( iters = self.mapping.iters mapping = self.mapping.dynamic_val_mappings[i] - # This logic relies on fact out mapping is identity. + # This logic relies on fact our mapping is identity. subs = { k: index[v] for k, v in zip(iters, self.mapping.output_mapping.keys()) } return { k: IndexSequence.from_expr(mapping[k], subs) for k in arg.type.symbolic_shape + if k in mapping } return index diff --git a/iree/turbine/kernel/wave/codegen.py b/iree/turbine/kernel/wave/codegen.py index 8a1f3118d..e00da5f4c 100644 --- a/iree/turbine/kernel/wave/codegen.py +++ b/iree/turbine/kernel/wave/codegen.py @@ -354,6 +354,7 @@ def _ceiling(value): cmp = arith_d.cmpi( arith_d.CmpIPredicate.eq, *_broadcast(value.numerator, zero) ) + zero, result = _broadcast(zero, result) value = arith_d.select(cmp, zero, result) else: value = arith_d.ceildivsi( @@ -464,11 +465,23 @@ def _get_const(val): lhs = stack.pop() _enforce_non_rational(rhs, term) _enforce_non_rational(lhs, term) - if _is_integer_like_type(rhs.type): + type = get_type_or_element_type(rhs.type) + if _is_integer_like_type(type): res = arith_d.maxsi(*_broadcast(lhs, rhs)) else: res = arith_d.maximumf(*_broadcast(lhs, rhs)) stack.append(res) + case sympy.Min(): + rhs = stack.pop() + lhs = stack.pop() + _enforce_non_rational(rhs, term) + _enforce_non_rational(lhs, term) + type = get_type_or_element_type(rhs.type) + if _is_integer_like_type(type): + res = arith_d.minsi(*_broadcast(lhs, rhs)) + else: + res = arith_d.minimumf(*_broadcast(lhs, rhs)) + stack.append(res) case sympy.logic.boolalg.BooleanFalse(): res = arith_d.constant(IntegerType.get_signless(1), 0) stack.append(res) diff --git a/iree/turbine/kernel/wave/constraints.py b/iree/turbine/kernel/wave/constraints.py index e288ffe98..ea65f7965 100644 --- a/iree/turbine/kernel/wave/constraints.py +++ b/iree/turbine/kernel/wave/constraints.py @@ -374,6 +374,7 @@ class WorkgroupConstraint(Constraint): dim: IndexExpr tile_size: IndexExpr workgroup_dim: int + primary: Optional[bool] = True def __post_init__(self): self.wg_dim = None @@ -397,8 +398,10 @@ def apply(self) -> IndexSequence: def get_grid_shape(wg_constraints: list[WorkgroupConstraint]) -> list[IndexExpr]: - sorted_constraints = sorted(wg_constraints, key=lambda x: x.workgroup_dim) - # Currently not more than one constraint in each dimension supported. + sorted_constraints = sorted( + [x for x in wg_constraints if x.primary], key=lambda x: x.workgroup_dim + ) + # Currently not more than one primary constraint in each dimension supported. if any( sorted_constraints[i].workgroup_dim == sorted_constraints[i + 1].workgroup_dim for i in range(len(sorted_constraints) - 1) diff --git a/iree/turbine/kernel/wave/expansion/expansion.py b/iree/turbine/kernel/wave/expansion/expansion.py index 10395546f..1bbed8b34 100644 --- a/iree/turbine/kernel/wave/expansion/expansion.py +++ b/iree/turbine/kernel/wave/expansion/expansion.py @@ -21,6 +21,9 @@ Reshape, GetResult, MMA, + SetSymbol, + ApplyExpr, + Broadcast, ) from ..._support.indexing import IndexingContext, IndexSymbol import itertools @@ -453,7 +456,7 @@ def populate_inputs( mma_metadata.last_mma_node = True new_nodes_to_expand.append((arg, mma_metadata)) continue - case Allocate(): + case Allocate() | SetSymbol() | ApplyExpr(): alloc_metadata = deepcopy(metadata) alloc_metadata.do_not_expand = True new_nodes_to_expand.append((arg, alloc_metadata)) @@ -462,6 +465,7 @@ def populate_inputs( new_nodes_to_expand.append((arg, metadata)) nodes_to_expand.extend(new_nodes_to_expand) + return nodes_to_expand @@ -519,7 +523,9 @@ def expand_node( ) # Check if the node has already been expanded, if so return early. - key = ExpansionInfo(node, get_indexed_dims(expanded_dims, node)) + indexed_dims = get_indexed_dims(expanded_dims, node) + + key = ExpansionInfo(node, indexed_dims) if key in expansion_context: update_users(node, expansion_context[key], metadata, expansion_context) return nodes_to_expand @@ -614,7 +620,33 @@ def fixup_mma_nodes(trace: CapturedTrace, expansion_context: ExpansionContext): first.replace_all_uses_with_except(second, [exclude]) -def fixup_reduction_nodes(trace: CapturedTrace, expansion_context: ExpansionContext): +def get_mma_indexed_dims( + mma: MMA, + original_indexed_dims: tuple[tuple[IndexSymbol, int]], + expansion_context: ExpansionContext, +): + dim = mma.reduction_dim + indexed_dims = None + max_reduction_dim = -1 + original_indexed_dims_dict = dict(original_indexed_dims) + for key in expansion_context.expansion_context.keys(): + indexed_dims_dict = dict(key.indexed_dims) + if any( + dim not in indexed_dims_dict + or indexed_dims_dict[dim] != original_indexed_dims_dict[dim] + for dim in original_indexed_dims_dict + ): + continue + if key.node == mma and indexed_dims_dict[dim] > max_reduction_dim: + indexed_dims = key.indexed_dims + max_reduction_dim = indexed_dims_dict[dim] + return indexed_dims + + +def fixup_reduction_nodes( + trace: CapturedTrace, + expansion_context: ExpansionContext, +): reduction_context = expansion_context.reduction_context for reduction in trace.walk(lambda x: isinstance(get_custom(x), Reduction)): reduction = get_custom(reduction) @@ -629,6 +661,12 @@ def fixup_reduction_nodes(trace: CapturedTrace, expansion_context: ExpansionCont sorted_keys = dict(sorted(reduction_info.outputs.items(), key=lambda x: x[0])) new_outputs = [] for key in sorted_keys.values(): + if key not in expansion_context and isinstance(key.node, MMA): + key = ExpansionInfo( + key.node, + get_mma_indexed_dims(key.node, key.indexed_dims, expansion_context), + ) + assert key in expansion_context, f"Key not found: {key}" new_outputs.append(expansion_context[key].fx_node) output.update_arg("return_vals", new_outputs) diff --git a/iree/turbine/kernel/wave/expansion/expansion_utils.py b/iree/turbine/kernel/wave/expansion/expansion_utils.py index 1986772a6..3b8c63516 100644 --- a/iree/turbine/kernel/wave/expansion/expansion_utils.py +++ b/iree/turbine/kernel/wave/expansion/expansion_utils.py @@ -23,6 +23,8 @@ Write, Reshape, NewRegister, + ReduceOp, + MMA, ) from ...lang.global_symbols import SHARED_ADDRESS_SPACE import itertools @@ -45,7 +47,7 @@ def __str__(self): def get_dim_scaling( - constraints: Sequence[Constraint], node: fx.Node + constraints: Sequence[Constraint], node: CustomOp ) -> dict[IndexSymbol, int]: """Get the number of expansions for the dimensions based on the constraints for a specific node.""" dim_scaling: dict[IndexSymbol, int] = {} @@ -92,6 +94,25 @@ def get_dim_scaling( ) dim_scaling[constraint.dim] = tile_size // wave_count // vector_size + # Also include dimensions that have no constraints on them and are known. + idxc = IndexingContext.current() + is_static_dim = lambda dim: dim in idxc.subs + is_non_batch = lambda dim: node.vector_shapes[dim] > 0 + not_computed = lambda dim: dim not in dim_scaling + + for dim in node.indexing_dims: + if not_computed(dim) and is_static_dim(dim) and is_non_batch(dim): + dim_scaling[dim] = idxc.get_static_value(dim) // node.vector_shapes[dim] + + # For reduce ops, also include the reduction dimension. + if isinstance(node, ReduceOp): + reduction_dim = node.reduction_dim + if not_computed(reduction_dim) and is_static_dim(reduction_dim): + dim_scaling[reduction_dim] = ( + idxc.get_static_value(reduction_dim) + // node.vector_shapes[reduction_dim] + ) + return dim_scaling diff --git a/iree/turbine/kernel/wave/index_sequence_analysis.py b/iree/turbine/kernel/wave/index_sequence_analysis.py index 0b49f68b1..bb700b96b 100644 --- a/iree/turbine/kernel/wave/index_sequence_analysis.py +++ b/iree/turbine/kernel/wave/index_sequence_analysis.py @@ -71,6 +71,8 @@ def _get_symbolic_shape_and_vector_shapes( # as it includes the aliased variables. symbolic_shape = custom.register_type.symbolic_shape vector_shapes = custom.vector_shapes + # TODO: Remove the need for this. + return custom.memory_type.symbolic_shape, vector_shapes if any([x in custom.memory_type.symbolic_shape for x in aliases]): symbolic_shape = custom.memory_type.symbolic_shape return symbolic_shape, vector_shapes @@ -334,7 +336,7 @@ def set_derived_index(trace): worklist.append((inp, new_index)) -def verify_nodes(trace: CapturedTrace): +def verify_nodes(trace: CapturedTrace, constraints: list[Constraint]): """ Verify that all the valid nodes have their index and vector shapes set. """ @@ -348,6 +350,16 @@ def verify_nodes(trace: CapturedTrace): if isinstance(custom, (Output, Reduction)): continue assert custom.index, f"Index not set for node {custom.fx_node}" + if not custom.vector_shapes: + # If vector_shapes is not set, see if it can be derived from the hardware constraints. + hw_constraint = get_hardware_constraint(constraints) + update_vector_shapes = [ + dim for dim in custom.index if dim in hw_constraint.vector_shapes + ] + if update_vector_shapes: + custom.vector_shapes = {} + for dim in update_vector_shapes: + custom.vector_shapes[dim] = hw_constraint.vector_shapes[dim] assert custom.vector_shapes, f"Vector shapes not set for node {custom.fx_node}" @@ -357,7 +369,7 @@ def set_node_indices(trace: CapturedTrace, constraints: list[Constraint]): set_thread_dependent_index(constraints, mma_index, trace) set_derived_index(trace) resolve_thread_shapes(trace, constraints) - verify_nodes(trace) + verify_nodes(trace, constraints) def compute_stride( diff --git a/iree/turbine/kernel/wave/symbolic_constraints.py b/iree/turbine/kernel/wave/symbolic_constraints.py index f0f976730..b51fe7350 100644 --- a/iree/turbine/kernel/wave/symbolic_constraints.py +++ b/iree/turbine/kernel/wave/symbolic_constraints.py @@ -6,7 +6,7 @@ from iree.turbine.kernel._support.indexing import IndexExpr, IndexSymbol from dataclasses import dataclass -from typing import Callable +from typing import Callable, Optional, Sequence from .utils import subs_idxc from .constraints import ( Constraint, diff --git a/iree/turbine/kernel/wave/templates/paged_decode_attention.py b/iree/turbine/kernel/wave/templates/paged_decode_attention.py new file mode 100644 index 000000000..17a01da44 --- /dev/null +++ b/iree/turbine/kernel/wave/templates/paged_decode_attention.py @@ -0,0 +1,340 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import iree.turbine.kernel.lang as tkl +import iree.turbine.kernel.wave as tkw +from iree.turbine.kernel.lang.global_symbols import * +from iree.turbine.kernel.wave.constraints import MMAType +from iree.turbine.kernel._support.dtype import DataType +from iree.turbine.kernel.wave.utils import ( + get_mfma_load_elems_per_thread, + get_mfma_store_elems_per_thread, +) +from ..symbolic_constraints import SymbolicAlias +import sympy +from enum import Enum +import math + + +def get_paged_decode_attention_kernels( + shape: tuple[int], + max_tokens: int, + mfma_variant: MMAType, + num_kv_splits: int, +): + # Input sizes + T = tkl.sym.T + S = tkl.sym.S + B = tkl.sym.B + M = tkl.sym.M + N = tkl.sym.N + K1 = tkl.sym.K1 + K2 = tkl.sym.K2 + U = tkl.sym.U + BH = tkl.sym.BH + # Workgroup tile sizes + BLOCK_B = tkl.sym.BLOCK_B + BLOCK_BH = tkl.sym.BLOCK_BH + BLOCK_M = tkl.sym.BLOCK_M + BLOCK_N = tkl.sym.BLOCK_N + BLOCK_K2 = tkl.sym.BLOCK_K2 + BLOCK_U = tkl.sym.BLOCK_U + BLOCK_T = tkl.sym.BLOCK_T + BLOCK_S = tkl.sym.BLOCK_S + # Address space (for GPU, shared(1) or global(0)) + ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE + # Other hyperparameters + LOAD_ELEMS_PER_THREAD = tkl.sym.LOAD_ELEMS_PER_THREAD + STORE_ELEMS_PER_THREAD = tkl.sym.STORE_ELEMS_PER_THREAD + + class Phase(Enum): + PHASE_0 = (0,) + PHASE_1 = (1,) + + B_WAVES = 2 + M_WAVES = 2 + N_WAVES = 2 + K_WAVES = 2 + THREADS_PER_WAVE = 64 + PHASE_1_BLOCK_B = 64 + PHASE_1_ELEMS_PER_THREAD = PHASE_1_BLOCK_B // THREADS_PER_WAVE + PHASE_1_BLOCK_N = 1 + + def phase_0_constraints(): + # K1, K2 are reduction dimensions that are fixed (not distributed) so + # they are not part of the constraints. + + constraints: list[tkw.Constraint] = [] + # U represents the number of splits of the key-value sequence. + # U is parallelizable and is distributed across workgroups. + constraints += [tkw.WorkgroupConstraint(U, BLOCK_U, 0)] + + # T represents the indices of the sequence tokens. + # T is dynamic and is distributed across workgroups and is tiled. + # + # Each workgroup will process: + # T = min(BLOCK, max(SEQ_LEN - WG_IDX[U] * BLOCK, 0)) tokens + # where BLOCK = ceil(SEQ_LEN / U) + # + # While T and U are related to one another, since we do not know SEQ_LEN + # we define them as symbolic aliases with different workgroup tile sizes. + # The tile size for T is set to BLOCK_T = ceil(SEQ_LEN / U) and will also + # be defined within the kernel. + constraints += [tkw.WorkgroupConstraint(T, BLOCK_T, 0, primary=False)] + constraints += [tkw.TilingConstraint(T, 1)] + + # BH is the kv-head index and is distributed across workgroups. + # B is the query index and is distributed like BH but with a different + # workgroup and wave tile size. + # TODO: We will want to add a function to the workgroup constraint to + # allow for using WG / ceil(kv_group_num, BLOCK_B) instead of just WG. + # This can be done by adding an optional additional argument to the WorkgroupConstraint. + + constraints += [tkw.WorkgroupConstraint(BH, BLOCK_BH, 1)] + constraints += [tkw.WorkgroupConstraint(B, BLOCK_B, 1, primary=False)] + constraints += [tkw.WaveConstraint(B, BLOCK_B / B_WAVES)] + + constraints += [tkw.WorkgroupConstraint(S, BLOCK_S, 2)] + + vector_shapes = {BH: 0, T: 0, S: 0, U: 1} + waves_per_block = (1, B_WAVES, 1) + constraints += [ + tkw.HardwareConstraint( + threads_per_wave=THREADS_PER_WAVE, + waves_per_block=waves_per_block, + mma_type=mfma_variant, + vector_shapes=vector_shapes, + ) + ] + return constraints + + def phase_1_constraints() -> list[tkw.Constraint]: + constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(B, BLOCK_B, 0)] + constraints += [tkw.WaveConstraint(B, BLOCK_B)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.WaveConstraint(N, BLOCK_N)] + constraints += [tkw.WorkgroupConstraint(S, BLOCK_S, 2)] + constraints += [tkw.TilingConstraint(U, BLOCK_U)] + vector_shapes = { + S: 0, + B: BLOCK_B, + N: BLOCK_N, + U: 1, + } + waves_per_block = (1, 1, 1) + constraints += [ + tkw.HardwareConstraint( + threads_per_wave=THREADS_PER_WAVE, + waves_per_block=waves_per_block, + mma_type=mfma_variant, + vector_shapes=vector_shapes, + ) + ] + return constraints + + def get_constraints(phase: Phase) -> list[tkw.Constraint]: + if phase == Phase.PHASE_0: + return phase_0_constraints() + else: + return phase_1_constraints() + + i = tkw.IndexMapping.iterator(0) + j = tkw.IndexMapping.iterator(1) + k = tkw.IndexMapping.iterator(2) + l = tkw.IndexMapping.iterator(3) + d0 = tkw.IndexMapping.dynamic_val(0) + + mapping = tkw.IndexMapping( + num_iterators=3, + inputs={S: i, B: j, N: k}, + outputs={S: i, B: j, N: k}, + ) + + # Returns the key for the given token index. + k_mapping = tkw.IndexMapping( + num_iterators=4, + inputs={T: d0, BH: j, K2: k, K1: l}, + outputs={T: i, BH: j, K2: k, K1: l}, + dynamic_val_mappings={T: i}, + ) + + # Returns the value for the given token index. + v_mapping = tkw.IndexMapping( + num_iterators=4, + inputs={T: d0, BH: j, N: k, K2: l}, + outputs={T: i, BH: j, N: k, K2: l}, + dynamic_val_mappings={T: i}, + ) + + # Returns token indices into the k-v cache for the given sequence (d0). + block_table_mapping = tkw.IndexMapping( + num_iterators=2, + inputs={S: d0, T: j}, + outputs={S: i, T: j}, + dynamic_val_mappings={S: i}, + ) + + @tkw.wave(get_constraints(Phase.PHASE_0)) + def phase_0( + q: tkl.Memory[S, B, K1, GLOBAL_ADDRESS_SPACE, tkl.f16], + k: tkl.Memory[T, BH, K2, K1, ADDRESS_SPACE, tkl.f16], + v: tkl.Memory[T, BH, N, K2, ADDRESS_SPACE, tkl.f16], + request_indices: tkl.Memory[S, GLOBAL_ADDRESS_SPACE, tkl.i32], + sequence_lengths: tkl.Memory[S, GLOBAL_ADDRESS_SPACE, tkl.i32], + block_table: tkl.Memory[S, T, GLOBAL_ADDRESS_SPACE, tkl.i32], + output: tkl.Memory[U, S, N, B, GLOBAL_ADDRESS_SPACE, tkl.f32], + output_max: tkl.Memory[U, S, B, GLOBAL_ADDRESS_SPACE, tkl.f32], + ): + # ========================================================================= + # Query has shape [NUM_SEQS, NUM_HEADS, HEAD_DIM] + # Key has shape [NUM_BLOCKS, NUM_KV_HEADS, BLOCK_SIZE, HEAD_DIM] + # Value has shape [NUM_BLOCKS, NUM_KV_HEADS, HEAD_DIM, BLOCK_SIZE] + # (TODO: This is a transposed version of the original) + # Sequence lengths has shape [NUM_SEQS] + # Request indices has shape [NUM_SEQS] + # Block table has shape [NUM_SEQS, MAX_KV_SEQ_LEN] + # Output has shape [NUM_KV_SPLITS, NUM_SEQS, NUM_HEADS, HEAD_DIM] + # ========================================================================= + + init_max = tkl.Register[S, B, tkl.f32](-1e6) + init_sum = tkl.Register[S, B, tkl.f32](0.0) + new_acc = tkl.Register[S, N, B, tkl.f32](0.0) + + # The request index is used to load the appropriate entries from the block table. + req_index = tkw.read(request_indices, elements_per_thread=1) + # The sequence length is used to control the bounds of the loop over T. + orig_seq_length = tkw.read(sequence_lengths, elements_per_thread=1) + + # The dimension T and its workgroup tile size BLOCK_T are both dynamic + # and set below. + tile_size = tkw.apply_expr(orig_seq_length, lambda x: sympy.ceiling(x / U)) + tkw.set_symbol(BLOCK_T, tile_size) + seq_length = tkw.apply_expr( + orig_seq_length, + lambda x: sympy.Min(BLOCK_T, sympy.Max(0, x - WORKGROUP_0 * BLOCK_T)), + ) + tkw.set_symbol(T, seq_length) + + # TODO: Add if statement here in cases where T is 0 to avoid writing nans for the output. + # While the for loop will be skipped, the calculations and writes outside the for + # loop will still be executed. + + @tkw.reduction(T, init_args=[init_max, init_sum, new_acc]) + def loop( + partial_max: tkl.Register[S, B, tkl.f32], + partial_sum: tkl.Register[S, B, tkl.f32], + acc: tkl.Register[S, N, B, tkl.f32], + ): + q_reg = tkw.read(q, elements_per_thread=LOAD_ELEMS_PER_THREAD) + tkw.set_symbol(T, orig_seq_length) + block_indices = tkw.read( + block_table, + elements_per_thread=1, + mapping=block_table_mapping, + mapping_dynamic_vals=(req_index,), + ) + k_reg = tkw.read( + k, + elements_per_thread=LOAD_ELEMS_PER_THREAD, + mapping=k_mapping, + mapping_dynamic_vals=(block_indices,), + ) + imm_reg = tkl.Register[S, K2, B, tkl.f32](0.0) + inner_acc = tkw.mma(k_reg, q_reg, imm_reg) + x_j = tkw.permute(inner_acc, target_shape=[S, B, K2]) + m_j = tkw.max(x_j, partial_max, dim=K2) + e_delta_max = tkw.exp2(partial_max - m_j) + e_delta = tkw.exp2(x_j - m_j) + e_init = partial_sum * e_delta_max + d_j = tkw.sum(e_delta, e_init, dim=K2) + imm_f16 = tkw.cast(e_delta, tkl.f16) + v_reg = tkw.read( + v, + elements_per_thread=LOAD_ELEMS_PER_THREAD, + mapping=v_mapping, + mapping_dynamic_vals=(block_indices,), + ) + new_acc = acc * e_delta_max + acc = tkw.mma(v_reg, imm_f16, new_acc) + return m_j, d_j, acc + + res_max, res_sum, res_mm = loop + reciprocal_sum = tkw.reciprocal(res_sum) + res = res_mm * reciprocal_sum + res_max_log_sum = res_max + tkw.log2(res_sum) + + tkw.write(res_max_log_sum, output_max, elements_per_thread=1) + tkw.write(res, output, elements_per_thread=STORE_ELEMS_PER_THREAD) + + @tkw.wave(get_constraints(Phase.PHASE_1)) + def phase_1( + logits: tkl.Memory[U, S, N, B, GLOBAL_ADDRESS_SPACE, tkl.f32], + logits_max: tkl.Memory[U, S, B, GLOBAL_ADDRESS_SPACE, tkl.f32], + output: tkl.Memory[S, B, N, GLOBAL_ADDRESS_SPACE, tkl.f32], + ): + c_reg = tkl.Register[S, B, N, tkl.f32](0.0) + init_sum = tkl.Register[S, B, tkl.f32](0.0) + init_max = tkl.Register[S, B, tkl.f32](-1e6) + + @tkw.reduction(U, init_args=[init_max, init_sum, c_reg]) + def repeat( + partial_max: tkl.Register[S, B, tkl.f32], + partial_sum: tkl.Register[S, B, tkl.f32], + acc: tkl.Register[S, B, N, tkl.f32], + ): + x_j = tkw.read(logits, elements_per_thread=PHASE_1_ELEMS_PER_THREAD) + xm_j = tkw.read(logits_max, elements_per_thread=PHASE_1_ELEMS_PER_THREAD) + m_j = tkw.maximum(xm_j, partial_max) + old_scale = tkw.exp2(partial_max - m_j) + new_scale = tkw.exp2(xm_j - m_j) + d_j = partial_sum * old_scale + new_scale + new_acc = acc * old_scale + term = new_scale * x_j + new_acc = new_acc + term + return m_j, d_j, new_acc + + res_max, res_sum, res_mm = repeat + res = res_mm / res_sum + tkw.write( + res, output, mapping=mapping, elements_per_thread=PHASE_1_ELEMS_PER_THREAD + ) + + symbols_0 = { + ADDRESS_SPACE: SHARED_ADDRESS_SPACE, + LOAD_ELEMS_PER_THREAD: get_mfma_load_elems_per_thread(mfma_variant), + STORE_ELEMS_PER_THREAD: get_mfma_store_elems_per_thread(mfma_variant), + BLOCK_BH: 1, + BLOCK_B: shape[0] // shape[5], + BLOCK_S: 1, + BLOCK_U: 1, + B: shape[0], + M: shape[1], + N: shape[2], + K1: shape[3], + K2: shape[4], + BH: shape[5], + S: shape[6], + U: num_kv_splits, + } + symbols_1 = dict(symbols_0) + symbols_1[BLOCK_B] = PHASE_1_BLOCK_B + symbols_1[BLOCK_N] = PHASE_1_BLOCK_N + + dynamic_symbols_0 = [T] + dynamic_symbols_1 = [] + dynamic_symbols_map_0 = {T: 1} + dynamic_symbols_map_1 = {} + + return ( + phase_0, + phase_1, + symbols_0, + symbols_1, + dynamic_symbols_0, + dynamic_symbols_map_0, + dynamic_symbols_1, + dynamic_symbols_map_1, + ) diff --git a/iree/turbine/kernel/wave/type_inference.py b/iree/turbine/kernel/wave/type_inference.py index db574cfc9..41fae3443 100644 --- a/iree/turbine/kernel/wave/type_inference.py +++ b/iree/turbine/kernel/wave/type_inference.py @@ -13,9 +13,17 @@ def infer_types(trace: CapturedTrace | fx.Graph): + if isinstance(trace, fx.Graph): + all_nodes = trace.nodes + else: + all_nodes = trace.get_root_graph().nodes # Infer and set the types for all nodes in the graph. - for subgraph in trace.region_graph.subgraphs.values(): - for node in subgraph.nodes: - custom = get_custom(node) - custom.infer_type() - logger.debug(f"Setting type for {custom.fx_node} = {custom.type}") + for node in all_nodes: + custom = get_custom(node) + if isinstance(custom, Reduction): + infer_types(trace.region_graph.subgraphs[custom.subgraph_name]) + custom.infer_type() + # For implicit captures, get type from variables in root graph. + if "lifted" in custom.fx_node.meta: + custom.type = custom.fx_node.meta["lifted"].type + logger.debug(f"Setting type for {custom.fx_node} = {custom.type}") diff --git a/iree/turbine/kernel/wave/utils.py b/iree/turbine/kernel/wave/utils.py index 45d3b2c77..1c2d50405 100644 --- a/iree/turbine/kernel/wave/utils.py +++ b/iree/turbine/kernel/wave/utils.py @@ -27,6 +27,9 @@ ExtractSlice, IterArg, Reshape, + Read, + SetSymbol, + ApplyExpr, ) from ..lang.wave_types import IndexMapping from .constraints import ( @@ -189,7 +192,9 @@ def is_removable_operator(node: fx.Node) -> bool: ) return ( - not custom.users and not isinstance(custom, Output) and not is_global_write + not custom.users + and not isinstance(custom, (Output, SetSymbol, ApplyExpr)) + and not is_global_write ) while removable_nodes := trace.walk(is_removable_operator): @@ -789,9 +794,16 @@ def get_users( if isinstance(custom, Reduction): # Map init arg to iter arg reduction = custom - init_arg_idx = custom.init_args.index(node) graph = custom.get_root_graph().subgraphs[custom.subgraph_name] - users.append(custom.iter_args(graph)[init_arg_idx]) + if node in custom.init_args: + init_arg_idx = custom.init_args.index(node) + users.append(custom.iter_args(graph)[init_arg_idx]) + else: + assert node in custom.implicit_captures + for outside_node in graph.nodes: + if outside_node.meta.get("lifted", None) == node: + users.append(outside_node) + break continue if isinstance(custom, Output): # Map output to get result @@ -1082,10 +1094,18 @@ def to_default_device(tensor: torch.Tensor) -> torch.Tensor: return tensor.to(get_default_device()) +def device_arange(*args, **kwargs): + return to_default_device(torch.arange(*args, **kwargs)) + + def device_randn(*args, **kwargs): return to_default_device(torch.randn(*args, **kwargs)) +def device_randn_like(*args, **kwargs): + return to_default_device(torch.randn_like(*args, **kwargs)) + + def device_randint(*args, **kwargs): return to_default_device(torch.randint(*args, **kwargs)) diff --git a/iree/turbine/kernel/wave/wave.py b/iree/turbine/kernel/wave/wave.py index 27c02d6b4..1eb1f231b 100644 --- a/iree/turbine/kernel/wave/wave.py +++ b/iree/turbine/kernel/wave/wave.py @@ -392,6 +392,8 @@ def _trace_and_get_kernel_signature( for constraint in self.workgroup_constraints: if constraint.dim in aliases: continue + if not constraint.primary: + continue dim = ( constraint.workgroup_dim if constraint.workgroup_dim < max_workgroup_dim diff --git a/lit_tests/kernel/wave/attention.py b/lit_tests/kernel/wave/attention.py index 43122fdbe..06f8cdbdf 100644 --- a/lit_tests/kernel/wave/attention.py +++ b/lit_tests/kernel/wave/attention.py @@ -14,6 +14,9 @@ from iree.turbine.kernel.wave.templates.decode_attention import ( get_decode_attention_kernels, ) +from iree.turbine.kernel.wave.templates.paged_decode_attention import ( + get_paged_decode_attention_kernels, +) import torch import sympy import math @@ -1080,3 +1083,89 @@ def repeat( # CHECK-COUNT-4: {{.*}} = arith.addf # CHECK-COUNT-8: {{.*}} = gpu.shuffle xor {{.*}} # CHECK-COUNT-8: {{.*}} = amdgpu.mfma + + +@run_test +def test_paged_flash_decoding(): + # (B, M, N, K1, K2, BH, S) + shape = (128, 1, 32, 32, 64, 4, 8) + max_tokens = 2048 + num_kv_splits = 8 + mfma_variant = tkw.MMAType.F32_16x16x16_F16 + ( + phase_0, + phase_1, + hyperparams_0, + hyperparams_1, + dynamic_symbols_0, + dynamic_symbols_map_0, + dynamic_symbols_1, + dynamic_symbols_map_1, + ) = get_paged_decode_attention_kernels( + shape, max_tokens, mfma_variant, num_kv_splits + ) + + torch.manual_seed(0) + q = torch.randn(shape[0], shape[1], shape[3], dtype=torch.float16) + k = torch.randn(shape[0], shape[4], shape[3], dtype=torch.float16) + v = torch.randn(shape[0], shape[4], shape[2], dtype=torch.float16) + logits = torch.zeros(shape[0], shape[1], shape[2], dtype=torch.float32) + logits_max = torch.zeros(shape[0], shape[1], dtype=torch.float32) + output = torch.zeros(shape[0], shape[1], shape[2], dtype=torch.float32) + + with tk.gen.TestLaunchContext( + hyperparams_0, + canonicalize=True, + run=False, + run_bench=False, + schedule=False, + use_scheduling_barriers=False, + dynamic_symbols=dynamic_symbols_0, + dynamic_symbols_map=dynamic_symbols_map_0, + ): + print(phase_0(q, k, v, logits, logits_max).module_op) + + # CHECK: func.func @phase_0 + # CHECK-NOT: {{.*}} = scf.for + # CHECK-COUNT-9: {{.*}} = vector.load + # CHECK-COUNT-1: vector.store + # CHECK-COUNT-4: {{.*}} = vector.load + # CHECK-COUNT-8: {{.*}} = amdgpu.mfma + # CHECK-COUNT-4: {{.*}} = gpu.shuffle + # CHECK-COUNT-2: {{.*}} = arith.subf + # CHECK-COUNT-2: {{.*}} = math.exp2 + # CHECK-COUNT-2: {{.*}} = arith.subf + # CHECK-COUNT-2: {{.*}} = math.exp2 + # CHECK-COUNT-4: {{.*}} = gpu.shuffle + # CHECK-COUNT-2: {{.*}} = amdgpu.mfma + # CHECK-COUNT-2: {{.*}} = arith.divf + # CHECK-COUNT-2: {{.*}} = math.log2 + # CHECK-COUNT-18: vector.store + + with tk.gen.TestLaunchContext( + hyperparams_1, + canonicalize=True, + run=False, + run_bench=False, + schedule=False, + use_scheduling_barriers=False, + dynamic_symbols=dynamic_symbols_1, + dynamic_symbols_map=dynamic_symbols_map_1, + ): + print(phase_1(logits, logits_max, output).module_op) + + # CHECK: func.func @phase_1 + # CHECK: {{.*}} = scf.for + # CHECK-COUNT-2: {{.*}} = vector.load + # CHECK-COUNT-1: {{.*}} = arith.maximumf + # CHECK-COUNT-1: {{.*}} = arith.subf + # CHECK-COUNT-1: {{.*}} = math.exp2 + # CHECK-COUNT-1: {{.*}} = arith.subf + # CHECK-COUNT-1: {{.*}} = math.exp2 + # CHECK-COUNT-1: {{.*}} = arith.mulf + # CHECK-COUNT-1: {{.*}} = arith.addf + # CHECK-COUNT-2: {{.*}} = arith.mulf + # CHECK-COUNT-1: {{.*}} = arith.addf + # CHECK-COUNT-1: {{.*}} = arith.divf + # TODO: Remove vector.scatter when optimizing for performance + # CHECK-COUNT-1: vector.scatter diff --git a/tests/kernel/wave/attention/paged_attention_test.py b/tests/kernel/wave/attention/paged_attention_test.py new file mode 100644 index 000000000..aa4712297 --- /dev/null +++ b/tests/kernel/wave/attention/paged_attention_test.py @@ -0,0 +1,358 @@ +# Copyright 2025 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import pytest +import torch +import math +import iree.turbine.kernel as tk +from iree.turbine.kernel.lang.global_symbols import * +from iree.turbine.kernel.wave.utils import ( + get_default_run_config, + get_default_scheduling_params, + device_arange, + device_randn, + device_randint, + device_randn_like, + device_zeros, +) +from iree.turbine.kernel.wave.constraints import MMAType +from iree.turbine.kernel.wave.templates.paged_decode_attention import ( + get_paged_decode_attention_kernels, +) +import os +from torch.testing import assert_allclose +from ..common.utils import ( + require_e2e, + enable_scheduling_barriers, + dump_generated_mlir, +) +from ..common.shapes import get_test_shapes +from typing import List, Optional + +# Reference paged attention implementation from vLLM and sglang. +# From: https://github.com/vllm-project/vllm/blob/main/tests/kernels/test_flash_attn.py +NUM_HEADS = [(128, 4)] +HEAD_SIZES = [64] +BLOCK_SIZES = [64] +DTYPES = [torch.float16] +NUM_BLOCKS = [128] +# First item is query length, second item is key/value length. +# In decode, query length is always one. +# TODO: Check with more queries and unaligned shapes. +SEQ_LENS = [[(1, 16), (1, 8)]] + + +# From: https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/attention/torch_native_backend.py +def _run_sdpa_forward_decode( + query: torch.Tensor, + output: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + req_to_token: torch.Tensor, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + scaling=None, + enable_gqa=False, + causal=False, +): + """Run the decode forward by using torch native sdpa op. + + Args: + query: [num_tokens, num_heads, head_size] + output: [num_tokens, num_heads, head_size] + k_cache: [max_total_num_tokens, num_heads, head_size] + v_cache: [max_total_num_tokens, num_heads, head_size] + req_to_token: [max_num_reqs, max_context_len] + req_pool_indices: [num_seqs] + seq_lens: [num_seqs] + scaling: float or None + enable_gqa: bool + causal: bool + + Returns: + output: [num_tokens, num_heads, head_size] + """ + + # [num_tokens, num_heads, head_size] -> [num_heads, num_tokens, head_size] + query = query.movedim(0, query.dim() - 2) + + start_q, start_kv = 0, 0 + for seq_idx in range(seq_lens.shape[0]): + # TODO: this loop process a sequence per iter, this is inefficient. + # Need optimize the performance later. + + seq_len_q = 1 + seq_len_kv = seq_lens[seq_idx] + end_q = start_q + seq_len_q + end_kv = start_kv + seq_len_kv + + per_req_query = query[:, start_q:end_q, :] + + # get key and value from cache. per_req_tokens contains the kv cache + # index for each token in the sequence. + req_pool_idx = req_pool_indices[seq_idx] + per_req_tokens = req_to_token[req_pool_idx, :seq_len_kv] + per_req_key = k_cache[per_req_tokens].movedim(0, query.dim() - 2) + per_req_value = v_cache[per_req_tokens].movedim(0, query.dim() - 2) + + per_req_out = ( + torch.nn.functional.scaled_dot_product_attention( + per_req_query.unsqueeze(0), + per_req_key.unsqueeze(0), + per_req_value.unsqueeze(0), + enable_gqa=enable_gqa, + scale=scaling, + is_causal=causal, + ) + .squeeze(0) + .movedim(query.dim() - 2, 0) + ) + output[start_q:end_q, :, :] = per_req_out + start_q, start_kv = end_q, end_kv + + return output + + +def ref_paged_attn( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + query_lens: List[int], + kv_lens: List[int], + block_tables: torch.Tensor, + scale: float, + causal: Optional[bool] = False, + sliding_window: Optional[int] = None, + soft_cap: Optional[float] = None, +) -> torch.Tensor: + num_seqs = len(query_lens) + block_tables = block_tables.cpu().numpy() + _, block_size, num_kv_heads, head_size = key_cache.shape + + outputs: List[torch.Tensor] = [] + start_idx = 0 + for i in range(num_seqs): + query_len = query_lens[i] + kv_len = kv_lens[i] + q = query[start_idx : start_idx + query_len] + q *= scale + + block_indices = block_tables[i, :kv_len] + + k = key_cache[block_indices].view(-1, num_kv_heads, head_size) + v = value_cache[block_indices].view(-1, num_kv_heads, head_size) + + if q.shape[1] != k.shape[1]: + k = torch.repeat_interleave(k, q.shape[1] // k.shape[1], dim=1) + v = torch.repeat_interleave(v, q.shape[1] // v.shape[1], dim=1) + attn = torch.einsum("qhd,khd->hqk", q, k).float() + empty_mask = torch.ones(query_len, kv_len) + mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool() + if sliding_window is not None: + sliding_window_mask = ( + torch.triu( + empty_mask, diagonal=kv_len - (query_len + sliding_window) + 1 + ) + .bool() + .logical_not() + ) + mask |= sliding_window_mask + if soft_cap is not None: + attn = soft_cap * torch.tanh(attn / soft_cap) + if causal: + attn.masked_fill_(mask, float("-inf")) + attn = torch.softmax(attn, dim=-1).to(v.dtype) + out = torch.einsum("hqk,khd->qhd", attn, v) + + outputs.append(out) + start_idx += query_len + + return torch.cat(outputs, dim=0) + + +@require_e2e +@pytest.mark.parametrize("seq_lens", SEQ_LENS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("sliding_window", [None]) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("soft_cap", [None]) +@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) +@pytest.mark.parametrize("enable_scheduling", [False]) +@pytest.mark.parametrize( + "mfma_variant", + [ + MMAType.F32_16x16x16_F16, + ], +) +def testPagedFlashDecoding( + seq_lens: List[tuple[int, int]], + num_heads: tuple[int, int], + head_size: int, + block_size: int, + sliding_window: Optional[int], + dtype: torch.dtype, + soft_cap: Optional[float], + num_blocks: int, + enable_scheduling: bool, + mfma_variant: MMAType, + request, +): + + torch.manual_seed(0) + num_seqs = len(seq_lens) + query_lens = [x[0] for x in seq_lens] + kv_lens = [x[1] for x in seq_lens] + num_query_heads = num_heads[0] + num_kv_heads = num_heads[1] + assert num_query_heads % num_kv_heads == 0 + max_query_len = max(query_lens) + max_kv_len = max(kv_lens) + window_size = (sliding_window - 1, 0) if sliding_window is not None else (-1, -1) + scale = head_size**-0.5 + + query = device_randn(sum(query_lens), num_query_heads, head_size, dtype=dtype) + key_cache = device_randn( + num_blocks, block_size, num_kv_heads, head_size, dtype=dtype + ) + value_cache = device_randn_like(key_cache) + # TODO: The block table entries should be able to be a random number + # in the range [0, num_blocks * block_size), but that fails for now. + # As a workaround, the maximum value is set to num_seqs - 1. + block_table = device_randint(0, num_seqs, (num_seqs, max_kv_len), dtype=torch.int32) + request_indices = device_arange(num_seqs, dtype=torch.int32) + kv_lens_tensor = device_zeros(num_seqs, dtype=torch.int32) + for i in range(len(kv_lens)): + kv_lens_tensor[i] = kv_lens[i] + + # Run the wave kernel. + # TODO: Currently all but K1 is set to dynamic. This may not be the case. + S = num_seqs + B = num_query_heads + K1 = head_size + K2 = block_size + M = 1 + N = head_size + BH = num_kv_heads + shape = (B, M, N, K1, K2, BH, S) + num_kv_splits = 8 + ( + phase_0, + phase_1, + hyperparams_0, + hyperparams_1, + dynamic_symbols_0, + dynamic_symbols_map_0, + dynamic_symbols_1, + dynamic_symbols_map_1, + ) = get_paged_decode_attention_kernels( + shape, num_blocks * block_size, mfma_variant, num_kv_splits + ) + hyperparams_0.update(get_default_scheduling_params()) + hyperparams_1.update(get_default_scheduling_params()) + config = get_default_run_config() + run_bench = request.config.getoption("--runperf") + dump_perf = request.config.getoption("--dump-perf-files-path") + if run_bench: + config["benchmark_batch_size"] = 10 + config["benchmark_repetitions"] = 3 + if dump_perf is not None: + perf_filename = request.node.name + ".json" + config["benchmark_results_file"] = os.path.join( + dump_perf, "tk_" + perf_filename + ) + + sym_U = index_symbol("U") + if sym_U in hyperparams_0: + U = hyperparams_0[sym_U] + else: + U = dynamic_symbols_map_0[sym_U] + phase_0_output = device_zeros(U, S, N, B, dtype=torch.float32) + phase_0_output_max = device_zeros(U, S, B, dtype=torch.float32) + output = device_zeros(S, B, N, dtype=torch.float32) + log2e = 1.44269504089 + dk_sqrt = math.sqrt(1.0 / K1) + + with tk.gen.TestLaunchContext( + hyperparams_0, + canonicalize=True, + run=True, + run_bench=run_bench, + run_config=config, + schedule=enable_scheduling, + use_scheduling_barriers=enable_scheduling_barriers, + dynamic_symbols=dynamic_symbols_0, + dynamic_symbols_map=dynamic_symbols_map_0, + ): + # TODO: Add scaling of QK as part of kernel. + mb_qk = phase_0( + query * dk_sqrt * log2e, + key_cache.permute([0, 2, 1, 3]), + value_cache.permute([0, 2, 3, 1]), + request_indices, + kv_lens_tensor, + block_table, + phase_0_output, + phase_0_output_max, + ) + + with tk.gen.TestLaunchContext( + hyperparams_1, + canonicalize=True, + run=True, + run_bench=run_bench, + run_config=config, + schedule=enable_scheduling, + use_scheduling_barriers=enable_scheduling_barriers, + dynamic_symbols=dynamic_symbols_1, + dynamic_symbols_map=dynamic_symbols_map_1, + ): + # TODO: Add variant of non-transposed V attention kernel. + mb_sv = phase_1(phase_0_output, phase_0_output_max, output) + + if dump_generated_mlir: + filename = f"wave_paged_phase_0_kernel_{'x'.join(map(str, shape))}.mlir" + with open(filename, "w") as f: + f.write(mb_qk.module_op.get_asm()) + filename = f"wave_paged_phase_1_kernel_{'x'.join(map(str, shape))}.mlir" + with open(filename, "w") as f: + f.write(mb_sv.module_op.get_asm()) + + # Run the reference implementation (vllm or sglang). + ref_vllm_output = ref_paged_attn( + query=query, + key_cache=key_cache, + value_cache=value_cache, + query_lens=query_lens, + kv_lens=kv_lens, + block_tables=block_table, + scale=scale, + causal=False, + sliding_window=sliding_window, + soft_cap=soft_cap, + ) + + compare_sglang = False + if compare_sglang: + # Since the query gets scaled in the first call, we don't + # scale it again below. + ref_sglang_output = _run_sdpa_forward_decode( + query=query, + output=torch.zeros_like(query), + k_cache=key_cache, + v_cache=value_cache, + req_to_token=block_table, + req_pool_indices=torch.arange(num_seqs), + seq_lens=torch.tensor(kv_lens, dtype=torch.int32), + scaling=1, + enable_gqa=True, + causal=False, + ) + + assert_allclose(ref_vllm_output, ref_sglang_output, rtol=1e-3, atol=1e-3) + + assert_allclose(output, ref_vllm_output, rtol=1e-3, atol=1e-3)