diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 5332dd9184..28d70591e5 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -77,6 +77,9 @@ ) from executorch.backends.arm.tosa_specification import TosaSpecification +from executorch.backends.transforms.replace_scalar_with_tensor import ( + ReplaceScalarWithTensorArgPass, +) from executorch.backends.xnnpack._passes.remove_getitem_op import RemoveGetItemPass from executorch.exir import ExportedProgram from executorch.exir.pass_manager import PassManager @@ -102,6 +105,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul self.add_pass(ConvertMeanDimToAveragePoolPass()) self.add_pass(ConvertFullLikeToFullPass()) + self.add_pass(ReplaceScalarWithTensorArgPass()) self.add_pass(AnnotateDecomposedMatmulPass()) self.add_pass(QuantizeOperatorArguments()) self.add_pass(FoldAndAnnotateQParamsPass()) # type: ignore[call-arg] @@ -125,7 +129,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul return self._transform(exported_program.graph_module) def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModule: - + self.add_pass(ReplaceScalarWithTensorArgPass()) self.add_pass(FuseQuantizedActivationPass()) self.add_pass(RemoveGetItemPass()) self.add_pass(ConvertSplitToSlicePass()) @@ -176,6 +180,7 @@ def transform_to_backend_pipeline(self, exported_program: ExportedProgram): def transform_for_annotation_pipeline(self, graph_module: GraphModule): self.add_pass(ScalarsToAttributePass()) + self.add_pass(ReplaceScalarWithTensorArgPass()) self.add_pass(DecomposeLayerNormPass()) self.add_pass(DecomposeVarPass()) self.add_pass(DecomposeMeanDimPass()) diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index 67c5b660ec..1fa626efce 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -113,6 +113,10 @@ def is_node_supported(self, submodules, node: fx.Node) -> bool: exir_ops.edge.aten.le.Tensor, exir_ops.edge.aten.lt.Tensor, exir_ops.edge.aten.mul.Tensor, + exir_ops.edge.aten.add.Scalar, + exir_ops.edge.aten.sub.Scalar, + exir_ops.edge.aten.mul.Scalar, + exir_ops.edge.aten.div.Scalar, exir_ops.edge.aten._native_batch_norm_legit_no_training.default, exir_ops.edge.aten.native_layer_norm.default, exir_ops.edge.aten.sigmoid.default, diff --git a/backends/arm/test/models/test_conformer.py b/backends/arm/test/models/test_conformer.py index 574ae1a5d4..d9bc4e363c 100644 --- a/backends/arm/test/models/test_conformer.py +++ b/backends/arm/test/models/test_conformer.py @@ -32,7 +32,6 @@ class TestConformer(unittest.TestCase): ops_after_partitioner = { "executorch_exir_dialects_edge__ops_aten_arange_start_step": 1, "executorch_exir_dialects_edge__ops_aten_max_default": 1, - "executorch_exir_dialects_edge__ops_aten_mul_Scalar": 4, "executorch_exir_dialects_edge__ops_aten_eq_Scalar": 2, "executorch_exir_dialects_edge__ops_aten_where_self": 4, "executorch_exir_dialects_edge__ops_aten_logical_not_default": 4, @@ -40,7 +39,7 @@ class TestConformer(unittest.TestCase): "torch.ops.aten._assert_scalar.default": 10, "torch.ops.aten._local_scalar_dense.default": 1, "torch.ops.aten.scalar_tensor.default": 2, - "torch.ops.higher_order.executorch_call_delegate": 5, + "torch.ops.higher_order.executorch_call_delegate": 4, } dim = 16 diff --git a/backends/arm/test/ops/test_scalars.py b/backends/arm/test/ops/test_scalars.py index bcf294de4a..58b8eb83a6 100644 --- a/backends/arm/test/ops/test_scalars.py +++ b/backends/arm/test/ops/test_scalars.py @@ -1,3 +1,8 @@ +# Copyright 2024-2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + import unittest import torch @@ -50,6 +55,22 @@ class Mul(torch.nn.Module): def forward(self, x, y): return x * y + class MulScalar(torch.nn.Module): + def forward(self, x, y): + return torch.ops.aten.mul.Scalar(x, y) + + class DivScalar(torch.nn.Module): + def forward(self, x, y): + return torch.ops.aten.div.Scalar(x, y) + + class AddScalar(torch.nn.Module): + def forward(self, x, y): + return torch.ops.aten.add.Scalar(x, y) + + class SubScalar(torch.nn.Module): + def forward(self, x, y): + return torch.ops.aten.sub.Scalar(x, y) + class AddInplace(torch.nn.Module): def forward(self, x, y): x += y @@ -91,6 +112,10 @@ def forward(self, x): ("Sub_", SubInplace()), ("Mul_", MulInplace()), ("Div_", DivInplace()), + ("MulScalar", MulScalar()), + ("DivScalar", DivScalar()), + ("AddScalar", AddScalar()), + ("SubScalar", SubScalar()), ] const_ops = [("Add", AddConst())] @@ -108,8 +133,8 @@ def forward(self, x): scalar = dtype[1] tensor_scalar_tests.append((test_name + "_ts", op[1], tensor, scalar)) - # Don't add (scalar, tensor) test case for inplace ops. - if op[0][-1] == "_": + # Don't add (scalar, tensor) test case for inplace and .Scalar ops. + if op[0][-1] == "_" or op[0][-6:] == "Scalar": continue # sub(scalar, tensor) does not work in any case. diff --git a/backends/arm/tosa_mapping.py b/backends/arm/tosa_mapping.py index 13eb53dfa8..d1849a7f47 100644 --- a/backends/arm/tosa_mapping.py +++ b/backends/arm/tosa_mapping.py @@ -11,6 +11,8 @@ # the standardised TOSA representation. # +from typing import Sequence + import serializer.tosa_serializer as ts # type: ignore import torch @@ -99,7 +101,7 @@ def __init__(self, argument) -> None: if isinstance(argument, torch.fx.Node): self.__process_node(argument) return - if isinstance(argument, list): + if isinstance(argument, Sequence): self.__process_list(argument) return if isinstance(argument, (int, float)): diff --git a/backends/cadence/aot/replace_ops.py b/backends/cadence/aot/replace_ops.py index 633733ead3..487d374fb8 100644 --- a/backends/cadence/aot/replace_ops.py +++ b/backends/cadence/aot/replace_ops.py @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -37,6 +38,9 @@ ) from executorch.backends.cadence.aot.remove_ops import RemoveNopSelectOpPass from executorch.backends.cadence.aot.utils import get_edge_overload_packet +from executorch.backends.transforms.replace_scalar_with_tensor import ( + ReplaceScalarWithTensorArgPass, +) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket from executorch.exir.pass_base import ExportPass, NodeMetadata, PassResult, ProxyValue @@ -1713,65 +1717,9 @@ def call_operator(self, op, args, kwargs, meta): ) -@register_cadence_pass(CadencePassAttribute(opt_level=0)) -class ReplaceScalarWithTensorArgPass(ExportPass): - """ - For binary ops like add.Scalar, sub.Scalar mul.Scalar, and div.Scalar, - replace the scalar arg with Tensor arg. - """ - - scalar_to_tensor_ops: Dict[EdgeOpOverload, EdgeOpOverload] = { - exir_ops.edge.aten.add.Scalar: exir_ops.edge.aten.add.Tensor, - exir_ops.edge.aten.sub.Scalar: exir_ops.edge.aten.sub.Tensor, - exir_ops.edge.aten.mul.Scalar: exir_ops.edge.aten.mul.Tensor, - exir_ops.edge.aten.div.Scalar: exir_ops.edge.aten.div.Tensor, - } - - def get_replacement(self, op, args, kwargs, meta): - return super().call_operator( - # Replace with .Tensor variant. - op=self.scalar_to_tensor_ops[op], - args=( - # Tensor arg. - args[0], - # Scalar arg - replace with aten.full tensor. - super().call_operator( - exir_ops.edge.aten.full.default, - args=( - (1,), - args[1], - ), - kwargs={"dtype": args[0].to_tensor().dtype}, - meta=meta, - ), - # Other args. - *args[2:], - ), - kwargs=kwargs, - meta=meta, - ) - - def call_operator(self, op, args, kwargs, meta): - if op not in self.scalar_to_tensor_ops: - return super().call_operator(op, args, kwargs, meta) - - # There must be exactly 2 args (3 for add and sub containing alpha) - assert len(args) == 2 or len(args) == 3 - - # If there are two args, just replace the op. - if len(args) == 2: - return self.get_replacement(op, args, kwargs, meta) - - # In case the op has three args, it must be scalar add/sub op. - if ( - op not in {exir_ops.edge.aten.add.Scalar, exir_ops.edge.aten.sub.Scalar} - or "alpha" in kwargs - ): - return super().call_operator(op, args, kwargs, meta) - - return self.get_replacement(op, args, kwargs, meta) - - +@register_cadence_pass(CadencePassAttribute(opt_level=0))( + ReplaceScalarWithTensorArgPass() +) @register_cadence_pass(CadencePassAttribute(opt_level=0)) class ReplaceScalarTensorWithFullPass(ExportPass): """ diff --git a/backends/transforms/replace_scalar_with_tensor.py b/backends/transforms/replace_scalar_with_tensor.py new file mode 100644 index 0000000000..9e2383654d --- /dev/null +++ b/backends/transforms/replace_scalar_with_tensor.py @@ -0,0 +1,75 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict + +import torch +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.dialects.edge._ops import EdgeOpOverload +from executorch.exir.pass_base import ExportPass + + +class ReplaceScalarWithTensorArgPass(ExportPass): + """ + For binary ops like add.Scalar, sub.Scalar mul.Scalar, and div.Scalar, + replace the scalar arg with Tensor arg. + """ + + scalar_to_tensor_ops: Dict[EdgeOpOverload, EdgeOpOverload] = { + exir_ops.edge.aten.add.Scalar: exir_ops.edge.aten.add.Tensor, + exir_ops.edge.aten.sub.Scalar: exir_ops.edge.aten.sub.Tensor, + exir_ops.edge.aten.mul.Scalar: exir_ops.edge.aten.mul.Tensor, + exir_ops.edge.aten.div.Scalar: exir_ops.edge.aten.div.Tensor, + torch.ops.aten.add.Scalar: torch.ops.aten.add.Tensor, + torch.ops.aten.sub.Scalar: torch.ops.aten.sub.Tensor, + torch.ops.aten.mul.Scalar: torch.ops.aten.mul.Tensor, + torch.ops.aten.div.Scalar: torch.ops.aten.div.Tensor, + } + + def get_replacement(self, op, args, kwargs, meta): + return super().call_operator( + # Replace with .Tensor variant. + op=self.scalar_to_tensor_ops[op], + args=( + # Tensor arg. + args[0], + # Scalar arg - replace with aten.full tensor. + super().call_operator( + exir_ops.edge.aten.full.default, + args=( + (1,), + args[1], + ), + kwargs={"dtype": args[0].to_tensor().dtype}, + meta=meta, + ), + # Other args. + *args[2:], + ), + kwargs=kwargs, + meta=meta, + ) + + def call_operator(self, op, args, kwargs, meta): + if op not in self.scalar_to_tensor_ops: + return super().call_operator(op, args, kwargs, meta) + + # There must be exactly 2 args (3 for add and sub containing alpha) + assert len(args) == 2 or len(args) == 3 + + # If there are two args, just replace the op. + if len(args) == 2: + return self.get_replacement(op, args, kwargs, meta) + + # In case the op has three args, it must be scalar add/sub op. + if ( + op not in {exir_ops.edge.aten.add.Scalar, exir_ops.edge.aten.sub.Scalar} + or "alpha" in kwargs + ): + return super().call_operator(op, args, kwargs, meta) + + return self.get_replacement(op, args, kwargs, meta)