Skip to content

Commit

Permalink
Merge pull request #158 from sony/fix/20190524-lstm-memory
Browse files Browse the repository at this point in the history
[Fix] LSTM/GRU Output Shape
  • Loading branch information
TakuyaNarihira authored May 28, 2019
2 parents 4af072c + 1e5fae1 commit f5a5403
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/nbla/cuda/cudnn/function/generic/gru.cu
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ void GRUCudaCudnn<T>::setup_impl(const Variables &inputs,
}

// Set output shapes
outputs[0]->reshape({seq_len_, batch_size, hidden_size_, num_directions_},
outputs[0]->reshape({seq_len_, batch_size, num_directions_ * hidden_size_},
true);
outputs[1]->reshape(inputs[1]->shape(), true);
}
Expand Down
2 changes: 1 addition & 1 deletion src/nbla/cuda/cudnn/function/generic/lstm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ void LSTMCudaCudnn<T>::setup_impl(const Variables &inputs,
}

// Set output shapes
outputs[0]->reshape({seq_len_, batch_size, hidden_size_, num_directions_},
outputs[0]->reshape({seq_len_, batch_size, num_directions_ * hidden_size_},
true);
outputs[1]->reshape(inputs[1]->shape(), true);
outputs[2]->reshape(inputs[2]->shape(), true);
Expand Down
2 changes: 1 addition & 1 deletion src/nbla/cuda/cudnn/function/generic/rnn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ void RNNCudaCudnn<T>::setup_impl(const Variables &inputs,
}

// Set output shapes
outputs[0]->reshape({seq_len_, batch_size, hidden_size_, num_directions_},
outputs[0]->reshape({seq_len_, batch_size, num_directions_ * hidden_size_},
true);
outputs[1]->reshape(inputs[1]->shape(), true);
}
Expand Down

0 comments on commit f5a5403

Please sign in to comment.