From 1e5fae14f20e2457ae4bc8573c279d7a59fa43f1 Mon Sep 17 00:00:00 2001 From: andrewshinsony Date: Fri, 24 May 2019 04:42:38 +0000 Subject: [PATCH] fix lstm memory shape --- src/nbla/cuda/cudnn/function/generic/gru.cu | 2 +- src/nbla/cuda/cudnn/function/generic/lstm.cu | 2 +- src/nbla/cuda/cudnn/function/generic/rnn.cu | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/nbla/cuda/cudnn/function/generic/gru.cu b/src/nbla/cuda/cudnn/function/generic/gru.cu index 94eadcacd..8732c6b9f 100644 --- a/src/nbla/cuda/cudnn/function/generic/gru.cu +++ b/src/nbla/cuda/cudnn/function/generic/gru.cu @@ -463,7 +463,7 @@ void GRUCudaCudnn::setup_impl(const Variables &inputs, } // Set output shapes - outputs[0]->reshape({seq_len_, batch_size, hidden_size_, num_directions_}, + outputs[0]->reshape({seq_len_, batch_size, num_directions_ * hidden_size_}, true); outputs[1]->reshape(inputs[1]->shape(), true); } diff --git a/src/nbla/cuda/cudnn/function/generic/lstm.cu b/src/nbla/cuda/cudnn/function/generic/lstm.cu index a63b08c32..128d9df3a 100644 --- a/src/nbla/cuda/cudnn/function/generic/lstm.cu +++ b/src/nbla/cuda/cudnn/function/generic/lstm.cu @@ -450,7 +450,7 @@ void LSTMCudaCudnn::setup_impl(const Variables &inputs, } // Set output shapes - outputs[0]->reshape({seq_len_, batch_size, hidden_size_, num_directions_}, + outputs[0]->reshape({seq_len_, batch_size, num_directions_ * hidden_size_}, true); outputs[1]->reshape(inputs[1]->shape(), true); outputs[2]->reshape(inputs[2]->shape(), true); diff --git a/src/nbla/cuda/cudnn/function/generic/rnn.cu b/src/nbla/cuda/cudnn/function/generic/rnn.cu index ffe933ae2..3ce66ae08 100644 --- a/src/nbla/cuda/cudnn/function/generic/rnn.cu +++ b/src/nbla/cuda/cudnn/function/generic/rnn.cu @@ -433,7 +433,7 @@ void RNNCudaCudnn::setup_impl(const Variables &inputs, } // Set output shapes - outputs[0]->reshape({seq_len_, batch_size, hidden_size_, num_directions_}, + outputs[0]->reshape({seq_len_, batch_size, num_directions_ * hidden_size_}, true); outputs[1]->reshape(inputs[1]->shape(), true); }