From ed7d74724269bce64b909ce2bfaf9a80fdb8271a Mon Sep 17 00:00:00 2001 From: Harsh Menon Date: Thu, 19 Dec 2024 07:44:37 -0800 Subject: [PATCH] More cleanups Signed-off-by: Harsh Menon --- .../kernel/wave/index_sequence_analysis.py | 26 +++++++++++++------ .../kernel/wave/templates/decode_attention.py | 9 ++----- 2 files changed, 20 insertions(+), 15 deletions(-) diff --git a/iree/turbine/kernel/wave/index_sequence_analysis.py b/iree/turbine/kernel/wave/index_sequence_analysis.py index e65cebd5d..29a6e0c76 100644 --- a/iree/turbine/kernel/wave/index_sequence_analysis.py +++ b/iree/turbine/kernel/wave/index_sequence_analysis.py @@ -68,17 +68,11 @@ def _get_symbolic_shape_and_vector_shapes( hw_constraint: HardwareConstraint, ): # When the memory type has symbolic aliases, use the memory type - # as it includes the aliased variables. Also, add the aliased variables - # vector shapes. + # as it includes the aliased variables. symbolic_shape = custom.register_type.symbolic_shape vector_shapes = custom.vector_shapes if any([x in custom.memory_type.symbolic_shape for x in aliases]): symbolic_shape = custom.memory_type.symbolic_shape - for dim in symbolic_shape: - if dim in aliases and aliases[dim].target in vector_shapes: - vector_shapes[dim] = aliases[dim].apply( - vector_shapes[aliases[dim].target] - ) return symbolic_shape, vector_shapes @@ -651,6 +645,21 @@ def should_update_index( return True +def append_aliased_shapes(source: CustomOp, symbolic_constraints: list[SymbolicAlias]): + """ + Append the aliased shapes to the vector shapes of the source, if they + are present in the source index. + """ + for constraint in symbolic_constraints: + if ( + constraint.target in source.vector_shapes + and constraint.source in source.index + ): + source.vector_shapes[constraint.source] = constraint.apply( + source.vector_shapes[constraint.target] + ) + + def propagate_index( node: CustomOp, hardware_constraint: HardwareConstraint, @@ -683,6 +692,7 @@ def propagate_index( source_index = source.transform_index(source_index) source.index = combine_indices(source.index, source_index) source.vector_shapes = source_vector_shapes + append_aliased_shapes(source, symbolic_constraints) visited.add(source) for func in [get_inputs, get_users]: sources, reduction = add_nodes_to_sources( @@ -761,7 +771,7 @@ def create_broadcast( binary_op.graph ) custom = get_custom(broadcasted) - custom.vector_shapes = to_broadcast.vector_shapes + custom.vector_shapes = binary_op.vector_shapes custom.index = deepcopy(target_node.index) custom.index[broadcast_dim].size = broadcast_size broadcast_idx = list(binary_op.node_args.values()).index(to_broadcast) diff --git a/iree/turbine/kernel/wave/templates/decode_attention.py b/iree/turbine/kernel/wave/templates/decode_attention.py index 02b5901bc..3cfd4551f 100644 --- a/iree/turbine/kernel/wave/templates/decode_attention.py +++ b/iree/turbine/kernel/wave/templates/decode_attention.py @@ -65,12 +65,7 @@ def phase_0_constraints(): constraints += [ SymbolicAlias(U, K2, lambda x: sympy.ceiling(x / (BLOCK_K2 / K_WAVES))) ] - if mfma_variant == MMAType.F32_16x16x16_F16: - vector_shapes = {B: 0, M: 16, N: 16} - elif mfma_variant == MMAType.F32_32x32x8_F16: - vector_shapes = {B: 0, M: 32, N: 32} - else: - raise NotImplementedError(f"Unsupported mfma_variant: {mfma_variant}") + vector_shapes = {B: 0} waves_per_block = (M_WAVES, N_WAVES, K_WAVES) constraints += [ tkw.HardwareConstraint( @@ -147,8 +142,8 @@ def phase_0( acc = tkw.mma(v_reg, imm_f16, new_acc) res = acc / d_j dm_j = m_j + tkw.log2(d_j) - tkw.write(res, output, elements_per_thread=STORE_ELEMS_PER_THREAD) tkw.write(dm_j, output_max, elements_per_thread=1) + tkw.write(res, output, elements_per_thread=STORE_ELEMS_PER_THREAD) @tkw.wave(get_constraints(Phase.PHASE_1)) def phase_1(