Skip to content

Commit

Permalink
More cleanups
Browse files Browse the repository at this point in the history
Signed-off-by: Harsh Menon <harsh@nod-labs.com>
  • Loading branch information
harsh-nod committed Dec 19, 2024
1 parent 0edf8ac commit ed7d747
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 15 deletions.
26 changes: 18 additions & 8 deletions iree/turbine/kernel/wave/index_sequence_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,17 +68,11 @@ def _get_symbolic_shape_and_vector_shapes(
hw_constraint: HardwareConstraint,
):
# When the memory type has symbolic aliases, use the memory type
# as it includes the aliased variables. Also, add the aliased variables
# vector shapes.
# as it includes the aliased variables.
symbolic_shape = custom.register_type.symbolic_shape
vector_shapes = custom.vector_shapes
if any([x in custom.memory_type.symbolic_shape for x in aliases]):
symbolic_shape = custom.memory_type.symbolic_shape
for dim in symbolic_shape:
if dim in aliases and aliases[dim].target in vector_shapes:
vector_shapes[dim] = aliases[dim].apply(
vector_shapes[aliases[dim].target]
)
return symbolic_shape, vector_shapes


Expand Down Expand Up @@ -651,6 +645,21 @@ def should_update_index(
return True


def append_aliased_shapes(source: CustomOp, symbolic_constraints: list[SymbolicAlias]):
"""
Append the aliased shapes to the vector shapes of the source, if they
are present in the source index.
"""
for constraint in symbolic_constraints:
if (
constraint.target in source.vector_shapes
and constraint.source in source.index
):
source.vector_shapes[constraint.source] = constraint.apply(
source.vector_shapes[constraint.target]
)


def propagate_index(
node: CustomOp,
hardware_constraint: HardwareConstraint,
Expand Down Expand Up @@ -683,6 +692,7 @@ def propagate_index(
source_index = source.transform_index(source_index)
source.index = combine_indices(source.index, source_index)
source.vector_shapes = source_vector_shapes
append_aliased_shapes(source, symbolic_constraints)
visited.add(source)
for func in [get_inputs, get_users]:
sources, reduction = add_nodes_to_sources(
Expand Down Expand Up @@ -761,7 +771,7 @@ def create_broadcast(
binary_op.graph
)
custom = get_custom(broadcasted)
custom.vector_shapes = to_broadcast.vector_shapes
custom.vector_shapes = binary_op.vector_shapes
custom.index = deepcopy(target_node.index)
custom.index[broadcast_dim].size = broadcast_size
broadcast_idx = list(binary_op.node_args.values()).index(to_broadcast)
Expand Down
9 changes: 2 additions & 7 deletions iree/turbine/kernel/wave/templates/decode_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,7 @@ def phase_0_constraints():
constraints += [
SymbolicAlias(U, K2, lambda x: sympy.ceiling(x / (BLOCK_K2 / K_WAVES)))
]
if mfma_variant == MMAType.F32_16x16x16_F16:
vector_shapes = {B: 0, M: 16, N: 16}
elif mfma_variant == MMAType.F32_32x32x8_F16:
vector_shapes = {B: 0, M: 32, N: 32}
else:
raise NotImplementedError(f"Unsupported mfma_variant: {mfma_variant}")
vector_shapes = {B: 0}
waves_per_block = (M_WAVES, N_WAVES, K_WAVES)
constraints += [
tkw.HardwareConstraint(
Expand Down Expand Up @@ -147,8 +142,8 @@ def phase_0(
acc = tkw.mma(v_reg, imm_f16, new_acc)
res = acc / d_j
dm_j = m_j + tkw.log2(d_j)
tkw.write(res, output, elements_per_thread=STORE_ELEMS_PER_THREAD)
tkw.write(dm_j, output_max, elements_per_thread=1)
tkw.write(res, output, elements_per_thread=STORE_ELEMS_PER_THREAD)

@tkw.wave(get_constraints(Phase.PHASE_1))
def phase_1(
Expand Down

0 comments on commit ed7d747

Please sign in to comment.