Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch_mlir.tools.import_onnx misses shape inference #3593

Closed
jinchen62 opened this issue Aug 5, 2024 · 8 comments
Closed

torch_mlir.tools.import_onnx misses shape inference #3593

jinchen62 opened this issue Aug 5, 2024 · 8 comments

Comments

@jinchen62
Copy link
Collaborator

Repro:
https://gist.github.com/jinchen62/554020991f814bfa3b6ee61287a78d49

Command:
torch-mlir-opt -pass-pipeline='builtin.module(func.func(convert-torch-onnx-to-torch),torch-lower-to-backend-contract,func.func(cse,canonicalize))' repro_transpose.mlir > repro_transpose_torch.mlir

Error:

repro_transpose.mlir:6:12: error: failed to legalize operation 'torch.operator' that was explicitly marked illegal
%241 = torch.operator "onnx.Transpose"(%240) {torch.onnx.perm = [0 : si64, 2 : si64, 1 : si64]} : (!torch.vtensor<[],f32>) -> !torch.vtensor<[],f32>
^
repro_transpose.mlir:6:12: note: see current operation: %8 = "torch.operator"(%7) <{name = "onnx.Transpose"}> {torch.onnx.perm = [0 : si64, 2 : si64, 1 : si64]} : (!torch.vtensor<[],f32>) -> !torch.vtensor<[],f32>

@jinchen62
Copy link
Collaborator Author

Also seems like a onnx.Identity issue to me. Is the result of onnx.Reshape supposed to have a specific shape? Or should we add the support of the unknown shape for the onnx.Transpose?

@rsuderman
Copy link
Contributor

Where did you get this IR from? The shapes on the onnx.Transpose operation are incorrect as they specify a rank-0 tensor when they should be rank 3. Onnx.Reshape is not returning the correct shape.

@jinchen62
Copy link
Collaborator Author

@rsuderman It's from one of models imported from ONNX. There are a lot of models having same problem. So is it like a onnx.Identity issue? Cuz in the same model I saw

%237 = torch.operator "onnx.Reshape"(%236#0, %73) : (!torch.vtensor<[1,256,16,16],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[4,64,256],f32>
%238 = torch.operator "onnx.Transpose"(%237) {torch.onnx.perm = [0 : si64, 2 : si64, 1 : si64]} : (!torch.vtensor<[4,64,256],f32>) -> !torch.vtensor<[4,256,64],f32>

where the %73 is the %0 constant in repro ir.

@ScottTodd
Copy link
Member

Is a verifier missing? I'd expect you could compile with --verify (or without --verify=false) to have the compiler stop as soon as the onnx.Transpose with invalid shapes is constructed.

@zjgarvey
Copy link
Collaborator

zjgarvey commented Aug 5, 2024

This looks like a shape inference issue to me. Are you using the torch-mlir import_onnx?

@jinchen62
Copy link
Collaborator Author

Got dynamic shapes instead of missing shapes by using onnx_importer of python binding. The IR looks like

%237 = torch.operator "onnx.Reshape"(%236#0, %73) : (!torch.vtensor<[1,256,16,16],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[4,64,256],f32>
%238 = torch.operator "onnx.Transpose"(%237) {torch.onnx.perm = [0 : si64, 2 : si64, 1 : si64]} : (!torch.vtensor<[4,64,256],f32>) -> !torch.vtensor<[4,256,64],f32>
%239 = torch.operator "onnx.Reshape"(%236#1, %182) : (!torch.vtensor<[1,256,16,16],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[?,?,?],f32>
%240 = torch.operator "onnx.Reshape"(%236#2, %181) : (!torch.vtensor<[1,256,16,16],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[?,?,?],f32>
%241 = torch.operator "onnx.Transpose"(%240) {torch.onnx.perm = [0 : si64, 2 : si64, 1 : si64]} : (!torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32>

@jinchen62 jinchen62 changed the title Failed to lower onnx.Transpose ONNX models imported by torch_mlir.tools.import_onnx miss shape inference Aug 9, 2024
@jinchen62
Copy link
Collaborator Author

iree-org/iree#18153

@jinchen62 jinchen62 changed the title ONNX models imported by torch_mlir.tools.import_onnx miss shape inference torch_mlir.tools.import_onnx misses shape inference Aug 9, 2024
@jinchen62
Copy link
Collaborator Author

Setting opset-version as 21 fixes the issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants