From 3cbb44d3c3200b43d0600f21efe71e7721d6c9b1 Mon Sep 17 00:00:00 2001 From: Yao Wang Date: Thu, 3 Dec 2020 20:50:59 -0800 Subject: [PATCH] [Topi] Fix GPU Dynamic Topk by Improving Dynamic Strided Slice in Topi (#7018) * Fix GPU dynamic Topk * Fix style * Minor fix * Simplfy dynamic checking * Fix lint * More improvements * Disable test any topk --- include/tvm/topi/detail/constant_utils.h | 15 +++++++ include/tvm/topi/nn.h | 2 +- include/tvm/topi/transform.h | 43 +++++++++++++++---- python/tvm/topi/cuda/sort.py | 20 +++++---- src/relay/op/tensor/transform.cc | 19 +++++--- .../relay/dyn/test_dynamic_op_level6.py | 4 +- tests/python/relay/test_any.py | 10 ++--- 7 files changed, 80 insertions(+), 33 deletions(-) diff --git a/include/tvm/topi/detail/constant_utils.h b/include/tvm/topi/detail/constant_utils.h index 412c79330ca9..49ce21b5732e 100644 --- a/include/tvm/topi/detail/constant_utils.h +++ b/include/tvm/topi/detail/constant_utils.h @@ -47,6 +47,21 @@ using namespace tvm::te; */ inline bool IsConstInt(PrimExpr expr) { return expr->IsInstance(); } +/*! + * \brief Test whether the given Array has every element as constant integer + * + * \param array the array to query + * + * \return true if every element in array is constant int or uint, false otherwise. + */ +inline bool IsConstIntArray(Array array) { + bool is_const_int = true; + for (auto const& elem : array) { + is_const_int &= elem->IsInstance(); + } + return is_const_int; +} + /*! * \brief Get the value of the given constant integer expression. An error * is logged if the given expression is not a constant integer. diff --git a/include/tvm/topi/nn.h b/include/tvm/topi/nn.h index f958048f13c3..71944071a7ce 100644 --- a/include/tvm/topi/nn.h +++ b/include/tvm/topi/nn.h @@ -614,7 +614,7 @@ inline tvm::te::Tensor batch_to_space_nd(const tvm::te::Tensor& data, out = reshape(out, r_p_shape); // Crop the start and end of dimensions of out - Array begin_idx, end_idx, strides; + Array begin_idx, end_idx, strides; for (size_t i = 0; i < r_p_shape.size(); ++i) { strides.push_back(Integer(1)); if (i > 0 && i <= num_block_dims) { diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index c2a4843dedd0..a04762f28feb 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -598,17 +598,42 @@ inline te::Tensor dynamic_strided_slice(const te::Tensor& x, const te::Tensor& b * * \return A Tensor whose op member is the split operation */ -inline Tensor strided_slice(const Tensor& x, const Array& begin, const Array& end, - const Array& strides, std::string slice_mode = "end", - std::string name = "T_strided_slice", std::string tag = kInjective) { +inline Tensor strided_slice(const Tensor& x, const Array& begin, + const Array& end, const Array& strides, + std::string slice_mode = "end", std::string name = "T_strided_slice", + std::string tag = kInjective) { size_t src_tensor_dim = static_cast(x->shape.size()); + // Quick path for dynamic shape strided slice. + // This is for ease of use to dynamice strided slice in topi. + bool is_static = IsConstIntArray(x->shape); + is_static &= IsConstIntArray(begin); + is_static &= IsConstIntArray(end); + is_static &= IsConstIntArray(strides); + + Array out_shape; + if (!is_static) { + for (size_t i = 0; i < src_tensor_dim; ++i) { + out_shape.push_back(indexdiv(end[i] - begin[i], strides[i])); + } + return te::compute( + out_shape, + [&](const Array& indices) { + Array real_indices; + for (size_t i = 0; i < src_tensor_dim; ++i) { + real_indices.push_back(indices[i] * strides[i] + begin[i]); + } + return x(real_indices); + }, + name, tag); + } + // Setup the ranges. // NOTE: this code duplicates the shape inference logic relay.op // Consider to refactor in the future. std::vector stride_vec(src_tensor_dim, 1); for (size_t i = 0; i < strides.size(); ++i) { ICHECK(strides[i].defined()); - stride_vec[i] = strides[i]->value; + stride_vec[i] = GetConstInt(strides[i]); } const int64_t max_range = std::numeric_limits::max(); @@ -619,7 +644,7 @@ inline Tensor strided_slice(const Tensor& x, const Array& begin, const // value=None begin_vec.push_back(stride_vec[i] > 0 ? 0 : max_range); } else { - begin_vec.push_back(begin[i]->value); + begin_vec.push_back(GetConstInt(begin[i])); } } for (size_t i = begin_vec.size(); i < src_tensor_dim; ++i) { @@ -633,20 +658,20 @@ inline Tensor strided_slice(const Tensor& x, const Array& begin, const if (!end[i].defined()) { end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range); } else if (slice_mode == "size") { - if (end[i]->value < 0) { + int64_t end_val = GetConstInt(end[i]); + if (end_val < 0) { end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range); } else { - end_vec.push_back(begin_vec[i] + end[i]->value); + end_vec.push_back(begin_vec[i] + end_val); } } else { - end_vec.push_back(end[i]->value); + end_vec.push_back(GetConstInt(end[i])); } } for (size_t i = end_vec.size(); i < src_tensor_dim; ++i) { end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range); } // Compute - Array out_shape; Array begin_expr; Array strides_expr; diff --git a/python/tvm/topi/cuda/sort.py b/python/tvm/topi/cuda/sort.py index 465299a5bc8f..ac14f5aae779 100644 --- a/python/tvm/topi/cuda/sort.py +++ b/python/tvm/topi/cuda/sort.py @@ -479,27 +479,28 @@ def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int64"): name="topk_gpu", tag="topk_gpu", ) - if k < 1: + if isinstance(k, int) and k < 1: if ret_type == "indices": return output[1] return output beg = [0] * ndim end = [] + strides = [1] * ndim for i in range(ndim): if i == axis: - end.append(k) + end.append(k if isinstance(k, int) else tvm.te.size_var("dim")) else: end.append(data.shape[i]) if ret_type == "both": values_out, indices_out = output - values_out = strided_slice(values_out, beg, end) - indices_out = strided_slice(indices_out, beg, end) + values_out = strided_slice(values_out, beg, end, strides) + indices_out = strided_slice(indices_out, beg, end, strides) output = [values_out, indices_out] elif ret_type == "values": - output = [strided_slice(output, beg, end)] + output = [strided_slice(output, beg, end, strides)] else: # ret_type == "indices" indices_out = output[1] - output = [strided_slice(indices_out, beg, end)] + output = [strided_slice(indices_out, beg, end, strides)] return output @@ -561,10 +562,11 @@ def topk_thrust(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int tag="topk_gpu", ) - if k > 0: + if not isinstance(k, int) or k > 0: beg = [0] * ndim - end = data.shape[:-1] + [k] - out = [strided_slice(o, beg, end) for o in out] + end = data.shape[:-1] + [k if isinstance(k, int) else tvm.te.size_var("dim")] + strides = [1] * ndim + out = [strided_slice(o, beg, end, strides) for o in out] if axis != ndim - 1: axes = swap(list(range(ndim)), axis) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index a3a9280be59c..640943eac805 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -2380,6 +2380,7 @@ Array StridedSliceCompute(const Attrs& attrs, const Array(); ICHECK(param != nullptr); Array begin, end, strides; + Array begin_expr, end_expr, strides_expr; begin = param->begin.value(); end = param->end.value(); strides = param->strides.value(); @@ -2392,8 +2393,6 @@ Array StridedSliceCompute(const Attrs& attrs, const Array begin_expr; - Array strides_expr; for (size_t i = 0; i < src_tensor_dim; ++i) { int64_t begin_i = begin[i]->value; if (begin_i < 0) { @@ -2414,8 +2413,19 @@ Array StridedSliceCompute(const Attrs& attrs, const Array{topi::strided_slice(inputs[0], begin, end, strides, param->slice_mode)}; + return Array{ + topi::strided_slice(inputs[0], begin_expr, end_expr, strides_expr, param->slice_mode)}; } // Positional relay function to create StridedSlice operator used by frontend FFI. @@ -2731,8 +2741,7 @@ Array SliceLikeCompute(const Attrs& attrs, const Array& << topi::GetConstInt(src_shape[axis]); } } - return Array{topi::strided_slice(inputs[0], GetIntArray(begin_idx), - GetIntArray(end_idx), GetIntArray(strides), "end")}; + return Array{topi::strided_slice(inputs[0], begin_idx, end_idx, strides, "end")}; } TVM_REGISTER_GLOBAL("relay.op._make.slice_like").set_body_typed(MakeSliceLike); diff --git a/tests/python/relay/dyn/test_dynamic_op_level6.py b/tests/python/relay/dyn/test_dynamic_op_level6.py index aeed8db7c1b6..52abbe2a15b6 100644 --- a/tests/python/relay/dyn/test_dynamic_op_level6.py +++ b/tests/python/relay/dyn/test_dynamic_op_level6.py @@ -22,8 +22,8 @@ from tvm import relay import tvm.testing -# TODO(mbrookhart): Enable when we can get it working -# @tvm.testing.uses_gpu + +@tvm.testing.uses_gpu def test_dynamic_topk(): def verify_topk(k, axis, ret_type, is_ascend, dtype): shape = (20, 100) diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index eec6aa21c69b..ee67e67b282f 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -815,15 +815,11 @@ def verify_any_topk(data_shape, kval, np_dshape, dtype, const_k=False): else: ref_out = sorted[0:kval] - for kind in ["debug", "vm"]: - ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") - result = ex.evaluate()(*in_vals) - tvm.testing.assert_allclose(result.asnumpy(), ref_out) - - # TODO(@zhiics) Fix topk cuda schedule for dynamic inputs - # check_result(in_vals, mod, ref_out) + check_result(in_vals, mod, ref_out) +# TODO(kevinthesun): enable this test when Thrust is available in ci. +# @tvm.testing.uses_gpu def test_any_topk(): verify_any_topk(any_dims(1), 5, (10,), "float32") verify_any_topk(any_dims(2), 2, (6, 3), "int32")