From fc7ba825f0a692102f91e067f0adab630dae00b4 Mon Sep 17 00:00:00 2001 From: Luc Berger-Vergiat Date: Thu, 15 Feb 2024 11:47:34 -0700 Subject: [PATCH] Lapack - SVD: fix for unit-test when MKL is enabled This is really a problem with our implementation of the BLAS interface when MKL is enabled since MKL redefines the function signatures of blas functions using MKL_INT instead if int... --- blas/tpls/KokkosBlas3_gemm_tpl_spec_decl.hpp | 18 ++--- blas/tpls/KokkosBlas_Host_tpl.cpp | 73 +++++++++++--------- blas/tpls/KokkosBlas_Host_tpl.hpp | 17 +++-- lapack/unit_test/Test_Lapack_svd.hpp | 4 ++ 4 files changed, 65 insertions(+), 47 deletions(-) diff --git a/blas/tpls/KokkosBlas3_gemm_tpl_spec_decl.hpp b/blas/tpls/KokkosBlas3_gemm_tpl_spec_decl.hpp index 66177e28a6..68bf2708ec 100644 --- a/blas/tpls/KokkosBlas3_gemm_tpl_spec_decl.hpp +++ b/blas/tpls/KokkosBlas3_gemm_tpl_spec_decl.hpp @@ -60,20 +60,20 @@ namespace Impl { Kokkos::Profiling::pushRegion("KokkosBlas::gemm[TPL_BLAS," #SCALAR_TYPE \ "]"); \ const bool A_t = (transA[0] != 'N') && (transA[0] != 'n'); \ - const int M = C.extent(0); \ - const int N = C.extent(1); \ - const int K = A.extent(A_t ? 0 : 1); \ + const KK_INT M = C.extent(0); \ + const KK_INT N = C.extent(1); \ + const KK_INT K = A.extent(A_t ? 0 : 1); \ \ bool A_is_lr = std::is_same::value; \ bool B_is_lr = std::is_same::value; \ bool C_is_lr = std::is_same::value; \ \ - const int AST = A_is_lr ? A.stride(0) : A.stride(1), \ - LDA = AST == 0 ? 1 : AST; \ - const int BST = B_is_lr ? B.stride(0) : B.stride(1), \ - LDB = BST == 0 ? 1 : BST; \ - const int CST = C_is_lr ? C.stride(0) : C.stride(1), \ - LDC = CST == 0 ? 1 : CST; \ + const KK_INT AST = A_is_lr ? A.stride(0) : A.stride(1), \ + LDA = AST == 0 ? 1 : AST; \ + const KK_INT BST = B_is_lr ? B.stride(0) : B.stride(1), \ + LDB = BST == 0 ? 1 : BST; \ + const KK_INT CST = C_is_lr ? C.stride(0) : C.stride(1), \ + LDC = CST == 0 ? 1 : CST; \ \ const BASE_SCALAR_TYPE alpha_val = alpha, beta_val = beta; \ if (!A_is_lr && !B_is_lr && !C_is_lr) \ diff --git a/blas/tpls/KokkosBlas_Host_tpl.cpp b/blas/tpls/KokkosBlas_Host_tpl.cpp index ec739aa98a..68f2810907 100644 --- a/blas/tpls/KokkosBlas_Host_tpl.cpp +++ b/blas/tpls/KokkosBlas_Host_tpl.cpp @@ -22,6 +22,8 @@ #if defined(KOKKOSKERNELS_ENABLE_TPL_BLAS) +using KokkosBlas::Impl::KK_INT; + /// Fortran headers extern "C" { @@ -339,26 +341,27 @@ void F77_BLAS_MANGLE(ztrsv, ZTRSV)(const char*, const char*, const char*, int*, /// Gemm /// -void F77_BLAS_MANGLE(sgemm, SGEMM)(const char*, const char*, int*, int*, int*, - const float*, const float*, int*, - const float*, int*, const float*, - /* */ float*, int*); -void F77_BLAS_MANGLE(dgemm, DGEMM)(const char*, const char*, int*, int*, int*, - const double*, const double*, int*, - const double*, int*, const double*, - /* */ double*, int*); -void F77_BLAS_MANGLE(cgemm, CGEMM)(const char*, const char*, int*, int*, int*, - const std::complex*, - const std::complex*, int*, - const std::complex*, int*, +void F77_BLAS_MANGLE(sgemm, SGEMM)(const char*, const char*, KK_INT*, KK_INT*, + KK_INT*, const float*, const float*, KK_INT*, + const float*, KK_INT*, const float*, + /* */ float*, KK_INT*); +void F77_BLAS_MANGLE(dgemm, DGEMM)(const char*, const char*, KK_INT*, KK_INT*, + KK_INT*, const double*, const double*, + KK_INT*, const double*, KK_INT*, + const double*, + /* */ double*, KK_INT*); +void F77_BLAS_MANGLE(cgemm, CGEMM)(const char*, const char*, KK_INT*, KK_INT*, + KK_INT*, const std::complex*, + const std::complex*, KK_INT*, + const std::complex*, KK_INT*, const std::complex*, - /* */ std::complex*, int*); -void F77_BLAS_MANGLE(zgemm, ZGEMM)(const char*, const char*, int*, int*, int*, + /* */ std::complex*, KK_INT*); +void F77_BLAS_MANGLE(zgemm, ZGEMM)(const char*, const char*, KK_INT*, KK_INT*, + KK_INT*, const std::complex*, + const std::complex*, KK_INT*, + const std::complex*, KK_INT*, const std::complex*, - const std::complex*, int*, - const std::complex*, int*, - const std::complex*, - /* */ std::complex*, int*); + /* */ std::complex*, KK_INT*); /// /// Herk @@ -632,10 +635,11 @@ void HostBlas::trsv(const char uplo, const char transa, const char diag, F77_FUNC_STRSV(&uplo, &transa, &diag, &m, a, &lda, b, &ldb); } template <> -void HostBlas::gemm(const char transa, const char transb, int m, int n, - int k, const float alpha, const float* a, int lda, - const float* b, int ldb, const float beta, - /* */ float* c, int ldc) { +void HostBlas::gemm(const char transa, const char transb, KK_INT m, + KK_INT n, KK_INT k, const float alpha, + const float* a, KK_INT lda, const float* b, + KK_INT ldb, const float beta, + /* */ float* c, KK_INT ldc) { F77_FUNC_SGEMM(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc); } @@ -750,10 +754,11 @@ void HostBlas::trsv(const char uplo, const char transa, const char diag, F77_FUNC_DTRSV(&uplo, &transa, &diag, &m, a, &lda, b, &ldb); } template <> -void HostBlas::gemm(const char transa, const char transb, int m, int n, - int k, const double alpha, const double* a, int lda, - const double* b, int ldb, const double beta, - /* */ double* c, int ldc) { +void HostBlas::gemm(const char transa, const char transb, KK_INT m, + KK_INT n, KK_INT k, const double alpha, + const double* a, KK_INT lda, const double* b, + KK_INT ldb, const double beta, + /* */ double* c, KK_INT ldc) { F77_FUNC_DGEMM(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc); } @@ -906,10 +911,10 @@ void HostBlas >::trsv(const char uplo, const char transa, } template <> void HostBlas >::gemm( - const char transa, const char transb, int m, int n, int k, - const std::complex alpha, const std::complex* a, int lda, - const std::complex* b, int ldb, const std::complex beta, - /* */ std::complex* c, int ldc) { + const char transa, const char transb, KK_INT m, KK_INT n, KK_INT k, + const std::complex alpha, const std::complex* a, KK_INT lda, + const std::complex* b, KK_INT ldb, const std::complex beta, + /* */ std::complex* c, KK_INT ldc) { F77_FUNC_CGEMM(&transa, &transb, &m, &n, &k, &alpha, (const std::complex*)a, &lda, (const std::complex*)b, &ldb, &beta, @@ -1081,10 +1086,10 @@ void HostBlas >::trsv(const char uplo, const char transa, template <> void HostBlas >::gemm( - const char transa, const char transb, int m, int n, int k, - const std::complex alpha, const std::complex* a, int lda, - const std::complex* b, int ldb, const std::complex beta, - /* */ std::complex* c, int ldc) { + const char transa, const char transb, KK_INT m, KK_INT n, KK_INT k, + const std::complex alpha, const std::complex* a, KK_INT lda, + const std::complex* b, KK_INT ldb, const std::complex beta, + /* */ std::complex* c, KK_INT ldc) { F77_FUNC_ZGEMM(&transa, &transb, &m, &n, &k, &alpha, (const std::complex*)a, &lda, (const std::complex*)b, &ldb, &beta, diff --git a/blas/tpls/KokkosBlas_Host_tpl.hpp b/blas/tpls/KokkosBlas_Host_tpl.hpp index 29afff4d62..8e8781bfcf 100644 --- a/blas/tpls/KokkosBlas_Host_tpl.hpp +++ b/blas/tpls/KokkosBlas_Host_tpl.hpp @@ -25,10 +25,19 @@ #include "Kokkos_ArithTraits.hpp" #if defined(KOKKOSKERNELS_ENABLE_TPL_BLAS) +#if defined(KOKKOSKERNELS_ENABLE_TPL_MKL) +#include "mkl_types.h" +#endif namespace KokkosBlas { namespace Impl { +#if defined(KOKKOSKERNELS_ENABLE_TPL_MKL) +using KK_INT = MKL_INT; +#else +using KK_INT = int; +#endif + template struct HostBlas { typedef Kokkos::ArithTraits ats; @@ -97,10 +106,10 @@ struct HostBlas { const T *a, int lda, /* */ T *b, int ldb); - static void gemm(const char transa, const char transb, int m, int n, int k, - const T alpha, const T *a, int lda, const T *b, int ldb, - const T beta, - /* */ T *c, int ldc); + static void gemm(const char transa, const char transb, KK_INT m, KK_INT n, + KK_INT k, const T alpha, const T *a, KK_INT lda, const T *b, + KK_INT ldb, const T beta, + /* */ T *c, KK_INT ldc); static void herk(const char transa, const char transb, int n, int k, const T alpha, const T *a, int lda, const T beta, diff --git a/lapack/unit_test/Test_Lapack_svd.hpp b/lapack/unit_test/Test_Lapack_svd.hpp index 3b2cd3d8d5..6cf161fd3b 100644 --- a/lapack/unit_test/Test_Lapack_svd.hpp +++ b/lapack/unit_test/Test_Lapack_svd.hpp @@ -529,8 +529,10 @@ int test_svd() { Kokkos::View; ret = Test::impl_analytic_2x2_svd(); + EXPECT_EQ(ret, 0); ret = Test::impl_analytic_2x3_svd(); + EXPECT_EQ(ret, 0); ret = Test::impl_test_svd(0, 0); EXPECT_EQ(ret, 0); @@ -558,8 +560,10 @@ int test_svd() { Kokkos::View; ret = Test::impl_analytic_2x2_svd(); + EXPECT_EQ(ret, 0); ret = Test::impl_analytic_2x3_svd(); + EXPECT_EQ(ret, 0); ret = Test::impl_test_svd(0, 0); EXPECT_EQ(ret, 0);