From fbb0395110724717c42720582bb8804b752241e3 Mon Sep 17 00:00:00 2001 From: Dave Bort Date: Thu, 16 Jan 2025 17:59:38 -0800 Subject: [PATCH 1/7] Validate tensor sizes during method load Differential Revision: D68180029 Pull Request resolved: https://github.com/pytorch/executorch/pull/7663 --- runtime/executor/tensor_parser_portable.cpp | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/runtime/executor/tensor_parser_portable.cpp b/runtime/executor/tensor_parser_portable.cpp index 3f190060f7..79e4c4bd96 100644 --- a/runtime/executor/tensor_parser_portable.cpp +++ b/runtime/executor/tensor_parser_portable.cpp @@ -101,6 +101,19 @@ Result parseTensor( sizes = const_cast(serialized_sizes); dim_order = const_cast(serialized_dim_order); } + // Validate sizes before using them in case the PTE data is bad. We can't + // detect bad positive values, but we can reject negative values, which would + // otherwise panic in the TensorImpl ctor. dim_order_to_stride() will validate + // dim_order. + for (int i = 0; i < dim; i++) { + ET_CHECK_OR_RETURN_ERROR( + sizes[i] >= 0, + InvalidProgram, + "Negative size[%d] %" PRId32, + i, + sizes[i]); + } + // We will remove strides from schema. // Allocating strides buffer here and populating it. // In subsequent diffs we can remove strides accessor, however this From 1a6b7a6f14c75d87b21c4fc517b0d7c0fe17f761 Mon Sep 17 00:00:00 2001 From: JP <46308822+zonglinpeng@users.noreply.github.com> Date: Thu, 16 Jan 2025 22:08:59 -0800 Subject: [PATCH 2/7] refactor test targets Differential Revision: D68194772 Pull Request resolved: https://github.com/pytorch/executorch/pull/7673 --- examples/cadence/operators/TARGETS | 25 ++----------- examples/cadence/operators/targets.bzl | 36 +++++++++++++++++++ ...nv1d_op.py => test_quantized_conv1d_op.py} | 4 ++- ...near_op.py => test_quantized_linear_op.py} | 0 4 files changed, 41 insertions(+), 24 deletions(-) create mode 100644 examples/cadence/operators/targets.bzl rename examples/cadence/operators/{quantized_conv1d_op.py => test_quantized_conv1d_op.py} (93%) rename examples/cadence/operators/{quantized_linear_op.py => test_quantized_linear_op.py} (100%) diff --git a/examples/cadence/operators/TARGETS b/examples/cadence/operators/TARGETS index 732f1ced09..67f2bab681 100644 --- a/examples/cadence/operators/TARGETS +++ b/examples/cadence/operators/TARGETS @@ -1,26 +1,5 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest") +load("targets.bzl", "define_common_targets") oncall("odai_jarvis") - -python_unittest( - name = "test_add_op", - srcs = [ - "test_add_op.py", - ], - typing = True, - supports_static_listing = False, - deps = [ - "fbsource//third-party/pypi/parameterized:parameterized", - "//caffe2:torch", - "//executorch/backends/cadence/aot:ops_registrations", - "//executorch/backends/cadence/aot:export_example", - "//executorch/backends/cadence/aot:compiler", - ], -) +define_common_targets() diff --git a/examples/cadence/operators/targets.bzl b/examples/cadence/operators/targets.bzl new file mode 100644 index 0000000000..e1fbeb9fdf --- /dev/null +++ b/examples/cadence/operators/targets.bzl @@ -0,0 +1,36 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest") + +TESTS_LIST = [ + "add_op", + "quantized_conv1d_op", + "quantized_linear_op", +] + +def define_common_targets(): + for op in TESTS_LIST: + _define_test_target(op) + + +def _define_test_target(test_name): + file_name = "test_{}".format(test_name) + python_unittest( + name = file_name, + srcs = [ + "{}.py".format(file_name), + ], + typing = True, + supports_static_listing = False, + deps = [ + "fbsource//third-party/pypi/parameterized:parameterized", + "fbcode//caffe2:torch", + "fbcode//executorch/backends/cadence/aot:ops_registrations", + "fbcode//executorch/backends/cadence/aot:export_example", + "fbcode//executorch/backends/cadence/aot:compiler", + ], + ) diff --git a/examples/cadence/operators/quantized_conv1d_op.py b/examples/cadence/operators/test_quantized_conv1d_op.py similarity index 93% rename from examples/cadence/operators/quantized_conv1d_op.py rename to examples/cadence/operators/test_quantized_conv1d_op.py index 3247cb690d..e2457077b2 100644 --- a/examples/cadence/operators/quantized_conv1d_op.py +++ b/examples/cadence/operators/test_quantized_conv1d_op.py @@ -8,6 +8,8 @@ import logging +from typing import cast, Sequence + import torch from executorch.backends.cadence.aot.ops_registrations import * # noqa @@ -53,6 +55,6 @@ def forward(self, x: torch.Tensor): model = QuantizedConv() model.eval() - example_inputs = (torch.randn(shape),) + example_inputs = (torch.randn(cast(Sequence[int], shape)),) export_model(model, example_inputs) diff --git a/examples/cadence/operators/quantized_linear_op.py b/examples/cadence/operators/test_quantized_linear_op.py similarity index 100% rename from examples/cadence/operators/quantized_linear_op.py rename to examples/cadence/operators/test_quantized_linear_op.py From dad73ca6240429e2f79d666547cd61c95c05c427 Mon Sep 17 00:00:00 2001 From: SaoirseARM <44364573+SaoirseARM@users.noreply.github.com> Date: Fri, 17 Jan 2025 08:57:38 +0000 Subject: [PATCH 3/7] Fix for multiple outputs in FVP tests (#7650) Fix for multiple outputs in corstone - Update to ensure all output nodes are consumed. - Update to ensure output quant scales are used. --- .../arm/test/misc/test_multiple_outputs.py | 47 ++++++++++- backends/arm/test/runner_utils.py | 79 ++++++++++--------- .../arm/test/tester/analyze_output_utils.py | 8 +- backends/arm/test/tester/arm_tester.py | 37 +++++---- 4 files changed, 114 insertions(+), 57 deletions(-) diff --git a/backends/arm/test/misc/test_multiple_outputs.py b/backends/arm/test/misc/test_multiple_outputs.py index 7762c7dc2f..ddddc94d27 100644 --- a/backends/arm/test/misc/test_multiple_outputs.py +++ b/backends/arm/test/misc/test_multiple_outputs.py @@ -6,9 +6,11 @@ import unittest +import pytest import torch -from executorch.backends.arm.test import common +from executorch.backends.arm.test import common, conftest from executorch.backends.arm.test.tester.arm_tester import ArmTester +from executorch.exir.backend.compile_spec_schema import CompileSpec class TestMultipleOutputs(unittest.TestCase): @@ -51,3 +53,46 @@ def test_tosa_BI_pipeline(self): .to_executorch() .run_method_and_compare_outputs(inputs=inputs, qtol=1.0) ) + + def _test_ethosu_BI_pipeline( + self, + module: torch.nn.Module, + test_data: tuple[torch.Tensor], + compile_spec: CompileSpec, + ): + tester = ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=compile_spec, + ) + .quantize() + .export() + .to_edge_transform_and_lower() + .to_executorch() + .serialize() + ) + if conftest.is_option_enabled("corstone_fvp"): + tester.run_method_and_compare_outputs(qtol=1, inputs=test_data) + + @pytest.mark.corstone_fvp + def test_u85_BI(self): + module = self.MultipleOutputsModule() + test_data = module.get_inputs() + self._test_ethosu_BI_pipeline( + module, + test_data, + common.get_u85_compile_spec(), + ) + + @pytest.mark.corstone_fvp + @conftest.expectedFailureOnFVP + # TODO MLETORCH-598 + def test_u55_BI(self): + module = self.MultipleOutputsModule() + test_data = module.get_inputs() + self._test_ethosu_BI_pipeline( + module, + test_data, + common.get_u55_compile_spec(), + ) diff --git a/backends/arm/test/runner_utils.py b/backends/arm/test/runner_utils.py index b206e5585b..3851e41b73 100644 --- a/backends/arm/test/runner_utils.py +++ b/backends/arm/test/runner_utils.py @@ -115,50 +115,53 @@ def _get_input_quantization_params( return quant_params -def _get_output_node(program: ExportedProgram) -> Node: +def _get_output_nodes(program: ExportedProgram) -> list[Node]: """ Get output node to this model. Args: - program (ExportedProgram): The program to get output node from. + program (ExportedProgram): The program to get the output nodes from. Returns: - The node that is the output of 'program'. + The nodes that are the outputs of the 'program'. """ - + output_nodes = [] for node in program.graph.nodes: if node.op == "output": - return node - raise RuntimeError("No output node found.") + for output in node.args[0]: + output_nodes.append(output) + if len(output_nodes) == 0: + raise RuntimeError("No output nodes found.") + else: + return output_nodes def _get_output_quantization_params( - program: ExportedProgram, output_node: Node -) -> Optional[QuantizationParams]: + output_nodes: list[Node], +) -> List[QuantizationParams]: """ Get output QuantizationParams from a program. Args: - program (ExportedProgram): The program to get output quantization parameters from. + output_nodes (list(Node)): A list of output nodes to get output quantization parameters from. Returns: QuantizationParams: The found quantization parameters. Raises: RuntimeError if no output quantization parameters are found. """ - - quant_params = None - for node in program.graph.nodes: - if ( - node.target == torch.ops.quantized_decomposed.dequantize_per_tensor.default - and node == output_node.args[0][0] - ): - quant_params = QuantizationParams( - node_name=node.args[0].name, - scale=node.args[1], - zp=node.args[2], - qmin=node.args[3], - qmax=node.args[4], - dtype=node.args[5], + quant_params = [] + for node in output_nodes: + if node.target == torch.ops.quantized_decomposed.dequantize_per_tensor.default: + quant_params.append( + QuantizationParams( + node_name=node.args[0].name, + scale=node.args[1], + zp=node.args[2], + qmin=node.args[3], + qmax=node.args[4], + dtype=node.args[5], + ) ) - break # break early, there's only one output node + if len(quant_params) == 0: + raise RuntimeError("No Quantization parameters not found in exported model.") return quant_params @@ -211,7 +214,7 @@ def __init__( self.input_names: list[str] = None self.output_name: str = None self.qp_input: list[QuantizationParams] = None - self.qp_output: QuantizationParams = None + self.qp_output: list[QuantizationParams] = None self.timeout = 480 self.target_board: str = None @@ -226,19 +229,17 @@ def init_run( ): self.input_names = _get_input_names(edge_program) - self.output_node = _get_output_node(exported_program) - self.output_name = self.output_node.name + self.output_nodes = _get_output_nodes(exported_program) + self.is_quantized = is_quantized self.target_board = target_board if is_quantized: self.qp_input = _get_input_quantization_params(exported_program) - self.qp_output = _get_output_quantization_params( - exported_program, self.output_node - ) + self.qp_output = _get_output_quantization_params(self.output_nodes) else: self.qp_input = [None] * len(self.input_names) - self.qp_output = None + self.qp_output = [None] * len(self.output_nodes) self._has_init_run = True @@ -265,7 +266,7 @@ def run_corstone( save_bytes(self.intermediate_path, data, False, input_name, quant_param) out_path = os.path.join(self.intermediate_path, "out") - out_path_with_suffix = out_path + "-0.bin" + input_paths = [] for name in self.input_names: input_paths.append( @@ -281,6 +282,7 @@ def run_corstone( ), f"Did not find build arm_executor_runner in path {elf_path}, run setup_testing.sh?" cmd_line = f"executor_runner -m {pte_path} -o {out_path}" + for input_path in input_paths: cmd_line += f" -i {input_path}" @@ -362,11 +364,14 @@ def run_corstone( raise RuntimeError( f"Corstone simulation failed:\ncmd: {command_args[self.target_board]}\n, log: \n {result_stdout}\n{result.stderr.decode()}" ) - - tosa_ref_output = np.fromfile(out_path_with_suffix, dtype=np.float32) - output_shape = self.output_node.args[0][0].meta["val"].shape - tosa_ref_output = torch.from_numpy(tosa_ref_output).reshape(output_shape) - return tosa_ref_output + output_np = [] + for i, node in enumerate(self.output_nodes): + tosa_ref_output = np.fromfile( + os.path.join(self.intermediate_path, f"out-{i}.bin"), dtype=np.float32 + ) + output_shape = node.meta["val"].shape + output_np.append(torch.from_numpy(tosa_ref_output).reshape(output_shape)) + return tuple(output_np) def run_tosa_graph( self, graph: TosaGraph, inputs: list[np.ndarray] | list[torch.Tensor] diff --git a/backends/arm/test/tester/analyze_output_utils.py b/backends/arm/test/tester/analyze_output_utils.py index d70f86c4f2..477a96652f 100644 --- a/backends/arm/test/tester/analyze_output_utils.py +++ b/backends/arm/test/tester/analyze_output_utils.py @@ -1,4 +1,4 @@ -# Copyright 2024 Arm Limited and/or its affiliates. +# Copyright 2024-2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -9,7 +9,7 @@ import torch from executorch.backends.arm.test.runner_utils import ( _get_input_quantization_params, - _get_output_node, + _get_output_nodes, _get_output_quantization_params, ) @@ -228,9 +228,9 @@ def dump_error_output( export_stage = tester.stages.get(tester.stage_name(Export), None) quantize_stage = tester.stages.get(tester.stage_name(Quantize), None) if export_stage is not None and quantize_stage is not None: - output_node = _get_output_node(export_stage.artifact) + output_nodes = _get_output_nodes(export_stage.artifact) qp_input = _get_input_quantization_params(export_stage.artifact) - qp_output = _get_output_quantization_params(export_stage.artifact, output_node) + qp_output = _get_output_quantization_params(output_nodes) logger.error(f"Input QuantArgs: {qp_input}") logger.error(f"Output QuantArgs: {qp_output}") diff --git a/backends/arm/test/tester/arm_tester.py b/backends/arm/test/tester/arm_tester.py index e5c700ec3c..5b2f9201fc 100644 --- a/backends/arm/test/tester/arm_tester.py +++ b/backends/arm/test/tester/arm_tester.py @@ -14,6 +14,7 @@ import serializer.tosa_serializer as ts import torch.fx +import torch.utils._pytree as pytree from executorch.backends.arm.arm_backend import get_intermediate_path from executorch.backends.arm.arm_partitioner import ArmPartitioner @@ -302,6 +303,7 @@ def run_method_and_compare_outputs( exported_program = self.stages[self.stage_name(tester.Export)].artifact edge_program = edge_stage.artifact.exported_program() + self.runner_util.init_run( exported_program, edge_program, @@ -309,14 +311,14 @@ def run_method_and_compare_outputs( target_board, ) - quantization_scale = None if is_quantized: reference_stage = self.stages[self.stage_name(tester.Quantize)] # bool output is quantized with none quantized output so allow # self.runner_util.qp_output to be none if self.runner_util.qp_output is not None: - quantization_scale = self.runner_util.qp_output.scale + quantization_scales = [qp.scale for qp in self.runner_util.qp_output] else: + quantization_scales = [None] * len(self.runner_util.output_nodes) reference_stage = self.stages[self.stage_name(InitialModel)] logger.info( @@ -334,21 +336,26 @@ def run_method_and_compare_outputs( input_shape_str = ", ".join([str(list(i)) for i in input_shapes]) logger.info(f"Run #{run_iteration}, input shapes: {input_shape_str}") - reference_output = reference_stage.run_artifact(reference_input) - if not isinstance(reference_output, tuple): - reference_output = (reference_output,) - test_output = test_stage.run_artifact(reference_input) - - self._compare_outputs( - reference_output, - test_output, - quantization_scale, - atol, - rtol, - qtol, - error_callbacks, + reference_outputs, _ = pytree.tree_flatten( + reference_stage.run_artifact(reference_input) + ) + test_outputs, _ = pytree.tree_flatten( + test_stage.run_artifact(reference_input) ) + for reference_output, test_output, quantization_scale in zip( + reference_outputs, test_outputs, quantization_scales + ): + self._compare_outputs( + reference_output, + test_output, + quantization_scale, + atol, + rtol, + qtol, + error_callbacks, + ) + return self def get_graph(self, stage: str | None = None) -> Graph: From cb45fb6ccb1a1b2dd170bc047617cc2e9ff592ab Mon Sep 17 00:00:00 2001 From: Thibaut Goetghebuer-Planchon Date: Fri, 17 Jan 2025 08:59:39 +0000 Subject: [PATCH 4/7] Fix uninitialized variable type-check in FuseQuantizedActivationPass (#7671) --- backends/arm/_passes/fuse_quantized_activation_pass.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/backends/arm/_passes/fuse_quantized_activation_pass.py b/backends/arm/_passes/fuse_quantized_activation_pass.py index 86836842bb..4eccea1a14 100644 --- a/backends/arm/_passes/fuse_quantized_activation_pass.py +++ b/backends/arm/_passes/fuse_quantized_activation_pass.py @@ -19,12 +19,13 @@ def _is_fuseable_quantized_activation(self, node: Node): is_fuseable = min_val == 0 is_quantized = len(node.users) == 1 and next(iter(node.users)).target == q_op - if is_quantized: + if is_fuseable and is_quantized: quant_node = next(iter(node.users)) zp = quant_node.args[2] qmin = quant_node.args[3] - - return is_fuseable and is_quantized and zp == qmin + return zp == qmin + else: + return False def _is_fuseable_input(self, node: Node): return ( From ffc20208dae8f4900da11bfffb76f749e7514132 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Per=20=C3=85strand?= Date: Fri, 17 Jan 2025 11:24:37 +0100 Subject: [PATCH 5/7] Remove unused functions for quantization handling (#7700) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Remove functions not used for searching/finding quantization information. Signed-off-by: Per Åstrand --- .../annotate_channels_last_dim_order_pass.py | 7 +- backends/arm/operators/__init__.py | 2 - backends/arm/operators/op_dequant.py | 35 --- backends/arm/operators/op_hardtanh.py | 7 +- backends/arm/operators/op_quant.py | 35 --- backends/arm/operators/op_relu.py | 8 +- backends/arm/process_node.py | 22 +- backends/arm/tosa_quant_utils.py | 270 +----------------- backends/arm/tosa_utils.py | 28 -- examples/arm/aot_arm_compiler.py | 6 +- 10 files changed, 21 insertions(+), 399 deletions(-) delete mode 100644 backends/arm/operators/op_dequant.py delete mode 100644 backends/arm/operators/op_quant.py diff --git a/backends/arm/_passes/annotate_channels_last_dim_order_pass.py b/backends/arm/_passes/annotate_channels_last_dim_order_pass.py index 80c5f3c442..4aff46de67 100644 --- a/backends/arm/_passes/annotate_channels_last_dim_order_pass.py +++ b/backends/arm/_passes/annotate_channels_last_dim_order_pass.py @@ -1,4 +1,4 @@ -# Copyright 2024 Arm Limited and/or its affiliates. +# Copyright 2024-2025 Arm Limited and/or its affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the @@ -15,7 +15,7 @@ get_node_arg, insert_q_dq_pair, ) -from executorch.backends.arm.tosa_quant_utils import dq_op, q_op, register_passable_op +from executorch.backends.arm.tosa_quant_utils import dq_op, q_op from executorch.backends.arm.tosa_utils import is_consumer_node_depthwise_conv2d from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult @@ -43,9 +43,6 @@ def _transpose_impl(*args, **kwargs): return args[0] -register_passable_op(torch.ops.passthrough_to_tosa._transpose) - - class AnnotateChannelsLastDimOrder(ExportPass): """ Annotates each node with a tosa_dim_order. tosa_dim_order can be seen as a channels-last dim-order diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index 157e5ec092..a21bde535e 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -13,7 +13,6 @@ op_bmm, op_cat, op_conv2d, - op_dequant, op_exp, op_full, op_get_item, @@ -24,7 +23,6 @@ op_min, op_mul, op_permute, - op_quant, op_reciprocal, op_relu, op_repeat, diff --git a/backends/arm/operators/op_dequant.py b/backends/arm/operators/op_dequant.py deleted file mode 100644 index 022f4e45ce..0000000000 --- a/backends/arm/operators/op_dequant.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright 2023-2024 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe -from typing import List - -import serializer.tosa_serializer as ts -import torch -from executorch.backends.arm.operators.node_visitor import ( - NodeVisitor, - register_node_visitor, -) -from executorch.backends.arm.tosa_mapping import TosaArg -from serializer.tosa_serializer import TosaOp - - -@register_node_visitor -class DequantVisitor(NodeVisitor): - target = "quantized_decomposed.dequantize_per_tensor.default" - - def __init__(self, *args): - super().__init__(*args) - - def define_node( - self, - node: torch.fx.Node, - tosa_graph: ts.TosaSerializer, - inputs: List[TosaArg], - output: TosaArg, - ) -> None: - item_name = inputs[0].name - ## Simply add an identityOp - tosa_graph.addOperator(TosaOp.Op().IDENTITY, [item_name], [output.name]) diff --git a/backends/arm/operators/op_hardtanh.py b/backends/arm/operators/op_hardtanh.py index bfbab55b92..c971b50b66 100644 --- a/backends/arm/operators/op_hardtanh.py +++ b/backends/arm/operators/op_hardtanh.py @@ -1,4 +1,4 @@ -# Copyright 2023-2024 Arm Limited and/or its affiliates. +# Copyright 2023-2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -19,7 +19,6 @@ ) from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_quant_utils import quantize_value from serializer.tosa_serializer import TosaOp @@ -44,8 +43,8 @@ def define_node( input_qparams = get_input_qparams(node) # pyre-ignore[16] qargs = input_qparams[0] # Convert to quantized representation - clamp_min_qs = quantize_value(inputs[1].number, qargs) - clamp_max_qs = quantize_value(inputs[2].number, qargs) + clamp_min_qs = qargs.quantize_value(inputs[1].number).item() + clamp_max_qs = qargs.quantize_value(inputs[2].number).item() # Set fp values to 0.0 since they are not used clamp_min_fp = 0.0 clamp_max_fp = 0.0 diff --git a/backends/arm/operators/op_quant.py b/backends/arm/operators/op_quant.py deleted file mode 100644 index fcf9372c11..0000000000 --- a/backends/arm/operators/op_quant.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright 2023-2024 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe -from typing import List - -import serializer.tosa_serializer as ts -import torch -from executorch.backends.arm.operators.node_visitor import ( - NodeVisitor, - register_node_visitor, -) -from executorch.backends.arm.tosa_mapping import TosaArg -from serializer.tosa_serializer import TosaOp - - -@register_node_visitor -class QuantVisitor(NodeVisitor): - target = "quantized_decomposed.quantize_per_tensor.default" - - def __init__(self, *args): - super().__init__(*args) - - def define_node( - self, - node: torch.fx.Node, - tosa_graph: ts.TosaSerializer, - inputs: List[TosaArg], - output: TosaArg, - ) -> None: - item_name = inputs[0].name - ## Simply add an identityOp - tosa_graph.addOperator(TosaOp.Op().IDENTITY, [item_name], [output.name]) diff --git a/backends/arm/operators/op_relu.py b/backends/arm/operators/op_relu.py index 4df13e71b7..b5ffa2aa70 100644 --- a/backends/arm/operators/op_relu.py +++ b/backends/arm/operators/op_relu.py @@ -1,11 +1,10 @@ -# Copyright 2024 Arm Limited and/or its affiliates. +# Copyright 2024-2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. # pyre-unsafe -import executorch.backends.arm.tosa_quant_utils as tqutils import serializer.tosa_serializer as ts import torch.fx @@ -43,9 +42,8 @@ def define_node( clamp_max_qs = 0 if inputs[0].dtype == ts.DType.INT8: out_qargs = get_output_qparams(node) # pyre-ignore[16] - clamp_min_qs = tqutils.quantize_value(0, out_qargs[0]) - clamp_max_qs = tqutils.quantize_value(float("inf"), out_qargs[0]) - + clamp_min_qs = out_qargs[0].quantize_value(0).item() + clamp_max_qs = out_qargs[0].quantize_value(float("inf")).item() else: clamp_min_fp = 0 clamp_max_fp = float("inf") diff --git a/backends/arm/process_node.py b/backends/arm/process_node.py index 9ab9c49044..36a1567df9 100644 --- a/backends/arm/process_node.py +++ b/backends/arm/process_node.py @@ -12,12 +12,7 @@ import torch import torch.fx from executorch.backends.arm.operators.node_visitor import NodeVisitor -from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg -from executorch.backends.arm.tosa_quant_utils import ( - dq_op, - get_quantized_node_output_dtype, - is_node_quantized, -) +from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification from executorch.backends.arm.tosa_utils import getNodeArgs, tosa_shape from torch.export.exported_program import ExportedProgram @@ -35,15 +30,8 @@ def process_call_function( # Convert output (this node itself) output = TosaArg(node) - is_dq_node = node.target == dq_op - if is_dq_node: - output_dtype = ts.DType.INT8 - else: - output_dtype = output.dtype tosa_graph.currRegion.currBasicBlock.addTensor( - output.name, - tosa_shape(output.shape, output.dim_order), - output_dtype, + output.name, tosa_shape(output.shape, output.dim_order), output.dtype ) # Visiting each Node @@ -79,11 +67,7 @@ def process_inputs( tensor = ts.TosaSerializerTensor( inputs[0].name, tosa_shape(input_shape, input_dim_order), - ( - map_dtype(get_quantized_node_output_dtype(node)) - if is_node_quantized(node) - else inputs[0].dtype - ), + inputs[0].dtype, data=None, placeholderFilename=inputs[0].name + ".npy", ) diff --git a/backends/arm/tosa_quant_utils.py b/backends/arm/tosa_quant_utils.py index dff7b12cdd..9869a08c0b 100644 --- a/backends/arm/tosa_quant_utils.py +++ b/backends/arm/tosa_quant_utils.py @@ -1,4 +1,4 @@ -# Copyright 2023-2024 Arm Limited and/or its affiliates. +# Copyright 2023-2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -8,9 +8,7 @@ # Utiliy functions for TOSA quantized lowerings import math -from typing import Callable, cast, NamedTuple, Sequence - -import numpy as np +from typing import cast, NamedTuple import serializer.tosa_serializer as ts import torch.fx @@ -24,22 +22,6 @@ q_op = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default dq_op = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default dq_q_ops = (q_op, dq_op) -passable_ops = [ - exir_ops.edge.aten.view_copy.default, - exir_ops.edge.aten.permute_copy.default, - exir_ops.edge.aten.squeeze_copy.dims, - exir_ops.edge.aten.unsqueeze_copy.default, - exir_ops.edge.aten.split_with_sizes_copy.default, - exir_ops.edge.aten.repeat.default, - exir_ops.edge.aten.clone.default, - exir_ops.edge.aten.slice_copy.Tensor, - exir_ops.edge.aten.cat.default, -] - - -def register_passable_op(op): - """We need to be able to add custom ops such as tosa_transpose to the passable_op list after they have been created""" - passable_ops.append(op) def insert_rescale_ops_to_int32( @@ -53,8 +35,7 @@ def insert_rescale_ops_to_int32( This functions is used in serialization to TOSA for target ops that are handled by the DQ/D folding pass, which stores the quantization parameters - in the node meta dict as opposed to 'rescale_nodes_to_int32' which search - the graph upstream for DQ nodes. + in the node meta dict. """ # pyre-fixme[21]: 'Could not find a module corresponding to import `executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass`.' @@ -100,13 +81,12 @@ def insert_rescale_op_to_int8( Parameters: node: The original node that is being handled by the rescales. last_tensor:the tosa tensor to rescale back. - scale: the scaling factor used to rescale to int32, from the function 'rescale_nodes_to_int32' + scale: the scaling factor used to rescale to int32, from the function 'insert_rescale_op_to_int32' tosa_graph: the tosa_graph to manipulate. This functions is used in serialization to TOSA for target ops that are handled by the DQ/D folding pass, which stores the quantization parameters - in the node meta dict as opposed to 'rescale_node_back_to_int8' which search - the graph downstream for Q nodes. + in the node meta dict. """ # pyre-fixme[21]: 'Could not find a module corresponding to import `executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass`.' from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( @@ -148,17 +128,6 @@ def quantize_value(self, x): def dequantize_value(self, qx: torch.Tensor) -> torch.Tensor: return (qx - self.zp) * self.scale - def __eq__(self, other): - if isinstance(other, QuantArgs): - return ( - self.scale == other.scale - and self.zp == other.zp - and self.qmin == other.qmin - and self.qmax == other.qmax - and self.dtype == other.dtype - ) - return False - @classmethod def from_operator(cls, op, args): if op in dq_q_ops: @@ -174,172 +143,6 @@ def from_operator(cls, op, args): raise NotImplementedError -def quantize_value(x, qargs: QuantArgs, dtype=np.int8): - return np.clip( - np.round(x / qargs.scale) + qargs.zp, - qargs.qmin, - qargs.qmax, - ).astype(dtype) - - -def dequantize_value(qx, qargs: QuantArgs): - return (np.int64(qx) - qargs.zp) * qargs.scale - - -def qargs_from_qnode(node: torch.fx.Node): - assert node.target in dq_q_ops, f"Op {node} is not a quant node." - - return QuantArgs.from_operator(node.target, node.args) - - -def get_neighbour_quant_args( - node: torch.fx.Node, -) -> tuple[list[QuantArgs], list[QuantArgs]]: - user_q_args = [] - - for user in node.users: - q_args = search_quant_arg_downstream(user) - if q_args: - user_q_args.append(q_args) - - input_q_nodes = [] - for input_node in node.all_input_nodes: - q_args = search_quant_arg_upstream(input_node) - if q_args: - input_q_nodes.append(q_args) - return user_q_args, input_q_nodes - - -def all_q_args_equal(q_arg_list: list[QuantArgs]) -> bool: - first_q_arg = q_arg_list[0] - for q_arg in q_arg_list: - if q_arg != first_q_arg: - return False - return True - - -def is_node_quantized(node: torch.fx.Node) -> bool: - if node.target in dq_q_ops: - return True - - user_q_args, input_q_args = get_neighbour_quant_args(node) - - # If we did not find any neighbouring quant nodes, we are not quantized. - if len(input_q_args) == 0 and len(user_q_args) == 0: - return False - - if node.target in passable_ops: - assert all_q_args_equal( - user_q_args + input_q_args - ), f"Node {node} needs same quantization parameters on all inputs and outputs." - - return True - - -def search_quant_arg_downstream(node: torch.fx.Node) -> QuantArgs | None: - """ - Iterates downward in the graph passing through 'passable_ops' to find and return a quantization node, - starting with 'node'. - If a passable node with multiple consumers is encountered, - find QuantArgs for all consumers and assert that they are equal. - If a node not in passable_ops is encountered, return None. - If a node without consumers is encountered, return None. - """ - if node.target in dq_q_ops: - return qargs_from_qnode(node) - if node.target not in passable_ops: - return None - consumer_nodes = list(node.users) - if len(consumer_nodes) == 0: - return None - elif len(consumer_nodes) == 1: - return search_quant_arg_downstream(consumer_nodes[0]) - else: - consumer_qargs: list[QuantArgs] = [] - for input in consumer_nodes: - quant_args = search_quant_arg_downstream(input) - if quant_args: - consumer_qargs.append(quant_args) - if len(consumer_qargs) == 0: - return None - assert all_q_args_equal( - consumer_qargs - ), f"Encountered a op, {node}, in passable_ops with different QuantArgs for different consumers." - return consumer_qargs[0] - - -def get_quant_arg_downstream(node: torch.fx.Node) -> QuantArgs: - """Calls search_quant_arg_downstream and asserts that QuantArgs are found, - meaning return value can't be None. - """ - qargs = search_quant_arg_downstream(node) - assert qargs, f"Did not find QuantArgs downstream for node {node}" - return qargs - - -def search_quant_arg_upstream(node: torch.fx.Node) -> QuantArgs | None: - """ - Iterates upward in the graph passing through 'passable_ops' to find and return a quantization node, - starting with 'node'. - If a passable node with multiple inputs is encountered, - find QuantArgs for all inputs and assert that they are equal. - If a node not in passable_ops is encountered, return None. - If a node without inputs is encountered, return None. - """ - - if node.target in dq_q_ops: - return qargs_from_qnode(node) - if node.target not in passable_ops: - return None - input_nodes = list(node.all_input_nodes) - if len(input_nodes) == 0: - return None - elif len(input_nodes) == 1: - return search_quant_arg_upstream(input_nodes[0]) - else: - input_qargs: list[QuantArgs] = [] - for input in input_nodes: - quant_args = search_quant_arg_upstream(input) - if quant_args: - input_qargs.append(quant_args) - if len(input_qargs) == 0: - return None - assert all_q_args_equal( - input_qargs - ), f"Encountered a op, {node}, in passable_ops with different QuantArgs for different inputs." - return input_qargs[0] - - -def get_quant_arg_upstream(node: torch.fx.Node) -> QuantArgs: - """Calls search_quant_arg_upstream and asserts that QuantArgs are found, - meaning return value can't be None. - """ - qargs = search_quant_arg_upstream(node) - assert qargs, f"Did not find QuantArgs upstream for node {node}" - return qargs - - -def get_quantized_node_output_dtype(node: torch.fx.Node) -> torch.dtype: - if isinstance(node.target, Callable) and "output_qparams" in node.meta.keys(): - # Check if the node has had it's quantization parameters folded - # and retrieve the dtype from the meta dict in that case. - assert len(node.meta["output_qparams"]) == 1 - qargs = cast(QuantArgs, node.meta["output_qparams"][0]) - return qargs.dtype - - if node.target in dq_q_ops: - return cast(torch.dtype, node.args[5]) - - # if not a tosa node, nor a q/dq op, walk the graph until we find a q op - user_q_args, input_q_args = get_neighbour_quant_args(node) - if len(user_q_args) > 0: - return user_q_args[0].dtype - elif node.target in passable_ops and len(input_q_args) > 0: - return input_q_args[0].dtype - else: - raise RuntimeError("No quantized node found in graph") - - # Check if scale32 mode is used for given output element type def is_scale32(type): return type == ts.DType.INT8 @@ -476,69 +279,6 @@ def build_rescale_from_int32( return -def rescale_nodes_to_int32( - nodes: Sequence[Node], tosa_graph: ts.TosaSerializer -) -> tuple[list[TosaSerializerTensor], float]: - """Rescales all 'nodes' to int32, adding suitable RESCALE ops to 'tosa_graph'. - The scales are adjusted using the smallest scale of all 'nodes'. - - Returns a list of the rescaled nodes and the scale factor used, - needed by rescale_node_back_to_int8. - """ - - tensors = [TosaArg(node) for node in nodes] - - # Reshape tensor according to tosa dim order - for tensor in tensors: - dim_order = tensor.dim_order - tensor.shape = [tensor.shape[i] for i in dim_order] - - qargs = [get_quant_arg_upstream(node) for node in nodes] - - # Scale the int8 quantized input to a common scale in the integer - # domain - min_scale = min([qarg.scale for qarg in qargs]) - scales = [qarg.scale / min_scale for qarg in qargs] - - rescaled_nodes: list[TosaSerializerTensor] = [] - for tensor, qarg, scale in zip(tensors, qargs, scales): - rescaled_nodes.append( - build_rescale_to_int32( - tosa_graph, - tensor, - qarg.zp, - scale, - ) - ) - return rescaled_nodes, min_scale - - -def rescale_node_back_to_int8( - node: Node, - last_tensor: TosaSerializerTensor, - scale: float, - tosa_graph: ts.TosaSerializer, -): - """Rescales the node back to int8, adding a suitable RESCALE op to 'tosa_graph'. - Parameters: - node: The original node that is being handled by the rescales. - last_tensor:the tosa tensor to rescale back. - scale: the scaling factor used to rescale to int32, from the function 'rescale_nodes_to_int32' - tosa_graph: the tosa_graph to manipulate. - """ - qargs_out = get_quant_arg_downstream(list(node.users)[0]) - output_rescale_scale = scale / qargs_out.scale - - # Rescale Back to INT8 - build_rescale_from_int32( - tosa_graph, - last_tensor.name, - node.name, - qargs_out.zp, - output_rescale_scale, - ) - - """ Creates a TOSA rescale op based on conv2d parameters. """ diff --git a/backends/arm/tosa_utils.py b/backends/arm/tosa_utils.py index c03e0ef0bb..9fefdbb3ff 100644 --- a/backends/arm/tosa_utils.py +++ b/backends/arm/tosa_utils.py @@ -115,10 +115,6 @@ def getNodeArgs(node: Node) -> list[TosaArg]: return [TosaArg(arg) for arg in node.args] -def get_input_tensor(node: Node) -> TosaArg: - return TosaArg(node.args[0]) - - def get_output_node(node: Node) -> Node: return list(node.users)[0] @@ -146,30 +142,6 @@ def is_consumer_node_depthwise_conv2d(node): return False -def get_two_inputs(node: Node, check: bool = False) -> tuple[Node, Node]: - """Returns two input nodes to 'node' in order. If 'node' only has one input, - it is returned twice. - - Fails if there are no input nodes. - Fails if there are >2 input nodes and 'check' is True, - """ - - num_inputs = len(node.all_input_nodes) - assert num_inputs > 0, f"Node '{node.name}' requires >0 input, got {num_inputs}." - - input1 = node.all_input_nodes[0] - if num_inputs == 1: - input2 = node.all_input_nodes[0] - else: - input2 = node.all_input_nodes[1] - if check: - assert ( - num_inputs <= 2 - ), f"Node '{node.name}' requires <=2 inputs, got {num_inputs}." - - return input1, input2 - - def tosa_shape(shape, dim_order): return tuple([shape[dim] for dim in dim_order]) diff --git a/examples/arm/aot_arm_compiler.py b/examples/arm/aot_arm_compiler.py index 9563be93aa..a49436193b 100644 --- a/examples/arm/aot_arm_compiler.py +++ b/examples/arm/aot_arm_compiler.py @@ -264,7 +264,11 @@ def get_compile_spec( ) -> list[CompileSpec]: spec_builder = None if target == "TOSA": - spec_builder = ArmCompileSpecBuilder().tosa_compile_spec("TOSA-0.80+BI") + spec_builder = ( + ArmCompileSpecBuilder() + .tosa_compile_spec("TOSA-0.80+BI") + .set_quantize_io(True) + ) elif "ethos-u55" in target: spec_builder = ArmCompileSpecBuilder().ethosu_compile_spec( target, From eaad7ff1ece5524b8892be9a3c40a3636ec2b64f Mon Sep 17 00:00:00 2001 From: Oscar Andersson <87121123+oscarandersson8218@users.noreply.github.com> Date: Fri, 17 Jan 2025 11:36:03 +0100 Subject: [PATCH 6/7] Revert "Remove unused functions for quantization handling" (#7724) Revert "Remove unused functions for quantization handling (#7700)" This reverts commit ffc20208dae8f4900da11bfffb76f749e7514132. --- .../annotate_channels_last_dim_order_pass.py | 7 +- backends/arm/operators/__init__.py | 2 + backends/arm/operators/op_dequant.py | 35 +++ backends/arm/operators/op_hardtanh.py | 7 +- backends/arm/operators/op_quant.py | 35 +++ backends/arm/operators/op_relu.py | 8 +- backends/arm/process_node.py | 22 +- backends/arm/tosa_quant_utils.py | 270 +++++++++++++++++- backends/arm/tosa_utils.py | 28 ++ examples/arm/aot_arm_compiler.py | 6 +- 10 files changed, 399 insertions(+), 21 deletions(-) create mode 100644 backends/arm/operators/op_dequant.py create mode 100644 backends/arm/operators/op_quant.py diff --git a/backends/arm/_passes/annotate_channels_last_dim_order_pass.py b/backends/arm/_passes/annotate_channels_last_dim_order_pass.py index 4aff46de67..80c5f3c442 100644 --- a/backends/arm/_passes/annotate_channels_last_dim_order_pass.py +++ b/backends/arm/_passes/annotate_channels_last_dim_order_pass.py @@ -1,4 +1,4 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024 Arm Limited and/or its affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the @@ -15,7 +15,7 @@ get_node_arg, insert_q_dq_pair, ) -from executorch.backends.arm.tosa_quant_utils import dq_op, q_op +from executorch.backends.arm.tosa_quant_utils import dq_op, q_op, register_passable_op from executorch.backends.arm.tosa_utils import is_consumer_node_depthwise_conv2d from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult @@ -43,6 +43,9 @@ def _transpose_impl(*args, **kwargs): return args[0] +register_passable_op(torch.ops.passthrough_to_tosa._transpose) + + class AnnotateChannelsLastDimOrder(ExportPass): """ Annotates each node with a tosa_dim_order. tosa_dim_order can be seen as a channels-last dim-order diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index a21bde535e..157e5ec092 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -13,6 +13,7 @@ op_bmm, op_cat, op_conv2d, + op_dequant, op_exp, op_full, op_get_item, @@ -23,6 +24,7 @@ op_min, op_mul, op_permute, + op_quant, op_reciprocal, op_relu, op_repeat, diff --git a/backends/arm/operators/op_dequant.py b/backends/arm/operators/op_dequant.py new file mode 100644 index 0000000000..022f4e45ce --- /dev/null +++ b/backends/arm/operators/op_dequant.py @@ -0,0 +1,35 @@ +# Copyright 2023-2024 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe +from typing import List + +import serializer.tosa_serializer as ts +import torch +from executorch.backends.arm.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.arm.tosa_mapping import TosaArg +from serializer.tosa_serializer import TosaOp + + +@register_node_visitor +class DequantVisitor(NodeVisitor): + target = "quantized_decomposed.dequantize_per_tensor.default" + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: ts.TosaSerializer, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + item_name = inputs[0].name + ## Simply add an identityOp + tosa_graph.addOperator(TosaOp.Op().IDENTITY, [item_name], [output.name]) diff --git a/backends/arm/operators/op_hardtanh.py b/backends/arm/operators/op_hardtanh.py index c971b50b66..bfbab55b92 100644 --- a/backends/arm/operators/op_hardtanh.py +++ b/backends/arm/operators/op_hardtanh.py @@ -1,4 +1,4 @@ -# Copyright 2023-2025 Arm Limited and/or its affiliates. +# Copyright 2023-2024 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -19,6 +19,7 @@ ) from executorch.backends.arm.tosa_mapping import TosaArg +from executorch.backends.arm.tosa_quant_utils import quantize_value from serializer.tosa_serializer import TosaOp @@ -43,8 +44,8 @@ def define_node( input_qparams = get_input_qparams(node) # pyre-ignore[16] qargs = input_qparams[0] # Convert to quantized representation - clamp_min_qs = qargs.quantize_value(inputs[1].number).item() - clamp_max_qs = qargs.quantize_value(inputs[2].number).item() + clamp_min_qs = quantize_value(inputs[1].number, qargs) + clamp_max_qs = quantize_value(inputs[2].number, qargs) # Set fp values to 0.0 since they are not used clamp_min_fp = 0.0 clamp_max_fp = 0.0 diff --git a/backends/arm/operators/op_quant.py b/backends/arm/operators/op_quant.py new file mode 100644 index 0000000000..fcf9372c11 --- /dev/null +++ b/backends/arm/operators/op_quant.py @@ -0,0 +1,35 @@ +# Copyright 2023-2024 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe +from typing import List + +import serializer.tosa_serializer as ts +import torch +from executorch.backends.arm.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.arm.tosa_mapping import TosaArg +from serializer.tosa_serializer import TosaOp + + +@register_node_visitor +class QuantVisitor(NodeVisitor): + target = "quantized_decomposed.quantize_per_tensor.default" + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: ts.TosaSerializer, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + item_name = inputs[0].name + ## Simply add an identityOp + tosa_graph.addOperator(TosaOp.Op().IDENTITY, [item_name], [output.name]) diff --git a/backends/arm/operators/op_relu.py b/backends/arm/operators/op_relu.py index b5ffa2aa70..4df13e71b7 100644 --- a/backends/arm/operators/op_relu.py +++ b/backends/arm/operators/op_relu.py @@ -1,10 +1,11 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. # pyre-unsafe +import executorch.backends.arm.tosa_quant_utils as tqutils import serializer.tosa_serializer as ts import torch.fx @@ -42,8 +43,9 @@ def define_node( clamp_max_qs = 0 if inputs[0].dtype == ts.DType.INT8: out_qargs = get_output_qparams(node) # pyre-ignore[16] - clamp_min_qs = out_qargs[0].quantize_value(0).item() - clamp_max_qs = out_qargs[0].quantize_value(float("inf")).item() + clamp_min_qs = tqutils.quantize_value(0, out_qargs[0]) + clamp_max_qs = tqutils.quantize_value(float("inf"), out_qargs[0]) + else: clamp_min_fp = 0 clamp_max_fp = float("inf") diff --git a/backends/arm/process_node.py b/backends/arm/process_node.py index 36a1567df9..9ab9c49044 100644 --- a/backends/arm/process_node.py +++ b/backends/arm/process_node.py @@ -12,7 +12,12 @@ import torch import torch.fx from executorch.backends.arm.operators.node_visitor import NodeVisitor -from executorch.backends.arm.tosa_mapping import TosaArg +from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg +from executorch.backends.arm.tosa_quant_utils import ( + dq_op, + get_quantized_node_output_dtype, + is_node_quantized, +) from executorch.backends.arm.tosa_specification import TosaSpecification from executorch.backends.arm.tosa_utils import getNodeArgs, tosa_shape from torch.export.exported_program import ExportedProgram @@ -30,8 +35,15 @@ def process_call_function( # Convert output (this node itself) output = TosaArg(node) + is_dq_node = node.target == dq_op + if is_dq_node: + output_dtype = ts.DType.INT8 + else: + output_dtype = output.dtype tosa_graph.currRegion.currBasicBlock.addTensor( - output.name, tosa_shape(output.shape, output.dim_order), output.dtype + output.name, + tosa_shape(output.shape, output.dim_order), + output_dtype, ) # Visiting each Node @@ -67,7 +79,11 @@ def process_inputs( tensor = ts.TosaSerializerTensor( inputs[0].name, tosa_shape(input_shape, input_dim_order), - inputs[0].dtype, + ( + map_dtype(get_quantized_node_output_dtype(node)) + if is_node_quantized(node) + else inputs[0].dtype + ), data=None, placeholderFilename=inputs[0].name + ".npy", ) diff --git a/backends/arm/tosa_quant_utils.py b/backends/arm/tosa_quant_utils.py index 9869a08c0b..dff7b12cdd 100644 --- a/backends/arm/tosa_quant_utils.py +++ b/backends/arm/tosa_quant_utils.py @@ -1,4 +1,4 @@ -# Copyright 2023-2025 Arm Limited and/or its affiliates. +# Copyright 2023-2024 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -8,7 +8,9 @@ # Utiliy functions for TOSA quantized lowerings import math -from typing import cast, NamedTuple +from typing import Callable, cast, NamedTuple, Sequence + +import numpy as np import serializer.tosa_serializer as ts import torch.fx @@ -22,6 +24,22 @@ q_op = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default dq_op = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default dq_q_ops = (q_op, dq_op) +passable_ops = [ + exir_ops.edge.aten.view_copy.default, + exir_ops.edge.aten.permute_copy.default, + exir_ops.edge.aten.squeeze_copy.dims, + exir_ops.edge.aten.unsqueeze_copy.default, + exir_ops.edge.aten.split_with_sizes_copy.default, + exir_ops.edge.aten.repeat.default, + exir_ops.edge.aten.clone.default, + exir_ops.edge.aten.slice_copy.Tensor, + exir_ops.edge.aten.cat.default, +] + + +def register_passable_op(op): + """We need to be able to add custom ops such as tosa_transpose to the passable_op list after they have been created""" + passable_ops.append(op) def insert_rescale_ops_to_int32( @@ -35,7 +53,8 @@ def insert_rescale_ops_to_int32( This functions is used in serialization to TOSA for target ops that are handled by the DQ/D folding pass, which stores the quantization parameters - in the node meta dict. + in the node meta dict as opposed to 'rescale_nodes_to_int32' which search + the graph upstream for DQ nodes. """ # pyre-fixme[21]: 'Could not find a module corresponding to import `executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass`.' @@ -81,12 +100,13 @@ def insert_rescale_op_to_int8( Parameters: node: The original node that is being handled by the rescales. last_tensor:the tosa tensor to rescale back. - scale: the scaling factor used to rescale to int32, from the function 'insert_rescale_op_to_int32' + scale: the scaling factor used to rescale to int32, from the function 'rescale_nodes_to_int32' tosa_graph: the tosa_graph to manipulate. This functions is used in serialization to TOSA for target ops that are handled by the DQ/D folding pass, which stores the quantization parameters - in the node meta dict. + in the node meta dict as opposed to 'rescale_node_back_to_int8' which search + the graph downstream for Q nodes. """ # pyre-fixme[21]: 'Could not find a module corresponding to import `executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass`.' from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( @@ -128,6 +148,17 @@ def quantize_value(self, x): def dequantize_value(self, qx: torch.Tensor) -> torch.Tensor: return (qx - self.zp) * self.scale + def __eq__(self, other): + if isinstance(other, QuantArgs): + return ( + self.scale == other.scale + and self.zp == other.zp + and self.qmin == other.qmin + and self.qmax == other.qmax + and self.dtype == other.dtype + ) + return False + @classmethod def from_operator(cls, op, args): if op in dq_q_ops: @@ -143,6 +174,172 @@ def from_operator(cls, op, args): raise NotImplementedError +def quantize_value(x, qargs: QuantArgs, dtype=np.int8): + return np.clip( + np.round(x / qargs.scale) + qargs.zp, + qargs.qmin, + qargs.qmax, + ).astype(dtype) + + +def dequantize_value(qx, qargs: QuantArgs): + return (np.int64(qx) - qargs.zp) * qargs.scale + + +def qargs_from_qnode(node: torch.fx.Node): + assert node.target in dq_q_ops, f"Op {node} is not a quant node." + + return QuantArgs.from_operator(node.target, node.args) + + +def get_neighbour_quant_args( + node: torch.fx.Node, +) -> tuple[list[QuantArgs], list[QuantArgs]]: + user_q_args = [] + + for user in node.users: + q_args = search_quant_arg_downstream(user) + if q_args: + user_q_args.append(q_args) + + input_q_nodes = [] + for input_node in node.all_input_nodes: + q_args = search_quant_arg_upstream(input_node) + if q_args: + input_q_nodes.append(q_args) + return user_q_args, input_q_nodes + + +def all_q_args_equal(q_arg_list: list[QuantArgs]) -> bool: + first_q_arg = q_arg_list[0] + for q_arg in q_arg_list: + if q_arg != first_q_arg: + return False + return True + + +def is_node_quantized(node: torch.fx.Node) -> bool: + if node.target in dq_q_ops: + return True + + user_q_args, input_q_args = get_neighbour_quant_args(node) + + # If we did not find any neighbouring quant nodes, we are not quantized. + if len(input_q_args) == 0 and len(user_q_args) == 0: + return False + + if node.target in passable_ops: + assert all_q_args_equal( + user_q_args + input_q_args + ), f"Node {node} needs same quantization parameters on all inputs and outputs." + + return True + + +def search_quant_arg_downstream(node: torch.fx.Node) -> QuantArgs | None: + """ + Iterates downward in the graph passing through 'passable_ops' to find and return a quantization node, + starting with 'node'. + If a passable node with multiple consumers is encountered, + find QuantArgs for all consumers and assert that they are equal. + If a node not in passable_ops is encountered, return None. + If a node without consumers is encountered, return None. + """ + if node.target in dq_q_ops: + return qargs_from_qnode(node) + if node.target not in passable_ops: + return None + consumer_nodes = list(node.users) + if len(consumer_nodes) == 0: + return None + elif len(consumer_nodes) == 1: + return search_quant_arg_downstream(consumer_nodes[0]) + else: + consumer_qargs: list[QuantArgs] = [] + for input in consumer_nodes: + quant_args = search_quant_arg_downstream(input) + if quant_args: + consumer_qargs.append(quant_args) + if len(consumer_qargs) == 0: + return None + assert all_q_args_equal( + consumer_qargs + ), f"Encountered a op, {node}, in passable_ops with different QuantArgs for different consumers." + return consumer_qargs[0] + + +def get_quant_arg_downstream(node: torch.fx.Node) -> QuantArgs: + """Calls search_quant_arg_downstream and asserts that QuantArgs are found, + meaning return value can't be None. + """ + qargs = search_quant_arg_downstream(node) + assert qargs, f"Did not find QuantArgs downstream for node {node}" + return qargs + + +def search_quant_arg_upstream(node: torch.fx.Node) -> QuantArgs | None: + """ + Iterates upward in the graph passing through 'passable_ops' to find and return a quantization node, + starting with 'node'. + If a passable node with multiple inputs is encountered, + find QuantArgs for all inputs and assert that they are equal. + If a node not in passable_ops is encountered, return None. + If a node without inputs is encountered, return None. + """ + + if node.target in dq_q_ops: + return qargs_from_qnode(node) + if node.target not in passable_ops: + return None + input_nodes = list(node.all_input_nodes) + if len(input_nodes) == 0: + return None + elif len(input_nodes) == 1: + return search_quant_arg_upstream(input_nodes[0]) + else: + input_qargs: list[QuantArgs] = [] + for input in input_nodes: + quant_args = search_quant_arg_upstream(input) + if quant_args: + input_qargs.append(quant_args) + if len(input_qargs) == 0: + return None + assert all_q_args_equal( + input_qargs + ), f"Encountered a op, {node}, in passable_ops with different QuantArgs for different inputs." + return input_qargs[0] + + +def get_quant_arg_upstream(node: torch.fx.Node) -> QuantArgs: + """Calls search_quant_arg_upstream and asserts that QuantArgs are found, + meaning return value can't be None. + """ + qargs = search_quant_arg_upstream(node) + assert qargs, f"Did not find QuantArgs upstream for node {node}" + return qargs + + +def get_quantized_node_output_dtype(node: torch.fx.Node) -> torch.dtype: + if isinstance(node.target, Callable) and "output_qparams" in node.meta.keys(): + # Check if the node has had it's quantization parameters folded + # and retrieve the dtype from the meta dict in that case. + assert len(node.meta["output_qparams"]) == 1 + qargs = cast(QuantArgs, node.meta["output_qparams"][0]) + return qargs.dtype + + if node.target in dq_q_ops: + return cast(torch.dtype, node.args[5]) + + # if not a tosa node, nor a q/dq op, walk the graph until we find a q op + user_q_args, input_q_args = get_neighbour_quant_args(node) + if len(user_q_args) > 0: + return user_q_args[0].dtype + elif node.target in passable_ops and len(input_q_args) > 0: + return input_q_args[0].dtype + else: + raise RuntimeError("No quantized node found in graph") + + # Check if scale32 mode is used for given output element type def is_scale32(type): return type == ts.DType.INT8 @@ -279,6 +476,69 @@ def build_rescale_from_int32( return +def rescale_nodes_to_int32( + nodes: Sequence[Node], tosa_graph: ts.TosaSerializer +) -> tuple[list[TosaSerializerTensor], float]: + """Rescales all 'nodes' to int32, adding suitable RESCALE ops to 'tosa_graph'. + The scales are adjusted using the smallest scale of all 'nodes'. + + Returns a list of the rescaled nodes and the scale factor used, + needed by rescale_node_back_to_int8. + """ + + tensors = [TosaArg(node) for node in nodes] + + # Reshape tensor according to tosa dim order + for tensor in tensors: + dim_order = tensor.dim_order + tensor.shape = [tensor.shape[i] for i in dim_order] + + qargs = [get_quant_arg_upstream(node) for node in nodes] + + # Scale the int8 quantized input to a common scale in the integer + # domain + min_scale = min([qarg.scale for qarg in qargs]) + scales = [qarg.scale / min_scale for qarg in qargs] + + rescaled_nodes: list[TosaSerializerTensor] = [] + for tensor, qarg, scale in zip(tensors, qargs, scales): + rescaled_nodes.append( + build_rescale_to_int32( + tosa_graph, + tensor, + qarg.zp, + scale, + ) + ) + return rescaled_nodes, min_scale + + +def rescale_node_back_to_int8( + node: Node, + last_tensor: TosaSerializerTensor, + scale: float, + tosa_graph: ts.TosaSerializer, +): + """Rescales the node back to int8, adding a suitable RESCALE op to 'tosa_graph'. + Parameters: + node: The original node that is being handled by the rescales. + last_tensor:the tosa tensor to rescale back. + scale: the scaling factor used to rescale to int32, from the function 'rescale_nodes_to_int32' + tosa_graph: the tosa_graph to manipulate. + """ + qargs_out = get_quant_arg_downstream(list(node.users)[0]) + output_rescale_scale = scale / qargs_out.scale + + # Rescale Back to INT8 + build_rescale_from_int32( + tosa_graph, + last_tensor.name, + node.name, + qargs_out.zp, + output_rescale_scale, + ) + + """ Creates a TOSA rescale op based on conv2d parameters. """ diff --git a/backends/arm/tosa_utils.py b/backends/arm/tosa_utils.py index 9fefdbb3ff..c03e0ef0bb 100644 --- a/backends/arm/tosa_utils.py +++ b/backends/arm/tosa_utils.py @@ -115,6 +115,10 @@ def getNodeArgs(node: Node) -> list[TosaArg]: return [TosaArg(arg) for arg in node.args] +def get_input_tensor(node: Node) -> TosaArg: + return TosaArg(node.args[0]) + + def get_output_node(node: Node) -> Node: return list(node.users)[0] @@ -142,6 +146,30 @@ def is_consumer_node_depthwise_conv2d(node): return False +def get_two_inputs(node: Node, check: bool = False) -> tuple[Node, Node]: + """Returns two input nodes to 'node' in order. If 'node' only has one input, + it is returned twice. + + Fails if there are no input nodes. + Fails if there are >2 input nodes and 'check' is True, + """ + + num_inputs = len(node.all_input_nodes) + assert num_inputs > 0, f"Node '{node.name}' requires >0 input, got {num_inputs}." + + input1 = node.all_input_nodes[0] + if num_inputs == 1: + input2 = node.all_input_nodes[0] + else: + input2 = node.all_input_nodes[1] + if check: + assert ( + num_inputs <= 2 + ), f"Node '{node.name}' requires <=2 inputs, got {num_inputs}." + + return input1, input2 + + def tosa_shape(shape, dim_order): return tuple([shape[dim] for dim in dim_order]) diff --git a/examples/arm/aot_arm_compiler.py b/examples/arm/aot_arm_compiler.py index a49436193b..9563be93aa 100644 --- a/examples/arm/aot_arm_compiler.py +++ b/examples/arm/aot_arm_compiler.py @@ -264,11 +264,7 @@ def get_compile_spec( ) -> list[CompileSpec]: spec_builder = None if target == "TOSA": - spec_builder = ( - ArmCompileSpecBuilder() - .tosa_compile_spec("TOSA-0.80+BI") - .set_quantize_io(True) - ) + spec_builder = ArmCompileSpecBuilder().tosa_compile_spec("TOSA-0.80+BI") elif "ethos-u55" in target: spec_builder = ArmCompileSpecBuilder().ethosu_compile_spec( target, From 5b9ab56657dabda161e866d4a574172f974b20c8 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Fri, 17 Jan 2025 09:16:12 -0800 Subject: [PATCH 7/7] install_requirements.py: reorganize requirements (#7705) Duplicate requirements with the pyproject.toml > /dev/null One unique devel reqiurement > requirements-dev.txt Examples requirements > requirements-examples.txt Nightlies stayed in the script. Rationale: be as "normal" a Python project as seemed possible. Test Plan: install_requirements.sh in a clean venv succeeded --- install_requirements.py | 25 +++++-------------------- requirements-examples.txt | 5 +++++ 2 files changed, 10 insertions(+), 20 deletions(-) create mode 100644 requirements-examples.txt diff --git a/install_requirements.py b/install_requirements.py index 409460ca10..52ba89edd7 100644 --- a/install_requirements.py +++ b/install_requirements.py @@ -104,34 +104,15 @@ def install_requirements(use_pytorch_nightly): if use_pytorch_nightly else "torchvision" ), # For testing. - "typing-extensions", ] - # pip packages needed to run examples. - # TODO: Make each example publish its own requirements.txt EXAMPLES_REQUIREMENTS = [ - "timm==1.0.7", f"torchaudio==2.6.0.{NIGHTLY_VERSION}" if use_pytorch_nightly else "torchaudio", - "torchsr==1.0.4", - "transformers==4.47.1", - ] - - # pip packages needed for development. - DEVEL_REQUIREMENTS = [ - "cmake", # For building binary targets. - "pip>=23", # For building the pip package. - "pyyaml", # Imported by the kernel codegen tools. - "setuptools>=63", # For building the pip package. - "tomli", # Imported by extract_sources.py when using python < 3.11. - "wheel", # For building the pip package archive. - "zstd", # Imported by resolve_buck.py. ] # Assemble the list of requirements to actually install. # TODO: Add options for reducing the number of requirements. - REQUIREMENTS_TO_INSTALL = ( - EXIR_REQUIREMENTS + DEVEL_REQUIREMENTS + EXAMPLES_REQUIREMENTS - ) + REQUIREMENTS_TO_INSTALL = EXIR_REQUIREMENTS + EXAMPLES_REQUIREMENTS # Install the requirements. `--extra-index-url` tells pip to look for package # versions on the provided URL if they aren't available on the default URL. @@ -141,6 +122,8 @@ def install_requirements(use_pytorch_nightly): "-m", "pip", "install", + "-r", + "requirements-examples.txt", *REQUIREMENTS_TO_INSTALL, "--extra-index-url", TORCH_NIGHTLY_URL, @@ -160,6 +143,8 @@ def install_requirements(use_pytorch_nightly): "-m", "pip", "install", + # Without --no-build-isolation, setup.py can't find the torch module. + "--no-build-isolation", *LOCAL_REQUIREMENTS, ], check=True, diff --git a/requirements-examples.txt b/requirements-examples.txt new file mode 100644 index 0000000000..d4126a178a --- /dev/null +++ b/requirements-examples.txt @@ -0,0 +1,5 @@ +# pip packages needed to run examples. +# TODO: Make each example publish its own requirements.txt +timm == 1.0.7 +torchsr == 1.0.4 +transformers ==4.47.1