Skip to content

Commit

Permalink
Add more descriptive error message to torch_ods_gen.py. (#3108)
Browse files Browse the repository at this point in the history
Added error message when adding new torch op to
[torch_ods_gen.py](https://github.com/llvm/torch-mlir/compare/main...IanWood1:torch-mlir:ods_gen_error_message?expand=1#diff-889b60b904ed67a5065a14e8de6fc89e00e199577e4d2bfa134ac4d1c89832d2).


New message displays which op key is failing and possible matches in the
torch `Registry`.
```Op does not match any Torch ops in Registry 
Given op: 
    "aten::hardtanh_wrong : (Tensor, Scalar) -> (Tensor)" 
Possible matches: 
    "aten::hardshrink : (Tensor, Scalar) -> (Tensor)" 
    "aten::hardtanh_ : (Tensor, Scalar, Scalar) -> (Tensor)" 
    "aten::hardtanh : (Tensor, Scalar, Scalar) -> (Tensor)"
    "aten::clamp_min : (Tensor, Scalar) -> (Tensor)" 
    "aten::linalg_cond : (Tensor, Scalar?) -> (Tensor)"```



Also, ran black formatting on file. Based on LLVM style guides this seems to be correct, but I can revert the formatting if needed.
  • Loading branch information
IanWood1 authored Apr 9, 2024
1 parent 8d5e257 commit 8ff2852
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
Expand All @@ -9,6 +10,7 @@

import io
import itertools
import difflib

from .utils import TextEmitter

Expand Down Expand Up @@ -407,6 +409,18 @@ def get_by_triple(self, key: Tuple[str, str, str]):
"""Looks up a JitOperator by its unique "triple"."""
return self.by_triple[key]

def assert_key_in_registry(self, key: str):
if key in self.by_unique_key:
return
print(
f'ERROR: Op does not match any Torch ops in Registry\nGiven op:\n\t"{key}"'
)
matches = difflib.get_close_matches(key, self.by_unique_key.keys(), n=5)
if len(matches):
print("Possible matches:")
print("\n".join(f'\t"{match}"' for match in matches))
exit(1)


# A Dict[str, _] mapping attribute names to:
# - str (e.g. {'name': 'dim'} )
Expand All @@ -420,3 +434,4 @@ def get_by_triple(self, key: Tuple[str, str, str]):
# - Tuple[str] (e.g. {'name': ('aten::size', 'int')} )
# - SIGLIST_TYPE (e.g. {'arguments': [...], 'returns': [...]} )
OP_INFO_DICT = Dict[str, Union[bool, Tuple[str], SIGLIST_TYPE]]

Original file line number Diff line number Diff line change
Expand Up @@ -243,9 +243,11 @@ def emit_op(operator: JitOperator,

def emit_ops(emitter_td: TextEmitter, registry: Registry):
def emit(key, **kwargs):
registry.assert_key_in_registry(key)
emit_op(registry[key], emitter_td, **kwargs)

def emit_with_mutating_variants(key, **kwargs):
registry.assert_key_in_registry(key)
operator = registry[key]
emit_op(operator, emitter_td, **kwargs)
ns, unqual, overload = operator.triple
Expand Down

0 comments on commit 8ff2852

Please sign in to comment.