diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 5fb17c79a65b7..a1f05e4548a36 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -206,6 +206,20 @@ Value getValueList(OpBinder binder, ConversionPatternRewriter &rewriter, rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); MLIRContext *context = binder.op->getContext(); + auto loc = binder.getLoc(); + for (int64_t i = 0; i < 2; i++) { + Value selectIndex = rewriter.create( + loc, rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); + Value ext = rewriter.create( + loc, selectResultType, operand, zero, selectIndex); + Value item = extract(operand, ext); + + Value constantFalse = rewriter.create(loc, false); + + rewriter.create( + loc, constantFalse, rewriter.getStringAttr("Expected a scalar value")); + } for (int i = 2; i < sizes[0]; i++) { Value selectIndex = rewriter.create( binder.getLoc(), rewriter.getType(),