diff --git a/include/darknet.h b/include/darknet.h index 00b49921f52..55f94ac5224 100644 --- a/include/darknet.h +++ b/include/darknet.h @@ -102,7 +102,7 @@ typedef struct tree { // activations.h typedef enum { - LOGISTIC, RELU, RELIE, LINEAR, RAMP, TANH, PLSE, LEAKY, ELU, LOGGY, STAIR, HARDTAN, LHTAN, SELU, SWISH + LOGISTIC, RELU, RELIE, LINEAR, RAMP, TANH, PLSE, LEAKY, ELU, LOGGY, STAIR, HARDTAN, LHTAN, SELU, SWISH, MISH }ACTIVATION; // parser.h @@ -347,7 +347,7 @@ struct layer { float *col_image; float * delta; float * output; - float * output_sigmoid; + float * activation_input; int delta_pinned; int output_pinned; float * loss; @@ -532,7 +532,7 @@ struct layer { float * input_antialiasing_gpu; float * output_gpu; - float * output_sigmoid_gpu; + float * activation_input_gpu; float * loss_gpu; float * delta_gpu; float * rand_gpu; diff --git a/src/activation_kernels.cu b/src/activation_kernels.cu index 24563c69d6e..846c586fada 100644 --- a/src/activation_kernels.cu +++ b/src/activation_kernels.cu @@ -199,6 +199,16 @@ __global__ void activate_array_swish_kernel(float *x, int n, float *output_sigmo } } +__global__ void activate_array_mish_kernel(float *x, int n, float *activation_input, float *output_gpu) +{ + int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x; + if (i < n) { + float x_val = x[i]; + activation_input[i] = x_val; // store value before activation + output_gpu[i] = x_val * tanh_activate_kernel(log(1 + expf(x_val))); + } +} + __global__ void activate_array_leaky_kernel(float *x, int n) { int index = blockIdx.x*blockDim.x + threadIdx.x; @@ -263,6 +273,18 @@ __global__ void gradient_array_swish_kernel(float *x, int n, float *sigmoid_gpu, } } +__global__ void gradient_array_mish_kernel(int n, float *activation_input, float *delta) +{ + int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x; + if (i < n) { + float x = activation_input[i]; + float d = 2 * expf(x) + expf(2 * x) + 2; + float w = 4 * (x + 1) + 4 * expf(2 * x) + expf(3 * x) + expf(x)*(4 * x + 6); + float derivative = expf(x) * w / (d * d); + delta[i] *= derivative; + } +} + __global__ void gradient_array_leaky_kernel(float *x, int n, float *delta) { int index = blockIdx.x*blockDim.x + threadIdx.x; @@ -333,6 +355,13 @@ extern "C" void activate_array_swish_ongpu(float *x, int n, float *output_sigmoi CHECK_CUDA(cudaPeekAtLastError()); } +extern "C" void activate_array_mish_ongpu(float *x, int n, float *activation_input_gpu, float *output_gpu) +{ + const int num_blocks = get_number_of_blocks(n, BLOCK); + activate_array_mish_kernel << > >(x, n, activation_input_gpu, output_gpu); + CHECK_CUDA(cudaPeekAtLastError()); +} + extern "C" void gradient_array_ongpu(float *x, int n, ACTIVATION a, float *delta) { const int num_blocks = get_number_of_blocks(n, BLOCK); @@ -354,4 +383,11 @@ extern "C" void gradient_array_swish_ongpu(float *x, int n, float *sigmoid_gpu, const int num_blocks = get_number_of_blocks(n, BLOCK); gradient_array_swish_kernel << > > (x, n, sigmoid_gpu, delta); CHECK_CUDA(cudaPeekAtLastError()); +} + +extern "C" void gradient_array_mish_ongpu(int n, float *activation_input_gpu, float *delta) +{ + const int num_blocks = get_number_of_blocks(n, BLOCK); + gradient_array_mish_kernel << > > (n, activation_input_gpu, delta); + CHECK_CUDA(cudaPeekAtLastError()); } \ No newline at end of file diff --git a/src/activations.c b/src/activations.c index c3ea4818aeb..5311fb97a77 100644 --- a/src/activations.c +++ b/src/activations.c @@ -46,6 +46,7 @@ ACTIVATION get_activation(char *s) { if (strcmp(s, "logistic")==0) return LOGISTIC; if (strcmp(s, "swish") == 0) return SWISH; + if (strcmp(s, "mish") == 0) return MISH; if (strcmp(s, "loggy")==0) return LOGGY; if (strcmp(s, "relu")==0) return RELU; if (strcmp(s, "elu")==0) return ELU; @@ -133,6 +134,17 @@ void activate_array_swish(float *x, const int n, float * output_sigmoid, float * } } +void activate_array_mish(float *x, const int n, float * activation_input, float * output) +{ + int i; + #pragma omp parallel for + for (i = 0; i < n; ++i) { + float x_val = x[i]; + activation_input[i] = x_val; // store value before activation + output[i] = x_val * tanh_activate(log(1 + expf(x_val))); + } +} + float gradient(float x, ACTIVATION a) { switch(a){ @@ -187,3 +199,16 @@ void gradient_array_swish(const float *x, const int n, const float * sigmoid, fl delta[i] *= swish + sigmoid[i]*(1 - swish); } } + +void gradient_array_mish(const int n, const float * activation_input, float * delta) +{ + int i; + #pragma omp parallel for + for (i = 0; i < n; ++i) { + float x = activation_input[i]; + float d = 2 * expf(x) + expf(2 * x) + 2; + float w = 4 * (x + 1) + 4 * expf(2 * x) + expf(3 * x) + expf(x)*(4 * x + 6); + float derivative = expf(x) * w / (d * d); + delta[i] *= derivative; + } +} diff --git a/src/activations.h b/src/activations.h index 19f3822c8e3..bba5ca8d10a 100644 --- a/src/activations.h +++ b/src/activations.h @@ -5,7 +5,7 @@ #include "math.h" //typedef enum{ -// LOGISTIC, RELU, RELIE, LINEAR, RAMP, TANH, PLSE, LEAKY, ELU, LOGGY, STAIR, HARDTAN, LHTAN, SELU +// LOGISTIC, RELU, RELIE, LINEAR, RAMP, TANH, PLSE, LEAKY, ELU, LOGGY, STAIR, HARDTAN, LHTAN, SELU, SWISH, MISH //}ACTIVATION; #ifdef __cplusplus @@ -18,13 +18,17 @@ float activate(float x, ACTIVATION a); float gradient(float x, ACTIVATION a); void gradient_array(const float *x, const int n, const ACTIVATION a, float *delta); void gradient_array_swish(const float *x, const int n, const float * sigmoid, float * delta); +void gradient_array_mish(const int n, const float * activation_input, float * delta); void activate_array(float *x, const int n, const ACTIVATION a); void activate_array_swish(float *x, const int n, float * output_sigmoid, float * output); +void activate_array_mish(float *x, const int n, float * activation_input, float * output); #ifdef GPU void activate_array_ongpu(float *x, int n, ACTIVATION a); void activate_array_swish_ongpu(float *x, int n, float *output_sigmoid_gpu, float *output_gpu); +void activate_array_mish_ongpu(float *x, int n, float *activation_input_gpu, float *output_gpu); void gradient_array_ongpu(float *x, int n, ACTIVATION a, float *delta); void gradient_array_swish_ongpu(float *x, int n, float *sigmoid_gpu, float *delta); +void gradient_array_mish_ongpu(int n, float *activation_input_gpu, float *delta); #endif static inline float stair_activate(float x) diff --git a/src/convolutional_kernels.cu b/src/convolutional_kernels.cu index 0b94dd29db4..a73f277ee92 100644 --- a/src/convolutional_kernels.cu +++ b/src/convolutional_kernels.cu @@ -392,7 +392,8 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state) */ //add_bias_gpu(l.output_gpu, l.biases_gpu, l.batch, l.n, l.out_w*l.out_h); - if (l.activation == SWISH) activate_array_swish_ongpu(l.output_gpu, l.outputs*l.batch, l.output_sigmoid_gpu, l.output_gpu); + if (l.activation == SWISH) activate_array_swish_ongpu(l.output_gpu, l.outputs*l.batch, l.activation_input_gpu, l.output_gpu); + else if (l.activation == MISH) activate_array_mish_ongpu(l.output_gpu, l.outputs*l.batch, l.activation_input_gpu, l.output_gpu); else if (l.activation != LINEAR && l.activation != LEAKY) activate_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation); //if(l.activation != LINEAR && l.activation != LEAKY) activate_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation); //if (l.binary || l.xnor) swap_binary(&l); @@ -596,7 +597,8 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state) //#ifndef CUDNN_HALF //#endif // no CUDNN_HALF - if (l.activation == SWISH) activate_array_swish_ongpu(l.output_gpu, l.outputs*l.batch, l.output_sigmoid_gpu, l.output_gpu); + if (l.activation == SWISH) activate_array_swish_ongpu(l.output_gpu, l.outputs*l.batch, l.activation_input_gpu, l.output_gpu); + else if (l.activation == MISH) activate_array_mish_ongpu(l.output_gpu, l.outputs*l.batch, l.activation_input_gpu, l.output_gpu); else if (l.activation != LINEAR) activate_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation); //if(l.dot > 0) dot_error_gpu(l); if(l.binary || l.xnor) swap_binary(&l); @@ -639,7 +641,8 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state if(state.net.try_fix_nan) constrain_ongpu(l.outputs*l.batch, 1, l.delta_gpu, 1); - if (l.activation == SWISH) gradient_array_swish_ongpu(l.output_gpu, l.outputs*l.batch, l.output_sigmoid_gpu, l.delta_gpu); + if (l.activation == SWISH) gradient_array_swish_ongpu(l.output_gpu, l.outputs*l.batch, l.activation_input_gpu, l.delta_gpu); + else if (l.activation == MISH) gradient_array_mish_ongpu(l.outputs*l.batch, l.activation_input_gpu, l.delta_gpu); else gradient_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation, l.delta_gpu); if (!l.batch_normalize) diff --git a/src/convolutional_layer.c b/src/convolutional_layer.c index bf5beac7972..b76d7ee735f 100644 --- a/src/convolutional_layer.c +++ b/src/convolutional_layer.c @@ -473,10 +473,10 @@ convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w, l.scale_v = (float*)calloc(n, sizeof(float)); } - if(l.activation == SWISH) l.output_sigmoid = (float*)calloc(total_batch*l.outputs, sizeof(float)); + if (l.activation == SWISH || l.activation == MISH) l.activation_input = (float*)calloc(total_batch*l.outputs, sizeof(float)); #ifdef GPU - if (l.activation == SWISH) l.output_sigmoid_gpu = cuda_make_array(l.output_sigmoid, total_batch*out_h*out_w*n); + if (l.activation == SWISH || l.activation == MISH) l.activation_input_gpu = cuda_make_array(l.activation_input, total_batch*out_h*out_w*n); l.forward_gpu = forward_convolutional_layer_gpu; l.backward_gpu = backward_convolutional_layer_gpu; @@ -1100,7 +1100,8 @@ void forward_convolutional_layer(convolutional_layer l, network_state state) add_bias(l.output, l.biases, l.batch, l.n, out_h*out_w); //activate_array(l.output, m*n*l.batch, l.activation); - if (l.activation == SWISH) activate_array_swish(l.output, l.outputs*l.batch, l.output_sigmoid, l.output); + if (l.activation == SWISH) activate_array_swish(l.output, l.outputs*l.batch, l.activation_input, l.output); + else if (l.activation == MISH) activate_array_mish(l.output, l.outputs*l.batch, l.activation_input, l.output); else activate_array_cpu_custom(l.output, m*n*l.batch, l.activation); return; @@ -1139,7 +1140,8 @@ void forward_convolutional_layer(convolutional_layer l, network_state state) add_bias(l.output, l.biases, l.batch, l.n, out_h*out_w); //activate_array(l.output, m*n*l.batch, l.activation); - if (l.activation == SWISH) activate_array_swish(l.output, l.outputs*l.batch, l.output_sigmoid, l.output); + if (l.activation == SWISH) activate_array_swish(l.output, l.outputs*l.batch, l.activation_input, l.output); + else if (l.activation == MISH) activate_array_mish(l.output, l.outputs*l.batch, l.activation_input, l.output); else activate_array_cpu_custom(l.output, l.outputs*l.batch, l.activation); if(l.binary || l.xnor) swap_binary(&l); @@ -1276,7 +1278,8 @@ void backward_convolutional_layer(convolutional_layer l, network_state state) int n = l.size*l.size*l.c / l.groups; int k = l.out_w*l.out_h; - if (l.activation == SWISH) gradient_array_swish(l.output, l.outputs*l.batch, l.output_sigmoid, l.delta); + if (l.activation == SWISH) gradient_array_swish(l.output, l.outputs*l.batch, l.activation_input, l.delta); + else if (l.activation == MISH) gradient_array_mish(l.outputs*l.batch, l.activation_input, l.delta); else gradient_array(l.output, l.outputs*l.batch, l.activation, l.delta); if (l.batch_normalize) { diff --git a/src/layer.c b/src/layer.c index e9ae67b5ff5..9fe4a439364 100644 --- a/src/layer.c +++ b/src/layer.c @@ -90,7 +90,7 @@ void free_layer(layer l) #endif // GPU if (l.delta) free(l.delta), l.delta = NULL; if (l.output) free(l.output), l.output = NULL; - if (l.output_sigmoid) free(l.output_sigmoid), l.output_sigmoid = NULL; + if (l.activation_input) free(l.activation_input), l.activation_input = NULL; if (l.squared) free(l.squared); if (l.norms) free(l.norms); if (l.spatial_mean) free(l.spatial_mean); @@ -176,7 +176,7 @@ void free_layer(layer l) if (l.scale_updates_gpu) cuda_free(l.scale_updates_gpu), l.scale_updates_gpu = NULL; if (l.input_antialiasing_gpu) cuda_free(l.input_antialiasing_gpu), l.input_antialiasing_gpu = NULL; if (l.output_gpu) cuda_free(l.output_gpu), l.output_gpu = NULL; - if (l.output_sigmoid_gpu) cuda_free(l.output_sigmoid_gpu), l.output_sigmoid_gpu = NULL; + if (l.activation_input_gpu) cuda_free(l.activation_input_gpu), l.activation_input_gpu = NULL; if (l.delta_gpu) cuda_free(l.delta_gpu), l.delta_gpu = NULL; if (l.rand_gpu) cuda_free(l.rand_gpu); if (l.squared_gpu) cuda_free(l.squared_gpu);