From 0f01a3908af9fea66b7d29df2848c0a6a7fe18ea Mon Sep 17 00:00:00 2001 From: Jean-Luc Fattebert Date: Fri, 17 Jan 2020 17:18:29 -0500 Subject: [PATCH] Use xsmm library --- CMakeLists.txt | 9 +++ cmake/FindXSMM.cmake | 36 ++++++++++++ src/C-interface/blas.h | 32 +++++++++++ .../ellblock/bml_multiply_ellblock_typed.c | 8 +-- src/internal-blas/bml_gemm.c | 30 ++++++++++ src/internal-blas/bml_gemm.h | 57 +++++++++++++++++++ src/typed.h | 5 ++ 7 files changed, 173 insertions(+), 4 deletions(-) create mode 100644 cmake/FindXSMM.cmake diff --git a/CMakeLists.txt b/CMakeLists.txt index 0be94dfde..646da87cb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -353,6 +353,15 @@ if(NOT HAVE_FABS) endif() list(APPEND LINK_LIBRARIES -lm) +find_package(XSMM) + +if(XSMM_FOUND) + message(STATUS "Use XSMM") + add_definitions(-DBML_USE_XSMM) + include_directories(${XSMM_INCLUDE_DIRS}) + list(APPEND LINK_LIBRARIES ${XSMM_LIBRARY_DIRS}/libxsmm.so) +endif() + if(NOT (BLAS_FOUND OR NOBLAS)) message(FATAL_ERROR "Could not find BLAS library.") endif() diff --git a/cmake/FindXSMM.cmake b/cmake/FindXSMM.cmake new file mode 100644 index 000000000..15843d20e --- /dev/null +++ b/cmake/FindXSMM.cmake @@ -0,0 +1,36 @@ +# - Find the XSMM library +# +# Usage: +# find_package(XSMM [REQUIRED] [QUIET] ) +# +# It sets the following variables: +# XSMM_FOUND ... true if magma is found on the system +# XSMM_LIBRARY_DIRS ... full path to magma library +# XSMM_INCLUDE_DIRS ... magma include directory +# XSMM_LIBRARIES ... magma libraries +# +# The following variables will be checked by the function +# XSMM_USE_STATIC_LIBS ... if true, only static libraries are found +# XSMM_ROOT ... if set, the libraries are exclusively searched +# under this path + +#If environment variable XSMM_ROOT is specified, it has same effect as XSMM_ROOT +if( NOT XSMM_ROOT AND NOT $ENV{XSMM_ROOT} STREQUAL "" ) + set( XSMM_ROOT $ENV{XSMM_ROOT} ) + # set library directories + set(XSMM_LIBRARY_DIRS ${XSMM_ROOT}/lib) + # set include directories + set(XSMM_INCLUDE_DIRS ${XSMM_ROOT}/include) + # set libraries + find_library( + XSMM_LIBRARIES + NAMES "libxsmm" + PATHS ${XSMM_ROOT} + PATH_SUFFIXES "lib" + NO_DEFAULT_PATH + ) + set(XSMM_FOUND TRUE) +else() + set(XSMM_FOUND FALSE) +endif() + diff --git a/src/C-interface/blas.h b/src/C-interface/blas.h index 7a7c3b3db..dd0076041 100644 --- a/src/C-interface/blas.h +++ b/src/C-interface/blas.h @@ -3,6 +3,8 @@ #include +#include "../typed.h" + void C_SSCAL( const int *n, const float *a, @@ -108,4 +110,34 @@ void C_ZAXPY( double complex * y, const int *incy); +void XSMM( + C_SGEMM) ( + const char *transa, + const char *transb, + const int *m, + const int *n, + const int *k, + const float *alpha, + const float *a, + const int *lda, + const float *b, + const int *ldb, + const float *beta, + float *c, + const int *ldc); +void XSMM( + C_DGEMM) ( + const char *transa, + const char *transb, + const int *m, + const int *n, + const int *k, + const double *alpha, + const double *a, + const int *lda, + const double *b, + const int *ldb, + const double *beta, + double *c, + const int *ldc); #endif diff --git a/src/C-interface/ellblock/bml_multiply_ellblock_typed.c b/src/C-interface/ellblock/bml_multiply_ellblock_typed.c index 7e1a6eb9e..133c77fdd 100644 --- a/src/C-interface/ellblock/bml_multiply_ellblock_typed.c +++ b/src/C-interface/ellblock/bml_multiply_ellblock_typed.c @@ -165,10 +165,10 @@ void *TYPED_FUNC( #if 1 REAL_T alpha = (REAL_T) 1.; REAL_T beta = (REAL_T) 1.; - TYPED_FUNC(bml_gemm) ("N", "N", &bsize[kb], &bsize[ib], - &bsize[jb], &alpha, X_value_right, - &bsize[kb], X_value_left, &bsize[jb], - &beta, x, &bsize[kb]); + TYPED_FUNC(bml_xsmm_gemm) ("N", "N", &bsize[kb], &bsize[ib], + &bsize[jb], &alpha, X_value_right, + &bsize[kb], X_value_left, + &bsize[jb], &beta, x, &bsize[kb]); #else for (int ii = 0; ii < bsize[ib]; ii++) for (int jj = 0; jj < bsize[kb]; jj++) diff --git a/src/internal-blas/bml_gemm.c b/src/internal-blas/bml_gemm.c index 3062fe6d6..63f15b47d 100644 --- a/src/internal-blas/bml_gemm.c +++ b/src/internal-blas/bml_gemm.c @@ -282,3 +282,33 @@ void TYPED_FUNC( #endif #endif } + +void TYPED_FUNC( + bml_xsmm_gemm) ( + const char *transa, + const char *transb, + const int *m, + const int *n, + const int *k, + const REAL_T * alpha, + const REAL_T * a, + const int *lda, + const REAL_T * b, + const int *ldb, + const REAL_T * beta, + REAL_T * c, + const int *ldc) +{ +#ifdef BML_INTERNAL_GEMM + TYPED_FUNC(bml_gemm_internal) (transa, transb, m, n, k, alpha, a, + lda, b, ldb, beta, c, ldc); +#else + +#ifndef BML_USE_XSMM + LOG_ERROR("No XSMM library"); +#else + XSMM(C_BLAS(GEMM)) (transa, transb, m, n, k, alpha, a, + lda, b, ldb, beta, c, ldc); +#endif +#endif +} diff --git a/src/internal-blas/bml_gemm.h b/src/internal-blas/bml_gemm.h index fa8ce9cbe..f54674c8d 100644 --- a/src/internal-blas/bml_gemm.h +++ b/src/internal-blas/bml_gemm.h @@ -60,4 +60,61 @@ void bml_gemm_double_complex( double complex * c, const int *ldc); +void bml_xsmm_gemm_single_real( + const char *transa, + const char *transb, + const int *m, + const int *n, + const int *k, + const float *alpha, + const float *a, + const int *lda, + const float *b, + const int *ldb, + const float *beta, + float *c, + const int *ldc); +void bml_xsmm_gemm_double_real( + const char *transa, + const char *transb, + const int *m, + const int *n, + const int *k, + const double *alpha, + const double *a, + const int *lda, + const double *b, + const int *ldb, + const double *beta, + double *c, + const int *ldc); +void bml_xsmm_gemm_single_complex( + const char *transa, + const char *transb, + const int *m, + const int *n, + const int *k, + const float complex * alpha, + const float complex * a, + const int *lda, + const float complex * b, + const int *ldb, + const float complex * beta, + float complex * c, + const int *ldc); +void bml_xsmm_gemm_double_complex( + const char *transa, + const char *transb, + const int *m, + const int *n, + const int *k, + const double complex * alpha, + const double complex * a, + const int *lda, + const double complex * b, + const int *ldb, + const double complex * beta, + double complex * c, + const int *ldc); + #endif diff --git a/src/typed.h b/src/typed.h index 257d21ef1..b2fd41745 100644 --- a/src/typed.h +++ b/src/typed.h @@ -21,6 +21,7 @@ #define MATRIX_PRECISION single_real #define BLAS_PREFIX S #define MAGMA_PREFIX s +#define XSMM_PREFIX libxsmm_ #define REAL_PART(x) (x) #define IMAGINARY_PART(x) (0) #define COMPLEX_CONJUGATE(x) (x) @@ -32,6 +33,7 @@ #define MATRIX_PRECISION double_real #define BLAS_PREFIX D #define MAGMA_PREFIX d +#define XSMM_PREFIX libxsmm_ #define REAL_PART(x) (x) #define IMAGINARY_PART(x) (0) #define COMPLEX_CONJUGATE(x) (x) @@ -43,6 +45,7 @@ #define MATRIX_PRECISION single_complex #define BLAS_PREFIX C #define MAGMA_PREFIX c +#define XSMM_PREFIX #define REAL_PART(x) (crealf(x)) #define IMAGINARY_PART(x) (cimagf(x)) #define COMPLEX_CONJUGATE(x) (conjf(x)) @@ -54,6 +57,7 @@ #define MATRIX_PRECISION double_complex #define BLAS_PREFIX Z #define MAGMA_PREFIX z +#define XSMM_PREFIX #define REAL_PART(x) (creal(x)) #define IMAGINARY_PART(x) (cimag(x)) #define COMPLEX_CONJUGATE(x) (conj(x)) @@ -71,6 +75,7 @@ #define TYPED_FUNC(a) CONCAT_(a, FUNC_SUFFIX) #define C_BLAS(a) CONCAT_(C, CONCAT(BLAS_PREFIX , a)) +#define XSMM(a) CONCAT(XSMM_PREFIX , a) #define MAGMACOMPLEX(a) CONCAT_(MAGMA, CONCAT_(BLAS_PREFIX, a)) #define MAGMA(a) CONCAT_(magma, CONCAT(MAGMA_PREFIX , a)) #define MAGMAGPU(a) CONCAT_(magma, CONCAT(MAGMA_PREFIX , CONCAT_(a, gpu)))