Skip to content

Commit

Permalink
fix lstm memory shape
Browse files Browse the repository at this point in the history
  • Loading branch information
TE-andrewshin committed May 24, 2019
1 parent 3cb68ed commit 1e5fae1
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/nbla/cuda/cudnn/function/generic/gru.cu
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ void GRUCudaCudnn<T>::setup_impl(const Variables &inputs,
}

// Set output shapes
outputs[0]->reshape({seq_len_, batch_size, hidden_size_, num_directions_},
outputs[0]->reshape({seq_len_, batch_size, num_directions_ * hidden_size_},
true);
outputs[1]->reshape(inputs[1]->shape(), true);
}
Expand Down
2 changes: 1 addition & 1 deletion src/nbla/cuda/cudnn/function/generic/lstm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ void LSTMCudaCudnn<T>::setup_impl(const Variables &inputs,
}

// Set output shapes
outputs[0]->reshape({seq_len_, batch_size, hidden_size_, num_directions_},
outputs[0]->reshape({seq_len_, batch_size, num_directions_ * hidden_size_},
true);
outputs[1]->reshape(inputs[1]->shape(), true);
outputs[2]->reshape(inputs[2]->shape(), true);
Expand Down
2 changes: 1 addition & 1 deletion src/nbla/cuda/cudnn/function/generic/rnn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ void RNNCudaCudnn<T>::setup_impl(const Variables &inputs,
}

// Set output shapes
outputs[0]->reshape({seq_len_, batch_size, hidden_size_, num_directions_},
outputs[0]->reshape({seq_len_, batch_size, num_directions_ * hidden_size_},
true);
outputs[1]->reshape(inputs[1]->shape(), true);
}
Expand Down

0 comments on commit 1e5fae1

Please sign in to comment.