Skip to content

Commit

Permalink
Fix unsqueeze optimize pass
Browse files Browse the repository at this point in the history
Differential Revision: D69812661

Pull Request resolved: #8564
  • Loading branch information
daseyb authored Feb 19, 2025
1 parent b6ffe1a commit e5dc18a
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion backends/transforms/view_copy_to_squeeze_unsqueeze.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,11 @@ def find_unsqueeze_dim(
j = 0
idx = -1
while j < len(view_shape):
if input_shape[i] != view_shape[j]:
# account for added dim being last dim in view_shape
if i == j and j == len(input_shape):
if view_shape[j] != 1:
return None
elif input_shape[i] != view_shape[j]:
if view_shape[j] == 1:
idx = j
i -= 1
Expand Down

0 comments on commit e5dc18a

Please sign in to comment.