Skip to content

Commit

Permalink
Merge pull request #173 from sony/feature/20190624-stream-event-handler
Browse files Browse the repository at this point in the history
Add CUDA stream handler class in Python (as an unsupported feature so far)
  • Loading branch information
TakuyaNarihira authored Jul 11, 2019
2 parents 6f62be5 + c214e6d commit a16372b
Show file tree
Hide file tree
Showing 4 changed files with 180 additions and 2 deletions.
21 changes: 21 additions & 0 deletions include/nbla/cuda/init.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@

#include <nbla/cuda/defs.hpp>

#include <memory>
#include <string>
#include <vector>

namespace nbla {

using std::vector;
using std::string;
using std::shared_ptr;

/**
Initialize CUDA features.
Expand Down Expand Up @@ -53,5 +55,24 @@ NBLA_CUDA_API int cuda_get_device_count();
/** get available devices.
*/
NBLA_CUDA_API vector<string> cuda_get_devices();

/** cudaStream wrapper functions.
*/
NBLA_CUDA_API shared_ptr<void> cuda_create_stream(int device_id = -1);

NBLA_CUDA_API void *cuda_stream_shared_to_void(shared_ptr<void> stream);
NBLA_CUDA_API void print_stream_flag(shared_ptr<void> stream);
NBLA_CUDA_API void print_stream_priority(shared_ptr<void> stream);
NBLA_CUDA_API void cuda_stream_synchronize(shared_ptr<void> stream);
NBLA_CUDA_API void cuda_nullstream_synchronize();
NBLA_CUDA_API void cuda_stream_destroy(shared_ptr<void> stream);

/** cudaEvent wrapper functions.
*/
NBLA_CUDA_API shared_ptr<void> cuda_create_event(int device_id = -1);
NBLA_CUDA_API void cuda_default_stream_event(shared_ptr<void> event);
NBLA_CUDA_API void cuda_stream_wait_event(shared_ptr<void> stream,
shared_ptr<void> event);
NBLA_CUDA_API void cuda_event_synchronize(shared_ptr<void> event);
}
#endif
3 changes: 2 additions & 1 deletion python/src/nnabla_ext/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
array_classes,
device_synchronize,
get_device_count,
get_devices)
get_devices,
StreamEventHandler)
except:
print('Please install CUDA version {}.'.format(__cuda_version__))
print(' and CUDNN version {}.'.format(__cudnn_version__))
Expand Down
66 changes: 65 additions & 1 deletion python/src/nnabla_ext/cuda/init.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ from nnabla import add_available_context
import nnabla._init as cpu_init
from libcpp.vector cimport vector
from libcpp.string cimport string
from libcpp.memory cimport shared_ptr
from libc.stdint cimport uintptr_t
from libcpp cimport bool

cdef extern from "nbla/cuda/init.hpp" namespace "nbla":
void init_cuda() except+
Expand All @@ -27,7 +30,17 @@ cdef extern from "nbla/cuda/init.hpp" namespace "nbla":
void cuda_device_synchronize(const string & device) except +
int cuda_get_device_count() except +
vector[string] cuda_get_devices() except +

shared_ptr[void] cuda_create_stream(int device_id) except +
void* cuda_stream_shared_to_void(shared_ptr[void]) except +
void print_stream_flag(shared_ptr[void]) except +
void print_stream_priority(shared_ptr[void]) except +
void cuda_stream_synchronize(shared_ptr[void]) nogil except +
void cuda_nullstream_synchronize() nogil except +
void cuda_stream_destroy(shared_ptr[void]) except +
shared_ptr[void] cuda_create_event(int device_id) except +
void cuda_default_stream_event(shared_ptr[void]) except +
void cuda_stream_wait_event(shared_ptr[void], shared_ptr[void]) except +
void cuda_event_synchronize(shared_ptr[void]) nogil except +

logger.info('Initializing CUDA extension...')
try:
Expand Down Expand Up @@ -110,3 +123,54 @@ def get_devices():
"""
return cuda_get_devices()
###############################################################################

cdef class StreamEventHandler:
cdef shared_ptr[void] stream
cdef shared_ptr[void] event
cdef public object value
cdef public int device_id
cpdef bool is_stream_destroy

def __cinit__(self, int device_id=-1):
self.is_stream_destroy = True
self.device_id = device_id

def __init__(self, int device_id=-1):
self.stream_create(device_id)
self.event = cuda_create_event(device_id)
self.add_default_stream_event()

def stream_wait_event(self):
if not self.is_stream_destroy:
cuda_stream_wait_event(self.stream, self.event)

def add_default_stream_event(self):
cuda_default_stream_event(self.event)

def event_synchronize(self):
with nogil:
cuda_event_synchronize(self.event)

def stream_destroy(self):
cuda_stream_destroy(self.stream)
self.is_stream_destroy = True

def stream_create(self, device_id):
if not self.is_stream_destroy:
self.stream_destroy()

self.stream = cuda_create_stream(device_id)

cdef void* stream_vp = cuda_stream_shared_to_void(self.stream)
self.value = <uintptr_t>stream_vp

self.is_stream_destroy = False

def stream_synchronize(self):
if not self.is_stream_destroy:
with nogil:
cuda_stream_synchronize(self.stream)

def default_stream_synchronize(self):
with nogil:
cuda_nullstream_synchronize()
92 changes: 92 additions & 0 deletions src/nbla/cuda/init.cpp.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -168,5 +168,97 @@ vector<string> cuda_get_devices() {
}
return ret;
}

shared_ptr<void> cuda_create_stream(int device_id) {
cuda_set_device(device_id);

std::default_delete<cudaStream_t> default_deleter;
auto deleter = [default_deleter](cudaStream_t* ptr) {
NBLA_CUDA_CHECK(cudaStreamDestroy(*ptr));

default_deleter(ptr);
};

auto stream = shared_ptr<cudaStream_t>(new cudaStream_t(), deleter);

NBLA_CUDA_CHECK(cudaStreamCreateWithFlags(stream.get(), cudaStreamNonBlocking));

return stream;
}

void* cuda_stream_shared_to_void(shared_ptr<void> stream) {
auto s = static_cast<cudaStream_t*>(stream.get());

return static_cast<void*>(*s);
}

void print_stream_flag (shared_ptr<void> stream) {
auto s = static_cast<cudaStream_t*>(stream.get());
unsigned int flags;

NBLA_CUDA_CHECK(cudaStreamGetFlags(*s, &flags));
printf("flags: %u\n", flags);
}

void print_stream_priority (shared_ptr<void> stream) {
auto s = static_cast<cudaStream_t*>(stream.get());
int p;

NBLA_CUDA_CHECK(cudaStreamGetPriority(*s, &p));
printf("priority: %d\n", p);
}

void cuda_nullstream_synchronize() {
NBLA_CUDA_CHECK(cudaStreamSynchronize(0));
}

void cuda_stream_synchronize(shared_ptr<void> stream) {
auto s = static_cast<cudaStream_t*>(stream.get());
NBLA_CUDA_CHECK(cudaStreamSynchronize(*s));
}

void cuda_stream_destroy(shared_ptr<void> stream) {
auto s = static_cast<cudaStream_t*>(stream.get());

NBLA_CUDA_CHECK(cudaStreamDestroy(*s));
}

std::shared_ptr<void> cuda_create_event(int device_id) {
cuda_set_device(device_id);

std::default_delete<cudaEvent_t> default_deleter;
auto deleter = [default_deleter](cudaEvent_t* ptr) {
NBLA_CUDA_CHECK(cudaEventDestroy(*ptr));

default_deleter(ptr);
};

auto event = shared_ptr<cudaEvent_t>(new cudaEvent_t(), deleter);

NBLA_CUDA_CHECK(cudaEventCreateWithFlags(event.get(), cudaEventDisableTiming));

return event;
}

void cuda_default_stream_event(shared_ptr<void> event){
auto e = static_cast<cudaEvent_t*>(event.get());

NBLA_CUDA_CHECK(cudaEventRecord(*e));

}
void cuda_stream_wait_event(shared_ptr<void> stream, shared_ptr<void> event) {
auto s = static_cast<cudaStream_t*>(stream.get());
auto e = static_cast<cudaEvent_t*>(event.get());

NBLA_CUDA_CHECK(cudaStreamWaitEvent(*s, *e, 0));

}

void cuda_event_synchronize(shared_ptr<void> event) {
auto e = static_cast<cudaEvent_t*>(event.get());

NBLA_CUDA_CHECK(cudaEventSynchronize(*e));
}

}

0 comments on commit a16372b

Please sign in to comment.