-
Notifications
You must be signed in to change notification settings - Fork 527
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(jax): add options to use TensorFlow C library to build the JAX backend #4357
Conversation
…ackend Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
📝 Walkthrough📝 WalkthroughWalkthroughThe pull request updates the installation documentation for DeePMD-kit to clarify installation instructions and requirements, particularly regarding the JAX backend and TensorFlow. It expands the supported platforms to include Windows x86-64 and introduces a new CMake option for enabling JAX support. The changes also enhance backend integration within the source code, ensuring compatibility with both TensorFlow and JAX while maintaining existing functionalities. Changes
Possibly related PRs
Suggested labels
Suggested reviewers
📜 Recent review detailsConfiguration used: CodeRabbit UI 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Outside diff range and nitpick comments (19)
source/cmake/FindTensorFlowC.cmake (5)
1-4
: Improve module documentation format and clarity.The documentation should follow CMake's standard format and clearly describe the module's behavior, variables, and targets.
-# Find TensorFlow C library (libtensorflow) Define target -# TensorFlow::tensorflow_c If TensorFlow::tensorflow_cc is not found, also -# define: - TENSORFLOWC_INCLUDE_DIR - TENSORFLOWC_LIBRARY +#[=======================================================================[.rst: +FindTensorFlowC +--------------- + +Finds the TensorFlow C library (libtensorflow). + +Imported Targets +^^^^^^^^^^^^^^^ + +This module provides the following imported targets, if found: + +``TensorFlow::tensorflow_c`` + The TensorFlow C library + +Result Variables +^^^^^^^^^^^^^^^ + +This will define the following variables: + +``TensorFlowC_FOUND`` + True if the system has the TensorFlow C library. +``TENSORFLOWC_INCLUDE_DIR`` + Include directory for TensorFlow C headers (when tensorflow_cc not found). +``TENSORFLOWC_LIBRARY`` + The path to TensorFlow C library (when tensorflow_cc not found). +#]=======================================================================]
5-11
: Add validation and improve documentation for tensorflow_cc usage.While the logic is correct, it would be beneficial to:
- Document why tensorflow_cc contains tensorflow_c functionality
- Validate that tensorflow_cc actually provides C API functionality
if(TARGET TensorFlow::tensorflow_cc) - # since tensorflow_cc contain tensorflow_c, just use it + # TensorFlow C++ library (tensorflow_cc) includes C API functionality + # Verify C API header is available in tensorflow_cc + get_target_property(_tf_cc_include TensorFlow::tensorflow_cc INTERFACE_INCLUDE_DIRECTORIES) + if (EXISTS "${_tf_cc_include}/tensorflow/c/c_api.h") add_library(TensorFlow::tensorflow_c SHARED IMPORTED GLOBAL) target_link_libraries(TensorFlow::tensorflow_c INTERFACE TensorFlow::tensorflow_cc) set(TensorFlowC_FOUND TRUE) + else() + message(STATUS "TensorFlow::tensorflow_cc found but missing C API headers") + endif() endif()
13-33
: Enhance library discovery with version checking and additional search paths.Consider adding:
- Version detection and compatibility check
- Common installation paths for different platforms
- Environment variable support for custom installations
if(NOT TensorFlowC_FOUND) + # Support custom install locations via environment variables + set(_TF_SEARCH_PATHS + $ENV{TENSORFLOW_ROOT} + $ENV{TENSORFLOW_HOME} + /usr/local + /usr + ) + find_path( TENSORFLOWC_INCLUDE_DIR NAMES tensorflow/c/c_api.h + PATHS ${_TF_SEARCH_PATHS} PATH_SUFFIXES include DOC "Path to TensorFlow C include directory") find_library( TENSORFLOWC_LIBRARY NAMES tensorflow + PATHS ${_TF_SEARCH_PATHS} PATH_SUFFIXES lib lib64 DOC "Path to TensorFlow C library") + if(TENSORFLOWC_INCLUDE_DIR) + # Extract version from TensorFlow header + file(STRINGS "${TENSORFLOWC_INCLUDE_DIR}/tensorflow/c/c_api.h" TF_VERSION_STR + REGEX "^#define TF_VERSION_STRING \"[^\"]*\"$") + string(REGEX REPLACE "^#define TF_VERSION_STRING \"([0-9.]+).*\"$" "\\1" + TensorFlowC_VERSION "${TF_VERSION_STR}") + endif() include(FindPackageHandleStandardArgs) find_package_handle_standard_args( - TensorFlowC REQUIRED_VARS TENSORFLOWC_LIBRARY TENSORFLOWC_INCLUDE_DIR) + TensorFlowC + REQUIRED_VARS TENSORFLOWC_LIBRARY TENSORFLOWC_INCLUDE_DIR + VERSION_VAR TensorFlowC_VERSION)
35-40
: Enhance target configuration with additional properties and platform support.Consider adding:
- Platform-specific library naming (e.g., .dll on Windows)
- Additional target properties for better integration
add_library(TensorFlow::tensorflow_c SHARED IMPORTED GLOBAL) + +if(WIN32) + set_target_properties(TensorFlow::tensorflow_c PROPERTIES + IMPORTED_IMPLIB "${TENSORFLOWC_LIBRARY}" + IMPORTED_LOCATION "${TENSORFLOWC_RUNTIME_LIBRARY}") +else() set_property(TARGET TensorFlow::tensorflow_c PROPERTY IMPORTED_LOCATION ${TENSORFLOWC_LIBRARY}) +endif() + +set_target_properties(TensorFlow::tensorflow_c PROPERTIES + INTERFACE_COMPILE_DEFINITIONS "TENSORFLOW_C_API" + VERSION "${TensorFlowC_VERSION}") + target_include_directories(TensorFlow::tensorflow_c INTERFACE ${TENSORFLOWC_INCLUDE_DIR})
1-42
: Consider enhancing module robustness and maintainability.While the implementation is solid, consider these architectural improvements:
- Add a component-based find_package support (COMPONENTS)
- Support for static linking configuration
- Add pkg-config fallback for Unix-like systems
- Add proper dependency handling for CUDA/ROCm if needed by TensorFlow
Would you like assistance in implementing any of these architectural improvements?
doc/install/install-from-c-library.md (1)
Line range hint
1-24
: Consider adding JAX-specific requirements or limitations.While the document clearly indicates JAX support, it would be helpful to specify if there are any JAX-specific requirements, limitations, or configuration differences when using the pre-compiled library with JAX backend.
Consider adding a section that addresses:
- Any additional dependencies needed for JAX support
- Whether the same pre-compiled library works for both backends
- Any backend-specific configuration steps during installation
🧰 Tools
🪛 LanguageTool
[uncategorized] ~7-~7: You might be missing the article “a” here.
Context: ...{{ jax_icon }} ::: DeePMD-kit provides pre-compiled C library package (`libdeepmd_...(AI_EN_LECTOR_MISSING_DETERMINER_A)
source/api_cc/CMakeLists.txt (1)
26-29
: Consider documenting backend compatibility matrix.As the codebase now supports multiple backends (TensorFlow, PyTorch, and JAX), it would be valuable to:
- Document which backends can be enabled simultaneously
- Consider adding CMake checks for incompatible backend combinations
- Add comments explaining the rationale behind the backend-specific compile definitions
Consider adding a comment above the JAX block:
+# JAX backend requires TensorFlow C library for operation if(ENABLE_JAX) target_link_libraries(${libname} PRIVATE TensorFlow::tensorflow_c) target_compile_definitions(${libname} PRIVATE BUILD_JAX) endif()
doc/install/easy-install.md (2)
207-210
: Consider making the TensorFlow requirement more prominent and specific.The TensorFlow requirement note is crucial information but might be overlooked in its current location under the JAX tab. Consider:
- Moving this note to a more prominent location, perhaps at the beginning of the installation section
- Specifying the minimum required TensorFlow version
- Adding a brief explanation of why TensorFlow is needed for these specific features
+ :::{important} + To generate a SavedModel and use [the LAMMPS module](../third-party/lammps-command.md) and [the i-PI driver](../third-party/ipi.md), + you need to install TensorFlow (version X.Y or above). + This requirement exists because these features rely on TensorFlow's SavedModel format for model serialization. + See the TensorFlow {{ tensorflow_icon }} tab for installation instructions. + :::
207-210
: Enhance JAX installation instructions with version and compatibility details.The JAX installation tab would benefit from additional information to help users make informed decisions:
- Specify supported JAX versions
- Add a compatibility matrix for JAX-CUDA combinations
- Consider adding CUDA 11 support similar to other backends
::::::{tab-item} JAX {{ jax_icon }} +:::{note} +Supported JAX versions: X.Y.Z or above + +| DeePMD-kit | JAX | CUDA Support | +|------------|--------|--------------| +| 3.0.0 | 0.4.13 | 11.8, 12.0 | +::: + :::::{tab-set}source/CMakeLists.txt (1)
254-262
: Consider adding error handling for missing pathsThe JAX backend configuration is well-structured and follows the same pattern as other backends. However, consider adding explicit error messages when TENSORFLOWC_LIBRARY or TENSORFLOWC_INCLUDE_DIR are not defined.
if(ENABLE_JAX AND NOT DEEPMD_C_ROOT) find_package(TensorFlowC REQUIRED MODULE) + if(NOT DEFINED TENSORFLOWC_LIBRARY) + message(FATAL_ERROR "TensorFlow C library not found but required for JAX support") + endif() + if(NOT DEFINED TENSORFLOWC_INCLUDE_DIR) + message(FATAL_ERROR "TensorFlow C include directory not found but required for JAX support") + endif() if(DEFINED TENSORFLOWC_LIBRARY) list(APPEND BACKEND_LIBRARY_PATH ${TENSORFLOWC_LIBRARY}) endif() if(DEFINED TENSORFLOWC_INCLUDE_DIR) list(APPEND BACKEND_INCLUDE_DIRS ${TENSORFLOWC_INCLUDE_DIR}) endif() endif()doc/install/install-from-source.md (3)
319-326
: Enhance clarity and completeness of JAX backend installation instructionsThe JAX backend installation section could be improved for better clarity and completeness:
Consider this revised structure:
:::{tab-item} JAX {{ jax_icon }} -The JAX backend only depends on the TensorFlow C API, which is included in both TensorFlow C++ library and [TensorFlow C library](https://www.tensorflow.org/install/lang_c). -If you want to use the TensorFlow C++ library, just enable the TensorFlow backend (which depends on the TensorFlow C++ library) and nothing else needs to do. -If you want to use the TensorFlow C library and disable the TensorFlow backend, -download the TensorFlow C library from [this page](https://www.tensorflow.org/install/lang_c#download_and_extract). +The JAX backend requires the TensorFlow C API, which can be obtained through either: + +1. TensorFlow C++ library (recommended if you're also using the TensorFlow backend) + - Automatically set up when you enable the TensorFlow backend + - No additional steps required + +2. TensorFlow C library (minimal option) + - Suitable when you only need the JAX backend + - Download from the [TensorFlow C library page](https://www.tensorflow.org/install/lang_c#download_and_extract) + - Ensure version compatibility with your JAX installation🧰 Tools
🪛 LanguageTool
[style] ~323-~323: You have already used this phrasing in nearby sentences. Consider replacing it to add variety to your writing.
Context: ...y) and nothing else needs to do. If you want to use the TensorFlow C library and disabl...(REP_WANT_TO_VB)
381-390
: Improve CMake configuration instructions for JAX backendThe CMake configuration section could be enhanced with more detailed instructions:
Consider this revised structure:
:::{tab-item} JAX {{ jax_icon }} -If you want to use the TensorFlow C++ library, just enable the TensorFlow backend and nothing else needs to do. -If you want to use the TensorFlow C library and disable the TensorFlow backend, set {cmake:variable}`ENABLE_JAX` to `ON` and `CMAKE_PREFIX_PATH` to the root directory of the [TensorFlow C library](https://www.tensorflow.org/install/lang_c). +Choose one of the following CMake configurations based on your setup: -```bash -cmake -DENABLE_JAX=ON -D CMAKE_PREFIX_PATH=${tensorflow_c_root} .. -``` +1. Using TensorFlow C++ library: + ```bash + # Enable both TensorFlow and JAX backends + cmake -DENABLE_TENSORFLOW=ON .. + ``` + +2. Using TensorFlow C library: + ```bash + # Enable only JAX backend with TF C library + cmake -DENABLE_JAX=ON -DCMAKE_PREFIX_PATH=${tensorflow_c_root} .. + ``` + +Note: Ensure that the TensorFlow C library version is compatible with your JAX installation.🧰 Tools
🪛 LanguageTool
[style] ~384-~384: You have already used this phrasing in nearby sentences. Consider replacing it to add variety to your writing.
Context: ...nd and nothing else needs to do. If you want to use the TensorFlow C library and disabl...(REP_WANT_TO_VB)
413-421
: Enhance ENABLE_JAX CMake variable documentationThe CMake variable documentation could be more comprehensive:
Consider this revised structure:
:::{cmake:variable} ENABLE_JAX **Type**: `BOOL` (`ON`/`OFF`), Default: `OFF` {{ jax_icon }} Build the JAX backend. -If {cmake:variable}`ENABLE_TENSORFLOW` is `ON`, the TensorFlow C++ library is used to build the JAX backend; -If {cmake:variable}`ENABLE_TENSORFLOW` is `OFF`, the TensorFlow C library is used to build the JAX backend. + +Dependencies: +- When {cmake:variable}`ENABLE_TENSORFLOW` is `ON`: + - Uses TensorFlow C++ library automatically + - No additional configuration needed +- When {cmake:variable}`ENABLE_TENSORFLOW` is `OFF`: + - Requires TensorFlow C library + - Set `CMAKE_PREFIX_PATH` to the TensorFlow C library root + +See also: +- {cmake:variable}`ENABLE_TENSORFLOW` +- {cmake:variable}`CMAKE_PREFIX_PATH`source/api_cc/src/DeepPot.cc (1)
Line range hint
68-74
: Document TensorFlow dependency for JAX backendThe error message indicates that TensorFlow is required for loading JAX2TF SavedModels. Consider adding this dependency requirement to the documentation.
Consider adding a comment in the code explaining why TensorFlow is needed for JAX models:
#if defined(BUILD_TENSORFLOW) || defined(BUILD_JAX) + // TensorFlow is required for loading JAX2TF SavedModels as JAX models are converted to TensorFlow format dp = std::make_shared<deepmd::DeepPotJAX>(model, gpu_rank, file_content); #else throw deepmd::deepmd_exception( "TensorFlow backend is not built, which is used to load JAX2TF " "SavedModels"); #endifsource/api_cc/src/DeepPotJAX.cc (5)
Line range hint
166-191
: Handle potential memory leaks when exceptions are thrownIn the
init
method, if an exception is thrown after allocating resources (e.g., creating sessions, graphs), these resources may not be properly deallocated, leading to memory leaks.Consider using RAII (Resource Acquisition Is Initialization) patterns or smart pointers to ensure that all allocated resources are automatically freed, even when exceptions occur.
Alternatively, ensure that all allocated resources are properly deallocated in a
catch
block or by restructuring the code to prevent leaks.
Line range hint
203-366
: Improve code modularity by refactoring duplicate codeThe
compute
methods forDeepPotJAX
with and without neighbor lists have significant duplicate code. Refactoring common code into helper functions can improve maintainability.Consider extracting shared logic into private helper functions. This will reduce code duplication and make future updates easier.
Line range hint
276-279
: Avoid hardcoding the padding factorThe padding factor is defined as
#define PADDING_FACTOR 1.05
. Using a floating-point value in a macro can lead to precision issues.Consider defining
PADDING_FACTOR
as aconstexpr
variable:constexpr double PADDING_FACTOR = 1.05;This approach provides type safety and avoids potential issues with macros.
Line range hint
280-286
: Ensure padding handles all cases correctlyIn the code that adjusts
padding_to_nall
, there is potential for an infinite loop ifnall_real
remains larger thanpadding_to_nall
after multiplication:while (padding_to_nall < nall_real) { padding_to_nall *= PADDING_FACTOR; }
Ensure that
padding_to_nall
will eventually exceednall_real
to prevent an infinite loop. Additionally, consider setting an upper limit or using a different strategy to calculate padding.
Line range hint
418-448
: Implement missing methods or remove unused declarationsThe methods
computew_mixed_type
throw anot implemented
exception:throw deepmd::deepmd_exception("not implemented");If these methods are not intended to be implemented soon, consider removing them to avoid confusion. If they are planned for future development, consider adding comments or
TODO
notes.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (8)
doc/install/easy-install.md
(1 hunks)doc/install/install-from-c-library.md
(1 hunks)doc/install/install-from-source.md
(4 hunks)source/CMakeLists.txt
(3 hunks)source/api_cc/CMakeLists.txt
(1 hunks)source/api_cc/src/DeepPot.cc
(2 hunks)source/api_cc/src/DeepPotJAX.cc
(1 hunks)source/cmake/FindTensorFlowC.cmake
(1 hunks)
🧰 Additional context used
🪛 LanguageTool
doc/install/install-from-source.md
[style] ~323-~323: You have already used this phrasing in nearby sentences. Consider replacing it to add variety to your writing.
Context: ...y) and nothing else needs to do. If you want to use the TensorFlow C library and disabl...
(REP_WANT_TO_VB)
[style] ~384-~384: You have already used this phrasing in nearby sentences. Consider replacing it to add variety to your writing.
Context: ...nd and nothing else needs to do. If you want to use the TensorFlow C library and disabl...
(REP_WANT_TO_VB)
🔇 Additional comments (12)
doc/install/install-from-c-library.md (2)
1-1
: LGTM! Clear indication of JAX support.
The title and note section accurately reflect the addition of JAX as a supported backend alongside TensorFlow.
Also applies to: 4-4
Line range hint 9-24
: Verify CUDA compatibility with JAX backend.
The document mentions CUDA 12.2 and 11.8 compatibility. Since JAX has its own CUDA requirements, it would be important to verify and document the CUDA version compatibility matrix for both backends.
🧰 Tools
🪛 LanguageTool
[uncategorized] ~7-~7: You might be missing the article “a” here.
Context: ...{{ jax_icon }} ::: DeePMD-kit provides pre-compiled C library package (`libdeepmd_...
(AI_EN_LECTOR_MISSING_DETERMINER_A)
source/api_cc/CMakeLists.txt (1)
26-29
: LGTM! Clean integration of JAX backend support.
The implementation follows the established pattern for backend integration and correctly uses PRIVATE visibility for backend-specific dependencies.
Let's verify the TensorFlow C library configuration:
✅ Verification successful
Let me check the other CMake files to understand the complete JAX configuration:
TensorFlow C library configuration is properly set up for JAX backend
The verification confirms that:
- TensorFlow C library target (
TensorFlow::tensorflow_c
) is properly defined inFindTensorFlowC.cmake
- JAX backend is correctly configured in the root CMakeLists.txt with proper dependency handling:
- Auto-enables when TensorFlow is enabled
- Properly sets up library paths and includes
- Includes appropriate error handling when no backend is enabled
- The implementation in
source/api_cc/CMakeLists.txt
correctly uses the target
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Verify TensorFlow C library target definition and any related configuration
# Check for TensorFlow find/config files
fd --type f "TensorFlow.*cmake" --exec cat {} \;
# Look for any other JAX-related CMake configuration
rg -l "ENABLE_JAX|tensorflow_c" --type cmake
Length of output: 1791
Script:
#!/bin/bash
# Check JAX-related configuration in other CMake files
cat source/CMakeLists.txt | grep -A 5 -B 5 "ENABLE_JAX"
# Check if tensorflow_c is properly used in other components
cat source/op/tf/CMakeLists.txt | grep -A 3 "tensorflow"
cat source/api_c/CMakeLists.txt | grep -A 3 "tensorflow"
Length of output: 2258
doc/install/easy-install.md (1)
207-210
: 🛠️ Refactor suggestion
Add Windows-specific requirements and limitations.
With the addition of Windows x86-64 support, consider documenting:
- Windows-specific prerequisites (e.g., Visual Studio version, Windows SDK)
- Any feature limitations on Windows
- Known issues or workarounds for Windows users
The supported platform includes Linux x86-64 and aarch64 with GNU C Library 2.28 or above, macOS x86-64 and arm64, and Windows x86-64.
+:::{note}
+Windows x86-64 support requires:
+- Windows 10 or later
+- Microsoft Visual Studio 2019 or later
+- Windows SDK 10.0.18362.0 or later
+
+Known limitations on Windows:
+- Feature X is not supported
+- Feature Y requires additional setup
+:::
+
:::{Warning}
Likely invalid or redundant comment.
source/CMakeLists.txt (3)
263-266
: LGTM! Good safeguard for CXX ABI
The default value and safeguard for OP_CXX_ABI are correctly implemented with a helpful comment explaining the purpose.
276-278
: LGTM! Backend status messages are consistent
The JAX backend is properly integrated into the status messages and backend validation check, maintaining consistency with other backends.
Also applies to: 281-281
7-11
: LGTM! Verify documentation updates
The JAX option and its dependency on TensorFlow are correctly implemented. The auto-enabling behavior when TensorFlow is enabled makes sense as JAX requires the TensorFlow C interface.
Let's verify that this dependency is properly documented:
✅ Verification successful
Documentation correctly reflects the JAX-TensorFlow dependency
The dependency between JAX and TensorFlow is well-documented across the codebase:
- Installation docs clearly explain that JAX backend relies on TensorFlow C API
- CMake options documentation accurately describes that enabling TensorFlow will auto-enable JAX
- Backend docs properly detail the SavedModel format requirements and C++ inference dependencies
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Check if the TensorFlow-JAX dependency is documented
rg -l "JAX.*TensorFlow|TensorFlow.*JAX" doc/
Length of output: 358
Script:
#!/bin/bash
# Check the actual documentation content about JAX-TensorFlow dependency
rg -A 5 -B 5 "JAX.*TensorFlow|TensorFlow.*JAX" doc/install/install-from-source.md doc/backend.md
Length of output: 6583
source/api_cc/src/DeepPot.cc (2)
15-17
: LGTM: Header inclusion properly handles JAX backend
The conditional compilation directive correctly includes DeepPotJAX.h when either TensorFlow or JAX is built, which aligns with the new JAX backend support.
Line range hint 42-47
: Verify consistent model file extensions
The backend determination uses ".savedmodel" extension for JAX models. Let's verify this is consistently used across the codebase.
✅ Verification successful
Let me gather more specific information about the model extension handling across these files.
Model file extensions are consistently defined across the codebase
The verification shows that model file extensions are consistently defined and handled:
- JAX backend uses
.savedmodel
extension in both C++ (DeepPot.cc
) and Python (backend/jax.py
) - PyTorch backend uses
.pth
and.pt
extensions - TensorFlow backend uses
.pb
extension
Each backend implementation properly declares its supported extensions in their respective backend classes, and the file extension checks in DeepPot.cc
align with these declarations.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Check for consistent model file extension usage
# Test: Search for savedmodel extension usage and related file handling code
# Check for savedmodel extension handling
rg -l "\.savedmodel"
# Check for model file extension handling in general
ast-grep --pattern 'if ($file$_.substr($file$_.length() - $_) == $_) {
$$$
}'
Length of output: 457
Script:
#!/bin/bash
# Check the actual implementation of model extension handling
rg -A 5 "\.savedmodel" source/api_cc/src/DeepPot.cc deepmd/jax/infer/deep_eval.py deepmd/backend/jax.py
# Check for any other model extension checks
rg -A 5 "\.(pb|pth)" source/api_cc/src/DeepPot.cc deepmd/backend/
Length of output: 2598
source/api_cc/src/DeepPotJAX.cc (3)
Line range hint 194-201
: Check for proper deallocation in destructor
In the destructor ~DeepPotJAX
, while you are deleting several members, ensure that all dynamically allocated resources are properly deallocated:
if (inited) {
TF_DeleteSession(session, status);
TF_DeleteGraph(graph);
TF_DeleteSessionOptions(sessionopts);
TF_DeleteStatus(status);
TFE_DeleteContext(ctx);
TFE_DeleteContextOptions(ctx_opts);
for (size_t i = 0; i < func_vector.size(); i++) {
TF_DeleteFunction(func_vector[i]);
}
}
Verify that there are no other resources that need to be deallocated, especially if new members were added to support JAX.
Line range hint 324-345
: Handle possible mismatches in tensor output order
Comments indicate uncertainty about the order of tensor outputs:
// it seems the order is the alphabet order?
// not sure whether it is safe to assume the order
Relying on implicit ordering can lead to bugs if the underlying library changes its behavior.
Consider explicitly retrieving tensors by name or confirming the order through documentation.
Line range hint 368-416
: Ensure type consistency in template specializations
In the template specializations for compute<double>
and compute<float>
, ensure that the data types for inputs and outputs are consistent and correctly handled for both double and float types.
Check that all necessary type conversions are performed and that precision is maintained where required.
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Outside diff range and nitpick comments (5)
source/cmake/FindTensorFlowC.cmake (4)
1-4
: Improve documentation formatting and clarity.The header comments could be more readable with proper formatting and grammar:
-# Find TensorFlow C library (libtensorflow) Define target -# TensorFlow::tensorflow_c If TensorFlow::tensorflow_cc is not found, also -# define: - TENSORFLOWC_INCLUDE_DIR - TENSORFLOWC_LIBRARY +# Finds the TensorFlow C library (libtensorflow) and defines the following: +# +# Always defined: +# - Target: TensorFlow::tensorflow_c +# +# Additionally defined when TensorFlow::tensorflow_cc is not found: +# - TENSORFLOWC_INCLUDE_DIR +# - TENSORFLOWC_LIBRARY
5-11
: Enhance comment to explain the tensorflow_cc dependency.The comment should better explain why using tensorflow_cc is valid.
- # since tensorflow_cc contain tensorflow_c, just use it + # TensorFlow C++ library (tensorflow_cc) includes the C library functionality, + # so we can safely create an interface dependency to avoid duplicate linking
13-33
: Consider adding version checking and search hints.The library finding logic could be enhanced with:
- Version checking using
tensorflow/c/version.h
- Search hints for common TensorFlow installation paths
if(NOT TensorFlowC_FOUND) + # Common installation prefixes + set(_TensorFlowC_PATHS + ${TENSORFLOW_ROOT} + $ENV{TENSORFLOW_ROOT} + /usr/local + /usr + ) + find_path( TENSORFLOWC_INCLUDE_DIR NAMES tensorflow/c/c_api.h + HINTS ${_TensorFlowC_PATHS} PATH_SUFFIXES include DOC "Path to TensorFlow C include directory") find_library( TENSORFLOWC_LIBRARY NAMES tensorflow + HINTS ${_TensorFlowC_PATHS} PATH_SUFFIXES lib lib64 DOC "Path to TensorFlow C library") + if(TENSORFLOWC_INCLUDE_DIR) + file(STRINGS "${TENSORFLOWC_INCLUDE_DIR}/tensorflow/c/version.h" + TensorFlowC_VERSION_STRING REGEX "^#define TF_VERSION_STRING \"[^\"]*\"") + string(REGEX REPLACE "^#define TF_VERSION_STRING \"([0-9.]+).*\"$" "\\1" + TensorFlowC_VERSION_STRING "${TensorFlowC_VERSION_STRING}") + endif()
35-41
: Add library type validation.The target is created as SHARED, but we should validate the actual library type:
+ get_filename_component(_lib_ext "${TENSORFLOWC_LIBRARY}" EXT) + if(_lib_ext MATCHES "${CMAKE_SHARED_LIBRARY_SUFFIX}") + set(_lib_type SHARED) + else() + set(_lib_type STATIC) + endif() - add_library(TensorFlow::tensorflow_c SHARED IMPORTED GLOBAL) + add_library(TensorFlow::tensorflow_c ${_lib_type} IMPORTED GLOBAL)source/CMakeLists.txt (1)
279-281
: Enhance backend status loggingConsider adding version information to the backend status logging to help with debugging and verification.
if(ENABLE_JAX) - message(STATUS "- JAX") + message(STATUS "- JAX (using TensorFlow C library ${TensorFlowC_VERSION})") endif()Also applies to: 284-284
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (2)
source/CMakeLists.txt
(3 hunks)source/cmake/FindTensorFlowC.cmake
(1 hunks)
🔇 Additional comments (1)
source/CMakeLists.txt (1)
266-269
: LGTM: Good practice for CXX ABI handling
Setting a default value for OP_CXX_ABI
prevents potential issues with empty values and follows GCC>=5 defaults.
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Outside diff range and nitpick comments (3)
source/cmake/FindTensorFlowC.cmake (3)
1-4
: Improve module documentation clarity.The current documentation is fragmented and could be clearer. Consider restructuring it as follows:
-# Find TensorFlow C library (libtensorflow) Define target -# TensorFlow::tensorflow_c If TensorFlow::tensorflow_cc is not found, also -# define: - TENSORFLOWC_INCLUDE_DIR - TENSORFLOWC_LIBRARY - +# FindTensorFlowC +# +# This module finds the TensorFlow C library (libtensorflow) and defines the following: +# +# Targets: +# TensorFlow::tensorflow_c - Interface or imported target for TensorFlow C library +# +# If TensorFlow::tensorflow_cc is not found, also defines: +# TENSORFLOWC_INCLUDE_DIR - Path to TensorFlow C headers +# TENSORFLOWC_LIBRARY - Path to TensorFlow C library +# TensorFlowC_FOUND - True if the TensorFlow C library is found
13-24
: Enhance library discovery robustness.Consider improving the library discovery process:
- Add more search paths (e.g., standard system paths)
- Support environment variables for custom installations
- Add version checking capability
find_path( TENSORFLOWC_INCLUDE_DIR NAMES tensorflow/c/c_api.h - PATH_SUFFIXES include + HINTS + $ENV{TENSORFLOW_ROOT} + $ENV{TENSORFLOW_HOME} + PATH_SUFFIXES + include + include/tensorflow DOC "Path to TensorFlow C include directory") find_library( TENSORFLOWC_LIBRARY NAMES tensorflow - PATH_SUFFIXES lib + HINTS + $ENV{TENSORFLOW_ROOT} + $ENV{TENSORFLOW_HOME} + PATH_SUFFIXES + lib + lib64 + lib/x86_64-linux-gnu DOC "Path to TensorFlow C library")
35-41
: Enhance imported target properties.Consider setting additional target properties for better integration:
add_library(TensorFlow::tensorflow_c SHARED IMPORTED GLOBAL) set_property(TARGET TensorFlow::tensorflow_c PROPERTY IMPORTED_LOCATION ${TENSORFLOWC_LIBRARY}) +set_property(TARGET TensorFlow::tensorflow_c PROPERTY + INTERFACE_INCLUDE_DIRECTORIES ${TENSORFLOWC_INCLUDE_DIR}) +# If version information is available +set_property(TARGET TensorFlow::tensorflow_c PROPERTY + IMPORTED_SONAME "libtensorflow${CMAKE_SHARED_LIBRARY_SUFFIX}") + target_include_directories(TensorFlow::tensorflow_c INTERFACE ${TENSORFLOWC_INCLUDE_DIR})Also, consider adding version information propagation if available from the library.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (1)
source/cmake/FindTensorFlowC.cmake
(1 hunks)
🔇 Additional comments (2)
source/cmake/FindTensorFlowC.cmake (2)
5-11
: LGTM! Efficient handling of tensorflow_cc dependency.
The code efficiently reuses the existing tensorflow_cc target when available, avoiding duplicate linking through an interface library.
26-33
: LGTM! Follows CMake best practices.
The code correctly uses FindPackageHandleStandardArgs and sets standard CMake variables.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Outside diff range and nitpick comments (3)
source/cmake/FindTensorFlowC.cmake (3)
1-4
: Improve module documentation readability and completeness.The documentation header could be more readable and comprehensive. Consider this format:
-# Find TensorFlow C library (libtensorflow) Define target -# TensorFlow::tensorflow_c If TensorFlow::tensorflow_cc is not found, also -# define: - TENSORFLOWC_INCLUDE_DIR - TENSORFLOWC_LIBRARY +# FindTensorFlowC +# +# Find TensorFlow C library (libtensorflow) +# +# This module defines the following targets: +# TensorFlow::tensorflow_c - TensorFlow C library target +# +# If TensorFlow::tensorflow_cc is not found, it also defines: +# TENSORFLOWC_INCLUDE_DIR - Include directory for TensorFlow C headers +# TENSORFLOWC_LIBRARY - Path to TensorFlow C library +# TensorFlowC_FOUND - True if TensorFlow C library is found
11-31
: Enhance library discovery robustness.The current implementation follows CMake best practices but could be more robust:
- Add version detection:
if(NOT TensorFlowC_FOUND) find_path( TENSORFLOWC_INCLUDE_DIR NAMES tensorflow/c/c_api.h PATH_SUFFIXES include DOC "Path to TensorFlow C include directory") if(TENSORFLOWC_INCLUDE_DIR) file(READ "${TENSORFLOWC_INCLUDE_DIR}/tensorflow/c/c_api.h" _tf_header) string(REGEX MATCH "TF_VERSION_STRING \"([0-9.]+)\"" _tf_version "${_tf_header}") set(TENSORFLOWC_VERSION "${CMAKE_MATCH_1}") endif() find_library(...) find_package_handle_standard_args( TensorFlowC REQUIRED_VARS TENSORFLOWC_LIBRARY TENSORFLOWC_INCLUDE_DIR VERSION_VAR TENSORFLOWC_VERSION) ... endif()
- Add hints from environment variables:
find_path( TENSORFLOWC_INCLUDE_DIR NAMES tensorflow/c/c_api.h HINTS ENV TENSORFLOW_ROOT ENV TENSORFLOW_INSTALL_PREFIX PATH_SUFFIXES include DOC "Path to TensorFlow C include directory")
33-37
: Add additional target properties for better integration.The target configuration could be enhanced with additional properties:
add_library(TensorFlow::tensorflow_c SHARED IMPORTED GLOBAL) set_property(TARGET TensorFlow::tensorflow_c PROPERTY IMPORTED_LOCATION ${TENSORFLOWC_LIBRARY}) +if(WIN32) + set_property(TARGET TensorFlow::tensorflow_c PROPERTY IMPORTED_IMPLIB + ${TENSORFLOWC_LIBRARY}) +endif() +set_property(TARGET TensorFlow::tensorflow_c PROPERTY + INTERFACE_COMPILE_DEFINITIONS TENSORFLOW_C_API) target_include_directories(TensorFlow::tensorflow_c INTERFACE ${TENSORFLOWC_INCLUDE_DIR})
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (1)
source/cmake/FindTensorFlowC.cmake
(1 hunks)
🔇 Additional comments (2)
source/cmake/FindTensorFlowC.cmake (2)
39-39
: LGTM! The advanced variables are correctly marked.
The syntax error from the previous review has been fixed, and the implementation follows CMake best practices.
5-9
: Consider additional validation for tensorflow_cc target.
While creating an alias is correct, it would be beneficial to verify that tensorflow_cc actually includes the C API functionality.
Consider adding a check:
if(TARGET TensorFlow::tensorflow_cc)
# Verify C API is available
get_target_property(_tf_cc_include TensorFlow::tensorflow_cc INTERFACE_INCLUDE_DIRECTORIES)
if(EXISTS "${_tf_cc_include}/tensorflow/c/c_api.h")
add_library(TensorFlow::tensorflow_c ALIAS TensorFlow::tensorflow_cc)
set(TensorFlowC_FOUND TRUE)
endif()
endif()
Summary by CodeRabbit
Release Notes
New Features
Documentation
Bug Fixes