Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
swolchok committed Jan 17, 2025
2 parents c208b1f + 5b9ab56 commit 4b3f654
Show file tree
Hide file tree
Showing 10 changed files with 172 additions and 84 deletions.
7 changes: 4 additions & 3 deletions backends/arm/_passes/fuse_quantized_activation_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@ def _is_fuseable_quantized_activation(self, node: Node):
is_fuseable = min_val == 0

is_quantized = len(node.users) == 1 and next(iter(node.users)).target == q_op
if is_quantized:
if is_fuseable and is_quantized:
quant_node = next(iter(node.users))
zp = quant_node.args[2]
qmin = quant_node.args[3]

return is_fuseable and is_quantized and zp == qmin
return zp == qmin
else:
return False

def _is_fuseable_input(self, node: Node):
return (
Expand Down
47 changes: 46 additions & 1 deletion backends/arm/test/misc/test_multiple_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@

import unittest

import pytest
import torch
from executorch.backends.arm.test import common
from executorch.backends.arm.test import common, conftest
from executorch.backends.arm.test.tester.arm_tester import ArmTester
from executorch.exir.backend.compile_spec_schema import CompileSpec


class TestMultipleOutputs(unittest.TestCase):
Expand Down Expand Up @@ -51,3 +53,46 @@ def test_tosa_BI_pipeline(self):
.to_executorch()
.run_method_and_compare_outputs(inputs=inputs, qtol=1.0)
)

def _test_ethosu_BI_pipeline(
self,
module: torch.nn.Module,
test_data: tuple[torch.Tensor],
compile_spec: CompileSpec,
):
tester = (
ArmTester(
module,
example_inputs=test_data,
compile_spec=compile_spec,
)
.quantize()
.export()
.to_edge_transform_and_lower()
.to_executorch()
.serialize()
)
if conftest.is_option_enabled("corstone_fvp"):
tester.run_method_and_compare_outputs(qtol=1, inputs=test_data)

@pytest.mark.corstone_fvp
def test_u85_BI(self):
module = self.MultipleOutputsModule()
test_data = module.get_inputs()
self._test_ethosu_BI_pipeline(
module,
test_data,
common.get_u85_compile_spec(),
)

@pytest.mark.corstone_fvp
@conftest.expectedFailureOnFVP
# TODO MLETORCH-598
def test_u55_BI(self):
module = self.MultipleOutputsModule()
test_data = module.get_inputs()
self._test_ethosu_BI_pipeline(
module,
test_data,
common.get_u55_compile_spec(),
)
79 changes: 42 additions & 37 deletions backends/arm/test/runner_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,50 +115,53 @@ def _get_input_quantization_params(
return quant_params


def _get_output_node(program: ExportedProgram) -> Node:
def _get_output_nodes(program: ExportedProgram) -> list[Node]:
"""
Get output node to this model.
Args:
program (ExportedProgram): The program to get output node from.
program (ExportedProgram): The program to get the output nodes from.
Returns:
The node that is the output of 'program'.
The nodes that are the outputs of the 'program'.
"""

output_nodes = []
for node in program.graph.nodes:
if node.op == "output":
return node
raise RuntimeError("No output node found.")
for output in node.args[0]:
output_nodes.append(output)
if len(output_nodes) == 0:
raise RuntimeError("No output nodes found.")
else:
return output_nodes


def _get_output_quantization_params(
program: ExportedProgram, output_node: Node
) -> Optional[QuantizationParams]:
output_nodes: list[Node],
) -> List[QuantizationParams]:
"""
Get output QuantizationParams from a program.
Args:
program (ExportedProgram): The program to get output quantization parameters from.
output_nodes (list(Node)): A list of output nodes to get output quantization parameters from.
Returns:
QuantizationParams: The found quantization parameters.
Raises:
RuntimeError if no output quantization parameters are found.
"""

quant_params = None
for node in program.graph.nodes:
if (
node.target == torch.ops.quantized_decomposed.dequantize_per_tensor.default
and node == output_node.args[0][0]
):
quant_params = QuantizationParams(
node_name=node.args[0].name,
scale=node.args[1],
zp=node.args[2],
qmin=node.args[3],
qmax=node.args[4],
dtype=node.args[5],
quant_params = []
for node in output_nodes:
if node.target == torch.ops.quantized_decomposed.dequantize_per_tensor.default:
quant_params.append(
QuantizationParams(
node_name=node.args[0].name,
scale=node.args[1],
zp=node.args[2],
qmin=node.args[3],
qmax=node.args[4],
dtype=node.args[5],
)
)
break # break early, there's only one output node
if len(quant_params) == 0:
raise RuntimeError("No Quantization parameters not found in exported model.")
return quant_params


Expand Down Expand Up @@ -211,7 +214,7 @@ def __init__(
self.input_names: list[str] = None
self.output_name: str = None
self.qp_input: list[QuantizationParams] = None
self.qp_output: QuantizationParams = None
self.qp_output: list[QuantizationParams] = None
self.timeout = 480
self.target_board: str = None

Expand All @@ -226,19 +229,17 @@ def init_run(
):

self.input_names = _get_input_names(edge_program)
self.output_node = _get_output_node(exported_program)
self.output_name = self.output_node.name
self.output_nodes = _get_output_nodes(exported_program)

self.is_quantized = is_quantized
self.target_board = target_board

if is_quantized:
self.qp_input = _get_input_quantization_params(exported_program)
self.qp_output = _get_output_quantization_params(
exported_program, self.output_node
)
self.qp_output = _get_output_quantization_params(self.output_nodes)
else:
self.qp_input = [None] * len(self.input_names)
self.qp_output = None
self.qp_output = [None] * len(self.output_nodes)

self._has_init_run = True

Expand All @@ -265,7 +266,7 @@ def run_corstone(
save_bytes(self.intermediate_path, data, False, input_name, quant_param)

out_path = os.path.join(self.intermediate_path, "out")
out_path_with_suffix = out_path + "-0.bin"

input_paths = []
for name in self.input_names:
input_paths.append(
Expand All @@ -281,6 +282,7 @@ def run_corstone(
), f"Did not find build arm_executor_runner in path {elf_path}, run setup_testing.sh?"

cmd_line = f"executor_runner -m {pte_path} -o {out_path}"

for input_path in input_paths:
cmd_line += f" -i {input_path}"

Expand Down Expand Up @@ -362,11 +364,14 @@ def run_corstone(
raise RuntimeError(
f"Corstone simulation failed:\ncmd: {command_args[self.target_board]}\n, log: \n {result_stdout}\n{result.stderr.decode()}"
)

tosa_ref_output = np.fromfile(out_path_with_suffix, dtype=np.float32)
output_shape = self.output_node.args[0][0].meta["val"].shape
tosa_ref_output = torch.from_numpy(tosa_ref_output).reshape(output_shape)
return tosa_ref_output
output_np = []
for i, node in enumerate(self.output_nodes):
tosa_ref_output = np.fromfile(
os.path.join(self.intermediate_path, f"out-{i}.bin"), dtype=np.float32
)
output_shape = node.meta["val"].shape
output_np.append(torch.from_numpy(tosa_ref_output).reshape(output_shape))
return tuple(output_np)

def run_tosa_graph(
self, graph: TosaGraph, inputs: list[np.ndarray] | list[torch.Tensor]
Expand Down
8 changes: 4 additions & 4 deletions backends/arm/test/tester/analyze_output_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# Copyright 2024-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand All @@ -9,7 +9,7 @@
import torch
from executorch.backends.arm.test.runner_utils import (
_get_input_quantization_params,
_get_output_node,
_get_output_nodes,
_get_output_quantization_params,
)

Expand Down Expand Up @@ -228,9 +228,9 @@ def dump_error_output(
export_stage = tester.stages.get(tester.stage_name(Export), None)
quantize_stage = tester.stages.get(tester.stage_name(Quantize), None)
if export_stage is not None and quantize_stage is not None:
output_node = _get_output_node(export_stage.artifact)
output_nodes = _get_output_nodes(export_stage.artifact)
qp_input = _get_input_quantization_params(export_stage.artifact)
qp_output = _get_output_quantization_params(export_stage.artifact, output_node)
qp_output = _get_output_quantization_params(output_nodes)
logger.error(f"Input QuantArgs: {qp_input}")
logger.error(f"Output QuantArgs: {qp_output}")

Expand Down
37 changes: 22 additions & 15 deletions backends/arm/test/tester/arm_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import serializer.tosa_serializer as ts

import torch.fx
import torch.utils._pytree as pytree

from executorch.backends.arm.arm_backend import get_intermediate_path
from executorch.backends.arm.arm_partitioner import ArmPartitioner
Expand Down Expand Up @@ -302,21 +303,22 @@ def run_method_and_compare_outputs(

exported_program = self.stages[self.stage_name(tester.Export)].artifact
edge_program = edge_stage.artifact.exported_program()

self.runner_util.init_run(
exported_program,
edge_program,
is_quantized,
target_board,
)

quantization_scale = None
if is_quantized:
reference_stage = self.stages[self.stage_name(tester.Quantize)]
# bool output is quantized with none quantized output so allow
# self.runner_util.qp_output to be none
if self.runner_util.qp_output is not None:
quantization_scale = self.runner_util.qp_output.scale
quantization_scales = [qp.scale for qp in self.runner_util.qp_output]
else:
quantization_scales = [None] * len(self.runner_util.output_nodes)
reference_stage = self.stages[self.stage_name(InitialModel)]

logger.info(
Expand All @@ -334,21 +336,26 @@ def run_method_and_compare_outputs(
input_shape_str = ", ".join([str(list(i)) for i in input_shapes])
logger.info(f"Run #{run_iteration}, input shapes: {input_shape_str}")

reference_output = reference_stage.run_artifact(reference_input)
if not isinstance(reference_output, tuple):
reference_output = (reference_output,)
test_output = test_stage.run_artifact(reference_input)

self._compare_outputs(
reference_output,
test_output,
quantization_scale,
atol,
rtol,
qtol,
error_callbacks,
reference_outputs, _ = pytree.tree_flatten(
reference_stage.run_artifact(reference_input)
)
test_outputs, _ = pytree.tree_flatten(
test_stage.run_artifact(reference_input)
)

for reference_output, test_output, quantization_scale in zip(
reference_outputs, test_outputs, quantization_scales
):
self._compare_outputs(
reference_output,
test_output,
quantization_scale,
atol,
rtol,
qtol,
error_callbacks,
)

return self

def get_graph(self, stage: str | None = None) -> Graph:
Expand Down
25 changes: 2 additions & 23 deletions examples/cadence/operators/TARGETS
Original file line number Diff line number Diff line change
@@ -1,26 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest")
load("targets.bzl", "define_common_targets")

oncall("odai_jarvis")


python_unittest(
name = "test_add_op",
srcs = [
"test_add_op.py",
],
typing = True,
supports_static_listing = False,
deps = [
"fbsource//third-party/pypi/parameterized:parameterized",
"//caffe2:torch",
"//executorch/backends/cadence/aot:ops_registrations",
"//executorch/backends/cadence/aot:export_example",
"//executorch/backends/cadence/aot:compiler",
],
)
define_common_targets()
36 changes: 36 additions & 0 deletions examples/cadence/operators/targets.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest")

TESTS_LIST = [
"add_op",
"quantized_conv1d_op",
"quantized_linear_op",
]

def define_common_targets():
for op in TESTS_LIST:
_define_test_target(op)


def _define_test_target(test_name):
file_name = "test_{}".format(test_name)
python_unittest(
name = file_name,
srcs = [
"{}.py".format(file_name),
],
typing = True,
supports_static_listing = False,
deps = [
"fbsource//third-party/pypi/parameterized:parameterized",
"fbcode//caffe2:torch",
"fbcode//executorch/backends/cadence/aot:ops_registrations",
"fbcode//executorch/backends/cadence/aot:export_example",
"fbcode//executorch/backends/cadence/aot:compiler",
],
)
Loading

0 comments on commit 4b3f654

Please sign in to comment.