Skip to content

Commit

Permalink
[RLlib, Offline RL] Add user-defined schemas for data loading. (#46738)
Browse files Browse the repository at this point in the history
  • Loading branch information
simonsays1980 authored Jul 23, 2024
1 parent 7874da9 commit 301de59
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 36 deletions.
16 changes: 15 additions & 1 deletion rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,7 @@ def __init__(self, algo_class: Optional[type] = None):
self.input_ = "sampler"
self.input_read_method = "read_parquet"
self.input_read_method_kwargs = {}
self.input_read_schema = {}
self.prelearner_module_synch_period = 10
self.dataset_num_iters_per_learner = None
self.input_config = {}
Expand Down Expand Up @@ -2382,6 +2383,7 @@ def offline_data(
input_=NotProvided,
input_read_method=NotProvided,
input_read_method_kwargs=NotProvided,
input_read_schema=NotProvided,
prelearner_module_synch_period=NotProvided,
dataset_num_iters_per_learner=NotProvided,
input_config=NotProvided,
Expand Down Expand Up @@ -2413,7 +2415,17 @@ def offline_data(
See https://docs.ray.io/en/latest/data/api/input_output.html for more
info about available read methods in `ray.data`.
input_read_method_kwargs: kwargs for the `input_read_method`. These will be
passed into the read method without checking.
passed into the read method without checking. If no arguments are passed
in the default argument `{'override_num_blocks': max(num_learners * 2,
2)}` is used.
input_read_schema: Table schema for converting offline data to episodes.
This schema maps the offline data columns to `ray.rllib.core.columns.
Columns`: {Columns.OBS: 'o_t', Columns.ACTIONS: 'a_t', ...}. Columns in
the data set that are not mapped via this schema are sorted into
episodes' `extra_model_outputs`. If no schema is passed in the default
schema used is `ray.rllib.offline.offline_data.SCHEMA`. If your data set
contains already the names in this schema, no `input_read_schema` is
needed.
prelearner_module_synch_period: The period (number of batches converted)
after which the `RLModule` held by the `PreLearner` should sync weights.
The `PreLearner` is used to preprocess batches for the learners. The
Expand Down Expand Up @@ -2467,6 +2479,8 @@ def offline_data(
self.input_read_method = input_read_method
if input_read_method_kwargs is not NotProvided:
self.input_read_method_kwargs = input_read_method_kwargs
if input_read_schema is not NotProvided:
self.input_read_schema = input_read_schema
if prelearner_module_synch_period is not NotProvided:
self.prelearner_module_synch_period = prelearner_module_synch_period
if dataset_num_iters_per_learner is not NotProvided:
Expand Down
84 changes: 50 additions & 34 deletions rllib/offline/offline_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,25 +18,29 @@

logger = logging.getLogger(__name__)

# TODO (simon): Implement schema mapping for users, i.e. user define
# which row name to map to which default schema name below.
SCHEMA = [
Columns.EPS_ID,
Columns.AGENT_ID,
Columns.MODULE_ID,
Columns.OBS,
Columns.ACTIONS,
Columns.REWARDS,
Columns.INFOS,
Columns.NEXT_OBS,
Columns.TERMINATEDS,
Columns.TRUNCATEDS,
Columns.T,
# This is the default schema used if no `input_read_schema` is set in
# the config. If a user passes in a schema into `input_read_schema`
# this user-defined schema has to comply with the keys of `SCHEMA`,
# while values correspond to the columns in the user's dataset. Note
# that only the user-defined values will be overridden while all
# other values from SCHEMA remain as defined here.
SCHEMA = {
Columns.EPS_ID: Columns.EPS_ID,
Columns.AGENT_ID: Columns.AGENT_ID,
Columns.MODULE_ID: Columns.MODULE_ID,
Columns.OBS: Columns.OBS,
Columns.ACTIONS: Columns.ACTIONS,
Columns.REWARDS: Columns.REWARDS,
Columns.INFOS: Columns.INFOS,
Columns.NEXT_OBS: Columns.NEXT_OBS,
Columns.TERMINATEDS: Columns.TERMINATEDS,
Columns.TRUNCATEDS: Columns.TRUNCATEDS,
Columns.T: Columns.T,
# TODO (simon): Add remove as soon as we are new stack only.
"agent_index",
"dones",
"unroll_id",
]
"agent_index": "agent_index",
"dones": "dones",
"unroll_id": "unroll_id",
}


class OfflineData:
Expand Down Expand Up @@ -203,12 +207,12 @@ def __init__(

def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, List[EpisodeType]]:
# Map the batch to episodes.
episodes = self._map_to_episodes(self._is_multi_agent, batch)
episodes = self._map_to_episodes(
self._is_multi_agent, batch, schema=SCHEMA | self.config.input_read_schema
)
# TODO (simon): Make synching work. Right now this becomes blocking or never
# receives weights. Learners appear to be non accessable via other actors.
# Increase the counter for updating the module.
# IDEA: put the module state into the object store. From there any actor has
# access.
# self.iter_since_last_module_update += 1

# if self._future:
Expand Down Expand Up @@ -275,23 +279,29 @@ def _should_module_be_updated(self, module_id, multi_agent_batch=None):

@staticmethod
def _map_to_episodes(
is_multi_agent: bool, batch: Dict[str, np.ndarray]
is_multi_agent: bool,
batch: Dict[str, np.ndarray],
schema: Dict[str, str] = SCHEMA,
) -> Dict[str, List[EpisodeType]]:
"""Maps a batch of data to episodes."""

episodes = []
# TODO (simon): Give users possibility to provide a custom schema.
for i, obs in enumerate(batch["obs"]):
for i, obs in enumerate(batch[schema[Columns.OBS]]):

# If multi-agent we need to extract the agent ID.
# TODO (simon): Check, what happens with the module ID.
if is_multi_agent:
agent_id = (
batch[Columns.AGENT_ID][i]
batch[schema[Columns.AGENT_ID]][i]
if Columns.AGENT_ID in batch
# The old stack uses "agent_index" instead of "agent_id".
# TODO (simon): Remove this as soon as we are new stack only.
else (batch["agent_index"][i] if "agent_index" in batch else None)
else (
batch[schema["agent_index"]][i]
if schema["agent_index"] in batch
else None
)
)
else:
agent_id = None
Expand All @@ -302,30 +312,36 @@ def _map_to_episodes(
else:
# Build a single-agent episode with a single row of the batch.
episode = SingleAgentEpisode(
id_=batch[Columns.EPS_ID][i],
id_=batch[schema[Columns.EPS_ID]][i],
agent_id=agent_id,
observations=[
unpack_if_needed(obs),
unpack_if_needed(batch[Columns.NEXT_OBS][i]),
unpack_if_needed(batch[schema[Columns.NEXT_OBS]][i]),
],
infos=[
{},
batch[Columns.INFOS][i] if Columns.INFOS in batch else {},
batch[schema[Columns.INFOS]][i]
if schema[Columns.INFOS] in batch
else {},
],
actions=[batch[Columns.ACTIONS][i]],
rewards=[batch[Columns.REWARDS][i]],
actions=[batch[schema[Columns.ACTIONS]][i]],
rewards=[batch[schema[Columns.REWARDS]][i]],
terminated=batch[
Columns.TERMINATEDS if Columns.TERMINATEDS in batch else "dones"
schema[Columns.TERMINATEDS]
if schema[Columns.TERMINATEDS] in batch
else "dones"
][i],
truncated=batch[Columns.TRUNCATEDS][i]
if Columns.TRUNCATEDS in batch
truncated=batch[schema[Columns.TRUNCATEDS]][i]
if schema[Columns.TRUNCATEDS] in batch
else False,
# TODO (simon): Results in zero-length episodes in connector.
# t_started=batch[Columns.T if Columns.T in batch else
# "unroll_id"][i][0],
# TODO (simon): Single-dimensional columns are not supported.
extra_model_outputs={
k: [v[i]] for k, v in batch.items() if k not in SCHEMA
k: [v[i]]
for k, v in batch.items()
if (k not in schema and k not in schema.values())
},
len_lookback_buffer=0,
)
Expand Down
69 changes: 68 additions & 1 deletion rllib/offline/tests/test_offline_data.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import functools
import gymnasium as gym
import ray
import shutil
import unittest

from pathlib import Path

from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
from ray.rllib.core.columns import Columns
from ray.rllib.env.single_agent_episode import SingleAgentEpisode
from ray.rllib.offline.offline_data import OfflineData, OfflinePreLearner
from ray.rllib.offline.offline_data import OfflineData, OfflinePreLearner, SCHEMA


class TestOfflineData(unittest.TestCase):
Expand Down Expand Up @@ -63,6 +66,70 @@ def test_sample(self):
self.assertTrue("episodes" in batch)
self.assertTrue(isinstance(batch["episodes"][0], SingleAgentEpisode))

def test_offline_data_with_schema(self):

# Create some data with a different schema.
env = gym.make("CartPole-v1")
obs, _ = env.reset()
eps_id = 12345
experiences = []
for i in range(100):
action = env.action_space.sample()
next_obs, reward, terminated, truncated, _ = env.step(action)
experience = {
"o_t": obs,
"a_t": action,
"r_t": reward,
"o_tp1": next_obs,
"d_t": terminated or truncated,
"episode_id": eps_id,
}
experiences.append(experience)
if terminated or truncated:
obs, info = env.reset()
eps_id = eps_id + i
obs = next_obs

# Convert to `Dataset`.
ds = ray.data.from_items(experiences)
# Store unter the temporary directory.
dir_path = "/tmp/ray/tests/data/test_offline_data_with_schema/test_data"
ds.write_parquet(dir_path)

# Define a config.
config = AlgorithmConfig()
config.input_ = [dir_path]
# Explicitly request to use a different schema.
config.input_read_schema = {
Columns.OBS: "o_t",
Columns.ACTIONS: "a_t",
Columns.REWARDS: "r_t",
Columns.NEXT_OBS: "o_tp1",
Columns.EPS_ID: "episode_id",
Columns.TERMINATEDS: "d_t",
}
# Create the `OfflineData` instance. Note, this tests reading
# the files.
offline_data = OfflineData(config)
# Ensure that the data could be loaded.
self.assertTrue(hasattr(offline_data, "data"))
# Take a small batch.
batch = offline_data.data.take_batch(10)
self.assertTrue("o_t" in batch.keys())
self.assertTrue("a_t" in batch.keys())
self.assertTrue("r_t" in batch.keys())
self.assertTrue("o_tp1" in batch.keys())
self.assertTrue("d_t" in batch.keys())
self.assertTrue("episode_id" in batch.keys())
# Preprocess the batch to episodes. Note, here we test that the
# user schema is used.
episodes = OfflinePreLearner._map_to_episodes(
is_multi_agent=False, batch=batch, schema=SCHEMA | config.input_read_schema
)
self.assertEqual(len(episodes["episodes"]), batch["o_t"].shape[0])
# Finally, remove the files and folders.
shutil.rmtree(dir_path)


if __name__ == "__main__":
import sys
Expand Down

0 comments on commit 301de59

Please sign in to comment.