Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Slice support for std::array shapes #363

Merged
merged 1 commit into from
Jan 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 22 additions & 6 deletions include/matx/core/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -1521,9 +1521,9 @@ class tensor_t : public detail::tensor_impl_t<T,RANK,Desc> {
*
*/
template <int N = RANK>
__MATX_INLINE__ auto Slice([[maybe_unused]] const typename Desc::shape_type (&firsts)[RANK],
[[maybe_unused]] const typename Desc::shape_type (&ends)[RANK],
[[maybe_unused]] const typename Desc::stride_type (&strides)[RANK]) const
__MATX_INLINE__ auto Slice([[maybe_unused]] const std::array<typename Desc::shape_type, RANK> &firsts,
[[maybe_unused]] const std::array<typename Desc::shape_type, RANK> &ends,
[[maybe_unused]] const std::array<typename Desc::stride_type, RANK> &strides) const
{
static_assert(N <= RANK && RANK > 0, "Must slice to a rank the same or less than current rank.");

Expand Down Expand Up @@ -1578,6 +1578,14 @@ class tensor_t : public detail::tensor_impl_t<T,RANK,Desc> {
tensor_desc_t<decltype(n), decltype(s), N> new_desc{std::move(n), std::move(s)};
return tensor_t<T, N, Storage, decltype(new_desc)>{storage_, std::move(new_desc), data};
}

template <int N = RANK>
__MATX_INLINE__ auto Slice(const typename Desc::shape_type (&firsts)[RANK],
const typename Desc::shape_type (&ends)[RANK],
const typename Desc::stride_type (&strides)[RANK]) const
{
return Slice<N>(detail::to_array(firsts), detail::to_array(ends), detail::to_array(strides));
}

/**
* Slice a tensor either within the same dimension or to a lower dimension
Expand All @@ -1604,17 +1612,25 @@ class tensor_t : public detail::tensor_impl_t<T,RANK,Desc> {
*
*/
template <int N = RANK>
__MATX_INLINE__ auto Slice(const typename Desc::shape_type (&firsts)[RANK],
const typename Desc::shape_type (&ends)[RANK]) const
__MATX_INLINE__ auto Slice(const std::array<typename Desc::shape_type, RANK> &firsts,
const std::array<typename Desc::shape_type, RANK> &ends) const
{
static_assert(N <= RANK && RANK > 0, "Must slice to a rank the same or less than current rank.");

MATX_NVTX_START("", matx::MATX_NVTX_LOG_API)

const std::array<typename Desc::stride_type, RANK> strides = {-1};

const typename Desc::stride_type strides[RANK] = {-1};
return Slice<N>(firsts, ends, strides);
}

template <int N = RANK>
__MATX_INLINE__ auto Slice(const typename Desc::shape_type (&firsts)[RANK],
const typename Desc::shape_type (&ends)[RANK]) const
{
return Slice<N>(detail::to_array(firsts), detail::to_array(ends));
}


/**
* Print a value
Expand Down
74 changes: 59 additions & 15 deletions include/matx/operators/slice.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ namespace matx

__MATX_INLINE__ std::string str() const { return "slice(" + op_.str() + ")"; }

__MATX_INLINE__ SliceOp(T op, const shape_type (&starts)[T::Rank()], const shape_type (&ends)[T::Rank()], const shape_type (&strides)[T::Rank()]) : op_(op) {
__MATX_INLINE__ SliceOp(T op, const std::array<shape_type, T::Rank()> &starts,
const std::array<shape_type, T::Rank()> &ends,
const std::array<shape_type, T::Rank()> &strides) : op_(op) {
int32_t d = 0;
for(int32_t i = 0; i < T::Rank(); i++) {
shape_type start = starts[i];
Expand Down Expand Up @@ -169,9 +171,9 @@ namespace matx
*/
template <typename OpType>
__MATX_INLINE__ auto slice( const OpType &op,
const index_t (&starts)[OpType::Rank()],
const index_t (&ends)[OpType::Rank()],
const index_t (&strides)[OpType::Rank()])
const std::array<index_t, OpType::Rank()> &starts,
const std::array<index_t, OpType::Rank()> &ends,
const std::array<index_t, OpType::Rank()> &strides)
{
if constexpr (is_tensor_view_v<OpType>) {
return op.Slice(starts, ends, strides);
Expand All @@ -180,6 +182,18 @@ namespace matx
}
}

template <typename OpType>
__MATX_INLINE__ auto slice( const OpType &op,
const index_t (&starts)[OpType::Rank()],
const index_t (&ends)[OpType::Rank()],
const index_t (&strides)[OpType::Rank()])
{
return slice(op,
detail::to_array(starts),
detail::to_array(ends),
detail::to_array(strides));
}

/**
* @brief Operator to logically slice a tensor or operator.
*
Expand All @@ -194,14 +208,23 @@ namespace matx
* @return sliced operator
*/
template <typename OpType>
__MATX_INLINE__ auto slice( const OpType &op,
const std::array<index_t, OpType::Rank()> &starts,
const std::array<index_t, OpType::Rank()> &ends)
{
std::array<index_t, OpType::Rank()> strides;
strides.fill(1);

return slice(op, starts, ends, strides);
}
template <typename OpType>
__MATX_INLINE__ auto slice( const OpType &op,
const index_t (&starts)[OpType::Rank()],
const index_t (&ends)[OpType::Rank()])
{
index_t strides[OpType::Rank()];
for(int i = 0; i < OpType::Rank(); i++)
strides[i] = 1;
return slice(op, starts, ends, strides);
return slice(op,
detail::to_array(starts),
detail::to_array(ends));
}

/**
Expand All @@ -223,9 +246,9 @@ namespace matx
*/
template <int N, typename OpType>
__MATX_INLINE__ auto slice( const OpType op,
const index_t (&starts)[OpType::Rank()],
const index_t (&ends)[OpType::Rank()],
const index_t (&strides)[OpType::Rank()])
const std::array<index_t, OpType::Rank()> &starts,
const std::array<index_t, OpType::Rank()> &ends,
const std::array<index_t, OpType::Rank()> &strides)
{
if constexpr (is_tensor_view_v<OpType>) {
return op.template Slice<N>(starts, ends, strides);
Expand All @@ -234,6 +257,18 @@ namespace matx
}
}

template <int N, typename OpType>
__MATX_INLINE__ auto slice( const OpType op,
const index_t (&starts)[OpType::Rank()],
const index_t (&ends)[OpType::Rank()],
const index_t (&strides)[OpType::Rank()])
{
return slice<N,OpType>(op,
detail::to_array(starts),
detail::to_array(ends),
detail::to_array(strides));
}

/**
* @brief Operator to logically slice a tensor or operator.
*
Expand All @@ -250,14 +285,23 @@ namespace matx
* @param ends the last element (exclusive) of each dimension of the input operator. matxDrop Dim removes that dimension. matxEnd deontes all remaining elements in that dimension.
* @return sliced operator
*/
template <int N, typename OpType>
__MATX_INLINE__ auto slice (const OpType opIn,
const std::array<index_t, OpType::Rank()> &starts,
const std::array<index_t, OpType::Rank()> &ends)
{
std::array<index_t, OpType::Rank()> strides;
strides.fill(1);
return slice<N,OpType>(opIn, starts, ends, strides);
}

template <int N, typename OpType>
__MATX_INLINE__ auto slice (const OpType opIn,
const index_t (&starts)[OpType::Rank()],
const index_t (&ends)[OpType::Rank()])
{
typename OpType::shape_type strides[OpType::Rank()];
for (int i = 0; i < OpType::Rank(); i++)
strides[i] = 1;
return slice<N, OpType>(opIn, starts, ends, strides);
return slice<N,OpType>(opIn,
detail::to_array(starts),
detail::to_array(ends));
}
} // end namespace matx