diff --git a/.github/action/Dockerfile b/.github/action/Dockerfile index 9279506..2b9f297 100644 --- a/.github/action/Dockerfile +++ b/.github/action/Dockerfile @@ -1,10 +1,10 @@ FROM nvidia/cuda:11.8.0-devel-ubuntu22.04 RUN apt-get update && \ - DEBIAN_FRONTEND=noninteractive apt-get install -y git python3-pip cmake + DEBIAN_FRONTEND=noninteractive apt-get install -y git python3-pip RUN pip install --upgrade pip && \ - pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html + pip install --upgrade "jax[cuda11_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html COPY entrypoint.sh /entrypoint.sh diff --git a/.github/action/entrypoint.sh b/.github/action/entrypoint.sh index 6f03702..45aff67 100755 --- a/.github/action/entrypoint.sh +++ b/.github/action/entrypoint.sh @@ -1,6 +1,6 @@ #!/bin/sh -l cd /github/workspace -KEPLER_JAX_CUDA=yes python3 -m pip install . +KEPLER_JAX_CUDA=yes python3 -m pip install -v . python3 -c 'import kepler_jax;print(kepler_jax.__version__)' python3 -c 'import kepler_jax.gpu_ops' diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 266a102..668f95b 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -16,7 +16,7 @@ jobs: os: [ubuntu-latest, macos-latest] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 with: fetch-depth: 0 @@ -28,7 +28,7 @@ jobs: - name: Install dependencies run: | python -m pip install -U pip - python -m pip install .[test] + python -m pip install -v .[test] - name: Run tests run: python -m pytest -v tests @@ -37,7 +37,7 @@ jobs: name: CUDA runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 with: fetch-depth: 0 - uses: ./.github/action diff --git a/CMakeLists.txt b/CMakeLists.txt index ba09e84..9b75854 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,25 +1,29 @@ -cmake_minimum_required(VERSION 3.12...3.18) -project(kepler_jax LANGUAGES CXX) +cmake_minimum_required(VERSION 3.15...3.26) +project(${SKBUILD_PROJECT_NAME} LANGUAGES C CXX) +message(STATUS "Using CMake version: " ${CMAKE_VERSION}) -message(STATUS "Using CMake version " ${CMAKE_VERSION}) - -find_package(Python COMPONENTS Interpreter Development REQUIRED) +# Find pybind11 +set(PYBIND11_NEWPYTHON ON) find_package(pybind11 CONFIG REQUIRED) include_directories(${CMAKE_CURRENT_LIST_DIR}/lib) # CPU op library pybind11_add_module(cpu_ops ${CMAKE_CURRENT_LIST_DIR}/lib/cpu_ops.cc) -install(TARGETS cpu_ops DESTINATION kepler_jax) +install(TARGETS cpu_ops LIBRARY DESTINATION .) + +# Include the CUDA extensions if possible +include(CheckLanguage) +check_language(CUDA) -if (KEPLER_JAX_CUDA) +if(CMAKE_CUDA_COMPILER) enable_language(CUDA) include_directories(${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) pybind11_add_module( gpu_ops ${CMAKE_CURRENT_LIST_DIR}/lib/kernels.cc.cu ${CMAKE_CURRENT_LIST_DIR}/lib/gpu_ops.cc) - install(TARGETS gpu_ops DESTINATION kepler_jax) + install(TARGETS gpu_ops LIBRARY DESTINATION .) else() message(STATUS "Building without CUDA") endif() diff --git a/README.md b/README.md index 601e6ef..19fab54 100644 --- a/README.md +++ b/README.md @@ -142,11 +142,8 @@ wanted to emphasize a few points to consider. The files in this repo come in three categories: 1. In the root directory, there are the standard packaging files like a - `setup.py` and `pyproject.toml`. Most of this setup is pretty standard, but - I'll highlight some of the unique elements in the packaging section below. - For example, we'll use a slightly strange combination of PEP-517/518 and - CMake to build the extensions. This isn't strictly necessary, but it's the - easiest packaging setup that I've been able to put together. + `pyproject.toml`. Most of this setup is pretty standard, but + I'll highlight some unique elements in the packaging section below. 2. Next, the `src/kepler_jax` directory is a Python module with the definition of our JAX primitive roughly following the JAX [How primitives @@ -260,17 +257,12 @@ pybind11 since that's what I'm most familiar with. The [LAPACK ops in jaxlib][jaxlib-lapack] are implemented using Cython if you'd like to see an example of how to do that. -Another choice that I've made is to use [CMake](https://cmake.org) to build the -extensions. It would be totally possible (and perhaps preferable if you only -support CPU usage) to stick to just using setuptools directly, but setuptools -doesn't seem to have great support for compiling CUDA extensions so that's why I -settled on CMake. In the end, it's not too painful since CMake can be included -as a build dependency in `pyproject.toml` so users won't have to install it -separately. Another build option would be to use [bazel](https://bazel.build) to -compile the code, like the JAX project, but I don't have any experience with it -so I decided to stick with what I know. _The key point is that we're just -compiling a regular old Python module so you can use whatever infrastructure -you're familiar with!_ +Another choice that I've made is to use [scikit-build-core](scikit-build-core) +and [CMake](https://cmake.org) to build the extensions. Another build option +would be to use [bazel](https://bazel.build) to compile the code, like the JAX +project, but I don't have any experience with it, so I decided to stick with +what I know. _The key point is that we're just compiling a regular old Python +module, so you can use whatever infrastructure you're familiar with!_ With these choices out of the way, the boilerplate code required to define the interface is, using the `cpu_kepler` function defined in the previous section as @@ -305,20 +297,13 @@ this. With that out of the way, the actual build routine is defined in the following files: -- In `./pyproject.toml`, we specify that `pybind11` and `cmake` are required - build dependencies and that we'll use `setuptools.build_meta` as the build - backend. +- In `./pyproject.toml`, we specify that `pybind11` and `scikit-build-core` are + required build dependencies and that we'll use `scikit-build-core` as the + build backend. -- `setup.py` is a pretty typical setup file with a custom class for building the - extensions that executes CMake for the actual compilation step. This does - include some extra configuration arguments for CMake to make sure that it uses - the correct Python libraries and installs the compiled objects to the right - place. It might be possible to use something like [scikit-build][scikit-build] - to replace this step, but I struggled to get it working. - -- Finally, `CMakeLists.txt` defines the build process for CMake using - [pybind11's support for CMake builds][pybind11-cmake]. This will also, - optionally, build the GPU ops as discussed below. +- Then, `CMakeLists.txt` defines the build process for CMake using [pybind11's + support for CMake builds][pybind11-cmake]. This will also, optionally, build + the GPU ops as discussed below. With these files in place, we can now compile our XLA custom call ops using @@ -605,18 +590,12 @@ __global__ void kepler_kernel( ## Building & packaging for the GPU Since we're already using CMake to build our project, it's not too hard to add -support for CUDA. I've chosen to enable GPU builds by the environment variable -`KEPLER_JAX_CUDA=yes` that you'll see in both `setup.py` and `CMakeLists.txt`. -Other than conditionally adding an `Extension` in `setup.py`, everything else on -the Python side is the same. In `CMakeLists.txt`, we also add a conditional: +support for CUDA. I've chosen to enable GPU builds whenever CMake can detect +CUDA support using `CheckLanguage` in `CMakelists.txt`: ```cmake -if (KEPLER_JAX_CUDA) - enable_language(CUDA) - # ... -else() - message(STATUS "Building without CUDA") -endif() +include(CheckLanguage) +check_language(CUDA) ``` Then, to expose this to JAX, we need to update the translation rule from above as follows: @@ -700,6 +679,6 @@ Colab: [kepler-h]: https://github.com/dfm/extending-jax/blob/main/lib/kepler.h [capsule]: https://docs.python.org/3/c-api/capsule.html "Capsules" [jaxlib-lapack]: https://github.com/google/jax/blob/master/jaxlib/lapack.pyx "jax/lapack.pyx" -[scikit-build]: https://scikit-build.readthedocs.io/ "scikit-build" +[scikit-build-core]: https://github.com/scikit-build/scikit-build-core "scikit-build-core" [pybind11-cmake]: https://pybind11.readthedocs.io/en/stable/compiling.html#building-with-cmake "Building with CMake" [exoplanet-tutorial]: https://docs.exoplanet.codes/en/stable/tutorials/intro-to-pymc3/#A-more-realistic-example:-radial-velocity-exoplanets "A more realistic example: radial velocity exoplanets" diff --git a/pyproject.toml b/pyproject.toml index a82478d..38f83b1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,27 @@ +[project] +name = "kepler_jax" +description = "A simple demonstration of how you can extend JAX with custom C++ and CUDA ops" +readme = "README.md" +authors = [{ name = "Dan Foreman-Mackey", email = "foreman.mackey@gmail.com" }] +requires-python = ">=3.9" +license = { file = "LICENSE" } +urls = { Homepage = "https://github.com/dfm/extending-jax" } +dependencies = ["jax>=0.4.16", "jaxlib>=0.4.16"] +dynamic = ["version"] + +[project.optional-dependencies] +test = ["pytest"] + [build-system] -requires = ["setuptools>=42", "wheel", "setuptools_scm[toml]>=3.4", "pybind11>=2.6", "cmake"] -build-backend = "setuptools.build_meta" +requires = ["pybind11>=2.6", "scikit-build-core>=0.5"] +build-backend = "scikit_build_core.build" + +[tool.scikit-build] +metadata.version.provider = "scikit_build_core.metadata.setuptools_scm" +sdist.include = ["src/kepler_jax/kepler_jax_version.py"] +wheel.install-dir = "kepler_jax" +minimum-version = "0.5" +build-dir = "build/{wheel_tag}" [tool.setuptools_scm] write_to = "src/kepler_jax/kepler_jax_version.py" diff --git a/setup.py b/setup.py deleted file mode 100644 index 7e9a341..0000000 --- a/setup.py +++ /dev/null @@ -1,130 +0,0 @@ -#!/usr/bin/env python - -import codecs -import os -import subprocess - -from setuptools import Extension, find_packages, setup -from setuptools.command.build_ext import build_ext - -HERE = os.path.dirname(os.path.realpath(__file__)) - - -def read(*parts): - with codecs.open(os.path.join(HERE, *parts), "rb", "utf-8") as f: - return f.read() - - -# This custom class for building the extensions uses CMake to compile. You -# don't have to use CMake for this task, but I found it to be the easiest when -# compiling ops with GPU support since setuptools doesn't have great CUDA -# support. -class CMakeBuildExt(build_ext): - def build_extensions(self): - # First: configure CMake build - import platform - import sys - import distutils.sysconfig - - import pybind11 - - # Work out the relevant Python paths to pass to CMake, adapted from the - # PyTorch build system - if platform.system() == "Windows": - cmake_python_library = "{}/libs/python{}.lib".format( - distutils.sysconfig.get_config_var("prefix"), - distutils.sysconfig.get_config_var("VERSION"), - ) - if not os.path.exists(cmake_python_library): - cmake_python_library = "{}/libs/python{}.lib".format( - sys.base_prefix, - distutils.sysconfig.get_config_var("VERSION"), - ) - else: - cmake_python_library = "{}/{}".format( - distutils.sysconfig.get_config_var("LIBDIR"), - distutils.sysconfig.get_config_var("INSTSONAME"), - ) - cmake_python_include_dir = distutils.sysconfig.get_python_inc() - - install_dir = os.path.abspath( - os.path.dirname(self.get_ext_fullpath("dummy")) - ) - os.makedirs(install_dir, exist_ok=True) - cmake_args = [ - "-DCMAKE_INSTALL_PREFIX={}".format(install_dir), - "-DPython_EXECUTABLE={}".format(sys.executable), - "-DPython_LIBRARIES={}".format(cmake_python_library), - "-DPython_INCLUDE_DIRS={}".format(cmake_python_include_dir), - "-DCMAKE_BUILD_TYPE={}".format( - "Debug" if self.debug else "Release" - ), - "-DCMAKE_PREFIX_PATH={}".format(pybind11.get_cmake_dir()), - ] - if os.environ.get("KEPLER_JAX_CUDA", "no").lower() == "yes": - cmake_args.append("-DKEPLER_JAX_CUDA=yes") - - os.makedirs(self.build_temp, exist_ok=True) - subprocess.check_call( - ["cmake", HERE] + cmake_args, cwd=self.build_temp - ) - - # Build all the extensions - super().build_extensions() - - # Finally run install - subprocess.check_call( - ["cmake", "--build", ".", "--target", "install"], - cwd=self.build_temp, - ) - - def build_extension(self, ext): - target_name = ext.name.split(".")[-1] - subprocess.check_call( - ["cmake", "--build", ".", "--target", target_name], - cwd=self.build_temp, - ) - - -extensions = [ - Extension( - "kepler_jax.cpu_ops", - ["src/kepler_jax/src/cpu_ops.cc"], - ), -] - -if os.environ.get("KEPLER_JAX_CUDA", "no").lower() == "yes": - extensions.append( - Extension( - "kepler_jax.gpu_ops", - [ - "src/kepler_jax/src/gpu_ops.cc", - "src/kepler_jax/src/cuda_kernels.cc.cu", - ], - ) - ) - - -setup( - name="kepler_jax", - author="Dan Foreman-Mackey", - author_email="foreman.mackey@gmail.com", - url="https://github.com/dfm/extending-jax", - license="MIT", - description=( - "A simple demonstration of how you can extend JAX with custom C++ and " - "CUDA ops" - ), - long_description=read("README.md"), - long_description_content_type="text/markdown", - packages=find_packages("src"), - package_dir={"": "src"}, - include_package_data=True, - install_requires=[ - "jax>=0.4.16", - "jaxlib>=0.4.16" - ], - extras_require={"test": "pytest"}, - ext_modules=extensions, - cmdclass={"build_ext": CMakeBuildExt}, -) diff --git a/tests/test_kepler_jax.py b/tests/test_kepler_jax.py index 6703bb8..41a0ada 100644 --- a/tests/test_kepler_jax.py +++ b/tests/test_kepler_jax.py @@ -4,13 +4,12 @@ import pytest import jax -from jax.config import config from jax.test_util import check_grads from kepler_jax import kepler -config.update("jax_enable_x64", True) +jax.config.update("jax_enable_x64", True) @pytest.fixture(params=[np.float32, np.float64])