Skip to content

Commit

Permalink
Update buck deps for new replace_scalar_with_tensor transforms, test;…
Browse files Browse the repository at this point in the history
… fix type annotations (pytorch#8588)

Summary:
Pull Request resolved: pytorch#8588

pytorch#8519 (D69765463) adds this new module. Update the buck deps to reflect the new dependencies.

Also fix some type annotations; the Meta-internal type checker seems to be more strict.

Reviewed By: swolchok

Differential Revision: D69883148
  • Loading branch information
dbort authored and facebook-github-bot committed Feb 21, 2025
1 parent 735f16e commit 2e77598
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 11 deletions.
1 change: 1 addition & 0 deletions backends/arm/_passes/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ python_library(
deps = [
"//executorch/backends/arm:tosa_quant_utils",
"//executorch/backends/arm:tosa_utils",
"//executorch/backends/transforms:replace_scalar_with_tensor",
"//executorch/backends/xnnpack/_passes:xnnpack_passes",
"//executorch/exir:lib",
],
Expand Down
1 change: 1 addition & 0 deletions backends/cadence/aot/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ python_library(
"//executorch/backends/cadence/aot:pass_utils",
"//executorch/backends/cadence/aot:remove_ops",
"//executorch/backends/cadence/aot:utils",
"//executorch/backends/transforms:replace_scalar_with_tensor",
"//executorch/exir:pass_base",
"//executorch/exir/dialects:lib",
"//executorch/exir/dialects/edge:lib",
Expand Down
16 changes: 8 additions & 8 deletions backends/cadence/aot/pass_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# pyre-strict

from dataclasses import dataclass
from typing import Callable, List, Optional, Set, Union
from typing import Callable, List, Optional, Set, Type, Union

import torch
from executorch.backends.cadence.aot.utils import get_edge_overload_packet
Expand All @@ -32,33 +32,33 @@ class CadencePassAttribute:


# A dictionary that maps an ExportPass to its attributes.
ALL_CADENCE_PASSES: dict[ExportPass, CadencePassAttribute] = {}
ALL_CADENCE_PASSES: dict[Type[ExportPass], CadencePassAttribute] = {}


def get_cadence_pass_attribute(p: ExportPass) -> CadencePassAttribute:
def get_cadence_pass_attribute(p: Type[ExportPass]) -> CadencePassAttribute:
return ALL_CADENCE_PASSES[p]


# A decorator that registers a pass.
def register_cadence_pass(
pass_attribute: CadencePassAttribute,
) -> Callable[[ExportPass], ExportPass]:
def wrapper(cls: ExportPass) -> ExportPass:
) -> Callable[[Type[ExportPass]], Type[ExportPass]]:
def wrapper(cls: Type[ExportPass]) -> Type[ExportPass]:
ALL_CADENCE_PASSES[cls] = pass_attribute
return cls

return wrapper


def get_all_available_cadence_passes() -> Set[ExportPass]:
def get_all_available_cadence_passes() -> Set[Type[ExportPass]]:
return set(ALL_CADENCE_PASSES.keys())


# Create a new filter to filter out relevant passes from all passes.
def create_cadence_pass_filter(
opt_level: int, debug: bool = False
) -> Callable[[ExportPass], bool]:
def _filter(p: ExportPass) -> bool:
) -> Callable[[Type[ExportPass]], bool]:
def _filter(p: Type[ExportPass]) -> bool:
pass_attribute = get_cadence_pass_attribute(p)
return (
pass_attribute.opt_level is not None
Expand Down
6 changes: 3 additions & 3 deletions backends/cadence/aot/replace_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1719,9 +1719,9 @@ def call_operator(self, op, args, kwargs, meta):
)


@register_cadence_pass(CadencePassAttribute(opt_level=0))(
ReplaceScalarWithTensorArgPass()
)
register_cadence_pass(CadencePassAttribute(opt_level=0))(ReplaceScalarWithTensorArgPass)


@register_cadence_pass(CadencePassAttribute(opt_level=0))
class ReplaceScalarTensorWithFullPass(ExportPass):
"""
Expand Down
14 changes: 14 additions & 0 deletions backends/transforms/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,20 @@ def define_common_targets():
],
)

runtime.python_library(
name = "replace_scalar_with_tensor",
srcs = [
"replace_scalar_with_tensor.py",
],
visibility = [
"//executorch/backends/...",
],
deps = [
"//caffe2:torch",
"//executorch/exir:pass_base",
],
)

runtime.python_test(
name = "test_duplicate_dynamic_quant_chain",
srcs = [
Expand Down

0 comments on commit 2e77598

Please sign in to comment.