diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index bed228671de1..e10564bbe26b 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -5405,8 +5405,11 @@ void BindSymbolicShapeOp::print(OpAsmPrinter &p) { } LogicalResult BindSymbolicShapeOp::verify() { - if (getShapeSymbols().empty()) - return emitOpError() << "requires non-empty shapeSymbols"; + if (getShapeSymbols().size() != + getShapeExpressions().getValue().getNumSymbols()) + return emitOpError() + << "requires equal number of shape symbol args and symbol args to " + "the attached affine map, since they are 1:1 mapped"; for (auto symbol : getShapeSymbols()) { Operation *definingOp = symbol.getDefiningOp(); diff --git a/test/Dialect/Torch/invalid.mlir b/test/Dialect/Torch/invalid.mlir index 5b732788faef..8f38c66ad154 100644 --- a/test/Dialect/Torch/invalid.mlir +++ b/test/Dialect/Torch/invalid.mlir @@ -381,13 +381,21 @@ func.func private @tensor.sparse() -> !torch.vtensor<[64,64],f32,12345> func.func @torch.symbolic_int$no_shape_symbols(%arg0: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> { %0 = torch.symbolic_int "s0" {min_val = 3, max_val = 6} : !torch.int - // expected-error @+1 {{op requires non-empty shapeSymbols}} + // expected-error @+1 {{op requires equal number of shape symbol args and symbol args to the attached affine map, since they are 1:1 mapped}} torch.bind_symbolic_shape %arg0, [], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32> return %arg0 : !torch.vtensor<[?],f32> } // ----- +// Verifier should not fail here since the op does not require shapeSymbols. +func.func @torch.symbolic_int$no_shape_symbols_no_symbols_in_map(%arg0: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> { + torch.bind_symbolic_shape %arg0, [], affine_map<()[] -> (1)> : !torch.vtensor<[?],f32> + return %arg0 : !torch.vtensor<[?],f32> +} + +// ----- + func.func @torch.symbolic_int$no_shape_symbols(%arg0: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> { %int0 = torch.constant.int 0 // expected-error @+1 {{shape symbol must be produced by a SymbolicIntOp}}