From 2a999a2fd887c98042353accac77479f00b5f99d Mon Sep 17 00:00:00 2001 From: Adrian Lundell <36153706+AdrianLundell@users.noreply.github.com> Date: Thu, 8 Feb 2024 15:51:49 +0100 Subject: [PATCH] Add support for LSTM timing mode=False (#104) --- .../arm_lstm_unidirectional_s8.c | 47 ++++++++++++------- 1 file changed, 31 insertions(+), 16 deletions(-) diff --git a/Source/LSTMFunctions/arm_lstm_unidirectional_s8.c b/Source/LSTMFunctions/arm_lstm_unidirectional_s8.c index c2868b3a..86f404e9 100644 --- a/Source/LSTMFunctions/arm_lstm_unidirectional_s8.c +++ b/Source/LSTMFunctions/arm_lstm_unidirectional_s8.c @@ -21,8 +21,8 @@ * Title: arm_lstm_unidirectional_s8.c * Description: S8 LSTM function with S16 gate output * - * $Date: 19 January 2024 - * $Revision: V.1.0.0 + * $Date: 08 February 2024 + * $Revision: V.1.1.0 * * Target Processor: Cortex-M processors * @@ -54,24 +54,39 @@ arm_cmsis_nn_status arm_lstm_unidirectional_s8(const int8_t *input, int8_t *hidden_in = NULL; memset(buffers->cell_state, 0, params->batch_size * params->hidden_size * sizeof(int16_t)); - const int32_t batch_offset = (params->time_major) ? 1 : params->time_steps; - - for (int t = 0; t < params->time_steps; t++) + if (params->time_major) { - const int8_t *data_in = input + (t * params->batch_size * params->input_size); - int8_t *hidden_out = output + (t * params->batch_size * params->hidden_size); - - arm_cmsis_nn_status status = arm_nn_lstm_step_s8(data_in, hidden_in, hidden_out, params, buffers, batch_offset); - - if (status != ARM_CMSIS_NN_SUCCESS) + // First dimension is time, input/output for each time step is stored continously in memory + for (int t = 0; t < params->time_steps; t++) { - return status; + const int8_t *data_in = input + (t * params->batch_size * params->input_size); + int8_t *hidden_out = output + (t * params->batch_size * params->hidden_size); + arm_cmsis_nn_status status = arm_nn_lstm_step_s8(data_in, hidden_in, hidden_out, params, buffers, 1); + if (status != ARM_CMSIS_NN_SUCCESS) + { + return status; + } + // Output is used as recurrent input/hidden state for the next timestep. + hidden_in = &hidden_out[0]; + } + } + else + { + // First dimension is time, add batch_offset to jump in memory for each batch + for (int t = 0; t < params->time_steps; t++) + { + const int8_t *data_in = input + (t * params->input_size); + int8_t *hidden_out = output + (t * params->hidden_size); + arm_cmsis_nn_status status = + arm_nn_lstm_step_s8(data_in, hidden_in, hidden_out, params, buffers, params->time_steps); + if (status != ARM_CMSIS_NN_SUCCESS) + { + return status; + } + // Output is used as recurrent input/hidden state for the next timestep. + hidden_in = &hidden_out[0]; } - - // Output is used as recurrent input/hidden state for the next timestep. - hidden_in = &hidden_out[0]; } - return ARM_CMSIS_NN_SUCCESS; }