Skip to content

Commit

Permalink
[RLlib] Attention Net prep PR #3. (ray-project#12450)
Browse files Browse the repository at this point in the history
  • Loading branch information
sven1977 authored Dec 7, 2020
1 parent 401d342 commit 99c81c6
Show file tree
Hide file tree
Showing 32 changed files with 355 additions and 248 deletions.
29 changes: 6 additions & 23 deletions rllib/agents/ppo/appo_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
from ray.rllib.agents.impala import vtrace_tf as vtrace
from ray.rllib.agents.impala.vtrace_tf_policy import _make_time_major, \
clip_gradients, choose_optimizer
from ray.rllib.agents.ppo.ppo_tf_policy import postprocess_ppo_gae
from ray.rllib.evaluation.episode import MultiAgentEpisode
from ray.rllib.evaluation.postprocessing import Postprocessing
from ray.rllib.models.tf.tf_action_dist import Categorical
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.evaluation.postprocessing import compute_advantages
from ray.rllib.policy.tf_policy_template import build_tf_policy
from ray.rllib.policy.tf_policy import LearningRateSchedule, TFPolicy
from ray.rllib.agents.ppo.ppo_tf_policy import KLCoeffMixin, ValueNetworkMixin
Expand Down Expand Up @@ -338,31 +338,14 @@ def postprocess_trajectory(
SampleBatch: The postprocessed, modified SampleBatch (or a new one).
"""
if not policy.config["vtrace"]:
completed = sample_batch["dones"][-1]
if completed:
last_r = 0.0
else:
next_state = []
for i in range(policy.num_state_tensors()):
next_state.append([sample_batch["state_out_{}".format(i)][-1]])
last_r = policy._value(sample_batch[SampleBatch.NEXT_OBS][-1],
sample_batch[SampleBatch.ACTIONS][-1],
sample_batch[SampleBatch.REWARDS][-1],
*next_state)
batch = compute_advantages(
sample_batch,
last_r,
policy.config["gamma"],
policy.config["lambda"],
use_gae=policy.config["use_gae"],
use_critic=policy.config["use_critic"])
else:
batch = sample_batch
sample_batch = postprocess_ppo_gae(policy, sample_batch,
other_agent_batches, episode)

# TODO: (sven) remove this del once we have trajectory view API fully in
# place.
del batch.data["new_obs"] # not used, so save some bandwidth
del sample_batch.data["new_obs"] # not used, so save some bandwidth

return batch
return sample_batch


def add_values(policy):
Expand Down
2 changes: 1 addition & 1 deletion rllib/agents/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
# If true, use the Generalized Advantage Estimator (GAE)
# with a value function, see https://arxiv.org/pdf/1506.02438.pdf.
"use_gae": True,
# The GAE(lambda) parameter.
# The GAE (lambda) parameter.
"lambda": 1.0,
# Initial coefficient for KL divergence.
"kl_coeff": 0.2,
Expand Down
70 changes: 48 additions & 22 deletions rllib/agents/ppo/ppo_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,13 +193,22 @@ def postprocess_ppo_gae(
last_r = 0.0
# Trajectory has been truncated -> last r=VF estimate of last obs.
else:
next_state = []
for i in range(policy.num_state_tensors()):
next_state.append(sample_batch["state_out_{}".format(i)][-1])
last_r = policy._value(sample_batch[SampleBatch.NEXT_OBS][-1],
sample_batch[SampleBatch.ACTIONS][-1],
sample_batch[SampleBatch.REWARDS][-1],
*next_state)
# Input dict is provided to us automatically via the Model's
# requirements. It's a single-timestep (last one in trajectory)
# input_dict.
if policy.config["_use_trajectory_view_api"]:
# Create an input dict according to the Model's requirements.
input_dict = policy.model.get_input_dict(sample_batch, index=-1)
last_r = policy._value(**input_dict)
# TODO: (sven) Remove once trajectory view API is all-algo default.
else:
next_state = []
for i in range(policy.num_state_tensors()):
next_state.append(sample_batch["state_out_{}".format(i)][-1])
last_r = policy._value(sample_batch[SampleBatch.NEXT_OBS][-1],
sample_batch[SampleBatch.ACTIONS][-1],
sample_batch[SampleBatch.REWARDS][-1],
*next_state)

# Adds the policy logits, VF preds, and advantages to the batch,
# using GAE ("generalized advantage estimation") or not.
Expand All @@ -208,7 +217,9 @@ def postprocess_ppo_gae(
last_r,
policy.config["gamma"],
policy.config["lambda"],
use_gae=policy.config["use_gae"])
use_gae=policy.config["use_gae"],
use_critic=policy.config.get("use_critic", True))

return batch


Expand Down Expand Up @@ -292,25 +303,40 @@ def __init__(self, obs_space, action_space, config):
# observation.
if config["use_gae"]:

@make_tf_callable(self.get_session())
def value(ob, prev_action, prev_reward, *state):
model_out, _ = self.model({
SampleBatch.CUR_OBS: tf.convert_to_tensor([ob]),
SampleBatch.PREV_ACTIONS: tf.convert_to_tensor(
[prev_action]),
SampleBatch.PREV_REWARDS: tf.convert_to_tensor(
[prev_reward]),
"is_training": tf.convert_to_tensor([False]),
}, [tf.convert_to_tensor([s]) for s in state],
tf.convert_to_tensor([1]))
# [0] = remove the batch dim.
return self.model.value_function()[0]
# Input dict is provided to us automatically via the Model's
# requirements. It's a single-timestep (last one in trajectory)
# input_dict.
if config["_use_trajectory_view_api"]:

@make_tf_callable(self.get_session())
def value(**input_dict):
model_out, _ = self.model.from_batch(
input_dict, is_training=False)
# [0] = remove the batch dim.
return self.model.value_function()[0]

# TODO: (sven) Remove once trajectory view API is all-algo default.
else:

@make_tf_callable(self.get_session())
def value(ob, prev_action, prev_reward, *state):
model_out, _ = self.model({
SampleBatch.CUR_OBS: tf.convert_to_tensor([ob]),
SampleBatch.PREV_ACTIONS: tf.convert_to_tensor(
[prev_action]),
SampleBatch.PREV_REWARDS: tf.convert_to_tensor(
[prev_reward]),
"is_training": tf.convert_to_tensor([False]),
}, [tf.convert_to_tensor([s]) for s in state],
tf.convert_to_tensor([1]))
# [0] = remove the batch dim.
return self.model.value_function()[0]

# When not doing GAE, we do not require the value function's output.
else:

@make_tf_callable(self.get_session())
def value(ob, prev_action, prev_reward, *state):
def value(*args, **kwargs):
return tf.constant(0.0)

self._value = value
Expand Down
46 changes: 30 additions & 16 deletions rllib/agents/ppo/ppo_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,22 +210,36 @@ def __init__(self, obs_space, action_space, config):
# When doing GAE, we need the value function estimate on the
# observation.
if config["use_gae"]:

def value(ob, prev_action, prev_reward, *state):
model_out, _ = self.model({
SampleBatch.CUR_OBS: convert_to_torch_tensor(
np.asarray([ob]), self.device),
SampleBatch.PREV_ACTIONS: convert_to_torch_tensor(
np.asarray([prev_action]), self.device),
SampleBatch.PREV_REWARDS: convert_to_torch_tensor(
np.asarray([prev_reward]), self.device),
"is_training": False,
}, [
convert_to_torch_tensor(np.asarray([s]), self.device)
for s in state
], convert_to_torch_tensor(np.asarray([1]), self.device))
# [0] = remove the batch dim.
return self.model.value_function()[0]
# Input dict is provided to us automatically via the Model's
# requirements. It's a single-timestep (last one in trajectory)
# input_dict.
if config["_use_trajectory_view_api"]:

def value(**input_dict):
model_out, _ = self.model.from_batch(
convert_to_torch_tensor(input_dict, self.device),
is_training=False)
# [0] = remove the batch dim.
return self.model.value_function()[0]

# TODO: (sven) Remove once trajectory view API is all-algo default.
else:

def value(ob, prev_action, prev_reward, *state):
model_out, _ = self.model({
SampleBatch.CUR_OBS: convert_to_torch_tensor(
np.asarray([ob]), self.device),
SampleBatch.PREV_ACTIONS: convert_to_torch_tensor(
np.asarray([prev_action]), self.device),
SampleBatch.PREV_REWARDS: convert_to_torch_tensor(
np.asarray([prev_reward]), self.device),
"is_training": False,
}, [
convert_to_torch_tensor(np.asarray([s]), self.device)
for s in state
], convert_to_torch_tensor(np.asarray([1]), self.device))
# [0] = remove the batch dim.
return self.model.value_function()[0]

# When not doing GAE, we do not require the value function's output.
else:
Expand Down
15 changes: 4 additions & 11 deletions rllib/agents/qmix/model.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
from gym.spaces import Box

from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.preprocessors import get_preprocessor
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.policy.view_requirement import ViewRequirement
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch

Expand All @@ -25,17 +22,13 @@ def __init__(self, obs_space, action_space, num_outputs, model_config,
self.fc2 = nn.Linear(self.rnn_hidden_dim, num_outputs)
self.n_agents = model_config["n_agents"]

self.inference_view_requirements.update({
"state_in_0": ViewRequirement(
"state_out_0",
data_rel_pos=-1,
space=Box(-1.0, 1.0, (self.n_agents, self.rnn_hidden_dim)))
})

@override(ModelV2)
def get_initial_state(self):
# Place hidden states on same device as model.
return [self.fc1.weight.new(1, self.rnn_hidden_dim).zero_().squeeze(0)]
return [
self.fc1.weight.new(self.n_agents,
self.rnn_hidden_dim).zero_().squeeze(0)
]

@override(ModelV2)
def forward(self, input_dict, hidden_state, seq_lens):
Expand Down
3 changes: 0 additions & 3 deletions rllib/agents/qmix/qmix_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,9 +215,6 @@ def __init__(self, obs_space, action_space, config):
name="target_model",
default_model=RNNModel).to(self.device)

# Combine view_requirements for Model and Policy.
self.view_requirements.update(self.model.inference_view_requirements)

self.exploration = self._create_exploration()

# Setup the mixer network.
Expand Down
5 changes: 4 additions & 1 deletion rllib/contrib/maddpg/maddpg_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def postprocess_trajectory(self,
other_agent_batches=None,
episode=None):
# FIXME: Get done from info is required since agentwise done is not
# supported now.
# supported now.
sample_batch.data[SampleBatch.DONES] = self.get_done_from_info(
sample_batch.data[SampleBatch.INFOS])

Expand Down Expand Up @@ -251,6 +251,9 @@ def _make_loss_inputs(placeholders):
loss_inputs=loss_inputs,
dist_inputs=actor_feature)

del self.view_requirements["prev_actions"]
del self.view_requirements["prev_rewards"]

self.sess.run(tf1.global_variables_initializer())

# Hard initial update
Expand Down
2 changes: 1 addition & 1 deletion rllib/evaluation/collectors/sample_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def try_build_truncated_episode_multi_agent_batch(self) -> \
postprocessor.
This is usually called to collect samples for policy training.
If not enough data has been collected yet (`rollout_fragment_length`),
returns None.
returns an empty list.
Returns:
List[Union[MultiAgentBatch, SampleBatch]]: Returns a (possibly
Expand Down
Loading

0 comments on commit 99c81c6

Please sign in to comment.