Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Deprecate get_algorithm_class(). #30053

Merged
merged 8 commits into from
Nov 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions rllib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,16 @@ def _setup_logger():

def _register_all():
from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.algorithms.registry import ALGORITHMS, get_algorithm_class
from ray.rllib.algorithms.registry import ALGORITHMS, _get_algorithm_class
from ray.rllib.contrib.registry import CONTRIBUTED_ALGORITHMS

for key in (
list(ALGORITHMS.keys())
+ list(CONTRIBUTED_ALGORITHMS.keys())
+ ["__fake", "__sigmoid_fake_data", "__parameter_tuning"]
for key, get_trainable_class_and_config in list(ALGORITHMS.items()) + list(
CONTRIBUTED_ALGORITHMS.items()
):
register_trainable(key, get_algorithm_class(key))
register_trainable(key, get_trainable_class_and_config()[0])

for key in ["__fake", "__sigmoid_fake_data", "__parameter_tuning"]:
register_trainable(key, _get_algorithm_class(key))

def _see_contrib(name):
"""Returns dummy agent class warning algo is in contrib/."""
Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def from_state(state: Dict) -> "Algorithm":
"No `algorithm_class` key was found in given `state`! "
"Cannot create new Algorithm."
)
# algo_class = get_algorithm_class(algo_class_name)
# algo_class = get_trainable_cls(algo_class_name)
# Create the new algo.
config = state.get("config")
if not config:
Expand Down
4 changes: 2 additions & 2 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from ray.tune.result import TRIAL_INFO
from ray.util import log_once
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.rllib.algorithms.registry import get_algorithm_class
from ray.rllib.env.env_context import EnvContext
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.evaluation.collectors.sample_collector import SampleCollector
Expand Down Expand Up @@ -40,6 +39,7 @@
SampleBatchType,
)
from ray.tune.logger import Logger
from ray.tune.registry import get_trainable_cls

if TYPE_CHECKING:
from ray.rllib.algorithms.algorithm import Algorithm
Expand Down Expand Up @@ -691,7 +691,7 @@ class directly. Note that this arg can also be specified via

algo_class = self.algo_class
if isinstance(self.algo_class, str):
algo_class = get_algorithm_class(self.algo_class)
algo_class = get_trainable_cls(self.algo_class)

return algo_class(
config=self if not use_copy else copy.deepcopy(self),
Expand Down
29 changes: 15 additions & 14 deletions rllib/algorithms/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Tuple, Type, TYPE_CHECKING, Union

from ray.rllib.contrib.registry import CONTRIBUTED_ALGORITHMS
from ray.rllib.utils.deprecation import Deprecated

if TYPE_CHECKING:
from ray.rllib.algorithms.algorithm import Algorithm
Expand Down Expand Up @@ -247,6 +248,11 @@ def _import_td3():
}


@Deprecated(
new="ray.tune.registry.get_trainable_cls([algo name], return_config=False) and cls="
"ray.tune.registry.get_trainable_cls([algo name]); cls.get_default_config();",
error=False,
)
def get_algorithm_class(
alg: str,
return_config=False,
Expand All @@ -269,37 +275,32 @@ def get_algorithm_class(
get_trainer_class = get_algorithm_class


def _get_algorithm_class(alg: str, return_config=False) -> type:
def _get_algorithm_class(alg: str) -> type:
# This helps us get around a circular import (tune calls rllib._register_all when
# checking if a rllib Trainable is registered)
if alg in ALGORITHMS:
class_, config = ALGORITHMS[alg]()
return ALGORITHMS[alg]()[0]
elif alg in CONTRIBUTED_ALGORITHMS:
class_, config = CONTRIBUTED_ALGORITHMS[alg]()
return CONTRIBUTED_ALGORITHMS[alg]()[0]
elif alg == "script":
from ray.tune import script_runner

class_, config = script_runner.ScriptRunner, {}
return script_runner.ScriptRunner
elif alg == "__fake":
from ray.rllib.algorithms.mock import _MockTrainer

class_, config = _MockTrainer, _MockTrainer.get_default_config()
return _MockTrainer
elif alg == "__sigmoid_fake_data":
from ray.rllib.algorithms.mock import _SigmoidFakeData

class_, config = _SigmoidFakeData, _SigmoidFakeData.get_default_config()
return _SigmoidFakeData
elif alg == "__parameter_tuning":
from ray.rllib.algorithms.mock import _ParameterTuningTrainer

class_, config = (
_ParameterTuningTrainer,
_ParameterTuningTrainer.get_default_config(),
)
return _ParameterTuningTrainer
else:
raise Exception("Unknown algorithm {}.".format(alg))

if return_config:
return class_, config
return class_


# Mapping from policy name to where it is located, relative to rllib.algorithms.
# TODO(jungong) : Finish migrating all the policies to PolicyV2, so we can list
Expand Down
34 changes: 16 additions & 18 deletions rllib/algorithms/tests/test_algorithm_export_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,44 +4,42 @@
import unittest

import ray
from ray.rllib.algorithms.registry import get_algorithm_class
from ray.rllib.examples.env.multi_agent import MultiAgentCartPole
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.test_utils import framework_iterator
from ray.tune.registry import get_trainable_cls

tf1, tf, tfv = try_import_tf()
torch, _ = try_import_torch()


def save_test(alg_name, framework="tf", multi_agent=False):
cls, config = get_algorithm_class(alg_name, return_config=True)

config["framework"] = framework

# Switch on saving native DL-framework (tf, torch) model files.
config["export_native_model_files"] = True
cls = get_trainable_cls(alg_name)
config = (
cls.get_default_config().framework(framework)
# Switch on saving native DL-framework (tf, torch) model files.
.checkpointing(export_native_model_files=True)
)

if "DDPG" in alg_name or "SAC" in alg_name:
algo = cls(config=config, env="Pendulum-v1")
config.environment("Pendulum-v1")
algo = config.build()
test_obs = np.array([[0.1, 0.2, 0.3]])
else:
if multi_agent:
config["multiagent"] = {
"policies": {"pol1", "pol2"},
"policy_mapping_fn": (
config.multi_agent(
policies={"pol1", "pol2"},
policy_mapping_fn=(
lambda agent_id, episode, worker, **kwargs: "pol1"
if agent_id == "agent1"
else "pol2"
),
}
config["env"] = MultiAgentCartPole
config["env_config"] = {
"num_agents": 2,
}
)
config.environment(MultiAgentCartPole, env_config={"num_agents": 2})
else:
config["env"] = "CartPole-v1"
algo = cls(config=config)
config.environment("CartPole-v1")
algo = config.build()
test_obs = np.array([[0.1, 0.2, 0.3, 0.4]])

export_dir = os.path.join(
Expand Down
5 changes: 2 additions & 3 deletions rllib/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,13 @@

import ray
import ray.cloudpickle as cloudpickle
from ray.rllib.algorithms.registry import get_algorithm_class
from ray.rllib.env import MultiAgentEnv
from ray.rllib.env.base_env import _DUMMY_AGENT_ID
from ray.rllib.env.env_context import EnvContext
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.utils.spaces.space_utils import flatten_to_single_ndarray
from ray.rllib.common import CLIArguments as cli

from ray.tune.utils import merge_dicts
from ray.tune.registry import get_trainable_cls, _global_registry, ENV_CREATOR

Expand Down Expand Up @@ -208,7 +206,8 @@ def run(
# Use default config for given agent.
if not algo:
raise ValueError("Please provide an algorithm via `--algo`.")
_, config = get_algorithm_class(algo, return_config=True)
algo_cls = get_trainable_cls(algo)
config = algo_cls.get_default_config()

# Make sure worker 0 has an Env.
config["create_env_on_driver"] = True
Expand Down
10 changes: 6 additions & 4 deletions rllib/examples/export/cartpole_dqn_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,22 @@
import os
import ray

from ray.rllib.algorithms.registry import get_algorithm_class
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.framework import try_import_tf
from ray.tune.registry import get_trainable_cls

tf1, tf, tfv = try_import_tf()

ray.init(num_cpus=10)


def train_and_export_policy_and_model(algo_name, num_steps, model_dir, ckpt_dir):
cls, config = get_algorithm_class(algo_name, return_config=True)
cls = get_trainable_cls(algo_name)
config = cls.get_default_config()
# Set exporting native (DL-framework) model files to True.
config["export_native_model_files"] = True
alg = cls(config=config, env="CartPole-v1")
config.export_native_model_files = True
config.env = "CartPole-v1"
alg = config.build()
for _ in range(num_steps):
alg.train()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import ray
from ray import air, tune
from ray.rllib.algorithms.registry import get_algorithm_class
from ray.rllib.algorithms.algorithm import Algorithm
from ray.tune.registry import get_trainable_cls

parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -126,8 +126,7 @@
# Get the last checkpoint from the above training run.
checkpoint = results.get_best_result().checkpoint
# Create new Trainer and restore its state from the last checkpoint.
algo = get_algorithm_class(args.run)(config=config)
algo.restore(checkpoint)
algo = Algorithm.from_checkpoint(checkpoint)

# Create the env to do inference in.
env = gym.make("FrozenLake-v1")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

import ray
from ray import air, tune
from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.algorithms.dt import DTConfig
from ray.rllib.algorithms.registry import get_algorithm_class
from ray.tune.utils.log import Verbosity

if __name__ == "__main__":
Expand Down Expand Up @@ -142,8 +142,7 @@
# Get the last checkpoint from the above training run.
checkpoint = results.get_best_result().checkpoint
# Create new Algorithm and restore its state from the last checkpoint.
algo = get_algorithm_class("DT")(config=config)
algo.restore(checkpoint)
algo = Algorithm.from_checkpoint(checkpoint)

# Create the env to do inference in.
env = gym.make("CartPole-v1")
Expand Down
6 changes: 2 additions & 4 deletions rllib/examples/rnnsac_stateless_cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import ray
from ray import air, tune
from ray.rllib.algorithms.registry import get_algorithm_class
from ray.tune.registry import get_trainable_cls

from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole

Expand Down Expand Up @@ -96,9 +96,7 @@
best_checkpoint = results.get_best_result().best_checkpoints[0][0]
print("Loading checkpoint: {}".format(best_checkpoint))

algo = get_algorithm_class("RNNSAC")(
env=StatelessCartPole, config=checkpoint_config
)
algo = get_trainable_cls("RNNSAC")(env=StatelessCartPole, config=checkpoint_config)
algo.restore(best_checkpoint)

env = algo.env_creator({})
Expand Down
7 changes: 4 additions & 3 deletions rllib/policy/tests/test_export_checkpoint_and_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
import unittest

import ray
from ray.rllib.algorithms.registry import get_algorithm_class
from ray.rllib.examples.env.multi_agent import MultiAgentCartPole
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.test_utils import framework_iterator
from ray.tune.registry import get_trainable_cls

tf1, tf, tfv = try_import_tf()
torch, _ = try_import_torch()
Expand Down Expand Up @@ -69,7 +69,8 @@ def export_test(
multi_agent=False,
tf_expected_to_work=True,
):
cls, config = get_algorithm_class(alg_name, return_config=True)
cls = get_trainable_cls(alg_name)
config = cls.get_default_config().to_dict()
config["framework"] = framework
# Switch on saving native DL-framework (tf, torch) model files.
config["export_native_model_files"] = True
Expand Down Expand Up @@ -192,7 +193,7 @@ def export_test(
class TestExportCheckpointAndModel(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
ray.init(num_cpus=4)
ray.init(num_cpus=4, local_mode=True)

@classmethod
def tearDownClass(cls) -> None:
Expand Down
4 changes: 2 additions & 2 deletions rllib/tests/test_algorithm_checkpoint_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import ray

from ray.rllib.algorithms.registry import get_algorithm_class
from ray.rllib.utils.test_utils import check, framework_iterator
from ray.rllib.algorithms.apex_ddpg import ApexDDPGConfig
from ray.rllib.algorithms.sac import SACConfig
Expand All @@ -16,6 +15,7 @@
from ray.rllib.algorithms.ddpg import DDPGConfig
from ray.rllib.algorithms.ars import ARSConfig
from ray.rllib.algorithms.a3c import A3CConfig
from ray.tune.registry import get_trainable_cls


def get_mean_action(alg, obs):
Expand Down Expand Up @@ -89,7 +89,7 @@ def ckpt_restore_test(algo_name, tf2=False, object_store=False, replay_buffer=Fa
for fw in framework_iterator(config, frameworks=frameworks):
for use_object_store in [False, True] if object_store else [False]:
print("use_object_store={}".format(use_object_store))
cls = get_algorithm_class(algo_name)
cls = get_trainable_cls(algo_name)
if "DDPG" in algo_name or "SAC" in algo_name:
alg1 = cls(config=config, env="Pendulum-v1")
alg2 = cls(config=config, env="Pendulum-v1")
Expand Down
4 changes: 2 additions & 2 deletions rllib/tests/test_eager_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import ray
from ray import air
from ray import tune
from ray.rllib.algorithms.registry import get_algorithm_class
from ray.rllib.utils.framework import try_import_tf
from ray.tune.registry import get_trainable_cls

tf1, tf, tfv = try_import_tf()

Expand All @@ -24,7 +24,7 @@ def check_support(alg, config, test_eager=False, test_trace=True):
else:
config["env"] = "CartPole-v1"

a = get_algorithm_class(alg)
a = get_trainable_cls(alg)
if test_eager:
print("tf-eager: alg={} cont.act={}".format(alg, cont))
config["eager_tracing"] = False
Expand Down
7 changes: 3 additions & 4 deletions rllib/tests/test_supported_multi_agent.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import unittest

import ray
from ray.rllib.algorithms.registry import get_algorithm_class
from ray.rllib.examples.env.multi_agent import MultiAgentCartPole, MultiAgentMountainCar
from ray.rllib.policy.policy import PolicySpec
from ray.rllib.utils.test_utils import check_train_results, framework_iterator
from ray.tune import register_env
from ray.tune.registry import get_trainable_cls, register_env


def check_support_multiagent(alg, config):
Expand Down Expand Up @@ -36,9 +35,9 @@ def policy_mapping_fn(agent_id, episode, worker, **kwargs):
if fw == "tf2" and alg in ["A3C", "APEX", "APEX_DDPG", "IMPALA"]:
continue
if alg in ["DDPG", "APEX_DDPG", "SAC"]:
a = get_algorithm_class(alg)(config=config, env="multi_agent_mountaincar")
a = get_trainable_cls(alg)(config=config, env="multi_agent_mountaincar")
else:
a = get_algorithm_class(alg)(config=config, env="multi_agent_cartpole")
a = get_trainable_cls(alg)(config=config, env="multi_agent_cartpole")

results = a.train()
check_train_results(results)
Expand Down
Loading