Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Storage reduction pass #203

Closed
wants to merge 99 commits into from
Closed
Show file tree
Hide file tree
Changes from 96 commits
Commits
Show all changes
99 commits
Select commit Hold shift + click to select a range
4ac03ef
Fix CuPy invalid argument error in 'GPUStorage' class
Aug 12, 2020
51a06bb
Merge branch 'gpu_storage_fix' into storage_reduction
Aug 24, 2020
7584678
Add initial ReduceTemporaryStoragesPass
Aug 24, 2020
1cb1a94
Implement ReduceTemporaryStoragesPass.StorageReducer inner class
Aug 24, 2020
b98eba0
Implementing 2D fields in 'numpy' backend
Aug 25, 2020
ea5e01f
Comment some code
Aug 25, 2020
37fbb54
Modify StorageReducer to skip 3D->2D reductions for parallel multistages
Aug 25, 2020
e887464
Merge with master
Aug 25, 2020
d2a57ca
Implement I-J storage in numpy backend
Aug 25, 2020
d33068f
Implement IJ storages in 'debug' backend
Aug 25, 2020
7ecb778
Add tests to reproduce FORWARD bug in 'DemoteLocalTemporariesToVariab…
Aug 27, 2020
5f55d19
Merge branch 'master' into forward_loop_fixes
Aug 27, 2020
4478e9b
Modify NormalizeBlocksPass to put each statement in own stage for all…
Aug 27, 2020
944bc5c
Apply pre-commit formatting
Aug 27, 2020
2a5ce47
Revert "Apply pre-commit formatting"
Aug 27, 2020
51ab61c
Apply pre-commit formatting
Aug 27, 2020
914c7bc
Merge branch 'forward_loop_fixes' into storage_reduction
Aug 28, 2020
456caf9
Add NumpyBackend.range_args to track when new K loops need to be gene…
Aug 28, 2020
d44d7e9
Add 'axes' to 'field_attributes'
Aug 28, 2020
84abff0
Merge from master
Sep 2, 2020
39832d4
Add support for 2D storages to GT backends
Sep 4, 2020
c55fcf8
Merge branch 'master' into storage_reduction
Sep 8, 2020
daaa8bc
Merge from master
Sep 10, 2020
502f413
Collect fields to reduce before applying reductions
Sep 11, 2020
32f7631
Mark parallel fields so they will not be reduced
Sep 11, 2020
8ca99c0
Refactor to include all parallel fields
Sep 11, 2020
d2b42bd
Only reduce fields that do not span multiple multistages
Sep 14, 2020
32dcd27
Tracking full fields not necessary if multi_stage spanning fields are…
Sep 14, 2020
1e9f336
Revert last commit
Sep 14, 2020
7101d61
Clear 'range_args' between parallel regions
Sep 14, 2020
0c22b62
Add 'test_reduce_temporary_storages_pass' test
Sep 14, 2020
f1c0094
Refactor lower dim storages to use gt::selector
Sep 15, 2020
f30f7b3
Update field extents from accessors
Sep 16, 2020
09c042d
Merge branch 'master' into storage_reduction
Sep 16, 2020
352fdac
Merge from master
Sep 16, 2020
1f4e761
Fix merge_blocks_pass
jdahm Sep 23, 2020
987d9cc
Fix StageMergingWrapper.has_incompatible_intervals_with to return Tru…
Sep 23, 2020
020b4b2
Merge branch 'master' into storage_reduction
Sep 23, 2020
34e42a0
[wip]: more general definition setup
Sep 24, 2020
ac208f6
Add asserts to test
jdahm Sep 24, 2020
12c312a
Merge branch 'bugfix/multistage-data-dependencies' of github.com:jdah…
jdahm Sep 24, 2020
0303cbc
Use itertools.chain.from_iterable
jdahm Sep 24, 2020
4655ea3
Merge branch 'bugfix/multistage-data-dependencies' of https://github.…
Sep 24, 2020
db914b8
Assert statement order in 'test_no_merge_with_overlapping_intervals'
Sep 24, 2020
babf61d
generalize test setup for stencil definitions
Sep 25, 2020
1efd90f
Merge remote-tracking branch 'jdahm/bugfix/multistage-data-dependenci…
Sep 25, 2020
9865c2a
Check data dependencies for parallel loops only
Sep 25, 2020
5192ac3
Revert previous change as too invasive
Sep 25, 2020
0c488ca
Merge branch 'bugfix/multistage-data-dependencies' of https://github.…
Sep 25, 2020
f67ab96
Do not reduce 3D to 2D temp field if k-intervals vary
Sep 28, 2020
6efa4f2
add ir_maker AugAssign test
Sep 29, 2020
083b66b
Merge remote-tracking branch 'jdahm/bugfix/multistage-data-dependenci…
Sep 29, 2020
8015f11
fix the ir maker AugAssign test
Sep 29, 2020
c3f3a55
fix IRMaker.visit_AugAssign
Sep 29, 2020
cc8f557
Merge branch 'master' into storage_reduction
Sep 29, 2020
65b4fa7
Apply pre-commit changes
Sep 29, 2020
d8962c8
Merge pull request #2 from DropD/bugfix/multistage-data-dependencies
jdahm Sep 30, 2020
39a41d4
Rename PassType to AnalysisPass
jdahm Sep 30, 2020
39cbd2a
Merge branch 'master' of https://github.com/GridTools/gt4py
Sep 30, 2020
5b39500
Merge branch 'master' into storage_reduction
Sep 30, 2020
605992f
Merge branch 'master' into bugfix/multistage-data-dependencies
jdahm Sep 30, 2020
90caa58
Cleanup before PR
Sep 30, 2020
9e68fc5
Merge branch 'master' of https://github.com/GridTools/gt4py
Sep 30, 2020
8b1aafe
Merge branch 'master' into storage_reduction
Sep 30, 2020
d9af5aa
Merge branch 'bugfix/multistage-data-dependencies' of https://github.…
Sep 30, 2020
9e585fd
Merge branch 'master' of https://github.com/GridTools/gt4py
Oct 1, 2020
790e90d
Merge branch 'master' of https://github.com/GridTools/gt4py
Oct 1, 2020
d5258bc
Fix merge issue in 'definition_setup.py'
Oct 1, 2020
0a10344
Merge branch 'master' of https://github.com/eddie-c-davis/gt4py
Oct 7, 2020
882c114
Merge branch 'master' into storage_reduction
Oct 7, 2020
76042e5
Merge branch 'master' into storage_reduction
Oct 12, 2020
f754919
Merge branch 'master' of https://github.com/GridTools/gt4py
Oct 14, 2020
c78ff8b
Fix bad merge conflict resolution
Oct 15, 2020
3e6c8ab
Refactor 'ReduceTemporaryStoragesPass'
Oct 21, 2020
51ed63c
Merge branch 'master' of https://github.com/GridTools/gt4py
Oct 21, 2020
35cad3d
Merge branch 'master' into storage_reduction
Oct 21, 2020
74fd684
Commit correct version of passes.py
Oct 21, 2020
38e3fec
Expand 'test_reduce_temporary_storages_pass'
Oct 21, 2020
b07ef9a
Apply formatting
Oct 21, 2020
0540ca5
Merge branch 'master' of https://github.com/GridTools/gt4py
Oct 21, 2020
7a23e5a
Merge branch 'master' into storage_reduction
Oct 21, 2020
af9c25b
Replace '_reduce_fields' method with 'StorageReducer' visitor
Oct 21, 2020
c9ee5e0
Merge branch 'master' of https://github.com/GridTools/gt4py
Oct 29, 2020
f3523a5
Merge branch 'master' into storage_reduction
Oct 29, 2020
26a81f4
Fixes for 2D storages in numpy backend
Oct 30, 2020
d225e49
Enable lower dim data stores in GT backends
Oct 30, 2020
0c408e4
Merge branch 'master' of https://github.com/GridTools/gt4py
Nov 19, 2020
ec00dcd
Merge branch 'master' of https://github.com/GridTools/gt4py
Nov 23, 2020
806ef89
Merge branch 'master' into storage_reduction
Nov 23, 2020
53a60b9
Merge branch 'master' of https://github.com/GridTools/gt4py
Dec 2, 2020
b520408
Merge branch 'master' into storage_reduction
Dec 2, 2020
e5834e2
Fuse eligible k-loops in 'debug' backend
Dec 2, 2020
ad19e9e
Apply formatting
Dec 2, 2020
de4792b
Merge branch 'master' of https://github.com/GridTools/gt4py
Dec 3, 2020
3b8cb77
Add multi-stage constraint, documentation, and codegen test
Dec 7, 2020
a9435d5
Merge branch 'master' into storage_reduction
Dec 7, 2020
96e3ca2
Merge branch 'master' of https://github.com/GridTools/gt4py
Dec 10, 2020
b0cd02c
Merge branch 'master' into storage_reduction
Dec 10, 2020
ac1dd21
Implement reviewer feedback
Dec 10, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 89 additions & 4 deletions src/gt4py/analysis/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1297,6 +1297,94 @@ def apply(cls, transform_data: TransformData) -> None:
cls.DemoteSymbols.apply(transform_data.implementation_ir, demotables)


class ReduceTemporaryStoragesPass(TransformPass):
"""Demote temporary symbols used within a single multi-stage to 2D fields if the following
constraints are satisfied:

1. The multi-stage iteration order is not parallel.
2. There are no offsets in the k-direction.
3. All of the accesses are in the same k-interval.
"""

class ReducibleFieldsCollector(gt_ir.IRNodeVisitor):
@classmethod
def apply(cls, node: gt_ir.StencilImplementation) -> Set[str]:
collector = cls()
return collector(node)

def __call__(self, node: gt_ir.StencilImplementation) -> Set[str]:
assert isinstance(node, gt_ir.StencilImplementation)
self.interval: gt_ir.AxisInterval = None
self.iteration_order: gt_ir.IterationOrder = None
self.multi_stage: str = ""
self.reduced_fields: Dict[str, Dict[str, Union[str, gt_ir.AxisInterval]]] = {
temp_field: dict(multi_stage="", interval=None)
for temp_field in node.temporary_fields
}
self.visit(node)
return set(self.reduced_fields.keys())

def visit_MultiStage(self, node: gt_ir.MultiStage) -> None:
self.iteration_order = node.iteration_order
self.multi_stage = node.name
self.generic_visit(node)

def visit_ApplyBlock(self, node: gt_ir.ApplyBlock) -> None:
self.interval = node.interval
self.generic_visit(node)

def visit_FieldRef(self, node: gt_ir.FieldRef) -> None:
field_name = node.name
if field_name in self.reduced_fields:
if self.iteration_order == gt_ir.IterationOrder.PARALLEL:
self.reduced_fields.pop(field_name)
else:
offsets: List[int] = list(node.offset.values())
if offsets[-1] != 0:
self.reduced_fields.pop(field_name)
else:
interval = self.reduced_fields[field_name]["interval"]
multi_stage = self.reduced_fields[field_name]["multi_stage"]
if (interval is not None and interval != self.interval) or (
multi_stage != "" and multi_stage != self.multi_stage
):
self.reduced_fields.pop(field_name)
else:
self.reduced_fields[field_name]["interval"] = self.interval
self.reduced_fields[field_name]["multi_stage"] = self.multi_stage

class StorageReducer(gt_ir.IRNodeVisitor):
@classmethod
def apply(cls, node: gt_ir.StencilImplementation, reduced_fields: Set[str]) -> None:
instance = cls(reduced_fields)
return instance(node)

def __init__(self, reduced_fields: Set[str]):
self.reduced_fields = reduced_fields

def __call__(self, node: gt_ir.StencilImplementation) -> gt_ir.StencilImplementation:
assert isinstance(node, gt_ir.StencilImplementation)
return self.visit(node)

def visit_StencilImplementation(self, node: gt_ir.StencilImplementation) -> None:
self.iir = node
for field_name in self.reduced_fields:
assert field_name in node.temporary_fields, "Tried to reduce API field to 2D."
node.fields[field_name].axes.pop()
self.generic_visit(node)

def visit_FieldRef(self, node: gt_ir.FieldRef) -> None:
if node.name in self.reduced_fields:
field_decl = self.iir.fields[node.name]
node.offset = {axis: node.offset[axis] for axis in field_decl.axes}

@classmethod
def apply(cls, transform_data: TransformData) -> None:
reduced_fields = cls.ReducibleFieldsCollector.apply(transform_data.implementation_ir)
if len(reduced_fields) > 0:
cls.StorageReducer.apply(transform_data.implementation_ir, reduced_fields)


class HousekeepingPass(TransformPass):
class WarnIfNoEffect(gt_ir.IRNodeVisitor):
"""Warn if StencilImplementation has no effect."""
Expand All @@ -1317,10 +1405,7 @@ def __call__(self, stencil_name: str, node: gt_ir.StencilImplementation) -> None
def visit_StencilImplementation(self, node: gt_ir.StencilImplementation):
# Emit warning if stencil has no effect, i.e. does not read or write to any api fields
if not node.has_effect:
warnings.warn(
f"Stencil `{self.stencil_name}` has no effect.",
RuntimeWarning,
)
warnings.warn(f"Stencil `{self.stencil_name}` has no effect.", RuntimeWarning)

class PruneEmptyNodes(gt_ir.IRNodeMapper):
"""Removes empty multi-stages, stage groups, and stages."""
Expand Down
5 changes: 5 additions & 0 deletions src/gt4py/analysis/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
InitInfoPass,
MergeBlocksPass,
NormalizeBlocksPass,
ReduceTemporaryStoragesPass,
)


Expand Down Expand Up @@ -111,6 +112,10 @@ def __call__(self, definition_ir, options):
# into local scalars
DemoteLocalTemporariesToVariablesPass.apply(self.transform_data)

# turn temporary fields that are only written and read within the horizontal plane
# into 2D i-j fields
ReduceTemporaryStoragesPass.apply(self.transform_data)

# prune some stages that don't have effect
HousekeepingPass.apply(self.transform_data)

Expand Down
23 changes: 17 additions & 6 deletions src/gt4py/backend/debug_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,24 @@ def _make_regional_computation(self, iteration_order, interval_definition):
else:
range_args = [loop_bounds[1] + " -1", loop_bounds[0] + " -1", "-1"]

range_expr = "range({args})".format(args=", ".join(a for a in range_args))
seq_axis = self.impl_node.domain.sequential_axis.name
source_lines.append("for {ax} in {range_expr}:".format(ax=seq_axis, range_expr=range_expr))
if range_args != self.range_args:
self.range_args = range_args
range_expr = "range({args})".format(args=", ".join(a for a in range_args))
seq_axis = self.impl_node.domain.sequential_axis.name
source_lines.append(
"for {ax} in {range_expr}:".format(ax=seq_axis, range_expr=range_expr)
)

return source_lines

def make_temporary_field(
self, name: str, dtype: gt_ir.DataType, extent: gt_definitions.Extent
self,
name: str,
dtype: gt_ir.DataType,
extent: gt_definitions.Extent,
axes: list = gt_definitions.CartesianSpace.names,
):
source_lines = super().make_temporary_field(name, dtype, extent)
source_lines = super().make_temporary_field(name, dtype, extent, axes)
source_lines.extend(self._make_field_accessor(name, extent.to_boundary().lower_indices))

return source_lines
Expand Down Expand Up @@ -112,10 +120,13 @@ def make_stage_source(self, iteration_order: gt_ir.IterationOrder, regions: list
def visit_FieldRef(self, node: gt_ir.FieldRef):
assert node.name in self.block_info.accessors
index = []
for ax in self.domain.axes_names:
for ax in node.offset.keys():
offset = "{:+d}".format(node.offset[ax]) if ax in node.offset else ""
index.append("{ax}{offset}".format(ax=ax, offset=offset))

# Extend index with zeros...
index.extend(["0"] * (len(self.block_info.extent) - len(index)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should use the filter_mask or NumericTuple.from_mask after merging.


source = "{name}{marker}[{index}]".format(
marker=self.origin_marker, name=node.name, index=", ".join(index)
)
Expand Down
13 changes: 11 additions & 2 deletions src/gt4py/backend/gt_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,14 +452,23 @@ def visit_StencilImplementation(
arg_fields = []
tmp_fields = []
storage_ids = []
used_axes = dict()
all_axes = [axis.lower() for axis in gt_definitions.CartesianSpace.names]

max_ndim = 0
for name, field_decl in node.fields.items():
if name not in node.unreferenced:
max_ndim = max(max_ndim, len(field_decl.axes))
axes = "".join(field_decl.axes).lower()
selector = ["1" if axis in axes else "0" for axis in all_axes]
used_axes[axes] = dict(name=axes.lower(), selector=", ".join(selector))

field_attributes = {
"name": field_decl.name,
"dtype": self._make_cpp_type(field_decl.data_type),
"axes": axes,
}

if field_decl.is_api:
if field_decl.layout_id not in storage_ids:
storage_ids.append(field_decl.layout_id)
Expand Down Expand Up @@ -497,6 +506,7 @@ def visit_StencilImplementation(
stage_functors=stage_functors,
stencil_unique_name=self.class_name,
tmp_fields=tmp_fields,
used_axes=used_axes.values(),
)

sources: Dict[str, Dict[str, str]] = {"computation": {}, "bindings": {}}
Expand Down Expand Up @@ -543,8 +553,7 @@ def generate(self) -> Type["StencilObject"]:

# Generate and return the Python wrapper class
return self.make_module(
pyext_module_name=pyext_module_name,
pyext_file_path=pyext_file_path,
pyext_module_name=pyext_module_name, pyext_file_path=pyext_file_path
)

def generate_computation(self) -> Dict[str, Union[str, Dict]]:
Expand Down
69 changes: 40 additions & 29 deletions src/gt4py/backend/numpy_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,17 @@ def _make_regional_computation(
range_args = [loop_bounds[1] + " -1", loop_bounds[0] + " -1", "-1"]

if iteration_order != gt_ir.IterationOrder.PARALLEL:
range_expr = "range({args})".format(args=", ".join(a for a in range_args))
seq_axis = self.impl_node.domain.sequential_axis.name
source_lines.append(
"for {ax} in {range_expr}:".format(ax=seq_axis, range_expr=range_expr)
)
if self.range_args != range_args:
self.range_args = range_args
range_expr = "range({args})".format(args=", ".join(a for a in range_args))
seq_axis = self.impl_node.domain.sequential_axis.name
source_lines.append(
"for {ax} in {range_expr}:".format(ax=seq_axis, range_expr=range_expr)
)
source_lines.extend(" " * self.indent_size + line for line in body_sources)
else:
# Clear range args on parallel intervals
self.range_args.clear()
source_lines.append(
"{interval_k_start_name} = {lb}".format(
interval_k_start_name=self.interval_k_start_name, lb=loop_bounds[0]
Expand All @@ -110,9 +114,13 @@ def _make_regional_computation(
return source_lines

def make_temporary_field(
self, name: str, dtype: gt_ir.DataType, extent: gt_definitions.Extent
) -> List[str]:
source_lines = super().make_temporary_field(name, dtype, extent)
self,
name: str,
dtype: gt_ir.DataType,
extent: gt_definitions.Extent,
axes: list = gt_definitions.CartesianSpace.names,
):
source_lines = super().make_temporary_field(name, dtype, extent, axes)
source_lines.extend(self._make_field_origin(name, extent.to_boundary().lower_indices))

return source_lines
Expand All @@ -138,7 +146,7 @@ def visit_FieldRef(self, node: gt_ir.FieldRef) -> str:
lower_extent = list(extent.lower_indices)
upper_extent = list(extent.upper_indices)

for d, ax in enumerate(self.domain.axes_names):
for d, ax in enumerate(node.offset.keys()):
idx = node.offset.get(ax, 0)
if idx:
lower_extent[d] += idx
Expand All @@ -160,28 +168,31 @@ def visit_FieldRef(self, node: gt_ir.FieldRef) -> str:
)

k_ax = self.domain.sequential_axis.name
k_offset = node.offset.get(k_ax, 0)
if is_parallel:
start_expr = self.interval_k_start_name
start_expr += " {:+d}".format(k_offset) if k_offset else ""
end_expr = self.interval_k_end_name
end_expr += " {:+d}".format(k_offset) if k_offset else ""
index.append(
"{name}{marker}[2] + {start}:{name}{marker}[2] + {stop}".format(
name=node.name, start=start_expr, marker=self.origin_marker, stop=end_expr
if k_ax in node.offset:
k_offset = node.offset.get(k_ax, 0)
if is_parallel:
start_expr = self.interval_k_start_name
start_expr += " {:+d}".format(k_offset) if k_offset else ""
end_expr = self.interval_k_end_name
end_expr += " {:+d}".format(k_offset) if k_offset else ""
index.append(
"{name}{marker}[2] + {start}:{name}{marker}[2] + {stop}".format(
name=node.name, start=start_expr, marker=self.origin_marker, stop=end_expr
)
)
)
else:
idx = "{:+d}".format(k_offset) if k_offset else ""
index.append(
"{name}{marker}[{d}] + {ax}{idx}".format(
name=node.name,
marker=self.origin_marker,
d=len(self.domain.parallel_axes),
ax=k_ax,
idx=idx,
else:
idx = "{:+d}".format(k_offset) if k_offset else ""
index.append(
"{name}{marker}[{d}] + {ax}{idx}".format(
name=node.name,
marker=self.origin_marker,
d=len(self.domain.parallel_axes),
ax=k_ax,
idx=idx,
)
)
)
else:
index.append("0:1" if is_parallel else "0")

source = "{name}[{index}]".format(name=node.name, index=", ".join(index))

Expand Down
8 changes: 6 additions & 2 deletions src/gt4py/backend/python_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def __init__(
self.param_names = None

self.var_refs_defined = set()
self.range_args = []

def __call__(self, impl_node: gt_ir.Node, sources: gt_text.TextBlock):
assert isinstance(impl_node, gt_ir.StencilImplementation)
Expand Down Expand Up @@ -110,7 +111,7 @@ def __call__(self, impl_node: gt_ir.Node, sources: gt_text.TextBlock):
return self.sources

def make_temporary_field(
self, name: str, data_type: gt_ir.DataType, extent: gt_definitions.Extent
self, name: str, data_type: gt_ir.DataType, extent: gt_definitions.Extent, axes: list
):
source_lines = []
boundary = extent.to_boundary()
Expand All @@ -119,7 +120,10 @@ def make_temporary_field(
domain=self.domain_arg_name, d=d, size=" {:+d}".format(size) if size > 0 else ""
)
for d, size in enumerate(boundary.frame_size)
if d < len(axes)
)
for i in range(len(extent) - len(axes)):
shape += ", 1"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

source_lines.append(
"{name} = {np_prefix}.empty(({shape}), dtype={np_prefix}.{dtype})".format(
name=name, np_prefix=self.numpy_prefix, shape=shape, dtype=data_type.dtype.name
Expand Down Expand Up @@ -293,7 +297,7 @@ def visit_StencilImplementation(self, node: gt_ir.StencilImplementation):
field = node.fields[name]
self.sources.extend(
self.make_temporary_field(
field.name, field.data_type, node.fields_extents[field.name]
field.name, field.data_type, node.fields_extents[field.name], field.axes
)
)
self.sources.empty_line()
Expand Down
Loading