diff --git a/build-tools/code_generator/function_types_cudnn.yaml b/build-tools/code_generator/function_types_cudnn.yaml index af8aaa9ef..85cee5d68 100644 --- a/build-tools/code_generator/function_types_cudnn.yaml +++ b/build-tools/code_generator/function_types_cudnn.yaml @@ -44,7 +44,9 @@ ReLU: # float: [float] Softmax: float: [float] - half: [Half] +LogSoftmax: + float: [float] +# half: [Half] # ELU: # float: [float] # SELU: diff --git a/include/nbla/cuda/cudnn/cudnn.hpp b/include/nbla/cuda/cudnn/cudnn.hpp index 9ca365575..358c2af0e 100644 --- a/include/nbla/cuda/cudnn/cudnn.hpp +++ b/include/nbla/cuda/cudnn/cudnn.hpp @@ -277,6 +277,28 @@ class CudnnPooling { const void *beta, void *dx) const; }; +/** + CUDNN softmax function wrapper + */ +class CudnnSoftmax { + CudnnTensorDescriptor input_desc_; + CudnnTensorDescriptor output_desc_; + cudnnSoftmaxAlgorithm_t algo_; + int device_; + +public: + typedef shared_ptr Ptr; + CudnnSoftmax(const Shape_t &inshape, int axis, cudnnSoftmaxAlgorithm_t algo, + cudnnDataType_t dtype, int device); + static Ptr create(const Shape_t &inshape, int axis, + cudnnSoftmaxAlgorithm_t algo, cudnnDataType_t dtype, + int device); + void forward(const void *alpha, const void *x, const void *beta, + void *y) const; + void backward(const void *alpha, const void *y, const void *dy, + const void *beta, void *dx) const; +}; + /** cuDNN Convolution resource cache. */ struct NBLA_CUDA_API CudnnConvResource { diff --git a/include/nbla/cuda/cudnn/function/log_softmax.hpp b/include/nbla/cuda/cudnn/function/log_softmax.hpp new file mode 100644 index 000000000..778f726d0 --- /dev/null +++ b/include/nbla/cuda/cudnn/function/log_softmax.hpp @@ -0,0 +1,51 @@ +// Copyright (c) 2017 Sony Corporation. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef __NBLA_CUDA_CUDNN_FUNCTION_LOG_SOFTMAX_HPP__ +#define __NBLA_CUDA_CUDNN_FUNCTION_LOG_SOFTMAX_HPP__ + +#include +#include +#include +#include + +namespace nbla { + +/** @copydoc LogSoftmax + +@note The default algorithm is set as ACCURATE. TODO: Set an algorithm by + context. +*/ +template class LogSoftmaxCudaCudnn : public LogSoftmax { +public: + typedef typename CudaType::type Tw; + + explicit LogSoftmaxCudaCudnn(const Context &ctx, int axis) + : LogSoftmax(ctx, axis), device_(std::stoi(ctx.device_id)) {} + virtual string name() { return "LogSoftmaxCudaCudnn"; } + virtual vector allowed_array_classes() { + return SingletonManager::get()->array_classes(); + } + +protected: + int device_; + CudnnSoftmax::Ptr cudnn_softmax_; + virtual void setup_impl(const Variables &inputs, const Variables &outputs); + virtual void forward_impl(const Variables &inputs, const Variables &outputs); + virtual void backward_impl(const Variables &inputs, const Variables &outputs, + const vector &propagate_down, + const vector &accum); +}; +} +#endif diff --git a/include/nbla/cuda/cudnn/function/softmax.hpp b/include/nbla/cuda/cudnn/function/softmax.hpp index bc6453a19..67b297642 100644 --- a/include/nbla/cuda/cudnn/function/softmax.hpp +++ b/include/nbla/cuda/cudnn/function/softmax.hpp @@ -24,34 +24,23 @@ namespace nbla { /** @copydoc Softmax -@note The default algorithm is set as ACCURATE. TODO: Set an algorithm by - context. +@note The default algorithm is set as ACCURATE. */ template class SoftmaxCudaCudnn : public Softmax { public: typedef typename CudaType::type Tw; explicit SoftmaxCudaCudnn(const Context &ctx, int axis) - : Softmax(ctx, axis), device_(std::stoi(ctx.device_id)) { - NBLA_CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc_)); - NBLA_CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc_)); - } - virtual ~SoftmaxCudaCudnn() { - NBLA_CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc_)); - NBLA_CUDNN_CHECK(cudnnDestroyTensorDescriptor(output_desc_)); - } + : Softmax(ctx, axis), device_(std::stoi(ctx.device_id)) {} virtual string name() { return "SoftmaxCudaCudnn"; } virtual vector allowed_array_classes() { return SingletonManager::get()->array_classes(); } - void set_cudnn_softmax_algorithm(std::string algorithm); protected: int device_; - cudnnHandle_t cudnn_handle_; - cudnnTensorDescriptor_t input_desc_; - cudnnTensorDescriptor_t output_desc_; - cudnnSoftmaxAlgorithm_t algorithm_; + CudnnSoftmax::Ptr cudnn_softmax_; + virtual void setup_impl(const Variables &inputs, const Variables &outputs); virtual void forward_impl(const Variables &inputs, const Variables &outputs); virtual void backward_impl(const Variables &inputs, const Variables &outputs, diff --git a/src/nbla/cuda/cudnn/cudnn.cpp b/src/nbla/cuda/cudnn/cudnn.cpp index 262834cca..f6494c589 100644 --- a/src/nbla/cuda/cudnn/cudnn.cpp +++ b/src/nbla/cuda/cudnn/cudnn.cpp @@ -574,6 +574,51 @@ void CudnnPooling::backward(const void *alpha, const void *y, const void *dy, output_desc_.desc, dy, input_desc_.desc, x, beta, input_desc_.desc, dx)); } +////////////////////////////// +// CUDNN Softmax wrapper +////////////////////////////// +CudnnSoftmax::CudnnSoftmax(const Shape_t &inshape, int axis, + cudnnSoftmaxAlgorithm_t algo, cudnnDataType_t dtype, + int device) + : algo_(algo), device_(device) { + const size_t size = std::accumulate(inshape.cbegin(), inshape.cend(), + (size_t)1, std::multiplies()); + const size_t size_axis = ndi::inner_size(inshape, axis); + const int N = size / size_axis; // Batch size. + const int C = inshape[axis]; // Size of specified axis. + const int H = size / (N * C); // Size of rest. + const int W = 1; + const int stride_w = 1; + const int stride_h = W * stride_w; + const int stride_c = H * stride_h; + const int stride_n = C * stride_c; + NBLA_CUDNN_CHECK(cudnnSetTensor4dDescriptorEx(input_desc_.desc, dtype, N, C, + H, W, stride_n, stride_c, + stride_h, stride_w)); + NBLA_CUDNN_CHECK(cudnnSetTensor4dDescriptorEx(output_desc_.desc, dtype, N, C, + H, W, stride_n, stride_c, + stride_h, stride_w)); +} +CudnnSoftmax::Ptr CudnnSoftmax::create(const Shape_t &inshape, int axis, + cudnnSoftmaxAlgorithm_t algo, + cudnnDataType_t dtype, int device) { + return make_shared(inshape, axis, algo, dtype, device); +} +void CudnnSoftmax::forward(const void *alpha, const void *x, const void *beta, + void *y) const { + auto handle = SingletonManager::get()->handle(device_); + NBLA_CUDNN_CHECK( + cudnnSoftmaxForward(handle, algo_, CUDNN_SOFTMAX_MODE_CHANNEL, alpha, + input_desc_.desc, x, beta, output_desc_.desc, y)); +} +void CudnnSoftmax::backward(const void *alpha, const void *y, const void *dy, + const void *beta, void *dx) const { + auto handle = SingletonManager::get()->handle(device_); + NBLA_CUDNN_CHECK(cudnnSoftmaxBackward( + handle, algo_, CUDNN_SOFTMAX_MODE_CHANNEL, alpha, output_desc_.desc, y, + output_desc_.desc, dy, beta, input_desc_.desc, dx)); +} + ////////////////////////////// // cuDNN Handle implementation ////////////////////////////// diff --git a/src/nbla/cuda/cudnn/function/generic/log_softmax.cu b/src/nbla/cuda/cudnn/function/generic/log_softmax.cu new file mode 100644 index 000000000..80eaab680 --- /dev/null +++ b/src/nbla/cuda/cudnn/function/generic/log_softmax.cu @@ -0,0 +1,62 @@ +// Copyright (c) 2017 Sony Corporation. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// log_softmax.cu + +#include +#include +#include +#include +#include +#include + +namespace nbla { + +template +void LogSoftmaxCudaCudnn::setup_impl(const Variables &inputs, + const Variables &outputs) { + LogSoftmax::setup_impl(inputs, outputs); + auto dtype = cudnn_data_type::type(); + cudnn_softmax_ = CudnnSoftmax::create( + inputs[0]->shape(), this->axis_, CUDNN_SOFTMAX_LOG, dtype, this->device_); +} + +template +void LogSoftmaxCudaCudnn::forward_impl(const Variables &inputs, + const Variables &outputs) { + NBLA_CHECK(cudnn_softmax_, error_code::value, "setup not called."); + auto x = inputs[0]->get_data_pointer(this->ctx_); + auto y = outputs[0]->cast_data_and_get_pointer(this->ctx_, true); + auto alpha = get_cudnn_scalar_arg(1); + auto beta = get_cudnn_scalar_arg(0); + cudnn_softmax_->forward(&alpha, x, &beta, y); +} + +template +void LogSoftmaxCudaCudnn::backward_impl(const Variables &inputs, + const Variables &outputs, + const vector &propagate_down, + const vector &accum) { + if (!propagate_down[0]) { + return; + } + NBLA_CHECK(cudnn_softmax_, error_code::value, "setup not called."); + auto y = outputs[0]->get_data_pointer(this->ctx_); + auto dy = outputs[0]->get_grad_pointer(this->ctx_); + auto dx = inputs[0]->cast_grad_and_get_pointer(this->ctx_, !accum[0]); + auto alpha = get_cudnn_scalar_arg(1); + auto beta = get_cudnn_scalar_arg(accum[0] ? 1 : 0); + cudnn_softmax_->backward(&alpha, y, dy, &beta, dx); +} +} // namespace nbla diff --git a/src/nbla/cuda/cudnn/function/generic/softmax.cu b/src/nbla/cuda/cudnn/function/generic/softmax.cu index f471249ab..c76fd7d5a 100644 --- a/src/nbla/cuda/cudnn/function/generic/softmax.cu +++ b/src/nbla/cuda/cudnn/function/generic/softmax.cu @@ -27,37 +27,21 @@ template void SoftmaxCudaCudnn::setup_impl(const Variables &inputs, const Variables &outputs) { Softmax::setup_impl(inputs, outputs); - cudnn_handle_ = SingletonManager::get()->handle(device_); - int N = this->size0_; - int C = this->size1_; - int H = this->size2_; - int W = 1; - const int stride_w = 1; - const int stride_h = W * stride_w; - const int stride_c = H * stride_h; - const int stride_n = C * stride_c; - NBLA_CUDNN_CHECK(cudnnSetTensor4dDescriptorEx( - input_desc_, cudnn_data_type::type(), N, C, H, W, stride_n, stride_c, - stride_h, stride_w)); - NBLA_CUDNN_CHECK(cudnnSetTensor4dDescriptorEx( - output_desc_, cudnn_data_type::type(), N, C, H, W, stride_n, stride_c, - stride_h, stride_w)); - // default algorithm setting. - // TODO: set by context. - set_cudnn_softmax_algorithm("ACCURATE"); + auto dtype = cudnn_data_type::type(); + cudnn_softmax_ = + CudnnSoftmax::create(inputs[0]->shape(), this->axis_, + CUDNN_SOFTMAX_ACCURATE, dtype, this->device_); } template void SoftmaxCudaCudnn::forward_impl(const Variables &inputs, const Variables &outputs) { - cuda_set_device(std::stoi(this->ctx_.device_id)); - const Tw *x = inputs[0]->get_data_pointer(this->ctx_); - Tw *y = outputs[0]->cast_data_and_get_pointer(this->ctx_, true); + NBLA_CHECK(cudnn_softmax_, error_code::value, "setup not called."); + auto x = inputs[0]->get_data_pointer(this->ctx_); + auto y = outputs[0]->cast_data_and_get_pointer(this->ctx_, true); auto alpha = get_cudnn_scalar_arg(1); auto beta = get_cudnn_scalar_arg(0); - NBLA_CUDNN_CHECK(cudnnSoftmaxForward(cudnn_handle_, algorithm_, - CUDNN_SOFTMAX_MODE_CHANNEL, &alpha, - input_desc_, x, &beta, output_desc_, y)); + cudnn_softmax_->forward(&alpha, x, &beta, y); } template @@ -68,25 +52,12 @@ void SoftmaxCudaCudnn::backward_impl(const Variables &inputs, if (!propagate_down[0]) { return; } - cuda_set_device(std::stoi(this->ctx_.device_id)); - const Tw *y = outputs[0]->get_data_pointer(this->ctx_); - const Tw *dy = outputs[0]->get_grad_pointer(this->ctx_); - Tw *dx = inputs[0]->cast_grad_and_get_pointer(this->ctx_, !accum[0]); + NBLA_CHECK(cudnn_softmax_, error_code::value, "setup not called."); + auto y = outputs[0]->get_data_pointer(this->ctx_); + auto dy = outputs[0]->get_grad_pointer(this->ctx_); + auto dx = inputs[0]->cast_grad_and_get_pointer(this->ctx_, !accum[0]); auto alpha = get_cudnn_scalar_arg(1); auto beta = get_cudnn_scalar_arg(accum[0] ? 1 : 0); - NBLA_CUDNN_CHECK(cudnnSoftmaxBackward( - cudnn_handle_, algorithm_, CUDNN_SOFTMAX_MODE_CHANNEL, &alpha, - output_desc_, y, output_desc_, dy, &beta, input_desc_, dx)); -} - -template -void SoftmaxCudaCudnn::set_cudnn_softmax_algorithm(std::string algorithm) { - if (algorithm == "FAST") { - algorithm_ = CUDNN_SOFTMAX_FAST; - } else if (algorithm == "ACCURATE") { - algorithm_ = CUDNN_SOFTMAX_ACCURATE; - } else { - NBLA_ERROR(error_code::target_specific, "Specified unsupported algorithm"); - } -} + cudnn_softmax_->backward(&alpha, y, dy, &beta, dx); } +} // namespace nbla diff --git a/src/nbla/cuda/function/generic/softmax.cu b/src/nbla/cuda/function/generic/softmax.cu index e6b6bded5..d9904462a 100644 --- a/src/nbla/cuda/function/generic/softmax.cu +++ b/src/nbla/cuda/function/generic/softmax.cu @@ -26,20 +26,21 @@ namespace nbla { template __global__ void kernel_softmax_forward(const int size0x2_, const int size1_, const int size2_, const T *x, T *y) { + typedef typename CudaTypeForceFloat::type AccumType; NBLA_CUDA_KERNEL_LOOP(idx, size0x2_) { const int i0 = idx / size2_; const int i2 = idx % size2_; // compute maximum - T max_x = nbla::numeric_limits_cuda::min(); + AccumType max_x = -nbla::numeric_limits_cuda::max(); for (int i1 = 0; i1 < size1_; ++i1) { const int k = (i0 * size1_ + i1) * size2_ + i2; max_x = max(max_x, x[k]); } // Compute exponential and sum - T exp_sum = T(0); + AccumType exp_sum = T(0); for (int i1 = 0; i1 < size1_; ++i1) { const int k = (i0 * size1_ + i1) * size2_ + i2; - const T tmp = std::exp(x[k] - max_x); + const AccumType tmp = std::exp(x[k] - max_x); y[k] = tmp; exp_sum += tmp; } @@ -55,11 +56,12 @@ template __global__ void kernel_softmax_backward(const int size0x2_, const int size1_, const int size2_, const T *y, const T *dy, T *dx) { + typedef typename CudaTypeForceFloat::type AccumType; NBLA_CUDA_KERNEL_LOOP(idx, size0x2_) { const int i0 = idx / size2_; const int i2 = idx % size2_; // compute sum of dy * y - T dyy_sum = T(0); + AccumType dyy_sum = T(0); for (int i1 = 0; i1 < size1_; ++i1) { const int k = (i0 * size1_ + i1) * size2_ + i2; dyy_sum += dy[k] * y[k]; diff --git a/src/nbla/cuda/function/generic/softmax_cross_entropy.cu b/src/nbla/cuda/function/generic/softmax_cross_entropy.cu index 455bfce1c..dfa682d19 100644 --- a/src/nbla/cuda/function/generic/softmax_cross_entropy.cu +++ b/src/nbla/cuda/function/generic/softmax_cross_entropy.cu @@ -26,23 +26,23 @@ namespace nbla { template __global__ void kernel_softmax_cross_entropy_forward(const int size0x2_, const int size1_, - const int size2_, const T *p, const Tl *l, - T *y) { + const int size2_, const T *log_p, + const Tl *l, T *y) { NBLA_CUDA_KERNEL_LOOP(idx, size0x2_) { const int i0 = idx / size2_; const int i2 = idx % size2_; const int j = i0 * size2_ + i2; Tl label = l[j]; const int k = i0 * size1_ * size2_ + label * size2_ + i2; - y[j] = -std::log(max(p[k], numeric_limits_cuda::min())); + y[j] = -log_p[k]; } } template __global__ void kernel_softmax_cross_entropy_backward(const int size0x2_, const int size1_, - const int size2_, const T *p, const T *dy, - const Tl *l, T *dx) { + const int size2_, const T *log_p, + const T *dy, const Tl *l, T *dx) { NBLA_CUDA_KERNEL_LOOP(idx, size0x2_) { const int i0 = idx / size2_; const int i2 = idx % size2_; @@ -52,7 +52,7 @@ kernel_softmax_cross_entropy_backward(const int size0x2_, const int size1_, for (int i1 = 0; i1 < size1_; ++i1) { const int k = i0 * size1_ * size2_ + i1 * size2_ + i2; dx[k] = (accum ? dx[k] : (T)0) + - grad * (p[k] - static_cast(label == i1)); + grad * (std::exp(log_p[k]) - static_cast(label == i1)); } } } @@ -67,15 +67,15 @@ template void SoftmaxCrossEntropyCuda::forward_impl(const Variables &inputs, const Variables &outputs) { cuda_set_device(std::stoi(this->ctx_.device_id)); - Variable &tso = this->softmax_output_; - this->softmax_->forward(Variables{inputs[0]}, Variables{&tso}); + Variable &tso = this->log_softmax_output_; + this->log_softmax_->forward(Variables{inputs[0]}, Variables{&tso}); // Setting up variables - const Tc *p = tso.get_data_pointer(this->ctx_); + const Tc *log_p = tso.get_data_pointer(this->ctx_); const Tl *l = inputs[1]->get_data_pointer(this->ctx_); Tc *y = outputs[0]->cast_data_and_get_pointer(this->ctx_, true); NBLA_CUDA_LAUNCH_KERNEL_SIMPLE(kernel_softmax_cross_entropy_forward, this->size0_ * this->size2_, this->size1_, - this->size2_, p, l, y); + this->size2_, log_p, l, y); } template @@ -88,19 +88,21 @@ void SoftmaxCrossEntropyCuda::backward_impl( return; cuda_set_device(std::stoi(this->ctx_.device_id)); - Variable &tso = this->softmax_output_; - const Tc *p = tso.get_data_pointer(this->ctx_); + Variable &tso = this->log_softmax_output_; + const Tc *log_p = tso.get_data_pointer(this->ctx_); const Tc *dy = outputs[0]->get_grad_pointer(this->ctx_); const Tl *l = inputs[1]->get_data_pointer(this->ctx_); Tc *dx = inputs[0]->cast_grad_and_get_pointer(this->ctx_, !accum[0]); if (accum[0]) { NBLA_CUDA_LAUNCH_KERNEL_SIMPLE( (kernel_softmax_cross_entropy_backward), - this->size0_ * this->size2_, this->size1_, this->size2_, p, dy, l, dx); + this->size0_ * this->size2_, this->size1_, this->size2_, log_p, dy, l, + dx); } else { NBLA_CUDA_LAUNCH_KERNEL_SIMPLE( (kernel_softmax_cross_entropy_backward), - this->size0_ * this->size2_, this->size1_, this->size2_, p, dy, l, dx); + this->size0_ * this->size2_, this->size1_, this->size2_, log_p, dy, l, + dx); } } }