Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

recurrent examples fails with "TypeError: pad_sequence(): sequences must be tuple of Tensors, not Tensor" #154

Closed
apersonnaz opened this issue Aug 24, 2021 · 9 comments · Fixed by #163

Comments

@apersonnaz
Copy link

Hello,

I am trying to implement recurrent network for an implementation of DDQL, so I did the same as the DRQL example, but my implem as the example crash with :

File "/home/aurelien/.local/lib/python3.8/site-packages/torch/nn/utils/rnn.py", line 363, in pad_sequence
return torch._C._nn.pad_sequence(sequences, batch_first, padding_value) TypeError: pad_sequence(): argument 'sequences' (position 1) must be tuple of Tensors, not Tensor

In batch_states.py line 33, the default_collate returns a tensor ( it is called first on a numpy array and then recursively on the tensor returned),

Then the full trace is

File "train_deep_agent.py", line 184, in
action = agent.act(observation)
File "/home/aurelien/.local/lib/python3.8/site-packages/pfrl/agent.py", line 161, in act
return self.batch_act([obs])[0]
File "/home/aurelien/.local/lib/python3.8/site-packages/pfrl/agents/dqn.py", line 490, in batch_act
batch_av = self._evaluate_model_and_update_recurrent_states(batch_obs)
File "/home/aurelien/.local/lib/python3.8/site-packages/pfrl/agents/dqn.py", line 477, in _evaluate_model_and_update_recurrent_states
batch_av, self.train_recurrent_states = one_step_forward(
File "/home/aurelien/.local/lib/python3.8/site-packages/pfrl/utils/recurrent.py", line 149, in one_step_forward
pack = pack_one_step_batch_as_sequences(batch_input)
File "/home/aurelien/.local/lib/python3.8/site-packages/pfrl/utils/recurrent.py", line 125, in pack_one_step_batch_as_sequences
return nn.utils.rnn.pack_sequence(xs[:, None])
File "/home/aurelien/.local/lib/python3.8/site-packages/torch/nn/utils/rnn.py", line 398, in pack_sequence
return pack_padded_sequence(pad_sequence(sequences), lengths, enforce_sorted=enforce_sorted)
File "/home/aurelien/.local/lib/python3.8/site-packages/torch/nn/utils/rnn.py", line 363, in pad_sequence
return torch._C._nn.pad_sequence(sequences, batch_first, padding_value)
TypeError: pad_sequence(): argument 'sequences' (position 1) must be tuple of Tensors, not Tensor

I have checked and the PPO example fails with the same error when started with recurrence.

Am I missing something?

@apersonnaz
Copy link
Author

I am using python 3.8.10, and torch 1.9.0, and pfrl 0.3.0

@muupan
Copy link
Member

muupan commented Aug 24, 2021

Thanks for reporting this. I was able to reproduce it. python examples/atari/train_ppo_ale.py --recurrent works fine with torch==1.8.1 but fails with torch==1.9.0. This should be fixed, but as a workaround you can use torch<1.9.0.

$ python examples/atari/train_ppo_ale.py --recurrent
Output files are saved in results/7b0c7e938ba2c0c56a941c766c68635d0dad43c8-00000000-7d123117
Observation space Box(0, 255, (4, 84, 84), uint8)
Action space Discrete(4)
INFO:pfrl.experiments.train_agent_batch:Saved the agent to results/7b0c7e938ba2c0c56a941c766c68635d0dad43c8-00000000-7d123117/0_except
Traceback (most recent call last):
  File "/home/fujita/pfrl/examples/atari/train_ppo_ale.py", line 333, in <module>
    main()
  File "/home/fujita/pfrl/examples/atari/train_ppo_ale.py", line 316, in main
    experiments.train_agent_batch_with_evaluation(
  File "/home/fujita/pfrl/pfrl/experiments/train_agent_batch.py", line 247, in train_agent_batch_with_evaluation
    eval_stats_history = train_agent_batch(
  File "/home/fujita/pfrl/pfrl/experiments/train_agent_batch.py", line 67, in train_agent_batch
    actions = agent.batch_act(obss)
  File "/home/fujita/pfrl/pfrl/agents/ppo.py", line 678, in batch_act
    return self._batch_act_train(batch_obs)
  File "/home/fujita/pfrl/pfrl/agents/ppo.py", line 731, in _batch_act_train
    ) = one_step_forward(
  File "/home/fujita/pfrl/pfrl/utils/recurrent.py", line 149, in one_step_forward
    pack = pack_one_step_batch_as_sequences(batch_input)
  File "/home/fujita/pfrl/pfrl/utils/recurrent.py", line 125, in pack_one_step_batch_as_sequences
    return nn.utils.rnn.pack_sequence(xs[:, None])
  File "/home/fujita/.local/share/virtualenvs/pfrl-pL7Y2GAq/lib/python3.9/site-packages/torch/nn/utils/rnn.py", line 398, in pack_sequence
    return pack_padded_sequence(pad_sequence(sequences), lengths, enforce_sorted=enforce_sorted)
  File "/home/fujita/.local/share/virtualenvs/pfrl-pL7Y2GAq/lib/python3.9/site-packages/torch/nn/utils/rnn.py", line 363, in pad_sequence
    return torch._C._nn.pad_sequence(sequences, batch_first, padding_value)
TypeError: pad_sequence(): argument 'sequences' (position 1) must be tuple of Tensors, not Tensor

@apersonnaz
Copy link
Author

Thank you so much! It is indeed working fine with torch==1.8.1

BTW on another point, it seems to me that we always have only two stacked states in train_recurrent_states in dqn.py. I could not find where to configure the number of states stacked for the recurrent layers. Is it possible to have more?

@muupan
Copy link
Member

muupan commented Aug 25, 2021

self.train_recurrent_states contains the necessary information, h_t and c_t in the case of LSTM, for inference at next step t+1. DQN does not store older states as they are no longer needed for inference.

If you need past recurrent states e.g. for updating the model, you can use recurrent states stored in the replay buffer

pfrl/pfrl/agents/dqn.py

Lines 532 to 542 in 7b0c7e9

if self.recurrent:
transition["recurrent_state"] = recurrent_state_as_numpy(
get_recurrent_state_at(
self.train_prev_recurrent_states, i, detach=True
)
)
transition["next_recurrent_state"] = recurrent_state_as_numpy(
get_recurrent_state_at(
self.train_recurrent_states, i, detach=True
)
)
.

@elbamos
Copy link

elbamos commented Nov 16, 2021

I'm seeing this as well with torch 1.10 - is it planned to be fixed?

@douglascvas
Copy link

Any fixes? Problem is still happening :(

@muupan
Copy link
Member

muupan commented Jan 16, 2022

Hopefully #163 will address the TypeError issue.

@weixians
Copy link

weixians commented Feb 7, 2022

I found this problem can be fixed by repalcing the "nn.utils.rnn.pack_sequence(xs[:, None])" to "nn.utils.rnn.pack_sequence([xs])"

@wansongying
Copy link

I found this problem can be fixed by repalcing the "nn.utils.rnn.pack_sequence(xs[:, None])" to "nn.utils.rnn.pack_sequence([xs])"

So coooool! It works!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants