Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove before fields in substitution mode #416

Merged
merged 15 commits into from
Apr 9, 2024
Next Next commit
first draft
  • Loading branch information
huppd committed Mar 18, 2024
commit b16c8de9954d55e706a4f6668d29c7f57734d068
205 changes: 185 additions & 20 deletions tools/src/icon4pytools/icon4pygen/bindings/codegen/f90.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class F90Generator(TemplatedGenerator):
implicit none
interface
{{run_fun}}
{{run_fun}}
{{run_and_verify_fun}}
{{setup_fun}}
subroutine &
Expand All @@ -51,6 +52,7 @@ class F90Generator(TemplatedGenerator):
contains

{{wrap_run_fun}}
{{wrap_run_and_verify_fun}}
{{wrap_setup_fun}}
end module
"""
Expand Down Expand Up @@ -95,6 +97,39 @@ class F90Generator(TemplatedGenerator):
"""
)

F90WrapRunAndVerifyFun = as_jinja(
"""\
subroutine &
wrap_run_and_verify_{{stencil_name}}( &
{{params}}
)
use, intrinsic :: iso_c_binding
use openacc
{{binds}}
{{tol_decls}}
integer(c_int) :: vertical_start
integer(c_int) :: vertical_end
integer(c_int) :: horizontal_start
integer(c_int) :: horizontal_end
{{k_sizes}}
vertical_start = vertical_lower-1
vertical_end = vertical_upper
horizontal_start = horizontal_lower-1
horizontal_end = horizontal_upper
{{k_sizes_assignments}}
{{conditionals}}
!$ACC host_data use_device( &
{{host_data_run_and_verify}}
!$ACC )
call run_and_verify_{{stencil_name}} &
( &
{{run_ver_params}}
)
!$ACC end host_data
end subroutine
"""
)

F90WrapRunFun = as_jinja(
"""\
subroutine &
Expand All @@ -117,19 +152,12 @@ class F90Generator(TemplatedGenerator):
{{k_sizes_assignments}}
{{conditionals}}
!$ACC host_data use_device( &
{{openacc}}
{{host_data_run}}
!$ACC )
#ifdef __DSL_VERIFY
call run_and_verify_{{stencil_name}} &
( &
{{run_ver_params}}
)
#else
call run_{{stencil_name}} &
( &
{{run_params}}
)
#endif
call run_{{stencil_name}} &
( &
{{run_params}}
)
!$ACC end host_data
end subroutine
"""
Expand Down Expand Up @@ -163,7 +191,7 @@ class F90Generator(TemplatedGenerator):

F90Field = as_jinja("{{ name }}{% if suffix %}_{{ suffix }}{% endif %}")

F90OpenACCField = as_jinja("!$ACC {{ name }}{% if suffix %}_{{ suffix }}{% endif %}")
F90HostDataField = as_jinja("!$ACC {{ name }}{% if suffix %}_{{ suffix }}{% endif %}")

F90TypedField = as_jinja(
"{{ dtype }}, {% if dims %}{{ dims }},{% endif %} target {% if _this_node.optional %} , optional {% endif %}:: {{ name }}{% if suffix %}_{{ suffix }}{% endif %} "
Expand All @@ -185,7 +213,7 @@ class F90Field(eve.Node):
suffix: str = ""


class F90OpenACCField(F90Field):
class F90HostDataField(F90Field):
...


Expand Down Expand Up @@ -347,7 +375,7 @@ def __post_init__(self, *args: Any, **kwargs: Any) -> None:
self.binds = F90EntityList(fields=bind_fields)


class F90WrapRunFun(Node):
class F90WrapRunAndVerifyFun(Node):
stencil_name: str
all_fields: Sequence[Field]
out_fields: Sequence[Field]
Expand All @@ -358,7 +386,7 @@ class F90WrapRunFun(Node):
conditionals: F90EntityList = eve.datamodels.field(init=False)
k_sizes: F90EntityList = eve.datamodels.field(init=False)
k_sizes_assignments: F90EntityList = eve.datamodels.field(init=False)
openacc: F90EntityList = eve.datamodels.field(init=False)
host_data_run_and_verify: F90EntityList = eve.datamodels.field(init=False)
tol_decls: F90EntityList = eve.datamodels.field(init=False)
run_ver_params: F90EntityList = eve.datamodels.field(init=False)
run_params: F90EntityList = eve.datamodels.field(init=False)
Expand Down Expand Up @@ -418,10 +446,10 @@ def __post_init__(self, *args: Any, **kwargs: Any) -> None:
for short, long in [("rel", "RELATIVE"), ("abs", "ABSOLUTE")]
for field in self.tol_fields
]
open_acc_fields = [
F90OpenACCField(name=field.name) for field in self.all_fields if field.rank() != 0
host_data_run_and_verify_fields = [
F90HostDataField(name=field.name) for field in self.all_fields if field.rank() != 0
] + [
F90OpenACCField(name=field.name, suffix="before")
F90HostDataField(name=field.name, suffix="before")
for field in self.out_fields
if field.rank() != 0
]
Expand Down Expand Up @@ -475,12 +503,141 @@ def __post_init__(self, *args: Any, **kwargs: Any) -> None:
self.conditionals = F90EntityList(fields=cond_fields)
self.k_sizes = F90EntityList(fields=k_sizes_fields)
self.k_sizes_assignments = F90EntityList(fields=k_sizes_assignment_fields)
self.openacc = F90EntityList(fields=open_acc_fields, line_end=", &", line_end_last=" &")
self.host_data_run_and_verify = F90EntityList(fields=host_data_run_and_verify_fields, line_end=", &", line_end_last=" &")
self.run_ver_params = F90EntityList(
fields=run_ver_param_fields, line_end=", &", line_end_last=" &"
)
self.run_params = F90EntityList(fields=run_param_fields, line_end=", &", line_end_last=" &")

class F90WrapRunFun(Node):
stencil_name: str
all_fields: Sequence[Field]
out_fields: Sequence[Field]
tol_fields: Sequence[Field]

params: F90EntityList = eve.datamodels.field(init=False)
binds: F90EntityList = eve.datamodels.field(init=False)
conditionals: F90EntityList = eve.datamodels.field(init=False)
k_sizes: F90EntityList = eve.datamodels.field(init=False)
k_sizes_assignments: F90EntityList = eve.datamodels.field(init=False)
host_data_run: F90EntityList = eve.datamodels.field(init=False)
tol_decls: F90EntityList = eve.datamodels.field(init=False)
run_ver_params: F90EntityList = eve.datamodels.field(init=False)
run_params: F90EntityList = eve.datamodels.field(init=False)

def __post_init__(self, *args: Any, **kwargs: Any) -> None:
param_fields = (
[F90Field(name=field.name) for field in self.all_fields]
+ [F90Field(name=field.name, suffix="before") for field in self.out_fields]
+ [F90Field(name=name) for name in _DOMAIN_ARGS]
)
bind_fields = (
[
F90TypedField(
name=field.name,
dtype=field.renderer.render_ctype("f90"),
dims=field.renderer.render_ranked_dim_string(),
)
for field in self.all_fields
]
+ [
F90TypedField(
name=field.name,
suffix="before",
dtype=field.renderer.render_ctype("f90"),
dims=field.renderer.render_ranked_dim_string(),
)
for field in self.out_fields
]
+ [
F90TypedField(name=name, dtype="integer(c_int)", dims="value")
for name in _DOMAIN_ARGS
]
)
tol_fields = [
F90TypedField(name=field.name, suffix=s, dtype="real(c_double)")
for s in ["rel_err_tol", "abs_err_tol"]
for field in self.tol_fields
]
k_sizes_fields = [
F90TypedField(name=field.name, suffix=s, dtype="integer")
for s in ["k_size"]
for field in self.out_fields
]
k_sizes_assignment_fields = [
F90Assignment(
left_side=f"{field.name}_k_size",
right_side=f"SIZE({field.name}, 2)",
)
for field in self.out_fields
]
cond_fields = [
F90Conditional(
predicate=f"present({field.name}_{short}_tol)",
if_branch=f"{field.name}_{short}_err_tol = {field.name}_{short}_tol",
else_branch=f"{field.name}_{short}_err_tol = DEFAULT_{long}_ERROR_THRESHOLD",
)
for short, long in [("rel", "RELATIVE"), ("abs", "ABSOLUTE")]
for field in self.tol_fields
]
host_data_run_fields = [
F90HostDataField(name=field.name) for field in self.all_fields if field.rank() != 0
]
run_ver_param_fields = (
[F90Field(name=field.name) for field in self.all_fields]
+ [F90Field(name=field.name, suffix="before") for field in self.out_fields]
+ [F90Field(name=field.name, suffix="k_size") for field in self.out_fields]
+ [
F90Field(name=name)
for name in [
"vertical_start",
"vertical_end",
"horizontal_start",
"horizontal_end",
]
]
)
run_param_fields = (
[F90Field(name=field.name) for field in self.all_fields]
+ [F90Field(name=field.name, suffix="k_size") for field in self.out_fields]
+ [
F90Field(name=name)
for name in [
"vertical_start",
"vertical_end",
"horizontal_start",
"horizontal_end",
]
]
)

for field in self.tol_fields:
param_fields += [F90Field(name=field.name, suffix=s) for s in ["rel_tol", "abs_tol"]]
bind_fields += [
F90TypedField(
name=field.name,
suffix=s,
dtype="real(c_double)",
dims="value",
optional=True,
)
for s in ["rel_tol", "abs_tol"]
]
run_ver_param_fields += [
F90Field(name=field.name, suffix=s) for s in ["rel_err_tol", "abs_err_tol"]
]

self.params = F90EntityList(fields=param_fields, line_end=", &", line_end_last=" &")
self.binds = F90EntityList(fields=bind_fields)
self.tol_decls = F90EntityList(fields=tol_fields)
self.conditionals = F90EntityList(fields=cond_fields)
self.k_sizes = F90EntityList(fields=k_sizes_fields)
self.k_sizes_assignments = F90EntityList(fields=k_sizes_assignment_fields)
self.host_data_run = F90EntityList(fields=host_data_run_fields, line_end=", &", line_end_last=" &")
self.run_ver_params = F90EntityList(
fields=run_ver_param_fields, line_end=", &", line_end_last=" &"
)
self.run_params = F90EntityList(fields=run_param_fields, line_end=", &", line_end_last=" &")

class F90WrapSetupFun(Node):
stencil_name: str
Expand Down Expand Up @@ -539,6 +696,7 @@ class F90File(Node):
run_fun: F90RunFun = eve.datamodels.field(init=False)
run_and_verify_fun: F90RunAndVerifyFun = eve.datamodels.field(init=False)
setup_fun: F90SetupFun = eve.datamodels.field(init=False)
wrap_run_and_verify_fun: F90WrapRunAndVerifyFun = eve.datamodels.field(init=False)
wrap_run_fun: F90WrapRunFun = eve.datamodels.field(init=False)
wrap_setup_fun: F90WrapSetupFun = eve.datamodels.field(init=False)

Expand All @@ -565,6 +723,13 @@ def __post_init__(self, *args: Any, **kwargs: Any) -> None:
out_fields=out_fields,
)

self.wrap_run_and_verify_fun = F90WrapRunAndVerifyFun(
stencil_name=self.stencil_name,
all_fields=all_fields,
out_fields=out_fields,
tol_fields=tol_fields,
)

self.wrap_run_fun = F90WrapRunFun(
stencil_name=self.stencil_name,
all_fields=all_fields,
Expand Down
6 changes: 3 additions & 3 deletions tools/src/icon4pytools/liskov/codegen/integration/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ class EndStencilStatementGenerator(BaseEndStencilStatementGenerator):
{% if _this_node.noprofile %}{% else %}call nvtxEndRange(){% endif %}
{%- endif %}
{% if _this_node.noendif %}{% else %}#endif{% endif %}
call wrap_run_{{ name }}( &
call wrap_run_and_verify_{{ name }}( &
{{ input_fields }}
{{ output_fields }}
{{ tolerance_fields }}
Expand All @@ -214,7 +214,7 @@ class EndStencilStatementGenerator(BaseEndStencilStatementGenerator):
class EndFusedStencilStatementGenerator(BaseEndStencilStatementGenerator):
EndFusedStencilStatement = as_jinja(
"""
call wrap_run_{{ name }}( &
call wrap_run_and_verify_{{ name }}( &
{{ input_fields }}
{{ output_fields }}
{{ tolerance_fields }}
Expand Down Expand Up @@ -401,7 +401,7 @@ def __post_init__(self, *args: Any, **kwargs: Any) -> None:

class ImportsStatementGenerator(TemplatedGenerator):
ImportsStatement = as_jinja(
""" {% for name in stencil_names %}USE {{ name }}, ONLY: wrap_run_{{ name }}\n{% endfor %}"""
""" {% for name in stencil_names %}USE {{ name }}, ONLY: wrap_run_and_verify_{{ name }}\n{% endfor %}"""
)


Expand Down
2 changes: 1 addition & 1 deletion tools/tests/icon4pygen/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def check_fortran_codegen(fname: str) -> None:
f"run_and_verify_{stencil_name}",
f"setup_{stencil_name}",
f"free_{stencil_name}",
f"wrap_run_{stencil_name}",
f"wrap_run_and_verify{stencil_name}",
]
check_for_matches(fname, patterns)

Expand Down