Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib; Offline RL] Make data pipeline better configurable and tuneable for users. #46777

Merged
39 changes: 35 additions & 4 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,9 @@ def __init__(self, algo_class: Optional[type] = None):
self.input_read_method = "read_parquet"
self.input_read_method_kwargs = {}
self.input_read_schema = {}
self.map_batches_kwargs = {}
self.iter_batches_kwargs = {}
self.prelearner_class = None
self.prelearner_module_synch_period = 10
self.dataset_num_iters_per_learner = None
self.input_config = {}
Expand Down Expand Up @@ -2373,6 +2376,9 @@ def offline_data(
input_read_method=NotProvided,
input_read_method_kwargs=NotProvided,
input_read_schema=NotProvided,
map_batches_kwargs=NotProvided,
iter_batches_kwargs=NotProvided,
prelearner_class=NotProvided,
prelearner_module_synch_period=NotProvided,
dataset_num_iters_per_learner=NotProvided,
input_config=NotProvided,
Expand Down Expand Up @@ -2403,10 +2409,12 @@ def offline_data(
offline data from `input_`. The default is `read_json` for JSON files.
See https://docs.ray.io/en/latest/data/api/input_output.html for more
info about available read methods in `ray.data`.
input_read_method_kwargs: kwargs for the `input_read_method`. These will be
passed into the read method without checking. If no arguments are passed
in the default argument `{'override_num_blocks': max(num_learners * 2,
2)}` is used.
input_read_method_kwargs: `kwargs` for the `input_read_method`. These will
be passed into the read method without checking. If no arguments are
passed in the default argument `{'override_num_blocks':
max(num_learners * 2, 2)}` is used. Use these `kwargs`` together with
the `map_batches_kwargs` and `iter_batches_kwargs` to tune the
performance of the data pipeline.
input_read_schema: Table schema for converting offline data to episodes.
This schema maps the offline data columns to `ray.rllib.core.columns.
Columns`: {Columns.OBS: 'o_t', Columns.ACTIONS: 'a_t', ...}. Columns in
Expand All @@ -2415,6 +2423,27 @@ def offline_data(
schema used is `ray.rllib.offline.offline_data.SCHEMA`. If your data set
contains already the names in this schema, no `input_read_schema` is
needed.
map_batches_kwargs: `kwargs` for the `map_batches` method. These will be
passed into the `ray.data.Dataset.map_batches` method when sampling
without checking. If no arguments passed in the default arguments `{
simonsays1980 marked this conversation as resolved.
Show resolved Hide resolved
'concurrency': max(2, num_learners), 'zero_copy_batch': True}` is
used. Use these `kwargs`` together with the `input_read_method_kwargs`
and `iter_batches_kwargs` to tune the performance of the data pipeline.
iter_batches_kwargs: `kwargs` for the `iter_batches` method. These will be
passed into the `ray.data.Dataset.iter_batches` method when sampling
without checking. If no arguments are passed in, the default argument `{
'prefetch_batches': 2, 'local_buffer_shuffle_size':
train_batch_size_per_learner * 4}` is used. Use these `kwargs``
together with the `input_read_method_kwargs` and `map_batches_kwargs`
to tune the performance of the data pipeline.
prelearner_class: An optional `OfflinePreLearner` class that is used to
transform data batches in `ray.data.map_batches` used in the
`OfflineData` class to transform data from columns to batches that can
be used in the `Learner`'s `update` methods. Override the
`OfflinePreLearner` class and pass your dervied class in here, if you
need to make some further transformations specific for your data or
loss. The default is `None` which uses the base `OfflinePreLearner`
defined in `ray.rllib.offline.offline_prelearner`.
prelearner_module_synch_period: The period (number of batches converted)
after which the `RLModule` held by the `PreLearner` should sync weights.
The `PreLearner` is used to preprocess batches for the learners. The
Expand Down Expand Up @@ -2470,6 +2499,8 @@ def offline_data(
self.input_read_method_kwargs = input_read_method_kwargs
if input_read_schema is not NotProvided:
self.input_read_schema = input_read_schema
if prelearner_class is not NotProvided:
self.prelearner_class = prelearner_class
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait, where do these get assigned?
input_read_method_kwargs
map_batches_kwargs

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Up in the file. Where all attributes get default values.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, sorry, didn't see this. I think b/c it hadn't been changed in this PR. All good.

if prelearner_module_synch_period is not NotProvided:
self.prelearner_module_synch_period = prelearner_module_synch_period
if dataset_num_iters_per_learner is not NotProvided:
Expand Down
5 changes: 4 additions & 1 deletion rllib/algorithms/bc/bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,9 @@ class (multi-/single-learner setup) and evaluation on
batch,
minibatch_size=self.config.train_batch_size_per_learner,
num_iters=self.config.dataset_num_iters_per_learner,
**self.offline_data.iter_batches_kwargs
if self.config.num_learners > 1
else {},
)

# Log training results.
Expand Down Expand Up @@ -217,7 +220,7 @@ class (multi-/single-learner setup) and evaluation on
# Update weights - after learning on the local worker -
# on all remote workers.
with self.metrics.log_time((TIMERS, SYNCH_WORKER_WEIGHTS_TIMER)):
self.workers.sync_weights(
self.env_runner_group.sync_weights(
# Sync weights from learner_group to all EnvRunners.
from_worker_or_learner_group=self.learner_group,
policies=modules_to_update,
Expand Down
6 changes: 3 additions & 3 deletions rllib/core/learner/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1143,7 +1143,8 @@ def update_from_iterator(
*,
timesteps: Optional[Dict[str, Any]] = None,
minibatch_size: Optional[int] = None,
num_iters: int = 1,
num_iters: int = None,
**kwargs,
):
self._check_is_built()
minibatch_size = minibatch_size or 32
Expand All @@ -1162,8 +1163,7 @@ def _finalize_fn(batch: Dict[str, numpy.ndarray]) -> Dict[str, Any]:
for batch in iterator.iter_batches(
batch_size=minibatch_size,
_finalize_fn=_finalize_fn,
prefetch_batches=2,
local_shuffle_buffer_size=minibatch_size * 10,
**kwargs,
):
# Update the iteration counter.
i += 1
Expand Down
Loading