Skip to content

Commit

Permalink
Merge pull request #2 from DropD/bugfix/multistage-data-dependencies
Browse files Browse the repository at this point in the history
Bugfix/multistage data dependencies
  • Loading branch information
jdahm authored Sep 30, 2020
2 parents 5192ac3 + c3f3a55 commit d8962c8
Show file tree
Hide file tree
Showing 6 changed files with 389 additions and 206 deletions.
2 changes: 1 addition & 1 deletion src/gt4py/frontend/gtscript_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,7 +942,7 @@ def visit_Assign(self, node: ast.Assign) -> list:
def visit_AugAssign(self, node: ast.AugAssign):
"""Implement left <op>= right in terms of left = left <op> right."""
binary_operation = ast.BinOp(left=node.target, op=node.op, right=node.value)
assignment = ast.Assign(targets=[node.target], value=node.target)
assignment = ast.Assign(targets=[node.target], value=binary_operation)
ast.copy_location(binary_operation, node)
ast.copy_location(assignment, node)
return self.visit_Assign(assignment)
Expand Down
5 changes: 4 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,19 @@
from gt4py import config as gt_config

from .analysis_setup import (
PassType,
compute_extents_pass,
init_pass,
merge_blocks_pass,
normalize_blocks_pass,
)
from .definition_setup import (
TAssign,
TComputationBlock,
TDefinition,
ij_offset,
ijk_domain,
iteration_order,
make_transform_data,
non_parallel_iteration_order,
)

Expand Down
323 changes: 218 additions & 105 deletions tests/definition_setup.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import itertools
from functools import partial
from typing import Iterator, List, Tuple
from typing import Iterator, List, Set, Tuple, Union

import pytest

Expand All @@ -16,6 +14,7 @@
ComputationBlock,
DataType,
Domain,
Expr,
FieldDecl,
FieldRef,
IterationOrder,
Expand Down Expand Up @@ -54,81 +53,6 @@ def make_offset(offset: Tuple[int, int, int]):
return {"I": offset[0], "J": offset[1], "K": offset[2]}


def make_assign(
target: str,
value: str,
offset: Tuple[int, int, int] = (0, 0, 0),
loc_line=0,
loc_scope="unnamed",
):
make_loc = partial(Location, scope=loc_scope)
return Assign(
target=FieldRef(
name=target, offset=make_offset((0, 0, 0)), loc=make_loc(line=loc_line, column=0)
),
value=FieldRef(
name=value, offset=make_offset(offset), loc=make_loc(line=loc_line, column=2)
),
loc=make_loc(line=loc_line, column=1),
)


def make_definition_multiple(
name: str,
fields: List[str],
domain: Domain,
info: List[Tuple[BodyType, IterationOrder, Tuple[int, int]]],
) -> StencilDefinition:
api_signature = [ArgumentInfo(name=n, is_keyword=False) for n in fields]
bodies = tuple(itertools.chain.from_iterable([definition[0] for definition in info]))
tmp_fields = {i[0] for i in bodies}.union({i[1] for i in bodies}).difference(fields)
api_fields = [
FieldDecl(name=n, data_type=DataType.AUTO, axes=domain.axes_names, is_api=True)
for n in fields
] + [
FieldDecl(name=n, data_type=DataType.AUTO, axes=domain.axes_names, is_api=False)
for n in tmp_fields
]
comp_blocks = [
ComputationBlock(
interval=AxisInterval(
start=AxisBound(level=LevelMarker.START, offset=interval[0]),
end=AxisBound(level=LevelMarker.END, offset=interval[1]),
),
iteration_order=iteration_order,
body=BlockStmt(
stmts=[
make_assign(*assign, loc_scope=name, loc_line=i)
for i, assign in enumerate(body)
]
),
)
for body, iteration_order, interval in info
]
return StencilDefinition(
name=name,
domain=domain,
api_signature=api_signature,
api_fields=api_fields,
parameters=[],
computations=comp_blocks,
docstring="",
)


def make_definition(
name: str, fields: List[str], domain: Domain, body: BodyType, iteration_order: IterationOrder
) -> StencilDefinition:
return make_definition_multiple(
name,
fields,
domain,
[
(body, iteration_order, (0, 0)),
],
)


def init_implementation_from_definition(definition: StencilDefinition) -> StencilImplementation:
return StencilImplementation(
name=definition.name,
Expand All @@ -146,32 +70,221 @@ def init_implementation_from_definition(definition: StencilDefinition) -> Stenci
)


def make_transform_data(
*,
name: str,
domain: Domain,
fields: List[str],
body: BodyType,
iteration_order: IterationOrder,
) -> TransformData:
definition = make_definition(name, fields, domain, body, iteration_order)
return TransformData(
definition_ir=definition,
implementation_ir=init_implementation_from_definition(definition),
options=BuildOptions(name=name, module=__name__),
)
class TObject:
def __init__(self, loc: Location, parent: "TObject" = None):
self.loc = loc
self.children = []
self.parent = parent

@property
def width(self) -> int:
return sum(child.width for child in self.children) + 1 if self.children else 1

def make_transform_data_multiple(
*,
name: str,
domain: Domain,
fields: List[str],
info: List[Tuple[BodyType, IterationOrder, Tuple[int, int]]],
) -> TransformData:
definition = make_definition_multiple(name, fields, domain, info)
return TransformData(
definition_ir=definition,
implementation_ir=init_implementation_from_definition(definition),
options=BuildOptions(name=name, module=__name__),
)
@property
def height(self) -> int:
return sum(child.height for child in self.children) + 1 if self.children else 1

def register_child(self, child: "TObject") -> None:
child.loc = Location(
line=self.loc.line + self.height,
column=self.loc.column + self.width,
scope=self.child_scope,
)
child.parent = self
self.children.append(child)

@property
def field_names(self) -> Set[str]:
return set.union(*(child.field_names for child in self.children))

@property
def child_scope(self) -> str:
return self.loc.scope


class TDefinition(TObject):
def __init__(self, *, name: str, domain: Domain, fields: List[str]):
super().__init__(Location(line=0, column=0, scope=name))
self.name = name
self.domain = domain
self.fields = fields
self.parameters = []
self.docstring = ""

def add_blocks(self, *blocks: "TComputationBlock") -> "TDefinition":
for block in blocks:
self.register_child(block)
return self

@property
def width(self) -> int:
return 0

@property
def height(self) -> int:
return super().height - 1

@property
def api_signature(self) -> List[ArgumentInfo]:
return [ArgumentInfo(name=n, is_keyword=False) for n in self.fields]

@property
def api_fields(self) -> List[FieldDecl]:
tmp_field_names = self.field_names.difference(self.fields)
tmp_fields = [
FieldDecl(name=n, data_type=DataType.AUTO, axes=self.domain.axes_names, is_api=False)
for n in tmp_field_names
]
return tmp_fields + [
FieldDecl(name=n, data_type=DataType.AUTO, axes=self.domain.axes_names, is_api=True)
for n in self.fields
]

def build(self) -> StencilDefinition:
return StencilDefinition(
name=self.name,
domain=self.domain,
api_signature=self.api_signature,
api_fields=self.api_fields,
parameters=self.parameters,
computations=[block.build() for block in self.children],
docstring=self.docstring,
loc=self.loc,
)

def build_transform(self):
definition = self.build()
return TransformData(
definition_ir=definition,
implementation_ir=init_implementation_from_definition(definition),
options=BuildOptions(name=self.name, module=__name__),
)


class TComputationBlock(TObject):
def __init__(
self, *, order: IterationOrder, start: int = 0, end: int = 0, scope: str = "<unnamed>"
):
super().__init__(Location(line=0, column=0))
self.order = order
self.start = start
self.end = end
self.scope = scope

def add_statements(self, *stmts: "TStatement") -> "TComputationBlock":
for stmt in stmts:
self.register_child(stmt)
return self

@property
def width(self) -> int:
return 0

def build(self) -> ComputationBlock:
self.loc.scope = self.parent.child_scope
return ComputationBlock(
interval=AxisInterval(
start=AxisBound(level=LevelMarker.START, offset=self.start),
end=AxisBound(level=LevelMarker.END, offset=self.end),
),
iteration_order=self.order,
body=BlockStmt(
stmts=[stmt.build() for stmt in self.children],
),
loc=self.loc,
)

@property
def child_scope(self) -> str:
return f"{self.loc.scope}:{self.scope}"


class TStatement(TObject):
pass


class TAssign(TStatement):
def __init__(self, target: str, value: Union[str, Expr], offset: Tuple[int, int, int]):
super().__init__(Location(line=0, column=0))
self._target = target
self._value = value
self.offset = offset

@property
def height(self):
return 1

@property
def width(self):
return self.target.width + 3 + self.value.width

@property
def value(self):
value = self._value
if isinstance(self._value, str):
value = TFieldRef(name=self._value, offset=self.offset)
value.loc = Location(
line=self.loc.line,
column=self.loc.column + self.target.width + 3,
scope=self.loc.scope,
)
value.parent = self
return value

@property
def field_names(self) -> Set[str]:
return set.union(self.target.field_names, self.value.field_names)

@property
def target(self):
return TFieldRef(
name=self._target,
parent=self,
loc=Location(line=self.loc.line, column=self.loc.column, scope=self.loc.scope),
)

def build(self) -> Assign:
self.loc.scope = self.parent.child_scope
return Assign(
target=self.target.build(),
value=self.value.build(),
loc=Location(
line=self.loc.line,
column=self.loc.column + self.target.width + 1,
scope=self.loc.scope,
),
)


class TFieldRef(TObject):
def __init__(
self,
*,
name: str,
offset: Tuple[int, int, int] = (0, 0, 0),
loc: Location = None,
parent: TObject = None,
):
super().__init__(loc or Location(line=0, column=0), parent=parent)
self.name = name
self.offset = make_offset(offset)

def build(self):
self.loc.scope = self.parent.child_scope
return FieldRef(
name=self.name,
offset=self.offset,
loc=self.loc,
)

@property
def height(self) -> int:
return 1

@property
def width(self) -> int:
return len(self.name)

@property
def field_names(self) -> Set[str]:
return {self.name}
14 changes: 14 additions & 0 deletions tests/test_unittest/test_ir_maker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import ast

from gt4py.frontend.gtscript_frontend import IRMaker
from gt4py.ir.nodes import BinaryOperator, BinOpExpr


def test_AugAssign():
ir_maker = IRMaker(None, None, None, domain=None, extra_temp_decls=None)
aug_assign = ast.parse("a += 1").body[0]

_, result = ir_maker.visit_AugAssign(aug_assign)

assert isinstance(result.value, BinOpExpr)
assert result.value.op == BinaryOperator.ADD
Loading

0 comments on commit d8962c8

Please sign in to comment.