Skip to content

Commit

Permalink
wip - fix(ONNX): adds guard against mismatched dynamic meta dimension…
Browse files Browse the repository at this point in the history
…s in getValueList
  • Loading branch information
bjacobgordon committed Jan 7, 2025
1 parent bf594b0 commit 0ec6f6e
Showing 1 changed file with 14 additions and 0 deletions.
14 changes: 14 additions & 0 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,20 @@ Value getValueList(OpBinder binder, ConversionPatternRewriter &rewriter,
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));

MLIRContext *context = binder.op->getContext();
auto loc = binder.getLoc();
for (int64_t i = 0; i < 2; i++) {
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
Value ext = rewriter.create<Torch::AtenSelectIntOp>(
loc, selectResultType, operand, zero, selectIndex);
Value item = extract(operand, ext);

Value constantFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);

rewriter.create<Torch::RuntimeAssertOp>(
loc, constantFalse, rewriter.getStringAttr("Expected a scalar value"));
}
for (int i = 2; i < sizes[0]; i++) {
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
Expand Down

0 comments on commit 0ec6f6e

Please sign in to comment.