-
Notifications
You must be signed in to change notification settings - Fork 453
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Move ReplaceScalarWithTensorArgPass to transforms
The pass is general and can be used by multiple backends. Use it in Arm backend and make small adjustments to make it work. Signed-off-by: Erik Lundell <erik.lundell@arm.com> Change-Id: I61863d9cefb1753c604d67a6e44845af46ef7c60
- Loading branch information
1 parent
5e4d6b6
commit 8fcd2b2
Showing
7 changed files
with
123 additions
and
65 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# Copyright 2025 Arm Limited and/or its affiliates. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from typing import Dict | ||
|
||
import torch | ||
from executorch.exir.dialects._ops import ops as exir_ops | ||
from executorch.exir.dialects.edge._ops import EdgeOpOverload | ||
from executorch.exir.pass_base import ExportPass | ||
|
||
|
||
class ReplaceScalarWithTensorArgPass(ExportPass): | ||
""" | ||
For binary ops like add.Scalar, sub.Scalar mul.Scalar, and div.Scalar, | ||
replace the scalar arg with Tensor arg. | ||
""" | ||
|
||
scalar_to_tensor_ops: Dict[EdgeOpOverload, EdgeOpOverload] = { | ||
exir_ops.edge.aten.add.Scalar: exir_ops.edge.aten.add.Tensor, | ||
exir_ops.edge.aten.sub.Scalar: exir_ops.edge.aten.sub.Tensor, | ||
exir_ops.edge.aten.mul.Scalar: exir_ops.edge.aten.mul.Tensor, | ||
exir_ops.edge.aten.div.Scalar: exir_ops.edge.aten.div.Tensor, | ||
torch.ops.aten.add.Scalar: torch.ops.aten.add.Tensor, | ||
torch.ops.aten.sub.Scalar: torch.ops.aten.sub.Tensor, | ||
torch.ops.aten.mul.Scalar: torch.ops.aten.mul.Tensor, | ||
torch.ops.aten.div.Scalar: torch.ops.aten.div.Tensor, | ||
} | ||
|
||
def get_replacement(self, op, args, kwargs, meta): | ||
return super().call_operator( | ||
# Replace with .Tensor variant. | ||
op=self.scalar_to_tensor_ops[op], | ||
args=( | ||
# Tensor arg. | ||
args[0], | ||
# Scalar arg - replace with aten.full tensor. | ||
super().call_operator( | ||
exir_ops.edge.aten.full.default, | ||
args=( | ||
(1,), | ||
args[1], | ||
), | ||
kwargs={"dtype": args[0].to_tensor().dtype}, | ||
meta=meta, | ||
), | ||
# Other args. | ||
*args[2:], | ||
), | ||
kwargs=kwargs, | ||
meta=meta, | ||
) | ||
|
||
def call_operator(self, op, args, kwargs, meta): | ||
if op not in self.scalar_to_tensor_ops: | ||
return super().call_operator(op, args, kwargs, meta) | ||
|
||
# There must be exactly 2 args (3 for add and sub containing alpha) | ||
assert len(args) == 2 or len(args) == 3 | ||
|
||
# If there are two args, just replace the op. | ||
if len(args) == 2: | ||
return self.get_replacement(op, args, kwargs, meta) | ||
|
||
# In case the op has three args, it must be scalar add/sub op. | ||
if ( | ||
op not in {exir_ops.edge.aten.add.Scalar, exir_ops.edge.aten.sub.Scalar} | ||
or "alpha" in kwargs | ||
): | ||
return super().call_operator(op, args, kwargs, meta) | ||
|
||
return self.get_replacement(op, args, kwargs, meta) |