Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix typing in icon4pytools #267

Merged
merged 15 commits into from
Sep 21, 2023
1 change: 1 addition & 0 deletions base-requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,4 @@ setuptools>=40.8.0
wheel>=0.37.1
tox >= 3.25
wget>=3.2
types-cffi>=1.15
4 changes: 2 additions & 2 deletions model/common/src/icon4py/model/common/grid/vertical.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import numpy as np
from gt4py.next.ffront.fbuiltins import int32

from gt4py.next import common
from icon4py.model.common.dimension import KDim


Expand All @@ -34,7 +34,7 @@ class VerticalModelParams:
rayleigh_damping_height: height of rayleigh damping in [m] mo_nonhydro_nml
"""

vct_a: Field[[KDim], float]
vct_a: common.Field
rayleigh_damping_height: Final[float]
index_of_damping_layer: Final[int32] = field(init=False)

Expand Down
4 changes: 3 additions & 1 deletion tools/.flake8
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ extend-ignore =
# Line too long (using Bugbear's B950 warning)
E501,
# Line break occurred before a binary operator
W503
W503,
# Calling setattr with a constant attribute value
B010

exclude =
.eggs,
Expand Down
72 changes: 28 additions & 44 deletions tools/src/icon4pytools/f2ser/deserialise.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,40 +31,32 @@ def __init__(
self.parsed = parsed
self.directory = directory
self.prefix = prefix
self.data = {"Savepoint": [], "Init": ..., "Import": ...}

def __call__(self) -> SerialisationCodeInterface:
"""Deserialise the parsed granule and returns a serialisation interface.

Returns:
A `SerialisationInterface` object representing the deserialised data.
"""
"""Deserialise the parsed granule and returns a serialisation interface."""
self._merge_out_inout_fields()
self._make_savepoints()
self._make_init_data()
self._make_imports()
return SerialisationCodeInterface(**self.data)
savepoints = self._make_savepoints()
init_data = self._make_init_data()
import_data = self._make_imports()
return SerialisationCodeInterface(Import=import_data, Init=init_data, Savepoint=savepoints)

def _make_savepoints(self) -> None:
"""Create savepoints for each subroutine and intent in the parsed granule.
def _make_savepoints(self) -> list[SavepointData]:
"""Create savepoints for each subroutine and intent in the parsed granule."""
savepoints: list[SavepointData] = []

Returns:
None.
"""
for subroutine_name, intent_dict in self.parsed.subroutines.items():
for intent, var_dict in intent_dict.items():
self._create_savepoint(subroutine_name, intent, var_dict)
savepoints.append(self._create_savepoint(subroutine_name, intent, var_dict))

def _create_savepoint(self, subroutine_name: str, intent: str, var_dict: dict) -> None:
return savepoints

def _create_savepoint(self, subroutine_name: str, intent: str, var_dict: dict) -> SavepointData:
"""Create a savepoint for the given variables.

Args:
subroutine_name: The name of the subroutine.
intent: The intent of the fields to be serialised.
var_dict: A dictionary representing the variables to be saved.

Returns:
None.
"""
field_vals = {k: v for k, v in var_dict.items() if isinstance(v, dict)}
fields = [
Expand All @@ -80,14 +72,12 @@ def _create_savepoint(self, subroutine_name: str, intent: str, var_dict: dict) -
for var_name, var_data in field_vals.items()
]

self.data["Savepoint"].append(
SavepointData(
subroutine=subroutine_name,
intent=intent,
startln=self._get_codegen_line(var_dict["codegen_ctx"], intent),
fields=fields,
metadata=None,
)
return SavepointData(
subroutine=subroutine_name,
intent=intent,
startln=self._get_codegen_line(var_dict["codegen_ctx"], intent),
fields=fields,
metadata=None,
)

@staticmethod
Expand Down Expand Up @@ -123,39 +113,33 @@ def _create_association(self, var_data: dict, var_name: str) -> str:
)
return var_name

def _make_init_data(self) -> None:
"""Create an `InitData` object and sets it to the `Init` key in the `data` dictionary.

Returns:
None.
"""
def _make_init_data(self) -> InitData:
"""Create an `InitData` object and sets it to the `Init` key in the `data` dictionary."""
first_intent_in_subroutine = [
var_dict
for intent_dict in self.parsed.subroutines.values()
for intent, var_dict in intent_dict.items()
if intent == "in"
][0]

startln = self._get_codegen_line(first_intent_in_subroutine["codegen_ctx"], "init")
self.data["Init"] = InitData(

return InitData(
startln=startln,
directory=self.directory,
prefix=self.prefix,
)

def _merge_out_inout_fields(self):
"""Merge the `inout` fields into the `in` and `out` fields in the `parsed` dictionary.

Returns:
None.
"""
def _merge_out_inout_fields(self) -> None:
"""Merge the `inout` fields into the `in` and `out` fields in the `parsed` dictionary."""
for _, intent_dict in self.parsed.subroutines.items():
if "inout" in intent_dict:
intent_dict["in"].update(intent_dict["inout"])
intent_dict["out"].update(intent_dict["inout"])
del intent_dict["inout"]

@staticmethod
def _get_codegen_line(ctx: CodegenContext, intent: str):
def _get_codegen_line(ctx: CodegenContext, intent: str) -> int:
if intent == "in":
return ctx.last_declaration_ln
elif intent == "out":
Expand All @@ -165,5 +149,5 @@ def _get_codegen_line(ctx: CodegenContext, intent: str):
else:
raise ValueError(f"Unrecognized intent: {intent}")

def _make_imports(self):
self.data["Import"] = ImportData(startln=self.parsed.last_import_ln)
def _make_imports(self) -> ImportData:
return ImportData(startln=self.parsed.last_import_ln)
31 changes: 13 additions & 18 deletions tools/src/icon4pytools/f2ser/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class CodegenContext:
end_subroutine_ln: int


ParsedSubroutines = dict[str, dict[str, dict[str, any] | CodegenContext]]
ParsedSubroutines = dict[str, dict[str, dict[str, Any]]]


@dataclass
Expand Down Expand Up @@ -69,36 +69,34 @@ def __call__(self) -> ParsedGranule:
last_import_ln = self._find_last_fortran_use_statement()
return ParsedGranule(subroutines=subroutines, last_import_ln=last_import_ln)

def _find_last_fortran_use_statement(self) -> Optional[int]:
def _find_last_fortran_use_statement(self) -> int:
"""Find the line number of the last Fortran USE statement in the code.

Returns:
int: the line number of the last USE statement, or None if no USE statement is found.
int: the line number of the last USE statement.
"""
# Reverse the order of the lines so we can search from the end
code = self._read_code_from_file()
code_lines = code.splitlines()
code_lines.reverse()

# Look for the last USE statement
use_ln = None
for i, line in enumerate(code_lines):
if line.strip().lower().startswith("use"):
use_ln = len(code_lines) - i
if i > 0 and code_lines[i - 1].strip().lower() == "#endif":
# If the USE statement is preceded by an #endif statement, return the line number after the #endif statement
return use_ln + 1
else:
return use_ln
return None
use_ln += 1
return use_ln
raise ParsingError("Could not find any USE statements.")

def _read_code_from_file(self) -> str:
"""Read the content of the granule and returns it as a string."""
with open(self.granule_path) as f:
code = f.read()
return code

def parse_subroutines(self):
def parse_subroutines(self) -> dict:
subroutines = self._extract_subroutines(crack(self.granule_path))
variables_grouped_by_intent = {
name: self._extract_intent_vars(routine) for name, routine in subroutines.items()
Expand Down Expand Up @@ -263,7 +261,7 @@ def _combine_types(derived_type_vars: dict, intrinsic_type_vars: dict) -> dict:
combined[subroutine_name][intent].update(new_vars)
return combined

def _update_with_codegen_lines(self, parsed_types: dict) -> dict:
def _update_with_codegen_lines(self, parsed_types: dict[str, Any]) -> dict[str, Any]:
"""Update the parsed_types dictionary with the line numbers for codegen.

Args:
Expand All @@ -285,9 +283,6 @@ def _get_subroutine_lines(self, subroutine_name: str) -> CodegenContext:

Args:
subroutine_name (str): Name of the subroutine to look for in the code.

Returns:
CodegenContext: Object containing the line number of the last declaration statement and the line number of the last line of the code before the end of the given subroutine.
"""
code = self._read_code_from_file()

Expand All @@ -312,7 +307,7 @@ def _get_subroutine_lines(self, subroutine_name: str) -> CodegenContext:
return CodegenContext(first_declaration_ln, last_declaration_ln, pre_end_subroutine_ln)

@staticmethod
def _find_subroutine_lines(code: str, subroutine_name: str) -> tuple[int]:
def _find_subroutine_lines(code: str, subroutine_name: str) -> tuple[int, int]:
"""Find line numbers of a subroutine within a code block.

Args:
Expand All @@ -327,15 +322,15 @@ def _find_subroutine_lines(code: str, subroutine_name: str) -> tuple[int]:
start_match = re.search(start_subroutine_pattern, code)
end_match = re.search(end_subroutine_pattern, code)
if start_match is None or end_match is None:
return None
raise ParsingError(f"Could not find {start_match} or {end_match}")
start_subroutine_ln = code[: start_match.start()].count("\n") + 1
end_subroutine_ln = code[: end_match.start()].count("\n") + 1
return start_subroutine_ln, end_subroutine_ln

@staticmethod
def _find_variable_declarations(
code: str, start_subroutine_ln: int, end_subroutine_ln: int
) -> list:
) -> list[int]:
"""Find line numbers of variable declarations within a code block.

Args:
Expand Down Expand Up @@ -371,8 +366,8 @@ def _find_variable_declarations(

@staticmethod
def _get_variable_declaration_bounds(
declaration_pattern_lines: list, start_subroutine_ln: int
) -> tuple:
declaration_pattern_lines: list[int], start_subroutine_ln: int
) -> tuple[int, int]:
"""Return the line numbers of the bounds for a variable declaration block.

Args:
Expand Down
7 changes: 3 additions & 4 deletions tools/src/icon4pytools/icon4pygen/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ def _is_size_param(param: itir.Sym) -> bool:
@staticmethod
def _missing_domain_params(params: List[itir.Sym]) -> Iterable[itir.Sym]:
"""Get domain limit params that are not present in param list."""
return map(
lambda p: itir.Sym(id=p),
filter(lambda s: s not in map(lambda p: p.id, params), _DOMAIN_ARGS),
)
param_ids = [p.id for p in params]
missing_args = [s for s in _DOMAIN_ARGS if s not in param_ids]
return (itir.Sym(id=p) for p in missing_args)
4 changes: 2 additions & 2 deletions tools/src/icon4pytools/icon4pygen/bindings/codegen/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# SPDX-License-Identifier: GPL-3.0-or-later

from pathlib import Path
from typing import Sequence
from typing import Any, Sequence

from gt4py import eve
from gt4py.eve.codegen import JinjaTemplate as as_jinja
Expand Down Expand Up @@ -678,7 +678,7 @@ def _get_field_data(self) -> tuple:
)
return fields, offsets

def __post_init__(self) -> None: # type: ignore
def __post_init__(self, *args: Any, **kwargs: Any) -> None:
fields, offsets = self._get_field_data()
offset_renderer = GpuTriMeshOffsetRenderer(self.offsets)

Expand Down
14 changes: 7 additions & 7 deletions tools/src/icon4pytools/icon4pygen/bindings/codegen/f90.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# SPDX-License-Identifier: GPL-3.0-or-later

from pathlib import Path
from typing import Sequence, Union
from typing import Any, Sequence, Union

from gt4py import eve
from gt4py.eve import Node
Expand Down Expand Up @@ -214,7 +214,7 @@ class F90RunFun(eve.Node):
params: F90EntityList = eve.datamodels.field(init=False)
binds: F90EntityList = eve.datamodels.field(init=False)

def __post_init__(self) -> None: # type: ignore
def __post_init__(self, *args: Any, **kwargs: Any) -> None:
param_fields = [F90Field(name=field.name) for field in self.all_fields] + [
F90Field(name=name) for name in _DOMAIN_ARGS
]
Expand Down Expand Up @@ -242,7 +242,7 @@ class F90RunAndVerifyFun(eve.Node):
params: F90EntityList = eve.datamodels.field(init=False)
binds: F90EntityList = eve.datamodels.field(init=False)

def __post_init__(self) -> None: # type: ignore
def __post_init__(self, *args: Any, **kwargs: Any) -> None:
param_fields = (
[F90Field(name=field.name) for field in self.all_fields]
+ [F90Field(name=field.name, suffix="before") for field in self.out_fields]
Expand Down Expand Up @@ -295,7 +295,7 @@ class F90SetupFun(Node):
params: F90EntityList = eve.datamodels.field(init=False)
binds: F90EntityList = eve.datamodels.field(init=False)

def __post_init__(self) -> None: # type: ignore
def __post_init__(self, *args: Any, **kwargs: Any) -> None:
param_fields = [
F90Field(name=name)
for name in [
Expand Down Expand Up @@ -346,7 +346,7 @@ class F90WrapRunFun(Node):
run_ver_params: F90EntityList = eve.datamodels.field(init=False)
run_params: F90EntityList = eve.datamodels.field(init=False)

def __post_init__(self) -> None: # type: ignore
def __post_init__(self, *args: Any, **kwargs: Any) -> None:
param_fields = (
[F90Field(name=field.name) for field in self.all_fields]
+ [F90Field(name=field.name, suffix="before") for field in self.out_fields]
Expand Down Expand Up @@ -457,7 +457,7 @@ class F90WrapSetupFun(Node):
vert_conditionals: F90EntityList = eve.datamodels.field(init=False)
setup_params: F90EntityList = eve.datamodels.field(init=False)

def __post_init__(self) -> None: # type: ignore
def __post_init__(self, *args: Any, **kwargs: Any) -> None:
param_fields = [
F90Field(name=name)
for name in [
Expand Down Expand Up @@ -534,7 +534,7 @@ class F90File(Node):
wrap_run_fun: F90WrapRunFun = eve.datamodels.field(init=False)
wrap_setup_fun: F90WrapSetupFun = eve.datamodels.field(init=False)

def __post_init__(self) -> None: # type: ignore
def __post_init__(self, *args: Any, **kwargs: Any) -> None:
all_fields = self.fields
out_fields = [field for field in self.fields if field.intent.out]
tol_fields = [field for field in out_fields if not field.is_integral()]
Expand Down
4 changes: 2 additions & 2 deletions tools/src/icon4pytools/icon4pygen/bindings/codegen/header.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# SPDX-License-Identifier: GPL-3.0-or-later

from pathlib import Path
from typing import Sequence
from typing import Any, Sequence

from gt4py import eve
from gt4py.eve import Node
Expand Down Expand Up @@ -149,7 +149,7 @@ class CppHeaderFile(Node):
setupFunc: CppSetupFuncDeclaration = eve.datamodels.field(init=False)
freeFunc: CppFreeFunc = eve.datamodels.field(init=False)

def __post_init__(self) -> None: # type: ignore
def __post_init__(self, *args: Any, **kwargs: Any) -> None:
output_fields = [field for field in self.fields if field.intent.out]
tolerance_fields = [field for field in output_fields if not field.is_integral()]

Expand Down
Loading