From 2e7759814eb3e8c7b7622f3dc033b8806a9b3246 Mon Sep 17 00:00:00 2001 From: Dave Bort Date: Thu, 20 Feb 2025 16:01:37 -0800 Subject: [PATCH] Update buck deps for new replace_scalar_with_tensor transforms, test; fix type annotations (#8588) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/8588 #8519 (D69765463) adds this new module. Update the buck deps to reflect the new dependencies. Also fix some type annotations; the Meta-internal type checker seems to be more strict. Reviewed By: swolchok Differential Revision: D69883148 --- backends/arm/_passes/TARGETS | 1 + backends/cadence/aot/TARGETS | 1 + backends/cadence/aot/pass_utils.py | 16 ++++++++-------- backends/cadence/aot/replace_ops.py | 6 +++--- backends/transforms/targets.bzl | 14 ++++++++++++++ 5 files changed, 27 insertions(+), 11 deletions(-) diff --git a/backends/arm/_passes/TARGETS b/backends/arm/_passes/TARGETS index 6ca59cfee27..843d6b159dc 100644 --- a/backends/arm/_passes/TARGETS +++ b/backends/arm/_passes/TARGETS @@ -7,6 +7,7 @@ python_library( deps = [ "//executorch/backends/arm:tosa_quant_utils", "//executorch/backends/arm:tosa_utils", + "//executorch/backends/transforms:replace_scalar_with_tensor", "//executorch/backends/xnnpack/_passes:xnnpack_passes", "//executorch/exir:lib", ], diff --git a/backends/cadence/aot/TARGETS b/backends/cadence/aot/TARGETS index 78a78bbda30..2dd3c4dc49d 100644 --- a/backends/cadence/aot/TARGETS +++ b/backends/cadence/aot/TARGETS @@ -256,6 +256,7 @@ python_library( "//executorch/backends/cadence/aot:pass_utils", "//executorch/backends/cadence/aot:remove_ops", "//executorch/backends/cadence/aot:utils", + "//executorch/backends/transforms:replace_scalar_with_tensor", "//executorch/exir:pass_base", "//executorch/exir/dialects:lib", "//executorch/exir/dialects/edge:lib", diff --git a/backends/cadence/aot/pass_utils.py b/backends/cadence/aot/pass_utils.py index d0166061c7f..3d73e7f8c1e 100644 --- a/backends/cadence/aot/pass_utils.py +++ b/backends/cadence/aot/pass_utils.py @@ -7,7 +7,7 @@ # pyre-strict from dataclasses import dataclass -from typing import Callable, List, Optional, Set, Union +from typing import Callable, List, Optional, Set, Type, Union import torch from executorch.backends.cadence.aot.utils import get_edge_overload_packet @@ -32,33 +32,33 @@ class CadencePassAttribute: # A dictionary that maps an ExportPass to its attributes. -ALL_CADENCE_PASSES: dict[ExportPass, CadencePassAttribute] = {} +ALL_CADENCE_PASSES: dict[Type[ExportPass], CadencePassAttribute] = {} -def get_cadence_pass_attribute(p: ExportPass) -> CadencePassAttribute: +def get_cadence_pass_attribute(p: Type[ExportPass]) -> CadencePassAttribute: return ALL_CADENCE_PASSES[p] # A decorator that registers a pass. def register_cadence_pass( pass_attribute: CadencePassAttribute, -) -> Callable[[ExportPass], ExportPass]: - def wrapper(cls: ExportPass) -> ExportPass: +) -> Callable[[Type[ExportPass]], Type[ExportPass]]: + def wrapper(cls: Type[ExportPass]) -> Type[ExportPass]: ALL_CADENCE_PASSES[cls] = pass_attribute return cls return wrapper -def get_all_available_cadence_passes() -> Set[ExportPass]: +def get_all_available_cadence_passes() -> Set[Type[ExportPass]]: return set(ALL_CADENCE_PASSES.keys()) # Create a new filter to filter out relevant passes from all passes. def create_cadence_pass_filter( opt_level: int, debug: bool = False -) -> Callable[[ExportPass], bool]: - def _filter(p: ExportPass) -> bool: +) -> Callable[[Type[ExportPass]], bool]: + def _filter(p: Type[ExportPass]) -> bool: pass_attribute = get_cadence_pass_attribute(p) return ( pass_attribute.opt_level is not None diff --git a/backends/cadence/aot/replace_ops.py b/backends/cadence/aot/replace_ops.py index 120f69008c1..f91fb26ddc8 100644 --- a/backends/cadence/aot/replace_ops.py +++ b/backends/cadence/aot/replace_ops.py @@ -1719,9 +1719,9 @@ def call_operator(self, op, args, kwargs, meta): ) -@register_cadence_pass(CadencePassAttribute(opt_level=0))( - ReplaceScalarWithTensorArgPass() -) +register_cadence_pass(CadencePassAttribute(opt_level=0))(ReplaceScalarWithTensorArgPass) + + @register_cadence_pass(CadencePassAttribute(opt_level=0)) class ReplaceScalarTensorWithFullPass(ExportPass): """ diff --git a/backends/transforms/targets.bzl b/backends/transforms/targets.bzl index c532798546d..ec4e1412862 100644 --- a/backends/transforms/targets.bzl +++ b/backends/transforms/targets.bzl @@ -201,6 +201,20 @@ def define_common_targets(): ], ) + runtime.python_library( + name = "replace_scalar_with_tensor", + srcs = [ + "replace_scalar_with_tensor.py", + ], + visibility = [ + "//executorch/backends/...", + ], + deps = [ + "//caffe2:torch", + "//executorch/exir:pass_base", + ], + ) + runtime.python_test( name = "test_duplicate_dynamic_quant_chain", srcs = [