Skip to content

Commit

Permalink
Move ReplaceScalarWithTensorArgPass to transforms
Browse files Browse the repository at this point in the history
The pass is general and can be used by multiple backends.
Use it in Arm backend and make small adjustments to make
it work.

Signed-off-by: Erik Lundell <erik.lundell@arm.com>
Change-Id: I61863d9cefb1753c604d67a6e44845af46ef7c60
  • Loading branch information
Erik-Lundell committed Feb 17, 2025
1 parent 5e4d6b6 commit 8fcd2b2
Show file tree
Hide file tree
Showing 7 changed files with 123 additions and 65 deletions.
7 changes: 6 additions & 1 deletion backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@
)
from executorch.backends.arm.tosa_specification import TosaSpecification

from executorch.backends.transforms.replace_scalar_with_tensor import (
ReplaceScalarWithTensorArgPass,
)
from executorch.backends.xnnpack._passes.remove_getitem_op import RemoveGetItemPass
from executorch.exir import ExportedProgram
from executorch.exir.pass_manager import PassManager
Expand All @@ -102,6 +105,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
self.add_pass(ConvertMeanDimToAveragePoolPass())
self.add_pass(ConvertFullLikeToFullPass())

self.add_pass(ReplaceScalarWithTensorArgPass())
self.add_pass(AnnotateDecomposedMatmulPass())
self.add_pass(QuantizeOperatorArguments())
self.add_pass(FoldAndAnnotateQParamsPass()) # type: ignore[call-arg]
Expand All @@ -125,7 +129,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
return self._transform(exported_program.graph_module)

def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModule:

self.add_pass(ReplaceScalarWithTensorArgPass())
self.add_pass(FuseQuantizedActivationPass())
self.add_pass(RemoveGetItemPass())
self.add_pass(ConvertSplitToSlicePass())
Expand Down Expand Up @@ -176,6 +180,7 @@ def transform_to_backend_pipeline(self, exported_program: ExportedProgram):

def transform_for_annotation_pipeline(self, graph_module: GraphModule):
self.add_pass(ScalarsToAttributePass())
self.add_pass(ReplaceScalarWithTensorArgPass())
self.add_pass(DecomposeLayerNormPass())
self.add_pass(DecomposeVarPass())
self.add_pass(DecomposeMeanDimPass())
Expand Down
4 changes: 4 additions & 0 deletions backends/arm/operator_support/tosa_supported_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,10 @@ def is_node_supported(self, submodules, node: fx.Node) -> bool:
exir_ops.edge.aten.le.Tensor,
exir_ops.edge.aten.lt.Tensor,
exir_ops.edge.aten.mul.Tensor,
exir_ops.edge.aten.add.Scalar,
exir_ops.edge.aten.sub.Scalar,
exir_ops.edge.aten.mul.Scalar,
exir_ops.edge.aten.div.Scalar,
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
exir_ops.edge.aten.native_layer_norm.default,
exir_ops.edge.aten.sigmoid.default,
Expand Down
3 changes: 1 addition & 2 deletions backends/arm/test/models/test_conformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,14 @@ class TestConformer(unittest.TestCase):
ops_after_partitioner = {
"executorch_exir_dialects_edge__ops_aten_arange_start_step": 1,
"executorch_exir_dialects_edge__ops_aten_max_default": 1,
"executorch_exir_dialects_edge__ops_aten_mul_Scalar": 4,
"executorch_exir_dialects_edge__ops_aten_eq_Scalar": 2,
"executorch_exir_dialects_edge__ops_aten_where_self": 4,
"executorch_exir_dialects_edge__ops_aten_logical_not_default": 4,
"executorch_exir_dialects_edge__ops_aten_any_dim": 2,
"torch.ops.aten._assert_scalar.default": 10,
"torch.ops.aten._local_scalar_dense.default": 1,
"torch.ops.aten.scalar_tensor.default": 2,
"torch.ops.higher_order.executorch_call_delegate": 5,
"torch.ops.higher_order.executorch_call_delegate": 4,
}

dim = 16
Expand Down
29 changes: 27 additions & 2 deletions backends/arm/test/ops/test_scalars.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# Copyright 2024-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch
Expand Down Expand Up @@ -50,6 +55,22 @@ class Mul(torch.nn.Module):
def forward(self, x, y):
return x * y

class MulScalar(torch.nn.Module):
def forward(self, x, y):
return torch.ops.aten.mul.Scalar(x, y)

class DivScalar(torch.nn.Module):
def forward(self, x, y):
return torch.ops.aten.div.Scalar(x, y)

class AddScalar(torch.nn.Module):
def forward(self, x, y):
return torch.ops.aten.add.Scalar(x, y)

class SubScalar(torch.nn.Module):
def forward(self, x, y):
return torch.ops.aten.sub.Scalar(x, y)

class AddInplace(torch.nn.Module):
def forward(self, x, y):
x += y
Expand Down Expand Up @@ -91,6 +112,10 @@ def forward(self, x):
("Sub_", SubInplace()),
("Mul_", MulInplace()),
("Div_", DivInplace()),
("MulScalar", MulScalar()),
("DivScalar", DivScalar()),
("AddScalar", AddScalar()),
("SubScalar", SubScalar()),
]

const_ops = [("Add", AddConst())]
Expand All @@ -108,8 +133,8 @@ def forward(self, x):
scalar = dtype[1]
tensor_scalar_tests.append((test_name + "_ts", op[1], tensor, scalar))

# Don't add (scalar, tensor) test case for inplace ops.
if op[0][-1] == "_":
# Don't add (scalar, tensor) test case for inplace and .Scalar ops.
if op[0][-1] == "_" or op[0][-6:] == "Scalar":
continue

# sub(scalar, tensor) does not work in any case.
Expand Down
4 changes: 3 additions & 1 deletion backends/arm/tosa_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# the standardised TOSA representation.
#

from typing import Sequence

import serializer.tosa_serializer as ts # type: ignore
import torch

Expand Down Expand Up @@ -95,7 +97,7 @@ def __init__(self, argument) -> None:
if isinstance(argument, torch.fx.node.Node):
self.__process_node(argument)
return
if isinstance(argument, list):
if isinstance(argument, Sequence):
self.__process_list(argument)
return
if isinstance(argument, int):
Expand Down
66 changes: 7 additions & 59 deletions backends/cadence/aot/replace_ops.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -37,6 +38,9 @@
)
from executorch.backends.cadence.aot.remove_ops import RemoveNopSelectOpPass
from executorch.backends.cadence.aot.utils import get_edge_overload_packet
from executorch.backends.transforms.replace_scalar_with_tensor import (
ReplaceScalarWithTensorArgPass,
)
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket
from executorch.exir.pass_base import ExportPass, NodeMetadata, PassResult, ProxyValue
Expand Down Expand Up @@ -1713,65 +1717,9 @@ def call_operator(self, op, args, kwargs, meta):
)


@register_cadence_pass(CadencePassAttribute(opt_level=0))
class ReplaceScalarWithTensorArgPass(ExportPass):
"""
For binary ops like add.Scalar, sub.Scalar mul.Scalar, and div.Scalar,
replace the scalar arg with Tensor arg.
"""

scalar_to_tensor_ops: Dict[EdgeOpOverload, EdgeOpOverload] = {
exir_ops.edge.aten.add.Scalar: exir_ops.edge.aten.add.Tensor,
exir_ops.edge.aten.sub.Scalar: exir_ops.edge.aten.sub.Tensor,
exir_ops.edge.aten.mul.Scalar: exir_ops.edge.aten.mul.Tensor,
exir_ops.edge.aten.div.Scalar: exir_ops.edge.aten.div.Tensor,
}

def get_replacement(self, op, args, kwargs, meta):
return super().call_operator(
# Replace with .Tensor variant.
op=self.scalar_to_tensor_ops[op],
args=(
# Tensor arg.
args[0],
# Scalar arg - replace with aten.full tensor.
super().call_operator(
exir_ops.edge.aten.full.default,
args=(
(1,),
args[1],
),
kwargs={"dtype": args[0].to_tensor().dtype},
meta=meta,
),
# Other args.
*args[2:],
),
kwargs=kwargs,
meta=meta,
)

def call_operator(self, op, args, kwargs, meta):
if op not in self.scalar_to_tensor_ops:
return super().call_operator(op, args, kwargs, meta)

# There must be exactly 2 args (3 for add and sub containing alpha)
assert len(args) == 2 or len(args) == 3

# If there are two args, just replace the op.
if len(args) == 2:
return self.get_replacement(op, args, kwargs, meta)

# In case the op has three args, it must be scalar add/sub op.
if (
op not in {exir_ops.edge.aten.add.Scalar, exir_ops.edge.aten.sub.Scalar}
or "alpha" in kwargs
):
return super().call_operator(op, args, kwargs, meta)

return self.get_replacement(op, args, kwargs, meta)


@register_cadence_pass(CadencePassAttribute(opt_level=0))(
ReplaceScalarWithTensorArgPass()
)
@register_cadence_pass(CadencePassAttribute(opt_level=0))
class ReplaceScalarTensorWithFullPass(ExportPass):
"""
Expand Down
75 changes: 75 additions & 0 deletions backends/transforms/replace_scalar_with_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Dict

import torch
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.dialects.edge._ops import EdgeOpOverload
from executorch.exir.pass_base import ExportPass


class ReplaceScalarWithTensorArgPass(ExportPass):
"""
For binary ops like add.Scalar, sub.Scalar mul.Scalar, and div.Scalar,
replace the scalar arg with Tensor arg.
"""

scalar_to_tensor_ops: Dict[EdgeOpOverload, EdgeOpOverload] = {
exir_ops.edge.aten.add.Scalar: exir_ops.edge.aten.add.Tensor,
exir_ops.edge.aten.sub.Scalar: exir_ops.edge.aten.sub.Tensor,
exir_ops.edge.aten.mul.Scalar: exir_ops.edge.aten.mul.Tensor,
exir_ops.edge.aten.div.Scalar: exir_ops.edge.aten.div.Tensor,
torch.ops.aten.add.Scalar: torch.ops.aten.add.Tensor,
torch.ops.aten.sub.Scalar: torch.ops.aten.sub.Tensor,
torch.ops.aten.mul.Scalar: torch.ops.aten.mul.Tensor,
torch.ops.aten.div.Scalar: torch.ops.aten.div.Tensor,
}

def get_replacement(self, op, args, kwargs, meta):
return super().call_operator(
# Replace with .Tensor variant.
op=self.scalar_to_tensor_ops[op],
args=(
# Tensor arg.
args[0],
# Scalar arg - replace with aten.full tensor.
super().call_operator(
exir_ops.edge.aten.full.default,
args=(
(1,),
args[1],
),
kwargs={"dtype": args[0].to_tensor().dtype},
meta=meta,
),
# Other args.
*args[2:],
),
kwargs=kwargs,
meta=meta,
)

def call_operator(self, op, args, kwargs, meta):
if op not in self.scalar_to_tensor_ops:
return super().call_operator(op, args, kwargs, meta)

# There must be exactly 2 args (3 for add and sub containing alpha)
assert len(args) == 2 or len(args) == 3

# If there are two args, just replace the op.
if len(args) == 2:
return self.get_replacement(op, args, kwargs, meta)

# In case the op has three args, it must be scalar add/sub op.
if (
op not in {exir_ops.edge.aten.add.Scalar, exir_ops.edge.aten.sub.Scalar}
or "alpha" in kwargs
):
return super().call_operator(op, args, kwargs, meta)

return self.get_replacement(op, args, kwargs, meta)

0 comments on commit 8fcd2b2

Please sign in to comment.