From 2775da42248013265b5591b1c9684f32eefeaf51 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Sun, 26 Jan 2025 18:50:56 -0800 Subject: [PATCH] [inductor] Add some typing to common.py (#145691) Pull Request resolved: https://github.com/pytorch/pytorch/pull/145691 Approved by: https://github.com/malfet ghstack dependencies: #145690 --- .../cpp/extension_codegen_backend.py | 2 +- .../triton/extension_codegen_backend.py | 2 +- torch/_inductor/codegen/common.py | 160 ++++++++++-------- torch/_inductor/codegen/cpp.py | 5 +- torch/_inductor/codegen/cpp_wrapper_cpu.py | 4 +- .../codegen/cpp_wrapper_cpu_array_ref.py | 4 +- torch/_inductor/codegen/cpp_wrapper_gpu.py | 4 +- .../codegen/cuda/cuda_cpp_scheduling.py | 8 +- .../codegen/cuda_combined_scheduling.py | 7 +- torch/_inductor/codegen/mps.py | 2 +- .../codegen/rocm/rocm_cpp_scheduling.py | 8 +- torch/_inductor/codegen/simd.py | 10 +- torch/_inductor/codegen/triton.py | 2 +- torch/_inductor/codegen/wrapper.py | 12 +- torch/_inductor/metrics.py | 2 +- torch/_inductor/scheduler.py | 8 + 16 files changed, 134 insertions(+), 106 deletions(-) diff --git a/test/inductor/extension_backends/cpp/extension_codegen_backend.py b/test/inductor/extension_backends/cpp/extension_codegen_backend.py index f241284c4aef3a..f6afd87db75b9f 100644 --- a/test/inductor/extension_backends/cpp/extension_codegen_backend.py +++ b/test/inductor/extension_backends/cpp/extension_codegen_backend.py @@ -15,7 +15,7 @@ def __init__(self, *args, **kwargs): class ExtensionScheduling(BaseScheduling): def __init__(self, scheduler): - self.scheduler = scheduler + super().__init__(scheduler) self._scheduling = cpp.CppScheduling(scheduler) def can_fuse_vertical(self, node1, node2): diff --git a/test/inductor/extension_backends/triton/extension_codegen_backend.py b/test/inductor/extension_backends/triton/extension_codegen_backend.py index 9a292678b3f87c..3e77a29caacc7d 100644 --- a/test/inductor/extension_backends/triton/extension_codegen_backend.py +++ b/test/inductor/extension_backends/triton/extension_codegen_backend.py @@ -10,7 +10,7 @@ def __init__(self, *args, **kwargs): class ExtensionScheduling(BaseScheduling): def __init__(self, scheduler): - self.scheduler = scheduler + super().__init__(scheduler) self._triton_scheduling = triton.TritonScheduling(scheduler) def can_fuse_vertical(self, node1, node2): diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index 0276ee6d8d9c0d..92be62dcbf8400 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -14,12 +14,6 @@ from itertools import chain from typing import Any, Callable, ClassVar, NamedTuple, Optional, TYPE_CHECKING, Union -from torch.utils._ordered_set import OrderedSet - - -if TYPE_CHECKING: - from typing import Never - import sympy import torch @@ -27,6 +21,7 @@ from torch._inductor.dtype_propagation import DtypePropagationOpsHandler from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND from torch.utils import _pytree as pytree +from torch.utils._ordered_set import OrderedSet from torch.utils._sympy.numbers import int_oo from torch.utils._sympy.printers import PythonPrinter as _PythonPrinter from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT @@ -46,11 +41,24 @@ from ..virtualized import ops, OpsHandler, OpsValue, ReductionType, StoreMode, V +if TYPE_CHECKING: + from typing import Never, TypeVar + + from ..ir import FixedLayout + from ..loop_body import LoopBody + from ..scheduler import BaseScheduling, Scheduler + from .wrapper import PythonWrapperCodegen + + _T = TypeVar("_T") + SchedulingConstructor = Callable[[Optional[Scheduler]], BaseScheduling] + WrapperConstructor = type[PythonWrapperCodegen] + SymbolLike = Union[str, sympy.Symbol] + schedule_log = torch._logging.getArtifactLogger(__name__, "schedule") log = logging.getLogger(__name__) -def data_type_logger(msg): +def data_type_logger(msg: str) -> None: if schedule_log.isEnabledFor(logging.DEBUG): schedule_log.debug("Data type propagation: %s", msg) @@ -61,7 +69,7 @@ class WorkspaceZeroMode(enum.Enum): ZERO_PER_GRAPH = 2 # must be re-zeroed by kernel @staticmethod - def combine(a, b): + def combine(a: WorkspaceZeroMode, b: WorkspaceZeroMode) -> WorkspaceZeroMode: if a == b or b == WorkspaceZeroMode.UNINITIALIZED: return a if a == WorkspaceZeroMode.UNINITIALIZED: @@ -69,7 +77,7 @@ def combine(a, b): raise NotImplementedError(f"WorkspaceZeroMode.combine({a!r}, {b!r})") @staticmethod - def from_bool(zero_fill): + def from_bool(zero_fill: bool) -> WorkspaceZeroMode: if zero_fill: return WorkspaceZeroMode.ZERO_ON_CALL return WorkspaceZeroMode.UNINITIALIZED @@ -96,17 +104,17 @@ class WorkspaceArg: dtype: torch.dtype = torch.uint8 @staticmethod - def unique_name(prefix="workspace_"): + def unique_name(prefix="workspace_") -> str: return f"{prefix}{next(V.graph.workspace_id)}" @staticmethod - def can_join(a, b) -> bool: + def can_join(a: WorkspaceArg, b: WorkspaceArg) -> bool: return ( a.inner_name == b.inner_name and a.dtype == b.dtype and a.device == b.device ) @staticmethod - def join(a, b): + def join(a: WorkspaceArg, b: WorkspaceArg) -> WorkspaceArg: return WorkspaceArg( count=a.count + b.count, zero_mode=WorkspaceZeroMode.combine(a.zero_mode, b.zero_mode), @@ -117,7 +125,7 @@ def join(a, b): ) @staticmethod - def maximum(a, b): + def maximum(a: WorkspaceArg, b: WorkspaceArg) -> WorkspaceArg: assert ( a.dtype == b.dtype and a.device == b.device and a.inner_name == b.inner_name ) @@ -131,15 +139,15 @@ def maximum(a, b): ) # These methods let WorkspaceArg pretend it is a buffer to reuse allocation code - def get_device(self): + def get_device(self) -> torch.device: return self.device get_device_or_error = get_device - def get_dtype(self): + def get_dtype(self) -> torch.dtype: return self.dtype - def get_layout(self): + def get_layout(self) -> FixedLayout: from ..ir import FixedLayout return FixedLayout( @@ -150,23 +158,23 @@ def get_layout(self): ) @property - def layout(self): + def layout(self) -> FixedLayout: return self.get_layout() get_output_spec = get_layout maybe_get_output_spec = get_layout maybe_get_layout = get_layout - def get_size(self): + def get_size(self) -> list[sympy.Expr]: return [self.count] - def get_stride(self): - return [1] + def get_stride(self) -> list[sympy.Expr]: + return [sympy.S.One] - def get_name(self): + def get_name(self) -> str: return self.outer_name - def get_inputs_that_alias_output(self): + def get_inputs_that_alias_output(self) -> list[str]: return [] @@ -185,7 +193,7 @@ class SizeArg: expr: sympy.Expr @property - def alias_of(self): + def alias_of(self) -> Optional[str]: return None @@ -201,9 +209,9 @@ class TMADescriptorArg: @dataclasses.dataclass class DeviceCodegen: - scheduling: Any - wrapper_codegen: type - cpp_wrapper_codegen: type = type(None) + scheduling: SchedulingConstructor + wrapper_codegen: WrapperConstructor + cpp_wrapper_codegen: Optional[WrapperConstructor] = None KernelArgType = Union[WorkspaceArg, TensorArg, SizeArg, TMADescriptorArg] @@ -212,16 +220,16 @@ class DeviceCodegen: class DeviceOpOverrides: - def import_get_raw_stream_as(self, name): + def import_get_raw_stream_as(self, name: str): raise NotImplementedError - def set_device(self, device_idx): + def set_device(self, device_idx: int) -> str: raise NotImplementedError def synchronize(self): raise NotImplementedError - def device_guard(self, device_idx): + def device_guard(self, device_idx: int) -> str: raise NotImplementedError def cpp_device_guard(self): @@ -290,10 +298,10 @@ def tma_descriptor_helpers(self): # https://github.com/intel/intel-extension-for-pytorch/blob/5dcc9d57e5422cf295e1a1ee97896d6b6a554a85/intel_extension_for_pytorch/_inductor/__init__.py#L9 def register_backend_for_device( device: str, - device_scheduling: Any, - device_wrapper_codegen: type, - device_cpp_wrapper_codegen: type = type(None), -): + device_scheduling: SchedulingConstructor, + device_wrapper_codegen: WrapperConstructor, + device_cpp_wrapper_codegen: Optional[WrapperConstructor] = None, +) -> None: device_codegens[device] = DeviceCodegen( device_scheduling, device_wrapper_codegen, device_cpp_wrapper_codegen ) @@ -322,21 +330,27 @@ def get_backend_features(device: Union[torch.device, str, None]): assert isinstance(device, str) device_type = device device = torch.device(device_type) - scheduling = get_scheduling_for_device(device_type) - return scheduling(None).get_backend_features(device) + scheduling_ctor = get_scheduling_for_device(device_type) + assert scheduling_ctor + scheduling = scheduling_ctor(None) + return scheduling.get_backend_features(device) -def has_backend_feature(device, feature): +def has_backend_feature( + device: Union[torch.device, str, None], feature: BackendFeature +) -> bool: """See also V.graph.has_feature""" assert isinstance(feature, BackendFeature) return feature in get_backend_features(device) -def get_scheduling_for_device(device: str): +def get_scheduling_for_device(device: str) -> Optional[SchedulingConstructor]: return device_codegens[device].scheduling if device in device_codegens else None -def get_wrapper_codegen_for_device(device: str, cpp_wrapper: bool = False): +def get_wrapper_codegen_for_device( + device: str, cpp_wrapper: bool = False +) -> Optional[WrapperConstructor]: if device in device_codegens: wrapper_codegen_obj: DeviceCodegen = device_codegens[device] return ( @@ -348,7 +362,7 @@ def get_wrapper_codegen_for_device(device: str, cpp_wrapper: bool = False): @functools.lru_cache(None) -def init_backend_registration(): +def init_backend_registration() -> None: from .cpp import CppScheduling from .cpp_wrapper_cpu import CppWrapperCpu from .cpp_wrapper_cpu_array_ref import CppWrapperCpuArrayRef @@ -367,7 +381,7 @@ def init_backend_registration(): } register_backend_for_device( "cpu", - lambda *args, **kwargs: cpu_backends[config.cpu_backend](*args, **kwargs), + lambda scheduling: cpu_backends[config.cpu_backend](scheduling), PythonWrapperCodegen, CppWrapperCpuArrayRef if config.aot_inductor.allow_stack_allocation @@ -376,10 +390,13 @@ def init_backend_registration(): if get_scheduling_for_device("cuda") is None: # CUDACombinedScheduling combines Triton and CUDA C++ scheduling for CUDA devices via delegation - cuda_backends = {"triton": CUDACombinedScheduling, "halide": HalideScheduling} + cuda_backends = { + "triton": CUDACombinedScheduling, + "halide": HalideScheduling, + } register_backend_for_device( "cuda", - lambda *args, **kwargs: cuda_backends[config.cuda_backend](*args, **kwargs), + lambda scheduling: cuda_backends[config.cuda_backend](scheduling), PythonWrapperCodegen, CppWrapperGpu, ) @@ -422,30 +439,33 @@ def init_backend_registration(): pass -def index_prevent_reordering(index: list[sympy.Expr], index_vars, sizes): +def index_prevent_reordering( + index: list[sympy.Expr], index_vars, sizes +) -> list[sympy.Expr]: from ..ir import FlexibleLayout # added contiguous index prevents reordering return [*index, sympy_dot(index_vars, FlexibleLayout.contiguous_strides(sizes))] -def register_device_op_overrides(device: str, device_op_overrides: DeviceOpOverrides): +def register_device_op_overrides( + device: str, device_op_overrides: DeviceOpOverrides +) -> None: device_op_overrides_dict[device] = device_op_overrides -def get_device_op_overrides(device: str): +def get_device_op_overrides(device: str) -> DeviceOpOverrides: assert isinstance(device, str) - if not device_op_overrides_dict.keys(): + if not device_op_overrides_dict: from . import cpu_device_op_overrides, mps_device_op_overrides # noqa: F401 from .cuda import device_op_overrides # noqa: F401 from .xpu import device_op_overrides as xpu_op_overrides # noqa: F401 - if device in device_op_overrides_dict.keys(): - return device_op_overrides_dict[device] + return device_op_overrides_dict[device] -DTYPE_TO_COMPUTATION_DTYPE = { +DTYPE_TO_COMPUTATION_DTYPE: dict[torch.dtype, torch.dtype] = { torch.bfloat16: torch.float, torch.float16: torch.float, **{ @@ -469,8 +489,8 @@ def get_device_op_overrides(device: str): def deduce_output_dtype_by_name( op_name: str, - *args, - **kwargs, + *args: Any, + **kwargs: Any, ) -> Optional[torch.dtype]: """ Given op name and a list of input dtypes, deduce the output dtype @@ -511,7 +531,7 @@ def deduce_output_dtype_by_name( class DataTypePropagation: - def __init__(self, body) -> None: + def __init__(self, body: LoopBody) -> None: self.body = body self.graphs: dict[Union[Callable[..., Any], str], Any] = { "root": body.root_block.graph @@ -519,7 +539,7 @@ def __init__(self, body) -> None: for k, v in body.subblocks.items(): self.graphs[k] = v.graph - def deduce_node_dtype_by_inputs(self, node: torch.fx.Node): + def deduce_node_dtype_by_inputs(self, node: torch.fx.Node) -> Optional[torch.dtype]: inputs = node.all_input_nodes input_nodes = [ n for n in inputs if isinstance(n, torch.fx.Node) and n.op != "placeholder" @@ -540,13 +560,13 @@ def deduce_node_dtype_by_inputs(self, node: torch.fx.Node): [n.meta[OptimizationContext.key].dtype for n in input_nodes], ) - def deduce_node_dtype_by_subgraph(self, node: torch.fx.Node): + def deduce_node_dtype_by_subgraph(self, node: torch.fx.Node) -> torch.dtype: sub_graph = self.graphs[node.target] dtype = self.propagate_graph(sub_graph) assert dtype return dtype - def deduce_node_dtype(self, node: torch.fx.Node): + def deduce_node_dtype(self, node: torch.fx.Node) -> Optional[torch.dtype]: if node.op == "placeholder": return None @@ -573,9 +593,9 @@ def deduce_node_dtype(self, node: torch.fx.Node): return self.deduce_node_dtype_by_inputs(node) - def propagate_graph(self, graph: torch.fx.Graph): + def propagate_graph(self, graph: torch.fx.Graph) -> Optional[torch.dtype]: assert graph.nodes - graph_dtype = None + graph_dtype: Optional[torch.dtype] = None # For masked_subblock, we use output's dtype to represent # the dtype of this subgraph. For other cases, graph_dtype # might be None @@ -591,25 +611,27 @@ def propagate_graph(self, graph: torch.fx.Graph): graph_dtype = opt_ctx.dtype return graph_dtype - def propagate(self): - self.propagate_graph(self.graphs["root"]) + def propagate(self) -> Optional[torch.dtype]: + return self.propagate_graph(self.graphs["root"]) @classmethod - def propagate_loopbody(cls, body): + def propagate_loopbody(cls, body) -> Optional[torch.dtype]: return cls(body).propagate() @classmethod - def propagate_scheduler_node(cls, node): + def propagate_scheduler_node(cls, node) -> Optional[torch.dtype]: from ..loop_body import LoopBody from ..scheduler import SchedulerNode assert isinstance(node, SchedulerNode) assert isinstance(node._body, LoopBody) - DataTypePropagation.propagate_loopbody(node._body) + return DataTypePropagation.propagate_loopbody(node._body) class PythonPrinter(_PythonPrinter): - def doprint(self, expr, *, simplify: bool = True, p=True): + def doprint( + self, expr: sympy.Expr, *, simplify: bool = True, p: bool = True + ) -> str: # TODO: why are people passing strings to the printer here :think: if simplify and isinstance(expr, sympy.Expr) and hasattr(V.graph, "sizevars"): expr = V.graph.sizevars.simplify(expr) @@ -622,7 +644,7 @@ class OpDecompositions: """ @staticmethod - def identity(value): + def identity(value: _T) -> _T: # used to trigger cse return value @@ -1093,7 +1115,7 @@ def _new_line(self, line): class BracesBuffer(IndentedBuffer): - def indent(self, offset=1): + def indent(self, offset=1) -> contextlib.AbstractContextManager[None]: @contextlib.contextmanager def ctx(): for _ in range(offset): @@ -1120,7 +1142,7 @@ class InplacedBuffer(NamedTuple): class KernelArgs: @staticmethod - def _lookup(prefix, odict, name): + def _lookup(prefix: str, odict: dict[SymbolLike, str], name: SymbolLike) -> str: assert isinstance(name, (str, sympy.Symbol)) if name not in odict: odict[name] = f"{prefix}{len(odict)}" @@ -1133,7 +1155,7 @@ def __init__(self, sizevars=None): self.sizevars = sizevars or {} self.workspace_args = [] - def __repr__(self): + def __repr__(self) -> str: return "KernelArgs({})".format( ", ".join( map( @@ -1185,7 +1207,7 @@ def make_inplace(self, input_name, output_name): self.inplace_buffers[input_name] = buf self.inplace_buffers[output_name] = buf - def workspace(self, nbytes: sympy.Expr, zero_fill: bool): + def workspace(self, nbytes: sympy.Expr, zero_fill: bool) -> tuple[str, int]: """ Allocate or extend a workspace buffer of nbytes bytes. @@ -1226,7 +1248,7 @@ def workspace(self, nbytes: sympy.Expr, zero_fill: bool): self.workspace_args.append(arg) return arg.inner_name, 0 - def semaphores(self, min_size: sympy.Expr): + def semaphores(self, min_size: sympy.Expr) -> str: """ Lazily allocate a graph-wide semaphores buffer with at least min_size. This is a single buffer shared by all kernels and zero initialized once at graph start. Each kernel must leave the buffer zeroed on exit. diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 61d1419ee7f53b..101e99d0f5c733 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -4251,8 +4251,7 @@ def get_backend_features(cls, device: torch.device): return cls.backend_features def __init__(self, scheduler): - super().__init__() - self.scheduler = scheduler + super().__init__(scheduler) if scheduler: self.reset_kernel_group() self._ready_to_flush = False @@ -4955,7 +4954,7 @@ def template_buffer_has_other_users( kernel.call_kernel(kernel_name, ctb) V.graph.removed_buffers |= kernel.removed_buffers - self.scheduler.free_buffers() + self.free_buffers_in_scheduler() def _get_scheduled_num_args(self): return self.kernel_group.get_num_args() diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index 1f71a9eccffb1a..d1e833600042cd 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -70,7 +70,9 @@ def __init__(self): @staticmethod def create( - is_subgraph: bool, subgraph_name: str, parent_wrapper: PythonWrapperCodegen + is_subgraph: bool, + subgraph_name: Optional[str], + parent_wrapper: Optional[PythonWrapperCodegen], ): # TODO - support subgraph codegen by lifting functions. Check the # comment at CppWrapperCpu `codegen_subgraph` function. diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py b/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py index bc79f2bbb73487..44b76299efb9ba 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py @@ -64,7 +64,9 @@ def __init__(self): @staticmethod def create( - is_subgraph: bool, subgraph_name: str, parent_wrapper: PythonWrapperCodegen + is_subgraph: bool, + subgraph_name: Optional[str], + parent_wrapper: Optional[PythonWrapperCodegen], ): # TODO - support subgraph codegen by lifting functions. Check the # comment at CppWrapperCpu `codegen_subgraph` function. diff --git a/torch/_inductor/codegen/cpp_wrapper_gpu.py b/torch/_inductor/codegen/cpp_wrapper_gpu.py index d0d5cccb595503..64b5b2812b0665 100644 --- a/torch/_inductor/codegen/cpp_wrapper_gpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_gpu.py @@ -191,7 +191,9 @@ def __init__(self) -> None: @staticmethod def create( - is_subgraph: bool, subgraph_name: str, parent_wrapper: PythonWrapperCodegen + is_subgraph: bool, + subgraph_name: Optional[str], + parent_wrapper: Optional[PythonWrapperCodegen], ): # TODO - support subgraph codegen by lifting functions. Check the # comment at CppWrapperCpu `codegen_subgraph` function. diff --git a/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py b/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py index 64e6f08ee0748c..13934b83528c73 100644 --- a/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py +++ b/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py @@ -7,7 +7,7 @@ from ... import config from ...codecache import code_hash, get_path from ...ir import CUDATemplateBuffer -from ...scheduler import BaseSchedulerNode, BaseScheduling, Scheduler, SchedulerNode +from ...scheduler import BaseSchedulerNode, BaseScheduling, SchedulerNode from ...utils import get_fused_kernel_name, get_kernel_metadata, sympy_product from ...virtualized import V from ..common import IndentedBuffer @@ -25,10 +25,6 @@ class CUDACPPScheduling(BaseScheduling): It handles fusion decisions and CUDA C++ specific template code generation. """ - def __init__(self, scheduler: Scheduler) -> None: - super().__init__() - self.scheduler = scheduler - @classmethod def get_backend_features(cls, device): return {} @@ -115,4 +111,4 @@ def codegen_template( kernel.call_kernel(kernel_name, ctb) V.graph.removed_buffers |= kernel.removed_buffers - self.scheduler.free_buffers() + self.free_buffers_in_scheduler() diff --git a/torch/_inductor/codegen/cuda_combined_scheduling.py b/torch/_inductor/codegen/cuda_combined_scheduling.py index 6d1f94fed977ff..df7faa510f93b4 100644 --- a/torch/_inductor/codegen/cuda_combined_scheduling.py +++ b/torch/_inductor/codegen/cuda_combined_scheduling.py @@ -1,6 +1,6 @@ # mypy: allow-untyped-defs from collections.abc import Sequence -from typing import Union +from typing import Optional, Union from ..scheduler import ( BaseSchedulerNode, @@ -24,9 +24,8 @@ class CUDACombinedScheduling(BaseScheduling): this would also be the place to do it. """ - def __init__(self, scheduler: Scheduler) -> None: - super().__init__() - self._scheduler = scheduler + def __init__(self, scheduler: Optional[Scheduler]) -> None: + super().__init__(scheduler) self._triton_scheduling = TritonScheduling(scheduler) self._cuda_cpp_scheduling = CUDACPPScheduling(scheduler) self._rocm_cpp_scheduling = ROCmCPPScheduling(scheduler) diff --git a/torch/_inductor/codegen/mps.py b/torch/_inductor/codegen/mps.py index b0a4117f7f4142..8f46d3a759f339 100644 --- a/torch/_inductor/codegen/mps.py +++ b/torch/_inductor/codegen/mps.py @@ -452,7 +452,7 @@ def check_bounds( class MetalScheduling(SIMDScheduling): kernel_type = MetalKernel # type: ignore[assignment] - def __init__(self, scheduler: Scheduler) -> None: + def __init__(self, scheduler: Optional[Scheduler]) -> None: super().__init__(scheduler) wrapper = V.graph.wrapper_code if wrapper is not None: diff --git a/torch/_inductor/codegen/rocm/rocm_cpp_scheduling.py b/torch/_inductor/codegen/rocm/rocm_cpp_scheduling.py index 2b7a935a4c1013..0b8b16d36cb669 100644 --- a/torch/_inductor/codegen/rocm/rocm_cpp_scheduling.py +++ b/torch/_inductor/codegen/rocm/rocm_cpp_scheduling.py @@ -5,7 +5,7 @@ from ... import config from ...codecache import code_hash, get_path -from ...scheduler import BaseSchedulerNode, BaseScheduling, Scheduler, SchedulerNode +from ...scheduler import BaseSchedulerNode, BaseScheduling, SchedulerNode from ...utils import get_fused_kernel_name, get_kernel_metadata, sympy_product from ...virtualized import V from ..common import IndentedBuffer @@ -24,10 +24,6 @@ class ROCmCPPScheduling(BaseScheduling): It handles fusion decisions and ROCm C++ specific template code generation. """ - def __init__(self, scheduler: Scheduler) -> None: - super().__init__() - self.scheduler = scheduler - def group_fn(self, sizes): return tuple(V.graph.sizevars.simplify(sympy_product(s)) for s in sizes) @@ -100,4 +96,4 @@ def codegen_template( kernel_name = self.define_kernel(src_code, node_schedule) kernel.call_kernel(kernel_name, ctb) V.graph.removed_buffers |= kernel.removed_buffers - self.scheduler.free_buffers() + self.free_buffers_in_scheduler() diff --git a/torch/_inductor/codegen/simd.py b/torch/_inductor/codegen/simd.py index 8faafe3dd95c01..a011fa82d431cf 100644 --- a/torch/_inductor/codegen/simd.py +++ b/torch/_inductor/codegen/simd.py @@ -1062,10 +1062,6 @@ def codegen_iteration_ranges_entry(self, entry: IterationRangesEntry): class SIMDScheduling(BaseScheduling): kernel_type: type[Any] = SIMDKernel # override in subclass - def __init__(self, scheduler) -> None: - super().__init__() - self.scheduler = scheduler - def group_fn(self, sizes): return tuple(V.graph.sizevars.simplify(sympy_product(s)) for s in sizes) @@ -1411,7 +1407,7 @@ def codegen_node_schedule(self, kernel_features: SIMDKernelFeatures): f"run_intermediate_hooks({origin_node.name!r}, {name})" ) - self.scheduler.free_buffers() + self.free_buffers_in_scheduler() def create_kernel_choices( self, kernel_features: SIMDKernelFeatures, kernel_args, kernel_kwargs @@ -1575,7 +1571,7 @@ def codegen_template( V.graph.removed_buffers |= kernel.removed_buffers V.graph.inplaced_to_remove |= kernel.inplaced_to_remove - self.scheduler.free_buffers() + self.free_buffers_in_scheduler() return None def codegen_sync(self): @@ -1660,7 +1656,7 @@ def codegen_combo_kernel(self, combo_kernel_node): log.debug("ComboKernels: generated kernel %s.", kernel_name) kernel.call_kernel(V.graph.wrapper_code, kernel_name) - self.scheduler.free_buffers() + self.free_buffers_in_scheduler() @classmethod @functools.lru_cache(32) diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index e539edf00cb21f..f6c44abaae0bee 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -3831,7 +3831,7 @@ class TritonScheduling(SIMDScheduling): ) ) - def __init__(self, scheduler: Scheduler) -> None: + def __init__(self, scheduler: Optional[Scheduler]) -> None: super().__init__(scheduler) if scheduler is None or not hasattr(scheduler, "nodes"): return diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index d76f943fde2584..508f325f09f980 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -727,9 +727,13 @@ def add_import_once(line: str) -> None: @staticmethod def create( - is_subgraph: bool, subgraph_name: str, parent_wrapper: PythonWrapperCodegen + is_subgraph: bool, + subgraph_name: Optional[str], + parent_wrapper: Optional[PythonWrapperCodegen], ): if is_subgraph: + assert subgraph_name is not None + assert parent_wrapper is not None return SubgraphPythonWrapperCodegen(subgraph_name, parent_wrapper) return PythonWrapperCodegen() @@ -1774,7 +1778,9 @@ def generate_workspace_allocation(self, ws: WorkspaceArg): elif ws.zero_mode == WorkspaceZeroMode.ZERO_PER_GRAPH: prior = self.allocated_workspaces.get(name) if prior: - assert isinstance(prior, AllocateLine) + assert isinstance(prior, AllocateLine) and isinstance( + prior.node, WorkspaceArg + ) # expand existing allocation prior.node = WorkspaceArg.maximum(prior.node, ws) else: @@ -2554,7 +2560,7 @@ class SubgraphPythonWrapperCodegen(PythonWrapperCodegen): imports twice in the output code) """ - def __init__(self, subgraph_name, parent_wrapper): + def __init__(self, subgraph_name: str, parent_wrapper: PythonWrapperCodegen): # It is necessary to set the subgraph_name before calling super __init__ # because __init__ calls set_launcher_fn_name self.subgraph_name = subgraph_name diff --git a/torch/_inductor/metrics.py b/torch/_inductor/metrics.py index f4aba743fcfee3..200c111d87fea9 100644 --- a/torch/_inductor/metrics.py +++ b/torch/_inductor/metrics.py @@ -143,7 +143,7 @@ def add_row( ), f"{len(self.column_names)} v.s. {len(row_dict)}" assert OrderedSet(self.column_names) == OrderedSet( row_dict.keys() - ), f"{set(self.column_names)} v.s. {set(row_dict.keys())}" + ), f"{OrderedSet(self.column_names)} v.s. {OrderedSet(row_dict.keys())}" row = [ get_benchmark_name(), diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index ec29fb075f3b13..6a539f28adb317 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -3915,6 +3915,14 @@ def update_zero_dim_cpu_tensor(self) -> None: class BaseScheduling: + def __init__(self, scheduler: Optional[Scheduler]): + super().__init__() + self.scheduler = scheduler + + def free_buffers_in_scheduler(self) -> None: + if self.scheduler: + self.scheduler.free_buffers() + @classmethod def get_backend_features(cls, device: torch.device) -> Sequence[BackendFeature]: """Return a set of .codegen.common.BackendFeature()"""