Skip to content

Commit

Permalink
Merge pull request zheng-da#11 from TaoLv/opt-1.4.0rc2
Browse files Browse the repository at this point in the history
Opt 1.4.0rc2
  • Loading branch information
Hao Li authored Feb 13, 2019
2 parents 9e1e943 + f8dc84a commit db4d865
Show file tree
Hide file tree
Showing 10 changed files with 249 additions and 12 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ endif
ifeq ($(DEBUG), 1)
CFLAGS += -g -O0
else
CFLAGS += -O3 -DNDEBUG=1
CFLAGS += -O3 -DNDEBUG=1 -march=native
endif
CFLAGS += -I$(TPARTYDIR)/mshadow/ -I$(TPARTYDIR)/dmlc-core/include -fPIC -I$(NNVM_PATH)/include -I$(DLPACK_PATH)/include -I$(TPARTYDIR)/tvm/include -Iinclude $(MSHADOW_CFLAGS)
LDFLAGS = -pthread $(MSHADOW_LDFLAGS) $(DMLC_LDFLAGS)
Expand Down
37 changes: 37 additions & 0 deletions src/operator/channel_op_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,43 @@ void Split(const mshadow::Tensor<xpu, dim, DType> &input,
split_helper<xpu, dim, dim-1>(input, output, dimension, req);
}
}

template<typename xpu, int dim, typename DType>
void Split_2D(const mshadow::Tensor<xpu, dim, DType> &input,
std::vector<mshadow::Tensor<xpu, dim, DType> > *output,
const int dimension, const std::vector<OpReqType> &req) {
if (dimension != 1) {
LOG(FATAL) << "dimension (" << dimension << ") must == 1";
}
if (dim != 3) {
LOG(FATAL) << "dimension (" << dim << ") must == 3";
} else {
std::vector<mshadow::Tensor<xpu, dim, DType> > out = *output;
size_t size = out.size();
std::vector<int>slice_len;
std::vector<int>begin_pos;
begin_pos.push_back(0);

for (index_t i = 0; i < size; ++i) {
slice_len.push_back(out[i].size(dimension));
begin_pos.push_back(begin_pos[i] + out[i].size(dimension));
}
#if !defined(MXNET_ENABLE_CUDA_RTC)
#pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount())
#endif
for (int i = 0; i < input.shape_[0]; i++) {
int iRow = i*input.shape_[1];
for (int j = 0; j < size; j++) {
int jRow = i*slice_len[j];
int iPos = iRow + begin_pos[j];
for (int k = 0; k < slice_len[j]; k++) {
out[j].dptr_[jRow + k] = input.dptr_[iPos + k];
}
}
}
}
}

} // namespace op
} // namespace mxnet
#endif // MXNET_OPERATOR_CHANNEL_OP_COMMON_H_
53 changes: 52 additions & 1 deletion src/operator/nn/softmax-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@
#include "../operator_common.h"
#include "../tensor/broadcast_reduce_op.h"

#if MSHADOW_USE_MKL == 1
#include "mkl.h"
#endif

namespace mxnet {
namespace op {
namespace mxnet_op {
Expand All @@ -42,7 +46,6 @@ struct softmax_fwd {
}
};


struct log_softmax_fwd {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
Expand Down Expand Up @@ -310,6 +313,54 @@ void SoftmaxCompute(const nnvm::NodeAttrs& attrs,
});
}

#if MSHADOW_USE_MKL == 1
static inline int64_t prod(int64_t* n, int start, int end) {
int64_t res = 1;
for (int i = start; i < end; i++)
res *= n[i];
return res;
}

static inline void max_(int64_t n, float * __restrict__ in, float *dst) {
dst[0] = in[0];
for (int64_t i = 1; i < n; i++)
dst[0] = (dst[0] < in[i]) ? in[i] : dst[0];
}

static inline void sub_(int64_t n, float * __restrict__ in, float b, float * __restrict__ dst) {
for (int64_t i = 0; i < n; i++)
dst[i] = in[i] - b;
}

static inline void sum_(int64_t n, float * __restrict__ in, float * __restrict__ dst) {
dst[0] = cblas_sasum(n, in, 1);
}

static inline void exp_(int64_t n, float *in, float *dst) {
vsExp(n, in, dst);
}

static inline void log_softmax_parallel(TShape sh, int axis,
float * __restrict__ in, float * __restrict__ out) {
int64_t outer_size = prod(sh.data(), 0, axis);
int64_t channels = sh[axis];
// int inner_size = prod(n, axis+1, 4);

#pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount())
for (int ou=0; ou < outer_size; ou++) {
float *in_dat = in + ou * channels;
float *out_dat = out + ou * channels;
float b, logsum;

max_(channels, in_dat, &b);
sub_(channels, in_dat, b, out_dat);
exp_(channels, out_dat, out_dat);
sum_(channels, out_dat, &logsum);
logsum = b + logf(logsum);
sub_(channels, in_dat, logsum, out_dat);
}
}
#endif

template<typename xpu, typename OP1, typename OP2, bool negate = false>
void SoftmaxGradCompute(const nnvm::NodeAttrs& attrs,
Expand Down
37 changes: 37 additions & 0 deletions src/operator/nn/softmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,39 @@ namespace mxnet {
namespace op {
DMLC_REGISTER_PARAMETER(SoftmaxParam);

#if MSHADOW_USE_MKL == 1
static inline bool SupportLogSoftmaxMKL(const TBlob &input, const SoftmaxParam &param) {
if (input.type_flag_ != mshadow::kFloat32) return false;
if (param.temperature.has_value()) return false;

int axis = CheckAxis(param.axis, input.ndim());
// channle on the last dimension
if ((input.ndim() == 4U && axis == 3U) ||
(input.ndim() == 3U && axis == 2U) ||
(input.ndim() == 2U && axis == 1U))
return true;
else
return false;
}

void LogSoftmaxComputeMKL(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
if (req[0] == kNullOp) return;
CHECK_NE(req[0], kAddTo);
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
int axis = CheckAxis(param.axis, inputs[0].ndim());
if (!SupportLogSoftmaxMKL(inputs[0], param)) {
// fallback
SoftmaxCompute<cpu, mxnet_op::log_softmax_fwd>(attrs, ctx, inputs, req, outputs);
} else {
log_softmax_parallel(inputs[0].shape_, axis, inputs[0].dptr<float>(), outputs[0].dptr<float>());
}
}
#endif

#if MXNET_USE_MKLDNN == 1
static void SoftmaxComputeExCPU(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
Expand Down Expand Up @@ -167,7 +200,11 @@ Examples::
)code")
.set_attr_parser(ParamParser<SoftmaxParam>)
#if MSHADOW_USE_MKL == 1
.set_attr<FCompute>("FCompute<cpu>", LogSoftmaxComputeMKL)
#else
.set_attr<FCompute>("FCompute<cpu>", SoftmaxCompute<cpu, mxnet_op::log_softmax_fwd>)
#endif
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseOut{"_backward_log_softmax"})
.add_arguments(SoftmaxParam::__FIELDS__());

Expand Down
7 changes: 6 additions & 1 deletion src/operator/slice_channel-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,12 @@ class SliceChannelOp : public Operator {
for (int i = 0; i < size_; ++i) {
outputs[i] = out_data[i].get_with_shape<xpu, 3, DType>(slice_shape, s);
}
Split(data, &outputs, 1, req);
// 3D dshape and trailing==1, split_2d can be used to speedup
if (trailing == 1 && std::is_same<xpu, cpu>::value) {
Split_2D(data, &outputs, 1, req);
} else {
Split(data, &outputs, 1, req);
}
}

virtual void Backward(const OpContext &ctx,
Expand Down
35 changes: 32 additions & 3 deletions src/operator/tensor/control_flow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,33 @@ struct where_csr {
}
};

#define MIN(a, b) ((a < b) ? a : b)

template<typename DType, typename CType>
void where_batch_func(DType* out, const CType* cond, const DType* x, const DType* y,
const int M, const int N) {
static int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
const int blk_size = 64/sizeof(DType);
if (M == 1) {
#pragma omp parallel for num_threads(omp_threads)
for (int blk = 0; blk < N; blk += blk_size) {
int blk_bound = MIN((blk + blk_size), N);
for (int i = blk; i < blk_bound; i++) {
out[i] = (cond[i] != 0) ? x[i] : y[i];
}
}
} else {
const int len = M * sizeof (DType);
#pragma omp parallel for num_threads(omp_threads)
for (int i = 0; i < N; i++) {
if (cond[i] != 0) {
memcpy(reinterpret_cast<void*>(out + i * M), reinterpret_cast<const void*>(x + i * M), len);
} else {
memcpy(reinterpret_cast<void*>(out + i * M), reinterpret_cast<const void*>(y + i * M), len);
}
}
}
}

/*! \brief Choose elements from x or y depending on condition
* The condition is a vector whose size is the same as the
Expand Down Expand Up @@ -295,9 +322,11 @@ void WhereOpForward(const nnvm::NodeAttrs& attrs,
cond.dptr<CType>(), x.dptr<DType>(),
y.dptr<DType>());
} else {
Kernel<where_batch<req_type>, xpu>::Launch(s, out.Size(), out.dptr<DType>(),
cond.dptr<CType>(), x.dptr<DType>(),
y.dptr<DType>(), x.Size()/cond.Size());
where_batch_func<DType, CType>(out.dptr<DType>(), cond.dptr<CType>(), x.dptr<DType>(),
y.dptr<DType>(), x.Size()/cond.Size(), cond.Size());
// Kernel<where_batch<req_type>, xpu>::Launch(s, out.Size(), out.dptr<DType>(),
// cond.dptr<CType>(), x.dptr<DType>(),
// y.dptr<DType>(), x.Size()/cond.Size());
}
});
});
Expand Down
41 changes: 41 additions & 0 deletions src/operator/tensor/elemwise_unary_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,15 @@
#include <vector>
#include <utility>
#include <algorithm>
#include <climits>
#include "./cast_storage-inl.h"
#include "../mshadow_op.h"
#include "../mxnet_op.h"
#include "../elemwise_op_common.h"
#include "../../ndarray/ndarray_function.h"
#if MSHADOW_USE_MKL == 1
#include "mkl.h"
#endif

namespace mxnet {
namespace op {
Expand Down Expand Up @@ -348,6 +352,43 @@ class UnaryOp : public OpBase {
LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
}
}

#if MSHADOW_USE_MKL == 1
static inline void MKLLog(MKL_INT size, const float* pIn, float* pOut) {
vsLn(size, pIn, pOut);
}

static inline void MKLLog(MKL_INT size, const double* pIn, double* pOut) {
vdLn(size, pIn, pOut);
}
#endif

template<typename xpu, typename OP>
static void LogCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
if (req[0] == kNullOp) return;
// if defined MSHADOW_USE_MKL then call mkl log when req is KWriteTo, type_flag
// is mshadow::kFloat32 or mshadow::kFloat64 and data size less than or equal MKL_INT_MAX
#if MSHADOW_USE_MKL == 1
auto type_flag = inputs[0].type_flag_;
const size_t MKL_INT_MAX = (sizeof(MKL_INT) == sizeof(int)) ? INT_MAX : LLONG_MAX;
size_t input_size = inputs[0].Size();
if (req[0] == kWriteTo &&
input_size <= MKL_INT_MAX &&
(type_flag == mshadow::kFloat32 || type_flag == mshadow::kFloat64)) {
MSHADOW_SGL_DBL_TYPE_SWITCH(type_flag, DType, {
MKLLog(input_size, inputs[0].dptr<DType>(), outputs[0].dptr<DType>());
});
} else {
Compute<xpu, OP>(attrs, ctx, inputs, req, outputs);
}
#else
Compute<xpu, OP>(attrs, ctx, inputs, req, outputs);
#endif
}
};

/*! \brief Map legacy unary_bwd to backward_grad */
Expand Down
3 changes: 2 additions & 1 deletion src/operator/tensor/elemwise_unary_op_basic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -940,7 +940,7 @@ The storage type of ``exp`` output is always dense
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseOut{"_mul"});

// log
MXNET_OPERATOR_REGISTER_UNARY_WITH_SPARSE_DR(log, cpu, mshadow_op::log)
MXNET_OPERATOR_REGISTER_UNARY(log)
MXNET_ADD_SPARSE_OP_ALIAS(log)
.describe(R"code(Returns element-wise Natural logarithmic value of the input.
Expand All @@ -949,6 +949,7 @@ The natural logarithm is logarithm in base *e*, so that ``log(exp(x)) = x``
The storage type of ``log`` output is always dense
)code" ADD_FILELINE)
.set_attr<FCompute>("FCompute<cpu>", UnaryOp::LogCompute<cpu, mshadow_op::log>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_log"});

// log10
Expand Down
14 changes: 9 additions & 5 deletions src/operator/tensor/indexing_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -295,11 +295,15 @@ void TakeOpForward<cpu>(const nnvm::NodeAttrs& attrs,
MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, { // index data type
if (actual_axis == 0) {
if (param.mode == take_::kClip) {
Kernel<TakeCPU<true>, cpu>::Launch(s, idxshape.Size(),
outputs[take_::kOut].dptr<DType>(),
inputs[take_::kArr].dptr<DType>(),
inputs[take_::kIdx].dptr<IType>(),
oshape.Size()/idxshape.Size(), arrshape[0]);
take_axis0_clip_func(outputs[take_::kOut].dptr<DType>(),
inputs[take_::kArr].dptr<DType>(),
inputs[take_::kIdx].dptr<IType>(),
idxshape[0], oshape.Size()/idxshape.Size(), arrshape[0]);
// Kernel<TakeCPU<true>, cpu>::Launch(s, idxshape.Size(),
// outputs[take_::kOut].dptr<DType>(),
// inputs[take_::kArr].dptr<DType>(),
// inputs[take_::kIdx].dptr<IType>(),
// oshape.Size()/idxshape.Size(), arrshape[0]);
} else {
Kernel<TakeCPU<false>, cpu>::Launch(s, idxshape.Size(),
outputs[take_::kOut].dptr<DType>(),
Expand Down
32 changes: 32 additions & 0 deletions src/operator/tensor/indexing_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,38 @@ inline bool SparseEmbeddingOpBackwardStorageType(const nnvm::NodeAttrs& attrs,
return dispatched;
}

#define MIN(a, b) ((a < b) ? a : b)

template<typename DType, typename IType>
void take_axis0_clip_func(DType* out_data, const DType* in_data, const IType* idx,
const int N, const int M, const int K) {
static int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
const int blk_size = 64/sizeof(DType);
if (M == 1) {
#pragma omp parallel for num_threads(omp_threads)
for (int blk = 0; blk < N; blk += blk_size) {
int blk_bound = MIN(blk + blk_size, N);
for (int i = blk; i < blk_bound; i++) {
int j = static_cast<int>(idx[i]);
if (j <= 0) j = 0;
else if (j >= K) j = K - 1;
out_data[i] = in_data[j];
}
}
} else {
const int len = M * sizeof (DType);
#pragma omp parallel for num_threads(omp_threads)
for (int i = 0; i < N; i++) {
int j = static_cast<int>(idx[i]);
if (j <= 0) j = 0;
else if (j >= K) j = K - 1;
memcpy(reinterpret_cast<void*>(out_data + i * M),
reinterpret_cast<const void*>(in_data + j * M),
len);
}
}
}

/*! \brief name the struct Take instead of take
* to avoid conflict with the take function in mshadow
*/
Expand Down

0 comments on commit db4d865

Please sign in to comment.