Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Provide old default kwargs to Atari environments #2405

Merged
merged 1 commit into from
Sep 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions gym/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,12 +691,19 @@ def _merge(a, b):
# mark it as nondeterministic.
nondeterministic = True

default_kwargs = {
"game": game,
"obs_type": obs_type,
"repeat_action_probability": 0.0,
"full_action_space": False,
"frameskip": (2, 5),
}

register(
id="{}-v0".format(name),
entry_point="ale_py.gym:ALGymEnv",
kwargs={
"game": game,
"obs_type": obs_type,
**default_kwargs,
"repeat_action_probability": 0.25,
},
max_episode_steps=10000,
Expand All @@ -706,7 +713,7 @@ def _merge(a, b):
register(
id="{}-v4".format(name),
entry_point="ale_py.gym:ALGymEnv",
kwargs={"game": game, "obs_type": obs_type},
kwargs={**default_kwargs},
max_episode_steps=100000,
nondeterministic=nondeterministic,
)
Expand All @@ -722,8 +729,7 @@ def _merge(a, b):
id="{}Deterministic-v0".format(name),
entry_point="ale_py.gym:ALGymEnv",
kwargs={
"game": game,
"obs_type": obs_type,
**default_kwargs,
"frameskip": frameskip,
"repeat_action_probability": 0.25,
},
Expand All @@ -734,7 +740,7 @@ def _merge(a, b):
register(
id="{}Deterministic-v4".format(name),
entry_point="ale_py.gym:ALGymEnv",
kwargs={"game": game, "obs_type": obs_type, "frameskip": frameskip},
kwargs={**default_kwargs, "frameskip": frameskip},
max_episode_steps=100000,
nondeterministic=nondeterministic,
)
Expand All @@ -743,8 +749,7 @@ def _merge(a, b):
id="{}NoFrameskip-v0".format(name),
entry_point="ale_py.gym:ALGymEnv",
kwargs={
"game": game,
"obs_type": obs_type,
**default_kwargs,
"frameskip": 1,
"repeat_action_probability": 0.25,
}, # A frameskip of 1 means we get every frame
Expand All @@ -758,8 +763,7 @@ def _merge(a, b):
id="{}NoFrameskip-v4".format(name),
entry_point="ale_py.gym:ALGymEnv",
kwargs={
"game": game,
"obs_type": obs_type,
**default_kwargs,
"frameskip": 1,
}, # A frameskip of 1 means we get every frame
max_episode_steps=frameskip * 100000,
Expand Down
132 changes: 132 additions & 0 deletions gym/envs/tests/test_atari_env_specs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
from gym.envs.registration import registry

from itertools import product


def test_ale_legacy_env_specs():
versions = ["-v0", "-v4"]
suffixes = ["", "NoFrameskip", "Deterministic"]
obs_types = ["", "-ram"]
games = [
"adventure",
"air_raid",
"alien",
"amidar",
"assault",
"asterix",
"asteroids",
"atlantis",
"bank_heist",
"battle_zone",
"beam_rider",
"berzerk",
"bowling",
"boxing",
"breakout",
"carnival",
"centipede",
"chopper_command",
"crazy_climber",
"defender",
"demon_attack",
"double_dunk",
"elevator_action",
"enduro",
"fishing_derby",
"freeway",
"frostbite",
"gopher",
"gravitar",
"hero",
"ice_hockey",
"jamesbond",
"journey_escape",
"kangaroo",
"krull",
"kung_fu_master",
"montezuma_revenge",
"ms_pacman",
"name_this_game",
"phoenix",
"pitfall",
"pong",
"pooyan",
"private_eye",
"qbert",
"riverraid",
"road_runner",
"robotank",
"seaquest",
"skiing",
"solaris",
"space_invaders",
"star_gunner",
"tennis",
"time_pilot",
"tutankham",
"up_n_down",
"venture",
"video_pinball",
"wizard_of_wor",
"yars_revenge",
"zaxxon",
]

# Convert snake case to camel case
games = list(map(lambda x: x.title().replace("_", ""), games))
specs = list(map("".join, product(games, obs_types, suffixes, versions)))

"""
defaults:
repeat_action_probability = 0.0
full_action_space = False
frameskip = (2, 5)
game = "Pong"
obs_type = "ram"
mode = None
difficulty = None

v0: repeat_action_probability = 0.25
v4: inherits defaults

-NoFrameskip: frameskip = 1
-Deterministic: frameskip = 4 or 3 for space_invaders
"""
for spec in specs:
assert spec in registry.env_specs
kwargs = registry.env_specs[spec]._kwargs

# Assert necessary parameters are set
assert "frameskip" in kwargs
assert "game" in kwargs
assert "obs_type" in kwargs
assert "repeat_action_probability" in kwargs
assert "full_action_space" in kwargs

# Common defaults
assert kwargs["full_action_space"] is False
assert "mode" not in kwargs
assert "difficulty" not in kwargs

if "-ram" in spec:
assert kwargs["obs_type"] == "ram"
else:
assert kwargs["obs_type"] == "rgb"

if "NoFrameskip" in spec:
assert kwargs["frameskip"] == 1
elif "Deterministic" in spec:
assert isinstance(kwargs["frameskip"], int)
frameskip = 3 if "SpaceInvaders" in spec else 4
assert kwargs["frameskip"] == frameskip
else:
assert isinstance(kwargs["frameskip"], tuple) and kwargs["frameskip"] == (
2,
5,
)

assert spec.endswith("v0") or spec.endswith("v4")
if spec.endswith("v0"):
assert kwargs["repeat_action_probability"] == 0.25
elif spec.endswith("v4"):
assert kwargs["repeat_action_probability"] == 0.0