Skip to content

Commit

Permalink
Enable quantized add
Browse files Browse the repository at this point in the history
Differential Revision: D69441041

Pull Request resolved: #8584
  • Loading branch information
mcremon-meta authored Feb 20, 2025
1 parent 3e188fe commit 75d4abc
Show file tree
Hide file tree
Showing 5 changed files with 146 additions and 3 deletions.
44 changes: 44 additions & 0 deletions backends/cadence/aot/ops_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@
"quantized_add(Tensor X, Tensor X_scale, Tensor X_zero_point, Tensor Y, Tensor Y_scale, "
"Tensor Y_zero_point, float out_scale, int out_zero_point) -> (Tensor Z)"
)
lib.define(
"quantized_add.per_tensor(Tensor X, float X_scale, int X_zero_point, Tensor Y, float Y_scale, "
"int Y_zero_point, float out_scale, int out_zero_point) -> (Tensor Z)"
)
lib.define(
"quantized_mul(Tensor X, Tensor X_scale, Tensor X_zero_point, Tensor Y, Tensor Y_scale, "
"Tensor Y_zero_point, float out_scale, int out_zero_point) -> (Tensor Z)"
Expand Down Expand Up @@ -175,6 +179,10 @@
"quantized_add.out(Tensor X, Tensor X_scale, Tensor X_zero_point, Tensor Y, Tensor Y_scale, "
"Tensor Y_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)"
)
lib.define(
"quantized_add.per_tensor_out(Tensor X, float X_scale, int X_zero_point, Tensor Y, float Y_scale, "
"int Y_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)"
)
lib.define(
"quantized_mul.out(Tensor X, Tensor X_scale, Tensor X_zero_point, Tensor Y, Tensor Y_scale, "
"Tensor Y_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)"
Expand Down Expand Up @@ -290,6 +298,42 @@ def dequantize_per_tensor_meta(
return input.new_empty(input.size(), dtype=torch.float)


@register_fake("cadence::quantized_add")
def quantized_add_meta(
X: torch.Tensor,
X_scale: torch.Tensor,
X_zero_point: torch.Tensor,
Y: torch.Tensor,
Y_scale: torch.Tensor,
Y_zero_point: torch.Tensor,
out_scale: float,
out_zero_point: int,
) -> torch.Tensor:
out_size = X.size()
if list(X.size()) == [1]:
out_size = Y.size()

return X.new_empty(out_size, dtype=X.dtype)


@register_fake("cadence::quantized_add.per_tensor")
def quantized_add_per_tensor_meta(
X: torch.Tensor,
X_scale: float,
X_zero_point: int,
Y: torch.Tensor,
Y_scale: float,
Y_zero_point: int,
out_scale: float,
out_zero_point: int,
) -> torch.Tensor:
out_size = X.size()
if list(X.size()) == [1]:
out_size = Y.size()

return X.new_empty(out_size, dtype=X.dtype)


@register_fake("cadence::quantized_linear")
def quantized_linear_meta(
src: torch.Tensor,
Expand Down
53 changes: 51 additions & 2 deletions backends/cadence/aot/quantizer/fusion_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch
from executorch.backends.cadence.aot.quantizer.patterns import (
AddmmPattern,
AddPattern,
BmmPattern,
Conv1dPattern,
Conv2dPattern,
Expand Down Expand Up @@ -41,6 +42,47 @@
ReluPatterns = (ReluPattern0, ReluPattern1)


def get_args_and_kwargs_add(
graph_module: GraphModule,
inputs_inputs: List[fx.Node],
dequants_inputs: List[fx.Node],
quant_node: fx.Node,
) -> Tuple[Tuple[ArgsType, ...], Dict[str, ArgsType]]:
X_scale_ = graph_module.graph.call_function(
torch.ops.aten.full.default,
([1], dequants_inputs[0].args[1]),
{"dtype": torch.float},
)
X_zero_point_ = graph_module.graph.call_function(
torch.ops.aten.full.default,
([1], dequants_inputs[0].args[2]),
{"dtype": torch.int32},
)
Y_scale_ = graph_module.graph.call_function(
torch.ops.aten.full.default,
([1], dequants_inputs[1].args[1]),
{"dtype": torch.float},
)
Y_zero_point_ = graph_module.graph.call_function(
torch.ops.aten.full.default,
([1], dequants_inputs[1].args[2]),
{"dtype": torch.int32},
)
args = (
inputs_inputs[0],
X_scale_,
X_zero_point_,
inputs_inputs[1],
Y_scale_,
Y_zero_point_,
quant_node.args[1],
quant_node.args[2],
)

kwargs = {}
return args, kwargs


# Helper function to get the args and kwargs for the linear replacement op
def get_args_and_kwargs_linear(
graph_module: GraphModule,
Expand Down Expand Up @@ -339,7 +381,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
)
for fused_partition in fused_partitions:
anchors = pattern.get_anchors(graph_module, fused_partition)
if not anchors:
if not anchors or anchors.empty:
continue
if any(self.is_fused(p.nodes) for p in fused_partition):
continue
Expand Down Expand Up @@ -385,7 +427,14 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
inputs_inputs + weights_inputs + other_inputs + bias_inputs
)
kwargs = {}
if isinstance(pattern, (Conv1dPattern, Conv2dPattern)):
if isinstance(pattern, AddPattern):
args, kwargs = get_args_and_kwargs_add(
graph_module,
inputs_inputs,
dequants_inputs,
quant_node,
)
elif isinstance(pattern, (Conv1dPattern, Conv2dPattern)):
args, kwargs = get_args_and_kwargs_conv(
graph_module,
inputs_inputs,
Expand Down
33 changes: 33 additions & 0 deletions backends/cadence/aot/quantizer/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class PartitionAnchors:
output: List[Union[Tuple[fx.Node], Tuple[fx.Node, SharedQuantizationSpec]]] = field(
default_factory=list
)
empty: bool = False


class QuantizationPattern(ABC):
Expand Down Expand Up @@ -101,6 +102,38 @@ def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_linear


class AddPattern(QuantizationPattern):
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.add.Tensor]

def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> PartitionAnchors:
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
add_node = fused_partition[0].nodes[-1]

# Bail if:
# - the add node is not a tensor add
# - the add node has kwargs (e.g. alpha)
is_tensor_add = isinstance(add_node.args[0], fx.Node) and isinstance(
add_node.args[1], fx.Node
)
if not is_tensor_add or len(add_node.kwargs) > 0:
return PartitionAnchors(
empty=True,
)

return PartitionAnchors(
inputs=[(add_node, 0), (add_node, 1)],
weights=[],
biases=[],
output=[(add_node,)],
)

def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_add.default


class BmmPattern(QuantizationPattern):
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.bmm.default]
Expand Down
15 changes: 14 additions & 1 deletion backends/cadence/aot/quantizer/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch
from executorch.backends.cadence.aot.quantizer.patterns import (
AddmmPattern,
AddPattern,
BmmPattern,
Conv1dPattern,
Conv2dPattern,
Expand Down Expand Up @@ -109,7 +110,7 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
continue

anchors = self.pattern.get_anchors(model, fused_partition)
if not anchors:
if not anchors or anchors.empty:
continue
if is_annotated(
[
Expand Down Expand Up @@ -211,3 +212,15 @@ def __init__(
self,
) -> None:
super().__init__([])


class CadenceWakeWordQuantizer(CadenceQuantizer):
"""
Quantizer for WakeWord, including add
"""

def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None:
if quantizers is None:
quantizers = get_cadence_default_quantizers()
quantizers.append(CadenceAtenQuantizer(AddPattern(), qconfig_A8uW8u))
super().__init__(quantizers)
4 changes: 4 additions & 0 deletions backends/cadence/aot/replace_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1839,6 +1839,10 @@ class ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass(ExportPass):
replaced_scalar_args: dict[
EdgeOpOverloadPacket, tuple[EdgeOpOverload, Sequence[int]]
] = {
exir_ops.edge.cadence.quantized_add: (
exir_ops.edge.cadence.quantized_add.per_tensor,
[1, 2, 4, 5],
),
exir_ops.edge.cadence.quantized_conv: (
exir_ops.edge.cadence.quantized_conv.per_tensor,
[8, 9, 12, 13],
Expand Down

0 comments on commit 75d4abc

Please sign in to comment.