diff --git a/include/mscclpp/device.hpp b/include/mscclpp/device.hpp index 73073062d..9dae40422 100644 --- a/include/mscclpp/device.hpp +++ b/include/mscclpp/device.hpp @@ -4,20 +4,20 @@ #ifndef MSCCLPP_DEVICE_HPP_ #define MSCCLPP_DEVICE_HPP_ -#if defined(__HIP_PLATFORM_AMD__) && (__HIP_PLATFORM_AMD__ == 1) +#if defined(__HIP_PLATFORM_AMD__) #include -#endif // defined(__HIP_PLATFORM_AMD__) && (__HIP_PLATFORM_AMD__ == 1) +#endif // defined(__HIP_PLATFORM_AMD__) #if (defined(__NVCC__) || defined(__HIP_PLATFORM_AMD__)) #define MSCCLPP_DEVICE_COMPILE #define MSCCLPP_DEVICE_INLINE __forceinline__ __device__ #define MSCCLPP_HOST_DEVICE_INLINE __forceinline__ __host__ __device__ -#if defined(__HIP_PLATFORM_AMD__) && (__HIP_PLATFORM_AMD__ == 1) +#if defined(__HIP_PLATFORM_AMD__) #define MSCCLPP_DEVICE_HIP -#else // !(defined(__HIP_PLATFORM_AMD__) && (__HIP_PLATFORM_AMD__ == 1)) +#else // !(defined(__HIP_PLATFORM_AMD__) #define MSCCLPP_DEVICE_CUDA -#endif // !(defined(__HIP_PLATFORM_AMD__) && (__HIP_PLATFORM_AMD__ == 1)) +#endif // !(defined(__HIP_PLATFORM_AMD__)) #else // !(defined(__NVCC__) || defined(__HIP_PLATFORM_AMD__)) diff --git a/include/mscclpp/gpu.hpp b/include/mscclpp/gpu.hpp index a40a6d106..d3d48ce1f 100644 --- a/include/mscclpp/gpu.hpp +++ b/include/mscclpp/gpu.hpp @@ -4,7 +4,7 @@ #ifndef MSCCLPP_GPU_HPP_ #define MSCCLPP_GPU_HPP_ -#if defined(__HIP_PLATFORM_AMD__) && (__HIP_PLATFORM_AMD__ == 1) +#if defined(__HIP_PLATFORM_AMD__) #include diff --git a/include/mscclpp/gpu_utils.hpp b/include/mscclpp/gpu_utils.hpp index 154f87723..e0cd7c3da 100644 --- a/include/mscclpp/gpu_utils.hpp +++ b/include/mscclpp/gpu_utils.hpp @@ -72,7 +72,7 @@ T* cudaExtCalloc(size_t nelem) { AvoidCudaGraphCaptureGuard cgcGuard; T* ptr; CudaStreamWithFlags stream(cudaStreamNonBlocking); -#if defined(__HIP_PLATFORM_AMD__) && (__HIP_PLATFORM_AMD__ == 1) +#if defined(__HIP_PLATFORM_AMD__) MSCCLPP_CUDATHROW(hipExtMallocWithFlags((void**)&ptr, nelem * sizeof(T), hipDeviceMallocUncached)); #else MSCCLPP_CUDATHROW(cudaMalloc(&ptr, nelem * sizeof(T))); diff --git a/test/allgather_test_cpp.cu b/test/allgather_test_cpp.cu index 08b4f6bff..2f56b221d 100644 --- a/test/allgather_test_cpp.cu +++ b/test/allgather_test_cpp.cu @@ -74,7 +74,7 @@ __device__ void localAllGather(DeviceHandle proxyCh if ((remoteRank % nranksPerNode) == ((rank - i + nranksPerNode) % nranksPerNode)) { if ((threadIdx.x % 32) == 0) proxyChan.wait(); } -#if defined(__HIP_PLATFORM_AMD__) && (__HIP_PLATFORM_AMD__ == 1) +#if defined(__HIP_PLATFORM_AMD__) // NOTE: we actually need a group barrier here for better performance, but __syncthreads() is still correct. __syncthreads(); #else diff --git a/test/mscclpp-test/allgather_test.cu b/test/mscclpp-test/allgather_test.cu index 529826f32..4b2eff78f 100644 --- a/test/mscclpp-test/allgather_test.cu +++ b/test/mscclpp-test/allgather_test.cu @@ -7,7 +7,7 @@ #include "common.hpp" -#if defined(__HIP_PLATFORM_AMD__) && (__HIP_PLATFORM_AMD__ == 1) +#if defined(__HIP_PLATFORM_AMD__) #define WARP_SIZE 64 #else #define WARP_SIZE 32 @@ -65,7 +65,7 @@ __device__ void localAllGather(DeviceHandle proxyCh if ((remoteRank % nRanksPerNode) == ((rank - i + nRanksPerNode) % nRanksPerNode)) { if ((threadIdx.x % WARP_SIZE) == 0) proxyChan.wait(); } -#if defined(__HIP_PLATFORM_AMD__) && (__HIP_PLATFORM_AMD__ == 1) +#if defined(__HIP_PLATFORM_AMD__) // NOTE: we actually need a group barrier here for better performance, but __syncthreads() is still correct. __syncthreads(); #else diff --git a/test/mscclpp-test/allreduce_test.cu b/test/mscclpp-test/allreduce_test.cu index 06264c7b3..2748681b4 100644 --- a/test/mscclpp-test/allreduce_test.cu +++ b/test/mscclpp-test/allreduce_test.cu @@ -956,7 +956,7 @@ __global__ void allreduce4(int* buff, int* scratch, int rank, int nRanksPerNode, } __global__ void allreduce5(int* buff, int rank, int nRanksPerNode, int worldSize, size_t nelems) { -#if defined(__HIP_PLATFORM_AMD__) && (__HIP_PLATFORM_AMD__ == 1) +#if defined(__HIP_PLATFORM_AMD__) localReduceScatterSm3(buff, rank, nRanksPerNode, nelems / worldSize, nelems / worldSize, gridDim.x); deviceSyncer.sync(gridDim.x); localRingAllGatherSm2(rank, nRanksPerNode, nelems / worldSize * sizeof(int), gridDim.x);