Skip to content

Commit

Permalink
fix the expand_copy lower issue
Browse files Browse the repository at this point in the history
Differential Revision: D69470884

Pull Request resolved: #8380
  • Loading branch information
billmguo authored Feb 12, 2025
1 parent 994d94d commit 73bb1f9
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
3 changes: 2 additions & 1 deletion backends/qualcomm/_passes/convert_bmm_to_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ def _get_ordered_inputs(
def call(self, graph_module: torch.fx.GraphModule):
graph = graph_module.graph
partitions = get_source_partitions(
graph, [operator.matmul, torch.matmul, torch.bmm]
graph,
[operator.matmul, torch.matmul, torch.bmm, torch.ops.aten.matmul.default],
)
for _, src_partitions in partitions.items():
for src_partition in src_partitions:
Expand Down
4 changes: 3 additions & 1 deletion backends/qualcomm/_passes/convert_to_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,9 @@ def _extract_bmm_ops(self, partitioned_nodes: List[edge_op]) -> List[torch.fx.No
return ret

def _convert(self, graph_module: torch.fx.GraphModule):
partitions = get_source_partitions(graph_module.graph, [torch.nn.Linear])
partitions = get_source_partitions(
graph_module.graph, [torch.nn.Linear, torch.ops.aten.linear.default]
)
for _, src_partitions in partitions.items():
for src_partition in src_partitions:
op_cnt = Counter(
Expand Down

0 comments on commit 73bb1f9

Please sign in to comment.