From 2b0e0f436a51127d819d8a69f29f4664614f56c3 Mon Sep 17 00:00:00 2001 From: taolv Date: Thu, 31 Jan 2019 15:29:06 +0800 Subject: [PATCH 1/2] merge operator opts: log_softmax, take, where, log, split --- Makefile | 2 +- src/operator/channel_op_common.h | 37 +++++++++++++ src/operator/nn/softmax-inl.h | 53 ++++++++++++++++++- src/operator/nn/softmax.cc | 37 +++++++++++++ src/operator/slice_channel-inl.h | 7 ++- src/operator/tensor/control_flow_op.h | 35 ++++++++++-- src/operator/tensor/elemwise_unary_op.h | 41 ++++++++++++++ .../tensor/elemwise_unary_op_basic.cc | 3 +- src/operator/tensor/indexing_op.cc | 14 +++-- src/operator/tensor/indexing_op.h | 32 +++++++++++ 10 files changed, 249 insertions(+), 12 deletions(-) diff --git a/Makefile b/Makefile index 16ea59f3d585..fb53af6dabe9 100644 --- a/Makefile +++ b/Makefile @@ -87,7 +87,7 @@ endif ifeq ($(DEBUG), 1) CFLAGS += -g -O0 else - CFLAGS += -O3 -DNDEBUG=1 + CFLAGS += -O3 -DNDEBUG=1 -march=native endif CFLAGS += -I$(TPARTYDIR)/mshadow/ -I$(TPARTYDIR)/dmlc-core/include -fPIC -I$(NNVM_PATH)/include -I$(DLPACK_PATH)/include -I$(TPARTYDIR)/tvm/include -Iinclude $(MSHADOW_CFLAGS) LDFLAGS = -pthread $(MSHADOW_LDFLAGS) $(DMLC_LDFLAGS) diff --git a/src/operator/channel_op_common.h b/src/operator/channel_op_common.h index 1afc13ad2594..95e8531a857b 100644 --- a/src/operator/channel_op_common.h +++ b/src/operator/channel_op_common.h @@ -101,6 +101,43 @@ void Split(const mshadow::Tensor &input, split_helper(input, output, dimension, req); } } + +template +void Split_2D(const mshadow::Tensor &input, + std::vector > *output, + const int dimension, const std::vector &req) { + if (dimension != 1) { + LOG(FATAL) << "dimension (" << dimension << ") must == 1"; + } + if (dim != 3) { + LOG(FATAL) << "dimension (" << dim << ") must == 3"; + } else { + std::vector > out = *output; + size_t size = out.size(); + std::vectorslice_len; + std::vectorbegin_pos; + begin_pos.push_back(0); + + for (index_t i = 0; i < size; ++i) { + slice_len.push_back(out[i].size(dimension)); + begin_pos.push_back(begin_pos[i] + out[i].size(dimension)); + } +#if !defined(MXNET_ENABLE_CUDA_RTC) + #pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()) +#endif + for (int i = 0; i < input.shape_[0]; i++) { + int iRow = i*input.shape_[1]; + for (int j = 0; j < size; j++) { + int jRow = i*slice_len[j]; + int iPos = iRow + begin_pos[j]; + for (int k = 0; k < slice_len[j]; k++) { + out[j].dptr_[jRow + k] = input.dptr_[iPos + k]; + } + } + } + } +} + } // namespace op } // namespace mxnet #endif // MXNET_OPERATOR_CHANNEL_OP_COMMON_H_ diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h index c063e385f63a..1526841c3caf 100644 --- a/src/operator/nn/softmax-inl.h +++ b/src/operator/nn/softmax-inl.h @@ -31,6 +31,10 @@ #include "../operator_common.h" #include "../tensor/broadcast_reduce_op.h" +#if MSHADOW_USE_MKL == 1 +#include "mkl.h" +#endif + namespace mxnet { namespace op { namespace mxnet_op { @@ -42,7 +46,6 @@ struct softmax_fwd { } }; - struct log_softmax_fwd { template MSHADOW_XINLINE static DType Map(DType a, DType b) { @@ -310,6 +313,54 @@ void SoftmaxCompute(const nnvm::NodeAttrs& attrs, }); } +#if MSHADOW_USE_MKL == 1 +static inline int64_t prod(int64_t* n, int start, int end) { + int64_t res = 1; + for (int i = start; i < end; i++) + res *= n[i]; + return res; +} + +static inline void max_(int64_t n, float * __restrict__ in, float *dst) { + dst[0] = in[0]; + for (int64_t i = 1; i < n; i++) + dst[0] = (dst[0] < in[i]) ? in[i] : dst[0]; +} + +static inline void sub_(int64_t n, float * __restrict__ in, float b, float * __restrict__ dst) { + for (int64_t i = 0; i < n; i++) + dst[i] = in[i] - b; +} + +static inline void sum_(int64_t n, float * __restrict__ in, float * __restrict__ dst) { + dst[0] = cblas_sasum(n, in, 1); +} + +static inline void exp_(int64_t n, float *in, float *dst) { + vsExp(n, in, dst); +} + +static inline void log_softmax_parallel(TShape sh, int axis, + float * __restrict__ in, float * __restrict__ out) { + int64_t outer_size = prod(sh.data(), 0, axis); + int64_t channels = sh[axis]; + // int inner_size = prod(n, axis+1, 4); + +#pragma omp parallel for + for (int ou=0; ou < outer_size; ou++) { + float *in_dat = in + ou * channels; + float *out_dat = out + ou * channels; + float b, logsum; + + max_(channels, in_dat, &b); + sub_(channels, in_dat, b, out_dat); + exp_(channels, out_dat, out_dat); + sum_(channels, out_dat, &logsum); + logsum = b + logf(logsum); + sub_(channels, in_dat, logsum, out_dat); + } +} +#endif template void SoftmaxGradCompute(const nnvm::NodeAttrs& attrs, diff --git a/src/operator/nn/softmax.cc b/src/operator/nn/softmax.cc index 81e775cac526..6d4f82b1148f 100644 --- a/src/operator/nn/softmax.cc +++ b/src/operator/nn/softmax.cc @@ -35,6 +35,39 @@ namespace mxnet { namespace op { DMLC_REGISTER_PARAMETER(SoftmaxParam); +#if MSHADOW_USE_MKL == 1 +static inline bool SupportLogSoftmaxMKL(const TBlob &input, const SoftmaxParam ¶m) { + if (input.type_flag_ != mshadow::kFloat32) return false; + if (param.temperature.has_value()) return false; + + int axis = CheckAxis(param.axis, input.ndim()); + // channle on the last dimension + if ((input.ndim() == 4U && axis == 3U) || + (input.ndim() == 3U && axis == 2U) || + (input.ndim() == 2U && axis == 1U)) + return true; + else + return false; +} + +void LogSoftmaxComputeMKL(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + if (req[0] == kNullOp) return; + CHECK_NE(req[0], kAddTo); + const SoftmaxParam& param = nnvm::get(attrs.parsed); + int axis = CheckAxis(param.axis, inputs[0].ndim()); + if (!SupportLogSoftmaxMKL(inputs[0], param)) { + // fallback + SoftmaxCompute(attrs, ctx, inputs, req, outputs); + } else { + log_softmax_parallel(inputs[0].shape_, axis, inputs[0].dptr(), outputs[0].dptr()); + } +} +#endif + #if MXNET_USE_MKLDNN == 1 static void SoftmaxComputeExCPU(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -167,7 +200,11 @@ Examples:: )code") .set_attr_parser(ParamParser) +#if MSHADOW_USE_MKL == 1 +.set_attr("FCompute", LogSoftmaxComputeMKL) +#else .set_attr("FCompute", SoftmaxCompute) +#endif .set_attr("FGradient", ElemwiseGradUseOut{"_backward_log_softmax"}) .add_arguments(SoftmaxParam::__FIELDS__()); diff --git a/src/operator/slice_channel-inl.h b/src/operator/slice_channel-inl.h index 3b14a26ea649..5c2748026ed2 100644 --- a/src/operator/slice_channel-inl.h +++ b/src/operator/slice_channel-inl.h @@ -99,7 +99,12 @@ class SliceChannelOp : public Operator { for (int i = 0; i < size_; ++i) { outputs[i] = out_data[i].get_with_shape(slice_shape, s); } - Split(data, &outputs, 1, req); + // 3D dshape and trailing==1, split_2d can be used to speedup + if (trailing == 1 && std::is_same::value) { + Split_2D(data, &outputs, 1, req); + } else { + Split(data, &outputs, 1, req); + } } virtual void Backward(const OpContext &ctx, diff --git a/src/operator/tensor/control_flow_op.h b/src/operator/tensor/control_flow_op.h index 07252963c874..7569c4c22eac 100644 --- a/src/operator/tensor/control_flow_op.h +++ b/src/operator/tensor/control_flow_op.h @@ -80,6 +80,33 @@ struct where_csr { } }; +#define MIN(a, b) ((a < b) ? a : b) + +template +void where_batch_func(DType* out, const CType* cond, const DType* x, const DType* y, + const int M, const int N) { + static int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); + const int blk_size = 64/sizeof(DType); + if (M == 1) { +#pragma omp parallel for num_threads(omp_threads) + for (int blk = 0; blk < N; blk += blk_size) { + int blk_bound = MIN((blk + blk_size), N); + for (int i = blk; i < blk_bound; i++) { + out[i] = (cond[i] != 0) ? x[i] : y[i]; + } + } + } else { + const int len = M * sizeof (DType); +#pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < N; i++) { + if (cond[i] != 0) { + memcpy(reinterpret_cast(out + i * M), reinterpret_cast(x + i * M), len); + } else { + memcpy(reinterpret_cast(out + i * M), reinterpret_cast(y + i * M), len); + } + } + } +} /*! \brief Choose elements from x or y depending on condition * The condition is a vector whose size is the same as the @@ -295,9 +322,11 @@ void WhereOpForward(const nnvm::NodeAttrs& attrs, cond.dptr(), x.dptr(), y.dptr()); } else { - Kernel, xpu>::Launch(s, out.Size(), out.dptr(), - cond.dptr(), x.dptr(), - y.dptr(), x.Size()/cond.Size()); + where_batch_func(out.dptr(), cond.dptr(), x.dptr(), + y.dptr(), x.Size()/cond.Size(), cond.Size()); + // Kernel, xpu>::Launch(s, out.Size(), out.dptr(), + // cond.dptr(), x.dptr(), + // y.dptr(), x.Size()/cond.Size()); } }); }); diff --git a/src/operator/tensor/elemwise_unary_op.h b/src/operator/tensor/elemwise_unary_op.h index 83b86bf1d94c..8d5ad055b118 100644 --- a/src/operator/tensor/elemwise_unary_op.h +++ b/src/operator/tensor/elemwise_unary_op.h @@ -29,11 +29,15 @@ #include #include #include +#include #include "./cast_storage-inl.h" #include "../mshadow_op.h" #include "../mxnet_op.h" #include "../elemwise_op_common.h" #include "../../ndarray/ndarray_function.h" +#if MSHADOW_USE_MKL == 1 +#include "mkl.h" +#endif namespace mxnet { namespace op { @@ -348,6 +352,43 @@ class UnaryOp : public OpBase { LogUnimplementedOp(attrs, ctx, inputs, req, outputs); } } + +#if MSHADOW_USE_MKL == 1 + static inline void MKLLog(MKL_INT size, const float* pIn, float* pOut) { + vsLn(size, pIn, pOut); + } + + static inline void MKLLog(MKL_INT size, const double* pIn, double* pOut) { + vdLn(size, pIn, pOut); + } +#endif + + template + static void LogCompute(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + if (req[0] == kNullOp) return; + // if defined MSHADOW_USE_MKL then call mkl log when req is KWriteTo, type_flag + // is mshadow::kFloat32 or mshadow::kFloat64 and data size less than or equal MKL_INT_MAX +#if MSHADOW_USE_MKL == 1 + auto type_flag = inputs[0].type_flag_; + const size_t MKL_INT_MAX = (sizeof(MKL_INT) == sizeof(int)) ? INT_MAX : LLONG_MAX; + size_t input_size = inputs[0].Size(); + if (req[0] == kWriteTo && + input_size <= MKL_INT_MAX && + (type_flag == mshadow::kFloat32 || type_flag == mshadow::kFloat64)) { + MSHADOW_SGL_DBL_TYPE_SWITCH(type_flag, DType, { + MKLLog(input_size, inputs[0].dptr(), outputs[0].dptr()); + }); + } else { + Compute(attrs, ctx, inputs, req, outputs); + } +#else + Compute(attrs, ctx, inputs, req, outputs); +#endif + } }; /*! \brief Map legacy unary_bwd to backward_grad */ diff --git a/src/operator/tensor/elemwise_unary_op_basic.cc b/src/operator/tensor/elemwise_unary_op_basic.cc index 301fc48d2128..9730d0096e58 100644 --- a/src/operator/tensor/elemwise_unary_op_basic.cc +++ b/src/operator/tensor/elemwise_unary_op_basic.cc @@ -940,7 +940,7 @@ The storage type of ``exp`` output is always dense .set_attr("FGradient", ElemwiseGradUseOut{"_mul"}); // log -MXNET_OPERATOR_REGISTER_UNARY_WITH_SPARSE_DR(log, cpu, mshadow_op::log) +MXNET_OPERATOR_REGISTER_UNARY(log) MXNET_ADD_SPARSE_OP_ALIAS(log) .describe(R"code(Returns element-wise Natural logarithmic value of the input. @@ -949,6 +949,7 @@ The natural logarithm is logarithm in base *e*, so that ``log(exp(x)) = x`` The storage type of ``log`` output is always dense )code" ADD_FILELINE) +.set_attr("FCompute", UnaryOp::LogCompute) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_log"}); // log10 diff --git a/src/operator/tensor/indexing_op.cc b/src/operator/tensor/indexing_op.cc index 77236e068f86..8bf32410f45f 100644 --- a/src/operator/tensor/indexing_op.cc +++ b/src/operator/tensor/indexing_op.cc @@ -295,11 +295,15 @@ void TakeOpForward(const nnvm::NodeAttrs& attrs, MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, { // index data type if (actual_axis == 0) { if (param.mode == take_::kClip) { - Kernel, cpu>::Launch(s, idxshape.Size(), - outputs[take_::kOut].dptr(), - inputs[take_::kArr].dptr(), - inputs[take_::kIdx].dptr(), - oshape.Size()/idxshape.Size(), arrshape[0]); + take_axis0_clip_func(outputs[take_::kOut].dptr(), + inputs[take_::kArr].dptr(), + inputs[take_::kIdx].dptr(), + idxshape[0], oshape.Size()/idxshape.Size(), arrshape[0]); + // Kernel, cpu>::Launch(s, idxshape.Size(), + // outputs[take_::kOut].dptr(), + // inputs[take_::kArr].dptr(), + // inputs[take_::kIdx].dptr(), + // oshape.Size()/idxshape.Size(), arrshape[0]); } else { Kernel, cpu>::Launch(s, idxshape.Size(), outputs[take_::kOut].dptr(), diff --git a/src/operator/tensor/indexing_op.h b/src/operator/tensor/indexing_op.h index 92b6e21018e5..daecfaaa7aa5 100644 --- a/src/operator/tensor/indexing_op.h +++ b/src/operator/tensor/indexing_op.h @@ -296,6 +296,38 @@ inline bool SparseEmbeddingOpBackwardStorageType(const nnvm::NodeAttrs& attrs, return dispatched; } +#define MIN(a, b) ((a < b) ? a : b) + +template +void take_axis0_clip_func(DType* out_data, const DType* in_data, const IType* idx, + const int N, const int M, const int K) { + static int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); + const int blk_size = 64/sizeof(DType); + if (M == 1) { +#pragma omp parallel for num_threads(omp_threads) + for (int blk = 0; blk < N; blk += blk_size) { + int blk_bound = MIN(blk + blk_size, N); + for (int i = blk; i < blk_bound; i++) { + int j = static_cast(idx[i]); + if (j <= 0) j = 0; + else if (j >= K) j = K - 1; + out_data[i] = in_data[j]; + } + } + } else { + const int len = M * sizeof (DType); +#pragma omp parallel for num_threads(omp_threads) + for (int i = 0; i < N; i++) { + int j = static_cast(idx[i]); + if (j <= 0) j = 0; + else if (j >= K) j = K - 1; + memcpy(reinterpret_cast(out_data + i * M), + reinterpret_cast(in_data + j * M), + len); + } + } +} + /*! \brief name the struct Take instead of take * to avoid conflict with the take function in mshadow */ From bf92ba1d2b37fbca5ba6db316b0c58a3b001f970 Mon Sep 17 00:00:00 2001 From: taolv Date: Mon, 4 Feb 2019 23:49:30 +0800 Subject: [PATCH 2/2] fix omp num_threads --- src/operator/nn/softmax-inl.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h index 1526841c3caf..0ff8c875826f 100644 --- a/src/operator/nn/softmax-inl.h +++ b/src/operator/nn/softmax-inl.h @@ -346,7 +346,7 @@ static inline void log_softmax_parallel(TShape sh, int axis, int64_t channels = sh[axis]; // int inner_size = prod(n, axis+1, 4); -#pragma omp parallel for +#pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()) for (int ou=0; ou < outer_size; ou++) { float *in_dat = in + ou * channels; float *out_dat = out + ou * channels;