Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib; Offline RL] 1. Fix multi-learner issue. #49194

Merged
1 change: 1 addition & 0 deletions rllib/algorithms/marwil/marwil.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,7 @@ class (multi-/single-learner setup) and evaluation on
batch=batch_or_iterator,
minibatch_size=self.config.train_batch_size_per_learner,
num_iters=self.config.dataset_num_iters_per_learner,
**self.offline_data.iter_batches_kwargs,
)

# Log training results.
Expand Down
15 changes: 12 additions & 3 deletions rllib/core/learner/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import tree # pip install dm_tree

import ray
from ray.data.iterator import DataIterator
from ray.rllib.connectors.learner.learner_connector_pipeline import (
LearnerConnectorPipeline,
)
Expand Down Expand Up @@ -270,6 +271,10 @@ def __init__(
# and return the resulting (reduced) dict.
self.metrics = MetricsLogger()

# In case of offline learning and multiple learners, each learner receives a
# repeatable iterator that iterates over a split of the streamed data.
self.iterator: DataIterator = None

# TODO (sven): Do we really need this API? It seems like LearnerGroup constructs
# all Learner workers and then immediately builds them any ways? Seems to make
# thing more complicated. Unless there is a reason related to Train worker group
Expand Down Expand Up @@ -956,6 +961,7 @@ def update_from_batch(
shuffle_batch_per_epoch: bool = False,
# Deprecated args.
num_iters=DEPRECATED_VALUE,
**kwargs,
) -> ResultDict:
"""Run `num_epochs` epochs over the given train batch.

Expand Down Expand Up @@ -1088,6 +1094,9 @@ def update_from_iterator(
"`num_iters` instead."
)

if not self.iterator:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dumb question: What if self.iterator is already set (to a previously incoming DataIterator)? Would the now-incoming iterator be thrown away?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it would. But we want this: The incoming iterator might be different from the one assigned to the learner, but only because of ray.get not being able to call learners in order, i.e. the learners might get different streaming splits every training iteration - we do not want this to happen. One Learner same split.

self.iterator = iterator

self._check_is_built()

# Call `before_gradient_based_update` to allow for non-gradient based
Expand All @@ -1101,8 +1110,8 @@ def _finalize_fn(batch: Dict[str, numpy.ndarray]) -> Dict[str, Any]:
return {"batch": self._set_slicing_by_batch_id(batch, value=True)}

i = 0
logger.debug(f"===> [Learner {id(self)}]: SLooping through batches ... ")
for batch in iterator.iter_batches(
logger.debug(f"===> [Learner {id(self)}]: Looping through batches ... ")
for batch in self.iterator.iter_batches(
# Note, this needs to be one b/c data is already mapped to
# `MultiAgentBatch`es of `minibatch_size`.
batch_size=1,
Expand Down Expand Up @@ -1145,7 +1154,7 @@ def _finalize_fn(batch: Dict[str, numpy.ndarray]) -> Dict[str, Any]:
if num_iters and i == num_iters:
break

logger.info(
logger.debug(
f"===> [Learner {id(self)}] number of iterations run in this epoch: {i}"
)

Expand Down
66 changes: 34 additions & 32 deletions rllib/offline/offline_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,14 @@ def __init__(self, config: AlgorithmConfig):
except Exception as e:
logger.error(e)
# Avoids reinstantiating the batch iterator each time we sample.
self.batch_iterator = None
self.batch_iterators = None
self.map_batches_kwargs = (
self.default_map_batches_kwargs | self.config.map_batches_kwargs
)
self.iter_batches_kwargs = (
self.default_iter_batches_kwargs | self.config.iter_batches_kwargs
)
self.returned_streaming_split = False
# Defines the prelearner class. Note, this could be user-defined.
self.prelearner_class = self.config.prelearner_class or OfflinePreLearner
# For remote learner setups.
Expand Down Expand Up @@ -164,56 +165,57 @@ def sample(
# If the user wants to materialize the data in memory.
if self.materialize_mapped_data:
self.data = self.data.materialize()
# Build an iterator, if necessary.
if (not self.batch_iterator and (not return_iterator or num_shards <= 1)) or (
return_iterator and isinstance(self.batch_iterator, types.GeneratorType)
# Build an iterator, if necessary. Note, in case that an iterator should be
# returned now and we have already generated from the iterator, i.e.
# `isinstance(self.batch_iterators, types.GeneratorType) == True`, we need
# to create here a new iterator.
if not self.batch_iterators or (
return_iterator and isinstance(self.batch_iterators, types.GeneratorType)
):
# If no iterator should be returned, or if we want to return a single
# batch iterator, we instantiate the batch iterator once, here.
# TODO (simon, sven): The iterator depends on the `num_samples`, i.e.abs
# sampling later with a different batch size would need a
# reinstantiation of the iterator.
self.batch_iterator = self.data.iter_batches(
# This is important. The batch size is now 1, because the data
# is already run through the `OfflinePreLearner` and a single
# instance is a single `MultiAgentBatch` of size `num_samples`.
batch_size=1,
**self.iter_batches_kwargs,
)

if not return_iterator:
self.batch_iterator = iter(self.batch_iterator)

# Do we want to return an iterator or a single batch?
if return_iterator:
# In case of multiple shards, we return multiple
# `StreamingSplitIterator` instances.
# If we have more than one learner create an iterator for each of them
# by splitting the data stream.
if num_shards > 1:
# TODO (simon): Check, if we should use `iter_batches_kwargs` here
# as well.
logger.debug("===> [OfflineData]: Return streaming_split ... ")
return self.data.streaming_split(
# In case of multiple shards, we return multiple
# `StreamingSplitIterator` instances.
self.batch_iterators = self.data.streaming_split(
n=num_shards,
# Note, `equal` must be `True`, i.e. the batch size must
# be the same for all batches b/c otherwise remote learners
# could block each others.
equal=True,
locality_hints=self.locality_hints,
)

# Otherwise, we return a simple batch `DataIterator`.
# Otherwise we create a simple iterator and - if necessary - initialize
# it here.
else:
return self.batch_iterator
# If no iterator should be returned, or if we want to return a single
# batch iterator, we instantiate the batch iterator once, here.
self.batch_iterators = self.data.iter_batches(
# This is important. The batch size is now 1, because the data
# is already run through the `OfflinePreLearner` and a single
# instance is a single `MultiAgentBatch` of size `num_samples`.
batch_size=1,
**self.iter_batches_kwargs,
)

# If there should be batches
if not return_iterator:
self.batch_iterators = iter(self.batch_iterators)

# Do we want to return an iterator or a single batch?
if return_iterator:
return self.batch_iterators
else:
# Return a single batch from the iterator.
try:
return next(self.batch_iterator)["batch"][0]
return next(self.batch_iterators)["batch"][0]
except StopIteration:
# If the batch iterator is exhausted, reinitiate a new one.
logger.debug(
"===> [OfflineData]: Batch iterator exhausted. Reinitiating ..."
)
self.batch_iterator = None
self.batch_iterators = None
return self.sample(
num_samples=num_samples,
return_iterator=return_iterator,
Expand Down
Loading