Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Agent collector time complexity reduction #31693

Merged
Merged
1 change: 0 additions & 1 deletion rllib/connectors/action/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ def __call__(self, ac_data: ActionConnectorDataType) -> ActionConnectorDataType:
timer = self.timers[str(c)]
with timer:
ac_data = c(ac_data)
timer.push_units_processed(1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait, we shouldn't get rid of these? same below.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realized we actually don't need these after kourosh asked me about them -> #31693 (comment)

I just executed code from this pr and took the following screenshot just to make sure that the timer actually works as expected.
When calling .mean(), the timer does not care of processed units - we don't need it.
Screenshot 2023-01-18 at 15 54 33

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh that's right, it's part of the with statements.

return ac_data

def to_state(self):
Expand Down
1 change: 0 additions & 1 deletion rllib/connectors/agent/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def __call__(
timer = self.timers[str(c)]
with timer:
ret = c(ret)
timer.push_units_processed(1)
return ret

def to_state(self):
Expand Down
55 changes: 33 additions & 22 deletions rllib/evaluation/collectors/agent_collector.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import copy
import logging

from copy import deepcopy
from gymnasium.spaces import Space
import math
from typing import Any, Dict, List, Optional

import numpy as np
import tree # pip install dm_tree
from typing import Any, Dict, List, Optional
from gymnasium.spaces import Space

from ray.rllib.policy.view_requirement import ViewRequirement
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.view_requirement import ViewRequirement
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.spaces.space_utils import (
flatten_to_single_ndarray,
Expand All @@ -20,7 +20,6 @@
TensorType,
ViewRequirementsDict,
)

from ray.util.annotations import PublicAPI

logger = logging.getLogger(__name__)
Expand All @@ -38,6 +37,17 @@ def _to_float_np_array(v: List[Any]) -> np.ndarray:
return arr


def _get_buffered_slice_with_paddings(d, inds):
element_at_t = []
for index in inds:
if index < len(d):
element_at_t.append(d[index])
else:
# zero pad similar to the last element.
element_at_t.append(tree.map_structure(np.zeros_like, d[-1]))
return element_at_t


@PublicAPI
class AgentCollector:
"""Collects samples for one agent in one trajectory (episode).
Expand Down Expand Up @@ -230,8 +240,7 @@ def add_action_reward_next_obs(self, input_values: Dict[str, TensorType]) -> Non
AgentCollector._next_unroll_id += 1

# Next obs -> obs.
# TODO @kourosh: remove the in-place operations and get rid of this deepcopy.
values = deepcopy(input_values)
values = copy.copy(input_values)
assert SampleBatch.OBS not in values
values[SampleBatch.OBS] = values[SampleBatch.NEXT_OBS]
del values[SampleBatch.NEXT_OBS]
Expand Down Expand Up @@ -356,8 +365,10 @@ def build_for_inference(self) -> SampleBatch:
# before the last one (len(d) - 2) and so on.
element_at_t = d[view_req.shift_arr + len(d) - 1]
if element_at_t.shape[0] == 1:
# squeeze to remove the T dimension if it is 1.
element_at_t = element_at_t.squeeze(0)
# We'd normally squeeze here to remove the time dim, but we'll
# simply use the time dim as the batch dim.
data.append(element_at_t)
continue
# add the batch dimension with [None]
data.append(element_at_t[None])

Expand Down Expand Up @@ -458,20 +469,20 @@ def build_for_training(
# handle the case where the inds are out of bounds from the end.
# if during the indexing any of the indices are out of bounds, we
# need to use padding on the end to fill in the missing indices.
element_at_t = []
for index in inds:
if index < len(d):
element_at_t.append(d[index])
else:
# zero pad similar to the last element.
element_at_t.append(
tree.map_structure(np.zeros_like, d[-1])
)
element_at_t = np.stack(element_at_t)
# Create padding first time we encounter data
if max(inds) < len(d):
# Simple case where we can simply pick slices from buffer
element_at_t = d[inds]
else:
# Case in which we have to pad because buffer has insufficient
# length. This branch takes more time than simply picking
# slices we try to avoid it.
element_at_t = _get_buffered_slice_with_paddings(d, inds)
element_at_t = np.stack(element_at_t)

if element_at_t.shape[0] == 1:
# squeeze to remove the T dimension if it is 1.
element_at_t = element_at_t.squeeze(0)
# Remove the T dimension if it is 1.
element_at_t = element_at_t[0]
shifted_data.append(element_at_t)

# in some multi-agent cases shifted_data may be an empty list.
Expand Down
3 changes: 2 additions & 1 deletion rllib/evaluation/episode_v2.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import random
from collections import defaultdict
import numpy as np
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple

from ray.rllib.env.base_env import _DUMMY_AGENT_ID
Expand Down Expand Up @@ -295,7 +296,7 @@ def postprocess_episode(

if (
not pre_batch.is_single_trajectory()
or len(set(pre_batch[SampleBatch.EPS_ID])) > 1
or len(np.unique(pre_batch[SampleBatch.EPS_ID])) > 1
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because of the above changes, EPS_IDs turn out to be np arrays as well, so set does not work here anymore.

):
raise ValueError(
"Batches sent to postprocessing must only contain steps "
Expand Down