Skip to content

Commit

Permalink
Merge pull request #171 from sony/feature/20190528-adamw-sgdw
Browse files Browse the repository at this point in the history
[Solver] AdamW and SGDW
  • Loading branch information
AkioHayakawa-sony authored Jul 8, 2019
2 parents e6ad8e2 + fa425a4 commit 853c965
Show file tree
Hide file tree
Showing 5 changed files with 223 additions and 0 deletions.
4 changes: 4 additions & 0 deletions build-tools/code_generator/solver_types.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
Sgd:
float: [float]
SgdW:
float: [float]
Momentum:
float: [float]
Lars:
Expand All @@ -10,6 +12,8 @@ Adagrad:
float: [float]
Adam:
float: [float]
AdamW:
float: [float]
AdaBound:
float: [float]
Adamax:
Expand Down
43 changes: 43 additions & 0 deletions include/nbla/cuda/solver/adamw.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Copyright (c) 2017 Sony Corporation. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef __NBLA_CUDA_SOLVER_ADAMW_HPP__
#define __NBLA_CUDA_SOLVER_ADAMW_HPP__

#include <nbla/cuda/cuda.hpp>
#include <nbla/solver/adamw.hpp>

namespace nbla {

template <typename T> class AdamWCuda : public AdamW<T> {
public:
explicit AdamWCuda(const Context &ctx, float alpha, float beta1, float beta2,
float eps, float wd)
: AdamW<T>(ctx, alpha, beta1, beta2, eps, wd) {}
virtual ~AdamWCuda() {}
virtual string name() { return "AdamWCuda"; }
virtual vector<string> allowed_array_classes() {
return SingletonManager::get<Cuda>()->array_classes();
}

protected:
virtual void update_impl(const string &key, VariablePtr param);
NBLA_DECL_WEIGHT_DECAY();
NBLA_DECL_CHECK_INF_GRAD();
NBLA_DECL_CHECK_NAN_GRAD();
NBLA_DECL_CHECK_INF_OR_NAN_GRAD();
NBLA_DECL_SCALE_GRAD();
};
}
#endif
42 changes: 42 additions & 0 deletions include/nbla/cuda/solver/sgdw.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Copyright (c) 2017 Sony Corporation. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef __NBLA_CUDA_SOLVER_SGDW_HPP__
#define __NBLA_CUDA_SOLVER_SGDW_HPP__

#include <nbla/cuda/cuda.hpp>
#include <nbla/solver/sgdw.hpp>

namespace nbla {

template <typename T> class SgdWCuda : public SgdW<T> {
public:
explicit SgdWCuda(const Context &ctx, float lr, float momentum, float wd)
: SgdW<T>(ctx, lr, momentum, wd) {}
virtual ~SgdWCuda() {}
virtual string name() { return "SgdWCuda"; }
virtual vector<string> allowed_array_classes() {
return SingletonManager::get<Cuda>()->array_classes();
}

protected:
virtual void update_impl(const string &key, VariablePtr param);
NBLA_DECL_WEIGHT_DECAY();
NBLA_DECL_CHECK_INF_GRAD();
NBLA_DECL_CHECK_NAN_GRAD();
NBLA_DECL_CHECK_INF_OR_NAN_GRAD();
NBLA_DECL_SCALE_GRAD();
};
}
#endif
73 changes: 73 additions & 0 deletions src/nbla/cuda/solver/generic/adamw.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// Copyright (c) 2017 Sony Corporation. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <nbla/cuda/common.hpp>
#include <nbla/cuda/solver/adamw.hpp>

#include "./mixed_precision_training.cuh"
#include "./weight_decay.cuh"

namespace nbla {

template <typename T>
__global__ void
kernel_adamw_update(const int num, T *theta, T *m, T *v, const T *g,
const float alpha_t, const float beta1, const float beta2,
const float eps, const float wd, const T eta_t) {
NBLA_CUDA_KERNEL_LOOP(s, num) {
// Updating running mean and var.
m[s] = beta1 * m[s] + (1 - beta1) * g[s];
v[s] = beta2 * v[s] + (1 - beta2) * g[s] * g[s];
// Update parameters.
theta[s] = theta[s] - alpha_t * m[s] / (std::sqrt(v[s]) + eps) -
eta_t * wd * theta[s];
}
}

template <typename T>
void AdamWCuda<T>::update_impl(const string &key, VariablePtr param) {
cuda_set_device(std::stoi(this->ctx_.device_id));
Size_t size = param->size();
auto &state = this->states_.at(key);
uint32_t &t = state.t;
const T *g = param->get_grad_pointer<T>(this->ctx_);
shared_ptr<Variable> mean_ =
state.pstate["mean"]; // To prevent compile error.
shared_ptr<Variable> var_ = state.pstate["var"]; // To prevent compile error.
T *m = mean_->cast_data_and_get_pointer<T>(this->ctx_);
T *v = var_->cast_data_and_get_pointer<T>(this->ctx_);
T *theta = param->cast_data_and_get_pointer<T>(this->ctx_);
t = std::min(t + 1, std::numeric_limits<uint32_t>::max() - 1);
T eta_t = this->alpha_ / this->init_alpha_;
const T bias_correction = std::sqrt(1 - std::pow(this->beta2_, t)) /
(1 - std::pow(this->beta1_, t));
const T alpha_t = this->alpha_ * bias_correction;
NBLA_CUDA_LAUNCH_KERNEL_SIMPLE(kernel_adamw_update, size, theta, m, v, g,
alpha_t, this->beta1_, this->beta2_,
this->eps_, this->wd_, eta_t);
}

template <typename T>
void AdamWCuda<T>::weight_decay_impl(const string &key, VariablePtr param,
float decay_rate) {
NBLA_CHECK(decay_rate == this->wd_, error_code::value,
"Decay rate should remain the same.");
weight_decay_cuda<T>(this->ctx_, param, decay_rate);
}

NBLA_DEF_CHECK_INF_GRAD(AdamWCuda, check_inf_grad_cuda);
NBLA_DEF_CHECK_NAN_GRAD(AdamWCuda, check_nan_grad_cuda);
NBLA_DEF_CHECK_INF_OR_NAN_GRAD(AdamWCuda, check_inf_or_nan_grad_cuda);
NBLA_DEF_SCALE_GRAD(AdamWCuda, scale_grad_impl_cuda);
}
61 changes: 61 additions & 0 deletions src/nbla/cuda/solver/generic/sgdw.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// Copyright (c) 2017 Sony Corporation. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <nbla/cuda/common.hpp>
#include <nbla/cuda/solver/sgdw.hpp>

#include "./mixed_precision_training.cuh"
#include "./weight_decay.cuh"

namespace nbla {

template <typename T>
__global__ void kernel_update(const int num, T *data, const T *grad, T *v,
const float lr, const float momentum,
const float wd, T eta_t) {
NBLA_CUDA_KERNEL_LOOP(idx, num) {
v[idx] = momentum * v[idx] + lr * grad[idx] - (eta_t * wd * v[idx]);
data[idx] -= v[idx];
}
}

template <typename T>
void SgdWCuda<T>::update_impl(const string &key, VariablePtr param) {
cuda_set_device(std::stoi(this->ctx_.device_id));
Size_t size = param->size();
auto &state = this->states_.at(key);
VariablePtr r_ = state.pstate["m"];
const T *grad = param->get_grad_pointer<T>(this->ctx_);
T *v = r_->cast_data_and_get_pointer<T>(this->ctx_);
T *data = param->cast_data_and_get_pointer<T>(this->ctx_);
T eta_t = this->lr_ / this->init_lr_;
NBLA_CUDA_LAUNCH_KERNEL_SIMPLE(kernel_update, size, data, grad, v, this->lr_,
this->momentum_, this->wd_, eta_t);
auto &t = state.t;
t = std::min(t + 1, std::numeric_limits<uint32_t>::max() - 1);
}

template <typename T>
void SgdWCuda<T>::weight_decay_impl(const string &key, VariablePtr param,
float decay_rate) {
NBLA_CHECK(decay_rate == this->wd_, error_code::value,
"Decay rate should remain the same.");
weight_decay_cuda<T>(this->ctx_, param, decay_rate);
}

NBLA_DEF_CHECK_INF_GRAD(SgdWCuda, check_inf_grad_cuda);
NBLA_DEF_CHECK_NAN_GRAD(SgdWCuda, check_nan_grad_cuda);
NBLA_DEF_CHECK_INF_OR_NAN_GRAD(SgdWCuda, check_inf_or_nan_grad_cuda);
NBLA_DEF_SCALE_GRAD(SgdWCuda, scale_grad_impl_cuda);
}

0 comments on commit 853c965

Please sign in to comment.