Skip to content

Commit

Permalink
[RLlib] RLlib deprecation Notices Part 2 (models/tf, models/torch, ba…
Browse files Browse the repository at this point in the history
…se_mode, catalog, modelv2, models/temp_spec_classes, policy/) (ray-project#36840)

Signed-off-by: Avnish <avnishnarayan@gmail.com>
Signed-off-by: e428265 <arvind.chandramouli@lmco.com>
  • Loading branch information
avnishn authored and arvind-chandra committed Aug 31, 2023
1 parent 6ab2809 commit 1246ef9
Show file tree
Hide file tree
Showing 56 changed files with 371 additions and 68 deletions.
1 change: 1 addition & 0 deletions rllib/algorithms/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,7 @@ def on_train_result(self, *, algorithm=None, result: dict, **kwargs) -> None:

# This Callback is used by the RE3 exploration strategy.
# See rllib/examples/re3_exploration.py for details.
@Deprecated(error=False)
class RE3UpdateCallbacks(DefaultCallbacks):
"""Update input callbacks to mutate batch with states entropy rewards."""

Expand Down
9 changes: 9 additions & 0 deletions rllib/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
override,
ExperimentalAPI,
)
from ray.rllib.utils.deprecation import deprecation_warning, Deprecated
from ray.util import log_once


ForwardOutputType = TensorDict
Expand Down Expand Up @@ -56,6 +58,10 @@ class RecurrentModel(abc.ABC):
"""

def __init__(self, name: Optional[str] = None):
if log_once("recurrent_model_deprecation"):
deprecation_warning(
old="ray.rllib.models.base_model.RecurrentModel",
)
self._name = name or self.__class__.__name__

@property
Expand Down Expand Up @@ -201,6 +207,7 @@ def _update_outputs_and_next_state(
return outputs, next_state


@Deprecated(error=False)
class Model(RecurrentModel):
"""A RecurrentModel made non-recurrent by ignoring
the input/output states.
Expand Down Expand Up @@ -299,6 +306,8 @@ class ModelIO(abc.ABC):
"""

def __init__(self, config: ModelConfig) -> None:
if log_once("rllib_base_model_io_deprecation"):
deprecation_warning(old="ray.rllib.models.base_model.ModelIO")
self._config = config

@DeveloperAPI
Expand Down
3 changes: 2 additions & 1 deletion rllib/models/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from ray.rllib.utils.deprecation import (
DEPRECATED_VALUE,
deprecation_warning,
Deprecated,
)
from ray.rllib.utils.error import UnsupportedSpaceException
from ray.rllib.utils.framework import try_import_tf, try_import_torch
Expand Down Expand Up @@ -200,7 +201,7 @@
# fmt: on


@PublicAPI
@Deprecated(old="rllib.models.catalog.ModelCatalog", error=False)
class ModelCatalog:
"""Registry of models, preprocessors, and action distributions for envs.
Expand Down
19 changes: 16 additions & 3 deletions rllib/models/tf/attention_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space
from ray.rllib.utils.tf_utils import flatten_inputs_to_1d_tensor, one_hot
from ray.rllib.utils.typing import ModelConfigDict, TensorType, List
from ray.rllib.utils.deprecation import deprecation_warning
from ray.util import log_once

tf1, tf, tfv = try_import_tf()

Expand Down Expand Up @@ -58,6 +60,10 @@ def __init__(
self._output_layer = tf.keras.layers.Dense(
out_dim, activation=output_activation
)
if log_once("positionwise_feedforward_tf"):
deprecation_warning(
old="rllib.models.tf.attention_net.PositionwiseFeedforward",
)

def call(self, inputs: TensorType, **kwargs) -> TensorType:
del kwargs
Expand Down Expand Up @@ -98,7 +104,10 @@ def __init__(
first of the two layers within the PositionwiseFeedforward. The
second layer always has size=`attention_dim`.
"""

if log_once("trxl_net_tf"):
deprecation_warning(
old="rllib.models.tf.attention_net.TrXLNet",
)
super().__init__(
observation_space, action_space, num_outputs, model_config, name
)
Expand Down Expand Up @@ -233,7 +242,8 @@ def __init__(
(two GRUs per Transformer unit, one after the MHA, one after
the position-wise MLP).
"""

if log_once("gtrxl_net_tf"):
deprecation_warning(old="ray.rllib.models.tf.attention_net.GTrXLNet")
super().__init__(
observation_space, action_space, num_outputs, model_config, name
)
Expand Down Expand Up @@ -383,7 +393,10 @@ def __init__(
model_config: ModelConfigDict,
name: str,
):

if log_once("attention_wrapper_tf_deprecation"):
deprecation_warning(
old="ray.rllib.models.tf.attention_net.AttentionWrapper"
)
super().__init__(obs_space, action_space, None, model_config, name)

self.use_n_prev_actions = model_config["attention_use_n_prev_actions"]
Expand Down
5 changes: 5 additions & 0 deletions rllib/models/tf/complex_input_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.spaces.space_utils import flatten_space
from ray.rllib.utils.tf_utils import one_hot
from ray.rllib.utils.deprecation import deprecation_warning
from ray.util import log_once

tf1, tf, tfv = try_import_tf()

Expand All @@ -32,6 +34,9 @@ class ComplexInputNetwork(TFModelV2):
"""

def __init__(self, obs_space, action_space, num_outputs, model_config, name):
if log_once("rllib_tf_complex_input_net_deprecation"):
deprecation_warning(old="rllib.models.tf.ComplexInputNetwork")

self.original_space = (
obs_space.original_space
if hasattr(obs_space, "original_space")
Expand Down
4 changes: 4 additions & 0 deletions rllib/models/tf/fcnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.typing import TensorType, List, ModelConfigDict
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.deprecation import deprecation_warning
from ray.util import log_once

tf1, tf, tfv = try_import_tf()

Expand All @@ -25,6 +27,8 @@ def __init__(
model_config: ModelConfigDict,
name: str,
):
if log_once("rllib_models_fcnet_deprecation"):
deprecation_warning(old="ray.rllib.models.tf.fcnet.FullyConnectedNetwork")
super(FullyConnectedNetwork, self).__init__(
obs_space, action_space, num_outputs, model_config, name
)
Expand Down
6 changes: 6 additions & 0 deletions rllib/models/tf/layers/gru_gate.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.typing import TensorType, TensorShape
from ray.rllib.utils.deprecation import deprecation_warning
from ray.util import log_once

tf1, tf, tfv = try_import_tf()

Expand All @@ -8,6 +10,10 @@ class GRUGate(tf.keras.layers.Layer if tf else object):
def __init__(self, init_bias: float = 0.0, **kwargs):
super().__init__(**kwargs)
self._init_bias = init_bias
if log_once("gru_gate"):
deprecation_warning(
old="rllib.models.tf.layers.GRUGate",
)

def build(self, input_shape: TensorShape):
h_shape, x_shape = input_shape
Expand Down
6 changes: 6 additions & 0 deletions rllib/models/tf/layers/multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
"""
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.typing import TensorType
from ray.rllib.utils.deprecation import deprecation_warning
from ray.util import log_once

tf1, tf, tfv = try_import_tf()

Expand All @@ -24,6 +26,10 @@ def __init__(self, out_dim: int, num_heads: int, head_dim: int, **kwargs):
self._linear_layer = tf.keras.layers.TimeDistributed(
tf.keras.layers.Dense(out_dim, use_bias=False)
)
if log_once("multi_head_attention"):
deprecation_warning(
old="rllib.models.tf.layers.MultiHeadAttention",
)

def call(self, inputs: TensorType) -> TensorType:
L = tf.shape(inputs)[1] # length of segment
Expand Down
6 changes: 6 additions & 0 deletions rllib/models/tf/layers/noisy_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
TensorType,
TensorShape,
)
from ray.rllib.utils.deprecation import deprecation_warning
from ray.util import log_once

tf1, tf, tfv = try_import_tf()

Expand Down Expand Up @@ -47,6 +49,10 @@ def __init__(
self.b = None # Biases.
self.sigma_w = None # Noise for weight matrix
self.sigma_b = None # Noise for biases.
if log_once("noisy_layer"):
deprecation_warning(
old="rllib.models.tf.layers.NoisyLayer",
)

def build(self, input_shape: TensorShape):
in_size = int(input_shape[1])
Expand Down
6 changes: 6 additions & 0 deletions rllib/models/tf/layers/relative_multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.typing import TensorType
from ray.rllib.utils.deprecation import deprecation_warning
from ray.util import log_once

tf1, tf, tfv = try_import_tf()

Expand Down Expand Up @@ -36,6 +38,10 @@ def __init__(
activation function. Should be relu for GTrXL.
**kwargs:
"""
if log_once("relative_multi_head_attention"):
deprecation_warning(
old="rllib.models.tf.layers.RelativeMultiHeadAttention",
)
super().__init__(**kwargs)

# No bias or non-linearity.
Expand Down
6 changes: 6 additions & 0 deletions rllib/models/tf/layers/skip_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.typing import TensorType
from ray.rllib.utils.deprecation import deprecation_warning
from ray.util import log_once

tf1, tf, tfv = try_import_tf()

Expand All @@ -22,6 +24,10 @@ def __init__(self, layer: Any, fan_in_layer: Optional[Any] = None, **kwargs):
layer taking two inputs: The original input and the output
of `layer`.
"""
if log_once("skip_connection"):
deprecation_warning(
old="rllib.models.tf.layers.SkipConnection",
)
super().__init__(**kwargs)
self._layer = layer
self._fan_in_layer = fan_in_layer
Expand Down
12 changes: 12 additions & 0 deletions rllib/models/tf/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,17 @@
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.typing import TensorType
from ray.rllib.utils.deprecation import deprecation_warning
from ray.util import log_once

tf1, tf, tfv = try_import_tf()


@DeveloperAPI
def normc_initializer(std: float = 1.0) -> Any:
if log_once("rllib_models_normc_initializer_tf_deprecation"):
deprecation_warning(old="ray.rllib.models.tf.misc.normc_initializer")

def _initializer(shape, dtype=None, partition_info=None):
out = np.random.randn(*shape).astype(
dtype.name if hasattr(dtype, "name") else dtype or np.float32
Expand All @@ -31,6 +36,9 @@ def conv2d(
dtype: Optional[Any] = None,
collections: Optional[Any] = None,
) -> TensorType:
if log_once("rllib_models_conv2d_tf_deprecation"):
deprecation_warning(old="ray.rllib.models.tf.misc.conv2d")

if dtype is None:
dtype = tf.float32

Expand Down Expand Up @@ -76,6 +84,8 @@ def linear(
initializer: Optional[Any] = None,
bias_init: float = 0.0,
) -> TensorType:
if log_once("rllib_models_linear_tf_deprecation"):
deprecation_warning(old="ray.rllib.models.tf.misc.linear")
w = tf1.get_variable(name + "/w", [x.get_shape()[1], size], initializer=initializer)
b = tf1.get_variable(
name + "/b", [size], initializer=tf1.constant_initializer(bias_init)
Expand All @@ -85,4 +95,6 @@ def linear(

@DeveloperAPI
def flatten(x: TensorType) -> TensorType:
if log_once("rllib_models_flatten_tf_deprecation"):
deprecation_warning(old="ray.rllib.models.tf.misc.flatten")
return tf.reshape(x, [-1, np.prod(x.get_shape().as_list()[1:])])
2 changes: 2 additions & 0 deletions rllib/models/tf/noop.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.deprecation import Deprecated

_, tf, _ = try_import_tf()


@Deprecated(error=False)
class NoopModel(TFModelV2):
"""Trivial model that just returns the obs flattened.
Expand Down
3 changes: 3 additions & 0 deletions rllib/models/tf/primitives.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from typing import List
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.deprecation import Deprecated

_, tf, _ = try_import_tf()

# TODO (Kourosh): Find a better hierarchy for the primitives after the POC is done.


@Deprecated(error=False)
class FCNet(tf.keras.Model):
"""A simple fully connected network.
Expand Down Expand Up @@ -47,6 +49,7 @@ def call(self, inputs, training=None, mask=None):
return self.network(inputs)


@Deprecated(error=False)
class IdentityNetwork(tf.keras.Model):
"""A network that returns the input as the output."""

Expand Down
13 changes: 12 additions & 1 deletion rllib/models/tf/recurrent_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space
from ray.rllib.utils.tf_utils import flatten_inputs_to_1d_tensor, one_hot
from ray.rllib.utils.typing import ModelConfigDict, TensorType
from ray.rllib.utils.deprecation import deprecation_warning
from ray.util.debug import log_once

tf1, tf, tfv = try_import_tf()
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -67,6 +69,14 @@ def forward(
"""Adds time dimension to batch before sending inputs to forward_rnn().
You should implement forward_rnn() in your subclass."""
# Creating a __init__ function that acts as a passthrough and adding the warning
# there led to errors probably due to the multiple inheritance. We encountered
# the same error if we add the Deprecated decorator. We therefore add the
# deprecation warning here.
if log_once("recurrent_network_tf"):
deprecation_warning(
old="ray.rllib.models.tf.recurrent_net.RecurrentNetwork"
)
assert seq_lens is not None
flat_inputs = input_dict["obs_flat"]
inputs = add_time_dimension(
Expand Down Expand Up @@ -131,7 +141,8 @@ def __init__(
model_config: ModelConfigDict,
name: str,
):

if log_once("lstm_wrapper_tf"):
deprecation_warning(old="ray.rllib.models.tf.recurrent_net.LSTMWrapper")
super(LSTMWrapper, self).__init__(
obs_space, action_space, None, model_config, name
)
Expand Down
Loading

0 comments on commit 1246ef9

Please sign in to comment.