Skip to content

Commit

Permalink
Add plugin system for third-party environments (openai#2383)
Browse files Browse the repository at this point in the history
  • Loading branch information
JesseFarebro authored Sep 15, 2021
1 parent 590db96 commit e212043
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 1 deletion.
12 changes: 11 additions & 1 deletion gym/envs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,14 @@
from gym.envs.registration import registry, register, make, spec
from gym.envs.registration import (
registry,
register,
make,
spec,
load_plugins as _load_plugins,
)

# Hook to load plugins from entry points
_load_plugins()


# Classic
# ----------------------------------------
Expand Down
59 changes: 59 additions & 0 deletions gym/envs/registration.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
import re
import sys
import copy
import importlib

from contextlib import contextmanager

if sys.version_info < (3, 8):
import importlib_metadata as metadata
else:
import importlib.metadata as metadata

from gym import error, logger

# This format is true today, but it's *not* an official spec.
Expand All @@ -11,6 +19,9 @@
# to include an optional username.
env_id_re = re.compile(r"^(?:[\w:-]+\/)?([\w:.-]+)-v(\d+)$")

# Whitelist of plugins which can hook into the `gym.envs.internal` entry point.
plugin_internal_whitelist = {"ale_py.gym"}


def load(name):
mod_name, attr_name = name.split(":")
Expand Down Expand Up @@ -95,6 +106,7 @@ class EnvRegistry(object):

def __init__(self):
self.env_specs = {}
self._ns = None

def make(self, path, **kwargs):
if len(kwargs) > 0:
Expand Down Expand Up @@ -183,10 +195,25 @@ def spec(self, path):
raise error.UnregisteredEnv("No registered env with id: {}".format(id))

def register(self, id, **kwargs):
if self._ns is not None:
if "/" in id:
namespace, id = id.split("/")
logger.warn(
f"Custom namespace '{namespace}' is being overrode by namespace '{self._ns}'. "
"If you are developing a plugin you shouldn't specify a namespace in `register` calls. "
"The namespace is specified through the entry point key."
)
id = f"{self._ns}/{id}"
if id in self.env_specs:
logger.warn("Overriding environment {}".format(id))
self.env_specs[id] = EnvSpec(id, **kwargs)

@contextmanager
def namespace(self, ns):
self._ns = ns
yield
self._ns = None


# Have a global registry
registry = EnvRegistry()
Expand All @@ -202,3 +229,35 @@ def make(id, **kwargs):

def spec(id):
return registry.spec(id)


@contextmanager
def namespace(ns):
with registry.namespace(ns):
yield


def load_plugins(
third_party_entry_point="gym.envs", internal_entry_point="gym.envs.internal"
):
# Load third-party environments
for external in metadata.entry_points().get(third_party_entry_point, []):
if external.attr is not None:
raise error.Error(
"Gym environment plugins must specify a root module to load, not a function"
)
# Force namespace on all `register` calls for third-party envs
with namespace(external.name):
external.load()

# Load plugins which hook into `gym.envs.internal`
# These plugins must be in the whitelist defined at the top of this file
# We don't force a namespace on register calls in this module
for internal in metadata.entry_points().get(internal_entry_point, []):
if internal.module not in plugin_internal_whitelist:
continue
if external.attr is not None:
raise error.Error(
"Gym environment plugins must specify a root module to load, not a function"
)
internal.load()
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
install_requires=[
"numpy>=1.18.0",
"cloudpickle>=1.2.0",
"importlib_metadata>=4.8.1; python_version < '3.8'",
],
extras_require=extras,
package_data={
Expand Down

0 comments on commit e212043

Please sign in to comment.