Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(jax): add options to use TensorFlow C library to build the JAX backend #4357

Merged
merged 6 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/install/easy-install.md
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,10 @@ pip install deepmd-kit[jax]

:::::

To generate a SavedModel and use [the LAMMPS module](../third-party/lammps-command.md) and [the i-PI driver](../third-party/ipi.md),
you need to install the TensorFlow.
Switch to the TensorFlow {{ tensorflow_icon }} tab for more information.

::::::

:::::::
Expand Down
4 changes: 2 additions & 2 deletions doc/install/install-from-c-library.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Install from pre-compiled C library {{ tensorflow_icon }}
# Install from pre-compiled C library {{ tensorflow_icon }}, JAX {{ jax_icon }}

:::{note}
**Supported backends**: TensorFlow {{ tensorflow_icon }}
**Supported backends**: TensorFlow {{ tensorflow_icon }}, JAX {{ jax_icon }}
:::

DeePMD-kit provides pre-compiled C library package (`libdeepmd_c.tar.gz`) in each [release](https://github.com/deepmodeling/deepmd-kit/releases). It can be used to build the [LAMMPS plugin](./install-lammps.md) and [GROMACS patch](./install-gromacs.md), as well as many [third-party software packages](../third-party/out-of-deepmd-kit.md), without building TensorFlow and DeePMD-kit on one's own.
Expand Down
31 changes: 31 additions & 0 deletions doc/install/install-from-source.md
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,15 @@ You can also download libtorch prebuilt library from the [PyTorch website](https

:::

:::{tab-item} JAX {{ jax_icon }}

The JAX backend only depends on the TensorFlow C API, which is included in both TensorFlow C++ library and [TensorFlow C library](https://www.tensorflow.org/install/lang_c).
If you want to use the TensorFlow C++ library, just enable the TensorFlow backend (which depends on the TensorFlow C++ library) and nothing else needs to do.
If you want to use the TensorFlow C library and disable the TensorFlow backend,
download the TensorFlow C library from [this page](https://www.tensorflow.org/install/lang_c#download_and_extract).

:::

::::

### Install DeePMD-kit's C++ interface
Expand Down Expand Up @@ -369,6 +378,17 @@ cmake -DENABLE_PYTORCH=TRUE -DUSE_PT_PYTHON_LIBS=TRUE -DCMAKE_INSTALL_PREFIX=$de

:::

:::{tab-item} JAX {{ jax_icon }}

If you want to use the TensorFlow C++ library, just enable the TensorFlow backend and nothing else needs to do.
If you want to use the TensorFlow C library and disable the TensorFlow backend, set {cmake:variable}`ENABLE_JAX` to `ON` and `CMAKE_PREFIX_PATH` to the root directory of the [TensorFlow C library](https://www.tensorflow.org/install/lang_c).

```bash
cmake -DENABLE_JAX=ON -D CMAKE_PREFIX_PATH=${tensorflow_c_root} ..
```

:::

::::

One may add the following CMake variables to `cmake` using the [`-D <var>=<value>` option](https://cmake.org/cmake/help/latest/manual/cmake.1.html#cmdoption-cmake-D):
Expand All @@ -378,6 +398,7 @@ One may add the following CMake variables to `cmake` using the [`-D <var>=<value
**Type**: `BOOL` (`ON`/`OFF`), Default: `OFF`

{{ tensorflow_icon }} {{ jax_icon }} Whether building the TensorFlow backend and the JAX backend.
Setting this option to `ON` will also set {cmake:variable}`ENABLE_JAX` to `ON`.

:::

Expand All @@ -389,6 +410,16 @@ One may add the following CMake variables to `cmake` using the [`-D <var>=<value

:::

:::{cmake:variable} ENABLE_JAX

**Type**: `BOOL` (`ON`/`OFF`), Default: `OFF`

{{ jax_icon }} Build the JAX backend.
If {cmake:variable}`ENABLE_TENSORFLOW` is `ON`, the TensorFlow C++ library is used to build the JAX backend;
If {cmake:variable}`ENABLE_TENSORFLOW` is `OFF`, the TensorFlow C library is used to build the JAX backend.

:::

:::{cmake:variable} TENSORFLOW_ROOT

**Type**: `PATH`
Expand Down
25 changes: 25 additions & 0 deletions source/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ project(DeePMD)

option(ENABLE_TENSORFLOW "Enable TensorFlow interface" OFF)
option(ENABLE_PYTORCH "Enable PyTorch interface" OFF)
option(ENABLE_JAX "Enable JAX interface" OFF)
if(ENABLE_TENSORFLOW)
# JAX requires TF C interface, contained in TF C++ library
set(ENABLE_JAX ON)
endif()
njzjz marked this conversation as resolved.
Show resolved Hide resolved
option(BUILD_TESTING "Build test and enable coverage" OFF)
set(DEEPMD_C_ROOT
""
Expand Down Expand Up @@ -246,6 +251,22 @@ if(ENABLE_PYTORCH AND NOT DEEPMD_C_ROOT)
list(APPEND BACKEND_LIBRARY_PATH ${PyTorch_LIBRARY_PATH})
list(APPEND BACKEND_INCLUDE_DIRS ${TORCH_INCLUDE_DIRS})
endif()
if(ENABLE_JAX
AND BUILD_CPP_IF
AND NOT DEEPMD_C_ROOT)
# no way to find it using Python
find_package(TensorFlowC REQUIRED MODULE)
if(DEFINED TENSORFLOWC_LIBRARY)
list(APPEND BACKEND_LIBRARY_PATH ${TENSORFLOWC_LIBRARY})
endif()
if(DEFINED TENSORFLOWC_INCLUDE_DIR)
list(APPEND BACKEND_INCLUDE_DIRS ${TENSORFLOWC_INCLUDE_DIR})
endif()
endif()
njzjz marked this conversation as resolved.
Show resolved Hide resolved
if(NOT DEFINED OP_CXX_ABI)
# prevent setting an empty value; this is default on GCC>=5
set(OP_CXX_ABI 1)
endif()
# log enabled backends
if(NOT DEEPMD_C_ROOT)
message(STATUS "Enabled backends:")
Expand All @@ -255,8 +276,12 @@ if(NOT DEEPMD_C_ROOT)
if(ENABLE_PYTORCH)
message(STATUS "- PyTorch")
endif()
if(ENABLE_JAX)
message(STATUS "- JAX")
endif()
if(NOT ENABLE_TENSORFLOW
AND NOT ENABLE_PYTORCH
AND NOT ENABLE_JAX
AND NOT BUILD_PY_IF)
message(FATAL_ERROR "No backend is enabled.")
endif()
Expand Down
4 changes: 4 additions & 0 deletions source/api_cc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ if(ENABLE_PYTORCH
target_link_libraries(${libname} PRIVATE "${TORCH_LIBRARIES}")
target_compile_definitions(${libname} PRIVATE BUILD_PYTORCH)
endif()
if(ENABLE_JAX)
target_link_libraries(${libname} PRIVATE TensorFlow::tensorflow_c)
target_compile_definitions(${libname} PRIVATE BUILD_JAX)
endif()

target_include_directories(
${libname}
Expand Down
6 changes: 4 additions & 2 deletions source/api_cc/src/DeepPot.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
#include "AtomMap.h"
#include "common.h"
#ifdef BUILD_TENSORFLOW
#include "DeepPotJAX.h"
#include "DeepPotTF.h"
#endif
#ifdef BUILD_PYTORCH
#include "DeepPotPT.h"
#endif
#if defined(BUILD_TENSORFLOW) || defined(BUILD_JAX)
#include "DeepPotJAX.h"
#endif
#include "device.h"

using namespace deepmd;
Expand Down Expand Up @@ -63,7 +65,7 @@ void DeepPot::init(const std::string& model,
} else if (deepmd::DPBackend::Paddle == backend) {
throw deepmd::deepmd_exception("PaddlePaddle backend is not supported yet");
} else if (deepmd::DPBackend::JAX == backend) {
#ifdef BUILD_TENSORFLOW
#if defined(BUILD_TENSORFLOW) || defined(BUILD_JAX)
dp = std::make_shared<deepmd::DeepPotJAX>(model, gpu_rank, file_content);
#else
throw deepmd::deepmd_exception(
Expand Down
2 changes: 1 addition & 1 deletion source/api_cc/src/DeepPotJAX.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: LGPL-3.0-or-later
#ifdef BUILD_TENSORFLOW
#if defined(BUILD_TENSORFLOW) || defined(BUILD_JAX)
njzjz marked this conversation as resolved.
Show resolved Hide resolved

#include "DeepPotJAX.h"

Expand Down
40 changes: 40 additions & 0 deletions source/cmake/FindTensorFlowC.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Find TensorFlow C library (libtensorflow) Define target
# TensorFlow::tensorflow_c If TensorFlow::tensorflow_cc is not found, also
# define: - TENSORFLOWC_INCLUDE_DIR - TENSORFLOWC_LIBRARY

if(TARGET TensorFlow::tensorflow_cc)
# since tensorflow_cc contain tensorflow_c, just use it
add_library(TensorFlow::tensorflow_c ALIAS TensorFlow::tensorflow_cc)
set(TensorFlowC_FOUND TRUE)
endif()

if(NOT TensorFlowC_FOUND)
find_path(
TENSORFLOWC_INCLUDE_DIR
NAMES tensorflow/c/c_api.h
PATH_SUFFIXES include
DOC "Path to TensorFlow C include directory")

find_library(
TENSORFLOWC_LIBRARY
NAMES tensorflow
PATH_SUFFIXES lib
DOC "Path to TensorFlow C library")

include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(
TensorFlowC REQUIRED_VARS TENSORFLOWC_LIBRARY TENSORFLOWC_INCLUDE_DIR)

if(TensorFlowC_FOUND)
set(TensorFlowC_INCLUDE_DIRS ${TENSORFLOWC_INCLUDE_DIR})
set(TensorFlowC_LIBRARIES ${TENSORFLOWC_LIBRARY})
endif()

add_library(TensorFlow::tensorflow_c SHARED IMPORTED GLOBAL)
set_property(TARGET TensorFlow::tensorflow_c PROPERTY IMPORTED_LOCATION
${TENSORFLOWC_LIBRARY})
target_include_directories(TensorFlow::tensorflow_c
INTERFACE ${TENSORFLOWC_INCLUDE_DIR})

mark_as_advanced(TENSORFLOWC_LIBRARY TENSORFLOWC_INCLUDE_DIR)
endif()