diff --git a/shark_turbine/kernel/wave/index_sequence_analysis.py b/shark_turbine/kernel/wave/index_sequence_analysis.py index 85b8761ca..4fb926738 100644 --- a/shark_turbine/kernel/wave/index_sequence_analysis.py +++ b/shark_turbine/kernel/wave/index_sequence_analysis.py @@ -36,7 +36,12 @@ def has_strided_access(node: fx.Node) -> bool: custom = get_custom(node) if isinstance(custom, Write) and len(custom.type.symbolic_shape) == 2: strides = [simplify_index(custom.index[dim]).stride for dim in custom.index] - return sum(1 for stride in strides if stride > 1) == 1 + num_strided_accesses = sum(1 for stride in strides if stride > 1) + if num_strided_accesses > 1: + raise NotImplementedError( + "Support for strided accesses on more than one dimension not implemented yet!" + ) + return num_strided_accesses == 1 return False strided_operators = trace.walk(has_strided_access)