Skip to content

Commit

Permalink
fix int overflow.
Browse files Browse the repository at this point in the history
  • Loading branch information
zheng-da committed Nov 1, 2018
1 parent 0bea50e commit 2c3d9a3
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 66 deletions.
44 changes: 22 additions & 22 deletions src/operator/mxnet_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -433,51 +433,51 @@ struct op_with_req {

/*! \brief input is one tensor */
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType *out, const DType *in) {
MSHADOW_XINLINE static void Map(int64_t i, DType *out, const DType *in) {
KERNEL_ASSIGN(out[i], req, OP::Map(in[i]));
}

/*! \brief inputs are two tensors */
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType *out, const DType *lhs, const DType *rhs) {
MSHADOW_XINLINE static void Map(int64_t i, DType *out, const DType *lhs, const DType *rhs) {
KERNEL_ASSIGN(out[i], req, OP::Map(lhs[i], rhs[i]));
}

/*! \brief input is tensor and a scalar value */
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType *out, const DType *in, const DType value) {
MSHADOW_XINLINE static void Map(int64_t i, DType *out, const DType *in, const DType value) {
KERNEL_ASSIGN(out[i], req, OP::Map(in[i], value));
}

/*! \brief input is tensor and two scalar value */
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType *out, const DType *in,
MSHADOW_XINLINE static void Map(int64_t i, DType *out, const DType *in,
const DType value_1, const DType value_2) {
KERNEL_ASSIGN(out[i], req, OP::Map(in[i], value_1, value_2));
}

/*! \brief No inputs (ie fill to constant value) */
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType *out) {
MSHADOW_XINLINE static void Map(int64_t i, DType *out) {
KERNEL_ASSIGN(out[i], req, OP::Map());
}

/*! \brief input is single scalar value */
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType *out, const DType value) {
MSHADOW_XINLINE static void Map(int64_t i, DType *out, const DType value) {
KERNEL_ASSIGN(out[i], req, OP::Map(value));
}

/*! \brief inputs are two tensors and a scalar value */
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType *out,
MSHADOW_XINLINE static void Map(int64_t i, DType *out,
const DType *input_1, const DType *input_2, const DType value) {
KERNEL_ASSIGN(out[i], req, OP::Map(input_1[i], input_2[i], value));
}

/*! \brief inputs are three tensors (ie backward grad with binary grad function) */
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType *out,
MSHADOW_XINLINE static void Map(int64_t i, DType *out,
const DType *input_1,
const DType *input_2,
const DType *input_3) {
Expand All @@ -503,21 +503,21 @@ struct Kernel<OP, cpu> {
* \param args Varargs to eventually pass to the OP::Map() functoion
*/
template<typename ...Args>
inline static bool Launch(mshadow::Stream<cpu> *, const int N, Args... args) {
inline static bool Launch(mshadow::Stream<cpu> *, const int64_t N, Args... args) {
#ifdef _OPENMP
const int omp_threads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
if (omp_threads < 2) {
for (int i = 0; i < N; ++i) {
for (int64_t i = 0; i < N; ++i) {
OP::Map(i, args...);
}
} else {
#pragma omp parallel for num_threads(omp_threads)
for (int i = 0; i < N; ++i) {
for (int64_t i = 0; i < N; ++i) {
OP::Map(i, args...);
}
}
#else
for (int i = 0; i < N; ++i) {
for (int64_t i = 0; i < N; ++i) {
OP::Map(i, args...);
}
#endif
Expand All @@ -536,22 +536,22 @@ struct Kernel<OP, cpu> {
* \param args Varargs to eventually pass to the OP::Map() functoion
*/
template<typename PRIMITIVE_OP, typename DType, typename ...Args>
static void LaunchTuned(mshadow::Stream<cpu> *, const int N, Args... args) {
static void LaunchTuned(mshadow::Stream<cpu> *, const int64_t N, Args... args) {
#ifdef _OPENMP
const int omp_threads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
if (omp_threads < 2 || !tuned_op<PRIMITIVE_OP, DType>::UseOMP(
static_cast<size_t>(N), static_cast<size_t>(omp_threads))) {
for (int i = 0; i < N; ++i) {
for (int64_t i = 0; i < N; ++i) {
OP::Map(i, args...);
}
} else {
#pragma omp parallel for num_threads(omp_threads)
for (int i = 0; i < N; ++i) {
for (int64_t i = 0; i < N; ++i) {
OP::Map(i, args...);
}
}
#else
for (int i = 0; i < N; ++i) {
for (int64_t i = 0; i < N; ++i) {
OP::Map(i, args...);
}
#endif
Expand All @@ -565,15 +565,15 @@ struct Kernel<OP, cpu> {
* \param args Varargs to eventually pass to the UseOMP() and OP::Map() functions
*/
template<typename ...Args>
inline static void LaunchEx(mshadow::Stream<cpu> *s, const int N, Args... args) {
inline static void LaunchEx(mshadow::Stream<cpu> *s, const int64_t N, Args... args) {
#ifdef _OPENMP
const int omp_threads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
if (omp_threads < 2) {
OP::Map(0, N, args...);
} else {
const int length = (N + omp_threads - 1) / omp_threads;
const int64_t length = (N + omp_threads - 1) / omp_threads;
#pragma omp parallel for num_threads(omp_threads)
for (int i = 0; i < N; i += length) {
for (int64_t i = 0; i < N; i += length) {
OP::Map(i, i + length > N ? N - i : length, args...);
}
}
Expand All @@ -595,7 +595,7 @@ struct Kernel<OP, cpu> {
template<typename DType, typename T = OP, typename ...Args>
static MSHADOW_CINLINE
typename std::enable_if<std::is_base_of<tunable, T>::value, bool>::type
Launch(mshadow::Stream<cpu> *s, const int N, DType *dest, Args... args) {
Launch(mshadow::Stream<cpu> *s, const int64_t N, DType *dest, Args... args) {
LaunchTuned<T, DType>(s, N, dest, args...);
return true;
}
Expand All @@ -613,7 +613,7 @@ struct Kernel<OP, cpu> {
template<typename DType, typename T = OP, typename ...Args>
static MSHADOW_CINLINE
typename std::enable_if<std::is_base_of<tunable, typename T::Operation>::value, bool>::type
Launch(mshadow::Stream<cpu> *s, const int N, DType *dest, Args... args) {
Launch(mshadow::Stream<cpu> *s, const int64_t N, DType *dest, Args... args) {
LaunchTuned<typename T::Operation, DType>(s, N, dest, args...);
return true;
}
Expand Down Expand Up @@ -669,7 +669,7 @@ template<int val>
struct set_to_int : public tunable {
// mxnet_op version (when used directly with Kernel<>::Launch()) */
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType *out) {
MSHADOW_XINLINE static void Map(int64_t i, DType *out) {
out[i] = DType(val);
}
// mshadow_op version (when used with op_with_req<>)
Expand Down
29 changes: 15 additions & 14 deletions src/operator/random/sampler.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,32 +43,33 @@ namespace op {
template<typename OP, typename xpu, typename GType, typename ...Args>
inline static void LaunchRNG(mshadow::Stream<xpu> *s,
common::random::RandGenerator<xpu, GType> *gen,
const int N, Args... args) {
const int64_t N, Args... args) {
// minimal check to avoid division by zero, below.
// if `N` is zero the map operation is a no-op in any case.
if (N <= 0) {
return;
}
const int nloop = (N + RandGenerator<xpu>::kMinNumRandomPerThread - 1) /
const int64_t nloop = (N + RandGenerator<xpu>::kMinNumRandomPerThread - 1) /
RandGenerator<xpu>::kMinNumRandomPerThread;
const int nthread = std::min(nloop, RandGenerator<xpu>::kNumRandomStates);
const int step = (N + nthread - 1) / nthread;
const int64_t nthread = std::min(nloop,
static_cast<int64_t>(RandGenerator<xpu>::kNumRandomStates));
const int64_t step = (N + nthread - 1) / nthread;
Kernel<OP, xpu>::Launch(s, nthread, *gen, N, step, args...);
}

#define RNG_KERNEL_LOOP(xpu, GType, thread_id, gen, N, step, ...) \
const int start = thread_id * step; \
const int end = start + step; \
const int64_t start = thread_id * step; \
const int64_t end = start + step; \
typename RandGenerator<xpu, GType>::Impl genImpl(&gen, thread_id); \
for (int i = start; i < end && i < N; ++i) { \
for (int64_t i = start; i < end && i < N; ++i) { \
{__VA_ARGS__} \
}

template<typename xpu>
struct SampleUniformKernel {
template<typename IType, typename OType>
MSHADOW_XINLINE static void Map(int id, RandGenerator<xpu, OType> gen,
const int N, const int step,
const int64_t N, const int64_t step,
index_t nParm, index_t nSample,
const IType *lower, const IType *upper, OType *out) {
RNG_KERNEL_LOOP(xpu, OType, id, gen, N, step, {
Expand Down Expand Up @@ -96,7 +97,7 @@ template<typename xpu>
struct SampleNormalKernel {
template<typename IType, typename OType>
MSHADOW_XINLINE static void Map(int id, RandGenerator<xpu, OType> gen,
const int N, const int step,
const int64_t N, const int64_t step,
index_t nParm, index_t nSample,
const IType *mean, const IType *std, OType *out) {
RNG_KERNEL_LOOP(xpu, OType, id, gen, N, step, {
Expand All @@ -123,7 +124,7 @@ template<typename xpu>
struct SampleExponentialKernel {
template<typename IType, typename OType>
MSHADOW_XINLINE static void Map(int id, RandGenerator<xpu, OType> gen,
const int N, const int step,
const int64_t N, const int64_t step,
index_t nParm, index_t nSample,
const IType *lambda, OType *out) {
RNG_KERNEL_LOOP(xpu, OType, id, gen, N, step, {
Expand Down Expand Up @@ -171,7 +172,7 @@ template<typename xpu>
struct SampleGammaKernel {
template<typename IType, typename OType, typename FType>
MSHADOW_XINLINE static void Map(int id, RandGenerator<xpu, FType> gen,
const int N, const int step,
const int64_t N, const int64_t step,
index_t nParm, index_t nSample,
const IType *alpha, const IType *beta, OType *out) {
RNG_KERNEL_LOOP(xpu, FType, id, gen, N, step, {
Expand Down Expand Up @@ -233,7 +234,7 @@ template<typename xpu>
struct SamplePoissonKernel {
template<typename IType, typename OType>
MSHADOW_XINLINE static void Map(int id, RandGenerator<xpu, float> gen,
const int N, const int step,
const int64_t N, const int64_t step,
index_t nParm, index_t nSample,
const IType *lambda, OType *out) {
RNG_KERNEL_LOOP(xpu, float, id, gen, N, step, {
Expand All @@ -260,7 +261,7 @@ template<typename xpu>
struct SampleNegativeBinomialKernel {
template<typename IType, typename OType>
MSHADOW_XINLINE static void Map(int id, RandGenerator<xpu, float> gen,
const int N, const int step,
const int64_t N, const int64_t step,
index_t nParm, index_t nSample,
const IType *k, const IType *p, OType *out) {
RNG_KERNEL_LOOP(xpu, float, id, gen, N, step, {
Expand Down Expand Up @@ -292,7 +293,7 @@ template<typename xpu>
struct SampleGeneralizedNegativeBinomialKernel {
template<typename IType, typename OType>
MSHADOW_XINLINE static void Map(int id, RandGenerator<xpu, float> gen,
const int N, const int step,
const int64_t N, const int64_t step,
index_t nParm, index_t nSample,
const IType *mu, const IType *alpha, OType *out) {
RNG_KERNEL_LOOP(xpu, float, id, gen, N, step, {
Expand Down
24 changes: 12 additions & 12 deletions src/operator/tensor/indexing_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -307,19 +307,19 @@ inline void SparseEmbeddingOpBackwardRspImpl<cpu>(const bool deterministic,

template<typename DType, typename IType>
inline typename std::enable_if<(!std::is_same<DType, mshadow::half::half_t>::value), void>::type
GatherNDBackwardImpl(int N, int M, int K,
GatherNDBackwardImpl(int64_t N, int64_t M, int64_t K,
const mshadow::Shape<10> strides,
DType* out,
const DType* data,
const IType* indices,
mshadow::Stream<cpu> *s) {
#pragma omp parallel for
for (int i = 0; i < N; i++) {
int offset = 0;
for (int j = 0; j < M; ++j) {
offset += strides[j] * static_cast<int>(indices[j*N + i]);
for (int64_t i = 0; i < N; i++) {
int64_t offset = 0;
for (int64_t j = 0; j < M; ++j) {
offset += strides[j] * static_cast<int64_t>(indices[j*N + i]);
}
for (int j = 0; j < K; ++j) {
for (int64_t j = 0; j < K; ++j) {
#pragma omp atomic
out[offset + j] += data[i * K + j];
}
Expand All @@ -328,18 +328,18 @@ GatherNDBackwardImpl(int N, int M, int K,

template<typename DType, typename IType>
inline typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value, void>::type
GatherNDBackwardImpl(int N, int M, int K,
GatherNDBackwardImpl(int64_t N, int64_t M, int64_t K,
const mshadow::Shape<10> strides,
DType* out,
const DType* data,
const IType* indices,
mshadow::Stream<cpu> *s) {
for (int i = 0; i < N; i++) {
int offset = 0;
for (int j = 0; j < M; ++j) {
offset += strides[j] * static_cast<int>(indices[j*N + i]);
for (int64_t i = 0; i < N; i++) {
int64_t offset = 0;
for (int64_t j = 0; j < M; ++j) {
offset += strides[j] * static_cast<int64_t>(indices[j*N + i]);
}
for (int j = 0; j < K; ++j) {
for (int64_t j = 0; j < K; ++j) {
out[offset + j] += data[i * K + j];
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/operator/tensor/indexing_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ struct backward_gather_nd_gpu {
};

template<typename DType, typename IType>
inline void GatherNDBackwardImpl(int N, int M, int K,
inline void GatherNDBackwardImpl(int64_t N, int64_t M, int64_t K,
const mshadow::Shape<10> strides,
DType* out,
const DType* data,
Expand Down
Loading

0 comments on commit 2c3d9a3

Please sign in to comment.