Skip to content

Commit

Permalink
Merge pull request #165 from sony/feature/20190611-fix-mathtype-cudnn…
Browse files Browse the repository at this point in the history
…-conv

[fix] Set mathType properly in CUDNN Convolution
  • Loading branch information
AkioHayakawa-sony authored Jun 25, 2019
2 parents a14c14a + 9453611 commit 12c73fc
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 62 deletions.
26 changes: 18 additions & 8 deletions include/nbla/cuda/cudnn/cudnn.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,14 @@ struct NBLA_CUDA_API CudnnConvDesc {

std::ostream &operator<<(std::ostream &os, const CudnnConvDesc &desc);

/** CUDNN Convolution descriptor wrapper.
*/
struct CudnnConvolutionDescriptor {
cudnnConvolutionDescriptor_t desc;
CudnnConvolutionDescriptor();
~CudnnConvolutionDescriptor();
};

/**
CUDNN Pooling descriptor wrapper.
*/
Expand Down Expand Up @@ -302,14 +310,16 @@ class CudnnSoftmax {
/** cuDNN Convolution resource cache.
*/
struct NBLA_CUDA_API CudnnConvResource {
int device; ///< Device ID.
cudnnTensorDescriptor_t x_desc; ///< Input desc.
cudnnTensorDescriptor_t y_desc; ///< Output desc.
cudnnTensorDescriptor_t b_desc; ///< Bias desc.
cudnnTensorDescriptor_t b_desc_deconv; ///< Bias desc for deconvolution.
cudnnFilterDescriptor_t w_desc; ///< Weight desc.
cudnnConvolutionDescriptor_t conv_desc; ///< Conv desc.
cudnnConvolutionFwdAlgo_t fwd_algo; ///< Best forward algorithm found.
int device; ///< Device ID.
cudnnTensorDescriptor_t x_desc; ///< Input desc.
cudnnTensorDescriptor_t y_desc; ///< Output desc.
cudnnTensorDescriptor_t b_desc; ///< Bias desc.
cudnnTensorDescriptor_t b_desc_deconv; ///< Bias desc for deconvolution.
cudnnFilterDescriptor_t w_desc; ///< Weight desc.
CudnnConvolutionDescriptor conv_desc; ///< Conv desc.
CudnnConvolutionDescriptor conv_dgrad_desc; ///< Conv backward data desc.
CudnnConvolutionDescriptor conv_wgrad_desc; ///< Conv backward filter desc.
cudnnConvolutionFwdAlgo_t fwd_algo; ///< Best forward algorithm found.
cudnnConvolutionBwdFilterAlgo_t
bwd_filter_algo; ///< Best Backward filter algorithm found.
cudnnConvolutionBwdDataAlgo_t
Expand Down
106 changes: 73 additions & 33 deletions src/nbla/cuda/cudnn/cudnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,12 @@ inline void cudnn_set_convolution_nd_descriptor_force_2dim(
stride.resize(2, 1);
dilation.resize(2, 1);
}
#if CUDNN_VERSION >= 7000
NBLA_CUDNN_CHECK(
cudnnSetConvolutionMathType(conv_desc, CUDNN_TENSOR_OP_MATH));
#endif
NBLA_CUDNN_CHECK(cudnnSetConvolutionNdDescriptor(
conv_desc, ndim, pad.data(), stride.data(), dilation.data(), mode,
dtype));
#if CUDNN_VERSION >= 7000
NBLA_CUDNN_CHECK(
cudnnSetConvolutionMathType(conv_desc, CUDNN_TENSOR_OP_MATH));
NBLA_CUDNN_CHECK(cudnnSetConvolutionGroupCount(conv_desc, group));
#endif
}
Expand Down Expand Up @@ -140,7 +138,7 @@ CudnnConvResource::CudnnConvResource(const CudnnConvDesc &desc) {
NBLA_CUDNN_CHECK(cudnnCreateTensorDescriptor(&b_desc));
NBLA_CUDNN_CHECK(cudnnCreateTensorDescriptor(&b_desc_deconv));
NBLA_CUDNN_CHECK(cudnnCreateFilterDescriptor(&w_desc));
NBLA_CUDNN_CHECK(cudnnCreateConvolutionDescriptor(&conv_desc));

// Set input desc
auto dims =
get_conv_dims(desc.n, desc.c, desc.group, desc.sample, channel_last);
Expand Down Expand Up @@ -204,8 +202,14 @@ CudnnConvResource::CudnnConvResource(const CudnnConvDesc &desc) {
cudnnDataType_t compute_type =
desc.dtype == CUDNN_DATA_HALF ? CUDNN_DATA_FLOAT : desc.dtype;
cudnn_set_convolution_nd_descriptor_force_2dim(
conv_desc, desc.ndim, desc.pad, desc.stride, desc.dilation, desc.group,
desc.mode, compute_type);
conv_desc.desc, desc.ndim, desc.pad, desc.stride, desc.dilation,
desc.group, desc.mode, compute_type);
cudnn_set_convolution_nd_descriptor_force_2dim(
conv_dgrad_desc.desc, desc.ndim, desc.pad, desc.stride, desc.dilation,
desc.group, desc.mode, compute_type);
cudnn_set_convolution_nd_descriptor_force_2dim(
conv_wgrad_desc.desc, desc.ndim, desc.pad, desc.stride, desc.dilation,
desc.group, desc.mode, compute_type);

// Find best algorithm
find_best_algorithms();
Expand All @@ -218,7 +222,6 @@ CudnnConvResource::~CudnnConvResource() {
NBLA_CUDNN_CHECK(cudnnDestroyTensorDescriptor(b_desc));
NBLA_CUDNN_CHECK(cudnnDestroyTensorDescriptor(b_desc_deconv));
NBLA_CUDNN_CHECK(cudnnDestroyFilterDescriptor(w_desc));
NBLA_CUDNN_CHECK(cudnnDestroyConvolutionDescriptor(conv_desc));
}

inline bool check_workspace_limit(int workspace_limit, size_t used_memory) {
Expand All @@ -244,9 +247,9 @@ void CudnnConvResource::find_forward_algorithm(int workspace_limit,
std::unique_ptr<cudnnConvolutionFwdAlgoPerf_t[]> perf_results{
new cudnnConvolutionFwdAlgoPerf_t[max_results]};

NBLA_CUDNN_CHECK(find_algorithm(cudnn_handle, this->x_desc, this->w_desc,
this->conv_desc, this->y_desc, max_results,
&num_results, perf_results.get()));
NBLA_CUDNN_CHECK(find_algorithm(
cudnn_handle, this->x_desc, this->w_desc, this->conv_desc.desc,
this->y_desc, max_results, &num_results, perf_results.get()));
#if 0
for (int i = 0; i < num_results; i++) {
auto &perf_result = perf_results[i];
Expand All @@ -260,13 +263,21 @@ void CudnnConvResource::find_forward_algorithm(int workspace_limit,
for (int i = 0; i < num_results; i++) {
auto &perf_result = perf_results[i];
if (CUDNN_STATUS_SUCCESS == perf_result.status) {
#if CUDNN_VERSION >= 7000
NBLA_CUDNN_CHECK(cudnnSetConvolutionMathType(this->conv_desc.desc,
perf_result.mathType));
#endif
NBLA_CUDNN_CHECK(get_workspace(cudnn_handle, this->x_desc, this->w_desc,
this->conv_desc, this->y_desc,
this->conv_desc.desc, this->y_desc,
perf_result.algo, &workspace_size));
if (check_workspace_limit(workspace_limit, workspace_size)) {
if (check_determinism(deterministic, perf_result.determinism)) {
this->fwd_algo = perf_result.algo;
this->fwd_workspace_size = workspace_size;
#if CUDNN_VERSION >= 7000
NBLA_CUDNN_CHECK(cudnnSetConvolutionMathType(this->conv_desc.desc,
perf_result.mathType));
#endif
return;
}
}
Expand Down Expand Up @@ -294,9 +305,9 @@ void CudnnConvResource::find_backward_data_algorithm(int workspace_limit,
std::unique_ptr<cudnnConvolutionBwdDataAlgoPerf_t[]> perf_results{
new cudnnConvolutionBwdDataAlgoPerf_t[max_results]};

NBLA_CUDNN_CHECK(find_algorithm(cudnn_handle, this->w_desc, this->y_desc,
this->conv_desc, this->x_desc, max_results,
&num_results, perf_results.get()));
NBLA_CUDNN_CHECK(find_algorithm(
cudnn_handle, this->w_desc, this->y_desc, this->conv_dgrad_desc.desc,
this->x_desc, max_results, &num_results, perf_results.get()));
#if 0
for (int i = 0; i < num_results; i++) {
auto &perf_result = perf_results[i];
Expand All @@ -310,13 +321,21 @@ void CudnnConvResource::find_backward_data_algorithm(int workspace_limit,
for (int i = 0; i < num_results; i++) {
auto &perf_result = perf_results[i];
if (CUDNN_STATUS_SUCCESS == perf_result.status) {
#if CUDNN_VERSION >= 7000
NBLA_CUDNN_CHECK(cudnnSetConvolutionMathType(conv_dgrad_desc.desc,
perf_result.mathType));
#endif
NBLA_CUDNN_CHECK(get_workspace(cudnn_handle, this->w_desc, this->y_desc,
this->conv_desc, this->x_desc,
this->conv_dgrad_desc.desc, this->x_desc,
perf_result.algo, &workspace_size));
if (check_workspace_limit(workspace_limit, workspace_size)) {
if (check_determinism(deterministic, perf_result.determinism)) {
this->bwd_data_algo = perf_result.algo;
this->bwd_data_workspace_size = workspace_size;
#if CUDNN_VERSION >= 7000
NBLA_CUDNN_CHECK(cudnnSetConvolutionMathType(conv_dgrad_desc.desc,
perf_result.mathType));
#endif
return;
}
}
Expand Down Expand Up @@ -344,9 +363,9 @@ void CudnnConvResource::find_backward_filter_algorithm(int workspace_limit,
std::unique_ptr<cudnnConvolutionBwdFilterAlgoPerf_t[]> perf_results{
new cudnnConvolutionBwdFilterAlgoPerf_t[max_results]};

NBLA_CUDNN_CHECK(find_algorithm(cudnn_handle, this->x_desc, this->y_desc,
this->conv_desc, this->w_desc, max_results,
&num_results, perf_results.get()));
NBLA_CUDNN_CHECK(find_algorithm(
cudnn_handle, this->x_desc, this->y_desc, this->conv_wgrad_desc.desc,
this->w_desc, max_results, &num_results, perf_results.get()));
#if 0
for (int i = 0; i < num_results; i++) {
auto &perf_result = perf_results[i];
Expand All @@ -360,13 +379,21 @@ void CudnnConvResource::find_backward_filter_algorithm(int workspace_limit,
for (int i = 0; i < num_results; i++) {
auto &perf_result = perf_results[i];
if (CUDNN_STATUS_SUCCESS == perf_result.status) {
#if CUDNN_VERSION >= 7000
NBLA_CUDNN_CHECK(cudnnSetConvolutionMathType(conv_wgrad_desc.desc,
perf_result.mathType));
#endif
NBLA_CUDNN_CHECK(get_workspace(cudnn_handle, this->x_desc, this->y_desc,
this->conv_desc, this->w_desc,
this->conv_wgrad_desc.desc, this->w_desc,
perf_result.algo, &workspace_size));
if (check_workspace_limit(workspace_limit, workspace_size)) {
if (check_determinism(deterministic, perf_result.determinism)) {
this->bwd_filter_algo = perf_result.algo;
this->bwd_filter_workspace_size = workspace_size;
#if CUDNN_VERSION >= 7000
NBLA_CUDNN_CHECK(cudnnSetConvolutionMathType(conv_wgrad_desc.desc,
perf_result.mathType));
#endif
return;
}
}
Expand All @@ -392,11 +419,11 @@ void CudnnConvResource::get_forward_algorithm(int workspace_limit) {
preference = CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT;

NBLA_CUDNN_CHECK(get_algorithm(cudnn_handle, this->x_desc, this->w_desc,
this->conv_desc, this->y_desc, preference,
this->conv_desc.desc, this->y_desc, preference,
workspace_limit, &this->fwd_algo));
if (workspace_limit != 0) {
NBLA_CUDNN_CHECK(get_workspace(cudnn_handle, this->x_desc, this->w_desc,
this->conv_desc, this->y_desc,
this->conv_desc.desc, this->y_desc,
this->fwd_algo, &this->fwd_workspace_size));
} else {
this->fwd_workspace_size = 0;
Expand All @@ -415,13 +442,13 @@ void CudnnConvResource::get_backward_data_algorithm(int workspace_limit) {
else if (workspace_limit > 0)
preference = CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT;

NBLA_CUDNN_CHECK(get_algorithm(cudnn_handle, this->w_desc, this->y_desc,
this->conv_desc, this->x_desc, preference,
workspace_limit, &this->bwd_data_algo));
NBLA_CUDNN_CHECK(get_algorithm(
cudnn_handle, this->w_desc, this->y_desc, this->conv_dgrad_desc.desc,
this->x_desc, preference, workspace_limit, &this->bwd_data_algo));
if (workspace_limit != 0) {
NBLA_CUDNN_CHECK(get_workspace(
cudnn_handle, this->w_desc, this->y_desc, this->conv_desc, this->x_desc,
this->bwd_data_algo, &this->bwd_data_workspace_size));
cudnn_handle, this->w_desc, this->y_desc, this->conv_dgrad_desc.desc,
this->x_desc, this->bwd_data_algo, &this->bwd_data_workspace_size));
} else {
this->bwd_data_workspace_size = 0;
}
Expand All @@ -439,13 +466,13 @@ void CudnnConvResource::get_backward_filter_algorithm(int workspace_limit) {
else if (workspace_limit > 0)
preference = CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT;

NBLA_CUDNN_CHECK(get_algorithm(cudnn_handle, this->x_desc, this->y_desc,
this->conv_desc, this->w_desc, preference,
workspace_limit, &this->bwd_filter_algo));
NBLA_CUDNN_CHECK(get_algorithm(
cudnn_handle, this->x_desc, this->y_desc, this->conv_wgrad_desc.desc,
this->w_desc, preference, workspace_limit, &this->bwd_filter_algo));
if (workspace_limit != 0) {
NBLA_CUDNN_CHECK(get_workspace(
cudnn_handle, this->x_desc, this->y_desc, this->conv_desc, this->w_desc,
this->bwd_filter_algo, &this->bwd_filter_workspace_size));
cudnn_handle, this->x_desc, this->y_desc, this->conv_wgrad_desc.desc,
this->w_desc, this->bwd_filter_algo, &this->bwd_filter_workspace_size));
} else {
this->bwd_filter_workspace_size = 0;
}
Expand Down Expand Up @@ -487,6 +514,16 @@ size_t CudnnConvResource::workspace_size() const {
std::max(bwd_filter_workspace_size, bwd_data_workspace_size));
}

////////////////////////////////////////
// Cudnn Convolution Wrapper
////////////////////////////////////////
CudnnConvolutionDescriptor::CudnnConvolutionDescriptor() {
NBLA_CUDNN_CHECK(cudnnCreateConvolutionDescriptor(&desc));
}
CudnnConvolutionDescriptor::~CudnnConvolutionDescriptor() {
NBLA_CUDNN_CHECK(cudnnDestroyConvolutionDescriptor(desc));
}

////////////////////////////////////////
// Cudnn activation descriptor Wrapper
////////////////////////////////////////
Expand All @@ -498,14 +535,17 @@ CudnnActivationDescriptor::~CudnnActivationDescriptor() {
}

////////////////////////////////////////
// Cudnn Pooling Wrapper
// Cudnn Tensor Descriptor Wrapper
////////////////////////////////////////
CudnnTensorDescriptor::CudnnTensorDescriptor() {
NBLA_CUDNN_CHECK(cudnnCreateTensorDescriptor(&desc));
}
CudnnTensorDescriptor::~CudnnTensorDescriptor() {
NBLA_CUDNN_CHECK(cudnnDestroyTensorDescriptor(desc));
}
////////////////////////////////////////
// Cudnn Pooling Wrapper
////////////////////////////////////////
CudnnPoolingDescriptor::CudnnPoolingDescriptor() {
NBLA_CUDNN_CHECK(cudnnCreatePoolingDescriptor(&desc));
}
Expand Down
20 changes: 10 additions & 10 deletions src/nbla/cuda/cudnn/function/generic/convolution.cu
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,9 @@ void ConvolutionCudaCudnn<T>::forward_impl(const Variables &inputs,
}
#if CUDNN_VERSION >= 7000
NBLA_CUDNN_CHECK(cudnnConvolutionForward(
cudnn_handle_, &alpha, rsc_->x_desc, x, rsc_->w_desc, w, rsc_->conv_desc,
rsc_->fwd_algo, workspace, rsc_->fwd_workspace_size, &beta, rsc_->y_desc,
y));
cudnn_handle_, &alpha, rsc_->x_desc, x, rsc_->w_desc, w,
rsc_->conv_desc.desc, rsc_->fwd_algo, workspace, rsc_->fwd_workspace_size,
&beta, rsc_->y_desc, y));
if (inputs.size() == 3) {
NBLA_CUDNN_CHECK(cudnnAddTensor(cudnn_handle_, &alpha, rsc_->b_desc, b,
&alpha, rsc_->y_desc, y));
Expand All @@ -109,7 +109,7 @@ void ConvolutionCudaCudnn<T>::forward_impl(const Variables &inputs,
for (int g = 0; g < this->group_; ++g) {
NBLA_CUDNN_CHECK(cudnnConvolutionForward(
cudnn_handle_, &alpha, rsc_->x_desc, x + x_offset_ * g, rsc_->w_desc,
w + w_offset_ * g, rsc_->conv_desc, rsc_->fwd_algo, workspace,
w + w_offset_ * g, rsc_->conv_desc.desc, rsc_->fwd_algo, workspace,
rsc_->fwd_workspace_size, &beta, rsc_->y_desc, y + y_offset_ * g));
if (inputs.size() == 3) {
// TODO: Bias addition should be outside of the loop. In that case,
Expand Down Expand Up @@ -162,14 +162,14 @@ void ConvolutionCudaCudnn<T>::backward_impl(const Variables &inputs,
auto beta = get_cudnn_scalar_arg<T>(accum[0] ? 1 : 0);
NBLA_CUDNN_CHECK(cudnnConvolutionBackwardData(
cudnn_handle_, &alpha, rsc_->w_desc, w, rsc_->y_desc, dy,
rsc_->conv_desc, rsc_->bwd_data_algo, workspace,
rsc_->conv_dgrad_desc.desc, rsc_->bwd_data_algo, workspace,
rsc_->bwd_data_workspace_size, &beta, rsc_->x_desc, dx));
}
if (propagate_down[1]) {
auto beta = get_cudnn_scalar_arg<T>(accum[1] ? 1 : 0);
NBLA_CUDNN_CHECK(cudnnConvolutionBackwardFilter(
cudnn_handle_, &alpha, rsc_->x_desc, x, rsc_->y_desc, dy,
rsc_->conv_desc, rsc_->bwd_filter_algo, workspace,
rsc_->conv_wgrad_desc.desc, rsc_->bwd_filter_algo, workspace,
rsc_->bwd_filter_workspace_size, &beta, rsc_->w_desc, dw));
}
if (inputs.size() == 3 && propagate_down[2]) {
Expand All @@ -183,16 +183,16 @@ void ConvolutionCudaCudnn<T>::backward_impl(const Variables &inputs,
auto beta = get_cudnn_scalar_arg<T>(accum[0] ? 1 : 0);
NBLA_CUDNN_CHECK(cudnnConvolutionBackwardData(
cudnn_handle_, &alpha, rsc_->w_desc, w + w_offset_ * g, rsc_->y_desc,
dy + y_offset_ * g, rsc_->conv_desc, rsc_->bwd_data_algo, workspace,
rsc_->bwd_data_workspace_size, &beta, rsc_->x_desc,
dy + y_offset_ * g, rsc_->conv_dgrad_desc.desc, rsc_->bwd_data_algo,
workspace, rsc_->bwd_data_workspace_size, &beta, rsc_->x_desc,
dx + x_offset_ * g));
}
if (propagate_down[1]) {
auto beta = get_cudnn_scalar_arg<T>(accum[1] ? 1 : 0);
NBLA_CUDNN_CHECK(cudnnConvolutionBackwardFilter(
cudnn_handle_, &alpha, rsc_->x_desc, x + x_offset_ * g, rsc_->y_desc,
dy + y_offset_ * g, rsc_->conv_desc, rsc_->bwd_filter_algo, workspace,
rsc_->bwd_filter_workspace_size, &beta, rsc_->w_desc,
dy + y_offset_ * g, rsc_->conv_wgrad_desc.desc, rsc_->bwd_filter_algo,
workspace, rsc_->bwd_filter_workspace_size, &beta, rsc_->w_desc,
dw + w_offset_ * g));
}
if (inputs.size() == 3 && propagate_down[2]) {
Expand Down
Loading

0 comments on commit 12c73fc

Please sign in to comment.