Skip to content

Commit

Permalink
[torch.bind_symbolic_shape] Fix verifier for shapeSymbol detection (#…
Browse files Browse the repository at this point in the history
…3751)

The op can be valid with no attached shape symbols if they are not
required by the corresponding affine map. Fix the verifier to consider
number of arguments for both.
  • Loading branch information
meshtag authored Oct 2, 2024
1 parent b1413a6 commit 617c1c7
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 3 deletions.
7 changes: 5 additions & 2 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5405,8 +5405,11 @@ void BindSymbolicShapeOp::print(OpAsmPrinter &p) {
}

LogicalResult BindSymbolicShapeOp::verify() {
if (getShapeSymbols().empty())
return emitOpError() << "requires non-empty shapeSymbols";
if (getShapeSymbols().size() !=
getShapeExpressions().getValue().getNumSymbols())
return emitOpError()
<< "requires equal number of shape symbol args and symbol args to "
"the attached affine map, since they are 1:1 mapped";

for (auto symbol : getShapeSymbols()) {
Operation *definingOp = symbol.getDefiningOp();
Expand Down
10 changes: 9 additions & 1 deletion test/Dialect/Torch/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -381,13 +381,21 @@ func.func private @tensor.sparse() -> !torch.vtensor<[64,64],f32,12345>

func.func @torch.symbolic_int$no_shape_symbols(%arg0: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> {
%0 = torch.symbolic_int "s0" {min_val = 3, max_val = 6} : !torch.int
// expected-error @+1 {{op requires non-empty shapeSymbols}}
// expected-error @+1 {{op requires equal number of shape symbol args and symbol args to the attached affine map, since they are 1:1 mapped}}
torch.bind_symbolic_shape %arg0, [], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32>
return %arg0 : !torch.vtensor<[?],f32>
}

// -----

// Verifier should not fail here since the op does not require shapeSymbols.
func.func @torch.symbolic_int$no_shape_symbols_no_symbols_in_map(%arg0: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> {
torch.bind_symbolic_shape %arg0, [], affine_map<()[] -> (1)> : !torch.vtensor<[?],f32>
return %arg0 : !torch.vtensor<[?],f32>
}

// -----

func.func @torch.symbolic_int$no_shape_symbols(%arg0: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> {
%int0 = torch.constant.int 0
// expected-error @+1 {{shape symbol must be produced by a SymbolicIntOp}}
Expand Down

0 comments on commit 617c1c7

Please sign in to comment.