Skip to content

Commit

Permalink
[OpenCL] Fixes scatter_nd_op_test
Browse files Browse the repository at this point in the history
  • Loading branch information
Rbiessy authored and Luke Iwanski committed Apr 27, 2017
1 parent 4704240 commit 5cbaab5
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 22 deletions.
44 changes: 42 additions & 2 deletions tensorflow/core/kernels/scatter_nd_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ limitations under the License.
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/util.h"

#ifdef TENSORFLOW_USE_SYCL
#include "tensorflow/core/common_runtime/sycl/sycl_util.h"
#endif // TENSORFLOW_USE_SYCL

namespace tensorflow {

typedef Eigen::ThreadPoolDevice CPUDevice;
Expand Down Expand Up @@ -141,6 +145,40 @@ static void PrepareAndValidateInputs(OpKernelContext* c,
*num_updates = indices_shape.num_elements() / safe_slice_dim;
}

template <typename Device, typename Index>
class IndexFlattener {
public:
inline typename TTypes<Index, 2>::ConstTensor
operator()(OpKernelContext*, const Tensor& indices) {
return indices.flat_inner_dims<Index>();
}
};

#ifdef TENSORFLOW_USE_SYCL
template <typename Index>
class IndexFlattener<SYCLDevice, Index> {
public:
IndexFlattener() { indices_host_ = nullptr; }
~IndexFlattener() { delete[] indices_host_; }

inline typename TTypes<Index, 2>::ConstTensor
operator()(OpKernelContext* c, const Tensor& indices) {
size_t num_indices = indices.NumElements();
indices_host_ = new Index[num_indices];
auto device = c->eigen_sycl_device();
auto size = sizeof(Index) * num_indices;
auto src_ptr = GetBase(&indices);
device.memcpyDeviceToHost(indices_host_, static_cast<const Index*>(src_ptr),
size);
return typename TTypes<Index, 2>::ConstTensor(indices_host_,
indices.shape().AsEigenDSizes<2>());
}

private:
Index* indices_host_;
};
#endif

template <typename Device, typename T, typename Index>
class ScatterNdOp : public OpKernel {
public:
Expand Down Expand Up @@ -169,7 +207,8 @@ class ScatterNdOp : public OpKernel {
&num_updates, &slice_size);
if (!c->status().ok()) return;

auto indices_flat = indices.flat_inner_dims<Index>();
IndexFlattener<Device, Index> index_flattener;
auto indices_flat = index_flattener(c, indices);
auto updates_flat = updates.shaped<T, 2>({num_updates, slice_size});

Tensor* out = nullptr;
Expand Down Expand Up @@ -265,7 +304,8 @@ class ScatterNdUpdateOp : public OpKernel {
&slice_dim, &num_updates, &slice_size);
if (!c->status().ok()) return;

auto indices_flat = indices.flat_inner_dims<Index>();
IndexFlattener<Device, Index> index_flattener;
auto indices_flat = index_flattener(c, indices);
auto updates_flat = updates.shaped<T, 2>({num_updates, slice_size});
auto params_matrix = params.template shaped<T, 2>(
{params_shape.num_elements() / slice_size, slice_size});
Expand Down
20 changes: 0 additions & 20 deletions tensorflow/python/kernel_tests/scatter_nd_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,6 @@ def _VariableRankTest(self,
vtype,
itype,
repeat_indices=False):
# Currently not implemented for OpenCL
if test_util.is_sycl_enabled():
return

np.random.seed(8)
ref_shapes = [(3, 6), (3, 6), (3, 6, 9), (3, 6, 9), (3, 6, 9), (3, 6, 9)]
indices_shapes = [(2,), (2, 2), (2,), (2, 2), (2, 3), (2, 3, 3)]
Expand Down Expand Up @@ -149,10 +145,6 @@ def _VariableRankTests(self, np_scatter, tf_scatter):
self._VariableRankTest(np_scatter, tf_scatter, vtype, itype)

def testSimple(self):
# Currently not implemented for OpenCL
if test_util.is_sycl_enabled():
return

indices = constant_op.constant([[4], [3], [1], [7]], dtype=dtypes.int32)
updates = constant_op.constant([9, 10, 11, 12], dtype=dtypes.float32)
ref = variables.Variable([0, 0, 0, 0, 0, 0, 0, 0], dtype=dtypes.float32)
Expand All @@ -166,10 +158,6 @@ def testSimple(self):
self.assertAllClose(result, expected)

def testSimple2(self):
# Currently not implemented for OpenCL
if test_util.is_sycl_enabled():
return

indices = constant_op.constant([[1, 0], [1, 1]], dtype=dtypes.int32)
updates = constant_op.constant([11., 12.], dtype=dtypes.float32)
ref = variables.Variable(
Expand All @@ -184,10 +172,6 @@ def testSimple2(self):
self.assertAllClose(result, expected)

def testSimple3(self):
# Currently not implemented for OpenCL
if test_util.is_sycl_enabled():
return

indices = constant_op.constant([[1]], dtype=dtypes.int32)
updates = constant_op.constant([[11., 12.]], dtype=dtypes.float32)
ref = variables.Variable(
Expand Down Expand Up @@ -420,10 +404,6 @@ def testGradientsRank3SliceUpdate(self):
self.assertAllEqual(expected_grads, grads.eval())

def testConcurrentUpdates(self):
# Currently not implemented for OpenCL
if test_util.is_sycl_enabled():
return

num_updates = 10000
update_values = np.random.rand(num_updates)
ref = variables.Variable(np.zeros([2, 2]), dtype=dtypes.float64)
Expand Down

0 comments on commit 5cbaab5

Please sign in to comment.