Skip to content

Commit

Permalink
[inductor] Add some typing to simd.py (pytorch#145690)
Browse files Browse the repository at this point in the history
Pull Request resolved: pytorch#145690
Approved by: https://github.com/malfet
  • Loading branch information
jansel authored and nWEIdia committed Jan 27, 2025
1 parent 7ca5a12 commit 1ace012
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 55 deletions.
132 changes: 81 additions & 51 deletions torch/_inductor/codegen/simd.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,15 @@
import operator
import textwrap
from collections import Counter
from typing import Any, Callable, no_type_check, Optional, TYPE_CHECKING, Union
from typing import (
Any,
Callable,
Iterator,
no_type_check,
Optional,
TYPE_CHECKING,
Union,
)

import sympy

Expand Down Expand Up @@ -125,7 +133,7 @@ def __init__(
def is_reduction(self) -> bool:
return prefix_is_reduction(self.prefix)

def symbol(self):
def symbol(self) -> sympy.Symbol:
return sympy_index_symbol(self.name)

@property
Expand All @@ -144,7 +152,7 @@ def __init__(
prefix: str,
index: int,
kernel: SIMDKernel,
pid_cache=None,
pid_cache: Optional[dict[str, str]] = None,
*,
is_loop: bool,
tensor_dim: Optional[int],
Expand Down Expand Up @@ -182,14 +190,14 @@ def __init__(
def __repr__(self) -> str:
return f"IterationRangesRoot({self.name!r}, {self.numel}, ...)"

def cache_clear(self):
def cache_clear(self) -> None:
for node in self.nodes.values():
node.cache_clear()

def index_sym(self):
def index_sym(self) -> sympy.Symbol:
return sympy_index_symbol(f"{self.prefix}index")

def lookup(self, divisor, length):
def lookup(self, divisor: sympy.Expr, length: sympy.Expr) -> IterationRangesEntry:
"""
Lookup a given RangeTreeEntry, creating it if needed
"""
Expand All @@ -212,18 +220,22 @@ def lookup(self, divisor, length):
self.nodes[expr] = node
return self.nodes[expr]

def construct_entries(self, lengths: list[sympy.Expr]):
def construct_entries(
self, lengths: list[sympy.Expr]
) -> list[IterationRangesEntry]:
divisor = sympy.S.One
itervars = []
for length in reversed(lengths):
itervars.append(self.lookup(divisor, length))
divisor = divisor * length
return list(reversed(itervars))
return [*reversed(itervars)]

def construct(self, lengths: list[sympy.Expr]):
def construct(self, lengths: list[sympy.Expr]) -> list[sympy.Symbol]:
return [e.symbol() for e in self.construct_entries(lengths)]

def vars_and_sizes(self, index: sympy.Expr):
def vars_and_sizes(
self, index: sympy.Expr
) -> tuple[list[sympy.Symbol], list[sympy.Expr]]:
"""Figure out vars from this tree used in index"""
nodes = [V.kernel.range_tree_nodes.get(s) for s in index.free_symbols]
nodes = [n for n in nodes if n and n.prefix == self.prefix]
Expand Down Expand Up @@ -252,7 +264,7 @@ def add(node):
# fill in unused index var
add(self.lookup(divisor, FloorDiv(self.numel, divisor)))

return list(reversed(index_vars)), list(reversed(sizes))
return [*reversed(index_vars)], [*reversed(sizes)]


class IterationRangesEntry(IterationRanges):
Expand Down Expand Up @@ -282,19 +294,19 @@ def __init__(
def __repr__(self) -> str:
return f"IterationRangesEntry({self.name}, {self.divisor}, {self.length}, {self.expr}, {self.var_ranges})"

def set_name(self, name):
def set_name(self, name: str) -> None:
self.codegen = lambda: name # type: ignore[assignment]
self.codegen.cache_clear = lambda: None # type: ignore[method-assign]
self.name = name

def cache_clear(self):
def cache_clear(self) -> None:
self.codegen.cache_clear()

def _codegen(self):
def _codegen(self) -> str:
V.kernel.codegen_iteration_ranges_entry(self)
return self.name

def precomputed_args(self):
def precomputed_args(self) -> list[sympy.Expr]:
# for dynamic shapes, find parts of indexing expressions that have to be precomputed
precomputed_args: list[sympy.Expr] = []
if isinstance(self.expr, sympy.Symbol):
Expand All @@ -309,14 +321,15 @@ def precomputed_args(self):
precomputed_args.append(arg)
return precomputed_args

def __hash__(self):
def __hash__(self) -> int:
return hash(self.name)

def __eq__(self, other):
def __eq__(self, other: object) -> bool:
assert isinstance(other, IterationRangesEntry)
return self.name == other.name


def constant_repr(value):
def constant_repr(value: Union[int, float]) -> str:
if value == float("inf"):
return 'float("inf")'
elif value == float("-inf"):
Expand All @@ -331,18 +344,18 @@ class SIMDKernel(Kernel):
Common base class for Triton/Halide codegen which both use flattened indexing rather than loop nests.
"""

sexpr = pexpr
sexpr: Callable[[sympy.Expr], str] = pexpr
kexpr: Callable[[sympy.Expr], str]
allow_block_ptr = False
allow_block_ptr: bool = False
kernel_name: str

def __init__(
self,
tiling: dict[str, sympy.Expr],
features: SIMDKernelFeatures,
pid_cache=None,
override_persistent_reduction=None,
override_cooperative_reduction=None,
pid_cache: Optional[dict[str, str]] = None,
override_persistent_reduction: Optional[bool] = None,
override_cooperative_reduction: Optional[bool] = None,
) -> None:
if pid_cache is None:
pid_cache = {}
Expand Down Expand Up @@ -396,12 +409,17 @@ def dtype_to_str(self, dtype: torch.dtype) -> str:
def index_dtype(self) -> str:
return self.dtype_to_str(self.features.select_index_dtype())

def want_no_x_dim(self):
def want_no_x_dim(self) -> bool:
return False

def construct_range_trees(
self, pid_cache, inside_reduction, is_reduction, numels, no_x_dim
):
self,
pid_cache: Optional[dict[str, str]],
inside_reduction: bool,
is_reduction: bool,
numels: dict[str, sympy.Expr],
no_x_dim: bool,
) -> list[IterationRangesRoot]:
active_prefixes = OrderedSet(
prefix for prefix in all_prefixes if prefix in numels
)
Expand Down Expand Up @@ -448,7 +466,7 @@ def filtered_index_map(seq, mask) -> dict[Any, int]:
)
return range_trees

def initialize_range_tree(self, pid_cache):
def initialize_range_tree(self, pid_cache: dict[str, str]) -> None:
range_trees = self.construct_range_trees(
pid_cache,
self.inside_reduction,
Expand All @@ -458,13 +476,13 @@ def initialize_range_tree(self, pid_cache):
)
self.range_trees.extend(range_trees)

def finalize_indexing(self, indices: Sequence[sympy.Expr]):
def finalize_indexing(self, indices: Sequence[sympy.Expr]) -> None:
"""
Hook called right before codegen with every index that will be
used in the fused kernel.
"""

def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable):
def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable) -> None:
prior = self.inside_reduction
self.inside_reduction = False
try:
Expand All @@ -478,17 +496,17 @@ def should_use_cooperative_reduction(self) -> bool:
def should_use_persistent_reduction(self) -> bool:
return False # defined in subclass

def var_ranges(self):
def var_ranges(self) -> dict[sympy.Symbol, sympy.Expr]:
return dict(
itertools.chain.from_iterable(
tree.var_ranges.items() for tree in self.range_trees
)
)

def triton_tensor_ndim(self):
def triton_tensor_ndim(self) -> int:
return sum(int(tree.tensor_dim is not None) for tree in self.range_trees)

def indexing_size_str(self, i):
def indexing_size_str(self, i: int) -> str:
sizes = ["None"] * self.triton_tensor_ndim()
sizes[i] = ":"
return f"[{', '.join(sizes)}]"
Expand All @@ -503,11 +521,11 @@ def dense_size_list(self) -> list[str]:
sizes[tree.tensor_dim] = f"{tree.prefix.upper()}BLOCK"
return sizes

def dense_size_str(self):
def dense_size_str(self) -> str:
sizes = self.dense_size_list()
return f"[{', '.join(sizes)}]"

def combine_modular_indexing_pairs(self, index):
def combine_modular_indexing_pairs(self, index: sympy.Expr) -> sympy.Expr:
if not isinstance(index, ModularIndexing):
return index
x = index.args[0]
Expand All @@ -525,14 +543,18 @@ def combine_modular_indexing_pairs(self, index):
},
)

def combine_contiguous_dims(self, index: sympy.Expr, tree: IterationRangesRoot):
def combine_contiguous_dims(
self, index: sympy.Expr, tree: IterationRangesRoot
) -> sympy.Expr:
if expand_res := V.graph.sizevars.expand_floor_div(index):
new_index, denominator = expand_res # type: ignore[misc]
return FloorDiv(self._combine_contiguous_dims(new_index, tree), denominator)
else:
return self._combine_contiguous_dims(index, tree)

def _combine_contiguous_dims(self, index: sympy.Expr, tree: IterationRangesRoot):
def _combine_contiguous_dims(
self, index: sympy.Expr, tree: IterationRangesRoot
) -> sympy.Expr:
"""
More aggressive simplification to merge contiguous dims
"""
Expand All @@ -550,7 +572,7 @@ def _combine_contiguous_dims(self, index: sympy.Expr, tree: IterationRangesRoot)
new_index = sympy_subs(index, dict(zip(index_vars, reindex(new_index_vars))))
return new_index

def disable_reduction(self):
def disable_reduction(self) -> contextlib.AbstractContextManager[None]:
should_flush = self.range_trees[-1].is_loop or self.cooperative_reduction

@contextlib.contextmanager
Expand All @@ -574,7 +596,7 @@ def ctx():

return ctx()

def set_ranges(self, *lengths):
def set_ranges(self, *lengths: sympy.Expr) -> list[sympy.Symbol]:
assert len(lengths) == len(self.range_trees)
return [
ranges.construct(length)
Expand All @@ -584,7 +606,9 @@ def set_ranges(self, *lengths):
@staticmethod
def _split_iteration_ranges(
groups: Iterable[sympy.Expr], lengths: Sequence[Sequence[sympy.Expr]]
):
) -> tuple[
list[list[sympy.Expr]], list[list[Callable[[list[sympy.Expr]], sympy.Expr]]]
]:
# Special case: if a node's sizes are ([], []), there's nothing to split.
if all(len(length) == 0 for length in lengths):
return [[] for group in groups], []
Expand All @@ -594,7 +618,7 @@ def _split_iteration_ranges(
remaining = [sv.simplify(g) for g in groups]
var_count = itertools.count()

def add_range(i, expr):
def add_range(i: int, expr: sympy.Expr) -> int:
expr = sv.simplify(expr)
if not sv.statically_known_multiple_of(remaining[i], expr):
raise CantSplit
Expand All @@ -603,8 +627,10 @@ def add_range(i, expr):
new_ranges[i].append(expr)
return next(var_count)

def make_combined(size, idx1, idx2):
def getter(flat_vars):
def make_combined(
size: sympy.Expr, idx1: int, idx2: int
) -> Callable[[list[sympy.Expr]], sympy.Expr]:
def getter(flat_vars: list[sympy.Expr]) -> sympy.Expr:
return size * flat_vars[idx1] + flat_vars[idx2]

return getter
Expand Down Expand Up @@ -659,7 +685,7 @@ def is_compatible(
groups: Iterable[sympy.Expr],
lengths: Sequence[Sequence[sympy.Expr]],
reduction_numel: sympy.Expr = sympy.S.One,
):
) -> bool:
# Fill in the reduction numel, in case the node is missing it.
sizevars = V.graph.sizevars
if len(lengths[1]) == 0 and (
Expand All @@ -676,7 +702,9 @@ def is_compatible(
except CantSplit:
return False

def split_and_set_ranges(self, lengths: Sequence[Sequence[sympy.Expr]]):
def split_and_set_ranges(
self, lengths: Sequence[Sequence[sympy.Expr]]
) -> list[list[sympy.Expr]]:
tiling = {rt.prefix: rt.numel for rt in self.range_trees}
if not self.inside_reduction:
for prefix in tiling:
Expand Down Expand Up @@ -715,11 +743,11 @@ def map_kernel_groups_to_node_sizes(
itervars = [*itertools.chain.from_iterable(set_ranges(*new_ranges))]
return [[fn(itervars) for fn in fns] for fns in return_getters_groups]

def is_indirect_indexing(self, index: sympy.Expr):
def is_indirect_indexing(self, index: sympy.Expr) -> bool:
# tmpX means indirect indexing
return free_symbol_is_type(index, SymT.TMP)

def is_broadcasted(self, index: sympy.Expr):
def is_broadcasted(self, index: sympy.Expr) -> bool:
# Note. This may not be correct when there is indirect indexing
if self.is_indirect_indexing(index):
return False
Expand Down Expand Up @@ -757,7 +785,7 @@ def index_to_str(self, index: sympy.Expr) -> str:
def prepare_indexing(
self,
index: sympy.Expr,
):
) -> sympy.Expr:
index = self.simplify_indexing(index)
index = sympy_subs(index, V.graph.sizevars.precomputed_replacements)
# if simple replacements didn't get rid of floor/ceil, try full subs
Expand Down Expand Up @@ -792,7 +820,7 @@ def prepare_indexing(

return self.codegen_indexing(simp_index)

def active_range_trees(self, reorder=False):
def active_range_trees(self, reorder: bool = False) -> list[IterationRangesRoot]:
trees = [
t for t in self.range_trees if not t.is_reduction or self.inside_reduction
]
Expand All @@ -804,7 +832,7 @@ def active_range_trees(self, reorder=False):
trees[:count] = reversed(trees[:count])
return trees

def codegen_indexing(self, expr: sympy.Expr):
def codegen_indexing(self, expr: sympy.Expr) -> sympy.Expr:
expr = V.graph.sizevars.simplify_with_ranges(expr, self.var_ranges())
for sym in sorted(expr.free_symbols, key=str):
if sym in self.range_tree_nodes:
Expand All @@ -827,7 +855,9 @@ def call_kernel(self, name: str, node: Optional[IRNode] = None) -> None:
raise NotImplementedError("NYI: call_kernel")

@contextlib.contextmanager
def mask_loads(self, mask, value):
def mask_loads(
self, mask: Union[str, OpsWrapper], value: Union[int, float]
) -> Iterator[str]:
"""Context manager to add an additional mask to tl.load/store"""
prior = self._load_mask
prior_val = self._load_other
Expand All @@ -844,7 +874,7 @@ def mask_loads(self, mask, value):
self._load_mask = prior
self._load_other = prior_val

def get_strides_of_load(self, index: sympy.Expr):
def get_strides_of_load(self, index: sympy.Expr) -> dict[sympy.Symbol, sympy.Expr]:
"""
This gets the stride of the index for each of the tiling variables
(technically, it does it at index 0)
Expand Down
Loading

0 comments on commit 1ace012

Please sign in to comment.