Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Fix performance and functionality flaws in attention nets (via Trajectory view API). #11729

Closed

Conversation

sven1977
Copy link
Contributor

@sven1977 sven1977 commented Oct 31, 2020

  • RLlib's attention nets (GTrXL) have been forced so far to run "inside" RLlib's RNN API (previous internal states are being passed as new state-ins in subsequent timesteps). This is not favorable for attention nets, which need a different handling and time-slicing of past states (attention net's memory). The trajectory view API allows for specifying the needed time-step ranges for forward passes and batched train passes through attention nets.
  • Besides the above, the handling of the tau-memory of attention nets was also not correct. This PR fixes existing bugs.
  • In a follow up PR, the torch version of GTrXL will be fully included in the testing as well (to make sure it's 100% en-par with the tf version).

Why are these changes needed?

Related issue number

Checks

  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

…ectory_view_api_attention_nets

# Conflicts:
#	rllib/agents/ppo/ppo_tf_policy.py
#	rllib/evaluation/tests/test_trajectory_view_api.py
#	rllib/policy/tf_policy_template.py
#	rllib/utils/tf_ops.py
…on_nets

# Conflicts:
#	rllib/agents/ppo/ppo_tf_policy.py
#	rllib/evaluation/collectors/simple_list_collector.py
#	rllib/evaluation/tests/test_trajectory_view_api.py
#	rllib/policy/dynamic_tf_policy.py
#	rllib/policy/policy.py
#	rllib/policy/sample_batch.py
#	rllib/policy/view_requirement.py
def __init__(self, shift_before: int = 0):
self.shift_before = max(shift_before, 1)
def __init__(self, view_reqs):
self.shift_before = -min(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might want to add comment to describe what this code does!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

def add_init_obs(self, episode_id: EpisodeID, agent_id: AgentID,
env_id: EnvID, init_obs: TensorType,
def add_init_obs(self, episode_id: EpisodeID, agent_index: int,
env_id: EnvID, t: int, init_obs: TensorType,
view_requirements: Dict[str, ViewRequirement]) -> None:
"""Adds an initial observation (after reset) to the Agent's trajectory.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change the description. It adds more than a single observation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it doesn't it's really just adds a single one. Same as it used to work w/ SampleBatchBuilder.

/ view_req.batch_repeat_value))
repeat_count = (view_req.data_rel_pos_to -
view_req.data_rel_pos_from + 1)
data = np.asarray([
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above, big confused. Add comments on what these lines of code do

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Provided an example.

shift = view_req.data_rel_pos + obs_shift
# Shift is exactly 0: Use trajectory as is.
if shift == 0:
data = np_data[data_col][self.shift_before:]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as before

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Provided an example.

@@ -208,6 +269,43 @@ def _build_buffers(self, single_row: Dict[str, TensorType]) -> None:
[np.zeros(shape=shape, dtype=dtype)
for _ in range(shift)]

def _get_input_dict(self, view_reqs, abs_pos: int = -1) -> \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add description what this method does

batch = SampleBatch(self.buffers)
assert SampleBatch.UNROLL_ID in batch.data
batch = SampleBatch(
self.buffers, _seq_lens=self.seq_lens, _dont_check_lens=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is _dont_check_lens?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added explanation.

self.count = 0
if self.seq_lens is not None:
self.seq_lens = []
return batch


class _PolicyCollectorGroup:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably add comments on what this is

for i, seq_len in enumerate(self.seq_lens):
count += seq_len
if count >= end:
data["state_in_0"] = self.data["state_in_0"][state_start:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add comment on what this does

# Range of indices on time-axis, make sure to create
if view_req.data_rel_pos_from is not None:
ret[view_col] = np.zeros_like([[
view_req.space.sample()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same add comment here

@@ -121,3 +125,14 @@ def rel_shift(x: TensorType) -> TensorType:
x = tf.reshape(x, x_size)

return x


class PositionalEmbedding(tf.keras.layers.Layer):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add comments on what this does (how it initializes embedding per position based on cos/sin something)

…ectory_view_api_attention_nets

# Conflicts:
#	rllib/agents/trainer.py
#	rllib/evaluation/collectors/simple_list_collector.py
#	rllib/evaluation/tests/test_trajectory_view_api.py
#	rllib/models/tf/attention_net.py
#	rllib/policy/policy.py
#	rllib/policy/torch_policy_template.py
#	rllib/policy/view_requirement.py
#	src/ray/raylet/node_manager.cc
…ectory_view_api_attention_nets

# Conflicts:
#	rllib/agents/trainer.py
#	rllib/evaluation/collectors/simple_list_collector.py
#	rllib/evaluation/tests/test_trajectory_view_api.py
#	rllib/models/tf/attention_net.py
#	rllib/policy/policy.py
#	rllib/policy/torch_policy_template.py
#	rllib/policy/view_requirement.py
#	src/ray/raylet/node_manager.cc
…ectory_view_api_attention_nets

� Conflicts:
�	rllib/agents/ppo/appo_tf_policy.py
�	rllib/agents/ppo/ppo_torch_policy.py
�	rllib/agents/qmix/model.py
�	rllib/evaluation/collectors/simple_list_collector.py
�	rllib/evaluation/rollout_worker.py
�	rllib/evaluation/tests/test_trajectory_view_api.py
�	rllib/models/modelv2.py
�	rllib/policy/dynamic_tf_policy.py
�	rllib/policy/policy.py
�	rllib/policy/sample_batch.py
�	rllib/policy/view_requirement.py
@sven1977 sven1977 closed this Dec 10, 2020
@sven1977
Copy link
Contributor Author

Moved here:
#12753

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[RLlib Trajectory View API] Trajectory View API works with our Attention Nets.
2 participants