Skip to content
This repository has been archived by the owner on Aug 11, 2020. It is now read-only.

Commit

Permalink
Improve batch gemm performance using MKL (#342)
Browse files Browse the repository at this point in the history
* improve batch_dot performance by using MKL

* reduce for loop

* improve double batch gemm

* remove unnecessary reserve

* fix lint

* add MKL version check
  • Loading branch information
xinyu-intel authored and piiswrong committed Jun 23, 2018
1 parent 5cf130a commit 757a91c
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 0 deletions.
1 change: 1 addition & 0 deletions mshadow/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ extern "C" {
#include <mkl_cblas.h>
#include <mkl_vsl.h>
#include <mkl_vsl_functions.h>
#include <mkl_version.h>
#endif

#if MSHADOW_USE_CUDA
Expand Down
75 changes: 75 additions & 0 deletions mshadow/dot_engine-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#ifndef MSHADOW_DOT_ENGINE_INL_H_
#define MSHADOW_DOT_ENGINE_INL_H_

#include <vector>
#include "./base.h"
#include "./extension/implicit_gemm.h"

Expand Down Expand Up @@ -291,11 +292,48 @@ struct BLASEngine<cpu, float> {
const float *A, int lda, const float *B, int ldb,
float beta, float *C, int ldc, int batch_count,
float **workspace) {
#if (MSHADOW_USE_MKL && INTEL_MKL_VERSION >= 20160000)
std::vector<int> p_m(batch_count, m);
std::vector<int> p_n(batch_count, n);
std::vector<int> p_k(batch_count, k);
std::vector<int> p_lda(batch_count, lda);
std::vector<int> p_ldb(batch_count, ldb);
std::vector<int> p_ldc(batch_count, ldc);
std::vector<float> p_alpha(batch_count, alpha);
std::vector<float> p_beta(batch_count, beta);
std::vector<const float*> pp_A;
std::vector<const float*> pp_B;
std::vector<float*> pp_C;

CBLAS_TRANSPOSE cblas_a_trans = GetT(transa);
CBLAS_TRANSPOSE cblas_b_trans = GetT(transb);

std::vector<int> p_group_sizeb(batch_count, batch_count);
std::vector<CBLAS_TRANSPOSE> p_transa(batch_count, cblas_a_trans);
std::vector<CBLAS_TRANSPOSE> p_transb(batch_count, cblas_b_trans);

auto m_k = m * k;
auto k_n = k * n;
auto m_n = m * n;

for (int i = 0; i < batch_count; i++) {
pp_A.push_back(A + i * m_k);
pp_B.push_back(B + i * k_n);
pp_C.push_back(C + i * m_n);
}

cblas_sgemm_batch(CblasColMajor, p_transa.data(), p_transb.data(),
p_m.data(), p_n.data(), p_k.data(),
p_alpha.data(), pp_A.data(), p_lda.data(), pp_B.data(),
p_ldb.data(), p_beta.data(), pp_C.data(), p_ldc.data(),
1, p_group_sizeb.data());
#else
for (int i = 0; i < batch_count; ++i) {
gemm(stream, transa, transb, m, n, k, alpha,
A + i * m * k, lda, B + i * k * n, ldb,
beta, C + i * m * n, ldc);
}
#endif
}
inline static void gemv(Stream<cpu> *stream,
bool trans, int m, int n,
Expand Down Expand Up @@ -361,11 +399,48 @@ struct BLASEngine<cpu, double> {
const double *A, int lda, const double *B, int ldb,
double beta, double *C, int ldc, int batch_count,
double **workspace) {
#if (MSHADOW_USE_MKL && INTEL_MKL_VERSION >= 20160000)
std::vector<int> p_m(batch_count, m);
std::vector<int> p_n(batch_count, n);
std::vector<int> p_k(batch_count, k);
std::vector<int> p_lda(batch_count, lda);
std::vector<int> p_ldb(batch_count, ldb);
std::vector<int> p_ldc(batch_count, ldc);
std::vector<double> p_alpha(batch_count, alpha);
std::vector<double> p_beta(batch_count, beta);
std::vector<const double*> pp_A;
std::vector<const double*> pp_B;
std::vector<double*> pp_C;

CBLAS_TRANSPOSE cblas_a_trans = GetT(transa);
CBLAS_TRANSPOSE cblas_b_trans = GetT(transb);

std::vector<int> p_group_sizeb(batch_count, batch_count);
std::vector<CBLAS_TRANSPOSE> p_transa(batch_count, cblas_a_trans);
std::vector<CBLAS_TRANSPOSE> p_transb(batch_count, cblas_b_trans);

auto m_k = m * k;
auto k_n = k * n;
auto m_n = m * n;

for (int i = 0; i < batch_count; i++) {
pp_A.push_back(A + i * m_k);
pp_B.push_back(B + i * k_n);
pp_C.push_back(C + i * m_n);
}

cblas_dgemm_batch(CblasColMajor, p_transa.data(), p_transb.data(),
p_m.data(), p_n.data(), p_k.data(),
p_alpha.data(), pp_A.data(), p_lda.data(), pp_B.data(),
p_ldb.data(), p_beta.data(), pp_C.data(), p_ldc.data(),
1, p_group_sizeb.data());
#else
for (int i = 0; i < batch_count; ++i) {
gemm(stream, transa, transb, m, n, k, alpha,
A + i * m * k, lda, B + i * k * n, ldb,
beta, C + i * m * n, ldc);
}
#endif
}
inline static void gemv(Stream<cpu> *stream,
bool trans, int m, int n, double alpha,
Expand Down

0 comments on commit 757a91c

Please sign in to comment.