-
Notifications
You must be signed in to change notification settings - Fork 99
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
move dot-based GEMM out of TPL CUBLAS.. #1050
Changes from 2 commits
40410ec
4be3c65
a9f944c
639d483
81b359d
dcd8f36
282a4ef
fdad674
6d6ee66
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,206 @@ | ||
/* | ||
//@HEADER | ||
// ************************************************************************ | ||
// | ||
// Kokkos v. 3.0 | ||
// Copyright (2020) National Technology & Engineering | ||
// Solutions of Sandia, LLC (NTESS). | ||
// | ||
// Under the terms of Contract DE-NA0003525 with NTESS, | ||
// the U.S. Government retains certain rights in this software. | ||
// | ||
// Redistribution and use in source and binary forms, with or without | ||
// modification, are permitted provided that the following conditions are | ||
// met: | ||
// | ||
// 1. Redistributions of source code must retain the above copyright | ||
// notice, this list of conditions and the following disclaimer. | ||
// | ||
// 2. Redistributions in binary form must reproduce the above copyright | ||
// notice, this list of conditions and the following disclaimer in the | ||
// documentation and/or other materials provided with the distribution. | ||
// | ||
// 3. Neither the name of the Corporation nor the names of the | ||
// contributors may be used to endorse or promote products derived from | ||
// this software without specific prior written permission. | ||
// | ||
// THIS SOFTWARE IS PROVIDED BY NTESS "AS IS" AND ANY | ||
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR | ||
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NTESS OR THE | ||
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, | ||
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, | ||
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR | ||
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF | ||
// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING | ||
// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS | ||
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | ||
// | ||
// Questions? Contact Siva Rajamanickam (srajama@sandia.gov) | ||
// | ||
// ************************************************************************ | ||
//@HEADER | ||
*/ | ||
|
||
#ifndef KOKKOS_BLAS3_GEMM_DOTBASED_IMPL_HPP_ | ||
#define KOKKOS_BLAS3_GEMM_DOTBASED_IMPL_HPP_ | ||
|
||
namespace KokkosBlas { | ||
namespace Impl { | ||
|
||
|
||
// DotBasedGEMM implements the optimization for C = beta*C + alpha*A^TB | ||
// with A and B matrices both being tall and skinny. C matrix is assumably | ||
// small, so, each entry of C is computed by performing the dot product of | ||
// respective columns of A and B matrices. Note that the dot products are | ||
// performed on very long vectors, so, each dot product is distributed among | ||
// numDivPerDot teams. | ||
|
||
struct TagZero{}; // The init tag for beta=0 | ||
struct TagInit{}; // The init tag for beta!=0 and beta !=1 | ||
struct TagMult{}; // The multiplication tag for transposed A | ||
struct TagMultCT{}; // The multiplication tag for conjugate-transposed A | ||
template<class ExecSpace, class AV, class BV, class CV> | ||
struct DotBasedGEMM{ | ||
|
||
const AV A; | ||
const BV B; | ||
CV C; | ||
|
||
using scalar_A = typename AV::non_const_value_type; | ||
using size_A = typename AV::size_type; | ||
using scalar_C = typename CV::non_const_value_type; | ||
using size_C = typename CV::size_type; | ||
using AVT = Kokkos::Details::ArithTraits<scalar_A>; | ||
using CVT = Kokkos::Details::ArithTraits<scalar_C>; | ||
|
||
const scalar_A alpha; | ||
const scalar_C beta; | ||
|
||
// The following types (especially dotSize) could have simply been int, | ||
const size_C numCrows; | ||
const size_C numCcols; | ||
|
||
size_C numDivPerDot; // number of teams collectively performing a dot product | ||
size_C numTeams; // total number of teams | ||
|
||
const size_A dotSize; // the length of the vectors in the dot products | ||
size_A chunkSize; // the local length of each team's share on the dot product | ||
|
||
|
||
DotBasedGEMM(const scalar_A& alpha_, const AV& A_, const BV& B_, const scalar_C& beta_, const CV& C_):A(A_),B(B_),C(C_),alpha(alpha_),beta(beta_),numCrows(C.extent(0)),numCcols(C.extent(1)),dotSize(A.extent(0)) | ||
{ } | ||
|
||
void run(bool conjugateTranspose) { | ||
|
||
constexpr size_C workPerTeam = 4096; // Amount of work per team | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is there a reason why There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
const size_C ndots = numCrows * numCcols; // Number of dot products | ||
size_C appxNumTeams = (dotSize * ndots) / workPerTeam; // Estimation for appxNumTeams | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what is |
||
|
||
// Adjust appxNumTeams in case it is too small or too large | ||
if(appxNumTeams < 1) | ||
appxNumTeams = 1; | ||
if(appxNumTeams > 1024) | ||
appxNumTeams = 1024; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same as above for |
||
|
||
// If there are more dot products than the number of teams, | ||
// then set the number of teams to be number of dot products | ||
// and each team will perform only one dot product. | ||
// We don't want a team to perform more than one dot product. | ||
if(ndots >= appxNumTeams) { | ||
numTeams = ndots; | ||
numDivPerDot = 1; | ||
} | ||
// If there are more teams than dot products, each dot product can | ||
// potentially be performed by multiple teams. First, compute | ||
// numDivPerDot as an integer (take the floor, not ceiling), then, | ||
// compute actual number of teams by using this factor. | ||
else{ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: |
||
numDivPerDot = appxNumTeams / ndots; | ||
numTeams = ndots * numDivPerDot; | ||
} | ||
|
||
// Determine the local length for the dot product | ||
chunkSize = dotSize / numDivPerDot; | ||
if(numDivPerDot > 1) | ||
chunkSize++; | ||
|
||
// Initialize C matrix if beta != 1 | ||
if(beta == CVT::zero()) { | ||
Kokkos::MDRangePolicy<TagZero, ExecSpace, Kokkos::Rank<2>> policyInit({0,0}, {numCrows, numCcols}); | ||
Kokkos::parallel_for("Initialize C for Dot Product Based GEMM", policyInit, *this); | ||
} | ||
else if(beta != CVT::one()) { | ||
Kokkos::MDRangePolicy<TagInit, ExecSpace, Kokkos::Rank<2>> policyInit({0,0}, {numCrows, numCcols}); | ||
Kokkos::parallel_for("Initialize C for Dot Product Based GEMM", policyInit, *this); | ||
} | ||
|
||
// Multiply alpha*A^TB and add it to beta*C | ||
if(conjugateTranspose) { | ||
Kokkos::TeamPolicy<TagMultCT, ExecSpace> policyMult(numTeams, Kokkos::AUTO); | ||
Kokkos::parallel_for("Perform Dot Product Based GEMM", policyMult, *this); | ||
} | ||
else{ | ||
Kokkos::TeamPolicy<TagMult, ExecSpace> policyMult(numTeams, Kokkos::AUTO); | ||
Kokkos::parallel_for("Perform Dot Product Based GEMM", policyMult, *this); | ||
} | ||
} | ||
|
||
KOKKOS_INLINE_FUNCTION | ||
void operator() (const TagZero&, const size_C &rowId, const size_C &colId ) const { | ||
C(rowId, colId) = CVT::zero(); | ||
} | ||
|
||
KOKKOS_INLINE_FUNCTION | ||
void operator() (const TagInit&, const size_C &rowId, const size_C &colId ) const { | ||
C(rowId, colId) = beta * C(rowId, colId); | ||
} | ||
|
||
KOKKOS_INLINE_FUNCTION | ||
void operator() (const TagMult&, const typename Kokkos::TeamPolicy<ExecSpace>::member_type& teamMember) const { | ||
|
||
const size_C globalRank = teamMember.league_rank(); | ||
const size_C localRank = globalRank % numDivPerDot; | ||
const size_C i = globalRank / numDivPerDot; | ||
const size_C rowId = i / numCcols; | ||
const size_C colId = i % numCcols; | ||
|
||
scalar_C result = CVT::zero(); | ||
const size_A baseInd = chunkSize*localRank; | ||
Kokkos::parallel_reduce( Kokkos::TeamThreadRange(teamMember, chunkSize), [&]( const size_A k, scalar_C &update ) { | ||
if(baseInd + k < dotSize) | ||
update += alpha * A(baseInd+k, rowId) * B(baseInd+k, colId); | ||
}, result ); | ||
|
||
Kokkos::single(Kokkos::PerTeam(teamMember), [&] () { | ||
Kokkos::atomic_add(&C(rowId, colId), result); | ||
}); | ||
} | ||
|
||
KOKKOS_INLINE_FUNCTION | ||
void operator() (const TagMultCT&, const typename Kokkos::TeamPolicy<ExecSpace>::member_type& teamMember) const { | ||
|
||
const size_C globalRank = teamMember.league_rank(); | ||
const size_C localRank = globalRank % numDivPerDot; | ||
const size_C i = globalRank / numDivPerDot; | ||
const size_C rowId = i / numCcols; | ||
const size_C colId = i % numCcols; | ||
|
||
scalar_C result = CVT::zero(); | ||
const size_A baseInd = chunkSize*localRank; | ||
Kokkos::parallel_reduce( Kokkos::TeamThreadRange(teamMember, chunkSize), [&]( const size_A k, scalar_C &update ) { | ||
if(baseInd + k < dotSize) | ||
update += alpha * AVT::conj(A(baseInd+k, rowId)) * B(baseInd+k, colId); | ||
}, result ); | ||
|
||
Kokkos::single(Kokkos::PerTeam(teamMember), [&] () { | ||
Kokkos::atomic_add(&C(rowId, colId), result); | ||
}); | ||
} | ||
|
||
}; | ||
|
||
} | ||
} | ||
|
||
#endif |
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -50,6 +50,7 @@ | |||
|
||||
#if !defined(KOKKOSKERNELS_ETI_ONLY) || KOKKOSKERNELS_IMPL_COMPILE_LIBRARY | ||||
#include<KokkosBlas3_gemm_impl.hpp> | ||||
#include<KokkosBlas3_gemm_dotbased_impl.hpp> | ||||
#endif | ||||
|
||||
namespace KokkosBlas { | ||||
|
@@ -135,73 +136,98 @@ struct GEMM { | |||
typedef typename BViewType::non_const_value_type ScalarB; | ||||
typedef typename CViewType::non_const_value_type ScalarC; | ||||
|
||||
// Define Blocking sizes (this will be used for scratch spaces) | ||||
static constexpr int blockA0 = 24; | ||||
static constexpr int blockB1 = 64; | ||||
static constexpr int blockA1 = (sizeof(ScalarA)*blockA0*16 + sizeof(ScalarB)*16*blockB1 + sizeof(ScalarC)*blockA0*blockB1 < 24000) ? 16 : | ||||
(sizeof(ScalarA)*blockA0*8 + sizeof(ScalarB)*8*blockB1 + sizeof(ScalarC)*blockA0*blockB1 < 24000) ? 8 : | ||||
(sizeof(ScalarA)*blockA0*4 + sizeof(ScalarB)*4*blockB1 + sizeof(ScalarC)*blockA0*blockB1 < 24000) ? 4 : 16 ; | ||||
static constexpr int vector_length = blockB1/4; | ||||
|
||||
// Compute scratch space size | ||||
typedef KokkosBlas::Impl::GEMMImpl<typename CViewType::execution_space,AViewType,BViewType,CViewType,blockA0,blockA1,blockB1,0,0> gemm_dummy_type; | ||||
const int scratch_memory_size = | ||||
gemm_dummy_type::ViewTypeAScratch::required_allocation_size() + | ||||
gemm_dummy_type::ViewTypeBScratch::required_allocation_size() + | ||||
gemm_dummy_type::ViewTypeCScratch::required_allocation_size(); | ||||
const int scratch_level = scratch_memory_size < 24000 ? 0 : 1; | ||||
|
||||
// Figure out Team Sizes | ||||
int team_size = 1; | ||||
#if defined(KOKKOS_ENABLE_CUDA) | ||||
if(std::is_same<typename CViewType::execution_space,Kokkos::Cuda>::value) | ||||
team_size = blockA0; | ||||
#endif | ||||
#if defined(KOKKOS_ENABLE_HIP) | ||||
if(std::is_same<typename CViewType::execution_space,Kokkos::Experimental::HIP>::value) | ||||
team_size = blockA0; | ||||
#endif | ||||
#if defined(KOKKOS_ENABLE_ROCM) | ||||
if(std::is_same<typename CViewType::execution_space,Kokkos::ROCm>::value) | ||||
team_size = blockA0; | ||||
#endif | ||||
|
||||
// Call the correct kernel | ||||
if((transA[0]=='N' || transA[0]=='n') && (transB[0]=='N' || transB[0]=='n')) { | ||||
KokkosBlas::Impl::GEMMImpl<typename CViewType::execution_space,AViewType,BViewType,CViewType,blockA0,blockA1,blockB1,0,0> gemm(alpha,A,B,beta,C); | ||||
gemm.run(team_size,vector_length,scratch_level); | ||||
} | ||||
if((transA[0]=='T' || transA[0]=='t') && (transB[0]=='N' || transB[0]=='n')) { | ||||
KokkosBlas::Impl::GEMMImpl<typename CViewType::execution_space,AViewType,BViewType,CViewType,blockA0,blockA1,blockB1,1,0> gemm(alpha,A,B,beta,C); | ||||
gemm.run(team_size,vector_length,scratch_level); | ||||
} | ||||
if((transA[0]=='C' || transA[0]=='c') && (transB[0]=='N' || transB[0]=='n')) { | ||||
KokkosBlas::Impl::GEMMImpl<typename CViewType::execution_space,AViewType,BViewType,CViewType,blockA0,blockA1,blockB1,2,0> gemm(alpha,A,B,beta,C); | ||||
gemm.run(team_size,vector_length,scratch_level); | ||||
} | ||||
if((transA[0]=='N' || transA[0]=='n') && (transB[0]=='T' || transB[0]=='t')) { | ||||
KokkosBlas::Impl::GEMMImpl<typename CViewType::execution_space,AViewType,BViewType,CViewType,blockA0,blockA1,blockB1,0,1> gemm(alpha,A,B,beta,C); | ||||
gemm.run(team_size,vector_length,scratch_level); | ||||
} | ||||
if((transA[0]=='T' || transA[0]=='t') && (transB[0]=='T' || transB[0]=='t')) { | ||||
KokkosBlas::Impl::GEMMImpl<typename CViewType::execution_space,AViewType,BViewType,CViewType,blockA0,blockA1,blockB1,1,1> gemm(alpha,A,B,beta,C); | ||||
gemm.run(team_size,vector_length,scratch_level); | ||||
} | ||||
if((transA[0]=='C' || transA[0]=='c') && (transB[0]=='T' || transB[0]=='t')) { | ||||
KokkosBlas::Impl::GEMMImpl<typename CViewType::execution_space,AViewType,BViewType,CViewType,blockA0,blockA1,blockB1,2,1> gemm(alpha,A,B,beta,C); | ||||
gemm.run(team_size,vector_length,scratch_level); | ||||
} | ||||
if((transA[0]=='N' || transA[0]=='n') && (transB[0]=='C' || transB[0]=='c')) { | ||||
KokkosBlas::Impl::GEMMImpl<typename CViewType::execution_space,AViewType,BViewType,CViewType,blockA0,blockA1,blockB1,0,2> gemm(alpha,A,B,beta,C); | ||||
gemm.run(team_size,vector_length,scratch_level); | ||||
} | ||||
if((transA[0]=='T' || transA[0]=='t') && (transB[0]=='C' || transB[0]=='c')) { | ||||
KokkosBlas::Impl::GEMMImpl<typename CViewType::execution_space,AViewType,BViewType,CViewType,blockA0,blockA1,blockB1,1,2> gemm(alpha,A,B,beta,C); | ||||
gemm.run(team_size,vector_length,scratch_level); | ||||
} | ||||
if((transA[0]=='C' || transA[0]=='c') && (transB[0]=='C' || transB[0]=='c')) { | ||||
KokkosBlas::Impl::GEMMImpl<typename CViewType::execution_space,AViewType,BViewType,CViewType,blockA0,blockA1,blockB1,2,2> gemm(alpha,A,B,beta,C); | ||||
gemm.run(team_size,vector_length,scratch_level); | ||||
|
||||
#if !defined(KOKKOSKERNELS_ETI_ONLY) || KOKKOSKERNELS_IMPL_COMPILE_LIBRARY | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||
typedef typename CViewType::execution_space ExecSpace; | ||||
|
||||
// Figure out if we used use DotBased implementation | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. typo: |
||||
const int M = static_cast<int> (C.extent(0)); | ||||
const int N = static_cast<int> (C.extent(1)); | ||||
|
||||
const bool A_is_lr = std::is_same<Kokkos::LayoutRight, typename AViewType::array_layout>::value; | ||||
const bool A_is_tr = ((transA[0]=='T') || (transA[0]=='t') || (transA[0]=='C') || (transA[0]=='c')); | ||||
const bool B_is_tr = ((transB[0]=='T') || (transB[0]=='t') || (transB[0]=='C') || (transB[0]=='c')); | ||||
|
||||
constexpr int numDotsLayoutLeftThreshold = 1600; | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. consider adding a comment explaining |
||||
constexpr int numDotsLayoutRightThreshold = 100; | ||||
if( (!A_is_lr && A_is_tr && !B_is_tr && M*N < numDotsLayoutLeftThreshold) | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK. As long as the performance of DotBasedGEMM is better than GEMMImpl for these cases. |
||||
|| ( A_is_lr && A_is_tr && !B_is_tr && M*N < numDotsLayoutRightThreshold)) { | ||||
// call dot-based GEMM, only for C := beta * C + alpha * A^T * B | ||||
bool A_is_conj = ((transA[0]=='C') || (transA[0]=='c')); | ||||
DotBasedGEMM<ExecSpace, AViewType, BViewType, CViewType> dotBasedGemm(alpha, A, B, beta, C); | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: dotBasedGemm -> dotBasedGEMM |
||||
dotBasedGemm.run(A_is_conj ? true : false); | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am pretty sure that the tertiary operator can be replaced by |
||||
} else | ||||
#endif | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||
{ | ||||
|
||||
// Define Blocking sizes (this will be used for scratch spaces) | ||||
static constexpr int blockA0 = 24; | ||||
static constexpr int blockB1 = 64; | ||||
static constexpr int blockA1 = (sizeof(ScalarA)*blockA0*16 + sizeof(ScalarB)*16*blockB1 + sizeof(ScalarC)*blockA0*blockB1 < 24000) ? 16 : | ||||
(sizeof(ScalarA)*blockA0*8 + sizeof(ScalarB)*8*blockB1 + sizeof(ScalarC)*blockA0*blockB1 < 24000) ? 8 : | ||||
(sizeof(ScalarA)*blockA0*4 + sizeof(ScalarB)*4*blockB1 + sizeof(ScalarC)*blockA0*blockB1 < 24000) ? 4 : 16 ; | ||||
static constexpr int vector_length = blockB1/4; | ||||
|
||||
// Compute scratch space size | ||||
typedef KokkosBlas::Impl::GEMMImpl<typename CViewType::execution_space,AViewType,BViewType,CViewType,blockA0,blockA1,blockB1,0,0> gemm_dummy_type; | ||||
const int scratch_memory_size = | ||||
gemm_dummy_type::ViewTypeAScratch::required_allocation_size() + | ||||
gemm_dummy_type::ViewTypeBScratch::required_allocation_size() + | ||||
gemm_dummy_type::ViewTypeCScratch::required_allocation_size(); | ||||
const int scratch_level = scratch_memory_size < 24000 ? 0 : 1; | ||||
|
||||
// Figure out Team Sizes | ||||
int team_size = 1; | ||||
#if defined(KOKKOS_ENABLE_CUDA) | ||||
if(std::is_same<typename CViewType::execution_space,Kokkos::Cuda>::value) | ||||
team_size = blockA0; | ||||
#endif | ||||
#if defined(KOKKOS_ENABLE_HIP) | ||||
if(std::is_same<typename CViewType::execution_space,Kokkos::Experimental::HIP>::value) | ||||
team_size = blockA0; | ||||
#endif | ||||
#if defined(KOKKOS_ENABLE_ROCM) | ||||
if(std::is_same<typename CViewType::execution_space,Kokkos::ROCm>::value) | ||||
team_size = blockA0; | ||||
#endif | ||||
|
||||
// Call the correct kernel | ||||
if((transA[0]=='N' || transA[0]=='n') && (transB[0]=='N' || transB[0]=='n')) { | ||||
KokkosBlas::Impl::GEMMImpl<typename CViewType::execution_space,AViewType,BViewType,CViewType,blockA0,blockA1,blockB1,0,0> gemm(alpha,A,B,beta,C); | ||||
gemm.run(team_size,vector_length,scratch_level); | ||||
} | ||||
if((transA[0]=='T' || transA[0]=='t') && (transB[0]=='N' || transB[0]=='n')) { | ||||
KokkosBlas::Impl::GEMMImpl<typename CViewType::execution_space,AViewType,BViewType,CViewType,blockA0,blockA1,blockB1,1,0> gemm(alpha,A,B,beta,C); | ||||
gemm.run(team_size,vector_length,scratch_level); | ||||
} | ||||
if((transA[0]=='C' || transA[0]=='c') && (transB[0]=='N' || transB[0]=='n')) { | ||||
KokkosBlas::Impl::GEMMImpl<typename CViewType::execution_space,AViewType,BViewType,CViewType,blockA0,blockA1,blockB1,2,0> gemm(alpha,A,B,beta,C); | ||||
gemm.run(team_size,vector_length,scratch_level); | ||||
} | ||||
if((transA[0]=='N' || transA[0]=='n') && (transB[0]=='T' || transB[0]=='t')) { | ||||
KokkosBlas::Impl::GEMMImpl<typename CViewType::execution_space,AViewType,BViewType,CViewType,blockA0,blockA1,blockB1,0,1> gemm(alpha,A,B,beta,C); | ||||
gemm.run(team_size,vector_length,scratch_level); | ||||
} | ||||
if((transA[0]=='T' || transA[0]=='t') && (transB[0]=='T' || transB[0]=='t')) { | ||||
KokkosBlas::Impl::GEMMImpl<typename CViewType::execution_space,AViewType,BViewType,CViewType,blockA0,blockA1,blockB1,1,1> gemm(alpha,A,B,beta,C); | ||||
gemm.run(team_size,vector_length,scratch_level); | ||||
} | ||||
if((transA[0]=='C' || transA[0]=='c') && (transB[0]=='T' || transB[0]=='t')) { | ||||
KokkosBlas::Impl::GEMMImpl<typename CViewType::execution_space,AViewType,BViewType,CViewType,blockA0,blockA1,blockB1,2,1> gemm(alpha,A,B,beta,C); | ||||
gemm.run(team_size,vector_length,scratch_level); | ||||
} | ||||
if((transA[0]=='N' || transA[0]=='n') && (transB[0]=='C' || transB[0]=='c')) { | ||||
KokkosBlas::Impl::GEMMImpl<typename CViewType::execution_space,AViewType,BViewType,CViewType,blockA0,blockA1,blockB1,0,2> gemm(alpha,A,B,beta,C); | ||||
gemm.run(team_size,vector_length,scratch_level); | ||||
} | ||||
if((transA[0]=='T' || transA[0]=='t') && (transB[0]=='C' || transB[0]=='c')) { | ||||
KokkosBlas::Impl::GEMMImpl<typename CViewType::execution_space,AViewType,BViewType,CViewType,blockA0,blockA1,blockB1,1,2> gemm(alpha,A,B,beta,C); | ||||
gemm.run(team_size,vector_length,scratch_level); | ||||
} | ||||
if((transA[0]=='C' || transA[0]=='c') && (transB[0]=='C' || transB[0]=='c')) { | ||||
KokkosBlas::Impl::GEMMImpl<typename CViewType::execution_space,AViewType,BViewType,CViewType,blockA0,blockA1,blockB1,2,2> gemm(alpha,A,B,beta,C); | ||||
gemm.run(team_size,vector_length,scratch_level); | ||||
} | ||||
} | ||||
Kokkos::Profiling::popRegion(); | ||||
} | ||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
consider removing the comment and changing the variable types to int