diff --git a/csrc/cuda/utils.cuh b/csrc/cuda/utils.cuh index ba4f3a11..747a8e2c 100644 --- a/csrc/cuda/utils.cuh +++ b/csrc/cuda/utils.cuh @@ -6,9 +6,10 @@ AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor") #define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch") -__device__ __inline__ at::Half -__shfl_sync(const unsigned mask, const at::Half var, const int srcLane) { - return __shfl_sync(mask, var.operator __half(), srcLane); +__device__ __inline__ at::Half __shfl_up_sync(const unsigned mask, + const at::Half var, + const unsigned int delta) { + return __shfl_up_sync(mask, var.operator __half(), delta); } __device__ __inline__ at::Half __shfl_down_sync(const unsigned mask, @@ -17,6 +18,27 @@ __device__ __inline__ at::Half __shfl_down_sync(const unsigned mask, return __shfl_down_sync(mask, var.operator __half(), delta); } +__device__ __inline__ at::Half __shfl_sync(const unsigned mask, + const at::Half var, + const int delta) { + return __shfl_sync(mask, var.operator __half(), delta); +} + +__device__ __inline__ at::Half __shfl_up(const at::Half var, + const unsigned int delta) { + return __shfl_up(var.operator __half(), delta); +} + +__device__ __inline__ at::Half __shfl_down(const at::Half var, + const unsigned int delta) { + return __shfl_down(var.operator __half(), delta); +} + +__device__ __inline__ at::Half +__shfl(const at::Half var, const int delta) { + return __shfl(var.operator __half(), delta); +} + #ifdef USE_ROCM __device__ __inline__ at::Half __ldg(const at::Half* ptr) { return __ldg(reinterpret_cast(ptr));