diff --git a/src/gt4py/cartesian/frontend/node_util.py b/src/gt4py/cartesian/frontend/node_util.py index 5da908de5c..9595d5f76a 100644 --- a/src/gt4py/cartesian/frontend/node_util.py +++ b/src/gt4py/cartesian/frontend/node_util.py @@ -147,4 +147,4 @@ def recurse(node: Node) -> Generator[Node, None, None]: def location_to_source_location(loc: Optional[Location]) -> Optional[eve.SourceLocation]: if loc is None or loc.line <= 0 or loc.column <= 0: return None - return eve.SourceLocation(line=loc.line, column=loc.column, source=loc.scope) + return eve.SourceLocation(line=loc.line, column=loc.column, filename=loc.scope) diff --git a/src/gt4py/cartesian/gtc/dace/expansion/utils.py b/src/gt4py/cartesian/gtc/dace/expansion/utils.py index 613d161fc0..dc10c53f21 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion/utils.py +++ b/src/gt4py/cartesian/gtc/dace/expansion/utils.py @@ -35,7 +35,7 @@ def get_dace_debuginfo(node: common.LocNode): node.loc.column, node.loc.line, node.loc.column, - node.loc.source, + node.loc.filename, ) else: return dace.dtypes.DebugInfo(0) diff --git a/src/gt4py/eve/concepts.py b/src/gt4py/eve/concepts.py index d66e8583fd..67991f6db0 100644 --- a/src/gt4py/eve/concepts.py +++ b/src/gt4py/eve/concepts.py @@ -17,7 +17,6 @@ from __future__ import annotations -import ast import copy import re @@ -58,58 +57,42 @@ class SymbolRef(ConstrainedStr, regex=_SYMBOL_NAME_RE): @datamodels.datamodel(slots=True, frozen=True) class SourceLocation: - """Source code location (line, column, source).""" + """File-line-column information for source code.""" + filename: Optional[str] line: int = datamodels.field(validator=_validators.ge(1)) column: int = datamodels.field(validator=_validators.ge(1)) - source: str end_line: Optional[int] = datamodels.field(validator=_validators.optional(_validators.ge(1))) end_column: Optional[int] = datamodels.field(validator=_validators.optional(_validators.ge(1))) - @classmethod - def from_AST(cls, ast_node: ast.AST, source: Optional[str] = None) -> SourceLocation: - if ( - not isinstance(ast_node, ast.AST) - or getattr(ast_node, "lineno", None) is None - or getattr(ast_node, "col_offset", None) is None - ): - raise ValueError( - f"Passed AST node '{ast_node}' does not contain a valid source location." - ) - if source is None: - source = f"" - return cls( - ast_node.lineno, - ast_node.col_offset + 1, - source, - end_line=ast_node.end_lineno, - end_column=ast_node.end_col_offset + 1 if ast_node.end_col_offset is not None else None, - ) - def __init__( self, + filename: Optional[str], line: int, column: int, - source: str, - *, end_line: Optional[int] = None, end_column: Optional[int] = None, ) -> None: assert end_column is None or end_line is not None self.__auto_init__( # type: ignore[attr-defined] # __auto_init__ added dynamically - line=line, column=column, source=source, end_line=end_line, end_column=end_column + filename=filename, line=line, column=column, end_line=end_line, end_column=end_column ) def __str__(self) -> str: - src = self.source or "" + filename_str = self.filename or "-" + + end_line_str = self.end_line if self.end_line is not None else "-" + end_column_str = self.end_column if self.end_column is not None else "-" - end_part = "" - if self.end_line is not None: - end_part += f" to {self.end_line}" + end_str: Optional[str] = None if self.end_column is not None: - end_part += f":{self.end_column}" + end_str = f"{end_line_str}:{end_column_str}" + elif self.end_line is not None: + end_str = f"{end_line_str}" - return f"<{src}:{self.line}:{self.column}{end_part}>" + if end_str is not None: + return f"<{filename_str}:{self.line}:{self.column} to {end_str}>" + return f"<{filename_str}:{self.line}:{self.column}>" @datamodels.datamodel(slots=True, frozen=True) diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 5ece3a23ec..b0f0b8ac11 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -108,41 +108,3 @@ class NeighborTable(Connectivity, Protocol): class GridType(StrEnum): CARTESIAN = "cartesian" UNSTRUCTURED = "unstructured" - - -class GTError: - """Base class for GridTools exceptions. - - Notes: - This base class has to be always inherited together with a standard - exception, and thus it should not be used as direct superclass - for custom exceptions. Inherit directly from :class:`GTTypeError`, - :class:`GTTypeError`, ... - - """ - - ... - - -class GTRuntimeError(GTError, RuntimeError): - """Base class for GridTools run-time errors.""" - - ... - - -class GTSyntaxError(GTError, SyntaxError): - """Base class for GridTools syntax errors.""" - - ... - - -class GTTypeError(GTError, TypeError): - """Base class for GridTools type errors.""" - - ... - - -class GTValueError(GTError, ValueError): - """Base class for GridTools value errors.""" - - ... diff --git a/src/gt4py/next/errors/__init__.py b/src/gt4py/next/errors/__init__.py new file mode 100644 index 0000000000..61441e83b9 --- /dev/null +++ b/src/gt4py/next/errors/__init__.py @@ -0,0 +1,39 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Contains the exception classes and other utilities for error handling.""" + +from . import ( # noqa: module needs to be loaded for pretty printing of uncaught exceptions. + excepthook, +) +from .excepthook import set_verbose_exceptions +from .exceptions import ( + DSLError, + InvalidParameterAnnotationError, + MissingAttributeError, + MissingParameterAnnotationError, + UndefinedSymbolError, + UnsupportedPythonFeatureError, +) + + +__all__ = [ + "DSLError", + "InvalidParameterAnnotationError", + "MissingAttributeError", + "MissingParameterAnnotationError", + "UndefinedSymbolError", + "UnsupportedPythonFeatureError", + "set_verbose_exceptions", +] diff --git a/src/gt4py/next/errors/excepthook.py b/src/gt4py/next/errors/excepthook.py new file mode 100644 index 0000000000..f1dd18e1b4 --- /dev/null +++ b/src/gt4py/next/errors/excepthook.py @@ -0,0 +1,87 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +""" +Loading this module registers an excepthook that formats :class:`DSLError`. + +The excepthook is necessary because the default hook prints :class:`DSLError`s +in an inconvenient way. The previously set excepthook is used to print all +other errors. +""" + +import os +import sys +import warnings +from typing import Callable + +from . import exceptions, formatting + + +def _get_verbose_exceptions_envvar() -> bool: + """Detect if the user enabled verbose exceptions in the environment variables.""" + env_var_name = "GT4PY_VERBOSE_EXCEPTIONS" + if env_var_name in os.environ: + false_values = ["0", "false", "off"] + true_values = ["1", "true", "on"] + value = os.environ[env_var_name].lower() + if value in false_values: + return False + elif value in true_values: + return True + else: + values = ", ".join([*false_values, *true_values]) + msg = f"the 'GT4PY_VERBOSE_EXCEPTIONS' environment variable must be one of {values} (case insensitive)" + warnings.warn(msg) + return False + + +_verbose_exceptions: bool = _get_verbose_exceptions_envvar() + + +def set_verbose_exceptions(enabled: bool = False) -> None: + """Programmatically set whether to use verbose printing for uncaught errors.""" + global _verbose_exceptions + _verbose_exceptions = enabled + + +def _format_uncaught_error(err: exceptions.DSLError, verbose_exceptions: bool) -> list[str]: + if verbose_exceptions: + return formatting.format_compilation_error( + type(err), + err.message, + err.location, + err.__traceback__, + err.__cause__, + ) + else: + return formatting.format_compilation_error(type(err), err.message, err.location) + + +def compilation_error_hook(fallback: Callable, type_: type, value: BaseException, tb) -> None: + """ + Format `CompilationError`s in a neat way. + + All other Python exceptions are formatted by the `fallback` hook. When + verbose exceptions are enabled, the stack trace and cause of the error is + also printed. + """ + if isinstance(value, exceptions.DSLError): + exc_strs = _format_uncaught_error(value, _verbose_exceptions) + print("".join(exc_strs), file=sys.stderr) + else: + fallback(type_, value, tb) + + +_fallback = sys.excepthook +sys.excepthook = lambda ty, val, tb: compilation_error_hook(_fallback, ty, val, tb) diff --git a/src/gt4py/next/errors/exceptions.py b/src/gt4py/next/errors/exceptions.py new file mode 100644 index 0000000000..74230263db --- /dev/null +++ b/src/gt4py/next/errors/exceptions.py @@ -0,0 +1,104 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +""" +The list of exception classes used in the library. + +Exception classes that represent errors within an IR go here as a subclass of +:class:`DSLError`. Exception classes that represent other errors, like +the builtin ValueError, go here as well, although you should use Python's +builtin error classes if you can. Exception classes that are specific to a +certain submodule and have no use for the entire application may be better off +in that submodule as opposed to being in this file. +""" + +from __future__ import annotations + +import textwrap +from typing import Any, Optional + +from gt4py.eve import SourceLocation + +from . import formatting + + +class DSLError(Exception): + location: Optional[SourceLocation] + + def __init__(self, location: Optional[SourceLocation], message: str) -> None: + self.location = location + super().__init__(message) + + @property + def message(self) -> str: + return self.args[0] + + def with_location(self, location: Optional[SourceLocation]) -> DSLError: + self.location = location + return self + + def __str__(self) -> str: + if self.location: + loc_str = formatting.format_location(self.location, show_caret=True) + return f"{self.message}\n{textwrap.indent(loc_str, ' ')}" + return self.message + + +class UnsupportedPythonFeatureError(DSLError): + feature: str + + def __init__(self, location: Optional[SourceLocation], feature: str) -> None: + super().__init__(location, f"unsupported Python syntax: '{feature}'") + self.feature = feature + + +class UndefinedSymbolError(DSLError): + sym_name: str + + def __init__(self, location: Optional[SourceLocation], name: str) -> None: + super().__init__(location, f"name '{name}' is not defined") + self.sym_name = name + + +class MissingAttributeError(DSLError): + attr_name: str + + def __init__(self, location: Optional[SourceLocation], attr_name: str) -> None: + super().__init__(location, f"object does not have attribute '{attr_name}'") + self.attr_name = attr_name + + +class TypeError_(DSLError): + def __init__(self, location: Optional[SourceLocation], message: str) -> None: + super().__init__(location, message) + + +class MissingParameterAnnotationError(TypeError_): + param_name: str + + def __init__(self, location: Optional[SourceLocation], param_name: str) -> None: + super().__init__(location, f"parameter '{param_name}' is missing type annotations") + self.param_name = param_name + + +class InvalidParameterAnnotationError(TypeError_): + param_name: str + annotated_type: Any + + def __init__(self, location: Optional[SourceLocation], param_name: str, type_: Any) -> None: + super().__init__( + location, f"parameter '{param_name}' has invalid type annotation '{type_}'" + ) + self.param_name = param_name + self.annotated_type = type_ diff --git a/src/gt4py/next/errors/formatting.py b/src/gt4py/next/errors/formatting.py new file mode 100644 index 0000000000..0176607971 --- /dev/null +++ b/src/gt4py/next/errors/formatting.py @@ -0,0 +1,107 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Utility functions for formatting :class:`DSLError` and its subclasses.""" + +from __future__ import annotations + +import linecache +import textwrap +import traceback +import types +from typing import TYPE_CHECKING, Optional + +from gt4py.eve import SourceLocation + + +if TYPE_CHECKING: + from . import exceptions + + +def get_source_from_location(location: SourceLocation) -> str: + if not location.filename: + raise FileNotFoundError() + source_lines = linecache.getlines(location.filename) + if not source_lines: + raise FileNotFoundError() + start_line = location.line + end_line = location.end_line + 1 if location.end_line is not None else start_line + 1 + relevant_lines = source_lines[(start_line - 1) : (end_line - 1)] + return "\n".join(relevant_lines) + + +def format_location(loc: SourceLocation, show_caret: bool = False) -> str: + """ + Format the source file location. + + Args: + show_caret (bool): Indicate the position within the source line by placing carets underneath. + """ + filename = loc.filename or "" + lineno = loc.line + loc_str = f'File "{filename}", line {lineno}' + + if show_caret and loc.column is not None: + offset = loc.column - 1 + width = loc.end_column - loc.column if loc.end_column is not None else 1 + caret_str = "".join([" "] * offset + ["^"] * width) + else: + caret_str = None + + try: + snippet_str = get_source_from_location(loc) + if caret_str: + snippet_str = f"{snippet_str}{caret_str}" + return f"{loc_str}\n{textwrap.indent(snippet_str, ' ')}" + except (FileNotFoundError, IndexError): + return loc_str + + +def _format_cause(cause: BaseException) -> list[str]: + """Format the cause of an exception plus the bridging message to the current exception.""" + bridging_message = "The above exception was the direct cause of the following exception:" + cause_strs = [*traceback.format_exception(cause), "\n", f"{bridging_message}\n\n"] + return cause_strs + + +def _format_traceback(tb: types.TracebackType) -> list[str]: + """Format the traceback of an exception.""" + intro_message = "Traceback (most recent call last):" + traceback_strs = [ + f"{intro_message}\n", + *traceback.format_tb(tb), + ] + return traceback_strs + + +def format_compilation_error( + type_: type[exceptions.DSLError], + message: str, + location: Optional[SourceLocation], + tb: Optional[types.TracebackType] = None, + cause: Optional[BaseException] = None, +) -> list[str]: + bits: list[str] = [] + + if cause is not None: + bits = [*bits, *_format_cause(cause)] + if tb is not None: + bits = [*bits, *_format_traceback(tb)] + if location is not None: + loc_str = format_location(location, show_caret=True) + loc_str_all = f"Source location:\n{textwrap.indent(loc_str, ' ')}\n" + bits = [*bits, loc_str_all] + msg_str = f"{type_.__module__}.{type_.__name__}: {message}" + bits = [*bits, msg_str] + return bits diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index a4efd6c168..2a343454e1 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -31,7 +31,7 @@ from gt4py.eve.extended_typing import Any, Optional from gt4py.eve.utils import UIDGenerator -from gt4py.next.common import Dimension, DimensionKind, GridType, GTTypeError, Scalar +from gt4py.next.common import Dimension, DimensionKind, GridType, Scalar from gt4py.next.ffront import ( dialect_ast_enums, field_operator_ast as foast, @@ -47,7 +47,7 @@ from gt4py.next.ffront.past_passes.closure_var_type_deduction import ( ClosureVarTypeDeduction as ProgramClosureVarTypeDeduction, ) -from gt4py.next.ffront.past_passes.type_deduction import ProgramTypeDeduction, ProgramTypeError +from gt4py.next.ffront.past_passes.type_deduction import ProgramTypeDeduction from gt4py.next.ffront.past_to_itir import ProgramLowering from gt4py.next.ffront.source_utils import SourceDefinition, get_closure_vars_from_function from gt4py.next.iterator import ir as itir @@ -115,7 +115,7 @@ def is_cartesian_offset(o: FieldOffset): break if requested_grid_type == GridType.CARTESIAN and deduced_grid_type == GridType.UNSTRUCTURED: - raise GTTypeError( + raise ValueError( "grid_type == GridType.CARTESIAN was requested, but unstructured `FieldOffset` or local `Dimension` was found." ) @@ -289,10 +289,8 @@ def _validate_args(self, *args, **kwargs) -> None: with_kwargs=kwarg_types, raise_exception=True, ) - except GTTypeError as err: - raise ProgramTypeError.from_past_node( - self.past_node, msg=f"Invalid argument types in call to `{self.past_node.id}`!" - ) from err + except ValueError as err: + raise TypeError(f"Invalid argument types in call to `{self.past_node.id}`!") from err def _process_args(self, args: tuple, kwargs: dict) -> tuple[tuple, tuple, dict[str, Any]]: self._validate_args(*args, **kwargs) @@ -344,7 +342,7 @@ def _column_axis(self): f"- {dim.value}: {', '.join(scanops)}" for dim, scanops in scanops_per_axis.items() ] - raise GTTypeError( + raise TypeError( "Only `ScanOperator`s defined on the same axis " + "can be used in a `Program`, but found:\n" + "\n".join(scanops_per_axis_strs) diff --git a/src/gt4py/next/ffront/dialect_parser.py b/src/gt4py/next/ffront/dialect_parser.py index 97afa0f7da..c04e978e51 100644 --- a/src/gt4py/next/ffront/dialect_parser.py +++ b/src/gt4py/next/ffront/dialect_parser.py @@ -19,8 +19,8 @@ from typing import Callable from gt4py.eve.concepts import SourceLocation -from gt4py.eve.extended_typing import Any, ClassVar, Generic, Optional, Type, TypeVar -from gt4py.next import common +from gt4py.eve.extended_typing import Any, Generic, TypeVar +from gt4py.next import errors from gt4py.next.ffront.ast_passes.fix_missing_locations import FixMissingLocations from gt4py.next.ffront.ast_passes.remove_docstrings import RemoveDocstrings from gt4py.next.ffront.source_utils import SourceDefinition, get_closure_vars_from_function @@ -29,67 +29,24 @@ DialectRootT = TypeVar("DialectRootT") -class DialectSyntaxError(common.GTSyntaxError): - dialect_name: ClassVar[str] = "" - - def __init__( - self, - msg="", - *, - lineno: int = 0, - offset: int = 0, - filename: Optional[str] = None, - end_lineno: Optional[int] = None, - end_offset: Optional[int] = None, - text: Optional[str] = None, - ): - msg = f"Invalid {self.dialect_name} Syntax: {msg}" - super().__init__(msg, (filename, lineno, offset, text, end_lineno, end_offset)) - - @classmethod - def from_AST( - cls, - node: ast.AST, - *, - msg: str = "", - filename: Optional[str] = None, - text: Optional[str] = None, - ): - return cls( - msg, - lineno=node.lineno, - offset=node.col_offset + 1, # offset is 1-based for syntax errors - filename=filename, - end_lineno=getattr(node, "end_lineno", None), - end_offset=(node.end_col_offset + 1) - if hasattr(node, "end_col_offset") and node.end_col_offset is not None +def parse_source_definition(source_definition: SourceDefinition) -> ast.AST: + try: + return ast.parse(textwrap.dedent(source_definition.source)).body[0] + except SyntaxError as err: + assert err.lineno + assert err.offset + loc = SourceLocation( + line=err.lineno + source_definition.line_offset, + column=err.offset + source_definition.column_offset, + filename=source_definition.filename, + end_line=err.end_lineno + source_definition.line_offset + if err.end_lineno is not None + else None, + end_column=err.end_offset + source_definition.column_offset + if err.end_offset is not None else None, - text=text, - ) - - @classmethod - def from_location(cls, msg="", *, location: SourceLocation): - return cls( - msg, - lineno=location.line, - offset=location.column, - filename=location.source, - end_lineno=location.end_line, - end_offset=location.end_column, - text=None, ) - - -def _ensure_syntax_error_invariants(err: SyntaxError): - """Ensure syntax error invariants required to print meaningful error messages.""" - # If offsets are provided so must line numbers. For example `err.offset` determines - # if carets (`^^^^`) are printed below `err.text`. They would be misleading if we - # don't know on which line the error occurs. - assert err.lineno or not err.offset - assert err.end_lineno or not err.end_offset - # If the ends are provided so must starts. - assert err.lineno or not err.end_lineno - assert err.offset or not err.end_offset + raise errors.DSLError(loc, err.msg).with_traceback(err.__traceback__) @dataclass(frozen=True, kw_only=True) @@ -97,7 +54,6 @@ class DialectParser(ast.NodeVisitor, Generic[DialectRootT]): source_definition: SourceDefinition closure_vars: dict[str, Any] annotations: dict[str, Any] - syntax_error_cls: ClassVar[Type[DialectSyntaxError]] = DialectSyntaxError @classmethod def apply( @@ -106,52 +62,20 @@ def apply( closure_vars: dict[str, Any], annotations: dict[str, Any], ) -> DialectRootT: - source, filename, starting_line = source_definition - try: - line_offset = starting_line - 1 - definition_ast: ast.AST - definition_ast = ast.parse(textwrap.dedent(source)).body[0] - definition_ast = ast.increment_lineno(definition_ast, line_offset) - line_offset = 0 # line numbers are correct from now on - - definition_ast = RemoveDocstrings.apply(definition_ast) - definition_ast = FixMissingLocations.apply(definition_ast) - output_ast = cls._postprocess_dialect_ast( - cls( - source_definition=source_definition, - closure_vars=closure_vars, - annotations=annotations, - ).visit(cls._preprocess_definition_ast(definition_ast)), - closure_vars, - annotations, - ) - except SyntaxError as err: - _ensure_syntax_error_invariants(err) - - # The ast nodes do not contain information about the path of the - # source file or its contents. We add this information here so - # that raising an error using :func:`DialectSyntaxError.from_AST` - # does not require passing the information on every invocation. - if not err.filename: - err.filename = filename - - # ensure line numbers are relative to the file (respects `starting_line`) - if err.lineno: - err.lineno = err.lineno + line_offset - if err.end_lineno: - err.end_lineno = err.end_lineno + line_offset - - if not err.text: - if err.lineno: - source_lineno = err.lineno - starting_line - source_end_lineno = ( - (err.end_lineno - starting_line) if err.end_lineno else source_lineno - ) - err.text = "\n".join(source.splitlines()[source_lineno : source_end_lineno + 1]) - else: - err.text = source - - raise err + definition_ast: ast.AST + definition_ast = parse_source_definition(source_definition) + + definition_ast = RemoveDocstrings.apply(definition_ast) + definition_ast = FixMissingLocations.apply(definition_ast) + output_ast = cls._postprocess_dialect_ast( + cls( + source_definition=source_definition, + closure_vars=closure_vars, + annotations=annotations, + ).visit(cls._preprocess_definition_ast(definition_ast)), + closure_vars, + annotations, + ) return output_ast @@ -173,10 +97,21 @@ def _postprocess_dialect_ast( return output_ast def generic_visit(self, node: ast.AST) -> None: - raise self.syntax_error_cls.from_AST( - node, - msg=f"Nodes of type {type(node).__module__}.{type(node).__qualname__} not supported in dialect.", + loc = self.get_location(node) + feature = f"{type(node).__module__}.{type(node).__qualname__}" + raise errors.UnsupportedPythonFeatureError(loc, feature) + + def get_location(self, node: ast.AST) -> SourceLocation: + file = self.source_definition.filename + line_offset = self.source_definition.line_offset + col_offset = self.source_definition.column_offset + + line = node.lineno + line_offset if node.lineno is not None else None + end_line = node.end_lineno + line_offset if node.end_lineno is not None else None + column = 1 + node.col_offset + col_offset if node.col_offset is not None else None + end_column = ( + 1 + node.end_col_offset + col_offset if node.end_col_offset is not None else None ) - def _make_loc(self, node: ast.AST) -> SourceLocation: - return SourceLocation.from_AST(node, source=self.source_definition.filename) + loc = SourceLocation(file, line, column, end_line=end_line, end_column=end_column) + return loc diff --git a/src/gt4py/next/ffront/foast_passes/closure_var_folding.py b/src/gt4py/next/ffront/foast_passes/closure_var_folding.py index 32a77fe155..9afd22de2c 100644 --- a/src/gt4py/next/ffront/foast_passes/closure_var_folding.py +++ b/src/gt4py/next/ffront/foast_passes/closure_var_folding.py @@ -18,6 +18,7 @@ import gt4py.next.ffront.field_operator_ast as foast from gt4py.eve import NodeTranslator, traits from gt4py.eve.utils import FrozenNamespace +from gt4py.next import errors @dataclass @@ -50,22 +51,12 @@ def visit_Name( return node def visit_Attribute(self, node: foast.Attribute, **kwargs) -> foast.Constant: - # TODO: fix import form parent module by restructuring exception classis - from gt4py.next.ffront.func_to_foast import FieldOperatorSyntaxError - value = self.visit(node.value, **kwargs) if isinstance(value, foast.Constant): if hasattr(value.value, node.attr): return foast.Constant(value=getattr(value.value, node.attr), location=node.location) - # TODO: use proper exception type (requires refactoring `FieldOperatorSyntaxError`) - raise FieldOperatorSyntaxError.from_location( - msg="Constant does not have the attribute specified by the AST.", - location=node.location, - ) - # TODO: use proper exception type (requires refactoring `FieldOperatorSyntaxError`) - raise FieldOperatorSyntaxError.from_location( - msg="Attribute can only be used on constants.", location=node.location - ) + raise errors.MissingAttributeError(node.location, node.attr) + raise errors.DSLError(node.location, "attribute access only applicable to constants") def visit_FunctionDefinition( self, node: foast.FunctionDefinition, **kwargs diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index 4fd760c97c..bd7eddbcdd 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -16,7 +16,8 @@ import gt4py.next.ffront.field_operator_ast as foast from gt4py.eve import NodeTranslator, NodeVisitor, traits -from gt4py.next.common import DimensionKind, GTSyntaxError, GTTypeError +from gt4py.next import errors +from gt4py.next.common import DimensionKind from gt4py.next.ffront import ( # noqa dialect_ast_enums, fbuiltins, @@ -51,7 +52,7 @@ def boolified_type(symbol_type: ts.TypeSpec) -> ts.ScalarType | ts.FieldType: return scalar_bool elif type_class is ts.FieldType: return ts.FieldType(dtype=scalar_bool, dims=type_info.extract_dims(symbol_type)) - raise GTTypeError(f"Can not boolify type {symbol_type}!") + raise ValueError(f"Can not boolify type {symbol_type}!") def construct_tuple_type( @@ -144,9 +145,9 @@ def deduce_stmt_return_type( if return_types[0] == return_types[1]: is_unconditional_return = True else: - raise FieldOperatorTypeDeductionError.from_foast_node( - stmt, - msg=f"If statement contains return statements with inconsistent types:" + raise errors.DSLError( + stmt.location, + f"If statement contains return statements with inconsistent types:" f"{return_types[0]} != {return_types[1]}", ) return_type = return_types[0] or return_types[1] @@ -161,9 +162,9 @@ def deduce_stmt_return_type( raise AssertionError(f"Nodes of type `{type(stmt).__name__}` not supported.") if conditional_return_type and return_type and return_type != conditional_return_type: - raise FieldOperatorTypeDeductionError.from_foast_node( - stmt, - msg=f"If statement contains return statements with inconsistent types:" + raise errors.DSLError( + stmt.location, + f"If statement contains return statements with inconsistent types:" f"{conditional_return_type} != {conditional_return_type}", ) @@ -247,9 +248,9 @@ def visit_FunctionDefinition(self, node: foast.FunctionDefinition, **kwargs): new_closure_vars = self.visit(node.closure_vars, **kwargs) return_type = deduce_stmt_return_type(new_body) if not isinstance(return_type, (ts.DataType, ts.DeferredType, ts.VoidType)): - raise FieldOperatorTypeDeductionError.from_foast_node( - node, - msg=f"Function must return `DataType`, `DeferredType`, or `VoidType`, got `{return_type}`.", + raise errors.DSLError( + node.location, + f"Function must return `DataType`, `DeferredType`, or `VoidType`, got `{return_type}`.", ) new_type = ts.FunctionType( pos_only_args=[], @@ -279,44 +280,44 @@ def visit_FieldOperator(self, node: foast.FieldOperator, **kwargs) -> foast.Fiel def visit_ScanOperator(self, node: foast.ScanOperator, **kwargs) -> foast.ScanOperator: new_axis = self.visit(node.axis, **kwargs) if not isinstance(new_axis.type, ts.DimensionType): - raise FieldOperatorTypeDeductionError.from_foast_node( - node, - msg=f"Argument `axis` to scan operator `{node.id}` must be a dimension.", + raise errors.DSLError( + node.location, + f"Argument `axis` to scan operator `{node.id}` must be a dimension.", ) if not new_axis.type.dim.kind == DimensionKind.VERTICAL: - raise FieldOperatorTypeDeductionError.from_foast_node( - node, - msg=f"Argument `axis` to scan operator `{node.id}` must be a vertical dimension.", + raise errors.DSLError( + node.location, + f"Argument `axis` to scan operator `{node.id}` must be a vertical dimension.", ) new_forward = self.visit(node.forward, **kwargs) if not new_forward.type.kind == ts.ScalarKind.BOOL: - raise FieldOperatorTypeDeductionError.from_foast_node( - node, msg=f"Argument `forward` to scan operator `{node.id}` must be a boolean." + raise errors.DSLError( + node.location, f"Argument `forward` to scan operator `{node.id}` must be a boolean." ) new_init = self.visit(node.init, **kwargs) if not all( type_info.is_arithmetic(type_) or type_info.is_logical(type_) for type_ in type_info.primitive_constituents(new_init.type) ): - raise FieldOperatorTypeDeductionError.from_foast_node( - node, - msg=f"Argument `init` to scan operator `{node.id}` must " + raise errors.DSLError( + node.location, + f"Argument `init` to scan operator `{node.id}` must " f"be an arithmetic type or a logical type or a composite of arithmetic and logical types.", ) new_definition = self.visit(node.definition, **kwargs) new_def_type = new_definition.type carry_type = list(new_def_type.pos_or_kw_args.values())[0] if new_init.type != new_def_type.returns: - raise FieldOperatorTypeDeductionError.from_foast_node( - node, - msg=f"Argument `init` to scan operator `{node.id}` must have same type as its return. " + raise errors.DSLError( + node.location, + f"Argument `init` to scan operator `{node.id}` must have same type as its return. " f"Expected `{new_def_type.returns}`, but got `{new_init.type}`", ) elif new_init.type != carry_type: carry_arg_name = list(new_def_type.pos_or_kw_args.keys())[0] - raise FieldOperatorTypeDeductionError.from_foast_node( - node, - msg=f"Argument `init` to scan operator `{node.id}` must have same type as `{carry_arg_name}` argument. " + raise errors.DSLError( + node.location, + f"Argument `init` to scan operator `{node.id}` must have same type as `{carry_arg_name}` argument. " f"Expected `{carry_type}`, but got `{new_init.type}`", ) @@ -337,9 +338,7 @@ def visit_ScanOperator(self, node: foast.ScanOperator, **kwargs) -> foast.ScanOp def visit_Name(self, node: foast.Name, **kwargs) -> foast.Name: symtable = kwargs["symtable"] if node.id not in symtable or symtable[node.id].type is None: - raise FieldOperatorTypeDeductionError.from_foast_node( - node, msg=f"Undeclared symbol `{node.id}`." - ) + raise errors.DSLError(node.location, f"Undeclared symbol `{node.id}`.") symbol = symtable[node.id] return foast.Name(id=node.id, type=symbol.type, location=node.location) @@ -363,8 +362,8 @@ def visit_TupleTargetAssign( indices: list[tuple[int, int] | int] = compute_assign_indices(targets, num_elts) if not any(isinstance(i, tuple) for i in indices) and len(indices) != num_elts: - raise FieldOperatorTypeDeductionError.from_foast_node( - node, msg=f"Too many values to unpack (expected {len(indices)})." + raise errors.DSLError( + node.location, f"Too many values to unpack (expected {len(indices)})." ) new_targets: TargetType = [] @@ -395,8 +394,8 @@ def visit_TupleTargetAssign( ) new_targets.append(new_target) else: - raise FieldOperatorTypeDeductionError.from_foast_node( - node, msg=f"Assignment value must be of type tuple! Got: {values.type}" + raise errors.DSLError( + node.location, f"Assignment value must be of type tuple! Got: {values.type}" ) return foast.TupleTargetAssign(targets=new_targets, value=values, location=node.location) @@ -414,16 +413,16 @@ def visit_IfStmt(self, node: foast.IfStmt, **kwargs) -> foast.IfStmt: ) if not isinstance(new_node.condition.type, ts.ScalarType): - raise FieldOperatorTypeDeductionError.from_foast_node( - node, - msg="Condition for `if` must be scalar. " + raise errors.DSLError( + node.location, + "Condition for `if` must be scalar. " f"But got `{new_node.condition.type}` instead.", ) if new_node.condition.type.kind != ts.ScalarKind.BOOL: - raise FieldOperatorTypeDeductionError.from_foast_node( - node, - msg="Condition for `if` must be of boolean type. " + raise errors.DSLError( + node.location, + "Condition for `if` must be of boolean type. " f"But got `{new_node.condition.type}` instead.", ) @@ -431,9 +430,9 @@ def visit_IfStmt(self, node: foast.IfStmt, **kwargs) -> foast.IfStmt: if (true_type := new_true_branch.annex.symtable[sym].type) != ( false_type := new_false_branch.annex.symtable[sym].type ): - raise FieldOperatorTypeDeductionError.from_foast_node( - node, - msg=f"Inconsistent types between two branches for variable `{sym}`. " + raise errors.DSLError( + node.location, + f"Inconsistent types between two branches for variable `{sym}`. " f"Got types `{true_type}` and `{false_type}.", ) # TODO: properly patch symtable (new node?) @@ -452,9 +451,9 @@ def visit_Symbol( symtable = kwargs["symtable"] if refine_type: if not type_info.is_concretizable(node.type, to_type=refine_type): - raise FieldOperatorTypeDeductionError.from_foast_node( - node, - msg=( + raise errors.DSLError( + node.location, + ( "type inconsistency: expression was deduced to be " f"of type {refine_type}, instead of the expected type {node.type}" ), @@ -474,8 +473,8 @@ def visit_Subscript(self, node: foast.Subscript, **kwargs) -> foast.Subscript: new_type = types[node.index] case ts.OffsetType(source=source, target=(target1, target2)): if not target2.kind == DimensionKind.LOCAL: - raise FieldOperatorTypeDeductionError.from_foast_node( - new_value, msg="Second dimension in offset must be a local dimension." + raise errors.DSLError( + new_value.location, "Second dimension in offset must be a local dimension." ) new_type = ts.OffsetType(source=source, target=(target1,)) case ts.OffsetType(source=source, target=(target,)): @@ -483,14 +482,14 @@ def visit_Subscript(self, node: foast.Subscript, **kwargs) -> foast.Subscript: # signifies the displacement in the respective dimension, # but does not change the target type. if source != target: - raise FieldOperatorTypeDeductionError.from_foast_node( - new_value, - msg="Source and target must be equal for offsets with a single target.", + raise errors.DSLError( + new_value.location, + "Source and target must be equal for offsets with a single target.", ) new_type = new_value.type case _: - raise FieldOperatorTypeDeductionError.from_foast_node( - new_value, msg="Could not deduce type of subscript expression!" + raise errors.DSLError( + new_value.location, "Could not deduce type of subscript expression!" ) return foast.Subscript( @@ -529,15 +528,15 @@ def _deduce_ternaryexpr_type( false_expr: foast.Expr, ) -> Optional[ts.TypeSpec]: if condition.type != ts.ScalarType(kind=ts.ScalarKind.BOOL): - raise FieldOperatorTypeDeductionError.from_foast_node( - condition, - msg=f"Condition is of type `{condition.type}` " f"but should be of type `bool`.", + raise errors.DSLError( + condition.location, + f"Condition is of type `{condition.type}` " f"but should be of type `bool`.", ) if true_expr.type != false_expr.type: - raise FieldOperatorTypeDeductionError.from_foast_node( - node, - msg=f"Left and right types are not the same: `{true_expr.type}` and `{false_expr.type}`", + raise errors.DSLError( + node.location, + f"Left and right types are not the same: `{true_expr.type}` and `{false_expr.type}`", ) return true_expr.type @@ -555,8 +554,8 @@ def _deduce_compare_type( # check both types compatible for arg in (left, right): if not type_info.is_arithmetic(arg.type): - raise FieldOperatorTypeDeductionError.from_foast_node( - arg, msg=f"Type {arg.type} can not be used in operator '{node.op}'!" + raise errors.DSLError( + arg.location, f"Type {arg.type} can not be used in operator '{node.op}'!" ) self._check_operand_dtypes_match(node, left=left, right=right) @@ -565,10 +564,10 @@ def _deduce_compare_type( # transform operands to have bool dtype and use regular promotion # mechanism to handle dimension promotion return type_info.promote(boolified_type(left.type), boolified_type(right.type)) - except GTTypeError as ex: - raise FieldOperatorTypeDeductionError.from_foast_node( - node, - msg=f"Could not promote `{left.type}` and `{right.type}` to common type" + except ValueError as ex: + raise errors.DSLError( + node.location, + f"Could not promote `{left.type}` and `{right.type}` to common type" f" in call to `{node.op}`.", ) from ex @@ -590,8 +589,8 @@ def _deduce_binop_type( # check both types compatible for arg in (left, right): if not is_compatible(arg.type): - raise FieldOperatorTypeDeductionError.from_foast_node( - arg, msg=f"Type {arg.type} can not be used in operator `{node.op}`!" + raise errors.DSLError( + arg.location, f"Type {arg.type} can not be used in operator `{node.op}`!" ) left_type = cast(ts.FieldType | ts.ScalarType, left.type) @@ -603,17 +602,17 @@ def _deduce_binop_type( if node.op == dialect_ast_enums.BinaryOperator.MOD and not type_info.is_integral( right_type ): - raise FieldOperatorTypeDeductionError.from_foast_node( - arg, - msg=f"Type {right_type} can not be used in operator `{node.op}`, it can only accept ints", + raise errors.DSLError( + arg.location, + f"Type {right_type} can not be used in operator `{node.op}`, it can only accept ints", ) try: return type_info.promote(left_type, right_type) - except GTTypeError as ex: - raise FieldOperatorTypeDeductionError.from_foast_node( - node, - msg=f"Could not promote `{left_type}` and `{right_type}` to common type" + except ValueError as ex: + raise errors.DSLError( + node.location, + f"Could not promote `{left_type}` and `{right_type}` to common type" f" in call to `{node.op}`.", ) from ex @@ -622,9 +621,9 @@ def _check_operand_dtypes_match( ) -> None: # check dtypes match if not type_info.extract_dtype(left.type) == type_info.extract_dtype(right.type): - raise FieldOperatorTypeDeductionError.from_foast_node( - node, - msg=f"Incompatible datatypes in operator `{node.op}`: {left.type} and {right.type}!", + raise errors.DSLError( + node.location, + f"Incompatible datatypes in operator `{node.op}`: {left.type} and {right.type}!", ) def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs) -> foast.UnaryOp: @@ -639,9 +638,9 @@ def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs) -> foast.UnaryOp: else type_info.is_arithmetic ) if not is_compatible(new_operand.type): - raise FieldOperatorTypeDeductionError.from_foast_node( - node, - msg=f"Incompatible type for unary operator `{node.op}`: `{new_operand.type}`!", + raise errors.DSLError( + node.location, + f"Incompatible type for unary operator `{node.op}`: `{new_operand.type}`!", ) return foast.UnaryOp( op=node.op, operand=new_operand, location=node.location, type=new_operand.type @@ -671,15 +670,13 @@ def visit_Call(self, node: foast.Call, **kwargs) -> foast.Call: new_func, (foast.FunctionDefinition, foast.FieldOperator, foast.ScanOperator, foast.Name), ): - raise FieldOperatorTypeDeductionError.from_foast_node( - node, msg="Functions can only be called directly!" - ) + raise errors.DSLError(node.location, "Functions can only be called directly!") elif isinstance(new_func.type, ts.FieldType): pass else: - raise FieldOperatorTypeDeductionError.from_foast_node( - node, - msg=f"Expression of type `{new_func.type}` is not callable, must be a `Function`, `FieldOperator`, `ScanOperator` or `Field`.", + raise errors.DSLError( + node.location, + f"Expression of type `{new_func.type}` is not callable, must be a `Function`, `FieldOperator`, `ScanOperator` or `Field`.", ) # ensure signature is valid @@ -690,9 +687,9 @@ def visit_Call(self, node: foast.Call, **kwargs) -> foast.Call: with_kwargs=kwarg_types, raise_exception=True, ) - except GTTypeError as err: - raise FieldOperatorTypeDeductionError.from_foast_node( - node, msg=f"Invalid argument types in call to `{new_func}`!" + except ValueError as err: + raise errors.DSLError( + node.location, f"Invalid argument types in call to `{new_func}`!" ) from err return_type = type_info.return_type(func_type, with_args=arg_types, with_kwargs=kwarg_types) @@ -749,9 +746,9 @@ def _visit_math_built_in(self, node: foast.Call, **kwargs) -> foast.Call: f"Expected {i}-th argument to be {error_msg_for_validator[arg_validator]} type, but got `{arg.type}`." ) if error_msgs: - raise FieldOperatorTypeDeductionError.from_foast_node( - node, - msg="\n".join([error_msg_preamble] + [f" - {error}" for error in error_msgs]), + raise errors.DSLError( + node.location, + "\n".join([error_msg_preamble] + [f" - {error}" for error in error_msgs]), ) if func_name == "power" and all(type_info.is_integral(arg.type) for arg in node.args): @@ -771,10 +768,8 @@ def _visit_math_built_in(self, node: foast.Call, **kwargs) -> foast.Call: return_type = type_info.promote( *((cast(ts.FieldType | ts.ScalarType, arg.type)) for arg in node.args) ) - except GTTypeError as ex: - raise FieldOperatorTypeDeductionError.from_foast_node( - node, msg=error_msg_preamble - ) from ex + except ValueError as ex: + raise errors.DSLError(node.location, error_msg_preamble) from ex else: raise AssertionError(f"Unknown math builtin `{func_name}`.") @@ -793,9 +788,9 @@ def _visit_reduction(self, node: foast.Call, **kwargs) -> foast.Call: assert field_type.dims is not ... if reduction_dim not in field_type.dims: field_dims_str = ", ".join(str(dim) for dim in field_type.dims) - raise FieldOperatorTypeDeductionError.from_foast_node( - node, - msg=f"Incompatible field argument in call to `{str(node.func)}`. " + raise errors.DSLError( + node.location, + f"Incompatible field argument in call to `{str(node.func)}`. " f"Expected a field with dimension {reduction_dim}, but got " f"{field_dims_str}.", ) @@ -849,17 +844,17 @@ def _visit_as_offset(self, node: foast.Call, **kwargs) -> foast.Call: assert isinstance(arg_0, ts.OffsetType) assert isinstance(arg_1, ts.FieldType) if not type_info.is_integral(arg_1): - raise FieldOperatorTypeDeductionError.from_foast_node( - node, - msg=f"Incompatible argument in call to `{str(node.func)}`. " + raise errors.DSLError( + node.location, + f"Incompatible argument in call to `{str(node.func)}`. " f"Excepted integer for offset field dtype, but got {arg_1.dtype}" f"{node.location}", ) if arg_0.source not in arg_1.dims: - raise FieldOperatorTypeDeductionError.from_foast_node( - node, - msg=f"Incompatible argument in call to `{str(node.func)}`. " + raise errors.DSLError( + node.location, + f"Incompatible argument in call to `{str(node.func)}`. " f"{arg_0.source} not in list of offset field dimensions {arg_1.dims}. " f"{node.location}", ) @@ -878,9 +873,9 @@ def _visit_where(self, node: foast.Call, **kwargs) -> foast.Call: false_branch_type = node.args[2].type return_type: ts.TupleType | ts.FieldType if not type_info.is_logical(mask_type): - raise FieldOperatorTypeDeductionError.from_foast_node( - node, - msg=f"Incompatible argument in call to `{str(node.func)}`. Expected " + raise errors.DSLError( + node.location, + f"Incompatible argument in call to `{str(node.func)}`. Expected " f"a field with dtype `bool`, but got `{mask_type}`.", ) @@ -896,9 +891,9 @@ def _visit_where(self, node: foast.Call, **kwargs) -> foast.Call: elif isinstance(true_branch_type, ts.TupleType) or isinstance( false_branch_type, ts.TupleType ): - raise FieldOperatorTypeDeductionError.from_foast_node( - node, - msg=f"Return arguments need to be of same type in {str(node.func)}, but got: " + raise errors.DSLError( + node.location, + f"Return arguments need to be of same type in {str(node.func)}, but got: " f"{node.args[1].type} and {node.args[2].type}", ) else: @@ -907,10 +902,10 @@ def _visit_where(self, node: foast.Call, **kwargs) -> foast.Call: promoted_type = type_info.promote(true_branch_fieldtype, false_branch_fieldtype) return_type = promote_to_mask_type(mask_type, promoted_type) - except GTTypeError as ex: - raise FieldOperatorTypeDeductionError.from_foast_node( - node, - msg=f"Incompatible argument in call to `{str(node.func)}`.", + except ValueError as ex: + raise errors.DSLError( + node.location, + f"Incompatible argument in call to `{str(node.func)}`.", ) from ex return foast.Call( @@ -926,18 +921,18 @@ def _visit_broadcast(self, node: foast.Call, **kwargs) -> foast.Call: broadcast_dims_expr = cast(foast.TupleExpr, node.args[1]).elts if any([not (isinstance(elt.type, ts.DimensionType)) for elt in broadcast_dims_expr]): - raise FieldOperatorTypeDeductionError.from_foast_node( - node, - msg=f"Incompatible broadcast dimension type in {str(node.func)}. Expected " + raise errors.DSLError( + node.location, + f"Incompatible broadcast dimension type in {str(node.func)}. Expected " f"all broadcast dimensions to be of type Dimension.", ) broadcast_dims = [cast(ts.DimensionType, elt.type).dim for elt in broadcast_dims_expr] if not set((arg_dims := type_info.extract_dims(arg_type))).issubset(set(broadcast_dims)): - raise FieldOperatorTypeDeductionError.from_foast_node( - node, - msg=f"Incompatible broadcast dimensions in {str(node.func)}. Expected " + raise errors.DSLError( + node.location, + f"Incompatible broadcast dimensions in {str(node.func)}. Expected " f"broadcast dimension is missing {set(arg_dims).difference(set(broadcast_dims))}", ) @@ -957,42 +952,6 @@ def _visit_broadcast(self, node: foast.Call, **kwargs) -> foast.Call: def visit_Constant(self, node: foast.Constant, **kwargs) -> foast.Constant: try: type_ = type_translation.from_value(node.value) - except GTTypeError as e: - raise FieldOperatorTypeDeductionError.from_foast_node( - node, msg="Could not deduce type of constant." - ) from e + except ValueError as e: + raise errors.DSLError(node.location, "Could not deduce type of constant.") from e return foast.Constant(value=node.value, location=node.location, type=type_) - - -class FieldOperatorTypeDeductionError(GTSyntaxError, SyntaxWarning): - """Exception for problematic type deductions that originate in user code.""" - - def __init__( - self, - msg="", - *, - lineno=0, - offset=0, - filename=None, - end_lineno=None, - end_offset=None, - text=None, - ): - msg = "Could not deduce type: " + msg - super().__init__(msg, (filename, lineno, offset, text, end_lineno, end_offset)) - - @classmethod - def from_foast_node( - cls, - node: foast.LocatedNode, - *, - msg: str = "", - ): - return cls( - msg, - lineno=node.location.line, - offset=node.location.column, - filename=node.location.source, - end_lineno=node.location.end_line, - end_offset=node.location.end_column, - ) diff --git a/src/gt4py/next/ffront/func_to_foast.py b/src/gt4py/next/ffront/func_to_foast.py index 7ef4f597ab..082939c938 100644 --- a/src/gt4py/next/ffront/func_to_foast.py +++ b/src/gt4py/next/ffront/func_to_foast.py @@ -19,7 +19,7 @@ from typing import Any, Callable, Iterable, Mapping, Type, cast import gt4py.eve as eve -from gt4py.next import common +from gt4py.next import errors from gt4py.next.ffront import dialect_ast_enums, fbuiltins, field_operator_ast as foast from gt4py.next.ffront.ast_passes import ( SingleAssignTargetPass, @@ -27,7 +27,7 @@ StringifyAnnotationsPass, UnchainComparesPass, ) -from gt4py.next.ffront.dialect_parser import DialectParser, DialectSyntaxError +from gt4py.next.ffront.dialect_parser import DialectParser from gt4py.next.ffront.foast_introspection import StmtReturnKind, deduce_stmt_return_kind from gt4py.next.ffront.foast_passes.closure_var_folding import ClosureVarFolding from gt4py.next.ffront.foast_passes.closure_var_type_deduction import ClosureVarTypeDeduction @@ -37,10 +37,6 @@ from gt4py.next.type_system import type_info, type_specifications as ts, type_translation -class FieldOperatorSyntaxError(DialectSyntaxError): - dialect_name = "Field Operator" - - class FieldOperatorParser(DialectParser[foast.FunctionDefinition]): """ Parse field operator function definition from source code into FOAST. @@ -75,13 +71,11 @@ class FieldOperatorParser(DialectParser[foast.FunctionDefinition]): >>> >>> try: # doctest: +ELLIPSIS ... FieldOperatorParser.apply_to_function(wrong_syntax) - ... except FieldOperatorSyntaxError as err: - ... print(f"Error at [{err.lineno}, {err.offset}] in {err.filename})") - Error at [2, 5] in ...gt4py.next.ffront.func_to_foast.FieldOperatorParser[...]>) + ... except errors.DSLError as err: + ... print(f"Error at [{err.location.line}, {err.location.column}] in {err.location.filename})") + Error at [2, 5] in ...func_to_foast.FieldOperatorParser[...]>) """ - syntax_error_cls = FieldOperatorSyntaxError - @classmethod def _preprocess_definition_ast(cls, definition_ast: ast.AST) -> ast.AST: sta = StringifyAnnotationsPass.apply(definition_ast) @@ -109,9 +103,10 @@ def _postprocess_dialect_ast( # TODO(tehrengruber): use `type_info.return_type` when the type of the # arguments becomes available here if annotated_return_type != foast_node.type.returns: # type: ignore[union-attr] # revisit when `type_info.return_type` is implemented - raise common.GTTypeError( + raise errors.DSLError( + foast_node.location, f"Annotated return type does not match deduced return type. Expected `{foast_node.type.returns}`" # type: ignore[union-attr] # revisit when `type_info.return_type` is implemented - f", but got `{annotated_return_type}`." + f", but got `{annotated_return_type}`.", ) return foast_node @@ -151,8 +146,9 @@ def _builtin_type_constructor_symbols( return result, to_be_inserted.keys() def visit_FunctionDef(self, node: ast.FunctionDef, **kwargs) -> foast.FunctionDefinition: + loc = self.get_location(node) closure_var_symbols, skip_names = self._builtin_type_constructor_symbols( - self.closure_vars, self._make_loc(node) + self.closure_vars, self.get_location(node) ) for name in self.closure_vars.keys(): if name in skip_names: @@ -162,37 +158,34 @@ def visit_FunctionDef(self, node: ast.FunctionDef, **kwargs) -> foast.FunctionDe id=name, type=ts.DeferredType(constraint=None), namespace=dialect_ast_enums.Namespace.CLOSURE, - location=self._make_loc(node), + location=self.get_location(node), ) ) - new_body = self._visit_stmts(node.body, self._make_loc(node), **kwargs) + new_body = self._visit_stmts(node.body, self.get_location(node), **kwargs) if deduce_stmt_return_kind(new_body) == StmtReturnKind.NO_RETURN: - raise FieldOperatorSyntaxError.from_AST( - node, msg="Function must return a value, but no return statement was found." - ) + raise errors.DSLError(loc, "Function is expected to return a value.") return foast.FunctionDefinition( id=node.name, params=self.visit(node.args, **kwargs), body=new_body, closure_vars=closure_var_symbols, - location=self._make_loc(node), + location=loc, ) def visit_arguments(self, node: ast.arguments) -> list[foast.DataSymbol]: return [self.visit_arg(arg) for arg in node.args] def visit_arg(self, node: ast.arg) -> foast.DataSymbol: + loc = self.get_location(node) if (annotation := self.annotations.get(node.arg, None)) is None: - raise FieldOperatorSyntaxError.from_AST(node, msg="Untyped parameters not allowed!") + raise errors.MissingParameterAnnotationError(loc, node.arg) new_type = type_translation.from_type_hint(annotation) if not isinstance(new_type, ts.DataType): - raise FieldOperatorSyntaxError.from_AST( - node, msg="Only arguments of type DataType are allowed." - ) - return foast.DataSymbol(id=node.arg, location=self._make_loc(node), type=new_type) + raise errors.InvalidParameterAnnotationError(loc, node.arg, new_type) + return foast.DataSymbol(id=node.arg, location=loc, type=new_type) def visit_Assign(self, node: ast.Assign, **kwargs) -> foast.Assign | foast.TupleTargetAssign: target = node.targets[0] # there is only one element after assignment passes @@ -208,10 +201,10 @@ def visit_Assign(self, node: ast.Assign, **kwargs) -> foast.Assign | foast.Tuple foast.Starred( id=foast.DataSymbol( id=self.visit(elt.value).id, - location=self._make_loc(elt), + location=self.get_location(elt), type=ts.DeferredType(constraint=ts.DataType), ), - location=self._make_loc(elt), + location=self.get_location(elt), type=ts.DeferredType(constraint=ts.DataType), ) ) @@ -219,17 +212,17 @@ def visit_Assign(self, node: ast.Assign, **kwargs) -> foast.Assign | foast.Tuple new_targets.append( foast.DataSymbol( id=self.visit(elt).id, - location=self._make_loc(elt), + location=self.get_location(elt), type=ts.DeferredType(constraint=ts.DataType), ) ) return foast.TupleTargetAssign( - targets=new_targets, value=self.visit(node.value), location=self._make_loc(node) + targets=new_targets, value=self.visit(node.value), location=self.get_location(node) ) if not isinstance(target, ast.Name): - raise FieldOperatorSyntaxError.from_AST(node, msg="Can only assign to names!") + raise errors.DSLError(self.get_location(node), "can only assign to names") new_value = self.visit(node.value) constraint_type: Type[ts.DataType] = ts.DataType if isinstance(new_value, foast.TupleExpr): @@ -242,16 +235,16 @@ def visit_Assign(self, node: ast.Assign, **kwargs) -> foast.Assign | foast.Tuple return foast.Assign( target=foast.DataSymbol( id=target.id, - location=self._make_loc(target), + location=self.get_location(target), type=ts.DeferredType(constraint=constraint_type), ), value=new_value, - location=self._make_loc(node), + location=self.get_location(node), ) def visit_AnnAssign(self, node: ast.AnnAssign, **kwargs) -> foast.Assign: if not isinstance(node.target, ast.Name): - raise FieldOperatorSyntaxError.from_AST(node, msg="Can only assign to names!") + raise errors.DSLError(self.get_location(node), "can only assign to names") if node.annotation is not None: assert isinstance( @@ -266,11 +259,11 @@ def visit_AnnAssign(self, node: ast.AnnAssign, **kwargs) -> foast.Assign: return foast.Assign( target=foast.Symbol[ts.FieldType]( id=node.target.id, - location=self._make_loc(node.target), + location=self.get_location(node.target), type=target_type, ), value=self.visit(node.value) if node.value else None, - location=self._make_loc(node), + location=self.get_location(node), ) @staticmethod @@ -292,40 +285,43 @@ def visit_Subscript(self, node: ast.Subscript, **kwargs) -> foast.Subscript: try: index = self._match_index(node.slice) except ValueError: - raise FieldOperatorSyntaxError.from_AST( - node, msg="""Only index is supported in subscript!""" - ) + raise errors.DSLError( + self.get_location(node.slice), "expected an integral index" + ) from None return foast.Subscript( value=self.visit(node.value), index=index, - location=self._make_loc(node), + location=self.get_location(node), ) def visit_Attribute(self, node: ast.Attribute) -> Any: return foast.Attribute( - value=self.visit(node.value), attr=node.attr, location=self._make_loc(node) + value=self.visit(node.value), attr=node.attr, location=self.get_location(node) ) def visit_Tuple(self, node: ast.Tuple, **kwargs) -> foast.TupleExpr: return foast.TupleExpr( - elts=[self.visit(item) for item in node.elts], location=self._make_loc(node) + elts=[self.visit(item) for item in node.elts], location=self.get_location(node) ) def visit_Return(self, node: ast.Return, **kwargs) -> foast.Return: + loc = self.get_location(node) if not node.value: - raise FieldOperatorSyntaxError.from_AST(node, msg="Empty return not allowed") - return foast.Return(value=self.visit(node.value), location=self._make_loc(node)) + raise errors.DSLError(loc, "must return a value, not None") + return foast.Return(value=self.visit(node.value), location=loc) def visit_Expr(self, node: ast.Expr) -> foast.Expr: return self.visit(node.value) def visit_Name(self, node: ast.Name, **kwargs) -> foast.Name: - return foast.Name(id=node.id, location=self._make_loc(node)) + return foast.Name(id=node.id, location=self.get_location(node)) def visit_UnaryOp(self, node: ast.UnaryOp, **kwargs) -> foast.UnaryOp: return foast.UnaryOp( - op=self.visit(node.op), operand=self.visit(node.operand), location=self._make_loc(node) + op=self.visit(node.op), + operand=self.visit(node.operand), + location=self.get_location(node), ) def visit_UAdd(self, node: ast.UAdd, **kwargs) -> dialect_ast_enums.UnaryOperator: @@ -345,7 +341,7 @@ def visit_BinOp(self, node: ast.BinOp, **kwargs) -> foast.BinOp: op=self.visit(node.op), left=self.visit(node.left), right=self.visit(node.right), - location=self._make_loc(node), + location=self.get_location(node), ) def visit_Add(self, node: ast.Add, **kwargs) -> dialect_ast_enums.BinaryOperator: @@ -379,19 +375,21 @@ def visit_BitXor(self, node: ast.BitXor, **kwargs) -> dialect_ast_enums.BinaryOp return dialect_ast_enums.BinaryOperator.BIT_XOR def visit_BoolOp(self, node: ast.BoolOp, **kwargs) -> None: - raise FieldOperatorSyntaxError.from_AST(node, msg="`and`/`or` operator not allowed!") + raise errors.UnsupportedPythonFeatureError( + self.get_location(node), "logical operators `and`, `or`" + ) def visit_IfExp(self, node: ast.IfExp, **kwargs) -> foast.TernaryExpr: return foast.TernaryExpr( condition=self.visit(node.test), true_expr=self.visit(node.body), false_expr=self.visit(node.orelse), - location=self._make_loc(node), + location=self.get_location(node), type=ts.DeferredType(constraint=ts.DataType), ) def visit_If(self, node: ast.If, **kwargs) -> foast.IfStmt: - loc = self._make_loc(node) + loc = self.get_location(node) return foast.IfStmt( condition=self.visit(node.test, **kwargs), true_branch=self._visit_stmts(node.body, loc, **kwargs), @@ -408,15 +406,16 @@ def _visit_stmts( ) def visit_Compare(self, node: ast.Compare, **kwargs) -> foast.Compare: + loc = self.get_location(node) if len(node.ops) != 1 or len(node.comparators) != 1: - raise FieldOperatorSyntaxError.from_AST( - node, msg="Comparison chains not allowed, run a preprocessing pass!" - ) + # Remove comparison chains in a preprocessing pass + # TODO: maybe add a note to the error about preprocessing passes? + raise errors.UnsupportedPythonFeatureError(loc, "comparison chains") return foast.Compare( op=self.visit(node.ops[0]), left=self.visit(node.left), right=self.visit(node.comparators[0]), - location=self._make_loc(node), + location=loc, ) def visit_Gt(self, node: ast.Gt, **kwargs) -> foast.CompareOperator: @@ -439,9 +438,9 @@ def visit_NotEq(self, node: ast.NotEq, **kwargs) -> foast.CompareOperator: def _verify_builtin_type_constructor(self, node: ast.Call): if len(node.args) > 0 and not isinstance(node.args[0], ast.Constant): - raise FieldOperatorSyntaxError.from_AST( - node, - msg=f"{self._func_name(node)}() only takes literal arguments!", + raise errors.DSLError( + self.get_location(node), + f"{self._func_name(node)}() only takes literal arguments!", ) def _func_name(self, node: ast.Call) -> str: @@ -458,19 +457,20 @@ def visit_Call(self, node: ast.Call, **kwargs) -> foast.Call: func=self.visit(node.func, **kwargs), args=[self.visit(arg, **kwargs) for arg in node.args], kwargs={keyword.arg: self.visit(keyword.value, **kwargs) for keyword in node.keywords}, - location=self._make_loc(node), + location=self.get_location(node), ) def visit_Constant(self, node: ast.Constant, **kwargs) -> foast.Constant: + loc = self.get_location(node) try: type_ = type_translation.from_value(node.value) - except common.GTTypeError as e: - raise FieldOperatorSyntaxError.from_AST( - node, msg=f"Constants of type {type(node.value)} are not permitted." - ) from e + except ValueError: + raise errors.DSLError( + loc, f"constants of type {type(node.value)} are not permitted" + ) from None return foast.Constant( value=node.value, - location=self._make_loc(node), + location=loc, type=type_, ) diff --git a/src/gt4py/next/ffront/func_to_past.py b/src/gt4py/next/ffront/func_to_past.py index 0f51ca88c1..7b04e90902 100644 --- a/src/gt4py/next/ffront/func_to_past.py +++ b/src/gt4py/next/ffront/func_to_past.py @@ -18,33 +18,26 @@ from dataclasses import dataclass from typing import Any, cast +from gt4py.next import errors from gt4py.next.ffront import ( dialect_ast_enums, program_ast as past, type_specifications as ts_ffront, ) -from gt4py.next.ffront.dialect_parser import DialectParser, DialectSyntaxError +from gt4py.next.ffront.dialect_parser import DialectParser from gt4py.next.ffront.past_passes.closure_var_type_deduction import ClosureVarTypeDeduction from gt4py.next.ffront.past_passes.type_deduction import ProgramTypeDeduction from gt4py.next.type_system import type_specifications as ts, type_translation -class ProgramSyntaxError(DialectSyntaxError): - dialect_name = "Program" - - @dataclass(frozen=True, kw_only=True) class ProgramParser(DialectParser[past.Program]): """Parse program definition from Python source code into PAST.""" - syntax_error_cls = ProgramSyntaxError - @classmethod def _postprocess_dialect_ast( cls, output_node: past.Program, closure_vars: dict[str, Any], annotations: dict[str, Any] ) -> past.Program: - if "return" in annotations and not isinstance(None, annotations["return"]): - raise ProgramSyntaxError("Program should not have a return value!") output_node = ClosureVarTypeDeduction.apply(output_node, closure_vars) return ProgramTypeDeduction.apply(output_node) @@ -54,7 +47,7 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> past.Program: id=name, type=type_translation.from_value(val), namespace=dialect_ast_enums.Namespace.CLOSURE, - location=self._make_loc(node), + location=self.get_location(node), ) for name, val in self.closure_vars.items() ] @@ -65,21 +58,20 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> past.Program: params=self.visit(node.args), body=[self.visit(node) for node in node.body], closure_vars=closure_symbols, - location=self._make_loc(node), + location=self.get_location(node), ) def visit_arguments(self, node: ast.arguments) -> list[past.DataSymbol]: return [self.visit_arg(arg) for arg in node.args] def visit_arg(self, node: ast.arg) -> past.DataSymbol: + loc = self.get_location(node) if (annotation := self.annotations.get(node.arg, None)) is None: - raise ProgramSyntaxError.from_AST(node, msg="Untyped parameters not allowed!") + raise errors.MissingParameterAnnotationError(loc, node.arg) new_type = type_translation.from_type_hint(annotation) if not isinstance(new_type, ts.DataType): - raise ProgramSyntaxError.from_AST( - node, msg="Only arguments of type DataType are allowed." - ) - return past.DataSymbol(id=node.arg, location=self._make_loc(node), type=new_type) + raise errors.InvalidParameterAnnotationError(loc, node.arg, new_type) + return past.DataSymbol(id=node.arg, location=loc, type=new_type) def visit_Expr(self, node: ast.Expr) -> past.LocatedNode: return self.visit(node.value) @@ -119,42 +111,45 @@ def visit_BinOp(self, node: ast.BinOp, **kwargs) -> past.BinOp: op=self.visit(node.op), left=self.visit(node.left), right=self.visit(node.right), - location=self._make_loc(node), + location=self.get_location(node), ) def visit_Name(self, node: ast.Name) -> past.Name: - return past.Name(id=node.id, location=self._make_loc(node)) + return past.Name(id=node.id, location=self.get_location(node)) def visit_Dict(self, node: ast.Dict) -> past.Dict: return past.Dict( keys_=[self.visit(cast(ast.AST, param)) for param in node.keys], values_=[self.visit(param) for param in node.values], - location=self._make_loc(node), + location=self.get_location(node), ) def visit_Call(self, node: ast.Call) -> past.Call: + loc = self.get_location(node) new_func = self.visit(node.func) if not isinstance(new_func, past.Name): - raise ProgramSyntaxError.from_AST(node, msg="Functions can only be called directly!") + raise errors.DSLError( + loc, "functions must be referenced by their name in function calls" + ) return past.Call( func=new_func, args=[self.visit(arg) for arg in node.args], kwargs={arg.arg: self.visit(arg.value) for arg in node.keywords}, - location=self._make_loc(node), + location=loc, ) def visit_Subscript(self, node: ast.Subscript) -> past.Subscript: return past.Subscript( value=self.visit(node.value), slice_=self.visit(node.slice), - location=self._make_loc(node), + location=self.get_location(node), ) def visit_Tuple(self, node: ast.Tuple) -> past.TupleExpr: return past.TupleExpr( elts=[self.visit(item) for item in node.elts], - location=self._make_loc(node), + location=self.get_location(node), type=ts.DeferredType(constraint=ts.TupleType), ) @@ -163,17 +158,16 @@ def visit_Slice(self, node: ast.Slice) -> past.Slice: lower=self.visit(node.lower) if node.lower is not None else None, upper=self.visit(node.upper) if node.upper is not None else None, step=self.visit(node.step) if node.step is not None else None, - location=self._make_loc(node), + location=self.get_location(node), ) def visit_UnaryOp(self, node: ast.UnaryOp) -> past.Constant: + loc = self.get_location(node) if isinstance(node.op, ast.USub) and isinstance(node.operand, ast.Constant): symbol_type = type_translation.from_value(node.operand.value) - return past.Constant( - value=-node.operand.value, type=symbol_type, location=self._make_loc(node) - ) - raise ProgramSyntaxError.from_AST(node, msg="Unary operators can only be used on literals.") + return past.Constant(value=-node.operand.value, type=symbol_type, location=loc) + raise errors.DSLError(loc, "unary operators are only applicable to literals") def visit_Constant(self, node: ast.Constant) -> past.Constant: symbol_type = type_translation.from_value(node.value) - return past.Constant(value=node.value, type=symbol_type, location=self._make_loc(node)) + return past.Constant(value=node.value, type=symbol_type, location=self.get_location(node)) diff --git a/src/gt4py/next/ffront/past_passes/type_deduction.py b/src/gt4py/next/ffront/past_passes/type_deduction.py index 4beb5dd8da..ed3bdae3ff 100644 --- a/src/gt4py/next/ffront/past_passes/type_deduction.py +++ b/src/gt4py/next/ffront/past_passes/type_deduction.py @@ -15,7 +15,7 @@ from typing import Optional, cast from gt4py.eve import NodeTranslator, traits -from gt4py.next.common import GTTypeError +from gt4py.next import errors from gt4py.next.ffront import ( dialect_ast_enums, program_ast as past, @@ -33,7 +33,7 @@ def _ensure_no_sliced_field(entry: past.Expr): For example, if argument is of type past.Subscript, this function will throw an error as both slicing and domain are being applied """ if not isinstance(entry, past.Name) and not isinstance(entry, past.TupleExpr): - raise GTTypeError("Either only domain or slicing allowed") + raise ValueError("Either only domain or slicing allowed") elif isinstance(entry, past.TupleExpr): for param in entry.elts: _ensure_no_sliced_field(param) @@ -56,39 +56,39 @@ def _validate_operator_call(new_func: past.Name, new_kwargs: dict): new_func.type, (ts_ffront.FieldOperatorType, ts_ffront.ScanOperatorType), ): - raise GTTypeError( + raise ValueError( f"Only calls `FieldOperator`s and `ScanOperator`s " f"allowed in `Program`, but got `{new_func.type}`." ) if "out" not in new_kwargs: - raise GTTypeError("Missing required keyword argument(s) `out`.") + raise ValueError("Missing required keyword argument(s) `out`.") if "domain" in new_kwargs: _ensure_no_sliced_field(new_kwargs["out"]) domain_kwarg = new_kwargs["domain"] if not isinstance(domain_kwarg, past.Dict): - raise GTTypeError( + raise ValueError( f"Only Dictionaries allowed in domain, but got `{type(domain_kwarg)}`." ) if len(domain_kwarg.values_) == 0 and len(domain_kwarg.keys_) == 0: - raise GTTypeError("Empty domain not allowed.") + raise ValueError("Empty domain not allowed.") for dim in domain_kwarg.keys_: if not isinstance(dim.type, ts.DimensionType): - raise GTTypeError( + raise ValueError( f"Only Dimension allowed in domain dictionary keys, but got `{dim}` which is of type `{dim.type}`." ) for domain_values in domain_kwarg.values_: if len(domain_values.elts) != 2: - raise GTTypeError( + raise ValueError( f"Only 2 values allowed in domain range, but got `{len(domain_values.elts)}`." ) if not _is_integral_scalar(domain_values.elts[0]) or not _is_integral_scalar( domain_values.elts[1] ): - raise GTTypeError( + raise ValueError( f"Only integer values allowed in domain range, but got {domain_values.elts[0].type} and {domain_values.elts[1].type}." ) @@ -148,8 +148,8 @@ def _deduce_binop_type( # check both types compatible for arg in (left, right): if not isinstance(arg.type, ts.ScalarType) or not is_compatible(arg.type): - raise ProgramTypeError.from_past_node( - arg, msg=f"Type {arg.type} can not be used in operator `{node.op}`!" + raise errors.DSLError( + arg.location, f"Type {arg.type} can not be used in operator `{node.op}`!" ) left_type = cast(ts.ScalarType, left.type) @@ -161,17 +161,17 @@ def _deduce_binop_type( if node.op == dialect_ast_enums.BinaryOperator.MOD and not type_info.is_integral( right_type ): - raise ProgramTypeError.from_past_node( - arg, - msg=f"Type {right_type} can not be used in operator `{node.op}`, it can only accept ints", + raise errors.DSLError( + arg.location, + f"Type {right_type} can not be used in operator `{node.op}`, it can only accept ints", ) try: return type_info.promote(left_type, right_type) - except GTTypeError as ex: - raise ProgramTypeError.from_past_node( - node, - msg=f"Could not promote `{left_type}` and `{right_type}` to common type" + except ValueError as ex: + raise errors.DSLError( + node.location, + f"Could not promote `{left_type}` and `{right_type}` to common type" f" in call to `{node.op}`.", ) from ex @@ -213,14 +213,14 @@ def visit_Call(self, node: past.Call, **kwargs): new_func.type, with_args=arg_types, with_kwargs=kwarg_types ) if operator_return_type != new_kwargs["out"].type: - raise GTTypeError( + raise ValueError( f"Expected keyword argument `out` to be of " f"type {operator_return_type}, but got " f"{new_kwargs['out'].type}." ) elif new_func.id in ["minimum", "maximum"]: if new_args[0].type != new_args[1].type: - raise GTTypeError( + raise ValueError( f"First and second argument in {new_func.id} must be the same type." f"Got `{new_args[0].type}` and `{new_args[1].type}`." ) @@ -230,10 +230,8 @@ def visit_Call(self, node: past.Call, **kwargs): "Only calls `FieldOperator`s, `ScanOperator`s or minimum and maximum builtins allowed" ) - except GTTypeError as ex: - raise ProgramTypeError.from_past_node( - node, msg=f"Invalid call to `{node.func.id}`." - ) from ex + except ValueError as ex: + raise errors.DSLError(node.location, f"Invalid call to `{node.func.id}`.") from ex return past.Call( func=new_func, @@ -246,41 +244,6 @@ def visit_Call(self, node: past.Call, **kwargs): def visit_Name(self, node: past.Name, **kwargs) -> past.Name: symtable = kwargs["symtable"] if node.id not in symtable or symtable[node.id].type is None: - raise ProgramTypeError.from_past_node( - node, msg=f"Undeclared or untyped symbol `{node.id}`." - ) + raise errors.DSLError(node.location, f"Undeclared or untyped symbol `{node.id}`.") return past.Name(id=node.id, type=symtable[node.id].type, location=node.location) - - -class ProgramTypeError(GTTypeError): - """Exception for problematic type deductions that originate in user code.""" - - def __init__( - self, - msg="", - *, - lineno=0, - offset=0, - filename=None, - end_lineno=None, - end_offset=None, - text=None, - ): - super().__init__(msg, (filename, lineno, offset, text, end_lineno, end_offset)) - - @classmethod - def from_past_node( - cls, - node: past.LocatedNode, - *, - msg: str = "", - ): - return cls( - msg, - lineno=node.location.line, - offset=node.location.column, - filename=node.location.source, - end_lineno=node.location.end_line, - end_offset=node.location.end_column, - ) diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index 44a408e9f7..2c5dfc6e2f 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -17,7 +17,7 @@ from typing import Optional, cast from gt4py.eve import NodeTranslator, concepts, traits -from gt4py.next.common import Dimension, DimensionKind, GridType, GTTypeError +from gt4py.next.common import Dimension, DimensionKind, GridType from gt4py.next.ffront import program_ast as past, type_specifications as ts_ffront from gt4py.next.iterator import ir as itir from gt4py.next.type_system import type_info, type_specifications as ts @@ -37,9 +37,7 @@ def _flatten_tuple_expr( for e in node.elts: result.extend(_flatten_tuple_expr(e)) return result - raise GTTypeError( - "Only `past.Name`, `past.Subscript` or `past.TupleExpr`s thereof are allowed." - ) + raise ValueError("Only `past.Name`, `past.Subscript` or `past.TupleExpr`s thereof are allowed.") class ProgramLowering(traits.VisitorWithSymbolTableTrait, NodeTranslator): @@ -190,7 +188,7 @@ def _construct_itir_out_arg(self, node: past.Expr) -> itir.Expr: args=[self._construct_itir_out_arg(el) for el in node.elts], ) else: - raise GTTypeError( + raise ValueError( "Unexpected `out` argument. Must be a `past.Name`, `past.Subscript`" " or a `past.TupleExpr` thereof." ) @@ -234,7 +232,7 @@ def _construct_itir_domain_arg( ) if dim.kind == DimensionKind.LOCAL: - raise GTTypeError(f"Dimension {dim.value} must not be local.") + raise ValueError(f"Dimension {dim.value} must not be local.") domain_args.append( itir.FunCall( fun=itir.SymRef(id="named_range"), @@ -260,7 +258,7 @@ def _construct_itir_initialized_domain_arg( assert len(node_domain.values_[dim_i].elts) == 2 keys_dims_types = cast(ts.DimensionType, node_domain.keys_[dim_i].type).dim if keys_dims_types != dim: - raise GTTypeError( + raise ValueError( f"Dimensions in out field and field domain are not equivalent" f"Expected {dim}, but got {keys_dims_types} " ) @@ -284,7 +282,7 @@ def _compute_field_slice(node: past.Subscript): node_dims_ls = cast(ts.FieldType, node.type).dims assert isinstance(node_dims_ls, list) if isinstance(node.type, ts.FieldType) and len(out_field_slice_) != len(node_dims_ls): - raise GTTypeError( + raise ValueError( f"Too many indices for field {out_field_name}: field is {len(node_dims_ls)}" f"-dimensional, but {len(out_field_slice_)} were indexed." ) diff --git a/src/gt4py/next/ffront/source_utils.py b/src/gt4py/next/ffront/source_utils.py index e0f428dbc2..17b2050b1b 100644 --- a/src/gt4py/next/ffront/source_utils.py +++ b/src/gt4py/next/ffront/source_utils.py @@ -23,8 +23,6 @@ from dataclasses import dataclass from typing import Any, cast -from gt4py.next import common - MISSING_FILENAME = "" @@ -36,32 +34,33 @@ def get_closure_vars_from_function(function: Callable) -> dict[str, Any]: def make_source_definition_from_function(func: Callable) -> SourceDefinition: try: - filename = str(pathlib.Path(inspect.getabsfile(func)).resolve()) or MISSING_FILENAME - source = textwrap.dedent(inspect.getsource(func)) - starting_line = ( - inspect.getsourcelines(func)[1] if not filename.endswith(MISSING_FILENAME) else 1 + filename = str(pathlib.Path(inspect.getabsfile(func)).resolve()) + if not filename: + raise ValueError( + "Can not create field operator from a function that is not in a source file!" + ) + source_lines, line_offset = inspect.getsourcelines(func) + source_code = textwrap.dedent(inspect.getsource(func)) + column_offset = min( + [len(line) - len(line.lstrip()) for line in source_lines if line.lstrip()], default=0 ) - except OSError as err: - if filename.endswith(MISSING_FILENAME): - message = "Can not create field operator from a function that is not in a source file!" - else: - message = f"Can not get source code of passed function ({func})" - raise ValueError(message) from err + return SourceDefinition(source_code, filename, line_offset - 1, column_offset) - return SourceDefinition(source, filename, starting_line) + except OSError as err: + raise ValueError(f"Can not get source code of passed function ({func})") from err def make_symbol_names_from_source(source: str, filename: str = MISSING_FILENAME) -> SymbolNames: try: mod_st = symtable.symtable(source, filename, "exec") except SyntaxError as err: - raise common.GTValueError( + raise ValueError( f"Unexpected error when parsing provided source code (\n{source}\n)" ) from err assert mod_st.get_type() == "module" if len(children := mod_st.get_children()) != 1: - raise common.GTValueError( + raise ValueError( f"Sources with multiple function definitions are not yet supported (\n{source}\n)" ) @@ -107,11 +106,11 @@ class SourceDefinition: >>> def foo(a): ... return a >>> src_def = SourceDefinition.from_function(foo) - >>> print(src_def) - SourceDefinition(source='def foo(a):... starting_line=1) + >>> print(src_def) # doctest:+ELLIPSIS + SourceDefinition(source='def foo(a):...', filename='...', line_offset=0, column_offset=0) >>> source, filename, starting_line = src_def - >>> print(source) + >>> print(source) # doctest:+ELLIPSIS def foo(a): return a ... @@ -119,12 +118,13 @@ def foo(a): source: str filename: str = MISSING_FILENAME - starting_line: int = 1 + line_offset: int = 0 + column_offset: int = 0 def __iter__(self) -> Iterator: yield self.source yield self.filename - yield self.starting_line + yield self.line_offset from_function = staticmethod(make_source_definition_from_function) diff --git a/src/gt4py/next/ffront/type_info.py b/src/gt4py/next/ffront/type_info.py index bdf96bc2fc..c81390ca41 100644 --- a/src/gt4py/next/ffront/type_info.py +++ b/src/gt4py/next/ffront/type_info.py @@ -17,7 +17,7 @@ import gt4py.next.ffront.type_specifications as ts_ffront import gt4py.next.type_system.type_specifications as ts -from gt4py.next.common import Dimension, GTTypeError +from gt4py.next.common import Dimension from gt4py.next.type_system import type_info @@ -51,7 +51,7 @@ def _as_field(arg_el: ts.TypeSpec, path: tuple[int, ...]) -> ts.TypeSpec: if type_info.extract_dtype(param_el) == type_info.extract_dtype(arg_el): return param_el else: - raise GTTypeError(f"{arg_el} is not compatible with {param_el}.") + raise ValueError(f"{arg_el} is not compatible with {param_el}.") return arg_el return type_info.apply_to_primitive_constituents(arg, _as_field, with_path_arg=True) @@ -217,7 +217,7 @@ def function_signature_incompatibilities_scanop( ] try: type_info.promote_dims(*arg_dims) - except GTTypeError as e: + except ValueError as e: yield e.args[0] assert len(scan_pass_type.pos_only_args) == 0 diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index ebc0921efe..e4ce2e9173 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -19,7 +19,7 @@ import numpy as np from gt4py.eve.utils import XIterable, xiter -from gt4py.next.common import Dimension, DimensionKind, GTTypeError +from gt4py.next.common import Dimension, DimensionKind from gt4py.next.type_system import type_specifications as ts @@ -50,15 +50,13 @@ def type_class(symbol_type: ts.TypeSpec) -> Type[ts.TypeSpec]: match symbol_type: case ts.DeferredType(constraint): if constraint is None: - raise GTTypeError(f"No type information available for {symbol_type}!") + raise ValueError(f"No type information available for {symbol_type}!") elif isinstance(constraint, tuple): - raise GTTypeError(f"Not sufficient type information available for {symbol_type}!") + raise ValueError(f"Not sufficient type information available for {symbol_type}!") return constraint case ts.TypeSpec() as concrete_type: return concrete_type.__class__ - raise GTTypeError( - f"Invalid type for TypeInfo: requires {ts.TypeSpec}, got {type(symbol_type)}!" - ) + raise ValueError(f"Invalid type for TypeInfo: requires {ts.TypeSpec}, got {type(symbol_type)}!") def primitive_constituents( @@ -140,7 +138,7 @@ def extract_dtype(symbol_type: ts.TypeSpec) -> ts.ScalarType: return dtype case ts.ScalarType() as dtype: return dtype - raise GTTypeError(f"Can not unambiguosly extract data type from {symbol_type}!") + raise ValueError(f"Can not unambiguosly extract data type from {symbol_type}!") def is_floating_point(symbol_type: ts.TypeSpec) -> bool: @@ -297,7 +295,7 @@ def extract_dims(symbol_type: ts.TypeSpec) -> list[Dimension]: return [] case ts.FieldType(dims): return dims - raise GTTypeError(f"Can not extract dimensions from {symbol_type}!") + raise ValueError(f"Can not extract dimensions from {symbol_type}!") def is_local_field(type_: ts.FieldType) -> bool: @@ -399,11 +397,11 @@ def promote(*types: ts.FieldType | ts.ScalarType) -> ts.FieldType | ts.ScalarTyp ... ) # doctest: +ELLIPSIS Traceback (most recent call last): ... - gt4py.next.common.GTTypeError: Dimensions can not be promoted. Could not determine order of the following dimensions: J, K. + ValueError: Dimensions can not be promoted. Could not determine order of the following dimensions: J, K. """ if all(isinstance(type_, ts.ScalarType) for type_ in types): if not all(type_ == types[0] for type_ in types): - raise GTTypeError("Could not promote scalars of different dtype (not implemented).") + raise ValueError("Could not promote scalars of different dtype (not implemented).") if not all(type_.shape is None for type_ in types): # type: ignore[union-attr] raise NotImplementedError("Shape promotion not implemented.") return types[0] @@ -433,11 +431,11 @@ def promote_dims(*dims_list: list[Dimension]) -> list[Dimension]: >>> promote_dims([I, J], [K]) # doctest: +ELLIPSIS Traceback (most recent call last): ... - gt4py.next.common.GTTypeError: Dimensions can not be promoted. Could not determine order of the following dimensions: J, K. + ValueError: Dimensions can not be promoted. Could not determine order of the following dimensions: J, K. >>> promote_dims([I, J], [J, I]) # doctest: +ELLIPSIS Traceback (most recent call last): ... - gt4py.next.common.GTTypeError: Dimensions can not be promoted. The following dimensions appear in contradicting order: I, J. + ValueError: Dimensions can not be promoted. The following dimensions appear in contradicting order: I, J. """ # build a graph with the vertices being dimensions and edges representing # the order between two dimensions. The graph is encoded as a dictionary @@ -470,7 +468,7 @@ def promote_dims(*dims_list: list[Dimension]) -> list[Dimension]: # TODO(tehrengruber): avoid recomputation of zero_in_degree_vertex_list while zero_in_degree_vertex_list := [v for v, d in in_degree.items() if d == 0]: if len(zero_in_degree_vertex_list) != 1: - raise GTTypeError( + raise ValueError( f"Dimensions can not be promoted. Could not determine " f"order of the following dimensions: " f"{', '.join((dim.value for dim in zero_in_degree_vertex_list))}." @@ -483,7 +481,7 @@ def promote_dims(*dims_list: list[Dimension]) -> list[Dimension]: in_degree[predecessor] -= 1 if len(in_degree.items()) > 0: - raise GTTypeError( + raise ValueError( f"Dimensions can not be promoted. The following dimensions " f"appear in contradicting order: {', '.join((dim.value for dim in in_degree.keys()))}." ) @@ -522,11 +520,11 @@ def return_type_field( ): try: accepts_args(field_type, with_args=with_args, with_kwargs=with_kwargs, raise_exception=True) - except GTTypeError as ex: - raise GTTypeError("Could not deduce return type of invalid remap operation.") from ex + except ValueError as ex: + raise ValueError("Could not deduce return type of invalid remap operation.") from ex if not isinstance(with_args[0], ts.OffsetType): - raise GTTypeError(f"First argument must be of type {ts.OffsetType}, got {with_args[0]}.") + raise ValueError(f"First argument must be of type {ts.OffsetType}, got {with_args[0]}.") source_dim = with_args[0].source target_dims = with_args[0].target @@ -738,7 +736,7 @@ def accepts_args( """ Check if a function can be called for given arguments. - If ``raise_exception`` is given a :class:`GTTypeError` is raised with a + If ``raise_exception`` is given a :class:`ValueError` is raised with a detailed description of why the function is not callable. Note that all types must be concrete/complete. @@ -758,14 +756,14 @@ def accepts_args( """ if not isinstance(callable_type, ts.CallableType): if raise_exception: - raise GTTypeError(f"Expected a callable type, but got `{callable_type}`.") + raise ValueError(f"Expected a callable type, but got `{callable_type}`.") return False errors = function_signature_incompatibilities(callable_type, with_args, with_kwargs) if raise_exception: error_list = list(errors) if len(error_list) > 0: - raise GTTypeError( + raise ValueError( f"Invalid call to function of type `{callable_type}`:\n" + ("\n".join([f" - {error}" for error in error_list])) ) diff --git a/src/gt4py/next/type_system/type_translation.py b/src/gt4py/next/type_system/type_translation.py index 3d054c0746..39947db4af 100644 --- a/src/gt4py/next/type_system/type_translation.py +++ b/src/gt4py/next/type_system/type_translation.py @@ -37,7 +37,7 @@ def get_scalar_kind(dtype: npt.DTypeLike) -> ts.ScalarKind: try: dt = np.dtype(dtype) except TypeError as err: - raise common.GTTypeError(f"Invalid scalar type definition ({dtype})") from err + raise ValueError(f"Invalid scalar type definition ({dtype})") from err if dt.shape == () and dt.fields is None: match dt: @@ -54,9 +54,9 @@ def get_scalar_kind(dtype: npt.DTypeLike) -> ts.ScalarKind: case np.str_: return ts.ScalarKind.STRING case _: - raise common.GTTypeError(f"Impossible to map '{dtype}' value to a ScalarKind") + raise ValueError(f"Impossible to map '{dtype}' value to a ScalarKind") else: - raise common.GTTypeError(f"Non-trivial dtypes like '{dtype}' are not yet supported") + raise ValueError(f"Non-trivial dtypes like '{dtype}' are not yet supported") def from_type_hint( @@ -75,7 +75,7 @@ def from_type_hint( try: type_hint = xtyping.eval_forward_ref(type_hint, globalns=globalns, localns=localns) except Exception as error: - raise TypingError( + raise ValueError( f"Type annotation ({type_hint}) has undefined forward references!" ) from error @@ -98,50 +98,50 @@ def from_type_hint( case builtins.tuple: if not args: - raise TypingError(f"Tuple annotation ({type_hint}) requires at least one argument!") + raise ValueError(f"Tuple annotation ({type_hint}) requires at least one argument!") if Ellipsis in args: - raise TypingError(f"Unbound tuples ({type_hint}) are not allowed!") + raise ValueError(f"Unbound tuples ({type_hint}) are not allowed!") return ts.TupleType(types=[recursive_make_symbol(arg) for arg in args]) case common.Field: if (n_args := len(args)) != 2: - raise TypingError(f"Field type requires two arguments, got {n_args}! ({type_hint})") + raise ValueError(f"Field type requires two arguments, got {n_args}! ({type_hint})") dims: Union[Ellipsis, list[common.Dimension]] = [] dim_arg, dtype_arg = args if isinstance(dim_arg, list): for d in dim_arg: if not isinstance(d, common.Dimension): - raise TypingError(f"Invalid field dimension definition '{d}'") + raise ValueError(f"Invalid field dimension definition '{d}'") dims.append(d) elif dim_arg is Ellipsis: dims = dim_arg else: - raise TypingError(f"Invalid field dimensions '{dim_arg}'") + raise ValueError(f"Invalid field dimensions '{dim_arg}'") try: dtype = recursive_make_symbol(dtype_arg) - except TypingError as error: - raise TypingError( + except ValueError as error: + raise ValueError( f"Field dtype argument must be a scalar type (got '{dtype_arg}')!" ) from error if not isinstance(dtype, ts.ScalarType) or dtype.kind == ts.ScalarKind.STRING: - raise TypingError("Field dtype argument must be a scalar type (got '{dtype}')!") + raise ValueError("Field dtype argument must be a scalar type (got '{dtype}')!") return ts.FieldType(dims=dims, dtype=dtype) case collections.abc.Callable: if not args: - raise TypingError("Not annotated functions are not supported!") + raise ValueError("Not annotated functions are not supported!") try: arg_types, return_type = args args = [recursive_make_symbol(arg) for arg in arg_types] except Exception as error: - raise TypingError(f"Invalid callable annotations in {type_hint}") from error + raise ValueError(f"Invalid callable annotations in {type_hint}") from error kwargs_info = [arg for arg in extra_args if isinstance(arg, xtyping.CallableKwargsInfo)] if len(kwargs_info) != 1: - raise TypingError(f"Invalid callable annotations in {type_hint}") + raise ValueError(f"Invalid callable annotations in {type_hint}") kwargs = { arg: recursive_make_symbol(arg_type) for arg, arg_type in kwargs_info[0].data.items() @@ -155,7 +155,7 @@ def from_type_hint( returns=recursive_make_symbol(return_type), ) - raise TypingError(f"'{type_hint}' type is not supported") + raise ValueError(f"'{type_hint}' type is not supported") def from_value(value: Any) -> ts.TypeSpec: @@ -179,7 +179,7 @@ def from_value(value: Any) -> ts.TypeSpec: symbol_type = candidate_type break if not symbol_type: - raise common.GTTypeError( + raise ValueError( f"Value `{value}` is out of range to be representable as `INT32` or `INT64`." ) return candidate_type @@ -202,17 +202,4 @@ def from_value(value: Any) -> ts.TypeSpec: if isinstance(symbol_type, (ts.DataType, ts.CallableType, ts.OffsetType, ts.DimensionType)): return symbol_type else: - raise common.GTTypeError(f"Impossible to map `{value}` value to a Symbol") - - -# TODO(egparedes): Add source location info (maybe subclassing FieldOperatorSyntaxError) -class TypingError(common.GTTypeError): - def __init__( - self, - msg="", - *, - info=None, - ): - msg = f"Invalid type declaration: {msg}" - args = tuple([msg, info] if info else [msg]) - super().__init__(*args) + raise ValueError(f"Impossible to map '{value}' value to a Symbol") diff --git a/tests/eve_tests/definitions.py b/tests/eve_tests/definitions.py index b20fb28ddb..f4f0232ae4 100644 --- a/tests/eve_tests/definitions.py +++ b/tests/eve_tests/definitions.py @@ -297,7 +297,7 @@ def make_source_location(*, fixed: bool = False) -> SourceLocation: str_value = make_str_value(fixed=fixed) source = f"file_{str_value}.py" - return SourceLocation(line=line, column=column, source=source) + return SourceLocation(line=line, column=column, filename=source) def make_source_location_group(*, fixed: bool = False) -> SourceLocationGroup: @@ -472,7 +472,7 @@ def make_frozen_simple_node(*, fixed: bool = False) -> FrozenSimpleNode: # -- Makers of invalid nodes -- def make_invalid_location_node(*, fixed: bool = False) -> LocationNode: - return LocationNode(loc=SourceLocation(line=0, column=-1, source="")) + return LocationNode(loc=SourceLocation(line=0, column=-1, filename="")) def make_invalid_at_int_simple_node(*, fixed: bool = False) -> SimpleNode: diff --git a/tests/eve_tests/unit_tests/test_codegen.py b/tests/eve_tests/unit_tests/test_codegen.py index dcd81ef4fd..7e7ec2244e 100644 --- a/tests/eve_tests/unit_tests/test_codegen.py +++ b/tests/eve_tests/unit_tests/test_codegen.py @@ -106,7 +106,7 @@ def visit_IntKind(self, node, **kwargs): return f"ONE INTKIND({node.value})" def visit_SourceLocation(self, node, **kwargs): - return f"SourceLocation" + return f"SourceLocation" LocationNode = codegen.FormatTemplate("LocationNode {{{loc}}}") diff --git a/tests/eve_tests/unit_tests/test_concepts.py b/tests/eve_tests/unit_tests/test_concepts.py index 10cea03836..f01e79f626 100644 --- a/tests/eve_tests/unit_tests/test_concepts.py +++ b/tests/eve_tests/unit_tests/test_concepts.py @@ -46,49 +46,29 @@ class LettersOnlySymbol(SymbolName, regex=re.compile(r"[a-zA-Z]+$")): class TestSourceLocation: def test_valid_position(self): - eve.concepts.SourceLocation(line=1, column=1, source="source.py") + eve.concepts.SourceLocation(line=1, column=1, filename="source.py") def test_invalid_position(self): with pytest.raises(ValueError, match="column"): - eve.concepts.SourceLocation(line=1, column=-1, source="source.py") + eve.concepts.SourceLocation(line=1, column=-1, filename="source.py") def test_str(self): - loc = eve.concepts.SourceLocation(line=1, column=1, source="dir/source.py") + loc = eve.concepts.SourceLocation(line=1, column=1, filename="dir/source.py") assert str(loc) == "" - loc = eve.concepts.SourceLocation(line=1, column=1, source="dir/source.py", end_line=2) + loc = eve.concepts.SourceLocation(line=1, column=1, filename="dir/source.py", end_line=2) assert str(loc) == "" loc = eve.concepts.SourceLocation( - line=1, column=1, source="dir/source.py", end_line=2, end_column=2 + line=1, column=1, filename="dir/source.py", end_line=2, end_column=2 ) assert str(loc) == "" - def test_construction_from_ast(self): - import ast - - ast_node = ast.parse("a = b + 1").body[0] - loc = eve.concepts.SourceLocation.from_AST(ast_node, "source.py") - - assert loc.line == ast_node.lineno - assert loc.column == ast_node.col_offset + 1 - assert loc.source == "source.py" - assert loc.end_line == ast_node.end_lineno - assert loc.end_column == ast_node.end_col_offset + 1 - - loc = eve.concepts.SourceLocation.from_AST(ast_node) - - assert loc.line == ast_node.lineno - assert loc.column == ast_node.col_offset + 1 - assert loc.source == f"" - assert loc.end_line == ast_node.end_lineno - assert loc.end_column == ast_node.end_col_offset + 1 - class TestSourceLocationGroup: def test_valid_locations(self): - loc1 = eve.concepts.SourceLocation(line=1, column=1, source="source1.py") - loc2 = eve.concepts.SourceLocation(line=2, column=2, source="source2.py") + loc1 = eve.concepts.SourceLocation(line=1, column=1, filename="source1.py") + loc2 = eve.concepts.SourceLocation(line=2, column=2, filename="source2.py") eve.concepts.SourceLocationGroup(loc1) eve.concepts.SourceLocationGroup(loc1, loc2) eve.concepts.SourceLocationGroup(loc1, loc1, loc2, loc2, context="test context") @@ -96,13 +76,13 @@ def test_valid_locations(self): def test_invalid_locations(self): with pytest.raises(ValueError): eve.concepts.SourceLocationGroup() - loc1 = eve.concepts.SourceLocation(line=1, column=1, source="source.py") + loc1 = eve.concepts.SourceLocation(line=1, column=1, filename="source.py") with pytest.raises(TypeError): eve.concepts.SourceLocationGroup(loc1, "loc2") def test_str(self): - loc1 = eve.concepts.SourceLocation(line=1, column=1, source="source1.py") - loc2 = eve.concepts.SourceLocation(line=2, column=2, source="source2.py") + loc1 = eve.concepts.SourceLocation(line=1, column=1, filename="source1.py") + loc2 = eve.concepts.SourceLocation(line=2, column=2, filename="source2.py") loc = eve.concepts.SourceLocationGroup(loc1, loc2, context="some context") assert str(loc) == "<#some context#[, ]>" diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py index b8f47a5770..cca38b3b98 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py @@ -18,10 +18,10 @@ import numpy as np import pytest +from gt4py.next import errors from gt4py.next.common import Field from gt4py.next.ffront.decorator import field_operator, program, scan_operator from gt4py.next.ffront.fbuiltins import int32, int64 -from gt4py.next.ffront.foast_passes.type_deduction import FieldOperatorTypeDeductionError from gt4py.next.program_processors.runners import dace_iterator, gtfn_cpu from next_tests.integration_tests import cases @@ -240,7 +240,7 @@ def testee( def test_scan_wrong_return_type(cartesian_case): with pytest.raises( - FieldOperatorTypeDeductionError, + errors.DSLError, match=(r"Argument `init` to scan operator `testee_scan` must have same type as its return"), ): @@ -257,7 +257,7 @@ def testee(qc: cases.IKFloatField, param_1: int32, param_2: float, scalar: float def test_scan_wrong_state_type(cartesian_case): with pytest.raises( - FieldOperatorTypeDeductionError, + errors.DSLError, match=( r"Argument `init` to scan operator `testee_scan` must have same type as `state` argument" ), diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index c9ac10abf0..2e136c046d 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -22,6 +22,7 @@ from gt4py.next import ( astype, broadcast, + errors, float32, float64, int32, @@ -32,7 +33,6 @@ where, ) from gt4py.next.ffront.experimental import as_offset -from gt4py.next.ffront.foast_passes.type_deduction import FieldOperatorTypeDeductionError from gt4py.next.program_processors.runners import dace_iterator, gtfn_cpu from next_tests.integration_tests import cases @@ -914,7 +914,7 @@ def fieldop_where_k_offset( def test_undefined_symbols(cartesian_case): - with pytest.raises(FieldOperatorTypeDeductionError, match="Undeclared symbol"): + with pytest.raises(errors.DSLError, match="Undeclared symbol"): @gtx.field_operator(backend=cartesian_case.backend) def return_undefined(): @@ -1017,7 +1017,7 @@ def unpack( def test_tuple_unpacking_too_many_values(cartesian_case): with pytest.raises( - FieldOperatorTypeDeductionError, + errors.DSLError, match=(r"Could not deduce type: Too many values to unpack \(expected 3\)"), ): @@ -1028,9 +1028,7 @@ def _star_unpack() -> tuple[int32, float64, int32]: def test_tuple_unpacking_too_many_values(cartesian_case): - with pytest.raises( - FieldOperatorTypeDeductionError, match=(r"Assignment value must be of type tuple!") - ): + with pytest.raises(errors.DSLError, match=(r"Assignment value must be of type tuple!")): @gtx.field_operator(backend=cartesian_case.backend) def _invalid_unpack() -> tuple[int32, float64, int32]: diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py index 28ac99108c..b484fc6f31 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py @@ -62,7 +62,7 @@ def make_builtin_field_operator(builtin_name: str): closure_vars = {"IDim": IDim, builtin_name: getattr(fbuiltins, builtin_name)} - loc = foast.SourceLocation(line=1, column=1, source="none") + loc = foast.SourceLocation(line=1, column=1, filename="none") params = [ foast.Symbol(id=k, type=type_translation.from_type_hint(type), location=loc) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py index 18b1fab906..f12e10fcab 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py @@ -20,8 +20,6 @@ import pytest import gt4py.next as gtx -from gt4py.next.common import GTTypeError -from gt4py.next.ffront.past_passes.type_deduction import ProgramTypeError from gt4py.next.program_processors.runners import dace_iterator from next_tests.integration_tests import cases @@ -229,7 +227,7 @@ def test_wrong_argument_type(cartesian_case, copy_program_def): inp = gtx.np_as_located_field(JDim)(np.ones((cartesian_case.default_sizes[JDim],))) out = cases.allocate(cartesian_case, copy_program, "out").strategy(cases.ConstInitializer(1))() - with pytest.raises(ProgramTypeError) as exc_info: + with pytest.raises(TypeError) as exc_info: # program is defined on Field[[IDim], ...], but we call with # Field[[JDim], ...] copy_program(inp, out, offset_provider={}) @@ -255,7 +253,7 @@ def empty_domain_program(a: cases.IJField, out_field: cases.IJField): out_field = cases.allocate(cartesian_case, empty_domain_program, "out_field")() with pytest.raises( - GTTypeError, + ValueError, match=(r"Dimensions in out field and field domain are not equivalent"), ): empty_domain_program(a, out_field, offset_provider={}) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py index 768cf2f9a0..f44b662f22 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py @@ -18,8 +18,7 @@ import numpy as np import pytest -from gt4py.next import Field, field_operator, float64, index_field, np_as_located_field -from gt4py.next.ffront.foast_passes.type_deduction import FieldOperatorTypeDeductionError +from gt4py.next import Field, errors, field_operator, float64, index_field, np_as_located_field from gt4py.next.program_processors.runners import dace_iterator, gtfn_cpu from next_tests.integration_tests import cases @@ -401,7 +400,7 @@ def if_without_else( def test_if_non_scalar_condition(): - with pytest.raises(FieldOperatorTypeDeductionError, match="Condition for `if` must be scalar."): + with pytest.raises(errors.DSLError, match="Condition for `if` must be scalar."): @field_operator def if_non_scalar_condition( @@ -414,9 +413,7 @@ def if_non_scalar_condition( def test_if_non_boolean_condition(): - with pytest.raises( - FieldOperatorTypeDeductionError, match="Condition for `if` must be of boolean type." - ): + with pytest.raises(errors.DSLError, match="Condition for `if` must be of boolean type."): @field_operator def if_non_boolean_condition( @@ -431,7 +428,7 @@ def if_non_boolean_condition( def test_if_inconsistent_types(): with pytest.raises( - FieldOperatorTypeDeductionError, + errors.DSLError, match="Inconsistent types between two branches for variable", ): diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py index 73509fdd17..c5013d1fa5 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py @@ -24,6 +24,7 @@ FieldOffset, astype, broadcast, + errors, float32, float64, int32, @@ -31,10 +32,8 @@ neighbor_sum, where, ) -from gt4py.next.common import GTTypeError from gt4py.next.ffront.ast_passes import single_static_assign as ssa from gt4py.next.ffront.experimental import as_offset -from gt4py.next.ffront.foast_passes.type_deduction import FieldOperatorTypeDeductionError from gt4py.next.ffront.func_to_foast import FieldOperatorParser from gt4py.next.type_system import type_info, type_specifications as ts @@ -420,7 +419,7 @@ def test_accept_args( if len(expected) > 0: with pytest.raises( - GTTypeError, + ValueError, ) as exc_info: type_info.accepts_args( func_type, with_args=args, with_kwargs=kwargs, raise_exception=True @@ -535,7 +534,7 @@ def add_bools(a: Field[[TDim], bool], b: Field[[TDim], bool]): return a + b with pytest.raises( - FieldOperatorTypeDeductionError, + errors.DSLError, match=(r"Type Field\[\[TDim\], bool\] can not be used in operator `\+`!"), ): _ = FieldOperatorParser.apply_to_function(add_bools) @@ -550,7 +549,7 @@ def nonmatching(a: Field[[X], float64], b: Field[[Y], float64]): return a + b with pytest.raises( - FieldOperatorTypeDeductionError, + errors.DSLError, match=( r"Could not promote `Field\[\[X], float64\]` and `Field\[\[Y\], float64\]` to common type in call to +." ), @@ -563,8 +562,8 @@ def float_bitop(a: Field[[TDim], float], b: Field[[TDim], float]): return a & b with pytest.raises( - FieldOperatorTypeDeductionError, - match=(r"Type Field\[\[TDim\], float64\] can not be used in operator `\&`! "), + errors.DSLError, + match=(r"Type Field\[\[TDim\], float64\] can not be used in operator `\&`!"), ): _ = FieldOperatorParser.apply_to_function(float_bitop) @@ -574,7 +573,7 @@ def sign_bool(a: Field[[TDim], bool]): return -a with pytest.raises( - FieldOperatorTypeDeductionError, + errors.DSLError, match=r"Incompatible type for unary operator `\-`: `Field\[\[TDim\], bool\]`!", ): _ = FieldOperatorParser.apply_to_function(sign_bool) @@ -585,7 +584,7 @@ def not_int(a: Field[[TDim], int64]): return not a with pytest.raises( - FieldOperatorTypeDeductionError, + errors.DSLError, match=r"Incompatible type for unary operator `not`: `Field\[\[TDim\], int64\]`!", ): _ = FieldOperatorParser.apply_to_function(not_int) @@ -657,7 +656,7 @@ def mismatched_lit() -> Field[[TDim], "float32"]: return float32("1.0") + float64("1.0") with pytest.raises( - FieldOperatorTypeDeductionError, + errors.DSLError, match=(r"Could not promote `float32` and `float64` to common type in call to +."), ): _ = FieldOperatorParser.apply_to_function(mismatched_lit) @@ -687,7 +686,7 @@ def disjoint_broadcast(a: Field[[ADim], float64]): return broadcast(a, (BDim, CDim)) with pytest.raises( - FieldOperatorTypeDeductionError, + errors.DSLError, match=r"Expected broadcast dimension is missing", ): _ = FieldOperatorParser.apply_to_function(disjoint_broadcast) @@ -702,7 +701,7 @@ def badtype_broadcast(a: Field[[ADim], float64]): return broadcast(a, (BDim, CDim)) with pytest.raises( - FieldOperatorTypeDeductionError, + errors.DSLError, match=r"Expected all broadcast dimensions to be of type Dimension.", ): _ = FieldOperatorParser.apply_to_function(badtype_broadcast) @@ -768,7 +767,7 @@ def bad_dim_where(a: Field[[ADim], bool], b: Field[[ADim], float64]): return where(a, ((5.0, 9.0), (b, 6.0)), b) with pytest.raises( - FieldOperatorTypeDeductionError, + errors.DSLError, match=r"Return arguments need to be of same type", ): _ = FieldOperatorParser.apply_to_function(bad_dim_where) @@ -823,7 +822,7 @@ def modulo_floats(inp: Field[[TDim], float]): return inp % 3.0 with pytest.raises( - FieldOperatorTypeDeductionError, + errors.DSLError, match=r"Type float64 can not be used in operator `%`", ): _ = FieldOperatorParser.apply_to_function(modulo_floats) @@ -833,7 +832,7 @@ def test_undefined_symbols(): def return_undefined(): return undefined_symbol - with pytest.raises(FieldOperatorTypeDeductionError, match="Undeclared symbol"): + with pytest.raises(errors.DSLError, match="Undeclared symbol"): _ = FieldOperatorParser.apply_to_function(return_undefined) @@ -846,7 +845,7 @@ def as_offset_dim(a: Field[[ADim, BDim], float], b: Field[[ADim], int]): return a(as_offset(Boff, b)) with pytest.raises( - FieldOperatorTypeDeductionError, + errors.DSLError, match=f"not in list of offset field dimensions", ): _ = FieldOperatorParser.apply_to_function(as_offset_dim) @@ -861,7 +860,7 @@ def as_offset_dtype(a: Field[[ADim, BDim], float], b: Field[[BDim], float]): return a(as_offset(Boff, b)) with pytest.raises( - FieldOperatorTypeDeductionError, + errors.DSLError, match=f"Excepted integer for offset field dtype", ): _ = FieldOperatorParser.apply_to_function(as_offset_dtype) diff --git a/tests/next_tests/integration_tests/feature_tests/test_util_cases.py b/tests/next_tests/integration_tests/feature_tests/test_util_cases.py index 2e1f6e3ff2..e3cecfa88f 100644 --- a/tests/next_tests/integration_tests/feature_tests/test_util_cases.py +++ b/tests/next_tests/integration_tests/feature_tests/test_util_cases.py @@ -16,7 +16,7 @@ import pytest import gt4py.next as gtx -from gt4py.next import common +from gt4py.next import errors from gt4py.next.program_processors.runners import roundtrip from next_tests.integration_tests import cases @@ -87,7 +87,7 @@ def test_verify_fails_with_wrong_type(cartesian_case): # noqa: F811 # fixtures b = cases.allocate(cartesian_case, addition, "b")() out = cases.allocate(cartesian_case, addition, cases.RETURN)() - with pytest.raises(common.GTTypeError): + with pytest.raises(errors.DSLError): cases.verify(cartesian_case, addition, a, b, out=out, ref=a.array() + b.array()) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/test_driver.py b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/test_driver.py index 14df97c523..7de56cb5bb 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/test_driver.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/test_driver.py @@ -14,6 +14,7 @@ import os import subprocess +import sys from pathlib import Path import pytest @@ -31,7 +32,7 @@ def _execute_cmake(backend_str: str): build_dir = _build_dir(backend_str) build_dir.mkdir(exist_ok=True) cmake = ["cmake", "-B", build_dir, f"-DBACKEND={backend_str}"] - subprocess.run(cmake, cwd=_source_dir(), check=True) + subprocess.run(cmake, cwd=_source_dir(), check=True, stderr=sys.stderr) def _get_available_cpu_count(): diff --git a/tests/next_tests/unit_tests/errors_tests/test_excepthook.py b/tests/next_tests/unit_tests/errors_tests/test_excepthook.py new file mode 100644 index 0000000000..526844d730 --- /dev/null +++ b/tests/next_tests/unit_tests/errors_tests/test_excepthook.py @@ -0,0 +1,57 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import os + +from gt4py import eve +from gt4py.next import errors +from gt4py.next.errors import excepthook + + +def test_format_uncaught_error(): + try: + loc = eve.SourceLocation("/src/file.py", 1, 1) + msg = "compile error msg" + raise errors.exceptions.DSLError(loc, msg) from ValueError("value error msg") + except errors.exceptions.DSLError as err: + str_devmode = "".join(errors.excepthook._format_uncaught_error(err, True)) + assert str_devmode.find("Source location") >= 0 + assert str_devmode.find("Traceback") >= 0 + assert str_devmode.find("cause") >= 0 + assert str_devmode.find("ValueError") >= 0 + str_usermode = "".join(errors.excepthook._format_uncaught_error(err, False)) + assert str_usermode.find("Source location") >= 0 + assert str_usermode.find("Traceback") < 0 + assert str_usermode.find("cause") < 0 + assert str_usermode.find("ValueError") < 0 + + +def test_get_verbose_exceptions(): + env_var_name = "GT4PY_VERBOSE_EXCEPTIONS" + + # Make sure to save and restore the environment variable, we don't want to + # affect other tests running in the same process. + saved = os.environ.get(env_var_name, None) + try: + os.environ[env_var_name] = "False" + assert excepthook._get_verbose_exceptions_envvar() is False + os.environ[env_var_name] = "True" + assert excepthook._get_verbose_exceptions_envvar() is True + os.environ[env_var_name] = "invalid value" # Should emit a warning too + assert excepthook._get_verbose_exceptions_envvar() is False + del os.environ[env_var_name] + assert excepthook._get_verbose_exceptions_envvar() is False + finally: + if saved is not None: + os.environ[env_var_name] = saved diff --git a/tests/next_tests/unit_tests/errors_tests/test_exceptions.py b/tests/next_tests/unit_tests/errors_tests/test_exceptions.py new file mode 100644 index 0000000000..60a382d989 --- /dev/null +++ b/tests/next_tests/unit_tests/errors_tests/test_exceptions.py @@ -0,0 +1,71 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import inspect +import re + +import pytest + +from gt4py.eve import SourceLocation +from gt4py.next import errors + + +@pytest.fixture +def loc_snippet(): + frameinfo = inspect.getframeinfo(inspect.currentframe()) + # This very line of comment should be shown in the snippet. + return SourceLocation( + frameinfo.filename, frameinfo.lineno + 1, 15, end_line=frameinfo.lineno + 1, end_column=29 + ) + + +@pytest.fixture +def loc_plain(): + return SourceLocation("/source/file.py", 5, 2, end_line=5, end_column=9) + + +@pytest.fixture +def message(): + return "a message" + + +def test_message(loc_plain, message): + assert errors.DSLError(loc_plain, message).message == message + + +def test_location(loc_plain, message): + assert errors.DSLError(loc_plain, message).location == loc_plain + + +def test_with_location(loc_plain, message): + assert errors.DSLError(None, message).with_location(loc_plain).location == loc_plain + + +def test_str(loc_plain, message): + pattern = f'{message}\\n File ".*", line.*' + s = str(errors.DSLError(loc_plain, message)) + assert re.match(pattern, s) + + +def test_str_snippet(loc_snippet, message): + pattern = r"\n".join( + [ + f"{message}", + ' File ".*", line.*', + " # This very line of comment should be shown in the snippet.", + r" \^\^\^\^\^\^\^\^\^\^\^\^\^\^", + ] + ) + s = str(errors.DSLError(loc_snippet, message)) + assert re.match(pattern, s) diff --git a/tests/next_tests/unit_tests/errors_tests/test_formatting.py b/tests/next_tests/unit_tests/errors_tests/test_formatting.py new file mode 100644 index 0000000000..ebb6cf9a37 --- /dev/null +++ b/tests/next_tests/unit_tests/errors_tests/test_formatting.py @@ -0,0 +1,85 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import re + +import pytest + +from gt4py.eve import SourceLocation +from gt4py.next import errors +from gt4py.next.errors.formatting import format_compilation_error + + +@pytest.fixture +def message(): + return "a message" + + +@pytest.fixture +def location(): + return SourceLocation("/source/file.py", 5, 2, end_line=5, end_column=9) + + +@pytest.fixture +def tb(): + try: + raise Exception() + except Exception as ex: + return ex.__traceback__ + + +@pytest.fixture +def type_(): + return errors.DSLError + + +@pytest.fixture +def qualname(type_): + return f"{type_.__module__}.{type_.__name__}" + + +def test_format(type_, qualname, message): + cls_pattern = f"{qualname}: {message}" + s = "\n".join(format_compilation_error(type_, message, None, None, None)) + assert re.match(cls_pattern, s) + + +def test_format_loc(type_, qualname, message, location): + loc_pattern = "Source location.*" + file_pattern = ' File "/source.*".*' + cls_pattern = f"{qualname}: {message}" + pattern = r"\n".join([loc_pattern, file_pattern, cls_pattern]) + s = "".join(format_compilation_error(type_, message, location, None, None)) + assert re.match(pattern, s) + + +def test_format_traceback(type_, qualname, message, tb): + tb_pattern = "Traceback.*" + file_pattern = ' File ".*".*' + line_pattern = ".*" + cls_pattern = f"{qualname}: {message}" + pattern = r"\n".join([tb_pattern, file_pattern, line_pattern, cls_pattern]) + s = "".join(format_compilation_error(type_, message, None, tb, None)) + assert re.match(pattern, s) + + +def test_format_cause(type_, qualname, message): + cause = ValueError("asd") + blank_pattern = "" + cause_pattern = "ValueError: asd" + bridge_pattern = "The above.*" + cls_pattern = f"{qualname}: {message}" + pattern = r"\n".join([cause_pattern, blank_pattern, bridge_pattern, blank_pattern, cls_pattern]) + s = "".join(format_compilation_error(type_, message, None, None, cause)) + assert re.match(pattern, s) diff --git a/tests/next_tests/unit_tests/ffront_tests/test_decorator_domain_deduction.py b/tests/next_tests/unit_tests/ffront_tests/test_decorator_domain_deduction.py index 7b2d796b29..75e23545de 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_decorator_domain_deduction.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_decorator_domain_deduction.py @@ -15,7 +15,6 @@ import pytest import gt4py.next as gtx -from gt4py.next.common import GTTypeError from gt4py.next.ffront.decorator import _deduce_grid_type @@ -38,7 +37,7 @@ def test_domain_deduction_unstructured(): def test_domain_complies_with_request_cartesian(): assert _deduce_grid_type(gtx.GridType.CARTESIAN, {CartesianOffset}) == gtx.GridType.CARTESIAN - with pytest.raises(GTTypeError, match="unstructured.*FieldOffset.*found"): + with pytest.raises(ValueError, match="unstructured.*FieldOffset.*found"): _deduce_grid_type(gtx.GridType.CARTESIAN, {UnstructuredOffset}) _deduce_grid_type(gtx.GridType.CARTESIAN, {LocalDim}) diff --git a/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py b/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py index 3f6e5ef5e6..e5bbed19fd 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py @@ -41,15 +41,12 @@ import gt4py.next as gtx from gt4py.eve.pattern_matching import ObjectPattern as P -from gt4py.next import astype, broadcast, float32, float64, int32, int64, where -from gt4py.next.common import GTTypeError +from gt4py.next import astype, broadcast, errors, float32, float64, int32, int64, where from gt4py.next.ffront import field_operator_ast as foast from gt4py.next.ffront.ast_passes import single_static_assign as ssa -from gt4py.next.ffront.foast_passes.type_deduction import FieldOperatorTypeDeductionError -from gt4py.next.ffront.func_to_foast import FieldOperatorParser, FieldOperatorSyntaxError +from gt4py.next.ffront.func_to_foast import FieldOperatorParser from gt4py.next.iterator import builtins as itb, ir as itir from gt4py.next.type_system import type_specifications as ts -from gt4py.next.type_system.type_translation import TypingError DEREF = itir.SymRef(id=itb.deref.fun.__name__) @@ -79,10 +76,7 @@ def test_untyped_arg(): def untyped(inp): return inp - with pytest.raises( - FieldOperatorSyntaxError, - match="Untyped parameters not allowed!", - ): + with pytest.raises(errors.MissingParameterAnnotationError): _ = FieldOperatorParser.apply_to_function(untyped) @@ -93,7 +87,7 @@ def mistyped(inp: gtx.Field): return inp with pytest.raises( - TypingError, + ValueError, match="Field type requires two arguments, got 0!", ): _ = FieldOperatorParser.apply_to_function(mistyped) @@ -120,8 +114,8 @@ def no_return(inp: gtx.Field[[TDim], "float64"]): tmp = inp # noqa with pytest.raises( - FieldOperatorSyntaxError, - match="Function must return a value, but no return statement was found\.", + errors.DSLError, + match=".*return.*", ): _ = FieldOperatorParser.apply_to_function(no_return) @@ -136,7 +130,7 @@ def invalid_assign_to_expr( tmp[-1] = inp2 return tmp - with pytest.raises(FieldOperatorSyntaxError, match=r"Can only assign to names! \(.*\)"): + with pytest.raises(errors.DSLError, match=r".*assign.*"): _ = FieldOperatorParser.apply_to_function(invalid_assign_to_expr) @@ -162,7 +156,7 @@ def clashing(inp: gtx.Field[[TDim], "float64"]): tmp: gtx.Field[[TDim], "int64"] = inp return tmp - with pytest.raises(FieldOperatorTypeDeductionError, match="type inconsistency"): + with pytest.raises(errors.DSLError, match="type inconsistency"): _ = FieldOperatorParser.apply_to_function(clashing) @@ -190,24 +184,24 @@ def modulo(inp: gtx.Field[[TDim], "int32"]): ) -def test_bool_and(): +def test_boolean_and_op_unsupported(): def bool_and(a: gtx.Field[[TDim], "bool"], b: gtx.Field[[TDim], "bool"]): return a and b with pytest.raises( - FieldOperatorSyntaxError, - match=(r"`and`/`or` operator not allowed!"), + errors.UnsupportedPythonFeatureError, + match=r".*and.*or.*", ): _ = FieldOperatorParser.apply_to_function(bool_and) -def test_bool_or(): +def test_boolean_or_op_unsupported(): def bool_or(a: gtx.Field[[TDim], "bool"], b: gtx.Field[[TDim], "bool"]): return a or b with pytest.raises( - FieldOperatorSyntaxError, - match=(r"`and`/`or` operator not allowed!"), + errors.UnsupportedPythonFeatureError, + match=r".*and.*or.*", ): _ = FieldOperatorParser.apply_to_function(bool_or) @@ -236,12 +230,12 @@ def unary_tilde(a: gtx.Field[[TDim], "bool"]): ) -def test_scalar_cast(): +def test_scalar_cast_disallow_non_literals(): def cast_scalar_temp(): tmp = int64(1) return int32(tmp) - with pytest.raises(FieldOperatorSyntaxError, match=(r"only takes literal arguments!")): + with pytest.raises(errors.DSLError, match=r".*literal.*"): _ = FieldOperatorParser.apply_to_function(cast_scalar_temp) @@ -252,7 +246,7 @@ def conditional_wrong_mask_type( return where(a, a, a) msg = r"Expected a field with dtype `bool`." - with pytest.raises(FieldOperatorTypeDeductionError, match=msg): + with pytest.raises(errors.DSLError, match=msg): _ = FieldOperatorParser.apply_to_function(conditional_wrong_mask_type) @@ -265,7 +259,7 @@ def conditional_wrong_arg_type( return where(mask, a, b) msg = r"Could not promote scalars of different dtype \(not implemented\)." - with pytest.raises(FieldOperatorTypeDeductionError) as exc_info: + with pytest.raises(errors.DSLError) as exc_info: _ = FieldOperatorParser.apply_to_function(conditional_wrong_arg_type) assert re.search(msg, exc_info.value.__cause__.args[0]) is not None @@ -275,7 +269,7 @@ def test_ternary_with_field_condition(): def ternary_with_field_condition(cond: gtx.Field[[], bool]): return 1 if cond else 2 - with pytest.raises(FieldOperatorTypeDeductionError, match=r"should be .* `bool`"): + with pytest.raises(errors.DSLError, match=r"should be .* `bool`"): _ = FieldOperatorParser.apply_to_function(ternary_with_field_condition) @@ -294,7 +288,7 @@ def test_adr13_wrong_return_type_annotation(): def wrong_return_type_annotation() -> gtx.Field[[], float]: return 1.0 - with pytest.raises(GTTypeError, match=r"Expected `float.*`"): + with pytest.raises(errors.DSLError, match=r"Expected `float.*`"): _ = FieldOperatorParser.apply_to_function(wrong_return_type_annotation) @@ -376,7 +370,7 @@ def wrong_return_type_annotation(a: gtx.Field[[ADim], float64]) -> gtx.Field[[BD return a with pytest.raises( - GTTypeError, + errors.DSLError, match=r"Annotated return type does not match deduced return type", ): _ = FieldOperatorParser.apply_to_function(wrong_return_type_annotation) @@ -387,7 +381,7 @@ def empty_dims() -> gtx.Field[[], float]: return 1.0 with pytest.raises( - GTTypeError, + errors.DSLError, match=r"Annotated return type does not match deduced return type", ): _ = FieldOperatorParser.apply_to_function(empty_dims) @@ -401,8 +395,8 @@ def zero_dims_ternary( ): return a if cond == 1 else b - msg = r"Could not deduce type" - with pytest.raises(FieldOperatorTypeDeductionError) as exc_info: + msg = r"Incompatible datatypes in operator `==`" + with pytest.raises(errors.DSLError) as exc_info: _ = FieldOperatorParser.apply_to_function(zero_dims_ternary) assert re.search(msg, exc_info.value.args[0]) is not None diff --git a/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast_error_line_number.py b/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast_error_line_number.py index 000351c611..123d57baf1 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast_error_line_number.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast_error_line_number.py @@ -18,6 +18,7 @@ import pytest import gt4py.next as gtx +from gt4py.next import errors from gt4py.next.ffront import func_to_foast as f2f, source_utils as src_utils from gt4py.next.ffront.foast_passes import type_deduction @@ -36,68 +37,24 @@ def wrong_syntax(inp: gtx.Field[[TDim], float]): return # <-- this line triggers the syntax error with pytest.raises( - f2f.FieldOperatorSyntaxError, - match=( - r"Invalid Field Operator Syntax: " - r"Empty return not allowed \(test_func_to_foast_error_line_number.py, line " - + str(line + 3) - + r"\)" - ), + f2f.errors.DSLError, + match=(r".*return.*"), ) as exc_info: _ = f2f.FieldOperatorParser.apply_to_function(wrong_syntax) - assert traceback.format_exception_only(exc_info.value)[1:3] == [ - " return # <-- this line triggers the syntax error\n", - " ^^^^^^\n", - ] - - -def test_wrong_caret_placement_bug(): - """Field operator syntax errors respect python's carets (`^^^^^`) placement.""" - - line = inspect.getframeinfo(inspect.currentframe()).lineno - - def wrong_line_syntax_error(inp: gtx.Field[[TDim], float]): - # the next line triggers the syntax error - inp = inp.this_attribute_surely_doesnt_exist - - return inp - - with pytest.raises(f2f.FieldOperatorSyntaxError) as exc_info: - _ = f2f.FieldOperatorParser.apply_to_function(wrong_line_syntax_error) - - exc = exc_info.value - - assert (exc.lineno, exc.end_lineno) == (line + 4, line + 4) - - # if `offset` is set, python will display carets (`^^^^`) after printing `text`. - # So `text` has to be the line where the error occurs (otherwise the carets - # will be very misleading). - - # See https://github.com/python/cpython/blob/6ad47b41a650a13b4a9214309c10239726331eb8/Lib/traceback.py#L852-L855 - python_printed_text = exc.text.rstrip("\n").lstrip(" \n\f") - - assert python_printed_text == "inp = inp.this_attribute_surely_doesnt_exist" - - # test that `offset` is aligned with `exc.text` - return_offset = ( - exc.text.find("inp.this_attribute_surely_doesnt_exist") + 1 - ) # offset is 1-based for syntax errors - assert (exc.offset, exc.end_offset) == (return_offset, return_offset + 38) - - print("".join(traceback.format_exception_only(exc))) - - assert traceback.format_exception_only(exc)[1:3] == [ - " inp = inp.this_attribute_surely_doesnt_exist\n", - " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", - ] + assert exc_info.value.location + assert exc_info.value.location.filename.find("test_func_to_foast_error_line_number.py") + assert exc_info.value.location.line == line + 3 + assert exc_info.value.location.end_line == line + 3 + assert exc_info.value.location.column == 9 + assert exc_info.value.location.end_column == 15 def test_syntax_error_without_function(): """Dialect parsers report line numbers correctly when applied to `SourceDefinition`.""" source_definition = src_utils.SourceDefinition( - starting_line=62, + line_offset=61, source=""" def invalid_python_syntax(): # This function contains a python syntax error @@ -106,17 +63,15 @@ def invalid_python_syntax(): """, ) - with pytest.raises(SyntaxError) as exc_info: + with pytest.raises(errors.DSLError) as exc_info: _ = f2f.FieldOperatorParser.apply(source_definition, {}, {}) - exc = exc_info.value - - assert (exc.lineno, exc.end_lineno) == (66, 66) - - assert traceback.format_exception_only(exc)[1:3] == [ - " ret%% # <-- this line triggers the syntax error\n", - " ^\n", - ] + assert exc_info.value.location + assert exc_info.value.location.filename.find("test_func_to_foast_error_line_number.py") + assert exc_info.value.location.line == 66 + assert exc_info.value.location.end_line == 66 + assert exc_info.value.location.column == 9 + assert exc_info.value.location.end_column == 10 def test_fo_type_deduction_error(): @@ -127,17 +82,17 @@ def test_fo_type_deduction_error(): def field_operator_with_undeclared_symbol(): return undeclared_symbol # noqa: F821 # undefined on purpose - with pytest.raises(type_deduction.FieldOperatorTypeDeductionError) as exc_info: + with pytest.raises(errors.DSLError) as exc_info: _ = f2f.FieldOperatorParser.apply_to_function(field_operator_with_undeclared_symbol) exc = exc_info.value - assert (exc.lineno, exc.end_lineno) == (line + 3, line + 3) - - assert traceback.format_exception_only(exc)[1:3] == [ - " return undeclared_symbol # noqa: F821 # undefined on purpose\n", - " ^^^^^^^^^^^^^^^^^\n", - ] + assert exc_info.value.location + assert exc_info.value.location.filename.find("test_func_to_foast_error_line_number.py") + assert exc_info.value.location.line == line + 3 + assert exc_info.value.location.end_line == line + 3 + assert exc_info.value.location.column == 16 + assert exc_info.value.location.end_column == 33 # TODO: test program type deduction? diff --git a/tests/next_tests/unit_tests/ffront_tests/test_func_to_past.py b/tests/next_tests/unit_tests/ffront_tests/test_func_to_past.py index f1070c93c2..6e617f77a2 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_func_to_past.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_func_to_past.py @@ -19,11 +19,9 @@ import gt4py.eve as eve import gt4py.next as gtx from gt4py.eve.pattern_matching import ObjectPattern as P -from gt4py.next import float64 -from gt4py.next.common import GTTypeError +from gt4py.next import errors, float64 from gt4py.next.ffront import program_ast as past from gt4py.next.ffront.func_to_past import ProgramParser -from gt4py.next.ffront.past_passes.type_deduction import ProgramTypeError from gt4py.next.type_system import type_specifications as ts from next_tests.past_common_fixtures import ( @@ -114,7 +112,7 @@ def undefined_field_program(in_field: gtx.Field[[IDim], "float64"]): identity(in_field, out=out_field) # noqa: F821 # undefined on purpose with pytest.raises( - ProgramTypeError, + errors.DSLError, match=(r"Undeclared or untyped symbol `out_field`."), ): ProgramParser.apply_to_function(undefined_field_program) @@ -163,7 +161,7 @@ def domain_format_1_program(in_field: gtx.Field[[IDim], float64]): domain_format_1(in_field, out=in_field, domain=(0, 2)) with pytest.raises( - GTTypeError, + errors.DSLError, ) as exc_info: ProgramParser.apply_to_function(domain_format_1_program) @@ -182,7 +180,7 @@ def domain_format_2_program(in_field: gtx.Field[[IDim], float64]): domain_format_2(in_field, out=in_field, domain={IDim: (0, 1, 2)}) with pytest.raises( - GTTypeError, + errors.DSLError, ) as exc_info: ProgramParser.apply_to_function(domain_format_2_program) @@ -201,7 +199,7 @@ def domain_format_3_program(in_field: gtx.Field[[IDim], float64]): domain_format_3(in_field, domain={IDim: (0, 2)}) with pytest.raises( - GTTypeError, + errors.DSLError, ) as exc_info: ProgramParser.apply_to_function(domain_format_3_program) @@ -222,7 +220,7 @@ def domain_format_4_program(in_field: gtx.Field[[IDim], float64]): ) with pytest.raises( - GTTypeError, + errors.DSLError, ) as exc_info: ProgramParser.apply_to_function(domain_format_4_program) @@ -241,7 +239,7 @@ def domain_format_5_program(in_field: gtx.Field[[IDim], float64]): domain_format_5(in_field, out=in_field, domain={IDim: ("1.0", 9.0)}) with pytest.raises( - GTTypeError, + errors.DSLError, ) as exc_info: ProgramParser.apply_to_function(domain_format_5_program) @@ -260,7 +258,7 @@ def domain_format_6_program(in_field: gtx.Field[[IDim], float64]): domain_format_6(in_field, out=in_field, domain={}) with pytest.raises( - GTTypeError, + errors.DSLError, ) as exc_info: ProgramParser.apply_to_function(domain_format_6_program) diff --git a/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py b/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py index ff3abcf266..e56dc85322 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py @@ -19,7 +19,7 @@ import gt4py.eve as eve import gt4py.next as gtx from gt4py.eve.pattern_matching import ObjectPattern as P -from gt4py.next.common import GTTypeError +from gt4py.next import errors from gt4py.next.ffront.func_to_past import ProgramParser from gt4py.next.ffront.past_to_itir import ProgramLowering from gt4py.next.iterator import ir as itir @@ -157,7 +157,7 @@ def inout_field_program(inout_field: gtx.Field[[IDim], "float64"]): identity(inout_field, out=inout_field) with pytest.raises( - GTTypeError, + ValueError, match=(r"Call to function with field as input and output not allowed."), ): ProgramLowering.apply( @@ -169,7 +169,7 @@ def inout_field_program(inout_field: gtx.Field[[IDim], "float64"]): def test_invalid_call_sig_program(invalid_call_sig_program_def): with pytest.raises( - GTTypeError, + errors.DSLError, ) as exc_info: ProgramLowering.apply( ProgramParser.apply_to_function(invalid_call_sig_program_def), diff --git a/tests/next_tests/unit_tests/type_system_tests/test_type_translation.py b/tests/next_tests/unit_tests/type_system_tests/test_type_translation.py index 59f53a3050..d281f5cd90 100644 --- a/tests/next_tests/unit_tests/type_system_tests/test_type_translation.py +++ b/tests/next_tests/unit_tests/type_system_tests/test_type_translation.py @@ -50,7 +50,7 @@ def test_valid_scalar_kind(value, expected): def test_invalid_scalar_kind(): - with pytest.raises(common.GTTypeError, match="Non-trivial dtypes"): + with pytest.raises(ValueError, match="Non-trivial dtypes"): type_translation.get_scalar_kind(np.dtype("i4, (2,3)f8, f4")) @@ -130,42 +130,40 @@ def test_make_symbol_type_from_typing(value, expected): def test_invalid_symbol_types(): # Forward references - with pytest.raises(type_translation.TypingError, match="undefined forward references"): + with pytest.raises(ValueError, match="undefined forward references"): type_translation.from_type_hint("foo") # Tuples - with pytest.raises(type_translation.TypingError, match="least one argument"): + with pytest.raises(ValueError, match="least one argument"): type_translation.from_type_hint(typing.Tuple) - with pytest.raises(type_translation.TypingError, match="least one argument"): + with pytest.raises(ValueError, match="least one argument"): type_translation.from_type_hint(tuple) - with pytest.raises(type_translation.TypingError, match="Unbound tuples"): + with pytest.raises(ValueError, match="Unbound tuples"): type_translation.from_type_hint(tuple[int, ...]) - with pytest.raises(type_translation.TypingError, match="Unbound tuples"): + with pytest.raises(ValueError, match="Unbound tuples"): type_translation.from_type_hint(typing.Tuple["float", ...]) # Fields - with pytest.raises(type_translation.TypingError, match="Field type requires two arguments"): - type_translation.from_type_hint(gtx.Field) - with pytest.raises(type_translation.TypingError, match="Invalid field dimensions"): - type_translation.from_type_hint(gtx.Field[int, int]) - with pytest.raises(type_translation.TypingError, match="Invalid field dimension"): - type_translation.from_type_hint(gtx.Field[[int, int], int]) - - with pytest.raises(type_translation.TypingError, match="Field dtype argument"): - type_translation.from_type_hint(gtx.Field[[IDim], str]) - with pytest.raises(type_translation.TypingError, match="Field dtype argument"): - type_translation.from_type_hint(gtx.Field[[IDim], None]) + with pytest.raises(ValueError, match="Field type requires two arguments"): + type_translation.from_type_hint(common.Field) + with pytest.raises(ValueError, match="Invalid field dimensions"): + type_translation.from_type_hint(common.Field[int, int]) + with pytest.raises(ValueError, match="Invalid field dimension"): + type_translation.from_type_hint(common.Field[[int, int], int]) + + with pytest.raises(ValueError, match="Field dtype argument"): + type_translation.from_type_hint(common.Field[[IDim], str]) + with pytest.raises(ValueError, match="Field dtype argument"): + type_translation.from_type_hint(common.Field[[IDim], None]) # Functions - with pytest.raises( - type_translation.TypingError, match="Not annotated functions are not supported" - ): + with pytest.raises(ValueError, match="Not annotated functions are not supported"): type_translation.from_type_hint(typing.Callable) - with pytest.raises(type_translation.TypingError, match="Invalid callable annotations"): + with pytest.raises(ValueError, match="Invalid callable annotations"): type_translation.from_type_hint(typing.Callable[..., float]) - with pytest.raises(type_translation.TypingError, match="Invalid callable annotations"): + with pytest.raises(ValueError, match="Invalid callable annotations"): type_translation.from_type_hint(typing.Callable[[int], str]) - with pytest.raises(type_translation.TypingError, match="Invalid callable annotations"): + with pytest.raises(ValueError, match="Invalid callable annotations"): type_translation.from_type_hint(typing.Callable[[int], float])