Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[dace]: Better constant substitution #1778

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -205,27 +205,41 @@ def gt_substitute_compiletime_symbols(
repl: Maps the name of the symbol to the value it should be replaced with.
validate: Perform validation at the end of the function.
validate_all: Perform validation also on intermediate steps.

Note:
Because of [issue 1817](https://github.com/spcl/dace/issues/1817) in DaCe,
the function has to run `gt_simplify()`. However, this is an artefact of
the implementation and will be changed once the bug is solved.
"""

# We will use the `replace` function of the top SDFG, however, lower levels
# are handled using ConstantPropagation.
sdfg.replace_dict(repl)
# Ideally this function would just call `ConstantPropagation` with the replacement
# `dict` and be done. However, because of [issue 1817](https://github.com/spcl/dace/issues/1817)
# in DaCe this is not possible and we have to do it in this awkward way.
# TODO(phimuell): Fix this strange behaviour.

# First we do replacement on the top level SDFG only. However, we have to filter
# out all names that refers to data descriptors, because the replacement function
# can not handle them. We leave this to `ConstantPropagation`.
arrays = sdfg.arrays
sdfg.replace_dict({sym: value for sym, value in repl.items() if sym not in arrays})
const_prop = dace_passes.ConstantPropagation()
const_prop.recursive = True
const_prop.progress = False

const_prop.apply_pass(
sdfg=sdfg,
initial_symbols=repl,
_=None,
)

# To handle some bugs in `ConstantPropagation` we now call simplify.
# TODO(phimuell): Once the bug in DaCe is fixed remove this.
gt_simplify(
sdfg=sdfg,
validate=validate,
sdfg,
validate=False,
validate_all=validate_all,
)
dace.sdfg.propagation.propagate_memlets_sdfg(sdfg)
if validate_all:
sdfg.validate()


def gt_reduce_distributed_buffering(
Expand Down
Loading
Loading