From be9849d774f536a23808a4c9a3cbbd8cd7b755ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Kardos?= Date: Mon, 5 Jun 2023 10:35:57 +0200 Subject: [PATCH 01/54] stuff --- src/gt4py/next/errors/__init__.py | 1 + src/gt4py/next/errors/exceptions.py | 31 +++++++++ src/gt4py/next/ffront/dialect_parser.py | 65 ++++++------------- src/gt4py/next/ffront/func_to_foast.py | 50 +++++++------- src/gt4py/next/ffront/func_to_past.py | 24 +++---- .../ffront_tests/ffront_test_utils.py | 5 +- .../errors_tests/test_compilation_error.py | 15 +++++ 7 files changed, 107 insertions(+), 84 deletions(-) create mode 100644 src/gt4py/next/errors/__init__.py create mode 100644 src/gt4py/next/errors/exceptions.py create mode 100644 tests/next_tests/unit_tests/errors_tests/test_compilation_error.py diff --git a/src/gt4py/next/errors/__init__.py b/src/gt4py/next/errors/__init__.py new file mode 100644 index 0000000000..431e0c5229 --- /dev/null +++ b/src/gt4py/next/errors/__init__.py @@ -0,0 +1 @@ +from .exceptions import CompilationError \ No newline at end of file diff --git a/src/gt4py/next/errors/exceptions.py b/src/gt4py/next/errors/exceptions.py new file mode 100644 index 0000000000..5c58c6cdd0 --- /dev/null +++ b/src/gt4py/next/errors/exceptions.py @@ -0,0 +1,31 @@ +from gt4py.eve import SourceLocation + + +class CompilationError(SyntaxError): + def __init__(self, location: SourceLocation, message: str): + super().__init__( + message, + ( + location.source, + location.line, + location.column, + None, + location.end_line, + location.end_column + ) + ) + + @property + def location(self): + return SourceLocation( + source=self.filename, + line=self.lineno, + column=self.offset, + end_line=self.end_lineno, + end_column=self.end_offset + ) + + +class UndefinedSymbolError(CompilationError): + def __init__(self, location: SourceLocation, name: str): + super().__init__(location, f"name '{name}' is not defined") \ No newline at end of file diff --git a/src/gt4py/next/ffront/dialect_parser.py b/src/gt4py/next/ffront/dialect_parser.py index 97afa0f7da..c23a93d0c0 100644 --- a/src/gt4py/next/ffront/dialect_parser.py +++ b/src/gt4py/next/ffront/dialect_parser.py @@ -107,51 +107,24 @@ def apply( annotations: dict[str, Any], ) -> DialectRootT: source, filename, starting_line = source_definition - try: - line_offset = starting_line - 1 - definition_ast: ast.AST - definition_ast = ast.parse(textwrap.dedent(source)).body[0] - definition_ast = ast.increment_lineno(definition_ast, line_offset) - line_offset = 0 # line numbers are correct from now on - - definition_ast = RemoveDocstrings.apply(definition_ast) - definition_ast = FixMissingLocations.apply(definition_ast) - output_ast = cls._postprocess_dialect_ast( - cls( - source_definition=source_definition, - closure_vars=closure_vars, - annotations=annotations, - ).visit(cls._preprocess_definition_ast(definition_ast)), - closure_vars, - annotations, - ) - except SyntaxError as err: - _ensure_syntax_error_invariants(err) - - # The ast nodes do not contain information about the path of the - # source file or its contents. We add this information here so - # that raising an error using :func:`DialectSyntaxError.from_AST` - # does not require passing the information on every invocation. - if not err.filename: - err.filename = filename - - # ensure line numbers are relative to the file (respects `starting_line`) - if err.lineno: - err.lineno = err.lineno + line_offset - if err.end_lineno: - err.end_lineno = err.end_lineno + line_offset - - if not err.text: - if err.lineno: - source_lineno = err.lineno - starting_line - source_end_lineno = ( - (err.end_lineno - starting_line) if err.end_lineno else source_lineno - ) - err.text = "\n".join(source.splitlines()[source_lineno : source_end_lineno + 1]) - else: - err.text = source - - raise err + + line_offset = starting_line - 1 + definition_ast: ast.AST + definition_ast = ast.parse(textwrap.dedent(source)).body[0] + definition_ast = ast.increment_lineno(definition_ast, line_offset) + line_offset = 0 # line numbers are correct from now on + + definition_ast = RemoveDocstrings.apply(definition_ast) + definition_ast = FixMissingLocations.apply(definition_ast) + output_ast = cls._postprocess_dialect_ast( + cls( + source_definition=source_definition, + closure_vars=closure_vars, + annotations=annotations, + ).visit(cls._preprocess_definition_ast(definition_ast)), + closure_vars, + annotations, + ) return output_ast @@ -178,5 +151,5 @@ def generic_visit(self, node: ast.AST) -> None: msg=f"Nodes of type {type(node).__module__}.{type(node).__qualname__} not supported in dialect.", ) - def _make_loc(self, node: ast.AST) -> SourceLocation: + def get_location(self, node: ast.AST) -> SourceLocation: return SourceLocation.from_AST(node, source=self.source_definition.filename) diff --git a/src/gt4py/next/ffront/func_to_foast.py b/src/gt4py/next/ffront/func_to_foast.py index 03f1187a7a..dd97ffa259 100644 --- a/src/gt4py/next/ffront/func_to_foast.py +++ b/src/gt4py/next/ffront/func_to_foast.py @@ -151,7 +151,7 @@ def _builtin_type_constructor_symbols( def visit_FunctionDef(self, node: ast.FunctionDef, **kwargs) -> foast.FunctionDefinition: closure_var_symbols, skip_names = self._builtin_type_constructor_symbols( - self.closure_vars, self._make_loc(node) + self.closure_vars, self.get_location(node) ) for name in self.closure_vars.keys(): if name in skip_names: @@ -161,11 +161,11 @@ def visit_FunctionDef(self, node: ast.FunctionDef, **kwargs) -> foast.FunctionDe id=name, type=ts.DeferredType(constraint=None), namespace=dialect_ast_enums.Namespace.CLOSURE, - location=self._make_loc(node), + location=self.get_location(node), ) ) - new_body = self._visit_stmts(node.body, self._make_loc(node), **kwargs) + new_body = self._visit_stmts(node.body, self.get_location(node), **kwargs) if deduce_stmt_return_kind(new_body) == StmtReturnKind.NO_RETURN: raise FieldOperatorSyntaxError.from_AST( @@ -177,7 +177,7 @@ def visit_FunctionDef(self, node: ast.FunctionDef, **kwargs) -> foast.FunctionDe params=self.visit(node.args, **kwargs), body=new_body, closure_vars=closure_var_symbols, - location=self._make_loc(node), + location=self.get_location(node), ) def visit_arguments(self, node: ast.arguments) -> list[foast.DataSymbol]: @@ -191,7 +191,7 @@ def visit_arg(self, node: ast.arg) -> foast.DataSymbol: raise FieldOperatorSyntaxError.from_AST( node, msg="Only arguments of type DataType are allowed." ) - return foast.DataSymbol(id=node.arg, location=self._make_loc(node), type=new_type) + return foast.DataSymbol(id=node.arg, location=self.get_location(node), type=new_type) def visit_Assign(self, node: ast.Assign, **kwargs) -> foast.Assign | foast.TupleTargetAssign: target = node.targets[0] # there is only one element after assignment passes @@ -207,10 +207,10 @@ def visit_Assign(self, node: ast.Assign, **kwargs) -> foast.Assign | foast.Tuple foast.Starred( id=foast.DataSymbol( id=self.visit(elt.value).id, - location=self._make_loc(elt), + location=self.get_location(elt), type=ts.DeferredType(constraint=ts.DataType), ), - location=self._make_loc(elt), + location=self.get_location(elt), type=ts.DeferredType(constraint=ts.DataType), ) ) @@ -218,13 +218,13 @@ def visit_Assign(self, node: ast.Assign, **kwargs) -> foast.Assign | foast.Tuple new_targets.append( foast.DataSymbol( id=self.visit(elt).id, - location=self._make_loc(elt), + location=self.get_location(elt), type=ts.DeferredType(constraint=ts.DataType), ) ) return foast.TupleTargetAssign( - targets=new_targets, value=self.visit(node.value), location=self._make_loc(node) + targets=new_targets, value=self.visit(node.value), location=self.get_location(node) ) if not isinstance(target, ast.Name): @@ -241,11 +241,11 @@ def visit_Assign(self, node: ast.Assign, **kwargs) -> foast.Assign | foast.Tuple return foast.Assign( target=foast.DataSymbol( id=target.id, - location=self._make_loc(target), + location=self.get_location(target), type=ts.DeferredType(constraint=constraint_type), ), value=new_value, - location=self._make_loc(node), + location=self.get_location(node), ) def visit_AnnAssign(self, node: ast.AnnAssign, **kwargs) -> foast.Assign: @@ -265,11 +265,11 @@ def visit_AnnAssign(self, node: ast.AnnAssign, **kwargs) -> foast.Assign: return foast.Assign( target=foast.Symbol[ts.FieldType]( id=node.target.id, - location=self._make_loc(node.target), + location=self.get_location(node.target), type=target_type, ), value=self.visit(node.value) if node.value else None, - location=self._make_loc(node), + location=self.get_location(node), ) @staticmethod @@ -298,33 +298,33 @@ def visit_Subscript(self, node: ast.Subscript, **kwargs) -> foast.Subscript: return foast.Subscript( value=self.visit(node.value), index=index, - location=self._make_loc(node), + location=self.get_location(node), ) def visit_Attribute(self, node: ast.Attribute) -> Any: return foast.Attribute( - value=self.visit(node.value), attr=node.attr, location=self._make_loc(node) + value=self.visit(node.value), attr=node.attr, location=self.get_location(node) ) def visit_Tuple(self, node: ast.Tuple, **kwargs) -> foast.TupleExpr: return foast.TupleExpr( - elts=[self.visit(item) for item in node.elts], location=self._make_loc(node) + elts=[self.visit(item) for item in node.elts], location=self.get_location(node) ) def visit_Return(self, node: ast.Return, **kwargs) -> foast.Return: if not node.value: raise FieldOperatorSyntaxError.from_AST(node, msg="Empty return not allowed") - return foast.Return(value=self.visit(node.value), location=self._make_loc(node)) + return foast.Return(value=self.visit(node.value), location=self.get_location(node)) def visit_Expr(self, node: ast.Expr) -> foast.Expr: return self.visit(node.value) def visit_Name(self, node: ast.Name, **kwargs) -> foast.Name: - return foast.Name(id=node.id, location=self._make_loc(node)) + return foast.Name(id=node.id, location=self.get_location(node)) def visit_UnaryOp(self, node: ast.UnaryOp, **kwargs) -> foast.UnaryOp: return foast.UnaryOp( - op=self.visit(node.op), operand=self.visit(node.operand), location=self._make_loc(node) + op=self.visit(node.op), operand=self.visit(node.operand), location=self.get_location(node) ) def visit_UAdd(self, node: ast.UAdd, **kwargs) -> dialect_ast_enums.UnaryOperator: @@ -344,7 +344,7 @@ def visit_BinOp(self, node: ast.BinOp, **kwargs) -> foast.BinOp: op=self.visit(node.op), left=self.visit(node.left), right=self.visit(node.right), - location=self._make_loc(node), + location=self.get_location(node), ) def visit_Add(self, node: ast.Add, **kwargs) -> dialect_ast_enums.BinaryOperator: @@ -385,12 +385,12 @@ def visit_IfExp(self, node: ast.IfExp, **kwargs) -> foast.TernaryExpr: condition=self.visit(node.test), true_expr=self.visit(node.body), false_expr=self.visit(node.orelse), - location=self._make_loc(node), + location=self.get_location(node), type=ts.DeferredType(constraint=ts.DataType), ) def visit_If(self, node: ast.If, **kwargs) -> foast.IfStmt: - loc = self._make_loc(node) + loc = self.get_location(node) return foast.IfStmt( condition=self.visit(node.test, **kwargs), true_branch=self._visit_stmts(node.body, loc, **kwargs), @@ -415,7 +415,7 @@ def visit_Compare(self, node: ast.Compare, **kwargs) -> foast.Compare: op=self.visit(node.ops[0]), left=self.visit(node.left), right=self.visit(node.comparators[0]), - location=self._make_loc(node), + location=self.get_location(node), ) def visit_Gt(self, node: ast.Gt, **kwargs) -> foast.CompareOperator: @@ -485,7 +485,7 @@ def visit_Call(self, node: ast.Call, **kwargs) -> foast.Call: func=self.visit(node.func, **kwargs), args=[self.visit(arg, **kwargs) for arg in node.args], kwargs={keyword.arg: self.visit(keyword.value, **kwargs) for keyword in node.keywords}, - location=self._make_loc(node), + location=self.get_location(node), ) def visit_Constant(self, node: ast.Constant, **kwargs) -> foast.Constant: @@ -498,6 +498,6 @@ def visit_Constant(self, node: ast.Constant, **kwargs) -> foast.Constant: return foast.Constant( value=node.value, - location=self._make_loc(node), + location=self.get_location(node), type=type_, ) diff --git a/src/gt4py/next/ffront/func_to_past.py b/src/gt4py/next/ffront/func_to_past.py index 0f51ca88c1..30d77ec7a6 100644 --- a/src/gt4py/next/ffront/func_to_past.py +++ b/src/gt4py/next/ffront/func_to_past.py @@ -54,7 +54,7 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> past.Program: id=name, type=type_translation.from_value(val), namespace=dialect_ast_enums.Namespace.CLOSURE, - location=self._make_loc(node), + location=self.get_location(node), ) for name, val in self.closure_vars.items() ] @@ -65,7 +65,7 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> past.Program: params=self.visit(node.args), body=[self.visit(node) for node in node.body], closure_vars=closure_symbols, - location=self._make_loc(node), + location=self.get_location(node), ) def visit_arguments(self, node: ast.arguments) -> list[past.DataSymbol]: @@ -79,7 +79,7 @@ def visit_arg(self, node: ast.arg) -> past.DataSymbol: raise ProgramSyntaxError.from_AST( node, msg="Only arguments of type DataType are allowed." ) - return past.DataSymbol(id=node.arg, location=self._make_loc(node), type=new_type) + return past.DataSymbol(id=node.arg, location=self.get_location(node), type=new_type) def visit_Expr(self, node: ast.Expr) -> past.LocatedNode: return self.visit(node.value) @@ -119,17 +119,17 @@ def visit_BinOp(self, node: ast.BinOp, **kwargs) -> past.BinOp: op=self.visit(node.op), left=self.visit(node.left), right=self.visit(node.right), - location=self._make_loc(node), + location=self.get_location(node), ) def visit_Name(self, node: ast.Name) -> past.Name: - return past.Name(id=node.id, location=self._make_loc(node)) + return past.Name(id=node.id, location=self.get_location(node)) def visit_Dict(self, node: ast.Dict) -> past.Dict: return past.Dict( keys_=[self.visit(cast(ast.AST, param)) for param in node.keys], values_=[self.visit(param) for param in node.values], - location=self._make_loc(node), + location=self.get_location(node), ) def visit_Call(self, node: ast.Call) -> past.Call: @@ -141,20 +141,20 @@ def visit_Call(self, node: ast.Call) -> past.Call: func=new_func, args=[self.visit(arg) for arg in node.args], kwargs={arg.arg: self.visit(arg.value) for arg in node.keywords}, - location=self._make_loc(node), + location=self.get_location(node), ) def visit_Subscript(self, node: ast.Subscript) -> past.Subscript: return past.Subscript( value=self.visit(node.value), slice_=self.visit(node.slice), - location=self._make_loc(node), + location=self.get_location(node), ) def visit_Tuple(self, node: ast.Tuple) -> past.TupleExpr: return past.TupleExpr( elts=[self.visit(item) for item in node.elts], - location=self._make_loc(node), + location=self.get_location(node), type=ts.DeferredType(constraint=ts.TupleType), ) @@ -163,17 +163,17 @@ def visit_Slice(self, node: ast.Slice) -> past.Slice: lower=self.visit(node.lower) if node.lower is not None else None, upper=self.visit(node.upper) if node.upper is not None else None, step=self.visit(node.step) if node.step is not None else None, - location=self._make_loc(node), + location=self.get_location(node), ) def visit_UnaryOp(self, node: ast.UnaryOp) -> past.Constant: if isinstance(node.op, ast.USub) and isinstance(node.operand, ast.Constant): symbol_type = type_translation.from_value(node.operand.value) return past.Constant( - value=-node.operand.value, type=symbol_type, location=self._make_loc(node) + value=-node.operand.value, type=symbol_type, location=self.get_location(node) ) raise ProgramSyntaxError.from_AST(node, msg="Unary operators can only be used on literals.") def visit_Constant(self, node: ast.Constant) -> past.Constant: symbol_type = type_translation.from_value(node.value) - return past.Constant(value=node.value, type=symbol_type, location=self._make_loc(node)) + return past.Constant(value=node.value, type=symbol_type, location=self.get_location(node)) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py index ab383e37a8..8d81482c73 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py @@ -30,7 +30,10 @@ from gt4py.next.program_processors.runners import gtfn_cpu, roundtrip -@pytest.fixture(params=[roundtrip.executor, gtfn_cpu.run_gtfn, gtfn_cpu.run_gtfn_imperative]) +@pytest.fixture(params=[roundtrip.executor, + #gtfn_cpu.run_gtfn, + # gtfn_cpu.run_gtfn_imperative +]) def fieldview_backend(request): yield request.param diff --git a/tests/next_tests/unit_tests/errors_tests/test_compilation_error.py b/tests/next_tests/unit_tests/errors_tests/test_compilation_error.py new file mode 100644 index 0000000000..fff2886e8b --- /dev/null +++ b/tests/next_tests/unit_tests/errors_tests/test_compilation_error.py @@ -0,0 +1,15 @@ +from gt4py.next.errors import CompilationError +from gt4py.eve import SourceLocation + +loc = SourceLocation(5, 1, "/source/file.py", end_line=5, end_column=9) +msg = 'a message' + + +def test_message(): + assert CompilationError(loc, msg).msg == msg + + +def test_location(): + assert CompilationError(loc, msg).location == loc + + From 743042e83d69af4d776eab6438d2e3a95c135e6d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Kardos?= Date: Mon, 5 Jun 2023 15:33:12 +0200 Subject: [PATCH 02/54] refactor past and foast parsers --- src/gt4py/next/errors/__init__.py | 11 ++- src/gt4py/next/errors/exceptions.py | 41 +++++++- src/gt4py/next/errors/tools.py | 14 +++ src/gt4py/next/ffront/dialect_parser.py | 99 +++++-------------- .../foast_passes/closure_var_folding.py | 15 +-- src/gt4py/next/ffront/func_to_foast.py | 85 ++++++---------- src/gt4py/next/ffront/func_to_past.py | 30 +++--- src/gt4py/next/ffront/source_utils.py | 25 ++--- .../ffront_tests/test_func_to_foast.py | 22 ++--- .../test_func_to_foast_error_line_number.py | 13 +-- 10 files changed, 164 insertions(+), 191 deletions(-) create mode 100644 src/gt4py/next/errors/tools.py diff --git a/src/gt4py/next/errors/__init__.py b/src/gt4py/next/errors/__init__.py index 431e0c5229..6db84e4c7a 100644 --- a/src/gt4py/next/errors/__init__.py +++ b/src/gt4py/next/errors/__init__.py @@ -1 +1,10 @@ -from .exceptions import CompilationError \ No newline at end of file +from .exceptions import ( + CompilationError, + UndefinedSymbolError, + UnsupportedPythonFeatureError, + MissingParameterTypeError, + InvalidParameterTypeError, + IncorrectArgumentCountError, + UnexpectedKeywordArgError, + MissingAttributeError +) diff --git a/src/gt4py/next/errors/exceptions.py b/src/gt4py/next/errors/exceptions.py index 5c58c6cdd0..210eb7284f 100644 --- a/src/gt4py/next/errors/exceptions.py +++ b/src/gt4py/next/errors/exceptions.py @@ -1,15 +1,20 @@ from gt4py.eve import SourceLocation - +from typing import Any +from . import tools class CompilationError(SyntaxError): def __init__(self, location: SourceLocation, message: str): + try: + source_code = tools.get_code_at_location(location) + except ValueError: + source_code = None super().__init__( message, ( location.source, location.line, location.column, - None, + source_code, location.end_line, location.end_column ) @@ -28,4 +33,34 @@ def location(self): class UndefinedSymbolError(CompilationError): def __init__(self, location: SourceLocation, name: str): - super().__init__(location, f"name '{name}' is not defined") \ No newline at end of file + super().__init__(location, f"name '{name}' is not defined") + + +class UnsupportedPythonFeatureError(CompilationError): + def __init__(self, location: SourceLocation, feature: str): + super().__init__(location, f"unsupported Python syntax: '{feature}'") + + +class MissingParameterTypeError(CompilationError): + def __init__(self, location: SourceLocation, param_name: str): + super().__init__(location, f"parameter '{param_name}' is missing type annotations") + + +class InvalidParameterTypeError(CompilationError): + def __init__(self, location: SourceLocation, param_name: str, type_: Any): + super().__init__(location, f"parameter '{param_name}' has invalid type annotation '{type_}'") + + +class IncorrectArgumentCountError(CompilationError): + def __init__(self, location: SourceLocation, num_expected: int, num_provided: int): + super().__init__(location, f"expected {num_expected} arguments but {num_provided} were provided") + + +class UnexpectedKeywordArgError(CompilationError): + def __init__(self, location: SourceLocation, provided_names: str): + super().__init__(location, f"unexpected keyword argument(s) '{provided_names}' provided") + + +class MissingAttributeError(CompilationError): + def __init__(self, location: SourceLocation, attr_name: str): + super().__init__(location, f"object does not have attribute '{attr_name}'") \ No newline at end of file diff --git a/src/gt4py/next/errors/tools.py b/src/gt4py/next/errors/tools.py new file mode 100644 index 0000000000..7dac7aca3a --- /dev/null +++ b/src/gt4py/next/errors/tools.py @@ -0,0 +1,14 @@ +import pathlib +from gt4py.eve import SourceLocation + +def get_code_at_location(location: SourceLocation): + try: + source_file = pathlib.Path(location.source) + source_code = source_file.read_text() + source_lines = source_code.splitlines(False) + start_line = location.line + end_line = location.end_line + 1 if location.end_line else start_line + 1 + relevant_lines = source_lines[(start_line-1):(end_line-1)] + return "\n".join(relevant_lines) + except Exception as ex: + raise ValueError("failed to get source code for source location") from ex \ No newline at end of file diff --git a/src/gt4py/next/ffront/dialect_parser.py b/src/gt4py/next/ffront/dialect_parser.py index c23a93d0c0..1ff4d0c0b2 100644 --- a/src/gt4py/next/ffront/dialect_parser.py +++ b/src/gt4py/next/ffront/dialect_parser.py @@ -24,72 +24,22 @@ from gt4py.next.ffront.ast_passes.fix_missing_locations import FixMissingLocations from gt4py.next.ffront.ast_passes.remove_docstrings import RemoveDocstrings from gt4py.next.ffront.source_utils import SourceDefinition, get_closure_vars_from_function +from gt4py.next.errors import UnsupportedPythonFeatureError DialectRootT = TypeVar("DialectRootT") -class DialectSyntaxError(common.GTSyntaxError): - dialect_name: ClassVar[str] = "" - - def __init__( - self, - msg="", - *, - lineno: int = 0, - offset: int = 0, - filename: Optional[str] = None, - end_lineno: Optional[int] = None, - end_offset: Optional[int] = None, - text: Optional[str] = None, - ): - msg = f"Invalid {self.dialect_name} Syntax: {msg}" - super().__init__(msg, (filename, lineno, offset, text, end_lineno, end_offset)) - - @classmethod - def from_AST( - cls, - node: ast.AST, - *, - msg: str = "", - filename: Optional[str] = None, - text: Optional[str] = None, - ): - return cls( - msg, - lineno=node.lineno, - offset=node.col_offset + 1, # offset is 1-based for syntax errors - filename=filename, - end_lineno=getattr(node, "end_lineno", None), - end_offset=(node.end_col_offset + 1) - if hasattr(node, "end_col_offset") and node.end_col_offset is not None - else None, - text=text, - ) - - @classmethod - def from_location(cls, msg="", *, location: SourceLocation): - return cls( - msg, - lineno=location.line, - offset=location.column, - filename=location.source, - end_lineno=location.end_line, - end_offset=location.end_column, - text=None, - ) - - -def _ensure_syntax_error_invariants(err: SyntaxError): - """Ensure syntax error invariants required to print meaningful error messages.""" - # If offsets are provided so must line numbers. For example `err.offset` determines - # if carets (`^^^^`) are printed below `err.text`. They would be misleading if we - # don't know on which line the error occurs. - assert err.lineno or not err.offset - assert err.end_lineno or not err.end_offset - # If the ends are provided so must starts. - assert err.lineno or not err.end_lineno - assert err.offset or not err.end_offset +def parse_source_definition(source_definition: SourceDefinition) -> ast.AST: + try: + return ast.parse(textwrap.dedent(source_definition.source)).body[0] + except SyntaxError as err: + err.filename = source_definition.filename + err.lineno = err.lineno + source_definition.starting_line if err.lineno is not None else None + err.offset = err.offset + source_definition.starting_column if err.offset is not None else None + err.end_lineno = err.end_lineno + source_definition.starting_line if err.end_lineno is not None else None + err.end_offset = err.end_offset + source_definition.starting_column if err.end_offset is not None else None + raise err @dataclass(frozen=True, kw_only=True) @@ -97,7 +47,6 @@ class DialectParser(ast.NodeVisitor, Generic[DialectRootT]): source_definition: SourceDefinition closure_vars: dict[str, Any] annotations: dict[str, Any] - syntax_error_cls: ClassVar[Type[DialectSyntaxError]] = DialectSyntaxError @classmethod def apply( @@ -106,13 +55,8 @@ def apply( closure_vars: dict[str, Any], annotations: dict[str, Any], ) -> DialectRootT: - source, filename, starting_line = source_definition - - line_offset = starting_line - 1 definition_ast: ast.AST - definition_ast = ast.parse(textwrap.dedent(source)).body[0] - definition_ast = ast.increment_lineno(definition_ast, line_offset) - line_offset = 0 # line numbers are correct from now on + definition_ast = parse_source_definition(source_definition) definition_ast = RemoveDocstrings.apply(definition_ast) definition_ast = FixMissingLocations.apply(definition_ast) @@ -146,10 +90,19 @@ def _postprocess_dialect_ast( return output_ast def generic_visit(self, node: ast.AST) -> None: - raise self.syntax_error_cls.from_AST( - node, - msg=f"Nodes of type {type(node).__module__}.{type(node).__qualname__} not supported in dialect.", - ) + loc = self.get_location(node) + feature = f"{type(node).__module__}.{type(node).__qualname__}" + raise UnsupportedPythonFeatureError(loc, feature) def get_location(self, node: ast.AST) -> SourceLocation: - return SourceLocation.from_AST(node, source=self.source_definition.filename) + file = self.source_definition.filename + line_offset = self.source_definition.starting_line + col_offset = self.source_definition.starting_column + + line = node.lineno + line_offset if node.lineno is not None else None + end_line = node.end_lineno + line_offset if node.end_lineno is not None else None + column = 1 + node.col_offset + col_offset if node.col_offset is not None else None + end_column = 1 + node.end_col_offset + col_offset if node.end_col_offset is not None else None + + loc = SourceLocation(line, column, file, end_line=end_line, end_column=end_column) + return loc diff --git a/src/gt4py/next/ffront/foast_passes/closure_var_folding.py b/src/gt4py/next/ffront/foast_passes/closure_var_folding.py index 32a77fe155..45a93d74a0 100644 --- a/src/gt4py/next/ffront/foast_passes/closure_var_folding.py +++ b/src/gt4py/next/ffront/foast_passes/closure_var_folding.py @@ -18,6 +18,7 @@ import gt4py.next.ffront.field_operator_ast as foast from gt4py.eve import NodeTranslator, traits from gt4py.eve.utils import FrozenNamespace +from gt4py.next.errors import * @dataclass @@ -50,22 +51,12 @@ def visit_Name( return node def visit_Attribute(self, node: foast.Attribute, **kwargs) -> foast.Constant: - # TODO: fix import form parent module by restructuring exception classis - from gt4py.next.ffront.func_to_foast import FieldOperatorSyntaxError - value = self.visit(node.value, **kwargs) if isinstance(value, foast.Constant): if hasattr(value.value, node.attr): return foast.Constant(value=getattr(value.value, node.attr), location=node.location) - # TODO: use proper exception type (requires refactoring `FieldOperatorSyntaxError`) - raise FieldOperatorSyntaxError.from_location( - msg="Constant does not have the attribute specified by the AST.", - location=node.location, - ) - # TODO: use proper exception type (requires refactoring `FieldOperatorSyntaxError`) - raise FieldOperatorSyntaxError.from_location( - msg="Attribute can only be used on constants.", location=node.location - ) + raise MissingAttributeError(node.location, node.attr) + raise CompilationError(node.location, "attribute access only applicable to constants") def visit_FunctionDefinition( self, node: foast.FunctionDefinition, **kwargs diff --git a/src/gt4py/next/ffront/func_to_foast.py b/src/gt4py/next/ffront/func_to_foast.py index dd97ffa259..6bda016826 100644 --- a/src/gt4py/next/ffront/func_to_foast.py +++ b/src/gt4py/next/ffront/func_to_foast.py @@ -27,7 +27,8 @@ StringifyAnnotationsPass, UnchainComparesPass, ) -from gt4py.next.ffront.dialect_parser import DialectParser, DialectSyntaxError +from gt4py.next.ffront.dialect_parser import DialectParser +from gt4py.next.errors import * from gt4py.next.ffront.foast_introspection import StmtReturnKind, deduce_stmt_return_kind from gt4py.next.ffront.foast_passes.closure_var_folding import ClosureVarFolding from gt4py.next.ffront.foast_passes.closure_var_type_deduction import ClosureVarTypeDeduction @@ -37,10 +38,6 @@ from gt4py.next.type_system import type_info, type_specifications as ts, type_translation -class FieldOperatorSyntaxError(DialectSyntaxError): - dialect_name = "Field Operator" - - class FieldOperatorParser(DialectParser[foast.FunctionDefinition]): """ Parse field operator function definition from source code into FOAST. @@ -75,13 +72,11 @@ class FieldOperatorParser(DialectParser[foast.FunctionDefinition]): >>> >>> try: # doctest: +ELLIPSIS ... FieldOperatorParser.apply_to_function(wrong_syntax) - ... except FieldOperatorSyntaxError as err: + ... except CompilationError as err: ... print(f"Error at [{err.lineno}, {err.offset}] in {err.filename})") Error at [2, 5] in ...gt4py.next.ffront.func_to_foast.FieldOperatorParser[...]>) """ - syntax_error_cls = FieldOperatorSyntaxError - @classmethod def _preprocess_definition_ast(cls, definition_ast: ast.AST) -> ast.AST: sta = StringifyAnnotationsPass.apply(definition_ast) @@ -150,6 +145,7 @@ def _builtin_type_constructor_symbols( return result, to_be_inserted.keys() def visit_FunctionDef(self, node: ast.FunctionDef, **kwargs) -> foast.FunctionDefinition: + loc = self.get_location(node) closure_var_symbols, skip_names = self._builtin_type_constructor_symbols( self.closure_vars, self.get_location(node) ) @@ -168,30 +164,27 @@ def visit_FunctionDef(self, node: ast.FunctionDef, **kwargs) -> foast.FunctionDe new_body = self._visit_stmts(node.body, self.get_location(node), **kwargs) if deduce_stmt_return_kind(new_body) == StmtReturnKind.NO_RETURN: - raise FieldOperatorSyntaxError.from_AST( - node, msg="Function must return a value, but no return statement was found." - ) + raise CompilationError(loc, "function is expected to return a value, return statement not found") return foast.FunctionDefinition( id=node.name, params=self.visit(node.args, **kwargs), body=new_body, closure_vars=closure_var_symbols, - location=self.get_location(node), + location=loc, ) def visit_arguments(self, node: ast.arguments) -> list[foast.DataSymbol]: return [self.visit_arg(arg) for arg in node.args] def visit_arg(self, node: ast.arg) -> foast.DataSymbol: + loc = self.get_location(node) if (annotation := self.annotations.get(node.arg, None)) is None: - raise FieldOperatorSyntaxError.from_AST(node, msg="Untyped parameters not allowed!") + raise MissingParameterTypeError(loc, node.arg) new_type = type_translation.from_type_hint(annotation) if not isinstance(new_type, ts.DataType): - raise FieldOperatorSyntaxError.from_AST( - node, msg="Only arguments of type DataType are allowed." - ) - return foast.DataSymbol(id=node.arg, location=self.get_location(node), type=new_type) + raise InvalidParameterTypeError(loc, node.arg, new_type) + return foast.DataSymbol(id=node.arg, location=loc, type=new_type) def visit_Assign(self, node: ast.Assign, **kwargs) -> foast.Assign | foast.TupleTargetAssign: target = node.targets[0] # there is only one element after assignment passes @@ -228,7 +221,7 @@ def visit_Assign(self, node: ast.Assign, **kwargs) -> foast.Assign | foast.Tuple ) if not isinstance(target, ast.Name): - raise FieldOperatorSyntaxError.from_AST(node, msg="Can only assign to names!") + raise CompilationError(self.get_location(node), "can only assign to names") new_value = self.visit(node.value) constraint_type: Type[ts.DataType] = ts.DataType if isinstance(new_value, foast.TupleExpr): @@ -250,7 +243,7 @@ def visit_Assign(self, node: ast.Assign, **kwargs) -> foast.Assign | foast.Tuple def visit_AnnAssign(self, node: ast.AnnAssign, **kwargs) -> foast.Assign: if not isinstance(node.target, ast.Name): - raise FieldOperatorSyntaxError.from_AST(node, msg="Can only assign to names!") + raise CompilationError(self.get_location(node), "can only assign to names") if node.annotation is not None: assert isinstance( @@ -291,9 +284,7 @@ def visit_Subscript(self, node: ast.Subscript, **kwargs) -> foast.Subscript: try: index = self._match_index(node.slice) except ValueError: - raise FieldOperatorSyntaxError.from_AST( - node, msg="""Only index is supported in subscript!""" - ) + raise CompilationError(self.get_location(node.slice), "expected an integral index") from None return foast.Subscript( value=self.visit(node.value), @@ -312,9 +303,10 @@ def visit_Tuple(self, node: ast.Tuple, **kwargs) -> foast.TupleExpr: ) def visit_Return(self, node: ast.Return, **kwargs) -> foast.Return: + loc = self.get_location(node) if not node.value: - raise FieldOperatorSyntaxError.from_AST(node, msg="Empty return not allowed") - return foast.Return(value=self.visit(node.value), location=self.get_location(node)) + raise CompilationError(loc, "must return a value, not None") + return foast.Return(value=self.visit(node.value), location=loc) def visit_Expr(self, node: ast.Expr) -> foast.Expr: return self.visit(node.value) @@ -378,7 +370,7 @@ def visit_BitXor(self, node: ast.BitXor, **kwargs) -> dialect_ast_enums.BinaryOp return dialect_ast_enums.BinaryOperator.BIT_XOR def visit_BoolOp(self, node: ast.BoolOp, **kwargs) -> None: - raise FieldOperatorSyntaxError.from_AST(node, msg="`and`/`or` operator not allowed!") + raise UnsupportedPythonFeatureError(self.get_location(node), "logical operators `and`, `or`") def visit_IfExp(self, node: ast.IfExp, **kwargs) -> foast.TernaryExpr: return foast.TernaryExpr( @@ -407,15 +399,16 @@ def _visit_stmts( ) def visit_Compare(self, node: ast.Compare, **kwargs) -> foast.Compare: + loc = self.get_location(node) if len(node.ops) != 1 or len(node.comparators) != 1: - raise FieldOperatorSyntaxError.from_AST( - node, msg="Comparison chains not allowed, run a preprocessing pass!" - ) + # Remove comparison chains in a preprocessing pass + # TODO: maybe add a note to the error about preprocessing passes? + raise UnsupportedPythonFeatureError(loc, "comparison chains") return foast.Compare( op=self.visit(node.ops[0]), left=self.visit(node.left), right=self.visit(node.comparators[0]), - location=self.get_location(node), + location=loc, ) def visit_Gt(self, node: ast.Gt, **kwargs) -> foast.CompareOperator: @@ -437,36 +430,23 @@ def visit_NotEq(self, node: ast.NotEq, **kwargs) -> foast.CompareOperator: return foast.CompareOperator.NOTEQ def _verify_builtin_function(self, node: ast.Call): + loc = self.get_location(node) func_name = self._func_name(node) func_info = getattr(fbuiltins, func_name).__gt_type__() if not len(node.args) == len(func_info.args): - raise FieldOperatorSyntaxError.from_AST( - node, - msg=f"{func_name}() expected {len(func_info.args)} positional arguments, {len(node.args)} given!", - ) + raise IncorrectArgumentCountError(loc, len(func_info.args), len(node.args)) elif unexpected_kwargs := set(k.arg for k in node.keywords) - set(func_info.kwargs): - raise FieldOperatorSyntaxError.from_AST( - node, - msg=f"{self._func_name(node)}() got unexpected keyword arguments: {unexpected_kwargs}!", - ) + raise UnexpectedKeywordArgError(loc, ", ".join(unexpected_kwargs)) def _verify_builtin_type_constructor(self, node: ast.Call): + loc = self.get_location(node) if not len(node.args) == 1: - raise FieldOperatorSyntaxError.from_AST( - node, - msg=f"{self._func_name(node)}() expected 1 positional argument, {len(node.args)} given!", - ) + raise IncorrectArgumentCountError(loc, 1, len(node.args)) elif node.keywords: unexpected_kwargs = set(k.arg for k in node.keywords) - raise FieldOperatorSyntaxError.from_AST( - node, - msg=f"{self._func_name(node)}() got unexpected keyword arguments: {unexpected_kwargs}!", - ) + raise UnexpectedKeywordArgError(loc, ", ".join(unexpected_kwargs)) elif not isinstance(node.args[0], ast.Constant): - raise FieldOperatorSyntaxError.from_AST( - node, - msg=f"{self._func_name(node)}() only takes literal arguments!", - ) + raise CompilationError(self.get_location(node.args[0]), "expected a literal expression") def _func_name(self, node: ast.Call) -> str: return node.func.id # type: ignore[attr-defined] # We want this to fail if the attribute does not exist unexpectedly. @@ -489,15 +469,14 @@ def visit_Call(self, node: ast.Call, **kwargs) -> foast.Call: ) def visit_Constant(self, node: ast.Constant, **kwargs) -> foast.Constant: + loc = self.get_location(node) try: type_ = type_translation.from_value(node.value) except common.GTTypeError as e: - raise FieldOperatorSyntaxError.from_AST( - node, msg=f"Constants of type {type(node.value)} are not permitted." - ) from e + raise CompilationError(loc, f"constants of type {type(node.value)} are not permitted") from None return foast.Constant( value=node.value, - location=self.get_location(node), + location=loc, type=type_, ) diff --git a/src/gt4py/next/ffront/func_to_past.py b/src/gt4py/next/ffront/func_to_past.py index 30d77ec7a6..85276fd3ff 100644 --- a/src/gt4py/next/ffront/func_to_past.py +++ b/src/gt4py/next/ffront/func_to_past.py @@ -23,28 +23,21 @@ program_ast as past, type_specifications as ts_ffront, ) -from gt4py.next.ffront.dialect_parser import DialectParser, DialectSyntaxError +from gt4py.next.ffront.dialect_parser import DialectParser +from gt4py.next.errors import CompilationError, MissingParameterTypeError, InvalidParameterTypeError from gt4py.next.ffront.past_passes.closure_var_type_deduction import ClosureVarTypeDeduction from gt4py.next.ffront.past_passes.type_deduction import ProgramTypeDeduction from gt4py.next.type_system import type_specifications as ts, type_translation -class ProgramSyntaxError(DialectSyntaxError): - dialect_name = "Program" - - @dataclass(frozen=True, kw_only=True) class ProgramParser(DialectParser[past.Program]): """Parse program definition from Python source code into PAST.""" - syntax_error_cls = ProgramSyntaxError - @classmethod def _postprocess_dialect_ast( cls, output_node: past.Program, closure_vars: dict[str, Any], annotations: dict[str, Any] ) -> past.Program: - if "return" in annotations and not isinstance(None, annotations["return"]): - raise ProgramSyntaxError("Program should not have a return value!") output_node = ClosureVarTypeDeduction.apply(output_node, closure_vars) return ProgramTypeDeduction.apply(output_node) @@ -72,14 +65,13 @@ def visit_arguments(self, node: ast.arguments) -> list[past.DataSymbol]: return [self.visit_arg(arg) for arg in node.args] def visit_arg(self, node: ast.arg) -> past.DataSymbol: + loc = self.get_location(node) if (annotation := self.annotations.get(node.arg, None)) is None: - raise ProgramSyntaxError.from_AST(node, msg="Untyped parameters not allowed!") + raise MissingParameterTypeError(loc, node.arg) new_type = type_translation.from_type_hint(annotation) if not isinstance(new_type, ts.DataType): - raise ProgramSyntaxError.from_AST( - node, msg="Only arguments of type DataType are allowed." - ) - return past.DataSymbol(id=node.arg, location=self.get_location(node), type=new_type) + raise InvalidParameterTypeError(loc, node.arg, new_type) + return past.DataSymbol(id=node.arg, location=loc, type=new_type) def visit_Expr(self, node: ast.Expr) -> past.LocatedNode: return self.visit(node.value) @@ -133,15 +125,16 @@ def visit_Dict(self, node: ast.Dict) -> past.Dict: ) def visit_Call(self, node: ast.Call) -> past.Call: + loc = self.get_location(node) new_func = self.visit(node.func) if not isinstance(new_func, past.Name): - raise ProgramSyntaxError.from_AST(node, msg="Functions can only be called directly!") + raise CompilationError(loc, "functions must be referenced by their name in function calls") return past.Call( func=new_func, args=[self.visit(arg) for arg in node.args], kwargs={arg.arg: self.visit(arg.value) for arg in node.keywords}, - location=self.get_location(node), + location=loc, ) def visit_Subscript(self, node: ast.Subscript) -> past.Subscript: @@ -167,12 +160,13 @@ def visit_Slice(self, node: ast.Slice) -> past.Slice: ) def visit_UnaryOp(self, node: ast.UnaryOp) -> past.Constant: + loc = self.get_location(node) if isinstance(node.op, ast.USub) and isinstance(node.operand, ast.Constant): symbol_type = type_translation.from_value(node.operand.value) return past.Constant( - value=-node.operand.value, type=symbol_type, location=self.get_location(node) + value=-node.operand.value, type=symbol_type, location=loc ) - raise ProgramSyntaxError.from_AST(node, msg="Unary operators can only be used on literals.") + raise CompilationError(loc, "unary operators are only applicable to literals") def visit_Constant(self, node: ast.Constant) -> past.Constant: symbol_type = type_translation.from_value(node.value) diff --git a/src/gt4py/next/ffront/source_utils.py b/src/gt4py/next/ffront/source_utils.py index e0f428dbc2..a112c946d9 100644 --- a/src/gt4py/next/ffront/source_utils.py +++ b/src/gt4py/next/ffront/source_utils.py @@ -36,19 +36,19 @@ def get_closure_vars_from_function(function: Callable) -> dict[str, Any]: def make_source_definition_from_function(func: Callable) -> SourceDefinition: try: - filename = str(pathlib.Path(inspect.getabsfile(func)).resolve()) or MISSING_FILENAME - source = textwrap.dedent(inspect.getsource(func)) - starting_line = ( - inspect.getsourcelines(func)[1] if not filename.endswith(MISSING_FILENAME) else 1 + filename = str(pathlib.Path(inspect.getabsfile(func)).resolve()) + if not filename: + raise ValueError("Can not create field operator from a function that is not in a source file!") + source_lines, line_offset = inspect.getsourcelines(func) + source_code = textwrap.dedent(inspect.getsource(func)) + column_offset = min( + [len(line) - len(line.lstrip()) for line in source_lines if line.lstrip()], + default=0 ) - except OSError as err: - if filename.endswith(MISSING_FILENAME): - message = "Can not create field operator from a function that is not in a source file!" - else: - message = f"Can not get source code of passed function ({func})" - raise ValueError(message) from err + return SourceDefinition(source_code, filename, line_offset - 1, column_offset) - return SourceDefinition(source, filename, starting_line) + except OSError as err: + raise ValueError(f"Can not get source code of passed function ({func})") from err def make_symbol_names_from_source(source: str, filename: str = MISSING_FILENAME) -> SymbolNames: @@ -119,7 +119,8 @@ def foo(a): source: str filename: str = MISSING_FILENAME - starting_line: int = 1 + starting_line: int = 0 + starting_column: int = 0 def __iter__(self) -> Iterator: yield self.source diff --git a/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py b/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py index 5a892070e2..9d8853e740 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py @@ -54,7 +54,7 @@ where, ) from gt4py.next.ffront.foast_passes.type_deduction import FieldOperatorTypeDeductionError -from gt4py.next.ffront.func_to_foast import FieldOperatorParser, FieldOperatorSyntaxError +from gt4py.next.ffront.func_to_foast import FieldOperatorParser from gt4py.next.iterator import ir as itir from gt4py.next.iterator.builtins import ( and_, @@ -76,6 +76,7 @@ ) from gt4py.next.type_system import type_specifications as ts from gt4py.next.type_system.type_translation import TypingError +from gt4py.next.errors import * DEREF = itir.SymRef(id=deref.fun.__name__) @@ -106,8 +107,7 @@ def untyped(inp): return inp with pytest.raises( - FieldOperatorSyntaxError, - match="Untyped parameters not allowed!", + MissingParameterTypeError ): _ = FieldOperatorParser.apply_to_function(untyped) @@ -146,8 +146,8 @@ def no_return(inp: Field[[TDim], "float64"]): tmp = inp # noqa with pytest.raises( - FieldOperatorSyntaxError, - match="Function must return a value, but no return statement was found\.", + CompilationError, + match=".*return.*", ): _ = FieldOperatorParser.apply_to_function(no_return) @@ -160,7 +160,7 @@ def invalid_assign_to_expr(inp1: Field[[TDim], "float64"], inp2: Field[[TDim], " tmp[-1] = inp2 return tmp - with pytest.raises(FieldOperatorSyntaxError, match=r"Can only assign to names! \(.*\)"): + with pytest.raises(CompilationError, match=r".*assign.*"): _ = FieldOperatorParser.apply_to_function(invalid_assign_to_expr) @@ -219,8 +219,8 @@ def bool_and(a: Field[[TDim], "bool"], b: Field[[TDim], "bool"]): return a and b with pytest.raises( - FieldOperatorSyntaxError, - match=(r"`and`/`or` operator not allowed!"), + UnsupportedPythonFeatureError, + match=r".*and.*or.*", ): _ = FieldOperatorParser.apply_to_function(bool_and) @@ -230,8 +230,8 @@ def bool_or(a: Field[[TDim], "bool"], b: Field[[TDim], "bool"]): return a or b with pytest.raises( - FieldOperatorSyntaxError, - match=(r"`and`/`or` operator not allowed!"), + UnsupportedPythonFeatureError, + match=r".*and.*or.*", ): _ = FieldOperatorParser.apply_to_function(bool_or) @@ -265,7 +265,7 @@ def cast_scalar_temp(): tmp = int64(1) return int32(tmp) - with pytest.raises(FieldOperatorSyntaxError, match=(r"only takes literal arguments!")): + with pytest.raises(CompilationError, match=r".*literal.*"): _ = FieldOperatorParser.apply_to_function(cast_scalar_temp) diff --git a/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast_error_line_number.py b/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast_error_line_number.py index 3ad2bf757c..adf7a47602 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast_error_line_number.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast_error_line_number.py @@ -21,7 +21,7 @@ from gt4py.next.common import Dimension, Field from gt4py.next.ffront import func_to_foast as f2f, source_utils as src_utils from gt4py.next.ffront.foast_passes import type_deduction -from gt4py.next.ffront.func_to_foast import FieldOperatorParser, FieldOperatorSyntaxError +from gt4py.next.ffront.func_to_foast import FieldOperatorParser # NOTE: These tests are sensitive to filename and the line number of the marked statement @@ -38,12 +38,9 @@ def wrong_syntax(inp: common.Field[[TDim], float]): return # <-- this line triggers the syntax error with pytest.raises( - f2f.FieldOperatorSyntaxError, + f2f.CompilationError, match=( - r"Invalid Field Operator Syntax: " - r"Empty return not allowed \(test_func_to_foast_error_line_number.py, line " - + str(line + 3) - + r"\)" + r".*return.*" ), ) as exc_info: _ = f2f.FieldOperatorParser.apply_to_function(wrong_syntax) @@ -65,7 +62,7 @@ def wrong_line_syntax_error(inp: common.Field[[TDim], float]): return inp - with pytest.raises(f2f.FieldOperatorSyntaxError) as exc_info: + with pytest.raises(f2f.CompilationError) as exc_info: _ = f2f.FieldOperatorParser.apply_to_function(wrong_line_syntax_error) exc = exc_info.value @@ -99,7 +96,7 @@ def test_syntax_error_without_function(): """Dialect parsers report line numbers correctly when applied to `SourceDefinition`.""" source_definition = src_utils.SourceDefinition( - starting_line=62, + starting_line=61, source=""" def invalid_python_syntax(): # This function contains a python syntax error From c38116ff3eac9fe9bd5ac2cf067f33da1746369d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Kardos?= Date: Tue, 6 Jun 2023 11:05:00 +0200 Subject: [PATCH 03/54] foast type deduction uses new exceptions --- .../ffront/foast_passes/type_deduction.py | 213 ++++++------------ .../ffront_tests/test_execution.py | 8 +- .../ffront_tests/test_scalar_if.py | 8 +- .../ffront_tests/test_type_deduction.py | 28 +-- .../ffront_tests/test_func_to_foast.py | 14 +- .../test_func_to_foast_error_line_number.py | 3 +- 6 files changed, 101 insertions(+), 173 deletions(-) diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index 33ea81ff2d..4448704946 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -25,6 +25,7 @@ ) from gt4py.next.ffront.foast_passes.utils import compute_assign_indices from gt4py.next.type_system import type_info, type_specifications as ts, type_translation +from gt4py.next.errors import * def boolified_type(symbol_type: ts.TypeSpec) -> ts.ScalarType | ts.FieldType: @@ -144,9 +145,8 @@ def deduce_stmt_return_type( if return_types[0] == return_types[1]: is_unconditional_return = True else: - raise FieldOperatorTypeDeductionError.from_foast_node( - stmt, - msg=f"If statement contains return statements with inconsistent types:" + raise CompilationError(stmt.location, + f"If statement contains return statements with inconsistent types:" f"{return_types[0]} != {return_types[1]}", ) return_type = return_types[0] or return_types[1] @@ -161,9 +161,8 @@ def deduce_stmt_return_type( raise AssertionError(f"Nodes of type `{type(stmt).__name__}` not supported.") if conditional_return_type and return_type and return_type != conditional_return_type: - raise FieldOperatorTypeDeductionError.from_foast_node( - stmt, - msg=f"If statement contains return statements with inconsistent types:" + raise CompilationError(stmt.location, + f"If statement contains return statements with inconsistent types:" f"{conditional_return_type} != {conditional_return_type}", ) @@ -247,9 +246,8 @@ def visit_FunctionDefinition(self, node: foast.FunctionDefinition, **kwargs): new_closure_vars = self.visit(node.closure_vars, **kwargs) return_type = deduce_stmt_return_type(new_body) if not isinstance(return_type, (ts.DataType, ts.DeferredType, ts.VoidType)): - raise FieldOperatorTypeDeductionError.from_foast_node( - node, - msg=f"Function must return `DataType`, `DeferredType`, or `VoidType`, got `{return_type}`.", + raise CompilationError(node.location, + f"Function must return `DataType`, `DeferredType`, or `VoidType`, got `{return_type}`.", ) new_type = ts.FunctionType( args=[new_param.type for new_param in new_params], kwargs={}, returns=return_type @@ -276,28 +274,24 @@ def visit_FieldOperator(self, node: foast.FieldOperator, **kwargs) -> foast.Fiel def visit_ScanOperator(self, node: foast.ScanOperator, **kwargs) -> foast.ScanOperator: new_axis = self.visit(node.axis, **kwargs) if not isinstance(new_axis.type, ts.DimensionType): - raise FieldOperatorTypeDeductionError.from_foast_node( - node, - msg=f"Argument `axis` to scan operator `{node.id}` must be a dimension.", + raise CompilationError(node.location, + f"Argument `axis` to scan operator `{node.id}` must be a dimension.", ) if not new_axis.type.dim.kind == DimensionKind.VERTICAL: - raise FieldOperatorTypeDeductionError.from_foast_node( - node, - msg=f"Argument `axis` to scan operator `{node.id}` must be a vertical dimension.", + raise CompilationError(node.location, + f"Argument `axis` to scan operator `{node.id}` must be a vertical dimension.", ) new_forward = self.visit(node.forward, **kwargs) if not new_forward.type.kind == ts.ScalarKind.BOOL: - raise FieldOperatorTypeDeductionError.from_foast_node( - node, msg=f"Argument `forward` to scan operator `{node.id}` must be a boolean." + raise CompilationError(node.location, f"Argument `forward` to scan operator `{node.id}` must be a boolean." ) new_init = self.visit(node.init, **kwargs) if not all( type_info.is_arithmetic(type_) or type_info.is_logical(type_) for type_ in type_info.primitive_constituents(new_init.type) ): - raise FieldOperatorTypeDeductionError.from_foast_node( - node, - msg=f"Argument `init` to scan operator `{node.id}` must " + raise CompilationError(node.location, + f"Argument `init` to scan operator `{node.id}` must " f"be an arithmetic type or a logical type or a composite of arithmetic and logical types.", ) new_definition = self.visit(node.definition, **kwargs) @@ -318,8 +312,7 @@ def visit_ScanOperator(self, node: foast.ScanOperator, **kwargs) -> foast.ScanOp def visit_Name(self, node: foast.Name, **kwargs) -> foast.Name: symtable = kwargs["symtable"] if node.id not in symtable or symtable[node.id].type is None: - raise FieldOperatorTypeDeductionError.from_foast_node( - node, msg=f"Undeclared symbol `{node.id}`." + raise CompilationError(node.location, f"Undeclared symbol `{node.id}`." ) symbol = symtable[node.id] @@ -344,8 +337,7 @@ def visit_TupleTargetAssign( indices: list[tuple[int, int] | int] = compute_assign_indices(targets, num_elts) if not any(isinstance(i, tuple) for i in indices) and len(indices) != num_elts: - raise FieldOperatorTypeDeductionError.from_foast_node( - node, msg=f"Too many values to unpack (expected {len(indices)})." + raise CompilationError(node.location, f"Too many values to unpack (expected {len(indices)})." ) new_targets: TargetType = [] @@ -376,8 +368,7 @@ def visit_TupleTargetAssign( ) new_targets.append(new_target) else: - raise FieldOperatorTypeDeductionError.from_foast_node( - node, msg=f"Assignment value must be of type tuple! Got: {values.type}" + raise CompilationError(node.location, f"Assignment value must be of type tuple! Got: {values.type}" ) return foast.TupleTargetAssign(targets=new_targets, value=values, location=node.location) @@ -395,16 +386,14 @@ def visit_IfStmt(self, node: foast.IfStmt, **kwargs) -> foast.IfStmt: ) if not isinstance(new_node.condition.type, ts.ScalarType): - raise FieldOperatorTypeDeductionError.from_foast_node( - node, - msg="Condition for `if` must be scalar. " + raise CompilationError(node.location, + "Condition for `if` must be scalar. " f"But got `{new_node.condition.type}` instead.", ) if new_node.condition.type.kind != ts.ScalarKind.BOOL: - raise FieldOperatorTypeDeductionError.from_foast_node( - node, - msg="Condition for `if` must be of boolean type. " + raise CompilationError(node.location, + "Condition for `if` must be of boolean type. " f"But got `{new_node.condition.type}` instead.", ) @@ -412,9 +401,8 @@ def visit_IfStmt(self, node: foast.IfStmt, **kwargs) -> foast.IfStmt: if (true_type := new_true_branch.annex.symtable[sym].type) != ( false_type := new_false_branch.annex.symtable[sym].type ): - raise FieldOperatorTypeDeductionError.from_foast_node( - node, - msg=f"Inconsistent types between two branches for variable `{sym}`. " + raise CompilationError(node.location, + f"Inconsistent types between two branches for variable `{sym}`. " f"Got types `{true_type}` and `{false_type}.", ) # TODO: properly patch symtable (new node?) @@ -433,9 +421,8 @@ def visit_Symbol( symtable = kwargs["symtable"] if refine_type: if not type_info.is_concretizable(node.type, to_type=refine_type): - raise FieldOperatorTypeDeductionError.from_foast_node( - node, - msg=( + raise CompilationError(node.location, + ( "type inconsistency: expression was deduced to be " f"of type {refine_type}, instead of the expected type {node.type}" ), @@ -455,23 +442,20 @@ def visit_Subscript(self, node: foast.Subscript, **kwargs) -> foast.Subscript: new_type = types[node.index] case ts.OffsetType(source=source, target=(target1, target2)): if not target2.kind == DimensionKind.LOCAL: - raise FieldOperatorTypeDeductionError.from_foast_node( - new_value, msg="Second dimension in offset must be a local dimension." - ) + raise CompilationError(new_value.location, "Second dimension in offset must be a local dimension.") new_type = ts.OffsetType(source=source, target=(target1,)) case ts.OffsetType(source=source, target=(target,)): # for cartesian axes (e.g. I, J) the index of the subscript only # signifies the displacement in the respective dimension, # but does not change the target type. if source != target: - raise FieldOperatorTypeDeductionError.from_foast_node( - new_value, - msg="Source and target must be equal for offsets with a single target.", + raise CompilationError(new_value.location, + "Source and target must be equal for offsets with a single target.", ) new_type = new_value.type case _: - raise FieldOperatorTypeDeductionError.from_foast_node( - new_value, msg="Could not deduce type of subscript expression!" + raise CompilationError( + new_value.location, "Could not deduce type of subscript expression!" ) return foast.Subscript( @@ -510,15 +494,13 @@ def _deduce_ternaryexpr_type( false_expr: foast.Expr, ) -> Optional[ts.TypeSpec]: if condition.type != ts.ScalarType(kind=ts.ScalarKind.BOOL): - raise FieldOperatorTypeDeductionError.from_foast_node( - condition, - msg=f"Condition is of type `{condition.type}` " f"but should be of type `bool`.", + raise CompilationError(condition.location, + f"Condition is of type `{condition.type}` " f"but should be of type `bool`.", ) if true_expr.type != false_expr.type: - raise FieldOperatorTypeDeductionError.from_foast_node( - node, - msg=f"Left and right types are not the same: `{true_expr.type}` and `{false_expr.type}`", + raise CompilationError(node.location, + f"Left and right types are not the same: `{true_expr.type}` and `{false_expr.type}`", ) return true_expr.type @@ -536,8 +518,7 @@ def _deduce_compare_type( # check both types compatible for arg in (left, right): if not type_info.is_arithmetic(arg.type): - raise FieldOperatorTypeDeductionError.from_foast_node( - arg, msg=f"Type {arg.type} can not be used in operator '{node.op}'!" + raise CompilationError(arg.location, f"Type {arg.type} can not be used in operator '{node.op}'!" ) self._check_operand_dtypes_match(node, left=left, right=right) @@ -547,9 +528,8 @@ def _deduce_compare_type( # mechanism to handle dimension promotion return type_info.promote(boolified_type(left.type), boolified_type(right.type)) except GTTypeError as ex: - raise FieldOperatorTypeDeductionError.from_foast_node( - node, - msg=f"Could not promote `{left.type}` and `{right.type}` to common type" + raise CompilationError(node.location, + f"Could not promote `{left.type}` and `{right.type}` to common type" f" in call to `{node.op}`.", ) from ex @@ -571,8 +551,7 @@ def _deduce_binop_type( # check both types compatible for arg in (left, right): if not is_compatible(arg.type): - raise FieldOperatorTypeDeductionError.from_foast_node( - arg, msg=f"Type {arg.type} can not be used in operator `{node.op}`!" + raise CompilationError(arg.location, f"Type {arg.type} can not be used in operator `{node.op}`!" ) left_type = cast(ts.FieldType | ts.ScalarType, left.type) @@ -584,17 +563,15 @@ def _deduce_binop_type( if node.op == dialect_ast_enums.BinaryOperator.MOD and not type_info.is_integral( right_type ): - raise FieldOperatorTypeDeductionError.from_foast_node( - arg, - msg=f"Type {right_type} can not be used in operator `{node.op}`, it can only accept ints", + raise CompilationError(arg.location, + f"Type {right_type} can not be used in operator `{node.op}`, it can only accept ints", ) try: return type_info.promote(left_type, right_type) except GTTypeError as ex: - raise FieldOperatorTypeDeductionError.from_foast_node( - node, - msg=f"Could not promote `{left_type}` and `{right_type}` to common type" + raise CompilationError(node.location, + f"Could not promote `{left_type}` and `{right_type}` to common type" f" in call to `{node.op}`.", ) from ex @@ -603,9 +580,8 @@ def _check_operand_dtypes_match( ) -> None: # check dtypes match if not type_info.extract_dtype(left.type) == type_info.extract_dtype(right.type): - raise FieldOperatorTypeDeductionError.from_foast_node( - node, - msg=f"Incompatible datatypes in operator `{node.op}`: {left.type} and {right.type}!", + raise CompilationError(node.location, + f"Incompatible datatypes in operator `{node.op}`: {left.type} and {right.type}!", ) def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs) -> foast.UnaryOp: @@ -620,9 +596,8 @@ def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs) -> foast.UnaryOp: else type_info.is_arithmetic ) if not is_compatible(new_operand.type): - raise FieldOperatorTypeDeductionError.from_foast_node( - node, - msg=f"Incompatible type for unary operator `{node.op}`: `{new_operand.type}`!", + raise CompilationError(node.location, + f"Incompatible type for unary operator `{node.op}`: `{new_operand.type}`!", ) return foast.UnaryOp( op=node.op, operand=new_operand, location=node.location, type=new_operand.type @@ -652,15 +627,13 @@ def visit_Call(self, node: foast.Call, **kwargs) -> foast.Call: new_func, (foast.FunctionDefinition, foast.FieldOperator, foast.ScanOperator, foast.Name), ): - raise FieldOperatorTypeDeductionError.from_foast_node( - node, msg="Functions can only be called directly!" + raise CompilationError(node.location, "Functions can only be called directly!" ) elif isinstance(new_func.type, ts.FieldType): pass else: - raise FieldOperatorTypeDeductionError.from_foast_node( - node, - msg=f"Expression of type `{new_func.type}` is not callable, must be a `Function`, `FieldOperator`, `ScanOperator` or `Field`.", + raise CompilationError(node.location, + f"Expression of type `{new_func.type}` is not callable, must be a `Function`, `FieldOperator`, `ScanOperator` or `Field`.", ) # ensure signature is valid @@ -672,8 +645,7 @@ def visit_Call(self, node: foast.Call, **kwargs) -> foast.Call: raise_exception=True, ) except GTTypeError as err: - raise FieldOperatorTypeDeductionError.from_foast_node( - node, msg=f"Invalid argument types in call to `{new_func}`!" + raise CompilationError(node.location, f"Invalid argument types in call to `{new_func}`!" ) from err return_type = type_info.return_type(func_type, with_args=arg_types, with_kwargs=kwarg_types) @@ -730,9 +702,8 @@ def _visit_math_built_in(self, node: foast.Call, **kwargs) -> foast.Call: f"Expected {i}-th argument to be {error_msg_for_validator[arg_validator]} type, but got `{arg.type}`." ) if error_msgs: - raise FieldOperatorTypeDeductionError.from_foast_node( - node, - msg="\n".join([error_msg_preamble] + [f" - {error}" for error in error_msgs]), + raise CompilationError(node.location, + "\n".join([error_msg_preamble] + [f" - {error}" for error in error_msgs]), ) if func_name == "power" and all(type_info.is_integral(arg.type) for arg in node.args): @@ -753,8 +724,7 @@ def _visit_math_built_in(self, node: foast.Call, **kwargs) -> foast.Call: *((cast(ts.FieldType | ts.ScalarType, arg.type)) for arg in node.args) ) except GTTypeError as ex: - raise FieldOperatorTypeDeductionError.from_foast_node( - node, msg=error_msg_preamble + raise CompilationError(node.location, error_msg_preamble ) from ex else: raise AssertionError(f"Unknown math builtin `{func_name}`.") @@ -774,9 +744,8 @@ def _visit_reduction(self, node: foast.Call, **kwargs) -> foast.Call: assert field_type.dims is not ... if reduction_dim not in field_type.dims: field_dims_str = ", ".join(str(dim) for dim in field_type.dims) - raise FieldOperatorTypeDeductionError.from_foast_node( - node, - msg=f"Incompatible field argument in call to `{str(node.func)}`. " + raise CompilationError(node.location, + f"Incompatible field argument in call to `{str(node.func)}`. " f"Expected a field with dimension {reduction_dim}, but got " f"{field_dims_str}.", ) @@ -830,17 +799,15 @@ def _visit_as_offset(self, node: foast.Call, **kwargs) -> foast.Call: assert isinstance(arg_0, ts.OffsetType) assert isinstance(arg_1, ts.FieldType) if not type_info.is_integral(arg_1): - raise FieldOperatorTypeDeductionError.from_foast_node( - node, - msg=f"Incompatible argument in call to `{str(node.func)}`. " + raise CompilationError(node.location, + f"Incompatible argument in call to `{str(node.func)}`. " f"Excepted integer for offset field dtype, but got {arg_1.dtype}" f"{node.location}", ) if arg_0.source not in arg_1.dims: - raise FieldOperatorTypeDeductionError.from_foast_node( - node, - msg=f"Incompatible argument in call to `{str(node.func)}`. " + raise CompilationError(node.location, + f"Incompatible argument in call to `{str(node.func)}`. " f"{arg_0.source} not in list of offset field dimensions {arg_1.dims}. " f"{node.location}", ) @@ -859,9 +826,8 @@ def _visit_where(self, node: foast.Call, **kwargs) -> foast.Call: false_branch_type = node.args[2].type return_type: ts.TupleType | ts.FieldType if not type_info.is_logical(mask_type): - raise FieldOperatorTypeDeductionError.from_foast_node( - node, - msg=f"Incompatible argument in call to `{str(node.func)}`. Expected " + raise CompilationError(node.location, + f"Incompatible argument in call to `{str(node.func)}`. Expected " f"a field with dtype bool, but got `{mask_type}`.", ) @@ -877,9 +843,8 @@ def _visit_where(self, node: foast.Call, **kwargs) -> foast.Call: elif isinstance(true_branch_type, ts.TupleType) or isinstance( false_branch_type, ts.TupleType ): - raise FieldOperatorTypeDeductionError.from_foast_node( - node, - msg=f"Return arguments need to be of same type in {str(node.func)}, but got: " + raise CompilationError(node.location, + f"Return arguments need to be of same type in {str(node.func)}, but got: " f"{node.args[1].type} and {node.args[2].type}", ) else: @@ -889,9 +854,8 @@ def _visit_where(self, node: foast.Call, **kwargs) -> foast.Call: return_type = promote_to_mask_type(mask_type, promoted_type) except GTTypeError as ex: - raise FieldOperatorTypeDeductionError.from_foast_node( - node, - msg=f"Incompatible argument in call to `{str(node.func)}`.", + raise CompilationError(node.location, + f"Incompatible argument in call to `{str(node.func)}`.", ) from ex return foast.Call( @@ -907,18 +871,16 @@ def _visit_broadcast(self, node: foast.Call, **kwargs) -> foast.Call: broadcast_dims_expr = cast(foast.TupleExpr, node.args[1]).elts if any([not (isinstance(elt.type, ts.DimensionType)) for elt in broadcast_dims_expr]): - raise FieldOperatorTypeDeductionError.from_foast_node( - node, - msg=f"Incompatible broadcast dimension type in {str(node.func)}. Expected " + raise CompilationError(node.location, + f"Incompatible broadcast dimension type in {str(node.func)}. Expected " f"all broadcast dimensions to be of type Dimension.", ) broadcast_dims = [cast(ts.DimensionType, elt.type).dim for elt in broadcast_dims_expr] if not set((arg_dims := type_info.extract_dims(arg_type))).issubset(set(broadcast_dims)): - raise FieldOperatorTypeDeductionError.from_foast_node( - node, - msg=f"Incompatible broadcast dimensions in {str(node.func)}. Expected " + raise CompilationError(node.location, + f"Incompatible broadcast dimensions in {str(node.func)}. Expected " f"broadcast dimension is missing {set(arg_dims).difference(set(broadcast_dims))}", ) @@ -939,41 +901,6 @@ def visit_Constant(self, node: foast.Constant, **kwargs) -> foast.Constant: try: type_ = type_translation.from_value(node.value) except GTTypeError as e: - raise FieldOperatorTypeDeductionError.from_foast_node( - node, msg="Could not deduce type of constant." + raise CompilationError(node.location, "Could not deduce type of constant." ) from e return foast.Constant(value=node.value, location=node.location, type=type_) - - -class FieldOperatorTypeDeductionError(GTSyntaxError, SyntaxWarning): - """Exception for problematic type deductions that originate in user code.""" - - def __init__( - self, - msg="", - *, - lineno=0, - offset=0, - filename=None, - end_lineno=None, - end_offset=None, - text=None, - ): - msg = "Could not deduce type: " + msg - super().__init__(msg, (filename, lineno, offset, text, end_lineno, end_offset)) - - @classmethod - def from_foast_node( - cls, - node: foast.LocatedNode, - *, - msg: str = "", - ): - return cls( - msg, - lineno=node.location.line, - offset=node.location.column, - filename=node.location.source, - end_lineno=node.location.end_line, - end_offset=node.location.end_column, - ) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index d85da94103..63fd1b1765 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -34,9 +34,9 @@ neighbor_sum, where, ) -from gt4py.next.ffront.foast_passes.type_deduction import FieldOperatorTypeDeductionError from gt4py.next.iterator.embedded import index_field, np_as_located_field from gt4py.next.program_processors.runners import gtfn_cpu +from gt4py.next.errors import * from next_tests.integration_tests.feature_tests import cases from next_tests.integration_tests.feature_tests.cases import ( @@ -875,7 +875,7 @@ def fieldop_where_k_offset( def test_undefined_symbols(): - with pytest.raises(FieldOperatorTypeDeductionError, match="Undeclared symbol"): + with pytest.raises(CompilationError, match="Undeclared symbol"): @field_operator def return_undefined(): @@ -982,7 +982,7 @@ def unpack( def test_tuple_unpacking_too_many_values(fieldview_backend): with pytest.raises( - FieldOperatorTypeDeductionError, + CompilationError, match=(r"Could not deduce type: Too many values to unpack \(expected 3\)"), ): @@ -994,7 +994,7 @@ def _star_unpack() -> tuple[int, float64, int]: def test_tuple_unpacking_too_many_values(fieldview_backend): with pytest.raises( - FieldOperatorTypeDeductionError, match=(r"Assignment value must be of type tuple!") + CompilationError, match=(r"Assignment value must be of type tuple!") ): @field_operator(backend=fieldview_backend) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py index 08aceac249..174a34556a 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py @@ -21,7 +21,7 @@ from gt4py.next.ffront.decorator import field_operator from gt4py.next.ffront.fbuiltins import Field, float64 -from gt4py.next.ffront.foast_passes.type_deduction import FieldOperatorTypeDeductionError +from gt4py.next.errors import * from gt4py.next.iterator.embedded import index_field, np_as_located_field from gt4py.next.program_processors.runners import gtfn_cpu @@ -362,7 +362,7 @@ def if_without_else( def test_if_non_scalar_condition(): - with pytest.raises(FieldOperatorTypeDeductionError, match="Condition for `if` must be scalar."): + with pytest.raises(CompilationError, match="Condition for `if` must be scalar."): @field_operator def if_non_scalar_condition( @@ -376,7 +376,7 @@ def if_non_scalar_condition( def test_if_non_boolean_condition(): with pytest.raises( - FieldOperatorTypeDeductionError, match="Condition for `if` must be of boolean type." + CompilationError, match="Condition for `if` must be of boolean type." ): @field_operator @@ -392,7 +392,7 @@ def if_non_boolean_condition( def test_if_inconsistent_types(): with pytest.raises( - FieldOperatorTypeDeductionError, + CompilationError, match="Inconsistent types between two branches for variable", ): diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py index 894c496ae7..5757723768 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py @@ -32,9 +32,9 @@ neighbor_sum, where, ) -from gt4py.next.ffront.foast_passes.type_deduction import FieldOperatorTypeDeductionError from gt4py.next.ffront.func_to_foast import FieldOperatorParser from gt4py.next.type_system import type_info, type_specifications as ts +from gt4py.next.errors import * TDim = Dimension("TDim") # Meaningless dimension, used for tests. @@ -386,7 +386,7 @@ def add_bools(a: Field[[TDim], bool], b: Field[[TDim], bool]): return a + b with pytest.raises( - FieldOperatorTypeDeductionError, + CompilationError, match=(r"Type Field\[\[TDim\], bool\] can not be used in operator `\+`!"), ): _ = FieldOperatorParser.apply_to_function(add_bools) @@ -401,7 +401,7 @@ def nonmatching(a: Field[[X], float64], b: Field[[Y], float64]): return a + b with pytest.raises( - FieldOperatorTypeDeductionError, + CompilationError, match=( r"Could not promote `Field\[\[X], float64\]` and `Field\[\[Y\], float64\]` to common type in call to +." ), @@ -414,7 +414,7 @@ def float_bitop(a: Field[[TDim], float], b: Field[[TDim], float]): return a & b with pytest.raises( - FieldOperatorTypeDeductionError, + CompilationError, match=(r"Type Field\[\[TDim\], float64\] can not be used in operator `\&`! "), ): _ = FieldOperatorParser.apply_to_function(float_bitop) @@ -425,7 +425,7 @@ def sign_bool(a: Field[[TDim], bool]): return -a with pytest.raises( - FieldOperatorTypeDeductionError, + CompilationError, match=r"Incompatible type for unary operator `\-`: `Field\[\[TDim\], bool\]`!", ): _ = FieldOperatorParser.apply_to_function(sign_bool) @@ -436,7 +436,7 @@ def not_int(a: Field[[TDim], int64]): return not a with pytest.raises( - FieldOperatorTypeDeductionError, + CompilationError, match=r"Incompatible type for unary operator `not`: `Field\[\[TDim\], int64\]`!", ): _ = FieldOperatorParser.apply_to_function(not_int) @@ -508,7 +508,7 @@ def mismatched_lit() -> Field[[TDim], "float32"]: return float32("1.0") + float64("1.0") with pytest.raises( - FieldOperatorTypeDeductionError, + CompilationError, match=(r"Could not promote `float32` and `float64` to common type in call to +."), ): _ = FieldOperatorParser.apply_to_function(mismatched_lit) @@ -538,7 +538,7 @@ def disjoint_broadcast(a: Field[[ADim], float64]): return broadcast(a, (BDim, CDim)) with pytest.raises( - FieldOperatorTypeDeductionError, + CompilationError, match=r"Expected broadcast dimension is missing", ): _ = FieldOperatorParser.apply_to_function(disjoint_broadcast) @@ -553,7 +553,7 @@ def badtype_broadcast(a: Field[[ADim], float64]): return broadcast(a, (BDim, CDim)) with pytest.raises( - FieldOperatorTypeDeductionError, + CompilationError, match=r"Expected all broadcast dimensions to be of type Dimension.", ): _ = FieldOperatorParser.apply_to_function(badtype_broadcast) @@ -619,7 +619,7 @@ def bad_dim_where(a: Field[[ADim], bool], b: Field[[ADim], float64]): return where(a, ((5.0, 9.0), (b, 6.0)), b) with pytest.raises( - FieldOperatorTypeDeductionError, + CompilationError, match=r"Return arguments need to be of same type", ): _ = FieldOperatorParser.apply_to_function(bad_dim_where) @@ -674,7 +674,7 @@ def modulo_floats(inp: Field[[TDim], float]): return inp % 3.0 with pytest.raises( - FieldOperatorTypeDeductionError, + CompilationError, match=r"Type float64 can not be used in operator `%`", ): _ = FieldOperatorParser.apply_to_function(modulo_floats) @@ -684,7 +684,7 @@ def test_undefined_symbols(): def return_undefined(): return undefined_symbol - with pytest.raises(FieldOperatorTypeDeductionError, match="Undeclared symbol"): + with pytest.raises(CompilationError, match="Undeclared symbol"): _ = FieldOperatorParser.apply_to_function(return_undefined) @@ -697,7 +697,7 @@ def as_offset_dim(a: Field[[ADim, BDim], float], b: Field[[ADim], int]): return a(as_offset(Boff, b)) with pytest.raises( - FieldOperatorTypeDeductionError, + CompilationError, match=f"not in list of offset field dimensions", ): _ = FieldOperatorParser.apply_to_function(as_offset_dim) @@ -712,7 +712,7 @@ def as_offset_dtype(a: Field[[ADim, BDim], float], b: Field[[BDim], float]): return a(as_offset(Boff, b)) with pytest.raises( - FieldOperatorTypeDeductionError, + CompilationError, match=f"Excepted integer for offset field dtype", ): _ = FieldOperatorParser.apply_to_function(as_offset_dtype) diff --git a/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py b/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py index 9d8853e740..da72a28930 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py @@ -53,7 +53,7 @@ int64, where, ) -from gt4py.next.ffront.foast_passes.type_deduction import FieldOperatorTypeDeductionError +from gt4py.next.errors import * from gt4py.next.ffront.func_to_foast import FieldOperatorParser from gt4py.next.iterator import ir as itir from gt4py.next.iterator.builtins import ( @@ -186,7 +186,7 @@ def clashing(inp: Field[[TDim], "float64"]): tmp: Field[[TDim], "int64"] = inp return tmp - with pytest.raises(FieldOperatorTypeDeductionError, match="type inconsistency"): + with pytest.raises(CompilationError, match="type inconsistency"): _ = FieldOperatorParser.apply_to_function(clashing) @@ -276,7 +276,7 @@ def conditional_wrong_mask_type( return where(a, a, a) msg = r"Expected a field with dtype bool." - with pytest.raises(FieldOperatorTypeDeductionError, match=msg): + with pytest.raises(CompilationError, match=msg): _ = FieldOperatorParser.apply_to_function(conditional_wrong_mask_type) @@ -289,7 +289,7 @@ def conditional_wrong_arg_type( return where(mask, a, b) msg = r"Could not promote scalars of different dtype \(not implemented\)." - with pytest.raises(FieldOperatorTypeDeductionError) as exc_info: + with pytest.raises(CompilationError) as exc_info: _ = FieldOperatorParser.apply_to_function(conditional_wrong_arg_type) assert re.search(msg, exc_info.value.__cause__.args[0]) is not None @@ -299,7 +299,7 @@ def test_ternary_with_field_condition(): def ternary_with_field_condition(cond: Field[[], bool]): return 1 if cond else 2 - with pytest.raises(FieldOperatorTypeDeductionError, match=r"should be .* `bool`"): + with pytest.raises(CompilationError, match=r"should be .* `bool`"): _ = FieldOperatorParser.apply_to_function(ternary_with_field_condition) @@ -426,8 +426,8 @@ def zero_dims_ternary( ): return a if cond == 1 else b - msg = r"Could not deduce type" - with pytest.raises(FieldOperatorTypeDeductionError) as exc_info: + msg = r"Incompatible datatypes in operator `==`" + with pytest.raises(CompilationError) as exc_info: _ = FieldOperatorParser.apply_to_function(zero_dims_ternary) assert re.search(msg, exc_info.value.args[0]) is not None diff --git a/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast_error_line_number.py b/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast_error_line_number.py index adf7a47602..edbdd5952d 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast_error_line_number.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast_error_line_number.py @@ -22,6 +22,7 @@ from gt4py.next.ffront import func_to_foast as f2f, source_utils as src_utils from gt4py.next.ffront.foast_passes import type_deduction from gt4py.next.ffront.func_to_foast import FieldOperatorParser +from gt4py.next.errors import * # NOTE: These tests are sensitive to filename and the line number of the marked statement @@ -126,7 +127,7 @@ def test_fo_type_deduction_error(): def field_operator_with_undeclared_symbol(): return undeclared_symbol - with pytest.raises(type_deduction.FieldOperatorTypeDeductionError) as exc_info: + with pytest.raises(CompilationError) as exc_info: _ = f2f.FieldOperatorParser.apply_to_function(field_operator_with_undeclared_symbol) exc = exc_info.value From d59808e297535d7706d8739bfaf560735cced5e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Kardos?= Date: Tue, 6 Jun 2023 11:17:02 +0200 Subject: [PATCH 04/54] past type deduction uses new exceptions --- src/gt4py/next/ffront/decorator.py | 9 ++- .../next/ffront/past_passes/type_deduction.py | 58 +++++-------------- .../ffront_tests/test_program.py | 6 +- .../feature_tests/test_util_cases.py | 3 +- .../ffront_tests/test_func_to_past.py | 16 ++--- .../ffront_tests/test_past_to_itir.py | 3 +- 6 files changed, 32 insertions(+), 63 deletions(-) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 25d194e161..53b26fac36 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -47,13 +47,14 @@ from gt4py.next.ffront.past_passes.closure_var_type_deduction import ( ClosureVarTypeDeduction as ProgramClosureVarTypeDeduction, ) -from gt4py.next.ffront.past_passes.type_deduction import ProgramTypeDeduction, ProgramTypeError +from gt4py.next.ffront.past_passes.type_deduction import ProgramTypeDeduction from gt4py.next.ffront.past_to_itir import ProgramLowering from gt4py.next.ffront.source_utils import SourceDefinition, get_closure_vars_from_function from gt4py.next.iterator import ir as itir from gt4py.next.program_processors import processor_interface as ppi from gt4py.next.program_processors.runners import roundtrip from gt4py.next.type_system import type_info, type_specifications as ts, type_translation +from gt4py.next.errors import * DEFAULT_BACKEND: Callable = roundtrip.executor @@ -101,7 +102,7 @@ def _canonicalize_args( for param_i, param in enumerate(node_params): if param.id in new_kwargs: if param_i < len(args): - raise ProgramTypeError(f"got multiple values for argument {param.id}.") + raise ValueError(f"got multiple values for argument {param.id}.") new_args.append(kwargs[param.id]) new_kwargs.pop(param.id) elif param_i < len(args): @@ -330,9 +331,7 @@ def _validate_args(self, *args, **kwargs) -> None: raise_exception=True, ) except GTTypeError as err: - raise ProgramTypeError.from_past_node( - self.past_node, msg=f"Invalid argument types in call to `{self.past_node.id}`!" - ) from err + raise ValueError(f"Invalid argument types in call to `{self.past_node.id}`!") from err def _process_args(self, args: tuple, kwargs: dict) -> tuple[tuple, tuple, dict[str, Any]]: args, kwargs = _canonicalize_args(self.past_node.params, args, kwargs) diff --git a/src/gt4py/next/ffront/past_passes/type_deduction.py b/src/gt4py/next/ffront/past_passes/type_deduction.py index 3c8c362bbe..3fc5bce3bc 100644 --- a/src/gt4py/next/ffront/past_passes/type_deduction.py +++ b/src/gt4py/next/ffront/past_passes/type_deduction.py @@ -22,6 +22,7 @@ type_specifications as ts_ffront, ) from gt4py.next.type_system import type_info, type_specifications as ts +from gt4py.next.errors import * def _ensure_no_sliced_field(entry: past.Expr): @@ -145,8 +146,8 @@ def _deduce_binop_type( # check both types compatible for arg in (left, right): if not isinstance(arg.type, ts.ScalarType) or not is_compatible(arg.type): - raise ProgramTypeError.from_past_node( - arg, msg=f"Type {arg.type} can not be used in operator `{node.op}`!" + raise CompilationError( + arg.location, f"Type {arg.type} can not be used in operator `{node.op}`!" ) left_type = cast(ts.ScalarType, left.type) @@ -158,17 +159,17 @@ def _deduce_binop_type( if node.op == dialect_ast_enums.BinaryOperator.MOD and not type_info.is_integral( right_type ): - raise ProgramTypeError.from_past_node( - arg, - msg=f"Type {right_type} can not be used in operator `{node.op}`, it can only accept ints", + raise CompilationError( + arg.location, + f"Type {right_type} can not be used in operator `{node.op}`, it can only accept ints", ) try: return type_info.promote(left_type, right_type) except GTTypeError as ex: - raise ProgramTypeError.from_past_node( - node, - msg=f"Could not promote `{left_type}` and `{right_type}` to common type" + raise CompilationError( + node.location, + f"Could not promote `{left_type}` and `{right_type}` to common type" f" in call to `{node.op}`.", ) from ex @@ -228,8 +229,8 @@ def visit_Call(self, node: past.Call, **kwargs): ) except GTTypeError as ex: - raise ProgramTypeError.from_past_node( - node, msg=f"Invalid call to `{node.func.id}`." + raise CompilationError( + node.location, f"Invalid call to `{node.func.id}`." ) from ex return past.Call( @@ -243,41 +244,8 @@ def visit_Call(self, node: past.Call, **kwargs): def visit_Name(self, node: past.Name, **kwargs) -> past.Name: symtable = kwargs["symtable"] if node.id not in symtable or symtable[node.id].type is None: - raise ProgramTypeError.from_past_node( - node, msg=f"Undeclared or untyped symbol `{node.id}`." + raise CompilationError( + node.location, f"Undeclared or untyped symbol `{node.id}`." ) return past.Name(id=node.id, type=symtable[node.id].type, location=node.location) - - -class ProgramTypeError(GTTypeError): - """Exception for problematic type deductions that originate in user code.""" - - def __init__( - self, - msg="", - *, - lineno=0, - offset=0, - filename=None, - end_lineno=None, - end_offset=None, - text=None, - ): - super().__init__(msg, (filename, lineno, offset, text, end_lineno, end_offset)) - - @classmethod - def from_past_node( - cls, - node: past.LocatedNode, - *, - msg: str = "", - ): - return cls( - msg, - lineno=node.location.line, - offset=node.location.column, - filename=node.location.source, - end_lineno=node.location.end_line, - end_offset=node.location.end_column, - ) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py index 39ec935174..d6e7cb0937 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py @@ -20,8 +20,8 @@ import pytest from gt4py.next.common import Field, GTTypeError +from gt4py.next.errors import * from gt4py.next.ffront.decorator import field_operator, program -from gt4py.next.ffront.past_passes.type_deduction import ProgramTypeError from gt4py.next.iterator.embedded import np_as_located_field from gt4py.next.program_processors.runners import gtfn_cpu, roundtrip @@ -237,7 +237,7 @@ def test_wrong_argument_type(fieldview_backend, copy_program_def): copy_program = program(copy_program_def, backend=fieldview_backend) - with pytest.raises(ProgramTypeError) as exc_info: + with pytest.raises(ValueError) as exc_info: # program is defined on Field[[IDim], ...], but we call with # Field[[JDim], ...] copy_program(inp, out, offset_provider={}) @@ -309,5 +309,5 @@ def program_input_kwargs( program_input_kwargs(a=input_1, b=input_2, c=input_3, out=out, offset_provider={}) assert np.allclose(expected, out) - with pytest.raises(GTTypeError, match="got multiple values for argument"): + with pytest.raises(ValueError, match="got multiple values for argument"): program_input_kwargs(input_2, input_3, a=input_1, out=out, offset_provider={}) diff --git a/tests/next_tests/integration_tests/feature_tests/test_util_cases.py b/tests/next_tests/integration_tests/feature_tests/test_util_cases.py index 3b5cf1a5b9..8cdbe02c5e 100644 --- a/tests/next_tests/integration_tests/feature_tests/test_util_cases.py +++ b/tests/next_tests/integration_tests/feature_tests/test_util_cases.py @@ -16,6 +16,7 @@ import pytest from gt4py.next import common +from gt4py.next.errors import * from gt4py.next.ffront.decorator import field_operator from gt4py.next.program_processors.runners import roundtrip @@ -88,7 +89,7 @@ def test_verify_fails_with_wrong_type(cartesian_case): # noqa: F811 # fixtures b = cases.allocate(cartesian_case, addition, "b")() out = cases.allocate(cartesian_case, addition, cases.RETURN)() - with pytest.raises(common.GTTypeError): + with pytest.raises(CompilationError): cases.verify(cartesian_case, addition, a, b, out=out, ref=a.array() + b.array()) diff --git a/tests/next_tests/unit_tests/ffront_tests/test_func_to_past.py b/tests/next_tests/unit_tests/ffront_tests/test_func_to_past.py index cf681c325f..02d512018d 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_func_to_past.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_func_to_past.py @@ -23,8 +23,8 @@ from gt4py.next.ffront.decorator import field_operator from gt4py.next.ffront.fbuiltins import float64 from gt4py.next.ffront.func_to_past import ProgramParser -from gt4py.next.ffront.past_passes.type_deduction import ProgramTypeError from gt4py.next.type_system import type_specifications as ts +from gt4py.next.errors import * from next_tests.past_common_fixtures import ( IDim, @@ -113,7 +113,7 @@ def undefined_field_program(in_field: Field[[IDim], "float64"]): identity(in_field, out=out_field) with pytest.raises( - ProgramTypeError, + CompilationError, match=(r"Undeclared or untyped symbol `out_field`."), ): ProgramParser.apply_to_function(undefined_field_program) @@ -162,7 +162,7 @@ def domain_format_1_program(in_field: Field[[IDim], float64]): domain_format_1(in_field, out=in_field, domain=(0, 2)) with pytest.raises( - GTTypeError, + CompilationError, ) as exc_info: ProgramParser.apply_to_function(domain_format_1_program) @@ -181,7 +181,7 @@ def domain_format_2_program(in_field: Field[[IDim], float64]): domain_format_2(in_field, out=in_field, domain={IDim: (0, 1, 2)}) with pytest.raises( - GTTypeError, + CompilationError, ) as exc_info: ProgramParser.apply_to_function(domain_format_2_program) @@ -200,7 +200,7 @@ def domain_format_3_program(in_field: Field[[IDim], float64]): domain_format_3(in_field, domain={IDim: (0, 2)}) with pytest.raises( - GTTypeError, + CompilationError, ) as exc_info: ProgramParser.apply_to_function(domain_format_3_program) @@ -221,7 +221,7 @@ def domain_format_4_program(in_field: Field[[IDim], float64]): ) with pytest.raises( - GTTypeError, + CompilationError, ) as exc_info: ProgramParser.apply_to_function(domain_format_4_program) @@ -240,7 +240,7 @@ def domain_format_5_program(in_field: Field[[IDim], float64]): domain_format_5(in_field, out=in_field, domain={IDim: ("1.0", 9.0)}) with pytest.raises( - GTTypeError, + CompilationError, ) as exc_info: ProgramParser.apply_to_function(domain_format_5_program) @@ -259,7 +259,7 @@ def domain_format_6_program(in_field: Field[[IDim], float64]): domain_format_6(in_field, out=in_field, domain={}) with pytest.raises( - GTTypeError, + CompilationError, ) as exc_info: ProgramParser.apply_to_function(domain_format_6_program) diff --git a/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py b/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py index d8d46aed3b..9c9fa73f34 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py @@ -23,6 +23,7 @@ from gt4py.next.ffront.func_to_past import ProgramParser from gt4py.next.ffront.past_to_itir import ProgramLowering from gt4py.next.iterator import ir as itir +from gt4py.next.errors import * from next_tests.past_common_fixtures import ( IDim, @@ -165,7 +166,7 @@ def inout_field_program(inout_field: Field[[IDim], "float64"]): def test_invalid_call_sig_program(invalid_call_sig_program_def): with pytest.raises( - GTTypeError, + CompilationError, ) as exc_info: ProgramLowering.apply( ProgramParser.apply_to_function(invalid_call_sig_program_def), From 601737a8174c4f80479c68e7fe22546469d8720c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Kardos?= Date: Tue, 6 Jun 2023 15:12:46 +0200 Subject: [PATCH 05/54] remove unused old exception classes --- src/gt4py/next/common.py | 20 +------------------ .../ffront/foast_passes/type_deduction.py | 2 +- src/gt4py/next/ffront/source_utils.py | 4 ++-- 3 files changed, 4 insertions(+), 22 deletions(-) diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 5ece3a23ec..0ae0729b6d 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -124,25 +124,7 @@ class GTError: ... -class GTRuntimeError(GTError, RuntimeError): - """Base class for GridTools run-time errors.""" - - ... - - -class GTSyntaxError(GTError, SyntaxError): - """Base class for GridTools syntax errors.""" - - ... - - class GTTypeError(GTError, TypeError): """Base class for GridTools type errors.""" - ... - - -class GTValueError(GTError, ValueError): - """Base class for GridTools value errors.""" - - ... + ... \ No newline at end of file diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index 4448704946..14200aa3ec 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -16,7 +16,7 @@ import gt4py.next.ffront.field_operator_ast as foast from gt4py.eve import NodeTranslator, NodeVisitor, traits -from gt4py.next.common import DimensionKind, GTSyntaxError, GTTypeError +from gt4py.next.common import DimensionKind, GTTypeError from gt4py.next.ffront import ( # noqa dialect_ast_enums, fbuiltins, diff --git a/src/gt4py/next/ffront/source_utils.py b/src/gt4py/next/ffront/source_utils.py index a112c946d9..43519927fd 100644 --- a/src/gt4py/next/ffront/source_utils.py +++ b/src/gt4py/next/ffront/source_utils.py @@ -55,13 +55,13 @@ def make_symbol_names_from_source(source: str, filename: str = MISSING_FILENAME) try: mod_st = symtable.symtable(source, filename, "exec") except SyntaxError as err: - raise common.GTValueError( + raise ValueError( f"Unexpected error when parsing provided source code (\n{source}\n)" ) from err assert mod_st.get_type() == "module" if len(children := mod_st.get_children()) != 1: - raise common.GTValueError( + raise ValueError( f"Sources with multiple function definitions are not yet supported (\n{source}\n)" ) From 05b5ef2fa391bb0ce0abe5c4bb151ad685fd3936 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Kardos?= Date: Tue, 6 Jun 2023 15:45:56 +0200 Subject: [PATCH 06/54] rename fields in source definition to better reflect meaning --- src/gt4py/next/ffront/dialect_parser.py | 12 ++++++------ src/gt4py/next/ffront/source_utils.py | 6 +++--- .../test_func_to_foast_error_line_number.py | 2 +- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/gt4py/next/ffront/dialect_parser.py b/src/gt4py/next/ffront/dialect_parser.py index 1ff4d0c0b2..72ff372d10 100644 --- a/src/gt4py/next/ffront/dialect_parser.py +++ b/src/gt4py/next/ffront/dialect_parser.py @@ -35,10 +35,10 @@ def parse_source_definition(source_definition: SourceDefinition) -> ast.AST: return ast.parse(textwrap.dedent(source_definition.source)).body[0] except SyntaxError as err: err.filename = source_definition.filename - err.lineno = err.lineno + source_definition.starting_line if err.lineno is not None else None - err.offset = err.offset + source_definition.starting_column if err.offset is not None else None - err.end_lineno = err.end_lineno + source_definition.starting_line if err.end_lineno is not None else None - err.end_offset = err.end_offset + source_definition.starting_column if err.end_offset is not None else None + err.lineno = err.lineno + source_definition.line_offset if err.lineno is not None else None + err.offset = err.offset + source_definition.column_offset if err.offset is not None else None + err.end_lineno = err.end_lineno + source_definition.line_offset if err.end_lineno is not None else None + err.end_offset = err.end_offset + source_definition.column_offset if err.end_offset is not None else None raise err @@ -96,8 +96,8 @@ def generic_visit(self, node: ast.AST) -> None: def get_location(self, node: ast.AST) -> SourceLocation: file = self.source_definition.filename - line_offset = self.source_definition.starting_line - col_offset = self.source_definition.starting_column + line_offset = self.source_definition.line_offset + col_offset = self.source_definition.column_offset line = node.lineno + line_offset if node.lineno is not None else None end_line = node.end_lineno + line_offset if node.end_lineno is not None else None diff --git a/src/gt4py/next/ffront/source_utils.py b/src/gt4py/next/ffront/source_utils.py index 43519927fd..5d9c648ae6 100644 --- a/src/gt4py/next/ffront/source_utils.py +++ b/src/gt4py/next/ffront/source_utils.py @@ -119,13 +119,13 @@ def foo(a): source: str filename: str = MISSING_FILENAME - starting_line: int = 0 - starting_column: int = 0 + line_offset: int = 0 + column_offset: int = 0 def __iter__(self) -> Iterator: yield self.source yield self.filename - yield self.starting_line + yield self.line_offset from_function = staticmethod(make_source_definition_from_function) diff --git a/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast_error_line_number.py b/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast_error_line_number.py index edbdd5952d..f4e8f4d82a 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast_error_line_number.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast_error_line_number.py @@ -97,7 +97,7 @@ def test_syntax_error_without_function(): """Dialect parsers report line numbers correctly when applied to `SourceDefinition`.""" source_definition = src_utils.SourceDefinition( - starting_line=61, + line_offset=61, source=""" def invalid_python_syntax(): # This function contains a python syntax error From 99785f35dcda59ab09091893a2bff519d0b8fb97 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Kardos?= Date: Tue, 6 Jun 2023 16:46:44 +0200 Subject: [PATCH 07/54] experiment exception printing --- tests/next_tests/exception_printing.py | 8 ++++++++ 1 file changed, 8 insertions(+) create mode 100644 tests/next_tests/exception_printing.py diff --git a/tests/next_tests/exception_printing.py b/tests/next_tests/exception_printing.py new file mode 100644 index 0000000000..65603af7cc --- /dev/null +++ b/tests/next_tests/exception_printing.py @@ -0,0 +1,8 @@ +from gt4py.next.ffront.decorator import field_operator +from gt4py.next.errors import * + +@field_operator +def testee(a) -> float: + return 1 + +testee(1, offset_provider={}) \ No newline at end of file From 4c6081121765f916f8907b984954d303c4763c28 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Kardos?= Date: Tue, 13 Jun 2023 16:47:26 +0200 Subject: [PATCH 08/54] exception hook and cleanups --- src/gt4py/next/errors/__init__.py | 1 + src/gt4py/next/errors/exceptions.py | 23 +++++++++++++------ src/gt4py/next/errors/formatting.py | 14 +++++++++++ src/gt4py/next/errors/tools.py | 3 ++- tests/next_tests/exception_printing.py | 10 ++++---- .../errors_tests/test_compilation_error.py | 5 ++-- 6 files changed, 41 insertions(+), 15 deletions(-) create mode 100644 src/gt4py/next/errors/formatting.py diff --git a/src/gt4py/next/errors/__init__.py b/src/gt4py/next/errors/__init__.py index 6db84e4c7a..778954b1f3 100644 --- a/src/gt4py/next/errors/__init__.py +++ b/src/gt4py/next/errors/__init__.py @@ -8,3 +8,4 @@ UnexpectedKeywordArgError, MissingAttributeError ) +from . import formatting \ No newline at end of file diff --git a/src/gt4py/next/errors/exceptions.py b/src/gt4py/next/errors/exceptions.py index 210eb7284f..4e93fe3af8 100644 --- a/src/gt4py/next/errors/exceptions.py +++ b/src/gt4py/next/errors/exceptions.py @@ -1,13 +1,15 @@ from gt4py.eve import SourceLocation -from typing import Any +from typing import Any, Optional from . import tools + class CompilationError(SyntaxError): - def __init__(self, location: SourceLocation, message: str): - try: - source_code = tools.get_code_at_location(location) - except ValueError: - source_code = None + def __init__(self, location: SourceLocation, message: str, *, snippet: str | bool = True): + source_code = None + if isinstance(snippet, str): + source_code = snippet + if snippet is True: + source_code = CompilationError.get_source_from_location(location) super().__init__( message, ( @@ -21,7 +23,7 @@ def __init__(self, location: SourceLocation, message: str): ) @property - def location(self): + def location(self) -> SourceLocation: return SourceLocation( source=self.filename, line=self.lineno, @@ -30,6 +32,13 @@ def location(self): end_column=self.end_offset ) + @staticmethod + def get_source_from_location(location: SourceLocation) -> Optional[str]: + try: + return tools.get_source_from_location(location) + except ValueError: + return None + class UndefinedSymbolError(CompilationError): def __init__(self, location: SourceLocation, name: str): diff --git a/src/gt4py/next/errors/formatting.py b/src/gt4py/next/errors/formatting.py new file mode 100644 index 0000000000..5495d6125f --- /dev/null +++ b/src/gt4py/next/errors/formatting.py @@ -0,0 +1,14 @@ +import sys +import traceback +from . import exceptions +from typing import Callable + + +def compilation_error_hook(fallback: Callable, type_: type, value: exceptions.CompilationError, _): + if issubclass(type_, exceptions.CompilationError): + print("".join(traceback.format_exception(value, limit=0)), file=sys.stderr) + else: + fallback(type_, value, traceback) + + +sys.excepthook = lambda ty, val, tb: compilation_error_hook(sys.excepthook, ty, val, tb) \ No newline at end of file diff --git a/src/gt4py/next/errors/tools.py b/src/gt4py/next/errors/tools.py index 7dac7aca3a..843e3d5437 100644 --- a/src/gt4py/next/errors/tools.py +++ b/src/gt4py/next/errors/tools.py @@ -1,7 +1,8 @@ import pathlib from gt4py.eve import SourceLocation -def get_code_at_location(location: SourceLocation): + +def get_source_from_location(location: SourceLocation): try: source_file = pathlib.Path(location.source) source_code = source_file.read_text() diff --git a/tests/next_tests/exception_printing.py b/tests/next_tests/exception_printing.py index 65603af7cc..94c611de95 100644 --- a/tests/next_tests/exception_printing.py +++ b/tests/next_tests/exception_printing.py @@ -1,8 +1,8 @@ -from gt4py.next.ffront.decorator import field_operator from gt4py.next.errors import * +import inspect +from gt4py.eve import SourceLocation -@field_operator -def testee(a) -> float: - return 1 -testee(1, offset_provider={}) \ No newline at end of file +frameinfo = inspect.getframeinfo(inspect.currentframe()) +loc = SourceLocation(frameinfo.lineno, 1, frameinfo.filename, end_line=frameinfo.lineno, end_column=5) +raise CompilationError(loc, "this is an error message") \ No newline at end of file diff --git a/tests/next_tests/unit_tests/errors_tests/test_compilation_error.py b/tests/next_tests/unit_tests/errors_tests/test_compilation_error.py index fff2886e8b..e4784277bf 100644 --- a/tests/next_tests/unit_tests/errors_tests/test_compilation_error.py +++ b/tests/next_tests/unit_tests/errors_tests/test_compilation_error.py @@ -1,8 +1,9 @@ from gt4py.next.errors import CompilationError from gt4py.eve import SourceLocation -loc = SourceLocation(5, 1, "/source/file.py", end_line=5, end_column=9) -msg = 'a message' + +loc = SourceLocation(5, 2, "/source/file.py", end_line=5, end_column=9) +msg = "a message" def test_message(): From 1451e79a209062e5892126f76ba2fe02eedbda63 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Kardos?= Date: Wed, 14 Jun 2023 12:19:10 +0200 Subject: [PATCH 09/54] fix infinite recursion --- src/gt4py/next/errors/formatting.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/errors/formatting.py b/src/gt4py/next/errors/formatting.py index 5495d6125f..dd9049ab39 100644 --- a/src/gt4py/next/errors/formatting.py +++ b/src/gt4py/next/errors/formatting.py @@ -11,4 +11,5 @@ def compilation_error_hook(fallback: Callable, type_: type, value: exceptions.Co fallback(type_, value, traceback) -sys.excepthook = lambda ty, val, tb: compilation_error_hook(sys.excepthook, ty, val, tb) \ No newline at end of file +_fallback = sys.excepthook +sys.excepthook = lambda ty, val, tb: compilation_error_hook(_fallback, ty, val, tb) \ No newline at end of file From ba17c6319b9c7300c999af37d7752542e6a8969d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Kardos?= Date: Wed, 14 Jun 2023 15:50:43 +0200 Subject: [PATCH 10/54] remove old exceptions --- src/gt4py/next/common.py | 20 -------- src/gt4py/next/ffront/decorator.py | 8 ++-- .../ffront/foast_passes/type_deduction.py | 16 +++---- src/gt4py/next/ffront/func_to_foast.py | 5 +- .../next/ffront/past_passes/type_deduction.py | 25 +++++----- src/gt4py/next/ffront/past_to_itir.py | 12 ++--- src/gt4py/next/ffront/type_info.py | 6 +-- src/gt4py/next/type_system/type_info.py | 36 +++++++------- .../next/type_system/type_translation.py | 47 +++++++------------ .../ffront_tests/test_program.py | 4 +- .../ffront_tests/test_type_deduction.py | 4 +- .../test_decorator_domain_deduction.py | 4 +- .../ffront_tests/test_func_to_foast.py | 11 ++--- .../ffront_tests/test_func_to_past.py | 2 +- .../ffront_tests/test_past_to_itir.py | 4 +- .../test_type_translation.py | 30 ++++++------ 16 files changed, 100 insertions(+), 134 deletions(-) diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 0ae0729b6d..b0f0b8ac11 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -108,23 +108,3 @@ class NeighborTable(Connectivity, Protocol): class GridType(StrEnum): CARTESIAN = "cartesian" UNSTRUCTURED = "unstructured" - - -class GTError: - """Base class for GridTools exceptions. - - Notes: - This base class has to be always inherited together with a standard - exception, and thus it should not be used as direct superclass - for custom exceptions. Inherit directly from :class:`GTTypeError`, - :class:`GTTypeError`, ... - - """ - - ... - - -class GTTypeError(GTError, TypeError): - """Base class for GridTools type errors.""" - - ... \ No newline at end of file diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 53b26fac36..8ea93caeb6 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -31,7 +31,7 @@ from gt4py.eve.extended_typing import Any, Optional from gt4py.eve.utils import UIDGenerator -from gt4py.next.common import DimensionKind, GridType, GTTypeError, Scalar +from gt4py.next.common import DimensionKind, GridType, Scalar from gt4py.next.ffront import ( dialect_ast_enums, field_operator_ast as foast, @@ -142,7 +142,7 @@ def is_cartesian_offset(o: FieldOffset): break if requested_grid_type == GridType.CARTESIAN and deduced_grid_type == GridType.UNSTRUCTURED: - raise GTTypeError( + raise ValueError( "grid_type == GridType.CARTESIAN was requested, but unstructured `FieldOffset` or local `Dimension` was found." ) @@ -330,7 +330,7 @@ def _validate_args(self, *args, **kwargs) -> None: with_kwargs=kwarg_types, raise_exception=True, ) - except GTTypeError as err: + except ValueError as err: raise ValueError(f"Invalid argument types in call to `{self.past_node.id}`!") from err def _process_args(self, args: tuple, kwargs: dict) -> tuple[tuple, tuple, dict[str, Any]]: @@ -383,7 +383,7 @@ def _column_axis(self): f"- {dim.value}: {', '.join(scanops)}" for dim, scanops in scanops_per_axis.items() ] - raise GTTypeError( + raise TypeError( "Only `ScanOperator`s defined on the same axis " + "can be used in a `Program`, but found:\n" + "\n".join(scanops_per_axis_strs) diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index 14200aa3ec..d53215f420 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -16,7 +16,7 @@ import gt4py.next.ffront.field_operator_ast as foast from gt4py.eve import NodeTranslator, NodeVisitor, traits -from gt4py.next.common import DimensionKind, GTTypeError +from gt4py.next.common import DimensionKind from gt4py.next.ffront import ( # noqa dialect_ast_enums, fbuiltins, @@ -52,7 +52,7 @@ def boolified_type(symbol_type: ts.TypeSpec) -> ts.ScalarType | ts.FieldType: return scalar_bool elif type_class is ts.FieldType: return ts.FieldType(dtype=scalar_bool, dims=type_info.extract_dims(symbol_type)) - raise GTTypeError(f"Can not boolify type {symbol_type}!") + raise ValueError(f"Can not boolify type {symbol_type}!") def construct_tuple_type( @@ -527,7 +527,7 @@ def _deduce_compare_type( # transform operands to have bool dtype and use regular promotion # mechanism to handle dimension promotion return type_info.promote(boolified_type(left.type), boolified_type(right.type)) - except GTTypeError as ex: + except ValueError as ex: raise CompilationError(node.location, f"Could not promote `{left.type}` and `{right.type}` to common type" f" in call to `{node.op}`.", @@ -569,7 +569,7 @@ def _deduce_binop_type( try: return type_info.promote(left_type, right_type) - except GTTypeError as ex: + except ValueError as ex: raise CompilationError(node.location, f"Could not promote `{left_type}` and `{right_type}` to common type" f" in call to `{node.op}`.", @@ -644,7 +644,7 @@ def visit_Call(self, node: foast.Call, **kwargs) -> foast.Call: with_kwargs=kwarg_types, raise_exception=True, ) - except GTTypeError as err: + except ValueError as err: raise CompilationError(node.location, f"Invalid argument types in call to `{new_func}`!" ) from err @@ -723,7 +723,7 @@ def _visit_math_built_in(self, node: foast.Call, **kwargs) -> foast.Call: return_type = type_info.promote( *((cast(ts.FieldType | ts.ScalarType, arg.type)) for arg in node.args) ) - except GTTypeError as ex: + except ValueError as ex: raise CompilationError(node.location, error_msg_preamble ) from ex else: @@ -853,7 +853,7 @@ def _visit_where(self, node: foast.Call, **kwargs) -> foast.Call: promoted_type = type_info.promote(true_branch_fieldtype, false_branch_fieldtype) return_type = promote_to_mask_type(mask_type, promoted_type) - except GTTypeError as ex: + except ValueError as ex: raise CompilationError(node.location, f"Incompatible argument in call to `{str(node.func)}`.", ) from ex @@ -900,7 +900,7 @@ def _visit_broadcast(self, node: foast.Call, **kwargs) -> foast.Call: def visit_Constant(self, node: foast.Constant, **kwargs) -> foast.Constant: try: type_ = type_translation.from_value(node.value) - except GTTypeError as e: + except ValueError as e: raise CompilationError(node.location, "Could not deduce type of constant." ) from e return foast.Constant(value=node.value, location=node.location, type=type_) diff --git a/src/gt4py/next/ffront/func_to_foast.py b/src/gt4py/next/ffront/func_to_foast.py index 6bda016826..e8cba554d3 100644 --- a/src/gt4py/next/ffront/func_to_foast.py +++ b/src/gt4py/next/ffront/func_to_foast.py @@ -104,7 +104,8 @@ def _postprocess_dialect_ast( # TODO(tehrengruber): use `type_info.return_type` when the type of the # arguments becomes available here if annotated_return_type != foast_node.type.returns: # type: ignore[union-attr] # revisit when `type_info.return_type` is implemented - raise common.GTTypeError( + raise CompilationError( + foast_node.location, f"Annotated return type does not match deduced return type. Expected `{foast_node.type.returns}`" # type: ignore[union-attr] # revisit when `type_info.return_type` is implemented f", but got `{annotated_return_type}`." ) @@ -472,7 +473,7 @@ def visit_Constant(self, node: ast.Constant, **kwargs) -> foast.Constant: loc = self.get_location(node) try: type_ = type_translation.from_value(node.value) - except common.GTTypeError as e: + except ValueError as e: raise CompilationError(loc, f"constants of type {type(node.value)} are not permitted") from None return foast.Constant( diff --git a/src/gt4py/next/ffront/past_passes/type_deduction.py b/src/gt4py/next/ffront/past_passes/type_deduction.py index 3fc5bce3bc..bc4fb66ed1 100644 --- a/src/gt4py/next/ffront/past_passes/type_deduction.py +++ b/src/gt4py/next/ffront/past_passes/type_deduction.py @@ -15,7 +15,6 @@ from typing import Optional, cast from gt4py.eve import NodeTranslator, traits -from gt4py.next.common import GTTypeError from gt4py.next.ffront import ( dialect_ast_enums, program_ast as past, @@ -34,7 +33,7 @@ def _ensure_no_sliced_field(entry: past.Expr): For example, if argument is of type past.Subscript, this function will throw an error as both slicing and domain are being applied """ if not isinstance(entry, past.Name) and not isinstance(entry, past.TupleExpr): - raise GTTypeError("Either only domain or slicing allowed") + raise ValueError("Either only domain or slicing allowed") elif isinstance(entry, past.TupleExpr): for param in entry.elts: _ensure_no_sliced_field(param) @@ -57,39 +56,39 @@ def _validate_operator_call(new_func: past.Name, new_kwargs: dict): new_func.type, (ts_ffront.FieldOperatorType, ts_ffront.ScanOperatorType), ): - raise GTTypeError( + raise ValueError( f"Only calls `FieldOperator`s and `ScanOperator`s " f"allowed in `Program`, but got `{new_func.type}`." ) if "out" not in new_kwargs: - raise GTTypeError("Missing required keyword argument(s) `out`.") + raise ValueError("Missing required keyword argument(s) `out`.") if "domain" in new_kwargs: _ensure_no_sliced_field(new_kwargs["out"]) domain_kwarg = new_kwargs["domain"] if not isinstance(domain_kwarg, past.Dict): - raise GTTypeError( + raise ValueError( f"Only Dictionaries allowed in domain, but got `{type(domain_kwarg)}`." ) if len(domain_kwarg.values_) == 0 and len(domain_kwarg.keys_) == 0: - raise GTTypeError("Empty domain not allowed.") + raise ValueError("Empty domain not allowed.") for dim in domain_kwarg.keys_: if not isinstance(dim.type, ts.DimensionType): - raise GTTypeError( + raise ValueError( f"Only Dimension allowed in domain dictionary keys, but got `{dim}` which is of type `{dim.type}`." ) for domain_values in domain_kwarg.values_: if len(domain_values.elts) != 2: - raise GTTypeError( + raise ValueError( f"Only 2 values allowed in domain range, but got `{len(domain_values.elts)}`." ) if not _is_integral_scalar(domain_values.elts[0]) or not _is_integral_scalar( domain_values.elts[1] ): - raise GTTypeError( + raise ValueError( f"Only integer values allowed in domain range, but got {domain_values.elts[0].type} and {domain_values.elts[1].type}." ) @@ -166,7 +165,7 @@ def _deduce_binop_type( try: return type_info.promote(left_type, right_type) - except GTTypeError as ex: + except ValueError as ex: raise CompilationError( node.location, f"Could not promote `{left_type}` and `{right_type}` to common type" @@ -211,14 +210,14 @@ def visit_Call(self, node: past.Call, **kwargs): new_func.type, with_args=arg_types, with_kwargs=kwarg_types ) if operator_return_type != new_kwargs["out"].type: - raise GTTypeError( + raise ValueError( f"Expected keyword argument `out` to be of " f"type {operator_return_type}, but got " f"{new_kwargs['out'].type}." ) elif new_func.id in ["minimum", "maximum"]: if new_args[0].type != new_args[1].type: - raise GTTypeError( + raise ValueError( f"First and second argument in {new_func.id} must be the same type." f"Got `{new_args[0].type}` and `{new_args[1].type}`." ) @@ -228,7 +227,7 @@ def visit_Call(self, node: past.Call, **kwargs): "Only calls `FieldOperator`s, `ScanOperator`s or minimum and maximum builtins allowed" ) - except GTTypeError as ex: + except ValueError as ex: raise CompilationError( node.location, f"Invalid call to `{node.func.id}`." ) from ex diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index 7c9ee5e6a0..92b0831903 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -17,7 +17,7 @@ from typing import Optional, cast from gt4py.eve import NodeTranslator, concepts, traits -from gt4py.next.common import Dimension, DimensionKind, GridType, GTTypeError +from gt4py.next.common import Dimension, DimensionKind, GridType from gt4py.next.ffront import program_ast as past from gt4py.next.iterator import ir as itir from gt4py.next.type_system import type_info, type_specifications as ts @@ -37,7 +37,7 @@ def _flatten_tuple_expr( for e in node.elts: result.extend(_flatten_tuple_expr(e)) return result - raise GTTypeError( + raise ValueError( "Only `past.Name`, `past.Subscript` or `past.TupleExpr`s thereof are allowed." ) @@ -183,7 +183,7 @@ def _construct_itir_out_arg(self, node: past.Expr) -> itir.Expr: args=[self._construct_itir_out_arg(el) for el in node.elts], ) else: - raise GTTypeError( + raise ValueError( "Unexpected `out` argument. Must be a `past.Name`, `past.Subscript`" " or a `past.TupleExpr` thereof." ) @@ -227,7 +227,7 @@ def _construct_itir_domain_arg( ) if dim.kind == DimensionKind.LOCAL: - raise GTTypeError(f"Dimension {dim.value} must not be local.") + raise ValueError(f"Dimension {dim.value} must not be local.") domain_args.append( itir.FunCall( fun=itir.SymRef(id="named_range"), @@ -253,7 +253,7 @@ def _construct_itir_initialized_domain_arg( assert len(node_domain.values_[dim_i].elts) == 2 keys_dims_types = cast(ts.DimensionType, node_domain.keys_[dim_i].type).dim if keys_dims_types != dim: - raise GTTypeError( + raise ValueError( f"Dimensions in out field and field domain are not equivalent" f"Expected {dim}, but got {keys_dims_types} " ) @@ -277,7 +277,7 @@ def _compute_field_slice(node: past.Subscript): node_dims_ls = cast(ts.FieldType, node.type).dims assert isinstance(node_dims_ls, list) if isinstance(node.type, ts.FieldType) and len(out_field_slice_) != len(node_dims_ls): - raise GTTypeError( + raise ValueError( f"Too many indices for field {out_field_name}: field is {len(node_dims_ls)}" f"-dimensional, but {len(out_field_slice_)} were indexed." ) diff --git a/src/gt4py/next/ffront/type_info.py b/src/gt4py/next/ffront/type_info.py index 72955962ac..dbcd496f6e 100644 --- a/src/gt4py/next/ffront/type_info.py +++ b/src/gt4py/next/ffront/type_info.py @@ -17,7 +17,7 @@ import gt4py.next.ffront.type_specifications as ts_ffront import gt4py.next.type_system.type_specifications as ts -from gt4py.next.common import Dimension, GTTypeError +from gt4py.next.common import Dimension from gt4py.next.type_system import type_info @@ -49,7 +49,7 @@ def _as_field(arg: ts.TypeSpec, path: tuple): if type_info.extract_dtype(el_def_type) == type_info.extract_dtype(arg): return el_def_type else: - raise GTTypeError(f"{arg} is not compatible with {el_def_type}.") + raise ValueError(f"{arg} is not compatible with {el_def_type}.") return arg new_args.append( @@ -99,7 +99,7 @@ def function_signature_incompatibilities_scanop( ] try: type_info.promote_dims(*arg_dims) - except GTTypeError as e: + except ValueError as e: yield e.args[0] if len(args) != len(scanop_type.definition.args) - 1: diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index 7e1d7cf91f..15fab5e3f5 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -18,7 +18,7 @@ import numpy as np from gt4py.eve.utils import XIterable, xiter -from gt4py.next.common import Dimension, DimensionKind, GTTypeError +from gt4py.next.common import Dimension, DimensionKind from gt4py.next.type_system import type_specifications as ts @@ -49,13 +49,13 @@ def type_class(symbol_type: ts.TypeSpec) -> Type[ts.TypeSpec]: match symbol_type: case ts.DeferredType(constraint): if constraint is None: - raise GTTypeError(f"No type information available for {symbol_type}!") + raise ValueError(f"No type information available for {symbol_type}!") elif isinstance(constraint, tuple): - raise GTTypeError(f"Not sufficient type information available for {symbol_type}!") + raise ValueError(f"Not sufficient type information available for {symbol_type}!") return constraint case ts.TypeSpec() as concrete_type: return concrete_type.__class__ - raise GTTypeError( + raise ValueError( f"Invalid type for TypeInfo: requires {ts.TypeSpec}, got {type(symbol_type)}!" ) @@ -139,7 +139,7 @@ def extract_dtype(symbol_type: ts.TypeSpec) -> ts.ScalarType: return dtype case ts.ScalarType() as dtype: return dtype - raise GTTypeError(f"Can not unambiguosly extract data type from {symbol_type}!") + raise ValueError(f"Can not unambiguosly extract data type from {symbol_type}!") def is_floating_point(symbol_type: ts.TypeSpec) -> bool: @@ -299,7 +299,7 @@ def extract_dims(symbol_type: ts.TypeSpec) -> list[Dimension]: return [] case ts.FieldType(dims): return dims - raise GTTypeError(f"Can not extract dimensions from {symbol_type}!") + raise ValueError(f"Can not extract dimensions from {symbol_type}!") def is_local_field(type_: ts.FieldType) -> bool: @@ -401,11 +401,11 @@ def promote(*types: ts.FieldType | ts.ScalarType) -> ts.FieldType | ts.ScalarTyp ... ) # doctest: +ELLIPSIS Traceback (most recent call last): ... - gt4py.next.common.GTTypeError: Dimensions can not be promoted. Could not determine order of the following dimensions: J, K. + ValueError: Dimensions can not be promoted. Could not determine order of the following dimensions: J, K. """ if all(isinstance(type_, ts.ScalarType) for type_ in types): if not all(type_ == types[0] for type_ in types): - raise GTTypeError("Could not promote scalars of different dtype (not implemented).") + raise ValueError("Could not promote scalars of different dtype (not implemented).") if not all(type_.shape is None for type_ in types): # type: ignore[union-attr] raise NotImplementedError("Shape promotion not implemented.") return types[0] @@ -435,11 +435,11 @@ def promote_dims(*dims_list: list[Dimension]) -> list[Dimension]: >>> promote_dims([I, J], [K]) # doctest: +ELLIPSIS Traceback (most recent call last): ... - gt4py.next.common.GTTypeError: Dimensions can not be promoted. Could not determine order of the following dimensions: J, K. + ValueError: Dimensions can not be promoted. Could not determine order of the following dimensions: J, K. >>> promote_dims([I, J], [J, I]) # doctest: +ELLIPSIS Traceback (most recent call last): ... - gt4py.next.common.GTTypeError: Dimensions can not be promoted. The following dimensions appear in contradicting order: I, J. + ValueError: Dimensions can not be promoted. The following dimensions appear in contradicting order: I, J. """ # build a graph with the vertices being dimensions and edges representing # the order between two dimensions. The graph is encoded as a dictionary @@ -472,7 +472,7 @@ def promote_dims(*dims_list: list[Dimension]) -> list[Dimension]: # TODO(tehrengruber): avoid recomputation of zero_in_degree_vertex_list while zero_in_degree_vertex_list := [v for v, d in in_degree.items() if d == 0]: if len(zero_in_degree_vertex_list) != 1: - raise GTTypeError( + raise ValueError( f"Dimensions can not be promoted. Could not determine " f"order of the following dimensions: " f"{', '.join((dim.value for dim in zero_in_degree_vertex_list))}." @@ -485,7 +485,7 @@ def promote_dims(*dims_list: list[Dimension]) -> list[Dimension]: in_degree[predecessor] -= 1 if len(in_degree.items()) > 0: - raise GTTypeError( + raise ValueError( f"Dimensions can not be promoted. The following dimensions " f"appear in contradicting order: {', '.join((dim.value for dim in in_degree.keys()))}." ) @@ -524,11 +524,11 @@ def return_type_field( ): try: accepts_args(field_type, with_args=with_args, with_kwargs=with_kwargs, raise_exception=True) - except GTTypeError as ex: - raise GTTypeError("Could not deduce return type of invalid remap operation.") from ex + except ValueError as ex: + raise ValueError("Could not deduce return type of invalid remap operation.") from ex if not isinstance(with_args[0], ts.OffsetType): - raise GTTypeError(f"First argument must be of type {ts.OffsetType}, got {with_args[0]}.") + raise ValueError(f"First argument must be of type {ts.OffsetType}, got {with_args[0]}.") source_dim = with_args[0].source target_dims = with_args[0].target @@ -617,7 +617,7 @@ def accepts_args( """ Check if a function can be called for given arguments. - If ``raise_exception`` is given a :class:`GTTypeError` is raised with a + If ``raise_exception`` is given a :class:`ValueError` is raised with a detailed description of why the function is not callable. Note that all types must be concrete/complete. @@ -636,14 +636,14 @@ def accepts_args( """ if not isinstance(callable_type, ts.CallableType): if raise_exception: - raise GTTypeError(f"Expected a callable type, but got `{callable_type}`.") + raise ValueError(f"Expected a callable type, but got `{callable_type}`.") return False errors = function_signature_incompatibilities(callable_type, with_args, with_kwargs) if raise_exception: error_list = list(errors) if len(error_list) > 0: - raise GTTypeError( + raise ValueError( f"Invalid call to function of type `{callable_type}`:\n" + ("\n".join([f" - {error}" for error in error_list])) ) diff --git a/src/gt4py/next/type_system/type_translation.py b/src/gt4py/next/type_system/type_translation.py index 7ffa8795d4..1ae0ea18b6 100644 --- a/src/gt4py/next/type_system/type_translation.py +++ b/src/gt4py/next/type_system/type_translation.py @@ -37,7 +37,7 @@ def get_scalar_kind(dtype: npt.DTypeLike) -> ts.ScalarKind: try: dt = np.dtype(dtype) except TypeError as err: - raise common.GTTypeError(f"Invalid scalar type definition ({dtype})") from err + raise ValueError(f"Invalid scalar type definition ({dtype})") from err if dt.shape == () and dt.fields is None: match dt: @@ -54,9 +54,9 @@ def get_scalar_kind(dtype: npt.DTypeLike) -> ts.ScalarKind: case np.str_: return ts.ScalarKind.STRING case _: - raise common.GTTypeError(f"Impossible to map '{dtype}' value to a ScalarKind") + raise ValueError(f"Impossible to map '{dtype}' value to a ScalarKind") else: - raise common.GTTypeError(f"Non-trivial dtypes like '{dtype}' are not yet supported") + raise ValueError(f"Non-trivial dtypes like '{dtype}' are not yet supported") def from_type_hint( @@ -75,7 +75,7 @@ def from_type_hint( try: type_hint = xtyping.eval_forward_ref(type_hint, globalns=globalns, localns=localns) except Exception as error: - raise TypingError( + raise ValueError( f"Type annotation ({type_hint}) has undefined forward references!" ) from error @@ -98,50 +98,50 @@ def from_type_hint( case builtins.tuple: if not args: - raise TypingError(f"Tuple annotation ({type_hint}) requires at least one argument!") + raise ValueError(f"Tuple annotation ({type_hint}) requires at least one argument!") if Ellipsis in args: - raise TypingError(f"Unbound tuples ({type_hint}) are not allowed!") + raise ValueError(f"Unbound tuples ({type_hint}) are not allowed!") return ts.TupleType(types=[recursive_make_symbol(arg) for arg in args]) case common.Field: if (n_args := len(args)) != 2: - raise TypingError(f"Field type requires two arguments, got {n_args}! ({type_hint})") + raise ValueError(f"Field type requires two arguments, got {n_args}! ({type_hint})") dims: Union[Ellipsis, list[common.Dimension]] = [] dim_arg, dtype_arg = args if isinstance(dim_arg, list): for d in dim_arg: if not isinstance(d, common.Dimension): - raise TypingError(f"Invalid field dimension definition '{d}'") + raise ValueError(f"Invalid field dimension definition '{d}'") dims.append(d) elif dim_arg is Ellipsis: dims = dim_arg else: - raise TypingError(f"Invalid field dimensions '{dim_arg}'") + raise ValueError(f"Invalid field dimensions '{dim_arg}'") try: dtype = recursive_make_symbol(dtype_arg) - except TypingError as error: - raise TypingError( + except ValueError as error: + raise ValueError( f"Field dtype argument must be a scalar type (got '{dtype_arg}')!" ) from error if not isinstance(dtype, ts.ScalarType) or dtype.kind == ts.ScalarKind.STRING: - raise TypingError("Field dtype argument must be a scalar type (got '{dtype}')!") + raise ValueError("Field dtype argument must be a scalar type (got '{dtype}')!") return ts.FieldType(dims=dims, dtype=dtype) case collections.abc.Callable: if not args: - raise TypingError("Not annotated functions are not supported!") + raise ValueError("Not annotated functions are not supported!") try: arg_types, return_type = args args = [recursive_make_symbol(arg) for arg in arg_types] except Exception as error: - raise TypingError(f"Invalid callable annotations in {type_hint}") from error + raise ValueError(f"Invalid callable annotations in {type_hint}") from error kwargs_info = [arg for arg in extra_args if isinstance(arg, xtyping.CallableKwargsInfo)] if len(kwargs_info) != 1: - raise TypingError(f"Invalid callable annotations in {type_hint}") + raise ValueError(f"Invalid callable annotations in {type_hint}") kwargs = { arg: recursive_make_symbol(arg_type) for arg, arg_type in kwargs_info[0].data.items() @@ -152,7 +152,7 @@ def from_type_hint( args=args, kwargs=kwargs, returns=recursive_make_symbol(return_type) ) - raise TypingError(f"'{type_hint}' type is not supported") + raise ValueError(f"'{type_hint}' type is not supported") def from_value(value: Any) -> ts.TypeSpec: @@ -184,17 +184,4 @@ def from_value(value: Any) -> ts.TypeSpec: if isinstance(symbol_type, (ts.DataType, ts.CallableType, ts.OffsetType, ts.DimensionType)): return symbol_type else: - raise common.GTTypeError(f"Impossible to map '{value}' value to a Symbol") - - -# TODO(egparedes): Add source location info (maybe subclassing FieldOperatorSyntaxError) -class TypingError(common.GTTypeError): - def __init__( - self, - msg="", - *, - info=None, - ): - msg = f"Invalid type declaration: {msg}" - args = tuple([msg, info] if info else [msg]) - super().__init__(*args) + raise ValueError(f"Impossible to map '{value}' value to a Symbol") diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py index d6e7cb0937..4154b03675 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py @@ -19,7 +19,7 @@ import numpy as np import pytest -from gt4py.next.common import Field, GTTypeError +from gt4py.next.common import Field from gt4py.next.errors import * from gt4py.next.ffront.decorator import field_operator, program from gt4py.next.iterator.embedded import np_as_located_field @@ -266,7 +266,7 @@ def empty_domain_program( empty_domain_fieldop(a, out=out_field, domain={JDim: (0, 1), IDim: (0, 1)}) with pytest.raises( - GTTypeError, + ValueError, match=(r"Dimensions in out field and field domain are not equivalent"), ): empty_domain_program(a, out_field, offset_provider={}) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py index 5757723768..45a83ba932 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py @@ -17,7 +17,7 @@ import pytest import gt4py.next.ffront.type_specifications -from gt4py.next.common import DimensionKind, GTTypeError +from gt4py.next.common import DimensionKind from gt4py.next.ffront.ast_passes import single_static_assign as ssa from gt4py.next.ffront.experimental import as_offset from gt4py.next.ffront.fbuiltins import ( @@ -271,7 +271,7 @@ def test_accept_args( if len(expected) > 0: with pytest.raises( - GTTypeError, + ValueError, ) as exc_info: type_info.accepts_args( func_type, with_args=args, with_kwargs=kwargs, raise_exception=True diff --git a/tests/next_tests/unit_tests/ffront_tests/test_decorator_domain_deduction.py b/tests/next_tests/unit_tests/ffront_tests/test_decorator_domain_deduction.py index 745a1e6cff..eccea4dd73 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_decorator_domain_deduction.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_decorator_domain_deduction.py @@ -14,7 +14,7 @@ import pytest -from gt4py.next.common import Dimension, DimensionKind, GridType, GTTypeError +from gt4py.next.common import Dimension, DimensionKind, GridType from gt4py.next.ffront.decorator import _deduce_grid_type from gt4py.next.ffront.fbuiltins import FieldOffset @@ -38,7 +38,7 @@ def test_domain_deduction_unstructured(): def test_domain_complies_with_request_cartesian(): assert _deduce_grid_type(GridType.CARTESIAN, {CartesianOffset}) == GridType.CARTESIAN - with pytest.raises(GTTypeError, match="unstructured.*FieldOffset.*found"): + with pytest.raises(ValueError, match="unstructured.*FieldOffset.*found"): _deduce_grid_type(GridType.CARTESIAN, {UnstructuredOffset}) _deduce_grid_type(GridType.CARTESIAN, {LocalDim}) diff --git a/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py b/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py index da72a28930..971b6ee08e 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py @@ -40,7 +40,7 @@ import pytest from gt4py.eve.pattern_matching import ObjectPattern as P -from gt4py.next.common import Field, GTTypeError +from gt4py.next.common import Field from gt4py.next.ffront import field_operator_ast as foast from gt4py.next.ffront.ast_passes import single_static_assign as ssa from gt4py.next.ffront.fbuiltins import ( @@ -75,7 +75,6 @@ xor_, ) from gt4py.next.type_system import type_specifications as ts -from gt4py.next.type_system.type_translation import TypingError from gt4py.next.errors import * @@ -119,7 +118,7 @@ def mistyped(inp: Field): return inp with pytest.raises( - TypingError, + ValueError, match="Field type requires two arguments, got 0!", ): _ = FieldOperatorParser.apply_to_function(mistyped) @@ -318,7 +317,7 @@ def test_adr13_wrong_return_type_annotation(): def wrong_return_type_annotation() -> Field[[], float]: return 1.0 - with pytest.raises(GTTypeError, match=r"Expected `float.*`"): + with pytest.raises(CompilationError, match=r"Expected `float.*`"): _ = FieldOperatorParser.apply_to_function(wrong_return_type_annotation) @@ -401,7 +400,7 @@ def wrong_return_type_annotation(a: Field[[ADim], float64]) -> Field[[BDim], flo return a with pytest.raises( - GTTypeError, + CompilationError, match=r"Annotated return type does not match deduced return type", ): _ = FieldOperatorParser.apply_to_function(wrong_return_type_annotation) @@ -412,7 +411,7 @@ def empty_dims() -> Field[[], float]: return 1.0 with pytest.raises( - GTTypeError, + CompilationError, match=r"Annotated return type does not match deduced return type", ): _ = FieldOperatorParser.apply_to_function(empty_dims) diff --git a/tests/next_tests/unit_tests/ffront_tests/test_func_to_past.py b/tests/next_tests/unit_tests/ffront_tests/test_func_to_past.py index 02d512018d..08041a7ec9 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_func_to_past.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_func_to_past.py @@ -18,7 +18,7 @@ import gt4py.eve as eve from gt4py.eve.pattern_matching import ObjectPattern as P -from gt4py.next.common import Field, GTTypeError +from gt4py.next.common import Field from gt4py.next.ffront import program_ast as past from gt4py.next.ffront.decorator import field_operator from gt4py.next.ffront.fbuiltins import float64 diff --git a/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py b/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py index 9c9fa73f34..9b3b366bb2 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py @@ -18,7 +18,7 @@ import gt4py.eve as eve from gt4py.eve.pattern_matching import ObjectPattern as P -from gt4py.next.common import Field, GridType, GTTypeError +from gt4py.next.common import Field, GridType from gt4py.next.ffront.decorator import field_operator from gt4py.next.ffront.func_to_past import ProgramParser from gt4py.next.ffront.past_to_itir import ProgramLowering @@ -154,7 +154,7 @@ def inout_field_program(inout_field: Field[[IDim], "float64"]): identity(inout_field, out=inout_field) with pytest.raises( - GTTypeError, + ValueError, match=(r"Call to function with field as input and output not allowed."), ): ProgramLowering.apply( diff --git a/tests/next_tests/unit_tests/type_system_tests/test_type_translation.py b/tests/next_tests/unit_tests/type_system_tests/test_type_translation.py index b0723ee14c..2be95c3f23 100644 --- a/tests/next_tests/unit_tests/type_system_tests/test_type_translation.py +++ b/tests/next_tests/unit_tests/type_system_tests/test_type_translation.py @@ -50,7 +50,7 @@ def test_valid_scalar_kind(value, expected): def test_invalid_scalar_kind(): - with pytest.raises(common.GTTypeError, match="Non-trivial dtypes"): + with pytest.raises(ValueError, match="Non-trivial dtypes"): type_translation.get_scalar_kind(np.dtype("i4, (2,3)f8, f4")) @@ -129,42 +129,42 @@ def test_make_symbol_type_from_typing(value, expected): def test_invalid_symbol_types(): # Forward references - with pytest.raises(type_translation.TypingError, match="undefined forward references"): + with pytest.raises(ValueError, match="undefined forward references"): type_translation.from_type_hint("foo") # Tuples - with pytest.raises(type_translation.TypingError, match="least one argument"): + with pytest.raises(ValueError, match="least one argument"): type_translation.from_type_hint(typing.Tuple) - with pytest.raises(type_translation.TypingError, match="least one argument"): + with pytest.raises(ValueError, match="least one argument"): type_translation.from_type_hint(tuple) - with pytest.raises(type_translation.TypingError, match="Unbound tuples"): + with pytest.raises(ValueError, match="Unbound tuples"): type_translation.from_type_hint(tuple[int, ...]) - with pytest.raises(type_translation.TypingError, match="Unbound tuples"): + with pytest.raises(ValueError, match="Unbound tuples"): type_translation.from_type_hint(typing.Tuple["float", ...]) # Fields - with pytest.raises(type_translation.TypingError, match="Field type requires two arguments"): + with pytest.raises(ValueError, match="Field type requires two arguments"): type_translation.from_type_hint(common.Field) - with pytest.raises(type_translation.TypingError, match="Invalid field dimensions"): + with pytest.raises(ValueError, match="Invalid field dimensions"): type_translation.from_type_hint(common.Field[int, int]) - with pytest.raises(type_translation.TypingError, match="Invalid field dimension"): + with pytest.raises(ValueError, match="Invalid field dimension"): type_translation.from_type_hint(common.Field[[int, int], int]) - with pytest.raises(type_translation.TypingError, match="Field dtype argument"): + with pytest.raises(ValueError, match="Field dtype argument"): type_translation.from_type_hint(common.Field[[IDim], str]) - with pytest.raises(type_translation.TypingError, match="Field dtype argument"): + with pytest.raises(ValueError, match="Field dtype argument"): type_translation.from_type_hint(common.Field[[IDim], None]) # Functions with pytest.raises( - type_translation.TypingError, match="Not annotated functions are not supported" + ValueError, match="Not annotated functions are not supported" ): type_translation.from_type_hint(typing.Callable) - with pytest.raises(type_translation.TypingError, match="Invalid callable annotations"): + with pytest.raises(ValueError, match="Invalid callable annotations"): type_translation.from_type_hint(typing.Callable[..., float]) - with pytest.raises(type_translation.TypingError, match="Invalid callable annotations"): + with pytest.raises(ValueError, match="Invalid callable annotations"): type_translation.from_type_hint(typing.Callable[[int], str]) - with pytest.raises(type_translation.TypingError, match="Invalid callable annotations"): + with pytest.raises(ValueError, match="Invalid callable annotations"): type_translation.from_type_hint(typing.Callable[[int], float]) From 384b62f815100c1ea70f03367e22ab5f55f053a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Kardos?= Date: Thu, 15 Jun 2023 17:14:29 +0200 Subject: [PATCH 11/54] improve printing & exception hierarchy --- src/gt4py/next/errors/__init__.py | 12 +- src/gt4py/next/errors/excepthook.py | 14 ++ src/gt4py/next/errors/exceptions.py | 116 +++++++------ src/gt4py/next/errors/formatting.py | 53 ++++-- src/gt4py/next/errors/tools.py | 15 -- src/gt4py/next/ffront/dialect_parser.py | 15 +- .../foast_passes/closure_var_folding.py | 2 +- .../ffront/foast_passes/type_deduction.py | 158 +++++++++--------- src/gt4py/next/ffront/func_to_foast.py | 30 ++-- src/gt4py/next/ffront/func_to_past.py | 10 +- .../next/ffront/past_passes/type_deduction.py | 10 +- tests/next_tests/exception_printing.py | 2 +- .../ffront_tests/test_execution.py | 6 +- .../ffront_tests/test_scalar_if.py | 6 +- .../ffront_tests/test_type_deduction.py | 28 ++-- .../feature_tests/test_util_cases.py | 2 +- .../errors_tests/test_compilation_error.py | 6 +- .../ffront_tests/test_func_to_foast.py | 24 +-- .../test_func_to_foast_error_line_number.py | 83 +++------ .../ffront_tests/test_func_to_past.py | 14 +- .../ffront_tests/test_past_to_itir.py | 2 +- 21 files changed, 306 insertions(+), 302 deletions(-) create mode 100644 src/gt4py/next/errors/excepthook.py delete mode 100644 src/gt4py/next/errors/tools.py diff --git a/src/gt4py/next/errors/__init__.py b/src/gt4py/next/errors/__init__.py index 778954b1f3..0fc590bca0 100644 --- a/src/gt4py/next/errors/__init__.py +++ b/src/gt4py/next/errors/__init__.py @@ -1,11 +1,11 @@ from .exceptions import ( - CompilationError, + CompilerError, UndefinedSymbolError, UnsupportedPythonFeatureError, - MissingParameterTypeError, - InvalidParameterTypeError, - IncorrectArgumentCountError, - UnexpectedKeywordArgError, + MissingParameterAnnotationError, + InvalidParameterAnnotationError, + ArgumentCountError, + KeywordArgumentError, MissingAttributeError ) -from . import formatting \ No newline at end of file +from . import excepthook \ No newline at end of file diff --git a/src/gt4py/next/errors/excepthook.py b/src/gt4py/next/errors/excepthook.py new file mode 100644 index 0000000000..6b49a4b3df --- /dev/null +++ b/src/gt4py/next/errors/excepthook.py @@ -0,0 +1,14 @@ +from . import formatting +from . import exceptions +from typing import Callable +import sys + +def compilation_error_hook(fallback: Callable, type_: type, value: exceptions.CompilerError, tb): + if issubclass(type_, exceptions.CompilerError): + print("".join(formatting.format_compilation_error(type_, value.message, value.location_trace)), file=sys.stderr) + else: + fallback(type_, value, tb) + + +_fallback = sys.excepthook +sys.excepthook = lambda ty, val, tb: compilation_error_hook(_fallback, ty, val, tb) \ No newline at end of file diff --git a/src/gt4py/next/errors/exceptions.py b/src/gt4py/next/errors/exceptions.py index 4e93fe3af8..4e89df2173 100644 --- a/src/gt4py/next/errors/exceptions.py +++ b/src/gt4py/next/errors/exceptions.py @@ -1,75 +1,85 @@ +import textwrap + from gt4py.eve import SourceLocation -from typing import Any, Optional -from . import tools - - -class CompilationError(SyntaxError): - def __init__(self, location: SourceLocation, message: str, *, snippet: str | bool = True): - source_code = None - if isinstance(snippet, str): - source_code = snippet - if snippet is True: - source_code = CompilationError.get_source_from_location(location) - super().__init__( - message, - ( - location.source, - location.line, - location.column, - source_code, - location.end_line, - location.end_column - ) - ) +from typing import Any, Optional, TypeVar +from . import formatting + + +LocationTraceT = TypeVar("LocationTraceT", SourceLocation, list[SourceLocation], None) + + +class CompilerError(Exception): + location_trace: list[SourceLocation] + + def __init__(self, location: LocationTraceT, message: str): + self.location_trace = CompilerError._make_location_trace(location) + super().__init__(message) + + @property + def message(self) -> str: + return self.args[0] @property - def location(self) -> SourceLocation: - return SourceLocation( - source=self.filename, - line=self.lineno, - column=self.offset, - end_line=self.end_lineno, - end_column=self.end_offset - ) + def location(self) -> Optional[SourceLocation]: + return self.location_trace[0] if self.location_trace else None + + def with_location(self, location: LocationTraceT) -> "CompilerError": + self.location_trace = CompilerError._make_location_trace(location) + return self + + def __str__(self): + if self.location: + loc_str = formatting.format_location(self.location, caret=True) + return f"{self.message}\n{textwrap.indent(loc_str, ' ')}" + return self.message @staticmethod - def get_source_from_location(location: SourceLocation) -> Optional[str]: - try: - return tools.get_source_from_location(location) - except ValueError: - return None + def _make_location_trace(location: LocationTraceT) -> list[SourceLocation]: + if isinstance(location, SourceLocation): + return [location] + elif isinstance(location, list): + return location + elif location is None: + return [] + else: + raise TypeError("expected 'SourceLocation', 'list', or 'None' for 'location'") + + +class UnsupportedPythonFeatureError(CompilerError): + def __init__(self, location: LocationTraceT, feature: str): + super().__init__(location, f"unsupported Python syntax: '{feature}'") -class UndefinedSymbolError(CompilationError): - def __init__(self, location: SourceLocation, name: str): +class UndefinedSymbolError(CompilerError): + def __init__(self, location: LocationTraceT, name: str): super().__init__(location, f"name '{name}' is not defined") -class UnsupportedPythonFeatureError(CompilationError): - def __init__(self, location: SourceLocation, feature: str): - super().__init__(location, f"unsupported Python syntax: '{feature}'") +class MissingAttributeError(CompilerError): + def __init__(self, location: LocationTraceT, attr_name: str): + super().__init__(location, f"object does not have attribute '{attr_name}'") -class MissingParameterTypeError(CompilationError): - def __init__(self, location: SourceLocation, param_name: str): +class CompilerTypeError(CompilerError): + def __init__(self, location: LocationTraceT, message: str): + super().__init__(location, message) + + +class MissingParameterAnnotationError(CompilerTypeError): + def __init__(self, location: LocationTraceT, param_name: str): super().__init__(location, f"parameter '{param_name}' is missing type annotations") -class InvalidParameterTypeError(CompilationError): - def __init__(self, location: SourceLocation, param_name: str, type_: Any): +class InvalidParameterAnnotationError(CompilerTypeError): + def __init__(self, location: LocationTraceT, param_name: str, type_: Any): super().__init__(location, f"parameter '{param_name}' has invalid type annotation '{type_}'") -class IncorrectArgumentCountError(CompilationError): - def __init__(self, location: SourceLocation, num_expected: int, num_provided: int): +class ArgumentCountError(CompilerTypeError): + def __init__(self, location: LocationTraceT, num_expected: int, num_provided: int): super().__init__(location, f"expected {num_expected} arguments but {num_provided} were provided") -class UnexpectedKeywordArgError(CompilationError): - def __init__(self, location: SourceLocation, provided_names: str): +class KeywordArgumentError(CompilerTypeError): + def __init__(self, location: LocationTraceT, provided_names: str): super().__init__(location, f"unexpected keyword argument(s) '{provided_names}' provided") - - -class MissingAttributeError(CompilationError): - def __init__(self, location: SourceLocation, attr_name: str): - super().__init__(location, f"object does not have attribute '{attr_name}'") \ No newline at end of file diff --git a/src/gt4py/next/errors/formatting.py b/src/gt4py/next/errors/formatting.py index dd9049ab39..0e6fb532bd 100644 --- a/src/gt4py/next/errors/formatting.py +++ b/src/gt4py/next/errors/formatting.py @@ -1,15 +1,48 @@ -import sys -import traceback -from . import exceptions -from typing import Callable +import textwrap +from gt4py.eve import SourceLocation +import pathlib -def compilation_error_hook(fallback: Callable, type_: type, value: exceptions.CompilationError, _): - if issubclass(type_, exceptions.CompilationError): - print("".join(traceback.format_exception(value, limit=0)), file=sys.stderr) +def get_source_from_location(location: SourceLocation): + try: + source_file = pathlib.Path(location.source) + source_code = source_file.read_text() + source_lines = source_code.splitlines(False) + start_line = location.line + end_line = location.end_line + 1 if location.end_line else start_line + 1 + relevant_lines = source_lines[(start_line-1):(end_line-1)] + return "\n".join(relevant_lines) + except Exception as ex: + raise ValueError("failed to get source code for source location") from ex + + +def format_location(loc: SourceLocation, caret: bool = False): + filename = loc.source or "" + lineno = loc.line or "" + loc_str = f"File \"{filename}\", line {lineno}" + + if caret and loc.column is not None: + offset = loc.column - 1 + width = loc.end_column - loc.column if loc.end_column is not None else 1 + caret_str = "".join([" "] * offset + ["^"] * width) else: - fallback(type_, value, traceback) + caret_str = None + + try: + snippet_str = get_source_from_location(loc) + if caret_str: + snippet_str = f"{snippet_str}\n{caret_str}" + return f"{loc_str}\n{textwrap.indent(snippet_str, ' ')}" + except ValueError: + return loc_str + +def format_compilation_error(type_: type[Exception], message: str, location_trace: list[SourceLocation]): + msg_str = f"{type_.__module__}.{type_.__name__}: {message}" -_fallback = sys.excepthook -sys.excepthook = lambda ty, val, tb: compilation_error_hook(_fallback, ty, val, tb) \ No newline at end of file + try: + loc_str = "".join([format_location(loc) for loc in location_trace]) + stack_str = f"Source location (most recent call last):\n{textwrap.indent(loc_str, ' ')}\n" + return [stack_str, msg_str] + except ValueError: + return [msg_str] diff --git a/src/gt4py/next/errors/tools.py b/src/gt4py/next/errors/tools.py deleted file mode 100644 index 843e3d5437..0000000000 --- a/src/gt4py/next/errors/tools.py +++ /dev/null @@ -1,15 +0,0 @@ -import pathlib -from gt4py.eve import SourceLocation - - -def get_source_from_location(location: SourceLocation): - try: - source_file = pathlib.Path(location.source) - source_code = source_file.read_text() - source_lines = source_code.splitlines(False) - start_line = location.line - end_line = location.end_line + 1 if location.end_line else start_line + 1 - relevant_lines = source_lines[(start_line-1):(end_line-1)] - return "\n".join(relevant_lines) - except Exception as ex: - raise ValueError("failed to get source code for source location") from ex \ No newline at end of file diff --git a/src/gt4py/next/ffront/dialect_parser.py b/src/gt4py/next/ffront/dialect_parser.py index 72ff372d10..72861ee4ee 100644 --- a/src/gt4py/next/ffront/dialect_parser.py +++ b/src/gt4py/next/ffront/dialect_parser.py @@ -21,6 +21,7 @@ from gt4py.eve.concepts import SourceLocation from gt4py.eve.extended_typing import Any, ClassVar, Generic, Optional, Type, TypeVar from gt4py.next import common +from gt4py.next.errors import * from gt4py.next.ffront.ast_passes.fix_missing_locations import FixMissingLocations from gt4py.next.ffront.ast_passes.remove_docstrings import RemoveDocstrings from gt4py.next.ffront.source_utils import SourceDefinition, get_closure_vars_from_function @@ -34,12 +35,14 @@ def parse_source_definition(source_definition: SourceDefinition) -> ast.AST: try: return ast.parse(textwrap.dedent(source_definition.source)).body[0] except SyntaxError as err: - err.filename = source_definition.filename - err.lineno = err.lineno + source_definition.line_offset if err.lineno is not None else None - err.offset = err.offset + source_definition.column_offset if err.offset is not None else None - err.end_lineno = err.end_lineno + source_definition.line_offset if err.end_lineno is not None else None - err.end_offset = err.end_offset + source_definition.column_offset if err.end_offset is not None else None - raise err + loc = SourceLocation( + line=err.lineno + source_definition.line_offset if err.lineno is not None else None, + column=err.offset + source_definition.column_offset if err.offset is not None else None, + source=source_definition.filename, + end_line=err.end_lineno + source_definition.line_offset if err.end_lineno is not None else None, + end_column=err.end_offset + source_definition.column_offset if err.end_offset is not None else None + ) + raise CompilerError(loc, err.msg).with_traceback(err.__traceback__) @dataclass(frozen=True, kw_only=True) diff --git a/src/gt4py/next/ffront/foast_passes/closure_var_folding.py b/src/gt4py/next/ffront/foast_passes/closure_var_folding.py index 45a93d74a0..3504864f78 100644 --- a/src/gt4py/next/ffront/foast_passes/closure_var_folding.py +++ b/src/gt4py/next/ffront/foast_passes/closure_var_folding.py @@ -56,7 +56,7 @@ def visit_Attribute(self, node: foast.Attribute, **kwargs) -> foast.Constant: if hasattr(value.value, node.attr): return foast.Constant(value=getattr(value.value, node.attr), location=node.location) raise MissingAttributeError(node.location, node.attr) - raise CompilationError(node.location, "attribute access only applicable to constants") + raise CompilerError(node.location, "attribute access only applicable to constants") def visit_FunctionDefinition( self, node: foast.FunctionDefinition, **kwargs diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index d53215f420..fb18d53919 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -145,10 +145,10 @@ def deduce_stmt_return_type( if return_types[0] == return_types[1]: is_unconditional_return = True else: - raise CompilationError(stmt.location, + raise CompilerError(stmt.location, f"If statement contains return statements with inconsistent types:" f"{return_types[0]} != {return_types[1]}", - ) + ) return_type = return_types[0] or return_types[1] elif isinstance(stmt, foast.BlockStmt): # just forward to nested BlockStmt @@ -161,10 +161,10 @@ def deduce_stmt_return_type( raise AssertionError(f"Nodes of type `{type(stmt).__name__}` not supported.") if conditional_return_type and return_type and return_type != conditional_return_type: - raise CompilationError(stmt.location, + raise CompilerError(stmt.location, f"If statement contains return statements with inconsistent types:" f"{conditional_return_type} != {conditional_return_type}", - ) + ) if is_unconditional_return: # found a statement that always returns assert return_type @@ -246,9 +246,9 @@ def visit_FunctionDefinition(self, node: foast.FunctionDefinition, **kwargs): new_closure_vars = self.visit(node.closure_vars, **kwargs) return_type = deduce_stmt_return_type(new_body) if not isinstance(return_type, (ts.DataType, ts.DeferredType, ts.VoidType)): - raise CompilationError(node.location, + raise CompilerError(node.location, f"Function must return `DataType`, `DeferredType`, or `VoidType`, got `{return_type}`.", - ) + ) new_type = ts.FunctionType( args=[new_param.type for new_param in new_params], kwargs={}, returns=return_type ) @@ -274,26 +274,26 @@ def visit_FieldOperator(self, node: foast.FieldOperator, **kwargs) -> foast.Fiel def visit_ScanOperator(self, node: foast.ScanOperator, **kwargs) -> foast.ScanOperator: new_axis = self.visit(node.axis, **kwargs) if not isinstance(new_axis.type, ts.DimensionType): - raise CompilationError(node.location, + raise CompilerError(node.location, f"Argument `axis` to scan operator `{node.id}` must be a dimension.", - ) + ) if not new_axis.type.dim.kind == DimensionKind.VERTICAL: - raise CompilationError(node.location, + raise CompilerError(node.location, f"Argument `axis` to scan operator `{node.id}` must be a vertical dimension.", - ) + ) new_forward = self.visit(node.forward, **kwargs) if not new_forward.type.kind == ts.ScalarKind.BOOL: - raise CompilationError(node.location, f"Argument `forward` to scan operator `{node.id}` must be a boolean." - ) + raise CompilerError(node.location, f"Argument `forward` to scan operator `{node.id}` must be a boolean." + ) new_init = self.visit(node.init, **kwargs) if not all( type_info.is_arithmetic(type_) or type_info.is_logical(type_) for type_ in type_info.primitive_constituents(new_init.type) ): - raise CompilationError(node.location, + raise CompilerError(node.location, f"Argument `init` to scan operator `{node.id}` must " f"be an arithmetic type or a logical type or a composite of arithmetic and logical types.", - ) + ) new_definition = self.visit(node.definition, **kwargs) new_type = ts_ffront.ScanOperatorType( axis=new_axis.type.dim, @@ -312,8 +312,8 @@ def visit_ScanOperator(self, node: foast.ScanOperator, **kwargs) -> foast.ScanOp def visit_Name(self, node: foast.Name, **kwargs) -> foast.Name: symtable = kwargs["symtable"] if node.id not in symtable or symtable[node.id].type is None: - raise CompilationError(node.location, f"Undeclared symbol `{node.id}`." - ) + raise CompilerError(node.location, f"Undeclared symbol `{node.id}`." + ) symbol = symtable[node.id] return foast.Name(id=node.id, type=symbol.type, location=node.location) @@ -337,8 +337,8 @@ def visit_TupleTargetAssign( indices: list[tuple[int, int] | int] = compute_assign_indices(targets, num_elts) if not any(isinstance(i, tuple) for i in indices) and len(indices) != num_elts: - raise CompilationError(node.location, f"Too many values to unpack (expected {len(indices)})." - ) + raise CompilerError(node.location, f"Too many values to unpack (expected {len(indices)})." + ) new_targets: TargetType = [] new_type: ts.TupleType | ts.DataType @@ -368,8 +368,8 @@ def visit_TupleTargetAssign( ) new_targets.append(new_target) else: - raise CompilationError(node.location, f"Assignment value must be of type tuple! Got: {values.type}" - ) + raise CompilerError(node.location, f"Assignment value must be of type tuple! Got: {values.type}" + ) return foast.TupleTargetAssign(targets=new_targets, value=values, location=node.location) @@ -386,25 +386,25 @@ def visit_IfStmt(self, node: foast.IfStmt, **kwargs) -> foast.IfStmt: ) if not isinstance(new_node.condition.type, ts.ScalarType): - raise CompilationError(node.location, + raise CompilerError(node.location, "Condition for `if` must be scalar. " f"But got `{new_node.condition.type}` instead.", - ) + ) if new_node.condition.type.kind != ts.ScalarKind.BOOL: - raise CompilationError(node.location, + raise CompilerError(node.location, "Condition for `if` must be of boolean type. " f"But got `{new_node.condition.type}` instead.", - ) + ) for sym in node.annex.propagated_symbols.keys(): if (true_type := new_true_branch.annex.symtable[sym].type) != ( false_type := new_false_branch.annex.symtable[sym].type ): - raise CompilationError(node.location, + raise CompilerError(node.location, f"Inconsistent types between two branches for variable `{sym}`. " f"Got types `{true_type}` and `{false_type}.", - ) + ) # TODO: properly patch symtable (new node?) symtable[sym].type = new_node.annex.propagated_symbols[ sym @@ -421,12 +421,12 @@ def visit_Symbol( symtable = kwargs["symtable"] if refine_type: if not type_info.is_concretizable(node.type, to_type=refine_type): - raise CompilationError(node.location, - ( + raise CompilerError(node.location, + ( "type inconsistency: expression was deduced to be " f"of type {refine_type}, instead of the expected type {node.type}" ), - ) + ) new_node: foast.Symbol = foast.Symbol( id=node.id, type=refine_type, location=node.location ) @@ -442,19 +442,19 @@ def visit_Subscript(self, node: foast.Subscript, **kwargs) -> foast.Subscript: new_type = types[node.index] case ts.OffsetType(source=source, target=(target1, target2)): if not target2.kind == DimensionKind.LOCAL: - raise CompilationError(new_value.location, "Second dimension in offset must be a local dimension.") + raise CompilerError(new_value.location, "Second dimension in offset must be a local dimension.") new_type = ts.OffsetType(source=source, target=(target1,)) case ts.OffsetType(source=source, target=(target,)): # for cartesian axes (e.g. I, J) the index of the subscript only # signifies the displacement in the respective dimension, # but does not change the target type. if source != target: - raise CompilationError(new_value.location, + raise CompilerError(new_value.location, "Source and target must be equal for offsets with a single target.", - ) + ) new_type = new_value.type case _: - raise CompilationError( + raise CompilerError( new_value.location, "Could not deduce type of subscript expression!" ) @@ -494,14 +494,14 @@ def _deduce_ternaryexpr_type( false_expr: foast.Expr, ) -> Optional[ts.TypeSpec]: if condition.type != ts.ScalarType(kind=ts.ScalarKind.BOOL): - raise CompilationError(condition.location, + raise CompilerError(condition.location, f"Condition is of type `{condition.type}` " f"but should be of type `bool`.", - ) + ) if true_expr.type != false_expr.type: - raise CompilationError(node.location, + raise CompilerError(node.location, f"Left and right types are not the same: `{true_expr.type}` and `{false_expr.type}`", - ) + ) return true_expr.type def visit_Compare(self, node: foast.Compare, **kwargs) -> foast.Compare: @@ -518,8 +518,8 @@ def _deduce_compare_type( # check both types compatible for arg in (left, right): if not type_info.is_arithmetic(arg.type): - raise CompilationError(arg.location, f"Type {arg.type} can not be used in operator '{node.op}'!" - ) + raise CompilerError(arg.location, f"Type {arg.type} can not be used in operator '{node.op}'!" + ) self._check_operand_dtypes_match(node, left=left, right=right) @@ -528,10 +528,10 @@ def _deduce_compare_type( # mechanism to handle dimension promotion return type_info.promote(boolified_type(left.type), boolified_type(right.type)) except ValueError as ex: - raise CompilationError(node.location, + raise CompilerError(node.location, f"Could not promote `{left.type}` and `{right.type}` to common type" f" in call to `{node.op}`.", - ) from ex + ) from ex def _deduce_binop_type( self, @@ -551,8 +551,8 @@ def _deduce_binop_type( # check both types compatible for arg in (left, right): if not is_compatible(arg.type): - raise CompilationError(arg.location, f"Type {arg.type} can not be used in operator `{node.op}`!" - ) + raise CompilerError(arg.location, f"Type {arg.type} can not be used in operator `{node.op}`!" + ) left_type = cast(ts.FieldType | ts.ScalarType, left.type) right_type = cast(ts.FieldType | ts.ScalarType, right.type) @@ -563,26 +563,26 @@ def _deduce_binop_type( if node.op == dialect_ast_enums.BinaryOperator.MOD and not type_info.is_integral( right_type ): - raise CompilationError(arg.location, + raise CompilerError(arg.location, f"Type {right_type} can not be used in operator `{node.op}`, it can only accept ints", - ) + ) try: return type_info.promote(left_type, right_type) except ValueError as ex: - raise CompilationError(node.location, + raise CompilerError(node.location, f"Could not promote `{left_type}` and `{right_type}` to common type" f" in call to `{node.op}`.", - ) from ex + ) from ex def _check_operand_dtypes_match( self, node: foast.BinOp | foast.Compare, left: foast.Expr, right: foast.Expr ) -> None: # check dtypes match if not type_info.extract_dtype(left.type) == type_info.extract_dtype(right.type): - raise CompilationError(node.location, + raise CompilerError(node.location, f"Incompatible datatypes in operator `{node.op}`: {left.type} and {right.type}!", - ) + ) def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs) -> foast.UnaryOp: new_operand = self.visit(node.operand, **kwargs) @@ -596,9 +596,9 @@ def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs) -> foast.UnaryOp: else type_info.is_arithmetic ) if not is_compatible(new_operand.type): - raise CompilationError(node.location, + raise CompilerError(node.location, f"Incompatible type for unary operator `{node.op}`: `{new_operand.type}`!", - ) + ) return foast.UnaryOp( op=node.op, operand=new_operand, location=node.location, type=new_operand.type ) @@ -627,14 +627,14 @@ def visit_Call(self, node: foast.Call, **kwargs) -> foast.Call: new_func, (foast.FunctionDefinition, foast.FieldOperator, foast.ScanOperator, foast.Name), ): - raise CompilationError(node.location, "Functions can only be called directly!" - ) + raise CompilerError(node.location, "Functions can only be called directly!" + ) elif isinstance(new_func.type, ts.FieldType): pass else: - raise CompilationError(node.location, + raise CompilerError(node.location, f"Expression of type `{new_func.type}` is not callable, must be a `Function`, `FieldOperator`, `ScanOperator` or `Field`.", - ) + ) # ensure signature is valid try: @@ -645,8 +645,8 @@ def visit_Call(self, node: foast.Call, **kwargs) -> foast.Call: raise_exception=True, ) except ValueError as err: - raise CompilationError(node.location, f"Invalid argument types in call to `{new_func}`!" - ) from err + raise CompilerError(node.location, f"Invalid argument types in call to `{new_func}`!" + ) from err return_type = type_info.return_type(func_type, with_args=arg_types, with_kwargs=kwarg_types) @@ -702,9 +702,9 @@ def _visit_math_built_in(self, node: foast.Call, **kwargs) -> foast.Call: f"Expected {i}-th argument to be {error_msg_for_validator[arg_validator]} type, but got `{arg.type}`." ) if error_msgs: - raise CompilationError(node.location, + raise CompilerError(node.location, "\n".join([error_msg_preamble] + [f" - {error}" for error in error_msgs]), - ) + ) if func_name == "power" and all(type_info.is_integral(arg.type) for arg in node.args): print(f"Warning: return type of {func_name} might be inconsistent (not implemented).") @@ -724,8 +724,8 @@ def _visit_math_built_in(self, node: foast.Call, **kwargs) -> foast.Call: *((cast(ts.FieldType | ts.ScalarType, arg.type)) for arg in node.args) ) except ValueError as ex: - raise CompilationError(node.location, error_msg_preamble - ) from ex + raise CompilerError(node.location, error_msg_preamble + ) from ex else: raise AssertionError(f"Unknown math builtin `{func_name}`.") @@ -744,11 +744,11 @@ def _visit_reduction(self, node: foast.Call, **kwargs) -> foast.Call: assert field_type.dims is not ... if reduction_dim not in field_type.dims: field_dims_str = ", ".join(str(dim) for dim in field_type.dims) - raise CompilationError(node.location, + raise CompilerError(node.location, f"Incompatible field argument in call to `{str(node.func)}`. " f"Expected a field with dimension {reduction_dim}, but got " f"{field_dims_str}.", - ) + ) return_type = ts.FieldType( dims=[dim for dim in field_type.dims if dim != reduction_dim], dtype=field_type.dtype, @@ -799,18 +799,18 @@ def _visit_as_offset(self, node: foast.Call, **kwargs) -> foast.Call: assert isinstance(arg_0, ts.OffsetType) assert isinstance(arg_1, ts.FieldType) if not type_info.is_integral(arg_1): - raise CompilationError(node.location, + raise CompilerError(node.location, f"Incompatible argument in call to `{str(node.func)}`. " f"Excepted integer for offset field dtype, but got {arg_1.dtype}" f"{node.location}", - ) + ) if arg_0.source not in arg_1.dims: - raise CompilationError(node.location, + raise CompilerError(node.location, f"Incompatible argument in call to `{str(node.func)}`. " f"{arg_0.source} not in list of offset field dimensions {arg_1.dims}. " f"{node.location}", - ) + ) return foast.Call( func=node.func, @@ -826,10 +826,10 @@ def _visit_where(self, node: foast.Call, **kwargs) -> foast.Call: false_branch_type = node.args[2].type return_type: ts.TupleType | ts.FieldType if not type_info.is_logical(mask_type): - raise CompilationError(node.location, + raise CompilerError(node.location, f"Incompatible argument in call to `{str(node.func)}`. Expected " f"a field with dtype bool, but got `{mask_type}`.", - ) + ) try: if isinstance(true_branch_type, ts.TupleType) and isinstance( @@ -843,10 +843,10 @@ def _visit_where(self, node: foast.Call, **kwargs) -> foast.Call: elif isinstance(true_branch_type, ts.TupleType) or isinstance( false_branch_type, ts.TupleType ): - raise CompilationError(node.location, + raise CompilerError(node.location, f"Return arguments need to be of same type in {str(node.func)}, but got: " f"{node.args[1].type} and {node.args[2].type}", - ) + ) else: true_branch_fieldtype = cast(ts.FieldType, true_branch_type) false_branch_fieldtype = cast(ts.FieldType, false_branch_type) @@ -854,9 +854,9 @@ def _visit_where(self, node: foast.Call, **kwargs) -> foast.Call: return_type = promote_to_mask_type(mask_type, promoted_type) except ValueError as ex: - raise CompilationError(node.location, + raise CompilerError(node.location, f"Incompatible argument in call to `{str(node.func)}`.", - ) from ex + ) from ex return foast.Call( func=node.func, @@ -871,18 +871,18 @@ def _visit_broadcast(self, node: foast.Call, **kwargs) -> foast.Call: broadcast_dims_expr = cast(foast.TupleExpr, node.args[1]).elts if any([not (isinstance(elt.type, ts.DimensionType)) for elt in broadcast_dims_expr]): - raise CompilationError(node.location, + raise CompilerError(node.location, f"Incompatible broadcast dimension type in {str(node.func)}. Expected " f"all broadcast dimensions to be of type Dimension.", - ) + ) broadcast_dims = [cast(ts.DimensionType, elt.type).dim for elt in broadcast_dims_expr] if not set((arg_dims := type_info.extract_dims(arg_type))).issubset(set(broadcast_dims)): - raise CompilationError(node.location, + raise CompilerError(node.location, f"Incompatible broadcast dimensions in {str(node.func)}. Expected " f"broadcast dimension is missing {set(arg_dims).difference(set(broadcast_dims))}", - ) + ) return_type = ts.FieldType( dims=broadcast_dims, @@ -901,6 +901,6 @@ def visit_Constant(self, node: foast.Constant, **kwargs) -> foast.Constant: try: type_ = type_translation.from_value(node.value) except ValueError as e: - raise CompilationError(node.location, "Could not deduce type of constant." - ) from e + raise CompilerError(node.location, "Could not deduce type of constant." + ) from e return foast.Constant(value=node.value, location=node.location, type=type_) diff --git a/src/gt4py/next/ffront/func_to_foast.py b/src/gt4py/next/ffront/func_to_foast.py index e8cba554d3..63c2c4aaed 100644 --- a/src/gt4py/next/ffront/func_to_foast.py +++ b/src/gt4py/next/ffront/func_to_foast.py @@ -72,7 +72,7 @@ class FieldOperatorParser(DialectParser[foast.FunctionDefinition]): >>> >>> try: # doctest: +ELLIPSIS ... FieldOperatorParser.apply_to_function(wrong_syntax) - ... except CompilationError as err: + ... except CompilerError as err: ... print(f"Error at [{err.lineno}, {err.offset}] in {err.filename})") Error at [2, 5] in ...gt4py.next.ffront.func_to_foast.FieldOperatorParser[...]>) """ @@ -104,7 +104,7 @@ def _postprocess_dialect_ast( # TODO(tehrengruber): use `type_info.return_type` when the type of the # arguments becomes available here if annotated_return_type != foast_node.type.returns: # type: ignore[union-attr] # revisit when `type_info.return_type` is implemented - raise CompilationError( + raise CompilerError( foast_node.location, f"Annotated return type does not match deduced return type. Expected `{foast_node.type.returns}`" # type: ignore[union-attr] # revisit when `type_info.return_type` is implemented f", but got `{annotated_return_type}`." @@ -165,7 +165,7 @@ def visit_FunctionDef(self, node: ast.FunctionDef, **kwargs) -> foast.FunctionDe new_body = self._visit_stmts(node.body, self.get_location(node), **kwargs) if deduce_stmt_return_kind(new_body) == StmtReturnKind.NO_RETURN: - raise CompilationError(loc, "function is expected to return a value, return statement not found") + raise CompilerError(loc, "function is expected to return a value, return statement not found") return foast.FunctionDefinition( id=node.name, @@ -181,10 +181,10 @@ def visit_arguments(self, node: ast.arguments) -> list[foast.DataSymbol]: def visit_arg(self, node: ast.arg) -> foast.DataSymbol: loc = self.get_location(node) if (annotation := self.annotations.get(node.arg, None)) is None: - raise MissingParameterTypeError(loc, node.arg) + raise MissingParameterAnnotationError(loc, node.arg) new_type = type_translation.from_type_hint(annotation) if not isinstance(new_type, ts.DataType): - raise InvalidParameterTypeError(loc, node.arg, new_type) + raise InvalidParameterAnnotationError(loc, node.arg, new_type) return foast.DataSymbol(id=node.arg, location=loc, type=new_type) def visit_Assign(self, node: ast.Assign, **kwargs) -> foast.Assign | foast.TupleTargetAssign: @@ -222,7 +222,7 @@ def visit_Assign(self, node: ast.Assign, **kwargs) -> foast.Assign | foast.Tuple ) if not isinstance(target, ast.Name): - raise CompilationError(self.get_location(node), "can only assign to names") + raise CompilerError(self.get_location(node), "can only assign to names") new_value = self.visit(node.value) constraint_type: Type[ts.DataType] = ts.DataType if isinstance(new_value, foast.TupleExpr): @@ -244,7 +244,7 @@ def visit_Assign(self, node: ast.Assign, **kwargs) -> foast.Assign | foast.Tuple def visit_AnnAssign(self, node: ast.AnnAssign, **kwargs) -> foast.Assign: if not isinstance(node.target, ast.Name): - raise CompilationError(self.get_location(node), "can only assign to names") + raise CompilerError(self.get_location(node), "can only assign to names") if node.annotation is not None: assert isinstance( @@ -285,7 +285,7 @@ def visit_Subscript(self, node: ast.Subscript, **kwargs) -> foast.Subscript: try: index = self._match_index(node.slice) except ValueError: - raise CompilationError(self.get_location(node.slice), "expected an integral index") from None + raise CompilerError(self.get_location(node.slice), "expected an integral index") from None return foast.Subscript( value=self.visit(node.value), @@ -306,7 +306,7 @@ def visit_Tuple(self, node: ast.Tuple, **kwargs) -> foast.TupleExpr: def visit_Return(self, node: ast.Return, **kwargs) -> foast.Return: loc = self.get_location(node) if not node.value: - raise CompilationError(loc, "must return a value, not None") + raise CompilerError(loc, "must return a value, not None") return foast.Return(value=self.visit(node.value), location=loc) def visit_Expr(self, node: ast.Expr) -> foast.Expr: @@ -435,19 +435,19 @@ def _verify_builtin_function(self, node: ast.Call): func_name = self._func_name(node) func_info = getattr(fbuiltins, func_name).__gt_type__() if not len(node.args) == len(func_info.args): - raise IncorrectArgumentCountError(loc, len(func_info.args), len(node.args)) + raise ArgumentCountError(loc, len(func_info.args), len(node.args)) elif unexpected_kwargs := set(k.arg for k in node.keywords) - set(func_info.kwargs): - raise UnexpectedKeywordArgError(loc, ", ".join(unexpected_kwargs)) + raise KeywordArgumentError(loc, ", ".join(unexpected_kwargs)) def _verify_builtin_type_constructor(self, node: ast.Call): loc = self.get_location(node) if not len(node.args) == 1: - raise IncorrectArgumentCountError(loc, 1, len(node.args)) + raise ArgumentCountError(loc, 1, len(node.args)) elif node.keywords: unexpected_kwargs = set(k.arg for k in node.keywords) - raise UnexpectedKeywordArgError(loc, ", ".join(unexpected_kwargs)) + raise KeywordArgumentError(loc, ", ".join(unexpected_kwargs)) elif not isinstance(node.args[0], ast.Constant): - raise CompilationError(self.get_location(node.args[0]), "expected a literal expression") + raise CompilerError(self.get_location(node.args[0]), "expected a literal expression") def _func_name(self, node: ast.Call) -> str: return node.func.id # type: ignore[attr-defined] # We want this to fail if the attribute does not exist unexpectedly. @@ -474,7 +474,7 @@ def visit_Constant(self, node: ast.Constant, **kwargs) -> foast.Constant: try: type_ = type_translation.from_value(node.value) except ValueError as e: - raise CompilationError(loc, f"constants of type {type(node.value)} are not permitted") from None + raise CompilerError(loc, f"constants of type {type(node.value)} are not permitted") from None return foast.Constant( value=node.value, diff --git a/src/gt4py/next/ffront/func_to_past.py b/src/gt4py/next/ffront/func_to_past.py index 85276fd3ff..2f8e16b84b 100644 --- a/src/gt4py/next/ffront/func_to_past.py +++ b/src/gt4py/next/ffront/func_to_past.py @@ -24,7 +24,7 @@ type_specifications as ts_ffront, ) from gt4py.next.ffront.dialect_parser import DialectParser -from gt4py.next.errors import CompilationError, MissingParameterTypeError, InvalidParameterTypeError +from gt4py.next.errors import CompilerError, MissingParameterAnnotationError, InvalidParameterAnnotationError from gt4py.next.ffront.past_passes.closure_var_type_deduction import ClosureVarTypeDeduction from gt4py.next.ffront.past_passes.type_deduction import ProgramTypeDeduction from gt4py.next.type_system import type_specifications as ts, type_translation @@ -67,10 +67,10 @@ def visit_arguments(self, node: ast.arguments) -> list[past.DataSymbol]: def visit_arg(self, node: ast.arg) -> past.DataSymbol: loc = self.get_location(node) if (annotation := self.annotations.get(node.arg, None)) is None: - raise MissingParameterTypeError(loc, node.arg) + raise MissingParameterAnnotationError(loc, node.arg) new_type = type_translation.from_type_hint(annotation) if not isinstance(new_type, ts.DataType): - raise InvalidParameterTypeError(loc, node.arg, new_type) + raise InvalidParameterAnnotationError(loc, node.arg, new_type) return past.DataSymbol(id=node.arg, location=loc, type=new_type) def visit_Expr(self, node: ast.Expr) -> past.LocatedNode: @@ -128,7 +128,7 @@ def visit_Call(self, node: ast.Call) -> past.Call: loc = self.get_location(node) new_func = self.visit(node.func) if not isinstance(new_func, past.Name): - raise CompilationError(loc, "functions must be referenced by their name in function calls") + raise CompilerError(loc, "functions must be referenced by their name in function calls") return past.Call( func=new_func, @@ -166,7 +166,7 @@ def visit_UnaryOp(self, node: ast.UnaryOp) -> past.Constant: return past.Constant( value=-node.operand.value, type=symbol_type, location=loc ) - raise CompilationError(loc, "unary operators are only applicable to literals") + raise CompilerError(loc, "unary operators are only applicable to literals") def visit_Constant(self, node: ast.Constant) -> past.Constant: symbol_type = type_translation.from_value(node.value) diff --git a/src/gt4py/next/ffront/past_passes/type_deduction.py b/src/gt4py/next/ffront/past_passes/type_deduction.py index bc4fb66ed1..68880db29e 100644 --- a/src/gt4py/next/ffront/past_passes/type_deduction.py +++ b/src/gt4py/next/ffront/past_passes/type_deduction.py @@ -145,7 +145,7 @@ def _deduce_binop_type( # check both types compatible for arg in (left, right): if not isinstance(arg.type, ts.ScalarType) or not is_compatible(arg.type): - raise CompilationError( + raise CompilerError( arg.location, f"Type {arg.type} can not be used in operator `{node.op}`!" ) @@ -158,7 +158,7 @@ def _deduce_binop_type( if node.op == dialect_ast_enums.BinaryOperator.MOD and not type_info.is_integral( right_type ): - raise CompilationError( + raise CompilerError( arg.location, f"Type {right_type} can not be used in operator `{node.op}`, it can only accept ints", ) @@ -166,7 +166,7 @@ def _deduce_binop_type( try: return type_info.promote(left_type, right_type) except ValueError as ex: - raise CompilationError( + raise CompilerError( node.location, f"Could not promote `{left_type}` and `{right_type}` to common type" f" in call to `{node.op}`.", @@ -228,7 +228,7 @@ def visit_Call(self, node: past.Call, **kwargs): ) except ValueError as ex: - raise CompilationError( + raise CompilerError( node.location, f"Invalid call to `{node.func.id}`." ) from ex @@ -243,7 +243,7 @@ def visit_Call(self, node: past.Call, **kwargs): def visit_Name(self, node: past.Name, **kwargs) -> past.Name: symtable = kwargs["symtable"] if node.id not in symtable or symtable[node.id].type is None: - raise CompilationError( + raise CompilerError( node.location, f"Undeclared or untyped symbol `{node.id}`." ) diff --git a/tests/next_tests/exception_printing.py b/tests/next_tests/exception_printing.py index 94c611de95..b1e07cac0e 100644 --- a/tests/next_tests/exception_printing.py +++ b/tests/next_tests/exception_printing.py @@ -5,4 +5,4 @@ frameinfo = inspect.getframeinfo(inspect.currentframe()) loc = SourceLocation(frameinfo.lineno, 1, frameinfo.filename, end_line=frameinfo.lineno, end_column=5) -raise CompilationError(loc, "this is an error message") \ No newline at end of file +raise CompilerError(loc, "this is an error message") \ No newline at end of file diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 63fd1b1765..32c40b9532 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -875,7 +875,7 @@ def fieldop_where_k_offset( def test_undefined_symbols(): - with pytest.raises(CompilationError, match="Undeclared symbol"): + with pytest.raises(CompilerError, match="Undeclared symbol"): @field_operator def return_undefined(): @@ -982,7 +982,7 @@ def unpack( def test_tuple_unpacking_too_many_values(fieldview_backend): with pytest.raises( - CompilationError, + CompilerError, match=(r"Could not deduce type: Too many values to unpack \(expected 3\)"), ): @@ -994,7 +994,7 @@ def _star_unpack() -> tuple[int, float64, int]: def test_tuple_unpacking_too_many_values(fieldview_backend): with pytest.raises( - CompilationError, match=(r"Assignment value must be of type tuple!") + CompilerError, match=(r"Assignment value must be of type tuple!") ): @field_operator(backend=fieldview_backend) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py index 174a34556a..1b5b021c89 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py @@ -362,7 +362,7 @@ def if_without_else( def test_if_non_scalar_condition(): - with pytest.raises(CompilationError, match="Condition for `if` must be scalar."): + with pytest.raises(CompilerError, match="Condition for `if` must be scalar."): @field_operator def if_non_scalar_condition( @@ -376,7 +376,7 @@ def if_non_scalar_condition( def test_if_non_boolean_condition(): with pytest.raises( - CompilationError, match="Condition for `if` must be of boolean type." + CompilerError, match="Condition for `if` must be of boolean type." ): @field_operator @@ -392,7 +392,7 @@ def if_non_boolean_condition( def test_if_inconsistent_types(): with pytest.raises( - CompilationError, + CompilerError, match="Inconsistent types between two branches for variable", ): diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py index 45a83ba932..e2fa4d9b9f 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py @@ -386,7 +386,7 @@ def add_bools(a: Field[[TDim], bool], b: Field[[TDim], bool]): return a + b with pytest.raises( - CompilationError, + CompilerError, match=(r"Type Field\[\[TDim\], bool\] can not be used in operator `\+`!"), ): _ = FieldOperatorParser.apply_to_function(add_bools) @@ -401,7 +401,7 @@ def nonmatching(a: Field[[X], float64], b: Field[[Y], float64]): return a + b with pytest.raises( - CompilationError, + CompilerError, match=( r"Could not promote `Field\[\[X], float64\]` and `Field\[\[Y\], float64\]` to common type in call to +." ), @@ -414,8 +414,8 @@ def float_bitop(a: Field[[TDim], float], b: Field[[TDim], float]): return a & b with pytest.raises( - CompilationError, - match=(r"Type Field\[\[TDim\], float64\] can not be used in operator `\&`! "), + CompilerError, + match=(r"Type Field\[\[TDim\], float64\] can not be used in operator `\&`!"), ): _ = FieldOperatorParser.apply_to_function(float_bitop) @@ -425,7 +425,7 @@ def sign_bool(a: Field[[TDim], bool]): return -a with pytest.raises( - CompilationError, + CompilerError, match=r"Incompatible type for unary operator `\-`: `Field\[\[TDim\], bool\]`!", ): _ = FieldOperatorParser.apply_to_function(sign_bool) @@ -436,7 +436,7 @@ def not_int(a: Field[[TDim], int64]): return not a with pytest.raises( - CompilationError, + CompilerError, match=r"Incompatible type for unary operator `not`: `Field\[\[TDim\], int64\]`!", ): _ = FieldOperatorParser.apply_to_function(not_int) @@ -508,7 +508,7 @@ def mismatched_lit() -> Field[[TDim], "float32"]: return float32("1.0") + float64("1.0") with pytest.raises( - CompilationError, + CompilerError, match=(r"Could not promote `float32` and `float64` to common type in call to +."), ): _ = FieldOperatorParser.apply_to_function(mismatched_lit) @@ -538,7 +538,7 @@ def disjoint_broadcast(a: Field[[ADim], float64]): return broadcast(a, (BDim, CDim)) with pytest.raises( - CompilationError, + CompilerError, match=r"Expected broadcast dimension is missing", ): _ = FieldOperatorParser.apply_to_function(disjoint_broadcast) @@ -553,7 +553,7 @@ def badtype_broadcast(a: Field[[ADim], float64]): return broadcast(a, (BDim, CDim)) with pytest.raises( - CompilationError, + CompilerError, match=r"Expected all broadcast dimensions to be of type Dimension.", ): _ = FieldOperatorParser.apply_to_function(badtype_broadcast) @@ -619,7 +619,7 @@ def bad_dim_where(a: Field[[ADim], bool], b: Field[[ADim], float64]): return where(a, ((5.0, 9.0), (b, 6.0)), b) with pytest.raises( - CompilationError, + CompilerError, match=r"Return arguments need to be of same type", ): _ = FieldOperatorParser.apply_to_function(bad_dim_where) @@ -674,7 +674,7 @@ def modulo_floats(inp: Field[[TDim], float]): return inp % 3.0 with pytest.raises( - CompilationError, + CompilerError, match=r"Type float64 can not be used in operator `%`", ): _ = FieldOperatorParser.apply_to_function(modulo_floats) @@ -684,7 +684,7 @@ def test_undefined_symbols(): def return_undefined(): return undefined_symbol - with pytest.raises(CompilationError, match="Undeclared symbol"): + with pytest.raises(CompilerError, match="Undeclared symbol"): _ = FieldOperatorParser.apply_to_function(return_undefined) @@ -697,7 +697,7 @@ def as_offset_dim(a: Field[[ADim, BDim], float], b: Field[[ADim], int]): return a(as_offset(Boff, b)) with pytest.raises( - CompilationError, + CompilerError, match=f"not in list of offset field dimensions", ): _ = FieldOperatorParser.apply_to_function(as_offset_dim) @@ -712,7 +712,7 @@ def as_offset_dtype(a: Field[[ADim, BDim], float], b: Field[[BDim], float]): return a(as_offset(Boff, b)) with pytest.raises( - CompilationError, + CompilerError, match=f"Excepted integer for offset field dtype", ): _ = FieldOperatorParser.apply_to_function(as_offset_dtype) diff --git a/tests/next_tests/integration_tests/feature_tests/test_util_cases.py b/tests/next_tests/integration_tests/feature_tests/test_util_cases.py index 8cdbe02c5e..43ffcc318c 100644 --- a/tests/next_tests/integration_tests/feature_tests/test_util_cases.py +++ b/tests/next_tests/integration_tests/feature_tests/test_util_cases.py @@ -89,7 +89,7 @@ def test_verify_fails_with_wrong_type(cartesian_case): # noqa: F811 # fixtures b = cases.allocate(cartesian_case, addition, "b")() out = cases.allocate(cartesian_case, addition, cases.RETURN)() - with pytest.raises(CompilationError): + with pytest.raises(CompilerError): cases.verify(cartesian_case, addition, a, b, out=out, ref=a.array() + b.array()) diff --git a/tests/next_tests/unit_tests/errors_tests/test_compilation_error.py b/tests/next_tests/unit_tests/errors_tests/test_compilation_error.py index e4784277bf..ea75165a2b 100644 --- a/tests/next_tests/unit_tests/errors_tests/test_compilation_error.py +++ b/tests/next_tests/unit_tests/errors_tests/test_compilation_error.py @@ -1,4 +1,4 @@ -from gt4py.next.errors import CompilationError +from gt4py.next.errors import CompilerError from gt4py.eve import SourceLocation @@ -7,10 +7,10 @@ def test_message(): - assert CompilationError(loc, msg).msg == msg + assert CompilerError(loc, msg).message == msg def test_location(): - assert CompilationError(loc, msg).location == loc + assert CompilerError(loc, msg).location_trace[0] == loc diff --git a/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py b/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py index 971b6ee08e..da4a182387 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py @@ -106,7 +106,7 @@ def untyped(inp): return inp with pytest.raises( - MissingParameterTypeError + MissingParameterAnnotationError ): _ = FieldOperatorParser.apply_to_function(untyped) @@ -145,7 +145,7 @@ def no_return(inp: Field[[TDim], "float64"]): tmp = inp # noqa with pytest.raises( - CompilationError, + CompilerError, match=".*return.*", ): _ = FieldOperatorParser.apply_to_function(no_return) @@ -159,7 +159,7 @@ def invalid_assign_to_expr(inp1: Field[[TDim], "float64"], inp2: Field[[TDim], " tmp[-1] = inp2 return tmp - with pytest.raises(CompilationError, match=r".*assign.*"): + with pytest.raises(CompilerError, match=r".*assign.*"): _ = FieldOperatorParser.apply_to_function(invalid_assign_to_expr) @@ -185,7 +185,7 @@ def clashing(inp: Field[[TDim], "float64"]): tmp: Field[[TDim], "int64"] = inp return tmp - with pytest.raises(CompilationError, match="type inconsistency"): + with pytest.raises(CompilerError, match="type inconsistency"): _ = FieldOperatorParser.apply_to_function(clashing) @@ -264,7 +264,7 @@ def cast_scalar_temp(): tmp = int64(1) return int32(tmp) - with pytest.raises(CompilationError, match=r".*literal.*"): + with pytest.raises(CompilerError, match=r".*literal.*"): _ = FieldOperatorParser.apply_to_function(cast_scalar_temp) @@ -275,7 +275,7 @@ def conditional_wrong_mask_type( return where(a, a, a) msg = r"Expected a field with dtype bool." - with pytest.raises(CompilationError, match=msg): + with pytest.raises(CompilerError, match=msg): _ = FieldOperatorParser.apply_to_function(conditional_wrong_mask_type) @@ -288,7 +288,7 @@ def conditional_wrong_arg_type( return where(mask, a, b) msg = r"Could not promote scalars of different dtype \(not implemented\)." - with pytest.raises(CompilationError) as exc_info: + with pytest.raises(CompilerError) as exc_info: _ = FieldOperatorParser.apply_to_function(conditional_wrong_arg_type) assert re.search(msg, exc_info.value.__cause__.args[0]) is not None @@ -298,7 +298,7 @@ def test_ternary_with_field_condition(): def ternary_with_field_condition(cond: Field[[], bool]): return 1 if cond else 2 - with pytest.raises(CompilationError, match=r"should be .* `bool`"): + with pytest.raises(CompilerError, match=r"should be .* `bool`"): _ = FieldOperatorParser.apply_to_function(ternary_with_field_condition) @@ -317,7 +317,7 @@ def test_adr13_wrong_return_type_annotation(): def wrong_return_type_annotation() -> Field[[], float]: return 1.0 - with pytest.raises(CompilationError, match=r"Expected `float.*`"): + with pytest.raises(CompilerError, match=r"Expected `float.*`"): _ = FieldOperatorParser.apply_to_function(wrong_return_type_annotation) @@ -400,7 +400,7 @@ def wrong_return_type_annotation(a: Field[[ADim], float64]) -> Field[[BDim], flo return a with pytest.raises( - CompilationError, + CompilerError, match=r"Annotated return type does not match deduced return type", ): _ = FieldOperatorParser.apply_to_function(wrong_return_type_annotation) @@ -411,7 +411,7 @@ def empty_dims() -> Field[[], float]: return 1.0 with pytest.raises( - CompilationError, + CompilerError, match=r"Annotated return type does not match deduced return type", ): _ = FieldOperatorParser.apply_to_function(empty_dims) @@ -426,7 +426,7 @@ def zero_dims_ternary( return a if cond == 1 else b msg = r"Incompatible datatypes in operator `==`" - with pytest.raises(CompilationError) as exc_info: + with pytest.raises(CompilerError) as exc_info: _ = FieldOperatorParser.apply_to_function(zero_dims_ternary) assert re.search(msg, exc_info.value.args[0]) is not None diff --git a/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast_error_line_number.py b/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast_error_line_number.py index f4e8f4d82a..4ca7b405bc 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast_error_line_number.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast_error_line_number.py @@ -39,58 +39,19 @@ def wrong_syntax(inp: common.Field[[TDim], float]): return # <-- this line triggers the syntax error with pytest.raises( - f2f.CompilationError, + f2f.CompilerError, match=( r".*return.*" ), ) as exc_info: _ = f2f.FieldOperatorParser.apply_to_function(wrong_syntax) - assert traceback.format_exception_only(exc_info.value)[1:3] == [ - " return # <-- this line triggers the syntax error\n", - " ^^^^^^\n", - ] - - -def test_wrong_caret_placement_bug(): - """Field operator syntax errors respect python's carets (`^^^^^`) placement.""" - - line = inspect.getframeinfo(inspect.currentframe()).lineno - - def wrong_line_syntax_error(inp: common.Field[[TDim], float]): - # the next line triggers the syntax error - inp = inp.this_attribute_surely_doesnt_exist - - return inp - - with pytest.raises(f2f.CompilationError) as exc_info: - _ = f2f.FieldOperatorParser.apply_to_function(wrong_line_syntax_error) - - exc = exc_info.value - - assert (exc.lineno, exc.end_lineno) == (line + 4, line + 4) - - # if `offset` is set, python will display carets (`^^^^`) after printing `text`. - # So `text` has to be the line where the error occurs (otherwise the carets - # will be very misleading). - - # See https://github.com/python/cpython/blob/6ad47b41a650a13b4a9214309c10239726331eb8/Lib/traceback.py#L852-L855 - python_printed_text = exc.text.rstrip("\n").lstrip(" \n\f") - - assert python_printed_text == "inp = inp.this_attribute_surely_doesnt_exist" - - # test that `offset` is aligned with `exc.text` - return_offset = ( - exc.text.find("inp.this_attribute_surely_doesnt_exist") + 1 - ) # offset is 1-based for syntax errors - assert (exc.offset, exc.end_offset) == (return_offset, return_offset + 38) - - print("".join(traceback.format_exception_only(exc))) - - assert traceback.format_exception_only(exc)[1:3] == [ - " inp = inp.this_attribute_surely_doesnt_exist\n", - " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", - ] + assert exc_info.value.location + assert exc_info.value.location.source.find("test_func_to_foast_error_line_number.py") + assert exc_info.value.location.line == line + 3 + assert exc_info.value.location.end_line == line + 3 + assert exc_info.value.location.column == 9 + assert exc_info.value.location.end_column == 15 def test_syntax_error_without_function(): @@ -106,17 +67,15 @@ def invalid_python_syntax(): """, ) - with pytest.raises(SyntaxError) as exc_info: + with pytest.raises(CompilerError) as exc_info: _ = f2f.FieldOperatorParser.apply(source_definition, {}, {}) - exc = exc_info.value - - assert (exc.lineno, exc.end_lineno) == (66, 66) - - assert traceback.format_exception_only(exc)[1:3] == [ - " ret%% # <-- this line triggers the syntax error\n", - " ^\n", - ] + assert exc_info.value.location + assert exc_info.value.location.source.find("test_func_to_foast_error_line_number.py") + assert exc_info.value.location.line == 66 + assert exc_info.value.location.end_line == 66 + assert exc_info.value.location.column == 9 + assert exc_info.value.location.end_column == 10 def test_fo_type_deduction_error(): @@ -127,17 +86,17 @@ def test_fo_type_deduction_error(): def field_operator_with_undeclared_symbol(): return undeclared_symbol - with pytest.raises(CompilationError) as exc_info: + with pytest.raises(CompilerError) as exc_info: _ = f2f.FieldOperatorParser.apply_to_function(field_operator_with_undeclared_symbol) exc = exc_info.value - assert (exc.lineno, exc.end_lineno) == (line + 3, line + 3) - - assert traceback.format_exception_only(exc)[1:3] == [ - " return undeclared_symbol\n", - " ^^^^^^^^^^^^^^^^^\n", - ] + assert exc_info.value.location + assert exc_info.value.location.source.find("test_func_to_foast_error_line_number.py") + assert exc_info.value.location.line == line + 3 + assert exc_info.value.location.end_line == line + 3 + assert exc_info.value.location.column == 16 + assert exc_info.value.location.end_column == 33 # TODO: test program type deduction? diff --git a/tests/next_tests/unit_tests/ffront_tests/test_func_to_past.py b/tests/next_tests/unit_tests/ffront_tests/test_func_to_past.py index 08041a7ec9..9730ba1b6c 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_func_to_past.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_func_to_past.py @@ -113,7 +113,7 @@ def undefined_field_program(in_field: Field[[IDim], "float64"]): identity(in_field, out=out_field) with pytest.raises( - CompilationError, + CompilerError, match=(r"Undeclared or untyped symbol `out_field`."), ): ProgramParser.apply_to_function(undefined_field_program) @@ -162,7 +162,7 @@ def domain_format_1_program(in_field: Field[[IDim], float64]): domain_format_1(in_field, out=in_field, domain=(0, 2)) with pytest.raises( - CompilationError, + CompilerError, ) as exc_info: ProgramParser.apply_to_function(domain_format_1_program) @@ -181,7 +181,7 @@ def domain_format_2_program(in_field: Field[[IDim], float64]): domain_format_2(in_field, out=in_field, domain={IDim: (0, 1, 2)}) with pytest.raises( - CompilationError, + CompilerError, ) as exc_info: ProgramParser.apply_to_function(domain_format_2_program) @@ -200,7 +200,7 @@ def domain_format_3_program(in_field: Field[[IDim], float64]): domain_format_3(in_field, domain={IDim: (0, 2)}) with pytest.raises( - CompilationError, + CompilerError, ) as exc_info: ProgramParser.apply_to_function(domain_format_3_program) @@ -221,7 +221,7 @@ def domain_format_4_program(in_field: Field[[IDim], float64]): ) with pytest.raises( - CompilationError, + CompilerError, ) as exc_info: ProgramParser.apply_to_function(domain_format_4_program) @@ -240,7 +240,7 @@ def domain_format_5_program(in_field: Field[[IDim], float64]): domain_format_5(in_field, out=in_field, domain={IDim: ("1.0", 9.0)}) with pytest.raises( - CompilationError, + CompilerError, ) as exc_info: ProgramParser.apply_to_function(domain_format_5_program) @@ -259,7 +259,7 @@ def domain_format_6_program(in_field: Field[[IDim], float64]): domain_format_6(in_field, out=in_field, domain={}) with pytest.raises( - CompilationError, + CompilerError, ) as exc_info: ProgramParser.apply_to_function(domain_format_6_program) diff --git a/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py b/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py index 9b3b366bb2..e6975edf0a 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py @@ -166,7 +166,7 @@ def inout_field_program(inout_field: Field[[IDim], "float64"]): def test_invalid_call_sig_program(invalid_call_sig_program_def): with pytest.raises( - CompilationError, + CompilerError, ) as exc_info: ProgramLowering.apply( ProgramParser.apply_to_function(invalid_call_sig_program_def), From 3ea99d5da61dab3e9743b4237693d79c9e07d3ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Kardos?= Date: Wed, 21 Jun 2023 09:52:50 +0200 Subject: [PATCH 12/54] tune --- src/gt4py/next/errors/formatting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/errors/formatting.py b/src/gt4py/next/errors/formatting.py index 0e6fb532bd..fe158a4992 100644 --- a/src/gt4py/next/errors/formatting.py +++ b/src/gt4py/next/errors/formatting.py @@ -41,7 +41,7 @@ def format_compilation_error(type_: type[Exception], message: str, location_trac msg_str = f"{type_.__module__}.{type_.__name__}: {message}" try: - loc_str = "".join([format_location(loc) for loc in location_trace]) + loc_str = "".join([format_location(loc, caret=True) for loc in location_trace]) stack_str = f"Source location (most recent call last):\n{textwrap.indent(loc_str, ' ')}\n" return [stack_str, msg_str] except ValueError: From 8aef5f0df2e391cc134ef4150a8d1220c36365c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Kardos?= Date: Wed, 21 Jun 2023 11:09:44 +0200 Subject: [PATCH 13/54] fix & clean SourceLocation --- src/gt4py/cartesian/frontend/node_util.py | 2 +- .../cartesian/gtc/dace/expansion/utils.py | 2 +- src/gt4py/eve/concepts.py | 46 ++++++------------- src/gt4py/next/errors/formatting.py | 4 +- src/gt4py/next/ffront/dialect_parser.py | 4 +- tests/eve_tests/definitions.py | 4 +- tests/eve_tests/unit_tests/test_codegen.py | 2 +- tests/eve_tests/unit_tests/test_concepts.py | 40 ++++------------ .../test_math_builtin_execution.py | 2 +- .../errors_tests/test_compilation_error.py | 2 +- .../test_func_to_foast_error_line_number.py | 6 +-- 11 files changed, 39 insertions(+), 75 deletions(-) diff --git a/src/gt4py/cartesian/frontend/node_util.py b/src/gt4py/cartesian/frontend/node_util.py index 5da908de5c..9595d5f76a 100644 --- a/src/gt4py/cartesian/frontend/node_util.py +++ b/src/gt4py/cartesian/frontend/node_util.py @@ -147,4 +147,4 @@ def recurse(node: Node) -> Generator[Node, None, None]: def location_to_source_location(loc: Optional[Location]) -> Optional[eve.SourceLocation]: if loc is None or loc.line <= 0 or loc.column <= 0: return None - return eve.SourceLocation(line=loc.line, column=loc.column, source=loc.scope) + return eve.SourceLocation(line=loc.line, column=loc.column, filename=loc.scope) diff --git a/src/gt4py/cartesian/gtc/dace/expansion/utils.py b/src/gt4py/cartesian/gtc/dace/expansion/utils.py index 613d161fc0..dc10c53f21 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion/utils.py +++ b/src/gt4py/cartesian/gtc/dace/expansion/utils.py @@ -35,7 +35,7 @@ def get_dace_debuginfo(node: common.LocNode): node.loc.column, node.loc.line, node.loc.column, - node.loc.source, + node.loc.filename, ) else: return dace.dtypes.DebugInfo(0) diff --git a/src/gt4py/eve/concepts.py b/src/gt4py/eve/concepts.py index d66e8583fd..bbb161b824 100644 --- a/src/gt4py/eve/concepts.py +++ b/src/gt4py/eve/concepts.py @@ -58,58 +58,42 @@ class SymbolRef(ConstrainedStr, regex=_SYMBOL_NAME_RE): @datamodels.datamodel(slots=True, frozen=True) class SourceLocation: - """Source code location (line, column, source).""" + """File-line-column information for source code.""" + filename: Optional[str] line: int = datamodels.field(validator=_validators.ge(1)) column: int = datamodels.field(validator=_validators.ge(1)) - source: str end_line: Optional[int] = datamodels.field(validator=_validators.optional(_validators.ge(1))) end_column: Optional[int] = datamodels.field(validator=_validators.optional(_validators.ge(1))) - @classmethod - def from_AST(cls, ast_node: ast.AST, source: Optional[str] = None) -> SourceLocation: - if ( - not isinstance(ast_node, ast.AST) - or getattr(ast_node, "lineno", None) is None - or getattr(ast_node, "col_offset", None) is None - ): - raise ValueError( - f"Passed AST node '{ast_node}' does not contain a valid source location." - ) - if source is None: - source = f"" - return cls( - ast_node.lineno, - ast_node.col_offset + 1, - source, - end_line=ast_node.end_lineno, - end_column=ast_node.end_col_offset + 1 if ast_node.end_col_offset is not None else None, - ) - def __init__( self, + filename: Optional[str], line: int, column: int, - source: str, - *, end_line: Optional[int] = None, end_column: Optional[int] = None, ) -> None: assert end_column is None or end_line is not None self.__auto_init__( # type: ignore[attr-defined] # __auto_init__ added dynamically - line=line, column=column, source=source, end_line=end_line, end_column=end_column + filename=filename, line=line, column=column, end_line=end_line, end_column=end_column ) def __str__(self) -> str: - src = self.source or "" + filename_str = self.filename or "-" + + end_line_str = self.end_line if self.end_line is not None else "-" + end_column_str = self.end_column if self.end_column is not None else "-" - end_part = "" - if self.end_line is not None: - end_part += f" to {self.end_line}" + end_str: Optional[str] = None if self.end_column is not None: - end_part += f":{self.end_column}" + end_str = f"{end_line_str}:{end_column_str}" + elif self.end_line is not None: + end_str = f"{end_line_str}" - return f"<{src}:{self.line}:{self.column}{end_part}>" + if end_str is not None: + return f"{filename_str}:{self.line}:{self.column} to {end_str}" + return f"{filename_str}:{self.line}:{self.column}" @datamodels.datamodel(slots=True, frozen=True) diff --git a/src/gt4py/next/errors/formatting.py b/src/gt4py/next/errors/formatting.py index fe158a4992..1b811feab0 100644 --- a/src/gt4py/next/errors/formatting.py +++ b/src/gt4py/next/errors/formatting.py @@ -5,7 +5,7 @@ def get_source_from_location(location: SourceLocation): try: - source_file = pathlib.Path(location.source) + source_file = pathlib.Path(location.filename) source_code = source_file.read_text() source_lines = source_code.splitlines(False) start_line = location.line @@ -17,7 +17,7 @@ def get_source_from_location(location: SourceLocation): def format_location(loc: SourceLocation, caret: bool = False): - filename = loc.source or "" + filename = loc.filename or "" lineno = loc.line or "" loc_str = f"File \"{filename}\", line {lineno}" diff --git a/src/gt4py/next/ffront/dialect_parser.py b/src/gt4py/next/ffront/dialect_parser.py index 72861ee4ee..7bb1ebcc5d 100644 --- a/src/gt4py/next/ffront/dialect_parser.py +++ b/src/gt4py/next/ffront/dialect_parser.py @@ -38,7 +38,7 @@ def parse_source_definition(source_definition: SourceDefinition) -> ast.AST: loc = SourceLocation( line=err.lineno + source_definition.line_offset if err.lineno is not None else None, column=err.offset + source_definition.column_offset if err.offset is not None else None, - source=source_definition.filename, + filename=source_definition.filename, end_line=err.end_lineno + source_definition.line_offset if err.end_lineno is not None else None, end_column=err.end_offset + source_definition.column_offset if err.end_offset is not None else None ) @@ -107,5 +107,5 @@ def get_location(self, node: ast.AST) -> SourceLocation: column = 1 + node.col_offset + col_offset if node.col_offset is not None else None end_column = 1 + node.end_col_offset + col_offset if node.end_col_offset is not None else None - loc = SourceLocation(line, column, file, end_line=end_line, end_column=end_column) + loc = SourceLocation(file, line, column, end_line=end_line, end_column=end_column) return loc diff --git a/tests/eve_tests/definitions.py b/tests/eve_tests/definitions.py index b20fb28ddb..f4f0232ae4 100644 --- a/tests/eve_tests/definitions.py +++ b/tests/eve_tests/definitions.py @@ -297,7 +297,7 @@ def make_source_location(*, fixed: bool = False) -> SourceLocation: str_value = make_str_value(fixed=fixed) source = f"file_{str_value}.py" - return SourceLocation(line=line, column=column, source=source) + return SourceLocation(line=line, column=column, filename=source) def make_source_location_group(*, fixed: bool = False) -> SourceLocationGroup: @@ -472,7 +472,7 @@ def make_frozen_simple_node(*, fixed: bool = False) -> FrozenSimpleNode: # -- Makers of invalid nodes -- def make_invalid_location_node(*, fixed: bool = False) -> LocationNode: - return LocationNode(loc=SourceLocation(line=0, column=-1, source="")) + return LocationNode(loc=SourceLocation(line=0, column=-1, filename="")) def make_invalid_at_int_simple_node(*, fixed: bool = False) -> SimpleNode: diff --git a/tests/eve_tests/unit_tests/test_codegen.py b/tests/eve_tests/unit_tests/test_codegen.py index dcd81ef4fd..7e7ec2244e 100644 --- a/tests/eve_tests/unit_tests/test_codegen.py +++ b/tests/eve_tests/unit_tests/test_codegen.py @@ -106,7 +106,7 @@ def visit_IntKind(self, node, **kwargs): return f"ONE INTKIND({node.value})" def visit_SourceLocation(self, node, **kwargs): - return f"SourceLocation" + return f"SourceLocation" LocationNode = codegen.FormatTemplate("LocationNode {{{loc}}}") diff --git a/tests/eve_tests/unit_tests/test_concepts.py b/tests/eve_tests/unit_tests/test_concepts.py index 10cea03836..f01e79f626 100644 --- a/tests/eve_tests/unit_tests/test_concepts.py +++ b/tests/eve_tests/unit_tests/test_concepts.py @@ -46,49 +46,29 @@ class LettersOnlySymbol(SymbolName, regex=re.compile(r"[a-zA-Z]+$")): class TestSourceLocation: def test_valid_position(self): - eve.concepts.SourceLocation(line=1, column=1, source="source.py") + eve.concepts.SourceLocation(line=1, column=1, filename="source.py") def test_invalid_position(self): with pytest.raises(ValueError, match="column"): - eve.concepts.SourceLocation(line=1, column=-1, source="source.py") + eve.concepts.SourceLocation(line=1, column=-1, filename="source.py") def test_str(self): - loc = eve.concepts.SourceLocation(line=1, column=1, source="dir/source.py") + loc = eve.concepts.SourceLocation(line=1, column=1, filename="dir/source.py") assert str(loc) == "" - loc = eve.concepts.SourceLocation(line=1, column=1, source="dir/source.py", end_line=2) + loc = eve.concepts.SourceLocation(line=1, column=1, filename="dir/source.py", end_line=2) assert str(loc) == "" loc = eve.concepts.SourceLocation( - line=1, column=1, source="dir/source.py", end_line=2, end_column=2 + line=1, column=1, filename="dir/source.py", end_line=2, end_column=2 ) assert str(loc) == "" - def test_construction_from_ast(self): - import ast - - ast_node = ast.parse("a = b + 1").body[0] - loc = eve.concepts.SourceLocation.from_AST(ast_node, "source.py") - - assert loc.line == ast_node.lineno - assert loc.column == ast_node.col_offset + 1 - assert loc.source == "source.py" - assert loc.end_line == ast_node.end_lineno - assert loc.end_column == ast_node.end_col_offset + 1 - - loc = eve.concepts.SourceLocation.from_AST(ast_node) - - assert loc.line == ast_node.lineno - assert loc.column == ast_node.col_offset + 1 - assert loc.source == f"" - assert loc.end_line == ast_node.end_lineno - assert loc.end_column == ast_node.end_col_offset + 1 - class TestSourceLocationGroup: def test_valid_locations(self): - loc1 = eve.concepts.SourceLocation(line=1, column=1, source="source1.py") - loc2 = eve.concepts.SourceLocation(line=2, column=2, source="source2.py") + loc1 = eve.concepts.SourceLocation(line=1, column=1, filename="source1.py") + loc2 = eve.concepts.SourceLocation(line=2, column=2, filename="source2.py") eve.concepts.SourceLocationGroup(loc1) eve.concepts.SourceLocationGroup(loc1, loc2) eve.concepts.SourceLocationGroup(loc1, loc1, loc2, loc2, context="test context") @@ -96,13 +76,13 @@ def test_valid_locations(self): def test_invalid_locations(self): with pytest.raises(ValueError): eve.concepts.SourceLocationGroup() - loc1 = eve.concepts.SourceLocation(line=1, column=1, source="source.py") + loc1 = eve.concepts.SourceLocation(line=1, column=1, filename="source.py") with pytest.raises(TypeError): eve.concepts.SourceLocationGroup(loc1, "loc2") def test_str(self): - loc1 = eve.concepts.SourceLocation(line=1, column=1, source="source1.py") - loc2 = eve.concepts.SourceLocation(line=2, column=2, source="source2.py") + loc1 = eve.concepts.SourceLocation(line=1, column=1, filename="source1.py") + loc2 = eve.concepts.SourceLocation(line=2, column=2, filename="source2.py") loc = eve.concepts.SourceLocationGroup(loc1, loc2, context="some context") assert str(loc) == "<#some context#[, ]>" diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py index d0e3969e85..aec9380958 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py @@ -58,7 +58,7 @@ def make_builtin_field_operator(builtin_name: str): closure_vars = {"IDim": IDim, builtin_name: getattr(fbuiltins, builtin_name)} - loc = foast.SourceLocation(line=1, column=1, source="none") + loc = foast.SourceLocation(line=1, column=1, filename="none") params = [ foast.Symbol(id=k, type=type_translation.from_type_hint(type), location=loc) diff --git a/tests/next_tests/unit_tests/errors_tests/test_compilation_error.py b/tests/next_tests/unit_tests/errors_tests/test_compilation_error.py index ea75165a2b..f377c13065 100644 --- a/tests/next_tests/unit_tests/errors_tests/test_compilation_error.py +++ b/tests/next_tests/unit_tests/errors_tests/test_compilation_error.py @@ -2,7 +2,7 @@ from gt4py.eve import SourceLocation -loc = SourceLocation(5, 2, "/source/file.py", end_line=5, end_column=9) +loc = SourceLocation("/source/file.py", 5, 2, end_line=5, end_column=9) msg = "a message" diff --git a/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast_error_line_number.py b/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast_error_line_number.py index 4ca7b405bc..1b205a1eb0 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast_error_line_number.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast_error_line_number.py @@ -47,7 +47,7 @@ def wrong_syntax(inp: common.Field[[TDim], float]): _ = f2f.FieldOperatorParser.apply_to_function(wrong_syntax) assert exc_info.value.location - assert exc_info.value.location.source.find("test_func_to_foast_error_line_number.py") + assert exc_info.value.location.filename.find("test_func_to_foast_error_line_number.py") assert exc_info.value.location.line == line + 3 assert exc_info.value.location.end_line == line + 3 assert exc_info.value.location.column == 9 @@ -71,7 +71,7 @@ def invalid_python_syntax(): _ = f2f.FieldOperatorParser.apply(source_definition, {}, {}) assert exc_info.value.location - assert exc_info.value.location.source.find("test_func_to_foast_error_line_number.py") + assert exc_info.value.location.filename.find("test_func_to_foast_error_line_number.py") assert exc_info.value.location.line == 66 assert exc_info.value.location.end_line == 66 assert exc_info.value.location.column == 9 @@ -92,7 +92,7 @@ def field_operator_with_undeclared_symbol(): exc = exc_info.value assert exc_info.value.location - assert exc_info.value.location.source.find("test_func_to_foast_error_line_number.py") + assert exc_info.value.location.filename.find("test_func_to_foast_error_line_number.py") assert exc_info.value.location.line == line + 3 assert exc_info.value.location.end_line == line + 3 assert exc_info.value.location.column == 16 From f498e008ec01e7ab00e5597ceeeca73e0c63d1f5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Kardos?= Date: Wed, 21 Jun 2023 11:56:31 +0200 Subject: [PATCH 14/54] revert source location str --- src/gt4py/eve/concepts.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gt4py/eve/concepts.py b/src/gt4py/eve/concepts.py index bbb161b824..f518aff42e 100644 --- a/src/gt4py/eve/concepts.py +++ b/src/gt4py/eve/concepts.py @@ -92,8 +92,8 @@ def __str__(self) -> str: end_str = f"{end_line_str}" if end_str is not None: - return f"{filename_str}:{self.line}:{self.column} to {end_str}" - return f"{filename_str}:{self.line}:{self.column}" + return f"<{filename_str}:{self.line}:{self.column} to {end_str}>" + return f"<{filename_str}:{self.line}:{self.column}>" @datamodels.datamodel(slots=True, frozen=True) From e1c1f04b0ae09e3e38698f632431565b9adccbb8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Kardos?= Date: Wed, 21 Jun 2023 12:47:22 +0200 Subject: [PATCH 15/54] fix qa --- src/gt4py/eve/concepts.py | 1 - src/gt4py/next/errors/__init__.py | 40 +++- src/gt4py/next/errors/excepthook.py | 36 +++- src/gt4py/next/errors/exceptions.py | 25 ++- src/gt4py/next/errors/formatting.py | 44 ++-- src/gt4py/next/ffront/decorator.py | 1 - src/gt4py/next/ffront/dialect_parser.py | 24 ++- .../foast_passes/closure_var_folding.py | 2 +- .../ffront/foast_passes/type_deduction.py | 190 ++++++++++-------- src/gt4py/next/ffront/func_to_foast.py | 32 ++- src/gt4py/next/ffront/func_to_past.py | 10 +- .../next/ffront/past_passes/type_deduction.py | 10 +- src/gt4py/next/ffront/past_to_itir.py | 4 +- src/gt4py/next/ffront/source_utils.py | 9 +- src/gt4py/next/type_system/type_info.py | 4 +- tests/next_tests/exception_printing.py | 23 ++- .../ffront_tests/test_execution.py | 6 +- .../ffront_tests/test_program.py | 4 +- .../ffront_tests/test_scalar_if.py | 6 +- .../ffront_tests/test_type_deduction.py | 2 +- .../feature_tests/test_util_cases.py | 3 +- .../errors_tests/test_compilation_error.py | 18 +- .../ffront_tests/test_func_to_foast.py | 6 +- .../test_func_to_foast_error_line_number.py | 6 +- .../ffront_tests/test_func_to_past.py | 2 +- .../ffront_tests/test_past_to_itir.py | 2 +- .../test_type_translation.py | 4 +- 27 files changed, 331 insertions(+), 183 deletions(-) diff --git a/src/gt4py/eve/concepts.py b/src/gt4py/eve/concepts.py index f518aff42e..67991f6db0 100644 --- a/src/gt4py/eve/concepts.py +++ b/src/gt4py/eve/concepts.py @@ -17,7 +17,6 @@ from __future__ import annotations -import ast import copy import re diff --git a/src/gt4py/next/errors/__init__.py b/src/gt4py/next/errors/__init__.py index 0fc590bca0..3b7f123d5f 100644 --- a/src/gt4py/next/errors/__init__.py +++ b/src/gt4py/next/errors/__init__.py @@ -1,11 +1,39 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from . import ( # noqa: module needs to be loaded for pretty printing of uncaught exceptions. + excepthook, +) from .exceptions import ( + ArgumentCountError, CompilerError, - UndefinedSymbolError, - UnsupportedPythonFeatureError, - MissingParameterAnnotationError, InvalidParameterAnnotationError, - ArgumentCountError, KeywordArgumentError, - MissingAttributeError + MissingAttributeError, + MissingParameterAnnotationError, + UndefinedSymbolError, + UnsupportedPythonFeatureError, ) -from . import excepthook \ No newline at end of file + + +__all__ = [ + "ArgumentCountError", + "CompilerError", + "InvalidParameterAnnotationError", + "KeywordArgumentError", + "MissingAttributeError", + "MissingParameterAnnotationError", + "UndefinedSymbolError", + "UnsupportedPythonFeatureError", +] diff --git a/src/gt4py/next/errors/excepthook.py b/src/gt4py/next/errors/excepthook.py index 6b49a4b3df..4b180fc74c 100644 --- a/src/gt4py/next/errors/excepthook.py +++ b/src/gt4py/next/errors/excepthook.py @@ -1,14 +1,36 @@ -from . import formatting -from . import exceptions -from typing import Callable +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + import sys +from typing import Callable + +from . import exceptions, formatting + -def compilation_error_hook(fallback: Callable, type_: type, value: exceptions.CompilerError, tb): - if issubclass(type_, exceptions.CompilerError): - print("".join(formatting.format_compilation_error(type_, value.message, value.location_trace)), file=sys.stderr) +def compilation_error_hook(fallback: Callable, type_: type, value: BaseException, tb): + if isinstance(value, exceptions.CompilerError): + print( + "".join( + formatting.format_compilation_error( + type(value), value.message, value.location_trace + ) + ), + file=sys.stderr, + ) else: fallback(type_, value, tb) _fallback = sys.excepthook -sys.excepthook = lambda ty, val, tb: compilation_error_hook(_fallback, ty, val, tb) \ No newline at end of file +sys.excepthook = lambda ty, val, tb: compilation_error_hook(_fallback, ty, val, tb) diff --git a/src/gt4py/next/errors/exceptions.py b/src/gt4py/next/errors/exceptions.py index 4e89df2173..03a5af0207 100644 --- a/src/gt4py/next/errors/exceptions.py +++ b/src/gt4py/next/errors/exceptions.py @@ -1,7 +1,22 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + import textwrap +from typing import Any, Optional, TypeVar from gt4py.eve import SourceLocation -from typing import Any, Optional, TypeVar + from . import formatting @@ -72,12 +87,16 @@ def __init__(self, location: LocationTraceT, param_name: str): class InvalidParameterAnnotationError(CompilerTypeError): def __init__(self, location: LocationTraceT, param_name: str, type_: Any): - super().__init__(location, f"parameter '{param_name}' has invalid type annotation '{type_}'") + super().__init__( + location, f"parameter '{param_name}' has invalid type annotation '{type_}'" + ) class ArgumentCountError(CompilerTypeError): def __init__(self, location: LocationTraceT, num_expected: int, num_provided: int): - super().__init__(location, f"expected {num_expected} arguments but {num_provided} were provided") + super().__init__( + location, f"expected {num_expected} arguments but {num_provided} were provided" + ) class KeywordArgumentError(CompilerTypeError): diff --git a/src/gt4py/next/errors/formatting.py b/src/gt4py/next/errors/formatting.py index 1b811feab0..70609a6284 100644 --- a/src/gt4py/next/errors/formatting.py +++ b/src/gt4py/next/errors/formatting.py @@ -1,25 +1,39 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import pathlib import textwrap + from gt4py.eve import SourceLocation -import pathlib def get_source_from_location(location: SourceLocation): - try: - source_file = pathlib.Path(location.filename) - source_code = source_file.read_text() - source_lines = source_code.splitlines(False) - start_line = location.line - end_line = location.end_line + 1 if location.end_line else start_line + 1 - relevant_lines = source_lines[(start_line-1):(end_line-1)] - return "\n".join(relevant_lines) - except Exception as ex: - raise ValueError("failed to get source code for source location") from ex + if not location.filename: + raise FileNotFoundError() + source_file = pathlib.Path(location.filename) + source_code = source_file.read_text() + source_lines = source_code.splitlines(False) + start_line = location.line + end_line = location.end_line + 1 if location.end_line else start_line + 1 + relevant_lines = source_lines[(start_line - 1) : (end_line - 1)] + return "\n".join(relevant_lines) def format_location(loc: SourceLocation, caret: bool = False): filename = loc.filename or "" lineno = loc.line or "" - loc_str = f"File \"{filename}\", line {lineno}" + loc_str = f'File "{filename}", line {lineno}' if caret and loc.column is not None: offset = loc.column - 1 @@ -33,11 +47,13 @@ def format_location(loc: SourceLocation, caret: bool = False): if caret_str: snippet_str = f"{snippet_str}\n{caret_str}" return f"{loc_str}\n{textwrap.indent(snippet_str, ' ')}" - except ValueError: + except Exception: return loc_str -def format_compilation_error(type_: type[Exception], message: str, location_trace: list[SourceLocation]): +def format_compilation_error( + type_: type[Exception], message: str, location_trace: list[SourceLocation] +): msg_str = f"{type_.__module__}.{type_.__name__}: {message}" try: diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index a7ec8dd88d..6f3e8ca97b 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -54,7 +54,6 @@ from gt4py.next.program_processors import processor_interface as ppi from gt4py.next.program_processors.runners import roundtrip from gt4py.next.type_system import type_info, type_specifications as ts, type_translation -from gt4py.next.errors import * DEFAULT_BACKEND: Callable = roundtrip.executor diff --git a/src/gt4py/next/ffront/dialect_parser.py b/src/gt4py/next/ffront/dialect_parser.py index 7bb1ebcc5d..38c6fd838b 100644 --- a/src/gt4py/next/ffront/dialect_parser.py +++ b/src/gt4py/next/ffront/dialect_parser.py @@ -19,13 +19,11 @@ from typing import Callable from gt4py.eve.concepts import SourceLocation -from gt4py.eve.extended_typing import Any, ClassVar, Generic, Optional, Type, TypeVar -from gt4py.next import common -from gt4py.next.errors import * +from gt4py.eve.extended_typing import Any, Generic, TypeVar +from gt4py.next.errors import CompilerError, UnsupportedPythonFeatureError from gt4py.next.ffront.ast_passes.fix_missing_locations import FixMissingLocations from gt4py.next.ffront.ast_passes.remove_docstrings import RemoveDocstrings from gt4py.next.ffront.source_utils import SourceDefinition, get_closure_vars_from_function -from gt4py.next.errors import UnsupportedPythonFeatureError DialectRootT = TypeVar("DialectRootT") @@ -35,12 +33,18 @@ def parse_source_definition(source_definition: SourceDefinition) -> ast.AST: try: return ast.parse(textwrap.dedent(source_definition.source)).body[0] except SyntaxError as err: + assert err.lineno + assert err.offset loc = SourceLocation( - line=err.lineno + source_definition.line_offset if err.lineno is not None else None, - column=err.offset + source_definition.column_offset if err.offset is not None else None, + line=err.lineno + source_definition.line_offset, + column=err.offset + source_definition.column_offset, filename=source_definition.filename, - end_line=err.end_lineno + source_definition.line_offset if err.end_lineno is not None else None, - end_column=err.end_offset + source_definition.column_offset if err.end_offset is not None else None + end_line=err.end_lineno + source_definition.line_offset + if err.end_lineno is not None + else None, + end_column=err.end_offset + source_definition.column_offset + if err.end_offset is not None + else None, ) raise CompilerError(loc, err.msg).with_traceback(err.__traceback__) @@ -105,7 +109,9 @@ def get_location(self, node: ast.AST) -> SourceLocation: line = node.lineno + line_offset if node.lineno is not None else None end_line = node.end_lineno + line_offset if node.end_lineno is not None else None column = 1 + node.col_offset + col_offset if node.col_offset is not None else None - end_column = 1 + node.end_col_offset + col_offset if node.end_col_offset is not None else None + end_column = ( + 1 + node.end_col_offset + col_offset if node.end_col_offset is not None else None + ) loc = SourceLocation(file, line, column, end_line=end_line, end_column=end_column) return loc diff --git a/src/gt4py/next/ffront/foast_passes/closure_var_folding.py b/src/gt4py/next/ffront/foast_passes/closure_var_folding.py index 3504864f78..a97090e6a8 100644 --- a/src/gt4py/next/ffront/foast_passes/closure_var_folding.py +++ b/src/gt4py/next/ffront/foast_passes/closure_var_folding.py @@ -18,7 +18,7 @@ import gt4py.next.ffront.field_operator_ast as foast from gt4py.eve import NodeTranslator, traits from gt4py.eve.utils import FrozenNamespace -from gt4py.next.errors import * +from gt4py.next.errors import CompilerError, MissingAttributeError @dataclass diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index 899f7f20f8..6a949a0245 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -17,6 +17,7 @@ import gt4py.next.ffront.field_operator_ast as foast from gt4py.eve import NodeTranslator, NodeVisitor, traits from gt4py.next.common import DimensionKind +from gt4py.next.errors import CompilerError from gt4py.next.ffront import ( # noqa dialect_ast_enums, fbuiltins, @@ -25,7 +26,6 @@ ) from gt4py.next.ffront.foast_passes.utils import compute_assign_indices from gt4py.next.type_system import type_info, type_specifications as ts, type_translation -from gt4py.next.errors import * def boolified_type(symbol_type: ts.TypeSpec) -> ts.ScalarType | ts.FieldType: @@ -145,10 +145,11 @@ def deduce_stmt_return_type( if return_types[0] == return_types[1]: is_unconditional_return = True else: - raise CompilerError(stmt.location, + raise CompilerError( + stmt.location, f"If statement contains return statements with inconsistent types:" f"{return_types[0]} != {return_types[1]}", - ) + ) return_type = return_types[0] or return_types[1] elif isinstance(stmt, foast.BlockStmt): # just forward to nested BlockStmt @@ -161,10 +162,11 @@ def deduce_stmt_return_type( raise AssertionError(f"Nodes of type `{type(stmt).__name__}` not supported.") if conditional_return_type and return_type and return_type != conditional_return_type: - raise CompilerError(stmt.location, + raise CompilerError( + stmt.location, f"If statement contains return statements with inconsistent types:" f"{conditional_return_type} != {conditional_return_type}", - ) + ) if is_unconditional_return: # found a statement that always returns assert return_type @@ -246,9 +248,10 @@ def visit_FunctionDefinition(self, node: foast.FunctionDefinition, **kwargs): new_closure_vars = self.visit(node.closure_vars, **kwargs) return_type = deduce_stmt_return_type(new_body) if not isinstance(return_type, (ts.DataType, ts.DeferredType, ts.VoidType)): - raise CompilerError(node.location, + raise CompilerError( + node.location, f"Function must return `DataType`, `DeferredType`, or `VoidType`, got `{return_type}`.", - ) + ) new_type = ts.FunctionType( pos_only_args=[], pos_or_kw_args={str(new_param.id): new_param.type for new_param in new_params}, @@ -277,26 +280,30 @@ def visit_FieldOperator(self, node: foast.FieldOperator, **kwargs) -> foast.Fiel def visit_ScanOperator(self, node: foast.ScanOperator, **kwargs) -> foast.ScanOperator: new_axis = self.visit(node.axis, **kwargs) if not isinstance(new_axis.type, ts.DimensionType): - raise CompilerError(node.location, + raise CompilerError( + node.location, f"Argument `axis` to scan operator `{node.id}` must be a dimension.", - ) + ) if not new_axis.type.dim.kind == DimensionKind.VERTICAL: - raise CompilerError(node.location, + raise CompilerError( + node.location, f"Argument `axis` to scan operator `{node.id}` must be a vertical dimension.", - ) + ) new_forward = self.visit(node.forward, **kwargs) if not new_forward.type.kind == ts.ScalarKind.BOOL: - raise CompilerError(node.location, f"Argument `forward` to scan operator `{node.id}` must be a boolean." - ) + raise CompilerError( + node.location, f"Argument `forward` to scan operator `{node.id}` must be a boolean." + ) new_init = self.visit(node.init, **kwargs) if not all( type_info.is_arithmetic(type_) or type_info.is_logical(type_) for type_ in type_info.primitive_constituents(new_init.type) ): - raise CompilerError(node.location, + raise CompilerError( + node.location, f"Argument `init` to scan operator `{node.id}` must " f"be an arithmetic type or a logical type or a composite of arithmetic and logical types.", - ) + ) new_definition = self.visit(node.definition, **kwargs) new_type = ts_ffront.ScanOperatorType( axis=new_axis.type.dim, @@ -315,8 +322,7 @@ def visit_ScanOperator(self, node: foast.ScanOperator, **kwargs) -> foast.ScanOp def visit_Name(self, node: foast.Name, **kwargs) -> foast.Name: symtable = kwargs["symtable"] if node.id not in symtable or symtable[node.id].type is None: - raise CompilerError(node.location, f"Undeclared symbol `{node.id}`." - ) + raise CompilerError(node.location, f"Undeclared symbol `{node.id}`.") symbol = symtable[node.id] return foast.Name(id=node.id, type=symbol.type, location=node.location) @@ -340,8 +346,9 @@ def visit_TupleTargetAssign( indices: list[tuple[int, int] | int] = compute_assign_indices(targets, num_elts) if not any(isinstance(i, tuple) for i in indices) and len(indices) != num_elts: - raise CompilerError(node.location, f"Too many values to unpack (expected {len(indices)})." - ) + raise CompilerError( + node.location, f"Too many values to unpack (expected {len(indices)})." + ) new_targets: TargetType = [] new_type: ts.TupleType | ts.DataType @@ -371,8 +378,9 @@ def visit_TupleTargetAssign( ) new_targets.append(new_target) else: - raise CompilerError(node.location, f"Assignment value must be of type tuple! Got: {values.type}" - ) + raise CompilerError( + node.location, f"Assignment value must be of type tuple! Got: {values.type}" + ) return foast.TupleTargetAssign(targets=new_targets, value=values, location=node.location) @@ -389,25 +397,28 @@ def visit_IfStmt(self, node: foast.IfStmt, **kwargs) -> foast.IfStmt: ) if not isinstance(new_node.condition.type, ts.ScalarType): - raise CompilerError(node.location, + raise CompilerError( + node.location, "Condition for `if` must be scalar. " f"But got `{new_node.condition.type}` instead.", - ) + ) if new_node.condition.type.kind != ts.ScalarKind.BOOL: - raise CompilerError(node.location, + raise CompilerError( + node.location, "Condition for `if` must be of boolean type. " f"But got `{new_node.condition.type}` instead.", - ) + ) for sym in node.annex.propagated_symbols.keys(): if (true_type := new_true_branch.annex.symtable[sym].type) != ( false_type := new_false_branch.annex.symtable[sym].type ): - raise CompilerError(node.location, + raise CompilerError( + node.location, f"Inconsistent types between two branches for variable `{sym}`. " f"Got types `{true_type}` and `{false_type}.", - ) + ) # TODO: properly patch symtable (new node?) symtable[sym].type = new_node.annex.propagated_symbols[ sym @@ -424,12 +435,13 @@ def visit_Symbol( symtable = kwargs["symtable"] if refine_type: if not type_info.is_concretizable(node.type, to_type=refine_type): - raise CompilerError(node.location, - ( + raise CompilerError( + node.location, + ( "type inconsistency: expression was deduced to be " f"of type {refine_type}, instead of the expected type {node.type}" ), - ) + ) new_node: foast.Symbol = foast.Symbol( id=node.id, type=refine_type, location=node.location ) @@ -445,16 +457,19 @@ def visit_Subscript(self, node: foast.Subscript, **kwargs) -> foast.Subscript: new_type = types[node.index] case ts.OffsetType(source=source, target=(target1, target2)): if not target2.kind == DimensionKind.LOCAL: - raise CompilerError(new_value.location, "Second dimension in offset must be a local dimension.") + raise CompilerError( + new_value.location, "Second dimension in offset must be a local dimension." + ) new_type = ts.OffsetType(source=source, target=(target1,)) case ts.OffsetType(source=source, target=(target,)): # for cartesian axes (e.g. I, J) the index of the subscript only # signifies the displacement in the respective dimension, # but does not change the target type. if source != target: - raise CompilerError(new_value.location, + raise CompilerError( + new_value.location, "Source and target must be equal for offsets with a single target.", - ) + ) new_type = new_value.type case _: raise CompilerError( @@ -497,14 +512,16 @@ def _deduce_ternaryexpr_type( false_expr: foast.Expr, ) -> Optional[ts.TypeSpec]: if condition.type != ts.ScalarType(kind=ts.ScalarKind.BOOL): - raise CompilerError(condition.location, + raise CompilerError( + condition.location, f"Condition is of type `{condition.type}` " f"but should be of type `bool`.", - ) + ) if true_expr.type != false_expr.type: - raise CompilerError(node.location, + raise CompilerError( + node.location, f"Left and right types are not the same: `{true_expr.type}` and `{false_expr.type}`", - ) + ) return true_expr.type def visit_Compare(self, node: foast.Compare, **kwargs) -> foast.Compare: @@ -521,8 +538,9 @@ def _deduce_compare_type( # check both types compatible for arg in (left, right): if not type_info.is_arithmetic(arg.type): - raise CompilerError(arg.location, f"Type {arg.type} can not be used in operator '{node.op}'!" - ) + raise CompilerError( + arg.location, f"Type {arg.type} can not be used in operator '{node.op}'!" + ) self._check_operand_dtypes_match(node, left=left, right=right) @@ -531,10 +549,11 @@ def _deduce_compare_type( # mechanism to handle dimension promotion return type_info.promote(boolified_type(left.type), boolified_type(right.type)) except ValueError as ex: - raise CompilerError(node.location, + raise CompilerError( + node.location, f"Could not promote `{left.type}` and `{right.type}` to common type" f" in call to `{node.op}`.", - ) from ex + ) from ex def _deduce_binop_type( self, @@ -554,8 +573,9 @@ def _deduce_binop_type( # check both types compatible for arg in (left, right): if not is_compatible(arg.type): - raise CompilerError(arg.location, f"Type {arg.type} can not be used in operator `{node.op}`!" - ) + raise CompilerError( + arg.location, f"Type {arg.type} can not be used in operator `{node.op}`!" + ) left_type = cast(ts.FieldType | ts.ScalarType, left.type) right_type = cast(ts.FieldType | ts.ScalarType, right.type) @@ -566,26 +586,29 @@ def _deduce_binop_type( if node.op == dialect_ast_enums.BinaryOperator.MOD and not type_info.is_integral( right_type ): - raise CompilerError(arg.location, + raise CompilerError( + arg.location, f"Type {right_type} can not be used in operator `{node.op}`, it can only accept ints", - ) + ) try: return type_info.promote(left_type, right_type) except ValueError as ex: - raise CompilerError(node.location, + raise CompilerError( + node.location, f"Could not promote `{left_type}` and `{right_type}` to common type" f" in call to `{node.op}`.", - ) from ex + ) from ex def _check_operand_dtypes_match( self, node: foast.BinOp | foast.Compare, left: foast.Expr, right: foast.Expr ) -> None: # check dtypes match if not type_info.extract_dtype(left.type) == type_info.extract_dtype(right.type): - raise CompilerError(node.location, + raise CompilerError( + node.location, f"Incompatible datatypes in operator `{node.op}`: {left.type} and {right.type}!", - ) + ) def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs) -> foast.UnaryOp: new_operand = self.visit(node.operand, **kwargs) @@ -599,9 +622,10 @@ def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs) -> foast.UnaryOp: else type_info.is_arithmetic ) if not is_compatible(new_operand.type): - raise CompilerError(node.location, + raise CompilerError( + node.location, f"Incompatible type for unary operator `{node.op}`: `{new_operand.type}`!", - ) + ) return foast.UnaryOp( op=node.op, operand=new_operand, location=node.location, type=new_operand.type ) @@ -630,14 +654,14 @@ def visit_Call(self, node: foast.Call, **kwargs) -> foast.Call: new_func, (foast.FunctionDefinition, foast.FieldOperator, foast.ScanOperator, foast.Name), ): - raise CompilerError(node.location, "Functions can only be called directly!" - ) + raise CompilerError(node.location, "Functions can only be called directly!") elif isinstance(new_func.type, ts.FieldType): pass else: - raise CompilerError(node.location, + raise CompilerError( + node.location, f"Expression of type `{new_func.type}` is not callable, must be a `Function`, `FieldOperator`, `ScanOperator` or `Field`.", - ) + ) # ensure signature is valid try: @@ -648,8 +672,9 @@ def visit_Call(self, node: foast.Call, **kwargs) -> foast.Call: raise_exception=True, ) except ValueError as err: - raise CompilerError(node.location, f"Invalid argument types in call to `{new_func}`!" - ) from err + raise CompilerError( + node.location, f"Invalid argument types in call to `{new_func}`!" + ) from err return_type = type_info.return_type(func_type, with_args=arg_types, with_kwargs=kwarg_types) @@ -705,9 +730,10 @@ def _visit_math_built_in(self, node: foast.Call, **kwargs) -> foast.Call: f"Expected {i}-th argument to be {error_msg_for_validator[arg_validator]} type, but got `{arg.type}`." ) if error_msgs: - raise CompilerError(node.location, + raise CompilerError( + node.location, "\n".join([error_msg_preamble] + [f" - {error}" for error in error_msgs]), - ) + ) if func_name == "power" and all(type_info.is_integral(arg.type) for arg in node.args): print(f"Warning: return type of {func_name} might be inconsistent (not implemented).") @@ -727,8 +753,7 @@ def _visit_math_built_in(self, node: foast.Call, **kwargs) -> foast.Call: *((cast(ts.FieldType | ts.ScalarType, arg.type)) for arg in node.args) ) except ValueError as ex: - raise CompilerError(node.location, error_msg_preamble - ) from ex + raise CompilerError(node.location, error_msg_preamble) from ex else: raise AssertionError(f"Unknown math builtin `{func_name}`.") @@ -747,11 +772,12 @@ def _visit_reduction(self, node: foast.Call, **kwargs) -> foast.Call: assert field_type.dims is not ... if reduction_dim not in field_type.dims: field_dims_str = ", ".join(str(dim) for dim in field_type.dims) - raise CompilerError(node.location, + raise CompilerError( + node.location, f"Incompatible field argument in call to `{str(node.func)}`. " f"Expected a field with dimension {reduction_dim}, but got " f"{field_dims_str}.", - ) + ) return_type = ts.FieldType( dims=[dim for dim in field_type.dims if dim != reduction_dim], dtype=field_type.dtype, @@ -802,18 +828,20 @@ def _visit_as_offset(self, node: foast.Call, **kwargs) -> foast.Call: assert isinstance(arg_0, ts.OffsetType) assert isinstance(arg_1, ts.FieldType) if not type_info.is_integral(arg_1): - raise CompilerError(node.location, + raise CompilerError( + node.location, f"Incompatible argument in call to `{str(node.func)}`. " f"Excepted integer for offset field dtype, but got {arg_1.dtype}" f"{node.location}", - ) + ) if arg_0.source not in arg_1.dims: - raise CompilerError(node.location, + raise CompilerError( + node.location, f"Incompatible argument in call to `{str(node.func)}`. " f"{arg_0.source} not in list of offset field dimensions {arg_1.dims}. " f"{node.location}", - ) + ) return foast.Call( func=node.func, @@ -829,10 +857,11 @@ def _visit_where(self, node: foast.Call, **kwargs) -> foast.Call: false_branch_type = node.args[2].type return_type: ts.TupleType | ts.FieldType if not type_info.is_logical(mask_type): - raise CompilerError(node.location, + raise CompilerError( + node.location, f"Incompatible argument in call to `{str(node.func)}`. Expected " f"a field with dtype `bool`, but got `{mask_type}`.", - ) + ) try: if isinstance(true_branch_type, ts.TupleType) and isinstance( @@ -846,10 +875,11 @@ def _visit_where(self, node: foast.Call, **kwargs) -> foast.Call: elif isinstance(true_branch_type, ts.TupleType) or isinstance( false_branch_type, ts.TupleType ): - raise CompilerError(node.location, + raise CompilerError( + node.location, f"Return arguments need to be of same type in {str(node.func)}, but got: " f"{node.args[1].type} and {node.args[2].type}", - ) + ) else: true_branch_fieldtype = cast(ts.FieldType, true_branch_type) false_branch_fieldtype = cast(ts.FieldType, false_branch_type) @@ -857,9 +887,10 @@ def _visit_where(self, node: foast.Call, **kwargs) -> foast.Call: return_type = promote_to_mask_type(mask_type, promoted_type) except ValueError as ex: - raise CompilerError(node.location, + raise CompilerError( + node.location, f"Incompatible argument in call to `{str(node.func)}`.", - ) from ex + ) from ex return foast.Call( func=node.func, @@ -874,18 +905,20 @@ def _visit_broadcast(self, node: foast.Call, **kwargs) -> foast.Call: broadcast_dims_expr = cast(foast.TupleExpr, node.args[1]).elts if any([not (isinstance(elt.type, ts.DimensionType)) for elt in broadcast_dims_expr]): - raise CompilerError(node.location, + raise CompilerError( + node.location, f"Incompatible broadcast dimension type in {str(node.func)}. Expected " f"all broadcast dimensions to be of type Dimension.", - ) + ) broadcast_dims = [cast(ts.DimensionType, elt.type).dim for elt in broadcast_dims_expr] if not set((arg_dims := type_info.extract_dims(arg_type))).issubset(set(broadcast_dims)): - raise CompilerError(node.location, + raise CompilerError( + node.location, f"Incompatible broadcast dimensions in {str(node.func)}. Expected " f"broadcast dimension is missing {set(arg_dims).difference(set(broadcast_dims))}", - ) + ) return_type = ts.FieldType( dims=broadcast_dims, @@ -904,6 +937,5 @@ def visit_Constant(self, node: foast.Constant, **kwargs) -> foast.Constant: try: type_ = type_translation.from_value(node.value) except ValueError as e: - raise CompilerError(node.location, "Could not deduce type of constant." - ) from e + raise CompilerError(node.location, "Could not deduce type of constant.") from e return foast.Constant(value=node.value, location=node.location, type=type_) diff --git a/src/gt4py/next/ffront/func_to_foast.py b/src/gt4py/next/ffront/func_to_foast.py index 74097a0dff..5d7cb5d2f2 100644 --- a/src/gt4py/next/ffront/func_to_foast.py +++ b/src/gt4py/next/ffront/func_to_foast.py @@ -19,7 +19,12 @@ from typing import Any, Callable, Iterable, Mapping, Type, cast import gt4py.eve as eve -from gt4py.next import common +from gt4py.next.errors import ( + CompilerError, + InvalidParameterAnnotationError, + MissingParameterAnnotationError, + UnsupportedPythonFeatureError, +) from gt4py.next.ffront import dialect_ast_enums, fbuiltins, field_operator_ast as foast from gt4py.next.ffront.ast_passes import ( SingleAssignTargetPass, @@ -28,7 +33,6 @@ UnchainComparesPass, ) from gt4py.next.ffront.dialect_parser import DialectParser -from gt4py.next.errors import * from gt4py.next.ffront.foast_introspection import StmtReturnKind, deduce_stmt_return_kind from gt4py.next.ffront.foast_passes.closure_var_folding import ClosureVarFolding from gt4py.next.ffront.foast_passes.closure_var_type_deduction import ClosureVarTypeDeduction @@ -107,7 +111,7 @@ def _postprocess_dialect_ast( raise CompilerError( foast_node.location, f"Annotated return type does not match deduced return type. Expected `{foast_node.type.returns}`" # type: ignore[union-attr] # revisit when `type_info.return_type` is implemented - f", but got `{annotated_return_type}`." + f", but got `{annotated_return_type}`.", ) return foast_node @@ -166,7 +170,9 @@ def visit_FunctionDef(self, node: ast.FunctionDef, **kwargs) -> foast.FunctionDe new_body = self._visit_stmts(node.body, self.get_location(node), **kwargs) if deduce_stmt_return_kind(new_body) == StmtReturnKind.NO_RETURN: - raise CompilerError(loc, "function is expected to return a value, return statement not found") + raise CompilerError( + loc, "function is expected to return a value, return statement not found" + ) return foast.FunctionDefinition( id=node.name, @@ -286,7 +292,9 @@ def visit_Subscript(self, node: ast.Subscript, **kwargs) -> foast.Subscript: try: index = self._match_index(node.slice) except ValueError: - raise CompilerError(self.get_location(node.slice), "expected an integral index") from None + raise CompilerError( + self.get_location(node.slice), "expected an integral index" + ) from None return foast.Subscript( value=self.visit(node.value), @@ -318,7 +326,9 @@ def visit_Name(self, node: ast.Name, **kwargs) -> foast.Name: def visit_UnaryOp(self, node: ast.UnaryOp, **kwargs) -> foast.UnaryOp: return foast.UnaryOp( - op=self.visit(node.op), operand=self.visit(node.operand), location=self.get_location(node) + op=self.visit(node.op), + operand=self.visit(node.operand), + location=self.get_location(node), ) def visit_UAdd(self, node: ast.UAdd, **kwargs) -> dialect_ast_enums.UnaryOperator: @@ -372,7 +382,9 @@ def visit_BitXor(self, node: ast.BitXor, **kwargs) -> dialect_ast_enums.BinaryOp return dialect_ast_enums.BinaryOperator.BIT_XOR def visit_BoolOp(self, node: ast.BoolOp, **kwargs) -> None: - raise UnsupportedPythonFeatureError(self.get_location(node), "logical operators `and`, `or`") + raise UnsupportedPythonFeatureError( + self.get_location(node), "logical operators `and`, `or`" + ) def visit_IfExp(self, node: ast.IfExp, **kwargs) -> foast.TernaryExpr: return foast.TernaryExpr( @@ -459,8 +471,10 @@ def visit_Constant(self, node: ast.Constant, **kwargs) -> foast.Constant: loc = self.get_location(node) try: type_ = type_translation.from_value(node.value) - except ValueError as e: - raise CompilerError(loc, f"constants of type {type(node.value)} are not permitted") from None + except ValueError: + raise CompilerError( + loc, f"constants of type {type(node.value)} are not permitted" + ) from None return foast.Constant( value=node.value, diff --git a/src/gt4py/next/ffront/func_to_past.py b/src/gt4py/next/ffront/func_to_past.py index 2f8e16b84b..8ed466cc29 100644 --- a/src/gt4py/next/ffront/func_to_past.py +++ b/src/gt4py/next/ffront/func_to_past.py @@ -18,13 +18,17 @@ from dataclasses import dataclass from typing import Any, cast +from gt4py.next.errors import ( + CompilerError, + InvalidParameterAnnotationError, + MissingParameterAnnotationError, +) from gt4py.next.ffront import ( dialect_ast_enums, program_ast as past, type_specifications as ts_ffront, ) from gt4py.next.ffront.dialect_parser import DialectParser -from gt4py.next.errors import CompilerError, MissingParameterAnnotationError, InvalidParameterAnnotationError from gt4py.next.ffront.past_passes.closure_var_type_deduction import ClosureVarTypeDeduction from gt4py.next.ffront.past_passes.type_deduction import ProgramTypeDeduction from gt4py.next.type_system import type_specifications as ts, type_translation @@ -163,9 +167,7 @@ def visit_UnaryOp(self, node: ast.UnaryOp) -> past.Constant: loc = self.get_location(node) if isinstance(node.op, ast.USub) and isinstance(node.operand, ast.Constant): symbol_type = type_translation.from_value(node.operand.value) - return past.Constant( - value=-node.operand.value, type=symbol_type, location=loc - ) + return past.Constant(value=-node.operand.value, type=symbol_type, location=loc) raise CompilerError(loc, "unary operators are only applicable to literals") def visit_Constant(self, node: ast.Constant) -> past.Constant: diff --git a/src/gt4py/next/ffront/past_passes/type_deduction.py b/src/gt4py/next/ffront/past_passes/type_deduction.py index 2402016624..a45e85231e 100644 --- a/src/gt4py/next/ffront/past_passes/type_deduction.py +++ b/src/gt4py/next/ffront/past_passes/type_deduction.py @@ -15,13 +15,13 @@ from typing import Optional, cast from gt4py.eve import NodeTranslator, traits +from gt4py.next.errors import CompilerError from gt4py.next.ffront import ( dialect_ast_enums, program_ast as past, type_specifications as ts_ffront, ) from gt4py.next.type_system import type_info, type_specifications as ts -from gt4py.next.errors import * def _ensure_no_sliced_field(entry: past.Expr): @@ -231,9 +231,7 @@ def visit_Call(self, node: past.Call, **kwargs): ) except ValueError as ex: - raise CompilerError( - node.location, f"Invalid call to `{node.func.id}`." - ) from ex + raise CompilerError(node.location, f"Invalid call to `{node.func.id}`.") from ex return past.Call( func=new_func, @@ -246,8 +244,6 @@ def visit_Call(self, node: past.Call, **kwargs): def visit_Name(self, node: past.Name, **kwargs) -> past.Name: symtable = kwargs["symtable"] if node.id not in symtable or symtable[node.id].type is None: - raise CompilerError( - node.location, f"Undeclared or untyped symbol `{node.id}`." - ) + raise CompilerError(node.location, f"Undeclared or untyped symbol `{node.id}`.") return past.Name(id=node.id, type=symtable[node.id].type, location=node.location) diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index a89385c592..2c5dfc6e2f 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -37,9 +37,7 @@ def _flatten_tuple_expr( for e in node.elts: result.extend(_flatten_tuple_expr(e)) return result - raise ValueError( - "Only `past.Name`, `past.Subscript` or `past.TupleExpr`s thereof are allowed." - ) + raise ValueError("Only `past.Name`, `past.Subscript` or `past.TupleExpr`s thereof are allowed.") class ProgramLowering(traits.VisitorWithSymbolTableTrait, NodeTranslator): diff --git a/src/gt4py/next/ffront/source_utils.py b/src/gt4py/next/ffront/source_utils.py index 5d9c648ae6..cc3c20e7fc 100644 --- a/src/gt4py/next/ffront/source_utils.py +++ b/src/gt4py/next/ffront/source_utils.py @@ -23,8 +23,6 @@ from dataclasses import dataclass from typing import Any, cast -from gt4py.next import common - MISSING_FILENAME = "" @@ -38,12 +36,13 @@ def make_source_definition_from_function(func: Callable) -> SourceDefinition: try: filename = str(pathlib.Path(inspect.getabsfile(func)).resolve()) if not filename: - raise ValueError("Can not create field operator from a function that is not in a source file!") + raise ValueError( + "Can not create field operator from a function that is not in a source file!" + ) source_lines, line_offset = inspect.getsourcelines(func) source_code = textwrap.dedent(inspect.getsource(func)) column_offset = min( - [len(line) - len(line.lstrip()) for line in source_lines if line.lstrip()], - default=0 + [len(line) - len(line.lstrip()) for line in source_lines if line.lstrip()], default=0 ) return SourceDefinition(source_code, filename, line_offset - 1, column_offset) diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index 9e7d64e081..e4ce2e9173 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -56,9 +56,7 @@ def type_class(symbol_type: ts.TypeSpec) -> Type[ts.TypeSpec]: return constraint case ts.TypeSpec() as concrete_type: return concrete_type.__class__ - raise ValueError( - f"Invalid type for TypeInfo: requires {ts.TypeSpec}, got {type(symbol_type)}!" - ) + raise ValueError(f"Invalid type for TypeInfo: requires {ts.TypeSpec}, got {type(symbol_type)}!") def primitive_constituents( diff --git a/tests/next_tests/exception_printing.py b/tests/next_tests/exception_printing.py index b1e07cac0e..78cba01632 100644 --- a/tests/next_tests/exception_printing.py +++ b/tests/next_tests/exception_printing.py @@ -1,8 +1,25 @@ -from gt4py.next.errors import * +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + import inspect + from gt4py.eve import SourceLocation +from gt4py.next.errors import CompilerError frameinfo = inspect.getframeinfo(inspect.currentframe()) -loc = SourceLocation(frameinfo.lineno, 1, frameinfo.filename, end_line=frameinfo.lineno, end_column=5) -raise CompilerError(loc, "this is an error message") \ No newline at end of file +loc = SourceLocation( + frameinfo.filename, frameinfo.lineno, 1, end_line=frameinfo.lineno, end_column=5 +) +raise CompilerError(loc, "this is an error message") diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 6b9af54913..f8d816b96a 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -31,9 +31,9 @@ neighbor_sum, where, ) +from gt4py.next.errors import * from gt4py.next.ffront.experimental import as_offset from gt4py.next.program_processors.runners import gtfn_cpu -from gt4py.next.errors import * from next_tests.integration_tests.feature_tests import cases from next_tests.integration_tests.feature_tests.cases import ( @@ -929,9 +929,7 @@ def _star_unpack() -> tuple[int32, float64, int32]: def test_tuple_unpacking_too_many_values(cartesian_case): - with pytest.raises( - CompilerError, match=(r"Assignment value must be of type tuple!") - ): + with pytest.raises(CompilerError, match=(r"Assignment value must be of type tuple!")): @gtx.field_operator(backend=cartesian_case.backend) def _invalid_unpack() -> tuple[int32, float64, int32]: diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py index 1702ae1fe0..1619bf343f 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py @@ -292,5 +292,7 @@ def program_input_kwargs( program_input_kwargs(a=input_1, b=input_2, c=input_3, out=out, offset_provider={}) assert np.allclose(expected, out) - with pytest.raises(ValueError, match="Invalid argument types in call to `program_input_kwargs`!"): + with pytest.raises( + ValueError, match="Invalid argument types in call to `program_input_kwargs`!" + ): program_input_kwargs(input_2, input_3, a=input_1, out=out, offset_provider={}) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py index b931d28f88..98728d5d5b 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py @@ -19,8 +19,8 @@ import pytest from gt4py.next import Field, field_operator, float64, index_field, np_as_located_field -from gt4py.next.program_processors.runners import gtfn_cpu from gt4py.next.errors import * +from gt4py.next.program_processors.runners import gtfn_cpu from next_tests.integration_tests.feature_tests import cases from next_tests.integration_tests.feature_tests.cases import ( @@ -370,9 +370,7 @@ def if_non_scalar_condition( def test_if_non_boolean_condition(): - with pytest.raises( - CompilerError, match="Condition for `if` must be of boolean type." - ): + with pytest.raises(CompilerError, match="Condition for `if` must be of boolean type."): @field_operator def if_non_boolean_condition( diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py index 51897b487c..1908a924d6 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py @@ -31,11 +31,11 @@ neighbor_sum, where, ) +from gt4py.next.errors import * from gt4py.next.ffront.ast_passes import single_static_assign as ssa from gt4py.next.ffront.experimental import as_offset from gt4py.next.ffront.func_to_foast import FieldOperatorParser from gt4py.next.type_system import type_info, type_specifications as ts -from gt4py.next.errors import * TDim = Dimension("TDim") # Meaningless dimension, used for tests. diff --git a/tests/next_tests/integration_tests/feature_tests/test_util_cases.py b/tests/next_tests/integration_tests/feature_tests/test_util_cases.py index 402eb2d56c..63a65d60ed 100644 --- a/tests/next_tests/integration_tests/feature_tests/test_util_cases.py +++ b/tests/next_tests/integration_tests/feature_tests/test_util_cases.py @@ -16,8 +16,7 @@ import pytest import gt4py.next as gtx -from gt4py.next import common -from gt4py.next.errors import * +from gt4py.next.errors import CompilerError from gt4py.next.program_processors.runners import roundtrip from next_tests.integration_tests.feature_tests import cases diff --git a/tests/next_tests/unit_tests/errors_tests/test_compilation_error.py b/tests/next_tests/unit_tests/errors_tests/test_compilation_error.py index f377c13065..d815ea1228 100644 --- a/tests/next_tests/unit_tests/errors_tests/test_compilation_error.py +++ b/tests/next_tests/unit_tests/errors_tests/test_compilation_error.py @@ -1,5 +1,19 @@ -from gt4py.next.errors import CompilerError +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + from gt4py.eve import SourceLocation +from gt4py.next.errors import CompilerError loc = SourceLocation("/source/file.py", 5, 2, end_line=5, end_column=9) @@ -12,5 +26,3 @@ def test_message(): def test_location(): assert CompilerError(loc, msg).location_trace[0] == loc - - diff --git a/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py b/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py index 25249b1407..c7b6faa034 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py @@ -42,12 +42,12 @@ import gt4py.next as gtx from gt4py.eve.pattern_matching import ObjectPattern as P from gt4py.next import astype, broadcast, float32, float64, int32, int64, where +from gt4py.next.errors import * from gt4py.next.ffront import field_operator_ast as foast from gt4py.next.ffront.ast_passes import single_static_assign as ssa from gt4py.next.ffront.func_to_foast import FieldOperatorParser from gt4py.next.iterator import builtins as itb, ir as itir from gt4py.next.type_system import type_specifications as ts -from gt4py.next.errors import * DEREF = itir.SymRef(id=itb.deref.fun.__name__) @@ -77,9 +77,7 @@ def test_untyped_arg(): def untyped(inp): return inp - with pytest.raises( - MissingParameterAnnotationError - ): + with pytest.raises(MissingParameterAnnotationError): _ = FieldOperatorParser.apply_to_function(untyped) diff --git a/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast_error_line_number.py b/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast_error_line_number.py index 91e5cad343..272b1428dd 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast_error_line_number.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast_error_line_number.py @@ -18,9 +18,9 @@ import pytest import gt4py.next as gtx +from gt4py.next.errors import * from gt4py.next.ffront import func_to_foast as f2f, source_utils as src_utils from gt4py.next.ffront.foast_passes import type_deduction -from gt4py.next.errors import * # NOTE: These tests are sensitive to filename and the line number of the marked statement @@ -38,9 +38,7 @@ def wrong_syntax(inp: gtx.Field[[TDim], float]): with pytest.raises( f2f.CompilerError, - match=( - r".*return.*" - ), + match=(r".*return.*"), ) as exc_info: _ = f2f.FieldOperatorParser.apply_to_function(wrong_syntax) diff --git a/tests/next_tests/unit_tests/ffront_tests/test_func_to_past.py b/tests/next_tests/unit_tests/ffront_tests/test_func_to_past.py index 20201736c5..013684972e 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_func_to_past.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_func_to_past.py @@ -20,10 +20,10 @@ import gt4py.next as gtx from gt4py.eve.pattern_matching import ObjectPattern as P from gt4py.next import float64 +from gt4py.next.errors import * from gt4py.next.ffront import program_ast as past from gt4py.next.ffront.func_to_past import ProgramParser from gt4py.next.type_system import type_specifications as ts -from gt4py.next.errors import * from next_tests.past_common_fixtures import ( IDim, diff --git a/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py b/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py index f06c2e9f44..b0069ddf5f 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py @@ -19,10 +19,10 @@ import gt4py.eve as eve import gt4py.next as gtx from gt4py.eve.pattern_matching import ObjectPattern as P +from gt4py.next.errors import * from gt4py.next.ffront.func_to_past import ProgramParser from gt4py.next.ffront.past_to_itir import ProgramLowering from gt4py.next.iterator import ir as itir -from gt4py.next.errors import * from next_tests.past_common_fixtures import ( IDim, diff --git a/tests/next_tests/unit_tests/type_system_tests/test_type_translation.py b/tests/next_tests/unit_tests/type_system_tests/test_type_translation.py index 78ea97f9dd..d281f5cd90 100644 --- a/tests/next_tests/unit_tests/type_system_tests/test_type_translation.py +++ b/tests/next_tests/unit_tests/type_system_tests/test_type_translation.py @@ -158,9 +158,7 @@ def test_invalid_symbol_types(): type_translation.from_type_hint(common.Field[[IDim], None]) # Functions - with pytest.raises( - ValueError, match="Not annotated functions are not supported" - ): + with pytest.raises(ValueError, match="Not annotated functions are not supported"): type_translation.from_type_hint(typing.Callable) with pytest.raises(ValueError, match="Invalid callable annotations"): From 360a6f0a3c69527be9221cb51fafd16c9396d089 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Kardos?= Date: Wed, 21 Jun 2023 15:31:56 +0200 Subject: [PATCH 16/54] fix doctests --- src/gt4py/next/ffront/func_to_foast.py | 4 ++-- src/gt4py/next/ffront/source_utils.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/gt4py/next/ffront/func_to_foast.py b/src/gt4py/next/ffront/func_to_foast.py index 5d7cb5d2f2..1fecc2b049 100644 --- a/src/gt4py/next/ffront/func_to_foast.py +++ b/src/gt4py/next/ffront/func_to_foast.py @@ -77,8 +77,8 @@ class FieldOperatorParser(DialectParser[foast.FunctionDefinition]): >>> try: # doctest: +ELLIPSIS ... FieldOperatorParser.apply_to_function(wrong_syntax) ... except CompilerError as err: - ... print(f"Error at [{err.lineno}, {err.offset}] in {err.filename})") - Error at [2, 5] in ...gt4py.next.ffront.func_to_foast.FieldOperatorParser[...]>) + ... print(f"Error at [{err.location.line}, {err.location.column}] in {err.location.filename})") + Error at [2, 5] in ...func_to_foast.FieldOperatorParser[...]>) """ @classmethod diff --git a/src/gt4py/next/ffront/source_utils.py b/src/gt4py/next/ffront/source_utils.py index cc3c20e7fc..17b2050b1b 100644 --- a/src/gt4py/next/ffront/source_utils.py +++ b/src/gt4py/next/ffront/source_utils.py @@ -106,11 +106,11 @@ class SourceDefinition: >>> def foo(a): ... return a >>> src_def = SourceDefinition.from_function(foo) - >>> print(src_def) - SourceDefinition(source='def foo(a):... starting_line=1) + >>> print(src_def) # doctest:+ELLIPSIS + SourceDefinition(source='def foo(a):...', filename='...', line_offset=0, column_offset=0) >>> source, filename, starting_line = src_def - >>> print(source) + >>> print(source) # doctest:+ELLIPSIS def foo(a): return a ... From d0813746f78635ced61d758f71e9af8a932ad676 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Kardos?= Date: Fri, 30 Jun 2023 19:35:26 +0200 Subject: [PATCH 17/54] print extra info for uncaught exceptions for gt4py developers --- src/gt4py/next/errors/__init__.py | 2 + src/gt4py/next/errors/excepthook.py | 85 +++++++++++++++++++++++--- tests/next_tests/exception_printing.py | 2 +- 3 files changed, 79 insertions(+), 10 deletions(-) diff --git a/src/gt4py/next/errors/__init__.py b/src/gt4py/next/errors/__init__.py index 3b7f123d5f..de9cce0055 100644 --- a/src/gt4py/next/errors/__init__.py +++ b/src/gt4py/next/errors/__init__.py @@ -15,6 +15,7 @@ from . import ( # noqa: module needs to be loaded for pretty printing of uncaught exceptions. excepthook, ) +from .excepthook import set_developer_mode from .exceptions import ( ArgumentCountError, CompilerError, @@ -36,4 +37,5 @@ "MissingParameterAnnotationError", "UndefinedSymbolError", "UnsupportedPythonFeatureError", + "set_developer_mode", ] diff --git a/src/gt4py/next/errors/excepthook.py b/src/gt4py/next/errors/excepthook.py index 4b180fc74c..1067119d52 100644 --- a/src/gt4py/next/errors/excepthook.py +++ b/src/gt4py/next/errors/excepthook.py @@ -11,23 +11,90 @@ # distribution for a copy of the license or check . # # SPDX-License-Identifier: GPL-3.0-or-later - +import os import sys -from typing import Callable +import traceback +from typing import Callable, Optional + +import importlib_metadata from . import exceptions, formatting +def _get_developer_mode_python() -> bool: + """Guess if the Python environment is used to develop gt4py.""" + # Import gt4py and use its __name__ because hard-coding "gt4py" would fail + # silently if the module's name changes for whatever reason. + import gt4py + + package_name = gt4py.__name__ + + # Check if any package requires gt4py as a dependency. If not, we are + # probably developing gt4py itself rather than something else using gt4py. + dists = importlib_metadata.distributions() + for dist in dists: + for req in dist.requires or []: + if req.startswith(package_name): + return False + return True + + +def _get_developer_mode_os() -> Optional[bool]: + """Detect if the user set developer mode in environment variables.""" + env_var_name = "GT4PY_DEVELOPER_MODE" + if env_var_name in os.environ: + try: + return bool(os.environ[env_var_name]) + except TypeError: + return False + return None + + +def _guess_developer_mode() -> bool: + """Guess if gt4py is run by its developers or by third party users.""" + env = _get_developer_mode_os() + if env is not None: + return env + return _get_developer_mode_python() + + +_developer_mode = _guess_developer_mode() + + +def set_developer_mode(enabled: bool = False): + """In developer mode, information useful for gt4py developers is also shown.""" + global _developer_mode + _developer_mode = enabled + + +def _print_cause(exc: BaseException): + """Print the cause of an exception plus the bridging message to STDERR.""" + bridging_message = "The above exception was the direct cause of the following exception:" + + if exc.__cause__ or exc.__context__: + traceback.print_exception(exc.__cause__ or exc.__context__) + print(f"\n{bridging_message}\n", file=sys.stderr) + + +def _print_traceback(exc: BaseException): + """Print the traceback of an exception to STDERR.""" + intro_message = "Traceback (most recent call last):" + traceback_strs = [ + f"{intro_message}\n", + *traceback.format_tb(exc.__traceback__), + ] + print("".join(traceback_strs), file=sys.stderr) + + def compilation_error_hook(fallback: Callable, type_: type, value: BaseException, tb): if isinstance(value, exceptions.CompilerError): - print( - "".join( - formatting.format_compilation_error( - type(value), value.message, value.location_trace - ) - ), - file=sys.stderr, + if _developer_mode: + _print_cause(value) + _print_traceback(value) + exc_strs = formatting.format_compilation_error( + type(value), value.message, value.location_trace ) + print("".join(exc_strs), file=sys.stderr) else: fallback(type_, value, tb) diff --git a/tests/next_tests/exception_printing.py b/tests/next_tests/exception_printing.py index 78cba01632..a3175b45d2 100644 --- a/tests/next_tests/exception_printing.py +++ b/tests/next_tests/exception_printing.py @@ -22,4 +22,4 @@ loc = SourceLocation( frameinfo.filename, frameinfo.lineno, 1, end_line=frameinfo.lineno, end_column=5 ) -raise CompilerError(loc, "this is an error message") +raise CompilerError(loc, "this is an error message") from ValueError("asd") From 79dca318eee6ca20e1c2e7a645dd2586cfee4762 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Kardos?= Date: Fri, 30 Jun 2023 19:56:38 +0200 Subject: [PATCH 18/54] use only single source location and not stack --- src/gt4py/next/errors/excepthook.py | 4 +- src/gt4py/next/errors/exceptions.py | 46 ++++++------------- src/gt4py/next/errors/formatting.py | 12 ++--- .../errors_tests/test_compilation_error.py | 2 +- 4 files changed, 22 insertions(+), 42 deletions(-) diff --git a/src/gt4py/next/errors/excepthook.py b/src/gt4py/next/errors/excepthook.py index 1067119d52..f4699c4a6f 100644 --- a/src/gt4py/next/errors/excepthook.py +++ b/src/gt4py/next/errors/excepthook.py @@ -91,9 +91,7 @@ def compilation_error_hook(fallback: Callable, type_: type, value: BaseException if _developer_mode: _print_cause(value) _print_traceback(value) - exc_strs = formatting.format_compilation_error( - type(value), value.message, value.location_trace - ) + exc_strs = formatting.format_compilation_error(type(value), value.message, value.location) print("".join(exc_strs), file=sys.stderr) else: fallback(type_, value, tb) diff --git a/src/gt4py/next/errors/exceptions.py b/src/gt4py/next/errors/exceptions.py index 03a5af0207..8b403526e3 100644 --- a/src/gt4py/next/errors/exceptions.py +++ b/src/gt4py/next/errors/exceptions.py @@ -13,33 +13,26 @@ # SPDX-License-Identifier: GPL-3.0-or-later import textwrap -from typing import Any, Optional, TypeVar +from typing import Any, Optional from gt4py.eve import SourceLocation from . import formatting -LocationTraceT = TypeVar("LocationTraceT", SourceLocation, list[SourceLocation], None) - - class CompilerError(Exception): - location_trace: list[SourceLocation] + location: Optional[SourceLocation] - def __init__(self, location: LocationTraceT, message: str): - self.location_trace = CompilerError._make_location_trace(location) + def __init__(self, location: Optional[SourceLocation], message: str): + self.location = location super().__init__(message) @property def message(self) -> str: return self.args[0] - @property - def location(self) -> Optional[SourceLocation]: - return self.location_trace[0] if self.location_trace else None - - def with_location(self, location: LocationTraceT) -> "CompilerError": - self.location_trace = CompilerError._make_location_trace(location) + def with_location(self, location: Optional[SourceLocation]) -> "CompilerError": + self.location = location return self def __str__(self): @@ -48,57 +41,46 @@ def __str__(self): return f"{self.message}\n{textwrap.indent(loc_str, ' ')}" return self.message - @staticmethod - def _make_location_trace(location: LocationTraceT) -> list[SourceLocation]: - if isinstance(location, SourceLocation): - return [location] - elif isinstance(location, list): - return location - elif location is None: - return [] - else: - raise TypeError("expected 'SourceLocation', 'list', or 'None' for 'location'") - class UnsupportedPythonFeatureError(CompilerError): - def __init__(self, location: LocationTraceT, feature: str): + def __init__(self, location: Optional[SourceLocation], feature: str): super().__init__(location, f"unsupported Python syntax: '{feature}'") class UndefinedSymbolError(CompilerError): - def __init__(self, location: LocationTraceT, name: str): + def __init__(self, location: Optional[SourceLocation], name: str): super().__init__(location, f"name '{name}' is not defined") class MissingAttributeError(CompilerError): - def __init__(self, location: LocationTraceT, attr_name: str): + def __init__(self, location: Optional[SourceLocation], attr_name: str): super().__init__(location, f"object does not have attribute '{attr_name}'") class CompilerTypeError(CompilerError): - def __init__(self, location: LocationTraceT, message: str): + def __init__(self, location: Optional[SourceLocation], message: str): super().__init__(location, message) class MissingParameterAnnotationError(CompilerTypeError): - def __init__(self, location: LocationTraceT, param_name: str): + def __init__(self, location: Optional[SourceLocation], param_name: str): super().__init__(location, f"parameter '{param_name}' is missing type annotations") class InvalidParameterAnnotationError(CompilerTypeError): - def __init__(self, location: LocationTraceT, param_name: str, type_: Any): + def __init__(self, location: Optional[SourceLocation], param_name: str, type_: Any): super().__init__( location, f"parameter '{param_name}' has invalid type annotation '{type_}'" ) class ArgumentCountError(CompilerTypeError): - def __init__(self, location: LocationTraceT, num_expected: int, num_provided: int): + def __init__(self, location: Optional[SourceLocation], num_expected: int, num_provided: int): super().__init__( location, f"expected {num_expected} arguments but {num_provided} were provided" ) class KeywordArgumentError(CompilerTypeError): - def __init__(self, location: LocationTraceT, provided_names: str): + def __init__(self, location: Optional[SourceLocation], provided_names: str): super().__init__(location, f"unexpected keyword argument(s) '{provided_names}' provided") diff --git a/src/gt4py/next/errors/formatting.py b/src/gt4py/next/errors/formatting.py index 70609a6284..3240f094c3 100644 --- a/src/gt4py/next/errors/formatting.py +++ b/src/gt4py/next/errors/formatting.py @@ -14,6 +14,7 @@ import pathlib import textwrap +from typing import Optional from gt4py.eve import SourceLocation @@ -52,13 +53,12 @@ def format_location(loc: SourceLocation, caret: bool = False): def format_compilation_error( - type_: type[Exception], message: str, location_trace: list[SourceLocation] + type_: type[Exception], message: str, location: Optional[SourceLocation] ): msg_str = f"{type_.__module__}.{type_.__name__}: {message}" - try: - loc_str = "".join([format_location(loc, caret=True) for loc in location_trace]) - stack_str = f"Source location (most recent call last):\n{textwrap.indent(loc_str, ' ')}\n" + if location is not None: + loc_str = format_location(location, caret=True) + stack_str = f"Source location:\n{textwrap.indent(loc_str, ' ')}\n" return [stack_str, msg_str] - except ValueError: - return [msg_str] + return [msg_str] diff --git a/tests/next_tests/unit_tests/errors_tests/test_compilation_error.py b/tests/next_tests/unit_tests/errors_tests/test_compilation_error.py index d815ea1228..52a1f30734 100644 --- a/tests/next_tests/unit_tests/errors_tests/test_compilation_error.py +++ b/tests/next_tests/unit_tests/errors_tests/test_compilation_error.py @@ -25,4 +25,4 @@ def test_message(): def test_location(): - assert CompilerError(loc, msg).location_trace[0] == loc + assert CompilerError(loc, msg).location == loc From 1bbeec7626ac78cee0bddb506658c2bf182d8f4f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Kardos?= Date: Wed, 5 Jul 2023 17:00:15 +0200 Subject: [PATCH 19/54] add more tests for error handling, refactor --- src/gt4py/next/errors/excepthook.py | 39 +++++------ src/gt4py/next/errors/formatting.py | 39 +++++++++-- ...ompilation_error.py => test_exceptions.py} | 28 ++++++++ .../errors_tests/test_formatting.py | 66 +++++++++++++++++++ 4 files changed, 144 insertions(+), 28 deletions(-) rename tests/next_tests/unit_tests/errors_tests/{test_compilation_error.py => test_exceptions.py} (55%) create mode 100644 tests/next_tests/unit_tests/errors_tests/test_formatting.py diff --git a/src/gt4py/next/errors/excepthook.py b/src/gt4py/next/errors/excepthook.py index f4699c4a6f..c3513e76a5 100644 --- a/src/gt4py/next/errors/excepthook.py +++ b/src/gt4py/next/errors/excepthook.py @@ -13,7 +13,6 @@ # SPDX-License-Identifier: GPL-3.0-or-later import os import sys -import traceback from typing import Callable, Optional import importlib_metadata @@ -67,31 +66,25 @@ def set_developer_mode(enabled: bool = False): _developer_mode = enabled -def _print_cause(exc: BaseException): - """Print the cause of an exception plus the bridging message to STDERR.""" - bridging_message = "The above exception was the direct cause of the following exception:" - - if exc.__cause__ or exc.__context__: - traceback.print_exception(exc.__cause__ or exc.__context__) - print(f"\n{bridging_message}\n", file=sys.stderr) - - -def _print_traceback(exc: BaseException): - """Print the traceback of an exception to STDERR.""" - intro_message = "Traceback (most recent call last):" - traceback_strs = [ - f"{intro_message}\n", - *traceback.format_tb(exc.__traceback__), - ] - print("".join(traceback_strs), file=sys.stderr) - - def compilation_error_hook(fallback: Callable, type_: type, value: BaseException, tb): + """ + Format `CompilationError`s in a neat way. + + All other Python exceptions are formatted by the `fallback` hook. + """ if isinstance(value, exceptions.CompilerError): if _developer_mode: - _print_cause(value) - _print_traceback(value) - exc_strs = formatting.format_compilation_error(type(value), value.message, value.location) + exc_strs = formatting.format_compilation_error( + type(value), + value.message, + value.location, + value.__traceback__, + value.__cause__, + ) + else: + exc_strs = formatting.format_compilation_error( + type(value), value.message, value.location + ) print("".join(exc_strs), file=sys.stderr) else: fallback(type_, value, tb) diff --git a/src/gt4py/next/errors/formatting.py b/src/gt4py/next/errors/formatting.py index 3240f094c3..c18506ed63 100644 --- a/src/gt4py/next/errors/formatting.py +++ b/src/gt4py/next/errors/formatting.py @@ -14,6 +14,8 @@ import pathlib import textwrap +import traceback +import types from typing import Optional from gt4py.eve import SourceLocation @@ -52,13 +54,40 @@ def format_location(loc: SourceLocation, caret: bool = False): return loc_str +def _format_cause(cause: BaseException) -> list[str]: + """Print the cause of an exception plus the bridging message to STDERR.""" + bridging_message = "The above exception was the direct cause of the following exception:" + cause_strs = [*traceback.format_exception(cause), "\n", f"{bridging_message}\n\n"] + return cause_strs + + +def _format_traceback(tb: types.TracebackType) -> list[str]: + """Format the traceback of an exception.""" + intro_message = "Traceback (most recent call last):" + traceback_strs = [ + f"{intro_message}\n", + *traceback.format_tb(tb), + ] + return traceback_strs + + def format_compilation_error( - type_: type[Exception], message: str, location: Optional[SourceLocation] + type_: type[Exception], + message: str, + location: Optional[SourceLocation], + tb: Optional[types.TracebackType] = None, + cause: Optional[BaseException] = None, ): - msg_str = f"{type_.__module__}.{type_.__name__}: {message}" + bits: list[str] = [] + if cause is not None: + bits = [*bits, *_format_cause(cause)] + if tb is not None: + bits = [*bits, *_format_traceback(tb)] if location is not None: loc_str = format_location(location, caret=True) - stack_str = f"Source location:\n{textwrap.indent(loc_str, ' ')}\n" - return [stack_str, msg_str] - return [msg_str] + loc_str_all = f"Source location:\n{textwrap.indent(loc_str, ' ')}\n" + bits = [*bits, loc_str_all] + msg_str = f"{type_.__module__}.{type_.__name__}: {message}" + bits = [*bits, msg_str] + return bits diff --git a/tests/next_tests/unit_tests/errors_tests/test_compilation_error.py b/tests/next_tests/unit_tests/errors_tests/test_exceptions.py similarity index 55% rename from tests/next_tests/unit_tests/errors_tests/test_compilation_error.py rename to tests/next_tests/unit_tests/errors_tests/test_exceptions.py index 52a1f30734..e3fde95e90 100644 --- a/tests/next_tests/unit_tests/errors_tests/test_compilation_error.py +++ b/tests/next_tests/unit_tests/errors_tests/test_exceptions.py @@ -12,11 +12,18 @@ # # SPDX-License-Identifier: GPL-3.0-or-later +import inspect +import re + from gt4py.eve import SourceLocation from gt4py.next.errors import CompilerError +frameinfo = inspect.getframeinfo(inspect.currentframe()) loc = SourceLocation("/source/file.py", 5, 2, end_line=5, end_column=9) +loc_snippet = SourceLocation( + frameinfo.filename, frameinfo.lineno + 2, 15, end_line=frameinfo.lineno + 2, end_column=29 +) msg = "a message" @@ -26,3 +33,24 @@ def test_message(): def test_location(): assert CompilerError(loc, msg).location == loc + + +def test_with_location(): + assert CompilerError(None, msg).with_location(loc).location == loc + + +def test_str(): + pattern = f'{msg}\\n File ".*", line.*' + s = str(CompilerError(loc, msg)) + assert re.match(pattern, s) + + +def test_str_snippet(): + pattern = ( + f"{msg}\\n" + ' File ".*", line.*\\n' + " loc_snippet = SourceLocation.*\\n" + " \^\^\^\^\^\^\^\^\^\^\^\^\^\^" + ) + s = str(CompilerError(loc_snippet, msg)) + assert re.match(pattern, s) diff --git a/tests/next_tests/unit_tests/errors_tests/test_formatting.py b/tests/next_tests/unit_tests/errors_tests/test_formatting.py new file mode 100644 index 0000000000..dbd4b74f8d --- /dev/null +++ b/tests/next_tests/unit_tests/errors_tests/test_formatting.py @@ -0,0 +1,66 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from gt4py.eve import SourceLocation +from gt4py.next.errors import CompilerError +from gt4py.next.errors.formatting import format_compilation_error +import re +import inspect + + +frameinfo = inspect.getframeinfo(inspect.currentframe()) +loc = SourceLocation("/source/file.py", 5, 2, end_line=5, end_column=9) +msg = "a message" + +module = CompilerError.__module__ +name = CompilerError.__name__ +try: + raise Exception() +except Exception as ex: + tb = ex.__traceback__ + + +def test_format(): + pattern = f"{module}.{name}: {msg}" + s = "\n".join(format_compilation_error(CompilerError, msg, None, None, None)) + assert re.match(pattern, s); + + +def test_format_loc(): + pattern = \ + "Source location.*\\n" \ + " File \"/source.*\".*\\n" \ + f"{module}.{name}: {msg}" + s = "".join(format_compilation_error(CompilerError, msg, loc, None, None)) + assert re.match(pattern, s); + + +def test_format_traceback(): + pattern = \ + "Traceback.*\\n" \ + " File \".*\".*\\n" \ + ".*\\n" \ + f"{module}.{name}: {msg}" + s = "".join(format_compilation_error(CompilerError, msg, None, tb, None)) + assert re.match(pattern, s); + + +def test_format_cause(): + cause = ValueError("asd") + pattern = \ + "ValueError: asd\\n\\n" \ + "The above.*\\n\\n" \ + f"{module}.{name}: {msg}" + s = "".join(format_compilation_error(CompilerError, msg, None, None, cause)) + assert re.match(pattern, s); From 9f95b76ab692b1c81a73d4a144377124a0265cd3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Kardos?= Date: Wed, 5 Jul 2023 17:04:48 +0200 Subject: [PATCH 20/54] keep parameters of exception for further access --- src/gt4py/next/errors/exceptions.py | 25 ++++++++++++++++ .../errors_tests/test_formatting.py | 29 +++++++------------ 2 files changed, 35 insertions(+), 19 deletions(-) diff --git a/src/gt4py/next/errors/exceptions.py b/src/gt4py/next/errors/exceptions.py index 8b403526e3..542a4838d1 100644 --- a/src/gt4py/next/errors/exceptions.py +++ b/src/gt4py/next/errors/exceptions.py @@ -43,18 +43,27 @@ def __str__(self): class UnsupportedPythonFeatureError(CompilerError): + feature: str + def __init__(self, location: Optional[SourceLocation], feature: str): super().__init__(location, f"unsupported Python syntax: '{feature}'") + self.feature = feature class UndefinedSymbolError(CompilerError): + sym_name: str + def __init__(self, location: Optional[SourceLocation], name: str): super().__init__(location, f"name '{name}' is not defined") + self.sym_name = name class MissingAttributeError(CompilerError): + attr_name: str + def __init__(self, location: Optional[SourceLocation], attr_name: str): super().__init__(location, f"object does not have attribute '{attr_name}'") + self.attr_name = attr_name class CompilerTypeError(CompilerError): @@ -63,24 +72,40 @@ def __init__(self, location: Optional[SourceLocation], message: str): class MissingParameterAnnotationError(CompilerTypeError): + param_name: str + def __init__(self, location: Optional[SourceLocation], param_name: str): super().__init__(location, f"parameter '{param_name}' is missing type annotations") + self.param_name = param_name class InvalidParameterAnnotationError(CompilerTypeError): + param_name: str + annotated_type: Any + def __init__(self, location: Optional[SourceLocation], param_name: str, type_: Any): super().__init__( location, f"parameter '{param_name}' has invalid type annotation '{type_}'" ) + self.param_name = param_name + self.annotated_type = type_ class ArgumentCountError(CompilerTypeError): + num_excected: int + num_provided: int + def __init__(self, location: Optional[SourceLocation], num_expected: int, num_provided: int): super().__init__( location, f"expected {num_expected} arguments but {num_provided} were provided" ) + self.num_expected = num_expected + self.num_provided = num_provided class KeywordArgumentError(CompilerTypeError): + provided_names: str + def __init__(self, location: Optional[SourceLocation], provided_names: str): super().__init__(location, f"unexpected keyword argument(s) '{provided_names}' provided") + self.provided_names = provided_names diff --git a/tests/next_tests/unit_tests/errors_tests/test_formatting.py b/tests/next_tests/unit_tests/errors_tests/test_formatting.py index dbd4b74f8d..5328a4f228 100644 --- a/tests/next_tests/unit_tests/errors_tests/test_formatting.py +++ b/tests/next_tests/unit_tests/errors_tests/test_formatting.py @@ -12,11 +12,12 @@ # # SPDX-License-Identifier: GPL-3.0-or-later +import inspect +import re + from gt4py.eve import SourceLocation from gt4py.next.errors import CompilerError from gt4py.next.errors.formatting import format_compilation_error -import re -import inspect frameinfo = inspect.getframeinfo(inspect.currentframe()) @@ -34,33 +35,23 @@ def test_format(): pattern = f"{module}.{name}: {msg}" s = "\n".join(format_compilation_error(CompilerError, msg, None, None, None)) - assert re.match(pattern, s); + assert re.match(pattern, s) def test_format_loc(): - pattern = \ - "Source location.*\\n" \ - " File \"/source.*\".*\\n" \ - f"{module}.{name}: {msg}" + pattern = "Source location.*\\n" ' File "/source.*".*\\n' f"{module}.{name}: {msg}" s = "".join(format_compilation_error(CompilerError, msg, loc, None, None)) - assert re.match(pattern, s); + assert re.match(pattern, s) def test_format_traceback(): - pattern = \ - "Traceback.*\\n" \ - " File \".*\".*\\n" \ - ".*\\n" \ - f"{module}.{name}: {msg}" + pattern = "Traceback.*\\n" ' File ".*".*\\n' ".*\\n" f"{module}.{name}: {msg}" s = "".join(format_compilation_error(CompilerError, msg, None, tb, None)) - assert re.match(pattern, s); + assert re.match(pattern, s) def test_format_cause(): cause = ValueError("asd") - pattern = \ - "ValueError: asd\\n\\n" \ - "The above.*\\n\\n" \ - f"{module}.{name}: {msg}" + pattern = "ValueError: asd\\n\\n" "The above.*\\n\\n" f"{module}.{name}: {msg}" s = "".join(format_compilation_error(CompilerError, msg, None, None, cause)) - assert re.match(pattern, s); + assert re.match(pattern, s) From 5d013110a933e0e31454938d73416cf93a33baac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Kardos?= Date: Wed, 12 Jul 2023 11:25:35 +0200 Subject: [PATCH 21/54] Update src/gt4py/next/errors/excepthook.py Co-authored-by: Enrique G. Paredes <18477+egparedes@users.noreply.github.com> --- src/gt4py/next/errors/excepthook.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/errors/excepthook.py b/src/gt4py/next/errors/excepthook.py index c3513e76a5..813fc6acfb 100644 --- a/src/gt4py/next/errors/excepthook.py +++ b/src/gt4py/next/errors/excepthook.py @@ -60,7 +60,7 @@ def _guess_developer_mode() -> bool: _developer_mode = _guess_developer_mode() -def set_developer_mode(enabled: bool = False): +def set_developer_mode(enabled: bool = False) -> None: """In developer mode, information useful for gt4py developers is also shown.""" global _developer_mode _developer_mode = enabled From 03c01c7021a81d62c0bf05fae60eb36243eb2b30 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Kardos?= Date: Wed, 12 Jul 2023 11:25:47 +0200 Subject: [PATCH 22/54] Update src/gt4py/next/errors/excepthook.py Co-authored-by: Enrique G. Paredes <18477+egparedes@users.noreply.github.com> --- src/gt4py/next/errors/excepthook.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/errors/excepthook.py b/src/gt4py/next/errors/excepthook.py index 813fc6acfb..c99ab752bd 100644 --- a/src/gt4py/next/errors/excepthook.py +++ b/src/gt4py/next/errors/excepthook.py @@ -66,7 +66,7 @@ def set_developer_mode(enabled: bool = False) -> None: _developer_mode = enabled -def compilation_error_hook(fallback: Callable, type_: type, value: BaseException, tb): +def compilation_error_hook(fallback: Callable, type_: type, value: BaseException, tb) -> None: """ Format `CompilationError`s in a neat way. From 61696cc48d89db1c49b887ff9cd4ae81018cec27 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Kardos?= Date: Wed, 12 Jul 2023 11:26:04 +0200 Subject: [PATCH 23/54] Update src/gt4py/next/errors/exceptions.py Co-authored-by: Enrique G. Paredes <18477+egparedes@users.noreply.github.com> --- src/gt4py/next/errors/exceptions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/errors/exceptions.py b/src/gt4py/next/errors/exceptions.py index 542a4838d1..a51e1c4857 100644 --- a/src/gt4py/next/errors/exceptions.py +++ b/src/gt4py/next/errors/exceptions.py @@ -23,7 +23,7 @@ class CompilerError(Exception): location: Optional[SourceLocation] - def __init__(self, location: Optional[SourceLocation], message: str): + def __init__(self, location: Optional[SourceLocation], message: str) -> None: self.location = location super().__init__(message) From 5c09705f754271f4ad7134634a9e1231989c0240 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Kardos?= Date: Wed, 12 Jul 2023 11:23:32 +0200 Subject: [PATCH 24/54] rename functions --- src/gt4py/next/errors/excepthook.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/gt4py/next/errors/excepthook.py b/src/gt4py/next/errors/excepthook.py index c99ab752bd..397503dcba 100644 --- a/src/gt4py/next/errors/excepthook.py +++ b/src/gt4py/next/errors/excepthook.py @@ -20,7 +20,7 @@ from . import exceptions, formatting -def _get_developer_mode_python() -> bool: +def _get_developer_mode_python_env() -> bool: """Guess if the Python environment is used to develop gt4py.""" # Import gt4py and use its __name__ because hard-coding "gt4py" would fail # silently if the module's name changes for whatever reason. @@ -38,7 +38,7 @@ def _get_developer_mode_python() -> bool: return True -def _get_developer_mode_os() -> Optional[bool]: +def _get_developer_mode_envvar() -> Optional[bool]: """Detect if the user set developer mode in environment variables.""" env_var_name = "GT4PY_DEVELOPER_MODE" if env_var_name in os.environ: @@ -51,10 +51,10 @@ def _get_developer_mode_os() -> Optional[bool]: def _guess_developer_mode() -> bool: """Guess if gt4py is run by its developers or by third party users.""" - env = _get_developer_mode_os() + env = _get_developer_mode_envvar() if env is not None: return env - return _get_developer_mode_python() + return _get_developer_mode_python_env() _developer_mode = _guess_developer_mode() From 3b5f1f4db204b841869ef78da0a7c6cf43479c1b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Kardos?= Date: Wed, 12 Jul 2023 11:28:54 +0200 Subject: [PATCH 25/54] avoid stringifying annotations --- src/gt4py/next/errors/exceptions.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/errors/exceptions.py b/src/gt4py/next/errors/exceptions.py index a51e1c4857..a282b5e779 100644 --- a/src/gt4py/next/errors/exceptions.py +++ b/src/gt4py/next/errors/exceptions.py @@ -12,6 +12,8 @@ # # SPDX-License-Identifier: GPL-3.0-or-later +from __future__ import annotations + import textwrap from typing import Any, Optional @@ -31,7 +33,7 @@ def __init__(self, location: Optional[SourceLocation], message: str) -> None: def message(self) -> str: return self.args[0] - def with_location(self, location: Optional[SourceLocation]) -> "CompilerError": + def with_location(self, location: Optional[SourceLocation]) -> CompilerError: self.location = location return self From a17ec99258d5988d2be740ba428ccbfea00ffedc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Kardos?= Date: Wed, 12 Jul 2023 11:30:23 +0200 Subject: [PATCH 26/54] Update src/gt4py/next/errors/exceptions.py Co-authored-by: Enrique G. Paredes <18477+egparedes@users.noreply.github.com> --- src/gt4py/next/errors/exceptions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/errors/exceptions.py b/src/gt4py/next/errors/exceptions.py index a282b5e779..387a88061f 100644 --- a/src/gt4py/next/errors/exceptions.py +++ b/src/gt4py/next/errors/exceptions.py @@ -37,7 +37,7 @@ def with_location(self, location: Optional[SourceLocation]) -> CompilerError: self.location = location return self - def __str__(self): + def __str__(self) -> str: if self.location: loc_str = formatting.format_location(self.location, caret=True) return f"{self.message}\n{textwrap.indent(loc_str, ' ')}" From e010bb81938420c67cc2119c4886a8fac001d326 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Kardos?= Date: Wed, 12 Jul 2023 11:32:58 +0200 Subject: [PATCH 27/54] type annotations --- src/gt4py/next/errors/exceptions.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/gt4py/next/errors/exceptions.py b/src/gt4py/next/errors/exceptions.py index 387a88061f..9bf72fa5cd 100644 --- a/src/gt4py/next/errors/exceptions.py +++ b/src/gt4py/next/errors/exceptions.py @@ -47,7 +47,7 @@ def __str__(self) -> str: class UnsupportedPythonFeatureError(CompilerError): feature: str - def __init__(self, location: Optional[SourceLocation], feature: str): + def __init__(self, location: Optional[SourceLocation], feature: str) -> None: super().__init__(location, f"unsupported Python syntax: '{feature}'") self.feature = feature @@ -55,7 +55,7 @@ def __init__(self, location: Optional[SourceLocation], feature: str): class UndefinedSymbolError(CompilerError): sym_name: str - def __init__(self, location: Optional[SourceLocation], name: str): + def __init__(self, location: Optional[SourceLocation], name: str) -> None: super().__init__(location, f"name '{name}' is not defined") self.sym_name = name @@ -63,20 +63,20 @@ def __init__(self, location: Optional[SourceLocation], name: str): class MissingAttributeError(CompilerError): attr_name: str - def __init__(self, location: Optional[SourceLocation], attr_name: str): + def __init__(self, location: Optional[SourceLocation], attr_name: str) -> None: super().__init__(location, f"object does not have attribute '{attr_name}'") self.attr_name = attr_name class CompilerTypeError(CompilerError): - def __init__(self, location: Optional[SourceLocation], message: str): + def __init__(self, location: Optional[SourceLocation], message: str) -> None: super().__init__(location, message) class MissingParameterAnnotationError(CompilerTypeError): param_name: str - def __init__(self, location: Optional[SourceLocation], param_name: str): + def __init__(self, location: Optional[SourceLocation], param_name: str) -> None: super().__init__(location, f"parameter '{param_name}' is missing type annotations") self.param_name = param_name @@ -85,7 +85,7 @@ class InvalidParameterAnnotationError(CompilerTypeError): param_name: str annotated_type: Any - def __init__(self, location: Optional[SourceLocation], param_name: str, type_: Any): + def __init__(self, location: Optional[SourceLocation], param_name: str, type_: Any) -> None: super().__init__( location, f"parameter '{param_name}' has invalid type annotation '{type_}'" ) @@ -97,7 +97,9 @@ class ArgumentCountError(CompilerTypeError): num_excected: int num_provided: int - def __init__(self, location: Optional[SourceLocation], num_expected: int, num_provided: int): + def __init__( + self, location: Optional[SourceLocation], num_expected: int, num_provided: int + ) -> None: super().__init__( location, f"expected {num_expected} arguments but {num_provided} were provided" ) @@ -108,6 +110,6 @@ def __init__(self, location: Optional[SourceLocation], num_expected: int, num_pr class KeywordArgumentError(CompilerTypeError): provided_names: str - def __init__(self, location: Optional[SourceLocation], provided_names: str): + def __init__(self, location: Optional[SourceLocation], provided_names: str) -> None: super().__init__(location, f"unexpected keyword argument(s) '{provided_names}' provided") self.provided_names = provided_names From 73b8a9159e44f7d5454d249a12b4b6fe881b2592 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Kardos?= Date: Wed, 12 Jul 2023 11:52:53 +0200 Subject: [PATCH 28/54] type annotations --- src/gt4py/next/errors/formatting.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/gt4py/next/errors/formatting.py b/src/gt4py/next/errors/formatting.py index c18506ed63..c565cfdceb 100644 --- a/src/gt4py/next/errors/formatting.py +++ b/src/gt4py/next/errors/formatting.py @@ -21,7 +21,7 @@ from gt4py.eve import SourceLocation -def get_source_from_location(location: SourceLocation): +def get_source_from_location(location: SourceLocation) -> str: if not location.filename: raise FileNotFoundError() source_file = pathlib.Path(location.filename) @@ -33,7 +33,7 @@ def get_source_from_location(location: SourceLocation): return "\n".join(relevant_lines) -def format_location(loc: SourceLocation, caret: bool = False): +def format_location(loc: SourceLocation, caret: bool = False) -> str: filename = loc.filename or "" lineno = loc.line or "" loc_str = f'File "{filename}", line {lineno}' @@ -77,7 +77,7 @@ def format_compilation_error( location: Optional[SourceLocation], tb: Optional[types.TracebackType] = None, cause: Optional[BaseException] = None, -): +) -> list[str]: bits: list[str] = [] if cause is not None: From 4951df7657a2d0251373730960b92b6291ea2259 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Kardos?= Date: Wed, 12 Jul 2023 12:10:26 +0200 Subject: [PATCH 29/54] rename vars to be consistent with class name --- src/gt4py/next/errors/exceptions.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/gt4py/next/errors/exceptions.py b/src/gt4py/next/errors/exceptions.py index 9bf72fa5cd..28a41a9137 100644 --- a/src/gt4py/next/errors/exceptions.py +++ b/src/gt4py/next/errors/exceptions.py @@ -94,17 +94,17 @@ def __init__(self, location: Optional[SourceLocation], param_name: str, type_: A class ArgumentCountError(CompilerTypeError): - num_excected: int - num_provided: int + expected_count: int + provided_count: int def __init__( - self, location: Optional[SourceLocation], num_expected: int, num_provided: int + self, location: Optional[SourceLocation], expected_count: int, provided_count: int ) -> None: super().__init__( - location, f"expected {num_expected} arguments but {num_provided} were provided" + location, f"expected {expected_count} arguments but {provided_count} were provided" ) - self.num_expected = num_expected - self.num_provided = num_provided + self.num_expected = expected_count + self.provided_count = provided_count class KeywordArgumentError(CompilerTypeError): From 519f242c89f40ffcb953d110a83d03a76534475f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Kardos?= Date: Wed, 12 Jul 2023 12:11:23 +0200 Subject: [PATCH 30/54] use linecache instead of loading file from disc --- src/gt4py/next/errors/formatting.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/gt4py/next/errors/formatting.py b/src/gt4py/next/errors/formatting.py index c565cfdceb..1679682182 100644 --- a/src/gt4py/next/errors/formatting.py +++ b/src/gt4py/next/errors/formatting.py @@ -12,7 +12,7 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -import pathlib +import linecache import textwrap import traceback import types @@ -24,9 +24,9 @@ def get_source_from_location(location: SourceLocation) -> str: if not location.filename: raise FileNotFoundError() - source_file = pathlib.Path(location.filename) - source_code = source_file.read_text() - source_lines = source_code.splitlines(False) + source_lines = linecache.getlines(location.filename) + if not source_lines: + raise FileNotFoundError() start_line = location.line end_line = location.end_line + 1 if location.end_line else start_line + 1 relevant_lines = source_lines[(start_line - 1) : (end_line - 1)] @@ -48,7 +48,7 @@ def format_location(loc: SourceLocation, caret: bool = False) -> str: try: snippet_str = get_source_from_location(loc) if caret_str: - snippet_str = f"{snippet_str}\n{caret_str}" + snippet_str = f"{snippet_str}{caret_str}" return f"{loc_str}\n{textwrap.indent(snippet_str, ' ')}" except Exception: return loc_str From f6bfdc4324006675046744356284f4338f870002 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Kardos?= Date: Wed, 12 Jul 2023 12:28:20 +0200 Subject: [PATCH 31/54] remove blanket exception handling --- src/gt4py/next/errors/formatting.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/gt4py/next/errors/formatting.py b/src/gt4py/next/errors/formatting.py index 1679682182..75b7fe408d 100644 --- a/src/gt4py/next/errors/formatting.py +++ b/src/gt4py/next/errors/formatting.py @@ -28,14 +28,14 @@ def get_source_from_location(location: SourceLocation) -> str: if not source_lines: raise FileNotFoundError() start_line = location.line - end_line = location.end_line + 1 if location.end_line else start_line + 1 + end_line = location.end_line + 1 if location.end_line is not None else start_line + 1 relevant_lines = source_lines[(start_line - 1) : (end_line - 1)] return "\n".join(relevant_lines) def format_location(loc: SourceLocation, caret: bool = False) -> str: filename = loc.filename or "" - lineno = loc.line or "" + lineno = loc.line loc_str = f'File "{filename}", line {lineno}' if caret and loc.column is not None: @@ -50,7 +50,7 @@ def format_location(loc: SourceLocation, caret: bool = False) -> str: if caret_str: snippet_str = f"{snippet_str}{caret_str}" return f"{loc_str}\n{textwrap.indent(snippet_str, ' ')}" - except Exception: + except (FileNotFoundError, IndexError): return loc_str From c7d7eb7820f8b32df9a57381b04988bf29bebc05 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Kardos?= Date: Wed, 12 Jul 2023 12:39:21 +0200 Subject: [PATCH 32/54] improve docstrings --- src/gt4py/next/errors/exceptions.py | 2 +- src/gt4py/next/errors/formatting.py | 14 ++++++++++---- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/src/gt4py/next/errors/exceptions.py b/src/gt4py/next/errors/exceptions.py index 28a41a9137..9854dcdef4 100644 --- a/src/gt4py/next/errors/exceptions.py +++ b/src/gt4py/next/errors/exceptions.py @@ -39,7 +39,7 @@ def with_location(self, location: Optional[SourceLocation]) -> CompilerError: def __str__(self) -> str: if self.location: - loc_str = formatting.format_location(self.location, caret=True) + loc_str = formatting.format_location(self.location, show_caret=True) return f"{self.message}\n{textwrap.indent(loc_str, ' ')}" return self.message diff --git a/src/gt4py/next/errors/formatting.py b/src/gt4py/next/errors/formatting.py index 75b7fe408d..face3873b0 100644 --- a/src/gt4py/next/errors/formatting.py +++ b/src/gt4py/next/errors/formatting.py @@ -33,12 +33,18 @@ def get_source_from_location(location: SourceLocation) -> str: return "\n".join(relevant_lines) -def format_location(loc: SourceLocation, caret: bool = False) -> str: +def format_location(loc: SourceLocation, show_caret: bool = False) -> str: + """ + Format the source file location. + + Args: + show_caret (bool): Indicate the position within the source line by placing carets underneath. + """ filename = loc.filename or "" lineno = loc.line loc_str = f'File "{filename}", line {lineno}' - if caret and loc.column is not None: + if show_caret and loc.column is not None: offset = loc.column - 1 width = loc.end_column - loc.column if loc.end_column is not None else 1 caret_str = "".join([" "] * offset + ["^"] * width) @@ -55,7 +61,7 @@ def format_location(loc: SourceLocation, caret: bool = False) -> str: def _format_cause(cause: BaseException) -> list[str]: - """Print the cause of an exception plus the bridging message to STDERR.""" + """Format the cause of an exception plus the bridging message to the current exception.""" bridging_message = "The above exception was the direct cause of the following exception:" cause_strs = [*traceback.format_exception(cause), "\n", f"{bridging_message}\n\n"] return cause_strs @@ -85,7 +91,7 @@ def format_compilation_error( if tb is not None: bits = [*bits, *_format_traceback(tb)] if location is not None: - loc_str = format_location(location, caret=True) + loc_str = format_location(location, show_caret=True) loc_str_all = f"Source location:\n{textwrap.indent(loc_str, ' ')}\n" bits = [*bits, loc_str_all] msg_str = f"{type_.__module__}.{type_.__name__}: {message}" From 89fc3a0a051cf4c1d85143c8a1351b32375b4c5a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Kardos?= Date: Wed, 12 Jul 2023 13:04:18 +0200 Subject: [PATCH 33/54] remove blanket imports --- .../feature_tests/ffront_tests/test_execution.py | 2 +- .../feature_tests/ffront_tests/test_scalar_if.py | 2 +- .../feature_tests/ffront_tests/test_type_deduction.py | 2 +- .../unit_tests/ffront_tests/test_func_to_foast.py | 6 +++++- .../ffront_tests/test_func_to_foast_error_line_number.py | 2 +- .../next_tests/unit_tests/ffront_tests/test_func_to_past.py | 2 +- .../next_tests/unit_tests/ffront_tests/test_past_to_itir.py | 2 +- 7 files changed, 11 insertions(+), 7 deletions(-) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index f8d816b96a..871f853eb7 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -31,7 +31,7 @@ neighbor_sum, where, ) -from gt4py.next.errors import * +from gt4py.next.errors import CompilerError from gt4py.next.ffront.experimental import as_offset from gt4py.next.program_processors.runners import gtfn_cpu diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py index 98728d5d5b..d2a6d91b8c 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py @@ -19,7 +19,7 @@ import pytest from gt4py.next import Field, field_operator, float64, index_field, np_as_located_field -from gt4py.next.errors import * +from gt4py.next.errors import CompilerError from gt4py.next.program_processors.runners import gtfn_cpu from next_tests.integration_tests.feature_tests import cases diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py index 1908a924d6..6d88eafcf8 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py @@ -31,7 +31,7 @@ neighbor_sum, where, ) -from gt4py.next.errors import * +from gt4py.next.errors import CompilerError from gt4py.next.ffront.ast_passes import single_static_assign as ssa from gt4py.next.ffront.experimental import as_offset from gt4py.next.ffront.func_to_foast import FieldOperatorParser diff --git a/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py b/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py index c7b6faa034..a9ab474f22 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py @@ -42,7 +42,11 @@ import gt4py.next as gtx from gt4py.eve.pattern_matching import ObjectPattern as P from gt4py.next import astype, broadcast, float32, float64, int32, int64, where -from gt4py.next.errors import * +from gt4py.next.errors import ( + CompilerError, + MissingParameterAnnotationError, + UnsupportedPythonFeatureError, +) from gt4py.next.ffront import field_operator_ast as foast from gt4py.next.ffront.ast_passes import single_static_assign as ssa from gt4py.next.ffront.func_to_foast import FieldOperatorParser diff --git a/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast_error_line_number.py b/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast_error_line_number.py index 272b1428dd..e83d8acbcb 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast_error_line_number.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast_error_line_number.py @@ -18,7 +18,7 @@ import pytest import gt4py.next as gtx -from gt4py.next.errors import * +from gt4py.next.errors import CompilerError from gt4py.next.ffront import func_to_foast as f2f, source_utils as src_utils from gt4py.next.ffront.foast_passes import type_deduction diff --git a/tests/next_tests/unit_tests/ffront_tests/test_func_to_past.py b/tests/next_tests/unit_tests/ffront_tests/test_func_to_past.py index 013684972e..3b2ac53fde 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_func_to_past.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_func_to_past.py @@ -20,7 +20,7 @@ import gt4py.next as gtx from gt4py.eve.pattern_matching import ObjectPattern as P from gt4py.next import float64 -from gt4py.next.errors import * +from gt4py.next.errors import CompilerError from gt4py.next.ffront import program_ast as past from gt4py.next.ffront.func_to_past import ProgramParser from gt4py.next.type_system import type_specifications as ts diff --git a/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py b/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py index b0069ddf5f..f33837458a 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py @@ -19,7 +19,7 @@ import gt4py.eve as eve import gt4py.next as gtx from gt4py.eve.pattern_matching import ObjectPattern as P -from gt4py.next.errors import * +from gt4py.next.errors import CompilerError from gt4py.next.ffront.func_to_past import ProgramParser from gt4py.next.ffront.past_to_itir import ProgramLowering from gt4py.next.iterator import ir as itir From 935d37b7c55b4951a5caec4d433340f3f6301373 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Kardos?= Date: Thu, 13 Jul 2023 10:30:03 +0200 Subject: [PATCH 34/54] tests for the exception hook --- src/gt4py/next/errors/excepthook.py | 41 ++++++++++--------- .../errors_tests/test_excepthook.py | 33 +++++++++++++++ 2 files changed, 55 insertions(+), 19 deletions(-) create mode 100644 tests/next_tests/unit_tests/errors_tests/test_excepthook.py diff --git a/src/gt4py/next/errors/excepthook.py b/src/gt4py/next/errors/excepthook.py index 397503dcba..d99f762e6b 100644 --- a/src/gt4py/next/errors/excepthook.py +++ b/src/gt4py/next/errors/excepthook.py @@ -49,15 +49,14 @@ def _get_developer_mode_envvar() -> Optional[bool]: return None -def _guess_developer_mode() -> bool: - """Guess if gt4py is run by its developers or by third party users.""" - env = _get_developer_mode_envvar() - if env is not None: - return env - return _get_developer_mode_python_env() +def _determine_developer_mode(python_env_enabled: bool, envvar_enabled: Optional[bool]) -> bool: + """Determine if gt4py is run by its developers or by third party users.""" + if envvar_enabled is not None: + return envvar_enabled + return python_env_enabled -_developer_mode = _guess_developer_mode() +_developer_mode = _determine_developer_mode(_get_developer_mode_python_env(), _get_developer_mode_envvar()) def set_developer_mode(enabled: bool = False) -> None: @@ -66,6 +65,21 @@ def set_developer_mode(enabled: bool = False) -> None: _developer_mode = enabled +def _format_uncaught_error(err: exceptions.CompilerError, developer_mode: bool) -> list[str]: + if developer_mode: + return formatting.format_compilation_error( + type(err), + err.message, + err.location, + err.__traceback__, + err.__cause__, + ) + else: + return formatting.format_compilation_error( + type(err), err.message, err.location + ) + + def compilation_error_hook(fallback: Callable, type_: type, value: BaseException, tb) -> None: """ Format `CompilationError`s in a neat way. @@ -73,18 +87,7 @@ def compilation_error_hook(fallback: Callable, type_: type, value: BaseException All other Python exceptions are formatted by the `fallback` hook. """ if isinstance(value, exceptions.CompilerError): - if _developer_mode: - exc_strs = formatting.format_compilation_error( - type(value), - value.message, - value.location, - value.__traceback__, - value.__cause__, - ) - else: - exc_strs = formatting.format_compilation_error( - type(value), value.message, value.location - ) + exc_strs = _format_uncaught_error(value, _developer_mode) print("".join(exc_strs), file=sys.stderr) else: fallback(type_, value, tb) diff --git a/tests/next_tests/unit_tests/errors_tests/test_excepthook.py b/tests/next_tests/unit_tests/errors_tests/test_excepthook.py new file mode 100644 index 0000000000..9c8e132b06 --- /dev/null +++ b/tests/next_tests/unit_tests/errors_tests/test_excepthook.py @@ -0,0 +1,33 @@ +from gt4py.next.errors import excepthook +from gt4py.next.errors import exceptions +from gt4py import eve + + +def test_determine_developer_mode(): + # Env var overrides python env. + assert excepthook._determine_developer_mode(False, False) == False + assert excepthook._determine_developer_mode(True, False) == False + assert excepthook._determine_developer_mode(False, True) == True + assert excepthook._determine_developer_mode(True, True) == True + + # Defaults to python env if no env var specified. + assert excepthook._determine_developer_mode(False, None) == False + assert excepthook._determine_developer_mode(True, None) == True + + +def test_format_uncaught_error(): + try: + loc = eve.SourceLocation("/src/file.py", 1, 1) + msg = "compile error msg" + raise exceptions.CompilerError(loc, msg) from ValueError("value error msg") + except exceptions.CompilerError as err: + str_devmode = "".join(excepthook._format_uncaught_error(err, True)) + assert str_devmode.find("Source location") >= 0 + assert str_devmode.find("Traceback") >= 0 + assert str_devmode.find("cause") >= 0 + assert str_devmode.find("ValueError") >= 0 + str_usermode = "".join(excepthook._format_uncaught_error(err, False)) + assert str_usermode.find("Source location") >= 0 + assert str_usermode.find("Traceback") < 0 + assert str_usermode.find("cause") < 0 + assert str_usermode.find("ValueError") < 0 \ No newline at end of file From d8d603f8f4ae33059ec5e834764c146a051b3b8f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Kardos?= Date: Thu, 13 Jul 2023 10:45:50 +0200 Subject: [PATCH 35/54] use typeerror --- src/gt4py/next/ffront/decorator.py | 2 +- .../feature_tests/ffront_tests/test_program.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 6f3e8ca97b..2a343454e1 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -290,7 +290,7 @@ def _validate_args(self, *args, **kwargs) -> None: raise_exception=True, ) except ValueError as err: - raise ValueError(f"Invalid argument types in call to `{self.past_node.id}`!") from err + raise TypeError(f"Invalid argument types in call to `{self.past_node.id}`!") from err def _process_args(self, args: tuple, kwargs: dict) -> tuple[tuple, tuple, dict[str, Any]]: self._validate_args(*args, **kwargs) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py index 1619bf343f..d2c80d26be 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py @@ -223,7 +223,7 @@ def test_wrong_argument_type(cartesian_case, copy_program_def): inp = gtx.np_as_located_field(JDim)(np.ones((cartesian_case.default_sizes[JDim],))) out = cases.allocate(cartesian_case, copy_program, "out").strategy(cases.ConstInitializer(1))() - with pytest.raises(ValueError) as exc_info: + with pytest.raises(TypeError) as exc_info: # program is defined on Field[[IDim], ...], but we call with # Field[[JDim], ...] copy_program(inp, out, offset_provider={}) @@ -293,6 +293,6 @@ def program_input_kwargs( assert np.allclose(expected, out) with pytest.raises( - ValueError, match="Invalid argument types in call to `program_input_kwargs`!" + TypeError, match="Invalid argument types in call to `program_input_kwargs`!" ): program_input_kwargs(input_2, input_3, a=input_1, out=out, offset_provider={}) From fb2489588e5ebe80f1a955f10888ed34e3cc890c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Kardos?= Date: Thu, 13 Jul 2023 10:46:09 +0200 Subject: [PATCH 36/54] return None for invalid env var --- src/gt4py/next/errors/excepthook.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/errors/excepthook.py b/src/gt4py/next/errors/excepthook.py index d99f762e6b..7bc7661623 100644 --- a/src/gt4py/next/errors/excepthook.py +++ b/src/gt4py/next/errors/excepthook.py @@ -45,7 +45,7 @@ def _get_developer_mode_envvar() -> Optional[bool]: try: return bool(os.environ[env_var_name]) except TypeError: - return False + return None return None From 084d218f489f6c6ae99d1ebad52022c6e2594e0d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Kardos?= Date: Thu, 13 Jul 2023 11:01:10 +0200 Subject: [PATCH 37/54] change names of tests to be more descriptive --- .../unit_tests/ffront_tests/test_func_to_foast.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py b/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py index a9ab474f22..271b2b849d 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py @@ -189,7 +189,7 @@ def modulo(inp: gtx.Field[[TDim], "int32"]): ) -def test_bool_and(): +def test_boolean_and_op_unsupported(): def bool_and(a: gtx.Field[[TDim], "bool"], b: gtx.Field[[TDim], "bool"]): return a and b @@ -200,7 +200,7 @@ def bool_and(a: gtx.Field[[TDim], "bool"], b: gtx.Field[[TDim], "bool"]): _ = FieldOperatorParser.apply_to_function(bool_and) -def test_bool_or(): +def test_boolean_or_op_unsupported(): def bool_or(a: gtx.Field[[TDim], "bool"], b: gtx.Field[[TDim], "bool"]): return a or b @@ -235,7 +235,7 @@ def unary_tilde(a: gtx.Field[[TDim], "bool"]): ) -def test_scalar_cast(): +def test_scalar_cast_disallow_non_literals(): def cast_scalar_temp(): tmp = int64(1) return int32(tmp) From c53fdc178a084250e06d8991d743986efbe227f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Kardos?= Date: Thu, 13 Jul 2023 11:07:18 +0200 Subject: [PATCH 38/54] remove developer mode autodetection --- src/gt4py/next/errors/excepthook.py | 41 +++---------------- .../errors_tests/test_excepthook.py | 31 +++++++------- 2 files changed, 22 insertions(+), 50 deletions(-) diff --git a/src/gt4py/next/errors/excepthook.py b/src/gt4py/next/errors/excepthook.py index 7bc7661623..c003df3f49 100644 --- a/src/gt4py/next/errors/excepthook.py +++ b/src/gt4py/next/errors/excepthook.py @@ -13,50 +13,23 @@ # SPDX-License-Identifier: GPL-3.0-or-later import os import sys -from typing import Callable, Optional - -import importlib_metadata +from typing import Callable from . import exceptions, formatting -def _get_developer_mode_python_env() -> bool: - """Guess if the Python environment is used to develop gt4py.""" - # Import gt4py and use its __name__ because hard-coding "gt4py" would fail - # silently if the module's name changes for whatever reason. - import gt4py - - package_name = gt4py.__name__ - - # Check if any package requires gt4py as a dependency. If not, we are - # probably developing gt4py itself rather than something else using gt4py. - dists = importlib_metadata.distributions() - for dist in dists: - for req in dist.requires or []: - if req.startswith(package_name): - return False - return True - - -def _get_developer_mode_envvar() -> Optional[bool]: +def _get_developer_mode_envvar() -> bool: """Detect if the user set developer mode in environment variables.""" env_var_name = "GT4PY_DEVELOPER_MODE" if env_var_name in os.environ: try: return bool(os.environ[env_var_name]) except TypeError: - return None - return None + return False + return False -def _determine_developer_mode(python_env_enabled: bool, envvar_enabled: Optional[bool]) -> bool: - """Determine if gt4py is run by its developers or by third party users.""" - if envvar_enabled is not None: - return envvar_enabled - return python_env_enabled - - -_developer_mode = _determine_developer_mode(_get_developer_mode_python_env(), _get_developer_mode_envvar()) +_developer_mode: bool = _get_developer_mode_envvar() def set_developer_mode(enabled: bool = False) -> None: @@ -75,9 +48,7 @@ def _format_uncaught_error(err: exceptions.CompilerError, developer_mode: bool) err.__cause__, ) else: - return formatting.format_compilation_error( - type(err), err.message, err.location - ) + return formatting.format_compilation_error(type(err), err.message, err.location) def compilation_error_hook(fallback: Callable, type_: type, value: BaseException, tb) -> None: diff --git a/tests/next_tests/unit_tests/errors_tests/test_excepthook.py b/tests/next_tests/unit_tests/errors_tests/test_excepthook.py index 9c8e132b06..d14ac49934 100644 --- a/tests/next_tests/unit_tests/errors_tests/test_excepthook.py +++ b/tests/next_tests/unit_tests/errors_tests/test_excepthook.py @@ -1,18 +1,19 @@ -from gt4py.next.errors import excepthook -from gt4py.next.errors import exceptions -from gt4py import eve - +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later -def test_determine_developer_mode(): - # Env var overrides python env. - assert excepthook._determine_developer_mode(False, False) == False - assert excepthook._determine_developer_mode(True, False) == False - assert excepthook._determine_developer_mode(False, True) == True - assert excepthook._determine_developer_mode(True, True) == True - - # Defaults to python env if no env var specified. - assert excepthook._determine_developer_mode(False, None) == False - assert excepthook._determine_developer_mode(True, None) == True +from gt4py import eve +from gt4py.next.errors import excepthook, exceptions def test_format_uncaught_error(): @@ -30,4 +31,4 @@ def test_format_uncaught_error(): assert str_usermode.find("Source location") >= 0 assert str_usermode.find("Traceback") < 0 assert str_usermode.find("cause") < 0 - assert str_usermode.find("ValueError") < 0 \ No newline at end of file + assert str_usermode.find("ValueError") < 0 From 8371cabd84d5f30537fa061462a67b3f59216c73 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Kardos?= Date: Thu, 13 Jul 2023 11:16:24 +0200 Subject: [PATCH 39/54] delete some unused exception classes --- src/gt4py/next/errors/__init__.py | 4 ---- src/gt4py/next/errors/exceptions.py | 22 ---------------------- 2 files changed, 26 deletions(-) diff --git a/src/gt4py/next/errors/__init__.py b/src/gt4py/next/errors/__init__.py index de9cce0055..b84bdde251 100644 --- a/src/gt4py/next/errors/__init__.py +++ b/src/gt4py/next/errors/__init__.py @@ -17,10 +17,8 @@ ) from .excepthook import set_developer_mode from .exceptions import ( - ArgumentCountError, CompilerError, InvalidParameterAnnotationError, - KeywordArgumentError, MissingAttributeError, MissingParameterAnnotationError, UndefinedSymbolError, @@ -29,10 +27,8 @@ __all__ = [ - "ArgumentCountError", "CompilerError", "InvalidParameterAnnotationError", - "KeywordArgumentError", "MissingAttributeError", "MissingParameterAnnotationError", "UndefinedSymbolError", diff --git a/src/gt4py/next/errors/exceptions.py b/src/gt4py/next/errors/exceptions.py index 9854dcdef4..3293b2d88e 100644 --- a/src/gt4py/next/errors/exceptions.py +++ b/src/gt4py/next/errors/exceptions.py @@ -91,25 +91,3 @@ def __init__(self, location: Optional[SourceLocation], param_name: str, type_: A ) self.param_name = param_name self.annotated_type = type_ - - -class ArgumentCountError(CompilerTypeError): - expected_count: int - provided_count: int - - def __init__( - self, location: Optional[SourceLocation], expected_count: int, provided_count: int - ) -> None: - super().__init__( - location, f"expected {expected_count} arguments but {provided_count} were provided" - ) - self.num_expected = expected_count - self.provided_count = provided_count - - -class KeywordArgumentError(CompilerTypeError): - provided_names: str - - def __init__(self, location: Optional[SourceLocation], provided_names: str) -> None: - super().__init__(location, f"unexpected keyword argument(s) '{provided_names}' provided") - self.provided_names = provided_names From 8e8e0081c6ba7063ba19d423602d70ce8a5cc0aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Kardos?= Date: Thu, 13 Jul 2023 11:24:40 +0200 Subject: [PATCH 40/54] rename developer mode to verbose exceptions --- src/gt4py/next/errors/__init__.py | 4 ++-- src/gt4py/next/errors/excepthook.py | 22 +++++++++++----------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/src/gt4py/next/errors/__init__.py b/src/gt4py/next/errors/__init__.py index b84bdde251..1e54a231c4 100644 --- a/src/gt4py/next/errors/__init__.py +++ b/src/gt4py/next/errors/__init__.py @@ -15,7 +15,7 @@ from . import ( # noqa: module needs to be loaded for pretty printing of uncaught exceptions. excepthook, ) -from .excepthook import set_developer_mode +from .excepthook import set_verbose_exceptions from .exceptions import ( CompilerError, InvalidParameterAnnotationError, @@ -33,5 +33,5 @@ "MissingParameterAnnotationError", "UndefinedSymbolError", "UnsupportedPythonFeatureError", - "set_developer_mode", + "set_verbose_exceptions", ] diff --git a/src/gt4py/next/errors/excepthook.py b/src/gt4py/next/errors/excepthook.py index c003df3f49..7d1e5b145e 100644 --- a/src/gt4py/next/errors/excepthook.py +++ b/src/gt4py/next/errors/excepthook.py @@ -18,9 +18,9 @@ from . import exceptions, formatting -def _get_developer_mode_envvar() -> bool: - """Detect if the user set developer mode in environment variables.""" - env_var_name = "GT4PY_DEVELOPER_MODE" +def _get_verbose_exceptions_envvar() -> bool: + """Detect if the user enabled verbose exceptions in the environment variables.""" + env_var_name = "GT4PY_VERBOSE_EXCEPTIONS" if env_var_name in os.environ: try: return bool(os.environ[env_var_name]) @@ -29,17 +29,17 @@ def _get_developer_mode_envvar() -> bool: return False -_developer_mode: bool = _get_developer_mode_envvar() +_verbose_exceptions: bool = _get_verbose_exceptions_envvar() -def set_developer_mode(enabled: bool = False) -> None: - """In developer mode, information useful for gt4py developers is also shown.""" - global _developer_mode - _developer_mode = enabled +def set_verbose_exceptions(enabled: bool = False) -> None: + """With verbose exceptions, the stack trace and cause of the error is also printed.""" + global _verbose_exceptions + _verbose_exceptions = enabled -def _format_uncaught_error(err: exceptions.CompilerError, developer_mode: bool) -> list[str]: - if developer_mode: +def _format_uncaught_error(err: exceptions.CompilerError, verbose_exceptions: bool) -> list[str]: + if verbose_exceptions: return formatting.format_compilation_error( type(err), err.message, @@ -58,7 +58,7 @@ def compilation_error_hook(fallback: Callable, type_: type, value: BaseException All other Python exceptions are formatted by the `fallback` hook. """ if isinstance(value, exceptions.CompilerError): - exc_strs = _format_uncaught_error(value, _developer_mode) + exc_strs = _format_uncaught_error(value, _verbose_exceptions) print("".join(exc_strs), file=sys.stderr) else: fallback(type_, value, tb) From 5e08a501df33fad5c32be6ced453d21c9b8ae7f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Kardos?= Date: Thu, 13 Jul 2023 11:33:26 +0200 Subject: [PATCH 41/54] Update src/gt4py/next/ffront/func_to_foast.py Co-authored-by: Till Ehrengruber --- src/gt4py/next/ffront/func_to_foast.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/ffront/func_to_foast.py b/src/gt4py/next/ffront/func_to_foast.py index 1fecc2b049..4f1e21d08b 100644 --- a/src/gt4py/next/ffront/func_to_foast.py +++ b/src/gt4py/next/ffront/func_to_foast.py @@ -171,7 +171,7 @@ def visit_FunctionDef(self, node: ast.FunctionDef, **kwargs) -> foast.FunctionDe if deduce_stmt_return_kind(new_body) == StmtReturnKind.NO_RETURN: raise CompilerError( - loc, "function is expected to return a value, return statement not found" + loc, "Function is expected to return a value." ) return foast.FunctionDefinition( From 1bcbdf734201e5f6718fa6a7ff5ec2a63089da21 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Kardos?= Date: Thu, 13 Jul 2023 11:37:25 +0200 Subject: [PATCH 42/54] delete test files --- tests/next_tests/exception_printing.py | 25 ------------------------- 1 file changed, 25 deletions(-) delete mode 100644 tests/next_tests/exception_printing.py diff --git a/tests/next_tests/exception_printing.py b/tests/next_tests/exception_printing.py deleted file mode 100644 index a3175b45d2..0000000000 --- a/tests/next_tests/exception_printing.py +++ /dev/null @@ -1,25 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2023, ETH Zurich -# All rights reserved. -# -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later - -import inspect - -from gt4py.eve import SourceLocation -from gt4py.next.errors import CompilerError - - -frameinfo = inspect.getframeinfo(inspect.currentframe()) -loc = SourceLocation( - frameinfo.filename, frameinfo.lineno, 1, end_line=frameinfo.lineno, end_column=5 -) -raise CompilerError(loc, "this is an error message") from ValueError("asd") From cc162b311bb324741e530869ed205f5d84dd0cb1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Kardos?= Date: Thu, 13 Jul 2023 12:09:45 +0200 Subject: [PATCH 43/54] use fixtures --- src/gt4py/next/ffront/func_to_foast.py | 4 +- .../errors_tests/test_exceptions.py | 51 ++++++++++------ .../errors_tests/test_formatting.py | 61 ++++++++++++------- 3 files changed, 72 insertions(+), 44 deletions(-) diff --git a/src/gt4py/next/ffront/func_to_foast.py b/src/gt4py/next/ffront/func_to_foast.py index 4f1e21d08b..27d09f2473 100644 --- a/src/gt4py/next/ffront/func_to_foast.py +++ b/src/gt4py/next/ffront/func_to_foast.py @@ -170,9 +170,7 @@ def visit_FunctionDef(self, node: ast.FunctionDef, **kwargs) -> foast.FunctionDe new_body = self._visit_stmts(node.body, self.get_location(node), **kwargs) if deduce_stmt_return_kind(new_body) == StmtReturnKind.NO_RETURN: - raise CompilerError( - loc, "Function is expected to return a value." - ) + raise CompilerError(loc, "Function is expected to return a value.") return foast.FunctionDefinition( id=node.name, diff --git a/tests/next_tests/unit_tests/errors_tests/test_exceptions.py b/tests/next_tests/unit_tests/errors_tests/test_exceptions.py index e3fde95e90..2449659755 100644 --- a/tests/next_tests/unit_tests/errors_tests/test_exceptions.py +++ b/tests/next_tests/unit_tests/errors_tests/test_exceptions.py @@ -15,42 +15,55 @@ import inspect import re +import pytest + from gt4py.eve import SourceLocation from gt4py.next.errors import CompilerError -frameinfo = inspect.getframeinfo(inspect.currentframe()) -loc = SourceLocation("/source/file.py", 5, 2, end_line=5, end_column=9) -loc_snippet = SourceLocation( - frameinfo.filename, frameinfo.lineno + 2, 15, end_line=frameinfo.lineno + 2, end_column=29 -) -msg = "a message" +@pytest.fixture +def loc_snippet(): + frameinfo = inspect.getframeinfo(inspect.currentframe()) + # This very line of comment should be shown in the snippet. + return SourceLocation( + frameinfo.filename, frameinfo.lineno + 1, 15, end_line=frameinfo.lineno + 1, end_column=29 + ) + + +@pytest.fixture +def loc_plain(): + return SourceLocation("/source/file.py", 5, 2, end_line=5, end_column=9) + + +@pytest.fixture +def message(): + return "a message" -def test_message(): - assert CompilerError(loc, msg).message == msg +def test_message(loc_plain, message): + assert CompilerError(loc_plain, message).message == message -def test_location(): - assert CompilerError(loc, msg).location == loc +def test_location(loc_plain, message): + assert CompilerError(loc_plain, message).location == loc_plain -def test_with_location(): - assert CompilerError(None, msg).with_location(loc).location == loc +def test_with_location(loc_plain, message): + assert CompilerError(None, message).with_location(loc_plain).location == loc_plain -def test_str(): - pattern = f'{msg}\\n File ".*", line.*' - s = str(CompilerError(loc, msg)) +def test_str(loc_plain, message): + pattern = f'{message}\\n File ".*", line.*' + s = str(CompilerError(loc_plain, message)) assert re.match(pattern, s) -def test_str_snippet(): +def test_str_snippet(loc_snippet, message): pattern = ( - f"{msg}\\n" + f"{message}\\n" ' File ".*", line.*\\n' - " loc_snippet = SourceLocation.*\\n" + " # This very line of comment should be shown in the snippet.\\n" " \^\^\^\^\^\^\^\^\^\^\^\^\^\^" ) - s = str(CompilerError(loc_snippet, msg)) + s = str(CompilerError(loc_snippet, message)) assert re.match(pattern, s) diff --git a/tests/next_tests/unit_tests/errors_tests/test_formatting.py b/tests/next_tests/unit_tests/errors_tests/test_formatting.py index 5328a4f228..84eee36969 100644 --- a/tests/next_tests/unit_tests/errors_tests/test_formatting.py +++ b/tests/next_tests/unit_tests/errors_tests/test_formatting.py @@ -12,46 +12,63 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -import inspect import re +import pytest + from gt4py.eve import SourceLocation from gt4py.next.errors import CompilerError from gt4py.next.errors.formatting import format_compilation_error -frameinfo = inspect.getframeinfo(inspect.currentframe()) -loc = SourceLocation("/source/file.py", 5, 2, end_line=5, end_column=9) -msg = "a message" +@pytest.fixture +def message(): + return "a message" + + +@pytest.fixture +def location(): + return SourceLocation("/source/file.py", 5, 2, end_line=5, end_column=9) + + +@pytest.fixture +def tb(): + try: + raise Exception() + except Exception as ex: + return ex.__traceback__ + + +@pytest.fixture +def type_(): + return CompilerError + -module = CompilerError.__module__ -name = CompilerError.__name__ -try: - raise Exception() -except Exception as ex: - tb = ex.__traceback__ +@pytest.fixture +def qualname(type_): + return f"{type_.__module__}.{type_.__name__}" -def test_format(): - pattern = f"{module}.{name}: {msg}" - s = "\n".join(format_compilation_error(CompilerError, msg, None, None, None)) +def test_format(type_, qualname, message): + pattern = f"{qualname}: {message}" + s = "\n".join(format_compilation_error(type_, message, None, None, None)) assert re.match(pattern, s) -def test_format_loc(): - pattern = "Source location.*\\n" ' File "/source.*".*\\n' f"{module}.{name}: {msg}" - s = "".join(format_compilation_error(CompilerError, msg, loc, None, None)) +def test_format_loc(type_, qualname, message, location): + pattern = "Source location.*\\n" ' File "/source.*".*\\n' f"{qualname}: {message}" + s = "".join(format_compilation_error(type_, message, location, None, None)) assert re.match(pattern, s) -def test_format_traceback(): - pattern = "Traceback.*\\n" ' File ".*".*\\n' ".*\\n" f"{module}.{name}: {msg}" - s = "".join(format_compilation_error(CompilerError, msg, None, tb, None)) +def test_format_traceback(type_, qualname, message, tb): + pattern = "Traceback.*\\n" ' File ".*".*\\n' ".*\\n" f"{qualname}: {message}" + s = "".join(format_compilation_error(type_, message, None, tb, None)) assert re.match(pattern, s) -def test_format_cause(): +def test_format_cause(type_, qualname, message): cause = ValueError("asd") - pattern = "ValueError: asd\\n\\n" "The above.*\\n\\n" f"{module}.{name}: {msg}" - s = "".join(format_compilation_error(CompilerError, msg, None, None, cause)) + pattern = "ValueError: asd\\n\\n" "The above.*\\n\\n" f"{qualname}: {message}" + s = "".join(format_compilation_error(type_, message, None, None, cause)) assert re.match(pattern, s) From e04c7fd4cd6034f5d602557390af02dc05ff33c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Kardos?= Date: Thu, 13 Jul 2023 12:39:43 +0200 Subject: [PATCH 44/54] rename exception classes --- src/gt4py/next/errors/__init__.py | 4 +- src/gt4py/next/errors/excepthook.py | 4 +- src/gt4py/next/errors/exceptions.py | 16 ++-- src/gt4py/next/ffront/dialect_parser.py | 4 +- .../foast_passes/closure_var_folding.py | 4 +- .../ffront/foast_passes/type_deduction.py | 84 +++++++++---------- src/gt4py/next/ffront/func_to_foast.py | 24 +++--- src/gt4py/next/ffront/func_to_past.py | 6 +- .../next/ffront/past_passes/type_deduction.py | 12 +-- .../ffront_tests/test_execution.py | 8 +- .../ffront_tests/test_scalar_if.py | 8 +- .../ffront_tests/test_type_deduction.py | 28 +++---- .../feature_tests/test_util_cases.py | 4 +- .../errors_tests/test_excepthook.py | 4 +- .../errors_tests/test_exceptions.py | 12 +-- .../errors_tests/test_formatting.py | 4 +- .../ffront_tests/test_func_to_foast.py | 24 +++--- .../test_func_to_foast_error_line_number.py | 8 +- .../ffront_tests/test_func_to_past.py | 16 ++-- .../ffront_tests/test_past_to_itir.py | 4 +- 20 files changed, 136 insertions(+), 142 deletions(-) diff --git a/src/gt4py/next/errors/__init__.py b/src/gt4py/next/errors/__init__.py index 1e54a231c4..21550bff60 100644 --- a/src/gt4py/next/errors/__init__.py +++ b/src/gt4py/next/errors/__init__.py @@ -17,7 +17,7 @@ ) from .excepthook import set_verbose_exceptions from .exceptions import ( - CompilerError, + DSLError, InvalidParameterAnnotationError, MissingAttributeError, MissingParameterAnnotationError, @@ -27,7 +27,7 @@ __all__ = [ - "CompilerError", + "DSLError", "InvalidParameterAnnotationError", "MissingAttributeError", "MissingParameterAnnotationError", diff --git a/src/gt4py/next/errors/excepthook.py b/src/gt4py/next/errors/excepthook.py index 7d1e5b145e..1dfbbc15d9 100644 --- a/src/gt4py/next/errors/excepthook.py +++ b/src/gt4py/next/errors/excepthook.py @@ -38,7 +38,7 @@ def set_verbose_exceptions(enabled: bool = False) -> None: _verbose_exceptions = enabled -def _format_uncaught_error(err: exceptions.CompilerError, verbose_exceptions: bool) -> list[str]: +def _format_uncaught_error(err: exceptions.DSLError, verbose_exceptions: bool) -> list[str]: if verbose_exceptions: return formatting.format_compilation_error( type(err), @@ -57,7 +57,7 @@ def compilation_error_hook(fallback: Callable, type_: type, value: BaseException All other Python exceptions are formatted by the `fallback` hook. """ - if isinstance(value, exceptions.CompilerError): + if isinstance(value, exceptions.DSLError): exc_strs = _format_uncaught_error(value, _verbose_exceptions) print("".join(exc_strs), file=sys.stderr) else: diff --git a/src/gt4py/next/errors/exceptions.py b/src/gt4py/next/errors/exceptions.py index 3293b2d88e..375bd832c3 100644 --- a/src/gt4py/next/errors/exceptions.py +++ b/src/gt4py/next/errors/exceptions.py @@ -22,7 +22,7 @@ from . import formatting -class CompilerError(Exception): +class DSLError(Exception): location: Optional[SourceLocation] def __init__(self, location: Optional[SourceLocation], message: str) -> None: @@ -33,7 +33,7 @@ def __init__(self, location: Optional[SourceLocation], message: str) -> None: def message(self) -> str: return self.args[0] - def with_location(self, location: Optional[SourceLocation]) -> CompilerError: + def with_location(self, location: Optional[SourceLocation]) -> DSLError: self.location = location return self @@ -44,7 +44,7 @@ def __str__(self) -> str: return self.message -class UnsupportedPythonFeatureError(CompilerError): +class UnsupportedPythonFeatureError(DSLError): feature: str def __init__(self, location: Optional[SourceLocation], feature: str) -> None: @@ -52,7 +52,7 @@ def __init__(self, location: Optional[SourceLocation], feature: str) -> None: self.feature = feature -class UndefinedSymbolError(CompilerError): +class UndefinedSymbolError(DSLError): sym_name: str def __init__(self, location: Optional[SourceLocation], name: str) -> None: @@ -60,7 +60,7 @@ def __init__(self, location: Optional[SourceLocation], name: str) -> None: self.sym_name = name -class MissingAttributeError(CompilerError): +class MissingAttributeError(DSLError): attr_name: str def __init__(self, location: Optional[SourceLocation], attr_name: str) -> None: @@ -68,12 +68,12 @@ def __init__(self, location: Optional[SourceLocation], attr_name: str) -> None: self.attr_name = attr_name -class CompilerTypeError(CompilerError): +class TypeError_(DSLError): def __init__(self, location: Optional[SourceLocation], message: str) -> None: super().__init__(location, message) -class MissingParameterAnnotationError(CompilerTypeError): +class MissingParameterAnnotationError(TypeError_): param_name: str def __init__(self, location: Optional[SourceLocation], param_name: str) -> None: @@ -81,7 +81,7 @@ def __init__(self, location: Optional[SourceLocation], param_name: str) -> None: self.param_name = param_name -class InvalidParameterAnnotationError(CompilerTypeError): +class InvalidParameterAnnotationError(TypeError_): param_name: str annotated_type: Any diff --git a/src/gt4py/next/ffront/dialect_parser.py b/src/gt4py/next/ffront/dialect_parser.py index 38c6fd838b..c69f0a8f9e 100644 --- a/src/gt4py/next/ffront/dialect_parser.py +++ b/src/gt4py/next/ffront/dialect_parser.py @@ -20,7 +20,7 @@ from gt4py.eve.concepts import SourceLocation from gt4py.eve.extended_typing import Any, Generic, TypeVar -from gt4py.next.errors import CompilerError, UnsupportedPythonFeatureError +from gt4py.next.errors import DSLError, UnsupportedPythonFeatureError from gt4py.next.ffront.ast_passes.fix_missing_locations import FixMissingLocations from gt4py.next.ffront.ast_passes.remove_docstrings import RemoveDocstrings from gt4py.next.ffront.source_utils import SourceDefinition, get_closure_vars_from_function @@ -46,7 +46,7 @@ def parse_source_definition(source_definition: SourceDefinition) -> ast.AST: if err.end_offset is not None else None, ) - raise CompilerError(loc, err.msg).with_traceback(err.__traceback__) + raise DSLError(loc, err.msg).with_traceback(err.__traceback__) @dataclass(frozen=True, kw_only=True) diff --git a/src/gt4py/next/ffront/foast_passes/closure_var_folding.py b/src/gt4py/next/ffront/foast_passes/closure_var_folding.py index a97090e6a8..f30b0c856a 100644 --- a/src/gt4py/next/ffront/foast_passes/closure_var_folding.py +++ b/src/gt4py/next/ffront/foast_passes/closure_var_folding.py @@ -18,7 +18,7 @@ import gt4py.next.ffront.field_operator_ast as foast from gt4py.eve import NodeTranslator, traits from gt4py.eve.utils import FrozenNamespace -from gt4py.next.errors import CompilerError, MissingAttributeError +from gt4py.next.errors import DSLError, MissingAttributeError @dataclass @@ -56,7 +56,7 @@ def visit_Attribute(self, node: foast.Attribute, **kwargs) -> foast.Constant: if hasattr(value.value, node.attr): return foast.Constant(value=getattr(value.value, node.attr), location=node.location) raise MissingAttributeError(node.location, node.attr) - raise CompilerError(node.location, "attribute access only applicable to constants") + raise DSLError(node.location, "attribute access only applicable to constants") def visit_FunctionDefinition( self, node: foast.FunctionDefinition, **kwargs diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index 6a949a0245..6b66b07e8a 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -17,7 +17,7 @@ import gt4py.next.ffront.field_operator_ast as foast from gt4py.eve import NodeTranslator, NodeVisitor, traits from gt4py.next.common import DimensionKind -from gt4py.next.errors import CompilerError +from gt4py.next.errors import DSLError from gt4py.next.ffront import ( # noqa dialect_ast_enums, fbuiltins, @@ -145,7 +145,7 @@ def deduce_stmt_return_type( if return_types[0] == return_types[1]: is_unconditional_return = True else: - raise CompilerError( + raise DSLError( stmt.location, f"If statement contains return statements with inconsistent types:" f"{return_types[0]} != {return_types[1]}", @@ -162,7 +162,7 @@ def deduce_stmt_return_type( raise AssertionError(f"Nodes of type `{type(stmt).__name__}` not supported.") if conditional_return_type and return_type and return_type != conditional_return_type: - raise CompilerError( + raise DSLError( stmt.location, f"If statement contains return statements with inconsistent types:" f"{conditional_return_type} != {conditional_return_type}", @@ -248,7 +248,7 @@ def visit_FunctionDefinition(self, node: foast.FunctionDefinition, **kwargs): new_closure_vars = self.visit(node.closure_vars, **kwargs) return_type = deduce_stmt_return_type(new_body) if not isinstance(return_type, (ts.DataType, ts.DeferredType, ts.VoidType)): - raise CompilerError( + raise DSLError( node.location, f"Function must return `DataType`, `DeferredType`, or `VoidType`, got `{return_type}`.", ) @@ -280,18 +280,18 @@ def visit_FieldOperator(self, node: foast.FieldOperator, **kwargs) -> foast.Fiel def visit_ScanOperator(self, node: foast.ScanOperator, **kwargs) -> foast.ScanOperator: new_axis = self.visit(node.axis, **kwargs) if not isinstance(new_axis.type, ts.DimensionType): - raise CompilerError( + raise DSLError( node.location, f"Argument `axis` to scan operator `{node.id}` must be a dimension.", ) if not new_axis.type.dim.kind == DimensionKind.VERTICAL: - raise CompilerError( + raise DSLError( node.location, f"Argument `axis` to scan operator `{node.id}` must be a vertical dimension.", ) new_forward = self.visit(node.forward, **kwargs) if not new_forward.type.kind == ts.ScalarKind.BOOL: - raise CompilerError( + raise DSLError( node.location, f"Argument `forward` to scan operator `{node.id}` must be a boolean." ) new_init = self.visit(node.init, **kwargs) @@ -299,7 +299,7 @@ def visit_ScanOperator(self, node: foast.ScanOperator, **kwargs) -> foast.ScanOp type_info.is_arithmetic(type_) or type_info.is_logical(type_) for type_ in type_info.primitive_constituents(new_init.type) ): - raise CompilerError( + raise DSLError( node.location, f"Argument `init` to scan operator `{node.id}` must " f"be an arithmetic type or a logical type or a composite of arithmetic and logical types.", @@ -322,7 +322,7 @@ def visit_ScanOperator(self, node: foast.ScanOperator, **kwargs) -> foast.ScanOp def visit_Name(self, node: foast.Name, **kwargs) -> foast.Name: symtable = kwargs["symtable"] if node.id not in symtable or symtable[node.id].type is None: - raise CompilerError(node.location, f"Undeclared symbol `{node.id}`.") + raise DSLError(node.location, f"Undeclared symbol `{node.id}`.") symbol = symtable[node.id] return foast.Name(id=node.id, type=symbol.type, location=node.location) @@ -346,7 +346,7 @@ def visit_TupleTargetAssign( indices: list[tuple[int, int] | int] = compute_assign_indices(targets, num_elts) if not any(isinstance(i, tuple) for i in indices) and len(indices) != num_elts: - raise CompilerError( + raise DSLError( node.location, f"Too many values to unpack (expected {len(indices)})." ) @@ -378,7 +378,7 @@ def visit_TupleTargetAssign( ) new_targets.append(new_target) else: - raise CompilerError( + raise DSLError( node.location, f"Assignment value must be of type tuple! Got: {values.type}" ) @@ -397,14 +397,14 @@ def visit_IfStmt(self, node: foast.IfStmt, **kwargs) -> foast.IfStmt: ) if not isinstance(new_node.condition.type, ts.ScalarType): - raise CompilerError( + raise DSLError( node.location, "Condition for `if` must be scalar. " f"But got `{new_node.condition.type}` instead.", ) if new_node.condition.type.kind != ts.ScalarKind.BOOL: - raise CompilerError( + raise DSLError( node.location, "Condition for `if` must be of boolean type. " f"But got `{new_node.condition.type}` instead.", @@ -414,7 +414,7 @@ def visit_IfStmt(self, node: foast.IfStmt, **kwargs) -> foast.IfStmt: if (true_type := new_true_branch.annex.symtable[sym].type) != ( false_type := new_false_branch.annex.symtable[sym].type ): - raise CompilerError( + raise DSLError( node.location, f"Inconsistent types between two branches for variable `{sym}`. " f"Got types `{true_type}` and `{false_type}.", @@ -435,7 +435,7 @@ def visit_Symbol( symtable = kwargs["symtable"] if refine_type: if not type_info.is_concretizable(node.type, to_type=refine_type): - raise CompilerError( + raise DSLError( node.location, ( "type inconsistency: expression was deduced to be " @@ -457,7 +457,7 @@ def visit_Subscript(self, node: foast.Subscript, **kwargs) -> foast.Subscript: new_type = types[node.index] case ts.OffsetType(source=source, target=(target1, target2)): if not target2.kind == DimensionKind.LOCAL: - raise CompilerError( + raise DSLError( new_value.location, "Second dimension in offset must be a local dimension." ) new_type = ts.OffsetType(source=source, target=(target1,)) @@ -466,15 +466,13 @@ def visit_Subscript(self, node: foast.Subscript, **kwargs) -> foast.Subscript: # signifies the displacement in the respective dimension, # but does not change the target type. if source != target: - raise CompilerError( + raise DSLError( new_value.location, "Source and target must be equal for offsets with a single target.", ) new_type = new_value.type case _: - raise CompilerError( - new_value.location, "Could not deduce type of subscript expression!" - ) + raise DSLError(new_value.location, "Could not deduce type of subscript expression!") return foast.Subscript( value=new_value, index=node.index, type=new_type, location=node.location @@ -512,13 +510,13 @@ def _deduce_ternaryexpr_type( false_expr: foast.Expr, ) -> Optional[ts.TypeSpec]: if condition.type != ts.ScalarType(kind=ts.ScalarKind.BOOL): - raise CompilerError( + raise DSLError( condition.location, f"Condition is of type `{condition.type}` " f"but should be of type `bool`.", ) if true_expr.type != false_expr.type: - raise CompilerError( + raise DSLError( node.location, f"Left and right types are not the same: `{true_expr.type}` and `{false_expr.type}`", ) @@ -538,7 +536,7 @@ def _deduce_compare_type( # check both types compatible for arg in (left, right): if not type_info.is_arithmetic(arg.type): - raise CompilerError( + raise DSLError( arg.location, f"Type {arg.type} can not be used in operator '{node.op}'!" ) @@ -549,7 +547,7 @@ def _deduce_compare_type( # mechanism to handle dimension promotion return type_info.promote(boolified_type(left.type), boolified_type(right.type)) except ValueError as ex: - raise CompilerError( + raise DSLError( node.location, f"Could not promote `{left.type}` and `{right.type}` to common type" f" in call to `{node.op}`.", @@ -573,7 +571,7 @@ def _deduce_binop_type( # check both types compatible for arg in (left, right): if not is_compatible(arg.type): - raise CompilerError( + raise DSLError( arg.location, f"Type {arg.type} can not be used in operator `{node.op}`!" ) @@ -586,7 +584,7 @@ def _deduce_binop_type( if node.op == dialect_ast_enums.BinaryOperator.MOD and not type_info.is_integral( right_type ): - raise CompilerError( + raise DSLError( arg.location, f"Type {right_type} can not be used in operator `{node.op}`, it can only accept ints", ) @@ -594,7 +592,7 @@ def _deduce_binop_type( try: return type_info.promote(left_type, right_type) except ValueError as ex: - raise CompilerError( + raise DSLError( node.location, f"Could not promote `{left_type}` and `{right_type}` to common type" f" in call to `{node.op}`.", @@ -605,7 +603,7 @@ def _check_operand_dtypes_match( ) -> None: # check dtypes match if not type_info.extract_dtype(left.type) == type_info.extract_dtype(right.type): - raise CompilerError( + raise DSLError( node.location, f"Incompatible datatypes in operator `{node.op}`: {left.type} and {right.type}!", ) @@ -622,7 +620,7 @@ def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs) -> foast.UnaryOp: else type_info.is_arithmetic ) if not is_compatible(new_operand.type): - raise CompilerError( + raise DSLError( node.location, f"Incompatible type for unary operator `{node.op}`: `{new_operand.type}`!", ) @@ -654,11 +652,11 @@ def visit_Call(self, node: foast.Call, **kwargs) -> foast.Call: new_func, (foast.FunctionDefinition, foast.FieldOperator, foast.ScanOperator, foast.Name), ): - raise CompilerError(node.location, "Functions can only be called directly!") + raise DSLError(node.location, "Functions can only be called directly!") elif isinstance(new_func.type, ts.FieldType): pass else: - raise CompilerError( + raise DSLError( node.location, f"Expression of type `{new_func.type}` is not callable, must be a `Function`, `FieldOperator`, `ScanOperator` or `Field`.", ) @@ -672,7 +670,7 @@ def visit_Call(self, node: foast.Call, **kwargs) -> foast.Call: raise_exception=True, ) except ValueError as err: - raise CompilerError( + raise DSLError( node.location, f"Invalid argument types in call to `{new_func}`!" ) from err @@ -730,7 +728,7 @@ def _visit_math_built_in(self, node: foast.Call, **kwargs) -> foast.Call: f"Expected {i}-th argument to be {error_msg_for_validator[arg_validator]} type, but got `{arg.type}`." ) if error_msgs: - raise CompilerError( + raise DSLError( node.location, "\n".join([error_msg_preamble] + [f" - {error}" for error in error_msgs]), ) @@ -753,7 +751,7 @@ def _visit_math_built_in(self, node: foast.Call, **kwargs) -> foast.Call: *((cast(ts.FieldType | ts.ScalarType, arg.type)) for arg in node.args) ) except ValueError as ex: - raise CompilerError(node.location, error_msg_preamble) from ex + raise DSLError(node.location, error_msg_preamble) from ex else: raise AssertionError(f"Unknown math builtin `{func_name}`.") @@ -772,7 +770,7 @@ def _visit_reduction(self, node: foast.Call, **kwargs) -> foast.Call: assert field_type.dims is not ... if reduction_dim not in field_type.dims: field_dims_str = ", ".join(str(dim) for dim in field_type.dims) - raise CompilerError( + raise DSLError( node.location, f"Incompatible field argument in call to `{str(node.func)}`. " f"Expected a field with dimension {reduction_dim}, but got " @@ -828,7 +826,7 @@ def _visit_as_offset(self, node: foast.Call, **kwargs) -> foast.Call: assert isinstance(arg_0, ts.OffsetType) assert isinstance(arg_1, ts.FieldType) if not type_info.is_integral(arg_1): - raise CompilerError( + raise DSLError( node.location, f"Incompatible argument in call to `{str(node.func)}`. " f"Excepted integer for offset field dtype, but got {arg_1.dtype}" @@ -836,7 +834,7 @@ def _visit_as_offset(self, node: foast.Call, **kwargs) -> foast.Call: ) if arg_0.source not in arg_1.dims: - raise CompilerError( + raise DSLError( node.location, f"Incompatible argument in call to `{str(node.func)}`. " f"{arg_0.source} not in list of offset field dimensions {arg_1.dims}. " @@ -857,7 +855,7 @@ def _visit_where(self, node: foast.Call, **kwargs) -> foast.Call: false_branch_type = node.args[2].type return_type: ts.TupleType | ts.FieldType if not type_info.is_logical(mask_type): - raise CompilerError( + raise DSLError( node.location, f"Incompatible argument in call to `{str(node.func)}`. Expected " f"a field with dtype `bool`, but got `{mask_type}`.", @@ -875,7 +873,7 @@ def _visit_where(self, node: foast.Call, **kwargs) -> foast.Call: elif isinstance(true_branch_type, ts.TupleType) or isinstance( false_branch_type, ts.TupleType ): - raise CompilerError( + raise DSLError( node.location, f"Return arguments need to be of same type in {str(node.func)}, but got: " f"{node.args[1].type} and {node.args[2].type}", @@ -887,7 +885,7 @@ def _visit_where(self, node: foast.Call, **kwargs) -> foast.Call: return_type = promote_to_mask_type(mask_type, promoted_type) except ValueError as ex: - raise CompilerError( + raise DSLError( node.location, f"Incompatible argument in call to `{str(node.func)}`.", ) from ex @@ -905,7 +903,7 @@ def _visit_broadcast(self, node: foast.Call, **kwargs) -> foast.Call: broadcast_dims_expr = cast(foast.TupleExpr, node.args[1]).elts if any([not (isinstance(elt.type, ts.DimensionType)) for elt in broadcast_dims_expr]): - raise CompilerError( + raise DSLError( node.location, f"Incompatible broadcast dimension type in {str(node.func)}. Expected " f"all broadcast dimensions to be of type Dimension.", @@ -914,7 +912,7 @@ def _visit_broadcast(self, node: foast.Call, **kwargs) -> foast.Call: broadcast_dims = [cast(ts.DimensionType, elt.type).dim for elt in broadcast_dims_expr] if not set((arg_dims := type_info.extract_dims(arg_type))).issubset(set(broadcast_dims)): - raise CompilerError( + raise DSLError( node.location, f"Incompatible broadcast dimensions in {str(node.func)}. Expected " f"broadcast dimension is missing {set(arg_dims).difference(set(broadcast_dims))}", @@ -937,5 +935,5 @@ def visit_Constant(self, node: foast.Constant, **kwargs) -> foast.Constant: try: type_ = type_translation.from_value(node.value) except ValueError as e: - raise CompilerError(node.location, "Could not deduce type of constant.") from e + raise DSLError(node.location, "Could not deduce type of constant.") from e return foast.Constant(value=node.value, location=node.location, type=type_) diff --git a/src/gt4py/next/ffront/func_to_foast.py b/src/gt4py/next/ffront/func_to_foast.py index 27d09f2473..9fe9bd09f8 100644 --- a/src/gt4py/next/ffront/func_to_foast.py +++ b/src/gt4py/next/ffront/func_to_foast.py @@ -20,7 +20,7 @@ import gt4py.eve as eve from gt4py.next.errors import ( - CompilerError, + DSLError, InvalidParameterAnnotationError, MissingParameterAnnotationError, UnsupportedPythonFeatureError, @@ -76,7 +76,7 @@ class FieldOperatorParser(DialectParser[foast.FunctionDefinition]): >>> >>> try: # doctest: +ELLIPSIS ... FieldOperatorParser.apply_to_function(wrong_syntax) - ... except CompilerError as err: + ... except DSLError as err: ... print(f"Error at [{err.location.line}, {err.location.column}] in {err.location.filename})") Error at [2, 5] in ...func_to_foast.FieldOperatorParser[...]>) """ @@ -108,7 +108,7 @@ def _postprocess_dialect_ast( # TODO(tehrengruber): use `type_info.return_type` when the type of the # arguments becomes available here if annotated_return_type != foast_node.type.returns: # type: ignore[union-attr] # revisit when `type_info.return_type` is implemented - raise CompilerError( + raise DSLError( foast_node.location, f"Annotated return type does not match deduced return type. Expected `{foast_node.type.returns}`" # type: ignore[union-attr] # revisit when `type_info.return_type` is implemented f", but got `{annotated_return_type}`.", @@ -170,7 +170,7 @@ def visit_FunctionDef(self, node: ast.FunctionDef, **kwargs) -> foast.FunctionDe new_body = self._visit_stmts(node.body, self.get_location(node), **kwargs) if deduce_stmt_return_kind(new_body) == StmtReturnKind.NO_RETURN: - raise CompilerError(loc, "Function is expected to return a value.") + raise DSLError(loc, "Function is expected to return a value.") return foast.FunctionDefinition( id=node.name, @@ -227,7 +227,7 @@ def visit_Assign(self, node: ast.Assign, **kwargs) -> foast.Assign | foast.Tuple ) if not isinstance(target, ast.Name): - raise CompilerError(self.get_location(node), "can only assign to names") + raise DSLError(self.get_location(node), "can only assign to names") new_value = self.visit(node.value) constraint_type: Type[ts.DataType] = ts.DataType if isinstance(new_value, foast.TupleExpr): @@ -249,7 +249,7 @@ def visit_Assign(self, node: ast.Assign, **kwargs) -> foast.Assign | foast.Tuple def visit_AnnAssign(self, node: ast.AnnAssign, **kwargs) -> foast.Assign: if not isinstance(node.target, ast.Name): - raise CompilerError(self.get_location(node), "can only assign to names") + raise DSLError(self.get_location(node), "can only assign to names") if node.annotation is not None: assert isinstance( @@ -290,9 +290,7 @@ def visit_Subscript(self, node: ast.Subscript, **kwargs) -> foast.Subscript: try: index = self._match_index(node.slice) except ValueError: - raise CompilerError( - self.get_location(node.slice), "expected an integral index" - ) from None + raise DSLError(self.get_location(node.slice), "expected an integral index") from None return foast.Subscript( value=self.visit(node.value), @@ -313,7 +311,7 @@ def visit_Tuple(self, node: ast.Tuple, **kwargs) -> foast.TupleExpr: def visit_Return(self, node: ast.Return, **kwargs) -> foast.Return: loc = self.get_location(node) if not node.value: - raise CompilerError(loc, "must return a value, not None") + raise DSLError(loc, "must return a value, not None") return foast.Return(value=self.visit(node.value), location=loc) def visit_Expr(self, node: ast.Expr) -> foast.Expr: @@ -443,7 +441,7 @@ def visit_NotEq(self, node: ast.NotEq, **kwargs) -> foast.CompareOperator: def _verify_builtin_type_constructor(self, node: ast.Call): if len(node.args) > 0 and not isinstance(node.args[0], ast.Constant): - raise CompilerError( + raise DSLError( self.get_location(node), f"{self._func_name(node)}() only takes literal arguments!", ) @@ -470,9 +468,7 @@ def visit_Constant(self, node: ast.Constant, **kwargs) -> foast.Constant: try: type_ = type_translation.from_value(node.value) except ValueError: - raise CompilerError( - loc, f"constants of type {type(node.value)} are not permitted" - ) from None + raise DSLError(loc, f"constants of type {type(node.value)} are not permitted") from None return foast.Constant( value=node.value, diff --git a/src/gt4py/next/ffront/func_to_past.py b/src/gt4py/next/ffront/func_to_past.py index 8ed466cc29..7911283646 100644 --- a/src/gt4py/next/ffront/func_to_past.py +++ b/src/gt4py/next/ffront/func_to_past.py @@ -19,7 +19,7 @@ from typing import Any, cast from gt4py.next.errors import ( - CompilerError, + DSLError, InvalidParameterAnnotationError, MissingParameterAnnotationError, ) @@ -132,7 +132,7 @@ def visit_Call(self, node: ast.Call) -> past.Call: loc = self.get_location(node) new_func = self.visit(node.func) if not isinstance(new_func, past.Name): - raise CompilerError(loc, "functions must be referenced by their name in function calls") + raise DSLError(loc, "functions must be referenced by their name in function calls") return past.Call( func=new_func, @@ -168,7 +168,7 @@ def visit_UnaryOp(self, node: ast.UnaryOp) -> past.Constant: if isinstance(node.op, ast.USub) and isinstance(node.operand, ast.Constant): symbol_type = type_translation.from_value(node.operand.value) return past.Constant(value=-node.operand.value, type=symbol_type, location=loc) - raise CompilerError(loc, "unary operators are only applicable to literals") + raise DSLError(loc, "unary operators are only applicable to literals") def visit_Constant(self, node: ast.Constant) -> past.Constant: symbol_type = type_translation.from_value(node.value) diff --git a/src/gt4py/next/ffront/past_passes/type_deduction.py b/src/gt4py/next/ffront/past_passes/type_deduction.py index a45e85231e..c00d8710f3 100644 --- a/src/gt4py/next/ffront/past_passes/type_deduction.py +++ b/src/gt4py/next/ffront/past_passes/type_deduction.py @@ -15,7 +15,7 @@ from typing import Optional, cast from gt4py.eve import NodeTranslator, traits -from gt4py.next.errors import CompilerError +from gt4py.next.errors import DSLError from gt4py.next.ffront import ( dialect_ast_enums, program_ast as past, @@ -148,7 +148,7 @@ def _deduce_binop_type( # check both types compatible for arg in (left, right): if not isinstance(arg.type, ts.ScalarType) or not is_compatible(arg.type): - raise CompilerError( + raise DSLError( arg.location, f"Type {arg.type} can not be used in operator `{node.op}`!" ) @@ -161,7 +161,7 @@ def _deduce_binop_type( if node.op == dialect_ast_enums.BinaryOperator.MOD and not type_info.is_integral( right_type ): - raise CompilerError( + raise DSLError( arg.location, f"Type {right_type} can not be used in operator `{node.op}`, it can only accept ints", ) @@ -169,7 +169,7 @@ def _deduce_binop_type( try: return type_info.promote(left_type, right_type) except ValueError as ex: - raise CompilerError( + raise DSLError( node.location, f"Could not promote `{left_type}` and `{right_type}` to common type" f" in call to `{node.op}`.", @@ -231,7 +231,7 @@ def visit_Call(self, node: past.Call, **kwargs): ) except ValueError as ex: - raise CompilerError(node.location, f"Invalid call to `{node.func.id}`.") from ex + raise DSLError(node.location, f"Invalid call to `{node.func.id}`.") from ex return past.Call( func=new_func, @@ -244,6 +244,6 @@ def visit_Call(self, node: past.Call, **kwargs): def visit_Name(self, node: past.Name, **kwargs) -> past.Name: symtable = kwargs["symtable"] if node.id not in symtable or symtable[node.id].type is None: - raise CompilerError(node.location, f"Undeclared or untyped symbol `{node.id}`.") + raise DSLError(node.location, f"Undeclared or untyped symbol `{node.id}`.") return past.Name(id=node.id, type=symtable[node.id].type, location=node.location) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 871f853eb7..cf18055a2b 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -31,7 +31,7 @@ neighbor_sum, where, ) -from gt4py.next.errors import CompilerError +from gt4py.next.errors import DSLError from gt4py.next.ffront.experimental import as_offset from gt4py.next.program_processors.runners import gtfn_cpu @@ -823,7 +823,7 @@ def fieldop_where_k_offset( def test_undefined_symbols(cartesian_case): - with pytest.raises(CompilerError, match="Undeclared symbol"): + with pytest.raises(DSLError, match="Undeclared symbol"): @gtx.field_operator(backend=cartesian_case.backend) def return_undefined(): @@ -918,7 +918,7 @@ def unpack( def test_tuple_unpacking_too_many_values(cartesian_case): with pytest.raises( - CompilerError, + DSLError, match=(r"Could not deduce type: Too many values to unpack \(expected 3\)"), ): @@ -929,7 +929,7 @@ def _star_unpack() -> tuple[int32, float64, int32]: def test_tuple_unpacking_too_many_values(cartesian_case): - with pytest.raises(CompilerError, match=(r"Assignment value must be of type tuple!")): + with pytest.raises(DSLError, match=(r"Assignment value must be of type tuple!")): @gtx.field_operator(backend=cartesian_case.backend) def _invalid_unpack() -> tuple[int32, float64, int32]: diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py index d2a6d91b8c..f5390f914b 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py @@ -19,7 +19,7 @@ import pytest from gt4py.next import Field, field_operator, float64, index_field, np_as_located_field -from gt4py.next.errors import CompilerError +from gt4py.next.errors import DSLError from gt4py.next.program_processors.runners import gtfn_cpu from next_tests.integration_tests.feature_tests import cases @@ -357,7 +357,7 @@ def if_without_else( def test_if_non_scalar_condition(): - with pytest.raises(CompilerError, match="Condition for `if` must be scalar."): + with pytest.raises(DSLError, match="Condition for `if` must be scalar."): @field_operator def if_non_scalar_condition( @@ -370,7 +370,7 @@ def if_non_scalar_condition( def test_if_non_boolean_condition(): - with pytest.raises(CompilerError, match="Condition for `if` must be of boolean type."): + with pytest.raises(DSLError, match="Condition for `if` must be of boolean type."): @field_operator def if_non_boolean_condition( @@ -385,7 +385,7 @@ def if_non_boolean_condition( def test_if_inconsistent_types(): with pytest.raises( - CompilerError, + DSLError, match="Inconsistent types between two branches for variable", ): diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py index 6d88eafcf8..bacf91c275 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py @@ -31,7 +31,7 @@ neighbor_sum, where, ) -from gt4py.next.errors import CompilerError +from gt4py.next.errors import DSLError from gt4py.next.ffront.ast_passes import single_static_assign as ssa from gt4py.next.ffront.experimental import as_offset from gt4py.next.ffront.func_to_foast import FieldOperatorParser @@ -505,7 +505,7 @@ def add_bools(a: Field[[TDim], bool], b: Field[[TDim], bool]): return a + b with pytest.raises( - CompilerError, + DSLError, match=(r"Type Field\[\[TDim\], bool\] can not be used in operator `\+`!"), ): _ = FieldOperatorParser.apply_to_function(add_bools) @@ -520,7 +520,7 @@ def nonmatching(a: Field[[X], float64], b: Field[[Y], float64]): return a + b with pytest.raises( - CompilerError, + DSLError, match=( r"Could not promote `Field\[\[X], float64\]` and `Field\[\[Y\], float64\]` to common type in call to +." ), @@ -533,7 +533,7 @@ def float_bitop(a: Field[[TDim], float], b: Field[[TDim], float]): return a & b with pytest.raises( - CompilerError, + DSLError, match=(r"Type Field\[\[TDim\], float64\] can not be used in operator `\&`!"), ): _ = FieldOperatorParser.apply_to_function(float_bitop) @@ -544,7 +544,7 @@ def sign_bool(a: Field[[TDim], bool]): return -a with pytest.raises( - CompilerError, + DSLError, match=r"Incompatible type for unary operator `\-`: `Field\[\[TDim\], bool\]`!", ): _ = FieldOperatorParser.apply_to_function(sign_bool) @@ -555,7 +555,7 @@ def not_int(a: Field[[TDim], int64]): return not a with pytest.raises( - CompilerError, + DSLError, match=r"Incompatible type for unary operator `not`: `Field\[\[TDim\], int64\]`!", ): _ = FieldOperatorParser.apply_to_function(not_int) @@ -627,7 +627,7 @@ def mismatched_lit() -> Field[[TDim], "float32"]: return float32("1.0") + float64("1.0") with pytest.raises( - CompilerError, + DSLError, match=(r"Could not promote `float32` and `float64` to common type in call to +."), ): _ = FieldOperatorParser.apply_to_function(mismatched_lit) @@ -657,7 +657,7 @@ def disjoint_broadcast(a: Field[[ADim], float64]): return broadcast(a, (BDim, CDim)) with pytest.raises( - CompilerError, + DSLError, match=r"Expected broadcast dimension is missing", ): _ = FieldOperatorParser.apply_to_function(disjoint_broadcast) @@ -672,7 +672,7 @@ def badtype_broadcast(a: Field[[ADim], float64]): return broadcast(a, (BDim, CDim)) with pytest.raises( - CompilerError, + DSLError, match=r"Expected all broadcast dimensions to be of type Dimension.", ): _ = FieldOperatorParser.apply_to_function(badtype_broadcast) @@ -738,7 +738,7 @@ def bad_dim_where(a: Field[[ADim], bool], b: Field[[ADim], float64]): return where(a, ((5.0, 9.0), (b, 6.0)), b) with pytest.raises( - CompilerError, + DSLError, match=r"Return arguments need to be of same type", ): _ = FieldOperatorParser.apply_to_function(bad_dim_where) @@ -793,7 +793,7 @@ def modulo_floats(inp: Field[[TDim], float]): return inp % 3.0 with pytest.raises( - CompilerError, + DSLError, match=r"Type float64 can not be used in operator `%`", ): _ = FieldOperatorParser.apply_to_function(modulo_floats) @@ -803,7 +803,7 @@ def test_undefined_symbols(): def return_undefined(): return undefined_symbol - with pytest.raises(CompilerError, match="Undeclared symbol"): + with pytest.raises(DSLError, match="Undeclared symbol"): _ = FieldOperatorParser.apply_to_function(return_undefined) @@ -816,7 +816,7 @@ def as_offset_dim(a: Field[[ADim, BDim], float], b: Field[[ADim], int]): return a(as_offset(Boff, b)) with pytest.raises( - CompilerError, + DSLError, match=f"not in list of offset field dimensions", ): _ = FieldOperatorParser.apply_to_function(as_offset_dim) @@ -831,7 +831,7 @@ def as_offset_dtype(a: Field[[ADim, BDim], float], b: Field[[BDim], float]): return a(as_offset(Boff, b)) with pytest.raises( - CompilerError, + DSLError, match=f"Excepted integer for offset field dtype", ): _ = FieldOperatorParser.apply_to_function(as_offset_dtype) diff --git a/tests/next_tests/integration_tests/feature_tests/test_util_cases.py b/tests/next_tests/integration_tests/feature_tests/test_util_cases.py index 63a65d60ed..bde6c7c247 100644 --- a/tests/next_tests/integration_tests/feature_tests/test_util_cases.py +++ b/tests/next_tests/integration_tests/feature_tests/test_util_cases.py @@ -16,7 +16,7 @@ import pytest import gt4py.next as gtx -from gt4py.next.errors import CompilerError +from gt4py.next.errors import DSLError from gt4py.next.program_processors.runners import roundtrip from next_tests.integration_tests.feature_tests import cases @@ -87,7 +87,7 @@ def test_verify_fails_with_wrong_type(cartesian_case): # noqa: F811 # fixtures b = cases.allocate(cartesian_case, addition, "b")() out = cases.allocate(cartesian_case, addition, cases.RETURN)() - with pytest.raises(CompilerError): + with pytest.raises(DSLError): cases.verify(cartesian_case, addition, a, b, out=out, ref=a.array() + b.array()) diff --git a/tests/next_tests/unit_tests/errors_tests/test_excepthook.py b/tests/next_tests/unit_tests/errors_tests/test_excepthook.py index d14ac49934..f7db8e7a0d 100644 --- a/tests/next_tests/unit_tests/errors_tests/test_excepthook.py +++ b/tests/next_tests/unit_tests/errors_tests/test_excepthook.py @@ -20,8 +20,8 @@ def test_format_uncaught_error(): try: loc = eve.SourceLocation("/src/file.py", 1, 1) msg = "compile error msg" - raise exceptions.CompilerError(loc, msg) from ValueError("value error msg") - except exceptions.CompilerError as err: + raise exceptions.DSLError(loc, msg) from ValueError("value error msg") + except exceptions.DSLError as err: str_devmode = "".join(excepthook._format_uncaught_error(err, True)) assert str_devmode.find("Source location") >= 0 assert str_devmode.find("Traceback") >= 0 diff --git a/tests/next_tests/unit_tests/errors_tests/test_exceptions.py b/tests/next_tests/unit_tests/errors_tests/test_exceptions.py index 2449659755..90d1ebb8b0 100644 --- a/tests/next_tests/unit_tests/errors_tests/test_exceptions.py +++ b/tests/next_tests/unit_tests/errors_tests/test_exceptions.py @@ -18,7 +18,7 @@ import pytest from gt4py.eve import SourceLocation -from gt4py.next.errors import CompilerError +from gt4py.next.errors import DSLError @pytest.fixture @@ -41,20 +41,20 @@ def message(): def test_message(loc_plain, message): - assert CompilerError(loc_plain, message).message == message + assert DSLError(loc_plain, message).message == message def test_location(loc_plain, message): - assert CompilerError(loc_plain, message).location == loc_plain + assert DSLError(loc_plain, message).location == loc_plain def test_with_location(loc_plain, message): - assert CompilerError(None, message).with_location(loc_plain).location == loc_plain + assert DSLError(None, message).with_location(loc_plain).location == loc_plain def test_str(loc_plain, message): pattern = f'{message}\\n File ".*", line.*' - s = str(CompilerError(loc_plain, message)) + s = str(DSLError(loc_plain, message)) assert re.match(pattern, s) @@ -65,5 +65,5 @@ def test_str_snippet(loc_snippet, message): " # This very line of comment should be shown in the snippet.\\n" " \^\^\^\^\^\^\^\^\^\^\^\^\^\^" ) - s = str(CompilerError(loc_snippet, message)) + s = str(DSLError(loc_snippet, message)) assert re.match(pattern, s) diff --git a/tests/next_tests/unit_tests/errors_tests/test_formatting.py b/tests/next_tests/unit_tests/errors_tests/test_formatting.py index 84eee36969..78a206eda3 100644 --- a/tests/next_tests/unit_tests/errors_tests/test_formatting.py +++ b/tests/next_tests/unit_tests/errors_tests/test_formatting.py @@ -17,7 +17,7 @@ import pytest from gt4py.eve import SourceLocation -from gt4py.next.errors import CompilerError +from gt4py.next.errors import DSLError from gt4py.next.errors.formatting import format_compilation_error @@ -41,7 +41,7 @@ def tb(): @pytest.fixture def type_(): - return CompilerError + return DSLError @pytest.fixture diff --git a/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py b/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py index 271b2b849d..21612ea8db 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py @@ -43,7 +43,7 @@ from gt4py.eve.pattern_matching import ObjectPattern as P from gt4py.next import astype, broadcast, float32, float64, int32, int64, where from gt4py.next.errors import ( - CompilerError, + DSLError, MissingParameterAnnotationError, UnsupportedPythonFeatureError, ) @@ -119,7 +119,7 @@ def no_return(inp: gtx.Field[[TDim], "float64"]): tmp = inp # noqa with pytest.raises( - CompilerError, + DSLError, match=".*return.*", ): _ = FieldOperatorParser.apply_to_function(no_return) @@ -135,7 +135,7 @@ def invalid_assign_to_expr( tmp[-1] = inp2 return tmp - with pytest.raises(CompilerError, match=r".*assign.*"): + with pytest.raises(DSLError, match=r".*assign.*"): _ = FieldOperatorParser.apply_to_function(invalid_assign_to_expr) @@ -161,7 +161,7 @@ def clashing(inp: gtx.Field[[TDim], "float64"]): tmp: gtx.Field[[TDim], "int64"] = inp return tmp - with pytest.raises(CompilerError, match="type inconsistency"): + with pytest.raises(DSLError, match="type inconsistency"): _ = FieldOperatorParser.apply_to_function(clashing) @@ -240,7 +240,7 @@ def cast_scalar_temp(): tmp = int64(1) return int32(tmp) - with pytest.raises(CompilerError, match=r".*literal.*"): + with pytest.raises(DSLError, match=r".*literal.*"): _ = FieldOperatorParser.apply_to_function(cast_scalar_temp) @@ -251,7 +251,7 @@ def conditional_wrong_mask_type( return where(a, a, a) msg = r"Expected a field with dtype `bool`." - with pytest.raises(CompilerError, match=msg): + with pytest.raises(DSLError, match=msg): _ = FieldOperatorParser.apply_to_function(conditional_wrong_mask_type) @@ -264,7 +264,7 @@ def conditional_wrong_arg_type( return where(mask, a, b) msg = r"Could not promote scalars of different dtype \(not implemented\)." - with pytest.raises(CompilerError) as exc_info: + with pytest.raises(DSLError) as exc_info: _ = FieldOperatorParser.apply_to_function(conditional_wrong_arg_type) assert re.search(msg, exc_info.value.__cause__.args[0]) is not None @@ -274,7 +274,7 @@ def test_ternary_with_field_condition(): def ternary_with_field_condition(cond: gtx.Field[[], bool]): return 1 if cond else 2 - with pytest.raises(CompilerError, match=r"should be .* `bool`"): + with pytest.raises(DSLError, match=r"should be .* `bool`"): _ = FieldOperatorParser.apply_to_function(ternary_with_field_condition) @@ -293,7 +293,7 @@ def test_adr13_wrong_return_type_annotation(): def wrong_return_type_annotation() -> gtx.Field[[], float]: return 1.0 - with pytest.raises(CompilerError, match=r"Expected `float.*`"): + with pytest.raises(DSLError, match=r"Expected `float.*`"): _ = FieldOperatorParser.apply_to_function(wrong_return_type_annotation) @@ -375,7 +375,7 @@ def wrong_return_type_annotation(a: gtx.Field[[ADim], float64]) -> gtx.Field[[BD return a with pytest.raises( - CompilerError, + DSLError, match=r"Annotated return type does not match deduced return type", ): _ = FieldOperatorParser.apply_to_function(wrong_return_type_annotation) @@ -386,7 +386,7 @@ def empty_dims() -> gtx.Field[[], float]: return 1.0 with pytest.raises( - CompilerError, + DSLError, match=r"Annotated return type does not match deduced return type", ): _ = FieldOperatorParser.apply_to_function(empty_dims) @@ -401,7 +401,7 @@ def zero_dims_ternary( return a if cond == 1 else b msg = r"Incompatible datatypes in operator `==`" - with pytest.raises(CompilerError) as exc_info: + with pytest.raises(DSLError) as exc_info: _ = FieldOperatorParser.apply_to_function(zero_dims_ternary) assert re.search(msg, exc_info.value.args[0]) is not None diff --git a/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast_error_line_number.py b/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast_error_line_number.py index e83d8acbcb..d83bf298f1 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast_error_line_number.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast_error_line_number.py @@ -18,7 +18,7 @@ import pytest import gt4py.next as gtx -from gt4py.next.errors import CompilerError +from gt4py.next.errors import DSLError from gt4py.next.ffront import func_to_foast as f2f, source_utils as src_utils from gt4py.next.ffront.foast_passes import type_deduction @@ -37,7 +37,7 @@ def wrong_syntax(inp: gtx.Field[[TDim], float]): return # <-- this line triggers the syntax error with pytest.raises( - f2f.CompilerError, + f2f.DSLError, match=(r".*return.*"), ) as exc_info: _ = f2f.FieldOperatorParser.apply_to_function(wrong_syntax) @@ -63,7 +63,7 @@ def invalid_python_syntax(): """, ) - with pytest.raises(CompilerError) as exc_info: + with pytest.raises(DSLError) as exc_info: _ = f2f.FieldOperatorParser.apply(source_definition, {}, {}) assert exc_info.value.location @@ -82,7 +82,7 @@ def test_fo_type_deduction_error(): def field_operator_with_undeclared_symbol(): return undeclared_symbol # noqa: F821 # undefined on purpose - with pytest.raises(CompilerError) as exc_info: + with pytest.raises(DSLError) as exc_info: _ = f2f.FieldOperatorParser.apply_to_function(field_operator_with_undeclared_symbol) exc = exc_info.value diff --git a/tests/next_tests/unit_tests/ffront_tests/test_func_to_past.py b/tests/next_tests/unit_tests/ffront_tests/test_func_to_past.py index 3b2ac53fde..d92b5abbb7 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_func_to_past.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_func_to_past.py @@ -20,7 +20,7 @@ import gt4py.next as gtx from gt4py.eve.pattern_matching import ObjectPattern as P from gt4py.next import float64 -from gt4py.next.errors import CompilerError +from gt4py.next.errors import DSLError from gt4py.next.ffront import program_ast as past from gt4py.next.ffront.func_to_past import ProgramParser from gt4py.next.type_system import type_specifications as ts @@ -113,7 +113,7 @@ def undefined_field_program(in_field: gtx.Field[[IDim], "float64"]): identity(in_field, out=out_field) # noqa: F821 # undefined on purpose with pytest.raises( - CompilerError, + DSLError, match=(r"Undeclared or untyped symbol `out_field`."), ): ProgramParser.apply_to_function(undefined_field_program) @@ -162,7 +162,7 @@ def domain_format_1_program(in_field: gtx.Field[[IDim], float64]): domain_format_1(in_field, out=in_field, domain=(0, 2)) with pytest.raises( - CompilerError, + DSLError, ) as exc_info: ProgramParser.apply_to_function(domain_format_1_program) @@ -181,7 +181,7 @@ def domain_format_2_program(in_field: gtx.Field[[IDim], float64]): domain_format_2(in_field, out=in_field, domain={IDim: (0, 1, 2)}) with pytest.raises( - CompilerError, + DSLError, ) as exc_info: ProgramParser.apply_to_function(domain_format_2_program) @@ -200,7 +200,7 @@ def domain_format_3_program(in_field: gtx.Field[[IDim], float64]): domain_format_3(in_field, domain={IDim: (0, 2)}) with pytest.raises( - CompilerError, + DSLError, ) as exc_info: ProgramParser.apply_to_function(domain_format_3_program) @@ -221,7 +221,7 @@ def domain_format_4_program(in_field: gtx.Field[[IDim], float64]): ) with pytest.raises( - CompilerError, + DSLError, ) as exc_info: ProgramParser.apply_to_function(domain_format_4_program) @@ -240,7 +240,7 @@ def domain_format_5_program(in_field: gtx.Field[[IDim], float64]): domain_format_5(in_field, out=in_field, domain={IDim: ("1.0", 9.0)}) with pytest.raises( - CompilerError, + DSLError, ) as exc_info: ProgramParser.apply_to_function(domain_format_5_program) @@ -259,7 +259,7 @@ def domain_format_6_program(in_field: gtx.Field[[IDim], float64]): domain_format_6(in_field, out=in_field, domain={}) with pytest.raises( - CompilerError, + DSLError, ) as exc_info: ProgramParser.apply_to_function(domain_format_6_program) diff --git a/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py b/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py index f33837458a..083a9796b9 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py @@ -19,7 +19,7 @@ import gt4py.eve as eve import gt4py.next as gtx from gt4py.eve.pattern_matching import ObjectPattern as P -from gt4py.next.errors import CompilerError +from gt4py.next.errors import DSLError from gt4py.next.ffront.func_to_past import ProgramParser from gt4py.next.ffront.past_to_itir import ProgramLowering from gt4py.next.iterator import ir as itir @@ -169,7 +169,7 @@ def inout_field_program(inout_field: gtx.Field[[IDim], "float64"]): def test_invalid_call_sig_program(invalid_call_sig_program_def): with pytest.raises( - CompilerError, + DSLError, ) as exc_info: ProgramLowering.apply( ProgramParser.apply_to_function(invalid_call_sig_program_def), From 8fe6d067a9a1396c21e6d908c15ffe1924c34989 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Kardos?= Date: Thu, 13 Jul 2023 13:26:19 +0200 Subject: [PATCH 45/54] remove test removed in another PR --- .../ffront_tests/test_program.py | 43 ------------------- 1 file changed, 43 deletions(-) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py index d2c80d26be..92c27e0a2b 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py @@ -253,46 +253,3 @@ def empty_domain_program(a: cases.IJField, out_field: cases.IJField): match=(r"Dimensions in out field and field domain are not equivalent"), ): empty_domain_program(a, out_field, offset_provider={}) - - -def test_input_kwargs(fieldview_backend): - size = 10 - input_1 = gtx.np_as_located_field(IDim, JDim)(np.ones((size, size))) - input_2 = gtx.np_as_located_field(IDim, JDim)(np.ones((size, size)) * 2) - input_3 = gtx.np_as_located_field(IDim, JDim)(np.ones((size, size)) * 3) - - expected = np.asarray(input_3) * np.asarray(input_1) - np.asarray(input_2) - - @gtx.field_operator(backend=fieldview_backend) - def fieldop_input_kwargs( - a: gtx.Field[[IDim, JDim], float64], - b: gtx.Field[[IDim, JDim], float64], - c: gtx.Field[[IDim, JDim], float64], - ) -> gtx.Field[[IDim, JDim], float64]: - return c * a - b - - out = gtx.np_as_located_field(IDim, JDim)(np.zeros((size, size))) - fieldop_input_kwargs(input_1, b=input_2, c=input_3, out=out, offset_provider={}) - assert np.allclose(expected, out) - - @gtx.program(backend=fieldview_backend) - def program_input_kwargs( - a: gtx.Field[[IDim, JDim], float64], - b: gtx.Field[[IDim, JDim], float64], - c: gtx.Field[[IDim, JDim], float64], - out: gtx.Field[[IDim, JDim], float64], - ): - fieldop_input_kwargs(a, b, c, out=out) - - out = gtx.np_as_located_field(IDim, JDim)(np.zeros((size, size))) - program_input_kwargs(input_1, b=input_2, c=input_3, out=out, offset_provider={}) - assert np.allclose(expected, out) - - out = gtx.np_as_located_field(IDim, JDim)(np.zeros((size, size))) - program_input_kwargs(a=input_1, b=input_2, c=input_3, out=out, offset_provider={}) - assert np.allclose(expected, out) - - with pytest.raises( - TypeError, match="Invalid argument types in call to `program_input_kwargs`!" - ): - program_input_kwargs(input_2, input_3, a=input_1, out=out, offset_provider={}) From 7a14225e497639ba4b041087847b150d3ed1b293 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Kardos?= Date: Tue, 18 Jul 2023 12:32:03 +0200 Subject: [PATCH 46/54] add doscstrings to modules --- src/gt4py/next/errors/__init__.py | 2 ++ src/gt4py/next/errors/excepthook.py | 9 +++++++++ src/gt4py/next/errors/exceptions.py | 11 +++++++++++ src/gt4py/next/errors/formatting.py | 2 ++ 4 files changed, 24 insertions(+) diff --git a/src/gt4py/next/errors/__init__.py b/src/gt4py/next/errors/__init__.py index 21550bff60..61441e83b9 100644 --- a/src/gt4py/next/errors/__init__.py +++ b/src/gt4py/next/errors/__init__.py @@ -12,6 +12,8 @@ # # SPDX-License-Identifier: GPL-3.0-or-later +"""Contains the exception classes and other utilities for error handling.""" + from . import ( # noqa: module needs to be loaded for pretty printing of uncaught exceptions. excepthook, ) diff --git a/src/gt4py/next/errors/excepthook.py b/src/gt4py/next/errors/excepthook.py index 1dfbbc15d9..41693f0af6 100644 --- a/src/gt4py/next/errors/excepthook.py +++ b/src/gt4py/next/errors/excepthook.py @@ -11,6 +11,15 @@ # distribution for a copy of the license or check . # # SPDX-License-Identifier: GPL-3.0-or-later + +""" +Loading this module registers an excepthook that formats py:class:: DSLError. + +The excepthook is necessary because the default hook prints DSLErrors in an +inconvenient way. The previously set excepthook is used to print all other +errors. +""" + import os import sys from typing import Callable diff --git a/src/gt4py/next/errors/exceptions.py b/src/gt4py/next/errors/exceptions.py index 375bd832c3..830a933981 100644 --- a/src/gt4py/next/errors/exceptions.py +++ b/src/gt4py/next/errors/exceptions.py @@ -12,6 +12,17 @@ # # SPDX-License-Identifier: GPL-3.0-or-later +""" +The list of exception classes used in the library. + +Exception classes that represent errors within an IR go here as a subclass of +py:class:: DSLError. Exception classes that represent other errors, like +the builtin ValueError, go here as well, although you should use Python's +builtin error classes if you can. Exception classes that are specific to a +certain submodule and have no use for the entire application may be better off +in that submodule as opposed to being in this file. +""" + from __future__ import annotations import textwrap diff --git a/src/gt4py/next/errors/formatting.py b/src/gt4py/next/errors/formatting.py index face3873b0..1126861cc9 100644 --- a/src/gt4py/next/errors/formatting.py +++ b/src/gt4py/next/errors/formatting.py @@ -12,6 +12,8 @@ # # SPDX-License-Identifier: GPL-3.0-or-later +"""Utility functions for formatting py:class:: DSLError and its subclasses.""" + import linecache import textwrap import traceback From a7e756603f84e0fc633ac4979e1ec40985aa32d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Kardos?= Date: Wed, 19 Jul 2023 12:30:17 +0200 Subject: [PATCH 47/54] fix doc string formatting --- src/gt4py/next/errors/excepthook.py | 8 ++++---- src/gt4py/next/errors/exceptions.py | 2 +- src/gt4py/next/errors/formatting.py | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/gt4py/next/errors/excepthook.py b/src/gt4py/next/errors/excepthook.py index 41693f0af6..9e8c0a3e58 100644 --- a/src/gt4py/next/errors/excepthook.py +++ b/src/gt4py/next/errors/excepthook.py @@ -13,11 +13,11 @@ # SPDX-License-Identifier: GPL-3.0-or-later """ -Loading this module registers an excepthook that formats py:class:: DSLError. +Loading this module registers an excepthook that formats :class:`DSLError`. -The excepthook is necessary because the default hook prints DSLErrors in an -inconvenient way. The previously set excepthook is used to print all other -errors. +The excepthook is necessary because the default hook prints :class:`DSLError`s +in an inconvenient way. The previously set excepthook is used to print all +other errors. """ import os diff --git a/src/gt4py/next/errors/exceptions.py b/src/gt4py/next/errors/exceptions.py index 830a933981..74230263db 100644 --- a/src/gt4py/next/errors/exceptions.py +++ b/src/gt4py/next/errors/exceptions.py @@ -16,7 +16,7 @@ The list of exception classes used in the library. Exception classes that represent errors within an IR go here as a subclass of -py:class:: DSLError. Exception classes that represent other errors, like +:class:`DSLError`. Exception classes that represent other errors, like the builtin ValueError, go here as well, although you should use Python's builtin error classes if you can. Exception classes that are specific to a certain submodule and have no use for the entire application may be better off diff --git a/src/gt4py/next/errors/formatting.py b/src/gt4py/next/errors/formatting.py index 1126861cc9..1c20a12507 100644 --- a/src/gt4py/next/errors/formatting.py +++ b/src/gt4py/next/errors/formatting.py @@ -12,7 +12,7 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -"""Utility functions for formatting py:class:: DSLError and its subclasses.""" +"""Utility functions for formatting :class:`DSLError` and its subclasses.""" import linecache import textwrap From 2cac27eaca4bebd7a9191489ea1585d32ffb0dc6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Kardos?= Date: Wed, 19 Jul 2023 12:33:12 +0200 Subject: [PATCH 48/54] import module for type checking --- src/gt4py/next/errors/formatting.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/errors/formatting.py b/src/gt4py/next/errors/formatting.py index 1c20a12507..0176607971 100644 --- a/src/gt4py/next/errors/formatting.py +++ b/src/gt4py/next/errors/formatting.py @@ -14,15 +14,21 @@ """Utility functions for formatting :class:`DSLError` and its subclasses.""" +from __future__ import annotations + import linecache import textwrap import traceback import types -from typing import Optional +from typing import TYPE_CHECKING, Optional from gt4py.eve import SourceLocation +if TYPE_CHECKING: + from . import exceptions + + def get_source_from_location(location: SourceLocation) -> str: if not location.filename: raise FileNotFoundError() @@ -80,7 +86,7 @@ def _format_traceback(tb: types.TracebackType) -> list[str]: def format_compilation_error( - type_: type[Exception], + type_: type[exceptions.DSLError], message: str, location: Optional[SourceLocation], tb: Optional[types.TracebackType] = None, From 98b0badd2ce443ab7cc9eab3c3bf7474b9643b65 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Kardos?= Date: Wed, 19 Jul 2023 15:19:57 +0200 Subject: [PATCH 49/54] remove direct imports --- src/gt4py/next/ffront/dialect_parser.py | 6 +- .../foast_passes/closure_var_folding.py | 6 +- .../ffront/foast_passes/type_deduction.py | 84 ++++++++++--------- src/gt4py/next/ffront/func_to_foast.py | 37 ++++---- src/gt4py/next/ffront/func_to_past.py | 16 ++-- .../next/ffront/past_passes/type_deduction.py | 12 +-- .../ffront_tests/test_execution.py | 8 +- .../ffront_tests/test_scalar_if.py | 9 +- .../ffront_tests/test_type_deduction.py | 28 +++---- .../feature_tests/test_util_cases.py | 4 +- .../cpp_backend_tests/test_driver.py | 3 +- .../errors_tests/test_excepthook.py | 10 +-- .../errors_tests/test_exceptions.py | 12 +-- .../errors_tests/test_formatting.py | 4 +- .../ffront_tests/test_func_to_foast.py | 35 ++++---- .../test_func_to_foast_error_line_number.py | 8 +- .../ffront_tests/test_func_to_past.py | 17 ++-- .../ffront_tests/test_past_to_itir.py | 4 +- 18 files changed, 148 insertions(+), 155 deletions(-) diff --git a/src/gt4py/next/ffront/dialect_parser.py b/src/gt4py/next/ffront/dialect_parser.py index c69f0a8f9e..c04e978e51 100644 --- a/src/gt4py/next/ffront/dialect_parser.py +++ b/src/gt4py/next/ffront/dialect_parser.py @@ -20,7 +20,7 @@ from gt4py.eve.concepts import SourceLocation from gt4py.eve.extended_typing import Any, Generic, TypeVar -from gt4py.next.errors import DSLError, UnsupportedPythonFeatureError +from gt4py.next import errors from gt4py.next.ffront.ast_passes.fix_missing_locations import FixMissingLocations from gt4py.next.ffront.ast_passes.remove_docstrings import RemoveDocstrings from gt4py.next.ffront.source_utils import SourceDefinition, get_closure_vars_from_function @@ -46,7 +46,7 @@ def parse_source_definition(source_definition: SourceDefinition) -> ast.AST: if err.end_offset is not None else None, ) - raise DSLError(loc, err.msg).with_traceback(err.__traceback__) + raise errors.DSLError(loc, err.msg).with_traceback(err.__traceback__) @dataclass(frozen=True, kw_only=True) @@ -99,7 +99,7 @@ def _postprocess_dialect_ast( def generic_visit(self, node: ast.AST) -> None: loc = self.get_location(node) feature = f"{type(node).__module__}.{type(node).__qualname__}" - raise UnsupportedPythonFeatureError(loc, feature) + raise errors.UnsupportedPythonFeatureError(loc, feature) def get_location(self, node: ast.AST) -> SourceLocation: file = self.source_definition.filename diff --git a/src/gt4py/next/ffront/foast_passes/closure_var_folding.py b/src/gt4py/next/ffront/foast_passes/closure_var_folding.py index f30b0c856a..9afd22de2c 100644 --- a/src/gt4py/next/ffront/foast_passes/closure_var_folding.py +++ b/src/gt4py/next/ffront/foast_passes/closure_var_folding.py @@ -18,7 +18,7 @@ import gt4py.next.ffront.field_operator_ast as foast from gt4py.eve import NodeTranslator, traits from gt4py.eve.utils import FrozenNamespace -from gt4py.next.errors import DSLError, MissingAttributeError +from gt4py.next import errors @dataclass @@ -55,8 +55,8 @@ def visit_Attribute(self, node: foast.Attribute, **kwargs) -> foast.Constant: if isinstance(value, foast.Constant): if hasattr(value.value, node.attr): return foast.Constant(value=getattr(value.value, node.attr), location=node.location) - raise MissingAttributeError(node.location, node.attr) - raise DSLError(node.location, "attribute access only applicable to constants") + raise errors.MissingAttributeError(node.location, node.attr) + raise errors.DSLError(node.location, "attribute access only applicable to constants") def visit_FunctionDefinition( self, node: foast.FunctionDefinition, **kwargs diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index 6b66b07e8a..2ebf690580 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -16,8 +16,8 @@ import gt4py.next.ffront.field_operator_ast as foast from gt4py.eve import NodeTranslator, NodeVisitor, traits +from gt4py.next import errors from gt4py.next.common import DimensionKind -from gt4py.next.errors import DSLError from gt4py.next.ffront import ( # noqa dialect_ast_enums, fbuiltins, @@ -145,7 +145,7 @@ def deduce_stmt_return_type( if return_types[0] == return_types[1]: is_unconditional_return = True else: - raise DSLError( + raise errors.DSLError( stmt.location, f"If statement contains return statements with inconsistent types:" f"{return_types[0]} != {return_types[1]}", @@ -162,7 +162,7 @@ def deduce_stmt_return_type( raise AssertionError(f"Nodes of type `{type(stmt).__name__}` not supported.") if conditional_return_type and return_type and return_type != conditional_return_type: - raise DSLError( + raise errors.DSLError( stmt.location, f"If statement contains return statements with inconsistent types:" f"{conditional_return_type} != {conditional_return_type}", @@ -248,7 +248,7 @@ def visit_FunctionDefinition(self, node: foast.FunctionDefinition, **kwargs): new_closure_vars = self.visit(node.closure_vars, **kwargs) return_type = deduce_stmt_return_type(new_body) if not isinstance(return_type, (ts.DataType, ts.DeferredType, ts.VoidType)): - raise DSLError( + raise errors.DSLError( node.location, f"Function must return `DataType`, `DeferredType`, or `VoidType`, got `{return_type}`.", ) @@ -280,18 +280,18 @@ def visit_FieldOperator(self, node: foast.FieldOperator, **kwargs) -> foast.Fiel def visit_ScanOperator(self, node: foast.ScanOperator, **kwargs) -> foast.ScanOperator: new_axis = self.visit(node.axis, **kwargs) if not isinstance(new_axis.type, ts.DimensionType): - raise DSLError( + raise errors.DSLError( node.location, f"Argument `axis` to scan operator `{node.id}` must be a dimension.", ) if not new_axis.type.dim.kind == DimensionKind.VERTICAL: - raise DSLError( + raise errors.DSLError( node.location, f"Argument `axis` to scan operator `{node.id}` must be a vertical dimension.", ) new_forward = self.visit(node.forward, **kwargs) if not new_forward.type.kind == ts.ScalarKind.BOOL: - raise DSLError( + raise errors.DSLError( node.location, f"Argument `forward` to scan operator `{node.id}` must be a boolean." ) new_init = self.visit(node.init, **kwargs) @@ -299,7 +299,7 @@ def visit_ScanOperator(self, node: foast.ScanOperator, **kwargs) -> foast.ScanOp type_info.is_arithmetic(type_) or type_info.is_logical(type_) for type_ in type_info.primitive_constituents(new_init.type) ): - raise DSLError( + raise errors.DSLError( node.location, f"Argument `init` to scan operator `{node.id}` must " f"be an arithmetic type or a logical type or a composite of arithmetic and logical types.", @@ -322,7 +322,7 @@ def visit_ScanOperator(self, node: foast.ScanOperator, **kwargs) -> foast.ScanOp def visit_Name(self, node: foast.Name, **kwargs) -> foast.Name: symtable = kwargs["symtable"] if node.id not in symtable or symtable[node.id].type is None: - raise DSLError(node.location, f"Undeclared symbol `{node.id}`.") + raise errors.DSLError(node.location, f"Undeclared symbol `{node.id}`.") symbol = symtable[node.id] return foast.Name(id=node.id, type=symbol.type, location=node.location) @@ -346,7 +346,7 @@ def visit_TupleTargetAssign( indices: list[tuple[int, int] | int] = compute_assign_indices(targets, num_elts) if not any(isinstance(i, tuple) for i in indices) and len(indices) != num_elts: - raise DSLError( + raise errors.DSLError( node.location, f"Too many values to unpack (expected {len(indices)})." ) @@ -378,7 +378,7 @@ def visit_TupleTargetAssign( ) new_targets.append(new_target) else: - raise DSLError( + raise errors.DSLError( node.location, f"Assignment value must be of type tuple! Got: {values.type}" ) @@ -397,14 +397,14 @@ def visit_IfStmt(self, node: foast.IfStmt, **kwargs) -> foast.IfStmt: ) if not isinstance(new_node.condition.type, ts.ScalarType): - raise DSLError( + raise errors.DSLError( node.location, "Condition for `if` must be scalar. " f"But got `{new_node.condition.type}` instead.", ) if new_node.condition.type.kind != ts.ScalarKind.BOOL: - raise DSLError( + raise errors.DSLError( node.location, "Condition for `if` must be of boolean type. " f"But got `{new_node.condition.type}` instead.", @@ -414,7 +414,7 @@ def visit_IfStmt(self, node: foast.IfStmt, **kwargs) -> foast.IfStmt: if (true_type := new_true_branch.annex.symtable[sym].type) != ( false_type := new_false_branch.annex.symtable[sym].type ): - raise DSLError( + raise errors.DSLError( node.location, f"Inconsistent types between two branches for variable `{sym}`. " f"Got types `{true_type}` and `{false_type}.", @@ -435,7 +435,7 @@ def visit_Symbol( symtable = kwargs["symtable"] if refine_type: if not type_info.is_concretizable(node.type, to_type=refine_type): - raise DSLError( + raise errors.DSLError( node.location, ( "type inconsistency: expression was deduced to be " @@ -457,7 +457,7 @@ def visit_Subscript(self, node: foast.Subscript, **kwargs) -> foast.Subscript: new_type = types[node.index] case ts.OffsetType(source=source, target=(target1, target2)): if not target2.kind == DimensionKind.LOCAL: - raise DSLError( + raise errors.DSLError( new_value.location, "Second dimension in offset must be a local dimension." ) new_type = ts.OffsetType(source=source, target=(target1,)) @@ -466,13 +466,15 @@ def visit_Subscript(self, node: foast.Subscript, **kwargs) -> foast.Subscript: # signifies the displacement in the respective dimension, # but does not change the target type. if source != target: - raise DSLError( + raise errors.DSLError( new_value.location, "Source and target must be equal for offsets with a single target.", ) new_type = new_value.type case _: - raise DSLError(new_value.location, "Could not deduce type of subscript expression!") + raise errors.DSLError( + new_value.location, "Could not deduce type of subscript expression!" + ) return foast.Subscript( value=new_value, index=node.index, type=new_type, location=node.location @@ -510,13 +512,13 @@ def _deduce_ternaryexpr_type( false_expr: foast.Expr, ) -> Optional[ts.TypeSpec]: if condition.type != ts.ScalarType(kind=ts.ScalarKind.BOOL): - raise DSLError( + raise errors.DSLError( condition.location, f"Condition is of type `{condition.type}` " f"but should be of type `bool`.", ) if true_expr.type != false_expr.type: - raise DSLError( + raise errors.DSLError( node.location, f"Left and right types are not the same: `{true_expr.type}` and `{false_expr.type}`", ) @@ -536,7 +538,7 @@ def _deduce_compare_type( # check both types compatible for arg in (left, right): if not type_info.is_arithmetic(arg.type): - raise DSLError( + raise errors.DSLError( arg.location, f"Type {arg.type} can not be used in operator '{node.op}'!" ) @@ -547,7 +549,7 @@ def _deduce_compare_type( # mechanism to handle dimension promotion return type_info.promote(boolified_type(left.type), boolified_type(right.type)) except ValueError as ex: - raise DSLError( + raise errors.DSLError( node.location, f"Could not promote `{left.type}` and `{right.type}` to common type" f" in call to `{node.op}`.", @@ -571,7 +573,7 @@ def _deduce_binop_type( # check both types compatible for arg in (left, right): if not is_compatible(arg.type): - raise DSLError( + raise errors.DSLError( arg.location, f"Type {arg.type} can not be used in operator `{node.op}`!" ) @@ -584,7 +586,7 @@ def _deduce_binop_type( if node.op == dialect_ast_enums.BinaryOperator.MOD and not type_info.is_integral( right_type ): - raise DSLError( + raise errors.DSLError( arg.location, f"Type {right_type} can not be used in operator `{node.op}`, it can only accept ints", ) @@ -592,7 +594,7 @@ def _deduce_binop_type( try: return type_info.promote(left_type, right_type) except ValueError as ex: - raise DSLError( + raise errors.DSLError( node.location, f"Could not promote `{left_type}` and `{right_type}` to common type" f" in call to `{node.op}`.", @@ -603,7 +605,7 @@ def _check_operand_dtypes_match( ) -> None: # check dtypes match if not type_info.extract_dtype(left.type) == type_info.extract_dtype(right.type): - raise DSLError( + raise errors.DSLError( node.location, f"Incompatible datatypes in operator `{node.op}`: {left.type} and {right.type}!", ) @@ -620,7 +622,7 @@ def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs) -> foast.UnaryOp: else type_info.is_arithmetic ) if not is_compatible(new_operand.type): - raise DSLError( + raise errors.DSLError( node.location, f"Incompatible type for unary operator `{node.op}`: `{new_operand.type}`!", ) @@ -652,11 +654,11 @@ def visit_Call(self, node: foast.Call, **kwargs) -> foast.Call: new_func, (foast.FunctionDefinition, foast.FieldOperator, foast.ScanOperator, foast.Name), ): - raise DSLError(node.location, "Functions can only be called directly!") + raise errors.DSLError(node.location, "Functions can only be called directly!") elif isinstance(new_func.type, ts.FieldType): pass else: - raise DSLError( + raise errors.DSLError( node.location, f"Expression of type `{new_func.type}` is not callable, must be a `Function`, `FieldOperator`, `ScanOperator` or `Field`.", ) @@ -670,7 +672,7 @@ def visit_Call(self, node: foast.Call, **kwargs) -> foast.Call: raise_exception=True, ) except ValueError as err: - raise DSLError( + raise errors.DSLError( node.location, f"Invalid argument types in call to `{new_func}`!" ) from err @@ -728,7 +730,7 @@ def _visit_math_built_in(self, node: foast.Call, **kwargs) -> foast.Call: f"Expected {i}-th argument to be {error_msg_for_validator[arg_validator]} type, but got `{arg.type}`." ) if error_msgs: - raise DSLError( + raise errors.DSLError( node.location, "\n".join([error_msg_preamble] + [f" - {error}" for error in error_msgs]), ) @@ -751,7 +753,7 @@ def _visit_math_built_in(self, node: foast.Call, **kwargs) -> foast.Call: *((cast(ts.FieldType | ts.ScalarType, arg.type)) for arg in node.args) ) except ValueError as ex: - raise DSLError(node.location, error_msg_preamble) from ex + raise errors.DSLError(node.location, error_msg_preamble) from ex else: raise AssertionError(f"Unknown math builtin `{func_name}`.") @@ -770,7 +772,7 @@ def _visit_reduction(self, node: foast.Call, **kwargs) -> foast.Call: assert field_type.dims is not ... if reduction_dim not in field_type.dims: field_dims_str = ", ".join(str(dim) for dim in field_type.dims) - raise DSLError( + raise errors.DSLError( node.location, f"Incompatible field argument in call to `{str(node.func)}`. " f"Expected a field with dimension {reduction_dim}, but got " @@ -826,7 +828,7 @@ def _visit_as_offset(self, node: foast.Call, **kwargs) -> foast.Call: assert isinstance(arg_0, ts.OffsetType) assert isinstance(arg_1, ts.FieldType) if not type_info.is_integral(arg_1): - raise DSLError( + raise errors.DSLError( node.location, f"Incompatible argument in call to `{str(node.func)}`. " f"Excepted integer for offset field dtype, but got {arg_1.dtype}" @@ -834,7 +836,7 @@ def _visit_as_offset(self, node: foast.Call, **kwargs) -> foast.Call: ) if arg_0.source not in arg_1.dims: - raise DSLError( + raise errors.DSLError( node.location, f"Incompatible argument in call to `{str(node.func)}`. " f"{arg_0.source} not in list of offset field dimensions {arg_1.dims}. " @@ -855,7 +857,7 @@ def _visit_where(self, node: foast.Call, **kwargs) -> foast.Call: false_branch_type = node.args[2].type return_type: ts.TupleType | ts.FieldType if not type_info.is_logical(mask_type): - raise DSLError( + raise errors.DSLError( node.location, f"Incompatible argument in call to `{str(node.func)}`. Expected " f"a field with dtype `bool`, but got `{mask_type}`.", @@ -873,7 +875,7 @@ def _visit_where(self, node: foast.Call, **kwargs) -> foast.Call: elif isinstance(true_branch_type, ts.TupleType) or isinstance( false_branch_type, ts.TupleType ): - raise DSLError( + raise errors.DSLError( node.location, f"Return arguments need to be of same type in {str(node.func)}, but got: " f"{node.args[1].type} and {node.args[2].type}", @@ -885,7 +887,7 @@ def _visit_where(self, node: foast.Call, **kwargs) -> foast.Call: return_type = promote_to_mask_type(mask_type, promoted_type) except ValueError as ex: - raise DSLError( + raise errors.DSLError( node.location, f"Incompatible argument in call to `{str(node.func)}`.", ) from ex @@ -903,7 +905,7 @@ def _visit_broadcast(self, node: foast.Call, **kwargs) -> foast.Call: broadcast_dims_expr = cast(foast.TupleExpr, node.args[1]).elts if any([not (isinstance(elt.type, ts.DimensionType)) for elt in broadcast_dims_expr]): - raise DSLError( + raise errors.DSLError( node.location, f"Incompatible broadcast dimension type in {str(node.func)}. Expected " f"all broadcast dimensions to be of type Dimension.", @@ -912,7 +914,7 @@ def _visit_broadcast(self, node: foast.Call, **kwargs) -> foast.Call: broadcast_dims = [cast(ts.DimensionType, elt.type).dim for elt in broadcast_dims_expr] if not set((arg_dims := type_info.extract_dims(arg_type))).issubset(set(broadcast_dims)): - raise DSLError( + raise errors.DSLError( node.location, f"Incompatible broadcast dimensions in {str(node.func)}. Expected " f"broadcast dimension is missing {set(arg_dims).difference(set(broadcast_dims))}", @@ -935,5 +937,5 @@ def visit_Constant(self, node: foast.Constant, **kwargs) -> foast.Constant: try: type_ = type_translation.from_value(node.value) except ValueError as e: - raise DSLError(node.location, "Could not deduce type of constant.") from e + raise errors.DSLError(node.location, "Could not deduce type of constant.") from e return foast.Constant(value=node.value, location=node.location, type=type_) diff --git a/src/gt4py/next/ffront/func_to_foast.py b/src/gt4py/next/ffront/func_to_foast.py index 9fe9bd09f8..082939c938 100644 --- a/src/gt4py/next/ffront/func_to_foast.py +++ b/src/gt4py/next/ffront/func_to_foast.py @@ -19,12 +19,7 @@ from typing import Any, Callable, Iterable, Mapping, Type, cast import gt4py.eve as eve -from gt4py.next.errors import ( - DSLError, - InvalidParameterAnnotationError, - MissingParameterAnnotationError, - UnsupportedPythonFeatureError, -) +from gt4py.next import errors from gt4py.next.ffront import dialect_ast_enums, fbuiltins, field_operator_ast as foast from gt4py.next.ffront.ast_passes import ( SingleAssignTargetPass, @@ -76,7 +71,7 @@ class FieldOperatorParser(DialectParser[foast.FunctionDefinition]): >>> >>> try: # doctest: +ELLIPSIS ... FieldOperatorParser.apply_to_function(wrong_syntax) - ... except DSLError as err: + ... except errors.DSLError as err: ... print(f"Error at [{err.location.line}, {err.location.column}] in {err.location.filename})") Error at [2, 5] in ...func_to_foast.FieldOperatorParser[...]>) """ @@ -108,7 +103,7 @@ def _postprocess_dialect_ast( # TODO(tehrengruber): use `type_info.return_type` when the type of the # arguments becomes available here if annotated_return_type != foast_node.type.returns: # type: ignore[union-attr] # revisit when `type_info.return_type` is implemented - raise DSLError( + raise errors.DSLError( foast_node.location, f"Annotated return type does not match deduced return type. Expected `{foast_node.type.returns}`" # type: ignore[union-attr] # revisit when `type_info.return_type` is implemented f", but got `{annotated_return_type}`.", @@ -170,7 +165,7 @@ def visit_FunctionDef(self, node: ast.FunctionDef, **kwargs) -> foast.FunctionDe new_body = self._visit_stmts(node.body, self.get_location(node), **kwargs) if deduce_stmt_return_kind(new_body) == StmtReturnKind.NO_RETURN: - raise DSLError(loc, "Function is expected to return a value.") + raise errors.DSLError(loc, "Function is expected to return a value.") return foast.FunctionDefinition( id=node.name, @@ -186,10 +181,10 @@ def visit_arguments(self, node: ast.arguments) -> list[foast.DataSymbol]: def visit_arg(self, node: ast.arg) -> foast.DataSymbol: loc = self.get_location(node) if (annotation := self.annotations.get(node.arg, None)) is None: - raise MissingParameterAnnotationError(loc, node.arg) + raise errors.MissingParameterAnnotationError(loc, node.arg) new_type = type_translation.from_type_hint(annotation) if not isinstance(new_type, ts.DataType): - raise InvalidParameterAnnotationError(loc, node.arg, new_type) + raise errors.InvalidParameterAnnotationError(loc, node.arg, new_type) return foast.DataSymbol(id=node.arg, location=loc, type=new_type) def visit_Assign(self, node: ast.Assign, **kwargs) -> foast.Assign | foast.TupleTargetAssign: @@ -227,7 +222,7 @@ def visit_Assign(self, node: ast.Assign, **kwargs) -> foast.Assign | foast.Tuple ) if not isinstance(target, ast.Name): - raise DSLError(self.get_location(node), "can only assign to names") + raise errors.DSLError(self.get_location(node), "can only assign to names") new_value = self.visit(node.value) constraint_type: Type[ts.DataType] = ts.DataType if isinstance(new_value, foast.TupleExpr): @@ -249,7 +244,7 @@ def visit_Assign(self, node: ast.Assign, **kwargs) -> foast.Assign | foast.Tuple def visit_AnnAssign(self, node: ast.AnnAssign, **kwargs) -> foast.Assign: if not isinstance(node.target, ast.Name): - raise DSLError(self.get_location(node), "can only assign to names") + raise errors.DSLError(self.get_location(node), "can only assign to names") if node.annotation is not None: assert isinstance( @@ -290,7 +285,9 @@ def visit_Subscript(self, node: ast.Subscript, **kwargs) -> foast.Subscript: try: index = self._match_index(node.slice) except ValueError: - raise DSLError(self.get_location(node.slice), "expected an integral index") from None + raise errors.DSLError( + self.get_location(node.slice), "expected an integral index" + ) from None return foast.Subscript( value=self.visit(node.value), @@ -311,7 +308,7 @@ def visit_Tuple(self, node: ast.Tuple, **kwargs) -> foast.TupleExpr: def visit_Return(self, node: ast.Return, **kwargs) -> foast.Return: loc = self.get_location(node) if not node.value: - raise DSLError(loc, "must return a value, not None") + raise errors.DSLError(loc, "must return a value, not None") return foast.Return(value=self.visit(node.value), location=loc) def visit_Expr(self, node: ast.Expr) -> foast.Expr: @@ -378,7 +375,7 @@ def visit_BitXor(self, node: ast.BitXor, **kwargs) -> dialect_ast_enums.BinaryOp return dialect_ast_enums.BinaryOperator.BIT_XOR def visit_BoolOp(self, node: ast.BoolOp, **kwargs) -> None: - raise UnsupportedPythonFeatureError( + raise errors.UnsupportedPythonFeatureError( self.get_location(node), "logical operators `and`, `or`" ) @@ -413,7 +410,7 @@ def visit_Compare(self, node: ast.Compare, **kwargs) -> foast.Compare: if len(node.ops) != 1 or len(node.comparators) != 1: # Remove comparison chains in a preprocessing pass # TODO: maybe add a note to the error about preprocessing passes? - raise UnsupportedPythonFeatureError(loc, "comparison chains") + raise errors.UnsupportedPythonFeatureError(loc, "comparison chains") return foast.Compare( op=self.visit(node.ops[0]), left=self.visit(node.left), @@ -441,7 +438,7 @@ def visit_NotEq(self, node: ast.NotEq, **kwargs) -> foast.CompareOperator: def _verify_builtin_type_constructor(self, node: ast.Call): if len(node.args) > 0 and not isinstance(node.args[0], ast.Constant): - raise DSLError( + raise errors.DSLError( self.get_location(node), f"{self._func_name(node)}() only takes literal arguments!", ) @@ -468,7 +465,9 @@ def visit_Constant(self, node: ast.Constant, **kwargs) -> foast.Constant: try: type_ = type_translation.from_value(node.value) except ValueError: - raise DSLError(loc, f"constants of type {type(node.value)} are not permitted") from None + raise errors.DSLError( + loc, f"constants of type {type(node.value)} are not permitted" + ) from None return foast.Constant( value=node.value, diff --git a/src/gt4py/next/ffront/func_to_past.py b/src/gt4py/next/ffront/func_to_past.py index 7911283646..7b04e90902 100644 --- a/src/gt4py/next/ffront/func_to_past.py +++ b/src/gt4py/next/ffront/func_to_past.py @@ -18,11 +18,7 @@ from dataclasses import dataclass from typing import Any, cast -from gt4py.next.errors import ( - DSLError, - InvalidParameterAnnotationError, - MissingParameterAnnotationError, -) +from gt4py.next import errors from gt4py.next.ffront import ( dialect_ast_enums, program_ast as past, @@ -71,10 +67,10 @@ def visit_arguments(self, node: ast.arguments) -> list[past.DataSymbol]: def visit_arg(self, node: ast.arg) -> past.DataSymbol: loc = self.get_location(node) if (annotation := self.annotations.get(node.arg, None)) is None: - raise MissingParameterAnnotationError(loc, node.arg) + raise errors.MissingParameterAnnotationError(loc, node.arg) new_type = type_translation.from_type_hint(annotation) if not isinstance(new_type, ts.DataType): - raise InvalidParameterAnnotationError(loc, node.arg, new_type) + raise errors.InvalidParameterAnnotationError(loc, node.arg, new_type) return past.DataSymbol(id=node.arg, location=loc, type=new_type) def visit_Expr(self, node: ast.Expr) -> past.LocatedNode: @@ -132,7 +128,9 @@ def visit_Call(self, node: ast.Call) -> past.Call: loc = self.get_location(node) new_func = self.visit(node.func) if not isinstance(new_func, past.Name): - raise DSLError(loc, "functions must be referenced by their name in function calls") + raise errors.DSLError( + loc, "functions must be referenced by their name in function calls" + ) return past.Call( func=new_func, @@ -168,7 +166,7 @@ def visit_UnaryOp(self, node: ast.UnaryOp) -> past.Constant: if isinstance(node.op, ast.USub) and isinstance(node.operand, ast.Constant): symbol_type = type_translation.from_value(node.operand.value) return past.Constant(value=-node.operand.value, type=symbol_type, location=loc) - raise DSLError(loc, "unary operators are only applicable to literals") + raise errors.DSLError(loc, "unary operators are only applicable to literals") def visit_Constant(self, node: ast.Constant) -> past.Constant: symbol_type = type_translation.from_value(node.value) diff --git a/src/gt4py/next/ffront/past_passes/type_deduction.py b/src/gt4py/next/ffront/past_passes/type_deduction.py index c00d8710f3..ed3bdae3ff 100644 --- a/src/gt4py/next/ffront/past_passes/type_deduction.py +++ b/src/gt4py/next/ffront/past_passes/type_deduction.py @@ -15,7 +15,7 @@ from typing import Optional, cast from gt4py.eve import NodeTranslator, traits -from gt4py.next.errors import DSLError +from gt4py.next import errors from gt4py.next.ffront import ( dialect_ast_enums, program_ast as past, @@ -148,7 +148,7 @@ def _deduce_binop_type( # check both types compatible for arg in (left, right): if not isinstance(arg.type, ts.ScalarType) or not is_compatible(arg.type): - raise DSLError( + raise errors.DSLError( arg.location, f"Type {arg.type} can not be used in operator `{node.op}`!" ) @@ -161,7 +161,7 @@ def _deduce_binop_type( if node.op == dialect_ast_enums.BinaryOperator.MOD and not type_info.is_integral( right_type ): - raise DSLError( + raise errors.DSLError( arg.location, f"Type {right_type} can not be used in operator `{node.op}`, it can only accept ints", ) @@ -169,7 +169,7 @@ def _deduce_binop_type( try: return type_info.promote(left_type, right_type) except ValueError as ex: - raise DSLError( + raise errors.DSLError( node.location, f"Could not promote `{left_type}` and `{right_type}` to common type" f" in call to `{node.op}`.", @@ -231,7 +231,7 @@ def visit_Call(self, node: past.Call, **kwargs): ) except ValueError as ex: - raise DSLError(node.location, f"Invalid call to `{node.func.id}`.") from ex + raise errors.DSLError(node.location, f"Invalid call to `{node.func.id}`.") from ex return past.Call( func=new_func, @@ -244,6 +244,6 @@ def visit_Call(self, node: past.Call, **kwargs): def visit_Name(self, node: past.Name, **kwargs) -> past.Name: symtable = kwargs["symtable"] if node.id not in symtable or symtable[node.id].type is None: - raise DSLError(node.location, f"Undeclared or untyped symbol `{node.id}`.") + raise errors.DSLError(node.location, f"Undeclared or untyped symbol `{node.id}`.") return past.Name(id=node.id, type=symtable[node.id].type, location=node.location) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index cf18055a2b..93eef541da 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -22,6 +22,7 @@ from gt4py.next import ( astype, broadcast, + errors, float32, float64, int32, @@ -31,7 +32,6 @@ neighbor_sum, where, ) -from gt4py.next.errors import DSLError from gt4py.next.ffront.experimental import as_offset from gt4py.next.program_processors.runners import gtfn_cpu @@ -823,7 +823,7 @@ def fieldop_where_k_offset( def test_undefined_symbols(cartesian_case): - with pytest.raises(DSLError, match="Undeclared symbol"): + with pytest.raises(errors.DSLError, match="Undeclared symbol"): @gtx.field_operator(backend=cartesian_case.backend) def return_undefined(): @@ -918,7 +918,7 @@ def unpack( def test_tuple_unpacking_too_many_values(cartesian_case): with pytest.raises( - DSLError, + errors.DSLError, match=(r"Could not deduce type: Too many values to unpack \(expected 3\)"), ): @@ -929,7 +929,7 @@ def _star_unpack() -> tuple[int32, float64, int32]: def test_tuple_unpacking_too_many_values(cartesian_case): - with pytest.raises(DSLError, match=(r"Assignment value must be of type tuple!")): + with pytest.raises(errors.DSLError, match=(r"Assignment value must be of type tuple!")): @gtx.field_operator(backend=cartesian_case.backend) def _invalid_unpack() -> tuple[int32, float64, int32]: diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py index f5390f914b..b16790bbde 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py @@ -18,8 +18,7 @@ import numpy as np import pytest -from gt4py.next import Field, field_operator, float64, index_field, np_as_located_field -from gt4py.next.errors import DSLError +from gt4py.next import Field, errors, field_operator, float64, index_field, np_as_located_field from gt4py.next.program_processors.runners import gtfn_cpu from next_tests.integration_tests.feature_tests import cases @@ -357,7 +356,7 @@ def if_without_else( def test_if_non_scalar_condition(): - with pytest.raises(DSLError, match="Condition for `if` must be scalar."): + with pytest.raises(errors.DSLError, match="Condition for `if` must be scalar."): @field_operator def if_non_scalar_condition( @@ -370,7 +369,7 @@ def if_non_scalar_condition( def test_if_non_boolean_condition(): - with pytest.raises(DSLError, match="Condition for `if` must be of boolean type."): + with pytest.raises(errors.DSLError, match="Condition for `if` must be of boolean type."): @field_operator def if_non_boolean_condition( @@ -385,7 +384,7 @@ def if_non_boolean_condition( def test_if_inconsistent_types(): with pytest.raises( - DSLError, + errors.DSLError, match="Inconsistent types between two branches for variable", ): diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py index bacf91c275..e62c14b1bb 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py @@ -24,6 +24,7 @@ FieldOffset, astype, broadcast, + errors, float32, float64, int32, @@ -31,7 +32,6 @@ neighbor_sum, where, ) -from gt4py.next.errors import DSLError from gt4py.next.ffront.ast_passes import single_static_assign as ssa from gt4py.next.ffront.experimental import as_offset from gt4py.next.ffront.func_to_foast import FieldOperatorParser @@ -505,7 +505,7 @@ def add_bools(a: Field[[TDim], bool], b: Field[[TDim], bool]): return a + b with pytest.raises( - DSLError, + errors.DSLError, match=(r"Type Field\[\[TDim\], bool\] can not be used in operator `\+`!"), ): _ = FieldOperatorParser.apply_to_function(add_bools) @@ -520,7 +520,7 @@ def nonmatching(a: Field[[X], float64], b: Field[[Y], float64]): return a + b with pytest.raises( - DSLError, + errors.DSLError, match=( r"Could not promote `Field\[\[X], float64\]` and `Field\[\[Y\], float64\]` to common type in call to +." ), @@ -533,7 +533,7 @@ def float_bitop(a: Field[[TDim], float], b: Field[[TDim], float]): return a & b with pytest.raises( - DSLError, + errors.DSLError, match=(r"Type Field\[\[TDim\], float64\] can not be used in operator `\&`!"), ): _ = FieldOperatorParser.apply_to_function(float_bitop) @@ -544,7 +544,7 @@ def sign_bool(a: Field[[TDim], bool]): return -a with pytest.raises( - DSLError, + errors.DSLError, match=r"Incompatible type for unary operator `\-`: `Field\[\[TDim\], bool\]`!", ): _ = FieldOperatorParser.apply_to_function(sign_bool) @@ -555,7 +555,7 @@ def not_int(a: Field[[TDim], int64]): return not a with pytest.raises( - DSLError, + errors.DSLError, match=r"Incompatible type for unary operator `not`: `Field\[\[TDim\], int64\]`!", ): _ = FieldOperatorParser.apply_to_function(not_int) @@ -627,7 +627,7 @@ def mismatched_lit() -> Field[[TDim], "float32"]: return float32("1.0") + float64("1.0") with pytest.raises( - DSLError, + errors.DSLError, match=(r"Could not promote `float32` and `float64` to common type in call to +."), ): _ = FieldOperatorParser.apply_to_function(mismatched_lit) @@ -657,7 +657,7 @@ def disjoint_broadcast(a: Field[[ADim], float64]): return broadcast(a, (BDim, CDim)) with pytest.raises( - DSLError, + errors.DSLError, match=r"Expected broadcast dimension is missing", ): _ = FieldOperatorParser.apply_to_function(disjoint_broadcast) @@ -672,7 +672,7 @@ def badtype_broadcast(a: Field[[ADim], float64]): return broadcast(a, (BDim, CDim)) with pytest.raises( - DSLError, + errors.DSLError, match=r"Expected all broadcast dimensions to be of type Dimension.", ): _ = FieldOperatorParser.apply_to_function(badtype_broadcast) @@ -738,7 +738,7 @@ def bad_dim_where(a: Field[[ADim], bool], b: Field[[ADim], float64]): return where(a, ((5.0, 9.0), (b, 6.0)), b) with pytest.raises( - DSLError, + errors.DSLError, match=r"Return arguments need to be of same type", ): _ = FieldOperatorParser.apply_to_function(bad_dim_where) @@ -793,7 +793,7 @@ def modulo_floats(inp: Field[[TDim], float]): return inp % 3.0 with pytest.raises( - DSLError, + errors.DSLError, match=r"Type float64 can not be used in operator `%`", ): _ = FieldOperatorParser.apply_to_function(modulo_floats) @@ -803,7 +803,7 @@ def test_undefined_symbols(): def return_undefined(): return undefined_symbol - with pytest.raises(DSLError, match="Undeclared symbol"): + with pytest.raises(errors.DSLError, match="Undeclared symbol"): _ = FieldOperatorParser.apply_to_function(return_undefined) @@ -816,7 +816,7 @@ def as_offset_dim(a: Field[[ADim, BDim], float], b: Field[[ADim], int]): return a(as_offset(Boff, b)) with pytest.raises( - DSLError, + errors.DSLError, match=f"not in list of offset field dimensions", ): _ = FieldOperatorParser.apply_to_function(as_offset_dim) @@ -831,7 +831,7 @@ def as_offset_dtype(a: Field[[ADim, BDim], float], b: Field[[BDim], float]): return a(as_offset(Boff, b)) with pytest.raises( - DSLError, + errors.DSLError, match=f"Excepted integer for offset field dtype", ): _ = FieldOperatorParser.apply_to_function(as_offset_dtype) diff --git a/tests/next_tests/integration_tests/feature_tests/test_util_cases.py b/tests/next_tests/integration_tests/feature_tests/test_util_cases.py index bde6c7c247..32132dbc98 100644 --- a/tests/next_tests/integration_tests/feature_tests/test_util_cases.py +++ b/tests/next_tests/integration_tests/feature_tests/test_util_cases.py @@ -16,7 +16,7 @@ import pytest import gt4py.next as gtx -from gt4py.next.errors import DSLError +from gt4py.next import errors from gt4py.next.program_processors.runners import roundtrip from next_tests.integration_tests.feature_tests import cases @@ -87,7 +87,7 @@ def test_verify_fails_with_wrong_type(cartesian_case): # noqa: F811 # fixtures b = cases.allocate(cartesian_case, addition, "b")() out = cases.allocate(cartesian_case, addition, cases.RETURN)() - with pytest.raises(DSLError): + with pytest.raises(errors.DSLError): cases.verify(cartesian_case, addition, a, b, out=out, ref=a.array() + b.array()) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/test_driver.py b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/test_driver.py index 14df97c523..7de56cb5bb 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/test_driver.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/test_driver.py @@ -14,6 +14,7 @@ import os import subprocess +import sys from pathlib import Path import pytest @@ -31,7 +32,7 @@ def _execute_cmake(backend_str: str): build_dir = _build_dir(backend_str) build_dir.mkdir(exist_ok=True) cmake = ["cmake", "-B", build_dir, f"-DBACKEND={backend_str}"] - subprocess.run(cmake, cwd=_source_dir(), check=True) + subprocess.run(cmake, cwd=_source_dir(), check=True, stderr=sys.stderr) def _get_available_cpu_count(): diff --git a/tests/next_tests/unit_tests/errors_tests/test_excepthook.py b/tests/next_tests/unit_tests/errors_tests/test_excepthook.py index f7db8e7a0d..d9f29b99d5 100644 --- a/tests/next_tests/unit_tests/errors_tests/test_excepthook.py +++ b/tests/next_tests/unit_tests/errors_tests/test_excepthook.py @@ -13,21 +13,21 @@ # SPDX-License-Identifier: GPL-3.0-or-later from gt4py import eve -from gt4py.next.errors import excepthook, exceptions +from gt4py.next import errors def test_format_uncaught_error(): try: loc = eve.SourceLocation("/src/file.py", 1, 1) msg = "compile error msg" - raise exceptions.DSLError(loc, msg) from ValueError("value error msg") - except exceptions.DSLError as err: - str_devmode = "".join(excepthook._format_uncaught_error(err, True)) + raise errors.exceptions.DSLError(loc, msg) from ValueError("value error msg") + except errors.exceptions.DSLError as err: + str_devmode = "".join(errors.excepthook._format_uncaught_error(err, True)) assert str_devmode.find("Source location") >= 0 assert str_devmode.find("Traceback") >= 0 assert str_devmode.find("cause") >= 0 assert str_devmode.find("ValueError") >= 0 - str_usermode = "".join(excepthook._format_uncaught_error(err, False)) + str_usermode = "".join(errors.excepthook._format_uncaught_error(err, False)) assert str_usermode.find("Source location") >= 0 assert str_usermode.find("Traceback") < 0 assert str_usermode.find("cause") < 0 diff --git a/tests/next_tests/unit_tests/errors_tests/test_exceptions.py b/tests/next_tests/unit_tests/errors_tests/test_exceptions.py index 90d1ebb8b0..8111597a7c 100644 --- a/tests/next_tests/unit_tests/errors_tests/test_exceptions.py +++ b/tests/next_tests/unit_tests/errors_tests/test_exceptions.py @@ -18,7 +18,7 @@ import pytest from gt4py.eve import SourceLocation -from gt4py.next.errors import DSLError +from gt4py.next import errors @pytest.fixture @@ -41,20 +41,20 @@ def message(): def test_message(loc_plain, message): - assert DSLError(loc_plain, message).message == message + assert errors.DSLError(loc_plain, message).message == message def test_location(loc_plain, message): - assert DSLError(loc_plain, message).location == loc_plain + assert errors.DSLError(loc_plain, message).location == loc_plain def test_with_location(loc_plain, message): - assert DSLError(None, message).with_location(loc_plain).location == loc_plain + assert errors.DSLError(None, message).with_location(loc_plain).location == loc_plain def test_str(loc_plain, message): pattern = f'{message}\\n File ".*", line.*' - s = str(DSLError(loc_plain, message)) + s = str(errors.DSLError(loc_plain, message)) assert re.match(pattern, s) @@ -65,5 +65,5 @@ def test_str_snippet(loc_snippet, message): " # This very line of comment should be shown in the snippet.\\n" " \^\^\^\^\^\^\^\^\^\^\^\^\^\^" ) - s = str(DSLError(loc_snippet, message)) + s = str(errors.DSLError(loc_snippet, message)) assert re.match(pattern, s) diff --git a/tests/next_tests/unit_tests/errors_tests/test_formatting.py b/tests/next_tests/unit_tests/errors_tests/test_formatting.py index 78a206eda3..5074166bc2 100644 --- a/tests/next_tests/unit_tests/errors_tests/test_formatting.py +++ b/tests/next_tests/unit_tests/errors_tests/test_formatting.py @@ -17,7 +17,7 @@ import pytest from gt4py.eve import SourceLocation -from gt4py.next.errors import DSLError +from gt4py.next import errors from gt4py.next.errors.formatting import format_compilation_error @@ -41,7 +41,7 @@ def tb(): @pytest.fixture def type_(): - return DSLError + return errors.DSLError @pytest.fixture diff --git a/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py b/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py index 21612ea8db..e5bbed19fd 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py @@ -41,12 +41,7 @@ import gt4py.next as gtx from gt4py.eve.pattern_matching import ObjectPattern as P -from gt4py.next import astype, broadcast, float32, float64, int32, int64, where -from gt4py.next.errors import ( - DSLError, - MissingParameterAnnotationError, - UnsupportedPythonFeatureError, -) +from gt4py.next import astype, broadcast, errors, float32, float64, int32, int64, where from gt4py.next.ffront import field_operator_ast as foast from gt4py.next.ffront.ast_passes import single_static_assign as ssa from gt4py.next.ffront.func_to_foast import FieldOperatorParser @@ -81,7 +76,7 @@ def test_untyped_arg(): def untyped(inp): return inp - with pytest.raises(MissingParameterAnnotationError): + with pytest.raises(errors.MissingParameterAnnotationError): _ = FieldOperatorParser.apply_to_function(untyped) @@ -119,7 +114,7 @@ def no_return(inp: gtx.Field[[TDim], "float64"]): tmp = inp # noqa with pytest.raises( - DSLError, + errors.DSLError, match=".*return.*", ): _ = FieldOperatorParser.apply_to_function(no_return) @@ -135,7 +130,7 @@ def invalid_assign_to_expr( tmp[-1] = inp2 return tmp - with pytest.raises(DSLError, match=r".*assign.*"): + with pytest.raises(errors.DSLError, match=r".*assign.*"): _ = FieldOperatorParser.apply_to_function(invalid_assign_to_expr) @@ -161,7 +156,7 @@ def clashing(inp: gtx.Field[[TDim], "float64"]): tmp: gtx.Field[[TDim], "int64"] = inp return tmp - with pytest.raises(DSLError, match="type inconsistency"): + with pytest.raises(errors.DSLError, match="type inconsistency"): _ = FieldOperatorParser.apply_to_function(clashing) @@ -194,7 +189,7 @@ def bool_and(a: gtx.Field[[TDim], "bool"], b: gtx.Field[[TDim], "bool"]): return a and b with pytest.raises( - UnsupportedPythonFeatureError, + errors.UnsupportedPythonFeatureError, match=r".*and.*or.*", ): _ = FieldOperatorParser.apply_to_function(bool_and) @@ -205,7 +200,7 @@ def bool_or(a: gtx.Field[[TDim], "bool"], b: gtx.Field[[TDim], "bool"]): return a or b with pytest.raises( - UnsupportedPythonFeatureError, + errors.UnsupportedPythonFeatureError, match=r".*and.*or.*", ): _ = FieldOperatorParser.apply_to_function(bool_or) @@ -240,7 +235,7 @@ def cast_scalar_temp(): tmp = int64(1) return int32(tmp) - with pytest.raises(DSLError, match=r".*literal.*"): + with pytest.raises(errors.DSLError, match=r".*literal.*"): _ = FieldOperatorParser.apply_to_function(cast_scalar_temp) @@ -251,7 +246,7 @@ def conditional_wrong_mask_type( return where(a, a, a) msg = r"Expected a field with dtype `bool`." - with pytest.raises(DSLError, match=msg): + with pytest.raises(errors.DSLError, match=msg): _ = FieldOperatorParser.apply_to_function(conditional_wrong_mask_type) @@ -264,7 +259,7 @@ def conditional_wrong_arg_type( return where(mask, a, b) msg = r"Could not promote scalars of different dtype \(not implemented\)." - with pytest.raises(DSLError) as exc_info: + with pytest.raises(errors.DSLError) as exc_info: _ = FieldOperatorParser.apply_to_function(conditional_wrong_arg_type) assert re.search(msg, exc_info.value.__cause__.args[0]) is not None @@ -274,7 +269,7 @@ def test_ternary_with_field_condition(): def ternary_with_field_condition(cond: gtx.Field[[], bool]): return 1 if cond else 2 - with pytest.raises(DSLError, match=r"should be .* `bool`"): + with pytest.raises(errors.DSLError, match=r"should be .* `bool`"): _ = FieldOperatorParser.apply_to_function(ternary_with_field_condition) @@ -293,7 +288,7 @@ def test_adr13_wrong_return_type_annotation(): def wrong_return_type_annotation() -> gtx.Field[[], float]: return 1.0 - with pytest.raises(DSLError, match=r"Expected `float.*`"): + with pytest.raises(errors.DSLError, match=r"Expected `float.*`"): _ = FieldOperatorParser.apply_to_function(wrong_return_type_annotation) @@ -375,7 +370,7 @@ def wrong_return_type_annotation(a: gtx.Field[[ADim], float64]) -> gtx.Field[[BD return a with pytest.raises( - DSLError, + errors.DSLError, match=r"Annotated return type does not match deduced return type", ): _ = FieldOperatorParser.apply_to_function(wrong_return_type_annotation) @@ -386,7 +381,7 @@ def empty_dims() -> gtx.Field[[], float]: return 1.0 with pytest.raises( - DSLError, + errors.DSLError, match=r"Annotated return type does not match deduced return type", ): _ = FieldOperatorParser.apply_to_function(empty_dims) @@ -401,7 +396,7 @@ def zero_dims_ternary( return a if cond == 1 else b msg = r"Incompatible datatypes in operator `==`" - with pytest.raises(DSLError) as exc_info: + with pytest.raises(errors.DSLError) as exc_info: _ = FieldOperatorParser.apply_to_function(zero_dims_ternary) assert re.search(msg, exc_info.value.args[0]) is not None diff --git a/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast_error_line_number.py b/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast_error_line_number.py index d83bf298f1..123d57baf1 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast_error_line_number.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast_error_line_number.py @@ -18,7 +18,7 @@ import pytest import gt4py.next as gtx -from gt4py.next.errors import DSLError +from gt4py.next import errors from gt4py.next.ffront import func_to_foast as f2f, source_utils as src_utils from gt4py.next.ffront.foast_passes import type_deduction @@ -37,7 +37,7 @@ def wrong_syntax(inp: gtx.Field[[TDim], float]): return # <-- this line triggers the syntax error with pytest.raises( - f2f.DSLError, + f2f.errors.DSLError, match=(r".*return.*"), ) as exc_info: _ = f2f.FieldOperatorParser.apply_to_function(wrong_syntax) @@ -63,7 +63,7 @@ def invalid_python_syntax(): """, ) - with pytest.raises(DSLError) as exc_info: + with pytest.raises(errors.DSLError) as exc_info: _ = f2f.FieldOperatorParser.apply(source_definition, {}, {}) assert exc_info.value.location @@ -82,7 +82,7 @@ def test_fo_type_deduction_error(): def field_operator_with_undeclared_symbol(): return undeclared_symbol # noqa: F821 # undefined on purpose - with pytest.raises(DSLError) as exc_info: + with pytest.raises(errors.DSLError) as exc_info: _ = f2f.FieldOperatorParser.apply_to_function(field_operator_with_undeclared_symbol) exc = exc_info.value diff --git a/tests/next_tests/unit_tests/ffront_tests/test_func_to_past.py b/tests/next_tests/unit_tests/ffront_tests/test_func_to_past.py index d92b5abbb7..6e617f77a2 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_func_to_past.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_func_to_past.py @@ -19,8 +19,7 @@ import gt4py.eve as eve import gt4py.next as gtx from gt4py.eve.pattern_matching import ObjectPattern as P -from gt4py.next import float64 -from gt4py.next.errors import DSLError +from gt4py.next import errors, float64 from gt4py.next.ffront import program_ast as past from gt4py.next.ffront.func_to_past import ProgramParser from gt4py.next.type_system import type_specifications as ts @@ -113,7 +112,7 @@ def undefined_field_program(in_field: gtx.Field[[IDim], "float64"]): identity(in_field, out=out_field) # noqa: F821 # undefined on purpose with pytest.raises( - DSLError, + errors.DSLError, match=(r"Undeclared or untyped symbol `out_field`."), ): ProgramParser.apply_to_function(undefined_field_program) @@ -162,7 +161,7 @@ def domain_format_1_program(in_field: gtx.Field[[IDim], float64]): domain_format_1(in_field, out=in_field, domain=(0, 2)) with pytest.raises( - DSLError, + errors.DSLError, ) as exc_info: ProgramParser.apply_to_function(domain_format_1_program) @@ -181,7 +180,7 @@ def domain_format_2_program(in_field: gtx.Field[[IDim], float64]): domain_format_2(in_field, out=in_field, domain={IDim: (0, 1, 2)}) with pytest.raises( - DSLError, + errors.DSLError, ) as exc_info: ProgramParser.apply_to_function(domain_format_2_program) @@ -200,7 +199,7 @@ def domain_format_3_program(in_field: gtx.Field[[IDim], float64]): domain_format_3(in_field, domain={IDim: (0, 2)}) with pytest.raises( - DSLError, + errors.DSLError, ) as exc_info: ProgramParser.apply_to_function(domain_format_3_program) @@ -221,7 +220,7 @@ def domain_format_4_program(in_field: gtx.Field[[IDim], float64]): ) with pytest.raises( - DSLError, + errors.DSLError, ) as exc_info: ProgramParser.apply_to_function(domain_format_4_program) @@ -240,7 +239,7 @@ def domain_format_5_program(in_field: gtx.Field[[IDim], float64]): domain_format_5(in_field, out=in_field, domain={IDim: ("1.0", 9.0)}) with pytest.raises( - DSLError, + errors.DSLError, ) as exc_info: ProgramParser.apply_to_function(domain_format_5_program) @@ -259,7 +258,7 @@ def domain_format_6_program(in_field: gtx.Field[[IDim], float64]): domain_format_6(in_field, out=in_field, domain={}) with pytest.raises( - DSLError, + errors.DSLError, ) as exc_info: ProgramParser.apply_to_function(domain_format_6_program) diff --git a/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py b/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py index 083a9796b9..e56dc85322 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py @@ -19,7 +19,7 @@ import gt4py.eve as eve import gt4py.next as gtx from gt4py.eve.pattern_matching import ObjectPattern as P -from gt4py.next.errors import DSLError +from gt4py.next import errors from gt4py.next.ffront.func_to_past import ProgramParser from gt4py.next.ffront.past_to_itir import ProgramLowering from gt4py.next.iterator import ir as itir @@ -169,7 +169,7 @@ def inout_field_program(inout_field: gtx.Field[[IDim], "float64"]): def test_invalid_call_sig_program(invalid_call_sig_program_def): with pytest.raises( - DSLError, + errors.DSLError, ) as exc_info: ProgramLowering.apply( ProgramParser.apply_to_function(invalid_call_sig_program_def), From de7f9e9feb5146c35f7d7616fdb2f5fc56677693 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Kardos?= Date: Wed, 19 Jul 2023 15:42:12 +0200 Subject: [PATCH 50/54] split string into multiple lines proper --- .../errors_tests/test_formatting.py | 21 ++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/tests/next_tests/unit_tests/errors_tests/test_formatting.py b/tests/next_tests/unit_tests/errors_tests/test_formatting.py index 5074166bc2..b14a62b8bd 100644 --- a/tests/next_tests/unit_tests/errors_tests/test_formatting.py +++ b/tests/next_tests/unit_tests/errors_tests/test_formatting.py @@ -50,25 +50,36 @@ def qualname(type_): def test_format(type_, qualname, message): - pattern = f"{qualname}: {message}" + cls_pattern = f"{qualname}: {message}" s = "\n".join(format_compilation_error(type_, message, None, None, None)) - assert re.match(pattern, s) + assert re.match(cls_pattern, s) def test_format_loc(type_, qualname, message, location): - pattern = "Source location.*\\n" ' File "/source.*".*\\n' f"{qualname}: {message}" + loc_pattern = "Source location.*" + file_pattern = ' File "/source.*".*' + cls_pattern = f"{qualname}: {message}" + pattern = "\\n".join([loc_pattern, file_pattern, cls_pattern]) s = "".join(format_compilation_error(type_, message, location, None, None)) assert re.match(pattern, s) def test_format_traceback(type_, qualname, message, tb): - pattern = "Traceback.*\\n" ' File ".*".*\\n' ".*\\n" f"{qualname}: {message}" + tb_pattern = "Traceback.*" + file_pattern = ' File ".*".*' + line_pattern = ".*" + cls_pattern = f"{qualname}: {message}" + pattern = "\\n".join([tb_pattern, file_pattern, line_pattern, cls_pattern]) s = "".join(format_compilation_error(type_, message, None, tb, None)) assert re.match(pattern, s) def test_format_cause(type_, qualname, message): cause = ValueError("asd") - pattern = "ValueError: asd\\n\\n" "The above.*\\n\\n" f"{qualname}: {message}" + blank_pattern = "" + cause_pattern = "ValueError: asd" + bridge_pattern = "The above.*" + cls_pattern = f"{qualname}: {message}" + pattern = "\\n".join([cause_pattern, blank_pattern, bridge_pattern, blank_pattern, cls_pattern]) s = "".join(format_compilation_error(type_, message, None, None, cause)) assert re.match(pattern, s) From 932789332c7e6ea64193c7b5d07ad5ae174a1287 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Kardos?= Date: Wed, 19 Jul 2023 15:57:20 +0200 Subject: [PATCH 51/54] fix concatenated strings --- tests/next_tests/unit_tests/errors_tests/test_formatting.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/next_tests/unit_tests/errors_tests/test_formatting.py b/tests/next_tests/unit_tests/errors_tests/test_formatting.py index b14a62b8bd..ebb6cf9a37 100644 --- a/tests/next_tests/unit_tests/errors_tests/test_formatting.py +++ b/tests/next_tests/unit_tests/errors_tests/test_formatting.py @@ -59,7 +59,7 @@ def test_format_loc(type_, qualname, message, location): loc_pattern = "Source location.*" file_pattern = ' File "/source.*".*' cls_pattern = f"{qualname}: {message}" - pattern = "\\n".join([loc_pattern, file_pattern, cls_pattern]) + pattern = r"\n".join([loc_pattern, file_pattern, cls_pattern]) s = "".join(format_compilation_error(type_, message, location, None, None)) assert re.match(pattern, s) @@ -69,7 +69,7 @@ def test_format_traceback(type_, qualname, message, tb): file_pattern = ' File ".*".*' line_pattern = ".*" cls_pattern = f"{qualname}: {message}" - pattern = "\\n".join([tb_pattern, file_pattern, line_pattern, cls_pattern]) + pattern = r"\n".join([tb_pattern, file_pattern, line_pattern, cls_pattern]) s = "".join(format_compilation_error(type_, message, None, tb, None)) assert re.match(pattern, s) @@ -80,6 +80,6 @@ def test_format_cause(type_, qualname, message): cause_pattern = "ValueError: asd" bridge_pattern = "The above.*" cls_pattern = f"{qualname}: {message}" - pattern = "\\n".join([cause_pattern, blank_pattern, bridge_pattern, blank_pattern, cls_pattern]) + pattern = r"\n".join([cause_pattern, blank_pattern, bridge_pattern, blank_pattern, cls_pattern]) s = "".join(format_compilation_error(type_, message, None, None, cause)) assert re.match(pattern, s) From 1bf640cef77b10258871ed17b4f0ab032c9d5fe1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Kardos?= Date: Thu, 20 Jul 2023 13:08:46 +0200 Subject: [PATCH 52/54] fix escape sequence, reformat --- .../unit_tests/errors_tests/test_exceptions.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/next_tests/unit_tests/errors_tests/test_exceptions.py b/tests/next_tests/unit_tests/errors_tests/test_exceptions.py index 8111597a7c..60a382d989 100644 --- a/tests/next_tests/unit_tests/errors_tests/test_exceptions.py +++ b/tests/next_tests/unit_tests/errors_tests/test_exceptions.py @@ -59,11 +59,13 @@ def test_str(loc_plain, message): def test_str_snippet(loc_snippet, message): - pattern = ( - f"{message}\\n" - ' File ".*", line.*\\n' - " # This very line of comment should be shown in the snippet.\\n" - " \^\^\^\^\^\^\^\^\^\^\^\^\^\^" + pattern = r"\n".join( + [ + f"{message}", + ' File ".*", line.*', + " # This very line of comment should be shown in the snippet.", + r" \^\^\^\^\^\^\^\^\^\^\^\^\^\^", + ] ) s = str(errors.DSLError(loc_snippet, message)) assert re.match(pattern, s) From eafdde2635e4ebf6fbfcd88b2326c7dbaf6fddd4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Kardos?= Date: Thu, 20 Jul 2023 13:09:07 +0200 Subject: [PATCH 53/54] fix verbose exceptions env var, tests for it --- src/gt4py/next/errors/excepthook.py | 14 ++++++++--- .../errors_tests/test_excepthook.py | 23 +++++++++++++++++++ 2 files changed, 34 insertions(+), 3 deletions(-) diff --git a/src/gt4py/next/errors/excepthook.py b/src/gt4py/next/errors/excepthook.py index 9e8c0a3e58..b3fc271d62 100644 --- a/src/gt4py/next/errors/excepthook.py +++ b/src/gt4py/next/errors/excepthook.py @@ -22,6 +22,7 @@ import os import sys +import warnings from typing import Callable from . import exceptions, formatting @@ -31,10 +32,17 @@ def _get_verbose_exceptions_envvar() -> bool: """Detect if the user enabled verbose exceptions in the environment variables.""" env_var_name = "GT4PY_VERBOSE_EXCEPTIONS" if env_var_name in os.environ: - try: - return bool(os.environ[env_var_name]) - except TypeError: + false_values = ["0", "false", "off"] + true_values = ["1", "true", "on"] + value = os.environ[env_var_name].lower() + if value in false_values: return False + elif value in true_values: + return True + else: + values = ", ".join([*false_values, *true_values]) + msg = f"the 'GT4PY_VERBOSE_EXCEPTIONS' environment variable must be one of {values} (case insensitive)" + warnings.warn(msg) return False diff --git a/tests/next_tests/unit_tests/errors_tests/test_excepthook.py b/tests/next_tests/unit_tests/errors_tests/test_excepthook.py index d9f29b99d5..526844d730 100644 --- a/tests/next_tests/unit_tests/errors_tests/test_excepthook.py +++ b/tests/next_tests/unit_tests/errors_tests/test_excepthook.py @@ -12,8 +12,11 @@ # # SPDX-License-Identifier: GPL-3.0-or-later +import os + from gt4py import eve from gt4py.next import errors +from gt4py.next.errors import excepthook def test_format_uncaught_error(): @@ -32,3 +35,23 @@ def test_format_uncaught_error(): assert str_usermode.find("Traceback") < 0 assert str_usermode.find("cause") < 0 assert str_usermode.find("ValueError") < 0 + + +def test_get_verbose_exceptions(): + env_var_name = "GT4PY_VERBOSE_EXCEPTIONS" + + # Make sure to save and restore the environment variable, we don't want to + # affect other tests running in the same process. + saved = os.environ.get(env_var_name, None) + try: + os.environ[env_var_name] = "False" + assert excepthook._get_verbose_exceptions_envvar() is False + os.environ[env_var_name] = "True" + assert excepthook._get_verbose_exceptions_envvar() is True + os.environ[env_var_name] = "invalid value" # Should emit a warning too + assert excepthook._get_verbose_exceptions_envvar() is False + del os.environ[env_var_name] + assert excepthook._get_verbose_exceptions_envvar() is False + finally: + if saved is not None: + os.environ[env_var_name] = saved From 763876390e09f5650e57147090e2aacfbcbfdc6d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Kardos?= Date: Thu, 20 Jul 2023 13:55:27 +0200 Subject: [PATCH 54/54] improve docstrings --- src/gt4py/next/errors/excepthook.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/errors/excepthook.py b/src/gt4py/next/errors/excepthook.py index b3fc271d62..f1dd18e1b4 100644 --- a/src/gt4py/next/errors/excepthook.py +++ b/src/gt4py/next/errors/excepthook.py @@ -50,7 +50,7 @@ def _get_verbose_exceptions_envvar() -> bool: def set_verbose_exceptions(enabled: bool = False) -> None: - """With verbose exceptions, the stack trace and cause of the error is also printed.""" + """Programmatically set whether to use verbose printing for uncaught errors.""" global _verbose_exceptions _verbose_exceptions = enabled @@ -72,7 +72,9 @@ def compilation_error_hook(fallback: Callable, type_: type, value: BaseException """ Format `CompilationError`s in a neat way. - All other Python exceptions are formatted by the `fallback` hook. + All other Python exceptions are formatted by the `fallback` hook. When + verbose exceptions are enabled, the stack trace and cause of the error is + also printed. """ if isinstance(value, exceptions.DSLError): exc_strs = _format_uncaught_error(value, _verbose_exceptions)