Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
tae898 committed Jul 27, 2023
1 parent 2ea1fd6 commit cae05ca
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 21 deletions.
21 changes: 17 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,11 +430,24 @@ def td_loss(self, batch: Tuple[Tensor, Tensor]) -> Tensor:
)

with torch.no_grad():
next_state_values = self.target_net(next_states).max(1)[0]
next_state_values[dones] = 0.0
next_state_values = next_state_values.detach()
if self.hparams.dqn_type.lower() == "single":
next_state_action_values = self.target_net(next_states).max(1)[0]
elif self.hparams.dqn_type.lower() == "double":
actions_ = self.net(next_states).argmax(dim=1)
next_state_action_values = (
self.target_net(next_states)
.gather(1, actions_.long().unsqueeze(-1))
.squeeze(-1)
)
else:
raise ValueError

next_state_action_values[dones] = 0.0
next_state_action_values = next_state_action_values.detach()

expected_state_action_values = next_state_values * self.hparams.gamma + rewards
expected_state_action_values = (
next_state_action_values * self.hparams.gamma + rewards
)

if self.hparams.loss_function.lower() == "mse":
return nn.MSELoss()(state_action_values, expected_state_action_values)
Expand Down
1 change: 1 addition & 0 deletions train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ log_every_n_steps: 1
early_stopping_patience: 1000 # number of epochs, not episodes! Atm, I don't do this.
precision: 32
accelerator: cpu
dqn_type: double
# num_steps_per_epoch = ceil(epoch_length / batch_size)
# total_number_of_steps = num_steps_per_epoch * max_epochs
# total_number_of_episodes = total_number_of_steps / last_time_step
Expand Down
36 changes: 19 additions & 17 deletions train_multiple.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,32 +48,34 @@
"early_stopping_patience": 1000,
"precision": 32,
"accelerator": "cpu",
"dqn_type": "double",
}

commands = []
num_parallel = 4
num_parallel = 2
reverse = False
os.makedirs("./junks", exist_ok=True)

for capacity in [2, 4, 8, 16, 32, 64]:
for pretrain_semantic in [False, True]:
for seed in [0, 1, 2, 3, 4]:
train_config["question_prob"] = 0.5
train_config["capacity"] = {
"episodic": capacity // 2,
"semantic": capacity // 2,
"short": 1,
}
train_config["pretrain_semantic"] = pretrain_semantic
train_config["seed"] = seed
# for capacity in [2, 4, 8, 16, 32, 64]:
# for pretrain_semantic in [False, True]:
# for seed in [0, 1, 2, 3, 4]:
# train_config["question_prob"] = 0.5
# train_config["capacity"] = {
# "episodic": capacity // 2,
# "semantic": capacity // 2,
# "short": 1,
# }
# train_config["pretrain_semantic"] = pretrain_semantic
# train_config["seed"] = seed

config_file_name = (
f"./junks/{str(datetime.datetime.now()).replace(' ', '-')}.yaml"
)
# config_file_name = (
# f"./junks/{str(datetime.datetime.now()).replace(' ', '-')}.yaml"
# )

write_yaml(train_config, config_file_name)
# write_yaml(train_config, config_file_name)

# commands.append(f"python train.py --config {config_file_name}")

commands.append(f"python train.py --config {config_file_name}")

for capacity in [2, 4, 8, 16, 32, 64]:
for pretrain_semantic in [False, True]:
Expand Down

0 comments on commit cae05ca

Please sign in to comment.