From 63594a0654fcad0c1338ef675347dc9185d46eef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ren=C3=A9=20Widera?= Date: Thu, 8 Dec 2022 10:02:27 +0100 Subject: [PATCH] add HIP support - add function attribute to define macros for HIP - add unroll support for HIP - fix missing function defines for function friend declarations --- include/llama/RecordRef.hpp | 7 +++++-- include/llama/macros.hpp | 9 +++++---- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/include/llama/RecordRef.hpp b/include/llama/RecordRef.hpp index e52396d3a4..97c0b2737e 100644 --- a/include/llama/RecordRef.hpp +++ b/include/llama/RecordRef.hpp @@ -742,9 +742,12 @@ namespace llama // to find subsequent elements. This is not a great design for now and the SIMD load/store functions should // probably take iterators to records. template - friend void internal::loadSimdRecord(const T& srcRef, Simd& dstSimd, RecordCoord rc); + friend LLAMA_FN_HOST_ACC_INLINE void internal::loadSimdRecord(const T& srcRef, Simd& dstSimd, RecordCoord rc); template - friend void internal::storeSimdRecord(const Simd& srcSimd, T&& dstRef, RecordCoord rc); + friend LLAMA_FN_HOST_ACC_INLINE void internal::storeSimdRecord( + const Simd& srcSimd, + T&& dstRef, + RecordCoord rc); }; // swap for heterogeneous RecordRef diff --git a/include/llama/macros.hpp b/include/llama/macros.hpp index 0b2c01803b..67c0f62049 100644 --- a/include/llama/macros.hpp +++ b/include/llama/macros.hpp @@ -34,7 +34,7 @@ #endif #ifndef LLAMA_FORCE_INLINE -# if defined(__NVCC__) +# if defined(__NVCC__) || defined(__HIP__) # define LLAMA_FORCE_INLINE __forceinline__ # elif defined(__GNUC__) || defined(__clang__) # define LLAMA_FORCE_INLINE inline __attribute__((always_inline)) @@ -52,7 +52,8 @@ #endif #ifndef LLAMA_UNROLL -# if defined(__NVCC__) || defined(__NVCOMPILER) || defined(__clang__) || defined(__INTEL_LLVM_COMPILER) +# if defined(__HIP__) || defined(__NVCC__) || defined(__NVCOMPILER) || defined(__clang__) \ + || defined(__INTEL_LLVM_COMPILER) # define LLAMA_UNROLL(...) LLAMA_PRAGMA(unroll __VA_ARGS__) # elif defined(__GNUG__) # define LLAMA_UNROLL(...) LLAMA_PRAGMA(GCC unroll __VA_ARGS__) @@ -68,7 +69,7 @@ #endif #ifndef LLAMA_ACC -# if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) +# if defined(__HIP__) || defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) # define LLAMA_ACC __device__ # elif defined(__GNUC__) || defined(__clang__) || defined(_MSC_VER) || defined(__INTEL_LLVM_COMPILER) # define LLAMA_ACC @@ -79,7 +80,7 @@ #endif #ifndef LLAMA_HOST_ACC -# if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) +# if defined(__HIP__) || defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) # define LLAMA_HOST_ACC __host__ __device__ # elif defined(__GNUC__) || defined(__clang__) || defined(_MSC_VER) || defined(__INTEL_LLVM_COMPILER) # define LLAMA_HOST_ACC