diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index d5c8adf35f00..a61f041d8263 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -2521,7 +2521,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( return failure(); auto shapeSizes = shapeType.getSizes(); - int64_t dataRank = dataType.getSizes().size(); + ArrayRef dataShape = dataType.getSizes(); + int64_t dataRank = dataShape.size(); int64_t shapeRank = shapeSizes.size(); if (shapeRank != 1 || shapeSizes[0] == Torch::kUnknownSize) return failure(); @@ -2543,22 +2544,43 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( // we are using torch implementation Torch::AtenBroadcastToOp which // takes list of int for (int i = 0; i < shapeSizes[0]; i++) { + // extract dim from shape Value selectIndex = rewriter.create( loc, rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); Value extract = rewriter.create( loc, selectResultType, shape, zero, selectIndex); - Value dim = rewriter.create( + Value selectDim = rewriter.create( loc, rewriter.getType(), extract); - - if (i + rankDifference >= 0) { + // compute dim to pass to broadcast op. For non-broadcastable dims, + // pass -1 + Value dim; + if (i + rankDifference >= 0 && dataShape[i + rankDifference] != 1) { + // 1. if dataShape[i + rankDiff] > 1, then this cannot be + // broadcasted + // 2. we will explicitly disallow broadcasting dynamic dims that are + // secretly 1. + dim = rewriter.create(loc, -1); + // Assert dataShape[i + rankDiff] >= selectDim. If both are + // constant, this should fold out. Value iv = rewriter.create(loc, i + rankDifference); auto sz = rewriter.create( loc, rewriter.getType(), data, iv); - dim = rewriter.create(loc, dim, sz); + Value gtSelect = + rewriter.create(loc, sz, selectDim); + rewriter.create( + loc, gtSelect, + rewriter.getStringAttr( + "onnx.Expand input has a dim that is not statically 1; " + "expected this dim >= dim provided shape.")); + } else { + // 1. excess selectDims get included in broadcast (shapeSizes[0] > + // dataRank) + // 2. selectDims which correspond to dataShape == 1 get included in + // broadcast + dim = selectDim; } - dimList.push_back(dim); } Value dimValueList = rewriter.create( diff --git a/projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py b/projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py index fc0d488b4787..a6e42e278757 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py +++ b/projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py @@ -42,7 +42,7 @@ def import_onnx(contents): # Import the ONNX model proto from the file contents: raw_model = onnx.load_from_string(contents) # since it does not affect current e2e tests, data_prop is left false here - model_proto = onnx.shape_inference.infer_shapes(raw_model) + model_proto = onnx.shape_inference.infer_shapes(raw_model, data_prop=True) # Import the ONNX module into an MLIR module: context = Context() diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index d3672941acdb..d9c2df1d83a0 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -1608,16 +1608,13 @@ func.func @test_expand_dim2_shape2(%arg0: !torch.vtensor<[1,4],f32>, %arg1: !tor // CHECK-DAG: %[[INT0_0:.+]] = torch.constant.int 0 // CHECK-DAG: %[[SEL0:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[2],si32>, !torch.int, !torch.int -> !torch.vtensor<[],si32> // CHECK-DAG: %[[ITEM0:.+]] = torch.aten.item %[[SEL0]] : !torch.vtensor<[],si32> -> !torch.int - // CHECK-DAG: %[[I0:.+]] = torch.constant.int 0 - // CHECK-DAG: %[[SZ0:.+]] = torch.aten.size.int %arg0, %[[I0]] : !torch.vtensor<[1,4],f32>, !torch.int -> !torch.int - // CHECK-DAG: %[[MX0:.+]] = torch.prim.max.int %[[ITEM0]], %[[SZ0]] : !torch.int, !torch.int -> !torch.int // CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1 // CHECK-DAG: %[[SEL1:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT1]] : !torch.vtensor<[2],si32>, !torch.int, !torch.int -> !torch.vtensor<[],si32> // CHECK-DAG: %[[ITEM1:.+]] = torch.aten.item %[[SEL1]] : !torch.vtensor<[],si32> -> !torch.int - // CHECK-DAG: %[[I1:.+]] = torch.constant.int 1 - // CHECK-DAG: %[[SZ1:.+]] = torch.aten.size.int %arg0, %[[I1]] : !torch.vtensor<[1,4],f32>, !torch.int -> !torch.int - // CHECK-DAG: %[[MX1:.+]] = torch.prim.max.int %[[ITEM1]], %[[SZ1]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[MX0]], %[[MX1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK-DAG: %[[Im1:.+]] = torch.constant.int -1 + // CHECK-DAG: %[[INT1_1:.+]] = torch.constant.int 1 + // CHECK-DAG: %[[SZ1:.+]] = torch.aten.size.int %arg0, %[[INT1_1]] : !torch.vtensor<[1,4],f32>, !torch.int -> !torch.int + // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[ITEM0]], %[[Im1]] : (!torch.int, !torch.int) -> !torch.list // CHECK: torch.aten.broadcast_to %arg0, %[[LIST]] : !torch.vtensor<[1,4],f32>, !torch.list -> !torch.vtensor<[3,4],f32> %0 = torch.operator "onnx.Expand"(%arg0, %arg1) : (!torch.vtensor<[1,4],f32>, !torch.vtensor<[2],si32>) -> !torch.vtensor<[3,4],f32> return %0 : !torch.vtensor<[3,4],f32> @@ -1634,16 +1631,15 @@ func.func @test_expand_dim2_shape3(%arg0: !torch.vtensor<[3,1],f32>, %arg1: !tor // CHECK-NEXT: %[[I1:.+]] = torch.constant.int 1 // CHECK-NEXT: %[[SEL1:.+]] = torch.aten.select.int %arg1, %[[I0]], %[[I1]] // CHECK-NEXT: %[[ITEM1:.+]] = torch.aten.item %[[SEL1]] + // CHECK-NEXT: %[[Im1:.+]] = torch.constant.int -1 // CHECK-NEXT: %[[D1:.+]] = torch.constant.int 0 // CHECK-NEXT: %[[SZ1:.+]] = torch.aten.size.int %arg0, %[[D1]] - // CHECK-NEXT: %[[MX1:.+]] = torch.prim.max.int %[[ITEM1]], %[[SZ1]] : !torch.int, !torch.int -> !torch.int + // CHECK-NEXT: %[[GE:.+]] = torch.aten.ge.int + // CHECK-NEXT: torch.runtime.assert %[[GE]] // CHECK-NEXT: %[[I2:.+]] = torch.constant.int 2 // CHECK-NEXT: %[[SEL2:.+]] = torch.aten.select.int %arg1, %[[I0]], %[[I2]] // CHECK-NEXT: %[[ITEM2:.+]] = torch.aten.item %[[SEL2]] - // CHECK-NEXT: %[[D2:.+]] = torch.constant.int 1 - // CHECK-NEXT: %[[SZ2:.+]] = torch.aten.size.int %arg0, %[[D2]] - // CHECK-NEXT: %[[MX2:.+]] = torch.prim.max.int %[[ITEM2]], %[[SZ2]] - // CHECK-NEXT: %[[LIST:.+]] = torch.prim.ListConstruct %[[ITEM0]], %[[MX1]], %[[MX2]] + // CHECK-NEXT: %[[LIST:.+]] = torch.prim.ListConstruct %[[ITEM0]], %[[Im1]], %[[ITEM2]] // CHECK-NEXT: %[[EXPAND:.+]] = torch.aten.broadcast_to %arg0, %[[LIST]] // CHECK: return %[[EXPAND]] %0 = torch.operator "onnx.Expand"(%arg0, %arg1) : (!torch.vtensor<[3,1],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[2,3,6],f32>