Skip to content

Commit

Permalink
Add support for LSTM timing mode=False (#104)
Browse files Browse the repository at this point in the history
  • Loading branch information
AdrianLundell authored Feb 8, 2024
1 parent 601d96c commit 2a999a2
Showing 1 changed file with 31 additions and 16 deletions.
47 changes: 31 additions & 16 deletions Source/LSTMFunctions/arm_lstm_unidirectional_s8.c
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
* Title: arm_lstm_unidirectional_s8.c
* Description: S8 LSTM function with S16 gate output
*
* $Date: 19 January 2024
* $Revision: V.1.0.0
* $Date: 08 February 2024
* $Revision: V.1.1.0
*
* Target Processor: Cortex-M processors
*
Expand Down Expand Up @@ -54,24 +54,39 @@ arm_cmsis_nn_status arm_lstm_unidirectional_s8(const int8_t *input,

int8_t *hidden_in = NULL;
memset(buffers->cell_state, 0, params->batch_size * params->hidden_size * sizeof(int16_t));
const int32_t batch_offset = (params->time_major) ? 1 : params->time_steps;

for (int t = 0; t < params->time_steps; t++)
if (params->time_major)
{
const int8_t *data_in = input + (t * params->batch_size * params->input_size);
int8_t *hidden_out = output + (t * params->batch_size * params->hidden_size);

arm_cmsis_nn_status status = arm_nn_lstm_step_s8(data_in, hidden_in, hidden_out, params, buffers, batch_offset);

if (status != ARM_CMSIS_NN_SUCCESS)
// First dimension is time, input/output for each time step is stored continously in memory
for (int t = 0; t < params->time_steps; t++)
{
return status;
const int8_t *data_in = input + (t * params->batch_size * params->input_size);
int8_t *hidden_out = output + (t * params->batch_size * params->hidden_size);
arm_cmsis_nn_status status = arm_nn_lstm_step_s8(data_in, hidden_in, hidden_out, params, buffers, 1);
if (status != ARM_CMSIS_NN_SUCCESS)
{
return status;
}
// Output is used as recurrent input/hidden state for the next timestep.
hidden_in = &hidden_out[0];
}
}
else
{
// First dimension is time, add batch_offset to jump in memory for each batch
for (int t = 0; t < params->time_steps; t++)
{
const int8_t *data_in = input + (t * params->input_size);
int8_t *hidden_out = output + (t * params->hidden_size);
arm_cmsis_nn_status status =
arm_nn_lstm_step_s8(data_in, hidden_in, hidden_out, params, buffers, params->time_steps);
if (status != ARM_CMSIS_NN_SUCCESS)
{
return status;
}
// Output is used as recurrent input/hidden state for the next timestep.
hidden_in = &hidden_out[0];
}

// Output is used as recurrent input/hidden state for the next timestep.
hidden_in = &hidden_out[0];
}

return ARM_CMSIS_NN_SUCCESS;
}

Expand Down

0 comments on commit 2a999a2

Please sign in to comment.