diff --git a/dpctl/tensor/_elementwise_funcs.py b/dpctl/tensor/_elementwise_funcs.py index 8e2abee837..259443f8e3 100644 --- a/dpctl/tensor/_elementwise_funcs.py +++ b/dpctl/tensor/_elementwise_funcs.py @@ -590,6 +590,7 @@ ti._divide_result_type, ti._divide, _divide_docstring_, + binary_inplace_fn=ti._divide_inplace, acceptance_fn=_acceptance_fn_divide, ) @@ -720,6 +721,7 @@ ti._floor_divide_result_type, ti._floor_divide, _floor_divide_docstring_, + binary_inplace_fn=ti._floor_divide_inplace, ) # B11: ==== GREATER (x1, x2) diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/floor_divide.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/floor_divide.hpp index ad75924070..025d7e8bc4 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/floor_divide.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/floor_divide.hpp @@ -57,12 +57,7 @@ struct FloorDivideFunctor resT operator()(const argT1 &in1, const argT2 &in2) const { - if constexpr (std::is_same_v && - std::is_same_v) { - return (in2) ? static_cast(in1) : resT(0); - } - else if constexpr (std::is_integral_v || - std::is_integral_v) { + if constexpr (std::is_integral_v || std::is_integral_v) { if (in2 == argT2(0)) { return resT(0); } @@ -87,16 +82,7 @@ struct FloorDivideFunctor operator()(const sycl::vec &in1, const sycl::vec &in2) const { - if constexpr (std::is_same_v && - std::is_same_v) { - sycl::vec res; -#pragma unroll - for (int i = 0; i < vec_sz; ++i) { - res[i] = (in2[i]) ? static_cast(in1[i]) : resT(0); - } - return res; - } - else if constexpr (std::is_integral_v) { + if constexpr (std::is_integral_v) { sycl::vec res; #pragma unroll for (int i = 0; i < vec_sz; ++i) { @@ -165,7 +151,6 @@ template struct FloorDivideOutputType { using value_type = typename std::disjunction< // disjunction is C++17 // feature, supported by DPC++ - td_ns::BinaryTypeMapResultEntry, td_ns::BinaryTypeMapResultEntry struct FloorDivideInplaceFunctor +{ + using supports_sg_loadstore = std::true_type; + using supports_vec = std::true_type; + + void operator()(resT &in1, const argT &in2) const + { + if constexpr (std::is_integral_v) { + if (in2 == argT(0)) { + in1 = 0; + return; + } + if constexpr (std::is_signed_v) { + auto tmp = in1; + in1 /= in2; + auto mod = tmp % in2; + auto corr = (mod != 0 && l_xor(mod < 0, in2 < 0)); + in1 -= corr; + } + else { + in1 /= in2; + } + } + else { + in1 /= in2; + if (in1 == resT(0)) { + return; + } + in1 = std::floor(in1); + } + } + + template + void operator()(sycl::vec &in1, + const sycl::vec &in2) const + { + if constexpr (std::is_integral_v) { +#pragma unroll + for (int i = 0; i < vec_sz; ++i) { + if (in2[i] == argT(0)) { + in1[i] = 0; + } + else { + if constexpr (std::is_signed_v) { + auto tmp = in1[i]; + in1[i] /= in2[i]; + auto mod = tmp % in2[i]; + auto corr = (mod != 0 && l_xor(mod < 0, in2[i] < 0)); + in1[i] -= corr; + } + else { + in1[i] /= in2[i]; + } + } + } + } + else { + in1 /= in2; +#pragma unroll + for (int i = 0; i < vec_sz; ++i) { + if (in2[i] != argT(0)) { + in1[i] = std::floor(in1[i]); + } + } + } + } + +private: + bool l_xor(bool b1, bool b2) const + { + return (b1 != b2); + } +}; + +template +using FloorDivideInplaceContigFunctor = + elementwise_common::BinaryInplaceContigFunctor< + argT, + resT, + FloorDivideInplaceFunctor, + vec_sz, + n_vecs>; + +template +using FloorDivideInplaceStridedFunctor = + elementwise_common::BinaryInplaceStridedFunctor< + argT, + resT, + IndexerT, + FloorDivideInplaceFunctor>; + +template +class floor_divide_inplace_contig_kernel; + +template +sycl::event +floor_divide_inplace_contig_impl(sycl::queue &exec_q, + size_t nelems, + const char *arg_p, + py::ssize_t arg_offset, + char *res_p, + py::ssize_t res_offset, + const std::vector &depends = {}) +{ + return elementwise_common::binary_inplace_contig_impl< + argTy, resTy, FloorDivideInplaceContigFunctor, + floor_divide_inplace_contig_kernel>(exec_q, nelems, arg_p, arg_offset, + res_p, res_offset, depends); +} + +template +struct FloorDivideInplaceContigFactory +{ + fnT get() + { + if constexpr (std::is_same_v< + typename FloorDivideOutputType::value_type, + void>) + { + fnT fn = nullptr; + return fn; + } + else { + fnT fn = floor_divide_inplace_contig_impl; + return fn; + } + } +}; + +template +class floor_divide_inplace_strided_kernel; + +template +sycl::event floor_divide_inplace_strided_impl( + sycl::queue &exec_q, + size_t nelems, + int nd, + const py::ssize_t *shape_and_strides, + const char *arg_p, + py::ssize_t arg_offset, + char *res_p, + py::ssize_t res_offset, + const std::vector &depends, + const std::vector &additional_depends) +{ + return elementwise_common::binary_inplace_strided_impl< + argTy, resTy, FloorDivideInplaceStridedFunctor, + floor_divide_inplace_strided_kernel>( + exec_q, nelems, nd, shape_and_strides, arg_p, arg_offset, res_p, + res_offset, depends, additional_depends); +} + +template +struct FloorDivideInplaceStridedFactory +{ + fnT get() + { + if constexpr (std::is_same_v< + typename FloorDivideOutputType::value_type, + void>) + { + fnT fn = nullptr; + return fn; + } + else { + fnT fn = floor_divide_inplace_strided_impl; + return fn; + } + } +}; + } // namespace floor_divide } // namespace kernels } // namespace tensor diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/true_divide.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/true_divide.hpp index 9f488e6598..138f7a3f91 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/true_divide.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/true_divide.hpp @@ -370,6 +370,233 @@ struct TrueDivideContigRowContigMatrixBroadcastFactory } }; +template struct TrueDivideInplaceFunctor +{ + + using supports_sg_loadstore = std::negation< + std::disjunction, tu_ns::is_complex>>; + using supports_vec = std::negation< + std::disjunction, tu_ns::is_complex>>; + + void operator()(resT &res, const argT &in) + { + res /= in; + } + + template + void operator()(sycl::vec &res, + const sycl::vec &in) + { + res /= in; + } +}; + +// cannot use the out of place table, as it permits real lhs and complex rhs +// T1 corresponds to the type of the rhs, while T2 corresponds to the lhs +// the type of the result must be the same as T2 +template struct TrueDivideInplaceOutputType +{ + using value_type = typename std::disjunction< // disjunction is C++17 + // feature, supported by DPC++ + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + T2, + std::complex, + std::complex>, + td_ns::BinaryTypeMapResultEntry, + std::complex>, + td_ns::BinaryTypeMapResultEntry, + T2, + std::complex, + std::complex>, + td_ns::BinaryTypeMapResultEntry, + std::complex>, + td_ns::DefaultResultEntry>::result_type; +}; + +template +struct TrueDivideInplaceTypeMapFactory +{ + /*! @brief get typeid for output type of divide(T1 x, T2 y) */ + std::enable_if_t::value, int> get() + { + using rT = typename TrueDivideInplaceOutputType::value_type; + static_assert(std::is_same_v || std::is_same_v); + return td_ns::GetTypeid{}.get(); + } +}; + +template +using TrueDivideInplaceContigFunctor = + elementwise_common::BinaryInplaceContigFunctor< + argT, + resT, + TrueDivideInplaceFunctor, + vec_sz, + n_vecs>; + +template +using TrueDivideInplaceStridedFunctor = + elementwise_common::BinaryInplaceStridedFunctor< + argT, + resT, + IndexerT, + TrueDivideInplaceFunctor>; + +template +class true_divide_inplace_contig_kernel; + +template +sycl::event +true_divide_inplace_contig_impl(sycl::queue &exec_q, + size_t nelems, + const char *arg_p, + py::ssize_t arg_offset, + char *res_p, + py::ssize_t res_offset, + const std::vector &depends = {}) +{ + return elementwise_common::binary_inplace_contig_impl< + argTy, resTy, TrueDivideInplaceContigFunctor, + true_divide_inplace_contig_kernel>(exec_q, nelems, arg_p, arg_offset, + res_p, res_offset, depends); +} + +template +struct TrueDivideInplaceContigFactory +{ + fnT get() + { + if constexpr (std::is_same_v::value_type, + void>) + { + fnT fn = nullptr; + return fn; + } + else { + fnT fn = true_divide_inplace_contig_impl; + return fn; + } + } +}; + +template +class true_divide_inplace_strided_kernel; + +template +sycl::event true_divide_inplace_strided_impl( + sycl::queue &exec_q, + size_t nelems, + int nd, + const py::ssize_t *shape_and_strides, + const char *arg_p, + py::ssize_t arg_offset, + char *res_p, + py::ssize_t res_offset, + const std::vector &depends, + const std::vector &additional_depends) +{ + return elementwise_common::binary_inplace_strided_impl< + argTy, resTy, TrueDivideInplaceStridedFunctor, + true_divide_inplace_strided_kernel>( + exec_q, nelems, nd, shape_and_strides, arg_p, arg_offset, res_p, + res_offset, depends, additional_depends); +} + +template +struct TrueDivideInplaceStridedFactory +{ + fnT get() + { + if constexpr (std::is_same_v::value_type, + void>) + { + fnT fn = nullptr; + return fn; + } + else { + fnT fn = true_divide_inplace_strided_impl; + return fn; + } + } +}; + +template +class true_divide_inplace_row_matrix_broadcast_sg_krn; + +template +using TrueDivideInplaceRowMatrixBroadcastingFunctor = + elementwise_common::BinaryInplaceRowMatrixBroadcastingFunctor< + argT, + resT, + TrueDivideInplaceFunctor>; + +template +sycl::event true_divide_inplace_row_matrix_broadcast_impl( + sycl::queue &exec_q, + std::vector &host_tasks, + size_t n0, + size_t n1, + const char *vec_p, // typeless pointer to (n1,) contiguous row + py::ssize_t vec_offset, + char *mat_p, // typeless pointer to (n0, n1) C-contiguous matrix + py::ssize_t mat_offset, + const std::vector &depends = {}) +{ + return elementwise_common::binary_inplace_row_matrix_broadcast_impl< + argT, resT, TrueDivideInplaceRowMatrixBroadcastingFunctor, + true_divide_inplace_row_matrix_broadcast_sg_krn>( + exec_q, host_tasks, n0, n1, vec_p, vec_offset, mat_p, mat_offset, + depends); +} + +template +struct TrueDivideInplaceRowMatrixBroadcastFactory +{ + fnT get() + { + using resT = typename TrueDivideInplaceOutputType::value_type; + if constexpr (!std::is_same_v) { + fnT fn = nullptr; + return fn; + } + else { + if constexpr (dpctl::tensor::type_utils::is_complex::value || + dpctl::tensor::type_utils::is_complex::value) + { + fnT fn = nullptr; + return fn; + } + else { + fnT fn = true_divide_inplace_row_matrix_broadcast_impl; + return fn; + } + } + } +}; + } // namespace true_divide } // namespace kernels } // namespace tensor diff --git a/dpctl/tensor/libtensor/source/elementwise_functions.cpp b/dpctl/tensor/libtensor/source/elementwise_functions.cpp index cca0ac7c0a..3cca479a3f 100644 --- a/dpctl/tensor/libtensor/source/elementwise_functions.cpp +++ b/dpctl/tensor/libtensor/source/elementwise_functions.cpp @@ -933,6 +933,8 @@ namespace true_divide_fn_ns = dpctl::tensor::kernels::true_divide; static binary_contig_impl_fn_ptr_t true_divide_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; static int true_divide_output_id_table[td_ns::num_types][td_ns::num_types]; +static int true_divide_inplace_output_id_table[td_ns::num_types] + [td_ns::num_types]; static binary_strided_impl_fn_ptr_t true_divide_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; @@ -947,6 +949,16 @@ static binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t true_divide_contig_row_contig_matrix_broadcast_dispatch_table [td_ns::num_types][td_ns::num_types]; +static binary_inplace_contig_impl_fn_ptr_t + true_divide_inplace_contig_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static binary_inplace_strided_impl_fn_ptr_t + true_divide_inplace_strided_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static binary_inplace_row_matrix_broadcast_impl_fn_ptr_t + true_divide_inplace_row_matrix_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + void populate_true_divide_dispatch_tables(void) { using namespace td_ns; @@ -990,6 +1002,33 @@ void populate_true_divide_dispatch_tables(void) dtb5; dtb5.populate_dispatch_table( true_divide_contig_row_contig_matrix_broadcast_dispatch_table); + + // which input types are supported, and what is the type of the result + using fn_ns::TrueDivideInplaceTypeMapFactory; + DispatchTableBuilder dtb6; + dtb6.populate_dispatch_table(true_divide_inplace_output_id_table); + + // function pointers for inplace operation on general strided arrays + using fn_ns::TrueDivideInplaceStridedFactory; + DispatchTableBuilder + dtb7; + dtb7.populate_dispatch_table(true_divide_inplace_strided_dispatch_table); + + // function pointers for inplace operation on contiguous inputs and output + using fn_ns::TrueDivideInplaceContigFactory; + DispatchTableBuilder + dtb8; + dtb8.populate_dispatch_table(true_divide_inplace_contig_dispatch_table); + + // function pointers for inplace operation on contiguous matrix + // and contiguous row + using fn_ns::TrueDivideInplaceRowMatrixBroadcastFactory; + DispatchTableBuilder + dtb9; + dtb9.populate_dispatch_table(true_divide_inplace_row_matrix_dispatch_table); }; } // namespace impl @@ -1151,6 +1190,13 @@ static int floor_divide_output_id_table[td_ns::num_types][td_ns::num_types]; static binary_strided_impl_fn_ptr_t floor_divide_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; +static binary_inplace_contig_impl_fn_ptr_t + floor_divide_inplace_contig_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static binary_inplace_strided_impl_fn_ptr_t + floor_divide_inplace_strided_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + void populate_floor_divide_dispatch_tables(void) { using namespace td_ns; @@ -1174,6 +1220,20 @@ void populate_floor_divide_dispatch_tables(void) num_types> dtb3; dtb3.populate_dispatch_table(floor_divide_contig_dispatch_table); + + // function pointers for inplace operation on general strided arrays + using fn_ns::FloorDivideInplaceStridedFactory; + DispatchTableBuilder + dtb4; + dtb4.populate_dispatch_table(floor_divide_inplace_strided_dispatch_table); + + // function pointers for inplace operation on contiguous inputs and output + using fn_ns::FloorDivideInplaceContigFactory; + DispatchTableBuilder + dtb5; + dtb5.populate_dispatch_table(floor_divide_inplace_contig_dispatch_table); }; } // namespace impl @@ -3379,6 +3439,33 @@ void init_elementwise_functions(py::module_ m) py::arg("dst"), py::arg("sycl_queue"), py::arg("depends") = py::list()); m.def("_divide_result_type", divide_result_type_pyapi, ""); + + using impl::true_divide_inplace_contig_dispatch_table; + using impl::true_divide_inplace_output_id_table; + using impl::true_divide_inplace_row_matrix_dispatch_table; + using impl::true_divide_inplace_strided_dispatch_table; + + auto divide_inplace_pyapi = + [&](const dpctl::tensor::usm_ndarray &src, + const dpctl::tensor::usm_ndarray &dst, sycl::queue &exec_q, + const std::vector &depends = {}) { + return py_binary_inplace_ufunc( + src, dst, exec_q, depends, + true_divide_inplace_output_id_table, + // function pointers to handle inplace operation on + // contiguous arrays (pointers may be nullptr) + true_divide_inplace_contig_dispatch_table, + // function pointers to handle inplace operation on strided + // arrays (most general case) + true_divide_inplace_strided_dispatch_table, + // function pointers to handle inplace operation on + // c-contig matrix with c-contig row with broadcasting + // (may be nullptr) + true_divide_inplace_row_matrix_dispatch_table); + }; + m.def("_divide_inplace", divide_inplace_pyapi, "", py::arg("lhs"), + py::arg("rhs"), py::arg("sycl_queue"), + py::arg("depends") = py::list()); } // B09: ==== EQUAL (x1, x2) @@ -3531,6 +3618,31 @@ void init_elementwise_functions(py::module_ m) py::arg("src2"), py::arg("dst"), py::arg("sycl_queue"), py::arg("depends") = py::list()); m.def("_floor_divide_result_type", floor_divide_result_type_pyapi, ""); + + using impl::floor_divide_inplace_contig_dispatch_table; + using impl::floor_divide_inplace_strided_dispatch_table; + + auto floor_divide_inplace_pyapi = + [&](const dpctl::tensor::usm_ndarray &src, + const dpctl::tensor::usm_ndarray &dst, sycl::queue &exec_q, + const std::vector &depends = {}) { + return py_binary_inplace_ufunc( + src, dst, exec_q, depends, floor_divide_output_id_table, + // function pointers to handle inplace operation on + // contiguous arrays (pointers may be nullptr) + floor_divide_inplace_contig_dispatch_table, + // function pointers to handle inplace operation on strided + // arrays (most general case) + floor_divide_inplace_strided_dispatch_table, + // function pointers to handle inplace operation on + // c-contig matrix with c-contig row with broadcasting + // (may be nullptr) + td_ns::NullPtrTable< + binary_inplace_row_matrix_broadcast_impl_fn_ptr_t>{}); + }; + m.def("_floor_divide_inplace", floor_divide_inplace_pyapi, "", + py::arg("lhs"), py::arg("rhs"), py::arg("sycl_queue"), + py::arg("depends") = py::list()); } // B11: ==== GREATER (x1, x2) diff --git a/dpctl/tests/elementwise/test_divide.py b/dpctl/tests/elementwise/test_divide.py index 41aac736d7..a54060792c 100644 --- a/dpctl/tests/elementwise/test_divide.py +++ b/dpctl/tests/elementwise/test_divide.py @@ -21,9 +21,16 @@ import dpctl import dpctl.tensor as dpt +from dpctl.tensor._type_utils import _can_cast from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported -from .utils import _all_dtypes, _compare_dtypes, _usm_types +from .utils import ( + _all_dtypes, + _compare_dtypes, + _complex_fp_dtypes, + _real_fp_dtypes, + _usm_types, +) @pytest.mark.parametrize("op1_dtype", _all_dtypes) @@ -187,3 +194,65 @@ def __sycl_usm_array_interface__(self): c = Canary() with pytest.raises(ValueError): dpt.divide(a, c) + + +@pytest.mark.parametrize("dtype", _real_fp_dtypes + _complex_fp_dtypes) +def test_divide_inplace_python_scalar(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + X = dpt.zeros((10, 10), dtype=dtype, sycl_queue=q) + dt_kind = X.dtype.kind + if dt_kind == "f": + X /= float(1) + elif dt_kind == "c": + X /= complex(1) + + +@pytest.mark.parametrize("op1_dtype", _all_dtypes) +@pytest.mark.parametrize("op2_dtype", _all_dtypes) +def test_divide_inplace_dtype_matrix(op1_dtype, op2_dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(op1_dtype, q) + skip_if_dtype_not_supported(op2_dtype, q) + + sz = 127 + ar1 = dpt.ones(sz, dtype=op1_dtype, sycl_queue=q) + ar2 = dpt.ones_like(ar1, dtype=op2_dtype, sycl_queue=q) + + dev = q.sycl_device + _fp16 = dev.has_aspect_fp16 + _fp64 = dev.has_aspect_fp64 + # out array only valid if it is inexact + if ( + _can_cast(ar2.dtype, ar1.dtype, _fp16, _fp64) + and dpt.dtype(op1_dtype).kind in "fc" + ): + ar1 /= ar2 + assert dpt.all(ar1 == 1) + + ar3 = dpt.ones(sz, dtype=op1_dtype, sycl_queue=q)[::-1] + ar4 = dpt.ones(2 * sz, dtype=op2_dtype, sycl_queue=q)[::2] + ar3 /= ar4 + assert dpt.all(ar3 == 1) + else: + with pytest.raises(TypeError): + ar1 /= ar2 + dpt.divide(ar1, ar2, out=ar1) + + # out is second arg + ar1 = dpt.ones(sz, dtype=op1_dtype, sycl_queue=q) + ar2 = dpt.ones_like(ar1, dtype=op2_dtype, sycl_queue=q) + if ( + _can_cast(ar1.dtype, ar2.dtype, _fp16, _fp64) + and dpt.dtype(op2_dtype).kind in "fc" + ): + dpt.divide(ar1, ar2, out=ar2) + assert dpt.all(ar2 == 1) + + ar3 = dpt.ones(sz, dtype=op1_dtype, sycl_queue=q)[::-1] + ar4 = dpt.ones(2 * sz, dtype=op2_dtype, sycl_queue=q)[::2] + dpt.divide(ar3, ar4, out=ar4) + dpt.all(ar4 == 1) + else: + with pytest.raises(TypeError): + dpt.divide(ar1, ar2, out=ar2) diff --git a/dpctl/tests/elementwise/test_floor_divide.py b/dpctl/tests/elementwise/test_floor_divide.py index c8ba5e80f1..b57c006cdf 100644 --- a/dpctl/tests/elementwise/test_floor_divide.py +++ b/dpctl/tests/elementwise/test_floor_divide.py @@ -21,13 +21,19 @@ import dpctl import dpctl.tensor as dpt +from dpctl.tensor._type_utils import _can_cast from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported -from .utils import _compare_dtypes, _no_complex_dtypes, _usm_types +from .utils import ( + _compare_dtypes, + _integral_dtypes, + _no_complex_dtypes, + _usm_types, +) -@pytest.mark.parametrize("op1_dtype", _no_complex_dtypes) -@pytest.mark.parametrize("op2_dtype", _no_complex_dtypes) +@pytest.mark.parametrize("op1_dtype", _no_complex_dtypes[1:]) +@pytest.mark.parametrize("op2_dtype", _no_complex_dtypes[1:]) def test_floor_divide_dtype_matrix(op1_dtype, op2_dtype): q = get_queue_or_skip() skip_if_dtype_not_supported(op1_dtype, q) @@ -133,7 +139,7 @@ def test_floor_divide_broadcasting(): assert (dpt.asnumpy(r2) == expected2.astype(r2.dtype)).all() -@pytest.mark.parametrize("arr_dt", _no_complex_dtypes) +@pytest.mark.parametrize("arr_dt", _no_complex_dtypes[1:]) def test_floor_divide_python_scalar(arr_dt): q = get_queue_or_skip() skip_if_dtype_not_supported(arr_dt, q) @@ -204,7 +210,7 @@ def test_floor_divide_gh_1247(): ) -@pytest.mark.parametrize("dtype", _no_complex_dtypes[1:9]) +@pytest.mark.parametrize("dtype", _integral_dtypes) def test_floor_divide_integer_zero(dtype): q = get_queue_or_skip() skip_if_dtype_not_supported(dtype, q) @@ -255,3 +261,59 @@ def test_floor_divide_special_cases(): res = dpt.floor_divide(x, y) res_np = np.floor_divide(dpt.asnumpy(x), dpt.asnumpy(y)) np.testing.assert_array_equal(dpt.asnumpy(res), res_np) + + +@pytest.mark.parametrize("dtype", _no_complex_dtypes[1:]) +def test_divide_inplace_python_scalar(dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(dtype, q) + X = dpt.zeros((10, 10), dtype=dtype, sycl_queue=q) + dt_kind = X.dtype.kind + if dt_kind in "ui": + X //= int(1) + elif dt_kind == "f": + X //= float(1) + + +@pytest.mark.parametrize("op1_dtype", _no_complex_dtypes[1:]) +@pytest.mark.parametrize("op2_dtype", _no_complex_dtypes[1:]) +def test_floor_divide_inplace_dtype_matrix(op1_dtype, op2_dtype): + q = get_queue_or_skip() + skip_if_dtype_not_supported(op1_dtype, q) + skip_if_dtype_not_supported(op2_dtype, q) + + sz = 127 + ar1 = dpt.ones(sz, dtype=op1_dtype, sycl_queue=q) + ar2 = dpt.ones_like(ar1, dtype=op2_dtype, sycl_queue=q) + + dev = q.sycl_device + _fp16 = dev.has_aspect_fp16 + _fp64 = dev.has_aspect_fp64 + # out array only valid if it is inexact + if _can_cast(ar2.dtype, ar1.dtype, _fp16, _fp64): + ar1 //= ar2 + assert dpt.all(ar1 == 1) + + ar3 = dpt.ones(sz, dtype=op1_dtype, sycl_queue=q)[::-1] + ar4 = dpt.ones(2 * sz, dtype=op2_dtype, sycl_queue=q)[::2] + ar3 //= ar4 + assert dpt.all(ar3 == 1) + else: + with pytest.raises(TypeError): + ar1 //= ar2 + dpt.floor_divide(ar1, ar2, out=ar1) + + # out is second arg + ar1 = dpt.ones(sz, dtype=op1_dtype, sycl_queue=q) + ar2 = dpt.ones_like(ar1, dtype=op2_dtype, sycl_queue=q) + if _can_cast(ar1.dtype, ar2.dtype, _fp16, _fp64): + dpt.floor_divide(ar1, ar2, out=ar2) + assert dpt.all(ar2 == 1) + + ar3 = dpt.ones(sz, dtype=op1_dtype, sycl_queue=q)[::-1] + ar4 = dpt.ones(2 * sz, dtype=op2_dtype, sycl_queue=q)[::2] + dpt.floor_divide(ar3, ar4, out=ar4) + dpt.all(ar4 == 1) + else: + with pytest.raises(TypeError): + dpt.floor_divide(ar1, ar2, out=ar2)