Skip to content

Commit

Permalink
Functional dpnp.arange for int and double types
Browse files Browse the repository at this point in the history
  • Loading branch information
chudur-budur committed Oct 31, 2023
1 parent 7baa79e commit ed8cb70
Show file tree
Hide file tree
Showing 5 changed files with 429 additions and 83 deletions.
18 changes: 16 additions & 2 deletions numba_dpex/core/runtime/kernels/api.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,35 @@
#include "dpctl_capi.h"
#include "dpctl_sycl_interface.h"

#pragma once

#ifdef __cplusplus
extern "C"
{
#endif

// Dispatch vector initializer functions.
void NUMBA_DPEX_SYCL_KERNEL_init_sequence_step_dispatch_vectors();
void NUMBA_DPEX_SYCL_KERNEL_init_affine_sequence_dispatch_vectors();

// Call linear sequences dispatch functions.
uint NUMBA_DPEX_SYCL_KERNEL_populate_arystruct_sequence(
void *start,
void *dt,
arystruct_t *dst,
int ndim,
u_int8_t is_c_contiguous,
int dst_typeid,
const DPCTLSyclQueueRef exec_q);

// Call linear affine sequences dispatch functions.
uint NUMBA_DPEX_SYCL_KERNEL_populate_arystruct_affine_sequence(
void *start,
void *end,
arystruct_t *dst,
u_int8_t include_endpoint,
int ndim,
u_int8_t is_c_contiguous,
const DPCTLSyclQueueRef exec_q);
// const DPCTLEventVectorRef depends = std::vector<DPCTLSyclEventRef>());

#ifdef __cplusplus
}
Expand Down
39 changes: 19 additions & 20 deletions numba_dpex/core/runtime/kernels/sequences.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,20 @@ extern "C" uint NUMBA_DPEX_SYCL_KERNEL_populate_arystruct_sequence(
arystruct_t *dst,
int ndim,
u_int8_t is_c_contiguous,
int dst_typeid,
const DPCTLSyclQueueRef exec_q)
// const DPCTLEventVectorRef depends = std::vector<DPCTLSyclEventRef>())
{
std::cout << "NUMBA_DPEX_SYCL_KERNEL_populate_arystruct_sequence:"
<< " start = " << *(reinterpret_cast<int *>(start)) << std::endl;
<< " start = "
<< ndpx::runtime::kernel::types::caste_using_typeid(start,
dst_typeid)
<< std::endl;

std::cout << "NUMBA_DPEX_SYCL_KERNEL_populate_arystruct_sequence:"
<< " dt = " << *(reinterpret_cast<int *>(dt)) << std::endl;
<< " dt = "
<< ndpx::runtime::kernel::types::caste_using_typeid(dt,
dst_typeid)
<< std::endl;

if (ndim != 1) {
throw std::logic_error(
Expand All @@ -60,37 +66,30 @@ extern "C" uint NUMBA_DPEX_SYCL_KERNEL_populate_arystruct_sequence(
"populate_arystruct_linseq(): array must be c-contiguous.");
}

/**
auto array_types = td_ns::usm_ndarray_types();
int dst_typenum = dst.get_typenum();
int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
*/

size_t len = static_cast<size_t>(dst->nitems); // dst.get_shape(0);
if (len == 0) {
// nothing to do
// return std::make_pair(sycl::event{}, sycl::event{});
size_t len = static_cast<size_t>(dst->nitems);
if (len == 0)
return 0;
}
std::cout << "NUMBA_DPEX_SYCL_KERNEL_populate_arystruct_sequence:"
<< " len = " << len << std::endl;

char *dst_data = reinterpret_cast<char *>(dst->data);

const int dst_typeid = 7; // 7 = int64_t, 10 = float, 11 = double
// int64_t *_start = reinterpret_cast<int64_t *>(start);
// int64_t *_dt = reinterpret_cast<int64_t *>(dt);
// int dst_typeid = 7; // 7 = int64_t, 10 = float, 11 = double
auto fn = sequence_step_dispatch_vector[dst_typeid];

sycl::queue *queue = reinterpret_cast<sycl::queue *>(exec_q);
std::vector<sycl::event> depends = std::vector<sycl::event>();
sycl::event linspace_step_event =
fn(*queue, len, start, dt, dst_data, depends);

/*return std::make_pair(keep_args_alive(exec_q, {dst},
{linspace_step_event}), linspace_step_event);*/
linspace_step_event.wait_and_throw();

return 1;
if (linspace_step_event
.get_info<sycl::info::event::command_execution_status>() ==
sycl::info::event_command_status::complete)
return 0;
else
return 1;
}

// uint ndpx::runtime::kernel::tensor::populate_arystruct_affine_sequence(
Expand Down
26 changes: 16 additions & 10 deletions numba_dpex/core/runtime/kernels/sequences.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <iostream>
#include <exception>
#include <complex>
#include <typeinfo>

#include <Python.h>
#include <numpy/npy_common.h>
Expand Down Expand Up @@ -111,7 +112,17 @@ sycl::event sequence_step_kernel(sycl::queue exec_q,
char *array_data,
const std::vector<sycl::event> &depends)
{
std::cout << "sequqnce_step_kernel<"
<< ndpx::runtime::kernel::types::demangle<T>()
<< ">(): nelems = " << nelems << ", start_v = " << start_v
<< ", step_v = " << step_v << std::endl;

ndpx::runtime::kernel::types::validate_type_for_device<T>(exec_q);

std::cout << "sequqnce_step_kernel<"
<< ndpx::runtime::kernel::types::demangle<T>()
<< ">(): validate_type_for_device<T>(exec_q) = done" << std::endl;

sycl::event seq_step_event = exec_q.submit([&](sycl::handler &cgh) {
cgh.depends_on(depends);
cgh.parallel_for<ndpx_sequence_step_kernel<T>>(
Expand Down Expand Up @@ -170,6 +181,11 @@ sycl::event sequence_step(sycl::queue &exec_q,
std::cerr << e.what() << std::endl;
}

std::cout << "sequqnce_step()<"
<< ndpx::runtime::kernel::types::demangle<T>()
<< ">: nelems = " << nelems << ", *start_v = " << (*start_v)
<< ", *step_v = " << (*step_v) << std::endl;

auto sequence_step_event = sequence_step_kernel<T>(
exec_q, nelems, *start_v, *step_v, array_data, depends);

Expand Down Expand Up @@ -232,16 +248,6 @@ typedef sycl::event (*affine_sequence_ptr_t)(sycl::queue &,
bool, // include_endpoint
char *, // dst_data_ptr
const std::vector<sycl::event> &);

// uint populate_arystruct_affine_sequence(void *start,
// void *end,
// arystruct_t *dst,
// int include_endpoint,
// int ndim,
// int is_c_contiguous,
// const DPCTLSyclQueueRef exec_q,
// const DPCTLEventVectorRef depends);

} // namespace tensor
} // namespace kernel
} // namespace runtime
Expand Down
64 changes: 64 additions & 0 deletions numba_dpex/core/runtime/kernels/types.hpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
#ifndef __TYPES_HPP__
#define __TYPES_HPP__

#include <cstdlib>
#include <complex>
#include <exception>
#include <utility>
#include <string>
#include <cxxabi.h>
#include <CL/sycl.hpp>

namespace ndpx
Expand Down Expand Up @@ -43,6 +46,58 @@ enum class typenum_t : int

constexpr int num_types = 14; // number of elements in typenum_t

template <typename T> std::string demangle()
{
char const *mangled = typeid(T).name();
char *c_demangled;
int status = 0;
c_demangled = abi::__cxa_demangle(mangled, nullptr, nullptr, &status);

std::string res;
if (c_demangled) {
res = c_demangled;
free(c_demangled);
}
else {
res = mangled;
free(c_demangled);
}
return res;
}

std::string caste_using_typeid(void *value, int _typeid)
{
switch (_typeid) {
case 0:
return std::to_string(*(reinterpret_cast<bool *>(value)));
case 1:
return std::to_string(*(reinterpret_cast<int8_t *>(value)));
case 2:
return std::to_string(*(reinterpret_cast<u_int8_t *>(value)));
case 3:
return std::to_string(*(reinterpret_cast<int16_t *>(value)));
case 4:
return std::to_string(*(reinterpret_cast<u_int16_t *>(value)));
case 5:
return std::to_string(*(reinterpret_cast<int32_t *>(value)));
case 6:
return std::to_string(*(reinterpret_cast<u_int32_t *>(value)));
case 7:
return std::to_string(*(reinterpret_cast<int64_t *>(value)));
case 8:
return std::to_string(*(reinterpret_cast<u_int64_t *>(value)));
case 9:
return std::to_string(*(reinterpret_cast<sycl::half *>(value)));
case 10:
return std::to_string(*(reinterpret_cast<float *>(value)));
case 11:
return std::to_string(*(reinterpret_cast<double *>(value)));
default:
throw std::runtime_error(std::to_string(_typeid) +
" could't be mapped to valid data type.");
}
}

template <typename dstTy, typename srcTy> dstTy convert_impl(const srcTy &v)
{
if constexpr (std::is_same<dstTy, srcTy>::value) {
Expand Down Expand Up @@ -75,20 +130,29 @@ template <typename dstTy, typename srcTy> dstTy convert_impl(const srcTy &v)
template <typename T> void validate_type_for_device(const sycl::device &d)
{
if constexpr (std::is_same_v<T, double>) {
std::cout
<< "ndpx::runtime::kernel::types::validate_type_for_device(): here0"
<< std::endl;
if (!d.has(sycl::aspect::fp64)) {
throw std::runtime_error("Device " +
d.get_info<sycl::info::device::name>() +
" does not support type 'float64'");
}
}
else if constexpr (std::is_same_v<T, std::complex<double>>) {
std::cout
<< "ndpx::runtime::kernel::types::validate_type_for_device(): here1"
<< std::endl;
if (!d.has(sycl::aspect::fp64)) {
throw std::runtime_error("Device " +
d.get_info<sycl::info::device::name>() +
" does not support type 'complex128'");
}
}
else if constexpr (std::is_same_v<T, sycl::half>) {
std::cout
<< "ndpx::runtime::kernel::types::validate_type_for_device(): here2"
<< std::endl;
if (!d.has(sycl::aspect::fp16)) {
throw std::runtime_error("Device " +
d.get_info<sycl::info::device::name>() +
Expand Down
Loading

0 comments on commit ed8cb70

Please sign in to comment.