From 617c1c76ce4d0410e2318dbd25d69c68db45388c Mon Sep 17 00:00:00 2001 From: Prathamesh Tagore <63031630+meshtag@users.noreply.github.com> Date: Wed, 2 Oct 2024 18:25:54 +0530 Subject: [PATCH] [torch.bind_symbolic_shape] Fix verifier for shapeSymbol detection (#3751) The op can be valid with no attached shape symbols if they are not required by the corresponding affine map. Fix the verifier to consider number of arguments for both. --- lib/Dialect/Torch/IR/TorchOps.cpp | 7 +++++-- test/Dialect/Torch/invalid.mlir | 10 +++++++++- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index bed228671de1..e10564bbe26b 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -5405,8 +5405,11 @@ void BindSymbolicShapeOp::print(OpAsmPrinter &p) { } LogicalResult BindSymbolicShapeOp::verify() { - if (getShapeSymbols().empty()) - return emitOpError() << "requires non-empty shapeSymbols"; + if (getShapeSymbols().size() != + getShapeExpressions().getValue().getNumSymbols()) + return emitOpError() + << "requires equal number of shape symbol args and symbol args to " + "the attached affine map, since they are 1:1 mapped"; for (auto symbol : getShapeSymbols()) { Operation *definingOp = symbol.getDefiningOp(); diff --git a/test/Dialect/Torch/invalid.mlir b/test/Dialect/Torch/invalid.mlir index 5b732788faef..8f38c66ad154 100644 --- a/test/Dialect/Torch/invalid.mlir +++ b/test/Dialect/Torch/invalid.mlir @@ -381,13 +381,21 @@ func.func private @tensor.sparse() -> !torch.vtensor<[64,64],f32,12345> func.func @torch.symbolic_int$no_shape_symbols(%arg0: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> { %0 = torch.symbolic_int "s0" {min_val = 3, max_val = 6} : !torch.int - // expected-error @+1 {{op requires non-empty shapeSymbols}} + // expected-error @+1 {{op requires equal number of shape symbol args and symbol args to the attached affine map, since they are 1:1 mapped}} torch.bind_symbolic_shape %arg0, [], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32> return %arg0 : !torch.vtensor<[?],f32> } // ----- +// Verifier should not fail here since the op does not require shapeSymbols. +func.func @torch.symbolic_int$no_shape_symbols_no_symbols_in_map(%arg0: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> { + torch.bind_symbolic_shape %arg0, [], affine_map<()[] -> (1)> : !torch.vtensor<[?],f32> + return %arg0 : !torch.vtensor<[?],f32> +} + +// ----- + func.func @torch.symbolic_int$no_shape_symbols(%arg0: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> { %int0 = torch.constant.int 0 // expected-error @+1 {{shape symbol must be produced by a SymbolicIntOp}}