From 5cbaab5f079fffd829a2ab5fbfc42958935b54d5 Mon Sep 17 00:00:00 2001 From: biessy_r Date: Wed, 26 Apr 2017 17:56:44 +0100 Subject: [PATCH] [OpenCL] Fixes scatter_nd_op_test --- tensorflow/core/kernels/scatter_nd_op.cc | 44 ++++++++++++++++++- .../kernel_tests/scatter_nd_ops_test.py | 20 --------- 2 files changed, 42 insertions(+), 22 deletions(-) diff --git a/tensorflow/core/kernels/scatter_nd_op.cc b/tensorflow/core/kernels/scatter_nd_op.cc index b177edf461c938..48565d8cb97e95 100644 --- a/tensorflow/core/kernels/scatter_nd_op.cc +++ b/tensorflow/core/kernels/scatter_nd_op.cc @@ -27,6 +27,10 @@ limitations under the License. #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/util.h" +#ifdef TENSORFLOW_USE_SYCL +#include "tensorflow/core/common_runtime/sycl/sycl_util.h" +#endif // TENSORFLOW_USE_SYCL + namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; @@ -141,6 +145,40 @@ static void PrepareAndValidateInputs(OpKernelContext* c, *num_updates = indices_shape.num_elements() / safe_slice_dim; } +template +class IndexFlattener { +public: + inline typename TTypes::ConstTensor + operator()(OpKernelContext*, const Tensor& indices) { + return indices.flat_inner_dims(); + } +}; + +#ifdef TENSORFLOW_USE_SYCL +template +class IndexFlattener { +public: + IndexFlattener() { indices_host_ = nullptr; } + ~IndexFlattener() { delete[] indices_host_; } + + inline typename TTypes::ConstTensor + operator()(OpKernelContext* c, const Tensor& indices) { + size_t num_indices = indices.NumElements(); + indices_host_ = new Index[num_indices]; + auto device = c->eigen_sycl_device(); + auto size = sizeof(Index) * num_indices; + auto src_ptr = GetBase(&indices); + device.memcpyDeviceToHost(indices_host_, static_cast(src_ptr), + size); + return typename TTypes::ConstTensor(indices_host_, + indices.shape().AsEigenDSizes<2>()); + } + +private: + Index* indices_host_; +}; +#endif + template class ScatterNdOp : public OpKernel { public: @@ -169,7 +207,8 @@ class ScatterNdOp : public OpKernel { &num_updates, &slice_size); if (!c->status().ok()) return; - auto indices_flat = indices.flat_inner_dims(); + IndexFlattener index_flattener; + auto indices_flat = index_flattener(c, indices); auto updates_flat = updates.shaped({num_updates, slice_size}); Tensor* out = nullptr; @@ -265,7 +304,8 @@ class ScatterNdUpdateOp : public OpKernel { &slice_dim, &num_updates, &slice_size); if (!c->status().ok()) return; - auto indices_flat = indices.flat_inner_dims(); + IndexFlattener index_flattener; + auto indices_flat = index_flattener(c, indices); auto updates_flat = updates.shaped({num_updates, slice_size}); auto params_matrix = params.template shaped( {params_shape.num_elements() / slice_size, slice_size}); diff --git a/tensorflow/python/kernel_tests/scatter_nd_ops_test.py b/tensorflow/python/kernel_tests/scatter_nd_ops_test.py index d5300a60216d26..c672449fce346d 100644 --- a/tensorflow/python/kernel_tests/scatter_nd_ops_test.py +++ b/tensorflow/python/kernel_tests/scatter_nd_ops_test.py @@ -95,10 +95,6 @@ def _VariableRankTest(self, vtype, itype, repeat_indices=False): - # Currently not implemented for OpenCL - if test_util.is_sycl_enabled(): - return - np.random.seed(8) ref_shapes = [(3, 6), (3, 6), (3, 6, 9), (3, 6, 9), (3, 6, 9), (3, 6, 9)] indices_shapes = [(2,), (2, 2), (2,), (2, 2), (2, 3), (2, 3, 3)] @@ -149,10 +145,6 @@ def _VariableRankTests(self, np_scatter, tf_scatter): self._VariableRankTest(np_scatter, tf_scatter, vtype, itype) def testSimple(self): - # Currently not implemented for OpenCL - if test_util.is_sycl_enabled(): - return - indices = constant_op.constant([[4], [3], [1], [7]], dtype=dtypes.int32) updates = constant_op.constant([9, 10, 11, 12], dtype=dtypes.float32) ref = variables.Variable([0, 0, 0, 0, 0, 0, 0, 0], dtype=dtypes.float32) @@ -166,10 +158,6 @@ def testSimple(self): self.assertAllClose(result, expected) def testSimple2(self): - # Currently not implemented for OpenCL - if test_util.is_sycl_enabled(): - return - indices = constant_op.constant([[1, 0], [1, 1]], dtype=dtypes.int32) updates = constant_op.constant([11., 12.], dtype=dtypes.float32) ref = variables.Variable( @@ -184,10 +172,6 @@ def testSimple2(self): self.assertAllClose(result, expected) def testSimple3(self): - # Currently not implemented for OpenCL - if test_util.is_sycl_enabled(): - return - indices = constant_op.constant([[1]], dtype=dtypes.int32) updates = constant_op.constant([[11., 12.]], dtype=dtypes.float32) ref = variables.Variable( @@ -420,10 +404,6 @@ def testGradientsRank3SliceUpdate(self): self.assertAllEqual(expected_grads, grads.eval()) def testConcurrentUpdates(self): - # Currently not implemented for OpenCL - if test_util.is_sycl_enabled(): - return - num_updates = 10000 update_values = np.random.rand(num_updates) ref = variables.Variable(np.zeros([2, 2]), dtype=dtypes.float64)