Skip to content

Commit

Permalink
Merge pull request #154 from sony/feature/20190522-fused-bn-nhwc
Browse files Browse the repository at this point in the history
Add CUDNN Fused Batch Normalization and utilize faster CUDNN Batch Normalization
  • Loading branch information
AkioHayakawa-sony authored May 22, 2019
2 parents 2e2e9dd + a1e7a21 commit fea69db
Show file tree
Hide file tree
Showing 7 changed files with 633 additions and 106 deletions.
3 changes: 3 additions & 0 deletions build-tools/code_generator/function_types_cudnn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ SyncBatchNormalization:
BatchNormalization:
float: [float]
half: [Half]
FusedBatchNormalization:
float: [float]
half: [Half]
# MeanSubtraction:
# float: [float]
Sum:
Expand Down
9 changes: 9 additions & 0 deletions include/nbla/cuda/cudnn/cudnn.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,15 @@ struct CudnnTensorDescriptor {
~CudnnTensorDescriptor();
};

/**
CUDNN activation descriptor wrapper.
*/
struct CudnnActivationDescriptor {
cudnnActivationDescriptor_t desc;
CudnnActivationDescriptor();
~CudnnActivationDescriptor();
};

/**
Common CUDNN pooling function wrapper.
*/
Expand Down
36 changes: 17 additions & 19 deletions include/nbla/cuda/cudnn/function/batch_normalization.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,20 @@ template <typename T>
class BatchNormalizationCudaCudnn : public BatchNormalizationCuda<T> {
protected:
int device_;
cudnnBatchNormMode_t mode_;
cudnnHandle_t cudnn_handle_;
cudnnTensorDescriptor_t input_desc_, output_desc_;
cudnnTensorDescriptor_t bn_scale_bias_mean_var_desc_;
CudnnTensorDescriptor input_desc_, output_desc_;
CudnnTensorDescriptor bn_scale_bias_mean_var_desc_;
cudnnDataType_t derived_bn_dtype_;
double epsilon;
cudnnBatchNormMode_t mode_;
#if CUDNN_VERSION >= 7400
bool can_use_bn_ex_{false};
CudnnActivationDescriptor act_desc_;
NdArrayPtr reserve_;
cudnnBatchNormOps_t ops_{CUDNN_BATCHNORM_OPS_BN};
size_t forward_workspace_size_{0};
size_t backward_workspace_size_{0};
size_t reserve_size_{0};
#endif

public:
typedef typename CudaType<T>::type Tw;
Expand All @@ -53,23 +61,13 @@ class BatchNormalizationCudaCudnn : public BatchNormalizationCuda<T> {
this->fall_back_func_.reset(
new BatchNormalizationCuda<T>(ctx, axes, decay_rate, eps, batch_stat));
#else
mode_ = CUDNN_BATCHNORM_SPATIAL;
NBLA_CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc_));
NBLA_CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc_));
NBLA_CUDNN_CHECK(
cudnnCreateTensorDescriptor(&bn_scale_bias_mean_var_desc_));
// NOTE: epsilon should be less than CUDNN_BN_MIN_EPSILON
epsilon = std::max((double)this->eps_, CUDNN_BN_MIN_EPSILON);
NBLA_CHECK(eps >= (float)CUDNN_BN_MIN_EPSILON, error_code::value,
"eps must be greater than or equal to CUDNN_BN_MIN_EPSILON. "
"eps=%g, CUDNN_BN_MIN_EPSILON=%g",
eps, CUDNN_BN_MIN_EPSILON);
#endif
}
virtual ~BatchNormalizationCudaCudnn() {
if (this->fall_back_func_)
return;
NBLA_CUDNN_CHECK(cudnnDestroyTensorDescriptor(input_desc_));
NBLA_CUDNN_CHECK(cudnnDestroyTensorDescriptor(output_desc_));
NBLA_CUDNN_CHECK(
cudnnDestroyTensorDescriptor(bn_scale_bias_mean_var_desc_));
}
virtual ~BatchNormalizationCudaCudnn() {}
virtual string name() { return "BatchNormalizationCudaCudnn"; }
virtual vector<string> allowed_array_classes() {
return SingletonManager::get<Cuda>()->array_classes();
Expand Down
98 changes: 98 additions & 0 deletions include/nbla/cuda/cudnn/function/fused_batch_normalization.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
// Copyright (c) 2017 Sony Corporation. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

/** Batch Normalization
*/
#pragma once

#include <nbla/cuda/common.hpp>
#include <nbla/cuda/cuda.hpp>
#include <nbla/cuda/cudnn/cudnn.hpp>
#include <nbla/function/fused_batch_normalization.hpp>

#include <type_traits>
#include <vector>

using std::vector;

namespace nbla {

template <typename T>
class FusedBatchNormalizationCudaCudnn : public FusedBatchNormalization<T> {
protected:
int device_;
#if CUDNN_VERSION >= 7400
Variable mean_;
Variable var_;
cudnnHandle_t cudnn_handle_;
CudnnTensorDescriptor input_desc_, z_desc_, output_desc_;
CudnnTensorDescriptor bn_scale_bias_mean_var_desc_;
cudnnDataType_t derived_bn_dtype_;
cudnnBatchNormMode_t mode_;
CudnnActivationDescriptor act_desc_;
NdArrayPtr reserve_;
cudnnBatchNormOps_t ops_{CUDNN_BATCHNORM_OPS_BN};
size_t forward_workspace_size_{0};
size_t backward_workspace_size_{0};
size_t reserve_size_{0};
#endif

public:
typedef typename CudaType<T>::type Tw;

FusedBatchNormalizationCudaCudnn(const Context &ctx, const vector<int> axes,
float decay_rate, float eps, bool batch_stat,
const string &nonlinearity)
: FusedBatchNormalization<T>(ctx, axes, decay_rate, eps, batch_stat,
nonlinearity),
device_(std::stoi(ctx.device_id)) {
#if CUDNN_VERSION >= 7400
// Note: The below is_same test causes unreachable statement warning during
// compiling. C++11 does not give any functionality for testing types at
// compile-time.
if (!std::is_same<Tw, HalfCuda>::value || !batch_stat) {
this->fall_back_func_ = make_shared<FusedBatchNormalization<T>>(
ctx, axes, decay_rate, eps, batch_stat, nonlinearity);
return;
}
this->mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
NBLA_CHECK(nonlinearity == "relu", error_code::value,
"Currently \"relu\" only supported.");
NBLA_CHECK(eps >= (float)CUDNN_BN_MIN_EPSILON, error_code::value,
"eps must be greater than or equal to CUDNN_BN_MIN_EPSILON. "
"eps=%g, CUDNN_BN_MIN_EPSILON=%g",
eps, CUDNN_BN_MIN_EPSILON);
NBLA_CUDNN_CHECK(cudnnSetActivationDescriptor(this->act_desc_.desc,
CUDNN_ACTIVATION_RELU,
CUDNN_PROPAGATE_NAN, T(0)));
#endif
}
virtual ~FusedBatchNormalizationCudaCudnn() {}
virtual string name() { return "FusedBatchNormalizationCudaCudnn"; }
virtual vector<string> allowed_array_classes() {
return SingletonManager::get<Cuda>()->array_classes();
}
// FusedBN CUDNN relies on output buffer.
virtual bool grad_depends_output_data(int i, int o) const { return true; }

protected:
#if CUDNN_VERSION > 7400
virtual void setup_impl(const Variables &inputs, const Variables &outputs);
virtual void forward_impl(const Variables &inputs, const Variables &outputs);
virtual void backward_impl(const Variables &inputs, const Variables &outputs,
const vector<bool> &propagate_down,
const vector<bool> &accum);
#endif
};
} // namespace nbla
10 changes: 10 additions & 0 deletions src/nbla/cuda/cudnn/cudnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,16 @@ size_t CudnnConvResource::workspace_size() const {
std::max(bwd_filter_workspace_size, bwd_data_workspace_size));
}

////////////////////////////////////////
// Cudnn activation descriptor Wrapper
////////////////////////////////////////
CudnnActivationDescriptor::CudnnActivationDescriptor() {
NBLA_CUDNN_CHECK(cudnnCreateActivationDescriptor(&desc));
}
CudnnActivationDescriptor::~CudnnActivationDescriptor() {
NBLA_CUDNN_CHECK(cudnnDestroyActivationDescriptor(desc));
}

////////////////////////////////////////
// Cudnn Pooling Wrapper
////////////////////////////////////////
Expand Down
Loading

0 comments on commit fea69db

Please sign in to comment.