diff --git a/build-tools/code_generator/solver_types.yaml b/build-tools/code_generator/solver_types.yaml index 765573d3e..a8ac11b02 100644 --- a/build-tools/code_generator/solver_types.yaml +++ b/build-tools/code_generator/solver_types.yaml @@ -1,5 +1,7 @@ Sgd: float: [float] +SgdW: + float: [float] Momentum: float: [float] Lars: @@ -10,6 +12,8 @@ Adagrad: float: [float] Adam: float: [float] +AdamW: + float: [float] AdaBound: float: [float] Adamax: diff --git a/include/nbla/cuda/solver/adamw.hpp b/include/nbla/cuda/solver/adamw.hpp new file mode 100644 index 000000000..fd5a5e4c4 --- /dev/null +++ b/include/nbla/cuda/solver/adamw.hpp @@ -0,0 +1,43 @@ +// Copyright (c) 2017 Sony Corporation. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef __NBLA_CUDA_SOLVER_ADAMW_HPP__ +#define __NBLA_CUDA_SOLVER_ADAMW_HPP__ + +#include +#include + +namespace nbla { + +template class AdamWCuda : public AdamW { +public: + explicit AdamWCuda(const Context &ctx, float alpha, float beta1, float beta2, + float eps, float wd) + : AdamW(ctx, alpha, beta1, beta2, eps, wd) {} + virtual ~AdamWCuda() {} + virtual string name() { return "AdamWCuda"; } + virtual vector allowed_array_classes() { + return SingletonManager::get()->array_classes(); + } + +protected: + virtual void update_impl(const string &key, VariablePtr param); + NBLA_DECL_WEIGHT_DECAY(); + NBLA_DECL_CHECK_INF_GRAD(); + NBLA_DECL_CHECK_NAN_GRAD(); + NBLA_DECL_CHECK_INF_OR_NAN_GRAD(); + NBLA_DECL_SCALE_GRAD(); +}; +} +#endif diff --git a/include/nbla/cuda/solver/sgdw.hpp b/include/nbla/cuda/solver/sgdw.hpp new file mode 100644 index 000000000..9830e859b --- /dev/null +++ b/include/nbla/cuda/solver/sgdw.hpp @@ -0,0 +1,42 @@ +// Copyright (c) 2017 Sony Corporation. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef __NBLA_CUDA_SOLVER_SGDW_HPP__ +#define __NBLA_CUDA_SOLVER_SGDW_HPP__ + +#include +#include + +namespace nbla { + +template class SgdWCuda : public SgdW { +public: + explicit SgdWCuda(const Context &ctx, float lr, float momentum, float wd) + : SgdW(ctx, lr, momentum, wd) {} + virtual ~SgdWCuda() {} + virtual string name() { return "SgdWCuda"; } + virtual vector allowed_array_classes() { + return SingletonManager::get()->array_classes(); + } + +protected: + virtual void update_impl(const string &key, VariablePtr param); + NBLA_DECL_WEIGHT_DECAY(); + NBLA_DECL_CHECK_INF_GRAD(); + NBLA_DECL_CHECK_NAN_GRAD(); + NBLA_DECL_CHECK_INF_OR_NAN_GRAD(); + NBLA_DECL_SCALE_GRAD(); +}; +} +#endif diff --git a/src/nbla/cuda/solver/generic/adamw.cu b/src/nbla/cuda/solver/generic/adamw.cu new file mode 100644 index 000000000..a35ce09f6 --- /dev/null +++ b/src/nbla/cuda/solver/generic/adamw.cu @@ -0,0 +1,73 @@ +// Copyright (c) 2017 Sony Corporation. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "./mixed_precision_training.cuh" +#include "./weight_decay.cuh" + +namespace nbla { + +template +__global__ void +kernel_adamw_update(const int num, T *theta, T *m, T *v, const T *g, + const float alpha_t, const float beta1, const float beta2, + const float eps, const float wd, const T eta_t) { + NBLA_CUDA_KERNEL_LOOP(s, num) { + // Updating running mean and var. + m[s] = beta1 * m[s] + (1 - beta1) * g[s]; + v[s] = beta2 * v[s] + (1 - beta2) * g[s] * g[s]; + // Update parameters. + theta[s] = theta[s] - alpha_t * m[s] / (std::sqrt(v[s]) + eps) - + eta_t * wd * theta[s]; + } +} + +template +void AdamWCuda::update_impl(const string &key, VariablePtr param) { + cuda_set_device(std::stoi(this->ctx_.device_id)); + Size_t size = param->size(); + auto &state = this->states_.at(key); + uint32_t &t = state.t; + const T *g = param->get_grad_pointer(this->ctx_); + shared_ptr mean_ = + state.pstate["mean"]; // To prevent compile error. + shared_ptr var_ = state.pstate["var"]; // To prevent compile error. + T *m = mean_->cast_data_and_get_pointer(this->ctx_); + T *v = var_->cast_data_and_get_pointer(this->ctx_); + T *theta = param->cast_data_and_get_pointer(this->ctx_); + t = std::min(t + 1, std::numeric_limits::max() - 1); + T eta_t = this->alpha_ / this->init_alpha_; + const T bias_correction = std::sqrt(1 - std::pow(this->beta2_, t)) / + (1 - std::pow(this->beta1_, t)); + const T alpha_t = this->alpha_ * bias_correction; + NBLA_CUDA_LAUNCH_KERNEL_SIMPLE(kernel_adamw_update, size, theta, m, v, g, + alpha_t, this->beta1_, this->beta2_, + this->eps_, this->wd_, eta_t); +} + +template +void AdamWCuda::weight_decay_impl(const string &key, VariablePtr param, + float decay_rate) { + NBLA_CHECK(decay_rate == this->wd_, error_code::value, + "Decay rate should remain the same."); + weight_decay_cuda(this->ctx_, param, decay_rate); +} + +NBLA_DEF_CHECK_INF_GRAD(AdamWCuda, check_inf_grad_cuda); +NBLA_DEF_CHECK_NAN_GRAD(AdamWCuda, check_nan_grad_cuda); +NBLA_DEF_CHECK_INF_OR_NAN_GRAD(AdamWCuda, check_inf_or_nan_grad_cuda); +NBLA_DEF_SCALE_GRAD(AdamWCuda, scale_grad_impl_cuda); +} diff --git a/src/nbla/cuda/solver/generic/sgdw.cu b/src/nbla/cuda/solver/generic/sgdw.cu new file mode 100644 index 000000000..6822d1359 --- /dev/null +++ b/src/nbla/cuda/solver/generic/sgdw.cu @@ -0,0 +1,61 @@ +// Copyright (c) 2017 Sony Corporation. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "./mixed_precision_training.cuh" +#include "./weight_decay.cuh" + +namespace nbla { + +template +__global__ void kernel_update(const int num, T *data, const T *grad, T *v, + const float lr, const float momentum, + const float wd, T eta_t) { + NBLA_CUDA_KERNEL_LOOP(idx, num) { + v[idx] = momentum * v[idx] + lr * grad[idx] - (eta_t * wd * v[idx]); + data[idx] -= v[idx]; + } +} + +template +void SgdWCuda::update_impl(const string &key, VariablePtr param) { + cuda_set_device(std::stoi(this->ctx_.device_id)); + Size_t size = param->size(); + auto &state = this->states_.at(key); + VariablePtr r_ = state.pstate["m"]; + const T *grad = param->get_grad_pointer(this->ctx_); + T *v = r_->cast_data_and_get_pointer(this->ctx_); + T *data = param->cast_data_and_get_pointer(this->ctx_); + T eta_t = this->lr_ / this->init_lr_; + NBLA_CUDA_LAUNCH_KERNEL_SIMPLE(kernel_update, size, data, grad, v, this->lr_, + this->momentum_, this->wd_, eta_t); + auto &t = state.t; + t = std::min(t + 1, std::numeric_limits::max() - 1); +} + +template +void SgdWCuda::weight_decay_impl(const string &key, VariablePtr param, + float decay_rate) { + NBLA_CHECK(decay_rate == this->wd_, error_code::value, + "Decay rate should remain the same."); + weight_decay_cuda(this->ctx_, param, decay_rate); +} + +NBLA_DEF_CHECK_INF_GRAD(SgdWCuda, check_inf_grad_cuda); +NBLA_DEF_CHECK_NAN_GRAD(SgdWCuda, check_nan_grad_cuda); +NBLA_DEF_CHECK_INF_OR_NAN_GRAD(SgdWCuda, check_inf_or_nan_grad_cuda); +NBLA_DEF_SCALE_GRAD(SgdWCuda, scale_grad_impl_cuda); +}