Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/terminated truncated #252

Merged
merged 49 commits into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from 48 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
3711d05
Decoupled RSSM for DV3 agent
belerico Feb 8, 2024
e80e9d5
Initialize posterior with prior if is_first is True
belerico Feb 8, 2024
b23112a
Merge branch 'main' of https://github.com/Eclectic-Sheep/sheeprl into…
belerico Feb 12, 2024
f47b8f9
Fix PlayerDV3 creation in evaluation
belerico Feb 12, 2024
e42c83d
Merge branch 'main' of https://github.com/Eclectic-Sheep/sheeprl into…
belerico Feb 26, 2024
2ec4fbb
Fix representation_model
belerico Feb 26, 2024
3a5380b
Fix compute first prior state with a zero posterior
belerico Feb 27, 2024
42d9433
DV3 replay ratio conversion
belerico Feb 29, 2024
750f671
Merge branch 'main' of https://github.com/Eclectic-Sheep/sheeprl into…
belerico Feb 29, 2024
b06433b
Removed expl parameters dependent on old per_Rank_gradient_steps
belerico Feb 29, 2024
20cc43e
Merge branch 'main' of https://github.com/Eclectic-Sheep/sheeprl into…
belerico Mar 4, 2024
37d0e86
Merge branch 'main' of github.com:Eclectic-Sheep/sheeprl into feature…
michele-milesi Mar 18, 2024
704b0ce
feat: update repeats computation
michele-milesi Mar 18, 2024
20905f0
Merge branch 'main' of github.com:Eclectic-Sheep/sheeprl into feature…
michele-milesi Mar 28, 2024
e1290ee
feat: update learning starts in config
michele-milesi Mar 28, 2024
1f0c0ef
fix: remove files
michele-milesi Mar 28, 2024
cd4a4c4
feat: update repeats
michele-milesi Mar 28, 2024
b17d451
Let Dv3 compute bootstrap correctly
belerico Mar 28, 2024
e8c9049
feat: added replay ratio and update exploration
michele-milesi Mar 28, 2024
88c6968
Fix exploration actions computation on DV1
belerico Mar 28, 2024
a5c957c
Fix naming
belerico Mar 28, 2024
c36577d
Add replay-ratio to SAC
belerico Mar 28, 2024
0bc9f07
feat: added replay ratio to p2e algos
michele-milesi Mar 28, 2024
b5fbe5d
feat: update configs and utils of p2e algos
michele-milesi Mar 28, 2024
24c9352
Add replay-ratio to SAC-AE
belerico Mar 28, 2024
a11b558
Merge branch 'feature/replay-ratio' of https://github.com/Eclectic-Sh…
belerico Mar 28, 2024
32b89b4
Add DrOQ replay ratio
belerico Mar 29, 2024
d057886
Fix tests
belerico Mar 29, 2024
b9044a3
Fix mispelled
belerico Mar 29, 2024
5bd7d75
Fix wrong attribute accesing
belerico Mar 29, 2024
8d94f68
FIx naming and configs
belerico Mar 29, 2024
cae85a3
Merge branch 'fix/dv3-continue-on-terminated' of github.com:Eclectic-…
michele-milesi Mar 29, 2024
e5dd8fd
feat: add terminated and truncated to dreamer, p2e and ppo algos
michele-milesi Mar 29, 2024
fdd4579
fix: dmc wrapper
michele-milesi Mar 29, 2024
a2a2690
feat: update algos to split terminated from truncated
michele-milesi Mar 29, 2024
74bfb6b
fix: crafter and diambra wrappers
michele-milesi Mar 29, 2024
3d1f2c9
Merge branch 'main' of github.com:Eclectic-Sheep/sheeprl into fix/ter…
michele-milesi Mar 30, 2024
05e4370
feat: replace done with truncated key in when the buffer is added to …
michele-milesi Mar 30, 2024
87c9098
feat: added truncated/terminated to minedojo environment
michele-milesi Mar 30, 2024
e137a38
feat: added terminated/truncated to minerl and super mario bros envs
michele-milesi Apr 2, 2024
b557835
Merge branch 'main' of github.com:Eclectic-Sheep/sheeprl into fix/ter…
michele-milesi Apr 2, 2024
64d3c81
docs: update howto
michele-milesi Apr 2, 2024
2e156f3
fix: minedojo wrapper
michele-milesi Apr 2, 2024
0167fd5
docs: update
michele-milesi Apr 2, 2024
09e051e
fix: minedojo
michele-milesi Apr 2, 2024
dacd425
update dependencies
michele-milesi Apr 2, 2024
f2557a3
fix: minedojo
michele-milesi Apr 2, 2024
5bf50dd
fix: dv3 small configs
michele-milesi Apr 2, 2024
f58a3c2
fix: episode buffer and tests
michele-milesi Apr 2, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions howto/learn_in_minedojo.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,14 @@ Moreover, we restrict the look-up/down actions between `min_pitch` and `max_pitc
In addition, we added the forward action when the agent selects one of the following actions: `jump`, `sprint`, and `sneak`.
Finally, we added sticky actions for the `jump` and `attack` actions. You can set the values of the `sticky_jump` and `sticky_attack` parameters through the `env.sticky_jump` and `env.sticky_attack` cli arguments, respectively. The sticky actions, if set, force the agent to repeat the selected actions for a certain number of steps.

> [!NOTE]
>
> The `env.sticky_attack` parameter is set to `0` if the `env.break_speed_multiplier > 1`.

For more information about the MineDojo action space, check [here](https://docs.minedojo.org/sections/core_api/action_space.html).

> [!NOTE]
>
> Since the MineDojo environments have a multi-discrete action space, the sticky actions can be easily implemented. The agent will perform the selected action and the sticky actions simultaneously.
>
> The action repeat in the Minecraft environments is set to 1, indeed, It makes no sense to force the agent to repeat an action such as crafting (it may not have enough material for the second action).
Expand Down
5 changes: 5 additions & 0 deletions howto/learn_in_minerl.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,15 @@ In addition, we added the forward action when the agent selects one of the follo
Finally, we added sticky actions for the `jump` and `attack` actions. You can set the values of the `sticky_jump` and `sticky_attack` parameters through the `env.sticky_jump` and `env.sticky_attack` arguments, respectively. The sticky actions, if set, force the agent to repeat the selected actions for a certain number of steps.

> [!NOTE]
>
> Since the MineRL environments have a multi-discrete action space, the sticky actions can be easily implemented. The agent will perform the selected action and the sticky actions simultaneously.
>
> The action repeat in the Minecraft environments is set to 1, indeed, It makes no sense to force the agent to repeat an action such as crafting (it may not have enough material for the second action).

> [!NOTE]
>
> The `env.sticky_attack` parameter is set to `0` if the `env.break_speed_multiplier > 1`.

## Headless machines

If you work on a headless machine, you need to software renderer. We recommend to adopt one of the following solutions:
Expand Down
1 change: 0 additions & 1 deletion howto/logs_and_checkpoints.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,6 @@ AGGREGATOR_KEYS = {
"State/post_entropy",
"State/prior_entropy",
"State/kl",
"Params/exploration_amount",
"Grads/world_model",
"Grads/actor",
"Grads/critic",
Expand Down
1 change: 1 addition & 0 deletions howto/select_observations.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ In the first case, the observations are returned in the form of python dictionar

### Both observations
The algorithms that can work with both image and vector observations are specified in [Table 1](../README.md) in the README, and are reported here:
* A2C
* PPO
* PPO Recurrent
* SAC-AE
Expand Down
10 changes: 6 additions & 4 deletions howto/work_with_steps.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@ The hyper-parameters that refer to the *policy steps* are:
* `exploration_steps`: the number of policy steps in which the agent explores the environment in the P2E algorithms.
* `max_episode_steps`: the maximum number of policy steps an episode can last (`max_steps`); when this number is reached a `terminated=True` is returned by the environment. This means that if you decide to have an action repeat greater than one (`action_repeat > 1`), then the environment performs a maximum number of steps equal to: `env_steps = max_steps * action_repeat`$.
* `learning_starts`: how many policy steps the agent has to perform before starting the training.
* `train_every`: how many policy steps the agent has to perform between one training and the next.

## Gradient steps
A *gradient step* consists of an update of the parameters of the agent, i.e., a call of the *train* function. The gradient step is proportional to the number of parallel processes, indeed, if there are $n$ parallel processes, `n * gradient_steps` calls to the *train* method will be executed.
A *gradient step* consists of an update of the parameters of the agent, i.e., a call of the *train* function. The gradient step is proportional to the number of parallel processes, indeed, if there are $n$ parallel processes, `n * per_rank_gradient_steps` calls to the *train* method will be executed.

The hyper-parameters which refer to the *gradient steps* are:
* `algo.per_rank_gradient_steps`: the number of gradient steps per rank to perform in a single iteration.
* `algo.per_rank_pretrain_steps`: the number of gradient steps per rank to perform in the first iteration.
* `algo.per_rank_pretrain_steps`: the number of gradient steps per rank to perform in the first iteration.

> [!NOTE]
>
> The `replay_ratio` is the ratio between the gradient steps and the policy steps played by the agente.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ atari = [
"gymnasium[accept-rom-license]==0.29.*",
"gymnasium[other]==0.29.*",
]
minedojo = ["minedojo==0.1", "importlib_resources==5.12.0"]
minerl = ["setuptools==66.0.0", "minerl==0.4.4"]
minedojo = ["minedojo==0.1", "importlib_resources==5.12.0", "gym==0.21.0"]
minerl = ["setuptools==66.0.0", "minerl==0.4.4", "gym==0.19.0"]
diambra = ["diambra==0.0.17", "diambra-arena==2.2.6"]
crafter = ["crafter==1.8.3"]
mlflow = ["mlflow==2.11.1"]
Expand Down
8 changes: 3 additions & 5 deletions sheeprl/algos/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
actions = torch.cat(actions, -1).cpu().numpy()

# Single environment step
obs, rewards, done, truncated, info = envs.step(real_actions.reshape(envs.action_space.shape))
obs, rewards, terminated, truncated, info = envs.step(real_actions.reshape(envs.action_space.shape))
truncated_envs = np.nonzero(truncated)[0]
if len(truncated_envs) > 0:
real_next_obs = {
Expand All @@ -266,10 +266,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
rewards[truncated_envs] += cfg.algo.gamma * vals.cpu().numpy().reshape(
rewards[truncated_envs].shape
)

dones = np.logical_or(done, truncated)
dones = dones.reshape(cfg.env.num_envs, -1)
rewards = rewards.reshape(cfg.env.num_envs, -1)
dones = np.logical_or(terminated, truncated).reshape(cfg.env.num_envs, -1).astype(np.uint8)
rewards = rewards.reshape(cfg.env.num_envs, -1)

# Update the step data
step_data["dones"] = dones[np.newaxis]
Expand Down
25 changes: 15 additions & 10 deletions sheeprl/algos/dreamer_v1/dreamer_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,9 @@ def train(
# compute predictions for terminal steps, if required
if cfg.algo.world_model.use_continues and world_model.continue_model:
qc = Independent(Bernoulli(logits=world_model.continue_model(latent_states)), 1)
continue_targets = (1 - data["dones"]) * cfg.algo.gamma
continues_targets = (1 - data["terminated"]) * cfg.algo.gamma
else:
qc = continue_targets = None
qc = continues_targets = None

# compute the distributions of the states (posteriors and priors)
# it is necessary an Independent distribution because
Expand All @@ -200,7 +200,7 @@ def train(
cfg.algo.world_model.kl_free_nats,
cfg.algo.world_model.kl_regularizer,
qc,
continue_targets,
continues_targets,
cfg.algo.world_model.continue_scale_factor,
)
fabric.backward(rec_loss)
Expand Down Expand Up @@ -554,7 +554,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
if k in cfg.algo.cnn_keys.encoder:
obs[k] = obs[k].reshape(cfg.env.num_envs, -1, *obs[k].shape[-2:])
step_data[k] = obs[k][np.newaxis]
step_data["dones"] = np.zeros((1, cfg.env.num_envs, 1))
step_data["terminated"] = np.zeros((1, cfg.env.num_envs, 1))
step_data["truncated"] = np.zeros((1, cfg.env.num_envs, 1))
step_data["actions"] = np.zeros((1, cfg.env.num_envs, sum(actions_dim)))
step_data["rewards"] = np.zeros((1, cfg.env.num_envs, 1))
rb.add(step_data, validate_args=cfg.buffer.validate_args)
Expand Down Expand Up @@ -601,8 +602,10 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
real_actions = (
torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy()
)
next_obs, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape))
dones = np.logical_or(dones, truncated).astype(np.uint8)
next_obs, rewards, terminated, truncated, infos = envs.step(
real_actions.reshape(envs.action_space.shape)
)
dones = np.logical_or(terminated, truncated).astype(np.uint8)

if cfg.metric.log_level > 0 and "final_info" in infos:
for i, agent_ep_info in enumerate(infos["final_info"]):
Expand Down Expand Up @@ -631,7 +634,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
# next_obs becomes the new obs
obs = next_obs

step_data["dones"] = dones[np.newaxis]
step_data["terminated"] = terminated[np.newaxis]
step_data["truncated"] = truncated[np.newaxis]
step_data["actions"] = actions[np.newaxis]
step_data["rewards"] = clip_rewards_fn(rewards)[np.newaxis]
rb.add(step_data, validate_args=cfg.buffer.validate_args)
Expand All @@ -643,13 +647,14 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
reset_data = {}
for k in obs_keys:
reset_data[k] = (next_obs[k][dones_idxes])[np.newaxis]
reset_data["dones"] = np.zeros((1, reset_envs, 1))
reset_data["terminated"] = np.zeros((1, reset_envs, 1))
reset_data["truncated"] = np.zeros((1, reset_envs, 1))
reset_data["actions"] = np.zeros((1, reset_envs, np.sum(actions_dim)))
reset_data["rewards"] = np.zeros((1, reset_envs, 1))
rb.add(reset_data, dones_idxes, validate_args=cfg.buffer.validate_args)
# Reset dones so that `is_first` is updated
for d in dones_idxes:
step_data["dones"][0, d] = np.zeros_like(step_data["dones"][0, d])
step_data["terminated"][0, d] = np.zeros_like(step_data["terminated"][0, d])
step_data["truncated"][0, d] = np.zeros_like(step_data["truncated"][0, d])
# Reset internal agent states
player.init_states(reset_envs=dones_idxes)

Expand Down
37 changes: 22 additions & 15 deletions sheeprl/algos/dreamer_v2/dreamer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,9 @@ def train(
# Compute the distribution over the terminal steps, if required
if cfg.algo.world_model.use_continues and world_model.continue_model:
pc = Independent(Bernoulli(logits=world_model.continue_model(latent_states)), 1)
continue_targets = (1 - data["dones"]) * cfg.algo.gamma
continues_targets = (1 - data["terminated"]) * cfg.algo.gamma
else:
pc = continue_targets = None
pc = continues_targets = None

# Reshape posterior and prior logits to shape [T, B, 32, 32]
priors_logits = priors_logits.view(*priors_logits.shape[:-1], stochastic_size, discrete_size)
Expand All @@ -190,7 +190,7 @@ def train(
cfg.algo.world_model.kl_free_avg,
cfg.algo.world_model.kl_regularizer,
pc,
continue_targets,
continues_targets,
cfg.algo.world_model.discount_scale_factor,
)
fabric.backward(rec_loss)
Expand Down Expand Up @@ -264,8 +264,8 @@ def train(
predicted_rewards = world_model.reward_model(imagined_trajectories)
if cfg.algo.world_model.use_continues and world_model.continue_model:
continues = logits_to_probs(world_model.continue_model(imagined_trajectories), is_binary=True)
true_done = (1 - data["dones"]).reshape(1, -1, 1) * cfg.algo.gamma
continues = torch.cat((true_done, continues[1:]))
true_continue = (1 - data["terminated"]).reshape(1, -1, 1) * cfg.algo.gamma
continues = torch.cat((true_continue, continues[1:]))
else:
continues = torch.ones_like(predicted_rewards.detach()) * cfg.algo.gamma

Expand Down Expand Up @@ -576,12 +576,14 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
obs = envs.reset(seed=cfg.seed)[0]
for k in obs_keys:
step_data[k] = obs[k][np.newaxis]
step_data["dones"] = np.zeros((1, cfg.env.num_envs, 1))
step_data["terminated"] = np.zeros((1, cfg.env.num_envs, 1))
step_data["truncated"] = np.zeros((1, cfg.env.num_envs, 1))
if cfg.dry_run:
step_data["dones"] = step_data["dones"] + 1
step_data["truncated"] = step_data["truncated"] + 1
step_data["terminated"] = step_data["terminated"] + 1
step_data["actions"] = np.zeros((1, cfg.env.num_envs, sum(actions_dim)))
step_data["rewards"] = np.zeros((1, cfg.env.num_envs, 1))
step_data["is_first"] = np.ones_like(step_data["dones"])
step_data["is_first"] = np.ones_like(step_data["terminated"])
rb.add(step_data, validate_args=cfg.buffer.validate_args)
player.init_states()

Expand Down Expand Up @@ -627,9 +629,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy()
)

step_data["is_first"] = copy.deepcopy(step_data["dones"])
next_obs, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape))
dones = np.logical_or(dones, truncated).astype(np.uint8)
step_data["is_first"] = copy.deepcopy(np.logical_or(step_data["terminated"], step_data["truncated"]))
next_obs, rewards, terminated, truncated, infos = envs.step(
real_actions.reshape(envs.action_space.shape)
)
dones = np.logical_or(terminated, truncated).astype(np.uint8)
if cfg.dry_run and buffer_type == "episode":
dones = np.ones_like(dones)

Expand Down Expand Up @@ -657,7 +661,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
# Next_obs becomes the new obs
obs = next_obs

step_data["dones"] = dones.reshape((1, cfg.env.num_envs, -1))
step_data["terminated"] = terminated.reshape((1, cfg.env.num_envs, -1))
step_data["truncated"] = truncated.reshape((1, cfg.env.num_envs, -1))
step_data["actions"] = actions.reshape((1, cfg.env.num_envs, -1))
step_data["rewards"] = clip_rewards_fn(rewards).reshape((1, cfg.env.num_envs, -1))
rb.add(step_data, validate_args=cfg.buffer.validate_args)
Expand All @@ -669,14 +674,16 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
reset_data = {}
for k in obs_keys:
reset_data[k] = (next_obs[k][dones_idxes])[np.newaxis]
reset_data["dones"] = np.zeros((1, reset_envs, 1))
reset_data["terminated"] = np.zeros((1, reset_envs, 1))
reset_data["truncated"] = np.zeros((1, reset_envs, 1))
reset_data["actions"] = np.zeros((1, reset_envs, np.sum(actions_dim)))
reset_data["rewards"] = np.zeros((1, reset_envs, 1))
reset_data["is_first"] = np.ones_like(reset_data["dones"])
reset_data["is_first"] = np.ones_like(reset_data["terminated"])
rb.add(reset_data, dones_idxes, validate_args=cfg.buffer.validate_args)
# Reset dones so that `is_first` is updated
for d in dones_idxes:
step_data["dones"][0, d] = np.zeros_like(step_data["dones"][0, d])
step_data["terminated"][0, d] = np.zeros_like(step_data["terminated"][0, d])
step_data["truncated"][0, d] = np.zeros_like(step_data["truncated"][0, d])
# Reset internal agent states
player.init_states(dones_idxes)

Expand Down
Loading
Loading