From 05718f5f1576d64615a1d621b4bf3831eb8b15b8 Mon Sep 17 00:00:00 2001 From: Leonard Lausen Date: Fri, 3 Aug 2018 20:17:35 +0000 Subject: [PATCH 1/2] Fix reduce_kernel_M1 --- src/operator/tensor/broadcast_reduce-inl.cuh | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/operator/tensor/broadcast_reduce-inl.cuh b/src/operator/tensor/broadcast_reduce-inl.cuh index be3d1f9223f4..33bf72798fd6 100644 --- a/src/operator/tensor/broadcast_reduce-inl.cuh +++ b/src/operator/tensor/broadcast_reduce-inl.cuh @@ -268,7 +268,11 @@ __global__ void reduce_kernel_M1(const int N, const bool addto, for (int idx = threadIdx.x + blockIdx.x*blockDim.x; idx < N; idx += blockDim.x*gridDim.x) { Shape coord = unravel(idx, sshape); int j = ravel(coord, bshape); - assign(&small[idx], addto, OP::Map(big[j])); + DType val, residual; + Reducer::SetInitValue(val, residual); + Reducer::Reduce(val, OP::Map(big[j]), residual); + Reducer::Finalize(val, residual); + assign(&small[idx], addto, val); } } @@ -287,7 +291,10 @@ __global__ void reduce_kernel_M1(const int N, const bool addto, int idx_big = ravel(coord, big_shape); int idx_lhs = ravel(coord, lhs_shape); int idx_rhs = ravel(coord, rhs_shape); - DType val = OP1::Map(big[idx_big], OP2::Map(lhs[idx_lhs], rhs[idx_rhs])); + DType val, residual; + Reducer::SetInitValue(val, residual); + Reducer::Reduce(val, OP1::Map(big[idx_big], OP2::Map(lhs[idx_lhs], rhs[idx_rhs])), residual); + Reducer::Finalize(val, residual); assign(&small[idx], addto, val); } } From 873ba00769f73c3bb4e9be75d651492031477104 Mon Sep 17 00:00:00 2001 From: Leonard Lausen Date: Fri, 3 Aug 2018 21:03:23 +0000 Subject: [PATCH 2/2] Improve test_norm --- tests/python/unittest/test_ndarray.py | 34 ++++++++++++++++----------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py index e55fa1af90e8..ac6ee1561c47 100644 --- a/tests/python/unittest/test_ndarray.py +++ b/tests/python/unittest/test_ndarray.py @@ -1308,25 +1308,31 @@ def test_norm(ctx=default_context()): def l1norm(input_data, axis=0, keepdims=False): return np.sum(abs(input_data), axis=axis, keepdims=keepdims) - def l2norm(input_data, axis=0, keepdims=False): + def l2norm(input_data, axis=0, keepdims=False): return sp_norm(input_data, axis=axis, keepdims=keepdims) in_data_dim = random_sample([4,5,6], 1)[0] - in_data_shape = rand_shape_nd(in_data_dim) - np_arr = np.random.uniform(-1, 1, in_data_shape).astype(np.float32) - mx_arr = mx.nd.array(np_arr, ctx=ctx) - for ord in [1,2]: - for keep_dims in [True, False]: - for i in range(4): - npy_out = l1norm(np_arr, i, keep_dims) if ord==1 else l2norm(np_arr, i, keep_dims) - mx_out = mx.nd.norm(mx_arr, ord=ord, axis=i, keepdims=keep_dims) - assert npy_out.shape == mx_out.shape - mx.test_utils.assert_almost_equal(npy_out, mx_out.asnumpy()) - if (i < 3): - npy_out = l1norm(np_arr, (i, i+1), keep_dims) if ord==1 else l2norm(np_arr, (i, i+1), keep_dims) - mx_out = mx.nd.norm(mx_arr, ord=ord, axis=(i, i+1), keepdims=keep_dims) + for force_reduce_dim1 in [True, False]: + in_data_shape = rand_shape_nd(in_data_dim) + if force_reduce_dim1: + in_data_shape = in_data_shape[:3] + (1, ) + in_data_shape[4:] + np_arr = np.random.uniform(-1, 1, in_data_shape).astype(np.float32) + mx_arr = mx.nd.array(np_arr, ctx=ctx) + for ord in [1, 2]: + for keep_dims in [True, False]: + for i in range(4): + npy_out = l1norm(np_arr, i, keep_dims) if ord == 1 else l2norm( + np_arr, i, keep_dims) + mx_out = mx.nd.norm(mx_arr, ord=ord, axis=i, keepdims=keep_dims) assert npy_out.shape == mx_out.shape mx.test_utils.assert_almost_equal(npy_out, mx_out.asnumpy()) + if (i < 3): + npy_out = l1norm(np_arr, (i, i + 1), keep_dims) if ord == 1 else l2norm( + np_arr, (i, i + 1), keep_dims) + mx_out = mx.nd.norm(mx_arr, ord=ord, axis=(i, i + 1), keepdims=keep_dims) + assert npy_out.shape == mx_out.shape + mx.test_utils.assert_almost_equal(npy_out, mx_out.asnumpy()) + @with_seed() def test_ndarray_cpu_shared_ctx():