diff --git a/src/nbla/cuda/cudnn/function/generic/gru.cu b/src/nbla/cuda/cudnn/function/generic/gru.cu index 94eadcacd..8732c6b9f 100644 --- a/src/nbla/cuda/cudnn/function/generic/gru.cu +++ b/src/nbla/cuda/cudnn/function/generic/gru.cu @@ -463,7 +463,7 @@ void GRUCudaCudnn::setup_impl(const Variables &inputs, } // Set output shapes - outputs[0]->reshape({seq_len_, batch_size, hidden_size_, num_directions_}, + outputs[0]->reshape({seq_len_, batch_size, num_directions_ * hidden_size_}, true); outputs[1]->reshape(inputs[1]->shape(), true); } diff --git a/src/nbla/cuda/cudnn/function/generic/lstm.cu b/src/nbla/cuda/cudnn/function/generic/lstm.cu index a63b08c32..128d9df3a 100644 --- a/src/nbla/cuda/cudnn/function/generic/lstm.cu +++ b/src/nbla/cuda/cudnn/function/generic/lstm.cu @@ -450,7 +450,7 @@ void LSTMCudaCudnn::setup_impl(const Variables &inputs, } // Set output shapes - outputs[0]->reshape({seq_len_, batch_size, hidden_size_, num_directions_}, + outputs[0]->reshape({seq_len_, batch_size, num_directions_ * hidden_size_}, true); outputs[1]->reshape(inputs[1]->shape(), true); outputs[2]->reshape(inputs[2]->shape(), true); diff --git a/src/nbla/cuda/cudnn/function/generic/rnn.cu b/src/nbla/cuda/cudnn/function/generic/rnn.cu index ffe933ae2..3ce66ae08 100644 --- a/src/nbla/cuda/cudnn/function/generic/rnn.cu +++ b/src/nbla/cuda/cudnn/function/generic/rnn.cu @@ -433,7 +433,7 @@ void RNNCudaCudnn::setup_impl(const Variables &inputs, } // Set output shapes - outputs[0]->reshape({seq_len_, batch_size, hidden_size_, num_directions_}, + outputs[0]->reshape({seq_len_, batch_size, num_directions_ * hidden_size_}, true); outputs[1]->reshape(inputs[1]->shape(), true); }