From 2c9c0b39814d62bf213d205b2d14707dd595d10b Mon Sep 17 00:00:00 2001 From: Federico Belotti Date: Fri, 12 Jan 2024 09:35:33 +0100 Subject: [PATCH] Fix/bernoulli (#186) * TF-like Bernoulli mode * pre-commit * Default dmc config --- sheeprl/algos/dreamer_v3/dreamer_v3.py | 7 ++++--- sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py | 9 +++++---- sheeprl/configs/exp/dreamer_v3_dmc_walker_walk.yaml | 11 +++++------ sheeprl/utils/distribution.py | 12 +++++++++++- 4 files changed, 25 insertions(+), 14 deletions(-) diff --git a/sheeprl/algos/dreamer_v3/dreamer_v3.py b/sheeprl/algos/dreamer_v3/dreamer_v3.py index 77af7974..964bb98d 100644 --- a/sheeprl/algos/dreamer_v3/dreamer_v3.py +++ b/sheeprl/algos/dreamer_v3/dreamer_v3.py @@ -17,7 +17,7 @@ from lightning.fabric import Fabric from lightning.fabric.wrappers import _FabricModule from torch import Tensor -from torch.distributions import Bernoulli, Distribution, Independent +from torch.distributions import Distribution, Independent from torch.optim import Optimizer from torchmetrics import SumMetric @@ -27,6 +27,7 @@ from sheeprl.data.buffers import EnvIndependentReplayBuffer, SequentialReplayBuffer from sheeprl.envs.wrappers import RestartOnException from sheeprl.utils.distribution import ( + BernoulliSafeMode, MSEDistribution, OneHotCategoricalValidateArgs, SymlogDistribution, @@ -145,7 +146,7 @@ def train( # Compute the distribution over the terminal steps, if required pc = Independent( - Bernoulli(logits=world_model.continue_model(latent_states), validate_args=validate_args), + BernoulliSafeMode(logits=world_model.continue_model(latent_states), validate_args=validate_args), 1, validate_args=validate_args, ) @@ -229,7 +230,7 @@ def train( predicted_values = TwoHotEncodingDistribution(critic(imagined_trajectories), dims=1).mean predicted_rewards = TwoHotEncodingDistribution(world_model.reward_model(imagined_trajectories), dims=1).mean continues = Independent( - Bernoulli(logits=world_model.continue_model(imagined_trajectories), validate_args=validate_args), + BernoulliSafeMode(logits=world_model.continue_model(imagined_trajectories), validate_args=validate_args), 1, validate_args=validate_args, ).mode diff --git a/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py b/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py index 3cc37d17..3a20fd6b 100644 --- a/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py +++ b/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py @@ -12,7 +12,7 @@ from lightning.fabric.wrappers import _FabricModule, _FabricOptimizer from omegaconf import DictConfig from torch import Tensor, nn -from torch.distributions import Bernoulli, Distribution, Independent +from torch.distributions import Distribution, Independent from torchmetrics import SumMetric from sheeprl.algos.dreamer_v3.agent import PlayerDV3, WorldModel @@ -21,6 +21,7 @@ from sheeprl.algos.p2e_dv3.agent import build_agent from sheeprl.data.buffers import EnvIndependentReplayBuffer, SequentialReplayBuffer from sheeprl.utils.distribution import ( + BernoulliSafeMode, MSEDistribution, OneHotCategoricalValidateArgs, SymlogDistribution, @@ -161,7 +162,7 @@ def train( # Compute the distribution over the terminal steps, if required pc = Independent( - Bernoulli(logits=world_model.continue_model(latent_states.detach()), validate_args=validate_args), + BernoulliSafeMode(logits=world_model.continue_model(latent_states.detach()), validate_args=validate_args), 1, validate_args=validate_args, ) @@ -268,7 +269,7 @@ def train( # Predict values and continues predicted_values = TwoHotEncodingDistribution(critic["module"](imagined_trajectories), dims=1).mean continues = Independent( - Bernoulli(logits=world_model.continue_model(imagined_trajectories), validate_args=validate_args), + BernoulliSafeMode(logits=world_model.continue_model(imagined_trajectories), validate_args=validate_args), 1, validate_args=validate_args, ).mode @@ -412,7 +413,7 @@ def train( predicted_values = TwoHotEncodingDistribution(critic_task(imagined_trajectories), dims=1).mean predicted_rewards = TwoHotEncodingDistribution(world_model.reward_model(imagined_trajectories), dims=1).mean continues = Independent( - Bernoulli(logits=world_model.continue_model(imagined_trajectories), validate_args=validate_args), + BernoulliSafeMode(logits=world_model.continue_model(imagined_trajectories), validate_args=validate_args), 1, validate_args=validate_args, ).mode diff --git a/sheeprl/configs/exp/dreamer_v3_dmc_walker_walk.yaml b/sheeprl/configs/exp/dreamer_v3_dmc_walker_walk.yaml index b97262b9..89347c81 100644 --- a/sheeprl/configs/exp/dreamer_v3_dmc_walker_walk.yaml +++ b/sheeprl/configs/exp/dreamer_v3_dmc_walker_walk.yaml @@ -10,11 +10,11 @@ seed: 5 # Environment env: - num_envs: 1 - max_episode_steps: 1000 + num_envs: 4 + max_episode_steps: -1 id: walker_walk wrapper: - from_vectors: True + from_vectors: False from_pixels: True # Checkpoint @@ -34,9 +34,8 @@ algo: encoder: - rgb mlp_keys: - encoder: - - state - learning_starts: 8000 + encoder: [] + learning_starts: 1024 train_every: 2 dense_units: 512 mlp_layers: 2 diff --git a/sheeprl/utils/distribution.py b/sheeprl/utils/distribution.py index cfc1bcc5..31765bb6 100644 --- a/sheeprl/utils/distribution.py +++ b/sheeprl/utils/distribution.py @@ -7,7 +7,7 @@ import torch import torch.nn.functional as F from torch import Tensor -from torch.distributions import Categorical, Distribution, constraints +from torch.distributions import Bernoulli, Categorical, Distribution, constraints from torch.distributions.kl import _kl_categorical_categorical, register_kl from torch.distributions.utils import broadcast_all @@ -402,3 +402,13 @@ def rsample(self, sample_shape=torch.Size()): @register_kl(OneHotCategoricalValidateArgs, OneHotCategoricalValidateArgs) def _kl_onehotcategoricalvalidateargs_onehotcategoricalvalidateargs(p, q): return _kl_categorical_categorical(p._categorical, q._categorical) + + +class BernoulliSafeMode(Bernoulli): + def __init__(self, probs=None, logits=None, validate_args=None): + super().__init__(probs, logits, validate_args) + + @property + def mode(self): + mode = (self.probs > 0.5).to(self.probs) + return mode