Skip to content

Commit

Permalink
Migrating to scikit-build-core (#12)
Browse files Browse the repository at this point in the history
* migrating to scikit-build-core

* install jaxlib

* dockerfile updates

* verbose builds

* update actions versions

* update deprecated config call

* Updating terminology in README
  • Loading branch information
dfm authored Nov 3, 2023
1 parent 388c8ba commit c338696
Show file tree
Hide file tree
Showing 8 changed files with 61 additions and 188 deletions.
4 changes: 2 additions & 2 deletions .github/action/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
FROM nvidia/cuda:11.8.0-devel-ubuntu22.04

RUN apt-get update && \
DEBIAN_FRONTEND=noninteractive apt-get install -y git python3-pip cmake
DEBIAN_FRONTEND=noninteractive apt-get install -y git python3-pip

RUN pip install --upgrade pip && \
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install --upgrade "jax[cuda11_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

COPY entrypoint.sh /entrypoint.sh

Expand Down
2 changes: 1 addition & 1 deletion .github/action/entrypoint.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/bin/sh -l

cd /github/workspace
KEPLER_JAX_CUDA=yes python3 -m pip install .
KEPLER_JAX_CUDA=yes python3 -m pip install -v .
python3 -c 'import kepler_jax;print(kepler_jax.__version__)'
python3 -c 'import kepler_jax.gpu_ops'
6 changes: 3 additions & 3 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
os: [ubuntu-latest, macos-latest]

steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v4
with:
fetch-depth: 0

Expand All @@ -28,7 +28,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install -U pip
python -m pip install .[test]
python -m pip install -v .[test]
- name: Run tests
run: python -m pytest -v tests
Expand All @@ -37,7 +37,7 @@ jobs:
name: CUDA
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v4
with:
fetch-depth: 0
- uses: ./.github/action
20 changes: 12 additions & 8 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,25 +1,29 @@
cmake_minimum_required(VERSION 3.12...3.18)
project(kepler_jax LANGUAGES CXX)
cmake_minimum_required(VERSION 3.15...3.26)
project(${SKBUILD_PROJECT_NAME} LANGUAGES C CXX)
message(STATUS "Using CMake version: " ${CMAKE_VERSION})

message(STATUS "Using CMake version " ${CMAKE_VERSION})

find_package(Python COMPONENTS Interpreter Development REQUIRED)
# Find pybind11
set(PYBIND11_NEWPYTHON ON)
find_package(pybind11 CONFIG REQUIRED)

include_directories(${CMAKE_CURRENT_LIST_DIR}/lib)

# CPU op library
pybind11_add_module(cpu_ops ${CMAKE_CURRENT_LIST_DIR}/lib/cpu_ops.cc)
install(TARGETS cpu_ops DESTINATION kepler_jax)
install(TARGETS cpu_ops LIBRARY DESTINATION .)

# Include the CUDA extensions if possible
include(CheckLanguage)
check_language(CUDA)

if (KEPLER_JAX_CUDA)
if(CMAKE_CUDA_COMPILER)
enable_language(CUDA)
include_directories(${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
pybind11_add_module(
gpu_ops
${CMAKE_CURRENT_LIST_DIR}/lib/kernels.cc.cu
${CMAKE_CURRENT_LIST_DIR}/lib/gpu_ops.cc)
install(TARGETS gpu_ops DESTINATION kepler_jax)
install(TARGETS gpu_ops LIBRARY DESTINATION .)
else()
message(STATUS "Building without CUDA")
endif()
59 changes: 19 additions & 40 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,8 @@ wanted to emphasize a few points to consider.
The files in this repo come in three categories:

1. In the root directory, there are the standard packaging files like a
`setup.py` and `pyproject.toml`. Most of this setup is pretty standard, but
I'll highlight some of the unique elements in the packaging section below.
For example, we'll use a slightly strange combination of PEP-517/518 and
CMake to build the extensions. This isn't strictly necessary, but it's the
easiest packaging setup that I've been able to put together.
`pyproject.toml`. Most of this setup is pretty standard, but
I'll highlight some unique elements in the packaging section below.

2. Next, the `src/kepler_jax` directory is a Python module with the definition
of our JAX primitive roughly following the JAX [How primitives
Expand Down Expand Up @@ -260,17 +257,12 @@ pybind11 since that's what I'm most familiar with. The [LAPACK ops in
jaxlib][jaxlib-lapack] are implemented using Cython if you'd like to see an
example of how to do that.

Another choice that I've made is to use [CMake](https://cmake.org) to build the
extensions. It would be totally possible (and perhaps preferable if you only
support CPU usage) to stick to just using setuptools directly, but setuptools
doesn't seem to have great support for compiling CUDA extensions so that's why I
settled on CMake. In the end, it's not too painful since CMake can be included
as a build dependency in `pyproject.toml` so users won't have to install it
separately. Another build option would be to use [bazel](https://bazel.build) to
compile the code, like the JAX project, but I don't have any experience with it
so I decided to stick with what I know. _The key point is that we're just
compiling a regular old Python module so you can use whatever infrastructure
you're familiar with!_
Another choice that I've made is to use [scikit-build-core](scikit-build-core)
and [CMake](https://cmake.org) to build the extensions. Another build option
would be to use [bazel](https://bazel.build) to compile the code, like the JAX
project, but I don't have any experience with it, so I decided to stick with
what I know. _The key point is that we're just compiling a regular old Python
module, so you can use whatever infrastructure you're familiar with!_

With these choices out of the way, the boilerplate code required to define the
interface is, using the `cpu_kepler` function defined in the previous section as
Expand Down Expand Up @@ -305,20 +297,13 @@ this.
With that out of the way, the actual build routine is defined in the following
files:
- In `./pyproject.toml`, we specify that `pybind11` and `cmake` are required
build dependencies and that we'll use `setuptools.build_meta` as the build
backend.
- In `./pyproject.toml`, we specify that `pybind11` and `scikit-build-core` are
required build dependencies and that we'll use `scikit-build-core` as the
build backend.
- `setup.py` is a pretty typical setup file with a custom class for building the
extensions that executes CMake for the actual compilation step. This does
include some extra configuration arguments for CMake to make sure that it uses
the correct Python libraries and installs the compiled objects to the right
place. It might be possible to use something like [scikit-build][scikit-build]
to replace this step, but I struggled to get it working.
- Finally, `CMakeLists.txt` defines the build process for CMake using
[pybind11's support for CMake builds][pybind11-cmake]. This will also,
optionally, build the GPU ops as discussed below.
- Then, `CMakeLists.txt` defines the build process for CMake using [pybind11's
support for CMake builds][pybind11-cmake]. This will also, optionally, build
the GPU ops as discussed below.
With these files in place, we can now compile our XLA custom call ops using
Expand Down Expand Up @@ -605,18 +590,12 @@ __global__ void kepler_kernel(
## Building & packaging for the GPU

Since we're already using CMake to build our project, it's not too hard to add
support for CUDA. I've chosen to enable GPU builds by the environment variable
`KEPLER_JAX_CUDA=yes` that you'll see in both `setup.py` and `CMakeLists.txt`.
Other than conditionally adding an `Extension` in `setup.py`, everything else on
the Python side is the same. In `CMakeLists.txt`, we also add a conditional:
support for CUDA. I've chosen to enable GPU builds whenever CMake can detect
CUDA support using `CheckLanguage` in `CMakelists.txt`:

```cmake
if (KEPLER_JAX_CUDA)
enable_language(CUDA)
# ...
else()
message(STATUS "Building without CUDA")
endif()
include(CheckLanguage)
check_language(CUDA)
```

Then, to expose this to JAX, we need to update the translation rule from above as follows:
Expand Down Expand Up @@ -700,6 +679,6 @@ Colab:
[kepler-h]: https://github.com/dfm/extending-jax/blob/main/lib/kepler.h
[capsule]: https://docs.python.org/3/c-api/capsule.html "Capsules"
[jaxlib-lapack]: https://github.com/google/jax/blob/master/jaxlib/lapack.pyx "jax/lapack.pyx"
[scikit-build]: https://scikit-build.readthedocs.io/ "scikit-build"
[scikit-build-core]: https://github.com/scikit-build/scikit-build-core "scikit-build-core"
[pybind11-cmake]: https://pybind11.readthedocs.io/en/stable/compiling.html#building-with-cmake "Building with CMake"
[exoplanet-tutorial]: https://docs.exoplanet.codes/en/stable/tutorials/intro-to-pymc3/#A-more-realistic-example:-radial-velocity-exoplanets "A more realistic example: radial velocity exoplanets"
25 changes: 23 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,27 @@
[project]
name = "kepler_jax"
description = "A simple demonstration of how you can extend JAX with custom C++ and CUDA ops"
readme = "README.md"
authors = [{ name = "Dan Foreman-Mackey", email = "foreman.mackey@gmail.com" }]
requires-python = ">=3.9"
license = { file = "LICENSE" }
urls = { Homepage = "https://github.com/dfm/extending-jax" }
dependencies = ["jax>=0.4.16", "jaxlib>=0.4.16"]
dynamic = ["version"]

[project.optional-dependencies]
test = ["pytest"]

[build-system]
requires = ["setuptools>=42", "wheel", "setuptools_scm[toml]>=3.4", "pybind11>=2.6", "cmake"]
build-backend = "setuptools.build_meta"
requires = ["pybind11>=2.6", "scikit-build-core>=0.5"]
build-backend = "scikit_build_core.build"

[tool.scikit-build]
metadata.version.provider = "scikit_build_core.metadata.setuptools_scm"
sdist.include = ["src/kepler_jax/kepler_jax_version.py"]
wheel.install-dir = "kepler_jax"
minimum-version = "0.5"
build-dir = "build/{wheel_tag}"

[tool.setuptools_scm]
write_to = "src/kepler_jax/kepler_jax_version.py"
130 changes: 0 additions & 130 deletions setup.py

This file was deleted.

3 changes: 1 addition & 2 deletions tests/test_kepler_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@
import pytest

import jax
from jax.config import config
from jax.test_util import check_grads

from kepler_jax import kepler


config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", True)


@pytest.fixture(params=[np.float32, np.float64])
Expand Down

0 comments on commit c338696

Please sign in to comment.