diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 5332dd9184..3097d64197 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -52,7 +52,6 @@ from executorch.backends.arm._passes.fuse_quantized_activation_pass import ( # type: ignore[import-not-found] FuseQuantizedActivationPass, ) -from executorch.backends.arm._passes.insert_rescales_pass import InsertRescalePass from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass from executorch.backends.arm._passes.keep_dims_false_to_squeeze_pass import ( KeepDimsFalseToSqueezePass, @@ -76,7 +75,6 @@ UnsqueezeScalarPlaceholdersPass, ) from executorch.backends.arm.tosa_specification import TosaSpecification - from executorch.backends.xnnpack._passes.remove_getitem_op import RemoveGetItemPass from executorch.exir import ExportedProgram from executorch.exir.pass_manager import PassManager @@ -121,7 +119,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul self.add_pass(ConvertSqueezesToViewPass()) self.add_pass(AnnotateChannelsLastDimOrder()) - self.add_pass(InsertRescalePass()) + return self._transform(exported_program.graph_module) def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModule: @@ -159,7 +157,6 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul self.add_pass(ConvertSqueezesToViewPass()) self.add_pass(AnnotateChannelsLastDimOrder()) - self.add_pass(InsertRescalePass()) return self._transform(exported_program.graph_module) diff --git a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py index 7a965539f8..da2bd28a17 100644 --- a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py +++ b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py @@ -131,9 +131,6 @@ def call(self, graph_module: GraphModule) -> PassResult: n = cast(Node, n) if n.op != "call_function": continue - # Don't fold chains of quant-ops into each other. - if n.target in (q_op, dq_op): - continue # Make sure we haven't already set qparams meta information on the node assert "input_qparams" not in n.meta.keys() diff --git a/backends/arm/_passes/insert_rescales_pass.py b/backends/arm/_passes/insert_rescales_pass.py deleted file mode 100644 index e9f6eec63a..0000000000 --- a/backends/arm/_passes/insert_rescales_pass.py +++ /dev/null @@ -1,109 +0,0 @@ -# Copyright 2025 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import logging -from copy import copy -from typing import cast - -import torch -from executorch.backends.arm._passes.arm_pass_utils import create_node -from executorch.backends.arm.tosa_quant_utils import dq_op, q_op, QuantArgs -from executorch.exir.pass_base import ExportPass, PassResult -from torch import Tensor -from torch.fx import GraphModule, Node -from torch.library import custom_op, register_fake - -logger = logging.getLogger(__name__) - - -@custom_op("tosa::_rescale", mutates_args=()) # type: ignore[misc] -def rescale( - x: Tensor, dtype: torch.dtype, scale: float, in_zp: int, out_zp: int -) -> Tensor: - logger.warning( - "Ran default implementation of tosa::_rescale." - "This op is meant to always be inserted inside a partition and a correct default implementation is not implemented." - ) - # Clone is needed to not return reference when rescaling to same dtype. - # This is a neccessary requirement for non-mutating custom ops. - return x.to(dtype=dtype).clone() - - -@register_fake("tosa::_rescale") # type: ignore[misc] -def rescale_fake( - x: Tensor, dtype: torch.dtype, scale: float, in_zp: int, out_zp: int -) -> Tensor: - """Casts the input tensor to dtype `dtype` to produce the correct tensor meta for a _rescale op. - Additionally validates TOSA constraints of a RESCALE op. - """ - if not (dtype == torch.int32 or dtype == torch.int8): - raise NotImplementedError( - "tosa::rescale currently only supports int32 and int8." - ) - if dtype == torch.int32 and out_zp != 0: - raise ValueError( - "TOSA requires output_zp to be zero when the output dtype is int32." - ) - if x.dtype == torch.int32 and in_zp != 0: - raise ValueError( - "TOSA requires input_zp to be zero when the input dtype is int32." - ) - if x.dtype == torch.int8 and not -128 <= in_zp <= 127: - raise ValueError(f"{in_zp=} outside valid range (-128,127) for int8.") - if dtype == torch.int8 and not -128 <= out_zp <= 127: - raise ValueError(f"{out_zp=} outside valid range (-128,127) for int8.") - - return x.to(dtype=dtype).clone() - - -class InsertRescalePass(ExportPass): - """Finds patterns of dq -> q, and replaces them - with passthrough_to_tosa::rescales. - - Does not garantuee that the dtypes and zero points are valid - in TOSA, that is the job of the quantization annotator that - produced the dq and q nodes. The TOSA constraints are validated - in the fake implementation of passthrough_to_tosa:rescale. - """ - - def fold_dq_q_to_rescale(self, node: Node, user: Node, graph_module: GraphModule): - dq_args = QuantArgs.from_operator(node.target, node.args) - q_args = QuantArgs.from_operator(user.target, user.args) - new_scale = dq_args.scale / q_args.scale - - with graph_module.graph.inserting_before(node): - rescale_node = create_node( - graph_module.graph, - torch.ops.tosa._rescale.default, - ( - node.all_input_nodes[0], - q_args.dtype, - new_scale, - dq_args.zp, - q_args.zp, - ), - ) - rescale_node.meta = copy(user.meta) - user.replace_all_uses_with(rescale_node) - graph_module.graph.erase_node(user) - - def call(self, graph_module: GraphModule) -> PassResult: - modified = False - for node in graph_module.graph.nodes: - node = cast(Node, node) - - if node.target is not dq_op: - continue - # Copy users since we remove them while iterating, modyfing the node.users list. - for user in copy(node.users): - if user.target is q_op: - self.fold_dq_q_to_rescale(node, user, graph_module) - modified = True - if len(node.users) == 0: - graph_module.graph.erase_node(node) - - graph_module = super().call(graph_module).graph_module - graph_module.recompile() - return PassResult(graph_module, modified) diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index 73395b247e..f57ba092bc 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -31,7 +31,6 @@ op_reciprocal, op_relu, op_repeat, - op_rescale, op_rshift, op_rsqrt, op_sigmoid, diff --git a/backends/arm/operators/op_rescale.py b/backends/arm/operators/op_rescale.py deleted file mode 100644 index e0589de4b3..0000000000 --- a/backends/arm/operators/op_rescale.py +++ /dev/null @@ -1,70 +0,0 @@ -# Copyright 2025 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe - -from typing import cast, List - -import executorch.backends.arm.tosa_quant_utils as tosa_quant_utils -import serializer.tosa_serializer as ts # type: ignore -import torch - -import tosa.Op as TosaOp # type: ignore -from executorch.backends.arm.operators.node_visitor import ( - NodeVisitor, - register_node_visitor, -) -from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg -from torch.fx import Node - - -@register_node_visitor -class RescaleVisitor(NodeVisitor): - target = "_rescale.default" - - def define_node( - self, - node: Node, - tosa_graph: ts.TosaSerializer, - inputs: List[TosaArg], - output: TosaArg, - ) -> None: - - input_dtype = inputs[0].dtype - output_dtype = cast(torch.dtype, node.args[1]) - scale = cast(float, node.args[2]) - input_zp = cast(int, node.args[3]) - output_zp = cast(int, node.args[4]) - - # Skip int16 cases for now. - if input_dtype != map_dtype(torch.int8) and input_zp != 0: - raise ValueError( - f"If input dtype is not int8, input_zp must be 0. Got input_dtype{ts.DTypeNames[input_dtype]}, {input_zp=}" - ) - if output_dtype != torch.int8 and output_zp != 0: - raise ValueError( - f"If output dtype is not int8, output_zp must be 0. Got {output_dtype=}, {output_zp=}" - ) - - scale_width = 32 if output_dtype == torch.int32 else 16 - multiplier, shift = tosa_quant_utils.compute_multiplier_and_shift( - scale, scale_width - ) - attr_rescale = ts.TosaSerializerAttribute() - attr_rescale.RescaleAttribute( - input_zp=input_zp, - output_zp=output_zp, - multiplier=[multiplier], - shift=[shift], - scale32=output_dtype == torch.int32, - double_round=False, - per_channel=False, - input_unsigned=False, - output_unsigned=False, - ) - - tosa_graph.addOperator( - TosaOp.Op().RESCALE, [inputs[0].name], [output.name], attr_rescale - ) diff --git a/backends/arm/test/ops/test_add.py b/backends/arm/test/ops/test_add.py index d87a109dae..b4b43f88c7 100644 --- a/backends/arm/test/ops/test_add.py +++ b/backends/arm/test/ops/test_add.py @@ -9,8 +9,6 @@ from typing import Tuple import torch -from executorch.backends.arm.arm_backend import get_tosa_version -from executorch.backends.arm.quantizer import arm_quantizer from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import ( EthosU55PipelineBI, @@ -18,10 +16,6 @@ TosaPipelineBI, TosaPipelineMI, ) -from executorch.backends.xnnpack.test.tester import Quantize -from torch.ao.quantization.observer import HistogramObserver -from torch.ao.quantization.quantizer import QuantizationSpec - aten_op = "torch.ops.aten.add.Tensor" exir_op = "executorch_exir_dialects_edge__ops_aten_add_Tensor" @@ -73,38 +67,6 @@ def test_add_tosa_BI(test_data: input_t1): pipeline.run() -@common.parametrize("test_data", Add.test_data) -def test_add_i32_tosa_BI(test_data: input_t1): - pipeline = TosaPipelineBI[input_t1](Add(), test_data, aten_op, exir_op) - - # Create a quantizer with int8 quantization on the input and output but int32 on everything else. - quantizer = arm_quantizer.ArmQuantizer( - get_tosa_version(common.get_tosa_compile_spec("TOSA-0.80+BI")) - ) - quantizer.set_io(arm_quantizer.get_symmetric_quantization_config()) - observer_options = {"eps": 2**-16} - observer = HistogramObserver.with_args(**observer_options) - input_act_qspec = QuantizationSpec( - torch.int32, - observer, - qscheme=torch.per_tensor_symmetric, - quant_max=2**31 - 1, - quant_min=-(2**31), - ) - # This quantization_config will be set as global config. - quantization_config = arm_quantizer.QuantizationConfig( - input_act_qspec, None, None, None - ) - quantize_stage = Quantize(quantizer, quantization_config) - pipeline.change_args("quantize", quantize_stage) - - # Check that we get the additional (dq -> q - pipeline.add_stage_after( - "export", pipeline.tester.check_count, {"torch.ops.quantized_decomposed": 8} - ) - pipeline.run() - - @common.parametrize("test_data", Add.test_data) def test_add_u55_BI(test_data: input_t1): pipeline = EthosU55PipelineBI[input_t1]( diff --git a/backends/arm/test/passes/test_rescale_pass.py b/backends/arm/test/passes/test_rescale_pass.py deleted file mode 100644 index 25052c448d..0000000000 --- a/backends/arm/test/passes/test_rescale_pass.py +++ /dev/null @@ -1,170 +0,0 @@ -# Copyright 2025 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - - -import unittest - -import pytest - -import torch -import torch.library -from executorch.backends.arm.test import common, conftest -from executorch.backends.arm.test.tester.arm_tester import ArmTester -from parameterized import parameterized -from torch.testing._internal import optests - - -def test_rescale_op(): - sample_inputs = [ - # (data, out_dtype, scale, in_zp, out_zp) - ( - torch.randint(low=0, high=100, size=(4, 4, 4), dtype=torch.int8), - torch.int32, - 0.2, - 2, - 0, - ), - ( - torch.randint(low=0, high=100, size=(4, 4, 4), dtype=torch.int32), - torch.int8, - 0.2, - 0, - -128, - ), - ( - torch.randint(low=0, high=100, size=(4, 4, 4), dtype=torch.int8), - torch.int8, - 0.8, - 10, - 127, - ), - ] - for sample_input in sample_inputs[1:2]: - torch.library.opcheck(torch.ops.tosa._rescale, sample_input) - - -def test_nonzero_zp_for_int32(): - - sample_inputs = [ - ( - torch.randint(low=0, high=100, size=(4, 4, 4), dtype=torch.int8), - torch.int32, - 0.2, - 2, # Should be 0, expect error - 1, - ), - ( - torch.randint(low=0, high=100, size=(4, 4, 4), dtype=torch.int32), - torch.int8, - 0.2, - 1, - 1, # Should be 0, expect error - ), - ] - for sample_input in sample_inputs: - with pytest.raises(optests.generate_tests.OpCheckError): - torch.library.opcheck(torch.ops.tosa._rescale, sample_input) - - -def test_zp_outside_range(): - - sample_inputs = [ - ( - torch.randint(low=0, high=100, size=(4, 4, 4), dtype=torch.int8), - torch.int32, - 0.2, - 128, # Should be <128, expect error - 0, - ), - ( - torch.randint(low=0, high=100, size=(4, 4, 4), dtype=torch.int32), - torch.int8, - 0.2, - 0, - -129, # Should be >-129m expect error - ), - ] - for sample_input in sample_inputs: - with pytest.raises(optests.generate_tests.OpCheckError): - torch.library.opcheck(torch.ops.tosa._rescale, sample_input) - - -class RescaleNetwork(torch.nn.Module): - test_parameters = [ - (torch.rand(5), torch.rand(5)), - (torch.randn(5, 2), torch.randn(5, 1)), - (torch.ones(1, 10, 4, 6), torch.ones(1, 10, 4, 6)), - (torch.randn(1, 1, 4, 4), torch.ones(1, 1, 4, 1)), - (10000 * torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 1)), - ] - - def forward(self, x: torch.Tensor, y: torch.Tensor): - a = y.exp() - g = (a + 5).log() - c = a + x - d = c - g - e = c * d - f = e.sigmoid() - - return f - - -def _test_rescale_pipeline( - module: torch.nn.Module, test_data: tuple[torch.Tensor, torch.Tensor] -): - """Tests a model with many ops that requires rescales. As more ops are quantized to int32 and - need the InsertRescalesPass, make sure that they play nicely together.""" - ( - ArmTester( - module, - example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"), - ) - .quantize() - .export() - .to_edge_transform_and_lower() - .to_executorch() - .run_method_and_compare_outputs(test_data) - ) - - -def _test_rescale_pipeline_ethosu( - module: torch.nn.Module, compile_spec, test_data: tuple[torch.Tensor, torch.Tensor] -): - tester = ( - ArmTester( - module, - example_inputs=test_data, - compile_spec=compile_spec, - ) - .quantize() - .export() - .to_edge_transform_and_lower() - .to_executorch() - .serialize() - ) - if conftest.is_option_enabled("corstone_fvp"): - tester.run_method_and_compare_outputs(inputs=test_data) - - -class TestRescales(unittest.TestCase): - - @parameterized.expand(RescaleNetwork.test_parameters) - def test_quantized_rescale(self, x, y): - _test_rescale_pipeline(RescaleNetwork(), (x, y)) - - @parameterized.expand(RescaleNetwork.test_parameters) - @pytest.mark.corstone_fvp - def test_quantized_rescale_U55(self, x, y): - _test_rescale_pipeline_ethosu( - RescaleNetwork(), common.get_u55_compile_spec(), (x, y) - ) - - @parameterized.expand(RescaleNetwork.test_parameters) - @pytest.mark.corstone_fvp - def test_quantized_rescale_U85(self, x, y): - _test_rescale_pipeline_ethosu( - RescaleNetwork(), common.get_u85_compile_spec(), (x, y) - )