Skip to content

Commit

Permalink
Arm: Add FLOOR operator (#8563)
Browse files Browse the repository at this point in the history
Implement an unary operator factory for creating one input NodeVisitors.

Change-Id: I59ba0407b763e9e0cb79f214b7679465eda94825
  • Loading branch information
YufengShi-dudu authored Feb 19, 2025
1 parent 43efc37 commit 4f90ce4
Show file tree
Hide file tree
Showing 6 changed files with 143 additions and 0 deletions.
1 change: 1 addition & 0 deletions backends/arm/_passes/insert_table_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class InsertTableOpsPass(ExportPass):

table_ops: Dict[EdgeOpOverload, Callable[[torch.Tensor], torch.Tensor]] = {
exir_ops.edge.aten.exp.default: torch.exp,
exir_ops.edge.aten.floor.default: torch.floor,
exir_ops.edge.aten.log.default: torch.log,
exir_ops.edge.aten.reciprocal.default: torch.reciprocal,
exir_ops.edge.aten.rsqrt.default: torch.rsqrt,
Expand Down
1 change: 1 addition & 0 deletions backends/arm/operator_support/tosa_supported_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def is_node_supported(self, submodules, node: fx.Node) -> bool:
exir_ops.edge.aten.log.default,
exir_ops.edge.aten.linear.default,
exir_ops.edge.aten.split_with_sizes_copy.default,
exir_ops.edge.aten.floor.default,
exir_ops.edge.aten.full.default,
exir_ops.edge.aten.full_like.default,
exir_ops.edge.aten.ge.Tensor,
Expand Down
1 change: 1 addition & 0 deletions backends/arm/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,5 @@
op_upsample_nearest2d,
op_view,
ops_binary,
ops_unary,
)
57 changes: 57 additions & 0 deletions backends/arm/operators/ops_unary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe
from typing import List

import serializer.tosa_serializer as ts # type: ignore
import torch.fx
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)

from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_specification import TosaSpecification
from serializer.tosa_serializer import TosaOp


def unary_operator_factory(unary_target: str, tosa_op):
"Creates and registers NodeVisitors for operations that have one input and map directly into a TOSA op."

class UnaryOperator_080_MI(NodeVisitor):
target = unary_target

tosa_specs = [TosaSpecification.create_from_string("TOSA-0.80+MI")]

def __init__(self, *args):
super().__init__(*args)

def define_node(
self,
node: torch.fx.Node,
tosa_graph: ts.TosaSerializer,
inputs: List[TosaArg],
output: TosaArg,
) -> None:

if not (inputs[0].dtype == output.dtype):
raise ValueError(
"All inputs and output need same dtype."
f"Got {inputs[0].dtype=}, {output.dtype=}"
)

if not (inputs[0].dtype == ts.DType.FP32):
raise ValueError(
"All inputs need to be FP32." f"Got {inputs[0].dtype=}"
)

# MI lowering
tosa_graph.addOperator(tosa_op, [inputs[0].name], [output.name])

register_node_visitor(UnaryOperator_080_MI)


unary_operator_factory("aten.floor.default", TosaOp.Op().FLOOR)
1 change: 1 addition & 0 deletions backends/arm/quantizer/quantization_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def _match_pattern(
_one_to_one = [
torch.ops.aten.abs.default,
torch.ops.aten.exp.default,
torch.ops.aten.floor.default,
torch.ops.aten.log.default,
torch.ops.aten.reciprocal.default,
torch.ops.aten.rsqrt.default,
Expand Down
82 changes: 82 additions & 0 deletions backends/arm/test/ops/test_floor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


from typing import Tuple

import torch
from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.test_pipeline import (
EthosU55PipelineBI,
EthosU85PipelineBI,
TosaPipelineBI,
TosaPipelineMI,
)


aten_op = "torch.ops.aten.floor.default"
exir_op = "executorch_exir_dialects_edge__ops_aten_floor_default"

input_t1 = Tuple[torch.Tensor] # Input x


class Floor(torch.nn.Module):
def forward(self, x: torch.Tensor):
return torch.floor(x)

test_data: dict[str, input_t1] = {
"zeros": (torch.zeros(1, 10, 10, 10),),
"ones": (torch.ones(10, 10, 10),),
"rand": ((torch.rand(10, 10) - 0.5),),
"randn_pos": ((torch.randn(1, 4, 4, 4) + 10),),
"randn_neg": ((torch.randn(1, 4, 4, 4) - 10),),
"ramp": (torch.arange(-16, 16, 0.2),),
}


@common.parametrize("test_data", Floor.test_data)
def test_floor_tosa_MI(test_data: input_t1):
pipeline = TosaPipelineMI[input_t1](Floor(), test_data, aten_op, exir_op)
pipeline.run()


@common.parametrize("test_data", Floor.test_data)
def test_floor_tosa_BI(test_data: input_t1):
pipeline = TosaPipelineBI[input_t1](Floor(), test_data, aten_op, exir_op)
pipeline.run()


@common.parametrize("test_data", Floor.test_data)
def test_floor_u55_BI(test_data: input_t1):
pipeline = EthosU55PipelineBI[input_t1](
Floor(), test_data, aten_op, exir_op, run_on_fvp=False
)
pipeline.run()


@common.parametrize("test_data", Floor.test_data)
def test_floor_u85_BI(test_data: input_t1):
pipeline = EthosU85PipelineBI[input_t1](
Floor(), test_data, aten_op, exir_op, run_on_fvp=False
)
pipeline.run()


@common.parametrize("test_data", Floor.test_data)
@common.SkipIfNoCorstone300
def test_floor_u55_BI_on_fvp(test_data: input_t1):
pipeline = EthosU55PipelineBI[input_t1](
Floor(), test_data, aten_op, exir_op, run_on_fvp=True
)
pipeline.run()


@common.parametrize("test_data", Floor.test_data)
@common.SkipIfNoCorstone320
def test_floor_u85_BI_on_fvp(test_data: input_t1):
pipeline = EthosU85PipelineBI[input_t1](
Floor(), test_data, aten_op, exir_op, run_on_fvp=True
)
pipeline.run()

0 comments on commit 4f90ce4

Please sign in to comment.