Skip to content

Commit

Permalink
[RLlib] Deprecate get_algorithm_class(). (#30053)
Browse files Browse the repository at this point in the history
  • Loading branch information
sven1977 authored Nov 9, 2022
1 parent f28d731 commit 3d06343
Show file tree
Hide file tree
Showing 15 changed files with 69 additions and 72 deletions.
13 changes: 7 additions & 6 deletions rllib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,16 @@ def _setup_logger():

def _register_all():
from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.algorithms.registry import ALGORITHMS, get_algorithm_class
from ray.rllib.algorithms.registry import ALGORITHMS, _get_algorithm_class
from ray.rllib.contrib.registry import CONTRIBUTED_ALGORITHMS

for key in (
list(ALGORITHMS.keys())
+ list(CONTRIBUTED_ALGORITHMS.keys())
+ ["__fake", "__sigmoid_fake_data", "__parameter_tuning"]
for key, get_trainable_class_and_config in list(ALGORITHMS.items()) + list(
CONTRIBUTED_ALGORITHMS.items()
):
register_trainable(key, get_algorithm_class(key))
register_trainable(key, get_trainable_class_and_config()[0])

for key in ["__fake", "__sigmoid_fake_data", "__parameter_tuning"]:
register_trainable(key, _get_algorithm_class(key))

def _see_contrib(name):
"""Returns dummy agent class warning algo is in contrib/."""
Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def from_state(state: Dict) -> "Algorithm":
"No `algorithm_class` key was found in given `state`! "
"Cannot create new Algorithm."
)
# algo_class = get_algorithm_class(algo_class_name)
# algo_class = get_trainable_cls(algo_class_name)
# Create the new algo.
config = state.get("config")
if not config:
Expand Down
4 changes: 2 additions & 2 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from ray.tune.result import TRIAL_INFO
from ray.util import log_once
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.rllib.algorithms.registry import get_algorithm_class
from ray.rllib.env.env_context import EnvContext
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.evaluation.collectors.sample_collector import SampleCollector
Expand Down Expand Up @@ -40,6 +39,7 @@
SampleBatchType,
)
from ray.tune.logger import Logger
from ray.tune.registry import get_trainable_cls

if TYPE_CHECKING:
from ray.rllib.algorithms.algorithm import Algorithm
Expand Down Expand Up @@ -691,7 +691,7 @@ class directly. Note that this arg can also be specified via

algo_class = self.algo_class
if isinstance(self.algo_class, str):
algo_class = get_algorithm_class(self.algo_class)
algo_class = get_trainable_cls(self.algo_class)

return algo_class(
config=self if not use_copy else copy.deepcopy(self),
Expand Down
29 changes: 15 additions & 14 deletions rllib/algorithms/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Tuple, Type, TYPE_CHECKING, Union

from ray.rllib.contrib.registry import CONTRIBUTED_ALGORITHMS
from ray.rllib.utils.deprecation import Deprecated

if TYPE_CHECKING:
from ray.rllib.algorithms.algorithm import Algorithm
Expand Down Expand Up @@ -247,6 +248,11 @@ def _import_td3():
}


@Deprecated(
new="ray.tune.registry.get_trainable_cls([algo name], return_config=False) and cls="
"ray.tune.registry.get_trainable_cls([algo name]); cls.get_default_config();",
error=False,
)
def get_algorithm_class(
alg: str,
return_config=False,
Expand All @@ -269,37 +275,32 @@ def get_algorithm_class(
get_trainer_class = get_algorithm_class


def _get_algorithm_class(alg: str, return_config=False) -> type:
def _get_algorithm_class(alg: str) -> type:
# This helps us get around a circular import (tune calls rllib._register_all when
# checking if a rllib Trainable is registered)
if alg in ALGORITHMS:
class_, config = ALGORITHMS[alg]()
return ALGORITHMS[alg]()[0]
elif alg in CONTRIBUTED_ALGORITHMS:
class_, config = CONTRIBUTED_ALGORITHMS[alg]()
return CONTRIBUTED_ALGORITHMS[alg]()[0]
elif alg == "script":
from ray.tune import script_runner

class_, config = script_runner.ScriptRunner, {}
return script_runner.ScriptRunner
elif alg == "__fake":
from ray.rllib.algorithms.mock import _MockTrainer

class_, config = _MockTrainer, _MockTrainer.get_default_config()
return _MockTrainer
elif alg == "__sigmoid_fake_data":
from ray.rllib.algorithms.mock import _SigmoidFakeData

class_, config = _SigmoidFakeData, _SigmoidFakeData.get_default_config()
return _SigmoidFakeData
elif alg == "__parameter_tuning":
from ray.rllib.algorithms.mock import _ParameterTuningTrainer

class_, config = (
_ParameterTuningTrainer,
_ParameterTuningTrainer.get_default_config(),
)
return _ParameterTuningTrainer
else:
raise Exception("Unknown algorithm {}.".format(alg))

if return_config:
return class_, config
return class_


# Mapping from policy name to where it is located, relative to rllib.algorithms.
# TODO(jungong) : Finish migrating all the policies to PolicyV2, so we can list
Expand Down
34 changes: 16 additions & 18 deletions rllib/algorithms/tests/test_algorithm_export_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,44 +4,42 @@
import unittest

import ray
from ray.rllib.algorithms.registry import get_algorithm_class
from ray.rllib.examples.env.multi_agent import MultiAgentCartPole
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.test_utils import framework_iterator
from ray.tune.registry import get_trainable_cls

tf1, tf, tfv = try_import_tf()
torch, _ = try_import_torch()


def save_test(alg_name, framework="tf", multi_agent=False):
cls, config = get_algorithm_class(alg_name, return_config=True)

config["framework"] = framework

# Switch on saving native DL-framework (tf, torch) model files.
config["export_native_model_files"] = True
cls = get_trainable_cls(alg_name)
config = (
cls.get_default_config().framework(framework)
# Switch on saving native DL-framework (tf, torch) model files.
.checkpointing(export_native_model_files=True)
)

if "DDPG" in alg_name or "SAC" in alg_name:
algo = cls(config=config, env="Pendulum-v1")
config.environment("Pendulum-v1")
algo = config.build()
test_obs = np.array([[0.1, 0.2, 0.3]])
else:
if multi_agent:
config["multiagent"] = {
"policies": {"pol1", "pol2"},
"policy_mapping_fn": (
config.multi_agent(
policies={"pol1", "pol2"},
policy_mapping_fn=(
lambda agent_id, episode, worker, **kwargs: "pol1"
if agent_id == "agent1"
else "pol2"
),
}
config["env"] = MultiAgentCartPole
config["env_config"] = {
"num_agents": 2,
}
)
config.environment(MultiAgentCartPole, env_config={"num_agents": 2})
else:
config["env"] = "CartPole-v1"
algo = cls(config=config)
config.environment("CartPole-v1")
algo = config.build()
test_obs = np.array([[0.1, 0.2, 0.3, 0.4]])

export_dir = os.path.join(
Expand Down
5 changes: 2 additions & 3 deletions rllib/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,13 @@

import ray
import ray.cloudpickle as cloudpickle
from ray.rllib.algorithms.registry import get_algorithm_class
from ray.rllib.env import MultiAgentEnv
from ray.rllib.env.base_env import _DUMMY_AGENT_ID
from ray.rllib.env.env_context import EnvContext
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.utils.spaces.space_utils import flatten_to_single_ndarray
from ray.rllib.common import CLIArguments as cli

from ray.tune.utils import merge_dicts
from ray.tune.registry import get_trainable_cls, _global_registry, ENV_CREATOR

Expand Down Expand Up @@ -208,7 +206,8 @@ def run(
# Use default config for given agent.
if not algo:
raise ValueError("Please provide an algorithm via `--algo`.")
_, config = get_algorithm_class(algo, return_config=True)
algo_cls = get_trainable_cls(algo)
config = algo_cls.get_default_config()

# Make sure worker 0 has an Env.
config["create_env_on_driver"] = True
Expand Down
10 changes: 6 additions & 4 deletions rllib/examples/export/cartpole_dqn_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,22 @@
import os
import ray

from ray.rllib.algorithms.registry import get_algorithm_class
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.framework import try_import_tf
from ray.tune.registry import get_trainable_cls

tf1, tf, tfv = try_import_tf()

ray.init(num_cpus=10)


def train_and_export_policy_and_model(algo_name, num_steps, model_dir, ckpt_dir):
cls, config = get_algorithm_class(algo_name, return_config=True)
cls = get_trainable_cls(algo_name)
config = cls.get_default_config()
# Set exporting native (DL-framework) model files to True.
config["export_native_model_files"] = True
alg = cls(config=config, env="CartPole-v1")
config.export_native_model_files = True
config.env = "CartPole-v1"
alg = config.build()
for _ in range(num_steps):
alg.train()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import ray
from ray import air, tune
from ray.rllib.algorithms.registry import get_algorithm_class
from ray.rllib.algorithms.algorithm import Algorithm
from ray.tune.registry import get_trainable_cls

parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -126,8 +126,7 @@
# Get the last checkpoint from the above training run.
checkpoint = results.get_best_result().checkpoint
# Create new Trainer and restore its state from the last checkpoint.
algo = get_algorithm_class(args.run)(config=config)
algo.restore(checkpoint)
algo = Algorithm.from_checkpoint(checkpoint)

# Create the env to do inference in.
env = gym.make("FrozenLake-v1")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

import ray
from ray import air, tune
from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.algorithms.dt import DTConfig
from ray.rllib.algorithms.registry import get_algorithm_class
from ray.tune.utils.log import Verbosity

if __name__ == "__main__":
Expand Down Expand Up @@ -142,8 +142,7 @@
# Get the last checkpoint from the above training run.
checkpoint = results.get_best_result().checkpoint
# Create new Algorithm and restore its state from the last checkpoint.
algo = get_algorithm_class("DT")(config=config)
algo.restore(checkpoint)
algo = Algorithm.from_checkpoint(checkpoint)

# Create the env to do inference in.
env = gym.make("CartPole-v1")
Expand Down
6 changes: 2 additions & 4 deletions rllib/examples/rnnsac_stateless_cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import ray
from ray import air, tune
from ray.rllib.algorithms.registry import get_algorithm_class
from ray.tune.registry import get_trainable_cls

from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole

Expand Down Expand Up @@ -96,9 +96,7 @@
best_checkpoint = results.get_best_result().best_checkpoints[0][0]
print("Loading checkpoint: {}".format(best_checkpoint))

algo = get_algorithm_class("RNNSAC")(
env=StatelessCartPole, config=checkpoint_config
)
algo = get_trainable_cls("RNNSAC")(env=StatelessCartPole, config=checkpoint_config)
algo.restore(best_checkpoint)

env = algo.env_creator({})
Expand Down
7 changes: 4 additions & 3 deletions rllib/policy/tests/test_export_checkpoint_and_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
import unittest

import ray
from ray.rllib.algorithms.registry import get_algorithm_class
from ray.rllib.examples.env.multi_agent import MultiAgentCartPole
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.test_utils import framework_iterator
from ray.tune.registry import get_trainable_cls

tf1, tf, tfv = try_import_tf()
torch, _ = try_import_torch()
Expand Down Expand Up @@ -69,7 +69,8 @@ def export_test(
multi_agent=False,
tf_expected_to_work=True,
):
cls, config = get_algorithm_class(alg_name, return_config=True)
cls = get_trainable_cls(alg_name)
config = cls.get_default_config().to_dict()
config["framework"] = framework
# Switch on saving native DL-framework (tf, torch) model files.
config["export_native_model_files"] = True
Expand Down Expand Up @@ -192,7 +193,7 @@ def export_test(
class TestExportCheckpointAndModel(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
ray.init(num_cpus=4)
ray.init(num_cpus=4, local_mode=True)

@classmethod
def tearDownClass(cls) -> None:
Expand Down
4 changes: 2 additions & 2 deletions rllib/tests/test_algorithm_checkpoint_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import ray

from ray.rllib.algorithms.registry import get_algorithm_class
from ray.rllib.utils.test_utils import check, framework_iterator
from ray.rllib.algorithms.apex_ddpg import ApexDDPGConfig
from ray.rllib.algorithms.sac import SACConfig
Expand All @@ -16,6 +15,7 @@
from ray.rllib.algorithms.ddpg import DDPGConfig
from ray.rllib.algorithms.ars import ARSConfig
from ray.rllib.algorithms.a3c import A3CConfig
from ray.tune.registry import get_trainable_cls


def get_mean_action(alg, obs):
Expand Down Expand Up @@ -89,7 +89,7 @@ def ckpt_restore_test(algo_name, tf2=False, object_store=False, replay_buffer=Fa
for fw in framework_iterator(config, frameworks=frameworks):
for use_object_store in [False, True] if object_store else [False]:
print("use_object_store={}".format(use_object_store))
cls = get_algorithm_class(algo_name)
cls = get_trainable_cls(algo_name)
if "DDPG" in algo_name or "SAC" in algo_name:
alg1 = cls(config=config, env="Pendulum-v1")
alg2 = cls(config=config, env="Pendulum-v1")
Expand Down
4 changes: 2 additions & 2 deletions rllib/tests/test_eager_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import ray
from ray import air
from ray import tune
from ray.rllib.algorithms.registry import get_algorithm_class
from ray.rllib.utils.framework import try_import_tf
from ray.tune.registry import get_trainable_cls

tf1, tf, tfv = try_import_tf()

Expand All @@ -24,7 +24,7 @@ def check_support(alg, config, test_eager=False, test_trace=True):
else:
config["env"] = "CartPole-v1"

a = get_algorithm_class(alg)
a = get_trainable_cls(alg)
if test_eager:
print("tf-eager: alg={} cont.act={}".format(alg, cont))
config["eager_tracing"] = False
Expand Down
7 changes: 3 additions & 4 deletions rllib/tests/test_supported_multi_agent.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import unittest

import ray
from ray.rllib.algorithms.registry import get_algorithm_class
from ray.rllib.examples.env.multi_agent import MultiAgentCartPole, MultiAgentMountainCar
from ray.rllib.policy.policy import PolicySpec
from ray.rllib.utils.test_utils import check_train_results, framework_iterator
from ray.tune import register_env
from ray.tune.registry import get_trainable_cls, register_env


def check_support_multiagent(alg, config):
Expand Down Expand Up @@ -36,9 +35,9 @@ def policy_mapping_fn(agent_id, episode, worker, **kwargs):
if fw == "tf2" and alg in ["A3C", "APEX", "APEX_DDPG", "IMPALA"]:
continue
if alg in ["DDPG", "APEX_DDPG", "SAC"]:
a = get_algorithm_class(alg)(config=config, env="multi_agent_mountaincar")
a = get_trainable_cls(alg)(config=config, env="multi_agent_mountaincar")
else:
a = get_algorithm_class(alg)(config=config, env="multi_agent_cartpole")
a = get_trainable_cls(alg)(config=config, env="multi_agent_cartpole")

results = a.train()
check_train_results(results)
Expand Down
Loading

0 comments on commit 3d06343

Please sign in to comment.