Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: use cppyy for JIT #2306

Merged
merged 18 commits into from
Mar 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/awkward/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import awkward._connect.numpy
import awkward._connect.numexpr
import awkward.numba
import awkward.cppyy
import awkward.jax
import awkward.typetracer
import awkward._typetracer # todo: remove this after "deprecation" period
Expand Down
33 changes: 33 additions & 0 deletions src/awkward/cppyy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE

import awkward as ak

_has_checked_version = False


def register_and_check():
global _has_checked_version

try:
import cppyy
except ImportError as err:
raise ImportError(
"""install the 'cppyy' package with:

pip install cppyy

or

conda install -c conda-forge cppyy

Note that this must be in a different venv or conda environment from ROOT, if you have installed ROOT.
"""
) from err

if not _has_checked_version:
if ak._util.parse_version(cppyy.__version__) < ak._util.parse_version("3.0.1"):
raise ImportError(
"Awkward Array can only work with cppyy 3.0.1 or later "
"(you have version {})".format(cppyy.__version__)
)
_has_checked_version = True
40 changes: 40 additions & 0 deletions src/awkward/highlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ def __init__(
check_valid=False,
backend=None,
):
self._cpptype = None
if isinstance(data, ak.contents.Content):
layout = data

Expand Down Expand Up @@ -1460,6 +1461,45 @@ def __bool__(self):
)
)

@property
def cpptype(self):
"""
The C++ type of this Array when it is used in cppyy.

cpptype (None or str): Generated on demand when the Array needs to be passed
to a C++ (possibly templated) function defined by a `cppyy` compiler.

See [cppyy documentation](https://cppyy.readthedocs.io/en/latest/index.html)
on types and signatures.
"""
ak.cppyy.register_and_check()

if self._cpptype is None:
# FIXME: see where and if to keep the lookup
self._generator = ak._connect.cling.togenerator(
self.layout.form, flatlist_as_rvec=False
)
self._lookup = ak._lookup.Lookup(self.layout)
self._cpptype = f"awkward::{self._generator.class_type()}"

return self._cpptype

def __cast_cpp__(self):
"""
The `__cast_cpp__` is called by cppyy to determine a C++ type of an `ak.Array`.
It returns the C++ dataset type that is already registered with cppyy with the
parameters needed to construct the C++ type of this Array when it is
used in cppyy.
"""
if self._cpptype is None:
self._cpptype = self.cpptype

import cppyy

return getattr(cppyy.gbl, self._cpptype)(
0, len(self), 0, self._lookup.arrayptrs, 0
)


class Record(NDArrayOperatorsMixin):
"""
Expand Down
118 changes: 118 additions & 0 deletions tests/test_2306_cppyy_git.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE


import pytest
import setuptools

import awkward as ak
import awkward._connect.cling

cppyy = pytest.importorskip("cppyy")


def test_array_as_generated_dataset():
array = ak.Array(
[
[{"x": 1, "y": [1.1]}, {"x": 2, "y": [2.2, 0.2]}],
[],
[{"x": 3, "y": [3.0, 0.3, 3.3]}],
]
)

generator = ak._connect.cling.togenerator(array.layout.form, flatlist_as_rvec=False)
lookup = ak._lookup.Lookup(array.layout)

source_code = f"""
double go_fast(ssize_t length, ssize_t* ptrs) {{
auto awkward_array = {generator.dataset()};
double out = 0.0;

for (auto list : awkward_array) {{
for (auto record : list) {{
for (auto item : record.y()) {{
out += item;
}}
}}
}}

return out;
}}
"""

generator.generate(cppyy.cppdef)
cppyy.cppdef(source_code)
out = cppyy.gbl.go_fast(len(array), lookup.arrayptrs)
assert out == ak.sum(array["y"])


@pytest.mark.skipif(
setuptools.extern.packaging.version.parse(cppyy.__version__)
< setuptools.extern.packaging.version.parse("3.0.1"),
reason="Awkward Array can only work with cppyy 3.0.1 or later.",
)
def test_array_as_type():
array = ak.Array(
[
[{"x": 1, "y": [1.1]}, {"x": 2, "y": [2.2, 0.2]}],
[],
[{"x": 3, "y": [3.0, 0.3, 3.3]}],
]
)

source_code_cpp = f"""
double go_fast_cpp({array.cpptype} awkward_array) {{
double out = 0.0;

for (auto list : awkward_array) {{
for (auto record : list) {{
for (auto item : record.y()) {{
out += item;
}}
}}
}}

return out;
}}
"""

cppyy.cppdef(source_code_cpp)

out = cppyy.gbl.go_fast_cpp(array)
assert out == ak.sum(array["y"])


@pytest.mark.skipif(
setuptools.extern.packaging.version.parse(cppyy.__version__)
< setuptools.extern.packaging.version.parse("3.0.1"),
reason="Awkward Array can only work with cppyy 3.0.1 or later.",
)
def test_array_as_templated_type():
array = ak.Array(
[
[{"x": 1, "y": [1.1]}, {"x": 2, "y": [2.2, 0.2]}],
[],
[{"x": 3, "y": [3.0, 0.3, 3.3]}],
]
)

source_code_cpp = """
template<typename T>
double go_fast_cpp_2(T& awkward_array) {
double out = 0.0;

for (auto list : awkward_array) {
for (auto record : list) {
for (auto item : record.y()) {
out += item;
}
}
}

return out;
}
"""

cppyy.cppdef(source_code_cpp)

out = cppyy.gbl.go_fast_cpp_2[array.cpptype](array)
assert out == ak.sum(array["y"])