Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Deprecate AlgorithmConfig.framework("tfe"): Use tf2 instead. #29755

Merged
merged 3 commits into from
Oct 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions doc/source/rllib/rllib-training.rst
Original file line number Diff line number Diff line change
Expand Up @@ -345,20 +345,19 @@ The following is a list of the common algorithm hyper-parameters:
# === Deep Learning Framework Settings ===
# tf: TensorFlow (static-graph)
# tf2: TensorFlow 2.x (eager or traced, if eager_tracing=True)
# tfe: TensorFlow eager (or traced, if eager_tracing=True)
# torch: PyTorch
"framework": "tf",
# Enable tracing in eager mode. This greatly improves performance
# (speedup ~2x), but makes it slightly harder to debug since Python
# code won't be evaluated after the initial eager pass.
# Only possible if framework=[tf2|tfe].
# Only supported if framework=tf2.
"eager_tracing": False,
# Maximum number of tf.function re-traces before a runtime error is raised.
# This is to prevent unnoticed retraces of methods inside the
# `..._eager_traced` Policy, which could slow down execution by a
# factor of 4, without the user noticing what the root cause for this
# slowdown could be.
# Only necessary for framework=[tf2|tfe].
# Only supported for framework=tf2.
# Set to None to ignore the re-trace count and never throw an error.
"eager_max_retraces": 20,

Expand Down Expand Up @@ -1549,7 +1548,7 @@ Eager Mode

Policies built with ``build_tf_policy`` (most of the reference algorithms are)
can be run in eager mode by setting the
``"framework": "[tf2|tfe]"`` / ``"eager_tracing": true`` config options or using
``"framework": "tf2"`` / ``"eager_tracing": true`` config options or using
``rllib train --config '{"framework": "tf2"}' [--trace]``.
This will tell RLlib to execute the model forward pass, action distribution,
loss, and stats functions in eager mode.
Expand Down
2 changes: 1 addition & 1 deletion rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2680,7 +2680,7 @@ py_test(
tags = ["team:rllib", "exclusive", "examples"],
size = "small",
srcs = ["examples/complex_struct_space.py"],
args = ["--framework=tfe"],
args = ["--framework=tf2"],
)

py_test(
Expand Down
13 changes: 5 additions & 8 deletions rllib/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2223,7 +2223,7 @@ def validate_framework(
_tf1, _tf, _tfv = None, None, None
_torch = None
framework = config["framework"]
tf_valid_frameworks = {"tf", "tf2", "tfe"}
tf_valid_frameworks = {"tf", "tf2"}
if framework not in tf_valid_frameworks and framework != "torch":
return
elif framework in tf_valid_frameworks:
Expand Down Expand Up @@ -2257,7 +2257,7 @@ def check_if_correct_nn_framework_installed():
def resolve_tf_settings():
"""Check and resolve tf settings."""

if _tf1 and config["framework"] in ["tf2", "tfe"]:
if _tf1 and config["framework"] == "tf2":
if config["framework"] == "tf2" and _tfv < 2:
raise ValueError(
"You configured `framework`=tf2, but your installed "
Expand Down Expand Up @@ -2323,7 +2323,7 @@ def validate_config(
# TODO: AlphaStar uses >1 GPUs differently (1 per policy actor), so this is
# ok for tf2 here.
# Remove this hacky check, once we have fully moved to the RLTrainer API.
if framework in ["tfe", "tf2"] and type(self).__name__ != "AlphaStar":
if framework == "tf2" and type(self).__name__ != "AlphaStar":
raise ValueError(
"`num_gpus` > 1 not supported yet for "
"framework={}!".format(framework)
Expand Down Expand Up @@ -2378,7 +2378,7 @@ def validate_config(

# User manually set simple-optimizer to False -> Error if tf-eager.
elif simple_optim_setting is False:
if framework in ["tfe", "tf2"]:
if framework == "tf2":
raise ValueError(
"`simple_optimizer=False` not supported for "
"config.framework({})!".format(framework)
Expand Down Expand Up @@ -2776,10 +2776,7 @@ def _run_one_training_iteration(self) -> Tuple[ResultDict, "TrainIterCtx"]:
# In case we are training (in a thread) parallel to evaluation,
# we may have to re-enable eager mode here (gets disabled in the
# thread).
if (
self.config.get("framework") in ["tf2", "tfe"]
and not tf.executing_eagerly()
):
if self.config.get("framework") == "tf2" and not tf.executing_eagerly():
tf1.enable_eager_execution()

results = None
Expand Down
8 changes: 7 additions & 1 deletion rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,7 @@ def framework(
methods inside the `..._eager_traced` Policy, which could slow down
execution by a factor of 4, without the user noticing what the root
cause for this slowdown could be.
Only necessary for framework=[tf2|tfe].
Only necessary for framework=tf2.
Set to None to ignore the re-trace count and never throw an error.
tf_session_args: Configures TF for single-process operation by default.
local_tf_session_args: Override the following tf session args on the local
Expand All @@ -598,6 +598,12 @@ def framework(
This updated AlgorithmConfig object.
"""
if framework is not None:
if framework == "tfe":
raise deprecation_warning(
old="AlgorithmConfig.framework('tfe')",
new="AlgorithmConfig.framework('tf2')",
error=True,
)
self.framework_str = framework
if eager_tracing is not None:
self.eager_tracing = eager_tracing
Expand Down
11 changes: 6 additions & 5 deletions rllib/algorithms/ars/ars_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,13 @@ def __init__(self, obs_space, action_space, config):
tf1.enable_eager_execution()
self.sess = self.inputs = None
if config.get("seed") is not None:
# Tf2.x.
# Non-static-graph TF.
if config.get("framework") == "tf2":
tf.random.set_seed(config["seed"])
# Tf-eager.
elif tf1 and config.get("framework") == "tfe":
tf1.set_random_seed(config["seed"])
# Tf1.x.
if tf1:
tf1.set_random_seed(config["seed"])
else:
tf.random.set_seed(config["seed"])

# Policy network.
self.dist_class, dist_dim = ModelCatalog.get_action_dist(
Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/cql/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def validate_config(self, config: AlgorithmConfigDict) -> None:
if config["simple_optimizer"] is not True and config["framework"] == "torch":
config["simple_optimizer"] = True

if config["framework"] in ["tf", "tf2", "tfe"] and tfp is None:
if config["framework"] in ["tf", "tf2"] and tfp is None:
logger.warning(
"You need `tensorflow_probability` in order to run CQL! "
"Install it via `pip install tensorflow_probability`. Your "
Expand Down
6 changes: 3 additions & 3 deletions rllib/algorithms/cql/cql_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ def __init__(self, config):
super().__init__(config)
if config["lagrangian"]:
# Eager mode.
if config["framework"] in ["tf2", "tfe"]:
if config["framework"] == "tf2":
self._alpha_prime_optimizer = tf.keras.optimizers.Adam(
learning_rate=config["optimization"]["critic_learning_rate"]
)
Expand Down Expand Up @@ -354,7 +354,7 @@ def compute_gradients_fn(
if policy.config["lagrangian"]:
# Eager: Use GradientTape (which is a property of the `optimizer`
# object (an OptimizerWrapper): see rllib/policy/eager_tf_policy.py).
if policy.config["framework"] in ["tf2", "tfe"]:
if policy.config["framework"] == "tf2":
tape = optimizer.tape
log_alpha_prime = [policy.model.log_alpha_prime]
alpha_prime_grads_and_vars = list(
Expand Down Expand Up @@ -391,7 +391,7 @@ def apply_gradients_fn(policy, optimizer, grads_and_vars):

if policy.config["lagrangian"]:
# Eager mode -> Just apply and return None.
if policy.config["framework"] in ["tf2", "tfe"]:
if policy.config["framework"] == "tf2":
policy._alpha_prime_optimizer.apply_gradients(
policy._alpha_prime_grads_and_vars
)
Expand Down
6 changes: 3 additions & 3 deletions rllib/algorithms/ddpg/ddpg_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def optimizer(
self,
) -> List["tf.keras.optimizers.Optimizer"]:
"""Create separate optimizers for actor & critic losses."""
if self.config["framework"] in ["tf2", "tfe"]:
if self.config["framework"] == "tf2":
self.global_step = get_variable(0, tf_name="global_step")
self._actor_optimizer = tf.keras.optimizers.Adam(
learning_rate=self.config["actor_lr"]
Expand All @@ -143,7 +143,7 @@ def optimizer(
def compute_gradients_fn(
self, optimizer: LocalOptimizer, loss: TensorType
) -> ModelGradients:
if self.config["framework"] in ["tf2", "tfe"]:
if self.config["framework"] == "tf2":
tape = optimizer.tape
pol_weights = self.model.policy_variables()
actor_grads_and_vars = list(
Expand Down Expand Up @@ -203,7 +203,7 @@ def make_apply_op():
self._critic_grads_and_vars
)
# Increment global step & apply ops.
if self.config["framework"] in ["tf2", "tfe"]:
if self.config["framework"] == "tf2":
self.global_step.assign_add(1)
return tf.no_op()
else:
Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/dqn/dqn_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ def build_q_losses(policy: Policy, model, _, train_batch: SampleBatch) -> Tensor
def adam_optimizer(
policy: Policy, config: AlgorithmConfigDict
) -> "tf.keras.optimizers.Optimizer":
if policy.config["framework"] in ["tf2", "tfe"]:
if policy.config["framework"] == "tf2":
return tf.keras.optimizers.Adam(
learning_rate=policy.cur_lr, epsilon=config["adam_epsilon"]
)
Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/dqn/learner_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(self, local_worker):

def run(self):
# Switch on eager mode if configured.
if self.local_worker.policy_config.get("framework") in ["tf2", "tfe"]:
if self.local_worker.policy_config.get("framework") == "tf2":
tf1.enable_eager_execution()
while not self.stopped:
self.step()
Expand Down
6 changes: 3 additions & 3 deletions rllib/algorithms/es/es_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,10 @@ def __init__(self, obs_space, action_space, config):
self.sess = self.inputs = None
if config.get("seed") is not None:
# Tf2.x.
if config.get("framework") == "tf2":
if tfv == 2:
tf.random.set_seed(config["seed"])
# Tf-eager.
elif tf1 and config.get("framework") == "tfe":
# Tf1.x.
else:
tf1.set_random_seed(config["seed"])

# Policy network.
Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/impala/impala.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,7 @@ def validate_config(self, config):
# TODO(sven): Need to change APPO|IMPALATorchPolicies (and the
# models to return separate sets of weights in order to create
# the different torch optimizers).
if config["framework"] not in ["tf", "tf2", "tfe"]:
if config["framework"] not in ["tf", "tf2"]:
raise ValueError(
"`_separate_vf_optimizer` only supported to tf so far!"
)
Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/impala/impala_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def optimizer(
) -> Union["tf.keras.optimizers.Optimizer", List["tf.keras.optimizers.Optimizer"]]:
config = self.config
if config["opt_type"] == "adam":
if config["framework"] in ["tf2", "tfe"]:
if config["framework"] == "tf2":
optim = tf.keras.optimizers.Adam(self.cur_lr)
if config["_separate_vf_optimizer"]:
return optim, tf.keras.optimizers.Adam(config["_lr_vf"])
Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/marwil/marwil_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def __init__(

# Update averaged advantage norm.
# Eager.
if policy.config["framework"] in ["tf2", "tfe"]:
if policy.config["framework"] == "tf2":
update_term = adv_squared - policy._moving_average_sqd_adv_norm
policy._moving_average_sqd_adv_norm.assign_add(rate * update_term)

Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/ppo/tests/test_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def test_ppo_loss_function(self):
check(train_batch[Postprocessing.VALUE_TARGETS], [0.50005, -0.505, 0.5])

# Calculate actual PPO loss.
if fw in ["tf2", "tfe"]:
if fw == "tf2":
PPOTF2Policy.loss(policy, policy.model, Categorical, train_batch)
elif fw == "torch":
PPOTorchPolicy.loss(
Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ def validate_config(self, config: AlgorithmConfigDict) -> None:
if config["grad_clip"] is not None and config["grad_clip"] <= 0.0:
raise ValueError("`grad_clip` value must be > 0.0!")

if config["framework"] in ["tf", "tf2", "tfe"] and tfp is None:
if config["framework"] in ["tf", "tf2"] and tfp is None:
logger.warning(
"You need `tensorflow_probability` in order to run SAC! "
"Install it via `pip install tensorflow_probability`. Your "
Expand Down
6 changes: 3 additions & 3 deletions rllib/algorithms/sac/sac_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ def compute_and_clip_gradients(
"""
# Eager: Use GradientTape (which is a property of the `optimizer` object
# (an OptimizerWrapper): see rllib/policy/eager_tf_policy.py).
if policy.config["framework"] in ["tf2", "tfe"]:
if policy.config["framework"] == "tf2":
tape = optimizer.tape
pol_weights = policy.model.policy_variables()
actor_grads_and_vars = list(
Expand Down Expand Up @@ -563,7 +563,7 @@ def apply_gradients(
critic_apply_ops = [policy._critic_optimizer[0].apply_gradients(cgrads)]

# Eager mode -> Just apply and return None.
if policy.config["framework"] in ["tf2", "tfe"]:
if policy.config["framework"] == "tf2":
policy._alpha_optimizer.apply_gradients(policy._alpha_grads_and_vars)
return
# Tf static graph -> Return op.
Expand Down Expand Up @@ -607,7 +607,7 @@ class ActorCriticOptimizerMixin:

def __init__(self, config):
# Eager mode.
if config["framework"] in ["tf2", "tfe"]:
if config["framework"] == "tf2":
self.global_step = get_variable(0, tf_name="global_step")
self._actor_optimizer = tf.keras.optimizers.Adam(
learning_rate=config["optimization"]["actor_learning_rate"]
Expand Down
12 changes: 6 additions & 6 deletions rllib/algorithms/sac/tests/test_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def test_sac_loss_function(self):
# Set all weights (of all nets) to fixed values.
if weights_dict is None:
# Start with the tf vars-dict.
assert fw in ["tf2", "tf", "tfe"]
assert fw in ["tf2", "tf"]

weights_dict_list = (
policy.model.variables() + policy.target_model.variables()
Expand All @@ -271,9 +271,9 @@ def test_sac_loss_function(self):
)
weights_dict = collector.get_weights()

if fw == "tfe":
if fw == "tf2":
log_alpha = weights_dict[10]
weights_dict = self._translate_tfe_weights(weights_dict, map_)
weights_dict = self._translate_tf2_weights(weights_dict, map_)
else:
assert fw == "torch" # Then transfer that to torch Model.
model_dict = self._translate_weights_to_torch(weights_dict, map_)
Expand Down Expand Up @@ -339,7 +339,7 @@ def test_sac_loss_function(self):
tf_a_grads = [g for g, v in tf_a_grads]
tf_e_grads = [g for g, v in tf_e_grads]

elif fw == "tfe":
elif fw == "tf2":
with tf.GradientTape() as tape:
tf_loss(policy, policy.model, None, input_)
c, a, e, t = (
Expand Down Expand Up @@ -680,7 +680,7 @@ def _sac_loss_helper(self, train_batch, weights, ks, log_alpha, fw, gamma, sess)
framework=fw,
)
else:
assert fw == "tfe"
assert fw == "tf2"
q_tp1 = fc(
relu(
fc(
Expand Down Expand Up @@ -733,7 +733,7 @@ def _translate_weights_to_torch(self, weights_dict, map_):

return model_dict

def _translate_tfe_weights(self, weights_dict, map_):
def _translate_tf2_weights(self, weights_dict, map_):
model_dict = {
"default_policy/log_alpha": None,
"default_policy/log_alpha_target": None,
Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/slateq/slateq_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def setup_late_mixins(
def rmsprop_optimizer(
policy: Policy, config: AlgorithmConfigDict
) -> "tf.keras.optimizers.Optimizer":
if policy.config["framework"] in ["tf2", "tfe"]:
if policy.config["framework"] == "tf2":
return tf.keras.optimizers.RMSprop(
learning_rate=policy.cur_lr,
epsilon=config["rmsprop_epsilon"],
Expand Down
6 changes: 2 additions & 4 deletions rllib/evaluation/rollout_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,9 +487,7 @@ def gen_rollouts():

if (
tf1
and (
config.framework_str in ["tf2", "tfe"] or config.enable_tf1_exec_eagerly
)
and (config.framework_str == "tf2" or config.enable_tf1_exec_eagerly)
# This eager check is necessary for certain all-framework tests
# that use tf's eager_mode() context generator.
and not tf1.executing_eagerly()
Expand Down Expand Up @@ -667,7 +665,7 @@ def wrap(env):
):

devices = []
if self.config.framework_str in ["tf2", "tf", "tfe"]:
if self.config.framework_str in ["tf2", "tf"]:
devices = get_tf_gpu_devices()
elif self.config.framework_str == "torch":
devices = list(range(torch.cuda.device_count()))
Expand Down
4 changes: 2 additions & 2 deletions rllib/evaluation/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,10 +451,10 @@ def run(self):
raise e

def _run(self):
# We are in a thread: Switch on eager execution mode, iff framework==tf2|tfe.
# We are in a thread: Switch on eager execution mode, iff framework==tf2.
if (
tf1
and self.worker.config.framework_str in ["tf2", "tfe"]
and self.worker.config.framework_str == "tf2"
and not tf1.executing_eagerly()
):
tf1.enable_eager_execution()
Expand Down
Loading