diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc index 6bfae6fdbbde..00857e8cbb00 100644 --- a/src/operator/nn/batch_norm.cc +++ b/src/operator/nn/batch_norm.cc @@ -430,7 +430,7 @@ void BatchNormCompute_CPU(const nnvm::NodeAttrs &attrs, switch (inputs[0].dtype()) { case mshadow::kFloat32: - MKLDNNBatchNorm_Forward(ctx, param, in_data, req, outputs, aux_states); + MKLDNNBatchNormForward(ctx, param, in_data, req, outputs, aux_states); return; } } @@ -472,8 +472,8 @@ void BatchNormGradCompute_CPU(const nnvm::NodeAttrs &attrs, std::vector in_grad(outputs.begin(), outputs.begin() + 3); if (inputs[0].dtype() == mshadow::kFloat32) { - MKLDNNBatchNorm_Backward(ctx, param, out_grad, in_data, - out_data, req, in_grad, aux_states); + MKLDNNBatchNormBackward(ctx, param, out_grad, in_data, + out_data, req, in_grad, aux_states); return; } } diff --git a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h index 145619b4ea65..035092780eb5 100644 --- a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h @@ -49,8 +49,8 @@ using mkldnn::forward_training; using mkldnn::forward_inference; inline static unsigned _GetFlags(const std::vector &in_data, - const std::vector &aux_states, - const BatchNormParam ¶m, bool is_train) { + const std::vector &aux_states, + const BatchNormParam ¶m, bool is_train) { unsigned flags = 0U; if (in_data.size() == 3U) { flags |= use_scale_shift; @@ -65,8 +65,10 @@ inline static unsigned _GetFlags(const std::vector &in_data, } template -inline static t_bn_f_pdesc _GetFwd(const mkldnn::memory &data_mem, bool is_train, - DType eps, unsigned flags) { +inline static t_bn_f_pdesc _GetFwd(const mkldnn::memory &data_mem, + bool is_train, + DType eps, + unsigned flags) { auto data_mpd = data_mem.get_primitive_desc(); auto data_md = data_mpd.desc(); auto engine = CpuEngine::Get()->get_engine(); @@ -81,8 +83,10 @@ inline static t_bn_f_pdesc _GetFwd(const mkldnn::memory &data_mem, bool is_train } template -inline static t_bn_b_pdesc _GetBwd(const mkldnn::memory &data_mem, const mkldnn::memory &diff_mem, - DType eps, unsigned flags) { +inline static t_bn_b_pdesc _GetBwd(const mkldnn::memory &data_mem, + const mkldnn::memory &diff_mem, + DType eps, + unsigned flags) { auto data_mpd = data_mem.get_primitive_desc(); auto data_md = data_mpd.desc(); auto diff_mpd = diff_mem.get_primitive_desc(); @@ -94,11 +98,11 @@ inline static t_bn_b_pdesc _GetBwd(const mkldnn::memory &data_mem, const mkldnn: } template -void MKLDNNBatchNorm_Forward(const OpContext &ctx, const BatchNormParam ¶m, - const std::vector &in_data, - const std::vector &req, - const std::vector &out_data, - const std::vector &aux_states) { +void MKLDNNBatchNormForward(const OpContext &ctx, const BatchNormParam ¶m, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data, + const std::vector &aux_states) { TmpMemMgr::Get()->Init(ctx.requested[batchnorm::kTempSpace]); unsigned flags = _GetFlags(in_data, aux_states, param, ctx.is_train); const NDArray &data = in_data[batchnorm::kData]; @@ -194,13 +198,13 @@ void MKLDNNBatchNorm_Forward(const OpContext &ctx, const BatchNormParam ¶m, } template -void MKLDNNBatchNorm_Backward(const OpContext &ctx, const BatchNormParam ¶m, - const std::vector &out_grad, - const std::vector &in_data, - const std::vector &out_data, - const std::vector &req, - const std::vector &in_grad, - const std::vector &aux_states) { +void MKLDNNBatchNormBackward(const OpContext &ctx, const BatchNormParam ¶m, + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data, + const std::vector &req, + const std::vector &in_grad, + const std::vector &aux_states) { TmpMemMgr::Get()->Init(ctx.requested[batchnorm::kTempSpace]); CHECK_EQ(out_grad.size(), param.output_mean_var ? 3U : 1U); CHECK_EQ(in_data.size(), 3U); @@ -262,12 +266,12 @@ void MKLDNNBatchNorm_Backward(const OpContext &ctx, const BatchNormParam ¶m, DType minus_mom = (1.0f - param.momentum); for (int i = 0; i < channels_; i++) { - moving_mean_ptr[i] = moving_mean_ptr[i] * param.momentum - + out_mean_ptr[i] * minus_mom; + moving_mean_ptr[i] = moving_mean_ptr[i] * param.momentum + + out_mean_ptr[i] * minus_mom; float variance = INVSTD_TO_VARIANCE(out_var_ptr[i], param.eps); tmp_var_ptr[i] = variance; - moving_var_ptr[i] = moving_var_ptr[i] * param.momentum - + variance * minus_mom; + moving_var_ptr[i] = moving_var_ptr[i] * param.momentum + + variance * minus_mom; } std::shared_ptr out_mean_mem( @@ -276,13 +280,13 @@ void MKLDNNBatchNorm_Backward(const OpContext &ctx, const BatchNormParam ¶m, new mkldnn::memory(bwd_pd.variance_primitive_desc(), out_var_ptr)); auto bn_bwd = mkldnn::batch_normalization_backward(bwd_pd, - *data_mem, - mkldnn::primitive::at(*out_mean_mem), - mkldnn::primitive::at(var_mem), - *diff_mem, - *weight_mem, - *gradi_mem, - *gradw_mem); + *data_mem, + mkldnn::primitive::at(*out_mean_mem), + mkldnn::primitive::at(var_mem), + *diff_mem, + *weight_mem, + *gradi_mem, + *gradw_mem); MKLDNNStream::Get()->RegisterPrim(bn_bwd); MKLDNNStream::Get()->Submit();