Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/bernoulli #186

Merged
merged 4 commits into from
Jan 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions sheeprl/algos/dreamer_v3/dreamer_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from lightning.fabric import Fabric
from lightning.fabric.wrappers import _FabricModule
from torch import Tensor
from torch.distributions import Bernoulli, Distribution, Independent
from torch.distributions import Distribution, Independent
from torch.optim import Optimizer
from torchmetrics import SumMetric

Expand All @@ -27,6 +27,7 @@
from sheeprl.data.buffers import EnvIndependentReplayBuffer, SequentialReplayBuffer
from sheeprl.envs.wrappers import RestartOnException
from sheeprl.utils.distribution import (
BernoulliSafeMode,
MSEDistribution,
OneHotCategoricalValidateArgs,
SymlogDistribution,
Expand Down Expand Up @@ -145,7 +146,7 @@ def train(

# Compute the distribution over the terminal steps, if required
pc = Independent(
Bernoulli(logits=world_model.continue_model(latent_states), validate_args=validate_args),
BernoulliSafeMode(logits=world_model.continue_model(latent_states), validate_args=validate_args),
1,
validate_args=validate_args,
)
Expand Down Expand Up @@ -229,7 +230,7 @@ def train(
predicted_values = TwoHotEncodingDistribution(critic(imagined_trajectories), dims=1).mean
predicted_rewards = TwoHotEncodingDistribution(world_model.reward_model(imagined_trajectories), dims=1).mean
continues = Independent(
Bernoulli(logits=world_model.continue_model(imagined_trajectories), validate_args=validate_args),
BernoulliSafeMode(logits=world_model.continue_model(imagined_trajectories), validate_args=validate_args),
1,
validate_args=validate_args,
).mode
Expand Down
9 changes: 5 additions & 4 deletions sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from lightning.fabric.wrappers import _FabricModule, _FabricOptimizer
from omegaconf import DictConfig
from torch import Tensor, nn
from torch.distributions import Bernoulli, Distribution, Independent
from torch.distributions import Distribution, Independent
from torchmetrics import SumMetric

from sheeprl.algos.dreamer_v3.agent import PlayerDV3, WorldModel
Expand All @@ -21,6 +21,7 @@
from sheeprl.algos.p2e_dv3.agent import build_agent
from sheeprl.data.buffers import EnvIndependentReplayBuffer, SequentialReplayBuffer
from sheeprl.utils.distribution import (
BernoulliSafeMode,
MSEDistribution,
OneHotCategoricalValidateArgs,
SymlogDistribution,
Expand Down Expand Up @@ -161,7 +162,7 @@ def train(

# Compute the distribution over the terminal steps, if required
pc = Independent(
Bernoulli(logits=world_model.continue_model(latent_states.detach()), validate_args=validate_args),
BernoulliSafeMode(logits=world_model.continue_model(latent_states.detach()), validate_args=validate_args),
1,
validate_args=validate_args,
)
Expand Down Expand Up @@ -268,7 +269,7 @@ def train(
# Predict values and continues
predicted_values = TwoHotEncodingDistribution(critic["module"](imagined_trajectories), dims=1).mean
continues = Independent(
Bernoulli(logits=world_model.continue_model(imagined_trajectories), validate_args=validate_args),
BernoulliSafeMode(logits=world_model.continue_model(imagined_trajectories), validate_args=validate_args),
1,
validate_args=validate_args,
).mode
Expand Down Expand Up @@ -412,7 +413,7 @@ def train(
predicted_values = TwoHotEncodingDistribution(critic_task(imagined_trajectories), dims=1).mean
predicted_rewards = TwoHotEncodingDistribution(world_model.reward_model(imagined_trajectories), dims=1).mean
continues = Independent(
Bernoulli(logits=world_model.continue_model(imagined_trajectories), validate_args=validate_args),
BernoulliSafeMode(logits=world_model.continue_model(imagined_trajectories), validate_args=validate_args),
1,
validate_args=validate_args,
).mode
Expand Down
11 changes: 5 additions & 6 deletions sheeprl/configs/exp/dreamer_v3_dmc_walker_walk.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ seed: 5

# Environment
env:
num_envs: 1
max_episode_steps: 1000
num_envs: 4
max_episode_steps: -1
id: walker_walk
wrapper:
from_vectors: True
from_vectors: False
from_pixels: True

# Checkpoint
Expand All @@ -34,9 +34,8 @@ algo:
encoder:
- rgb
mlp_keys:
encoder:
- state
learning_starts: 8000
encoder: []
learning_starts: 1024
train_every: 2
dense_units: 512
mlp_layers: 2
Expand Down
12 changes: 11 additions & 1 deletion sheeprl/utils/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.distributions import Categorical, Distribution, constraints
from torch.distributions import Bernoulli, Categorical, Distribution, constraints
from torch.distributions.kl import _kl_categorical_categorical, register_kl
from torch.distributions.utils import broadcast_all

Expand Down Expand Up @@ -402,3 +402,13 @@ def rsample(self, sample_shape=torch.Size()):
@register_kl(OneHotCategoricalValidateArgs, OneHotCategoricalValidateArgs)
def _kl_onehotcategoricalvalidateargs_onehotcategoricalvalidateargs(p, q):
return _kl_categorical_categorical(p._categorical, q._categorical)


class BernoulliSafeMode(Bernoulli):
def __init__(self, probs=None, logits=None, validate_args=None):
super().__init__(probs, logits, validate_args)

@property
def mode(self):
mode = (self.probs > 0.5).to(self.probs)
return mode