From 99480e849ad95d68020e7c67f65582210a8227ff Mon Sep 17 00:00:00 2001 From: rusty1s Date: Mon, 30 Dec 2024 20:04:00 +0000 Subject: [PATCH] update --- .gitmodules | 6 ++++++ CMakeLists.txt | 9 +++++++++ pyg_lib/csrc/ops/cuda/sampled_kernel.cu | 4 +++- third_party/cccl | 1 + third_party/cuCollections | 1 + 5 files changed, 20 insertions(+), 1 deletion(-) create mode 160000 third_party/cccl create mode 160000 third_party/cuCollections diff --git a/.gitmodules b/.gitmodules index 685f5ab7c..1b7a5b22a 100644 --- a/.gitmodules +++ b/.gitmodules @@ -7,3 +7,9 @@ [submodule "third_party/METIS"] path = third_party/METIS url = https://github.com/KarypisLab/METIS.git +[submodule "third_party/cuCollections"] + path = third_party/cuCollections + url = https://github.com/NVIDIA/cuCollections.git +[submodule "third_party/cccl"] + path = third_party/cccl + url = https://github.com/NVIDIA/cccl.git diff --git a/CMakeLists.txt b/CMakeLists.txt index 5c2d27185..10afda777 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -52,6 +52,8 @@ if(WITH_CUDA) set(CUTLASS_UTIL_DIR third_party/cutlass/tools/util/include) include_directories(${CUTLASS_UTIL_DIR}) endif() + set(CUCOLLECTIONS_DIR third_party/cuCollections/include) + include_directories(${CUCOLLECTIONS_DIR}) endif() set(CSRC pyg_lib/csrc) @@ -105,6 +107,13 @@ if (USE_PYTHON) target_link_libraries(${PROJECT_NAME} PRIVATE Python3::Python) endif() +if(WITH_CUDA) + target_include_directories(${PROJECT_NAME} PRIVATE + third_party/cccl/thrust + third_party/cccl/cub + third_party/cccl/libcudacxx/include) +endif() + set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -O0") if(BUILD_TEST) diff --git a/pyg_lib/csrc/ops/cuda/sampled_kernel.cu b/pyg_lib/csrc/ops/cuda/sampled_kernel.cu index f28ab47b5..a73951f37 100644 --- a/pyg_lib/csrc/ops/cuda/sampled_kernel.cu +++ b/pyg_lib/csrc/ops/cuda/sampled_kernel.cu @@ -2,13 +2,15 @@ #include #include +#include + namespace pyg { namespace ops { namespace { #define THREADS 1024 -#define CDIV(N, M) ((N) + (M) - 1) / (M) +#define CDIV(N, M) ((N) + (M)-1) / (M) enum FnType { ADD, SUB, MUL, DIV }; const std::map to_fn_type = { diff --git a/third_party/cccl b/third_party/cccl new file mode 160000 index 000000000..faca86cc0 --- /dev/null +++ b/third_party/cccl @@ -0,0 +1 @@ +Subproject commit faca86cc08941b25799da1be74b36ee18ae436df diff --git a/third_party/cuCollections b/third_party/cuCollections new file mode 160000 index 000000000..e79787be2 --- /dev/null +++ b/third_party/cuCollections @@ -0,0 +1 @@ +Subproject commit e79787be2cb3de1b12e90d56355612e47395cce5