From 1ace01224d672df8c68b21b5ae8bca80eefefafc Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Sun, 26 Jan 2025 18:50:56 -0800 Subject: [PATCH] [inductor] Add some typing to simd.py (#145690) Pull Request resolved: https://github.com/pytorch/pytorch/pull/145690 Approved by: https://github.com/malfet --- torch/_inductor/codegen/simd.py | 132 ++++++++++++++++++------------ torch/_inductor/codegen/triton.py | 8 +- 2 files changed, 85 insertions(+), 55 deletions(-) diff --git a/torch/_inductor/codegen/simd.py b/torch/_inductor/codegen/simd.py index 155687fde889b3..8faafe3dd95c01 100644 --- a/torch/_inductor/codegen/simd.py +++ b/torch/_inductor/codegen/simd.py @@ -11,7 +11,15 @@ import operator import textwrap from collections import Counter -from typing import Any, Callable, no_type_check, Optional, TYPE_CHECKING, Union +from typing import ( + Any, + Callable, + Iterator, + no_type_check, + Optional, + TYPE_CHECKING, + Union, +) import sympy @@ -125,7 +133,7 @@ def __init__( def is_reduction(self) -> bool: return prefix_is_reduction(self.prefix) - def symbol(self): + def symbol(self) -> sympy.Symbol: return sympy_index_symbol(self.name) @property @@ -144,7 +152,7 @@ def __init__( prefix: str, index: int, kernel: SIMDKernel, - pid_cache=None, + pid_cache: Optional[dict[str, str]] = None, *, is_loop: bool, tensor_dim: Optional[int], @@ -182,14 +190,14 @@ def __init__( def __repr__(self) -> str: return f"IterationRangesRoot({self.name!r}, {self.numel}, ...)" - def cache_clear(self): + def cache_clear(self) -> None: for node in self.nodes.values(): node.cache_clear() - def index_sym(self): + def index_sym(self) -> sympy.Symbol: return sympy_index_symbol(f"{self.prefix}index") - def lookup(self, divisor, length): + def lookup(self, divisor: sympy.Expr, length: sympy.Expr) -> IterationRangesEntry: """ Lookup a given RangeTreeEntry, creating it if needed """ @@ -212,18 +220,22 @@ def lookup(self, divisor, length): self.nodes[expr] = node return self.nodes[expr] - def construct_entries(self, lengths: list[sympy.Expr]): + def construct_entries( + self, lengths: list[sympy.Expr] + ) -> list[IterationRangesEntry]: divisor = sympy.S.One itervars = [] for length in reversed(lengths): itervars.append(self.lookup(divisor, length)) divisor = divisor * length - return list(reversed(itervars)) + return [*reversed(itervars)] - def construct(self, lengths: list[sympy.Expr]): + def construct(self, lengths: list[sympy.Expr]) -> list[sympy.Symbol]: return [e.symbol() for e in self.construct_entries(lengths)] - def vars_and_sizes(self, index: sympy.Expr): + def vars_and_sizes( + self, index: sympy.Expr + ) -> tuple[list[sympy.Symbol], list[sympy.Expr]]: """Figure out vars from this tree used in index""" nodes = [V.kernel.range_tree_nodes.get(s) for s in index.free_symbols] nodes = [n for n in nodes if n and n.prefix == self.prefix] @@ -252,7 +264,7 @@ def add(node): # fill in unused index var add(self.lookup(divisor, FloorDiv(self.numel, divisor))) - return list(reversed(index_vars)), list(reversed(sizes)) + return [*reversed(index_vars)], [*reversed(sizes)] class IterationRangesEntry(IterationRanges): @@ -282,19 +294,19 @@ def __init__( def __repr__(self) -> str: return f"IterationRangesEntry({self.name}, {self.divisor}, {self.length}, {self.expr}, {self.var_ranges})" - def set_name(self, name): + def set_name(self, name: str) -> None: self.codegen = lambda: name # type: ignore[assignment] self.codegen.cache_clear = lambda: None # type: ignore[method-assign] self.name = name - def cache_clear(self): + def cache_clear(self) -> None: self.codegen.cache_clear() - def _codegen(self): + def _codegen(self) -> str: V.kernel.codegen_iteration_ranges_entry(self) return self.name - def precomputed_args(self): + def precomputed_args(self) -> list[sympy.Expr]: # for dynamic shapes, find parts of indexing expressions that have to be precomputed precomputed_args: list[sympy.Expr] = [] if isinstance(self.expr, sympy.Symbol): @@ -309,14 +321,15 @@ def precomputed_args(self): precomputed_args.append(arg) return precomputed_args - def __hash__(self): + def __hash__(self) -> int: return hash(self.name) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: + assert isinstance(other, IterationRangesEntry) return self.name == other.name -def constant_repr(value): +def constant_repr(value: Union[int, float]) -> str: if value == float("inf"): return 'float("inf")' elif value == float("-inf"): @@ -331,18 +344,18 @@ class SIMDKernel(Kernel): Common base class for Triton/Halide codegen which both use flattened indexing rather than loop nests. """ - sexpr = pexpr + sexpr: Callable[[sympy.Expr], str] = pexpr kexpr: Callable[[sympy.Expr], str] - allow_block_ptr = False + allow_block_ptr: bool = False kernel_name: str def __init__( self, tiling: dict[str, sympy.Expr], features: SIMDKernelFeatures, - pid_cache=None, - override_persistent_reduction=None, - override_cooperative_reduction=None, + pid_cache: Optional[dict[str, str]] = None, + override_persistent_reduction: Optional[bool] = None, + override_cooperative_reduction: Optional[bool] = None, ) -> None: if pid_cache is None: pid_cache = {} @@ -396,12 +409,17 @@ def dtype_to_str(self, dtype: torch.dtype) -> str: def index_dtype(self) -> str: return self.dtype_to_str(self.features.select_index_dtype()) - def want_no_x_dim(self): + def want_no_x_dim(self) -> bool: return False def construct_range_trees( - self, pid_cache, inside_reduction, is_reduction, numels, no_x_dim - ): + self, + pid_cache: Optional[dict[str, str]], + inside_reduction: bool, + is_reduction: bool, + numels: dict[str, sympy.Expr], + no_x_dim: bool, + ) -> list[IterationRangesRoot]: active_prefixes = OrderedSet( prefix for prefix in all_prefixes if prefix in numels ) @@ -448,7 +466,7 @@ def filtered_index_map(seq, mask) -> dict[Any, int]: ) return range_trees - def initialize_range_tree(self, pid_cache): + def initialize_range_tree(self, pid_cache: dict[str, str]) -> None: range_trees = self.construct_range_trees( pid_cache, self.inside_reduction, @@ -458,13 +476,13 @@ def initialize_range_tree(self, pid_cache): ) self.range_trees.extend(range_trees) - def finalize_indexing(self, indices: Sequence[sympy.Expr]): + def finalize_indexing(self, indices: Sequence[sympy.Expr]) -> None: """ Hook called right before codegen with every index that will be used in the fused kernel. """ - def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable): + def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable) -> None: prior = self.inside_reduction self.inside_reduction = False try: @@ -478,17 +496,17 @@ def should_use_cooperative_reduction(self) -> bool: def should_use_persistent_reduction(self) -> bool: return False # defined in subclass - def var_ranges(self): + def var_ranges(self) -> dict[sympy.Symbol, sympy.Expr]: return dict( itertools.chain.from_iterable( tree.var_ranges.items() for tree in self.range_trees ) ) - def triton_tensor_ndim(self): + def triton_tensor_ndim(self) -> int: return sum(int(tree.tensor_dim is not None) for tree in self.range_trees) - def indexing_size_str(self, i): + def indexing_size_str(self, i: int) -> str: sizes = ["None"] * self.triton_tensor_ndim() sizes[i] = ":" return f"[{', '.join(sizes)}]" @@ -503,11 +521,11 @@ def dense_size_list(self) -> list[str]: sizes[tree.tensor_dim] = f"{tree.prefix.upper()}BLOCK" return sizes - def dense_size_str(self): + def dense_size_str(self) -> str: sizes = self.dense_size_list() return f"[{', '.join(sizes)}]" - def combine_modular_indexing_pairs(self, index): + def combine_modular_indexing_pairs(self, index: sympy.Expr) -> sympy.Expr: if not isinstance(index, ModularIndexing): return index x = index.args[0] @@ -525,14 +543,18 @@ def combine_modular_indexing_pairs(self, index): }, ) - def combine_contiguous_dims(self, index: sympy.Expr, tree: IterationRangesRoot): + def combine_contiguous_dims( + self, index: sympy.Expr, tree: IterationRangesRoot + ) -> sympy.Expr: if expand_res := V.graph.sizevars.expand_floor_div(index): new_index, denominator = expand_res # type: ignore[misc] return FloorDiv(self._combine_contiguous_dims(new_index, tree), denominator) else: return self._combine_contiguous_dims(index, tree) - def _combine_contiguous_dims(self, index: sympy.Expr, tree: IterationRangesRoot): + def _combine_contiguous_dims( + self, index: sympy.Expr, tree: IterationRangesRoot + ) -> sympy.Expr: """ More aggressive simplification to merge contiguous dims """ @@ -550,7 +572,7 @@ def _combine_contiguous_dims(self, index: sympy.Expr, tree: IterationRangesRoot) new_index = sympy_subs(index, dict(zip(index_vars, reindex(new_index_vars)))) return new_index - def disable_reduction(self): + def disable_reduction(self) -> contextlib.AbstractContextManager[None]: should_flush = self.range_trees[-1].is_loop or self.cooperative_reduction @contextlib.contextmanager @@ -574,7 +596,7 @@ def ctx(): return ctx() - def set_ranges(self, *lengths): + def set_ranges(self, *lengths: sympy.Expr) -> list[sympy.Symbol]: assert len(lengths) == len(self.range_trees) return [ ranges.construct(length) @@ -584,7 +606,9 @@ def set_ranges(self, *lengths): @staticmethod def _split_iteration_ranges( groups: Iterable[sympy.Expr], lengths: Sequence[Sequence[sympy.Expr]] - ): + ) -> tuple[ + list[list[sympy.Expr]], list[list[Callable[[list[sympy.Expr]], sympy.Expr]]] + ]: # Special case: if a node's sizes are ([], []), there's nothing to split. if all(len(length) == 0 for length in lengths): return [[] for group in groups], [] @@ -594,7 +618,7 @@ def _split_iteration_ranges( remaining = [sv.simplify(g) for g in groups] var_count = itertools.count() - def add_range(i, expr): + def add_range(i: int, expr: sympy.Expr) -> int: expr = sv.simplify(expr) if not sv.statically_known_multiple_of(remaining[i], expr): raise CantSplit @@ -603,8 +627,10 @@ def add_range(i, expr): new_ranges[i].append(expr) return next(var_count) - def make_combined(size, idx1, idx2): - def getter(flat_vars): + def make_combined( + size: sympy.Expr, idx1: int, idx2: int + ) -> Callable[[list[sympy.Expr]], sympy.Expr]: + def getter(flat_vars: list[sympy.Expr]) -> sympy.Expr: return size * flat_vars[idx1] + flat_vars[idx2] return getter @@ -659,7 +685,7 @@ def is_compatible( groups: Iterable[sympy.Expr], lengths: Sequence[Sequence[sympy.Expr]], reduction_numel: sympy.Expr = sympy.S.One, - ): + ) -> bool: # Fill in the reduction numel, in case the node is missing it. sizevars = V.graph.sizevars if len(lengths[1]) == 0 and ( @@ -676,7 +702,9 @@ def is_compatible( except CantSplit: return False - def split_and_set_ranges(self, lengths: Sequence[Sequence[sympy.Expr]]): + def split_and_set_ranges( + self, lengths: Sequence[Sequence[sympy.Expr]] + ) -> list[list[sympy.Expr]]: tiling = {rt.prefix: rt.numel for rt in self.range_trees} if not self.inside_reduction: for prefix in tiling: @@ -715,11 +743,11 @@ def map_kernel_groups_to_node_sizes( itervars = [*itertools.chain.from_iterable(set_ranges(*new_ranges))] return [[fn(itervars) for fn in fns] for fns in return_getters_groups] - def is_indirect_indexing(self, index: sympy.Expr): + def is_indirect_indexing(self, index: sympy.Expr) -> bool: # tmpX means indirect indexing return free_symbol_is_type(index, SymT.TMP) - def is_broadcasted(self, index: sympy.Expr): + def is_broadcasted(self, index: sympy.Expr) -> bool: # Note. This may not be correct when there is indirect indexing if self.is_indirect_indexing(index): return False @@ -757,7 +785,7 @@ def index_to_str(self, index: sympy.Expr) -> str: def prepare_indexing( self, index: sympy.Expr, - ): + ) -> sympy.Expr: index = self.simplify_indexing(index) index = sympy_subs(index, V.graph.sizevars.precomputed_replacements) # if simple replacements didn't get rid of floor/ceil, try full subs @@ -792,7 +820,7 @@ def prepare_indexing( return self.codegen_indexing(simp_index) - def active_range_trees(self, reorder=False): + def active_range_trees(self, reorder: bool = False) -> list[IterationRangesRoot]: trees = [ t for t in self.range_trees if not t.is_reduction or self.inside_reduction ] @@ -804,7 +832,7 @@ def active_range_trees(self, reorder=False): trees[:count] = reversed(trees[:count]) return trees - def codegen_indexing(self, expr: sympy.Expr): + def codegen_indexing(self, expr: sympy.Expr) -> sympy.Expr: expr = V.graph.sizevars.simplify_with_ranges(expr, self.var_ranges()) for sym in sorted(expr.free_symbols, key=str): if sym in self.range_tree_nodes: @@ -827,7 +855,9 @@ def call_kernel(self, name: str, node: Optional[IRNode] = None) -> None: raise NotImplementedError("NYI: call_kernel") @contextlib.contextmanager - def mask_loads(self, mask, value): + def mask_loads( + self, mask: Union[str, OpsWrapper], value: Union[int, float] + ) -> Iterator[str]: """Context manager to add an additional mask to tl.load/store""" prior = self._load_mask prior_val = self._load_other @@ -844,7 +874,7 @@ def mask_loads(self, mask, value): self._load_mask = prior self._load_other = prior_val - def get_strides_of_load(self, index: sympy.Expr): + def get_strides_of_load(self, index: sympy.Expr) -> dict[sympy.Symbol, sympy.Expr]: """ This gets the stride of the index for each of the tiling variables (technically, it does it at index 0) diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index c1e9039b6492de..e539edf00cb21f 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -292,7 +292,7 @@ def create( *, params: BlockParameters, constant_offset: sympy.Expr, - range_trees: list[IterationRangesEntry], + range_trees: list[IterationRangesRoot], mask_vars: OrderedSet[str], get_max_block: Callable[[str], int], ) -> BlockPtrOptions: @@ -1740,7 +1740,7 @@ def indexing( ): def match_strided_block( - index: sympy.Expr, range_tree: IterationRangesEntry + index: sympy.Expr, range_tree: IterationRangesRoot ) -> Optional[BlockParameters]: """ Matches expressions of the form: @@ -1762,7 +1762,7 @@ def match_strided_block( ) def match_mod_div_block( - index: sympy.Expr, range_tree: IterationRangesEntry + index: sympy.Expr, range_tree: IterationRangesRoot ) -> Optional[BlockParameters]: """ Matches higher-dimensional blocks coming from FloorDiv and ModularIndexing. @@ -1860,7 +1860,7 @@ def match_mod_div_block( ) def match_block_pointer_subexpr( - expr: sympy.Expr, range_tree: IterationRangesEntry + expr: sympy.Expr, range_tree: IterationRangesRoot ) -> Optional[BlockParameters]: """ Match a block indexing subexpression involving a single range tree.