diff --git a/gym/envs/__init__.py b/gym/envs/__init__.py index d0e93cceb24..b2e9ee118af 100644 --- a/gym/envs/__init__.py +++ b/gym/envs/__init__.py @@ -1,4 +1,14 @@ -from gym.envs.registration import registry, register, make, spec +from gym.envs.registration import ( + registry, + register, + make, + spec, + load_plugins as _load_plugins, +) + +# Hook to load plugins from entry points +_load_plugins() + # Classic # ---------------------------------------- diff --git a/gym/envs/registration.py b/gym/envs/registration.py index b7ba24337d8..a29bfd4f07b 100644 --- a/gym/envs/registration.py +++ b/gym/envs/registration.py @@ -1,7 +1,15 @@ import re +import sys import copy import importlib +from contextlib import contextmanager + +if sys.version_info < (3, 8): + import importlib_metadata as metadata +else: + import importlib.metadata as metadata + from gym import error, logger # This format is true today, but it's *not* an official spec. @@ -11,6 +19,9 @@ # to include an optional username. env_id_re = re.compile(r"^(?:[\w:-]+\/)?([\w:.-]+)-v(\d+)$") +# Whitelist of plugins which can hook into the `gym.envs.internal` entry point. +plugin_internal_whitelist = {"ale_py.gym"} + def load(name): mod_name, attr_name = name.split(":") @@ -95,6 +106,7 @@ class EnvRegistry(object): def __init__(self): self.env_specs = {} + self._ns = None def make(self, path, **kwargs): if len(kwargs) > 0: @@ -183,10 +195,25 @@ def spec(self, path): raise error.UnregisteredEnv("No registered env with id: {}".format(id)) def register(self, id, **kwargs): + if self._ns is not None: + if "/" in id: + namespace, id = id.split("/") + logger.warn( + f"Custom namespace '{namespace}' is being overrode by namespace '{self._ns}'. " + "If you are developing a plugin you shouldn't specify a namespace in `register` calls. " + "The namespace is specified through the entry point key." + ) + id = f"{self._ns}/{id}" if id in self.env_specs: logger.warn("Overriding environment {}".format(id)) self.env_specs[id] = EnvSpec(id, **kwargs) + @contextmanager + def namespace(self, ns): + self._ns = ns + yield + self._ns = None + # Have a global registry registry = EnvRegistry() @@ -202,3 +229,35 @@ def make(id, **kwargs): def spec(id): return registry.spec(id) + + +@contextmanager +def namespace(ns): + with registry.namespace(ns): + yield + + +def load_plugins( + third_party_entry_point="gym.envs", internal_entry_point="gym.envs.internal" +): + # Load third-party environments + for external in metadata.entry_points().get(third_party_entry_point, []): + if external.attr is not None: + raise error.Error( + "Gym environment plugins must specify a root module to load, not a function" + ) + # Force namespace on all `register` calls for third-party envs + with namespace(external.name): + external.load() + + # Load plugins which hook into `gym.envs.internal` + # These plugins must be in the whitelist defined at the top of this file + # We don't force a namespace on register calls in this module + for internal in metadata.entry_points().get(internal_entry_point, []): + if internal.module not in plugin_internal_whitelist: + continue + if external.attr is not None: + raise error.Error( + "Gym environment plugins must specify a root module to load, not a function" + ) + internal.load() diff --git a/setup.py b/setup.py index da98e4f627f..4e1aa09a3b8 100644 --- a/setup.py +++ b/setup.py @@ -44,6 +44,7 @@ install_requires=[ "numpy>=1.18.0", "cloudpickle>=1.2.0", + "importlib_metadata>=4.8.1; python_version < '3.8'", ], extras_require=extras, package_data={