-
Notifications
You must be signed in to change notification settings - Fork 6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Fix performance and functionality flaws in attention nets (via Trajectory view API). #11729
[RLlib] Fix performance and functionality flaws in attention nets (via Trajectory view API). #11729
Conversation
…ectory_view_api_attention_nets
…nto trajectory_view_api_attention_nets # Conflicts: # rllib/models/tf/attention_net.py # rllib/policy/view_requirement.py
…ectory_view_api_enable_by_default_for_all_simple
…ectory_view_api_attention_nets
…ectory_view_api_enable_by_default_for_all_simple
…ectory_view_api_attention_nets # Conflicts: # rllib/agents/ppo/ppo_tf_policy.py # rllib/evaluation/tests/test_trajectory_view_api.py # rllib/policy/tf_policy_template.py # rllib/utils/tf_ops.py
…on_nets # Conflicts: # rllib/agents/ppo/ppo_tf_policy.py # rllib/evaluation/collectors/simple_list_collector.py # rllib/evaluation/tests/test_trajectory_view_api.py # rllib/policy/dynamic_tf_policy.py # rllib/policy/policy.py # rllib/policy/sample_batch.py # rllib/policy/view_requirement.py
def __init__(self, shift_before: int = 0): | ||
self.shift_before = max(shift_before, 1) | ||
def __init__(self, view_reqs): | ||
self.shift_before = -min( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might want to add comment to describe what this code does!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
def add_init_obs(self, episode_id: EpisodeID, agent_id: AgentID, | ||
env_id: EnvID, init_obs: TensorType, | ||
def add_init_obs(self, episode_id: EpisodeID, agent_index: int, | ||
env_id: EnvID, t: int, init_obs: TensorType, | ||
view_requirements: Dict[str, ViewRequirement]) -> None: | ||
"""Adds an initial observation (after reset) to the Agent's trajectory. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change the description. It adds more than a single observation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, it doesn't it's really just adds a single one. Same as it used to work w/ SampleBatchBuilder.
/ view_req.batch_repeat_value)) | ||
repeat_count = (view_req.data_rel_pos_to - | ||
view_req.data_rel_pos_from + 1) | ||
data = np.asarray([ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above, big confused. Add comments on what these lines of code do
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Provided an example.
shift = view_req.data_rel_pos + obs_shift | ||
# Shift is exactly 0: Use trajectory as is. | ||
if shift == 0: | ||
data = np_data[data_col][self.shift_before:] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as before
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Provided an example.
@@ -208,6 +269,43 @@ def _build_buffers(self, single_row: Dict[str, TensorType]) -> None: | |||
[np.zeros(shape=shape, dtype=dtype) | |||
for _ in range(shift)] | |||
|
|||
def _get_input_dict(self, view_reqs, abs_pos: int = -1) -> \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add description what this method does
batch = SampleBatch(self.buffers) | ||
assert SampleBatch.UNROLL_ID in batch.data | ||
batch = SampleBatch( | ||
self.buffers, _seq_lens=self.seq_lens, _dont_check_lens=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is _dont_check_lens?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added explanation.
self.count = 0 | ||
if self.seq_lens is not None: | ||
self.seq_lens = [] | ||
return batch | ||
|
||
|
||
class _PolicyCollectorGroup: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably add comments on what this is
rllib/policy/sample_batch.py
Outdated
for i, seq_len in enumerate(self.seq_lens): | ||
count += seq_len | ||
if count >= end: | ||
data["state_in_0"] = self.data["state_in_0"][state_start: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add comment on what this does
# Range of indices on time-axis, make sure to create | ||
if view_req.data_rel_pos_from is not None: | ||
ret[view_col] = np.zeros_like([[ | ||
view_req.space.sample() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same add comment here
@@ -121,3 +125,14 @@ def rel_shift(x: TensorType) -> TensorType: | |||
x = tf.reshape(x, x_size) | |||
|
|||
return x | |||
|
|||
|
|||
class PositionalEmbedding(tf.keras.layers.Layer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add comments on what this does (how it initializes embedding per position based on cos/sin something)
…ectory_view_api_attention_nets # Conflicts: # rllib/agents/trainer.py # rllib/evaluation/collectors/simple_list_collector.py # rllib/evaluation/tests/test_trajectory_view_api.py # rllib/models/tf/attention_net.py # rllib/policy/policy.py # rllib/policy/torch_policy_template.py # rllib/policy/view_requirement.py # src/ray/raylet/node_manager.cc
…ectory_view_api_attention_nets # Conflicts: # rllib/agents/trainer.py # rllib/evaluation/collectors/simple_list_collector.py # rllib/evaluation/tests/test_trajectory_view_api.py # rllib/models/tf/attention_net.py # rllib/policy/policy.py # rllib/policy/torch_policy_template.py # rllib/policy/view_requirement.py # src/ray/raylet/node_manager.cc
…ectory_view_api_attention_nets � Conflicts: � rllib/agents/ppo/appo_tf_policy.py � rllib/agents/ppo/ppo_torch_policy.py � rllib/agents/qmix/model.py � rllib/evaluation/collectors/simple_list_collector.py � rllib/evaluation/rollout_worker.py � rllib/evaluation/tests/test_trajectory_view_api.py � rllib/models/modelv2.py � rllib/policy/dynamic_tf_policy.py � rllib/policy/policy.py � rllib/policy/sample_batch.py � rllib/policy/view_requirement.py
…ectory_view_api_attention_nets
…ectory_view_api_attention_nets
…ectory_view_api_attention_nets
…ectory_view_api_attention_nets
Moved here: |
Why are these changes needed?
Related issue number
Checks
scripts/format.sh
to lint the changes in this PR.