diff --git a/iree/turbine/kernel/wave/utils.py b/iree/turbine/kernel/wave/utils.py index 4ecd66a50..6f00caa55 100644 --- a/iree/turbine/kernel/wave/utils.py +++ b/iree/turbine/kernel/wave/utils.py @@ -261,28 +261,41 @@ def is_mma(node): # Determine if any reshapes are required. Reshapes are added for # chained matmuls when the vector shapes of the operands in one matmul - # differ from those in another matmul. - for src in mma_nodes: - custom_src = get_custom(src) - for dst in mma_nodes: - if src == dst: - continue - custom_dst = get_custom(dst) - lhs_slice = capture_backward_slice(custom_dst.lhs) - rhs_slice = capture_backward_slice(custom_dst.rhs) - if src in lhs_slice or src in rhs_slice: - with custom_dst.graph.inserting_before(dst): - for i, arg in custom_dst.node_args.items(): - if is_reshape_needed( - arg, custom_dst.vector_shapes, custom_src.vector_shapes - ): - reshape = Reshape( - arg.fx_node, custom_src.vector_shapes - ).add_to_graph(custom.graph) - custom_reshape = get_custom(reshape) - custom_reshape.vector_shapes = custom.vector_shapes - custom_reshape.anchor = custom - custom.update_arg(i, reshape) + # differ from those in another matmul. The mma_slices contain all the ops + # in the backward slice of the lhs and rhs upto a previous mma (if one exists). + # So we check for the previous node of the first operator in the slice to see + # if it is an MMA and if so check if a reshape is required. + def add_reshape_if_needed(mma: MMA, prev_mma: MMA): + with mma.graph.inserting_before(mma.fx_node): + for i, arg in mma.node_args.items(): + if is_reshape_needed(arg, mma.vector_shapes, prev_mma.vector_shapes): + reshape = Reshape(arg.fx_node, prev_mma.vector_shapes).add_to_graph( + custom.graph + ) + custom_reshape = get_custom(reshape) + custom_reshape.vector_shapes = custom.vector_shapes + custom_reshape.anchor = custom + custom.update_arg(i, reshape) + + def find_mma_in_slice(node: CustomOp) -> Optional[MMA]: + """ + Find the closest mma by iterating through the backward slice of a node + in reverse. + """ + slice = list(capture_backward_slice(node)) + for arg in reversed(slice): + prev_mma = get_custom(arg) + if isinstance(prev_mma, MMA): + return prev_mma + return None + + for mma in mma_nodes: + custom_mma = get_custom(mma) + prev_mma = find_mma_in_slice(custom_mma.lhs) or find_mma_in_slice( + custom_mma.rhs + ) + if prev_mma: + add_reshape_if_needed(custom_mma, prev_mma) return mapping, mma_slices