-
Notifications
You must be signed in to change notification settings - Fork 452
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
2025-02-07 nightly release (456928f)
- Loading branch information
pytorchbot
committed
Feb 7, 2025
1 parent
333230d
commit 1506bee
Showing
80 changed files
with
1,999 additions
and
209 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# Copyright 2025 Arm Limited and/or its affiliates. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import torch | ||
from executorch.devtools.visualization.visualization_utils import visualize_graph | ||
from executorch.exir import ExportedProgram | ||
from executorch.exir.pass_base import ExportPass, PassResult | ||
|
||
|
||
class VisualizePass(ExportPass): | ||
""" | ||
This pass visualizes the graph at the point of insertion in the pass manager | ||
""" | ||
|
||
def __init__(self, exported_program: ExportedProgram) -> None: | ||
super().__init__() | ||
self.exported_program = exported_program | ||
|
||
def call(self, graph_module: torch.fx.GraphModule) -> PassResult: | ||
visualize_graph(graph_module, self.exported_program) | ||
return PassResult(graph_module, False) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,15 @@ | ||
# Copyright 2024 Arm Limited and/or its affiliates. | ||
# Copyright 2024-2025 Arm Limited and/or its affiliates. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# pyre-unsafe | ||
|
||
from . import right_shift_support, to_copy_support, tosa_supported_operators # noqa | ||
from . import ( # noqa | ||
convolution_support, | ||
pool_2d_support, | ||
reduce_sum_support, | ||
right_shift_support, | ||
to_copy_support, | ||
tosa_supported_operators, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
# Copyright 2025 Arm Limited and/or its affiliates. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from typing import cast | ||
|
||
import torch | ||
import torch.fx as fx | ||
from executorch.backends.arm.operator_support.tosa_supported_operators import ( | ||
register_tosa_support_check, | ||
SupportedTOSAOperatorCheck, | ||
) | ||
from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification | ||
from executorch.exir.dialects._ops import ops as exir_ops | ||
|
||
|
||
@register_tosa_support_check | ||
class ConvolutionSupported(SupportedTOSAOperatorCheck): | ||
targets = [exir_ops.edge.aten.convolution.default] | ||
|
||
tosa_specs = [ | ||
TosaSpecification.create_from_string("TOSA-0.80+BI"), | ||
TosaSpecification.create_from_string("TOSA-0.80+MI"), | ||
] | ||
|
||
def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification): | ||
|
||
# Not implemented | ||
transposed = cast(bool, node.args[6]) | ||
output_padding = cast(list[int], node.args[7]) | ||
if transposed: | ||
return False | ||
|
||
for pad in output_padding: | ||
if pad != 0: | ||
return False | ||
|
||
# Hardware specific constraints | ||
if not (isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset): | ||
return True | ||
else: | ||
return self._is_node_supported_u55(node) | ||
|
||
def _is_node_supported_u55(self, node: fx.Node): | ||
"""Hardware constraints for Ethos-U-55 case, Vela 4.2.0 (25.02 release)""" | ||
|
||
shape_in = cast(torch.Tensor, node.all_input_nodes[0].meta["val"]).shape | ||
shape_out = node.meta["val"].shape | ||
kernel = cast(fx.Node, node.args[1]).meta["val"].shape | ||
group = cast(int, node.args[8]) | ||
|
||
C_in = shape_in[1] | ||
C_out = shape_out[1] | ||
if (C_in == group) and (C_out % C_in) == 0: | ||
# Depthwise convolution | ||
for dim in shape_in[1:]: | ||
if not 1 <= dim <= 65536: | ||
return False | ||
else: | ||
# Convolution | ||
if not 1 <= C_in <= 65536: | ||
return False | ||
|
||
kernel_w = kernel[2] | ||
kernel_h = kernel[3] if len(kernel) > 3 else 1 | ||
# Kernel condition misses constraint on sum of absolute weights | ||
if not 1 <= kernel_h <= 64 or not 1 <= kernel_w * kernel_h <= 4096: | ||
return False | ||
|
||
if not self._stride_condition(node): | ||
return False | ||
|
||
return True | ||
|
||
def _stride_condition(self, node: fx.Node) -> bool: | ||
"""This condition is somewhat complex but boils down | ||
to not supporting stride > 3, unless we have some special conditions. | ||
This condition is a simplified, relaxed version of the hardware constraint, | ||
since the actual constraint requires information not available | ||
here (without a lot of work). | ||
This means that we might accept ops that are not actually supported. | ||
""" | ||
strides = cast(list[int], node.args[3]) | ||
has_padding = any(pad > 0 for pad in cast(list[int], node.args[4])) | ||
dilations = cast(list[int], node.args[5]) | ||
if len(dilations) == 1: | ||
dilations = [dilations[0]] * 2 | ||
if len(strides) == 1: | ||
strides = [strides[0]] * 2 | ||
|
||
for stride, dilation in zip(strides, dilations): | ||
stride_condition = 1 <= stride <= 3 | ||
dilation_condition = (not has_padding) and (dilation == 1) | ||
if (not stride_condition) and (not dilation_condition): | ||
return False | ||
|
||
return True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
# Copyright 2025 Arm Limited and/or its affiliates. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from typing import cast | ||
|
||
import torch | ||
import torch.fx as fx | ||
from executorch.backends.arm.operator_support.tosa_supported_operators import ( | ||
register_tosa_support_check, | ||
SupportedTOSAOperatorCheck, | ||
) | ||
from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification | ||
from executorch.exir.dialects._ops import ops as exir_ops | ||
|
||
|
||
def kernel_check(kernel: tuple[int, int]) -> bool: | ||
if not (1 <= kernel[0] * kernel[1] <= 65536): | ||
return False | ||
return 1 <= kernel[1] <= 256 | ||
|
||
|
||
def stride_check(strides: tuple[int, int]) -> bool: | ||
return all(1 <= stride <= 3 for stride in strides) | ||
|
||
|
||
def dim_check(shape=torch.Size) -> bool: | ||
check = shape[0] == 1 | ||
for dim in shape: | ||
check &= 1 <= dim <= 65536 | ||
return check | ||
|
||
|
||
@register_tosa_support_check | ||
class AvgPool2dSupported(SupportedTOSAOperatorCheck): | ||
targets = [ | ||
exir_ops.edge.aten.avg_pool2d.default, | ||
] | ||
|
||
tosa_specs = [ | ||
TosaSpecification.create_from_string("TOSA-0.80+BI"), | ||
TosaSpecification.create_from_string("TOSA-0.80+MI"), | ||
] | ||
|
||
def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification): | ||
if not (isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset): | ||
return True | ||
|
||
# U55 case, Vela 4.2.0 (25.02 release) | ||
shape = cast(torch.Tensor, node.all_input_nodes[0].meta["val"]).shape | ||
kernel = cast(tuple[int, int], node.args[1]) | ||
stride = cast(tuple[int, int], node.args[2]) | ||
if len(node.args) > 3: | ||
# Padding case | ||
if not all(1 <= k <= 8 for k in kernel): | ||
return False | ||
else: | ||
if not kernel_check(kernel): | ||
return False | ||
|
||
return dim_check(shape) and stride_check(stride) | ||
|
||
|
||
@register_tosa_support_check | ||
class MaxPool2dSupported(SupportedTOSAOperatorCheck): | ||
targets = [ | ||
exir_ops.edge.aten.max_pool2d_with_indices.default, | ||
] | ||
|
||
tosa_specs = [ | ||
TosaSpecification.create_from_string("TOSA-0.80+BI"), | ||
TosaSpecification.create_from_string("TOSA-0.80+MI"), | ||
] | ||
|
||
def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification): | ||
if not (isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset): | ||
return True | ||
|
||
# U55 case, Vela 4.2.0 (25.02 release) | ||
shape = cast(torch.Tensor, node.all_input_nodes[0].meta["val"]).shape | ||
kernel = cast(tuple[int, int], node.args[1]) | ||
stride = cast(tuple[int, int], node.args[2]) | ||
|
||
return kernel_check(kernel) and dim_check(shape) and stride_check(stride) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
# Copyright 2025 Arm Limited and/or its affiliates. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from typing import cast | ||
|
||
import torch.fx as fx | ||
from executorch.backends.arm.operator_support.tosa_supported_operators import ( | ||
register_tosa_support_check, | ||
SupportedTOSAOperatorCheck, | ||
) | ||
from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification | ||
from executorch.exir.dialects._ops import ops as exir_ops | ||
|
||
|
||
@register_tosa_support_check | ||
class SumSupported(SupportedTOSAOperatorCheck): | ||
targets = [exir_ops.edge.aten.sum.dim_IntList] | ||
|
||
tosa_specs = [ | ||
TosaSpecification.create_from_string("TOSA-0.80+BI"), | ||
TosaSpecification.create_from_string("TOSA-0.80+MI"), | ||
] | ||
|
||
def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification): | ||
if not (isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset): | ||
return True | ||
|
||
# U55 case, Vela 4.2.0 (25.02 release) | ||
input_shape = node.all_input_nodes[0].meta["val"].shape | ||
dim_list = cast(list[int], node.args[1]) | ||
dim_list = [dim % len(input_shape) for dim in dim_list] | ||
|
||
for dim in dim_list: | ||
if not 1 <= input_shape[dim] <= 65536: | ||
return False | ||
|
||
# We can't be certain of which dim is the last in memory yet, | ||
# Always go for stricter condition. | ||
pre_R_product = 1.0 | ||
for length in input_shape[:dim]: | ||
pre_R_product *= length | ||
post_R_product = 1.0 | ||
for length in input_shape[dim + 1 :]: | ||
post_R_product *= length | ||
if not 1 <= pre_R_product <= 65536: | ||
return False | ||
if not 1 <= post_R_product <= 65536: | ||
return False | ||
return True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.