From fdd3a84ae8a4ed3e21b60c3d7b08b4d2f821af79 Mon Sep 17 00:00:00 2001 From: Michele Milesi <74559684+michele-milesi@users.noreply.github.com> Date: Tue, 2 Apr 2024 16:09:51 +0200 Subject: [PATCH] Fix/terminated truncated (#252) * Decoupled RSSM for DV3 agent * Initialize posterior with prior if is_first is True * Fix PlayerDV3 creation in evaluation * Fix representation_model * Fix compute first prior state with a zero posterior * DV3 replay ratio conversion * Removed expl parameters dependent on old per_Rank_gradient_steps * feat: update repeats computation * feat: update learning starts in config * fix: remove files * feat: update repeats * Let Dv3 compute bootstrap correctly * feat: added replay ratio and update exploration * Fix exploration actions computation on DV1 * Fix naming * Add replay-ratio to SAC * feat: added replay ratio to p2e algos * feat: update configs and utils of p2e algos * Add replay-ratio to SAC-AE * Add DrOQ replay ratio * Fix tests * Fix mispelled * Fix wrong attribute accesing * FIx naming and configs * feat: add terminated and truncated to dreamer, p2e and ppo algos * fix: dmc wrapper * feat: update algos to split terminated from truncated * fix: crafter and diambra wrappers * feat: replace done with truncated key in when the buffer is added to the checkpoint * feat: added truncated/terminated to minedojo environment * feat: added terminated/truncated to minerl and super mario bros envs * docs: update howto * fix: minedojo wrapper * docs: update * fix: minedojo * update dependencies * fix: minedojo * fix: dv3 small configs * fix: episode buffer and tests --------- Co-authored-by: belerico --- howto/learn_in_minedojo.md | 5 + howto/learn_in_minerl.md | 5 + howto/logs_and_checkpoints.md | 1 - howto/select_observations.md | 1 + howto/work_with_steps.md | 10 +- pyproject.toml | 4 +- sheeprl/algos/a2c/a2c.py | 8 +- sheeprl/algos/dreamer_v1/dreamer_v1.py | 25 +- sheeprl/algos/dreamer_v2/dreamer_v2.py | 37 +-- sheeprl/algos/dreamer_v3/dreamer_v3.py | 39 +-- sheeprl/algos/droq/droq.py | 8 +- sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py | 25 +- sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py | 19 +- sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py | 41 +-- sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py | 27 +- sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py | 43 ++-- sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py | 31 ++- sheeprl/algos/ppo/ppo.py | 4 +- sheeprl/algos/ppo/ppo_decoupled.py | 4 +- sheeprl/algos/ppo_recurrent/ppo_recurrent.py | 8 +- sheeprl/algos/sac/sac.py | 8 +- sheeprl/algos/sac/sac_decoupled.py | 6 +- sheeprl/algos/sac_ae/sac_ae.py | 8 +- sheeprl/configs/env/minerl.yaml | 5 +- .../configs/env/minerl_obtain_diamond.yaml | 12 + .../env/minerl_obtain_iron_pickaxe.yaml | 12 + ...reamer_v3_dmc_cartpole_swingup_sparse.yaml | 4 +- .../exp/dreamer_v3_dmc_walker_walk.yaml | 2 +- sheeprl/data/buffers.py | 24 +- sheeprl/envs/crafter.py | 2 +- sheeprl/envs/diambra.py | 4 +- sheeprl/envs/dmc.py | 5 +- sheeprl/envs/minedojo.py | 30 ++- sheeprl/envs/minerl.py | 3 +- sheeprl/envs/minerl_envs/navigate.py | 9 +- sheeprl/envs/minerl_envs/obtain.py | 17 +- sheeprl/envs/super_mario_bros.py | 3 +- sheeprl/utils/callback.py | 12 +- tests/test_data/test_episode_buffer.py | 234 +++++++++++------- 39 files changed, 460 insertions(+), 285 deletions(-) create mode 100644 sheeprl/configs/env/minerl_obtain_diamond.yaml create mode 100644 sheeprl/configs/env/minerl_obtain_iron_pickaxe.yaml diff --git a/howto/learn_in_minedojo.md b/howto/learn_in_minedojo.md index b9fabbe1..1eea5a3f 100644 --- a/howto/learn_in_minedojo.md +++ b/howto/learn_in_minedojo.md @@ -62,9 +62,14 @@ Moreover, we restrict the look-up/down actions between `min_pitch` and `max_pitc In addition, we added the forward action when the agent selects one of the following actions: `jump`, `sprint`, and `sneak`. Finally, we added sticky actions for the `jump` and `attack` actions. You can set the values of the `sticky_jump` and `sticky_attack` parameters through the `env.sticky_jump` and `env.sticky_attack` cli arguments, respectively. The sticky actions, if set, force the agent to repeat the selected actions for a certain number of steps. +> [!NOTE] +> +> The `env.sticky_attack` parameter is set to `0` if the `env.break_speed_multiplier > 1`. + For more information about the MineDojo action space, check [here](https://docs.minedojo.org/sections/core_api/action_space.html). > [!NOTE] +> > Since the MineDojo environments have a multi-discrete action space, the sticky actions can be easily implemented. The agent will perform the selected action and the sticky actions simultaneously. > > The action repeat in the Minecraft environments is set to 1, indeed, It makes no sense to force the agent to repeat an action such as crafting (it may not have enough material for the second action). diff --git a/howto/learn_in_minerl.md b/howto/learn_in_minerl.md index 80edff01..50d61a33 100644 --- a/howto/learn_in_minerl.md +++ b/howto/learn_in_minerl.md @@ -47,10 +47,15 @@ In addition, we added the forward action when the agent selects one of the follo Finally, we added sticky actions for the `jump` and `attack` actions. You can set the values of the `sticky_jump` and `sticky_attack` parameters through the `env.sticky_jump` and `env.sticky_attack` arguments, respectively. The sticky actions, if set, force the agent to repeat the selected actions for a certain number of steps. > [!NOTE] +> > Since the MineRL environments have a multi-discrete action space, the sticky actions can be easily implemented. The agent will perform the selected action and the sticky actions simultaneously. > > The action repeat in the Minecraft environments is set to 1, indeed, It makes no sense to force the agent to repeat an action such as crafting (it may not have enough material for the second action). +> [!NOTE] +> +> The `env.sticky_attack` parameter is set to `0` if the `env.break_speed_multiplier > 1`. + ## Headless machines If you work on a headless machine, you need to software renderer. We recommend to adopt one of the following solutions: diff --git a/howto/logs_and_checkpoints.md b/howto/logs_and_checkpoints.md index 57a2d8e9..7c72d4c3 100644 --- a/howto/logs_and_checkpoints.md +++ b/howto/logs_and_checkpoints.md @@ -122,7 +122,6 @@ AGGREGATOR_KEYS = { "State/post_entropy", "State/prior_entropy", "State/kl", - "Params/exploration_amount", "Grads/world_model", "Grads/actor", "Grads/critic", diff --git a/howto/select_observations.md b/howto/select_observations.md index e984ee9f..61a6188c 100644 --- a/howto/select_observations.md +++ b/howto/select_observations.md @@ -8,6 +8,7 @@ In the first case, the observations are returned in the form of python dictionar ### Both observations The algorithms that can work with both image and vector observations are specified in [Table 1](../README.md) in the README, and are reported here: +* A2C * PPO * PPO Recurrent * SAC-AE diff --git a/howto/work_with_steps.md b/howto/work_with_steps.md index ec5a7147..5bc62a8c 100644 --- a/howto/work_with_steps.md +++ b/howto/work_with_steps.md @@ -22,11 +22,13 @@ The hyper-parameters that refer to the *policy steps* are: * `exploration_steps`: the number of policy steps in which the agent explores the environment in the P2E algorithms. * `max_episode_steps`: the maximum number of policy steps an episode can last (`max_steps`); when this number is reached a `terminated=True` is returned by the environment. This means that if you decide to have an action repeat greater than one (`action_repeat > 1`), then the environment performs a maximum number of steps equal to: `env_steps = max_steps * action_repeat`$. * `learning_starts`: how many policy steps the agent has to perform before starting the training. -* `train_every`: how many policy steps the agent has to perform between one training and the next. ## Gradient steps -A *gradient step* consists of an update of the parameters of the agent, i.e., a call of the *train* function. The gradient step is proportional to the number of parallel processes, indeed, if there are $n$ parallel processes, `n * gradient_steps` calls to the *train* method will be executed. +A *gradient step* consists of an update of the parameters of the agent, i.e., a call of the *train* function. The gradient step is proportional to the number of parallel processes, indeed, if there are $n$ parallel processes, `n * per_rank_gradient_steps` calls to the *train* method will be executed. The hyper-parameters which refer to the *gradient steps* are: -* `algo.per_rank_gradient_steps`: the number of gradient steps per rank to perform in a single iteration. -* `algo.per_rank_pretrain_steps`: the number of gradient steps per rank to perform in the first iteration. \ No newline at end of file +* `algo.per_rank_pretrain_steps`: the number of gradient steps per rank to perform in the first iteration. + +> [!NOTE] +> +> The `replay_ratio` is the ratio between the gradient steps and the policy steps played by the agente. \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index b4b7362b..6577c495 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -81,8 +81,8 @@ atari = [ "gymnasium[accept-rom-license]==0.29.*", "gymnasium[other]==0.29.*", ] -minedojo = ["minedojo==0.1", "importlib_resources==5.12.0"] -minerl = ["setuptools==66.0.0", "minerl==0.4.4"] +minedojo = ["minedojo==0.1", "importlib_resources==5.12.0", "gym==0.21.0"] +minerl = ["setuptools==66.0.0", "minerl==0.4.4", "gym==0.19.0"] diambra = ["diambra==0.0.17", "diambra-arena==2.2.6"] crafter = ["crafter==1.8.3"] mlflow = ["mlflow==2.11.1"] diff --git a/sheeprl/algos/a2c/a2c.py b/sheeprl/algos/a2c/a2c.py index bb0f7a01..8d45d0ed 100644 --- a/sheeprl/algos/a2c/a2c.py +++ b/sheeprl/algos/a2c/a2c.py @@ -243,7 +243,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): actions = torch.cat(actions, -1).cpu().numpy() # Single environment step - obs, rewards, done, truncated, info = envs.step(real_actions.reshape(envs.action_space.shape)) + obs, rewards, terminated, truncated, info = envs.step(real_actions.reshape(envs.action_space.shape)) truncated_envs = np.nonzero(truncated)[0] if len(truncated_envs) > 0: real_next_obs = { @@ -266,10 +266,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): rewards[truncated_envs] += cfg.algo.gamma * vals.cpu().numpy().reshape( rewards[truncated_envs].shape ) - - dones = np.logical_or(done, truncated) - dones = dones.reshape(cfg.env.num_envs, -1) - rewards = rewards.reshape(cfg.env.num_envs, -1) + dones = np.logical_or(terminated, truncated).reshape(cfg.env.num_envs, -1).astype(np.uint8) + rewards = rewards.reshape(cfg.env.num_envs, -1) # Update the step data step_data["dones"] = dones[np.newaxis] diff --git a/sheeprl/algos/dreamer_v1/dreamer_v1.py b/sheeprl/algos/dreamer_v1/dreamer_v1.py index c4b431ab..079a8065 100644 --- a/sheeprl/algos/dreamer_v1/dreamer_v1.py +++ b/sheeprl/algos/dreamer_v1/dreamer_v1.py @@ -176,9 +176,9 @@ def train( # compute predictions for terminal steps, if required if cfg.algo.world_model.use_continues and world_model.continue_model: qc = Independent(Bernoulli(logits=world_model.continue_model(latent_states)), 1) - continue_targets = (1 - data["dones"]) * cfg.algo.gamma + continues_targets = (1 - data["terminated"]) * cfg.algo.gamma else: - qc = continue_targets = None + qc = continues_targets = None # compute the distributions of the states (posteriors and priors) # it is necessary an Independent distribution because @@ -200,7 +200,7 @@ def train( cfg.algo.world_model.kl_free_nats, cfg.algo.world_model.kl_regularizer, qc, - continue_targets, + continues_targets, cfg.algo.world_model.continue_scale_factor, ) fabric.backward(rec_loss) @@ -554,7 +554,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): if k in cfg.algo.cnn_keys.encoder: obs[k] = obs[k].reshape(cfg.env.num_envs, -1, *obs[k].shape[-2:]) step_data[k] = obs[k][np.newaxis] - step_data["dones"] = np.zeros((1, cfg.env.num_envs, 1)) + step_data["terminated"] = np.zeros((1, cfg.env.num_envs, 1)) + step_data["truncated"] = np.zeros((1, cfg.env.num_envs, 1)) step_data["actions"] = np.zeros((1, cfg.env.num_envs, sum(actions_dim))) step_data["rewards"] = np.zeros((1, cfg.env.num_envs, 1)) rb.add(step_data, validate_args=cfg.buffer.validate_args) @@ -601,8 +602,10 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): real_actions = ( torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy() ) - next_obs, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) - dones = np.logical_or(dones, truncated).astype(np.uint8) + next_obs, rewards, terminated, truncated, infos = envs.step( + real_actions.reshape(envs.action_space.shape) + ) + dones = np.logical_or(terminated, truncated).astype(np.uint8) if cfg.metric.log_level > 0 and "final_info" in infos: for i, agent_ep_info in enumerate(infos["final_info"]): @@ -631,7 +634,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # next_obs becomes the new obs obs = next_obs - step_data["dones"] = dones[np.newaxis] + step_data["terminated"] = terminated[np.newaxis] + step_data["truncated"] = truncated[np.newaxis] step_data["actions"] = actions[np.newaxis] step_data["rewards"] = clip_rewards_fn(rewards)[np.newaxis] rb.add(step_data, validate_args=cfg.buffer.validate_args) @@ -643,13 +647,14 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): reset_data = {} for k in obs_keys: reset_data[k] = (next_obs[k][dones_idxes])[np.newaxis] - reset_data["dones"] = np.zeros((1, reset_envs, 1)) + reset_data["terminated"] = np.zeros((1, reset_envs, 1)) + reset_data["truncated"] = np.zeros((1, reset_envs, 1)) reset_data["actions"] = np.zeros((1, reset_envs, np.sum(actions_dim))) reset_data["rewards"] = np.zeros((1, reset_envs, 1)) rb.add(reset_data, dones_idxes, validate_args=cfg.buffer.validate_args) - # Reset dones so that `is_first` is updated for d in dones_idxes: - step_data["dones"][0, d] = np.zeros_like(step_data["dones"][0, d]) + step_data["terminated"][0, d] = np.zeros_like(step_data["terminated"][0, d]) + step_data["truncated"][0, d] = np.zeros_like(step_data["truncated"][0, d]) # Reset internal agent states player.init_states(reset_envs=dones_idxes) diff --git a/sheeprl/algos/dreamer_v2/dreamer_v2.py b/sheeprl/algos/dreamer_v2/dreamer_v2.py index 4b56c740..97dccccf 100644 --- a/sheeprl/algos/dreamer_v2/dreamer_v2.py +++ b/sheeprl/algos/dreamer_v2/dreamer_v2.py @@ -168,9 +168,9 @@ def train( # Compute the distribution over the terminal steps, if required if cfg.algo.world_model.use_continues and world_model.continue_model: pc = Independent(Bernoulli(logits=world_model.continue_model(latent_states)), 1) - continue_targets = (1 - data["dones"]) * cfg.algo.gamma + continues_targets = (1 - data["terminated"]) * cfg.algo.gamma else: - pc = continue_targets = None + pc = continues_targets = None # Reshape posterior and prior logits to shape [T, B, 32, 32] priors_logits = priors_logits.view(*priors_logits.shape[:-1], stochastic_size, discrete_size) @@ -190,7 +190,7 @@ def train( cfg.algo.world_model.kl_free_avg, cfg.algo.world_model.kl_regularizer, pc, - continue_targets, + continues_targets, cfg.algo.world_model.discount_scale_factor, ) fabric.backward(rec_loss) @@ -264,8 +264,8 @@ def train( predicted_rewards = world_model.reward_model(imagined_trajectories) if cfg.algo.world_model.use_continues and world_model.continue_model: continues = logits_to_probs(world_model.continue_model(imagined_trajectories), is_binary=True) - true_done = (1 - data["dones"]).reshape(1, -1, 1) * cfg.algo.gamma - continues = torch.cat((true_done, continues[1:])) + true_continue = (1 - data["terminated"]).reshape(1, -1, 1) * cfg.algo.gamma + continues = torch.cat((true_continue, continues[1:])) else: continues = torch.ones_like(predicted_rewards.detach()) * cfg.algo.gamma @@ -576,12 +576,14 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): obs = envs.reset(seed=cfg.seed)[0] for k in obs_keys: step_data[k] = obs[k][np.newaxis] - step_data["dones"] = np.zeros((1, cfg.env.num_envs, 1)) + step_data["terminated"] = np.zeros((1, cfg.env.num_envs, 1)) + step_data["truncated"] = np.zeros((1, cfg.env.num_envs, 1)) if cfg.dry_run: - step_data["dones"] = step_data["dones"] + 1 + step_data["truncated"] = step_data["truncated"] + 1 + step_data["terminated"] = step_data["terminated"] + 1 step_data["actions"] = np.zeros((1, cfg.env.num_envs, sum(actions_dim))) step_data["rewards"] = np.zeros((1, cfg.env.num_envs, 1)) - step_data["is_first"] = np.ones_like(step_data["dones"]) + step_data["is_first"] = np.ones_like(step_data["terminated"]) rb.add(step_data, validate_args=cfg.buffer.validate_args) player.init_states() @@ -627,9 +629,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy() ) - step_data["is_first"] = copy.deepcopy(step_data["dones"]) - next_obs, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) - dones = np.logical_or(dones, truncated).astype(np.uint8) + step_data["is_first"] = copy.deepcopy(np.logical_or(step_data["terminated"], step_data["truncated"])) + next_obs, rewards, terminated, truncated, infos = envs.step( + real_actions.reshape(envs.action_space.shape) + ) + dones = np.logical_or(terminated, truncated).astype(np.uint8) if cfg.dry_run and buffer_type == "episode": dones = np.ones_like(dones) @@ -657,7 +661,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Next_obs becomes the new obs obs = next_obs - step_data["dones"] = dones.reshape((1, cfg.env.num_envs, -1)) + step_data["terminated"] = terminated.reshape((1, cfg.env.num_envs, -1)) + step_data["truncated"] = truncated.reshape((1, cfg.env.num_envs, -1)) step_data["actions"] = actions.reshape((1, cfg.env.num_envs, -1)) step_data["rewards"] = clip_rewards_fn(rewards).reshape((1, cfg.env.num_envs, -1)) rb.add(step_data, validate_args=cfg.buffer.validate_args) @@ -669,14 +674,16 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): reset_data = {} for k in obs_keys: reset_data[k] = (next_obs[k][dones_idxes])[np.newaxis] - reset_data["dones"] = np.zeros((1, reset_envs, 1)) + reset_data["terminated"] = np.zeros((1, reset_envs, 1)) + reset_data["truncated"] = np.zeros((1, reset_envs, 1)) reset_data["actions"] = np.zeros((1, reset_envs, np.sum(actions_dim))) reset_data["rewards"] = np.zeros((1, reset_envs, 1)) - reset_data["is_first"] = np.ones_like(reset_data["dones"]) + reset_data["is_first"] = np.ones_like(reset_data["terminated"]) rb.add(reset_data, dones_idxes, validate_args=cfg.buffer.validate_args) # Reset dones so that `is_first` is updated for d in dones_idxes: - step_data["dones"][0, d] = np.zeros_like(step_data["dones"][0, d]) + step_data["terminated"][0, d] = np.zeros_like(step_data["terminated"][0, d]) + step_data["truncated"][0, d] = np.zeros_like(step_data["truncated"][0, d]) # Reset internal agent states player.init_states(dones_idxes) diff --git a/sheeprl/algos/dreamer_v3/dreamer_v3.py b/sheeprl/algos/dreamer_v3/dreamer_v3.py index 2764d90a..d0e76a81 100644 --- a/sheeprl/algos/dreamer_v3/dreamer_v3.py +++ b/sheeprl/algos/dreamer_v3/dreamer_v3.py @@ -165,7 +165,7 @@ def train( # Compute the distribution over the terminal steps, if required pc = Independent(BernoulliSafeMode(logits=world_model.continue_model(latent_states)), 1) - continue_targets = 1 - data["dones"] + continues_targets = 1 - data["terminated"] # Reshape posterior and prior logits to shape [B, T, 32, 32] priors_logits = priors_logits.view(*priors_logits.shape[:-1], stochastic_size, discrete_size) @@ -185,7 +185,7 @@ def train( cfg.algo.world_model.kl_free_nats, cfg.algo.world_model.kl_regularizer, pc, - continue_targets, + continues_targets, cfg.algo.world_model.continue_scale_factor, ) fabric.backward(rec_loss) @@ -244,8 +244,8 @@ def train( predicted_values = TwoHotEncodingDistribution(critic(imagined_trajectories), dims=1).mean predicted_rewards = TwoHotEncodingDistribution(world_model.reward_model(imagined_trajectories), dims=1).mean continues = Independent(BernoulliSafeMode(logits=world_model.continue_model(imagined_trajectories)), 1).mode - true_done = (1 - data["dones"]).flatten().reshape(1, -1, 1) - continues = torch.cat((true_done, continues[1:])) + true_continue = (1 - data["terminated"]).flatten().reshape(1, -1, 1) + continues = torch.cat((true_continue, continues[1:])) # Estimate lambda-values lambda_values = compute_lambda_values( @@ -548,9 +548,10 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): obs = envs.reset(seed=cfg.seed)[0] for k in obs_keys: step_data[k] = obs[k][np.newaxis] - step_data["dones"] = np.zeros((1, cfg.env.num_envs, 1)) step_data["rewards"] = np.zeros((1, cfg.env.num_envs, 1)) - step_data["is_first"] = np.ones_like(step_data["dones"]) + step_data["truncated"] = np.zeros((1, cfg.env.num_envs, 1)) + step_data["terminated"] = np.zeros((1, cfg.env.num_envs, 1)) + step_data["is_first"] = np.ones_like(step_data["terminated"]) player.init_states() cumulative_per_rank_gradient_steps = 0 @@ -597,16 +598,21 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): step_data["actions"] = actions.reshape((1, cfg.env.num_envs, -1)) rb.add(step_data, validate_args=cfg.buffer.validate_args) - next_obs, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) - dones = np.logical_or(dones, truncated).astype(np.uint8) + next_obs, rewards, terminated, truncated, infos = envs.step( + real_actions.reshape(envs.action_space.shape) + ) + dones = np.logical_or(terminated, truncated).astype(np.uint8) - step_data["is_first"] = np.zeros_like(step_data["dones"]) + step_data["is_first"] = np.zeros_like(step_data["terminated"]) if "restart_on_exception" in infos: for i, agent_roe in enumerate(infos["restart_on_exception"]): if agent_roe and not dones[i]: last_inserted_idx = (rb.buffer[i]._pos - 1) % rb.buffer[i].buffer_size - rb.buffer[i]["dones"][last_inserted_idx] = np.ones_like( - rb.buffer[i]["dones"][last_inserted_idx] + rb.buffer[i]["terminated"][last_inserted_idx] = np.zeros_like( + rb.buffer[i]["terminated"][last_inserted_idx] + ) + rb.buffer[i]["truncated"][last_inserted_idx] = np.ones_like( + rb.buffer[i]["truncated"][last_inserted_idx] ) rb.buffer[i]["is_first"][last_inserted_idx] = np.zeros_like( rb.buffer[i]["is_first"][last_inserted_idx] @@ -638,7 +644,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): obs = next_obs rewards = rewards.reshape((1, cfg.env.num_envs, -1)) - step_data["dones"] = dones.reshape((1, cfg.env.num_envs, -1)) + step_data["terminated"] = terminated.reshape((1, cfg.env.num_envs, -1)) + step_data["truncated"] = truncated.reshape((1, cfg.env.num_envs, -1)) step_data["rewards"] = clip_rewards_fn(rewards) dones_idxes = dones.nonzero()[0].tolist() @@ -647,15 +654,17 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): reset_data = {} for k in obs_keys: reset_data[k] = (real_next_obs[k][dones_idxes])[np.newaxis] - reset_data["dones"] = np.ones((1, reset_envs, 1)) + reset_data["terminated"] = step_data["terminated"][:, dones_idxes] + reset_data["truncated"] = step_data["truncated"][:, dones_idxes] reset_data["actions"] = np.zeros((1, reset_envs, np.sum(actions_dim))) reset_data["rewards"] = step_data["rewards"][:, dones_idxes] - reset_data["is_first"] = np.zeros_like(reset_data["dones"]) + reset_data["is_first"] = np.zeros_like(reset_data["terminated"]) rb.add(reset_data, dones_idxes, validate_args=cfg.buffer.validate_args) # Reset already inserted step data step_data["rewards"][:, dones_idxes] = np.zeros_like(reset_data["rewards"]) - step_data["dones"][:, dones_idxes] = np.zeros_like(step_data["dones"][:, dones_idxes]) + step_data["terminated"][:, dones_idxes] = np.zeros_like(step_data["terminated"][:, dones_idxes]) + step_data["truncated"][:, dones_idxes] = np.zeros_like(step_data["truncated"][:, dones_idxes]) step_data["is_first"][:, dones_idxes] = np.ones_like(step_data["is_first"][:, dones_idxes]) player.init_states(dones_idxes) diff --git a/sheeprl/algos/droq/droq.py b/sheeprl/algos/droq/droq.py index 186a905f..0c54d42d 100644 --- a/sheeprl/algos/droq/droq.py +++ b/sheeprl/algos/droq/droq.py @@ -97,7 +97,7 @@ def train( next_target_qf_value = agent.get_next_target_q_values( critic_batch_data["next_observations"], critic_batch_data["rewards"], - critic_batch_data["dones"], + critic_batch_data["terminated"], cfg.algo.gamma, ) for qf_value_idx in range(agent.num_critics): @@ -310,8 +310,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Sample an action given the observation received by the environment actions, _ = actor(torch.from_numpy(obs).to(device)) actions = actions.cpu().numpy() - next_obs, rewards, dones, truncated, infos = envs.step(actions.reshape(envs.action_space.shape)) - dones = np.logical_or(dones, truncated) + next_obs, rewards, terminated, truncated, infos = envs.step(actions.reshape(envs.action_space.shape)) if cfg.metric.log_level > 0 and "final_info" in infos: for i, agent_ep_info in enumerate(infos["final_info"]): @@ -340,7 +339,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): if not cfg.buffer.sample_next_obs: step_data["next_observations"] = real_next_obs[np.newaxis] step_data["actions"] = actions.reshape(1, cfg.env.num_envs, -1) - step_data["dones"] = dones.reshape(1, cfg.env.num_envs, -1).astype(np.float32) + step_data["terminated"] = terminated.reshape(1, cfg.env.num_envs, -1).astype(np.float32) + step_data["truncated"] = truncated.reshape(1, cfg.env.num_envs, -1).astype(np.float32) step_data["rewards"] = rewards.reshape(1, cfg.env.num_envs, -1).astype(np.float32) rb.add(step_data, validate_args=cfg.buffer.validate_args) diff --git a/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py b/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py index ced372ec..632633b9 100644 --- a/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py +++ b/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py @@ -134,9 +134,9 @@ def train( qr = Independent(Normal(world_model.reward_model(latent_states.detach()), 1), 1) if cfg.algo.world_model.use_continues and world_model.continue_model: qc = Independent(Bernoulli(logits=world_model.continue_model(latent_states.detach())), 1) - continue_targets = (1 - data["dones"]) * cfg.algo.gamma + continues_targets = (1 - data["terminated"]) * cfg.algo.gamma else: - qc = continue_targets = None + qc = continues_targets = None posteriors_dist = Independent(Normal(posteriors_mean, posteriors_std), 1) priors_dist = Independent(Normal(priors_mean, priors_std), 1) @@ -151,7 +151,7 @@ def train( cfg.algo.world_model.kl_free_nats, cfg.algo.world_model.kl_regularizer, qc, - continue_targets, + continues_targets, cfg.algo.world_model.continue_scale_factor, ) fabric.backward(rec_loss) @@ -580,7 +580,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): if k in cfg.algo.cnn_keys.encoder: obs[k] = obs[k].reshape(cfg.env.num_envs, -1, *obs[k].shape[-2:]) step_data[k] = obs[k][np.newaxis] - step_data["dones"] = np.zeros((1, cfg.env.num_envs, 1)) + step_data["terminated"] = np.zeros((1, cfg.env.num_envs, 1)) + step_data["truncated"] = np.zeros((1, cfg.env.num_envs, 1)) step_data["actions"] = np.zeros((1, cfg.env.num_envs, sum(actions_dim))) step_data["rewards"] = np.zeros((1, cfg.env.num_envs, 1)) rb.add(step_data, validate_args=cfg.buffer.validate_args) @@ -627,8 +628,10 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): real_actions = ( torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy() ) - next_obs, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) - dones = np.logical_or(dones, truncated).astype(np.uint8) + next_obs, rewards, terminated, truncated, infos = envs.step( + real_actions.reshape(envs.action_space.shape) + ) + dones = np.logical_or(terminated, truncated).astype(np.uint8) if cfg.metric.log_level > 0 and "final_info" in infos: for i, agent_ep_info in enumerate(infos["final_info"]): @@ -657,7 +660,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # next_obs becomes the new obs obs = next_obs - step_data["dones"] = dones[np.newaxis] + step_data["terminated"] = terminated[np.newaxis] + step_data["truncated"] = truncated[np.newaxis] step_data["actions"] = actions[np.newaxis] step_data["rewards"] = clip_rewards_fn(rewards)[np.newaxis] rb.add(step_data, validate_args=cfg.buffer.validate_args) @@ -669,13 +673,14 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): reset_data = {} for k in obs_keys: reset_data[k] = (next_obs[k][dones_idxes])[np.newaxis] - reset_data["dones"] = np.zeros((1, reset_envs, 1)) + reset_data["terminated"] = np.zeros((1, reset_envs, 1)) + reset_data["truncated"] = np.zeros((1, reset_envs, 1)) reset_data["actions"] = np.zeros((1, reset_envs, np.sum(actions_dim))) reset_data["rewards"] = np.zeros((1, reset_envs, 1)) rb.add(reset_data, dones_idxes, validate_args=cfg.buffer.validate_args) - # Reset dones so that `is_first` is updated for d in dones_idxes: - step_data["dones"][0, d] = np.zeros_like(step_data["dones"][0, d]) + step_data["terminated"][0, d] = np.zeros_like(step_data["terminated"][0, d]) + step_data["truncated"][0, d] = np.zeros_like(step_data["truncated"][0, d]) # Reset internal agent states player.init_states(reset_envs=dones_idxes) diff --git a/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py b/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py index 1919d825..f7dfc34d 100644 --- a/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py +++ b/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py @@ -252,7 +252,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): if k in cfg.algo.cnn_keys.encoder: obs[k] = obs[k].reshape(cfg.env.num_envs, -1, *obs[k].shape[-2:]) step_data[k] = obs[k][np.newaxis] - step_data["dones"] = np.zeros((1, cfg.env.num_envs, 1)) + step_data["terminated"] = np.zeros((1, cfg.env.num_envs, 1)) + step_data["truncated"] = np.zeros((1, cfg.env.num_envs, 1)) step_data["actions"] = np.zeros((1, cfg.env.num_envs, sum(actions_dim))) step_data["rewards"] = np.zeros((1, cfg.env.num_envs, 1)) rb.add(step_data, validate_args=cfg.buffer.validate_args) @@ -283,8 +284,10 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): real_actions = ( torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy() ) - next_obs, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) - dones = np.logical_or(dones, truncated).astype(np.uint8) + next_obs, rewards, terminated, truncated, infos = envs.step( + real_actions.reshape(envs.action_space.shape) + ) + dones = np.logical_or(terminated, truncated).astype(np.uint8) if cfg.metric.log_level > 0 and "final_info" in infos: for i, agent_ep_info in enumerate(infos["final_info"]): @@ -313,7 +316,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): # next_obs becomes the new obs obs = next_obs - step_data["dones"] = dones[np.newaxis] + step_data["terminated"] = terminated[np.newaxis] + step_data["truncated"] = truncated[np.newaxis] step_data["actions"] = actions[np.newaxis] step_data["rewards"] = clip_rewards_fn(rewards)[np.newaxis] rb.add(step_data, validate_args=cfg.buffer.validate_args) @@ -325,13 +329,14 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): reset_data = {} for k in obs_keys: reset_data[k] = (next_obs[k][dones_idxes])[np.newaxis] - reset_data["dones"] = np.zeros((1, reset_envs, 1)) + reset_data["terminated"] = np.zeros((1, reset_envs, 1)) + reset_data["truncated"] = np.zeros((1, reset_envs, 1)) reset_data["actions"] = np.zeros((1, reset_envs, np.sum(actions_dim))) reset_data["rewards"] = np.zeros((1, reset_envs, 1)) rb.add(reset_data, dones_idxes, validate_args=cfg.buffer.validate_args) - # Reset dones so that `is_first` is updated for d in dones_idxes: - step_data["dones"][0, d] = np.zeros_like(step_data["dones"][0, d]) + step_data["terminated"][0, d] = np.zeros_like(step_data["terminated"][0, d]) + step_data["truncated"][0, d] = np.zeros_like(step_data["truncated"][0, d]) # Reset internal agent states player.init_states(reset_envs=dones_idxes) diff --git a/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py b/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py index b551b54e..4b7cc09b 100644 --- a/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py +++ b/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py @@ -152,9 +152,9 @@ def train( # compute the distribution over the terminal steps, if required if cfg.algo.world_model.use_continues and world_model.continue_model: pc = Independent(Bernoulli(logits=world_model.continue_model(latent_states.detach())), 1) - continue_targets = (1 - data["dones"]) * cfg.algo.gamma + continues_targets = (1 - data["terminated"]) * cfg.algo.gamma else: - pc = continue_targets = None + pc = continues_targets = None # Reshape posterior and prior logits to shape [B, T, 32, 32] priors_logits = priors_logits.view(*priors_logits.shape[:-1], stochastic_size, discrete_size) @@ -174,7 +174,7 @@ def train( cfg.algo.world_model.kl_free_avg, cfg.algo.world_model.kl_regularizer, pc, - continue_targets, + continues_targets, cfg.algo.world_model.discount_scale_factor, ) fabric.backward(rec_loss) @@ -260,8 +260,8 @@ def train( if cfg.algo.world_model.use_continues and world_model.continue_model: continues = logits_to_probs(logits=world_model.continue_model(imagined_trajectories), is_binary=True) - true_done = (1 - data["dones"]).flatten().reshape(1, -1, 1) * cfg.algo.gamma - continues = torch.cat((true_done, continues[1:])) + true_continue = (1 - data["terminated"]).flatten().reshape(1, -1, 1) * cfg.algo.gamma + continues = torch.cat((true_continue, continues[1:])) else: continues = torch.ones_like(intrinsic_reward.detach()) * cfg.algo.gamma @@ -359,8 +359,8 @@ def train( predicted_rewards = world_model.reward_model(imagined_trajectories) if cfg.algo.world_model.use_continues and world_model.continue_model: continues = logits_to_probs(logits=world_model.continue_model(imagined_trajectories), is_binary=True) - true_done = (1 - data["dones"]).reshape(1, -1, 1) * cfg.algo.gamma - continues = torch.cat((true_done, continues[1:])) + true_continue = (1 - data["terminated"]).reshape(1, -1, 1) * cfg.algo.gamma + continues = torch.cat((true_continue, continues[1:])) else: continues = torch.ones_like(predicted_rewards.detach()) * cfg.algo.gamma @@ -713,12 +713,14 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): obs = envs.reset(seed=cfg.seed)[0] for k in obs_keys: step_data[k] = obs[k][np.newaxis] - step_data["dones"] = np.zeros((1, cfg.env.num_envs, 1)) + step_data["terminated"] = np.zeros((1, cfg.env.num_envs, 1)) + step_data["truncated"] = np.zeros((1, cfg.env.num_envs, 1)) if cfg.dry_run: - step_data["dones"] = step_data["dones"] + 1 + step_data["terminated"] = step_data["terminated"] + 1 + step_data["truncated"] = step_data["truncated"] + 1 step_data["actions"] = np.zeros((1, cfg.env.num_envs, sum(actions_dim))) step_data["rewards"] = np.zeros((1, cfg.env.num_envs, 1)) - step_data["is_first"] = np.ones_like(step_data["dones"]) + step_data["is_first"] = np.ones_like(step_data["terminated"]) rb.add(step_data, validate_args=cfg.buffer.validate_args) player.init_states() @@ -764,9 +766,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy() ) - step_data["is_first"] = copy.deepcopy(step_data["dones"]) - next_obs, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) - dones = np.logical_or(dones, truncated).astype(np.uint8) + step_data["is_first"] = copy.deepcopy(np.logical_or(step_data["terminated"], step_data["truncated"])) + next_obs, rewards, terminated, truncated, infos = envs.step( + real_actions.reshape(envs.action_space.shape) + ) + dones = np.logical_or(terminated, truncated).astype(np.uint8) if cfg.dry_run and buffer_type == "episode": dones = np.ones_like(dones) @@ -794,7 +798,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Next_obs becomes the new obs obs = next_obs - step_data["dones"] = dones.reshape((1, cfg.env.num_envs, -1)) + step_data["terminated"] = terminated.reshape((1, cfg.env.num_envs, -1)) + step_data["truncated"] = truncated.reshape((1, cfg.env.num_envs, -1)) step_data["actions"] = actions.reshape((1, cfg.env.num_envs, -1)) step_data["rewards"] = clip_rewards_fn(rewards).reshape((1, cfg.env.num_envs, -1)) rb.add(step_data, validate_args=cfg.buffer.validate_args) @@ -806,14 +811,16 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): reset_data = {} for k in obs_keys: reset_data[k] = (next_obs[k][dones_idxes])[np.newaxis] - reset_data["dones"] = np.zeros((1, reset_envs, 1)) + reset_data["terminated"] = np.zeros((1, reset_envs, 1)) + reset_data["truncated"] = np.zeros((1, reset_envs, 1)) reset_data["actions"] = np.zeros((1, reset_envs, np.sum(actions_dim))) reset_data["rewards"] = np.zeros((1, reset_envs, 1)) - reset_data["is_first"] = np.ones_like(reset_data["dones"]) + reset_data["is_first"] = np.ones_like(reset_data["terminated"]) rb.add(reset_data, dones_idxes, validate_args=cfg.buffer.validate_args) # Reset dones so that `is_first` is updated for d in dones_idxes: - step_data["dones"][0, d] = np.zeros_like(step_data["dones"][0, d]) + step_data["terminated"][0, d] = np.zeros_like(step_data["terminated"][0, d]) + step_data["truncated"][0, d] = np.zeros_like(step_data["truncated"][0, d]) # Reset internal agent states player.init_states(dones_idxes) diff --git a/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py b/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py index 24b0929b..365722b1 100644 --- a/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py +++ b/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py @@ -269,12 +269,14 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): obs = envs.reset(seed=cfg.seed)[0] for k in obs_keys: step_data[k] = obs[k][np.newaxis] - step_data["dones"] = np.zeros((1, cfg.env.num_envs, 1)) + step_data["terminated"] = np.zeros((1, cfg.env.num_envs, 1)) + step_data["truncated"] = np.zeros((1, cfg.env.num_envs, 1)) if cfg.dry_run: - step_data["dones"] = step_data["dones"] + 1 + step_data["terminated"] = step_data["terminated"] + 1 + step_data["truncated"] = step_data["truncated"] + 1 step_data["actions"] = np.zeros((1, cfg.env.num_envs, sum(actions_dim))) step_data["rewards"] = np.zeros((1, cfg.env.num_envs, 1)) - step_data["is_first"] = np.ones_like(step_data["dones"]) + step_data["is_first"] = np.ones_like(step_data["terminated"]) rb.add(step_data, validate_args=cfg.buffer.validate_args) player.init_states() @@ -304,9 +306,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy() ) - step_data["is_first"] = copy.deepcopy(step_data["dones"]) - next_obs, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) - dones = np.logical_or(dones, truncated).astype(np.uint8) + step_data["is_first"] = copy.deepcopy(np.logical_or(step_data["terminated"], step_data["truncated"])) + next_obs, rewards, terminated, truncated, infos = envs.step( + real_actions.reshape(envs.action_space.shape) + ) + dones = np.logical_or(terminated, truncated).astype(np.uint8) if cfg.dry_run and buffer_type == "episode": dones = np.ones_like(dones) @@ -334,7 +338,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): # Next_obs becomes the new obs obs = next_obs - step_data["dones"] = dones.reshape((1, cfg.env.num_envs, -1)) + step_data["terminated"] = terminated.reshape((1, cfg.env.num_envs, -1)) + step_data["terminated"] = terminated.reshape((1, cfg.env.num_envs, -1)) step_data["actions"] = actions.reshape((1, cfg.env.num_envs, -1)) step_data["rewards"] = clip_rewards_fn(rewards).reshape((1, cfg.env.num_envs, -1)) rb.add(step_data, validate_args=cfg.buffer.validate_args) @@ -346,14 +351,16 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): reset_data = {} for k in obs_keys: reset_data[k] = (next_obs[k][dones_idxes])[np.newaxis] - reset_data["dones"] = np.zeros((1, reset_envs, 1)) + reset_data["terminated"] = np.zeros((1, reset_envs, 1)) + reset_data["truncated"] = np.zeros((1, reset_envs, 1)) reset_data["actions"] = np.zeros((1, reset_envs, np.sum(actions_dim))) reset_data["rewards"] = np.zeros((1, reset_envs, 1)) - reset_data["is_first"] = np.ones_like(reset_data["dones"]) + reset_data["is_first"] = np.ones_like(reset_data["terminated"]) rb.add(reset_data, dones_idxes, validate_args=cfg.buffer.validate_args) # Reset dones so that `is_first` is updated for d in dones_idxes: - step_data["dones"][0, d] = np.zeros_like(step_data["dones"][0, d]) + step_data["terminated"][0, d] = np.zeros_like(step_data["terminated"][0, d]) + step_data["terminated"][0, d] = np.zeros_like(step_data["terminated"][0, d]) # Reset internal agent states player.init_states(dones_idxes) diff --git a/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py b/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py index f179bbf6..bb622993 100644 --- a/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py +++ b/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py @@ -161,7 +161,7 @@ def train( # Compute the distribution over the terminal steps, if required pc = Independent(BernoulliSafeMode(logits=world_model.continue_model(latent_states.detach())), 1) - continue_targets = 1 - data["dones"] + continues_targets = 1 - data["terminated"] # Reshape posterior and prior logits to shape [B, T, 32, 32] priors_logits = priors_logits.view(*priors_logits.shape[:-1], stochastic_size, discrete_size) @@ -181,7 +181,7 @@ def train( cfg.algo.world_model.kl_free_nats, cfg.algo.world_model.kl_regularizer, pc, - continue_targets, + continues_targets, cfg.algo.world_model.continue_scale_factor, ) fabric.backward(rec_loss) @@ -264,8 +264,8 @@ def train( # Predict values and continues predicted_values = TwoHotEncodingDistribution(critic["module"](imagined_trajectories), dims=1).mean continues = Independent(BernoulliSafeMode(logits=world_model.continue_model(imagined_trajectories)), 1).mode - true_done = (1 - data["dones"]).flatten().reshape(1, -1, 1) - continues = torch.cat((true_done, continues[1:])) + true_continue = (1 - data["terminated"]).flatten().reshape(1, -1, 1) + continues = torch.cat((true_continue, continues[1:])) if critic["reward_type"] == "intrinsic": # Predict intrinsic reward @@ -404,8 +404,8 @@ def train( predicted_values = TwoHotEncodingDistribution(critic_task(imagined_trajectories), dims=1).mean predicted_rewards = TwoHotEncodingDistribution(world_model.reward_model(imagined_trajectories), dims=1).mean continues = Independent(BernoulliSafeMode(logits=world_model.continue_model(imagined_trajectories)), 1).mode - true_done = (1 - data["dones"]).flatten().reshape(1, -1, 1) - continues = torch.cat((true_done, continues[1:])) + true_continue = (1 - data["terminated"]).flatten().reshape(1, -1, 1) + continues = torch.cat((true_continue, continues[1:])) lambda_values = compute_lambda_values( predicted_rewards[1:], @@ -789,9 +789,10 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): obs = envs.reset(seed=cfg.seed)[0] for k in obs_keys: step_data[k] = obs[k][np.newaxis] - step_data["dones"] = np.zeros((1, cfg.env.num_envs, 1)) + step_data["terminated"] = np.zeros((1, cfg.env.num_envs, 1)) + step_data["truncated"] = np.zeros((1, cfg.env.num_envs, 1)) step_data["rewards"] = np.zeros((1, cfg.env.num_envs, 1)) - step_data["is_first"] = np.ones_like(step_data["dones"]) + step_data["is_first"] = np.ones_like(step_data["terminated"]) player.init_states() cumulative_per_rank_gradient_steps = 0 @@ -838,16 +839,21 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): step_data["actions"] = actions.reshape((1, cfg.env.num_envs, -1)) rb.add(step_data, validate_args=cfg.buffer.validate_args) - next_obs, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) - dones = np.logical_or(dones, truncated).astype(np.uint8) + next_obs, rewards, terminated, truncated, infos = envs.step( + real_actions.reshape(envs.action_space.shape) + ) + dones = np.logical_or(terminated, truncated).astype(np.uint8) - step_data["is_first"] = np.zeros_like(step_data["dones"]) + step_data["is_first"] = np.zeros_like(step_data["terminated"]) if "restart_on_exception" in infos: for i, agent_roe in enumerate(infos["restart_on_exception"]): if agent_roe and not dones[i]: last_inserted_idx = (rb.buffer[i]._pos - 1) % rb.buffer[i].buffer_size - rb.buffer[i]["dones"][last_inserted_idx] = np.ones_like( - rb.buffer[i]["dones"][last_inserted_idx] + rb.buffer[i]["terminated"][last_inserted_idx] = np.zeros_like( + rb.buffer[i]["terminated"][last_inserted_idx] + ) + rb.buffer[i]["truncated"][last_inserted_idx] = np.ones_like( + rb.buffer[i]["truncated"][last_inserted_idx] ) rb.buffer[i]["is_first"][last_inserted_idx] = np.zeros_like( rb.buffer[i]["is_first"][last_inserted_idx] @@ -879,7 +885,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): obs = next_obs rewards = rewards.reshape((1, cfg.env.num_envs, -1)) - step_data["dones"] = dones.reshape((1, cfg.env.num_envs, -1)) + step_data["terminated"] = terminated.reshape((1, cfg.env.num_envs, -1)) + step_data["truncated"] = truncated.reshape((1, cfg.env.num_envs, -1)) step_data["rewards"] = clip_rewards_fn(rewards) dones_idxes = dones.nonzero()[0].tolist() @@ -888,15 +895,17 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): reset_data = {} for k in obs_keys: reset_data[k] = (real_next_obs[k][dones_idxes])[np.newaxis] - reset_data["dones"] = np.ones((1, reset_envs, 1)) + reset_data["terminated"] = step_data["terminated"][:, dones_idxes] + reset_data["truncated"] = step_data["truncated"][:, dones_idxes] reset_data["actions"] = np.zeros((1, reset_envs, np.sum(actions_dim))) reset_data["rewards"] = step_data["rewards"][:, dones_idxes] - reset_data["is_first"] = np.zeros_like(reset_data["dones"]) + reset_data["is_first"] = np.zeros_like(reset_data["terminated"]) rb.add(reset_data, dones_idxes, validate_args=cfg.buffer.validate_args) # Reset already inserted step data step_data["rewards"][:, dones_idxes] = np.zeros_like(reset_data["rewards"]) - step_data["dones"][:, dones_idxes] = np.zeros_like(step_data["dones"][:, dones_idxes]) + step_data["terminated"][:, dones_idxes] = np.zeros_like(step_data["terminated"][:, dones_idxes]) + step_data["truncated"][:, dones_idxes] = np.zeros_like(step_data["truncated"][:, dones_idxes]) step_data["is_first"][:, dones_idxes] = np.ones_like(step_data["is_first"][:, dones_idxes]) player.init_states(dones_idxes) diff --git a/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py b/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py index 97df1ce6..e220fea8 100644 --- a/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py +++ b/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py @@ -264,9 +264,10 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): obs = envs.reset(seed=cfg.seed)[0] for k in obs_keys: step_data[k] = obs[k][np.newaxis] - step_data["dones"] = np.zeros((1, cfg.env.num_envs, 1)) + step_data["terminated"] = np.zeros((1, cfg.env.num_envs, 1)) + step_data["truncated"] = np.zeros((1, cfg.env.num_envs, 1)) step_data["rewards"] = np.zeros((1, cfg.env.num_envs, 1)) - step_data["is_first"] = np.ones_like(step_data["dones"]) + step_data["is_first"] = np.ones_like(step_data["terminated"]) player.init_states() cumulative_per_rank_gradient_steps = 0 @@ -297,16 +298,21 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): step_data["actions"] = actions.reshape((1, cfg.env.num_envs, -1)) rb.add(step_data, validate_args=cfg.buffer.validate_args) - next_obs, rewards, dones, truncated, infos = envs.step(real_actions.reshape(envs.action_space.shape)) - dones = np.logical_or(dones, truncated).astype(np.uint8) + next_obs, rewards, terminated, truncated, infos = envs.step( + real_actions.reshape(envs.action_space.shape) + ) + dones = np.logical_or(terminated, truncated).astype(np.uint8) - step_data["is_first"] = np.zeros_like(step_data["dones"]) + step_data["is_first"] = np.zeros_like(step_data["terminated"]) if "restart_on_exception" in infos: for i, agent_roe in enumerate(infos["restart_on_exception"]): if agent_roe and not dones[i]: last_inserted_idx = (rb.buffer[i]._pos - 1) % rb.buffer[i].buffer_size - rb.buffer[i]["dones"][last_inserted_idx] = np.ones_like( - rb.buffer[i]["dones"][last_inserted_idx] + rb.buffer[i]["terminated"][last_inserted_idx] = np.zeros_like( + rb.buffer[i]["terminated"][last_inserted_idx] + ) + rb.buffer[i]["truncated"][last_inserted_idx] = np.ones_like( + rb.buffer[i]["truncated"][last_inserted_idx] ) rb.buffer[i]["is_first"][last_inserted_idx] = np.zeros_like( rb.buffer[i]["is_first"][last_inserted_idx] @@ -338,7 +344,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): obs = next_obs rewards = rewards.reshape((1, cfg.env.num_envs, -1)) - step_data["dones"] = dones.reshape((1, cfg.env.num_envs, -1)) + step_data["terminated"] = terminated.reshape((1, cfg.env.num_envs, -1)) + step_data["truncated"] = truncated.reshape((1, cfg.env.num_envs, -1)) step_data["rewards"] = clip_rewards_fn(rewards) dones_idxes = dones.nonzero()[0].tolist() @@ -347,15 +354,17 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): reset_data = {} for k in obs_keys: reset_data[k] = (real_next_obs[k][dones_idxes])[np.newaxis] - reset_data["dones"] = np.ones((1, reset_envs, 1)) + reset_data["terminated"] = step_data["terminated"][:, dones_idxes] + reset_data["truncated"] = step_data["truncated"][:, dones_idxes] reset_data["actions"] = np.zeros((1, reset_envs, np.sum(actions_dim))) reset_data["rewards"] = step_data["rewards"][:, dones_idxes] - reset_data["is_first"] = np.zeros_like(reset_data["dones"]) + reset_data["is_first"] = np.zeros_like(reset_data["terminated"]) rb.add(reset_data, dones_idxes, validate_args=cfg.buffer.validate_args) # Reset already inserted step data step_data["rewards"][:, dones_idxes] = np.zeros_like(reset_data["rewards"]) - step_data["dones"][:, dones_idxes] = np.zeros_like(step_data["dones"][:, dones_idxes]) + step_data["terminated"][:, dones_idxes] = np.zeros_like(step_data["terminated"][:, dones_idxes]) + step_data["truncated"][:, dones_idxes] = np.zeros_like(step_data["truncated"][:, dones_idxes]) step_data["is_first"][:, dones_idxes] = np.ones_like(step_data["is_first"][:, dones_idxes]) player.init_states(dones_idxes) diff --git a/sheeprl/algos/ppo/ppo.py b/sheeprl/algos/ppo/ppo.py index 23f634eb..ba2aa447 100644 --- a/sheeprl/algos/ppo/ppo.py +++ b/sheeprl/algos/ppo/ppo.py @@ -283,7 +283,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): actions = torch.cat(actions, -1).cpu().numpy() # Single environment step - obs, rewards, dones, truncated, info = envs.step(real_actions.reshape(envs.action_space.shape)) + obs, rewards, terminated, truncated, info = envs.step(real_actions.reshape(envs.action_space.shape)) truncated_envs = np.nonzero(truncated)[0] if len(truncated_envs) > 0: real_next_obs = { @@ -306,7 +306,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): rewards[truncated_envs] += cfg.algo.gamma * vals.cpu().numpy().reshape( rewards[truncated_envs].shape ) - dones = np.logical_or(dones, truncated).reshape(cfg.env.num_envs, -1).astype(np.uint8) + dones = np.logical_or(terminated, truncated).reshape(cfg.env.num_envs, -1).astype(np.uint8) rewards = rewards.reshape(cfg.env.num_envs, -1) # Update the step data diff --git a/sheeprl/algos/ppo/ppo_decoupled.py b/sheeprl/algos/ppo/ppo_decoupled.py index 42ba7569..866228d7 100644 --- a/sheeprl/algos/ppo/ppo_decoupled.py +++ b/sheeprl/algos/ppo/ppo_decoupled.py @@ -206,7 +206,7 @@ def player( actions = torch.cat(actions, -1).cpu().numpy() # Single environment step - obs, rewards, dones, truncated, info = envs.step(real_actions.reshape(envs.action_space.shape)) + obs, rewards, terminated, truncated, info = envs.step(real_actions.reshape(envs.action_space.shape)) truncated_envs = np.nonzero(truncated)[0] if len(truncated_envs) > 0: real_next_obs = { @@ -229,7 +229,7 @@ def player( rewards[truncated_envs] += cfg.algo.gamma * vals.cpu().numpy().reshape( rewards[truncated_envs].shape ) - dones = np.logical_or(dones, truncated).reshape(cfg.env.num_envs, -1).astype(np.uint8) + dones = np.logical_or(terminated, truncated).reshape(cfg.env.num_envs, -1).astype(np.uint8) rewards = rewards.reshape(cfg.env.num_envs, -1) # Update the step data diff --git a/sheeprl/algos/ppo_recurrent/ppo_recurrent.py b/sheeprl/algos/ppo_recurrent/ppo_recurrent.py index b16a5459..4054261d 100644 --- a/sheeprl/algos/ppo_recurrent/ppo_recurrent.py +++ b/sheeprl/algos/ppo_recurrent/ppo_recurrent.py @@ -308,7 +308,9 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): actions = torch_actions.cpu().numpy() # Single environment step - next_obs, rewards, dones, truncated, info = envs.step(real_actions.reshape(envs.action_space.shape)) + next_obs, rewards, terminated, truncated, info = envs.step( + real_actions.reshape(envs.action_space.shape) + ) truncated_envs = np.nonzero(truncated)[0] if len(truncated_envs) > 0: real_next_obs = { @@ -333,8 +335,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): tuple(s[:, truncated_envs, ...] for s in states), ) vals = player.get_values(rnn_out).view(rewards[truncated_envs].shape).cpu().numpy() - rewards[truncated_envs] += vals.reshape(rewards[truncated_envs].shape) - dones = np.logical_or(dones, truncated).reshape(1, cfg.env.num_envs, -1).astype(np.float32) + rewards[truncated_envs] += cfg.algo.gamma * vals.reshape(rewards[truncated_envs].shape) + dones = np.logical_or(terminated, truncated).reshape(1, cfg.env.num_envs, -1).astype(np.float32) rewards = rewards.reshape(1, cfg.env.num_envs, -1).astype(np.float32) step_data["dones"] = dones.reshape(1, cfg.env.num_envs, -1) diff --git a/sheeprl/algos/sac/sac.py b/sheeprl/algos/sac/sac.py index 58753f9f..ec662cd2 100644 --- a/sheeprl/algos/sac/sac.py +++ b/sheeprl/algos/sac/sac.py @@ -45,7 +45,7 @@ def train( ): # Update the soft-critic next_target_qf_value = agent.get_next_target_q_values( - data["next_observations"], data["rewards"], data["dones"], cfg.algo.gamma + data["next_observations"], data["rewards"], data["terminated"], cfg.algo.gamma ) qf_values = agent.get_q_values(data["observations"], data["actions"]) qf_loss = critic_loss(qf_values, next_target_qf_value, agent.num_critics) @@ -253,9 +253,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): torch_obs = torch.as_tensor(obs, dtype=torch.float32, device=device) actions, _ = actor(torch_obs) actions = actions.cpu().numpy() - next_obs, rewards, dones, truncated, infos = envs.step(actions) + next_obs, rewards, terminated, truncated, infos = envs.step(actions) next_obs = np.concatenate([next_obs[k] for k in cfg.algo.mlp_keys.encoder], axis=-1) - dones = np.logical_or(dones, truncated).reshape(cfg.env.num_envs, -1).astype(np.uint8) rewards = rewards.reshape(cfg.env.num_envs, -1) if cfg.metric.log_level > 0 and "final_info" in infos: @@ -277,7 +276,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): [v for k, v in final_obs.items() if k in cfg.algo.mlp_keys.encoder], axis=-1 ) - step_data["dones"] = dones[np.newaxis] + step_data["terminated"] = terminated.reshape(1, cfg.env.num_envs, -1).astype(np.uint8) + step_data["truncated"] = truncated.reshape(1, cfg.env.num_envs, -1).astype(np.uint8) step_data["actions"] = actions[np.newaxis] step_data["observations"] = obs[np.newaxis] if not cfg.buffer.sample_next_obs: diff --git a/sheeprl/algos/sac/sac_decoupled.py b/sheeprl/algos/sac/sac_decoupled.py index cf75675e..ee0a0ceb 100644 --- a/sheeprl/algos/sac/sac_decoupled.py +++ b/sheeprl/algos/sac/sac_decoupled.py @@ -190,9 +190,8 @@ def player( torch_obs = torch.as_tensor(obs, dtype=torch.float32, device=device) actions, _ = actor(torch_obs) actions = actions.cpu().numpy() - next_obs, rewards, dones, truncated, infos = envs.step(actions) + next_obs, rewards, terminated, truncated, infos = envs.step(actions) next_obs = np.concatenate([next_obs[k] for k in cfg.algo.mlp_keys.encoder], axis=-1) - dones = np.logical_or(dones, truncated).reshape(cfg.env.num_envs, -1).astype(np.uint8) rewards = rewards.reshape(cfg.env.num_envs, -1) if cfg.metric.log_level > 0 and "final_info" in infos: @@ -214,7 +213,8 @@ def player( [v for k, v in final_obs.items() if k in cfg.algo.mlp_keys.encoder], axis=-1 ) - step_data["dones"] = dones[np.newaxis] + step_data["terminated"] = terminated.reshape(1, cfg.env.num_envs, -1).astype(np.uint8) + step_data["truncated"] = truncated.reshape(1, cfg.env.num_envs, -1).astype(np.uint8) step_data["actions"] = actions[np.newaxis] step_data["observations"] = obs[np.newaxis] if not cfg.buffer.sample_next_obs: diff --git a/sheeprl/algos/sac_ae/sac_ae.py b/sheeprl/algos/sac_ae/sac_ae.py index f5a80ed6..86ae8c8f 100644 --- a/sheeprl/algos/sac_ae/sac_ae.py +++ b/sheeprl/algos/sac_ae/sac_ae.py @@ -62,7 +62,7 @@ def train( # Update the soft-critic next_target_qf_value = agent.get_next_target_q_values( - normalized_next_obs, data["rewards"], data["dones"], cfg.algo.gamma + normalized_next_obs, data["rewards"], data["terminated"], cfg.algo.gamma ) qf_values = agent.get_q_values(normalized_obs, data["actions"]) qf_loss = critic_loss(qf_values, next_target_qf_value, agent.num_critics) @@ -324,8 +324,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): torch_obs = {k: torch.from_numpy(v).to(device).float() for k, v in normalized_obs.items()} actions, _ = actor(torch_obs) actions = actions.cpu().numpy() - next_obs, rewards, dones, truncated, infos = envs.step(actions.reshape(envs.action_space.shape)) - dones = np.logical_or(dones, truncated) + next_obs, rewards, terminated, truncated, infos = envs.step(actions.reshape(envs.action_space.shape)) if cfg.metric.log_level > 0 and "final_info" in infos: for i, agent_ep_info in enumerate(infos["final_info"]): @@ -357,7 +356,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): 1, cfg.env.num_envs, -1, *step_data[f"next_{k}"].shape[-2:] ) - step_data["dones"] = dones.reshape(1, cfg.env.num_envs, -1).astype(np.float32) + step_data["terminated"] = terminated.reshape(1, cfg.env.num_envs, -1).astype(np.float32) + step_data["truncated"] = truncated.reshape(1, cfg.env.num_envs, -1).astype(np.float32) step_data["actions"] = actions.reshape(1, cfg.env.num_envs, -1).astype(np.float32) step_data["rewards"] = rewards.reshape(1, cfg.env.num_envs, -1).astype(np.float32) rb.add(step_data, validate_args=cfg.buffer.validate_args) diff --git a/sheeprl/configs/env/minerl.yaml b/sheeprl/configs/env/minerl.yaml index 74e564a6..c3849c4f 100644 --- a/sheeprl/configs/env/minerl.yaml +++ b/sheeprl/configs/env/minerl.yaml @@ -5,6 +5,9 @@ defaults: # Override from `minecraft` config id: custom_navigate action_repeat: 1 +max_episode_steps: 12000 +reward_as_observation: True +num_envs: 4 # Wrapper to be instantiated wrapper: @@ -17,7 +20,7 @@ wrapper: - ${env.max_pitch} seed: null break_speed_multiplier: ${env.break_speed_multiplier} - multihot_inventory: True + multihot_inventory: False sticky_attack: ${env.sticky_attack} sticky_jump: ${env.sticky_jump} dense: True diff --git a/sheeprl/configs/env/minerl_obtain_diamond.yaml b/sheeprl/configs/env/minerl_obtain_diamond.yaml new file mode 100644 index 00000000..55699873 --- /dev/null +++ b/sheeprl/configs/env/minerl_obtain_diamond.yaml @@ -0,0 +1,12 @@ +defaults: + - minerl + - _self_ + +id: custom_obtain_diamond +action_repeat: 1 +max_episode_steps: 36000 +num_envs: 16 + +wrapper: + multihot_inventory: True + dense: False \ No newline at end of file diff --git a/sheeprl/configs/env/minerl_obtain_iron_pickaxe.yaml b/sheeprl/configs/env/minerl_obtain_iron_pickaxe.yaml new file mode 100644 index 00000000..78125a32 --- /dev/null +++ b/sheeprl/configs/env/minerl_obtain_iron_pickaxe.yaml @@ -0,0 +1,12 @@ +defaults: + - minerl + - _self_ + +id: custom_obtain_iron_pickaxe +action_repeat: 1 +max_episode_steps: 36000 +num_envs: 16 + +wrapper: + multihot_inventory: True + dense: False \ No newline at end of file diff --git a/sheeprl/configs/exp/dreamer_v3_dmc_cartpole_swingup_sparse.yaml b/sheeprl/configs/exp/dreamer_v3_dmc_cartpole_swingup_sparse.yaml index 13645795..5f1e1773 100644 --- a/sheeprl/configs/exp/dreamer_v3_dmc_cartpole_swingup_sparse.yaml +++ b/sheeprl/configs/exp/dreamer_v3_dmc_cartpole_swingup_sparse.yaml @@ -26,7 +26,7 @@ checkpoint: # Buffer buffer: - size: 1_000_000 + size: 500_000 checkpoint: True memmap: True @@ -38,7 +38,7 @@ algo: - rgb mlp_keys: encoder: [] - learning_starts: 1024 + learning_starts: 1300 # Metric diff --git a/sheeprl/configs/exp/dreamer_v3_dmc_walker_walk.yaml b/sheeprl/configs/exp/dreamer_v3_dmc_walker_walk.yaml index b7d6ea09..af38aee4 100644 --- a/sheeprl/configs/exp/dreamer_v3_dmc_walker_walk.yaml +++ b/sheeprl/configs/exp/dreamer_v3_dmc_walker_walk.yaml @@ -26,7 +26,7 @@ checkpoint: # Buffer buffer: - size: 1_000_000 + size: 500_000 checkpoint: True memmap: True diff --git a/sheeprl/data/buffers.py b/sheeprl/data/buffers.py index bbf10d5a..8c51c9b1 100644 --- a/sheeprl/data/buffers.py +++ b/sheeprl/data/buffers.py @@ -923,8 +923,10 @@ def add( last_key = current_key last_batch_shape = current_batch_shape - if "dones" not in data: - raise RuntimeError(f"The episode must contain the `dones` key, got: {data.keys()}") + if "terminated" not in data and "truncated" not in data: + raise RuntimeError( + f"The episode must contain the `terminated` and the `truncated` keys, got: {data.keys()}" + ) if env_idxes is not None and (np.array(env_idxes) >= self._n_envs).any(): raise ValueError( @@ -937,7 +939,7 @@ def add( for i, env in enumerate(env_idxes): # Take the data from a single environment env_data = {k: v[:, i] for k, v in data.items()} - done = env_data["dones"] + done = np.logical_or(env_data["terminated"], env_data["truncated"]) # Take episode ends episode_ends = done.nonzero()[0].tolist() # If there is not any done, then add the data to the respective open episode @@ -954,12 +956,15 @@ def add( episode = {k: env_data[k][start : stop + 1] for k in env_data.keys()} # If the episode length is greater than zero, then add it to the open episode # of the corresponding environment. - if len(episode["dones"]) > 0: + if len(np.logical_or(episode["terminated"], episode["truncated"])) > 0: self._open_episodes[env].append(episode) start = stop + 1 # If the open episode is not empty and the last element is a done, then save the episode # in the buffer and clear the open episode - if len(self._open_episodes[env]) > 0 and self._open_episodes[env][-1]["dones"][-1] == 1: + should_save = len(self._open_episodes[env]) > 0 and np.logical_or( + self._open_episodes[env][-1]["terminated"][-1], self._open_episodes[env][-1]["truncated"][-1] + ) + if should_save: self._save_episode(self._open_episodes[env]) self._open_episodes[env] = [] @@ -974,9 +979,10 @@ def _save_episode(self, episode_chunks: Sequence[Dict[str, np.ndarray | MemmapAr episode = {k: np.concatenate(v, axis=0) for k, v in episode.items()} # Control the validity of the episode - ep_len = episode["dones"].shape[0] - if len(episode["dones"].nonzero()[0]) != 1 or episode["dones"][-1] != 1: - raise RuntimeError(f"The episode must contain exactly one done, got: {len(np.nonzero(episode['dones']))}") + ends = np.logical_or(episode["terminated"], episode["truncated"]) + ep_len = ends.shape[0] + if len(ends.nonzero()[0]) != 1 or ends[-1] != 1: + raise RuntimeError(f"The episode must contain exactly one done, got: {len(np.nonzero(ends))}") if ep_len < self._minimum_episode_length: raise RuntimeError( f"Episode too short (at least {self._minimum_episode_length} steps), got: {ep_len} steps" @@ -1076,7 +1082,7 @@ def sample( samples_per_eps.update({f"next_{k}": [] for k in self._obs_keys}) for i, n in enumerate(nsample_per_eps): if n > 0: - ep_len = valid_episodes[i]["dones"].shape[0] + ep_len = np.logical_or(valid_episodes[i]["terminated"], valid_episodes[i]["truncated"]).shape[0] if sample_next_obs: ep_len -= 1 # Define the maximum index that can be sampled in the episodes diff --git a/sheeprl/envs/crafter.py b/sheeprl/envs/crafter.py index ae5a94cc..f0c6f71d 100644 --- a/sheeprl/envs/crafter.py +++ b/sheeprl/envs/crafter.py @@ -50,7 +50,7 @@ def _convert_obs(self, obs: np.ndarray) -> Dict[str, np.ndarray]: def step(self, action: Any) -> Tuple[Any, SupportsFloat, bool, bool, Dict[str, Any]]: obs, reward, done, info = self.env.step(action) - return self._convert_obs(obs), reward, done, False, info + return self._convert_obs(obs), reward, done and info["discount"] == 0, done and info["discount"] != 0, info def reset( self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None diff --git a/sheeprl/envs/diambra.py b/sheeprl/envs/diambra.py index 2e773ac1..002ed0ba 100644 --- a/sheeprl/envs/diambra.py +++ b/sheeprl/envs/diambra.py @@ -123,9 +123,9 @@ def step(self, action: Any) -> Tuple[Any, SupportsFloat, bool, bool, Dict[str, A if self._action_type == "discrete" and isinstance(action, np.ndarray): action = action.squeeze() action = action.item() - obs, reward, done, truncated, infos = self.env.step(action) + obs, reward, terminated, truncated, infos = self.env.step(action) infos["env_domain"] = "DIAMBRA" - return self._convert_obs(obs), reward, done or infos.get("env_done", False), truncated, infos + return self._convert_obs(obs), reward, terminated or infos.get("env_done", False), truncated, infos def render(self, mode: str = "rgb_array", **kwargs) -> Optional[Union[RenderFrame, List[RenderFrame]]]: return self.env.render() diff --git a/sheeprl/envs/dmc.py b/sheeprl/envs/dmc.py index c2643993..3ea0825d 100644 --- a/sheeprl/envs/dmc.py +++ b/sheeprl/envs/dmc.py @@ -220,13 +220,14 @@ def step( action = self._convert_action(action) time_step = self.env.step(action) reward = time_step.reward or 0.0 - done = time_step.last() obs = self._get_obs(time_step) self.current_state = _flatten_obs(time_step.observation) extra = {} extra["discount"] = time_step.discount extra["internal_state"] = self.env.physics.get_state().copy() - return obs, reward, done, False, extra + truncated = time_step.last() and time_step.discount == 1 + terminated = False if time_step.first() else time_step.last() and time_step.discount == 0 + return obs, reward, terminated, truncated, extra def reset( self, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None diff --git a/sheeprl/envs/minedojo.py b/sheeprl/envs/minedojo.py index 00c0837f..17cf7b0d 100644 --- a/sheeprl/envs/minedojo.py +++ b/sheeprl/envs/minedojo.py @@ -69,7 +69,7 @@ def __init__( self._pos = kwargs.get("start_position", None) self._break_speed_multiplier = kwargs.get("break_speed_multiplier", 100) self._start_pos = copy.deepcopy(self._pos) - self._sticky_attack = sticky_attack + self._sticky_attack = 0 if self._break_speed_multiplier > 1 else sticky_attack self._sticky_jump = sticky_jump self._sticky_attack_counter = 0 self._sticky_jump_counter = 0 @@ -83,7 +83,6 @@ def __init__( task_id=id, image_size=(height, width), world_seed=seed, - generate_world_type="default", fast_reset=True, **kwargs, ) @@ -245,6 +244,9 @@ def step(self, action: np.ndarray) -> Tuple[Any, SupportsFloat, bool, bool, Dict action[3] = 12 obs, reward, done, info = self.env.step(action) + is_timelimit = info.get("TimeLimit.truncated", False) + terminated = done and not is_timelimit + truncated = done and is_timelimit self._pos = { "x": float(obs["location_stats"]["pos"][0]), "y": float(obs["location_stats"]["pos"][1]), @@ -252,17 +254,19 @@ def step(self, action: np.ndarray) -> Tuple[Any, SupportsFloat, bool, bool, Dict "pitch": float(obs["location_stats"]["pitch"].item()), "yaw": float(obs["location_stats"]["yaw"].item()), } - info = { - "life_stats": { - "life": float(obs["life_stats"]["life"].item()), - "oxygen": float(obs["life_stats"]["oxygen"].item()), - "food": float(obs["life_stats"]["food"].item()), - }, - "location_stats": copy.deepcopy(self._pos), - "action": a.tolist(), - "biomeid": float(obs["location_stats"]["biome_id"].item()), - } - return self._convert_obs(obs), reward, done, False, info + info.update( + { + "life_stats": { + "life": float(obs["life_stats"]["life"].item()), + "oxygen": float(obs["life_stats"]["oxygen"].item()), + "food": float(obs["life_stats"]["food"].item()), + }, + "location_stats": copy.deepcopy(self._pos), + "action": a.tolist(), + "biomeid": float(obs["location_stats"]["biome_id"].item()), + } + ) + return self._convert_obs(obs), reward, terminated, truncated, info def reset( self, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None diff --git a/sheeprl/envs/minerl.py b/sheeprl/envs/minerl.py index 35047bdf..338d384f 100644 --- a/sheeprl/envs/minerl.py +++ b/sheeprl/envs/minerl.py @@ -85,7 +85,7 @@ def __init__( self._height = height self._width = width self._pitch_limits = pitch_limits - self._sticky_attack = sticky_attack + self._sticky_attack = 0 if break_speed_multiplier > 1 else sticky_attack self._sticky_jump = sticky_jump self._sticky_attack_counter = 0 self._sticky_jump_counter = 0 @@ -303,7 +303,6 @@ def step(self, actions: np.ndarray) -> Tuple[Dict[str, Any], SupportsFloat, bool "pitch": next_pitch, "yaw": next_yaw, } - info = {} return self._convert_obs(obs), reward, done, False, info def reset( diff --git a/sheeprl/envs/minerl_envs/navigate.py b/sheeprl/envs/minerl_envs/navigate.py index f02723d5..619d17e6 100644 --- a/sheeprl/envs/minerl_envs/navigate.py +++ b/sheeprl/envs/minerl_envs/navigate.py @@ -9,7 +9,6 @@ import minerl.herobraine.hero.handlers as handlers from minerl.herobraine.hero.handler import Handler -from minerl.herobraine.hero.mc import MS_PER_STEP from sheeprl.envs.minerl_envs.backend import CustomSimpleEmbodimentEnvSpec @@ -22,7 +21,11 @@ def __init__(self, dense, extreme, *args, **kwargs): suffix += "Dense" if dense else "" name = "CustomMineRLNavigate{}-v0".format(suffix) self.dense, self.extreme = dense, extreme - super().__init__(name, *args, max_episode_steps=6000, **kwargs) + + # The time limit is handled outside because MineRL does not provide a way + # to distinguish between terminated and truncated + kwargs.pop("max_episode_steps", None) + super().__init__(name, *args, max_episode_steps=None, **kwargs) def is_from_folder(self, folder: str) -> bool: return folder == "navigateextreme" if self.extreme else folder == "navigate" @@ -60,7 +63,7 @@ def create_server_world_generators(self) -> List[Handler]: return [handlers.DefaultWorldGenerator(force_reset=True)] def create_server_quit_producers(self) -> List[Handler]: - return [handlers.ServerQuitFromTimeUp(NAVIGATE_STEPS * MS_PER_STEP), handlers.ServerQuitWhenAnyAgentFinishes()] + return [handlers.ServerQuitWhenAnyAgentFinishes()] def create_server_decorators(self) -> List[Handler]: return [ diff --git a/sheeprl/envs/minerl_envs/obtain.py b/sheeprl/envs/minerl_envs/obtain.py index ce5f9d1c..3302024e 100644 --- a/sheeprl/envs/minerl_envs/obtain.py +++ b/sheeprl/envs/minerl_envs/obtain.py @@ -9,7 +9,6 @@ from minerl.herobraine.hero import handlers from minerl.herobraine.hero.handler import Handler -from minerl.herobraine.hero.mc import MS_PER_STEP from sheeprl.envs.minerl_envs.backend import CustomSimpleEmbodimentEnvSpec @@ -28,7 +27,7 @@ def __init__( dense, reward_schedule: List[Dict[str, Union[str, int, float]]], *args, - max_episode_steps=6000, + max_episode_steps=None, **kwargs, ): # 6000 for obtain iron (5 mins) @@ -138,10 +137,7 @@ def create_server_world_generators(self) -> List[Handler]: return [handlers.DefaultWorldGenerator(force_reset=True)] def create_server_quit_producers(self) -> List[Handler]: - return [ - handlers.ServerQuitFromTimeUp(time_limit_ms=self.max_episode_steps * MS_PER_STEP), - handlers.ServerQuitWhenAnyAgentFinishes(), - ] + return [handlers.ServerQuitWhenAnyAgentFinishes()] def create_server_decorators(self) -> List[Handler]: return [] @@ -175,6 +171,9 @@ def determine_success_from_rewards(self, rewards: list) -> bool: class CustomObtainDiamond(CustomObtain): def __init__(self, dense, *args, **kwargs): + # The time limit is handled outside because MineRL does not provide a way + # to distinguish between terminated and truncated + kwargs.pop("max_episode_steps", None) super(CustomObtainDiamond, self).__init__( *args, target_item="diamond", @@ -193,7 +192,7 @@ def __init__(self, dense, *args, **kwargs): dict(type="iron_pickaxe", amount=1, reward=256), dict(type="diamond", amount=1, reward=1024), ], - max_episode_steps=18000, + max_episode_steps=None, **kwargs, ) @@ -251,6 +250,9 @@ def get_docstring(self): class CustomObtainIronPickaxe(CustomObtain): def __init__(self, dense, *args, **kwargs): + # The time limit is handled outside because MineRL does not provide a way + # to distinguish between terminated and truncated + kwargs.pop("max_episode_steps", None) super(CustomObtainIronPickaxe, self).__init__( *args, target_item="iron_pickaxe", @@ -268,6 +270,7 @@ def __init__(self, dense, *args, **kwargs): dict(type="iron_ingot", amount=1, reward=128), dict(type="iron_pickaxe", amount=1, reward=256), ], + max_episode_steps=None, **kwargs, ) diff --git a/sheeprl/envs/super_mario_bros.py b/sheeprl/envs/super_mario_bros.py index 7fe41a13..0c1043fa 100644 --- a/sheeprl/envs/super_mario_bros.py +++ b/sheeprl/envs/super_mario_bros.py @@ -55,7 +55,8 @@ def step(self, action: np.ndarray | int) -> Tuple[Any, SupportsFloat, bool, bool action = action.squeeze().item() obs, reward, done, info = self.env.step(action) converted_obs = {"rgb": obs.copy()} - return converted_obs, reward, done, False, info + is_timelimit = info.get("time", False) + return converted_obs, reward, done and not is_timelimit, done and is_timelimit, info def render(self) -> RenderFrame | list[RenderFrame] | None: rendered_frame: np.ndarray | None = self.env.render(mode=self.render_mode) diff --git a/sheeprl/utils/callback.py b/sheeprl/utils/callback.py index eaa099c6..19e588a7 100644 --- a/sheeprl/utils/callback.py +++ b/sheeprl/utils/callback.py @@ -105,14 +105,14 @@ def _ckpt_rb( """ if isinstance(rb, ReplayBuffer): # clone the true done - state = rb["dones"][(rb._pos - 1) % rb.buffer_size, :].copy() + state = rb["truncated"][(rb._pos - 1) % rb.buffer_size, :].copy() # substitute the last done with all True values (all the environment are truncated) - rb["dones"][(rb._pos - 1) % rb.buffer_size, :] = True + rb["truncated"][(rb._pos - 1) % rb.buffer_size, :] = 1 elif isinstance(rb, EnvIndependentReplayBuffer): state = [] for b in rb.buffer: - state.append(b["dones"][(b._pos - 1) % b.buffer_size, :].copy()) - b["dones"][(b._pos - 1) % b.buffer_size, :] = True + state.append(b["truncated"][(b._pos - 1) % b.buffer_size, :].copy()) + b["truncated"][(b._pos - 1) % b.buffer_size, :] = 1 elif isinstance(rb, EpisodeBuffer): # remove open episodes from the buffer because the state of the environment is not saved state = rb._open_episodes @@ -133,10 +133,10 @@ def _experiment_consistent_rb( """ if isinstance(rb, ReplayBuffer): # reinsert the true dones in the buffer - rb["dones"][(rb._pos - 1) % rb.buffer_size, :] = state + rb["truncated"][(rb._pos - 1) % rb.buffer_size, :] = state elif isinstance(rb, EnvIndependentReplayBuffer): for i, b in enumerate(rb.buffer): - b["dones"][(b._pos - 1) % b.buffer_size, :] = state[i] + b["truncated"][(b._pos - 1) % b.buffer_size, :] = state[i] elif isinstance(rb, EpisodeBuffer): # reinsert the open episodes to continue the training rb._open_episodes = state diff --git a/tests/test_data/test_episode_buffer.py b/tests/test_data/test_episode_buffer.py index 967abfbf..12fdba37 100644 --- a/tests/test_data/test_episode_buffer.py +++ b/tests/test_data/test_episode_buffer.py @@ -35,51 +35,51 @@ def test_episode_buffer_add_episodes(): buf_size = 30 sl = 5 n_envs = 1 - obs_keys = ("dones",) + obs_keys = ("terminated",) rb = EpisodeBuffer(buf_size, sl, n_envs=n_envs, obs_keys=obs_keys) - ep1 = {"dones": np.zeros((sl, n_envs, 1))} - ep2 = {"dones": np.zeros((sl + 5, n_envs, 1))} - ep3 = {"dones": np.zeros((sl + 10, n_envs, 1))} - ep4 = {"dones": np.zeros((sl, n_envs, 1))} - ep1["dones"][-1] = 1 - ep2["dones"][-1] = 1 - ep3["dones"][-1] = 1 - ep4["dones"][-1] = 1 + ep1 = {"terminated": np.zeros((sl, n_envs, 1)), "truncated": np.zeros((sl, n_envs, 1))} + ep2 = {"terminated": np.zeros((sl + 5, n_envs, 1)), "truncated": np.zeros((sl + 5, n_envs, 1))} + ep3 = {"terminated": np.zeros((sl + 10, n_envs, 1)), "truncated": np.zeros((sl + 10, n_envs, 1))} + ep4 = {"terminated": np.zeros((sl, n_envs, 1)), "truncated": np.zeros((sl, n_envs, 1))} + ep1["terminated"][-1] = 1 + ep2["truncated"][-1] = 1 + ep3["terminated"][-1] = 1 + ep4["truncated"][-1] = 1 rb.add(ep1) rb.add(ep2) rb.add(ep3) rb.add(ep4) assert rb.full - assert (rb._buf[-1]["dones"] == ep4["dones"][:, 0]).all() - assert (rb._buf[0]["dones"] == ep2["dones"][:, 0]).all() + assert (rb._buf[-1]["terminated"] == ep4["terminated"][:, 0]).all() + assert (rb._buf[0]["terminated"] == ep2["terminated"][:, 0]).all() def test_episode_buffer_add_single_dict(): buf_size = 5 sl = 5 n_envs = 4 - obs_keys = ("dones",) + obs_keys = ("terminated", "truncated") rb = EpisodeBuffer(buf_size, sl, n_envs=n_envs, obs_keys=obs_keys) - ep1 = {"dones": np.zeros((sl, n_envs, 1))} - ep1["dones"][-1] = 1 + ep1 = {"terminated": np.zeros((sl, n_envs, 1)), "truncated": np.zeros((sl, n_envs, 1))} + ep1["truncated"][-1] = 1 rb.add(ep1) assert rb.full for env in range(n_envs): - assert (rb._buf[0]["dones"] == ep1["dones"][:, env]).all() + assert (rb._buf[0]["terminated"] == ep1["terminated"][:, env]).all() def test_episode_buffer_error_add(): buf_size = 10 sl = 5 n_envs = 4 - obs_keys = ("dones",) + obs_keys = ("terminated",) rb = EpisodeBuffer(buf_size, sl, n_envs=n_envs, obs_keys=obs_keys) ep1 = torch.zeros(sl, n_envs, 1) with pytest.raises(ValueError, match="`data` must be a dictionary containing Numpy arrays, but `data` is of type"): rb.add(ep1, validate_args=True) - ep2 = {"dones": torch.zeros((sl, n_envs, 1))} + ep2 = {"terminated": torch.zeros((sl, n_envs, 1)), "truncated": torch.zeros((sl, n_envs, 1))} with pytest.raises(ValueError, match="`data` must be a dictionary containing Numpy arrays. Found key"): rb.add(ep2, validate_args=True) @@ -87,22 +87,22 @@ def test_episode_buffer_error_add(): with pytest.raises(ValueError, match="The `data` replay buffer must be not None"): rb.add(ep3, validate_args=True) - ep4 = {"dones": np.zeros((1,))} + ep4 = {"terminated": np.zeros((1,)), "truncated": np.zeros((1,))} with pytest.raises(RuntimeError, match=r"`data` must have at least 2: \[sequence_length, n_envs"): rb.add(ep4, validate_args=True) - obs_keys = ("dones", "obs") + obs_keys = ("terminated", "obs") rb = EpisodeBuffer(buf_size, sl, n_envs=n_envs, obs_keys=obs_keys) - ep5 = {"dones": np.zeros((sl, n_envs, 1)), "obs": np.zeros((sl, 1, 6))} + ep5 = {"terminated": np.zeros((sl, n_envs, 1)), "truncated": np.zeros((sl, n_envs, 1)), "obs": np.zeros((sl, 1, 6))} with pytest.raises(RuntimeError, match="Every array in `data` must be congruent in the first 2 dimensions"): rb.add(ep5, validate_args=True) ep6 = {"obs": np.zeros((sl, 1, 6))} - with pytest.raises(RuntimeError, match="The episode must contain the `dones` key"): + with pytest.raises(RuntimeError, match="The episode must contain the `terminated`"): rb.add(ep6, validate_args=True) - ep7 = {"dones": np.zeros((sl, 1, 1))} - ep7["dones"][-1] = 1 + ep7 = {"terminated": np.zeros((sl, 1, 1)), "truncated": np.zeros((sl, 1, 1))} + ep7["terminated"][-1] = 1 with pytest.raises(ValueError, match="The indices of the environment must be integers in"): rb.add(ep7, validate_args=True, env_idxes=[10]) @@ -111,9 +111,9 @@ def test_add_only_for_some_envs(): buf_size = 10 sl = 5 n_envs = 4 - obs_keys = ("dones",) + obs_keys = ("terminated", "truncated") rb = EpisodeBuffer(buf_size, sl, n_envs=n_envs, obs_keys=obs_keys) - ep1 = {"dones": np.zeros((sl, n_envs - 2, 1))} + ep1 = {"terminated": np.zeros((sl, n_envs - 2, 1)), "truncated": np.zeros((sl, n_envs - 2, 1))} rb.add(ep1, env_idxes=[0, 3]) assert len(rb._open_episodes[0]) > 0 assert len(rb._open_episodes[1]) == 0 @@ -125,16 +125,24 @@ def test_save_episode(): buf_size = 100 sl = 5 n_envs = 4 - obs_keys = ("dones",) + obs_keys = ("terminated", "truncated") rb = EpisodeBuffer(buf_size, sl, n_envs=n_envs, obs_keys=obs_keys) - ep_chunks = [{"dones": np.zeros((np.random.randint(1, 8, (1,)).item(), 1))} for _ in range(8)] - ep_chunks[-1]["dones"][-1] = 1 + ep_chunks = [] + for _ in range(8): + chunk_dim = (np.random.randint(1, 8, (1,)).item(), 1) + ep_chunks.append( + { + "terminated": np.zeros(chunk_dim), + "truncated": np.zeros(chunk_dim), + } + ) + ep_chunks[-1]["terminated"][-1] = 1 rb._save_episode(ep_chunks) assert len(rb._buf) == 1 assert ( - np.concatenate([e["dones"] for e in rb.buffer], axis=0) - == np.concatenate([c["dones"] for c in ep_chunks], axis=0) + np.concatenate([e["terminated"] for e in rb.buffer], axis=0) + == np.concatenate([c["terminated"] for c in ep_chunks], axis=0) ).all() @@ -142,29 +150,35 @@ def test_save_episode_errors(): buf_size = 100 sl = 5 n_envs = 4 - obs_keys = ("dones",) + obs_keys = ("terminated", "truncated") rb = EpisodeBuffer(buf_size, sl, n_envs=n_envs, obs_keys=obs_keys) with pytest.raises(RuntimeError, match="Invalid episode, an empty sequence is given"): rb._save_episode([]) - ep_chunks = [{"dones": np.zeros((np.random.randint(1, 8, (1,)).item(), 1))} for _ in range(8)] - ep_chunks[-1]["dones"][-1] = 1 - ep_chunks[0]["dones"][-1] = 1 + ep_chunks = [] + for _ in range(8): + chunk_dim = (np.random.randint(1, 8, (1,)).item(), 1) + ep_chunks.append({"terminated": np.zeros(chunk_dim), "truncated": np.zeros(chunk_dim)}) + ep_chunks[-1]["terminated"][-1] = 1 + ep_chunks[0]["truncated"][-1] = 1 with pytest.raises(RuntimeError, match="The episode must contain exactly one done"): rb._save_episode(ep_chunks) - ep_chunks = [{"dones": np.zeros((np.random.randint(1, 8, (1,)).item(), 1))} for _ in range(8)] - ep_chunks[0]["dones"][-1] = 1 + ep_chunks = [] + for _ in range(8): + chunk_dim = (np.random.randint(1, 8, (1,)).item(), 1) + ep_chunks.append({"terminated": np.zeros(chunk_dim), "truncated": np.zeros(chunk_dim)}) + ep_chunks[0]["terminated"][-1] = 1 with pytest.raises(RuntimeError, match="The episode must contain exactly one done"): rb._save_episode(ep_chunks) - ep_chunks = [{"dones": np.ones((1, 1))}] + ep_chunks = [{"terminated": np.ones((1, 1)), "truncated": np.zeros((1, 1))}] with pytest.raises(RuntimeError, match="Episode too short"): rb._save_episode(ep_chunks) - ep_chunks = [{"dones": np.zeros((110, 1))} for _ in range(8)] - ep_chunks[-1]["dones"][-1] = 1 + ep_chunks = [{"terminated": np.zeros((110, 1)), "truncated": np.zeros((110, 1))} for _ in range(8)] + ep_chunks[-1]["truncated"][-1] = 1 with pytest.raises(RuntimeError, match="Episode too long"): rb._save_episode(ep_chunks) @@ -173,14 +187,19 @@ def test_episode_buffer_sample_one_element(): buf_size = 5 sl = 5 n_envs = 1 - obs_keys = ("dones", "a") + obs_keys = ("terminated", "truncated", "a") rb = EpisodeBuffer(buf_size, sl, n_envs=n_envs, obs_keys=obs_keys) - ep = {"dones": np.zeros((sl, n_envs, 1)), "a": np.random.rand(sl, n_envs, 1)} - ep["dones"][-1] = 1 + ep = { + "terminated": np.zeros((sl, n_envs, 1)), + "truncated": np.zeros((sl, n_envs, 1)), + "a": np.random.rand(sl, n_envs, 1), + } + ep["terminated"][-1] = 1 rb.add(ep) sample = rb.sample(1, n_samples=1, sequence_length=sl) assert rb.full - assert (sample["dones"][0, :, 0] == ep["dones"][:, 0]).all() + assert (sample["terminated"][0, :, 0] == ep["terminated"][:, 0]).all() + assert (sample["truncated"][0, :, 0] == ep["truncated"][:, 0]).all() assert (sample["a"][0, :, 0] == ep["a"][:, 0]).all() @@ -188,42 +207,58 @@ def test_episode_buffer_sample_shapes(): buf_size = 30 sl = 2 n_envs = 1 - obs_keys = ("dones",) + obs_keys = ("terminated", "truncated") rb = EpisodeBuffer(buf_size, sl, n_envs=n_envs, obs_keys=obs_keys) - ep = {"dones": np.zeros((sl, n_envs, 1))} - ep["dones"][-1] = 1 + ep = {"terminated": np.zeros((sl, n_envs, 1)), "truncated": np.zeros((sl, n_envs, 1))} + ep["truncated"][-1] = 1 rb.add(ep) sample = rb.sample(3, n_samples=2, sequence_length=sl) - assert sample["dones"].shape[:-1] == tuple([2, sl, 3]) + assert sample["terminated"].shape[:-1] == tuple([2, sl, 3]) + assert sample["truncated"].shape[:-1] == tuple([2, sl, 3]) sample = rb.sample(3, n_samples=2, sequence_length=sl, clone=True) - assert sample["dones"].shape[:-1] == tuple([2, sl, 3]) + assert sample["terminated"].shape[:-1] == tuple([2, sl, 3]) + assert sample["truncated"].shape[:-1] == tuple([2, sl, 3]) def test_episode_buffer_sample_more_episodes(): buf_size = 100 sl = 15 n_envs = 1 - obs_keys = ("dones", "a") + obs_keys = ("terminated", "truncated", "a") rb = EpisodeBuffer(buf_size, sl, n_envs=n_envs, obs_keys=obs_keys) - ep1 = {"dones": np.zeros((40, n_envs, 1)), "a": np.ones((40, n_envs, 1)) * -1} - ep2 = {"dones": np.zeros((45, n_envs, 1)), "a": np.ones((45, n_envs, 1)) * -2} - ep3 = {"dones": np.zeros((50, n_envs, 1)), "a": np.ones((50, n_envs, 1)) * -3} - ep1["dones"][-1] = 1 - ep2["dones"][-1] = 1 - ep3["dones"][-1] = 1 + ep1 = { + "terminated": np.zeros((40, n_envs, 1)), + "a": np.ones((40, n_envs, 1)) * -1, + "truncated": np.zeros((40, n_envs, 1)), + } + ep2 = { + "terminated": np.zeros((45, n_envs, 1)), + "a": np.ones((45, n_envs, 1)) * -2, + "truncated": np.zeros((45, n_envs, 1)), + } + ep3 = { + "terminated": np.zeros((50, n_envs, 1)), + "a": np.ones((50, n_envs, 1)) * -3, + "truncated": np.zeros((50, n_envs, 1)), + } + ep1["terminated"][-1] = 1 + ep2["truncated"][-1] = 1 + ep3["terminated"][-1] = 1 rb.add(ep1) rb.add(ep2) rb.add(ep3) samples = rb.sample(50, n_samples=5, sequence_length=sl) - assert samples["dones"].shape[:-1] == tuple([5, sl, 50]) + assert samples["terminated"].shape[:-1] == tuple([5, sl, 50]) + assert samples["truncated"].shape[:-1] == tuple([5, sl, 50]) samples = {k: np.moveaxis(samples[k], 2, 1).reshape(-1, sl, 1) for k in obs_keys} - for i in range(len(samples["dones"])): + for i in range(len(samples["terminated"])): assert ( np.isin(samples["a"][i], -1).all() or np.isin(samples["a"][i], -2).all() or np.isin(samples["a"][i], -3).all() ) - assert len(samples["dones"][i].nonzero()[0]) == 0 or samples["dones"][i][-1] == 1 + assert len(samples["terminated"][i].nonzero()[0]) == 0 or samples["terminated"][i][-1] == 1 + assert len(samples["truncated"][i].nonzero()[0]) == 0 or samples["truncated"][i][-1] == 1 def test_episode_buffer_error_sample(): @@ -236,7 +271,7 @@ def test_episode_buffer_error_sample(): rb.sample(-1, n_samples=2) with pytest.raises(ValueError, match="The number of samples must be greater than 0"): rb.sample(2, n_samples=-1) - ep1 = {"dones": np.zeros((15, 1, 1))} + ep1 = {"terminated": np.zeros((15, 1, 1)), "truncated": np.zeros((15, 1, 1))} rb.add(ep1) with pytest.raises(RuntimeError, match="No valid episodes has been added to the buffer"): rb.sample(2, n_samples=2, sequence_length=20) @@ -247,34 +282,38 @@ def test_episode_buffer_prioritize_ends(): buf_size = 100 sl = 15 n_envs = 1 - obs_keys = ("dones",) + obs_keys = ("terminated", "truncated") rb = EpisodeBuffer(buf_size, sl, n_envs=n_envs, obs_keys=obs_keys, prioritize_ends=True) - ep1 = {"dones": np.zeros((15, n_envs, 1))} - ep2 = {"dones": np.zeros((25, n_envs, 1))} - ep3 = {"dones": np.zeros((30, n_envs, 1))} - ep1["dones"][-1] = 1 - ep2["dones"][-1] = 1 - ep3["dones"][-1] = 1 + ep1 = {"terminated": np.zeros((15, n_envs, 1)), "truncated": np.zeros((15, n_envs, 1))} + ep2 = {"terminated": np.zeros((25, n_envs, 1)), "truncated": np.zeros((25, n_envs, 1))} + ep3 = {"terminated": np.zeros((30, n_envs, 1)), "truncated": np.zeros((30, n_envs, 1))} + ep1["truncated"][-1] = 1 + ep2["terminated"][-1] = 1 + ep3["truncated"][-1] = 1 rb.add(ep1) rb.add(ep2) rb.add(ep3) samples = rb.sample(50, n_samples=5, sequence_length=sl) - assert samples["dones"].shape[:-1] == tuple([5, sl, 50]) - assert np.isin(samples["dones"], 1).any() > 0 + assert samples["terminated"].shape[:-1] == tuple([5, sl, 50]) + assert samples["truncated"].shape[:-1] == tuple([5, sl, 50]) + assert np.isin(samples["terminated"], 1).any() > 0 + assert np.isin(samples["truncated"], 1).any() > 0 def test_sample_next_obs(): buf_size = 10 sl = 5 n_envs = 4 - obs_keys = ("dones",) + obs_keys = ("terminated", "truncated") rb = EpisodeBuffer(buf_size, sl, n_envs=n_envs, obs_keys=obs_keys) - ep1 = {"dones": np.zeros((sl, n_envs, 1))} - ep1["dones"][-1] = 1 + ep1 = {"terminated": np.zeros((sl, n_envs, 1)), "truncated": np.zeros((sl, n_envs, 1))} + ep1["terminated"][-1] = 1 rb.add(ep1) sample = rb.sample(10, True, n_samples=5, sequence_length=sl - 1) - assert "next_dones" in sample - assert (sample["next_dones"][:, -1] == 1).all() + assert "next_terminated" in sample + assert "next_truncated" in sample + assert (sample["next_terminated"][:, -1] == 1).all() + assert not (sample["next_truncated"][:, -1] == 1).any() def test_memmap_episode_buffer(): @@ -282,7 +321,7 @@ def test_memmap_episode_buffer(): bs = 4 sl = 4 n_envs = 1 - obs_keys = ("dones", "observations") + obs_keys = ("terminated", "truncated", "observations") with pytest.raises( ValueError, match="The buffer is set to be memory-mapped but the `memmap_dir` attribute is None", @@ -293,11 +332,13 @@ def test_memmap_episode_buffer(): for _ in range(buf_size // bs): ep = { "observations": np.random.randint(0, 256, (bs, n_envs, 3, 64, 64), dtype=np.uint8), - "dones": np.zeros((bs, n_envs, 1)), + "terminated": np.zeros((bs, n_envs, 1)), + "truncated": np.zeros((bs, n_envs, 1)), } - ep["dones"][-1] = 1 + ep["truncated"][-1] = 1 rb.add(ep) - assert isinstance(rb._buf[-1]["dones"], MemmapArray) + assert isinstance(rb._buf[-1]["terminated"], MemmapArray) + assert isinstance(rb._buf[-1]["truncated"], MemmapArray) assert isinstance(rb._buf[-1]["observations"], MemmapArray) assert rb.is_memmap del rb @@ -309,7 +350,7 @@ def test_memmap_to_file_episode_buffer(): bs = 5 sl = 4 n_envs = 1 - obs_keys = ("dones", "observations") + obs_keys = ("terminated", "truncated", "observations") memmap_dir = "test_episode_buffer" rb = EpisodeBuffer(buf_size, sl, n_envs=n_envs, obs_keys=obs_keys, memmap=True, memmap_dir=memmap_dir) for i in range(4): @@ -319,15 +360,19 @@ def test_memmap_to_file_episode_buffer(): bs = 5 ep = { "observations": np.random.randint(0, 256, (bs, n_envs, 3, 64, 64), dtype=np.uint8), - "dones": np.zeros((bs, n_envs, 1)), + "terminated": np.zeros((bs, n_envs, 1)), + "truncated": np.zeros((bs, n_envs, 1)), } - ep["dones"][-1] = 1 + ep["terminated"][-1] = 1 rb.add(ep) del ep - assert isinstance(rb._buf[-1]["dones"], MemmapArray) + assert isinstance(rb._buf[-1]["terminated"], MemmapArray) + assert isinstance(rb._buf[-1]["truncated"], MemmapArray) assert isinstance(rb._buf[-1]["observations"], MemmapArray) - memmap_dir = os.path.dirname(rb._buf[-1]["dones"].filename) - assert os.path.exists(os.path.join(memmap_dir, "dones.memmap")) + memmap_dir = os.path.dirname(rb._buf[-1]["terminated"].filename) + memmap_dir = os.path.dirname(rb._buf[-1]["truncated"].filename) + assert os.path.exists(os.path.join(memmap_dir, "terminated.memmap")) + assert os.path.exists(os.path.join(memmap_dir, "truncated.memmap")) assert os.path.exists(os.path.join(memmap_dir, "observations.memmap")) assert rb.is_memmap for ep in rb.buffer: @@ -342,8 +387,12 @@ def test_sample_tensors(): buf_size = 10 n_envs = 1 rb = EpisodeBuffer(buf_size, n_envs) - td = {"observations": np.arange(8).reshape(-1, 1, 1), "dones": np.zeros((8, 1, 1))} - td["dones"][-1] = 1 + td = { + "observations": np.arange(8).reshape(-1, 1, 1), + "terminated": np.zeros((8, 1, 1)), + "truncated": np.zeros((8, 1, 1)), + } + td["truncated"][-1] = 1 rb.add(td) s = rb.sample_tensors(10, sample_next_obs=True, n_samples=3, sequence_length=5) assert isinstance(s["observations"], torch.Tensor) @@ -363,9 +412,10 @@ def test_sample_tensor_memmap(): rb = EpisodeBuffer(buf_size, n_envs, memmap=True, memmap_dir=memmap_dir, obs_keys=("observations")) td = { "observations": np.random.randint(0, 256, (10, n_envs, 3, 64, 64), dtype=np.uint8), - "dones": np.zeros((buf_size, n_envs, 1)), + "terminated": np.zeros((buf_size, n_envs, 1)), + "truncated": np.zeros((buf_size, n_envs, 1)), } - td["dones"][-1] = 1 + td["terminated"][-1] = 1 rb.add(td) sample = rb.sample_tensors(10, False, n_samples=3, sequence_length=5) assert isinstance(sample["observations"], torch.Tensor) @@ -378,8 +428,14 @@ def test_add_rb(): buf_size = 10 n_envs = 3 rb = ReplayBuffer(buf_size, n_envs) - rb.add({"dones": np.zeros((buf_size, n_envs, 1)), "a": np.random.rand(buf_size, n_envs, 5)}) - rb["dones"][-1] = 1 + rb.add( + { + "terminated": np.zeros((buf_size, n_envs, 1)), + "truncated": np.zeros((buf_size, n_envs, 1)), + "a": np.random.rand(buf_size, n_envs, 5), + } + ) + rb["truncated"][-1] = 1 epb = EpisodeBuffer(buf_size * n_envs, minimum_episode_length=2, n_envs=n_envs) epb.add(rb) assert (rb["a"][:, 0] == epb._buf[0]["a"]).all()