Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Non-value semantic ops should have stronger verifiers #2490

Closed
stellaraccident opened this issue Sep 27, 2023 · 5 comments
Closed

Non-value semantic ops should have stronger verifiers #2490

stellaraccident opened this issue Sep 27, 2023 · 5 comments

Comments

@stellaraccident
Copy link
Collaborator

stellaraccident commented Sep 27, 2023

Currently, in-place ops are defined like:

def Torch_AtenTanh_Op : Torch_Op<"aten.tanh_", [
    IsTrailingUnderscoreInplaceVariant,
    AllowsTypeRefinement
  ]> {
  let summary = "Generated op for `aten::tanh_ : (Tensor) -> (Tensor)`";
  let arguments = (ins
    AnyTorchTensorType:$self
  );
  let results = (outs
    AnyTorchTensorType:$result
  );
  let hasCustomAssemblyFormat = 1;
  let extraClassDefinition = [{
    ParseResult AtenTanh_Op::parse(OpAsmParser &parser, OperationState &result) {
      return parseDefaultTorchOp(parser, result, 1, 1);
    }
    void AtenTanh_Op::print(OpAsmPrinter &printer) {
      printDefaultTorchOp(printer, *this, 1, 1);
    }
  }];
}

The result is that if a !torch.vtensor is inadvertently passed, bad things happen later during lowering. We should tighten up verification by one of:

  1. Making the IsTrailingUnderscoreInplaceVariant trait perform stronger verification of the first argument/result type.
  2. Teach the torch_ods_gen.py to emit the proper, more constrained types (vs AnyTorchTensorType).

Either would produce the correct result but 2 is more explicit and provides better documentation value. Doing that would require some minor refactoring of the raw_emit_op helper of torch_ods_gen.py so that it tweaks the types if the trait is present.

On the other hand, implementing in 1 would allow us to emit a much more friendly error message, which may help for exotic cases.

@gptsarthak
Copy link
Contributor

I wanted to start contributing to torch-mlir and I think this would be a good beginner issue. I have experience with some very small contributions to upstream MLIR.

From what I understand,

  1. This would require us to do a let hasVerifier = 1 in the definition in GeneratedTorchOps.td and create a LogicalResult Tanh_Op::verify() function in, where? I dont know since we are generating the functions, we do not have them defined in TorchOps.cpp. I had this idea looking at Tosa.

  2. Assuming that Tanh can only accept Floating point tensors, this would as easy as replacing AnyTorchTensorType with something like Torch_FloatType.

How would you like me to proceed? What else can we do?

@stellaraccident
Copy link
Collaborator Author

stellaraccident commented Oct 17, 2023

It would be great to get a patch for this.

The primary complexity here is that the code in that TD file is generated from the pytorch op registry, so you have to teach torch_ods_gen.py how to do it.

I'm afk but I believe that there is a tablegen type for value tensors vs non value tensors. You probably don't need to further constrain these by dtype.

@gptsarthak
Copy link
Contributor

Have a look at #2519 when you have time. Thank you.

stellaraccident pushed a commit that referenced this issue Oct 21, 2023
Attempt to solve #2490

Changes for Non Value Semantic Ops having the
`IsTrailingUnderscoreInplaceVariant` trait :
- AnyTorchTensorType -> Torch_NonValueTensorType
- AnyTorchOptionalTensorType -> AnyTorchOptionalNonValueTensorType
- AnyTorchListOfOptionalTensorType ->
AnyTorchListOfOptionalNonValueTensorType
- AnyTorchListOfTensorType -> AnyTorchListOfNonValueTensorType

Created three new tensor types for optional and list non value tensors.
@ramiro050
Copy link
Collaborator

Can we close this, @stellaraccident @gptsarthak ?

@stellaraccident
Copy link
Collaborator Author

Yep. Looks fixed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants