Skip to content

Commit

Permalink
Dedicated kernels for in-place dpt.divide and dpt.floor_divide (
Browse files Browse the repository at this point in the history
#1431)

* Implements dedicated kernels for in-place division

Includes floor division and true division

* Adds tests for inplace division behavior

* Adds a `static_assert` check to TrueDivideInplaceTypeMapFactory

Checks that the result type is either the same as the third template parameter, or none

Adds a comment to TrueDivideInplaceOutputType
  • Loading branch information
ndgrigorian authored Oct 11, 2023
1 parent 39e0700 commit e885838
Show file tree
Hide file tree
Showing 6 changed files with 657 additions and 23 deletions.
2 changes: 2 additions & 0 deletions dpctl/tensor/_elementwise_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,7 @@
ti._divide_result_type,
ti._divide,
_divide_docstring_,
binary_inplace_fn=ti._divide_inplace,
acceptance_fn=_acceptance_fn_divide,
)

Expand Down Expand Up @@ -720,6 +721,7 @@
ti._floor_divide_result_type,
ti._floor_divide,
_floor_divide_docstring_,
binary_inplace_fn=ti._floor_divide_inplace,
)

# B11: ==== GREATER (x1, x2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,7 @@ struct FloorDivideFunctor

resT operator()(const argT1 &in1, const argT2 &in2) const
{
if constexpr (std::is_same_v<argT1, bool> &&
std::is_same_v<argT2, bool>) {
return (in2) ? static_cast<resT>(in1) : resT(0);
}
else if constexpr (std::is_integral_v<argT1> ||
std::is_integral_v<argT2>) {
if constexpr (std::is_integral_v<argT1> || std::is_integral_v<argT2>) {
if (in2 == argT2(0)) {
return resT(0);
}
Expand All @@ -87,16 +82,7 @@ struct FloorDivideFunctor
operator()(const sycl::vec<argT1, vec_sz> &in1,
const sycl::vec<argT2, vec_sz> &in2) const
{
if constexpr (std::is_same_v<argT1, bool> &&
std::is_same_v<argT2, bool>) {
sycl::vec<resT, vec_sz> res;
#pragma unroll
for (int i = 0; i < vec_sz; ++i) {
res[i] = (in2[i]) ? static_cast<resT>(in1[i]) : resT(0);
}
return res;
}
else if constexpr (std::is_integral_v<resT>) {
if constexpr (std::is_integral_v<resT>) {
sycl::vec<resT, vec_sz> res;
#pragma unroll
for (int i = 0; i < vec_sz; ++i) {
Expand Down Expand Up @@ -165,7 +151,6 @@ template <typename T1, typename T2> struct FloorDivideOutputType
{
using value_type = typename std::disjunction< // disjunction is C++17
// feature, supported by DPC++
td_ns::BinaryTypeMapResultEntry<T1, bool, T2, bool, std::int8_t>,
td_ns::BinaryTypeMapResultEntry<T1,
std::uint8_t,
T2,
Expand Down Expand Up @@ -315,6 +300,183 @@ struct FloorDivideStridedFactory
}
};

template <typename argT, typename resT> struct FloorDivideInplaceFunctor
{
using supports_sg_loadstore = std::true_type;
using supports_vec = std::true_type;

void operator()(resT &in1, const argT &in2) const
{
if constexpr (std::is_integral_v<resT>) {
if (in2 == argT(0)) {
in1 = 0;
return;
}
if constexpr (std::is_signed_v<resT>) {
auto tmp = in1;
in1 /= in2;
auto mod = tmp % in2;
auto corr = (mod != 0 && l_xor(mod < 0, in2 < 0));
in1 -= corr;
}
else {
in1 /= in2;
}
}
else {
in1 /= in2;
if (in1 == resT(0)) {
return;
}
in1 = std::floor(in1);
}
}

template <int vec_sz>
void operator()(sycl::vec<resT, vec_sz> &in1,
const sycl::vec<argT, vec_sz> &in2) const
{
if constexpr (std::is_integral_v<resT>) {
#pragma unroll
for (int i = 0; i < vec_sz; ++i) {
if (in2[i] == argT(0)) {
in1[i] = 0;
}
else {
if constexpr (std::is_signed_v<resT>) {
auto tmp = in1[i];
in1[i] /= in2[i];
auto mod = tmp % in2[i];
auto corr = (mod != 0 && l_xor(mod < 0, in2[i] < 0));
in1[i] -= corr;
}
else {
in1[i] /= in2[i];
}
}
}
}
else {
in1 /= in2;
#pragma unroll
for (int i = 0; i < vec_sz; ++i) {
if (in2[i] != argT(0)) {
in1[i] = std::floor(in1[i]);
}
}
}
}

private:
bool l_xor(bool b1, bool b2) const
{
return (b1 != b2);
}
};

template <typename argT,
typename resT,
unsigned int vec_sz = 4,
unsigned int n_vecs = 2>
using FloorDivideInplaceContigFunctor =
elementwise_common::BinaryInplaceContigFunctor<
argT,
resT,
FloorDivideInplaceFunctor<argT, resT>,
vec_sz,
n_vecs>;

template <typename argT, typename resT, typename IndexerT>
using FloorDivideInplaceStridedFunctor =
elementwise_common::BinaryInplaceStridedFunctor<
argT,
resT,
IndexerT,
FloorDivideInplaceFunctor<argT, resT>>;

template <typename argT,
typename resT,
unsigned int vec_sz,
unsigned int n_vecs>
class floor_divide_inplace_contig_kernel;

template <typename argTy, typename resTy>
sycl::event
floor_divide_inplace_contig_impl(sycl::queue &exec_q,
size_t nelems,
const char *arg_p,
py::ssize_t arg_offset,
char *res_p,
py::ssize_t res_offset,
const std::vector<sycl::event> &depends = {})
{
return elementwise_common::binary_inplace_contig_impl<
argTy, resTy, FloorDivideInplaceContigFunctor,
floor_divide_inplace_contig_kernel>(exec_q, nelems, arg_p, arg_offset,
res_p, res_offset, depends);
}

template <typename fnT, typename T1, typename T2>
struct FloorDivideInplaceContigFactory
{
fnT get()
{
if constexpr (std::is_same_v<
typename FloorDivideOutputType<T1, T2>::value_type,
void>)
{
fnT fn = nullptr;
return fn;
}
else {
fnT fn = floor_divide_inplace_contig_impl<T1, T2>;
return fn;
}
}
};

template <typename resT, typename argT, typename IndexerT>
class floor_divide_inplace_strided_kernel;

template <typename argTy, typename resTy>
sycl::event floor_divide_inplace_strided_impl(
sycl::queue &exec_q,
size_t nelems,
int nd,
const py::ssize_t *shape_and_strides,
const char *arg_p,
py::ssize_t arg_offset,
char *res_p,
py::ssize_t res_offset,
const std::vector<sycl::event> &depends,
const std::vector<sycl::event> &additional_depends)
{
return elementwise_common::binary_inplace_strided_impl<
argTy, resTy, FloorDivideInplaceStridedFunctor,
floor_divide_inplace_strided_kernel>(
exec_q, nelems, nd, shape_and_strides, arg_p, arg_offset, res_p,
res_offset, depends, additional_depends);
}

template <typename fnT, typename T1, typename T2>
struct FloorDivideInplaceStridedFactory
{
fnT get()
{
if constexpr (std::is_same_v<
typename FloorDivideOutputType<T1, T2>::value_type,
void>)
{
fnT fn = nullptr;
return fn;
}
else {
fnT fn = floor_divide_inplace_strided_impl<T1, T2>;
return fn;
}
}
};

} // namespace floor_divide
} // namespace kernels
} // namespace tensor
Expand Down
Loading

0 comments on commit e885838

Please sign in to comment.