From a6ca66670e9cde981a807f8a640c01e4685295e3 Mon Sep 17 00:00:00 2001 From: Charles Sun Date: Mon, 1 Aug 2022 18:10:23 -0700 Subject: [PATCH 1/4] dataset reader normalization and test Signed-off-by: Charles Sun --- rllib/evaluation/tests/test_rollout_worker.py | 88 ++++++++++++ rllib/offline/dataset_reader.py | 3 +- rllib/offline/json_reader.py | 127 ++++++++++-------- rllib/tuned_examples/crr/pendulum-v1-crr.yaml | 2 +- 4 files changed, 164 insertions(+), 56 deletions(-) diff --git a/rllib/evaluation/tests/test_rollout_worker.py b/rllib/evaluation/tests/test_rollout_worker.py index c431c43ad0b50..104e0f7fac278 100644 --- a/rllib/evaluation/tests/test_rollout_worker.py +++ b/rllib/evaluation/tests/test_rollout_worker.py @@ -6,6 +6,8 @@ import random import time import unittest +import tempfile +import json import ray from ray.rllib.algorithms.a2c import A2C @@ -22,6 +24,8 @@ ) from ray.rllib.examples.env.multi_agent import BasicMultiAgent, MultiAgentCartPole from ray.rllib.examples.policy.random_policy import RandomPolicy +from ray.rllib.offline.dataset_reader import DatasetReader, get_dataset_and_shards +from ray.rllib.offline.json_reader import JsonReader from ray.rllib.policy.policy import Policy, PolicySpec from ray.rllib.policy.sample_batch import ( DEFAULT_POLICY_ID, @@ -358,6 +362,90 @@ def test_action_normalization(self): self.assertLess(np.min(sample["actions"]), action_space.low[0]) ev.stop() + def test_action_normalization_offline_dataset(self): + with tempfile.TemporaryDirectory() as tmp_dir: + # create environment + env = gym.make("Pendulum-v1") + + # create temp data with actions at min and max + data = { + "type": "SampleBatch", + "actions": [[2.0], [-2.0]], + "dones": [0.0, 0.0], + "rewards": [0.0, 0.0], + "obs": [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], + "new_obs": [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], + } + + data_file = os.path.join(tmp_dir, "data.json") + + with open(data_file, "w") as f: + json.dump(data, f) + + # create input reader functions + def dataset_reader_creator(ioctx): + config = { + "input": "dataset", + "input_config": {"format": "json", "paths": data_file}, + } + _, shards = get_dataset_and_shards(config, num_workers=0) + return DatasetReader(shards[0], ioctx) + + def json_reader_creator(ioctx): + return JsonReader(data_file, ioctx) + + input_creators = [dataset_reader_creator, json_reader_creator] + + # check that if actions_in_input_normalized is False + # it will normalize input + for input_creator in input_creators: + ev = RolloutWorker( + env_creator=lambda _: env, + policy_spec=MockPolicy, + policy_config=dict( + actions_in_input_normalized=False, + normalize_actions=True, + clip_actions=False, + offline_sampling=True, + train_batch_size=1, + ), + rollout_fragment_length=1, + input_creator=input_creator, + ) + + sample = ev.sample() + + # check if the samples from dataset are normalized properly + self.assertLessEqual(np.max(sample["actions"]), 1.0) + self.assertGreaterEqual(np.min(sample["actions"]), -1.0) + + ev.stop() + + # check that if actions_in_input_normalized is True + # it will not normalize input + for input_creator in input_creators: + ev = RolloutWorker( + env_creator=lambda _: env, + policy_spec=MockPolicy, + policy_config=dict( + actions_in_input_normalized=True, + normalize_actions=True, + clip_actions=False, + offline_sampling=True, + train_batch_size=1, + ), + rollout_fragment_length=1, + input_creator=input_creator, + ) + + sample = ev.sample() + + # check if the samples from dataset are not normalized + self.assertGreater(np.max(sample["actions"]), 1.0) + self.assertLess(np.min(sample["actions"]), -1.0) + + ev.stop() + def test_action_immutability(self): from ray.rllib.examples.env.random_env import RandomEnv diff --git a/rllib/offline/dataset_reader.py b/rllib/offline/dataset_reader.py index f8fd2d4e9afd0..40ef36f3c33a2 100644 --- a/rllib/offline/dataset_reader.py +++ b/rllib/offline/dataset_reader.py @@ -8,7 +8,7 @@ import ray.data from ray.rllib.offline.input_reader import InputReader from ray.rllib.offline.io_context import IOContext -from ray.rllib.offline.json_reader import from_json_data +from ray.rllib.offline.json_reader import from_json_data, postprocess_actions from ray.rllib.policy.sample_batch import concat_samples, SampleBatch, DEFAULT_POLICY_ID from ray.rllib.utils.annotations import override, PublicAPI from ray.rllib.utils.typing import SampleBatchType, AlgorithmConfigDict @@ -251,6 +251,7 @@ def next(self) -> SampleBatchType: d = next(self._iter).as_pydict() # Columns like obs are compressed when written by DatasetWriter. d = from_json_data(d, self._ioctx.worker) + d = postprocess_actions(d, self._ioctx) count += d.count ret.append(self._postprocess_if_needed(d)) ret = concat_samples(ret) diff --git a/rllib/offline/json_reader.py b/rllib/offline/json_reader.py index e763e1167632d..fca5801eb8bbd 100644 --- a/rllib/offline/json_reader.py +++ b/rllib/offline/json_reader.py @@ -83,6 +83,78 @@ def _adjust_obs_actions_for_policy(json_data: dict, policy: Policy) -> dict: return json_data +@DeveloperAPI +def postprocess_actions(batch: SampleBatchType, ioctx: IOContext) -> SampleBatchType: + # Clip actions (from any values into env's bounds), if necessary. + cfg = ioctx.config + # TODO(jungong) : we should not clip_action in input reader. + # Use connector to handle this. + if cfg.get("clip_actions"): + if ioctx.worker is None: + raise ValueError( + "clip_actions is True but cannot clip actions since no workers exist" + ) + + if isinstance(batch, SampleBatch): + default_policy = ioctx.worker.policy_map.get(DEFAULT_POLICY_ID) + batch[SampleBatch.ACTIONS] = clip_action( + batch[SampleBatch.ACTIONS], default_policy.action_space_struct + ) + else: + for pid, b in batch.policy_batches.items(): + b[SampleBatch.ACTIONS] = clip_action( + b[SampleBatch.ACTIONS], + ioctx.worker.policy_map[pid].action_space_struct, + ) + # Re-normalize actions (from env's bounds to zero-centered), if + # necessary. + if ( + cfg.get("actions_in_input_normalized") is False + and cfg.get("normalize_actions") is True + ): + if ioctx.worker is None: + raise ValueError( + "actions_in_input_normalized is False but" + "cannot normalize actions since no workers exist" + ) + + # If we have a complex action space and actions were flattened + # and we have to normalize -> Error. + error_msg = ( + "Normalization of offline actions that are flattened is not " + "supported! Make sure that you record actions into offline " + "file with the `_disable_action_flattening=True` flag OR " + "as already normalized (between -1.0 and 1.0) values. " + "Also, when reading already normalized action values from " + "offline files, make sure to set " + "`actions_in_input_normalized=True` so that RLlib will not " + "perform normalization on top." + ) + + if isinstance(batch, SampleBatch): + pol = ioctx.worker.policy_map.get(DEFAULT_POLICY_ID) + if isinstance( + pol.action_space_struct, (tuple, dict) + ) and not pol.config.get("_disable_action_flattening"): + raise ValueError(error_msg) + batch[SampleBatch.ACTIONS] = normalize_action( + batch[SampleBatch.ACTIONS], pol.action_space_struct + ) + else: + for pid, b in batch.policy_batches.items(): + pol = ioctx.worker.policy_map[pid] + if isinstance( + pol.action_space_struct, (tuple, dict) + ) and not pol.config.get("_disable_action_flattening"): + raise ValueError(error_msg) + b[SampleBatch.ACTIONS] = normalize_action( + b[SampleBatch.ACTIONS], + ioctx.worker.policy_map[pid].action_space_struct, + ) + + return batch + + @DeveloperAPI def from_json_data(json_data: Any, worker: Optional["RolloutWorker"]): # Try to infer the SampleBatchType (SampleBatch or MultiAgentBatch). @@ -290,61 +362,8 @@ def _try_parse(self, line: str) -> Optional[SampleBatchType]: ) return None - # Clip actions (from any values into env's bounds), if necessary. - cfg = self.ioctx.config - # TODO(jungong) : we should not clip_action in input reader. - # Use connector to handle this. - if cfg.get("clip_actions") and self.ioctx.worker is not None: - if isinstance(batch, SampleBatch): - batch[SampleBatch.ACTIONS] = clip_action( - batch[SampleBatch.ACTIONS], self.default_policy.action_space_struct - ) - else: - for pid, b in batch.policy_batches.items(): - b[SampleBatch.ACTIONS] = clip_action( - b[SampleBatch.ACTIONS], - self.ioctx.worker.policy_map[pid].action_space_struct, - ) - # Re-normalize actions (from env's bounds to zero-centered), if - # necessary. - if ( - cfg.get("actions_in_input_normalized") is False - and self.ioctx.worker is not None - ): + batch = postprocess_actions(batch, self.ioctx) - # If we have a complex action space and actions were flattened - # and we have to normalize -> Error. - error_msg = ( - "Normalization of offline actions that are flattened is not " - "supported! Make sure that you record actions into offline " - "file with the `_disable_action_flattening=True` flag OR " - "as already normalized (between -1.0 and 1.0) values. " - "Also, when reading already normalized action values from " - "offline files, make sure to set " - "`actions_in_input_normalized=True` so that RLlib will not " - "perform normalization on top." - ) - - if isinstance(batch, SampleBatch): - pol = self.default_policy - if isinstance( - pol.action_space_struct, (tuple, dict) - ) and not pol.config.get("_disable_action_flattening"): - raise ValueError(error_msg) - batch[SampleBatch.ACTIONS] = normalize_action( - batch[SampleBatch.ACTIONS], pol.action_space_struct - ) - else: - for pid, b in batch.policy_batches.items(): - pol = self.policy_map[pid] - if isinstance( - pol.action_space_struct, (tuple, dict) - ) and not pol.config.get("_disable_action_flattening"): - raise ValueError(error_msg) - b[SampleBatch.ACTIONS] = normalize_action( - b[SampleBatch.ACTIONS], - self.ioctx.worker.policy_map[pid].action_space_struct, - ) return batch def _next_line(self) -> str: diff --git a/rllib/tuned_examples/crr/pendulum-v1-crr.yaml b/rllib/tuned_examples/crr/pendulum-v1-crr.yaml index ce19723f9b2c7..539b3ee886b48 100644 --- a/rllib/tuned_examples/crr/pendulum-v1-crr.yaml +++ b/rllib/tuned_examples/crr/pendulum-v1-crr.yaml @@ -20,7 +20,7 @@ pendulum_crr: actor_hidden_activation: 'relu' actor_hiddens: [256, 256] actor_lr: 0.0003 - actions_in_input_normalized: True + actions_in_input_normalized: False clip_actions: True # Q function update setting twin_q: True From d43d2798d5d39ab4bd00ecf77b3e4495e41a08b0 Mon Sep 17 00:00:00 2001 From: Charles Sun Date: Mon, 1 Aug 2022 18:18:31 -0700 Subject: [PATCH 2/4] lint Signed-off-by: Charles Sun --- rllib/evaluation/tests/test_rollout_worker.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rllib/evaluation/tests/test_rollout_worker.py b/rllib/evaluation/tests/test_rollout_worker.py index 104e0f7fac278..ae98fac66b26d 100644 --- a/rllib/evaluation/tests/test_rollout_worker.py +++ b/rllib/evaluation/tests/test_rollout_worker.py @@ -1,13 +1,13 @@ from collections import Counter import gym from gym.spaces import Box, Discrete +import json import numpy as np import os import random +import tempfile import time import unittest -import tempfile -import json import ray from ray.rllib.algorithms.a2c import A2C From 1cf078045a11bfdd1541531ca33eb29cd83c4714 Mon Sep 17 00:00:00 2001 From: Charles Sun Date: Sun, 7 Aug 2022 23:30:49 -0700 Subject: [PATCH 3/4] update test Signed-off-by: Charles Sun --- rllib/evaluation/tests/test_rollout_worker.py | 80 ++++++++----------- 1 file changed, 35 insertions(+), 45 deletions(-) diff --git a/rllib/evaluation/tests/test_rollout_worker.py b/rllib/evaluation/tests/test_rollout_worker.py index ae98fac66b26d..bdb98924a4ad4 100644 --- a/rllib/evaluation/tests/test_rollout_worker.py +++ b/rllib/evaluation/tests/test_rollout_worker.py @@ -396,55 +396,45 @@ def json_reader_creator(ioctx): input_creators = [dataset_reader_creator, json_reader_creator] - # check that if actions_in_input_normalized is False - # it will normalize input + # actions_in_input_normalized, normalize_actions + parameters = [ + (True, True), + (True, False), + (False, True), + (False, False), + ] + + # check that samples from dataset will be normalized if and only if + # actions_in_input_normalized == False and + # normalize_actions == True for input_creator in input_creators: - ev = RolloutWorker( - env_creator=lambda _: env, - policy_spec=MockPolicy, - policy_config=dict( - actions_in_input_normalized=False, - normalize_actions=True, - clip_actions=False, - offline_sampling=True, - train_batch_size=1, - ), - rollout_fragment_length=1, - input_creator=input_creator, - ) - - sample = ev.sample() - - # check if the samples from dataset are normalized properly - self.assertLessEqual(np.max(sample["actions"]), 1.0) - self.assertGreaterEqual(np.min(sample["actions"]), -1.0) - - ev.stop() - - # check that if actions_in_input_normalized is True - # it will not normalize input - for input_creator in input_creators: - ev = RolloutWorker( - env_creator=lambda _: env, - policy_spec=MockPolicy, - policy_config=dict( - actions_in_input_normalized=True, - normalize_actions=True, - clip_actions=False, - offline_sampling=True, - train_batch_size=1, - ), - rollout_fragment_length=1, - input_creator=input_creator, - ) + for actions_in_input_normalized, normalize_actions in parameters: + ev = RolloutWorker( + env_creator=lambda _: env, + policy_spec=MockPolicy, + policy_config=dict( + actions_in_input_normalized=actions_in_input_normalized, + normalize_actions=normalize_actions, + clip_actions=False, + offline_sampling=True, + train_batch_size=1, + ), + rollout_fragment_length=1, + input_creator=input_creator, + ) - sample = ev.sample() + sample = ev.sample() - # check if the samples from dataset are not normalized - self.assertGreater(np.max(sample["actions"]), 1.0) - self.assertLess(np.min(sample["actions"]), -1.0) + if normalize_actions and not actions_in_input_normalized: + # check if the samples from dataset are normalized properly + self.assertLessEqual(np.max(sample["actions"]), 1.0) + self.assertGreaterEqual(np.min(sample["actions"]), -1.0) + else: + # check if the samples from dataset are not normalized + self.assertGreater(np.max(sample["actions"]), 1.5) + self.assertLess(np.min(sample["actions"]), -1.5) - ev.stop() + ev.stop() def test_action_immutability(self): from ray.rllib.examples.env.random_env import RandomEnv From 50f886185bb9b485505758a4ada5eb59b5bf8f65 Mon Sep 17 00:00:00 2001 From: Charles Sun Date: Mon, 8 Aug 2022 09:35:03 -0700 Subject: [PATCH 4/4] run ci Signed-off-by: Charles Sun