From 2c3d9a3a491d33497c2b37897e73796a0c28e19d Mon Sep 17 00:00:00 2001 From: Da Zheng Date: Tue, 30 Oct 2018 00:20:11 +0000 Subject: [PATCH] fix int overflow. --- src/operator/mxnet_op.h | 44 +++++++++++++++--------------- src/operator/random/sampler.h | 29 ++++++++++---------- src/operator/tensor/indexing_op.cc | 24 ++++++++-------- src/operator/tensor/indexing_op.cu | 2 +- src/operator/tensor/indexing_op.h | 34 ++++++++++++----------- src/operator/tensor/init_op.h | 2 +- 6 files changed, 69 insertions(+), 66 deletions(-) diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h index e77569671ebb..f3061e859035 100644 --- a/src/operator/mxnet_op.h +++ b/src/operator/mxnet_op.h @@ -433,51 +433,51 @@ struct op_with_req { /*! \brief input is one tensor */ template - MSHADOW_XINLINE static void Map(int i, DType *out, const DType *in) { + MSHADOW_XINLINE static void Map(int64_t i, DType *out, const DType *in) { KERNEL_ASSIGN(out[i], req, OP::Map(in[i])); } /*! \brief inputs are two tensors */ template - MSHADOW_XINLINE static void Map(int i, DType *out, const DType *lhs, const DType *rhs) { + MSHADOW_XINLINE static void Map(int64_t i, DType *out, const DType *lhs, const DType *rhs) { KERNEL_ASSIGN(out[i], req, OP::Map(lhs[i], rhs[i])); } /*! \brief input is tensor and a scalar value */ template - MSHADOW_XINLINE static void Map(int i, DType *out, const DType *in, const DType value) { + MSHADOW_XINLINE static void Map(int64_t i, DType *out, const DType *in, const DType value) { KERNEL_ASSIGN(out[i], req, OP::Map(in[i], value)); } /*! \brief input is tensor and two scalar value */ template - MSHADOW_XINLINE static void Map(int i, DType *out, const DType *in, + MSHADOW_XINLINE static void Map(int64_t i, DType *out, const DType *in, const DType value_1, const DType value_2) { KERNEL_ASSIGN(out[i], req, OP::Map(in[i], value_1, value_2)); } /*! \brief No inputs (ie fill to constant value) */ template - MSHADOW_XINLINE static void Map(int i, DType *out) { + MSHADOW_XINLINE static void Map(int64_t i, DType *out) { KERNEL_ASSIGN(out[i], req, OP::Map()); } /*! \brief input is single scalar value */ template - MSHADOW_XINLINE static void Map(int i, DType *out, const DType value) { + MSHADOW_XINLINE static void Map(int64_t i, DType *out, const DType value) { KERNEL_ASSIGN(out[i], req, OP::Map(value)); } /*! \brief inputs are two tensors and a scalar value */ template - MSHADOW_XINLINE static void Map(int i, DType *out, + MSHADOW_XINLINE static void Map(int64_t i, DType *out, const DType *input_1, const DType *input_2, const DType value) { KERNEL_ASSIGN(out[i], req, OP::Map(input_1[i], input_2[i], value)); } /*! \brief inputs are three tensors (ie backward grad with binary grad function) */ template - MSHADOW_XINLINE static void Map(int i, DType *out, + MSHADOW_XINLINE static void Map(int64_t i, DType *out, const DType *input_1, const DType *input_2, const DType *input_3) { @@ -503,21 +503,21 @@ struct Kernel { * \param args Varargs to eventually pass to the OP::Map() functoion */ template - inline static bool Launch(mshadow::Stream *, const int N, Args... args) { + inline static bool Launch(mshadow::Stream *, const int64_t N, Args... args) { #ifdef _OPENMP const int omp_threads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); if (omp_threads < 2) { - for (int i = 0; i < N; ++i) { + for (int64_t i = 0; i < N; ++i) { OP::Map(i, args...); } } else { #pragma omp parallel for num_threads(omp_threads) - for (int i = 0; i < N; ++i) { + for (int64_t i = 0; i < N; ++i) { OP::Map(i, args...); } } #else - for (int i = 0; i < N; ++i) { + for (int64_t i = 0; i < N; ++i) { OP::Map(i, args...); } #endif @@ -536,22 +536,22 @@ struct Kernel { * \param args Varargs to eventually pass to the OP::Map() functoion */ template - static void LaunchTuned(mshadow::Stream *, const int N, Args... args) { + static void LaunchTuned(mshadow::Stream *, const int64_t N, Args... args) { #ifdef _OPENMP const int omp_threads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); if (omp_threads < 2 || !tuned_op::UseOMP( static_cast(N), static_cast(omp_threads))) { - for (int i = 0; i < N; ++i) { + for (int64_t i = 0; i < N; ++i) { OP::Map(i, args...); } } else { #pragma omp parallel for num_threads(omp_threads) - for (int i = 0; i < N; ++i) { + for (int64_t i = 0; i < N; ++i) { OP::Map(i, args...); } } #else - for (int i = 0; i < N; ++i) { + for (int64_t i = 0; i < N; ++i) { OP::Map(i, args...); } #endif @@ -565,15 +565,15 @@ struct Kernel { * \param args Varargs to eventually pass to the UseOMP() and OP::Map() functions */ template - inline static void LaunchEx(mshadow::Stream *s, const int N, Args... args) { + inline static void LaunchEx(mshadow::Stream *s, const int64_t N, Args... args) { #ifdef _OPENMP const int omp_threads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); if (omp_threads < 2) { OP::Map(0, N, args...); } else { - const int length = (N + omp_threads - 1) / omp_threads; + const int64_t length = (N + omp_threads - 1) / omp_threads; #pragma omp parallel for num_threads(omp_threads) - for (int i = 0; i < N; i += length) { + for (int64_t i = 0; i < N; i += length) { OP::Map(i, i + length > N ? N - i : length, args...); } } @@ -595,7 +595,7 @@ struct Kernel { template static MSHADOW_CINLINE typename std::enable_if::value, bool>::type - Launch(mshadow::Stream *s, const int N, DType *dest, Args... args) { + Launch(mshadow::Stream *s, const int64_t N, DType *dest, Args... args) { LaunchTuned(s, N, dest, args...); return true; } @@ -613,7 +613,7 @@ struct Kernel { template static MSHADOW_CINLINE typename std::enable_if::value, bool>::type - Launch(mshadow::Stream *s, const int N, DType *dest, Args... args) { + Launch(mshadow::Stream *s, const int64_t N, DType *dest, Args... args) { LaunchTuned(s, N, dest, args...); return true; } @@ -669,7 +669,7 @@ template struct set_to_int : public tunable { // mxnet_op version (when used directly with Kernel<>::Launch()) */ template - MSHADOW_XINLINE static void Map(int i, DType *out) { + MSHADOW_XINLINE static void Map(int64_t i, DType *out) { out[i] = DType(val); } // mshadow_op version (when used with op_with_req<>) diff --git a/src/operator/random/sampler.h b/src/operator/random/sampler.h index 44f80ab56254..57a83b69927d 100644 --- a/src/operator/random/sampler.h +++ b/src/operator/random/sampler.h @@ -43,24 +43,25 @@ namespace op { template inline static void LaunchRNG(mshadow::Stream *s, common::random::RandGenerator *gen, - const int N, Args... args) { + const int64_t N, Args... args) { // minimal check to avoid division by zero, below. // if `N` is zero the map operation is a no-op in any case. if (N <= 0) { return; } - const int nloop = (N + RandGenerator::kMinNumRandomPerThread - 1) / + const int64_t nloop = (N + RandGenerator::kMinNumRandomPerThread - 1) / RandGenerator::kMinNumRandomPerThread; - const int nthread = std::min(nloop, RandGenerator::kNumRandomStates); - const int step = (N + nthread - 1) / nthread; + const int64_t nthread = std::min(nloop, + static_cast(RandGenerator::kNumRandomStates)); + const int64_t step = (N + nthread - 1) / nthread; Kernel::Launch(s, nthread, *gen, N, step, args...); } #define RNG_KERNEL_LOOP(xpu, GType, thread_id, gen, N, step, ...) \ - const int start = thread_id * step; \ - const int end = start + step; \ + const int64_t start = thread_id * step; \ + const int64_t end = start + step; \ typename RandGenerator::Impl genImpl(&gen, thread_id); \ - for (int i = start; i < end && i < N; ++i) { \ + for (int64_t i = start; i < end && i < N; ++i) { \ {__VA_ARGS__} \ } @@ -68,7 +69,7 @@ template struct SampleUniformKernel { template MSHADOW_XINLINE static void Map(int id, RandGenerator gen, - const int N, const int step, + const int64_t N, const int64_t step, index_t nParm, index_t nSample, const IType *lower, const IType *upper, OType *out) { RNG_KERNEL_LOOP(xpu, OType, id, gen, N, step, { @@ -96,7 +97,7 @@ template struct SampleNormalKernel { template MSHADOW_XINLINE static void Map(int id, RandGenerator gen, - const int N, const int step, + const int64_t N, const int64_t step, index_t nParm, index_t nSample, const IType *mean, const IType *std, OType *out) { RNG_KERNEL_LOOP(xpu, OType, id, gen, N, step, { @@ -123,7 +124,7 @@ template struct SampleExponentialKernel { template MSHADOW_XINLINE static void Map(int id, RandGenerator gen, - const int N, const int step, + const int64_t N, const int64_t step, index_t nParm, index_t nSample, const IType *lambda, OType *out) { RNG_KERNEL_LOOP(xpu, OType, id, gen, N, step, { @@ -171,7 +172,7 @@ template struct SampleGammaKernel { template MSHADOW_XINLINE static void Map(int id, RandGenerator gen, - const int N, const int step, + const int64_t N, const int64_t step, index_t nParm, index_t nSample, const IType *alpha, const IType *beta, OType *out) { RNG_KERNEL_LOOP(xpu, FType, id, gen, N, step, { @@ -233,7 +234,7 @@ template struct SamplePoissonKernel { template MSHADOW_XINLINE static void Map(int id, RandGenerator gen, - const int N, const int step, + const int64_t N, const int64_t step, index_t nParm, index_t nSample, const IType *lambda, OType *out) { RNG_KERNEL_LOOP(xpu, float, id, gen, N, step, { @@ -260,7 +261,7 @@ template struct SampleNegativeBinomialKernel { template MSHADOW_XINLINE static void Map(int id, RandGenerator gen, - const int N, const int step, + const int64_t N, const int64_t step, index_t nParm, index_t nSample, const IType *k, const IType *p, OType *out) { RNG_KERNEL_LOOP(xpu, float, id, gen, N, step, { @@ -292,7 +293,7 @@ template struct SampleGeneralizedNegativeBinomialKernel { template MSHADOW_XINLINE static void Map(int id, RandGenerator gen, - const int N, const int step, + const int64_t N, const int64_t step, index_t nParm, index_t nSample, const IType *mu, const IType *alpha, OType *out) { RNG_KERNEL_LOOP(xpu, float, id, gen, N, step, { diff --git a/src/operator/tensor/indexing_op.cc b/src/operator/tensor/indexing_op.cc index de0ede3a7427..e99322b454bd 100644 --- a/src/operator/tensor/indexing_op.cc +++ b/src/operator/tensor/indexing_op.cc @@ -307,19 +307,19 @@ inline void SparseEmbeddingOpBackwardRspImpl(const bool deterministic, template inline typename std::enable_if<(!std::is_same::value), void>::type -GatherNDBackwardImpl(int N, int M, int K, +GatherNDBackwardImpl(int64_t N, int64_t M, int64_t K, const mshadow::Shape<10> strides, DType* out, const DType* data, const IType* indices, mshadow::Stream *s) { #pragma omp parallel for - for (int i = 0; i < N; i++) { - int offset = 0; - for (int j = 0; j < M; ++j) { - offset += strides[j] * static_cast(indices[j*N + i]); + for (int64_t i = 0; i < N; i++) { + int64_t offset = 0; + for (int64_t j = 0; j < M; ++j) { + offset += strides[j] * static_cast(indices[j*N + i]); } - for (int j = 0; j < K; ++j) { + for (int64_t j = 0; j < K; ++j) { #pragma omp atomic out[offset + j] += data[i * K + j]; } @@ -328,18 +328,18 @@ GatherNDBackwardImpl(int N, int M, int K, template inline typename std::enable_if::value, void>::type -GatherNDBackwardImpl(int N, int M, int K, +GatherNDBackwardImpl(int64_t N, int64_t M, int64_t K, const mshadow::Shape<10> strides, DType* out, const DType* data, const IType* indices, mshadow::Stream *s) { - for (int i = 0; i < N; i++) { - int offset = 0; - for (int j = 0; j < M; ++j) { - offset += strides[j] * static_cast(indices[j*N + i]); + for (int64_t i = 0; i < N; i++) { + int64_t offset = 0; + for (int64_t j = 0; j < M; ++j) { + offset += strides[j] * static_cast(indices[j*N + i]); } - for (int j = 0; j < K; ++j) { + for (int64_t j = 0; j < K; ++j) { out[offset + j] += data[i * K + j]; } } diff --git a/src/operator/tensor/indexing_op.cu b/src/operator/tensor/indexing_op.cu index 16cc697626ee..f355fce5ad39 100644 --- a/src/operator/tensor/indexing_op.cu +++ b/src/operator/tensor/indexing_op.cu @@ -405,7 +405,7 @@ struct backward_gather_nd_gpu { }; template -inline void GatherNDBackwardImpl(int N, int M, int K, +inline void GatherNDBackwardImpl(int64_t N, int64_t M, int64_t K, const mshadow::Shape<10> strides, DType* out, const DType* data, diff --git a/src/operator/tensor/indexing_op.h b/src/operator/tensor/indexing_op.h index 2a419e7f6b0e..378c666175e1 100644 --- a/src/operator/tensor/indexing_op.h +++ b/src/operator/tensor/indexing_op.h @@ -1490,15 +1490,15 @@ inline bool ScatterNDType(const nnvm::NodeAttrs& attrs, struct scatter_nd { template - MSHADOW_XINLINE static void Map(int i, OpReqType req, int N, int M, int K, + MSHADOW_XINLINE static void Map(int64_t i, OpReqType req, int64_t N, int64_t M, int64_t K, const mshadow::Shape<10> strides, DType* out, const DType* data, const IType* indices) { - int offset = 0; - for (int j = 0; j < M; ++j) { - offset += strides[j] * static_cast(indices[j*N + i]); + int64_t offset = 0; + for (int64_t j = 0; j < M; ++j) { + offset += strides[j] * static_cast(indices[j*N + i]); } - for (int j = 0; j < K; ++j) { + for (int64_t j = 0; j < K; ++j) { KERNEL_ASSIGN(out[offset+j], req, data[i*K + j]); } } @@ -1511,17 +1511,18 @@ void ScatterNDForward(const nnvm::NodeAttrs& attrs, const std::vector& req, const std::vector& outputs) { using namespace mshadow; + using nnvm::dim_t; CHECK_EQ(inputs.size(), 2U); CHECK_EQ(outputs.size(), 1U); if (req[0] == kNullOp) return; mshadow::Stream *s = ctx.get_stream(); const TShape& oshape = outputs[0].shape_; const TShape& ishape = inputs[1].shape_; - int M = ishape[0]; - int N = ishape.Size() / M; - int K = oshape.ProdShape(M, oshape.ndim()); + dim_t M = ishape[0]; + dim_t N = ishape.Size() / M; + dim_t K = oshape.ProdShape(M, oshape.ndim()); mshadow::Shape<10> strides; - for (int i = M-1, stride = K; i >= 0; stride *= oshape[i], --i) strides[i] = stride; + for (dim_t i = M-1, stride = K; i >= 0; stride *= oshape[i], --i) strides[i] = stride; if (kWriteTo == req[0]) { Fill(s, outputs[0], req[0], 0); } @@ -1536,7 +1537,7 @@ void ScatterNDForward(const nnvm::NodeAttrs& attrs, template inline typename std::enable_if<(!std::is_same::value), void>::type -GatherNDBackwardImpl(int N, int M, int K, +GatherNDBackwardImpl(int64_t N, int64_t M, int64_t K, const mshadow::Shape<10> strides, DType* out, const DType* data, @@ -1545,7 +1546,7 @@ GatherNDBackwardImpl(int N, int M, int K, template inline typename std::enable_if::value, void>::type -GatherNDBackwardImpl(int N, int M, int K, +GatherNDBackwardImpl(int64_t N, int64_t M, int64_t K, const mshadow::Shape<10> strides, DType* out, const DType* data, @@ -1553,7 +1554,7 @@ GatherNDBackwardImpl(int N, int M, int K, mshadow::Stream *s); template -inline void GatherNDBackwardImpl(int N, int M, int K, +inline void GatherNDBackwardImpl(int64_t N, int64_t M, int64_t K, const mshadow::Shape<10> strides, DType* out, const DType* data, @@ -1567,17 +1568,18 @@ void GatherNDBackward(const nnvm::NodeAttrs& attrs, const std::vector& req, const std::vector& outputs) { using namespace mshadow; + using nnvm::dim_t; CHECK_EQ(inputs.size(), 2U); CHECK_EQ(outputs.size(), 1U); if (req[0] == kNullOp) return; mshadow::Stream *s = ctx.get_stream(); const TShape& oshape = outputs[0].shape_; const TShape& ishape = inputs[1].shape_; - int M = ishape[0]; - int N = ishape.Size() / M; - int K = oshape.ProdShape(M, oshape.ndim()); + dim_t M = ishape[0]; + dim_t N = ishape.Size() / M; + dim_t K = oshape.ProdShape(M, oshape.ndim()); mshadow::Shape<10> strides; - for (int i = M-1, stride = K; i >= 0; stride *= oshape[i], --i) strides[i] = stride; + for (dim_t i = M-1, stride = K; i >= 0; stride *= oshape[i], --i) strides[i] = stride; if (kWriteTo == req[0]) { Fill(s, outputs[0], req[0], 0); } diff --git a/src/operator/tensor/init_op.h b/src/operator/tensor/init_op.h index 4e52b087f10a..aff40dcdc6a6 100644 --- a/src/operator/tensor/init_op.h +++ b/src/operator/tensor/init_op.h @@ -453,7 +453,7 @@ void EyeFill(const nnvm::NodeAttrs& attrs, struct range_fwd { template - MSHADOW_XINLINE static void Map(int i, int repeat, DType start, DType step, + MSHADOW_XINLINE static void Map(int64_t i, int repeat, DType start, DType step, int req, DType* out) { KERNEL_ASSIGN(out[i], req, start + (i/repeat) * step); }