Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] DatasetReader action normalization #27356

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions rllib/evaluation/tests/test_rollout_worker.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from collections import Counter
import gym
from gym.spaces import Box, Discrete
import json
import numpy as np
import os
import random
import tempfile
import time
import unittest

Expand All @@ -22,6 +24,8 @@
)
from ray.rllib.examples.env.multi_agent import BasicMultiAgent, MultiAgentCartPole
from ray.rllib.examples.policy.random_policy import RandomPolicy
from ray.rllib.offline.dataset_reader import DatasetReader, get_dataset_and_shards
from ray.rllib.offline.json_reader import JsonReader
from ray.rllib.policy.policy import Policy, PolicySpec
from ray.rllib.policy.sample_batch import (
DEFAULT_POLICY_ID,
Expand Down Expand Up @@ -358,6 +362,90 @@ def test_action_normalization(self):
self.assertLess(np.min(sample["actions"]), action_space.low[0])
ev.stop()

def test_action_normalization_offline_dataset(self):
with tempfile.TemporaryDirectory() as tmp_dir:
# create environment
env = gym.make("Pendulum-v1")

# create temp data with actions at min and max
data = {
"type": "SampleBatch",
"actions": [[2.0], [-2.0]],
"dones": [0.0, 0.0],
"rewards": [0.0, 0.0],
"obs": [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
"new_obs": [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
}

data_file = os.path.join(tmp_dir, "data.json")

with open(data_file, "w") as f:
json.dump(data, f)

# create input reader functions
def dataset_reader_creator(ioctx):
config = {
"input": "dataset",
"input_config": {"format": "json", "paths": data_file},
}
_, shards = get_dataset_and_shards(config, num_workers=0)
return DatasetReader(shards[0], ioctx)

def json_reader_creator(ioctx):
return JsonReader(data_file, ioctx)

input_creators = [dataset_reader_creator, json_reader_creator]

# check that if actions_in_input_normalized is False
# it will normalize input
for input_creator in input_creators:
ev = RolloutWorker(
env_creator=lambda _: env,
policy_spec=MockPolicy,
policy_config=dict(
actions_in_input_normalized=False,
normalize_actions=True,
clip_actions=False,
offline_sampling=True,
train_batch_size=1,
),
rollout_fragment_length=1,
input_creator=input_creator,
)

sample = ev.sample()

# check if the samples from dataset are normalized properly
self.assertLessEqual(np.max(sample["actions"]), 1.0)
self.assertGreaterEqual(np.min(sample["actions"]), -1.0)

ev.stop()

# check that if actions_in_input_normalized is True
# it will not normalize input
for input_creator in input_creators:
ev = RolloutWorker(
env_creator=lambda _: env,
policy_spec=MockPolicy,
policy_config=dict(
actions_in_input_normalized=True,
normalize_actions=True,
clip_actions=False,
offline_sampling=True,
train_batch_size=1,
),
rollout_fragment_length=1,
input_creator=input_creator,
)

sample = ev.sample()

# check if the samples from dataset are not normalized
self.assertGreater(np.max(sample["actions"]), 1.0)
self.assertLess(np.min(sample["actions"]), -1.0)

ev.stop()

def test_action_immutability(self):
from ray.rllib.examples.env.random_env import RandomEnv

Expand Down
3 changes: 2 additions & 1 deletion rllib/offline/dataset_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import ray.data
from ray.rllib.offline.input_reader import InputReader
from ray.rllib.offline.io_context import IOContext
from ray.rllib.offline.json_reader import from_json_data
from ray.rllib.offline.json_reader import from_json_data, postprocess_actions
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice

from ray.rllib.policy.sample_batch import concat_samples, SampleBatch, DEFAULT_POLICY_ID
from ray.rllib.utils.annotations import override, PublicAPI
from ray.rllib.utils.typing import SampleBatchType, AlgorithmConfigDict
Expand Down Expand Up @@ -251,6 +251,7 @@ def next(self) -> SampleBatchType:
d = next(self._iter).as_pydict()
# Columns like obs are compressed when written by DatasetWriter.
d = from_json_data(d, self._ioctx.worker)
d = postprocess_actions(d, self._ioctx)
count += d.count
ret.append(self._postprocess_if_needed(d))
ret = concat_samples(ret)
Expand Down
127 changes: 73 additions & 54 deletions rllib/offline/json_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,78 @@ def _adjust_obs_actions_for_policy(json_data: dict, policy: Policy) -> dict:
return json_data


@DeveloperAPI
def postprocess_actions(batch: SampleBatchType, ioctx: IOContext) -> SampleBatchType:
# Clip actions (from any values into env's bounds), if necessary.
cfg = ioctx.config
# TODO(jungong) : we should not clip_action in input reader.
# Use connector to handle this.
if cfg.get("clip_actions"):
if ioctx.worker is None:
raise ValueError(
"clip_actions is True but cannot clip actions since no workers exist"
)

if isinstance(batch, SampleBatch):
default_policy = ioctx.worker.policy_map.get(DEFAULT_POLICY_ID)
batch[SampleBatch.ACTIONS] = clip_action(
batch[SampleBatch.ACTIONS], default_policy.action_space_struct
)
else:
for pid, b in batch.policy_batches.items():
b[SampleBatch.ACTIONS] = clip_action(
b[SampleBatch.ACTIONS],
ioctx.worker.policy_map[pid].action_space_struct,
)
# Re-normalize actions (from env's bounds to zero-centered), if
# necessary.
if (
cfg.get("actions_in_input_normalized") is False
and cfg.get("normalize_actions") is True
):
if ioctx.worker is None:
raise ValueError(
"actions_in_input_normalized is False but"
"cannot normalize actions since no workers exist"
)

# If we have a complex action space and actions were flattened
# and we have to normalize -> Error.
error_msg = (
"Normalization of offline actions that are flattened is not "
"supported! Make sure that you record actions into offline "
"file with the `_disable_action_flattening=True` flag OR "
"as already normalized (between -1.0 and 1.0) values. "
"Also, when reading already normalized action values from "
"offline files, make sure to set "
"`actions_in_input_normalized=True` so that RLlib will not "
"perform normalization on top."
)

if isinstance(batch, SampleBatch):
pol = ioctx.worker.policy_map.get(DEFAULT_POLICY_ID)
if isinstance(
pol.action_space_struct, (tuple, dict)
) and not pol.config.get("_disable_action_flattening"):
raise ValueError(error_msg)
batch[SampleBatch.ACTIONS] = normalize_action(
batch[SampleBatch.ACTIONS], pol.action_space_struct
)
else:
for pid, b in batch.policy_batches.items():
pol = ioctx.worker.policy_map[pid]
if isinstance(
pol.action_space_struct, (tuple, dict)
) and not pol.config.get("_disable_action_flattening"):
raise ValueError(error_msg)
b[SampleBatch.ACTIONS] = normalize_action(
b[SampleBatch.ACTIONS],
ioctx.worker.policy_map[pid].action_space_struct,
)

return batch


@DeveloperAPI
def from_json_data(json_data: Any, worker: Optional["RolloutWorker"]):
# Try to infer the SampleBatchType (SampleBatch or MultiAgentBatch).
Expand Down Expand Up @@ -290,61 +362,8 @@ def _try_parse(self, line: str) -> Optional[SampleBatchType]:
)
return None

# Clip actions (from any values into env's bounds), if necessary.
cfg = self.ioctx.config
# TODO(jungong) : we should not clip_action in input reader.
# Use connector to handle this.
if cfg.get("clip_actions") and self.ioctx.worker is not None:
if isinstance(batch, SampleBatch):
batch[SampleBatch.ACTIONS] = clip_action(
batch[SampleBatch.ACTIONS], self.default_policy.action_space_struct
)
else:
for pid, b in batch.policy_batches.items():
b[SampleBatch.ACTIONS] = clip_action(
b[SampleBatch.ACTIONS],
self.ioctx.worker.policy_map[pid].action_space_struct,
)
# Re-normalize actions (from env's bounds to zero-centered), if
# necessary.
if (
cfg.get("actions_in_input_normalized") is False
and self.ioctx.worker is not None
):
batch = postprocess_actions(batch, self.ioctx)

# If we have a complex action space and actions were flattened
# and we have to normalize -> Error.
error_msg = (
"Normalization of offline actions that are flattened is not "
"supported! Make sure that you record actions into offline "
"file with the `_disable_action_flattening=True` flag OR "
"as already normalized (between -1.0 and 1.0) values. "
"Also, when reading already normalized action values from "
"offline files, make sure to set "
"`actions_in_input_normalized=True` so that RLlib will not "
"perform normalization on top."
)

if isinstance(batch, SampleBatch):
pol = self.default_policy
if isinstance(
pol.action_space_struct, (tuple, dict)
) and not pol.config.get("_disable_action_flattening"):
raise ValueError(error_msg)
batch[SampleBatch.ACTIONS] = normalize_action(
batch[SampleBatch.ACTIONS], pol.action_space_struct
)
else:
for pid, b in batch.policy_batches.items():
pol = self.policy_map[pid]
if isinstance(
pol.action_space_struct, (tuple, dict)
) and not pol.config.get("_disable_action_flattening"):
raise ValueError(error_msg)
b[SampleBatch.ACTIONS] = normalize_action(
b[SampleBatch.ACTIONS],
self.ioctx.worker.policy_map[pid].action_space_struct,
)
return batch

def _next_line(self) -> str:
Expand Down
2 changes: 1 addition & 1 deletion rllib/tuned_examples/crr/pendulum-v1-crr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ pendulum_crr:
actor_hidden_activation: 'relu'
actor_hiddens: [256, 256]
actor_lr: 0.0003
actions_in_input_normalized: True
actions_in_input_normalized: False
clip_actions: True
# Q function update setting
twin_q: True
Expand Down