From 2ef7724978a74309c54739dc099373e655654cec Mon Sep 17 00:00:00 2001 From: Harsh Menon Date: Fri, 22 Nov 2024 14:08:33 -0800 Subject: [PATCH] Add attention with bias tests Signed-off-by: Harsh Menon --- iree/turbine/kernel/ops/wave_ops.py | 42 ++- .../kernel/wave/index_sequence_analysis.py | 329 +++++++++++++----- .../kernel/wave/thread_shape_analysis.py | 23 +- iree/turbine/kernel/wave/utils.py | 34 +- lit_tests/kernel/wave/attention.py | 116 ++++++ lit_tests/kernel/wave/expansion.py | 136 ++++++++ tests/kernel/wave/wave_attention_test.py | 212 +++++++++++ 7 files changed, 760 insertions(+), 132 deletions(-) diff --git a/iree/turbine/kernel/ops/wave_ops.py b/iree/turbine/kernel/ops/wave_ops.py index 627f2170..c4cac8c0 100644 --- a/iree/turbine/kernel/ops/wave_ops.py +++ b/iree/turbine/kernel/ops/wave_ops.py @@ -538,20 +538,6 @@ def expanded_dims(self, value: dict[IndexSymbol, int]): raise ValueError("Expanded dims must be a dict") self.fx_node.expanded_dims = value - @property - def anchor(self) -> fx.Node: - """ - The anchor is a node that provides information to the node - such as vector_shapes, indexing information etc. - """ - if hasattr(self.fx_node, "anchor"): - return self.fx_node.anchor - return None - - @anchor.setter - def anchor(self, value: fx.Node): - self.fx_node.anchor = value - @property def vector_shapes(self) -> dict[IndexSymbol, int]: if hasattr(self.fx_node, "vector_shapes"): @@ -590,6 +576,14 @@ def align_index(self, constraints: list["Constraint"]) -> None: """ pass + def transform_index( + self, index: dict[IndexSymbol, IndexSequence] + ) -> dict[IndexSymbol, IndexSequence]: + """ + Transform the index of the node based on the provided mapping. + """ + return index + @define_py_op(operator.add) @define_py_op(operator.sub) @@ -1426,6 +1420,26 @@ def infer_type(self): ), f"Target shape {self.target_shape} must be a permutation of source shape {src_type.symbolic_shape}" self.type = Register[*self.target_shape, src_type.dtype] + def transform_index( + self, index: dict[IndexSymbol, IndexSequence] + ) -> dict[IndexSymbol, IndexSequence]: + """ + The permute operation swaps the strides of the permuted indices. + So say we have a permute operation that swaps [B, M, N] to + [M, N, B], then we swap the strides of the dimensions. + """ + custom_src = get_custom(self.arg) + src_shape = custom_src.type.symbolic_shape + src_to_target = { + src: self.target_shape[src_shape.index(src)] for src in src_shape + } + permuted_index = { + k: IndexSequence(v.start, v.size, index[src_to_target[k]].stride) + for k, v in index.items() + if k in src_shape + } + return permuted_index + def _to_sequence(input: Any | Sequence[Any]) -> Sequence[Any]: return input if isinstance(input, Sequence) else (input,) diff --git a/iree/turbine/kernel/wave/index_sequence_analysis.py b/iree/turbine/kernel/wave/index_sequence_analysis.py index 8e174fc5..53643b4c 100644 --- a/iree/turbine/kernel/wave/index_sequence_analysis.py +++ b/iree/turbine/kernel/wave/index_sequence_analysis.py @@ -28,13 +28,13 @@ get_mma_dimensional_mapping, get_hardware_constraint, subs_idxc, - specialize_index_sequence, - capture_backward_slice, + get_inputs, + get_users, ) import torch.fx as fx import numpy as np from functools import partial -from typing import Sequence +from typing import Sequence, Callable from ...support.logging import get_logger import sympy from itertools import groupby @@ -248,25 +248,16 @@ def has_gpr_offsets(node: fx.Node) -> bool: reshape = Reshape(ops_to_combine, custom.vector_shapes).add_to_graph( custom.graph ) + reshape.expanded_dims = custom.expanded_dims + reshape.vector_shapes = custom.vector_shapes custom.replace_all_uses_with(reshape) custom.graph.erase_node(custom.fx_node) -def preprocess_nodes( - constraints: Sequence[Constraint], - mma_index: dict[MMA, dict[IndexSymbol, int]], - mma_slices: dict[MMA, dict[IndexSymbol, list[fx.Node]]], - node: fx.Node, -): - set_vector_shapes(constraints, mma_index, mma_slices, node) - set_node_index(constraints, mma_index, mma_slices, node) - - def set_node_indices(trace: CapturedTrace, constraints: list[Constraint]): - mma_index, mma_slices = get_mma_dimensional_mapping( - trace, get_hardware_constraint(constraints) - ) - trace.walk(partial(preprocess_nodes, constraints, mma_index, mma_slices)) + mma_index = get_mma_dimensional_mapping(trace, get_hardware_constraint(constraints)) + trace.walk(partial(set_thread_independent_index, constraints)) + set_thread_dependent_index(constraints, mma_index, trace) def compute_stride( @@ -355,100 +346,26 @@ def set_vector_shapes( return -def set_node_index( +def set_thread_independent_index( constraints: Sequence[Constraint], - mma_index: dict[MMA, dict[IndexSymbol, int]], - mma_slices: dict[MMA, dict[IndexSymbol, list[fx.Node]]], node: fx.Node, ): """ - Set the index of the node based on the user constraints. In certain - operators (like read, write), there is only a single index associated - with the node (the index to read from, the index to write to). But for - other operators like mma, each operand reads from a different index. - - Rather than maintain operand specific indices for operators, we maintain - dimension specific indices for each operator. So for an mma operator that - has a signature of (MxK, NxK) -> MxN, we maintain only 3 mappings for - dimensions M, N and K, but allow each mapping to be piecewise conditioned - on the operand. + Set the index of the node based on all constraints except the hardware constraint. """ custom = get_custom(node) - anchor = custom.anchor if isinstance(custom, (Reduction, Placeholder)) and not isinstance(custom, IterArg): return - hardware_constraint = [get_hardware_constraint(constraints)] - workgroup_constraints = { - c.dim: c for c in constraints if isinstance(c, WorkgroupConstraint) - } - other_constraints = [ + constraints = [ c for c in constraints if not isinstance(c, (HardwareConstraint, Assumption)) ] - # Apply hardware constraint first since it dictates the stride and size. - sorted_constraints = hardware_constraint + other_constraints index = {} - # The semantics of elements_per_thread are that it represents the number of - # elements that are loaded contiguously from memory. - elements_per_thread = getattr(custom, "elements_per_thread", None) - # For elementwise operations that do not have an elements per thread attribute, - # look back to the backward slice to see if they can find an appropriate value. - # TODO: Remove this once set_node_index is integrated with thread_shape_analysis. - if elements_per_thread is None: - backward_slice = capture_backward_slice(node) - for bwd_node in backward_slice: - custom_node = get_custom(bwd_node) - elements_per_thread = getattr(custom_node, "elements_per_thread", None) - if elements_per_thread: - break - for dim in custom.indexing_dims: index_seq = None - for constraint in sorted_constraints: - if isinstance(constraint, HardwareConstraint): - inputs = None - if anchor and dim in mma_index[anchor]: - inputs = (mma_index[anchor][dim], elements_per_thread, None) - else: - # Assumes vector shapes are associated with workgroup dims. - if dim not in workgroup_constraints: - continue - assert ( - dim in constraint.vector_shapes - ), f"Dimension {dim} not found in vector shapes" - if constraint.vector_shapes[dim] == 0: - continue - inputs = ( - workgroup_constraints[dim].workgroup_dim, - ( - 1 - if not is_contiguous_dim( - dim, - custom.indexing_dims, - constraint.vector_shapes, - ) - else elements_per_thread - ), - compute_stride( - custom.indexing_dims, constraint.vector_shapes, dim - ), - ) - if elements_per_thread is None: - # Here we end up with a situation where there will be no thread level - # dependence in the dimensional index. - # TODO: Evaluate if this is a valid case. - continue - mma_type = anchor.mma_type if anchor else None - index_seq = constraint.apply( - dim, *inputs, anchor and dim in mma_index[anchor], mma_type - ) - if anchor and dim in mma_index[anchor]: - index_seq = specialize_index_sequence( - index_seq, mma_slices[anchor], custom - ) - - elif constraint.dim == dim: + for constraint in constraints: + if constraint.dim == dim: if index_seq is None: index_seq = constraint.apply() else: @@ -462,6 +379,226 @@ def set_node_index( custom.index = index +def specialize_index( + index: dict[IndexSymbol, IndexSequence], subs: dict[IndexSymbol, int] +): + """ + Specialize the index sequence with the given substitutions. + """ + return {dim: seq.subs(subs) for dim, seq in index.items()} + + +def populate_mma_sources( + node: MMA, + mma_index: dict[MMA, dict[IndexSymbol, int]], + hardware_constraint: HardwareConstraint, +): + """ + Initialize the sources with the LHS, RHS, ACC and MMA node + and their index sequences and vector shapes. These will + be propagated to the rest of the graph. + """ + index: dict[IndexSymbol, IndexSequence] = {} + mapping = mma_index[node] + for dim, dim_index in mapping.items(): + index[dim] = hardware_constraint.apply( + dim, dim_index, None, None, True, node.mma_type + ) + node.index = combine_indices(node.index, index) + return [ + ( + get_custom(node.lhs), + specialize_index(index, {MMA_LHS: 1, MMA_RHS: 0, MMA_ACC: 0}), + node.vector_shapes, + ), + ( + get_custom(node.rhs), + specialize_index(index, {MMA_LHS: 0, MMA_RHS: 1, MMA_ACC: 0}), + node.vector_shapes, + ), + ( + get_custom(node.acc), + specialize_index(index, {MMA_LHS: 0, MMA_RHS: 0, MMA_ACC: 1}), + node.vector_shapes, + ), + ( + node, + specialize_index(index, {MMA_LHS: 0, MMA_RHS: 0, MMA_ACC: 1}), + node.vector_shapes, + ), + ] + + +def populate_non_mma_sources( + node: Read | Write, + hardware_constraint: HardwareConstraint, + workgroup_constraints: list[WorkgroupConstraint], +): + """ + Initialize the sources with the read and/or write nodes + and their index sequences and vector shapes. These will + be propagated to the rest of the graph. + """ + index: dict[IndexSymbol, IndexSequence] = {} + for dim in node.indexing_dims: + elements_per_thread = ( + 1 + if not is_contiguous_dim( + dim, node.indexing_dims, hardware_constraint.vector_shapes + ) + else node.elements_per_thread + ) + stride = compute_stride( + node.indexing_dims, hardware_constraint.vector_shapes, dim + ) + wg_constraint = [x for x in workgroup_constraints if x.dim == dim] + if not wg_constraint: + continue + index[dim] = hardware_constraint.apply( + dim, + wg_constraint[0].workgroup_dim, + elements_per_thread, + stride, + False, + None, + ) + return [(node, index, hardware_constraint.vector_shapes)] + + +def combine_indices( + thread_independent_index: dict[IndexSymbol, IndexSequence], + thread_dependent_index: dict[IndexSymbol, IndexSequence], +) -> dict[IndexSymbol, IndexSequence]: + combined_index = {k: v for k, v in thread_independent_index.items()} + for k in combined_index: + if k in thread_dependent_index: + combined_index[k].start += thread_dependent_index[k].start + combined_index[k].size = thread_dependent_index[k].size + combined_index[k].stride = thread_dependent_index[k].stride + return combined_index + + +def add_nodes_to_sources( + source: CustomOp, + reduction: Reduction, + fn: Callable, + source_index: dict[IndexSymbol, IndexSequence], + source_vector_shapes: dict[IndexSymbol, int], + sources: list[ + tuple[CustomOp, dict[IndexSymbol, IndexSequence], dict[IndexSymbol, int]] + ], +) -> tuple[list[CustomOp], Reduction]: + """ + Populate the sources with the inputs and users of the source node. + """ + for args, reduction in [fn(source.fx_node, reduction)]: + logger.debug(f"{source.fx_node} -> {args}") + if not args: + break + for arg in args: + custom = get_custom(arg) + if isinstance(custom, (Allocate, Placeholder)) and not isinstance( + custom, IterArg + ): + continue + vector_shapes = ( + custom.vector_shapes if custom.vector_shapes else source_vector_shapes + ) + sources.append((custom, source_index, vector_shapes)) + return sources, reduction + + +def should_update_index( + source: CustomOp, + source_index: dict[IndexSymbol, IndexSequence], + source_vector_shapes: dict[IndexSymbol, int], +): + # Determine if we should update the idx based on the source. + # We update the source only if the source index provides + # information about all the non-batch dimensions of the source. + non_batch_dims = [x for x in source.indexing_dims if source_vector_shapes[x] > 1] + + # If the source index is smaller than the non-batch dims, check if the + # source index is a subset of the non-batch dims. + if len(source_index.keys()) < len(non_batch_dims): + return set(source_index.keys()).issubset(set(non_batch_dims)) + + # Otherwise, check if the non-batch dims are a subset of the source index. + if not set(non_batch_dims).issubset(set(source_index.keys())): + return False + + return True + + +def propagate_index( + node: CustomOp, + hardware_constraint: HardwareConstraint, + workgroup_constraints: list[WorkgroupConstraint], + mma_index: dict[MMA, dict[IndexSymbol, int]], + visited: set[CustomOp], +): + """ + Propagate the index and vector shapes through the graph + starting with priveleged nodes (like MMA, Read, Write). + """ + sources = set() + if isinstance(node, MMA): + sources = populate_mma_sources(node, mma_index, hardware_constraint) + else: + sources = populate_non_mma_sources( + node, hardware_constraint, workgroup_constraints + ) + reduction = None + while sources: + source, source_index, source_vector_shapes = sources.pop(0) + if source in visited: + continue + if not isinstance(source, (Reduction, MMA)): + if not should_update_index(source, source_index, source_vector_shapes): + continue + source_index = source.transform_index(source_index) + source.index = combine_indices(source.index, source_index) + source.vector_shapes = source_vector_shapes + visited.add(source) + for func in [get_inputs, get_users]: + sources, reduction = add_nodes_to_sources( + source, + reduction, + func, + source_index, + source_vector_shapes, + sources, + ) + return visited + + +def set_thread_dependent_index( + constraints: Sequence[Constraint], + mma_index: dict[MMA, dict[IndexSymbol, int]], + trace: CapturedTrace, +): + """ + Set the thread dependent index based on the hardware constraint. + """ + hardware_constraint = get_hardware_constraint(constraints) + sources: list[MMA] = list(mma_index.keys()) + if not sources: + sources = trace.walk(lambda node: isinstance(get_custom(node), (Read, Write))) + sources = [get_custom(x) for x in sources] + assert sources, "No read or mma nodes found in the graph." + + visited = set() + workgroup_constraints = [ + c for c in constraints if isinstance(c, WorkgroupConstraint) + ] + for source in sources: + visited = visited.union(set([x for x in sources])) + visited.remove(source) + visited = propagate_index( + source, hardware_constraint, workgroup_constraints, mma_index, visited + ) + + def set_post_expansion_indices(trace: CapturedTrace, constraints: list[Constraint]): """ Add offsets to the indices based on the expanded dims. diff --git a/iree/turbine/kernel/wave/thread_shape_analysis.py b/iree/turbine/kernel/wave/thread_shape_analysis.py index 129f7551..5573cc53 100644 --- a/iree/turbine/kernel/wave/thread_shape_analysis.py +++ b/iree/turbine/kernel/wave/thread_shape_analysis.py @@ -68,6 +68,17 @@ def propagatable_op(node: fx.Node): ) +def propagate_resolutions( + custom_node: CustomOp, dst_op: CustomOp = None +) -> list[fx.Node]: + propagated_resolutions = capture_forward_slice(custom_node.fx_node, propagatable_op) + if dst_op: + for node in propagated_resolutions: + get_custom(node).index = dst_op.index + resolved_resolutions = capture_backward_slice(custom_node.fx_node, propagatable_op) + return propagated_resolutions.union(resolved_resolutions) + + def handle_binaryop_conflict(custom_node: CustomOp) -> list[fx.Node]: """ This function will attempt to resolve binaryOp conflicts @@ -81,7 +92,6 @@ def handle_binaryop_conflict(custom_node: CustomOp) -> list[fx.Node]: lhs_dim_set = set(lhs.type.symbolic_shape) rhs_dim_set = set(rhs.type.symbolic_shape) if lhs_dim_set == rhs_dim_set: - # Could be caused by consumers(likely also binaryOp) of this node. return [] if lhs_dim_set.isdisjoint(rhs_dim_set): raise ValueError("Cannot broadcast if lhs and rhs has disjointed shapes.") @@ -94,17 +104,8 @@ def handle_binaryop_conflict(custom_node: CustomOp) -> list[fx.Node]: ) custom_broadcast = get_custom(broadcast) custom_broadcast.vector_shapes = broadcast_src.vector_shapes - custom_broadcast.anchor = broadcast_src.anchor custom_node.update_arg(broadcast_idx, custom_broadcast.fx_node) - propagated_resolutions = capture_forward_slice( - custom_broadcast.fx_node, propagatable_op - ) - for node in propagated_resolutions: - get_custom(node).index = dst_op.index - resolved_resolutions = capture_backward_slice( - custom_broadcast.fx_node, propagatable_op - ) - return propagated_resolutions.union(resolved_resolutions) + return propagate_resolutions(custom_broadcast, dst_op) # Returns True iff all conflicts are handled succesfully. diff --git a/iree/turbine/kernel/wave/utils.py b/iree/turbine/kernel/wave/utils.py index be4cd09e..e2e6879e 100644 --- a/iree/turbine/kernel/wave/utils.py +++ b/iree/turbine/kernel/wave/utils.py @@ -297,7 +297,6 @@ def is_mma(node): } if hardware_constraint.vector_shapes: custom.vector_shapes.update(hardware_constraint.vector_shapes) - custom.anchor = custom custom.reduction_dim = k # Since expansion proceeds bottom-up, we set the vector shapes @@ -305,9 +304,6 @@ def is_mma(node): if hasattr(custom.graph, "parent_op"): reduction = get_custom(custom.graph.parent_op) reduction.vector_shapes = custom.vector_shapes - reduction.anchor = custom - - mma_slices = {get_custom(x): capture_mma_slices(get_custom(x)) for x in mma_nodes} # Determine if any reshapes are required. Reshapes are added for # chained matmuls when the vector shapes of the operands in one matmul @@ -325,7 +321,6 @@ def add_reshape_if_needed(mma: MMA, prev_mma: MMA, arg_index: int): ) custom_reshape = get_custom(reshape) custom_reshape.vector_shapes = custom.vector_shapes - custom_reshape.anchor = custom custom.update_arg(arg_index, reshape) def find_mma_in_slice(node: CustomOp) -> Optional[MMA]: @@ -351,7 +346,7 @@ def find_mma_in_slice(node: CustomOp) -> Optional[MMA]: if prev_mma: add_reshape_if_needed(custom_mma, prev_mma, 1) - return mapping, mma_slices + return mapping def get_hardware_vector_size( @@ -718,7 +713,8 @@ def get_users( # Map init arg to iter arg reduction = custom init_arg_idx = custom.init_args.index(node) - users.append(custom.iter_args[init_arg_idx]) + graph = custom.get_root_graph().subgraphs[custom.subgraph_name] + users.append(custom.iter_args(graph)[init_arg_idx]) continue if isinstance(custom, Output): # Map output to get result @@ -939,9 +935,17 @@ def get_mfma_load_elems_per_thread(mfma_variant: MMAType) -> int: return 4 case MMAType.F32_32x32x8_F16 | MMAType.I32_32x32x8_I8: return 4 - case MMAType.F32_16x16x32_F8 | MMAType.F32_16x16x32_K4_F8 | MMAType.I32_16x16x32_I8: + case ( + MMAType.F32_16x16x32_F8 + | MMAType.F32_16x16x32_K4_F8 + | MMAType.I32_16x16x32_I8 + ): return 8 - case MMAType.F32_32x32x16_F8 | MMAType.F32_32x32x16_K4_F8 | MMAType.I32_32x32x16_I8: + case ( + MMAType.F32_32x32x16_F8 + | MMAType.F32_32x32x16_K4_F8 + | MMAType.I32_32x32x16_I8 + ): return 8 @@ -951,9 +955,17 @@ def get_mfma_store_elems_per_thread(mfma_variant: MMAType) -> int: return 4 case MMAType.F32_32x32x8_F16 | MMAType.I32_32x32x8_I8: return 16 - case MMAType.F32_16x16x32_F8 | MMAType.F32_16x16x32_K4_F8 | MMAType.I32_16x16x32_I8: + case ( + MMAType.F32_16x16x32_F8 + | MMAType.F32_16x16x32_K4_F8 + | MMAType.I32_16x16x32_I8 + ): return 4 - case MMAType.F32_32x32x16_F8 | MMAType.F32_32x32x16_K4_F8 | MMAType.I32_32x32x16_I8: + case ( + MMAType.F32_32x32x16_F8 + | MMAType.F32_32x32x16_K4_F8 + | MMAType.I32_32x32x16_I8 + ): return 16 diff --git a/lit_tests/kernel/wave/attention.py b/lit_tests/kernel/wave/attention.py index 843e2578..b7ec5889 100644 --- a/lit_tests/kernel/wave/attention.py +++ b/lit_tests/kernel/wave/attention.py @@ -529,3 +529,119 @@ def repeat( # CHECK-COUNT-16: {{.*}} = amdgpu.mfma # CHECK-COUNT-8: {{.*}} = gpu.shuffle xor {{.*}} # CHECK-COUNT-8: {{.*}} = amdgpu.mfma + + +@run_test +def test_attention_bias(): + shape = (8, 128, 128, 64, 256) + # Expose user-constraints + constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.WorkgroupConstraint(B, BLOCK_B, 2)] + constraints += [tkw.TilingConstraint(K2, BLOCK_K2)] + constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)] + constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)] + + mfma_variant = tkw.MMAType.F32_16x16x16_F16 + constraints += [ + tkw.HardwareConstraint( + threads_per_wave=64, + waves_per_block=(2, 2, 1), + mma_type=mfma_variant, + vector_shapes={B: 0, M: 16, N: 16}, + ) + ] + + i = tkw.IndexMapping.iterator(0) + j = tkw.IndexMapping.iterator(1) + k = tkw.IndexMapping.iterator(2) + mapping = tkw.IndexMapping( + num_iterators=3, inputs={B: i, N: j, M: k}, outputs={B: i, M: k, N: j} + ) + + @tkw.wave(constraints) + def base_attention_bias( + q: tkl.Memory[B, M, K1, ADDRESS_SPACE, tkl.f16], + k: tkl.Memory[B, K2, K1, ADDRESS_SPACE, tkl.f16], + v: tkl.Memory[B, N, K2, ADDRESS_SPACE, tkl.f16], + bias: tkl.Memory[B, M, K2, GLOBAL_ADDRESS_SPACE, tkl.f32], + c: tkl.Memory[B, M, N, GLOBAL_ADDRESS_SPACE, tkl.f32], + ): + c_reg = tkl.Register[B, N, M, tkl.f32](0.0) + init_sum = tkl.Register[B, M, tkl.f32](0.0) + init_max = tkl.Register[B, M, tkl.f32](-1e6) + + # This microkernel encodes the fact that if the reduction + # dimension were tiled, then we would need to materialize a loop. + @tkw.reduction(K2, init_args=[init_max, init_sum, c_reg]) + def repeat( + partial_max: tkl.Register[B, M, tkl.f32], + partial_sum: tkl.Register[B, M, tkl.f32], + acc: tkl.Register[B, N, M, tkl.f32], + ) -> ( + tkl.Register[B, M, tkl.f32], + tkl.Register[B, M, tkl.f32], + tkl.Register[B, N, M, tkl.f32], + ): + imm_reg = tkl.Register[B, K2, M, tkl.f32](0.0) + q_reg = tkw.read(q, elements_per_thread=LOAD_ELEMS_PER_THREAD) + # b_reg: tkw.Register[B, N, K, tkl.f16] + k_reg = tkw.read(k, elements_per_thread=LOAD_ELEMS_PER_THREAD) + # acc: tkw.Register[B, N, M, tkl.f32] + inner_acc = tkw.mma(k_reg, q_reg, imm_reg) + x_j = tkw.permute(inner_acc, target_shape=[B, M, K2]) + bias_reg = tkw.read(bias, elements_per_thread=STORE_ELEMS_PER_THREAD) + x_j = x_j + bias_reg + m_j = tkw.max(x_j, partial_max, dim=K2) + e_delta_max = tkw.exp2(partial_max - m_j) + e_delta = tkw.exp2(x_j - m_j) + e_init = partial_sum * e_delta_max + d_j = tkw.sum(e_delta, e_init, dim=K2) + imm_f16 = tkw.cast(e_delta, tkl.f16) + v_reg = tkw.read(v, elements_per_thread=LOAD_ELEMS_PER_THREAD) + new_acc = acc * e_delta_max + acc = tkw.mma(v_reg, imm_f16, new_acc) + return m_j, d_j, acc + + # repeat represents the results of the loop + res_max, res_sum, res_mm = repeat + res = res_mm / res_sum + tkw.write(res, c, mapping=mapping, elements_per_thread=STORE_ELEMS_PER_THREAD) + + hyperparams = { + ADDRESS_SPACE: SHARED_ADDRESS_SPACE, + LOAD_ELEMS_PER_THREAD: get_mfma_load_elems_per_thread(mfma_variant), + STORE_ELEMS_PER_THREAD: get_mfma_store_elems_per_thread(mfma_variant), + BLOCK_B: 1, + BLOCK_M: 64, + BLOCK_N: 64, + BLOCK_K2: 32, + B: shape[0], + M: shape[1], + N: shape[2], + K1: shape[3], + K2: shape[4], + READ_SHARED_DELAY: 1, + WRITE_SHARED_DELAY: 1, + READ_GLOBAL_DELAY: 2, + WRITE_GLOBAL_DELAY: 2, + MMA_DELAY: 1, + SHARED_MEMORY_UNITS: 4, + GLOBAL_MEMORY_UNITS: 4, + MMA_UNITS: 4, + } + + with tk.gen.TestLaunchContext( + hyperparams, + canonicalize=True, + run=False, + run_bench=False, + schedule=False, + use_scheduling_barriers=False, + ): + torch.manual_seed(0) + q = torch.randn(shape[0], shape[1], shape[3], dtype=torch.float16) + k = torch.randn(shape[0], shape[4], shape[3], dtype=torch.float16) + v = torch.randn(shape[0], shape[4], shape[2], dtype=torch.float16) + output = torch.zeros(shape[0], shape[1], shape[2], dtype=torch.float32) + print(base_attention_bias(q, k, v, output).module_op) diff --git a/lit_tests/kernel/wave/expansion.py b/lit_tests/kernel/wave/expansion.py index b881c6f2..77ac683a 100644 --- a/lit_tests/kernel/wave/expansion.py +++ b/lit_tests/kernel/wave/expansion.py @@ -765,6 +765,142 @@ def test_gemm_reduction_expansion_only(): # CHECK-NEXT: ----- +@tkw.wave_trace_only() +def attention( + q: tkl.Memory[B, M, K1, ADDRESS_SPACE, tkl.f16], + k: tkl.Memory[B, K2, K1, ADDRESS_SPACE, tkl.f16], + v: tkl.Memory[B, N, K2, ADDRESS_SPACE, tkl.f16], + c: tkl.Memory[B, N, M, GLOBAL_ADDRESS_SPACE, tkl.f32], +): + c_reg = tkl.Register[B, N, M, tkl.f32](0.0) + init_sum = tkl.Register[B, M, tkl.f32](0.0) + init_max = tkl.Register[B, M, tkl.f32](-1e6) + + # This microkernel encodes the fact that if the reduction + # dimension were tiled, then we would need to materialize a loop. + @tkw.reduction(K2, init_args=[init_max, init_sum, c_reg]) + def repeat( + partial_max: tkl.Register[B, M, tkl.f32], + partial_sum: tkl.Register[B, M, tkl.f32], + acc: tkl.Register[B, N, M, tkl.f32], + ) -> ( + tkl.Register[B, M, tkl.f32], + tkl.Register[B, M, tkl.f32], + tkl.Register[B, N, M, tkl.f32], + ): + imm_reg = tkl.Register[B, K2, M, tkl.f32](0.0) + q_reg = tkw.read(q, elements_per_thread=4) + k_reg = tkw.read(k, elements_per_thread=4) + inner_acc = tkw.mma(k_reg, q_reg, imm_reg) + x_j = tkw.permute(inner_acc, target_shape=[B, M, K2]) + m_j = tkw.max(x_j, partial_max, dim=K2) + e_delta_max = tkw.exp2(partial_max - m_j) + e_delta = tkw.exp2(x_j - m_j) + e_init = partial_sum * e_delta_max + d_j = tkw.sum(e_delta, e_init, dim=K2) + imm_f16 = tkw.cast(e_delta, tkl.f16) + v_reg = tkw.read(v, elements_per_thread=4) + new_acc = acc * e_delta_max + acc = tkw.mma(v_reg, imm_f16, new_acc) + return m_j, d_j, acc + + # repeat represents the results of the loop + res_max, res_sum, res_mm = repeat + res = res_mm / res_sum + tkw.write(res, c, elements_per_thread=4) + + +@run_test +def test_attention(): + constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.WorkgroupConstraint(B, BLOCK_B, 2)] + constraints += [tkw.TilingConstraint(K2, BLOCK_K2, ARGK)] + constraints += [tkw.WaveConstraint(M, BLOCK_M / 2, THREAD_0 / 64)] + constraints += [tkw.WaveConstraint(N, BLOCK_N / 2, THREAD_1)] + + mfma_variant = tkw.MMAType.F32_16x16x16_F16 + constraints += [ + tkw.HardwareConstraint( + threads_per_wave=64, + waves_per_block=(2, 2, 1), + mma_type=mfma_variant, + vector_shapes={B: 0, M: 16, N: 16}, + ) + ] + + with tk.gen.TestLaunchContext( + { + K1: 64, + BLOCK_M: 64, + BLOCK_N: 64, + BLOCK_B: 1, + BLOCK_K2: 32, + } + ): + graph = attention() + IndexingContext.current().finalize() + infer_types(graph) + set_node_indices(graph, constraints) + expand_graph(graph, constraints) + set_post_expansion_indices(graph, constraints) + print_trace(graph) + + # Root graph: + # CHECK: write(register_=truediv_0_0_0, + # CHECK-SAME: index={B: $WG2*BLOCK_B, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + 4*floor((Mod($T0, 64))/16) : 4 : 16, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16)}) + # CHECK: write(register_=truediv_1_1_0, + # CHECK-SAME: index={B: $WG2*BLOCK_B, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16}) + # CHECK: write(register_=truediv_1_0_0, + # CHECK-SAME: index={B: $WG2*BLOCK_B, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + 4*floor((Mod($T0, 64))/16) : 4 : 16, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16}) + # CHECK: write(register_=truediv_0_1_0, + # CHECK-SAME: index={B: $WG2*BLOCK_B, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 16, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16)}) + + # Reduction graph: + # CHECK: read(memory=q, + # CHECK-SAME: index={B: $WG2*BLOCK_B, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), K1: 4*floor((Mod($T0, 64))/16) : 4 : 1}) + # CHECK: read(memory=q, + # CHECK-SAME: index={B: $WG2*BLOCK_B, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16, K1: 4*floor((Mod($T0, 64))/16) : 4 : 1}) + # CHECK: read(memory=q, + # CHECK-SAME: index={B: $WG2*BLOCK_B, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16, K1: 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) + # CHECK: read(memory=q, + # CHECK-SAME: index={B: $WG2*BLOCK_B, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16, K1: 4*floor((Mod($T0, 64))/16) + 32 : 4 : 1}) + # CHECK: read(memory=q, + # CHECK-SAME: index={B: $WG2*BLOCK_B, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16) + 16, K1: 4*floor((Mod($T0, 64))/16) + 48 : 4 : 1}) + # CHECK: read(memory=q, + # CHECK-SAME: index={B: $WG2*BLOCK_B, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), K1: 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) + # CHECK: read(memory=q, + # CHECK-SAME: index={B: $WG2*BLOCK_B, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), K1: 4*floor((Mod($T0, 64))/16) + 32 : 4 : 1}) + # CHECK: read(memory=q, + # CHECK-SAME: index={B: $WG2*BLOCK_B, M: $T0*BLOCK_M/128 + $WG0*BLOCK_M + Mod($T0, 16), K1: 4*floor((Mod($T0, 64))/16) + 48 : 4 : 1}) + + # CHECK: read(memory=k, + # CHECK-SAME: index={B: $WG2*BLOCK_B, K2: ARGK*BLOCK_K2 + Mod($T0, 16), K1: 4*floor((Mod($T0, 64))/16) : 4 : 1}) + # CHECK: read(memory=k, + # CHECK-SAME: index={B: $WG2*BLOCK_B, K2: ARGK*BLOCK_K2 + Mod($T0, 16) + 16, K1: 4*floor((Mod($T0, 64))/16) : 4 : 1}) + # CHECK: read(memory=k, + # CHECK-SAME: index={B: $WG2*BLOCK_B, K2: ARGK*BLOCK_K2 + Mod($T0, 16) + 16, K1: 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) + # CHECK: read(memory=k, + # CHECK-SAME: index={B: $WG2*BLOCK_B, K2: ARGK*BLOCK_K2 + Mod($T0, 16) + 16, K1: 4*floor((Mod($T0, 64))/16) + 32 : 4 : 1}) + # CHECK: read(memory=k, + # CHECK-SAME: index={B: $WG2*BLOCK_B, K2: ARGK*BLOCK_K2 + Mod($T0, 16) + 16, K1: 4*floor((Mod($T0, 64))/16) + 48 : 4 : 1}) + # CHECK: read(memory=k, + # CHECK-SAME: index={B: $WG2*BLOCK_B, K2: ARGK*BLOCK_K2 + Mod($T0, 16), K1: 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) + # CHECK: read(memory=k, + # CHECK-SAME: index={B: $WG2*BLOCK_B, K2: ARGK*BLOCK_K2 + Mod($T0, 16), K1: 4*floor((Mod($T0, 64))/16) + 32 : 4 : 1}) + # CHECK: read(memory=k, + # CHECK-SAME: index={B: $WG2*BLOCK_B, K2: ARGK*BLOCK_K2 + Mod($T0, 16), K1: 4*floor((Mod($T0, 64))/16) + 48 : 4 : 1}) + + # CHECK: read(memory=v, + # CHECK-SAME: index={B: $WG2*BLOCK_B, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16), K2: ARGK*BLOCK_K2 + 4*floor((Mod($T0, 64))/16) : 4 : 1}) + # CHECK: read(memory=v, + # CHECK-SAME: index={B: $WG2*BLOCK_B, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16), K2: ARGK*BLOCK_K2 + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) + # CHECK: read(memory=v, + # CHECK-SAME: index={B: $WG2*BLOCK_B, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16, K2: ARGK*BLOCK_K2 + 4*floor((Mod($T0, 64))/16) : 4 : 1}) + # CHECK: read(memory=v, + # CHECK-SAME: index={B: $WG2*BLOCK_B, N: $T1*BLOCK_N/2 + $WG1*BLOCK_N + Mod($T0, 16) + 16, K2: ARGK*BLOCK_K2 + 4*floor((Mod($T0, 64))/16) + 16 : 4 : 1}) + + @tkw.wave_trace_only() def py_arithmetic_different_dims( a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16], diff --git a/tests/kernel/wave/wave_attention_test.py b/tests/kernel/wave/wave_attention_test.py index b1680e3c..6d67fb7e 100644 --- a/tests/kernel/wave/wave_attention_test.py +++ b/tests/kernel/wave/wave_attention_test.py @@ -7,6 +7,7 @@ import logging import pytest import torch +from torch.nn import functional as F import math import unittest import iree.turbine.kernel as tk @@ -563,6 +564,216 @@ def repeat( assert_allclose(output, torch_ref) +@require_e2e +@pytest.mark.parametrize("shape", get_test_shapes("test_attention")) +@pytest.mark.parametrize("enable_scheduling", [False, True]) +@pytest.mark.parametrize("dynamic_dims", [False, True]) +@pytest.mark.parametrize( + "mfma_variant", + [ + MMAType.F32_16x16x16_F16, + MMAType.F32_32x32x8_F16, + ], +) +def testAttentionBias( + shape: tuple[int], + enable_scheduling: bool, + dynamic_dims: bool, + mfma_variant: MMAType, + request, +): + run_bench = request.config.getoption("--runperf") + dump_perf = request.config.getoption("--dump-perf-files-path") + # Input sizes + B = tkl.sym.B + M = tkl.sym.M + N = tkl.sym.N + K1 = tkl.sym.K1 + K2 = tkl.sym.K2 + # Workgroup tile sizes + BLOCK_B = tkl.sym.BLOCK_B + BLOCK_M = tkl.sym.BLOCK_M + BLOCK_N = tkl.sym.BLOCK_N + BLOCK_K2 = tkl.sym.BLOCK_K2 + # Address space (for GPU, shared(1) or global(0)) + ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE + # Other hyperparameters + LOAD_ELEMS_PER_THREAD = tkl.sym.LOAD_ELEMS_PER_THREAD + STORE_ELEMS_PER_THREAD = tkl.sym.STORE_ELEMS_PER_THREAD + + # Expose user-constraints + constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.WorkgroupConstraint(B, BLOCK_B, 2)] + constraints += [tkw.TilingConstraint(K2, BLOCK_K2)] + constraints += [tkw.WaveConstraint(M, BLOCK_M / 4)] + constraints += [tkw.WaveConstraint(N, BLOCK_N / 1)] + + if mfma_variant == MMAType.F32_16x16x16_F16: + Mvec = 16 + Nvec = 16 + if mfma_variant == MMAType.F32_32x32x8_F16: + Mvec = 32 + Nvec = 32 + + constraints += [ + tkw.HardwareConstraint( + threads_per_wave=64, + waves_per_block=(4, 1, 1), + mma_type=mfma_variant, + vector_shapes={B: 0, M: Mvec, N: Nvec}, + ) + ] + + if dynamic_dims: + constraints += [tkw.Assumption(K2 > BLOCK_K2 * 4)] + + i = tkw.IndexMapping.iterator(0) + j = tkw.IndexMapping.iterator(1) + k = tkw.IndexMapping.iterator(2) + mapping = tkw.IndexMapping( + num_iterators=3, inputs={B: i, N: j, M: k}, outputs={B: i, M: k, N: j} + ) + + @tkw.wave(constraints) + def base_attention_bias( + q: tkl.Memory[B, M, K1, ADDRESS_SPACE, tkl.f16], + k: tkl.Memory[B, K2, K1, ADDRESS_SPACE, tkl.f16], + v: tkl.Memory[B, N, K2, ADDRESS_SPACE, tkl.f16], + bias: tkl.Memory[B, M, K2, GLOBAL_ADDRESS_SPACE, tkl.f32], + c: tkl.Memory[B, M, N, GLOBAL_ADDRESS_SPACE, tkl.f32], + ): + c_reg = tkl.Register[B, N, M, tkl.f32](0.0) + init_sum = tkl.Register[B, M, tkl.f32](0.0) + init_max = tkl.Register[B, M, tkl.f32](-1e6) + + # This microkernel encodes the fact that if the reduction + # dimension were tiled, then we would need to materialize a loop. + @tkw.reduction(K2, init_args=[init_max, init_sum, c_reg]) + def repeat( + partial_max: tkl.Register[B, M, tkl.f32], + partial_sum: tkl.Register[B, M, tkl.f32], + acc: tkl.Register[B, N, M, tkl.f32], + ) -> ( + tkl.Register[B, M, tkl.f32], + tkl.Register[B, M, tkl.f32], + tkl.Register[B, N, M, tkl.f32], + ): + imm_reg = tkl.Register[B, K2, M, tkl.f32](0.0) + q_reg = tkw.read(q, elements_per_thread=LOAD_ELEMS_PER_THREAD) + # b_reg: tkw.Register[B, N, K, tkl.f16] + k_reg = tkw.read(k, elements_per_thread=LOAD_ELEMS_PER_THREAD) + # acc: tkw.Register[B, N, M, tkl.f32] + inner_acc = tkw.mma(k_reg, q_reg, imm_reg) + x_j = tkw.permute(inner_acc, target_shape=[B, M, K2]) + bias_reg = tkw.read(bias, elements_per_thread=STORE_ELEMS_PER_THREAD) + x_j = x_j + bias_reg + m_j = tkw.max(x_j, partial_max, dim=K2) + e_delta_max = tkw.exp2(partial_max - m_j) + e_delta = tkw.exp2(x_j - m_j) + e_init = partial_sum * e_delta_max + d_j = tkw.sum(e_delta, e_init, dim=K2) + imm_f16 = tkw.cast(e_delta, tkl.f16) + v_reg = tkw.read(v, elements_per_thread=LOAD_ELEMS_PER_THREAD) + new_acc = acc * e_delta_max + acc = tkw.mma(v_reg, imm_f16, new_acc) + return m_j, d_j, acc + + # repeat represents the results of the loop + res_max, res_sum, res_mm = repeat + res = res_mm / res_sum + tkw.write(res, c, mapping=mapping, elements_per_thread=STORE_ELEMS_PER_THREAD) + + hyperparams = { + ADDRESS_SPACE: SHARED_ADDRESS_SPACE, + LOAD_ELEMS_PER_THREAD: get_mfma_load_elems_per_thread(mfma_variant), + STORE_ELEMS_PER_THREAD: get_mfma_store_elems_per_thread(mfma_variant), + BLOCK_B: 1, + BLOCK_M: 128, + BLOCK_N: 64, + BLOCK_K2: 64, + B: shape[0], + M: shape[1], + N: shape[2], + K1: shape[3], + K2: shape[4], + READ_SHARED_DELAY: 1, + WRITE_SHARED_DELAY: 1, + READ_GLOBAL_DELAY: 2, + WRITE_GLOBAL_DELAY: 2, + MMA_DELAY: 1, + VALU_DELAY: 1, + SHUFFLE_DELAY: 1, + SHARED_MEMORY_UNITS: 4, + GLOBAL_MEMORY_UNITS: 4, + MMA_UNITS: 4, + VALU_UNITS: 2, + SHUFFLE_UNITS: 2, + } + config = get_default_run_config() + if run_bench: + config["benchmark_batch_size"] = 10 + config["benchmark_repetitions"] = 3 + if dump_perf is not None: + perf_filename = request.node.name + ".json" + config["benchmark_results_file"] = os.path.join( + dump_perf, "tk_" + perf_filename + ) + + dynamic_symbols = [] + dynamic_symbols_map = {} + if dynamic_dims: + dynamic_symbols_map[M] = hyperparams[M] + dynamic_symbols_map[N] = hyperparams[N] + dynamic_symbols_map[B] = hyperparams[B] + dynamic_symbols_map[K2] = hyperparams[K2] + dynamic_symbols.append(M) + dynamic_symbols.append(N) + dynamic_symbols.append(B) + dynamic_symbols.append(K2) + del hyperparams[M] + del hyperparams[N] + del hyperparams[B] + del hyperparams[K2] + + with tk.gen.TestLaunchContext( + hyperparams, + canonicalize=True, + run=True, + run_bench=run_bench, + run_config=config, + schedule=enable_scheduling, + use_scheduling_barriers=enable_scheduling_barriers, + dynamic_symbols=dynamic_symbols, + dynamic_symbols_map=dynamic_symbols_map, + ): + torch.manual_seed(0) + q = device_randn(shape[0], shape[1], shape[3], dtype=torch.float16) + k = device_randn(shape[0], shape[4], shape[3], dtype=torch.float16) + v = device_randn(shape[0], shape[4], shape[2], dtype=torch.float16) + bias = device_randn(shape[0], shape[1], shape[4], dtype=torch.float32) + output = device_zeros(shape[0], shape[1], shape[2], dtype=torch.float32) + log2e = 1.44269504089 + dk_sqrt = math.sqrt(1.0 / shape[3]) + # TODO: Add scaling of QK as part of kernel. + # TODO: Add variant of non-transposed V attention kernel. + mb = base_attention_bias( + q * dk_sqrt * log2e, k, v.permute([0, 2, 1]), bias * log2e, output + ) + k_t = k.transpose(-1, -2) + a = torch.matmul(q, k_t) * dk_sqrt + a += bias + a = F.softmax(a, dim=-1) + torch_ref = torch.matmul(a, v) + + if test_dump_generated_mlir: + filename = f"wave_attention_{'x'.join(map(str, shape))}.mlir" + with open(filename, "w") as f: + f.write(mb.module_op.get_asm()) + + assert_allclose(output, torch_ref, atol=2e-3, rtol=5e-3) + + @require_e2e @require_cdna3 @pytest.mark.parametrize("shape", get_test_shapes("test_attention")) @@ -633,6 +844,7 @@ def base_attention( c_reg = tkl.Register[B, N, M, tkl.f32](0.0) init_sum = tkl.Register[B, M, tkl.f32](0.0) init_max = tkl.Register[B, M, tkl.f32](-1e6) + # This microkernel encodes the fact that if the reduction # dimension were tiled, then we would need to materialize a loop. @tkw.reduction(K2, init_args=[init_max, init_sum, c_reg])