Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/crafter #103

Merged
merged 8 commits into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ The environments supported by sheeprl are:
| MineRL | `pip install -e .[minerl]` | [how_to/minerl](./howto/learn_in_minerl.md) | :heavy_check_mark: |
| MineDojo | `pip install -e .[minedojo]` | [how_to/minedojo](./howto/learn_in_minedojo.md) | :heavy_check_mark: |
| DIAMBRA | `pip install -e .[diambra]` | [how_to/diambra](./howto/learn_in_diambra.md) | :heavy_check_mark: |
| Crafter | `pip install -e .[crafter]` | https://github.com/danijar/crafter | :heavy_check_mark: |


## Why
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ atari = [
minedojo = ["minedojo==0.1", "importlib_resources==5.12.0"]
minerl = ["minerl==0.4.4"]
diambra = ["wheel==0.38.4", "setuptools<=66.0.0", "gym==0.21.0", "diambra==0.0.16", "diambra-arena==2.1.2"]
crafter = ["crafter==1.8.1"]

[tool.ruff]
line-length = 120
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/dreamer_v1/dreamer_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,8 +413,8 @@ def main(fabric: Fabric, cfg: DictConfig):
raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}")
if cfg.cnn_keys.encoder == [] and cfg.mlp_keys.encoder == []:
raise RuntimeError(
"You should specify at least one CNN keys or MLP keys from the cli: `cnn_keys.encoder=[rgb]` "
"or `mlp_keys.encoder=[state]` "
"You should specify at least one CNN keys or MLP keys from the cli: "
"`cnn_keys.encoder=[rgb]` or `mlp_keys.encoder=[state]`"
)
if (
len(set(cfg.cnn_keys.encoder).intersection(set(cfg.cnn_keys.decoder))) == 0
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/dreamer_v2/dreamer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,8 +432,8 @@ def main(fabric: Fabric, cfg: DictConfig):
raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}")
if cfg.cnn_keys.encoder == [] and cfg.mlp_keys.encoder == []:
raise RuntimeError(
"You should specify at least one CNN keys or MLP keys from the cli: `cnn_keys.encoder=[rgb]` "
"or `mlp_keys.encoder=[state]` "
"You should specify at least one CNN keys or MLP keys from the cli: "
"`cnn_keys.encoder=[rgb]` or `mlp_keys.encoder=[state]`"
)
if (
len(set(cfg.cnn_keys.encoder).intersection(set(cfg.cnn_keys.decoder))) == 0
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/dreamer_v3/dreamer_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,8 +389,8 @@ def main(fabric: Fabric, cfg: DictConfig):
raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}")
if cfg.cnn_keys.encoder == [] and cfg.mlp_keys.encoder == []:
raise RuntimeError(
"You should specify at least one CNN keys or MLP keys from the cli: `cnn_keys.encoder=[rgb]` "
"or `mlp_keys.encoder=[state]` "
"You should specify at least one CNN keys or MLP keys from the cli: "
"`cnn_keys.encoder=[rgb]` or `mlp_keys.encoder=[state]`"
)
if (
len(set(cfg.cnn_keys.encoder).intersection(set(cfg.cnn_keys.decoder))) == 0
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/p2e_dv1/p2e_dv1.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,8 +417,8 @@ def main(fabric: Fabric, cfg: DictConfig):
raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}")
if cfg.cnn_keys.encoder == [] and cfg.mlp_keys.encoder == []:
raise RuntimeError(
"You should specify at least one CNN keys or MLP keys from the cli: `cnn_keys.encoder=[rgb]` "
"or `mlp_keys.encoder=[state]` "
"You should specify at least one CNN keys or MLP keys from the cli: "
"`cnn_keys.encoder=[rgb]` or `mlp_keys.encoder=[state]`"
)
if (
len(set(cfg.cnn_keys.encoder).intersection(set(cfg.cnn_keys.decoder))) == 0
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/p2e_dv2/p2e_dv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,8 +519,8 @@ def main(fabric: Fabric, cfg: DictConfig):
raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}")
if cfg.cnn_keys.encoder == [] and cfg.mlp_keys.encoder == []:
raise RuntimeError(
"You should specify at least one CNN keys or MLP keys from the cli: `cnn_keys.encoder=[rgb]` "
"or `mlp_keys.encoder=[state]` "
"You should specify at least one CNN keys or MLP keys from the cli: "
"`cnn_keys.encoder=[rgb]` or `mlp_keys.encoder=[state]`"
)
if (
len(set(cfg.cnn_keys.encoder).intersection(set(cfg.cnn_keys.decoder))) == 0
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,8 @@ def main(fabric: Fabric, cfg: DictConfig):
raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}")
if cfg.cnn_keys.encoder + cfg.mlp_keys.encoder == []:
raise RuntimeError(
"You should specify at least one CNN keys or MLP keys from the cli: `cnn_keys.encoder=[rgb]` "
"or `mlp_keys.encoder=[state]` "
"You should specify at least one CNN keys or MLP keys from the cli: "
"`cnn_keys.encoder=[rgb]` or `mlp_keys.encoder=[state]`"
)
fabric.print("Encoder CNN keys:", cfg.cnn_keys.encoder)
fabric.print("Encoder MLP keys:", cfg.mlp_keys.encoder)
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/ppo/ppo_decoupled.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ def player(
raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}")
if cfg.cnn_keys.encoder + cfg.mlp_keys.encoder == []:
raise RuntimeError(
"You should specify at least one CNN keys or MLP keys from the cli: `cnn_keys.encoder=[rgb]` "
"or `mlp_keys.encoder=[state]` "
"You should specify at least one CNN keys or MLP keys from the cli: "
"`cnn_keys.encoder=[rgb]` or `mlp_keys.encoder=[state]`"
)
fabric.print("Encoder CNN keys:", cfg.cnn_keys.encoder)
fabric.print("Encoder MLP keys:", cfg.mlp_keys.encoder)
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/sac_ae/sac_ae.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,8 @@ def main(fabric: Fabric, cfg: DictConfig):
raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}")
if cfg.cnn_keys.encoder == [] and cfg.mlp_keys.encoder == []:
raise RuntimeError(
"You should specify at least one CNN keys or MLP keys from the cli: `cnn_keys.encoder=[rgb]` "
"or `mlp_keys.encoder=[state]` "
"You should specify at least one CNN keys or MLP keys from the cli: "
"`cnn_keys.encoder=[rgb]` or `mlp_keys.encoder=[state]`"
)
if (
len(set(cfg.cnn_keys.encoder).intersection(set(cfg.cnn_keys.decoder))) == 0
Expand Down
16 changes: 16 additions & 0 deletions sheeprl/configs/env/crafter.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
defaults:
- default
- _self_

# Override from `default` config
id: reward
action_repeat: 1
capture_video: False
reward_as_observation: True

# Wrapper to be instantiated
wrapper:
_target_: sheeprl.envs.crafter.CrafterWrapper
id: ${env.id}
screen_size: ${env.screen_size}
seed: ${seed}
50 changes: 50 additions & 0 deletions sheeprl/configs/exp/dreamer_v3_XL_crafter.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# @package _global_

defaults:
- dreamer_v3
- override /env: crafter
- _self_

# Experiment
seed: 5
total_steps: 1000000

# Environment
env:
num_envs: 1
id: reward

# Checkpoint
checkpoint:
every: 100000

# Buffer
buffer:
checkpoint: True

# The CNN and MLP keys of the decoder are the same as those of the encoder by default
cnn_keys:
encoder:
- rgb
decoder:
- rgb
mlp_keys:
encoder:
- reward
decoder: []

# Algorithm
algo:
train_every: 2
learning_starts: 1024
dense_units: 1024
mlp_layers: 5
world_model:
encoder:
cnn_channels_multiplier: 96
recurrent_model:
recurrent_state_size: 4096
transition_model:
hidden_size: 1024
representation_model:
hidden_size: 1024
60 changes: 60 additions & 0 deletions sheeprl/envs/crafter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from sheeprl.utils.imports import _IS_CRAFTER_AVAILABLE

if not _IS_CRAFTER_AVAILABLE:
raise ModuleNotFoundError(_IS_CRAFTER_AVAILABLE)

from typing import Any, Dict, List, Optional, SupportsFloat, Tuple, Union

import crafter
import numpy as np
from gymnasium import core, spaces
from gymnasium.core import RenderFrame


class CrafterWrapper(core.Env):
def __init__(self, id: str, screen_size: Union[int, Tuple[int, int]] = 64, seed: Optional[int] = None) -> None:
assert id in {"reward", "nonreward"}
if isinstance(screen_size, int):
screen_size = (screen_size,) * 2

self._env = crafter.Env(size=screen_size, seed=seed, reward=(id == "reward"))
self.observation_space = spaces.Dict(
{
"rgb": spaces.Box(
self._env.observation_space.low,
self._env.observation_space.high,
self._env.observation_space.shape,
self._env.observation_space.dtype,
)
}
)
self.action_space = spaces.Discrete(self._env.action_space.n)
self.reward_range = self._env.reward_range or (-np.inf, np.inf)
self.observation_space.seed(seed)
self.action_space.seed(seed)

# render
self._render_mode: str = "rgb_array"

@property
def render_mode(self) -> str:
return self._render_mode

def _convert_obs(self, obs: np.ndarray) -> Dict[str, np.ndarray]:
return {"rgb": obs}

def step(self, action: Any) -> Tuple[Any, SupportsFloat, bool, bool, Dict[str, Any]]:
obs, reward, done, info = self._env.step(action)
return self._convert_obs(obs), reward, done, False, info

def reset(
self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None
) -> Tuple[Any, Dict[str, Any]]:
obs = self._env.reset()
return self._convert_obs(obs), {}

def render(self) -> Optional[Union[RenderFrame, List[RenderFrame]]]:
return self._env.render()

def close(self) -> None:
return super().close()
4 changes: 3 additions & 1 deletion sheeprl/envs/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,9 @@ class RewardAsObservationWrapper(gym.Wrapper):
def __init__(self, env: Env) -> None:
super().__init__(env)
self._env = env
reward_range = self._env.reward_range if hasattr(self._env, "reward_range") else (-np.inf, np.inf)
reward_range = (
self._env.reward_range or (-np.inf, np.inf) if hasattr(self._env, "reward_range") else (-np.inf, np.inf)
)
# The reward is assumed to be a scalar
if isinstance(self._env.observation_space, gym.spaces.Dict):
self.observation_space = gym.spaces.Dict(
Expand Down
1 change: 1 addition & 0 deletions sheeprl/utils/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

_IS_ATARI_AVAILABLE = RequirementCache("gymnasium[atari]")
_IS_ATARI_ROMS_AVAILABLE = RequirementCache("gymnasium[accept-rom-license]")
_IS_CRAFTER_AVAILABLE = RequirementCache("crafter")
_IS_DIAMBRA_AVAILABLE = RequirementCache("diambra")
_IS_DIAMBRA_ARENA_AVAILABLE = RequirementCache("diambra-arena")
_IS_DMC_AVAILABLE = RequirementCache("dm_control")
Expand Down