Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ONNX] simplify shapes fed to broadcast in Expand lowering (#3756)
Addresses ~200 onnx model compile failures in <https://github.com/nod-ai/SHARK-TestSuite> related to <iree-org/iree#18631>. This change simplifies the result of the generated broadcast op substantially, but reduces the case coverage slightly. The case which will become unsupported: - trying to actually broadcast a dynamic dim that is secretly 1. When does this case appear in practical scenarios? - for a model where onnx shape inference cannot figure out that a dim should be 1. Why do I think we should not support this case for now? 1. For all models with dynamic dim expand ops, the previous path uniformly generates uglier linalg IR (making it harder for IREE to fuse properly with other ops). 2. For models failing shape inference castastrophically enough to fail to see a dim is statically 1, we can try to apply constant folding in the onnx model before importing. Leaving this as a draft PR, since it may be more appropriate to fix the compilation failure in IREE rather than torch-mlir. ### Example of broadcast required in previous path: ```mlir %300 = linalg.generic {indexing_maps = [#map11], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} outs(%299 : tensor<?x12x?x?xi1>) { ^bb0(%out: i1): %306 = linalg.index 0 : index %307 = linalg.index 3 : index %308 = arith.index_cast %285 : i64 to index %309 = arith.cmpi eq, %308, %c1 : index %310 = arith.select %309, %c0, %306 : index %311 = arith.index_cast %286 : i64 to index %312 = arith.cmpi eq, %311, %c1 : index %313 = arith.select %312, %c0, %307 : index %extracted_79 = tensor.extract %reshape_78[%310, %c0, %c0, %313] : tensor<?x1x1x?xi1> linalg.yield %extracted_79 : i1 } -> tensor<?x12x?x?xi1> ``` ### Example of broadcast with simplified shape list: ```mlir %409 = linalg.generic {indexing_maps = [#map15, #map11], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%reshape_135 : tensor<?x1x1x?xi1>) outs(%408 : tensor<?x12x?x?xi1>) { ^bb0(%in: i1, %out: i1): linalg.yield %in : i1 } -> tensor<?x12x?x?xi1> ```
- Loading branch information