Skip to content

Commit

Permalink
Env plugins, remove gym.envs.internal and replace with __internal__ k…
Browse files Browse the repository at this point in the history
…ey (openai#2409)
  • Loading branch information
JesseFarebro authored Sep 16, 2021
1 parent 2754d97 commit 5a4709b
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 27 deletions.
4 changes: 2 additions & 2 deletions gym/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
register,
make,
spec,
load_plugins as _load_plugins,
load_env_plugins as _load_env_plugins,
)

# Hook to load plugins from entry points
_load_plugins()
_load_env_plugins()


# Classic
Expand Down
48 changes: 23 additions & 25 deletions gym/envs/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
import sys
import copy
import importlib

from contextlib import contextmanager
import contextlib

if sys.version_info < (3, 8):
import importlib_metadata as metadata
Expand Down Expand Up @@ -217,7 +216,7 @@ def register(self, id, **kwargs):
logger.warn("Overriding environment {}".format(id))
self.env_specs[id] = EnvSpec(id, **kwargs)

@contextmanager
@contextlib.contextmanager
def namespace(self, ns):
self._ns = ns
yield
Expand All @@ -240,33 +239,32 @@ def spec(id):
return registry.spec(id)


@contextmanager
@contextlib.contextmanager
def namespace(ns):
with registry.namespace(ns):
yield


def load_plugins(
third_party_entry_point="gym.envs", internal_entry_point="gym.envs.internal"
):
def load_env_plugins(entry_point="gym.envs"):
# Load third-party environments
for external in metadata.entry_points().get(third_party_entry_point, []):
if external.attr is not None:
raise error.Error(
"Gym environment plugins must specify a root module to load, not a function"
)
# Force namespace on all `register` calls for third-party envs
with namespace(external.name):
external.load()

# Load plugins which hook into `gym.envs.internal`
# These plugins must be in the whitelist defined at the top of this file
# We don't force a namespace on register calls in this module
for internal in metadata.entry_points().get(internal_entry_point, []):
if internal.module not in plugin_internal_whitelist:
continue
if external.attr is not None:
for plugin in metadata.entry_points().get(entry_point, []):
if plugin.attr is None:
raise error.Error(
"Gym environment plugins must specify a root module to load, not a function"
f"Gym environment plugin `{plugin.module}` must specify a function to execute, not a root module"
)
internal.load()

context = namespace(plugin.name)
if plugin.name == "__internal__":
if plugin.module in plugin_internal_whitelist:
context = contextlib.nullcontext()
else:
logger.warn(
f"Trying to register an internal environment when `{plugin.module}` is not in the whitelist"
)

with context:
fn = plugin.load()
try:
fn()
except Exception as e:
logger.warn(str(e))

0 comments on commit 5a4709b

Please sign in to comment.