Skip to content

Commit

Permalink
Fix dawn4py bindings relating to IIR maps (#984)
Browse files Browse the repository at this point in the history
Enables us to better analyze the passes from the python command line
  • Loading branch information
jdahm authored May 29, 2020
1 parent bc24e15 commit f656bc3
Show file tree
Hide file tree
Showing 9 changed files with 78 additions and 54 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,6 @@ _local_/*

# Ignore python editable installs
*.egg-info/

# Ignore pip files
pip-wheel-metadata/
4 changes: 2 additions & 2 deletions dawn/scripts/_dawn4py.cpp.in
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ PYBIND11_MODULE(_dawn4py, m) {
dawn::IIRSerializer::Format format,
dawn::codegen::Backend backend,
const dawn::codegen::Options& options) {
return dawn::codegen::run(stencilInstantiationMap, format, backend, options);
return dawn::codegen::run(stencilInstantiationMap, format, backend, options);
},
"Generate code from the stencil instantiation map.",
py::arg("stencil_instantiation_map"),
Expand All @@ -77,7 +77,7 @@ PYBIND11_MODULE(_dawn4py, m) {
m.def("compile_sir", [](const std::string& sir, dawn::SIRSerializer::Format format,
const std::list<dawn::PassGroup>& groups, const dawn::Options& optimizerOptions,
dawn::codegen::Backend backend, const dawn::codegen::Options& codegenOptions) {
return dawn::compile(sir, format, groups, optimizerOptions, backend, codegenOptions);
return dawn::compile(sir, format, groups, optimizerOptions, backend, codegenOptions);
},
"Compile the stencil IR: lower, optimize, and generate code.",
"Runs the default_pass_groups() unless the 'groups' argument is passed.",
Expand Down
52 changes: 22 additions & 30 deletions dawn/src/dawn/Compiler/Driver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -278,11 +278,10 @@ run(const std::map<std::string, std::shared_ptr<iir::StencilInstantiation>>&
<< "`";

if(options.SerializeIIR) {
const dawn::IIRSerializer::Format serializationKind =
options.SerializeIIR ? dawn::IIRSerializer::parseFormatString(options.IIRFormat)
: dawn::IIRSerializer::Format::Json;
dawn::IIRSerializer::serialize(instantiation->getName() + ".iir", instantiation,
serializationKind);
const IIRSerializer::Format serializationKind =
options.SerializeIIR ? IIRSerializer::parseFormatString(options.IIRFormat)
: IIRSerializer::Format::Json;
IIRSerializer::serialize(instantiation->getName() + ".iir", instantiation, serializationKind);
}

if(options.DumpStencilInstantiation) {
Expand All @@ -299,36 +298,30 @@ run(const std::map<std::string, std::shared_ptr<iir::StencilInstantiation>>&
return optimizer.getStencilInstantiationMap();
}

std::map<std::string, std::string> run(const std::string& sir, dawn::SIRSerializer::Format format,
const std::list<dawn::PassGroup>& groups,
const dawn::Options& options) {
auto stencilIR = dawn::SIRSerializer::deserializeFromString(sir, format);
auto optimizedSIM = dawn::run(stencilIR, groups, options);
std::map<std::string, std::string> run(const std::string& sir, SIRSerializer::Format format,
const std::list<PassGroup>& groups, const Options& options) {
auto stencilIR = SIRSerializer::deserializeFromString(sir, format);
auto optimizedSIM = run(stencilIR, groups, options);
std::map<std::string, std::string> instantiationStringMap;
const dawn::IIRSerializer::Format outputFormat = format == dawn::SIRSerializer::Format::Byte
? dawn::IIRSerializer::Format::Byte
: dawn::IIRSerializer::Format::Json;
for(auto [name, instantiation] : optimizedSIM) {
instantiationStringMap.insert(
std::make_pair(name, dawn::IIRSerializer::serializeToString(instantiation, outputFormat)));
instantiationStringMap.insert(std::make_pair(
name, IIRSerializer::serializeToString(instantiation, IIRSerializer::Format::Json)));
}
return instantiationStringMap;
}

std::map<std::string, std::string>
run(const std::map<std::string, std::string>& stencilInstantiationMap,
dawn::IIRSerializer::Format format, const std::list<dawn::PassGroup>& groups,
const dawn::Options& options) {
std::map<std::string, std::shared_ptr<dawn::iir::StencilInstantiation>> internalMap;
run(const std::map<std::string, std::string>& stencilInstantiationMap, IIRSerializer::Format format,
const std::list<PassGroup>& groups, const Options& options) {
std::map<std::string, std::shared_ptr<iir::StencilInstantiation>> internalMap;
for(auto [name, instStr] : stencilInstantiationMap) {
internalMap.insert(
std::make_pair(name, dawn::IIRSerializer::deserializeFromString(instStr, format)));
internalMap.insert(std::make_pair(name, IIRSerializer::deserializeFromString(instStr, format)));
}
auto optimizedSIM = dawn::run(internalMap, groups, options);
auto optimizedSIM = run(internalMap, groups, options);
std::map<std::string, std::string> instantiationStringMap;
for(auto [name, instantiation] : optimizedSIM) {
instantiationStringMap.insert(
std::make_pair(name, dawn::IIRSerializer::serializeToString(instantiation, format)));
instantiationStringMap.insert(std::make_pair(
name, IIRSerializer::serializeToString(instantiation, IIRSerializer::Format::Json)));
}
return instantiationStringMap;
}
Expand All @@ -341,12 +334,11 @@ std::unique_ptr<codegen::TranslationUnit> compile(const std::shared_ptr<SIR>& st
return codegen::run(run(stencilIR, passGroups, optimizerOptions), backend, codegenOptions);
}

std::string compile(const std::string& sir, dawn::SIRSerializer::Format format,
const std::list<dawn::PassGroup>& groups, const dawn::Options& optimizerOptions,
dawn::codegen::Backend backend, const dawn::codegen::Options& codegenOptions) {
auto stencilIR = dawn::SIRSerializer::deserializeFromString(sir, format);
return dawn::codegen::generate(
dawn::compile(stencilIR, groups, optimizerOptions, backend, codegenOptions));
std::string compile(const std::string& sir, SIRSerializer::Format format,
const std::list<PassGroup>& groups, const Options& optimizerOptions,
codegen::Backend backend, const codegen::Options& codegenOptions) {
auto stencilIR = SIRSerializer::deserializeFromString(sir, format);
return codegen::generate(compile(stencilIR, groups, optimizerOptions, backend, codegenOptions));
}

} // namespace dawn
8 changes: 8 additions & 0 deletions dawn/src/dawn/Compiler/Driver.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ run(const std::shared_ptr<SIR>& stencilIR, const std::list<PassGroup>& groups,
const Options& options = {});

/// @brief Lower to IIR and run groups. Use strings in place of C++ structures.
///
/// NOTE: This method always returns the stencil instantiations as json string objects, not
/// bytes, as this greatly simplifies the conversion.
/// See https://pybind11.readthedocs.io/en/stable/advanced/cast/strings.html for more details.
std::map<std::string, std::string> run(const std::string& sir, dawn::SIRSerializer::Format format,
const std::list<dawn::PassGroup>& groups,
const dawn::Options& options = {});
Expand All @@ -49,6 +53,10 @@ run(const std::map<std::string, std::shared_ptr<iir::StencilInstantiation>>&
const std::list<PassGroup>& groups, const Options& options = {});

/// @brief Run groups. Use strings in place of C++ structures.
///
/// NOTE: This method always returns the stencil instantiations as json string objects, not bytes,
/// as this greatly simplifies the conversion.
/// See https://pybind11.readthedocs.io/en/stable/advanced/cast/strings.html for more details.
std::map<std::string, std::string>
run(const std::map<std::string, std::string>& stencilInstantiationMap, IIRSerializer::Format format,
const std::list<PassGroup>& groups, const Options& options = {});
Expand Down
1 change: 0 additions & 1 deletion dawn/src/dawn/dawn-opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
#include "dawn/Serialization/IIRSerializer.h"
#include "dawn/Serialization/SIRSerializer.h"
#include "dawn/Support/FileSystem.h"
#include "dawn/Support/Json.h"
#include "dawn/Support/Logger.h"

#include <cxxopts.hpp>
Expand Down
1 change: 1 addition & 0 deletions dawn/src/dawn4py/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ _local_

### Dawn ####
_external_src/driver-includes/*
_external_src/interface/*

### Python ###
# Byte-compiled / optimized / DLL files
Expand Down
42 changes: 22 additions & 20 deletions dawn/src/dawn4py/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,16 @@ def _serialize_sir(sir: Union[serialization.SIR.SIR, str, bytes]):
elif isinstance(sir, str) and sir.lstrip().startswith("{") and sir.rstrip().endswith("}"):
serializer_format = SIRSerializerFormat.Json
elif not isinstance(sir, bytes):
raise ValueError(f"Unrecognized SIR data format ({sir})")
raise ValueError("Unrecognized SIR data format")
return sir, serializer_format


def _serialize_instantiations(stencil_instantiation_map: dict):
# Determine serializer_format based on first stencil instantiation in the dict
serializer_format = IIRSerializerFormat.Byte
si = stencil_instantiation_map[0]
if len(stencil_instantiation_map) == 0:
raise ValueError("No stencil instantiations found")
si = list(stencil_instantiation_map.values())[0]
if isinstance(si, str) and si.lstrip().startswith("{") and si.rstrip().endswith("}"):
serializer_format = IIRSerializerFormat.Json

Expand All @@ -67,9 +69,13 @@ def _serialize(si, serializer_format):
else:
return si

return {
name: _serialize(si, serializer_format) for name, si in stencil_instantiation_map.items()
}
return (
{
name: _serialize(si, serializer_format)
for name, si in stencil_instantiation_map.items()
},
serializer_format,
)


_OPTIMIZER_OPTIONS = tuple(
Expand Down Expand Up @@ -146,12 +152,10 @@ def lower_and_optimize(
iir_map = _dawn4py.run_optimizer_sir(
sir, sir_format, groups, OptimizerOptions(**optimizer_options)
)
deserializeFcn = (
lambda x: serialization.from_bytes(x, serialization.IIR.StencilInstantiation)
if sir_format == SIRSerializerFormat.Byte
else serialization.from_json(x, serialization.IIR.StencilInstantiation)
)
return {name: deserializeFcn(string) for name, string in iir_map.items()}
return {
name: serialization.from_json(string, serialization.IIR.StencilInstantiation)
for name, string in iir_map.items()
}


def optimize(
Expand All @@ -176,15 +180,13 @@ def optimize(
optimizer_options = {k: v for k, v in kwargs.items() if k in _OPTIMIZER_OPTIONS}

instantiation_map, iir_format = _serialize_instantiations(instantiation_map)
optimized_instantiations = _dawn4py.run_optimizer_sir(
optimized_instantiations = _dawn4py.run_optimizer_iir(
instantiation_map, iir_format, groups, OptimizerOptions(**optimizer_options)
)
deserializeFcn = (
lambda x: serialization.from_bytes(x, serialization.IIR.StencilInstantiation)
if iir_format == IIRSerializerFormat.Byte
else serialization.from_json(x, serialization.IIR.StencilInstantiation)
)
return {name: deserializeFcn(string) for name, string in optimized_instantiations.items()}
return {
name: serialization.from_json(string, serialization.IIR.StencilInstantiation)
for name, string in optimized_instantiations.items()
}


def codegen(
Expand All @@ -197,8 +199,8 @@ def codegen(
----------
instantiation_map:
Stencil instantiation map (values in any valid serialized or non serialized form).
groups:
Optimizer pass groups [defaults to :func:`default_pass_groups()`]
backend:
Code generation backend [defaults to GridTools].
**kwargs
Optional keyword arguments with specific options for the compiler (see :class:`Options`).
Returns
Expand Down
13 changes: 13 additions & 0 deletions dawn/test/unit-test/dawn4py-tests/test_structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,17 @@ def test_compilation(grid_sir_with_reference_code):
dawn4py.CodeGenBackend.CUDA,
):
dawn4py.compile(sir, backend=backend)
dawn4py.codegen(
dawn4py.optimize(
dawn4py.lower_and_optimize(sir, groups=[]),
groups=[
dawn4py.PassGroup.SetStageName,
dawn4py.PassGroup.StageReordering,
# dawn4py.PassGroup.StageMerger,
dawn4py.PassGroup.SetCaches,
dawn4py.PassGroup.SetBlockSize,
],
),
backend=backend,
)
# TODO There was not test here...
8 changes: 7 additions & 1 deletion dawn/test/unit-test/dawn4py-tests/test_unstructured.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,11 @@ def test_sir_serialization(name):

def test_compilation(unstructured_sir_with_reference_code):
sir, reference_code = unstructured_sir_with_reference_code
code = dawn4py.compile(sir, backend=dawn4py.CodeGenBackend.CXXNaiveIco)
dawn4py.compile(sir, backend=dawn4py.CodeGenBackend.CXXNaiveIco)
dawn4py.codegen(
dawn4py.optimize(
dawn4py.lower_and_optimize(sir, groups=[]), groups=dawn4py.default_pass_groups()
),
backend=dawn4py.CodeGenBackend.CXXNaiveIco,
)
# TODO There was no test here...

0 comments on commit f656bc3

Please sign in to comment.