From 351753aae55894591dafa81814eaa82a59687f09 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Fri, 10 May 2019 20:36:18 -0700 Subject: [PATCH 001/118] [rllib] Remove dependency on TensorFlow (#4764) * remove hard tf dep * add test * comment fix * fix test --- ci/jenkins_tests/run_rllib_tests.sh | 3 ++ python/ray/experimental/tf_utils.py | 4 ++- .../rllib/agents/a3c/a3c_tf_policy_graph.py | 4 ++- python/ray/rllib/agents/ars/policies.py | 4 ++- python/ray/rllib/agents/ars/utils.py | 4 ++- .../rllib/agents/ddpg/ddpg_policy_graph.py | 9 ++++-- .../ray/rllib/agents/dqn/dqn_policy_graph.py | 9 ++++-- python/ray/rllib/agents/es/policies.py | 4 ++- python/ray/rllib/agents/es/utils.py | 4 ++- python/ray/rllib/agents/impala/vtrace.py | 6 ++-- .../agents/impala/vtrace_policy_graph.py | 4 ++- .../agents/marwil/marwil_policy_graph.py | 5 ++-- python/ray/rllib/agents/pg/pg_policy_graph.py | 5 ++-- .../ray/rllib/agents/ppo/appo_policy_graph.py | 4 ++- .../ray/rllib/agents/ppo/ppo_policy_graph.py | 4 ++- python/ray/rllib/agents/trainer.py | 13 +++++++-- python/ray/rllib/evaluation/metrics.py | 29 +++++++++++-------- .../ray/rllib/evaluation/policy_evaluator.py | 8 +++-- .../ray/rllib/evaluation/tf_policy_graph.py | 3 +- python/ray/rllib/models/action_dist.py | 11 +++++-- python/ray/rllib/models/catalog.py | 4 ++- python/ray/rllib/models/fcnet.py | 8 +++-- python/ray/rllib/models/lstm.py | 7 +++-- python/ray/rllib/models/misc.py | 9 ++++-- python/ray/rllib/models/model.py | 4 ++- python/ray/rllib/models/visionnet.py | 8 +++-- python/ray/rllib/offline/input_reader.py | 4 ++- python/ray/rllib/optimizers/multi_gpu_impl.py | 4 ++- .../rllib/optimizers/multi_gpu_optimizer.py | 4 ++- python/ray/rllib/tests/test_dependency.py | 24 +++++++++++++++ python/ray/rllib/utils/__init__.py | 14 +++++++++ python/ray/rllib/utils/explained_variance.py | 4 ++- python/ray/rllib/utils/seed.py | 4 ++- python/ray/rllib/utils/tf_run_builder.py | 7 +++-- python/ray/tune/logger.py | 12 +++++--- 35 files changed, 190 insertions(+), 64 deletions(-) create mode 100644 python/ray/rllib/tests/test_dependency.py diff --git a/ci/jenkins_tests/run_rllib_tests.sh b/ci/jenkins_tests/run_rllib_tests.sh index efe30a0a7780..fa10c14b8c5a 100644 --- a/ci/jenkins_tests/run_rllib_tests.sh +++ b/ci/jenkins_tests/run_rllib_tests.sh @@ -289,6 +289,9 @@ docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ /ray/ci/suppress_output python /ray/python/ray/rllib/tests/test_local.py +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ + /ray/ci/suppress_output python /ray/python/ray/rllib/tests/test_dependency.py + docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ /ray/ci/suppress_output python /ray/python/ray/rllib/tests/test_legacy.py diff --git a/python/ray/experimental/tf_utils.py b/python/ray/experimental/tf_utils.py index d2f1b259961c..900cc948b066 100644 --- a/python/ray/experimental/tf_utils.py +++ b/python/ray/experimental/tf_utils.py @@ -5,7 +5,9 @@ from collections import deque, OrderedDict import numpy as np -import tensorflow as tf +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() def unflatten(vector, shapes): diff --git a/python/ray/rllib/agents/a3c/a3c_tf_policy_graph.py b/python/ray/rllib/agents/a3c/a3c_tf_policy_graph.py index d4e140543e31..e6ae8d17bad3 100644 --- a/python/ray/rllib/agents/a3c/a3c_tf_policy_graph.py +++ b/python/ray/rllib/agents/a3c/a3c_tf_policy_graph.py @@ -4,7 +4,6 @@ from __future__ import division from __future__ import print_function -import tensorflow as tf import gym import ray @@ -19,6 +18,9 @@ LearningRateSchedule from ray.rllib.models.catalog import ModelCatalog from ray.rllib.utils.annotations import override +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() class A3CLoss(object): diff --git a/python/ray/rllib/agents/ars/policies.py b/python/ray/rllib/agents/ars/policies.py index fe82be5b65dd..7fdb54b99cd8 100644 --- a/python/ray/rllib/agents/ars/policies.py +++ b/python/ray/rllib/agents/ars/policies.py @@ -7,13 +7,15 @@ import gym import numpy as np -import tensorflow as tf import ray import ray.experimental.tf_utils from ray.rllib.evaluation.sampler import _unbatch_tuple_actions from ray.rllib.utils.filter import get_filter from ray.rllib.models import ModelCatalog +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() def rollout(policy, env, timestep_limit=None, add_noise=False, offset=0): diff --git a/python/ray/rllib/agents/ars/utils.py b/python/ray/rllib/agents/ars/utils.py index 1575e46c3837..518fd3d00634 100644 --- a/python/ray/rllib/agents/ars/utils.py +++ b/python/ray/rllib/agents/ars/utils.py @@ -6,7 +6,9 @@ from __future__ import print_function import numpy as np -import tensorflow as tf +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() def compute_ranks(x): diff --git a/python/ray/rllib/agents/ddpg/ddpg_policy_graph.py b/python/ray/rllib/agents/ddpg/ddpg_policy_graph.py index 9304cbe0b598..52b7593f2dfa 100644 --- a/python/ray/rllib/agents/ddpg/ddpg_policy_graph.py +++ b/python/ray/rllib/agents/ddpg/ddpg_policy_graph.py @@ -4,8 +4,6 @@ from gym.spaces import Box import numpy as np -import tensorflow as tf -import tensorflow.contrib.layers as layers import ray import ray.experimental.tf_utils @@ -18,6 +16,9 @@ from ray.rllib.utils.error import UnsupportedSpaceException from ray.rllib.evaluation.policy_graph import PolicyGraph from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() ACTION_SCOPE = "action" POLICY_SCOPE = "policy" @@ -397,6 +398,8 @@ def set_state(self, state): self.set_pure_exploration_phase(state[2]) def _build_q_network(self, obs, obs_space, action_space, actions): + import tensorflow.contrib.layers as layers + if self.config["use_state_preprocessor"]: q_model = ModelCatalog.get_model({ "obs": obs, @@ -417,6 +420,8 @@ def _build_q_network(self, obs, obs_space, action_space, actions): return q_values, q_model def _build_policy_network(self, obs, obs_space, action_space): + import tensorflow.contrib.layers as layers + if self.config["use_state_preprocessor"]: model = ModelCatalog.get_model({ "obs": obs, diff --git a/python/ray/rllib/agents/dqn/dqn_policy_graph.py b/python/ray/rllib/agents/dqn/dqn_policy_graph.py index 6a226d237461..5af38ed9e958 100644 --- a/python/ray/rllib/agents/dqn/dqn_policy_graph.py +++ b/python/ray/rllib/agents/dqn/dqn_policy_graph.py @@ -5,8 +5,6 @@ from gym.spaces import Discrete import numpy as np from scipy.stats import entropy -import tensorflow as tf -import tensorflow.contrib.layers as layers import ray from ray.rllib.evaluation.sample_batch import SampleBatch @@ -17,6 +15,9 @@ from ray.rllib.evaluation.policy_graph import PolicyGraph from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \ LearningRateSchedule +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() Q_SCOPE = "q_func" Q_TARGET_SCOPE = "target_q_func" @@ -153,6 +154,8 @@ def __init__(self, v_max=10.0, sigma0=0.5, parameter_noise=False): + import tensorflow.contrib.layers as layers + self.model = model with tf.variable_scope("action_value"): if hiddens: @@ -263,6 +266,8 @@ def noisy_layer(self, prefix, action_in, out_size, sigma0, distributions and \sigma are trainable variables which are expected to vanish along the training procedure """ + import tensorflow.contrib.layers as layers + in_size = int(action_in.shape[1]) epsilon_in = tf.random_normal(shape=[in_size]) diff --git a/python/ray/rllib/agents/es/policies.py b/python/ray/rllib/agents/es/policies.py index 78ff29da4f86..dfc7e2deec47 100644 --- a/python/ray/rllib/agents/es/policies.py +++ b/python/ray/rllib/agents/es/policies.py @@ -7,13 +7,15 @@ import gym import numpy as np -import tensorflow as tf import ray import ray.experimental.tf_utils from ray.rllib.evaluation.sampler import _unbatch_tuple_actions from ray.rllib.models import ModelCatalog from ray.rllib.utils.filter import get_filter +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() def rollout(policy, env, timestep_limit=None, add_noise=False): diff --git a/python/ray/rllib/agents/es/utils.py b/python/ray/rllib/agents/es/utils.py index 1575e46c3837..518fd3d00634 100644 --- a/python/ray/rllib/agents/es/utils.py +++ b/python/ray/rllib/agents/es/utils.py @@ -6,7 +6,9 @@ from __future__ import print_function import numpy as np -import tensorflow as tf +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() def compute_ranks(x): diff --git a/python/ray/rllib/agents/impala/vtrace.py b/python/ray/rllib/agents/impala/vtrace.py index 238b30d99355..cc560d9937e4 100644 --- a/python/ray/rllib/agents/impala/vtrace.py +++ b/python/ray/rllib/agents/impala/vtrace.py @@ -34,9 +34,11 @@ import collections -import tensorflow as tf +from ray.rllib.utils import try_import_tf -nest = tf.contrib.framework.nest +tf = try_import_tf() +if tf: + nest = tf.contrib.framework.nest VTraceFromLogitsReturns = collections.namedtuple("VTraceFromLogitsReturns", [ "vs", "pg_advantages", "log_rhos", "behaviour_action_log_probs", diff --git a/python/ray/rllib/agents/impala/vtrace_policy_graph.py b/python/ray/rllib/agents/impala/vtrace_policy_graph.py index ff4cfdb4be97..702aefb50a6e 100644 --- a/python/ray/rllib/agents/impala/vtrace_policy_graph.py +++ b/python/ray/rllib/agents/impala/vtrace_policy_graph.py @@ -9,7 +9,6 @@ import gym import ray import numpy as np -import tensorflow as tf from ray.rllib.agents.impala import vtrace from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY from ray.rllib.evaluation.policy_graph import PolicyGraph @@ -21,6 +20,9 @@ from ray.rllib.utils.annotations import override from ray.rllib.utils.error import UnsupportedSpaceException from ray.rllib.utils.explained_variance import explained_variance +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() # Frozen logits of the policy that computed the action BEHAVIOUR_LOGITS = "behaviour_logits" diff --git a/python/ray/rllib/agents/marwil/marwil_policy_graph.py b/python/ray/rllib/agents/marwil/marwil_policy_graph.py index 2dd67ab5f39c..2c647db9aa96 100644 --- a/python/ray/rllib/agents/marwil/marwil_policy_graph.py +++ b/python/ray/rllib/agents/marwil/marwil_policy_graph.py @@ -2,8 +2,6 @@ from __future__ import division from __future__ import print_function -import tensorflow as tf - import ray from ray.rllib.models import ModelCatalog from ray.rllib.evaluation.postprocessing import compute_advantages, \ @@ -15,6 +13,9 @@ from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph from ray.rllib.agents.dqn.dqn_policy_graph import _scope_vars from ray.rllib.utils.explained_variance import explained_variance +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() POLICY_SCOPE = "p_func" VALUE_SCOPE = "v_func" diff --git a/python/ray/rllib/agents/pg/pg_policy_graph.py b/python/ray/rllib/agents/pg/pg_policy_graph.py index 6e8abd7d4a81..a55af79b1e61 100644 --- a/python/ray/rllib/agents/pg/pg_policy_graph.py +++ b/python/ray/rllib/agents/pg/pg_policy_graph.py @@ -2,8 +2,6 @@ from __future__ import division from __future__ import print_function -import tensorflow as tf - import ray from ray.rllib.models.catalog import ModelCatalog from ray.rllib.evaluation.postprocessing import compute_advantages, \ @@ -12,6 +10,9 @@ from ray.rllib.evaluation.sample_batch import SampleBatch from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph from ray.rllib.utils.annotations import override +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() class PGLoss(object): diff --git a/python/ray/rllib/agents/ppo/appo_policy_graph.py b/python/ray/rllib/agents/ppo/appo_policy_graph.py index 89e49153f90c..64523c60d1b3 100644 --- a/python/ray/rllib/agents/ppo/appo_policy_graph.py +++ b/python/ray/rllib/agents/ppo/appo_policy_graph.py @@ -7,7 +7,6 @@ from __future__ import print_function import numpy as np -import tensorflow as tf import logging import gym @@ -23,6 +22,9 @@ from ray.rllib.utils.explained_variance import explained_variance from ray.rllib.models.action_dist import MultiCategorical from ray.rllib.evaluation.postprocessing import compute_advantages +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() logger = logging.getLogger(__name__) diff --git a/python/ray/rllib/agents/ppo/ppo_policy_graph.py b/python/ray/rllib/agents/ppo/ppo_policy_graph.py index 8bede3421f6d..61aced1db740 100644 --- a/python/ray/rllib/agents/ppo/ppo_policy_graph.py +++ b/python/ray/rllib/agents/ppo/ppo_policy_graph.py @@ -3,7 +3,6 @@ from __future__ import print_function import logging -import tensorflow as tf import ray from ray.rllib.evaluation.postprocessing import compute_advantages, \ @@ -16,6 +15,9 @@ from ray.rllib.models.catalog import ModelCatalog from ray.rllib.utils.annotations import override from ray.rllib.utils.explained_variance import explained_variance +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() logger = logging.getLogger(__name__) diff --git a/python/ray/rllib/agents/trainer.py b/python/ray/rllib/agents/trainer.py index 0b033baf60cf..8e6db02707d8 100644 --- a/python/ray/rllib/agents/trainer.py +++ b/python/ray/rllib/agents/trainer.py @@ -10,7 +10,6 @@ import six import time import tempfile -import tensorflow as tf from types import FunctionType import ray @@ -26,12 +25,15 @@ from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI from ray.rllib.utils import FilterManager, deep_update, merge_dicts from ray.rllib.utils.memory import ray_get_and_free +from ray.rllib.utils import try_import_tf from ray.tune.registry import ENV_CREATOR, register_env, _global_registry from ray.tune.trainable import Trainable from ray.tune.trial import Resources, ExportFormat from ray.tune.logger import UnifiedLogger from ray.tune.result import DEFAULT_RESULTS_DIR +tf = try_import_tf() + logger = logging.getLogger(__name__) # Max number of times to retry a worker failure. We shouldn't try too many @@ -412,8 +414,13 @@ def _setup(self, config): if self.config.get("log_level"): logging.getLogger("ray.rllib").setLevel(self.config["log_level"]) - # TODO(ekl) setting the graph is unnecessary for PyTorch agents - with tf.Graph().as_default(): + def get_scope(): + if tf: + return tf.Graph().as_default() + else: + return open("/dev/null") # fake a no-op scope + + with get_scope(): self._init(self.config, self.env_creator) # Evaluation related diff --git a/python/ray/rllib/evaluation/metrics.py b/python/ray/rllib/evaluation/metrics.py index fe43257226cf..a92c64bc9e4b 100644 --- a/python/ray/rllib/evaluation/metrics.py +++ b/python/ray/rllib/evaluation/metrics.py @@ -59,18 +59,23 @@ def collect_episodes(local_evaluator=None, timeout_seconds=180): """Gathers new episodes metrics tuples from the given evaluators.""" - pending = [ - a.apply.remote(lambda ev: ev.get_metrics()) for a in remote_evaluators - ] - collected, _ = ray.wait( - pending, num_returns=len(pending), timeout=timeout_seconds * 1.0) - num_metric_batches_dropped = len(pending) - len(collected) - if pending and len(collected) == 0: - raise ValueError( - "Timed out waiting for metrics from workers. You can configure " - "this timeout with `collect_metrics_timeout`.") - - metric_lists = ray_get_and_free(collected) + if remote_evaluators: + pending = [ + a.apply.remote(lambda ev: ev.get_metrics()) + for a in remote_evaluators + ] + collected, _ = ray.wait( + pending, num_returns=len(pending), timeout=timeout_seconds * 1.0) + num_metric_batches_dropped = len(pending) - len(collected) + if pending and len(collected) == 0: + raise ValueError( + "Timed out waiting for metrics from workers. You can " + "configure this timeout with `collect_metrics_timeout`.") + metric_lists = ray_get_and_free(collected) + else: + metric_lists = [] + num_metric_batches_dropped = 0 + if local_evaluator: metric_lists.append(local_evaluator.get_metrics()) episodes = [] diff --git a/python/ray/rllib/evaluation/policy_evaluator.py b/python/ray/rllib/evaluation/policy_evaluator.py index faf9b576d4aa..f6761122156e 100644 --- a/python/ray/rllib/evaluation/policy_evaluator.py +++ b/python/ray/rllib/evaluation/policy_evaluator.py @@ -5,7 +5,6 @@ import gym import logging import pickle -import tensorflow as tf import ray from ray.rllib.env.atari_wrappers import wrap_deepmind, is_atari @@ -32,7 +31,9 @@ summarize, enable_periodic_logging from ray.rllib.utils.filter import get_filter from ray.rllib.utils.tf_run_builder import TFRunBuilder +from ray.rllib.utils import try_import_tf +tf = try_import_tf() logger = logging.getLogger(__name__) # Handle to the current evaluator, which will be set to the most recently @@ -722,7 +723,10 @@ def _build_policy_map(self, policy_dict, policy_config): "Found raw Tuple|Dict space as input to policy graph. " "Please preprocess these observations with a " "Tuple|DictFlatteningPreprocessor.") - with tf.variable_scope(name): + if tf: + with tf.variable_scope(name): + policy_map[name] = cls(obs_space, act_space, merged_conf) + else: policy_map[name] = cls(obs_space, act_space, merged_conf) if self.worker_index == 0: logger.info("Built policy map: {}".format(policy_map)) diff --git a/python/ray/rllib/evaluation/tf_policy_graph.py b/python/ray/rllib/evaluation/tf_policy_graph.py index 416de4d41089..2b1eca9e8d5b 100644 --- a/python/ray/rllib/evaluation/tf_policy_graph.py +++ b/python/ray/rllib/evaluation/tf_policy_graph.py @@ -5,7 +5,6 @@ import os import errno import logging -import tensorflow as tf import numpy as np import ray @@ -18,7 +17,9 @@ from ray.rllib.utils.debug import log_once, summarize from ray.rllib.utils.schedules import ConstantSchedule, PiecewiseSchedule from ray.rllib.utils.tf_run_builder import TFRunBuilder +from ray.rllib.utils import try_import_tf +tf = try_import_tf() logger = logging.getLogger(__name__) diff --git a/python/ray/rllib/models/action_dist.py b/python/ray/rllib/models/action_dist.py index 026a6c493e5c..1cad7d3aa9ac 100644 --- a/python/ray/rllib/models/action_dist.py +++ b/python/ray/rllib/models/action_dist.py @@ -4,13 +4,18 @@ from collections import namedtuple import distutils.version -import tensorflow as tf import numpy as np from ray.rllib.utils.annotations import override, DeveloperAPI +from ray.rllib.utils import try_import_tf -use_tf150_api = (distutils.version.LooseVersion(tf.VERSION) >= - distutils.version.LooseVersion("1.5.0")) +tf = try_import_tf() + +if tf: + use_tf150_api = (distutils.version.LooseVersion(tf.VERSION) >= + distutils.version.LooseVersion("1.5.0")) +else: + use_tf150_api = False @DeveloperAPI diff --git a/python/ray/rllib/models/catalog.py b/python/ray/rllib/models/catalog.py index ce91742c3f5d..d237474480e5 100644 --- a/python/ray/rllib/models/catalog.py +++ b/python/ray/rllib/models/catalog.py @@ -5,7 +5,6 @@ import gym import logging import numpy as np -import tensorflow as tf from functools import partial from ray.tune.registry import RLLIB_MODEL, RLLIB_PREPROCESSOR, \ @@ -22,6 +21,9 @@ from ray.rllib.models.visionnet import VisionNetwork from ray.rllib.models.lstm import LSTM from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() logger = logging.getLogger(__name__) diff --git a/python/ray/rllib/models/fcnet.py b/python/ray/rllib/models/fcnet.py index 19745b9e7a3c..3cc0fbe403c5 100644 --- a/python/ray/rllib/models/fcnet.py +++ b/python/ray/rllib/models/fcnet.py @@ -2,12 +2,12 @@ from __future__ import division from __future__ import print_function -import tensorflow as tf -import tensorflow.contrib.slim as slim - from ray.rllib.models.model import Model from ray.rllib.models.misc import normc_initializer, get_activation_fn from ray.rllib.utils.annotations import override +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() class FullyConnectedNetwork(Model): @@ -21,6 +21,8 @@ def _build_layers(self, inputs, num_outputs, options): model that processes the components separately, use _build_layers_v2(). """ + import tensorflow.contrib.slim as slim + hiddens = options.get("fcnet_hiddens") activation = get_activation_fn(options.get("fcnet_activation")) diff --git a/python/ray/rllib/models/lstm.py b/python/ray/rllib/models/lstm.py index 18f141d095f9..5b9328c3c463 100644 --- a/python/ray/rllib/models/lstm.py +++ b/python/ray/rllib/models/lstm.py @@ -18,12 +18,13 @@ """ import numpy as np -import tensorflow as tf -import tensorflow.contrib.rnn as rnn from ray.rllib.models.misc import linear, normc_initializer from ray.rllib.models.model import Model from ray.rllib.utils.annotations import override, DeveloperAPI, PublicAPI +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() class LSTM(Model): @@ -37,6 +38,8 @@ class LSTM(Model): @override(Model) def _build_layers_v2(self, input_dict, num_outputs, options): + import tensorflow.contrib.rnn as rnn + cell_size = options.get("lstm_cell_size") if options.get("lstm_use_prev_action_reward"): action_dim = int( diff --git a/python/ray/rllib/models/misc.py b/python/ray/rllib/models/misc.py index aad399c3b222..73ee1d87c6fd 100644 --- a/python/ray/rllib/models/misc.py +++ b/python/ray/rllib/models/misc.py @@ -2,8 +2,10 @@ from __future__ import division from __future__ import print_function -import tensorflow as tf import numpy as np +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() def normc_initializer(std=1.0): @@ -25,8 +27,11 @@ def conv2d(x, filter_size=(3, 3), stride=(1, 1), pad="SAME", - dtype=tf.float32, + dtype=None, collections=None): + if dtype is None: + dtype = tf.float32 + with tf.variable_scope(name): stride_shape = [1, stride[0], stride[1], 1] filter_shape = [ diff --git a/python/ray/rllib/models/model.py b/python/ray/rllib/models/model.py index b5664057d9a8..4996f3cdf437 100644 --- a/python/ray/rllib/models/model.py +++ b/python/ray/rllib/models/model.py @@ -5,11 +5,13 @@ from collections import OrderedDict import gym -import tensorflow as tf from ray.rllib.models.misc import linear, normc_initializer from ray.rllib.models.preprocessors import get_preprocessor from ray.rllib.utils.annotations import PublicAPI, DeveloperAPI +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() @PublicAPI diff --git a/python/ray/rllib/models/visionnet.py b/python/ray/rllib/models/visionnet.py index 432a3317c782..53eaf5d02c3f 100644 --- a/python/ray/rllib/models/visionnet.py +++ b/python/ray/rllib/models/visionnet.py @@ -2,12 +2,12 @@ from __future__ import division from __future__ import print_function -import tensorflow as tf -import tensorflow.contrib.slim as slim - from ray.rllib.models.model import Model from ray.rllib.models.misc import get_activation_fn, flatten from ray.rllib.utils.annotations import override +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() class VisionNetwork(Model): @@ -15,6 +15,8 @@ class VisionNetwork(Model): @override(Model) def _build_layers_v2(self, input_dict, num_outputs, options): + import tensorflow.contrib.slim as slim + inputs = input_dict["obs"] filters = options.get("conv_filters") if not filters: diff --git a/python/ray/rllib/offline/input_reader.py b/python/ray/rllib/offline/input_reader.py index bb4fe91161a2..5315773fd839 100644 --- a/python/ray/rllib/offline/input_reader.py +++ b/python/ray/rllib/offline/input_reader.py @@ -4,11 +4,13 @@ import logging import numpy as np -import tensorflow as tf import threading from ray.rllib.evaluation.sample_batch import MultiAgentBatch from ray.rllib.utils.annotations import PublicAPI +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() logger = logging.getLogger(__name__) diff --git a/python/ray/rllib/optimizers/multi_gpu_impl.py b/python/ray/rllib/optimizers/multi_gpu_impl.py index 7c3feb165e5b..d892dbe7dbac 100644 --- a/python/ray/rllib/optimizers/multi_gpu_impl.py +++ b/python/ray/rllib/optimizers/multi_gpu_impl.py @@ -4,9 +4,11 @@ from collections import namedtuple import logging -import tensorflow as tf from ray.rllib.utils.debug import log_once, summarize +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() # Variable scope in which created variables will be placed under TOWER_SCOPE_NAME = "tower" diff --git a/python/ray/rllib/optimizers/multi_gpu_optimizer.py b/python/ray/rllib/optimizers/multi_gpu_optimizer.py index 23ee1833b9f0..45df865e43ff 100644 --- a/python/ray/rllib/optimizers/multi_gpu_optimizer.py +++ b/python/ray/rllib/optimizers/multi_gpu_optimizer.py @@ -6,7 +6,6 @@ import math import numpy as np from collections import defaultdict -import tensorflow as tf import ray from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY @@ -19,6 +18,9 @@ from ray.rllib.utils.timer import TimerStat from ray.rllib.evaluation.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \ MultiAgentBatch +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() logger = logging.getLogger(__name__) diff --git a/python/ray/rllib/tests/test_dependency.py b/python/ray/rllib/tests/test_dependency.py new file mode 100644 index 000000000000..2df0b4b95937 --- /dev/null +++ b/python/ray/rllib/tests/test_dependency.py @@ -0,0 +1,24 @@ +#!/usr/bin/env python + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import sys + +os.environ["RLLIB_TEST_NO_TF_IMPORT"] = "1" + +if __name__ == "__main__": + from ray.rllib.agents.a3c import A2CTrainer + assert "tensorflow" not in sys.modules, "TF initially present" + + # note: no ray.init(), to test it works without Ray + trainer = A2CTrainer( + env="CartPole-v0", config={ + "use_pytorch": True, + "num_workers": 0 + }) + trainer.train() + + assert "tensorflow" not in sys.modules, "TF should not be imported" diff --git a/python/ray/rllib/utils/__init__.py b/python/ray/rllib/utils/__init__.py index 7aab0f2a0dfb..9ff0295690e2 100644 --- a/python/ray/rllib/utils/__init__.py +++ b/python/ray/rllib/utils/__init__.py @@ -1,4 +1,5 @@ import logging +import os from ray.rllib.utils.filter_manager import FilterManager from ray.rllib.utils.filter import Filter @@ -26,6 +27,18 @@ def __init__(self, config=None, env=None, logger_creator=None): return DeprecationWrapper +def try_import_tf(): + if "RLLIB_TEST_NO_TF_IMPORT" in os.environ: + logger.warning("Not importing TensorFlow for test purposes") + return None + + try: + import tensorflow as tf + return tf + except ImportError: + return None + + __all__ = [ "Filter", "FilterManager", @@ -34,4 +47,5 @@ def __init__(self, config=None, env=None, logger_creator=None): "merge_dicts", "deep_update", "renamed_class", + "try_import_tf", ] diff --git a/python/ray/rllib/utils/explained_variance.py b/python/ray/rllib/utils/explained_variance.py index 942f0f8f31f0..a3e9cbadbee3 100644 --- a/python/ray/rllib/utils/explained_variance.py +++ b/python/ray/rllib/utils/explained_variance.py @@ -2,7 +2,9 @@ from __future__ import division from __future__ import print_function -import tensorflow as tf +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() def explained_variance(y, pred): diff --git a/python/ray/rllib/utils/seed.py b/python/ray/rllib/utils/seed.py index bec02b6ad6ec..3675fd11913d 100644 --- a/python/ray/rllib/utils/seed.py +++ b/python/ray/rllib/utils/seed.py @@ -4,7 +4,9 @@ import numpy as np import random -import tensorflow as tf +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() def seed(np_seed=0, random_seed=0, tf_seed=0): diff --git a/python/ray/rllib/utils/tf_run_builder.py b/python/ray/rllib/utils/tf_run_builder.py index ef411b047d7b..4694c96c03c1 100644 --- a/python/ray/rllib/utils/tf_run_builder.py +++ b/python/ray/rllib/utils/tf_run_builder.py @@ -6,11 +6,10 @@ import os import time -import tensorflow as tf -from tensorflow.python.client import timeline - from ray.rllib.utils.debug import log_once +from ray.rllib.utils import try_import_tf +tf = try_import_tf() logger = logging.getLogger(__name__) @@ -65,6 +64,8 @@ def get(self, to_fetch): def run_timeline(sess, ops, debug_name, feed_dict={}, timeline_dir=None): if timeline_dir: + from tensorflow.python.client import timeline + run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) run_metadata = tf.RunMetadata() start = time.time() diff --git a/python/ray/tune/logger.py b/python/ray/tune/logger.py index 6095cfb4dcbd..9d472cac36fe 100644 --- a/python/ray/tune/logger.py +++ b/python/ray/tune/logger.py @@ -118,10 +118,14 @@ class TFLogger(Logger): def _init(self): try: global tf, use_tf150_api - import tensorflow - tf = tensorflow - use_tf150_api = (distutils.version.LooseVersion(tf.VERSION) >= - distutils.version.LooseVersion("1.5.0")) + if "RLLIB_TEST_NO_TF_IMPORT" in os.environ: + logger.warning("Not importing TensorFlow for test purposes") + tf = None + else: + import tensorflow + tf = tensorflow + use_tf150_api = (distutils.version.LooseVersion(tf.VERSION) >= + distutils.version.LooseVersion("1.5.0")) except ImportError: logger.warning("Couldn't import TensorFlow - " "disabling TensorBoard logging.") From 004440f526f13586a1740174235f1aaffc4937b9 Mon Sep 17 00:00:00 2001 From: Romil Bhardwaj Date: Sat, 11 May 2019 05:06:04 -0700 Subject: [PATCH 002/118] Dynamic Custom Resources - create and delete resources (#3742) --- BUILD.bazel | 1 + doc/source/conf.py | 1 + doc/source/development.rst | 2 +- java/BUILD.bazel | 1 + .../java/org/ray/runtime/gcs/GcsClient.java | 17 +- python/ray/_raylet.pyx | 4 + python/ray/experimental/__init__.py | 3 +- python/ray/experimental/dynamic_resources.py | 35 ++ python/ray/experimental/state.py | 63 +- python/ray/includes/libraylet.pxd | 1 + python/ray/tests/cluster_utils.py | 4 +- python/ray/tests/test_dynres.py | 586 ++++++++++++++++++ src/ray/gcs/client_test.cc | 4 +- src/ray/gcs/format/gcs.fbs | 13 +- src/ray/gcs/tables.cc | 121 +++- src/ray/gcs/tables.h | 26 +- src/ray/object_manager/object_directory.cc | 2 +- src/ray/raylet/format/node_manager.fbs | 11 + src/ray/raylet/monitor.cc | 3 +- src/ray/raylet/node_manager.cc | 168 ++++- src/ray/raylet/node_manager.h | 22 + src/ray/raylet/raylet_client.cc | 10 + src/ray/raylet/raylet_client.h | 8 + 23 files changed, 1041 insertions(+), 65 deletions(-) create mode 100644 python/ray/experimental/dynamic_resources.py create mode 100644 python/ray/tests/test_dynres.py diff --git a/BUILD.bazel b/BUILD.bazel index 61484ba82f6b..e2cbdd64bf51 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -495,6 +495,7 @@ flatbuffer_py_library( "ConfigTableData.py", "CustomSerializerData.py", "DriverTableData.py", + "EntryType.py", "ErrorTableData.py", "ErrorType.py", "FunctionTableData.py", diff --git a/doc/source/conf.py b/doc/source/conf.py index b67dbe267d4c..e0bd2c6dad4c 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -26,6 +26,7 @@ "ray.core.generated.ActorCheckpointIdData", "ray.core.generated.ClientTableData", "ray.core.generated.DriverTableData", + "ray.core.generated.EntryType", "ray.core.generated.ErrorTableData", "ray.core.generated.ErrorType", "ray.core.generated.GcsTableEntry", diff --git a/doc/source/development.rst b/doc/source/development.rst index 66e666b4d1a4..e4d50327a43a 100644 --- a/doc/source/development.rst +++ b/doc/source/development.rst @@ -81,7 +81,7 @@ API. The easiest way to do this is to start or connect to a Ray cluster with ray.worker.global_state.client_table() # Returns current information about the nodes in the cluster, such as: # [{'ClientID': '2a9d2b34ad24a37ed54e4fcd32bf19f915742f5b', - # 'IsInsertion': True, + # 'EntryType': 0, # 'NodeManagerAddress': '1.2.3.4', # 'NodeManagerPort': 43280, # 'ObjectManagerPort': 38062, diff --git a/java/BUILD.bazel b/java/BUILD.bazel index 34799a76c78a..2d2762d837e6 100644 --- a/java/BUILD.bazel +++ b/java/BUILD.bazel @@ -154,6 +154,7 @@ flatbuffers_generated_files = [ "ConfigTableData.java", "CustomSerializerData.java", "DriverTableData.java", + "EntryType.java", "ErrorTableData.java", "ErrorType.java", "FunctionTableData.java", diff --git a/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java b/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java index a627f200a0e6..647b77e336b4 100644 --- a/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java +++ b/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java @@ -13,6 +13,7 @@ import org.ray.api.runtimecontext.NodeInfo; import org.ray.runtime.generated.ActorCheckpointIdData; import org.ray.runtime.generated.ClientTableData; +import org.ray.runtime.generated.EntryType; import org.ray.runtime.generated.TablePrefix; import org.ray.runtime.util.UniqueIdUtil; import org.slf4j.Logger; @@ -63,7 +64,7 @@ public List getAllNodeInfo() { ClientTableData data = ClientTableData.getRootAsClientTableData(ByteBuffer.wrap(result)); final UniqueId clientId = UniqueId.fromByteBuffer(data.clientIdAsByteBuffer()); - if (data.isInsertion()) { + if (data.entryType() == EntryType.INSERTION) { //Code path of node insertion. Map resources = new HashMap<>(); // Compute resources. @@ -72,12 +73,24 @@ public List getAllNodeInfo() { for (int i = 0; i < data.resourcesTotalLabelLength(); i++) { resources.put(data.resourcesTotalLabel(i), data.resourcesTotalCapacity(i)); } - NodeInfo nodeInfo = new NodeInfo( clientId, data.nodeManagerAddress(), true, resources); clients.put(clientId, nodeInfo); + } else if (data.entryType() == EntryType.RES_CREATEUPDATE){ + Preconditions.checkState(clients.containsKey(clientId)); + NodeInfo nodeInfo = clients.get(clientId); + for (int i = 0; i < data.resourcesTotalLabelLength(); i++) { + nodeInfo.resources.put(data.resourcesTotalLabel(i), data.resourcesTotalCapacity(i)); + } + } else if (data.entryType() == EntryType.RES_DELETE){ + Preconditions.checkState(clients.containsKey(clientId)); + NodeInfo nodeInfo = clients.get(clientId); + for (int i = 0; i < data.resourcesTotalLabelLength(); i++) { + nodeInfo.resources.remove(data.resourcesTotalLabel(i)); + } } else { // Code path of node deletion. + Preconditions.checkState(data.entryType() == EntryType.DELETION); NodeInfo nodeInfo = new NodeInfo(clientId, clients.get(clientId).nodeAddress, false, clients.get(clientId).resources); clients.put(clientId, nodeInfo); diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 31937837c780..bae62f9b1c88 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -32,6 +32,7 @@ from ray.includes.libraylet cimport ( from ray.includes.unique_ids cimport ( CActorCheckpointID, CObjectID, + CClientID, ) from ray.includes.task cimport CTaskSpecification from ray.includes.ray_config cimport RayConfig @@ -368,6 +369,9 @@ cdef class RayletClient: check_status(self.client.get().NotifyActorResumedFromCheckpoint( actor_id.native(), checkpoint_id.native())) + def set_resource(self, basestring resource_name, double capacity, ClientID client_id): + self.client.get().SetResource(resource_name.encode("ascii"), capacity, CClientID.from_binary(client_id.binary())) + @property def language(self): return Language.from_native(self.client.get().GetLanguage()) diff --git a/python/ray/experimental/__init__.py b/python/ray/experimental/__init__.py index 425ff2d932fc..5b811ff0ffb2 100644 --- a/python/ray/experimental/__init__.py +++ b/python/ray/experimental/__init__.py @@ -10,6 +10,7 @@ SimpleGcsFlushPolicy) from .named_actors import get_actor, register_actor from .api import get, wait +from .dynamic_resources import set_resource def TensorFlowVariables(*args, **kwargs): @@ -24,5 +25,5 @@ def TensorFlowVariables(*args, **kwargs): "flush_evicted_objects_unsafe", "_flush_finished_tasks_unsafe_shard", "_flush_evicted_objects_unsafe_shard", "get_actor", "register_actor", "get", "wait", "set_flushing_policy", "GcsFlushPolicy", - "SimpleGcsFlushPolicy" + "SimpleGcsFlushPolicy", "set_resource" ] diff --git a/python/ray/experimental/dynamic_resources.py b/python/ray/experimental/dynamic_resources.py new file mode 100644 index 000000000000..34b2b99e65a2 --- /dev/null +++ b/python/ray/experimental/dynamic_resources.py @@ -0,0 +1,35 @@ +import ray + + +def set_resource(resource_name, capacity, client_id=None): + """ Set a resource to a specified capacity. + + This creates, updates or deletes a custom resource for a target clientId. + If the resource already exists, it's capacity is updated to the new value. + If the capacity is set to 0, the resource is deleted. + If ClientID is not specified or set to None, + the resource is created on the local client where the actor is running. + + Args: + resource_name (str): Name of the resource to be created + capacity (int): Capacity of the new resource. Resource is deleted if + capacity is 0. + client_id (str): The ClientId of the node where the resource is to be + set. + + Returns: + None + + Raises: + ValueError: This exception is raised when a non-negative capacity is + specified. + """ + if client_id is not None: + client_id_obj = ray.ClientID(ray.utils.hex_to_binary(client_id)) + else: + client_id_obj = ray.ClientID.nil() + if (capacity < 0) or (capacity != int(capacity)): + raise ValueError( + "Capacity {} must be a non-negative integer.".format(capacity)) + return ray.worker.global_worker.raylet_client.set_resource( + resource_name, capacity, client_id_obj) diff --git a/python/ray/experimental/state.py b/python/ray/experimental/state.py index 31d4b77c64e6..51b36dc83fc7 100644 --- a/python/ray/experimental/state.py +++ b/python/ray/experimental/state.py @@ -13,6 +13,7 @@ from ray.ray_constants import ID_SIZE from ray import services +from ray.core.generated.EntryType import EntryType from ray.utils import (decode, binary_to_object_id, binary_to_hex, hex_to_binary) @@ -54,29 +55,43 @@ def parse_client_table(redis_client): } client_id = ray.utils.binary_to_hex(client.ClientId()) - # If this client is being removed, then it must + if client.EntryType() == EntryType.INSERTION: + ordered_client_ids.append(client_id) + node_info[client_id] = { + "ClientID": client_id, + "EntryType": client.EntryType(), + "NodeManagerAddress": decode( + client.NodeManagerAddress(), allow_none=True), + "NodeManagerPort": client.NodeManagerPort(), + "ObjectManagerPort": client.ObjectManagerPort(), + "ObjectStoreSocketName": decode( + client.ObjectStoreSocketName(), allow_none=True), + "RayletSocketName": decode( + client.RayletSocketName(), allow_none=True), + "Resources": resources + } + + # If this client is being updated, then it must # have previously been inserted, and # it cannot have previously been removed. - if not client.IsInsertion(): - assert client_id in node_info, "Client removed not found!" - assert node_info[client_id]["IsInsertion"], ( - "Unexpected duplicate removal of client.") else: - ordered_client_ids.append(client_id) - - node_info[client_id] = { - "ClientID": client_id, - "IsInsertion": client.IsInsertion(), - "NodeManagerAddress": decode( - client.NodeManagerAddress(), allow_none=True), - "NodeManagerPort": client.NodeManagerPort(), - "ObjectManagerPort": client.ObjectManagerPort(), - "ObjectStoreSocketName": decode( - client.ObjectStoreSocketName(), allow_none=True), - "RayletSocketName": decode( - client.RayletSocketName(), allow_none=True), - "Resources": resources - } + assert client_id in node_info, "Client not found!" + assert node_info[client_id]["EntryType"] != EntryType.DELETION, ( + "Unexpected updation of deleted client.") + res_map = node_info[client_id]["Resources"] + if client.EntryType() == EntryType.RES_CREATEUPDATE: + for res in resources: + res_map[res] = resources[res] + elif client.EntryType() == EntryType.RES_DELETE: + for res in resources: + res_map.pop(res, None) + elif client.EntryType() == EntryType.DELETION: + pass # Do nothing with the resmap if client deletion + else: + raise RuntimeError("Unexpected EntryType {}".format( + client.EntryType())) + node_info[client_id]["Resources"] = res_map + node_info[client_id]["EntryType"] = client.EntryType() # NOTE: We return the list comprehension below instead of simply doing # 'list(node_info.values())' in order to have the nodes appear in the order # that they joined the cluster. Python dictionaries do not preserve @@ -757,18 +772,18 @@ def cluster_resources(self): resources = defaultdict(int) clients = self.client_table() for client in clients: - # Only count resources from live clients. - if client["IsInsertion"]: + # Only count resources from latest entries of live clients. + if client["EntryType"] != EntryType.DELETION: for key, value in client["Resources"].items(): resources[key] += value - return dict(resources) def _live_client_ids(self): """Returns a set of client IDs corresponding to clients still alive.""" return { client["ClientID"] - for client in self.client_table() if client["IsInsertion"] + for client in self.client_table() + if (client["EntryType"] != EntryType.DELETION) } def available_resources(self): diff --git a/python/ray/includes/libraylet.pxd b/python/ray/includes/libraylet.pxd index be74b06e5729..1b4c5e3cd037 100644 --- a/python/ray/includes/libraylet.pxd +++ b/python/ray/includes/libraylet.pxd @@ -72,6 +72,7 @@ cdef extern from "ray/raylet/raylet_client.h" nogil: CActorCheckpointID &checkpoint_id) CRayStatus NotifyActorResumedFromCheckpoint( const CActorID &actor_id, const CActorCheckpointID &checkpoint_id) + CRayStatus SetResource(const c_string &resource_name, const double capacity, const CClientID &client_Id) CLanguage GetLanguage() const CClientID GetClientID() const CDriverID GetDriverID() const diff --git a/python/ray/tests/cluster_utils.py b/python/ray/tests/cluster_utils.py index 0a7984d69740..a7ed3e14a89a 100644 --- a/python/ray/tests/cluster_utils.py +++ b/python/ray/tests/cluster_utils.py @@ -8,6 +8,7 @@ import redis import ray +from ray.core.generated.EntryType import EntryType logger = logging.getLogger(__name__) @@ -175,7 +176,8 @@ def wait_for_nodes(self, timeout=30): while time.time() - start_time < timeout: clients = ray.experimental.state.parse_client_table(redis_client) live_clients = [ - client for client in clients if client["IsInsertion"] + client for client in clients + if client["EntryType"] == EntryType.INSERTION ] expected = len(self.list_all_nodes()) diff --git a/python/ray/tests/test_dynres.py b/python/ray/tests/test_dynres.py new file mode 100644 index 000000000000..6f39839301c9 --- /dev/null +++ b/python/ray/tests/test_dynres.py @@ -0,0 +1,586 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import logging +import time + +import ray +import ray.tests.cluster_utils +import ray.tests.utils + +logger = logging.getLogger(__name__) + + +def test_dynamic_res_creation(ray_start_regular): + # This test creates a resource locally (without specifying the client_id) + res_name = "test_res" + res_capacity = 1.0 + + @ray.remote + def set_res(resource_name, resource_capacity): + ray.experimental.set_resource(resource_name, resource_capacity) + + ray.get(set_res.remote(res_name, res_capacity)) + + available_res = ray.global_state.available_resources() + cluster_res = ray.global_state.cluster_resources() + + assert available_res[res_name] == res_capacity + assert cluster_res[res_name] == res_capacity + + +def test_dynamic_res_deletion(shutdown_only): + # This test deletes a resource locally (without specifying the client_id) + res_name = "test_res" + res_capacity = 1.0 + + ray.init(num_cpus=1, resources={res_name: res_capacity}) + + @ray.remote + def delete_res(resource_name): + ray.experimental.set_resource(resource_name, 0) + + ray.get(delete_res.remote(res_name)) + + available_res = ray.global_state.available_resources() + cluster_res = ray.global_state.cluster_resources() + + assert res_name not in available_res + assert res_name not in cluster_res + + +def test_dynamic_res_infeasible_rescheduling(ray_start_regular): + # This test launches an infeasible task and then creates a + # resource to make the task feasible. This tests if the + # infeasible tasks get rescheduled when resources are + # created at runtime. + res_name = "test_res" + res_capacity = 1.0 + + @ray.remote + def set_res(resource_name, resource_capacity): + ray.experimental.set_resource(resource_name, resource_capacity) + + def f(): + return 1 + + remote_task = ray.remote(resources={res_name: res_capacity})(f) + oid = remote_task.remote() # This is infeasible + ray.get(set_res.remote(res_name, res_capacity)) # Now should be feasible + + available_res = ray.global_state.available_resources() + assert available_res[res_name] == res_capacity + + successful, unsuccessful = ray.wait([oid], timeout=1) + assert successful # The task completed + + +def test_dynamic_res_updation_clientid(ray_start_cluster): + # This test does a simple resource capacity update + cluster = ray_start_cluster + + res_name = "test_res" + res_capacity = 1.0 + num_nodes = 3 + for i in range(num_nodes): + cluster.add_node() + + ray.init(redis_address=cluster.redis_address) + + target_clientid = ray.global_state.client_table()[1]["ClientID"] + + @ray.remote + def set_res(resource_name, resource_capacity, client_id): + ray.experimental.set_resource( + resource_name, resource_capacity, client_id=client_id) + + # Create resource + ray.get(set_res.remote(res_name, res_capacity, target_clientid)) + + # Update resource + new_capacity = res_capacity + 1 + ray.get(set_res.remote(res_name, new_capacity, target_clientid)) + + target_client = next(client for client in ray.global_state.client_table() + if client["ClientID"] == target_clientid) + resources = target_client["Resources"] + + assert res_name in resources + assert resources[res_name] == new_capacity + + +def test_dynamic_res_creation_clientid(ray_start_cluster): + # Creates a resource on a specific client and verifies creation. + cluster = ray_start_cluster + + res_name = "test_res" + res_capacity = 1.0 + num_nodes = 3 + for i in range(num_nodes): + cluster.add_node() + + ray.init(redis_address=cluster.redis_address) + + target_clientid = ray.global_state.client_table()[1]["ClientID"] + + @ray.remote + def set_res(resource_name, resource_capacity, res_client_id): + ray.experimental.set_resource( + resource_name, resource_capacity, client_id=res_client_id) + + ray.get(set_res.remote(res_name, res_capacity, target_clientid)) + target_client = next(client for client in ray.global_state.client_table() + if client["ClientID"] == target_clientid) + resources = target_client["Resources"] + + assert res_name in resources + assert resources[res_name] == res_capacity + + +def test_dynamic_res_creation_clientid_multiple(ray_start_cluster): + # This test creates resources on multiple clients using the clientid + # specifier + cluster = ray_start_cluster + + TIMEOUT = 5 + res_name = "test_res" + res_capacity = 1.0 + num_nodes = 3 + for i in range(num_nodes): + cluster.add_node() + + ray.init(redis_address=cluster.redis_address) + + target_clientids = [ + client["ClientID"] for client in ray.global_state.client_table() + ] + + @ray.remote + def set_res(resource_name, resource_capacity, res_client_id): + ray.experimental.set_resource( + resource_name, resource_capacity, client_id=res_client_id) + + results = [] + for cid in target_clientids: + results.append(set_res.remote(res_name, res_capacity, cid)) + ray.get(results) + + success = False + start_time = time.time() + + while time.time() - start_time < TIMEOUT and not success: + resources_created = [] + for cid in target_clientids: + target_client = next(client + for client in ray.global_state.client_table() + if client["ClientID"] == cid) + resources = target_client["Resources"] + resources_created.append(resources[res_name] == res_capacity) + success = all(resources_created) + assert success + + +def test_dynamic_res_deletion_clientid(ray_start_cluster): + # This test deletes a resource on a given client id + cluster = ray_start_cluster + + res_name = "test_res" + res_capacity = 1.0 + num_nodes = 5 + + for i in range(num_nodes): + # Create resource on all nodes, but later we'll delete it from a + # target node + cluster.add_node(resources={res_name: res_capacity}) + + ray.init(redis_address=cluster.redis_address) + + target_clientid = ray.global_state.client_table()[1]["ClientID"] + + # Launch the delete task + @ray.remote + def delete_res(resource_name, res_client_id): + ray.experimental.set_resource( + resource_name, 0, client_id=res_client_id) + + ray.get(delete_res.remote(res_name, target_clientid)) + + target_client = next(client for client in ray.global_state.client_table() + if client["ClientID"] == target_clientid) + resources = target_client["Resources"] + print(ray.global_state.cluster_resources()) + assert res_name not in resources + + +def test_dynamic_res_creation_scheduler_consistency(ray_start_cluster): + # This makes sure the resource is actually created and the state is + # consistent in the scheduler + # by launching a task which requests the created resource + cluster = ray_start_cluster + + res_name = "test_res" + res_capacity = 1.0 + num_nodes = 5 + + for i in range(num_nodes): + cluster.add_node() + + ray.init(redis_address=cluster.redis_address) + + clientids = [ + client["ClientID"] for client in ray.global_state.client_table() + ] + + @ray.remote + def set_res(resource_name, resource_capacity, res_client_id): + ray.experimental.set_resource( + resource_name, resource_capacity, client_id=res_client_id) + + # Create the resource on node1 + target_clientid = clientids[1] + ray.get(set_res.remote(res_name, res_capacity, target_clientid)) + + # Define a task which requires this resource + @ray.remote(resources={res_name: res_capacity}) + def test_func(): + return 1 + + result = test_func.remote() + successful, unsuccessful = ray.wait([result], timeout=5) + assert successful # The task completed + + +def test_dynamic_res_deletion_scheduler_consistency(ray_start_cluster): + # This makes sure the resource is actually deleted and the state is + # consistent in the scheduler by launching an infeasible task which + # requests the created resource + cluster = ray_start_cluster + + res_name = "test_res" + res_capacity = 1.0 + num_nodes = 5 + TIMEOUT_DURATION = 1 + + for i in range(num_nodes): + cluster.add_node() + + ray.init(redis_address=cluster.redis_address) + + clientids = [ + client["ClientID"] for client in ray.global_state.client_table() + ] + + @ray.remote + def delete_res(resource_name, res_client_id): + ray.experimental.set_resource( + resource_name, 0, client_id=res_client_id) + + @ray.remote + def set_res(resource_name, resource_capacity, res_client_id): + ray.experimental.set_resource( + resource_name, resource_capacity, client_id=res_client_id) + + # Create the resource on node1 + target_clientid = clientids[1] + ray.get(set_res.remote(res_name, res_capacity, target_clientid)) + assert ray.global_state.cluster_resources()[res_name] == res_capacity + + # Delete the resource + ray.get(delete_res.remote(res_name, target_clientid)) + + # Define a task which requires this resource. This should not run + @ray.remote(resources={res_name: res_capacity}) + def test_func(): + return 1 + + result = test_func.remote() + successful, unsuccessful = ray.wait([result], timeout=TIMEOUT_DURATION) + assert unsuccessful # The task did not complete because it's infeasible + + +def test_dynamic_res_concurrent_res_increment(ray_start_cluster): + # This test makes sure resource capacity is updated (increment) correctly + # when a task has already acquired some of the resource. + + cluster = ray_start_cluster + + res_name = "test_res" + res_capacity = 5 + updated_capacity = 10 + num_nodes = 5 + TIMEOUT_DURATION = 1 + + # Create a object ID to have the task wait on + WAIT_OBJECT_ID_STR = ("a" * 20).encode("ascii") + + # Create a object ID to signal that the task is running + TASK_RUNNING_OBJECT_ID_STR = ("b" * 20).encode("ascii") + + for i in range(num_nodes): + cluster.add_node() + + ray.init(redis_address=cluster.redis_address) + + clientids = [ + client["ClientID"] for client in ray.global_state.client_table() + ] + target_clientid = clientids[1] + + @ray.remote + def set_res(resource_name, resource_capacity, res_client_id): + ray.experimental.set_resource( + resource_name, resource_capacity, client_id=res_client_id) + + # Create the resource on node 1 + ray.get(set_res.remote(res_name, res_capacity, target_clientid)) + assert ray.global_state.cluster_resources()[res_name] == res_capacity + + # Task to hold the resource till the driver signals to finish + @ray.remote + def wait_func(running_oid, wait_oid): + # Signal that the task is running + ray.worker.global_worker.put_object(ray.ObjectID(running_oid), 1) + # Make the task wait till signalled by driver + ray.get(ray.ObjectID(wait_oid)) + + @ray.remote + def test_func(): + return 1 + + # Launch the task with resource requirement of 4, thus the new available + # capacity becomes 1 + task = wait_func._remote( + args=[TASK_RUNNING_OBJECT_ID_STR, WAIT_OBJECT_ID_STR], + resources={res_name: 4}) + # Wait till wait_func is launched before updating resource + ray.get(ray.ObjectID(TASK_RUNNING_OBJECT_ID_STR)) + + # Update the resource capacity + ray.get(set_res.remote(res_name, updated_capacity, target_clientid)) + + # Signal task to complete + ray.worker.global_worker.put_object(ray.ObjectID(WAIT_OBJECT_ID_STR), 1) + ray.get(task) + + # Check if scheduler state is consistent by launching a task requiring + # updated capacity + task_2 = test_func._remote(args=[], resources={res_name: updated_capacity}) + successful, unsuccessful = ray.wait([task_2], timeout=TIMEOUT_DURATION) + assert successful # The task completed + + # Check if scheduler state is consistent by launching a task requiring + # updated capacity + 1. This should not execute + task_3 = test_func._remote( + args=[], resources={res_name: updated_capacity + 1 + }) # This should be infeasible + successful, unsuccessful = ray.wait([task_3], timeout=TIMEOUT_DURATION) + assert unsuccessful # The task did not complete because it's infeasible + assert ray.global_state.available_resources()[res_name] == updated_capacity + + +def test_dynamic_res_concurrent_res_decrement(ray_start_cluster): + # This test makes sure resource capacity is updated (decremented) + # correctly when a task has already acquired some + # of the resource. + + cluster = ray_start_cluster + + res_name = "test_res" + res_capacity = 5 + updated_capacity = 2 + num_nodes = 5 + TIMEOUT_DURATION = 1 + + # Create a object ID to have the task wait on + WAIT_OBJECT_ID_STR = ("a" * 20).encode("ascii") + + # Create a object ID to signal that the task is running + TASK_RUNNING_OBJECT_ID_STR = ("b" * 20).encode("ascii") + + for i in range(num_nodes): + cluster.add_node() + + ray.init(redis_address=cluster.redis_address) + + clientids = [ + client["ClientID"] for client in ray.global_state.client_table() + ] + target_clientid = clientids[1] + + @ray.remote + def set_res(resource_name, resource_capacity, res_client_id): + ray.experimental.set_resource( + resource_name, resource_capacity, client_id=res_client_id) + + # Create the resource on node 1 + ray.get(set_res.remote(res_name, res_capacity, target_clientid)) + assert ray.global_state.cluster_resources()[res_name] == res_capacity + + # Task to hold the resource till the driver signals to finish + @ray.remote + def wait_func(running_oid, wait_oid): + # Signal that the task is running + ray.worker.global_worker.put_object(ray.ObjectID(running_oid), 1) + # Make the task wait till signalled by driver + ray.get(ray.ObjectID(wait_oid)) + + @ray.remote + def test_func(): + return 1 + + # Launch the task with resource requirement of 4, thus the new available + # capacity becomes 1 + task = wait_func._remote( + args=[TASK_RUNNING_OBJECT_ID_STR, WAIT_OBJECT_ID_STR], + resources={res_name: 4}) + # Wait till wait_func is launched before updating resource + ray.get(ray.ObjectID(TASK_RUNNING_OBJECT_ID_STR)) + + # Decrease the resource capacity + ray.get(set_res.remote(res_name, updated_capacity, target_clientid)) + + # Signal task to complete + ray.worker.global_worker.put_object(ray.ObjectID(WAIT_OBJECT_ID_STR), 1) + ray.get(task) + + # Check if scheduler state is consistent by launching a task requiring + # updated capacity + task_2 = test_func._remote(args=[], resources={res_name: updated_capacity}) + successful, unsuccessful = ray.wait([task_2], timeout=TIMEOUT_DURATION) + assert successful # The task completed + + # Check if scheduler state is consistent by launching a task requiring + # updated capacity + 1. This should not execute + task_3 = test_func._remote( + args=[], resources={res_name: updated_capacity + 1 + }) # This should be infeasible + successful, unsuccessful = ray.wait([task_3], timeout=TIMEOUT_DURATION) + assert unsuccessful # The task did not complete because it's infeasible + assert ray.global_state.available_resources()[res_name] == updated_capacity + + +def test_dynamic_res_concurrent_res_delete(ray_start_cluster): + # This test makes sure resource gets deleted correctly when a task has + # already acquired the resource + + cluster = ray_start_cluster + + res_name = "test_res" + res_capacity = 5 + num_nodes = 5 + TIMEOUT_DURATION = 1 + + # Create a object ID to have the task wait on + WAIT_OBJECT_ID_STR = ("a" * 20).encode("ascii") + + # Create a object ID to signal that the task is running + TASK_RUNNING_OBJECT_ID_STR = ("b" * 20).encode("ascii") + + for i in range(num_nodes): + cluster.add_node() + + ray.init(redis_address=cluster.redis_address) + + clientids = [ + client["ClientID"] for client in ray.global_state.client_table() + ] + target_clientid = clientids[1] + + @ray.remote + def set_res(resource_name, resource_capacity, res_client_id): + ray.experimental.set_resource( + resource_name, resource_capacity, client_id=res_client_id) + + @ray.remote + def delete_res(resource_name, res_client_id): + ray.experimental.set_resource( + resource_name, 0, client_id=res_client_id) + + # Create the resource on node 1 + ray.get(set_res.remote(res_name, res_capacity, target_clientid)) + assert ray.global_state.cluster_resources()[res_name] == res_capacity + + # Task to hold the resource till the driver signals to finish + @ray.remote + def wait_func(running_oid, wait_oid): + # Signal that the task is running + ray.worker.global_worker.put_object(ray.ObjectID(running_oid), 1) + # Make the task wait till signalled by driver + ray.get(ray.ObjectID(wait_oid)) + + @ray.remote + def test_func(): + return 1 + + # Launch the task with resource requirement of 4, thus the new available + # capacity becomes 1 + task = wait_func._remote( + args=[TASK_RUNNING_OBJECT_ID_STR, WAIT_OBJECT_ID_STR], + resources={res_name: 4}) + # Wait till wait_func is launched before updating resource + ray.get(ray.ObjectID(TASK_RUNNING_OBJECT_ID_STR)) + + # Delete the resource + ray.get(delete_res.remote(res_name, target_clientid)) + + # Signal task to complete + ray.worker.global_worker.put_object(ray.ObjectID(WAIT_OBJECT_ID_STR), 1) + ray.get(task) + + # Check if scheduler state is consistent by launching a task requiring + # the deleted resource This should not execute + task_2 = test_func._remote( + args=[], resources={res_name: 1}) # This should be infeasible + successful, unsuccessful = ray.wait([task_2], timeout=TIMEOUT_DURATION) + assert unsuccessful # The task did not complete because it's infeasible + assert res_name not in ray.global_state.available_resources() + + +def test_dynamic_res_creation_stress(ray_start_cluster): + # This stress tests creates many resources simultaneously on the same + # client and then checks if the final state is consistent + + cluster = ray_start_cluster + + TIMEOUT = 5 + res_capacity = 1 + num_nodes = 5 + NUM_RES_TO_CREATE = 500 + + for i in range(num_nodes): + cluster.add_node() + + ray.init(redis_address=cluster.redis_address) + + clientids = [ + client["ClientID"] for client in ray.global_state.client_table() + ] + target_clientid = clientids[1] + + @ray.remote + def set_res(resource_name, resource_capacity, res_client_id): + ray.experimental.set_resource( + resource_name, resource_capacity, client_id=res_client_id) + + @ray.remote + def delete_res(resource_name, res_client_id): + ray.experimental.set_resource( + resource_name, 0, client_id=res_client_id) + + results = [ + set_res.remote(str(i), res_capacity, target_clientid) + for i in range(0, NUM_RES_TO_CREATE) + ] + ray.get(results) + + success = False + start_time = time.time() + + while time.time() - start_time < TIMEOUT and not success: + resources = ray.global_state.cluster_resources() + all_resources_created = [] + for i in range(0, NUM_RES_TO_CREATE): + all_resources_created.append(str(i) in resources) + success = all(all_resources_created) + assert success diff --git a/src/ray/gcs/client_test.cc b/src/ray/gcs/client_test.cc index d2d225c0a687..f7e25a4873ab 100644 --- a/src/ray/gcs/client_test.cc +++ b/src/ray/gcs/client_test.cc @@ -1188,12 +1188,12 @@ void ClientTableNotification(gcs::AsyncGcsClient *client, const ClientID &client ASSERT_EQ(client_id, added_id); ASSERT_EQ(ClientID::from_binary(data.client_id), added_id); ASSERT_EQ(ClientID::from_binary(data.client_id), added_id); - ASSERT_EQ(data.is_insertion, is_insertion); + ASSERT_EQ(data.entry_type == EntryType::INSERTION, is_insertion); ClientTableDataT cached_client; client->client_table().GetClient(added_id, cached_client); ASSERT_EQ(ClientID::from_binary(cached_client.client_id), added_id); - ASSERT_EQ(cached_client.is_insertion, is_insertion); + ASSERT_EQ(cached_client.entry_type == EntryType::INSERTION, is_insertion); } void TestClientTableConnect(const DriverID &driver_id, diff --git a/src/ray/gcs/format/gcs.fbs b/src/ray/gcs/format/gcs.fbs index 7acb24d27bd6..7cf250247461 100644 --- a/src/ray/gcs/format/gcs.fbs +++ b/src/ray/gcs/format/gcs.fbs @@ -39,6 +39,14 @@ enum TablePubsub:int { DRIVER, } +// Enum for the entry type in the ClientTable +enum EntryType:int { + INSERTION = 0, + DELETION, + RES_CREATEUPDATE, + RES_DELETE, +} + table Arg { // Object ID for pass-by-reference arguments. Normally there is only one // object ID in this list which represents the object that is being passed. @@ -267,9 +275,8 @@ table ClientTableData { // The port at which the client's object manager is listening for TCP // connections from other object managers. object_manager_port: int; - // True if the message is about the addition of a client and false if it is - // about the deletion of a client. - is_insertion: bool; + // Enum to store the entry type in the log + entry_type: EntryType = INSERTION; resources_total_label: [string]; resources_total_capacity: [double]; } diff --git a/src/ray/gcs/tables.cc b/src/ray/gcs/tables.cc index e0876aa73e3e..dbd39349caf7 100644 --- a/src/ray/gcs/tables.cc +++ b/src/ray/gcs/tables.cc @@ -363,7 +363,7 @@ void ClientTable::RegisterClientAddedCallback(const ClientTableCallback &callbac client_added_callback_ = callback; // Call the callback for any added clients that are cached. for (const auto &entry : client_cache_) { - if (!entry.first.is_nil() && entry.second.is_insertion) { + if (!entry.first.is_nil() && (entry.second.entry_type == EntryType::INSERTION)) { client_added_callback_(client_, entry.first, entry.second); } } @@ -373,55 +373,136 @@ void ClientTable::RegisterClientRemovedCallback(const ClientTableCallback &callb client_removed_callback_ = callback; // Call the callback for any removed clients that are cached. for (const auto &entry : client_cache_) { - if (!entry.first.is_nil() && !entry.second.is_insertion) { + if (!entry.first.is_nil() && entry.second.entry_type == EntryType::DELETION) { client_removed_callback_(client_, entry.first, entry.second); } } } +void ClientTable::RegisterResourceCreateUpdatedCallback( + const ClientTableCallback &callback) { + resource_createupdated_callback_ = callback; + // Call the callback for any clients that are cached. + for (const auto &entry : client_cache_) { + if (!entry.first.is_nil() && + (entry.second.entry_type == EntryType::RES_CREATEUPDATE)) { + resource_createupdated_callback_(client_, entry.first, entry.second); + } + } +} + +void ClientTable::RegisterResourceDeletedCallback(const ClientTableCallback &callback) { + resource_deleted_callback_ = callback; + // Call the callback for any clients that are cached. + for (const auto &entry : client_cache_) { + if (!entry.first.is_nil() && entry.second.entry_type == EntryType::RES_DELETE) { + resource_deleted_callback_(client_, entry.first, entry.second); + } + } +} + void ClientTable::HandleNotification(AsyncGcsClient *client, const ClientTableDataT &data) { ClientID client_id = ClientID::from_binary(data.client_id); // It's possible to get duplicate notifications from the client table, so // check whether this notification is new. auto entry = client_cache_.find(client_id); - bool is_new; + bool is_notif_new; if (entry == client_cache_.end()) { // If the entry is not in the cache, then the notification is new. - is_new = true; + is_notif_new = true; } else { // If the entry is in the cache, then the notification is new if the client - // was alive and is now dead. - bool was_inserted = entry->second.is_insertion; - bool is_deleted = !data.is_insertion; - is_new = (was_inserted && is_deleted); + // was alive and is now dead or resources have been updated. + bool was_not_deleted = (entry->second.entry_type != EntryType::DELETION); + bool is_deleted = (data.entry_type == EntryType::DELETION); + bool is_res_modified = ((data.entry_type == EntryType::RES_CREATEUPDATE) || + (data.entry_type == EntryType::RES_DELETE)); + is_notif_new = (was_not_deleted && (is_deleted || is_res_modified)); // Once a client with a given ID has been removed, it should never be added // again. If the entry was in the cache and the client was deleted, check // that this new notification is not an insertion. - if (!entry->second.is_insertion) { - RAY_CHECK(!data.is_insertion) + if (entry->second.entry_type == EntryType::DELETION) { + RAY_CHECK((data.entry_type == EntryType::DELETION)) << "Notification for addition of a client that was already removed:" << client_id; } } // Add the notification to our cache. Notifications are idempotent. - client_cache_[client_id] = data; + // If it is a new client or a client removal, add as is + if ((data.entry_type == EntryType::INSERTION) || + (data.entry_type == EntryType::DELETION)) { + RAY_LOG(DEBUG) << "[ClientTableNotification] ClientTable Insertion/Deletion " + "notification for client id " + << client_id << ". EntryType: " << int(data.entry_type) + << ". Setting the client cache to data."; + client_cache_[client_id] = data; + } else if ((data.entry_type == EntryType::RES_CREATEUPDATE) || + (data.entry_type == EntryType::RES_DELETE)) { + RAY_LOG(DEBUG) << "[ClientTableNotification] ClientTable RES_CREATEUPDATE " + "notification for client id " + << client_id << ". EntryType: " << int(data.entry_type) + << ". Updating the client cache with the delta from the log."; + + ClientTableDataT &cache_data = client_cache_[client_id]; + // Iterate over all resources in the new create/update notification + for (std::vector::size_type i = 0; i != data.resources_total_label.size(); i++) { + auto const &resource_name = data.resources_total_label[i]; + auto const &capacity = data.resources_total_capacity[i]; + + // If resource exists in the ClientTableData, update it, else create it + auto existing_resource_label = + std::find(cache_data.resources_total_label.begin(), + cache_data.resources_total_label.end(), resource_name); + if (existing_resource_label != cache_data.resources_total_label.end()) { + auto index = std::distance(cache_data.resources_total_label.begin(), + existing_resource_label); + // Resource already exists, set capacity if updation call.. + if (data.entry_type == EntryType::RES_CREATEUPDATE) { + cache_data.resources_total_capacity[index] = capacity; + } + // .. delete if deletion call. + else if (data.entry_type == EntryType::RES_DELETE) { + cache_data.resources_total_label.erase( + cache_data.resources_total_label.begin() + index); + cache_data.resources_total_capacity.erase( + cache_data.resources_total_capacity.begin() + index); + } + } else { + // Resource does not exist, create resource and add capacity if it was a resource + // create call. + if (data.entry_type == EntryType::RES_CREATEUPDATE) { + cache_data.resources_total_label.push_back(resource_name); + cache_data.resources_total_capacity.push_back(capacity); + } + } + } + } // If the notification is new, call any registered callbacks. - if (is_new) { - if (data.is_insertion) { + ClientTableDataT &cache_data = client_cache_[client_id]; + if (is_notif_new) { + if (data.entry_type == EntryType::INSERTION) { if (client_added_callback_ != nullptr) { - client_added_callback_(client, client_id, data); + client_added_callback_(client, client_id, cache_data); } RAY_CHECK(removed_clients_.find(client_id) == removed_clients_.end()); - } else { + } else if (data.entry_type == EntryType::DELETION) { // NOTE(swang): The client should be added to this data structure before // the callback gets called, in case the callback depends on the data // structure getting updated. removed_clients_.insert(client_id); if (client_removed_callback_ != nullptr) { - client_removed_callback_(client, client_id, data); + client_removed_callback_(client, client_id, cache_data); + } + } else if (data.entry_type == EntryType::RES_CREATEUPDATE) { + if (resource_createupdated_callback_ != nullptr) { + resource_createupdated_callback_(client, client_id, cache_data); + } + } else if (data.entry_type == EntryType::RES_DELETE) { + if (resource_deleted_callback_ != nullptr) { + resource_deleted_callback_(client, client_id, cache_data); } } } @@ -449,7 +530,7 @@ Status ClientTable::Connect(const ClientTableDataT &local_client) { // Construct the data to add to the client table. auto data = std::make_shared(local_client_); - data->is_insertion = true; + data->entry_type = EntryType::INSERTION; // Callback to handle our own successful connection once we've added // ourselves. auto add_callback = [this](AsyncGcsClient *client, const UniqueID &log_key, @@ -467,7 +548,7 @@ Status ClientTable::Connect(const ClientTableDataT &local_client) { for (auto ¬ification : notifications) { // This is temporary fix for Issue 4140 to avoid connect to dead nodes. // TODO(yuhguo): remove this temporary fix after GCS entry is removable. - if (notification.is_insertion) { + if (notification.entry_type != EntryType::DELETION) { connected_nodes.emplace(notification.client_id, notification); } else { auto iter = connected_nodes.find(notification.client_id); @@ -498,7 +579,7 @@ Status ClientTable::Connect(const ClientTableDataT &local_client) { Status ClientTable::Disconnect(const DisconnectCallback &callback) { auto data = std::make_shared(local_client_); - data->is_insertion = false; + data->entry_type = EntryType::DELETION; auto add_callback = [this, callback](AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) { HandleConnected(client, data); @@ -516,7 +597,7 @@ Status ClientTable::Disconnect(const DisconnectCallback &callback) { ray::Status ClientTable::MarkDisconnected(const ClientID &dead_client_id) { auto data = std::make_shared(); data->client_id = dead_client_id.binary(); - data->is_insertion = false; + data->entry_type = EntryType::DELETION; return Append(DriverID::nil(), client_log_key_, data, nullptr); } diff --git a/src/ray/gcs/tables.h b/src/ray/gcs/tables.h index b229108328f9..056bf7b97ec7 100644 --- a/src/ray/gcs/tables.h +++ b/src/ray/gcs/tables.h @@ -677,7 +677,7 @@ using ConfigTable = Table; /// it should append an entry to the log indicating that it is dead. A client /// that is marked as dead should never again be marked as alive; if it needs /// to reconnect, it must connect with a different ClientID. -class ClientTable : private Log { +class ClientTable : public Log { public: using ClientTableCallback = std::function; @@ -729,6 +729,16 @@ class ClientTable : private Log { /// \param callback The callback to register. void RegisterClientRemovedCallback(const ClientTableCallback &callback); + /// Register a callback to call when a resource is created or updated. + /// + /// \param callback The callback to register. + void RegisterResourceCreateUpdatedCallback(const ClientTableCallback &callback); + + /// Register a callback to call when a resource is deleted. + /// + /// \param callback The callback to register. + void RegisterResourceDeletedCallback(const ClientTableCallback &callback); + /// Get a client's information from the cache. The cache only contains /// information for clients that we've heard a notification for. /// @@ -772,16 +782,16 @@ class ClientTable : private Log { /// \return string. std::string DebugString() const; + /// The key at which the log of client information is stored. This key must + /// be kept the same across all instances of the ClientTable, so that all + /// clients append and read from the same key. + ClientID client_log_key_; + private: /// Handle a client table notification. void HandleNotification(AsyncGcsClient *client, const ClientTableDataT ¬ifications); /// Handle this client's successful connection to the GCS. void HandleConnected(AsyncGcsClient *client, const ClientTableDataT &client_data); - - /// The key at which the log of client information is stored. This key must - /// be kept the same across all instances of the ClientTable, so that all - /// clients append and read from the same key. - ClientID client_log_key_; /// Whether this client has called Disconnect(). bool disconnected_; /// This client's ID. @@ -792,6 +802,10 @@ class ClientTable : private Log { ClientTableCallback client_added_callback_; /// The callback to call when a client is removed. ClientTableCallback client_removed_callback_; + /// The callback to call when a resource is created or updated. + ClientTableCallback resource_createupdated_callback_; + /// The callback to call when a resource is deleted. + ClientTableCallback resource_deleted_callback_; /// A cache for information about all clients. std::unordered_map client_cache_; /// The set of removed clients. diff --git a/src/ray/object_manager/object_directory.cc b/src/ray/object_manager/object_directory.cc index 99ed0851cfb9..85157abcdbe9 100644 --- a/src/ray/object_manager/object_directory.cc +++ b/src/ray/object_manager/object_directory.cc @@ -108,7 +108,7 @@ void ObjectDirectory::LookupRemoteConnectionInfo( ClientID result_client_id = ClientID::from_binary(client_data.client_id); if (!result_client_id.is_nil()) { RAY_CHECK(result_client_id == connection_info.client_id); - if (client_data.is_insertion) { + if (client_data.entry_type == EntryType::INSERTION) { connection_info.ip = client_data.node_manager_address; connection_info.port = static_cast(client_data.object_manager_port); } diff --git a/src/ray/raylet/format/node_manager.fbs b/src/ray/raylet/format/node_manager.fbs index f673e2251548..a5b041f29cae 100644 --- a/src/ray/raylet/format/node_manager.fbs +++ b/src/ray/raylet/format/node_manager.fbs @@ -77,6 +77,8 @@ enum MessageType:int { NotifyActorResumedFromCheckpoint, // A node manager requests to connect to another node manager. ConnectClient, + // Set dynamic custom resource + SetResourceRequest, } table TaskExecutionSpecification { @@ -234,3 +236,12 @@ table ConnectClient { // ID of the connecting client. client_id: string; } + +table SetResourceRequest{ + // Name of the resource to be set + resource_name: string; + // Capacity of the resource to be set + capacity: double; + // Client ID where this resource will be set + client_id: string; +} diff --git a/src/ray/raylet/monitor.cc b/src/ray/raylet/monitor.cc index 51e035b1b1e6..1e20fe3f4131 100644 --- a/src/ray/raylet/monitor.cc +++ b/src/ray/raylet/monitor.cc @@ -52,7 +52,8 @@ void Monitor::Tick() { const std::vector &all_data) { bool marked = false; for (const auto &data : all_data) { - if (client_id.binary() == data.client_id && !data.is_insertion) { + if (client_id.binary() == data.client_id && + data.entry_type == EntryType::DELETION) { // The node has been marked dead by itself. marked = true; } diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index f901c6800eb5..fa4ad4868aa5 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -185,6 +185,22 @@ ray::Status NodeManager::RegisterGcs() { }; gcs_client_->client_table().RegisterClientRemovedCallback(node_manager_client_removed); + // Register a callback on the client table for resource create/update requests + auto node_manager_resource_createupdated = [this]( + gcs::AsyncGcsClient *client, const UniqueID &id, const ClientTableDataT &data) { + ResourceCreateUpdated(data); + }; + gcs_client_->client_table().RegisterResourceCreateUpdatedCallback( + node_manager_resource_createupdated); + + // Register a callback on the client table for resource delete requests + auto node_manager_resource_deleted = [this]( + gcs::AsyncGcsClient *client, const UniqueID &id, const ClientTableDataT &data) { + ResourceDeleted(data); + }; + gcs_client_->client_table().RegisterResourceDeletedCallback( + node_manager_resource_deleted); + // Subscribe to heartbeat batches from the monitor. const auto &heartbeat_batch_added = [this]( gcs::AsyncGcsClient *client, const ClientID &id, @@ -461,6 +477,92 @@ void NodeManager::ClientRemoved(const ClientTableDataT &client_data) { object_directory_->HandleClientRemoved(client_id); } +void NodeManager::ResourceCreateUpdated(const ClientTableDataT &client_data) { + const ClientID client_id = ClientID::from_binary(client_data.client_id); + const ClientID &local_client_id = gcs_client_->client_table().GetLocalClientId(); + + RAY_LOG(DEBUG) << "[ResourceCreateUpdated] received callback from client id " + << client_id << ". Updating resource map."; + ResourceSet new_res_set(client_data.resources_total_label, + client_data.resources_total_capacity); + + const ResourceSet &old_res_set = cluster_resource_map_[client_id].GetTotalResources(); + ResourceSet difference_set = old_res_set.FindUpdatedResources(new_res_set); + RAY_LOG(DEBUG) << "[ResourceCreateUpdated] The difference in the resource map is " + << difference_set.ToString(); + + SchedulingResources &cluster_schedres = cluster_resource_map_[client_id]; + + // Update local_available_resources_ and SchedulingResources + for (const auto &resource_pair : difference_set.GetResourceMap()) { + const std::string &resource_label = resource_pair.first; + const double &new_resource_capacity = resource_pair.second; + + cluster_schedres.UpdateResource(resource_label, new_resource_capacity); + if (client_id == local_client_id) { + local_available_resources_.AddOrUpdateResource(resource_label, + new_resource_capacity); + } + } + RAY_LOG(DEBUG) << "[ResourceCreateUpdated] Updated cluster_resource_map."; + + if (client_id == local_client_id) { + // The resource update is on the local node, check if we can reschedule tasks. + TryLocalInfeasibleTaskScheduling(); + } + return; +} + +void NodeManager::ResourceDeleted(const ClientTableDataT &client_data) { + const ClientID client_id = ClientID::from_binary(client_data.client_id); + const ClientID &local_client_id = gcs_client_->client_table().GetLocalClientId(); + + ResourceSet new_res_set(client_data.resources_total_label, + client_data.resources_total_capacity); + RAY_LOG(DEBUG) << "[ResourceDeleted] received callback from client id " << client_id + << " with new resources: " << new_res_set.ToString() + << ". Updating resource map."; + + const ResourceSet &old_res_set = cluster_resource_map_[client_id].GetTotalResources(); + ResourceSet deleted_set = old_res_set.FindDeletedResources(new_res_set); + RAY_LOG(DEBUG) << "[ResourceDeleted] The difference in the resource map is " + << deleted_set.ToString(); + + SchedulingResources &cluster_schedres = cluster_resource_map_[client_id]; + + // Update local_available_resources_ and SchedulingResources + for (const auto &resource_pair : deleted_set.GetResourceMap()) { + const std::string &resource_label = resource_pair.first; + + cluster_schedres.DeleteResource(resource_label); + if (client_id == local_client_id) { + local_available_resources_.DeleteResource(resource_label); + } + } + RAY_LOG(DEBUG) << "[ResourceDeleted] Updated cluster_resource_map."; + return; +} + +void NodeManager::TryLocalInfeasibleTaskScheduling() { + RAY_LOG(DEBUG) << "[LocalResourceUpdateRescheduler] The resource update is on the " + "local node, check if we can reschedule tasks"; + const ClientID &local_client_id = gcs_client_->client_table().GetLocalClientId(); + SchedulingResources &new_local_resources = cluster_resource_map_[local_client_id]; + + // SpillOver locally to figure out which infeasible tasks can be placed now + std::vector decision = scheduling_policy_.SpillOver(new_local_resources); + + std::unordered_set local_task_ids(decision.begin(), decision.end()); + + // Transition locally placed tasks to waiting or ready for dispatch. + if (local_task_ids.size() > 0) { + std::vector tasks = local_queues_.RemoveTasks(local_task_ids); + for (const auto &t : tasks) { + EnqueuePlaceableTask(t); + } + } +} + void NodeManager::HeartbeatAdded(const ClientID &client_id, const HeartbeatTableDataT &heartbeat_data) { // Locate the client id in remote client table and update available resources based on @@ -718,6 +820,9 @@ void NodeManager::ProcessClientMessage( case protocol::MessageType::SubmitTask: { ProcessSubmitTaskMessage(message_data); } break; + case protocol::MessageType::SetResourceRequest: { + ProcessSetResourceRequest(client, message_data); + } break; case protocol::MessageType::FetchOrReconstruct: { ProcessFetchOrReconstructMessage(client, message_data); } break; @@ -931,12 +1036,14 @@ void NodeManager::ProcessDisconnectClientMessage( // Return the resources that were being used by this worker. auto const &task_resources = worker->GetTaskResourceIds(); - local_available_resources_.Release(task_resources); + local_available_resources_.ReleaseConstrained( + task_resources, cluster_resource_map_[client_id].GetTotalResources()); cluster_resource_map_[client_id].Release(task_resources.ToResourceSet()); worker->ResetTaskResourceIds(); auto const &lifetime_resources = worker->GetLifetimeResourceIds(); - local_available_resources_.Release(lifetime_resources); + local_available_resources_.ReleaseConstrained( + lifetime_resources, cluster_resource_map_[client_id].GetTotalResources()); cluster_resource_map_[client_id].Release(lifetime_resources.ToResourceSet()); worker->ResetLifetimeResourceIds(); @@ -1170,6 +1277,59 @@ void NodeManager::ProcessNodeManagerMessage(TcpClientConnection &node_manager_cl node_manager_client.ProcessMessages(); } +void NodeManager::ProcessSetResourceRequest( + const std::shared_ptr &client, const uint8_t *message_data) { + // Read the SetResource message + auto message = flatbuffers::GetRoot(message_data); + + auto const &resource_name = string_from_flatbuf(*message->resource_name()); + double const &capacity = message->capacity(); + bool is_deletion = capacity <= 0; + + ClientID client_id = from_flatbuf(*message->client_id()); + + // If the python arg was null, set client_id to the local client + if (client_id.is_nil()) { + client_id = gcs_client_->client_table().GetLocalClientId(); + } + + if (is_deletion && + cluster_resource_map_[client_id].GetTotalResources().GetResourceMap().count( + resource_name) == 0) { + // Resource does not exist in the cluster resource map, thus nothing to delete. + // Return.. + RAY_LOG(INFO) << "[ProcessDeleteResourceRequest] Trying to delete resource " + << resource_name << ", but it does not exist. Doing nothing.."; + return; + } + + // Add the new resource to a skeleton ClientTableDataT object + ClientTableDataT data; + gcs_client_->client_table().GetClient(client_id, data); + // Replace the resource vectors with the resource deltas from the message. + // RES_CREATEUPDATE and RES_DELETE entries in the ClientTable track changes (deltas) in + // the resources + data.resources_total_label = std::vector{resource_name}; + data.resources_total_capacity = std::vector{capacity}; + // Set the correct flag for entry_type + if (is_deletion) { + data.entry_type = EntryType::RES_DELETE; + } else { + data.entry_type = EntryType::RES_CREATEUPDATE; + } + + // Submit to the client table. This calls the ResourceCreateUpdated callback, which + // updates cluster_resource_map_. + std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); + if (not worker) { + worker = worker_pool_.GetRegisteredDriver(client); + } + auto data_shared_ptr = std::make_shared(data); + auto client_table = gcs_client_->client_table(); + RAY_CHECK_OK(gcs_client_->client_table().Append( + DriverID::nil(), client_table.client_log_key_, data_shared_ptr, nullptr)); +} + void NodeManager::ScheduleTasks( std::unordered_map &resource_map) { const ClientID &local_client_id = gcs_client_->client_table().GetLocalClientId(); @@ -1761,7 +1921,9 @@ void NodeManager::FinishAssignedTask(Worker &worker) { // Release task's resources. The worker's lifetime resources are still held. auto const &task_resources = worker.GetTaskResourceIds(); - local_available_resources_.Release(task_resources); + const ClientID &client_id = gcs_client_->client_table().GetLocalClientId(); + local_available_resources_.ReleaseConstrained( + task_resources, cluster_resource_map_[client_id].GetTotalResources()); cluster_resource_map_[gcs_client_->client_table().GetLocalClientId()].Release( task_resources.ToResourceSet()); worker.ResetTaskResourceIds(); diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index edd456dbecce..8c5973c73fac 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -120,6 +120,21 @@ class NodeManager { /// \return Void. void ClientRemoved(const ClientTableDataT &client_data); + /// Handler for the addition or updation of a resource in the GCS + /// \param client_data Data associated with the new client. + /// \return Void. + void ResourceCreateUpdated(const ClientTableDataT &client_data); + + /// Handler for the deletion of a resource in the GCS + /// \param client_data Data associated with the new client. + /// \return Void. + void ResourceDeleted(const ClientTableDataT &client_data); + + /// Evaluates the local infeasible queue to check if any tasks can be scheduled. + /// This is called whenever there's an update to the resources on the local client. + /// \return Void. + void TryLocalInfeasibleTaskScheduling(); + /// Send heartbeats to the GCS. void Heartbeat(); @@ -413,6 +428,13 @@ class NodeManager { /// \param task The task that just finished. void UpdateActorFrontier(const Task &task); + /// Process client message of SetResourceRequest + /// \param client The client that sent the message. + /// \param message_data A pointer to the message data. + /// \return Void. + void ProcessSetResourceRequest(const std::shared_ptr &client, + const uint8_t *message_data); + /// Handle the case where an actor is disconnected, determine whether this /// actor needs to be reconstructed and then update actor table. /// This function needs to be called either when actor process dies or when diff --git a/src/ray/raylet/raylet_client.cc b/src/ray/raylet/raylet_client.cc index 09e9b5fed5e2..0f488089e6d0 100644 --- a/src/ray/raylet/raylet_client.cc +++ b/src/ray/raylet/raylet_client.cc @@ -386,3 +386,13 @@ ray::Status RayletClient::NotifyActorResumedFromCheckpoint( return conn_->WriteMessage(MessageType::NotifyActorResumedFromCheckpoint, &fbb); } + +ray::Status RayletClient::SetResource(const std::string &resource_name, + const double capacity, + const ray::ClientID &client_Id) { + flatbuffers::FlatBufferBuilder fbb; + auto message = ray::protocol::CreateSetResourceRequest( + fbb, fbb.CreateString(resource_name), capacity, to_flatbuf(fbb, client_Id)); + fbb.Finish(message); + return conn_->WriteMessage(MessageType::SetResourceRequest, &fbb); +} diff --git a/src/ray/raylet/raylet_client.h b/src/ray/raylet/raylet_client.h index ff66ff4621b0..0bdd076b5577 100644 --- a/src/ray/raylet/raylet_client.h +++ b/src/ray/raylet/raylet_client.h @@ -165,6 +165,14 @@ class RayletClient { ray::Status NotifyActorResumedFromCheckpoint(const ActorID &actor_id, const ActorCheckpointID &checkpoint_id); + /// Sets a resource with the specified capacity and client id + /// \param resource_name Name of the resource to be set + /// \param capacity Capacity of the resource + /// \param client_Id ClientID where the resource is to be set + /// \return ray::Status + ray::Status SetResource(const std::string &resource_name, const double capacity, + const ray::ClientID &client_Id); + Language GetLanguage() const { return language_; } ClientID GetClientID() const { return client_id_; } From f3b8b9093d98da2990a8c020aa12cfb9379e87e7 Mon Sep 17 00:00:00 2001 From: Adi Zimmerman Date: Sun, 12 May 2019 15:08:47 -0700 Subject: [PATCH 003/118] Update tutorial link in doc (#4777) --- doc/source/tune.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/source/tune.rst b/doc/source/tune.rst index bfeb729e6469..2674f5b064a7 100644 --- a/doc/source/tune.rst +++ b/doc/source/tune.rst @@ -7,9 +7,9 @@ Tune: Scalable Hyperparameter Search Tune is a scalable framework for hyperparameter search with a focus on deep learning and deep reinforcement learning. -You can find the code for Tune `here on GitHub `__. To get started with Tune, try going through `our tutorial of using Tune with Keras `__. +You can find the code for Tune `here on GitHub `__. To get started with Tune, try going through `our tutorial of using Tune with Keras `__. -(Experimental): You can try out `the above tutorial on a free hosted server via Binder `__. +(Experimental): You can try out `the above tutorial on a free hosted server via Binder `__. Features From 69352e3302d1a8eeff594f4ee73c858a874df2b4 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Sun, 12 May 2019 21:29:58 -0700 Subject: [PATCH 004/118] [rllib] Implement learn_on_batch() in torch policy graph --- .../ray/rllib/evaluation/torch_policy_graph.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/python/ray/rllib/evaluation/torch_policy_graph.py b/python/ray/rllib/evaluation/torch_policy_graph.py index 35220dc54570..fb5c879a1ab8 100644 --- a/python/ray/rllib/evaluation/torch_policy_graph.py +++ b/python/ray/rllib/evaluation/torch_policy_graph.py @@ -85,6 +85,24 @@ def compute_actions(self, [h.cpu().numpy() for h in state], self.extra_action_out(model_out)) + @override(PolicyGraph) + def learn_on_batch(self, postprocessed_batch): + with self.lock: + loss_in = [] + for key in self._loss_inputs: + loss_in.append( + torch.from_numpy(postprocessed_batch[key]).to(self.device)) + loss_out = self._loss(self._model, *loss_in) + self._optimizer.zero_grad() + loss_out.backward() + + grad_process_info = self.extra_grad_process() + self._optimizer.step() + + grad_info = self.extra_grad_info() + grad_info.update(grad_process_info) + return {LEARNER_STATS_KEY: grad_info} + @override(PolicyGraph) def compute_gradients(self, postprocessed_batch): with self.lock: From 62c949bbd502a62a4fdcc40c7d5384a2f44d05c4 Mon Sep 17 00:00:00 2001 From: Qing Wang Date: Mon, 13 May 2019 14:53:10 +0800 Subject: [PATCH 005/118] Fix `ray stop` by killing raylet before plasma (#4778) --- python/ray/scripts/scripts.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/ray/scripts/scripts.py b/python/ray/scripts/scripts.py index d870c655a4a9..01c8e267b85a 100644 --- a/python/ray/scripts/scripts.py +++ b/python/ray/scripts/scripts.py @@ -380,9 +380,11 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards, @cli.command() def stop(): + # Note that raylet needs to exit before object store, otherwise + # it cannot exit gracefully. processes_to_kill = [ - "plasma_store_server", "raylet", + "plasma_store_server", "raylet_monitor", "monitor.py", "redis-server", From 1622fc21fc50fd2decda93c328aab03f38c120a9 Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Mon, 13 May 2019 11:59:12 -0700 Subject: [PATCH 006/118] Fatal check if object store dies (#4763) --- .../object_manager/object_store_notification_manager.cc | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/ray/object_manager/object_store_notification_manager.cc b/src/ray/object_manager/object_store_notification_manager.cc index 746f4d622d5a..5245a94ace3a 100644 --- a/src/ray/object_manager/object_store_notification_manager.cc +++ b/src/ray/object_manager/object_store_notification_manager.cc @@ -42,6 +42,11 @@ void ObjectStoreNotificationManager::NotificationWait() { void ObjectStoreNotificationManager::ProcessStoreLength( const boost::system::error_code &error) { notification_.resize(length_); + if (error) { + RAY_LOG(FATAL) + << "Problem communicating with the object store from raylet, check logs or " + << "dmesg for previous errors: " << boost_to_ray_status(error).ToString(); + } boost::asio::async_read( socket_, boost::asio::buffer(notification_), boost::bind(&ObjectStoreNotificationManager::ProcessStoreNotification, this, @@ -50,7 +55,7 @@ void ObjectStoreNotificationManager::ProcessStoreLength( void ObjectStoreNotificationManager::ProcessStoreNotification( const boost::system::error_code &error) { - if (error.value() != boost::system::errc::success) { + if (error) { RAY_LOG(FATAL) << "Problem communicating with the object store from raylet, check logs or " << "dmesg for previous errors: " << boost_to_ray_status(error).ToString(); From c5161a2c4d0e0a46c87dba82e6f3dfafcdd57da0 Mon Sep 17 00:00:00 2001 From: Jones Wong Date: Tue, 14 May 2019 06:39:25 +0800 Subject: [PATCH 007/118] [rllib] fix clip by value issue as TF upgraded (#4697) * fix clip_by_value issue * fix typo --- python/ray/rllib/agents/ddpg/ddpg_policy_graph.py | 14 +++++++++----- python/ray/rllib/utils/tf_run_builder.py | 2 ++ 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/python/ray/rllib/agents/ddpg/ddpg_policy_graph.py b/python/ray/rllib/agents/ddpg/ddpg_policy_graph.py index 52b7593f2dfa..6c4917ad853f 100644 --- a/python/ray/rllib/agents/ddpg/ddpg_policy_graph.py +++ b/python/ray/rllib/agents/ddpg/ddpg_policy_graph.py @@ -166,8 +166,9 @@ def __init__(self, observation_space, action_space, config): stddev=self.config["target_noise"]), -target_noise_clip, target_noise_clip) policy_tp1_smoothed = tf.clip_by_value( - policy_tp1 + clipped_normal_sample, action_space.low, - action_space.high) + policy_tp1 + clipped_normal_sample, + action_space.low * tf.ones_like(policy_tp1), + action_space.high * tf.ones_like(policy_tp1)) else: # no smoothing, just use deterministic actions policy_tp1_smoothed = policy_tp1 @@ -473,8 +474,9 @@ def make_noisy_actions(): tf.shape(deterministic_actions), stddev=self.config["exploration_gaussian_sigma"]) stochastic_actions = tf.clip_by_value( - deterministic_actions + normal_sample, action_low, - action_high) + deterministic_actions + normal_sample, + action_low * tf.ones_like(deterministic_actions), + action_high * tf.ones_like(deterministic_actions)) elif noise_type == "ou": # add OU noise for exploration, DDPG-style zero_acts = action_low.size * [.0] @@ -494,7 +496,9 @@ def make_noisy_actions(): noise = noise_scale * base_scale \ * exploration_value * action_range stochastic_actions = tf.clip_by_value( - deterministic_actions + noise, action_low, action_high) + deterministic_actions + noise, + action_low * tf.ones_like(deterministic_actions), + action_high * tf.ones_like(deterministic_actions)) else: raise ValueError( "Unknown noise type '%s' (try 'ou' or 'gaussian')" % diff --git a/python/ray/rllib/utils/tf_run_builder.py b/python/ray/rllib/utils/tf_run_builder.py index 4694c96c03c1..ed4525ddfa79 100644 --- a/python/ray/rllib/utils/tf_run_builder.py +++ b/python/ray/rllib/utils/tf_run_builder.py @@ -47,6 +47,8 @@ def get(self, to_fetch): self.session, self.fetches, self.debug_name, self.feed_dict, os.environ.get("TF_TIMELINE_DIR")) except Exception: + logger.exception("Error fetching: {}, feed_dict={}".format( + self.fetches, self.feed_dict)) raise ValueError("Error fetching: {}, feed_dict={}".format( self.fetches, self.feed_dict)) if isinstance(to_fetch, int): From 3bbafc710530d8fa0be201706f7b4df8e1941f7d Mon Sep 17 00:00:00 2001 From: Richard Liaw Date: Tue, 14 May 2019 19:52:28 -0700 Subject: [PATCH 008/118] [autoscaler] Fix submit (#4782) --- python/ray/scripts/scripts.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/python/ray/scripts/scripts.py b/python/ray/scripts/scripts.py index 01c8e267b85a..1951af208573 100644 --- a/python/ray/scripts/scripts.py +++ b/python/ray/scripts/scripts.py @@ -555,7 +555,7 @@ def rsync_up(cluster_config_file, source, target, cluster_name): rsync(cluster_config_file, source, target, cluster_name, down=False) -@cli.command() +@cli.command(context_settings={"ignore_unknown_options": True}) @click.argument("cluster_config_file", required=True, type=str) @click.option( "--docker", @@ -588,14 +588,17 @@ def rsync_up(cluster_config_file, source, target, cluster_name): @click.option( "--port-forward", required=False, type=int, help="Port to forward.") @click.argument("script", required=True, type=str) -@click.argument("script_args", required=False, type=str, nargs=-1) +@click.option("--args", required=False, type=str, help="Script args.") def submit(cluster_config_file, docker, screen, tmux, stop, start, - cluster_name, port_forward, script, script_args): + cluster_name, port_forward, script, args): """Uploads and runs a script on the specified cluster. The script is automatically synced to the following location: os.path.join("~", os.path.basename(script)) + + Example: + >>> ray submit [CLUSTER.YAML] experiment.py --args="--smoke-test" """ assert not (screen and tmux), "Can specify only one of `screen` or `tmux`." @@ -606,7 +609,7 @@ def submit(cluster_config_file, docker, screen, tmux, stop, start, target = os.path.join("~", os.path.basename(script)) rsync(cluster_config_file, script, target, cluster_name, down=False) - cmd = " ".join(["python", target] + list(script_args)) + cmd = " ".join(["python", target, args]) exec_cluster(cluster_config_file, cmd, docker, screen, tmux, stop, False, cluster_name, port_forward) From cb1a195ca2914288c06a65b3f4d20ea9499091aa Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Wed, 15 May 2019 10:23:25 -0700 Subject: [PATCH 009/118] Queue tasks in the raylet in between async callbacks (#4766) * Add a SWAP TaskQueue so that we can keep track of tasks that are temporarily dequeued * Fix bug where tasks that fail to be forwarded don't appear to be local by adding them to SWAP queue * cleanups * updates * updates --- src/ray/raylet/node_manager.cc | 53 ++++++++++++++++++++++-------- src/ray/raylet/node_manager.h | 5 +-- src/ray/raylet/scheduling_queue.cc | 40 ++++++++++++++-------- src/ray/raylet/scheduling_queue.h | 9 ++++- 4 files changed, 77 insertions(+), 30 deletions(-) diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index fa4ad4868aa5..efd190ba5b27 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -774,7 +774,10 @@ void NodeManager::DispatchTasks( } } } - local_queues_.RemoveTasks(removed_task_ids); + // Move the ASSIGNED task to the SWAP queue so that we remember that we have + // it queued locally. Once the GetTaskReply has been sent, the task will get + // re-queued, depending on whether the message succeeded or not. + local_queues_.MoveTasks(removed_task_ids, TaskState::READY, TaskState::SWAP); } void NodeManager::ProcessClientMessage( @@ -1825,11 +1828,15 @@ bool NodeManager::AssignTask(const Task &task) { auto message = protocol::CreateGetTaskReply(fbb, spec.ToFlatbuffer(fbb), fbb.CreateVector(resource_id_set_flatbuf)); fbb.Finish(message); - // Give the callback a copy of the task so it can modify it. - Task assigned_task(task); + const auto &task_id = spec.TaskId(); worker->Connection()->WriteMessageAsync( static_cast(protocol::MessageType::ExecuteTask), fbb.GetSize(), - fbb.GetBufferPointer(), [this, worker, assigned_task](ray::Status status) mutable { + fbb.GetBufferPointer(), [this, worker, task_id](ray::Status status) { + // Remove the ASSIGNED task from the SWAP queue. + TaskState state; + auto assigned_task = local_queues_.RemoveTask(task_id, &state); + RAY_CHECK(state == TaskState::SWAP); + if (status.ok()) { auto spec = assigned_task.GetTaskSpecification(); // We successfully assigned the task to the worker. @@ -2212,9 +2219,9 @@ void NodeManager::ForwardTaskOrResubmit(const Task &task, /// TODO(rkn): Should we check that the node manager is remote and not local? /// TODO(rkn): Should we check if the remote node manager is known to be dead? // Attempt to forward the task. - ForwardTask(task, node_manager_id, [this, task, node_manager_id](ray::Status error) { + ForwardTask(task, node_manager_id, [this, node_manager_id](ray::Status error, + const Task &task) { const TaskID task_id = task.GetTaskSpecification().TaskId(); - RAY_LOG(INFO) << "Failed to forward task " << task_id << " to node manager " << node_manager_id; @@ -2236,14 +2243,22 @@ void NodeManager::ForwardTaskOrResubmit(const Task &task, RayConfig::instance().node_manager_forward_task_retry_timeout_milliseconds()); retry_timer->expires_from_now(retry_duration); retry_timer->async_wait( - [this, task, task_id, retry_timer](const boost::system::error_code &error) { + [this, task_id, retry_timer](const boost::system::error_code &error) { // Timer killing will receive the boost::asio::error::operation_aborted, // we only handle the timeout event. RAY_CHECK(!error); RAY_LOG(INFO) << "Resubmitting task " << task_id << " because ForwardTask failed."; + // Remove the RESUBMITTED task from the SWAP queue. + TaskState state; + const auto task = local_queues_.RemoveTask(task_id, &state); + RAY_CHECK(state == TaskState::SWAP); + // Submit the task again. SubmitTask(task, Lineage()); }); + // Temporarily move the RESUBMITTED task to the SWAP queue while the + // timer is active. + local_queues_.QueueTasks({task}, TaskState::SWAP); // Remove the task from the lineage cache. The task will get added back // once it is resubmitted. lineage_cache_.RemoveWaitingTask(task_id); @@ -2256,8 +2271,9 @@ void NodeManager::ForwardTaskOrResubmit(const Task &task, }); } -void NodeManager::ForwardTask(const Task &task, const ClientID &node_id, - const std::function &on_error) { +void NodeManager::ForwardTask( + const Task &task, const ClientID &node_id, + const std::function &on_error) { const auto &spec = task.GetTaskSpecification(); auto task_id = spec.TaskId(); @@ -2291,16 +2307,25 @@ void NodeManager::ForwardTask(const Task &task, const ClientID &node_id, if (it == remote_server_connections_.end()) { // TODO(atumanov): caller must handle failure to ensure tasks are not lost. RAY_LOG(INFO) << "No NodeManager connection found for GCS client id " << node_id; - on_error(ray::Status::IOError("NodeManager connection not found")); + on_error(ray::Status::IOError("NodeManager connection not found"), task); return; } - auto &server_conn = it->second; + // Move the FORWARDING task to the SWAP queue so that we remember that we + // have it queued locally. Once the ForwardTaskRequest has been sent, the + // task will get re-queued, depending on whether the message succeeded or + // not. + local_queues_.QueueTasks({task}, TaskState::SWAP); server_conn->WriteMessageAsync( static_cast(protocol::MessageType::ForwardTaskRequest), fbb.GetSize(), - fbb.GetBufferPointer(), - [this, on_error, task_id, node_id, spec](ray::Status status) { + fbb.GetBufferPointer(), [this, on_error, task_id, node_id](ray::Status status) { + // Remove the FORWARDING task from the SWAP queue. + TaskState state; + const auto task = local_queues_.RemoveTask(task_id, &state); + RAY_CHECK(state == TaskState::SWAP); + if (status.ok()) { + const auto &spec = task.GetTaskSpecification(); // If we were able to forward the task, remove the forwarded task from the // lineage cache since the receiving node is now responsible for writing // the task to the GCS. @@ -2335,7 +2360,7 @@ void NodeManager::ForwardTask(const Task &task, const ClientID &node_id, } } } else { - on_error(status); + on_error(status, task); } }); } diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index 8c5973c73fac..576ffbc23f72 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -246,8 +246,9 @@ class NodeManager { /// \param task The task to forward. /// \param node_id The ID of the node to forward the task to. /// \param on_error Callback on run on non-ok status. - void ForwardTask(const Task &task, const ClientID &node_id, - const std::function &on_error); + void ForwardTask( + const Task &task, const ClientID &node_id, + const std::function &on_error); /// Dispatch locally scheduled tasks. This attempts the transition from "scheduled" to /// "running" task state. diff --git a/src/ray/raylet/scheduling_queue.cc b/src/ray/raylet/scheduling_queue.cc index 986ede199d38..29af345b8391 100644 --- a/src/ray/raylet/scheduling_queue.cc +++ b/src/ray/raylet/scheduling_queue.cc @@ -9,7 +9,8 @@ namespace { static constexpr const char *task_state_strings[] = { "placeable", "waiting", "ready", - "running", "infeasible", "waiting_for_actor_creation"}; + "running", "infeasible", "waiting for actor creation", + "swap"}; static_assert(sizeof(task_state_strings) / sizeof(const char *) == static_cast(ray::raylet::TaskState::kNumTaskQueues), "Must specify a TaskState name for every task queue"); @@ -172,6 +173,9 @@ void SchedulingQueue::FilterState(std::unordered_set &task_ids, case TaskState::INFEASIBLE: FilterStateFromQueue(task_ids, TaskState::INFEASIBLE); break; + case TaskState::SWAP: + FilterStateFromQueue(task_ids, TaskState::SWAP); + break; case TaskState::BLOCKED: { const auto blocked_ids = GetBlockedTaskIds(); for (auto it = task_ids.begin(); it != task_ids.end();) { @@ -230,7 +234,7 @@ std::vector SchedulingQueue::RemoveTasks(std::unordered_set &task_ // Try to find the tasks to remove from the queues. for (const auto &task_state : { TaskState::PLACEABLE, TaskState::WAITING, TaskState::READY, TaskState::RUNNING, - TaskState::INFEASIBLE, TaskState::WAITING_FOR_ACTOR_CREATION, + TaskState::INFEASIBLE, TaskState::WAITING_FOR_ACTOR_CREATION, TaskState::SWAP, }) { RemoveTasksFromQueue(task_state, task_ids, &removed_tasks); } @@ -245,7 +249,7 @@ Task SchedulingQueue::RemoveTask(const TaskID &task_id, TaskState *removed_task_ // Try to find the task to remove in the queues. for (const auto &task_state : { TaskState::PLACEABLE, TaskState::WAITING, TaskState::READY, TaskState::RUNNING, - TaskState::INFEASIBLE, TaskState::WAITING_FOR_ACTOR_CREATION, + TaskState::INFEASIBLE, TaskState::WAITING_FOR_ACTOR_CREATION, TaskState::SWAP, }) { RemoveTasksFromQueue(task_state, task_id_set, &removed_tasks); if (task_id_set.empty()) { @@ -260,7 +264,7 @@ Task SchedulingQueue::RemoveTask(const TaskID &task_id, TaskState *removed_task_ } // Make sure we got the removed task. - RAY_CHECK(removed_tasks.size() == 1); + RAY_CHECK(removed_tasks.size() == 1) << task_id; const auto &task = removed_tasks.front(); RAY_CHECK(task.GetTaskSpecification().TaskId() == task_id); return task; @@ -287,6 +291,9 @@ void SchedulingQueue::MoveTasks(std::unordered_set &task_ids, TaskState case TaskState::INFEASIBLE: RemoveTasksFromQueue(TaskState::INFEASIBLE, task_ids, &removed_tasks); break; + case TaskState::SWAP: + RemoveTasksFromQueue(TaskState::SWAP, task_ids, &removed_tasks); + break; default: RAY_LOG(FATAL) << "Attempting to move tasks from unrecognized state " << static_cast::type>(src_state); @@ -312,6 +319,9 @@ void SchedulingQueue::MoveTasks(std::unordered_set &task_ids, TaskState case TaskState::INFEASIBLE: QueueTasks(removed_tasks, TaskState::INFEASIBLE); break; + case TaskState::SWAP: + QueueTasks(removed_tasks, TaskState::SWAP); + break; default: RAY_LOG(FATAL) << "Attempting to move tasks to unrecognized state " << static_cast::type>(dst_state); @@ -348,8 +358,16 @@ std::unordered_set SchedulingQueue::GetTaskIdsForDriver( std::unordered_set SchedulingQueue::GetTaskIdsForActor( const ActorID &actor_id) const { std::unordered_set task_ids; + int swap = static_cast(TaskState::SWAP); + int i = 0; for (const auto &task_queue : task_queues_) { - GetActorTasksFromQueue(*task_queue, actor_id, task_ids); + // This is a hack to make sure that we don't remove tasks from the SWAP + // queue, since these are always guaranteed to be removed and eventually + // resubmitted if necessary by the node manager. + if (i != swap) { + GetActorTasksFromQueue(*task_queue, actor_id, task_ids); + } + i++; } return task_ids; } @@ -385,10 +403,8 @@ const std::unordered_set &SchedulingQueue::GetDriverTaskIds() const { std::string SchedulingQueue::DebugString() const { std::stringstream result; result << "SchedulingQueue:"; - for (const auto &task_state : { - TaskState::PLACEABLE, TaskState::WAITING, TaskState::READY, TaskState::RUNNING, - TaskState::INFEASIBLE, TaskState::WAITING_FOR_ACTOR_CREATION, - }) { + for (size_t i = 0; i < static_cast(ray::raylet::TaskState::kNumTaskQueues); i++) { + TaskState task_state = static_cast(i); result << "\n- num " << GetTaskStateString(task_state) << " tasks: " << GetTaskQueue(task_state)->GetTasks().size(); } @@ -397,10 +413,8 @@ std::string SchedulingQueue::DebugString() const { } void SchedulingQueue::RecordMetrics() const { - for (const auto &task_state : { - TaskState::PLACEABLE, TaskState::WAITING, TaskState::READY, TaskState::RUNNING, - TaskState::INFEASIBLE, TaskState::WAITING_FOR_ACTOR_CREATION, - }) { + for (size_t i = 0; i < static_cast(ray::raylet::TaskState::kNumTaskQueues); i++) { + TaskState task_state = static_cast(i); stats::SchedulingQueueStats().Record( static_cast(GetTaskQueue(task_state)->GetTasks().size()), {{stats::ValueTypeKey, diff --git a/src/ray/raylet/scheduling_queue.h b/src/ray/raylet/scheduling_queue.h index 3f7bab1233cb..4fd07e5ca606 100644 --- a/src/ray/raylet/scheduling_queue.h +++ b/src/ray/raylet/scheduling_queue.h @@ -33,6 +33,13 @@ enum class TaskState { // The task is an actor method and is waiting to learn where the actor was // created. WAITING_FOR_ACTOR_CREATION, + // Swap queue for tasks that are in between states. This can happen when a + // task is removed from one queue, and an async callback is responsible for + // re-queuing the task. For example, a READY task that has just been assigned + // to a worker will get moved to the SWAP queue while waiting for a response + // from the worker. If the worker accepts the task, the task will be added to + // the RUNNING queue, else it will be returned to READY. + SWAP, // The number of task queues. All states that precede this enum must have an // associated TaskQueue in SchedulingQueue. All states that succeed // this enum do not have an associated TaskQueue, since the tasks @@ -144,7 +151,7 @@ class SchedulingQueue { for (const auto &task_state : { TaskState::PLACEABLE, TaskState::WAITING, TaskState::READY, TaskState::RUNNING, TaskState::INFEASIBLE, - TaskState::WAITING_FOR_ACTOR_CREATION, + TaskState::WAITING_FOR_ACTOR_CREATION, TaskState::SWAP, }) { if (task_state == TaskState::READY) { task_queues_[static_cast(task_state)] = ready_queue_; From 643f62dc43929ace27789f1634d49a26e888c340 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Thu, 16 May 2019 11:19:31 +0800 Subject: [PATCH 010/118] [Java][Bazel] Refine auto-generated pom files (#4780) --- bazel/ray.bzl | 5 +- java/BUILD.bazel | 51 ++++++------------ java/api/pom.xml | 42 +++++++++------ java/api/pom_template.xml | 3 +- java/pom.xml | 72 ------------------------- java/runtime/pom.xml | 94 ++++++++++++++++++--------------- java/runtime/pom_template.xml | 3 +- java/streaming/pom.xml | 28 +++++----- java/streaming/pom_template.xml | 32 +++++++++++ java/test/pom.xml | 42 +++++++++------ java/test/pom_template.xml | 6 +-- java/tutorial/pom.xml | 6 +++ java/tutorial/pom_template.xml | 3 +- 13 files changed, 186 insertions(+), 201 deletions(-) create mode 100644 java/streaming/pom_template.xml diff --git a/bazel/ray.bzl b/bazel/ray.bzl index 4ba637f3cdd4..e26428bafa26 100644 --- a/bazel/ray.bzl +++ b/bazel/ray.bzl @@ -53,12 +53,13 @@ def define_java_module(name, additional_srcs = [], additional_resources = [], de size = "small", tags = ["checkstyle"], ) - -def gen_java_pom_file(name): pom_file( name = "org_ray_ray_" + name + "_pom", targets = [ ":org_ray_ray_" + name, ], template_file = name + "/pom_template.xml", + substitutions = { + "{auto_gen_header}": "", + }, ) diff --git a/java/BUILD.bazel b/java/BUILD.bazel index 2d2762d837e6..f86df8d40f96 100644 --- a/java/BUILD.bazel +++ b/java/BUILD.bazel @@ -1,4 +1,4 @@ -load("//bazel:ray.bzl", "flatbuffer_java_library", "define_java_module", "gen_java_pom_file") +load("//bazel:ray.bzl", "flatbuffer_java_library", "define_java_module") exports_files([ "testng.xml", @@ -7,27 +7,29 @@ exports_files([ "streaming/testng.xml", ]) +all_modules = [ + "api", + "runtime", + "test", + "tutorial", + "streaming", +] + java_import( name = "all_modules", jars = [ - "liborg_ray_ray_api.jar", - "liborg_ray_ray_api-src.jar", - "liborg_ray_ray_runtime.jar", - "liborg_ray_ray_runtime-src.jar", - "liborg_ray_ray_tutorial.jar", - "liborg_ray_ray_tutorial-src.jar", - "liborg_ray_ray_streaming.jar", - "liborg_ray_ray_streaming-src.jar", + "liborg_ray_ray_" + module + ".jar" for module in all_modules + ] + [ + "liborg_ray_ray_" + module + "-src.jar" for module in all_modules + ] + [ "all_tests_deploy.jar", "all_tests_deploy-src.jar", "streaming_tests_deploy.jar", "streaming_tests_deploy-src.jar", ], deps = [ - ":org_ray_ray_api", - ":org_ray_ray_runtime", - ":org_ray_ray_tutorial", - ":org_ray_ray_streaming", + ":org_ray_ray_" + module for module in all_modules + ] + [ ":all_tests", ":streaming_tests", ], @@ -247,30 +249,10 @@ genrule( local = 1, ) -# generate pom.xml file for maven compile -gen_java_pom_file( - name = "api", -) - -gen_java_pom_file( - name = "runtime", -) - -gen_java_pom_file( - name = "tutorial", -) - -gen_java_pom_file( - name = "test", -) - genrule( name = "copy_pom_file", srcs = [ - "//java:org_ray_ray_api_pom", - "//java:org_ray_ray_runtime_pom", - "//java:org_ray_ray_tutorial_pom", - "//java:org_ray_ray_test_pom", + "//java:org_ray_ray_" + module + "_pom" for module in all_modules ], outs = ["copy_pom_file.out"], cmd = """ @@ -280,6 +262,7 @@ genrule( cp -f $(location //java:org_ray_ray_runtime_pom) $$WORK_DIR/java/runtime/pom.xml cp -f $(location //java:org_ray_ray_tutorial_pom) $$WORK_DIR/java/tutorial/pom.xml cp -f $(location //java:org_ray_ray_test_pom) $$WORK_DIR/java/test/pom.xml + cp -f $(location //java:org_ray_ray_streaming_pom) $$WORK_DIR/java/streaming/pom.xml echo $$(date) > $@ """, local = 1, diff --git a/java/api/pom.xml b/java/api/pom.xml index c7a910cd989f..792e54f6c433 100644 --- a/java/api/pom.xml +++ b/java/api/pom.xml @@ -1,4 +1,5 @@ + @@ -16,21 +17,30 @@ jar - - org.slf4j - slf4j-log4j12 - - - javax.xml.bind - jaxb-api - - - com.sun.xml.bind - jaxb-core - - - com.sun.xml.bind - jaxb-impl - + + com.sun.xml.bind + jaxb-core + 2.3.0 + + + com.sun.xml.bind + jaxb-impl + 2.3.0 + + + javax.xml.bind + jaxb-api + 2.3.0 + + + log4j + log4j + 1.2.17 + + + org.slf4j + slf4j-log4j12 + 1.7.25 + diff --git a/java/api/pom_template.xml b/java/api/pom_template.xml index ae37175a812a..67088f9584cb 100644 --- a/java/api/pom_template.xml +++ b/java/api/pom_template.xml @@ -1,4 +1,5 @@ +{auto_gen_header} @@ -16,6 +17,6 @@ jar - {generated_bzl_deps} +{generated_bzl_deps} diff --git a/java/pom.xml b/java/pom.xml index ce5ffa2faa29..bf7a41229b9b 100644 --- a/java/pom.xml +++ b/java/pom.xml @@ -20,8 +20,6 @@ 1.8 UTF-8 0.1-SNAPSHOT - 1.7.25 - 2.3.0 @@ -31,76 +29,6 @@ arrow-plasma 0.13.0-SNAPSHOT - - de.ruedigermoeller - fst - 2.47 - - - org.ow2.asm - asm - 6.0 - - - com.github.davidmoten - flatbuffers-java - 1.9.0.1 - - - com.beust - jcommander - 1.72 - - - redis.clients - jedis - 2.8.0 - - - commons-io - commons-io - 2.5 - - - org.apache.commons - commons-lang3 - 3.4 - - - com.google.guava - guava - 19.0 - - - org.slf4j - slf4j-log4j12 - ${slf4j.version} - - - com.typesafe - config - 1.3.2 - - - org.testng - testng - 6.9.9 - - - javax.xml.bind - jaxb-api - ${jaxb.version} - - - com.sun.xml.bind - jaxb-core - ${jaxb.version} - - - com.sun.xml.bind - jaxb-impl - ${jaxb.version} - diff --git a/java/runtime/pom.xml b/java/runtime/pom.xml index 4b2cc7d50373..c7e1730c9004 100644 --- a/java/runtime/pom.xml +++ b/java/runtime/pom.xml @@ -1,4 +1,5 @@ + @@ -21,53 +22,60 @@ ray-api ${project.version} - - com.typesafe - config - - - org.apache.commons - commons-lang3 - - - de.ruedigermoeller - fst - - - com.github.davidmoten - flatbuffers-java - - - redis.clients - jedis - org.apache.arrow arrow-plasma - - commons-io - commons-io - - - com.google.guava - guava - - - org.slf4j - slf4j-log4j12 - - - org.ow2.asm - asm - - - - - org.testng - testng - test - + + com.github.davidmoten + flatbuffers-java + 1.9.0.1 + + + com.google.guava + guava + 27.0.1-jre + + + com.typesafe + config + 1.3.2 + + + commons-io + commons-io + 2.5 + + + de.ruedigermoeller + fst + 2.47 + + + org.apache.commons + commons-lang3 + 3.4 + + + org.ow2.asm + asm + 6.0 + + + org.slf4j + slf4j-api + 1.7.25 + + + org.slf4j + slf4j-log4j12 + 1.7.25 + + + redis.clients + jedis + 2.8.0 + diff --git a/java/runtime/pom_template.xml b/java/runtime/pom_template.xml index fc75efe70398..9200bd6c6003 100644 --- a/java/runtime/pom_template.xml +++ b/java/runtime/pom_template.xml @@ -1,4 +1,5 @@ +{auto_gen_header} @@ -25,7 +26,7 @@ org.apache.arrow arrow-plasma - {generated_bzl_deps} +{generated_bzl_deps} diff --git a/java/streaming/pom.xml b/java/streaming/pom.xml index c95976373d3c..3ee6e89c401e 100644 --- a/java/streaming/pom.xml +++ b/java/streaming/pom.xml @@ -1,4 +1,5 @@ + @@ -26,17 +27,20 @@ ray-runtime ${project.version} - - org.slf4j - slf4j-log4j12 - - - com.google.guava - guava - - - org.testng - testng - + + com.google.guava + guava + 27.0.1-jre + + + org.slf4j + slf4j-api + 1.7.25 + + + org.slf4j + slf4j-log4j12 + 1.7.25 + diff --git a/java/streaming/pom_template.xml b/java/streaming/pom_template.xml new file mode 100644 index 000000000000..3551e7443e5c --- /dev/null +++ b/java/streaming/pom_template.xml @@ -0,0 +1,32 @@ + +{auto_gen_header} + + + org.ray + ray-superpom + 0.1-SNAPSHOT + + 4.0.0 + + streaming + ray streaming + ray streaming + + jar + + + + org.ray + ray-api + ${project.version} + + + org.ray + ray-runtime + ${project.version} + +{generated_bzl_deps} + + diff --git a/java/test/pom.xml b/java/test/pom.xml index afb8da564293..10f7ea4b3313 100644 --- a/java/test/pom.xml +++ b/java/test/pom.xml @@ -1,5 +1,5 @@ - + @@ -22,34 +22,46 @@ ray-api ${project.version} - org.ray ray-runtime ${project.version} - - - org.testng - testng - - - - com.google.guava - guava - + + com.google.guava + guava + 27.0.1-jre + + + commons-io + commons-io + 2.5 + + + org.apache.commons + commons-lang3 + 3.4 + + + org.slf4j + slf4j-api + 1.7.25 + + + org.testng + testng + 6.9.9 + org.apache.maven.plugins maven-surefire-plugin - 3.0.0-M3 + 2.21.0 - false false ${basedir}/src/main/java/ - ${basedir}/src/main/resources/ ${project.build.directory}/classes/ diff --git a/java/test/pom_template.xml b/java/test/pom_template.xml index f67e735a5b80..9b8b3684f297 100644 --- a/java/test/pom_template.xml +++ b/java/test/pom_template.xml @@ -1,5 +1,5 @@ - +{auto_gen_header} @@ -22,14 +22,12 @@ ray-api ${project.version} - org.ray ray-runtime ${project.version} - - {generated_bzl_deps} +{generated_bzl_deps} diff --git a/java/tutorial/pom.xml b/java/tutorial/pom.xml index 48a03dc1ca8e..b0e78b40e15e 100644 --- a/java/tutorial/pom.xml +++ b/java/tutorial/pom.xml @@ -1,4 +1,5 @@ + ray-runtime ${project.version} + + com.google.guava + guava + 27.0.1-jre + diff --git a/java/tutorial/pom_template.xml b/java/tutorial/pom_template.xml index 3ced33cf3ac2..0f7b2fdf4693 100644 --- a/java/tutorial/pom_template.xml +++ b/java/tutorial/pom_template.xml @@ -1,4 +1,5 @@ +{auto_gen_header} ray-runtime ${project.version} - {generated_bzl_deps} +{generated_bzl_deps} From 1490a98a71351119013b7478172fea9b452c1082 Mon Sep 17 00:00:00 2001 From: Devin Petersohn Date: Wed, 15 May 2019 22:55:21 -0700 Subject: [PATCH 011/118] Bump version to 0.7.0 (#4791) --- python/ray/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/__init__.py b/python/ray/__init__.py index 1e382d1b9c2c..b15fb13cbf29 100644 --- a/python/ray/__init__.py +++ b/python/ray/__init__.py @@ -95,7 +95,7 @@ from ray.runtime_context import _get_runtime_context # noqa: E402 # Ray version string. -__version__ = "0.7.0.dev3" +__version__ = "0.7.0" __all__ = [ "LOCAL_MODE", From 98dd0331793402cd8ed88f94e056abd9f5348c91 Mon Sep 17 00:00:00 2001 From: Yuhong Guo Date: Thu, 16 May 2019 19:53:15 +0800 Subject: [PATCH 012/118] [JAVA] setDefaultUncaughtExceptionHandler to log uncaught exception in user thread. (#4798) * Add WorkerUncaughtExceptionHandler * Fix * revert bazel and pom --- java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java | 4 ++-- .../java/org/ray/runtime/runner/worker/DefaultWorker.java | 3 +++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java b/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java index 647b77e336b4..7439dfa430f8 100644 --- a/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java +++ b/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java @@ -76,13 +76,13 @@ public List getAllNodeInfo() { NodeInfo nodeInfo = new NodeInfo( clientId, data.nodeManagerAddress(), true, resources); clients.put(clientId, nodeInfo); - } else if (data.entryType() == EntryType.RES_CREATEUPDATE){ + } else if (data.entryType() == EntryType.RES_CREATEUPDATE) { Preconditions.checkState(clients.containsKey(clientId)); NodeInfo nodeInfo = clients.get(clientId); for (int i = 0; i < data.resourcesTotalLabelLength(); i++) { nodeInfo.resources.put(data.resourcesTotalLabel(i), data.resourcesTotalCapacity(i)); } - } else if (data.entryType() == EntryType.RES_DELETE){ + } else if (data.entryType() == EntryType.RES_DELETE) { Preconditions.checkState(clients.containsKey(clientId)); NodeInfo nodeInfo = clients.get(clientId); for (int i = 0; i < data.resourcesTotalLabelLength(); i++) { diff --git a/java/runtime/src/main/java/org/ray/runtime/runner/worker/DefaultWorker.java b/java/runtime/src/main/java/org/ray/runtime/runner/worker/DefaultWorker.java index 6fd3ea0e76f9..211411906fdc 100644 --- a/java/runtime/src/main/java/org/ray/runtime/runner/worker/DefaultWorker.java +++ b/java/runtime/src/main/java/org/ray/runtime/runner/worker/DefaultWorker.java @@ -15,6 +15,9 @@ public class DefaultWorker { public static void main(String[] args) { try { System.setProperty("ray.worker.mode", "WORKER"); + Thread.setDefaultUncaughtExceptionHandler((Thread t, Throwable e) -> { + LOGGER.error("Uncaught worker exception in thread {}: {}", t, e); + }); Ray.init(); LOGGER.info("Worker started."); ((AbstractRayRuntime)Ray.internal()).loop(); From 9f2645d6ea7167fe9309c55b9f1ef58817fd42ab Mon Sep 17 00:00:00 2001 From: Richard Liaw Date: Thu, 16 May 2019 13:50:03 -0700 Subject: [PATCH 013/118] [tune] Fix CLI test (#4801) --- python/ray/tune/tests/test_commands.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/tune/tests/test_commands.py b/python/ray/tune/tests/test_commands.py index ab27cc65d56c..f55dc83362c3 100644 --- a/python/ray/tune/tests/test_commands.py +++ b/python/ray/tune/tests/test_commands.py @@ -33,7 +33,7 @@ def __exit__(self, *args): @pytest.fixture def start_ray(): - ray.init() + ray.init(log_to_driver=False) _register_all() yield ray.shutdown() From ffd596d5bbb969006da9287d0e28613bc50d6e03 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Fri, 17 May 2019 10:56:39 +0800 Subject: [PATCH 014/118] Fix pom file generation (#4800) --- bazel/ray.bzl | 12 +++++++----- java/runtime/pom.xml | 10 ++++++++++ java/streaming/pom.xml | 10 ++++++++++ java/test.sh | 2 +- 4 files changed, 28 insertions(+), 6 deletions(-) diff --git a/bazel/ray.bzl b/bazel/ray.bzl index e26428bafa26..750b90a21aec 100644 --- a/bazel/ray.bzl +++ b/bazel/ray.bzl @@ -25,8 +25,10 @@ def flatbuffer_java_library(name, srcs, outs, out_prefix, includes = [], include ) def define_java_module(name, additional_srcs = [], additional_resources = [], define_test_lib = False, test_deps = [], **kwargs): + lib_name = "org_ray_ray_" + name + pom_file_targets = [lib_name] native.java_library( - name = "org_ray_ray_" + name, + name = lib_name, srcs = additional_srcs + native.glob([name + "/src/main/java/**/*.java"]), resources = native.glob([name + "/src/main/resources/**"]) + additional_resources, **kwargs @@ -40,8 +42,10 @@ def define_java_module(name, additional_srcs = [], additional_resources = [], de tags = ["checkstyle"], ) if define_test_lib: + test_lib_name = "org_ray_ray_" + name + "_test" + pom_file_targets.append(test_lib_name) native.java_library( - name = "org_ray_ray_" + name + "_test", + name = test_lib_name, srcs = native.glob([name + "/src/test/java/**/*.java"]), deps = test_deps, ) @@ -55,9 +59,7 @@ def define_java_module(name, additional_srcs = [], additional_resources = [], de ) pom_file( name = "org_ray_ray_" + name + "_pom", - targets = [ - ":org_ray_ray_" + name, - ], + targets = pom_file_targets, template_file = name + "/pom_template.xml", substitutions = { "{auto_gen_header}": "", diff --git a/java/runtime/pom.xml b/java/runtime/pom.xml index c7e1730c9004..1ce51971c03e 100644 --- a/java/runtime/pom.xml +++ b/java/runtime/pom.xml @@ -26,6 +26,11 @@ org.apache.arrow arrow-plasma + + com.beust + jcommander + 1.72 + com.github.davidmoten flatbuffers-java @@ -71,6 +76,11 @@ slf4j-log4j12 1.7.25 + + org.testng + testng + 6.9.9 + redis.clients jedis diff --git a/java/streaming/pom.xml b/java/streaming/pom.xml index 3ee6e89c401e..382233fb02af 100644 --- a/java/streaming/pom.xml +++ b/java/streaming/pom.xml @@ -27,6 +27,11 @@ ray-runtime ${project.version} + + com.beust + jcommander + 1.72 + com.google.guava guava @@ -41,6 +46,11 @@ org.slf4j slf4j-log4j12 1.7.25 + + + org.testng + testng + 6.9.9 diff --git a/java/test.sh b/java/test.sh index 48242f39888b..ba728f14bf38 100755 --- a/java/test.sh +++ b/java/test.sh @@ -38,5 +38,5 @@ popd pushd $ROOT_DIR echo "Testing maven install." -mvn clean install -Dmaven.test.skip +mvn clean install -DskipTests popd From 7d5ef6d99c4f784bb50efad90b814ded3e46176b Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Thu, 16 May 2019 22:05:07 -0700 Subject: [PATCH 015/118] [rllib] Support continuous action distributions in IMPALA/APPO (#4771) --- doc/source/rllib-algorithms.rst | 2 +- doc/source/rllib-env.rst | 2 +- python/ray/rllib/agents/impala/vtrace.py | 46 ++++++++++++------- .../agents/impala/vtrace_policy_graph.py | 25 +++++----- .../ray/rllib/agents/ppo/appo_policy_graph.py | 17 ++++--- .../pendulum-appo-vtrace.yaml | 12 +++++ 6 files changed, 64 insertions(+), 40 deletions(-) create mode 100644 python/ray/rllib/tuned_examples/regression_tests/pendulum-appo-vtrace.yaml diff --git a/doc/source/rllib-algorithms.rst b/doc/source/rllib-algorithms.rst index 9ee6108476a8..2f1a74b2458b 100644 --- a/doc/source/rllib-algorithms.rst +++ b/doc/source/rllib-algorithms.rst @@ -95,7 +95,7 @@ Asynchronous Proximal Policy Optimization (APPO) `[implementation] `__ We include an asynchronous variant of Proximal Policy Optimization (PPO) based on the IMPALA architecture. This is similar to IMPALA but using a surrogate policy loss with clipping. Compared to synchronous PPO, APPO is more efficient in wall-clock time due to its use of asynchronous sampling. Using a clipped loss also allows for multiple SGD passes, and therefore the potential for better sample efficiency compared to IMPALA. V-trace can also be enabled to correct for off-policy samples. -This implementation is currently *experimental*. Consider also using `PPO `__ or `IMPALA `__. +APPO is not always more efficient; it is often better to simply use `PPO `__ or `IMPALA `__. Tuned examples: `PongNoFrameskip-v4 `__ diff --git a/doc/source/rllib-env.rst b/doc/source/rllib-env.rst index 056b7c3fc791..1373c450d7ab 100644 --- a/doc/source/rllib-env.rst +++ b/doc/source/rllib-env.rst @@ -13,7 +13,7 @@ Algorithm Discrete Actions Continuous Actions Multi-Agent Recurre A2C, A3C **Yes** `+parametric`_ **Yes** **Yes** **Yes** PPO, APPO **Yes** `+parametric`_ **Yes** **Yes** **Yes** PG **Yes** `+parametric`_ **Yes** **Yes** **Yes** -IMPALA **Yes** `+parametric`_ No **Yes** **Yes** +IMPALA **Yes** `+parametric`_ **Yes** **Yes** **Yes** DQN, Rainbow **Yes** `+parametric`_ No **Yes** No DDPG, TD3 No **Yes** **Yes** No APEX-DQN **Yes** `+parametric`_ No **Yes** No diff --git a/python/ray/rllib/agents/impala/vtrace.py b/python/ray/rllib/agents/impala/vtrace.py index cc560d9937e4..2319c633ea89 100644 --- a/python/ray/rllib/agents/impala/vtrace.py +++ b/python/ray/rllib/agents/impala/vtrace.py @@ -34,6 +34,7 @@ import collections +from ray.rllib.models.action_dist import Categorical from ray.rllib.utils import try_import_tf tf = try_import_tf() @@ -48,12 +49,15 @@ VTraceReturns = collections.namedtuple("VTraceReturns", "vs pg_advantages") -def log_probs_from_logits_and_actions(policy_logits, actions): - return multi_log_probs_from_logits_and_actions([policy_logits], - [actions])[0] +def log_probs_from_logits_and_actions(policy_logits, + actions, + dist_class=Categorical): + return multi_log_probs_from_logits_and_actions([policy_logits], [actions], + dist_class)[0] -def multi_log_probs_from_logits_and_actions(policy_logits, actions): +def multi_log_probs_from_logits_and_actions(policy_logits, actions, + dist_class): """Computes action log-probs from policy logits and actions. In the notation used throughout documentation and comments, T refers to the @@ -68,11 +72,11 @@ def multi_log_probs_from_logits_and_actions(policy_logits, actions): ..., [T, B, ACTION_SPACE[-1]] with un-normalized log-probabilities parameterizing a softmax policy. - actions: A list with length of ACTION_SPACE of int32 + actions: A list with length of ACTION_SPACE of tensors of shapes - [T, B], + [T, B, ...], ..., - [T, B] + [T, B, ...] with actions. Returns: @@ -87,8 +91,16 @@ def multi_log_probs_from_logits_and_actions(policy_logits, actions): log_probs = [] for i in range(len(policy_logits)): - log_probs.append(-tf.nn.sparse_softmax_cross_entropy_with_logits( - logits=policy_logits[i], labels=actions[i])) + p_shape = tf.shape(policy_logits[i]) + a_shape = tf.shape(actions[i]) + policy_logits_flat = tf.reshape(policy_logits[i], + tf.concat([[-1], p_shape[2:]], axis=0)) + actions_flat = tf.reshape(actions[i], + tf.concat([[-1], a_shape[2:]], axis=0)) + log_probs.append( + tf.reshape( + dist_class(policy_logits_flat).logp(actions_flat), + a_shape[:2])) return log_probs @@ -100,6 +112,7 @@ def from_logits(behaviour_policy_logits, rewards, values, bootstrap_value, + dist_class=Categorical, clip_rho_threshold=1.0, clip_pg_rho_threshold=1.0, name="vtrace_from_logits"): @@ -111,6 +124,7 @@ def from_logits(behaviour_policy_logits, rewards, values, bootstrap_value, + dist_class, clip_rho_threshold=clip_rho_threshold, clip_pg_rho_threshold=clip_pg_rho_threshold, name=name) @@ -133,6 +147,7 @@ def multi_from_logits(behaviour_policy_logits, rewards, values, bootstrap_value, + dist_class, clip_rho_threshold=1.0, clip_pg_rho_threshold=1.0, name="vtrace_from_logits"): @@ -168,11 +183,11 @@ def multi_from_logits(behaviour_policy_logits, [T, B, ACTION_SPACE[-1]] with un-normalized log-probabilities parameterizing the softmax target policy. - actions: A list with length of ACTION_SPACE of int32 + actions: A list with length of ACTION_SPACE of tensors of shapes - [T, B], + [T, B, ...], ..., - [T, B] + [T, B, ...] with actions sampled from the behaviour policy. discounts: A float32 tensor of shape [T, B] with the discount encountered when following the behaviour policy. @@ -182,6 +197,7 @@ def multi_from_logits(behaviour_policy_logits, wrt. the target policy. bootstrap_value: A float32 of shape [B] with the value function estimate at time T. + dist_class: action distribution class for the logits. clip_rho_threshold: A scalar float32 tensor with the clipping threshold for importance weights (rho) when calculating the baseline targets (vs). rho^bar in the paper. @@ -208,13 +224,11 @@ def multi_from_logits(behaviour_policy_logits, behaviour_policy_logits[i], dtype=tf.float32) target_policy_logits[i] = tf.convert_to_tensor( target_policy_logits[i], dtype=tf.float32) - actions[i] = tf.convert_to_tensor(actions[i], dtype=tf.int32) # Make sure tensor ranks are as expected. # The rest will be checked by from_action_log_probs. behaviour_policy_logits[i].shape.assert_has_rank(3) target_policy_logits[i].shape.assert_has_rank(3) - actions[i].shape.assert_has_rank(2) with tf.name_scope( name, @@ -223,9 +237,9 @@ def multi_from_logits(behaviour_policy_logits, discounts, rewards, values, bootstrap_value ]): target_action_log_probs = multi_log_probs_from_logits_and_actions( - target_policy_logits, actions) + target_policy_logits, actions, dist_class) behaviour_action_log_probs = multi_log_probs_from_logits_and_actions( - behaviour_policy_logits, actions) + behaviour_policy_logits, actions, dist_class) log_rhos = get_log_rhos(target_action_log_probs, behaviour_action_log_probs) diff --git a/python/ray/rllib/agents/impala/vtrace_policy_graph.py b/python/ray/rllib/agents/impala/vtrace_policy_graph.py index 702aefb50a6e..d94a59dcef2e 100644 --- a/python/ray/rllib/agents/impala/vtrace_policy_graph.py +++ b/python/ray/rllib/agents/impala/vtrace_policy_graph.py @@ -18,7 +18,6 @@ from ray.rllib.models.action_dist import MultiCategorical from ray.rllib.models.catalog import ModelCatalog from ray.rllib.utils.annotations import override -from ray.rllib.utils.error import UnsupportedSpaceException from ray.rllib.utils.explained_variance import explained_variance from ray.rllib.utils import try_import_tf @@ -40,6 +39,7 @@ def __init__(self, rewards, values, bootstrap_value, + dist_class, valid_mask, vf_loss_coeff=0.5, entropy_coeff=0.01, @@ -52,7 +52,7 @@ def __init__(self, handle episode cut boundaries. Args: - actions: An int32 tensor of shape [T, B, ACTION_SPACE]. + actions: An int|float32 tensor of shape [T, B, ACTION_SPACE]. actions_logp: A float32 tensor of shape [T, B]. actions_entropy: A float32 tensor of shape [T, B]. dones: A bool tensor of shape [T, B]. @@ -70,6 +70,7 @@ def __init__(self, rewards: A float32 tensor of shape [T, B]. values: A float32 tensor of shape [T, B]. bootstrap_value: A float32 tensor of shape [B]. + dist_class: action distribution class for logits. valid_mask: A bool tensor of valid RNN input elements (#2992). """ @@ -78,11 +79,12 @@ def __init__(self, self.vtrace_returns = vtrace.multi_from_logits( behaviour_policy_logits=behaviour_logits, target_policy_logits=target_logits, - actions=tf.unstack(tf.cast(actions, tf.int32), axis=2), + actions=tf.unstack(actions, axis=2), discounts=tf.to_float(~dones) * discount, rewards=rewards, values=values, bootstrap_value=bootstrap_value, + dist_class=dist_class, clip_rho_threshold=tf.cast(clip_rho_threshold, tf.float32), clip_pg_rho_threshold=tf.cast(clip_pg_rho_threshold, tf.float32)) @@ -140,30 +142,28 @@ def __init__(self, if isinstance(action_space, gym.spaces.Discrete): is_multidiscrete = False - actions_shape = [None] output_hidden_shape = [action_space.n] elif isinstance(action_space, gym.spaces.multi_discrete.MultiDiscrete): is_multidiscrete = True - actions_shape = [None, len(action_space.nvec)] output_hidden_shape = action_space.nvec.astype(np.int32) else: - raise UnsupportedSpaceException( - "Action space {} is not supported for IMPALA.".format( - action_space)) + is_multidiscrete = False + output_hidden_shape = 1 # Create input placeholders + dist_class, logit_dim = ModelCatalog.get_action_dist( + action_space, self.config["model"]) if existing_inputs: actions, dones, behaviour_logits, rewards, observations, \ prev_actions, prev_rewards = existing_inputs[:7] existing_state_in = existing_inputs[7:-1] existing_seq_lens = existing_inputs[-1] else: - actions = tf.placeholder(tf.int64, actions_shape, name="ac") + actions = ModelCatalog.get_action_placeholder(action_space) dones = tf.placeholder(tf.bool, [None], name="dones") rewards = tf.placeholder(tf.float32, [None], name="rewards") behaviour_logits = tf.placeholder( - tf.float32, [None, sum(output_hidden_shape)], - name="behaviour_logits") + tf.float32, [None, logit_dim], name="behaviour_logits") observations = tf.placeholder( tf.float32, [None] + list(observation_space.shape)) existing_state_in = None @@ -174,8 +174,6 @@ def __init__(self, behaviour_logits, output_hidden_shape, axis=1) # Setup the policy - dist_class, logit_dim = ModelCatalog.get_action_dist( - action_space, self.config["model"]) prev_actions = ModelCatalog.get_action_placeholder(action_space) prev_rewards = tf.placeholder(tf.float32, [None], name="prev_reward") self.model = ModelCatalog.get_model( @@ -261,6 +259,7 @@ def make_time_major(tensor, drop_last=False): rewards=make_time_major(rewards, drop_last=True), values=make_time_major(values, drop_last=True), bootstrap_value=make_time_major(values)[-1], + dist_class=dist_class, valid_mask=make_time_major(mask, drop_last=True), vf_loss_coeff=self.config["vf_loss_coeff"], entropy_coeff=self.config["entropy_coeff"], diff --git a/python/ray/rllib/agents/ppo/appo_policy_graph.py b/python/ray/rllib/agents/ppo/appo_policy_graph.py index 64523c60d1b3..8ebb52353ede 100644 --- a/python/ray/rllib/agents/ppo/appo_policy_graph.py +++ b/python/ray/rllib/agents/ppo/appo_policy_graph.py @@ -18,7 +18,6 @@ LearningRateSchedule from ray.rllib.models.catalog import ModelCatalog from ray.rllib.utils.annotations import override -from ray.rllib.utils.error import UnsupportedSpaceException from ray.rllib.utils.explained_variance import explained_variance from ray.rllib.models.action_dist import MultiCategorical from ray.rllib.evaluation.postprocessing import compute_advantages @@ -94,6 +93,7 @@ def __init__(self, rewards, values, bootstrap_value, + dist_class, valid_mask, vf_loss_coeff=0.5, entropy_coeff=0.01, @@ -107,18 +107,19 @@ def __init__(self, handle episode cut boundaries. Arguments: - actions: An int32 tensor of shape [T, B, NUM_ACTIONS]. + actions: An int|float32 tensor of shape [T, B, logit_dim]. prev_actions_logp: A float32 tensor of shape [T, B]. actions_logp: A float32 tensor of shape [T, B]. action_kl: A float32 tensor of shape [T, B]. actions_entropy: A float32 tensor of shape [T, B]. dones: A bool tensor of shape [T, B]. - behaviour_logits: A float32 tensor of shape [T, B, NUM_ACTIONS]. - target_logits: A float32 tensor of shape [T, B, NUM_ACTIONS]. + behaviour_logits: A float32 tensor of shape [T, B, logit_dim]. + target_logits: A float32 tensor of shape [T, B, logit_dim]. discount: A float32 scalar. rewards: A float32 tensor of shape [T, B]. values: A float32 tensor of shape [T, B]. bootstrap_value: A float32 tensor of shape [B]. + dist_class: action distribution class for logits. valid_mask: A bool tensor of valid RNN input elements (#2992). """ @@ -127,11 +128,12 @@ def __init__(self, self.vtrace_returns = vtrace.multi_from_logits( behaviour_policy_logits=behaviour_logits, target_policy_logits=target_logits, - actions=tf.unstack(tf.cast(actions, tf.int32), axis=2), + actions=tf.unstack(actions, axis=2), discounts=tf.to_float(~dones) * discount, rewards=rewards, values=values, bootstrap_value=bootstrap_value, + dist_class=dist_class, clip_rho_threshold=tf.cast(clip_rho_threshold, tf.float32), clip_pg_rho_threshold=tf.cast(clip_pg_rho_threshold, tf.float32)) @@ -218,10 +220,6 @@ def __init__(self, elif isinstance(action_space, gym.spaces.multi_discrete.MultiDiscrete): is_multidiscrete = True output_hidden_shape = action_space.nvec.astype(np.int32) - elif self.config["vtrace"]: - raise UnsupportedSpaceException( - "Action space {} is not supported for APPO + VTrace.", - format(action_space)) else: is_multidiscrete = False output_hidden_shape = 1 @@ -365,6 +363,7 @@ def make_time_major(tensor, drop_last=False): rewards=make_time_major(rewards, drop_last=True), values=make_time_major(values, drop_last=True), bootstrap_value=make_time_major(values)[-1], + dist_class=dist_class, valid_mask=make_time_major(mask, drop_last=True), vf_loss_coeff=self.config["vf_loss_coeff"], entropy_coeff=self.config["entropy_coeff"], diff --git a/python/ray/rllib/tuned_examples/regression_tests/pendulum-appo-vtrace.yaml b/python/ray/rllib/tuned_examples/regression_tests/pendulum-appo-vtrace.yaml new file mode 100644 index 000000000000..245e908cc89c --- /dev/null +++ b/python/ray/rllib/tuned_examples/regression_tests/pendulum-appo-vtrace.yaml @@ -0,0 +1,12 @@ +pendulum-appo-vt: + env: Pendulum-v0 + run: APPO + stop: + episode_reward_mean: -900 # just check it learns a bit + timesteps_total: 500000 + config: + num_gpus: 0 + num_workers: 1 + gamma: 0.95 + train_batch_size: 50 + vtrace: true From 3807fb505b121e1e9f583f5859941cf07b1d7438 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Thu, 16 May 2019 22:12:07 -0700 Subject: [PATCH 016/118] [rllib] TensorFlow 2 compatibility (#4802) --- doc/source/rllib-env.rst | 2 +- doc/source/rllib-models.rst | 2 +- .../rllib/agents/ddpg/ddpg_policy_graph.py | 33 +- .../ray/rllib/agents/dqn/dqn_policy_graph.py | 33 +- python/ray/rllib/agents/impala/vtrace.py | 2 - .../agents/impala/vtrace_policy_graph.py | 3 - python/ray/rllib/agents/impala/vtrace_test.py | 4 +- .../ray/rllib/agents/ppo/appo_policy_graph.py | 3 - python/ray/rllib/agents/ppo/test/test.py | 4 +- python/ray/rllib/examples/batch_norm_model.py | 22 +- python/ray/rllib/examples/carla/README | 14 - python/ray/rllib/examples/carla/env.py | 684 ------------------ python/ray/rllib/examples/carla/models.py | 108 --- python/ray/rllib/examples/carla/scenarios.py | 131 ---- python/ray/rllib/examples/carla/train_a3c.py | 51 -- python/ray/rllib/examples/carla/train_dqn.py | 65 -- python/ray/rllib/examples/carla/train_ppo.py | 55 -- .../ray/rllib/examples/custom_fast_model.py | 4 +- python/ray/rllib/examples/custom_loss.py | 4 +- .../examples/export/cartpole_dqn_export.py | 4 +- .../ray/rllib/examples/multiagent_cartpole.py | 30 +- .../examples/parametric_action_cartpole.py | 21 +- python/ray/rllib/models/action_dist.py | 6 +- python/ray/rllib/models/fcnet.py | 18 +- python/ray/rllib/models/lstm.py | 4 +- python/ray/rllib/models/visionnet.py | 25 +- .../rllib/optimizers/aso_multi_gpu_learner.py | 6 +- python/ray/rllib/tests/test_catalog.py | 4 +- python/ray/rllib/tests/test_lstm.py | 9 +- python/ray/rllib/tests/test_nested_spaces.py | 11 +- python/ray/rllib/tests/test_optimizers.py | 4 +- python/ray/rllib/utils/__init__.py | 9 +- 32 files changed, 140 insertions(+), 1235 deletions(-) delete mode 100644 python/ray/rllib/examples/carla/README delete mode 100644 python/ray/rllib/examples/carla/env.py delete mode 100644 python/ray/rllib/examples/carla/models.py delete mode 100644 python/ray/rllib/examples/carla/scenarios.py delete mode 100644 python/ray/rllib/examples/carla/train_a3c.py delete mode 100644 python/ray/rllib/examples/carla/train_dqn.py delete mode 100644 python/ray/rllib/examples/carla/train_ppo.py diff --git a/doc/source/rllib-env.rst b/doc/source/rllib-env.rst index 1373c450d7ab..2701a689dc2c 100644 --- a/doc/source/rllib-env.rst +++ b/doc/source/rllib-env.rst @@ -92,7 +92,7 @@ In the above example, note that the ``env_creator`` function takes in an ``env_c OpenAI Gym ---------- -RLlib uses Gym as its environment interface for single-agent training. For more information on how to implement a custom Gym environment, see the `gym.Env class definition `__. You may also find the `SimpleCorridor `__ and `Carla simulator `__ example env implementations useful as a reference. +RLlib uses Gym as its environment interface for single-agent training. For more information on how to implement a custom Gym environment, see the `gym.Env class definition `__. You may find the `SimpleCorridor `__ example useful as a reference. Performance ~~~~~~~~~~~ diff --git a/doc/source/rllib-models.rst b/doc/source/rllib-models.rst index 7fd860a65a3e..b429e04be417 100644 --- a/doc/source/rllib-models.rst +++ b/doc/source/rllib-models.rst @@ -134,7 +134,7 @@ Custom TF models should subclass the common RLlib `model class `__ and associated `training scripts `__. You can also reference the `unit tests `__ for Tuple and Dict spaces, which show how to access nested observation fields. +For a full example of a custom model in code, see the `custom env example `__. You can also reference the `unit tests `__ for Tuple and Dict spaces, which show how to access nested observation fields. Custom Recurrent Models ~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/python/ray/rllib/agents/ddpg/ddpg_policy_graph.py b/python/ray/rllib/agents/ddpg/ddpg_policy_graph.py index 6c4917ad853f..675f9187f2c6 100644 --- a/python/ray/rllib/agents/ddpg/ddpg_policy_graph.py +++ b/python/ray/rllib/agents/ddpg/ddpg_policy_graph.py @@ -399,8 +399,6 @@ def set_state(self, state): self.set_pure_exploration_phase(state[2]) def _build_q_network(self, obs, obs_space, action_space, actions): - import tensorflow.contrib.layers as layers - if self.config["use_state_preprocessor"]: q_model = ModelCatalog.get_model({ "obs": obs, @@ -413,16 +411,12 @@ def _build_q_network(self, obs, obs_space, action_space, actions): activation = getattr(tf.nn, self.config["critic_hidden_activation"]) for hidden in self.config["critic_hiddens"]: - q_out = layers.fully_connected( - q_out, num_outputs=hidden, activation_fn=activation) - q_values = layers.fully_connected( - q_out, num_outputs=1, activation_fn=None) + q_out = tf.layers.dense(q_out, units=hidden, activation=activation) + q_values = tf.layers.dense(q_out, units=1, activation=None) return q_values, q_model def _build_policy_network(self, obs, obs_space, action_space): - import tensorflow.contrib.layers as layers - if self.config["use_state_preprocessor"]: model = ModelCatalog.get_model({ "obs": obs, @@ -434,16 +428,19 @@ def _build_policy_network(self, obs, obs_space, action_space): action_out = obs activation = getattr(tf.nn, self.config["actor_hidden_activation"]) - normalizer_fn = layers.layer_norm if self.config["parameter_noise"] \ - else None for hidden in self.config["actor_hiddens"]: - action_out = layers.fully_connected( - action_out, - num_outputs=hidden, - activation_fn=activation, - normalizer_fn=normalizer_fn) - action_out = layers.fully_connected( - action_out, num_outputs=self.dim_actions, activation_fn=None) + if self.config["parameter_noise"]: + import tensorflow.contrib.layers as layers + action_out = layers.fully_connected( + action_out, + num_outputs=hidden, + activation_fn=activation, + normalizer_fn=layers.layer_norm) + else: + action_out = tf.layers.dense( + action_out, units=hidden, activation=activation) + action_out = tf.layers.dense( + action_out, units=self.dim_actions, activation=None) # Use sigmoid to scale to [0,1], but also double magnitude of input to # emulate behaviour of tanh activation used in DDPG and TD3 papers. @@ -507,7 +504,7 @@ def make_noisy_actions(): def make_uniform_random_actions(): # pure random exploration option - uniform_random_actions = tf.random.uniform( + uniform_random_actions = tf.random_uniform( tf.shape(deterministic_actions)) # rescale uniform random actions according to action range tf_range = tf.constant(action_range[None], dtype="float32") diff --git a/python/ray/rllib/agents/dqn/dqn_policy_graph.py b/python/ray/rllib/agents/dqn/dqn_policy_graph.py index 5af38ed9e958..1e682ce80cfa 100644 --- a/python/ray/rllib/agents/dqn/dqn_policy_graph.py +++ b/python/ray/rllib/agents/dqn/dqn_policy_graph.py @@ -154,8 +154,6 @@ def __init__(self, v_max=10.0, sigma0=0.5, parameter_noise=False): - import tensorflow.contrib.layers as layers - self.model = model with tf.variable_scope("action_value"): if hiddens: @@ -164,13 +162,18 @@ def __init__(self, if use_noisy: action_out = self.noisy_layer( "hidden_%d" % i, action_out, hiddens[i], sigma0) - else: + elif parameter_noise: + import tensorflow.contrib.layers as layers action_out = layers.fully_connected( action_out, num_outputs=hiddens[i], activation_fn=tf.nn.relu, - normalizer_fn=layers.layer_norm - if parameter_noise else None) + normalizer_fn=layers.layer_norm) + else: + action_out = tf.layers.dense( + action_out, + units=hiddens[i], + activation=tf.nn.relu) else: # Avoid postprocessing the outputs. This enables custom models # to be used for parametric action DQN. @@ -183,10 +186,8 @@ def __init__(self, sigma0, non_linear=False) elif hiddens: - action_scores = layers.fully_connected( - action_out, - num_outputs=num_actions * num_atoms, - activation_fn=None) + action_scores = tf.layers.dense( + action_out, units=num_actions * num_atoms, activation=None) else: action_scores = model.outputs if num_atoms > 1: @@ -214,13 +215,15 @@ def __init__(self, state_out = self.noisy_layer("dueling_hidden_%d" % i, state_out, hiddens[i], sigma0) - else: - state_out = layers.fully_connected( + elif parameter_noise: + state_out = tf.contrib.layers.fully_connected( state_out, num_outputs=hiddens[i], activation_fn=tf.nn.relu, - normalizer_fn=layers.layer_norm - if parameter_noise else None) + normalizer_fn=tf.contrib.layers.layer_norm) + else: + state_out = tf.layers.dense( + state_out, units=hiddens[i], activation=tf.nn.relu) if use_noisy: state_score = self.noisy_layer( "dueling_output", @@ -229,8 +232,8 @@ def __init__(self, sigma0, non_linear=False) else: - state_score = layers.fully_connected( - state_out, num_outputs=num_atoms, activation_fn=None) + state_score = tf.layers.dense( + state_out, units=num_atoms, activation=None) if num_atoms > 1: support_logits_per_action_mean = tf.reduce_mean( support_logits_per_action, 1) diff --git a/python/ray/rllib/agents/impala/vtrace.py b/python/ray/rllib/agents/impala/vtrace.py index 2319c633ea89..67e76929dfc3 100644 --- a/python/ray/rllib/agents/impala/vtrace.py +++ b/python/ray/rllib/agents/impala/vtrace.py @@ -38,8 +38,6 @@ from ray.rllib.utils import try_import_tf tf = try_import_tf() -if tf: - nest = tf.contrib.framework.nest VTraceFromLogitsReturns = collections.namedtuple("VTraceFromLogitsReturns", [ "vs", "pg_advantages", "log_rhos", "behaviour_action_log_probs", diff --git a/python/ray/rllib/agents/impala/vtrace_policy_graph.py b/python/ray/rllib/agents/impala/vtrace_policy_graph.py index d94a59dcef2e..56b6de42ed5a 100644 --- a/python/ray/rllib/agents/impala/vtrace_policy_graph.py +++ b/python/ray/rllib/agents/impala/vtrace_policy_graph.py @@ -278,14 +278,11 @@ def make_time_major(tensor, drop_last=False): self.KL_stats.update({ "mean_KL_{}".format(i): tf.reduce_mean(kl), "max_KL_{}".format(i): tf.reduce_max(kl), - "median_KL_{}".format(i): tf.contrib.distributions. - percentile(kl, 50.0), }) else: self.KL_stats = { "mean_KL": tf.reduce_mean(kls[0]), "max_KL": tf.reduce_max(kls[0]), - "median_KL": tf.contrib.distributions.percentile(kls[0], 50.0), } # Initialize TFPolicyGraph diff --git a/python/ray/rllib/agents/impala/vtrace_test.py b/python/ray/rllib/agents/impala/vtrace_test.py index 145ed4e7a2cd..e1f39991b097 100644 --- a/python/ray/rllib/agents/impala/vtrace_test.py +++ b/python/ray/rllib/agents/impala/vtrace_test.py @@ -26,8 +26,10 @@ from absl.testing import parameterized import numpy as np -import tensorflow as tf import vtrace +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() def _shaped_arange(*shape): diff --git a/python/ray/rllib/agents/ppo/appo_policy_graph.py b/python/ray/rllib/agents/ppo/appo_policy_graph.py index 8ebb52353ede..caaaf512bcb1 100644 --- a/python/ray/rllib/agents/ppo/appo_policy_graph.py +++ b/python/ray/rllib/agents/ppo/appo_policy_graph.py @@ -399,14 +399,11 @@ def make_time_major(tensor, drop_last=False): self.KL_stats.update({ "mean_KL_{}".format(i): tf.reduce_mean(kl), "max_KL_{}".format(i): tf.reduce_max(kl), - "median_KL_{}".format(i): tf.contrib.distributions. - percentile(kl, 50.0), }) else: self.KL_stats = { "mean_KL": tf.reduce_mean(kls[0]), "max_KL": tf.reduce_max(kls[0]), - "median_KL": tf.contrib.distributions.percentile(kls[0], 50.0), } # Initialize TFPolicyGraph diff --git a/python/ray/rllib/agents/ppo/test/test.py b/python/ray/rllib/agents/ppo/test/test.py index 432b22f9aed2..1091b639c6f4 100644 --- a/python/ray/rllib/agents/ppo/test/test.py +++ b/python/ray/rllib/agents/ppo/test/test.py @@ -4,11 +4,13 @@ import unittest import numpy as np -import tensorflow as tf from numpy.testing import assert_allclose from ray.rllib.models.action_dist import Categorical from ray.rllib.agents.ppo.utils import flatten, concatenate +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() # TODO(ekl): move to rllib/models dir diff --git a/python/ray/rllib/examples/batch_norm_model.py b/python/ray/rllib/examples/batch_norm_model.py index 7852a62c2c24..c8a3fc83c0e4 100644 --- a/python/ray/rllib/examples/batch_norm_model.py +++ b/python/ray/rllib/examples/batch_norm_model.py @@ -5,13 +5,13 @@ import argparse -import tensorflow as tf -import tensorflow.contrib.slim as slim - import ray from ray import tune from ray.rllib.models import Model, ModelCatalog from ray.rllib.models.misc import normc_initializer +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() parser = argparse.ArgumentParser() parser.add_argument("--num-iters", type=int, default=200) @@ -24,21 +24,21 @@ def _build_layers_v2(self, input_dict, num_outputs, options): hiddens = [256, 256] for i, size in enumerate(hiddens): label = "fc{}".format(i) - last_layer = slim.fully_connected( + last_layer = tf.layers.dense( last_layer, size, - weights_initializer=normc_initializer(1.0), - activation_fn=tf.nn.tanh, - scope=label) + kernel_initializer=normc_initializer(1.0), + activation=tf.nn.tanh, + name=label) # Add a batch norm layer last_layer = tf.layers.batch_normalization( last_layer, training=input_dict["is_training"]) - output = slim.fully_connected( + output = tf.layers.dense( last_layer, num_outputs, - weights_initializer=normc_initializer(0.01), - activation_fn=None, - scope="fc_out") + kernel_initializer=normc_initializer(0.01), + activation=None, + name="fc_out") return output, last_layer diff --git a/python/ray/rllib/examples/carla/README b/python/ray/rllib/examples/carla/README deleted file mode 100644 index a066b048a2a1..000000000000 --- a/python/ray/rllib/examples/carla/README +++ /dev/null @@ -1,14 +0,0 @@ -(Experimental) OpenAI gym environment for https://github.com/carla-simulator/carla - -To run, first download and unpack the Carla binaries from this URL: https://github.com/carla-simulator/carla/releases/tag/0.7.0 - -Note that currently you also need to clone the Python code from `carla/benchmark_branch` which includes the Carla planner. - -Then, you can try running env.py to drive the car. Run one of the train_* scripts to attempt training. - - $ pkill -9 Carla - $ export CARLA_SERVER=/PATH/TO/CARLA_0.7.0/CarlaUE4.sh - $ export CARLA_PY_PATH=/PATH/TO/CARLA_BENCHMARK_BRANCH_REPO/PythonClient - $ python env.py - -Check out the scenarios.py file for different training and test scenarios that can be used. diff --git a/python/ray/rllib/examples/carla/env.py b/python/ray/rllib/examples/carla/env.py deleted file mode 100644 index af5b619afcdb..000000000000 --- a/python/ray/rllib/examples/carla/env.py +++ /dev/null @@ -1,684 +0,0 @@ -"""OpenAI gym environment for Carla. Run this file for a demo.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from datetime import datetime -import atexit -import cv2 -import os -import json -import random -import signal -import subprocess -import sys -import time -import traceback - -import numpy as np -try: - import scipy.misc -except Exception: - pass - -import gym -from gym.spaces import Box, Discrete, Tuple - -from scenarios import DEFAULT_SCENARIO - -# Set this where you want to save image outputs (or empty string to disable) -CARLA_OUT_PATH = os.environ.get("CARLA_OUT", os.path.expanduser("~/carla_out")) -if CARLA_OUT_PATH and not os.path.exists(CARLA_OUT_PATH): - os.makedirs(CARLA_OUT_PATH) - -# Set this to the path of your Carla binary -SERVER_BINARY = os.environ.get("CARLA_SERVER", - os.path.expanduser("~/CARLA_0.7.0/CarlaUE4.sh")) - -assert os.path.exists(SERVER_BINARY) -if "CARLA_PY_PATH" in os.environ: - sys.path.append(os.path.expanduser(os.environ["CARLA_PY_PATH"])) -else: - # TODO(ekl) switch this to the binary path once the planner is in master - sys.path.append(os.path.expanduser("~/carla/PythonClient/")) - -try: - from carla.client import CarlaClient - from carla.sensor import Camera - from carla.settings import CarlaSettings - from carla.planner.planner import Planner, REACH_GOAL, GO_STRAIGHT, \ - TURN_RIGHT, TURN_LEFT, LANE_FOLLOW -except Exception as e: - print("Failed to import Carla python libs, try setting $CARLA_PY_PATH") - raise e - -# Carla planner commands -COMMANDS_ENUM = { - REACH_GOAL: "REACH_GOAL", - GO_STRAIGHT: "GO_STRAIGHT", - TURN_RIGHT: "TURN_RIGHT", - TURN_LEFT: "TURN_LEFT", - LANE_FOLLOW: "LANE_FOLLOW", -} - -# Mapping from string repr to one-hot encoding index to feed to the model -COMMAND_ORDINAL = { - "REACH_GOAL": 0, - "GO_STRAIGHT": 1, - "TURN_RIGHT": 2, - "TURN_LEFT": 3, - "LANE_FOLLOW": 4, -} - -# Number of retries if the server doesn't respond -RETRIES_ON_ERROR = 5 - -# Dummy Z coordinate to use when we only care about (x, y) -GROUND_Z = 22 - -# Default environment configuration -ENV_CONFIG = { - "log_images": True, - "enable_planner": True, - "framestack": 2, # note: only [1, 2] currently supported - "convert_images_to_video": True, - "early_terminate_on_collision": True, - "verbose": True, - "reward_function": "custom", - "render_x_res": 800, - "render_y_res": 600, - "x_res": 80, - "y_res": 80, - "server_map": "/Game/Maps/Town02", - "scenarios": [DEFAULT_SCENARIO], - "use_depth_camera": False, - "discrete_actions": True, - "squash_action_logits": False, -} - -DISCRETE_ACTIONS = { - # coast - 0: [0.0, 0.0], - # turn left - 1: [0.0, -0.5], - # turn right - 2: [0.0, 0.5], - # forward - 3: [1.0, 0.0], - # brake - 4: [-0.5, 0.0], - # forward left - 5: [1.0, -0.5], - # forward right - 6: [1.0, 0.5], - # brake left - 7: [-0.5, -0.5], - # brake right - 8: [-0.5, 0.5], -} - -live_carla_processes = set() - - -def cleanup(): - print("Killing live carla processes", live_carla_processes) - for pgid in live_carla_processes: - os.killpg(pgid, signal.SIGKILL) - - -atexit.register(cleanup) - - -class CarlaEnv(gym.Env): - def __init__(self, config=ENV_CONFIG): - self.config = config - self.city = self.config["server_map"].split("/")[-1] - if self.config["enable_planner"]: - self.planner = Planner(self.city) - - if config["discrete_actions"]: - self.action_space = Discrete(len(DISCRETE_ACTIONS)) - else: - self.action_space = Box(-1.0, 1.0, shape=(2, ), dtype=np.float32) - if config["use_depth_camera"]: - image_space = Box( - -1.0, - 1.0, - shape=(config["y_res"], config["x_res"], - 1 * config["framestack"]), - dtype=np.float32) - else: - image_space = Box( - 0, - 255, - shape=(config["y_res"], config["x_res"], - 3 * config["framestack"]), - dtype=np.uint8) - self.observation_space = Tuple( # forward_speed, dist to goal - [ - image_space, - Discrete(len(COMMANDS_ENUM)), # next_command - Box(-128.0, 128.0, shape=(2, ), dtype=np.float32) - ]) - - # TODO(ekl) this isn't really a proper gym spec - self._spec = lambda: None - self._spec.id = "Carla-v0" - - self.server_port = None - self.server_process = None - self.client = None - self.num_steps = 0 - self.total_reward = 0 - self.prev_measurement = None - self.prev_image = None - self.episode_id = None - self.measurements_file = None - self.weather = None - self.scenario = None - self.start_pos = None - self.end_pos = None - self.start_coord = None - self.end_coord = None - self.last_obs = None - - def init_server(self): - print("Initializing new Carla server...") - # Create a new server process and start the client. - self.server_port = random.randint(10000, 60000) - self.server_process = subprocess.Popen( - [ - SERVER_BINARY, self.config["server_map"], "-windowed", - "-ResX=400", "-ResY=300", "-carla-server", - "-carla-world-port={}".format(self.server_port) - ], - preexec_fn=os.setsid, - stdout=open(os.devnull, "w")) - live_carla_processes.add(os.getpgid(self.server_process.pid)) - - for i in range(RETRIES_ON_ERROR): - try: - self.client = CarlaClient("localhost", self.server_port) - return self.client.connect() - except Exception as e: - print("Error connecting: {}, attempt {}".format(e, i)) - time.sleep(2) - - def clear_server_state(self): - print("Clearing Carla server state") - try: - if self.client: - self.client.disconnect() - self.client = None - except Exception as e: - print("Error disconnecting client: {}".format(e)) - pass - if self.server_process: - pgid = os.getpgid(self.server_process.pid) - os.killpg(pgid, signal.SIGKILL) - live_carla_processes.remove(pgid) - self.server_port = None - self.server_process = None - - def __del__(self): - self.clear_server_state() - - def reset(self): - error = None - for _ in range(RETRIES_ON_ERROR): - try: - if not self.server_process: - self.init_server() - return self._reset() - except Exception as e: - print("Error during reset: {}".format(traceback.format_exc())) - self.clear_server_state() - error = e - raise error - - def _reset(self): - self.num_steps = 0 - self.total_reward = 0 - self.prev_measurement = None - self.prev_image = None - self.episode_id = datetime.today().strftime("%Y-%m-%d_%H-%M-%S_%f") - self.measurements_file = None - - # Create a CarlaSettings object. This object is a wrapper around - # the CarlaSettings.ini file. Here we set the configuration we - # want for the new episode. - settings = CarlaSettings() - self.scenario = random.choice(self.config["scenarios"]) - assert self.scenario["city"] == self.city, (self.scenario, self.city) - self.weather = random.choice(self.scenario["weather_distribution"]) - settings.set( - SynchronousMode=True, - SendNonPlayerAgentsInfo=True, - NumberOfVehicles=self.scenario["num_vehicles"], - NumberOfPedestrians=self.scenario["num_pedestrians"], - WeatherId=self.weather) - settings.randomize_seeds() - - if self.config["use_depth_camera"]: - camera1 = Camera("CameraDepth", PostProcessing="Depth") - camera1.set_image_size(self.config["render_x_res"], - self.config["render_y_res"]) - camera1.set_position(30, 0, 130) - settings.add_sensor(camera1) - - camera2 = Camera("CameraRGB") - camera2.set_image_size(self.config["render_x_res"], - self.config["render_y_res"]) - camera2.set_position(30, 0, 130) - settings.add_sensor(camera2) - - # Setup start and end positions - scene = self.client.load_settings(settings) - positions = scene.player_start_spots - self.start_pos = positions[self.scenario["start_pos_id"]] - self.end_pos = positions[self.scenario["end_pos_id"]] - self.start_coord = [ - self.start_pos.location.x // 100, self.start_pos.location.y // 100 - ] - self.end_coord = [ - self.end_pos.location.x // 100, self.end_pos.location.y // 100 - ] - print("Start pos {} ({}), end {} ({})".format( - self.scenario["start_pos_id"], self.start_coord, - self.scenario["end_pos_id"], self.end_coord)) - - # Notify the server that we want to start the episode at the - # player_start index. This function blocks until the server is ready - # to start the episode. - print("Starting new episode...") - self.client.start_episode(self.scenario["start_pos_id"]) - - image, py_measurements = self._read_observation() - self.prev_measurement = py_measurements - return self.encode_obs(self.preprocess_image(image), py_measurements) - - def encode_obs(self, image, py_measurements): - assert self.config["framestack"] in [1, 2] - prev_image = self.prev_image - self.prev_image = image - if prev_image is None: - prev_image = image - if self.config["framestack"] == 2: - image = np.concatenate([prev_image, image], axis=2) - obs = (image, COMMAND_ORDINAL[py_measurements["next_command"]], [ - py_measurements["forward_speed"], - py_measurements["distance_to_goal"] - ]) - self.last_obs = obs - return obs - - def step(self, action): - try: - obs = self._step(action) - return obs - except Exception: - print("Error during step, terminating episode early", - traceback.format_exc()) - self.clear_server_state() - return (self.last_obs, 0.0, True, {}) - - def _step(self, action): - if self.config["discrete_actions"]: - action = DISCRETE_ACTIONS[int(action)] - assert len(action) == 2, "Invalid action {}".format(action) - if self.config["squash_action_logits"]: - forward = 2 * float(sigmoid(action[0]) - 0.5) - throttle = float(np.clip(forward, 0, 1)) - brake = float(np.abs(np.clip(forward, -1, 0))) - steer = 2 * float(sigmoid(action[1]) - 0.5) - else: - throttle = float(np.clip(action[0], 0, 1)) - brake = float(np.abs(np.clip(action[0], -1, 0))) - steer = float(np.clip(action[1], -1, 1)) - reverse = False - hand_brake = False - - if self.config["verbose"]: - print("steer", steer, "throttle", throttle, "brake", brake, - "reverse", reverse) - - self.client.send_control( - steer=steer, - throttle=throttle, - brake=brake, - hand_brake=hand_brake, - reverse=reverse) - - # Process observations - image, py_measurements = self._read_observation() - if self.config["verbose"]: - print("Next command", py_measurements["next_command"]) - if type(action) is np.ndarray: - py_measurements["action"] = [float(a) for a in action] - else: - py_measurements["action"] = action - py_measurements["control"] = { - "steer": steer, - "throttle": throttle, - "brake": brake, - "reverse": reverse, - "hand_brake": hand_brake, - } - reward = compute_reward(self, self.prev_measurement, py_measurements) - self.total_reward += reward - py_measurements["reward"] = reward - py_measurements["total_reward"] = self.total_reward - done = (self.num_steps > self.scenario["max_steps"] - or py_measurements["next_command"] == "REACH_GOAL" - or (self.config["early_terminate_on_collision"] - and collided_done(py_measurements))) - py_measurements["done"] = done - self.prev_measurement = py_measurements - - # Write out measurements to file - if CARLA_OUT_PATH: - if not self.measurements_file: - self.measurements_file = open( - os.path.join( - CARLA_OUT_PATH, - "measurements_{}.json".format(self.episode_id)), "w") - self.measurements_file.write(json.dumps(py_measurements)) - self.measurements_file.write("\n") - if done: - self.measurements_file.close() - self.measurements_file = None - if self.config["convert_images_to_video"]: - self.images_to_video() - - self.num_steps += 1 - image = self.preprocess_image(image) - return (self.encode_obs(image, py_measurements), reward, done, - py_measurements) - - def images_to_video(self): - videos_dir = os.path.join(CARLA_OUT_PATH, "Videos") - if not os.path.exists(videos_dir): - os.makedirs(videos_dir) - ffmpeg_cmd = ( - "ffmpeg -loglevel -8 -r 60 -f image2 -s {x_res}x{y_res} " - "-start_number 0 -i " - "{img}_%04d.jpg -vcodec libx264 {vid}.mp4 && rm -f {img}_*.jpg " - ).format( - x_res=self.config["render_x_res"], - y_res=self.config["render_y_res"], - vid=os.path.join(videos_dir, self.episode_id), - img=os.path.join(CARLA_OUT_PATH, "CameraRGB", self.episode_id)) - print("Executing ffmpeg command", ffmpeg_cmd) - subprocess.call(ffmpeg_cmd, shell=True) - - def preprocess_image(self, image): - if self.config["use_depth_camera"]: - assert self.config["use_depth_camera"] - data = (image.data - 0.5) * 2 - data = data.reshape(self.config["render_y_res"], - self.config["render_x_res"], 1) - data = cv2.resize( - data, (self.config["x_res"], self.config["y_res"]), - interpolation=cv2.INTER_AREA) - data = np.expand_dims(data, 2) - else: - data = image.data.reshape(self.config["render_y_res"], - self.config["render_x_res"], 3) - data = cv2.resize( - data, (self.config["x_res"], self.config["y_res"]), - interpolation=cv2.INTER_AREA) - data = (data.astype(np.float32) - 128) / 128 - return data - - def _read_observation(self): - # Read the data produced by the server this frame. - measurements, sensor_data = self.client.read_data() - - # Print some of the measurements. - if self.config["verbose"]: - print_measurements(measurements) - - observation = None - if self.config["use_depth_camera"]: - camera_name = "CameraDepth" - else: - camera_name = "CameraRGB" - for name, image in sensor_data.items(): - if name == camera_name: - observation = image - - cur = measurements.player_measurements - - if self.config["enable_planner"]: - next_command = COMMANDS_ENUM[self.planner.get_next_command( - [cur.transform.location.x, cur.transform.location.y, GROUND_Z], - [ - cur.transform.orientation.x, cur.transform.orientation.y, - GROUND_Z - ], - [self.end_pos.location.x, self.end_pos.location.y, GROUND_Z], [ - self.end_pos.orientation.x, self.end_pos.orientation.y, - GROUND_Z - ])] - else: - next_command = "LANE_FOLLOW" - - if next_command == "REACH_GOAL": - distance_to_goal = 0.0 # avoids crash in planner - elif self.config["enable_planner"]: - distance_to_goal = self.planner.get_shortest_path_distance([ - cur.transform.location.x, cur.transform.location.y, GROUND_Z - ], [ - cur.transform.orientation.x, cur.transform.orientation.y, - GROUND_Z - ], [self.end_pos.location.x, self.end_pos.location.y, GROUND_Z], [ - self.end_pos.orientation.x, self.end_pos.orientation.y, - GROUND_Z - ]) / 100 - else: - distance_to_goal = -1 - - distance_to_goal_euclidean = float( - np.linalg.norm([ - cur.transform.location.x - self.end_pos.location.x, - cur.transform.location.y - self.end_pos.location.y - ]) / 100) - - py_measurements = { - "episode_id": self.episode_id, - "step": self.num_steps, - "x": cur.transform.location.x, - "y": cur.transform.location.y, - "x_orient": cur.transform.orientation.x, - "y_orient": cur.transform.orientation.y, - "forward_speed": cur.forward_speed, - "distance_to_goal": distance_to_goal, - "distance_to_goal_euclidean": distance_to_goal_euclidean, - "collision_vehicles": cur.collision_vehicles, - "collision_pedestrians": cur.collision_pedestrians, - "collision_other": cur.collision_other, - "intersection_offroad": cur.intersection_offroad, - "intersection_otherlane": cur.intersection_otherlane, - "weather": self.weather, - "map": self.config["server_map"], - "start_coord": self.start_coord, - "end_coord": self.end_coord, - "current_scenario": self.scenario, - "x_res": self.config["x_res"], - "y_res": self.config["y_res"], - "num_vehicles": self.scenario["num_vehicles"], - "num_pedestrians": self.scenario["num_pedestrians"], - "max_steps": self.scenario["max_steps"], - "next_command": next_command, - } - - if CARLA_OUT_PATH and self.config["log_images"]: - for name, image in sensor_data.items(): - out_dir = os.path.join(CARLA_OUT_PATH, name) - if not os.path.exists(out_dir): - os.makedirs(out_dir) - out_file = os.path.join( - out_dir, "{}_{:>04}.jpg".format(self.episode_id, - self.num_steps)) - scipy.misc.imsave(out_file, image.data) - - assert observation is not None, sensor_data - return observation, py_measurements - - -def compute_reward_corl2017(env, prev, current): - reward = 0.0 - - cur_dist = current["distance_to_goal"] - - prev_dist = prev["distance_to_goal"] - - if env.config["verbose"]: - print("Cur dist {}, prev dist {}".format(cur_dist, prev_dist)) - - # Distance travelled toward the goal in m - reward += np.clip(prev_dist - cur_dist, -10.0, 10.0) - - # Change in speed (km/h) - reward += 0.05 * (current["forward_speed"] - prev["forward_speed"]) - - # New collision damage - reward -= .00002 * ( - current["collision_vehicles"] + current["collision_pedestrians"] + - current["collision_other"] - prev["collision_vehicles"] - - prev["collision_pedestrians"] - prev["collision_other"]) - - # New sidewalk intersection - reward -= 2 * ( - current["intersection_offroad"] - prev["intersection_offroad"]) - - # New opposite lane intersection - reward -= 2 * ( - current["intersection_otherlane"] - prev["intersection_otherlane"]) - - return reward - - -def compute_reward_custom(env, prev, current): - reward = 0.0 - - cur_dist = current["distance_to_goal"] - prev_dist = prev["distance_to_goal"] - - if env.config["verbose"]: - print("Cur dist {}, prev dist {}".format(cur_dist, prev_dist)) - - # Distance travelled toward the goal in m - reward += np.clip(prev_dist - cur_dist, -10.0, 10.0) - - # Speed reward, up 30.0 (km/h) - reward += np.clip(current["forward_speed"], 0.0, 30.0) / 10 - - # New collision damage - new_damage = ( - current["collision_vehicles"] + current["collision_pedestrians"] + - current["collision_other"] - prev["collision_vehicles"] - - prev["collision_pedestrians"] - prev["collision_other"]) - if new_damage: - reward -= 100.0 - - # Sidewalk intersection - reward -= current["intersection_offroad"] - - # Opposite lane intersection - reward -= current["intersection_otherlane"] - - # Reached goal - if current["next_command"] == "REACH_GOAL": - reward += 100.0 - - return reward - - -def compute_reward_lane_keep(env, prev, current): - reward = 0.0 - - # Speed reward, up 30.0 (km/h) - reward += np.clip(current["forward_speed"], 0.0, 30.0) / 10 - - # New collision damage - new_damage = ( - current["collision_vehicles"] + current["collision_pedestrians"] + - current["collision_other"] - prev["collision_vehicles"] - - prev["collision_pedestrians"] - prev["collision_other"]) - if new_damage: - reward -= 100.0 - - # Sidewalk intersection - reward -= current["intersection_offroad"] - - # Opposite lane intersection - reward -= current["intersection_otherlane"] - - return reward - - -REWARD_FUNCTIONS = { - "corl2017": compute_reward_corl2017, - "custom": compute_reward_custom, - "lane_keep": compute_reward_lane_keep, -} - - -def compute_reward(env, prev, current): - return REWARD_FUNCTIONS[env.config["reward_function"]](env, prev, current) - - -def print_measurements(measurements): - number_of_agents = len(measurements.non_player_agents) - player_measurements = measurements.player_measurements - message = "Vehicle at ({pos_x:.1f}, {pos_y:.1f}), " - message += "{speed:.2f} km/h, " - message += "Collision: {{vehicles={col_cars:.0f}, " - message += "pedestrians={col_ped:.0f}, other={col_other:.0f}}}, " - message += "{other_lane:.0f}% other lane, {offroad:.0f}% off-road, " - message += "({agents_num:d} non-player agents in the scene)" - message = message.format( - pos_x=player_measurements.transform.location.x / 100, # cm -> m - pos_y=player_measurements.transform.location.y / 100, - speed=player_measurements.forward_speed, - col_cars=player_measurements.collision_vehicles, - col_ped=player_measurements.collision_pedestrians, - col_other=player_measurements.collision_other, - other_lane=100 * player_measurements.intersection_otherlane, - offroad=100 * player_measurements.intersection_offroad, - agents_num=number_of_agents) - print(message) - - -def sigmoid(x): - x = float(x) - return np.exp(x) / (1 + np.exp(x)) - - -def collided_done(py_measurements): - m = py_measurements - collided = (m["collision_vehicles"] > 0 or m["collision_pedestrians"] > 0 - or m["collision_other"] > 0) - return bool(collided or m["total_reward"] < -100) - - -if __name__ == "__main__": - for _ in range(2): - env = CarlaEnv() - obs = env.reset() - print("reset", obs) - start = time.time() - done = False - i = 0 - total_reward = 0.0 - while not done: - i += 1 - if ENV_CONFIG["discrete_actions"]: - obs, reward, done, info = env.step(1) - else: - obs, reward, done, info = env.step([0, 1, 0]) - total_reward += reward - print(i, "rew", reward, "total", total_reward, "done", done) - print("{} fps".format(100 / (time.time() - start))) diff --git a/python/ray/rllib/examples/carla/models.py b/python/ray/rllib/examples/carla/models.py deleted file mode 100644 index 3f8cc0c5ba47..000000000000 --- a/python/ray/rllib/examples/carla/models.py +++ /dev/null @@ -1,108 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np -import tensorflow as tf -import tensorflow.contrib.slim as slim -from tensorflow.contrib.layers import xavier_initializer - -from ray.rllib.models.catalog import ModelCatalog -from ray.rllib.models.misc import normc_initializer -from ray.rllib.models.model import Model - - -class CarlaModel(Model): - """Carla model that can process the observation tuple. - - The architecture processes the image using convolutional layers, the - metrics using fully connected layers, and then combines them with - further fully connected layers. - """ - - # TODO(ekl): use build_layers_v2 for native dict space support - def _build_layers(self, inputs, num_outputs, options): - # Parse options - image_shape = options["custom_options"]["image_shape"] - convs = options.get("conv_filters", [ - [16, [8, 8], 4], - [32, [5, 5], 3], - [32, [5, 5], 2], - [512, [10, 10], 1], - ]) - hiddens = options.get("fcnet_hiddens", [64]) - fcnet_activation = options.get("fcnet_activation", "tanh") - if fcnet_activation == "tanh": - activation = tf.nn.tanh - elif fcnet_activation == "relu": - activation = tf.nn.relu - - # Sanity checks - image_size = np.product(image_shape) - expected_shape = [image_size + 5 + 2] - assert inputs.shape.as_list()[1:] == expected_shape, \ - (inputs.shape.as_list()[1:], expected_shape) - - # Reshape the input vector back into its components - vision_in = tf.reshape(inputs[:, :image_size], - [tf.shape(inputs)[0]] + image_shape) - metrics_in = inputs[:, image_size:] - print("Vision in shape", vision_in) - print("Metrics in shape", metrics_in) - - # Setup vision layers - with tf.name_scope("carla_vision"): - for i, (out_size, kernel, stride) in enumerate(convs[:-1], 1): - vision_in = slim.conv2d( - vision_in, - out_size, - kernel, - stride, - scope="conv{}".format(i)) - out_size, kernel, stride = convs[-1] - vision_in = slim.conv2d( - vision_in, - out_size, - kernel, - stride, - padding="VALID", - scope="conv_out") - vision_in = tf.squeeze(vision_in, [1, 2]) - - # Setup metrics layer - with tf.name_scope("carla_metrics"): - metrics_in = slim.fully_connected( - metrics_in, - 64, - weights_initializer=xavier_initializer(), - activation_fn=activation, - scope="metrics_out") - - print("Shape of vision out is", vision_in.shape) - print("Shape of metric out is", metrics_in.shape) - - # Combine the metrics and vision inputs - with tf.name_scope("carla_out"): - i = 1 - last_layer = tf.concat([vision_in, metrics_in], axis=1) - print("Shape of concatenated out is", last_layer.shape) - for size in hiddens: - last_layer = slim.fully_connected( - last_layer, - size, - weights_initializer=xavier_initializer(), - activation_fn=activation, - scope="fc{}".format(i)) - i += 1 - output = slim.fully_connected( - last_layer, - num_outputs, - weights_initializer=normc_initializer(0.01), - activation_fn=None, - scope="fc_out") - - return output, last_layer - - -def register_carla_model(): - ModelCatalog.register_custom_model("carla", CarlaModel) diff --git a/python/ray/rllib/examples/carla/scenarios.py b/python/ray/rllib/examples/carla/scenarios.py deleted file mode 100644 index beedd2989d5c..000000000000 --- a/python/ray/rllib/examples/carla/scenarios.py +++ /dev/null @@ -1,131 +0,0 @@ -"""Collection of Carla scenarios, including those from the CoRL 2017 paper.""" - -TEST_WEATHERS = [0, 2, 5, 7, 9, 10, 11, 12, 13] -TRAIN_WEATHERS = [1, 3, 4, 6, 8, 14] - - -def build_scenario(city, start, end, vehicles, pedestrians, max_steps, - weathers): - return { - "city": city, - "num_vehicles": vehicles, - "num_pedestrians": pedestrians, - "weather_distribution": weathers, - "start_pos_id": start, - "end_pos_id": end, - "max_steps": max_steps, - } - - -# Simple scenario for Town02 that involves driving down a road -DEFAULT_SCENARIO = build_scenario( - city="Town02", - start=36, - end=40, - vehicles=20, - pedestrians=40, - max_steps=200, - weathers=[0]) - -# Simple scenario for Town02 that involves driving down a road -LANE_KEEP = build_scenario( - city="Town02", - start=36, - end=40, - vehicles=0, - pedestrians=0, - max_steps=2000, - weathers=[0]) - -# Scenarios from the CoRL2017 paper -POSES_TOWN1_STRAIGHT = [[36, 40], [39, 35], [110, 114], [7, 3], [0, 4], [ - 68, 50 -], [61, 59], [47, 64], [147, 90], [33, 87], [26, 19], [80, 76], [45, 49], [ - 55, 44 -], [29, 107], [95, 104], [84, 34], [53, 67], [22, 17], [91, 148], [20, 107], - [78, 70], [95, 102], [68, 44], [45, 69]] - -POSES_TOWN1_ONE_CURVE = [[138, 17], [47, 16], [26, 9], [42, 49], [140, 124], [ - 85, 98 -], [65, 133], [137, 51], [76, 66], [46, 39], [40, 60], [0, 29], [4, 129], [ - 121, 140 -], [2, 129], [78, 44], [68, 85], [41, 102], [95, 70], [68, 129], [84, 69], - [47, 79], [110, 15], [130, 17], [0, 17]] - -POSES_TOWN1_NAV = [[105, 29], [27, 130], [102, 87], [132, 27], [24, 44], [ - 96, 26 -], [34, 67], [28, 1], [140, 134], [105, 9], [148, 129], [65, 18], [21, 16], [ - 147, 97 -], [42, 51], [30, 41], [18, 107], [69, 45], [102, 95], [18, 145], [111, 64], - [79, 45], [84, 69], [73, 31], [37, 81]] - -POSES_TOWN2_STRAIGHT = [[38, 34], [4, 2], [12, 10], [62, 55], [43, 47], [ - 64, 66 -], [78, 76], [59, 57], [61, 18], [35, 39], [12, 8], [0, 18], [75, 68], [ - 54, 60 -], [45, 49], [46, 42], [53, 46], [80, 29], [65, 63], [0, 81], [54, 63], - [51, 42], [16, 19], [17, 26], [77, 68]] - -POSES_TOWN2_ONE_CURVE = [[37, 76], [8, 24], [60, 69], [38, 10], [21, 1], [ - 58, 71 -], [74, 32], [44, 0], [71, 16], [14, 24], [34, 11], [43, 14], [75, 16], [ - 80, 21 -], [3, 23], [75, 59], [50, 47], [11, 19], [77, 34], [79, 25], [40, 63], - [58, 76], [79, 55], [16, 61], [27, 11]] - -POSES_TOWN2_NAV = [[19, 66], [79, 14], [19, 57], [23, 1], [53, 76], [42, 13], [ - 31, 71 -], [33, 5], [54, 30], [10, 61], [66, 3], [27, 12], [79, 19], [2, 29], [16, 14], - [5, 57], [70, 73], [46, 67], [57, 50], [61, 49], [21, 12], - [51, 81], [77, 68], [56, 65], [43, 54]] - -TOWN1_STRAIGHT = [ - build_scenario("Town01", start, end, 0, 0, 300, TEST_WEATHERS) - for (start, end) in POSES_TOWN1_STRAIGHT -] - -TOWN1_ONE_CURVE = [ - build_scenario("Town01", start, end, 0, 0, 600, TEST_WEATHERS) - for (start, end) in POSES_TOWN1_ONE_CURVE -] - -TOWN1_NAVIGATION = [ - build_scenario("Town01", start, end, 0, 0, 900, TEST_WEATHERS) - for (start, end) in POSES_TOWN1_NAV -] - -TOWN1_NAVIGATION_DYNAMIC = [ - build_scenario("Town01", start, end, 20, 50, 900, TEST_WEATHERS) - for (start, end) in POSES_TOWN1_NAV -] - -TOWN2_STRAIGHT = [ - build_scenario("Town02", start, end, 0, 0, 300, TRAIN_WEATHERS) - for (start, end) in POSES_TOWN2_STRAIGHT -] - -TOWN2_STRAIGHT_DYNAMIC = [ - build_scenario("Town02", start, end, 20, 50, 300, TRAIN_WEATHERS) - for (start, end) in POSES_TOWN2_STRAIGHT -] - -TOWN2_ONE_CURVE = [ - build_scenario("Town02", start, end, 0, 0, 600, TRAIN_WEATHERS) - for (start, end) in POSES_TOWN2_ONE_CURVE -] - -TOWN2_NAVIGATION = [ - build_scenario("Town02", start, end, 0, 0, 900, TRAIN_WEATHERS) - for (start, end) in POSES_TOWN2_NAV -] - -TOWN2_NAVIGATION_DYNAMIC = [ - build_scenario("Town02", start, end, 20, 50, 900, TRAIN_WEATHERS) - for (start, end) in POSES_TOWN2_NAV -] - -TOWN1_ALL = (TOWN1_STRAIGHT + TOWN1_ONE_CURVE + TOWN1_NAVIGATION + - TOWN1_NAVIGATION_DYNAMIC) - -TOWN2_ALL = (TOWN2_STRAIGHT + TOWN2_ONE_CURVE + TOWN2_NAVIGATION + - TOWN2_NAVIGATION_DYNAMIC) diff --git a/python/ray/rllib/examples/carla/train_a3c.py b/python/ray/rllib/examples/carla/train_a3c.py deleted file mode 100644 index 8fbcfbc576d1..000000000000 --- a/python/ray/rllib/examples/carla/train_a3c.py +++ /dev/null @@ -1,51 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import ray -from ray.tune import grid_search, run_experiments - -from env import CarlaEnv, ENV_CONFIG -from models import register_carla_model -from scenarios import TOWN2_STRAIGHT - -env_config = ENV_CONFIG.copy() -env_config.update({ - "verbose": False, - "x_res": 80, - "y_res": 80, - "squash_action_logits": grid_search([False, True]), - "use_depth_camera": False, - "discrete_actions": False, - "server_map": "/Game/Maps/Town02", - "reward_function": grid_search(["custom", "corl2017"]), - "scenarios": TOWN2_STRAIGHT, -}) - -register_carla_model() -redis_address = ray.services.get_node_ip_address() + ":6379" - -ray.init(redis_address=redis_address) -run_experiments({ - "carla-a3c": { - "run": "A3C", - "env": CarlaEnv, - "config": { - "env_config": env_config, - "use_gpu_for_workers": True, - "model": { - "custom_model": "carla", - "custom_options": { - "image_shape": [80, 80, 6], - }, - "conv_filters": [ - [16, [8, 8], 4], - [32, [4, 4], 2], - [512, [10, 10], 1], - ], - }, - "gamma": 0.95, - "num_workers": 2, - }, - }, -}) diff --git a/python/ray/rllib/examples/carla/train_dqn.py b/python/ray/rllib/examples/carla/train_dqn.py deleted file mode 100644 index 27aa65444d38..000000000000 --- a/python/ray/rllib/examples/carla/train_dqn.py +++ /dev/null @@ -1,65 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import ray -from ray.tune import run_experiments - -from env import CarlaEnv, ENV_CONFIG -from models import register_carla_model -from scenarios import TOWN2_ONE_CURVE - -env_config = ENV_CONFIG.copy() -env_config.update({ - "verbose": False, - "x_res": 80, - "y_res": 80, - "discrete_actions": True, - "server_map": "/Game/Maps/Town02", - "reward_function": "custom", - "scenarios": TOWN2_ONE_CURVE, -}) - -register_carla_model() - -ray.init() - - -def shape_out(spec): - return (spec.config.env_config.framestack * - (spec.config.env_config.use_depth_camera and 1 or 3)) - - -run_experiments({ - "carla-dqn": { - "run": "DQN", - "env": CarlaEnv, - "config": { - "env_config": env_config, - "model": { - "custom_model": "carla", - "custom_options": { - "image_shape": [ - 80, - 80, - shape_out, - ], - }, - "conv_filters": [ - [16, [8, 8], 4], - [32, [4, 4], 2], - [512, [10, 10], 1], - ], - }, - "timesteps_per_iteration": 100, - "learning_starts": 1000, - "schedule_max_timesteps": 100000, - "gamma": 0.8, - "tf_session_args": { - "gpu_options": { - "allow_growth": True - }, - }, - }, - }, -}) diff --git a/python/ray/rllib/examples/carla/train_ppo.py b/python/ray/rllib/examples/carla/train_ppo.py deleted file mode 100644 index 130acf3a5849..000000000000 --- a/python/ray/rllib/examples/carla/train_ppo.py +++ /dev/null @@ -1,55 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import ray -from ray.tune import run_experiments - -from env import CarlaEnv, ENV_CONFIG -from models import register_carla_model -from scenarios import TOWN2_STRAIGHT - -env_config = ENV_CONFIG.copy() -env_config.update({ - "verbose": False, - "x_res": 80, - "y_res": 80, - "use_depth_camera": False, - "discrete_actions": False, - "server_map": "/Game/Maps/Town02", - "scenarios": TOWN2_STRAIGHT, -}) -register_carla_model() - -ray.init() -run_experiments({ - "carla": { - "run": "PPO", - "env": CarlaEnv, - "config": { - "env_config": env_config, - "model": { - "custom_model": "carla", - "custom_options": { - "image_shape": [ - env_config["x_res"], env_config["y_res"], 6 - ], - }, - "conv_filters": [ - [16, [8, 8], 4], - [32, [4, 4], 2], - [512, [10, 10], 1], - ], - }, - "num_workers": 1, - "train_batch_size": 2000, - "sample_batch_size": 100, - "lambda": 0.95, - "clip_param": 0.2, - "num_sgd_iter": 20, - "lr": 0.0001, - "sgd_minibatch_size": 32, - "num_gpus": 1, - }, - }, -}) diff --git a/python/ray/rllib/examples/custom_fast_model.py b/python/ray/rllib/examples/custom_fast_model.py index 86201c87da7a..dce01e9e7754 100644 --- a/python/ray/rllib/examples/custom_fast_model.py +++ b/python/ray/rllib/examples/custom_fast_model.py @@ -11,11 +11,13 @@ from gym.spaces import Discrete, Box import gym import numpy as np -import tensorflow as tf import ray from ray.rllib.models import Model, ModelCatalog from ray.tune import run_experiments, sample_from +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() class FastModel(Model): diff --git a/python/ray/rllib/examples/custom_loss.py b/python/ray/rllib/examples/custom_loss.py index 1f04f0fb5a6e..8905b48952da 100644 --- a/python/ray/rllib/examples/custom_loss.py +++ b/python/ray/rllib/examples/custom_loss.py @@ -15,7 +15,6 @@ import argparse import os -import tensorflow as tf import ray from ray import tune @@ -23,6 +22,9 @@ ModelCatalog) from ray.rllib.models.model import restore_original_dimensions from ray.rllib.offline import JsonReader +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() parser = argparse.ArgumentParser() parser.add_argument("--iters", type=int, default=200) diff --git a/python/ray/rllib/examples/export/cartpole_dqn_export.py b/python/ray/rllib/examples/export/cartpole_dqn_export.py index 6bfcae060d13..47a5e3b41ea7 100644 --- a/python/ray/rllib/examples/export/cartpole_dqn_export.py +++ b/python/ray/rllib/examples/export/cartpole_dqn_export.py @@ -6,9 +6,11 @@ import os import ray -import tensorflow as tf from ray.rllib.agents.registry import get_agent_class +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() ray.init(num_cpus=10) diff --git a/python/ray/rllib/examples/multiagent_cartpole.py b/python/ray/rllib/examples/multiagent_cartpole.py index d7485e27a0c6..6e0f93711540 100644 --- a/python/ray/rllib/examples/multiagent_cartpole.py +++ b/python/ray/rllib/examples/multiagent_cartpole.py @@ -16,14 +16,14 @@ import gym import random -import tensorflow as tf -import tensorflow.contrib.slim as slim - import ray from ray import tune from ray.rllib.models import Model, ModelCatalog from ray.rllib.tests.test_multi_agent_env import MultiCartpole from ray.tune.registry import register_env +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() parser = argparse.ArgumentParser() @@ -43,12 +43,12 @@ def _build_layers_v2(self, input_dict, num_outputs, options): tf.VariableScope(tf.AUTO_REUSE, "shared"), reuse=tf.AUTO_REUSE, auxiliary_name_scope=False): - last_layer = slim.fully_connected( - input_dict["obs"], 64, activation_fn=tf.nn.relu, scope="fc1") - last_layer = slim.fully_connected( - last_layer, 64, activation_fn=tf.nn.relu, scope="fc2") - output = slim.fully_connected( - last_layer, num_outputs, activation_fn=None, scope="fc_out") + last_layer = tf.layers.dense( + input_dict["obs"], 64, activation=tf.nn.relu, name="fc1") + last_layer = tf.layers.dense( + last_layer, 64, activation=tf.nn.relu, name="fc2") + output = tf.layers.dense( + last_layer, num_outputs, activation=None, name="fc_out") return output, last_layer @@ -59,12 +59,12 @@ def _build_layers_v2(self, input_dict, num_outputs, options): tf.VariableScope(tf.AUTO_REUSE, "shared"), reuse=tf.AUTO_REUSE, auxiliary_name_scope=False): - last_layer = slim.fully_connected( - input_dict["obs"], 64, activation_fn=tf.nn.relu, scope="fc1") - last_layer = slim.fully_connected( - last_layer, 64, activation_fn=tf.nn.relu, scope="fc2") - output = slim.fully_connected( - last_layer, num_outputs, activation_fn=None, scope="fc_out") + last_layer = tf.layers.dense( + input_dict["obs"], 64, activation=tf.nn.relu, name="fc1") + last_layer = tf.layers.dense( + last_layer, 64, activation=tf.nn.relu, name="fc2") + output = tf.layers.dense( + last_layer, num_outputs, activation=None, name="fc_out") return output, last_layer diff --git a/python/ray/rllib/examples/parametric_action_cartpole.py b/python/ray/rllib/examples/parametric_action_cartpole.py index 3d57c268cae3..e16e1ab75870 100644 --- a/python/ray/rllib/examples/parametric_action_cartpole.py +++ b/python/ray/rllib/examples/parametric_action_cartpole.py @@ -23,14 +23,15 @@ import numpy as np import gym from gym.spaces import Box, Discrete, Dict -import tensorflow as tf -import tensorflow.contrib.slim as slim import ray from ray import tune from ray.rllib.models import Model, ModelCatalog from ray.rllib.models.misc import normc_initializer from ray.tune.registry import register_env +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() parser = argparse.ArgumentParser() parser.add_argument("--stop", type=int, default=200) @@ -134,18 +135,18 @@ def _build_layers_v2(self, input_dict, num_outputs, options): hiddens = [256, 256] for i, size in enumerate(hiddens): label = "fc{}".format(i) - last_layer = slim.fully_connected( + last_layer = tf.layers.dense( last_layer, size, - weights_initializer=normc_initializer(1.0), - activation_fn=tf.nn.tanh, - scope=label) - output = slim.fully_connected( + kernel_initializer=normc_initializer(1.0), + activation=tf.nn.tanh, + name=label) + output = tf.layers.dense( last_layer, action_embed_size, - weights_initializer=normc_initializer(0.01), - activation_fn=None, - scope="fc_out") + kernel_initializer=normc_initializer(0.01), + activation=None, + name="fc_out") # Expand the model output to [BATCH, 1, EMBED_SIZE]. Note that the # avail actions tensor is of shape [BATCH, MAX_ACTIONS, EMBED_SIZE]. diff --git a/python/ray/rllib/models/action_dist.py b/python/ray/rllib/models/action_dist.py index 1cad7d3aa9ac..9cf58b9dd317 100644 --- a/python/ray/rllib/models/action_dist.py +++ b/python/ray/rllib/models/action_dist.py @@ -12,7 +12,11 @@ tf = try_import_tf() if tf: - use_tf150_api = (distutils.version.LooseVersion(tf.VERSION) >= + if hasattr(tf, "__version__"): + version = tf.__version__ + else: + version = tf.VERSION + use_tf150_api = (distutils.version.LooseVersion(version) >= distutils.version.LooseVersion("1.5.0")) else: use_tf150_api = False diff --git a/python/ray/rllib/models/fcnet.py b/python/ray/rllib/models/fcnet.py index 3cc0fbe403c5..c3bacbd46a7d 100644 --- a/python/ray/rllib/models/fcnet.py +++ b/python/ray/rllib/models/fcnet.py @@ -21,8 +21,6 @@ def _build_layers(self, inputs, num_outputs, options): model that processes the components separately, use _build_layers_v2(). """ - import tensorflow.contrib.slim as slim - hiddens = options.get("fcnet_hiddens") activation = get_activation_fn(options.get("fcnet_activation")) @@ -31,18 +29,18 @@ def _build_layers(self, inputs, num_outputs, options): last_layer = inputs for size in hiddens: label = "fc{}".format(i) - last_layer = slim.fully_connected( + last_layer = tf.layers.dense( last_layer, size, - weights_initializer=normc_initializer(1.0), - activation_fn=activation, - scope=label) + kernel_initializer=normc_initializer(1.0), + activation=activation, + name=label) i += 1 label = "fc_out" - output = slim.fully_connected( + output = tf.layers.dense( last_layer, num_outputs, - weights_initializer=normc_initializer(0.01), - activation_fn=None, - scope=label) + kernel_initializer=normc_initializer(0.01), + activation=None, + name=label) return output, last_layer diff --git a/python/ray/rllib/models/lstm.py b/python/ray/rllib/models/lstm.py index 5b9328c3c463..62b854a86ed9 100644 --- a/python/ray/rllib/models/lstm.py +++ b/python/ray/rllib/models/lstm.py @@ -38,8 +38,6 @@ class LSTM(Model): @override(Model) def _build_layers_v2(self, input_dict, num_outputs, options): - import tensorflow.contrib.rnn as rnn - cell_size = options.get("lstm_cell_size") if options.get("lstm_use_prev_action_reward"): action_dim = int( @@ -76,7 +74,7 @@ def _build_layers_v2(self, input_dict, num_outputs, options): self.state_in = [c_in, h_in] # Setup LSTM outputs - state_in = rnn.LSTMStateTuple(c_in, h_in) + state_in = tf.nn.rnn_cell.LSTMStateTuple(c_in, h_in) lstm_out, lstm_state = tf.nn.dynamic_rnn( lstm, last_layer, diff --git a/python/ray/rllib/models/visionnet.py b/python/ray/rllib/models/visionnet.py index 53eaf5d02c3f..6ad30ddb90c4 100644 --- a/python/ray/rllib/models/visionnet.py +++ b/python/ray/rllib/models/visionnet.py @@ -15,8 +15,6 @@ class VisionNetwork(Model): @override(Model) def _build_layers_v2(self, input_dict, num_outputs, options): - import tensorflow.contrib.slim as slim - inputs = input_dict["obs"] filters = options.get("conv_filters") if not filters: @@ -26,28 +24,29 @@ def _build_layers_v2(self, input_dict, num_outputs, options): with tf.name_scope("vision_net"): for i, (out_size, kernel, stride) in enumerate(filters[:-1], 1): - inputs = slim.conv2d( + inputs = tf.layers.conv2d( inputs, out_size, kernel, stride, - activation_fn=activation, - scope="conv{}".format(i)) + activation=activation, + padding="same", + name="conv{}".format(i)) out_size, kernel, stride = filters[-1] - fc1 = slim.conv2d( + fc1 = tf.layers.conv2d( inputs, out_size, kernel, stride, - activation_fn=activation, - padding="VALID", - scope="fc1") - fc2 = slim.conv2d( + activation=activation, + padding="valid", + name="fc1") + fc2 = tf.layers.conv2d( fc1, num_outputs, [1, 1], - activation_fn=None, - normalizer_fn=None, - scope="fc2") + activation=None, + padding="same", + name="fc2") return flatten(fc2), flatten(fc1) diff --git a/python/ray/rllib/optimizers/aso_multi_gpu_learner.py b/python/ray/rllib/optimizers/aso_multi_gpu_learner.py index a584be7e6c53..328fee67d548 100644 --- a/python/ray/rllib/optimizers/aso_multi_gpu_learner.py +++ b/python/ray/rllib/optimizers/aso_multi_gpu_learner.py @@ -17,6 +17,9 @@ from ray.rllib.optimizers.multi_gpu_impl import LocalSyncParallelOptimizer from ray.rllib.utils.annotations import override from ray.rllib.utils.timer import TimerStat +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() logger = logging.getLogger(__name__) @@ -38,9 +41,6 @@ def __init__(self, learner_queue_size=16, num_data_load_threads=16, _fake_gpus=False): - # Multi-GPU requires TensorFlow to function. - import tensorflow as tf - LearnerThread.__init__(self, local_evaluator, minibatch_buffer_size, num_sgd_iter, learner_queue_size) self.lr = lr diff --git a/python/ray/rllib/tests/test_catalog.py b/python/ray/rllib/tests/test_catalog.py index fe89152c6cbd..1c93b40ed484 100644 --- a/python/ray/rllib/tests/test_catalog.py +++ b/python/ray/rllib/tests/test_catalog.py @@ -1,6 +1,5 @@ import gym import numpy as np -import tensorflow as tf import unittest from gym.spaces import Box, Discrete, Tuple @@ -12,6 +11,9 @@ Preprocessor) from ray.rllib.models.fcnet import FullyConnectedNetwork from ray.rllib.models.visionnet import VisionNetwork +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() class CustomPreprocessor(Preprocessor): diff --git a/python/ray/rllib/tests/test_lstm.py b/python/ray/rllib/tests/test_lstm.py index 385f2d7bc1ba..dd9c7ccd9d86 100644 --- a/python/ray/rllib/tests/test_lstm.py +++ b/python/ray/rllib/tests/test_lstm.py @@ -6,8 +6,6 @@ import numpy as np import pickle import unittest -import tensorflow as tf -import tensorflow.contrib.rnn as rnn import ray from ray.rllib.agents.ppo import PPOTrainer @@ -16,6 +14,9 @@ from ray.rllib.models.misc import linear, normc_initializer from ray.rllib.models.model import Model from ray.tune.registry import register_env +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() class LSTMUtilsTest(unittest.TestCase): @@ -104,7 +105,7 @@ def spy(sequences, state_in, state_out, seq_lens): last_layer = add_time_dimension(features, self.seq_lens) # Setup the LSTM cell - lstm = rnn.BasicLSTMCell(cell_size, state_is_tuple=True) + lstm = tf.nn.rnn_cell.BasicLSTMCell(cell_size, state_is_tuple=True) self.state_init = [ np.zeros(lstm.state_size.c, np.float32), np.zeros(lstm.state_size.h, np.float32) @@ -121,7 +122,7 @@ def spy(sequences, state_in, state_out, seq_lens): self.state_in = [c_in, h_in] # Setup LSTM outputs - state_in = rnn.LSTMStateTuple(c_in, h_in) + state_in = tf.nn.rnn_cell.LSTMStateTuple(c_in, h_in) lstm_out, lstm_state = tf.nn.dynamic_rnn( lstm, last_layer, diff --git a/python/ray/rllib/tests/test_nested_spaces.py b/python/ray/rllib/tests/test_nested_spaces.py index dc45ca3f605e..e4285e42287c 100644 --- a/python/ray/rllib/tests/test_nested_spaces.py +++ b/python/ray/rllib/tests/test_nested_spaces.py @@ -7,8 +7,6 @@ from gym import spaces from gym.envs.registration import EnvSpec import gym -import tensorflow.contrib.slim as slim -import tensorflow as tf import unittest import ray @@ -25,6 +23,9 @@ from ray.rllib.rollout import rollout from ray.rllib.tests.test_external_env import SimpleServing from ray.tune.registry import register_env +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() DICT_SPACE = spaces.Dict({ "sensors": spaces.Dict({ @@ -179,8 +180,8 @@ def spy(pos, front_cam, task): stateful=True) with tf.control_dependencies([spy_fn]): - output = slim.fully_connected( - input_dict["obs"]["sensors"]["position"], num_outputs) + output = tf.layers.dense(input_dict["obs"]["sensors"]["position"], + num_outputs) return output, output @@ -208,7 +209,7 @@ def spy(pos, cam, task): stateful=True) with tf.control_dependencies([spy_fn]): - output = slim.fully_connected(input_dict["obs"][0], num_outputs) + output = tf.layers.dense(input_dict["obs"][0], num_outputs) return output, output diff --git a/python/ray/rllib/tests/test_optimizers.py b/python/ray/rllib/tests/test_optimizers.py index 65992a220ba2..9c9e6b56b426 100644 --- a/python/ray/rllib/tests/test_optimizers.py +++ b/python/ray/rllib/tests/test_optimizers.py @@ -4,7 +4,6 @@ import gym import numpy as np -import tensorflow as tf import time import unittest @@ -16,6 +15,9 @@ from ray.rllib.optimizers import AsyncGradientsOptimizer, AsyncSamplesOptimizer from ray.rllib.optimizers.aso_tree_aggregator import TreeAggregator from ray.rllib.tests.mock_evaluator import _MockEvaluator +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() class AsyncOptimizerTest(unittest.TestCase): diff --git a/python/ray/rllib/utils/__init__.py b/python/ray/rllib/utils/__init__.py index 9ff0295690e2..a16cba22b611 100644 --- a/python/ray/rllib/utils/__init__.py +++ b/python/ray/rllib/utils/__init__.py @@ -33,10 +33,15 @@ def try_import_tf(): return None try: - import tensorflow as tf + import tensorflow.compat.v1 as tf + tf.disable_v2_behavior() return tf except ImportError: - return None + try: + import tensorflow as tf + return tf + except ImportError: + return None __all__ = [ From 84cf474abc0b1f438ea33e2d1ce0cb3db455ffc1 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Thu, 16 May 2019 22:34:14 -0700 Subject: [PATCH 017/118] Change tagline in documentation and README. (#4807) * Update README.rst, index.rst, tutorial.rst and _config.yml --- README.rst | 2 +- doc/source/index.rst | 2 +- doc/source/tutorial.rst | 6 +++--- site/_config.yml | 4 ++-- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/README.rst b/README.rst index 87ceba0d72c6..ada6f7c2d4d1 100644 --- a/README.rst +++ b/README.rst @@ -11,7 +11,7 @@ | -**Ray is a flexible, high-performance distributed execution framework.** +**Ray is a fast and simple framework for building and running distributed applications.** Ray is easy to install: ``pip install ray`` diff --git a/doc/source/index.rst b/doc/source/index.rst index 48c0c0d0e662..eba9eaa6ccac 100644 --- a/doc/source/index.rst +++ b/doc/source/index.rst @@ -7,7 +7,7 @@ Ray Fork me on GitHub -*Ray is a flexible, high-performance distributed execution framework.* +*Ray is a fast and simple framework for building and running distributed applications.* Ray is easy to install: ``pip install ray`` diff --git a/doc/source/tutorial.rst b/doc/source/tutorial.rst index c1b26d155e8c..889ca77c50c3 100644 --- a/doc/source/tutorial.rst +++ b/doc/source/tutorial.rst @@ -9,9 +9,9 @@ To use Ray, you need to understand the following: Overview -------- -Ray is a distributed execution engine. The same code can be run on -a single machine to achieve efficient multiprocessing, and it can be used on a -cluster for large computations. +Ray is a fast and simple framework for building and running distributed applications. +The same code can be run on a single machine to achieve efficient multiprocessing, +and it can be used on a cluster for large computations. When using Ray, several processes are involved. diff --git a/site/_config.yml b/site/_config.yml index 24ec957d22e7..676d4f0c0bb4 100644 --- a/site/_config.yml +++ b/site/_config.yml @@ -13,10 +13,10 @@ # you will see them accessed via {{ site.title }}, {{ site.email }}, and so on. # You can create any custom variable you would like, and they will be accessible # in the templates via {{ site.myvariable }}. -title: "Ray: A Distributed Execution Framework for AI Applications" +title: "Ray: A fast and simple framework for distributed applications" email: "" description: > # this means to ignore newlines until "baseurl:" - Ray is a flexible, high-performance distributed execution framework for AI applications. + Ray is a fast and simple framework for building and running distributed applications. baseurl: "" # the subpath of your site, e.g. /blog url: "" # the base hostname & protocol for your site, e.g. http://example.com github_username: ray-project From ffe61fcc70ac0c4ae1dd9441507679303ae6cde0 Mon Sep 17 00:00:00 2001 From: Richard Liaw Date: Thu, 16 May 2019 23:10:07 -0700 Subject: [PATCH 018/118] [tune] Support non-arg submit (#4803) --- python/ray/scripts/scripts.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/ray/scripts/scripts.py b/python/ray/scripts/scripts.py index 1951af208573..9489cd2ae43d 100644 --- a/python/ray/scripts/scripts.py +++ b/python/ray/scripts/scripts.py @@ -609,7 +609,10 @@ def submit(cluster_config_file, docker, screen, tmux, stop, start, target = os.path.join("~", os.path.basename(script)) rsync(cluster_config_file, script, target, cluster_name, down=False) - cmd = " ".join(["python", target, args]) + command_parts = ["python", target] + if args is not None: + command_parts += [args] + cmd = " ".join(command_parts) exec_cluster(cluster_config_file, cmd, docker, screen, tmux, stop, False, cluster_name, port_forward) From 88b45a53d6a85c1d837b4c7bf5ecbd2ee672e851 Mon Sep 17 00:00:00 2001 From: Richard Liaw Date: Thu, 16 May 2019 23:11:06 -0700 Subject: [PATCH 019/118] [autoscaler] rsync cluster (#4785) --- python/ray/autoscaler/commands.py | 9 +++- python/ray/autoscaler/updater.py | 42 +++++++++++-------- python/ray/scripts/scripts.py | 8 ++-- .../tune/examples/mnist_pytorch_trainable.py | 7 +++- 4 files changed, 42 insertions(+), 24 deletions(-) diff --git a/python/ray/autoscaler/commands.py b/python/ray/autoscaler/commands.py index 9a89261be7de..faaef8c6a153 100644 --- a/python/ray/autoscaler/commands.py +++ b/python/ray/autoscaler/commands.py @@ -423,6 +423,8 @@ def rsync(config_file, source, target, override_cluster_name, down): override_cluster_name: set the name of the cluster down: whether we're syncing remote -> local """ + assert bool(source) == bool(target), ( + "Must either provide both or neither source and target.") config = yaml.load(open(config_file).read()) if override_cluster_name is not None: @@ -448,7 +450,12 @@ def rsync(config_file, source, target, override_cluster_name, down): rsync = updater.rsync_down else: rsync = updater.rsync_up - rsync(source, target, check_error=False) + + if source and target: + rsync(source, target, check_error=False) + else: + updater.sync_file_mounts(rsync) + finally: provider.cleanup() diff --git a/python/ray/autoscaler/updater.py b/python/ray/autoscaler/updater.py index 9fff0c767467..c86750fe399d 100644 --- a/python/ray/autoscaler/updater.py +++ b/python/ray/autoscaler/updater.py @@ -183,25 +183,9 @@ def wait_for_ssh(self, deadline): return False - def do_update(self): - self.provider.set_node_tags(self.node_id, - {TAG_RAY_NODE_STATUS: "waiting-for-ssh"}) - - deadline = time.time() + NODE_START_WAIT_S - self.set_ssh_ip_if_required() - - # Wait for SSH access - with LogTimer("NodeUpdater: " "{}: Got SSH".format(self.node_id)): - ssh_ok = self.wait_for_ssh(deadline) - assert ssh_ok, "Unable to SSH to node" - + def sync_file_mounts(self, sync_cmd): # Rsync file mounts - self.provider.set_node_tags(self.node_id, - {TAG_RAY_NODE_STATUS: "syncing-files"}) for remote_path, local_path in self.file_mounts.items(): - logger.info("NodeUpdater: " - "{}: Syncing {} to {}...".format( - self.node_id, local_path, remote_path)) assert os.path.exists(local_path), local_path if os.path.isdir(local_path): if not local_path.endswith("/"): @@ -217,7 +201,23 @@ def do_update(self): "mkdir -p {}".format(os.path.dirname(remote_path)), redirect=redirect, ) - self.rsync_up(local_path, remote_path, redirect=redirect) + sync_cmd(local_path, remote_path, redirect=redirect) + + def do_update(self): + self.provider.set_node_tags(self.node_id, + {TAG_RAY_NODE_STATUS: "waiting-for-ssh"}) + + deadline = time.time() + NODE_START_WAIT_S + self.set_ssh_ip_if_required() + + # Wait for SSH access + with LogTimer("NodeUpdater: " "{}: Got SSH".format(self.node_id)): + ssh_ok = self.wait_for_ssh(deadline) + assert ssh_ok, "Unable to SSH to node" + + self.provider.set_node_tags(self.node_id, + {TAG_RAY_NODE_STATUS: "syncing-files"}) + self.sync_file_mounts(self.rsync_up) # Run init commands self.provider.set_node_tags(self.node_id, @@ -236,6 +236,9 @@ def do_update(self): self.ssh_cmd(cmd, redirect=redirect) def rsync_up(self, source, target, redirect=None, check_error=True): + logger.info("NodeUpdater: " + "{}: Syncing {} to {}...".format(self.node_id, source, + target)) self.set_ssh_ip_if_required() self.get_caller(check_error)( [ @@ -247,6 +250,9 @@ def rsync_up(self, source, target, redirect=None, check_error=True): stderr=redirect or sys.stderr) def rsync_down(self, source, target, redirect=None, check_error=True): + logger.info("NodeUpdater: " + "{}: Syncing {} from {}...".format(self.node_id, source, + target)) self.set_ssh_ip_if_required() self.get_caller(check_error)( [ diff --git a/python/ray/scripts/scripts.py b/python/ray/scripts/scripts.py index 9489cd2ae43d..9ed667b59671 100644 --- a/python/ray/scripts/scripts.py +++ b/python/ray/scripts/scripts.py @@ -529,8 +529,8 @@ def attach(cluster_config_file, start, tmux, cluster_name, new): @cli.command() @click.argument("cluster_config_file", required=True, type=str) -@click.argument("source", required=True, type=str) -@click.argument("target", required=True, type=str) +@click.argument("source", required=False, type=str) +@click.argument("target", required=False, type=str) @click.option( "--cluster-name", "-n", @@ -543,8 +543,8 @@ def rsync_down(cluster_config_file, source, target, cluster_name): @cli.command() @click.argument("cluster_config_file", required=True, type=str) -@click.argument("source", required=True, type=str) -@click.argument("target", required=True, type=str) +@click.argument("source", required=False, type=str) +@click.argument("target", required=False, type=str) @click.option( "--cluster-name", "-n", diff --git a/python/ray/tune/examples/mnist_pytorch_trainable.py b/python/ray/tune/examples/mnist_pytorch_trainable.py index ac26d0353a98..7163dcfd6a01 100644 --- a/python/ray/tune/examples/mnist_pytorch_trainable.py +++ b/python/ray/tune/examples/mnist_pytorch_trainable.py @@ -49,6 +49,11 @@ action="store_true", default=False, help="disables CUDA training") +parser.add_argument( + "--redis-address", + default=None, + type=str, + help="The Redis address of the cluster.") parser.add_argument( "--seed", type=int, @@ -173,7 +178,7 @@ def _restore(self, checkpoint_path): from ray import tune from ray.tune.schedulers import HyperBandScheduler - ray.init() + ray.init(redis_address=args.redis_address) sched = HyperBandScheduler( time_attr="training_iteration", reward_attr="neg_mean_loss") tune.run( From e20855ccae216d398c6d4d8ca766a49dee82b72b Mon Sep 17 00:00:00 2001 From: Richard Liaw Date: Thu, 16 May 2019 23:11:35 -0700 Subject: [PATCH 020/118] [tune] Remove extra parsing functionality (#4804) --- python/ray/tune/config_parser.py | 18 +----------------- 1 file changed, 1 insertion(+), 17 deletions(-) diff --git a/python/ray/tune/config_parser.py b/python/ray/tune/config_parser.py index 864ed6402639..139ef6f82bc3 100644 --- a/python/ray/tune/config_parser.py +++ b/python/ray/tune/config_parser.py @@ -76,22 +76,6 @@ def make_parser(parser_creator=None, **kwargs): default="", type=str, help="Optional URI to sync training results to (e.g. s3://bucket).") - parser.add_argument( - "--trial-name-creator", - default=None, - help="Optional creator function for the trial string, used in " - "generating a trial directory.") - parser.add_argument( - "--sync-function", - default=None, - help="Function for syncing the local_dir to upload_dir. If string, " - "then it must be a string template for syncer to run and needs to " - "include replacement fields '{local_dir}' and '{remote_dir}'.") - parser.add_argument( - "--loggers", - default=None, - help="List of logger creators to be used with each Trial. " - "Defaults to ray.tune.logger.DEFAULT_LOGGERS.") parser.add_argument( "--checkpoint-freq", default=0, @@ -187,7 +171,7 @@ def create_trial_from_spec(spec, output_path, parser, **trial_kwargs): A trial object with corresponding parameters to the specification. """ try: - args = parser.parse_args(to_argv(spec)) + args, _ = parser.parse_known_args(to_argv(spec)) except SystemExit: raise TuneError("Error parsing args, see above message", spec) if "resources_per_trial" in spec: From dcd6d4949ca296b584f7090fb2249ae7a07b3e8f Mon Sep 17 00:00:00 2001 From: Qing Wang Date: Fri, 17 May 2019 16:13:28 +0800 Subject: [PATCH 021/118] Fix Java worker log dir (#4781) --- python/ray/node.py | 1 + python/ray/services.py | 11 ++++++----- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/python/ray/node.py b/python/ray/node.py index 733f21c9d728..85510147a35f 100644 --- a/python/ray/node.py +++ b/python/ray/node.py @@ -435,6 +435,7 @@ def start_raylet(self, use_valgrind=False, use_profiler=False): self._plasma_store_socket_name, self._ray_params.worker_path, self._temp_dir, + self._session_dir, self._ray_params.num_cpus, self._ray_params.num_gpus, self._ray_params.resources, diff --git a/python/ray/services.py b/python/ray/services.py index 3969019fcf6d..034a610e471c 100644 --- a/python/ray/services.py +++ b/python/ray/services.py @@ -1063,6 +1063,7 @@ def start_raylet(redis_address, plasma_store_name, worker_path, temp_dir, + session_dir, num_cpus=None, num_gpus=None, resources=None, @@ -1088,6 +1089,7 @@ def start_raylet(redis_address, worker_path (str): The path of the Python file that new worker processes will execute. temp_dir (str): The path of the temporary directory Ray will use. + session_dir (str): The path of this session. num_cpus: The CPUs allocated for this raylet. num_gpus: The GPUs allocated for this raylet. resources: The custom resources allocated for this raylet. @@ -1145,7 +1147,7 @@ def start_raylet(redis_address, plasma_store_name, raylet_name, redis_password, - os.path.join(temp_dir, "sockets"), + session_dir, ) else: java_worker_command = "" @@ -1212,7 +1214,7 @@ def build_java_worker_command( plasma_store_name, raylet_name, redis_password, - temp_dir, + session_dir, ): """This method assembles the command used to start a Java worker. @@ -1223,7 +1225,7 @@ def build_java_worker_command( to. raylet_name (str): The name of the raylet socket to create. redis_password (str): The password of connect to redis. - temp_dir (str): The path of the temporary directory Ray will use. + session_dir (str): The path of this session. Returns: The command string for starting Java worker. """ @@ -1244,8 +1246,7 @@ def build_java_worker_command( command += "-Dray.redis.password={} ".format(redis_password) command += "-Dray.home={} ".format(RAY_HOME) - # TODO(suquark): We should use temp_dir as the input of a java worker. - command += "-Dray.log-dir={} ".format(os.path.join(temp_dir, "sockets")) + command += "-Dray.log-dir={} ".format(os.path.join(session_dir, "logs")) if java_worker_options: # Put `java_worker_options` in the last, so it can overwrite the From 1ef9c0729d104dd101b71323a6aec6c0ad502e03 Mon Sep 17 00:00:00 2001 From: Noah Golmant Date: Fri, 17 May 2019 11:34:05 -0700 Subject: [PATCH 022/118] [tune] Initial track integration (#4362) Introduces a minimally invasive utility for logging experiment results. A broad requirement for this tool is that it should integrate seamlessly with Tune execution. --- python/ray/tune/__init__.py | 2 +- .../ray/tune/automlboard/backend/collector.py | 6 +- python/ray/tune/examples/track_example.py | 71 +++++++++++ python/ray/tune/examples/utils.py | 4 +- python/ray/tune/function_runner.py | 23 +++- python/ray/tune/logger.py | 30 +++-- python/ray/tune/result.py | 2 +- python/ray/tune/tests/test_track.py | 84 +++++++++++++ python/ray/tune/track/__init__.py | 71 +++++++++++ python/ray/tune/track/session.py | 110 ++++++++++++++++++ python/ray/tune/trial.py | 20 ++-- 11 files changed, 397 insertions(+), 26 deletions(-) create mode 100644 python/ray/tune/examples/track_example.py create mode 100644 python/ray/tune/tests/test_track.py create mode 100644 python/ray/tune/track/__init__.py create mode 100644 python/ray/tune/track/session.py diff --git a/python/ray/tune/__init__.py b/python/ray/tune/__init__.py index 810256e07138..560a67e6b35b 100644 --- a/python/ray/tune/__init__.py +++ b/python/ray/tune/__init__.py @@ -14,5 +14,5 @@ __all__ = [ "Trainable", "TuneError", "grid_search", "register_env", "register_trainable", "run", "run_experiments", "Experiment", "function", - "sample_from", "uniform", "choice", "randint", "randn" + "sample_from", "track", "uniform", "choice", "randint", "randn" ] diff --git a/python/ray/tune/automlboard/backend/collector.py b/python/ray/tune/automlboard/backend/collector.py index 5566f479960f..dd87df1a450d 100644 --- a/python/ray/tune/automlboard/backend/collector.py +++ b/python/ray/tune/automlboard/backend/collector.py @@ -14,7 +14,7 @@ from ray.tune.automlboard.models.models import JobRecord, \ TrialRecord, ResultRecord from ray.tune.result import DEFAULT_RESULTS_DIR, JOB_META_FILE, \ - EXPR_PARARM_FILE, EXPR_RESULT_FILE, EXPR_META_FILE + EXPR_PARAM_FILE, EXPR_RESULT_FILE, EXPR_META_FILE class CollectorService(object): @@ -327,7 +327,7 @@ def _build_trial_meta(cls, expr_dir): if not meta: job_id = expr_dir.split("/")[-2] trial_id = expr_dir[-8:] - params = parse_json(os.path.join(expr_dir, EXPR_PARARM_FILE)) + params = parse_json(os.path.join(expr_dir, EXPR_PARAM_FILE)) meta = { "trial_id": trial_id, "job_id": job_id, @@ -349,7 +349,7 @@ def _build_trial_meta(cls, expr_dir): if meta.get("end_time", None): meta["end_time"] = timestamp2date(meta["end_time"]) - meta["params"] = parse_json(os.path.join(expr_dir, EXPR_PARARM_FILE)) + meta["params"] = parse_json(os.path.join(expr_dir, EXPR_PARAM_FILE)) return meta diff --git a/python/ray/tune/examples/track_example.py b/python/ray/tune/examples/track_example.py new file mode 100644 index 000000000000..1ccec39462d0 --- /dev/null +++ b/python/ray/tune/examples/track_example.py @@ -0,0 +1,71 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import keras +from keras.datasets import mnist +from keras.models import Sequential +from keras.layers import (Dense, Dropout, Flatten, Conv2D, MaxPooling2D) + +from ray.tune import track +from ray.tune.examples.utils import TuneKerasCallback, get_mnist_data + +parser = argparse.ArgumentParser() +parser.add_argument( + "--smoke-test", action="store_true", help="Finish quickly for testing") +parser.add_argument( + "--lr", + type=float, + default=0.01, + metavar="LR", + help="learning rate (default: 0.01)") +parser.add_argument( + "--momentum", + type=float, + default=0.5, + metavar="M", + help="SGD momentum (default: 0.5)") +parser.add_argument( + "--hidden", type=int, default=64, help="Size of hidden layer.") +args, _ = parser.parse_known_args() + + +def train_mnist(args): + track.init(trial_name="track-example", trial_config=vars(args)) + batch_size = 128 + num_classes = 10 + epochs = 1 if args.smoke_test else 12 + mnist.load() + x_train, y_train, x_test, y_test, input_shape = get_mnist_data() + + model = Sequential() + model.add( + Conv2D( + 32, kernel_size=(3, 3), activation="relu", + input_shape=input_shape)) + model.add(Conv2D(64, (3, 3), activation="relu")) + model.add(MaxPooling2D(pool_size=(2, 2))) + model.add(Dropout(0.5)) + model.add(Flatten()) + model.add(Dense(args.hidden, activation="relu")) + model.add(Dropout(0.5)) + model.add(Dense(num_classes, activation="softmax")) + + model.compile( + loss="categorical_crossentropy", + optimizer=keras.optimizers.SGD(lr=args.lr, momentum=args.momentum), + metrics=["accuracy"]) + + model.fit( + x_train, + y_train, + batch_size=batch_size, + epochs=epochs, + validation_data=(x_test, y_test), + callbacks=[TuneKerasCallback(track.metric)]) + track.shutdown() + + +if __name__ == "__main__": + train_mnist(args) diff --git a/python/ray/tune/examples/utils.py b/python/ray/tune/examples/utils.py index 3c73bce2bae7..a5ab1dbdb6a1 100644 --- a/python/ray/tune/examples/utils.py +++ b/python/ray/tune/examples/utils.py @@ -15,7 +15,9 @@ def __init__(self, reporter, logs={}): def on_train_end(self, epoch, logs={}): self.reporter( - timesteps_total=self.iteration, done=1, mean_accuracy=logs["acc"]) + timesteps_total=self.iteration, + done=1, + mean_accuracy=logs.get("acc")) def on_batch_end(self, batch, logs={}): self.iteration += 1 diff --git a/python/ray/tune/function_runner.py b/python/ray/tune/function_runner.py index 551f1702759f..e30e2bdf5cf0 100644 --- a/python/ray/tune/function_runner.py +++ b/python/ray/tune/function_runner.py @@ -5,9 +5,11 @@ import logging import sys import time +import inspect import threading from six.moves import queue +from ray.tune import track from ray.tune import TuneError from ray.tune.trainable import Trainable from ray.tune.result import TIME_THIS_ITER_S, RESULT_DUPLICATE @@ -244,6 +246,17 @@ def _report_thread_runner_error(self, block=False): def wrap_function(train_func): + + use_track = False + try: + func_args = inspect.getargspec(train_func).args + use_track = ("reporter" not in func_args and len(func_args) == 1) + if use_track: + logger.info("tune.track signature detected.") + except Exception: + logger.info( + "Function inspection failed - assuming reporter signature.") + class WrappedFunc(FunctionRunner): def _trainable_func(self, config, reporter): output = train_func(config, reporter) @@ -253,4 +266,12 @@ def _trainable_func(self, config, reporter): reporter(**{RESULT_DUPLICATE: True}) return output - return WrappedFunc + class WrappedTrackFunc(FunctionRunner): + def _trainable_func(self, config, reporter): + track.init(_tune_reporter=reporter) + output = train_func(config) + reporter(**{RESULT_DUPLICATE: True}) + track.shutdown() + return output + + return WrappedTrackFunc if use_track else WrappedFunc diff --git a/python/ray/tune/logger.py b/python/ray/tune/logger.py index 9d472cac36fe..4b9d5a914aa1 100644 --- a/python/ray/tune/logger.py +++ b/python/ray/tune/logger.py @@ -50,6 +50,11 @@ def on_result(self, result): raise NotImplementedError + def update_config(self, config): + """Updates the config for all loggers.""" + + pass + def close(self): """Releases all resources used by this logger.""" @@ -68,17 +73,7 @@ def on_result(self, result): class JsonLogger(Logger): def _init(self): - config_out = os.path.join(self.logdir, "params.json") - with open(config_out, "w") as f: - json.dump( - self.config, - f, - indent=2, - sort_keys=True, - cls=_SafeFallbackEncoder) - config_pkl = os.path.join(self.logdir, "params.pkl") - with open(config_pkl, "wb") as f: - cloudpickle.dump(self.config, f) + self.update_config(self.config) local_file = os.path.join(self.logdir, "result.json") self.local_out = open(local_file, "a") @@ -96,6 +91,15 @@ def flush(self): def close(self): self.local_out.close() + def update_config(self, config): + self.config = config + config_out = os.path.join(self.logdir, "params.json") + with open(config_out, "w") as f: + json.dump(self.config, f, cls=_SafeFallbackEncoder) + config_pkl = os.path.join(self.logdir, "params.pkl") + with open(config_pkl, "wb") as f: + cloudpickle.dump(self.config, f) + def to_tf_values(result, path): values = [] @@ -231,6 +235,10 @@ def on_result(self, result): self._log_syncer.set_worker_ip(result.get(NODE_IP)) self._log_syncer.sync_if_needed() + def update_config(self, config): + for _logger in self._loggers: + _logger.update_config(config) + def close(self): for _logger in self._loggers: _logger.close() diff --git a/python/ray/tune/result.py b/python/ray/tune/result.py index 2978fe540d18..51a67d5931a7 100644 --- a/python/ray/tune/result.py +++ b/python/ray/tune/result.py @@ -68,7 +68,7 @@ EXPR_META_FILE = "trial_status.json" # File that stores parameters of the trial. -EXPR_PARARM_FILE = "params.json" +EXPR_PARAM_FILE = "params.json" # File that stores the progress of the trial. EXPR_PROGRESS_FILE = "progress.csv" diff --git a/python/ray/tune/tests/test_track.py b/python/ray/tune/tests/test_track.py new file mode 100644 index 000000000000..d3b6c38d745a --- /dev/null +++ b/python/ray/tune/tests/test_track.py @@ -0,0 +1,84 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import pandas as pd +import unittest + +import ray +from ray import tune +from ray.tune import track +from ray.tune.result import EXPR_PARAM_FILE, EXPR_RESULT_FILE + + +def _check_json_val(fname, key, val): + with open(fname, "r") as f: + df = pd.read_json(f, typ="frame", lines=True) + return key in df.columns and (df[key].tail(n=1) == val).all() + + +class TrackApiTest(unittest.TestCase): + def tearDown(self): + track.shutdown() + ray.shutdown() + + def testSessionInitShutdown(self): + self.assertTrue(track._session is None) + + # Checks that the singleton _session is created/destroyed + # by track.init() and track.shutdown() + for _ in range(2): + # do it twice to see that we can reopen the session + track.init(trial_name="test_init") + self.assertTrue(track._session is not None) + track.shutdown() + self.assertTrue(track._session is None) + + def testLogCreation(self): + """Checks that track.init() starts logger and creates log files.""" + track.init(trial_name="test_init") + session = track.get_session() + self.assertTrue(session is not None) + + self.assertTrue(os.path.isdir(session.logdir)) + + params_path = os.path.join(session.logdir, EXPR_PARAM_FILE) + result_path = os.path.join(session.logdir, EXPR_RESULT_FILE) + + self.assertTrue(os.path.exists(params_path)) + self.assertTrue(os.path.exists(result_path)) + self.assertTrue(session.logdir == track.trial_dir()) + + def testMetric(self): + track.init(trial_name="test_log") + session = track.get_session() + for i in range(5): + track.log(test=i) + result_path = os.path.join(session.logdir, EXPR_RESULT_FILE) + self.assertTrue(_check_json_val(result_path, "test", i)) + + def testRayOutput(self): + """Checks that local and remote log format are the same.""" + ray.init() + + def testme(config): + for i in range(config["iters"]): + track.log(iteration=i, hi="test") + + trials = tune.run(testme, config={"iters": 5}) + trial_res = trials[0].last_result + self.assertTrue(trial_res["hi"], "test") + self.assertTrue(trial_res["training_iteration"], 5) + + def testLocalMetrics(self): + """Checks that metric state is updated correctly.""" + track.init(trial_name="test_logs") + session = track.get_session() + self.assertEqual(set(session.trial_config.keys()), {"trial_id"}) + + result_path = os.path.join(session.logdir, EXPR_RESULT_FILE) + track.log(test=1) + self.assertTrue(_check_json_val(result_path, "test", 1)) + track.log(iteration=1, test=2) + self.assertTrue(_check_json_val(result_path, "test", 2)) diff --git a/python/ray/tune/track/__init__.py b/python/ray/tune/track/__init__.py new file mode 100644 index 000000000000..a35511e89350 --- /dev/null +++ b/python/ray/tune/track/__init__.py @@ -0,0 +1,71 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import logging + +from ray.tune.track.session import TrackSession + +logger = logging.getLogger(__name__) + +_session = None + + +def get_session(): + global _session + if not _session: + raise ValueError("Session not detected. Try `track.init()`?") + return _session + + +def init(ignore_reinit_error=True, **session_kwargs): + """Initializes the global trial context for this process. + + This creates a TrackSession object and the corresponding hooks for logging. + + Examples: + >>> from ray.tune import track + >>> track.init() + """ + global _session + + if _session: + # TODO(ng): would be nice to stack crawl at creation time to report + # where that initial trial was created, and that creation line + # info is helpful to keep around anyway. + reinit_msg = "A session already exists in the current context." + if ignore_reinit_error: + if not _session.is_tune_session: + logger.warning(reinit_msg) + return + else: + raise ValueError(reinit_msg) + + _session = TrackSession(**session_kwargs) + + +def shutdown(): + """Cleans up the trial and removes it from the global context.""" + + global _session + if _session: + _session.close() + _session = None + + +def log(**kwargs): + """Applies TrackSession.log to the trial in the current context.""" + _session = get_session() + return _session.log(**kwargs) + + +def trial_dir(): + """Returns the directory where trial results are saved. + + This includes json data containing the session's parameters and metrics. + """ + _session = get_session() + return _session.logdir + + +__all__ = ["TrackSession", "session", "log", "trial_dir", "init", "shutdown"] diff --git a/python/ray/tune/track/session.py b/python/ray/tune/track/session.py new file mode 100644 index 000000000000..faf850e5fea2 --- /dev/null +++ b/python/ray/tune/track/session.py @@ -0,0 +1,110 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +from datetime import datetime + +from ray.tune.trial import Trial +from ray.tune.result import DEFAULT_RESULTS_DIR, TRAINING_ITERATION +from ray.tune.logger import UnifiedLogger, Logger + + +class _ReporterHook(Logger): + def __init__(self, tune_reporter): + self.tune_reporter = tune_reporter + + def on_result(self, metrics): + return self.tune_reporter(**metrics) + + +class TrackSession(object): + """Manages results for a single session. + + Represents a single Trial in an experiment. + + Attributes: + trial_name (str): Custom trial name. + experiment_dir (str): Directory where results for all trials + are stored. Each session is stored into a unique directory + inside experiment_dir. + upload_dir (str): Directory to sync results to. + trial_config (dict): Parameters that will be logged to disk. + _tune_reporter (StatusReporter): For rerouting when using Tune. + Will not instantiate logging if not None. + """ + + def __init__(self, + trial_name="", + experiment_dir=None, + upload_dir=None, + trial_config=None, + _tune_reporter=None): + self._experiment_dir = None + self._logdir = None + self._upload_dir = None + self.trial_config = None + self._iteration = -1 + self.is_tune_session = bool(_tune_reporter) + self.trial_id = Trial.generate_id() + if trial_name: + self.trial_id = trial_name + "_" + self.trial_id + if self.is_tune_session: + self._logger = _ReporterHook(_tune_reporter) + else: + self._initialize_logging(trial_name, experiment_dir, upload_dir, + trial_config) + + def _initialize_logging(self, + trial_name="", + experiment_dir=None, + upload_dir=None, + trial_config=None): + + # TODO(rliaw): In other parts of the code, this is `local_dir`. + if experiment_dir is None: + experiment_dir = os.path.join(DEFAULT_RESULTS_DIR, "default") + + self._experiment_dir = os.path.expanduser(experiment_dir) + + # TODO(rliaw): Refactor `logdir` to `trial_dir`. + self._logdir = Trial.create_logdir(trial_name, self._experiment_dir) + self._upload_dir = upload_dir + self.trial_config = trial_config or {} + + # misc metadata to save as well + self.trial_config["trial_id"] = self.trial_id + self._logger = UnifiedLogger(self.trial_config, self._logdir, + self._upload_dir) + + def log(self, **metrics): + """Logs all named arguments specified in **metrics. + + This will log trial metrics locally, and they will be synchronized + with the driver periodically through ray. + + Arguments: + metrics: named arguments with corresponding values to log. + """ + + # TODO: Implement a batching mechanism for multiple calls to `log` + # within the same iteration. + self._iteration += 1 + metrics_dict = metrics.copy() + metrics_dict.update({"trial_id": self.trial_id}) + + # TODO: Move Trainable autopopulation to a util function + metrics_dict.setdefault(TRAINING_ITERATION, self._iteration) + self._logger.on_result(metrics_dict) + + def close(self): + self.trial_config["trial_completed"] = True + self.trial_config["end_time"] = datetime.now().isoformat() + # TODO(rliaw): Have Tune support updated configs + self._logger.update_config(self.trial_config) + self._logger.close() + + @property + def logdir(self): + """Trial logdir (subdir of given experiment directory)""" + return self._logdir diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index ad61e8d4b393..272945ba1cf4 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -8,6 +8,7 @@ from datetime import datetime import logging import json +import uuid import time import tempfile import os @@ -27,7 +28,7 @@ from ray.tune.result import (DEFAULT_RESULTS_DIR, DONE, HOSTNAME, PID, TIME_TOTAL_S, TRAINING_ITERATION, TIMESTEPS_TOTAL, EPISODE_REWARD_MEAN, MEAN_LOSS, MEAN_ACCURACY) -from ray.utils import _random_string, binary_to_hex, hex_to_binary +from ray.utils import binary_to_hex, hex_to_binary DEBUG_PRINT_INTERVAL = 5 MAX_LEN_IDENTIFIER = 130 @@ -341,19 +342,22 @@ def _registration_check(cls, trainable_name): @classmethod def generate_id(cls): - return binary_to_hex(_random_string())[:8] + return str(uuid.uuid1().hex)[:8] + + @classmethod + def create_logdir(cls, identifier, local_dir): + if not os.path.exists(local_dir): + os.makedirs(local_dir) + return tempfile.mkdtemp( + prefix="{}_{}".format(identifier[:MAX_LEN_IDENTIFIER], date_str()), + dir=local_dir) def init_logger(self): """Init logger.""" if not self.result_logger: - if not os.path.exists(self.local_dir): - os.makedirs(self.local_dir) if not self.logdir: - self.logdir = tempfile.mkdtemp( - prefix="{}_{}".format( - str(self)[:MAX_LEN_IDENTIFIER], date_str()), - dir=self.local_dir) + self.logdir = Trial.create_logdir(str(self), self.local_dir) elif not os.path.exists(self.logdir): os.makedirs(self.logdir) From 6cb5b90bd6b2ec5fbdc5926acd3e0efed3ed1e03 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Sat, 18 May 2019 00:23:11 -0700 Subject: [PATCH 023/118] [rllib] [RFC] Dynamic definition of loss functions and modularization support (#4795) * dynamic graph * wip * clean up * fix * document trainer * wip * initialize the graph using a fake batch * clean up dynamic init * wip * spelling * use builder for ppo pol graph * add ppo graph * fix naming * order * docs * set class name correctly * add torch builder * add custom model support in builder * cleanup * remove underscores * fix py2 compat * Update dynamic_tf_policy_graph.py * Update tracking_dict.py * wip * rename * debug level * rename policy_graph -> policy in new classes * fix test * rename ppo tf policy * port appo too * forgot grads * default policy optimizer * make default config optional * add config to optimizer * use lr by default in optimizer * update * comments * remove optimizer * fix tuple actions support in dynamic tf graph --- python/ray/rllib/agents/a3c/a3c.py | 4 +- .../agents/a3c/a3c_torch_policy_graph.py | 173 +++--- python/ray/rllib/agents/pg/pg.py | 54 +- python/ray/rllib/agents/pg/pg_policy_graph.py | 105 +--- .../rllib/agents/pg/torch_pg_policy_graph.py | 93 +-- python/ray/rllib/agents/ppo/appo.py | 6 +- .../ray/rllib/agents/ppo/appo_policy_graph.py | 549 +++++++----------- python/ray/rllib/agents/ppo/ppo.py | 214 ++++--- .../ray/rllib/agents/ppo/ppo_policy_graph.py | 367 +++++------- python/ray/rllib/agents/trainer_template.py | 97 ++++ .../evaluation/dynamic_tf_policy_graph.py | 275 +++++++++ .../ray/rllib/evaluation/policy_evaluator.py | 10 +- .../ray/rllib/evaluation/tf_policy_graph.py | 64 +- .../rllib/evaluation/tf_policy_template.py | 146 +++++ .../rllib/evaluation/torch_policy_graph.py | 47 +- .../rllib/evaluation/torch_policy_template.py | 133 +++++ .../rllib/examples/multiagent_two_trainers.py | 4 +- python/ray/rllib/optimizers/multi_gpu_impl.py | 2 +- .../rllib/optimizers/multi_gpu_optimizer.py | 2 +- .../tests/test_external_multi_agent_env.py | 4 +- python/ray/rllib/tests/test_io.py | 4 +- .../ray/rllib/tests/test_multi_agent_env.py | 8 +- python/ray/rllib/tests/test_nested_spaces.py | 6 +- python/ray/rllib/tests/test_optimizers.py | 6 +- python/ray/rllib/utils/tracking_dict.py | 32 + 25 files changed, 1376 insertions(+), 1029 deletions(-) create mode 100644 python/ray/rllib/agents/trainer_template.py create mode 100644 python/ray/rllib/evaluation/dynamic_tf_policy_graph.py create mode 100644 python/ray/rllib/evaluation/tf_policy_template.py create mode 100644 python/ray/rllib/evaluation/torch_policy_template.py create mode 100644 python/ray/rllib/utils/tracking_dict.py diff --git a/python/ray/rllib/agents/a3c/a3c.py b/python/ray/rllib/agents/a3c/a3c.py index 836d9f074999..eb384058de80 100644 --- a/python/ray/rllib/agents/a3c/a3c.py +++ b/python/ray/rllib/agents/a3c/a3c.py @@ -49,8 +49,8 @@ class A3CTrainer(Trainer): def _init(self, config, env_creator): if config["use_pytorch"]: from ray.rllib.agents.a3c.a3c_torch_policy_graph import \ - A3CTorchPolicyGraph - policy_cls = A3CTorchPolicyGraph + A3CTorchPolicy + policy_cls = A3CTorchPolicy else: policy_cls = self._policy_graph diff --git a/python/ray/rllib/agents/a3c/a3c_torch_policy_graph.py b/python/ray/rllib/agents/a3c/a3c_torch_policy_graph.py index d35aabe0d667..fa6f857f9eca 100644 --- a/python/ray/rllib/agents/a3c/a3c_torch_policy_graph.py +++ b/python/ray/rllib/agents/a3c/a3c_torch_policy_graph.py @@ -7,109 +7,84 @@ from torch import nn import ray -from ray.rllib.models.catalog import ModelCatalog from ray.rllib.evaluation.postprocessing import compute_advantages, \ Postprocessing -from ray.rllib.evaluation.policy_graph import PolicyGraph from ray.rllib.evaluation.sample_batch import SampleBatch -from ray.rllib.evaluation.torch_policy_graph import TorchPolicyGraph -from ray.rllib.utils.annotations import override - - -class A3CLoss(nn.Module): - def __init__(self, dist_class, vf_loss_coeff=0.5, entropy_coeff=0.01): - nn.Module.__init__(self) - self.dist_class = dist_class - self.vf_loss_coeff = vf_loss_coeff - self.entropy_coeff = entropy_coeff - - def forward(self, policy_model, observations, actions, advantages, - value_targets): - logits, _, values, _ = policy_model({ - SampleBatch.CUR_OBS: observations - }, []) - dist = self.dist_class(logits) - log_probs = dist.logp(actions) - self.entropy = dist.entropy().mean() - self.pi_err = -advantages.dot(log_probs.reshape(-1)) - self.value_err = F.mse_loss(values.reshape(-1), value_targets) - overall_err = sum([ - self.pi_err, - self.vf_loss_coeff * self.value_err, - -self.entropy_coeff * self.entropy, - ]) - - return overall_err - - -class A3CPostprocessing(object): - """Adds the VF preds and advantages fields to the trajectory.""" - - @override(TorchPolicyGraph) - def extra_action_out(self, model_out): - return {SampleBatch.VF_PREDS: model_out[2].cpu().numpy()} - - @override(PolicyGraph) - def postprocess_trajectory(self, - sample_batch, - other_agent_batches=None, - episode=None): - completed = sample_batch[SampleBatch.DONES][-1] - if completed: - last_r = 0.0 - else: - last_r = self._value(sample_batch[SampleBatch.NEXT_OBS][-1]) - return compute_advantages(sample_batch, last_r, self.config["gamma"], - self.config["lambda"]) - - -class A3CTorchPolicyGraph(A3CPostprocessing, TorchPolicyGraph): - """A simple, non-recurrent PyTorch policy example.""" - - def __init__(self, obs_space, action_space, config): - config = dict(ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG, **config) - self.config = config - dist_class, self.logit_dim = ModelCatalog.get_action_dist( - action_space, self.config["model"], torch=True) - model = ModelCatalog.get_torch_model(obs_space, self.logit_dim, - self.config["model"]) - loss = A3CLoss(dist_class, self.config["vf_loss_coeff"], - self.config["entropy_coeff"]) - TorchPolicyGraph.__init__( - self, - obs_space, - action_space, - model, - loss, - loss_inputs=[ - SampleBatch.CUR_OBS, SampleBatch.ACTIONS, - Postprocessing.ADVANTAGES, Postprocessing.VALUE_TARGETS - ], - action_distribution_cls=dist_class) - - @override(TorchPolicyGraph) - def optimizer(self): - return torch.optim.Adam(self._model.parameters(), lr=self.config["lr"]) - - @override(TorchPolicyGraph) - def extra_grad_process(self): - info = {} - if self.config["grad_clip"]: - total_norm = nn.utils.clip_grad_norm_(self._model.parameters(), - self.config["grad_clip"]) - info["grad_gnorm"] = total_norm - return info - - @override(TorchPolicyGraph) - def extra_grad_info(self): - return { - "policy_entropy": self._loss.entropy.item(), - "policy_loss": self._loss.pi_err.item(), - "vf_loss": self._loss.value_err.item() - } - +from ray.rllib.evaluation.torch_policy_template import build_torch_policy + + +def actor_critic_loss(policy, batch_tensors): + logits, _, values, _ = policy.model({ + SampleBatch.CUR_OBS: batch_tensors[SampleBatch.CUR_OBS] + }, []) + dist = policy.dist_class(logits) + log_probs = dist.logp(batch_tensors[SampleBatch.ACTIONS]) + policy.entropy = dist.entropy().mean() + policy.pi_err = -batch_tensors[Postprocessing.ADVANTAGES].dot( + log_probs.reshape(-1)) + policy.value_err = F.mse_loss( + values.reshape(-1), batch_tensors[Postprocessing.VALUE_TARGETS]) + overall_err = sum([ + policy.pi_err, + policy.config["vf_loss_coeff"] * policy.value_err, + -policy.config["entropy_coeff"] * policy.entropy, + ]) + return overall_err + + +def loss_and_entropy_stats(policy, batch_tensors): + return { + "policy_entropy": policy.entropy.item(), + "policy_loss": policy.pi_err.item(), + "vf_loss": policy.value_err.item(), + } + + +def add_advantages(policy, + sample_batch, + other_agent_batches=None, + episode=None): + completed = sample_batch[SampleBatch.DONES][-1] + if completed: + last_r = 0.0 + else: + last_r = policy._value(sample_batch[SampleBatch.NEXT_OBS][-1]) + return compute_advantages(sample_batch, last_r, policy.config["gamma"], + policy.config["lambda"]) + + +def model_value_predictions(policy, model_out): + return {SampleBatch.VF_PREDS: model_out[2].cpu().numpy()} + + +def apply_grad_clipping(policy): + info = {} + if policy.config["grad_clip"]: + total_norm = nn.utils.clip_grad_norm_(policy.model.parameters(), + policy.config["grad_clip"]) + info["grad_gnorm"] = total_norm + return info + + +def torch_optimizer(policy, config): + return torch.optim.Adam(policy.model.parameters(), lr=config["lr"]) + + +class ValueNetworkMixin(object): def _value(self, obs): with self.lock: obs = torch.from_numpy(obs).float().unsqueeze(0).to(self.device) - _, _, vf, _ = self._model({"obs": obs}, []) + _, _, vf, _ = self.model({"obs": obs}, []) return vf.detach().cpu().numpy().squeeze() + + +A3CTorchPolicy = build_torch_policy( + name="A3CTorchPolicy", + get_default_config=lambda: ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG, + loss_fn=actor_critic_loss, + stats_fn=loss_and_entropy_stats, + postprocess_fn=add_advantages, + extra_action_out_fn=model_value_predictions, + extra_grad_process_fn=apply_grad_clipping, + optimizer_fn=torch_optimizer, + mixins=[ValueNetworkMixin]) diff --git a/python/ray/rllib/agents/pg/pg.py b/python/ray/rllib/agents/pg/pg.py index e70fdcc8b2c6..ffbb899d1b9e 100644 --- a/python/ray/rllib/agents/pg/pg.py +++ b/python/ray/rllib/agents/pg/pg.py @@ -2,11 +2,9 @@ from __future__ import division from __future__ import print_function -from ray.rllib.agents.trainer import Trainer, with_common_config -from ray.rllib.agents.pg.pg_policy_graph import PGPolicyGraph - -from ray.rllib.optimizers import SyncSamplesOptimizer -from ray.rllib.utils.annotations import override +from ray.rllib.agents.trainer import with_common_config +from ray.rllib.agents.trainer_template import build_trainer +from ray.rllib.agents.pg.pg_policy_graph import PGTFPolicy # yapf: disable # __sphinx_doc_begin__ @@ -22,40 +20,16 @@ # yapf: enable -class PGTrainer(Trainer): - """Simple policy gradient agent. - - This is an example agent to show how to implement algorithms in RLlib. - In most cases, you will probably want to use the PPO agent instead. - """ - - _name = "PG" - _default_config = DEFAULT_CONFIG - _policy_graph = PGPolicyGraph +def get_policy_class(config): + if config["use_pytorch"]: + from ray.rllib.agents.pg.torch_pg_policy_graph import PGTorchPolicy + return PGTorchPolicy + else: + return PGTFPolicy - @override(Trainer) - def _init(self, config, env_creator): - if config["use_pytorch"]: - from ray.rllib.agents.pg.torch_pg_policy_graph import \ - PGTorchPolicyGraph - policy_cls = PGTorchPolicyGraph - else: - policy_cls = self._policy_graph - self.local_evaluator = self.make_local_evaluator( - env_creator, policy_cls) - self.remote_evaluators = self.make_remote_evaluators( - env_creator, policy_cls, config["num_workers"]) - optimizer_config = dict( - config["optimizer"], - **{"train_batch_size": config["train_batch_size"]}) - self.optimizer = SyncSamplesOptimizer( - self.local_evaluator, self.remote_evaluators, **optimizer_config) - @override(Trainer) - def _train(self): - prev_steps = self.optimizer.num_steps_sampled - self.optimizer.step() - result = self.collect_metrics() - result.update(timesteps_this_iter=self.optimizer.num_steps_sampled - - prev_steps) - return result +PGTrainer = build_trainer( + name="PG", + default_config=DEFAULT_CONFIG, + default_policy=PGTFPolicy, + get_policy_class=get_policy_class) diff --git a/python/ray/rllib/agents/pg/pg_policy_graph.py b/python/ray/rllib/agents/pg/pg_policy_graph.py index a55af79b1e61..54fcd041cc72 100644 --- a/python/ray/rllib/agents/pg/pg_policy_graph.py +++ b/python/ray/rllib/agents/pg/pg_policy_graph.py @@ -3,102 +3,33 @@ from __future__ import print_function import ray -from ray.rllib.models.catalog import ModelCatalog from ray.rllib.evaluation.postprocessing import compute_advantages, \ Postprocessing -from ray.rllib.evaluation.policy_graph import PolicyGraph +from ray.rllib.evaluation.tf_policy_template import build_tf_policy from ray.rllib.evaluation.sample_batch import SampleBatch -from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph -from ray.rllib.utils.annotations import override from ray.rllib.utils import try_import_tf tf = try_import_tf() -class PGLoss(object): - """The basic policy gradient loss.""" +# The basic policy gradients loss +def policy_gradient_loss(policy, batch_tensors): + actions = batch_tensors[SampleBatch.ACTIONS] + advantages = batch_tensors[Postprocessing.ADVANTAGES] + return -tf.reduce_mean(policy.action_dist.logp(actions) * advantages) - def __init__(self, action_dist, actions, advantages): - self.loss = -tf.reduce_mean(action_dist.logp(actions) * advantages) +# This adds the "advantages" column to the sample batch. +def postprocess_advantages(policy, + sample_batch, + other_agent_batches=None, + episode=None): + return compute_advantages( + sample_batch, 0.0, policy.config["gamma"], use_gae=False) -class PGPostprocessing(object): - """Adds the advantages field to the trajectory.""" - @override(PolicyGraph) - def postprocess_trajectory(self, - sample_batch, - other_agent_batches=None, - episode=None): - # This adds the "advantages" column to the sample batch - return compute_advantages( - sample_batch, 0.0, self.config["gamma"], use_gae=False) - - -class PGPolicyGraph(PGPostprocessing, TFPolicyGraph): - """Simple policy gradient example of defining a policy graph.""" - - def __init__(self, obs_space, action_space, config): - config = dict(ray.rllib.agents.pg.pg.DEFAULT_CONFIG, **config) - self.config = config - - # Setup placeholders - obs = tf.placeholder(tf.float32, shape=[None] + list(obs_space.shape)) - dist_class, self.logit_dim = ModelCatalog.get_action_dist( - action_space, self.config["model"]) - prev_actions = ModelCatalog.get_action_placeholder(action_space) - prev_rewards = tf.placeholder(tf.float32, [None], name="prev_reward") - - # Create the model network and action outputs - self.model = ModelCatalog.get_model({ - "obs": obs, - "prev_actions": prev_actions, - "prev_rewards": prev_rewards, - "is_training": self._get_is_training_placeholder(), - }, obs_space, action_space, self.logit_dim, self.config["model"]) - action_dist = dist_class(self.model.outputs) # logit for each action - - # Setup policy loss - actions = ModelCatalog.get_action_placeholder(action_space) - advantages = tf.placeholder(tf.float32, [None], name="adv") - loss = PGLoss(action_dist, actions, advantages).loss - - # Mapping from sample batch keys to placeholders. These keys will be - # read from postprocessed sample batches and fed into the specified - # placeholders during loss computation. - loss_in = [ - (SampleBatch.CUR_OBS, obs), - (SampleBatch.ACTIONS, actions), - (SampleBatch.PREV_ACTIONS, prev_actions), - (SampleBatch.PREV_REWARDS, prev_rewards), - (Postprocessing.ADVANTAGES, advantages), - ] - - # Initialize TFPolicyGraph - sess = tf.get_default_session() - TFPolicyGraph.__init__( - self, - obs_space, - action_space, - sess, - obs_input=obs, - action_sampler=action_dist.sample(), - action_prob=action_dist.sampled_action_prob(), - loss=loss, - loss_inputs=loss_in, - model=self.model, - state_inputs=self.model.state_in, - state_outputs=self.model.state_out, - prev_action_input=prev_actions, - prev_reward_input=prev_rewards, - seq_lens=self.model.seq_lens, - max_seq_len=config["model"]["max_seq_len"]) - sess.run(tf.global_variables_initializer()) - - @override(PolicyGraph) - def get_initial_state(self): - return self.model.state_init - - @override(TFPolicyGraph) - def optimizer(self): - return tf.train.AdamOptimizer(learning_rate=self.config["lr"]) +PGTFPolicy = build_tf_policy( + name="PGTFPolicy", + get_default_config=lambda: ray.rllib.agents.pg.pg.DEFAULT_CONFIG, + postprocess_fn=postprocess_advantages, + loss_fn=policy_gradient_loss) diff --git a/python/ray/rllib/agents/pg/torch_pg_policy_graph.py b/python/ray/rllib/agents/pg/torch_pg_policy_graph.py index 746ef1bca42f..cda1b6eb5057 100644 --- a/python/ray/rllib/agents/pg/torch_pg_policy_graph.py +++ b/python/ray/rllib/agents/pg/torch_pg_policy_graph.py @@ -2,82 +2,41 @@ from __future__ import division from __future__ import print_function -import torch -from torch import nn - import ray -from ray.rllib.models.catalog import ModelCatalog from ray.rllib.evaluation.postprocessing import compute_advantages, \ Postprocessing -from ray.rllib.evaluation.policy_graph import PolicyGraph from ray.rllib.evaluation.sample_batch import SampleBatch -from ray.rllib.evaluation.torch_policy_graph import TorchPolicyGraph -from ray.rllib.utils.annotations import override - - -class PGLoss(nn.Module): - def __init__(self, dist_class): - nn.Module.__init__(self) - self.dist_class = dist_class - - def forward(self, policy_model, observations, actions, advantages): - logits, _, values, _ = policy_model({ - SampleBatch.CUR_OBS: observations - }, []) - dist = self.dist_class(logits) - log_probs = dist.logp(actions) - self.pi_err = -advantages.dot(log_probs.reshape(-1)) - return self.pi_err - - -class PGPostprocessing(object): - """Adds the value func output and advantages field to the trajectory.""" +from ray.rllib.evaluation.torch_policy_template import build_torch_policy - @override(TorchPolicyGraph) - def extra_action_out(self, model_out): - return {SampleBatch.VF_PREDS: model_out[2].cpu().numpy()} - @override(PolicyGraph) - def postprocess_trajectory(self, - sample_batch, - other_agent_batches=None, - episode=None): - return compute_advantages( - sample_batch, 0.0, self.config["gamma"], use_gae=False) +def pg_torch_loss(policy, batch_tensors): + logits, _, values, _ = policy.model({ + SampleBatch.CUR_OBS: batch_tensors[SampleBatch.CUR_OBS] + }, []) + action_dist = policy.dist_class(logits) + log_probs = action_dist.logp(batch_tensors[SampleBatch.ACTIONS]) + # save the error in the policy object + policy.pi_err = -batch_tensors[Postprocessing.ADVANTAGES].dot( + log_probs.reshape(-1)) + return policy.pi_err -class PGTorchPolicyGraph(PGPostprocessing, TorchPolicyGraph): - def __init__(self, obs_space, action_space, config): - config = dict(ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG, **config) - self.config = config - dist_class, self.logit_dim = ModelCatalog.get_action_dist( - action_space, self.config["model"], torch=True) - model = ModelCatalog.get_torch_model(obs_space, self.logit_dim, - self.config["model"]) - loss = PGLoss(dist_class) +def postprocess_advantages(policy, + sample_batch, + other_agent_batches=None, + episode=None): + return compute_advantages( + sample_batch, 0.0, policy.config["gamma"], use_gae=False) - TorchPolicyGraph.__init__( - self, - obs_space, - action_space, - model, - loss, - loss_inputs=[ - SampleBatch.CUR_OBS, SampleBatch.ACTIONS, - Postprocessing.ADVANTAGES - ], - action_distribution_cls=dist_class) - @override(TorchPolicyGraph) - def optimizer(self): - return torch.optim.Adam(self._model.parameters(), lr=self.config["lr"]) +def pg_loss_stats(policy, batch_tensors): + # the error is recorded when computing the loss + return {"policy_loss": policy.pi_err.item()} - @override(TorchPolicyGraph) - def extra_grad_info(self): - return {"policy_loss": self._loss.pi_err.item()} - def _value(self, obs): - with self.lock: - obs = torch.from_numpy(obs).float().unsqueeze(0).to(self.device) - _, _, vf, _ = self.model({"obs": obs}, []) - return vf.detach().cpu().numpy().squeeze() +PGTorchPolicy = build_torch_policy( + name="PGTorchPolicy", + get_default_config=lambda: ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG, + loss_fn=pg_torch_loss, + stats_fn=pg_loss_stats, + postprocess_fn=postprocess_advantages) diff --git a/python/ray/rllib/agents/ppo/appo.py b/python/ray/rllib/agents/ppo/appo.py index ac3251775d52..b32531dd7d5c 100644 --- a/python/ray/rllib/agents/ppo/appo.py +++ b/python/ray/rllib/agents/ppo/appo.py @@ -2,7 +2,7 @@ from __future__ import division from __future__ import print_function -from ray.rllib.agents.ppo.appo_policy_graph import AsyncPPOPolicyGraph +from ray.rllib.agents.ppo.appo_policy_graph import AsyncPPOTFPolicy from ray.rllib.agents.trainer import with_base_config from ray.rllib.agents import impala from ray.rllib.utils.annotations import override @@ -57,8 +57,8 @@ class APPOTrainer(impala.ImpalaTrainer): _name = "APPO" _default_config = DEFAULT_CONFIG - _policy_graph = AsyncPPOPolicyGraph + _policy_graph = AsyncPPOTFPolicy @override(impala.ImpalaTrainer) def _get_policy_graph(self): - return AsyncPPOPolicyGraph + return AsyncPPOTFPolicy diff --git a/python/ray/rllib/agents/ppo/appo_policy_graph.py b/python/ray/rllib/agents/ppo/appo_policy_graph.py index caaaf512bcb1..5aa76913194f 100644 --- a/python/ray/rllib/agents/ppo/appo_policy_graph.py +++ b/python/ray/rllib/agents/ppo/appo_policy_graph.py @@ -12,14 +12,11 @@ import ray from ray.rllib.agents.impala import vtrace -from ray.rllib.evaluation.policy_graph import PolicyGraph -from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY -from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \ - LearningRateSchedule -from ray.rllib.models.catalog import ModelCatalog -from ray.rllib.utils.annotations import override +from ray.rllib.evaluation.postprocessing import Postprocessing +from ray.rllib.evaluation.sample_batch import SampleBatch +from ray.rllib.evaluation.tf_policy_template import build_tf_policy +from ray.rllib.evaluation.tf_policy_graph import LearningRateSchedule from ray.rllib.utils.explained_variance import explained_variance -from ray.rllib.models.action_dist import MultiCategorical from ray.rllib.evaluation.postprocessing import compute_advantages from ray.rllib.utils import try_import_tf @@ -27,6 +24,8 @@ logger = logging.getLogger(__name__) +BEHAVIOUR_LOGITS = "behaviour_logits" + class PPOSurrogateLoss(object): """Loss used when V-trace is disabled. @@ -163,333 +162,235 @@ def __init__(self, self.entropy * entropy_coeff) -class APPOPostprocessing(object): - """Adds the policy logits, VF preds, and advantages to the trajectory.""" - - @override(TFPolicyGraph) - def extra_compute_action_fetches(self): - out = {"behaviour_logits": self.model.outputs} - if not self.config["vtrace"]: - out["vf_preds"] = self.value_function - return dict(TFPolicyGraph.extra_compute_action_fetches(self), **out) - - @override(PolicyGraph) - def postprocess_trajectory(self, - sample_batch, - other_agent_batches=None, - episode=None): - if not self.config["vtrace"]: - completed = sample_batch["dones"][-1] - if completed: - last_r = 0.0 - else: - next_state = [] - for i in range(len(self.model.state_in)): - next_state.append( - [sample_batch["state_out_{}".format(i)][-1]]) - last_r = self.value(sample_batch["new_obs"][-1], *next_state) - batch = compute_advantages( - sample_batch, - last_r, - self.config["gamma"], - self.config["lambda"], - use_gae=self.config["use_gae"]) - else: - batch = sample_batch - del batch.data["new_obs"] # not used, so save some bandwidth - return batch - +def _make_time_major(policy, tensor, drop_last=False): + """Swaps batch and trajectory axis. -class AsyncPPOPolicyGraph(LearningRateSchedule, APPOPostprocessing, - TFPolicyGraph): - def __init__(self, - observation_space, - action_space, - config, - existing_inputs=None): - config = dict(ray.rllib.agents.impala.impala.DEFAULT_CONFIG, **config) - assert config["batch_mode"] == "truncate_episodes", \ - "Must use `truncate_episodes` batch mode with V-trace." - self.config = config - self.sess = tf.get_default_session() - self.grads = None - - if isinstance(action_space, gym.spaces.Discrete): - is_multidiscrete = False - output_hidden_shape = [action_space.n] - elif isinstance(action_space, gym.spaces.multi_discrete.MultiDiscrete): - is_multidiscrete = True - output_hidden_shape = action_space.nvec.astype(np.int32) - else: - is_multidiscrete = False - output_hidden_shape = 1 - - # Policy network model - dist_class, logit_dim = ModelCatalog.get_action_dist( - action_space, self.config["model"]) - - # Create input placeholders - if existing_inputs: - if self.config["vtrace"]: - actions, dones, behaviour_logits, rewards, observations, \ - prev_actions, prev_rewards = existing_inputs[:7] - existing_state_in = existing_inputs[7:-1] - existing_seq_lens = existing_inputs[-1] - else: - actions, dones, behaviour_logits, rewards, observations, \ - prev_actions, prev_rewards, adv_ph, value_targets = \ - existing_inputs[:9] - existing_state_in = existing_inputs[9:-1] - existing_seq_lens = existing_inputs[-1] + Arguments: + policy: Policy reference + tensor: A tensor or list of tensors to reshape. + drop_last: A bool indicating whether to drop the last + trajectory item. + + Returns: + res: A tensor with swapped axes or a list of tensors with + swapped axes. + """ + if isinstance(tensor, list): + return [_make_time_major(policy, t, drop_last) for t in tensor] + + if policy.model.state_init: + B = tf.shape(policy.model.seq_lens)[0] + T = tf.shape(tensor)[0] // B + else: + # Important: chop the tensor into batches at known episode cut + # boundaries. TODO(ekl) this is kind of a hack + T = policy.config["sample_batch_size"] + B = tf.shape(tensor)[0] // T + rs = tf.reshape(tensor, tf.concat([[B, T], tf.shape(tensor)[1:]], axis=0)) + + # swap B and T axes + res = tf.transpose( + rs, [1, 0] + list(range(2, 1 + int(tf.shape(tensor).shape[0])))) + + if drop_last: + return res[:-1] + return res + + +def build_appo_surrogate_loss(policy, batch_tensors): + if isinstance(policy.action_space, gym.spaces.Discrete): + is_multidiscrete = False + output_hidden_shape = [policy.action_space.n] + elif isinstance(policy.action_space, + gym.spaces.multi_discrete.MultiDiscrete): + is_multidiscrete = True + output_hidden_shape = policy.action_space.nvec.astype(np.int32) + else: + is_multidiscrete = False + output_hidden_shape = 1 + + def make_time_major(*args, **kw): + return _make_time_major(policy, *args, **kw) + + actions = batch_tensors[SampleBatch.ACTIONS] + dones = batch_tensors[SampleBatch.DONES] + rewards = batch_tensors[SampleBatch.REWARDS] + behaviour_logits = batch_tensors[BEHAVIOUR_LOGITS] + unpacked_behaviour_logits = tf.split( + behaviour_logits, output_hidden_shape, axis=1) + unpacked_outputs = tf.split( + policy.model.outputs, output_hidden_shape, axis=1) + prev_dist_inputs = unpacked_behaviour_logits if is_multidiscrete else \ + behaviour_logits + action_dist = policy.action_dist + prev_action_dist = policy.dist_class(prev_dist_inputs) + values = policy.value_function + + if policy.model.state_in: + max_seq_len = tf.reduce_max(policy.model.seq_lens) - 1 + mask = tf.sequence_mask(policy.model.seq_lens, max_seq_len) + mask = tf.reshape(mask, [-1]) + else: + mask = tf.ones_like(rewards) + + if policy.config["vtrace"]: + logger.info("Using V-Trace surrogate loss (vtrace=True)") + + # Prepare actions for loss + loss_actions = actions if is_multidiscrete else tf.expand_dims( + actions, axis=1) + + policy.loss = VTraceSurrogateLoss( + actions=make_time_major(loss_actions, drop_last=True), + prev_actions_logp=make_time_major( + prev_action_dist.logp(actions), drop_last=True), + actions_logp=make_time_major( + action_dist.logp(actions), drop_last=True), + action_kl=prev_action_dist.kl(action_dist), + actions_entropy=make_time_major( + action_dist.entropy(), drop_last=True), + dones=make_time_major(dones, drop_last=True), + behaviour_logits=make_time_major( + unpacked_behaviour_logits, drop_last=True), + target_logits=make_time_major(unpacked_outputs, drop_last=True), + discount=policy.config["gamma"], + rewards=make_time_major(rewards, drop_last=True), + values=make_time_major(values, drop_last=True), + bootstrap_value=make_time_major(values)[-1], + dist_class=policy.dist_class, + valid_mask=make_time_major(mask, drop_last=True), + vf_loss_coeff=policy.config["vf_loss_coeff"], + entropy_coeff=policy.config["entropy_coeff"], + clip_rho_threshold=policy.config["vtrace_clip_rho_threshold"], + clip_pg_rho_threshold=policy.config[ + "vtrace_clip_pg_rho_threshold"], + clip_param=policy.config["clip_param"]) + else: + logger.info("Using PPO surrogate loss (vtrace=False)") + policy.loss = PPOSurrogateLoss( + prev_actions_logp=make_time_major(prev_action_dist.logp(actions)), + actions_logp=make_time_major(action_dist.logp(actions)), + action_kl=prev_action_dist.kl(action_dist), + actions_entropy=make_time_major(action_dist.entropy()), + values=make_time_major(values), + valid_mask=make_time_major(mask), + advantages=make_time_major( + batch_tensors[Postprocessing.ADVANTAGES]), + value_targets=make_time_major( + batch_tensors[Postprocessing.VALUE_TARGETS]), + vf_loss_coeff=policy.config["vf_loss_coeff"], + entropy_coeff=policy.config["entropy_coeff"], + clip_param=policy.config["clip_param"]) + + return policy.loss.total_loss + + +def stats(policy, batch_tensors): + values_batched = _make_time_major( + policy, policy.value_function, drop_last=policy.config["vtrace"]) + + return { + "cur_lr": tf.cast(policy.cur_lr, tf.float64), + "policy_loss": policy.loss.pi_loss, + "entropy": policy.loss.entropy, + "var_gnorm": tf.global_norm(policy.var_list), + "vf_loss": policy.loss.vf_loss, + "vf_explained_var": explained_variance( + tf.reshape(policy.loss.value_targets, [-1]), + tf.reshape(values_batched, [-1])), + } + + +def grad_stats(policy, grads): + return { + "grad_gnorm": tf.global_norm(grads), + } + + +def postprocess_trajectory(policy, + sample_batch, + other_agent_batches=None, + episode=None): + if not policy.config["vtrace"]: + completed = sample_batch["dones"][-1] + if completed: + last_r = 0.0 else: - actions = ModelCatalog.get_action_placeholder(action_space) - dones = tf.placeholder(tf.bool, [None], name="dones") - rewards = tf.placeholder(tf.float32, [None], name="rewards") - behaviour_logits = tf.placeholder( - tf.float32, [None, logit_dim], name="behaviour_logits") - observations = tf.placeholder( - tf.float32, [None] + list(observation_space.shape)) - existing_state_in = None - existing_seq_lens = None - - if not self.config["vtrace"]: - adv_ph = tf.placeholder( - tf.float32, name="advantages", shape=(None, )) - value_targets = tf.placeholder( - tf.float32, name="value_targets", shape=(None, )) - self.observations = observations - - # Unpack behaviour logits - unpacked_behaviour_logits = tf.split( - behaviour_logits, output_hidden_shape, axis=1) - - # Setup the policy - dist_class, logit_dim = ModelCatalog.get_action_dist( - action_space, self.config["model"]) - prev_actions = ModelCatalog.get_action_placeholder(action_space) - prev_rewards = tf.placeholder(tf.float32, [None], name="prev_reward") - self.model = ModelCatalog.get_model( - { - "obs": observations, - "prev_actions": prev_actions, - "prev_rewards": prev_rewards, - "is_training": self._get_is_training_placeholder(), - }, - observation_space, - action_space, - logit_dim, - self.config["model"], - state_in=existing_state_in, - seq_lens=existing_seq_lens) - unpacked_outputs = tf.split( - self.model.outputs, output_hidden_shape, axis=1) - - dist_inputs = unpacked_outputs if is_multidiscrete else \ - self.model.outputs - prev_dist_inputs = unpacked_behaviour_logits if is_multidiscrete else \ - behaviour_logits - - action_dist = dist_class(dist_inputs) - prev_action_dist = dist_class(prev_dist_inputs) - - values = self.model.value_function() - self.value_function = values + next_state = [] + for i in range(len(policy.model.state_in)): + next_state.append([sample_batch["state_out_{}".format(i)][-1]]) + last_r = policy.value(sample_batch["new_obs"][-1], *next_state) + batch = compute_advantages( + sample_batch, + last_r, + policy.config["gamma"], + policy.config["lambda"], + use_gae=policy.config["use_gae"]) + else: + batch = sample_batch + del batch.data["new_obs"] # not used, so save some bandwidth + return batch + + +def add_values_and_logits(policy): + out = {BEHAVIOUR_LOGITS: policy.model.outputs} + if not policy.config["vtrace"]: + out[SampleBatch.VF_PREDS] = policy.value_function + return out + + +def validate_config(policy, obs_space, action_space, config): + assert config["batch_mode"] == "truncate_episodes", \ + "Must use `truncate_episodes` batch mode with V-trace." + + +def choose_optimizer(policy, config): + if policy.config["opt_type"] == "adam": + return tf.train.AdamOptimizer(policy.cur_lr) + else: + return tf.train.RMSPropOptimizer(policy.cur_lr, config["decay"], + config["momentum"], config["epsilon"]) + + +def clip_gradients(policy, optimizer, loss): + grads = tf.gradients(loss, policy.var_list) + policy.grads, _ = tf.clip_by_global_norm(grads, policy.config["grad_clip"]) + clipped_grads = list(zip(policy.grads, policy.var_list)) + return clipped_grads + + +class ValueNetworkMixin(object): + def __init__(self): + self.value_function = self.model.value_function() self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name) - def make_time_major(tensor, drop_last=False): - """Swaps batch and trajectory axis. - Args: - tensor: A tensor or list of tensors to reshape. - drop_last: A bool indicating whether to drop the last - trajectory item. - Returns: - res: A tensor with swapped axes or a list of tensors with - swapped axes. - """ - if isinstance(tensor, list): - return [make_time_major(t, drop_last) for t in tensor] - - if self.model.state_init: - B = tf.shape(self.model.seq_lens)[0] - T = tf.shape(tensor)[0] // B - else: - # Important: chop the tensor into batches at known episode cut - # boundaries. TODO(ekl) this is kind of a hack - T = self.config["sample_batch_size"] - B = tf.shape(tensor)[0] // T - rs = tf.reshape(tensor, - tf.concat([[B, T], tf.shape(tensor)[1:]], axis=0)) - - # swap B and T axes - res = tf.transpose( - rs, - [1, 0] + list(range(2, 1 + int(tf.shape(tensor).shape[0])))) - - if drop_last: - return res[:-1] - return res - - if self.model.state_in: - max_seq_len = tf.reduce_max(self.model.seq_lens) - 1 - mask = tf.sequence_mask(self.model.seq_lens, max_seq_len) - mask = tf.reshape(mask, [-1]) - else: - mask = tf.ones_like(rewards) - - # Inputs are reshaped from [B * T] => [T - 1, B] for V-trace calc. - if self.config["vtrace"]: - logger.info("Using V-Trace surrogate loss (vtrace=True)") - - # Prepare actions for loss - loss_actions = actions if is_multidiscrete else tf.expand_dims( - actions, axis=1) - - self.loss = VTraceSurrogateLoss( - actions=make_time_major(loss_actions, drop_last=True), - prev_actions_logp=make_time_major( - prev_action_dist.logp(actions), drop_last=True), - actions_logp=make_time_major( - action_dist.logp(actions), drop_last=True), - action_kl=prev_action_dist.kl(action_dist), - actions_entropy=make_time_major( - action_dist.entropy(), drop_last=True), - dones=make_time_major(dones, drop_last=True), - behaviour_logits=make_time_major( - unpacked_behaviour_logits, drop_last=True), - target_logits=make_time_major( - unpacked_outputs, drop_last=True), - discount=config["gamma"], - rewards=make_time_major(rewards, drop_last=True), - values=make_time_major(values, drop_last=True), - bootstrap_value=make_time_major(values)[-1], - dist_class=dist_class, - valid_mask=make_time_major(mask, drop_last=True), - vf_loss_coeff=self.config["vf_loss_coeff"], - entropy_coeff=self.config["entropy_coeff"], - clip_rho_threshold=self.config["vtrace_clip_rho_threshold"], - clip_pg_rho_threshold=self.config[ - "vtrace_clip_pg_rho_threshold"], - clip_param=self.config["clip_param"]) - else: - logger.info("Using PPO surrogate loss (vtrace=False)") - self.loss = PPOSurrogateLoss( - prev_actions_logp=make_time_major( - prev_action_dist.logp(actions)), - actions_logp=make_time_major(action_dist.logp(actions)), - action_kl=prev_action_dist.kl(action_dist), - actions_entropy=make_time_major(action_dist.entropy()), - values=make_time_major(values), - valid_mask=make_time_major(mask), - advantages=make_time_major(adv_ph), - value_targets=make_time_major(value_targets), - vf_loss_coeff=self.config["vf_loss_coeff"], - entropy_coeff=self.config["entropy_coeff"], - clip_param=self.config["clip_param"]) - - # KL divergence between worker and learner logits for debugging - model_dist = MultiCategorical(unpacked_outputs) - behaviour_dist = MultiCategorical(unpacked_behaviour_logits) - - kls = model_dist.kl(behaviour_dist) - if len(kls) > 1: - self.KL_stats = {} - - for i, kl in enumerate(kls): - self.KL_stats.update({ - "mean_KL_{}".format(i): tf.reduce_mean(kl), - "max_KL_{}".format(i): tf.reduce_max(kl), - }) - else: - self.KL_stats = { - "mean_KL": tf.reduce_mean(kls[0]), - "max_KL": tf.reduce_max(kls[0]), - } - - # Initialize TFPolicyGraph - loss_in = [ - ("actions", actions), - ("dones", dones), - ("behaviour_logits", behaviour_logits), - ("rewards", rewards), - ("obs", observations), - ("prev_actions", prev_actions), - ("prev_rewards", prev_rewards), - ] - if not self.config["vtrace"]: - loss_in.append(("advantages", adv_ph)) - loss_in.append(("value_targets", value_targets)) - LearningRateSchedule.__init__(self, self.config["lr"], - self.config["lr_schedule"]) - TFPolicyGraph.__init__( - self, - observation_space, - action_space, - self.sess, - obs_input=observations, - action_sampler=action_dist.sample(), - action_prob=action_dist.sampled_action_prob(), - loss=self.loss.total_loss, - model=self.model, - loss_inputs=loss_in, - state_inputs=self.model.state_in, - state_outputs=self.model.state_out, - prev_action_input=prev_actions, - prev_reward_input=prev_rewards, - seq_lens=self.model.seq_lens, - max_seq_len=self.config["model"]["max_seq_len"], - batch_divisibility_req=self.config["sample_batch_size"]) - - self.sess.run(tf.global_variables_initializer()) - - values_batched = make_time_major( - values, drop_last=self.config["vtrace"]) - self.stats_fetches = { - LEARNER_STATS_KEY: dict({ - "cur_lr": tf.cast(self.cur_lr, tf.float64), - "policy_loss": self.loss.pi_loss, - "entropy": self.loss.entropy, - "grad_gnorm": tf.global_norm(self._grads), - "var_gnorm": tf.global_norm(self.var_list), - "vf_loss": self.loss.vf_loss, - "vf_explained_var": explained_variance( - tf.reshape(self.loss.value_targets, [-1]), - tf.reshape(values_batched, [-1])), - }, **self.KL_stats), - } - - def optimizer(self): - if self.config["opt_type"] == "adam": - return tf.train.AdamOptimizer(self.cur_lr) - else: - return tf.train.RMSPropOptimizer(self.cur_lr, self.config["decay"], - self.config["momentum"], - self.config["epsilon"]) - - def gradients(self, optimizer, loss): - grads = tf.gradients(loss, self.var_list) - self.grads, _ = tf.clip_by_global_norm(grads, self.config["grad_clip"]) - clipped_grads = list(zip(self.grads, self.var_list)) - return clipped_grads - - def extra_compute_grad_fetches(self): - return self.stats_fetches - def value(self, ob, *args): - feed_dict = {self.observations: [ob], self.model.seq_lens: [1]} + feed_dict = {self._obs_input: [ob], self.model.seq_lens: [1]} assert len(args) == len(self.model.state_in), \ (args, self.model.state_in) for k, v in zip(self.model.state_in, args): feed_dict[k] = v - vf = self.sess.run(self.value_function, feed_dict) + vf = self._sess.run(self.value_function, feed_dict) return vf[0] - def get_initial_state(self): - return self.model.state_init - def copy(self, existing_inputs): - return AsyncPPOPolicyGraph( - self.observation_space, - self.action_space, - self.config, - existing_inputs=existing_inputs) +def setup_mixins(policy, obs_space, action_space, config): + LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"]) + ValueNetworkMixin.__init__(policy) + + +AsyncPPOTFPolicy = build_tf_policy( + name="AsyncPPOTFPolicy", + get_default_config=lambda: ray.rllib.agents.impala.impala.DEFAULT_CONFIG, + loss_fn=build_appo_surrogate_loss, + stats_fn=stats, + grad_stats_fn=grad_stats, + postprocess_fn=postprocess_trajectory, + optimizer_fn=choose_optimizer, + gradients_fn=clip_gradients, + extra_action_fetches_fn=add_values_and_logits, + before_init=validate_config, + before_loss_init=setup_mixins, + mixins=[LearningRateSchedule, ValueNetworkMixin], + get_batch_divisibility_req=lambda p: p.config["sample_batch_size"]) diff --git a/python/ray/rllib/agents/ppo/ppo.py b/python/ray/rllib/agents/ppo/ppo.py index 8f69c91149e7..d3f5abdaa95c 100644 --- a/python/ray/rllib/agents/ppo/ppo.py +++ b/python/ray/rllib/agents/ppo/ppo.py @@ -4,10 +4,10 @@ import logging -from ray.rllib.agents import Trainer, with_common_config -from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph +from ray.rllib.agents import with_common_config +from ray.rllib.agents.ppo.ppo_policy_graph import PPOTFPolicy +from ray.rllib.agents.trainer_template import build_trainer from ray.rllib.optimizers import SyncSamplesOptimizer, LocalMultiGPUOptimizer -from ray.rllib.utils.annotations import override logger = logging.getLogger(__name__) @@ -63,110 +63,104 @@ # yapf: enable -class PPOTrainer(Trainer): - """Multi-GPU optimized implementation of PPO in TensorFlow.""" - - _name = "PPO" - _default_config = DEFAULT_CONFIG - _policy_graph = PPOPolicyGraph - - @override(Trainer) - def _init(self, config, env_creator): - self._validate_config() - self.local_evaluator = self.make_local_evaluator( - env_creator, self._policy_graph) - self.remote_evaluators = self.make_remote_evaluators( - env_creator, self._policy_graph, config["num_workers"]) - if config["simple_optimizer"]: - self.optimizer = SyncSamplesOptimizer( - self.local_evaluator, - self.remote_evaluators, - num_sgd_iter=config["num_sgd_iter"], - train_batch_size=config["train_batch_size"]) - else: - self.optimizer = LocalMultiGPUOptimizer( - self.local_evaluator, - self.remote_evaluators, - sgd_batch_size=config["sgd_minibatch_size"], - num_sgd_iter=config["num_sgd_iter"], - num_gpus=config["num_gpus"], - sample_batch_size=config["sample_batch_size"], - num_envs_per_worker=config["num_envs_per_worker"], - train_batch_size=config["train_batch_size"], - standardize_fields=["advantages"], - straggler_mitigation=config["straggler_mitigation"]) - - @override(Trainer) - def _train(self): - if "observation_filter" not in self.raw_user_config: - # TODO(ekl) remove this message after a few releases - logger.info( - "Important! Since 0.7.0, observation normalization is no " - "longer enabled by default. To enable running-mean " - "normalization, set 'observation_filter': 'MeanStdFilter'. " - "You can ignore this message if your environment doesn't " - "require observation normalization.") - prev_steps = self.optimizer.num_steps_sampled - fetches = self.optimizer.step() - if "kl" in fetches: - # single-agent - self.local_evaluator.for_policy( - lambda pi: pi.update_kl(fetches["kl"])) - else: - - def update(pi, pi_id): - if pi_id in fetches: - pi.update_kl(fetches[pi_id]["kl"]) - else: - logger.debug( - "No data for {}, not updating kl".format(pi_id)) - - # multi-agent - self.local_evaluator.foreach_trainable_policy(update) - res = self.collect_metrics() - res.update( - timesteps_this_iter=self.optimizer.num_steps_sampled - prev_steps, - info=res.get("info", {})) - - # Warn about bad clipping configs - if self.config["vf_clip_param"] <= 0: - rew_scale = float("inf") - elif res["policy_reward_mean"]: - rew_scale = 0 # punt on handling multiagent case - else: - rew_scale = round( - abs(res["episode_reward_mean"]) / self.config["vf_clip_param"], - 0) - if rew_scale > 200: - logger.warning( - "The magnitude of your environment rewards are more than " - "{}x the scale of `vf_clip_param`. ".format(rew_scale) + - "This means that it will take more than " - "{} iterations for your value ".format(rew_scale) + - "function to converge. If this is not intended, consider " - "increasing `vf_clip_param`.") - return res - - def _validate_config(self): - if self.config["entropy_coeff"] < 0: - raise DeprecationWarning("entropy_coeff must be >= 0") - if self.config["sgd_minibatch_size"] > self.config["train_batch_size"]: - raise ValueError( - "Minibatch size {} must be <= train batch size {}.".format( - self.config["sgd_minibatch_size"], - self.config["train_batch_size"])) - if (self.config["batch_mode"] == "truncate_episodes" - and not self.config["use_gae"]): - raise ValueError( - "Episode truncation is not supported without a value " - "function. Consider setting batch_mode=complete_episodes.") - if (self.config["multiagent"]["policy_graphs"] - and not self.config["simple_optimizer"]): - logger.info( - "In multi-agent mode, policies will be optimized sequentially " - "by the multi-GPU optimizer. Consider setting " - "simple_optimizer=True if this doesn't work for you.") - if not self.config["vf_share_layers"]: - logger.warning( - "FYI: By default, the value function will not share layers " - "with the policy model ('vf_share_layers': False).") +def make_optimizer(local_evaluator, remote_evaluators, config): + if config["simple_optimizer"]: + return SyncSamplesOptimizer( + local_evaluator, + remote_evaluators, + num_sgd_iter=config["num_sgd_iter"], + train_batch_size=config["train_batch_size"]) + + return LocalMultiGPUOptimizer( + local_evaluator, + remote_evaluators, + sgd_batch_size=config["sgd_minibatch_size"], + num_sgd_iter=config["num_sgd_iter"], + num_gpus=config["num_gpus"], + sample_batch_size=config["sample_batch_size"], + num_envs_per_worker=config["num_envs_per_worker"], + train_batch_size=config["train_batch_size"], + standardize_fields=["advantages"], + straggler_mitigation=config["straggler_mitigation"]) + + +def update_kl(trainer, fetches): + if "kl" in fetches: + # single-agent + trainer.local_evaluator.for_policy( + lambda pi: pi.update_kl(fetches["kl"])) + else: + + def update(pi, pi_id): + if pi_id in fetches: + pi.update_kl(fetches[pi_id]["kl"]) + else: + logger.debug("No data for {}, not updating kl".format(pi_id)) + + # multi-agent + trainer.local_evaluator.foreach_trainable_policy(update) + + +def warn_about_obs_filter(trainer): + if "observation_filter" not in trainer.raw_user_config: + # TODO(ekl) remove this message after a few releases + logger.info( + "Important! Since 0.7.0, observation normalization is no " + "longer enabled by default. To enable running-mean " + "normalization, set 'observation_filter': 'MeanStdFilter'. " + "You can ignore this message if your environment doesn't " + "require observation normalization.") + + +def warn_about_bad_reward_scales(trainer, result): + # Warn about bad clipping configs + if trainer.config["vf_clip_param"] <= 0: + rew_scale = float("inf") + elif result["policy_reward_mean"]: + rew_scale = 0 # punt on handling multiagent case + else: + rew_scale = round( + abs(result["episode_reward_mean"]) / + trainer.config["vf_clip_param"], 0) + if rew_scale > 200: + logger.warning( + "The magnitude of your environment rewards are more than " + "{}x the scale of `vf_clip_param`. ".format(rew_scale) + + "This means that it will take more than " + "{} iterations for your value ".format(rew_scale) + + "function to converge. If this is not intended, consider " + "increasing `vf_clip_param`.") + + +def validate_config(config): + if config["entropy_coeff"] < 0: + raise DeprecationWarning("entropy_coeff must be >= 0") + if config["sgd_minibatch_size"] > config["train_batch_size"]: + raise ValueError( + "Minibatch size {} must be <= train batch size {}.".format( + config["sgd_minibatch_size"], config["train_batch_size"])) + if (config["batch_mode"] == "truncate_episodes" and not config["use_gae"]): + raise ValueError( + "Episode truncation is not supported without a value " + "function. Consider setting batch_mode=complete_episodes.") + if (config["multiagent"]["policy_graphs"] + and not config["simple_optimizer"]): + logger.info( + "In multi-agent mode, policies will be optimized sequentially " + "by the multi-GPU optimizer. Consider setting " + "simple_optimizer=True if this doesn't work for you.") + if not config["vf_share_layers"]: + logger.warning( + "FYI: By default, the value function will not share layers " + "with the policy model ('vf_share_layers': False).") + + +PPOTrainer = build_trainer( + name="PPO", + default_config=DEFAULT_CONFIG, + default_policy=PPOTFPolicy, + make_policy_optimizer=make_optimizer, + validate_config=validate_config, + after_optimizer_step=update_kl, + before_train_step=warn_about_obs_filter, + after_train_result=warn_about_bad_reward_scales) diff --git a/python/ray/rllib/agents/ppo/ppo_policy_graph.py b/python/ray/rllib/agents/ppo/ppo_policy_graph.py index 61aced1db740..334ca788c936 100644 --- a/python/ray/rllib/agents/ppo/ppo_policy_graph.py +++ b/python/ray/rllib/agents/ppo/ppo_policy_graph.py @@ -7,13 +7,10 @@ import ray from ray.rllib.evaluation.postprocessing import compute_advantages, \ Postprocessing -from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY -from ray.rllib.evaluation.policy_graph import PolicyGraph from ray.rllib.evaluation.sample_batch import SampleBatch -from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \ - LearningRateSchedule +from ray.rllib.evaluation.tf_policy_graph import LearningRateSchedule +from ray.rllib.evaluation.tf_policy_template import build_tf_policy from ray.rllib.models.catalog import ModelCatalog -from ray.rllib.utils.annotations import override from ray.rllib.utils.explained_variance import explained_variance from ray.rllib.utils import try_import_tf @@ -107,119 +104,106 @@ def reduce_mean_valid(t): self.loss = loss -class PPOPostprocessing(object): +def ppo_surrogate_loss(policy, batch_tensors): + if policy.model.state_in: + max_seq_len = tf.reduce_max(policy.model.seq_lens) + mask = tf.sequence_mask(policy.model.seq_lens, max_seq_len) + mask = tf.reshape(mask, [-1]) + else: + mask = tf.ones_like( + batch_tensors[Postprocessing.ADVANTAGES], dtype=tf.bool) + + policy.loss_obj = PPOLoss( + policy.action_space, + batch_tensors[Postprocessing.VALUE_TARGETS], + batch_tensors[Postprocessing.ADVANTAGES], + batch_tensors[SampleBatch.ACTIONS], + batch_tensors[BEHAVIOUR_LOGITS], + batch_tensors[SampleBatch.VF_PREDS], + policy.action_dist, + policy.value_function, + policy.kl_coeff, + mask, + entropy_coeff=policy.config["entropy_coeff"], + clip_param=policy.config["clip_param"], + vf_clip_param=policy.config["vf_clip_param"], + vf_loss_coeff=policy.config["vf_loss_coeff"], + use_gae=policy.config["use_gae"]) + + return policy.loss_obj.loss + + +def kl_and_loss_stats(policy, batch_tensors): + policy.explained_variance = explained_variance( + batch_tensors[Postprocessing.VALUE_TARGETS], policy.value_function) + + stats_fetches = { + "cur_kl_coeff": policy.kl_coeff, + "cur_lr": tf.cast(policy.cur_lr, tf.float64), + "total_loss": policy.loss_obj.loss, + "policy_loss": policy.loss_obj.mean_policy_loss, + "vf_loss": policy.loss_obj.mean_vf_loss, + "vf_explained_var": policy.explained_variance, + "kl": policy.loss_obj.mean_kl, + "entropy": policy.loss_obj.mean_entropy, + } + + return stats_fetches + + +def vf_preds_and_logits_fetches(policy): + """Adds value function and logits outputs to experience batches.""" + return { + SampleBatch.VF_PREDS: policy.value_function, + BEHAVIOUR_LOGITS: policy.model.outputs, + } + + +def postprocess_ppo_gae(policy, + sample_batch, + other_agent_batches=None, + episode=None): """Adds the policy logits, VF preds, and advantages to the trajectory.""" - @override(TFPolicyGraph) - def extra_compute_action_fetches(self): - return dict( - TFPolicyGraph.extra_compute_action_fetches(self), **{ - SampleBatch.VF_PREDS: self.value_function, - BEHAVIOUR_LOGITS: self.logits - }) - - @override(PolicyGraph) - def postprocess_trajectory(self, - sample_batch, - other_agent_batches=None, - episode=None): - completed = sample_batch["dones"][-1] - if completed: - last_r = 0.0 - else: - next_state = [] - for i in range(len(self.model.state_in)): - next_state.append([sample_batch["state_out_{}".format(i)][-1]]) - last_r = self._value(sample_batch[SampleBatch.NEXT_OBS][-1], - sample_batch[SampleBatch.ACTIONS][-1], - sample_batch[SampleBatch.REWARDS][-1], - *next_state) - batch = compute_advantages( - sample_batch, - last_r, - self.config["gamma"], - self.config["lambda"], - use_gae=self.config["use_gae"]) - return batch - - -class PPOPolicyGraph(LearningRateSchedule, PPOPostprocessing, TFPolicyGraph): - def __init__(self, - observation_space, - action_space, - config, - existing_inputs=None): - """ - Arguments: - observation_space: Environment observation space specification. - action_space: Environment action space specification. - config (dict): Configuration values for PPO graph. - existing_inputs (list): Optional list of tuples that specify the - placeholders upon which the graph should be built upon. - """ - config = dict(ray.rllib.agents.ppo.ppo.DEFAULT_CONFIG, **config) - self.sess = tf.get_default_session() - self.action_space = action_space - self.config = config - self.kl_coeff_val = self.config["kl_coeff"] - self.kl_target = self.config["kl_target"] - dist_cls, logit_dim = ModelCatalog.get_action_dist( - action_space, self.config["model"]) - - if existing_inputs: - obs_ph, value_targets_ph, adv_ph, act_ph, \ - logits_ph, vf_preds_ph, prev_actions_ph, prev_rewards_ph = \ - existing_inputs[:8] - existing_state_in = existing_inputs[8:-1] - existing_seq_lens = existing_inputs[-1] - else: - obs_ph = tf.placeholder( - tf.float32, - name="obs", - shape=(None, ) + observation_space.shape) - adv_ph = tf.placeholder( - tf.float32, name="advantages", shape=(None, )) - act_ph = ModelCatalog.get_action_placeholder(action_space) - logits_ph = tf.placeholder( - tf.float32, name="logits", shape=(None, logit_dim)) - vf_preds_ph = tf.placeholder( - tf.float32, name="vf_preds", shape=(None, )) - value_targets_ph = tf.placeholder( - tf.float32, name="value_targets", shape=(None, )) - prev_actions_ph = ModelCatalog.get_action_placeholder(action_space) - prev_rewards_ph = tf.placeholder( - tf.float32, [None], name="prev_reward") - existing_state_in = None - existing_seq_lens = None - self.observations = obs_ph - self.prev_actions = prev_actions_ph - self.prev_rewards = prev_rewards_ph - - self.loss_in = [ - (SampleBatch.CUR_OBS, obs_ph), - (Postprocessing.VALUE_TARGETS, value_targets_ph), - (Postprocessing.ADVANTAGES, adv_ph), - (SampleBatch.ACTIONS, act_ph), - (BEHAVIOUR_LOGITS, logits_ph), - (SampleBatch.VF_PREDS, vf_preds_ph), - (SampleBatch.PREV_ACTIONS, prev_actions_ph), - (SampleBatch.PREV_REWARDS, prev_rewards_ph), - ] - self.model = ModelCatalog.get_model( - { - "obs": obs_ph, - "prev_actions": prev_actions_ph, - "prev_rewards": prev_rewards_ph, - "is_training": self._get_is_training_placeholder(), - }, - observation_space, - action_space, - logit_dim, - self.config["model"], - state_in=existing_state_in, - seq_lens=existing_seq_lens) - + completed = sample_batch["dones"][-1] + if completed: + last_r = 0.0 + else: + next_state = [] + for i in range(len(policy.model.state_in)): + next_state.append([sample_batch["state_out_{}".format(i)][-1]]) + last_r = policy._value(sample_batch[SampleBatch.NEXT_OBS][-1], + sample_batch[SampleBatch.ACTIONS][-1], + sample_batch[SampleBatch.REWARDS][-1], + *next_state) + batch = compute_advantages( + sample_batch, + last_r, + policy.config["gamma"], + policy.config["lambda"], + use_gae=policy.config["use_gae"]) + return batch + + +def clip_gradients(policy, optimizer, loss): + if policy.config["grad_clip"] is not None: + policy.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, + tf.get_variable_scope().name) + grads = tf.gradients(loss, policy.var_list) + policy.grads, _ = tf.clip_by_global_norm(grads, + policy.config["grad_clip"]) + clipped_grads = list(zip(policy.grads, policy.var_list)) + return clipped_grads + else: + return optimizer.compute_gradients( + loss, colocate_gradients_with_ops=True) + + +class KLCoeffMixin(object): + def __init__(self, config): # KL Coefficient + self.kl_coeff_val = config["kl_coeff"] + self.kl_target = config["kl_target"] self.kl_coeff = tf.get_variable( initializer=tf.constant_initializer(self.kl_coeff_val), name="kl_coeff", @@ -227,14 +211,22 @@ def __init__(self, trainable=False, dtype=tf.float32) - self.logits = self.model.outputs - curr_action_dist = dist_cls(self.logits) - self.sampler = curr_action_dist.sample() - if self.config["use_gae"]: - if self.config["vf_share_layers"]: + def update_kl(self, sampled_kl): + if sampled_kl > 2.0 * self.kl_target: + self.kl_coeff_val *= 1.5 + elif sampled_kl < 0.5 * self.kl_target: + self.kl_coeff_val *= 0.5 + self.kl_coeff.load(self.kl_coeff_val, session=self._sess) + return self.kl_coeff_val + + +class ValueNetworkMixin(object): + def __init__(self, obs_space, action_space, config): + if config["use_gae"]: + if config["vf_share_layers"]: self.value_function = self.model.value_function() else: - vf_config = self.config["model"].copy() + vf_config = config["model"].copy() # Do not split the last layer of the value function into # mean parameters and standard deviation parameters and # do not make the standard deviations free variables. @@ -249,122 +241,43 @@ def __init__(self, "value_function() method.") with tf.variable_scope("value_function"): self.value_function = ModelCatalog.get_model({ - "obs": obs_ph, - "prev_actions": prev_actions_ph, - "prev_rewards": prev_rewards_ph, + "obs": self._obs_input, + "prev_actions": self._prev_action_input, + "prev_rewards": self._prev_reward_input, "is_training": self._get_is_training_placeholder(), - }, observation_space, action_space, 1, vf_config).outputs + }, obs_space, action_space, 1, vf_config).outputs self.value_function = tf.reshape(self.value_function, [-1]) else: - self.value_function = tf.zeros(shape=tf.shape(obs_ph)[:1]) - - if self.model.state_in: - max_seq_len = tf.reduce_max(self.model.seq_lens) - mask = tf.sequence_mask(self.model.seq_lens, max_seq_len) - mask = tf.reshape(mask, [-1]) - else: - mask = tf.ones_like(adv_ph, dtype=tf.bool) - - self.loss_obj = PPOLoss( - action_space, - value_targets_ph, - adv_ph, - act_ph, - logits_ph, - vf_preds_ph, - curr_action_dist, - self.value_function, - self.kl_coeff, - mask, - entropy_coeff=self.config["entropy_coeff"], - clip_param=self.config["clip_param"], - vf_clip_param=self.config["vf_clip_param"], - vf_loss_coeff=self.config["vf_loss_coeff"], - use_gae=self.config["use_gae"]) - - LearningRateSchedule.__init__(self, self.config["lr"], - self.config["lr_schedule"]) - TFPolicyGraph.__init__( - self, - observation_space, - action_space, - self.sess, - obs_input=obs_ph, - action_sampler=self.sampler, - action_prob=curr_action_dist.sampled_action_prob(), - loss=self.loss_obj.loss, - model=self.model, - loss_inputs=self.loss_in, - state_inputs=self.model.state_in, - state_outputs=self.model.state_out, - prev_action_input=prev_actions_ph, - prev_reward_input=prev_rewards_ph, - seq_lens=self.model.seq_lens, - max_seq_len=config["model"]["max_seq_len"]) - - self.sess.run(tf.global_variables_initializer()) - self.explained_variance = explained_variance(value_targets_ph, - self.value_function) - self.stats_fetches = { - "cur_kl_coeff": self.kl_coeff, - "cur_lr": tf.cast(self.cur_lr, tf.float64), - "total_loss": self.loss_obj.loss, - "policy_loss": self.loss_obj.mean_policy_loss, - "vf_loss": self.loss_obj.mean_vf_loss, - "vf_explained_var": self.explained_variance, - "kl": self.loss_obj.mean_kl, - "entropy": self.loss_obj.mean_entropy - } - - @override(TFPolicyGraph) - def copy(self, existing_inputs): - """Creates a copy of self using existing input placeholders.""" - return PPOPolicyGraph( - self.observation_space, - self.action_space, - self.config, - existing_inputs=existing_inputs) - - @override(TFPolicyGraph) - def gradients(self, optimizer, loss): - if self.config["grad_clip"] is not None: - self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, - tf.get_variable_scope().name) - grads = tf.gradients(loss, self.var_list) - self.grads, _ = tf.clip_by_global_norm(grads, - self.config["grad_clip"]) - clipped_grads = list(zip(self.grads, self.var_list)) - return clipped_grads - else: - return optimizer.compute_gradients( - loss, colocate_gradients_with_ops=True) - - @override(PolicyGraph) - def get_initial_state(self): - return self.model.state_init - - @override(TFPolicyGraph) - def extra_compute_grad_fetches(self): - return {LEARNER_STATS_KEY: self.stats_fetches} - - def update_kl(self, sampled_kl): - if sampled_kl > 2.0 * self.kl_target: - self.kl_coeff_val *= 1.5 - elif sampled_kl < 0.5 * self.kl_target: - self.kl_coeff_val *= 0.5 - self.kl_coeff.load(self.kl_coeff_val, session=self.sess) - return self.kl_coeff_val + self.value_function = tf.zeros(shape=tf.shape(self._obs_input)[:1]) def _value(self, ob, prev_action, prev_reward, *args): feed_dict = { - self.observations: [ob], - self.prev_actions: [prev_action], - self.prev_rewards: [prev_reward], + self._obs_input: [ob], + self._prev_action_input: [prev_action], + self._prev_reward_input: [prev_reward], self.model.seq_lens: [1] } assert len(args) == len(self.model.state_in), \ (args, self.model.state_in) for k, v in zip(self.model.state_in, args): feed_dict[k] = v - vf = self.sess.run(self.value_function, feed_dict) + vf = self._sess.run(self.value_function, feed_dict) return vf[0] + + +def setup_mixins(policy, obs_space, action_space, config): + ValueNetworkMixin.__init__(policy, obs_space, action_space, config) + KLCoeffMixin.__init__(policy, config) + LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"]) + + +PPOTFPolicy = build_tf_policy( + name="PPOTFPolicy", + get_default_config=lambda: ray.rllib.agents.ppo.ppo.DEFAULT_CONFIG, + loss_fn=ppo_surrogate_loss, + stats_fn=kl_and_loss_stats, + extra_action_fetches_fn=vf_preds_and_logits_fetches, + postprocess_fn=postprocess_ppo_gae, + gradients_fn=clip_gradients, + before_loss_init=setup_mixins, + mixins=[LearningRateSchedule, KLCoeffMixin, ValueNetworkMixin]) diff --git a/python/ray/rllib/agents/trainer_template.py b/python/ray/rllib/agents/trainer_template.py new file mode 100644 index 000000000000..618bc3b30ace --- /dev/null +++ b/python/ray/rllib/agents/trainer_template.py @@ -0,0 +1,97 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from ray.rllib.agents.trainer import Trainer +from ray.rllib.optimizers import SyncSamplesOptimizer +from ray.rllib.utils.annotations import override, DeveloperAPI + + +@DeveloperAPI +def build_trainer(name, + default_policy, + default_config=None, + make_policy_optimizer=None, + validate_config=None, + get_policy_class=None, + before_train_step=None, + after_optimizer_step=None, + after_train_result=None): + """Helper function for defining a custom trainer. + + Arguments: + name (str): name of the trainer (e.g., "PPO") + default_policy (cls): the default PolicyGraph class to use + default_config (dict): the default config dict of the algorithm, + otherwises uses the Trainer default config + make_policy_optimizer (func): optional function that returns a + PolicyOptimizer instance given + (local_evaluator, remote_evaluators, config) + validate_config (func): optional callback that checks a given config + for correctness. It may mutate the config as needed. + get_policy_class (func): optional callback that takes a config and + returns the policy graph class to override the default with + before_train_step (func): optional callback to run before each train() + call. It takes the trainer instance as an argument. + after_optimizer_step (func): optional callback to run after each + step() call to the policy optimizer. It takes the trainer instance + and the policy gradient fetches as arguments. + after_train_result (func): optional callback to run at the end of each + train() call. It takes the trainer instance and result dict as + arguments, and may mutate the result dict as needed. + + Returns: + a Trainer instance that uses the specified args. + """ + + if name.endswith("Trainer"): + raise ValueError("Algorithm name should not include *Trainer suffix", + name) + + class trainer_cls(Trainer): + _name = name + _default_config = default_config or Trainer.COMMON_CONFIG + _policy_graph = default_policy + + def _init(self, config, env_creator): + if validate_config: + validate_config(config) + if get_policy_class is None: + policy_graph = default_policy + else: + policy_graph = get_policy_class(config) + self.local_evaluator = self.make_local_evaluator( + env_creator, policy_graph) + self.remote_evaluators = self.make_remote_evaluators( + env_creator, policy_graph, config["num_workers"]) + if make_policy_optimizer: + self.optimizer = make_policy_optimizer( + self.local_evaluator, self.remote_evaluators, config) + else: + optimizer_config = dict( + config["optimizer"], + **{"train_batch_size": config["train_batch_size"]}) + self.optimizer = SyncSamplesOptimizer(self.local_evaluator, + self.remote_evaluators, + **optimizer_config) + + @override(Trainer) + def _train(self): + if before_train_step: + before_train_step(self) + prev_steps = self.optimizer.num_steps_sampled + fetches = self.optimizer.step() + if after_optimizer_step: + after_optimizer_step(self, fetches) + res = self.collect_metrics() + res.update( + timesteps_this_iter=self.optimizer.num_steps_sampled - + prev_steps, + info=res.get("info", {})) + if after_train_result: + after_train_result(self, res) + return res + + trainer_cls.__name__ = name + "Trainer" + trainer_cls.__qualname__ = name + "Trainer" + return trainer_cls diff --git a/python/ray/rllib/evaluation/dynamic_tf_policy_graph.py b/python/ray/rllib/evaluation/dynamic_tf_policy_graph.py new file mode 100644 index 000000000000..73e08fcf9093 --- /dev/null +++ b/python/ray/rllib/evaluation/dynamic_tf_policy_graph.py @@ -0,0 +1,275 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import OrderedDict +import logging +import numpy as np + +from ray.rllib.evaluation.policy_graph import PolicyGraph +from ray.rllib.evaluation.sample_batch import SampleBatch +from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph +from ray.rllib.models.catalog import ModelCatalog +from ray.rllib.utils.annotations import override +from ray.rllib.utils import try_import_tf +from ray.rllib.utils.debug import log_once, summarize +from ray.rllib.utils.tracking_dict import UsageTrackingDict + +tf = try_import_tf() + +logger = logging.getLogger(__name__) + + +class DynamicTFPolicyGraph(TFPolicyGraph): + """A TFPolicyGraph that auto-defines placeholders dynamically at runtime. + + Initialization of this class occurs in two phases. + * Phase 1: the model is created and model variables are initialized. + * Phase 2: a fake batch of data is created, sent to the trajectory + postprocessor, and then used to create placeholders for the loss + function. The loss and stats functions are initialized with these + placeholders. + """ + + def __init__(self, + obs_space, + action_space, + config, + loss_fn, + stats_fn=None, + grad_stats_fn=None, + before_loss_init=None, + make_action_sampler=None, + existing_inputs=None, + get_batch_divisibility_req=None): + """Initialize a dynamic TF policy graph. + + Arguments: + observation_space (gym.Space): Observation space of the policy. + action_space (gym.Space): Action space of the policy. + config (dict): Policy-specific configuration data. + loss_fn (func): function that returns a loss tensor the policy + graph, and dict of experience tensor placeholders + stats_fn (func): optional function that returns a dict of + TF fetches given the policy graph and batch input tensors + grad_stats_fn (func): optional function that returns a dict of + TF fetches given the policy graph and loss gradient tensors + before_loss_init (func): optional function to run prior to loss + init that takes the same arguments as __init__ + make_action_sampler (func): optional function that returns a + tuple of action and action prob tensors. The function takes + (policy, input_dict, obs_space, action_space, config) as its + arguments + existing_inputs (OrderedDict): when copying a policy graph, this + specifies an existing dict of placeholders to use instead of + defining new ones + get_batch_divisibility_req (func): optional function that returns + the divisibility requirement for sample batches + """ + self.config = config + self._loss_fn = loss_fn + self._stats_fn = stats_fn + self._grad_stats_fn = grad_stats_fn + + # Setup standard placeholders + if existing_inputs is not None: + obs = existing_inputs[SampleBatch.CUR_OBS] + prev_actions = existing_inputs[SampleBatch.PREV_ACTIONS] + prev_rewards = existing_inputs[SampleBatch.PREV_REWARDS] + else: + obs = tf.placeholder( + tf.float32, + shape=[None] + list(obs_space.shape), + name="observation") + prev_actions = ModelCatalog.get_action_placeholder(action_space) + prev_rewards = tf.placeholder( + tf.float32, [None], name="prev_reward") + + input_dict = { + "obs": obs, + "prev_actions": prev_actions, + "prev_rewards": prev_rewards, + "is_training": self._get_is_training_placeholder(), + } + + # Create the model network and action outputs + if make_action_sampler: + assert not existing_inputs, \ + "Cloning not supported with custom action sampler" + self.model = None + self.dist_class = None + self.action_dist = None + action_sampler, action_prob = make_action_sampler( + self, input_dict, obs_space, action_space, config) + else: + self.dist_class, logit_dim = ModelCatalog.get_action_dist( + action_space, self.config["model"]) + if existing_inputs: + existing_state_in = [ + v for k, v in existing_inputs.items() + if k.startswith("state_in_") + ] + if existing_state_in: + existing_seq_lens = existing_inputs["seq_lens"] + else: + existing_seq_lens = None + else: + existing_state_in = [] + existing_seq_lens = None + self.model = ModelCatalog.get_model( + input_dict, + obs_space, + action_space, + logit_dim, + self.config["model"], + state_in=existing_state_in, + seq_lens=existing_seq_lens) + self.action_dist = self.dist_class(self.model.outputs) + action_sampler = self.action_dist.sample() + action_prob = self.action_dist.sampled_action_prob() + + # Phase 1 init + sess = tf.get_default_session() + if get_batch_divisibility_req: + batch_divisibility_req = get_batch_divisibility_req(self) + else: + batch_divisibility_req = 1 + TFPolicyGraph.__init__( + self, + obs_space, + action_space, + sess, + obs_input=obs, + action_sampler=action_sampler, + action_prob=action_prob, + loss=None, # dynamically initialized on run + loss_inputs=[], + model=self.model, + state_inputs=self.model and self.model.state_in, + state_outputs=self.model and self.model.state_out, + prev_action_input=prev_actions, + prev_reward_input=prev_rewards, + seq_lens=self.model and self.model.seq_lens, + max_seq_len=config["model"]["max_seq_len"], + batch_divisibility_req=batch_divisibility_req) + + # Phase 2 init + before_loss_init(self, obs_space, action_space, config) + if not existing_inputs: + self._initialize_loss() + + @override(TFPolicyGraph) + def copy(self, existing_inputs): + """Creates a copy of self using existing input placeholders.""" + + # Note that there might be RNN state inputs at the end of the list + if self._state_inputs: + num_state_inputs = len(self._state_inputs) + 1 + else: + num_state_inputs = 0 + if len(self._loss_inputs) + num_state_inputs != len(existing_inputs): + raise ValueError("Tensor list mismatch", self._loss_inputs, + self._state_inputs, existing_inputs) + for i, (k, v) in enumerate(self._loss_inputs): + if v.shape.as_list() != existing_inputs[i].shape.as_list(): + raise ValueError("Tensor shape mismatch", i, k, v.shape, + existing_inputs[i].shape) + # By convention, the loss inputs are followed by state inputs and then + # the seq len tensor + rnn_inputs = [] + for i in range(len(self._state_inputs)): + rnn_inputs.append(("state_in_{}".format(i), + existing_inputs[len(self._loss_inputs) + i])) + if rnn_inputs: + rnn_inputs.append(("seq_lens", existing_inputs[-1])) + input_dict = OrderedDict( + [(k, existing_inputs[i]) + for i, (k, _) in enumerate(self._loss_inputs)] + rnn_inputs) + instance = self.__class__( + self.observation_space, + self.action_space, + self.config, + existing_inputs=input_dict) + loss = instance._loss_fn(instance, input_dict) + if instance._stats_fn: + instance._stats_fetches.update( + instance._stats_fn(instance, input_dict)) + TFPolicyGraph._initialize_loss( + instance, loss, [(k, existing_inputs[i]) + for i, (k, _) in enumerate(self._loss_inputs)]) + if instance._grad_stats_fn: + instance._stats_fetches.update( + instance._grad_stats_fn(instance, instance._grads)) + return instance + + @override(PolicyGraph) + def get_initial_state(self): + if self.model: + return self.model.state_init + else: + return [] + + def _initialize_loss(self): + def fake_array(tensor): + shape = tensor.shape.as_list() + shape[0] = 1 + return np.zeros(shape, dtype=tensor.dtype.as_numpy_dtype) + + dummy_batch = { + SampleBatch.PREV_ACTIONS: fake_array(self._prev_action_input), + SampleBatch.PREV_REWARDS: fake_array(self._prev_reward_input), + SampleBatch.CUR_OBS: fake_array(self._obs_input), + SampleBatch.NEXT_OBS: fake_array(self._obs_input), + SampleBatch.ACTIONS: fake_array(self._prev_action_input), + SampleBatch.REWARDS: np.array([0], dtype=np.float32), + SampleBatch.DONES: np.array([False], dtype=np.bool), + } + state_init = self.get_initial_state() + for i, h in enumerate(state_init): + dummy_batch["state_in_{}".format(i)] = np.expand_dims(h, 0) + dummy_batch["state_out_{}".format(i)] = np.expand_dims(h, 0) + if state_init: + dummy_batch["seq_lens"] = np.array([1], dtype=np.int32) + for k, v in self.extra_compute_action_fetches().items(): + dummy_batch[k] = fake_array(v) + + # postprocessing might depend on variable init, so run it first here + self._sess.run(tf.global_variables_initializer()) + postprocessed_batch = self.postprocess_trajectory( + SampleBatch(dummy_batch)) + + batch_tensors = UsageTrackingDict({ + SampleBatch.PREV_ACTIONS: self._prev_action_input, + SampleBatch.PREV_REWARDS: self._prev_reward_input, + SampleBatch.CUR_OBS: self._obs_input, + }) + loss_inputs = [ + (SampleBatch.PREV_ACTIONS, self._prev_action_input), + (SampleBatch.PREV_REWARDS, self._prev_reward_input), + (SampleBatch.CUR_OBS, self._obs_input), + ] + + for k, v in postprocessed_batch.items(): + if k in batch_tensors: + continue + elif v.dtype == np.object: + continue # can't handle arbitrary objects in TF + shape = (None, ) + v.shape[1:] + dtype = np.float32 if v.dtype == np.float64 else v.dtype + placeholder = tf.placeholder(dtype, shape=shape, name=k) + batch_tensors[k] = placeholder + + if log_once("loss_init"): + logger.info( + "Initializing loss function with dummy input:\n\n{}\n".format( + summarize(batch_tensors))) + + loss = self._loss_fn(self, batch_tensors) + if self._stats_fn: + self._stats_fetches.update(self._stats_fn(self, batch_tensors)) + for k in sorted(batch_tensors.accessed_keys): + loss_inputs.append((k, batch_tensors[k])) + TFPolicyGraph._initialize_loss(self, loss, loss_inputs) + if self._grad_stats_fn: + self._stats_fetches.update(self._grad_stats_fn(self, self._grads)) + self._sess.run(tf.global_variables_initializer()) diff --git a/python/ray/rllib/evaluation/policy_evaluator.py b/python/ray/rllib/evaluation/policy_evaluator.py index f6761122156e..48e19dfcb96e 100644 --- a/python/ray/rllib/evaluation/policy_evaluator.py +++ b/python/ray/rllib/evaluation/policy_evaluator.py @@ -65,7 +65,7 @@ class PolicyEvaluator(EvaluatorInterface): >>> # Create a policy evaluator and using it to collect experiences. >>> evaluator = PolicyEvaluator( ... env_creator=lambda _: gym.make("CartPole-v0"), - ... policy_graph=PGPolicyGraph) + ... policy_graph=PGTFPolicy) >>> print(evaluator.sample()) SampleBatch({ "obs": [[...]], "actions": [[...]], "rewards": [[...]], @@ -76,7 +76,7 @@ class PolicyEvaluator(EvaluatorInterface): ... evaluator_cls=PolicyEvaluator, ... evaluator_args={ ... "env_creator": lambda _: gym.make("CartPole-v0"), - ... "policy_graph": PGPolicyGraph, + ... "policy_graph": PGTFPolicy, ... }, ... num_workers=10) >>> for _ in range(10): optimizer.step() @@ -87,12 +87,12 @@ class PolicyEvaluator(EvaluatorInterface): ... policy_graphs={ ... # Use an ensemble of two policies for car agents ... "car_policy1": - ... (PGPolicyGraph, Box(...), Discrete(...), {"gamma": 0.99}), + ... (PGTFPolicy, Box(...), Discrete(...), {"gamma": 0.99}), ... "car_policy2": - ... (PGPolicyGraph, Box(...), Discrete(...), {"gamma": 0.95}), + ... (PGTFPolicy, Box(...), Discrete(...), {"gamma": 0.95}), ... # Use a single shared policy for all traffic lights ... "traffic_light_policy": - ... (PGPolicyGraph, Box(...), Discrete(...), {}), + ... (PGTFPolicy, Box(...), Discrete(...), {}), ... }, ... policy_mapping_fn=lambda agent_id: ... random.choice(["car_policy1", "car_policy2"]) diff --git a/python/ray/rllib/evaluation/tf_policy_graph.py b/python/ray/rllib/evaluation/tf_policy_graph.py index 2b1eca9e8d5b..b921e6cfb0d1 100644 --- a/python/ray/rllib/evaluation/tf_policy_graph.py +++ b/python/ray/rllib/evaluation/tf_policy_graph.py @@ -112,24 +112,45 @@ def __init__(self, self._prev_action_input = prev_action_input self._prev_reward_input = prev_reward_input self._sampler = action_sampler - self._loss_inputs = loss_inputs - self._loss_input_dict = dict(self._loss_inputs) self._is_training = self._get_is_training_placeholder() self._action_prob = action_prob self._state_inputs = state_inputs or [] self._state_outputs = state_outputs or [] - for i, ph in enumerate(self._state_inputs): - self._loss_input_dict["state_in_{}".format(i)] = ph self._seq_lens = seq_lens self._max_seq_len = max_seq_len self._batch_divisibility_req = batch_divisibility_req + self._update_ops = update_ops + self._stats_fetches = {} + + if loss is not None: + self._initialize_loss(loss, loss_inputs) + else: + self._loss = None + + if len(self._state_inputs) != len(self._state_outputs): + raise ValueError( + "Number of state input and output tensors must match, got: " + "{} vs {}".format(self._state_inputs, self._state_outputs)) + if len(self.get_initial_state()) != len(self._state_inputs): + raise ValueError( + "Length of initial state must match number of state inputs, " + "got: {} vs {}".format(self.get_initial_state(), + self._state_inputs)) + if self._state_inputs and self._seq_lens is None: + raise ValueError( + "seq_lens tensor must be given if state inputs are defined") + + def _initialize_loss(self, loss, loss_inputs): + self._loss_inputs = loss_inputs + self._loss_input_dict = dict(self._loss_inputs) + for i, ph in enumerate(self._state_inputs): + self._loss_input_dict["state_in_{}".format(i)] = ph if self.model: self._loss = self.model.custom_loss(loss, self._loss_input_dict) - self._stats_fetches = {"model": self.model.custom_stats()} + self._stats_fetches.update({"model": self.model.custom_stats()}) else: self._loss = loss - self._stats_fetches = {} self._optimizer = self.optimizer() self._grads_and_vars = [ @@ -141,9 +162,7 @@ def __init__(self, self._loss, self._sess) # gather update ops for any batch norm layers - if update_ops: - self._update_ops = update_ops - else: + if not self._update_ops: self._update_ops = tf.get_collection( tf.GraphKeys.UPDATE_OPS, scope=tf.get_variable_scope().name) if self._update_ops: @@ -153,21 +172,12 @@ def __init__(self, self._apply_op = self.build_apply_op(self._optimizer, self._grads_and_vars) - if len(self._state_inputs) != len(self._state_outputs): - raise ValueError( - "Number of state input and output tensors must match, got: " - "{} vs {}".format(self._state_inputs, self._state_outputs)) - if len(self.get_initial_state()) != len(self._state_inputs): - raise ValueError( - "Length of initial state must match number of state inputs, " - "got: {} vs {}".format(self.get_initial_state(), - self._state_inputs)) - if self._state_inputs and self._seq_lens is None: - raise ValueError( - "seq_lens tensor must be given if state inputs are defined") + if log_once("loss_used"): + logger.debug( + "These tensors were used in the loss_fn:\n\n{}\n".format( + summarize(self._loss_input_dict))) - logger.debug("Created {} with loss inputs: {}".format( - self, self._loss_input_dict)) + self._sess.run(tf.global_variables_initializer()) @override(PolicyGraph) def compute_actions(self, @@ -186,18 +196,21 @@ def compute_actions(self, @override(PolicyGraph) def compute_gradients(self, postprocessed_batch): + assert self._loss is not None, "Loss not initialized" builder = TFRunBuilder(self._sess, "compute_gradients") fetches = self._build_compute_gradients(builder, postprocessed_batch) return builder.get(fetches) @override(PolicyGraph) def apply_gradients(self, gradients): + assert self._loss is not None, "Loss not initialized" builder = TFRunBuilder(self._sess, "apply_gradients") fetches = self._build_apply_gradients(builder, gradients) builder.get(fetches) @override(PolicyGraph) def learn_on_batch(self, postprocessed_batch): + assert self._loss is not None, "Loss not initialized" builder = TFRunBuilder(self._sess, "learn_on_batch") fetches = self._build_learn_on_batch(builder, postprocessed_batch) return builder.get(fetches) @@ -271,7 +284,10 @@ def extra_compute_grad_fetches(self): @DeveloperAPI def optimizer(self): """TF optimizer to use for policy optimization.""" - return tf.train.AdamOptimizer() + if hasattr(self, "config"): + return tf.train.AdamOptimizer(self.config["lr"]) + else: + return tf.train.AdamOptimizer() @DeveloperAPI def gradients(self, optimizer, loss): diff --git a/python/ray/rllib/evaluation/tf_policy_template.py b/python/ray/rllib/evaluation/tf_policy_template.py new file mode 100644 index 000000000000..b2549e973a65 --- /dev/null +++ b/python/ray/rllib/evaluation/tf_policy_template.py @@ -0,0 +1,146 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from ray.rllib.evaluation.dynamic_tf_policy_graph import DynamicTFPolicyGraph +from ray.rllib.evaluation.policy_graph import PolicyGraph +from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph +from ray.rllib.utils.annotations import override, DeveloperAPI + + +@DeveloperAPI +def build_tf_policy(name, + loss_fn, + get_default_config=None, + stats_fn=None, + grad_stats_fn=None, + extra_action_fetches_fn=None, + postprocess_fn=None, + optimizer_fn=None, + gradients_fn=None, + before_init=None, + before_loss_init=None, + after_init=None, + make_action_sampler=None, + mixins=None, + get_batch_divisibility_req=None): + """Helper function for creating a dynamic tf policy at runtime. + + Arguments: + name (str): name of the graph (e.g., "PPOPolicy") + loss_fn (func): function that returns a loss tensor the policy, + and dict of experience tensor placeholders + get_default_config (func): optional function that returns the default + config to merge with any overrides + stats_fn (func): optional function that returns a dict of + TF fetches given the policy and batch input tensors + grad_stats_fn (func): optional function that returns a dict of + TF fetches given the policy and loss gradient tensors + extra_action_fetches_fn (func): optional function that returns + a dict of TF fetches given the policy object + postprocess_fn (func): optional experience postprocessing function + that takes the same args as PolicyGraph.postprocess_trajectory() + optimizer_fn (func): optional function that returns a tf.Optimizer + given the policy and config + gradients_fn (func): optional function that returns a list of gradients + given a tf optimizer and loss tensor. If not specified, this + defaults to optimizer.compute_gradients(loss) + before_init (func): optional function to run at the beginning of + policy init that takes the same arguments as the policy constructor + before_loss_init (func): optional function to run prior to loss + init that takes the same arguments as the policy constructor + after_init (func): optional function to run at the end of policy init + that takes the same arguments as the policy constructor + make_action_sampler (func): optional function that returns a + tuple of action and action prob tensors. The function takes + (policy, input_dict, obs_space, action_space, config) as its + arguments + mixins (list): list of any class mixins for the returned policy class. + These mixins will be applied in order and will have higher + precedence than the DynamicTFPolicyGraph class + get_batch_divisibility_req (func): optional function that returns + the divisibility requirement for sample batches + + Returns: + a DynamicTFPolicyGraph instance that uses the specified args + """ + + if not name.endswith("TFPolicy"): + raise ValueError("Name should match *TFPolicy", name) + + base = DynamicTFPolicyGraph + while mixins: + + class new_base(mixins.pop(), base): + pass + + base = new_base + + class graph_cls(base): + def __init__(self, + obs_space, + action_space, + config, + existing_inputs=None): + if get_default_config: + config = dict(get_default_config(), **config) + + if before_init: + before_init(self, obs_space, action_space, config) + + def before_loss_init_wrapper(policy, obs_space, action_space, + config): + if before_loss_init: + before_loss_init(policy, obs_space, action_space, config) + if extra_action_fetches_fn is None: + self._extra_action_fetches = {} + else: + self._extra_action_fetches = extra_action_fetches_fn(self) + + DynamicTFPolicyGraph.__init__( + self, + obs_space, + action_space, + config, + loss_fn, + stats_fn=stats_fn, + grad_stats_fn=grad_stats_fn, + before_loss_init=before_loss_init_wrapper, + existing_inputs=existing_inputs) + + if after_init: + after_init(self, obs_space, action_space, config) + + @override(PolicyGraph) + def postprocess_trajectory(self, + sample_batch, + other_agent_batches=None, + episode=None): + if not postprocess_fn: + return sample_batch + return postprocess_fn(self, sample_batch, other_agent_batches, + episode) + + @override(TFPolicyGraph) + def optimizer(self): + if optimizer_fn: + return optimizer_fn(self, self.config) + else: + return TFPolicyGraph.optimizer(self) + + @override(TFPolicyGraph) + def gradients(self, optimizer, loss): + if gradients_fn: + return gradients_fn(self, optimizer, loss) + else: + return TFPolicyGraph.gradients(self, optimizer, loss) + + @override(TFPolicyGraph) + def extra_compute_action_fetches(self): + return dict( + TFPolicyGraph.extra_compute_action_fetches(self), + **self._extra_action_fetches) + + graph_cls.__name__ = name + graph_cls.__qualname__ = name + return graph_cls diff --git a/python/ray/rllib/evaluation/torch_policy_graph.py b/python/ray/rllib/evaluation/torch_policy_graph.py index fb5c879a1ab8..ccf1b9eeb81d 100644 --- a/python/ray/rllib/evaluation/torch_policy_graph.py +++ b/python/ray/rllib/evaluation/torch_policy_graph.py @@ -15,6 +15,7 @@ from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY from ray.rllib.evaluation.policy_graph import PolicyGraph from ray.rllib.utils.annotations import override +from ray.rllib.utils.tracking_dict import UsageTrackingDict class TorchPolicyGraph(PolicyGraph): @@ -30,7 +31,7 @@ class TorchPolicyGraph(PolicyGraph): """ def __init__(self, observation_space, action_space, model, loss, - loss_inputs, action_distribution_cls): + action_distribution_cls): """Build a policy graph from policy and loss torch modules. Note that model will be placed on GPU device if CUDA_VISIBLE_DEVICES @@ -42,13 +43,8 @@ def __init__(self, observation_space, action_space, model, loss, model (nn.Module): PyTorch policy module. Given observations as input, this module must return a list of outputs where the first item is action logits, and the rest can be any value. - loss (nn.Module): Loss defined as a PyTorch module. The inputs for - this module are defined by the `loss_inputs` param. This module - returns a single scalar loss. Note that this module should - internally be using the model module. - loss_inputs (list): List of SampleBatch columns that will be - passed to the loss module's forward() function when computing - the loss. For example, ["obs", "action", "advantages"]. + loss (func): Function that takes (policy_graph, batch_tensors) + and returns a single scalar loss. action_distribution_cls (ActionDistribution): Class for action distribution. """ @@ -60,7 +56,6 @@ def __init__(self, observation_space, action_space, model, loss, else torch.device("cpu")) self._model = model.to(self.device) self._loss = loss - self._loss_inputs = loss_inputs self._optimizer = self.optimizer() self._action_dist_cls = action_distribution_cls @@ -87,30 +82,26 @@ def compute_actions(self, @override(PolicyGraph) def learn_on_batch(self, postprocessed_batch): + batch_tensors = self._lazy_tensor_dict(postprocessed_batch) + with self.lock: - loss_in = [] - for key in self._loss_inputs: - loss_in.append( - torch.from_numpy(postprocessed_batch[key]).to(self.device)) - loss_out = self._loss(self._model, *loss_in) + loss_out = self._loss(self, batch_tensors) self._optimizer.zero_grad() loss_out.backward() grad_process_info = self.extra_grad_process() self._optimizer.step() - grad_info = self.extra_grad_info() + grad_info = self.extra_grad_info(batch_tensors) grad_info.update(grad_process_info) return {LEARNER_STATS_KEY: grad_info} @override(PolicyGraph) def compute_gradients(self, postprocessed_batch): + batch_tensors = self._lazy_tensor_dict(postprocessed_batch) + with self.lock: - loss_in = [] - for key in self._loss_inputs: - loss_in.append( - torch.from_numpy(postprocessed_batch[key]).to(self.device)) - loss_out = self._loss(self._model, *loss_in) + loss_out = self._loss(self, batch_tensors) self._optimizer.zero_grad() loss_out.backward() @@ -125,7 +116,7 @@ def compute_gradients(self, postprocessed_batch): else: grads.append(None) - grad_info = self.extra_grad_info() + grad_info = self.extra_grad_info(batch_tensors) grad_info.update(grad_process_info) return grads, {LEARNER_STATS_KEY: grad_info} @@ -163,11 +154,21 @@ def extra_action_out(self, model_out): model_out (list): Outputs of the policy model module.""" return {} - def extra_grad_info(self): + def extra_grad_info(self, batch_tensors): """Return dict of extra grad info.""" return {} def optimizer(self): """Custom PyTorch optimizer to use.""" - return torch.optim.Adam(self._model.parameters()) + if hasattr(self, "config"): + return torch.optim.Adam( + self._model.parameters(), lr=self.config["lr"]) + else: + return torch.optim.Adam(self._model.parameters()) + + def _lazy_tensor_dict(self, postprocessed_batch): + batch_tensors = UsageTrackingDict(postprocessed_batch) + batch_tensors.set_get_interceptor( + lambda arr: torch.from_numpy(arr).to(self.device)) + return batch_tensors diff --git a/python/ray/rllib/evaluation/torch_policy_template.py b/python/ray/rllib/evaluation/torch_policy_template.py new file mode 100644 index 000000000000..7f65c2b963b8 --- /dev/null +++ b/python/ray/rllib/evaluation/torch_policy_template.py @@ -0,0 +1,133 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from ray.rllib.evaluation.policy_graph import PolicyGraph +from ray.rllib.evaluation.torch_policy_graph import TorchPolicyGraph +from ray.rllib.models.catalog import ModelCatalog +from ray.rllib.utils.annotations import override, DeveloperAPI + + +@DeveloperAPI +def build_torch_policy(name, + loss_fn, + get_default_config=None, + stats_fn=None, + postprocess_fn=None, + extra_action_out_fn=None, + extra_grad_process_fn=None, + optimizer_fn=None, + before_init=None, + after_init=None, + make_model_and_action_dist=None, + mixins=None): + """Helper function for creating a torch policy at runtime. + + Arguments: + name (str): name of the graph (e.g., "PPOPolicy") + loss_fn (func): function that returns a loss tensor the policy, + and dict of experience tensor placeholders + get_default_config (func): optional function that returns the default + config to merge with any overrides + stats_fn (func): optional function that returns a dict of + values given the policy and batch input tensors + postprocess_fn (func): optional experience postprocessing function + that takes the same args as PolicyGraph.postprocess_trajectory() + extra_action_out_fn (func): optional function that returns + a dict of extra values to include in experiences + extra_grad_process_fn (func): optional function that is called after + gradients are computed and returns processing info + optimizer_fn (func): optional function that returns a torch optimizer + given the policy and config + before_init (func): optional function to run at the beginning of + policy init that takes the same arguments as the policy constructor + after_init (func): optional function to run at the end of policy init + that takes the same arguments as the policy constructor + make_model_and_action_dist (func): optional func that takes the same + arguments as policy init and returns a tuple of model instance and + torch action distribution class. If not specified, the default + model and action dist from the catalog will be used + mixins (list): list of any class mixins for the returned policy class. + These mixins will be applied in order and will have higher + precedence than the TorchPolicyGraph class + + Returns: + a TorchPolicyGraph instance that uses the specified args + """ + + if not name.endswith("TorchPolicy"): + raise ValueError("Name should match *TorchPolicy", name) + + base = TorchPolicyGraph + while mixins: + + class new_base(mixins.pop(), base): + pass + + base = new_base + + class graph_cls(base): + def __init__(self, obs_space, action_space, config): + if get_default_config: + config = dict(get_default_config(), **config) + self.config = config + + if before_init: + before_init(self, obs_space, action_space, config) + + if make_model_and_action_dist: + self.model, self.dist_class = make_model_and_action_dist( + self, obs_space, action_space, config) + else: + self.dist_class, logit_dim = ModelCatalog.get_action_dist( + action_space, self.config["model"], torch=True) + self.model = ModelCatalog.get_torch_model( + obs_space, logit_dim, self.config["model"]) + + TorchPolicyGraph.__init__(self, obs_space, action_space, + self.model, loss_fn, self.dist_class) + + if after_init: + after_init(self, obs_space, action_space, config) + + @override(PolicyGraph) + def postprocess_trajectory(self, + sample_batch, + other_agent_batches=None, + episode=None): + if not postprocess_fn: + return sample_batch + return postprocess_fn(self, sample_batch, other_agent_batches, + episode) + + @override(TorchPolicyGraph) + def extra_grad_process(self): + if extra_grad_process_fn: + return extra_grad_process_fn(self) + else: + return TorchPolicyGraph.extra_grad_process(self) + + @override(TorchPolicyGraph) + def extra_action_out(self, model_out): + if extra_action_out_fn: + return extra_action_out_fn(self, model_out) + else: + return TorchPolicyGraph.extra_action_out(self, model_out) + + @override(TorchPolicyGraph) + def optimizer(self): + if optimizer_fn: + return optimizer_fn(self, self.config) + else: + return TorchPolicyGraph.optimizer(self) + + @override(TorchPolicyGraph) + def extra_grad_info(self, batch_tensors): + if stats_fn: + return stats_fn(self, batch_tensors) + else: + return TorchPolicyGraph.extra_grad_info(self, batch_tensors) + + graph_cls.__name__ = name + graph_cls.__qualname__ = name + return graph_cls diff --git a/python/ray/rllib/examples/multiagent_two_trainers.py b/python/ray/rllib/examples/multiagent_two_trainers.py index 2c18f2bf4b96..1d4257e4eb9d 100644 --- a/python/ray/rllib/examples/multiagent_two_trainers.py +++ b/python/ray/rllib/examples/multiagent_two_trainers.py @@ -18,7 +18,7 @@ from ray.rllib.agents.dqn.dqn import DQNTrainer from ray.rllib.agents.dqn.dqn_policy_graph import DQNPolicyGraph from ray.rllib.agents.ppo.ppo import PPOTrainer -from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph +from ray.rllib.agents.ppo.ppo_policy_graph import PPOTFPolicy from ray.rllib.tests.test_multi_agent_env import MultiCartpole from ray.tune.logger import pretty_print from ray.tune.registry import register_env @@ -39,7 +39,7 @@ # You can also have multiple policy graphs per trainer, but here we just # show one each for PPO and DQN. policy_graphs = { - "ppo_policy": (PPOPolicyGraph, obs_space, act_space, {}), + "ppo_policy": (PPOTFPolicy, obs_space, act_space, {}), "dqn_policy": (DQNPolicyGraph, obs_space, act_space, {}), } diff --git a/python/ray/rllib/optimizers/multi_gpu_impl.py b/python/ray/rllib/optimizers/multi_gpu_impl.py index d892dbe7dbac..8d1bbd4fb54d 100644 --- a/python/ray/rllib/optimizers/multi_gpu_impl.py +++ b/python/ray/rllib/optimizers/multi_gpu_impl.py @@ -255,7 +255,7 @@ def optimize(self, sess, batch_index): fetches = {"train": self._train_op} for tower in self._towers: - fetches.update(tower.loss_graph.extra_compute_grad_fetches()) + fetches.update(tower.loss_graph._get_grad_and_stats_fetches()) return sess.run(fetches, feed_dict=feed_dict) diff --git a/python/ray/rllib/optimizers/multi_gpu_optimizer.py b/python/ray/rllib/optimizers/multi_gpu_optimizer.py index 45df865e43ff..de2671e6a932 100644 --- a/python/ray/rllib/optimizers/multi_gpu_optimizer.py +++ b/python/ray/rllib/optimizers/multi_gpu_optimizer.py @@ -222,6 +222,6 @@ def stats(self): def _averaged(kv): out = {} for k, v in kv.items(): - if v[0] is not None: + if v[0] is not None and not isinstance(v[0], dict): out[k] = np.mean(v) return out diff --git a/python/ray/rllib/tests/test_external_multi_agent_env.py b/python/ray/rllib/tests/test_external_multi_agent_env.py index e5e182b38655..c01e6fa0b7ae 100644 --- a/python/ray/rllib/tests/test_external_multi_agent_env.py +++ b/python/ray/rllib/tests/test_external_multi_agent_env.py @@ -8,7 +8,7 @@ import unittest import ray -from ray.rllib.agents.pg.pg_policy_graph import PGPolicyGraph +from ray.rllib.agents.pg.pg_policy_graph import PGTFPolicy from ray.rllib.optimizers import SyncSamplesOptimizer from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv @@ -67,7 +67,7 @@ def testTrainExternalMultiCartpoleManyPolicies(self): obs_space = single_env.observation_space policies = {} for i in range(20): - policies["pg_{}".format(i)] = (PGPolicyGraph, obs_space, act_space, + policies["pg_{}".format(i)] = (PGTFPolicy, obs_space, act_space, {}) policy_ids = list(policies.keys()) ev = PolicyEvaluator( diff --git a/python/ray/rllib/tests/test_io.py b/python/ray/rllib/tests/test_io.py index 9f92c9107c4e..0706be1019cc 100644 --- a/python/ray/rllib/tests/test_io.py +++ b/python/ray/rllib/tests/test_io.py @@ -15,7 +15,7 @@ import ray from ray.rllib.agents.pg import PGTrainer -from ray.rllib.agents.pg.pg_policy_graph import PGPolicyGraph +from ray.rllib.agents.pg.pg_policy_graph import PGTFPolicy from ray.rllib.evaluation import SampleBatch from ray.rllib.offline import IOContext, JsonWriter, JsonReader from ray.rllib.offline.json_writer import _to_json @@ -159,7 +159,7 @@ def testMultiAgent(self): def gen_policy(): obs_space = single_env.observation_space act_space = single_env.action_space - return (PGPolicyGraph, obs_space, act_space, {}) + return (PGTFPolicy, obs_space, act_space, {}) pg = PGTrainer( env="multi_cartpole", diff --git a/python/ray/rllib/tests/test_multi_agent_env.py b/python/ray/rllib/tests/test_multi_agent_env.py index eccb9aa82fb8..72130712d555 100644 --- a/python/ray/rllib/tests/test_multi_agent_env.py +++ b/python/ray/rllib/tests/test_multi_agent_env.py @@ -8,7 +8,7 @@ import ray from ray.rllib.agents.pg import PGTrainer -from ray.rllib.agents.pg.pg_policy_graph import PGPolicyGraph +from ray.rllib.agents.pg.pg_policy_graph import PGTFPolicy from ray.rllib.agents.dqn.dqn_policy_graph import DQNPolicyGraph from ray.rllib.optimizers import (SyncSamplesOptimizer, SyncReplayOptimizer, AsyncGradientsOptimizer) @@ -470,7 +470,7 @@ def get_initial_state(self): self.assertEqual(batch["state_out_0"][1], h) def testReturningModelBasedRolloutsData(self): - class ModelBasedPolicyGraph(PGPolicyGraph): + class ModelBasedPolicyGraph(PGTFPolicy): def compute_actions(self, obs_batch, state_batches, @@ -584,7 +584,7 @@ def _testWithOptimizer(self, optimizer_cls): } else: policies = { - "p1": (PGPolicyGraph, obs_space, act_space, {}), + "p1": (PGTFPolicy, obs_space, act_space, {}), "p2": (DQNPolicyGraph, obs_space, act_space, dqn_config), } ev = PolicyEvaluator( @@ -640,7 +640,7 @@ def testTrainMultiCartpoleManyPolicies(self): obs_space = env.observation_space policies = {} for i in range(20): - policies["pg_{}".format(i)] = (PGPolicyGraph, obs_space, act_space, + policies["pg_{}".format(i)] = (PGTFPolicy, obs_space, act_space, {}) policy_ids = list(policies.keys()) ev = PolicyEvaluator( diff --git a/python/ray/rllib/tests/test_nested_spaces.py b/python/ray/rllib/tests/test_nested_spaces.py index e4285e42287c..b70bd9a2908e 100644 --- a/python/ray/rllib/tests/test_nested_spaces.py +++ b/python/ray/rllib/tests/test_nested_spaces.py @@ -12,7 +12,7 @@ import ray from ray.rllib.agents.a3c import A2CTrainer from ray.rllib.agents.pg import PGTrainer -from ray.rllib.agents.pg.pg_policy_graph import PGPolicyGraph +from ray.rllib.agents.pg.pg_policy_graph import PGTFPolicy from ray.rllib.env import MultiAgentEnv from ray.rllib.env.base_env import BaseEnv from ray.rllib.env.vector_env import VectorEnv @@ -333,10 +333,10 @@ def testMultiAgentComplexSpaces(self): "multiagent": { "policy_graphs": { "tuple_policy": ( - PGPolicyGraph, TUPLE_SPACE, act_space, + PGTFPolicy, TUPLE_SPACE, act_space, {"model": {"custom_model": "tuple_spy"}}), "dict_policy": ( - PGPolicyGraph, DICT_SPACE, act_space, + PGTFPolicy, DICT_SPACE, act_space, {"model": {"custom_model": "dict_spy"}}), }, "policy_mapping_fn": lambda a: { diff --git a/python/ray/rllib/tests/test_optimizers.py b/python/ray/rllib/tests/test_optimizers.py index 9c9e6b56b426..5436baeafa90 100644 --- a/python/ray/rllib/tests/test_optimizers.py +++ b/python/ray/rllib/tests/test_optimizers.py @@ -9,7 +9,7 @@ import ray from ray.rllib.agents.ppo import PPOTrainer -from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph +from ray.rllib.agents.ppo.ppo_policy_graph import PPOTFPolicy from ray.rllib.evaluation import SampleBatch from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator from ray.rllib.optimizers import AsyncGradientsOptimizer, AsyncSamplesOptimizer @@ -240,12 +240,12 @@ def make_sess(): local = PolicyEvaluator( env_creator=lambda _: gym.make("CartPole-v0"), - policy_graph=PPOPolicyGraph, + policy_graph=PPOTFPolicy, tf_session_creator=make_sess) remotes = [ PolicyEvaluator.as_remote().remote( env_creator=lambda _: gym.make("CartPole-v0"), - policy_graph=PPOPolicyGraph, + policy_graph=PPOTFPolicy, tf_session_creator=make_sess) ] return local, remotes diff --git a/python/ray/rllib/utils/tracking_dict.py b/python/ray/rllib/utils/tracking_dict.py new file mode 100644 index 000000000000..c0f145734e78 --- /dev/null +++ b/python/ray/rllib/utils/tracking_dict.py @@ -0,0 +1,32 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +class UsageTrackingDict(dict): + """Dict that tracks which keys have been accessed. + + It can also intercept gets and allow an arbitrary callback to be applied + (i.e., to lazily convert numpy arrays to Tensors). + + We make the simplifying assumption only __getitem__ is used to access + values. + """ + + def __init__(self, *args, **kwargs): + dict.__init__(self, *args, **kwargs) + self.accessed_keys = set() + self.intercepted_values = {} + self.get_interceptor = None + + def set_get_interceptor(self, fn): + self.get_interceptor = fn + + def __getitem__(self, key): + self.accessed_keys.add(key) + value = dict.__getitem__(self, key) + if self.get_interceptor: + if key not in self.intercepted_values: + self.intercepted_values[key] = self.get_interceptor(value) + value = self.intercepted_values[key] + return value From 02583a8598cd3e0442188ad0978f154ccdcec742 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Mon, 20 May 2019 16:46:05 -0700 Subject: [PATCH 024/118] [rllib] Rename PolicyGraph => Policy, move from evaluation/ to policy/ (#4819) This implements some of the renames proposed in #4813 We leave behind backwards-compatibility aliases for *PolicyGraph and SampleBatch. --- doc/source/rllib-algorithms.rst | 2 +- doc/source/rllib-concepts.rst | 22 +- doc/source/rllib-env.rst | 18 +- doc/source/rllib-models.rst | 26 +- doc/source/rllib-offline.rst | 6 +- doc/source/rllib.rst | 4 +- python/ray/rllib/__init__.py | 8 +- python/ray/rllib/agents/a3c/__init__.py | 6 +- python/ray/rllib/agents/a3c/a3c.py | 8 +- ...3c_tf_policy_graph.py => a3c_tf_policy.py} | 26 +- ...ch_policy_graph.py => a3c_torch_policy.py} | 4 +- python/ray/rllib/agents/agent.py | 4 +- python/ray/rllib/agents/ars/__init__.py | 4 +- python/ray/rllib/agents/ars/ars.py | 2 +- python/ray/rllib/agents/ddpg/__init__.py | 6 +- python/ray/rllib/agents/ddpg/ddpg.py | 4 +- .../{ddpg_policy_graph.py => ddpg_policy.py} | 38 +- python/ray/rllib/agents/dqn/__init__.py | 6 +- python/ray/rllib/agents/dqn/dqn.py | 10 +- .../{dqn_policy_graph.py => dqn_policy.py} | 52 +- python/ray/rllib/agents/es/__init__.py | 4 +- python/ray/rllib/agents/es/es.py | 2 +- python/ray/rllib/agents/impala/__init__.py | 4 +- python/ray/rllib/agents/impala/impala.py | 14 +- ...trace_policy_graph.py => vtrace_policy.py} | 35 +- python/ray/rllib/agents/marwil/marwil.py | 8 +- ...arwil_policy_graph.py => marwil_policy.py} | 20 +- python/ray/rllib/agents/pg/__init__.py | 4 +- python/ray/rllib/agents/pg/pg.py | 4 +- .../pg/{pg_policy_graph.py => pg_policy.py} | 4 +- ..._pg_policy_graph.py => torch_pg_policy.py} | 4 +- python/ray/rllib/agents/ppo/__init__.py | 4 +- python/ray/rllib/agents/ppo/appo.py | 6 +- .../{appo_policy_graph.py => appo_policy.py} | 10 +- python/ray/rllib/agents/ppo/ppo.py | 5 +- .../{ppo_policy_graph.py => ppo_policy.py} | 6 +- python/ray/rllib/agents/qmix/qmix.py | 4 +- .../{qmix_policy_graph.py => qmix_policy.py} | 20 +- python/ray/rllib/agents/trainer.py | 49 +- python/ray/rllib/agents/trainer_template.py | 14 +- python/ray/rllib/evaluation/episode.py | 4 +- python/ray/rllib/evaluation/interface.py | 2 +- python/ray/rllib/evaluation/metrics.py | 9 +- .../ray/rllib/evaluation/policy_evaluator.py | 86 ++- python/ray/rllib/evaluation/policy_graph.py | 285 +--------- python/ray/rllib/evaluation/postprocessing.py | 2 +- python/ray/rllib/evaluation/sample_batch.py | 297 +--------- .../rllib/evaluation/sample_batch_builder.py | 4 +- python/ray/rllib/evaluation/sampler.py | 8 +- .../ray/rllib/evaluation/tf_policy_graph.py | 512 +---------------- .../rllib/evaluation/tf_policy_template.py | 40 +- .../rllib/evaluation/torch_policy_graph.py | 172 +----- .../rllib/examples/hierarchical_training.py | 2 +- .../ray/rllib/examples/multiagent_cartpole.py | 10 +- .../examples/multiagent_custom_policy.py | 6 +- .../rllib/examples/multiagent_two_trainers.py | 14 +- .../policy_evaluator_custom_workflow.py | 13 +- .../keras_policy_graph.py => keras_policy.py} | 12 +- python/ray/rllib/models/model.py | 2 +- python/ray/rllib/offline/input_reader.py | 2 +- python/ray/rllib/offline/json_reader.py | 2 +- python/ray/rllib/offline/json_writer.py | 2 +- .../ray/rllib/offline/off_policy_estimator.py | 6 +- .../rllib/optimizers/aso_multi_gpu_learner.py | 2 +- .../optimizers/async_replay_optimizer.py | 2 +- python/ray/rllib/optimizers/multi_gpu_impl.py | 2 +- .../rllib/optimizers/multi_gpu_optimizer.py | 10 +- python/ray/rllib/optimizers/rollout.py | 2 +- .../optimizers/sync_batch_replay_optimizer.py | 2 +- .../rllib/optimizers/sync_replay_optimizer.py | 2 +- .../optimizers/sync_samples_optimizer.py | 2 +- python/ray/rllib/policy/__init__.py | 17 + .../dynamic_tf_policy.py} | 28 +- python/ray/rllib/policy/policy.py | 291 ++++++++++ python/ray/rllib/policy/sample_batch.py | 296 ++++++++++ python/ray/rllib/policy/tf_policy.py | 513 ++++++++++++++++++ python/ray/rllib/policy/tf_policy_template.py | 146 +++++ python/ray/rllib/policy/torch_policy.py | 173 ++++++ .../torch_policy_template.py | 36 +- python/ray/rllib/rollout.py | 2 +- python/ray/rllib/tests/test_evaluators.py | 2 +- python/ray/rllib/tests/test_external_env.py | 14 +- .../tests/test_external_multi_agent_env.py | 16 +- python/ray/rllib/tests/test_io.py | 6 +- .../ray/rllib/tests/test_multi_agent_env.py | 78 +-- python/ray/rllib/tests/test_nested_spaces.py | 4 +- python/ray/rllib/tests/test_optimizers.py | 6 +- python/ray/rllib/tests/test_perf.py | 4 +- .../ray/rllib/tests/test_policy_evaluator.py | 46 +- python/ray/rllib/utils/__init__.py | 21 +- python/ray/rllib/utils/debug.py | 2 +- 91 files changed, 1955 insertions(+), 1739 deletions(-) rename python/ray/rllib/agents/a3c/{a3c_tf_policy_graph.py => a3c_tf_policy.py} (92%) rename python/ray/rllib/agents/a3c/{a3c_torch_policy_graph.py => a3c_torch_policy.py} (95%) rename python/ray/rllib/agents/ddpg/{ddpg_policy_graph.py => ddpg_policy.py} (97%) rename python/ray/rllib/agents/dqn/{dqn_policy_graph.py => dqn_policy.py} (95%) rename python/ray/rllib/agents/impala/{vtrace_policy_graph.py => vtrace_policy.py} (94%) rename python/ray/rllib/agents/marwil/{marwil_policy_graph.py => marwil_policy.py} (93%) rename python/ray/rllib/agents/pg/{pg_policy_graph.py => pg_policy.py} (89%) rename python/ray/rllib/agents/pg/{torch_pg_policy_graph.py => torch_pg_policy.py} (90%) rename python/ray/rllib/agents/ppo/{appo_policy_graph.py => appo_policy.py} (97%) rename python/ray/rllib/agents/ppo/{ppo_policy_graph.py => ppo_policy.py} (98%) rename python/ray/rllib/agents/qmix/{qmix_policy_graph.py => qmix_policy.py} (98%) rename python/ray/rllib/{evaluation/keras_policy_graph.py => keras_policy.py} (83%) create mode 100644 python/ray/rllib/policy/__init__.py rename python/ray/rllib/{evaluation/dynamic_tf_policy_graph.py => policy/dynamic_tf_policy.py} (94%) create mode 100644 python/ray/rllib/policy/policy.py create mode 100644 python/ray/rllib/policy/sample_batch.py create mode 100644 python/ray/rllib/policy/tf_policy.py create mode 100644 python/ray/rllib/policy/tf_policy_template.py create mode 100644 python/ray/rllib/policy/torch_policy.py rename python/ray/rllib/{evaluation => policy}/torch_policy_template.py (82%) diff --git a/doc/source/rllib-algorithms.rst b/doc/source/rllib-algorithms.rst index 2f1a74b2458b..5a07280e3972 100644 --- a/doc/source/rllib-algorithms.rst +++ b/doc/source/rllib-algorithms.rst @@ -274,7 +274,7 @@ QMIX Monotonic Value Factorisation (QMIX, VDN, IQN) --------------------------------------------------- `[paper] `__ `[implementation] `__ Q-Mix is a specialized multi-agent algorithm. Code here is adapted from https://github.com/oxwhirl/pymarl_alpha to integrate with RLlib multi-agent APIs. To use Q-Mix, you must specify an agent `grouping `__ in the environment (see the `two-step game example `__). Currently, all agents in the group must be homogeneous. The algorithm can be scaled by increasing the number of workers or using Ape-X. -Q-Mix is implemented in `PyTorch `__ and is currently *experimental*. +Q-Mix is implemented in `PyTorch `__ and is currently *experimental*. Tuned examples: `Two-step game `__ diff --git a/doc/source/rllib-concepts.rst b/doc/source/rllib-concepts.rst index d91a29f28b9f..e3e7948c864f 100644 --- a/doc/source/rllib-concepts.rst +++ b/doc/source/rllib-concepts.rst @@ -3,24 +3,24 @@ RLlib Concepts This page describes the internal concepts used to implement algorithms in RLlib. You might find this useful if modifying or adding new algorithms to RLlib. -Policy Graphs -------------- +Policies +-------- -Policy graph classes encapsulate the core numerical components of RL algorithms. This typically includes the policy model that determines actions to take, a trajectory postprocessor for experiences, and a loss function to improve the policy given postprocessed experiences. For a simple example, see the policy gradients `graph definition `__. +Policy classes encapsulate the core numerical components of RL algorithms. This typically includes the policy model that determines actions to take, a trajectory postprocessor for experiences, and a loss function to improve the policy given postprocessed experiences. For a simple example, see the policy gradients `graph definition `__. -Most interaction with deep learning frameworks is isolated to the `PolicyGraph interface `__, allowing RLlib to support multiple frameworks. To simplify the definition of policy graphs, RLlib includes `Tensorflow `__ and `PyTorch-specific `__ templates. You can also write your own from scratch. Here is an example: +Most interaction with deep learning frameworks is isolated to the `Policy interface `__, allowing RLlib to support multiple frameworks. To simplify the definition of policies, RLlib includes `Tensorflow `__ and `PyTorch-specific `__ templates. You can also write your own from scratch. Here is an example: .. code-block:: python - class CustomPolicy(PolicyGraph): - """Example of a custom policy graph written from scratch. + class CustomPolicy(Policy): + """Example of a custom policy written from scratch. - You might find it more convenient to extend TF/TorchPolicyGraph instead + You might find it more convenient to extend TF/TorchPolicy instead for a real policy. """ def __init__(self, observation_space, action_space, config): - PolicyGraph.__init__(self, observation_space, action_space, config) + Policy.__init__(self, observation_space, action_space, config) # example parameter self.w = 1.0 @@ -48,7 +48,7 @@ Most interaction with deep learning frameworks is isolated to the `PolicyGraph i Policy Evaluation ----------------- -Given an environment and policy graph, policy evaluation produces `batches `__ of experiences. This is your classic "environment interaction loop". Efficient policy evaluation can be burdensome to get right, especially when leveraging vectorization, RNNs, or when operating in a multi-agent environment. RLlib provides a `PolicyEvaluator `__ class that manages all of this, and this class is used in most RLlib algorithms. +Given an environment and policy, policy evaluation produces `batches `__ of experiences. This is your classic "environment interaction loop". Efficient policy evaluation can be burdensome to get right, especially when leveraging vectorization, RNNs, or when operating in a multi-agent environment. RLlib provides a `PolicyEvaluator `__ class that manages all of this, and this class is used in most RLlib algorithms. You can use policy evaluation standalone to produce batches of experiences. This can be done by calling ``ev.sample()`` on an evaluator instance, or ``ev.sample.remote()`` in parallel on evaluator instances created as Ray actors (see ``PolicyEvaluator.as_remote()``). @@ -81,9 +81,9 @@ Here is an example of creating a set of policy evaluation actors and using the t Policy Optimization ------------------- -Similar to how a `gradient-descent optimizer `__ can be used to improve a model, RLlib's `policy optimizers `__ implement different strategies for improving a policy graph. +Similar to how a `gradient-descent optimizer `__ can be used to improve a model, RLlib's `policy optimizers `__ implement different strategies for improving a policy. -For example, in A3C you'd want to compute gradients asynchronously on different workers, and apply them to a central policy graph replica. This strategy is implemented by the `AsyncGradientsOptimizer `__. Another alternative is to gather experiences synchronously in parallel and optimize the model centrally, as in `SyncSamplesOptimizer `__. Policy optimizers abstract these strategies away into reusable modules. +For example, in A3C you'd want to compute gradients asynchronously on different workers, and apply them to a central policy replica. This strategy is implemented by the `AsyncGradientsOptimizer `__. Another alternative is to gather experiences synchronously in parallel and optimize the model centrally, as in `SyncSamplesOptimizer `__. Policy optimizers abstract these strategies away into reusable modules. This is how the example in the previous section looks when written using a policy optimizer: diff --git a/doc/source/rllib-env.rst b/doc/source/rllib-env.rst index 2701a689dc2c..3d00ac69bcde 100644 --- a/doc/source/rllib-env.rst +++ b/doc/source/rllib-env.rst @@ -167,8 +167,8 @@ If all the agents will be using the same algorithm class to train, then you can trainer = pg.PGAgent(env="my_multiagent_env", config={ "multiagent": { - "policy_graphs": { - # the first tuple value is None -> uses default policy graph + "policies": { + # the first tuple value is None -> uses default policy "car1": (None, car_obs_space, car_act_space, {"gamma": 0.85}), "car2": (None, car_obs_space, car_act_space, {"gamma": 0.99}), "traffic_light": (None, tl_obs_space, tl_act_space, {}), @@ -234,10 +234,10 @@ This can be implemented as a multi-agent environment with three types of agents. .. code-block:: python "multiagent": { - "policy_graphs": { - "top_level": (custom_policy_graph or None, ...), - "mid_level": (custom_policy_graph or None, ...), - "low_level": (custom_policy_graph or None, ...), + "policies": { + "top_level": (custom_policy or None, ...), + "mid_level": (custom_policy or None, ...), + "low_level": (custom_policy or None, ...), }, "policy_mapping_fn": lambda agent_id: @@ -269,9 +269,9 @@ There is a full example of this in the `example training script `__. +2. Updating the critic: the centralized critic loss can be added to the loss of the custom policy, the same as with any other value function. For an example of defining loss inputs, see the `PGPolicy example `__. Grouping Agents ~~~~~~~~~~~~~~~ diff --git a/doc/source/rllib-models.rst b/doc/source/rllib-models.rst index b429e04be417..cdf42ea228c7 100644 --- a/doc/source/rllib-models.rst +++ b/doc/source/rllib-models.rst @@ -101,7 +101,7 @@ Custom TF models should subclass the common RLlib `model class `__. RLlib will automatically run the update ops for the batch norm layers during optimization (see `tf_policy_graph.py `__ and `multi_gpu_impl.py `__ for the exact handling of these updates). +You can use ``tf.layers.batch_normalization(x, training=input_dict["is_training"])`` to add batch norm layers to your custom model: `code example `__. RLlib will automatically run the update ops for the batch norm layers during optimization (see `tf_policy.py `__ and `multi_gpu_impl.py `__ for the exact handling of these updates). Custom Models (PyTorch) ----------------------- @@ -263,7 +263,7 @@ You can mix supervised losses into any RLlib algorithm through custom models. Fo **TensorFlow**: To add a supervised loss to a custom TF model, you need to override the ``custom_loss()`` method. This method takes in the existing policy loss for the algorithm, which you can add your own supervised loss to before returning. For debugging, you can also return a dictionary of scalar tensors in the ``custom_metrics()`` method. Here is a `runnable example `__ of adding an imitation loss to CartPole training that is defined over a `offline dataset `__. -**PyTorch**: There is no explicit API for adding losses to custom torch models. However, you can modify the loss in the policy graph definition directly. Like for TF models, offline datasets can be incorporated by creating an input reader and calling ``reader.next()`` in the loss forward pass. +**PyTorch**: There is no explicit API for adding losses to custom torch models. However, you can modify the loss in the policy definition directly. Like for TF models, offline datasets can be incorporated by creating an input reader and calling ``reader.next()`` in the loss forward pass. Variable-length / Parametric Action Spaces @@ -312,15 +312,15 @@ Custom models can be used to work with environments where (1) the set of valid a Depending on your use case it may make sense to use just the masking, just action embeddings, or both. For a runnable example of this in code, check out `parametric_action_cartpole.py `__. Note that since masking introduces ``tf.float32.min`` values into the model output, this technique might not work with all algorithm options. For example, algorithms might crash if they incorrectly process the ``tf.float32.min`` values. The cartpole example has working configurations for DQN (must set ``hiddens=[]``), PPO (must disable running mean and set ``vf_share_layers=True``), and several other algorithms. -Customizing Policy Graphs +Customizing Policies ------------------------- -For deeper customization of algorithms, you can modify the policy graphs of the trainer classes. Here's an example of extending the DDPG policy graph to specify custom sub-network modules: +For deeper customization of algorithms, you can modify the policies of the trainer classes. Here's an example of extending the DDPG policy to specify custom sub-network modules: .. code-block:: python from ray.rllib.models import ModelCatalog - from ray.rllib.agents.ddpg.ddpg_policy_graph import DDPGPolicyGraph as BaseDDPGPolicyGraph + from ray.rllib.agents.ddpg.ddpg_policy import DDPGTFPolicy as BaseDDPGTFPolicy class CustomPNetwork(object): def __init__(self, dim_actions, hiddens, activation): @@ -336,7 +336,7 @@ For deeper customization of algorithms, you can modify the policy graphs of the self.value = layers.fully_connected( q_out, num_outputs=1, activation_fn=None) - class CustomDDPGPolicyGraph(BaseDDPGPolicyGraph): + class CustomDDPGTFPolicy(BaseDDPGTFPolicy): def _build_p_network(self, obs): return CustomPNetwork( self.dim_actions, @@ -349,26 +349,26 @@ For deeper customization of algorithms, you can modify the policy graphs of the self.config["critic_hiddens"], self.config["critic_hidden_activation"]).value -Then, you can create an trainer with your custom policy graph by: +Then, you can create an trainer with your custom policy by: .. code-block:: python from ray.rllib.agents.ddpg.ddpg import DDPGTrainer - from custom_policy_graph import CustomDDPGPolicyGraph + from custom_policy import CustomDDPGTFPolicy - DDPGTrainer._policy_graph = CustomDDPGPolicyGraph + DDPGTrainer._policy = CustomDDPGTFPolicy trainer = DDPGTrainer(...) -In this example we overrode existing methods of the existing DDPG policy graph, i.e., `_build_q_network`, `_build_p_network`, `_build_action_network`, `_build_actor_critic_loss`, but you can also replace the entire graph class entirely. +In this example we overrode existing methods of the existing DDPG policy, i.e., `_build_q_network`, `_build_p_network`, `_build_action_network`, `_build_actor_critic_loss`, but you can also replace the entire graph class entirely. Model-Based Rollouts ~~~~~~~~~~~~~~~~~~~~ -With a custom policy graph, you can also perform model-based rollouts and optionally incorporate the results of those rollouts as training data. For example, suppose you wanted to extend PGPolicyGraph for model-based rollouts. This involves overriding the ``compute_actions`` method of that policy graph: +With a custom policy, you can also perform model-based rollouts and optionally incorporate the results of those rollouts as training data. For example, suppose you wanted to extend PGPolicy for model-based rollouts. This involves overriding the ``compute_actions`` method of that policy: .. code-block:: python - class ModelBasedPolicyGraph(PGPolicyGraph): + class ModelBasedPolicy(PGPolicy): def compute_actions(self, obs_batch, state_batches, diff --git a/doc/source/rllib-offline.rst b/doc/source/rllib-offline.rst index 42dd5f5b4909..825038af3d53 100644 --- a/doc/source/rllib-offline.rst +++ b/doc/source/rllib-offline.rst @@ -6,7 +6,7 @@ Working with Offline Datasets RLlib's offline dataset APIs enable working with experiences read from offline storage (e.g., disk, cloud storage, streaming systems, HDFS). For example, you might want to read experiences saved from previous training runs, or gathered from policies deployed in `web applications `__. You can also log new agent experiences produced during online training for future use. -RLlib represents trajectory sequences (i.e., ``(s, a, r, s', ...)`` tuples) with `SampleBatch `__ objects. Using a batch format enables efficient encoding and compression of experiences. During online training, RLlib uses `policy evaluation `__ actors to generate batches of experiences in parallel using the current policy. RLlib also uses this same batch format for reading and writing experiences to offline storage. +RLlib represents trajectory sequences (i.e., ``(s, a, r, s', ...)`` tuples) with `SampleBatch `__ objects. Using a batch format enables efficient encoding and compression of experiences. During online training, RLlib uses `policy evaluation `__ actors to generate batches of experiences in parallel using the current policy. RLlib also uses this same batch format for reading and writing experiences to offline storage. Example: Training on previously saved experiences ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -65,7 +65,7 @@ This example plot shows the Q-value metric in addition to importance sampling (I .. image:: offline-q.png -**Estimator Python API:** For greater control over the evaluation process, you can create off-policy estimators in your Python code and call ``estimator.estimate(episode_batch)`` to perform counterfactual estimation as needed. The estimators take in a policy graph object and gamma value for the environment: +**Estimator Python API:** For greater control over the evaluation process, you can create off-policy estimators in your Python code and call ``estimator.estimate(episode_batch)`` to perform counterfactual estimation as needed. The estimators take in a policy object and gamma value for the environment: .. code-block:: python @@ -99,7 +99,7 @@ This `runnable example `__. This isn't typically critical for off-policy algorithms (e.g., DQN's `post-processing `__ is only needed if ``n_step > 1`` or ``worker_side_prioritization: True``). For off-policy algorithms, you can also safely set the ``postprocess_inputs: True`` config to auto-postprocess data. +RLlib assumes that input batches are of `postprocessed experiences `__. This isn't typically critical for off-policy algorithms (e.g., DQN's `post-processing `__ is only needed if ``n_step > 1`` or ``worker_side_prioritization: True``). For off-policy algorithms, you can also safely set the ``postprocess_inputs: True`` config to auto-postprocess data. However, for on-policy algorithms like PPO, you'll need to pass in the extra values added during policy evaluation and postprocessing to ``batch_builder.add_values()``, e.g., ``logits``, ``vf_preds``, ``value_target``, and ``advantages`` for PPO. This is needed since the calculation of these values depends on the parameters of the *behaviour* policy, which RLlib does not have access to in the offline setting (in online training, these values are automatically added during policy evaluation). diff --git a/doc/source/rllib.rst b/doc/source/rllib.rst index 06c580035507..02b1bc3478ee 100644 --- a/doc/source/rllib.rst +++ b/doc/source/rllib.rst @@ -50,7 +50,7 @@ Models and Preprocessors * `Custom Preprocessors `__ * `Supervised Model Losses `__ * `Variable-length / Parametric Action Spaces `__ -* `Customizing Policy Graphs `__ +* `Customizing Policies `__ Algorithms ---------- @@ -98,7 +98,7 @@ Offline Datasets Concepts -------- -* `Policy Graphs `__ +* `Policies `__ * `Policy Evaluation `__ * `Policy Optimization `__ * `Trainers `__ diff --git a/python/ray/rllib/__init__.py b/python/ray/rllib/__init__.py index 613199cf795f..05f88ac653c4 100644 --- a/python/ray/rllib/__init__.py +++ b/python/ray/rllib/__init__.py @@ -10,12 +10,14 @@ from ray.rllib.evaluation.policy_graph import PolicyGraph from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph +from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator from ray.rllib.env.base_env import BaseEnv from ray.rllib.env.multi_agent_env import MultiAgentEnv from ray.rllib.env.vector_env import VectorEnv from ray.rllib.env.external_env import ExternalEnv -from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator -from ray.rllib.evaluation.sample_batch import SampleBatch +from ray.rllib.policy.policy import Policy +from ray.rllib.policy.tf_policy import TFPolicy +from ray.rllib.policy.sample_batch import SampleBatch def _setup_logger(): @@ -43,7 +45,9 @@ def _register_all(): _register_all() __all__ = [ + "Policy", "PolicyGraph", + "TFPolicy", "TFPolicyGraph", "PolicyEvaluator", "SampleBatch", diff --git a/python/ray/rllib/agents/a3c/__init__.py b/python/ray/rllib/agents/a3c/__init__.py index 9c8205389ea2..4a8480eab695 100644 --- a/python/ray/rllib/agents/a3c/__init__.py +++ b/python/ray/rllib/agents/a3c/__init__.py @@ -1,9 +1,9 @@ from ray.rllib.agents.a3c.a3c import A3CTrainer, DEFAULT_CONFIG from ray.rllib.agents.a3c.a2c import A2CTrainer -from ray.rllib.utils import renamed_class +from ray.rllib.utils import renamed_agent -A2CAgent = renamed_class(A2CTrainer) -A3CAgent = renamed_class(A3CTrainer) +A2CAgent = renamed_agent(A2CTrainer) +A3CAgent = renamed_agent(A3CTrainer) __all__ = [ "A2CAgent", "A3CAgent", "A2CTrainer", "A3CTrainer", "DEFAULT_CONFIG" diff --git a/python/ray/rllib/agents/a3c/a3c.py b/python/ray/rllib/agents/a3c/a3c.py index eb384058de80..56d7a09daa0f 100644 --- a/python/ray/rllib/agents/a3c/a3c.py +++ b/python/ray/rllib/agents/a3c/a3c.py @@ -4,7 +4,7 @@ import time -from ray.rllib.agents.a3c.a3c_tf_policy_graph import A3CPolicyGraph +from ray.rllib.agents.a3c.a3c_tf_policy import A3CTFPolicy from ray.rllib.agents.trainer import Trainer, with_common_config from ray.rllib.optimizers import AsyncGradientsOptimizer from ray.rllib.utils.annotations import override @@ -43,16 +43,16 @@ class A3CTrainer(Trainer): _name = "A3C" _default_config = DEFAULT_CONFIG - _policy_graph = A3CPolicyGraph + _policy = A3CTFPolicy @override(Trainer) def _init(self, config, env_creator): if config["use_pytorch"]: - from ray.rllib.agents.a3c.a3c_torch_policy_graph import \ + from ray.rllib.agents.a3c.a3c_torch_policy import \ A3CTorchPolicy policy_cls = A3CTorchPolicy else: - policy_cls = self._policy_graph + policy_cls = self._policy if config["entropy_coeff"] < 0: raise DeprecationWarning("entropy_coeff must be >= 0") diff --git a/python/ray/rllib/agents/a3c/a3c_tf_policy_graph.py b/python/ray/rllib/agents/a3c/a3c_tf_policy.py similarity index 92% rename from python/ray/rllib/agents/a3c/a3c_tf_policy_graph.py rename to python/ray/rllib/agents/a3c/a3c_tf_policy.py index e6ae8d17bad3..eb5becceaa71 100644 --- a/python/ray/rllib/agents/a3c/a3c_tf_policy_graph.py +++ b/python/ray/rllib/agents/a3c/a3c_tf_policy.py @@ -1,4 +1,4 @@ -"""Note: Keep in sync with changes to VTracePolicyGraph.""" +"""Note: Keep in sync with changes to VTraceTFPolicy.""" from __future__ import absolute_import from __future__ import division @@ -8,13 +8,13 @@ import ray from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY -from ray.rllib.evaluation.sample_batch import SampleBatch +from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils.error import UnsupportedSpaceException from ray.rllib.utils.explained_variance import explained_variance -from ray.rllib.evaluation.policy_graph import PolicyGraph +from ray.rllib.policy.policy import Policy from ray.rllib.evaluation.postprocessing import compute_advantages, \ Postprocessing -from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \ +from ray.rllib.policy.tf_policy import TFPolicy, \ LearningRateSchedule from ray.rllib.models.catalog import ModelCatalog from ray.rllib.utils.annotations import override @@ -47,13 +47,13 @@ def __init__(self, class A3CPostprocessing(object): """Adds the VF preds and advantages fields to the trajectory.""" - @override(TFPolicyGraph) + @override(TFPolicy) def extra_compute_action_fetches(self): return dict( - TFPolicyGraph.extra_compute_action_fetches(self), + TFPolicy.extra_compute_action_fetches(self), **{SampleBatch.VF_PREDS: self.vf}) - @override(PolicyGraph) + @override(Policy) def postprocess_trajectory(self, sample_batch, other_agent_batches=None, @@ -73,7 +73,7 @@ def postprocess_trajectory(self, self.config["lambda"]) -class A3CPolicyGraph(LearningRateSchedule, A3CPostprocessing, TFPolicyGraph): +class A3CTFPolicy(LearningRateSchedule, A3CPostprocessing, TFPolicy): def __init__(self, observation_space, action_space, config): config = dict(ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG, **config) self.config = config @@ -114,7 +114,7 @@ def __init__(self, observation_space, action_space, config): self.vf, self.config["vf_loss_coeff"], self.config["entropy_coeff"]) - # Initialize TFPolicyGraph + # Initialize TFPolicy loss_in = [ (SampleBatch.CUR_OBS, self.observations), (SampleBatch.ACTIONS, actions), @@ -125,7 +125,7 @@ def __init__(self, observation_space, action_space, config): ] LearningRateSchedule.__init__(self, self.config["lr"], self.config["lr_schedule"]) - TFPolicyGraph.__init__( + TFPolicy.__init__( self, observation_space, action_space, @@ -157,18 +157,18 @@ def __init__(self, observation_space, action_space, config): self.sess.run(tf.global_variables_initializer()) - @override(PolicyGraph) + @override(Policy) def get_initial_state(self): return self.model.state_init - @override(TFPolicyGraph) + @override(TFPolicy) def gradients(self, optimizer, loss): grads = tf.gradients(loss, self.var_list) self.grads, _ = tf.clip_by_global_norm(grads, self.config["grad_clip"]) clipped_grads = list(zip(self.grads, self.var_list)) return clipped_grads - @override(TFPolicyGraph) + @override(TFPolicy) def extra_compute_grad_fetches(self): return self.stats_fetches diff --git a/python/ray/rllib/agents/a3c/a3c_torch_policy_graph.py b/python/ray/rllib/agents/a3c/a3c_torch_policy.py similarity index 95% rename from python/ray/rllib/agents/a3c/a3c_torch_policy_graph.py rename to python/ray/rllib/agents/a3c/a3c_torch_policy.py index fa6f857f9eca..6ccf6c48d35f 100644 --- a/python/ray/rllib/agents/a3c/a3c_torch_policy_graph.py +++ b/python/ray/rllib/agents/a3c/a3c_torch_policy.py @@ -9,8 +9,8 @@ import ray from ray.rllib.evaluation.postprocessing import compute_advantages, \ Postprocessing -from ray.rllib.evaluation.sample_batch import SampleBatch -from ray.rllib.evaluation.torch_policy_template import build_torch_policy +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.torch_policy_template import build_torch_policy def actor_critic_loss(policy, batch_tensors): diff --git a/python/ray/rllib/agents/agent.py b/python/ray/rllib/agents/agent.py index 5b0ecf268fe7..17da952ddedf 100644 --- a/python/ray/rllib/agents/agent.py +++ b/python/ray/rllib/agents/agent.py @@ -3,6 +3,6 @@ from __future__ import print_function from ray.rllib.agents.trainer import Trainer -from ray.rllib.utils import renamed_class +from ray.rllib.utils import renamed_agent -Agent = renamed_class(Trainer) +Agent = renamed_agent(Trainer) diff --git a/python/ray/rllib/agents/ars/__init__.py b/python/ray/rllib/agents/ars/__init__.py index a1120ff8ce31..0681efe7ab37 100644 --- a/python/ray/rllib/agents/ars/__init__.py +++ b/python/ray/rllib/agents/ars/__init__.py @@ -1,6 +1,6 @@ from ray.rllib.agents.ars.ars import (ARSTrainer, DEFAULT_CONFIG) -from ray.rllib.utils import renamed_class +from ray.rllib.utils import renamed_agent -ARSAgent = renamed_class(ARSTrainer) +ARSAgent = renamed_agent(ARSTrainer) __all__ = ["ARSAgent", "ARSTrainer", "DEFAULT_CONFIG"] diff --git a/python/ray/rllib/agents/ars/ars.py b/python/ray/rllib/agents/ars/ars.py index 65738a620b30..4330f0d90db0 100644 --- a/python/ray/rllib/agents/ars/ars.py +++ b/python/ray/rllib/agents/ars/ars.py @@ -17,7 +17,7 @@ from ray.rllib.agents.ars import optimizers from ray.rllib.agents.ars import policies from ray.rllib.agents.ars import utils -from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID +from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID from ray.rllib.utils.annotations import override from ray.rllib.utils.memory import ray_get_and_free from ray.rllib.utils import FilterManager diff --git a/python/ray/rllib/agents/ddpg/__init__.py b/python/ray/rllib/agents/ddpg/__init__.py index 9b90ca842ae5..3d681b8356c9 100644 --- a/python/ray/rllib/agents/ddpg/__init__.py +++ b/python/ray/rllib/agents/ddpg/__init__.py @@ -5,10 +5,10 @@ from ray.rllib.agents.ddpg.apex import ApexDDPGTrainer from ray.rllib.agents.ddpg.ddpg import DDPGTrainer, DEFAULT_CONFIG from ray.rllib.agents.ddpg.td3 import TD3Trainer -from ray.rllib.utils import renamed_class +from ray.rllib.utils import renamed_agent -ApexDDPGAgent = renamed_class(ApexDDPGTrainer) -DDPGAgent = renamed_class(DDPGTrainer) +ApexDDPGAgent = renamed_agent(ApexDDPGTrainer) +DDPGAgent = renamed_agent(DDPGTrainer) __all__ = [ "DDPGAgent", "ApexDDPGAgent", "DDPGTrainer", "ApexDDPGTrainer", diff --git a/python/ray/rllib/agents/ddpg/ddpg.py b/python/ray/rllib/agents/ddpg/ddpg.py index 7a140beeea24..66d3810e5e93 100644 --- a/python/ray/rllib/agents/ddpg/ddpg.py +++ b/python/ray/rllib/agents/ddpg/ddpg.py @@ -4,7 +4,7 @@ from ray.rllib.agents.trainer import with_common_config from ray.rllib.agents.dqn.dqn import DQNTrainer -from ray.rllib.agents.ddpg.ddpg_policy_graph import DDPGPolicyGraph +from ray.rllib.agents.ddpg.ddpg_policy import DDPGTFPolicy from ray.rllib.utils.annotations import override from ray.rllib.utils.schedules import ConstantSchedule, LinearSchedule @@ -163,7 +163,7 @@ class DDPGTrainer(DQNTrainer): """DDPG implementation in TensorFlow.""" _name = "DDPG" _default_config = DEFAULT_CONFIG - _policy_graph = DDPGPolicyGraph + _policy = DDPGTFPolicy @override(DQNTrainer) def _train(self): diff --git a/python/ray/rllib/agents/ddpg/ddpg_policy_graph.py b/python/ray/rllib/agents/ddpg/ddpg_policy.py similarity index 97% rename from python/ray/rllib/agents/ddpg/ddpg_policy_graph.py rename to python/ray/rllib/agents/ddpg/ddpg_policy.py index 675f9187f2c6..b80cfce4cdaa 100644 --- a/python/ray/rllib/agents/ddpg/ddpg_policy_graph.py +++ b/python/ray/rllib/agents/ddpg/ddpg_policy.py @@ -7,15 +7,15 @@ import ray import ray.experimental.tf_utils -from ray.rllib.agents.dqn.dqn_policy_graph import ( - _huber_loss, _minimize_and_clip, _scope_vars, _postprocess_dqn) -from ray.rllib.evaluation.sample_batch import SampleBatch +from ray.rllib.agents.dqn.dqn_policy import (_huber_loss, _minimize_and_clip, + _scope_vars, _postprocess_dqn) +from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY from ray.rllib.models import ModelCatalog from ray.rllib.utils.annotations import override from ray.rllib.utils.error import UnsupportedSpaceException -from ray.rllib.evaluation.policy_graph import PolicyGraph -from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph +from ray.rllib.policy.policy import Policy +from ray.rllib.policy.tf_policy import TFPolicy from ray.rllib.utils import try_import_tf tf = try_import_tf() @@ -35,7 +35,7 @@ class DDPGPostprocessing(object): """Implements n-step learning and param noise adjustments.""" - @override(PolicyGraph) + @override(Policy) def postprocess_trajectory(self, sample_batch, other_agent_batches=None, @@ -68,7 +68,7 @@ def postprocess_trajectory(self, return _postprocess_dqn(self, sample_batch) -class DDPGPolicyGraph(DDPGPostprocessing, TFPolicyGraph): +class DDPGTFPolicy(DDPGPostprocessing, TFPolicy): def __init__(self, observation_space, action_space, config): config = dict(ray.rllib.agents.ddpg.ddpg.DEFAULT_CONFIG, **config) if not isinstance(action_space, Box): @@ -281,7 +281,7 @@ def __init__(self, observation_space, action_space, config): self.critic_loss = self.twin_q_model.custom_loss( self.critic_loss, input_dict) - TFPolicyGraph.__init__( + TFPolicy.__init__( self, observation_space, action_space, @@ -301,12 +301,12 @@ def __init__(self, observation_space, action_space, config): # Hard initial update self.update_target(tau=1.0) - @override(TFPolicyGraph) + @override(TFPolicy) def optimizer(self): # we don't use this because we have two separate optimisers return None - @override(TFPolicyGraph) + @override(TFPolicy) def build_apply_op(self, optimizer, grads_and_vars): # for policy gradient, update policy net one time v.s. # update critic net `policy_delay` time(s) @@ -327,7 +327,7 @@ def make_apply_op(): with tf.control_dependencies([tf.assign_add(self.global_step, 1)]): return tf.group(actor_op, critic_op) - @override(TFPolicyGraph) + @override(TFPolicy) def gradients(self, optimizer, loss): if self.config["grad_norm_clipping"] is not None: actor_grads_and_vars = _minimize_and_clip( @@ -360,7 +360,7 @@ def gradients(self, optimizer, loss): + self._critic_grads_and_vars return grads_and_vars - @override(TFPolicyGraph) + @override(TFPolicy) def extra_compute_action_feed_dict(self): return { # FIXME: what about turning off exploration? Isn't that a good @@ -370,31 +370,31 @@ def extra_compute_action_feed_dict(self): self.pure_exploration_phase: self.cur_pure_exploration_phase, } - @override(TFPolicyGraph) + @override(TFPolicy) def extra_compute_grad_fetches(self): return { "td_error": self.td_error, LEARNER_STATS_KEY: self.stats, } - @override(TFPolicyGraph) + @override(TFPolicy) def get_weights(self): return self.variables.get_weights() - @override(TFPolicyGraph) + @override(TFPolicy) def set_weights(self, weights): self.variables.set_weights(weights) - @override(PolicyGraph) + @override(Policy) def get_state(self): return [ - TFPolicyGraph.get_state(self), self.cur_noise_scale, + TFPolicy.get_state(self), self.cur_noise_scale, self.cur_pure_exploration_phase ] - @override(PolicyGraph) + @override(Policy) def set_state(self, state): - TFPolicyGraph.set_state(self, state[0]) + TFPolicy.set_state(self, state[0]) self.set_epsilon(state[1]) self.set_pure_exploration_phase(state[2]) diff --git a/python/ray/rllib/agents/dqn/__init__.py b/python/ray/rllib/agents/dqn/__init__.py index 415ceae6c1de..d3de8cb802cc 100644 --- a/python/ray/rllib/agents/dqn/__init__.py +++ b/python/ray/rllib/agents/dqn/__init__.py @@ -4,10 +4,10 @@ from ray.rllib.agents.dqn.apex import ApexTrainer from ray.rllib.agents.dqn.dqn import DQNTrainer, DEFAULT_CONFIG -from ray.rllib.utils import renamed_class +from ray.rllib.utils import renamed_agent -DQNAgent = renamed_class(DQNTrainer) -ApexAgent = renamed_class(ApexTrainer) +DQNAgent = renamed_agent(DQNTrainer) +ApexAgent = renamed_agent(ApexTrainer) __all__ = [ "DQNAgent", "ApexAgent", "ApexTrainer", "DQNTrainer", "DEFAULT_CONFIG" diff --git a/python/ray/rllib/agents/dqn/dqn.py b/python/ray/rllib/agents/dqn/dqn.py index d8fb480cbda6..7fdb6f66b433 100644 --- a/python/ray/rllib/agents/dqn/dqn.py +++ b/python/ray/rllib/agents/dqn/dqn.py @@ -8,9 +8,9 @@ from ray import tune from ray.rllib import optimizers from ray.rllib.agents.trainer import Trainer, with_common_config -from ray.rllib.agents.dqn.dqn_policy_graph import DQNPolicyGraph +from ray.rllib.agents.dqn.dqn_policy import DQNTFPolicy from ray.rllib.evaluation.metrics import collect_metrics -from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID +from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID from ray.rllib.utils.annotations import override from ray.rllib.utils.schedules import ConstantSchedule, LinearSchedule @@ -133,7 +133,7 @@ class DQNTrainer(Trainer): _name = "DQN" _default_config = DEFAULT_CONFIG - _policy_graph = DQNPolicyGraph + _policy = DQNTFPolicy _optimizer_shared_configs = OPTIMIZER_SHARED_CONFIGS @override(Trainer) @@ -197,10 +197,10 @@ def on_episode_end(info): on_episode_end) self.local_evaluator = self.make_local_evaluator( - env_creator, self._policy_graph) + env_creator, self._policy) def create_remote_evaluators(): - return self.make_remote_evaluators(env_creator, self._policy_graph, + return self.make_remote_evaluators(env_creator, self._policy, config["num_workers"]) if config["optimizer_class"] != "AsyncReplayOptimizer": diff --git a/python/ray/rllib/agents/dqn/dqn_policy_graph.py b/python/ray/rllib/agents/dqn/dqn_policy.py similarity index 95% rename from python/ray/rllib/agents/dqn/dqn_policy_graph.py rename to python/ray/rllib/agents/dqn/dqn_policy.py index 1e682ce80cfa..a1affa947a43 100644 --- a/python/ray/rllib/agents/dqn/dqn_policy_graph.py +++ b/python/ray/rllib/agents/dqn/dqn_policy.py @@ -7,13 +7,13 @@ from scipy.stats import entropy import ray -from ray.rllib.evaluation.sample_batch import SampleBatch +from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY from ray.rllib.models import ModelCatalog, Categorical from ray.rllib.utils.annotations import override from ray.rllib.utils.error import UnsupportedSpaceException -from ray.rllib.evaluation.policy_graph import PolicyGraph -from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \ +from ray.rllib.policy.policy import Policy +from ray.rllib.policy.tf_policy import TFPolicy, \ LearningRateSchedule from ray.rllib.utils import try_import_tf @@ -105,14 +105,14 @@ def __init__(self, class DQNPostprocessing(object): """Implements n-step learning and param noise adjustments.""" - @override(TFPolicyGraph) + @override(TFPolicy) def extra_compute_action_fetches(self): return dict( - TFPolicyGraph.extra_compute_action_fetches(self), **{ + TFPolicy.extra_compute_action_fetches(self), **{ "q_values": self.q_values, }) - @override(PolicyGraph) + @override(Policy) def postprocess_trajectory(self, sample_batch, other_agent_batches=None, @@ -345,7 +345,7 @@ def __init__(self, q_values, observations, num_actions, stochastic, eps, self.action_prob = None -class DQNPolicyGraph(LearningRateSchedule, DQNPostprocessing, TFPolicyGraph): +class DQNTFPolicy(LearningRateSchedule, DQNPostprocessing, TFPolicy): def __init__(self, observation_space, action_space, config): config = dict(ray.rllib.agents.dqn.dqn.DEFAULT_CONFIG, **config) if not isinstance(action_space, Discrete): @@ -446,7 +446,7 @@ def __init__(self, observation_space, action_space, config): update_target_expr.append(var_target.assign(var)) self.update_target_expr = tf.group(*update_target_expr) - # initialize TFPolicyGraph + # initialize TFPolicy self.sess = tf.get_default_session() self.loss_inputs = [ (SampleBatch.CUR_OBS, self.obs_t), @@ -459,7 +459,7 @@ def __init__(self, observation_space, action_space, config): LearningRateSchedule.__init__(self, self.config["lr"], self.config["lr_schedule"]) - TFPolicyGraph.__init__( + TFPolicy.__init__( self, observation_space, action_space, @@ -477,12 +477,12 @@ def __init__(self, observation_space, action_space, config): "cur_lr": tf.cast(self.cur_lr, tf.float64), }, **self.loss.stats) - @override(TFPolicyGraph) + @override(TFPolicy) def optimizer(self): return tf.train.AdamOptimizer( learning_rate=self.cur_lr, epsilon=self.config["adam_epsilon"]) - @override(TFPolicyGraph) + @override(TFPolicy) def gradients(self, optimizer, loss): if self.config["grad_norm_clipping"] is not None: grads_and_vars = _minimize_and_clip( @@ -496,27 +496,27 @@ def gradients(self, optimizer, loss): grads_and_vars = [(g, v) for (g, v) in grads_and_vars if g is not None] return grads_and_vars - @override(TFPolicyGraph) + @override(TFPolicy) def extra_compute_action_feed_dict(self): return { self.stochastic: True, self.eps: self.cur_epsilon, } - @override(TFPolicyGraph) + @override(TFPolicy) def extra_compute_grad_fetches(self): return { "td_error": self.loss.td_error, LEARNER_STATS_KEY: self.stats_fetches, } - @override(PolicyGraph) + @override(Policy) def get_state(self): - return [TFPolicyGraph.get_state(self), self.cur_epsilon] + return [TFPolicy.get_state(self), self.cur_epsilon] - @override(PolicyGraph) + @override(Policy) def set_state(self, state): - TFPolicyGraph.set_state(self, state[0]) + TFPolicy.set_state(self, state[0]) self.set_epsilon(state[1]) def _build_parameter_noise(self, pnet_params): @@ -633,25 +633,25 @@ def _adjust_nstep(n_step, gamma, obs, actions, rewards, new_obs, dones): rewards[i] += gamma**j * rewards[i + j] -def _postprocess_dqn(policy_graph, batch): +def _postprocess_dqn(policy, batch): # N-step Q adjustments - if policy_graph.config["n_step"] > 1: - _adjust_nstep(policy_graph.config["n_step"], - policy_graph.config["gamma"], batch[SampleBatch.CUR_OBS], - batch[SampleBatch.ACTIONS], batch[SampleBatch.REWARDS], - batch[SampleBatch.NEXT_OBS], batch[SampleBatch.DONES]) + if policy.config["n_step"] > 1: + _adjust_nstep(policy.config["n_step"], policy.config["gamma"], + batch[SampleBatch.CUR_OBS], batch[SampleBatch.ACTIONS], + batch[SampleBatch.REWARDS], batch[SampleBatch.NEXT_OBS], + batch[SampleBatch.DONES]) if PRIO_WEIGHTS not in batch: batch[PRIO_WEIGHTS] = np.ones_like(batch[SampleBatch.REWARDS]) # Prioritize on the worker side - if batch.count > 0 and policy_graph.config["worker_side_prioritization"]: - td_errors = policy_graph.compute_td_error( + if batch.count > 0 and policy.config["worker_side_prioritization"]: + td_errors = policy.compute_td_error( batch[SampleBatch.CUR_OBS], batch[SampleBatch.ACTIONS], batch[SampleBatch.REWARDS], batch[SampleBatch.NEXT_OBS], batch[SampleBatch.DONES], batch[PRIO_WEIGHTS]) new_priorities = ( - np.abs(td_errors) + policy_graph.config["prioritized_replay_eps"]) + np.abs(td_errors) + policy.config["prioritized_replay_eps"]) batch.data[PRIO_WEIGHTS] = new_priorities return batch diff --git a/python/ray/rllib/agents/es/__init__.py b/python/ray/rllib/agents/es/__init__.py index d7bec2a9e002..38b2b772ec57 100644 --- a/python/ray/rllib/agents/es/__init__.py +++ b/python/ray/rllib/agents/es/__init__.py @@ -1,6 +1,6 @@ from ray.rllib.agents.es.es import (ESTrainer, DEFAULT_CONFIG) -from ray.rllib.utils import renamed_class +from ray.rllib.utils import renamed_agent -ESAgent = renamed_class(ESTrainer) +ESAgent = renamed_agent(ESTrainer) __all__ = ["ESAgent", "ESTrainer", "DEFAULT_CONFIG"] diff --git a/python/ray/rllib/agents/es/es.py b/python/ray/rllib/agents/es/es.py index 2328b90e9ed0..e167129c6a93 100644 --- a/python/ray/rllib/agents/es/es.py +++ b/python/ray/rllib/agents/es/es.py @@ -16,7 +16,7 @@ from ray.rllib.agents.es import optimizers from ray.rllib.agents.es import policies from ray.rllib.agents.es import utils -from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID +from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID from ray.rllib.utils.annotations import override from ray.rllib.utils.memory import ray_get_and_free from ray.rllib.utils import FilterManager diff --git a/python/ray/rllib/agents/impala/__init__.py b/python/ray/rllib/agents/impala/__init__.py index 81c64e8891ab..d7bdd7210fdd 100644 --- a/python/ray/rllib/agents/impala/__init__.py +++ b/python/ray/rllib/agents/impala/__init__.py @@ -1,6 +1,6 @@ from ray.rllib.agents.impala.impala import ImpalaTrainer, DEFAULT_CONFIG -from ray.rllib.utils import renamed_class +from ray.rllib.utils import renamed_agent -ImpalaAgent = renamed_class(ImpalaTrainer) +ImpalaAgent = renamed_agent(ImpalaTrainer) __all__ = ["ImpalaAgent", "ImpalaTrainer", "DEFAULT_CONFIG"] diff --git a/python/ray/rllib/agents/impala/impala.py b/python/ray/rllib/agents/impala/impala.py index ffe74c087a3e..838f2975ce67 100644 --- a/python/ray/rllib/agents/impala/impala.py +++ b/python/ray/rllib/agents/impala/impala.py @@ -4,8 +4,8 @@ import time -from ray.rllib.agents.a3c.a3c_tf_policy_graph import A3CPolicyGraph -from ray.rllib.agents.impala.vtrace_policy_graph import VTracePolicyGraph +from ray.rllib.agents.a3c.a3c_tf_policy import A3CTFPolicy +from ray.rllib.agents.impala.vtrace_policy import VTraceTFPolicy from ray.rllib.agents.trainer import Trainer, with_common_config from ray.rllib.optimizers import AsyncSamplesOptimizer from ray.rllib.optimizers.aso_tree_aggregator import TreeAggregator @@ -105,14 +105,14 @@ class ImpalaTrainer(Trainer): _name = "IMPALA" _default_config = DEFAULT_CONFIG - _policy_graph = VTracePolicyGraph + _policy = VTraceTFPolicy @override(Trainer) def _init(self, config, env_creator): for k in OPTIMIZER_SHARED_CONFIGS: if k not in config["optimizer"]: config["optimizer"][k] = config[k] - policy_cls = self._get_policy_graph() + policy_cls = self._get_policy() self.local_evaluator = self.make_local_evaluator( self.env_creator, policy_cls) @@ -158,9 +158,9 @@ def _train(self): prev_steps) return result - def _get_policy_graph(self): + def _get_policy(self): if self.config["vtrace"]: - policy_cls = self._policy_graph + policy_cls = self._policy else: - policy_cls = A3CPolicyGraph + policy_cls = A3CTFPolicy return policy_cls diff --git a/python/ray/rllib/agents/impala/vtrace_policy_graph.py b/python/ray/rllib/agents/impala/vtrace_policy.py similarity index 94% rename from python/ray/rllib/agents/impala/vtrace_policy_graph.py rename to python/ray/rllib/agents/impala/vtrace_policy.py index 56b6de42ed5a..9b7c57b9355e 100644 --- a/python/ray/rllib/agents/impala/vtrace_policy_graph.py +++ b/python/ray/rllib/agents/impala/vtrace_policy.py @@ -1,6 +1,6 @@ -"""Adapted from A3CPolicyGraph to add V-trace. +"""Adapted from A3CTFPolicy to add V-trace. -Keep in sync with changes to A3CPolicyGraph and VtraceSurrogatePolicyGraph.""" +Keep in sync with changes to A3CTFPolicy and VtraceSurrogatePolicy.""" from __future__ import absolute_import from __future__ import division @@ -11,9 +11,9 @@ import numpy as np from ray.rllib.agents.impala import vtrace from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY -from ray.rllib.evaluation.policy_graph import PolicyGraph -from ray.rllib.evaluation.sample_batch import SampleBatch -from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \ +from ray.rllib.policy.policy import Policy +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.tf_policy import TFPolicy, \ LearningRateSchedule from ray.rllib.models.action_dist import MultiCategorical from ray.rllib.models.catalog import ModelCatalog @@ -110,13 +110,13 @@ def __init__(self, class VTracePostprocessing(object): """Adds the policy logits to the trajectory.""" - @override(TFPolicyGraph) + @override(TFPolicy) def extra_compute_action_fetches(self): return dict( - TFPolicyGraph.extra_compute_action_fetches(self), + TFPolicy.extra_compute_action_fetches(self), **{BEHAVIOUR_LOGITS: self.model.outputs}) - @override(PolicyGraph) + @override(Policy) def postprocess_trajectory(self, sample_batch, other_agent_batches=None, @@ -126,8 +126,7 @@ def postprocess_trajectory(self, return sample_batch -class VTracePolicyGraph(LearningRateSchedule, VTracePostprocessing, - TFPolicyGraph): +class VTraceTFPolicy(LearningRateSchedule, VTracePostprocessing, TFPolicy): def __init__(self, observation_space, action_space, @@ -285,7 +284,7 @@ def make_time_major(tensor, drop_last=False): "max_KL": tf.reduce_max(kls[0]), } - # Initialize TFPolicyGraph + # Initialize TFPolicy loss_in = [ (SampleBatch.ACTIONS, actions), (SampleBatch.DONES, dones), @@ -297,7 +296,7 @@ def make_time_major(tensor, drop_last=False): ] LearningRateSchedule.__init__(self, self.config["lr"], self.config["lr_schedule"]) - TFPolicyGraph.__init__( + TFPolicy.__init__( self, observation_space, action_space, @@ -332,15 +331,15 @@ def make_time_major(tensor, drop_last=False): }, **self.KL_stats), } - @override(TFPolicyGraph) + @override(TFPolicy) def copy(self, existing_inputs): - return VTracePolicyGraph( + return VTraceTFPolicy( self.observation_space, self.action_space, self.config, existing_inputs=existing_inputs) - @override(TFPolicyGraph) + @override(TFPolicy) def optimizer(self): if self.config["opt_type"] == "adam": return tf.train.AdamOptimizer(self.cur_lr) @@ -349,17 +348,17 @@ def optimizer(self): self.config["momentum"], self.config["epsilon"]) - @override(TFPolicyGraph) + @override(TFPolicy) def gradients(self, optimizer, loss): grads = tf.gradients(loss, self.var_list) self.grads, _ = tf.clip_by_global_norm(grads, self.config["grad_clip"]) clipped_grads = list(zip(self.grads, self.var_list)) return clipped_grads - @override(TFPolicyGraph) + @override(TFPolicy) def extra_compute_grad_fetches(self): return self.stats_fetches - @override(PolicyGraph) + @override(Policy) def get_initial_state(self): return self.model.state_init diff --git a/python/ray/rllib/agents/marwil/marwil.py b/python/ray/rllib/agents/marwil/marwil.py index b1e535b64530..d6c6eadeaa9c 100644 --- a/python/ray/rllib/agents/marwil/marwil.py +++ b/python/ray/rllib/agents/marwil/marwil.py @@ -3,7 +3,7 @@ from __future__ import print_function from ray.rllib.agents.trainer import Trainer, with_common_config -from ray.rllib.agents.marwil.marwil_policy_graph import MARWILPolicyGraph +from ray.rllib.agents.marwil.marwil_policy import MARWILPolicy from ray.rllib.optimizers import SyncBatchReplayOptimizer from ray.rllib.utils.annotations import override @@ -44,14 +44,14 @@ class MARWILTrainer(Trainer): _name = "MARWIL" _default_config = DEFAULT_CONFIG - _policy_graph = MARWILPolicyGraph + _policy = MARWILPolicy @override(Trainer) def _init(self, config, env_creator): self.local_evaluator = self.make_local_evaluator( - env_creator, self._policy_graph) + env_creator, self._policy) self.remote_evaluators = self.make_remote_evaluators( - env_creator, self._policy_graph, config["num_workers"]) + env_creator, self._policy, config["num_workers"]) self.optimizer = SyncBatchReplayOptimizer( self.local_evaluator, self.remote_evaluators, diff --git a/python/ray/rllib/agents/marwil/marwil_policy_graph.py b/python/ray/rllib/agents/marwil/marwil_policy.py similarity index 93% rename from python/ray/rllib/agents/marwil/marwil_policy_graph.py rename to python/ray/rllib/agents/marwil/marwil_policy.py index 2c647db9aa96..add021025c9c 100644 --- a/python/ray/rllib/agents/marwil/marwil_policy_graph.py +++ b/python/ray/rllib/agents/marwil/marwil_policy.py @@ -6,12 +6,12 @@ from ray.rllib.models import ModelCatalog from ray.rllib.evaluation.postprocessing import compute_advantages, \ Postprocessing -from ray.rllib.evaluation.sample_batch import SampleBatch +from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY from ray.rllib.utils.annotations import override -from ray.rllib.evaluation.policy_graph import PolicyGraph -from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph -from ray.rllib.agents.dqn.dqn_policy_graph import _scope_vars +from ray.rllib.policy.policy import Policy +from ray.rllib.policy.tf_policy import TFPolicy +from ray.rllib.agents.dqn.dqn_policy import _scope_vars from ray.rllib.utils.explained_variance import explained_variance from ray.rllib.utils import try_import_tf @@ -59,7 +59,7 @@ def __init__(self, state_values, cumulative_rewards, logits, actions, class MARWILPostprocessing(object): """Adds the advantages field to the trajectory.""" - @override(PolicyGraph) + @override(Policy) def postprocess_trajectory(self, sample_batch, other_agent_batches=None, @@ -79,7 +79,7 @@ def postprocess_trajectory(self, return batch -class MARWILPolicyGraph(MARWILPostprocessing, TFPolicyGraph): +class MARWILPolicy(MARWILPostprocessing, TFPolicy): def __init__(self, observation_space, action_space, config): config = dict(ray.rllib.agents.dqn.dqn.DEFAULT_CONFIG, **config) self.config = config @@ -127,14 +127,14 @@ def __init__(self, observation_space, action_space, config): self.explained_variance = tf.reduce_mean( explained_variance(self.cum_rew_t, state_values)) - # initialize TFPolicyGraph + # initialize TFPolicy self.sess = tf.get_default_session() self.loss_inputs = [ (SampleBatch.CUR_OBS, self.obs_t), (SampleBatch.ACTIONS, self.act_t), (Postprocessing.ADVANTAGES, self.cum_rew_t), ] - TFPolicyGraph.__init__( + TFPolicy.__init__( self, observation_space, action_space, @@ -166,10 +166,10 @@ def _build_policy_loss(self, state_values, cum_rwds, logits, actions, return ReweightedImitationLoss(state_values, cum_rwds, logits, actions, action_space, self.config["beta"]) - @override(TFPolicyGraph) + @override(TFPolicy) def extra_compute_grad_fetches(self): return {LEARNER_STATS_KEY: self.stats_fetches} - @override(PolicyGraph) + @override(Policy) def get_initial_state(self): return self.model.state_init diff --git a/python/ray/rllib/agents/pg/__init__.py b/python/ray/rllib/agents/pg/__init__.py index 2203188a7ca6..eb11c99bf625 100644 --- a/python/ray/rllib/agents/pg/__init__.py +++ b/python/ray/rllib/agents/pg/__init__.py @@ -1,6 +1,6 @@ from ray.rllib.agents.pg.pg import PGTrainer, DEFAULT_CONFIG -from ray.rllib.utils import renamed_class +from ray.rllib.utils import renamed_agent -PGAgent = renamed_class(PGTrainer) +PGAgent = renamed_agent(PGTrainer) __all__ = ["PGAgent", "PGTrainer", "DEFAULT_CONFIG"] diff --git a/python/ray/rllib/agents/pg/pg.py b/python/ray/rllib/agents/pg/pg.py index ffbb899d1b9e..71e2ab3fbd69 100644 --- a/python/ray/rllib/agents/pg/pg.py +++ b/python/ray/rllib/agents/pg/pg.py @@ -4,7 +4,7 @@ from ray.rllib.agents.trainer import with_common_config from ray.rllib.agents.trainer_template import build_trainer -from ray.rllib.agents.pg.pg_policy_graph import PGTFPolicy +from ray.rllib.agents.pg.pg_policy import PGTFPolicy # yapf: disable # __sphinx_doc_begin__ @@ -22,7 +22,7 @@ def get_policy_class(config): if config["use_pytorch"]: - from ray.rllib.agents.pg.torch_pg_policy_graph import PGTorchPolicy + from ray.rllib.agents.pg.torch_pg_policy import PGTorchPolicy return PGTorchPolicy else: return PGTFPolicy diff --git a/python/ray/rllib/agents/pg/pg_policy_graph.py b/python/ray/rllib/agents/pg/pg_policy.py similarity index 89% rename from python/ray/rllib/agents/pg/pg_policy_graph.py rename to python/ray/rllib/agents/pg/pg_policy.py index 54fcd041cc72..7cca613928fb 100644 --- a/python/ray/rllib/agents/pg/pg_policy_graph.py +++ b/python/ray/rllib/agents/pg/pg_policy.py @@ -5,8 +5,8 @@ import ray from ray.rllib.evaluation.postprocessing import compute_advantages, \ Postprocessing -from ray.rllib.evaluation.tf_policy_template import build_tf_policy -from ray.rllib.evaluation.sample_batch import SampleBatch +from ray.rllib.policy.tf_policy_template import build_tf_policy +from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils import try_import_tf tf = try_import_tf() diff --git a/python/ray/rllib/agents/pg/torch_pg_policy_graph.py b/python/ray/rllib/agents/pg/torch_pg_policy.py similarity index 90% rename from python/ray/rllib/agents/pg/torch_pg_policy_graph.py rename to python/ray/rllib/agents/pg/torch_pg_policy.py index cda1b6eb5057..d0f1cda71cc7 100644 --- a/python/ray/rllib/agents/pg/torch_pg_policy_graph.py +++ b/python/ray/rllib/agents/pg/torch_pg_policy.py @@ -5,8 +5,8 @@ import ray from ray.rllib.evaluation.postprocessing import compute_advantages, \ Postprocessing -from ray.rllib.evaluation.sample_batch import SampleBatch -from ray.rllib.evaluation.torch_policy_template import build_torch_policy +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.torch_policy_template import build_torch_policy def pg_torch_loss(policy, batch_tensors): diff --git a/python/ray/rllib/agents/ppo/__init__.py b/python/ray/rllib/agents/ppo/__init__.py index a02cbc23c684..a3d492baf24a 100644 --- a/python/ray/rllib/agents/ppo/__init__.py +++ b/python/ray/rllib/agents/ppo/__init__.py @@ -1,7 +1,7 @@ from ray.rllib.agents.ppo.ppo import PPOTrainer, DEFAULT_CONFIG from ray.rllib.agents.ppo.appo import APPOTrainer -from ray.rllib.utils import renamed_class +from ray.rllib.utils import renamed_agent -PPOAgent = renamed_class(PPOTrainer) +PPOAgent = renamed_agent(PPOTrainer) __all__ = ["PPOAgent", "APPOTrainer", "PPOTrainer", "DEFAULT_CONFIG"] diff --git a/python/ray/rllib/agents/ppo/appo.py b/python/ray/rllib/agents/ppo/appo.py index b32531dd7d5c..0438b2714221 100644 --- a/python/ray/rllib/agents/ppo/appo.py +++ b/python/ray/rllib/agents/ppo/appo.py @@ -2,7 +2,7 @@ from __future__ import division from __future__ import print_function -from ray.rllib.agents.ppo.appo_policy_graph import AsyncPPOTFPolicy +from ray.rllib.agents.ppo.appo_policy import AsyncPPOTFPolicy from ray.rllib.agents.trainer import with_base_config from ray.rllib.agents import impala from ray.rllib.utils.annotations import override @@ -57,8 +57,8 @@ class APPOTrainer(impala.ImpalaTrainer): _name = "APPO" _default_config = DEFAULT_CONFIG - _policy_graph = AsyncPPOTFPolicy + _policy = AsyncPPOTFPolicy @override(impala.ImpalaTrainer) - def _get_policy_graph(self): + def _get_policy(self): return AsyncPPOTFPolicy diff --git a/python/ray/rllib/agents/ppo/appo_policy_graph.py b/python/ray/rllib/agents/ppo/appo_policy.py similarity index 97% rename from python/ray/rllib/agents/ppo/appo_policy_graph.py rename to python/ray/rllib/agents/ppo/appo_policy.py index 5aa76913194f..b740d6d81430 100644 --- a/python/ray/rllib/agents/ppo/appo_policy_graph.py +++ b/python/ray/rllib/agents/ppo/appo_policy.py @@ -1,6 +1,6 @@ -"""Adapted from VTracePolicyGraph to use the PPO surrogate loss. +"""Adapted from VTraceTFPolicy to use the PPO surrogate loss. -Keep in sync with changes to VTracePolicyGraph.""" +Keep in sync with changes to VTraceTFPolicy.""" from __future__ import absolute_import from __future__ import division @@ -13,9 +13,9 @@ import ray from ray.rllib.agents.impala import vtrace from ray.rllib.evaluation.postprocessing import Postprocessing -from ray.rllib.evaluation.sample_batch import SampleBatch -from ray.rllib.evaluation.tf_policy_template import build_tf_policy -from ray.rllib.evaluation.tf_policy_graph import LearningRateSchedule +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.tf_policy_template import build_tf_policy +from ray.rllib.policy.tf_policy import LearningRateSchedule from ray.rllib.utils.explained_variance import explained_variance from ray.rllib.evaluation.postprocessing import compute_advantages from ray.rllib.utils import try_import_tf diff --git a/python/ray/rllib/agents/ppo/ppo.py b/python/ray/rllib/agents/ppo/ppo.py index d3f5abdaa95c..b395d935f119 100644 --- a/python/ray/rllib/agents/ppo/ppo.py +++ b/python/ray/rllib/agents/ppo/ppo.py @@ -5,7 +5,7 @@ import logging from ray.rllib.agents import with_common_config -from ray.rllib.agents.ppo.ppo_policy_graph import PPOTFPolicy +from ray.rllib.agents.ppo.ppo_policy import PPOTFPolicy from ray.rllib.agents.trainer_template import build_trainer from ray.rllib.optimizers import SyncSamplesOptimizer, LocalMultiGPUOptimizer @@ -143,8 +143,7 @@ def validate_config(config): raise ValueError( "Episode truncation is not supported without a value " "function. Consider setting batch_mode=complete_episodes.") - if (config["multiagent"]["policy_graphs"] - and not config["simple_optimizer"]): + if (config["multiagent"]["policies"] and not config["simple_optimizer"]): logger.info( "In multi-agent mode, policies will be optimized sequentially " "by the multi-GPU optimizer. Consider setting " diff --git a/python/ray/rllib/agents/ppo/ppo_policy_graph.py b/python/ray/rllib/agents/ppo/ppo_policy.py similarity index 98% rename from python/ray/rllib/agents/ppo/ppo_policy_graph.py rename to python/ray/rllib/agents/ppo/ppo_policy.py index 334ca788c936..5a17d6c6d60c 100644 --- a/python/ray/rllib/agents/ppo/ppo_policy_graph.py +++ b/python/ray/rllib/agents/ppo/ppo_policy.py @@ -7,9 +7,9 @@ import ray from ray.rllib.evaluation.postprocessing import compute_advantages, \ Postprocessing -from ray.rllib.evaluation.sample_batch import SampleBatch -from ray.rllib.evaluation.tf_policy_graph import LearningRateSchedule -from ray.rllib.evaluation.tf_policy_template import build_tf_policy +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.tf_policy import LearningRateSchedule +from ray.rllib.policy.tf_policy_template import build_tf_policy from ray.rllib.models.catalog import ModelCatalog from ray.rllib.utils.explained_variance import explained_variance from ray.rllib.utils import try_import_tf diff --git a/python/ray/rllib/agents/qmix/qmix.py b/python/ray/rllib/agents/qmix/qmix.py index 420a567d8eff..2ad6a3e56f95 100644 --- a/python/ray/rllib/agents/qmix/qmix.py +++ b/python/ray/rllib/agents/qmix/qmix.py @@ -4,7 +4,7 @@ from ray.rllib.agents.trainer import with_common_config from ray.rllib.agents.dqn.dqn import DQNTrainer -from ray.rllib.agents.qmix.qmix_policy_graph import QMixPolicyGraph +from ray.rllib.agents.qmix.qmix_policy import QMixTorchPolicy # yapf: disable # __sphinx_doc_begin__ @@ -95,7 +95,7 @@ class QMixTrainer(DQNTrainer): _name = "QMIX" _default_config = DEFAULT_CONFIG - _policy_graph = QMixPolicyGraph + _policy = QMixTorchPolicy _optimizer_shared_configs = [ "learning_starts", "buffer_size", "train_batch_size" ] diff --git a/python/ray/rllib/agents/qmix/qmix_policy_graph.py b/python/ray/rllib/agents/qmix/qmix_policy.py similarity index 98% rename from python/ray/rllib/agents/qmix/qmix_policy_graph.py rename to python/ray/rllib/agents/qmix/qmix_policy.py index b7c9a7ad8120..26ec387de004 100644 --- a/python/ray/rllib/agents/qmix/qmix_policy_graph.py +++ b/python/ray/rllib/agents/qmix/qmix_policy.py @@ -14,8 +14,8 @@ from ray.rllib.agents.qmix.mixers import VDNMixer, QMixer from ray.rllib.agents.qmix.model import RNNModel, _get_size from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY -from ray.rllib.evaluation.policy_graph import PolicyGraph -from ray.rllib.evaluation.sample_batch import SampleBatch +from ray.rllib.policy.policy import Policy +from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.models.action_dist import TupleActions from ray.rllib.models.catalog import ModelCatalog from ray.rllib.models.lstm import chop_into_sequences @@ -130,7 +130,7 @@ def forward(self, rewards, actions, terminated, mask, obs, next_obs, return loss, mask, masked_td_error, chosen_action_qvals, targets -class QMixPolicyGraph(PolicyGraph): +class QMixTorchPolicy(Policy): """QMix impl. Assumes homogeneous agents for now. You must use MultiAgentEnv.with_agent_groups() to group agents @@ -213,7 +213,7 @@ def __init__(self, obs_space, action_space, config): alpha=config["optim_alpha"], eps=config["optim_eps"]) - @override(PolicyGraph) + @override(Policy) def compute_actions(self, obs_batch, state_batches=None, @@ -243,7 +243,7 @@ def compute_actions(self, return TupleActions(list(actions.transpose([1, 0]))), hiddens, {} - @override(PolicyGraph) + @override(Policy) def learn_on_batch(self, samples): obs_batch, action_mask = self._unpack_observation( samples[SampleBatch.CUR_OBS]) @@ -314,22 +314,22 @@ def to_batches(arr): } return {LEARNER_STATS_KEY: stats} - @override(PolicyGraph) + @override(Policy) def get_initial_state(self): return [ s.expand([self.n_agents, -1]).numpy() for s in self.model.state_init() ] - @override(PolicyGraph) + @override(Policy) def get_weights(self): return {"model": self.model.state_dict()} - @override(PolicyGraph) + @override(Policy) def set_weights(self, weights): self.model.load_state_dict(weights["model"]) - @override(PolicyGraph) + @override(Policy) def get_state(self): return { "model": self.model.state_dict(), @@ -340,7 +340,7 @@ def get_state(self): "cur_epsilon": self.cur_epsilon, } - @override(PolicyGraph) + @override(Policy) def set_state(self, state): self.model.load_state_dict(state["model"]) self.target_model.load_state_dict(state["target_model"]) diff --git a/python/ray/rllib/agents/trainer.py b/python/ray/rllib/agents/trainer.py index 8e6db02707d8..83b00a896b71 100644 --- a/python/ray/rllib/agents/trainer.py +++ b/python/ray/rllib/agents/trainer.py @@ -19,7 +19,7 @@ from ray.rllib.models import MODEL_DEFAULTS from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator, \ _validate_multiagent_config -from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID +from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID from ray.rllib.evaluation.metrics import collect_metrics from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI @@ -220,9 +220,9 @@ # === Multiagent === "multiagent": { - # Map from policy ids to tuples of (policy_graph_cls, obs_space, + # Map from policy ids to tuples of (policy_cls, obs_space, # act_space, config). See policy_evaluator.py for more info. - "policy_graphs": {}, + "policies": {}, # Function mapping agent ids to policy ids. "policy_mapping_fn": None, # Optional whitelist of policies to train, or None for all policies. @@ -435,9 +435,7 @@ def get_scope(): "using evaluation_config: {}".format(extra_config)) # Make local evaluation evaluators self.evaluation_ev = self.make_local_evaluator( - self.env_creator, - self._policy_graph, - extra_config=extra_config) + self.env_creator, self._policy, extra_config=extra_config) self.evaluation_metrics = self._evaluate() @override(Trainable) @@ -578,10 +576,10 @@ def _default_config(self): @PublicAPI def get_policy(self, policy_id=DEFAULT_POLICY_ID): - """Return policy graph for the specified id, or None. + """Return policy for the specified id, or None. Arguments: - policy_id (str): id of policy graph to return. + policy_id (str): id of policy to return. """ return self.local_evaluator.get_policy(policy_id) @@ -606,16 +604,13 @@ def set_weights(self, weights): self.local_evaluator.set_weights(weights) @DeveloperAPI - def make_local_evaluator(self, - env_creator, - policy_graph, - extra_config=None): + def make_local_evaluator(self, env_creator, policy, extra_config=None): """Convenience method to return configured local evaluator.""" return self._make_evaluator( PolicyEvaluator, env_creator, - policy_graph, + policy, 0, merge_dicts( # important: allow local tf to use more CPUs for optimization @@ -627,7 +622,7 @@ def make_local_evaluator(self, extra_config or {})) @DeveloperAPI - def make_remote_evaluators(self, env_creator, policy_graph, count): + def make_remote_evaluators(self, env_creator, policy, count): """Convenience method to return a number of remote evaluators.""" remote_args = { @@ -639,8 +634,8 @@ def make_remote_evaluators(self, env_creator, policy_graph, count): cls = PolicyEvaluator.as_remote(**remote_args).remote return [ - self._make_evaluator(cls, env_creator, policy_graph, i + 1, - self.config) for i in range(count) + self._make_evaluator(cls, env_creator, policy, i + 1, self.config) + for i in range(count) ] @DeveloperAPI @@ -700,6 +695,13 @@ def resource_help(cls, config): @staticmethod def _validate_config(config): + if "policy_graphs" in config["multiagent"]: + logger.warning( + "The `policy_graphs` config has been renamed to `policies`.") + # Backwards compatibility + config["multiagent"]["policies"] = config["multiagent"][ + "policy_graphs"] + del config["multiagent"]["policy_graphs"] if "gpu" in config: raise ValueError( "The `gpu` config is deprecated, please use `num_gpus=0|1` " @@ -760,8 +762,7 @@ def _has_policy_optimizer(self): return hasattr(self, "optimizer") and isinstance( self.optimizer, PolicyOptimizer) - def _make_evaluator(self, cls, env_creator, policy_graph, worker_index, - config): + def _make_evaluator(self, cls, env_creator, policy, worker_index, config): def session_creator(): logger.debug("Creating TF session {}".format( config["tf_session_args"])) @@ -803,18 +804,18 @@ def session_creator(): else: input_evaluation = config["input_evaluation"] - # Fill in the default policy graph if 'None' is specified in multiagent - if self.config["multiagent"]["policy_graphs"]: - tmp = self.config["multiagent"]["policy_graphs"] + # Fill in the default policy if 'None' is specified in multiagent + if self.config["multiagent"]["policies"]: + tmp = self.config["multiagent"]["policies"] _validate_multiagent_config(tmp, allow_none_graph=True) for k, v in tmp.items(): if v[0] is None: - tmp[k] = (policy_graph, v[1], v[2], v[3]) - policy_graph = tmp + tmp[k] = (policy, v[1], v[2], v[3]) + policy = tmp return cls( env_creator, - policy_graph, + policy, policy_mapping_fn=self.config["multiagent"]["policy_mapping_fn"], policies_to_train=self.config["multiagent"]["policies_to_train"], tf_session_creator=(session_creator diff --git a/python/ray/rllib/agents/trainer_template.py b/python/ray/rllib/agents/trainer_template.py index 618bc3b30ace..314202a1e842 100644 --- a/python/ray/rllib/agents/trainer_template.py +++ b/python/ray/rllib/agents/trainer_template.py @@ -21,7 +21,7 @@ def build_trainer(name, Arguments: name (str): name of the trainer (e.g., "PPO") - default_policy (cls): the default PolicyGraph class to use + default_policy (cls): the default Policy class to use default_config (dict): the default config dict of the algorithm, otherwises uses the Trainer default config make_policy_optimizer (func): optional function that returns a @@ -30,7 +30,7 @@ def build_trainer(name, validate_config (func): optional callback that checks a given config for correctness. It may mutate the config as needed. get_policy_class (func): optional callback that takes a config and - returns the policy graph class to override the default with + returns the policy class to override the default with before_train_step (func): optional callback to run before each train() call. It takes the trainer instance as an argument. after_optimizer_step (func): optional callback to run after each @@ -51,19 +51,19 @@ def build_trainer(name, class trainer_cls(Trainer): _name = name _default_config = default_config or Trainer.COMMON_CONFIG - _policy_graph = default_policy + _policy = default_policy def _init(self, config, env_creator): if validate_config: validate_config(config) if get_policy_class is None: - policy_graph = default_policy + policy = default_policy else: - policy_graph = get_policy_class(config) + policy = get_policy_class(config) self.local_evaluator = self.make_local_evaluator( - env_creator, policy_graph) + env_creator, policy) self.remote_evaluators = self.make_remote_evaluators( - env_creator, policy_graph, config["num_workers"]) + env_creator, policy, config["num_workers"]) if make_policy_optimizer: self.optimizer = make_policy_optimizer( self.local_evaluator, self.remote_evaluators, config) diff --git a/python/ray/rllib/evaluation/episode.py b/python/ray/rllib/evaluation/episode.py index b7afa222b149..8d7641b9c313 100644 --- a/python/ray/rllib/evaluation/episode.py +++ b/python/ray/rllib/evaluation/episode.py @@ -27,7 +27,7 @@ class MultiAgentEpisode(object): user_data (dict): Dict that you can use for temporary storage. Use case 1: Model-based rollouts in multi-agent: - A custom compute_actions() function in a policy graph can inspect the + A custom compute_actions() function in a policy can inspect the current episode state and perform a number of rollouts based on the policies and state of other agents in the environment. @@ -80,7 +80,7 @@ def soft_reset(self): @DeveloperAPI def policy_for(self, agent_id=_DUMMY_AGENT_ID): - """Returns the policy graph for the specified agent. + """Returns the policy for the specified agent. If the agent is new, the policy mapping fn will be called to bind the agent to a policy for the duration of the episode. diff --git a/python/ray/rllib/evaluation/interface.py b/python/ray/rllib/evaluation/interface.py index eb705a99b530..6bc626da1175 100644 --- a/python/ray/rllib/evaluation/interface.py +++ b/python/ray/rllib/evaluation/interface.py @@ -62,7 +62,7 @@ def compute_gradients(self, samples): Returns: (grads, info): A list of gradients that can be applied on a compatible evaluator. In the multi-agent case, returns a dict - of gradients keyed by policy graph ids. An info dictionary of + of gradients keyed by policy ids. An info dictionary of extra metadata is also returned. Examples: diff --git a/python/ray/rllib/evaluation/metrics.py b/python/ray/rllib/evaluation/metrics.py index a92c64bc9e4b..d8b3122fed4b 100644 --- a/python/ray/rllib/evaluation/metrics.py +++ b/python/ray/rllib/evaluation/metrics.py @@ -7,21 +7,18 @@ import collections import ray -from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID +from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID from ray.rllib.offline.off_policy_estimator import OffPolicyEstimate +from ray.rllib.policy.policy import LEARNER_STATS_KEY from ray.rllib.utils.annotations import DeveloperAPI from ray.rllib.utils.memory import ray_get_and_free logger = logging.getLogger(__name__) -# By convention, metrics from optimizing the loss can be reported in the -# `grad_info` dict returned by learn_on_batch() / compute_grads() via this key. -LEARNER_STATS_KEY = "learner_stats" - @DeveloperAPI def get_learner_stats(grad_info): - """Return optimization stats reported from the policy graph. + """Return optimization stats reported from the policy. Example: >>> grad_info = evaluator.learn_on_batch(samples) diff --git a/python/ray/rllib/evaluation/policy_evaluator.py b/python/ray/rllib/evaluation/policy_evaluator.py index 48e19dfcb96e..40df71006a8c 100644 --- a/python/ray/rllib/evaluation/policy_evaluator.py +++ b/python/ray/rllib/evaluation/policy_evaluator.py @@ -15,11 +15,10 @@ from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv from ray.rllib.env.vector_env import VectorEnv from ray.rllib.evaluation.interface import EvaluatorInterface -from ray.rllib.evaluation.sample_batch import MultiAgentBatch, \ - DEFAULT_POLICY_ID from ray.rllib.evaluation.sampler import AsyncSampler, SyncSampler -from ray.rllib.evaluation.policy_graph import PolicyGraph -from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph +from ray.rllib.policy.sample_batch import MultiAgentBatch, DEFAULT_POLICY_ID +from ray.rllib.policy.policy import Policy +from ray.rllib.policy.tf_policy import TFPolicy from ray.rllib.offline import NoopOutput, IOContext, OutputWriter, InputReader from ray.rllib.offline.is_estimator import ImportanceSamplingEstimator from ray.rllib.offline.wis_estimator import WeightedImportanceSamplingEstimator @@ -52,9 +51,9 @@ def get_global_evaluator(): @DeveloperAPI class PolicyEvaluator(EvaluatorInterface): - """Common ``PolicyEvaluator`` implementation that wraps a ``PolicyGraph``. + """Common ``PolicyEvaluator`` implementation that wraps a ``Policy``. - This class wraps a policy graph instance and an environment class to + This class wraps a policy instance and an environment class to collect experiences from the environment. You can create many replicas of this class as Ray actors to scale RL training. @@ -65,7 +64,7 @@ class PolicyEvaluator(EvaluatorInterface): >>> # Create a policy evaluator and using it to collect experiences. >>> evaluator = PolicyEvaluator( ... env_creator=lambda _: gym.make("CartPole-v0"), - ... policy_graph=PGTFPolicy) + ... policy=PGTFPolicy) >>> print(evaluator.sample()) SampleBatch({ "obs": [[...]], "actions": [[...]], "rewards": [[...]], @@ -76,7 +75,7 @@ class PolicyEvaluator(EvaluatorInterface): ... evaluator_cls=PolicyEvaluator, ... evaluator_args={ ... "env_creator": lambda _: gym.make("CartPole-v0"), - ... "policy_graph": PGTFPolicy, + ... "policy": PGTFPolicy, ... }, ... num_workers=10) >>> for _ in range(10): optimizer.step() @@ -84,7 +83,7 @@ class PolicyEvaluator(EvaluatorInterface): >>> # Creating a multi-agent policy evaluator >>> evaluator = PolicyEvaluator( ... env_creator=lambda _: MultiAgentTrafficGrid(num_cars=25), - ... policy_graphs={ + ... policies={ ... # Use an ensemble of two policies for car agents ... "car_policy1": ... (PGTFPolicy, Box(...), Discrete(...), {"gamma": 0.99}), @@ -113,7 +112,7 @@ def as_remote(cls, num_cpus=None, num_gpus=None, resources=None): @DeveloperAPI def __init__(self, env_creator, - policy_graph, + policy, policy_mapping_fn=None, policies_to_train=None, tf_session_creator=None, @@ -147,9 +146,9 @@ def __init__(self, Arguments: env_creator (func): Function that returns a gym.Env given an EnvContext wrapped configuration. - policy_graph (class|dict): Either a class implementing - PolicyGraph, or a dictionary of policy id strings to - (PolicyGraph, obs_space, action_space, config) tuples. If a + policy (class|dict): Either a class implementing + Policy, or a dictionary of policy id strings to + (Policy, obs_space, action_space, config) tuples. If a dict is specified, then we are in multi-agent mode and a policy_mapping_fn should also be set. policy_mapping_fn (func): A function that maps agent ids to @@ -159,7 +158,7 @@ def __init__(self, policies_to_train (list): Optional whitelist of policies to train, or None for all policies. tf_session_creator (func): A function that returns a TF session. - This is optional and only useful with TFPolicyGraph. + This is optional and only useful with TFPolicy. batch_steps (int): The target number of env transitions to include in each sample batch returned from this evaluator. batch_mode (str): One of the following batch modes: @@ -196,7 +195,7 @@ def __init__(self, model_config (dict): Config to use when creating the policy model. policy_config (dict): Config to pass to the policy. In the multi-agent case, this config will be merged with the - per-policy configs specified by `policy_graph`. + per-policy configs specified by `policy`. worker_index (int): For remote evaluators, this should be set to a non-zero and unique value. This index is passed to created envs through EnvContext so that envs can be configured per worker. @@ -301,7 +300,7 @@ def make_env(vector_index): vector_index=vector_index, remote=remote_worker_envs))) self.tf_sess = None - policy_dict = _validate_and_canonicalize(policy_graph, self.env) + policy_dict = _validate_and_canonicalize(policy, self.env) self.policies_to_train = policies_to_train or list(policy_dict.keys()) if _has_tensorflow_graph(policy_dict): if (ray.is_initialized() @@ -330,7 +329,7 @@ def make_env(vector_index): or isinstance(self.env, ExternalMultiAgentEnv)) or isinstance(self.env, BaseEnv)): raise ValueError( - "Have multiple policy graphs {}, but the env ".format( + "Have multiple policies {}, but the env ".format( self.policy_map) + "{} is not a subclass of BaseEnv, MultiAgentEnv or " "ExternalMultiAgentEnv?".format(self.env)) @@ -608,17 +607,17 @@ def foreach_env(self, func): @DeveloperAPI def get_policy(self, policy_id=DEFAULT_POLICY_ID): - """Return policy graph for the specified id, or None. + """Return policy for the specified id, or None. Arguments: - policy_id (str): id of policy graph to return. + policy_id (str): id of policy to return. """ return self.policy_map.get(policy_id) @DeveloperAPI def for_policy(self, func, policy_id=DEFAULT_POLICY_ID): - """Apply the given function to the specified policy graph.""" + """Apply the given function to the specified policy.""" return func(self.policy_map[policy_id]) @@ -708,7 +707,7 @@ def _build_policy_map(self, policy_dict, policy_config): preprocessors = {} for name, (cls, obs_space, act_space, conf) in sorted(policy_dict.items()): - logger.debug("Creating policy graph for {}".format(name)) + logger.debug("Creating policy for {}".format(name)) merged_conf = merge_dicts(policy_config, conf) if self.preprocessing_enabled: preprocessor = ModelCatalog.get_preprocessor_for_space( @@ -720,7 +719,7 @@ def _build_policy_map(self, policy_dict, policy_config): if isinstance(obs_space, gym.spaces.Dict) or \ isinstance(obs_space, gym.spaces.Tuple): raise ValueError( - "Found raw Tuple|Dict space as input to policy graph. " + "Found raw Tuple|Dict space as input to policy. " "Please preprocess these observations with a " "Tuple|DictFlatteningPreprocessor.") if tf: @@ -738,12 +737,12 @@ def __del__(self): self.sampler.shutdown = True -def _validate_and_canonicalize(policy_graph, env): - if isinstance(policy_graph, dict): - _validate_multiagent_config(policy_graph) - return policy_graph - elif not issubclass(policy_graph, PolicyGraph): - raise ValueError("policy_graph must be a rllib.PolicyGraph class") +def _validate_and_canonicalize(policy, env): + if isinstance(policy, dict): + _validate_multiagent_config(policy) + return policy + elif not issubclass(policy, Policy): + raise ValueError("policy must be a rllib.Policy class") else: if (isinstance(env, MultiAgentEnv) and not hasattr(env, "observation_space")): @@ -751,38 +750,35 @@ def _validate_and_canonicalize(policy_graph, env): "MultiAgentEnv must have observation_space defined if run " "in a single-agent configuration.") return { - DEFAULT_POLICY_ID: (policy_graph, env.observation_space, + DEFAULT_POLICY_ID: (policy, env.observation_space, env.action_space, {}) } -def _validate_multiagent_config(policy_graph, allow_none_graph=False): - for k, v in policy_graph.items(): +def _validate_multiagent_config(policy, allow_none_graph=False): + for k, v in policy.items(): if not isinstance(k, str): - raise ValueError("policy_graph keys must be strs, got {}".format( + raise ValueError("policy keys must be strs, got {}".format( type(k))) if not isinstance(v, tuple) or len(v) != 4: raise ValueError( - "policy_graph values must be tuples of " + "policy values must be tuples of " "(cls, obs_space, action_space, config), got {}".format(v)) if allow_none_graph and v[0] is None: pass - elif not issubclass(v[0], PolicyGraph): - raise ValueError( - "policy_graph tuple value 0 must be a rllib.PolicyGraph " - "class or None, got {}".format(v[0])) + elif not issubclass(v[0], Policy): + raise ValueError("policy tuple value 0 must be a rllib.Policy " + "class or None, got {}".format(v[0])) if not isinstance(v[1], gym.Space): raise ValueError( - "policy_graph tuple value 1 (observation_space) must be a " + "policy tuple value 1 (observation_space) must be a " "gym.Space, got {}".format(type(v[1]))) if not isinstance(v[2], gym.Space): - raise ValueError( - "policy_graph tuple value 2 (action_space) must be a " - "gym.Space, got {}".format(type(v[2]))) + raise ValueError("policy tuple value 2 (action_space) must be a " + "gym.Space, got {}".format(type(v[2]))) if not isinstance(v[3], dict): - raise ValueError( - "policy_graph tuple value 3 (config) must be a dict, " - "got {}".format(type(v[3]))) + raise ValueError("policy tuple value 3 (config) must be a dict, " + "got {}".format(type(v[3]))) def _validate_env(env): @@ -805,6 +801,6 @@ def _monitor(env, path): def _has_tensorflow_graph(policy_dict): for policy, _, _, _ in policy_dict.values(): - if issubclass(policy, TFPolicyGraph): + if issubclass(policy, TFPolicy): return True return False diff --git a/python/ray/rllib/evaluation/policy_graph.py b/python/ray/rllib/evaluation/policy_graph.py index a577550975f9..5d0fdf2a4e57 100644 --- a/python/ray/rllib/evaluation/policy_graph.py +++ b/python/ray/rllib/evaluation/policy_graph.py @@ -2,286 +2,7 @@ from __future__ import division from __future__ import print_function -import numpy as np -import gym +from ray.rllib.policy.policy import Policy +from ray.rllib.utils import renamed_class -from ray.rllib.utils.annotations import DeveloperAPI - - -@DeveloperAPI -class PolicyGraph(object): - """An agent policy and loss, i.e., a TFPolicyGraph or other subclass. - - This object defines how to act in the environment, and also losses used to - improve the policy based on its experiences. Note that both policy and - loss are defined together for convenience, though the policy itself is - logically separate. - - All policies can directly extend PolicyGraph, however TensorFlow users may - find TFPolicyGraph simpler to implement. TFPolicyGraph also enables RLlib - to apply TensorFlow-specific optimizations such as fusing multiple policy - graphs and multi-GPU support. - - Attributes: - observation_space (gym.Space): Observation space of the policy. - action_space (gym.Space): Action space of the policy. - """ - - @DeveloperAPI - def __init__(self, observation_space, action_space, config): - """Initialize the graph. - - This is the standard constructor for policy graphs. The policy graph - class you pass into PolicyEvaluator will be constructed with - these arguments. - - Args: - observation_space (gym.Space): Observation space of the policy. - action_space (gym.Space): Action space of the policy. - config (dict): Policy-specific configuration data. - """ - - self.observation_space = observation_space - self.action_space = action_space - - @DeveloperAPI - def compute_actions(self, - obs_batch, - state_batches, - prev_action_batch=None, - prev_reward_batch=None, - info_batch=None, - episodes=None, - **kwargs): - """Compute actions for the current policy. - - Arguments: - obs_batch (np.ndarray): batch of observations - state_batches (list): list of RNN state input batches, if any - prev_action_batch (np.ndarray): batch of previous action values - prev_reward_batch (np.ndarray): batch of previous rewards - info_batch (info): batch of info objects - episodes (list): MultiAgentEpisode for each obs in obs_batch. - This provides access to all of the internal episode state, - which may be useful for model-based or multiagent algorithms. - kwargs: forward compatibility placeholder - - Returns: - actions (np.ndarray): batch of output actions, with shape like - [BATCH_SIZE, ACTION_SHAPE]. - state_outs (list): list of RNN state output batches, if any, with - shape like [STATE_SIZE, BATCH_SIZE]. - info (dict): dictionary of extra feature batches, if any, with - shape like {"f1": [BATCH_SIZE, ...], "f2": [BATCH_SIZE, ...]}. - """ - raise NotImplementedError - - @DeveloperAPI - def compute_single_action(self, - obs, - state, - prev_action=None, - prev_reward=None, - info=None, - episode=None, - clip_actions=False, - **kwargs): - """Unbatched version of compute_actions. - - Arguments: - obs (obj): single observation - state_batches (list): list of RNN state inputs, if any - prev_action (obj): previous action value, if any - prev_reward (int): previous reward, if any - info (dict): info object, if any - episode (MultiAgentEpisode): this provides access to all of the - internal episode state, which may be useful for model-based or - multi-agent algorithms. - clip_actions (bool): should the action be clipped - kwargs: forward compatibility placeholder - - Returns: - actions (obj): single action - state_outs (list): list of RNN state outputs, if any - info (dict): dictionary of extra features, if any - """ - - prev_action_batch = None - prev_reward_batch = None - info_batch = None - episodes = None - if prev_action is not None: - prev_action_batch = [prev_action] - if prev_reward is not None: - prev_reward_batch = [prev_reward] - if info is not None: - info_batch = [info] - if episode is not None: - episodes = [episode] - [action], state_out, info = self.compute_actions( - [obs], [[s] for s in state], - prev_action_batch=prev_action_batch, - prev_reward_batch=prev_reward_batch, - info_batch=info_batch, - episodes=episodes) - if clip_actions: - action = clip_action(action, self.action_space) - return action, [s[0] for s in state_out], \ - {k: v[0] for k, v in info.items()} - - @DeveloperAPI - def postprocess_trajectory(self, - sample_batch, - other_agent_batches=None, - episode=None): - """Implements algorithm-specific trajectory postprocessing. - - This will be called on each trajectory fragment computed during policy - evaluation. Each fragment is guaranteed to be only from one episode. - - Arguments: - sample_batch (SampleBatch): batch of experiences for the policy, - which will contain at most one episode trajectory. - other_agent_batches (dict): In a multi-agent env, this contains a - mapping of agent ids to (policy_graph, agent_batch) tuples - containing the policy graph and experiences of the other agent. - episode (MultiAgentEpisode): this provides access to all of the - internal episode state, which may be useful for model-based or - multi-agent algorithms. - - Returns: - SampleBatch: postprocessed sample batch. - """ - return sample_batch - - @DeveloperAPI - def learn_on_batch(self, samples): - """Fused compute gradients and apply gradients call. - - Either this or the combination of compute/apply grads must be - implemented by subclasses. - - Returns: - grad_info: dictionary of extra metadata from compute_gradients(). - - Examples: - >>> batch = ev.sample() - >>> ev.learn_on_batch(samples) - """ - - grads, grad_info = self.compute_gradients(samples) - self.apply_gradients(grads) - return grad_info - - @DeveloperAPI - def compute_gradients(self, postprocessed_batch): - """Computes gradients against a batch of experiences. - - Either this or learn_on_batch() must be implemented by subclasses. - - Returns: - grads (list): List of gradient output values - info (dict): Extra policy-specific values - """ - raise NotImplementedError - - @DeveloperAPI - def apply_gradients(self, gradients): - """Applies previously computed gradients. - - Either this or learn_on_batch() must be implemented by subclasses. - """ - raise NotImplementedError - - @DeveloperAPI - def get_weights(self): - """Returns model weights. - - Returns: - weights (obj): Serializable copy or view of model weights - """ - raise NotImplementedError - - @DeveloperAPI - def set_weights(self, weights): - """Sets model weights. - - Arguments: - weights (obj): Serializable copy or view of model weights - """ - raise NotImplementedError - - @DeveloperAPI - def get_initial_state(self): - """Returns initial RNN state for the current policy.""" - return [] - - @DeveloperAPI - def get_state(self): - """Saves all local state. - - Returns: - state (obj): Serialized local state. - """ - return self.get_weights() - - @DeveloperAPI - def set_state(self, state): - """Restores all local state. - - Arguments: - state (obj): Serialized local state. - """ - self.set_weights(state) - - @DeveloperAPI - def on_global_var_update(self, global_vars): - """Called on an update to global vars. - - Arguments: - global_vars (dict): Global variables broadcast from the driver. - """ - pass - - @DeveloperAPI - def export_model(self, export_dir): - """Export PolicyGraph to local directory for serving. - - Arguments: - export_dir (str): Local writable directory. - """ - raise NotImplementedError - - @DeveloperAPI - def export_checkpoint(self, export_dir): - """Export PolicyGraph checkpoint to local directory. - - Argument: - export_dir (str): Local writable directory. - """ - raise NotImplementedError - - -def clip_action(action, space): - """Called to clip actions to the specified range of this policy. - - Arguments: - action: Single action. - space: Action space the actions should be present in. - - Returns: - Clipped batch of actions. - """ - - if isinstance(space, gym.spaces.Box): - return np.clip(action, space.low, space.high) - elif isinstance(space, gym.spaces.Tuple): - if type(action) not in (tuple, list): - raise ValueError("Expected tuple space for actions {}: {}".format( - action, space)) - out = [] - for a, s in zip(action, space.spaces): - out.append(clip_action(a, s)) - return out - else: - return action +PolicyGraph = renamed_class(Policy, old_name="PolicyGraph") diff --git a/python/ray/rllib/evaluation/postprocessing.py b/python/ray/rllib/evaluation/postprocessing.py index aa2835f87e04..f236df6ed763 100644 --- a/python/ray/rllib/evaluation/postprocessing.py +++ b/python/ray/rllib/evaluation/postprocessing.py @@ -4,7 +4,7 @@ import numpy as np import scipy.signal -from ray.rllib.evaluation.sample_batch import SampleBatch +from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils.annotations import DeveloperAPI diff --git a/python/ray/rllib/evaluation/sample_batch.py b/python/ray/rllib/evaluation/sample_batch.py index c80f22bdbd1a..2c0f119a94b2 100644 --- a/python/ray/rllib/evaluation/sample_batch.py +++ b/python/ray/rllib/evaluation/sample_batch.py @@ -2,295 +2,10 @@ from __future__ import division from __future__ import print_function -import six -import collections -import numpy as np +from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch +from ray.rllib.utils import renamed_class -from ray.rllib.utils.annotations import PublicAPI, DeveloperAPI -from ray.rllib.utils.compression import pack, unpack, is_compressed -from ray.rllib.utils.memory import concat_aligned - -# Defaults policy id for single agent environments -DEFAULT_POLICY_ID = "default_policy" - - -@PublicAPI -class MultiAgentBatch(object): - """A batch of experiences from multiple policies in the environment. - - Attributes: - policy_batches (dict): Mapping from policy id to a normal SampleBatch - of experiences. Note that these batches may be of different length. - count (int): The number of timesteps in the environment this batch - contains. This will be less than the number of transitions this - batch contains across all policies in total. - """ - - @PublicAPI - def __init__(self, policy_batches, count): - self.policy_batches = policy_batches - self.count = count - - @staticmethod - @PublicAPI - def wrap_as_needed(batches, count): - if len(batches) == 1 and DEFAULT_POLICY_ID in batches: - return batches[DEFAULT_POLICY_ID] - return MultiAgentBatch(batches, count) - - @staticmethod - @PublicAPI - def concat_samples(samples): - policy_batches = collections.defaultdict(list) - total_count = 0 - for s in samples: - assert isinstance(s, MultiAgentBatch) - for policy_id, batch in s.policy_batches.items(): - policy_batches[policy_id].append(batch) - total_count += s.count - out = {} - for policy_id, batches in policy_batches.items(): - out[policy_id] = SampleBatch.concat_samples(batches) - return MultiAgentBatch(out, total_count) - - @PublicAPI - def copy(self): - return MultiAgentBatch( - {k: v.copy() - for (k, v) in self.policy_batches.items()}, self.count) - - @PublicAPI - def total(self): - ct = 0 - for batch in self.policy_batches.values(): - ct += batch.count - return ct - - @DeveloperAPI - def compress(self, bulk=False, columns=frozenset(["obs", "new_obs"])): - for batch in self.policy_batches.values(): - batch.compress(bulk=bulk, columns=columns) - - @DeveloperAPI - def decompress_if_needed(self, columns=frozenset(["obs", "new_obs"])): - for batch in self.policy_batches.values(): - batch.decompress_if_needed(columns) - - def __str__(self): - return "MultiAgentBatch({}, count={})".format( - str(self.policy_batches), self.count) - - def __repr__(self): - return "MultiAgentBatch({}, count={})".format( - str(self.policy_batches), self.count) - - -@PublicAPI -class SampleBatch(object): - """Wrapper around a dictionary with string keys and array-like values. - - For example, {"obs": [1, 2, 3], "reward": [0, -1, 1]} is a batch of three - samples, each with an "obs" and "reward" attribute. - """ - - # Outputs from interacting with the environment - CUR_OBS = "obs" - NEXT_OBS = "new_obs" - ACTIONS = "actions" - REWARDS = "rewards" - PREV_ACTIONS = "prev_actions" - PREV_REWARDS = "prev_rewards" - DONES = "dones" - INFOS = "infos" - - # Uniquely identifies an episode - EPS_ID = "eps_id" - - # Uniquely identifies a sample batch. This is important to distinguish RNN - # sequences from the same episode when multiple sample batches are - # concatenated (fusing sequences across batches can be unsafe). - UNROLL_ID = "unroll_id" - - # Uniquely identifies an agent within an episode - AGENT_INDEX = "agent_index" - - # Value function predictions emitted by the behaviour policy - VF_PREDS = "vf_preds" - - @PublicAPI - def __init__(self, *args, **kwargs): - """Constructs a sample batch (same params as dict constructor).""" - - self.data = dict(*args, **kwargs) - lengths = [] - for k, v in self.data.copy().items(): - assert isinstance(k, six.string_types), self - lengths.append(len(v)) - self.data[k] = np.array(v, copy=False) - if not lengths: - raise ValueError("Empty sample batch") - assert len(set(lengths)) == 1, "data columns must be same length" - self.count = lengths[0] - - @staticmethod - @PublicAPI - def concat_samples(samples): - if isinstance(samples[0], MultiAgentBatch): - return MultiAgentBatch.concat_samples(samples) - out = {} - samples = [s for s in samples if s.count > 0] - for k in samples[0].keys(): - out[k] = concat_aligned([s[k] for s in samples]) - return SampleBatch(out) - - @PublicAPI - def concat(self, other): - """Returns a new SampleBatch with each data column concatenated. - - Examples: - >>> b1 = SampleBatch({"a": [1, 2]}) - >>> b2 = SampleBatch({"a": [3, 4, 5]}) - >>> print(b1.concat(b2)) - {"a": [1, 2, 3, 4, 5]} - """ - - assert self.keys() == other.keys(), "must have same columns" - out = {} - for k in self.keys(): - out[k] = concat_aligned([self[k], other[k]]) - return SampleBatch(out) - - @PublicAPI - def copy(self): - return SampleBatch( - {k: np.array(v, copy=True) - for (k, v) in self.data.items()}) - - @PublicAPI - def rows(self): - """Returns an iterator over data rows, i.e. dicts with column values. - - Examples: - >>> batch = SampleBatch({"a": [1, 2, 3], "b": [4, 5, 6]}) - >>> for row in batch.rows(): - print(row) - {"a": 1, "b": 4} - {"a": 2, "b": 5} - {"a": 3, "b": 6} - """ - - for i in range(self.count): - row = {} - for k in self.keys(): - row[k] = self[k][i] - yield row - - @PublicAPI - def columns(self, keys): - """Returns a list of just the specified columns. - - Examples: - >>> batch = SampleBatch({"a": [1], "b": [2], "c": [3]}) - >>> print(batch.columns(["a", "b"])) - [[1], [2]] - """ - - out = [] - for k in keys: - out.append(self[k]) - return out - - @PublicAPI - def shuffle(self): - """Shuffles the rows of this batch in-place.""" - - permutation = np.random.permutation(self.count) - for key, val in self.items(): - self[key] = val[permutation] - - @PublicAPI - def split_by_episode(self): - """Splits this batch's data by `eps_id`. - - Returns: - list of SampleBatch, one per distinct episode. - """ - - slices = [] - cur_eps_id = self.data["eps_id"][0] - offset = 0 - for i in range(self.count): - next_eps_id = self.data["eps_id"][i] - if next_eps_id != cur_eps_id: - slices.append(self.slice(offset, i)) - offset = i - cur_eps_id = next_eps_id - slices.append(self.slice(offset, self.count)) - for s in slices: - slen = len(set(s["eps_id"])) - assert slen == 1, (s, slen) - assert sum(s.count for s in slices) == self.count, (slices, self.count) - return slices - - @PublicAPI - def slice(self, start, end): - """Returns a slice of the row data of this batch. - - Arguments: - start (int): Starting index. - end (int): Ending index. - - Returns: - SampleBatch which has a slice of this batch's data. - """ - - return SampleBatch({k: v[start:end] for k, v in self.data.items()}) - - @PublicAPI - def keys(self): - return self.data.keys() - - @PublicAPI - def items(self): - return self.data.items() - - @PublicAPI - def __getitem__(self, key): - return self.data[key] - - @PublicAPI - def __setitem__(self, key, item): - self.data[key] = item - - @DeveloperAPI - def compress(self, bulk=False, columns=frozenset(["obs", "new_obs"])): - for key in columns: - if key in self.data: - if bulk: - self.data[key] = pack(self.data[key]) - else: - self.data[key] = np.array( - [pack(o) for o in self.data[key]]) - - @DeveloperAPI - def decompress_if_needed(self, columns=frozenset(["obs", "new_obs"])): - for key in columns: - if key in self.data: - arr = self.data[key] - if is_compressed(arr): - self.data[key] = unpack(arr) - elif len(arr) > 0 and is_compressed(arr[0]): - self.data[key] = np.array( - [unpack(o) for o in self.data[key]]) - - def __str__(self): - return "SampleBatch({})".format(str(self.data)) - - def __repr__(self): - return "SampleBatch({})".format(str(self.data)) - - def __iter__(self): - return self.data.__iter__() - - def __contains__(self, x): - return x in self.data +SampleBatch = renamed_class( + SampleBatch, old_name="rllib.evaluation.SampleBatch") +MultiAgentBatch = renamed_class( + MultiAgentBatch, old_name="rllib.evaluation.MultiAgentBatch") diff --git a/python/ray/rllib/evaluation/sample_batch_builder.py b/python/ray/rllib/evaluation/sample_batch_builder.py index c6d69d7d97f1..0ead77d52847 100644 --- a/python/ray/rllib/evaluation/sample_batch_builder.py +++ b/python/ray/rllib/evaluation/sample_batch_builder.py @@ -6,7 +6,7 @@ import logging import numpy as np -from ray.rllib.evaluation.sample_batch import SampleBatch, MultiAgentBatch +from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch from ray.rllib.utils.annotations import PublicAPI, DeveloperAPI from ray.rllib.utils.debug import log_once, summarize @@ -79,7 +79,7 @@ def __init__(self, policy_map, clip_rewards, postp_callback): """Initialize a MultiAgentSampleBatchBuilder. Arguments: - policy_map (dict): Maps policy ids to policy graph instances. + policy_map (dict): Maps policy ids to policy instances. clip_rewards (bool): Whether to clip rewards before postprocessing. postp_callback: function to call on each postprocessed batch. """ diff --git a/python/ray/rllib/evaluation/sampler.py b/python/ray/rllib/evaluation/sampler.py index 25dd1ef5d9a7..47964c3c561b 100644 --- a/python/ray/rllib/evaluation/sampler.py +++ b/python/ray/rllib/evaluation/sampler.py @@ -12,7 +12,7 @@ from ray.rllib.evaluation.episode import MultiAgentEpisode, _flatten_action from ray.rllib.evaluation.sample_batch_builder import \ MultiAgentSampleBatchBuilder -from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph +from ray.rllib.policy.tf_policy import TFPolicy from ray.rllib.env.base_env import BaseEnv, ASYNC_RESET_RETURN from ray.rllib.env.atari_wrappers import get_wrapper_by_cls, MonitorEnv from ray.rllib.models.action_dist import TupleActions @@ -20,7 +20,7 @@ from ray.rllib.utils.annotations import override from ray.rllib.utils.debug import log_once, summarize from ray.rllib.utils.tf_run_builder import TFRunBuilder -from ray.rllib.evaluation.policy_graph import clip_action +from ray.rllib.policy.policy import clip_action logger = logging.getLogger(__name__) @@ -236,7 +236,7 @@ def _env_runner(base_env, extra_batch_callback, policies, policy_mapping_fn, Args: base_env (BaseEnv): env implementing BaseEnv. extra_batch_callback (fn): function to send extra batch data to. - policies (dict): Map of policy ids to PolicyGraph instances. + policies (dict): Map of policy ids to Policy instances. policy_mapping_fn (func): Function that maps agent ids to policy ids. This is called when an agent first enters the environment. The agent is then "bound" to the returned policy for the episode. @@ -528,7 +528,7 @@ def _do_policy_eval(tf_sess, to_eval, policies, active_episodes): rnn_in_cols = _to_column_format([t.rnn_state for t in eval_data]) policy = _get_or_raise(policies, policy_id) if builder and (policy.compute_actions.__code__ is - TFPolicyGraph.compute_actions.__code__): + TFPolicy.compute_actions.__code__): # TODO(ekl): how can we make info batch available to TF code? pending_fetches[policy_id] = policy._build_compute_actions( builder, [t.obs for t in eval_data], diff --git a/python/ray/rllib/evaluation/tf_policy_graph.py b/python/ray/rllib/evaluation/tf_policy_graph.py index b921e6cfb0d1..2c4955a17ff1 100644 --- a/python/ray/rllib/evaluation/tf_policy_graph.py +++ b/python/ray/rllib/evaluation/tf_policy_graph.py @@ -2,513 +2,7 @@ from __future__ import division from __future__ import print_function -import os -import errno -import logging -import numpy as np +from ray.rllib.policy.tf_policy import TFPolicy +from ray.rllib.utils import renamed_class -import ray -import ray.experimental.tf_utils -from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY -from ray.rllib.evaluation.policy_graph import PolicyGraph -from ray.rllib.evaluation.sample_batch import SampleBatch -from ray.rllib.models.lstm import chop_into_sequences -from ray.rllib.utils.annotations import override, DeveloperAPI -from ray.rllib.utils.debug import log_once, summarize -from ray.rllib.utils.schedules import ConstantSchedule, PiecewiseSchedule -from ray.rllib.utils.tf_run_builder import TFRunBuilder -from ray.rllib.utils import try_import_tf - -tf = try_import_tf() -logger = logging.getLogger(__name__) - - -@DeveloperAPI -class TFPolicyGraph(PolicyGraph): - """An agent policy and loss implemented in TensorFlow. - - Extending this class enables RLlib to perform TensorFlow specific - optimizations on the policy graph, e.g., parallelization across gpus or - fusing multiple graphs together in the multi-agent setting. - - Input tensors are typically shaped like [BATCH_SIZE, ...]. - - Attributes: - observation_space (gym.Space): observation space of the policy. - action_space (gym.Space): action space of the policy. - model (rllib.models.Model): RLlib model used for the policy. - - Examples: - >>> policy = TFPolicyGraphSubclass( - sess, obs_input, action_sampler, loss, loss_inputs) - - >>> print(policy.compute_actions([1, 0, 2])) - (array([0, 1, 1]), [], {}) - - >>> print(policy.postprocess_trajectory(SampleBatch({...}))) - SampleBatch({"action": ..., "advantages": ..., ...}) - """ - - @DeveloperAPI - def __init__(self, - observation_space, - action_space, - sess, - obs_input, - action_sampler, - loss, - loss_inputs, - model=None, - action_prob=None, - state_inputs=None, - state_outputs=None, - prev_action_input=None, - prev_reward_input=None, - seq_lens=None, - max_seq_len=20, - batch_divisibility_req=1, - update_ops=None): - """Initialize the policy graph. - - Arguments: - observation_space (gym.Space): Observation space of the env. - action_space (gym.Space): Action space of the env. - sess (Session): TensorFlow session to use. - obs_input (Tensor): input placeholder for observations, of shape - [BATCH_SIZE, obs...]. - action_sampler (Tensor): Tensor for sampling an action, of shape - [BATCH_SIZE, action...] - loss (Tensor): scalar policy loss output tensor. - loss_inputs (list): a (name, placeholder) tuple for each loss - input argument. Each placeholder name must correspond to a - SampleBatch column key returned by postprocess_trajectory(), - and has shape [BATCH_SIZE, data...]. These keys will be read - from postprocessed sample batches and fed into the specified - placeholders during loss computation. - model (rllib.models.Model): used to integrate custom losses and - stats from user-defined RLlib models. - action_prob (Tensor): probability of the sampled action. - state_inputs (list): list of RNN state input Tensors. - state_outputs (list): list of RNN state output Tensors. - prev_action_input (Tensor): placeholder for previous actions - prev_reward_input (Tensor): placeholder for previous rewards - seq_lens (Tensor): placeholder for RNN sequence lengths, of shape - [NUM_SEQUENCES]. Note that NUM_SEQUENCES << BATCH_SIZE. See - models/lstm.py for more information. - max_seq_len (int): max sequence length for LSTM training. - batch_divisibility_req (int): pad all agent experiences batches to - multiples of this value. This only has an effect if not using - a LSTM model. - update_ops (list): override the batchnorm update ops to run when - applying gradients. Otherwise we run all update ops found in - the current variable scope. - """ - - self.observation_space = observation_space - self.action_space = action_space - self.model = model - self._sess = sess - self._obs_input = obs_input - self._prev_action_input = prev_action_input - self._prev_reward_input = prev_reward_input - self._sampler = action_sampler - self._is_training = self._get_is_training_placeholder() - self._action_prob = action_prob - self._state_inputs = state_inputs or [] - self._state_outputs = state_outputs or [] - self._seq_lens = seq_lens - self._max_seq_len = max_seq_len - self._batch_divisibility_req = batch_divisibility_req - self._update_ops = update_ops - self._stats_fetches = {} - - if loss is not None: - self._initialize_loss(loss, loss_inputs) - else: - self._loss = None - - if len(self._state_inputs) != len(self._state_outputs): - raise ValueError( - "Number of state input and output tensors must match, got: " - "{} vs {}".format(self._state_inputs, self._state_outputs)) - if len(self.get_initial_state()) != len(self._state_inputs): - raise ValueError( - "Length of initial state must match number of state inputs, " - "got: {} vs {}".format(self.get_initial_state(), - self._state_inputs)) - if self._state_inputs and self._seq_lens is None: - raise ValueError( - "seq_lens tensor must be given if state inputs are defined") - - def _initialize_loss(self, loss, loss_inputs): - self._loss_inputs = loss_inputs - self._loss_input_dict = dict(self._loss_inputs) - for i, ph in enumerate(self._state_inputs): - self._loss_input_dict["state_in_{}".format(i)] = ph - - if self.model: - self._loss = self.model.custom_loss(loss, self._loss_input_dict) - self._stats_fetches.update({"model": self.model.custom_stats()}) - else: - self._loss = loss - - self._optimizer = self.optimizer() - self._grads_and_vars = [ - (g, v) for (g, v) in self.gradients(self._optimizer, self._loss) - if g is not None - ] - self._grads = [g for (g, v) in self._grads_and_vars] - self._variables = ray.experimental.tf_utils.TensorFlowVariables( - self._loss, self._sess) - - # gather update ops for any batch norm layers - if not self._update_ops: - self._update_ops = tf.get_collection( - tf.GraphKeys.UPDATE_OPS, scope=tf.get_variable_scope().name) - if self._update_ops: - logger.debug("Update ops to run on apply gradient: {}".format( - self._update_ops)) - with tf.control_dependencies(self._update_ops): - self._apply_op = self.build_apply_op(self._optimizer, - self._grads_and_vars) - - if log_once("loss_used"): - logger.debug( - "These tensors were used in the loss_fn:\n\n{}\n".format( - summarize(self._loss_input_dict))) - - self._sess.run(tf.global_variables_initializer()) - - @override(PolicyGraph) - def compute_actions(self, - obs_batch, - state_batches=None, - prev_action_batch=None, - prev_reward_batch=None, - info_batch=None, - episodes=None, - **kwargs): - builder = TFRunBuilder(self._sess, "compute_actions") - fetches = self._build_compute_actions(builder, obs_batch, - state_batches, prev_action_batch, - prev_reward_batch) - return builder.get(fetches) - - @override(PolicyGraph) - def compute_gradients(self, postprocessed_batch): - assert self._loss is not None, "Loss not initialized" - builder = TFRunBuilder(self._sess, "compute_gradients") - fetches = self._build_compute_gradients(builder, postprocessed_batch) - return builder.get(fetches) - - @override(PolicyGraph) - def apply_gradients(self, gradients): - assert self._loss is not None, "Loss not initialized" - builder = TFRunBuilder(self._sess, "apply_gradients") - fetches = self._build_apply_gradients(builder, gradients) - builder.get(fetches) - - @override(PolicyGraph) - def learn_on_batch(self, postprocessed_batch): - assert self._loss is not None, "Loss not initialized" - builder = TFRunBuilder(self._sess, "learn_on_batch") - fetches = self._build_learn_on_batch(builder, postprocessed_batch) - return builder.get(fetches) - - @override(PolicyGraph) - def get_weights(self): - return self._variables.get_flat() - - @override(PolicyGraph) - def set_weights(self, weights): - return self._variables.set_flat(weights) - - @override(PolicyGraph) - def export_model(self, export_dir): - """Export tensorflow graph to export_dir for serving.""" - with self._sess.graph.as_default(): - builder = tf.saved_model.builder.SavedModelBuilder(export_dir) - signature_def_map = self._build_signature_def() - builder.add_meta_graph_and_variables( - self._sess, [tf.saved_model.tag_constants.SERVING], - signature_def_map=signature_def_map) - builder.save() - - @override(PolicyGraph) - def export_checkpoint(self, export_dir, filename_prefix="model"): - """Export tensorflow checkpoint to export_dir.""" - try: - os.makedirs(export_dir) - except OSError as e: - # ignore error if export dir already exists - if e.errno != errno.EEXIST: - raise - save_path = os.path.join(export_dir, filename_prefix) - with self._sess.graph.as_default(): - saver = tf.train.Saver() - saver.save(self._sess, save_path) - - @DeveloperAPI - def copy(self, existing_inputs): - """Creates a copy of self using existing input placeholders. - - Optional, only required to work with the multi-GPU optimizer.""" - raise NotImplementedError - - @DeveloperAPI - def extra_compute_action_feed_dict(self): - """Extra dict to pass to the compute actions session run.""" - return {} - - @DeveloperAPI - def extra_compute_action_fetches(self): - """Extra values to fetch and return from compute_actions(). - - By default we only return action probability info (if present). - """ - if self._action_prob is not None: - return {"action_prob": self._action_prob} - else: - return {} - - @DeveloperAPI - def extra_compute_grad_feed_dict(self): - """Extra dict to pass to the compute gradients session run.""" - return {} # e.g, kl_coeff - - @DeveloperAPI - def extra_compute_grad_fetches(self): - """Extra values to fetch and return from compute_gradients().""" - return {LEARNER_STATS_KEY: {}} # e.g, stats, td error, etc. - - @DeveloperAPI - def optimizer(self): - """TF optimizer to use for policy optimization.""" - if hasattr(self, "config"): - return tf.train.AdamOptimizer(self.config["lr"]) - else: - return tf.train.AdamOptimizer() - - @DeveloperAPI - def gradients(self, optimizer, loss): - """Override for custom gradient computation.""" - return optimizer.compute_gradients(loss) - - @DeveloperAPI - def build_apply_op(self, optimizer, grads_and_vars): - """Override for custom gradient apply computation.""" - - # specify global_step for TD3 which needs to count the num updates - return optimizer.apply_gradients( - self._grads_and_vars, - global_step=tf.train.get_or_create_global_step()) - - @DeveloperAPI - def _get_is_training_placeholder(self): - """Get the placeholder for _is_training, i.e., for batch norm layers. - - This can be called safely before __init__ has run. - """ - if not hasattr(self, "_is_training"): - self._is_training = tf.placeholder_with_default(False, ()) - return self._is_training - - def _extra_input_signature_def(self): - """Extra input signatures to add when exporting tf model. - Inferred from extra_compute_action_feed_dict() - """ - feed_dict = self.extra_compute_action_feed_dict() - return { - k.name: tf.saved_model.utils.build_tensor_info(k) - for k in feed_dict.keys() - } - - def _extra_output_signature_def(self): - """Extra output signatures to add when exporting tf model. - Inferred from extra_compute_action_fetches() - """ - fetches = self.extra_compute_action_fetches() - return { - k: tf.saved_model.utils.build_tensor_info(fetches[k]) - for k in fetches.keys() - } - - def _build_signature_def(self): - """Build signature def map for tensorflow SavedModelBuilder. - """ - # build input signatures - input_signature = self._extra_input_signature_def() - input_signature["observations"] = \ - tf.saved_model.utils.build_tensor_info(self._obs_input) - - if self._seq_lens is not None: - input_signature["seq_lens"] = \ - tf.saved_model.utils.build_tensor_info(self._seq_lens) - if self._prev_action_input is not None: - input_signature["prev_action"] = \ - tf.saved_model.utils.build_tensor_info(self._prev_action_input) - if self._prev_reward_input is not None: - input_signature["prev_reward"] = \ - tf.saved_model.utils.build_tensor_info(self._prev_reward_input) - input_signature["is_training"] = \ - tf.saved_model.utils.build_tensor_info(self._is_training) - - for state_input in self._state_inputs: - input_signature[state_input.name] = \ - tf.saved_model.utils.build_tensor_info(state_input) - - # build output signatures - output_signature = self._extra_output_signature_def() - output_signature["actions"] = \ - tf.saved_model.utils.build_tensor_info(self._sampler) - for state_output in self._state_outputs: - output_signature[state_output.name] = \ - tf.saved_model.utils.build_tensor_info(state_output) - signature_def = ( - tf.saved_model.signature_def_utils.build_signature_def( - input_signature, output_signature, - tf.saved_model.signature_constants.PREDICT_METHOD_NAME)) - signature_def_key = (tf.saved_model.signature_constants. - DEFAULT_SERVING_SIGNATURE_DEF_KEY) - signature_def_map = {signature_def_key: signature_def} - return signature_def_map - - def _build_compute_actions(self, - builder, - obs_batch, - state_batches=None, - prev_action_batch=None, - prev_reward_batch=None, - episodes=None): - state_batches = state_batches or [] - if len(self._state_inputs) != len(state_batches): - raise ValueError( - "Must pass in RNN state batches for placeholders {}, got {}". - format(self._state_inputs, state_batches)) - builder.add_feed_dict(self.extra_compute_action_feed_dict()) - builder.add_feed_dict({self._obs_input: obs_batch}) - if state_batches: - builder.add_feed_dict({self._seq_lens: np.ones(len(obs_batch))}) - if self._prev_action_input is not None and prev_action_batch: - builder.add_feed_dict({self._prev_action_input: prev_action_batch}) - if self._prev_reward_input is not None and prev_reward_batch: - builder.add_feed_dict({self._prev_reward_input: prev_reward_batch}) - builder.add_feed_dict({self._is_training: False}) - builder.add_feed_dict(dict(zip(self._state_inputs, state_batches))) - fetches = builder.add_fetches([self._sampler] + self._state_outputs + - [self.extra_compute_action_fetches()]) - return fetches[0], fetches[1:-1], fetches[-1] - - def _build_compute_gradients(self, builder, postprocessed_batch): - builder.add_feed_dict(self.extra_compute_grad_feed_dict()) - builder.add_feed_dict({self._is_training: True}) - builder.add_feed_dict(self._get_loss_inputs_dict(postprocessed_batch)) - fetches = builder.add_fetches( - [self._grads, self._get_grad_and_stats_fetches()]) - return fetches[0], fetches[1] - - def _build_apply_gradients(self, builder, gradients): - if len(gradients) != len(self._grads): - raise ValueError( - "Unexpected number of gradients to apply, got {} for {}". - format(gradients, self._grads)) - builder.add_feed_dict({self._is_training: True}) - builder.add_feed_dict(dict(zip(self._grads, gradients))) - fetches = builder.add_fetches([self._apply_op]) - return fetches[0] - - def _build_learn_on_batch(self, builder, postprocessed_batch): - builder.add_feed_dict(self.extra_compute_grad_feed_dict()) - builder.add_feed_dict(self._get_loss_inputs_dict(postprocessed_batch)) - builder.add_feed_dict({self._is_training: True}) - fetches = builder.add_fetches([ - self._apply_op, - self._get_grad_and_stats_fetches(), - ]) - return fetches[1] - - def _get_grad_and_stats_fetches(self): - fetches = self.extra_compute_grad_fetches() - if LEARNER_STATS_KEY not in fetches: - raise ValueError( - "Grad fetches should contain 'stats': {...} entry") - if self._stats_fetches: - fetches[LEARNER_STATS_KEY] = dict(self._stats_fetches, - **fetches[LEARNER_STATS_KEY]) - return fetches - - def _get_loss_inputs_dict(self, batch): - feed_dict = {} - if self._batch_divisibility_req > 1: - meets_divisibility_reqs = ( - len(batch[SampleBatch.CUR_OBS]) % - self._batch_divisibility_req == 0 - and max(batch[SampleBatch.AGENT_INDEX]) == 0) # not multiagent - else: - meets_divisibility_reqs = True - - # Simple case: not RNN nor do we need to pad - if not self._state_inputs and meets_divisibility_reqs: - for k, ph in self._loss_inputs: - feed_dict[ph] = batch[k] - return feed_dict - - if self._state_inputs: - max_seq_len = self._max_seq_len - dynamic_max = True - else: - max_seq_len = self._batch_divisibility_req - dynamic_max = False - - # RNN or multi-agent case - feature_keys = [k for k, v in self._loss_inputs] - state_keys = [ - "state_in_{}".format(i) for i in range(len(self._state_inputs)) - ] - feature_sequences, initial_states, seq_lens = chop_into_sequences( - batch[SampleBatch.EPS_ID], - batch[SampleBatch.UNROLL_ID], - batch[SampleBatch.AGENT_INDEX], [batch[k] for k in feature_keys], - [batch[k] for k in state_keys], - max_seq_len, - dynamic_max=dynamic_max) - for k, v in zip(feature_keys, feature_sequences): - feed_dict[self._loss_input_dict[k]] = v - for k, v in zip(state_keys, initial_states): - feed_dict[self._loss_input_dict[k]] = v - feed_dict[self._seq_lens] = seq_lens - - if log_once("rnn_feed_dict"): - logger.info("Padded input for RNN:\n\n{}\n".format( - summarize({ - "features": feature_sequences, - "initial_states": initial_states, - "seq_lens": seq_lens, - "max_seq_len": max_seq_len, - }))) - return feed_dict - - -@DeveloperAPI -class LearningRateSchedule(object): - """Mixin for TFPolicyGraph that adds a learning rate schedule.""" - - @DeveloperAPI - def __init__(self, lr, lr_schedule): - self.cur_lr = tf.get_variable("lr", initializer=lr) - if lr_schedule is None: - self.lr_schedule = ConstantSchedule(lr) - else: - self.lr_schedule = PiecewiseSchedule( - lr_schedule, outside_value=lr_schedule[-1][-1]) - - @override(PolicyGraph) - def on_global_var_update(self, global_vars): - super(LearningRateSchedule, self).on_global_var_update(global_vars) - self.cur_lr.load( - self.lr_schedule.value(global_vars["timestep"]), - session=self._sess) - - @override(TFPolicyGraph) - def optimizer(self): - return tf.train.AdamOptimizer(self.cur_lr) +TFPolicyGraph = renamed_class(TFPolicy, old_name="TFPolicyGraph") diff --git a/python/ray/rllib/evaluation/tf_policy_template.py b/python/ray/rllib/evaluation/tf_policy_template.py index b2549e973a65..36f482f18bf8 100644 --- a/python/ray/rllib/evaluation/tf_policy_template.py +++ b/python/ray/rllib/evaluation/tf_policy_template.py @@ -2,9 +2,9 @@ from __future__ import division from __future__ import print_function -from ray.rllib.evaluation.dynamic_tf_policy_graph import DynamicTFPolicyGraph -from ray.rllib.evaluation.policy_graph import PolicyGraph -from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph +from ray.rllib.policy.dynamic_tf_policy import DynamicTFPolicy +from ray.rllib.policy.policy import Policy +from ray.rllib.policy.tf_policy import TFPolicy from ray.rllib.utils.annotations import override, DeveloperAPI @@ -27,7 +27,7 @@ def build_tf_policy(name, """Helper function for creating a dynamic tf policy at runtime. Arguments: - name (str): name of the graph (e.g., "PPOPolicy") + name (str): name of the policy (e.g., "PPOTFPolicy") loss_fn (func): function that returns a loss tensor the policy, and dict of experience tensor placeholders get_default_config (func): optional function that returns the default @@ -39,7 +39,7 @@ def build_tf_policy(name, extra_action_fetches_fn (func): optional function that returns a dict of TF fetches given the policy object postprocess_fn (func): optional experience postprocessing function - that takes the same args as PolicyGraph.postprocess_trajectory() + that takes the same args as Policy.postprocess_trajectory() optimizer_fn (func): optional function that returns a tf.Optimizer given the policy and config gradients_fn (func): optional function that returns a list of gradients @@ -57,18 +57,18 @@ def build_tf_policy(name, arguments mixins (list): list of any class mixins for the returned policy class. These mixins will be applied in order and will have higher - precedence than the DynamicTFPolicyGraph class + precedence than the DynamicTFPolicy class get_batch_divisibility_req (func): optional function that returns the divisibility requirement for sample batches Returns: - a DynamicTFPolicyGraph instance that uses the specified args + a DynamicTFPolicy instance that uses the specified args """ if not name.endswith("TFPolicy"): raise ValueError("Name should match *TFPolicy", name) - base = DynamicTFPolicyGraph + base = DynamicTFPolicy while mixins: class new_base(mixins.pop(), base): @@ -76,7 +76,7 @@ class new_base(mixins.pop(), base): base = new_base - class graph_cls(base): + class policy_cls(base): def __init__(self, obs_space, action_space, @@ -97,7 +97,7 @@ def before_loss_init_wrapper(policy, obs_space, action_space, else: self._extra_action_fetches = extra_action_fetches_fn(self) - DynamicTFPolicyGraph.__init__( + DynamicTFPolicy.__init__( self, obs_space, action_space, @@ -111,7 +111,7 @@ def before_loss_init_wrapper(policy, obs_space, action_space, if after_init: after_init(self, obs_space, action_space, config) - @override(PolicyGraph) + @override(Policy) def postprocess_trajectory(self, sample_batch, other_agent_batches=None, @@ -121,26 +121,26 @@ def postprocess_trajectory(self, return postprocess_fn(self, sample_batch, other_agent_batches, episode) - @override(TFPolicyGraph) + @override(TFPolicy) def optimizer(self): if optimizer_fn: return optimizer_fn(self, self.config) else: - return TFPolicyGraph.optimizer(self) + return TFPolicy.optimizer(self) - @override(TFPolicyGraph) + @override(TFPolicy) def gradients(self, optimizer, loss): if gradients_fn: return gradients_fn(self, optimizer, loss) else: - return TFPolicyGraph.gradients(self, optimizer, loss) + return TFPolicy.gradients(self, optimizer, loss) - @override(TFPolicyGraph) + @override(TFPolicy) def extra_compute_action_fetches(self): return dict( - TFPolicyGraph.extra_compute_action_fetches(self), + TFPolicy.extra_compute_action_fetches(self), **self._extra_action_fetches) - graph_cls.__name__ = name - graph_cls.__qualname__ = name - return graph_cls + policy_cls.__name__ = name + policy_cls.__qualname__ = name + return policy_cls diff --git a/python/ray/rllib/evaluation/torch_policy_graph.py b/python/ray/rllib/evaluation/torch_policy_graph.py index ccf1b9eeb81d..08cc29fed746 100644 --- a/python/ray/rllib/evaluation/torch_policy_graph.py +++ b/python/ray/rllib/evaluation/torch_policy_graph.py @@ -2,173 +2,7 @@ from __future__ import division from __future__ import print_function -import os +from ray.rllib.policy.torch_policy import TorchPolicy +from ray.rllib.utils import renamed_class -import numpy as np -from threading import Lock - -try: - import torch -except ImportError: - pass # soft dep - -from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY -from ray.rllib.evaluation.policy_graph import PolicyGraph -from ray.rllib.utils.annotations import override -from ray.rllib.utils.tracking_dict import UsageTrackingDict - - -class TorchPolicyGraph(PolicyGraph): - """Template for a PyTorch policy and loss to use with RLlib. - - This is similar to TFPolicyGraph, but for PyTorch. - - Attributes: - observation_space (gym.Space): observation space of the policy. - action_space (gym.Space): action space of the policy. - lock (Lock): Lock that must be held around PyTorch ops on this graph. - This is necessary when using the async sampler. - """ - - def __init__(self, observation_space, action_space, model, loss, - action_distribution_cls): - """Build a policy graph from policy and loss torch modules. - - Note that model will be placed on GPU device if CUDA_VISIBLE_DEVICES - is set. Only single GPU is supported for now. - - Arguments: - observation_space (gym.Space): observation space of the policy. - action_space (gym.Space): action space of the policy. - model (nn.Module): PyTorch policy module. Given observations as - input, this module must return a list of outputs where the - first item is action logits, and the rest can be any value. - loss (func): Function that takes (policy_graph, batch_tensors) - and returns a single scalar loss. - action_distribution_cls (ActionDistribution): Class for action - distribution. - """ - self.observation_space = observation_space - self.action_space = action_space - self.lock = Lock() - self.device = (torch.device("cuda") - if bool(os.environ.get("CUDA_VISIBLE_DEVICES", None)) - else torch.device("cpu")) - self._model = model.to(self.device) - self._loss = loss - self._optimizer = self.optimizer() - self._action_dist_cls = action_distribution_cls - - @override(PolicyGraph) - def compute_actions(self, - obs_batch, - state_batches=None, - prev_action_batch=None, - prev_reward_batch=None, - info_batch=None, - episodes=None, - **kwargs): - with self.lock: - with torch.no_grad(): - ob = torch.from_numpy(np.array(obs_batch)) \ - .float().to(self.device) - model_out = self._model({"obs": ob}, state_batches) - logits, _, vf, state = model_out - action_dist = self._action_dist_cls(logits) - actions = action_dist.sample() - return (actions.cpu().numpy(), - [h.cpu().numpy() for h in state], - self.extra_action_out(model_out)) - - @override(PolicyGraph) - def learn_on_batch(self, postprocessed_batch): - batch_tensors = self._lazy_tensor_dict(postprocessed_batch) - - with self.lock: - loss_out = self._loss(self, batch_tensors) - self._optimizer.zero_grad() - loss_out.backward() - - grad_process_info = self.extra_grad_process() - self._optimizer.step() - - grad_info = self.extra_grad_info(batch_tensors) - grad_info.update(grad_process_info) - return {LEARNER_STATS_KEY: grad_info} - - @override(PolicyGraph) - def compute_gradients(self, postprocessed_batch): - batch_tensors = self._lazy_tensor_dict(postprocessed_batch) - - with self.lock: - loss_out = self._loss(self, batch_tensors) - self._optimizer.zero_grad() - loss_out.backward() - - grad_process_info = self.extra_grad_process() - - # Note that return values are just references; - # calling zero_grad will modify the values - grads = [] - for p in self._model.parameters(): - if p.grad is not None: - grads.append(p.grad.data.cpu().numpy()) - else: - grads.append(None) - - grad_info = self.extra_grad_info(batch_tensors) - grad_info.update(grad_process_info) - return grads, {LEARNER_STATS_KEY: grad_info} - - @override(PolicyGraph) - def apply_gradients(self, gradients): - with self.lock: - for g, p in zip(gradients, self._model.parameters()): - if g is not None: - p.grad = torch.from_numpy(g).to(self.device) - self._optimizer.step() - - @override(PolicyGraph) - def get_weights(self): - with self.lock: - return {k: v.cpu() for k, v in self._model.state_dict().items()} - - @override(PolicyGraph) - def set_weights(self, weights): - with self.lock: - self._model.load_state_dict(weights) - - @override(PolicyGraph) - def get_initial_state(self): - return [s.numpy() for s in self._model.state_init()] - - def extra_grad_process(self): - """Allow subclass to do extra processing on gradients and - return processing info.""" - return {} - - def extra_action_out(self, model_out): - """Returns dict of extra info to include in experience batch. - - Arguments: - model_out (list): Outputs of the policy model module.""" - return {} - - def extra_grad_info(self, batch_tensors): - """Return dict of extra grad info.""" - - return {} - - def optimizer(self): - """Custom PyTorch optimizer to use.""" - if hasattr(self, "config"): - return torch.optim.Adam( - self._model.parameters(), lr=self.config["lr"]) - else: - return torch.optim.Adam(self._model.parameters()) - - def _lazy_tensor_dict(self, postprocessed_batch): - batch_tensors = UsageTrackingDict(postprocessed_batch) - batch_tensors.set_get_interceptor( - lambda arr: torch.from_numpy(arr).to(self.device)) - return batch_tensors +TorchPolicyGraph = renamed_class(TorchPolicy, old_name="TorchPolicyGraph") diff --git a/python/ray/rllib/examples/hierarchical_training.py b/python/ray/rllib/examples/hierarchical_training.py index c6d2db96837f..2fe61953dc96 100644 --- a/python/ray/rllib/examples/hierarchical_training.py +++ b/python/ray/rllib/examples/hierarchical_training.py @@ -209,7 +209,7 @@ def policy_mapping_fn(agent_id): "log_level": "INFO", "entropy_coeff": 0.01, "multiagent": { - "policy_graphs": { + "policies": { "high_level_policy": (None, maze.observation_space, Discrete(4), { "gamma": 0.9 diff --git a/python/ray/rllib/examples/multiagent_cartpole.py b/python/ray/rllib/examples/multiagent_cartpole.py index 6e0f93711540..efa77ecbf7a5 100644 --- a/python/ray/rllib/examples/multiagent_cartpole.py +++ b/python/ray/rllib/examples/multiagent_cartpole.py @@ -6,7 +6,7 @@ Control the number of agents and policies via --num-agents and --num-policies. This works with hundreds of agents and policies, but note that initializing -many TF policy graphs will take some time. +many TF policies will take some time. Also, TF evals might slow down with large numbers of policies. To debug TF execution, set the TF_TIMELINE_DIR environment variable. @@ -90,12 +90,12 @@ def gen_policy(i): } return (None, obs_space, act_space, config) - # Setup PPO with an ensemble of `num_policies` different policy graphs - policy_graphs = { + # Setup PPO with an ensemble of `num_policies` different policies + policies = { "policy_{}".format(i): gen_policy(i) for i in range(args.num_policies) } - policy_ids = list(policy_graphs.keys()) + policy_ids = list(policies.keys()) tune.run( "PPO", @@ -105,7 +105,7 @@ def gen_policy(i): "log_level": "DEBUG", "num_sgd_iter": 10, "multiagent": { - "policy_graphs": policy_graphs, + "policies": policies, "policy_mapping_fn": tune.function( lambda agent_id: random.choice(policy_ids)), }, diff --git a/python/ray/rllib/examples/multiagent_custom_policy.py b/python/ray/rllib/examples/multiagent_custom_policy.py index 855051d52ef4..d34d678098b6 100644 --- a/python/ray/rllib/examples/multiagent_custom_policy.py +++ b/python/ray/rllib/examples/multiagent_custom_policy.py @@ -22,7 +22,7 @@ import ray from ray import tune -from ray.rllib.evaluation import PolicyGraph +from ray.rllib.policy import Policy from ray.rllib.tests.test_multi_agent_env import MultiCartpole from ray.tune.registry import register_env @@ -30,7 +30,7 @@ parser.add_argument("--num-iters", type=int, default=20) -class RandomPolicy(PolicyGraph): +class RandomPolicy(Policy): """Hand-coded policy that returns random actions.""" def compute_actions(self, @@ -65,7 +65,7 @@ def learn_on_batch(self, samples): config={ "env": "multi_cartpole", "multiagent": { - "policy_graphs": { + "policies": { "pg_policy": (None, obs_space, act_space, {}), "random": (RandomPolicy, obs_space, act_space, {}), }, diff --git a/python/ray/rllib/examples/multiagent_two_trainers.py b/python/ray/rllib/examples/multiagent_two_trainers.py index 1d4257e4eb9d..68c0e742e857 100644 --- a/python/ray/rllib/examples/multiagent_two_trainers.py +++ b/python/ray/rllib/examples/multiagent_two_trainers.py @@ -16,9 +16,9 @@ import ray from ray.rllib.agents.dqn.dqn import DQNTrainer -from ray.rllib.agents.dqn.dqn_policy_graph import DQNPolicyGraph +from ray.rllib.agents.dqn.dqn_policy import DQNTFPolicy from ray.rllib.agents.ppo.ppo import PPOTrainer -from ray.rllib.agents.ppo.ppo_policy_graph import PPOTFPolicy +from ray.rllib.agents.ppo.ppo_policy import PPOTFPolicy from ray.rllib.tests.test_multi_agent_env import MultiCartpole from ray.tune.logger import pretty_print from ray.tune.registry import register_env @@ -36,11 +36,11 @@ obs_space = single_env.observation_space act_space = single_env.action_space - # You can also have multiple policy graphs per trainer, but here we just + # You can also have multiple policies per trainer, but here we just # show one each for PPO and DQN. - policy_graphs = { + policies = { "ppo_policy": (PPOTFPolicy, obs_space, act_space, {}), - "dqn_policy": (DQNPolicyGraph, obs_space, act_space, {}), + "dqn_policy": (DQNTFPolicy, obs_space, act_space, {}), } def policy_mapping_fn(agent_id): @@ -53,7 +53,7 @@ def policy_mapping_fn(agent_id): env="multi_cartpole", config={ "multiagent": { - "policy_graphs": policy_graphs, + "policies": policies, "policy_mapping_fn": policy_mapping_fn, "policies_to_train": ["ppo_policy"], }, @@ -66,7 +66,7 @@ def policy_mapping_fn(agent_id): env="multi_cartpole", config={ "multiagent": { - "policy_graphs": policy_graphs, + "policies": policies, "policy_mapping_fn": policy_mapping_fn, "policies_to_train": ["dqn_policy"], }, diff --git a/python/ray/rllib/examples/policy_evaluator_custom_workflow.py b/python/ray/rllib/examples/policy_evaluator_custom_workflow.py index b07787129246..a8d80da994d2 100644 --- a/python/ray/rllib/examples/policy_evaluator_custom_workflow.py +++ b/python/ray/rllib/examples/policy_evaluator_custom_workflow.py @@ -1,7 +1,7 @@ """Example of using policy evaluator classes directly to implement training. Instead of using the built-in Trainer classes provided by RLlib, here we define -a custom PolicyGraph class and manually coordinate distributed sample +a custom Policy class and manually coordinate distributed sample collection and policy optimization. """ @@ -14,7 +14,8 @@ import ray from ray import tune -from ray.rllib.evaluation import PolicyGraph, PolicyEvaluator, SampleBatch +from ray.rllib.policy import Policy +from ray.rllib.evaluation import PolicyEvaluator, SampleBatch from ray.rllib.evaluation.metrics import collect_metrics parser = argparse.ArgumentParser() @@ -23,15 +24,15 @@ parser.add_argument("--num-workers", type=int, default=2) -class CustomPolicy(PolicyGraph): - """Example of a custom policy graph written from scratch. +class CustomPolicy(Policy): + """Example of a custom policy written from scratch. - You might find it more convenient to extend TF/TorchPolicyGraph instead + You might find it more convenient to extend TF/TorchPolicy instead for a real policy. """ def __init__(self, observation_space, action_space, config): - PolicyGraph.__init__(self, observation_space, action_space, config) + Policy.__init__(self, observation_space, action_space, config) # example parameter self.w = 1.0 diff --git a/python/ray/rllib/evaluation/keras_policy_graph.py b/python/ray/rllib/keras_policy.py similarity index 83% rename from python/ray/rllib/evaluation/keras_policy_graph.py rename to python/ray/rllib/keras_policy.py index 88d8e0a9be32..3008e133c1c6 100644 --- a/python/ray/rllib/evaluation/keras_policy_graph.py +++ b/python/ray/rllib/keras_policy.py @@ -4,19 +4,19 @@ import numpy as np -from ray.rllib.evaluation.policy_graph import PolicyGraph +from ray.rllib.policy.policy import Policy def _sample(probs): return [np.random.choice(len(pr), p=pr) for pr in probs] -class KerasPolicyGraph(PolicyGraph): - """Initialize the Keras Policy Graph. +class KerasPolicy(Policy): + """Initialize the Keras Policy. - This is a Policy Graph used for models with actor and critics. + This is a Policy used for models with actor and critics. Note: This class is built for specific usage of Actor-Critic models, - and is less general compared to TFPolicyGraph and TorchPolicyGraphs. + and is less general compared to TFPolicy and TorchPolicies. Args: observation_space (gym.Space): Observation space of the policy. @@ -32,7 +32,7 @@ def __init__(self, config, actor=None, critic=None): - PolicyGraph.__init__(self, observation_space, action_space, config) + Policy.__init__(self, observation_space, action_space, config) self.actor = actor self.critic = critic self.models = [self.actor, self.critic] diff --git a/python/ray/rllib/models/model.py b/python/ray/rllib/models/model.py index 4996f3cdf437..901ffa8024bf 100644 --- a/python/ray/rllib/models/model.py +++ b/python/ray/rllib/models/model.py @@ -161,7 +161,7 @@ def custom_loss(self, policy_loss, loss_inputs): You can find an runnable example in examples/custom_loss.py. Arguments: - policy_loss (Tensor): scalar policy loss from the policy graph. + policy_loss (Tensor): scalar policy loss from the policy. loss_inputs (dict): map of input placeholders for rollout data. Returns: diff --git a/python/ray/rllib/offline/input_reader.py b/python/ray/rllib/offline/input_reader.py index 5315773fd839..053c279343a8 100644 --- a/python/ray/rllib/offline/input_reader.py +++ b/python/ray/rllib/offline/input_reader.py @@ -6,7 +6,7 @@ import numpy as np import threading -from ray.rllib.evaluation.sample_batch import MultiAgentBatch +from ray.rllib.policy.sample_batch import MultiAgentBatch from ray.rllib.utils.annotations import PublicAPI from ray.rllib.utils import try_import_tf diff --git a/python/ray/rllib/offline/json_reader.py b/python/ray/rllib/offline/json_reader.py index e9568e75c7f4..55a002fb3ce6 100644 --- a/python/ray/rllib/offline/json_reader.py +++ b/python/ray/rllib/offline/json_reader.py @@ -17,7 +17,7 @@ from ray.rllib.offline.input_reader import InputReader from ray.rllib.offline.io_context import IOContext -from ray.rllib.evaluation.sample_batch import MultiAgentBatch, SampleBatch, \ +from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch, \ DEFAULT_POLICY_ID from ray.rllib.utils.annotations import override, PublicAPI from ray.rllib.utils.compression import unpack_if_needed diff --git a/python/ray/rllib/offline/json_writer.py b/python/ray/rllib/offline/json_writer.py index 5613d1f67dc2..679b00158b9e 100644 --- a/python/ray/rllib/offline/json_writer.py +++ b/python/ray/rllib/offline/json_writer.py @@ -15,7 +15,7 @@ except ImportError: smart_open = None -from ray.rllib.evaluation.sample_batch import MultiAgentBatch +from ray.rllib.policy.sample_batch import MultiAgentBatch from ray.rllib.offline.io_context import IOContext from ray.rllib.offline.output_writer import OutputWriter from ray.rllib.utils.annotations import override, PublicAPI diff --git a/python/ray/rllib/offline/off_policy_estimator.py b/python/ray/rllib/offline/off_policy_estimator.py index d09fe6baf052..7534e667f0bf 100644 --- a/python/ray/rllib/offline/off_policy_estimator.py +++ b/python/ray/rllib/offline/off_policy_estimator.py @@ -5,7 +5,7 @@ from collections import namedtuple import logging -from ray.rllib.evaluation.sample_batch import MultiAgentBatch +from ray.rllib.policy.sample_batch import MultiAgentBatch from ray.rllib.utils.annotations import DeveloperAPI logger = logging.getLogger(__name__) @@ -23,7 +23,7 @@ def __init__(self, policy, gamma): """Creates an off-policy estimator. Arguments: - policy (PolicyGraph): Policy graph to evaluate. + policy (Policy): Policy to evaluate. gamma (float): Discount of the MDP. """ self.policy = policy @@ -71,7 +71,7 @@ def action_prob(self, batch): raise ValueError( "Off-policy estimation is not possible unless the policy " "returns action probabilities when computing actions (i.e., " - "the 'action_prob' key is output by the policy graph). You " + "the 'action_prob' key is output by the policy). You " "can set `input_evaluation: []` to resolve this.") return info["action_prob"] diff --git a/python/ray/rllib/optimizers/aso_multi_gpu_learner.py b/python/ray/rllib/optimizers/aso_multi_gpu_learner.py index 328fee67d548..b5040e45584c 100644 --- a/python/ray/rllib/optimizers/aso_multi_gpu_learner.py +++ b/python/ray/rllib/optimizers/aso_multi_gpu_learner.py @@ -11,7 +11,7 @@ from six.moves import queue from ray.rllib.evaluation.metrics import get_learner_stats -from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID +from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID from ray.rllib.optimizers.aso_learner import LearnerThread from ray.rllib.optimizers.aso_minibatch_buffer import MinibatchBuffer from ray.rllib.optimizers.multi_gpu_impl import LocalSyncParallelOptimizer diff --git a/python/ray/rllib/optimizers/async_replay_optimizer.py b/python/ray/rllib/optimizers/async_replay_optimizer.py index b040c8e8a99f..d66f942ae532 100644 --- a/python/ray/rllib/optimizers/async_replay_optimizer.py +++ b/python/ray/rllib/optimizers/async_replay_optimizer.py @@ -17,7 +17,7 @@ import ray from ray.rllib.evaluation.metrics import get_learner_stats -from ray.rllib.evaluation.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \ +from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \ MultiAgentBatch from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer from ray.rllib.optimizers.replay_buffer import PrioritizedReplayBuffer diff --git a/python/ray/rllib/optimizers/multi_gpu_impl.py b/python/ray/rllib/optimizers/multi_gpu_impl.py index 8d1bbd4fb54d..aad301b29eee 100644 --- a/python/ray/rllib/optimizers/multi_gpu_impl.py +++ b/python/ray/rllib/optimizers/multi_gpu_impl.py @@ -48,7 +48,7 @@ class LocalSyncParallelOptimizer(object): processed. If this is larger than the total data size, it will be clipped. build_graph: Function that takes the specified inputs and returns a - TF Policy Graph instance. + TF Policy instance. """ def __init__(self, diff --git a/python/ray/rllib/optimizers/multi_gpu_optimizer.py b/python/ray/rllib/optimizers/multi_gpu_optimizer.py index de2671e6a932..a25553c40111 100644 --- a/python/ray/rllib/optimizers/multi_gpu_optimizer.py +++ b/python/ray/rllib/optimizers/multi_gpu_optimizer.py @@ -9,14 +9,14 @@ import ray from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY -from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph +from ray.rllib.policy.tf_policy import TFPolicy from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer from ray.rllib.optimizers.multi_gpu_impl import LocalSyncParallelOptimizer from ray.rllib.optimizers.rollout import collect_samples, \ collect_samples_straggler_mitigation from ray.rllib.utils.annotations import override from ray.rllib.utils.timer import TimerStat -from ray.rllib.evaluation.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \ +from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \ MultiAgentBatch from ray.rllib.utils import try_import_tf @@ -34,9 +34,9 @@ class LocalMultiGPUOptimizer(PolicyOptimizer): details, see `multi_gpu_impl.LocalSyncParallelOptimizer`. This optimizer is Tensorflow-specific and require the underlying - PolicyGraph to be a TFPolicyGraph instance that support `.copy()`. + Policy to be a TFPolicy instance that support `.copy()`. - Note that all replicas of the TFPolicyGraph will merge their + Note that all replicas of the TFPolicy will merge their extra_compute_grad and apply_grad feed_dicts and fetches. This may result in unexpected behavior. """ @@ -83,7 +83,7 @@ def __init__(self, self.local_evaluator.foreach_trainable_policy(lambda p, i: (i, p))) logger.debug("Policies to train: {}".format(self.policies)) for policy_id, policy in self.policies.items(): - if not isinstance(policy, TFPolicyGraph): + if not isinstance(policy, TFPolicy): raise ValueError( "Only TF policies are supported with multi-GPU. Try using " "the simple optimizer instead.") diff --git a/python/ray/rllib/optimizers/rollout.py b/python/ray/rllib/optimizers/rollout.py index 063c2ff8999d..fa1c03f6081e 100644 --- a/python/ray/rllib/optimizers/rollout.py +++ b/python/ray/rllib/optimizers/rollout.py @@ -5,7 +5,7 @@ import logging import ray -from ray.rllib.evaluation.sample_batch import SampleBatch +from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils.memory import ray_get_and_free logger = logging.getLogger(__name__) diff --git a/python/ray/rllib/optimizers/sync_batch_replay_optimizer.py b/python/ray/rllib/optimizers/sync_batch_replay_optimizer.py index 0a334e84ef79..e13d71c6e4cd 100644 --- a/python/ray/rllib/optimizers/sync_batch_replay_optimizer.py +++ b/python/ray/rllib/optimizers/sync_batch_replay_optimizer.py @@ -7,7 +7,7 @@ import ray from ray.rllib.evaluation.metrics import get_learner_stats from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer -from ray.rllib.evaluation.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \ +from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \ MultiAgentBatch from ray.rllib.utils.annotations import override from ray.rllib.utils.timer import TimerStat diff --git a/python/ray/rllib/optimizers/sync_replay_optimizer.py b/python/ray/rllib/optimizers/sync_replay_optimizer.py index 2e765f2d8641..27858f3527c1 100644 --- a/python/ray/rllib/optimizers/sync_replay_optimizer.py +++ b/python/ray/rllib/optimizers/sync_replay_optimizer.py @@ -11,7 +11,7 @@ PrioritizedReplayBuffer from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer from ray.rllib.evaluation.metrics import get_learner_stats -from ray.rllib.evaluation.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \ +from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \ MultiAgentBatch from ray.rllib.utils.annotations import override from ray.rllib.utils.compression import pack_if_needed diff --git a/python/ray/rllib/optimizers/sync_samples_optimizer.py b/python/ray/rllib/optimizers/sync_samples_optimizer.py index a08f0345eb2b..f5807ae343ef 100644 --- a/python/ray/rllib/optimizers/sync_samples_optimizer.py +++ b/python/ray/rllib/optimizers/sync_samples_optimizer.py @@ -6,7 +6,7 @@ import logging from ray.rllib.evaluation.metrics import get_learner_stats from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer -from ray.rllib.evaluation.sample_batch import SampleBatch +from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils.annotations import override from ray.rllib.utils.filter import RunningStat from ray.rllib.utils.timer import TimerStat diff --git a/python/ray/rllib/policy/__init__.py b/python/ray/rllib/policy/__init__.py new file mode 100644 index 000000000000..0f172dcd566d --- /dev/null +++ b/python/ray/rllib/policy/__init__.py @@ -0,0 +1,17 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from ray.rllib.policy.policy import Policy +from ray.rllib.policy.torch_policy import TorchPolicy +from ray.rllib.policy.tf_policy import TFPolicy +from ray.rllib.policy.torch_policy_template import build_torch_policy +from ray.rllib.policy.tf_policy_template import build_tf_policy + +__all__ = [ + "Policy", + "TFPolicy", + "TorchPolicy", + "build_tf_policy", + "build_torch_policy", +] diff --git a/python/ray/rllib/evaluation/dynamic_tf_policy_graph.py b/python/ray/rllib/policy/dynamic_tf_policy.py similarity index 94% rename from python/ray/rllib/evaluation/dynamic_tf_policy_graph.py rename to python/ray/rllib/policy/dynamic_tf_policy.py index 73e08fcf9093..691fc1186272 100644 --- a/python/ray/rllib/evaluation/dynamic_tf_policy_graph.py +++ b/python/ray/rllib/policy/dynamic_tf_policy.py @@ -6,9 +6,9 @@ import logging import numpy as np -from ray.rllib.evaluation.policy_graph import PolicyGraph -from ray.rllib.evaluation.sample_batch import SampleBatch -from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph +from ray.rllib.policy.policy import Policy +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.tf_policy import TFPolicy from ray.rllib.models.catalog import ModelCatalog from ray.rllib.utils.annotations import override from ray.rllib.utils import try_import_tf @@ -20,8 +20,8 @@ logger = logging.getLogger(__name__) -class DynamicTFPolicyGraph(TFPolicyGraph): - """A TFPolicyGraph that auto-defines placeholders dynamically at runtime. +class DynamicTFPolicy(TFPolicy): + """A TFPolicy that auto-defines placeholders dynamically at runtime. Initialization of this class occurs in two phases. * Phase 1: the model is created and model variables are initialized. @@ -42,7 +42,7 @@ def __init__(self, make_action_sampler=None, existing_inputs=None, get_batch_divisibility_req=None): - """Initialize a dynamic TF policy graph. + """Initialize a dynamic TF policy. Arguments: observation_space (gym.Space): Observation space of the policy. @@ -51,16 +51,16 @@ def __init__(self, loss_fn (func): function that returns a loss tensor the policy graph, and dict of experience tensor placeholders stats_fn (func): optional function that returns a dict of - TF fetches given the policy graph and batch input tensors + TF fetches given the policy and batch input tensors grad_stats_fn (func): optional function that returns a dict of - TF fetches given the policy graph and loss gradient tensors + TF fetches given the policy and loss gradient tensors before_loss_init (func): optional function to run prior to loss init that takes the same arguments as __init__ make_action_sampler (func): optional function that returns a tuple of action and action prob tensors. The function takes (policy, input_dict, obs_space, action_space, config) as its arguments - existing_inputs (OrderedDict): when copying a policy graph, this + existing_inputs (OrderedDict): when copying a policy, this specifies an existing dict of placeholders to use instead of defining new ones get_batch_divisibility_req (func): optional function that returns @@ -134,7 +134,7 @@ def __init__(self, batch_divisibility_req = get_batch_divisibility_req(self) else: batch_divisibility_req = 1 - TFPolicyGraph.__init__( + TFPolicy.__init__( self, obs_space, action_space, @@ -158,7 +158,7 @@ def __init__(self, if not existing_inputs: self._initialize_loss() - @override(TFPolicyGraph) + @override(TFPolicy) def copy(self, existing_inputs): """Creates a copy of self using existing input placeholders.""" @@ -194,7 +194,7 @@ def copy(self, existing_inputs): if instance._stats_fn: instance._stats_fetches.update( instance._stats_fn(instance, input_dict)) - TFPolicyGraph._initialize_loss( + TFPolicy._initialize_loss( instance, loss, [(k, existing_inputs[i]) for i, (k, _) in enumerate(self._loss_inputs)]) if instance._grad_stats_fn: @@ -202,7 +202,7 @@ def copy(self, existing_inputs): instance._grad_stats_fn(instance, instance._grads)) return instance - @override(PolicyGraph) + @override(Policy) def get_initial_state(self): if self.model: return self.model.state_init @@ -269,7 +269,7 @@ def fake_array(tensor): self._stats_fetches.update(self._stats_fn(self, batch_tensors)) for k in sorted(batch_tensors.accessed_keys): loss_inputs.append((k, batch_tensors[k])) - TFPolicyGraph._initialize_loss(self, loss, loss_inputs) + TFPolicy._initialize_loss(self, loss, loss_inputs) if self._grad_stats_fn: self._stats_fetches.update(self._grad_stats_fn(self, self._grads)) self._sess.run(tf.global_variables_initializer()) diff --git a/python/ray/rllib/policy/policy.py b/python/ray/rllib/policy/policy.py new file mode 100644 index 000000000000..6f456e608007 --- /dev/null +++ b/python/ray/rllib/policy/policy.py @@ -0,0 +1,291 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import gym + +from ray.rllib.utils.annotations import DeveloperAPI + +# By convention, metrics from optimizing the loss can be reported in the +# `grad_info` dict returned by learn_on_batch() / compute_grads() via this key. +LEARNER_STATS_KEY = "learner_stats" + + +@DeveloperAPI +class Policy(object): + """An agent policy and loss, i.e., a TFPolicy or other subclass. + + This object defines how to act in the environment, and also losses used to + improve the policy based on its experiences. Note that both policy and + loss are defined together for convenience, though the policy itself is + logically separate. + + All policies can directly extend Policy, however TensorFlow users may + find TFPolicy simpler to implement. TFPolicy also enables RLlib + to apply TensorFlow-specific optimizations such as fusing multiple policy + graphs and multi-GPU support. + + Attributes: + observation_space (gym.Space): Observation space of the policy. + action_space (gym.Space): Action space of the policy. + """ + + @DeveloperAPI + def __init__(self, observation_space, action_space, config): + """Initialize the graph. + + This is the standard constructor for policies. The policy + class you pass into PolicyEvaluator will be constructed with + these arguments. + + Args: + observation_space (gym.Space): Observation space of the policy. + action_space (gym.Space): Action space of the policy. + config (dict): Policy-specific configuration data. + """ + + self.observation_space = observation_space + self.action_space = action_space + + @DeveloperAPI + def compute_actions(self, + obs_batch, + state_batches, + prev_action_batch=None, + prev_reward_batch=None, + info_batch=None, + episodes=None, + **kwargs): + """Compute actions for the current policy. + + Arguments: + obs_batch (np.ndarray): batch of observations + state_batches (list): list of RNN state input batches, if any + prev_action_batch (np.ndarray): batch of previous action values + prev_reward_batch (np.ndarray): batch of previous rewards + info_batch (info): batch of info objects + episodes (list): MultiAgentEpisode for each obs in obs_batch. + This provides access to all of the internal episode state, + which may be useful for model-based or multiagent algorithms. + kwargs: forward compatibility placeholder + + Returns: + actions (np.ndarray): batch of output actions, with shape like + [BATCH_SIZE, ACTION_SHAPE]. + state_outs (list): list of RNN state output batches, if any, with + shape like [STATE_SIZE, BATCH_SIZE]. + info (dict): dictionary of extra feature batches, if any, with + shape like {"f1": [BATCH_SIZE, ...], "f2": [BATCH_SIZE, ...]}. + """ + raise NotImplementedError + + @DeveloperAPI + def compute_single_action(self, + obs, + state, + prev_action=None, + prev_reward=None, + info=None, + episode=None, + clip_actions=False, + **kwargs): + """Unbatched version of compute_actions. + + Arguments: + obs (obj): single observation + state_batches (list): list of RNN state inputs, if any + prev_action (obj): previous action value, if any + prev_reward (int): previous reward, if any + info (dict): info object, if any + episode (MultiAgentEpisode): this provides access to all of the + internal episode state, which may be useful for model-based or + multi-agent algorithms. + clip_actions (bool): should the action be clipped + kwargs: forward compatibility placeholder + + Returns: + actions (obj): single action + state_outs (list): list of RNN state outputs, if any + info (dict): dictionary of extra features, if any + """ + + prev_action_batch = None + prev_reward_batch = None + info_batch = None + episodes = None + if prev_action is not None: + prev_action_batch = [prev_action] + if prev_reward is not None: + prev_reward_batch = [prev_reward] + if info is not None: + info_batch = [info] + if episode is not None: + episodes = [episode] + [action], state_out, info = self.compute_actions( + [obs], [[s] for s in state], + prev_action_batch=prev_action_batch, + prev_reward_batch=prev_reward_batch, + info_batch=info_batch, + episodes=episodes) + if clip_actions: + action = clip_action(action, self.action_space) + return action, [s[0] for s in state_out], \ + {k: v[0] for k, v in info.items()} + + @DeveloperAPI + def postprocess_trajectory(self, + sample_batch, + other_agent_batches=None, + episode=None): + """Implements algorithm-specific trajectory postprocessing. + + This will be called on each trajectory fragment computed during policy + evaluation. Each fragment is guaranteed to be only from one episode. + + Arguments: + sample_batch (SampleBatch): batch of experiences for the policy, + which will contain at most one episode trajectory. + other_agent_batches (dict): In a multi-agent env, this contains a + mapping of agent ids to (policy, agent_batch) tuples + containing the policy and experiences of the other agent. + episode (MultiAgentEpisode): this provides access to all of the + internal episode state, which may be useful for model-based or + multi-agent algorithms. + + Returns: + SampleBatch: postprocessed sample batch. + """ + return sample_batch + + @DeveloperAPI + def learn_on_batch(self, samples): + """Fused compute gradients and apply gradients call. + + Either this or the combination of compute/apply grads must be + implemented by subclasses. + + Returns: + grad_info: dictionary of extra metadata from compute_gradients(). + + Examples: + >>> batch = ev.sample() + >>> ev.learn_on_batch(samples) + """ + + grads, grad_info = self.compute_gradients(samples) + self.apply_gradients(grads) + return grad_info + + @DeveloperAPI + def compute_gradients(self, postprocessed_batch): + """Computes gradients against a batch of experiences. + + Either this or learn_on_batch() must be implemented by subclasses. + + Returns: + grads (list): List of gradient output values + info (dict): Extra policy-specific values + """ + raise NotImplementedError + + @DeveloperAPI + def apply_gradients(self, gradients): + """Applies previously computed gradients. + + Either this or learn_on_batch() must be implemented by subclasses. + """ + raise NotImplementedError + + @DeveloperAPI + def get_weights(self): + """Returns model weights. + + Returns: + weights (obj): Serializable copy or view of model weights + """ + raise NotImplementedError + + @DeveloperAPI + def set_weights(self, weights): + """Sets model weights. + + Arguments: + weights (obj): Serializable copy or view of model weights + """ + raise NotImplementedError + + @DeveloperAPI + def get_initial_state(self): + """Returns initial RNN state for the current policy.""" + return [] + + @DeveloperAPI + def get_state(self): + """Saves all local state. + + Returns: + state (obj): Serialized local state. + """ + return self.get_weights() + + @DeveloperAPI + def set_state(self, state): + """Restores all local state. + + Arguments: + state (obj): Serialized local state. + """ + self.set_weights(state) + + @DeveloperAPI + def on_global_var_update(self, global_vars): + """Called on an update to global vars. + + Arguments: + global_vars (dict): Global variables broadcast from the driver. + """ + pass + + @DeveloperAPI + def export_model(self, export_dir): + """Export Policy to local directory for serving. + + Arguments: + export_dir (str): Local writable directory. + """ + raise NotImplementedError + + @DeveloperAPI + def export_checkpoint(self, export_dir): + """Export Policy checkpoint to local directory. + + Argument: + export_dir (str): Local writable directory. + """ + raise NotImplementedError + + +def clip_action(action, space): + """Called to clip actions to the specified range of this policy. + + Arguments: + action: Single action. + space: Action space the actions should be present in. + + Returns: + Clipped batch of actions. + """ + + if isinstance(space, gym.spaces.Box): + return np.clip(action, space.low, space.high) + elif isinstance(space, gym.spaces.Tuple): + if type(action) not in (tuple, list): + raise ValueError("Expected tuple space for actions {}: {}".format( + action, space)) + out = [] + for a, s in zip(action, space.spaces): + out.append(clip_action(a, s)) + return out + else: + return action diff --git a/python/ray/rllib/policy/sample_batch.py b/python/ray/rllib/policy/sample_batch.py new file mode 100644 index 000000000000..a9515eeeac5a --- /dev/null +++ b/python/ray/rllib/policy/sample_batch.py @@ -0,0 +1,296 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import six +import collections +import numpy as np + +from ray.rllib.utils.annotations import PublicAPI, DeveloperAPI +from ray.rllib.utils.compression import pack, unpack, is_compressed +from ray.rllib.utils.memory import concat_aligned + +# Default policy id for single agent environments +DEFAULT_POLICY_ID = "default_policy" + + +@PublicAPI +class MultiAgentBatch(object): + """A batch of experiences from multiple policies in the environment. + + Attributes: + policy_batches (dict): Mapping from policy id to a normal SampleBatch + of experiences. Note that these batches may be of different length. + count (int): The number of timesteps in the environment this batch + contains. This will be less than the number of transitions this + batch contains across all policies in total. + """ + + @PublicAPI + def __init__(self, policy_batches, count): + self.policy_batches = policy_batches + self.count = count + + @staticmethod + @PublicAPI + def wrap_as_needed(batches, count): + if len(batches) == 1 and DEFAULT_POLICY_ID in batches: + return batches[DEFAULT_POLICY_ID] + return MultiAgentBatch(batches, count) + + @staticmethod + @PublicAPI + def concat_samples(samples): + policy_batches = collections.defaultdict(list) + total_count = 0 + for s in samples: + assert isinstance(s, MultiAgentBatch) + for policy_id, batch in s.policy_batches.items(): + policy_batches[policy_id].append(batch) + total_count += s.count + out = {} + for policy_id, batches in policy_batches.items(): + out[policy_id] = SampleBatch.concat_samples(batches) + return MultiAgentBatch(out, total_count) + + @PublicAPI + def copy(self): + return MultiAgentBatch( + {k: v.copy() + for (k, v) in self.policy_batches.items()}, self.count) + + @PublicAPI + def total(self): + ct = 0 + for batch in self.policy_batches.values(): + ct += batch.count + return ct + + @DeveloperAPI + def compress(self, bulk=False, columns=frozenset(["obs", "new_obs"])): + for batch in self.policy_batches.values(): + batch.compress(bulk=bulk, columns=columns) + + @DeveloperAPI + def decompress_if_needed(self, columns=frozenset(["obs", "new_obs"])): + for batch in self.policy_batches.values(): + batch.decompress_if_needed(columns) + + def __str__(self): + return "MultiAgentBatch({}, count={})".format( + str(self.policy_batches), self.count) + + def __repr__(self): + return "MultiAgentBatch({}, count={})".format( + str(self.policy_batches), self.count) + + +@PublicAPI +class SampleBatch(object): + """Wrapper around a dictionary with string keys and array-like values. + + For example, {"obs": [1, 2, 3], "reward": [0, -1, 1]} is a batch of three + samples, each with an "obs" and "reward" attribute. + """ + + # Outputs from interacting with the environment + CUR_OBS = "obs" + NEXT_OBS = "new_obs" + ACTIONS = "actions" + REWARDS = "rewards" + PREV_ACTIONS = "prev_actions" + PREV_REWARDS = "prev_rewards" + DONES = "dones" + INFOS = "infos" + + # Uniquely identifies an episode + EPS_ID = "eps_id" + + # Uniquely identifies a sample batch. This is important to distinguish RNN + # sequences from the same episode when multiple sample batches are + # concatenated (fusing sequences across batches can be unsafe). + UNROLL_ID = "unroll_id" + + # Uniquely identifies an agent within an episode + AGENT_INDEX = "agent_index" + + # Value function predictions emitted by the behaviour policy + VF_PREDS = "vf_preds" + + @PublicAPI + def __init__(self, *args, **kwargs): + """Constructs a sample batch (same params as dict constructor).""" + + self.data = dict(*args, **kwargs) + lengths = [] + for k, v in self.data.copy().items(): + assert isinstance(k, six.string_types), self + lengths.append(len(v)) + self.data[k] = np.array(v, copy=False) + if not lengths: + raise ValueError("Empty sample batch") + assert len(set(lengths)) == 1, "data columns must be same length" + self.count = lengths[0] + + @staticmethod + @PublicAPI + def concat_samples(samples): + if isinstance(samples[0], MultiAgentBatch): + return MultiAgentBatch.concat_samples(samples) + out = {} + samples = [s for s in samples if s.count > 0] + for k in samples[0].keys(): + out[k] = concat_aligned([s[k] for s in samples]) + return SampleBatch(out) + + @PublicAPI + def concat(self, other): + """Returns a new SampleBatch with each data column concatenated. + + Examples: + >>> b1 = SampleBatch({"a": [1, 2]}) + >>> b2 = SampleBatch({"a": [3, 4, 5]}) + >>> print(b1.concat(b2)) + {"a": [1, 2, 3, 4, 5]} + """ + + assert self.keys() == other.keys(), "must have same columns" + out = {} + for k in self.keys(): + out[k] = concat_aligned([self[k], other[k]]) + return SampleBatch(out) + + @PublicAPI + def copy(self): + return SampleBatch( + {k: np.array(v, copy=True) + for (k, v) in self.data.items()}) + + @PublicAPI + def rows(self): + """Returns an iterator over data rows, i.e. dicts with column values. + + Examples: + >>> batch = SampleBatch({"a": [1, 2, 3], "b": [4, 5, 6]}) + >>> for row in batch.rows(): + print(row) + {"a": 1, "b": 4} + {"a": 2, "b": 5} + {"a": 3, "b": 6} + """ + + for i in range(self.count): + row = {} + for k in self.keys(): + row[k] = self[k][i] + yield row + + @PublicAPI + def columns(self, keys): + """Returns a list of just the specified columns. + + Examples: + >>> batch = SampleBatch({"a": [1], "b": [2], "c": [3]}) + >>> print(batch.columns(["a", "b"])) + [[1], [2]] + """ + + out = [] + for k in keys: + out.append(self[k]) + return out + + @PublicAPI + def shuffle(self): + """Shuffles the rows of this batch in-place.""" + + permutation = np.random.permutation(self.count) + for key, val in self.items(): + self[key] = val[permutation] + + @PublicAPI + def split_by_episode(self): + """Splits this batch's data by `eps_id`. + + Returns: + list of SampleBatch, one per distinct episode. + """ + + slices = [] + cur_eps_id = self.data["eps_id"][0] + offset = 0 + for i in range(self.count): + next_eps_id = self.data["eps_id"][i] + if next_eps_id != cur_eps_id: + slices.append(self.slice(offset, i)) + offset = i + cur_eps_id = next_eps_id + slices.append(self.slice(offset, self.count)) + for s in slices: + slen = len(set(s["eps_id"])) + assert slen == 1, (s, slen) + assert sum(s.count for s in slices) == self.count, (slices, self.count) + return slices + + @PublicAPI + def slice(self, start, end): + """Returns a slice of the row data of this batch. + + Arguments: + start (int): Starting index. + end (int): Ending index. + + Returns: + SampleBatch which has a slice of this batch's data. + """ + + return SampleBatch({k: v[start:end] for k, v in self.data.items()}) + + @PublicAPI + def keys(self): + return self.data.keys() + + @PublicAPI + def items(self): + return self.data.items() + + @PublicAPI + def __getitem__(self, key): + return self.data[key] + + @PublicAPI + def __setitem__(self, key, item): + self.data[key] = item + + @DeveloperAPI + def compress(self, bulk=False, columns=frozenset(["obs", "new_obs"])): + for key in columns: + if key in self.data: + if bulk: + self.data[key] = pack(self.data[key]) + else: + self.data[key] = np.array( + [pack(o) for o in self.data[key]]) + + @DeveloperAPI + def decompress_if_needed(self, columns=frozenset(["obs", "new_obs"])): + for key in columns: + if key in self.data: + arr = self.data[key] + if is_compressed(arr): + self.data[key] = unpack(arr) + elif len(arr) > 0 and is_compressed(arr[0]): + self.data[key] = np.array( + [unpack(o) for o in self.data[key]]) + + def __str__(self): + return "SampleBatch({})".format(str(self.data)) + + def __repr__(self): + return "SampleBatch({})".format(str(self.data)) + + def __iter__(self): + return self.data.__iter__() + + def __contains__(self, x): + return x in self.data diff --git a/python/ray/rllib/policy/tf_policy.py b/python/ray/rllib/policy/tf_policy.py new file mode 100644 index 000000000000..bbb5795e52ab --- /dev/null +++ b/python/ray/rllib/policy/tf_policy.py @@ -0,0 +1,513 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import errno +import logging +import numpy as np + +import ray +import ray.experimental.tf_utils +from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.models.lstm import chop_into_sequences +from ray.rllib.utils.annotations import override, DeveloperAPI +from ray.rllib.utils.debug import log_once, summarize +from ray.rllib.utils.schedules import ConstantSchedule, PiecewiseSchedule +from ray.rllib.utils.tf_run_builder import TFRunBuilder +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() +logger = logging.getLogger(__name__) + + +@DeveloperAPI +class TFPolicy(Policy): + """An agent policy and loss implemented in TensorFlow. + + Extending this class enables RLlib to perform TensorFlow specific + optimizations on the policy, e.g., parallelization across gpus or + fusing multiple graphs together in the multi-agent setting. + + Input tensors are typically shaped like [BATCH_SIZE, ...]. + + Attributes: + observation_space (gym.Space): observation space of the policy. + action_space (gym.Space): action space of the policy. + model (rllib.models.Model): RLlib model used for the policy. + + Examples: + >>> policy = TFPolicySubclass( + sess, obs_input, action_sampler, loss, loss_inputs) + + >>> print(policy.compute_actions([1, 0, 2])) + (array([0, 1, 1]), [], {}) + + >>> print(policy.postprocess_trajectory(SampleBatch({...}))) + SampleBatch({"action": ..., "advantages": ..., ...}) + """ + + @DeveloperAPI + def __init__(self, + observation_space, + action_space, + sess, + obs_input, + action_sampler, + loss, + loss_inputs, + model=None, + action_prob=None, + state_inputs=None, + state_outputs=None, + prev_action_input=None, + prev_reward_input=None, + seq_lens=None, + max_seq_len=20, + batch_divisibility_req=1, + update_ops=None): + """Initialize the policy. + + Arguments: + observation_space (gym.Space): Observation space of the env. + action_space (gym.Space): Action space of the env. + sess (Session): TensorFlow session to use. + obs_input (Tensor): input placeholder for observations, of shape + [BATCH_SIZE, obs...]. + action_sampler (Tensor): Tensor for sampling an action, of shape + [BATCH_SIZE, action...] + loss (Tensor): scalar policy loss output tensor. + loss_inputs (list): a (name, placeholder) tuple for each loss + input argument. Each placeholder name must correspond to a + SampleBatch column key returned by postprocess_trajectory(), + and has shape [BATCH_SIZE, data...]. These keys will be read + from postprocessed sample batches and fed into the specified + placeholders during loss computation. + model (rllib.models.Model): used to integrate custom losses and + stats from user-defined RLlib models. + action_prob (Tensor): probability of the sampled action. + state_inputs (list): list of RNN state input Tensors. + state_outputs (list): list of RNN state output Tensors. + prev_action_input (Tensor): placeholder for previous actions + prev_reward_input (Tensor): placeholder for previous rewards + seq_lens (Tensor): placeholder for RNN sequence lengths, of shape + [NUM_SEQUENCES]. Note that NUM_SEQUENCES << BATCH_SIZE. See + models/lstm.py for more information. + max_seq_len (int): max sequence length for LSTM training. + batch_divisibility_req (int): pad all agent experiences batches to + multiples of this value. This only has an effect if not using + a LSTM model. + update_ops (list): override the batchnorm update ops to run when + applying gradients. Otherwise we run all update ops found in + the current variable scope. + """ + + self.observation_space = observation_space + self.action_space = action_space + self.model = model + self._sess = sess + self._obs_input = obs_input + self._prev_action_input = prev_action_input + self._prev_reward_input = prev_reward_input + self._sampler = action_sampler + self._is_training = self._get_is_training_placeholder() + self._action_prob = action_prob + self._state_inputs = state_inputs or [] + self._state_outputs = state_outputs or [] + self._seq_lens = seq_lens + self._max_seq_len = max_seq_len + self._batch_divisibility_req = batch_divisibility_req + self._update_ops = update_ops + self._stats_fetches = {} + + if loss is not None: + self._initialize_loss(loss, loss_inputs) + else: + self._loss = None + + if len(self._state_inputs) != len(self._state_outputs): + raise ValueError( + "Number of state input and output tensors must match, got: " + "{} vs {}".format(self._state_inputs, self._state_outputs)) + if len(self.get_initial_state()) != len(self._state_inputs): + raise ValueError( + "Length of initial state must match number of state inputs, " + "got: {} vs {}".format(self.get_initial_state(), + self._state_inputs)) + if self._state_inputs and self._seq_lens is None: + raise ValueError( + "seq_lens tensor must be given if state inputs are defined") + + def _initialize_loss(self, loss, loss_inputs): + self._loss_inputs = loss_inputs + self._loss_input_dict = dict(self._loss_inputs) + for i, ph in enumerate(self._state_inputs): + self._loss_input_dict["state_in_{}".format(i)] = ph + + if self.model: + self._loss = self.model.custom_loss(loss, self._loss_input_dict) + self._stats_fetches.update({"model": self.model.custom_stats()}) + else: + self._loss = loss + + self._optimizer = self.optimizer() + self._grads_and_vars = [ + (g, v) for (g, v) in self.gradients(self._optimizer, self._loss) + if g is not None + ] + self._grads = [g for (g, v) in self._grads_and_vars] + self._variables = ray.experimental.tf_utils.TensorFlowVariables( + self._loss, self._sess) + + # gather update ops for any batch norm layers + if not self._update_ops: + self._update_ops = tf.get_collection( + tf.GraphKeys.UPDATE_OPS, scope=tf.get_variable_scope().name) + if self._update_ops: + logger.debug("Update ops to run on apply gradient: {}".format( + self._update_ops)) + with tf.control_dependencies(self._update_ops): + self._apply_op = self.build_apply_op(self._optimizer, + self._grads_and_vars) + + if log_once("loss_used"): + logger.debug( + "These tensors were used in the loss_fn:\n\n{}\n".format( + summarize(self._loss_input_dict))) + + self._sess.run(tf.global_variables_initializer()) + + @override(Policy) + def compute_actions(self, + obs_batch, + state_batches=None, + prev_action_batch=None, + prev_reward_batch=None, + info_batch=None, + episodes=None, + **kwargs): + builder = TFRunBuilder(self._sess, "compute_actions") + fetches = self._build_compute_actions(builder, obs_batch, + state_batches, prev_action_batch, + prev_reward_batch) + return builder.get(fetches) + + @override(Policy) + def compute_gradients(self, postprocessed_batch): + assert self._loss is not None, "Loss not initialized" + builder = TFRunBuilder(self._sess, "compute_gradients") + fetches = self._build_compute_gradients(builder, postprocessed_batch) + return builder.get(fetches) + + @override(Policy) + def apply_gradients(self, gradients): + assert self._loss is not None, "Loss not initialized" + builder = TFRunBuilder(self._sess, "apply_gradients") + fetches = self._build_apply_gradients(builder, gradients) + builder.get(fetches) + + @override(Policy) + def learn_on_batch(self, postprocessed_batch): + assert self._loss is not None, "Loss not initialized" + builder = TFRunBuilder(self._sess, "learn_on_batch") + fetches = self._build_learn_on_batch(builder, postprocessed_batch) + return builder.get(fetches) + + @override(Policy) + def get_weights(self): + return self._variables.get_flat() + + @override(Policy) + def set_weights(self, weights): + return self._variables.set_flat(weights) + + @override(Policy) + def export_model(self, export_dir): + """Export tensorflow graph to export_dir for serving.""" + with self._sess.graph.as_default(): + builder = tf.saved_model.builder.SavedModelBuilder(export_dir) + signature_def_map = self._build_signature_def() + builder.add_meta_graph_and_variables( + self._sess, [tf.saved_model.tag_constants.SERVING], + signature_def_map=signature_def_map) + builder.save() + + @override(Policy) + def export_checkpoint(self, export_dir, filename_prefix="model"): + """Export tensorflow checkpoint to export_dir.""" + try: + os.makedirs(export_dir) + except OSError as e: + # ignore error if export dir already exists + if e.errno != errno.EEXIST: + raise + save_path = os.path.join(export_dir, filename_prefix) + with self._sess.graph.as_default(): + saver = tf.train.Saver() + saver.save(self._sess, save_path) + + @DeveloperAPI + def copy(self, existing_inputs): + """Creates a copy of self using existing input placeholders. + + Optional, only required to work with the multi-GPU optimizer.""" + raise NotImplementedError + + @DeveloperAPI + def extra_compute_action_feed_dict(self): + """Extra dict to pass to the compute actions session run.""" + return {} + + @DeveloperAPI + def extra_compute_action_fetches(self): + """Extra values to fetch and return from compute_actions(). + + By default we only return action probability info (if present). + """ + if self._action_prob is not None: + return {"action_prob": self._action_prob} + else: + return {} + + @DeveloperAPI + def extra_compute_grad_feed_dict(self): + """Extra dict to pass to the compute gradients session run.""" + return {} # e.g, kl_coeff + + @DeveloperAPI + def extra_compute_grad_fetches(self): + """Extra values to fetch and return from compute_gradients().""" + return {LEARNER_STATS_KEY: {}} # e.g, stats, td error, etc. + + @DeveloperAPI + def optimizer(self): + """TF optimizer to use for policy optimization.""" + if hasattr(self, "config"): + return tf.train.AdamOptimizer(self.config["lr"]) + else: + return tf.train.AdamOptimizer() + + @DeveloperAPI + def gradients(self, optimizer, loss): + """Override for custom gradient computation.""" + return optimizer.compute_gradients(loss) + + @DeveloperAPI + def build_apply_op(self, optimizer, grads_and_vars): + """Override for custom gradient apply computation.""" + + # specify global_step for TD3 which needs to count the num updates + return optimizer.apply_gradients( + self._grads_and_vars, + global_step=tf.train.get_or_create_global_step()) + + @DeveloperAPI + def _get_is_training_placeholder(self): + """Get the placeholder for _is_training, i.e., for batch norm layers. + + This can be called safely before __init__ has run. + """ + if not hasattr(self, "_is_training"): + self._is_training = tf.placeholder_with_default(False, ()) + return self._is_training + + def _extra_input_signature_def(self): + """Extra input signatures to add when exporting tf model. + Inferred from extra_compute_action_feed_dict() + """ + feed_dict = self.extra_compute_action_feed_dict() + return { + k.name: tf.saved_model.utils.build_tensor_info(k) + for k in feed_dict.keys() + } + + def _extra_output_signature_def(self): + """Extra output signatures to add when exporting tf model. + Inferred from extra_compute_action_fetches() + """ + fetches = self.extra_compute_action_fetches() + return { + k: tf.saved_model.utils.build_tensor_info(fetches[k]) + for k in fetches.keys() + } + + def _build_signature_def(self): + """Build signature def map for tensorflow SavedModelBuilder. + """ + # build input signatures + input_signature = self._extra_input_signature_def() + input_signature["observations"] = \ + tf.saved_model.utils.build_tensor_info(self._obs_input) + + if self._seq_lens is not None: + input_signature["seq_lens"] = \ + tf.saved_model.utils.build_tensor_info(self._seq_lens) + if self._prev_action_input is not None: + input_signature["prev_action"] = \ + tf.saved_model.utils.build_tensor_info(self._prev_action_input) + if self._prev_reward_input is not None: + input_signature["prev_reward"] = \ + tf.saved_model.utils.build_tensor_info(self._prev_reward_input) + input_signature["is_training"] = \ + tf.saved_model.utils.build_tensor_info(self._is_training) + + for state_input in self._state_inputs: + input_signature[state_input.name] = \ + tf.saved_model.utils.build_tensor_info(state_input) + + # build output signatures + output_signature = self._extra_output_signature_def() + output_signature["actions"] = \ + tf.saved_model.utils.build_tensor_info(self._sampler) + for state_output in self._state_outputs: + output_signature[state_output.name] = \ + tf.saved_model.utils.build_tensor_info(state_output) + signature_def = ( + tf.saved_model.signature_def_utils.build_signature_def( + input_signature, output_signature, + tf.saved_model.signature_constants.PREDICT_METHOD_NAME)) + signature_def_key = (tf.saved_model.signature_constants. + DEFAULT_SERVING_SIGNATURE_DEF_KEY) + signature_def_map = {signature_def_key: signature_def} + return signature_def_map + + def _build_compute_actions(self, + builder, + obs_batch, + state_batches=None, + prev_action_batch=None, + prev_reward_batch=None, + episodes=None): + state_batches = state_batches or [] + if len(self._state_inputs) != len(state_batches): + raise ValueError( + "Must pass in RNN state batches for placeholders {}, got {}". + format(self._state_inputs, state_batches)) + builder.add_feed_dict(self.extra_compute_action_feed_dict()) + builder.add_feed_dict({self._obs_input: obs_batch}) + if state_batches: + builder.add_feed_dict({self._seq_lens: np.ones(len(obs_batch))}) + if self._prev_action_input is not None and prev_action_batch: + builder.add_feed_dict({self._prev_action_input: prev_action_batch}) + if self._prev_reward_input is not None and prev_reward_batch: + builder.add_feed_dict({self._prev_reward_input: prev_reward_batch}) + builder.add_feed_dict({self._is_training: False}) + builder.add_feed_dict(dict(zip(self._state_inputs, state_batches))) + fetches = builder.add_fetches([self._sampler] + self._state_outputs + + [self.extra_compute_action_fetches()]) + return fetches[0], fetches[1:-1], fetches[-1] + + def _build_compute_gradients(self, builder, postprocessed_batch): + builder.add_feed_dict(self.extra_compute_grad_feed_dict()) + builder.add_feed_dict({self._is_training: True}) + builder.add_feed_dict(self._get_loss_inputs_dict(postprocessed_batch)) + fetches = builder.add_fetches( + [self._grads, self._get_grad_and_stats_fetches()]) + return fetches[0], fetches[1] + + def _build_apply_gradients(self, builder, gradients): + if len(gradients) != len(self._grads): + raise ValueError( + "Unexpected number of gradients to apply, got {} for {}". + format(gradients, self._grads)) + builder.add_feed_dict({self._is_training: True}) + builder.add_feed_dict(dict(zip(self._grads, gradients))) + fetches = builder.add_fetches([self._apply_op]) + return fetches[0] + + def _build_learn_on_batch(self, builder, postprocessed_batch): + builder.add_feed_dict(self.extra_compute_grad_feed_dict()) + builder.add_feed_dict(self._get_loss_inputs_dict(postprocessed_batch)) + builder.add_feed_dict({self._is_training: True}) + fetches = builder.add_fetches([ + self._apply_op, + self._get_grad_and_stats_fetches(), + ]) + return fetches[1] + + def _get_grad_and_stats_fetches(self): + fetches = self.extra_compute_grad_fetches() + if LEARNER_STATS_KEY not in fetches: + raise ValueError( + "Grad fetches should contain 'stats': {...} entry") + if self._stats_fetches: + fetches[LEARNER_STATS_KEY] = dict(self._stats_fetches, + **fetches[LEARNER_STATS_KEY]) + return fetches + + def _get_loss_inputs_dict(self, batch): + feed_dict = {} + if self._batch_divisibility_req > 1: + meets_divisibility_reqs = ( + len(batch[SampleBatch.CUR_OBS]) % + self._batch_divisibility_req == 0 + and max(batch[SampleBatch.AGENT_INDEX]) == 0) # not multiagent + else: + meets_divisibility_reqs = True + + # Simple case: not RNN nor do we need to pad + if not self._state_inputs and meets_divisibility_reqs: + for k, ph in self._loss_inputs: + feed_dict[ph] = batch[k] + return feed_dict + + if self._state_inputs: + max_seq_len = self._max_seq_len + dynamic_max = True + else: + max_seq_len = self._batch_divisibility_req + dynamic_max = False + + # RNN or multi-agent case + feature_keys = [k for k, v in self._loss_inputs] + state_keys = [ + "state_in_{}".format(i) for i in range(len(self._state_inputs)) + ] + feature_sequences, initial_states, seq_lens = chop_into_sequences( + batch[SampleBatch.EPS_ID], + batch[SampleBatch.UNROLL_ID], + batch[SampleBatch.AGENT_INDEX], [batch[k] for k in feature_keys], + [batch[k] for k in state_keys], + max_seq_len, + dynamic_max=dynamic_max) + for k, v in zip(feature_keys, feature_sequences): + feed_dict[self._loss_input_dict[k]] = v + for k, v in zip(state_keys, initial_states): + feed_dict[self._loss_input_dict[k]] = v + feed_dict[self._seq_lens] = seq_lens + + if log_once("rnn_feed_dict"): + logger.info("Padded input for RNN:\n\n{}\n".format( + summarize({ + "features": feature_sequences, + "initial_states": initial_states, + "seq_lens": seq_lens, + "max_seq_len": max_seq_len, + }))) + return feed_dict + + +@DeveloperAPI +class LearningRateSchedule(object): + """Mixin for TFPolicy that adds a learning rate schedule.""" + + @DeveloperAPI + def __init__(self, lr, lr_schedule): + self.cur_lr = tf.get_variable("lr", initializer=lr) + if lr_schedule is None: + self.lr_schedule = ConstantSchedule(lr) + else: + self.lr_schedule = PiecewiseSchedule( + lr_schedule, outside_value=lr_schedule[-1][-1]) + + @override(Policy) + def on_global_var_update(self, global_vars): + super(LearningRateSchedule, self).on_global_var_update(global_vars) + self.cur_lr.load( + self.lr_schedule.value(global_vars["timestep"]), + session=self._sess) + + @override(TFPolicy) + def optimizer(self): + return tf.train.AdamOptimizer(self.cur_lr) diff --git a/python/ray/rllib/policy/tf_policy_template.py b/python/ray/rllib/policy/tf_policy_template.py new file mode 100644 index 000000000000..36f482f18bf8 --- /dev/null +++ b/python/ray/rllib/policy/tf_policy_template.py @@ -0,0 +1,146 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from ray.rllib.policy.dynamic_tf_policy import DynamicTFPolicy +from ray.rllib.policy.policy import Policy +from ray.rllib.policy.tf_policy import TFPolicy +from ray.rllib.utils.annotations import override, DeveloperAPI + + +@DeveloperAPI +def build_tf_policy(name, + loss_fn, + get_default_config=None, + stats_fn=None, + grad_stats_fn=None, + extra_action_fetches_fn=None, + postprocess_fn=None, + optimizer_fn=None, + gradients_fn=None, + before_init=None, + before_loss_init=None, + after_init=None, + make_action_sampler=None, + mixins=None, + get_batch_divisibility_req=None): + """Helper function for creating a dynamic tf policy at runtime. + + Arguments: + name (str): name of the policy (e.g., "PPOTFPolicy") + loss_fn (func): function that returns a loss tensor the policy, + and dict of experience tensor placeholders + get_default_config (func): optional function that returns the default + config to merge with any overrides + stats_fn (func): optional function that returns a dict of + TF fetches given the policy and batch input tensors + grad_stats_fn (func): optional function that returns a dict of + TF fetches given the policy and loss gradient tensors + extra_action_fetches_fn (func): optional function that returns + a dict of TF fetches given the policy object + postprocess_fn (func): optional experience postprocessing function + that takes the same args as Policy.postprocess_trajectory() + optimizer_fn (func): optional function that returns a tf.Optimizer + given the policy and config + gradients_fn (func): optional function that returns a list of gradients + given a tf optimizer and loss tensor. If not specified, this + defaults to optimizer.compute_gradients(loss) + before_init (func): optional function to run at the beginning of + policy init that takes the same arguments as the policy constructor + before_loss_init (func): optional function to run prior to loss + init that takes the same arguments as the policy constructor + after_init (func): optional function to run at the end of policy init + that takes the same arguments as the policy constructor + make_action_sampler (func): optional function that returns a + tuple of action and action prob tensors. The function takes + (policy, input_dict, obs_space, action_space, config) as its + arguments + mixins (list): list of any class mixins for the returned policy class. + These mixins will be applied in order and will have higher + precedence than the DynamicTFPolicy class + get_batch_divisibility_req (func): optional function that returns + the divisibility requirement for sample batches + + Returns: + a DynamicTFPolicy instance that uses the specified args + """ + + if not name.endswith("TFPolicy"): + raise ValueError("Name should match *TFPolicy", name) + + base = DynamicTFPolicy + while mixins: + + class new_base(mixins.pop(), base): + pass + + base = new_base + + class policy_cls(base): + def __init__(self, + obs_space, + action_space, + config, + existing_inputs=None): + if get_default_config: + config = dict(get_default_config(), **config) + + if before_init: + before_init(self, obs_space, action_space, config) + + def before_loss_init_wrapper(policy, obs_space, action_space, + config): + if before_loss_init: + before_loss_init(policy, obs_space, action_space, config) + if extra_action_fetches_fn is None: + self._extra_action_fetches = {} + else: + self._extra_action_fetches = extra_action_fetches_fn(self) + + DynamicTFPolicy.__init__( + self, + obs_space, + action_space, + config, + loss_fn, + stats_fn=stats_fn, + grad_stats_fn=grad_stats_fn, + before_loss_init=before_loss_init_wrapper, + existing_inputs=existing_inputs) + + if after_init: + after_init(self, obs_space, action_space, config) + + @override(Policy) + def postprocess_trajectory(self, + sample_batch, + other_agent_batches=None, + episode=None): + if not postprocess_fn: + return sample_batch + return postprocess_fn(self, sample_batch, other_agent_batches, + episode) + + @override(TFPolicy) + def optimizer(self): + if optimizer_fn: + return optimizer_fn(self, self.config) + else: + return TFPolicy.optimizer(self) + + @override(TFPolicy) + def gradients(self, optimizer, loss): + if gradients_fn: + return gradients_fn(self, optimizer, loss) + else: + return TFPolicy.gradients(self, optimizer, loss) + + @override(TFPolicy) + def extra_compute_action_fetches(self): + return dict( + TFPolicy.extra_compute_action_fetches(self), + **self._extra_action_fetches) + + policy_cls.__name__ = name + policy_cls.__qualname__ = name + return policy_cls diff --git a/python/ray/rllib/policy/torch_policy.py b/python/ray/rllib/policy/torch_policy.py new file mode 100644 index 000000000000..633e438c5ad7 --- /dev/null +++ b/python/ray/rllib/policy/torch_policy.py @@ -0,0 +1,173 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +import numpy as np +from threading import Lock + +try: + import torch +except ImportError: + pass # soft dep + +from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY +from ray.rllib.utils.annotations import override +from ray.rllib.utils.tracking_dict import UsageTrackingDict + + +class TorchPolicy(Policy): + """Template for a PyTorch policy and loss to use with RLlib. + + This is similar to TFPolicy, but for PyTorch. + + Attributes: + observation_space (gym.Space): observation space of the policy. + action_space (gym.Space): action space of the policy. + lock (Lock): Lock that must be held around PyTorch ops on this graph. + This is necessary when using the async sampler. + """ + + def __init__(self, observation_space, action_space, model, loss, + action_distribution_cls): + """Build a policy from policy and loss torch modules. + + Note that model will be placed on GPU device if CUDA_VISIBLE_DEVICES + is set. Only single GPU is supported for now. + + Arguments: + observation_space (gym.Space): observation space of the policy. + action_space (gym.Space): action space of the policy. + model (nn.Module): PyTorch policy module. Given observations as + input, this module must return a list of outputs where the + first item is action logits, and the rest can be any value. + loss (func): Function that takes (policy, batch_tensors) + and returns a single scalar loss. + action_distribution_cls (ActionDistribution): Class for action + distribution. + """ + self.observation_space = observation_space + self.action_space = action_space + self.lock = Lock() + self.device = (torch.device("cuda") + if bool(os.environ.get("CUDA_VISIBLE_DEVICES", None)) + else torch.device("cpu")) + self._model = model.to(self.device) + self._loss = loss + self._optimizer = self.optimizer() + self._action_dist_cls = action_distribution_cls + + @override(Policy) + def compute_actions(self, + obs_batch, + state_batches=None, + prev_action_batch=None, + prev_reward_batch=None, + info_batch=None, + episodes=None, + **kwargs): + with self.lock: + with torch.no_grad(): + ob = torch.from_numpy(np.array(obs_batch)) \ + .float().to(self.device) + model_out = self._model({"obs": ob}, state_batches) + logits, _, vf, state = model_out + action_dist = self._action_dist_cls(logits) + actions = action_dist.sample() + return (actions.cpu().numpy(), + [h.cpu().numpy() for h in state], + self.extra_action_out(model_out)) + + @override(Policy) + def learn_on_batch(self, postprocessed_batch): + batch_tensors = self._lazy_tensor_dict(postprocessed_batch) + + with self.lock: + loss_out = self._loss(self, batch_tensors) + self._optimizer.zero_grad() + loss_out.backward() + + grad_process_info = self.extra_grad_process() + self._optimizer.step() + + grad_info = self.extra_grad_info(batch_tensors) + grad_info.update(grad_process_info) + return {LEARNER_STATS_KEY: grad_info} + + @override(Policy) + def compute_gradients(self, postprocessed_batch): + batch_tensors = self._lazy_tensor_dict(postprocessed_batch) + + with self.lock: + loss_out = self._loss(self, batch_tensors) + self._optimizer.zero_grad() + loss_out.backward() + + grad_process_info = self.extra_grad_process() + + # Note that return values are just references; + # calling zero_grad will modify the values + grads = [] + for p in self._model.parameters(): + if p.grad is not None: + grads.append(p.grad.data.cpu().numpy()) + else: + grads.append(None) + + grad_info = self.extra_grad_info(batch_tensors) + grad_info.update(grad_process_info) + return grads, {LEARNER_STATS_KEY: grad_info} + + @override(Policy) + def apply_gradients(self, gradients): + with self.lock: + for g, p in zip(gradients, self._model.parameters()): + if g is not None: + p.grad = torch.from_numpy(g).to(self.device) + self._optimizer.step() + + @override(Policy) + def get_weights(self): + with self.lock: + return {k: v.cpu() for k, v in self._model.state_dict().items()} + + @override(Policy) + def set_weights(self, weights): + with self.lock: + self._model.load_state_dict(weights) + + @override(Policy) + def get_initial_state(self): + return [s.numpy() for s in self._model.state_init()] + + def extra_grad_process(self): + """Allow subclass to do extra processing on gradients and + return processing info.""" + return {} + + def extra_action_out(self, model_out): + """Returns dict of extra info to include in experience batch. + + Arguments: + model_out (list): Outputs of the policy model module.""" + return {} + + def extra_grad_info(self, batch_tensors): + """Return dict of extra grad info.""" + + return {} + + def optimizer(self): + """Custom PyTorch optimizer to use.""" + if hasattr(self, "config"): + return torch.optim.Adam( + self._model.parameters(), lr=self.config["lr"]) + else: + return torch.optim.Adam(self._model.parameters()) + + def _lazy_tensor_dict(self, postprocessed_batch): + batch_tensors = UsageTrackingDict(postprocessed_batch) + batch_tensors.set_get_interceptor( + lambda arr: torch.from_numpy(arr).to(self.device)) + return batch_tensors diff --git a/python/ray/rllib/evaluation/torch_policy_template.py b/python/ray/rllib/policy/torch_policy_template.py similarity index 82% rename from python/ray/rllib/evaluation/torch_policy_template.py rename to python/ray/rllib/policy/torch_policy_template.py index 7f65c2b963b8..049591c04671 100644 --- a/python/ray/rllib/evaluation/torch_policy_template.py +++ b/python/ray/rllib/policy/torch_policy_template.py @@ -2,8 +2,8 @@ from __future__ import division from __future__ import print_function -from ray.rllib.evaluation.policy_graph import PolicyGraph -from ray.rllib.evaluation.torch_policy_graph import TorchPolicyGraph +from ray.rllib.policy.policy import Policy +from ray.rllib.policy.torch_policy import TorchPolicy from ray.rllib.models.catalog import ModelCatalog from ray.rllib.utils.annotations import override, DeveloperAPI @@ -24,7 +24,7 @@ def build_torch_policy(name, """Helper function for creating a torch policy at runtime. Arguments: - name (str): name of the graph (e.g., "PPOPolicy") + name (str): name of the policy (e.g., "PPOTFPolicy") loss_fn (func): function that returns a loss tensor the policy, and dict of experience tensor placeholders get_default_config (func): optional function that returns the default @@ -32,7 +32,7 @@ def build_torch_policy(name, stats_fn (func): optional function that returns a dict of values given the policy and batch input tensors postprocess_fn (func): optional experience postprocessing function - that takes the same args as PolicyGraph.postprocess_trajectory() + that takes the same args as Policy.postprocess_trajectory() extra_action_out_fn (func): optional function that returns a dict of extra values to include in experiences extra_grad_process_fn (func): optional function that is called after @@ -49,16 +49,16 @@ def build_torch_policy(name, model and action dist from the catalog will be used mixins (list): list of any class mixins for the returned policy class. These mixins will be applied in order and will have higher - precedence than the TorchPolicyGraph class + precedence than the TorchPolicy class Returns: - a TorchPolicyGraph instance that uses the specified args + a TorchPolicy instance that uses the specified args """ if not name.endswith("TorchPolicy"): raise ValueError("Name should match *TorchPolicy", name) - base = TorchPolicyGraph + base = TorchPolicy while mixins: class new_base(mixins.pop(), base): @@ -84,13 +84,13 @@ def __init__(self, obs_space, action_space, config): self.model = ModelCatalog.get_torch_model( obs_space, logit_dim, self.config["model"]) - TorchPolicyGraph.__init__(self, obs_space, action_space, - self.model, loss_fn, self.dist_class) + TorchPolicy.__init__(self, obs_space, action_space, self.model, + loss_fn, self.dist_class) if after_init: after_init(self, obs_space, action_space, config) - @override(PolicyGraph) + @override(Policy) def postprocess_trajectory(self, sample_batch, other_agent_batches=None, @@ -100,33 +100,33 @@ def postprocess_trajectory(self, return postprocess_fn(self, sample_batch, other_agent_batches, episode) - @override(TorchPolicyGraph) + @override(TorchPolicy) def extra_grad_process(self): if extra_grad_process_fn: return extra_grad_process_fn(self) else: - return TorchPolicyGraph.extra_grad_process(self) + return TorchPolicy.extra_grad_process(self) - @override(TorchPolicyGraph) + @override(TorchPolicy) def extra_action_out(self, model_out): if extra_action_out_fn: return extra_action_out_fn(self, model_out) else: - return TorchPolicyGraph.extra_action_out(self, model_out) + return TorchPolicy.extra_action_out(self, model_out) - @override(TorchPolicyGraph) + @override(TorchPolicy) def optimizer(self): if optimizer_fn: return optimizer_fn(self, self.config) else: - return TorchPolicyGraph.optimizer(self) + return TorchPolicy.optimizer(self) - @override(TorchPolicyGraph) + @override(TorchPolicy) def extra_grad_info(self, batch_tensors): if stats_fn: return stats_fn(self, batch_tensors) else: - return TorchPolicyGraph.extra_grad_info(self, batch_tensors) + return TorchPolicy.extra_grad_info(self, batch_tensors) graph_cls.__name__ = name graph_cls.__qualname__ = name diff --git a/python/ray/rllib/rollout.py b/python/ray/rllib/rollout.py index 2bb25f5c40af..efa5743c0a54 100755 --- a/python/ray/rllib/rollout.py +++ b/python/ray/rllib/rollout.py @@ -15,7 +15,7 @@ from ray.rllib.agents.registry import get_agent_class from ray.rllib.env import MultiAgentEnv from ray.rllib.env.base_env import _DUMMY_AGENT_ID -from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID +from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID from ray.tune.util import merge_dicts EXAMPLE_USAGE = """ diff --git a/python/ray/rllib/tests/test_evaluators.py b/python/ray/rllib/tests/test_evaluators.py index 36ded2b4e800..7f2ef740e4f5 100644 --- a/python/ray/rllib/tests/test_evaluators.py +++ b/python/ray/rllib/tests/test_evaluators.py @@ -7,7 +7,7 @@ import ray from ray.rllib.agents.dqn import DQNTrainer from ray.rllib.agents.a3c import A3CTrainer -from ray.rllib.agents.dqn.dqn_policy_graph import _adjust_nstep +from ray.rllib.agents.dqn.dqn_policy import _adjust_nstep from ray.tune.registry import register_env import gym diff --git a/python/ray/rllib/tests/test_external_env.py b/python/ray/rllib/tests/test_external_env.py index 3379639612f6..3b2158959267 100644 --- a/python/ray/rllib/tests/test_external_env.py +++ b/python/ray/rllib/tests/test_external_env.py @@ -13,8 +13,8 @@ from ray.rllib.agents.pg import PGTrainer from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator from ray.rllib.env.external_env import ExternalEnv -from ray.rllib.tests.test_policy_evaluator import (BadPolicyGraph, - MockPolicyGraph, MockEnv) +from ray.rllib.tests.test_policy_evaluator import (BadPolicy, MockPolicy, + MockEnv) from ray.tune.registry import register_env @@ -121,7 +121,7 @@ class TestExternalEnv(unittest.TestCase): def testExternalEnvCompleteEpisodes(self): ev = PolicyEvaluator( env_creator=lambda _: SimpleServing(MockEnv(25)), - policy_graph=MockPolicyGraph, + policy=MockPolicy, batch_steps=40, batch_mode="complete_episodes") for _ in range(3): @@ -131,7 +131,7 @@ def testExternalEnvCompleteEpisodes(self): def testExternalEnvTruncateEpisodes(self): ev = PolicyEvaluator( env_creator=lambda _: SimpleServing(MockEnv(25)), - policy_graph=MockPolicyGraph, + policy=MockPolicy, batch_steps=40, batch_mode="truncate_episodes") for _ in range(3): @@ -141,7 +141,7 @@ def testExternalEnvTruncateEpisodes(self): def testExternalEnvOffPolicy(self): ev = PolicyEvaluator( env_creator=lambda _: SimpleOffPolicyServing(MockEnv(25), 42), - policy_graph=MockPolicyGraph, + policy=MockPolicy, batch_steps=40, batch_mode="complete_episodes") for _ in range(3): @@ -153,7 +153,7 @@ def testExternalEnvOffPolicy(self): def testExternalEnvBadActions(self): ev = PolicyEvaluator( env_creator=lambda _: SimpleServing(MockEnv(25)), - policy_graph=BadPolicyGraph, + policy=BadPolicy, sample_async=True, batch_steps=40, batch_mode="truncate_episodes") @@ -198,7 +198,7 @@ def testTrainCartpoleMulti(self): def testExternalEnvHorizonNotSupported(self): ev = PolicyEvaluator( env_creator=lambda _: SimpleServing(MockEnv(25)), - policy_graph=MockPolicyGraph, + policy=MockPolicy, episode_horizon=20, batch_steps=10, batch_mode="complete_episodes") diff --git a/python/ray/rllib/tests/test_external_multi_agent_env.py b/python/ray/rllib/tests/test_external_multi_agent_env.py index c01e6fa0b7ae..fcb3de634cbe 100644 --- a/python/ray/rllib/tests/test_external_multi_agent_env.py +++ b/python/ray/rllib/tests/test_external_multi_agent_env.py @@ -8,11 +8,11 @@ import unittest import ray -from ray.rllib.agents.pg.pg_policy_graph import PGTFPolicy +from ray.rllib.agents.pg.pg_policy import PGTFPolicy from ray.rllib.optimizers import SyncSamplesOptimizer from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv -from ray.rllib.tests.test_policy_evaluator import MockPolicyGraph +from ray.rllib.tests.test_policy_evaluator import MockPolicy from ray.rllib.tests.test_external_env import make_simple_serving from ray.rllib.tests.test_multi_agent_env import BasicMultiAgent, MultiCartpole from ray.rllib.evaluation.metrics import collect_metrics @@ -25,7 +25,7 @@ def testExternalMultiAgentEnvCompleteEpisodes(self): agents = 4 ev = PolicyEvaluator( env_creator=lambda _: SimpleMultiServing(BasicMultiAgent(agents)), - policy_graph=MockPolicyGraph, + policy=MockPolicy, batch_steps=40, batch_mode="complete_episodes") for _ in range(3): @@ -37,7 +37,7 @@ def testExternalMultiAgentEnvTruncateEpisodes(self): agents = 4 ev = PolicyEvaluator( env_creator=lambda _: SimpleMultiServing(BasicMultiAgent(agents)), - policy_graph=MockPolicyGraph, + policy=MockPolicy, batch_steps=40, batch_mode="truncate_episodes") for _ in range(3): @@ -51,9 +51,9 @@ def testExternalMultiAgentEnvSample(self): obs_space = gym.spaces.Discrete(2) ev = PolicyEvaluator( env_creator=lambda _: SimpleMultiServing(BasicMultiAgent(agents)), - policy_graph={ - "p0": (MockPolicyGraph, obs_space, act_space, {}), - "p1": (MockPolicyGraph, obs_space, act_space, {}), + policy={ + "p0": (MockPolicy, obs_space, act_space, {}), + "p1": (MockPolicy, obs_space, act_space, {}), }, policy_mapping_fn=lambda agent_id: "p{}".format(agent_id % 2), batch_steps=50) @@ -72,7 +72,7 @@ def testTrainExternalMultiCartpoleManyPolicies(self): policy_ids = list(policies.keys()) ev = PolicyEvaluator( env_creator=lambda _: MultiCartpole(n), - policy_graph=policies, + policy=policies, policy_mapping_fn=lambda agent_id: random.choice(policy_ids), batch_steps=100) optimizer = SyncSamplesOptimizer(ev, []) diff --git a/python/ray/rllib/tests/test_io.py b/python/ray/rllib/tests/test_io.py index 0706be1019cc..c98e4553dcf1 100644 --- a/python/ray/rllib/tests/test_io.py +++ b/python/ray/rllib/tests/test_io.py @@ -15,7 +15,7 @@ import ray from ray.rllib.agents.pg import PGTrainer -from ray.rllib.agents.pg.pg_policy_graph import PGTFPolicy +from ray.rllib.agents.pg.pg_policy import PGTFPolicy from ray.rllib.evaluation import SampleBatch from ray.rllib.offline import IOContext, JsonWriter, JsonReader from ray.rllib.offline.json_writer import _to_json @@ -167,7 +167,7 @@ def gen_policy(): "num_workers": 0, "output": self.test_dir, "multiagent": { - "policy_graphs": { + "policies": { "policy_1": gen_policy(), "policy_2": gen_policy(), }, @@ -188,7 +188,7 @@ def gen_policy(): "input_evaluation": ["simulation"], "train_batch_size": 2000, "multiagent": { - "policy_graphs": { + "policies": { "policy_1": gen_policy(), "policy_2": gen_policy(), }, diff --git a/python/ray/rllib/tests/test_multi_agent_env.py b/python/ray/rllib/tests/test_multi_agent_env.py index 72130712d555..be4bfcd3428f 100644 --- a/python/ray/rllib/tests/test_multi_agent_env.py +++ b/python/ray/rllib/tests/test_multi_agent_env.py @@ -8,14 +8,14 @@ import ray from ray.rllib.agents.pg import PGTrainer -from ray.rllib.agents.pg.pg_policy_graph import PGTFPolicy -from ray.rllib.agents.dqn.dqn_policy_graph import DQNPolicyGraph +from ray.rllib.agents.pg.pg_policy import PGTFPolicy +from ray.rllib.agents.dqn.dqn_policy import DQNTFPolicy from ray.rllib.optimizers import (SyncSamplesOptimizer, SyncReplayOptimizer, AsyncGradientsOptimizer) from ray.rllib.tests.test_policy_evaluator import (MockEnv, MockEnv2, - MockPolicyGraph) + MockPolicy) from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator -from ray.rllib.evaluation.policy_graph import PolicyGraph +from ray.rllib.policy.policy import Policy from ray.rllib.evaluation.metrics import collect_metrics from ray.rllib.env.base_env import _MultiAgentEnvToBaseEnv from ray.rllib.env.multi_agent_env import MultiAgentEnv @@ -329,9 +329,9 @@ def testMultiAgentSample(self): obs_space = gym.spaces.Discrete(2) ev = PolicyEvaluator( env_creator=lambda _: BasicMultiAgent(5), - policy_graph={ - "p0": (MockPolicyGraph, obs_space, act_space, {}), - "p1": (MockPolicyGraph, obs_space, act_space, {}), + policy={ + "p0": (MockPolicy, obs_space, act_space, {}), + "p1": (MockPolicy, obs_space, act_space, {}), }, policy_mapping_fn=lambda agent_id: "p{}".format(agent_id % 2), batch_steps=50) @@ -347,9 +347,9 @@ def testMultiAgentSampleSyncRemote(self): obs_space = gym.spaces.Discrete(2) ev = PolicyEvaluator( env_creator=lambda _: BasicMultiAgent(5), - policy_graph={ - "p0": (MockPolicyGraph, obs_space, act_space, {}), - "p1": (MockPolicyGraph, obs_space, act_space, {}), + policy={ + "p0": (MockPolicy, obs_space, act_space, {}), + "p1": (MockPolicy, obs_space, act_space, {}), }, policy_mapping_fn=lambda agent_id: "p{}".format(agent_id % 2), batch_steps=50, @@ -364,9 +364,9 @@ def testMultiAgentSampleAsyncRemote(self): obs_space = gym.spaces.Discrete(2) ev = PolicyEvaluator( env_creator=lambda _: BasicMultiAgent(5), - policy_graph={ - "p0": (MockPolicyGraph, obs_space, act_space, {}), - "p1": (MockPolicyGraph, obs_space, act_space, {}), + policy={ + "p0": (MockPolicy, obs_space, act_space, {}), + "p1": (MockPolicy, obs_space, act_space, {}), }, policy_mapping_fn=lambda agent_id: "p{}".format(agent_id % 2), batch_steps=50, @@ -380,9 +380,9 @@ def testMultiAgentSampleWithHorizon(self): obs_space = gym.spaces.Discrete(2) ev = PolicyEvaluator( env_creator=lambda _: BasicMultiAgent(5), - policy_graph={ - "p0": (MockPolicyGraph, obs_space, act_space, {}), - "p1": (MockPolicyGraph, obs_space, act_space, {}), + policy={ + "p0": (MockPolicy, obs_space, act_space, {}), + "p1": (MockPolicy, obs_space, act_space, {}), }, policy_mapping_fn=lambda agent_id: "p{}".format(agent_id % 2), episode_horizon=10, # test with episode horizon set @@ -395,9 +395,9 @@ def testSampleFromEarlyDoneEnv(self): obs_space = gym.spaces.Discrete(2) ev = PolicyEvaluator( env_creator=lambda _: EarlyDoneMultiAgent(), - policy_graph={ - "p0": (MockPolicyGraph, obs_space, act_space, {}), - "p1": (MockPolicyGraph, obs_space, act_space, {}), + policy={ + "p0": (MockPolicy, obs_space, act_space, {}), + "p1": (MockPolicy, obs_space, act_space, {}), }, policy_mapping_fn=lambda agent_id: "p{}".format(agent_id % 2), batch_mode="complete_episodes", @@ -411,8 +411,8 @@ def testMultiAgentSampleRoundRobin(self): obs_space = gym.spaces.Discrete(10) ev = PolicyEvaluator( env_creator=lambda _: RoundRobinMultiAgent(5, increment_obs=True), - policy_graph={ - "p0": (MockPolicyGraph, obs_space, act_space, {}), + policy={ + "p0": (MockPolicy, obs_space, act_space, {}), }, policy_mapping_fn=lambda agent_id: "p0", batch_steps=50) @@ -445,7 +445,7 @@ def testMultiAgentSampleRoundRobin(self): def testCustomRNNStateValues(self): h = {"some": {"arbitrary": "structure", "here": [1, 2, 3]}} - class StatefulPolicyGraph(PolicyGraph): + class StatefulPolicy(Policy): def compute_actions(self, obs_batch, state_batches, @@ -460,7 +460,7 @@ def get_initial_state(self): ev = PolicyEvaluator( env_creator=lambda _: gym.make("CartPole-v0"), - policy_graph=StatefulPolicyGraph, + policy=StatefulPolicy, batch_steps=5) batch = ev.sample() self.assertEqual(batch.count, 5) @@ -470,7 +470,7 @@ def get_initial_state(self): self.assertEqual(batch["state_out_0"][1], h) def testReturningModelBasedRolloutsData(self): - class ModelBasedPolicyGraph(PGTFPolicy): + class ModelBasedPolicy(PGTFPolicy): def compute_actions(self, obs_batch, state_batches, @@ -505,9 +505,9 @@ def compute_actions(self, act_space = single_env.action_space ev = PolicyEvaluator( env_creator=lambda _: MultiCartpole(2), - policy_graph={ - "p0": (ModelBasedPolicyGraph, obs_space, act_space, {}), - "p1": (ModelBasedPolicyGraph, obs_space, act_space, {}), + policy={ + "p0": (ModelBasedPolicy, obs_space, act_space, {}), + "p1": (ModelBasedPolicy, obs_space, act_space, {}), }, policy_mapping_fn=lambda agent_id: "p0", batch_steps=5) @@ -547,7 +547,7 @@ def gen_policy(): config={ "num_workers": 0, "multiagent": { - "policy_graphs": { + "policies": { "policy_1": gen_policy(), "policy_2": gen_policy(), }, @@ -579,17 +579,17 @@ def _testWithOptimizer(self, optimizer_cls): # happen since the replay buffer doesn't encode extra fields like # "advantages" that PG uses. policies = { - "p1": (DQNPolicyGraph, obs_space, act_space, dqn_config), - "p2": (DQNPolicyGraph, obs_space, act_space, dqn_config), + "p1": (DQNTFPolicy, obs_space, act_space, dqn_config), + "p2": (DQNTFPolicy, obs_space, act_space, dqn_config), } else: policies = { "p1": (PGTFPolicy, obs_space, act_space, {}), - "p2": (DQNPolicyGraph, obs_space, act_space, dqn_config), + "p2": (DQNTFPolicy, obs_space, act_space, dqn_config), } ev = PolicyEvaluator( env_creator=lambda _: MultiCartpole(n), - policy_graph=policies, + policy=policies, policy_mapping_fn=lambda agent_id: ["p1", "p2"][agent_id % 2], batch_steps=50) if optimizer_cls == AsyncGradientsOptimizer: @@ -600,7 +600,7 @@ def policy_mapper(agent_id): remote_evs = [ PolicyEvaluator.as_remote().remote( env_creator=lambda _: MultiCartpole(n), - policy_graph=policies, + policy=policies, policy_mapping_fn=policy_mapper, batch_steps=50) ] @@ -610,12 +610,16 @@ def policy_mapper(agent_id): for i in range(200): ev.foreach_policy(lambda p, _: p.set_epsilon( max(0.02, 1 - i * .02)) - if isinstance(p, DQNPolicyGraph) else None) + if isinstance(p, DQNTFPolicy) else None) optimizer.step() result = collect_metrics(ev, remote_evs) if i % 20 == 0: - ev.foreach_policy(lambda p, _: p.update_target() if isinstance( - p, DQNPolicyGraph) else None) + + def do_update(p): + if isinstance(p, DQNTFPolicy): + p.update_target() + + ev.foreach_policy(lambda p, _: do_update(p)) print("Iter {}, rew {}".format(i, result["policy_reward_mean"])) print("Total reward", result["episode_reward_mean"]) @@ -645,7 +649,7 @@ def testTrainMultiCartpoleManyPolicies(self): policy_ids = list(policies.keys()) ev = PolicyEvaluator( env_creator=lambda _: MultiCartpole(n), - policy_graph=policies, + policy=policies, policy_mapping_fn=lambda agent_id: random.choice(policy_ids), batch_steps=100) optimizer = SyncSamplesOptimizer(ev, []) diff --git a/python/ray/rllib/tests/test_nested_spaces.py b/python/ray/rllib/tests/test_nested_spaces.py index b70bd9a2908e..0220ba01722c 100644 --- a/python/ray/rllib/tests/test_nested_spaces.py +++ b/python/ray/rllib/tests/test_nested_spaces.py @@ -12,7 +12,7 @@ import ray from ray.rllib.agents.a3c import A2CTrainer from ray.rllib.agents.pg import PGTrainer -from ray.rllib.agents.pg.pg_policy_graph import PGTFPolicy +from ray.rllib.agents.pg.pg_policy import PGTFPolicy from ray.rllib.env import MultiAgentEnv from ray.rllib.env.base_env import BaseEnv from ray.rllib.env.vector_env import VectorEnv @@ -331,7 +331,7 @@ def testMultiAgentComplexSpaces(self): "sample_batch_size": 5, "train_batch_size": 5, "multiagent": { - "policy_graphs": { + "policies": { "tuple_policy": ( PGTFPolicy, TUPLE_SPACE, act_space, {"model": {"custom_model": "tuple_spy"}}), diff --git a/python/ray/rllib/tests/test_optimizers.py b/python/ray/rllib/tests/test_optimizers.py index 5436baeafa90..f851cfc33f12 100644 --- a/python/ray/rllib/tests/test_optimizers.py +++ b/python/ray/rllib/tests/test_optimizers.py @@ -9,7 +9,7 @@ import ray from ray.rllib.agents.ppo import PPOTrainer -from ray.rllib.agents.ppo.ppo_policy_graph import PPOTFPolicy +from ray.rllib.agents.ppo.ppo_policy import PPOTFPolicy from ray.rllib.evaluation import SampleBatch from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator from ray.rllib.optimizers import AsyncGradientsOptimizer, AsyncSamplesOptimizer @@ -240,12 +240,12 @@ def make_sess(): local = PolicyEvaluator( env_creator=lambda _: gym.make("CartPole-v0"), - policy_graph=PPOTFPolicy, + policy=PPOTFPolicy, tf_session_creator=make_sess) remotes = [ PolicyEvaluator.as_remote().remote( env_creator=lambda _: gym.make("CartPole-v0"), - policy_graph=PPOTFPolicy, + policy=PPOTFPolicy, tf_session_creator=make_sess) ] return local, remotes diff --git a/python/ray/rllib/tests/test_perf.py b/python/ray/rllib/tests/test_perf.py index f437c9628dfd..e31530f44ced 100644 --- a/python/ray/rllib/tests/test_perf.py +++ b/python/ray/rllib/tests/test_perf.py @@ -8,7 +8,7 @@ import ray from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator -from ray.rllib.tests.test_policy_evaluator import MockPolicyGraph +from ray.rllib.tests.test_policy_evaluator import MockPolicy class TestPerf(unittest.TestCase): @@ -19,7 +19,7 @@ def testBaselinePerformance(self): for _ in range(20): ev = PolicyEvaluator( env_creator=lambda _: gym.make("CartPole-v0"), - policy_graph=MockPolicyGraph, + policy=MockPolicy, batch_steps=100) start = time.time() count = 0 diff --git a/python/ray/rllib/tests/test_policy_evaluator.py b/python/ray/rllib/tests/test_policy_evaluator.py index 6283a5b66314..dc0dcaff6782 100644 --- a/python/ray/rllib/tests/test_policy_evaluator.py +++ b/python/ray/rllib/tests/test_policy_evaluator.py @@ -14,14 +14,14 @@ from ray.rllib.agents.a3c import A2CTrainer from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator from ray.rllib.evaluation.metrics import collect_metrics -from ray.rllib.evaluation.policy_graph import PolicyGraph +from ray.rllib.policy.policy import Policy from ray.rllib.evaluation.postprocessing import compute_advantages -from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID, SampleBatch +from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch from ray.rllib.env.vector_env import VectorEnv from ray.tune.registry import register_env -class MockPolicyGraph(PolicyGraph): +class MockPolicy(Policy): def compute_actions(self, obs_batch, state_batches, @@ -39,7 +39,7 @@ def postprocess_trajectory(self, return compute_advantages(batch, 100.0, 0.9, use_gae=False) -class BadPolicyGraph(PolicyGraph): +class BadPolicy(Policy): def compute_actions(self, obs_batch, state_batches, @@ -132,8 +132,7 @@ def get_unwrapped(self): class TestPolicyEvaluator(unittest.TestCase): def testBasic(self): ev = PolicyEvaluator( - env_creator=lambda _: gym.make("CartPole-v0"), - policy_graph=MockPolicyGraph) + env_creator=lambda _: gym.make("CartPole-v0"), policy=MockPolicy) batch = ev.sample() for key in [ "obs", "actions", "rewards", "dones", "advantages", @@ -157,8 +156,7 @@ def to_prev(vec): def testBatchIds(self): ev = PolicyEvaluator( - env_creator=lambda _: gym.make("CartPole-v0"), - policy_graph=MockPolicyGraph) + env_creator=lambda _: gym.make("CartPole-v0"), policy=MockPolicy) batch1 = ev.sample() batch2 = ev.sample() self.assertEqual(len(set(batch1["unroll_id"])), 1) @@ -229,7 +227,7 @@ def testRewardClipping(self): # clipping on ev = PolicyEvaluator( env_creator=lambda _: MockEnv2(episode_length=10), - policy_graph=MockPolicyGraph, + policy=MockPolicy, clip_rewards=True, batch_mode="complete_episodes") self.assertEqual(max(ev.sample()["rewards"]), 1) @@ -239,7 +237,7 @@ def testRewardClipping(self): # clipping off ev2 = PolicyEvaluator( env_creator=lambda _: MockEnv2(episode_length=10), - policy_graph=MockPolicyGraph, + policy=MockPolicy, clip_rewards=False, batch_mode="complete_episodes") self.assertEqual(max(ev2.sample()["rewards"]), 100) @@ -249,7 +247,7 @@ def testRewardClipping(self): def testHardHorizon(self): ev = PolicyEvaluator( env_creator=lambda _: MockEnv(episode_length=10), - policy_graph=MockPolicyGraph, + policy=MockPolicy, batch_mode="complete_episodes", batch_steps=10, episode_horizon=4, @@ -263,7 +261,7 @@ def testHardHorizon(self): def testSoftHorizon(self): ev = PolicyEvaluator( env_creator=lambda _: MockEnv(episode_length=10), - policy_graph=MockPolicyGraph, + policy=MockPolicy, batch_mode="complete_episodes", batch_steps=10, episode_horizon=4, @@ -277,11 +275,11 @@ def testSoftHorizon(self): def testMetrics(self): ev = PolicyEvaluator( env_creator=lambda _: MockEnv(episode_length=10), - policy_graph=MockPolicyGraph, + policy=MockPolicy, batch_mode="complete_episodes") remote_ev = PolicyEvaluator.as_remote().remote( env_creator=lambda _: MockEnv(episode_length=10), - policy_graph=MockPolicyGraph, + policy=MockPolicy, batch_mode="complete_episodes") ev.sample() ray.get(remote_ev.sample.remote()) @@ -293,7 +291,7 @@ def testAsync(self): ev = PolicyEvaluator( env_creator=lambda _: gym.make("CartPole-v0"), sample_async=True, - policy_graph=MockPolicyGraph) + policy=MockPolicy) batch = ev.sample() for key in ["obs", "actions", "rewards", "dones", "advantages"]: self.assertIn(key, batch) @@ -302,7 +300,7 @@ def testAsync(self): def testAutoVectorization(self): ev = PolicyEvaluator( env_creator=lambda cfg: MockEnv(episode_length=20, config=cfg), - policy_graph=MockPolicyGraph, + policy=MockPolicy, batch_mode="truncate_episodes", batch_steps=2, num_envs=8) @@ -325,7 +323,7 @@ def testAutoVectorization(self): def testBatchesLargerWhenVectorized(self): ev = PolicyEvaluator( env_creator=lambda _: MockEnv(episode_length=8), - policy_graph=MockPolicyGraph, + policy=MockPolicy, batch_mode="truncate_episodes", batch_steps=4, num_envs=4) @@ -340,7 +338,7 @@ def testBatchesLargerWhenVectorized(self): def testVectorEnvSupport(self): ev = PolicyEvaluator( env_creator=lambda _: MockVectorEnv(episode_length=20, num_envs=8), - policy_graph=MockPolicyGraph, + policy=MockPolicy, batch_mode="truncate_episodes", batch_steps=10) for _ in range(8): @@ -357,7 +355,7 @@ def testVectorEnvSupport(self): def testTruncateEpisodes(self): ev = PolicyEvaluator( env_creator=lambda _: MockEnv(10), - policy_graph=MockPolicyGraph, + policy=MockPolicy, batch_steps=15, batch_mode="truncate_episodes") batch = ev.sample() @@ -366,7 +364,7 @@ def testTruncateEpisodes(self): def testCompleteEpisodes(self): ev = PolicyEvaluator( env_creator=lambda _: MockEnv(10), - policy_graph=MockPolicyGraph, + policy=MockPolicy, batch_steps=5, batch_mode="complete_episodes") batch = ev.sample() @@ -375,7 +373,7 @@ def testCompleteEpisodes(self): def testCompleteEpisodesPacking(self): ev = PolicyEvaluator( env_creator=lambda _: MockEnv(10), - policy_graph=MockPolicyGraph, + policy=MockPolicy, batch_steps=15, batch_mode="complete_episodes") batch = ev.sample() @@ -387,7 +385,7 @@ def testCompleteEpisodesPacking(self): def testFilterSync(self): ev = PolicyEvaluator( env_creator=lambda _: gym.make("CartPole-v0"), - policy_graph=MockPolicyGraph, + policy=MockPolicy, sample_async=True, observation_filter="ConcurrentMeanStdFilter") time.sleep(2) @@ -400,7 +398,7 @@ def testFilterSync(self): def testGetFilters(self): ev = PolicyEvaluator( env_creator=lambda _: gym.make("CartPole-v0"), - policy_graph=MockPolicyGraph, + policy=MockPolicy, sample_async=True, observation_filter="ConcurrentMeanStdFilter") self.sample_and_flush(ev) @@ -415,7 +413,7 @@ def testGetFilters(self): def testSyncFilter(self): ev = PolicyEvaluator( env_creator=lambda _: gym.make("CartPole-v0"), - policy_graph=MockPolicyGraph, + policy=MockPolicy, sample_async=True, observation_filter="ConcurrentMeanStdFilter") obs_f = self.sample_and_flush(ev) diff --git a/python/ray/rllib/utils/__init__.py b/python/ray/rllib/utils/__init__.py index a16cba22b611..aad5590fd097 100644 --- a/python/ray/rllib/utils/__init__.py +++ b/python/ray/rllib/utils/__init__.py @@ -10,13 +10,30 @@ logger = logging.getLogger(__name__) -def renamed_class(cls): +def renamed_class(cls, old_name): + """Helper class for renaming classes with a warning.""" + + class DeprecationWrapper(cls): + # note: **kw not supported for ray.remote classes + def __init__(self, *args, **kw): + new_name = cls.__module__ + "." + cls.__name__ + logger.warn("DeprecationWarning: {} has been renamed to {}. ". + format(old_name, new_name) + + "This will raise an error in the future.") + cls.__init__(self, *args, **kw) + + DeprecationWrapper.__name__ = cls.__name__ + + return DeprecationWrapper + + +def renamed_agent(cls): """Helper class for renaming Agent => Trainer with a warning.""" class DeprecationWrapper(cls): def __init__(self, config=None, env=None, logger_creator=None): old_name = cls.__name__.replace("Trainer", "Agent") - new_name = cls.__name__ + new_name = cls.__module__ + "." + cls.__name__ logger.warn("DeprecationWarning: {} has been renamed to {}. ". format(old_name, new_name) + "This will raise an error in the future.") diff --git a/python/ray/rllib/utils/debug.py b/python/ray/rllib/utils/debug.py index ce86326f27a0..0f636b0f00ef 100644 --- a/python/ray/rllib/utils/debug.py +++ b/python/ray/rllib/utils/debug.py @@ -6,7 +6,7 @@ import pprint import time -from ray.rllib.evaluation.sample_batch import SampleBatch, MultiAgentBatch +from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch _logged = set() _disabled = False From 081708bdefb1bc14b9086e203c22113daa0a4d35 Mon Sep 17 00:00:00 2001 From: Qing Wang Date: Tue, 21 May 2019 17:13:48 +0800 Subject: [PATCH 025/118] [Java] Dynamic resource API in Java (#4824) --- java/api/src/main/java/org/ray/api/Ray.java | 15 +++++++ .../java/org/ray/api/runtime/RayRuntime.java | 9 ++++ .../org/ray/runtime/AbstractRayRuntime.java | 9 ++++ .../ray/runtime/raylet/MockRayletClient.java | 5 +++ .../org/ray/runtime/raylet/RayletClient.java | 2 + .../ray/runtime/raylet/RayletClientImpl.java | 7 +++ .../org/ray/api/test/DynamicResourceTest.java | 44 +++++++++++++++++++ ...org_ray_runtime_raylet_RayletClientImpl.cc | 18 ++++++++ .../org_ray_runtime_raylet_RayletClientImpl.h | 8 ++++ 9 files changed, 117 insertions(+) create mode 100644 java/test/src/main/java/org/ray/api/test/DynamicResourceTest.java diff --git a/java/api/src/main/java/org/ray/api/Ray.java b/java/api/src/main/java/org/ray/api/Ray.java index fa82ea685706..3ebfc16687c1 100644 --- a/java/api/src/main/java/org/ray/api/Ray.java +++ b/java/api/src/main/java/org/ray/api/Ray.java @@ -123,6 +123,21 @@ public static RayRuntime internal() { return runtime; } + /** + * Update the resource for the specified client. + * Set the resource for the specific node. + */ + public static void setResource(UniqueId nodeId, String resourceName, double capacity) { + runtime.setResource(resourceName, capacity, nodeId); + } + + /** + * Set the resource for local node. + */ + public static void setResource(String resourceName, double capacity) { + runtime.setResource(resourceName, capacity, UniqueId.NIL); + } + /** * Get the runtime context. */ diff --git a/java/api/src/main/java/org/ray/api/runtime/RayRuntime.java b/java/api/src/main/java/org/ray/api/runtime/RayRuntime.java index 521032316366..7767253c52ff 100644 --- a/java/api/src/main/java/org/ray/api/runtime/RayRuntime.java +++ b/java/api/src/main/java/org/ray/api/runtime/RayRuntime.java @@ -65,6 +65,15 @@ public interface RayRuntime { */ void free(List objectIds, boolean localOnly, boolean deleteCreatingTasks); + /** + * Set the resource for the specific node. + * + * @param resourceName The name of resource. + * @param capacity The capacity of the resource. + * @param nodeId The node that we want to set its resource. + */ + void setResource(String resourceName, double capacity, UniqueId nodeId); + /** * Invoke a remote function. * diff --git a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java index af8cff9d79d9..e77d9a6f570f 100644 --- a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java @@ -210,6 +210,15 @@ public void free(List objectIds, boolean localOnly, boolean deleteCrea rayletClient.freePlasmaObjects(objectIds, localOnly, deleteCreatingTasks); } + @Override + public void setResource(String resourceName, double capacity, UniqueId nodeId) { + Preconditions.checkArgument(Double.compare(capacity, 0) >= 0); + if (nodeId == null) { + nodeId = UniqueId.NIL; + } + rayletClient.setResource(resourceName, capacity, nodeId); + } + private List> splitIntoBatches(List objectIds) { List> batches = new ArrayList<>(); int objectsSize = objectIds.size(); diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java b/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java index 385431c7055f..640789c3b0aa 100644 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java +++ b/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java @@ -209,6 +209,11 @@ public void notifyActorResumedFromCheckpoint(UniqueId actorId, UniqueId checkpoi throw new NotImplementedException("Not implemented."); } + @Override + public void setResource(String resourceName, double capacity, UniqueId nodeId) { + LOGGER.error("Not implemented under SINGLE_PROCESS mode."); + } + @Override public void destroy() { exec.shutdown(); diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java index fc6fc75b0fbd..19db27f6d900 100644 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java +++ b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java @@ -30,5 +30,7 @@ WaitResult wait(List> waitFor, int numReturns, int void notifyActorResumedFromCheckpoint(UniqueId actorId, UniqueId checkpointId); + void setResource(String resourceName, double capacity, UniqueId nodeId); + void destroy(); } diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java index 0ed1f9c86fbf..b46d6b611a8e 100644 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java +++ b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java @@ -308,6 +308,10 @@ private static ByteBuffer convertTaskSpecToFlatbuffer(TaskSpec task) { return buffer; } + public void setResource(String resourceName, double capacity, UniqueId nodeId) { + nativeSetResource(client, resourceName, capacity, nodeId.getBytes()); + } + public void destroy() { nativeDestroy(client); } @@ -357,4 +361,7 @@ private static native void nativeFreePlasmaObjects(long conn, byte[][] objectIds private static native void nativeNotifyActorResumedFromCheckpoint(long conn, byte[] actorId, byte[] checkpointId); + + private static native void nativeSetResource(long conn, String resourceName, double capacity, + byte[] nodeId) throws RayException; } diff --git a/java/test/src/main/java/org/ray/api/test/DynamicResourceTest.java b/java/test/src/main/java/org/ray/api/test/DynamicResourceTest.java new file mode 100644 index 000000000000..ffda0732287e --- /dev/null +++ b/java/test/src/main/java/org/ray/api/test/DynamicResourceTest.java @@ -0,0 +1,44 @@ +package org.ray.api.test; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import java.util.List; +import org.ray.api.Ray; +import org.ray.api.RayObject; +import org.ray.api.TestUtils; +import org.ray.api.WaitResult; +import org.ray.api.annotation.RayRemote; +import org.ray.api.options.CallOptions; +import org.ray.api.runtimecontext.NodeInfo; +import org.testng.Assert; +import org.testng.annotations.Test; + +public class DynamicResourceTest extends BaseTest { + + @RayRemote + public static String sayHi() { + return "hi"; + } + + @Test + public void testSetResource() { + TestUtils.skipTestUnderSingleProcess(); + CallOptions op1 = new CallOptions(ImmutableMap.of("A", 10.0)); + RayObject obj = Ray.call(DynamicResourceTest::sayHi, op1); + WaitResult result = Ray.wait(ImmutableList.of(obj), 1, 1000); + Assert.assertEquals(result.getReady().size(), 0); + + Ray.setResource("A", 10.0); + + // Assert node info. + List nodes = Ray.getRuntimeContext().getAllNodeInfo(); + Assert.assertEquals(nodes.size(), 1); + Assert.assertEquals(nodes.get(0).resources.get("A"), 10.0); + + // Assert ray call result. + result = Ray.wait(ImmutableList.of(obj), 1, 1000); + Assert.assertEquals(result.getReady().size(), 1); + Assert.assertEquals(Ray.get(obj.getId()), "hi"); + } + +} diff --git a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc index eb9d2f0e5a83..ac32911ef2d0 100644 --- a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc +++ b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc @@ -302,6 +302,24 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativeNotifyActorResumedFromCheckpo ThrowRayExceptionIfNotOK(env, status); } +/* + * Class: org_ray_runtime_raylet_RayletClientImpl + * Method: nativeSetResource + * Signature: (JLjava/lang/String;D[B)V + */ +JNIEXPORT void JNICALL +Java_org_ray_runtime_raylet_RayletClientImpl_nativeSetResource(JNIEnv *env, jclass, + jlong client, jstring resourceName, jdouble capacity, jbyteArray nodeId) { + auto raylet_client = reinterpret_cast(client); + UniqueIdFromJByteArray node_id(env, nodeId); + const char *native_resource_name = env->GetStringUTFChars(resourceName, JNI_FALSE); + + auto status = raylet_client->SetResource(native_resource_name, + static_cast(capacity), node_id.GetId()); + env->ReleaseStringUTFChars(resourceName, native_resource_name); + ThrowRayExceptionIfNotOK(env, status); +} + #ifdef __cplusplus } #endif diff --git a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.h b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.h index c00c7c009814..91338a12e176 100644 --- a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.h +++ b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.h @@ -116,6 +116,14 @@ JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeNotifyActorResumedFromCheckpoint( JNIEnv *, jclass, jlong, jbyteArray, jbyteArray); +/* + * Class: org_ray_runtime_raylet_RayletClientImpl + * Method: nativeSetResource + * Signature: (JLjava/lang/String;D[B)V + */ +JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeSetResource( + JNIEnv *, jclass, jlong, jstring, jdouble, jbyteArray); + #ifdef __cplusplus } #endif From 5391b613109f56a2257f67d3c0baf4186f47371b Mon Sep 17 00:00:00 2001 From: Stefan Pantic Date: Tue, 21 May 2019 13:11:24 +0200 Subject: [PATCH 026/118] Add default values for Wgym flags --- python/ray/rllib/agents/trainer.py | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/python/ray/rllib/agents/trainer.py b/python/ray/rllib/agents/trainer.py index 00f115978ecd..029034e94258 100644 --- a/python/ray/rllib/agents/trainer.py +++ b/python/ray/rllib/agents/trainer.py @@ -2,35 +2,37 @@ from __future__ import division from __future__ import print_function -from datetime import datetime import copy import logging import os import pickle -import six -import time import tempfile +import time +from datetime import datetime from types import FunctionType import ray +import six from ray.exceptions import RayError -from ray.rllib.offline import NoopOutput, JsonReader, MixedInput, JsonWriter, \ - ShuffledInput -from ray.rllib.models import MODEL_DEFAULTS +from ray.rllib.evaluation.metrics import collect_metrics from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator, \ _validate_multiagent_config from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID -from ray.rllib.evaluation.metrics import collect_metrics +from ray.rllib.models import MODEL_DEFAULTS +from ray.rllib.offline import NoopOutput, JsonReader, MixedInput, JsonWriter, \ + ShuffledInput from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer -from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI from ray.rllib.utils import FilterManager, deep_update, merge_dicts -from ray.rllib.utils.memory import ray_get_and_free from ray.rllib.utils import try_import_tf +from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI +from ray.rllib.utils.memory import ray_get_and_free +from ray.tune.logger import UnifiedLogger from ray.tune.registry import ENV_CREATOR, register_env, _global_registry +from ray.tune.result import DEFAULT_RESULTS_DIR from ray.tune.trainable import Trainable from ray.tune.trial import Resources, ExportFormat -from ray.tune.logger import UnifiedLogger -from ray.tune.result import DEFAULT_RESULTS_DIR + +from python.ray.tune.logger import to_tf_values tf = try_import_tf() @@ -108,7 +110,10 @@ # and to disable exploration by computing deterministic actions # TODO(kismuz): implement determ. actions and include relevant keys hints "evaluation_config": { - "beholder": False + "beholder": False, + "should_log_histograms": False, + "to_tf_values": to_tf_values, + "debug_learner_session_port": None, }, # === Resources === From 87bb2e58bc86ec3a6a1a0cc6d25eb2431b31c4b4 Mon Sep 17 00:00:00 2001 From: Stefan Pantic Date: Tue, 21 May 2019 13:13:41 +0200 Subject: [PATCH 027/118] Fix import --- python/ray/rllib/agents/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/rllib/agents/trainer.py b/python/ray/rllib/agents/trainer.py index 029034e94258..0e727955aab3 100644 --- a/python/ray/rllib/agents/trainer.py +++ b/python/ray/rllib/agents/trainer.py @@ -32,7 +32,7 @@ from ray.tune.trainable import Trainable from ray.tune.trial import Resources, ExportFormat -from python.ray.tune.logger import to_tf_values +from ray.tune.logger import to_tf_values tf = try_import_tf() From 259cdfa0defd278e449a235258746089f1ed3fcb Mon Sep 17 00:00:00 2001 From: Qing Wang Date: Wed, 22 May 2019 11:08:24 +0800 Subject: [PATCH 028/118] Fix issue when starting `raylet_monitor` (#4829) --- python/ray/services.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/services.py b/python/ray/services.py index 034a610e471c..2e9759428154 100644 --- a/python/ray/services.py +++ b/python/ray/services.py @@ -1563,7 +1563,7 @@ def start_raylet_monitor(redis_address, "--config_list={}".format(config_str), ] if redis_password: - command += [redis_password] + command += ["--redis_password={}".format(redis_password)] process_info = start_ray_process( command, ray_constants.PROCESS_TYPE_RAYLET_MONITOR, From 1a39fee9c6fa0f81e1184182b47ef106269613fe Mon Sep 17 00:00:00 2001 From: Yuhong Guo Date: Wed, 22 May 2019 14:46:30 +0800 Subject: [PATCH 029/118] Refactor ID Serial 1: Separate ObjectID and TaskID from UniqueID (#4776) * Enable BaseId. * Change TaskID and make python test pass * Remove unnecessary functions and fix test failure and change TaskID to 16 bytes. * Java code change draft * Refine * Lint * Update java/api/src/main/java/org/ray/api/id/TaskId.java Co-Authored-By: Hao Chen * Update java/api/src/main/java/org/ray/api/id/BaseId.java Co-Authored-By: Hao Chen * Update java/api/src/main/java/org/ray/api/id/BaseId.java Co-Authored-By: Hao Chen * Update java/api/src/main/java/org/ray/api/id/ObjectId.java Co-Authored-By: Hao Chen * Address comment * Lint * Fix SINGLE_PROCESS * Fix comments * Refine code * Refine test * Resolve conflict --- java/api/src/main/java/org/ray/api/Ray.java | 5 +- .../src/main/java/org/ray/api/RayObject.java | 4 +- .../exception/UnreconstructableException.java | 6 +- .../src/main/java/org/ray/api/id/BaseId.java | 99 ++++++++ .../main/java/org/ray/api/id/ObjectId.java | 62 +++++ .../src/main/java/org/ray/api/id/TaskId.java | 56 ++++ .../main/java/org/ray/api/id/UniqueId.java | 76 +----- .../java/org/ray/api/runtime/RayRuntime.java | 7 +- .../org/ray/runtime/AbstractRayRuntime.java | 46 ++-- .../java/org/ray/runtime/RayActorImpl.java | 9 +- .../java/org/ray/runtime/RayObjectImpl.java | 8 +- .../src/main/java/org/ray/runtime/Worker.java | 7 +- .../java/org/ray/runtime/WorkerContext.java | 11 +- .../java/org/ray/runtime/gcs/GcsClient.java | 12 +- .../runtime/objectstore/MockObjectStore.java | 31 ++- .../runtime/objectstore/ObjectStoreProxy.java | 16 +- .../ray/runtime/raylet/MockRayletClient.java | 34 +-- .../org/ray/runtime/raylet/RayletClient.java | 12 +- .../ray/runtime/raylet/RayletClientImpl.java | 45 ++-- .../ray/runtime/task/ArgumentsBuilder.java | 6 +- .../org/ray/runtime/task/FunctionArg.java | 8 +- .../java/org/ray/runtime/task/TaskSpec.java | 18 +- .../util/{UniqueIdUtil.java => IdUtil.java} | 113 +++++---- .../org/ray/api/test/ClientExceptionTest.java | 4 +- .../org/ray/api/test/ObjectStoreTest.java | 4 +- .../java/org/ray/api/test/PlasmaFreeTest.java | 3 +- .../java/org/ray/api/test/StressTest.java | 8 +- .../java/org/ray/api/test/UniqueIdTest.java | 58 +++-- python/ray/_raylet.pyx | 4 +- python/ray/actor.py | 7 +- python/ray/includes/common.pxd | 6 - python/ray/includes/unique_ids.pxd | 70 +++-- python/ray/includes/unique_ids.pxi | 152 ++++++++--- python/ray/monitor.py | 12 +- python/ray/tests/test_basic.py | 27 +- python/ray/tests/test_failure.py | 3 +- python/ray/utils.py | 4 + python/ray/worker.py | 9 +- src/ray/constants.h | 5 +- src/ray/gcs/client_test.cc | 36 +-- src/ray/gcs/redis_context.cc | 39 --- src/ray/gcs/redis_context.h | 51 +++- src/ray/gcs/redis_module/ray_redis_module.cc | 12 +- src/ray/gcs/tables.cc | 22 +- src/ray/gcs/tables.h | 2 +- src/ray/id.cc | 165 +++--------- src/ray/id.h | 239 ++++++++++++++---- src/ray/id_def.h | 2 - .../test/object_manager_test.cc | 2 +- src/ray/raylet/lineage_cache.cc | 2 +- src/ray/raylet/node_manager.cc | 9 +- src/ray/raylet/reconstruction_policy.cc | 4 +- src/ray/raylet/reconstruction_policy_test.cc | 26 +- src/ray/raylet/task_dependency_manager.cc | 12 +- .../raylet/task_dependency_manager_test.cc | 2 +- src/ray/raylet/task_spec.cc | 2 +- src/ray/raylet/task_test.cc | 26 +- 57 files changed, 1076 insertions(+), 644 deletions(-) create mode 100644 java/api/src/main/java/org/ray/api/id/BaseId.java create mode 100644 java/api/src/main/java/org/ray/api/id/ObjectId.java create mode 100644 java/api/src/main/java/org/ray/api/id/TaskId.java rename java/runtime/src/main/java/org/ray/runtime/util/{UniqueIdUtil.java => IdUtil.java} (64%) diff --git a/java/api/src/main/java/org/ray/api/Ray.java b/java/api/src/main/java/org/ray/api/Ray.java index 3ebfc16687c1..cdad95e16758 100644 --- a/java/api/src/main/java/org/ray/api/Ray.java +++ b/java/api/src/main/java/org/ray/api/Ray.java @@ -1,6 +1,7 @@ package org.ray.api; import java.util.List; +import org.ray.api.id.ObjectId; import org.ray.api.id.UniqueId; import org.ray.api.runtime.RayRuntime; import org.ray.api.runtime.RayRuntimeFactory; @@ -65,7 +66,7 @@ public static RayObject put(T obj) { * @param objectId The ID of the object to get. * @return The Java object. */ - public static T get(UniqueId objectId) { + public static T get(ObjectId objectId) { return runtime.get(objectId); } @@ -75,7 +76,7 @@ public static T get(UniqueId objectId) { * @param objectIds The list of object IDs. * @return A list of Java objects. */ - public static List get(List objectIds) { + public static List get(List objectIds) { return runtime.get(objectIds); } diff --git a/java/api/src/main/java/org/ray/api/RayObject.java b/java/api/src/main/java/org/ray/api/RayObject.java index a1971be40773..faf42f826aa1 100644 --- a/java/api/src/main/java/org/ray/api/RayObject.java +++ b/java/api/src/main/java/org/ray/api/RayObject.java @@ -1,6 +1,6 @@ package org.ray.api; -import org.ray.api.id.UniqueId; +import org.ray.api.id.ObjectId; /** * Represents an object in the object store. @@ -17,7 +17,7 @@ public interface RayObject { /** * Get the object id. */ - UniqueId getId(); + ObjectId getId(); } diff --git a/java/api/src/main/java/org/ray/api/exception/UnreconstructableException.java b/java/api/src/main/java/org/ray/api/exception/UnreconstructableException.java index 8362295baf1a..0eb2ed9e7dca 100644 --- a/java/api/src/main/java/org/ray/api/exception/UnreconstructableException.java +++ b/java/api/src/main/java/org/ray/api/exception/UnreconstructableException.java @@ -1,6 +1,6 @@ package org.ray.api.exception; -import org.ray.api.id.UniqueId; +import org.ray.api.id.ObjectId; /** * Indicates that an object is lost (either evicted or explicitly deleted) and cannot be @@ -11,9 +11,9 @@ */ public class UnreconstructableException extends RayException { - public final UniqueId objectId; + public final ObjectId objectId; - public UnreconstructableException(UniqueId objectId) { + public UnreconstructableException(ObjectId objectId) { super(String.format( "Object %s is lost (either evicted or explicitly deleted) and cannot be reconstructed.", objectId)); diff --git a/java/api/src/main/java/org/ray/api/id/BaseId.java b/java/api/src/main/java/org/ray/api/id/BaseId.java new file mode 100644 index 000000000000..3c5e1e3a3619 --- /dev/null +++ b/java/api/src/main/java/org/ray/api/id/BaseId.java @@ -0,0 +1,99 @@ +package org.ray.api.id; + +import java.io.Serializable; +import java.nio.ByteBuffer; +import java.util.Arrays; +import javax.xml.bind.DatatypeConverter; + +public abstract class BaseId implements Serializable { + private static final long serialVersionUID = 8588849129675565761L; + private final byte[] id; + private int hashCodeCache = 0; + private Boolean isNilCache = null; + + /** + * Create a BaseId instance according to the input byte array. + */ + public BaseId(byte[] id) { + if (id.length != size()) { + throw new IllegalArgumentException("Failed to construct BaseId, expect " + size() + + " bytes, but got " + id.length + " bytes."); + } + this.id = id; + } + + /** + * Get the byte data of this id. + */ + public byte[] getBytes() { + return id; + } + + /** + * Convert the byte data to a ByteBuffer. + */ + public ByteBuffer toByteBuffer() { + return ByteBuffer.wrap(id); + } + + /** + * @return True if this id is nil. + */ + public boolean isNil() { + if (isNilCache == null) { + isNilCache = true; + for (int i = 0; i < size(); ++i) { + if (id[i] != (byte) 0xff) { + isNilCache = false; + break; + } + } + } + return isNilCache; + } + + /** + * Derived class should implement this function. + * @return The length of this id in bytes. + */ + public abstract int size(); + + @Override + public int hashCode() { + // Lazy evaluation. + if (hashCodeCache == 0) { + hashCodeCache = Arrays.hashCode(id); + } + return hashCodeCache; + } + + @Override + public boolean equals(Object obj) { + if (obj == null) { + return false; + } + + if (!this.getClass().equals(obj.getClass())) { + return false; + } + + BaseId r = (BaseId) obj; + return Arrays.equals(id, r.id); + } + + @Override + public String toString() { + return DatatypeConverter.printHexBinary(id).toLowerCase(); + } + + protected static byte[] hexString2Bytes(String hex) { + return DatatypeConverter.parseHexBinary(hex); + } + + protected static byte[] byteBuffer2Bytes(ByteBuffer bb) { + byte[] id = new byte[bb.remaining()]; + bb.get(id); + return id; + } + +} diff --git a/java/api/src/main/java/org/ray/api/id/ObjectId.java b/java/api/src/main/java/org/ray/api/id/ObjectId.java new file mode 100644 index 000000000000..49c0f39ebe5b --- /dev/null +++ b/java/api/src/main/java/org/ray/api/id/ObjectId.java @@ -0,0 +1,62 @@ +package org.ray.api.id; + +import java.io.Serializable; +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Random; + +/** + * Represents the id of a Ray object. + */ +public class ObjectId extends BaseId implements Serializable { + + public static final int LENGTH = 20; + public static final ObjectId NIL = genNil(); + + /** + * Create an ObjectId from a hex string. + */ + public static ObjectId fromHexString(String hex) { + return new ObjectId(hexString2Bytes(hex)); + } + + /** + * Create an ObjectId from a ByteBuffer. + */ + public static ObjectId fromByteBuffer(ByteBuffer bb) { + return new ObjectId(byteBuffer2Bytes(bb)); + } + + /** + * Generate a nil ObjectId. + */ + private static ObjectId genNil() { + byte[] b = new byte[LENGTH]; + Arrays.fill(b, (byte) 0xFF); + return new ObjectId(b); + } + + /** + * Generate an ObjectId with random value. + */ + public static ObjectId randomId() { + byte[] b = new byte[LENGTH]; + new Random().nextBytes(b); + return new ObjectId(b); + } + + public ObjectId(byte[] id) { + super(id); + } + + @Override + public int size() { + return LENGTH; + } + + public TaskId getTaskId() { + byte[] taskIdBytes = Arrays.copyOf(getBytes(), TaskId.LENGTH); + return new TaskId(taskIdBytes); + } + +} diff --git a/java/api/src/main/java/org/ray/api/id/TaskId.java b/java/api/src/main/java/org/ray/api/id/TaskId.java new file mode 100644 index 000000000000..8f1fe0694ea4 --- /dev/null +++ b/java/api/src/main/java/org/ray/api/id/TaskId.java @@ -0,0 +1,56 @@ +package org.ray.api.id; + +import java.io.Serializable; +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Random; + +/** + * Represents the id of a Ray task. + */ +public class TaskId extends BaseId implements Serializable { + + public static final int LENGTH = 16; + public static final TaskId NIL = genNil(); + + /** + * Create a TaskId from a hex string. + */ + public static TaskId fromHexString(String hex) { + return new TaskId(hexString2Bytes(hex)); + } + + /** + * Creates a TaskId from a ByteBuffer. + */ + public static TaskId fromByteBuffer(ByteBuffer bb) { + return new TaskId(byteBuffer2Bytes(bb)); + } + + /** + * Generate a nil TaskId. + */ + private static TaskId genNil() { + byte[] b = new byte[LENGTH]; + Arrays.fill(b, (byte) 0xFF); + return new TaskId(b); + } + + /** + * Generate an TaskId with random value. + */ + public static TaskId randomId() { + byte[] b = new byte[LENGTH]; + new Random().nextBytes(b); + return new TaskId(b); + } + + public TaskId(byte[] id) { + super(id); + } + + @Override + public int size() { + return LENGTH; + } +} diff --git a/java/api/src/main/java/org/ray/api/id/UniqueId.java b/java/api/src/main/java/org/ray/api/id/UniqueId.java index f93bdc737229..4fd723ff26bf 100644 --- a/java/api/src/main/java/org/ray/api/id/UniqueId.java +++ b/java/api/src/main/java/org/ray/api/id/UniqueId.java @@ -4,41 +4,34 @@ import java.nio.ByteBuffer; import java.util.Arrays; import java.util.Random; -import javax.xml.bind.DatatypeConverter; /** * Represents a unique id of all Ray concepts, including - * objects, tasks, workers, actors, etc. + * workers, actors, checkpoints, etc. */ -public class UniqueId implements Serializable { +public class UniqueId extends BaseId implements Serializable { public static final int LENGTH = 20; public static final UniqueId NIL = genNil(); - private static final long serialVersionUID = 8588849129675565761L; - private final byte[] id; /** * Create a UniqueId from a hex string. */ public static UniqueId fromHexString(String hex) { - byte[] bytes = DatatypeConverter.parseHexBinary(hex); - return new UniqueId(bytes); + return new UniqueId(hexString2Bytes(hex)); } /** * Creates a UniqueId from a ByteBuffer. */ public static UniqueId fromByteBuffer(ByteBuffer bb) { - byte[] id = new byte[bb.remaining()]; - bb.get(id); - - return new UniqueId(id); + return new UniqueId(byteBuffer2Bytes(bb)); } /** * Generate a nil UniqueId. */ - public static UniqueId genNil() { + private static UniqueId genNil() { byte[] b = new byte[LENGTH]; Arrays.fill(b, (byte) 0xFF); return new UniqueId(b); @@ -54,64 +47,11 @@ public static UniqueId randomId() { } public UniqueId(byte[] id) { - if (id.length != LENGTH) { - throw new IllegalArgumentException("Illegal argument for UniqueId, expect " + LENGTH - + " bytes, but got " + id.length + " bytes."); - } - - this.id = id; - } - - /** - * Get the byte data of this UniqueId. - */ - public byte[] getBytes() { - return id; - } - - /** - * Convert the byte data to a ByteBuffer. - */ - public ByteBuffer toByteBuffer() { - return ByteBuffer.wrap(id); - } - - /** - * Create a copy of this UniqueId. - */ - public UniqueId copy() { - byte[] nid = Arrays.copyOf(id, id.length); - return new UniqueId(nid); - } - - /** - * Returns true if this id is nil. - */ - public boolean isNil() { - return this.equals(NIL); - } - - @Override - public int hashCode() { - return Arrays.hashCode(id); - } - - @Override - public boolean equals(Object obj) { - if (obj == null) { - return false; - } - - if (!(obj instanceof UniqueId)) { - return false; - } - - UniqueId r = (UniqueId) obj; - return Arrays.equals(id, r.id); + super(id); } @Override - public String toString() { - return DatatypeConverter.printHexBinary(id).toLowerCase(); + public int size() { + return LENGTH; } } diff --git a/java/api/src/main/java/org/ray/api/runtime/RayRuntime.java b/java/api/src/main/java/org/ray/api/runtime/RayRuntime.java index 7767253c52ff..5a29c9a39dd1 100644 --- a/java/api/src/main/java/org/ray/api/runtime/RayRuntime.java +++ b/java/api/src/main/java/org/ray/api/runtime/RayRuntime.java @@ -6,6 +6,7 @@ import org.ray.api.RayPyActor; import org.ray.api.WaitResult; import org.ray.api.function.RayFunc; +import org.ray.api.id.ObjectId; import org.ray.api.id.UniqueId; import org.ray.api.options.ActorCreationOptions; import org.ray.api.options.CallOptions; @@ -35,7 +36,7 @@ public interface RayRuntime { * @param objectId The ID of the object to get. * @return The Java object. */ - T get(UniqueId objectId); + T get(ObjectId objectId); /** * Get a list of objects from the object store. @@ -43,7 +44,7 @@ public interface RayRuntime { * @param objectIds The list of object IDs. * @return A list of Java objects. */ - List get(List objectIds); + List get(List objectIds); /** * Wait for a list of RayObjects to be locally available, until specified number of objects are @@ -63,7 +64,7 @@ public interface RayRuntime { * @param localOnly Whether only free objects for local object store or not. * @param deleteCreatingTasks Whether also delete objects' creating tasks from GCS. */ - void free(List objectIds, boolean localOnly, boolean deleteCreatingTasks); + void free(List objectIds, boolean localOnly, boolean deleteCreatingTasks); /** * Set the resource for the specific node. diff --git a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java index e77d9a6f570f..01f8dbd12ba0 100644 --- a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java @@ -15,6 +15,8 @@ import org.ray.api.WaitResult; import org.ray.api.exception.RayException; import org.ray.api.function.RayFunc; +import org.ray.api.id.ObjectId; +import org.ray.api.id.TaskId; import org.ray.api.id.UniqueId; import org.ray.api.options.ActorCreationOptions; import org.ray.api.options.BaseTaskOptions; @@ -32,7 +34,7 @@ import org.ray.runtime.task.ArgumentsBuilder; import org.ray.runtime.task.TaskLanguage; import org.ray.runtime.task.TaskSpec; -import org.ray.runtime.util.UniqueIdUtil; +import org.ray.runtime.util.IdUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -88,15 +90,15 @@ public AbstractRayRuntime(RayConfig rayConfig) { @Override public RayObject put(T obj) { - UniqueId objectId = UniqueIdUtil.computePutId( + ObjectId objectId = IdUtil.computePutId( workerContext.getCurrentTaskId(), workerContext.nextPutIndex()); put(objectId, obj); return new RayObjectImpl<>(objectId); } - public void put(UniqueId objectId, T obj) { - UniqueId taskId = workerContext.getCurrentTaskId(); + public void put(ObjectId objectId, T obj) { + TaskId taskId = workerContext.getCurrentTaskId(); LOGGER.debug("Putting object {}, for task {} ", objectId, taskId); objectStoreProxy.put(objectId, obj); } @@ -109,28 +111,28 @@ public void put(UniqueId objectId, T obj) { * @return A RayObject instance that represents the in-store object. */ public RayObject putSerialized(byte[] obj) { - UniqueId objectId = UniqueIdUtil.computePutId( + ObjectId objectId = IdUtil.computePutId( workerContext.getCurrentTaskId(), workerContext.nextPutIndex()); - UniqueId taskId = workerContext.getCurrentTaskId(); + TaskId taskId = workerContext.getCurrentTaskId(); LOGGER.debug("Putting serialized object {}, for task {} ", objectId, taskId); objectStoreProxy.putSerialized(objectId, obj); return new RayObjectImpl<>(objectId); } @Override - public T get(UniqueId objectId) throws RayException { + public T get(ObjectId objectId) throws RayException { List ret = get(ImmutableList.of(objectId)); return ret.get(0); } @Override - public List get(List objectIds) { + public List get(List objectIds) { List ret = new ArrayList<>(Collections.nCopies(objectIds.size(), null)); boolean wasBlocked = false; try { // A map that stores the unready object ids and their original indexes. - Map unready = new HashMap<>(); + Map unready = new HashMap<>(); for (int i = 0; i < objectIds.size(); i++) { unready.put(objectIds.get(i), i); } @@ -138,7 +140,7 @@ public List get(List objectIds) { // Repeat until we get all objects. while (!unready.isEmpty()) { - List unreadyIds = new ArrayList<>(unready.keySet()); + List unreadyIds = new ArrayList<>(unready.keySet()); // For the initial fetch, we only fetch the objects, do not reconstruct them. boolean fetchOnly = numAttempts == 0; @@ -147,7 +149,7 @@ public List get(List objectIds) { wasBlocked = true; } // Call `fetchOrReconstruct` in batches. - for (List batch : splitIntoBatches(unreadyIds)) { + for (List batch : splitIntoBatches(unreadyIds)) { rayletClient.fetchOrReconstruct(batch, fetchOnly, workerContext.getCurrentTaskId()); } @@ -161,7 +163,7 @@ public List get(List objectIds) { throw getResult.exception; } else { // Set the result to the return list, and remove it from the unready map. - UniqueId id = unreadyIds.get(i); + ObjectId id = unreadyIds.get(i); ret.set(unready.get(id), getResult.object); unready.remove(id); } @@ -172,11 +174,11 @@ public List get(List objectIds) { if (LOGGER.isWarnEnabled() && numAttempts % WARN_PER_NUM_ATTEMPTS == 0) { // Print a warning if we've attempted too many times, but some objects are still // unavailable. - List idsToPrint = new ArrayList<>(unready.keySet()); + List idsToPrint = new ArrayList<>(unready.keySet()); if (idsToPrint.size() > MAX_IDS_TO_PRINT_IN_WARNING) { idsToPrint = idsToPrint.subList(0, MAX_IDS_TO_PRINT_IN_WARNING); } - String ids = idsToPrint.stream().map(UniqueId::toString) + String ids = idsToPrint.stream().map(ObjectId::toString) .collect(Collectors.joining(", ")); if (idsToPrint.size() < unready.size()) { ids += ", etc"; @@ -206,7 +208,7 @@ public List get(List objectIds) { } @Override - public void free(List objectIds, boolean localOnly, boolean deleteCreatingTasks) { + public void free(List objectIds, boolean localOnly, boolean deleteCreatingTasks) { rayletClient.freePlasmaObjects(objectIds, localOnly, deleteCreatingTasks); } @@ -219,13 +221,13 @@ public void setResource(String resourceName, double capacity, UniqueId nodeId) { rayletClient.setResource(resourceName, capacity, nodeId); } - private List> splitIntoBatches(List objectIds) { - List> batches = new ArrayList<>(); + private List> splitIntoBatches(List objectIds) { + List> batches = new ArrayList<>(); int objectsSize = objectIds.size(); for (int i = 0; i < objectsSize; i += FETCH_BATCH_SIZE) { int endIndex = i + FETCH_BATCH_SIZE; - List batchIds = (endIndex < objectsSize) + List batchIds = (endIndex < objectsSize) ? objectIds.subList(i, endIndex) : objectIds.subList(i, objectsSize); @@ -271,7 +273,7 @@ public RayActor createActor(RayFunc actorFactoryFunc, Object[] args, ActorCreationOptions options) { TaskSpec spec = createTaskSpec(actorFactoryFunc, null, RayActorImpl.NIL, args, true, options); - RayActorImpl actor = new RayActorImpl(spec.returnIds[0]); + RayActorImpl actor = new RayActorImpl(new UniqueId(spec.returnIds[0].getBytes())); actor.increaseTaskCounter(); actor.setTaskCursor(spec.returnIds[0]); rayletClient.submitTask(spec); @@ -343,14 +345,14 @@ private TaskSpec createTaskSpec(RayFunc func, PyFunctionDescriptor pyFunctionDes boolean isActorCreationTask, BaseTaskOptions taskOptions) { Preconditions.checkArgument((func == null) != (pyFunctionDescriptor == null)); - UniqueId taskId = rayletClient.generateTaskId(workerContext.getCurrentDriverId(), + TaskId taskId = rayletClient.generateTaskId(workerContext.getCurrentDriverId(), workerContext.getCurrentTaskId(), workerContext.nextTaskIndex()); int numReturns = actor.getId().isNil() ? 1 : 2; - UniqueId[] returnIds = UniqueIdUtil.genReturnIds(taskId, numReturns); + ObjectId[] returnIds = IdUtil.genReturnIds(taskId, numReturns); UniqueId actorCreationId = UniqueId.NIL; if (isActorCreationTask) { - actorCreationId = returnIds[0]; + actorCreationId = new UniqueId(returnIds[0].getBytes()); } Map resources; diff --git a/java/runtime/src/main/java/org/ray/runtime/RayActorImpl.java b/java/runtime/src/main/java/org/ray/runtime/RayActorImpl.java index 7899869aef42..c5a9703c9164 100644 --- a/java/runtime/src/main/java/org/ray/runtime/RayActorImpl.java +++ b/java/runtime/src/main/java/org/ray/runtime/RayActorImpl.java @@ -7,6 +7,7 @@ import java.util.ArrayList; import java.util.List; import org.ray.api.RayActor; +import org.ray.api.id.ObjectId; import org.ray.api.id.UniqueId; import org.ray.runtime.util.Sha1Digestor; @@ -30,7 +31,7 @@ public class RayActorImpl implements RayActor, Externalizable { * The unique id of the last return of the last task. * It's used as a dependency for the next task. */ - protected UniqueId taskCursor; + protected ObjectId taskCursor; /** * The number of times that this actor handle has been forked. * It's used to make sure ids of actor handles are unique. @@ -72,7 +73,7 @@ public UniqueId getHandleId() { return handleId; } - public void setTaskCursor(UniqueId taskCursor) { + public void setTaskCursor(ObjectId taskCursor) { this.taskCursor = taskCursor; } @@ -84,7 +85,7 @@ public void clearNewActorHandles() { this.newActorHandles.clear(); } - public UniqueId getTaskCursor() { + public ObjectId getTaskCursor() { return taskCursor; } @@ -121,7 +122,7 @@ public void writeExternal(ObjectOutput out) throws IOException { public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { this.id = (UniqueId) in.readObject(); this.handleId = (UniqueId) in.readObject(); - this.taskCursor = (UniqueId) in.readObject(); + this.taskCursor = (ObjectId) in.readObject(); this.taskCounter = (int) in.readObject(); this.numForks = (int) in.readObject(); } diff --git a/java/runtime/src/main/java/org/ray/runtime/RayObjectImpl.java b/java/runtime/src/main/java/org/ray/runtime/RayObjectImpl.java index 1516543a1e2a..9f8e567f8e09 100644 --- a/java/runtime/src/main/java/org/ray/runtime/RayObjectImpl.java +++ b/java/runtime/src/main/java/org/ray/runtime/RayObjectImpl.java @@ -3,13 +3,13 @@ import java.io.Serializable; import org.ray.api.Ray; import org.ray.api.RayObject; -import org.ray.api.id.UniqueId; +import org.ray.api.id.ObjectId; public final class RayObjectImpl implements RayObject, Serializable { - private final UniqueId id; + private final ObjectId id; - public RayObjectImpl(UniqueId id) { + public RayObjectImpl(ObjectId id) { this.id = id; } @@ -19,7 +19,7 @@ public T get() { } @Override - public UniqueId getId() { + public ObjectId getId() { return id; } diff --git a/java/runtime/src/main/java/org/ray/runtime/Worker.java b/java/runtime/src/main/java/org/ray/runtime/Worker.java index 813a62fdc07e..b4de226e2914 100644 --- a/java/runtime/src/main/java/org/ray/runtime/Worker.java +++ b/java/runtime/src/main/java/org/ray/runtime/Worker.java @@ -7,6 +7,7 @@ import org.ray.api.Checkpointable.Checkpoint; import org.ray.api.Checkpointable.CheckpointContext; import org.ray.api.exception.RayTaskException; +import org.ray.api.id.ObjectId; import org.ray.api.id.UniqueId; import org.ray.runtime.config.RunMode; import org.ray.runtime.functionmanager.RayFunction; @@ -80,7 +81,7 @@ public void loop() { */ public void execute(TaskSpec spec) { LOGGER.debug("Executing task {}", spec); - UniqueId returnId = spec.returnIds[0]; + ObjectId returnId = spec.returnIds[0]; ClassLoader oldLoader = Thread.currentThread().getContextClassLoader(); try { // Get method @@ -91,7 +92,7 @@ public void execute(TaskSpec spec) { Thread.currentThread().setContextClassLoader(rayFunction.classLoader); if (spec.isActorCreationTask()) { - currentActorId = returnId; + currentActorId = new UniqueId(returnId.getBytes()); } // Get local actor object and arguments. @@ -119,7 +120,7 @@ public void execute(TaskSpec spec) { } runtime.put(returnId, result); } else { - maybeLoadCheckpoint(result, returnId); + maybeLoadCheckpoint(result, new UniqueId(returnId.getBytes())); currentActor = result; } LOGGER.debug("Finished executing task {}", spec.taskId); diff --git a/java/runtime/src/main/java/org/ray/runtime/WorkerContext.java b/java/runtime/src/main/java/org/ray/runtime/WorkerContext.java index 57f23cf31b19..44703bf673fd 100644 --- a/java/runtime/src/main/java/org/ray/runtime/WorkerContext.java +++ b/java/runtime/src/main/java/org/ray/runtime/WorkerContext.java @@ -1,6 +1,7 @@ package org.ray.runtime; import com.google.common.base.Preconditions; +import org.ray.api.id.TaskId; import org.ray.api.id.UniqueId; import org.ray.runtime.config.RunMode; import org.ray.runtime.config.WorkerMode; @@ -14,7 +15,7 @@ public class WorkerContext { private UniqueId workerId; - private ThreadLocal currentTaskId; + private ThreadLocal currentTaskId; /** * Number of objects that have been put from current task. @@ -46,17 +47,17 @@ public WorkerContext(WorkerMode workerMode, UniqueId driverId, RunMode runMode) mainThreadId = Thread.currentThread().getId(); taskIndex = ThreadLocal.withInitial(() -> 0); putIndex = ThreadLocal.withInitial(() -> 0); - currentTaskId = ThreadLocal.withInitial(UniqueId::randomId); + currentTaskId = ThreadLocal.withInitial(TaskId::randomId); this.runMode = runMode; currentTask = ThreadLocal.withInitial(() -> null); currentClassLoader = null; if (workerMode == WorkerMode.DRIVER) { workerId = driverId; - currentTaskId.set(UniqueId.randomId()); + currentTaskId.set(TaskId.randomId()); currentDriverId = driverId; } else { workerId = UniqueId.randomId(); - this.currentTaskId.set(UniqueId.NIL); + this.currentTaskId.set(TaskId.NIL); this.currentDriverId = UniqueId.NIL; } } @@ -65,7 +66,7 @@ public WorkerContext(WorkerMode workerMode, UniqueId driverId, RunMode runMode) * @return For the main thread, this method returns the ID of this worker's current running task; * for other threads, this method returns a random ID. */ - public UniqueId getCurrentTaskId() { + public TaskId getCurrentTaskId() { return currentTaskId.get(); } diff --git a/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java b/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java index 7439dfa430f8..431b48ded58c 100644 --- a/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java +++ b/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java @@ -9,13 +9,15 @@ import java.util.stream.Collectors; import org.apache.commons.lang3.ArrayUtils; import org.ray.api.Checkpointable.Checkpoint; +import org.ray.api.id.BaseId; +import org.ray.api.id.TaskId; import org.ray.api.id.UniqueId; import org.ray.api.runtimecontext.NodeInfo; import org.ray.runtime.generated.ActorCheckpointIdData; import org.ray.runtime.generated.ClientTableData; import org.ray.runtime.generated.EntryType; import org.ray.runtime.generated.TablePrefix; -import org.ray.runtime.util.UniqueIdUtil; +import org.ray.runtime.util.IdUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -112,7 +114,7 @@ public boolean actorExists(UniqueId actorId) { /** * Query whether the raylet task exists in Gcs. */ - public boolean rayletTaskExistsInGcs(UniqueId taskId) { + public boolean rayletTaskExistsInGcs(TaskId taskId) { byte[] key = ArrayUtils.addAll(TablePrefix.name(TablePrefix.RAYLET_TASK).getBytes(), taskId.getBytes()); RedisClient client = getShardClient(taskId); @@ -132,7 +134,7 @@ public List getCheckpointsForActor(UniqueId actorId) { if (result != null) { ActorCheckpointIdData data = ActorCheckpointIdData.getRootAsActorCheckpointIdData(ByteBuffer.wrap(result)); - UniqueId[] checkpointIds = UniqueIdUtil.getUniqueIdsFromByteBuffer( + UniqueId[] checkpointIds = IdUtil.getUniqueIdsFromByteBuffer( data.checkpointIdsAsByteBuffer()); for (int i = 0; i < checkpointIds.length; i++) { @@ -143,8 +145,8 @@ public List getCheckpointsForActor(UniqueId actorId) { return checkpoints; } - private RedisClient getShardClient(UniqueId key) { - return shards.get((int) Long.remainderUnsigned(UniqueIdUtil.murmurHashCode(key), + private RedisClient getShardClient(BaseId key) { + return shards.get((int) Long.remainderUnsigned(IdUtil.murmurHashCode(key), shards.size())); } diff --git a/java/runtime/src/main/java/org/ray/runtime/objectstore/MockObjectStore.java b/java/runtime/src/main/java/org/ray/runtime/objectstore/MockObjectStore.java index 4b80d3e4c276..f3d64c8340a2 100644 --- a/java/runtime/src/main/java/org/ray/runtime/objectstore/MockObjectStore.java +++ b/java/runtime/src/main/java/org/ray/runtime/objectstore/MockObjectStore.java @@ -9,7 +9,7 @@ import java.util.stream.Collectors; import org.apache.arrow.plasma.ObjectStoreLink; -import org.ray.api.id.UniqueId; +import org.ray.api.id.ObjectId; import org.ray.runtime.RayDevRuntime; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -24,16 +24,16 @@ public class MockObjectStore implements ObjectStoreLink { private static final int GET_CHECK_INTERVAL_MS = 100; private final RayDevRuntime runtime; - private final Map data = new ConcurrentHashMap<>(); - private final Map metadata = new ConcurrentHashMap<>(); - private final List> objectPutCallbacks; + private final Map data = new ConcurrentHashMap<>(); + private final Map metadata = new ConcurrentHashMap<>(); + private final List> objectPutCallbacks; public MockObjectStore(RayDevRuntime runtime) { this.runtime = runtime; this.objectPutCallbacks = new ArrayList<>(); } - public void addObjectPutCallback(Consumer callback) { + public void addObjectPutCallback(Consumer callback) { this.objectPutCallbacks.add(callback); } @@ -44,13 +44,12 @@ public void put(byte[] objectId, byte[] value, byte[] metadataValue) { .error("{} cannot put null: {}, {}", logPrefix(), objectId, Arrays.toString(value)); System.exit(-1); } - UniqueId uniqueId = new UniqueId(objectId); - data.put(uniqueId, value); + ObjectId id = new ObjectId(objectId); + data.put(id, value); if (metadataValue != null) { - metadata.put(uniqueId, metadataValue); + metadata.put(id, metadataValue); } - UniqueId id = new UniqueId(objectId); - for (Consumer callback : objectPutCallbacks) { + for (Consumer callback : objectPutCallbacks) { callback.accept(id); } } @@ -85,7 +84,7 @@ public List get(byte[][] objectIds, int timeoutMs) { } ready = 0; for (byte[] id : objectIds) { - if (data.containsKey(new UniqueId(id))) { + if (data.containsKey(new ObjectId(id))) { ready += 1; } } @@ -93,8 +92,8 @@ public List get(byte[][] objectIds, int timeoutMs) { } ArrayList rets = new ArrayList<>(); for (byte[] objId : objectIds) { - UniqueId uniqueId = new UniqueId(objId); - rets.add(new ObjectStoreData(metadata.get(uniqueId), data.get(uniqueId))); + ObjectId objectId = new ObjectId(objId); + rets.add(new ObjectStoreData(metadata.get(objectId), data.get(objectId))); } return rets; } @@ -121,7 +120,7 @@ public void delete(byte[] objectId) { @Override public boolean contains(byte[] objectId) { - return data.containsKey(new UniqueId(objectId)); + return data.containsKey(new ObjectId(objectId)); } private String logPrefix() { @@ -138,11 +137,11 @@ private String getUserTrace() { return stes[k].getFileName() + ":" + stes[k].getLineNumber(); } - public boolean isObjectReady(UniqueId id) { + public boolean isObjectReady(ObjectId id) { return data.containsKey(id); } - public void free(UniqueId id) { + public void free(ObjectId id) { data.remove(id); metadata.remove(id); } diff --git a/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java b/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java index 64b9e2b73a9f..f9e310249a35 100644 --- a/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java +++ b/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java @@ -12,13 +12,13 @@ import org.ray.api.exception.RayException; import org.ray.api.exception.RayWorkerException; import org.ray.api.exception.UnreconstructableException; -import org.ray.api.id.UniqueId; +import org.ray.api.id.ObjectId; import org.ray.runtime.AbstractRayRuntime; import org.ray.runtime.RayDevRuntime; import org.ray.runtime.config.RunMode; import org.ray.runtime.generated.ErrorType; +import org.ray.runtime.util.IdUtil; import org.ray.runtime.util.Serializer; -import org.ray.runtime.util.UniqueIdUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -61,7 +61,7 @@ public ObjectStoreProxy(AbstractRayRuntime runtime, String storeSocketName) { * @param Type of the object. * @return The GetResult object. */ - public GetResult get(UniqueId id, int timeoutMs) { + public GetResult get(ObjectId id, int timeoutMs) { List> list = get(ImmutableList.of(id), timeoutMs); return list.get(0); } @@ -74,8 +74,8 @@ public GetResult get(UniqueId id, int timeoutMs) { * @param Type of these objects. * @return A list of GetResult objects. */ - public List> get(List ids, int timeoutMs) { - byte[][] binaryIds = UniqueIdUtil.getIdBytes(ids); + public List> get(List ids, int timeoutMs) { + byte[][] binaryIds = IdUtil.getIdBytes(ids); List dataAndMetaList = objectStore.get().get(binaryIds, timeoutMs); List> results = new ArrayList<>(); @@ -114,7 +114,7 @@ public List> get(List ids, int timeoutMs) { } @SuppressWarnings("unchecked") - private GetResult deserializeFromMeta(byte[] meta, byte[] data, UniqueId objectId) { + private GetResult deserializeFromMeta(byte[] meta, byte[] data, ObjectId objectId) { if (Arrays.equals(meta, RAW_TYPE_META)) { return (GetResult) new GetResult<>(true, data, null); } else if (Arrays.equals(meta, WORKER_EXCEPTION_META)) { @@ -133,7 +133,7 @@ private GetResult deserializeFromMeta(byte[] meta, byte[] data, UniqueId * @param id Id of the object. * @param object The object to put. */ - public void put(UniqueId id, Object object) { + public void put(ObjectId id, Object object) { try { if (object instanceof byte[]) { // If the object is a byte array, skip serializing it and use a special metadata to @@ -153,7 +153,7 @@ public void put(UniqueId id, Object object) { * @param id Id of the object. * @param serializedObject The serialized object to put. */ - public void putSerialized(UniqueId id, byte[] serializedObject) { + public void putSerialized(ObjectId id, byte[] serializedObject) { try { objectStore.get().put(id.getBytes(), serializedObject, null); } catch (DuplicateObjectException e) { diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java b/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java index 640789c3b0aa..fe1f61d0bc11 100644 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java +++ b/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java @@ -17,6 +17,8 @@ import org.apache.commons.lang3.NotImplementedException; import org.ray.api.RayObject; import org.ray.api.WaitResult; +import org.ray.api.id.ObjectId; +import org.ray.api.id.TaskId; import org.ray.api.id.UniqueId; import org.ray.runtime.RayDevRuntime; import org.ray.runtime.Worker; @@ -33,7 +35,7 @@ public class MockRayletClient implements RayletClient { private static final Logger LOGGER = LoggerFactory.getLogger(MockRayletClient.class); - private final Map> waitingTasks = new ConcurrentHashMap<>(); + private final Map> waitingTasks = new ConcurrentHashMap<>(); private final MockObjectStore store; private final RayDevRuntime runtime; private final ExecutorService exec; @@ -52,7 +54,7 @@ public MockRayletClient(RayDevRuntime runtime, int numberThreads) { currentWorker = new ThreadLocal<>(); } - public synchronized void onObjectPut(UniqueId id) { + public synchronized void onObjectPut(ObjectId id) { Set tasks = waitingTasks.get(id); if (tasks != null) { waitingTasks.remove(id); @@ -98,7 +100,7 @@ private void returnWorker(Worker worker) { @Override public synchronized void submitTask(TaskSpec task) { LOGGER.debug("Submitting task: {}.", task); - Set unreadyObjects = getUnreadyObjects(task); + Set unreadyObjects = getUnreadyObjects(task); if (unreadyObjects.isEmpty()) { // If all dependencies are ready, execute this task. exec.submit(() -> { @@ -109,7 +111,7 @@ public synchronized void submitTask(TaskSpec task) { // put the dummy object in object store, so those tasks which depends on it // can be executed. if (task.isActorCreationTask() || task.isActorTask()) { - UniqueId[] returnIds = task.returnIds; + ObjectId[] returnIds = task.returnIds; store.put(returnIds[returnIds.length - 1].getBytes(), new byte[]{}, new byte[]{}); } @@ -119,14 +121,14 @@ public synchronized void submitTask(TaskSpec task) { }); } else { // If some dependencies aren't ready yet, put this task in waiting list. - for (UniqueId id : unreadyObjects) { + for (ObjectId id : unreadyObjects) { waitingTasks.computeIfAbsent(id, k -> new HashSet<>()).add(task); } } } - private Set getUnreadyObjects(TaskSpec spec) { - Set unreadyObjects = new HashSet<>(); + private Set getUnreadyObjects(TaskSpec spec) { + Set unreadyObjects = new HashSet<>(); // Check whether task arguments are ready. for (FunctionArg arg : spec.args) { if (arg.id != null) { @@ -136,7 +138,7 @@ private Set getUnreadyObjects(TaskSpec spec) { } } // Check whether task dependencies are ready. - for (UniqueId id : spec.getExecutionDependencies()) { + for (ObjectId id : spec.getExecutionDependencies()) { if (!store.isObjectReady(id)) { unreadyObjects.add(id); } @@ -151,24 +153,24 @@ public TaskSpec getTask() { } @Override - public void fetchOrReconstruct(List objectIds, boolean fetchOnly, - UniqueId currentTaskId) { + public void fetchOrReconstruct(List objectIds, boolean fetchOnly, + TaskId currentTaskId) { } @Override - public void notifyUnblocked(UniqueId currentTaskId) { + public void notifyUnblocked(TaskId currentTaskId) { } @Override - public UniqueId generateTaskId(UniqueId driverId, UniqueId parentTaskId, int taskIndex) { - return UniqueId.randomId(); + public TaskId generateTaskId(UniqueId driverId, TaskId parentTaskId, int taskIndex) { + return TaskId.randomId(); } @Override public WaitResult wait(List> waitFor, int numReturns, int - timeoutMs, UniqueId currentTaskId) { + timeoutMs, TaskId currentTaskId) { if (waitFor == null || waitFor.isEmpty()) { return new WaitResult<>(ImmutableList.of(), ImmutableList.of()); } @@ -191,9 +193,9 @@ public WaitResult wait(List> waitFor, int numReturns, int } @Override - public void freePlasmaObjects(List objectIds, boolean localOnly, + public void freePlasmaObjects(List objectIds, boolean localOnly, boolean deleteCreatingTasks) { - for (UniqueId id : objectIds) { + for (ObjectId id : objectIds) { store.free(id); } } diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java index 19db27f6d900..4a78fde9430e 100644 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java +++ b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java @@ -3,6 +3,8 @@ import java.util.List; import org.ray.api.RayObject; import org.ray.api.WaitResult; +import org.ray.api.id.ObjectId; +import org.ray.api.id.TaskId; import org.ray.api.id.UniqueId; import org.ray.runtime.task.TaskSpec; @@ -15,16 +17,16 @@ public interface RayletClient { TaskSpec getTask(); - void fetchOrReconstruct(List objectIds, boolean fetchOnly, UniqueId currentTaskId); + void fetchOrReconstruct(List objectIds, boolean fetchOnly, TaskId currentTaskId); - void notifyUnblocked(UniqueId currentTaskId); + void notifyUnblocked(TaskId currentTaskId); - UniqueId generateTaskId(UniqueId driverId, UniqueId parentTaskId, int taskIndex); + TaskId generateTaskId(UniqueId driverId, TaskId parentTaskId, int taskIndex); WaitResult wait(List> waitFor, int numReturns, int - timeoutMs, UniqueId currentTaskId); + timeoutMs, TaskId currentTaskId); - void freePlasmaObjects(List objectIds, boolean localOnly, boolean deleteCreatingTasks); + void freePlasmaObjects(List objectIds, boolean localOnly, boolean deleteCreatingTasks); UniqueId prepareCheckpoint(UniqueId actorId); diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java index b46d6b611a8e..b4bfa5a7fd47 100644 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java +++ b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java @@ -11,6 +11,8 @@ import org.ray.api.RayObject; import org.ray.api.WaitResult; import org.ray.api.exception.RayException; +import org.ray.api.id.ObjectId; +import org.ray.api.id.TaskId; import org.ray.api.id.UniqueId; import org.ray.runtime.functionmanager.JavaFunctionDescriptor; import org.ray.runtime.generated.Arg; @@ -20,7 +22,7 @@ import org.ray.runtime.task.FunctionArg; import org.ray.runtime.task.TaskLanguage; import org.ray.runtime.task.TaskSpec; -import org.ray.runtime.util.UniqueIdUtil; +import org.ray.runtime.util.IdUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -50,18 +52,18 @@ public RayletClientImpl(String schedulerSockName, UniqueId clientId, @Override public WaitResult wait(List> waitFor, int numReturns, int - timeoutMs, UniqueId currentTaskId) { + timeoutMs, TaskId currentTaskId) { Preconditions.checkNotNull(waitFor); if (waitFor.isEmpty()) { return new WaitResult<>(new ArrayList<>(), new ArrayList<>()); } - List ids = new ArrayList<>(); + List ids = new ArrayList<>(); for (RayObject element : waitFor) { ids.add(element.getId()); } - boolean[] ready = nativeWaitObject(client, UniqueIdUtil.getIdBytes(ids), + boolean[] ready = nativeWaitObject(client, IdUtil.getIdBytes(ids), numReturns, timeoutMs, false, currentTaskId.getBytes()); List> readyList = new ArrayList<>(); List> unreadyList = new ArrayList<>(); @@ -101,31 +103,31 @@ public TaskSpec getTask() { } @Override - public void fetchOrReconstruct(List objectIds, boolean fetchOnly, - UniqueId currentTaskId) { + public void fetchOrReconstruct(List objectIds, boolean fetchOnly, + TaskId currentTaskId) { if (LOGGER.isDebugEnabled()) { LOGGER.debug("Blocked on objects for task {}, object IDs are {}", - UniqueIdUtil.computeTaskId(objectIds.get(0)), objectIds); + objectIds.get(0).getTaskId(), objectIds); } - nativeFetchOrReconstruct(client, UniqueIdUtil.getIdBytes(objectIds), + nativeFetchOrReconstruct(client, IdUtil.getIdBytes(objectIds), fetchOnly, currentTaskId.getBytes()); } @Override - public UniqueId generateTaskId(UniqueId driverId, UniqueId parentTaskId, int taskIndex) { + public TaskId generateTaskId(UniqueId driverId, TaskId parentTaskId, int taskIndex) { byte[] bytes = nativeGenerateTaskId(driverId.getBytes(), parentTaskId.getBytes(), taskIndex); - return new UniqueId(bytes); + return new TaskId(bytes); } @Override - public void notifyUnblocked(UniqueId currentTaskId) { + public void notifyUnblocked(TaskId currentTaskId) { nativeNotifyUnblocked(client, currentTaskId.getBytes()); } @Override - public void freePlasmaObjects(List objectIds, boolean localOnly, + public void freePlasmaObjects(List objectIds, boolean localOnly, boolean deleteCreatingTasks) { - byte[][] objectIdsArray = UniqueIdUtil.getIdBytes(objectIds); + byte[][] objectIdsArray = IdUtil.getIdBytes(objectIds); nativeFreePlasmaObjects(client, objectIdsArray, localOnly, deleteCreatingTasks); } @@ -144,8 +146,8 @@ private static TaskSpec parseTaskSpecFromFlatbuffer(ByteBuffer bb) { bb.order(ByteOrder.LITTLE_ENDIAN); TaskInfo info = TaskInfo.getRootAsTaskInfo(bb); UniqueId driverId = UniqueId.fromByteBuffer(info.driverIdAsByteBuffer()); - UniqueId taskId = UniqueId.fromByteBuffer(info.taskIdAsByteBuffer()); - UniqueId parentTaskId = UniqueId.fromByteBuffer(info.parentTaskIdAsByteBuffer()); + TaskId taskId = TaskId.fromByteBuffer(info.taskIdAsByteBuffer()); + TaskId parentTaskId = TaskId.fromByteBuffer(info.parentTaskIdAsByteBuffer()); int parentCounter = info.parentCounter(); UniqueId actorCreationId = UniqueId.fromByteBuffer(info.actorCreationIdAsByteBuffer()); int maxActorReconstructions = info.maxActorReconstructions(); @@ -154,7 +156,7 @@ private static TaskSpec parseTaskSpecFromFlatbuffer(ByteBuffer bb) { int actorCounter = info.actorCounter(); // Deserialize new actor handles - UniqueId[] newActorHandles = UniqueIdUtil.getUniqueIdsFromByteBuffer( + UniqueId[] newActorHandles = IdUtil.getUniqueIdsFromByteBuffer( info.newActorHandlesAsByteBuffer()); // Deserialize args @@ -166,8 +168,7 @@ private static TaskSpec parseTaskSpecFromFlatbuffer(ByteBuffer bb) { if (objectIdsLength > 0) { Preconditions.checkArgument(objectIdsLength == 1, "This arg has more than one id: {}", objectIdsLength); - UniqueId id = UniqueIdUtil.getUniqueIdsFromByteBuffer(arg.objectIdsAsByteBuffer())[0]; - args[i] = FunctionArg.passByReference(id); + args[i] = FunctionArg.passByReference(ObjectId.fromByteBuffer(arg.objectIdsAsByteBuffer())); } else { ByteBuffer lbb = arg.dataAsByteBuffer(); Preconditions.checkState(lbb != null && lbb.remaining() > 0); @@ -177,7 +178,7 @@ private static TaskSpec parseTaskSpecFromFlatbuffer(ByteBuffer bb) { } } // Deserialize return ids - UniqueId[] returnIds = UniqueIdUtil.getUniqueIdsFromByteBuffer(info.returnsAsByteBuffer()); + ObjectId[] returnIds = IdUtil.getObjectIdsFromByteBuffer(info.returnsAsByteBuffer()); // Deserialize required resources; Map resources = new HashMap<>(); @@ -213,7 +214,7 @@ private static ByteBuffer convertTaskSpecToFlatbuffer(TaskSpec task) { // Serialize the new actor handles. int newActorHandlesOffset - = fbb.createString(UniqueIdUtil.concatUniqueIds(task.newActorHandles)); + = fbb.createString(IdUtil.concatIds(task.newActorHandles)); // Serialize args int[] argsOffsets = new int[task.args.length]; @@ -222,7 +223,7 @@ private static ByteBuffer convertTaskSpecToFlatbuffer(TaskSpec task) { int dataOffset = 0; if (task.args[i].id != null) { objectIdOffset = fbb.createString( - UniqueIdUtil.concatUniqueIds(new UniqueId[]{task.args[i].id})); + IdUtil.concatIds(new ObjectId[]{task.args[i].id})); } else { objectIdOffset = fbb.createString(""); } @@ -234,7 +235,7 @@ private static ByteBuffer convertTaskSpecToFlatbuffer(TaskSpec task) { int argsOffset = fbb.createVectorOfTables(argsOffsets); // Serialize returns - int returnsOffset = fbb.createString(UniqueIdUtil.concatUniqueIds(task.returnIds)); + int returnsOffset = fbb.createString(IdUtil.concatIds(task.returnIds)); // Serialize required resources // The required_resources vector indicates the quantities of the different diff --git a/java/runtime/src/main/java/org/ray/runtime/task/ArgumentsBuilder.java b/java/runtime/src/main/java/org/ray/runtime/task/ArgumentsBuilder.java index 1da6dec31eb1..52447cf79334 100644 --- a/java/runtime/src/main/java/org/ray/runtime/task/ArgumentsBuilder.java +++ b/java/runtime/src/main/java/org/ray/runtime/task/ArgumentsBuilder.java @@ -5,7 +5,7 @@ import org.ray.api.Ray; import org.ray.api.RayActor; import org.ray.api.RayObject; -import org.ray.api.id.UniqueId; +import org.ray.api.id.ObjectId; import org.ray.runtime.AbstractRayRuntime; import org.ray.runtime.util.Serializer; @@ -24,7 +24,7 @@ public static FunctionArg[] wrap(Object[] args, boolean crossLanguage) { FunctionArg[] ret = new FunctionArg[args.length]; for (int i = 0; i < ret.length; i++) { Object arg = args[i]; - UniqueId id = null; + ObjectId id = null; byte[] data = null; if (arg == null) { data = Serializer.encode(null); @@ -59,7 +59,7 @@ public static FunctionArg[] wrap(Object[] args, boolean crossLanguage) { */ public static Object[] unwrap(TaskSpec task, ClassLoader classLoader) { Object[] realArgs = new Object[task.args.length]; - List idsToFetch = new ArrayList<>(); + List idsToFetch = new ArrayList<>(); List indices = new ArrayList<>(); for (int i = 0; i < task.args.length; i++) { FunctionArg arg = task.args[i]; diff --git a/java/runtime/src/main/java/org/ray/runtime/task/FunctionArg.java b/java/runtime/src/main/java/org/ray/runtime/task/FunctionArg.java index 19a16e872b55..95bdcb0da653 100644 --- a/java/runtime/src/main/java/org/ray/runtime/task/FunctionArg.java +++ b/java/runtime/src/main/java/org/ray/runtime/task/FunctionArg.java @@ -1,6 +1,6 @@ package org.ray.runtime.task; -import org.ray.api.id.UniqueId; +import org.ray.api.id.ObjectId; /** * Represents a function argument in task spec. @@ -12,13 +12,13 @@ public class FunctionArg { /** * The id of this argument (passed by reference). */ - public final UniqueId id; + public final ObjectId id; /** * Serialized data of this argument (passed by value). */ public final byte[] data; - private FunctionArg(UniqueId id, byte[] data) { + private FunctionArg(ObjectId id, byte[] data) { this.id = id; this.data = data; } @@ -26,7 +26,7 @@ private FunctionArg(UniqueId id, byte[] data) { /** * Create a FunctionArg that will be passed by reference. */ - public static FunctionArg passByReference(UniqueId id) { + public static FunctionArg passByReference(ObjectId id) { return new FunctionArg(id, null); } diff --git a/java/runtime/src/main/java/org/ray/runtime/task/TaskSpec.java b/java/runtime/src/main/java/org/ray/runtime/task/TaskSpec.java index d8f715ce6a76..8a98e11c61ae 100644 --- a/java/runtime/src/main/java/org/ray/runtime/task/TaskSpec.java +++ b/java/runtime/src/main/java/org/ray/runtime/task/TaskSpec.java @@ -5,6 +5,8 @@ import java.util.Arrays; import java.util.List; import java.util.Map; +import org.ray.api.id.ObjectId; +import org.ray.api.id.TaskId; import org.ray.api.id.UniqueId; import org.ray.runtime.functionmanager.FunctionDescriptor; import org.ray.runtime.functionmanager.JavaFunctionDescriptor; @@ -19,10 +21,10 @@ public class TaskSpec { public final UniqueId driverId; // Task ID of the task. - public final UniqueId taskId; + public final TaskId taskId; // Task ID of the parent task. - public final UniqueId parentTaskId; + public final TaskId parentTaskId; // A count of the number of tasks submitted by the parent task before this one. public final int parentCounter; @@ -49,7 +51,7 @@ public class TaskSpec { public final FunctionArg[] args; // return ids - public final UniqueId[] returnIds; + public final ObjectId[] returnIds; // The task's resource demands. public final Map resources; @@ -62,7 +64,7 @@ public class TaskSpec { // is Python, the type is PyFunctionDescriptor. private final FunctionDescriptor functionDescriptor; - private List executionDependencies; + private List executionDependencies; public boolean isActorTask() { return !actorId.isNil(); @@ -74,8 +76,8 @@ public boolean isActorCreationTask() { public TaskSpec( UniqueId driverId, - UniqueId taskId, - UniqueId parentTaskId, + TaskId taskId, + TaskId parentTaskId, int parentCounter, UniqueId actorCreationId, int maxActorReconstructions, @@ -84,7 +86,7 @@ public TaskSpec( int actorCounter, UniqueId[] newActorHandles, FunctionArg[] args, - UniqueId[] returnIds, + ObjectId[] returnIds, Map resources, TaskLanguage language, FunctionDescriptor functionDescriptor) { @@ -125,7 +127,7 @@ public PyFunctionDescriptor getPyFunctionDescriptor() { return (PyFunctionDescriptor) functionDescriptor; } - public List getExecutionDependencies() { + public List getExecutionDependencies() { return executionDependencies; } diff --git a/java/runtime/src/main/java/org/ray/runtime/util/UniqueIdUtil.java b/java/runtime/src/main/java/org/ray/runtime/util/IdUtil.java similarity index 64% rename from java/runtime/src/main/java/org/ray/runtime/util/UniqueIdUtil.java rename to java/runtime/src/main/java/org/ray/runtime/util/IdUtil.java index fa8b51ffaac8..62c56d17ceed 100644 --- a/java/runtime/src/main/java/org/ray/runtime/util/UniqueIdUtil.java +++ b/java/runtime/src/main/java/org/ray/runtime/util/IdUtil.java @@ -3,19 +3,20 @@ import com.google.common.base.Preconditions; import java.nio.ByteBuffer; import java.nio.ByteOrder; -import java.util.Arrays; import java.util.List; +import org.ray.api.id.BaseId; +import org.ray.api.id.ObjectId; +import org.ray.api.id.TaskId; import org.ray.api.id.UniqueId; /** - * Helper method for UniqueId. + * Helper method for different Ids. * Note: any changes to these methods must be synced with C++ helper functions * in src/ray/id.h */ -public class UniqueIdUtil { - public static final int OBJECT_INDEX_POS = 0; - public static final int OBJECT_INDEX_LENGTH = 4; +public class IdUtil { + public static final int OBJECT_INDEX_POS = 16; /** * Compute the object ID of an object returned by the task. @@ -24,7 +25,7 @@ public class UniqueIdUtil { * @param returnIndex What number return value this object is in the task. * @return The computed object ID. */ - public static UniqueId computeReturnId(UniqueId taskId, int returnIndex) { + public static ObjectId computeReturnId(TaskId taskId, int returnIndex) { return computeObjectId(taskId, returnIndex); } @@ -34,14 +35,13 @@ public static UniqueId computeReturnId(UniqueId taskId, int returnIndex) { * @param index The index which can distinguish different objects in one task. * @return The computed object ID. */ - private static UniqueId computeObjectId(UniqueId taskId, int index) { - byte[] objId = new byte[UniqueId.LENGTH]; - System.arraycopy(taskId.getBytes(),0, objId, 0, UniqueId.LENGTH); - ByteBuffer wbb = ByteBuffer.wrap(objId); + private static ObjectId computeObjectId(TaskId taskId, int index) { + byte[] bytes = new byte[ObjectId.LENGTH]; + System.arraycopy(taskId.getBytes(), 0, bytes, 0, taskId.size()); + ByteBuffer wbb = ByteBuffer.wrap(bytes); wbb.order(ByteOrder.LITTLE_ENDIAN); - wbb.putInt(UniqueIdUtil.OBJECT_INDEX_POS, index); - - return new UniqueId(objId); + wbb.putInt(OBJECT_INDEX_POS, index); + return new ObjectId(bytes); } /** @@ -51,26 +51,11 @@ private static UniqueId computeObjectId(UniqueId taskId, int index) { * @param putIndex What number put this object was created by in the task. * @return The computed object ID. */ - public static UniqueId computePutId(UniqueId taskId, int putIndex) { + public static ObjectId computePutId(TaskId taskId, int putIndex) { // We multiply putIndex by -1 to distinguish from returnIndex. return computeObjectId(taskId, -1 * putIndex); } - /** - * Compute the task ID of the task that created the object. - * - * @param objectId The object ID. - * @return The task ID of the task that created this object. - */ - public static UniqueId computeTaskId(UniqueId objectId) { - byte[] taskId = new byte[UniqueId.LENGTH]; - System.arraycopy(objectId.getBytes(), 0, taskId, 0, UniqueId.LENGTH); - Arrays.fill(taskId, UniqueIdUtil.OBJECT_INDEX_POS, - UniqueIdUtil.OBJECT_INDEX_POS + UniqueIdUtil.OBJECT_INDEX_LENGTH, (byte) 0); - - return new UniqueId(taskId); - } - /** * Generate the return ids of a task. * @@ -78,15 +63,15 @@ public static UniqueId computeTaskId(UniqueId objectId) { * @param numReturns The number of returnIds. * @return The Return Ids of this task. */ - public static UniqueId[] genReturnIds(UniqueId taskId, int numReturns) { - UniqueId[] ret = new UniqueId[numReturns]; + public static ObjectId[] genReturnIds(TaskId taskId, int numReturns) { + ObjectId[] ret = new ObjectId[numReturns]; for (int i = 0; i < numReturns; i++) { - ret[i] = UniqueIdUtil.computeReturnId(taskId, i + 1); + ret[i] = IdUtil.computeReturnId(taskId, i + 1); } return ret; } - public static byte[][] getIdBytes(List objectIds) { + public static byte[][] getIdBytes(List objectIds) { int size = objectIds.size(); byte[][] ids = new byte[size][]; for (int i = 0; i < size; i++) { @@ -95,6 +80,24 @@ public static byte[][] getIdBytes(List objectIds) { return ids; } + public static byte[][] getByteListFromByteBuffer(ByteBuffer byteBufferOfIds, int length) { + Preconditions.checkArgument(byteBufferOfIds != null); + + byte[] bytesOfIds = new byte[byteBufferOfIds.remaining()]; + byteBufferOfIds.get(bytesOfIds, 0, byteBufferOfIds.remaining()); + + int count = bytesOfIds.length / length; + byte[][] idBytes = new byte[count][]; + + for (int i = 0; i < count; ++i) { + byte[] id = new byte[length]; + System.arraycopy(bytesOfIds, i * length, id, 0, length); + idBytes[i] = id; + } + + return idBytes; + } + /** * Get unique IDs from concatenated ByteBuffer. * @@ -102,21 +105,31 @@ public static byte[][] getIdBytes(List objectIds) { * @return The array of unique IDs. */ public static UniqueId[] getUniqueIdsFromByteBuffer(ByteBuffer byteBufferOfIds) { - Preconditions.checkArgument(byteBufferOfIds != null); + byte[][]idBytes = getByteListFromByteBuffer(byteBufferOfIds, UniqueId.LENGTH); + UniqueId[] uniqueIds = new UniqueId[idBytes.length]; - byte[] bytesOfIds = new byte[byteBufferOfIds.remaining()]; - byteBufferOfIds.get(bytesOfIds, 0, byteBufferOfIds.remaining()); + for (int i = 0; i < idBytes.length; ++i) { + uniqueIds[i] = UniqueId.fromByteBuffer(ByteBuffer.wrap(idBytes[i])); + } + + return uniqueIds; + } - int count = bytesOfIds.length / UniqueId.LENGTH; - UniqueId[] uniqueIds = new UniqueId[count]; + /** + * Get object IDs from concatenated ByteBuffer. + * + * @param byteBufferOfIds The ByteBuffer concatenated from IDs. + * @return The array of object IDs. + */ + public static ObjectId[] getObjectIdsFromByteBuffer(ByteBuffer byteBufferOfIds) { + byte[][]idBytes = getByteListFromByteBuffer(byteBufferOfIds, UniqueId.LENGTH); + ObjectId[] objectIds = new ObjectId[idBytes.length]; - for (int i = 0; i < count; ++i) { - byte[] id = new byte[UniqueId.LENGTH]; - System.arraycopy(bytesOfIds, i * UniqueId.LENGTH, id, 0, UniqueId.LENGTH); - uniqueIds[i] = UniqueId.fromByteBuffer(ByteBuffer.wrap(id)); + for (int i = 0; i < idBytes.length; ++i) { + objectIds[i] = ObjectId.fromByteBuffer(ByteBuffer.wrap(idBytes[i])); } - return uniqueIds; + return objectIds; } /** @@ -125,11 +138,15 @@ public static UniqueId[] getUniqueIdsFromByteBuffer(ByteBuffer byteBufferOfIds) * @param ids The array of IDs that will be concatenated. * @return A ByteBuffer that contains bytes of concatenated IDs. */ - public static ByteBuffer concatUniqueIds(UniqueId[] ids) { - byte[] bytesOfIds = new byte[UniqueId.LENGTH * ids.length]; + public static ByteBuffer concatIds(T[] ids) { + int length = 0; + if (ids != null && ids.length != 0) { + length = ids[0].size() * ids.length; + } + byte[] bytesOfIds = new byte[length]; for (int i = 0; i < ids.length; ++i) { System.arraycopy(ids[i].getBytes(), 0, bytesOfIds, - i * UniqueId.LENGTH, UniqueId.LENGTH); + i * ids[i].size(), ids[i].size()); } return ByteBuffer.wrap(bytesOfIds); @@ -139,8 +156,8 @@ public static ByteBuffer concatUniqueIds(UniqueId[] ids) { /** * Compute the murmur hash code of this ID. */ - public static long murmurHashCode(UniqueId id) { - return murmurHash64A(id.getBytes(), UniqueId.LENGTH, 0); + public static long murmurHashCode(BaseId id) { + return murmurHash64A(id.getBytes(), id.size(), 0); } /** diff --git a/java/test/src/main/java/org/ray/api/test/ClientExceptionTest.java b/java/test/src/main/java/org/ray/api/test/ClientExceptionTest.java index b588822712c5..227ff7e5865b 100644 --- a/java/test/src/main/java/org/ray/api/test/ClientExceptionTest.java +++ b/java/test/src/main/java/org/ray/api/test/ClientExceptionTest.java @@ -6,7 +6,7 @@ import org.ray.api.RayObject; import org.ray.api.TestUtils; import org.ray.api.exception.RayException; -import org.ray.api.id.UniqueId; +import org.ray.api.id.ObjectId; import org.ray.runtime.RayObjectImpl; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -20,7 +20,7 @@ public class ClientExceptionTest extends BaseTest { @Test public void testWaitAndCrash() { TestUtils.skipTestUnderSingleProcess(); - UniqueId randomId = UniqueId.randomId(); + ObjectId randomId = ObjectId.randomId(); RayObject notExisting = new RayObjectImpl(randomId); Thread thread = new Thread(() -> { diff --git a/java/test/src/main/java/org/ray/api/test/ObjectStoreTest.java b/java/test/src/main/java/org/ray/api/test/ObjectStoreTest.java index eaa99a2892fd..be584ba6d1be 100644 --- a/java/test/src/main/java/org/ray/api/test/ObjectStoreTest.java +++ b/java/test/src/main/java/org/ray/api/test/ObjectStoreTest.java @@ -5,7 +5,7 @@ import java.util.stream.Collectors; import org.ray.api.Ray; import org.ray.api.RayObject; -import org.ray.api.id.UniqueId; +import org.ray.api.id.ObjectId; import org.testng.Assert; import org.testng.annotations.Test; @@ -23,7 +23,7 @@ public void testPutAndGet() { @Test public void testGetMultipleObjects() { List ints = ImmutableList.of(1, 2, 3, 4, 5); - List ids = ints.stream().map(obj -> Ray.put(obj).getId()) + List ids = ints.stream().map(obj -> Ray.put(obj).getId()) .collect(Collectors.toList()); Assert.assertEquals(ints, Ray.get(ids)); } diff --git a/java/test/src/main/java/org/ray/api/test/PlasmaFreeTest.java b/java/test/src/main/java/org/ray/api/test/PlasmaFreeTest.java index 1e344e5028b3..3c36f2201a8b 100644 --- a/java/test/src/main/java/org/ray/api/test/PlasmaFreeTest.java +++ b/java/test/src/main/java/org/ray/api/test/PlasmaFreeTest.java @@ -6,7 +6,6 @@ import org.ray.api.TestUtils; import org.ray.api.annotation.RayRemote; import org.ray.runtime.AbstractRayRuntime; -import org.ray.runtime.util.UniqueIdUtil; import org.testng.Assert; import org.testng.annotations.Test; @@ -38,7 +37,7 @@ public void testDeleteCreatingTasks() { final boolean result = TestUtils.waitForCondition( () -> !(((AbstractRayRuntime)Ray.internal()).getGcsClient()) - .rayletTaskExistsInGcs(UniqueIdUtil.computeTaskId(helloId.getId())), 50); + .rayletTaskExistsInGcs(helloId.getId().getTaskId()), 50); Assert.assertTrue(result); } diff --git a/java/test/src/main/java/org/ray/api/test/StressTest.java b/java/test/src/main/java/org/ray/api/test/StressTest.java index b5bf1356ea4f..e2efecbf222e 100644 --- a/java/test/src/main/java/org/ray/api/test/StressTest.java +++ b/java/test/src/main/java/org/ray/api/test/StressTest.java @@ -7,7 +7,7 @@ import org.ray.api.RayActor; import org.ray.api.RayObject; import org.ray.api.TestUtils; -import org.ray.api.id.UniqueId; +import org.ray.api.id.ObjectId; import org.testng.Assert; import org.testng.annotations.Test; @@ -23,7 +23,7 @@ public void testSubmittingTasks() { for (int numIterations : ImmutableList.of(1, 10, 100, 1000)) { int numTasks = 1000 / numIterations; for (int i = 0; i < numIterations; i++) { - List resultIds = new ArrayList<>(); + List resultIds = new ArrayList<>(); for (int j = 0; j < numTasks; j++) { resultIds.add(Ray.call(StressTest::echo, 1).getId()); } @@ -60,7 +60,7 @@ public Worker(RayActor actor) { } public int ping(int n) { - List objectIds = new ArrayList<>(); + List objectIds = new ArrayList<>(); for (int i = 0; i < n; i++) { objectIds.add(Ray.call(Actor::ping, actor).getId()); } @@ -76,7 +76,7 @@ public int ping(int n) { public void testSubmittingManyTasksToOneActor() { TestUtils.skipTestUnderSingleProcess(); RayActor actor = Ray.createActor(Actor::new); - List objectIds = new ArrayList<>(); + List objectIds = new ArrayList<>(); for (int i = 0; i < 10; i++) { RayActor worker = Ray.createActor(Worker::new, actor); objectIds.add(Ray.call(Worker::ping, worker, 100).getId()); diff --git a/java/test/src/main/java/org/ray/api/test/UniqueIdTest.java b/java/test/src/main/java/org/ray/api/test/UniqueIdTest.java index 5b3d773dbf2c..cc1bc7a53f3e 100644 --- a/java/test/src/main/java/org/ray/api/test/UniqueIdTest.java +++ b/java/test/src/main/java/org/ray/api/test/UniqueIdTest.java @@ -3,8 +3,10 @@ import java.nio.ByteBuffer; import java.util.Arrays; import javax.xml.bind.DatatypeConverter; +import org.ray.api.id.ObjectId; +import org.ray.api.id.TaskId; import org.ray.api.id.UniqueId; -import org.ray.runtime.util.UniqueIdUtil; +import org.ray.runtime.util.IdUtil; import org.testng.Assert; import org.testng.annotations.Test; @@ -42,7 +44,7 @@ public void testConstructUniqueId() { // Test `genNil()` - UniqueId id6 = UniqueId.genNil(); + UniqueId id6 = UniqueId.NIL; Assert.assertEquals("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF".toLowerCase(), id6.toString()); Assert.assertTrue(id6.isNil()); } @@ -50,33 +52,33 @@ public void testConstructUniqueId() { @Test public void testComputeReturnId() { // Mock a taskId, and the lowest 4 bytes should be 0. - UniqueId taskId = UniqueId.fromHexString("00000000123456789ABCDEF123456789ABCDEF00"); + TaskId taskId = TaskId.fromHexString("123456789ABCDEF123456789ABCDEF00"); - UniqueId returnId = UniqueIdUtil.computeReturnId(taskId, 1); - Assert.assertEquals("01000000123456789abcdef123456789abcdef00", returnId.toString()); + ObjectId returnId = IdUtil.computeReturnId(taskId, 1); + Assert.assertEquals("123456789abcdef123456789abcdef0001000000", returnId.toString()); - returnId = UniqueIdUtil.computeReturnId(taskId, 0x01020304); - Assert.assertEquals("04030201123456789abcdef123456789abcdef00", returnId.toString()); + returnId = IdUtil.computeReturnId(taskId, 0x01020304); + Assert.assertEquals("123456789abcdef123456789abcdef0004030201", returnId.toString()); } @Test public void testComputeTaskId() { - UniqueId objId = UniqueId.fromHexString("34421980123456789ABCDEF123456789ABCDEF00"); - UniqueId taskId = UniqueIdUtil.computeTaskId(objId); + ObjectId objId = ObjectId.fromHexString("123456789ABCDEF123456789ABCDEF0034421980"); + TaskId taskId = objId.getTaskId(); - Assert.assertEquals("00000000123456789abcdef123456789abcdef00", taskId.toString()); + Assert.assertEquals("123456789abcdef123456789abcdef00", taskId.toString()); } @Test public void testComputePutId() { // Mock a taskId, the lowest 4 bytes should be 0. - UniqueId taskId = UniqueId.fromHexString("00000000123456789ABCDEF123456789ABCDEF00"); + TaskId taskId = TaskId.fromHexString("123456789ABCDEF123456789ABCDEF00"); - UniqueId putId = UniqueIdUtil.computePutId(taskId, 1); - Assert.assertEquals("FFFFFFFF123456789ABCDEF123456789ABCDEF00".toLowerCase(), putId.toString()); + ObjectId putId = IdUtil.computePutId(taskId, 1); + Assert.assertEquals("123456789ABCDEF123456789ABCDEF00FFFFFFFF".toLowerCase(), putId.toString()); - putId = UniqueIdUtil.computePutId(taskId, 0x01020304); - Assert.assertEquals("FCFCFDFE123456789ABCDEF123456789ABCDEF00".toLowerCase(), putId.toString()); + putId = IdUtil.computePutId(taskId, 0x01020304); + Assert.assertEquals("123456789ABCDEF123456789ABCDEF00FCFCFDFE".toLowerCase(), putId.toString()); } @Test @@ -87,8 +89,8 @@ public void testUniqueIdsAndByteBufferInterConversion() { ids[i] = UniqueId.randomId(); } - ByteBuffer temp = UniqueIdUtil.concatUniqueIds(ids); - UniqueId[] res = UniqueIdUtil.getUniqueIdsFromByteBuffer(temp); + ByteBuffer temp = IdUtil.concatIds(ids); + UniqueId[] res = IdUtil.getUniqueIdsFromByteBuffer(temp); for (int i = 0; i < len; ++i) { Assert.assertEquals(ids[i], res[i]); @@ -98,8 +100,28 @@ public void testUniqueIdsAndByteBufferInterConversion() { @Test void testMurmurHash() { UniqueId id = UniqueId.fromHexString("3131313131313131313132323232323232323232"); - long remainder = Long.remainderUnsigned(UniqueIdUtil.murmurHashCode(id), 1000000000); + long remainder = Long.remainderUnsigned(IdUtil.murmurHashCode(id), 1000000000); Assert.assertEquals(remainder, 787616861); } + @Test + void testConcateIds() { + String taskHexStr = "123456789ABCDEF123456789ABCDEF00"; + String objectHexStr = taskHexStr + "01020304"; + ObjectId objectId1 = ObjectId.fromHexString(objectHexStr); + ObjectId objectId2 = ObjectId.fromHexString(objectHexStr); + TaskId[] taskIds = new TaskId[2]; + taskIds[0] = objectId1.getTaskId(); + taskIds[1] = objectId2.getTaskId(); + ObjectId[] objectIds = new ObjectId[2]; + objectIds[0] = objectId1; + objectIds[1] = objectId2; + String taskHexCompareStr = taskHexStr + taskHexStr; + String objectHexCompareStr = objectHexStr + objectHexStr; + Assert.assertEquals(DatatypeConverter.printHexBinary( + IdUtil.concatIds(taskIds).array()), taskHexCompareStr); + Assert.assertEquals(DatatypeConverter.printHexBinary( + IdUtil.concatIds(objectIds).array()), objectHexCompareStr); + } + } diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index bae62f9b1c88..a5f106f1e911 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -88,11 +88,11 @@ def compute_put_id(TaskID task_id, int64_t put_index): if put_index < 1 or put_index > kMaxTaskPuts: raise ValueError("The range of 'put_index' should be [1, %d]" % kMaxTaskPuts) - return ObjectID(ComputePutId(task_id.native(), put_index).binary()) + return ObjectID(CObjectID.for_put(task_id.native(), put_index).binary()) def compute_task_id(ObjectID object_id): - return TaskID(ComputeTaskId(object_id.native()).binary()) + return TaskID(object_id.native().task_id().binary()) cdef c_bool is_simple_value(value, int *num_elements_contained): diff --git a/python/ray/actor.py b/python/ray/actor.py index 7c24208028b4..e806a5f8fae3 100644 --- a/python/ray/actor.py +++ b/python/ray/actor.py @@ -17,7 +17,6 @@ import ray.ray_constants as ray_constants import ray.signature as signature import ray.worker -from ray.utils import _random_string from ray import (ObjectID, ActorID, ActorHandleID, ActorClassID, TaskID, DriverID) @@ -308,7 +307,7 @@ def _remote(self, raise Exception("Actors cannot be created before ray.init() " "has been called.") - actor_id = ActorID(_random_string()) + actor_id = ActorID.from_random() # The actor cursor is a dummy object representing the most recent # actor method invocation. For each subsequent method invocation, # the current cursor should be added as a dependency, and then @@ -670,7 +669,7 @@ def _serialization_helper(self, ray_forking): # to release, since it could be unpickled and submit another # dependent task at any time. Therefore, we notify the backend of a # random handle ID that will never actually be used. - new_actor_handle_id = ActorHandleID(_random_string()) + new_actor_handle_id = ActorHandleID.from_random() # Notify the backend to expect this new actor handle. The backend will # not release the cursor for any new handles until the first task for # each of the new handles is submitted. @@ -780,7 +779,7 @@ def __ray_checkpoint__(self): Class.__module__ = cls.__module__ Class.__name__ = cls.__name__ - class_id = ActorClassID(_random_string()) + class_id = ActorClassID.from_random() return ActorClass(Class, class_id, max_reconstructions, num_cpus, num_gpus, resources) diff --git a/python/ray/includes/common.pxd b/python/ray/includes/common.pxd index 3b6463fc9ea6..bdb4316fcc4e 100644 --- a/python/ray/includes/common.pxd +++ b/python/ray/includes/common.pxd @@ -81,15 +81,9 @@ cdef extern from "ray/status.h" namespace "ray::StatusCode" nogil: cdef extern from "ray/id.h" namespace "ray" nogil: - const CTaskID FinishTaskId(const CTaskID &task_id) - const CObjectID ComputeReturnId(const CTaskID &task_id, - int64_t return_index) - const CObjectID ComputePutId(const CTaskID &task_id, int64_t put_index) - const CTaskID ComputeTaskId(const CObjectID &object_id) const CTaskID GenerateTaskId(const CDriverID &driver_id, const CTaskID &parent_task_id, int parent_task_counter) - int64_t ComputeObjectIndex(const CObjectID &object_id) cdef extern from "ray/gcs/format/gcs_generated.h" nogil: diff --git a/python/ray/includes/unique_ids.pxd b/python/ray/includes/unique_ids.pxd index a607b2a86419..fbe793cc023b 100644 --- a/python/ray/includes/unique_ids.pxd +++ b/python/ray/includes/unique_ids.pxd @@ -1,12 +1,35 @@ from libcpp cimport bool as c_bool from libcpp.string cimport string as c_string -from libc.stdint cimport uint8_t +from libc.stdint cimport uint8_t, int64_t cdef extern from "ray/id.h" namespace "ray" nogil: - cdef cppclass CUniqueID "ray::UniqueID": + cdef cppclass CBaseID[T]: + @staticmethod + T from_random() + + @staticmethod + T from_binary(const c_string &binary) + + @staticmethod + const T nil() + + @staticmethod + size_t size() + + size_t hash() const + c_bool is_nil() const + c_bool operator==(const CBaseID &rhs) const + c_bool operator!=(const CBaseID &rhs) const + const uint8_t *data() const; + + c_string binary() const; + c_string hex() const; + + cdef cppclass CUniqueID "ray::UniqueID"(CBaseID): CUniqueID() - CUniqueID(const c_string &binary) - CUniqueID(const CUniqueID &from_id) + + @staticmethod + size_t size() @staticmethod CUniqueID from_random() @@ -17,15 +40,8 @@ cdef extern from "ray/id.h" namespace "ray" nogil: @staticmethod const CUniqueID nil() - size_t hash() const - c_bool is_nil() const - c_bool operator==(const CUniqueID& rhs) const - c_bool operator!=(const CUniqueID& rhs) const - const uint8_t *data() const - uint8_t *mutable_data() - size_t size() const - c_string binary() const - c_string hex() const + @staticmethod + size_t size() cdef cppclass CActorCheckpointID "ray::ActorCheckpointID"(CUniqueID): @@ -67,16 +83,40 @@ cdef extern from "ray/id.h" namespace "ray" nogil: @staticmethod CDriverID from_binary(const c_string &binary) - cdef cppclass CTaskID "ray::TaskID"(CUniqueID): + cdef cppclass CTaskID "ray::TaskID"(CBaseID[CTaskID]): @staticmethod CTaskID from_binary(const c_string &binary) - cdef cppclass CObjectID" ray::ObjectID"(CUniqueID): + @staticmethod + const CTaskID nil() + + @staticmethod + size_t size() + + cdef cppclass CObjectID" ray::ObjectID"(CBaseID[CObjectID]): @staticmethod CObjectID from_binary(const c_string &binary) + @staticmethod + const CObjectID nil() + + @staticmethod + CObjectID for_put(const CTaskID &task_id, int64_t index); + + @staticmethod + CObjectID for_task_return(const CTaskID &task_id, int64_t index); + + @staticmethod + size_t size() + + c_bool is_put() + + int64_t object_index() const + + CTaskID task_id() const + cdef cppclass CWorkerID "ray::WorkerID"(CUniqueID): @staticmethod diff --git a/python/ray/includes/unique_ids.pxi b/python/ray/includes/unique_ids.pxi index c96668f2bf07..b9773d56fb20 100644 --- a/python/ray/includes/unique_ids.pxi +++ b/python/ray/includes/unique_ids.pxi @@ -6,10 +6,8 @@ See https://github.com/ray-project/ray/issues/3721. # WARNING: Any additional ID types defined in this file must be added to the # _ID_TYPES list at the bottom of this file. -from ray.includes.common cimport ( - ComputePutId, - ComputeTaskId, -) +import os + from ray.includes.unique_ids cimport ( CActorCheckpointID, CActorClassID, @@ -28,12 +26,12 @@ from ray.includes.unique_ids cimport ( from ray.utils import decode -def check_id(b): +def check_id(b, size=kUniqueIDSize): if not isinstance(b, bytes): raise TypeError("Unsupported type: " + str(type(b))) - if len(b) != kUniqueIDSize: + if len(b) != size: raise ValueError("ID string needs to have length " + - str(kUniqueIDSize)) + str(size)) cdef extern from "ray/constants.h" nogil: @@ -41,28 +39,27 @@ cdef extern from "ray/constants.h" nogil: cdef int64_t kMaxTaskPuts -cdef class UniqueID: - cdef CUniqueID data +cdef class BaseID: - def __init__(self, id): - check_id(id) - self.data = CUniqueID.from_binary(id) + # To avoid the error of "Python int too large to convert to C ssize_t", + # here `cdef size_t` is required. + cdef size_t hash(self): + pass - @classmethod - def from_binary(cls, id_bytes): - if not isinstance(id_bytes, bytes): - raise TypeError("Expect bytes, got " + str(type(id_bytes))) - return cls(id_bytes) + def binary(self): + pass - @classmethod - def nil(cls): - return cls(CUniqueID.nil().binary()) + def size(self): + pass - def __hash__(self): - return self.data.hash() + def hex(self): + pass def is_nil(self): - return self.data.is_nil() + pass + + def __hash__(self): + return self.hash() def __eq__(self, other): return type(self) == type(other) and self.binary() == other.binary() @@ -70,18 +67,9 @@ cdef class UniqueID: def __ne__(self, other): return self.binary() != other.binary() - def size(self): - return self.data.size() - - def binary(self): - return self.data.binary() - def __bytes__(self): return self.binary() - def hex(self): - return decode(self.data.hex()) - def __hex__(self): return self.hex() @@ -98,11 +86,52 @@ cdef class UniqueID: # NOTE: The hash function used here must match the one in # GetRedisContext in src/ray/gcs/tables.h. Changes to the # hash function should only be made through std::hash in - # src/common/common.h + # src/common/common.h. + # Do not use __hash__ that returns signed uint64_t, which + # is different from std::hash in c++ code. + return self.hash() + + +cdef class UniqueID(BaseID): + cdef CUniqueID data + + def __init__(self, id): + check_id(id) + self.data = CUniqueID.from_binary(id) + + @classmethod + def from_binary(cls, id_bytes): + if not isinstance(id_bytes, bytes): + raise TypeError("Expect bytes, got " + str(type(id_bytes))) + return cls(id_bytes) + + @classmethod + def nil(cls): + return cls(CUniqueID.nil().binary()) + + + @classmethod + def from_random(cls): + return cls(os.urandom(CUniqueID.size())) + + def size(self): + return CUniqueID.size() + + def binary(self): + return self.data.binary() + + def hex(self): + return decode(self.data.hex()) + + def is_nil(self): + return self.data.is_nil() + + cdef size_t hash(self): return self.data.hash() -cdef class ObjectID(UniqueID): +cdef class ObjectID(BaseID): + cdef CObjectID data def __init__(self, id): check_id(id) @@ -111,16 +140,67 @@ cdef class ObjectID(UniqueID): cdef CObjectID native(self): return self.data + def size(self): + return CObjectID.size() -cdef class TaskID(UniqueID): + def binary(self): + return self.data.binary() + + def hex(self): + return decode(self.data.hex()) + + def is_nil(self): + return self.data.is_nil() + + cdef size_t hash(self): + return self.data.hash() + + @classmethod + def nil(cls): + return cls(CObjectID.nil().binary()) + + @classmethod + def from_random(cls): + return cls(os.urandom(CObjectID.size())) + + +cdef class TaskID(BaseID): + cdef CTaskID data def __init__(self, id): - check_id(id) + check_id(id, CTaskID.size()) self.data = CTaskID.from_binary(id) cdef CTaskID native(self): return self.data + def size(self): + return CTaskID.size() + + def binary(self): + return self.data.binary() + + def hex(self): + return decode(self.data.hex()) + + def is_nil(self): + return self.data.is_nil() + + cdef size_t hash(self): + return self.data.hash() + + @classmethod + def nil(cls): + return cls(CTaskID.nil().binary()) + + @classmethod + def size(cla): + return CTaskID.size() + + @classmethod + def from_random(cls): + return cls(os.urandom(CTaskID.size())) + cdef class ClientID(UniqueID): diff --git a/python/ray/monitor.py b/python/ray/monitor.py index ded86611e88c..cc6432cbc8de 100644 --- a/python/ray/monitor.py +++ b/python/ray/monitor.py @@ -16,8 +16,8 @@ import ray.gcs_utils import ray.utils import ray.ray_constants as ray_constants -from ray.utils import (binary_to_hex, binary_to_object_id, hex_to_binary, - setup_logger) +from ray.utils import (binary_to_hex, binary_to_object_id, binary_to_task_id, + hex_to_binary, setup_logger) logger = logging.getLogger(__name__) @@ -169,8 +169,12 @@ def _xray_clean_up_entries_for_driver(self, driver_id): driver_object_id_bins.add(object_id.binary()) def to_shard_index(id_bin): - return binary_to_object_id(id_bin).redis_shard_hash() % len( - self.state.redis_clients) + if len(id_bin) == ray.TaskID.size(): + return binary_to_task_id(id_bin).redis_shard_hash() % len( + self.state.redis_clients) + else: + return binary_to_object_id(id_bin).redis_shard_hash() % len( + self.state.redis_clients) # Form the redis keys to delete. sharded_keys = [[] for _ in range(len(self.state.redis_clients))] diff --git a/python/ray/tests/test_basic.py b/python/ray/tests/test_basic.py index 3f8c7cb2b3a1..ffd0fb630e80 100644 --- a/python/ray/tests/test_basic.py +++ b/python/ray/tests/test_basic.py @@ -7,6 +7,7 @@ from concurrent.futures import ThreadPoolExecutor import json import logging +from multiprocessing import Process import os import random import re @@ -28,7 +29,6 @@ import ray import ray.tests.cluster_utils import ray.tests.utils -from ray.utils import _random_string logger = logging.getLogger(__name__) @@ -2630,12 +2630,33 @@ def test_object_id_properties(): ray.ObjectID(id_bytes + b"1234") with pytest.raises(ValueError, match=r".*needs to have length 20.*"): ray.ObjectID(b"0123456789") - object_id = ray.ObjectID(_random_string()) + object_id = ray.ObjectID.from_random() assert not object_id.is_nil() assert object_id.binary() != id_bytes id_dumps = pickle.dumps(object_id) id_from_dumps = pickle.loads(id_dumps) assert id_from_dumps == object_id + file_prefix = "test_object_id_properties" + + # Make sure the ids are fork safe. + def write(index): + str = ray.ObjectID.from_random().hex() + with open("{}{}".format(file_prefix, index), "w") as fo: + fo.write(str) + + def read(index): + with open("{}{}".format(file_prefix, index), "r") as fi: + for line in fi: + return line + + processes = [Process(target=write, args=(_, )) for _ in range(4)] + for process in processes: + process.start() + for process in processes: + process.join() + hexes = {read(i) for i in range(4)} + [os.remove("{}{}".format(file_prefix, i)) for i in range(4)] + assert len(hexes) == 4 @pytest.fixture @@ -2768,7 +2789,7 @@ def test_pandas_parquet_serialization(): def test_socket_dir_not_existing(shutdown_only): - random_name = ray.ObjectID(_random_string()).hex() + random_name = ray.ObjectID.from_random().hex() temp_raylet_socket_dir = "/tmp/ray/tests/{}".format(random_name) temp_raylet_socket_name = os.path.join(temp_raylet_socket_dir, "raylet_socket") diff --git a/python/ray/tests/test_failure.py b/python/ray/tests/test_failure.py index 8fb58e576ea1..650cce68b246 100644 --- a/python/ray/tests/test_failure.py +++ b/python/ray/tests/test_failure.py @@ -15,7 +15,6 @@ import ray import ray.ray_constants as ray_constants -from ray.utils import _random_string from ray.tests.cluster_utils import Cluster from ray.tests.utils import ( relevant_errors, @@ -667,7 +666,7 @@ def test_warning_for_dead_node(ray_start_cluster_2_nodes): def test_raylet_crash_when_get(ray_start_regular): - nonexistent_id = ray.ObjectID(_random_string()) + nonexistent_id = ray.ObjectID.from_random() def sleep_to_kill_raylet(): # Don't kill raylet before default workers get connected. diff --git a/python/ray/utils.py b/python/ray/utils.py index 0f26aa22d03a..7b87486e325e 100644 --- a/python/ray/utils.py +++ b/python/ray/utils.py @@ -216,6 +216,10 @@ def binary_to_object_id(binary_object_id): return ray.ObjectID(binary_object_id) +def binary_to_task_id(binary_task_id): + return ray.TaskID(binary_task_id) + + def binary_to_hex(identifier): hex_identifier = binascii.hexlify(identifier) if sys.version_info >= (3, 0): diff --git a/python/ray/worker.py b/python/ray/worker.py index bbcf1bb2235e..5feb71344bce 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -198,7 +198,7 @@ def task_context(self): # to the current task ID may not be correct. Generate a # random task ID so that the backend can differentiate # between different threads. - self._task_context.current_task_id = TaskID(_random_string()) + self._task_context.current_task_id = TaskID.from_random() if getattr(self, "_multithreading_warned", False) is not True: logger.warning( "Calling ray.get or ray.wait in a separate thread " @@ -1725,7 +1725,7 @@ def connect(node, else: # This is the code path of driver mode. if driver_id is None: - driver_id = DriverID(_random_string()) + driver_id = DriverID.from_random() if not isinstance(driver_id, DriverID): raise TypeError("The type of given driver id must be DriverID.") @@ -1834,6 +1834,7 @@ def connect(node, # Create an object store client. worker.plasma_client = thread_safe_client( plasma.connect(node.plasma_store_socket_name, None, 0, 300)) + driver_id_str = _random_string() # If this is a driver, set the current task ID, the task driver ID, and set # the task index to 0. @@ -1865,7 +1866,7 @@ def connect(node, function_descriptor.get_function_descriptor_list(), [], # arguments. 0, # num_returns. - TaskID(_random_string()), # parent_task_id. + TaskID(driver_id_str[:TaskID.size()]), # parent_task_id. 0, # parent_counter. ActorID.nil(), # actor_creation_id. ObjectID.nil(), # actor_creation_dummy_object_id. @@ -1894,7 +1895,7 @@ def connect(node, node.raylet_socket_name, ClientID(worker.worker_id), (mode == WORKER_MODE), - DriverID(worker.current_task_id.binary()), + DriverID(driver_id_str), ) # Start the import thread diff --git a/src/ray/constants.h b/src/ray/constants.h index 2035938be267..c92e6a74aa5d 100644 --- a/src/ray/constants.h +++ b/src/ray/constants.h @@ -4,7 +4,7 @@ #include #include -/// Length of Ray IDs in bytes. +/// Length of Ray full-length IDs in bytes. constexpr int64_t kUniqueIDSize = 20; /// An ObjectID's bytes are split into the task ID itself and the index of the @@ -13,6 +13,9 @@ constexpr int kObjectIdIndexSize = 32; static_assert(kObjectIdIndexSize % CHAR_BIT == 0, "ObjectID prefix not a multiple of bytes"); +/// Length of Ray TaskID in bytes. 32-bit integer is used for object index. +constexpr int64_t kTaskIDSize = kUniqueIDSize - kObjectIdIndexSize / 8; + /// The maximum number of objects that can be returned by a task when finishing /// execution. An ObjectID's bytes are split into the task ID itself and the /// index of the object's creation. A positive index indicates an object diff --git a/src/ray/gcs/client_test.cc b/src/ray/gcs/client_test.cc index f7e25a4873ab..7f69c482e5eb 100644 --- a/src/ray/gcs/client_test.cc +++ b/src/ray/gcs/client_test.cc @@ -89,7 +89,7 @@ void TestTableLookup(const DriverID &driver_id, data->task_specification = "123"; // Check that we added the correct task. - auto add_callback = [task_id, data](gcs::AsyncGcsClient *client, const UniqueID &id, + auto add_callback = [task_id, data](gcs::AsyncGcsClient *client, const TaskID &id, const protocol::TaskT &d) { ASSERT_EQ(id, task_id); ASSERT_EQ(data->task_specification, d.task_specification); @@ -104,7 +104,7 @@ void TestTableLookup(const DriverID &driver_id, }; // Check that the lookup does not return an empty entry. - auto failure_callback = [](gcs::AsyncGcsClient *client, const UniqueID &id) { + auto failure_callback = [](gcs::AsyncGcsClient *client, const TaskID &id) { RAY_CHECK(false); }; @@ -139,7 +139,7 @@ void TestLogLookup(const DriverID &driver_id, auto data = std::make_shared(); data->node_manager_id = node_manager_id; // Check that we added the correct object entries. - auto add_callback = [task_id, data](gcs::AsyncGcsClient *client, const UniqueID &id, + auto add_callback = [task_id, data](gcs::AsyncGcsClient *client, const TaskID &id, const TaskReconstructionDataT &d) { ASSERT_EQ(id, task_id); ASSERT_EQ(data->node_manager_id, d.node_manager_id); @@ -150,7 +150,7 @@ void TestLogLookup(const DriverID &driver_id, // Check that lookup returns the added object entries. auto lookup_callback = [task_id, node_manager_ids]( - gcs::AsyncGcsClient *client, const UniqueID &id, + gcs::AsyncGcsClient *client, const TaskID &id, const std::vector &data) { ASSERT_EQ(id, task_id); for (const auto &entry : data) { @@ -181,11 +181,11 @@ void TestTableLookupFailure(const DriverID &driver_id, TaskID task_id = TaskID::from_random(); // Check that the lookup does not return data. - auto lookup_callback = [](gcs::AsyncGcsClient *client, const UniqueID &id, + auto lookup_callback = [](gcs::AsyncGcsClient *client, const TaskID &id, const protocol::TaskT &d) { RAY_CHECK(false); }; // Check that the lookup returns an empty entry. - auto failure_callback = [task_id](gcs::AsyncGcsClient *client, const UniqueID &id) { + auto failure_callback = [task_id](gcs::AsyncGcsClient *client, const TaskID &id) { ASSERT_EQ(id, task_id); test->Stop(); }; @@ -215,7 +215,7 @@ void TestLogAppendAt(const DriverID &driver_id, } // Check that we added the correct task. - auto failure_callback = [task_id](gcs::AsyncGcsClient *client, const UniqueID &id, + auto failure_callback = [task_id](gcs::AsyncGcsClient *client, const TaskID &id, const TaskReconstructionDataT &d) { ASSERT_EQ(id, task_id); test->IncrementNumCallbacks(); @@ -241,7 +241,7 @@ void TestLogAppendAt(const DriverID &driver_id, /*done callback=*/nullptr, failure_callback, /*log_length=*/1)); auto lookup_callback = [node_manager_ids]( - gcs::AsyncGcsClient *client, const UniqueID &id, + gcs::AsyncGcsClient *client, const TaskID &id, const std::vector &data) { std::vector appended_managers; for (const auto &entry : data) { @@ -271,7 +271,7 @@ void TestSet(const DriverID &driver_id, std::shared_ptr cli auto data = std::make_shared(); data->manager = manager; // Check that we added the correct object entries. - auto add_callback = [object_id, data](gcs::AsyncGcsClient *client, const UniqueID &id, + auto add_callback = [object_id, data](gcs::AsyncGcsClient *client, const ObjectID &id, const ObjectTableDataT &d) { ASSERT_EQ(id, object_id); ASSERT_EQ(data->manager, d.manager); @@ -297,7 +297,7 @@ void TestSet(const DriverID &driver_id, std::shared_ptr cli data->manager = manager; // Check that we added the correct object entries. auto remove_entry_callback = [object_id, data]( - gcs::AsyncGcsClient *client, const UniqueID &id, const ObjectTableDataT &d) { + gcs::AsyncGcsClient *client, const ObjectID &id, const ObjectTableDataT &d) { ASSERT_EQ(id, object_id); ASSERT_EQ(data->manager, d.manager); test->IncrementNumCallbacks(); @@ -338,7 +338,7 @@ void TestDeleteKeysFromLog( task_id = TaskID::from_random(); ids.push_back(task_id); // Check that we added the correct object entries. - auto add_callback = [task_id, data](gcs::AsyncGcsClient *client, const UniqueID &id, + auto add_callback = [task_id, data](gcs::AsyncGcsClient *client, const TaskID &id, const TaskReconstructionDataT &d) { ASSERT_EQ(id, task_id); ASSERT_EQ(data->node_manager_id, d.node_manager_id); @@ -350,7 +350,7 @@ void TestDeleteKeysFromLog( for (const auto &task_id : ids) { // Check that lookup returns the added object entries. auto lookup_callback = [task_id, data_vector]( - gcs::AsyncGcsClient *client, const UniqueID &id, + gcs::AsyncGcsClient *client, const TaskID &id, const std::vector &data) { ASSERT_EQ(id, task_id); ASSERT_EQ(data.size(), 1); @@ -386,7 +386,7 @@ void TestDeleteKeysFromTable(const DriverID &driver_id, task_id = TaskID::from_random(); ids.push_back(task_id); // Check that we added the correct object entries. - auto add_callback = [task_id, data](gcs::AsyncGcsClient *client, const UniqueID &id, + auto add_callback = [task_id, data](gcs::AsyncGcsClient *client, const TaskID &id, const protocol::TaskT &d) { ASSERT_EQ(id, task_id); ASSERT_EQ(data->task_specification, d.task_specification); @@ -434,7 +434,7 @@ void TestDeleteKeysFromSet(const DriverID &driver_id, object_id = ObjectID::from_random(); ids.push_back(object_id); // Check that we added the correct object entries. - auto add_callback = [object_id, data](gcs::AsyncGcsClient *client, const UniqueID &id, + auto add_callback = [object_id, data](gcs::AsyncGcsClient *client, const ObjectID &id, const ObjectTableDataT &d) { ASSERT_EQ(id, object_id); ASSERT_EQ(data->manager, d.manager); @@ -607,7 +607,7 @@ void TestLogSubscribeAll(const DriverID &driver_id, } // Callback for a notification. auto notification_callback = [driver_ids](gcs::AsyncGcsClient *client, - const UniqueID &id, + const DriverID &id, const std::vector data) { ASSERT_EQ(id, driver_ids[test->NumCallbacks()]); // Check that we get notifications in the same order as the writes. @@ -657,7 +657,7 @@ void TestSetSubscribeAll(const DriverID &driver_id, // Callback for a notification. auto notification_callback = [object_ids, managers]( - gcs::AsyncGcsClient *client, const UniqueID &id, + gcs::AsyncGcsClient *client, const ObjectID &id, const GcsTableNotificationMode notification_mode, const std::vector data) { if (test->NumCallbacks() < 3 * 3) { @@ -752,7 +752,7 @@ void TestTableSubscribeId(const DriverID &driver_id, // The failure callback should be called once since both keys start as empty. bool failure_notification_received = false; auto failure_callback = [task_id2, &failure_notification_received]( - gcs::AsyncGcsClient *client, const UniqueID &id) { + gcs::AsyncGcsClient *client, const TaskID &id) { ASSERT_EQ(id, task_id2); // The failure notification should be the first notification received. ASSERT_EQ(test->NumCallbacks(), 0); @@ -962,7 +962,7 @@ void TestTableSubscribeCancel(const DriverID &driver_id, // The failure callback should not be called since all keys are non-empty // when notifications are requested. - auto failure_callback = [](gcs::AsyncGcsClient *client, const UniqueID &id) { + auto failure_callback = [](gcs::AsyncGcsClient *client, const TaskID &id) { RAY_CHECK(false); }; diff --git a/src/ray/gcs/redis_context.cc b/src/ray/gcs/redis_context.cc index fe61df288d6b..6b03fa735007 100644 --- a/src/ray/gcs/redis_context.cc +++ b/src/ray/gcs/redis_context.cc @@ -226,45 +226,6 @@ Status RedisContext::AttachToEventLoop(aeEventLoop *loop) { } } -Status RedisContext::RunAsync(const std::string &command, const UniqueID &id, - const uint8_t *data, int64_t length, - const TablePrefix prefix, const TablePubsub pubsub_channel, - RedisCallback redisCallback, int log_length) { - int64_t callback_index = RedisCallbackManager::instance().add(redisCallback, false); - if (length > 0) { - if (log_length >= 0) { - std::string redis_command = command + " %d %d %b %b %d"; - int status = redisAsyncCommand( - async_context_, reinterpret_cast(&GlobalRedisCallback), - reinterpret_cast(callback_index), redis_command.c_str(), prefix, - pubsub_channel, id.data(), id.size(), data, length, log_length); - if (status == REDIS_ERR) { - return Status::RedisError(std::string(async_context_->errstr)); - } - } else { - std::string redis_command = command + " %d %d %b %b"; - int status = redisAsyncCommand( - async_context_, reinterpret_cast(&GlobalRedisCallback), - reinterpret_cast(callback_index), redis_command.c_str(), prefix, - pubsub_channel, id.data(), id.size(), data, length); - if (status == REDIS_ERR) { - return Status::RedisError(std::string(async_context_->errstr)); - } - } - } else { - RAY_CHECK(log_length == -1); - std::string redis_command = command + " %d %d %b"; - int status = redisAsyncCommand( - async_context_, reinterpret_cast(&GlobalRedisCallback), - reinterpret_cast(callback_index), redis_command.c_str(), prefix, - pubsub_channel, id.data(), id.size()); - if (status == REDIS_ERR) { - return Status::RedisError(std::string(async_context_->errstr)); - } - } - return Status::OK(); -} - Status RedisContext::RunArgvAsync(const std::vector &args) { // Build the arguments. std::vector argv; diff --git a/src/ray/gcs/redis_context.h b/src/ray/gcs/redis_context.h index 0af5a121e573..93a343464892 100644 --- a/src/ray/gcs/redis_context.h +++ b/src/ray/gcs/redis_context.h @@ -11,6 +11,12 @@ #include "ray/gcs/format/gcs_generated.h" +extern "C" { +#include "ray/thirdparty/hiredis/adapters/ae.h" +#include "ray/thirdparty/hiredis/async.h" +#include "ray/thirdparty/hiredis/hiredis.h" +} + struct redisContext; struct redisAsyncContext; struct aeEventLoop; @@ -22,6 +28,8 @@ namespace gcs { /// operation. using RedisCallback = std::function; +void GlobalRedisCallback(void *c, void *r, void *privdata); + class RedisCallbackManager { public: static RedisCallbackManager &instance() { @@ -83,7 +91,8 @@ class RedisContext { /// at which the data must be appended. For all other commands, set to /// -1 for unused. If set, then data must be provided. /// \return Status. - Status RunAsync(const std::string &command, const UniqueID &id, const uint8_t *data, + template + Status RunAsync(const std::string &command, const ID &id, const uint8_t *data, int64_t length, const TablePrefix prefix, const TablePubsub pubsub_channel, RedisCallback redisCallback, int log_length = -1); @@ -113,6 +122,46 @@ class RedisContext { redisAsyncContext *subscribe_context_; }; +template +Status RedisContext::RunAsync(const std::string &command, const ID &id, + const uint8_t *data, int64_t length, + const TablePrefix prefix, const TablePubsub pubsub_channel, + RedisCallback redisCallback, int log_length) { + int64_t callback_index = RedisCallbackManager::instance().add(redisCallback, false); + if (length > 0) { + if (log_length >= 0) { + std::string redis_command = command + " %d %d %b %b %d"; + int status = redisAsyncCommand( + async_context_, reinterpret_cast(&GlobalRedisCallback), + reinterpret_cast(callback_index), redis_command.c_str(), prefix, + pubsub_channel, id.data(), id.size(), data, length, log_length); + if (status == REDIS_ERR) { + return Status::RedisError(std::string(async_context_->errstr)); + } + } else { + std::string redis_command = command + " %d %d %b %b"; + int status = redisAsyncCommand( + async_context_, reinterpret_cast(&GlobalRedisCallback), + reinterpret_cast(callback_index), redis_command.c_str(), prefix, + pubsub_channel, id.data(), id.size(), data, length); + if (status == REDIS_ERR) { + return Status::RedisError(std::string(async_context_->errstr)); + } + } + } else { + RAY_CHECK(log_length == -1); + std::string redis_command = command + " %d %d %b"; + int status = redisAsyncCommand( + async_context_, reinterpret_cast(&GlobalRedisCallback), + reinterpret_cast(callback_index), redis_command.c_str(), prefix, + pubsub_channel, id.data(), id.size()); + if (status == REDIS_ERR) { + return Status::RedisError(std::string(async_context_->errstr)); + } + } + return Status::OK(); +} + } // namespace gcs } // namespace ray diff --git a/src/ray/gcs/redis_module/ray_redis_module.cc b/src/ray/gcs/redis_module/ray_redis_module.cc index 0405367e15f0..b9891e8cae32 100644 --- a/src/ray/gcs/redis_module/ray_redis_module.cc +++ b/src/ray/gcs/redis_module/ray_redis_module.cc @@ -676,13 +676,15 @@ int TableDelete_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int size_t len = 0; const char *data_ptr = nullptr; data_ptr = RedisModule_StringPtrLen(data, &len); - REPLY_AND_RETURN_IF_FALSE( - len % kUniqueIDSize == 0, - "The deletion data length must be a multiple of the UniqueID size."); - size_t ids_to_delete = len / kUniqueIDSize; + // The first uint16_t are used to encode the number of ids to delete. + size_t ids_to_delete = *reinterpret_cast(data_ptr); + size_t id_length = (len - sizeof(uint16_t)) / ids_to_delete; + REPLY_AND_RETURN_IF_FALSE((len - sizeof(uint16_t)) % ids_to_delete == 0, + "The deletion data length must be multiple of the ID size"); + data_ptr += sizeof(uint16_t); for (size_t i = 0; i < ids_to_delete; ++i) { RedisModuleString *id_data = - RedisModule_CreateString(ctx, data_ptr + i * kUniqueIDSize, kUniqueIDSize); + RedisModule_CreateString(ctx, data_ptr + i * id_length, id_length); RAY_IGNORE_EXPR(DeleteKeyHelper(ctx, prefix_str, id_data)); } return RedisModule_ReplyWithSimpleString(ctx, "OK"); diff --git a/src/ray/gcs/tables.cc b/src/ray/gcs/tables.cc index dbd39349caf7..3d4708940d1a 100644 --- a/src/ray/gcs/tables.cc +++ b/src/ray/gcs/tables.cc @@ -192,15 +192,25 @@ void Log::Delete(const DriverID &driver_id, const std::vector &ids } // Breaking really large deletion commands into batches of smaller size. const size_t batch_size = - RayConfig::instance().maximum_gcs_deletion_batch_size() * kUniqueIDSize; + RayConfig::instance().maximum_gcs_deletion_batch_size() * ID::size(); for (const auto &pair : sharded_data) { std::string current_data = pair.second.str(); for (size_t cur = 0; cur < pair.second.str().size(); cur += batch_size) { - RAY_IGNORE_EXPR(pair.first->RunAsync( - "RAY.TABLE_DELETE", UniqueID::nil(), - reinterpret_cast(current_data.c_str() + cur), - std::min(batch_size, current_data.size() - cur), prefix_, pubsub_channel_, - /*redisCallback=*/nullptr)); + size_t data_field_size = std::min(batch_size, current_data.size() - cur); + uint16_t id_count = data_field_size / ID::size(); + // Send data contains id count and all the id data. + std::string send_data(data_field_size + sizeof(id_count), 0); + uint8_t *buffer = reinterpret_cast(&send_data[0]); + *reinterpret_cast(buffer) = id_count; + RAY_IGNORE_EXPR( + std::copy_n(reinterpret_cast(current_data.c_str() + cur), + data_field_size, buffer + sizeof(uint16_t))); + + RAY_IGNORE_EXPR( + pair.first->RunAsync("RAY.TABLE_DELETE", UniqueID::nil(), + reinterpret_cast(send_data.c_str()), + send_data.size(), prefix_, pubsub_channel_, + /*redisCallback=*/nullptr)); } } } diff --git a/src/ray/gcs/tables.h b/src/ray/gcs/tables.h index 056bf7b97ec7..58a087d8c666 100644 --- a/src/ray/gcs/tables.h +++ b/src/ray/gcs/tables.h @@ -206,7 +206,7 @@ class Log : public LogInterface, virtual public PubsubInterface { protected: std::shared_ptr GetRedisContext(const ID &id) { - static std::hash index; + static std::hash index; return shard_contexts_[index(id) % shard_contexts_.size()]; } diff --git a/src/ray/id.cc b/src/ray/id.cc index 8d72cef8b300..a011430ad1cf 100644 --- a/src/ray/id.cc +++ b/src/ray/id.cc @@ -26,82 +26,16 @@ std::mt19937 RandomlySeededMersenneTwister() { uint64_t MurmurHash64A(const void *key, int len, unsigned int seed); -UniqueID::UniqueID() { - // Set the ID to nil. - std::fill_n(id_, kUniqueIDSize, 255); -} - -UniqueID::UniqueID(const std::string &binary) { - std::memcpy(&id_, binary.data(), kUniqueIDSize); -} - -UniqueID::UniqueID(const plasma::UniqueID &from) { - std::memcpy(&id_, from.data(), kUniqueIDSize); -} - -UniqueID UniqueID::from_random() { - std::string data(kUniqueIDSize, 0); - // NOTE(pcm): The right way to do this is to have one std::mt19937 per - // thread (using the thread_local keyword), but that's not supported on - // older versions of macOS (see https://stackoverflow.com/a/29929949) - static std::mutex random_engine_mutex; - std::lock_guard lock(random_engine_mutex); - static std::mt19937 generator = RandomlySeededMersenneTwister(); - std::uniform_int_distribution dist(0, std::numeric_limits::max()); - for (int i = 0; i < kUniqueIDSize; i++) { - data[i] = static_cast(dist(generator)); - } - return UniqueID::from_binary(data); -} - -UniqueID UniqueID::from_binary(const std::string &binary) { return UniqueID(binary); } - -const UniqueID &UniqueID::nil() { - static const UniqueID nil_id; - return nil_id; -} - -bool UniqueID::is_nil() const { - const uint8_t *d = data(); - for (int i = 0; i < kUniqueIDSize; ++i) { - if (d[i] != 255) { - return false; - } - } - return true; -} - -const uint8_t *UniqueID::data() const { return id_; } - -size_t UniqueID::size() { return kUniqueIDSize; } - -std::string UniqueID::binary() const { - return std::string(reinterpret_cast(id_), kUniqueIDSize); -} - -std::string UniqueID::hex() const { - constexpr char hex[] = "0123456789abcdef"; - std::string result; - for (int i = 0; i < kUniqueIDSize; i++) { - unsigned int val = id_[i]; - result.push_back(hex[val >> 4]); - result.push_back(hex[val & 0xf]); - } - return result; -} - -plasma::UniqueID UniqueID::to_plasma_id() const { +plasma::UniqueID ObjectID::to_plasma_id() const { plasma::UniqueID result; - std::memcpy(result.mutable_data(), &id_, kUniqueIDSize); + std::memcpy(result.mutable_data(), data(), kUniqueIDSize); return result; } -bool UniqueID::operator==(const UniqueID &rhs) const { - return std::memcmp(data(), rhs.data(), kUniqueIDSize) == 0; +ObjectID::ObjectID(const plasma::UniqueID &from) { + std::memcpy(this->mutable_data(), from.data(), kUniqueIDSize); } -bool UniqueID::operator!=(const UniqueID &rhs) const { return !(*this == rhs); } - // This code is from https://sites.google.com/site/murmurhash/ // and is public domain. uint64_t MurmurHash64A(const void *key, int len, unsigned int seed) { @@ -151,60 +85,32 @@ uint64_t MurmurHash64A(const void *key, int len, unsigned int seed) { return h; } -size_t UniqueID::hash() const { - // Note(ashione): hash code lazy calculation(it's invoked every time if hash code is - // default value 0) - if (!hash_) { - hash_ = MurmurHash64A(&id_[0], kUniqueIDSize, 0); - } - return hash_; +TaskID TaskID::GetDriverTaskID(const DriverID &driver_id) { + std::string driver_id_str = driver_id.binary(); + driver_id_str.resize(size()); + return TaskID::from_binary(driver_id_str); } -std::ostream &operator<<(std::ostream &os, const UniqueID &id) { - if (id.is_nil()) { - os << "NIL_ID"; - } else { - os << id.hex(); - } - return os; +TaskID ObjectID::task_id() const { + return TaskID::from_binary( + std::string(reinterpret_cast(id_), TaskID::size())); } -const ObjectID ComputeObjectId(const TaskID &task_id, int64_t object_index) { - RAY_CHECK(object_index <= kMaxTaskReturns && object_index >= -kMaxTaskPuts); - ObjectID return_id = ObjectID(task_id); - int64_t *first_bytes = reinterpret_cast(&return_id); - // Zero out the lowest kObjectIdIndexSize bits of the first byte of the - // object ID. - uint64_t bitmask = static_cast(-1) << kObjectIdIndexSize; - *first_bytes = *first_bytes & (bitmask); - // OR the first byte of the object ID with the return index. - *first_bytes = *first_bytes | (object_index & ~bitmask); - return return_id; +ObjectID ObjectID::for_put(const TaskID &task_id, int64_t put_index) { + RAY_CHECK(put_index >= 1 && put_index <= kMaxTaskPuts) << "index=" << put_index; + ObjectID object_id; + std::memcpy(object_id.id_, task_id.binary().c_str(), task_id.size()); + object_id.index_ = -put_index; + return object_id; } -const TaskID FinishTaskId(const TaskID &task_id) { - return TaskID(ComputeObjectId(task_id, 0)); -} - -const ObjectID ComputeReturnId(const TaskID &task_id, int64_t return_index) { - RAY_CHECK(return_index >= 1 && return_index <= kMaxTaskReturns); - return ComputeObjectId(task_id, return_index); -} - -const ObjectID ComputePutId(const TaskID &task_id, int64_t put_index) { - RAY_CHECK(put_index >= 1 && put_index <= kMaxTaskPuts); - // We multiply put_index by -1 to distinguish from return_index. - return ComputeObjectId(task_id, -1 * put_index); -} - -const TaskID ComputeTaskId(const ObjectID &object_id) { - TaskID task_id = TaskID(object_id); - int64_t *first_bytes = reinterpret_cast(&task_id); - // Zero out the lowest kObjectIdIndexSize bits of the first byte of the - // object ID. - uint64_t bitmask = static_cast(-1) << kObjectIdIndexSize; - *first_bytes = *first_bytes & (bitmask); - return task_id; +ObjectID ObjectID::for_task_return(const TaskID &task_id, int64_t return_index) { + RAY_CHECK(return_index >= 1 && return_index <= kMaxTaskReturns) << "index=" + << return_index; + ObjectID object_id; + std::memcpy(object_id.id_, task_id.binary().c_str(), task_id.size()); + object_id.index_ = return_index; + return object_id; } const TaskID GenerateTaskId(const DriverID &driver_id, const TaskID &parent_task_id, @@ -220,16 +126,21 @@ const TaskID GenerateTaskId(const DriverID &driver_id, const TaskID &parent_task // Compute the final task ID from the hash. BYTE buff[DIGEST_SIZE]; sha256_final(&ctx, buff); - return FinishTaskId(TaskID::from_binary(std::string(buff, buff + kUniqueIDSize))); + return TaskID::from_binary(std::string(buff, buff + TaskID::size())); } -int64_t ComputeObjectIndex(const ObjectID &object_id) { - const int64_t *first_bytes = reinterpret_cast(&object_id); - uint64_t bitmask = static_cast(-1) << kObjectIdIndexSize; - int64_t index = *first_bytes & (~bitmask); - index <<= (8 * sizeof(int64_t) - kObjectIdIndexSize); - index >>= (8 * sizeof(int64_t) - kObjectIdIndexSize); - return index; -} +#define ID_OSTREAM_OPERATOR(id_type) \ + std::ostream &operator<<(std::ostream &os, const id_type &id) { \ + if (id.is_nil()) { \ + os << "NIL_ID"; \ + } else { \ + os << id.hex(); \ + } \ + return os; \ + } + +ID_OSTREAM_OPERATOR(UniqueID); +ID_OSTREAM_OPERATOR(TaskID); +ID_OSTREAM_OPERATOR(ObjectID); } // namespace ray diff --git a/src/ray/id.h b/src/ray/id.h index 9467c1a3f11d..f90f66549358 100644 --- a/src/ray/id.h +++ b/src/ray/id.h @@ -2,44 +2,128 @@ #define RAY_ID_H_ #include +#include +#include #include +#include +#include #include #include "plasma/common.h" #include "ray/constants.h" +#include "ray/util/logging.h" #include "ray/util/visibility.h" namespace ray { -class RAY_EXPORT UniqueID { +class DriverID; +class UniqueID; + +// Declaration. +std::mt19937 RandomlySeededMersenneTwister(); +uint64_t MurmurHash64A(const void *key, int len, unsigned int seed); + +// Change the compiler alignment to 1 byte (default is 8). +#pragma pack(push, 1) + +template +class BaseID { public: - UniqueID(); - UniqueID(const plasma::UniqueID &from); - static UniqueID from_random(); - static UniqueID from_binary(const std::string &binary); - static const UniqueID &nil(); + BaseID(); + static T from_random(); + static T from_binary(const std::string &binary); + static const T &nil(); + static size_t size() { return T::size(); } + size_t hash() const; bool is_nil() const; - bool operator==(const UniqueID &rhs) const; - bool operator!=(const UniqueID &rhs) const; + bool operator==(const BaseID &rhs) const; + bool operator!=(const BaseID &rhs) const; const uint8_t *data() const; - static size_t size(); std::string binary() const; std::string hex() const; - plasma::UniqueID to_plasma_id() const; - private: + protected: + BaseID(const std::string &binary) { + std::memcpy(const_cast(this->data()), binary.data(), T::size()); + } + // All IDs are immutable for hash evaluations. mutable_data is only allow to use + // in construction time, so this function is protected. + uint8_t *mutable_data(); + // For lazy evaluation, be careful to have one Id contained in another. + // This hash code will be duplicated. + mutable size_t hash_ = 0; +}; + +class UniqueID : public BaseID { + public: + UniqueID() : BaseID(){}; + static size_t size() { return kUniqueIDSize; } + + protected: UniqueID(const std::string &binary); protected: uint8_t id_[kUniqueIDSize]; - mutable size_t hash_ = 0; }; -static_assert(std::is_standard_layout::value, "UniqueID must be standard"); +class TaskID : public BaseID { + public: + TaskID() : BaseID() {} + static size_t size() { return kTaskIDSize; } + static TaskID GetDriverTaskID(const DriverID &driver_id); + + private: + uint8_t id_[kTaskIDSize]; +}; + +class ObjectID : public BaseID { + public: + ObjectID() : BaseID() {} + static size_t size() { return kUniqueIDSize; } + plasma::ObjectID to_plasma_id() const; + ObjectID(const plasma::UniqueID &from); + + /// Get the index of this object in the task that created it. + /// + /// \return The index of object creation according to the task that created + /// this object. This is positive if the task returned the object and negative + /// if created by a put. + int32_t object_index() const { return index_; } + + /// Compute the task ID of the task that created the object. + /// + /// \return The task ID of the task that created this object. + TaskID task_id() const; + + /// Compute the object ID of an object put by the task. + /// + /// \param task_id The task ID of the task that created the object. + /// \param index What index of the object put in the task. + /// \return The computed object ID. + static ObjectID for_put(const TaskID &task_id, int64_t put_index); + + /// Compute the object ID of an object returned by the task. + /// + /// \param task_id The task ID of the task that created the object. + /// \param return_index What index of the object returned by in the task. + /// \return The computed object ID. + static ObjectID for_task_return(const TaskID &task_id, int64_t return_index); + + private: + uint8_t id_[kTaskIDSize]; + int32_t index_; +}; + +static_assert(sizeof(TaskID) == kTaskIDSize + sizeof(size_t), + "TaskID size is not as expected"); +static_assert(sizeof(ObjectID) == sizeof(int32_t) + sizeof(TaskID), + "ObjectID size is not as expected"); std::ostream &operator<<(std::ostream &os, const UniqueID &id); +std::ostream &operator<<(std::ostream &os, const TaskID &id); +std::ostream &operator<<(std::ostream &os, const ObjectID &id); #define DEFINE_UNIQUE_ID(type) \ class RAY_EXPORT type : public UniqueID { \ @@ -63,35 +147,8 @@ std::ostream &operator<<(std::ostream &os, const UniqueID &id); #undef DEFINE_UNIQUE_ID -// TODO(swang): ObjectID and TaskID should derive from UniqueID. Then, we -// can make these methods of the derived classes. -/// Finish computing a task ID. Since objects created by the task share a -/// prefix of the ID, the suffix of the task ID is zeroed out by this function. -/// -/// \param task_id A task ID to finish. -/// \return The finished task ID. It may now be used to compute IDs for objects -/// created by the task. -const TaskID FinishTaskId(const TaskID &task_id); - -/// Compute the object ID of an object returned by the task. -/// -/// \param task_id The task ID of the task that created the object. -/// \param return_index What number return value this object is in the task. -/// \return The computed object ID. -const ObjectID ComputeReturnId(const TaskID &task_id, int64_t return_index); - -/// Compute the object ID of an object put by the task. -/// -/// \param task_id The task ID of the task that created the object. -/// \param put_index What number put this object was created by in the task. -/// \return The computed object ID. -const ObjectID ComputePutId(const TaskID &task_id, int64_t put_index); - -/// Compute the task ID of the task that created the object. -/// -/// \param object_id The object ID. -/// \return The task ID of the task that created this object. -const TaskID ComputeTaskId(const ObjectID &object_id); +// Restore the compiler alignment to defult (8 bytes). +#pragma pack(pop) /// Generate a task ID from the given info. /// @@ -102,13 +159,95 @@ const TaskID ComputeTaskId(const ObjectID &object_id); const TaskID GenerateTaskId(const DriverID &driver_id, const TaskID &parent_task_id, int parent_task_counter); -/// Compute the index of this object in the task that created it. -/// -/// \param object_id The object ID. -/// \return The index of object creation according to the task that created -/// this object. This is positive if the task returned the object and negative -/// if created by a put. -int64_t ComputeObjectIndex(const ObjectID &object_id); +template +BaseID::BaseID() { + // Using const_cast to directly change data is dangerous. The cached + // hash may not be changed. This is used in construction time. + std::fill_n(this->mutable_data(), T::size(), 0xff); +} + +template +T BaseID::from_random() { + std::string data(T::size(), 0); + // NOTE(pcm): The right way to do this is to have one std::mt19937 per + // thread (using the thread_local keyword), but that's not supported on + // older versions of macOS (see https://stackoverflow.com/a/29929949) + static std::mutex random_engine_mutex; + std::lock_guard lock(random_engine_mutex); + static std::mt19937 generator = RandomlySeededMersenneTwister(); + std::uniform_int_distribution dist(0, std::numeric_limits::max()); + for (int i = 0; i < T::size(); i++) { + data[i] = static_cast(dist(generator)); + } + return T::from_binary(data); +} + +template +T BaseID::from_binary(const std::string &binary) { + T t = T::nil(); + std::memcpy(t.mutable_data(), binary.data(), T::size()); + return t; +} + +template +const T &BaseID::nil() { + static const T nil_id; + return nil_id; +} + +template +bool BaseID::is_nil() const { + static T nil_id = T::nil(); + return *this == nil_id; +} + +template +size_t BaseID::hash() const { + // Note(ashione): hash code lazy calculation(it's invoked every time if hash code is + // default value 0) + if (!hash_) { + hash_ = MurmurHash64A(data(), T::size(), 0); + } + return hash_; +} + +template +bool BaseID::operator==(const BaseID &rhs) const { + return std::memcmp(data(), rhs.data(), T::size()) == 0; +} + +template +bool BaseID::operator!=(const BaseID &rhs) const { + return !(*this == rhs); +} + +template +uint8_t *BaseID::mutable_data() { + return reinterpret_cast(this) + sizeof(hash_); +} + +template +const uint8_t *BaseID::data() const { + return reinterpret_cast(this) + sizeof(hash_); +} + +template +std::string BaseID::binary() const { + return std::string(reinterpret_cast(data()), T::size()); +} + +template +std::string BaseID::hex() const { + constexpr char hex[] = "0123456789abcdef"; + const uint8_t *id = data(); + std::string result; + for (int i = 0; i < T::size(); i++) { + unsigned int val = id[i]; + result.push_back(hex[val >> 4]); + result.push_back(hex[val & 0xf]); + } + return result; +} } // namespace ray @@ -125,6 +264,8 @@ namespace std { }; DEFINE_UNIQUE_ID(UniqueID); +DEFINE_UNIQUE_ID(TaskID); +DEFINE_UNIQUE_ID(ObjectID); #include "id_def.h" #undef DEFINE_UNIQUE_ID diff --git a/src/ray/id_def.h b/src/ray/id_def.h index 8a5e7e943262..96c7d59d1098 100644 --- a/src/ray/id_def.h +++ b/src/ray/id_def.h @@ -4,8 +4,6 @@ // Macro definition format: DEFINE_UNIQUE_ID(id_type). // NOTE: This file should NOT be included in any file other than id.h. -DEFINE_UNIQUE_ID(TaskID) -DEFINE_UNIQUE_ID(ObjectID) DEFINE_UNIQUE_ID(FunctionID) DEFINE_UNIQUE_ID(ActorClassID) DEFINE_UNIQUE_ID(ActorID) diff --git a/src/ray/object_manager/test/object_manager_test.cc b/src/ray/object_manager/test/object_manager_test.cc index a373ea9b9365..98eeb9186192 100644 --- a/src/ray/object_manager/test/object_manager_test.cc +++ b/src/ray/object_manager/test/object_manager_test.cc @@ -288,7 +288,7 @@ class TestObjectManager : public TestObjectManagerBase { // object. ObjectID object_1 = WriteDataToClient(client2, data_size); ObjectID object_2 = WriteDataToClient(client2, data_size); - UniqueID sub_id = ray::ObjectID::from_random(); + UniqueID sub_id = ray::UniqueID::from_random(); RAY_CHECK_OK(server1->object_manager_.object_directory_->SubscribeObjectLocations( sub_id, object_1, [this, sub_id, object_1, object_2]( diff --git a/src/ray/raylet/lineage_cache.cc b/src/ray/raylet/lineage_cache.cc index 94f1dc11f189..4c3fac24f19e 100644 --- a/src/ray/raylet/lineage_cache.cc +++ b/src/ray/raylet/lineage_cache.cc @@ -48,7 +48,7 @@ void LineageEntry::ComputeParentTaskIds() { parent_task_ids_.clear(); // A task's parents are the tasks that created its arguments. for (const auto &dependency : task_.GetDependencies()) { - parent_task_ids_.insert(ComputeTaskId(dependency)); + parent_task_ids_.insert(dependency.task_id()); } } diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index efd190ba5b27..2e25407f12fb 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -852,7 +852,7 @@ void NodeManager::ProcessClientMessage( // Clean up their creating tasks from GCS. std::vector creating_task_ids; for (const auto &object_id : object_ids) { - creating_task_ids.push_back(ComputeTaskId(object_id)); + creating_task_ids.push_back(object_id.task_id()); } gcs_client_->raylet_task_table().Delete(DriverID::nil(), creating_task_ids); } @@ -887,11 +887,12 @@ void NodeManager::ProcessRegisterClientRequestMessage( // message is actually the ID of the driver task, while client_id represents the // real driver ID, which can associate all the tasks/actors for a given driver, // which is set to the worker ID. - const DriverID driver_task_id = from_flatbuf(*message->driver_id()); - worker->AssignTaskId(TaskID(driver_task_id)); + const DriverID driver_id = from_flatbuf(*message->driver_id()); + TaskID driver_task_id = TaskID::GetDriverTaskID(driver_id); + worker->AssignTaskId(driver_task_id); worker->AssignDriverId(from_flatbuf(*message->client_id())); worker_pool_.RegisterDriver(std::move(worker)); - local_queues_.AddDriverTaskId(TaskID(driver_task_id)); + local_queues_.AddDriverTaskId(driver_task_id); } } diff --git a/src/ray/raylet/reconstruction_policy.cc b/src/ray/raylet/reconstruction_policy.cc index 4274ff5a2018..d1a648a34ce4 100644 --- a/src/ray/raylet/reconstruction_policy.cc +++ b/src/ray/raylet/reconstruction_policy.cc @@ -171,7 +171,7 @@ void ReconstructionPolicy::HandleTaskLeaseNotification(const TaskID &task_id, } void ReconstructionPolicy::ListenAndMaybeReconstruct(const ObjectID &object_id) { - TaskID task_id = ComputeTaskId(object_id); + TaskID task_id = object_id.task_id(); auto it = listening_tasks_.find(task_id); // Add this object to the list of objects created by the same task. if (it == listening_tasks_.end()) { @@ -185,7 +185,7 @@ void ReconstructionPolicy::ListenAndMaybeReconstruct(const ObjectID &object_id) } void ReconstructionPolicy::Cancel(const ObjectID &object_id) { - TaskID task_id = ComputeTaskId(object_id); + TaskID task_id = object_id.task_id(); auto it = listening_tasks_.find(task_id); if (it == listening_tasks_.end()) { // We already stopped listening for this task. diff --git a/src/ray/raylet/reconstruction_policy_test.cc b/src/ray/raylet/reconstruction_policy_test.cc index d9fb92388aa6..7f8887b15372 100644 --- a/src/ray/raylet/reconstruction_policy_test.cc +++ b/src/ray/raylet/reconstruction_policy_test.cc @@ -224,8 +224,7 @@ class ReconstructionPolicyTest : public ::testing::Test { TEST_F(ReconstructionPolicyTest, TestReconstructionSimple) { TaskID task_id = TaskID::from_random(); - task_id = FinishTaskId(task_id); - ObjectID object_id = ComputeReturnId(task_id, 1); + ObjectID object_id = ObjectID::for_task_return(task_id, 1); // Listen for an object. reconstruction_policy_->ListenAndMaybeReconstruct(object_id); @@ -243,8 +242,7 @@ TEST_F(ReconstructionPolicyTest, TestReconstructionSimple) { TEST_F(ReconstructionPolicyTest, TestReconstructionEvicted) { TaskID task_id = TaskID::from_random(); - task_id = FinishTaskId(task_id); - ObjectID object_id = ComputeReturnId(task_id, 1); + ObjectID object_id = ObjectID::for_task_return(task_id, 1); mock_object_directory_->SetObjectLocations(object_id, {ClientID::from_random()}); // Listen for both objects. @@ -267,8 +265,7 @@ TEST_F(ReconstructionPolicyTest, TestReconstructionEvicted) { TEST_F(ReconstructionPolicyTest, TestReconstructionObjectLost) { TaskID task_id = TaskID::from_random(); - task_id = FinishTaskId(task_id); - ObjectID object_id = ComputeReturnId(task_id, 1); + ObjectID object_id = ObjectID::for_task_return(task_id, 1); ClientID client_id = ClientID::from_random(); mock_object_directory_->SetObjectLocations(object_id, {client_id}); @@ -292,9 +289,8 @@ TEST_F(ReconstructionPolicyTest, TestReconstructionObjectLost) { TEST_F(ReconstructionPolicyTest, TestDuplicateReconstruction) { // Create two object IDs produced by the same task. TaskID task_id = TaskID::from_random(); - task_id = FinishTaskId(task_id); - ObjectID object_id1 = ComputeReturnId(task_id, 1); - ObjectID object_id2 = ComputeReturnId(task_id, 2); + ObjectID object_id1 = ObjectID::for_task_return(task_id, 1); + ObjectID object_id2 = ObjectID::for_task_return(task_id, 2); // Listen for both objects. reconstruction_policy_->ListenAndMaybeReconstruct(object_id1); @@ -313,8 +309,7 @@ TEST_F(ReconstructionPolicyTest, TestDuplicateReconstruction) { TEST_F(ReconstructionPolicyTest, TestReconstructionSuppressed) { TaskID task_id = TaskID::from_random(); - task_id = FinishTaskId(task_id); - ObjectID object_id = ComputeReturnId(task_id, 1); + ObjectID object_id = ObjectID::for_task_return(task_id, 1); // Run the test for much longer than the reconstruction timeout. int64_t test_period = 2 * reconstruction_timeout_ms_; @@ -340,8 +335,7 @@ TEST_F(ReconstructionPolicyTest, TestReconstructionSuppressed) { TEST_F(ReconstructionPolicyTest, TestReconstructionContinuallySuppressed) { TaskID task_id = TaskID::from_random(); - task_id = FinishTaskId(task_id); - ObjectID object_id = ComputeReturnId(task_id, 1); + ObjectID object_id = ObjectID::for_task_return(task_id, 1); // Listen for an object. reconstruction_policy_->ListenAndMaybeReconstruct(object_id); @@ -368,8 +362,7 @@ TEST_F(ReconstructionPolicyTest, TestReconstructionContinuallySuppressed) { TEST_F(ReconstructionPolicyTest, TestReconstructionCanceled) { TaskID task_id = TaskID::from_random(); - task_id = FinishTaskId(task_id); - ObjectID object_id = ComputeReturnId(task_id, 1); + ObjectID object_id = ObjectID::for_task_return(task_id, 1); // Listen for an object. reconstruction_policy_->ListenAndMaybeReconstruct(object_id); @@ -395,8 +388,7 @@ TEST_F(ReconstructionPolicyTest, TestReconstructionCanceled) { TEST_F(ReconstructionPolicyTest, TestSimultaneousReconstructionSuppressed) { TaskID task_id = TaskID::from_random(); - task_id = FinishTaskId(task_id); - ObjectID object_id = ComputeReturnId(task_id, 1); + ObjectID object_id = ObjectID::for_task_return(task_id, 1); // Log a reconstruction attempt to simulate a different node attempting the // reconstruction first. This should suppress this node's first attempt at diff --git a/src/ray/raylet/task_dependency_manager.cc b/src/ray/raylet/task_dependency_manager.cc index 4fbebb8df79f..dc24c95d46e4 100644 --- a/src/ray/raylet/task_dependency_manager.cc +++ b/src/ray/raylet/task_dependency_manager.cc @@ -24,7 +24,7 @@ bool TaskDependencyManager::CheckObjectLocal(const ObjectID &object_id) const { } bool TaskDependencyManager::CheckObjectRequired(const ObjectID &object_id) const { - const TaskID task_id = ComputeTaskId(object_id); + const TaskID task_id = object_id.task_id(); auto task_entry = required_tasks_.find(task_id); // If there are no subscribed tasks that are dependent on the object, then do // nothing. @@ -82,7 +82,7 @@ std::vector TaskDependencyManager::HandleObjectLocal( // Find any tasks that are dependent on the newly available object. std::vector ready_task_ids; - auto creating_task_entry = required_tasks_.find(ComputeTaskId(object_id)); + auto creating_task_entry = required_tasks_.find(object_id.task_id()); if (creating_task_entry != required_tasks_.end()) { auto object_entry = creating_task_entry->second.find(object_id); if (object_entry != creating_task_entry->second.end()) { @@ -113,7 +113,7 @@ std::vector TaskDependencyManager::HandleObjectMissing( // Find any tasks that are dependent on the missing object. std::vector waiting_task_ids; - TaskID creating_task_id = ComputeTaskId(object_id); + TaskID creating_task_id = object_id.task_id(); auto creating_task_entry = required_tasks_.find(creating_task_id); if (creating_task_entry != required_tasks_.end()) { auto object_entry = creating_task_entry->second.find(object_id); @@ -149,7 +149,7 @@ bool TaskDependencyManager::SubscribeDependencies( auto inserted = task_entry.object_dependencies.insert(object_id); if (inserted.second) { // Get the ID of the task that creates the dependency. - TaskID creating_task_id = ComputeTaskId(object_id); + TaskID creating_task_id = object_id.task_id(); // Determine whether the dependency can be fulfilled by the local node. if (local_objects_.count(object_id) == 0) { // The object is not local. @@ -186,7 +186,7 @@ bool TaskDependencyManager::UnsubscribeDependencies(const TaskID &task_id) { // Remove the task from the list of tasks that are dependent on this // object. // Get the ID of the task that creates the dependency. - TaskID creating_task_id = ComputeTaskId(object_id); + TaskID creating_task_id = object_id.task_id(); auto creating_task_entry = required_tasks_.find(creating_task_id); std::vector &dependent_tasks = creating_task_entry->second[object_id]; auto it = std::find(dependent_tasks.begin(), dependent_tasks.end(), task_id); @@ -324,7 +324,7 @@ void TaskDependencyManager::RemoveTasksAndRelatedObjects( // Cancel all of the objects that were required by the removed tasks. for (const auto &object_id : required_objects) { - TaskID creating_task_id = ComputeTaskId(object_id); + TaskID creating_task_id = object_id.task_id(); required_tasks_.erase(creating_task_id); HandleRemoteDependencyCanceled(object_id); } diff --git a/src/ray/raylet/task_dependency_manager_test.cc b/src/ray/raylet/task_dependency_manager_test.cc index 5126d82555af..62bbf17069d5 100644 --- a/src/ray/raylet/task_dependency_manager_test.cc +++ b/src/ray/raylet/task_dependency_manager_test.cc @@ -266,7 +266,7 @@ TEST_F(TaskDependencyManagerTest, TestTaskChain) { TEST_F(TaskDependencyManagerTest, TestDependentPut) { // Create a task with 3 arguments. auto task1 = ExampleTask({}, 0); - ObjectID put_id = ComputePutId(task1.GetTaskSpecification().TaskId(), 1); + ObjectID put_id = ObjectID::for_put(task1.GetTaskSpecification().TaskId(), 1); auto task2 = ExampleTask({put_id}, 0); // No objects have been registered in the task dependency manager, so the put diff --git a/src/ray/raylet/task_spec.cc b/src/ray/raylet/task_spec.cc index 5f301c47c1c3..d4ec4f5c5e75 100644 --- a/src/ray/raylet/task_spec.cc +++ b/src/ray/raylet/task_spec.cc @@ -95,7 +95,7 @@ TaskSpecification::TaskSpecification( // Generate return ids. std::vector returns; for (int64_t i = 1; i < num_returns + 1; ++i) { - returns.push_back(ComputeReturnId(task_id, i)); + returns.push_back(ObjectID::for_task_return(task_id, i)); } // Serialize the TaskSpecification. diff --git a/src/ray/raylet/task_test.cc b/src/ray/raylet/task_test.cc index 9f3545bdf638..03a4caff16ee 100644 --- a/src/ray/raylet/task_test.cc +++ b/src/ray/raylet/task_test.cc @@ -9,21 +9,21 @@ namespace raylet { void TestTaskReturnId(const TaskID &task_id, int64_t return_index) { // Round trip test for computing the object ID for a task's return value, // then computing the task ID that created the object. - ObjectID return_id = ComputeReturnId(task_id, return_index); - ASSERT_EQ(ComputeTaskId(return_id), task_id); - ASSERT_EQ(ComputeObjectIndex(return_id), return_index); + ObjectID return_id = ObjectID::for_task_return(task_id, return_index); + ASSERT_EQ(return_id.task_id(), task_id); + ASSERT_EQ(return_id.object_index(), return_index); } void TestTaskPutId(const TaskID &task_id, int64_t put_index) { // Round trip test for computing the object ID for a task's put value, then // computing the task ID that created the object. - ObjectID put_id = ComputePutId(task_id, put_index); - ASSERT_EQ(ComputeTaskId(put_id), task_id); - ASSERT_EQ(ComputeObjectIndex(put_id), -1 * put_index); + ObjectID put_id = ObjectID::for_put(task_id, put_index); + ASSERT_EQ(put_id.task_id(), task_id); + ASSERT_EQ(put_id.object_index(), -1 * put_index); } TEST(TaskSpecTest, TestTaskReturnIds) { - TaskID task_id = FinishTaskId(TaskID::from_random()); + TaskID task_id = TaskID::from_random(); // Check that we can compute between a task ID and the object IDs of its // return values and puts. @@ -35,6 +35,18 @@ TEST(TaskSpecTest, TestTaskReturnIds) { TestTaskPutId(task_id, kMaxTaskPuts); } +TEST(IdPropertyTest, TestIdProperty) { + TaskID task_id = TaskID::from_random(); + ASSERT_EQ(task_id, TaskID::from_binary(task_id.binary())); + ObjectID object_id = ObjectID::from_random(); + ASSERT_EQ(object_id, ObjectID::from_binary(object_id.binary())); + + ASSERT_TRUE(TaskID().is_nil()); + ASSERT_TRUE(TaskID::nil().is_nil()); + ASSERT_TRUE(ObjectID().is_nil()); + ASSERT_TRUE(ObjectID::nil().is_nil()); +} + } // namespace raylet } // namespace ray From 20150851923249587c0ac981d9232dfa32f85512 Mon Sep 17 00:00:00 2001 From: Robert Nishihara Date: Thu, 23 May 2019 09:22:46 -0700 Subject: [PATCH 030/118] Fix bug in which actor classes are not exported multiple times. (#4838) --- python/ray/actor.py | 15 ++++++++++----- python/ray/remote_function.py | 4 +++- python/ray/tests/test_basic.py | 26 ++++++++++++++++++++++++++ 3 files changed, 39 insertions(+), 6 deletions(-) diff --git a/python/ray/actor.py b/python/ray/actor.py index e806a5f8fae3..420fe5c3a58e 100644 --- a/python/ray/actor.py +++ b/python/ray/actor.py @@ -186,8 +186,9 @@ class ActorClass(object): task. _resources: The default resources required by the actor creation task. _actor_method_cpus: The number of CPUs required by actor method tasks. - _exported: True if the actor class has been exported and false - otherwise. + _last_export_session: The index of the last session in which the remote + function was exported. This is used to determine if we need to + export the remote function again. _actor_methods: The actor methods. _method_decorators: Optional decorators that should be applied to the method invocation function before invoking the actor methods. These @@ -208,7 +209,7 @@ def __init__(self, modified_class, class_id, max_reconstructions, num_cpus, self._num_cpus = num_cpus self._num_gpus = num_gpus self._resources = resources - self._exported = False + self._last_export_session = None self._actor_methods = inspect.getmembers( self._modified_class, ray.utils.is_function_or_method) @@ -341,10 +342,14 @@ def _remote(self, *copy.deepcopy(args), **copy.deepcopy(kwargs)) else: # Export the actor. - if not self._exported: + if (self._last_export_session is None + or self._last_export_session < worker._session_index): + # If this actor class was exported in a previous session, we + # need to export this function again, because current GCS + # doesn't have it. + self._last_export_session = worker._session_index worker.function_actor_manager.export_actor_class( self._modified_class, self._actor_method_names) - self._exported = True resources = ray.utils.resources_from_resource_arguments( cpus_to_use, self._num_gpus, self._resources, num_cpus, diff --git a/python/ray/remote_function.py b/python/ray/remote_function.py index 3bc3fc2bd92e..e4828fd47bb5 100644 --- a/python/ray/remote_function.py +++ b/python/ray/remote_function.py @@ -43,6 +43,9 @@ class RemoteFunction(object): return the resulting ObjectIDs. For an example, see "test_decorated_function" in "python/ray/tests/test_basic.py". _function_signature: The function signature. + _last_export_session: The index of the last session in which the remote + function was exported. This is used to determine if we need to + export the remote function again. """ def __init__(self, function, num_cpus, num_gpus, resources, @@ -68,7 +71,6 @@ def __init__(self, function, num_cpus, num_gpus, resources, # Export the function. worker = ray.worker.get_global_worker() - # In which session this function was exported last time. self._last_export_session = worker._session_index worker.function_actor_manager.export(self) diff --git a/python/ray/tests/test_basic.py b/python/ray/tests/test_basic.py index ffd0fb630e80..056aedd4f86c 100644 --- a/python/ray/tests/test_basic.py +++ b/python/ray/tests/test_basic.py @@ -2942,3 +2942,29 @@ def get_postprocessor(object_ids, values): assert ray.get( [ray.put(i) for i in [0, 1, 3, 5, -1, -3, 4]]) == [1, 3, 5, 4] + + +def test_export_after_shutdown(ray_start_regular): + # This test checks that we can use actor and remote function definitions + # across multiple Ray sessions. + + @ray.remote + def f(): + pass + + @ray.remote + class Actor(object): + def method(self): + pass + + ray.get(f.remote()) + a = Actor.remote() + ray.get(a.method.remote()) + + ray.shutdown() + + # Start Ray and use the remote function and actor again. + ray.init(num_cpus=1) + ray.get(f.remote()) + a = Actor.remote() + ray.get(a.method.remote()) From ba6c595094a27828358e891227ed3852c6f1e50f Mon Sep 17 00:00:00 2001 From: Devin Petersohn Date: Thu, 23 May 2019 17:02:20 -0700 Subject: [PATCH 031/118] Bump Ray master version to 0.8.0.dev0 (#4845) --- python/ray/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/__init__.py b/python/ray/__init__.py index b15fb13cbf29..ecd9138b4a93 100644 --- a/python/ray/__init__.py +++ b/python/ray/__init__.py @@ -95,7 +95,7 @@ from ray.runtime_context import _get_runtime_context # noqa: E402 # Ray version string. -__version__ = "0.7.0" +__version__ = "0.8.0.dev0" __all__ = [ "LOCAL_MODE", From 4e281ba938d4eed3bb3b34ee8ecf9fcabd05a1db Mon Sep 17 00:00:00 2001 From: Devin Petersohn Date: Thu, 23 May 2019 18:06:07 -0700 Subject: [PATCH 032/118] Add section to bump version of master branch and cleanup release docs (#4846) --- dev/RELEASE_PROCESS.rst | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/dev/RELEASE_PROCESS.rst b/dev/RELEASE_PROCESS.rst index 88a243c28bb6..62862506e1ed 100644 --- a/dev/RELEASE_PROCESS.rst +++ b/dev/RELEASE_PROCESS.rst @@ -41,12 +41,12 @@ This document describes the process for creating new releases. 6. **Download all the wheels:** Now the release is ready to begin final testing. The wheels are automatically uploaded to S3, even on the release - branch. The wheels can ``pip install``ed from the following URLs: + branch. To test, ``pip install`` from the following URLs: .. code-block:: bash export RAY_HASH=... # e.g., 618147f57fb40368448da3b2fb4fd213828fa12b - export RAY_VERSION=... # e.g., 0.6.6 + export RAY_VERSION=... # e.g., 0.7.0 pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/$RAY_HASH/ray-$RAY_VERSION-cp27-cp27mu-manylinux1_x86_64.whl pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/$RAY_HASH/ray-$RAY_VERSION-cp35-cp35m-manylinux1_x86_64.whl pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/$RAY_HASH/ray-$RAY_VERSION-cp36-cp36m-manylinux1_x86_64.whl @@ -120,10 +120,28 @@ This document describes the process for creating new releases. git pull origin master --tags git log $(git describe --tags --abbrev=0)..HEAD --pretty=format:"%s" | sort +11. **Bump version on Ray master branch:** Create a pull request to increment the + version of the master branch. The format of the new version is as follows: + + New minor release (e.g., 0.7.0): Increment the minor version and append ``.dev0`` to + the version. For example, if the version of the new release is 0.7.0, the master + branch needs to be updated to 0.8.0.dev0. `Example PR for minor release` + + New micro release (e.g., 0.7.1): Increment the ``dev`` number, such that the number + after ``dev`` equals the micro version. For example, if the version of the new + release is 0.7.1, the master branch needs to be updated to 0.8.0.dev1. + +12. **Update version numbers throughout codebase:** Suppose we just released 0.7.1. The + previous release version number (in this case 0.7.0) and the previous dev version number + (in this case 0.8.0.dev0) appear in many places throughout the code base including + the installation documentation, the example autoscaler config files, and the testing + scripts. Search for all of the occurrences of these version numbers and update them to + use the new release and dev version numbers. + .. _documentation: https://ray.readthedocs.io/en/latest/installation.html#trying-snapshots-from-master .. _`documentation for building wheels`: https://github.com/ray-project/ray/blob/master/python/README-building-wheels.md .. _`ci/stress_tests/run_stress_tests.sh`: https://github.com/ray-project/ray/blob/master/ci/stress_tests/run_stress_tests.sh .. _`ci/stress_tests/run_application_stress_tests.sh`: https://github.com/ray-project/ray/blob/master/ci/stress_tests/run_application_stress_tests.sh .. _`this example`: https://github.com/ray-project/ray/pull/4226 -.. _`these wheels here`: https://ray.readthedocs.io/en/latest/installation.html .. _`GitHub website`: https://github.com/ray-project/ray/releases +.. _`Example PR for minor release`: https://github.com/ray-project/ray/pull/4845 From 71f95e1c54c93189bfefc0abc5fcbf16681e2f6b Mon Sep 17 00:00:00 2001 From: Stefan Pantic Date: Fri, 24 May 2019 10:33:42 +0200 Subject: [PATCH 033/118] Fix import --- python/ray/rllib/agents/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/rllib/agents/trainer.py b/python/ray/rllib/agents/trainer.py index 029034e94258..0e727955aab3 100644 --- a/python/ray/rllib/agents/trainer.py +++ b/python/ray/rllib/agents/trainer.py @@ -32,7 +32,7 @@ from ray.tune.trainable import Trainable from ray.tune.trial import Resources, ExportFormat -from python.ray.tune.logger import to_tf_values +from ray.tune.logger import to_tf_values tf = try_import_tf() From 49fe894e2219e99dc836176dfe447afa7a0ab331 Mon Sep 17 00:00:00 2001 From: Robert Nishihara Date: Fri, 24 May 2019 13:44:39 -0700 Subject: [PATCH 034/118] =?UTF-8?q?Export=20remote=20functions=20when=20fi?= =?UTF-8?q?rst=20used=20and=20also=20fix=20bug=20in=20which=20rem=E2=80=A6?= =?UTF-8?q?=20(#4844)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Export remote functions when first used and also fix bug in which remote functions and actor classes are not exported from workers during subsequent ray sessions. * Documentation update * Fix tests. * Fix grammar --- doc/source/internals-overview.rst | 26 -------------------------- python/ray/actor.py | 18 +++++++++++------- python/ray/function_manager.py | 2 +- python/ray/remote_function.py | 19 ++++++++++--------- python/ray/tests/test_basic.py | 31 +++++++++++++++++++++++++++++++ python/ray/tests/test_failure.py | 13 ++++++++++++- python/ray/tests/test_monitors.py | 9 +-------- 7 files changed, 66 insertions(+), 52 deletions(-) diff --git a/doc/source/internals-overview.rst b/doc/source/internals-overview.rst index 7a762b6bdff0..109b923ecc3b 100644 --- a/doc/source/internals-overview.rst +++ b/doc/source/internals-overview.rst @@ -66,32 +66,6 @@ listens for the addition of remote functions to the centralized control state. When a new remote function is added, the thread fetches the pickled remote function, unpickles it, and can then execute that function. -Notes and limitations -~~~~~~~~~~~~~~~~~~~~~ - -- Because we export remote functions as soon as they are defined, that means - that remote functions can't close over variables that are defined after the - remote function is defined. For example, the following code gives an error. - - .. code-block:: python - - @ray.remote - def f(x): - return helper(x) - - def helper(x): - return x + 1 - - If you call ``f.remote(0)``, it will give an error of the form. - - .. code-block:: python - - Traceback (most recent call last): - File "", line 3, in f - NameError: name 'helper' is not defined - - On the other hand, if ``helper`` is defined before ``f``, then it will work. - Calling a remote function ------------------------- diff --git a/python/ray/actor.py b/python/ray/actor.py index 420fe5c3a58e..dce9a0b26074 100644 --- a/python/ray/actor.py +++ b/python/ray/actor.py @@ -186,9 +186,12 @@ class ActorClass(object): task. _resources: The default resources required by the actor creation task. _actor_method_cpus: The number of CPUs required by actor method tasks. - _last_export_session: The index of the last session in which the remote - function was exported. This is used to determine if we need to - export the remote function again. + _last_driver_id_exported_for: The ID of the driver ID of the last Ray + session during which this actor class definition was exported. This + is an imperfect mechanism used to determine if we need to export + the remote function again. It is imperfect in the sense that the + actor class definition could be exported multiple times by + different workers. _actor_methods: The actor methods. _method_decorators: Optional decorators that should be applied to the method invocation function before invoking the actor methods. These @@ -209,7 +212,7 @@ def __init__(self, modified_class, class_id, max_reconstructions, num_cpus, self._num_cpus = num_cpus self._num_gpus = num_gpus self._resources = resources - self._last_export_session = None + self._last_driver_id_exported_for = None self._actor_methods = inspect.getmembers( self._modified_class, ray.utils.is_function_or_method) @@ -342,12 +345,13 @@ def _remote(self, *copy.deepcopy(args), **copy.deepcopy(kwargs)) else: # Export the actor. - if (self._last_export_session is None - or self._last_export_session < worker._session_index): + if (self._last_driver_id_exported_for is None + or self._last_driver_id_exported_for != + worker.task_driver_id): # If this actor class was exported in a previous session, we # need to export this function again, because current GCS # doesn't have it. - self._last_export_session = worker._session_index + self._last_driver_id_exported_for = worker.task_driver_id worker.function_actor_manager.export_actor_class( self._modified_class, self._actor_method_names) diff --git a/python/ray/function_manager.py b/python/ray/function_manager.py index e4a172fc1e71..4914c9f87050 100644 --- a/python/ray/function_manager.py +++ b/python/ray/function_manager.py @@ -342,7 +342,7 @@ def export(self, remote_function): # and export it later. self._functions_to_export.append(remote_function) return - if self._worker.mode != ray.worker.SCRIPT_MODE: + if self._worker.mode == ray.worker.LOCAL_MODE: # Don't need to export if the worker is not a driver. return self._do_export(remote_function) diff --git a/python/ray/remote_function.py b/python/ray/remote_function.py index e4828fd47bb5..44d2777a2900 100644 --- a/python/ray/remote_function.py +++ b/python/ray/remote_function.py @@ -43,9 +43,12 @@ class RemoteFunction(object): return the resulting ObjectIDs. For an example, see "test_decorated_function" in "python/ray/tests/test_basic.py". _function_signature: The function signature. - _last_export_session: The index of the last session in which the remote - function was exported. This is used to determine if we need to - export the remote function again. + _last_driver_id_exported_for: The ID of the driver ID of the last Ray + session during which this remote function definition was exported. + This is an imperfect mechanism used to determine if we need to + export the remote function again. It is imperfect in the sense that + the actor class definition could be exported multiple times by + different workers. """ def __init__(self, function, num_cpus, num_gpus, resources, @@ -69,10 +72,7 @@ def __init__(self, function, num_cpus, num_gpus, resources, self._function_signature = ray.signature.extract_signature( self._function) - # Export the function. - worker = ray.worker.get_global_worker() - self._last_export_session = worker._session_index - worker.function_actor_manager.export(self) + self._last_driver_id_exported_for = None def __call__(self, *args, **kwargs): raise Exception("Remote functions cannot be called directly. Instead " @@ -111,10 +111,11 @@ def _remote(self, worker = ray.worker.get_global_worker() worker.check_connected() - if self._last_export_session < worker._session_index: + if (self._last_driver_id_exported_for is None + or self._last_driver_id_exported_for != worker.task_driver_id): # If this function was exported in a previous session, we need to # export this function again, because current GCS doesn't have it. - self._last_export_session = worker._session_index + self._last_driver_id_exported_for = worker.task_driver_id worker.function_actor_manager.export(self) kwargs = {} if kwargs is None else kwargs diff --git a/python/ray/tests/test_basic.py b/python/ray/tests/test_basic.py index 056aedd4f86c..d6eebe1517bb 100644 --- a/python/ray/tests/test_basic.py +++ b/python/ray/tests/test_basic.py @@ -303,6 +303,23 @@ def f(x): assert_equal(obj, ray.get(ray.put(obj))) +def test_nested_functions(ray_start_regular): + # Make sure that remote functions can use other values that are defined + # after the remote function but before the first function invocation. + @ray.remote + def f(): + return g(), ray.get(h.remote()) + + def g(): + return 1 + + @ray.remote + def h(): + return 2 + + assert ray.get(f.remote()) == (1, 2) + + def test_ray_recursive_objects(ray_start_regular): class ClassA(object): pass @@ -2968,3 +2985,17 @@ def method(self): ray.get(f.remote()) a = Actor.remote() ray.get(a.method.remote()) + + ray.shutdown() + + # Start Ray again and make sure that these definitions can be exported from + # workers. + ray.init(num_cpus=2) + + @ray.remote + def export_definitions_from_worker(remote_function, actor_class): + ray.get(remote_function.remote()) + actor_handle = actor_class.remote() + ray.get(actor_handle.method.remote()) + + ray.get(export_definitions_from_worker.remote(f, Actor)) diff --git a/python/ray/tests/test_failure.py b/python/ray/tests/test_failure.py index 650cce68b246..6a782ee726da 100644 --- a/python/ray/tests/test_failure.py +++ b/python/ray/tests/test_failure.py @@ -95,7 +95,15 @@ def temporary_helper_function(): # fail when it is unpickled. @ray.remote def g(): - return module.temporary_python_file() + try: + module.temporary_python_file() + except Exception: + # This test is not concerned with the error from running this + # function. Only from unpickling the remote function. + pass + + # Invoke the function so that the definition is exported. + g.remote() wait_for_errors(ray_constants.REGISTER_REMOTE_FUNCTION_PUSH_ERROR, 2) errors = relevant_errors(ray_constants.REGISTER_REMOTE_FUNCTION_PUSH_ERROR) @@ -499,6 +507,9 @@ def test_export_large_objects(ray_start_regular): def f(): large_object + # Invoke the function so that the definition is exported. + f.remote() + # Make sure that a warning is generated. wait_for_errors(ray_constants.PICKLING_LARGE_OBJECT_PUSH_ERROR, 1) diff --git a/python/ray/tests/test_monitors.py b/python/ray/tests/test_monitors.py index d588732c11f4..36ed55a52474 100644 --- a/python/ray/tests/test_monitors.py +++ b/python/ray/tests/test_monitors.py @@ -46,13 +46,6 @@ def Driver(success): # Two new objects. ray.get(ray.put(1111)) ray.get(ray.put(1111)) - attempts = 0 - while (2, 1, summary_start[2]) != StateSummary(): - time.sleep(0.1) - attempts += 1 - if attempts == max_attempts_before_failing: - success.value = False - break @ray.remote def f(): @@ -61,7 +54,7 @@ def f(): # 1 new function. attempts = 0 - while (2, 1, summary_start[2] + 1) != StateSummary(): + while (2, 1, summary_start[2]) != StateSummary(): time.sleep(0.1) attempts += 1 if attempts == max_attempts_before_failing: From a7d01aba9b94232cf2ee385e4cad15f435022033 Mon Sep 17 00:00:00 2001 From: Devin Petersohn Date: Fri, 24 May 2019 16:49:13 -0700 Subject: [PATCH 035/118] Update wheel versions in documentation to 0.8.0.dev0 and 0.7.0. (#4847) --- README.rst | 2 +- .../run_perf_integration.sh | 2 +- .../application_cluster_template.yaml | 4 ++-- ci/stress_tests/stress_testing_config.yaml | 2 +- doc/source/installation.rst | 16 ++++++++-------- docker/stress_test/Dockerfile | 2 +- docker/tune_test/Dockerfile | 2 +- python/ray/autoscaler/aws/example-full.yaml | 6 +++--- .../ray/autoscaler/aws/example-gpu-docker.yaml | 6 +++--- python/ray/autoscaler/gcp/example-full.yaml | 6 +++--- .../ray/autoscaler/gcp/example-gpu-docker.yaml | 6 +++--- 11 files changed, 27 insertions(+), 27 deletions(-) diff --git a/README.rst b/README.rst index ada6f7c2d4d1..aed1d2e81ae9 100644 --- a/README.rst +++ b/README.rst @@ -6,7 +6,7 @@ .. image:: https://readthedocs.org/projects/ray/badge/?version=latest :target: http://ray.readthedocs.io/en/latest/?badge=latest -.. image:: https://img.shields.io/badge/pypi-0.6.6-blue.svg +.. image:: https://img.shields.io/badge/pypi-0.7.0-blue.svg :target: https://pypi.org/project/ray/ | diff --git a/ci/jenkins_tests/perf_integration_tests/run_perf_integration.sh b/ci/jenkins_tests/perf_integration_tests/run_perf_integration.sh index 927f8bf5e83d..f723d5122981 100755 --- a/ci/jenkins_tests/perf_integration_tests/run_perf_integration.sh +++ b/ci/jenkins_tests/perf_integration_tests/run_perf_integration.sh @@ -9,7 +9,7 @@ pushd "$ROOT_DIR" python -m pip install pytest-benchmark -pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.7.0.dev3-cp27-cp27mu-manylinux1_x86_64.whl +pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp27-cp27mu-manylinux1_x86_64.whl python -m pytest --benchmark-autosave --benchmark-min-rounds=10 --benchmark-columns="min, max, mean" $ROOT_DIR/../../../python/ray/tests/perf_integration_tests/test_perf_integration.py pushd $ROOT_DIR/../../../python diff --git a/ci/stress_tests/application_cluster_template.yaml b/ci/stress_tests/application_cluster_template.yaml index e8fc0efa2bcd..541419da55af 100644 --- a/ci/stress_tests/application_cluster_template.yaml +++ b/ci/stress_tests/application_cluster_template.yaml @@ -90,8 +90,8 @@ file_mounts: { # List of shell commands to run to set up nodes. setup_commands: - echo 'export PATH="$HOME/anaconda3/envs/tensorflow_<<>>/bin:$PATH"' >> ~/.bashrc - - ray || wget https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.7.0.dev2-<<>>-manylinux1_x86_64.whl - - rllib || pip install -U ray-0.7.0.dev2-<<>>-manylinux1_x86_64.whl[rllib] + - ray || wget https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-<<>>-manylinux1_x86_64.whl + - rllib || pip install -U ray-0.8.0.dev0-<<>>-manylinux1_x86_64.whl[rllib] - pip install tensorflow-gpu==1.12.0 - echo "sudo halt" | at now + 60 minutes # Consider uncommenting these if you also want to run apt-get commands during setup diff --git a/ci/stress_tests/stress_testing_config.yaml b/ci/stress_tests/stress_testing_config.yaml index 3ea6f7f717a3..07c27bab79b6 100644 --- a/ci/stress_tests/stress_testing_config.yaml +++ b/ci/stress_tests/stress_testing_config.yaml @@ -100,7 +100,7 @@ setup_commands: # - git clone https://github.com/ray-project/ray || true - pip install boto3==1.4.8 cython==0.29.0 # - cd ray/python; git checkout master; git pull; pip install -e . --verbose - - pip install https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.7.0.dev2-cp36-cp36m-manylinux1_x86_64.whl + - pip install https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp36-cp36m-manylinux1_x86_64.whl - echo "sudo halt" | at now + 60 minutes # Custom commands that will be run on the head node after common setup. diff --git a/doc/source/installation.rst b/doc/source/installation.rst index 0aa925d47c63..ad92cb347e83 100644 --- a/doc/source/installation.rst +++ b/doc/source/installation.rst @@ -33,14 +33,14 @@ Here are links to the latest wheels (which are built off of master). To install =================== =================== -.. _`Linux Python 3.7`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.7.0.dev2-cp37-cp37m-manylinux1_x86_64.whl -.. _`Linux Python 3.6`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.7.0.dev2-cp36-cp36m-manylinux1_x86_64.whl -.. _`Linux Python 3.5`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.7.0.dev2-cp35-cp35m-manylinux1_x86_64.whl -.. _`Linux Python 2.7`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.7.0.dev2-cp27-cp27mu-manylinux1_x86_64.whl -.. _`MacOS Python 3.7`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.7.0.dev2-cp37-cp37m-macosx_10_6_intel.whl -.. _`MacOS Python 3.6`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.7.0.dev2-cp36-cp36m-macosx_10_6_intel.whl -.. _`MacOS Python 3.5`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.7.0.dev2-cp35-cp35m-macosx_10_6_intel.whl -.. _`MacOS Python 2.7`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.7.0.dev2-cp27-cp27m-macosx_10_6_intel.whl +.. _`Linux Python 3.7`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp37-cp37m-manylinux1_x86_64.whl +.. _`Linux Python 3.6`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp36-cp36m-manylinux1_x86_64.whl +.. _`Linux Python 3.5`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp35-cp35m-manylinux1_x86_64.whl +.. _`Linux Python 2.7`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp27-cp27mu-manylinux1_x86_64.whl +.. _`MacOS Python 3.7`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp37-cp37m-macosx_10_6_intel.whl +.. _`MacOS Python 3.6`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp36-cp36m-macosx_10_6_intel.whl +.. _`MacOS Python 3.5`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp35-cp35m-macosx_10_6_intel.whl +.. _`MacOS Python 2.7`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp27-cp27m-macosx_10_6_intel.whl Building Ray from source diff --git a/docker/stress_test/Dockerfile b/docker/stress_test/Dockerfile index 0891ac02c8f9..664370eb0479 100644 --- a/docker/stress_test/Dockerfile +++ b/docker/stress_test/Dockerfile @@ -4,7 +4,7 @@ FROM ray-project/base-deps # We install ray and boto3 to enable the ray autoscaler as # a test runner. -RUN pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.7.0.dev2-cp27-cp27mu-manylinux1_x86_64.whl boto3 +RUN pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp27-cp27mu-manylinux1_x86_64.whl boto3 RUN mkdir -p /root/.ssh/ # We port the source code in so that we run the most up-to-date stress tests. diff --git a/docker/tune_test/Dockerfile b/docker/tune_test/Dockerfile index 9546b676b779..b0cf426c1b1d 100644 --- a/docker/tune_test/Dockerfile +++ b/docker/tune_test/Dockerfile @@ -4,7 +4,7 @@ FROM ray-project/base-deps # We install ray and boto3 to enable the ray autoscaler as # a test runner. -RUN pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.7.0.dev2-cp27-cp27mu-manylinux1_x86_64.whl boto3 +RUN pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp27-cp27mu-manylinux1_x86_64.whl boto3 # We install this after the latest wheels -- this should not override the latest wheels. RUN apt-get install -y zlib1g-dev RUN pip install gym[atari]==0.10.11 opencv-python-headless tensorflow lz4 keras pytest-timeout smart_open diff --git a/python/ray/autoscaler/aws/example-full.yaml b/python/ray/autoscaler/aws/example-full.yaml index c8ebe9dc31c2..7399450aeedb 100644 --- a/python/ray/autoscaler/aws/example-full.yaml +++ b/python/ray/autoscaler/aws/example-full.yaml @@ -113,9 +113,9 @@ setup_commands: # has your Ray repo pre-cloned. Then, you can replace the pip installs # below with a git checkout (and possibly a recompile). - echo 'export PATH="$HOME/anaconda3/envs/tensorflow_p36/bin:$PATH"' >> ~/.bashrc - # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.7.0.dev2-cp27-cp27mu-manylinux1_x86_64.whl - # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.7.0.dev2-cp35-cp35m-manylinux1_x86_64.whl - - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.7.0.dev2-cp36-cp36m-manylinux1_x86_64.whl + # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp27-cp27mu-manylinux1_x86_64.whl + # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp35-cp35m-manylinux1_x86_64.whl + - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp36-cp36m-manylinux1_x86_64.whl # Consider uncommenting these if you also want to run apt-get commands during setup # - sudo pkill -9 apt-get || true # - sudo pkill -9 dpkg || true diff --git a/python/ray/autoscaler/aws/example-gpu-docker.yaml b/python/ray/autoscaler/aws/example-gpu-docker.yaml index 37c0323fc757..79fdc055b091 100644 --- a/python/ray/autoscaler/aws/example-gpu-docker.yaml +++ b/python/ray/autoscaler/aws/example-gpu-docker.yaml @@ -105,9 +105,9 @@ file_mounts: { # List of shell commands to run to set up nodes. setup_commands: - # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.7.0.dev2-cp27-cp27mu-manylinux1_x86_64.whl - - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.7.0.dev2-cp35-cp35m-manylinux1_x86_64.whl - # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.7.0.dev2-cp36-cp36m-manylinux1_x86_64.whl + # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp27-cp27mu-manylinux1_x86_64.whl + - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp35-cp35m-manylinux1_x86_64.whl + # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp36-cp36m-manylinux1_x86_64.whl # Custom commands that will be run on the head node after common setup. head_setup_commands: diff --git a/python/ray/autoscaler/gcp/example-full.yaml b/python/ray/autoscaler/gcp/example-full.yaml index 9575691158c1..4ab2093dd865 100644 --- a/python/ray/autoscaler/gcp/example-full.yaml +++ b/python/ray/autoscaler/gcp/example-full.yaml @@ -127,9 +127,9 @@ setup_commands: && echo 'export PATH="$HOME/anaconda3/bin:$PATH"' >> ~/.profile # Install ray - # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.7.0.dev2-cp27-cp27mu-manylinux1_x86_64.whl - # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.7.0.dev2-cp35-cp35m-manylinux1_x86_64.whl - - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.7.0.dev2-cp36-cp36m-manylinux1_x86_64.whl + # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp27-cp27mu-manylinux1_x86_64.whl + # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp35-cp35m-manylinux1_x86_64.whl + - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp36-cp36m-manylinux1_x86_64.whl # Custom commands that will be run on the head node after common setup. diff --git a/python/ray/autoscaler/gcp/example-gpu-docker.yaml b/python/ray/autoscaler/gcp/example-gpu-docker.yaml index 43b9d867b5b5..75e0497094cb 100644 --- a/python/ray/autoscaler/gcp/example-gpu-docker.yaml +++ b/python/ray/autoscaler/gcp/example-gpu-docker.yaml @@ -140,9 +140,9 @@ setup_commands: # - echo 'export PATH="$HOME/anaconda3/envs/tensorflow_p36/bin:$PATH"' >> ~/.bashrc # Install ray - # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.7.0.dev2-cp27-cp27mu-manylinux1_x86_64.whl - - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.7.0.dev2-cp35-cp35m-manylinux1_x86_64.whl - # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.7.0.dev2-cp36-cp36m-manylinux1_x86_64.whl + # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp27-cp27mu-manylinux1_x86_64.whl + - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp35-cp35m-manylinux1_x86_64.whl + # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp36-cp36m-manylinux1_x86_64.whl # Custom commands that will be run on the head node after common setup. head_setup_commands: From 0ce0ecbe9ce9d1c405ef5d708e8d0b79626b40ce Mon Sep 17 00:00:00 2001 From: Richard Liaw Date: Sat, 25 May 2019 02:19:28 -0700 Subject: [PATCH 036/118] [tune] Later expansion of local_dir (#4806) --- python/ray/tune/trial.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index 272945ba1cf4..91ea941b8cf0 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -273,7 +273,7 @@ def __init__(self, # Trial config self.trainable_name = trainable_name self.config = config or {} - self.local_dir = os.path.expanduser(local_dir) + self.local_dir = local_dir # This remains unexpanded for syncing. self.experiment_tag = experiment_tag self.resources = ( resources @@ -346,6 +346,7 @@ def generate_id(cls): @classmethod def create_logdir(cls, identifier, local_dir): + local_dir = os.path.expanduser(local_dir) if not os.path.exists(local_dir): os.makedirs(local_dir) return tempfile.mkdtemp( From 7237ea70c41d0a86c925734ee9e606914588cfa9 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Sat, 25 May 2019 10:45:26 -0700 Subject: [PATCH 037/118] [rllib] [RFC] Deprecate Python 2 / RLlib (#4832) --- python/ray/rllib/__init__.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/python/ray/rllib/__init__.py b/python/ray/rllib/__init__.py index 05f88ac653c4..92844e485ff3 100644 --- a/python/ray/rllib/__init__.py +++ b/python/ray/rllib/__init__.py @@ -3,6 +3,7 @@ from __future__ import print_function import logging +import sys # Note: do not introduce unnecessary library dependencies here, e.g. gym. # This file is imported from the tune module in order to register RLlib agents. @@ -30,6 +31,11 @@ def _setup_logger(): logger.addHandler(handler) logger.propagate = False + if sys.version_info[0] < 3: + logger.warn( + "RLlib Python 2 support is deprecated, and will be removed " + "in a future release.") + def _register_all(): From ea8d7b4dc07563d48a40952b4150ae9df007180c Mon Sep 17 00:00:00 2001 From: IkedaYutaro Date: Sun, 26 May 2019 15:13:58 +0900 Subject: [PATCH 038/118] Fix a typo in kubernetes yaml (#4872) --- kubernetes/submit.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kubernetes/submit.yaml b/kubernetes/submit.yaml index 80eeecd9751e..e6e66ae8d944 100644 --- a/kubernetes/submit.yaml +++ b/kubernetes/submit.yaml @@ -86,7 +86,7 @@ spec: spec: affinity: podAntiAffinity: - requiredDuringSchedulingIgnoreDuringExecution: + requiredDuringSchedulingIgnoredDuringExecution: - labelSelector: matchLabels: type: ray From 67035191443e8d645024806b2df72e96e6b05e64 Mon Sep 17 00:00:00 2001 From: Robert Nishihara Date: Sun, 26 May 2019 11:27:53 -0700 Subject: [PATCH 039/118] Move global state API out of global_state object. (#4857) --- .../test_many_tasks_and_transfers.py | 4 +- doc/source/api.rst | 20 ++ doc/source/development.rst | 10 +- doc/source/user-profiling.rst | 2 +- kubernetes/example.py | 2 +- python/ray/__init__.py | 16 +- python/ray/actor.py | 4 +- python/ray/experimental/__init__.py | 13 +- python/ray/experimental/features.py | 186 -------------- python/ray/monitor.py | 17 +- python/ray/scripts/scripts.py | 2 +- python/ray/services.py | 2 +- python/ray/{experimental => }/state.py | 233 ++++++++++++++++-- python/ray/tests/cluster_utils.py | 4 +- python/ray/tests/test_actor.py | 4 +- python/ray/tests/test_basic.py | 73 +++--- python/ray/tests/test_dynres.py | 73 +++--- python/ray/tests/test_failure.py | 9 +- python/ray/tests/test_global_state.py | 18 +- python/ray/tests/test_monitors.py | 17 +- python/ray/tests/test_multi_node.py | 22 +- python/ray/tests/test_multi_node_2.py | 8 +- python/ray/tests/test_object_manager.py | 4 +- python/ray/tests/test_stress.py | 2 +- python/ray/tests/utils.py | 2 +- python/ray/tune/ray_trial_executor.py | 2 +- python/ray/tune/tests/test_cluster.py | 10 +- python/ray/tune/tests/test_trial_runner.py | 2 +- python/ray/worker.py | 29 +-- 29 files changed, 387 insertions(+), 403 deletions(-) delete mode 100644 python/ray/experimental/features.py rename python/ray/{experimental => }/state.py (84%) diff --git a/ci/stress_tests/test_many_tasks_and_transfers.py b/ci/stress_tests/test_many_tasks_and_transfers.py index c4c4825b1bb9..985e05c28a02 100644 --- a/ci/stress_tests/test_many_tasks_and_transfers.py +++ b/ci/stress_tests/test_many_tasks_and_transfers.py @@ -24,10 +24,10 @@ # Wait until the expected number of nodes have joined the cluster. while True: - if len(ray.global_state.client_table()) >= num_remote_nodes + 1: + if len(ray.nodes()) >= num_remote_nodes + 1: break logger.info("Nodes have all joined. There are %s resources.", - ray.global_state.cluster_resources()) + ray.cluster_resources()) # Require 1 GPU to force the tasks to be on remote machines. diff --git a/doc/source/api.rst b/doc/source/api.rst index 65e31e5a4ded..a149fbb5bb77 100644 --- a/doc/source/api.rst +++ b/doc/source/api.rst @@ -27,6 +27,26 @@ The Ray API .. autofunction:: ray.method +Inspect the Cluster State +------------------------- + +.. autofunction:: ray.nodes() + +.. autofunction:: ray.tasks() + +.. autofunction:: ray.objects() + +.. autofunction:: ray.timeline() + +.. autofunction:: ray.object_transfer_timeline() + +.. autofunction:: ray.cluster_resources() + +.. autofunction:: ray.available_resources() + +.. autofunction:: ray.errors() + + The Ray Command Line API ------------------------ diff --git a/doc/source/development.rst b/doc/source/development.rst index e4d50327a43a..1fdc65fa35cf 100644 --- a/doc/source/development.rst +++ b/doc/source/development.rst @@ -60,7 +60,7 @@ Python script with the following: .. code-block:: bash RAY_RAYLET_GDB=1 RAY_RAYLET_TMUX=1 python - + You can then list the ``tmux`` sessions with ``tmux ls`` and attach to the appropriate one. @@ -71,14 +71,14 @@ allow core dump files to be written. Inspecting Redis shards ~~~~~~~~~~~~~~~~~~~~~~~ -To inspect Redis, you can use the ``ray.experimental.state.GlobalState`` Python -API. The easiest way to do this is to start or connect to a Ray cluster with -``ray.init()``, then query the API like so: +To inspect Redis, you can use the global state API. The easiest way to do this +is to start or connect to a Ray cluster with ``ray.init()``, then query the API +like so: .. code-block:: python ray.init() - ray.worker.global_state.client_table() + ray.nodes() # Returns current information about the nodes in the cluster, such as: # [{'ClientID': '2a9d2b34ad24a37ed54e4fcd32bf19f915742f5b', # 'EntryType': 0, diff --git a/doc/source/user-profiling.rst b/doc/source/user-profiling.rst index 4bf152e52a00..511531f061a8 100644 --- a/doc/source/user-profiling.rst +++ b/doc/source/user-profiling.rst @@ -18,7 +18,7 @@ following command. .. code-block:: python - ray.global_state.chrome_tracing_dump(filename="/tmp/timeline.json") + ray.timeline(filename="/tmp/timeline.json") Then open `chrome://tracing`_ in the Chrome web browser, and load ``timeline.json``. diff --git a/kubernetes/example.py b/kubernetes/example.py index e80a6b6c9b30..5ba0272c73e5 100644 --- a/kubernetes/example.py +++ b/kubernetes/example.py @@ -14,7 +14,7 @@ # Wait for all 4 nodes to join the cluster. while True: - num_nodes = len(ray.global_state.client_table()) + num_nodes = len(ray.nodes()) if num_nodes < 4: print("{} nodes have joined so far. Waiting for more." .format(num_nodes)) diff --git a/python/ray/__init__.py b/python/ray/__init__.py index ecd9138b4a93..e1b65cdcf6c7 100644 --- a/python/ray/__init__.py +++ b/python/ray/__init__.py @@ -66,6 +66,9 @@ _config = _Config() from ray.profiling import profile # noqa: E402 +from ray.state import (global_state, nodes, tasks, objects, timeline, + object_transfer_timeline, cluster_resources, + available_resources, errors) # noqa: E402 from ray.worker import ( LOCAL_MODE, PYTHON_MODE, @@ -73,12 +76,10 @@ WORKER_MODE, connect, disconnect, - error_info, get, get_gpu_ids, get_resource_ids, get_webui_url, - global_state, init, is_initialized, put, @@ -98,6 +99,15 @@ __version__ = "0.8.0.dev0" __all__ = [ + "global_state", + "nodes", + "tasks", + "objects", + "timeline", + "object_transfer_timeline", + "cluster_resources", + "available_resources", + "errors", "LOCAL_MODE", "PYTHON_MODE", "SCRIPT_MODE", @@ -108,12 +118,10 @@ "actor", "connect", "disconnect", - "error_info", "get", "get_gpu_ids", "get_resource_ids", "get_webui_url", - "global_state", "init", "internal", "is_initialized", diff --git a/python/ray/actor.py b/python/ray/actor.py index dce9a0b26074..65642d9928ee 100644 --- a/python/ray/actor.py +++ b/python/ray/actor.py @@ -811,7 +811,7 @@ def exit_actor(): worker.raylet_client.disconnect() ray.disconnect() # Disconnect global state from GCS. - ray.global_state.disconnect() + ray.state.state.disconnect() sys.exit(0) assert False, "This process should have terminated." else: @@ -931,7 +931,7 @@ def get_checkpoints_for_actor(actor_id): """Get the available checkpoints for the given actor ID, return a list sorted by checkpoint timestamp in descending order. """ - checkpoint_info = ray.worker.global_state.actor_checkpoint_info(actor_id) + checkpoint_info = ray.state.state.actor_checkpoint_info(actor_id) if checkpoint_info is None: return [] checkpoints = [ diff --git a/python/ray/experimental/__init__.py b/python/ray/experimental/__init__.py index 5b811ff0ffb2..cb6438d0f2d5 100644 --- a/python/ray/experimental/__init__.py +++ b/python/ray/experimental/__init__.py @@ -2,10 +2,6 @@ from __future__ import division from __future__ import print_function -from .features import ( - flush_redis_unsafe, flush_task_and_object_metadata_unsafe, - flush_finished_tasks_unsafe, flush_evicted_objects_unsafe, - _flush_finished_tasks_unsafe_shard, _flush_evicted_objects_unsafe_shard) from .gcs_flush_policy import (set_flushing_policy, GcsFlushPolicy, SimpleGcsFlushPolicy) from .named_actors import get_actor, register_actor @@ -20,10 +16,7 @@ def TensorFlowVariables(*args, **kwargs): __all__ = [ - "TensorFlowVariables", "flush_redis_unsafe", - "flush_task_and_object_metadata_unsafe", "flush_finished_tasks_unsafe", - "flush_evicted_objects_unsafe", "_flush_finished_tasks_unsafe_shard", - "_flush_evicted_objects_unsafe_shard", "get_actor", "register_actor", - "get", "wait", "set_flushing_policy", "GcsFlushPolicy", - "SimpleGcsFlushPolicy", "set_resource" + "TensorFlowVariables", "get_actor", "register_actor", "get", "wait", + "set_flushing_policy", "GcsFlushPolicy", "SimpleGcsFlushPolicy", + "set_resource" ] diff --git a/python/ray/experimental/features.py b/python/ray/experimental/features.py deleted file mode 100644 index 90f893f271fb..000000000000 --- a/python/ray/experimental/features.py +++ /dev/null @@ -1,186 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import ray -from ray.utils import binary_to_hex - -OBJECT_INFO_PREFIX = b"OI:" -OBJECT_LOCATION_PREFIX = b"OL:" -TASK_PREFIX = b"TT:" - - -def flush_redis_unsafe(redis_client=None): - """This removes some non-critical state from the primary Redis shard. - - This removes the log files as well as the event log from Redis. This can - be used to try to address out-of-memory errors caused by the accumulation - of metadata in Redis. However, it will only partially address the issue as - much of the data is in the task table (and object table), which are not - flushed. - - Args: - redis_client: optional, if not provided then ray.init() must have been - called. - """ - if redis_client is None: - ray.worker.global_worker.check_connected() - redis_client = ray.worker.global_worker.redis_client - - # Delete the log files from the primary Redis shard. - keys = redis_client.keys("LOGFILE:*") - if len(keys) > 0: - num_deleted = redis_client.delete(*keys) - else: - num_deleted = 0 - print("Deleted {} log files from Redis.".format(num_deleted)) - - # Delete the event log from the primary Redis shard. - keys = redis_client.keys("event_log:*") - if len(keys) > 0: - num_deleted = redis_client.delete(*keys) - else: - num_deleted = 0 - print("Deleted {} event logs from Redis.".format(num_deleted)) - - -def flush_task_and_object_metadata_unsafe(): - """This removes some critical state from the Redis shards. - - In a multitenant environment, this will flush metadata for all jobs, which - may be undesirable. - - This removes all of the object and task metadata. This can be used to try - to address out-of-memory errors caused by the accumulation of metadata in - Redis. However, after running this command, fault tolerance will most - likely not work. - """ - ray.worker.global_worker.check_connected() - - def flush_shard(redis_client): - # Flush the task table. Note that this also flushes the driver tasks - # which may be undesirable. - num_task_keys_deleted = 0 - for key in redis_client.scan_iter(match=TASK_PREFIX + b"*"): - num_task_keys_deleted += redis_client.delete(key) - print("Deleted {} task keys from Redis.".format(num_task_keys_deleted)) - - # Flush the object information. - num_object_keys_deleted = 0 - for key in redis_client.scan_iter(match=OBJECT_INFO_PREFIX + b"*"): - num_object_keys_deleted += redis_client.delete(key) - print("Deleted {} object info keys from Redis.".format( - num_object_keys_deleted)) - - # Flush the object locations. - num_object_location_keys_deleted = 0 - for key in redis_client.scan_iter(match=OBJECT_LOCATION_PREFIX + b"*"): - num_object_location_keys_deleted += redis_client.delete(key) - print("Deleted {} object location keys from Redis.".format( - num_object_location_keys_deleted)) - - # Loop over the shards and flush all of them. - for redis_client in ray.worker.global_state.redis_clients: - flush_shard(redis_client) - - -def _task_table_shard(shard_index): - redis_client = ray.global_state.redis_clients[shard_index] - task_table_keys = redis_client.keys(TASK_PREFIX + b"*") - results = {} - for key in task_table_keys: - task_id_binary = key[len(TASK_PREFIX):] - results[binary_to_hex(task_id_binary)] = ray.global_state._task_table( - ray.TaskID(task_id_binary)) - - return results - - -def _object_table_shard(shard_index): - redis_client = ray.global_state.redis_clients[shard_index] - object_table_keys = redis_client.keys(OBJECT_LOCATION_PREFIX + b"*") - results = {} - for key in object_table_keys: - object_id_binary = key[len(OBJECT_LOCATION_PREFIX):] - results[binary_to_hex(object_id_binary)] = ( - ray.global_state._object_table(ray.ObjectID(object_id_binary))) - - return results - - -def _flush_finished_tasks_unsafe_shard(shard_index): - ray.worker.global_worker.check_connected() - - redis_client = ray.global_state.redis_clients[shard_index] - tasks = _task_table_shard(shard_index) - - keys_to_delete = [] - for task_id, task_info in tasks.items(): - if task_info["State"] == ray.experimental.state.TASK_STATUS_DONE: - keys_to_delete.append(TASK_PREFIX + - ray.utils.hex_to_binary(task_id)) - - num_task_keys_deleted = 0 - if len(keys_to_delete) > 0: - num_task_keys_deleted = redis_client.execute_command( - "del", *keys_to_delete) - - print("Deleted {} finished tasks from Redis shard." - .format(num_task_keys_deleted)) - - -def _flush_evicted_objects_unsafe_shard(shard_index): - ray.worker.global_worker.check_connected() - - redis_client = ray.global_state.redis_clients[shard_index] - objects = _object_table_shard(shard_index) - - keys_to_delete = [] - for object_id, object_info in objects.items(): - if object_info["ManagerIDs"] == []: - keys_to_delete.append(OBJECT_LOCATION_PREFIX + - ray.utils.hex_to_binary(object_id)) - keys_to_delete.append(OBJECT_INFO_PREFIX + - ray.utils.hex_to_binary(object_id)) - - num_object_keys_deleted = 0 - if len(keys_to_delete) > 0: - num_object_keys_deleted = redis_client.execute_command( - "del", *keys_to_delete) - - print("Deleted {} keys for evicted objects from Redis." - .format(num_object_keys_deleted)) - - -def flush_finished_tasks_unsafe(): - """This removes some critical state from the Redis shards. - - In a multitenant environment, this will flush metadata for all jobs, which - may be undesirable. - - This removes all of the metadata for finished tasks. This can be used to - try to address out-of-memory errors caused by the accumulation of metadata - in Redis. However, after running this command, fault tolerance will most - likely not work. - """ - ray.worker.global_worker.check_connected() - - for shard_index in range(len(ray.global_state.redis_clients)): - _flush_finished_tasks_unsafe_shard(shard_index) - - -def flush_evicted_objects_unsafe(): - """This removes some critical state from the Redis shards. - - In a multitenant environment, this will flush metadata for all jobs, which - may be undesirable. - - This removes all of the metadata for objects that have been evicted. This - can be used to try to address out-of-memory errors caused by the - accumulation of metadata in Redis. However, after running this command, - fault tolerance will most likely not work. - """ - ray.worker.global_worker.check_connected() - - for shard_index in range(len(ray.global_state.redis_clients)): - _flush_evicted_objects_unsafe_shard(shard_index) diff --git a/python/ray/monitor.py b/python/ray/monitor.py index cc6432cbc8de..09a154d7b548 100644 --- a/python/ray/monitor.py +++ b/python/ray/monitor.py @@ -37,8 +37,7 @@ class Monitor(object): def __init__(self, redis_address, autoscaling_config, redis_password=None): # Initialize the Redis clients. - self.state = ray.experimental.state.GlobalState() - self.state._initialize_global_state( + ray.state.state._initialize_global_state( args.redis_address, redis_password=redis_password) self.redis = ray.services.create_redis_client( redis_address, password=redis_password) @@ -149,7 +148,7 @@ def _xray_clean_up_entries_for_driver(self, driver_id): xray_object_table_prefix = ( ray.gcs_utils.TablePrefix_OBJECT_string.encode("ascii")) - task_table_objects = self.state.task_table() + task_table_objects = ray.tasks() driver_id_hex = binary_to_hex(driver_id) driver_task_id_bins = set() for task_id_hex, task_info in task_table_objects.items(): @@ -161,7 +160,7 @@ def _xray_clean_up_entries_for_driver(self, driver_id): driver_task_id_bins.add(hex_to_binary(task_id_hex)) # Get objects associated with the driver. - object_table_objects = self.state.object_table() + object_table_objects = ray.objects() driver_object_id_bins = set() for object_id, _ in object_table_objects.items(): task_id_bin = ray._raylet.compute_task_id(object_id).binary() @@ -171,13 +170,13 @@ def _xray_clean_up_entries_for_driver(self, driver_id): def to_shard_index(id_bin): if len(id_bin) == ray.TaskID.size(): return binary_to_task_id(id_bin).redis_shard_hash() % len( - self.state.redis_clients) + ray.state.state.redis_clients) else: return binary_to_object_id(id_bin).redis_shard_hash() % len( - self.state.redis_clients) + ray.state.state.redis_clients) # Form the redis keys to delete. - sharded_keys = [[] for _ in range(len(self.state.redis_clients))] + sharded_keys = [[] for _ in range(len(ray.state.state.redis_clients))] for task_id_bin in driver_task_id_bins: sharded_keys[to_shard_index(task_id_bin)].append( xray_task_table_prefix + task_id_bin) @@ -190,7 +189,7 @@ def to_shard_index(id_bin): keys = sharded_keys[shard_index] if len(keys) == 0: continue - redis = self.state.redis_clients[shard_index] + redis = ray.state.state.redis_clients[shard_index] num_deleted = redis.delete(*keys) logger.info("Monitor: " "Removed {} dead redis entries of the " @@ -256,7 +255,7 @@ def process_messages(self, max_messages=10000): message_handler(channel, data) def update_raylet_map(self): - all_raylet_nodes = self.state.client_table() + all_raylet_nodes = ray.nodes() self.raylet_id_to_ip_map = {} for raylet_info in all_raylet_nodes: client_id = (raylet_info.get("DBClientID") diff --git a/python/ray/scripts/scripts.py b/python/ray/scripts/scripts.py index 9ed667b59671..5a0529a51c5d 100644 --- a/python/ray/scripts/scripts.py +++ b/python/ray/scripts/scripts.py @@ -746,7 +746,7 @@ def timeline(redis_address): ray.init(redis_address=redis_address) time = datetime.today().strftime("%Y-%m-%d_%H-%M-%S") filename = "/tmp/ray-timeline-{}.json".format(time) - ray.global_state.chrome_tracing_dump(filename=filename) + ray.timeline(filename=filename) size = os.path.getsize(filename) logger.info("Trace file written to {} ({} bytes).".format(filename, size)) logger.info( diff --git a/python/ray/services.py b/python/ray/services.py index 2e9759428154..7dc594963cec 100644 --- a/python/ray/services.py +++ b/python/ray/services.py @@ -101,7 +101,7 @@ def get_address_info_from_redis_helper(redis_address, # Redis) must have run "CONFIG SET protected-mode no". redis_client = create_redis_client(redis_address, password=redis_password) - client_table = ray.experimental.state.parse_client_table(redis_client) + client_table = ray.state._parse_client_table(redis_client) if len(client_table) == 0: raise Exception( "Redis has started but no raylets have registered yet.") diff --git a/python/ray/experimental/state.py b/python/ray/state.py similarity index 84% rename from python/ray/experimental/state.py rename to python/ray/state.py index 51b36dc83fc7..6b2c8a4ef8bc 100644 --- a/python/ray/experimental/state.py +++ b/python/ray/state.py @@ -4,6 +4,7 @@ from collections import defaultdict import json +import logging import sys import time @@ -17,8 +18,10 @@ from ray.utils import (decode, binary_to_object_id, binary_to_hex, hex_to_binary) +logger = logging.getLogger(__name__) -def parse_client_table(redis_client): + +def _parse_client_table(redis_client): """Read the client table. Args: @@ -128,11 +131,11 @@ def _check_connected(self): yet. """ if self.redis_client is None: - raise Exception("The ray.global_state API cannot be used before " + raise Exception("The ray global state API cannot be used before " "ray.init has been called.") if self.redis_clients is None: - raise Exception("The ray.global_state API cannot be used before " + raise Exception("The ray global state API cannot be used before " "ray.init has been called.") def disconnect(self): @@ -408,7 +411,7 @@ def client_table(self): """ self._check_connected() - return parse_client_table(self.redis_client) + return _parse_client_table(self.redis_client) def _profile_table(self, batch_id): """Get the profile events for a given batch of profile events. @@ -461,6 +464,7 @@ def _profile_table(self, batch_id): return profile_events def profile_table(self): + self._check_connected() profile_table_keys = self._keys( ray.gcs_utils.TablePrefix_PROFILE_string + "*") batch_identifiers_binary = [ @@ -561,6 +565,8 @@ def chrome_tracing_dump(self, filename=None): # TODO(rkn): This should support viewing just a window of time or a # limited number of events. + self._check_connected() + profile_table = self.profile_table() all_events = [] @@ -626,8 +632,10 @@ def chrome_tracing_object_transfer_dump(self, filename=None): If filename is not provided, this returns a list of profiling events. Each profile event is a dictionary. """ + self._check_connected() + client_id_to_address = {} - for client_info in ray.global_state.client_table(): + for client_info in self.client_table(): client_id_to_address[client_info["ClientID"]] = "{}:{}".format( client_info["NodeManagerAddress"], client_info["ObjectManagerPort"]) @@ -703,6 +711,8 @@ def chrome_tracing_object_transfer_dump(self, filename=None): def workers(self): """Get a dictionary mapping worker ID to worker information.""" + self._check_connected() + worker_keys = self.redis_client.keys("Worker*") workers_data = {} @@ -723,22 +733,6 @@ def workers(self): worker_info[b"stdout_file"]) return workers_data - def actors(self): - actor_keys = self.redis_client.keys("Actor:*") - actor_info = {} - for key in actor_keys: - info = self.redis_client.hgetall(key) - actor_id = key[len("Actor:"):] - assert len(actor_id) == ID_SIZE - actor_info[binary_to_hex(actor_id)] = { - "class_id": binary_to_hex(info[b"class_id"]), - "driver_id": binary_to_hex(info[b"driver_id"]), - "raylet_id": binary_to_hex(info[b"raylet_id"]), - "num_gpus": int(info[b"num_gpus"]), - "removed": decode(info[b"removed"]) == "True" - } - return actor_info - def _job_length(self): event_log_sets = self.redis_client.keys("event_log*") overall_smallest = sys.maxsize @@ -769,6 +763,8 @@ def cluster_resources(self): A dictionary mapping resource name to the total quantity of that resource in the cluster. """ + self._check_connected() + resources = defaultdict(int) clients = self.client_table() for client in clients: @@ -798,6 +794,8 @@ def available_resources(self): A dictionary mapping resource name to the total quantity of that resource in the cluster. """ + self._check_connected() + available_resources_by_id = {} subscribe_clients = [ @@ -899,6 +897,8 @@ def error_messages(self, driver_id=None): A dictionary mapping driver ID to a list of the error messages for that driver. """ + self._check_connected() + if driver_id is not None: assert isinstance(driver_id, ray.DriverID) return self._error_messages(driver_id) @@ -954,3 +954,194 @@ def actor_checkpoint_info(self, actor_id): entry.Timestamps(i) for i in range(num_checkpoints) ], } + + +class DeprecatedGlobalState(object): + """A class used to print errors when the old global state API is used.""" + + def object_table(self, object_id=None): + logger.warning( + "ray.global_state.object_table() is deprecated and will be " + "removed in a subsequent release. Use ray.objects() instead.") + return ray.objects(object_id=object_id) + + def task_table(self, task_id=None): + logger.warning( + "ray.global_state.task_table() is deprecated and will be " + "removed in a subsequent release. Use ray.tasks() instead.") + return ray.tasks(task_id=task_id) + + def function_table(self, function_id=None): + raise DeprecationWarning( + "ray.global_state.function_table() is deprecated.") + + def client_table(self): + logger.warning( + "ray.global_state.client_table() is deprecated and will be " + "removed in a subsequent release. Use ray.nodes() instead.") + return ray.nodes() + + def profile_table(self): + raise DeprecationWarning( + "ray.global_state.profile_table() is deprecated.") + + def chrome_tracing_dump(self, filename=None): + logger.warning( + "ray.global_state.chrome_tracing_dump() is deprecated and will be " + "removed in a subsequent release. Use ray.timeline() instead.") + return ray.timeline(filename=filename) + + def chrome_tracing_object_transfer_dump(self, filename=None): + logger.warning( + "ray.global_state.chrome_tracing_object_transfer_dump() is " + "deprecated and will be removed in a subsequent release. Use " + "ray.object_transfer_timeline() instead.") + return ray.object_transfer_timeline(filename=filename) + + def workers(self): + raise DeprecationWarning("ray.global_state.workers() is deprecated.") + + def cluster_resources(self): + logger.warning( + "ray.global_state.cluster_resources() is deprecated and will be " + "removed in a subsequent release. Use ray.cluster_resources() " + "instead.") + return ray.cluster_resources() + + def available_resources(self): + logger.warning( + "ray.global_state.available_resources() is deprecated and will be " + "removed in a subsequent release. Use ray.available_resources() " + "instead.") + return ray.available_resources() + + def error_messages(self, driver_id=None): + logger.warning( + "ray.global_state.error_messages() is deprecated and will be " + "removed in a subsequent release. Use ray.errors() " + "instead.") + return ray.errors(driver_id=driver_id) + + +state = GlobalState() +"""A global object used to access the cluster's global state.""" + +global_state = DeprecatedGlobalState() + + +def nodes(): + """Get a list of the nodes in the cluster. + + Returns: + Information about the Ray clients in the cluster. + """ + return state.client_table() + + +def tasks(task_id=None): + """Fetch and parse the task table information for one or more task IDs. + + Args: + task_id: A hex string of the task ID to fetch information about. If + this is None, then the task object table is fetched. + + Returns: + Information from the task table. + """ + return state.task_table(task_id=task_id) + + +def objects(object_id=None): + """Fetch and parse the object table info for one or more object IDs. + + Args: + object_id: An object ID to fetch information about. If this is None, + then the entire object table is fetched. + + Returns: + Information from the object table. + """ + return state.object_table(object_id=object_id) + + +def timeline(filename=None): + """Return a list of profiling events that can viewed as a timeline. + + To view this information as a timeline, simply dump it as a json file by + passing in "filename" or using using json.dump, and then load go to + chrome://tracing in the Chrome web browser and load the dumped file. + + Args: + filename: If a filename is provided, the timeline is dumped to that + file. + + Returns: + If filename is not provided, this returns a list of profiling events. + Each profile event is a dictionary. + """ + return state.chrome_tracing_dump(filename=filename) + + +def object_transfer_timeline(filename=None): + """Return a list of transfer events that can viewed as a timeline. + + To view this information as a timeline, simply dump it as a json file by + passing in "filename" or using using json.dump, and then load go to + chrome://tracing in the Chrome web browser and load the dumped file. Make + sure to enable "Flow events" in the "View Options" menu. + + Args: + filename: If a filename is provided, the timeline is dumped to that + file. + + Returns: + If filename is not provided, this returns a list of profiling events. + Each profile event is a dictionary. + """ + return state.chrome_tracing_object_transfer_dump(filename=filename) + + +def cluster_resources(): + """Get the current total cluster resources. + + Note that this information can grow stale as nodes are added to or removed + from the cluster. + + Returns: + A dictionary mapping resource name to the total quantity of that + resource in the cluster. + """ + return state.cluster_resources() + + +def available_resources(): + """Get the current available cluster resources. + + This is different from `cluster_resources` in that this will return idle + (available) resources rather than total resources. + + Note that this information can grow stale as tasks start and finish. + + Returns: + A dictionary mapping resource name to the total quantity of that + resource in the cluster. + """ + return state.available_resources() + + +def errors(include_cluster_errors=True): + """Get error messages from the cluster. + + Args: + include_cluster_errors: True if we should include error messages for + all drivers, and false if we should only include error messages for + this specific driver. + + Returns: + Error messages pushed from the cluster. + """ + worker = ray.worker.global_worker + error_messages = state.error_messages(driver_id=worker.task_driver_id) + if include_cluster_errors: + error_messages += state.error_messages(driver_id=ray.DriverID.nil()) + return error_messages diff --git a/python/ray/tests/cluster_utils.py b/python/ray/tests/cluster_utils.py index a7ed3e14a89a..703c3a1420ed 100644 --- a/python/ray/tests/cluster_utils.py +++ b/python/ray/tests/cluster_utils.py @@ -141,7 +141,7 @@ def _wait_for_node(self, node, timeout=30): start_time = time.time() while time.time() - start_time < timeout: - clients = ray.experimental.state.parse_client_table(redis_client) + clients = ray.state._parse_client_table(redis_client) object_store_socket_names = [ client["ObjectStoreSocketName"] for client in clients ] @@ -174,7 +174,7 @@ def wait_for_nodes(self, timeout=30): start_time = time.time() while time.time() - start_time < timeout: - clients = ray.experimental.state.parse_client_table(redis_client) + clients = ray.state._parse_client_table(redis_client) live_clients = [ client for client in clients if client["EntryType"] == EntryType.INSERTION diff --git a/python/ray/tests/test_actor.py b/python/ray/tests/test_actor.py index d7da081fd18c..dd726e00f27b 100644 --- a/python/ray/tests/test_actor.py +++ b/python/ray/tests/test_actor.py @@ -2439,7 +2439,7 @@ def save_checkpoint(self, actor_id, checkpoint_context): assert ray.get(actor.was_resumed_from_checkpoint.remote()) is False # Check that checkpointing errors were pushed to the driver. - errors = ray.error_info() + errors = ray.errors() assert len(errors) > 0 for error in errors: # An error for the actor process dying may also get pushed. @@ -2483,7 +2483,7 @@ def load_checkpoint(self, actor_id, checkpoints): assert ray.get(actor.was_resumed_from_checkpoint.remote()) is False # Check that checkpointing errors were pushed to the driver. - errors = ray.error_info() + errors = ray.errors() assert len(errors) > 0 for error in errors: # An error for the actor process dying may also get pushed. diff --git a/python/ray/tests/test_basic.py b/python/ray/tests/test_basic.py index d6eebe1517bb..50aeca025362 100644 --- a/python/ray/tests/test_basic.py +++ b/python/ray/tests/test_basic.py @@ -935,7 +935,7 @@ def f(block, accepted_resources): stop_time = time.time() + 10 correct_available_resources = False while time.time() < stop_time: - if ray.global_state.available_resources() == { + if ray.available_resources() == { "CPU": 2.0, "GPU": 2.0, "Custom": 2.0, @@ -1176,7 +1176,7 @@ def f(): if time.time() - start_time > timeout_seconds: raise Exception("Timed out while waiting for information in " "profile table.") - profile_data = ray.global_state.chrome_tracing_dump() + profile_data = ray.timeline() event_types = {event["cat"] for event in profile_data} expected_types = [ "worker_idle", @@ -1252,7 +1252,7 @@ def f(x): # The profiling information only flushes once every second. time.sleep(1.1) - transfer_dump = ray.global_state.chrome_tracing_object_transfer_dump() + transfer_dump = ray.object_transfer_timeline() # Make sure the transfer dump can be serialized with JSON. json.loads(json.dumps(transfer_dump)) assert len(transfer_dump) >= num_nodes**2 @@ -1559,12 +1559,12 @@ def run_one_test(actors, local_only, delete_creating_tasks): # Case3: These cases test the deleting creating tasks for the object. (a, b, c) = run_one_test(actors, False, False) - task_table = ray.global_state.task_table() + task_table = ray.tasks() for obj in [a, b, c]: assert ray._raylet.compute_task_id(obj).hex() in task_table (a, b, c) = run_one_test(actors, False, True) - task_table = ray.global_state.task_table() + task_table = ray.tasks() for obj in [a, b, c]: assert ray._raylet.compute_task_id(obj).hex() not in task_table @@ -2026,7 +2026,7 @@ def run_lots_of_tasks(): results.append(run_on_0_2.remote()) return names, results - client_table = ray.global_state.client_table() + client_table = ray.nodes() store_names = [] store_names += [ client["ObjectStoreSocketName"] for client in client_table @@ -2214,13 +2214,13 @@ def test_zero_capacity_deletion_semantics(shutdown_only): ray.init(num_cpus=2, num_gpus=1, resources={"test_resource": 1}) def test(): - resources = ray.global_state.available_resources() + resources = ray.available_resources() MAX_RETRY_ATTEMPTS = 5 retry_count = 0 while resources and retry_count < MAX_RETRY_ATTEMPTS: time.sleep(0.1) - resources = ray.global_state.available_resources() + resources = ray.available_resources() retry_count += 1 if retry_count >= MAX_RETRY_ATTEMPTS: @@ -2394,7 +2394,7 @@ def f(x): def wait_for_num_tasks(num_tasks, timeout=10): start_time = time.time() while time.time() - start_time < timeout: - if len(ray.global_state.task_table()) >= num_tasks: + if len(ray.tasks()) >= num_tasks: return time.sleep(0.1) raise Exception("Timed out while waiting for global state.") @@ -2403,7 +2403,7 @@ def wait_for_num_tasks(num_tasks, timeout=10): def wait_for_num_objects(num_objects, timeout=10): start_time = time.time() while time.time() - start_time < timeout: - if len(ray.global_state.object_table()) >= num_objects: + if len(ray.objects()) >= num_objects: return time.sleep(0.1) raise Exception("Timed out while waiting for global state.") @@ -2414,31 +2414,27 @@ def wait_for_num_objects(num_objects, timeout=10): reason="New GCS API doesn't have a Python API yet.") def test_global_state_api(shutdown_only): with pytest.raises(Exception): - ray.global_state.object_table() + ray.objects() with pytest.raises(Exception): - ray.global_state.task_table() + ray.tasks() with pytest.raises(Exception): - ray.global_state.client_table() - - with pytest.raises(Exception): - ray.global_state.function_table() + ray.nodes() ray.init(num_cpus=5, num_gpus=3, resources={"CustomResource": 1}) resources = {"CPU": 5, "GPU": 3, "CustomResource": 1} - assert ray.global_state.cluster_resources() == resources + assert ray.cluster_resources() == resources - assert ray.global_state.object_table() == {} + assert ray.objects() == {} - driver_id = ray.experimental.state.binary_to_hex( - ray.worker.global_worker.worker_id) + driver_id = ray.utils.binary_to_hex(ray.worker.global_worker.worker_id) driver_task_id = ray.worker.global_worker.current_task_id.hex() # One task is put in the task table which corresponds to this driver. wait_for_num_tasks(1) - task_table = ray.global_state.task_table() + task_table = ray.tasks() assert len(task_table) == 1 assert driver_task_id == list(task_table.keys())[0] task_spec = task_table[driver_task_id]["TaskSpec"] @@ -2451,7 +2447,7 @@ def test_global_state_api(shutdown_only): assert task_spec["FunctionID"] == nil_id_hex assert task_spec["ReturnObjectIDs"] == [] - client_table = ray.global_state.client_table() + client_table = ray.nodes() node_ip_address = ray.worker.global_worker.node_ip_address assert len(client_table) == 1 @@ -2466,24 +2462,19 @@ def f(*xs): # Wait for one additional task to complete. wait_for_num_tasks(1 + 1) - task_table = ray.global_state.task_table() + task_table = ray.tasks() assert len(task_table) == 1 + 1 task_id_set = set(task_table.keys()) task_id_set.remove(driver_task_id) task_id = list(task_id_set)[0] - function_table = ray.global_state.function_table() task_spec = task_table[task_id]["TaskSpec"] assert task_spec["ActorID"] == nil_id_hex assert task_spec["Args"] == [1, "hi", x_id] assert task_spec["DriverID"] == driver_id assert task_spec["ReturnObjectIDs"] == [result_id] - function_table_entry = function_table[task_spec["FunctionID"]] - assert function_table_entry["Name"] == "ray.tests.test_basic.f" - assert function_table_entry["DriverID"] == driver_id - assert function_table_entry["Module"] == "ray.tests.test_basic" - assert task_table[task_id] == ray.global_state.task_table(task_id) + assert task_table[task_id] == ray.tasks(task_id) # Wait for two objects, one for the x_id and one for result_id. wait_for_num_objects(2) @@ -2492,7 +2483,7 @@ def wait_for_object_table(): timeout = 10 start_time = time.time() while time.time() - start_time < timeout: - object_table = ray.global_state.object_table() + object_table = ray.objects() tables_ready = (object_table[x_id]["ManagerIDs"] is not None and object_table[result_id]["ManagerIDs"] is not None) if tables_ready: @@ -2501,11 +2492,11 @@ def wait_for_object_table(): raise Exception("Timed out while waiting for object table to " "update.") - object_table = ray.global_state.object_table() + object_table = ray.objects() assert len(object_table) == 2 - assert object_table[x_id] == ray.global_state.object_table(x_id) - object_table_entry = ray.global_state.object_table(result_id) + assert object_table[x_id] == ray.objects(x_id) + object_table_entry = ray.objects(result_id) assert object_table[result_id] == object_table_entry @@ -2611,14 +2602,6 @@ def f(): while len(worker_ids) != num_workers: worker_ids = set(ray.get([f.remote() for _ in range(10)])) - worker_info = ray.global_state.workers() - assert len(worker_info) >= num_workers - for worker_id, info in worker_info.items(): - assert "node_ip_address" in info - assert "plasma_store_socket" in info - assert "stderr_file" in info - assert "stdout_file" in info - def test_specific_driver_id(): dummy_driver_id = ray.DriverID(b"00112233445566778899") @@ -2816,7 +2799,7 @@ def test_socket_dir_not_existing(shutdown_only): def test_raylet_is_robust_to_random_messages(ray_start_regular): node_manager_address = None node_manager_port = None - for client in ray.global_state.client_table(): + for client in ray.nodes(): if "NodeManagerAddress" in client: node_manager_address = client["NodeManagerAddress"] node_manager_port = client["NodeManagerPort"] @@ -2908,7 +2891,7 @@ def test_shutdown_disconnect_global_state(): ray.shutdown() with pytest.raises(Exception) as e: - ray.global_state.object_table() + ray.objects() assert str(e.value).endswith("ray.init has been called.") @@ -2922,8 +2905,8 @@ def test_redis_lru_with_set(ray_start_object_store_memory): removed = False start_time = time.time() while time.time() < start_time + 10: - if ray.global_state.redis_clients[0].delete(b"OBJECT" + - x_id.binary()) == 1: + if ray.state.state.redis_clients[0].delete(b"OBJECT" + + x_id.binary()) == 1: removed = True break assert removed diff --git a/python/ray/tests/test_dynres.py b/python/ray/tests/test_dynres.py index 6f39839301c9..ea647adf1207 100644 --- a/python/ray/tests/test_dynres.py +++ b/python/ray/tests/test_dynres.py @@ -23,8 +23,8 @@ def set_res(resource_name, resource_capacity): ray.get(set_res.remote(res_name, res_capacity)) - available_res = ray.global_state.available_resources() - cluster_res = ray.global_state.cluster_resources() + available_res = ray.available_resources() + cluster_res = ray.cluster_resources() assert available_res[res_name] == res_capacity assert cluster_res[res_name] == res_capacity @@ -43,8 +43,8 @@ def delete_res(resource_name): ray.get(delete_res.remote(res_name)) - available_res = ray.global_state.available_resources() - cluster_res = ray.global_state.cluster_resources() + available_res = ray.available_resources() + cluster_res = ray.cluster_resources() assert res_name not in available_res assert res_name not in cluster_res @@ -69,7 +69,7 @@ def f(): oid = remote_task.remote() # This is infeasible ray.get(set_res.remote(res_name, res_capacity)) # Now should be feasible - available_res = ray.global_state.available_resources() + available_res = ray.available_resources() assert available_res[res_name] == res_capacity successful, unsuccessful = ray.wait([oid], timeout=1) @@ -88,7 +88,7 @@ def test_dynamic_res_updation_clientid(ray_start_cluster): ray.init(redis_address=cluster.redis_address) - target_clientid = ray.global_state.client_table()[1]["ClientID"] + target_clientid = ray.nodes()[1]["ClientID"] @ray.remote def set_res(resource_name, resource_capacity, client_id): @@ -102,7 +102,7 @@ def set_res(resource_name, resource_capacity, client_id): new_capacity = res_capacity + 1 ray.get(set_res.remote(res_name, new_capacity, target_clientid)) - target_client = next(client for client in ray.global_state.client_table() + target_client = next(client for client in ray.nodes() if client["ClientID"] == target_clientid) resources = target_client["Resources"] @@ -122,7 +122,7 @@ def test_dynamic_res_creation_clientid(ray_start_cluster): ray.init(redis_address=cluster.redis_address) - target_clientid = ray.global_state.client_table()[1]["ClientID"] + target_clientid = ray.nodes()[1]["ClientID"] @ray.remote def set_res(resource_name, resource_capacity, res_client_id): @@ -130,7 +130,7 @@ def set_res(resource_name, resource_capacity, res_client_id): resource_name, resource_capacity, client_id=res_client_id) ray.get(set_res.remote(res_name, res_capacity, target_clientid)) - target_client = next(client for client in ray.global_state.client_table() + target_client = next(client for client in ray.nodes() if client["ClientID"] == target_clientid) resources = target_client["Resources"] @@ -152,9 +152,7 @@ def test_dynamic_res_creation_clientid_multiple(ray_start_cluster): ray.init(redis_address=cluster.redis_address) - target_clientids = [ - client["ClientID"] for client in ray.global_state.client_table() - ] + target_clientids = [client["ClientID"] for client in ray.nodes()] @ray.remote def set_res(resource_name, resource_capacity, res_client_id): @@ -172,9 +170,8 @@ def set_res(resource_name, resource_capacity, res_client_id): while time.time() - start_time < TIMEOUT and not success: resources_created = [] for cid in target_clientids: - target_client = next(client - for client in ray.global_state.client_table() - if client["ClientID"] == cid) + target_client = next( + client for client in ray.nodes() if client["ClientID"] == cid) resources = target_client["Resources"] resources_created.append(resources[res_name] == res_capacity) success = all(resources_created) @@ -196,7 +193,7 @@ def test_dynamic_res_deletion_clientid(ray_start_cluster): ray.init(redis_address=cluster.redis_address) - target_clientid = ray.global_state.client_table()[1]["ClientID"] + target_clientid = ray.nodes()[1]["ClientID"] # Launch the delete task @ray.remote @@ -206,10 +203,10 @@ def delete_res(resource_name, res_client_id): ray.get(delete_res.remote(res_name, target_clientid)) - target_client = next(client for client in ray.global_state.client_table() + target_client = next(client for client in ray.nodes() if client["ClientID"] == target_clientid) resources = target_client["Resources"] - print(ray.global_state.cluster_resources()) + print(ray.cluster_resources()) assert res_name not in resources @@ -228,9 +225,7 @@ def test_dynamic_res_creation_scheduler_consistency(ray_start_cluster): ray.init(redis_address=cluster.redis_address) - clientids = [ - client["ClientID"] for client in ray.global_state.client_table() - ] + clientids = [client["ClientID"] for client in ray.nodes()] @ray.remote def set_res(resource_name, resource_capacity, res_client_id): @@ -267,9 +262,7 @@ def test_dynamic_res_deletion_scheduler_consistency(ray_start_cluster): ray.init(redis_address=cluster.redis_address) - clientids = [ - client["ClientID"] for client in ray.global_state.client_table() - ] + clientids = [client["ClientID"] for client in ray.nodes()] @ray.remote def delete_res(resource_name, res_client_id): @@ -284,7 +277,7 @@ def set_res(resource_name, resource_capacity, res_client_id): # Create the resource on node1 target_clientid = clientids[1] ray.get(set_res.remote(res_name, res_capacity, target_clientid)) - assert ray.global_state.cluster_resources()[res_name] == res_capacity + assert ray.cluster_resources()[res_name] == res_capacity # Delete the resource ray.get(delete_res.remote(res_name, target_clientid)) @@ -322,9 +315,7 @@ def test_dynamic_res_concurrent_res_increment(ray_start_cluster): ray.init(redis_address=cluster.redis_address) - clientids = [ - client["ClientID"] for client in ray.global_state.client_table() - ] + clientids = [client["ClientID"] for client in ray.nodes()] target_clientid = clientids[1] @ray.remote @@ -334,7 +325,7 @@ def set_res(resource_name, resource_capacity, res_client_id): # Create the resource on node 1 ray.get(set_res.remote(res_name, res_capacity, target_clientid)) - assert ray.global_state.cluster_resources()[res_name] == res_capacity + assert ray.cluster_resources()[res_name] == res_capacity # Task to hold the resource till the driver signals to finish @ray.remote @@ -376,7 +367,7 @@ def test_func(): }) # This should be infeasible successful, unsuccessful = ray.wait([task_3], timeout=TIMEOUT_DURATION) assert unsuccessful # The task did not complete because it's infeasible - assert ray.global_state.available_resources()[res_name] == updated_capacity + assert ray.available_resources()[res_name] == updated_capacity def test_dynamic_res_concurrent_res_decrement(ray_start_cluster): @@ -403,9 +394,7 @@ def test_dynamic_res_concurrent_res_decrement(ray_start_cluster): ray.init(redis_address=cluster.redis_address) - clientids = [ - client["ClientID"] for client in ray.global_state.client_table() - ] + clientids = [client["ClientID"] for client in ray.nodes()] target_clientid = clientids[1] @ray.remote @@ -415,7 +404,7 @@ def set_res(resource_name, resource_capacity, res_client_id): # Create the resource on node 1 ray.get(set_res.remote(res_name, res_capacity, target_clientid)) - assert ray.global_state.cluster_resources()[res_name] == res_capacity + assert ray.cluster_resources()[res_name] == res_capacity # Task to hold the resource till the driver signals to finish @ray.remote @@ -457,7 +446,7 @@ def test_func(): }) # This should be infeasible successful, unsuccessful = ray.wait([task_3], timeout=TIMEOUT_DURATION) assert unsuccessful # The task did not complete because it's infeasible - assert ray.global_state.available_resources()[res_name] == updated_capacity + assert ray.available_resources()[res_name] == updated_capacity def test_dynamic_res_concurrent_res_delete(ray_start_cluster): @@ -482,9 +471,7 @@ def test_dynamic_res_concurrent_res_delete(ray_start_cluster): ray.init(redis_address=cluster.redis_address) - clientids = [ - client["ClientID"] for client in ray.global_state.client_table() - ] + clientids = [client["ClientID"] for client in ray.nodes()] target_clientid = clientids[1] @ray.remote @@ -499,7 +486,7 @@ def delete_res(resource_name, res_client_id): # Create the resource on node 1 ray.get(set_res.remote(res_name, res_capacity, target_clientid)) - assert ray.global_state.cluster_resources()[res_name] == res_capacity + assert ray.cluster_resources()[res_name] == res_capacity # Task to hold the resource till the driver signals to finish @ray.remote @@ -534,7 +521,7 @@ def test_func(): args=[], resources={res_name: 1}) # This should be infeasible successful, unsuccessful = ray.wait([task_2], timeout=TIMEOUT_DURATION) assert unsuccessful # The task did not complete because it's infeasible - assert res_name not in ray.global_state.available_resources() + assert res_name not in ray.available_resources() def test_dynamic_res_creation_stress(ray_start_cluster): @@ -553,9 +540,7 @@ def test_dynamic_res_creation_stress(ray_start_cluster): ray.init(redis_address=cluster.redis_address) - clientids = [ - client["ClientID"] for client in ray.global_state.client_table() - ] + clientids = [client["ClientID"] for client in ray.nodes()] target_clientid = clientids[1] @ray.remote @@ -578,7 +563,7 @@ def delete_res(resource_name, res_client_id): start_time = time.time() while time.time() - start_time < TIMEOUT and not success: - resources = ray.global_state.cluster_resources() + resources = ray.cluster_resources() all_resources_created = [] for i in range(0, NUM_RES_TO_CREATE): all_resources_created.append(str(i) in resources) diff --git a/python/ray/tests/test_failure.py b/python/ray/tests/test_failure.py index 6a782ee726da..51b906695c2d 100644 --- a/python/ray/tests/test_failure.py +++ b/python/ray/tests/test_failure.py @@ -164,7 +164,7 @@ def get_val(self): return 1 # There should be no errors yet. - assert len(ray.error_info()) == 0 + assert len(ray.errors()) == 0 # Create an actor. foo = Foo.remote() @@ -376,8 +376,9 @@ class Actor(object): a = Actor.remote() a.__ray_terminate__.remote() time.sleep(1) - assert len(ray.error_info()) == 0, ( - "Should not have propogated an error - {}".format(ray.error_info())) + assert len( + ray.errors()) == 0, ("Should not have propogated an error - {}".format( + ray.errors())) @pytest.mark.skip("This test does not work yet.") @@ -653,7 +654,7 @@ def test_warning_for_dead_node(ray_start_cluster_2_nodes): cluster = ray_start_cluster_2_nodes cluster.wait_for_nodes() - client_ids = {item["ClientID"] for item in ray.global_state.client_table()} + client_ids = {item["ClientID"] for item in ray.nodes()} # Try to make sure that the monitor has received at least one heartbeat # from the node. diff --git a/python/ray/tests/test_global_state.py b/python/ray/tests/test_global_state.py index bc82eb8590c3..db71fc69c73b 100644 --- a/python/ray/tests/test_global_state.py +++ b/python/ray/tests/test_global_state.py @@ -18,8 +18,8 @@ reason="Timeout package not installed; skipping test that may hang.") @pytest.mark.timeout(10) def test_replenish_resources(ray_start_regular): - cluster_resources = ray.global_state.cluster_resources() - available_resources = ray.global_state.available_resources() + cluster_resources = ray.cluster_resources() + available_resources = ray.available_resources() assert cluster_resources == available_resources @ray.remote @@ -30,7 +30,7 @@ def cpu_task(): resources_reset = False while not resources_reset: - available_resources = ray.global_state.available_resources() + available_resources = ray.available_resources() resources_reset = (cluster_resources == available_resources) assert resources_reset @@ -40,7 +40,7 @@ def cpu_task(): reason="Timeout package not installed; skipping test that may hang.") @pytest.mark.timeout(10) def test_uses_resources(ray_start_regular): - cluster_resources = ray.global_state.cluster_resources() + cluster_resources = ray.cluster_resources() @ray.remote def cpu_task(): @@ -50,7 +50,7 @@ def cpu_task(): resource_used = False while not resource_used: - available_resources = ray.global_state.available_resources() + available_resources = ray.available_resources() resource_used = available_resources.get( "CPU", 0) == cluster_resources.get("CPU", 0) - 1 @@ -64,17 +64,17 @@ def cpu_task(): def test_add_remove_cluster_resources(ray_start_cluster_head): """Tests that Global State API is consistent with actual cluster.""" cluster = ray_start_cluster_head - assert ray.global_state.cluster_resources()["CPU"] == 1 + assert ray.cluster_resources()["CPU"] == 1 nodes = [] nodes += [cluster.add_node(num_cpus=1)] cluster.wait_for_nodes() - assert ray.global_state.cluster_resources()["CPU"] == 2 + assert ray.cluster_resources()["CPU"] == 2 cluster.remove_node(nodes.pop()) cluster.wait_for_nodes() - assert ray.global_state.cluster_resources()["CPU"] == 1 + assert ray.cluster_resources()["CPU"] == 1 for i in range(5): nodes += [cluster.add_node(num_cpus=1)] cluster.wait_for_nodes() - assert ray.global_state.cluster_resources()["CPU"] == 6 + assert ray.cluster_resources()["CPU"] == 6 diff --git a/python/ray/tests/test_monitors.py b/python/ray/tests/test_monitors.py index 36ed55a52474..9eebe7e45087 100644 --- a/python/ray/tests/test_monitors.py +++ b/python/ray/tests/test_monitors.py @@ -30,17 +30,16 @@ def _test_cleanup_on_driver_exit(num_redis_shards): time.sleep(2) def StateSummary(): - obj_tbl_len = len(ray.global_state.object_table()) - task_tbl_len = len(ray.global_state.task_table()) - func_tbl_len = len(ray.global_state.function_table()) - return obj_tbl_len, task_tbl_len, func_tbl_len + obj_tbl_len = len(ray.objects()) + task_tbl_len = len(ray.tasks()) + return obj_tbl_len, task_tbl_len def Driver(success): success.value = True # Start driver. ray.init(redis_address=redis_address) summary_start = StateSummary() - if (0, 1) != summary_start[:2]: + if (0, 1) != summary_start: success.value = False # Two new objects. @@ -54,7 +53,7 @@ def f(): # 1 new function. attempts = 0 - while (2, 1, summary_start[2]) != StateSummary(): + while (2, 1) != StateSummary(): time.sleep(0.1) attempts += 1 if attempts == max_attempts_before_failing: @@ -63,7 +62,7 @@ def f(): ray.get(f.remote()) attempts = 0 - while (4, 2, summary_start[2] + 1) != StateSummary(): + while (4, 2) != StateSummary(): time.sleep(0.1) attempts += 1 if attempts == max_attempts_before_failing: @@ -83,12 +82,12 @@ def f(): # Check that objects, tasks, and functions are cleaned up. ray.init(redis_address=redis_address) attempts = 0 - while (0, 1) != StateSummary()[:2]: + while (0, 1) != StateSummary(): time.sleep(0.1) attempts += 1 if attempts == max_attempts_before_failing: break - assert (0, 1) == StateSummary()[:2] + assert (0, 1) == StateSummary() ray.shutdown() subprocess.Popen(["ray", "stop"]).wait() diff --git a/python/ray/tests/test_multi_node.py b/python/ray/tests/test_multi_node.py index a963f6b15ea1..07f0d621c483 100644 --- a/python/ray/tests/test_multi_node.py +++ b/python/ray/tests/test_multi_node.py @@ -19,7 +19,7 @@ def test_error_isolation(call_ray_start): ray.init(redis_address=redis_address) # There shouldn't be any errors yet. - assert len(ray.error_info()) == 0 + assert len(ray.errors()) == 0 error_string1 = "error_string1" error_string2 = "error_string2" @@ -33,13 +33,13 @@ def f(): ray.get(f.remote()) # Wait for the error to appear in Redis. - while len(ray.error_info()) != 1: + while len(ray.errors()) != 1: time.sleep(0.1) print("Waiting for error to appear.") # Make sure we got the error. - assert len(ray.error_info()) == 1 - assert error_string1 in ray.error_info()[0]["message"] + assert len(ray.errors()) == 1 + assert error_string1 in ray.errors()[0]["message"] # Start another driver and make sure that it does not receive this # error. Make the other driver throw an error, and make sure it @@ -51,7 +51,7 @@ def f(): ray.init(redis_address="{}") time.sleep(1) -assert len(ray.error_info()) == 0 +assert len(ray.errors()) == 0 @ray.remote def f(): @@ -62,12 +62,12 @@ def f(): except Exception as e: pass -while len(ray.error_info()) != 1: - print(len(ray.error_info())) +while len(ray.errors()) != 1: + print(len(ray.errors())) time.sleep(0.1) -assert len(ray.error_info()) == 1 +assert len(ray.errors()) == 1 -assert "{}" in ray.error_info()[0]["message"] +assert "{}" in ray.errors()[0]["message"] print("success") """.format(redis_address, error_string2, error_string2) @@ -78,8 +78,8 @@ def f(): # Make sure that the other error message doesn't show up for this # driver. - assert len(ray.error_info()) == 1 - assert error_string1 in ray.error_info()[0]["message"] + assert len(ray.errors()) == 1 + assert error_string1 in ray.errors()[0]["message"] def test_remote_function_isolation(call_ray_start): diff --git a/python/ray/tests/test_multi_node_2.py b/python/ray/tests/test_multi_node_2.py index e66a3799e25e..979f4728330f 100644 --- a/python/ray/tests/test_multi_node_2.py +++ b/python/ray/tests/test_multi_node_2.py @@ -52,10 +52,10 @@ def test_internal_config(ray_start_cluster_head): cluster.remove_node(worker) time.sleep(1) - assert ray.global_state.cluster_resources()["CPU"] == 2 + assert ray.cluster_resources()["CPU"] == 2 time.sleep(2) - assert ray.global_state.cluster_resources()["CPU"] == 1 + assert ray.cluster_resources()["CPU"] == 1 def test_wait_for_nodes(ray_start_cluster_head): @@ -70,12 +70,12 @@ def test_wait_for_nodes(ray_start_cluster_head): [cluster.remove_node(w) for w in workers] cluster.wait_for_nodes() - assert ray.global_state.cluster_resources()["CPU"] == 1 + assert ray.cluster_resources()["CPU"] == 1 worker2 = cluster.add_node() cluster.wait_for_nodes() cluster.remove_node(worker2) cluster.wait_for_nodes() - assert ray.global_state.cluster_resources()["CPU"] == 1 + assert ray.cluster_resources()["CPU"] == 1 def test_worker_plasma_store_failure(ray_start_cluster_head): diff --git a/python/ray/tests/test_object_manager.py b/python/ray/tests/test_object_manager.py index e02e3d9a7d6e..bbe47a7e47d0 100644 --- a/python/ray/tests/test_object_manager.py +++ b/python/ray/tests/test_object_manager.py @@ -80,7 +80,7 @@ def create_object(): # Wait for profiling information to be pushed to the profile table. time.sleep(1) - transfer_events = ray.global_state.chrome_tracing_object_transfer_dump() + transfer_events = ray.object_transfer_timeline() # Make sure that each object was transferred a reasonable number of times. for x_id in object_ids: @@ -160,7 +160,7 @@ def set_weights(self, x): # Wait for profiling information to be pushed to the profile table. time.sleep(1) - transfer_events = ray.global_state.chrome_tracing_object_transfer_dump() + transfer_events = ray.object_transfer_timeline() # Make sure that each object was transferred a reasonable number of times. for x_id in object_ids: diff --git a/python/ray/tests/test_stress.py b/python/ray/tests/test_stress.py index 4f94e2310b7c..1135d71011bf 100644 --- a/python/ray/tests/test_stress.py +++ b/python/ray/tests/test_stress.py @@ -393,7 +393,7 @@ def wait_for_errors(error_check): errors = [] time_left = 100 while time_left > 0: - errors = ray.error_info() + errors = ray.errors() if error_check(errors): break time_left -= 1 diff --git a/python/ray/tests/utils.py b/python/ray/tests/utils.py index 22146e89fa65..bd9291d8fa81 100644 --- a/python/ray/tests/utils.py +++ b/python/ray/tests/utils.py @@ -84,7 +84,7 @@ def run_string_as_driver_nonblocking(driver_script): def relevant_errors(error_type): - return [info for info in ray.error_info() if info["type"] == error_type] + return [info for info in ray.errors() if info["type"] == error_type] def wait_for_errors(error_type, num_errors, timeout=10): diff --git a/python/ray/tune/ray_trial_executor.py b/python/ray/tune/ray_trial_executor.py index e4938ac609fa..548e092cfb1d 100644 --- a/python/ray/tune/ray_trial_executor.py +++ b/python/ray/tune/ray_trial_executor.py @@ -356,7 +356,7 @@ def _return_resources(self, resources): def _update_avail_resources(self, num_retries=5): for i in range(num_retries): try: - resources = ray.global_state.cluster_resources() + resources = ray.cluster_resources() except Exception: # TODO(rliaw): Remove this when local mode is fixed. # https://github.com/ray-project/ray/issues/4147 diff --git a/python/ray/tune/tests/test_cluster.py b/python/ray/tune/tests/test_cluster.py index 4f962299d51a..e00e5da371c5 100644 --- a/python/ray/tune/tests/test_cluster.py +++ b/python/ray/tune/tests/test_cluster.py @@ -71,7 +71,7 @@ def test_counting_resources(start_connected_cluster): """Tests that Tune accounting is consistent with actual cluster.""" cluster = start_connected_cluster nodes = [] - assert ray.global_state.cluster_resources()["CPU"] == 1 + assert ray.cluster_resources()["CPU"] == 1 runner = TrialRunner(BasicVariantGenerator()) kwargs = {"stopping_criterion": {"training_iteration": 10}} @@ -82,17 +82,17 @@ def test_counting_resources(start_connected_cluster): runner.step() # run 1 nodes += [cluster.add_node(num_cpus=1)] cluster.wait_for_nodes() - assert ray.global_state.cluster_resources()["CPU"] == 2 + assert ray.cluster_resources()["CPU"] == 2 cluster.remove_node(nodes.pop()) cluster.wait_for_nodes() - assert ray.global_state.cluster_resources()["CPU"] == 1 + assert ray.cluster_resources()["CPU"] == 1 runner.step() # run 2 assert sum(t.status == Trial.RUNNING for t in runner.get_trials()) == 1 for i in range(5): nodes += [cluster.add_node(num_cpus=1)] cluster.wait_for_nodes() - assert ray.global_state.cluster_resources()["CPU"] == 6 + assert ray.cluster_resources()["CPU"] == 6 runner.step() # 1 result assert sum(t.status == Trial.RUNNING for t in runner.get_trials()) == 2 @@ -120,7 +120,7 @@ def test_remove_node_before_result(start_connected_emptyhead_cluster): cluster.remove_node(node) cluster.add_node(num_cpus=1) cluster.wait_for_nodes() - assert ray.global_state.cluster_resources()["CPU"] == 1 + assert ray.cluster_resources()["CPU"] == 1 for i in range(3): runner.step() diff --git a/python/ray/tune/tests/test_trial_runner.py b/python/ray/tune/tests/test_trial_runner.py index 19930559ce7d..a9bf8e3239c6 100644 --- a/python/ray/tune/tests/test_trial_runner.py +++ b/python/ray/tune/tests/test_trial_runner.py @@ -1532,7 +1532,7 @@ def testFailureRecoveryNodeRemoval(self): runner.add_trial(Trial("__fake", **kwargs)) trials = runner.get_trials() - with patch("ray.global_state.cluster_resources") as resource_mock: + with patch("ray.cluster_resources") as resource_mock: resource_mock.return_value = {"CPU": 1, "GPU": 1} runner.step() self.assertEqual(trials[0].status, Trial.RUNNING) diff --git a/python/ray/worker.py b/python/ray/worker.py index 5feb71344bce..c886159aafec 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -25,7 +25,6 @@ import ray.cloudpickle as pickle import ray.experimental.signal as ray_signal import ray.experimental.no_return -import ray.experimental.state as state import ray.gcs_utils import ray.memory_monitor as memory_monitor import ray.node @@ -35,6 +34,7 @@ import ray.serialization as serialization import ray.services as services import ray.signature +import ray.state from ray import ( ActorHandleID, @@ -1108,8 +1108,6 @@ def get_webui_url(): per worker process. """ -global_state = state.GlobalState() - _global_node = None """ray.node.Node: The global node object that is created by ray.init().""" @@ -1134,14 +1132,6 @@ def print_failed_task(task_status): task_status["error_message"])) -def error_info(): - """Return information about failed tasks.""" - worker = global_worker - worker.check_connected() - return (global_state.error_messages(driver_id=worker.task_driver_id) + - global_state.error_messages(driver_id=DriverID.nil())) - - def _initialize_serialization(driver_id, worker=global_worker): """Initialize the serialization library. @@ -1488,7 +1478,7 @@ def shutdown(exiting_interpreter=False): disconnect() # Disconnect global state from GCS. - global_state.disconnect() + ray.state.state.disconnect() # Shut down the Ray processes. global _global_node @@ -1644,7 +1634,7 @@ def listen_error_messages_raylet(worker, task_error_queue, threads_stopped): try: # Get the exports that occurred before the call to subscribe. - error_messages = global_state.error_messages(worker.task_driver_id) + error_messages = ray.errors(include_cluster_errors=False) for error_message in error_messages: logger.error(error_message) @@ -1774,7 +1764,7 @@ def connect(node, worker.lock = threading.RLock() # Create an object for interfacing with the global state. - global_state._initialize_global_state( + ray.state.state._initialize_global_state( node.redis_address, redis_password=node.redis_password) # Register the worker with Redis. @@ -1881,11 +1871,12 @@ def connect(node, ) # Add the driver task to the task table. - global_state._execute_command(driver_task.task_id(), "RAY.TABLE_ADD", - ray.gcs_utils.TablePrefix.RAYLET_TASK, - ray.gcs_utils.TablePubsub.RAYLET_TASK, - driver_task.task_id().binary(), - driver_task._serialized_raylet_task()) + ray.state.state._execute_command(driver_task.task_id(), + "RAY.TABLE_ADD", + ray.gcs_utils.TablePrefix.RAYLET_TASK, + ray.gcs_utils.TablePubsub.RAYLET_TASK, + driver_task.task_id().binary(), + driver_task._serialized_raylet_task()) # Set the driver's current task ID to the task ID assigned to the # driver task. From 7a78e1e3209ee0c5eba284bdda79ea5bf37db674 Mon Sep 17 00:00:00 2001 From: Robert Nishihara Date: Sun, 26 May 2019 16:13:50 -0700 Subject: [PATCH 040/118] Install bazel in autoscaler development configs. (#4874) --- ci/long_running_tests/config.yaml | 1 + ci/stress_tests/stress_testing_config.yaml | 1 + python/ray/autoscaler/aws/development-example.yaml | 1 + 3 files changed, 3 insertions(+) diff --git a/ci/long_running_tests/config.yaml b/ci/long_running_tests/config.yaml index cbc7feb435af..b9667ae648bb 100644 --- a/ci/long_running_tests/config.yaml +++ b/ci/long_running_tests/config.yaml @@ -49,6 +49,7 @@ setup_commands: # - sudo apt-get update # - sudo apt-get install -y build-essential curl unzip # - git clone https://github.com/ray-project/ray || true + # - ray/ci/travis/install-bazel.sh # - cd ray/python; git checkout master; git pull; pip install -e . --verbose # Install nightly Ray wheels. - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/<>/ray-<>-cp36-cp36m-manylinux1_x86_64.whl diff --git a/ci/stress_tests/stress_testing_config.yaml b/ci/stress_tests/stress_testing_config.yaml index 07c27bab79b6..f71ae8f2dc18 100644 --- a/ci/stress_tests/stress_testing_config.yaml +++ b/ci/stress_tests/stress_testing_config.yaml @@ -98,6 +98,7 @@ setup_commands: - echo 'export PATH="$HOME/anaconda3/bin:$PATH"' >> ~/.bashrc # # Build Ray. # - git clone https://github.com/ray-project/ray || true + # - ray/ci/travis/install-bazel.sh - pip install boto3==1.4.8 cython==0.29.0 # - cd ray/python; git checkout master; git pull; pip install -e . --verbose - pip install https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp36-cp36m-manylinux1_x86_64.whl diff --git a/python/ray/autoscaler/aws/development-example.yaml b/python/ray/autoscaler/aws/development-example.yaml index 0986a48ecc05..539c28643faa 100644 --- a/python/ray/autoscaler/aws/development-example.yaml +++ b/python/ray/autoscaler/aws/development-example.yaml @@ -94,6 +94,7 @@ setup_commands: - echo 'export PATH="$HOME/anaconda3/bin:$PATH"' >> ~/.bashrc # Build Ray. - git clone https://github.com/ray-project/ray || true + - ray/ci/travis/install-bazel.sh - pip install boto3==1.4.8 cython==0.29.0 - cd ray/python; pip install -e . --verbose From 574e1c76959bf1b0a91ec65d2c21380b549c6900 Mon Sep 17 00:00:00 2001 From: Richard Liaw Date: Mon, 27 May 2019 13:23:17 -0700 Subject: [PATCH 041/118] [tune] Fix up Ax Search and Examples (#4851) * update Ax for cleaner API * docs update --- doc/source/tune-searchalg.rst | 4 ++- python/ray/tune/examples/ax_example.py | 7 +++-- python/ray/tune/suggest/ax.py | 42 ++++++++++---------------- 3 files changed, 24 insertions(+), 29 deletions(-) diff --git a/doc/source/tune-searchalg.rst b/doc/source/tune-searchalg.rst index 0a2bf491a676..2dae8eaf4abe 100644 --- a/doc/source/tune-searchalg.rst +++ b/doc/source/tune-searchalg.rst @@ -171,7 +171,9 @@ This algorithm requires specifying a search space and objective. You can use `Ax .. code-block:: python - tune.run(... , search_alg=AxSearch(parameter_dicts, ... )) + client = AxClient(enforce_sequential_optimization=False) + client.create_experiment( ... ) + tune.run(... , search_alg=AxSearch(client)) An example of this can be found in `ax_example.py `__. diff --git a/python/ray/tune/examples/ax_example.py b/python/ray/tune/examples/ax_example.py index 07bb7f79a1f3..8620986a26ea 100644 --- a/python/ray/tune/examples/ax_example.py +++ b/python/ray/tune/examples/ax_example.py @@ -51,11 +51,13 @@ def easy_objective(config, reporter): if __name__ == "__main__": import argparse + from ax.service.ax_client import AxClient parser = argparse.ArgumentParser() parser.add_argument( "--smoke-test", action="store_true", help="Finish quickly for testing") args, _ = parser.parse_known_args() + ray.init() config = { @@ -101,13 +103,14 @@ def easy_objective(config, reporter): "bounds": [0.0, 1.0], }, ] - algo = AxSearch( + client = AxClient(enforce_sequential_optimization=False) + client.create_experiment( parameters=parameters, objective_name="hartmann6", - max_concurrent=4, minimize=True, # Optional, defaults to False. parameter_constraints=["x1 + x2 <= 2.0"], # Optional. outcome_constraints=["l2norm <= 1.25"], # Optional. ) + algo = AxSearch(client, max_concurrent=4) scheduler = AsyncHyperBandScheduler(reward_attr="hartmann6") run(easy_objective, name="ax", search_alg=algo, **config) diff --git a/python/ray/tune/suggest/ax.py b/python/ray/tune/suggest/ax.py index a48852e84864..75b982d67087 100644 --- a/python/ray/tune/suggest/ax.py +++ b/python/ray/tune/suggest/ax.py @@ -6,16 +6,19 @@ import ax except ImportError: ax = None +import logging from ray.tune.suggest.suggestion import SuggestionAlgorithm +logger = logging.getLogger(__name__) + class AxSearch(SuggestionAlgorithm): """A wrapper around Ax to provide trial suggestions. - Requires Ax to be installed. - Ax is an open source tool from Facebook for configuring and - optimizing experiments. More information can be found in https://ax.dev/. + Requires Ax to be installed. Ax is an open source tool from + Facebook for configuring and optimizing experiments. More information + can be found in https://ax.dev/. Parameters: parameters (list[dict]): Parameters in the experiment search space. @@ -48,40 +51,27 @@ class AxSearch(SuggestionAlgorithm): >>> objective_name="hartmann6", max_concurrent=4) """ - def __init__(self, - parameters, - objective_name, - max_concurrent=10, - minimize=False, - parameter_constraints=None, - outcome_constraints=None, - **kwargs): + def __init__(self, ax_client, max_concurrent=10, **kwargs): assert ax is not None, "Ax must be installed!" - from ax.service import ax_client assert type(max_concurrent) is int and max_concurrent > 0 - self._ax = ax_client.AxClient(enforce_sequential_optimization=False) - self._ax.create_experiment( - name="ax", - parameters=parameters, - objective_name=objective_name, - minimize=minimize, - parameter_constraints=parameter_constraints or [], - outcome_constraints=outcome_constraints or [], - ) + self._ax = ax_client + exp = self._ax.experiment + self._objective_name = exp.optimization_config.objective.metric.name + if self._ax._enforce_sequential_optimization: + logger.warning("Detected sequential enforcement. Setting max " + "concurrency to 1.") + max_concurrent = 1 self._max_concurrent = max_concurrent - self._parameters = [d["name"] for d in parameters] - self._objective_name = objective_name + self._parameters = list(exp.parameters) self._live_index_mapping = {} - super(AxSearch, self).__init__(**kwargs) def _suggest(self, trial_id): if self._num_live_trials() >= self._max_concurrent: return None parameters, trial_index = self._ax.get_next_trial() - suggested_config = list(parameters.values()) self._live_index_mapping[trial_id] = trial_index - return dict(zip(self._parameters, suggested_config)) + return parameters def on_trial_result(self, trial_id, result): pass From a45c61e19b950e3229b2863eaad3aa85d890eaa3 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Mon, 27 May 2019 14:17:32 -0700 Subject: [PATCH 042/118] [rllib] Update concepts docs and add "Building Policies in Torch/TensorFlow" section (#4821) * wip * fix index * fix bugs * todo * add imports * note on get ph * note on get ph * rename to building custom algs * add rnn state info --- doc/source/index.rst | 4 +- doc/source/rllib-concepts.rst | 431 ++++++++++++++++++-- doc/source/rllib-env.rst | 2 +- doc/source/rllib.rst | 20 +- python/ray/rllib/agents/pg/pg.py | 2 +- python/ray/rllib/agents/ppo/ppo.py | 6 +- python/ray/rllib/agents/trainer_template.py | 13 +- 7 files changed, 430 insertions(+), 48 deletions(-) diff --git a/doc/source/index.rst b/doc/source/index.rst index eba9eaa6ccac..a90e0224bb02 100644 --- a/doc/source/index.rst +++ b/doc/source/index.rst @@ -98,10 +98,10 @@ Ray comes with libraries that accelerate deep learning and reinforcement learnin rllib-models.rst rllib-algorithms.rst rllib-offline.rst - rllib-dev.rst rllib-concepts.rst - rllib-package-ref.rst rllib-examples.rst + rllib-dev.rst + rllib-package-ref.rst .. toctree:: :maxdepth: 1 diff --git a/doc/source/rllib-concepts.rst b/doc/source/rllib-concepts.rst index e3e7948c864f..06e890832295 100644 --- a/doc/source/rllib-concepts.rst +++ b/doc/source/rllib-concepts.rst @@ -1,5 +1,5 @@ -RLlib Concepts -============== +RLlib Concepts and Building Custom Algorithms +============================================= This page describes the internal concepts used to implement algorithms in RLlib. You might find this useful if modifying or adding new algorithms to RLlib. @@ -8,15 +8,16 @@ Policies Policy classes encapsulate the core numerical components of RL algorithms. This typically includes the policy model that determines actions to take, a trajectory postprocessor for experiences, and a loss function to improve the policy given postprocessed experiences. For a simple example, see the policy gradients `graph definition `__. -Most interaction with deep learning frameworks is isolated to the `Policy interface `__, allowing RLlib to support multiple frameworks. To simplify the definition of policies, RLlib includes `Tensorflow `__ and `PyTorch-specific `__ templates. You can also write your own from scratch. Here is an example: +Most interaction with deep learning frameworks is isolated to the `Policy interface `__, allowing RLlib to support multiple frameworks. To simplify the definition of policies, RLlib includes `Tensorflow <#building-policies-in-tensorflow>`__ and `PyTorch-specific <#building-policies-in-pytorch>`__ templates. You can also write your own from scratch. Here is an example: .. code-block:: python class CustomPolicy(Policy): """Example of a custom policy written from scratch. - You might find it more convenient to extend TF/TorchPolicy instead - for a real policy. + You might find it more convenient to use the `build_tf_policy` and + `build_torch_policy` helpers instead for a real policy, which are + described in the next sections. """ def __init__(self, observation_space, action_space, config): @@ -45,37 +46,413 @@ Most interaction with deep learning frameworks is isolated to the `Policy interf def set_weights(self, weights): self.w = weights["w"] + +The above basic policy, when run, will produce batches of observations with the basic ``obs``, ``new_obs``, ``actions``, ``rewards``, ``dones``, and ``infos`` columns. There are two more mechanisms to pass along and emit extra information: + +**Policy recurrent state**: Suppose you want to compute actions based on the current timestep of the episode. While it is possible to have the environment provide this as part of the observation, we can instead compute and store it as part of the Policy recurrent state: + +.. code-block:: python + + def get_initial_state(self): + """Returns initial RNN state for the current policy.""" + return [0] # list of single state element (t=0) + # you could also return multiple values, e.g., [0, "foo"] + + def compute_actions(self, + obs_batch, + state_batches, + prev_action_batch=None, + prev_reward_batch=None, + info_batch=None, + episodes=None, + **kwargs): + assert len(state_batches) == len(self.get_initial_state()) + new_state_batches = [[ + t + 1 for t in state_batches[0] + ]] + return ..., new_state_batches, {} + + def learn_on_batch(self, samples): + # can access array of the state elements at each timestep + # or state_in_1, 2, etc. if there are multiple state elements + assert "state_in_0" in samples.keys() + assert "state_out_0" in samples.keys() + + +**Extra action info output**: You can also emit extra outputs at each step which will be available for learning on. For example, you might want to output the behaviour policy logits as extra action info, which can be used for importance weighting, but in general arbitrary values can be stored here (as long as they are convertible to numpy arrays): + +.. code-block:: python + + def compute_actions(self, + obs_batch, + state_batches, + prev_action_batch=None, + prev_reward_batch=None, + info_batch=None, + episodes=None, + **kwargs): + action_info_batch = { + "some_value": ["foo" for _ in obs_batch], + "other_value": [12345 for _ in obs_batch], + } + return ..., [], action_info_batch + + def learn_on_batch(self, samples): + # can access array of the extra values at each timestep + assert "some_value" in samples.keys() + assert "other_value" in samples.keys() + + +Building Policies in TensorFlow +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +This section covers how to build a TensorFlow RLlib policy using ``tf_policy_template.build_tf_policy()``. + +To start, you first have to define a loss function. In RLlib, loss functions are defined over batches of trajectory data produced by policy evaluation. A basic policy gradient loss that only tries to maximize the 1-step reward can be defined as follows: + +.. code-block:: python + + import tensorflow as tf + from ray.rllib.policy.sample_batch import SampleBatch + + def policy_gradient_loss(policy, batch_tensors): + actions = batch_tensors[SampleBatch.ACTIONS] + rewards = batch_tensors[SampleBatch.REWARDS] + return -tf.reduce_mean(policy.action_dist.logp(actions) * rewards) + +In the above snippet, ``actions`` is a Tensor placeholder of shape ``[batch_size, action_dim...]``, and ``rewards`` is a placeholder of shape ``[batch_size]``. The ``policy.action_dist`` object is an `ActionDistribution `__ that represents the output of the neural network policy model. Passing this loss function to ``build_tf_policy`` is enough to produce a very basic TF policy: + +.. code-block:: python + + from ray.rllib.policy.tf_policy_template import build_tf_policy + + # + MyTFPolicy = build_tf_policy( + name="MyTFPolicy", + loss_fn=policy_gradient_loss) + +We can create a `Trainer <#trainers>`__ and try running this policy on a toy env with two parallel rollout workers: + +.. code-block:: python + + import ray + from ray import tune + from ray.rllib.agents.trainer_template import build_trainer + + # + MyTrainer = build_trainer( + name="MyCustomTrainer", + default_policy=MyTFPolicy) + + ray.init() + tune.run(MyTrainer, config={"env": "CartPole-v0", "num_workers": 2}) + + +If you run the above snippet, you'll probably notice that CartPole doesn't learn so well: + +.. code-block:: bash + + == Status == + Using FIFO scheduling algorithm. + Resources requested: 3/4 CPUs, 0/0 GPUs + Memory usage on this node: 4.6/12.3 GB + Result logdir: /home/ubuntu/ray_results/MyAlgTrainer + Number of trials: 1 ({'RUNNING': 1}) + RUNNING trials: + - MyAlgTrainer_CartPole-v0_0: RUNNING, [3 CPUs, 0 GPUs], [pid=26784], + 32 s, 156 iter, 62400 ts, 23.1 rew + +Let's modify our policy loss to include rewards summed over time. To enable this advantage calculation, we need to define a *trajectory postprocessor* for the policy. This can be done by defining ``postprocess_fn``: + +.. code-block:: python + + from ray.rllib.evaluation.postprocessing import compute_advantages, \ + Postprocessing + + def postprocess_advantages(policy, + sample_batch, + other_agent_batches=None, + episode=None): + return compute_advantages( + sample_batch, 0.0, policy.config["gamma"], use_gae=False) + + def policy_gradient_loss(policy, batch_tensors): + actions = batch_tensors[SampleBatch.ACTIONS] + advantages = batch_tensors[Postprocessing.ADVANTAGES] + return -tf.reduce_mean(policy.action_dist.logp(actions) * advantages) + + MyTFPolicy = build_tf_policy( + name="MyTFPolicy", + loss_fn=policy_gradient_loss, + postprocess_fn=postprocess_advantages) + +The ``postprocess_advantages()`` function above uses calls RLlib's ``compute_advantages`` function to compute advantages for each timestep. If you re-run the trainer with this improved policy, you'll find that it quickly achieves the max reward of 200. + +You might be wondering how RLlib makes the advantages placeholder automatically available as ``batch_tensors[Postprocessing.ADVANTAGES]``. When building your policy, RLlib will create a "dummy" trajectory batch where all observations, actions, rewards, etc. are zeros. It then calls your ``postprocess_fn``, and generates TF placeholders based on the numpy shapes of the postprocessed batch. RLlib tracks which placeholders that ``loss_fn`` and ``stats_fn`` access, and then feeds the corresponding sample data into those placeholders during loss optimization. You can also access these placeholders via ``policy.get_placeholder()`` after loss initialization. + +**Example 1: Proximal Policy Optimization** + +In the above section you saw how to compose a simple policy gradient algorithm with RLlib. In this example, we'll dive into how PPO was built with RLlib and how you can modify it. First, check out the `PPO trainer definition `__: + +.. code-block:: python + + PPOTrainer = build_trainer( + name="PPOTrainer", + default_config=DEFAULT_CONFIG, + default_policy=PPOTFPolicy, + make_policy_optimizer=choose_policy_optimizer, + validate_config=validate_config, + after_optimizer_step=update_kl, + before_train_step=warn_about_obs_filter, + after_train_result=warn_about_bad_reward_scales) + +Besides some boilerplate for defining the PPO configuration and some warnings, there are two important arguments to take note of here: ``make_policy_optimizer=choose_policy_optimizer``, and ``after_optimizer_step=update_kl``. + +The ``choose_policy_optimizer`` function chooses which `Policy Optimizer <#policy-optimization>`__ to use for distributed training. You can think of these policy optimizers as coordinating the distributed workflow needed to improve the policy. Depending on the trainer config, PPO can switch between a simple synchronous optimizer (the default), or a multi-GPU optimizer that implements minibatch SGD: + +.. code-block:: python + + def choose_policy_optimizer(workers, config): + if config["simple_optimizer"]: + return SyncSamplesOptimizer( + workers, + num_sgd_iter=config["num_sgd_iter"], + train_batch_size=config["train_batch_size"]) + + return LocalMultiGPUOptimizer( + workers, + sgd_batch_size=config["sgd_minibatch_size"], + num_sgd_iter=config["num_sgd_iter"], + num_gpus=config["num_gpus"], + sample_batch_size=config["sample_batch_size"], + num_envs_per_worker=config["num_envs_per_worker"], + train_batch_size=config["train_batch_size"], + standardize_fields=["advantages"], + straggler_mitigation=config["straggler_mitigation"]) + +Suppose we want to customize PPO to use an asynchronous-gradient optimization strategy similar to A3C. To do that, we could define a new function that returns ``AsyncGradientsOptimizer`` and pass in ``make_policy_optimizer=make_async_optimizer`` when building the trainer: + +.. code-block:: python + + from ray.rllib.agents.ppo.ppo_policy import * + from ray.rllib.optimizers import AsyncGradientsOptimizer + from ray.rllib.policy.tf_policy_template import build_tf_policy + + def make_async_optimizer(workers, config): + return AsyncGradientsOptimizer(workers, grads_per_step=100) + + PPOTrainer = build_trainer( + ..., + make_policy_optimizer=make_async_optimizer) + + +Now let's take a look at the ``update_kl`` function. This is used to adaptively adjust the KL penalty coefficient on the PPO loss, which bounds the policy change per training step. You'll notice the code handles both single and multi-agent cases (where there are be multiple policies each with different KL coeffs): + +.. code-block:: python + + def update_kl(trainer, fetches): + if "kl" in fetches: + # single-agent + trainer.workers.local_worker().for_policy( + lambda pi: pi.update_kl(fetches["kl"])) + else: + + def update(pi, pi_id): + if pi_id in fetches: + pi.update_kl(fetches[pi_id]["kl"]) + else: + logger.debug("No data for {}, not updating kl".format(pi_id)) + + # multi-agent + trainer.workers.local_worker().foreach_trainable_policy(update) + +The ``update_kl`` method on the policy is defined in `PPOTFPolicy `__ via the ``KLCoeffMixin``, along with several other advanced features. Let's look at each new feature used by the policy: + +.. code-block:: python + + PPOTFPolicy = build_tf_policy( + name="PPOTFPolicy", + get_default_config=lambda: ray.rllib.agents.ppo.ppo.DEFAULT_CONFIG, + loss_fn=ppo_surrogate_loss, + stats_fn=kl_and_loss_stats, + extra_action_fetches_fn=vf_preds_and_logits_fetches, + postprocess_fn=postprocess_ppo_gae, + gradients_fn=clip_gradients, + before_loss_init=setup_mixins, + mixins=[LearningRateSchedule, KLCoeffMixin, ValueNetworkMixin]) + +``stats_fn``: The stats function returns a dictionary of Tensors that will be reported with the training results. This also includes the ``kl`` metric which is used by the trainer to adjust the KL penalty. Note that many of the values below reference ``policy.loss_obj``, which is assigned by ``loss_fn`` (not shown here since the PPO loss is quite complex). RLlib will always call ``stats_fn`` after ``loss_fn``, so you can rely on using values saved by ``loss_fn`` as part of your statistics: + +.. code-block:: python + + def kl_and_loss_stats(policy, batch_tensors): + policy.explained_variance = explained_variance( + batch_tensors[Postprocessing.VALUE_TARGETS], policy.value_function) + + stats_fetches = { + "cur_kl_coeff": policy.kl_coeff, + "cur_lr": tf.cast(policy.cur_lr, tf.float64), + "total_loss": policy.loss_obj.loss, + "policy_loss": policy.loss_obj.mean_policy_loss, + "vf_loss": policy.loss_obj.mean_vf_loss, + "vf_explained_var": policy.explained_variance, + "kl": policy.loss_obj.mean_kl, + "entropy": policy.loss_obj.mean_entropy, + } + + return stats_fetches + +``extra_actions_fetches_fn``: This function defines extra outputs that will be recorded when generating actions with the policy. For example, this enables saving the raw policy logits in the experience batch, which e.g. means it can be referenced in the PPO loss function via ``batch_tensors[BEHAVIOUR_LOGITS]``. Other values such as the current value prediction can also be emitted for debugging or optimization purposes: + +.. code-block:: python + + def vf_preds_and_logits_fetches(policy): + return { + SampleBatch.VF_PREDS: policy.value_function, + BEHAVIOUR_LOGITS: policy.model.outputs, + } + +``gradients_fn``: If defined, this function returns TF gradients for the loss function. You'd typically only want to override this to apply transformations such as gradient clipping: + +.. code-block:: python + + def clip_gradients(policy, optimizer, loss): + if policy.config["grad_clip"] is not None: + policy.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, + tf.get_variable_scope().name) + grads = tf.gradients(loss, policy.var_list) + policy.grads, _ = tf.clip_by_global_norm(grads, + policy.config["grad_clip"]) + clipped_grads = list(zip(policy.grads, policy.var_list)) + return clipped_grads + else: + return optimizer.compute_gradients( + loss, colocate_gradients_with_ops=True) + +``mixins``: To add arbitrary stateful components, you can add mixin classes to the policy. Methods defined by these mixins will have higher priority than the base policy class, so you can use these to override methods (as in the case of ``LearningRateSchedule``), or define extra methods and attributes (e.g., ``KLCoeffMixin``, ``ValueNetworkMixin``). Like any other Python superclass, these should be initialized at some point, which is what the ``setup_mixins`` function does: + +.. code-block:: python + + def setup_mixins(policy, obs_space, action_space, config): + ValueNetworkMixin.__init__(policy, obs_space, action_space, config) + KLCoeffMixin.__init__(policy, config) + LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"]) + +In PPO we run ``setup_mixins`` before the loss function is called (i.e., ``before_loss_init``), but other callbacks you can use include ``before_init`` and ``after_init``. + +**Example 2: Deep Q Networks** + +(todo) + +Finally, note that you do not have to use ``build_tf_policy`` to define a TensorFlow policy. You can alternatively subclass ``Policy``, ``TFPolicy``, or ``DynamicTFPolicy`` as convenient. + +Building Policies in PyTorch +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Defining a policy in PyTorch is quite similar to that for TensorFlow (and the process of defining a trainer given a Torch policy is exactly the same). Building on the TF examples above, let's look at how the `A3C torch policy `__ is defined: + +.. code-block:: python + + A3CTorchPolicy = build_torch_policy( + name="A3CTorchPolicy", + get_default_config=lambda: ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG, + loss_fn=actor_critic_loss, + stats_fn=loss_and_entropy_stats, + postprocess_fn=add_advantages, + extra_action_out_fn=model_value_predictions, + extra_grad_process_fn=apply_grad_clipping, + optimizer_fn=torch_optimizer, + mixins=[ValueNetworkMixin]) + +``loss_fn``: Similar to the TF example, the actor critic loss is defined over ``batch_tensors``. We imperatively execute the forward pass by calling ``policy.model()`` on the observations followed by ``policy.dist_class()`` on the output logits. The output Tensors are saved as attributes of the policy object (e.g., ``policy.entropy = dist.entropy.mean()``), and we return the scalar loss: + +.. code-block:: python + + def actor_critic_loss(policy, batch_tensors): + logits, _, values, _ = policy.model({ + SampleBatch.CUR_OBS: batch_tensors[SampleBatch.CUR_OBS] + }, []) + dist = policy.dist_class(logits) + log_probs = dist.logp(batch_tensors[SampleBatch.ACTIONS]) + policy.entropy = dist.entropy().mean() + ... + return overall_err + +``stats_fn``: The stats function references ``entropy``, ``pi_err``, and ``value_err`` saved from the call to the loss function, similar in the PPO TF example: + +.. code-block:: python + + def loss_and_entropy_stats(policy, batch_tensors): + return { + "policy_entropy": policy.entropy.item(), + "policy_loss": policy.pi_err.item(), + "vf_loss": policy.value_err.item(), + } + +``extra_action_out_fn``: We save value function predictions given model outputs. This makes the value function predictions of the model available in the trajectory as ``batch_tensors[SampleBatch.VF_PREDS]``: + +.. code-block:: python + + def model_value_predictions(policy, model_out): + return {SampleBatch.VF_PREDS: model_out[2].cpu().numpy()} + +``postprocess_fn`` and ``mixins``: Similar to the PPO example, we need access to the value function during postprocessing (i.e., ``add_advantages`` below calls ``policy._value()``. The value function is exposed through a mixin class that defines the method: + +.. code-block:: python + + def add_advantages(policy, + sample_batch, + other_agent_batches=None, + episode=None): + completed = sample_batch[SampleBatch.DONES][-1] + if completed: + last_r = 0.0 + else: + last_r = policy._value(sample_batch[SampleBatch.NEXT_OBS][-1]) + return compute_advantages(sample_batch, last_r, policy.config["gamma"], + policy.config["lambda"]) + + class ValueNetworkMixin(object): + def _value(self, obs): + with self.lock: + obs = torch.from_numpy(obs).float().unsqueeze(0).to(self.device) + _, _, vf, _ = self.model({"obs": obs}, []) + return vf.detach().cpu().numpy().squeeze() + +You can find the full policy definition in `a3c_torch_policy.py `__. + +In summary, the main differences between the PyTorch and TensorFlow policy builder functions is that the TF loss and stats functions are built symbolically when the policy is initialized, whereas for PyTorch these functions are called imperatively each time they are used. + Policy Evaluation ----------------- -Given an environment and policy, policy evaluation produces `batches `__ of experiences. This is your classic "environment interaction loop". Efficient policy evaluation can be burdensome to get right, especially when leveraging vectorization, RNNs, or when operating in a multi-agent environment. RLlib provides a `PolicyEvaluator `__ class that manages all of this, and this class is used in most RLlib algorithms. +Given an environment and policy, policy evaluation produces `batches `__ of experiences. This is your classic "environment interaction loop". Efficient policy evaluation can be burdensome to get right, especially when leveraging vectorization, RNNs, or when operating in a multi-agent environment. RLlib provides a `RolloutWorker `__ class that manages all of this, and this class is used in most RLlib algorithms. -You can use policy evaluation standalone to produce batches of experiences. This can be done by calling ``ev.sample()`` on an evaluator instance, or ``ev.sample.remote()`` in parallel on evaluator instances created as Ray actors (see ``PolicyEvaluator.as_remote()``). +You can use rollout workers standalone to produce batches of experiences. This can be done by calling ``worker.sample()`` on a worker instance, or ``worker.sample.remote()`` in parallel on worker instances created as Ray actors (see ``RolloutWorkers.create_remote``). -Here is an example of creating a set of policy evaluation actors and using the to gather experiences in parallel. The trajectories are concatenated, the policy learns on the trajectory batch, and then we broadcast the policy weights to the evaluators for the next round of rollouts: +Here is an example of creating a set of rollout workers and using them gather experiences in parallel. The trajectories are concatenated, the policy learns on the trajectory batch, and then we broadcast the policy weights to the workers for the next round of rollouts: .. code-block:: python - # Setup policy and remote policy evaluation actors + # Setup policy and rollout workers env = gym.make("CartPole-v0") policy = CustomPolicy(env.observation_space, env.action_space, {}) - remote_evaluators = [ - PolicyEvaluator.as_remote().remote(lambda c: gym.make("CartPole-v0"), - CustomPolicy) - for _ in range(10) - ] + workers = WorkerSet( + policy=CustomPolicy, + env_creator=lambda c: gym.make("CartPole-v0"), + num_workers=10) while True: # Gather a batch of samples T1 = SampleBatch.concat_samples( - ray.get([w.sample.remote() for w in remote_evaluators])) + ray.get([w.sample.remote() for w in workers.remote_workers()])) # Improve the policy using the T1 batch policy.learn_on_batch(T1) # Broadcast weights to the policy evaluation workers weights = ray.put({"default_policy": policy.get_weights()}) - for w in remote_evaluators: + for w in workers.remote_workers(): w.set_weights.remote(weights) Policy Optimization @@ -90,16 +467,13 @@ This is how the example in the previous section looks when written using a polic .. code-block:: python # Same setup as before - local_evaluator = PolicyEvaluator(lambda c: gym.make("CartPole-v0"), CustomPolicy) - remote_evaluators = [ - PolicyEvaluator.as_remote().remote(lambda c: gym.make("CartPole-v0"), - CustomPolicy) - for _ in range(10) - ] + workers = WorkerSet( + policy=CustomPolicy, + env_creator=lambda c: gym.make("CartPole-v0"), + num_workers=10) # this optimizer implements the IMPALA architecture - optimizer = AsyncSamplesOptimizer( - local_evaluator, remote_evaluators, train_batch_size=500) + optimizer = AsyncSamplesOptimizer(workers, train_batch_size=500) while True: optimizer.step() @@ -108,9 +482,9 @@ This is how the example in the previous section looks when written using a polic Trainers -------- -Trainers are the boilerplate classes that put the above components together, making algorithms accessible via Python API and the command line. They manage algorithm configuration, setup of the policy evaluators and optimizer, and collection of training metrics. Trainers also implement the `Trainable API `__ for easy experiment management. +Trainers are the boilerplate classes that put the above components together, making algorithms accessible via Python API and the command line. They manage algorithm configuration, setup of the rollout workers and optimizer, and collection of training metrics. Trainers also implement the `Trainable API `__ for easy experiment management. -Example of two equivalent ways of interacting with the PPO trainer: +Example of three equivalent ways of interacting with the PPO trainer, all of which log results in ``~/ray_results``: .. code-block:: python @@ -121,3 +495,8 @@ Example of two equivalent ways of interacting with the PPO trainer: .. code-block:: bash rllib train --run=PPO --env=CartPole-v0 --config='{"train_batch_size": 4000}' + +.. code-block:: python + + from ray import tune + tune.run(PPOTrainer, config={"env": "CartPole-v0", "train_batch_size": 4000}) diff --git a/doc/source/rllib-env.rst b/doc/source/rllib-env.rst index 3d00ac69bcde..b04b91c3c265 100644 --- a/doc/source/rllib-env.rst +++ b/doc/source/rllib-env.rst @@ -275,7 +275,7 @@ Implementing a centralized critic that takes as input the observations and actio .. code-block:: python - def postprocess_trajectory(self, sample_batch, other_agent_batches, episode): + def postprocess_trajectory(policy, sample_batch, other_agent_batches, episode): agents = ["agent_1", "agent_2", "agent_3"] # simple example of 3 agents global_obs_batch = np.stack( [other_agent_batches[agent_id][1]["obs"] for agent_id in agents], diff --git a/doc/source/rllib.rst b/doc/source/rllib.rst index 02b1bc3478ee..e77a0ab427f8 100644 --- a/doc/source/rllib.rst +++ b/doc/source/rllib.rst @@ -5,8 +5,7 @@ RLlib is an open-source library for reinforcement learning that offers both high .. image:: rllib-stack.svg -Learn more about RLlib's design by reading the `ICML paper `__. -To get started, take a look over the `custom env example `__ and the `API documentation `__. +To get started, take a look over the `custom env example `__ and the `API documentation `__. If you're looking to develop custom algorithms with RLlib, also check out `concepts and custom algorithms `__. Installation ------------ @@ -96,12 +95,17 @@ Offline Datasets * `Input API `__ * `Output API `__ -Concepts --------- -* `Policies `__ -* `Policy Evaluation `__ -* `Policy Optimization `__ -* `Trainers `__ +Concepts and Building Custom Algorithms +--------------------------------------- +* `Policies `__ + + - `Building Policies in TensorFlow `__ + + - `Building Policies in PyTorch `__ + +* `Policy Evaluation `__ +* `Policy Optimization `__ +* `Trainers `__ Examples -------- diff --git a/python/ray/rllib/agents/pg/pg.py b/python/ray/rllib/agents/pg/pg.py index 71e2ab3fbd69..299cdcac3de4 100644 --- a/python/ray/rllib/agents/pg/pg.py +++ b/python/ray/rllib/agents/pg/pg.py @@ -29,7 +29,7 @@ def get_policy_class(config): PGTrainer = build_trainer( - name="PG", + name="PGTrainer", default_config=DEFAULT_CONFIG, default_policy=PGTFPolicy, get_policy_class=get_policy_class) diff --git a/python/ray/rllib/agents/ppo/ppo.py b/python/ray/rllib/agents/ppo/ppo.py index b395d935f119..daf43d14821d 100644 --- a/python/ray/rllib/agents/ppo/ppo.py +++ b/python/ray/rllib/agents/ppo/ppo.py @@ -63,7 +63,7 @@ # yapf: enable -def make_optimizer(local_evaluator, remote_evaluators, config): +def choose_policy_optimizer(local_evaluator, remote_evaluators, config): if config["simple_optimizer"]: return SyncSamplesOptimizer( local_evaluator, @@ -155,10 +155,10 @@ def validate_config(config): PPOTrainer = build_trainer( - name="PPO", + name="PPOTrainer", default_config=DEFAULT_CONFIG, default_policy=PPOTFPolicy, - make_policy_optimizer=make_optimizer, + make_policy_optimizer=choose_policy_optimizer, validate_config=validate_config, after_optimizer_step=update_kl, before_train_step=warn_about_obs_filter, diff --git a/python/ray/rllib/agents/trainer_template.py b/python/ray/rllib/agents/trainer_template.py index 314202a1e842..aae8e35f64f8 100644 --- a/python/ray/rllib/agents/trainer_template.py +++ b/python/ray/rllib/agents/trainer_template.py @@ -2,7 +2,7 @@ from __future__ import division from __future__ import print_function -from ray.rllib.agents.trainer import Trainer +from ray.rllib.agents.trainer import Trainer, COMMON_CONFIG from ray.rllib.optimizers import SyncSamplesOptimizer from ray.rllib.utils.annotations import override, DeveloperAPI @@ -44,13 +44,12 @@ def build_trainer(name, a Trainer instance that uses the specified args. """ - if name.endswith("Trainer"): - raise ValueError("Algorithm name should not include *Trainer suffix", - name) + if not name.endswith("Trainer"): + raise ValueError("Algorithm name should have *Trainer suffix", name) class trainer_cls(Trainer): _name = name - _default_config = default_config or Trainer.COMMON_CONFIG + _default_config = default_config or COMMON_CONFIG _policy = default_policy def _init(self, config, env_creator): @@ -92,6 +91,6 @@ def _train(self): after_train_result(self, res) return res - trainer_cls.__name__ = name + "Trainer" - trainer_cls.__qualname__ = name + "Trainer" + trainer_cls.__name__ = name + trainer_cls.__qualname__ = name return trainer_cls From d7be5a5d36348c8cfe49bef1f7f39fa7c8acfb37 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Mon, 27 May 2019 17:24:45 -0700 Subject: [PATCH 043/118] [rllib] Fix error getting kl when simple_optimizer: True in multi-agent PPO --- ci/jenkins_tests/run_rllib_tests.sh | 3 +++ python/ray/rllib/examples/multiagent_cartpole.py | 2 ++ python/ray/rllib/optimizers/sync_samples_optimizer.py | 2 +- 3 files changed, 6 insertions(+), 1 deletion(-) diff --git a/ci/jenkins_tests/run_rllib_tests.sh b/ci/jenkins_tests/run_rllib_tests.sh index fa10c14b8c5a..13acff28d39c 100644 --- a/ci/jenkins_tests/run_rllib_tests.sh +++ b/ci/jenkins_tests/run_rllib_tests.sh @@ -368,6 +368,9 @@ docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ /ray/ci/suppress_output python /ray/python/ray/rllib/examples/multiagent_cartpole.py --num-iters=2 +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ + /ray/ci/suppress_output python /ray/python/ray/rllib/examples/multiagent_cartpole.py --num-iters=2 --simple + docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ /ray/ci/suppress_output python /ray/python/ray/rllib/examples/multiagent_two_trainers.py --num-iters=2 diff --git a/python/ray/rllib/examples/multiagent_cartpole.py b/python/ray/rllib/examples/multiagent_cartpole.py index efa77ecbf7a5..275c54390f97 100644 --- a/python/ray/rllib/examples/multiagent_cartpole.py +++ b/python/ray/rllib/examples/multiagent_cartpole.py @@ -30,6 +30,7 @@ parser.add_argument("--num-agents", type=int, default=4) parser.add_argument("--num-policies", type=int, default=2) parser.add_argument("--num-iters", type=int, default=20) +parser.add_argument("--simple", action="store_true") class CustomModel1(Model): @@ -103,6 +104,7 @@ def gen_policy(i): config={ "env": "multi_cartpole", "log_level": "DEBUG", + "simple_optimizer": args.simple, "num_sgd_iter": 10, "multiagent": { "policies": policies, diff --git a/python/ray/rllib/optimizers/sync_samples_optimizer.py b/python/ray/rllib/optimizers/sync_samples_optimizer.py index f5807ae343ef..a49b290d3e2c 100644 --- a/python/ray/rllib/optimizers/sync_samples_optimizer.py +++ b/python/ray/rllib/optimizers/sync_samples_optimizer.py @@ -69,7 +69,7 @@ def step(self): self.num_steps_sampled += samples.count self.num_steps_trained += samples.count - return fetches + return self.learner_stats @override(PolicyOptimizer) def stats(self): From fa0892f2852b4b2cf46ea33603ea4cd6e410e97a Mon Sep 17 00:00:00 2001 From: Yuhong Guo Date: Tue, 28 May 2019 13:30:41 +0800 Subject: [PATCH 044/118] Replace ReturnIds with NumReturns in TaskInfo to reduce the size (#4854) * Refine TaskInfo * Fix * Add a test to print task info size * Lint * Refine --- .../org/ray/runtime/AbstractRayRuntime.java | 2 +- .../ray/runtime/raylet/RayletClientImpl.java | 11 ++-- .../java/org/ray/runtime/task/TaskSpec.java | 16 ++++-- src/ray/gcs/format/gcs.fbs | 5 +- src/ray/raylet/task_spec.cc | 14 ++--- src/ray/raylet/task_test.cc | 51 +++++++++++++++++++ 6 files changed, 74 insertions(+), 25 deletions(-) diff --git a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java index 01f8dbd12ba0..fbd03bf10483 100644 --- a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java @@ -390,7 +390,7 @@ private TaskSpec createTaskSpec(RayFunc func, PyFunctionDescriptor pyFunctionDes actor.increaseTaskCounter(), actor.getNewActorHandles().toArray(new UniqueId[0]), ArgumentsBuilder.wrap(args, language == TaskLanguage.PYTHON), - returnIds, + numReturns, resources, language, functionDescriptor diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java index b4bfa5a7fd47..01b9e4675016 100644 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java +++ b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java @@ -154,6 +154,7 @@ private static TaskSpec parseTaskSpecFromFlatbuffer(ByteBuffer bb) { UniqueId actorId = UniqueId.fromByteBuffer(info.actorIdAsByteBuffer()); UniqueId actorHandleId = UniqueId.fromByteBuffer(info.actorHandleIdAsByteBuffer()); int actorCounter = info.actorCounter(); + int numReturns = info.numReturns(); // Deserialize new actor handles UniqueId[] newActorHandles = IdUtil.getUniqueIdsFromByteBuffer( @@ -177,8 +178,6 @@ private static TaskSpec parseTaskSpecFromFlatbuffer(ByteBuffer bb) { args[i] = FunctionArg.passByValue(data); } } - // Deserialize return ids - ObjectId[] returnIds = IdUtil.getObjectIdsFromByteBuffer(info.returnsAsByteBuffer()); // Deserialize required resources; Map resources = new HashMap<>(); @@ -193,7 +192,7 @@ private static TaskSpec parseTaskSpecFromFlatbuffer(ByteBuffer bb) { ); return new TaskSpec(driverId, taskId, parentTaskId, parentCounter, actorCreationId, maxActorReconstructions, actorId, actorHandleId, actorCounter, newActorHandles, - args, returnIds, resources, TaskLanguage.JAVA, functionDescriptor); + args, numReturns, resources, TaskLanguage.JAVA, functionDescriptor); } private static ByteBuffer convertTaskSpecToFlatbuffer(TaskSpec task) { @@ -211,6 +210,7 @@ private static ByteBuffer convertTaskSpecToFlatbuffer(TaskSpec task) { final int actorIdOffset = fbb.createString(task.actorId.toByteBuffer()); final int actorHandleIdOffset = fbb.createString(task.actorHandleId.toByteBuffer()); final int actorCounter = task.actorCounter; + final int numReturnsOffset = task.numReturns; // Serialize the new actor handles. int newActorHandlesOffset @@ -234,9 +234,6 @@ private static ByteBuffer convertTaskSpecToFlatbuffer(TaskSpec task) { } int argsOffset = fbb.createVectorOfTables(argsOffsets); - // Serialize returns - int returnsOffset = fbb.createString(IdUtil.concatIds(task.returnIds)); - // Serialize required resources // The required_resources vector indicates the quantities of the different // resources required by this task. The index in this vector corresponds to @@ -292,7 +289,7 @@ private static ByteBuffer convertTaskSpecToFlatbuffer(TaskSpec task) { actorCounter, newActorHandlesOffset, argsOffset, - returnsOffset, + numReturnsOffset, requiredResourcesOffset, requiredPlacementResourcesOffset, language, diff --git a/java/runtime/src/main/java/org/ray/runtime/task/TaskSpec.java b/java/runtime/src/main/java/org/ray/runtime/task/TaskSpec.java index 8a98e11c61ae..3473a9bdb3cc 100644 --- a/java/runtime/src/main/java/org/ray/runtime/task/TaskSpec.java +++ b/java/runtime/src/main/java/org/ray/runtime/task/TaskSpec.java @@ -11,6 +11,7 @@ import org.ray.runtime.functionmanager.FunctionDescriptor; import org.ray.runtime.functionmanager.JavaFunctionDescriptor; import org.ray.runtime.functionmanager.PyFunctionDescriptor; +import org.ray.runtime.util.IdUtil; /** * Represents necessary information of a task for scheduling and executing. @@ -50,7 +51,10 @@ public class TaskSpec { // Task arguments. public final FunctionArg[] args; - // return ids + // number of return objects. + public final int numReturns; + + // returns ids. public final ObjectId[] returnIds; // The task's resource demands. @@ -86,7 +90,7 @@ public TaskSpec( int actorCounter, UniqueId[] newActorHandles, FunctionArg[] args, - ObjectId[] returnIds, + int numReturns, Map resources, TaskLanguage language, FunctionDescriptor functionDescriptor) { @@ -101,7 +105,11 @@ public TaskSpec( this.actorCounter = actorCounter; this.newActorHandles = newActorHandles; this.args = args; - this.returnIds = returnIds; + this.numReturns = numReturns; + returnIds = new ObjectId[numReturns]; + for (int i = 0; i < numReturns; ++i) { + returnIds[i] = IdUtil.computeReturnId(taskId, i + 1); + } this.resources = resources; this.language = language; if (language == TaskLanguage.JAVA) { @@ -145,7 +153,7 @@ public String toString() { ", actorCounter=" + actorCounter + ", newActorHandles=" + Arrays.toString(newActorHandles) + ", args=" + Arrays.toString(args) + - ", returnIds=" + Arrays.toString(returnIds) + + ", numReturns=" + numReturns + ", resources=" + resources + ", language=" + language + ", functionDescriptor=" + functionDescriptor + diff --git a/src/ray/gcs/format/gcs.fbs b/src/ray/gcs/format/gcs.fbs index 7cf250247461..b81f388d88c5 100644 --- a/src/ray/gcs/format/gcs.fbs +++ b/src/ray/gcs/format/gcs.fbs @@ -89,9 +89,8 @@ table TaskInfo { new_actor_handles: string; // Task arguments. args: [Arg]; - // Object IDs of return values. This is a long string that concatenate - // all of the return object IDs of this task. - returns: string; + // Number of return objects. + num_returns: int; // The required_resources vector indicates the quantities of the different // resources required by this task. required_resources: [ResourcePair]; diff --git a/src/ray/raylet/task_spec.cc b/src/ray/raylet/task_spec.cc index d4ec4f5c5e75..17a8b185fc78 100644 --- a/src/ray/raylet/task_spec.cc +++ b/src/ray/raylet/task_spec.cc @@ -92,20 +92,14 @@ TaskSpecification::TaskSpecification( arguments.push_back(argument->ToFlatbuffer(fbb)); } - // Generate return ids. - std::vector returns; - for (int64_t i = 1; i < num_returns + 1; ++i) { - returns.push_back(ObjectID::for_task_return(task_id, i)); - } - // Serialize the TaskSpecification. auto spec = CreateTaskInfo( fbb, to_flatbuf(fbb, driver_id), to_flatbuf(fbb, task_id), to_flatbuf(fbb, parent_task_id), parent_counter, to_flatbuf(fbb, actor_creation_id), to_flatbuf(fbb, actor_creation_dummy_object_id), max_actor_reconstructions, to_flatbuf(fbb, actor_id), to_flatbuf(fbb, actor_handle_id), actor_counter, - ids_to_flatbuf(fbb, new_actor_handles), fbb.CreateVector(arguments), - ids_to_flatbuf(fbb, returns), map_to_flatbuf(fbb, required_resources), + ids_to_flatbuf(fbb, new_actor_handles), fbb.CreateVector(arguments), num_returns, + map_to_flatbuf(fbb, required_resources), map_to_flatbuf(fbb, required_placement_resources), language, string_vec_to_flatbuf(fbb, function_descriptor)); fbb.Finish(spec); @@ -167,12 +161,12 @@ int64_t TaskSpecification::NumArgs() const { int64_t TaskSpecification::NumReturns() const { auto message = flatbuffers::GetRoot(spec_.data()); - return (message->returns()->size() / kUniqueIDSize); + return message->num_returns(); } ObjectID TaskSpecification::ReturnId(int64_t return_index) const { auto message = flatbuffers::GetRoot(spec_.data()); - return ids_from_flatbuf(*message->returns())[return_index]; + return ObjectID::for_task_return(TaskId(), return_index + 1); } bool TaskSpecification::ArgByRef(int64_t arg_index) const { diff --git a/src/ray/raylet/task_test.cc b/src/ray/raylet/task_test.cc index 03a4caff16ee..6d0cfa37017a 100644 --- a/src/ray/raylet/task_test.cc +++ b/src/ray/raylet/task_test.cc @@ -1,5 +1,6 @@ #include "gtest/gtest.h" +#include "ray/common/common_protocol.h" #include "ray/raylet/task_spec.h" namespace ray { @@ -47,6 +48,56 @@ TEST(IdPropertyTest, TestIdProperty) { ASSERT_TRUE(ObjectID::nil().is_nil()); } +TEST(TaskSpecTest, TaskInfoSize) { + std::vector references = {ObjectID::from_random(), ObjectID::from_random()}; + auto arguments_1 = std::make_shared(references); + std::string one_arg("This is an value argument."); + auto arguments_2 = std::make_shared( + reinterpret_cast(one_arg.c_str()), one_arg.size()); + std::vector> task_arguments({arguments_1, arguments_2}); + auto task_id = TaskID::from_random(); + { + flatbuffers::FlatBufferBuilder fbb; + std::vector> arguments; + for (auto &argument : task_arguments) { + arguments.push_back(argument->ToFlatbuffer(fbb)); + } + // General task. + auto spec = CreateTaskInfo( + fbb, to_flatbuf(fbb, DriverID::from_random()), to_flatbuf(fbb, task_id), + to_flatbuf(fbb, TaskID::from_random()), 0, to_flatbuf(fbb, ActorID::nil()), + to_flatbuf(fbb, ObjectID::nil()), 0, to_flatbuf(fbb, ActorID::nil()), + to_flatbuf(fbb, ActorHandleID::nil()), 0, + ids_to_flatbuf(fbb, std::vector()), fbb.CreateVector(arguments), 1, + map_to_flatbuf(fbb, {}), map_to_flatbuf(fbb, {}), Language::PYTHON, + string_vec_to_flatbuf(fbb, {"PackageName", "ClassName", "FunctionName"})); + fbb.Finish(spec); + RAY_LOG(ERROR) << "Ordinary task info size: " << fbb.GetSize(); + } + + { + flatbuffers::FlatBufferBuilder fbb; + std::vector> arguments; + for (auto &argument : task_arguments) { + arguments.push_back(argument->ToFlatbuffer(fbb)); + } + // General task. + auto spec = CreateTaskInfo( + fbb, to_flatbuf(fbb, DriverID::from_random()), to_flatbuf(fbb, task_id), + to_flatbuf(fbb, TaskID::from_random()), 10, + to_flatbuf(fbb, ActorID::from_random()), to_flatbuf(fbb, ObjectID::from_random()), + 10000000, to_flatbuf(fbb, ActorID::from_random()), + to_flatbuf(fbb, ActorHandleID::from_random()), 20, + ids_to_flatbuf(fbb, std::vector( + {ObjectID::from_random(), ObjectID::from_random()})), + fbb.CreateVector(arguments), 2, map_to_flatbuf(fbb, {}), map_to_flatbuf(fbb, {}), + Language::PYTHON, + string_vec_to_flatbuf(fbb, {"PackageName", "ClassName", "FunctionName"})); + fbb.Finish(spec); + RAY_LOG(ERROR) << "Actor task info size: " << fbb.GetSize(); + } +} + } // namespace raylet } // namespace ray From 64a01b2ab6e5b3291fdbaa1e2b09c716d1e4274d Mon Sep 17 00:00:00 2001 From: Qing Wang Date: Tue, 28 May 2019 14:29:35 +0800 Subject: [PATCH 045/118] Update deps commits of opencensus to support building with bzl 0.25.x (#4862) * Update deps to support bzl 2.5.x * Fix --- bazel/ray_deps_setup.bzl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/bazel/ray_deps_setup.bzl b/bazel/ray_deps_setup.bzl index 6094ed3c9303..dafa72b773fe 100644 --- a/bazel/ray_deps_setup.bzl +++ b/bazel/ray_deps_setup.bzl @@ -80,18 +80,18 @@ def ray_deps_setup(): http_archive( name = "io_opencensus_cpp", - strip_prefix = "opencensus-cpp-0.3.0", - urls = ["https://github.com/census-instrumentation/opencensus-cpp/archive/v0.3.0.zip"], + strip_prefix = "opencensus-cpp-3aa11f20dd610cb8d2f7c62e58d1e69196aadf11", + urls = ["https://github.com/census-instrumentation/opencensus-cpp/archive/3aa11f20dd610cb8d2f7c62e58d1e69196aadf11.zip"], ) # OpenCensus depends on Abseil so we have to explicitly pull it in. # This is how diamond dependencies are prevented. git_repository( name = "com_google_absl", - commit = "88a152ae747c3c42dc9167d46c590929b048d436", + commit = "5b65c4af5107176555b23a638e5947686410ac1f", remote = "https://github.com/abseil/abseil-cpp.git", ) - + # OpenCensus depends on jupp0r/prometheus-cpp http_archive( name = "com_github_jupp0r_prometheus_cpp", From 64eb7b322c69a576622b0bc2b3f2059ff428809c Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Tue, 28 May 2019 16:04:16 -0700 Subject: [PATCH 046/118] Upgrade arrow to latest master (#4858) --- bazel/BUILD.plasma | 6 ++++ bazel/ray_deps_setup.bzl | 30 +++++++++---------- build.sh | 4 +-- .../rllib/tests/test_env_with_subprocess.py | 1 + 4 files changed, 24 insertions(+), 17 deletions(-) diff --git a/bazel/BUILD.plasma b/bazel/BUILD.plasma index a2b75b99735e..f264c4a13943 100644 --- a/bazel/BUILD.plasma +++ b/bazel/BUILD.plasma @@ -25,11 +25,13 @@ cc_library( name = "arrow", srcs = [ "cpp/src/arrow/buffer.cc", + "cpp/src/arrow/io/interfaces.cc", "cpp/src/arrow/memory_pool.cc", "cpp/src/arrow/status.cc", "cpp/src/arrow/util/io-util.cc", "cpp/src/arrow/util/logging.cc", "cpp/src/arrow/util/memory.cc", + "cpp/src/arrow/util/string_builder.cc", "cpp/src/arrow/util/thread-pool.cc", ], hdrs = [ @@ -42,6 +44,7 @@ cc_library( "cpp/src/arrow/util/logging.h", "cpp/src/arrow/util/macros.h", "cpp/src/arrow/util/memory.h", + "cpp/src/arrow/util/stl.h", "cpp/src/arrow/util/string_builder.h", "cpp/src/arrow/util/string_view.h", "cpp/src/arrow/util/thread-pool.h", @@ -53,6 +56,9 @@ cc_library( "cpp/src/arrow/vendored/xxhash/xxhash.h", ], strip_include_prefix = "cpp/src", + deps = [ + "@boost//:filesystem", + ], ) cc_library( diff --git a/bazel/ray_deps_setup.bzl b/bazel/ray_deps_setup.bzl index dafa72b773fe..b3cd21b9b3b1 100644 --- a/bazel/ray_deps_setup.bzl +++ b/bazel/ray_deps_setup.bzl @@ -3,9 +3,9 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") def ray_deps_setup(): RULES_JVM_EXTERNAL_TAG = "1.2" - + RULES_JVM_EXTERNAL_SHA = "e5c68b87f750309a79f59c2b69ead5c3221ffa54ff9496306937bfa1c9c8c86b" - + http_archive( name = "rules_jvm_external", sha256 = RULES_JVM_EXTERNAL_SHA, @@ -18,7 +18,7 @@ def ray_deps_setup(): strip_prefix = "bazel-common-f1115e0f777f08c3cdb115526c4e663005bec69b", url = "https://github.com/google/bazel-common/archive/f1115e0f777f08c3cdb115526c4e663005bec69b.zip", ) - + BAZEL_SKYLIB_TAG = "0.6.0" http_archive( @@ -26,64 +26,64 @@ def ray_deps_setup(): strip_prefix = "bazel-skylib-%s" % BAZEL_SKYLIB_TAG, url = "https://github.com/bazelbuild/bazel-skylib/archive/%s.tar.gz" % BAZEL_SKYLIB_TAG, ) - + git_repository( name = "com_github_checkstyle_java", commit = "85f37871ca03b9d3fee63c69c8107f167e24e77b", remote = "https://github.com/ruifangChen/checkstyle_java", ) - + git_repository( name = "com_github_nelhage_rules_boost", commit = "5171b9724fbb39c5fdad37b9ca9b544e8858d8ac", remote = "https://github.com/ray-project/rules_boost", ) - + git_repository( name = "com_github_google_flatbuffers", commit = "63d51afd1196336a7d1f56a988091ef05deb1c62", remote = "https://github.com/google/flatbuffers.git", ) - + git_repository( name = "com_google_googletest", commit = "3306848f697568aacf4bcca330f6bdd5ce671899", remote = "https://github.com/google/googletest", ) - + git_repository( name = "com_github_gflags_gflags", remote = "https://github.com/gflags/gflags.git", tag = "v2.2.2", ) - + new_git_repository( name = "com_github_google_glog", build_file = "@//bazel:BUILD.glog", commit = "5c576f78c49b28d89b23fbb1fc80f54c879ec02e", remote = "https://github.com/google/glog", ) - + new_git_repository( name = "plasma", build_file = "@//bazel:BUILD.plasma", - commit = "d00497b38be84fd77c40cbf77f3422f2a81c44f9", + commit = "9fcc12fc094b85ec2e3e9798bae5c8151d14df5e", remote = "https://github.com/apache/arrow", ) - + new_git_repository( name = "cython", build_file = "@//bazel:BUILD.cython", commit = "49414dbc7ddc2ca2979d6dbe1e44714b10d72e7e", remote = "https://github.com/cython/cython", ) - + http_archive( name = "io_opencensus_cpp", strip_prefix = "opencensus-cpp-3aa11f20dd610cb8d2f7c62e58d1e69196aadf11", urls = ["https://github.com/census-instrumentation/opencensus-cpp/archive/3aa11f20dd610cb8d2f7c62e58d1e69196aadf11.zip"], ) - + # OpenCensus depends on Abseil so we have to explicitly pull it in. # This is how diamond dependencies are prevented. git_repository( @@ -96,7 +96,7 @@ def ray_deps_setup(): http_archive( name = "com_github_jupp0r_prometheus_cpp", strip_prefix = "prometheus-cpp-master", - + # TODO(qwang): We should use the repository of `jupp0r` here when this PR # `https://github.com/jupp0r/prometheus-cpp/pull/225` getting merged. urls = ["https://github.com/jovany-wang/prometheus-cpp/archive/master.zip"], diff --git a/build.sh b/build.sh index 5e391058fa36..8ec5e5b8c106 100755 --- a/build.sh +++ b/build.sh @@ -101,8 +101,8 @@ pushd "$BUILD_DIR" # generated from https://github.com/ray-project/arrow-build from # the commit listed in the command. $PYTHON_EXECUTABLE -m pip install \ - --target="$ROOT_DIR/python/ray/pyarrow_files" pyarrow==0.12.0.RAY \ - --find-links https://s3-us-west-2.amazonaws.com/arrow-wheels/ca1fa51f0901f5a4298f0e4faea00f24e5dd7bb7/index.html + --target="$ROOT_DIR/python/ray/pyarrow_files" pyarrow==0.14.0.RAY \ + --find-links https://s3-us-west-2.amazonaws.com/arrow-wheels/9f35817b35f9d0614a736a497d70de2cf07fed52/index.html export PYTHON_BIN_PATH="$PYTHON_EXECUTABLE" if [ "$RAY_BUILD_JAVA" == "YES" ]; then diff --git a/python/ray/rllib/tests/test_env_with_subprocess.py b/python/ray/rllib/tests/test_env_with_subprocess.py index ecde6c626ca0..1a760ff7e04c 100644 --- a/python/ray/rllib/tests/test_env_with_subprocess.py +++ b/python/ray/rllib/tests/test_env_with_subprocess.py @@ -80,6 +80,7 @@ def leaked_processes(): }, }, }) + time.sleep(5.0) leaked = leaked_processes() assert not leaked, "LEAKED PROCESSES: {}".format(leaked) assert not os.path.exists(UNIQUE_FILE_0), "atexit handler not called" From acee89b1f63d2640026801f34d9d480840feef62 Mon Sep 17 00:00:00 2001 From: Richard Liaw Date: Wed, 29 May 2019 12:09:34 -0700 Subject: [PATCH 047/118] [tune] Auto-init Ray + default SearchAlg (#4815) --- python/ray/tune/ray_trial_executor.py | 7 +++ python/ray/tune/tests/test_trial_runner.py | 51 ++++++++++------------ python/ray/tune/tests/test_tune_restore.py | 15 +++++++ python/ray/tune/tests/test_tune_server.py | 4 +- python/ray/tune/trial_runner.py | 22 +++------- python/ray/tune/tune.py | 15 +++++-- 6 files changed, 63 insertions(+), 51 deletions(-) diff --git a/python/ray/tune/ray_trial_executor.py b/python/ray/tune/ray_trial_executor.py index 548e092cfb1d..81f02661f69f 100644 --- a/python/ray/tune/ray_trial_executor.py +++ b/python/ray/tune/ray_trial_executor.py @@ -39,6 +39,7 @@ class RayTrialExecutor(TrialExecutor): def __init__(self, queue_trials=False, reuse_actors=False, + ray_auto_init=False, refresh_period=RESOURCE_REFRESH_PERIOD): super(RayTrialExecutor, self).__init__(queue_trials) self._running = {} @@ -55,6 +56,12 @@ def __init__(self, self._refresh_period = refresh_period self._last_resource_refresh = float("-inf") self._last_nontrivial_wait = time.time() + if not ray.is_initialized() and ray_auto_init: + logger.info("Initializing Ray automatically." + "For cluster usage or custom Ray initialization, " + "call `ray.init(...)` before `tune.run`.") + ray.init() + if ray.is_initialized(): self._update_avail_resources() diff --git a/python/ray/tune/tests/test_trial_runner.py b/python/ray/tune/tests/test_trial_runner.py index a9bf8e3239c6..37022ceab615 100644 --- a/python/ray/tune/tests/test_trial_runner.py +++ b/python/ray/tune/tests/test_trial_runner.py @@ -1218,7 +1218,7 @@ def train(config, reporter): def testExtraResources(self): ray.init(num_cpus=4, num_gpus=2) - runner = TrialRunner(BasicVariantGenerator()) + runner = TrialRunner() kwargs = { "stopping_criterion": { "training_iteration": 1 @@ -1239,7 +1239,7 @@ def testExtraResources(self): def testCustomResources(self): ray.init(num_cpus=4, num_gpus=2, resources={"a": 2}) - runner = TrialRunner(BasicVariantGenerator()) + runner = TrialRunner() kwargs = { "stopping_criterion": { "training_iteration": 1 @@ -1260,7 +1260,7 @@ def testCustomResources(self): def testExtraCustomResources(self): ray.init(num_cpus=4, num_gpus=2, resources={"a": 2}) - runner = TrialRunner(BasicVariantGenerator()) + runner = TrialRunner() kwargs = { "stopping_criterion": { "training_iteration": 1 @@ -1283,7 +1283,7 @@ def testExtraCustomResources(self): def testCustomResources2(self): ray.init(num_cpus=4, num_gpus=2, resources={"a": 2}) - runner = TrialRunner(BasicVariantGenerator()) + runner = TrialRunner() resource1 = Resources(cpu=1, gpu=0, extra_custom_resources={"a": 2}) self.assertTrue(runner.has_resources(resource1)) resource2 = Resources(cpu=1, gpu=0, custom_resources={"a": 2}) @@ -1295,7 +1295,7 @@ def testCustomResources2(self): def testFractionalGpus(self): ray.init(num_cpus=4, num_gpus=1) - runner = TrialRunner(BasicVariantGenerator()) + runner = TrialRunner() kwargs = { "resources": Resources(cpu=1, gpu=0.5), } @@ -1318,7 +1318,7 @@ def testFractionalGpus(self): def testResourceScheduler(self): ray.init(num_cpus=4, num_gpus=1) - runner = TrialRunner(BasicVariantGenerator()) + runner = TrialRunner() kwargs = { "stopping_criterion": { "training_iteration": 1 @@ -1347,7 +1347,7 @@ def testResourceScheduler(self): def testMultiStepRun(self): ray.init(num_cpus=4, num_gpus=2) - runner = TrialRunner(BasicVariantGenerator()) + runner = TrialRunner() kwargs = { "stopping_criterion": { "training_iteration": 5 @@ -1377,7 +1377,7 @@ def testMultiStepRun(self): def testMultiStepRun2(self): """Checks that runner.step throws when overstepping.""" ray.init(num_cpus=1) - runner = TrialRunner(BasicVariantGenerator()) + runner = TrialRunner() kwargs = { "stopping_criterion": { "training_iteration": 2 @@ -1411,8 +1411,7 @@ def on_trial_result(self, trial_runner, trial, result): executor.start_trial(trial) return TrialScheduler.CONTINUE - runner = TrialRunner( - BasicVariantGenerator(), scheduler=ChangingScheduler()) + runner = TrialRunner(scheduler=ChangingScheduler()) kwargs = { "stopping_criterion": { "training_iteration": 2 @@ -1434,7 +1433,7 @@ def on_trial_result(self, trial_runner, trial, result): def testErrorHandling(self): ray.init(num_cpus=4, num_gpus=2) - runner = TrialRunner(BasicVariantGenerator()) + runner = TrialRunner() kwargs = { "stopping_criterion": { "training_iteration": 1 @@ -1456,7 +1455,7 @@ def testErrorHandling(self): def testThrowOnOverstep(self): ray.init(num_cpus=1, num_gpus=1) - runner = TrialRunner(BasicVariantGenerator()) + runner = TrialRunner() runner.step() self.assertRaises(TuneError, runner.step) @@ -1550,7 +1549,7 @@ def testFailureRecoveryNodeRemoval(self): def testFailureRecoveryMaxFailures(self): ray.init(num_cpus=1, num_gpus=1) - runner = TrialRunner(BasicVariantGenerator()) + runner = TrialRunner() kwargs = { "resources": Resources(cpu=1, gpu=1), "checkpoint_freq": 1, @@ -1579,7 +1578,7 @@ def testFailureRecoveryMaxFailures(self): def testCheckpointing(self): ray.init(num_cpus=1, num_gpus=1) - runner = TrialRunner(BasicVariantGenerator()) + runner = TrialRunner() kwargs = { "stopping_criterion": { "training_iteration": 1 @@ -1610,7 +1609,7 @@ def testCheckpointing(self): def testRestoreMetricsAfterCheckpointing(self): ray.init(num_cpus=1, num_gpus=1) - runner = TrialRunner(BasicVariantGenerator()) + runner = TrialRunner() kwargs = { "resources": Resources(cpu=1, gpu=1), } @@ -1642,7 +1641,7 @@ def testRestoreMetricsAfterCheckpointing(self): def testCheckpointingAtEnd(self): ray.init(num_cpus=1, num_gpus=1) - runner = TrialRunner(BasicVariantGenerator()) + runner = TrialRunner() kwargs = { "stopping_criterion": { "training_iteration": 2 @@ -1663,7 +1662,7 @@ def testCheckpointingAtEnd(self): def testResultDone(self): """Tests that last_result is marked `done` after trial is complete.""" ray.init(num_cpus=1, num_gpus=1) - runner = TrialRunner(BasicVariantGenerator()) + runner = TrialRunner() kwargs = { "stopping_criterion": { "training_iteration": 2 @@ -1682,7 +1681,7 @@ def testResultDone(self): def testPauseThenResume(self): ray.init(num_cpus=1, num_gpus=1) - runner = TrialRunner(BasicVariantGenerator()) + runner = TrialRunner() kwargs = { "stopping_criterion": { "training_iteration": 2 @@ -1713,7 +1712,7 @@ def testPauseThenResume(self): def testStepHook(self): ray.init(num_cpus=4, num_gpus=2) - runner = TrialRunner(BasicVariantGenerator()) + runner = TrialRunner() def on_step_begin(self): self._update_avail_resources() @@ -1743,7 +1742,7 @@ def on_step_end(self): def testStopTrial(self): ray.init(num_cpus=4, num_gpus=2) - runner = TrialRunner(BasicVariantGenerator()) + runner = TrialRunner() kwargs = { "stopping_criterion": { "training_iteration": 5 @@ -1953,8 +1952,7 @@ def testTrialSaveRestore(self): ray.init(num_cpus=3) tmpdir = tempfile.mkdtemp() - runner = TrialRunner( - BasicVariantGenerator(), metadata_checkpoint_dir=tmpdir) + runner = TrialRunner(metadata_checkpoint_dir=tmpdir) trials = [ Trial( "__fake", @@ -2013,8 +2011,7 @@ def testTrialNoSave(self): ray.init(num_cpus=3) tmpdir = tempfile.mkdtemp() - runner = TrialRunner( - BasicVariantGenerator(), metadata_checkpoint_dir=tmpdir) + runner = TrialRunner(metadata_checkpoint_dir=tmpdir) runner.add_trial( Trial( @@ -2069,8 +2066,7 @@ def testCheckpointWithFunction(self): }, checkpoint_freq=1) tmpdir = tempfile.mkdtemp() - runner = TrialRunner( - BasicVariantGenerator(), metadata_checkpoint_dir=tmpdir) + runner = TrialRunner(metadata_checkpoint_dir=tmpdir) runner.add_trial(trial) for i in range(5): runner.step() @@ -2091,8 +2087,7 @@ def count_checkpoints(cdir): ray.init() trial = Trial("__fake", checkpoint_freq=1) tmpdir = tempfile.mkdtemp() - runner = TrialRunner( - BasicVariantGenerator(), metadata_checkpoint_dir=tmpdir) + runner = TrialRunner(metadata_checkpoint_dir=tmpdir) runner.add_trial(trial) for i in range(5): runner.step() diff --git a/python/ray/tune/tests/test_tune_restore.py b/python/ray/tune/tests/test_tune_restore.py index 3742cf598676..768e9d72b04b 100644 --- a/python/ray/tune/tests/test_tune_restore.py +++ b/python/ray/tune/tests/test_tune_restore.py @@ -53,5 +53,20 @@ def testTuneRestore(self): ) +class AutoInitTest(unittest.TestCase): + def testTuneRestore(self): + self.assertFalse(ray.is_initialized()) + tune.run( + "__fake", + name="TestAutoInit", + stop={"training_iteration": 1}, + ray_auto_init=True) + self.assertTrue(ray.is_initialized()) + + def tearDown(self): + ray.shutdown() + _register_all() + + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/python/ray/tune/tests/test_tune_server.py b/python/ray/tune/tests/test_tune_server.py index 7df7a698b38d..dd9a9134c8b6 100644 --- a/python/ray/tune/tests/test_tune_server.py +++ b/python/ray/tune/tests/test_tune_server.py @@ -12,7 +12,6 @@ from ray.rllib import _register_all from ray.tune.trial import Trial, Resources from ray.tune.web_server import TuneClient -from ray.tune.suggest import BasicVariantGenerator from ray.tune.trial_runner import TrialRunner @@ -34,8 +33,7 @@ class TuneServerSuite(unittest.TestCase): def basicSetup(self): ray.init(num_cpus=4, num_gpus=1) port = get_valid_port() - self.runner = TrialRunner( - BasicVariantGenerator(), launch_web_server=True, server_port=port) + self.runner = TrialRunner(launch_web_server=True, server_port=port) runner = self.runner kwargs = { "stopping_criterion": { diff --git a/python/ray/tune/trial_runner.py b/python/ray/tune/trial_runner.py index 8ffcb6e317b8..dfd809732857 100644 --- a/python/ray/tune/trial_runner.py +++ b/python/ray/tune/trial_runner.py @@ -18,6 +18,7 @@ from ray.tune.trial import Trial, Checkpoint from ray.tune.sample import function from ray.tune.schedulers import FIFOScheduler, TrialScheduler +from ray.tune.suggest import BasicVariantGenerator from ray.tune.util import warn_if_slow from ray.utils import binary_to_hex, hex_to_binary from ray.tune.web_server import TuneServer @@ -77,7 +78,7 @@ class TrialRunner(object): """A TrialRunner implements the event loop for scheduling trials on Ray. Example: - runner = TrialRunner(BasicVariantGenerator()) + runner = TrialRunner() runner.add_trial(Trial(...)) runner.add_trial(Trial(...)) while not runner.is_finished(): @@ -98,14 +99,12 @@ class TrialRunner(object): CKPT_FILE_TMPL = "experiment_state-{}.json" def __init__(self, - search_alg, + search_alg=None, scheduler=None, launch_web_server=False, metadata_checkpoint_dir=None, server_port=TuneServer.DEFAULT_PORT, verbose=True, - queue_trials=False, - reuse_actors=False, trial_executor=None): """Initializes a new TrialRunner. @@ -119,20 +118,15 @@ def __init__(self, server_port (int): Port number for launching TuneServer verbose (bool): Flag for verbosity. If False, trial results will not be output. - queue_trials (bool): Whether to queue trials when the cluster does - not currently have enough resources to launch one. This should - be set to True when running on an autoscaling cluster to enable - automatic scale-up. reuse_actors (bool): Whether to reuse actors between different trials when possible. This can drastically speed up experiments that start and stop actors often (e.g., PBT in time-multiplexing mode). trial_executor (TrialExecutor): Defaults to RayTrialExecutor. """ - self._search_alg = search_alg + self._search_alg = search_alg or BasicVariantGenerator() self._scheduler_alg = scheduler or FIFOScheduler() - self.trial_executor = (trial_executor or RayTrialExecutor( - queue_trials=queue_trials, reuse_actors=reuse_actors)) + self.trial_executor = trial_executor or RayTrialExecutor() # For debugging, it may be useful to halt trials after some time has # elapsed. TODO(ekl) consider exposing this in the API. @@ -141,7 +135,6 @@ def __init__(self, self._total_time = 0 self._iteration = 0 self._verbose = verbose - self._queue_trials = queue_trials self._server = None self._server_port = server_port @@ -229,11 +222,8 @@ def restore(cls, "This will ignore any new changes to the specification." ])) - from ray.tune.suggest import BasicVariantGenerator runner = TrialRunner( - search_alg or BasicVariantGenerator(), - scheduler=scheduler, - trial_executor=trial_executor) + search_alg, scheduler=scheduler, trial_executor=trial_executor) runner.__setstate__(runner_state["runner_data"]) diff --git a/python/ray/tune/tune.py b/python/ray/tune/tune.py index 51f5dcdf265c..0d84b665167a 100644 --- a/python/ray/tune/tune.py +++ b/python/ray/tune/tune.py @@ -11,6 +11,7 @@ from ray.tune.experiment import convert_to_experiment_list, Experiment from ray.tune.suggest import BasicVariantGenerator from ray.tune.trial import Trial, DEBUG_PRINT_INTERVAL +from ray.tune.ray_trial_executor import RayTrialExecutor from ray.tune.log_sync import wait_for_log_sync from ray.tune.trial_runner import TrialRunner from ray.tune.schedulers import (HyperBandScheduler, AsyncHyperBandScheduler, @@ -90,7 +91,8 @@ def run(run_or_experiment, queue_trials=False, reuse_actors=False, trial_executor=None, - raise_on_failed_trial=True): + raise_on_failed_trial=True, + ray_auto_init=True): """Executes training. Args: @@ -166,6 +168,9 @@ def run(run_or_experiment, trial_executor (TrialExecutor): Manage the execution of trials. raise_on_failed_trial (bool): Raise TuneError if there exists failed trial (of ERROR state) when the experiments complete. + ray_auto_init (bool): Automatically starts a local Ray cluster + if using a RayTrialExecutor (which is the default) and + if Ray is not initialized. Defaults to True. Returns: List of Trial objects. @@ -187,6 +192,10 @@ def run(run_or_experiment, } ) """ + trial_executor = trial_executor or RayTrialExecutor( + queue_trials=queue_trials, + reuse_actors=reuse_actors, + ray_auto_init=ray_auto_init) experiment = run_or_experiment if not isinstance(run_or_experiment, Experiment): experiment = Experiment( @@ -229,14 +238,12 @@ def run(run_or_experiment, search_alg.add_configurations([experiment]) runner = TrialRunner( - search_alg, + search_alg=search_alg, scheduler=scheduler, metadata_checkpoint_dir=checkpoint_dir, launch_web_server=with_server, server_port=server_port, verbose=bool(verbose > 1), - queue_trials=queue_trials, - reuse_actors=reuse_actors, trial_executor=trial_executor) if verbose: From a218a14c92e7896b244f2f67369741983b37c12d Mon Sep 17 00:00:00 2001 From: Robert Nishihara Date: Wed, 29 May 2019 16:57:28 -0700 Subject: [PATCH 048/118] Bump version from 0.8.0.dev0 to 0.7.1. (#4890) --- python/ray/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/__init__.py b/python/ray/__init__.py index e1b65cdcf6c7..421b1c6838ac 100644 --- a/python/ray/__init__.py +++ b/python/ray/__init__.py @@ -96,7 +96,7 @@ from ray.runtime_context import _get_runtime_context # noqa: E402 # Ray version string. -__version__ = "0.8.0.dev0" +__version__ = "0.7.1" __all__ = [ "global_state", From 2dd0beb5bd7d0f77f4cdfb15fe40de6f1ac5c62e Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Wed, 29 May 2019 18:17:14 -0700 Subject: [PATCH 049/118] [rllib] Allow access to batches prior to postprocessing (#4871) --- doc/source/rllib-algorithms.rst | 4 ++++ doc/source/rllib-models.rst | 4 ++++ python/ray/rllib/agents/trainer.py | 10 ++++++++-- python/ray/rllib/evaluation/sample_batch_builder.py | 8 +++++++- .../ray/rllib/examples/custom_metrics_and_callbacks.py | 2 +- 5 files changed, 24 insertions(+), 4 deletions(-) diff --git a/doc/source/rllib-algorithms.rst b/doc/source/rllib-algorithms.rst index 5a07280e3972..a9291bc4a984 100644 --- a/doc/source/rllib-algorithms.rst +++ b/doc/source/rllib-algorithms.rst @@ -101,6 +101,10 @@ Tuned examples: `PongNoFrameskip-v4 `__): +.. warning:: + + Keras custom models are not compatible with multi-GPU (this includes PPO in single-GPU mode). This is because the multi-GPU implementation in RLlib relies on variable scopes to implement cross-GPU support. + .. literalinclude:: ../../python/ray/rllib/agents/ppo/appo.py :language: python :start-after: __sphinx_doc_begin__ diff --git a/doc/source/rllib-models.rst b/doc/source/rllib-models.rst index cdf42ea228c7..6a05e5b1c3e6 100644 --- a/doc/source/rllib-models.rst +++ b/doc/source/rllib-models.rst @@ -35,6 +35,10 @@ Custom Models (TensorFlow) Custom TF models should subclass the common RLlib `model class `__ and override the ``_build_layers_v2`` method. This method takes in a dict of tensor inputs (the observation ``obs``, ``prev_action``, and ``prev_reward``, ``is_training``), and returns a feature layer and float vector of the specified output size. You can also override the ``value_function`` method to implement a custom value branch. Additional supervised / self-supervised losses can be added via the ``custom_loss`` method. The model can then be registered and used in place of a built-in model: +.. warning:: + + Keras custom models are not compatible with multi-GPU (this includes PPO in single-GPU mode). This is because the multi-GPU implementation in RLlib relies on variable scopes to implement cross-GPU support. + .. code-block:: python import ray diff --git a/python/ray/rllib/agents/trainer.py b/python/ray/rllib/agents/trainer.py index 83b00a896b71..4294affb1172 100644 --- a/python/ray/rllib/agents/trainer.py +++ b/python/ray/rllib/agents/trainer.py @@ -54,14 +54,20 @@ # Callbacks that will be run during various phases of training. These all # take a single "info" dict as an argument. For episode callbacks, custom # metrics can be attached to the episode by updating the episode object's - # custom metrics dict (see examples/custom_metrics_and_callbacks.py). + # custom metrics dict (see examples/custom_metrics_and_callbacks.py). You + # may also mutate the passed in batch data in your callback. "callbacks": { "on_episode_start": None, # arg: {"env": .., "episode": ...} "on_episode_step": None, # arg: {"env": .., "episode": ...} "on_episode_end": None, # arg: {"env": .., "episode": ...} "on_sample_end": None, # arg: {"samples": .., "evaluator": ...} "on_train_result": None, # arg: {"trainer": ..., "result": ...} - "on_postprocess_traj": None, # arg: {"batch": ..., "episode": ...} + "on_postprocess_traj": None, # arg: { + # "agent_id": ..., "episode": ..., + # "pre_batch": (before processing), + # "post_batch": (after processing), + # "all_pre_batches": (other agent ids), + # } }, # Whether to attempt to continue training if a worker crashes. "ignore_worker_failures": False, diff --git a/python/ray/rllib/evaluation/sample_batch_builder.py b/python/ray/rllib/evaluation/sample_batch_builder.py index 0ead77d52847..e82ca7357b70 100644 --- a/python/ray/rllib/evaluation/sample_batch_builder.py +++ b/python/ray/rllib/evaluation/sample_batch_builder.py @@ -165,7 +165,13 @@ def postprocess_batch_so_far(self, episode): self.policy_builders[self.agent_to_policy[agent_id]].add_batch( post_batch) if self.postp_callback: - self.postp_callback({"episode": episode, "batch": post_batch}) + self.postp_callback({ + "episode": episode, + "agent_id": agent_id, + "pre_batch": pre_batches[agent_id], + "post_batch": post_batch, + "all_pre_batches": pre_batches, + }) self.agent_builders.clear() self.agent_to_policy.clear() diff --git a/python/ray/rllib/examples/custom_metrics_and_callbacks.py b/python/ray/rllib/examples/custom_metrics_and_callbacks.py index 27d91331f32d..ba7795bf0553 100644 --- a/python/ray/rllib/examples/custom_metrics_and_callbacks.py +++ b/python/ray/rllib/examples/custom_metrics_and_callbacks.py @@ -46,7 +46,7 @@ def on_train_result(info): def on_postprocess_traj(info): episode = info["episode"] - batch = info["batch"] + batch = info["post_batch"] print("postprocessed {} steps".format(batch.count)) if "num_batches" not in episode.custom_metrics: episode.custom_metrics["num_batches"] = 0 From 3f4d37cd0e0c625f0e8806d6ca628b81838a976f Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Wed, 29 May 2019 20:41:02 -0700 Subject: [PATCH 050/118] [rllib] Fix Multidiscrete support (#4869) --- .../ray/rllib/agents/impala/vtrace_policy.py | 31 +++---------------- python/ray/rllib/agents/ppo/appo_policy.py | 7 ++--- python/ray/rllib/models/action_dist.py | 11 ++++--- python/ray/rllib/models/catalog.py | 3 +- .../ray/rllib/tests/test_supported_spaces.py | 12 +++++-- 5 files changed, 26 insertions(+), 38 deletions(-) diff --git a/python/ray/rllib/agents/impala/vtrace_policy.py b/python/ray/rllib/agents/impala/vtrace_policy.py index 9b7c57b9355e..9b283c7172cc 100644 --- a/python/ray/rllib/agents/impala/vtrace_policy.py +++ b/python/ray/rllib/agents/impala/vtrace_policy.py @@ -15,7 +15,7 @@ from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.tf_policy import TFPolicy, \ LearningRateSchedule -from ray.rllib.models.action_dist import MultiCategorical +from ray.rllib.models.action_dist import Categorical from ray.rllib.models.catalog import ModelCatalog from ray.rllib.utils.annotations import override from ray.rllib.utils.explained_variance import explained_variance @@ -191,9 +191,7 @@ def __init__(self, unpacked_outputs = tf.split( self.model.outputs, output_hidden_shape, axis=1) - dist_inputs = unpacked_outputs if is_multidiscrete else \ - self.model.outputs - action_dist = dist_class(dist_inputs) + action_dist = dist_class(self.model.outputs) values = self.model.value_function() self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, @@ -258,32 +256,13 @@ def make_time_major(tensor, drop_last=False): rewards=make_time_major(rewards, drop_last=True), values=make_time_major(values, drop_last=True), bootstrap_value=make_time_major(values)[-1], - dist_class=dist_class, + dist_class=Categorical if is_multidiscrete else dist_class, valid_mask=make_time_major(mask, drop_last=True), vf_loss_coeff=self.config["vf_loss_coeff"], entropy_coeff=self.config["entropy_coeff"], clip_rho_threshold=self.config["vtrace_clip_rho_threshold"], clip_pg_rho_threshold=self.config["vtrace_clip_pg_rho_threshold"]) - # KL divergence between worker and learner logits for debugging - model_dist = MultiCategorical(unpacked_outputs) - behaviour_dist = MultiCategorical(unpacked_behaviour_logits) - - kls = model_dist.kl(behaviour_dist) - if len(kls) > 1: - self.KL_stats = {} - - for i, kl in enumerate(kls): - self.KL_stats.update({ - "mean_KL_{}".format(i): tf.reduce_mean(kl), - "max_KL_{}".format(i): tf.reduce_max(kl), - }) - else: - self.KL_stats = { - "mean_KL": tf.reduce_mean(kls[0]), - "max_KL": tf.reduce_max(kls[0]), - } - # Initialize TFPolicy loss_in = [ (SampleBatch.ACTIONS, actions), @@ -318,7 +297,7 @@ def make_time_major(tensor, drop_last=False): self.sess.run(tf.global_variables_initializer()) self.stats_fetches = { - LEARNER_STATS_KEY: dict({ + LEARNER_STATS_KEY: { "cur_lr": tf.cast(self.cur_lr, tf.float64), "policy_loss": self.loss.pi_loss, "entropy": self.loss.entropy, @@ -328,7 +307,7 @@ def make_time_major(tensor, drop_last=False): "vf_explained_var": explained_variance( tf.reshape(self.loss.vtrace_returns.vs, [-1]), tf.reshape(make_time_major(values, drop_last=True), [-1])), - }, **self.KL_stats), + }, } @override(TFPolicy) diff --git a/python/ray/rllib/agents/ppo/appo_policy.py b/python/ray/rllib/agents/ppo/appo_policy.py index b740d6d81430..9f213063ab94 100644 --- a/python/ray/rllib/agents/ppo/appo_policy.py +++ b/python/ray/rllib/agents/ppo/appo_policy.py @@ -13,6 +13,7 @@ import ray from ray.rllib.agents.impala import vtrace from ray.rllib.evaluation.postprocessing import Postprocessing +from ray.rllib.models.action_dist import Categorical from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.tf_policy_template import build_tf_policy from ray.rllib.policy.tf_policy import LearningRateSchedule @@ -220,10 +221,8 @@ def make_time_major(*args, **kw): behaviour_logits, output_hidden_shape, axis=1) unpacked_outputs = tf.split( policy.model.outputs, output_hidden_shape, axis=1) - prev_dist_inputs = unpacked_behaviour_logits if is_multidiscrete else \ - behaviour_logits action_dist = policy.action_dist - prev_action_dist = policy.dist_class(prev_dist_inputs) + prev_action_dist = policy.dist_class(behaviour_logits) values = policy.value_function if policy.model.state_in: @@ -257,7 +256,7 @@ def make_time_major(*args, **kw): rewards=make_time_major(rewards, drop_last=True), values=make_time_major(values, drop_last=True), bootstrap_value=make_time_major(values)[-1], - dist_class=policy.dist_class, + dist_class=Categorical if is_multidiscrete else policy.dist_class, valid_mask=make_time_major(mask, drop_last=True), vf_loss_coeff=policy.config["vf_loss_coeff"], entropy_coeff=policy.config["entropy_coeff"], diff --git a/python/ray/rllib/models/action_dist.py b/python/ray/rllib/models/action_dist.py index 9cf58b9dd317..303f3bed2b21 100644 --- a/python/ray/rllib/models/action_dist.py +++ b/python/ray/rllib/models/action_dist.py @@ -76,7 +76,7 @@ class Categorical(ActionDistribution): @override(ActionDistribution) def logp(self, x): return -tf.nn.sparse_softmax_cross_entropy_with_logits( - logits=self.inputs, labels=x) + logits=self.inputs, labels=tf.cast(x, tf.int32)) @override(ActionDistribution) def entropy(self): @@ -126,14 +126,17 @@ def _build_sample_op(self): class MultiCategorical(ActionDistribution): """Categorical distribution for discrete action spaces.""" - def __init__(self, inputs): - self.cats = [Categorical(input_) for input_ in inputs] + def __init__(self, inputs, input_lens): + self.cats = [ + Categorical(input_) + for input_ in tf.split(inputs, input_lens, axis=1) + ] self.sample_op = self._build_sample_op() def logp(self, actions): # If tensor is provided, unstack it into list if isinstance(actions, tf.Tensor): - actions = tf.unstack(actions, axis=1) + actions = tf.unstack(tf.cast(actions, tf.int32), axis=1) logps = tf.stack( [cat.logp(act) for cat, act in zip(self.cats, actions)]) return tf.reduce_sum(logps, axis=0) diff --git a/python/ray/rllib/models/catalog.py b/python/ray/rllib/models/catalog.py index d237474480e5..a3a68c22ef1c 100644 --- a/python/ray/rllib/models/catalog.py +++ b/python/ray/rllib/models/catalog.py @@ -149,7 +149,8 @@ def get_action_dist(action_space, config, dist_type=None, torch=False): elif isinstance(action_space, gym.spaces.multi_discrete.MultiDiscrete): if torch: raise NotImplementedError - return MultiCategorical, int(sum(action_space.nvec)) + return partial(MultiCategorical, input_lens=action_space.nvec), \ + int(sum(action_space.nvec)) raise NotImplementedError("Unsupported args: {} {}".format( action_space, dist_type)) diff --git a/python/ray/rllib/tests/test_supported_spaces.py b/python/ray/rllib/tests/test_supported_spaces.py index c3ea442c8e8d..a7f4976ef5b2 100644 --- a/python/ray/rllib/tests/test_supported_spaces.py +++ b/python/ray/rllib/tests/test_supported_spaces.py @@ -2,7 +2,7 @@ import traceback import gym -from gym.spaces import Box, Discrete, Tuple, Dict +from gym.spaces import Box, Discrete, Tuple, Dict, MultiDiscrete from gym.envs.registration import EnvSpec import numpy as np import sys @@ -17,6 +17,7 @@ ACTION_SPACES_TO_TEST = { "discrete": Discrete(5), "vector": Box(-1.0, 1.0, (5, ), dtype=np.float32), + "multidiscrete": MultiDiscrete([1, 2, 3, 4]), "tuple": Tuple( [Discrete(2), Discrete(3), @@ -61,7 +62,7 @@ def step(self, action): return StubEnv -def check_support(alg, config, stats, check_bounds=False): +def check_support(alg, config, stats, check_bounds=False, name=None): for a_name, action_space in ACTION_SPACES_TO_TEST.items(): for o_name, obs_space in OBSERVATION_SPACES_TO_TEST.items(): print("=== Testing", alg, action_space, obs_space, "===") @@ -87,7 +88,7 @@ def check_support(alg, config, stats, check_bounds=False): pass print(stat) print() - stats[alg, a_name, o_name] = stat + stats[name or alg, a_name, o_name] = stat def check_support_multiagent(alg, config): @@ -114,6 +115,11 @@ def testAll(self): stats = {} check_support("IMPALA", {"num_gpus": 0}, stats) check_support("APPO", {"num_gpus": 0, "vtrace": False}, stats) + check_support( + "APPO", { + "num_gpus": 0, + "vtrace": True + }, stats, name="APPO-vt") check_support( "DDPG", { "exploration_ou_noise_scale": 100.0, From b7c284aaa3c4869327c2e80e11b2a8ab36e0d8bd Mon Sep 17 00:00:00 2001 From: Qing Wang Date: Thu, 30 May 2019 11:54:30 +0800 Subject: [PATCH 051/118] Refactor redis callback handling (#4841) * Add CallbackReply * Fix * fix linting by format.sh * Fix linting * Address comments. * Fix --- src/ray/gcs/redis_context.cc | 140 +++++++++++-------- src/ray/gcs/redis_context.h | 36 ++++- src/ray/gcs/redis_module/ray_redis_module.cc | 2 +- src/ray/gcs/tables.cc | 29 ++-- 4 files changed, 132 insertions(+), 75 deletions(-) diff --git a/src/ray/gcs/redis_context.cc b/src/ray/gcs/redis_context.cc index 6b03fa735007..e0c5a6565412 100644 --- a/src/ray/gcs/redis_context.cc +++ b/src/ray/gcs/redis_context.cc @@ -20,7 +20,8 @@ namespace { /// A helper function to call the callback and delete it from the callback /// manager if necessary. -void ProcessCallback(int64_t callback_index, const std::string &data) { +void ProcessCallback(int64_t callback_index, + const ray::gcs::CallbackReply &callback_reply) { RAY_CHECK(callback_index >= 0) << "The callback index must be greater than 0, " << "but it actually is " << callback_index; auto callback_item = ray::gcs::RedisCallbackManager::instance().get(callback_index); @@ -31,7 +32,7 @@ void ProcessCallback(int64_t callback_index, const std::string &data) { } // Invoke the callback. if (callback_item.callback != nullptr) { - callback_item.callback(data); + callback_item.callback(callback_reply); } if (!callback_item.is_subscription) { // Delete the callback if it's not a subscription callback. @@ -45,74 +46,91 @@ namespace ray { namespace gcs { -// This is a global redis callback which will be registered for every -// asynchronous redis call. It dispatches the appropriate callback -// that was registered with the RedisCallbackManager. -void GlobalRedisCallback(void *c, void *r, void *privdata) { - if (r == nullptr) { - return; +CallbackReply::CallbackReply(redisReply *redis_reply) { + RAY_CHECK(nullptr != redis_reply); + RAY_CHECK(redis_reply->type != REDIS_REPLY_ERROR) << "Got an error in redis reply: " + << redis_reply->str; + this->redis_reply_ = redis_reply; +} + +bool CallbackReply::IsNil() const { return REDIS_REPLY_NIL == redis_reply_->type; } + +int64_t CallbackReply::ReadAsInteger() const { + RAY_CHECK(REDIS_REPLY_INTEGER == redis_reply_->type) << "Unexpected type: " + << redis_reply_->type; + return static_cast(redis_reply_->integer); +} + +std::string CallbackReply::ReadAsString() const { + RAY_CHECK(REDIS_REPLY_STRING == redis_reply_->type) << "Unexpected type: " + << redis_reply_->type; + return std::string(redis_reply_->str, redis_reply_->len); +} + +Status CallbackReply::ReadAsStatus() const { + RAY_CHECK(REDIS_REPLY_STATUS == redis_reply_->type) << "Unexpected type: " + << redis_reply_->type; + const std::string status_str(redis_reply_->str, redis_reply_->len); + if ("OK" == status_str) { + return Status::OK(); } - int64_t callback_index = reinterpret_cast(privdata); - redisReply *reply = reinterpret_cast(r); + + return Status::RedisError(status_str); +} + +std::string CallbackReply::ReadAsPubsubData() const { + RAY_CHECK(REDIS_REPLY_ARRAY == redis_reply_->type) << "Unexpected type: " + << redis_reply_->type; + std::string data = ""; - // Parse the response. - switch (reply->type) { - case (REDIS_REPLY_NIL): { - // Do not add any data for a nil response. - } break; - case (REDIS_REPLY_STRING): { - data = std::string(reply->str, reply->len); - } break; - case (REDIS_REPLY_STATUS): { - } break; - case (REDIS_REPLY_ERROR): { - RAY_LOG(FATAL) << "Redis error: " << reply->str; - } break; - case (REDIS_REPLY_INTEGER): { - data = std::to_string(reply->integer); - break; + // Parse the published message. + redisReply *message_type = redis_reply_->element[0]; + if (strcmp(message_type->str, "subscribe") == 0) { + // If the message is for the initial subscription call, return the empty + // string as a response to signify that subscription was successful. + } else if (strcmp(message_type->str, "message") == 0) { + // If the message is from a PUBLISH, make sure the data is nonempty. + redisReply *message = redis_reply_->element[redis_reply_->elements - 1]; + // data is a notification message. + data = std::string(message->str, message->len); + RAY_CHECK(!data.empty()) << "Empty message received on subscribe channel."; + } else { + RAY_LOG(FATAL) << "This is not a pubsub reply: data=" << message_type->str; } - default: - RAY_LOG(FATAL) << "Fatal redis error of type " << reply->type << " and with string " - << reply->str; + + return data; +} + +void CallbackReply::ReadAsStringArray(std::vector *array) const { + RAY_CHECK(nullptr != array) << "Argument `array` must not be nullptr."; + RAY_CHECK(REDIS_REPLY_ARRAY == redis_reply_->type); + + const auto array_size = static_cast(redis_reply_->elements); + if (array_size > 0) { + auto *entry = redis_reply_->element[0]; + const bool is_pubsub_reply = + strcmp(entry->str, "subscribe") == 0 || strcmp(entry->str, "message") == 0; + RAY_CHECK(!is_pubsub_reply) << "Subpub reply cannot be read as a string array."; + } + + array->resize(array_size); + for (size_t i = 0; i < array_size; ++i) { + auto *entry = redis_reply_->element[i]; + RAY_CHECK(REDIS_REPLY_STRING == entry->type) << "Unexcepted type: " << entry->type; + array->push_back(std::string(entry->str, entry->len)); } - ProcessCallback(callback_index, data); } -void SubscribeRedisCallback(void *c, void *r, void *privdata) { +// This is a global redis callback which will be registered for every +// asynchronous redis call. It dispatches the appropriate callback +// that was registered with the RedisCallbackManager. +void GlobalRedisCallback(void *c, void *r, void *privdata) { if (r == nullptr) { return; } int64_t callback_index = reinterpret_cast(privdata); redisReply *reply = reinterpret_cast(r); - std::string data = ""; - // Parse the response. - switch (reply->type) { - case (REDIS_REPLY_ARRAY): { - // Parse the published message. - redisReply *message_type = reply->element[0]; - if (strcmp(message_type->str, "subscribe") == 0) { - // If the message is for the initial subscription call, return the empty - // string as a response to signify that subscription was successful. - } else if (strcmp(message_type->str, "message") == 0) { - // If the message is from a PUBLISH, make sure the data is nonempty. - redisReply *message = reply->element[reply->elements - 1]; - auto notification = std::string(message->str, message->len); - RAY_CHECK(!notification.empty()) << "Empty message received on subscribe channel"; - data = notification; - } else { - RAY_LOG(FATAL) << "Fatal redis error during subscribe" << message_type->str; - } - - } break; - case (REDIS_REPLY_ERROR): { - RAY_LOG(FATAL) << "Redis error: " << reply->str; - } break; - default: - RAY_LOG(FATAL) << "Fatal redis error of type " << reply->type << " and with string " - << reply->str; - } - ProcessCallback(callback_index, data); + ProcessCallback(callback_index, CallbackReply(reply)); } int64_t RedisCallbackManager::add(const RedisCallback &function, bool is_subscription) { @@ -259,13 +277,13 @@ Status RedisContext::SubscribeAsync(const ClientID &client_id, // Subscribe to all messages. std::string redis_command = "SUBSCRIBE %d"; status = redisAsyncCommand( - subscribe_context_, reinterpret_cast(&SubscribeRedisCallback), + subscribe_context_, reinterpret_cast(&GlobalRedisCallback), reinterpret_cast(callback_index), redis_command.c_str(), pubsub_channel); } else { // Subscribe only to messages sent to this client. std::string redis_command = "SUBSCRIBE %d:%b"; status = redisAsyncCommand( - subscribe_context_, reinterpret_cast(&SubscribeRedisCallback), + subscribe_context_, reinterpret_cast(&GlobalRedisCallback), reinterpret_cast(callback_index), redis_command.c_str(), pubsub_channel, client_id.data(), client_id.size()); } diff --git a/src/ray/gcs/redis_context.h b/src/ray/gcs/redis_context.h index 93a343464892..b82915374b0a 100644 --- a/src/ray/gcs/redis_context.h +++ b/src/ray/gcs/redis_context.h @@ -24,9 +24,43 @@ struct aeEventLoop; namespace ray { namespace gcs { + +/// A simple reply wrapper for redis reply. +class CallbackReply { + public: + explicit CallbackReply(redisReply *redis_reply); + + /// Whether this reply is `nil` type reply. + bool IsNil() const; + + /// Read this reply data as an integer. + int64_t ReadAsInteger() const; + + /// Read this reply data as a string. + /// + /// Note that this will return an empty string if + /// the type of this reply is `nil` or `status`. + std::string ReadAsString() const; + + /// Read this reply data as a status. + Status ReadAsStatus() const; + + /// Read this reply data as a pub-sub data. + std::string ReadAsPubsubData() const; + + /// Read this reply data as a string array. + /// + /// \param array Since the return-value may be large, + /// make it as an output parameter. + void ReadAsStringArray(std::vector *array) const; + + private: + redisReply *redis_reply_; +}; + /// Every callback should take in a vector of the results from the Redis /// operation. -using RedisCallback = std::function; +using RedisCallback = std::function; void GlobalRedisCallback(void *c, void *r, void *privdata); diff --git a/src/ray/gcs/redis_module/ray_redis_module.cc b/src/ray/gcs/redis_module/ray_redis_module.cc index b9891e8cae32..0014778896cd 100644 --- a/src/ray/gcs/redis_module/ray_redis_module.cc +++ b/src/ray/gcs/redis_module/ray_redis_module.cc @@ -351,7 +351,7 @@ int TableAppend_DoWrite(RedisModuleCtx *ctx, RedisModuleString **argv, int argc, // The requested index did not match the current length of the log. Return // an error message as a string. static const char *reply = "ERR entry exists"; - RedisModule_ReplyWithStringBuffer(ctx, reply, strlen(reply)); + RedisModule_ReplyWithSimpleString(ctx, reply); return REDISMODULE_ERR; } } diff --git a/src/ray/gcs/tables.cc b/src/ray/gcs/tables.cc index 3d4708940d1a..ccf05f2b5151 100644 --- a/src/ray/gcs/tables.cc +++ b/src/ray/gcs/tables.cc @@ -41,10 +41,11 @@ template Status Log::Append(const DriverID &driver_id, const ID &id, std::shared_ptr &dataT, const WriteCallback &done) { num_appends_++; - auto callback = [this, id, dataT, done](const std::string &data) { - // If data is not empty, then Redis failed to append the entry. - RAY_CHECK(data.empty()) << "TABLE_APPEND command failed: " << data; - + auto callback = [this, id, dataT, done](const CallbackReply &reply) { + const auto status = reply.ReadAsStatus(); + // Failed to append the entry. + RAY_CHECK(status.ok()) << "Failed to execute command TABLE_APPEND:" + << status.ToString(); if (done != nullptr) { (done)(client_, id, *dataT); } @@ -62,8 +63,9 @@ Status Log::AppendAt(const DriverID &driver_id, const ID &id, std::shared_ptr &dataT, const WriteCallback &done, const WriteCallback &failure, int log_length) { num_appends_++; - auto callback = [this, id, dataT, done, failure](const std::string &data) { - if (data.empty()) { + auto callback = [this, id, dataT, done, failure](const CallbackReply &reply) { + const auto status = reply.ReadAsStatus(); + if (status.ok()) { if (done != nullptr) { (done)(client_, id, *dataT); } @@ -85,10 +87,11 @@ template Status Log::Lookup(const DriverID &driver_id, const ID &id, const Callback &lookup) { num_lookups_++; - auto callback = [this, id, lookup](const std::string &data) { + auto callback = [this, id, lookup](const CallbackReply &reply) { if (lookup != nullptr) { std::vector results; - if (!data.empty()) { + if (!reply.IsNil()) { + const auto data = reply.ReadAsString(); auto root = flatbuffers::GetRoot(data.data()); RAY_CHECK(from_flatbuf(*root->id()) == id); for (size_t i = 0; i < root->entries()->size(); i++) { @@ -125,7 +128,9 @@ Status Log::Subscribe(const DriverID &driver_id, const ClientID &clien const SubscriptionCallback &done) { RAY_CHECK(subscribe_callback_index_ == -1) << "Client called Subscribe twice on the same table"; - auto callback = [this, subscribe, done](const std::string &data) { + auto callback = [this, subscribe, done](const CallbackReply &reply) { + const auto data = reply.ReadAsPubsubData(); + if (data.empty()) { // No notification data is provided. This is the callback for the // initial subscription request. @@ -231,7 +236,7 @@ template Status Table::Add(const DriverID &driver_id, const ID &id, std::shared_ptr &dataT, const WriteCallback &done) { num_adds_++; - auto callback = [this, id, dataT, done](const std::string &data) { + auto callback = [this, id, dataT, done](const CallbackReply &reply) { if (done != nullptr) { (done)(client_, id, *dataT); } @@ -296,7 +301,7 @@ template Status Set::Add(const DriverID &driver_id, const ID &id, std::shared_ptr &dataT, const WriteCallback &done) { num_adds_++; - auto callback = [this, id, dataT, done](const std::string &data) { + auto callback = [this, id, dataT, done](const CallbackReply &reply) { if (done != nullptr) { (done)(client_, id, *dataT); } @@ -313,7 +318,7 @@ template Status Set::Remove(const DriverID &driver_id, const ID &id, std::shared_ptr &dataT, const WriteCallback &done) { num_removes_++; - auto callback = [this, id, dataT, done](const std::string &data) { + auto callback = [this, id, dataT, done](const CallbackReply &reply) { if (done != nullptr) { (done)(client_, id, *dataT); } From 2912a7cb860bf027a4b8886896ca0533c9965f39 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Thu, 30 May 2019 02:43:17 -0700 Subject: [PATCH 052/118] Initial high-level code structure of CoreWorker. (#4875) --- BUILD.bazel | 35 ++++++++- src/ray/common/buffer.h | 45 ++++++++++++ src/ray/core_worker/common.h | 71 ++++++++++++++++++ src/ray/core_worker/core_worker.h | 68 ++++++++++++++++++ src/ray/core_worker/core_worker_test.cc | 38 ++++++++++ src/ray/core_worker/object_interface.cc | 25 +++++++ src/ray/core_worker/object_interface.h | 61 ++++++++++++++++ src/ray/core_worker/task_execution.cc | 7 ++ src/ray/core_worker/task_execution.h | 36 ++++++++++ src/ray/core_worker/task_interface.cc | 26 +++++++ src/ray/core_worker/task_interface.h | 96 +++++++++++++++++++++++++ 11 files changed, 506 insertions(+), 2 deletions(-) create mode 100644 src/ray/common/buffer.h create mode 100644 src/ray/core_worker/common.h create mode 100644 src/ray/core_worker/core_worker.h create mode 100644 src/ray/core_worker/core_worker_test.cc create mode 100644 src/ray/core_worker/object_interface.cc create mode 100644 src/ray/core_worker/object_interface.h create mode 100644 src/ray/core_worker/task_execution.cc create mode 100644 src/ray/core_worker/task_execution.h create mode 100644 src/ray/core_worker/task_interface.cc create mode 100644 src/ray/core_worker/task_interface.h diff --git a/BUILD.bazel b/BUILD.bazel index e2cbdd64bf51..0bdbe5741cf8 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -105,6 +105,36 @@ cc_library( ], ) +cc_library( + name = "core_worker_lib", + srcs = glob( + [ + "src/ray/core_worker/*.cc", + ], + exclude = [ + "src/ray/core_worker/*_test.cc", + ], + ), + hdrs = glob([ + "src/ray/core_worker/*.h", + ]), + copts = COPTS, + deps = [ + ":ray_common", + ":ray_util", + ], +) + +cc_test( + name = "core_worker_test", + srcs = ["src/ray/core_worker/core_worker_test.cc"], + copts = COPTS, + deps = [ + ":core_worker_lib", + "@com_google_googletest//:gtest_main", + ], +) + cc_test( name = "lineage_cache_test", srcs = ["src/ray/raylet/lineage_cache_test.cc"], @@ -277,6 +307,7 @@ cc_library( "src/ray/common/common_protocol.cc", ], hdrs = [ + "src/ray/common/buffer.h", "src/ray/common/client_connection.h", "src/ray/common/common_protocol.h", ], @@ -637,8 +668,8 @@ genrule( cp -f $(location //:raylet) $$WORK_DIR/python/ray/core/src/ray/raylet/ && for f in $(locations //:python_gcs_fbs); do cp -f $$f $$WORK_DIR/python/ray/core/generated/; done && mkdir -p $$WORK_DIR/python/ray/core/generated/ray/protocol/ && - for f in $(locations //:python_node_manager_fbs); do - cp -f $$f $$WORK_DIR/python/ray/core/generated/ray/protocol/; + for f in $(locations //:python_node_manager_fbs); do + cp -f $$f $$WORK_DIR/python/ray/core/generated/ray/protocol/; done && echo $$WORK_DIR > $@ """, diff --git a/src/ray/common/buffer.h b/src/ray/common/buffer.h new file mode 100644 index 000000000000..358d903799c7 --- /dev/null +++ b/src/ray/common/buffer.h @@ -0,0 +1,45 @@ +#ifndef RAY_COMMON_BUFFER_H +#define RAY_COMMON_BUFFER_H + +#include +#include + +namespace ray { + +/// The interface that represents a buffer of bytes. +class Buffer { + public: + /// Pointer to the data. + virtual uint8_t *Data() const = 0; + + /// Size of this buffer. + virtual size_t Size() const = 0; + + virtual ~Buffer() {} + + bool operator==(const Buffer &rhs) const { + return this->Data() == rhs.Data() && this->Size() == rhs.Size(); + } +}; + +/// Represents a byte buffer in local memory. +class LocalMemoryBuffer : public Buffer { + public: + LocalMemoryBuffer(uint8_t *data, size_t size) : data_(data), size_(size) {} + + uint8_t *Data() const override { return data_; } + + size_t Size() const override { return size_; } + + ~LocalMemoryBuffer() {} + + private: + /// Pointer to the data. + uint8_t *data_; + /// Size of the buffer. + size_t size_; +}; + +} // namespace ray + +#endif // RAY_COMMON_BUFFER_H diff --git a/src/ray/core_worker/common.h b/src/ray/core_worker/common.h new file mode 100644 index 000000000000..b53c35b25fa8 --- /dev/null +++ b/src/ray/core_worker/common.h @@ -0,0 +1,71 @@ +#ifndef RAY_CORE_WORKER_COMMON_H +#define RAY_CORE_WORKER_COMMON_H + +#include + +#include "ray/common/buffer.h" +#include "ray/id.h" + +namespace ray { + +/// Type of this worker. +enum class WorkerType { WORKER, DRIVER }; + +/// Language of Ray tasks and workers. +enum class Language { PYTHON, JAVA }; + +/// Information about a remote function. +struct RayFunction { + /// Language of the remote function. + const Language language; + /// Function descriptor of the remote function. + const std::vector function_descriptor; +}; + +/// Argument of a task. +class TaskArg { + public: + /// Create a pass-by-reference task argument. + /// + /// \param[in] object_id Id of the argument. + /// \return The task argument. + static TaskArg PassByReference(const ObjectID &object_id) { + return TaskArg(std::make_shared(object_id), nullptr); + } + + /// Create a pass-by-reference task argument. + /// + /// \param[in] object_id Id of the argument. + /// \return The task argument. + static TaskArg PassByValue(const std::shared_ptr &data) { + return TaskArg(nullptr, data); + } + + /// Return true if this argument is passed by reference, false if passed by value. + bool IsPassedByReference() const { return id_ != nullptr; } + + /// Get the reference object ID. + ObjectID &GetReference() { + RAY_CHECK(id_ != nullptr) << "This argument isn't passed by reference."; + return *id_; + } + + /// Get the value. + std::shared_ptr GetValue() { + RAY_CHECK(data_ != nullptr) << "This argument isn't passed by value."; + return data_; + } + + private: + TaskArg(const std::shared_ptr id, const std::shared_ptr data) + : id_(id), data_(data) {} + + /// Id of the argument, if passed by reference, otherwise nullptr. + const std::shared_ptr id_; + /// Data of the argument, if passed by value, otherwise nullptr. + const std::shared_ptr data_; +}; + +} // namespace ray + +#endif // RAY_CORE_WORKER_COMMON_H diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h new file mode 100644 index 000000000000..96e51dbc4532 --- /dev/null +++ b/src/ray/core_worker/core_worker.h @@ -0,0 +1,68 @@ +#ifndef RAY_CORE_WORKER_CORE_WORKER_H +#define RAY_CORE_WORKER_CORE_WORKER_H + +#include "common.h" +#include "object_interface.h" +#include "ray/common/buffer.h" +#include "task_execution.h" +#include "task_interface.h" + +namespace ray { + +/// The root class that contains all the core and language-independent functionalities +/// of the worker. This class is supposed to be used to implement app-language (Java, +/// Python, etc) workers. +class CoreWorker { + public: + /// Construct a CoreWorker instance. + /// + /// \param[in] worker_type Type of this worker. + /// \param[in] langauge Language of this worker. + CoreWorker(const WorkerType worker_type, const Language language) + : worker_type_(worker_type), + language_(language), + task_interface_(*this), + object_interface_(*this), + task_execution_interface_(*this) {} + + /// Connect this worker to Raylet. + Status Connect() { return Status::OK(); } + + /// Type of this worker. + enum WorkerType WorkerType() const { return worker_type_; } + + /// Language of this worker. + enum Language Language() const { return language_; } + + /// Return the `CoreWorkerTaskInterface` that contains the methods related to task + /// submisson. + CoreWorkerTaskInterface &Tasks() { return task_interface_; } + + /// Return the `CoreWorkerObjectInterface` that contains methods related to object + /// store. + CoreWorkerObjectInterface &Objects() { return object_interface_; } + + /// Return the `CoreWorkerTaskExecutionInterface` that contains methods related to + /// task execution. + CoreWorkerTaskExecutionInterface &Execution() { return task_execution_interface_; } + + private: + /// Type of this worker. + const enum WorkerType worker_type_; + + /// Language of this worker. + const enum Language language_; + + /// The `CoreWorkerTaskInterface` instance. + CoreWorkerTaskInterface task_interface_; + + /// The `CoreWorkerObjectInterface` instance. + CoreWorkerObjectInterface object_interface_; + + /// The `CoreWorkerTaskExecutionInterface` instance. + CoreWorkerTaskExecutionInterface task_execution_interface_; +}; + +} // namespace ray + +#endif // RAY_CORE_WORKER_CORE_WORKER_H diff --git a/src/ray/core_worker/core_worker_test.cc b/src/ray/core_worker/core_worker_test.cc new file mode 100644 index 000000000000..b1be58da95b8 --- /dev/null +++ b/src/ray/core_worker/core_worker_test.cc @@ -0,0 +1,38 @@ +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +#include "core_worker.h" +#include "ray/common/buffer.h" + +namespace ray { + +class CoreWorkerTest : public ::testing::Test { + public: + CoreWorkerTest() : core_worker_(WorkerType::WORKER, Language::PYTHON) {} + + protected: + CoreWorker core_worker_; +}; + +TEST_F(CoreWorkerTest, TestTaskArg) { + // Test by-reference argument. + ObjectID id = ObjectID::from_random(); + TaskArg by_ref = TaskArg::PassByReference(id); + ASSERT_TRUE(by_ref.IsPassedByReference()); + ASSERT_EQ(by_ref.GetReference(), id); + // Test by-value argument. + std::shared_ptr buffer = + std::make_shared(static_cast(0), 0); + TaskArg by_value = TaskArg::PassByValue(buffer); + ASSERT_FALSE(by_value.IsPassedByReference()); + auto data = by_value.GetValue(); + ASSERT_TRUE(data != nullptr); + ASSERT_EQ(*data, *buffer); +} + +TEST_F(CoreWorkerTest, TestAttributeGetters) { + ASSERT_EQ(core_worker_.WorkerType(), WorkerType::WORKER); + ASSERT_EQ(core_worker_.Language(), Language::PYTHON); +} + +} // namespace ray diff --git a/src/ray/core_worker/object_interface.cc b/src/ray/core_worker/object_interface.cc new file mode 100644 index 000000000000..d5d5d6f883f6 --- /dev/null +++ b/src/ray/core_worker/object_interface.cc @@ -0,0 +1,25 @@ +#include "object_interface.h" + +namespace ray { + +Status CoreWorkerObjectInterface::Put(const Buffer &buffer, const ObjectID *object_id) { + return Status::OK(); +} + +Status CoreWorkerObjectInterface::Get(const std::vector &ids, + int64_t timeout_ms, std::vector *results) { + return Status::OK(); +} + +Status CoreWorkerObjectInterface::Wait(const std::vector &object_ids, + int num_objects, int64_t timeout_ms, + std::vector *results) { + return Status::OK(); +} + +Status CoreWorkerObjectInterface::Delete(const std::vector &object_ids, + bool local_only, bool delete_creating_tasks) { + return Status::OK(); +} + +} // namespace ray diff --git a/src/ray/core_worker/object_interface.h b/src/ray/core_worker/object_interface.h new file mode 100644 index 000000000000..424c123ee543 --- /dev/null +++ b/src/ray/core_worker/object_interface.h @@ -0,0 +1,61 @@ +#ifndef RAY_CORE_WORKER_OBJECT_INTERFACE_H +#define RAY_CORE_WORKER_OBJECT_INTERFACE_H + +#include "common.h" +#include "ray/common/buffer.h" +#include "ray/id.h" +#include "ray/status.h" + +namespace ray { + +class CoreWorker; + +/// The interface that contains all `CoreWorker` methods that are related to object store. +class CoreWorkerObjectInterface { + public: + CoreWorkerObjectInterface(CoreWorker &core_worker) : core_worker_(core_worker) {} + + /// Put an object into object store. + /// + /// \param[in] buffer Data buffer of the object. + /// \param[out] object_id Generated ID of the object. + /// \return Status. + Status Put(const Buffer &buffer, const ObjectID *object_id); + + /// Get a list of objects from the object store. + /// + /// \param[in] ids IDs of the objects to get. + /// \param[in] timeout_ms Timeout in milliseconds, wait infinitely if it's negative. + /// \param[out] results Result list of objects data. + /// \return Status. + Status Get(const std::vector &ids, int64_t timeout_ms, + std::vector *results); + + /// Wait for a list of objects to appear in the object store. + /// + /// \param[in] IDs of the objects to wait for. + /// \param[in] num_returns Number of objects that should appear. + /// \param[in] timeout_ms Timeout in milliseconds, wait infinitely if it's negative. + /// \param[out] results A bitset that indicates each object has appeared or not. + /// \return Status. + Status Wait(const std::vector &object_ids, int num_objects, + int64_t timeout_ms, std::vector *results); + + /// Delete a list of objects from the object store. + /// + /// \param[in] object_ids IDs of the objects to delete. + /// \param[in] local_only Whether only delete the objects in local node, or all nodes in + /// the cluster. + /// \param[in] delete_creating_tasks Whether also delete the tasks that + /// created these objects. \return Status. + Status Delete(const std::vector &object_ids, bool local_only, + bool delete_creating_tasks); + + private: + /// Reference to the parent CoreWorker instance. + CoreWorker &core_worker_; +}; + +} // namespace ray + +#endif // RAY_CORE_WORKER_OBJECT_INTERFACE_H diff --git a/src/ray/core_worker/task_execution.cc b/src/ray/core_worker/task_execution.cc new file mode 100644 index 000000000000..aea48b4de34a --- /dev/null +++ b/src/ray/core_worker/task_execution.cc @@ -0,0 +1,7 @@ +#include "task_execution.h" + +namespace ray { + +void CoreWorkerTaskExecutionInterface::Start(const TaskExecutor &executor) {} + +} // namespace ray diff --git a/src/ray/core_worker/task_execution.h b/src/ray/core_worker/task_execution.h new file mode 100644 index 000000000000..308b1e6868d6 --- /dev/null +++ b/src/ray/core_worker/task_execution.h @@ -0,0 +1,36 @@ +#ifndef RAY_CORE_WORKER_TASK_EXECUTION_H +#define RAY_CORE_WORKER_TASK_EXECUTION_H + +#include "common.h" +#include "ray/common/buffer.h" +#include "ray/status.h" + +namespace ray { + +class CoreWorker; + +/// The interface that contains all `CoreWorker` methods that are related to task +/// execution. +class CoreWorkerTaskExecutionInterface { + public: + CoreWorkerTaskExecutionInterface(CoreWorker &core_worker) : core_worker_(core_worker) {} + + /// The callback provided app-language workers that executes tasks. + /// + /// \param ray_function[in] Information about the function to execute. + /// \param args[in] Arguments of the task. + /// \return Status. + using TaskExecutor = std::function &args)>; + + /// Start receving and executes tasks in a infinite loop. + void Start(const TaskExecutor &executor); + + private: + /// Reference to the parent CoreWorker instance. + CoreWorker &core_worker_; +}; + +} // namespace ray + +#endif // RAY_CORE_WORKER_TASK_EXECUTION_H diff --git a/src/ray/core_worker/task_interface.cc b/src/ray/core_worker/task_interface.cc new file mode 100644 index 000000000000..ab8b8950c298 --- /dev/null +++ b/src/ray/core_worker/task_interface.cc @@ -0,0 +1,26 @@ +#include "task_interface.h" + +namespace ray { + +Status CoreWorkerTaskInterface::SubmitTask(const RayFunction &function, + const std::vector &args, + const TaskOptions &task_options, + std::vector *return_ids) { + return Status::OK(); +} + +Status CoreWorkerTaskInterface::CreateActor( + const RayFunction &function, const std::vector &args, + const ActorCreationOptions &actor_creation_options, ActorHandle *actor_handle) { + return Status::OK(); +} + +Status CoreWorkerTaskInterface::SubmitActorTask(ActorHandle &actor_handle, + const RayFunction &function, + const std::vector &args, + const TaskOptions &task_options, + std::vector *return_ids) { + return Status::OK(); +} + +} // namespace ray diff --git a/src/ray/core_worker/task_interface.h b/src/ray/core_worker/task_interface.h new file mode 100644 index 000000000000..f667d8d5a06f --- /dev/null +++ b/src/ray/core_worker/task_interface.h @@ -0,0 +1,96 @@ +#ifndef RAY_CORE_WORKER_TASK_INTERFACE_H +#define RAY_CORE_WORKER_TASK_INTERFACE_H + +#include "common.h" +#include "ray/common/buffer.h" +#include "ray/id.h" +#include "ray/status.h" + +namespace ray { + +class CoreWorker; + +/// Options of a non-actor-creation task. +struct TaskOptions { + /// Number of returns of this task. + const int num_returns = 1; + /// Resources required by this task. + const std::unordered_map resources; +}; + +/// Options of an actor creation task. +struct ActorCreationOptions { + /// Maximum number of times that the actor should be reconstructed when it dies + /// unexpectedly. It must be non-negative. If it's 0, the actor won't be reconstructed. + const uint64_t max_reconstructions = 0; + /// Resources required by the whole lifetime of this actor. + const std::unordered_map resources; +}; + +/// A handle to an actor. +class ActorHandle { + public: + ActorHandle(const ActorID &actor_id, const ActorHandleID &actor_handle_id) + : actor_id_(actor_id), actor_handle_id_(actor_handle_id) {} + + /// ID of the actor. + const class ActorID &ActorID() const { return actor_id_; } + + /// ID of this actor handle. + const class ActorHandleID &ActorHandleID() const { return actor_handle_id_; } + + private: + /// ID of the actor. + const class ActorID actor_id_; + /// ID of this actor handle. + const class ActorHandleID actor_handle_id_; +}; + +/// The interface that contains all `CoreWorker` methods that are related to task +/// submission. +class CoreWorkerTaskInterface { + public: + CoreWorkerTaskInterface(CoreWorker &core_worker) : core_worker_(core_worker) {} + + /// Submit a normal task. + /// + /// \param[in] function The remote function to execute. + /// \param[in] args Arguments of this task. + /// \param[in] task_options Options for this task. + /// \param[out] return_ids Ids of the return objects. + /// \return Status. + Status SubmitTask(const RayFunction &function, const std::vector &args, + const TaskOptions &task_options, std::vector *return_ids); + + /// Create an actor. + /// + /// \param[in] function The remote function that generates the actor object. + /// \param[in] args Arguments of this task. + /// \param[in] actor_creation_options Options for this actor creation task. + /// \param[out] actor_handle Handle to the actor. + /// \return Status. + Status CreateActor(const RayFunction &function, const std::vector &args, + const ActorCreationOptions &actor_creation_options, + ActorHandle *actor_handle); + + /// Submit an actor task. + /// + /// \param[in] actor_handle Handle to the actor. + /// \param[in] function The remote function to execute. + /// \param[in] args Arguments of this task. + /// \param[in] task_options Options for this task. + /// \param[out] return_ids Ids of the return objects. + /// \return Status. + Status SubmitActorTask(ActorHandle &actor_handle, const RayFunction &function, + const std::vector &args, + const TaskOptions &task_options, + std::vector *return_ids); + + private: + /// Reference to the parent CoreWorker instance. + CoreWorker &core_worker_; +}; + +} // namespace ray + +#endif // RAY_CORE_WORKER_TASK_INTERFACE_H From 4e0be8b4500d00dc0deb675e634d46a6a95b388c Mon Sep 17 00:00:00 2001 From: Si-Yuan Date: Thu, 30 May 2019 19:43:27 +0800 Subject: [PATCH 053/118] Drop duplicated string format (#4897) This string format is unnecessary. java_worker_options has been appended to the commandline later. --- python/ray/services.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/services.py b/python/ray/services.py index 7dc594963cec..00ae4e1a2b09 100644 --- a/python/ray/services.py +++ b/python/ray/services.py @@ -1231,7 +1231,7 @@ def build_java_worker_command( """ assert java_worker_options is not None - command = "java ".format(java_worker_options) + command = "java " if redis_address is not None: command += "-Dray.redis.address={} ".format(redis_address) From 1f0809e2b49bfdb915c77b0d67dda191017d7939 Mon Sep 17 00:00:00 2001 From: Yuhong Guo Date: Fri, 31 May 2019 11:31:19 +0800 Subject: [PATCH 054/118] Refactor ID Serial 2: change all ID functions to `CamelCase` (#4896) --- .../java/org/ray/runtime/util/IdUtil.java | 2 +- python/ray/_raylet.pyx | 14 +-- python/ray/includes/task.pxi | 16 +-- python/ray/includes/unique_ids.pxd | 60 ++++----- python/ray/includes/unique_ids.pxi | 66 +++++----- src/ray/common/client_connection.cc | 4 +- src/ray/common/common_protocol.h | 12 +- src/ray/gcs/client.cc | 8 +- src/ray/gcs/client_test.cc | 62 ++++----- src/ray/gcs/redis_context.cc | 4 +- src/ray/gcs/redis_context.h | 6 +- src/ray/gcs/redis_module/ray_redis_module.cc | 2 +- src/ray/gcs/tables.cc | 58 ++++----- src/ray/gcs/tables.h | 4 +- src/ray/id.cc | 38 +++--- src/ray/id.h | 118 +++++++++--------- src/ray/object_manager/object_buffer_pool.cc | 14 +-- src/ray/object_manager/object_directory.cc | 22 ++-- src/ray/object_manager/object_manager.cc | 24 ++-- src/ray/object_manager/object_manager.h | 2 +- .../test/object_manager_stress_test.cc | 26 ++-- .../test/object_manager_test.cc | 40 +++--- src/ray/raylet/actor_registration.cc | 26 ++-- src/ray/raylet/client_connection_test.cc | 2 +- ...org_ray_runtime_raylet_RayletClientImpl.cc | 18 +-- src/ray/raylet/lineage_cache.cc | 8 +- src/ray/raylet/lineage_cache_test.cc | 32 ++--- src/ray/raylet/monitor.cc | 8 +- src/ray/raylet/node_manager.cc | 94 +++++++------- .../raylet/object_manager_integration_test.cc | 20 +-- src/ray/raylet/raylet_client.cc | 6 +- src/ray/raylet/reconstruction_policy.cc | 12 +- src/ray/raylet/reconstruction_policy_test.cc | 52 ++++---- src/ray/raylet/task_dependency_manager.cc | 16 +-- .../raylet/task_dependency_manager_test.cc | 22 ++-- src/ray/raylet/task_execution_spec.cc | 4 +- src/ray/raylet/task_spec.cc | 13 +- src/ray/raylet/task_test.cc | 56 ++++----- src/ray/raylet/worker.cc | 4 +- src/ray/raylet/worker_pool.cc | 8 +- src/ray/raylet/worker_pool_test.cc | 12 +- 41 files changed, 506 insertions(+), 509 deletions(-) diff --git a/java/runtime/src/main/java/org/ray/runtime/util/IdUtil.java b/java/runtime/src/main/java/org/ray/runtime/util/IdUtil.java index 62c56d17ceed..04f75500b29d 100644 --- a/java/runtime/src/main/java/org/ray/runtime/util/IdUtil.java +++ b/java/runtime/src/main/java/org/ray/runtime/util/IdUtil.java @@ -161,7 +161,7 @@ public static long murmurHashCode(BaseId id) { } /** - * This method is the same as `hash()` method of `ID` class in ray/src/ray/id.h + * This method is the same as `Hash()` method of `ID` class in ray/src/ray/id.h */ private static long murmurHash64A(byte[] data, int length, int seed) { final long m = 0xc6a4a7935bd1e995L; diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index a5f106f1e911..1cea8354dada 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -80,7 +80,7 @@ cdef c_vector[CObjectID] ObjectIDsToVector(object_ids): cdef VectorToObjectIDs(c_vector[CObjectID] object_ids): result = [] for i in range(object_ids.size()): - result.append(ObjectID(object_ids[i].binary())) + result.append(ObjectID(object_ids[i].Binary())) return result @@ -88,11 +88,11 @@ def compute_put_id(TaskID task_id, int64_t put_index): if put_index < 1 or put_index > kMaxTaskPuts: raise ValueError("The range of 'put_index' should be [1, %d]" % kMaxTaskPuts) - return ObjectID(CObjectID.for_put(task_id.native(), put_index).binary()) + return ObjectID(CObjectID.ForPut(task_id.native(), put_index).Binary()) def compute_task_id(ObjectID object_id): - return TaskID(object_id.native().task_id().binary()) + return TaskID(object_id.native().TaskId().Binary()) cdef c_bool is_simple_value(value, int *num_elements_contained): @@ -362,7 +362,7 @@ cdef class RayletClient: with nogil: check_status(self.client.get().PrepareActorCheckpoint( c_actor_id, checkpoint_id)) - return ActorCheckpointID(checkpoint_id.binary()) + return ActorCheckpointID(checkpoint_id.Binary()) def notify_actor_resumed_from_checkpoint(self, ActorID actor_id, ActorCheckpointID checkpoint_id): @@ -370,7 +370,7 @@ cdef class RayletClient: actor_id.native(), checkpoint_id.native())) def set_resource(self, basestring resource_name, double capacity, ClientID client_id): - self.client.get().SetResource(resource_name.encode("ascii"), capacity, CClientID.from_binary(client_id.binary())) + self.client.get().SetResource(resource_name.encode("ascii"), capacity, CClientID.FromBinary(client_id.binary())) @property def language(self): @@ -378,11 +378,11 @@ cdef class RayletClient: @property def client_id(self): - return ClientID(self.client.get().GetClientID().binary()) + return ClientID(self.client.get().GetClientID().Binary()) @property def driver_id(self): - return DriverID(self.client.get().GetDriverID().binary()) + return DriverID(self.client.get().GetDriverID().Binary()) @property def is_worker(self): diff --git a/python/ray/includes/task.pxi b/python/ray/includes/task.pxi index 70e7e584d457..f8258e0a6e65 100644 --- a/python/ray/includes/task.pxi +++ b/python/ray/includes/task.pxi @@ -124,15 +124,15 @@ cdef class Task: def driver_id(self): """Return the driver ID for this task.""" - return DriverID(self.task_spec.get().DriverId().binary()) + return DriverID(self.task_spec.get().DriverId().Binary()) def task_id(self): """Return the task ID for this task.""" - return TaskID(self.task_spec.get().TaskId().binary()) + return TaskID(self.task_spec.get().TaskId().Binary()) def parent_task_id(self): """Return the task ID of the parent task.""" - return TaskID(self.task_spec.get().ParentTaskId().binary()) + return TaskID(self.task_spec.get().ParentTaskId().Binary()) def parent_counter(self): """Return the parent counter of this task.""" @@ -162,7 +162,7 @@ cdef class Task: if count > 0: assert count == 1 arg_list.append( - ObjectID(task_spec.ArgId(i, 0).binary())) + ObjectID(task_spec.ArgId(i, 0).Binary())) else: serialized_str = ( task_spec.ArgVal(i)[:task_spec.ArgValLength(i)]) @@ -178,7 +178,7 @@ cdef class Task: cdef CTaskSpecification *task_spec = self.task_spec.get() return_id_list = [] for i in range(task_spec.NumReturns()): - return_id_list.append(ObjectID(task_spec.ReturnId(i).binary())) + return_id_list.append(ObjectID(task_spec.ReturnId(i).Binary())) return return_id_list def required_resources(self): @@ -207,16 +207,16 @@ cdef class Task: def actor_creation_id(self): """Return the actor creation ID for the task.""" - return ActorID(self.task_spec.get().ActorCreationId().binary()) + return ActorID(self.task_spec.get().ActorCreationId().Binary()) def actor_creation_dummy_object_id(self): """Return the actor creation dummy object ID for the task.""" return ObjectID( - self.task_spec.get().ActorCreationDummyObjectId().binary()) + self.task_spec.get().ActorCreationDummyObjectId().Binary()) def actor_id(self): """Return the actor ID for this task.""" - return ActorID(self.task_spec.get().ActorId().binary()) + return ActorID(self.task_spec.get().ActorId().Binary()) def actor_counter(self): """Return the actor counter for this task.""" diff --git a/python/ray/includes/unique_ids.pxd b/python/ray/includes/unique_ids.pxd index fbe793cc023b..8bf369c649b7 100644 --- a/python/ray/includes/unique_ids.pxd +++ b/python/ray/includes/unique_ids.pxd @@ -8,116 +8,116 @@ cdef extern from "ray/id.h" namespace "ray" nogil: T from_random() @staticmethod - T from_binary(const c_string &binary) + T FromBinary(const c_string &binary) @staticmethod - const T nil() + const T Nil() @staticmethod - size_t size() + size_t Size() - size_t hash() const - c_bool is_nil() const + size_t Hash() const + c_bool IsNil() const c_bool operator==(const CBaseID &rhs) const c_bool operator!=(const CBaseID &rhs) const const uint8_t *data() const; - c_string binary() const; - c_string hex() const; + c_string Binary() const; + c_string Hex() const; cdef cppclass CUniqueID "ray::UniqueID"(CBaseID): CUniqueID() @staticmethod - size_t size() + size_t Size() @staticmethod CUniqueID from_random() @staticmethod - CUniqueID from_binary(const c_string &binary) + CUniqueID FromBinary(const c_string &binary) @staticmethod - const CUniqueID nil() + const CUniqueID Nil() @staticmethod - size_t size() + size_t Size() cdef cppclass CActorCheckpointID "ray::ActorCheckpointID"(CUniqueID): @staticmethod - CActorCheckpointID from_binary(const c_string &binary) + CActorCheckpointID FromBinary(const c_string &binary) cdef cppclass CActorClassID "ray::ActorClassID"(CUniqueID): @staticmethod - CActorClassID from_binary(const c_string &binary) + CActorClassID FromBinary(const c_string &binary) cdef cppclass CActorID "ray::ActorID"(CUniqueID): @staticmethod - CActorID from_binary(const c_string &binary) + CActorID FromBinary(const c_string &binary) cdef cppclass CActorHandleID "ray::ActorHandleID"(CUniqueID): @staticmethod - CActorHandleID from_binary(const c_string &binary) + CActorHandleID FromBinary(const c_string &binary) cdef cppclass CClientID "ray::ClientID"(CUniqueID): @staticmethod - CClientID from_binary(const c_string &binary) + CClientID FromBinary(const c_string &binary) cdef cppclass CConfigID "ray::ConfigID"(CUniqueID): @staticmethod - CConfigID from_binary(const c_string &binary) + CConfigID FromBinary(const c_string &binary) cdef cppclass CFunctionID "ray::FunctionID"(CUniqueID): @staticmethod - CFunctionID from_binary(const c_string &binary) + CFunctionID FromBinary(const c_string &binary) cdef cppclass CDriverID "ray::DriverID"(CUniqueID): @staticmethod - CDriverID from_binary(const c_string &binary) + CDriverID FromBinary(const c_string &binary) cdef cppclass CTaskID "ray::TaskID"(CBaseID[CTaskID]): @staticmethod - CTaskID from_binary(const c_string &binary) + CTaskID FromBinary(const c_string &binary) @staticmethod - const CTaskID nil() + const CTaskID Nil() @staticmethod - size_t size() + size_t Size() cdef cppclass CObjectID" ray::ObjectID"(CBaseID[CObjectID]): @staticmethod - CObjectID from_binary(const c_string &binary) + CObjectID FromBinary(const c_string &binary) @staticmethod - const CObjectID nil() + const CObjectID Nil() @staticmethod - CObjectID for_put(const CTaskID &task_id, int64_t index); + CObjectID ForPut(const CTaskID &task_id, int64_t index); @staticmethod - CObjectID for_task_return(const CTaskID &task_id, int64_t index); + CObjectID ForTaskReturn(const CTaskID &task_id, int64_t index); @staticmethod - size_t size() + size_t Size() c_bool is_put() - int64_t object_index() const + int64_t ObjectIndex() const - CTaskID task_id() const + CTaskID TaskId() const cdef cppclass CWorkerID "ray::WorkerID"(CUniqueID): @staticmethod - CWorkerID from_binary(const c_string &binary) + CWorkerID FromBinary(const c_string &binary) diff --git a/python/ray/includes/unique_ids.pxi b/python/ray/includes/unique_ids.pxi index b9773d56fb20..cd3c58003fed 100644 --- a/python/ray/includes/unique_ids.pxi +++ b/python/ray/includes/unique_ids.pxi @@ -97,7 +97,7 @@ cdef class UniqueID(BaseID): def __init__(self, id): check_id(id) - self.data = CUniqueID.from_binary(id) + self.data = CUniqueID.FromBinary(id) @classmethod def from_binary(cls, id_bytes): @@ -107,27 +107,27 @@ cdef class UniqueID(BaseID): @classmethod def nil(cls): - return cls(CUniqueID.nil().binary()) + return cls(CUniqueID.Nil().Binary()) @classmethod def from_random(cls): - return cls(os.urandom(CUniqueID.size())) + return cls(os.urandom(CUniqueID.Size())) def size(self): - return CUniqueID.size() + return CUniqueID.Size() def binary(self): - return self.data.binary() + return self.data.Binary() def hex(self): - return decode(self.data.hex()) + return decode(self.data.Hex()) def is_nil(self): - return self.data.is_nil() + return self.data.IsNil() cdef size_t hash(self): - return self.data.hash() + return self.data.Hash() cdef class ObjectID(BaseID): @@ -135,78 +135,78 @@ cdef class ObjectID(BaseID): def __init__(self, id): check_id(id) - self.data = CObjectID.from_binary(id) + self.data = CObjectID.FromBinary(id) cdef CObjectID native(self): return self.data def size(self): - return CObjectID.size() + return CObjectID.Size() def binary(self): - return self.data.binary() + return self.data.Binary() def hex(self): - return decode(self.data.hex()) + return decode(self.data.Hex()) def is_nil(self): - return self.data.is_nil() + return self.data.IsNil() cdef size_t hash(self): - return self.data.hash() + return self.data.Hash() @classmethod def nil(cls): - return cls(CObjectID.nil().binary()) + return cls(CObjectID.Nil().Binary()) @classmethod def from_random(cls): - return cls(os.urandom(CObjectID.size())) + return cls(os.urandom(CObjectID.Size())) cdef class TaskID(BaseID): cdef CTaskID data def __init__(self, id): - check_id(id, CTaskID.size()) - self.data = CTaskID.from_binary(id) + check_id(id, CTaskID.Size()) + self.data = CTaskID.FromBinary(id) cdef CTaskID native(self): return self.data def size(self): - return CTaskID.size() + return CTaskID.Size() def binary(self): - return self.data.binary() + return self.data.Binary() def hex(self): - return decode(self.data.hex()) + return decode(self.data.Hex()) def is_nil(self): - return self.data.is_nil() + return self.data.IsNil() cdef size_t hash(self): - return self.data.hash() + return self.data.Hash() @classmethod def nil(cls): - return cls(CTaskID.nil().binary()) + return cls(CTaskID.Nil().Binary()) @classmethod def size(cla): - return CTaskID.size() + return CTaskID.Size() @classmethod def from_random(cls): - return cls(os.urandom(CTaskID.size())) + return cls(os.urandom(CTaskID.Size())) cdef class ClientID(UniqueID): def __init__(self, id): check_id(id) - self.data = CClientID.from_binary(id) + self.data = CClientID.FromBinary(id) cdef CClientID native(self): return self.data @@ -216,7 +216,7 @@ cdef class DriverID(UniqueID): def __init__(self, id): check_id(id) - self.data = CDriverID.from_binary(id) + self.data = CDriverID.FromBinary(id) cdef CDriverID native(self): return self.data @@ -226,7 +226,7 @@ cdef class ActorID(UniqueID): def __init__(self, id): check_id(id) - self.data = CActorID.from_binary(id) + self.data = CActorID.FromBinary(id) cdef CActorID native(self): return self.data @@ -236,7 +236,7 @@ cdef class ActorHandleID(UniqueID): def __init__(self, id): check_id(id) - self.data = CActorHandleID.from_binary(id) + self.data = CActorHandleID.FromBinary(id) cdef CActorHandleID native(self): return self.data @@ -246,7 +246,7 @@ cdef class ActorCheckpointID(UniqueID): def __init__(self, id): check_id(id) - self.data = CActorCheckpointID.from_binary(id) + self.data = CActorCheckpointID.FromBinary(id) cdef CActorCheckpointID native(self): return self.data @@ -256,7 +256,7 @@ cdef class FunctionID(UniqueID): def __init__(self, id): check_id(id) - self.data = CFunctionID.from_binary(id) + self.data = CFunctionID.FromBinary(id) cdef CFunctionID native(self): return self.data @@ -266,7 +266,7 @@ cdef class ActorClassID(UniqueID): def __init__(self, id): check_id(id) - self.data = CActorClassID.from_binary(id) + self.data = CActorClassID.FromBinary(id) cdef CActorClassID native(self): return self.data diff --git a/src/ray/common/client_connection.cc b/src/ray/common/client_connection.cc index 4ad961008e91..b36382bbb99a 100644 --- a/src/ray/common/client_connection.cc +++ b/src/ray/common/client_connection.cc @@ -231,7 +231,7 @@ ClientConnection::ClientConnection( const std::string &debug_label, const std::vector &message_type_enum_names, int64_t error_message_type) : ServerConnection(std::move(socket)), - client_id_(ClientID::nil()), + client_id_(ClientID::Nil()), message_handler_(message_handler), debug_label_(debug_label), message_type_enum_names_(message_type_enum_names), @@ -307,7 +307,7 @@ bool ClientConnection::CheckRayCookie() { ss << ", remote endpoint info: " << remote_endpoint_info; } - if (!client_id_.is_nil()) { + if (!client_id_.IsNil()) { // This is from a known client, which indicates a bug. RAY_LOG(FATAL) << ss.str(); } else { diff --git a/src/ray/common/common_protocol.h b/src/ray/common/common_protocol.h index 792bb173b105..63a0bf8c259c 100644 --- a/src/ray/common/common_protocol.h +++ b/src/ray/common/common_protocol.h @@ -104,13 +104,13 @@ string_vec_to_flatbuf(flatbuffers::FlatBufferBuilder &fbb, template flatbuffers::Offset to_flatbuf(flatbuffers::FlatBufferBuilder &fbb, ID id) { - return fbb.CreateString(reinterpret_cast(id.data()), id.size()); + return fbb.CreateString(reinterpret_cast(id.Data()), id.Size()); } template ID from_flatbuf(const flatbuffers::String &string) { - RAY_CHECK(string.size() == ID::size()); - return ID::from_binary(string.str()); + RAY_CHECK(string.size() == ID::Size()); + return ID::FromBinary(string.str()); } template @@ -127,14 +127,14 @@ template const std::vector ids_from_flatbuf(const flatbuffers::String &string) { const auto &ids = string_from_flatbuf(string); std::vector ret; - size_t id_size = ID::size(); + size_t id_size = ID::Size(); RAY_CHECK(ids.size() % id_size == 0); auto count = ids.size() / id_size; for (size_t i = 0; i < count; ++i) { auto pos = static_cast(id_size * i); const auto &id = ids.substr(pos, id_size); - ret.push_back(ID::from_binary(id)); + ret.push_back(ID::FromBinary(id)); } return ret; @@ -145,7 +145,7 @@ flatbuffers::Offset ids_to_flatbuf( flatbuffers::FlatBufferBuilder &fbb, const std::vector &ids) { std::string result; for (const auto &id : ids) { - result += id.binary(); + result += id.Binary(); } return fbb.CreateString(result); diff --git a/src/ray/gcs/client.cc b/src/ray/gcs/client.cc index b51421e10a14..642f5f2cf156 100644 --- a/src/ray/gcs/client.cc +++ b/src/ray/gcs/client.cc @@ -146,19 +146,19 @@ AsyncGcsClient::AsyncGcsClient(const std::string &address, int port, AsyncGcsClient::AsyncGcsClient(const std::string &address, int port, CommandType command_type) - : AsyncGcsClient(address, port, ClientID::from_random(), command_type) {} + : AsyncGcsClient(address, port, ClientID::FromRandom(), command_type) {} AsyncGcsClient::AsyncGcsClient(const std::string &address, int port, CommandType command_type, bool is_test_client) - : AsyncGcsClient(address, port, ClientID::from_random(), command_type, + : AsyncGcsClient(address, port, ClientID::FromRandom(), command_type, is_test_client) {} AsyncGcsClient::AsyncGcsClient(const std::string &address, int port, const std::string &password = "") - : AsyncGcsClient(address, port, ClientID::from_random(), false, password) {} + : AsyncGcsClient(address, port, ClientID::FromRandom(), false, password) {} AsyncGcsClient::AsyncGcsClient(const std::string &address, int port, bool is_test_client) - : AsyncGcsClient(address, port, ClientID::from_random(), is_test_client) {} + : AsyncGcsClient(address, port, ClientID::FromRandom(), is_test_client) {} Status AsyncGcsClient::Attach(boost::asio::io_service &io_service) { // Take care of sharding contexts. diff --git a/src/ray/gcs/client_test.cc b/src/ray/gcs/client_test.cc index 7f69c482e5eb..c203a4a9482a 100644 --- a/src/ray/gcs/client_test.cc +++ b/src/ray/gcs/client_test.cc @@ -29,7 +29,7 @@ class TestGcs : public ::testing::Test { TestGcs(CommandType command_type) : num_callbacks_(0), command_type_(command_type) { client_ = std::make_shared("127.0.0.1", 6379, command_type_, /*is_test_client=*/true); - driver_id_ = DriverID::from_random(); + driver_id_ = DriverID::FromRandom(); } virtual ~TestGcs() { @@ -84,7 +84,7 @@ class TestGcsWithChainAsio : public TestGcsWithAsio { void TestTableLookup(const DriverID &driver_id, std::shared_ptr client) { - TaskID task_id = TaskID::from_random(); + TaskID task_id = TaskID::FromRandom(); auto data = std::make_shared(); data->task_specification = "123"; @@ -133,7 +133,7 @@ TEST_MACRO(TestGcsWithChainAsio, TestTableLookup); void TestLogLookup(const DriverID &driver_id, std::shared_ptr client) { // Append some entries to the log at an object ID. - TaskID task_id = TaskID::from_random(); + TaskID task_id = TaskID::FromRandom(); std::vector node_manager_ids = {"abc", "def", "ghi"}; for (auto &node_manager_id : node_manager_ids) { auto data = std::make_shared(); @@ -178,7 +178,7 @@ TEST_F(TestGcsWithAsio, TestLogLookup) { void TestTableLookupFailure(const DriverID &driver_id, std::shared_ptr client) { - TaskID task_id = TaskID::from_random(); + TaskID task_id = TaskID::FromRandom(); // Check that the lookup does not return data. auto lookup_callback = [](gcs::AsyncGcsClient *client, const TaskID &id, @@ -205,7 +205,7 @@ TEST_MACRO(TestGcsWithChainAsio, TestTableLookupFailure); void TestLogAppendAt(const DriverID &driver_id, std::shared_ptr client) { - TaskID task_id = TaskID::from_random(); + TaskID task_id = TaskID::FromRandom(); std::vector node_manager_ids = {"A", "B"}; std::vector> data_log; for (const auto &node_manager_id : node_manager_ids) { @@ -265,7 +265,7 @@ TEST_F(TestGcsWithAsio, TestLogAppendAt) { void TestSet(const DriverID &driver_id, std::shared_ptr client) { // Add some entries to the set at an object ID. - ObjectID object_id = ObjectID::from_random(); + ObjectID object_id = ObjectID::FromRandom(); std::vector managers = {"abc", "def", "ghi"}; for (auto &manager : managers) { auto data = std::make_shared(); @@ -335,7 +335,7 @@ void TestDeleteKeysFromLog( std::vector ids; TaskID task_id; for (auto &data : data_vector) { - task_id = TaskID::from_random(); + task_id = TaskID::FromRandom(); ids.push_back(task_id); // Check that we added the correct object entries. auto add_callback = [task_id, data](gcs::AsyncGcsClient *client, const TaskID &id, @@ -383,7 +383,7 @@ void TestDeleteKeysFromTable(const DriverID &driver_id, std::vector ids; TaskID task_id; for (auto &data : data_vector) { - task_id = TaskID::from_random(); + task_id = TaskID::FromRandom(); ids.push_back(task_id); // Check that we added the correct object entries. auto add_callback = [task_id, data](gcs::AsyncGcsClient *client, const TaskID &id, @@ -431,7 +431,7 @@ void TestDeleteKeysFromSet(const DriverID &driver_id, std::vector ids; ObjectID object_id; for (auto &data : data_vector) { - object_id = ObjectID::from_random(); + object_id = ObjectID::FromRandom(); ids.push_back(object_id); // Check that we added the correct object entries. auto add_callback = [object_id, data](gcs::AsyncGcsClient *client, const ObjectID &id, @@ -477,7 +477,7 @@ void TestDeleteKeys(const DriverID &driver_id, auto AppendTaskReconstructionData = [&task_reconstruction_vector](size_t add_count) { for (size_t i = 0; i < add_count; ++i) { auto data = std::make_shared(); - data->node_manager_id = ObjectID::from_random().hex(); + data->node_manager_id = ObjectID::FromRandom().Hex(); task_reconstruction_vector.push_back(data); } }; @@ -506,7 +506,7 @@ void TestDeleteKeys(const DriverID &driver_id, auto AppendTaskData = [&task_vector](size_t add_count) { for (size_t i = 0; i < add_count; ++i) { auto task_data = std::make_shared(); - task_data->task_specification = ObjectID::from_random().hex(); + task_data->task_specification = ObjectID::FromRandom().Hex(); task_vector.push_back(task_data); } }; @@ -532,7 +532,7 @@ void TestDeleteKeys(const DriverID &driver_id, auto AppendObjectData = [&object_vector](size_t add_count) { for (size_t i = 0; i < add_count; ++i) { auto data = std::make_shared(); - data->manager = ObjectID::from_random().hex(); + data->manager = ObjectID::FromRandom().Hex(); object_vector.push_back(data); } }; @@ -603,7 +603,7 @@ void TestLogSubscribeAll(const DriverID &driver_id, std::shared_ptr client) { std::vector driver_ids; for (int i = 0; i < 3; i++) { - driver_ids.emplace_back(DriverID::from_random()); + driver_ids.emplace_back(DriverID::FromRandom()); } // Callback for a notification. auto notification_callback = [driver_ids](gcs::AsyncGcsClient *client, @@ -612,7 +612,7 @@ void TestLogSubscribeAll(const DriverID &driver_id, ASSERT_EQ(id, driver_ids[test->NumCallbacks()]); // Check that we get notifications in the same order as the writes. for (const auto &entry : data) { - ASSERT_EQ(entry.driver_id, driver_ids[test->NumCallbacks()].binary()); + ASSERT_EQ(entry.driver_id, driver_ids[test->NumCallbacks()].Binary()); test->IncrementNumCallbacks(); } if (test->NumCallbacks() == driver_ids.size()) { @@ -633,7 +633,7 @@ void TestLogSubscribeAll(const DriverID &driver_id, // subscribed, we will append to the key several times and check that we get // notified for each. RAY_CHECK_OK(client->driver_table().Subscribe( - driver_id, ClientID::nil(), notification_callback, subscribe_callback)); + driver_id, ClientID::Nil(), notification_callback, subscribe_callback)); // Run the event loop. The loop will only stop if the registered subscription // callback is called (or an assertion failure). @@ -651,7 +651,7 @@ void TestSetSubscribeAll(const DriverID &driver_id, std::shared_ptr client) { std::vector object_ids; for (int i = 0; i < 3; i++) { - object_ids.emplace_back(ObjectID::from_random()); + object_ids.emplace_back(ObjectID::FromRandom()); } std::vector managers = {"abc", "def", "ghi"}; @@ -711,7 +711,7 @@ void TestSetSubscribeAll(const DriverID &driver_id, // subscribed, we will append to the key several times and check that we get // notified for each. RAY_CHECK_OK(client->object_table().Subscribe( - driver_id, ClientID::nil(), notification_callback, subscribe_callback)); + driver_id, ClientID::Nil(), notification_callback, subscribe_callback)); // Run the event loop. The loop will only stop if the registered subscription // callback is called (or an assertion failure). @@ -728,11 +728,11 @@ TEST_F(TestGcsWithAsio, TestSetSubscribeAll) { void TestTableSubscribeId(const DriverID &driver_id, std::shared_ptr client) { // Add a table entry. - TaskID task_id1 = TaskID::from_random(); + TaskID task_id1 = TaskID::FromRandom(); std::vector task_specs1 = {"abc", "def", "ghi"}; // Add a table entry at a second key. - TaskID task_id2 = TaskID::from_random(); + TaskID task_id2 = TaskID::FromRandom(); std::vector task_specs2 = {"jkl", "mno", "pqr"}; // The callback for a notification from the table. This should only be @@ -804,14 +804,14 @@ TEST_MACRO(TestGcsWithChainAsio, TestTableSubscribeId); void TestLogSubscribeId(const DriverID &driver_id, std::shared_ptr client) { // Add a log entry. - DriverID driver_id1 = DriverID::from_random(); + DriverID driver_id1 = DriverID::FromRandom(); std::vector driver_ids1 = {"abc", "def", "ghi"}; auto data1 = std::make_shared(); data1->driver_id = driver_ids1[0]; RAY_CHECK_OK(client->driver_table().Append(driver_id, driver_id1, data1, nullptr)); // Add a log entry at a second key. - DriverID driver_id2 = DriverID::from_random(); + DriverID driver_id2 = DriverID::FromRandom(); std::vector driver_ids2 = {"jkl", "mno", "pqr"}; auto data2 = std::make_shared(); data2->driver_id = driver_ids2[0]; @@ -878,14 +878,14 @@ TEST_F(TestGcsWithAsio, TestLogSubscribeId) { void TestSetSubscribeId(const DriverID &driver_id, std::shared_ptr client) { // Add a set entry. - ObjectID object_id1 = ObjectID::from_random(); + ObjectID object_id1 = ObjectID::FromRandom(); std::vector managers1 = {"abc", "def", "ghi"}; auto data1 = std::make_shared(); data1->manager = managers1[0]; RAY_CHECK_OK(client->object_table().Add(driver_id, object_id1, data1, nullptr)); // Add a set entry at a second key. - ObjectID object_id2 = ObjectID::from_random(); + ObjectID object_id2 = ObjectID::FromRandom(); std::vector managers2 = {"jkl", "mno", "pqr"}; auto data2 = std::make_shared(); data2->manager = managers2[0]; @@ -954,7 +954,7 @@ TEST_F(TestGcsWithAsio, TestSetSubscribeId) { void TestTableSubscribeCancel(const DriverID &driver_id, std::shared_ptr client) { // Add a table entry. - TaskID task_id = TaskID::from_random(); + TaskID task_id = TaskID::FromRandom(); std::vector task_specs = {"jkl", "mno", "pqr"}; auto data = std::make_shared(); data->task_specification = task_specs[0]; @@ -1029,7 +1029,7 @@ TEST_MACRO(TestGcsWithChainAsio, TestTableSubscribeCancel); void TestLogSubscribeCancel(const DriverID &driver_id, std::shared_ptr client) { // Add a log entry. - DriverID random_driver_id = DriverID::from_random(); + DriverID random_driver_id = DriverID::FromRandom(); std::vector driver_ids = {"jkl", "mno", "pqr"}; auto data = std::make_shared(); data->driver_id = driver_ids[0]; @@ -1102,7 +1102,7 @@ TEST_F(TestGcsWithAsio, TestLogSubscribeCancel) { void TestSetSubscribeCancel(const DriverID &driver_id, std::shared_ptr client) { // Add a set entry. - ObjectID object_id = ObjectID::from_random(); + ObjectID object_id = ObjectID::FromRandom(); std::vector managers = {"jkl", "mno", "pqr"}; auto data = std::make_shared(); data->manager = managers[0]; @@ -1186,13 +1186,13 @@ void ClientTableNotification(gcs::AsyncGcsClient *client, const ClientID &client const ClientTableDataT &data, bool is_insertion) { ClientID added_id = client->client_table().GetLocalClientId(); ASSERT_EQ(client_id, added_id); - ASSERT_EQ(ClientID::from_binary(data.client_id), added_id); - ASSERT_EQ(ClientID::from_binary(data.client_id), added_id); + ASSERT_EQ(ClientID::FromBinary(data.client_id), added_id); + ASSERT_EQ(ClientID::FromBinary(data.client_id), added_id); ASSERT_EQ(data.entry_type == EntryType::INSERTION, is_insertion); ClientTableDataT cached_client; client->client_table().GetClient(added_id, cached_client); - ASSERT_EQ(ClientID::from_binary(cached_client.client_id), added_id); + ASSERT_EQ(ClientID::FromBinary(cached_client.client_id), added_id); ASSERT_EQ(cached_client.entry_type == EntryType::INSERTION, is_insertion); } @@ -1290,13 +1290,13 @@ void TestClientTableMarkDisconnected(const DriverID &driver_id, // Connect to the client table to start receiving notifications. RAY_CHECK_OK(client->client_table().Connect(local_client_info)); // Mark a different client as dead. - ClientID dead_client_id = ClientID::from_random(); + ClientID dead_client_id = ClientID::FromRandom(); RAY_CHECK_OK(client->client_table().MarkDisconnected(dead_client_id)); // Make sure we only get a notification for the removal of the client we // marked as dead. client->client_table().RegisterClientRemovedCallback([dead_client_id]( gcs::AsyncGcsClient *client, const UniqueID &id, const ClientTableDataT &data) { - ASSERT_EQ(ClientID::from_binary(data.client_id), dead_client_id); + ASSERT_EQ(ClientID::FromBinary(data.client_id), dead_client_id); test->Stop(); }); test->Start(); diff --git a/src/ray/gcs/redis_context.cc b/src/ray/gcs/redis_context.cc index e0c5a6565412..42d921d932d7 100644 --- a/src/ray/gcs/redis_context.cc +++ b/src/ray/gcs/redis_context.cc @@ -273,7 +273,7 @@ Status RedisContext::SubscribeAsync(const ClientID &client_id, RAY_CHECK(out_callback_index != nullptr); *out_callback_index = callback_index; int status = 0; - if (client_id.is_nil()) { + if (client_id.IsNil()) { // Subscribe to all messages. std::string redis_command = "SUBSCRIBE %d"; status = redisAsyncCommand( @@ -285,7 +285,7 @@ Status RedisContext::SubscribeAsync(const ClientID &client_id, status = redisAsyncCommand( subscribe_context_, reinterpret_cast(&GlobalRedisCallback), reinterpret_cast(callback_index), redis_command.c_str(), pubsub_channel, - client_id.data(), client_id.size()); + client_id.Data(), client_id.Size()); } if (status == REDIS_ERR) { diff --git a/src/ray/gcs/redis_context.h b/src/ray/gcs/redis_context.h index b82915374b0a..264f61b1ceaa 100644 --- a/src/ray/gcs/redis_context.h +++ b/src/ray/gcs/redis_context.h @@ -168,7 +168,7 @@ Status RedisContext::RunAsync(const std::string &command, const ID &id, int status = redisAsyncCommand( async_context_, reinterpret_cast(&GlobalRedisCallback), reinterpret_cast(callback_index), redis_command.c_str(), prefix, - pubsub_channel, id.data(), id.size(), data, length, log_length); + pubsub_channel, id.Data(), id.Size(), data, length, log_length); if (status == REDIS_ERR) { return Status::RedisError(std::string(async_context_->errstr)); } @@ -177,7 +177,7 @@ Status RedisContext::RunAsync(const std::string &command, const ID &id, int status = redisAsyncCommand( async_context_, reinterpret_cast(&GlobalRedisCallback), reinterpret_cast(callback_index), redis_command.c_str(), prefix, - pubsub_channel, id.data(), id.size(), data, length); + pubsub_channel, id.Data(), id.Size(), data, length); if (status == REDIS_ERR) { return Status::RedisError(std::string(async_context_->errstr)); } @@ -188,7 +188,7 @@ Status RedisContext::RunAsync(const std::string &command, const ID &id, int status = redisAsyncCommand( async_context_, reinterpret_cast(&GlobalRedisCallback), reinterpret_cast(callback_index), redis_command.c_str(), prefix, - pubsub_channel, id.data(), id.size()); + pubsub_channel, id.Data(), id.Size()); if (status == REDIS_ERR) { return Status::RedisError(std::string(async_context_->errstr)); } diff --git a/src/ray/gcs/redis_module/ray_redis_module.cc b/src/ray/gcs/redis_module/ray_redis_module.cc index 0014778896cd..6a7742c6b5a4 100644 --- a/src/ray/gcs/redis_module/ray_redis_module.cc +++ b/src/ray/gcs/redis_module/ray_redis_module.cc @@ -648,7 +648,7 @@ static Status DeleteKeyHelper(RedisModuleCtx *ctx, RedisModuleString *prefix_str const char *redis_string_str = RedisModule_StringPtrLen(id_data, &redis_string_size); auto id_binary = std::string(redis_string_str, redis_string_size); ostream << "Undesired type for RAY.TableDelete: " << key_type - << " id:" << ray::UniqueID::from_binary(id_binary); + << " id:" << ray::UniqueID::FromBinary(id_binary); RAY_LOG(ERROR) << ostream.str(); return Status::RedisError(ostream.str()); } diff --git a/src/ray/gcs/tables.cc b/src/ray/gcs/tables.cc index ccf05f2b5151..3a381313fd21 100644 --- a/src/ray/gcs/tables.cc +++ b/src/ray/gcs/tables.cc @@ -172,7 +172,7 @@ Status Log::RequestNotifications(const DriverID &driver_id, const ID & RAY_CHECK(subscribe_callback_index_ >= 0) << "Client requested notifications on a key before Subscribe completed"; return GetRedisContext(id)->RunAsync("RAY.TABLE_REQUEST_NOTIFICATIONS", id, - client_id.data(), client_id.size(), prefix_, + client_id.Data(), client_id.Size(), prefix_, pubsub_channel_, nullptr); } @@ -182,7 +182,7 @@ Status Log::CancelNotifications(const DriverID &driver_id, const ID &i RAY_CHECK(subscribe_callback_index_ >= 0) << "Client canceled notifications on a key before Subscribe completed"; return GetRedisContext(id)->RunAsync("RAY.TABLE_CANCEL_NOTIFICATIONS", id, - client_id.data(), client_id.size(), prefix_, + client_id.Data(), client_id.Size(), prefix_, pubsub_channel_, nullptr); } @@ -193,16 +193,16 @@ void Log::Delete(const DriverID &driver_id, const std::vector &ids } std::unordered_map sharded_data; for (const auto &id : ids) { - sharded_data[GetRedisContext(id).get()] << id.binary(); + sharded_data[GetRedisContext(id).get()] << id.Binary(); } // Breaking really large deletion commands into batches of smaller size. const size_t batch_size = - RayConfig::instance().maximum_gcs_deletion_batch_size() * ID::size(); + RayConfig::instance().maximum_gcs_deletion_batch_size() * ID::Size(); for (const auto &pair : sharded_data) { std::string current_data = pair.second.str(); for (size_t cur = 0; cur < pair.second.str().size(); cur += batch_size) { size_t data_field_size = std::min(batch_size, current_data.size() - cur); - uint16_t id_count = data_field_size / ID::size(); + uint16_t id_count = data_field_size / ID::Size(); // Send data contains id count and all the id data. std::string send_data(data_field_size + sizeof(id_count), 0); uint8_t *buffer = reinterpret_cast(&send_data[0]); @@ -212,7 +212,7 @@ void Log::Delete(const DriverID &driver_id, const std::vector &ids data_field_size, buffer + sizeof(uint16_t))); RAY_IGNORE_EXPR( - pair.first->RunAsync("RAY.TABLE_DELETE", UniqueID::nil(), + pair.first->RunAsync("RAY.TABLE_DELETE", UniqueID::Nil(), reinterpret_cast(send_data.c_str()), send_data.size(), prefix_, pubsub_channel_, /*redisCallback=*/nullptr)); @@ -342,7 +342,7 @@ std::string Set::DebugString() const { Status ErrorTable::PushErrorToDriver(const DriverID &driver_id, const std::string &type, const std::string &error_message, double timestamp) { auto data = std::make_shared(); - data->driver_id = driver_id.binary(); + data->driver_id = driver_id.Binary(); data->type = type; data->error_message = error_message; data->timestamp = timestamp; @@ -359,7 +359,7 @@ Status ProfileTable::AddProfileEventBatch(const ProfileTableData &profile_events // call "Pack" and undo the "UnPack". profile_events.UnPackTo(data.get()); - return Append(DriverID::nil(), UniqueID::from_random(), data, + return Append(DriverID::Nil(), UniqueID::FromRandom(), data, /*done_callback=*/nullptr); } @@ -369,7 +369,7 @@ std::string ProfileTable::DebugString() const { Status DriverTable::AppendDriverData(const DriverID &driver_id, bool is_dead) { auto data = std::make_shared(); - data->driver_id = driver_id.binary(); + data->driver_id = driver_id.Binary(); data->is_dead = is_dead; return Append(DriverID(driver_id), driver_id, data, /*done_callback=*/nullptr); } @@ -378,7 +378,7 @@ void ClientTable::RegisterClientAddedCallback(const ClientTableCallback &callbac client_added_callback_ = callback; // Call the callback for any added clients that are cached. for (const auto &entry : client_cache_) { - if (!entry.first.is_nil() && (entry.second.entry_type == EntryType::INSERTION)) { + if (!entry.first.IsNil() && (entry.second.entry_type == EntryType::INSERTION)) { client_added_callback_(client_, entry.first, entry.second); } } @@ -388,7 +388,7 @@ void ClientTable::RegisterClientRemovedCallback(const ClientTableCallback &callb client_removed_callback_ = callback; // Call the callback for any removed clients that are cached. for (const auto &entry : client_cache_) { - if (!entry.first.is_nil() && entry.second.entry_type == EntryType::DELETION) { + if (!entry.first.IsNil() && entry.second.entry_type == EntryType::DELETION) { client_removed_callback_(client_, entry.first, entry.second); } } @@ -399,7 +399,7 @@ void ClientTable::RegisterResourceCreateUpdatedCallback( resource_createupdated_callback_ = callback; // Call the callback for any clients that are cached. for (const auto &entry : client_cache_) { - if (!entry.first.is_nil() && + if (!entry.first.IsNil() && (entry.second.entry_type == EntryType::RES_CREATEUPDATE)) { resource_createupdated_callback_(client_, entry.first, entry.second); } @@ -410,7 +410,7 @@ void ClientTable::RegisterResourceDeletedCallback(const ClientTableCallback &cal resource_deleted_callback_ = callback; // Call the callback for any clients that are cached. for (const auto &entry : client_cache_) { - if (!entry.first.is_nil() && entry.second.entry_type == EntryType::RES_DELETE) { + if (!entry.first.IsNil() && entry.second.entry_type == EntryType::RES_DELETE) { resource_deleted_callback_(client_, entry.first, entry.second); } } @@ -418,7 +418,7 @@ void ClientTable::RegisterResourceDeletedCallback(const ClientTableCallback &cal void ClientTable::HandleNotification(AsyncGcsClient *client, const ClientTableDataT &data) { - ClientID client_id = ClientID::from_binary(data.client_id); + ClientID client_id = ClientID::FromBinary(data.client_id); // It's possible to get duplicate notifications from the client table, so // check whether this notification is new. auto entry = client_cache_.find(client_id); @@ -524,7 +524,7 @@ void ClientTable::HandleNotification(AsyncGcsClient *client, } void ClientTable::HandleConnected(AsyncGcsClient *client, const ClientTableDataT &data) { - auto connected_client_id = ClientID::from_binary(data.client_id); + auto connected_client_id = ClientID::FromBinary(data.client_id); RAY_CHECK(client_id_ == connected_client_id) << connected_client_id << " " << client_id_; } @@ -583,13 +583,13 @@ Status ClientTable::Connect(const ClientTableDataT &local_client) { // Callback to request notifications from the client table once we've // successfully subscribed. auto subscription_callback = [this](AsyncGcsClient *c) { - RAY_CHECK_OK(RequestNotifications(DriverID::nil(), client_log_key_, client_id_)); + RAY_CHECK_OK(RequestNotifications(DriverID::Nil(), client_log_key_, client_id_)); }; // Subscribe to the client table. - RAY_CHECK_OK(Subscribe(DriverID::nil(), client_id_, notification_callback, + RAY_CHECK_OK(Subscribe(DriverID::Nil(), client_id_, notification_callback, subscription_callback)); }; - return Append(DriverID::nil(), client_log_key_, data, add_callback); + return Append(DriverID::Nil(), client_log_key_, data, add_callback); } Status ClientTable::Disconnect(const DisconnectCallback &callback) { @@ -598,12 +598,12 @@ Status ClientTable::Disconnect(const DisconnectCallback &callback) { auto add_callback = [this, callback](AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) { HandleConnected(client, data); - RAY_CHECK_OK(CancelNotifications(DriverID::nil(), client_log_key_, id)); + RAY_CHECK_OK(CancelNotifications(DriverID::Nil(), client_log_key_, id)); if (callback != nullptr) { callback(); } }; - RAY_RETURN_NOT_OK(Append(DriverID::nil(), client_log_key_, data, add_callback)); + RAY_RETURN_NOT_OK(Append(DriverID::Nil(), client_log_key_, data, add_callback)); // We successfully added the deletion entry. Mark ourselves as disconnected. disconnected_ = true; return Status::OK(); @@ -611,19 +611,19 @@ Status ClientTable::Disconnect(const DisconnectCallback &callback) { ray::Status ClientTable::MarkDisconnected(const ClientID &dead_client_id) { auto data = std::make_shared(); - data->client_id = dead_client_id.binary(); + data->client_id = dead_client_id.Binary(); data->entry_type = EntryType::DELETION; - return Append(DriverID::nil(), client_log_key_, data, nullptr); + return Append(DriverID::Nil(), client_log_key_, data, nullptr); } void ClientTable::GetClient(const ClientID &client_id, ClientTableDataT &client_info) const { - RAY_CHECK(!client_id.is_nil()); + RAY_CHECK(!client_id.IsNil()); auto entry = client_cache_.find(client_id); if (entry != client_cache_.end()) { client_info = entry->second; } else { - client_info.client_id = ClientID::nil().binary(); + client_info.client_id = ClientID::Nil().Binary(); } } @@ -633,7 +633,7 @@ const std::unordered_map &ClientTable::GetAllClients Status ClientTable::Lookup(const Callback &lookup) { RAY_CHECK(lookup != nullptr); - return Log::Lookup(DriverID::nil(), client_log_key_, lookup); + return Log::Lookup(DriverID::Nil(), client_log_key_, lookup); } std::string ClientTable::DebugString() const { @@ -653,12 +653,12 @@ Status ActorCheckpointIdTable::AddCheckpointId(const DriverID &driver_id, std::shared_ptr copy = std::make_shared(data); copy->timestamps.push_back(current_sys_time_ms()); - copy->checkpoint_ids += checkpoint_id.binary(); + copy->checkpoint_ids += checkpoint_id.Binary(); auto num_to_keep = RayConfig::instance().num_actor_checkpoints_to_keep(); while (copy->timestamps.size() > num_to_keep) { // Delete the checkpoint from actor checkpoint table. const auto &checkpoint_id = - ActorCheckpointID::from_binary(copy->checkpoint_ids.substr(0, kUniqueIDSize)); + ActorCheckpointID::FromBinary(copy->checkpoint_ids.substr(0, kUniqueIDSize)); RAY_LOG(DEBUG) << "Deleting checkpoint " << checkpoint_id << " for actor " << actor_id; copy->timestamps.erase(copy->timestamps.begin()); @@ -671,9 +671,9 @@ Status ActorCheckpointIdTable::AddCheckpointId(const DriverID &driver_id, ray::gcs::AsyncGcsClient *client, const UniqueID &id) { std::shared_ptr data = std::make_shared(); - data->actor_id = id.binary(); + data->actor_id = id.Binary(); data->timestamps.push_back(current_sys_time_ms()); - data->checkpoint_ids = checkpoint_id.binary(); + data->checkpoint_ids = checkpoint_id.Binary(); RAY_CHECK_OK(Add(driver_id, actor_id, data, nullptr)); }; return Lookup(driver_id, actor_id, lookup_callback, failure_callback); diff --git a/src/ray/gcs/tables.h b/src/ray/gcs/tables.h index 58a087d8c666..af739dc2ed32 100644 --- a/src/ray/gcs/tables.h +++ b/src/ray/gcs/tables.h @@ -559,7 +559,7 @@ class TaskLeaseTable : public Table { // TODO(swang): Use a common helper function to format the key instead of // hardcoding it to match the Redis module. std::vector args = {"PEXPIRE", - EnumNameTablePrefix(prefix_) + id.binary(), + EnumNameTablePrefix(prefix_) + id.Binary(), std::to_string(data->timeout)}; return GetRedisContext(id)->RunArgvAsync(args); @@ -695,7 +695,7 @@ class ClientTable : public Log { prefix_ = TablePrefix::CLIENT; // Set the local client's ID. - local_client_.client_id = client_id.binary(); + local_client_.client_id = client_id.Binary(); }; /// Connect as a client to the GCS. This registers us in the client table diff --git a/src/ray/id.cc b/src/ray/id.cc index a011430ad1cf..4c9ce4dc9244 100644 --- a/src/ray/id.cc +++ b/src/ray/id.cc @@ -26,14 +26,14 @@ std::mt19937 RandomlySeededMersenneTwister() { uint64_t MurmurHash64A(const void *key, int len, unsigned int seed); -plasma::UniqueID ObjectID::to_plasma_id() const { +plasma::UniqueID ObjectID::ToPlasmaId() const { plasma::UniqueID result; - std::memcpy(result.mutable_data(), data(), kUniqueIDSize); + std::memcpy(result.mutable_data(), Data(), kUniqueIDSize); return result; } ObjectID::ObjectID(const plasma::UniqueID &from) { - std::memcpy(this->mutable_data(), from.data(), kUniqueIDSize); + std::memcpy(this->MutableData(), from.data(), kUniqueIDSize); } // This code is from https://sites.google.com/site/murmurhash/ @@ -86,29 +86,29 @@ uint64_t MurmurHash64A(const void *key, int len, unsigned int seed) { } TaskID TaskID::GetDriverTaskID(const DriverID &driver_id) { - std::string driver_id_str = driver_id.binary(); - driver_id_str.resize(size()); - return TaskID::from_binary(driver_id_str); + std::string driver_id_str = driver_id.Binary(); + driver_id_str.resize(Size()); + return TaskID::FromBinary(driver_id_str); } -TaskID ObjectID::task_id() const { - return TaskID::from_binary( - std::string(reinterpret_cast(id_), TaskID::size())); +TaskID ObjectID::TaskId() const { + return TaskID::FromBinary( + std::string(reinterpret_cast(id_), TaskID::Size())); } -ObjectID ObjectID::for_put(const TaskID &task_id, int64_t put_index) { +ObjectID ObjectID::ForPut(const TaskID &task_id, int64_t put_index) { RAY_CHECK(put_index >= 1 && put_index <= kMaxTaskPuts) << "index=" << put_index; ObjectID object_id; - std::memcpy(object_id.id_, task_id.binary().c_str(), task_id.size()); + std::memcpy(object_id.id_, task_id.Binary().c_str(), task_id.Size()); object_id.index_ = -put_index; return object_id; } -ObjectID ObjectID::for_task_return(const TaskID &task_id, int64_t return_index) { +ObjectID ObjectID::ForTaskReturn(const TaskID &task_id, int64_t return_index) { RAY_CHECK(return_index >= 1 && return_index <= kMaxTaskReturns) << "index=" << return_index; ObjectID object_id; - std::memcpy(object_id.id_, task_id.binary().c_str(), task_id.size()); + std::memcpy(object_id.id_, task_id.Binary().c_str(), task_id.Size()); object_id.index_ = return_index; return object_id; } @@ -118,23 +118,23 @@ const TaskID GenerateTaskId(const DriverID &driver_id, const TaskID &parent_task // Compute hashes. SHA256_CTX ctx; sha256_init(&ctx); - sha256_update(&ctx, reinterpret_cast(driver_id.data()), driver_id.size()); - sha256_update(&ctx, reinterpret_cast(parent_task_id.data()), - parent_task_id.size()); + sha256_update(&ctx, reinterpret_cast(driver_id.Data()), driver_id.Size()); + sha256_update(&ctx, reinterpret_cast(parent_task_id.Data()), + parent_task_id.Size()); sha256_update(&ctx, (const BYTE *)&parent_task_counter, sizeof(parent_task_counter)); // Compute the final task ID from the hash. BYTE buff[DIGEST_SIZE]; sha256_final(&ctx, buff); - return TaskID::from_binary(std::string(buff, buff + TaskID::size())); + return TaskID::FromBinary(std::string(buff, buff + TaskID::Size())); } #define ID_OSTREAM_OPERATOR(id_type) \ std::ostream &operator<<(std::ostream &os, const id_type &id) { \ - if (id.is_nil()) { \ + if (id.IsNil()) { \ os << "NIL_ID"; \ } else { \ - os << id.hex(); \ + os << id.Hex(); \ } \ return os; \ } diff --git a/src/ray/id.h b/src/ray/id.h index f90f66549358..7153a95f7750 100644 --- a/src/ray/id.h +++ b/src/ray/id.h @@ -31,26 +31,26 @@ template class BaseID { public: BaseID(); - static T from_random(); - static T from_binary(const std::string &binary); - static const T &nil(); - static size_t size() { return T::size(); } + static T FromRandom(); + static T FromBinary(const std::string &binary); + static const T &Nil(); + static size_t Size() { return T::Size(); } - size_t hash() const; - bool is_nil() const; + size_t Hash() const; + bool IsNil() const; bool operator==(const BaseID &rhs) const; bool operator!=(const BaseID &rhs) const; - const uint8_t *data() const; - std::string binary() const; - std::string hex() const; + const uint8_t *Data() const; + std::string Binary() const; + std::string Hex() const; protected: BaseID(const std::string &binary) { - std::memcpy(const_cast(this->data()), binary.data(), T::size()); + std::memcpy(const_cast(this->Data()), binary.data(), T::Size()); } - // All IDs are immutable for hash evaluations. mutable_data is only allow to use + // All IDs are immutable for hash evaluations. MutableData is only allow to use // in construction time, so this function is protected. - uint8_t *mutable_data(); + uint8_t *MutableData(); // For lazy evaluation, be careful to have one Id contained in another. // This hash code will be duplicated. mutable size_t hash_ = 0; @@ -59,7 +59,7 @@ class BaseID { class UniqueID : public BaseID { public: UniqueID() : BaseID(){}; - static size_t size() { return kUniqueIDSize; } + static size_t Size() { return kUniqueIDSize; } protected: UniqueID(const std::string &binary); @@ -71,7 +71,7 @@ class UniqueID : public BaseID { class TaskID : public BaseID { public: TaskID() : BaseID() {} - static size_t size() { return kTaskIDSize; } + static size_t Size() { return kTaskIDSize; } static TaskID GetDriverTaskID(const DriverID &driver_id); private: @@ -81,8 +81,8 @@ class TaskID : public BaseID { class ObjectID : public BaseID { public: ObjectID() : BaseID() {} - static size_t size() { return kUniqueIDSize; } - plasma::ObjectID to_plasma_id() const; + static size_t Size() { return kUniqueIDSize; } + plasma::ObjectID ToPlasmaId() const; ObjectID(const plasma::UniqueID &from); /// Get the index of this object in the task that created it. @@ -90,26 +90,26 @@ class ObjectID : public BaseID { /// \return The index of object creation according to the task that created /// this object. This is positive if the task returned the object and negative /// if created by a put. - int32_t object_index() const { return index_; } + int32_t ObjectIndex() const { return index_; } /// Compute the task ID of the task that created the object. /// /// \return The task ID of the task that created this object. - TaskID task_id() const; + TaskID TaskId() const; /// Compute the object ID of an object put by the task. /// /// \param task_id The task ID of the task that created the object. /// \param index What index of the object put in the task. /// \return The computed object ID. - static ObjectID for_put(const TaskID &task_id, int64_t put_index); + static ObjectID ForPut(const TaskID &task_id, int64_t put_index); /// Compute the object ID of an object returned by the task. /// /// \param task_id The task ID of the task that created the object. /// \param return_index What index of the object returned by in the task. /// \return The computed object ID. - static ObjectID for_task_return(const TaskID &task_id, int64_t return_index); + static ObjectID ForTaskReturn(const TaskID &task_id, int64_t return_index); private: uint8_t id_[kTaskIDSize]; @@ -125,22 +125,22 @@ std::ostream &operator<<(std::ostream &os, const UniqueID &id); std::ostream &operator<<(std::ostream &os, const TaskID &id); std::ostream &operator<<(std::ostream &os, const ObjectID &id); -#define DEFINE_UNIQUE_ID(type) \ - class RAY_EXPORT type : public UniqueID { \ - public: \ - explicit type(const UniqueID &from) { \ - std::memcpy(&id_, from.data(), kUniqueIDSize); \ - } \ - type() : UniqueID() {} \ - static type from_random() { return type(UniqueID::from_random()); } \ - static type from_binary(const std::string &binary) { return type(binary); } \ - static type nil() { return type(UniqueID::nil()); } \ - static size_t size() { return kUniqueIDSize; } \ - \ - private: \ - explicit type(const std::string &binary) { \ - std::memcpy(&id_, binary.data(), kUniqueIDSize); \ - } \ +#define DEFINE_UNIQUE_ID(type) \ + class RAY_EXPORT type : public UniqueID { \ + public: \ + explicit type(const UniqueID &from) { \ + std::memcpy(&id_, from.Data(), kUniqueIDSize); \ + } \ + type() : UniqueID() {} \ + static type FromRandom() { return type(UniqueID::FromRandom()); } \ + static type FromBinary(const std::string &binary) { return type(binary); } \ + static type Nil() { return type(UniqueID::Nil()); } \ + static size_t Size() { return kUniqueIDSize; } \ + \ + private: \ + explicit type(const std::string &binary) { \ + std::memcpy(&id_, binary.data(), kUniqueIDSize); \ + } \ }; #include "id_def.h" @@ -163,12 +163,12 @@ template BaseID::BaseID() { // Using const_cast to directly change data is dangerous. The cached // hash may not be changed. This is used in construction time. - std::fill_n(this->mutable_data(), T::size(), 0xff); + std::fill_n(this->MutableData(), T::Size(), 0xff); } template -T BaseID::from_random() { - std::string data(T::size(), 0); +T BaseID::FromRandom() { + std::string data(T::Size(), 0); // NOTE(pcm): The right way to do this is to have one std::mt19937 per // thread (using the thread_local keyword), but that's not supported on // older versions of macOS (see https://stackoverflow.com/a/29929949) @@ -176,44 +176,44 @@ T BaseID::from_random() { std::lock_guard lock(random_engine_mutex); static std::mt19937 generator = RandomlySeededMersenneTwister(); std::uniform_int_distribution dist(0, std::numeric_limits::max()); - for (int i = 0; i < T::size(); i++) { + for (int i = 0; i < T::Size(); i++) { data[i] = static_cast(dist(generator)); } - return T::from_binary(data); + return T::FromBinary(data); } template -T BaseID::from_binary(const std::string &binary) { - T t = T::nil(); - std::memcpy(t.mutable_data(), binary.data(), T::size()); +T BaseID::FromBinary(const std::string &binary) { + T t = T::Nil(); + std::memcpy(t.MutableData(), binary.data(), T::Size()); return t; } template -const T &BaseID::nil() { +const T &BaseID::Nil() { static const T nil_id; return nil_id; } template -bool BaseID::is_nil() const { - static T nil_id = T::nil(); +bool BaseID::IsNil() const { + static T nil_id = T::Nil(); return *this == nil_id; } template -size_t BaseID::hash() const { +size_t BaseID::Hash() const { // Note(ashione): hash code lazy calculation(it's invoked every time if hash code is // default value 0) if (!hash_) { - hash_ = MurmurHash64A(data(), T::size(), 0); + hash_ = MurmurHash64A(Data(), T::Size(), 0); } return hash_; } template bool BaseID::operator==(const BaseID &rhs) const { - return std::memcmp(data(), rhs.data(), T::size()) == 0; + return std::memcmp(Data(), rhs.Data(), T::Size()) == 0; } template @@ -222,26 +222,26 @@ bool BaseID::operator!=(const BaseID &rhs) const { } template -uint8_t *BaseID::mutable_data() { +uint8_t *BaseID::MutableData() { return reinterpret_cast(this) + sizeof(hash_); } template -const uint8_t *BaseID::data() const { +const uint8_t *BaseID::Data() const { return reinterpret_cast(this) + sizeof(hash_); } template -std::string BaseID::binary() const { - return std::string(reinterpret_cast(data()), T::size()); +std::string BaseID::Binary() const { + return std::string(reinterpret_cast(Data()), T::Size()); } template -std::string BaseID::hex() const { +std::string BaseID::Hex() const { constexpr char hex[] = "0123456789abcdef"; - const uint8_t *id = data(); + const uint8_t *id = Data(); std::string result; - for (int i = 0; i < T::size(); i++) { + for (int i = 0; i < T::Size(); i++) { unsigned int val = id[i]; result.push_back(hex[val >> 4]); result.push_back(hex[val & 0xf]); @@ -256,11 +256,11 @@ namespace std { #define DEFINE_UNIQUE_ID(type) \ template <> \ struct hash<::ray::type> { \ - size_t operator()(const ::ray::type &id) const { return id.hash(); } \ + size_t operator()(const ::ray::type &id) const { return id.Hash(); } \ }; \ template <> \ struct hash { \ - size_t operator()(const ::ray::type &id) const { return id.hash(); } \ + size_t operator()(const ::ray::type &id) const { return id.Hash(); } \ }; DEFINE_UNIQUE_ID(UniqueID); diff --git a/src/ray/object_manager/object_buffer_pool.cc b/src/ray/object_manager/object_buffer_pool.cc index fe0471797c0d..fba426d732c3 100644 --- a/src/ray/object_manager/object_buffer_pool.cc +++ b/src/ray/object_manager/object_buffer_pool.cc @@ -43,7 +43,7 @@ std::pair ObjectBufferPool::Ge std::lock_guard lock(pool_mutex_); if (get_buffer_state_.count(object_id) == 0) { plasma::ObjectBuffer object_buffer; - plasma::ObjectID plasma_id = object_id.to_plasma_id(); + plasma::ObjectID plasma_id = object_id.ToPlasmaId(); RAY_ARROW_CHECK_OK(store_client_.Get(&plasma_id, 1, 0, &object_buffer)); if (object_buffer.data == nullptr) { RAY_LOG(ERROR) << "Failed to get object"; @@ -72,14 +72,14 @@ void ObjectBufferPool::ReleaseGetChunk(const ObjectID &object_id, uint64_t chunk GetBufferState &buffer_state = get_buffer_state_[object_id]; buffer_state.references--; if (buffer_state.references == 0) { - RAY_ARROW_CHECK_OK(store_client_.Release(object_id.to_plasma_id())); + RAY_ARROW_CHECK_OK(store_client_.Release(object_id.ToPlasmaId())); get_buffer_state_.erase(object_id); } } void ObjectBufferPool::AbortGet(const ObjectID &object_id) { std::lock_guard lock(pool_mutex_); - RAY_ARROW_CHECK_OK(store_client_.Release(object_id.to_plasma_id())); + RAY_ARROW_CHECK_OK(store_client_.Release(object_id.ToPlasmaId())); get_buffer_state_.erase(object_id); } @@ -88,7 +88,7 @@ std::pair ObjectBufferPool::Cr uint64_t chunk_index) { std::lock_guard lock(pool_mutex_); if (create_buffer_state_.count(object_id) == 0) { - const plasma::ObjectID plasma_id = object_id.to_plasma_id(); + const plasma::ObjectID plasma_id = object_id.ToPlasmaId(); int64_t object_size = data_size - metadata_size; // Try to create shared buffer. std::shared_ptr data; @@ -150,7 +150,7 @@ void ObjectBufferPool::SealChunk(const ObjectID &object_id, const uint64_t chunk create_buffer_state_[object_id].chunk_state[chunk_index] = CreateChunkState::SEALED; create_buffer_state_[object_id].num_seals_remaining--; if (create_buffer_state_[object_id].num_seals_remaining == 0) { - const plasma::ObjectID plasma_id = object_id.to_plasma_id(); + const plasma::ObjectID plasma_id = object_id.ToPlasmaId(); RAY_ARROW_CHECK_OK(store_client_.Seal(plasma_id)); RAY_ARROW_CHECK_OK(store_client_.Release(plasma_id)); create_buffer_state_.erase(object_id); @@ -158,7 +158,7 @@ void ObjectBufferPool::SealChunk(const ObjectID &object_id, const uint64_t chunk } void ObjectBufferPool::AbortCreate(const ObjectID &object_id) { - const plasma::ObjectID plasma_id = object_id.to_plasma_id(); + const plasma::ObjectID plasma_id = object_id.ToPlasmaId(); RAY_ARROW_CHECK_OK(store_client_.Release(plasma_id)); RAY_ARROW_CHECK_OK(store_client_.Abort(plasma_id)); create_buffer_state_.erase(object_id); @@ -186,7 +186,7 @@ void ObjectBufferPool::FreeObjects(const std::vector &object_ids) { std::vector plasma_ids; plasma_ids.reserve(object_ids.size()); for (const auto &id : object_ids) { - plasma_ids.push_back(id.to_plasma_id()); + plasma_ids.push_back(id.ToPlasmaId()); } std::lock_guard lock(pool_mutex_); RAY_ARROW_CHECK_OK(store_client_.Delete(plasma_ids)); diff --git a/src/ray/object_manager/object_directory.cc b/src/ray/object_manager/object_directory.cc index 85157abcdbe9..1f05559f4b87 100644 --- a/src/ray/object_manager/object_directory.cc +++ b/src/ray/object_manager/object_directory.cc @@ -19,7 +19,7 @@ void UpdateObjectLocations(const GcsTableNotificationMode notification_mode, // with GcsTableNotificationMode, we can determine whether the update mode is // addition or deletion. for (const auto &object_table_data : location_updates) { - ClientID client_id = ClientID::from_binary(object_table_data.manager); + ClientID client_id = ClientID::FromBinary(object_table_data.manager); if (notification_mode != GcsTableNotificationMode::REMOVE) { client_ids->insert(client_id); } else { @@ -71,7 +71,7 @@ void ObjectDirectory::RegisterBackend() { } }; RAY_CHECK_OK(gcs_client_->object_table().Subscribe( - DriverID::nil(), gcs_client_->client_table().GetLocalClientId(), + DriverID::Nil(), gcs_client_->client_table().GetLocalClientId(), object_notification_callback, nullptr)); } @@ -81,10 +81,10 @@ ray::Status ObjectDirectory::ReportObjectAdded( RAY_LOG(DEBUG) << "Reporting object added to GCS " << object_id; // Append the addition entry to the object table. auto data = std::make_shared(); - data->manager = client_id.binary(); + data->manager = client_id.Binary(); data->object_size = object_info.data_size; ray::Status status = - gcs_client_->object_table().Add(DriverID::nil(), object_id, data, nullptr); + gcs_client_->object_table().Add(DriverID::Nil(), object_id, data, nullptr); return status; } @@ -94,10 +94,10 @@ ray::Status ObjectDirectory::ReportObjectRemoved( RAY_LOG(DEBUG) << "Reporting object removed to GCS " << object_id; // Append the eviction entry to the object table. auto data = std::make_shared(); - data->manager = client_id.binary(); + data->manager = client_id.Binary(); data->object_size = object_info.data_size; ray::Status status = - gcs_client_->object_table().Remove(DriverID::nil(), object_id, data, nullptr); + gcs_client_->object_table().Remove(DriverID::Nil(), object_id, data, nullptr); return status; }; @@ -105,8 +105,8 @@ void ObjectDirectory::LookupRemoteConnectionInfo( RemoteConnectionInfo &connection_info) const { ClientTableDataT client_data; gcs_client_->client_table().GetClient(connection_info.client_id, client_data); - ClientID result_client_id = ClientID::from_binary(client_data.client_id); - if (!result_client_id.is_nil()) { + ClientID result_client_id = ClientID::FromBinary(client_data.client_id); + if (!result_client_id.IsNil()) { RAY_CHECK(result_client_id == connection_info.client_id); if (client_data.entry_type == EntryType::INSERTION) { connection_info.ip = client_data.node_manager_address; @@ -157,7 +157,7 @@ ray::Status ObjectDirectory::SubscribeObjectLocations(const UniqueID &callback_i if (it == listeners_.end()) { it = listeners_.emplace(object_id, LocationListenerState()).first; status = gcs_client_->object_table().RequestNotifications( - DriverID::nil(), object_id, gcs_client_->client_table().GetLocalClientId()); + DriverID::Nil(), object_id, gcs_client_->client_table().GetLocalClientId()); } auto &listener_state = it->second; // TODO(hme): Make this fatal after implementing Pull suppression. @@ -185,7 +185,7 @@ ray::Status ObjectDirectory::UnsubscribeObjectLocations(const UniqueID &callback entry->second.callbacks.erase(callback_id); if (entry->second.callbacks.empty()) { status = gcs_client_->object_table().CancelNotifications( - DriverID::nil(), object_id, gcs_client_->client_table().GetLocalClientId()); + DriverID::Nil(), object_id, gcs_client_->client_table().GetLocalClientId()); listeners_.erase(entry); } return status; @@ -208,7 +208,7 @@ ray::Status ObjectDirectory::LookupLocations(const ObjectID &object_id, // SubscribeObjectLocations call, so look up the object's locations // directly from the GCS. status = gcs_client_->object_table().Lookup( - DriverID::nil(), object_id, + DriverID::Nil(), object_id, [this, callback](gcs::AsyncGcsClient *client, const ObjectID &object_id, const std::vector &location_updates) { // Build the set of current locations based on the entries in the log. diff --git a/src/ray/object_manager/object_manager.cc b/src/ray/object_manager/object_manager.cc index 3d829027b7cf..954162c21aef 100644 --- a/src/ray/object_manager/object_manager.cc +++ b/src/ray/object_manager/object_manager.cc @@ -64,7 +64,7 @@ void ObjectManager::StopIOService() { void ObjectManager::HandleObjectAdded( const object_manager::protocol::ObjectInfoT &object_info) { // Notify the object directory that the object has been added to this node. - ObjectID object_id = ObjectID::from_binary(object_info.object_id); + ObjectID object_id = ObjectID::FromBinary(object_info.object_id); RAY_LOG(DEBUG) << "Object added " << object_id; RAY_CHECK(local_objects_.count(object_id) == 0); local_objects_[object_id].object_info = object_info; @@ -272,7 +272,7 @@ void ObjectManager::PullSendRequest(const ObjectID &object_id, flatbuffers::FlatBufferBuilder fbb; auto message = object_manager_protocol::CreatePullRequestMessage( - fbb, fbb.CreateString(client_id_.binary()), fbb.CreateString(object_id.binary())); + fbb, fbb.CreateString(client_id_.Binary()), fbb.CreateString(object_id.Binary())); fbb.Finish(message); conn->WriteMessageAsync( static_cast(object_manager_protocol::MessageType::PullRequest), @@ -315,7 +315,7 @@ void ObjectManager::HandleSendFinished(const ObjectID &object_id, profile_event.end_time = end_time; // Encode the object ID, client ID, chunk index, and status as a json list, // which will be parsed by the reader of the profile table. - profile_event.extra_data = "[\"" + object_id.hex() + "\",\"" + client_id.hex() + "\"," + + profile_event.extra_data = "[\"" + object_id.Hex() + "\",\"" + client_id.Hex() + "\"," + std::to_string(chunk_index) + ",\"" + status.ToString() + "\"]"; profile_events_.push_back(profile_event); @@ -335,7 +335,7 @@ void ObjectManager::HandleReceiveFinished(const ObjectID &object_id, profile_event.end_time = end_time; // Encode the object ID, client ID, chunk index, and status as a json list, // which will be parsed by the reader of the profile table. - profile_event.extra_data = "[\"" + object_id.hex() + "\",\"" + client_id.hex() + "\"," + + profile_event.extra_data = "[\"" + object_id.Hex() + "\",\"" + client_id.Hex() + "\"," + std::to_string(chunk_index) + ",\"" + status.ToString() + "\"]"; profile_events_.push_back(profile_event); @@ -408,7 +408,7 @@ void ObjectManager::Push(const ObjectID &object_id, const ClientID &client_id) { static_cast(object_info.data_size + object_info.metadata_size); uint64_t metadata_size = static_cast(object_info.metadata_size); uint64_t num_chunks = buffer_pool_.GetNumChunks(data_size); - UniqueID push_id = UniqueID::from_random(); + UniqueID push_id = UniqueID::FromRandom(); for (uint64_t chunk_index = 0; chunk_index < num_chunks; ++chunk_index) { send_service_.post([this, push_id, client_id, object_id, data_size, metadata_size, chunk_index, connection_info]() { @@ -527,7 +527,7 @@ void ObjectManager::CancelPull(const ObjectID &object_id) { ray::Status ObjectManager::Wait(const std::vector &object_ids, int64_t timeout_ms, uint64_t num_required_objects, bool wait_local, const WaitCallback &callback) { - UniqueID wait_id = UniqueID::from_random(); + UniqueID wait_id = UniqueID::FromRandom(); RAY_LOG(DEBUG) << "Wait request " << wait_id << " on " << client_id_; RAY_RETURN_NOT_OK(AddWaitRequest(wait_id, object_ids, timeout_ms, num_required_objects, wait_local, callback)); @@ -773,7 +773,7 @@ void ObjectManager::ConnectClient(std::shared_ptr &conn, // TODO: trash connection on failure. auto info = flatbuffers::GetRoot(message); - ClientID client_id = ClientID::from_binary(info->client_id()->str()); + ClientID client_id = ClientID::FromBinary(info->client_id()->str()); bool is_transfer = info->is_transfer(); conn->SetClientID(client_id); if (is_transfer) { @@ -798,14 +798,14 @@ void ObjectManager::ReceivePullRequest(std::shared_ptr &con const uint8_t *message) { // Serialize and push object to requesting client. auto pr = flatbuffers::GetRoot(message); - ObjectID object_id = ObjectID::from_binary(pr->object_id()->str()); - ClientID client_id = ClientID::from_binary(pr->client_id()->str()); + ObjectID object_id = ObjectID::FromBinary(pr->object_id()->str()); + ClientID client_id = ClientID::FromBinary(pr->client_id()->str()); ProfileEventT profile_event; profile_event.event_type = "receive_pull_request"; profile_event.start_time = current_sys_time_seconds(); profile_event.end_time = profile_event.start_time; - profile_event.extra_data = "[\"" + object_id.hex() + "\",\"" + client_id.hex() + "\"]"; + profile_event.extra_data = "[\"" + object_id.Hex() + "\",\"" + client_id.Hex() + "\"]"; profile_events_.push_back(profile_event); Push(object_id, client_id); @@ -817,7 +817,7 @@ void ObjectManager::ReceivePushRequest(std::shared_ptr &con // Serialize. auto object_header = flatbuffers::GetRoot(message); - const ObjectID object_id = ObjectID::from_binary(object_header->object_id()->str()); + const ObjectID object_id = ObjectID::FromBinary(object_header->object_id()->str()); uint64_t chunk_index = object_header->chunk_index(); uint64_t data_size = object_header->data_size(); uint64_t metadata_size = object_header->metadata_size(); @@ -941,7 +941,7 @@ void ObjectManager::SpreadFreeObjectRequest(const std::vector &object_ ProfileTableDataT ObjectManager::GetAndResetProfilingInfo() { ProfileTableDataT profile_info; profile_info.component_type = "object_manager"; - profile_info.component_id = client_id_.binary(); + profile_info.component_id = client_id_.Binary(); for (auto const &profile_event : profile_events_) { profile_info.profile_events.emplace_back(new ProfileEventT(profile_event)); diff --git a/src/ray/object_manager/object_manager.h b/src/ray/object_manager/object_manager.h index 0d301052358f..cb0cff83f349 100644 --- a/src/ray/object_manager/object_manager.h +++ b/src/ray/object_manager/object_manager.h @@ -394,7 +394,7 @@ class ObjectManager : public ObjectManagerInterface { /// This is used as the callback identifier in Pull for /// SubscribeObjectLocations. We only need one identifier because we never need to /// subscribe multiple times to the same object during Pull. - UniqueID object_directory_pull_callback_id_ = UniqueID::from_random(); + UniqueID object_directory_pull_callback_id_ = UniqueID::FromRandom(); /// A set of active wait requests. std::unordered_map active_wait_requests_; diff --git a/src/ray/object_manager/test/object_manager_stress_test.cc b/src/ray/object_manager/test/object_manager_stress_test.cc index 91b0ffc3d576..6d7c0be0f856 100644 --- a/src/ray/object_manager/test/object_manager_stress_test.cc +++ b/src/ray/object_manager/test/object_manager_stress_test.cc @@ -121,8 +121,8 @@ class TestObjectManagerBase : public ::testing::Test { flushall_redis(); // start store - store_id_1 = StartStore(UniqueID::from_random().hex()); - store_id_2 = StartStore(UniqueID::from_random().hex()); + store_id_1 = StartStore(UniqueID::FromRandom().Hex()); + store_id_2 = StartStore(UniqueID::FromRandom().Hex()); uint pull_timeout_ms = 1000; int max_sends_a = 2; @@ -174,14 +174,14 @@ class TestObjectManagerBase : public ::testing::Test { } ObjectID WriteDataToClient(plasma::PlasmaClient &client, int64_t data_size) { - ObjectID object_id = ObjectID::from_random(); + ObjectID object_id = ObjectID::FromRandom(); RAY_LOG(DEBUG) << "ObjectID Created: " << object_id; uint8_t metadata[] = {5}; int64_t metadata_size = sizeof(metadata); std::shared_ptr data; - RAY_ARROW_CHECK_OK(client.Create(object_id.to_plasma_id(), data_size, metadata, - metadata_size, &data)); - RAY_ARROW_CHECK_OK(client.Seal(object_id.to_plasma_id())); + RAY_ARROW_CHECK_OK( + client.Create(object_id.ToPlasmaId(), data_size, metadata, metadata_size, &data)); + RAY_ARROW_CHECK_OK(client.Seal(object_id.ToPlasmaId())); return object_id; } @@ -242,7 +242,7 @@ class StressTestObjectManager : public TestObjectManagerBase { client_id_2 = gcs_client_2->client_table().GetLocalClientId(); gcs_client_1->client_table().RegisterClientAddedCallback([this]( gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) { - ClientID parsed_id = ClientID::from_binary(data.client_id); + ClientID parsed_id = ClientID::FromBinary(data.client_id); if (parsed_id == client_id_1 || parsed_id == client_id_2) { num_connected_clients += 1; } @@ -262,7 +262,7 @@ class StressTestObjectManager : public TestObjectManagerBase { ray::Status status = ray::Status::OK(); status = server1->object_manager_.SubscribeObjAdded( [this](const object_manager::protocol::ObjectInfoT &object_info) { - object_added_handler_1(ObjectID::from_binary(object_info.object_id)); + object_added_handler_1(ObjectID::FromBinary(object_info.object_id)); if (v1.size() == num_expected_objects && v1.size() == v2.size()) { TransferTestComplete(); } @@ -270,7 +270,7 @@ class StressTestObjectManager : public TestObjectManagerBase { RAY_CHECK_OK(status); status = server2->object_manager_.SubscribeObjAdded( [this](const object_manager::protocol::ObjectInfoT &object_info) { - object_added_handler_2(ObjectID::from_binary(object_info.object_id)); + object_added_handler_2(ObjectID::FromBinary(object_info.object_id)); if (v2.size() == num_expected_objects && v1.size() == v2.size()) { TransferTestComplete(); } @@ -290,7 +290,7 @@ class StressTestObjectManager : public TestObjectManagerBase { plasma::ObjectBuffer GetObject(plasma::PlasmaClient &client, ObjectID &object_id) { plasma::ObjectBuffer object_buffer; - plasma::ObjectID plasma_id = object_id.to_plasma_id(); + plasma::ObjectID plasma_id = object_id.ToPlasmaId(); RAY_ARROW_CHECK_OK(client.Get(&plasma_id, 1, 0, &object_buffer)); return object_buffer; } @@ -298,7 +298,7 @@ class StressTestObjectManager : public TestObjectManagerBase { static unsigned char *GetDigest(plasma::PlasmaClient &client, ObjectID &object_id) { const int64_t size = sizeof(uint64_t); static unsigned char digest_1[size]; - RAY_ARROW_CHECK_OK(client.Hash(object_id.to_plasma_id(), &digest_1[0])); + RAY_ARROW_CHECK_OK(client.Hash(object_id.ToPlasmaId(), &digest_1[0])); return digest_1; } @@ -439,12 +439,12 @@ class StressTestObjectManager : public TestObjectManagerBase { << "\n"; ClientTableDataT data; gcs_client_1->client_table().GetClient(client_id_1, data); - RAY_LOG(DEBUG) << "ClientID=" << ClientID::from_binary(data.client_id) << "\n" + RAY_LOG(DEBUG) << "ClientID=" << ClientID::FromBinary(data.client_id) << "\n" << "ClientIp=" << data.node_manager_address << "\n" << "ClientPort=" << data.node_manager_port; ClientTableDataT data2; gcs_client_1->client_table().GetClient(client_id_2, data2); - RAY_LOG(DEBUG) << "ClientID=" << ClientID::from_binary(data2.client_id) << "\n" + RAY_LOG(DEBUG) << "ClientID=" << ClientID::FromBinary(data2.client_id) << "\n" << "ClientIp=" << data2.node_manager_address << "\n" << "ClientPort=" << data2.node_manager_port; } diff --git a/src/ray/object_manager/test/object_manager_test.cc b/src/ray/object_manager/test/object_manager_test.cc index 98eeb9186192..983a8fa7bc05 100644 --- a/src/ray/object_manager/test/object_manager_test.cc +++ b/src/ray/object_manager/test/object_manager_test.cc @@ -114,8 +114,8 @@ class TestObjectManagerBase : public ::testing::Test { flushall_redis(); // start store - store_id_1 = StartStore(UniqueID::from_random().hex()); - store_id_2 = StartStore(UniqueID::from_random().hex()); + store_id_1 = StartStore(UniqueID::FromRandom().Hex()); + store_id_2 = StartStore(UniqueID::FromRandom().Hex()); uint pull_timeout_ms = 1; push_timeout_ms = 1000; @@ -162,7 +162,7 @@ class TestObjectManagerBase : public ::testing::Test { } ObjectID WriteDataToClient(plasma::PlasmaClient &client, int64_t data_size) { - return WriteDataToClient(client, data_size, ObjectID::from_random()); + return WriteDataToClient(client, data_size, ObjectID::FromRandom()); } ObjectID WriteDataToClient(plasma::PlasmaClient &client, int64_t data_size, @@ -171,9 +171,9 @@ class TestObjectManagerBase : public ::testing::Test { uint8_t metadata[] = {5}; int64_t metadata_size = sizeof(metadata); std::shared_ptr data; - RAY_ARROW_CHECK_OK(client.Create(object_id.to_plasma_id(), data_size, metadata, - metadata_size, &data)); - RAY_ARROW_CHECK_OK(client.Seal(object_id.to_plasma_id())); + RAY_ARROW_CHECK_OK( + client.Create(object_id.ToPlasmaId(), data_size, metadata, metadata_size, &data)); + RAY_ARROW_CHECK_OK(client.Seal(object_id.ToPlasmaId())); return object_id; } @@ -221,7 +221,7 @@ class TestObjectManager : public TestObjectManagerBase { client_id_2 = gcs_client_2->client_table().GetLocalClientId(); gcs_client_1->client_table().RegisterClientAddedCallback([this]( gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) { - ClientID parsed_id = ClientID::from_binary(data.client_id); + ClientID parsed_id = ClientID::FromBinary(data.client_id); if (parsed_id == client_id_1 || parsed_id == client_id_2) { num_connected_clients += 1; } @@ -240,13 +240,13 @@ class TestObjectManager : public TestObjectManagerBase { ray::Status status = ray::Status::OK(); status = server1->object_manager_.SubscribeObjAdded( [this](const object_manager::protocol::ObjectInfoT &object_info) { - object_added_handler_1(ObjectID::from_binary(object_info.object_id)); + object_added_handler_1(ObjectID::FromBinary(object_info.object_id)); NotificationTestCompleteIfSatisfied(); }); RAY_CHECK_OK(status); status = server2->object_manager_.SubscribeObjAdded( [this](const object_manager::protocol::ObjectInfoT &object_info) { - object_added_handler_2(ObjectID::from_binary(object_info.object_id)); + object_added_handler_2(ObjectID::FromBinary(object_info.object_id)); NotificationTestCompleteIfSatisfied(); }); RAY_CHECK_OK(status); @@ -254,11 +254,11 @@ class TestObjectManager : public TestObjectManagerBase { uint data_size = 1000000; // dummy_id is not local. The push function will timeout. - ObjectID dummy_id = ObjectID::from_random(); + ObjectID dummy_id = ObjectID::FromRandom(); server1->object_manager_.Push(dummy_id, gcs_client_2->client_table().GetLocalClientId()); - created_object_id1 = ObjectID::from_random(); + created_object_id1 = ObjectID::FromRandom(); WriteDataToClient(client1, data_size, created_object_id1); // Server1 holds Object1 so this Push call will success. server1->object_manager_.Push(created_object_id1, @@ -268,7 +268,7 @@ class TestObjectManager : public TestObjectManagerBase { timer.reset(new boost::asio::deadline_timer(main_service)); auto period = boost::posix_time::milliseconds(push_timeout_ms + 10); timer->expires_from_now(period); - created_object_id2 = ObjectID::from_random(); + created_object_id2 = ObjectID::FromRandom(); timer->async_wait([this, data_size](const boost::system::error_code &error) { WriteDataToClient(client2, data_size, created_object_id2); }); @@ -288,7 +288,7 @@ class TestObjectManager : public TestObjectManagerBase { // object. ObjectID object_1 = WriteDataToClient(client2, data_size); ObjectID object_2 = WriteDataToClient(client2, data_size); - UniqueID sub_id = ray::UniqueID::from_random(); + UniqueID sub_id = ray::UniqueID::FromRandom(); RAY_CHECK_OK(server1->object_manager_.object_directory_->SubscribeObjectLocations( sub_id, object_1, [this, sub_id, object_1, object_2]( @@ -307,7 +307,7 @@ class TestObjectManager : public TestObjectManagerBase { std::vector object_ids = {object_1, object_2}; boost::posix_time::ptime start_time = boost::posix_time::second_clock::local_time(); - UniqueID wait_id = UniqueID::from_random(); + UniqueID wait_id = UniqueID::FromRandom(); RAY_CHECK_OK(server1->object_manager_.AddWaitRequest( wait_id, object_ids, timeout_ms, required_objects, false, @@ -378,7 +378,7 @@ class TestObjectManager : public TestObjectManagerBase { } if (include_nonexistent) { num_objects += 1; - object_ids.push_back(ObjectID::from_random()); + object_ids.push_back(ObjectID::FromRandom()); } boost::posix_time::ptime start_time = boost::posix_time::second_clock::local_time(); RAY_CHECK_OK(server1->object_manager_.Wait( @@ -457,17 +457,17 @@ class TestObjectManager : public TestObjectManagerBase { << "\n"; ClientTableDataT data; gcs_client_1->client_table().GetClient(client_id_1, data); - RAY_LOG(DEBUG) << (ClientID::from_binary(data.client_id).is_nil()); - RAY_LOG(DEBUG) << "Server 1 ClientID=" << ClientID::from_binary(data.client_id); + RAY_LOG(DEBUG) << (ClientID::FromBinary(data.client_id).IsNil()); + RAY_LOG(DEBUG) << "Server 1 ClientID=" << ClientID::FromBinary(data.client_id); RAY_LOG(DEBUG) << "Server 1 ClientIp=" << data.node_manager_address; RAY_LOG(DEBUG) << "Server 1 ClientPort=" << data.node_manager_port; - ASSERT_EQ(client_id_1, ClientID::from_binary(data.client_id)); + ASSERT_EQ(client_id_1, ClientID::FromBinary(data.client_id)); ClientTableDataT data2; gcs_client_1->client_table().GetClient(client_id_2, data2); - RAY_LOG(DEBUG) << "Server 2 ClientID=" << ClientID::from_binary(data2.client_id); + RAY_LOG(DEBUG) << "Server 2 ClientID=" << ClientID::FromBinary(data2.client_id); RAY_LOG(DEBUG) << "Server 2 ClientIp=" << data2.node_manager_address; RAY_LOG(DEBUG) << "Server 2 ClientPort=" << data2.node_manager_port; - ASSERT_EQ(client_id_2, ClientID::from_binary(data2.client_id)); + ASSERT_EQ(client_id_2, ClientID::FromBinary(data2.client_id)); } }; diff --git a/src/ray/raylet/actor_registration.cc b/src/ray/raylet/actor_registration.cc index 1cc55367ed07..cc587bc4d74e 100644 --- a/src/ray/raylet/actor_registration.cc +++ b/src/ray/raylet/actor_registration.cc @@ -14,28 +14,28 @@ ActorRegistration::ActorRegistration(const ActorTableDataT &actor_table_data) ActorRegistration::ActorRegistration(const ActorTableDataT &actor_table_data, const ActorCheckpointDataT &checkpoint_data) : actor_table_data_(actor_table_data), - execution_dependency_(ObjectID::from_binary(checkpoint_data.execution_dependency)) { + execution_dependency_(ObjectID::FromBinary(checkpoint_data.execution_dependency)) { // Restore `frontier_`. for (size_t i = 0; i < checkpoint_data.handle_ids.size(); i++) { - auto handle_id = ActorHandleID::from_binary(checkpoint_data.handle_ids[i]); + auto handle_id = ActorHandleID::FromBinary(checkpoint_data.handle_ids[i]); auto &frontier_entry = frontier_[handle_id]; frontier_entry.task_counter = checkpoint_data.task_counters[i]; frontier_entry.execution_dependency = - ObjectID::from_binary(checkpoint_data.frontier_dependencies[i]); + ObjectID::FromBinary(checkpoint_data.frontier_dependencies[i]); } // Restore `dummy_objects_`. for (size_t i = 0; i < checkpoint_data.unreleased_dummy_objects.size(); i++) { - auto dummy = ObjectID::from_binary(checkpoint_data.unreleased_dummy_objects[i]); + auto dummy = ObjectID::FromBinary(checkpoint_data.unreleased_dummy_objects[i]); dummy_objects_[dummy] = checkpoint_data.num_dummy_object_dependencies[i]; } } const ClientID ActorRegistration::GetNodeManagerId() const { - return ClientID::from_binary(actor_table_data_.node_manager_id); + return ClientID::FromBinary(actor_table_data_.node_manager_id); } const ObjectID ActorRegistration::GetActorCreationDependency() const { - return ObjectID::from_binary(actor_table_data_.actor_creation_dummy_object_id); + return ObjectID::FromBinary(actor_table_data_.actor_creation_dummy_object_id); } const ObjectID ActorRegistration::GetExecutionDependency() const { @@ -43,7 +43,7 @@ const ObjectID ActorRegistration::GetExecutionDependency() const { } const DriverID ActorRegistration::GetDriverId() const { - return DriverID::from_binary(actor_table_data_.driver_id); + return DriverID::FromBinary(actor_table_data_.driver_id); } const int64_t ActorRegistration::GetMaxReconstructions() const { @@ -65,7 +65,7 @@ ObjectID ActorRegistration::ExtendFrontier(const ActorHandleID &handle_id, // Release the reference to the previous cursor for this // actor handle, if there was one. ObjectID object_to_release; - if (!frontier_entry.execution_dependency.is_nil()) { + if (!frontier_entry.execution_dependency.IsNil()) { auto it = dummy_objects_.find(frontier_entry.execution_dependency); RAY_CHECK(it != dummy_objects_.end()); it->second--; @@ -110,16 +110,16 @@ std::shared_ptr ActorRegistration::GenerateCheckpointData( // Use actor's current state to generate checkpoint data. auto checkpoint_data = std::make_shared(); - checkpoint_data->actor_id = actor_id.binary(); - checkpoint_data->execution_dependency = copy.GetExecutionDependency().binary(); + checkpoint_data->actor_id = actor_id.Binary(); + checkpoint_data->execution_dependency = copy.GetExecutionDependency().Binary(); for (const auto &frontier : copy.GetFrontier()) { - checkpoint_data->handle_ids.push_back(frontier.first.binary()); + checkpoint_data->handle_ids.push_back(frontier.first.Binary()); checkpoint_data->task_counters.push_back(frontier.second.task_counter); checkpoint_data->frontier_dependencies.push_back( - frontier.second.execution_dependency.binary()); + frontier.second.execution_dependency.Binary()); } for (const auto &entry : copy.GetDummyObjects()) { - checkpoint_data->unreleased_dummy_objects.push_back(entry.first.binary()); + checkpoint_data->unreleased_dummy_objects.push_back(entry.first.Binary()); checkpoint_data->num_dummy_object_dependencies.push_back(entry.second); } return checkpoint_data; diff --git a/src/ray/raylet/client_connection_test.cc b/src/ray/raylet/client_connection_test.cc index 361a04de1147..952088bb2f4b 100644 --- a/src/ray/raylet/client_connection_test.cc +++ b/src/ray/raylet/client_connection_test.cc @@ -180,7 +180,7 @@ TEST_F(ClientConnectionTest, ProcessBadMessage) { "reader", {}, error_message_type_); // If client ID is set, bad message would crash the test. - // reader->SetClientID(UniqueID::from_random()); + // reader->SetClientID(UniqueID::FromRandom()); // Intentionally write a message with incorrect cookie. // Verify it won't crash as long as client ID is not set. diff --git a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc index ac32911ef2d0..ba9fef4f44d6 100644 --- a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc +++ b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc @@ -12,10 +12,10 @@ class UniqueIdFromJByteArray { const ID &GetId() const { return id; } UniqueIdFromJByteArray(JNIEnv *env, const jbyteArray &bytes) { - std::string id_str(ID::size(), 0); - env->GetByteArrayRegion(bytes, 0, ID::size(), + std::string id_str(ID::Size(), 0); + env->GetByteArrayRegion(bytes, 0, ID::Size(), reinterpret_cast(&id_str.front())); - id = ID::from_binary(id_str); + id = ID::FromBinary(id_str); } private: @@ -231,12 +231,12 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativeGenerateTaskId( TaskID task_id = ray::GenerateTaskId(driver_id.GetId(), parent_task_id.GetId(), parent_task_counter); - jbyteArray result = env->NewByteArray(task_id.size()); + jbyteArray result = env->NewByteArray(task_id.Size()); if (nullptr == result) { return nullptr; } - env->SetByteArrayRegion(result, 0, task_id.size(), - reinterpret_cast(task_id.data())); + env->SetByteArrayRegion(result, 0, task_id.Size(), + reinterpret_cast(task_id.Data())); return result; } @@ -280,9 +280,9 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativePrepareCheckpoint(JNIEnv *env if (ThrowRayExceptionIfNotOK(env, status)) { return nullptr; } - jbyteArray result = env->NewByteArray(checkpoint_id.size()); - env->SetByteArrayRegion(result, 0, checkpoint_id.size(), - reinterpret_cast(checkpoint_id.data())); + jbyteArray result = env->NewByteArray(checkpoint_id.Size()); + env->SetByteArrayRegion(result, 0, checkpoint_id.Size(), + reinterpret_cast(checkpoint_id.Data())); return result; } diff --git a/src/ray/raylet/lineage_cache.cc b/src/ray/raylet/lineage_cache.cc index 4c3fac24f19e..910c3481bf58 100644 --- a/src/ray/raylet/lineage_cache.cc +++ b/src/ray/raylet/lineage_cache.cc @@ -48,7 +48,7 @@ void LineageEntry::ComputeParentTaskIds() { parent_task_ids_.clear(); // A task's parents are the tasks that created its arguments. for (const auto &dependency : task_.GetDependencies()) { - parent_task_ids_.insert(dependency.task_id()); + parent_task_ids_.insert(dependency.TaskId()); } } @@ -296,7 +296,7 @@ bool LineageCache::RemoveWaitingTask(const TaskID &task_id) { } void LineageCache::MarkTaskAsForwarded(const TaskID &task_id, const ClientID &node_id) { - RAY_CHECK(!node_id.is_nil()); + RAY_CHECK(!node_id.IsNil()); lineage_.GetEntryMutable(task_id)->MarkExplicitlyForwarded(node_id); } @@ -374,7 +374,7 @@ bool LineageCache::SubscribeTask(const TaskID &task_id) { if (unsubscribed) { // Request notifications for the task if we haven't already requested // notifications for it. - RAY_CHECK_OK(task_pubsub_.RequestNotifications(DriverID::nil(), task_id, client_id_)); + RAY_CHECK_OK(task_pubsub_.RequestNotifications(DriverID::Nil(), task_id, client_id_)); } // Return whether we were previously unsubscribed to this task and are now // subscribed. @@ -387,7 +387,7 @@ bool LineageCache::UnsubscribeTask(const TaskID &task_id) { if (subscribed) { // Cancel notifications for the task if we previously requested // notifications for it. - RAY_CHECK_OK(task_pubsub_.CancelNotifications(DriverID::nil(), task_id, client_id_)); + RAY_CHECK_OK(task_pubsub_.CancelNotifications(DriverID::Nil(), task_id, client_id_)); subscribed_tasks_.erase(it); } // Return whether we were previously subscribed to this task and are now diff --git a/src/ray/raylet/lineage_cache_test.cc b/src/ray/raylet/lineage_cache_test.cc index af411066e914..a61ae846a925 100644 --- a/src/ray/raylet/lineage_cache_test.cc +++ b/src/ray/raylet/lineage_cache_test.cc @@ -43,7 +43,7 @@ class MockGcs : public gcs::TableInterface, notification_callback_(client, task_id, data); } }; - return Add(DriverID::nil(), task_id, task_data, callback); + return Add(DriverID::Nil(), task_id, task_data, callback); } Status RequestNotifications(const DriverID &driver_id, const TaskID &task_id, @@ -91,7 +91,7 @@ class LineageCacheTest : public ::testing::Test { LineageCacheTest() : max_lineage_size_(10), mock_gcs_(), - lineage_cache_(ClientID::from_random(), mock_gcs_, mock_gcs_, max_lineage_size_) { + lineage_cache_(ClientID::FromRandom(), mock_gcs_, mock_gcs_, max_lineage_size_) { mock_gcs_.Subscribe([this](ray::gcs::AsyncGcsClient *client, const TaskID &task_id, const ray::protocol::TaskT &data) { lineage_cache_.HandleEntryCommitted(task_id); @@ -113,7 +113,7 @@ static inline Task ExampleTask(const std::vector &arguments, task_arguments.emplace_back(std::make_shared(references)); } std::vector function_descriptor(3); - auto spec = TaskSpecification(DriverID::nil(), TaskID::from_random(), 0, task_arguments, + auto spec = TaskSpecification(DriverID::Nil(), TaskID::FromRandom(), 0, task_arguments, num_returns, required_resources, Language::PYTHON, function_descriptor); auto execution_spec = TaskExecutionSpecification(std::vector()); @@ -160,7 +160,7 @@ TEST_F(LineageCacheTest, TestGetUncommittedLineageOrDie) { // Get the uncommitted lineage for the last task (the leaf) of one of the chains. auto uncommitted_lineage = - lineage_cache_.GetUncommittedLineageOrDie(task_ids1.back(), ClientID::nil()); + lineage_cache_.GetUncommittedLineageOrDie(task_ids1.back(), ClientID::Nil()); // Check that the uncommitted lineage is exactly equal to the first chain of tasks. ASSERT_EQ(task_ids1.size(), uncommitted_lineage.GetEntries().size()); for (auto &task_id : task_ids1) { @@ -181,7 +181,7 @@ TEST_F(LineageCacheTest, TestGetUncommittedLineageOrDie) { // Get the uncommitted lineage for the inserted task. uncommitted_lineage = lineage_cache_.GetUncommittedLineageOrDie( - combined_task_ids.back(), ClientID::nil()); + combined_task_ids.back(), ClientID::Nil()); // Check that the uncommitted lineage is exactly equal to the entire set of // tasks inserted so far. ASSERT_EQ(combined_task_ids.size(), uncommitted_lineage.GetEntries().size()); @@ -200,8 +200,8 @@ TEST_F(LineageCacheTest, TestMarkTaskAsForwarded) { task_ids.push_back(task.GetTaskSpecification().TaskId()); } - auto node_id = ClientID::from_random(); - auto node_id2 = ClientID::from_random(); + auto node_id = ClientID::FromRandom(); + auto node_id2 = ClientID::FromRandom(); auto forwarded_task_id = task_ids[task_ids.size() - 2]; auto remaining_task_id = task_ids[task_ids.size() - 1]; lineage_cache_.MarkTaskAsForwarded(forwarded_task_id, node_id); @@ -285,7 +285,7 @@ TEST_F(LineageCacheTest, TestEvictChain) { mock_gcs_.Flush(); ASSERT_EQ(lineage_cache_ .GetUncommittedLineageOrDie(tasks.back().GetTaskSpecification().TaskId(), - ClientID::nil()) + ClientID::Nil()) .GetEntries() .size(), tasks.size()); @@ -298,7 +298,7 @@ TEST_F(LineageCacheTest, TestEvictChain) { mock_gcs_.Flush(); ASSERT_EQ(lineage_cache_ .GetUncommittedLineageOrDie(tasks.back().GetTaskSpecification().TaskId(), - ClientID::nil()) + ClientID::Nil()) .GetEntries() .size(), tasks.size()); @@ -335,7 +335,7 @@ TEST_F(LineageCacheTest, TestEvictManyParents) { ASSERT_EQ(lineage_cache_.GetLineage().GetEntries().size(), total_tasks); ASSERT_EQ(lineage_cache_ .GetUncommittedLineageOrDie(child_task.GetTaskSpecification().TaskId(), - ClientID::nil()) + ClientID::Nil()) .GetEntries() .size(), total_tasks); @@ -351,7 +351,7 @@ TEST_F(LineageCacheTest, TestEvictManyParents) { ASSERT_EQ(lineage_cache_.GetLineage().GetEntries().size(), total_tasks); ASSERT_EQ(lineage_cache_ .GetUncommittedLineageOrDie( - child_task.GetTaskSpecification().TaskId(), ClientID::nil()) + child_task.GetTaskSpecification().TaskId(), ClientID::Nil()) .GetEntries() .size(), total_tasks); @@ -376,7 +376,7 @@ TEST_F(LineageCacheTest, TestForwardTasksRoundTrip) { const auto task_id = it->GetTaskSpecification().TaskId(); // Simulate removing the task and forwarding it to another node. auto uncommitted_lineage = - lineage_cache_.GetUncommittedLineageOrDie(task_id, ClientID::nil()); + lineage_cache_.GetUncommittedLineageOrDie(task_id, ClientID::Nil()); ASSERT_TRUE(lineage_cache_.RemoveWaitingTask(task_id)); // Simulate receiving the task again. Make sure we can add the task back. flatbuffers::FlatBufferBuilder fbb; @@ -400,7 +400,7 @@ TEST_F(LineageCacheTest, TestForwardTask) { tasks.erase(it); auto task_id_to_remove = forwarded_task.GetTaskSpecification().TaskId(); auto uncommitted_lineage = - lineage_cache_.GetUncommittedLineageOrDie(task_id_to_remove, ClientID::nil()); + lineage_cache_.GetUncommittedLineageOrDie(task_id_to_remove, ClientID::Nil()); ASSERT_TRUE(lineage_cache_.RemoveWaitingTask(task_id_to_remove)); ASSERT_EQ(lineage_cache_.GetLineage().GetEntries().size(), 3); @@ -450,7 +450,7 @@ TEST_F(LineageCacheTest, TestEviction) { // uncommitted lineage. const auto last_task_id = tasks.back().GetTaskSpecification().TaskId(); auto uncommitted_lineage = - lineage_cache_.GetUncommittedLineageOrDie(last_task_id, ClientID::nil()); + lineage_cache_.GetUncommittedLineageOrDie(last_task_id, ClientID::Nil()); ASSERT_EQ(uncommitted_lineage.GetEntries().size(), lineage_size); // Simulate executing the first task on a remote node and adding it to the @@ -484,7 +484,7 @@ TEST_F(LineageCacheTest, TestEviction) { // All tasks have now been flushed. Check that enough lineage has been // evicted that the uncommitted lineage is now less than the maximum size. uncommitted_lineage = - lineage_cache_.GetUncommittedLineageOrDie(last_task_id, ClientID::nil()); + lineage_cache_.GetUncommittedLineageOrDie(last_task_id, ClientID::Nil()); ASSERT_TRUE(uncommitted_lineage.GetEntries().size() < max_lineage_size_); // The remaining task should have no uncommitted lineage. ASSERT_EQ(uncommitted_lineage.GetEntries().size(), 1); @@ -510,7 +510,7 @@ TEST_F(LineageCacheTest, TestOutOfOrderEviction) { // uncommitted lineage. const auto last_task_id = tasks.back().GetTaskSpecification().TaskId(); auto uncommitted_lineage = - lineage_cache_.GetUncommittedLineageOrDie(last_task_id, ClientID::nil()); + lineage_cache_.GetUncommittedLineageOrDie(last_task_id, ClientID::Nil()); ASSERT_EQ(uncommitted_lineage.GetEntries().size(), lineage_size); ASSERT_EQ(lineage_cache_.GetLineage().GetEntries().size(), lineage_size); diff --git a/src/ray/raylet/monitor.cc b/src/ray/raylet/monitor.cc index 1e20fe3f4131..171b4dc9439e 100644 --- a/src/ray/raylet/monitor.cc +++ b/src/ray/raylet/monitor.cc @@ -35,7 +35,7 @@ void Monitor::Start() { HandleHeartbeat(id, heartbeat_data); }; RAY_CHECK_OK(gcs_client_.heartbeat_table().Subscribe( - DriverID::nil(), ClientID::nil(), heartbeat_callback, nullptr, nullptr)); + DriverID::Nil(), ClientID::Nil(), heartbeat_callback, nullptr, nullptr)); Tick(); } @@ -52,7 +52,7 @@ void Monitor::Tick() { const std::vector &all_data) { bool marked = false; for (const auto &data : all_data) { - if (client_id.binary() == data.client_id && + if (client_id.Binary() == data.client_id && data.entry_type == EntryType::DELETION) { // The node has been marked dead by itself. marked = true; @@ -70,7 +70,7 @@ void Monitor::Tick() { << " has missed too many heartbeats from it."; // We use the nil DriverID to broadcast the message to all drivers. RAY_CHECK_OK(gcs_client_.error_table().PushErrorToDriver( - DriverID::nil(), type, error_message.str(), current_time_ms())); + DriverID::Nil(), type, error_message.str(), current_time_ms())); } }; RAY_CHECK_OK(gcs_client_.client_table().Lookup(lookup_callback)); @@ -89,7 +89,7 @@ void Monitor::Tick() { batch->batch.push_back(std::unique_ptr( new HeartbeatTableDataT(heartbeat.second))); } - RAY_CHECK_OK(gcs_client_.heartbeat_batch_table().Add(DriverID::nil(), ClientID::nil(), + RAY_CHECK_OK(gcs_client_.heartbeat_batch_table().Add(DriverID::Nil(), ClientID::Nil(), batch, nullptr)); heartbeat_buffer_.clear(); } diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 2e25407f12fb..1b582d7617cb 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -110,7 +110,7 @@ NodeManager::NodeManager(boost::asio::io_service &io_service, RAY_CHECK_OK(object_manager_.SubscribeObjAdded( [this](const object_manager::protocol::ObjectInfoT &object_info) { - ObjectID object_id = ObjectID::from_binary(object_info.object_id); + ObjectID object_id = ObjectID::FromBinary(object_info.object_id); HandleObjectLocal(object_id); })); RAY_CHECK_OK(object_manager_.SubscribeObjDeleted( @@ -131,13 +131,13 @@ ray::Status NodeManager::RegisterGcs() { lineage_cache_.HandleEntryCommitted(task_id); }; RAY_RETURN_NOT_OK(gcs_client_->raylet_task_table().Subscribe( - DriverID::nil(), gcs_client_->client_table().GetLocalClientId(), + DriverID::Nil(), gcs_client_->client_table().GetLocalClientId(), task_committed_callback, nullptr, nullptr)); const auto task_lease_notification_callback = [this](gcs::AsyncGcsClient *client, const TaskID &task_id, const TaskLeaseDataT &task_lease) { - const ClientID node_manager_id = ClientID::from_binary(task_lease.node_manager_id); + const ClientID node_manager_id = ClientID::FromBinary(task_lease.node_manager_id); if (gcs_client_->client_table().IsRemoved(node_manager_id)) { // The node manager that added the task lease is already removed. The // lease is considered inactive. @@ -155,7 +155,7 @@ ray::Status NodeManager::RegisterGcs() { reconstruction_policy_.HandleTaskLeaseNotification(task_id, 0); }; RAY_RETURN_NOT_OK(gcs_client_->task_lease_table().Subscribe( - DriverID::nil(), gcs_client_->client_table().GetLocalClientId(), + DriverID::Nil(), gcs_client_->client_table().GetLocalClientId(), task_lease_notification_callback, task_lease_empty_callback, nullptr)); // Register a callback to handle actor notifications. @@ -170,7 +170,7 @@ ray::Status NodeManager::RegisterGcs() { }; RAY_RETURN_NOT_OK(gcs_client_->actor_table().Subscribe( - DriverID::nil(), ClientID::nil(), actor_notification_callback, nullptr)); + DriverID::Nil(), ClientID::Nil(), actor_notification_callback, nullptr)); // Register a callback on the client table for new clients. auto node_manager_client_added = [this](gcs::AsyncGcsClient *client, const UniqueID &id, @@ -208,7 +208,7 @@ ray::Status NodeManager::RegisterGcs() { HeartbeatBatchAdded(heartbeat_batch); }; RAY_RETURN_NOT_OK(gcs_client_->heartbeat_batch_table().Subscribe( - DriverID::nil(), ClientID::nil(), heartbeat_batch_added, + DriverID::Nil(), ClientID::Nil(), heartbeat_batch_added, /*subscribe_callback=*/nullptr, /*done_callback=*/nullptr)); @@ -219,7 +219,7 @@ ray::Status NodeManager::RegisterGcs() { HandleDriverTableUpdate(client_id, driver_data); }; RAY_RETURN_NOT_OK(gcs_client_->driver_table().Subscribe( - DriverID::nil(), ClientID::nil(), driver_table_handler, nullptr)); + DriverID::Nil(), ClientID::Nil(), driver_table_handler, nullptr)); // Start sending heartbeats to the GCS. last_heartbeat_at_ms_ = current_time_ms(); @@ -253,10 +253,10 @@ void NodeManager::KillWorker(std::shared_ptr worker) { void NodeManager::HandleDriverTableUpdate( const DriverID &id, const std::vector &driver_data) { for (const auto &entry : driver_data) { - RAY_LOG(DEBUG) << "HandleDriverTableUpdate " << UniqueID::from_binary(entry.driver_id) + RAY_LOG(DEBUG) << "HandleDriverTableUpdate " << UniqueID::FromBinary(entry.driver_id) << " " << entry.is_dead; if (entry.is_dead) { - auto driver_id = DriverID::from_binary(entry.driver_id); + auto driver_id = DriverID::FromBinary(entry.driver_id); auto workers = worker_pool_.GetWorkersRunningTasksForDriver(driver_id); // Kill all the workers. The actual cleanup for these workers is done @@ -291,7 +291,7 @@ void NodeManager::Heartbeat() { auto heartbeat_data = std::make_shared(); const auto &my_client_id = gcs_client_->client_table().GetLocalClientId(); SchedulingResources &local_resources = cluster_resource_map_[my_client_id]; - heartbeat_data->client_id = my_client_id.binary(); + heartbeat_data->client_id = my_client_id.Binary(); // TODO(atumanov): modify the heartbeat table protocol to use the ResourceSet directly. // TODO(atumanov): implement a ResourceSet const_iterator. for (const auto &resource_pair : @@ -311,7 +311,7 @@ void NodeManager::Heartbeat() { } ray::Status status = heartbeat_table.Add( - DriverID::nil(), gcs_client_->client_table().GetLocalClientId(), heartbeat_data, + DriverID::Nil(), gcs_client_->client_table().GetLocalClientId(), heartbeat_data, /*success_callback=*/nullptr); RAY_CHECK_OK_PREPEND(status, "Heartbeat failed"); @@ -359,7 +359,7 @@ void NodeManager::GetObjectManagerProfileInfo() { } void NodeManager::ClientAdded(const ClientTableDataT &client_data) { - const ClientID client_id = ClientID::from_binary(client_data.client_id); + const ClientID client_id = ClientID::FromBinary(client_data.client_id); RAY_LOG(DEBUG) << "[ClientAdded] Received callback from client id " << client_id; if (client_id == gcs_client_->client_table().GetLocalClientId()) { @@ -393,7 +393,7 @@ void NodeManager::ClientAdded(const ClientTableDataT &client_data) { << ". This may be since the node was recently removed."; // We use the nil DriverID to broadcast the message to all drivers. RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver( - DriverID::nil(), type, error_message.str(), current_time_ms())); + DriverID::Nil(), type, error_message.str(), current_time_ms())); return; } @@ -432,7 +432,7 @@ ray::Status NodeManager::ConnectRemoteNodeManager(const ClientID &client_id, void NodeManager::ClientRemoved(const ClientTableDataT &client_data) { // TODO(swang): If we receive a notification for our own death, clean up and // exit immediately. - const ClientID client_id = ClientID::from_binary(client_data.client_id); + const ClientID client_id = ClientID::FromBinary(client_data.client_id); RAY_LOG(DEBUG) << "[ClientRemoved] Received callback from client id " << client_id; RAY_CHECK(client_id != gcs_client_->client_table().GetLocalClientId()) @@ -478,7 +478,7 @@ void NodeManager::ClientRemoved(const ClientTableDataT &client_data) { } void NodeManager::ResourceCreateUpdated(const ClientTableDataT &client_data) { - const ClientID client_id = ClientID::from_binary(client_data.client_id); + const ClientID client_id = ClientID::FromBinary(client_data.client_id); const ClientID &local_client_id = gcs_client_->client_table().GetLocalClientId(); RAY_LOG(DEBUG) << "[ResourceCreateUpdated] received callback from client id " @@ -514,7 +514,7 @@ void NodeManager::ResourceCreateUpdated(const ClientTableDataT &client_data) { } void NodeManager::ResourceDeleted(const ClientTableDataT &client_data) { - const ClientID client_id = ClientID::from_binary(client_data.client_id); + const ClientID client_id = ClientID::FromBinary(client_data.client_id); const ClientID &local_client_id = gcs_client_->client_table().GetLocalClientId(); ResourceSet new_res_set(client_data.resources_total_label, @@ -608,7 +608,7 @@ void NodeManager::HeartbeatBatchAdded(const HeartbeatBatchTableDataT &heartbeat_ const ClientID &local_client_id = gcs_client_->client_table().GetLocalClientId(); // Update load information provided by each heartbeat. for (const auto &heartbeat_data : heartbeat_batch.batch) { - const ClientID &client_id = ClientID::from_binary(heartbeat_data->client_id); + const ClientID &client_id = ClientID::FromBinary(heartbeat_data->client_id); if (client_id == local_client_id) { // Skip heartbeats from self. continue; @@ -638,12 +638,12 @@ void NodeManager::PublishActorStateTransition( const ActorTableDataT &data) { auto redis_context = client->primary_context(); if (data.state == ActorState::DEAD || data.state == ActorState::RECONSTRUCTING) { - std::vector args = {"XADD", id.hex(), "*", "signal", + std::vector args = {"XADD", id.Hex(), "*", "signal", "ACTOR_DIED_SIGNAL"}; RAY_CHECK_OK(redis_context->RunArgvAsync(args)); } }; - RAY_CHECK_OK(gcs_client_->actor_table().AppendAt(DriverID::nil(), actor_id, + RAY_CHECK_OK(gcs_client_->actor_table().AppendAt(DriverID::Nil(), actor_id, actor_notification, success_callback, failure_callback, log_length)); } @@ -852,9 +852,9 @@ void NodeManager::ProcessClientMessage( // Clean up their creating tasks from GCS. std::vector creating_task_ids; for (const auto &object_id : object_ids) { - creating_task_ids.push_back(object_id.task_id()); + creating_task_ids.push_back(object_id.TaskId()); } - gcs_client_->raylet_task_table().Delete(DriverID::nil(), creating_task_ids); + gcs_client_->raylet_task_table().Delete(DriverID::Nil(), creating_task_ids); } } break; case protocol::MessageType::PrepareActorCheckpointRequest: { @@ -945,7 +945,7 @@ void NodeManager::ProcessGetTaskMessage( std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); RAY_CHECK(worker); // If the worker was assigned a task, mark it as finished. - if (!worker->GetAssignedTaskId().is_nil()) { + if (!worker->GetAssignedTaskId().IsNil()) { FinishAssignedTask(*worker); } // Return the worker to the idle pool. @@ -1003,7 +1003,7 @@ void NodeManager::ProcessDisconnectClientMessage( } const ActorID &actor_id = worker->GetActorId(); - if (!actor_id.is_nil()) { + if (!actor_id.IsNil()) { // If the worker was an actor, update actor state, reconstruct the actor if needed, // and clean up actor's tasks if the actor is permanently dead. HandleDisconnectedActor(actor_id, true, intentional_disconnect); @@ -1012,10 +1012,10 @@ void NodeManager::ProcessDisconnectClientMessage( const TaskID &task_id = worker->GetAssignedTaskId(); // If the worker was running a task, clean up the task and push an error to // the driver, unless the worker is already dead. - if (!task_id.is_nil() && !worker->IsDead()) { + if (!task_id.IsNil() && !worker->IsDead()) { // If the worker was an actor, the task was already cleaned up in // `HandleDisconnectedActor`. - if (actor_id.is_nil()) { + if (actor_id.IsNil()) { const Task &task = local_queues_.RemoveTask(task_id); TreatTaskAsFailed(task, ErrorType::WORKER_DIED); } @@ -1062,7 +1062,7 @@ void NodeManager::ProcessDisconnectClientMessage( gcs_client_->driver_table().AppendDriverData(DriverID(client->GetClientId()), /*is_dead=*/true)); auto driver_id = worker->GetAssignedTaskId(); - RAY_CHECK(!driver_id.is_nil()); + RAY_CHECK(!driver_id.IsNil()); local_queues_.RemoveDriverTaskId(driver_id); worker_pool_.DisconnectDriver(worker); @@ -1197,13 +1197,13 @@ void NodeManager::ProcessPrepareActorCheckpointRequest( const auto task_id = worker->GetAssignedTaskId(); const Task &task = local_queues_.GetTaskOfState(task_id, TaskState::RUNNING); // Generate checkpoint id and data. - ActorCheckpointID checkpoint_id = ActorCheckpointID::from_random(); + ActorCheckpointID checkpoint_id = ActorCheckpointID::FromRandom(); auto checkpoint_data = actor_entry->second.GenerateCheckpointData(actor_entry->first, task); // Write checkpoint data to GCS. RAY_CHECK_OK(gcs_client_->actor_checkpoint_table().Add( - DriverID::nil(), checkpoint_id, checkpoint_data, + DriverID::Nil(), checkpoint_id, checkpoint_data, [worker, actor_id, this](ray::gcs::AsyncGcsClient *client, const ActorCheckpointID &checkpoint_id, const ActorCheckpointDataT &data) { @@ -1212,7 +1212,7 @@ void NodeManager::ProcessPrepareActorCheckpointRequest( // Save this actor-to-checkpoint mapping, and remove old checkpoints associated // with this actor. RAY_CHECK_OK(gcs_client_->actor_checkpoint_id_table().AddCheckpointId( - DriverID::nil(), actor_id, checkpoint_id)); + DriverID::Nil(), actor_id, checkpoint_id)); // Send reply to worker. flatbuffers::FlatBufferBuilder fbb; auto reply = ray::protocol::CreatePrepareActorCheckpointReply( @@ -1293,7 +1293,7 @@ void NodeManager::ProcessSetResourceRequest( ClientID client_id = from_flatbuf(*message->client_id()); // If the python arg was null, set client_id to the local client - if (client_id.is_nil()) { + if (client_id.IsNil()) { client_id = gcs_client_->client_table().GetLocalClientId(); } @@ -1331,7 +1331,7 @@ void NodeManager::ProcessSetResourceRequest( auto data_shared_ptr = std::make_shared(data); auto client_table = gcs_client_->client_table(); RAY_CHECK_OK(gcs_client_->client_table().Append( - DriverID::nil(), client_table.client_log_key_, data_shared_ptr, nullptr)); + DriverID::Nil(), client_table.client_log_key_, data_shared_ptr, nullptr)); } void NodeManager::ScheduleTasks( @@ -1450,7 +1450,7 @@ void NodeManager::TreatTaskAsFailed(const Task &task, const ErrorType &error_typ } const std::string meta = std::to_string(static_cast(error_type)); for (int64_t i = 0; i < num_returns; i++) { - const auto object_id = spec.ReturnId(i).to_plasma_id(); + const auto object_id = spec.ReturnId(i).ToPlasmaId(); arrow::Status status = store_client_.CreateAndSeal(object_id, "", meta); if (!status.ok() && !status.IsPlasmaObjectExists()) { // If we failed to save the error code, log a warning and push an error message @@ -1605,7 +1605,7 @@ void NodeManager::SubmitTask(const Task &task, const Lineage &uncommitted_lineag HandleActorStateTransition(actor_id, ActorRegistration(data.back())); } }; - RAY_CHECK_OK(gcs_client_->actor_table().Lookup(DriverID::nil(), spec.ActorId(), + RAY_CHECK_OK(gcs_client_->actor_table().Lookup(DriverID::Nil(), spec.ActorId(), lookup_callback)); actor_creation_dummy_object = spec.ActorCreationDummyObjectId(); } else { @@ -1796,7 +1796,7 @@ bool NodeManager::AssignTask(const Task &task) { const std::string warning_message = worker_pool_.WarningAboutSize(); if (warning_message != "") { RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver( - DriverID::nil(), "worker_pool_large", warning_message, current_time_ms())); + DriverID::Nil(), "worker_pool_large", warning_message, current_time_ms())); } } // We couldn't assign this task, as no worker available. @@ -1875,7 +1875,7 @@ bool NodeManager::AssignTask(const Task &task) { // The execution dependency is initialized to the actor creation task's // return value, and is subsequently updated to the assigned tasks' // return values, so it should never be nil. - RAY_CHECK(!execution_dependency.is_nil()); + RAY_CHECK(!execution_dependency.IsNil()); // Update the task's execution dependencies to reflect the actual // execution order, to support deterministic reconstruction. // NOTE(swang): The update of an actor task's execution dependencies is @@ -1946,11 +1946,11 @@ void NodeManager::FinishAssignedTask(Worker &worker) { task_dependency_manager_.TaskCanceled(task_id); // Unset the worker's assigned task. - worker.AssignTaskId(TaskID::nil()); + worker.AssignTaskId(TaskID::Nil()); // Unset the worker's assigned driver Id if this is not an actor. if (!task.GetTaskSpecification().IsActorCreationTask() && !task.GetTaskSpecification().IsActorTask()) { - worker.AssignDriverId(DriverID::nil()); + worker.AssignDriverId(DriverID::Nil()); } } @@ -1966,10 +1966,10 @@ ActorTableDataT NodeManager::CreateActorTableDataFromCreationTask(const Task &ta if (actor_entry == actor_registry_.end()) { // Set all of the static fields for the actor. These fields will not // change even if the actor fails or is reconstructed. - new_actor_data.actor_id = actor_id.binary(); + new_actor_data.actor_id = actor_id.Binary(); new_actor_data.actor_creation_dummy_object_id = - task.GetTaskSpecification().ActorDummyObject().binary(); - new_actor_data.driver_id = task.GetTaskSpecification().DriverId().binary(); + task.GetTaskSpecification().ActorDummyObject().Binary(); + new_actor_data.driver_id = task.GetTaskSpecification().DriverId().Binary(); new_actor_data.max_reconstructions = task.GetTaskSpecification().MaxActorReconstructions(); // This is the first time that the actor has been created, so the number @@ -1990,7 +1990,7 @@ ActorTableDataT NodeManager::CreateActorTableDataFromCreationTask(const Task &ta // Set the new fields for the actor's state to indicate that the actor is // now alive on this node manager. new_actor_data.node_manager_id = - gcs_client_->client_table().GetLocalClientId().binary(); + gcs_client_->client_table().GetLocalClientId().Binary(); new_actor_data.state = ActorState::ALIVE; return new_actor_data; } @@ -2001,7 +2001,7 @@ void NodeManager::FinishAssignedActorTask(Worker &worker, const Task &task) { bool resumed_from_checkpoint = false; if (task.GetTaskSpecification().IsActorCreationTask()) { actor_id = task.GetTaskSpecification().ActorCreationId(); - actor_handle_id = ActorHandleID::nil(); + actor_handle_id = ActorHandleID::Nil(); if (checkpoint_id_to_restore_.count(actor_id) > 0) { resumed_from_checkpoint = true; } @@ -2024,7 +2024,7 @@ void NodeManager::FinishAssignedActorTask(Worker &worker, const Task &task) { RAY_LOG(DEBUG) << "Looking up checkpoint " << checkpoint_id << " for actor " << actor_id; RAY_CHECK_OK(gcs_client_->actor_checkpoint_table().Lookup( - DriverID::nil(), checkpoint_id, + DriverID::Nil(), checkpoint_id, [this, actor_id, new_actor_data](ray::gcs::AsyncGcsClient *client, const UniqueID &checkpoint_id, const ActorCheckpointDataT &checkpoint_data) { @@ -2074,7 +2074,7 @@ void NodeManager::FinishAssignedActorTask(Worker &worker, const Task &task) { const auto dummy_object = task.GetTaskSpecification().ActorDummyObject(); const ObjectID object_to_release = actor_entry->second.ExtendFrontier(actor_handle_id, dummy_object); - if (!object_to_release.is_nil()) { + if (!object_to_release.IsNil()) { // If there were no new actor handles created, then no other actor task // will depend on this execution dependency, so it safe to release. HandleObjectMissing(object_to_release); @@ -2094,7 +2094,7 @@ void NodeManager::FinishAssignedActorTask(Worker &worker, const Task &task) { void NodeManager::HandleTaskReconstruction(const TaskID &task_id) { // Retrieve the task spec in order to re-execute the task. RAY_CHECK_OK(gcs_client_->raylet_task_table().Lookup( - DriverID::nil(), task_id, + DriverID::Nil(), task_id, /*success_callback=*/ [this](ray::gcs::AsyncGcsClient *client, const TaskID &task_id, const ray::protocol::TaskT &task_data) { @@ -2380,7 +2380,7 @@ std::string NodeManager::DebugString() const { result << "\nInitialConfigResources: " << initial_config_.resource_config.ToString(); result << "\nClusterResources:"; for (auto &pair : cluster_resource_map_) { - result << "\n" << pair.first.hex() << ": " << pair.second.DebugString(); + result << "\n" << pair.first.Hex() << ": " << pair.second.DebugString(); } result << "\n" << object_manager_.DebugString(); result << "\n" << gcs_client_->DebugString(); @@ -2399,7 +2399,7 @@ std::string NodeManager::DebugString() const { result << "\nRemoteConnections:"; for (auto &pair : remote_server_connections_) { - result << "\n" << pair.first.hex() << ": " << pair.second->DebugString(); + result << "\n" << pair.first.Hex() << ": " << pair.second->DebugString(); } result << "\nDebugString() time ms: " << (current_time_ms() - now_ms); return result.str(); diff --git a/src/ray/raylet/object_manager_integration_test.cc b/src/ray/raylet/object_manager_integration_test.cc index 861020595448..a774e1409195 100644 --- a/src/ray/raylet/object_manager_integration_test.cc +++ b/src/ray/raylet/object_manager_integration_test.cc @@ -99,14 +99,14 @@ class TestObjectManagerBase : public ::testing::Test { } ObjectID WriteDataToClient(plasma::PlasmaClient &client, int64_t data_size) { - ObjectID object_id = ObjectID::from_random(); + ObjectID object_id = ObjectID::FromRandom(); RAY_LOG(DEBUG) << "ObjectID Created: " << object_id; uint8_t metadata[] = {5}; int64_t metadata_size = sizeof(metadata); std::shared_ptr data; - RAY_ARROW_CHECK_OK(client.Create(object_id.to_plasma_id(), data_size, metadata, - metadata_size, &data)); - RAY_ARROW_CHECK_OK(client.Seal(object_id.to_plasma_id())); + RAY_ARROW_CHECK_OK( + client.Create(object_id.ToPlasmaId(), data_size, metadata, metadata_size, &data)); + RAY_ARROW_CHECK_OK(client.Seal(object_id.ToPlasmaId())); return object_id; } @@ -138,7 +138,7 @@ class TestObjectManagerIntegration : public TestObjectManagerBase { client_id_2 = gcs_client_2->client_table().GetLocalClientId(); gcs_client_1->client_table().RegisterClientAddedCallback([this]( gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) { - ClientID parsed_id = ClientID::from_binary(data.client_id); + ClientID parsed_id = ClientID::FromBinary(data.client_id); if (parsed_id == client_id_1 || parsed_id == client_id_2) { num_connected_clients += 1; } @@ -158,7 +158,7 @@ class TestObjectManagerIntegration : public TestObjectManagerBase { ray::Status status = ray::Status::OK(); status = server1->object_manager_.SubscribeObjAdded( [this](const object_manager::protocol::ObjectInfoT &object_info) { - v1.push_back(ObjectID::from_binary(object_info.object_id)); + v1.push_back(ObjectID::FromBinary(object_info.object_id)); if (v1.size() == num_expected_objects && v1.size() == v2.size()) { TestPushComplete(); } @@ -166,7 +166,7 @@ class TestObjectManagerIntegration : public TestObjectManagerBase { RAY_CHECK_OK(status); status = server2->object_manager_.SubscribeObjAdded( [this](const object_manager::protocol::ObjectInfoT &object_info) { - v2.push_back(ObjectID::from_binary(object_info.object_id)); + v2.push_back(ObjectID::FromBinary(object_info.object_id)); if (v2.size() == num_expected_objects && v1.size() == v2.size()) { TestPushComplete(); } @@ -208,13 +208,13 @@ class TestObjectManagerIntegration : public TestObjectManagerBase { << "\n"; ClientTableDataT data; gcs_client_2->client_table().GetClient(client_id_1, data); - RAY_LOG(INFO) << (ClientID::from_binary(data.client_id).is_nil()); - RAY_LOG(INFO) << "ClientID=" << ClientID::from_binary(data.client_id); + RAY_LOG(INFO) << (ClientID::FromBinary(data.client_id).IsNil()); + RAY_LOG(INFO) << "ClientID=" << ClientID::FromBinary(data.client_id); RAY_LOG(INFO) << "ClientIp=" << data.node_manager_address; RAY_LOG(INFO) << "ClientPort=" << data.node_manager_port; ClientTableDataT data2; gcs_client_1->client_table().GetClient(client_id_2, data2); - RAY_LOG(INFO) << "ClientID=" << ClientID::from_binary(data2.client_id); + RAY_LOG(INFO) << "ClientID=" << ClientID::FromBinary(data2.client_id); RAY_LOG(INFO) << "ClientIp=" << data2.node_manager_address; RAY_LOG(INFO) << "ClientPort=" << data2.node_manager_port; } diff --git a/src/ray/raylet/raylet_client.cc b/src/ray/raylet/raylet_client.cc index 0f488089e6d0..ac312b79d13e 100644 --- a/src/ray/raylet/raylet_client.cc +++ b/src/ray/raylet/raylet_client.cc @@ -312,12 +312,12 @@ ray::Status RayletClient::Wait(const std::vector &object_ids, int num_ auto reply_message = flatbuffers::GetRoot(reply.get()); auto found = reply_message->found(); for (uint i = 0; i < found->size(); i++) { - ObjectID object_id = ObjectID::from_binary(found->Get(i)->str()); + ObjectID object_id = ObjectID::FromBinary(found->Get(i)->str()); result->first.push_back(object_id); } auto remaining = reply_message->remaining(); for (uint i = 0; i < remaining->size(); i++) { - ObjectID object_id = ObjectID::from_binary(remaining->Get(i)->str()); + ObjectID object_id = ObjectID::FromBinary(remaining->Get(i)->str()); result->second.push_back(object_id); } return ray::Status::OK(); @@ -373,7 +373,7 @@ ray::Status RayletClient::PrepareActorCheckpoint(const ActorID &actor_id, if (!status.ok()) return status; auto reply_message = flatbuffers::GetRoot(reply.get()); - checkpoint_id = ActorCheckpointID::from_binary(reply_message->checkpoint_id()->str()); + checkpoint_id = ActorCheckpointID::FromBinary(reply_message->checkpoint_id()->str()); return ray::Status::OK(); } diff --git a/src/ray/raylet/reconstruction_policy.cc b/src/ray/raylet/reconstruction_policy.cc index d1a648a34ce4..97c86ea73cd8 100644 --- a/src/ray/raylet/reconstruction_policy.cc +++ b/src/ray/raylet/reconstruction_policy.cc @@ -52,7 +52,7 @@ void ReconstructionPolicy::SetTaskTimeout( // required by the task are no longer needed soon after. If the // task is still required after this initial period, then we now // subscribe to task lease notifications. - RAY_CHECK_OK(task_lease_pubsub_.RequestNotifications(DriverID::nil(), task_id, + RAY_CHECK_OK(task_lease_pubsub_.RequestNotifications(DriverID::Nil(), task_id, client_id_)); it->second.subscribed = true; } @@ -108,9 +108,9 @@ void ReconstructionPolicy::AttemptReconstruction(const TaskID &task_id, // an entry for this reconstruction. auto reconstruction_entry = std::make_shared(); reconstruction_entry->num_reconstructions = reconstruction_attempt; - reconstruction_entry->node_manager_id = client_id_.binary(); + reconstruction_entry->node_manager_id = client_id_.Binary(); RAY_CHECK_OK(task_reconstruction_log_.AppendAt( - DriverID::nil(), task_id, reconstruction_entry, + DriverID::Nil(), task_id, reconstruction_entry, /*success_callback=*/ [this](gcs::AsyncGcsClient *client, const TaskID &task_id, const TaskReconstructionDataT &data) { @@ -171,7 +171,7 @@ void ReconstructionPolicy::HandleTaskLeaseNotification(const TaskID &task_id, } void ReconstructionPolicy::ListenAndMaybeReconstruct(const ObjectID &object_id) { - TaskID task_id = object_id.task_id(); + TaskID task_id = object_id.TaskId(); auto it = listening_tasks_.find(task_id); // Add this object to the list of objects created by the same task. if (it == listening_tasks_.end()) { @@ -185,7 +185,7 @@ void ReconstructionPolicy::ListenAndMaybeReconstruct(const ObjectID &object_id) } void ReconstructionPolicy::Cancel(const ObjectID &object_id) { - TaskID task_id = object_id.task_id(); + TaskID task_id = object_id.TaskId(); auto it = listening_tasks_.find(task_id); if (it == listening_tasks_.end()) { // We already stopped listening for this task. @@ -199,7 +199,7 @@ void ReconstructionPolicy::Cancel(const ObjectID &object_id) { // Cancel notifications for the task lease if we were subscribed to them. if (it->second.subscribed) { RAY_CHECK_OK( - task_lease_pubsub_.CancelNotifications(DriverID::nil(), task_id, client_id_)); + task_lease_pubsub_.CancelNotifications(DriverID::Nil(), task_id, client_id_)); } listening_tasks_.erase(it); } diff --git a/src/ray/raylet/reconstruction_policy_test.cc b/src/ray/raylet/reconstruction_policy_test.cc index 7f8887b15372..4ccebd0c0c09 100644 --- a/src/ray/raylet/reconstruction_policy_test.cc +++ b/src/ray/raylet/reconstruction_policy_test.cc @@ -154,7 +154,7 @@ class ReconstructionPolicyTest : public ::testing::Test { reconstruction_policy_(std::make_shared( io_service_, [this](const TaskID &task_id) { TriggerReconstruction(task_id); }, - reconstruction_timeout_ms_, ClientID::from_random(), mock_gcs_, + reconstruction_timeout_ms_, ClientID::FromRandom(), mock_gcs_, mock_object_directory_, mock_gcs_)), timer_canceled_(false) { mock_gcs_.Subscribe( @@ -223,8 +223,8 @@ class ReconstructionPolicyTest : public ::testing::Test { }; TEST_F(ReconstructionPolicyTest, TestReconstructionSimple) { - TaskID task_id = TaskID::from_random(); - ObjectID object_id = ObjectID::for_task_return(task_id, 1); + TaskID task_id = TaskID::FromRandom(); + ObjectID object_id = ObjectID::ForTaskReturn(task_id, 1); // Listen for an object. reconstruction_policy_->ListenAndMaybeReconstruct(object_id); @@ -241,9 +241,9 @@ TEST_F(ReconstructionPolicyTest, TestReconstructionSimple) { } TEST_F(ReconstructionPolicyTest, TestReconstructionEvicted) { - TaskID task_id = TaskID::from_random(); - ObjectID object_id = ObjectID::for_task_return(task_id, 1); - mock_object_directory_->SetObjectLocations(object_id, {ClientID::from_random()}); + TaskID task_id = TaskID::FromRandom(); + ObjectID object_id = ObjectID::ForTaskReturn(task_id, 1); + mock_object_directory_->SetObjectLocations(object_id, {ClientID::FromRandom()}); // Listen for both objects. reconstruction_policy_->ListenAndMaybeReconstruct(object_id); @@ -264,9 +264,9 @@ TEST_F(ReconstructionPolicyTest, TestReconstructionEvicted) { } TEST_F(ReconstructionPolicyTest, TestReconstructionObjectLost) { - TaskID task_id = TaskID::from_random(); - ObjectID object_id = ObjectID::for_task_return(task_id, 1); - ClientID client_id = ClientID::from_random(); + TaskID task_id = TaskID::FromRandom(); + ObjectID object_id = ObjectID::ForTaskReturn(task_id, 1); + ClientID client_id = ClientID::FromRandom(); mock_object_directory_->SetObjectLocations(object_id, {client_id}); // Listen for both objects. @@ -288,9 +288,9 @@ TEST_F(ReconstructionPolicyTest, TestReconstructionObjectLost) { TEST_F(ReconstructionPolicyTest, TestDuplicateReconstruction) { // Create two object IDs produced by the same task. - TaskID task_id = TaskID::from_random(); - ObjectID object_id1 = ObjectID::for_task_return(task_id, 1); - ObjectID object_id2 = ObjectID::for_task_return(task_id, 2); + TaskID task_id = TaskID::FromRandom(); + ObjectID object_id1 = ObjectID::ForTaskReturn(task_id, 1); + ObjectID object_id2 = ObjectID::ForTaskReturn(task_id, 2); // Listen for both objects. reconstruction_policy_->ListenAndMaybeReconstruct(object_id1); @@ -308,17 +308,17 @@ TEST_F(ReconstructionPolicyTest, TestDuplicateReconstruction) { } TEST_F(ReconstructionPolicyTest, TestReconstructionSuppressed) { - TaskID task_id = TaskID::from_random(); - ObjectID object_id = ObjectID::for_task_return(task_id, 1); + TaskID task_id = TaskID::FromRandom(); + ObjectID object_id = ObjectID::ForTaskReturn(task_id, 1); // Run the test for much longer than the reconstruction timeout. int64_t test_period = 2 * reconstruction_timeout_ms_; // Acquire the task lease for a period longer than the test period. auto task_lease_data = std::make_shared(); - task_lease_data->node_manager_id = ClientID::from_random().binary(); + task_lease_data->node_manager_id = ClientID::FromRandom().Binary(); task_lease_data->acquired_at = current_sys_time_ms(); task_lease_data->timeout = 2 * test_period; - mock_gcs_.Add(DriverID::nil(), task_id, task_lease_data); + mock_gcs_.Add(DriverID::Nil(), task_id, task_lease_data); // Listen for an object. reconstruction_policy_->ListenAndMaybeReconstruct(object_id); @@ -334,18 +334,18 @@ TEST_F(ReconstructionPolicyTest, TestReconstructionSuppressed) { } TEST_F(ReconstructionPolicyTest, TestReconstructionContinuallySuppressed) { - TaskID task_id = TaskID::from_random(); - ObjectID object_id = ObjectID::for_task_return(task_id, 1); + TaskID task_id = TaskID::FromRandom(); + ObjectID object_id = ObjectID::ForTaskReturn(task_id, 1); // Listen for an object. reconstruction_policy_->ListenAndMaybeReconstruct(object_id); // Send the reconstruction manager heartbeats about the object. SetPeriodicTimer(reconstruction_timeout_ms_ / 2, [this, task_id]() { auto task_lease_data = std::make_shared(); - task_lease_data->node_manager_id = ClientID::from_random().binary(); + task_lease_data->node_manager_id = ClientID::FromRandom().Binary(); task_lease_data->acquired_at = current_sys_time_ms(); task_lease_data->timeout = reconstruction_timeout_ms_; - mock_gcs_.Add(DriverID::nil(), task_id, task_lease_data); + mock_gcs_.Add(DriverID::Nil(), task_id, task_lease_data); }); // Run the test for much longer than the reconstruction timeout. Run(reconstruction_timeout_ms_ * 2); @@ -361,8 +361,8 @@ TEST_F(ReconstructionPolicyTest, TestReconstructionContinuallySuppressed) { } TEST_F(ReconstructionPolicyTest, TestReconstructionCanceled) { - TaskID task_id = TaskID::from_random(); - ObjectID object_id = ObjectID::for_task_return(task_id, 1); + TaskID task_id = TaskID::FromRandom(); + ObjectID object_id = ObjectID::ForTaskReturn(task_id, 1); // Listen for an object. reconstruction_policy_->ListenAndMaybeReconstruct(object_id); @@ -387,17 +387,17 @@ TEST_F(ReconstructionPolicyTest, TestReconstructionCanceled) { } TEST_F(ReconstructionPolicyTest, TestSimultaneousReconstructionSuppressed) { - TaskID task_id = TaskID::from_random(); - ObjectID object_id = ObjectID::for_task_return(task_id, 1); + TaskID task_id = TaskID::FromRandom(); + ObjectID object_id = ObjectID::ForTaskReturn(task_id, 1); // Log a reconstruction attempt to simulate a different node attempting the // reconstruction first. This should suppress this node's first attempt at // reconstruction. auto task_reconstruction_data = std::make_shared(); - task_reconstruction_data->node_manager_id = ClientID::from_random().binary(); + task_reconstruction_data->node_manager_id = ClientID::FromRandom().Binary(); task_reconstruction_data->num_reconstructions = 0; RAY_CHECK_OK( - mock_gcs_.AppendAt(DriverID::nil(), task_id, task_reconstruction_data, nullptr, + mock_gcs_.AppendAt(DriverID::Nil(), task_id, task_reconstruction_data, nullptr, /*failure_callback=*/ [](ray::gcs::AsyncGcsClient *client, const TaskID &task_id, const TaskReconstructionDataT &data) { ASSERT_TRUE(false); }, diff --git a/src/ray/raylet/task_dependency_manager.cc b/src/ray/raylet/task_dependency_manager.cc index dc24c95d46e4..c5155b96b0c1 100644 --- a/src/ray/raylet/task_dependency_manager.cc +++ b/src/ray/raylet/task_dependency_manager.cc @@ -24,7 +24,7 @@ bool TaskDependencyManager::CheckObjectLocal(const ObjectID &object_id) const { } bool TaskDependencyManager::CheckObjectRequired(const ObjectID &object_id) const { - const TaskID task_id = object_id.task_id(); + const TaskID task_id = object_id.TaskId(); auto task_entry = required_tasks_.find(task_id); // If there are no subscribed tasks that are dependent on the object, then do // nothing. @@ -82,7 +82,7 @@ std::vector TaskDependencyManager::HandleObjectLocal( // Find any tasks that are dependent on the newly available object. std::vector ready_task_ids; - auto creating_task_entry = required_tasks_.find(object_id.task_id()); + auto creating_task_entry = required_tasks_.find(object_id.TaskId()); if (creating_task_entry != required_tasks_.end()) { auto object_entry = creating_task_entry->second.find(object_id); if (object_entry != creating_task_entry->second.end()) { @@ -113,7 +113,7 @@ std::vector TaskDependencyManager::HandleObjectMissing( // Find any tasks that are dependent on the missing object. std::vector waiting_task_ids; - TaskID creating_task_id = object_id.task_id(); + TaskID creating_task_id = object_id.TaskId(); auto creating_task_entry = required_tasks_.find(creating_task_id); if (creating_task_entry != required_tasks_.end()) { auto object_entry = creating_task_entry->second.find(object_id); @@ -149,7 +149,7 @@ bool TaskDependencyManager::SubscribeDependencies( auto inserted = task_entry.object_dependencies.insert(object_id); if (inserted.second) { // Get the ID of the task that creates the dependency. - TaskID creating_task_id = object_id.task_id(); + TaskID creating_task_id = object_id.TaskId(); // Determine whether the dependency can be fulfilled by the local node. if (local_objects_.count(object_id) == 0) { // The object is not local. @@ -186,7 +186,7 @@ bool TaskDependencyManager::UnsubscribeDependencies(const TaskID &task_id) { // Remove the task from the list of tasks that are dependent on this // object. // Get the ID of the task that creates the dependency. - TaskID creating_task_id = object_id.task_id(); + TaskID creating_task_id = object_id.TaskId(); auto creating_task_entry = required_tasks_.find(creating_task_id); std::vector &dependent_tasks = creating_task_entry->second[object_id]; auto it = std::find(dependent_tasks.begin(), dependent_tasks.end(), task_id); @@ -262,10 +262,10 @@ void TaskDependencyManager::AcquireTaskLease(const TaskID &task_id) { } auto task_lease_data = std::make_shared(); - task_lease_data->node_manager_id = client_id_.hex(); + task_lease_data->node_manager_id = client_id_.Hex(); task_lease_data->acquired_at = current_sys_time_ms(); task_lease_data->timeout = it->second.lease_period; - RAY_CHECK_OK(task_lease_table_.Add(DriverID::nil(), task_id, task_lease_data, nullptr)); + RAY_CHECK_OK(task_lease_table_.Add(DriverID::Nil(), task_id, task_lease_data, nullptr)); auto period = boost::posix_time::milliseconds(it->second.lease_period / 2); it->second.lease_timer->expires_from_now(period); @@ -324,7 +324,7 @@ void TaskDependencyManager::RemoveTasksAndRelatedObjects( // Cancel all of the objects that were required by the removed tasks. for (const auto &object_id : required_objects) { - TaskID creating_task_id = object_id.task_id(); + TaskID creating_task_id = object_id.TaskId(); required_tasks_.erase(creating_task_id); HandleRemoteDependencyCanceled(object_id); } diff --git a/src/ray/raylet/task_dependency_manager_test.cc b/src/ray/raylet/task_dependency_manager_test.cc index 62bbf17069d5..e0f832a12870 100644 --- a/src/ray/raylet/task_dependency_manager_test.cc +++ b/src/ray/raylet/task_dependency_manager_test.cc @@ -43,7 +43,7 @@ class TaskDependencyManagerTest : public ::testing::Test { gcs_mock_(), initial_lease_period_ms_(100), task_dependency_manager_(object_manager_mock_, reconstruction_policy_mock_, - io_service_, ClientID::nil(), initial_lease_period_ms_, + io_service_, ClientID::Nil(), initial_lease_period_ms_, gcs_mock_) {} void Run(uint64_t timeout_ms) { @@ -75,7 +75,7 @@ static inline Task ExampleTask(const std::vector &arguments, task_arguments.emplace_back(std::make_shared(references)); } std::vector function_descriptor(3); - auto spec = TaskSpecification(DriverID::nil(), TaskID::from_random(), 0, task_arguments, + auto spec = TaskSpecification(DriverID::Nil(), TaskID::FromRandom(), 0, task_arguments, num_returns, required_resources, Language::PYTHON, function_descriptor); auto execution_spec = TaskExecutionSpecification(std::vector()); @@ -105,9 +105,9 @@ TEST_F(TaskDependencyManagerTest, TestSimpleTask) { int num_arguments = 3; std::vector arguments; for (int i = 0; i < num_arguments; i++) { - arguments.push_back(ObjectID::from_random()); + arguments.push_back(ObjectID::FromRandom()); } - TaskID task_id = TaskID::from_random(); + TaskID task_id = TaskID::FromRandom(); // No objects have been registered in the task dependency manager, so all // arguments should be remote. for (const auto &argument_id : arguments) { @@ -139,12 +139,12 @@ TEST_F(TaskDependencyManagerTest, TestSimpleTask) { TEST_F(TaskDependencyManagerTest, TestDuplicateSubscribe) { // Create a task with 3 arguments. - TaskID task_id = TaskID::from_random(); + TaskID task_id = TaskID::FromRandom(); int num_arguments = 3; std::vector arguments; for (int i = 0; i < num_arguments; i++) { // Add the new argument to the list of dependencies to subscribe to. - ObjectID argument_id = ObjectID::from_random(); + ObjectID argument_id = ObjectID::FromRandom(); arguments.push_back(argument_id); // Subscribe to the task's dependencies. All arguments except the last are // duplicates of previous subscription calls. Each argument should only be @@ -176,7 +176,7 @@ TEST_F(TaskDependencyManagerTest, TestDuplicateSubscribe) { TEST_F(TaskDependencyManagerTest, TestMultipleTasks) { // Create 3 tasks that are dependent on the same object. - ObjectID argument_id = ObjectID::from_random(); + ObjectID argument_id = ObjectID::FromRandom(); std::vector dependent_tasks; int num_dependent_tasks = 3; // The object should only be requested from the object manager once for all @@ -184,7 +184,7 @@ TEST_F(TaskDependencyManagerTest, TestMultipleTasks) { EXPECT_CALL(object_manager_mock_, Pull(argument_id)); EXPECT_CALL(reconstruction_policy_mock_, ListenAndMaybeReconstruct(argument_id)); for (int i = 0; i < num_dependent_tasks; i++) { - TaskID task_id = TaskID::from_random(); + TaskID task_id = TaskID::FromRandom(); dependent_tasks.push_back(task_id); // Subscribe to each of the task's dependencies. bool ready = task_dependency_manager_.SubscribeDependencies(task_id, {argument_id}); @@ -266,7 +266,7 @@ TEST_F(TaskDependencyManagerTest, TestTaskChain) { TEST_F(TaskDependencyManagerTest, TestDependentPut) { // Create a task with 3 arguments. auto task1 = ExampleTask({}, 0); - ObjectID put_id = ObjectID::for_put(task1.GetTaskSpecification().TaskId(), 1); + ObjectID put_id = ObjectID::ForPut(task1.GetTaskSpecification().TaskId(), 1); auto task2 = ExampleTask({put_id}, 0); // No objects have been registered in the task dependency manager, so the put @@ -326,9 +326,9 @@ TEST_F(TaskDependencyManagerTest, TestEviction) { int num_arguments = 3; std::vector arguments; for (int i = 0; i < num_arguments; i++) { - arguments.push_back(ObjectID::from_random()); + arguments.push_back(ObjectID::FromRandom()); } - TaskID task_id = TaskID::from_random(); + TaskID task_id = TaskID::FromRandom(); // No objects have been registered in the task dependency manager, so all // arguments should be remote. for (const auto &argument_id : arguments) { diff --git a/src/ray/raylet/task_execution_spec.cc b/src/ray/raylet/task_execution_spec.cc index c5b1486d5bcc..dc7bf30b83d2 100644 --- a/src/ray/raylet/task_execution_spec.cc +++ b/src/ray/raylet/task_execution_spec.cc @@ -25,7 +25,7 @@ TaskExecutionSpecification::ToFlatbuffer(flatbuffers::FlatBufferBuilder &fbb) co std::vector TaskExecutionSpecification::ExecutionDependencies() const { std::vector dependencies; for (const auto &dependency : execution_spec_.dependencies) { - dependencies.push_back(ObjectID::from_binary(dependency)); + dependencies.push_back(ObjectID::FromBinary(dependency)); } return dependencies; } @@ -34,7 +34,7 @@ void TaskExecutionSpecification::SetExecutionDependencies( const std::vector &dependencies) { execution_spec_.dependencies.clear(); for (const auto &dependency : dependencies) { - execution_spec_.dependencies.push_back(dependency.binary()); + execution_spec_.dependencies.push_back(dependency.Binary()); } } diff --git a/src/ray/raylet/task_spec.cc b/src/ray/raylet/task_spec.cc index 17a8b185fc78..eeab29272126 100644 --- a/src/ray/raylet/task_spec.cc +++ b/src/ray/raylet/task_spec.cc @@ -65,8 +65,8 @@ TaskSpecification::TaskSpecification( const std::vector> &task_arguments, int64_t num_returns, const std::unordered_map &required_resources, const Language &language, const std::vector &function_descriptor) - : TaskSpecification(driver_id, parent_task_id, parent_counter, ActorID::nil(), - ObjectID::nil(), 0, ActorID::nil(), ActorHandleID::nil(), -1, {}, + : TaskSpecification(driver_id, parent_task_id, parent_counter, ActorID::Nil(), + ObjectID::Nil(), 0, ActorID::Nil(), ActorHandleID::Nil(), -1, {}, task_arguments, num_returns, required_resources, std::unordered_map(), language, function_descriptor) {} @@ -165,8 +165,7 @@ int64_t TaskSpecification::NumReturns() const { } ObjectID TaskSpecification::ReturnId(int64_t return_index) const { - auto message = flatbuffers::GetRoot(spec_.data()); - return ObjectID::for_task_return(TaskId(), return_index + 1); + return ObjectID::ForTaskReturn(TaskId(), return_index + 1); } bool TaskSpecification::ArgByRef(int64_t arg_index) const { @@ -215,11 +214,9 @@ Language TaskSpecification::GetLanguage() const { return message->language(); } -bool TaskSpecification::IsActorCreationTask() const { - return !ActorCreationId().is_nil(); -} +bool TaskSpecification::IsActorCreationTask() const { return !ActorCreationId().IsNil(); } -bool TaskSpecification::IsActorTask() const { return !ActorId().is_nil(); } +bool TaskSpecification::IsActorTask() const { return !ActorId().IsNil(); } ActorID TaskSpecification::ActorCreationId() const { auto message = flatbuffers::GetRoot(spec_.data()); diff --git a/src/ray/raylet/task_test.cc b/src/ray/raylet/task_test.cc index 6d0cfa37017a..1e26cb33bf82 100644 --- a/src/ray/raylet/task_test.cc +++ b/src/ray/raylet/task_test.cc @@ -10,21 +10,21 @@ namespace raylet { void TestTaskReturnId(const TaskID &task_id, int64_t return_index) { // Round trip test for computing the object ID for a task's return value, // then computing the task ID that created the object. - ObjectID return_id = ObjectID::for_task_return(task_id, return_index); - ASSERT_EQ(return_id.task_id(), task_id); - ASSERT_EQ(return_id.object_index(), return_index); + ObjectID return_id = ObjectID::ForTaskReturn(task_id, return_index); + ASSERT_EQ(return_id.TaskId(), task_id); + ASSERT_EQ(return_id.ObjectIndex(), return_index); } void TestTaskPutId(const TaskID &task_id, int64_t put_index) { // Round trip test for computing the object ID for a task's put value, then // computing the task ID that created the object. - ObjectID put_id = ObjectID::for_put(task_id, put_index); - ASSERT_EQ(put_id.task_id(), task_id); - ASSERT_EQ(put_id.object_index(), -1 * put_index); + ObjectID put_id = ObjectID::ForPut(task_id, put_index); + ASSERT_EQ(put_id.TaskId(), task_id); + ASSERT_EQ(put_id.ObjectIndex(), -1 * put_index); } TEST(TaskSpecTest, TestTaskReturnIds) { - TaskID task_id = TaskID::from_random(); + TaskID task_id = TaskID::FromRandom(); // Check that we can compute between a task ID and the object IDs of its // return values and puts. @@ -37,25 +37,25 @@ TEST(TaskSpecTest, TestTaskReturnIds) { } TEST(IdPropertyTest, TestIdProperty) { - TaskID task_id = TaskID::from_random(); - ASSERT_EQ(task_id, TaskID::from_binary(task_id.binary())); - ObjectID object_id = ObjectID::from_random(); - ASSERT_EQ(object_id, ObjectID::from_binary(object_id.binary())); + TaskID task_id = TaskID::FromRandom(); + ASSERT_EQ(task_id, TaskID::FromBinary(task_id.Binary())); + ObjectID object_id = ObjectID::FromRandom(); + ASSERT_EQ(object_id, ObjectID::FromBinary(object_id.Binary())); - ASSERT_TRUE(TaskID().is_nil()); - ASSERT_TRUE(TaskID::nil().is_nil()); - ASSERT_TRUE(ObjectID().is_nil()); - ASSERT_TRUE(ObjectID::nil().is_nil()); + ASSERT_TRUE(TaskID().IsNil()); + ASSERT_TRUE(TaskID::Nil().IsNil()); + ASSERT_TRUE(ObjectID().IsNil()); + ASSERT_TRUE(ObjectID::Nil().IsNil()); } TEST(TaskSpecTest, TaskInfoSize) { - std::vector references = {ObjectID::from_random(), ObjectID::from_random()}; + std::vector references = {ObjectID::FromRandom(), ObjectID::FromRandom()}; auto arguments_1 = std::make_shared(references); std::string one_arg("This is an value argument."); auto arguments_2 = std::make_shared( reinterpret_cast(one_arg.c_str()), one_arg.size()); std::vector> task_arguments({arguments_1, arguments_2}); - auto task_id = TaskID::from_random(); + auto task_id = TaskID::FromRandom(); { flatbuffers::FlatBufferBuilder fbb; std::vector> arguments; @@ -64,10 +64,10 @@ TEST(TaskSpecTest, TaskInfoSize) { } // General task. auto spec = CreateTaskInfo( - fbb, to_flatbuf(fbb, DriverID::from_random()), to_flatbuf(fbb, task_id), - to_flatbuf(fbb, TaskID::from_random()), 0, to_flatbuf(fbb, ActorID::nil()), - to_flatbuf(fbb, ObjectID::nil()), 0, to_flatbuf(fbb, ActorID::nil()), - to_flatbuf(fbb, ActorHandleID::nil()), 0, + fbb, to_flatbuf(fbb, DriverID::FromRandom()), to_flatbuf(fbb, task_id), + to_flatbuf(fbb, TaskID::FromRandom()), 0, to_flatbuf(fbb, ActorID::Nil()), + to_flatbuf(fbb, ObjectID::Nil()), 0, to_flatbuf(fbb, ActorID::Nil()), + to_flatbuf(fbb, ActorHandleID::Nil()), 0, ids_to_flatbuf(fbb, std::vector()), fbb.CreateVector(arguments), 1, map_to_flatbuf(fbb, {}), map_to_flatbuf(fbb, {}), Language::PYTHON, string_vec_to_flatbuf(fbb, {"PackageName", "ClassName", "FunctionName"})); @@ -83,13 +83,13 @@ TEST(TaskSpecTest, TaskInfoSize) { } // General task. auto spec = CreateTaskInfo( - fbb, to_flatbuf(fbb, DriverID::from_random()), to_flatbuf(fbb, task_id), - to_flatbuf(fbb, TaskID::from_random()), 10, - to_flatbuf(fbb, ActorID::from_random()), to_flatbuf(fbb, ObjectID::from_random()), - 10000000, to_flatbuf(fbb, ActorID::from_random()), - to_flatbuf(fbb, ActorHandleID::from_random()), 20, - ids_to_flatbuf(fbb, std::vector( - {ObjectID::from_random(), ObjectID::from_random()})), + fbb, to_flatbuf(fbb, DriverID::FromRandom()), to_flatbuf(fbb, task_id), + to_flatbuf(fbb, TaskID::FromRandom()), 10, to_flatbuf(fbb, ActorID::FromRandom()), + to_flatbuf(fbb, ObjectID::FromRandom()), 10000000, + to_flatbuf(fbb, ActorID::FromRandom()), + to_flatbuf(fbb, ActorHandleID::FromRandom()), 20, + ids_to_flatbuf( + fbb, std::vector({ObjectID::FromRandom(), ObjectID::FromRandom()})), fbb.CreateVector(arguments), 2, map_to_flatbuf(fbb, {}), map_to_flatbuf(fbb, {}), Language::PYTHON, string_vec_to_flatbuf(fbb, {"PackageName", "ClassName", "FunctionName"})); diff --git a/src/ray/raylet/worker.cc b/src/ray/raylet/worker.cc index c6686e4b6f6b..36bfc6d846b9 100644 --- a/src/ray/raylet/worker.cc +++ b/src/ray/raylet/worker.cc @@ -57,9 +57,9 @@ void Worker::AssignDriverId(const DriverID &driver_id) { const DriverID &Worker::GetAssignedDriverId() const { return assigned_driver_id_; } void Worker::AssignActorId(const ActorID &actor_id) { - RAY_CHECK(actor_id_.is_nil()) + RAY_CHECK(actor_id_.IsNil()) << "A worker that is already an actor cannot be assigned an actor ID again."; - RAY_CHECK(!actor_id.is_nil()); + RAY_CHECK(!actor_id.IsNil()); actor_id_ = actor_id; } diff --git a/src/ray/raylet/worker_pool.cc b/src/ray/raylet/worker_pool.cc index 3138b88cf696..27e7fea05311 100644 --- a/src/ray/raylet/worker_pool.cc +++ b/src/ray/raylet/worker_pool.cc @@ -172,7 +172,7 @@ void WorkerPool::RegisterWorker(const std::shared_ptr &worker) { } void WorkerPool::RegisterDriver(const std::shared_ptr &driver) { - RAY_CHECK(!driver->GetAssignedTaskId().is_nil()); + RAY_CHECK(!driver->GetAssignedTaskId().IsNil()); auto &state = GetStateForLanguage(driver->GetLanguage()); state.registered_drivers.insert(std::move(driver)); } @@ -201,11 +201,11 @@ std::shared_ptr WorkerPool::GetRegisteredDriver( void WorkerPool::PushWorker(const std::shared_ptr &worker) { // Since the worker is now idle, unset its assigned task ID. - RAY_CHECK(worker->GetAssignedTaskId().is_nil()) + RAY_CHECK(worker->GetAssignedTaskId().IsNil()) << "Idle workers cannot have an assigned task ID"; auto &state = GetStateForLanguage(worker->GetLanguage()); // Add the worker to the idle pool. - if (worker->GetActorId().is_nil()) { + if (worker->GetActorId().IsNil()) { state.idle.insert(std::move(worker)); } else { state.idle_actor[worker->GetActorId()] = std::move(worker); @@ -216,7 +216,7 @@ std::shared_ptr WorkerPool::PopWorker(const TaskSpecification &task_spec auto &state = GetStateForLanguage(task_spec.GetLanguage()); const auto &actor_id = task_spec.ActorId(); std::shared_ptr worker = nullptr; - if (actor_id.is_nil()) { + if (actor_id.IsNil()) { if (!state.idle.empty()) { worker = std::move(*state.idle.begin()); state.idle.erase(state.idle.begin()); diff --git a/src/ray/raylet/worker_pool_test.cc b/src/ray/raylet/worker_pool_test.cc index 9799dfb80a40..143ffd57dda6 100644 --- a/src/ray/raylet/worker_pool_test.cc +++ b/src/ray/raylet/worker_pool_test.cc @@ -72,11 +72,11 @@ class WorkerPoolTest : public ::testing::Test { }; static inline TaskSpecification ExampleTaskSpec( - const ActorID actor_id = ActorID::nil(), + const ActorID actor_id = ActorID::Nil(), const Language &language = Language::PYTHON) { std::vector function_descriptor(3); - return TaskSpecification(DriverID::nil(), TaskID::nil(), 0, ActorID::nil(), - ObjectID::nil(), 0, actor_id, ActorHandleID::nil(), 0, {}, {}, + return TaskSpecification(DriverID::Nil(), TaskID::Nil(), 0, ActorID::Nil(), + ObjectID::Nil(), 0, actor_id, ActorHandleID::Nil(), 0, {}, {}, 0, {}, {}, language, function_descriptor); } @@ -155,7 +155,7 @@ TEST_F(WorkerPoolTest, PopActorWorker) { // Assign an actor ID to the worker. const auto task_spec = ExampleTaskSpec(); auto actor = worker_pool_.PopWorker(task_spec); - auto actor_id = ActorID::from_random(); + auto actor_id = ActorID::FromRandom(); actor->AssignActorId(actor_id); worker_pool_.PushWorker(actor); @@ -173,10 +173,10 @@ TEST_F(WorkerPoolTest, PopWorkersOfMultipleLanguages) { auto py_worker = CreateWorker(1234, Language::PYTHON); worker_pool_.PushWorker(py_worker); // Check that no worker will be popped if the given task is a Java task - const auto java_task_spec = ExampleTaskSpec(ActorID::nil(), Language::JAVA); + const auto java_task_spec = ExampleTaskSpec(ActorID::Nil(), Language::JAVA); ASSERT_EQ(worker_pool_.PopWorker(java_task_spec), nullptr); // Check that the worker can be popped if the given task is a Python task - const auto py_task_spec = ExampleTaskSpec(ActorID::nil(), Language::PYTHON); + const auto py_task_spec = ExampleTaskSpec(ActorID::Nil(), Language::PYTHON); ASSERT_NE(worker_pool_.PopWorker(py_task_spec), nullptr); // Create a Java Worker, and add it to the pool From 0066d7cf2ac723d8d47f1498f01eed7fcfda5f8e Mon Sep 17 00:00:00 2001 From: Yuhong Guo Date: Fri, 31 May 2019 16:41:32 +0800 Subject: [PATCH 055/118] Hotfix for change of from_random to FromRandom (#4909) --- src/ray/core_worker/core_worker_test.cc | 2 +- src/ray/raylet/mock_gcs_client.cc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ray/core_worker/core_worker_test.cc b/src/ray/core_worker/core_worker_test.cc index b1be58da95b8..6711c874a973 100644 --- a/src/ray/core_worker/core_worker_test.cc +++ b/src/ray/core_worker/core_worker_test.cc @@ -16,7 +16,7 @@ class CoreWorkerTest : public ::testing::Test { TEST_F(CoreWorkerTest, TestTaskArg) { // Test by-reference argument. - ObjectID id = ObjectID::from_random(); + ObjectID id = ObjectID::FromRandom(); TaskArg by_ref = TaskArg::PassByReference(id); ASSERT_TRUE(by_ref.IsPassedByReference()); ASSERT_EQ(by_ref.GetReference(), id); diff --git a/src/ray/raylet/mock_gcs_client.cc b/src/ray/raylet/mock_gcs_client.cc index 69b197899b29..1a75b6593fe8 100644 --- a/src/ray/raylet/mock_gcs_client.cc +++ b/src/ray/raylet/mock_gcs_client.cc @@ -108,7 +108,7 @@ ray::Status ClientTable::Remove(const ClientID &client_id, DoneCallback done_cal } ClientID GcsClient::Register(const std::string &ip, uint16_t port) { - ClientID client_id = ClientID().from_random(); + ClientID client_id = ClientID::FromRandom(); // TODO: handle client registration failure. ray::Status status = client_table().Add(std::move(client_id), ip, port, []() {}); return client_id; From 1c073e92e4f23c7b61e16ad3d3b77c6c69ca35cc Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Sat, 1 Jun 2019 16:13:21 +0800 Subject: [PATCH 056/118] [rllib] Fix documentation on custom policies (#4910) * wip * add docs * lint * todo sections * fix doc --- ci/jenkins_tests/run_rllib_tests.sh | 6 +++ doc/source/rllib-concepts.rst | 33 +++++++++++-- doc/source/rllib.rst | 2 + python/ray/rllib/agents/trainer.py | 2 + python/ray/rllib/examples/custom_tf_policy.py | 47 +++++++++++++++++++ .../ray/rllib/examples/custom_torch_policy.py | 45 ++++++++++++++++++ 6 files changed, 131 insertions(+), 4 deletions(-) create mode 100644 python/ray/rllib/examples/custom_tf_policy.py create mode 100644 python/ray/rllib/examples/custom_torch_policy.py diff --git a/ci/jenkins_tests/run_rllib_tests.sh b/ci/jenkins_tests/run_rllib_tests.sh index 13acff28d39c..78fbf6a3ab46 100644 --- a/ci/jenkins_tests/run_rllib_tests.sh +++ b/ci/jenkins_tests/run_rllib_tests.sh @@ -389,6 +389,12 @@ docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ /ray/ci/suppress_output python /ray/python/ray/rllib/examples/custom_loss.py --iters=2 +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ + /ray/ci/suppress_output python /ray/python/ray/rllib/examples/custom_tf_policy.py --iters=2 + +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ + /ray/ci/suppress_output python /ray/python/ray/rllib/examples/custom_torch_policy.py --iters=2 + docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ /ray/ci/suppress_output python /ray/python/ray/rllib/examples/policy_evaluator_custom_workflow.py diff --git a/doc/source/rllib-concepts.rst b/doc/source/rllib-concepts.rst index 06e890832295..8556e419ae08 100644 --- a/doc/source/rllib-concepts.rst +++ b/doc/source/rllib-concepts.rst @@ -6,7 +6,7 @@ This page describes the internal concepts used to implement algorithms in RLlib. Policies -------- -Policy classes encapsulate the core numerical components of RL algorithms. This typically includes the policy model that determines actions to take, a trajectory postprocessor for experiences, and a loss function to improve the policy given postprocessed experiences. For a simple example, see the policy gradients `graph definition `__. +Policy classes encapsulate the core numerical components of RL algorithms. This typically includes the policy model that determines actions to take, a trajectory postprocessor for experiences, and a loss function to improve the policy given postprocessed experiences. For a simple example, see the policy gradients `policy definition `__. Most interaction with deep learning frameworks is isolated to the `Policy interface `__, allowing RLlib to support multiple frameworks. To simplify the definition of policies, RLlib includes `Tensorflow <#building-policies-in-tensorflow>`__ and `PyTorch-specific <#building-policies-in-pytorch>`__ templates. You can also write your own from scratch. Here is an example: @@ -148,7 +148,7 @@ We can create a `Trainer <#trainers>`__ and try running this policy on a toy env tune.run(MyTrainer, config={"env": "CartPole-v0", "num_workers": 2}) -If you run the above snippet, you'll probably notice that CartPole doesn't learn so well: +If you run the above snippet `(runnable file here) `__, you'll probably notice that CartPole doesn't learn so well: .. code-block:: bash @@ -208,7 +208,7 @@ In the above section you saw how to compose a simple policy gradient algorithm w Besides some boilerplate for defining the PPO configuration and some warnings, there are two important arguments to take note of here: ``make_policy_optimizer=choose_policy_optimizer``, and ``after_optimizer_step=update_kl``. -The ``choose_policy_optimizer`` function chooses which `Policy Optimizer <#policy-optimization>`__ to use for distributed training. You can think of these policy optimizers as coordinating the distributed workflow needed to improve the policy. Depending on the trainer config, PPO can switch between a simple synchronous optimizer (the default), or a multi-GPU optimizer that implements minibatch SGD: +The ``choose_policy_optimizer`` function chooses which `Policy Optimizer <#policy-optimization>`__ to use for distributed training. You can think of these policy optimizers as coordinating the distributed workflow needed to improve the policy. Depending on the trainer config, PPO can switch between a simple synchronous optimizer, or a multi-GPU optimizer that implements minibatch SGD (the default): .. code-block:: python @@ -349,7 +349,27 @@ Finally, note that you do not have to use ``build_tf_policy`` to define a Tensor Building Policies in PyTorch ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Defining a policy in PyTorch is quite similar to that for TensorFlow (and the process of defining a trainer given a Torch policy is exactly the same). Building on the TF examples above, let's look at how the `A3C torch policy `__ is defined: +Defining a policy in PyTorch is quite similar to that for TensorFlow (and the process of defining a trainer given a Torch policy is exactly the same). Here's a simple example of a trivial torch policy `(runnable file here) `__: + +.. code-block:: python + + from ray.rllib.policy.sample_batch import SampleBatch + from ray.rllib.policy.torch_policy_template import build_torch_policy + + def policy_gradient_loss(policy, batch_tensors): + logits, _, values, _ = policy.model({ + SampleBatch.CUR_OBS: batch_tensors[SampleBatch.CUR_OBS] + }, []) + action_dist = policy.dist_class(logits) + log_probs = action_dist.logp(batch_tensors[SampleBatch.ACTIONS]) + return -batch_tensors[SampleBatch.REWARDS].dot(log_probs) + + # + MyTorchPolicy = build_torch_policy( + name="MyTorchPolicy", + loss_fn=policy_gradient_loss) + +Now, building on the TF examples above, let's look at how the `A3C torch policy `__ is defined: .. code-block:: python @@ -423,6 +443,11 @@ You can find the full policy definition in `a3c_torch_policy.py `__ + - `Extending Existing Policies `__ + * `Policy Evaluation `__ * `Policy Optimization `__ * `Trainers `__ diff --git a/python/ray/rllib/agents/trainer.py b/python/ray/rllib/agents/trainer.py index 4294affb1172..fb20f56baa21 100644 --- a/python/ray/rllib/agents/trainer.py +++ b/python/ray/rllib/agents/trainer.py @@ -100,6 +100,8 @@ "clip_actions": True, # Whether to use rllib or deepmind preprocessors by default "preprocessor_pref": "deepmind", + # The default learning rate + "lr": 0.0001, # === Evaluation === # Evaluate with every `evaluation_interval` training iterations. diff --git a/python/ray/rllib/examples/custom_tf_policy.py b/python/ray/rllib/examples/custom_tf_policy.py new file mode 100644 index 000000000000..0442dff83d71 --- /dev/null +++ b/python/ray/rllib/examples/custom_tf_policy.py @@ -0,0 +1,47 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse + +import ray +from ray import tune +from ray.rllib.agents.trainer_template import build_trainer +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.tf_policy_template import build_tf_policy +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() + +parser = argparse.ArgumentParser() +parser.add_argument("--iters", type=int, default=200) + + +def policy_gradient_loss(policy, batch_tensors): + actions = batch_tensors[SampleBatch.ACTIONS] + rewards = batch_tensors[SampleBatch.REWARDS] + return -tf.reduce_mean(policy.action_dist.logp(actions) * rewards) + + +# +MyTFPolicy = build_tf_policy( + name="MyTFPolicy", + loss_fn=policy_gradient_loss, +) + +# +MyTrainer = build_trainer( + name="MyCustomTrainer", + default_policy=MyTFPolicy, +) + +if __name__ == "__main__": + ray.init() + args = parser.parse_args() + tune.run( + MyTrainer, + stop={"training_iteration": args.iters}, + config={ + "env": "CartPole-v0", + "num_workers": 2, + }) diff --git a/python/ray/rllib/examples/custom_torch_policy.py b/python/ray/rllib/examples/custom_torch_policy.py new file mode 100644 index 000000000000..7ab2786cfb6e --- /dev/null +++ b/python/ray/rllib/examples/custom_torch_policy.py @@ -0,0 +1,45 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse + +import ray +from ray import tune +from ray.rllib.agents.trainer_template import build_trainer +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.torch_policy_template import build_torch_policy + +parser = argparse.ArgumentParser() +parser.add_argument("--iters", type=int, default=200) + + +def policy_gradient_loss(policy, batch_tensors): + logits, _, values, _ = policy.model({ + SampleBatch.CUR_OBS: batch_tensors[SampleBatch.CUR_OBS] + }, []) + action_dist = policy.dist_class(logits) + log_probs = action_dist.logp(batch_tensors[SampleBatch.ACTIONS]) + return -batch_tensors[SampleBatch.REWARDS].dot(log_probs) + + +# +MyTorchPolicy = build_torch_policy( + name="MyTorchPolicy", loss_fn=policy_gradient_loss) + +# +MyTrainer = build_trainer( + name="MyCustomTrainer", + default_policy=MyTorchPolicy, +) + +if __name__ == "__main__": + ray.init() + args = parser.parse_args() + tune.run( + MyTrainer, + stop={"training_iteration": args.iters}, + config={ + "env": "CartPole-v0", + "num_workers": 2, + }) From 9aa1cd613d2457176c731807b237d8747b539b1b Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Sat, 1 Jun 2019 16:58:49 +0800 Subject: [PATCH 057/118] [rllib] Allow Torch policies access to full action input dict in extra_action_out_fn (#4894) * fix torch extra out * preserve setitem * fix docs --- doc/source/rllib-concepts.rst | 2 +- .../ray/rllib/agents/a3c/a3c_torch_policy.py | 2 +- python/ray/rllib/policy/torch_policy.py | 30 ++++++++++++++----- .../ray/rllib/policy/torch_policy_template.py | 8 +++-- python/ray/rllib/utils/tracking_dict.py | 5 ++++ 5 files changed, 34 insertions(+), 13 deletions(-) diff --git a/doc/source/rllib-concepts.rst b/doc/source/rllib-concepts.rst index 8556e419ae08..2f9603b69f58 100644 --- a/doc/source/rllib-concepts.rst +++ b/doc/source/rllib-concepts.rst @@ -413,7 +413,7 @@ Now, building on the TF examples above, let's look at how the `A3C torch policy .. code-block:: python - def model_value_predictions(policy, model_out): + def model_value_predictions(policy, input_dict, state_batches, model_out): return {SampleBatch.VF_PREDS: model_out[2].cpu().numpy()} ``postprocess_fn`` and ``mixins``: Similar to the PPO example, we need access to the value function during postprocessing (i.e., ``add_advantages`` below calls ``policy._value()``. The value function is exposed through a mixin class that defines the method: diff --git a/python/ray/rllib/agents/a3c/a3c_torch_policy.py b/python/ray/rllib/agents/a3c/a3c_torch_policy.py index 6ccf6c48d35f..f11ff51bec27 100644 --- a/python/ray/rllib/agents/a3c/a3c_torch_policy.py +++ b/python/ray/rllib/agents/a3c/a3c_torch_policy.py @@ -53,7 +53,7 @@ def add_advantages(policy, policy.config["lambda"]) -def model_value_predictions(policy, model_out): +def model_value_predictions(policy, input_dict, state_batches, model_out): return {SampleBatch.VF_PREDS: model_out[2].cpu().numpy()} diff --git a/python/ray/rllib/policy/torch_policy.py b/python/ray/rllib/policy/torch_policy.py index 633e438c5ad7..045902621ec6 100644 --- a/python/ray/rllib/policy/torch_policy.py +++ b/python/ray/rllib/policy/torch_policy.py @@ -2,9 +2,9 @@ from __future__ import division from __future__ import print_function +import numpy as np import os -import numpy as np from threading import Lock try: @@ -69,15 +69,21 @@ def compute_actions(self, **kwargs): with self.lock: with torch.no_grad(): - ob = torch.from_numpy(np.array(obs_batch)) \ - .float().to(self.device) - model_out = self._model({"obs": ob}, state_batches) + input_dict = self._lazy_tensor_dict({ + "obs": obs_batch, + }) + if prev_action_batch: + input_dict["prev_actions"] = prev_action_batch + if prev_reward_batch: + input_dict["prev_rewards"] = prev_reward_batch + model_out = self._model(input_dict, state_batches) logits, _, vf, state = model_out action_dist = self._action_dist_cls(logits) actions = action_dist.sample() return (actions.cpu().numpy(), [h.cpu().numpy() for h in state], - self.extra_action_out(model_out)) + self.extra_action_out(input_dict, state_batches, + model_out)) @override(Policy) def learn_on_batch(self, postprocessed_batch): @@ -146,10 +152,12 @@ def extra_grad_process(self): return processing info.""" return {} - def extra_action_out(self, model_out): + def extra_action_out(self, input_dict, state_batches, model_out): """Returns dict of extra info to include in experience batch. Arguments: + input_dict (dict): Dict of model input tensors. + state_batches (list): List of state tensors. model_out (list): Outputs of the policy model module.""" return {} @@ -168,6 +176,12 @@ def optimizer(self): def _lazy_tensor_dict(self, postprocessed_batch): batch_tensors = UsageTrackingDict(postprocessed_batch) - batch_tensors.set_get_interceptor( - lambda arr: torch.from_numpy(arr).to(self.device)) + + def convert(arr): + tensor = torch.from_numpy(np.asarray(arr)) + if tensor.dtype == torch.double: + tensor = tensor.float() + return tensor.to(self.device) + + batch_tensors.set_get_interceptor(convert) return batch_tensors diff --git a/python/ray/rllib/policy/torch_policy_template.py b/python/ray/rllib/policy/torch_policy_template.py index 049591c04671..19e943600210 100644 --- a/python/ray/rllib/policy/torch_policy_template.py +++ b/python/ray/rllib/policy/torch_policy_template.py @@ -108,11 +108,13 @@ def extra_grad_process(self): return TorchPolicy.extra_grad_process(self) @override(TorchPolicy) - def extra_action_out(self, model_out): + def extra_action_out(self, input_dict, state_batches, model_out): if extra_action_out_fn: - return extra_action_out_fn(self, model_out) + return extra_action_out_fn(self, input_dict, state_batches, + model_out) else: - return TorchPolicy.extra_action_out(self, model_out) + return TorchPolicy.extra_action_out(self, input_dict, + state_batches, model_out) @override(TorchPolicy) def optimizer(self): diff --git a/python/ray/rllib/utils/tracking_dict.py b/python/ray/rllib/utils/tracking_dict.py index c0f145734e78..9b64925dc251 100644 --- a/python/ray/rllib/utils/tracking_dict.py +++ b/python/ray/rllib/utils/tracking_dict.py @@ -30,3 +30,8 @@ def __getitem__(self, key): self.intercepted_values[key] = self.get_interceptor(value) value = self.intercepted_values[key] return value + + def __setitem__(self, key, value): + dict.__setitem__(self, key, value) + if key in self.intercepted_values: + self.intercepted_values[key] = value From 88bab5d3c4edd949b94c185deb9378135169456a Mon Sep 17 00:00:00 2001 From: Kristian Hartikainen Date: Sat, 1 Jun 2019 12:38:59 -0700 Subject: [PATCH 058/118] [tune] Pretty print params json in logger.py (#4903) --- python/ray/tune/logger.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/python/ray/tune/logger.py b/python/ray/tune/logger.py index 4b9d5a914aa1..895f4819e0c0 100644 --- a/python/ray/tune/logger.py +++ b/python/ray/tune/logger.py @@ -95,7 +95,12 @@ def update_config(self, config): self.config = config config_out = os.path.join(self.logdir, "params.json") with open(config_out, "w") as f: - json.dump(self.config, f, cls=_SafeFallbackEncoder) + json.dump( + self.config, + f, + indent=2, + sort_keys=True, + cls=_SafeFallbackEncoder) config_pkl = os.path.join(self.logdir, "params.pkl") with open(config_pkl, "wb") as f: cloudpickle.dump(self.config, f) From c2ade075a3fc912a099b68851879f293cdbb63d8 Mon Sep 17 00:00:00 2001 From: Peter Schafhalter Date: Sat, 1 Jun 2019 21:39:22 -0700 Subject: [PATCH 059/118] [sgd] Distributed Training via PyTorch (#4797) Implements distributed SGD using distributed PyTorch. --- ci/jenkins_tests/run_multi_node_tests.sh | 23 +- doc/source/conf.py | 4 + doc/source/distributed_training.rst | 48 ++++ doc/source/index.rst | 3 +- .../ray/experimental/sgd/pytorch/__init__.py | 8 + .../sgd/pytorch/pytorch_runner.py | 182 +++++++++++++ .../sgd/pytorch/pytorch_trainer.py | 150 +++++++++++ python/ray/experimental/sgd/pytorch/utils.py | 240 ++++++++++++++++++ python/ray/experimental/sgd/tests/__init__.py | 0 .../experimental/sgd/tests/pytorch_utils.py | 40 +++ .../experimental/sgd/tests/test_pytorch.py | 76 ++++++ 11 files changed, 751 insertions(+), 23 deletions(-) create mode 100644 doc/source/distributed_training.rst create mode 100644 python/ray/experimental/sgd/pytorch/__init__.py create mode 100644 python/ray/experimental/sgd/pytorch/pytorch_runner.py create mode 100644 python/ray/experimental/sgd/pytorch/pytorch_trainer.py create mode 100644 python/ray/experimental/sgd/pytorch/utils.py create mode 100644 python/ray/experimental/sgd/tests/__init__.py create mode 100644 python/ray/experimental/sgd/tests/pytorch_utils.py create mode 100644 python/ray/experimental/sgd/tests/test_pytorch.py diff --git a/ci/jenkins_tests/run_multi_node_tests.sh b/ci/jenkins_tests/run_multi_node_tests.sh index 1ab086792ca0..a07e36f87d57 100755 --- a/ci/jenkins_tests/run_multi_node_tests.sh +++ b/ci/jenkins_tests/run_multi_node_tests.sh @@ -31,25 +31,4 @@ $SUPPRESS_OUTPUT docker run --rm --shm-size=60G --memory=60G $DOCKER_SHA \ ######################## SGD TESTS ################################# $SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ - python /ray/python/ray/experimental/sgd/test_sgd.py --num-iters=2 \ - --batch-size=1 --strategy=simple - -$SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ - python /ray/python/ray/experimental/sgd/test_sgd.py --num-iters=2 \ - --batch-size=1 --strategy=ps - -$SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ - python /ray/python/ray/experimental/sgd/test_save_and_restore.py --num-iters=2 \ - --batch-size=1 --strategy=simple - -$SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ - python /ray/python/ray/experimental/sgd/test_save_and_restore.py --num-iters=2 \ - --batch-size=1 --strategy=ps - -$SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ - python /ray/python/ray/experimental/sgd/mnist_example.py --num-iters=1 \ - --num-workers=1 --devices-per-worker=1 --strategy=ps - -$SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ - python /ray/python/ray/experimental/sgd/mnist_example.py --num-iters=1 \ - --num-workers=1 --devices-per-worker=1 --strategy=ps --tune + python -m pytest /ray/python/ray/experimental/sgd/tests diff --git a/doc/source/conf.py b/doc/source/conf.py index e0bd2c6dad4c..b0ae3416d4ab 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -53,6 +53,10 @@ "tensorflow.python", "tensorflow.python.client", "tensorflow.python.util", + "torch", + "torch.distributed", + "torch.nn", + "torch.utils.data", ] for mod_name in MOCK_MODULES: sys.modules[mod_name] = mock.Mock() diff --git a/doc/source/distributed_training.rst b/doc/source/distributed_training.rst new file mode 100644 index 000000000000..61f3442eb181 --- /dev/null +++ b/doc/source/distributed_training.rst @@ -0,0 +1,48 @@ +Distributed Training (Experimental) +=================================== + + +Ray includes abstractions for distributed model training that integrate with +deep learning frameworks, such as PyTorch. + +Ray Train is built on top of the Ray task and actor abstractions to provide +seamless integration into existing Ray applications. + +PyTorch Interface +----------------- + +To use Ray Train with PyTorch, pass model and data creator functions to the +``ray.experimental.sgd.pytorch.PyTorchTrainer`` class. +To drive the distributed training, ``trainer.train()`` can be called +repeatedly. + +.. code-block:: python + + model_creator = lambda config: YourPyTorchModel() + data_creator = lambda config: YourTrainingSet(), YourValidationSet() + + trainer = PyTorchTrainer( + model_creator, + data_creator, + optimizer_creator=utils.sgd_mse_optimizer, + config={"lr": 1e-4}, + num_replicas=2, + resources_per_replica=Resources(num_gpus=1), + batch_size=16, + backend="auto") + + for i in range(NUM_EPOCHS): + trainer.train() + +Under the hood, Ray Train will create *replicas* of your model +(controlled by ``num_replicas``) which are each managed by a worker. +Multiple devices (e.g. GPUs) can be managed by each replica (controlled by ``resources_per_replica``), +which allows training of lage models across multiple GPUs. +The ``PyTorchTrainer`` class coordinates the distributed computation and training to improve the model. + +The full documentation for ``PyTorchTrainer`` is as follows: + +.. autoclass:: ray.experimental.sgd.pytorch.PyTorchTrainer + :members: + + .. automethod:: __init__ diff --git a/doc/source/index.rst b/doc/source/index.rst index a90e0224bb02..a8efb7a537dc 100644 --- a/doc/source/index.rst +++ b/doc/source/index.rst @@ -42,7 +42,7 @@ Ray comes with libraries that accelerate deep learning and reinforcement learnin - `Tune`_: Scalable Hyperparameter Search - `RLlib`_: Scalable Reinforcement Learning -- `Distributed Training `__ +- `Distributed Training `__ .. _`Tune`: tune.html .. _`RLlib`: rllib.html @@ -107,6 +107,7 @@ Ray comes with libraries that accelerate deep learning and reinforcement learnin :maxdepth: 1 :caption: Other Libraries + distributed_training.rst distributed_sgd.rst pandas_on_ray.rst diff --git a/python/ray/experimental/sgd/pytorch/__init__.py b/python/ray/experimental/sgd/pytorch/__init__.py new file mode 100644 index 000000000000..74a33016d88b --- /dev/null +++ b/python/ray/experimental/sgd/pytorch/__init__.py @@ -0,0 +1,8 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from ray.experimental.sgd.pytorch.pytorch_trainer import PyTorchTrainer +from ray.experimental.sgd.pytorch.utils import Resources + +__all__ = ["PyTorchTrainer", "Resources"] diff --git a/python/ray/experimental/sgd/pytorch/pytorch_runner.py b/python/ray/experimental/sgd/pytorch/pytorch_runner.py new file mode 100644 index 000000000000..5fe4ba1009f9 --- /dev/null +++ b/python/ray/experimental/sgd/pytorch/pytorch_runner.py @@ -0,0 +1,182 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import logging +import os +import torch +import torch.distributed as dist +import torch.utils.data + +import ray +from ray.experimental.sgd.pytorch import utils + +logger = logging.getLogger(__name__) + + +class PyTorchRunner(object): + """Manages a distributed PyTorch model replica""" + + def __init__(self, + model_creator, + data_creator, + optimizer_creator, + config=None, + batch_size=16, + backend="gloo"): + """Initializes the runner. + + Args: + model_creator (dict -> torch.nn.Module): creates the model using + the config. + data_creator (dict -> Dataset, Dataset): creates the training and + validation data sets using the config. + optimizer_creator (torch.nn.Module, dict -> loss, optimizer): + creates the loss and optimizer using the model and the config. + config (dict): configuration passed to 'model_creator', + 'data_creator', and 'optimizer_creator'. + batch_size (int): batch size used in an update. + backend (string): backend used by distributed PyTorch. + """ + + self.model_creator = model_creator + self.data_creator = data_creator + self.optimizer_creator = optimizer_creator + self.config = {} if config is None else config + self.batch_size = batch_size + self.backend = backend + self.verbose = True + + self.epoch = 0 + self._timers = { + k: utils.TimerStat(window_size=1) + for k in [ + "setup_proc", "setup_model", "get_state", "set_state", + "validation", "training" + ] + } + + def setup(self, url, world_rank, world_size): + """Connects to the distributed PyTorch backend and initializes the model. + + Args: + url (str): the URL used to connect to distributed PyTorch. + world_rank (int): the index of the runner. + world_size (int): the total number of runners. + """ + self._setup_distributed_pytorch(url, world_rank, world_size) + self._setup_training() + + def _setup_distributed_pytorch(self, url, world_rank, world_size): + os.environ["CUDA_LAUNCH_BLOCKING"] = "1" + with self._timers["setup_proc"]: + self.world_rank = world_rank + logger.debug( + "Connecting to {} world_rank: {} world_size: {}".format( + url, world_rank, world_size)) + logger.debug("using {}".format(self.backend)) + dist.init_process_group( + backend=self.backend, + init_method=url, + rank=world_rank, + world_size=world_size) + + def _setup_training(self): + logger.debug("Creating model") + self.model = self.model_creator(self.config) + if torch.cuda.is_available(): + self.model = torch.nn.parallel.DistributedDataParallel( + self.model.cuda()) + else: + self.model = torch.nn.parallel.DistributedDataParallelCPU( + self.model) + + logger.debug("Creating optimizer") + self.criterion, self.optimizer = self.optimizer_creator( + self.model, self.config) + + if torch.cuda.is_available(): + self.criterion = self.criterion.cuda() + + logger.debug("Creating dataset") + self.training_set, self.validation_set = self.data_creator(self.config) + + # TODO: make num_workers configurable + self.train_sampler = torch.utils.data.distributed.DistributedSampler( + self.training_set) + self.train_loader = torch.utils.data.DataLoader( + self.training_set, + batch_size=self.batch_size, + shuffle=(self.train_sampler is None), + num_workers=2, + pin_memory=False, + sampler=self.train_sampler) + + self.validation_sampler = ( + torch.utils.data.distributed.DistributedSampler( + self.validation_set)) + self.validation_loader = torch.utils.data.DataLoader( + self.validation_set, + batch_size=self.batch_size, + shuffle=(self.validation_sampler is None), + num_workers=2, + pin_memory=False, + sampler=self.validation_sampler) + + def get_node_ip(self): + """Returns the IP address of the current node""" + return ray.services.get_node_ip_address() + + def step(self): + """Runs a training epoch and updates the model parameters""" + logger.debug("Starting step") + self.train_sampler.set_epoch(self.epoch) + + logger.debug("Begin Training Epoch {}".format(self.epoch + 1)) + with self._timers["training"]: + train_stats = utils.train(self.train_loader, self.model, + self.criterion, self.optimizer) + train_stats["epoch"] = self.epoch + + self.epoch += 1 + + train_stats.update(self.stats()) + return train_stats + + def validate(self): + """Evaluates the model on the validation data set""" + with self._timers["validation"]: + validation_stats = utils.validate(self.validation_loader, + self.model, self.criterion) + + validation_stats.update(self.stats()) + return validation_stats + + def stats(self): + """Returns a dictionary of statistics collected""" + stats = {"epoch": self.epoch} + for k, t in self._timers.items(): + stats[k + "_time_mean"] = t.mean + stats[k + "_time_total"] = t.sum + t.reset() + return stats + + def get_state(self): + """Returns the state of the runner""" + return { + "epoch": self.epoch, + "model": self.model.state_dict(), + "optimizer": self.optimizer.state_dict(), + "stats": self.stats() + } + + def set_state(self, state): + """Sets the state of the model""" + # TODO: restore timer stats + self.model.load_state_dict(state["model"]) + self.optimizer.load_state_dict(state["optimizer"]) + self.epoch = state["stats"]["epoch"] + + def shutdown(self): + """Attempts to shut down the worker""" + dist.destroy_process_group() diff --git a/python/ray/experimental/sgd/pytorch/pytorch_trainer.py b/python/ray/experimental/sgd/pytorch/pytorch_trainer.py new file mode 100644 index 000000000000..073ad3d34042 --- /dev/null +++ b/python/ray/experimental/sgd/pytorch/pytorch_trainer.py @@ -0,0 +1,150 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import sys +import torch +import logging + +import ray + +from ray.experimental.sgd.pytorch.pytorch_runner import PyTorchRunner +from ray.experimental.sgd.pytorch import utils + +logger = logging.getLogger(__name__) + + +class PyTorchTrainer(object): + """Train a PyTorch model using distributed PyTorch. + + Launches a set of actors which connect via distributed PyTorch and + coordinate gradient updates to train the provided model. + """ + + def __init__(self, + model_creator, + data_creator, + optimizer_creator=utils.sgd_mse_optimizer, + config=None, + num_replicas=1, + resources_per_replica=None, + batch_size=16, + backend="auto"): + """Sets up the PyTorch trainer. + + Args: + model_creator (dict -> torch.nn.Module): creates the model + using the config. + data_creator (dict -> Dataset, Dataset): creates the training + and validation data sets using the config. + optimizer_creator (torch.nn.Module, dict -> loss, optimizer): + creates the loss and optimizer using the model and the config. + config (dict): configuration passed to 'model_creator', + 'data_creator', and 'optimizer_creator'. + num_replicas (int): the number of workers used in distributed + training. + resources_per_replica (Resources): resources used by each worker. + Defaults to Resources(num_cpus=1). + batch_size (int): batch size for an update. + backend (string): backend used by distributed PyTorch. + """ + # TODO: add support for mixed precision + # TODO: add support for callbacks + if sys.platform == "darwin": + raise Exception( + ("Distributed PyTorch is not supported on macOS. For more " + "information, see " + "https://github.com/pytorch/examples/issues/467.")) + + self.model_creator = model_creator + self.config = {} if config is None else config + self.optimizer_timer = utils.TimerStat(window_size=1) + + if resources_per_replica is None: + resources_per_replica = utils.Resources( + num_cpus=1, num_gpus=0, resources={}) + + if backend == "auto": + backend = "nccl" if resources_per_replica.num_gpus > 0 else "gloo" + + Runner = ray.remote( + num_cpus=resources_per_replica.num_cpus, + num_gpus=resources_per_replica.num_gpus, + resources=resources_per_replica.resources)(PyTorchRunner) + + batch_size_per_replica = batch_size // num_replicas + if batch_size % num_replicas > 0: + new_batch_size = batch_size_per_replica * num_replicas + logger.warn( + ("Changing batch size from {old_batch_size} to " + "{new_batch_size} to evenly distribute batches across " + "{num_replicas} replicas.").format( + old_batch_size=batch_size, + new_batch_size=new_batch_size, + num_replicas=num_replicas)) + + self.workers = [ + Runner.remote(model_creator, data_creator, optimizer_creator, + self.config, batch_size_per_replica, backend) + for i in range(num_replicas) + ] + + ip = ray.get(self.workers[0].get_node_ip.remote()) + port = utils.find_free_port() + address = "tcp://{ip}:{port}".format(ip=ip, port=port) + + # Get setup tasks in order to throw errors on failure + ray.get([ + worker.setup.remote(address, i, len(self.workers)) + for i, worker in enumerate(self.workers) + ]) + + def train(self): + """Runs a training epoch""" + with self.optimizer_timer: + worker_stats = ray.get([w.step.remote() for w in self.workers]) + + train_stats = worker_stats[0].copy() + train_stats["train_loss"] = np.mean( + [s["train_loss"] for s in worker_stats]) + return train_stats + + def validate(self): + """Evaluates the model on the validation data set""" + worker_stats = ray.get([w.validate.remote() for w in self.workers]) + validation_stats = worker_stats[0].copy() + validation_stats["validation_loss"] = np.mean( + [s["validation_loss"] for s in worker_stats]) + return validation_stats + + def get_model(self): + """Returns the learned model""" + model = self.model_creator(self.config) + state = ray.get(self.workers[0].get_state.remote()) + + # Remove module. prefix added by distrbuted pytorch + state_dict = { + k.replace("module.", ""): v + for k, v in state["model"].items() + } + + model.load_state_dict(state_dict) + return model + + def save(self, ckpt): + """Saves the model at the provided checkpoint""" + state = ray.get(self.workers[0].get_state.remote()) + torch.save(state, ckpt) + + def restore(self, ckpt): + """Restores the model from the provided checkpoint""" + state = torch.load(ckpt) + state_id = ray.put(state) + ray.get([worker.set_state.remote(state_id) for worker in self.workers]) + + def shutdown(self): + """Shuts down workers and releases resources""" + for worker in self.workers: + worker.shutdown.remote() + worker.__ray_terminate__.remote() diff --git a/python/ray/experimental/sgd/pytorch/utils.py b/python/ray/experimental/sgd/pytorch/utils.py new file mode 100644 index 000000000000..f7c6e4abac97 --- /dev/null +++ b/python/ray/experimental/sgd/pytorch/utils.py @@ -0,0 +1,240 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import namedtuple +from contextlib import closing +import numpy as np +import socket +import time +import torch +import torch.nn as nn + + +def train(train_iterator, model, criterion, optimizer): + """Runs 1 training epoch""" + batch_time = AverageMeter() + data_time = AverageMeter() + losses = AverageMeter() + + timers = {k: TimerStat() for k in ["d2h", "fwd", "grad", "apply"]} + + # switch to train mode + model.train() + + end = time.time() + + for i, (features, target) in enumerate(train_iterator): + # measure data loading time + data_time.update(time.time() - end) + + # Create non_blocking tensors for distributed training + with timers["d2h"]: + if torch.cuda.is_available(): + features = features.cuda(non_blocking=True) + target = target.cuda(non_blocking=True) + + # compute output + with timers["fwd"]: + output = model(features) + loss = criterion(output, target) + + # measure accuracy and record loss + losses.update(loss.item(), features.size(0)) + + with timers["grad"]: + # compute gradients in a backward pass + optimizer.zero_grad() + loss.backward() + + with timers["apply"]: + # Call step of optimizer to update model params + optimizer.step() + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + stats = { + "batch_time": batch_time.avg, + "batch_processed": losses.count, + "train_loss": losses.avg, + "data_time": data_time.avg, + } + stats.update({k: t.mean for k, t in timers.items()}) + return stats + + +def validate(val_loader, model, criterion): + batch_time = AverageMeter() + losses = AverageMeter() + + # switch to evaluate mode + model.eval() + + with torch.no_grad(): + end = time.time() + for i, (features, target) in enumerate(val_loader): + + if torch.cuda.is_available(): + features = features.cuda(non_blocking=True) + target = target.cuda(non_blocking=True) + + # compute output + output = model(features) + loss = criterion(output, target) + + # measure accuracy and record loss + losses.update(loss.item(), features.size(0)) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + stats = {"batch_time": batch_time.avg, "validation_loss": losses.avg} + return stats + + +class TimerStat(object): + """A running stat for conveniently logging the duration of a code block. + + Note that this class is *not* thread-safe. + + Examples: + Time a call to 'time.sleep'. + + >>> import time + >>> sleep_timer = TimerStat() + >>> with sleep_timer: + ... time.sleep(1) + >>> round(sleep_timer.mean) + 1 + """ + + def __init__(self, window_size=10): + self._window_size = window_size + self._samples = [] + self._units_processed = [] + self._start_time = None + self._total_time = 0.0 + self.count = 0 + + def __enter__(self): + assert self._start_time is None, "concurrent updates not supported" + self._start_time = time.time() + + def __exit__(self, type, value, tb): + assert self._start_time is not None + time_delta = time.time() - self._start_time + self.push(time_delta) + self._start_time = None + + def push(self, time_delta): + self._samples.append(time_delta) + if len(self._samples) > self._window_size: + self._samples.pop(0) + self.count += 1 + self._total_time += time_delta + + def push_units_processed(self, n): + self._units_processed.append(n) + if len(self._units_processed) > self._window_size: + self._units_processed.pop(0) + + @property + def mean(self): + return np.mean(self._samples) + + @property + def median(self): + return np.median(self._samples) + + @property + def sum(self): + return np.sum(self._samples) + + @property + def max(self): + return np.max(self._samples) + + @property + def first(self): + return self._samples[0] if self._samples else None + + @property + def last(self): + return self._samples[-1] if self._samples else None + + @property + def size(self): + return len(self._samples) + + @property + def mean_units_processed(self): + return float(np.mean(self._units_processed)) + + @property + def mean_throughput(self): + time_total = sum(self._samples) + if not time_total: + return 0.0 + return sum(self._units_processed) / time_total + + def reset(self): + self._samples = [] + self._units_processed = [] + self._start_time = None + self._total_time = 0.0 + self.count = 0 + + +def find_free_port(): + with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: + s.bind(("", 0)) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + return s.getsockname()[1] + + +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +class Resources( + namedtuple("Resources", ["num_cpus", "num_gpus", "resources"])): + __slots__ = () + + def __new__(cls, num_cpus=1, num_gpus=0, resources=None): + if resources is None: + resources = {} + + return super(Resources, cls).__new__(cls, num_cpus, num_gpus, + resources) + + +def sgd_mse_optimizer(model, config): + """Returns the mean squared error criterion and SGD optimizer. + + Args: + model (torch.nn.Module): the model to optimize. + config (dict): configuration for the optimizer. + lr (float): the learning rate. defaults to 0.01. + """ + learning_rate = config.get("lr", 0.01) + criterion = nn.MSELoss() + optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) + return criterion, optimizer diff --git a/python/ray/experimental/sgd/tests/__init__.py b/python/ray/experimental/sgd/tests/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/python/ray/experimental/sgd/tests/pytorch_utils.py b/python/ray/experimental/sgd/tests/pytorch_utils.py new file mode 100644 index 000000000000..6299fff1c13c --- /dev/null +++ b/python/ray/experimental/sgd/tests/pytorch_utils.py @@ -0,0 +1,40 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import torch +import torch.nn as nn +import torch.utils.data + + +class LinearDataset(torch.utils.data.Dataset): + """y = a * x + b""" + + def __init__(self, a, b, size=1000): + x = np.random.random(size).astype(np.float32) * 10 + x = np.arange(0, 10, 10 / size, dtype=np.float32) + self.x = torch.from_numpy(x) + self.y = torch.from_numpy(a * x + b) + + def __getitem__(self, index): + return self.x[index, None], self.y[index, None] + + def __len__(self): + return len(self.x) + + +def model_creator(config): + return nn.Linear(1, 1) + + +def optimizer_creator(model, config): + """Returns criterion, optimizer""" + criterion = nn.MSELoss() + optimizer = torch.optim.SGD(model.parameters(), lr=1e-4) + return criterion, optimizer + + +def data_creator(config): + """Returns training set, validation set""" + return LinearDataset(2, 5), LinearDataset(2, 5, size=400) diff --git a/python/ray/experimental/sgd/tests/test_pytorch.py b/python/ray/experimental/sgd/tests/test_pytorch.py new file mode 100644 index 000000000000..faff23f8a809 --- /dev/null +++ b/python/ray/experimental/sgd/tests/test_pytorch.py @@ -0,0 +1,76 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import pytest +import sys +import tempfile +import torch + +from ray.tests.conftest import ray_start_2_cpus # noqa: F401 +from ray.experimental.sgd.pytorch import PyTorchTrainer, Resources + +from ray.experimental.sgd.tests.pytorch_utils import ( + model_creator, optimizer_creator, data_creator) + + +@pytest.mark.skipif( # noqa: F811 + sys.platform == "darwin", reason="Doesn't work on macOS.") +def test_train(ray_start_2_cpus): # noqa: F811 + trainer = PyTorchTrainer( + model_creator, + data_creator, + optimizer_creator, + num_replicas=2, + resources_per_replica=Resources(num_cpus=1)) + train_loss1 = trainer.train()["train_loss"] + validation_loss1 = trainer.validate()["validation_loss"] + + train_loss2 = trainer.train()["train_loss"] + validation_loss2 = trainer.validate()["validation_loss"] + + print(train_loss1, train_loss2) + print(validation_loss1, validation_loss2) + + assert train_loss2 <= train_loss1 + assert validation_loss2 <= validation_loss1 + + +@pytest.mark.skipif( # noqa: F811 + sys.platform == "darwin", reason="Doesn't work on macOS.") +def test_save_and_restore(ray_start_2_cpus): # noqa: F811 + trainer1 = PyTorchTrainer( + model_creator, + data_creator, + optimizer_creator, + num_replicas=2, + resources_per_replica=Resources(num_cpus=1)) + trainer1.train() + + filename = os.path.join(tempfile.mkdtemp(), "checkpoint") + trainer1.save(filename) + + model1 = trainer1.get_model() + + trainer1.shutdown() + + trainer2 = PyTorchTrainer( + model_creator, + data_creator, + optimizer_creator, + num_replicas=2, + resources_per_replica=Resources(num_cpus=1)) + trainer2.restore(filename) + + os.remove(filename) + + model2 = trainer2.get_model() + + model1_state_dict = model1.state_dict() + model2_state_dict = model2.state_dict() + + assert set(model1_state_dict.keys()) == set(model2_state_dict.keys()) + + for k in model1_state_dict: + assert torch.equal(model1_state_dict[k], model2_state_dict[k]) From 665d081fe914eca4806413253aa02c679368785d Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Sun, 2 Jun 2019 14:14:31 +0800 Subject: [PATCH 060/118] [rllib] Rough port of DQN to build_tf_policy() pattern (#4823) --- python/ray/rllib/agents/a3c/a3c_tf_policy.py | 224 +++---- python/ray/rllib/agents/dqn/dqn_policy.py | 555 +++++++++--------- python/ray/rllib/agents/ppo/appo_policy.py | 7 +- python/ray/rllib/agents/ppo/ppo_policy.py | 22 +- .../rllib/evaluation/tf_policy_template.py | 146 ----- python/ray/rllib/policy/dynamic_tf_policy.py | 102 ++-- python/ray/rllib/policy/tf_policy.py | 41 +- python/ray/rllib/policy/tf_policy_template.py | 76 ++- 8 files changed, 547 insertions(+), 626 deletions(-) delete mode 100644 python/ray/rllib/evaluation/tf_policy_template.py diff --git a/python/ray/rllib/agents/a3c/a3c_tf_policy.py b/python/ray/rllib/agents/a3c/a3c_tf_policy.py index eb5becceaa71..ed3676472850 100644 --- a/python/ray/rllib/agents/a3c/a3c_tf_policy.py +++ b/python/ray/rllib/agents/a3c/a3c_tf_policy.py @@ -4,20 +4,13 @@ from __future__ import division from __future__ import print_function -import gym - import ray -from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.utils.error import UnsupportedSpaceException from ray.rllib.utils.explained_variance import explained_variance -from ray.rllib.policy.policy import Policy from ray.rllib.evaluation.postprocessing import compute_advantages, \ Postprocessing -from ray.rllib.policy.tf_policy import TFPolicy, \ - LearningRateSchedule -from ray.rllib.models.catalog import ModelCatalog -from ray.rllib.utils.annotations import override +from ray.rllib.policy.tf_policy_template import build_tf_policy +from ray.rllib.policy.tf_policy import LearningRateSchedule from ray.rllib.utils import try_import_tf tf = try_import_tf() @@ -44,144 +37,97 @@ def __init__(self, self.entropy * entropy_coeff) -class A3CPostprocessing(object): - """Adds the VF preds and advantages fields to the trajectory.""" - - @override(TFPolicy) - def extra_compute_action_fetches(self): - return dict( - TFPolicy.extra_compute_action_fetches(self), - **{SampleBatch.VF_PREDS: self.vf}) - - @override(Policy) - def postprocess_trajectory(self, - sample_batch, - other_agent_batches=None, - episode=None): - completed = sample_batch[SampleBatch.DONES][-1] - if completed: - last_r = 0.0 - else: - next_state = [] - for i in range(len(self.model.state_in)): - next_state.append([sample_batch["state_out_{}".format(i)][-1]]) - last_r = self._value(sample_batch[SampleBatch.NEXT_OBS][-1], - sample_batch[SampleBatch.ACTIONS][-1], - sample_batch[SampleBatch.REWARDS][-1], - *next_state) - return compute_advantages(sample_batch, last_r, self.config["gamma"], - self.config["lambda"]) - - -class A3CTFPolicy(LearningRateSchedule, A3CPostprocessing, TFPolicy): - def __init__(self, observation_space, action_space, config): - config = dict(ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG, **config) - self.config = config - self.sess = tf.get_default_session() - - # Setup the policy - self.observations = tf.placeholder( - tf.float32, [None] + list(observation_space.shape)) - dist_class, logit_dim = ModelCatalog.get_action_dist( - action_space, self.config["model"]) - self.prev_actions = ModelCatalog.get_action_placeholder(action_space) - self.prev_rewards = tf.placeholder( - tf.float32, [None], name="prev_reward") - self.model = ModelCatalog.get_model({ - "obs": self.observations, - "prev_actions": self.prev_actions, - "prev_rewards": self.prev_rewards, - "is_training": self._get_is_training_placeholder(), - }, observation_space, action_space, logit_dim, self.config["model"]) - action_dist = dist_class(self.model.outputs) +def actor_critic_loss(policy, batch_tensors): + policy.loss = A3CLoss( + policy.action_dist, batch_tensors[SampleBatch.ACTIONS], + batch_tensors[Postprocessing.ADVANTAGES], + batch_tensors[Postprocessing.VALUE_TARGETS], policy.vf, + policy.config["vf_loss_coeff"], policy.config["entropy_coeff"]) + return policy.loss.total_loss + + +def postprocess_advantages(policy, + sample_batch, + other_agent_batches=None, + episode=None): + completed = sample_batch[SampleBatch.DONES][-1] + if completed: + last_r = 0.0 + else: + next_state = [] + for i in range(len(policy.model.state_in)): + next_state.append([sample_batch["state_out_{}".format(i)][-1]]) + last_r = policy._value(sample_batch[SampleBatch.NEXT_OBS][-1], + sample_batch[SampleBatch.ACTIONS][-1], + sample_batch[SampleBatch.REWARDS][-1], + *next_state) + return compute_advantages(sample_batch, last_r, policy.config["gamma"], + policy.config["lambda"]) + + +def add_value_function_fetch(policy): + return {SampleBatch.VF_PREDS: policy.vf} + + +class ValueNetworkMixin(object): + def __init__(self): self.vf = self.model.value_function() - self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, - tf.get_variable_scope().name) - - # Setup the policy loss - if isinstance(action_space, gym.spaces.Box): - ac_size = action_space.shape[0] - actions = tf.placeholder(tf.float32, [None, ac_size], name="ac") - elif isinstance(action_space, gym.spaces.Discrete): - actions = tf.placeholder(tf.int64, [None], name="ac") - else: - raise UnsupportedSpaceException( - "Action space {} is not supported for A3C.".format( - action_space)) - advantages = tf.placeholder(tf.float32, [None], name="advantages") - self.v_target = tf.placeholder(tf.float32, [None], name="v_target") - self.loss = A3CLoss(action_dist, actions, advantages, self.v_target, - self.vf, self.config["vf_loss_coeff"], - self.config["entropy_coeff"]) - - # Initialize TFPolicy - loss_in = [ - (SampleBatch.CUR_OBS, self.observations), - (SampleBatch.ACTIONS, actions), - (SampleBatch.PREV_ACTIONS, self.prev_actions), - (SampleBatch.PREV_REWARDS, self.prev_rewards), - (Postprocessing.ADVANTAGES, advantages), - (Postprocessing.VALUE_TARGETS, self.v_target), - ] - LearningRateSchedule.__init__(self, self.config["lr"], - self.config["lr_schedule"]) - TFPolicy.__init__( - self, - observation_space, - action_space, - self.sess, - obs_input=self.observations, - action_sampler=action_dist.sample(), - action_prob=action_dist.sampled_action_prob(), - loss=self.loss.total_loss, - model=self.model, - loss_inputs=loss_in, - state_inputs=self.model.state_in, - state_outputs=self.model.state_out, - prev_action_input=self.prev_actions, - prev_reward_input=self.prev_rewards, - seq_lens=self.model.seq_lens, - max_seq_len=self.config["model"]["max_seq_len"]) - - self.stats_fetches = { - LEARNER_STATS_KEY: { - "cur_lr": tf.cast(self.cur_lr, tf.float64), - "policy_loss": self.loss.pi_loss, - "policy_entropy": self.loss.entropy, - "grad_gnorm": tf.global_norm(self._grads), - "var_gnorm": tf.global_norm(self.var_list), - "vf_loss": self.loss.vf_loss, - "vf_explained_var": explained_variance(self.v_target, self.vf), - }, - } - - self.sess.run(tf.global_variables_initializer()) - - @override(Policy) - def get_initial_state(self): - return self.model.state_init - - @override(TFPolicy) - def gradients(self, optimizer, loss): - grads = tf.gradients(loss, self.var_list) - self.grads, _ = tf.clip_by_global_norm(grads, self.config["grad_clip"]) - clipped_grads = list(zip(self.grads, self.var_list)) - return clipped_grads - - @override(TFPolicy) - def extra_compute_grad_fetches(self): - return self.stats_fetches def _value(self, ob, prev_action, prev_reward, *args): feed_dict = { - self.observations: [ob], - self.prev_actions: [prev_action], - self.prev_rewards: [prev_reward], + self.get_placeholder(SampleBatch.CUR_OBS): [ob], + self.get_placeholder(SampleBatch.PREV_ACTIONS): [prev_action], + self.get_placeholder(SampleBatch.PREV_REWARDS): [prev_reward], self.model.seq_lens: [1] } assert len(args) == len(self.model.state_in), \ (args, self.model.state_in) for k, v in zip(self.model.state_in, args): feed_dict[k] = v - vf = self.sess.run(self.vf, feed_dict) + vf = self.get_session().run(self.vf, feed_dict) return vf[0] + + +def stats(policy, batch_tensors): + return { + "cur_lr": tf.cast(policy.cur_lr, tf.float64), + "policy_loss": policy.loss.pi_loss, + "policy_entropy": policy.loss.entropy, + "var_gnorm": tf.global_norm(policy.var_list), + "vf_loss": policy.loss.vf_loss, + } + + +def grad_stats(policy, grads): + return { + "grad_gnorm": tf.global_norm(grads), + "vf_explained_var": explained_variance( + policy.get_placeholder(Postprocessing.VALUE_TARGETS), policy.vf), + } + + +def clip_gradients(policy, optimizer, loss): + grads = tf.gradients(loss, policy.var_list) + grads, _ = tf.clip_by_global_norm(grads, policy.config["grad_clip"]) + clipped_grads = list(zip(grads, policy.var_list)) + return clipped_grads + + +def setup_mixins(policy, obs_space, action_space, config): + ValueNetworkMixin.__init__(policy) + LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"]) + policy.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, + tf.get_variable_scope().name) + + +A3CTFPolicy = build_tf_policy( + name="A3CTFPolicy", + get_default_config=lambda: ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG, + loss_fn=actor_critic_loss, + stats_fn=stats, + grad_stats_fn=grad_stats, + gradients_fn=clip_gradients, + postprocess_fn=postprocess_advantages, + extra_action_fetches_fn=add_value_function_fetch, + before_loss_init=setup_mixins, + mixins=[ValueNetworkMixin, LearningRateSchedule]) diff --git a/python/ray/rllib/agents/dqn/dqn_policy.py b/python/ray/rllib/agents/dqn/dqn_policy.py index a1affa947a43..505930406fbb 100644 --- a/python/ray/rllib/agents/dqn/dqn_policy.py +++ b/python/ray/rllib/agents/dqn/dqn_policy.py @@ -7,14 +7,14 @@ from scipy.stats import entropy import ray +from ray.rllib.policy.policy import Policy from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY from ray.rllib.models import ModelCatalog, Categorical from ray.rllib.utils.annotations import override from ray.rllib.utils.error import UnsupportedSpaceException -from ray.rllib.policy.policy import Policy from ray.rllib.policy.tf_policy import TFPolicy, \ LearningRateSchedule +from ray.rllib.policy.tf_policy_template import build_tf_policy from ray.rllib.utils import try_import_tf tf = try_import_tf() @@ -102,46 +102,6 @@ def __init__(self, } -class DQNPostprocessing(object): - """Implements n-step learning and param noise adjustments.""" - - @override(TFPolicy) - def extra_compute_action_fetches(self): - return dict( - TFPolicy.extra_compute_action_fetches(self), **{ - "q_values": self.q_values, - }) - - @override(Policy) - def postprocess_trajectory(self, - sample_batch, - other_agent_batches=None, - episode=None): - if self.config["parameter_noise"]: - # adjust the sigma of parameter space noise - states = [list(x) for x in sample_batch.columns(["obs"])][0] - - noisy_action_distribution = self.sess.run( - self.action_probs, feed_dict={self.cur_observations: states}) - self.sess.run(self.remove_noise_op) - clean_action_distribution = self.sess.run( - self.action_probs, feed_dict={self.cur_observations: states}) - distance_in_action_space = np.mean( - entropy(clean_action_distribution.T, - noisy_action_distribution.T)) - self.pi_distance = distance_in_action_space - if (distance_in_action_space < - -np.log(1 - self.cur_epsilon + - self.cur_epsilon / self.num_actions)): - self.parameter_noise_sigma_val *= 1.01 - else: - self.parameter_noise_sigma_val /= 1.01 - self.parameter_noise_sigma.load( - self.parameter_noise_sigma_val, session=self.sess) - - return _postprocess_dqn(self, sample_batch) - - class QNetwork(object): def __init__(self, model, @@ -345,98 +305,31 @@ def __init__(self, q_values, observations, num_actions, stochastic, eps, self.action_prob = None -class DQNTFPolicy(LearningRateSchedule, DQNPostprocessing, TFPolicy): - def __init__(self, observation_space, action_space, config): - config = dict(ray.rllib.agents.dqn.dqn.DEFAULT_CONFIG, **config) - if not isinstance(action_space, Discrete): - raise UnsupportedSpaceException( - "Action space {} is not supported for DQN.".format( - action_space)) - - self.config = config +class ExplorationStateMixin(object): + def __init__(self, obs_space, action_space, config): self.cur_epsilon = 1.0 - self.num_actions = action_space.n - - # Action inputs self.stochastic = tf.placeholder(tf.bool, (), name="stochastic") self.eps = tf.placeholder(tf.float32, (), name="eps") - self.cur_observations = tf.placeholder( - tf.float32, shape=(None, ) + observation_space.shape) - - # Action Q network - with tf.variable_scope(Q_SCOPE) as scope: - q_values, q_logits, q_dist, _ = self._build_q_network( - self.cur_observations, observation_space, action_space) - self.q_values = q_values - self.q_func_vars = _scope_vars(scope.name) - # Noise vars for Q network except for layer normalization vars + def add_parameter_noise(self): if self.config["parameter_noise"]: - self._build_parameter_noise([ - var for var in self.q_func_vars if "LayerNorm" not in var.name - ]) - self.action_probs = tf.nn.softmax(self.q_values) - - # Action outputs - self.output_actions, self.action_prob = self._build_q_value_policy( - q_values) - - # Replay inputs - self.obs_t = tf.placeholder( - tf.float32, shape=(None, ) + observation_space.shape) - self.act_t = tf.placeholder(tf.int32, [None], name="action") - self.rew_t = tf.placeholder(tf.float32, [None], name="reward") - self.obs_tp1 = tf.placeholder( - tf.float32, shape=(None, ) + observation_space.shape) - self.done_mask = tf.placeholder(tf.float32, [None], name="done") - self.importance_weights = tf.placeholder( - tf.float32, [None], name="weight") - - # q network evaluation - with tf.variable_scope(Q_SCOPE, reuse=True): - prev_update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) - q_t, q_logits_t, q_dist_t, model = self._build_q_network( - self.obs_t, observation_space, action_space) - q_batchnorm_update_ops = list( - set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) - - prev_update_ops) - - # target q network evalution - with tf.variable_scope(Q_TARGET_SCOPE) as scope: - q_tp1, q_logits_tp1, q_dist_tp1, _ = self._build_q_network( - self.obs_tp1, observation_space, action_space) - self.target_q_func_vars = _scope_vars(scope.name) - - # q scores for actions which we know were selected in the given state. - one_hot_selection = tf.one_hot(self.act_t, self.num_actions) - q_t_selected = tf.reduce_sum(q_t * one_hot_selection, 1) - q_logits_t_selected = tf.reduce_sum( - q_logits_t * tf.expand_dims(one_hot_selection, -1), 1) - - # compute estimate of best possible value starting from state at t + 1 - if config["double_q"]: - with tf.variable_scope(Q_SCOPE, reuse=True): - q_tp1_using_online_net, q_logits_tp1_using_online_net, \ - q_dist_tp1_using_online_net, _ = self._build_q_network( - self.obs_tp1, observation_space, action_space) - q_tp1_best_using_online_net = tf.argmax(q_tp1_using_online_net, 1) - q_tp1_best_one_hot_selection = tf.one_hot( - q_tp1_best_using_online_net, self.num_actions) - q_tp1_best = tf.reduce_sum(q_tp1 * q_tp1_best_one_hot_selection, 1) - q_dist_tp1_best = tf.reduce_sum( - q_dist_tp1 * tf.expand_dims(q_tp1_best_one_hot_selection, -1), - 1) - else: - q_tp1_best_one_hot_selection = tf.one_hot( - tf.argmax(q_tp1, 1), self.num_actions) - q_tp1_best = tf.reduce_sum(q_tp1 * q_tp1_best_one_hot_selection, 1) - q_dist_tp1_best = tf.reduce_sum( - q_dist_tp1 * tf.expand_dims(q_tp1_best_one_hot_selection, -1), - 1) + self.sess.run(self.add_noise_op) + + def set_epsilon(self, epsilon): + self.cur_epsilon = epsilon - self.loss = self._build_q_loss(q_t_selected, q_logits_t_selected, - q_tp1_best, q_dist_tp1_best) + @override(Policy) + def get_state(self): + return [TFPolicy.get_state(self), self.cur_epsilon] + @override(Policy) + def set_state(self, state): + TFPolicy.set_state(self, state[0]) + self.set_epsilon(state[1]) + + +class TargetNetworkMixin(object): + def __init__(self, obs_space, action_space, config): # update_target_fn will be called periodically to copy Q network to # target Q network update_target_expr = [] @@ -446,166 +339,250 @@ def __init__(self, observation_space, action_space, config): update_target_expr.append(var_target.assign(var)) self.update_target_expr = tf.group(*update_target_expr) - # initialize TFPolicy - self.sess = tf.get_default_session() - self.loss_inputs = [ - (SampleBatch.CUR_OBS, self.obs_t), - (SampleBatch.ACTIONS, self.act_t), - (SampleBatch.REWARDS, self.rew_t), - (SampleBatch.NEXT_OBS, self.obs_tp1), - (SampleBatch.DONES, self.done_mask), - (PRIO_WEIGHTS, self.importance_weights), - ] - - LearningRateSchedule.__init__(self, self.config["lr"], - self.config["lr_schedule"]) - TFPolicy.__init__( - self, - observation_space, - action_space, - self.sess, - obs_input=self.cur_observations, - action_sampler=self.output_actions, - action_prob=self.action_prob, - loss=self.loss.loss, - model=model, - loss_inputs=self.loss_inputs, - update_ops=q_batchnorm_update_ops) - self.sess.run(tf.global_variables_initializer()) - - self.stats_fetches = dict({ - "cur_lr": tf.cast(self.cur_lr, tf.float64), - }, **self.loss.stats) - - @override(TFPolicy) - def optimizer(self): - return tf.train.AdamOptimizer( - learning_rate=self.cur_lr, epsilon=self.config["adam_epsilon"]) - - @override(TFPolicy) - def gradients(self, optimizer, loss): - if self.config["grad_norm_clipping"] is not None: - grads_and_vars = _minimize_and_clip( - optimizer, - loss, - var_list=self.q_func_vars, - clip_val=self.config["grad_norm_clipping"]) - else: - grads_and_vars = optimizer.compute_gradients( - loss, var_list=self.q_func_vars) - grads_and_vars = [(g, v) for (g, v) in grads_and_vars if g is not None] - return grads_and_vars - - @override(TFPolicy) - def extra_compute_action_feed_dict(self): - return { - self.stochastic: True, - self.eps: self.cur_epsilon, - } - - @override(TFPolicy) - def extra_compute_grad_fetches(self): - return { - "td_error": self.loss.td_error, - LEARNER_STATS_KEY: self.stats_fetches, - } - - @override(Policy) - def get_state(self): - return [TFPolicy.get_state(self), self.cur_epsilon] - - @override(Policy) - def set_state(self, state): - TFPolicy.set_state(self, state[0]) - self.set_epsilon(state[1]) + def update_target(self): + return self.get_session().run(self.update_target_expr) - def _build_parameter_noise(self, pnet_params): - self.parameter_noise_sigma_val = 1.0 - self.parameter_noise_sigma = tf.get_variable( - initializer=tf.constant_initializer( - self.parameter_noise_sigma_val), - name="parameter_noise_sigma", - shape=(), - trainable=False, - dtype=tf.float32) - self.parameter_noise = list() - # No need to add any noise on LayerNorm parameters - for var in pnet_params: - noise_var = tf.get_variable( - name=var.name.split(":")[0] + "_noise", - shape=var.shape, - initializer=tf.constant_initializer(.0), - trainable=False) - self.parameter_noise.append(noise_var) - remove_noise_ops = list() - for var, var_noise in zip(pnet_params, self.parameter_noise): - remove_noise_ops.append(tf.assign_add(var, -var_noise)) - self.remove_noise_op = tf.group(*tuple(remove_noise_ops)) - generate_noise_ops = list() - for var_noise in self.parameter_noise: - generate_noise_ops.append( - tf.assign( - var_noise, - tf.random_normal( - shape=var_noise.shape, - stddev=self.parameter_noise_sigma))) - with tf.control_dependencies(generate_noise_ops): - add_noise_ops = list() - for var, var_noise in zip(pnet_params, self.parameter_noise): - add_noise_ops.append(tf.assign_add(var, var_noise)) - self.add_noise_op = tf.group(*tuple(add_noise_ops)) - self.pi_distance = None +class ComputeTDErrorMixin(object): def compute_td_error(self, obs_t, act_t, rew_t, obs_tp1, done_mask, importance_weights): - td_err = self.sess.run( + if not self.loss_initialized(): + return np.zeros_like(rew_t) + + td_err = self.get_session().run( self.loss.td_error, feed_dict={ - self.obs_t: [np.array(ob) for ob in obs_t], - self.act_t: act_t, - self.rew_t: rew_t, - self.obs_tp1: [np.array(ob) for ob in obs_tp1], - self.done_mask: done_mask, - self.importance_weights: importance_weights + self.get_placeholder(SampleBatch.CUR_OBS): [ + np.array(ob) for ob in obs_t + ], + self.get_placeholder(SampleBatch.ACTIONS): act_t, + self.get_placeholder(SampleBatch.REWARDS): rew_t, + self.get_placeholder(SampleBatch.NEXT_OBS): [ + np.array(ob) for ob in obs_tp1 + ], + self.get_placeholder(SampleBatch.DONES): done_mask, + self.get_placeholder(PRIO_WEIGHTS): importance_weights, }) return td_err - def add_parameter_noise(self): - if self.config["parameter_noise"]: - self.sess.run(self.add_noise_op) - def update_target(self): - return self.sess.run(self.update_target_expr) - - def set_epsilon(self, epsilon): - self.cur_epsilon = epsilon - - def _build_q_network(self, obs, obs_space, action_space): - qnet = QNetwork( - ModelCatalog.get_model({ - "obs": obs, - "is_training": self._get_is_training_placeholder(), - }, obs_space, action_space, self.num_actions, - self.config["model"]), self.num_actions, - self.config["dueling"], self.config["hiddens"], - self.config["noisy"], self.config["num_atoms"], - self.config["v_min"], self.config["v_max"], self.config["sigma0"], - self.config["parameter_noise"]) - return qnet.value, qnet.logits, qnet.dist, qnet.model - - def _build_q_value_policy(self, q_values): - policy = QValuePolicy( - q_values, self.cur_observations, self.num_actions, self.stochastic, - self.eps, self.config["soft_q"], self.config["softmax_temp"]) - return policy.action, policy.action_prob - - def _build_q_loss(self, q_t_selected, q_logits_t_selected, q_tp1_best, - q_dist_tp1_best): - return QLoss(q_t_selected, q_logits_t_selected, q_tp1_best, - q_dist_tp1_best, self.importance_weights, self.rew_t, - self.done_mask, self.config["gamma"], - self.config["n_step"], self.config["num_atoms"], - self.config["v_min"], self.config["v_max"]) +def postprocess_trajectory(policy, + sample_batch, + other_agent_batches=None, + episode=None): + if policy.config["parameter_noise"]: + # adjust the sigma of parameter space noise + states = [list(x) for x in sample_batch.columns(["obs"])][0] + + noisy_action_distribution = policy.get_session().run( + policy.action_probs, feed_dict={policy.cur_observations: states}) + policy.get_session().run(policy.remove_noise_op) + clean_action_distribution = policy.get_session().run( + policy.action_probs, feed_dict={policy.cur_observations: states}) + distance_in_action_space = np.mean( + entropy(clean_action_distribution.T, noisy_action_distribution.T)) + policy.pi_distance = distance_in_action_space + if (distance_in_action_space < + -np.log(1 - policy.cur_epsilon + + policy.cur_epsilon / policy.num_actions)): + policy.parameter_noise_sigma_val *= 1.01 + else: + policy.parameter_noise_sigma_val /= 1.01 + policy.parameter_noise_sigma.load( + policy.parameter_noise_sigma_val, session=policy.get_session()) + + return _postprocess_dqn(policy, sample_batch) + + +def build_q_networks(policy, input_dict, observation_space, action_space, + config): + + if not isinstance(action_space, Discrete): + raise UnsupportedSpaceException( + "Action space {} is not supported for DQN.".format(action_space)) + + # Action Q network + with tf.variable_scope(Q_SCOPE) as scope: + q_values, q_logits, q_dist, _ = _build_q_network( + policy, input_dict[SampleBatch.CUR_OBS], observation_space, + action_space) + policy.q_values = q_values + policy.q_func_vars = _scope_vars(scope.name) + + # Noise vars for Q network except for layer normalization vars + if config["parameter_noise"]: + _build_parameter_noise( + policy, + [var for var in policy.q_func_vars if "LayerNorm" not in var.name]) + policy.action_probs = tf.nn.softmax(policy.q_values) + + # Action outputs + qvp = QValuePolicy(q_values, input_dict[SampleBatch.CUR_OBS], + action_space.n, policy.stochastic, policy.eps, + policy.config["soft_q"], policy.config["softmax_temp"]) + policy.output_actions, policy.action_prob = qvp.action, qvp.action_prob + + return policy.output_actions, policy.action_prob + + +def _build_parameter_noise(policy, pnet_params): + policy.parameter_noise_sigma_val = 1.0 + policy.parameter_noise_sigma = tf.get_variable( + initializer=tf.constant_initializer(policy.parameter_noise_sigma_val), + name="parameter_noise_sigma", + shape=(), + trainable=False, + dtype=tf.float32) + policy.parameter_noise = list() + # No need to add any noise on LayerNorm parameters + for var in pnet_params: + noise_var = tf.get_variable( + name=var.name.split(":")[0] + "_noise", + shape=var.shape, + initializer=tf.constant_initializer(.0), + trainable=False) + policy.parameter_noise.append(noise_var) + remove_noise_ops = list() + for var, var_noise in zip(pnet_params, policy.parameter_noise): + remove_noise_ops.append(tf.assign_add(var, -var_noise)) + policy.remove_noise_op = tf.group(*tuple(remove_noise_ops)) + generate_noise_ops = list() + for var_noise in policy.parameter_noise: + generate_noise_ops.append( + tf.assign( + var_noise, + tf.random_normal( + shape=var_noise.shape, + stddev=policy.parameter_noise_sigma))) + with tf.control_dependencies(generate_noise_ops): + add_noise_ops = list() + for var, var_noise in zip(pnet_params, policy.parameter_noise): + add_noise_ops.append(tf.assign_add(var, var_noise)) + policy.add_noise_op = tf.group(*tuple(add_noise_ops)) + policy.pi_distance = None + + +def build_q_losses(policy, batch_tensors): + # q network evaluation + with tf.variable_scope(Q_SCOPE, reuse=True): + prev_update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) + q_t, q_logits_t, q_dist_t, model = _build_q_network( + policy, batch_tensors[SampleBatch.CUR_OBS], + policy.observation_space, policy.action_space) + policy.q_batchnorm_update_ops = list( + set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) - prev_update_ops) + + # target q network evalution + with tf.variable_scope(Q_TARGET_SCOPE) as scope: + q_tp1, q_logits_tp1, q_dist_tp1, _ = _build_q_network( + policy, batch_tensors[SampleBatch.NEXT_OBS], + policy.observation_space, policy.action_space) + policy.target_q_func_vars = _scope_vars(scope.name) + + # q scores for actions which we know were selected in the given state. + one_hot_selection = tf.one_hot(batch_tensors[SampleBatch.ACTIONS], + policy.action_space.n) + q_t_selected = tf.reduce_sum(q_t * one_hot_selection, 1) + q_logits_t_selected = tf.reduce_sum( + q_logits_t * tf.expand_dims(one_hot_selection, -1), 1) + + # compute estimate of best possible value starting from state at t + 1 + if policy.config["double_q"]: + with tf.variable_scope(Q_SCOPE, reuse=True): + q_tp1_using_online_net, q_logits_tp1_using_online_net, \ + q_dist_tp1_using_online_net, _ = _build_q_network( + policy, + batch_tensors[SampleBatch.NEXT_OBS], + policy.observation_space, policy.action_space) + q_tp1_best_using_online_net = tf.argmax(q_tp1_using_online_net, 1) + q_tp1_best_one_hot_selection = tf.one_hot(q_tp1_best_using_online_net, + policy.action_space.n) + q_tp1_best = tf.reduce_sum(q_tp1 * q_tp1_best_one_hot_selection, 1) + q_dist_tp1_best = tf.reduce_sum( + q_dist_tp1 * tf.expand_dims(q_tp1_best_one_hot_selection, -1), 1) + else: + q_tp1_best_one_hot_selection = tf.one_hot( + tf.argmax(q_tp1, 1), policy.action_space.n) + q_tp1_best = tf.reduce_sum(q_tp1 * q_tp1_best_one_hot_selection, 1) + q_dist_tp1_best = tf.reduce_sum( + q_dist_tp1 * tf.expand_dims(q_tp1_best_one_hot_selection, -1), 1) + + policy.loss = _build_q_loss( + q_t_selected, q_logits_t_selected, q_tp1_best, q_dist_tp1_best, + batch_tensors[SampleBatch.REWARDS], batch_tensors[SampleBatch.DONES], + batch_tensors[PRIO_WEIGHTS], policy.config) + + return policy.loss.loss + + +def adam_optimizer(policy, config): + return tf.train.AdamOptimizer( + learning_rate=policy.cur_lr, epsilon=config["adam_epsilon"]) + + +def clip_gradients(policy, optimizer, loss): + if policy.config["grad_norm_clipping"] is not None: + grads_and_vars = _minimize_and_clip( + optimizer, + loss, + var_list=policy.q_func_vars, + clip_val=policy.config["grad_norm_clipping"]) + else: + grads_and_vars = optimizer.compute_gradients( + loss, var_list=policy.q_func_vars) + grads_and_vars = [(g, v) for (g, v) in grads_and_vars if g is not None] + return grads_and_vars + + +def exploration_setting_inputs(policy): + return { + policy.stochastic: True, + policy.eps: policy.cur_epsilon, + } + + +def build_q_stats(policy, batch_tensors): + return dict({ + "cur_lr": tf.cast(policy.cur_lr, tf.float64), + }, **policy.loss.stats) + + +def setup_early_mixins(policy, obs_space, action_space, config): + LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"]) + ExplorationStateMixin.__init__(policy, obs_space, action_space, config) + + +def setup_late_mixins(policy, obs_space, action_space, config): + TargetNetworkMixin.__init__(policy, obs_space, action_space, config) + + +def _build_q_network(policy, obs, obs_space, action_space): + config = policy.config + qnet = QNetwork( + ModelCatalog.get_model({ + "obs": obs, + "is_training": policy._get_is_training_placeholder(), + }, obs_space, action_space, action_space.n, config["model"]), + action_space.n, config["dueling"], config["hiddens"], config["noisy"], + config["num_atoms"], config["v_min"], config["v_max"], + config["sigma0"], config["parameter_noise"]) + return qnet.value, qnet.logits, qnet.dist, qnet.model + + +def _build_q_value_policy(policy, q_values): + policy = QValuePolicy(q_values, policy.cur_observations, + policy.num_actions, policy.stochastic, policy.eps, + policy.config["soft_q"], + policy.config["softmax_temp"]) + return policy.action, policy.action_prob + + +def _build_q_loss(q_t_selected, q_logits_t_selected, q_tp1_best, + q_dist_tp1_best, rewards, dones, importance_weights, config): + return QLoss(q_t_selected, q_logits_t_selected, q_tp1_best, + q_dist_tp1_best, importance_weights, rewards, + tf.cast(dones, tf.float32), config["gamma"], config["n_step"], + config["num_atoms"], config["v_min"], config["v_max"]) def _adjust_nstep(n_step, gamma, obs, actions, rewards, new_obs, dones): @@ -706,3 +683,27 @@ def _scope_vars(scope, trainable_only=False): tf.GraphKeys.TRAINABLE_VARIABLES if trainable_only else tf.GraphKeys.VARIABLES, scope=scope if isinstance(scope, str) else scope.name) + + +DQNTFPolicy = build_tf_policy( + name="DQNTFPolicy", + get_default_config=lambda: ray.rllib.agents.dqn.dqn.DEFAULT_CONFIG, + make_action_sampler=build_q_networks, + loss_fn=build_q_losses, + stats_fn=build_q_stats, + postprocess_fn=postprocess_trajectory, + optimizer_fn=adam_optimizer, + gradients_fn=clip_gradients, + extra_action_feed_fn=exploration_setting_inputs, + extra_action_fetches_fn=lambda policy: {"q_values": policy.q_values}, + extra_learn_fetches_fn=lambda policy: {"td_error": policy.loss.td_error}, + update_ops_fn=lambda policy: policy.q_batchnorm_update_ops, + before_init=setup_early_mixins, + after_init=setup_late_mixins, + obs_include_prev_action_reward=False, + mixins=[ + ExplorationStateMixin, + TargetNetworkMixin, + ComputeTDErrorMixin, + LearningRateSchedule, + ]) diff --git a/python/ray/rllib/agents/ppo/appo_policy.py b/python/ray/rllib/agents/ppo/appo_policy.py index 9f213063ab94..56e473a61af5 100644 --- a/python/ray/rllib/agents/ppo/appo_policy.py +++ b/python/ray/rllib/agents/ppo/appo_policy.py @@ -365,12 +365,15 @@ def __init__(self): tf.get_variable_scope().name) def value(self, ob, *args): - feed_dict = {self._obs_input: [ob], self.model.seq_lens: [1]} + feed_dict = { + self.get_placeholder(SampleBatch.CUR_OBS): [ob], + self.model.seq_lens: [1] + } assert len(args) == len(self.model.state_in), \ (args, self.model.state_in) for k, v in zip(self.model.state_in, args): feed_dict[k] = v - vf = self._sess.run(self.value_function, feed_dict) + vf = self.get_session().run(self.value_function, feed_dict) return vf[0] diff --git a/python/ray/rllib/agents/ppo/ppo_policy.py b/python/ray/rllib/agents/ppo/ppo_policy.py index 5a17d6c6d60c..4b391cab2cdc 100644 --- a/python/ray/rllib/agents/ppo/ppo_policy.py +++ b/python/ray/rllib/agents/ppo/ppo_policy.py @@ -216,7 +216,7 @@ def update_kl(self, sampled_kl): self.kl_coeff_val *= 1.5 elif sampled_kl < 0.5 * self.kl_target: self.kl_coeff_val *= 0.5 - self.kl_coeff.load(self.kl_coeff_val, session=self._sess) + self.kl_coeff.load(self.kl_coeff_val, session=self.get_session()) return self.kl_coeff_val @@ -240,28 +240,26 @@ def __init__(self, obs_space, action_space, config): "a custom LSTM model that overrides the " "value_function() method.") with tf.variable_scope("value_function"): - self.value_function = ModelCatalog.get_model({ - "obs": self._obs_input, - "prev_actions": self._prev_action_input, - "prev_rewards": self._prev_reward_input, - "is_training": self._get_is_training_placeholder(), - }, obs_space, action_space, 1, vf_config).outputs + self.value_function = ModelCatalog.get_model( + self.get_obs_input_dict(), obs_space, action_space, 1, + vf_config).outputs self.value_function = tf.reshape(self.value_function, [-1]) else: - self.value_function = tf.zeros(shape=tf.shape(self._obs_input)[:1]) + self.value_function = tf.zeros( + shape=tf.shape(self.get_placeholder(SampleBatch.CUR_OBS))[:1]) def _value(self, ob, prev_action, prev_reward, *args): feed_dict = { - self._obs_input: [ob], - self._prev_action_input: [prev_action], - self._prev_reward_input: [prev_reward], + self.get_placeholder(SampleBatch.CUR_OBS): [ob], + self.get_placeholder(SampleBatch.PREV_ACTIONS): [prev_action], + self.get_placeholder(SampleBatch.PREV_REWARDS): [prev_reward], self.model.seq_lens: [1] } assert len(args) == len(self.model.state_in), \ (args, self.model.state_in) for k, v in zip(self.model.state_in, args): feed_dict[k] = v - vf = self._sess.run(self.value_function, feed_dict) + vf = self.get_session().run(self.value_function, feed_dict) return vf[0] diff --git a/python/ray/rllib/evaluation/tf_policy_template.py b/python/ray/rllib/evaluation/tf_policy_template.py deleted file mode 100644 index 36f482f18bf8..000000000000 --- a/python/ray/rllib/evaluation/tf_policy_template.py +++ /dev/null @@ -1,146 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from ray.rllib.policy.dynamic_tf_policy import DynamicTFPolicy -from ray.rllib.policy.policy import Policy -from ray.rllib.policy.tf_policy import TFPolicy -from ray.rllib.utils.annotations import override, DeveloperAPI - - -@DeveloperAPI -def build_tf_policy(name, - loss_fn, - get_default_config=None, - stats_fn=None, - grad_stats_fn=None, - extra_action_fetches_fn=None, - postprocess_fn=None, - optimizer_fn=None, - gradients_fn=None, - before_init=None, - before_loss_init=None, - after_init=None, - make_action_sampler=None, - mixins=None, - get_batch_divisibility_req=None): - """Helper function for creating a dynamic tf policy at runtime. - - Arguments: - name (str): name of the policy (e.g., "PPOTFPolicy") - loss_fn (func): function that returns a loss tensor the policy, - and dict of experience tensor placeholders - get_default_config (func): optional function that returns the default - config to merge with any overrides - stats_fn (func): optional function that returns a dict of - TF fetches given the policy and batch input tensors - grad_stats_fn (func): optional function that returns a dict of - TF fetches given the policy and loss gradient tensors - extra_action_fetches_fn (func): optional function that returns - a dict of TF fetches given the policy object - postprocess_fn (func): optional experience postprocessing function - that takes the same args as Policy.postprocess_trajectory() - optimizer_fn (func): optional function that returns a tf.Optimizer - given the policy and config - gradients_fn (func): optional function that returns a list of gradients - given a tf optimizer and loss tensor. If not specified, this - defaults to optimizer.compute_gradients(loss) - before_init (func): optional function to run at the beginning of - policy init that takes the same arguments as the policy constructor - before_loss_init (func): optional function to run prior to loss - init that takes the same arguments as the policy constructor - after_init (func): optional function to run at the end of policy init - that takes the same arguments as the policy constructor - make_action_sampler (func): optional function that returns a - tuple of action and action prob tensors. The function takes - (policy, input_dict, obs_space, action_space, config) as its - arguments - mixins (list): list of any class mixins for the returned policy class. - These mixins will be applied in order and will have higher - precedence than the DynamicTFPolicy class - get_batch_divisibility_req (func): optional function that returns - the divisibility requirement for sample batches - - Returns: - a DynamicTFPolicy instance that uses the specified args - """ - - if not name.endswith("TFPolicy"): - raise ValueError("Name should match *TFPolicy", name) - - base = DynamicTFPolicy - while mixins: - - class new_base(mixins.pop(), base): - pass - - base = new_base - - class policy_cls(base): - def __init__(self, - obs_space, - action_space, - config, - existing_inputs=None): - if get_default_config: - config = dict(get_default_config(), **config) - - if before_init: - before_init(self, obs_space, action_space, config) - - def before_loss_init_wrapper(policy, obs_space, action_space, - config): - if before_loss_init: - before_loss_init(policy, obs_space, action_space, config) - if extra_action_fetches_fn is None: - self._extra_action_fetches = {} - else: - self._extra_action_fetches = extra_action_fetches_fn(self) - - DynamicTFPolicy.__init__( - self, - obs_space, - action_space, - config, - loss_fn, - stats_fn=stats_fn, - grad_stats_fn=grad_stats_fn, - before_loss_init=before_loss_init_wrapper, - existing_inputs=existing_inputs) - - if after_init: - after_init(self, obs_space, action_space, config) - - @override(Policy) - def postprocess_trajectory(self, - sample_batch, - other_agent_batches=None, - episode=None): - if not postprocess_fn: - return sample_batch - return postprocess_fn(self, sample_batch, other_agent_batches, - episode) - - @override(TFPolicy) - def optimizer(self): - if optimizer_fn: - return optimizer_fn(self, self.config) - else: - return TFPolicy.optimizer(self) - - @override(TFPolicy) - def gradients(self, optimizer, loss): - if gradients_fn: - return gradients_fn(self, optimizer, loss) - else: - return TFPolicy.gradients(self, optimizer, loss) - - @override(TFPolicy) - def extra_compute_action_fetches(self): - return dict( - TFPolicy.extra_compute_action_fetches(self), - **self._extra_action_fetches) - - policy_cls.__name__ = name - policy_cls.__qualname__ = name - return policy_cls diff --git a/python/ray/rllib/policy/dynamic_tf_policy.py b/python/ray/rllib/policy/dynamic_tf_policy.py index 691fc1186272..afa72a0af709 100644 --- a/python/ray/rllib/policy/dynamic_tf_policy.py +++ b/python/ray/rllib/policy/dynamic_tf_policy.py @@ -37,11 +37,13 @@ def __init__(self, config, loss_fn, stats_fn=None, + update_ops_fn=None, grad_stats_fn=None, before_loss_init=None, make_action_sampler=None, existing_inputs=None, - get_batch_divisibility_req=None): + get_batch_divisibility_req=None, + obs_include_prev_action_reward=True): """Initialize a dynamic TF policy. Arguments: @@ -54,6 +56,8 @@ def __init__(self, TF fetches given the policy and batch input tensors grad_stats_fn (func): optional function that returns a dict of TF fetches given the policy and loss gradient tensors + update_ops_fn (func): optional function that returns a list + overriding the update ops to run when applying gradients before_loss_init (func): optional function to run prior to loss init that takes the same arguments as __init__ make_action_sampler (func): optional function that returns a @@ -65,30 +69,39 @@ def __init__(self, defining new ones get_batch_divisibility_req (func): optional function that returns the divisibility requirement for sample batches + obs_include_prev_action_reward (bool): whether to include the + previous action and reward in the model input """ self.config = config self._loss_fn = loss_fn self._stats_fn = stats_fn self._grad_stats_fn = grad_stats_fn + self._update_ops_fn = update_ops_fn + self._obs_include_prev_action_reward = obs_include_prev_action_reward # Setup standard placeholders + prev_actions = None + prev_rewards = None if existing_inputs is not None: obs = existing_inputs[SampleBatch.CUR_OBS] - prev_actions = existing_inputs[SampleBatch.PREV_ACTIONS] - prev_rewards = existing_inputs[SampleBatch.PREV_REWARDS] + if self._obs_include_prev_action_reward: + prev_actions = existing_inputs[SampleBatch.PREV_ACTIONS] + prev_rewards = existing_inputs[SampleBatch.PREV_REWARDS] else: obs = tf.placeholder( tf.float32, shape=[None] + list(obs_space.shape), name="observation") - prev_actions = ModelCatalog.get_action_placeholder(action_space) - prev_rewards = tf.placeholder( - tf.float32, [None], name="prev_reward") + if self._obs_include_prev_action_reward: + prev_actions = ModelCatalog.get_action_placeholder( + action_space) + prev_rewards = tf.placeholder( + tf.float32, [None], name="prev_reward") - input_dict = { - "obs": obs, - "prev_actions": prev_actions, - "prev_rewards": prev_rewards, + self.input_dict = { + SampleBatch.CUR_OBS: obs, + SampleBatch.PREV_ACTIONS: prev_actions, + SampleBatch.PREV_REWARDS: prev_rewards, "is_training": self._get_is_training_placeholder(), } @@ -100,7 +113,7 @@ def __init__(self, self.dist_class = None self.action_dist = None action_sampler, action_prob = make_action_sampler( - self, input_dict, obs_space, action_space, config) + self, self.input_dict, obs_space, action_space, config) else: self.dist_class, logit_dim = ModelCatalog.get_action_dist( action_space, self.config["model"]) @@ -117,7 +130,7 @@ def __init__(self, existing_state_in = [] existing_seq_lens = None self.model = ModelCatalog.get_model( - input_dict, + self.input_dict, obs_space, action_space, logit_dim, @@ -158,6 +171,13 @@ def __init__(self, if not existing_inputs: self._initialize_loss() + def get_obs_input_dict(self): + """Returns the obs input dict used to build policy models. + + This dict includes the obs, prev actions, prev rewards, etc. tensors. + """ + return self.input_dict + @override(TFPolicy) def copy(self, existing_inputs): """Creates a copy of self using existing input placeholders.""" @@ -190,10 +210,8 @@ def copy(self, existing_inputs): self.action_space, self.config, existing_inputs=input_dict) - loss = instance._loss_fn(instance, input_dict) - if instance._stats_fn: - instance._stats_fetches.update( - instance._stats_fn(instance, input_dict)) + + loss = instance._do_loss_init(input_dict) TFPolicy._initialize_loss( instance, loss, [(k, existing_inputs[i]) for i, (k, _) in enumerate(self._loss_inputs)]) @@ -216,14 +234,18 @@ def fake_array(tensor): return np.zeros(shape, dtype=tensor.dtype.as_numpy_dtype) dummy_batch = { - SampleBatch.PREV_ACTIONS: fake_array(self._prev_action_input), - SampleBatch.PREV_REWARDS: fake_array(self._prev_reward_input), SampleBatch.CUR_OBS: fake_array(self._obs_input), SampleBatch.NEXT_OBS: fake_array(self._obs_input), - SampleBatch.ACTIONS: fake_array(self._prev_action_input), - SampleBatch.REWARDS: np.array([0], dtype=np.float32), SampleBatch.DONES: np.array([False], dtype=np.bool), + SampleBatch.ACTIONS: fake_array( + ModelCatalog.get_action_placeholder(self.action_space)), + SampleBatch.REWARDS: np.array([0], dtype=np.float32), } + if self._obs_include_prev_action_reward: + dummy_batch.update({ + SampleBatch.PREV_ACTIONS: fake_array(self._prev_action_input), + SampleBatch.PREV_REWARDS: fake_array(self._prev_reward_input), + }) state_init = self.get_initial_state() for i, h in enumerate(state_init): dummy_batch["state_in_{}".format(i)] = np.expand_dims(h, 0) @@ -238,16 +260,24 @@ def fake_array(tensor): postprocessed_batch = self.postprocess_trajectory( SampleBatch(dummy_batch)) - batch_tensors = UsageTrackingDict({ - SampleBatch.PREV_ACTIONS: self._prev_action_input, - SampleBatch.PREV_REWARDS: self._prev_reward_input, - SampleBatch.CUR_OBS: self._obs_input, - }) - loss_inputs = [ - (SampleBatch.PREV_ACTIONS, self._prev_action_input), - (SampleBatch.PREV_REWARDS, self._prev_reward_input), - (SampleBatch.CUR_OBS, self._obs_input), - ] + if self._obs_include_prev_action_reward: + batch_tensors = UsageTrackingDict({ + SampleBatch.PREV_ACTIONS: self._prev_action_input, + SampleBatch.PREV_REWARDS: self._prev_reward_input, + SampleBatch.CUR_OBS: self._obs_input, + }) + loss_inputs = [ + (SampleBatch.PREV_ACTIONS, self._prev_action_input), + (SampleBatch.PREV_REWARDS, self._prev_reward_input), + (SampleBatch.CUR_OBS, self._obs_input), + ] + else: + batch_tensors = UsageTrackingDict({ + SampleBatch.CUR_OBS: self._obs_input, + }) + loss_inputs = [ + (SampleBatch.CUR_OBS, self._obs_input), + ] for k, v in postprocessed_batch.items(): if k in batch_tensors: @@ -264,12 +294,18 @@ def fake_array(tensor): "Initializing loss function with dummy input:\n\n{}\n".format( summarize(batch_tensors))) - loss = self._loss_fn(self, batch_tensors) - if self._stats_fn: - self._stats_fetches.update(self._stats_fn(self, batch_tensors)) + loss = self._do_loss_init(batch_tensors) for k in sorted(batch_tensors.accessed_keys): loss_inputs.append((k, batch_tensors[k])) TFPolicy._initialize_loss(self, loss, loss_inputs) if self._grad_stats_fn: self._stats_fetches.update(self._grad_stats_fn(self, self._grads)) self._sess.run(tf.global_variables_initializer()) + + def _do_loss_init(self, batch_tensors): + loss = self._loss_fn(self, batch_tensors) + if self._stats_fn: + self._stats_fetches.update(self._stats_fn(self, batch_tensors)) + if self._update_ops_fn: + self._update_ops = self._update_ops_fn(self) + return loss diff --git a/python/ray/rllib/policy/tf_policy.py b/python/ray/rllib/policy/tf_policy.py index bbb5795e52ab..ed234f809512 100644 --- a/python/ray/rllib/policy/tf_policy.py +++ b/python/ray/rllib/policy/tf_policy.py @@ -139,6 +139,39 @@ def __init__(self, raise ValueError( "seq_lens tensor must be given if state inputs are defined") + def get_placeholder(self, name): + """Returns the given action or loss input placeholder by name. + + If the loss has not been initialized and a loss input placeholder is + requested, an error is raised. + """ + + obs_inputs = { + SampleBatch.CUR_OBS: self._obs_input, + SampleBatch.PREV_ACTIONS: self._prev_action_input, + SampleBatch.PREV_REWARDS: self._prev_reward_input, + } + if name in obs_inputs: + return obs_inputs[name] + + if not self.loss_initialized(): + raise RuntimeError( + "You cannot call policy.get_placeholder() for non-obs inputs " + "before the loss has been initialized. To avoid this, use " + "policy.loss_initialized() to check whether this is the " + "case, or move the call to later (e.g., from stats_fn to " + "grad_stats_fn).") + + return self._loss_input_dict[name] + + def get_session(self): + """Returns a reference to the TF session for this policy.""" + return self._sess + + def loss_initialized(self): + """Returns whether the loss function has been initialized.""" + return self._loss is not None + def _initialize_loss(self, loss, loss_inputs): self._loss_inputs = loss_inputs self._loss_input_dict = dict(self._loss_inputs) @@ -172,7 +205,7 @@ def _initialize_loss(self, loss, loss_inputs): self._grads_and_vars) if log_once("loss_used"): - logger.debug( + logger.info( "These tensors were used in the loss_fn:\n\n{}\n".format( summarize(self._loss_input_dict))) @@ -195,21 +228,21 @@ def compute_actions(self, @override(Policy) def compute_gradients(self, postprocessed_batch): - assert self._loss is not None, "Loss not initialized" + assert self.loss_initialized() builder = TFRunBuilder(self._sess, "compute_gradients") fetches = self._build_compute_gradients(builder, postprocessed_batch) return builder.get(fetches) @override(Policy) def apply_gradients(self, gradients): - assert self._loss is not None, "Loss not initialized" + assert self.loss_initialized() builder = TFRunBuilder(self._sess, "apply_gradients") fetches = self._build_apply_gradients(builder, gradients) builder.get(fetches) @override(Policy) def learn_on_batch(self, postprocessed_batch): - assert self._loss is not None, "Loss not initialized" + assert self.loss_initialized() builder = TFRunBuilder(self._sess, "learn_on_batch") fetches = self._build_learn_on_batch(builder, postprocessed_batch) return builder.get(fetches) diff --git a/python/ray/rllib/policy/tf_policy_template.py b/python/ray/rllib/policy/tf_policy_template.py index 36f482f18bf8..7f10958cdee7 100644 --- a/python/ray/rllib/policy/tf_policy_template.py +++ b/python/ray/rllib/policy/tf_policy_template.py @@ -3,7 +3,7 @@ from __future__ import print_function from ray.rllib.policy.dynamic_tf_policy import DynamicTFPolicy -from ray.rllib.policy.policy import Policy +from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY from ray.rllib.policy.tf_policy import TFPolicy from ray.rllib.utils.annotations import override, DeveloperAPI @@ -12,39 +12,60 @@ def build_tf_policy(name, loss_fn, get_default_config=None, - stats_fn=None, - grad_stats_fn=None, - extra_action_fetches_fn=None, postprocess_fn=None, + stats_fn=None, + update_ops_fn=None, optimizer_fn=None, gradients_fn=None, + grad_stats_fn=None, + extra_action_fetches_fn=None, + extra_action_feed_fn=None, + extra_learn_fetches_fn=None, + extra_learn_feed_fn=None, before_init=None, before_loss_init=None, after_init=None, make_action_sampler=None, mixins=None, - get_batch_divisibility_req=None): + get_batch_divisibility_req=None, + obs_include_prev_action_reward=True): """Helper function for creating a dynamic tf policy at runtime. + Functions will be run in this order to initialize the policy: + 1. Placeholder setup: postprocess_fn + 2. Loss init: loss_fn, stats_fn, update_ops_fn + 3. Optimizer init: optimizer_fn, gradients_fn, grad_stats_fn + + This means that you can e.g., depend on any policy attributes created in + the running of `loss_fn` in later functions such as `stats_fn`. + Arguments: name (str): name of the policy (e.g., "PPOTFPolicy") loss_fn (func): function that returns a loss tensor the policy, - and dict of experience tensor placeholders + and dict of experience tensor placeholdes get_default_config (func): optional function that returns the default config to merge with any overrides - stats_fn (func): optional function that returns a dict of - TF fetches given the policy and batch input tensors - grad_stats_fn (func): optional function that returns a dict of - TF fetches given the policy and loss gradient tensors - extra_action_fetches_fn (func): optional function that returns - a dict of TF fetches given the policy object postprocess_fn (func): optional experience postprocessing function that takes the same args as Policy.postprocess_trajectory() + stats_fn (func): optional function that returns a dict of + TF fetches given the policy and batch input tensors + update_ops_fn (func): optional function that returns a list overriding + the update ops to run when applying gradients optimizer_fn (func): optional function that returns a tf.Optimizer given the policy and config gradients_fn (func): optional function that returns a list of gradients given a tf optimizer and loss tensor. If not specified, this defaults to optimizer.compute_gradients(loss) + grad_stats_fn (func): optional function that returns a dict of + TF fetches given the policy and loss gradient tensors + extra_action_fetches_fn (func): optional function that returns + a dict of TF fetches given the policy object + extra_action_feed_fn (func): optional function that returns a feed dict + to also feed to TF when computing actions + extra_learn_fetches_fn (func): optional function that returns a dict of + extra values to fetch and return when learning on a batch + extra_learn_feed_fn (func): optional function that returns a feed dict + to also feed to TF when learning on a batch before_init (func): optional function to run at the beginning of policy init that takes the same arguments as the policy constructor before_loss_init (func): optional function to run prior to loss @@ -60,6 +81,8 @@ def build_tf_policy(name, precedence than the DynamicTFPolicy class get_batch_divisibility_req (func): optional function that returns the divisibility requirement for sample batches + obs_include_prev_action_reward (bool): whether to include the + previous action and reward in the model input Returns: a DynamicTFPolicy instance that uses the specified args @@ -105,8 +128,11 @@ def before_loss_init_wrapper(policy, obs_space, action_space, loss_fn, stats_fn=stats_fn, grad_stats_fn=grad_stats_fn, + update_ops_fn=update_ops_fn, before_loss_init=before_loss_init_wrapper, - existing_inputs=existing_inputs) + make_action_sampler=make_action_sampler, + existing_inputs=existing_inputs, + obs_include_prev_action_reward=obs_include_prev_action_reward) if after_init: after_init(self, obs_space, action_space, config) @@ -141,6 +167,30 @@ def extra_compute_action_fetches(self): TFPolicy.extra_compute_action_fetches(self), **self._extra_action_fetches) + @override(TFPolicy) + def extra_compute_action_feed_dict(self): + if extra_action_feed_fn: + return extra_action_feed_fn(self) + else: + return TFPolicy.extra_compute_action_feed_dict(self) + + @override(TFPolicy) + def extra_compute_grad_fetches(self): + if extra_learn_fetches_fn: + # auto-add empty learner stats dict if needed + return dict({ + LEARNER_STATS_KEY: {} + }, **extra_learn_fetches_fn(self)) + else: + return TFPolicy.extra_compute_grad_fetches(self) + + @override(TFPolicy) + def extra_compute_grad_feed_dict(self): + if extra_learn_feed_fn: + return extra_learn_feed_fn(self) + else: + return TFPolicy.extra_compute_grad_feed_dict(self) + policy_cls.__name__ = name policy_cls.__qualname__ = name return policy_cls From d86ee8c83e8cd5c6711421ffd689f147184c2e4c Mon Sep 17 00:00:00 2001 From: Akshat Gokhale Date: Sun, 2 Jun 2019 12:05:48 +0530 Subject: [PATCH 061/118] fetching objects in parallel in _get_arguments_for_execution (#4775) --- python/ray/worker.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/python/ray/worker.py b/python/ray/worker.py index c886159aafec..7786c742d9b1 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -782,18 +782,27 @@ def _get_arguments_for_execution(self, function_name, serialized_args): RayError: This exception is raised if a task that created one of the arguments failed. """ - arguments = [] + arguments = [None] * len(serialized_args) + object_ids = [] + object_indices = [] + for (i, arg) in enumerate(serialized_args): if isinstance(arg, ObjectID): - # get the object from the local object store - argument = self.get_object([arg])[0] - if isinstance(argument, RayError): - raise argument + object_ids.append(arg) + object_indices.append(i) else: # pass the argument by value - argument = arg + arguments[i] = arg + + # Get the objects from the local object store. + if len(object_ids) > 0: + values = self.get_object(object_ids) + for i, value in enumerate(values): + if isinstance(value, RayError): + raise value + else: + arguments[object_indices[i]] = value - arguments.append(argument) return arguments def _store_outputs_in_object_store(self, object_ids, outputs): From 99eae05cf617b2a6f8797d3e4a6b07978f90b4c5 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Mon, 3 Jun 2019 06:47:39 +0800 Subject: [PATCH 062/118] [tune] Disallow setting resources_per_trial when it is already configured (#4880) * disallow it * import fix * fix example * fix test * fix tests * Update mock.py * fix * make less convoluted * fix tests --- python/ray/rllib/agents/mock.py | 4 ++++ python/ray/rllib/examples/custom_train_fn.py | 16 ++++++---------- python/ray/tune/trainable.py | 3 +-- python/ray/tune/trial.py | 20 +++++++++++++++++--- 4 files changed, 28 insertions(+), 15 deletions(-) diff --git a/python/ray/rllib/agents/mock.py b/python/ray/rllib/agents/mock.py index 8cd8c08b4695..0b7d77c2a76f 100644 --- a/python/ray/rllib/agents/mock.py +++ b/python/ray/rllib/agents/mock.py @@ -20,6 +20,10 @@ class _MockTrainer(Trainer): "num_workers": 0, }) + @classmethod + def default_resource_request(cls, config): + return None + def _init(self, config, env_creator): self.info = None self.restored = False diff --git a/python/ray/rllib/examples/custom_train_fn.py b/python/ray/rllib/examples/custom_train_fn.py index 3ae418ee7252..cc7b55e70d5d 100644 --- a/python/ray/rllib/examples/custom_train_fn.py +++ b/python/ray/rllib/examples/custom_train_fn.py @@ -40,13 +40,9 @@ def my_train_fn(config, reporter): if __name__ == "__main__": ray.init() - tune.run( - my_train_fn, - resources_per_trial={ - "cpu": 1, - }, - config={ - "lr": 0.01, - "num_workers": 0, - }, - ) + config = { + "lr": 0.01, + "num_workers": 0, + } + resources = PPOTrainer.default_resource_request(config).to_json() + tune.run(my_train_fn, resources_per_trial=resources, config=config) diff --git a/python/ray/tune/trainable.py b/python/ray/tune/trainable.py index c10934896bcc..bb70e2b39434 100644 --- a/python/ray/tune/trainable.py +++ b/python/ray/tune/trainable.py @@ -21,7 +21,6 @@ TIMESTEPS_THIS_ITER, DONE, TIMESTEPS_TOTAL, EPISODES_THIS_ITER, EPISODES_TOTAL, TRAINING_ITERATION, RESULT_DUPLICATE) -from ray.tune.trial import Resources logger = logging.getLogger(__name__) @@ -96,7 +95,7 @@ def default_resource_request(cls, config): allocation, so the user does not need to. """ - return Resources(cpu=1, gpu=0) + return None @classmethod def resource_help(cls, config): diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index 91ea941b8cf0..cb9351f9adf8 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -139,6 +139,9 @@ def subtract(cls, original, to_remove): return Resources(cpu, gpu, extra_cpu, extra_gpu, new_custom_res, extra_custom_res) + def to_json(self): + return resources_to_json(self) + def json_to_resources(data): if data is None or data == "null": @@ -275,9 +278,20 @@ def __init__(self, self.config = config or {} self.local_dir = local_dir # This remains unexpanded for syncing. self.experiment_tag = experiment_tag - self.resources = ( - resources - or self._get_trainable_cls().default_resource_request(self.config)) + trainable_cls = self._get_trainable_cls() + if trainable_cls and hasattr(trainable_cls, + "default_resource_request"): + default_resources = trainable_cls.default_resource_request( + self.config) + if default_resources: + if resources: + raise ValueError( + "Resources for {} have been automatically set to {} " + "by its `default_resource_request()` method. Please " + "clear the `resources_per_trial` option.".format( + trainable_cls, default_resources)) + resources = default_resources + self.resources = resources or Resources(cpu=1, gpu=0) self.stopping_criterion = stopping_criterion or {} self.upload_dir = upload_dir self.loggers = loggers From 7501ee51db88e6aa642b068b329b813c0712f3d9 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Mon, 3 Jun 2019 06:49:24 +0800 Subject: [PATCH 063/118] [rllib] Rename PolicyEvaluator => RolloutWorker (#4820) --- ci/jenkins_tests/run_rllib_tests.sh | 7 +- doc/source/rllib-concepts.rst | 2 +- doc/source/rllib-config.svg | 2 +- doc/source/rllib-examples.rst | 2 +- doc/source/rllib-training.rst | 18 +- python/ray/rllib/__init__.py | 3 +- python/ray/rllib/agents/a3c/a2c.py | 26 +- python/ray/rllib/agents/a3c/a3c.py | 59 +- python/ray/rllib/agents/ddpg/apex.py | 2 +- python/ray/rllib/agents/ddpg/ddpg.py | 4 +- python/ray/rllib/agents/ddpg/ddpg_policy.py | 2 +- python/ray/rllib/agents/dqn/apex.py | 2 +- python/ray/rllib/agents/dqn/dqn.py | 50 +- python/ray/rllib/agents/es/es.py | 10 +- python/ray/rllib/agents/impala/impala.py | 10 +- python/ray/rllib/agents/marwil/marwil.py | 9 +- python/ray/rllib/agents/pg/pg.py | 2 +- python/ray/rllib/agents/ppo/ppo.py | 14 +- python/ray/rllib/agents/qmix/apex.py | 2 +- python/ray/rllib/agents/trainer.py | 245 ++---- python/ray/rllib/agents/trainer_template.py | 37 +- python/ray/rllib/env/base_env.py | 2 +- python/ray/rllib/evaluation/__init__.py | 20 +- python/ray/rllib/evaluation/interface.py | 2 +- python/ray/rllib/evaluation/metrics.py | 20 +- .../ray/rllib/evaluation/policy_evaluator.py | 805 +----------------- python/ray/rllib/evaluation/rollout_worker.py | 794 +++++++++++++++++ python/ray/rllib/evaluation/worker_set.py | 214 +++++ .../rllib/examples/multiagent_two_trainers.py | 2 +- ...w.py => rollout_worker_custom_workflow.py} | 10 +- python/ray/rllib/offline/io_context.py | 12 +- python/ray/rllib/offline/json_reader.py | 2 +- .../ray/rllib/offline/off_policy_estimator.py | 6 +- python/ray/rllib/optimizers/aso_aggregator.py | 32 +- python/ray/rllib/optimizers/aso_learner.py | 6 +- .../rllib/optimizers/aso_multi_gpu_learner.py | 14 +- .../rllib/optimizers/aso_tree_aggregator.py | 55 +- .../optimizers/async_gradients_optimizer.py | 19 +- .../optimizers/async_replay_optimizer.py | 39 +- .../optimizers/async_samples_optimizer.py | 25 +- .../rllib/optimizers/multi_gpu_optimizer.py | 31 +- .../ray/rllib/optimizers/policy_optimizer.py | 74 +- .../optimizers/sync_batch_replay_optimizer.py | 19 +- .../rllib/optimizers/sync_replay_optimizer.py | 25 +- .../optimizers/sync_samples_optimizer.py | 27 +- python/ray/rllib/policy/dynamic_tf_policy.py | 2 +- python/ray/rllib/policy/policy.py | 2 +- python/ray/rllib/policy/tf_policy_template.py | 9 +- .../ray/rllib/policy/torch_policy_template.py | 19 +- python/ray/rllib/rollout.py | 8 +- .../{mock_evaluator.py => mock_worker.py} | 2 +- python/ray/rllib/tests/test_external_env.py | 16 +- .../tests/test_external_multi_agent_env.py | 15 +- python/ray/rllib/tests/test_filters.py | 6 +- .../ray/rllib/tests/test_multi_agent_env.py | 46 +- python/ray/rllib/tests/test_optimizers.py | 63 +- python/ray/rllib/tests/test_perf.py | 6 +- ...cy_evaluator.py => test_rollout_worker.py} | 47 +- python/ray/rllib/utils/actors.py | 8 +- 59 files changed, 1538 insertions(+), 1474 deletions(-) create mode 100644 python/ray/rllib/evaluation/rollout_worker.py create mode 100644 python/ray/rllib/evaluation/worker_set.py rename python/ray/rllib/examples/{policy_evaluator_custom_workflow.py => rollout_worker_custom_workflow.py} (90%) rename python/ray/rllib/tests/{mock_evaluator.py => mock_worker.py} (98%) rename python/ray/rllib/tests/{test_policy_evaluator.py => test_rollout_worker.py} (94%) diff --git a/ci/jenkins_tests/run_rllib_tests.sh b/ci/jenkins_tests/run_rllib_tests.sh index 78fbf6a3ab46..a97bf5517ea2 100644 --- a/ci/jenkins_tests/run_rllib_tests.sh +++ b/ci/jenkins_tests/run_rllib_tests.sh @@ -302,7 +302,7 @@ docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ /ray/ci/suppress_output python /ray/python/ray/rllib/tests/test_checkpoint_restore.py docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ - /ray/ci/suppress_output python /ray/python/ray/rllib/tests/test_policy_evaluator.py + /ray/ci/suppress_output python /ray/python/ray/rllib/tests/test_rollout_worker.py docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ /ray/ci/suppress_output python /ray/python/ray/rllib/tests/test_nested_spaces.py @@ -389,6 +389,9 @@ docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ /ray/ci/suppress_output python /ray/python/ray/rllib/examples/custom_loss.py --iters=2 +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ + /ray/ci/suppress_output python /ray/python/ray/rllib/examples/rollout_worker_custom_workflow.py + docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ /ray/ci/suppress_output python /ray/python/ray/rllib/examples/custom_tf_policy.py --iters=2 @@ -396,7 +399,7 @@ docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ /ray/ci/suppress_output python /ray/python/ray/rllib/examples/custom_torch_policy.py --iters=2 docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ - /ray/ci/suppress_output python /ray/python/ray/rllib/examples/policy_evaluator_custom_workflow.py + /ray/ci/suppress_output python /ray/python/ray/rllib/examples/rollout_worker_custom_workflow.py docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ /ray/ci/suppress_output python /ray/python/ray/rllib/examples/custom_metrics_and_callbacks.py --num-iters=2 diff --git a/doc/source/rllib-concepts.rst b/doc/source/rllib-concepts.rst index 2f9603b69f58..b7b3ff823774 100644 --- a/doc/source/rllib-concepts.rst +++ b/doc/source/rllib-concepts.rst @@ -453,7 +453,7 @@ Policy Evaluation Given an environment and policy, policy evaluation produces `batches `__ of experiences. This is your classic "environment interaction loop". Efficient policy evaluation can be burdensome to get right, especially when leveraging vectorization, RNNs, or when operating in a multi-agent environment. RLlib provides a `RolloutWorker `__ class that manages all of this, and this class is used in most RLlib algorithms. -You can use rollout workers standalone to produce batches of experiences. This can be done by calling ``worker.sample()`` on a worker instance, or ``worker.sample.remote()`` in parallel on worker instances created as Ray actors (see ``RolloutWorkers.create_remote``). +You can use rollout workers standalone to produce batches of experiences. This can be done by calling ``worker.sample()`` on a worker instance, or ``worker.sample.remote()`` in parallel on worker instances created as Ray actors (see `WorkerSet `__). Here is an example of creating a set of rollout workers and using them gather experiences in parallel. The trajectories are concatenated, the policy learns on the trajectory batch, and then we broadcast the policy weights to the workers for the next round of rollouts: diff --git a/doc/source/rllib-config.svg b/doc/source/rllib-config.svg index 04331f5f3021..b3a011eee1fb 100644 --- a/doc/source/rllib-config.svg +++ b/doc/source/rllib-config.svg @@ -1 +1 @@ - \ No newline at end of file + \ No newline at end of file diff --git a/doc/source/rllib-examples.rst b/doc/source/rllib-examples.rst index f26e078ea32d..13bfdc68bfc1 100644 --- a/doc/source/rllib-examples.rst +++ b/doc/source/rllib-examples.rst @@ -22,7 +22,7 @@ Training Workflows Example of how to adjust the configuration of an environment over time. - `Custom metrics `__: Example of how to output custom training metrics to TensorBoard. -- `Using policy evaluators directly for control over the whole training workflow `__: +- `Using rollout workers directly for control over the whole training workflow `__: Example of how to use RLlib's lower-level building blocks to implement a fully customized training workflow. Custom Envs and Models diff --git a/doc/source/rllib-training.rst b/doc/source/rllib-training.rst index ef4f292954d6..824ef4c3dd88 100644 --- a/doc/source/rllib-training.rst +++ b/doc/source/rllib-training.rst @@ -178,13 +178,13 @@ Custom Training Workflows In the `basic training example `__, Tune will call ``train()`` on your trainer once per iteration and report the new training results. Sometimes, it is desirable to have full control over training, but still run inside Tune. Tune supports `custom trainable functions `__ that can be used to implement `custom training workflows (example) `__. -For even finer-grained control over training, you can use RLlib's lower-level `building blocks `__ directly to implement `fully customized training workflows `__. +For even finer-grained control over training, you can use RLlib's lower-level `building blocks `__ directly to implement `fully customized training workflows `__. Accessing Policy State ~~~~~~~~~~~~~~~~~~~~~~ -It is common to need to access a trainer's internal state, e.g., to set or get internal weights. In RLlib trainer state is replicated across multiple *policy evaluators* (Ray actors) in the cluster. However, you can easily get and update this state between calls to ``train()`` via ``trainer.optimizer.foreach_evaluator()`` or ``trainer.optimizer.foreach_evaluator_with_index()``. These functions take a lambda function that is applied with the evaluator as an arg. You can also return values from these functions and those will be returned as a list. +It is common to need to access a trainer's internal state, e.g., to set or get internal weights. In RLlib trainer state is replicated across multiple *rollout workers* (Ray actors) in the cluster. However, you can easily get and update this state between calls to ``train()`` via ``trainer.workers.foreach_worker()`` or ``trainer.workers.foreach_worker_with_index()``. These functions take a lambda function that is applied with the worker as an arg. You can also return values from these functions and those will be returned as a list. -You can also access just the "master" copy of the trainer state through ``trainer.get_policy()`` or ``trainer.local_evaluator``, but note that updates here may not be immediately reflected in remote replicas if you have configured ``num_workers > 0``. For example, to access the weights of a local TF policy, you can run ``trainer.get_policy().get_weights()``. This is also equivalent to ``trainer.local_evaluator.policy_map["default_policy"].get_weights()``: +You can also access just the "master" copy of the trainer state through ``trainer.get_policy()`` or ``trainer.workers.local_worker()``, but note that updates here may not be immediately reflected in remote replicas if you have configured ``num_workers > 0``. For example, to access the weights of a local TF policy, you can run ``trainer.get_policy().get_weights()``. This is also equivalent to ``trainer.workers.local_worker().policy_map["default_policy"].get_weights()``: .. code-block:: python @@ -192,13 +192,13 @@ You can also access just the "master" copy of the trainer state through ``traine trainer.get_policy().get_weights() # Same as above - trainer.local_evaluator.policy_map["default_policy"].get_weights() + trainer.workers.local_worker().policy_map["default_policy"].get_weights() - # Get list of weights of each evaluator, including remote replicas - trainer.optimizer.foreach_evaluator(lambda ev: ev.get_policy().get_weights()) + # Get list of weights of each worker, including remote replicas + trainer.workers.foreach_worker(lambda ev: ev.get_policy().get_weights()) # Same as above - trainer.optimizer.foreach_evaluator_with_index(lambda ev, i: ev.get_policy().get_weights()) + trainer.workers.foreach_worker_with_index(lambda ev, i: ev.get_policy().get_weights()) Global Coordination ~~~~~~~~~~~~~~~~~~~ @@ -299,7 +299,7 @@ Approach 1: Use the Trainer API and update the environment between calls to ``tr phase = 1 else: phase = 0 - trainer.optimizer.foreach_evaluator( + trainer.workers.foreach_worker( lambda ev: ev.foreach_env( lambda env: env.set_phase(phase))) @@ -333,7 +333,7 @@ Approach 2: Use the callbacks API to update the environment on new training resu else: phase = 0 trainer = info["trainer"] - trainer.optimizer.foreach_evaluator( + trainer.workers.foreach_worker( lambda ev: ev.foreach_env( lambda env: env.set_phase(phase))) diff --git a/python/ray/rllib/__init__.py b/python/ray/rllib/__init__.py index 92844e485ff3..0824e999503f 100644 --- a/python/ray/rllib/__init__.py +++ b/python/ray/rllib/__init__.py @@ -11,7 +11,7 @@ from ray.rllib.evaluation.policy_graph import PolicyGraph from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph -from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator +from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.env.base_env import BaseEnv from ray.rllib.env.multi_agent_env import MultiAgentEnv from ray.rllib.env.vector_env import VectorEnv @@ -55,6 +55,7 @@ def _register_all(): "PolicyGraph", "TFPolicy", "TFPolicyGraph", + "RolloutWorker", "PolicyEvaluator", "SampleBatch", "BaseEnv", diff --git a/python/ray/rllib/agents/a3c/a2c.py b/python/ray/rllib/agents/a3c/a2c.py index e1834503016d..0b6592e741df 100644 --- a/python/ray/rllib/agents/a3c/a2c.py +++ b/python/ray/rllib/agents/a3c/a2c.py @@ -2,9 +2,10 @@ from __future__ import division from __future__ import print_function -from ray.rllib.agents.a3c.a3c import A3CTrainer, DEFAULT_CONFIG as A3C_CONFIG -from ray.rllib.optimizers import SyncSamplesOptimizer -from ray.rllib.utils.annotations import override +from ray.rllib.agents.a3c.a3c import DEFAULT_CONFIG as A3C_CONFIG, \ + validate_config, get_policy_class +from ray.rllib.agents.a3c.a3c_tf_policy import A3CTFPolicy +from ray.rllib.agents.trainer_template import build_trainer from ray.rllib.utils import merge_dicts A2C_DEFAULT_CONFIG = merge_dicts( @@ -16,16 +17,9 @@ }, ) - -class A2CTrainer(A3CTrainer): - """Synchronous variant of the A3CTrainer.""" - - _name = "A2C" - _default_config = A2C_DEFAULT_CONFIG - - @override(A3CTrainer) - def _make_optimizer(self): - return SyncSamplesOptimizer( - self.local_evaluator, - self.remote_evaluators, - train_batch_size=self.config["train_batch_size"]) +A2CTrainer = build_trainer( + name="A2C", + default_config=A2C_DEFAULT_CONFIG, + default_policy=A3CTFPolicy, + get_policy_class=get_policy_class, + validate_config=validate_config) diff --git a/python/ray/rllib/agents/a3c/a3c.py b/python/ray/rllib/agents/a3c/a3c.py index 56d7a09daa0f..c269df2fc6e5 100644 --- a/python/ray/rllib/agents/a3c/a3c.py +++ b/python/ray/rllib/agents/a3c/a3c.py @@ -2,12 +2,10 @@ from __future__ import division from __future__ import print_function -import time - from ray.rllib.agents.a3c.a3c_tf_policy import A3CTFPolicy -from ray.rllib.agents.trainer import Trainer, with_common_config +from ray.rllib.agents.trainer import with_common_config +from ray.rllib.agents.trainer_template import build_trainer from ray.rllib.optimizers import AsyncGradientsOptimizer -from ray.rllib.utils.annotations import override # yapf: disable # __sphinx_doc_begin__ @@ -38,43 +36,28 @@ # yapf: enable -class A3CTrainer(Trainer): - """A3C implementations in TensorFlow and PyTorch.""" +def get_policy_class(config): + if config["use_pytorch"]: + from ray.rllib.agents.a3c.a3c_torch_policy import \ + A3CTorchPolicy + return A3CTorchPolicy + else: + return A3CTFPolicy - _name = "A3C" - _default_config = DEFAULT_CONFIG - _policy = A3CTFPolicy - @override(Trainer) - def _init(self, config, env_creator): - if config["use_pytorch"]: - from ray.rllib.agents.a3c.a3c_torch_policy import \ - A3CTorchPolicy - policy_cls = A3CTorchPolicy - else: - policy_cls = self._policy +def validate_config(config): + if config["entropy_coeff"] < 0: + raise DeprecationWarning("entropy_coeff must be >= 0") - if config["entropy_coeff"] < 0: - raise DeprecationWarning("entropy_coeff must be >= 0") - self.local_evaluator = self.make_local_evaluator( - env_creator, policy_cls) - self.remote_evaluators = self.make_remote_evaluators( - env_creator, policy_cls, config["num_workers"]) - self.optimizer = self._make_optimizer() +def make_async_optimizer(workers, config): + return AsyncGradientsOptimizer(workers, **config["optimizer"]) - @override(Trainer) - def _train(self): - prev_steps = self.optimizer.num_steps_sampled - start = time.time() - while time.time() - start < self.config["min_iter_time_s"]: - self.optimizer.step() - result = self.collect_metrics() - result.update(timesteps_this_iter=self.optimizer.num_steps_sampled - - prev_steps) - return result - def _make_optimizer(self): - return AsyncGradientsOptimizer(self.local_evaluator, - self.remote_evaluators, - **self.config["optimizer"]) +A3CTrainer = build_trainer( + name="A3C", + default_config=DEFAULT_CONFIG, + default_policy=A3CTFPolicy, + get_policy_class=get_policy_class, + validate_config=validate_config, + make_policy_optimizer=make_async_optimizer) diff --git a/python/ray/rllib/agents/ddpg/apex.py b/python/ray/rllib/agents/ddpg/apex.py index 24edbb226e5d..5ea732f17508 100644 --- a/python/ray/rllib/agents/ddpg/apex.py +++ b/python/ray/rllib/agents/ddpg/apex.py @@ -48,7 +48,7 @@ def update_target_if_needed(self): # Ape-X updates based on num steps trained, not sampled if self.optimizer.num_steps_trained - self.last_target_update_ts > \ self.config["target_network_update_freq"]: - self.local_evaluator.foreach_trainable_policy( + self.workers.local_worker().foreach_trainable_policy( lambda p, _: p.update_target()) self.last_target_update_ts = self.optimizer.num_steps_trained self.num_target_updates += 1 diff --git a/python/ray/rllib/agents/ddpg/ddpg.py b/python/ray/rllib/agents/ddpg/ddpg.py index 66d3810e5e93..a9676335eb3f 100644 --- a/python/ray/rllib/agents/ddpg/ddpg.py +++ b/python/ray/rllib/agents/ddpg/ddpg.py @@ -171,9 +171,9 @@ def _train(self): if pure_expl_steps: # tell workers whether they should do pure exploration only_explore = self.global_timestep < pure_expl_steps - self.local_evaluator.foreach_trainable_policy( + self.workers.local_worker().foreach_trainable_policy( lambda p, _: p.set_pure_exploration_phase(only_explore)) - for e in self.remote_evaluators: + for e in self.workers.remote_workers(): e.foreach_trainable_policy.remote( lambda p, _: p.set_pure_exploration_phase(only_explore)) return super(DDPGTrainer, self)._train() diff --git a/python/ray/rllib/agents/ddpg/ddpg_policy.py b/python/ray/rllib/agents/ddpg/ddpg_policy.py index b80cfce4cdaa..bb5fc25ef8af 100644 --- a/python/ray/rllib/agents/ddpg/ddpg_policy.py +++ b/python/ray/rllib/agents/ddpg/ddpg_policy.py @@ -515,7 +515,7 @@ def make_uniform_random_actions(): stochastic_actions = tf.cond( # need to condition on noise_scale > 0 because zeroing - # noise_scale is how evaluator signals no noise should be used + # noise_scale is how a worker signals no noise should be used # (this is ugly and should be fixed by adding an "eval_mode" # config flag or something) tf.logical_and(enable_pure_exploration, noise_scale > 0), diff --git a/python/ray/rllib/agents/dqn/apex.py b/python/ray/rllib/agents/dqn/apex.py index 27bde322a946..129839a27119 100644 --- a/python/ray/rllib/agents/dqn/apex.py +++ b/python/ray/rllib/agents/dqn/apex.py @@ -51,7 +51,7 @@ def update_target_if_needed(self): # Ape-X updates based on num steps trained, not sampled if self.optimizer.num_steps_trained - self.last_target_update_ts > \ self.config["target_network_update_freq"]: - self.local_evaluator.foreach_trainable_policy( + self.workers.local_worker().foreach_trainable_policy( lambda p, _: p.update_target()) self.last_target_update_ts = self.optimizer.num_steps_trained self.num_target_updates += 1 diff --git a/python/ray/rllib/agents/dqn/dqn.py b/python/ray/rllib/agents/dqn/dqn.py index 7fdb6f66b433..15379e3fb394 100644 --- a/python/ray/rllib/agents/dqn/dqn.py +++ b/python/ray/rllib/agents/dqn/dqn.py @@ -196,26 +196,26 @@ def on_episode_end(info): config["callbacks"]["on_episode_end"] = tune.function( on_episode_end) - self.local_evaluator = self.make_local_evaluator( - env_creator, self._policy) - - def create_remote_evaluators(): - return self.make_remote_evaluators(env_creator, self._policy, - config["num_workers"]) - if config["optimizer_class"] != "AsyncReplayOptimizer": - self.remote_evaluators = create_remote_evaluators() + self.workers = self._make_workers( + env_creator, + self._policy, + config, + num_workers=self.config["num_workers"]) + workers_needed = 0 else: # Hack to workaround https://github.com/ray-project/ray/issues/2541 - self.remote_evaluators = None + self.workers = self._make_workers( + env_creator, self._policy, config, num_workers=0) + workers_needed = self.config["num_workers"] self.optimizer = getattr(optimizers, config["optimizer_class"])( - self.local_evaluator, self.remote_evaluators, - **config["optimizer"]) - # Create the remote evaluators *after* the replay actors - if self.remote_evaluators is None: - self.remote_evaluators = create_remote_evaluators() - self.optimizer._set_evaluators(self.remote_evaluators) + self.workers, **config["optimizer"]) + + # Create the remote workers *after* the replay actors + if workers_needed > 0: + self.workers.add_workers(workers_needed) + self.optimizer._set_workers(self.workers.remote_workers()) self.last_target_update_ts = 0 self.num_target_updates = 0 @@ -226,9 +226,9 @@ def _train(self): # Update worker explorations exp_vals = [self.exploration0.value(self.global_timestep)] - self.local_evaluator.foreach_trainable_policy( + self.workers.local_worker().foreach_trainable_policy( lambda p, _: p.set_epsilon(exp_vals[0])) - for i, e in enumerate(self.remote_evaluators): + for i, e in enumerate(self.workers.remote_workers()): exp_val = self.explorations[i].value(self.global_timestep) e.foreach_trainable_policy.remote( lambda p, _: p.set_epsilon(exp_val)) @@ -245,8 +245,8 @@ def _train(self): if self.config["per_worker_exploration"]: # Only collect metrics from the third of workers with lowest eps result = self.collect_metrics( - selected_evaluators=self.remote_evaluators[ - -len(self.remote_evaluators) // 3:]) + selected_workers=self.workers.remote_workers()[ + -len(self.workers.remote_workers()) // 3:]) else: result = self.collect_metrics() @@ -263,7 +263,7 @@ def _train(self): def update_target_if_needed(self): if self.global_timestep - self.last_target_update_ts > \ self.config["target_network_update_freq"]: - self.local_evaluator.foreach_trainable_policy( + self.workers.local_worker().foreach_trainable_policy( lambda p, _: p.update_target()) self.last_target_update_ts = self.global_timestep self.num_target_updates += 1 @@ -275,11 +275,13 @@ def global_timestep(self): def _evaluate(self): logger.info("Evaluating current policy for {} episodes".format( self.config["evaluation_num_episodes"])) - self.evaluation_ev.restore(self.local_evaluator.save()) - self.evaluation_ev.foreach_policy(lambda p, _: p.set_epsilon(0)) + self.evaluation_workers.local_worker().restore( + self.workers.local_worker().save()) + self.evaluation_workers.local_worker().foreach_policy( + lambda p, _: p.set_epsilon(0)) for _ in range(self.config["evaluation_num_episodes"]): - self.evaluation_ev.sample() - metrics = collect_metrics(self.evaluation_ev) + self.evaluation_workers.local_worker().sample() + metrics = collect_metrics(self.evaluation_workers.local_worker()) return {"evaluation": metrics} def _make_exploration_schedule(self, worker_index): diff --git a/python/ray/rllib/agents/es/es.py b/python/ray/rllib/agents/es/es.py index e167129c6a93..f5338a632e86 100644 --- a/python/ray/rllib/agents/es/es.py +++ b/python/ray/rllib/agents/es/es.py @@ -192,7 +192,7 @@ def _init(self, config, env_creator): # Create the actors. logger.info("Creating actors.") - self.workers = [ + self._workers = [ Worker.remote(config, policy_params, env_creator, noise_id) for _ in range(config["num_workers"]) ] @@ -270,7 +270,7 @@ def _train(self): # Now sync the filters FilterManager.synchronize({ DEFAULT_POLICY_ID: self.policy.get_filter() - }, self.workers) + }, self._workers) info = { "weights_norm": np.square(theta).sum(), @@ -296,7 +296,7 @@ def compute_action(self, observation): @override(Trainer) def _stop(self): # workaround for https://github.com/ray-project/ray/issues/1516 - for w in self.workers: + for w in self._workers: w.__ray_terminate__.remote() def _collect_results(self, theta_id, min_episodes, min_timesteps): @@ -307,7 +307,7 @@ def _collect_results(self, theta_id, min_episodes, min_timesteps): "Collected {} episodes {} timesteps so far this iter".format( num_episodes, num_timesteps)) rollout_ids = [ - worker.do_rollouts.remote(theta_id) for worker in self.workers + worker.do_rollouts.remote(theta_id) for worker in self._workers ] # Get the results of the rollouts. for result in ray_get_and_free(rollout_ids): @@ -334,4 +334,4 @@ def __setstate__(self, state): self.policy.set_filter(state["filter"]) FilterManager.synchronize({ DEFAULT_POLICY_ID: self.policy.get_filter() - }, self.workers) + }, self._workers) diff --git a/python/ray/rllib/agents/impala/impala.py b/python/ray/rllib/agents/impala/impala.py index 838f2975ce67..e025a4817f8f 100644 --- a/python/ray/rllib/agents/impala/impala.py +++ b/python/ray/rllib/agents/impala/impala.py @@ -113,18 +113,16 @@ def _init(self, config, env_creator): if k not in config["optimizer"]: config["optimizer"][k] = config[k] policy_cls = self._get_policy() - self.local_evaluator = self.make_local_evaluator( - self.env_creator, policy_cls) + self.workers = self._make_workers( + self.env_creator, policy_cls, self.config, num_workers=0) if self.config["num_aggregation_workers"] > 0: # Create co-located aggregator actors first for placement pref aggregators = TreeAggregator.precreate_aggregators( self.config["num_aggregation_workers"]) - self.remote_evaluators = self.make_remote_evaluators( - env_creator, policy_cls, config["num_workers"]) - self.optimizer = AsyncSamplesOptimizer(self.local_evaluator, - self.remote_evaluators, + self.workers.add_workers(config["num_workers"]) + self.optimizer = AsyncSamplesOptimizer(self.workers, **config["optimizer"]) if config["entropy_coeff"] < 0: raise DeprecationWarning("entropy_coeff must be >= 0") diff --git a/python/ray/rllib/agents/marwil/marwil.py b/python/ray/rllib/agents/marwil/marwil.py index d6c6eadeaa9c..b8e01806ca29 100644 --- a/python/ray/rllib/agents/marwil/marwil.py +++ b/python/ray/rllib/agents/marwil/marwil.py @@ -48,13 +48,10 @@ class MARWILTrainer(Trainer): @override(Trainer) def _init(self, config, env_creator): - self.local_evaluator = self.make_local_evaluator( - env_creator, self._policy) - self.remote_evaluators = self.make_remote_evaluators( - env_creator, self._policy, config["num_workers"]) + self.workers = self._make_workers(env_creator, self._policy, config, + config["num_workers"]) self.optimizer = SyncBatchReplayOptimizer( - self.local_evaluator, - self.remote_evaluators, + self.workers, learning_starts=config["learning_starts"], buffer_size=config["replay_buffer_size"], train_batch_size=config["train_batch_size"], diff --git a/python/ray/rllib/agents/pg/pg.py b/python/ray/rllib/agents/pg/pg.py index 299cdcac3de4..71e2ab3fbd69 100644 --- a/python/ray/rllib/agents/pg/pg.py +++ b/python/ray/rllib/agents/pg/pg.py @@ -29,7 +29,7 @@ def get_policy_class(config): PGTrainer = build_trainer( - name="PGTrainer", + name="PG", default_config=DEFAULT_CONFIG, default_policy=PGTFPolicy, get_policy_class=get_policy_class) diff --git a/python/ray/rllib/agents/ppo/ppo.py b/python/ray/rllib/agents/ppo/ppo.py index daf43d14821d..a21c3d28fd50 100644 --- a/python/ray/rllib/agents/ppo/ppo.py +++ b/python/ray/rllib/agents/ppo/ppo.py @@ -63,17 +63,15 @@ # yapf: enable -def choose_policy_optimizer(local_evaluator, remote_evaluators, config): +def choose_policy_optimizer(workers, config): if config["simple_optimizer"]: return SyncSamplesOptimizer( - local_evaluator, - remote_evaluators, + workers, num_sgd_iter=config["num_sgd_iter"], train_batch_size=config["train_batch_size"]) return LocalMultiGPUOptimizer( - local_evaluator, - remote_evaluators, + workers, sgd_batch_size=config["sgd_minibatch_size"], num_sgd_iter=config["num_sgd_iter"], num_gpus=config["num_gpus"], @@ -87,7 +85,7 @@ def choose_policy_optimizer(local_evaluator, remote_evaluators, config): def update_kl(trainer, fetches): if "kl" in fetches: # single-agent - trainer.local_evaluator.for_policy( + trainer.workers.local_worker().for_policy( lambda pi: pi.update_kl(fetches["kl"])) else: @@ -98,7 +96,7 @@ def update(pi, pi_id): logger.debug("No data for {}, not updating kl".format(pi_id)) # multi-agent - trainer.local_evaluator.foreach_trainable_policy(update) + trainer.workers.local_worker().foreach_trainable_policy(update) def warn_about_obs_filter(trainer): @@ -155,7 +153,7 @@ def validate_config(config): PPOTrainer = build_trainer( - name="PPOTrainer", + name="PPO", default_config=DEFAULT_CONFIG, default_policy=PPOTFPolicy, make_policy_optimizer=choose_policy_optimizer, diff --git a/python/ray/rllib/agents/qmix/apex.py b/python/ray/rllib/agents/qmix/apex.py index f43a5ac121eb..65c91d655af2 100644 --- a/python/ray/rllib/agents/qmix/apex.py +++ b/python/ray/rllib/agents/qmix/apex.py @@ -50,7 +50,7 @@ def update_target_if_needed(self): # Ape-X updates based on num steps trained, not sampled if self.optimizer.num_steps_trained - self.last_target_update_ts > \ self.config["target_network_update_freq"]: - self.local_evaluator.foreach_trainable_policy( + self.workers.local_worker().foreach_trainable_policy( lambda p, _: p.update_target()) self.last_target_update_ts = self.optimizer.num_steps_trained self.num_target_updates += 1 diff --git a/python/ray/rllib/agents/trainer.py b/python/ray/rllib/agents/trainer.py index fb20f56baa21..f08b23e93fd7 100644 --- a/python/ray/rllib/agents/trainer.py +++ b/python/ray/rllib/agents/trainer.py @@ -10,18 +10,14 @@ import six import time import tempfile -from types import FunctionType import ray from ray.exceptions import RayError -from ray.rllib.offline import NoopOutput, JsonReader, MixedInput, JsonWriter, \ - ShuffledInput from ray.rllib.models import MODEL_DEFAULTS -from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator, \ - _validate_multiagent_config from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID from ray.rllib.evaluation.metrics import collect_metrics from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer +from ray.rllib.evaluation.worker_set import WorkerSet from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI from ray.rllib.utils import FilterManager, deep_update, merge_dicts from ray.rllib.utils.memory import ray_get_and_free @@ -46,7 +42,7 @@ # === Debugging === # Whether to write episode stats and videos to the agent log dir "monitor": False, - # Set the ray.rllib.* log level for the agent process and its evaluators. + # Set the ray.rllib.* log level for the agent process and its workers. # Should be one of DEBUG, INFO, WARN, or ERROR. The DEBUG level will also # periodically print out summaries of relevant internal dataflow (this is # also printed out once at startup at the INFO level). @@ -60,7 +56,7 @@ "on_episode_start": None, # arg: {"env": .., "episode": ...} "on_episode_step": None, # arg: {"env": .., "episode": ...} "on_episode_end": None, # arg: {"env": .., "episode": ...} - "on_sample_end": None, # arg: {"samples": .., "evaluator": ...} + "on_sample_end": None, # arg: {"samples": .., "worker": ...} "on_train_result": None, # arg: {"trainer": ..., "result": ...} "on_postprocess_traj": None, # arg: { # "agent_id": ..., "episode": ..., @@ -153,7 +149,7 @@ "synchronize_filters": True, # Configure TF for single-process operation by default "tf_session_args": { - # note: overriden by `local_evaluator_tf_session_args` + # note: overriden by `local_tf_session_args` "intra_op_parallelism_threads": 2, "inter_op_parallelism_threads": 2, "gpu_options": { @@ -165,8 +161,8 @@ }, "allow_soft_placement": True, # required by PPO multi-gpu }, - # Override the following tf session args on the local evaluator - "local_evaluator_tf_session_args": { + # Override the following tf session args on the local worker + "local_tf_session_args": { # Allow a higher level of parallelism by default, but not unlimited # since that can cause crashes with many concurrent drivers. "intra_op_parallelism_threads": 8, @@ -188,6 +184,8 @@ # but optimal value could be obtained by measuring your environment # step / reset and model inference perf. "remote_env_batch_wait_ms": 0, + # Minimum time per iteration + "min_iter_time_s": 0, # === Offline Datasets === # Specify how to generate experiences: @@ -229,7 +227,7 @@ # === Multiagent === "multiagent": { # Map from policy ids to tuples of (policy_cls, obs_space, - # act_space, config). See policy_evaluator.py for more info. + # act_space, config). See rollout_worker.py for more info. "policies": {}, # Function mapping agent ids to policy ids. "policy_mapping_fn": None, @@ -292,7 +290,7 @@ def __init__(self, config=None, env=None, logger_creator=None): config = config or {} - # Vars to synchronize to evaluators on each train call + # Vars to synchronize to workers on each train call self.global_vars = {"timestep": 0} # Trainers allow env ids to be passed directly to the constructor. @@ -337,9 +335,10 @@ def train(self): if self._has_policy_optimizer(): self.global_vars["timestep"] = self.optimizer.num_steps_sampled - self.optimizer.local_evaluator.set_global_vars(self.global_vars) - for ev in self.optimizer.remote_evaluators: - ev.set_global_vars.remote(self.global_vars) + self.optimizer.workers.local_worker().set_global_vars( + self.global_vars) + for w in self.optimizer.workers.remote_workers(): + w.set_global_vars.remote(self.global_vars) logger.debug("updated global vars: {}".format(self.global_vars)) result = None @@ -366,17 +365,18 @@ def train(self): raise RuntimeError("Failed to recover from worker crash") if (self.config.get("observation_filter", "NoFilter") != "NoFilter" - and hasattr(self, "local_evaluator")): + and hasattr(self, "workers") + and isinstance(self.workers, WorkerSet)): FilterManager.synchronize( - self.local_evaluator.filters, - self.remote_evaluators, + self.workers.local_worker().filters, + self.workers.remote_workers(), update_remote=self.config["synchronize_filters"]) logger.debug("synchronized filters: {}".format( - self.local_evaluator.filters)) + self.workers.local_worker().filters)) if self._has_policy_optimizer(): result["num_healthy_workers"] = len( - self.optimizer.remote_evaluators) + self.optimizer.workers.remote_workers()) if self.config["evaluation_interval"]: if self._iteration % self.config["evaluation_interval"] == 0: @@ -441,25 +441,17 @@ def get_scope(): }) logger.debug( "using evaluation_config: {}".format(extra_config)) - # Make local evaluation evaluators - self.evaluation_ev = self.make_local_evaluator( - self.env_creator, self._policy, extra_config=extra_config) + self.evaluation_workers = self._make_workers( + self.env_creator, + self._policy, + merge_dicts(self.config, extra_config), + num_workers=0) self.evaluation_metrics = self._evaluate() @override(Trainable) def _stop(self): - # Call stop on all evaluators to release resources - if hasattr(self, "local_evaluator"): - self.local_evaluator.stop() - if hasattr(self, "remote_evaluators"): - for ev in self.remote_evaluators: - ev.stop.remote() - - # workaround for https://github.com/ray-project/ray/issues/1516 - if hasattr(self, "remote_evaluators"): - for ev in self.remote_evaluators: - ev.__ray_terminate__.remote() - + if hasattr(self, "workers"): + self.workers.stop() if hasattr(self, "optimizer"): self.optimizer.stop() @@ -475,6 +467,15 @@ def _restore(self, checkpoint_path): extra_data = pickle.load(open(checkpoint_path, "rb")) self.__setstate__(extra_data) + @DeveloperAPI + def _make_workers(self, env_creator, policy, config, num_workers): + return WorkerSet( + env_creator, + policy, + config, + num_workers=num_workers, + logdir=self.logdir) + @DeveloperAPI def _init(self, config, env_creator): """Subclasses should override this for custom initialization.""" @@ -498,11 +499,12 @@ def _evaluate(self): logger.info("Evaluating current policy for {} episodes".format( self.config["evaluation_num_episodes"])) - self.evaluation_ev.restore(self.local_evaluator.save()) + self.evaluation_workers.local_worker().restore( + self.workers.local_worker().save()) for _ in range(self.config["evaluation_num_episodes"]): - self.evaluation_ev.sample() + self.evaluation_workers.local_worker().sample() - metrics = collect_metrics(self.evaluation_ev) + metrics = collect_metrics(self.evaluation_workers.local_worker()) return {"evaluation": metrics} @PublicAPI @@ -540,9 +542,9 @@ def compute_action(self, if state is None: state = [] - preprocessed = self.local_evaluator.preprocessors[policy_id].transform( - observation) - filtered_obs = self.local_evaluator.filters[policy_id]( + preprocessed = self.workers.local_worker().preprocessors[ + policy_id].transform(observation) + filtered_obs = self.workers.local_worker().filters[policy_id]( preprocessed, update=False) if state: return self.get_policy(policy_id).compute_single_action( @@ -590,7 +592,7 @@ def get_policy(self, policy_id=DEFAULT_POLICY_ID): policy_id (str): id of policy to return. """ - return self.local_evaluator.get_policy(policy_id) + return self.workers.local_worker().get_policy(policy_id) @PublicAPI def get_weights(self, policies=None): @@ -600,7 +602,7 @@ def get_weights(self, policies=None): policies (list): Optional list of policies to return weights for, or None for all policies. """ - return self.local_evaluator.get_weights(policies) + return self.workers.local_worker().get_weights(policies) @PublicAPI def set_weights(self, weights): @@ -609,42 +611,7 @@ def set_weights(self, weights): Arguments: weights (dict): Map of policy ids to weights to set. """ - self.local_evaluator.set_weights(weights) - - @DeveloperAPI - def make_local_evaluator(self, env_creator, policy, extra_config=None): - """Convenience method to return configured local evaluator.""" - - return self._make_evaluator( - PolicyEvaluator, - env_creator, - policy, - 0, - merge_dicts( - # important: allow local tf to use more CPUs for optimization - merge_dicts( - self.config, { - "tf_session_args": self. - config["local_evaluator_tf_session_args"] - }), - extra_config or {})) - - @DeveloperAPI - def make_remote_evaluators(self, env_creator, policy, count): - """Convenience method to return a number of remote evaluators.""" - - remote_args = { - "num_cpus": self.config["num_cpus_per_worker"], - "num_gpus": self.config["num_gpus_per_worker"], - "resources": self.config["custom_resources_per_worker"], - } - - cls = PolicyEvaluator.as_remote(**remote_args).remote - - return [ - self._make_evaluator(cls, env_creator, policy, i + 1, self.config) - for i in range(count) - ] + self.workers.local_worker().set_weights(weights) @DeveloperAPI def export_policy_model(self, export_dir, policy_id=DEFAULT_POLICY_ID): @@ -660,7 +627,7 @@ def export_policy_model(self, export_dir, policy_id=DEFAULT_POLICY_ID): >>> trainer.train() >>> trainer.export_policy_model("/tmp/export_dir") """ - self.local_evaluator.export_policy_model(export_dir, policy_id) + self.workers.local_worker().export_policy_model(export_dir, policy_id) @DeveloperAPI def export_policy_checkpoint(self, @@ -680,19 +647,19 @@ def export_policy_checkpoint(self, >>> trainer.train() >>> trainer.export_policy_checkpoint("/tmp/export_dir") """ - self.local_evaluator.export_policy_checkpoint( + self.workers.local_worker().export_policy_checkpoint( export_dir, filename_prefix, policy_id) @DeveloperAPI - def collect_metrics(self, selected_evaluators=None): - """Collects metrics from the remote evaluators of this agent. + def collect_metrics(self, selected_workers=None): + """Collects metrics from the remote workers of this agent. This is the same data as returned by a call to train(). """ return self.optimizer.collect_metrics( self.config["collect_metrics_timeout"], min_history=self.config["metrics_smoothing_episodes"], - selected_evaluators=selected_evaluators) + selected_workers=selected_workers) @classmethod def resource_help(cls, config): @@ -742,118 +709,34 @@ def _try_recover(self): logger.info("Health checking all workers...") checks = [] - for ev in self.optimizer.remote_evaluators: + for ev in self.optimizer.workers.remote_workers(): _, obj_id = ev.sample_with_count.remote() checks.append(obj_id) - healthy_evaluators = [] + healthy_workers = [] for i, obj_id in enumerate(checks): - ev = self.optimizer.remote_evaluators[i] + w = self.optimizer.workers.remote_workers()[i] try: ray_get_and_free(obj_id) - healthy_evaluators.append(ev) + healthy_workers.append(w) logger.info("Worker {} looks healthy".format(i + 1)) except RayError: logger.exception("Blacklisting worker {}".format(i + 1)) try: - ev.__ray_terminate__.remote() + w.__ray_terminate__.remote() except Exception: logger.exception("Error terminating unhealthy worker") - if len(healthy_evaluators) < 1: + if len(healthy_workers) < 1: raise RuntimeError( "Not enough healthy workers remain to continue.") - self.optimizer.reset(healthy_evaluators) + self.optimizer.reset(healthy_workers) def _has_policy_optimizer(self): return hasattr(self, "optimizer") and isinstance( self.optimizer, PolicyOptimizer) - def _make_evaluator(self, cls, env_creator, policy, worker_index, config): - def session_creator(): - logger.debug("Creating TF session {}".format( - config["tf_session_args"])) - return tf.Session( - config=tf.ConfigProto(**config["tf_session_args"])) - - if isinstance(config["input"], FunctionType): - input_creator = config["input"] - elif config["input"] == "sampler": - input_creator = (lambda ioctx: ioctx.default_sampler_input()) - elif isinstance(config["input"], dict): - input_creator = (lambda ioctx: ShuffledInput( - MixedInput(config["input"], ioctx), config[ - "shuffle_buffer_size"])) - else: - input_creator = (lambda ioctx: ShuffledInput( - JsonReader(config["input"], ioctx), config[ - "shuffle_buffer_size"])) - - if isinstance(config["output"], FunctionType): - output_creator = config["output"] - elif config["output"] is None: - output_creator = (lambda ioctx: NoopOutput()) - elif config["output"] == "logdir": - output_creator = (lambda ioctx: JsonWriter( - ioctx.log_dir, - ioctx, - max_file_size=config["output_max_file_size"], - compress_columns=config["output_compress_columns"])) - else: - output_creator = (lambda ioctx: JsonWriter( - config["output"], - ioctx, - max_file_size=config["output_max_file_size"], - compress_columns=config["output_compress_columns"])) - - if config["input"] == "sampler": - input_evaluation = [] - else: - input_evaluation = config["input_evaluation"] - - # Fill in the default policy if 'None' is specified in multiagent - if self.config["multiagent"]["policies"]: - tmp = self.config["multiagent"]["policies"] - _validate_multiagent_config(tmp, allow_none_graph=True) - for k, v in tmp.items(): - if v[0] is None: - tmp[k] = (policy, v[1], v[2], v[3]) - policy = tmp - - return cls( - env_creator, - policy, - policy_mapping_fn=self.config["multiagent"]["policy_mapping_fn"], - policies_to_train=self.config["multiagent"]["policies_to_train"], - tf_session_creator=(session_creator - if config["tf_session_args"] else None), - batch_steps=config["sample_batch_size"], - batch_mode=config["batch_mode"], - episode_horizon=config["horizon"], - preprocessor_pref=config["preprocessor_pref"], - sample_async=config["sample_async"], - compress_observations=config["compress_observations"], - num_envs=config["num_envs_per_worker"], - observation_filter=config["observation_filter"], - clip_rewards=config["clip_rewards"], - clip_actions=config["clip_actions"], - env_config=config["env_config"], - model_config=config["model"], - policy_config=config, - worker_index=worker_index, - monitor_path=self.logdir if config["monitor"] else None, - log_dir=self.logdir, - log_level=config["log_level"], - callbacks=config["callbacks"], - input_creator=input_creator, - input_evaluation=input_evaluation, - output_creator=output_creator, - remote_worker_envs=config["remote_worker_envs"], - remote_env_batch_wait_ms=config["remote_env_batch_wait_ms"], - soft_horizon=config["soft_horizon"], - _fake_sampler=config.get("_fake_sampler", False)) - @override(Trainable) def _export_model(self, export_formats, export_dir): ExportFormat.validate(export_formats) @@ -870,17 +753,17 @@ def _export_model(self, export_formats, export_dir): def __getstate__(self): state = {} - if hasattr(self, "local_evaluator"): - state["evaluator"] = self.local_evaluator.save() + if hasattr(self, "workers"): + state["worker"] = self.workers.local_worker().save() if hasattr(self, "optimizer") and hasattr(self.optimizer, "save"): state["optimizer"] = self.optimizer.save() return state def __setstate__(self, state): - if "evaluator" in state: - self.local_evaluator.restore(state["evaluator"]) - remote_state = ray.put(state["evaluator"]) - for r in self.remote_evaluators: + if "worker" in state: + self.workers.local_worker().restore(state["worker"]) + remote_state = ray.put(state["worker"]) + for r in self.workers.remote_workers(): r.restore.remote(remote_state) if "optimizer" in state: self.optimizer.restore(state["optimizer"]) diff --git a/python/ray/rllib/agents/trainer_template.py b/python/ray/rllib/agents/trainer_template.py index aae8e35f64f8..6af9e1c781e0 100644 --- a/python/ray/rllib/agents/trainer_template.py +++ b/python/ray/rllib/agents/trainer_template.py @@ -2,6 +2,8 @@ from __future__ import division from __future__ import print_function +import time + from ray.rllib.agents.trainer import Trainer, COMMON_CONFIG from ray.rllib.optimizers import SyncSamplesOptimizer from ray.rllib.utils.annotations import override, DeveloperAPI @@ -25,8 +27,7 @@ def build_trainer(name, default_config (dict): the default config dict of the algorithm, otherwises uses the Trainer default config make_policy_optimizer (func): optional function that returns a - PolicyOptimizer instance given - (local_evaluator, remote_evaluators, config) + PolicyOptimizer instance given (WorkerSet, config) validate_config (func): optional callback that checks a given config for correctness. It may mutate the config as needed. get_policy_class (func): optional callback that takes a config and @@ -44,8 +45,7 @@ def build_trainer(name, a Trainer instance that uses the specified args. """ - if not name.endswith("Trainer"): - raise ValueError("Algorithm name should have *Trainer suffix", name) + original_kwargs = locals().copy() class trainer_cls(Trainer): _name = name @@ -59,19 +59,15 @@ def _init(self, config, env_creator): policy = default_policy else: policy = get_policy_class(config) - self.local_evaluator = self.make_local_evaluator( - env_creator, policy) - self.remote_evaluators = self.make_remote_evaluators( - env_creator, policy, config["num_workers"]) + self.workers = self._make_workers(env_creator, policy, config, + self.config["num_workers"]) if make_policy_optimizer: - self.optimizer = make_policy_optimizer( - self.local_evaluator, self.remote_evaluators, config) + self.optimizer = make_policy_optimizer(self.workers, config) else: optimizer_config = dict( config["optimizer"], **{"train_batch_size": config["train_batch_size"]}) - self.optimizer = SyncSamplesOptimizer(self.local_evaluator, - self.remote_evaluators, + self.optimizer = SyncSamplesOptimizer(self.workers, **optimizer_config) @override(Trainer) @@ -79,9 +75,15 @@ def _train(self): if before_train_step: before_train_step(self) prev_steps = self.optimizer.num_steps_sampled - fetches = self.optimizer.step() - if after_optimizer_step: - after_optimizer_step(self, fetches) + + start = time.time() + while True: + fetches = self.optimizer.step() + if after_optimizer_step: + after_optimizer_step(self, fetches) + if time.time() - start > self.config["min_iter_time_s"]: + break + res = self.collect_metrics() res.update( timesteps_this_iter=self.optimizer.num_steps_sampled - @@ -91,6 +93,11 @@ def _train(self): after_train_result(self, res) return res + @staticmethod + def with_updates(**overrides): + return build_trainer(**dict(original_kwargs, **overrides)) + + trainer_cls.with_updates = with_updates trainer_cls.__name__ = name trainer_cls.__qualname__ = name return trainer_cls diff --git a/python/ray/rllib/env/base_env.py b/python/ray/rllib/env/base_env.py index 5db799c3282d..a36c3e228e66 100644 --- a/python/ray/rllib/env/base_env.py +++ b/python/ray/rllib/env/base_env.py @@ -21,7 +21,7 @@ class BaseEnv(object): can be sent back via send_actions(). All other env types can be adapted to BaseEnv. RLlib handles these - conversions internally in PolicyEvaluator, for example: + conversions internally in RolloutWorker, for example: gym.Env => rllib.VectorEnv => rllib.BaseEnv rllib.MultiAgentEnv => rllib.BaseEnv diff --git a/python/ray/rllib/evaluation/__init__.py b/python/ray/rllib/evaluation/__init__.py index 7e56bb7479a0..f743cca64772 100644 --- a/python/ray/rllib/evaluation/__init__.py +++ b/python/ray/rllib/evaluation/__init__.py @@ -1,4 +1,5 @@ from ray.rllib.evaluation.episode import MultiAgentEpisode +from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator from ray.rllib.evaluation.interface import EvaluatorInterface from ray.rllib.evaluation.policy_graph import PolicyGraph @@ -12,8 +13,19 @@ from ray.rllib.evaluation.metrics import collect_metrics __all__ = [ - "EvaluatorInterface", "PolicyEvaluator", "PolicyGraph", "TFPolicyGraph", - "TorchPolicyGraph", "SampleBatch", "MultiAgentBatch", "SampleBatchBuilder", - "MultiAgentSampleBatchBuilder", "SyncSampler", "AsyncSampler", - "compute_advantages", "collect_metrics", "MultiAgentEpisode" + "EvaluatorInterface", + "RolloutWorker", + "PolicyGraph", + "TFPolicyGraph", + "TorchPolicyGraph", + "SampleBatch", + "MultiAgentBatch", + "SampleBatchBuilder", + "MultiAgentSampleBatchBuilder", + "SyncSampler", + "AsyncSampler", + "compute_advantages", + "collect_metrics", + "MultiAgentEpisode", + "PolicyEvaluator", ] diff --git a/python/ray/rllib/evaluation/interface.py b/python/ray/rllib/evaluation/interface.py index 6bc626da1175..06fa9f94ec97 100644 --- a/python/ray/rllib/evaluation/interface.py +++ b/python/ray/rllib/evaluation/interface.py @@ -11,7 +11,7 @@ class EvaluatorInterface(object): """This is the interface between policy optimizers and policy evaluation. - See also: PolicyEvaluator + See also: RolloutWorker """ @DeveloperAPI diff --git a/python/ray/rllib/evaluation/metrics.py b/python/ray/rllib/evaluation/metrics.py index d8b3122fed4b..341327608db3 100644 --- a/python/ray/rllib/evaluation/metrics.py +++ b/python/ray/rllib/evaluation/metrics.py @@ -39,27 +39,23 @@ def get_learner_stats(grad_info): @DeveloperAPI -def collect_metrics(local_evaluator=None, - remote_evaluators=[], - timeout_seconds=180): - """Gathers episode metrics from PolicyEvaluator instances.""" +def collect_metrics(local_worker=None, remote_workers=[], timeout_seconds=180): + """Gathers episode metrics from RolloutWorker instances.""" episodes, num_dropped = collect_episodes( - local_evaluator, remote_evaluators, timeout_seconds=timeout_seconds) + local_worker, remote_workers, timeout_seconds=timeout_seconds) metrics = summarize_episodes(episodes, episodes, num_dropped) return metrics @DeveloperAPI -def collect_episodes(local_evaluator=None, - remote_evaluators=[], +def collect_episodes(local_worker=None, remote_workers=[], timeout_seconds=180): """Gathers new episodes metrics tuples from the given evaluators.""" - if remote_evaluators: + if remote_workers: pending = [ - a.apply.remote(lambda ev: ev.get_metrics()) - for a in remote_evaluators + a.apply.remote(lambda ev: ev.get_metrics()) for a in remote_workers ] collected, _ = ray.wait( pending, num_returns=len(pending), timeout=timeout_seconds * 1.0) @@ -73,8 +69,8 @@ def collect_episodes(local_evaluator=None, metric_lists = [] num_metric_batches_dropped = 0 - if local_evaluator: - metric_lists.append(local_evaluator.get_metrics()) + if local_worker: + metric_lists.append(local_worker.get_metrics()) episodes = [] for metrics in metric_lists: episodes.extend(metrics) diff --git a/python/ray/rllib/evaluation/policy_evaluator.py b/python/ray/rllib/evaluation/policy_evaluator.py index 40df71006a8c..18dec8abc80b 100644 --- a/python/ray/rllib/evaluation/policy_evaluator.py +++ b/python/ray/rllib/evaluation/policy_evaluator.py @@ -2,805 +2,8 @@ from __future__ import division from __future__ import print_function -import gym -import logging -import pickle +from ray.rllib.utils import renamed_class +from ray.rllib.evaluation import RolloutWorker -import ray -from ray.rllib.env.atari_wrappers import wrap_deepmind, is_atari -from ray.rllib.env.base_env import BaseEnv -from ray.rllib.env.env_context import EnvContext -from ray.rllib.env.external_env import ExternalEnv -from ray.rllib.env.multi_agent_env import MultiAgentEnv -from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv -from ray.rllib.env.vector_env import VectorEnv -from ray.rllib.evaluation.interface import EvaluatorInterface -from ray.rllib.evaluation.sampler import AsyncSampler, SyncSampler -from ray.rllib.policy.sample_batch import MultiAgentBatch, DEFAULT_POLICY_ID -from ray.rllib.policy.policy import Policy -from ray.rllib.policy.tf_policy import TFPolicy -from ray.rllib.offline import NoopOutput, IOContext, OutputWriter, InputReader -from ray.rllib.offline.is_estimator import ImportanceSamplingEstimator -from ray.rllib.offline.wis_estimator import WeightedImportanceSamplingEstimator -from ray.rllib.models import ModelCatalog -from ray.rllib.models.preprocessors import NoPreprocessor -from ray.rllib.utils import merge_dicts -from ray.rllib.utils.annotations import override, DeveloperAPI -from ray.rllib.utils.debug import disable_log_once_globally, log_once, \ - summarize, enable_periodic_logging -from ray.rllib.utils.filter import get_filter -from ray.rllib.utils.tf_run_builder import TFRunBuilder -from ray.rllib.utils import try_import_tf - -tf = try_import_tf() -logger = logging.getLogger(__name__) - -# Handle to the current evaluator, which will be set to the most recently -# created PolicyEvaluator in this process. This can be helpful to access in -# custom env or policy classes for debugging or advanced use cases. -_global_evaluator = None - - -@DeveloperAPI -def get_global_evaluator(): - """Returns a handle to the active policy evaluator in this process.""" - - global _global_evaluator - return _global_evaluator - - -@DeveloperAPI -class PolicyEvaluator(EvaluatorInterface): - """Common ``PolicyEvaluator`` implementation that wraps a ``Policy``. - - This class wraps a policy instance and an environment class to - collect experiences from the environment. You can create many replicas of - this class as Ray actors to scale RL training. - - This class supports vectorized and multi-agent policy evaluation (e.g., - VectorEnv, MultiAgentEnv, etc.) - - Examples: - >>> # Create a policy evaluator and using it to collect experiences. - >>> evaluator = PolicyEvaluator( - ... env_creator=lambda _: gym.make("CartPole-v0"), - ... policy=PGTFPolicy) - >>> print(evaluator.sample()) - SampleBatch({ - "obs": [[...]], "actions": [[...]], "rewards": [[...]], - "dones": [[...]], "new_obs": [[...]]}) - - >>> # Creating policy evaluators using optimizer_cls.make(). - >>> optimizer = SyncSamplesOptimizer.make( - ... evaluator_cls=PolicyEvaluator, - ... evaluator_args={ - ... "env_creator": lambda _: gym.make("CartPole-v0"), - ... "policy": PGTFPolicy, - ... }, - ... num_workers=10) - >>> for _ in range(10): optimizer.step() - - >>> # Creating a multi-agent policy evaluator - >>> evaluator = PolicyEvaluator( - ... env_creator=lambda _: MultiAgentTrafficGrid(num_cars=25), - ... policies={ - ... # Use an ensemble of two policies for car agents - ... "car_policy1": - ... (PGTFPolicy, Box(...), Discrete(...), {"gamma": 0.99}), - ... "car_policy2": - ... (PGTFPolicy, Box(...), Discrete(...), {"gamma": 0.95}), - ... # Use a single shared policy for all traffic lights - ... "traffic_light_policy": - ... (PGTFPolicy, Box(...), Discrete(...), {}), - ... }, - ... policy_mapping_fn=lambda agent_id: - ... random.choice(["car_policy1", "car_policy2"]) - ... if agent_id.startswith("car_") else "traffic_light_policy") - >>> print(evaluator.sample()) - MultiAgentBatch({ - "car_policy1": SampleBatch(...), - "car_policy2": SampleBatch(...), - "traffic_light_policy": SampleBatch(...)}) - """ - - @DeveloperAPI - @classmethod - def as_remote(cls, num_cpus=None, num_gpus=None, resources=None): - return ray.remote( - num_cpus=num_cpus, num_gpus=num_gpus, resources=resources)(cls) - - @DeveloperAPI - def __init__(self, - env_creator, - policy, - policy_mapping_fn=None, - policies_to_train=None, - tf_session_creator=None, - batch_steps=100, - batch_mode="truncate_episodes", - episode_horizon=None, - preprocessor_pref="deepmind", - sample_async=False, - compress_observations=False, - num_envs=1, - observation_filter="NoFilter", - clip_rewards=None, - clip_actions=True, - env_config=None, - model_config=None, - policy_config=None, - worker_index=0, - monitor_path=None, - log_dir=None, - log_level=None, - callbacks=None, - input_creator=lambda ioctx: ioctx.default_sampler_input(), - input_evaluation=frozenset([]), - output_creator=lambda ioctx: NoopOutput(), - remote_worker_envs=False, - remote_env_batch_wait_ms=0, - soft_horizon=False, - _fake_sampler=False): - """Initialize a policy evaluator. - - Arguments: - env_creator (func): Function that returns a gym.Env given an - EnvContext wrapped configuration. - policy (class|dict): Either a class implementing - Policy, or a dictionary of policy id strings to - (Policy, obs_space, action_space, config) tuples. If a - dict is specified, then we are in multi-agent mode and a - policy_mapping_fn should also be set. - policy_mapping_fn (func): A function that maps agent ids to - policy ids in multi-agent mode. This function will be called - each time a new agent appears in an episode, to bind that agent - to a policy for the duration of the episode. - policies_to_train (list): Optional whitelist of policies to train, - or None for all policies. - tf_session_creator (func): A function that returns a TF session. - This is optional and only useful with TFPolicy. - batch_steps (int): The target number of env transitions to include - in each sample batch returned from this evaluator. - batch_mode (str): One of the following batch modes: - "truncate_episodes": Each call to sample() will return a batch - of at most `batch_steps * num_envs` in size. The batch will - be exactly `batch_steps * num_envs` in size if - postprocessing does not change batch sizes. Episodes may be - truncated in order to meet this size requirement. - "complete_episodes": Each call to sample() will return a batch - of at least `batch_steps * num_envs` in size. Episodes will - not be truncated, but multiple episodes may be packed - within one batch to meet the batch size. Note that when - `num_envs > 1`, episode steps will be buffered until the - episode completes, and hence batches may contain - significant amounts of off-policy data. - episode_horizon (int): Whether to stop episodes at this horizon. - preprocessor_pref (str): Whether to prefer RLlib preprocessors - ("rllib") or deepmind ("deepmind") when applicable. - sample_async (bool): Whether to compute samples asynchronously in - the background, which improves throughput but can cause samples - to be slightly off-policy. - compress_observations (bool): If true, compress the observations. - They can be decompressed with rllib/utils/compression. - num_envs (int): If more than one, will create multiple envs - and vectorize the computation of actions. This has no effect if - if the env already implements VectorEnv. - observation_filter (str): Name of observation filter to use. - clip_rewards (bool): Whether to clip rewards to [-1, 1] prior to - experience postprocessing. Setting to None means clip for Atari - only. - clip_actions (bool): Whether to clip action values to the range - specified by the policy action space. - env_config (dict): Config to pass to the env creator. - model_config (dict): Config to use when creating the policy model. - policy_config (dict): Config to pass to the policy. In the - multi-agent case, this config will be merged with the - per-policy configs specified by `policy`. - worker_index (int): For remote evaluators, this should be set to a - non-zero and unique value. This index is passed to created envs - through EnvContext so that envs can be configured per worker. - monitor_path (str): Write out episode stats and videos to this - directory if specified. - log_dir (str): Directory where logs can be placed. - log_level (str): Set the root log level on creation. - callbacks (dict): Dict of custom debug callbacks. - input_creator (func): Function that returns an InputReader object - for loading previous generated experiences. - input_evaluation (list): How to evaluate the policy performance. - This only makes sense to set when the input is reading offline - data. The possible values include: - - "is": the step-wise importance sampling estimator. - - "wis": the weighted step-wise is estimator. - - "simulation": run the environment in the background, but - use this data for evaluation only and never for learning. - output_creator (func): Function that returns an OutputWriter object - for saving generated experiences. - remote_worker_envs (bool): If using num_envs > 1, whether to create - those new envs in remote processes instead of in the current - process. This adds overheads, but can make sense if your envs - remote_env_batch_wait_ms (float): Timeout that remote workers - are waiting when polling environments. 0 (continue when at - least one env is ready) is a reasonable default, but optimal - value could be obtained by measuring your environment - step / reset and model inference perf. - soft_horizon (bool): Calculate rewards but don't reset the - environment when the horizon is hit. - _fake_sampler (bool): Use a fake (inf speed) sampler for testing. - """ - - global _global_evaluator - _global_evaluator = self - - if log_level: - logging.getLogger("ray.rllib").setLevel(log_level) - - if worker_index > 1: - disable_log_once_globally() # only need 1 evaluator to log - elif log_level == "DEBUG": - enable_periodic_logging() - - env_context = EnvContext(env_config or {}, worker_index) - policy_config = policy_config or {} - self.policy_config = policy_config - self.callbacks = callbacks or {} - self.worker_index = worker_index - model_config = model_config or {} - policy_mapping_fn = (policy_mapping_fn - or (lambda agent_id: DEFAULT_POLICY_ID)) - if not callable(policy_mapping_fn): - raise ValueError( - "Policy mapping function not callable. If you're using Tune, " - "make sure to escape the function with tune.function() " - "to prevent it from being evaluated as an expression.") - self.env_creator = env_creator - self.sample_batch_size = batch_steps * num_envs - self.batch_mode = batch_mode - self.compress_observations = compress_observations - self.preprocessing_enabled = True - self.last_batch = None - self._fake_sampler = _fake_sampler - - self.env = _validate_env(env_creator(env_context)) - if isinstance(self.env, MultiAgentEnv) or \ - isinstance(self.env, BaseEnv): - - def wrap(env): - return env # we can't auto-wrap these env types - elif is_atari(self.env) and \ - not model_config.get("custom_preprocessor") and \ - preprocessor_pref == "deepmind": - - # Deepmind wrappers already handle all preprocessing - self.preprocessing_enabled = False - - if clip_rewards is None: - clip_rewards = True - - def wrap(env): - env = wrap_deepmind( - env, - dim=model_config.get("dim"), - framestack=model_config.get("framestack")) - if monitor_path: - env = _monitor(env, monitor_path) - return env - else: - - def wrap(env): - if monitor_path: - env = _monitor(env, monitor_path) - return env - - self.env = wrap(self.env) - - def make_env(vector_index): - return wrap( - env_creator( - env_context.copy_with_overrides( - vector_index=vector_index, remote=remote_worker_envs))) - - self.tf_sess = None - policy_dict = _validate_and_canonicalize(policy, self.env) - self.policies_to_train = policies_to_train or list(policy_dict.keys()) - if _has_tensorflow_graph(policy_dict): - if (ray.is_initialized() - and ray.worker._mode() != ray.worker.LOCAL_MODE - and not ray.get_gpu_ids()): - logger.info("Creating policy evaluation worker {}".format( - worker_index) + - " on CPU (please ignore any CUDA init errors)") - with tf.Graph().as_default(): - if tf_session_creator: - self.tf_sess = tf_session_creator() - else: - self.tf_sess = tf.Session( - config=tf.ConfigProto( - gpu_options=tf.GPUOptions(allow_growth=True))) - with self.tf_sess.as_default(): - self.policy_map, self.preprocessors = \ - self._build_policy_map(policy_dict, policy_config) - else: - self.policy_map, self.preprocessors = self._build_policy_map( - policy_dict, policy_config) - - self.multiagent = set(self.policy_map.keys()) != {DEFAULT_POLICY_ID} - if self.multiagent: - if not ((isinstance(self.env, MultiAgentEnv) - or isinstance(self.env, ExternalMultiAgentEnv)) - or isinstance(self.env, BaseEnv)): - raise ValueError( - "Have multiple policies {}, but the env ".format( - self.policy_map) + - "{} is not a subclass of BaseEnv, MultiAgentEnv or " - "ExternalMultiAgentEnv?".format(self.env)) - - self.filters = { - policy_id: get_filter(observation_filter, - policy.observation_space.shape) - for (policy_id, policy) in self.policy_map.items() - } - if self.worker_index == 0: - logger.info("Built filter map: {}".format(self.filters)) - - # Always use vector env for consistency even if num_envs = 1 - self.async_env = BaseEnv.to_base_env( - self.env, - make_env=make_env, - num_envs=num_envs, - remote_envs=remote_worker_envs, - remote_env_batch_wait_ms=remote_env_batch_wait_ms) - self.num_envs = num_envs - - if self.batch_mode == "truncate_episodes": - unroll_length = batch_steps - pack_episodes = True - elif self.batch_mode == "complete_episodes": - unroll_length = float("inf") # never cut episodes - pack_episodes = False # sampler will return 1 episode per poll - else: - raise ValueError("Unsupported batch mode: {}".format( - self.batch_mode)) - - self.io_context = IOContext(log_dir, policy_config, worker_index, self) - self.reward_estimators = [] - for method in input_evaluation: - if method == "simulation": - logger.warning( - "Requested 'simulation' input evaluation method: " - "will discard all sampler outputs and keep only metrics.") - sample_async = True - elif method == "is": - ise = ImportanceSamplingEstimator.create(self.io_context) - self.reward_estimators.append(ise) - elif method == "wis": - wise = WeightedImportanceSamplingEstimator.create( - self.io_context) - self.reward_estimators.append(wise) - else: - raise ValueError( - "Unknown evaluation method: {}".format(method)) - - if sample_async: - self.sampler = AsyncSampler( - self.async_env, - self.policy_map, - policy_mapping_fn, - self.preprocessors, - self.filters, - clip_rewards, - unroll_length, - self.callbacks, - horizon=episode_horizon, - pack=pack_episodes, - tf_sess=self.tf_sess, - clip_actions=clip_actions, - blackhole_outputs="simulation" in input_evaluation, - soft_horizon=soft_horizon) - self.sampler.start() - else: - self.sampler = SyncSampler( - self.async_env, - self.policy_map, - policy_mapping_fn, - self.preprocessors, - self.filters, - clip_rewards, - unroll_length, - self.callbacks, - horizon=episode_horizon, - pack=pack_episodes, - tf_sess=self.tf_sess, - clip_actions=clip_actions, - soft_horizon=soft_horizon) - - self.input_reader = input_creator(self.io_context) - assert isinstance(self.input_reader, InputReader), self.input_reader - self.output_writer = output_creator(self.io_context) - assert isinstance(self.output_writer, OutputWriter), self.output_writer - - logger.debug("Created evaluator with env {} ({}), policies {}".format( - self.async_env, self.env, self.policy_map)) - - @override(EvaluatorInterface) - def sample(self): - """Evaluate the current policies and return a batch of experiences. - - Return: - SampleBatch|MultiAgentBatch from evaluating the current policies. - """ - - if self._fake_sampler and self.last_batch is not None: - return self.last_batch - - if log_once("sample_start"): - logger.info("Generating sample batch of size {}".format( - self.sample_batch_size)) - - batches = [self.input_reader.next()] - steps_so_far = batches[0].count - - # In truncate_episodes mode, never pull more than 1 batch per env. - # This avoids over-running the target batch size. - if self.batch_mode == "truncate_episodes": - max_batches = self.num_envs - else: - max_batches = float("inf") - - while steps_so_far < self.sample_batch_size and len( - batches) < max_batches: - batch = self.input_reader.next() - steps_so_far += batch.count - batches.append(batch) - batch = batches[0].concat_samples(batches) - - if self.callbacks.get("on_sample_end"): - self.callbacks["on_sample_end"]({ - "evaluator": self, - "samples": batch - }) - - # Always do writes prior to compression for consistency and to allow - # for better compression inside the writer. - self.output_writer.write(batch) - - # Do off-policy estimation if needed - if self.reward_estimators: - for sub_batch in batch.split_by_episode(): - for estimator in self.reward_estimators: - estimator.process(sub_batch) - - if log_once("sample_end"): - logger.info("Completed sample batch:\n\n{}\n".format( - summarize(batch))) - - if self.compress_observations == "bulk": - batch.compress(bulk=True) - elif self.compress_observations: - batch.compress() - - if self._fake_sampler: - self.last_batch = batch - return batch - - @DeveloperAPI - @ray.method(num_return_vals=2) - def sample_with_count(self): - """Same as sample() but returns the count as a separate future.""" - batch = self.sample() - return batch, batch.count - - @override(EvaluatorInterface) - def get_weights(self, policies=None): - if policies is None: - policies = self.policy_map.keys() - return { - pid: policy.get_weights() - for pid, policy in self.policy_map.items() if pid in policies - } - - @override(EvaluatorInterface) - def set_weights(self, weights): - for pid, w in weights.items(): - self.policy_map[pid].set_weights(w) - - @override(EvaluatorInterface) - def compute_gradients(self, samples): - if log_once("compute_gradients"): - logger.info("Compute gradients on:\n\n{}\n".format( - summarize(samples))) - if isinstance(samples, MultiAgentBatch): - grad_out, info_out = {}, {} - if self.tf_sess is not None: - builder = TFRunBuilder(self.tf_sess, "compute_gradients") - for pid, batch in samples.policy_batches.items(): - if pid not in self.policies_to_train: - continue - grad_out[pid], info_out[pid] = ( - self.policy_map[pid]._build_compute_gradients( - builder, batch)) - grad_out = {k: builder.get(v) for k, v in grad_out.items()} - info_out = {k: builder.get(v) for k, v in info_out.items()} - else: - for pid, batch in samples.policy_batches.items(): - if pid not in self.policies_to_train: - continue - grad_out[pid], info_out[pid] = ( - self.policy_map[pid].compute_gradients(batch)) - else: - grad_out, info_out = ( - self.policy_map[DEFAULT_POLICY_ID].compute_gradients(samples)) - info_out["batch_count"] = samples.count - if log_once("grad_out"): - logger.info("Compute grad info:\n\n{}\n".format( - summarize(info_out))) - return grad_out, info_out - - @override(EvaluatorInterface) - def apply_gradients(self, grads): - if log_once("apply_gradients"): - logger.info("Apply gradients:\n\n{}\n".format(summarize(grads))) - if isinstance(grads, dict): - if self.tf_sess is not None: - builder = TFRunBuilder(self.tf_sess, "apply_gradients") - outputs = { - pid: self.policy_map[pid]._build_apply_gradients( - builder, grad) - for pid, grad in grads.items() - } - return {k: builder.get(v) for k, v in outputs.items()} - else: - return { - pid: self.policy_map[pid].apply_gradients(g) - for pid, g in grads.items() - } - else: - return self.policy_map[DEFAULT_POLICY_ID].apply_gradients(grads) - - @override(EvaluatorInterface) - def learn_on_batch(self, samples): - if log_once("learn_on_batch"): - logger.info( - "Training on concatenated sample batches:\n\n{}\n".format( - summarize(samples))) - if isinstance(samples, MultiAgentBatch): - info_out = {} - to_fetch = {} - if self.tf_sess is not None: - builder = TFRunBuilder(self.tf_sess, "learn_on_batch") - else: - builder = None - for pid, batch in samples.policy_batches.items(): - if pid not in self.policies_to_train: - continue - policy = self.policy_map[pid] - if builder and hasattr(policy, "_build_learn_on_batch"): - to_fetch[pid] = policy._build_learn_on_batch( - builder, batch) - else: - info_out[pid] = policy.learn_on_batch(batch) - info_out.update({k: builder.get(v) for k, v in to_fetch.items()}) - else: - info_out = self.policy_map[DEFAULT_POLICY_ID].learn_on_batch( - samples) - if log_once("learn_out"): - logger.info("Training output:\n\n{}\n".format(summarize(info_out))) - return info_out - - @DeveloperAPI - def get_metrics(self): - """Returns a list of new RolloutMetric objects from evaluation.""" - - out = self.sampler.get_metrics() - for m in self.reward_estimators: - out.extend(m.get_metrics()) - return out - - @DeveloperAPI - def foreach_env(self, func): - """Apply the given function to each underlying env instance.""" - - envs = self.async_env.get_unwrapped() - if not envs: - return [func(self.async_env)] - else: - return [func(e) for e in envs] - - @DeveloperAPI - def get_policy(self, policy_id=DEFAULT_POLICY_ID): - """Return policy for the specified id, or None. - - Arguments: - policy_id (str): id of policy to return. - """ - - return self.policy_map.get(policy_id) - - @DeveloperAPI - def for_policy(self, func, policy_id=DEFAULT_POLICY_ID): - """Apply the given function to the specified policy.""" - - return func(self.policy_map[policy_id]) - - @DeveloperAPI - def foreach_policy(self, func): - """Apply the given function to each (policy, policy_id) tuple.""" - - return [func(policy, pid) for pid, policy in self.policy_map.items()] - - @DeveloperAPI - def foreach_trainable_policy(self, func): - """Apply the given function to each (policy, policy_id) tuple. - - This only applies func to policies in `self.policies_to_train`.""" - - return [ - func(policy, pid) for pid, policy in self.policy_map.items() - if pid in self.policies_to_train - ] - - @DeveloperAPI - def sync_filters(self, new_filters): - """Changes self's filter to given and rebases any accumulated delta. - - Args: - new_filters (dict): Filters with new state to update local copy. - """ - assert all(k in new_filters for k in self.filters) - for k in self.filters: - self.filters[k].sync(new_filters[k]) - - @DeveloperAPI - def get_filters(self, flush_after=False): - """Returns a snapshot of filters. - - Args: - flush_after (bool): Clears the filter buffer state. - - Returns: - return_filters (dict): Dict for serializable filters - """ - return_filters = {} - for k, f in self.filters.items(): - return_filters[k] = f.as_serializable() - if flush_after: - f.clear_buffer() - return return_filters - - @DeveloperAPI - def save(self): - filters = self.get_filters(flush_after=True) - state = { - pid: self.policy_map[pid].get_state() - for pid in self.policy_map - } - return pickle.dumps({"filters": filters, "state": state}) - - @DeveloperAPI - def restore(self, objs): - objs = pickle.loads(objs) - self.sync_filters(objs["filters"]) - for pid, state in objs["state"].items(): - self.policy_map[pid].set_state(state) - - @DeveloperAPI - def set_global_vars(self, global_vars): - self.foreach_policy(lambda p, _: p.on_global_var_update(global_vars)) - - @DeveloperAPI - def export_policy_model(self, export_dir, policy_id=DEFAULT_POLICY_ID): - self.policy_map[policy_id].export_model(export_dir) - - @DeveloperAPI - def export_policy_checkpoint(self, - export_dir, - filename_prefix="model", - policy_id=DEFAULT_POLICY_ID): - self.policy_map[policy_id].export_checkpoint(export_dir, - filename_prefix) - - @DeveloperAPI - def stop(self): - self.async_env.stop() - - def _build_policy_map(self, policy_dict, policy_config): - policy_map = {} - preprocessors = {} - for name, (cls, obs_space, act_space, - conf) in sorted(policy_dict.items()): - logger.debug("Creating policy for {}".format(name)) - merged_conf = merge_dicts(policy_config, conf) - if self.preprocessing_enabled: - preprocessor = ModelCatalog.get_preprocessor_for_space( - obs_space, merged_conf.get("model")) - preprocessors[name] = preprocessor - obs_space = preprocessor.observation_space - else: - preprocessors[name] = NoPreprocessor(obs_space) - if isinstance(obs_space, gym.spaces.Dict) or \ - isinstance(obs_space, gym.spaces.Tuple): - raise ValueError( - "Found raw Tuple|Dict space as input to policy. " - "Please preprocess these observations with a " - "Tuple|DictFlatteningPreprocessor.") - if tf: - with tf.variable_scope(name): - policy_map[name] = cls(obs_space, act_space, merged_conf) - else: - policy_map[name] = cls(obs_space, act_space, merged_conf) - if self.worker_index == 0: - logger.info("Built policy map: {}".format(policy_map)) - logger.info("Built preprocessor map: {}".format(preprocessors)) - return policy_map, preprocessors - - def __del__(self): - if hasattr(self, "sampler") and isinstance(self.sampler, AsyncSampler): - self.sampler.shutdown = True - - -def _validate_and_canonicalize(policy, env): - if isinstance(policy, dict): - _validate_multiagent_config(policy) - return policy - elif not issubclass(policy, Policy): - raise ValueError("policy must be a rllib.Policy class") - else: - if (isinstance(env, MultiAgentEnv) - and not hasattr(env, "observation_space")): - raise ValueError( - "MultiAgentEnv must have observation_space defined if run " - "in a single-agent configuration.") - return { - DEFAULT_POLICY_ID: (policy, env.observation_space, - env.action_space, {}) - } - - -def _validate_multiagent_config(policy, allow_none_graph=False): - for k, v in policy.items(): - if not isinstance(k, str): - raise ValueError("policy keys must be strs, got {}".format( - type(k))) - if not isinstance(v, tuple) or len(v) != 4: - raise ValueError( - "policy values must be tuples of " - "(cls, obs_space, action_space, config), got {}".format(v)) - if allow_none_graph and v[0] is None: - pass - elif not issubclass(v[0], Policy): - raise ValueError("policy tuple value 0 must be a rllib.Policy " - "class or None, got {}".format(v[0])) - if not isinstance(v[1], gym.Space): - raise ValueError( - "policy tuple value 1 (observation_space) must be a " - "gym.Space, got {}".format(type(v[1]))) - if not isinstance(v[2], gym.Space): - raise ValueError("policy tuple value 2 (action_space) must be a " - "gym.Space, got {}".format(type(v[2]))) - if not isinstance(v[3], dict): - raise ValueError("policy tuple value 3 (config) must be a dict, " - "got {}".format(type(v[3]))) - - -def _validate_env(env): - # allow this as a special case (assumed gym.Env) - if hasattr(env, "observation_space") and hasattr(env, "action_space"): - return env - - allowed_types = [gym.Env, MultiAgentEnv, ExternalEnv, VectorEnv, BaseEnv] - if not any(isinstance(env, tpe) for tpe in allowed_types): - raise ValueError( - "Returned env should be an instance of gym.Env, MultiAgentEnv, " - "ExternalEnv, VectorEnv, or BaseEnv. The provided env creator " - "function returned {} ({}).".format(env, type(env))) - return env - - -def _monitor(env, path): - return gym.wrappers.Monitor(env, path, resume=True) - - -def _has_tensorflow_graph(policy_dict): - for policy, _, _, _ in policy_dict.values(): - if issubclass(policy, TFPolicy): - return True - return False +PolicyEvaluator = renamed_class( + RolloutWorker, old_name="rllib.evaluation.PolicyEvaluator") diff --git a/python/ray/rllib/evaluation/rollout_worker.py b/python/ray/rllib/evaluation/rollout_worker.py new file mode 100644 index 000000000000..3be01a42907b --- /dev/null +++ b/python/ray/rllib/evaluation/rollout_worker.py @@ -0,0 +1,794 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gym +import logging +import pickle + +import ray +from ray.rllib.env.atari_wrappers import wrap_deepmind, is_atari +from ray.rllib.env.base_env import BaseEnv +from ray.rllib.env.env_context import EnvContext +from ray.rllib.env.external_env import ExternalEnv +from ray.rllib.env.multi_agent_env import MultiAgentEnv +from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv +from ray.rllib.env.vector_env import VectorEnv +from ray.rllib.evaluation.interface import EvaluatorInterface +from ray.rllib.evaluation.sampler import AsyncSampler, SyncSampler +from ray.rllib.policy.sample_batch import MultiAgentBatch, DEFAULT_POLICY_ID +from ray.rllib.policy.policy import Policy +from ray.rllib.policy.tf_policy import TFPolicy +from ray.rllib.offline import NoopOutput, IOContext, OutputWriter, InputReader +from ray.rllib.offline.is_estimator import ImportanceSamplingEstimator +from ray.rllib.offline.wis_estimator import WeightedImportanceSamplingEstimator +from ray.rllib.models import ModelCatalog +from ray.rllib.models.preprocessors import NoPreprocessor +from ray.rllib.utils import merge_dicts +from ray.rllib.utils.annotations import override, DeveloperAPI +from ray.rllib.utils.debug import disable_log_once_globally, log_once, \ + summarize, enable_periodic_logging +from ray.rllib.utils.filter import get_filter +from ray.rllib.utils.tf_run_builder import TFRunBuilder +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() +logger = logging.getLogger(__name__) + +# Handle to the current rollout worker, which will be set to the most recently +# created RolloutWorker in this process. This can be helpful to access in +# custom env or policy classes for debugging or advanced use cases. +_global_worker = None + + +@DeveloperAPI +def get_global_worker(): + """Returns a handle to the active rollout worker in this process.""" + + global _global_worker + return _global_worker + + +@DeveloperAPI +class RolloutWorker(EvaluatorInterface): + """Common experience collection class. + + This class wraps a policy instance and an environment class to + collect experiences from the environment. You can create many replicas of + this class as Ray actors to scale RL training. + + This class supports vectorized and multi-agent policy evaluation (e.g., + VectorEnv, MultiAgentEnv, etc.) + + Examples: + >>> # Create a rollout worker and using it to collect experiences. + >>> worker = RolloutWorker( + ... env_creator=lambda _: gym.make("CartPole-v0"), + ... policy=PGTFPolicy) + >>> print(worker.sample()) + SampleBatch({ + "obs": [[...]], "actions": [[...]], "rewards": [[...]], + "dones": [[...]], "new_obs": [[...]]}) + + >>> # Creating a multi-agent rollout worker + >>> worker = RolloutWorker( + ... env_creator=lambda _: MultiAgentTrafficGrid(num_cars=25), + ... policies={ + ... # Use an ensemble of two policies for car agents + ... "car_policy1": + ... (PGTFPolicy, Box(...), Discrete(...), {"gamma": 0.99}), + ... "car_policy2": + ... (PGTFPolicy, Box(...), Discrete(...), {"gamma": 0.95}), + ... # Use a single shared policy for all traffic lights + ... "traffic_light_policy": + ... (PGTFPolicy, Box(...), Discrete(...), {}), + ... }, + ... policy_mapping_fn=lambda agent_id: + ... random.choice(["car_policy1", "car_policy2"]) + ... if agent_id.startswith("car_") else "traffic_light_policy") + >>> print(worker.sample()) + MultiAgentBatch({ + "car_policy1": SampleBatch(...), + "car_policy2": SampleBatch(...), + "traffic_light_policy": SampleBatch(...)}) + """ + + @DeveloperAPI + @classmethod + def as_remote(cls, num_cpus=None, num_gpus=None, resources=None): + return ray.remote( + num_cpus=num_cpus, num_gpus=num_gpus, resources=resources)(cls) + + @DeveloperAPI + def __init__(self, + env_creator, + policy, + policy_mapping_fn=None, + policies_to_train=None, + tf_session_creator=None, + batch_steps=100, + batch_mode="truncate_episodes", + episode_horizon=None, + preprocessor_pref="deepmind", + sample_async=False, + compress_observations=False, + num_envs=1, + observation_filter="NoFilter", + clip_rewards=None, + clip_actions=True, + env_config=None, + model_config=None, + policy_config=None, + worker_index=0, + monitor_path=None, + log_dir=None, + log_level=None, + callbacks=None, + input_creator=lambda ioctx: ioctx.default_sampler_input(), + input_evaluation=frozenset([]), + output_creator=lambda ioctx: NoopOutput(), + remote_worker_envs=False, + remote_env_batch_wait_ms=0, + soft_horizon=False, + _fake_sampler=False): + """Initialize a rollout worker. + + Arguments: + env_creator (func): Function that returns a gym.Env given an + EnvContext wrapped configuration. + policy (class|dict): Either a class implementing + Policy, or a dictionary of policy id strings to + (Policy, obs_space, action_space, config) tuples. If a + dict is specified, then we are in multi-agent mode and a + policy_mapping_fn should also be set. + policy_mapping_fn (func): A function that maps agent ids to + policy ids in multi-agent mode. This function will be called + each time a new agent appears in an episode, to bind that agent + to a policy for the duration of the episode. + policies_to_train (list): Optional whitelist of policies to train, + or None for all policies. + tf_session_creator (func): A function that returns a TF session. + This is optional and only useful with TFPolicy. + batch_steps (int): The target number of env transitions to include + in each sample batch returned from this worker. + batch_mode (str): One of the following batch modes: + "truncate_episodes": Each call to sample() will return a batch + of at most `batch_steps * num_envs` in size. The batch will + be exactly `batch_steps * num_envs` in size if + postprocessing does not change batch sizes. Episodes may be + truncated in order to meet this size requirement. + "complete_episodes": Each call to sample() will return a batch + of at least `batch_steps * num_envs` in size. Episodes will + not be truncated, but multiple episodes may be packed + within one batch to meet the batch size. Note that when + `num_envs > 1`, episode steps will be buffered until the + episode completes, and hence batches may contain + significant amounts of off-policy data. + episode_horizon (int): Whether to stop episodes at this horizon. + preprocessor_pref (str): Whether to prefer RLlib preprocessors + ("rllib") or deepmind ("deepmind") when applicable. + sample_async (bool): Whether to compute samples asynchronously in + the background, which improves throughput but can cause samples + to be slightly off-policy. + compress_observations (bool): If true, compress the observations. + They can be decompressed with rllib/utils/compression. + num_envs (int): If more than one, will create multiple envs + and vectorize the computation of actions. This has no effect if + if the env already implements VectorEnv. + observation_filter (str): Name of observation filter to use. + clip_rewards (bool): Whether to clip rewards to [-1, 1] prior to + experience postprocessing. Setting to None means clip for Atari + only. + clip_actions (bool): Whether to clip action values to the range + specified by the policy action space. + env_config (dict): Config to pass to the env creator. + model_config (dict): Config to use when creating the policy model. + policy_config (dict): Config to pass to the policy. In the + multi-agent case, this config will be merged with the + per-policy configs specified by `policy`. + worker_index (int): For remote workers, this should be set to a + non-zero and unique value. This index is passed to created envs + through EnvContext so that envs can be configured per worker. + monitor_path (str): Write out episode stats and videos to this + directory if specified. + log_dir (str): Directory where logs can be placed. + log_level (str): Set the root log level on creation. + callbacks (dict): Dict of custom debug callbacks. + input_creator (func): Function that returns an InputReader object + for loading previous generated experiences. + input_evaluation (list): How to evaluate the policy performance. + This only makes sense to set when the input is reading offline + data. The possible values include: + - "is": the step-wise importance sampling estimator. + - "wis": the weighted step-wise is estimator. + - "simulation": run the environment in the background, but + use this data for evaluation only and never for learning. + output_creator (func): Function that returns an OutputWriter object + for saving generated experiences. + remote_worker_envs (bool): If using num_envs > 1, whether to create + those new envs in remote processes instead of in the current + process. This adds overheads, but can make sense if your envs + remote_env_batch_wait_ms (float): Timeout that remote workers + are waiting when polling environments. 0 (continue when at + least one env is ready) is a reasonable default, but optimal + value could be obtained by measuring your environment + step / reset and model inference perf. + soft_horizon (bool): Calculate rewards but don't reset the + environment when the horizon is hit. + _fake_sampler (bool): Use a fake (inf speed) sampler for testing. + """ + + global _global_worker + _global_worker = self + + if log_level: + logging.getLogger("ray.rllib").setLevel(log_level) + + if worker_index > 1: + disable_log_once_globally() # only need 1 worker to log + elif log_level == "DEBUG": + enable_periodic_logging() + + env_context = EnvContext(env_config or {}, worker_index) + policy_config = policy_config or {} + self.policy_config = policy_config + self.callbacks = callbacks or {} + self.worker_index = worker_index + model_config = model_config or {} + policy_mapping_fn = (policy_mapping_fn + or (lambda agent_id: DEFAULT_POLICY_ID)) + if not callable(policy_mapping_fn): + raise ValueError( + "Policy mapping function not callable. If you're using Tune, " + "make sure to escape the function with tune.function() " + "to prevent it from being evaluated as an expression.") + self.env_creator = env_creator + self.sample_batch_size = batch_steps * num_envs + self.batch_mode = batch_mode + self.compress_observations = compress_observations + self.preprocessing_enabled = True + self.last_batch = None + self._fake_sampler = _fake_sampler + + self.env = _validate_env(env_creator(env_context)) + if isinstance(self.env, MultiAgentEnv) or \ + isinstance(self.env, BaseEnv): + + def wrap(env): + return env # we can't auto-wrap these env types + elif is_atari(self.env) and \ + not model_config.get("custom_preprocessor") and \ + preprocessor_pref == "deepmind": + + # Deepmind wrappers already handle all preprocessing + self.preprocessing_enabled = False + + if clip_rewards is None: + clip_rewards = True + + def wrap(env): + env = wrap_deepmind( + env, + dim=model_config.get("dim"), + framestack=model_config.get("framestack")) + if monitor_path: + env = _monitor(env, monitor_path) + return env + else: + + def wrap(env): + if monitor_path: + env = _monitor(env, monitor_path) + return env + + self.env = wrap(self.env) + + def make_env(vector_index): + return wrap( + env_creator( + env_context.copy_with_overrides( + vector_index=vector_index, remote=remote_worker_envs))) + + self.tf_sess = None + policy_dict = _validate_and_canonicalize(policy, self.env) + self.policies_to_train = policies_to_train or list(policy_dict.keys()) + if _has_tensorflow_graph(policy_dict): + if (ray.is_initialized() + and ray.worker._mode() != ray.worker.LOCAL_MODE + and not ray.get_gpu_ids()): + logger.info("Creating policy evaluation worker {}".format( + worker_index) + + " on CPU (please ignore any CUDA init errors)") + with tf.Graph().as_default(): + if tf_session_creator: + self.tf_sess = tf_session_creator() + else: + self.tf_sess = tf.Session( + config=tf.ConfigProto( + gpu_options=tf.GPUOptions(allow_growth=True))) + with self.tf_sess.as_default(): + self.policy_map, self.preprocessors = \ + self._build_policy_map(policy_dict, policy_config) + else: + self.policy_map, self.preprocessors = self._build_policy_map( + policy_dict, policy_config) + + self.multiagent = set(self.policy_map.keys()) != {DEFAULT_POLICY_ID} + if self.multiagent: + if not ((isinstance(self.env, MultiAgentEnv) + or isinstance(self.env, ExternalMultiAgentEnv)) + or isinstance(self.env, BaseEnv)): + raise ValueError( + "Have multiple policies {}, but the env ".format( + self.policy_map) + + "{} is not a subclass of BaseEnv, MultiAgentEnv or " + "ExternalMultiAgentEnv?".format(self.env)) + + self.filters = { + policy_id: get_filter(observation_filter, + policy.observation_space.shape) + for (policy_id, policy) in self.policy_map.items() + } + if self.worker_index == 0: + logger.info("Built filter map: {}".format(self.filters)) + + # Always use vector env for consistency even if num_envs = 1 + self.async_env = BaseEnv.to_base_env( + self.env, + make_env=make_env, + num_envs=num_envs, + remote_envs=remote_worker_envs, + remote_env_batch_wait_ms=remote_env_batch_wait_ms) + self.num_envs = num_envs + + if self.batch_mode == "truncate_episodes": + unroll_length = batch_steps + pack_episodes = True + elif self.batch_mode == "complete_episodes": + unroll_length = float("inf") # never cut episodes + pack_episodes = False # sampler will return 1 episode per poll + else: + raise ValueError("Unsupported batch mode: {}".format( + self.batch_mode)) + + self.io_context = IOContext(log_dir, policy_config, worker_index, self) + self.reward_estimators = [] + for method in input_evaluation: + if method == "simulation": + logger.warning( + "Requested 'simulation' input evaluation method: " + "will discard all sampler outputs and keep only metrics.") + sample_async = True + elif method == "is": + ise = ImportanceSamplingEstimator.create(self.io_context) + self.reward_estimators.append(ise) + elif method == "wis": + wise = WeightedImportanceSamplingEstimator.create( + self.io_context) + self.reward_estimators.append(wise) + else: + raise ValueError( + "Unknown evaluation method: {}".format(method)) + + if sample_async: + self.sampler = AsyncSampler( + self.async_env, + self.policy_map, + policy_mapping_fn, + self.preprocessors, + self.filters, + clip_rewards, + unroll_length, + self.callbacks, + horizon=episode_horizon, + pack=pack_episodes, + tf_sess=self.tf_sess, + clip_actions=clip_actions, + blackhole_outputs="simulation" in input_evaluation, + soft_horizon=soft_horizon) + self.sampler.start() + else: + self.sampler = SyncSampler( + self.async_env, + self.policy_map, + policy_mapping_fn, + self.preprocessors, + self.filters, + clip_rewards, + unroll_length, + self.callbacks, + horizon=episode_horizon, + pack=pack_episodes, + tf_sess=self.tf_sess, + clip_actions=clip_actions, + soft_horizon=soft_horizon) + + self.input_reader = input_creator(self.io_context) + assert isinstance(self.input_reader, InputReader), self.input_reader + self.output_writer = output_creator(self.io_context) + assert isinstance(self.output_writer, OutputWriter), self.output_writer + + logger.debug( + "Created rollout worker with env {} ({}), policies {}".format( + self.async_env, self.env, self.policy_map)) + + @override(EvaluatorInterface) + def sample(self): + """Evaluate the current policies and return a batch of experiences. + + Return: + SampleBatch|MultiAgentBatch from evaluating the current policies. + """ + + if self._fake_sampler and self.last_batch is not None: + return self.last_batch + + if log_once("sample_start"): + logger.info("Generating sample batch of size {}".format( + self.sample_batch_size)) + + batches = [self.input_reader.next()] + steps_so_far = batches[0].count + + # In truncate_episodes mode, never pull more than 1 batch per env. + # This avoids over-running the target batch size. + if self.batch_mode == "truncate_episodes": + max_batches = self.num_envs + else: + max_batches = float("inf") + + while steps_so_far < self.sample_batch_size and len( + batches) < max_batches: + batch = self.input_reader.next() + steps_so_far += batch.count + batches.append(batch) + batch = batches[0].concat_samples(batches) + + if self.callbacks.get("on_sample_end"): + self.callbacks["on_sample_end"]({"worker": self, "samples": batch}) + + # Always do writes prior to compression for consistency and to allow + # for better compression inside the writer. + self.output_writer.write(batch) + + # Do off-policy estimation if needed + if self.reward_estimators: + for sub_batch in batch.split_by_episode(): + for estimator in self.reward_estimators: + estimator.process(sub_batch) + + if log_once("sample_end"): + logger.info("Completed sample batch:\n\n{}\n".format( + summarize(batch))) + + if self.compress_observations == "bulk": + batch.compress(bulk=True) + elif self.compress_observations: + batch.compress() + + if self._fake_sampler: + self.last_batch = batch + return batch + + @DeveloperAPI + @ray.method(num_return_vals=2) + def sample_with_count(self): + """Same as sample() but returns the count as a separate future.""" + batch = self.sample() + return batch, batch.count + + @override(EvaluatorInterface) + def get_weights(self, policies=None): + if policies is None: + policies = self.policy_map.keys() + return { + pid: policy.get_weights() + for pid, policy in self.policy_map.items() if pid in policies + } + + @override(EvaluatorInterface) + def set_weights(self, weights): + for pid, w in weights.items(): + self.policy_map[pid].set_weights(w) + + @override(EvaluatorInterface) + def compute_gradients(self, samples): + if log_once("compute_gradients"): + logger.info("Compute gradients on:\n\n{}\n".format( + summarize(samples))) + if isinstance(samples, MultiAgentBatch): + grad_out, info_out = {}, {} + if self.tf_sess is not None: + builder = TFRunBuilder(self.tf_sess, "compute_gradients") + for pid, batch in samples.policy_batches.items(): + if pid not in self.policies_to_train: + continue + grad_out[pid], info_out[pid] = ( + self.policy_map[pid]._build_compute_gradients( + builder, batch)) + grad_out = {k: builder.get(v) for k, v in grad_out.items()} + info_out = {k: builder.get(v) for k, v in info_out.items()} + else: + for pid, batch in samples.policy_batches.items(): + if pid not in self.policies_to_train: + continue + grad_out[pid], info_out[pid] = ( + self.policy_map[pid].compute_gradients(batch)) + else: + grad_out, info_out = ( + self.policy_map[DEFAULT_POLICY_ID].compute_gradients(samples)) + info_out["batch_count"] = samples.count + if log_once("grad_out"): + logger.info("Compute grad info:\n\n{}\n".format( + summarize(info_out))) + return grad_out, info_out + + @override(EvaluatorInterface) + def apply_gradients(self, grads): + if log_once("apply_gradients"): + logger.info("Apply gradients:\n\n{}\n".format(summarize(grads))) + if isinstance(grads, dict): + if self.tf_sess is not None: + builder = TFRunBuilder(self.tf_sess, "apply_gradients") + outputs = { + pid: self.policy_map[pid]._build_apply_gradients( + builder, grad) + for pid, grad in grads.items() + } + return {k: builder.get(v) for k, v in outputs.items()} + else: + return { + pid: self.policy_map[pid].apply_gradients(g) + for pid, g in grads.items() + } + else: + return self.policy_map[DEFAULT_POLICY_ID].apply_gradients(grads) + + @override(EvaluatorInterface) + def learn_on_batch(self, samples): + if log_once("learn_on_batch"): + logger.info( + "Training on concatenated sample batches:\n\n{}\n".format( + summarize(samples))) + if isinstance(samples, MultiAgentBatch): + info_out = {} + to_fetch = {} + if self.tf_sess is not None: + builder = TFRunBuilder(self.tf_sess, "learn_on_batch") + else: + builder = None + for pid, batch in samples.policy_batches.items(): + if pid not in self.policies_to_train: + continue + policy = self.policy_map[pid] + if builder and hasattr(policy, "_build_learn_on_batch"): + to_fetch[pid] = policy._build_learn_on_batch( + builder, batch) + else: + info_out[pid] = policy.learn_on_batch(batch) + info_out.update({k: builder.get(v) for k, v in to_fetch.items()}) + else: + info_out = self.policy_map[DEFAULT_POLICY_ID].learn_on_batch( + samples) + if log_once("learn_out"): + logger.info("Training output:\n\n{}\n".format(summarize(info_out))) + return info_out + + @DeveloperAPI + def get_metrics(self): + """Returns a list of new RolloutMetric objects from evaluation.""" + + out = self.sampler.get_metrics() + for m in self.reward_estimators: + out.extend(m.get_metrics()) + return out + + @DeveloperAPI + def foreach_env(self, func): + """Apply the given function to each underlying env instance.""" + + envs = self.async_env.get_unwrapped() + if not envs: + return [func(self.async_env)] + else: + return [func(e) for e in envs] + + @DeveloperAPI + def get_policy(self, policy_id=DEFAULT_POLICY_ID): + """Return policy for the specified id, or None. + + Arguments: + policy_id (str): id of policy to return. + """ + + return self.policy_map.get(policy_id) + + @DeveloperAPI + def for_policy(self, func, policy_id=DEFAULT_POLICY_ID): + """Apply the given function to the specified policy.""" + + return func(self.policy_map[policy_id]) + + @DeveloperAPI + def foreach_policy(self, func): + """Apply the given function to each (policy, policy_id) tuple.""" + + return [func(policy, pid) for pid, policy in self.policy_map.items()] + + @DeveloperAPI + def foreach_trainable_policy(self, func): + """Apply the given function to each (policy, policy_id) tuple. + + This only applies func to policies in `self.policies_to_train`.""" + + return [ + func(policy, pid) for pid, policy in self.policy_map.items() + if pid in self.policies_to_train + ] + + @DeveloperAPI + def sync_filters(self, new_filters): + """Changes self's filter to given and rebases any accumulated delta. + + Args: + new_filters (dict): Filters with new state to update local copy. + """ + assert all(k in new_filters for k in self.filters) + for k in self.filters: + self.filters[k].sync(new_filters[k]) + + @DeveloperAPI + def get_filters(self, flush_after=False): + """Returns a snapshot of filters. + + Args: + flush_after (bool): Clears the filter buffer state. + + Returns: + return_filters (dict): Dict for serializable filters + """ + return_filters = {} + for k, f in self.filters.items(): + return_filters[k] = f.as_serializable() + if flush_after: + f.clear_buffer() + return return_filters + + @DeveloperAPI + def save(self): + filters = self.get_filters(flush_after=True) + state = { + pid: self.policy_map[pid].get_state() + for pid in self.policy_map + } + return pickle.dumps({"filters": filters, "state": state}) + + @DeveloperAPI + def restore(self, objs): + objs = pickle.loads(objs) + self.sync_filters(objs["filters"]) + for pid, state in objs["state"].items(): + self.policy_map[pid].set_state(state) + + @DeveloperAPI + def set_global_vars(self, global_vars): + self.foreach_policy(lambda p, _: p.on_global_var_update(global_vars)) + + @DeveloperAPI + def export_policy_model(self, export_dir, policy_id=DEFAULT_POLICY_ID): + self.policy_map[policy_id].export_model(export_dir) + + @DeveloperAPI + def export_policy_checkpoint(self, + export_dir, + filename_prefix="model", + policy_id=DEFAULT_POLICY_ID): + self.policy_map[policy_id].export_checkpoint(export_dir, + filename_prefix) + + @DeveloperAPI + def stop(self): + self.async_env.stop() + + def _build_policy_map(self, policy_dict, policy_config): + policy_map = {} + preprocessors = {} + for name, (cls, obs_space, act_space, + conf) in sorted(policy_dict.items()): + logger.debug("Creating policy for {}".format(name)) + merged_conf = merge_dicts(policy_config, conf) + if self.preprocessing_enabled: + preprocessor = ModelCatalog.get_preprocessor_for_space( + obs_space, merged_conf.get("model")) + preprocessors[name] = preprocessor + obs_space = preprocessor.observation_space + else: + preprocessors[name] = NoPreprocessor(obs_space) + if isinstance(obs_space, gym.spaces.Dict) or \ + isinstance(obs_space, gym.spaces.Tuple): + raise ValueError( + "Found raw Tuple|Dict space as input to policy. " + "Please preprocess these observations with a " + "Tuple|DictFlatteningPreprocessor.") + if tf: + with tf.variable_scope(name): + policy_map[name] = cls(obs_space, act_space, merged_conf) + else: + policy_map[name] = cls(obs_space, act_space, merged_conf) + if self.worker_index == 0: + logger.info("Built policy map: {}".format(policy_map)) + logger.info("Built preprocessor map: {}".format(preprocessors)) + return policy_map, preprocessors + + def __del__(self): + if hasattr(self, "sampler") and isinstance(self.sampler, AsyncSampler): + self.sampler.shutdown = True + + +def _validate_and_canonicalize(policy, env): + if isinstance(policy, dict): + _validate_multiagent_config(policy) + return policy + elif not issubclass(policy, Policy): + raise ValueError("policy must be a rllib.Policy class") + else: + if (isinstance(env, MultiAgentEnv) + and not hasattr(env, "observation_space")): + raise ValueError( + "MultiAgentEnv must have observation_space defined if run " + "in a single-agent configuration.") + return { + DEFAULT_POLICY_ID: (policy, env.observation_space, + env.action_space, {}) + } + + +def _validate_multiagent_config(policy, allow_none_graph=False): + for k, v in policy.items(): + if not isinstance(k, str): + raise ValueError("policy keys must be strs, got {}".format( + type(k))) + if not isinstance(v, tuple) or len(v) != 4: + raise ValueError( + "policy values must be tuples of " + "(cls, obs_space, action_space, config), got {}".format(v)) + if allow_none_graph and v[0] is None: + pass + elif not issubclass(v[0], Policy): + raise ValueError("policy tuple value 0 must be a rllib.Policy " + "class or None, got {}".format(v[0])) + if not isinstance(v[1], gym.Space): + raise ValueError( + "policy tuple value 1 (observation_space) must be a " + "gym.Space, got {}".format(type(v[1]))) + if not isinstance(v[2], gym.Space): + raise ValueError("policy tuple value 2 (action_space) must be a " + "gym.Space, got {}".format(type(v[2]))) + if not isinstance(v[3], dict): + raise ValueError("policy tuple value 3 (config) must be a dict, " + "got {}".format(type(v[3]))) + + +def _validate_env(env): + # allow this as a special case (assumed gym.Env) + if hasattr(env, "observation_space") and hasattr(env, "action_space"): + return env + + allowed_types = [gym.Env, MultiAgentEnv, ExternalEnv, VectorEnv, BaseEnv] + if not any(isinstance(env, tpe) for tpe in allowed_types): + raise ValueError( + "Returned env should be an instance of gym.Env, MultiAgentEnv, " + "ExternalEnv, VectorEnv, or BaseEnv. The provided env creator " + "function returned {} ({}).".format(env, type(env))) + return env + + +def _monitor(env, path): + return gym.wrappers.Monitor(env, path, resume=True) + + +def _has_tensorflow_graph(policy_dict): + for policy, _, _, _ in policy_dict.values(): + if issubclass(policy, TFPolicy): + return True + return False diff --git a/python/ray/rllib/evaluation/worker_set.py b/python/ray/rllib/evaluation/worker_set.py new file mode 100644 index 000000000000..90d3c13c217e --- /dev/null +++ b/python/ray/rllib/evaluation/worker_set.py @@ -0,0 +1,214 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import logging +from types import FunctionType + +from ray.rllib.utils.annotations import DeveloperAPI +from ray.rllib.evaluation.rollout_worker import RolloutWorker, \ + _validate_multiagent_config +from ray.rllib.offline import NoopOutput, JsonReader, MixedInput, JsonWriter, \ + ShuffledInput +from ray.rllib.utils import merge_dicts, try_import_tf +from ray.rllib.utils.memory import ray_get_and_free + +tf = try_import_tf() + +logger = logging.getLogger(__name__) + + +@DeveloperAPI +class WorkerSet(object): + """Represents a set of RolloutWorkers. + + There must be one local worker copy, and zero or more remote workers. + """ + + def __init__(self, + env_creator, + policy, + trainer_config=None, + num_workers=0, + logdir=None, + _setup=True): + """Create a new WorkerSet and initialize its workers. + + Arguments: + env_creator (func): Function that returns env given env config. + policy (cls): rllib.policy.Policy class. + trainer_config (dict): Optional dict that extends the common + config of the Trainer class. + num_workers (int): Number of remote rollout workers to create. + logdir (str): Optional logging directory for workers. + _setup (bool): Whether to setup workers. This is only for testing. + """ + + if not trainer_config: + from ray.rllib.agents.trainer import COMMON_CONFIG + trainer_config = COMMON_CONFIG + + self._env_creator = env_creator + self._policy = policy + self._remote_config = trainer_config + self._num_workers = num_workers + self._logdir = logdir + + if _setup: + self._local_config = merge_dicts( + trainer_config, + {"tf_session_args": trainer_config["local_tf_session_args"]}) + + # Always create a local worker + self._local_worker = self._make_worker( + RolloutWorker, env_creator, policy, 0, self._local_config) + + # Create a number of remote workers + self._remote_workers = [] + self.add_workers(num_workers) + + def local_worker(self): + """Return the local rollout worker.""" + return self._local_worker + + def remote_workers(self): + """Return a list of remote rollout workers.""" + return self._remote_workers + + def add_workers(self, num_workers): + """Create and add a number of remote workers to this worker set.""" + remote_args = { + "num_cpus": self._remote_config["num_cpus_per_worker"], + "num_gpus": self._remote_config["num_gpus_per_worker"], + "resources": self._remote_config["custom_resources_per_worker"], + } + cls = RolloutWorker.as_remote(**remote_args).remote + self._remote_workers.extend([ + self._make_worker(cls, self._env_creator, self._policy, i + 1, + self._remote_config) for i in range(num_workers) + ]) + + def reset(self, new_remote_workers): + """Called to change the set of remote workers.""" + self._remote_workers = new_remote_workers + + def stop(self): + """Stop all rollout workers.""" + self.local_worker().stop() + for w in self.remote_workers(): + w.stop.remote() + w.__ray_terminate__.remote() + + @DeveloperAPI + def foreach_worker(self, func): + """Apply the given function to each worker instance.""" + + local_result = [func(self.local_worker())] + remote_results = ray_get_and_free( + [w.apply.remote(func) for w in self.remote_workers()]) + return local_result + remote_results + + @DeveloperAPI + def foreach_worker_with_index(self, func): + """Apply the given function to each worker instance. + + The index will be passed as the second arg to the given function. + """ + + local_result = [func(self.local_worker(), 0)] + remote_results = ray_get_and_free([ + w.apply.remote(func, i + 1) + for i, w in enumerate(self.remote_workers()) + ]) + return local_result + remote_results + + @staticmethod + def _from_existing(local_worker, remote_workers=None): + workers = WorkerSet(None, None, {}, _setup=False) + workers._local_worker = local_worker + workers._remote_workers = remote_workers or [] + return workers + + def _make_worker(self, cls, env_creator, policy, worker_index, config): + def session_creator(): + logger.debug("Creating TF session {}".format( + config["tf_session_args"])) + return tf.Session( + config=tf.ConfigProto(**config["tf_session_args"])) + + if isinstance(config["input"], FunctionType): + input_creator = config["input"] + elif config["input"] == "sampler": + input_creator = (lambda ioctx: ioctx.default_sampler_input()) + elif isinstance(config["input"], dict): + input_creator = (lambda ioctx: ShuffledInput( + MixedInput(config["input"], ioctx), config[ + "shuffle_buffer_size"])) + else: + input_creator = (lambda ioctx: ShuffledInput( + JsonReader(config["input"], ioctx), config[ + "shuffle_buffer_size"])) + + if isinstance(config["output"], FunctionType): + output_creator = config["output"] + elif config["output"] is None: + output_creator = (lambda ioctx: NoopOutput()) + elif config["output"] == "logdir": + output_creator = (lambda ioctx: JsonWriter( + ioctx.log_dir, + ioctx, + max_file_size=config["output_max_file_size"], + compress_columns=config["output_compress_columns"])) + else: + output_creator = (lambda ioctx: JsonWriter( + config["output"], + ioctx, + max_file_size=config["output_max_file_size"], + compress_columns=config["output_compress_columns"])) + + if config["input"] == "sampler": + input_evaluation = [] + else: + input_evaluation = config["input_evaluation"] + + # Fill in the default policy if 'None' is specified in multiagent + if config["multiagent"]["policies"]: + tmp = config["multiagent"]["policies"] + _validate_multiagent_config(tmp, allow_none_graph=True) + for k, v in tmp.items(): + if v[0] is None: + tmp[k] = (policy, v[1], v[2], v[3]) + policy = tmp + + return cls( + env_creator, + policy, + policy_mapping_fn=config["multiagent"]["policy_mapping_fn"], + policies_to_train=config["multiagent"]["policies_to_train"], + tf_session_creator=(session_creator + if config["tf_session_args"] else None), + batch_steps=config["sample_batch_size"], + batch_mode=config["batch_mode"], + episode_horizon=config["horizon"], + preprocessor_pref=config["preprocessor_pref"], + sample_async=config["sample_async"], + compress_observations=config["compress_observations"], + num_envs=config["num_envs_per_worker"], + observation_filter=config["observation_filter"], + clip_rewards=config["clip_rewards"], + clip_actions=config["clip_actions"], + env_config=config["env_config"], + model_config=config["model"], + policy_config=config, + worker_index=worker_index, + monitor_path=self._logdir if config["monitor"] else None, + log_dir=self._logdir, + log_level=config["log_level"], + callbacks=config["callbacks"], + input_creator=input_creator, + input_evaluation=input_evaluation, + output_creator=output_creator, + remote_worker_envs=config["remote_worker_envs"], + remote_env_batch_wait_ms=config["remote_env_batch_wait_ms"], + soft_horizon=config["soft_horizon"], + _fake_sampler=config.get("_fake_sampler", False)) diff --git a/python/ray/rllib/examples/multiagent_two_trainers.py b/python/ray/rllib/examples/multiagent_two_trainers.py index 68c0e742e857..cdac4a2fde71 100644 --- a/python/ray/rllib/examples/multiagent_two_trainers.py +++ b/python/ray/rllib/examples/multiagent_two_trainers.py @@ -75,7 +75,7 @@ def policy_mapping_fn(agent_id): }) # disable DQN exploration when used by the PPO trainer - ppo_trainer.optimizer.foreach_evaluator( + ppo_trainer.workers.foreach_worker( lambda ev: ev.for_policy( lambda pi: pi.set_epsilon(0.0), policy_id="dqn_policy")) diff --git a/python/ray/rllib/examples/policy_evaluator_custom_workflow.py b/python/ray/rllib/examples/rollout_worker_custom_workflow.py similarity index 90% rename from python/ray/rllib/examples/policy_evaluator_custom_workflow.py rename to python/ray/rllib/examples/rollout_worker_custom_workflow.py index a8d80da994d2..fd1adc851e5d 100644 --- a/python/ray/rllib/examples/policy_evaluator_custom_workflow.py +++ b/python/ray/rllib/examples/rollout_worker_custom_workflow.py @@ -1,4 +1,4 @@ -"""Example of using policy evaluator classes directly to implement training. +"""Example of using rollout worker classes directly to implement training. Instead of using the built-in Trainer classes provided by RLlib, here we define a custom Policy class and manually coordinate distributed sample @@ -15,7 +15,7 @@ import ray from ray import tune from ray.rllib.policy import Policy -from ray.rllib.evaluation import PolicyEvaluator, SampleBatch +from ray.rllib.evaluation import RolloutWorker, SampleBatch from ray.rllib.evaluation.metrics import collect_metrics parser = argparse.ArgumentParser() @@ -67,8 +67,8 @@ def training_workflow(config, reporter): env = gym.make("CartPole-v0") policy = CustomPolicy(env.observation_space, env.action_space, {}) workers = [ - PolicyEvaluator.as_remote().remote(lambda c: gym.make("CartPole-v0"), - CustomPolicy) + RolloutWorker.as_remote().remote(lambda c: gym.make("CartPole-v0"), + CustomPolicy) for _ in range(config["num_workers"]) ] @@ -97,7 +97,7 @@ def training_workflow(config, reporter): # Do some arbitrary updates based on the T2 batch policy.update_some_value(sum(T2["rewards"])) - reporter(**collect_metrics(remote_evaluators=workers)) + reporter(**collect_metrics(remote_workers=workers)) if __name__ == "__main__": diff --git a/python/ray/rllib/offline/io_context.py b/python/ray/rllib/offline/io_context.py index 187c02f9ca0d..58f7f03c5407 100644 --- a/python/ray/rllib/offline/io_context.py +++ b/python/ray/rllib/offline/io_context.py @@ -18,20 +18,16 @@ class IOContext(object): config (dict): Configuration of the agent. worker_index (int): When there are multiple workers created, this uniquely identifies the current worker. - evaluator (PolicyEvaluator): policy evaluator object reference. + worker (RolloutWorker): rollout worker object reference. """ @PublicAPI - def __init__(self, - log_dir=None, - config=None, - worker_index=0, - evaluator=None): + def __init__(self, log_dir=None, config=None, worker_index=0, worker=None): self.log_dir = log_dir or os.getcwd() self.config = config or {} self.worker_index = worker_index - self.evaluator = evaluator + self.worker = worker @PublicAPI def default_sampler_input(self): - return self.evaluator.sampler + return self.worker.sampler diff --git a/python/ray/rllib/offline/json_reader.py b/python/ray/rllib/offline/json_reader.py index 55a002fb3ce6..35d28669d9a5 100644 --- a/python/ray/rllib/offline/json_reader.py +++ b/python/ray/rllib/offline/json_reader.py @@ -88,7 +88,7 @@ def _postprocess_if_needed(self, batch): if isinstance(batch, SampleBatch): out = [] for sub_batch in batch.split_by_episode(): - out.append(self.ioctx.evaluator.policy_map[DEFAULT_POLICY_ID] + out.append(self.ioctx.worker.policy_map[DEFAULT_POLICY_ID] .postprocess_trajectory(sub_batch)) return SampleBatch.concat_samples(out) else: diff --git a/python/ray/rllib/offline/off_policy_estimator.py b/python/ray/rllib/offline/off_policy_estimator.py index 7534e667f0bf..9d369f715cff 100644 --- a/python/ray/rllib/offline/off_policy_estimator.py +++ b/python/ray/rllib/offline/off_policy_estimator.py @@ -33,14 +33,14 @@ def __init__(self, policy, gamma): @classmethod def create(cls, ioctx): """Create an off-policy estimator from a IOContext.""" - gamma = ioctx.evaluator.policy_config["gamma"] + gamma = ioctx.worker.policy_config["gamma"] # Grab a reference to the current model - keys = list(ioctx.evaluator.policy_map.keys()) + keys = list(ioctx.worker.policy_map.keys()) if len(keys) > 1: raise NotImplementedError( "Off-policy estimation is not implemented for multi-agent. " "You can set `input_evaluation: []` to resolve this.") - policy = ioctx.evaluator.get_policy(keys[0]) + policy = ioctx.worker.get_policy(keys[0]) return cls(policy, gamma) @DeveloperAPI diff --git a/python/ray/rllib/optimizers/aso_aggregator.py b/python/ray/rllib/optimizers/aso_aggregator.py index c2ecb6ed194b..bc7c75bbf0e1 100644 --- a/python/ray/rllib/optimizers/aso_aggregator.py +++ b/python/ray/rllib/optimizers/aso_aggregator.py @@ -14,7 +14,7 @@ class Aggregator(object): - """An aggregator collects and processes samples from evaluators. + """An aggregator collects and processes samples from workers. This class is used to abstract away the strategy for sample collection. For example, you may want to use a tree of actors to collect samples. The @@ -22,21 +22,21 @@ class Aggregator(object): as concatenating and decompressing sample batches. Attributes: - local_evaluator: local PolicyEvaluator copy + local_worker: local RolloutWorker copy """ def iter_train_batches(self): """Returns a generator over batches ready to learn on. Iterating through this generator will also send out weight updates to - remote evaluators as needed. + remote workers as needed. This call may block until results are available. """ raise NotImplementedError def broadcast_new_weights(self): - """Broadcast a new set of weights from the local evaluator.""" + """Broadcast a new set of weights from the local workers.""" raise NotImplementedError def should_broadcast(self): @@ -47,19 +47,19 @@ def stats(self): """Returns runtime statistics for debugging.""" raise NotImplementedError - def reset(self, remote_evaluators): - """Called to change the set of remote evaluators being used.""" + def reset(self, remote_workers): + """Called to change the set of remote workers being used.""" raise NotImplementedError class AggregationWorkerBase(object): """Aggregators should extend from this class.""" - def __init__(self, initial_weights_obj_id, remote_evaluators, + def __init__(self, initial_weights_obj_id, remote_workers, max_sample_requests_in_flight_per_worker, replay_proportion, replay_buffer_num_slots, train_batch_size, sample_batch_size): self.broadcasted_weights = initial_weights_obj_id - self.remote_evaluators = remote_evaluators + self.remote_workers = remote_workers self.sample_batch_size = sample_batch_size self.train_batch_size = train_batch_size @@ -73,7 +73,7 @@ def __init__(self, initial_weights_obj_id, remote_evaluators, # Kick off async background sampling self.sample_tasks = TaskPool() - for ev in self.remote_evaluators: + for ev in self.remote_workers: ev.set_weights.remote(self.broadcasted_weights) for _ in range(max_sample_requests_in_flight_per_worker): self.sample_tasks.add(ev, ev.sample.remote()) @@ -138,8 +138,8 @@ def stats(self): } @override(Aggregator) - def reset(self, remote_evaluators): - self.sample_tasks.reset_evaluators(remote_evaluators) + def reset(self, remote_workers): + self.sample_tasks.reset_workers(remote_workers) def _augment_with_replay(self, sample_futures): def can_replay(): @@ -164,25 +164,25 @@ class SimpleAggregator(AggregationWorkerBase, Aggregator): """Simple single-threaded implementation of an Aggregator.""" def __init__(self, - local_evaluator, - remote_evaluators, + workers, max_sample_requests_in_flight_per_worker=2, replay_proportion=0.0, replay_buffer_num_slots=0, train_batch_size=500, sample_batch_size=50, broadcast_interval=5): - self.local_evaluator = local_evaluator + self.workers = workers + self.local_worker = workers.local_worker() self.broadcast_interval = broadcast_interval self.broadcast_new_weights() AggregationWorkerBase.__init__( - self, self.broadcasted_weights, remote_evaluators, + self, self.broadcasted_weights, self.workers.remote_workers(), max_sample_requests_in_flight_per_worker, replay_proportion, replay_buffer_num_slots, train_batch_size, sample_batch_size) @override(Aggregator) def broadcast_new_weights(self): - self.broadcasted_weights = ray.put(self.local_evaluator.get_weights()) + self.broadcasted_weights = ray.put(self.local_worker.get_weights()) self.num_sent_since_broadcast = 0 @override(Aggregator) diff --git a/python/ray/rllib/optimizers/aso_learner.py b/python/ray/rllib/optimizers/aso_learner.py index 3bf87f660730..74980bdf0a00 100644 --- a/python/ray/rllib/optimizers/aso_learner.py +++ b/python/ray/rllib/optimizers/aso_learner.py @@ -25,11 +25,11 @@ class LearnerThread(threading.Thread): improves overall throughput. """ - def __init__(self, local_evaluator, minibatch_buffer_size, num_sgd_iter, + def __init__(self, local_worker, minibatch_buffer_size, num_sgd_iter, learner_queue_size): threading.Thread.__init__(self) self.learner_queue_size = WindowStat("size", 50) - self.local_evaluator = local_evaluator + self.local_worker = local_worker self.inqueue = queue.Queue(maxsize=learner_queue_size) self.outqueue = queue.Queue() self.minibatch_buffer = MinibatchBuffer( @@ -52,7 +52,7 @@ def step(self): batch, _ = self.minibatch_buffer.get() with self.grad_timer: - fetches = self.local_evaluator.learn_on_batch(batch) + fetches = self.local_worker.learn_on_batch(batch) self.weights_updated = True self.stats = get_learner_stats(fetches) diff --git a/python/ray/rllib/optimizers/aso_multi_gpu_learner.py b/python/ray/rllib/optimizers/aso_multi_gpu_learner.py index b5040e45584c..78058da44ef4 100644 --- a/python/ray/rllib/optimizers/aso_multi_gpu_learner.py +++ b/python/ray/rllib/optimizers/aso_multi_gpu_learner.py @@ -31,7 +31,7 @@ class TFMultiGPULearner(LearnerThread): """ def __init__(self, - local_evaluator, + local_worker, num_gpus=1, lr=0.0005, train_batch_size=500, @@ -41,7 +41,7 @@ def __init__(self, learner_queue_size=16, num_data_load_threads=16, _fake_gpus=False): - LearnerThread.__init__(self, local_evaluator, minibatch_buffer_size, + LearnerThread.__init__(self, local_worker, minibatch_buffer_size, num_sgd_iter, learner_queue_size) self.lr = lr self.train_batch_size = train_batch_size @@ -59,16 +59,16 @@ def __init__(self, assert self.train_batch_size % len(self.devices) == 0 assert self.train_batch_size >= len(self.devices), "batch too small" - if set(self.local_evaluator.policy_map.keys()) != {DEFAULT_POLICY_ID}: + if set(self.local_worker.policy_map.keys()) != {DEFAULT_POLICY_ID}: raise NotImplementedError("Multi-gpu mode for multi-agent") - self.policy = self.local_evaluator.policy_map[DEFAULT_POLICY_ID] + self.policy = self.local_worker.policy_map[DEFAULT_POLICY_ID] # per-GPU graph copies created below must share vars with the policy # reuse is set to AUTO_REUSE because Adam nodes are created after # all of the device copies are created. self.par_opt = [] - with self.local_evaluator.tf_sess.graph.as_default(): - with self.local_evaluator.tf_sess.as_default(): + with self.local_worker.tf_sess.graph.as_default(): + with self.local_worker.tf_sess.as_default(): with tf.variable_scope(DEFAULT_POLICY_ID, reuse=tf.AUTO_REUSE): if self.policy._state_inputs: rnn_inputs = self.policy._state_inputs + [ @@ -87,7 +87,7 @@ def __init__(self, 999999, # it will get rounded down self.policy.copy)) - self.sess = self.local_evaluator.tf_sess + self.sess = self.local_worker.tf_sess self.sess.run(tf.global_variables_initializer()) self.idle_optimizers = queue.Queue() diff --git a/python/ray/rllib/optimizers/aso_tree_aggregator.py b/python/ray/rllib/optimizers/aso_tree_aggregator.py index cf51bce25352..75677e31372b 100644 --- a/python/ray/rllib/optimizers/aso_tree_aggregator.py +++ b/python/ray/rllib/optimizers/aso_tree_aggregator.py @@ -22,15 +22,14 @@ class TreeAggregator(Aggregator): """A hierarchical experiences aggregator. - The given set of remote evaluators is divided into subsets and assigned to + The given set of remote workers is divided into subsets and assigned to one of several aggregation workers. These aggregation workers collate experiences into batches of size `train_batch_size` and we collect them in this class when `iter_train_batches` is called. """ def __init__(self, - local_evaluator, - remote_evaluators, + workers, num_aggregation_workers, max_sample_requests_in_flight_per_worker=2, replay_proportion=0.0, @@ -38,8 +37,7 @@ def __init__(self, train_batch_size=500, sample_batch_size=50, broadcast_interval=5): - self.local_evaluator = local_evaluator - self.remote_evaluators = remote_evaluators + self.workers = workers self.num_aggregation_workers = num_aggregation_workers self.max_sample_requests_in_flight_per_worker = \ max_sample_requests_in_flight_per_worker @@ -48,7 +46,8 @@ def __init__(self, self.sample_batch_size = sample_batch_size self.train_batch_size = train_batch_size self.broadcast_interval = broadcast_interval - self.broadcasted_weights = ray.put(local_evaluator.get_weights()) + self.broadcasted_weights = ray.put( + workers.local_worker().get_weights()) self.num_batches_processed = 0 self.num_broadcasts = 0 self.num_sent_since_broadcast = 0 @@ -58,26 +57,27 @@ def init(self, aggregators): """Deferred init so that we can pass in previously created workers.""" assert len(aggregators) == self.num_aggregation_workers, aggregators - if len(self.remote_evaluators) < self.num_aggregation_workers: + if len(self.workers.remote_workers()) < self.num_aggregation_workers: raise ValueError( "The number of aggregation workers should not exceed the " "number of total evaluation workers ({} vs {})".format( - self.num_aggregation_workers, len(self.remote_evaluators))) + self.num_aggregation_workers, + len(self.workers.remote_workers()))) - assigned_evaluators = collections.defaultdict(list) - for i, ev in enumerate(self.remote_evaluators): - assigned_evaluators[i % self.num_aggregation_workers].append(ev) + assigned_workers = collections.defaultdict(list) + for i, ev in enumerate(self.workers.remote_workers()): + assigned_workers[i % self.num_aggregation_workers].append(ev) - self.workers = aggregators - for i, worker in enumerate(self.workers): - worker.init.remote( - self.broadcasted_weights, assigned_evaluators[i], - self.max_sample_requests_in_flight_per_worker, - self.replay_proportion, self.replay_buffer_num_slots, - self.train_batch_size, self.sample_batch_size) + self.aggregators = aggregators + for i, agg in enumerate(self.aggregators): + agg.init.remote(self.broadcasted_weights, assigned_workers[i], + self.max_sample_requests_in_flight_per_worker, + self.replay_proportion, + self.replay_buffer_num_slots, + self.train_batch_size, self.sample_batch_size) self.agg_tasks = TaskPool() - for agg in self.workers: + for agg in self.aggregators: agg.set_weights.remote(self.broadcasted_weights) self.agg_tasks.add(agg, agg.get_train_batches.remote()) @@ -96,7 +96,8 @@ def iter_train_batches(self): @override(Aggregator) def broadcast_new_weights(self): - self.broadcasted_weights = ray.put(self.local_evaluator.get_weights()) + self.broadcasted_weights = ray.put( + self.workers.local_worker().get_weights()) self.num_sent_since_broadcast = 0 self.num_broadcasts += 1 @@ -112,8 +113,8 @@ def stats(self): } @override(Aggregator) - def reset(self, remote_evaluators): - raise NotImplementedError("changing number of remote evaluators") + def reset(self, remote_workers): + raise NotImplementedError("changing number of remote workers") @staticmethod def precreate_aggregators(n): @@ -125,16 +126,16 @@ class AggregationWorker(AggregationWorkerBase): def __init__(self): self.initialized = False - def init(self, initial_weights_obj_id, remote_evaluators, + def init(self, initial_weights_obj_id, remote_workers, max_sample_requests_in_flight_per_worker, replay_proportion, replay_buffer_num_slots, train_batch_size, sample_batch_size): """Deferred init that assigns sub-workers to this aggregator.""" - logger.info("Assigned evaluators {} to aggregation worker {}".format( - remote_evaluators, self)) - assert remote_evaluators + logger.info("Assigned workers {} to aggregation worker {}".format( + remote_workers, self)) + assert remote_workers AggregationWorkerBase.__init__( - self, initial_weights_obj_id, remote_evaluators, + self, initial_weights_obj_id, remote_workers, max_sample_requests_in_flight_per_worker, replay_proportion, replay_buffer_num_slots, train_batch_size, sample_batch_size) self.initialized = True diff --git a/python/ray/rllib/optimizers/async_gradients_optimizer.py b/python/ray/rllib/optimizers/async_gradients_optimizer.py index 2b46e1259956..05f266b66238 100644 --- a/python/ray/rllib/optimizers/async_gradients_optimizer.py +++ b/python/ray/rllib/optimizers/async_gradients_optimizer.py @@ -14,30 +14,30 @@ class AsyncGradientsOptimizer(PolicyOptimizer): """An asynchronous RL optimizer, e.g. for implementing A3C. This optimizer asynchronously pulls and applies gradients from remote - evaluators, sending updated weights back as needed. This pipelines the + workers, sending updated weights back as needed. This pipelines the gradient computations on the remote workers. """ - def __init__(self, local_evaluator, remote_evaluators, grads_per_step=100): - PolicyOptimizer.__init__(self, local_evaluator, remote_evaluators) + def __init__(self, workers, grads_per_step=100): + PolicyOptimizer.__init__(self, workers) self.apply_timer = TimerStat() self.wait_timer = TimerStat() self.dispatch_timer = TimerStat() self.grads_per_step = grads_per_step self.learner_stats = {} - if not self.remote_evaluators: + if not self.workers.remote_workers(): raise ValueError( - "Async optimizer requires at least 1 remote evaluator") + "Async optimizer requires at least 1 remote workers") @override(PolicyOptimizer) def step(self): - weights = ray.put(self.local_evaluator.get_weights()) + weights = ray.put(self.workers.local_worker().get_weights()) pending_gradients = {} num_gradients = 0 # Kick off the first wave of async tasks - for e in self.remote_evaluators: + for e in self.workers.remote_workers(): e.set_weights.remote(weights) future = e.compute_gradients.remote(e.sample.remote()) pending_gradients[future] = e @@ -56,13 +56,14 @@ def step(self): if gradient is not None: with self.apply_timer: - self.local_evaluator.apply_gradients(gradient) + self.workers.local_worker().apply_gradients(gradient) self.num_steps_sampled += info["batch_count"] self.num_steps_trained += info["batch_count"] if num_gradients < self.grads_per_step: with self.dispatch_timer: - e.set_weights.remote(self.local_evaluator.get_weights()) + e.set_weights.remote( + self.workers.local_worker().get_weights()) future = e.compute_gradients.remote(e.sample.remote()) pending_gradients[future] = e diff --git a/python/ray/rllib/optimizers/async_replay_optimizer.py b/python/ray/rllib/optimizers/async_replay_optimizer.py index d66f942ae532..0b99cef2df53 100644 --- a/python/ray/rllib/optimizers/async_replay_optimizer.py +++ b/python/ray/rllib/optimizers/async_replay_optimizer.py @@ -36,20 +36,19 @@ class AsyncReplayOptimizer(PolicyOptimizer): """Main event loop of the Ape-X optimizer (async sampling with replay). This class coordinates the data transfers between the learner thread, - remote evaluators (Ape-X actors), and replay buffer actors. + remote workers (Ape-X actors), and replay buffer actors. This has two modes of operation: - normal replay: replays independent samples. - batch replay: simplified mode where entire sample batches are replayed. This supports RNNs, but not prioritization. - This optimizer requires that policy evaluators return an additional + This optimizer requires that rollout workers return an additional "td_error" array in the info return of compute_gradients(). This error term will be used for sample prioritization.""" def __init__(self, - local_evaluator, - remote_evaluators, + workers, learning_starts=1000, buffer_size=10000, prioritized_replay=True, @@ -62,7 +61,7 @@ def __init__(self, max_weight_sync_delay=400, debug=False, batch_replay=False): - PolicyOptimizer.__init__(self, local_evaluator, remote_evaluators) + PolicyOptimizer.__init__(self, workers) self.debug = debug self.batch_replay = batch_replay @@ -71,7 +70,7 @@ def __init__(self, self.prioritized_replay_eps = prioritized_replay_eps self.max_weight_sync_delay = max_weight_sync_delay - self.learner = LearnerThread(self.local_evaluator) + self.learner = LearnerThread(self.workers.local_worker()) self.learner.start() if self.batch_replay: @@ -111,13 +110,13 @@ def __init__(self, # Kick off async background sampling self.sample_tasks = TaskPool() - if self.remote_evaluators: - self._set_evaluators(self.remote_evaluators) + if self.workers.remote_workers(): + self._set_workers(self.workers.remote_workers()) @override(PolicyOptimizer) def step(self): assert self.learner.is_alive() - assert len(self.remote_evaluators) > 0 + assert len(self.workers.remote_workers()) > 0 start = time.time() sample_timesteps, train_timesteps = self._step() time_delta = time.time() - start @@ -138,9 +137,9 @@ def stop(self): self.learner.stopped = True @override(PolicyOptimizer) - def reset(self, remote_evaluators): - self.remote_evaluators = remote_evaluators - self.sample_tasks.reset_evaluators(remote_evaluators) + def reset(self, remote_workers): + self.workers.reset(remote_workers) + self.sample_tasks.reset_workers(remote_workers) @override(PolicyOptimizer) def stats(self): @@ -175,10 +174,10 @@ def stats(self): return dict(PolicyOptimizer.stats(self), **stats) # For https://github.com/ray-project/ray/issues/2541 only - def _set_evaluators(self, remote_evaluators): - self.remote_evaluators = remote_evaluators - weights = self.local_evaluator.get_weights() - for ev in self.remote_evaluators: + def _set_workers(self, remote_workers): + self.workers.reset(remote_workers) + weights = self.workers.local_worker().get_weights() + for ev in self.workers.remote_workers(): ev.set_weights.remote(weights) self.steps_since_update[ev] = 0 for _ in range(SAMPLE_QUEUE_DEPTH): @@ -207,7 +206,7 @@ def _step(self): self.learner.weights_updated = False with self.timers["put_weights"]: weights = ray.put( - self.local_evaluator.get_weights()) + self.workers.local_worker().get_weights()) ev.set_weights.remote(weights) self.num_weight_syncs += 1 self.steps_since_update[ev] = 0 @@ -380,10 +379,10 @@ class LearnerThread(threading.Thread): improves overall throughput. """ - def __init__(self, local_evaluator): + def __init__(self, local_worker): threading.Thread.__init__(self) self.learner_queue_size = WindowStat("size", 50) - self.local_evaluator = local_evaluator + self.local_worker = local_worker self.inqueue = queue.Queue(maxsize=LEARNER_QUEUE_MAX_SIZE) self.outqueue = queue.Queue() self.queue_timer = TimerStat() @@ -403,7 +402,7 @@ def step(self): if replay is not None: prio_dict = {} with self.grad_timer: - grad_out = self.local_evaluator.learn_on_batch(replay) + grad_out = self.local_worker.learn_on_batch(replay) for pid, info in grad_out.items(): prio_dict[pid] = ( replay.policy_batches[pid].data.get("batch_indexes"), diff --git a/python/ray/rllib/optimizers/async_samples_optimizer.py b/python/ray/rllib/optimizers/async_samples_optimizer.py index e2ff320e618c..1e3afb8fb2c3 100644 --- a/python/ray/rllib/optimizers/async_samples_optimizer.py +++ b/python/ray/rllib/optimizers/async_samples_optimizer.py @@ -24,12 +24,11 @@ class AsyncSamplesOptimizer(PolicyOptimizer): """Main event loop of the IMPALA architecture. This class coordinates the data transfers between the learner thread - and remote evaluators (IMPALA actors). + and remote workers (IMPALA actors). """ def __init__(self, - local_evaluator, - remote_evaluators, + workers, train_batch_size=500, sample_batch_size=50, num_envs_per_worker=1, @@ -45,7 +44,7 @@ def __init__(self, learner_queue_size=16, num_aggregation_workers=0, _fake_gpus=False): - PolicyOptimizer.__init__(self, local_evaluator, remote_evaluators) + PolicyOptimizer.__init__(self, workers) self._stats_start_time = time.time() self._last_stats_time = {} @@ -62,7 +61,7 @@ def __init__(self, "{} vs {}".format(num_data_loader_buffers, minibatch_buffer_size)) self.learner = TFMultiGPULearner( - self.local_evaluator, + self.workers.local_worker(), lr=lr, num_gpus=num_gpus, train_batch_size=train_batch_size, @@ -72,7 +71,7 @@ def __init__(self, learner_queue_size=learner_queue_size, _fake_gpus=_fake_gpus) else: - self.learner = LearnerThread(self.local_evaluator, + self.learner = LearnerThread(self.workers.local_worker(), minibatch_buffer_size, num_sgd_iter, learner_queue_size) self.learner.start() @@ -84,8 +83,7 @@ def __init__(self, if num_aggregation_workers > 0: self.aggregator = TreeAggregator( - self.local_evaluator, - self.remote_evaluators, + workers, num_aggregation_workers, replay_proportion=replay_proportion, max_sample_requests_in_flight_per_worker=( @@ -96,8 +94,7 @@ def __init__(self, broadcast_interval=broadcast_interval) else: self.aggregator = SimpleAggregator( - self.local_evaluator, - self.remote_evaluators, + workers, replay_proportion=replay_proportion, max_sample_requests_in_flight_per_worker=( max_sample_requests_in_flight_per_worker), @@ -127,7 +124,7 @@ def get_mean_stats_and_reset(self): @override(PolicyOptimizer) def step(self): - if len(self.remote_evaluators) == 0: + if len(self.workers.remote_workers()) == 0: raise ValueError("Config num_workers=0 means training will hang!") assert self.learner.is_alive() with self._optimizer_step_timer: @@ -146,9 +143,9 @@ def stop(self): self.learner.stopped = True @override(PolicyOptimizer) - def reset(self, remote_evaluators): - self.remote_evaluators = remote_evaluators - self.aggregator.reset(remote_evaluators) + def reset(self, remote_workers): + self.workers.reset(remote_workers) + self.aggregator.reset(remote_workers) @override(PolicyOptimizer) def stats(self): diff --git a/python/ray/rllib/optimizers/multi_gpu_optimizer.py b/python/ray/rllib/optimizers/multi_gpu_optimizer.py index a25553c40111..65d7842d82c7 100644 --- a/python/ray/rllib/optimizers/multi_gpu_optimizer.py +++ b/python/ray/rllib/optimizers/multi_gpu_optimizer.py @@ -28,7 +28,7 @@ class LocalMultiGPUOptimizer(PolicyOptimizer): """A synchronous optimizer that uses multiple local GPUs. - Samples are pulled synchronously from multiple remote evaluators, + Samples are pulled synchronously from multiple remote workers, concatenated, and then split across the memory of multiple local GPUs. A number of SGD passes are then taken over the in-memory data. For more details, see `multi_gpu_impl.LocalSyncParallelOptimizer`. @@ -42,8 +42,7 @@ class LocalMultiGPUOptimizer(PolicyOptimizer): """ def __init__(self, - local_evaluator, - remote_evaluators, + workers, sgd_batch_size=128, num_sgd_iter=10, sample_batch_size=200, @@ -52,7 +51,7 @@ def __init__(self, num_gpus=0, standardize_fields=[], straggler_mitigation=False): - PolicyOptimizer.__init__(self, local_evaluator, remote_evaluators) + PolicyOptimizer.__init__(self, workers) self.batch_size = sgd_batch_size self.num_sgd_iter = num_sgd_iter @@ -79,8 +78,8 @@ def __init__(self, logger.info("LocalMultiGPUOptimizer devices {}".format(self.devices)) - self.policies = dict( - self.local_evaluator.foreach_trainable_policy(lambda p, i: (i, p))) + self.policies = dict(self.workers.local_worker() + .foreach_trainable_policy(lambda p, i: (i, p))) logger.debug("Policies to train: {}".format(self.policies)) for policy_id, policy in self.policies.items(): if not isinstance(policy, TFPolicy): @@ -92,8 +91,8 @@ def __init__(self, # reuse is set to AUTO_REUSE because Adam nodes are created after # all of the device copies are created. self.optimizers = {} - with self.local_evaluator.tf_sess.graph.as_default(): - with self.local_evaluator.tf_sess.as_default(): + with self.workers.local_worker().tf_sess.graph.as_default(): + with self.workers.local_worker().tf_sess.as_default(): for policy_id, policy in self.policies.items(): with tf.variable_scope(policy_id, reuse=tf.AUTO_REUSE): if policy._state_inputs: @@ -109,25 +108,25 @@ def __init__(self, for _, v in policy._loss_inputs], rnn_inputs, self.per_device_batch_size, policy.copy)) - self.sess = self.local_evaluator.tf_sess + self.sess = self.workers.local_worker().tf_sess self.sess.run(tf.global_variables_initializer()) @override(PolicyOptimizer) def step(self): with self.update_weights_timer: - if self.remote_evaluators: - weights = ray.put(self.local_evaluator.get_weights()) - for e in self.remote_evaluators: + if self.workers.remote_workers(): + weights = ray.put(self.workers.local_worker().get_weights()) + for e in self.workers.remote_workers(): e.set_weights.remote(weights) with self.sample_timer: - if self.remote_evaluators: + if self.workers.remote_workers(): if self.straggler_mitigation: samples = collect_samples_straggler_mitigation( - self.remote_evaluators, self.train_batch_size) + self.workers.remote_workers(), self.train_batch_size) else: samples = collect_samples( - self.remote_evaluators, self.sample_batch_size, + self.workers.remote_workers(), self.sample_batch_size, self.num_envs_per_worker, self.train_batch_size) if samples.count > self.train_batch_size * 2: logger.info( @@ -139,7 +138,7 @@ def step(self): else: samples = [] while sum(s.count for s in samples) < self.train_batch_size: - samples.append(self.local_evaluator.sample()) + samples.append(self.workers.local_worker().sample()) samples = SampleBatch.concat_samples(samples) # Handle everything as if multiagent diff --git a/python/ray/rllib/optimizers/policy_optimizer.py b/python/ray/rllib/optimizers/policy_optimizer.py index f67ea9cdc073..29287e96440d 100644 --- a/python/ray/rllib/optimizers/policy_optimizer.py +++ b/python/ray/rllib/optimizers/policy_optimizer.py @@ -6,7 +6,6 @@ from ray.rllib.utils.annotations import DeveloperAPI from ray.rllib.evaluation.metrics import collect_episodes, summarize_episodes -from ray.rllib.utils.memory import ray_get_and_free logger = logging.getLogger(__name__) @@ -21,34 +20,21 @@ class PolicyOptimizer(object): used for PPO. These optimizers are all pluggable, and it is possible to mix and match as needed. - In order for an algorithm to use an RLlib optimizer, it must implement - the PolicyEvaluator interface and pass a PolicyEvaluator class or set of - PolicyEvaluators to its PolicyOptimizer of choice. The PolicyOptimizer - uses these Evaluators to sample from the environment and compute model - gradient updates. - Attributes: config (dict): The JSON configuration passed to this optimizer. - local_evaluator (PolicyEvaluator): The embedded evaluator instance. - remote_evaluators (list): List of remote evaluator replicas, or []. + workers (WorkerSet): The set of rollout workers to use. num_steps_trained (int): Number of timesteps trained on so far. num_steps_sampled (int): Number of timesteps sampled so far. - evaluator_resources (dict): Optional resource requests to set for - evaluators created by this optimizer. """ @DeveloperAPI - def __init__(self, local_evaluator, remote_evaluators=None): + def __init__(self, workers): """Create an optimizer instance. Args: - local_evaluator (Evaluator): Local evaluator instance, required. - remote_evaluators (list): A list of Ray actor handles to remote - evaluators instances. If empty, the optimizer should fall back - to using only the local evaluator. + workers (WorkerSet): The set of rollout workers to use. """ - self.local_evaluator = local_evaluator - self.remote_evaluators = remote_evaluators or [] + self.workers = workers self.episode_history = [] # Counters that should be updated by sub-classes @@ -100,23 +86,23 @@ def stop(self): def collect_metrics(self, timeout_seconds, min_history=100, - selected_evaluators=None): - """Returns evaluator and optimizer stats. + selected_workers=None): + """Returns worker and optimizer stats. Arguments: - timeout_seconds (int): Max wait time for a evaluator before - dropping its results. This usually indicates a hung evaluator. + timeout_seconds (int): Max wait time for a worker before + dropping its results. This usually indicates a hung worker. min_history (int): Min history length to smooth results over. - selected_evaluators (list): Override the list of remote evaluators + selected_workers (list): Override the list of remote workers to collect metrics from. Returns: - res (dict): A training result dict from evaluator metrics with + res (dict): A training result dict from worker metrics with `info` replaced with stats from self. """ episodes, num_dropped = collect_episodes( - self.local_evaluator, - selected_evaluators or self.remote_evaluators, + self.workers.local_worker(), + selected_workers or self.workers.remote_workers(), timeout_seconds=timeout_seconds) orig_episodes = list(episodes) missing = min_history - len(episodes) @@ -130,30 +116,28 @@ def collect_metrics(self, return res @DeveloperAPI - def reset(self, remote_evaluators): - """Called to change the set of remote evaluators being used.""" - - self.remote_evaluators = remote_evaluators + def reset(self, remote_workers): + """Called to change the set of remote workers being used.""" + self.workers.reset(remote_workers) @DeveloperAPI - def foreach_evaluator(self, func): - """Apply the given function to each evaluator instance.""" - - local_result = [func(self.local_evaluator)] - remote_results = ray_get_and_free( - [ev.apply.remote(func) for ev in self.remote_evaluators]) - return local_result + remote_results + def foreach_worker(self, func): + """Apply the given function to each worker instance.""" + return self.workers.foreach_worker(func) @DeveloperAPI - def foreach_evaluator_with_index(self, func): - """Apply the given function to each evaluator instance. + def foreach_worker_with_index(self, func): + """Apply the given function to each worker instance. The index will be passed as the second arg to the given function. """ + return self.workers.foreach_worker_with_index(func) + + def foreach_evaluator(self, func): + raise DeprecationWarning( + "foreach_evaluator has been renamed to foreach_worker") - local_result = [func(self.local_evaluator, 0)] - remote_results = ray_get_and_free([ - ev.apply.remote(func, i + 1) - for i, ev in enumerate(self.remote_evaluators) - ]) - return local_result + remote_results + def foreach_evaluator_with_index(self, func): + raise DeprecationWarning( + "foreach_evaluator_with_index has been renamed to " + "foreach_worker_with_index") diff --git a/python/ray/rllib/optimizers/sync_batch_replay_optimizer.py b/python/ray/rllib/optimizers/sync_batch_replay_optimizer.py index e13d71c6e4cd..e2b4865da5ee 100644 --- a/python/ray/rllib/optimizers/sync_batch_replay_optimizer.py +++ b/python/ray/rllib/optimizers/sync_batch_replay_optimizer.py @@ -20,12 +20,11 @@ class SyncBatchReplayOptimizer(PolicyOptimizer): This enables RNN support. Does not currently support prioritization.""" def __init__(self, - local_evaluator, - remote_evaluators, + workers, learning_starts=1000, buffer_size=10000, train_batch_size=32): - PolicyOptimizer.__init__(self, local_evaluator, remote_evaluators) + PolicyOptimizer.__init__(self, workers) self.replay_starts = learning_starts self.max_buffer_size = buffer_size @@ -45,17 +44,17 @@ def __init__(self, @override(PolicyOptimizer) def step(self): with self.update_weights_timer: - if self.remote_evaluators: - weights = ray.put(self.local_evaluator.get_weights()) - for e in self.remote_evaluators: + if self.workers.remote_workers(): + weights = ray.put(self.workers.local_worker().get_weights()) + for e in self.workers.remote_workers(): e.set_weights.remote(weights) with self.sample_timer: - if self.remote_evaluators: + if self.workers.remote_workers(): batches = ray_get_and_free( - [e.sample.remote() for e in self.remote_evaluators]) + [e.sample.remote() for e in self.workers.remote_workers()]) else: - batches = [self.local_evaluator.sample()] + batches = [self.workers.local_worker().sample()] # Handle everything as if multiagent tmp = [] @@ -105,7 +104,7 @@ def _optimize(self): samples.append(random.choice(self.replay_buffer)) samples = SampleBatch.concat_samples(samples) with self.grad_timer: - info_dict = self.local_evaluator.learn_on_batch(samples) + info_dict = self.workers.local_worker().learn_on_batch(samples) for policy_id, info in info_dict.items(): self.learner_stats[policy_id] = get_learner_stats(info) self.grad_timer.push_units_processed(samples.count) diff --git a/python/ray/rllib/optimizers/sync_replay_optimizer.py b/python/ray/rllib/optimizers/sync_replay_optimizer.py index 27858f3527c1..881e02f90c74 100644 --- a/python/ray/rllib/optimizers/sync_replay_optimizer.py +++ b/python/ray/rllib/optimizers/sync_replay_optimizer.py @@ -25,13 +25,12 @@ class SyncReplayOptimizer(PolicyOptimizer): """Variant of the local sync optimizer that supports replay (for DQN). - This optimizer requires that policy evaluators return an additional + This optimizer requires that rollout workers return an additional "td_error" array in the info return of compute_gradients(). This error term will be used for sample prioritization.""" def __init__(self, - local_evaluator, - remote_evaluators, + workers, learning_starts=1000, buffer_size=10000, prioritized_replay=True, @@ -43,7 +42,7 @@ def __init__(self, prioritized_replay_eps=1e-6, train_batch_size=32, sample_batch_size=4): - PolicyOptimizer.__init__(self, local_evaluator, remote_evaluators) + PolicyOptimizer.__init__(self, workers) self.replay_starts = learning_starts # linearly annealing beta used in Rainbow paper @@ -82,18 +81,20 @@ def new_buffer(): @override(PolicyOptimizer) def step(self): with self.update_weights_timer: - if self.remote_evaluators: - weights = ray.put(self.local_evaluator.get_weights()) - for e in self.remote_evaluators: + if self.workers.remote_workers(): + weights = ray.put(self.workers.local_worker().get_weights()) + for e in self.workers.remote_workers(): e.set_weights.remote(weights) with self.sample_timer: - if self.remote_evaluators: + if self.workers.remote_workers(): batch = SampleBatch.concat_samples( - ray_get_and_free( - [e.sample.remote() for e in self.remote_evaluators])) + ray_get_and_free([ + e.sample.remote() + for e in self.workers.remote_workers() + ])) else: - batch = self.local_evaluator.sample() + batch = self.workers.local_worker().sample() # Handle everything as if multiagent if isinstance(batch, SampleBatch): @@ -135,7 +136,7 @@ def _optimize(self): samples = self._replay() with self.grad_timer: - info_dict = self.local_evaluator.learn_on_batch(samples) + info_dict = self.workers.local_worker().learn_on_batch(samples) for policy_id, info in info_dict.items(): self.learner_stats[policy_id] = get_learner_stats(info) replay_buffer = self.replay_buffers[policy_id] diff --git a/python/ray/rllib/optimizers/sync_samples_optimizer.py b/python/ray/rllib/optimizers/sync_samples_optimizer.py index a49b290d3e2c..0f79062a337d 100644 --- a/python/ray/rllib/optimizers/sync_samples_optimizer.py +++ b/python/ray/rllib/optimizers/sync_samples_optimizer.py @@ -19,16 +19,12 @@ class SyncSamplesOptimizer(PolicyOptimizer): """A simple synchronous RL optimizer. In each step, this optimizer pulls samples from a number of remote - evaluators, concatenates them, and then updates a local model. The updated - model weights are then broadcast to all remote evaluators. + workers, concatenates them, and then updates a local model. The updated + model weights are then broadcast to all remote workers. """ - def __init__(self, - local_evaluator, - remote_evaluators, - num_sgd_iter=1, - train_batch_size=1): - PolicyOptimizer.__init__(self, local_evaluator, remote_evaluators) + def __init__(self, workers, num_sgd_iter=1, train_batch_size=1): + PolicyOptimizer.__init__(self, workers) self.update_weights_timer = TimerStat() self.sample_timer = TimerStat() @@ -41,27 +37,28 @@ def __init__(self, @override(PolicyOptimizer) def step(self): with self.update_weights_timer: - if self.remote_evaluators: - weights = ray.put(self.local_evaluator.get_weights()) - for e in self.remote_evaluators: + if self.workers.remote_workers(): + weights = ray.put(self.workers.local_worker().get_weights()) + for e in self.workers.remote_workers(): e.set_weights.remote(weights) with self.sample_timer: samples = [] while sum(s.count for s in samples) < self.train_batch_size: - if self.remote_evaluators: + if self.workers.remote_workers(): samples.extend( ray_get_and_free([ - e.sample.remote() for e in self.remote_evaluators + e.sample.remote() + for e in self.workers.remote_workers() ])) else: - samples.append(self.local_evaluator.sample()) + samples.append(self.workers.local_worker().sample()) samples = SampleBatch.concat_samples(samples) self.sample_timer.push_units_processed(samples.count) with self.grad_timer: for i in range(self.num_sgd_iter): - fetches = self.local_evaluator.learn_on_batch(samples) + fetches = self.workers.local_worker().learn_on_batch(samples) self.learner_stats = get_learner_stats(fetches) if self.num_sgd_iter > 1: logger.debug("{} {}".format(i, fetches)) diff --git a/python/ray/rllib/policy/dynamic_tf_policy.py b/python/ray/rllib/policy/dynamic_tf_policy.py index afa72a0af709..0240f275de37 100644 --- a/python/ray/rllib/policy/dynamic_tf_policy.py +++ b/python/ray/rllib/policy/dynamic_tf_policy.py @@ -142,7 +142,7 @@ def __init__(self, action_prob = self.action_dist.sampled_action_prob() # Phase 1 init - sess = tf.get_default_session() + sess = tf.get_default_session() or tf.Session() if get_batch_divisibility_req: batch_divisibility_req = get_batch_divisibility_req(self) else: diff --git a/python/ray/rllib/policy/policy.py b/python/ray/rllib/policy/policy.py index 6f456e608007..e12cafef2cc4 100644 --- a/python/ray/rllib/policy/policy.py +++ b/python/ray/rllib/policy/policy.py @@ -36,7 +36,7 @@ def __init__(self, observation_space, action_space, config): """Initialize the graph. This is the standard constructor for policies. The policy - class you pass into PolicyEvaluator will be constructed with + class you pass into RolloutWorker will be constructed with these arguments. Args: diff --git a/python/ray/rllib/policy/tf_policy_template.py b/python/ray/rllib/policy/tf_policy_template.py index 7f10958cdee7..b7f33fcb0887 100644 --- a/python/ray/rllib/policy/tf_policy_template.py +++ b/python/ray/rllib/policy/tf_policy_template.py @@ -88,9 +88,7 @@ def build_tf_policy(name, a DynamicTFPolicy instance that uses the specified args """ - if not name.endswith("TFPolicy"): - raise ValueError("Name should match *TFPolicy", name) - + original_kwargs = locals().copy() base = DynamicTFPolicy while mixins: @@ -191,6 +189,11 @@ def extra_compute_grad_feed_dict(self): else: return TFPolicy.extra_compute_grad_feed_dict(self) + @staticmethod + def with_updates(**overrides): + return build_tf_policy(**dict(original_kwargs, **overrides)) + + policy_cls.with_updates = with_updates policy_cls.__name__ = name policy_cls.__qualname__ = name return policy_cls diff --git a/python/ray/rllib/policy/torch_policy_template.py b/python/ray/rllib/policy/torch_policy_template.py index 19e943600210..1f4185f9c12e 100644 --- a/python/ray/rllib/policy/torch_policy_template.py +++ b/python/ray/rllib/policy/torch_policy_template.py @@ -24,7 +24,7 @@ def build_torch_policy(name, """Helper function for creating a torch policy at runtime. Arguments: - name (str): name of the policy (e.g., "PPOTFPolicy") + name (str): name of the policy (e.g., "PPOTorchPolicy") loss_fn (func): function that returns a loss tensor the policy, and dict of experience tensor placeholders get_default_config (func): optional function that returns the default @@ -55,9 +55,7 @@ def build_torch_policy(name, a TorchPolicy instance that uses the specified args """ - if not name.endswith("TorchPolicy"): - raise ValueError("Name should match *TorchPolicy", name) - + original_kwargs = locals().copy() base = TorchPolicy while mixins: @@ -66,7 +64,7 @@ class new_base(mixins.pop(), base): base = new_base - class graph_cls(base): + class policy_cls(base): def __init__(self, obs_space, action_space, config): if get_default_config: config = dict(get_default_config(), **config) @@ -130,6 +128,11 @@ def extra_grad_info(self, batch_tensors): else: return TorchPolicy.extra_grad_info(self, batch_tensors) - graph_cls.__name__ = name - graph_cls.__qualname__ = name - return graph_cls + @staticmethod + def with_updates(**overrides): + return build_torch_policy(**dict(original_kwargs, **overrides)) + + policy_cls.with_updates = with_updates + policy_cls.__name__ = name + policy_cls.__qualname__ = name + return policy_cls diff --git a/python/ray/rllib/rollout.py b/python/ray/rllib/rollout.py index efa5743c0a54..d8292739f923 100755 --- a/python/ray/rllib/rollout.py +++ b/python/ray/rllib/rollout.py @@ -120,14 +120,14 @@ def default_policy_agent_mapping(unused_agent_id): def rollout(agent, env_name, num_steps, out=None, no_render=True): policy_agent_mapping = default_policy_agent_mapping - if hasattr(agent, "local_evaluator"): - env = agent.local_evaluator.env + if hasattr(agent, "workers"): + env = agent.workers.local_worker().env multiagent = isinstance(env, MultiAgentEnv) - if agent.local_evaluator.multiagent: + if agent.workers.local_worker().multiagent: policy_agent_mapping = agent.config["multiagent"][ "policy_mapping_fn"] - policy_map = agent.local_evaluator.policy_map + policy_map = agent.workers.local_worker().policy_map state_init = {p: m.get_initial_state() for p, m in policy_map.items()} use_lstm = {p: len(s) > 0 for p, s in state_init.items()} action_init = { diff --git a/python/ray/rllib/tests/mock_evaluator.py b/python/ray/rllib/tests/mock_worker.py similarity index 98% rename from python/ray/rllib/tests/mock_evaluator.py rename to python/ray/rllib/tests/mock_worker.py index e11b097e7119..b6b2e9773c30 100644 --- a/python/ray/rllib/tests/mock_evaluator.py +++ b/python/ray/rllib/tests/mock_worker.py @@ -8,7 +8,7 @@ from ray.rllib.utils.filter import MeanStdFilter -class _MockEvaluator(object): +class _MockWorker(object): def __init__(self, sample_count=10): self._weights = np.array([-10, -10, -10, -10]) self._grad = np.array([1, 1, 1, 1]) diff --git a/python/ray/rllib/tests/test_external_env.py b/python/ray/rllib/tests/test_external_env.py index 3b2158959267..24281e757fb9 100644 --- a/python/ray/rllib/tests/test_external_env.py +++ b/python/ray/rllib/tests/test_external_env.py @@ -11,10 +11,10 @@ import ray from ray.rllib.agents.dqn import DQNTrainer from ray.rllib.agents.pg import PGTrainer -from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator +from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.env.external_env import ExternalEnv -from ray.rllib.tests.test_policy_evaluator import (BadPolicy, MockPolicy, - MockEnv) +from ray.rllib.tests.test_rollout_worker import (BadPolicy, MockPolicy, + MockEnv) from ray.tune.registry import register_env @@ -119,7 +119,7 @@ def run(self): class TestExternalEnv(unittest.TestCase): def testExternalEnvCompleteEpisodes(self): - ev = PolicyEvaluator( + ev = RolloutWorker( env_creator=lambda _: SimpleServing(MockEnv(25)), policy=MockPolicy, batch_steps=40, @@ -129,7 +129,7 @@ def testExternalEnvCompleteEpisodes(self): self.assertEqual(batch.count, 50) def testExternalEnvTruncateEpisodes(self): - ev = PolicyEvaluator( + ev = RolloutWorker( env_creator=lambda _: SimpleServing(MockEnv(25)), policy=MockPolicy, batch_steps=40, @@ -139,7 +139,7 @@ def testExternalEnvTruncateEpisodes(self): self.assertEqual(batch.count, 40) def testExternalEnvOffPolicy(self): - ev = PolicyEvaluator( + ev = RolloutWorker( env_creator=lambda _: SimpleOffPolicyServing(MockEnv(25), 42), policy=MockPolicy, batch_steps=40, @@ -151,7 +151,7 @@ def testExternalEnvOffPolicy(self): self.assertEqual(batch["actions"][-1], 42) def testExternalEnvBadActions(self): - ev = PolicyEvaluator( + ev = RolloutWorker( env_creator=lambda _: SimpleServing(MockEnv(25)), policy=BadPolicy, sample_async=True, @@ -196,7 +196,7 @@ def testTrainCartpoleMulti(self): raise Exception("failed to improve reward") def testExternalEnvHorizonNotSupported(self): - ev = PolicyEvaluator( + ev = RolloutWorker( env_creator=lambda _: SimpleServing(MockEnv(25)), policy=MockPolicy, episode_horizon=20, diff --git a/python/ray/rllib/tests/test_external_multi_agent_env.py b/python/ray/rllib/tests/test_external_multi_agent_env.py index fcb3de634cbe..be232c0bfb67 100644 --- a/python/ray/rllib/tests/test_external_multi_agent_env.py +++ b/python/ray/rllib/tests/test_external_multi_agent_env.py @@ -10,9 +10,10 @@ import ray from ray.rllib.agents.pg.pg_policy import PGTFPolicy from ray.rllib.optimizers import SyncSamplesOptimizer -from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator +from ray.rllib.evaluation.rollout_worker import RolloutWorker +from ray.rllib.evaluation.worker_set import WorkerSet from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv -from ray.rllib.tests.test_policy_evaluator import MockPolicy +from ray.rllib.tests.test_rollout_worker import MockPolicy from ray.rllib.tests.test_external_env import make_simple_serving from ray.rllib.tests.test_multi_agent_env import BasicMultiAgent, MultiCartpole from ray.rllib.evaluation.metrics import collect_metrics @@ -23,7 +24,7 @@ class TestExternalMultiAgentEnv(unittest.TestCase): def testExternalMultiAgentEnvCompleteEpisodes(self): agents = 4 - ev = PolicyEvaluator( + ev = RolloutWorker( env_creator=lambda _: SimpleMultiServing(BasicMultiAgent(agents)), policy=MockPolicy, batch_steps=40, @@ -35,7 +36,7 @@ def testExternalMultiAgentEnvCompleteEpisodes(self): def testExternalMultiAgentEnvTruncateEpisodes(self): agents = 4 - ev = PolicyEvaluator( + ev = RolloutWorker( env_creator=lambda _: SimpleMultiServing(BasicMultiAgent(agents)), policy=MockPolicy, batch_steps=40, @@ -49,7 +50,7 @@ def testExternalMultiAgentEnvSample(self): agents = 2 act_space = gym.spaces.Discrete(2) obs_space = gym.spaces.Discrete(2) - ev = PolicyEvaluator( + ev = RolloutWorker( env_creator=lambda _: SimpleMultiServing(BasicMultiAgent(agents)), policy={ "p0": (MockPolicy, obs_space, act_space, {}), @@ -70,12 +71,12 @@ def testTrainExternalMultiCartpoleManyPolicies(self): policies["pg_{}".format(i)] = (PGTFPolicy, obs_space, act_space, {}) policy_ids = list(policies.keys()) - ev = PolicyEvaluator( + ev = RolloutWorker( env_creator=lambda _: MultiCartpole(n), policy=policies, policy_mapping_fn=lambda agent_id: random.choice(policy_ids), batch_steps=100) - optimizer = SyncSamplesOptimizer(ev, []) + optimizer = SyncSamplesOptimizer(WorkerSet._from_existing(ev)) for i in range(100): optimizer.step() result = collect_metrics(ev) diff --git a/python/ray/rllib/tests/test_filters.py b/python/ray/rllib/tests/test_filters.py index f039c6c09019..1446809eb9fc 100644 --- a/python/ray/rllib/tests/test_filters.py +++ b/python/ray/rllib/tests/test_filters.py @@ -8,7 +8,7 @@ import ray from ray.rllib.utils.filter import RunningStat, MeanStdFilter from ray.rllib.utils import FilterManager -from ray.rllib.tests.mock_evaluator import _MockEvaluator +from ray.rllib.tests.mock_worker import _MockWorker class RunningStatTest(unittest.TestCase): @@ -89,8 +89,8 @@ def testSynchronize(self): filt1.clear_buffer() self.assertEqual(filt1.buffer.n, 0) - RemoteEvaluator = ray.remote(_MockEvaluator) - remote_e = RemoteEvaluator.remote(sample_count=10) + RemoteWorker = ray.remote(_MockWorker) + remote_e = RemoteWorker.remote(sample_count=10) remote_e.sample.remote() FilterManager.synchronize({ diff --git a/python/ray/rllib/tests/test_multi_agent_env.py b/python/ray/rllib/tests/test_multi_agent_env.py index be4bfcd3428f..e69ba6b1f53d 100644 --- a/python/ray/rllib/tests/test_multi_agent_env.py +++ b/python/ray/rllib/tests/test_multi_agent_env.py @@ -12,11 +12,11 @@ from ray.rllib.agents.dqn.dqn_policy import DQNTFPolicy from ray.rllib.optimizers import (SyncSamplesOptimizer, SyncReplayOptimizer, AsyncGradientsOptimizer) -from ray.rllib.tests.test_policy_evaluator import (MockEnv, MockEnv2, - MockPolicy) -from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator +from ray.rllib.tests.test_rollout_worker import (MockEnv, MockEnv2, MockPolicy) +from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.policy.policy import Policy from ray.rllib.evaluation.metrics import collect_metrics +from ray.rllib.evaluation.worker_set import WorkerSet from ray.rllib.env.base_env import _MultiAgentEnvToBaseEnv from ray.rllib.env.multi_agent_env import MultiAgentEnv from ray.tune.registry import register_env @@ -327,7 +327,7 @@ def testVectorizeRoundRobin(self): def testMultiAgentSample(self): act_space = gym.spaces.Discrete(2) obs_space = gym.spaces.Discrete(2) - ev = PolicyEvaluator( + ev = RolloutWorker( env_creator=lambda _: BasicMultiAgent(5), policy={ "p0": (MockPolicy, obs_space, act_space, {}), @@ -345,7 +345,7 @@ def testMultiAgentSample(self): def testMultiAgentSampleSyncRemote(self): act_space = gym.spaces.Discrete(2) obs_space = gym.spaces.Discrete(2) - ev = PolicyEvaluator( + ev = RolloutWorker( env_creator=lambda _: BasicMultiAgent(5), policy={ "p0": (MockPolicy, obs_space, act_space, {}), @@ -362,7 +362,7 @@ def testMultiAgentSampleSyncRemote(self): def testMultiAgentSampleAsyncRemote(self): act_space = gym.spaces.Discrete(2) obs_space = gym.spaces.Discrete(2) - ev = PolicyEvaluator( + ev = RolloutWorker( env_creator=lambda _: BasicMultiAgent(5), policy={ "p0": (MockPolicy, obs_space, act_space, {}), @@ -378,7 +378,7 @@ def testMultiAgentSampleAsyncRemote(self): def testMultiAgentSampleWithHorizon(self): act_space = gym.spaces.Discrete(2) obs_space = gym.spaces.Discrete(2) - ev = PolicyEvaluator( + ev = RolloutWorker( env_creator=lambda _: BasicMultiAgent(5), policy={ "p0": (MockPolicy, obs_space, act_space, {}), @@ -393,7 +393,7 @@ def testMultiAgentSampleWithHorizon(self): def testSampleFromEarlyDoneEnv(self): act_space = gym.spaces.Discrete(2) obs_space = gym.spaces.Discrete(2) - ev = PolicyEvaluator( + ev = RolloutWorker( env_creator=lambda _: EarlyDoneMultiAgent(), policy={ "p0": (MockPolicy, obs_space, act_space, {}), @@ -409,7 +409,7 @@ def testSampleFromEarlyDoneEnv(self): def testMultiAgentSampleRoundRobin(self): act_space = gym.spaces.Discrete(2) obs_space = gym.spaces.Discrete(10) - ev = PolicyEvaluator( + ev = RolloutWorker( env_creator=lambda _: RoundRobinMultiAgent(5, increment_obs=True), policy={ "p0": (MockPolicy, obs_space, act_space, {}), @@ -458,7 +458,7 @@ def compute_actions(self, def get_initial_state(self): return [{}] # empty dict - ev = PolicyEvaluator( + ev = RolloutWorker( env_creator=lambda _: gym.make("CartPole-v0"), policy=StatefulPolicy, batch_steps=5) @@ -503,7 +503,7 @@ def compute_actions(self, single_env = gym.make("CartPole-v0") obs_space = single_env.observation_space act_space = single_env.action_space - ev = PolicyEvaluator( + ev = RolloutWorker( env_creator=lambda _: MultiCartpole(2), policy={ "p0": (ModelBasedPolicy, obs_space, act_space, {}), @@ -587,7 +587,7 @@ def _testWithOptimizer(self, optimizer_cls): "p1": (PGTFPolicy, obs_space, act_space, {}), "p2": (DQNTFPolicy, obs_space, act_space, dqn_config), } - ev = PolicyEvaluator( + worker = RolloutWorker( env_creator=lambda _: MultiCartpole(n), policy=policies, policy_mapping_fn=lambda agent_id: ["p1", "p2"][agent_id % 2], @@ -597,29 +597,30 @@ def _testWithOptimizer(self, optimizer_cls): def policy_mapper(agent_id): return ["p1", "p2"][agent_id % 2] - remote_evs = [ - PolicyEvaluator.as_remote().remote( + remote_workers = [ + RolloutWorker.as_remote().remote( env_creator=lambda _: MultiCartpole(n), policy=policies, policy_mapping_fn=policy_mapper, batch_steps=50) ] else: - remote_evs = [] - optimizer = optimizer_cls(ev, remote_evs) + remote_workers = [] + workers = WorkerSet._from_existing(worker, remote_workers) + optimizer = optimizer_cls(workers) for i in range(200): - ev.foreach_policy(lambda p, _: p.set_epsilon( + worker.foreach_policy(lambda p, _: p.set_epsilon( max(0.02, 1 - i * .02)) if isinstance(p, DQNTFPolicy) else None) optimizer.step() - result = collect_metrics(ev, remote_evs) + result = collect_metrics(worker, remote_workers) if i % 20 == 0: def do_update(p): if isinstance(p, DQNTFPolicy): p.update_target() - ev.foreach_policy(lambda p, _: do_update(p)) + worker.foreach_policy(lambda p, _: do_update(p)) print("Iter {}, rew {}".format(i, result["policy_reward_mean"])) print("Total reward", result["episode_reward_mean"]) @@ -647,15 +648,16 @@ def testTrainMultiCartpoleManyPolicies(self): policies["pg_{}".format(i)] = (PGTFPolicy, obs_space, act_space, {}) policy_ids = list(policies.keys()) - ev = PolicyEvaluator( + worker = RolloutWorker( env_creator=lambda _: MultiCartpole(n), policy=policies, policy_mapping_fn=lambda agent_id: random.choice(policy_ids), batch_steps=100) - optimizer = SyncSamplesOptimizer(ev, []) + workers = WorkerSet._from_existing(worker, []) + optimizer = SyncSamplesOptimizer(workers) for i in range(100): optimizer.step() - result = collect_metrics(ev) + result = collect_metrics(worker) print("Iteration {}, rew {}".format(i, result["policy_reward_mean"])) print("Total reward", result["episode_reward_mean"]) diff --git a/python/ray/rllib/tests/test_optimizers.py b/python/ray/rllib/tests/test_optimizers.py index f851cfc33f12..a87a295ccf1d 100644 --- a/python/ray/rllib/tests/test_optimizers.py +++ b/python/ray/rllib/tests/test_optimizers.py @@ -11,10 +11,11 @@ from ray.rllib.agents.ppo import PPOTrainer from ray.rllib.agents.ppo.ppo_policy import PPOTFPolicy from ray.rllib.evaluation import SampleBatch -from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator +from ray.rllib.evaluation.rollout_worker import RolloutWorker +from ray.rllib.evaluation.worker_set import WorkerSet from ray.rllib.optimizers import AsyncGradientsOptimizer, AsyncSamplesOptimizer from ray.rllib.optimizers.aso_tree_aggregator import TreeAggregator -from ray.rllib.tests.mock_evaluator import _MockEvaluator +from ray.rllib.tests.mock_worker import _MockWorker from ray.rllib.utils import try_import_tf tf = try_import_tf() @@ -26,11 +27,11 @@ def tearDown(self): def testBasic(self): ray.init(num_cpus=4) - local = _MockEvaluator() - remotes = ray.remote(_MockEvaluator) - remote_evaluators = [remotes.remote() for i in range(5)] - test_optimizer = AsyncGradientsOptimizer( - local, remote_evaluators, grads_per_step=10) + local = _MockWorker() + remotes = ray.remote(_MockWorker) + remote_workers = [remotes.remote() for i in range(5)] + workers = WorkerSet._from_existing(local, remote_workers) + test_optimizer = AsyncGradientsOptimizer(workers, grads_per_step=10) test_optimizer.step() self.assertTrue(all(local.get_weights() == 0)) @@ -117,30 +118,28 @@ def setUpClass(cls): def testSimple(self): local, remotes = self._make_evs() - optimizer = AsyncSamplesOptimizer(local, remotes) + workers = WorkerSet._from_existing(local, remotes) + optimizer = AsyncSamplesOptimizer(workers) self._wait_for(optimizer, 1000, 1000) def testMultiGPU(self): local, remotes = self._make_evs() - optimizer = AsyncSamplesOptimizer( - local, remotes, num_gpus=2, _fake_gpus=True) + workers = WorkerSet._from_existing(local, remotes) + optimizer = AsyncSamplesOptimizer(workers, num_gpus=2, _fake_gpus=True) self._wait_for(optimizer, 1000, 1000) def testMultiGPUParallelLoad(self): local, remotes = self._make_evs() + workers = WorkerSet._from_existing(local, remotes) optimizer = AsyncSamplesOptimizer( - local, - remotes, - num_gpus=2, - num_data_loader_buffers=2, - _fake_gpus=True) + workers, num_gpus=2, num_data_loader_buffers=2, _fake_gpus=True) self._wait_for(optimizer, 1000, 1000) def testMultiplePasses(self): local, remotes = self._make_evs() + workers = WorkerSet._from_existing(local, remotes) optimizer = AsyncSamplesOptimizer( - local, - remotes, + workers, minibatch_buffer_size=10, num_sgd_iter=10, sample_batch_size=10, @@ -151,9 +150,9 @@ def testMultiplePasses(self): def testReplay(self): local, remotes = self._make_evs() + workers = WorkerSet._from_existing(local, remotes) optimizer = AsyncSamplesOptimizer( - local, - remotes, + workers, replay_buffer_num_slots=100, replay_proportion=10, sample_batch_size=10, @@ -168,9 +167,9 @@ def testReplay(self): def testReplayAndMultiplePasses(self): local, remotes = self._make_evs() + workers = WorkerSet._from_existing(local, remotes) optimizer = AsyncSamplesOptimizer( - local, - remotes, + workers, minibatch_buffer_size=10, num_sgd_iter=10, replay_buffer_num_slots=100, @@ -189,45 +188,43 @@ def testReplayAndMultiplePasses(self): def testMultiTierAggregationBadConf(self): local, remotes = self._make_evs() + workers = WorkerSet._from_existing(local, remotes) aggregators = TreeAggregator.precreate_aggregators(4) - optimizer = AsyncSamplesOptimizer( - local, remotes, num_aggregation_workers=4) + optimizer = AsyncSamplesOptimizer(workers, num_aggregation_workers=4) self.assertRaises(ValueError, lambda: optimizer.aggregator.init(aggregators)) def testMultiTierAggregation(self): local, remotes = self._make_evs() + workers = WorkerSet._from_existing(local, remotes) aggregators = TreeAggregator.precreate_aggregators(1) - optimizer = AsyncSamplesOptimizer( - local, remotes, num_aggregation_workers=1) + optimizer = AsyncSamplesOptimizer(workers, num_aggregation_workers=1) optimizer.aggregator.init(aggregators) self._wait_for(optimizer, 1000, 1000) def testRejectBadConfigs(self): local, remotes = self._make_evs() + workers = WorkerSet._from_existing(local, remotes) self.assertRaises( ValueError, lambda: AsyncSamplesOptimizer( local, remotes, num_data_loader_buffers=2, minibatch_buffer_size=4)) optimizer = AsyncSamplesOptimizer( - local, - remotes, + workers, num_gpus=2, train_batch_size=100, sample_batch_size=50, _fake_gpus=True) self._wait_for(optimizer, 1000, 1000) optimizer = AsyncSamplesOptimizer( - local, - remotes, + workers, num_gpus=2, train_batch_size=100, sample_batch_size=25, _fake_gpus=True) self._wait_for(optimizer, 1000, 1000) optimizer = AsyncSamplesOptimizer( - local, - remotes, + workers, num_gpus=2, train_batch_size=100, sample_batch_size=74, @@ -238,12 +235,12 @@ def _make_evs(self): def make_sess(): return tf.Session(config=tf.ConfigProto(device_count={"CPU": 2})) - local = PolicyEvaluator( + local = RolloutWorker( env_creator=lambda _: gym.make("CartPole-v0"), policy=PPOTFPolicy, tf_session_creator=make_sess) remotes = [ - PolicyEvaluator.as_remote().remote( + RolloutWorker.as_remote().remote( env_creator=lambda _: gym.make("CartPole-v0"), policy=PPOTFPolicy, tf_session_creator=make_sess) diff --git a/python/ray/rllib/tests/test_perf.py b/python/ray/rllib/tests/test_perf.py index e31530f44ced..6ed02a0ff2c7 100644 --- a/python/ray/rllib/tests/test_perf.py +++ b/python/ray/rllib/tests/test_perf.py @@ -7,8 +7,8 @@ import unittest import ray -from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator -from ray.rllib.tests.test_policy_evaluator import MockPolicy +from ray.rllib.evaluation.rollout_worker import RolloutWorker +from ray.rllib.tests.test_rollout_worker import MockPolicy class TestPerf(unittest.TestCase): @@ -17,7 +17,7 @@ class TestPerf(unittest.TestCase): # 03/01/19: Samples per second 8610.164353268685 def testBaselinePerformance(self): for _ in range(20): - ev = PolicyEvaluator( + ev = RolloutWorker( env_creator=lambda _: gym.make("CartPole-v0"), policy=MockPolicy, batch_steps=100) diff --git a/python/ray/rllib/tests/test_policy_evaluator.py b/python/ray/rllib/tests/test_rollout_worker.py similarity index 94% rename from python/ray/rllib/tests/test_policy_evaluator.py rename to python/ray/rllib/tests/test_rollout_worker.py index dc0dcaff6782..45b2fa01551f 100644 --- a/python/ray/rllib/tests/test_policy_evaluator.py +++ b/python/ray/rllib/tests/test_rollout_worker.py @@ -12,7 +12,7 @@ import ray from ray.rllib.agents.pg import PGTrainer from ray.rllib.agents.a3c import A2CTrainer -from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator +from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.evaluation.metrics import collect_metrics from ray.rllib.policy.policy import Policy from ray.rllib.evaluation.postprocessing import compute_advantages @@ -129,9 +129,9 @@ def get_unwrapped(self): return self.envs -class TestPolicyEvaluator(unittest.TestCase): +class TestRolloutWorker(unittest.TestCase): def testBasic(self): - ev = PolicyEvaluator( + ev = RolloutWorker( env_creator=lambda _: gym.make("CartPole-v0"), policy=MockPolicy) batch = ev.sample() for key in [ @@ -155,7 +155,7 @@ def to_prev(vec): self.assertGreater(batch["advantages"][0], 1) def testBatchIds(self): - ev = PolicyEvaluator( + ev = RolloutWorker( env_creator=lambda _: gym.make("CartPole-v0"), policy=MockPolicy) batch1 = ev.sample() batch2 = ev.sample() @@ -213,11 +213,10 @@ def testQueryEvaluators(self): "sample_batch_size": 5, "num_envs_per_worker": 2, }) - results = pg.optimizer.foreach_evaluator( - lambda ev: ev.sample_batch_size) - results2 = pg.optimizer.foreach_evaluator_with_index( + results = pg.workers.foreach_worker(lambda ev: ev.sample_batch_size) + results2 = pg.workers.foreach_worker_with_index( lambda ev, i: (i, ev.sample_batch_size)) - results3 = pg.optimizer.foreach_evaluator( + results3 = pg.workers.foreach_worker( lambda ev: ev.foreach_env(lambda env: 1)) self.assertEqual(results, [10, 10, 10]) self.assertEqual(results2, [(0, 10), (1, 10), (2, 10)]) @@ -225,7 +224,7 @@ def testQueryEvaluators(self): def testRewardClipping(self): # clipping on - ev = PolicyEvaluator( + ev = RolloutWorker( env_creator=lambda _: MockEnv2(episode_length=10), policy=MockPolicy, clip_rewards=True, @@ -235,7 +234,7 @@ def testRewardClipping(self): self.assertEqual(result["episode_reward_mean"], 1000) # clipping off - ev2 = PolicyEvaluator( + ev2 = RolloutWorker( env_creator=lambda _: MockEnv2(episode_length=10), policy=MockPolicy, clip_rewards=False, @@ -245,7 +244,7 @@ def testRewardClipping(self): self.assertEqual(result2["episode_reward_mean"], 1000) def testHardHorizon(self): - ev = PolicyEvaluator( + ev = RolloutWorker( env_creator=lambda _: MockEnv(episode_length=10), policy=MockPolicy, batch_mode="complete_episodes", @@ -259,7 +258,7 @@ def testHardHorizon(self): self.assertEqual(sum(samples["dones"]), 3) def testSoftHorizon(self): - ev = PolicyEvaluator( + ev = RolloutWorker( env_creator=lambda _: MockEnv(episode_length=10), policy=MockPolicy, batch_mode="complete_episodes", @@ -273,11 +272,11 @@ def testSoftHorizon(self): self.assertEqual(sum(samples["dones"]), 1) def testMetrics(self): - ev = PolicyEvaluator( + ev = RolloutWorker( env_creator=lambda _: MockEnv(episode_length=10), policy=MockPolicy, batch_mode="complete_episodes") - remote_ev = PolicyEvaluator.as_remote().remote( + remote_ev = RolloutWorker.as_remote().remote( env_creator=lambda _: MockEnv(episode_length=10), policy=MockPolicy, batch_mode="complete_episodes") @@ -288,7 +287,7 @@ def testMetrics(self): self.assertEqual(result["episode_reward_mean"], 10) def testAsync(self): - ev = PolicyEvaluator( + ev = RolloutWorker( env_creator=lambda _: gym.make("CartPole-v0"), sample_async=True, policy=MockPolicy) @@ -298,7 +297,7 @@ def testAsync(self): self.assertGreater(batch["advantages"][0], 1) def testAutoVectorization(self): - ev = PolicyEvaluator( + ev = RolloutWorker( env_creator=lambda cfg: MockEnv(episode_length=20, config=cfg), policy=MockPolicy, batch_mode="truncate_episodes", @@ -321,7 +320,7 @@ def testAutoVectorization(self): self.assertEqual(indices, [0, 1, 2, 3, 4, 5, 6, 7]) def testBatchesLargerWhenVectorized(self): - ev = PolicyEvaluator( + ev = RolloutWorker( env_creator=lambda _: MockEnv(episode_length=8), policy=MockPolicy, batch_mode="truncate_episodes", @@ -336,7 +335,7 @@ def testBatchesLargerWhenVectorized(self): self.assertEqual(result["episodes_this_iter"], 4) def testVectorEnvSupport(self): - ev = PolicyEvaluator( + ev = RolloutWorker( env_creator=lambda _: MockVectorEnv(episode_length=20, num_envs=8), policy=MockPolicy, batch_mode="truncate_episodes", @@ -353,7 +352,7 @@ def testVectorEnvSupport(self): self.assertEqual(result["episodes_this_iter"], 8) def testTruncateEpisodes(self): - ev = PolicyEvaluator( + ev = RolloutWorker( env_creator=lambda _: MockEnv(10), policy=MockPolicy, batch_steps=15, @@ -362,7 +361,7 @@ def testTruncateEpisodes(self): self.assertEqual(batch.count, 15) def testCompleteEpisodes(self): - ev = PolicyEvaluator( + ev = RolloutWorker( env_creator=lambda _: MockEnv(10), policy=MockPolicy, batch_steps=5, @@ -371,7 +370,7 @@ def testCompleteEpisodes(self): self.assertEqual(batch.count, 10) def testCompleteEpisodesPacking(self): - ev = PolicyEvaluator( + ev = RolloutWorker( env_creator=lambda _: MockEnv(10), policy=MockPolicy, batch_steps=15, @@ -383,7 +382,7 @@ def testCompleteEpisodesPacking(self): [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) def testFilterSync(self): - ev = PolicyEvaluator( + ev = RolloutWorker( env_creator=lambda _: gym.make("CartPole-v0"), policy=MockPolicy, sample_async=True, @@ -396,7 +395,7 @@ def testFilterSync(self): self.assertNotEqual(obs_f.buffer.n, 0) def testGetFilters(self): - ev = PolicyEvaluator( + ev = RolloutWorker( env_creator=lambda _: gym.make("CartPole-v0"), policy=MockPolicy, sample_async=True, @@ -411,7 +410,7 @@ def testGetFilters(self): self.assertGreaterEqual(obs_f2.buffer.n, obs_f.buffer.n) def testSyncFilter(self): - ev = PolicyEvaluator( + ev = RolloutWorker( env_creator=lambda _: gym.make("CartPole-v0"), policy=MockPolicy, sample_async=True, diff --git a/python/ray/rllib/utils/actors.py b/python/ray/rllib/utils/actors.py index b0e712f69d2c..8907aa5c9966 100644 --- a/python/ray/rllib/utils/actors.py +++ b/python/ray/rllib/utils/actors.py @@ -58,15 +58,15 @@ def completed_prefetch(self, blocking_wait=False, max_yield=999): remaining.append((worker, obj_id)) self._fetching = remaining - def reset_evaluators(self, evaluators): - """Notify that some evaluators may be removed.""" + def reset_workers(self, workers): + """Notify that some workers may be removed.""" for obj_id, ev in self._tasks.copy().items(): - if ev not in evaluators: + if ev not in workers: del self._tasks[obj_id] del self._objects[obj_id] ok = [] for ev, obj_id in self._fetching: - if ev in evaluators: + if ev in workers: ok.append((ev, obj_id)) self._fetching = ok From 084b22181edab2e21a9c56f9b4aea31513c5b723 Mon Sep 17 00:00:00 2001 From: Richard Liaw Date: Sun, 2 Jun 2019 17:45:57 -0700 Subject: [PATCH 064/118] Fix local cluster yaml (#4918) --- python/ray/autoscaler/local/example-full.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/python/ray/autoscaler/local/example-full.yaml b/python/ray/autoscaler/local/example-full.yaml index 1d152807c3b8..3324b5004b58 100644 --- a/python/ray/autoscaler/local/example-full.yaml +++ b/python/ray/autoscaler/local/example-full.yaml @@ -21,6 +21,7 @@ file_mounts: {} setup_commands: [] head_setup_commands: [] worker_setup_commands: [] +initialization_commands: [] setup_commands: - source activate ray && pip install -U ray head_start_ray_commands: From 89722ff00349203424cde8d468c9f1815799f807 Mon Sep 17 00:00:00 2001 From: Hersh Godse Date: Sun, 2 Jun 2019 22:13:40 -0700 Subject: [PATCH 065/118] [tune] Directional metrics for components (#4120) (#4915) --- ci/long_running_tests/workloads/pbt.py | 3 +- doc/source/tune-schedulers.rst | 8 +- .../tune/examples/async_hyperband_example.py | 3 +- python/ray/tune/examples/ax_example.py | 2 +- python/ray/tune/examples/bayesopt_example.py | 8 +- python/ray/tune/examples/genetic_example.py | 2 +- python/ray/tune/examples/hyperband_example.py | 3 +- python/ray/tune/examples/hyperopt_example.py | 8 +- python/ray/tune/examples/mnist_pytorch.py | 3 +- .../tune/examples/mnist_pytorch_trainable.py | 2 +- python/ray/tune/examples/nevergrad_example.py | 8 +- python/ray/tune/examples/pbt_example.py | 3 +- python/ray/tune/examples/pbt_ppo_example.py | 3 +- .../examples/pbt_tune_cifar10_with_keras.py | 3 +- python/ray/tune/examples/sigopt_example.py | 8 +- python/ray/tune/examples/skopt_example.py | 13 +-- .../ray/tune/examples/tune_cifar10_gluon.py | 3 +- .../examples/tune_mnist_async_hyperband.py | 3 +- python/ray/tune/examples/tune_mnist_keras.py | 3 +- .../tune/examples/tune_mnist_ray_hyperband.py | 5 +- python/ray/tune/schedulers/async_hyperband.py | 29 +++++-- python/ray/tune/schedulers/hyperband.py | 40 ++++++--- .../tune/schedulers/median_stopping_rule.py | 35 +++++--- python/ray/tune/schedulers/pbt.py | 31 +++++-- python/ray/tune/suggest/bayesopt.py | 32 +++++-- python/ray/tune/suggest/hyperopt.py | 32 +++++-- python/ray/tune/suggest/nevergrad.py | 36 ++++++-- python/ray/tune/suggest/sigopt.py | 32 +++++-- python/ray/tune/suggest/skopt.py | 34 ++++++-- .../tune/tests/test_experiment_analysis.py | 3 +- python/ray/tune/tests/test_trial_scheduler.py | 87 +++++++++++++------ 31 files changed, 354 insertions(+), 131 deletions(-) diff --git a/ci/long_running_tests/workloads/pbt.py b/ci/long_running_tests/workloads/pbt.py index 86473d86ec4d..5e63596c4a79 100644 --- a/ci/long_running_tests/workloads/pbt.py +++ b/ci/long_running_tests/workloads/pbt.py @@ -37,7 +37,8 @@ pbt = PopulationBasedTraining( time_attr="training_iteration", - reward_attr="episode_reward_mean", + metric="episode_reward_mean", + mode="max", perturbation_interval=10, hyperparam_mutations={ "lr": [0.1, 0.01, 0.001, 0.0001], diff --git a/doc/source/tune-schedulers.rst b/doc/source/tune-schedulers.rst index cbb105ff5806..2f6957f3f475 100644 --- a/doc/source/tune-schedulers.rst +++ b/doc/source/tune-schedulers.rst @@ -7,7 +7,7 @@ By default, Tune schedules trials in serial order with the ``FIFOScheduler`` cla tune.run( ... , scheduler=AsyncHyperBandScheduler()) -Tune includes distributed implementations of early stopping algorithms such as `Median Stopping Rule `__, `HyperBand `__, and an `asynchronous version of HyperBand `__. These algorithms are very resource efficient and can outperform Bayesian Optimization methods in `many cases `__. Currently, all schedulers take in a ``reward_attr``, which is assumed to be maximized. +Tune includes distributed implementations of early stopping algorithms such as `Median Stopping Rule `__, `HyperBand `__, and an `asynchronous version of HyperBand `__. These algorithms are very resource efficient and can outperform Bayesian Optimization methods in `many cases `__. All schedulers take in a ``metric``, which is a value returned in the result dict of your Trainable and is maximized or minimized according to ``mode``. Current Available Trial Schedulers: @@ -25,7 +25,8 @@ Tune includes a distributed implementation of `Population Based Training (PBT) < pbt_scheduler = PopulationBasedTraining( time_attr='time_total_s', - reward_attr='mean_accuracy', + metric='mean_accuracy', + mode='max', perturbation_interval=600.0, hyperparam_mutations={ "lr": [1e-3, 5e-4, 1e-4, 5e-5, 1e-5], @@ -52,7 +53,8 @@ The `asynchronous version of HyperBand 0, "grace_period must be positive!" assert reduction_factor > 1, "Reduction Factor not valid!" assert brackets > 0, "brackets must be positive!" + assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!" + + if reward_attr is not None: + mode = "max" + metric = reward_attr + logger.warning( + "`reward_attr` is deprecated and will be removed in a future " + "version of Tune. " + "Setting `metric={}` and `mode=max`.".format(reward_attr)) + FIFOScheduler.__init__(self) self._reduction_factor = reduction_factor self._max_t = max_t @@ -63,7 +76,11 @@ def __init__(self, ] self._counter = 0 # for self._num_stopped = 0 - self._reward_attr = reward_attr + self._metric = metric + if mode == "max": + self._metric_op = 1. + elif mode == "min": + self._metric_op = -1. self._time_attr = time_attr def on_trial_add(self, trial_runner, trial): @@ -80,7 +97,7 @@ def on_trial_result(self, trial_runner, trial, result): else: bracket = self._trial_info[trial.trial_id] action = bracket.on_result(trial, result[self._time_attr], - result[self._reward_attr]) + self._metric_op * result[self._metric]) if action == TrialScheduler.STOP: self._num_stopped += 1 return action @@ -88,7 +105,7 @@ def on_trial_result(self, trial_runner, trial, result): def on_trial_complete(self, trial_runner, trial, result): bracket = self._trial_info[trial.trial_id] bracket.on_result(trial, result[self._time_attr], - result[self._reward_attr]) + self._metric_op * result[self._metric]) del self._trial_info[trial.trial_id] def on_trial_remove(self, trial_runner, trial): diff --git a/python/ray/tune/schedulers/hyperband.py b/python/ray/tune/schedulers/hyperband.py index c9bdde8ab4aa..b1ca1deec11b 100644 --- a/python/ray/tune/schedulers/hyperband.py +++ b/python/ray/tune/schedulers/hyperband.py @@ -43,8 +43,9 @@ class HyperBandScheduler(FIFOScheduler): To use this implementation of HyperBand with Tune, all you need to do is specify the max length of time a trial can run `max_t`, the time - units `time_attr`, and the name of the reported objective value - `reward_attr`. We automatically determine reasonable values for the other + units `time_attr`, the name of the reported objective value `metric`, + and if `metric` is to be maximized or minimized (`mode`). + We automatically determine reasonable values for the other HyperBand parameters based on the given values. For example, to limit trials to 10 minutes and early stop based on the @@ -62,9 +63,10 @@ class HyperBandScheduler(FIFOScheduler): Note that you can pass in something non-temporal such as `training_iteration` as a measure of progress, the only requirement is that the attribute should increase monotonically. - reward_attr (str): The training result objective value attribute. As - with `time_attr`, this may refer to any objective value. Stopping + metric (str): The training result objective value attribute. Stopping procedures will use this attribute. + mode (str): One of {min, max}. Determines whether objective is + minimizing or maximizing the metric attribute. max_t (int): max time units per trial. Trials will be stopped after max_t time units (determined by time_attr) have passed. The scheduler will terminate trials after this time has passed. @@ -74,16 +76,28 @@ class HyperBandScheduler(FIFOScheduler): def __init__(self, time_attr="training_iteration", - reward_attr="episode_reward_mean", + reward_attr=None, + metric="episode_reward_mean", + mode="max", max_t=81): assert max_t > 0, "Max (time_attr) not valid!" + assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!" + + if reward_attr is not None: + mode = "max" + metric = reward_attr + logger.warning( + "`reward_attr` is deprecated and will be removed in a future " + "version of Tune. " + "Setting `metric={}` and `mode=max`.".format(reward_attr)) + FIFOScheduler.__init__(self) self._eta = 3 self._s_max_1 = 5 self._max_t_attr = max_t # bracket max trials self._get_n0 = lambda s: int( - np.ceil(self._s_max_1/(s+1) * self._eta**s)) + np.ceil(self._s_max_1 / (s + 1) * self._eta**s)) # bracket initial iterations self._get_r0 = lambda s: int((max_t * self._eta**(-s))) self._hyperbands = [[]] # list of hyperband iterations @@ -92,7 +106,11 @@ def __init__(self, # Tracks state for new trial add self._state = {"bracket": None, "band_idx": 0} self._num_stopped = 0 - self._reward_attr = reward_attr + self._metric = metric + if mode == "max": + self._metric_op = 1. + elif mode == "min": + self._metric_op = -1. self._time_attr = time_attr def on_trial_add(self, trial_runner, trial): @@ -173,7 +191,8 @@ def _process_bracket(self, trial_runner, bracket, trial): bracket.cleanup_full(trial_runner) return TrialScheduler.STOP - good, bad = bracket.successive_halving(self._reward_attr) + good, bad = bracket.successive_halving(self._metric, + self._metric_op) # kill bad trials self._num_stopped += len(bad) for t in bad: @@ -322,7 +341,7 @@ def filled(self): return len(self._live_trials) == self._n - def successive_halving(self, reward_attr): + def successive_halving(self, metric, metric_op): assert self._halves > 0 self._halves -= 1 self._n /= self._eta @@ -332,7 +351,8 @@ def successive_halving(self, reward_attr): self._r = int(min(self._r, self._max_t_attr - self._cumul_r)) self._cumul_r += self._r sorted_trials = sorted( - self._live_trials, key=lambda t: self._live_trials[t][reward_attr]) + self._live_trials, + key=lambda t: metric_op * self._live_trials[t][metric]) good, bad = sorted_trials[-self._n:], sorted_trials[:-self._n] return good, bad diff --git a/python/ray/tune/schedulers/median_stopping_rule.py b/python/ray/tune/schedulers/median_stopping_rule.py index e554a69f099d..36276273c941 100644 --- a/python/ray/tune/schedulers/median_stopping_rule.py +++ b/python/ray/tune/schedulers/median_stopping_rule.py @@ -22,9 +22,10 @@ class MedianStoppingRule(FIFOScheduler): Note that you can pass in something non-temporal such as `training_iteration` as a measure of progress, the only requirement is that the attribute should increase monotonically. - reward_attr (str): The training result objective value attribute. As - with `time_attr`, this may refer to any objective value that - is supposed to increase with time. + metric (str): The training result objective value attribute. Stopping + procedures will use this attribute. + mode (str): One of {min, max}. Determines whether objective is + minimizing or maximizing the metric attribute. grace_period (float): Only stop trials at least this old in time. The units are the same as the attribute named by `time_attr`. min_samples_required (int): Min samples to compute median over. @@ -37,18 +38,34 @@ class MedianStoppingRule(FIFOScheduler): def __init__(self, time_attr="time_total_s", - reward_attr="episode_reward_mean", + reward_attr=None, + metric="episode_reward_mean", + mode="max", grace_period=60.0, min_samples_required=3, hard_stop=True, verbose=True): + assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!" + + if reward_attr is not None: + mode = "max" + metric = reward_attr + logger.warning( + "`reward_attr` is deprecated and will be removed in a future " + "version of Tune. " + "Setting `metric={}` and `mode=max`.".format(reward_attr)) + FIFOScheduler.__init__(self) self._stopped_trials = set() self._completed_trials = set() self._results = collections.defaultdict(list) self._grace_period = grace_period self._min_samples_required = min_samples_required - self._reward_attr = reward_attr + self._metric = metric + if mode == "max": + self._metric_op = 1. + elif mode == "min": + self._metric_op = -1. self._time_attr = time_attr self._hard_stop = hard_stop self._verbose = verbose @@ -110,11 +127,9 @@ def _running_result(self, trial, t_max=float("inf")): results = self._results[trial] # TODO(ekl) we could do interpolation to be more precise, but for now # assume len(results) is large and the time diffs are roughly equal - return np.mean([ - r[self._reward_attr] for r in results - if r[self._time_attr] <= t_max - ]) + return self._metric_op * np.mean( + [r[self._metric] for r in results if r[self._time_attr] <= t_max]) def _best_result(self, trial): results = self._results[trial] - return max(r[self._reward_attr] for r in results) + return max(self._metric_op * r[self._metric] for r in results) diff --git a/python/ray/tune/schedulers/pbt.py b/python/ray/tune/schedulers/pbt.py index e5d793fd113a..b6d1b4e80838 100644 --- a/python/ray/tune/schedulers/pbt.py +++ b/python/ray/tune/schedulers/pbt.py @@ -120,9 +120,10 @@ class PopulationBasedTraining(FIFOScheduler): Note that you can pass in something non-temporal such as `training_iteration` as a measure of progress, the only requirement is that the attribute should increase monotonically. - reward_attr (str): The training result objective value attribute. As - with `time_attr`, this may refer to any objective value. Stopping + metric (str): The training result objective value attribute. Stopping procedures will use this attribute. + mode (str): One of {min, max}. Determines whether objective is + minimizing or maximizing the metric attribute. perturbation_interval (float): Models will be considered for perturbation at this interval of `time_attr`. Note that perturbation incurs checkpoint overhead, so you shouldn't set this @@ -149,7 +150,8 @@ class PopulationBasedTraining(FIFOScheduler): Example: >>> pbt = PopulationBasedTraining( >>> time_attr="training_iteration", - >>> reward_attr="episode_reward_mean", + >>> metric="episode_reward_mean", + >>> mode="max", >>> perturbation_interval=10, # every 10 `time_attr` units >>> # (training_iterations in this case) >>> hyperparam_mutations={ @@ -165,7 +167,9 @@ class PopulationBasedTraining(FIFOScheduler): def __init__(self, time_attr="time_total_s", - reward_attr="episode_reward_mean", + reward_attr=None, + metric="episode_reward_mean", + mode="max", perturbation_interval=60.0, hyperparam_mutations={}, resample_probability=0.25, @@ -175,8 +179,23 @@ def __init__(self, raise TuneError( "You must specify at least one of `hyperparam_mutations` or " "`custom_explore_fn` to use PBT.") + + assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!" + + if reward_attr is not None: + mode = "max" + metric = reward_attr + logger.warning( + "`reward_attr` is deprecated and will be removed in a future " + "version of Tune. " + "Setting `metric={}` and `mode=max`.".format(reward_attr)) + FIFOScheduler.__init__(self) - self._reward_attr = reward_attr + self._metric = metric + if mode == "max": + self._metric_op = 1. + elif mode == "min": + self._metric_op = -1. self._time_attr = time_attr self._perturbation_interval = perturbation_interval self._hyperparam_mutations = hyperparam_mutations @@ -199,7 +218,7 @@ def on_trial_result(self, trial_runner, trial, result): if time - state.last_perturbation_time < self._perturbation_interval: return TrialScheduler.CONTINUE # avoid checkpoint overhead - score = result[self._reward_attr] + score = self._metric_op * result[self._metric] state.last_score = score state.last_perturbation_time = time lower_quantile, upper_quantile = self._quantiles() diff --git a/python/ray/tune/suggest/bayesopt.py b/python/ray/tune/suggest/bayesopt.py index 029ee4ddd11d..7b253041122f 100644 --- a/python/ray/tune/suggest/bayesopt.py +++ b/python/ray/tune/suggest/bayesopt.py @@ -3,6 +3,7 @@ from __future__ import print_function import copy +import logging try: # Python 3 only -- needed for lint test. import bayes_opt as byo except ImportError: @@ -10,6 +11,8 @@ from ray.tune.suggest.suggestion import SuggestionAlgorithm +logger = logging.getLogger(__name__) + class BayesOptSearch(SuggestionAlgorithm): """A wrapper around BayesOpt to provide trial suggestions. @@ -22,8 +25,9 @@ class BayesOptSearch(SuggestionAlgorithm): this space which will be used to run trials. max_concurrent (int): Number of maximum concurrent trials. Defaults to 10. - reward_attr (str): The training result objective value attribute. - This refers to an increasing value. + metric (str): The training result objective value attribute. + mode (str): One of {min, max}. Determines whether objective is + minimizing or maximizing the metric attribute. utility_kwargs (dict): Parameters to define the utility function. Must provide values for the keys `kind`, `kappa`, and `xi`. random_state (int): Used to initialize BayesOpt. @@ -35,13 +39,15 @@ class BayesOptSearch(SuggestionAlgorithm): >>> 'height': (-100, 100), >>> } >>> algo = BayesOptSearch( - >>> space, max_concurrent=4, reward_attr="neg_mean_loss") + >>> space, max_concurrent=4, metric="mean_loss", mode="min") """ def __init__(self, space, max_concurrent=10, - reward_attr="episode_reward_mean", + reward_attr=None, + metric="episode_reward_mean", + mode="max", utility_kwargs=None, random_state=1, verbose=0, @@ -52,8 +58,22 @@ def __init__(self, assert type(max_concurrent) is int and max_concurrent > 0 assert utility_kwargs is not None, ( "Must define arguments for the utiliy function!") + assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!" + + if reward_attr is not None: + mode = "max" + metric = reward_attr + logger.warning( + "`reward_attr` is deprecated and will be removed in a future " + "version of Tune. " + "Setting `metric={}` and `mode=max`.".format(reward_attr)) + self._max_concurrent = max_concurrent - self._reward_attr = reward_attr + self._metric = metric + if mode == "max": + self._metric_op = 1. + elif mode == "min": + self._metric_op = -1. self._live_trial_mapping = {} self.optimizer = byo.BayesianOptimization( @@ -85,7 +105,7 @@ def on_trial_complete(self, if result: self.optimizer.register( params=self._live_trial_mapping[trial_id], - target=result[self._reward_attr]) + target=self._metric_op * result[self._metric]) del self._live_trial_mapping[trial_id] diff --git a/python/ray/tune/suggest/hyperopt.py b/python/ray/tune/suggest/hyperopt.py index 533e320c051b..efae5d0310ea 100644 --- a/python/ray/tune/suggest/hyperopt.py +++ b/python/ray/tune/suggest/hyperopt.py @@ -15,6 +15,8 @@ from ray.tune.error import TuneError from ray.tune.suggest.suggestion import SuggestionAlgorithm +logger = logging.getLogger(__name__) + class HyperOptSearch(SuggestionAlgorithm): """A wrapper around HyperOpt to provide trial suggestions. @@ -30,8 +32,9 @@ class HyperOptSearch(SuggestionAlgorithm): parameters generated in the variant generation process. max_concurrent (int): Number of maximum concurrent trials. Defaults to 10. - reward_attr (str): The training result objective value attribute. - This refers to an increasing value. + metric (str): The training result objective value attribute. + mode (str): One of {min, max}. Determines whether objective is + minimizing or maximizing the metric attribute. points_to_evaluate (list): Initial parameter suggestions to be run first. This is for when you already have some good parameters you want hyperopt to run first to help the TPE algorithm @@ -52,21 +55,38 @@ class HyperOptSearch(SuggestionAlgorithm): >>> 'activation': 0, # The index of "relu" >>> }] >>> algo = HyperOptSearch( - >>> space, max_concurrent=4, reward_attr="neg_mean_loss", + >>> space, max_concurrent=4, metric="mean_loss", mode="min", >>> points_to_evaluate=current_best_params) """ def __init__(self, space, max_concurrent=10, - reward_attr="episode_reward_mean", + reward_attr=None, + metric="episode_reward_mean", + mode="max", points_to_evaluate=None, **kwargs): assert hpo is not None, "HyperOpt must be installed!" from hyperopt.fmin import generate_trials_to_calculate assert type(max_concurrent) is int and max_concurrent > 0 + assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!" + + if reward_attr is not None: + mode = "max" + metric = reward_attr + logger.warning( + "`reward_attr` is deprecated and will be removed in a future " + "version of Tune. " + "Setting `metric={}` and `mode=max`.".format(reward_attr)) + self._max_concurrent = max_concurrent - self._reward_attr = reward_attr + self._metric = metric + # hyperopt internally minimizes, so "max" => -1 + if mode == "max": + self._metric_op = -1. + elif mode == "min": + self._metric_op = 1. self.algo = hpo.tpe.suggest self.domain = hpo.Domain(lambda spc: spc, space) if points_to_evaluate is None: @@ -151,7 +171,7 @@ def on_trial_complete(self, del self._live_trial_mapping[trial_id] def _to_hyperopt_result(self, result): - return {"loss": -result[self._reward_attr], "status": "ok"} + return {"loss": self._metric_op * result[self._metric], "status": "ok"} def _get_hyperopt_trial(self, trial_id): if trial_id not in self._live_trial_mapping: diff --git a/python/ray/tune/suggest/nevergrad.py b/python/ray/tune/suggest/nevergrad.py index 284311a3646f..2ad8eed37d6e 100644 --- a/python/ray/tune/suggest/nevergrad.py +++ b/python/ray/tune/suggest/nevergrad.py @@ -2,6 +2,7 @@ from __future__ import division from __future__ import print_function +import logging try: import nevergrad as ng except ImportError: @@ -9,6 +10,8 @@ from ray.tune.suggest.suggestion import SuggestionAlgorithm +logger = logging.getLogger(__name__) + class NevergradSearch(SuggestionAlgorithm): """A wrapper around Nevergrad to provide trial suggestions. @@ -28,15 +31,16 @@ class NevergradSearch(SuggestionAlgorithm): (see nevergrad v0.2.0+). max_concurrent (int): Number of maximum concurrent trials. Defaults to 10. - reward_attr (str): The training result objective value attribute. - This refers to an increasing value. + metric (str): The training result objective value attribute. + mode (str): One of {min, max}. Determines whether objective is + minimizing or maximizing the metric attribute. Example: >>> from nevergrad.optimization import optimizerlib >>> instrumentation = 1 >>> optimizer = optimizerlib.OnePlusOne(instrumentation, budget=100) >>> algo = NevergradSearch(optimizer, ["lr"], max_concurrent=4, - >>> reward_attr="neg_mean_loss") + >>> metric="mean_loss", mode="min") Note: In nevergrad v0.2.0+, optimizers can be instrumented. @@ -49,7 +53,7 @@ class NevergradSearch(SuggestionAlgorithm): >>> instrumentation = inst.Instrumentation(lr=lr) >>> optimizer = optimizerlib.OnePlusOne(instrumentation, budget=100) >>> algo = NevergradSearch(optimizer, None, max_concurrent=4, - >>> reward_attr="neg_mean_loss") + >>> metric="mean_loss", mode="min") """ @@ -57,13 +61,30 @@ def __init__(self, optimizer, parameter_names, max_concurrent=10, - reward_attr="episode_reward_mean", + reward_attr=None, + metric="episode_reward_mean", + mode="max", **kwargs): assert ng is not None, "Nevergrad must be installed!" assert type(max_concurrent) is int and max_concurrent > 0 + assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!" + + if reward_attr is not None: + mode = "max" + metric = reward_attr + logger.warning( + "`reward_attr` is deprecated and will be removed in a future " + "version of Tune. " + "Setting `metric={}` and `mode=max`.".format(reward_attr)) + self._max_concurrent = max_concurrent self._parameters = parameter_names - self._reward_attr = reward_attr + self._metric = metric + # nevergrad.tell internally minimizes, so "max" => -1 + if mode == "max": + self._metric_op = -1. + elif mode == "min": + self._metric_op = 1. self._nevergrad_opt = optimizer self._live_trial_mapping = {} super(NevergradSearch, self).__init__(**kwargs) @@ -119,7 +140,8 @@ def on_trial_complete(self, """ ng_trial_info = self._live_trial_mapping.pop(trial_id) if result: - self._nevergrad_opt.tell(ng_trial_info, -result[self._reward_attr]) + self._nevergrad_opt.tell(ng_trial_info, + self._metric_op * result[self._metric]) def _num_live_trials(self): return len(self._live_trial_mapping) diff --git a/python/ray/tune/suggest/sigopt.py b/python/ray/tune/suggest/sigopt.py index 72d2d0afca75..9aaf593f13c6 100644 --- a/python/ray/tune/suggest/sigopt.py +++ b/python/ray/tune/suggest/sigopt.py @@ -4,6 +4,7 @@ import copy import os +import logging try: import sigopt as sgo except ImportError: @@ -11,6 +12,8 @@ from ray.tune.suggest.suggestion import SuggestionAlgorithm +logger = logging.getLogger(__name__) + class SigOptSearch(SuggestionAlgorithm): """A wrapper around SigOpt to provide trial suggestions. @@ -25,8 +28,9 @@ class SigOptSearch(SuggestionAlgorithm): name (str): Name of experiment. Required by SigOpt. max_concurrent (int): Number of maximum concurrent trials supported based on the user's SigOpt plan. Defaults to 1. - reward_attr (str): The training result objective value attribute. - This refers to an increasing value. + metric (str): The training result objective value attribute. + mode (str): One of {min, max}. Determines whether objective is + minimizing or maximizing the metric attribute. Example: >>> space = [ @@ -49,21 +53,37 @@ class SigOptSearch(SuggestionAlgorithm): >>> ] >>> algo = SigOptSearch( >>> space, name="SigOpt Example Experiment", - >>> max_concurrent=1, reward_attr="neg_mean_loss") + >>> max_concurrent=1, metric="mean_loss", mode="min") """ def __init__(self, space, name="Default Tune Experiment", max_concurrent=1, - reward_attr="episode_reward_mean", + reward_attr=None, + metric="episode_reward_mean", + mode="max", **kwargs): assert sgo is not None, "SigOpt must be installed!" assert type(max_concurrent) is int and max_concurrent > 0 assert "SIGOPT_KEY" in os.environ, \ "SigOpt API key must be stored as environ variable at SIGOPT_KEY" + assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!" + + if reward_attr is not None: + mode = "max" + metric = reward_attr + logger.warning( + "`reward_attr` is deprecated and will be removed in a future " + "version of Tune. " + "Setting `metric={}` and `mode=max`.".format(reward_attr)) + self._max_concurrent = max_concurrent - self._reward_attr = reward_attr + self._metric = metric + if mode == "max": + self._metric_op = 1. + elif mode == "min": + self._metric_op = -1. self._live_trial_mapping = {} # Create a connection with SigOpt API, requires API key @@ -108,7 +128,7 @@ def on_trial_complete(self, if result: self.conn.experiments(self.experiment.id).observations().create( suggestion=self._live_trial_mapping[trial_id].id, - value=result[self._reward_attr], + value=self._metric_op * result[self._metric], ) # Update the experiment object self.experiment = self.conn.experiments(self.experiment.id).fetch() diff --git a/python/ray/tune/suggest/skopt.py b/python/ray/tune/suggest/skopt.py index 26457c5fa9ee..f60a0856ed14 100644 --- a/python/ray/tune/suggest/skopt.py +++ b/python/ray/tune/suggest/skopt.py @@ -2,6 +2,7 @@ from __future__ import division from __future__ import print_function +import logging try: import skopt as sko except ImportError: @@ -9,6 +10,8 @@ from ray.tune.suggest.suggestion import SuggestionAlgorithm +logger = logging.getLogger(__name__) + def _validate_warmstart(parameter_names, points_to_evaluate, evaluated_rewards): @@ -52,8 +55,9 @@ class SkOptSearch(SuggestionAlgorithm): the dimension of the optimizer output. max_concurrent (int): Number of maximum concurrent trials. Defaults to 10. - reward_attr (str): The training result objective value attribute. - This refers to an increasing value. + metric (str): The training result objective value attribute. + mode (str): One of {min, max}. Determines whether objective is + minimizing or maximizing the metric attribute. points_to_evaluate (list of lists): A list of points you'd like to run first before sampling from the optimiser, e.g. these could be parameter configurations you already know work well to help @@ -73,7 +77,8 @@ class SkOptSearch(SuggestionAlgorithm): >>> algo = SkOptSearch(optimizer, >>> ["width", "height"], >>> max_concurrent=4, - >>> reward_attr="neg_mean_loss", + >>> metric="mean_loss", + >>> mode="min", >>> points_to_evaluate=current_best_params) """ @@ -81,7 +86,9 @@ def __init__(self, optimizer, parameter_names, max_concurrent=10, - reward_attr="episode_reward_mean", + reward_attr=None, + metric="episode_reward_mean", + mode="max", points_to_evaluate=None, evaluated_rewards=None, **kwargs): @@ -91,6 +98,15 @@ def __init__(self, assert type(max_concurrent) is int and max_concurrent > 0 _validate_warmstart(parameter_names, points_to_evaluate, evaluated_rewards) + assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!" + + if reward_attr is not None: + mode = "max" + metric = reward_attr + logger.warning( + "`reward_attr` is deprecated and will be removed in a future " + "version of Tune. " + "Setting `metric={}` and `mode=max`.".format(reward_attr)) self._initial_points = [] if points_to_evaluate and evaluated_rewards: @@ -99,7 +115,12 @@ def __init__(self, self._initial_points = points_to_evaluate self._max_concurrent = max_concurrent self._parameters = parameter_names - self._reward_attr = reward_attr + self._metric = metric + # Skopt internally minimizes, so "max" => -1 + if mode == "max": + self._metric_op = -1. + elif mode == "min": + self._metric_op = 1. self._skopt_opt = optimizer self._live_trial_mapping = {} super(SkOptSearch, self).__init__(**kwargs) @@ -131,7 +152,8 @@ def on_trial_complete(self, """ skopt_trial_info = self._live_trial_mapping.pop(trial_id) if result: - self._skopt_opt.tell(skopt_trial_info, -result[self._reward_attr]) + self._skopt_opt.tell(skopt_trial_info, + self._metric_op * result[self._metric]) def _num_live_trials(self): return len(self._live_trial_mapping) diff --git a/python/ray/tune/tests/test_experiment_analysis.py b/python/ray/tune/tests/test_experiment_analysis.py index 96b08c0a6685..a0721abc5d29 100644 --- a/python/ray/tune/tests/test_experiment_analysis.py +++ b/python/ray/tune/tests/test_experiment_analysis.py @@ -36,7 +36,8 @@ def tearDown(self): def run_test_exp(self): ahb = AsyncHyperBandScheduler( time_attr="training_iteration", - reward_attr=self.metric, + metric=self.metric, + mode="max", grace_period=5, max_t=100) diff --git a/python/ray/tune/tests/test_trial_scheduler.py b/python/ray/tune/tests/test_trial_scheduler.py index 1f8ab8025f1f..1c33375549f0 100644 --- a/python/ray/tune/tests/test_trial_scheduler.py +++ b/python/ray/tune/tests/test_trial_scheduler.py @@ -135,33 +135,43 @@ def testMedianStoppingSoftStop(self): rule.on_trial_result(None, t3, result(2, 260)), TrialScheduler.PAUSE) - def testAlternateMetrics(self): - def result2(t, rew): - return dict(training_iteration=t, neg_mean_loss=rew) - + def _test_metrics(self, result_func, metric, mode): rule = MedianStoppingRule( grace_period=0, min_samples_required=1, time_attr="training_iteration", - reward_attr="neg_mean_loss") + metric=metric, + mode=mode) t1 = Trial("PPO") # mean is 450, max 900, t_max=10 t2 = Trial("PPO") # mean is 450, max 450, t_max=5 for i in range(10): self.assertEqual( - rule.on_trial_result(None, t1, result2(i, i * 100)), + rule.on_trial_result(None, t1, result_func(i, i * 100)), TrialScheduler.CONTINUE) for i in range(5): self.assertEqual( - rule.on_trial_result(None, t2, result2(i, 450)), + rule.on_trial_result(None, t2, result_func(i, 450)), TrialScheduler.CONTINUE) - rule.on_trial_complete(None, t1, result2(10, 1000)) + rule.on_trial_complete(None, t1, result_func(10, 1000)) self.assertEqual( - rule.on_trial_result(None, t2, result2(5, 450)), + rule.on_trial_result(None, t2, result_func(5, 450)), TrialScheduler.CONTINUE) self.assertEqual( - rule.on_trial_result(None, t2, result2(6, 0)), + rule.on_trial_result(None, t2, result_func(6, 0)), TrialScheduler.CONTINUE) + def testAlternateMetrics(self): + def result2(t, rew): + return dict(training_iteration=t, neg_mean_loss=rew) + + self._test_metrics(result2, "neg_mean_loss", "max") + + def testAlternateMetricsMin(self): + def result2(t, rew): + return dict(training_iteration=t, mean_loss=-rew) + + self._test_metrics(result2, "mean_loss", "min") + class _MockTrialExecutor(TrialExecutor): def start_trial(self, trial, checkpoint_obj=None): @@ -495,14 +505,9 @@ def testAddAfterHalving(self): TrialScheduler.PAUSE, sched.on_trial_result(mock_runner, t, result(new_units, 12))) - def testAlternateMetrics(self): - """Checking that alternate metrics will pass.""" - - def result2(t, rew): - return dict(time_total_s=t, neg_mean_loss=rew) - + def _test_metrics(self, result_func, metric, mode): sched = HyperBandScheduler( - time_attr="time_total_s", reward_attr="neg_mean_loss") + time_attr="time_total_s", metric=metric, mode=mode) stats = self.default_statistics() for i in range(stats["max_trials"]): @@ -518,13 +523,29 @@ def result2(t, rew): # Provides results from 0 to 8 in order, keeping the last one running for i, trl in enumerate(big_bracket.current_trials()): - action = sched.on_trial_result(runner, trl, result2(1, i)) + action = sched.on_trial_result(runner, trl, result_func(1, i)) runner.process_action(trl, action) new_length = len(big_bracket.current_trials()) self.assertEqual(action, TrialScheduler.CONTINUE) self.assertEqual(new_length, self.downscale(current_length, sched)) + def testAlternateMetrics(self): + """Checking that alternate metrics will pass.""" + + def result2(t, rew): + return dict(time_total_s=t, neg_mean_loss=rew) + + self._test_metrics(result2, "neg_mean_loss", "max") + + def testAlternateMetricsMin(self): + """Checking that alternate metrics will pass.""" + + def result2(t, rew): + return dict(time_total_s=t, mean_loss=-rew) + + self._test_metrics(result2, "mean_loss", "min") + def testJumpingTime(self): sched, mock_runner = self.schedulerSetup(81) big_bracket = sched._hyperbands[0][-1] @@ -1015,14 +1036,12 @@ def testAsyncHBUsesPercentile(self): scheduler.on_trial_result(None, t3, result(2, 260)), TrialScheduler.STOP) - def testAlternateMetrics(self): - def result2(t, rew): - return dict(training_iteration=t, neg_mean_loss=rew) - + def _test_metrics(self, result_func, metric, mode): scheduler = AsyncHyperBandScheduler( grace_period=1, time_attr="training_iteration", - reward_attr="neg_mean_loss", + metric=metric, + mode=mode, brackets=1) t1 = Trial("PPO") # mean is 450, max 900, t_max=10 t2 = Trial("PPO") # mean is 450, max 450, t_max=5 @@ -1030,20 +1049,32 @@ def result2(t, rew): scheduler.on_trial_add(None, t2) for i in range(10): self.assertEqual( - scheduler.on_trial_result(None, t1, result2(i, i * 100)), + scheduler.on_trial_result(None, t1, result_func(i, i * 100)), TrialScheduler.CONTINUE) for i in range(5): self.assertEqual( - scheduler.on_trial_result(None, t2, result2(i, 450)), + scheduler.on_trial_result(None, t2, result_func(i, 450)), TrialScheduler.CONTINUE) - scheduler.on_trial_complete(None, t1, result2(10, 1000)) + scheduler.on_trial_complete(None, t1, result_func(10, 1000)) self.assertEqual( - scheduler.on_trial_result(None, t2, result2(5, 450)), + scheduler.on_trial_result(None, t2, result_func(5, 450)), TrialScheduler.CONTINUE) self.assertEqual( - scheduler.on_trial_result(None, t2, result2(6, 0)), + scheduler.on_trial_result(None, t2, result_func(6, 0)), TrialScheduler.CONTINUE) + def testAlternateMetrics(self): + def result2(t, rew): + return dict(training_iteration=t, neg_mean_loss=rew) + + self._test_metrics(result2, "neg_mean_loss", "max") + + def testAlternateMetricsMin(self): + def result2(t, rew): + return dict(training_iteration=t, mean_loss=-rew) + + self._test_metrics(result2, "mean_loss", "min") + if __name__ == "__main__": unittest.main(verbosity=2) From b674c4a5ba6cc62d4535e2e016b248edc28b53e5 Mon Sep 17 00:00:00 2001 From: Zhijun Fu <37800433+zhijunfu@users.noreply.github.com> Date: Mon, 3 Jun 2019 19:59:43 +0800 Subject: [PATCH 066/118] [Core Worker] implement ObjectInterface and add test framework (#4899) --- .travis.yml | 3 + BUILD.bazel | 7 +- src/ray/common/buffer.h | 22 +- src/ray/core_worker/common.h | 4 +- src/ray/core_worker/context.cc | 81 +++++++ src/ray/core_worker/context.h | 48 ++++ src/ray/core_worker/core_worker.cc | 39 ++++ src/ray/core_worker/core_worker.h | 34 ++- src/ray/core_worker/core_worker_test.cc | 220 ++++++++++++++++++- src/ray/core_worker/object_interface.cc | 111 +++++++++- src/ray/core_worker/object_interface.h | 7 +- src/ray/gcs/redis_module/ray_redis_module.cc | 4 +- src/ray/status.h | 11 + src/ray/test/run_core_worker_tests.sh | 47 ++++ 14 files changed, 611 insertions(+), 27 deletions(-) create mode 100644 src/ray/core_worker/context.cc create mode 100644 src/ray/core_worker/context.h create mode 100644 src/ray/core_worker/core_worker.cc create mode 100644 src/ray/test/run_core_worker_tests.sh diff --git a/.travis.yml b/.travis.yml index 7402e6fb6e78..2266833f8414 100644 --- a/.travis.yml +++ b/.travis.yml @@ -148,6 +148,9 @@ install: - ./ci/suppress_output bazel build //:stats_test -c opt - ./bazel-bin/stats_test + # core worker test. + - ./ci/suppress_output bash src/ray/test/run_core_worker_tests.sh + # Raylet tests. - ./ci/suppress_output bash src/ray/test/run_object_manager_tests.sh - ./ci/suppress_output bazel test --build_tests_only --test_lang_filters=cc //:all diff --git a/BUILD.bazel b/BUILD.bazel index 0bdbe5741cf8..2b75d5d04e77 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -77,6 +77,7 @@ cc_library( "src/ray/raylet/mock_gcs_client.cc", "src/ray/raylet/monitor_main.cc", "src/ray/raylet/*_test.cc", + "src/ray/raylet/main.cc", ], ), hdrs = glob([ @@ -122,15 +123,18 @@ cc_library( deps = [ ":ray_common", ":ray_util", + ":raylet_lib", ], ) -cc_test( +# This test is run by src/ray/test/run_core_worker_tests.sh +cc_binary( name = "core_worker_test", srcs = ["src/ray/core_worker/core_worker_test.cc"], copts = COPTS, deps = [ ":core_worker_lib", + ":gcs", "@com_google_googletest//:gtest_main", ], ) @@ -320,6 +324,7 @@ cc_library( ":node_manager_fbs", ":ray_util", "@boost//:asio", + "@plasma//:plasma_client", ], ) diff --git a/src/ray/common/buffer.h b/src/ray/common/buffer.h index 358d903799c7..4340c74a8d38 100644 --- a/src/ray/common/buffer.h +++ b/src/ray/common/buffer.h @@ -3,6 +3,11 @@ #include #include +#include "plasma/client.h" + +namespace arrow { +class Buffer; +} namespace ray { @@ -15,7 +20,7 @@ class Buffer { /// Size of this buffer. virtual size_t Size() const = 0; - virtual ~Buffer() {} + virtual ~Buffer(){}; bool operator==(const Buffer &rhs) const { return this->Data() == rhs.Data() && this->Size() == rhs.Size(); @@ -40,6 +45,21 @@ class LocalMemoryBuffer : public Buffer { size_t size_; }; +/// Represents a byte buffer for plasma object. +class PlasmaBuffer : public Buffer { + public: + PlasmaBuffer(std::shared_ptr buffer) : buffer_(buffer) {} + + uint8_t *Data() const override { return const_cast(buffer_->data()); } + + size_t Size() const override { return buffer_->size(); } + + private: + /// shared_ptr to arrow buffer which can potentially hold a reference + /// for the object (when it's a plasma::PlasmaBuffer). + std::shared_ptr buffer_; +}; + } // namespace ray #endif // RAY_COMMON_BUFFER_H diff --git a/src/ray/core_worker/common.h b/src/ray/core_worker/common.h index b53c35b25fa8..ad9485e53826 100644 --- a/src/ray/core_worker/common.h +++ b/src/ray/core_worker/common.h @@ -45,13 +45,13 @@ class TaskArg { bool IsPassedByReference() const { return id_ != nullptr; } /// Get the reference object ID. - ObjectID &GetReference() { + const ObjectID &GetReference() const { RAY_CHECK(id_ != nullptr) << "This argument isn't passed by reference."; return *id_; } /// Get the value. - std::shared_ptr GetValue() { + std::shared_ptr GetValue() const { RAY_CHECK(data_ != nullptr) << "This argument isn't passed by value."; return data_; } diff --git a/src/ray/core_worker/context.cc b/src/ray/core_worker/context.cc new file mode 100644 index 000000000000..fedcfc6625d9 --- /dev/null +++ b/src/ray/core_worker/context.cc @@ -0,0 +1,81 @@ + +#include "context.h" + +namespace ray { + +/// per-thread context for core worker. +struct WorkerThreadContext { + WorkerThreadContext() + : current_task_id(TaskID::FromRandom()), task_index(0), put_index(0) {} + + int GetNextTaskIndex() { return ++task_index; } + + int GetNextPutIndex() { return ++put_index; } + + const TaskID &GetCurrentTaskID() const { return current_task_id; } + + void SetCurrentTask(const TaskID &task_id) { + current_task_id = task_id; + task_index = 0; + put_index = 0; + } + + void SetCurrentTask(const raylet::TaskSpecification &spec) { + SetCurrentTask(spec.TaskId()); + } + + private: + /// The task ID for current task. + TaskID current_task_id; + + /// Number of tasks that have been submitted from current task. + int task_index; + + /// Number of objects that have been put from current task. + int put_index; +}; + +thread_local std::unique_ptr WorkerContext::thread_context_ = + nullptr; + +WorkerContext::WorkerContext(WorkerType worker_type, const DriverID &driver_id) + : worker_type(worker_type), + worker_id(worker_type == WorkerType::DRIVER + ? ClientID::FromBinary(driver_id.Binary()) + : ClientID::FromRandom()), + current_driver_id(worker_type == WorkerType::DRIVER ? driver_id : DriverID::Nil()) { + // For worker main thread which initializes the WorkerContext, + // set task_id according to whether current worker is a driver. + // (For other threads it's set to randmom ID via GetThreadContext). + GetThreadContext().SetCurrentTask( + (worker_type == WorkerType::DRIVER) ? TaskID::FromRandom() : TaskID::Nil()); +} + +const WorkerType WorkerContext::GetWorkerType() const { return worker_type; } + +const ClientID &WorkerContext::GetWorkerID() const { return worker_id; } + +int WorkerContext::GetNextTaskIndex() { return GetThreadContext().GetNextTaskIndex(); } + +int WorkerContext::GetNextPutIndex() { return GetThreadContext().GetNextPutIndex(); } + +const DriverID &WorkerContext::GetCurrentDriverID() const { return current_driver_id; } + +const TaskID &WorkerContext::GetCurrentTaskID() const { + return GetThreadContext().GetCurrentTaskID(); +} + +void WorkerContext::SetCurrentTask(const raylet::TaskSpecification &spec) { + current_driver_id = spec.DriverId(); + GetThreadContext().SetCurrentTask(spec); +} + +WorkerThreadContext &WorkerContext::GetThreadContext() { + if (thread_context_ == nullptr) { + thread_context_ = std::unique_ptr(new WorkerThreadContext()); + } + + return *thread_context_; +} + +} // namespace ray diff --git a/src/ray/core_worker/context.h b/src/ray/core_worker/context.h new file mode 100644 index 000000000000..6e0cf3f9f2cf --- /dev/null +++ b/src/ray/core_worker/context.h @@ -0,0 +1,48 @@ +#ifndef RAY_CORE_WORKER_CONTEXT_H +#define RAY_CORE_WORKER_CONTEXT_H + +#include "common.h" +#include "ray/raylet/task_spec.h" + +namespace ray { + +struct WorkerThreadContext; + +class WorkerContext { + public: + WorkerContext(WorkerType worker_type, const DriverID &driver_id); + + const WorkerType GetWorkerType() const; + + const ClientID &GetWorkerID() const; + + const DriverID &GetCurrentDriverID() const; + + const TaskID &GetCurrentTaskID() const; + + void SetCurrentTask(const raylet::TaskSpecification &spec); + + int GetNextTaskIndex(); + + int GetNextPutIndex(); + + private: + /// Type of the worker. + const WorkerType worker_type; + + /// ID for this worker. + const ClientID worker_id; + + /// Driver ID for this worker. + DriverID current_driver_id; + + private: + static WorkerThreadContext &GetThreadContext(); + + /// Per-thread worker context. + static thread_local std::unique_ptr thread_context_; +}; + +} // namespace ray + +#endif // RAY_CORE_WORKER_CONTEXT_H diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc new file mode 100644 index 000000000000..82f2d885ec58 --- /dev/null +++ b/src/ray/core_worker/core_worker.cc @@ -0,0 +1,39 @@ +#include "core_worker.h" +#include "context.h" + +namespace ray { + +CoreWorker::CoreWorker(const enum WorkerType worker_type, const enum Language language, + const std::string &store_socket, const std::string &raylet_socket, + DriverID driver_id) + : worker_type_(worker_type), + language_(language), + worker_context_(worker_type, driver_id), + store_socket_(store_socket), + raylet_socket_(raylet_socket), + task_interface_(*this), + object_interface_(*this), + task_execution_interface_(*this) {} + +Status CoreWorker::Connect() { + // connect to plasma. + RAY_ARROW_RETURN_NOT_OK(store_client_.Connect(store_socket_)); + + // connect to raylet. + ::Language lang = ::Language::PYTHON; + if (language_ == ray::Language::JAVA) { + lang = ::Language::JAVA; + } + + // TODO: currently RayletClient would crash in its constructor if it cannot + // connect to Raylet after a number of retries, this needs to be changed + // so that the worker (java/python .etc) can retrieve and handle the error + // instead of crashing. + raylet_client_ = std::unique_ptr( + new RayletClient(raylet_socket_, worker_context_.GetWorkerID(), + (worker_type_ == ray::WorkerType::WORKER), + worker_context_.GetCurrentDriverID(), lang)); + return Status::OK(); +} + +} // namespace ray diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index 96e51dbc4532..951b55451f09 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -2,8 +2,10 @@ #define RAY_CORE_WORKER_CORE_WORKER_H #include "common.h" +#include "context.h" #include "object_interface.h" #include "ray/common/buffer.h" +#include "ray/raylet/raylet_client.h" #include "task_execution.h" #include "task_interface.h" @@ -18,15 +20,12 @@ class CoreWorker { /// /// \param[in] worker_type Type of this worker. /// \param[in] langauge Language of this worker. - CoreWorker(const WorkerType worker_type, const Language language) - : worker_type_(worker_type), - language_(language), - task_interface_(*this), - object_interface_(*this), - task_execution_interface_(*this) {} + CoreWorker(const WorkerType worker_type, const Language language, + const std::string &store_socket, const std::string &raylet_socket, + DriverID driver_id = DriverID::Nil()); - /// Connect this worker to Raylet. - Status Connect() { return Status::OK(); } + /// Connect to raylet. + Status Connect(); /// Type of this worker. enum WorkerType WorkerType() const { return worker_type_; } @@ -53,6 +52,21 @@ class CoreWorker { /// Language of this worker. const enum Language language_; + /// Worker context per thread. + WorkerContext worker_context_; + + /// Plasma store socket name. + std::string store_socket_; + + /// raylet socket name. + std::string raylet_socket_; + + /// Plasma store client. + plasma::PlasmaClient store_client_; + + /// Raylet client. + std::unique_ptr raylet_client_; + /// The `CoreWorkerTaskInterface` instance. CoreWorkerTaskInterface task_interface_; @@ -61,6 +75,10 @@ class CoreWorker { /// The `CoreWorkerTaskExecutionInterface` instance. CoreWorkerTaskExecutionInterface task_execution_interface_; + + friend class CoreWorkerTaskInterface; + friend class CoreWorkerObjectInterface; + friend class CoreWorkerTaskExecutionInterface; }; } // namespace ray diff --git a/src/ray/core_worker/core_worker_test.cc b/src/ray/core_worker/core_worker_test.cc index 6711c874a973..e440aae24d67 100644 --- a/src/ray/core_worker/core_worker_test.cc +++ b/src/ray/core_worker/core_worker_test.cc @@ -1,20 +1,137 @@ +#include #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "context.h" #include "core_worker.h" #include "ray/common/buffer.h" +#include "ray/raylet/raylet_client.h" + +#include +#include +#include + +#include "ray/thirdparty/hiredis/async.h" +#include "ray/thirdparty/hiredis/hiredis.h" namespace ray { +std::string store_executable; +std::string raylet_executable; + +ray::ObjectID RandomObjectID() { return ObjectID::FromRandom(); } + +static void flushall_redis(void) { + redisContext *context = redisConnect("127.0.0.1", 6379); + freeReplyObject(redisCommand(context, "FLUSHALL")); + freeReplyObject(redisCommand(context, "SET NumRedisShards 1")); + freeReplyObject(redisCommand(context, "LPUSH RedisShards 127.0.0.1:6380")); + redisFree(context); +} + class CoreWorkerTest : public ::testing::Test { public: - CoreWorkerTest() : core_worker_(WorkerType::WORKER, Language::PYTHON) {} + CoreWorkerTest(int num_nodes) { + RAY_CHECK(num_nodes >= 0); + if (num_nodes > 0) { + raylet_socket_names_.resize(num_nodes); + raylet_store_socket_names_.resize(num_nodes); + } + + // start plasma store. + for (auto &store_socket : raylet_store_socket_names_) { + store_socket = StartStore(); + } + + // start raylet on each node + for (int i = 0; i < num_nodes; i++) { + raylet_socket_names_[i] = StartRaylet(raylet_store_socket_names_[i], "127.0.0.1", + "127.0.0.1", "\"CPU,4.0\""); + } + } + + ~CoreWorkerTest() { + for (const auto &raylet_socket : raylet_socket_names_) { + StopRaylet(raylet_socket); + } + + for (const auto &store_socket : raylet_store_socket_names_) { + StopStore(store_socket); + } + } + + std::string StartStore() { + std::string store_socket_name = "/tmp/store" + RandomObjectID().Hex(); + std::string store_pid = store_socket_name + ".pid"; + std::string plasma_command = store_executable + " -m 10000000 -s " + + store_socket_name + + " 1> /dev/null 2> /dev/null & echo $! > " + store_pid; + RAY_LOG(INFO) << plasma_command; + RAY_CHECK(system(plasma_command.c_str()) == 0); + usleep(200 * 1000); + return store_socket_name; + } + + void StopStore(std::string store_socket_name) { + std::string store_pid = store_socket_name + ".pid"; + std::string kill_9 = "kill -9 `cat " + store_pid + "`"; + RAY_LOG(INFO) << kill_9; + ASSERT_TRUE(system(kill_9.c_str()) == 0); + ASSERT_TRUE(system(("rm -rf " + store_socket_name).c_str()) == 0); + ASSERT_TRUE(system(("rm -rf " + store_socket_name + ".pid").c_str()) == 0); + } + + std::string StartRaylet(std::string store_socket_name, std::string node_ip_address, + std::string redis_address, std::string resource) { + std::string raylet_socket_name = "/tmp/raylet" + RandomObjectID().Hex(); + std::string ray_start_cmd = raylet_executable; + ray_start_cmd.append(" --raylet_socket_name=" + raylet_socket_name) + .append(" --store_socket_name=" + store_socket_name) + .append(" --object_manager_port=0 --node_manager_port=0") + .append(" --node_ip_address=" + node_ip_address) + .append(" --redis_address=" + redis_address) + .append(" --redis_port=6379") + .append(" --num_initial_workers=0") + .append(" --maximum_startup_concurrency=10") + .append(" --static_resource_list=" + resource) + .append(" --python_worker_command=NoneCmd") + .append(" & echo $! > " + raylet_socket_name + ".pid"); + + RAY_LOG(INFO) << "Ray Start command: " << ray_start_cmd; + RAY_CHECK(system(ray_start_cmd.c_str()) == 0); + usleep(200 * 1000); + return raylet_socket_name; + } + + void StopRaylet(std::string raylet_socket_name) { + std::string raylet_pid = raylet_socket_name + ".pid"; + std::string kill_9 = "kill -9 `cat " + raylet_pid + "`"; + RAY_LOG(INFO) << kill_9; + ASSERT_TRUE(system(kill_9.c_str()) == 0); + ASSERT_TRUE(system(("rm -rf " + raylet_socket_name).c_str()) == 0); + ASSERT_TRUE(system(("rm -rf " + raylet_socket_name + ".pid").c_str()) == 0); + } + + void SetUp() { flushall_redis(); } + + void TearDown() {} protected: - CoreWorker core_worker_; + std::vector raylet_socket_names_; + std::vector raylet_store_socket_names_; }; -TEST_F(CoreWorkerTest, TestTaskArg) { +class ZeroNodeTest : public CoreWorkerTest { + public: + ZeroNodeTest() : CoreWorkerTest(0) {} +}; + +class SingleNodeTest : public CoreWorkerTest { + public: + SingleNodeTest() : CoreWorkerTest(1) {} +}; + +TEST_F(ZeroNodeTest, TestTaskArg) { // Test by-reference argument. ObjectID id = ObjectID::FromRandom(); TaskArg by_ref = TaskArg::PassByReference(id); @@ -30,9 +147,100 @@ TEST_F(CoreWorkerTest, TestTaskArg) { ASSERT_EQ(*data, *buffer); } -TEST_F(CoreWorkerTest, TestAttributeGetters) { - ASSERT_EQ(core_worker_.WorkerType(), WorkerType::WORKER); - ASSERT_EQ(core_worker_.Language(), Language::PYTHON); +TEST_F(ZeroNodeTest, TestAttributeGetters) { + CoreWorker core_worker(WorkerType::DRIVER, Language::PYTHON, "", "", + DriverID::FromRandom()); + ASSERT_EQ(core_worker.WorkerType(), WorkerType::DRIVER); + ASSERT_EQ(core_worker.Language(), Language::PYTHON); +} + +TEST_F(ZeroNodeTest, TestWorkerContext) { + auto driver_id = DriverID::FromRandom(); + + WorkerContext context(WorkerType::WORKER, driver_id); + ASSERT_TRUE(context.GetCurrentTaskID().IsNil()); + ASSERT_EQ(context.GetNextTaskIndex(), 1); + ASSERT_EQ(context.GetNextTaskIndex(), 2); + ASSERT_EQ(context.GetNextPutIndex(), 1); + ASSERT_EQ(context.GetNextPutIndex(), 2); + + auto thread_func = [&context]() { + // Verify that task_index, put_index are thread-local. + ASSERT_TRUE(!context.GetCurrentTaskID().IsNil()); + ASSERT_EQ(context.GetNextTaskIndex(), 1); + ASSERT_EQ(context.GetNextPutIndex(), 1); + }; + + std::thread async_thread(thread_func); + async_thread.join(); + + // Verify that these fields are thread-local. + ASSERT_EQ(context.GetNextTaskIndex(), 3); + ASSERT_EQ(context.GetNextPutIndex(), 3); +} + +TEST_F(SingleNodeTest, TestObjectInterface) { + CoreWorker core_worker(WorkerType::DRIVER, Language::PYTHON, + raylet_store_socket_names_[0], raylet_socket_names_[0], + DriverID::FromRandom()); + RAY_CHECK_OK(core_worker.Connect()); + + uint8_t array1[] = {1, 2, 3, 4, 5, 6, 7, 8}; + uint8_t array2[] = {10, 11, 12, 13, 14, 15}; + + std::vector buffers; + buffers.emplace_back(array1, sizeof(array1)); + buffers.emplace_back(array2, sizeof(array2)); + + std::vector ids(buffers.size()); + for (int i = 0; i < ids.size(); i++) { + core_worker.Objects().Put(buffers[i], &ids[i]); + } + + // Test Get(). + std::vector> results; + core_worker.Objects().Get(ids, 0, &results); + + ASSERT_EQ(results.size(), 2); + for (int i = 0; i < ids.size(); i++) { + ASSERT_EQ(results[i]->Size(), buffers[i].Size()); + ASSERT_EQ(memcmp(results[i]->Data(), buffers[i].Data(), buffers[i].Size()), 0); + } + + // Test Wait(). + ObjectID non_existent_id = ObjectID::FromRandom(); + std::vector all_ids(ids); + all_ids.push_back(non_existent_id); + + std::vector wait_results; + core_worker.Objects().Wait(all_ids, 2, -1, &wait_results); + ASSERT_EQ(wait_results.size(), 3); + ASSERT_EQ(wait_results, std::vector({true, true, false})); + + core_worker.Objects().Wait(all_ids, 3, 100, &wait_results); + ASSERT_EQ(wait_results.size(), 3); + ASSERT_EQ(wait_results, std::vector({true, true, false})); + + // Test Delete(). + // clear the reference held by PlasmaBuffer. + results.clear(); + core_worker.Objects().Delete(ids, true, false); + + // Note that Delete() calls RayletClient::FreeObjects and would not + // wait for objects being deleted, so wait a while for plasma store + // to process the command. + usleep(200 * 1000); + core_worker.Objects().Get(ids, 0, &results); + ASSERT_EQ(results.size(), 2); + ASSERT_TRUE(!results[0]); + ASSERT_TRUE(!results[1]); } } // namespace ray + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + ray::store_executable = std::string(argv[1]); + ray::raylet_executable = std::string(argv[2]); + return RUN_ALL_TESTS(); +} diff --git a/src/ray/core_worker/object_interface.cc b/src/ray/core_worker/object_interface.cc index d5d5d6f883f6..c966192610c7 100644 --- a/src/ray/core_worker/object_interface.cc +++ b/src/ray/core_worker/object_interface.cc @@ -1,25 +1,128 @@ #include "object_interface.h" +#include "context.h" +#include "core_worker.h" +#include "ray/ray_config.h" namespace ray { -Status CoreWorkerObjectInterface::Put(const Buffer &buffer, const ObjectID *object_id) { +CoreWorkerObjectInterface::CoreWorkerObjectInterface(CoreWorker &core_worker) + : core_worker_(core_worker) {} + +Status CoreWorkerObjectInterface::Put(const Buffer &buffer, ObjectID *object_id) { + ObjectID put_id = ObjectID::ForPut(core_worker_.worker_context_.GetCurrentTaskID(), + core_worker_.worker_context_.GetNextPutIndex()); + *object_id = put_id; + + auto plasma_id = put_id.ToPlasmaId(); + std::shared_ptr data; + RAY_ARROW_RETURN_NOT_OK( + core_worker_.store_client_.Create(plasma_id, buffer.Size(), nullptr, 0, &data)); + memcpy(data->mutable_data(), buffer.Data(), buffer.Size()); + RAY_ARROW_RETURN_NOT_OK(core_worker_.store_client_.Seal(plasma_id)); + RAY_ARROW_RETURN_NOT_OK(core_worker_.store_client_.Release(plasma_id)); return Status::OK(); } Status CoreWorkerObjectInterface::Get(const std::vector &ids, - int64_t timeout_ms, std::vector *results) { + int64_t timeout_ms, + std::vector> *results) { + (*results).resize(ids.size(), nullptr); + + bool was_blocked = false; + + std::unordered_map unready; + for (int i = 0; i < ids.size(); i++) { + unready.insert({ids[i], i}); + } + + int num_attempts = 0; + bool should_break = false; + int64_t remaining_timeout = timeout_ms; + // Repeat until we get all objects. + while (!unready.empty() && !should_break) { + std::vector unready_ids; + for (const auto &entry : unready) { + unready_ids.push_back(entry.first); + } + + // For the initial fetch, we only fetch the objects, do not reconstruct them. + bool fetch_only = num_attempts == 0; + if (!fetch_only) { + // If fetch_only is false, this worker will be blocked. + was_blocked = true; + } + + // TODO: can call `fetchOrReconstruct` in batches as an optimization. + RAY_CHECK_OK(core_worker_.raylet_client_->FetchOrReconstruct( + unready_ids, fetch_only, core_worker_.worker_context_.GetCurrentTaskID())); + + // Get the objects from the object store, and parse the result. + int64_t get_timeout; + if (remaining_timeout >= 0) { + get_timeout = + std::min(remaining_timeout, RayConfig::instance().get_timeout_milliseconds()); + remaining_timeout -= get_timeout; + should_break = remaining_timeout <= 0; + } else { + get_timeout = RayConfig::instance().get_timeout_milliseconds(); + } + + std::vector plasma_ids; + for (const auto &id : unready_ids) { + plasma_ids.push_back(id.ToPlasmaId()); + } + + std::vector object_buffers; + auto status = + core_worker_.store_client_.Get(plasma_ids, get_timeout, &object_buffers); + + for (int i = 0; i < object_buffers.size(); i++) { + if (object_buffers[i].data != nullptr) { + const auto &object_id = unready_ids[i]; + (*results)[unready[object_id]] = + std::make_shared(object_buffers[i].data); + unready.erase(object_id); + } + } + + num_attempts += 1; + // TODO: log a message if attempted too many times. + } + + if (was_blocked) { + RAY_CHECK_OK(core_worker_.raylet_client_->NotifyUnblocked( + core_worker_.worker_context_.GetCurrentTaskID())); + } + return Status::OK(); } Status CoreWorkerObjectInterface::Wait(const std::vector &object_ids, int num_objects, int64_t timeout_ms, std::vector *results) { - return Status::OK(); + WaitResultPair result_pair; + auto status = core_worker_.raylet_client_->Wait( + object_ids, num_objects, timeout_ms, false, + core_worker_.worker_context_.GetCurrentTaskID(), &result_pair); + std::unordered_set ready_ids; + for (const auto &entry : result_pair.first) { + ready_ids.insert(entry); + } + + // TODO: change RayletClient::Wait() to return a bit set, so that we don't need + // to do this translation. + (*results).resize(object_ids.size()); + for (int i = 0; i < object_ids.size(); i++) { + (*results)[i] = ready_ids.count(object_ids[i]) > 0; + } + + return status; } Status CoreWorkerObjectInterface::Delete(const std::vector &object_ids, bool local_only, bool delete_creating_tasks) { - return Status::OK(); + return core_worker_.raylet_client_->FreeObjects(object_ids, local_only, + delete_creating_tasks); } } // namespace ray diff --git a/src/ray/core_worker/object_interface.h b/src/ray/core_worker/object_interface.h index 424c123ee543..f14c5297c456 100644 --- a/src/ray/core_worker/object_interface.h +++ b/src/ray/core_worker/object_interface.h @@ -2,6 +2,7 @@ #define RAY_CORE_WORKER_OBJECT_INTERFACE_H #include "common.h" +#include "plasma/client.h" #include "ray/common/buffer.h" #include "ray/id.h" #include "ray/status.h" @@ -13,14 +14,14 @@ class CoreWorker; /// The interface that contains all `CoreWorker` methods that are related to object store. class CoreWorkerObjectInterface { public: - CoreWorkerObjectInterface(CoreWorker &core_worker) : core_worker_(core_worker) {} + CoreWorkerObjectInterface(CoreWorker &core_worker); /// Put an object into object store. /// /// \param[in] buffer Data buffer of the object. /// \param[out] object_id Generated ID of the object. /// \return Status. - Status Put(const Buffer &buffer, const ObjectID *object_id); + Status Put(const Buffer &buffer, ObjectID *object_id); /// Get a list of objects from the object store. /// @@ -29,7 +30,7 @@ class CoreWorkerObjectInterface { /// \param[out] results Result list of objects data. /// \return Status. Status Get(const std::vector &ids, int64_t timeout_ms, - std::vector *results); + std::vector> *results); /// Wait for a list of objects to appear in the object store. /// diff --git a/src/ray/gcs/redis_module/ray_redis_module.cc b/src/ray/gcs/redis_module/ray_redis_module.cc index 6a7742c6b5a4..13450a4b7642 100644 --- a/src/ray/gcs/redis_module/ray_redis_module.cc +++ b/src/ray/gcs/redis_module/ray_redis_module.cc @@ -791,7 +791,7 @@ int TableCancelNotifications_RedisCommand(RedisModuleCtx *ctx, RedisModuleString return REDISMODULE_OK; } -Status is_nil(bool *out, const std::string &data) { +Status IsNil(bool *out, const std::string &data) { if (data.size() != kUniqueIDSize) { return Status::RedisError("Size of data doesn't match size of UniqueID"); } @@ -836,7 +836,7 @@ int TableTestAndUpdate_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **arg static_cast(update->test_state_bitmask()); bool is_nil_result; - REPLY_AND_RETURN_IF_NOT_OK(is_nil(&is_nil_result, update->test_raylet_id()->str())); + REPLY_AND_RETURN_IF_NOT_OK(IsNil(&is_nil_result, update->test_raylet_id()->str())); if (!is_nil_result) { do_update = do_update && update->test_raylet_id()->str() == data->raylet_id()->str(); } diff --git a/src/ray/status.h b/src/ray/status.h index fb6252b34667..340ffb3112cc 100644 --- a/src/ray/status.h +++ b/src/ray/status.h @@ -54,6 +54,17 @@ // This macro is used to replace the "ARROW_CHECK_OK" macro. #define RAY_ARROW_CHECK_OK(s) RAY_ARROW_CHECK_OK_PREPEND(s, "Bad status") +// If arrow status is not ok, return a ray IOError status +// with the error message. +#define RAY_ARROW_RETURN_NOT_OK(s) \ + do { \ + ::arrow::Status _s = (s); \ + if (RAY_PREDICT_FALSE(!_s.ok())) { \ + return ray::Status::IOError(_s.message()); \ + ; \ + } \ + } while (0) + namespace ray { enum class StatusCode : char { diff --git a/src/ray/test/run_core_worker_tests.sh b/src/ray/test/run_core_worker_tests.sh new file mode 100644 index 000000000000..5f1dd2eda69f --- /dev/null +++ b/src/ray/test/run_core_worker_tests.sh @@ -0,0 +1,47 @@ +#!/usr/bin/env bash + +# This needs to be run in the root directory. + +# Cause the script to exit if a single command fails. +set -e +set -x + +bazel build "//:core_worker_test" "//:raylet" "//:libray_redis_module.so" "@plasma//:plasma_store_server" + +# Get the directory in which this script is executing. +SCRIPT_DIR="`dirname \"$0\"`" +RAY_ROOT="$SCRIPT_DIR/../../.." +# Makes $RAY_ROOT an absolute path. +RAY_ROOT="`( cd \"$RAY_ROOT\" && pwd )`" +if [ -z "$RAY_ROOT" ] ; then + exit 1 +fi +# Ensure we're in the right directory. +if [ ! -d "$RAY_ROOT/python" ]; then + echo "Unable to find root Ray directory. Has this script moved?" + exit 1 +fi + +REDIS_MODULE="./bazel-bin/libray_redis_module.so" +LOAD_MODULE_ARGS="--loadmodule ${REDIS_MODULE}" +STORE_EXEC="./bazel-bin/external/plasma/plasma_store_server" +RAYLET_EXEC="./bazel-bin/raylet" + +# Allow cleanup commands to fail. +bazel run //:redis-cli -- -p 6379 shutdown || true +sleep 1s +bazel run //:redis-cli -- -p 6380 shutdown || true +sleep 1s +bazel run //:redis-server -- --loglevel warning ${LOAD_MODULE_ARGS} --port 6379 & +sleep 2s +bazel run //:redis-server -- --loglevel warning ${LOAD_MODULE_ARGS} --port 6380 & +sleep 2s +# Run tests. +./bazel-bin/core_worker_test $STORE_EXEC $RAYLET_EXEC +sleep 1s +bazel run //:redis-cli -- -p 6379 shutdown +bazel run //:redis-cli -- -p 6380 shutdown +sleep 1s + +# Include raylet integration test once it's ready. +# ./bazel-bin/object_manager_integration_test $STORE_EXEC From c2253d2313f5a43c20658319063e1713ad695e67 Mon Sep 17 00:00:00 2001 From: Timon Ruban Date: Tue, 4 Jun 2019 03:45:15 +0200 Subject: [PATCH 067/118] [tune] Make PBT Quantile fraction configurable (#4912) --- python/ray/tune/schedulers/pbt.py | 33 ++++++++++++++----- python/ray/tune/tests/test_trial_scheduler.py | 1 + 2 files changed, 26 insertions(+), 8 deletions(-) diff --git a/python/ray/tune/schedulers/pbt.py b/python/ray/tune/schedulers/pbt.py index b6d1b4e80838..f8f7ce7dc8f8 100644 --- a/python/ray/tune/schedulers/pbt.py +++ b/python/ray/tune/schedulers/pbt.py @@ -19,10 +19,6 @@ logger = logging.getLogger(__name__) -# Parameters are transferred from the top PBT_QUANTILE fraction of trials to -# the bottom PBT_QUANTILE fraction. -PBT_QUANTILE = 0.25 - class PBTTrialState(object): """Internal PBT state tracked per-trial.""" @@ -134,6 +130,10 @@ class PopulationBasedTraining(FIFOScheduler): A function specifies the distribution of a continuous parameter. You must specify at least one of `hyperparam_mutations` or `custom_explore_fn`. + quantile_fraction (float): Parameters are transferred from the top + `quantile_fraction` fraction of trials to the bottom + `quantile_fraction` fraction. Needs to be between 0 and 0.5. + Setting it to 0 essentially implies doing no exploitation at all. resample_probability (float): The probability of resampling from the original distribution when applying `hyperparam_mutations`. If not resampled, the value will be perturbed by a factor of 1.2 or 0.8 @@ -172,6 +172,7 @@ def __init__(self, mode="max", perturbation_interval=60.0, hyperparam_mutations={}, + quantile_fraction=0.25, resample_probability=0.25, custom_explore_fn=None, log_config=True): @@ -180,6 +181,11 @@ def __init__(self, "You must specify at least one of `hyperparam_mutations` or " "`custom_explore_fn` to use PBT.") + if quantile_fraction > 0.5 or quantile_fraction < 0: + raise TuneError( + "You must set `quantile_fraction` to a value between 0 and" + "0.5. Current value: '{}'".format(quantile_fraction)) + assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!" if reward_attr is not None: @@ -199,6 +205,7 @@ def __init__(self, self._time_attr = time_attr self._perturbation_interval = perturbation_interval self._hyperparam_mutations = hyperparam_mutations + self._quantile_fraction = quantile_fraction self._resample_probability = resample_probability self._trial_state = {} self._custom_explore_fn = custom_explore_fn @@ -247,6 +254,7 @@ def _log_config_on_step(self, trial_state, new_state, trial, For each step, logs: [target trial tag, clone trial tag, target trial iteration, clone trial iteration, old config, new config]. + """ trial_name, trial_to_clone_name = (trial_state.orig_tag, new_state.orig_tag) @@ -277,7 +285,9 @@ def _log_config_on_step(self, trial_state, new_state, trial, def _exploit(self, trial_executor, trial, trial_to_clone): """Transfers perturbed state from trial_to_clone -> trial. - If specified, also logs the updated hyperparam state.""" + If specified, also logs the updated hyperparam state. + + """ trial_state = self._trial_state[trial] new_state = self._trial_state[trial_to_clone] @@ -318,7 +328,9 @@ def _exploit(self, trial_executor, trial, trial_to_clone): def _quantiles(self): """Returns trials in the lower and upper `quantile` of the population. - If there is not enough data to compute this, returns empty lists.""" + If there is not enough data to compute this, returns empty lists. + + """ trials = [] for trial, state in self._trial_state.items(): @@ -329,14 +341,19 @@ def _quantiles(self): if len(trials) <= 1: return [], [] else: - return (trials[:int(math.ceil(len(trials) * PBT_QUANTILE))], - trials[int(math.floor(-len(trials) * PBT_QUANTILE)):]) + num_trials_in_quantile = int( + math.ceil(len(trials) * self._quantile_fraction)) + if num_trials_in_quantile > len(trials) / 2: + num_trials_in_quantile = int(math.floor(len(trials) / 2)) + return (trials[:num_trials_in_quantile], + trials[-num_trials_in_quantile:]) def choose_trial_to_run(self, trial_runner): """Ensures all trials get fair share of time (as defined by time_attr). This enables the PBT scheduler to support a greater number of concurrent trials than can fit in the cluster at any given time. + """ candidates = [] diff --git a/python/ray/tune/tests/test_trial_scheduler.py b/python/ray/tune/tests/test_trial_scheduler.py index 1c33375549f0..427ca53633fb 100644 --- a/python/ray/tune/tests/test_trial_scheduler.py +++ b/python/ray/tune/tests/test_trial_scheduler.py @@ -627,6 +627,7 @@ def basicSetup(self, resample_prob=0.0, explore=None, log_config=False): time_attr="training_iteration", perturbation_interval=10, resample_probability=resample_prob, + quantile_fraction=0.25, hyperparam_mutations={ "id_factor": [100], "float_factor": lambda: 100.0, From d10628376994fdea2eec3269201e1e589f4f2dad Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Wed, 5 Jun 2019 14:19:09 +0800 Subject: [PATCH 068/118] Better organize ray_common module (#4898) --- BUILD.bazel | 30 +++++++++---------- .../java/org/ray/runtime/util/IdUtil.java | 4 +-- python/ray/includes/common.pxd | 6 ++-- python/ray/includes/ray_config.pxd | 2 +- python/ray/includes/unique_ids.pxd | 4 +-- python/ray/includes/unique_ids.pxi | 8 ++--- src/ray/common/client_connection.cc | 3 +- src/ray/common/client_connection.h | 4 +-- src/ray/common/common_protocol.h | 2 +- src/ray/{ => common}/constants.h | 0 src/ray/{ => common}/id.cc | 8 ++--- src/ray/{ => common}/id.h | 2 +- src/ray/{ => common}/id_def.h | 0 src/ray/{ => common}/ray_config.h | 0 src/ray/{ => common}/ray_config_def.h | 0 src/ray/{ => common}/status.cc | 2 +- src/ray/{ => common}/status.h | 0 src/ray/core_worker/common.h | 2 +- src/ray/core_worker/object_interface.cc | 2 +- src/ray/core_worker/object_interface.h | 4 +-- src/ray/core_worker/task_execution.h | 2 +- src/ray/core_worker/task_interface.h | 4 +-- src/ray/gcs/client.cc | 2 +- src/ray/gcs/client.h | 4 +-- src/ray/gcs/client_test.cc | 2 +- src/ray/gcs/redis_context.cc | 2 +- src/ray/gcs/redis_context.h | 4 +-- src/ray/gcs/redis_module/ray_redis_module.cc | 4 +-- src/ray/gcs/tables.cc | 2 +- src/ray/gcs/tables.h | 6 ++-- src/ray/object_manager/connection_pool.h | 4 +-- src/ray/object_manager/object_buffer_pool.cc | 2 +- src/ray/object_manager/object_buffer_pool.h | 4 +-- src/ray/object_manager/object_directory.h | 4 +-- src/ray/object_manager/object_manager.h | 4 +-- .../object_manager_client_connection.h | 4 +-- .../object_store_notification_manager.cc | 2 +- .../object_store_notification_manager.h | 4 +-- .../test/object_manager_stress_test.cc | 2 +- .../test/object_manager_test.cc | 4 +-- src/ray/raylet/actor_registration.h | 2 +- ...org_ray_runtime_raylet_RayletClientImpl.cc | 2 +- src/ray/raylet/lineage_cache.h | 4 +-- src/ray/raylet/main.cc | 4 +-- src/ray/raylet/mock_gcs_client.h | 4 +-- src/ray/raylet/monitor.cc | 4 +-- src/ray/raylet/monitor.h | 2 +- src/ray/raylet/monitor_main.cc | 2 +- src/ray/raylet/node_manager.cc | 4 +-- .../raylet/object_manager_integration_test.cc | 2 +- src/ray/raylet/raylet.cc | 4 +-- src/ray/raylet/raylet_client.cc | 2 +- src/ray/raylet/raylet_client.h | 2 +- src/ray/raylet/reconstruction_policy.h | 2 +- src/ray/raylet/scheduling_queue.cc | 2 +- src/ray/raylet/task_dependency_manager.h | 2 +- src/ray/raylet/task_execution_spec.h | 2 +- src/ray/raylet/task_spec.h | 2 +- src/ray/raylet/worker.h | 2 +- src/ray/raylet/worker_pool.cc | 4 +-- src/ray/util/util.h | 2 +- 61 files changed, 102 insertions(+), 103 deletions(-) rename src/ray/{ => common}/constants.h (100%) rename src/ray/{ => common}/id.cc (96%) rename src/ray/{ => common}/id.h (99%) rename src/ray/{ => common}/id_def.h (100%) rename src/ray/{ => common}/ray_config.h (100%) rename src/ray/{ => common}/ray_config_def.h (100%) rename src/ray/{ => common}/status.cc (98%) rename src/ray/{ => common}/status.h (100%) diff --git a/BUILD.bazel b/BUILD.bazel index 2b75d5d04e77..90b0f536a10b 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -281,16 +281,13 @@ cc_library( name = "ray_util", srcs = glob( [ - "src/ray/*.cc", "src/ray/util/*.cc", ], exclude = [ - "src/ray/util/logging_test.cc", - "src/ray/util/signal_test.cc", + "src/ray/util/*_test.cc", ], ), hdrs = glob([ - "src/ray/*.h", "src/ray/util/*.h", ]), copts = COPTS, @@ -306,22 +303,25 @@ cc_library( cc_library( name = "ray_common", - srcs = [ - "src/ray/common/client_connection.cc", - "src/ray/common/common_protocol.cc", - ], - hdrs = [ - "src/ray/common/buffer.h", - "src/ray/common/client_connection.h", - "src/ray/common/common_protocol.h", - ], + srcs = glob( + [ + "src/ray/common/*.cc", + ], + exclude = [ + "src/ray/common/*_test.cc", + ], + ), + hdrs = glob( + [ + "src/ray/common/*.h", + ], + ), copts = COPTS, includes = [ "src/ray/gcs/format", ], deps = [ ":gcs_fbs", - ":node_manager_fbs", ":ray_util", "@boost//:asio", "@plasma//:plasma_client", @@ -468,7 +468,7 @@ cc_binary( srcs = [ "src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.h", "src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc", - "src/ray/id.h", + "src/ray/common/id.h", "src/ray/raylet/raylet_client.h", "src/ray/util/logging.h", "@bazel_tools//tools/jdk:jni_header", diff --git a/java/runtime/src/main/java/org/ray/runtime/util/IdUtil.java b/java/runtime/src/main/java/org/ray/runtime/util/IdUtil.java index 04f75500b29d..67df09fa11ea 100644 --- a/java/runtime/src/main/java/org/ray/runtime/util/IdUtil.java +++ b/java/runtime/src/main/java/org/ray/runtime/util/IdUtil.java @@ -13,7 +13,7 @@ /** * Helper method for different Ids. * Note: any changes to these methods must be synced with C++ helper functions - * in src/ray/id.h + * in src/ray/common/id.h */ public class IdUtil { public static final int OBJECT_INDEX_POS = 16; @@ -161,7 +161,7 @@ public static long murmurHashCode(BaseId id) { } /** - * This method is the same as `Hash()` method of `ID` class in ray/src/ray/id.h + * This method is the same as `Hash()` method of `ID` class in ray/src/ray/common/id.h */ private static long murmurHash64A(byte[] data, int length, int seed) { final long m = 0xc6a4a7935bd1e995L; diff --git a/python/ray/includes/common.pxd b/python/ray/includes/common.pxd index bdb4316fcc4e..4c2cd8437a66 100644 --- a/python/ray/includes/common.pxd +++ b/python/ray/includes/common.pxd @@ -12,7 +12,7 @@ from ray.includes.unique_ids cimport ( ) -cdef extern from "ray/status.h" namespace "ray" nogil: +cdef extern from "ray/common/status.h" namespace "ray" nogil: cdef cppclass StatusCode: pass @@ -68,7 +68,7 @@ cdef extern from "ray/status.h" namespace "ray" nogil: cdef CRayStatus RayStatus_Invalid "Status::Invalid"() -cdef extern from "ray/status.h" namespace "ray::StatusCode" nogil: +cdef extern from "ray/common/status.h" namespace "ray::StatusCode" nogil: cdef StatusCode StatusCode_OK "OK" cdef StatusCode StatusCode_OutOfMemory "OutOfMemory" cdef StatusCode StatusCode_KeyError "KeyError" @@ -80,7 +80,7 @@ cdef extern from "ray/status.h" namespace "ray::StatusCode" nogil: cdef StatusCode StatusCode_RedisError "RedisError" -cdef extern from "ray/id.h" namespace "ray" nogil: +cdef extern from "ray/common/id.h" namespace "ray" nogil: const CTaskID GenerateTaskId(const CDriverID &driver_id, const CTaskID &parent_task_id, int parent_task_counter) diff --git a/python/ray/includes/ray_config.pxd b/python/ray/includes/ray_config.pxd index bae0995cebff..41adec160c19 100644 --- a/python/ray/includes/ray_config.pxd +++ b/python/ray/includes/ray_config.pxd @@ -3,7 +3,7 @@ from libcpp.string cimport string as c_string from libcpp.unordered_map cimport unordered_map -cdef extern from "ray/ray_config.h" nogil: +cdef extern from "ray/common/ray_config.h" nogil: cdef cppclass RayConfig "RayConfig": @staticmethod RayConfig &instance() diff --git a/python/ray/includes/unique_ids.pxd b/python/ray/includes/unique_ids.pxd index 8bf369c649b7..1bce0f4ba2d5 100644 --- a/python/ray/includes/unique_ids.pxd +++ b/python/ray/includes/unique_ids.pxd @@ -2,7 +2,7 @@ from libcpp cimport bool as c_bool from libcpp.string cimport string as c_string from libc.stdint cimport uint8_t, int64_t -cdef extern from "ray/id.h" namespace "ray" nogil: +cdef extern from "ray/common/id.h" namespace "ray" nogil: cdef cppclass CBaseID[T]: @staticmethod T from_random() @@ -113,7 +113,7 @@ cdef extern from "ray/id.h" namespace "ray" nogil: c_bool is_put() - int64_t ObjectIndex() const + int64_t ObjectIndex() const CTaskID TaskId() const diff --git a/python/ray/includes/unique_ids.pxi b/python/ray/includes/unique_ids.pxi index cd3c58003fed..98a0a291351d 100644 --- a/python/ray/includes/unique_ids.pxi +++ b/python/ray/includes/unique_ids.pxi @@ -34,7 +34,7 @@ def check_id(b, size=kUniqueIDSize): str(size)) -cdef extern from "ray/constants.h" nogil: +cdef extern from "ray/common/constants.h" nogil: cdef int64_t kUniqueIDSize cdef int64_t kMaxTaskPuts @@ -109,7 +109,7 @@ cdef class UniqueID(BaseID): def nil(cls): return cls(CUniqueID.Nil().Binary()) - + @classmethod def from_random(cls): return cls(os.urandom(CUniqueID.Size())) @@ -122,7 +122,7 @@ cdef class UniqueID(BaseID): def hex(self): return decode(self.data.Hex()) - + def is_nil(self): return self.data.IsNil() @@ -148,7 +148,7 @@ cdef class ObjectID(BaseID): def hex(self): return decode(self.data.Hex()) - + def is_nil(self): return self.data.IsNil() diff --git a/src/ray/common/client_connection.cc b/src/ray/common/client_connection.cc index b36382bbb99a..b5b26042689d 100644 --- a/src/ray/common/client_connection.cc +++ b/src/ray/common/client_connection.cc @@ -4,8 +4,7 @@ #include #include -#include "ray/ray_config.h" -#include "ray/raylet/format/node_manager_generated.h" +#include "ray/common/ray_config.h" #include "ray/util/util.h" namespace ray { diff --git a/src/ray/common/client_connection.h b/src/ray/common/client_connection.h index fe895fafd4d4..936b3d577ed5 100644 --- a/src/ray/common/client_connection.h +++ b/src/ray/common/client_connection.h @@ -8,8 +8,8 @@ #include #include -#include "ray/id.h" -#include "ray/status.h" +#include "ray/common/id.h" +#include "ray/common/status.h" namespace ray { diff --git a/src/ray/common/common_protocol.h b/src/ray/common/common_protocol.h index 63a0bf8c259c..1dbd6bbc2c9b 100644 --- a/src/ray/common/common_protocol.h +++ b/src/ray/common/common_protocol.h @@ -5,7 +5,7 @@ #include -#include "ray/id.h" +#include "ray/common/id.h" #include "ray/util/logging.h" /// Convert an unique ID to a flatbuffer string. diff --git a/src/ray/constants.h b/src/ray/common/constants.h similarity index 100% rename from src/ray/constants.h rename to src/ray/common/constants.h diff --git a/src/ray/id.cc b/src/ray/common/id.cc similarity index 96% rename from src/ray/id.cc rename to src/ray/common/id.cc index 4c9ce4dc9244..57e41d97d10c 100644 --- a/src/ray/id.cc +++ b/src/ray/common/id.cc @@ -1,4 +1,4 @@ -#include "ray/id.h" +#include "ray/common/id.h" #include @@ -6,11 +6,11 @@ #include #include -#include "ray/constants.h" -#include "ray/status.h" +#include "ray/common/constants.h" +#include "ray/common/status.h" extern "C" { -#include "thirdparty/sha256.h" +#include "ray/thirdparty/sha256.h" } // Definitions for computing hash digests. diff --git a/src/ray/id.h b/src/ray/common/id.h similarity index 99% rename from src/ray/id.h rename to src/ray/common/id.h index 7153a95f7750..3b2d244cfb46 100644 --- a/src/ray/id.h +++ b/src/ray/common/id.h @@ -11,7 +11,7 @@ #include #include "plasma/common.h" -#include "ray/constants.h" +#include "ray/common/constants.h" #include "ray/util/logging.h" #include "ray/util/visibility.h" diff --git a/src/ray/id_def.h b/src/ray/common/id_def.h similarity index 100% rename from src/ray/id_def.h rename to src/ray/common/id_def.h diff --git a/src/ray/ray_config.h b/src/ray/common/ray_config.h similarity index 100% rename from src/ray/ray_config.h rename to src/ray/common/ray_config.h diff --git a/src/ray/ray_config_def.h b/src/ray/common/ray_config_def.h similarity index 100% rename from src/ray/ray_config_def.h rename to src/ray/common/ray_config_def.h diff --git a/src/ray/status.cc b/src/ray/common/status.cc similarity index 98% rename from src/ray/status.cc rename to src/ray/common/status.cc index 4be0de442bb1..01abacde6214 100644 --- a/src/ray/status.cc +++ b/src/ray/common/status.cc @@ -12,7 +12,7 @@ // Adapted from Apache Arrow, Apache Kudu, TensorFlow -#include "ray/status.h" +#include "ray/common/status.h" #include diff --git a/src/ray/status.h b/src/ray/common/status.h similarity index 100% rename from src/ray/status.h rename to src/ray/common/status.h diff --git a/src/ray/core_worker/common.h b/src/ray/core_worker/common.h index ad9485e53826..8317bf181207 100644 --- a/src/ray/core_worker/common.h +++ b/src/ray/core_worker/common.h @@ -4,7 +4,7 @@ #include #include "ray/common/buffer.h" -#include "ray/id.h" +#include "ray/common/id.h" namespace ray { diff --git a/src/ray/core_worker/object_interface.cc b/src/ray/core_worker/object_interface.cc index c966192610c7..0b94c9d4a747 100644 --- a/src/ray/core_worker/object_interface.cc +++ b/src/ray/core_worker/object_interface.cc @@ -1,7 +1,7 @@ #include "object_interface.h" #include "context.h" #include "core_worker.h" -#include "ray/ray_config.h" +#include "ray/common/ray_config.h" namespace ray { diff --git a/src/ray/core_worker/object_interface.h b/src/ray/core_worker/object_interface.h index f14c5297c456..8a9e20c48c6e 100644 --- a/src/ray/core_worker/object_interface.h +++ b/src/ray/core_worker/object_interface.h @@ -4,8 +4,8 @@ #include "common.h" #include "plasma/client.h" #include "ray/common/buffer.h" -#include "ray/id.h" -#include "ray/status.h" +#include "ray/common/id.h" +#include "ray/common/status.h" namespace ray { diff --git a/src/ray/core_worker/task_execution.h b/src/ray/core_worker/task_execution.h index 308b1e6868d6..c4de937ee439 100644 --- a/src/ray/core_worker/task_execution.h +++ b/src/ray/core_worker/task_execution.h @@ -3,7 +3,7 @@ #include "common.h" #include "ray/common/buffer.h" -#include "ray/status.h" +#include "ray/common/status.h" namespace ray { diff --git a/src/ray/core_worker/task_interface.h b/src/ray/core_worker/task_interface.h index f667d8d5a06f..e23f049d341d 100644 --- a/src/ray/core_worker/task_interface.h +++ b/src/ray/core_worker/task_interface.h @@ -3,8 +3,8 @@ #include "common.h" #include "ray/common/buffer.h" -#include "ray/id.h" -#include "ray/status.h" +#include "ray/common/id.h" +#include "ray/common/status.h" namespace ray { diff --git a/src/ray/gcs/client.cc b/src/ray/gcs/client.cc index 642f5f2cf156..d9b5087c4719 100644 --- a/src/ray/gcs/client.cc +++ b/src/ray/gcs/client.cc @@ -1,7 +1,7 @@ #include "ray/gcs/client.h" +#include "ray/common/ray_config.h" #include "ray/gcs/redis_context.h" -#include "ray/ray_config.h" static void GetRedisShards(redisContext *context, std::vector &addresses, std::vector &ports) { diff --git a/src/ray/gcs/client.h b/src/ray/gcs/client.h index 7a5d8ef0ee9c..d47d9a6e8b24 100644 --- a/src/ray/gcs/client.h +++ b/src/ray/gcs/client.h @@ -4,10 +4,10 @@ #include #include +#include "ray/common/id.h" +#include "ray/common/status.h" #include "ray/gcs/asio.h" #include "ray/gcs/tables.h" -#include "ray/id.h" -#include "ray/status.h" #include "ray/util/logging.h" namespace ray { diff --git a/src/ray/gcs/client_test.cc b/src/ray/gcs/client_test.cc index c203a4a9482a..1b43bcc23c08 100644 --- a/src/ray/gcs/client_test.cc +++ b/src/ray/gcs/client_test.cc @@ -5,9 +5,9 @@ extern "C" { #include "ray/thirdparty/hiredis/hiredis.h" } +#include "ray/common/ray_config.h" #include "ray/gcs/client.h" #include "ray/gcs/tables.h" -#include "ray/ray_config.h" namespace ray { diff --git a/src/ray/gcs/redis_context.cc b/src/ray/gcs/redis_context.cc index 42d921d932d7..fe5ba3d1d134 100644 --- a/src/ray/gcs/redis_context.cc +++ b/src/ray/gcs/redis_context.cc @@ -14,7 +14,7 @@ extern "C" { } // TODO(pcm): Integrate into the C++ tree. -#include "ray/ray_config.h" +#include "ray/common/ray_config.h" namespace { diff --git a/src/ray/gcs/redis_context.h b/src/ray/gcs/redis_context.h index 264f61b1ceaa..fc42e5cd98c2 100644 --- a/src/ray/gcs/redis_context.h +++ b/src/ray/gcs/redis_context.h @@ -5,8 +5,8 @@ #include #include -#include "ray/id.h" -#include "ray/status.h" +#include "ray/common/id.h" +#include "ray/common/status.h" #include "ray/util/logging.h" #include "ray/gcs/format/gcs_generated.h" diff --git a/src/ray/gcs/redis_module/ray_redis_module.cc b/src/ray/gcs/redis_module/ray_redis_module.cc index 13450a4b7642..23e611e400df 100644 --- a/src/ray/gcs/redis_module/ray_redis_module.cc +++ b/src/ray/gcs/redis_module/ray_redis_module.cc @@ -2,9 +2,9 @@ #include #include "ray/common/common_protocol.h" +#include "ray/common/id.h" +#include "ray/common/status.h" #include "ray/gcs/format/gcs_generated.h" -#include "ray/id.h" -#include "ray/status.h" #include "ray/util/logging.h" #include "redis_string.h" #include "redismodule.h" diff --git a/src/ray/gcs/tables.cc b/src/ray/gcs/tables.cc index 3a381313fd21..ffc44daa049a 100644 --- a/src/ray/gcs/tables.cc +++ b/src/ray/gcs/tables.cc @@ -1,8 +1,8 @@ #include "ray/gcs/tables.h" #include "ray/common/common_protocol.h" +#include "ray/common/ray_config.h" #include "ray/gcs/client.h" -#include "ray/ray_config.h" #include "ray/util/util.h" namespace { diff --git a/src/ray/gcs/tables.h b/src/ray/gcs/tables.h index af739dc2ed32..af42509bda96 100644 --- a/src/ray/gcs/tables.h +++ b/src/ray/gcs/tables.h @@ -6,9 +6,9 @@ #include #include -#include "ray/constants.h" -#include "ray/id.h" -#include "ray/status.h" +#include "ray/common/constants.h" +#include "ray/common/id.h" +#include "ray/common/status.h" #include "ray/util/logging.h" #include "ray/gcs/format/gcs_generated.h" diff --git a/src/ray/object_manager/connection_pool.h b/src/ray/object_manager/connection_pool.h index 628769d38b1e..820cf591e70d 100644 --- a/src/ray/object_manager/connection_pool.h +++ b/src/ray/object_manager/connection_pool.h @@ -12,8 +12,8 @@ #include #include -#include "ray/id.h" -#include "ray/status.h" +#include "ray/common/id.h" +#include "ray/common/status.h" #include #include "ray/object_manager/format/object_manager_generated.h" diff --git a/src/ray/object_manager/object_buffer_pool.cc b/src/ray/object_manager/object_buffer_pool.cc index fba426d732c3..ee2a2319c0f4 100644 --- a/src/ray/object_manager/object_buffer_pool.cc +++ b/src/ray/object_manager/object_buffer_pool.cc @@ -1,6 +1,6 @@ #include "ray/object_manager/object_buffer_pool.h" -#include "ray/status.h" +#include "ray/common/status.h" #include "ray/util/logging.h" namespace ray { diff --git a/src/ray/object_manager/object_buffer_pool.h b/src/ray/object_manager/object_buffer_pool.h index 78814710ae3c..67288800498d 100644 --- a/src/ray/object_manager/object_buffer_pool.h +++ b/src/ray/object_manager/object_buffer_pool.h @@ -12,8 +12,8 @@ #include "plasma/client.h" -#include "ray/id.h" -#include "ray/status.h" +#include "ray/common/id.h" +#include "ray/common/status.h" namespace ray { diff --git a/src/ray/object_manager/object_directory.h b/src/ray/object_manager/object_directory.h index 96a2d726e241..be21e52807da 100644 --- a/src/ray/object_manager/object_directory.h +++ b/src/ray/object_manager/object_directory.h @@ -9,10 +9,10 @@ #include "plasma/client.h" +#include "ray/common/id.h" +#include "ray/common/status.h" #include "ray/gcs/client.h" -#include "ray/id.h" #include "ray/object_manager/format/object_manager_generated.h" -#include "ray/status.h" namespace ray { diff --git a/src/ray/object_manager/object_manager.h b/src/ray/object_manager/object_manager.h index cb0cff83f349..6318250ae3e8 100644 --- a/src/ray/object_manager/object_manager.h +++ b/src/ray/object_manager/object_manager.h @@ -16,8 +16,8 @@ #include "plasma/client.h" #include "ray/common/client_connection.h" -#include "ray/id.h" -#include "ray/status.h" +#include "ray/common/id.h" +#include "ray/common/status.h" #include "ray/object_manager/connection_pool.h" #include "ray/object_manager/format/object_manager_generated.h" diff --git a/src/ray/object_manager/object_manager_client_connection.h b/src/ray/object_manager/object_manager_client_connection.h index f878da67d0d3..c08b99748408 100644 --- a/src/ray/object_manager/object_manager_client_connection.h +++ b/src/ray/object_manager/object_manager_client_connection.h @@ -11,8 +11,8 @@ #include #include "ray/common/client_connection.h" -#include "ray/id.h" -#include "ray/ray_config.h" +#include "ray/common/id.h" +#include "ray/common/ray_config.h" namespace ray { diff --git a/src/ray/object_manager/object_store_notification_manager.cc b/src/ray/object_manager/object_store_notification_manager.cc index 5245a94ace3a..6f813ea4595d 100644 --- a/src/ray/object_manager/object_store_notification_manager.cc +++ b/src/ray/object_manager/object_store_notification_manager.cc @@ -5,7 +5,7 @@ #include #include -#include "ray/status.h" +#include "ray/common/status.h" #include "ray/common/common_protocol.h" #include "ray/object_manager/object_store_notification_manager.h" diff --git a/src/ray/object_manager/object_store_notification_manager.h b/src/ray/object_manager/object_store_notification_manager.h index 92db89573647..912c5dc884d0 100644 --- a/src/ray/object_manager/object_store_notification_manager.h +++ b/src/ray/object_manager/object_store_notification_manager.h @@ -11,8 +11,8 @@ #include "plasma/client.h" -#include "ray/id.h" -#include "ray/status.h" +#include "ray/common/id.h" +#include "ray/common/status.h" #include "ray/object_manager/object_directory.h" diff --git a/src/ray/object_manager/test/object_manager_stress_test.cc b/src/ray/object_manager/test/object_manager_stress_test.cc index 6d7c0be0f856..f1169605134a 100644 --- a/src/ray/object_manager/test/object_manager_stress_test.cc +++ b/src/ray/object_manager/test/object_manager_stress_test.cc @@ -5,7 +5,7 @@ #include "gtest/gtest.h" -#include "ray/status.h" +#include "ray/common/status.h" #include "ray/object_manager/object_manager.h" diff --git a/src/ray/object_manager/test/object_manager_test.cc b/src/ray/object_manager/test/object_manager_test.cc index 983a8fa7bc05..012c306938d6 100644 --- a/src/ray/object_manager/test/object_manager_test.cc +++ b/src/ray/object_manager/test/object_manager_test.cc @@ -3,14 +3,14 @@ #include "gtest/gtest.h" -#include "ray/status.h" +#include "ray/common/status.h" #include "ray/object_manager/object_manager.h" namespace { std::string store_executable; int64_t wait_timeout_ms; -} +} // namespace namespace ray { diff --git a/src/ray/raylet/actor_registration.h b/src/ray/raylet/actor_registration.h index 38533e7dbe42..8d7ce2a449ec 100644 --- a/src/ray/raylet/actor_registration.h +++ b/src/ray/raylet/actor_registration.h @@ -3,8 +3,8 @@ #include +#include "ray/common/id.h" #include "ray/gcs/format/gcs_generated.h" -#include "ray/id.h" #include "ray/raylet/task.h" namespace ray { diff --git a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc index ba9fef4f44d6..2afcba18c356 100644 --- a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc +++ b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc @@ -2,7 +2,7 @@ #include -#include "ray/id.h" +#include "ray/common/id.h" #include "ray/raylet/raylet_client.h" #include "ray/util/logging.h" diff --git a/src/ray/raylet/lineage_cache.h b/src/ray/raylet/lineage_cache.h index 045ded107474..02d98b8cffe6 100644 --- a/src/ray/raylet/lineage_cache.h +++ b/src/ray/raylet/lineage_cache.h @@ -8,8 +8,8 @@ #include "ray/common/common_protocol.h" #include "ray/raylet/task.h" #include "ray/gcs/tables.h" -#include "ray/id.h" -#include "ray/status.h" +#include "ray/common/id.h" +#include "ray/common/status.h" // clang-format on namespace ray { diff --git a/src/ray/raylet/main.cc b/src/ray/raylet/main.cc index f5b8e898558e..003e48370dbf 100644 --- a/src/ray/raylet/main.cc +++ b/src/ray/raylet/main.cc @@ -1,9 +1,9 @@ #include -#include "ray/ray_config.h" +#include "ray/common/ray_config.h" +#include "ray/common/status.h" #include "ray/raylet/raylet.h" #include "ray/stats/stats.h" -#include "ray/status.h" #include "gflags/gflags.h" diff --git a/src/ray/raylet/mock_gcs_client.h b/src/ray/raylet/mock_gcs_client.h index f84b57dbc363..a08e74f66e5e 100644 --- a/src/ray/raylet/mock_gcs_client.h +++ b/src/ray/raylet/mock_gcs_client.h @@ -12,8 +12,8 @@ #include #include -#include "ray/id.h" -#include "ray/status.h" +#include "ray/common/id.h" +#include "ray/common/status.h" namespace ray { diff --git a/src/ray/raylet/monitor.cc b/src/ray/raylet/monitor.cc index 171b4dc9439e..a87257cadda4 100644 --- a/src/ray/raylet/monitor.cc +++ b/src/ray/raylet/monitor.cc @@ -1,7 +1,7 @@ #include "ray/raylet/monitor.h" -#include "ray/ray_config.h" -#include "ray/status.h" +#include "ray/common/ray_config.h" +#include "ray/common/status.h" #include "ray/util/util.h" namespace ray { diff --git a/src/ray/raylet/monitor.h b/src/ray/raylet/monitor.h index bb698b07f674..c69cc9f003e0 100644 --- a/src/ray/raylet/monitor.h +++ b/src/ray/raylet/monitor.h @@ -4,8 +4,8 @@ #include #include +#include "ray/common/id.h" #include "ray/gcs/client.h" -#include "ray/id.h" namespace ray { diff --git a/src/ray/raylet/monitor_main.cc b/src/ray/raylet/monitor_main.cc index ed0c05aa7b81..6f3c57136ca7 100644 --- a/src/ray/raylet/monitor_main.cc +++ b/src/ray/raylet/monitor_main.cc @@ -1,6 +1,6 @@ #include -#include "ray/ray_config.h" +#include "ray/common/ray_config.h" #include "ray/raylet/monitor.h" #include "ray/util/util.h" diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 1b582d7617cb..e3fd9a0df09f 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -2,10 +2,10 @@ #include -#include "ray/status.h" +#include "ray/common/status.h" #include "ray/common/common_protocol.h" -#include "ray/id.h" +#include "ray/common/id.h" #include "ray/raylet/format/node_manager_generated.h" #include "ray/stats/stats.h" diff --git a/src/ray/raylet/object_manager_integration_test.cc b/src/ray/raylet/object_manager_integration_test.cc index a774e1409195..1b043ca58c2b 100644 --- a/src/ray/raylet/object_manager_integration_test.cc +++ b/src/ray/raylet/object_manager_integration_test.cc @@ -3,7 +3,7 @@ #include "gtest/gtest.h" -#include "ray/status.h" +#include "ray/common/status.h" #include "ray/raylet/raylet.h" diff --git a/src/ray/raylet/raylet.cc b/src/ray/raylet/raylet.cc index 3b0ebd5b691f..dd9e5fac318e 100644 --- a/src/ray/raylet/raylet.cc +++ b/src/ray/raylet/raylet.cc @@ -5,7 +5,7 @@ #include #include -#include "ray/status.h" +#include "ray/common/status.h" namespace { @@ -37,7 +37,7 @@ static const std::vector object_manager_message_enum = GenerateEnumNames(ray::object_manager::protocol::EnumNamesMessageType(), static_cast(ray::object_manager::protocol::MessageType::MIN), static_cast(ray::object_manager::protocol::MessageType::MAX)); -} +} // namespace namespace ray { diff --git a/src/ray/raylet/raylet_client.cc b/src/ray/raylet/raylet_client.cc index ac312b79d13e..801bb9112241 100644 --- a/src/ray/raylet/raylet_client.cc +++ b/src/ray/raylet/raylet_client.cc @@ -13,7 +13,7 @@ #include #include "ray/common/common_protocol.h" -#include "ray/ray_config.h" +#include "ray/common/ray_config.h" #include "ray/raylet/format/node_manager_generated.h" #include "ray/raylet/task_spec.h" #include "ray/util/logging.h" diff --git a/src/ray/raylet/raylet_client.h b/src/ray/raylet/raylet_client.h index 0bdd076b5577..8b4dfad5b37a 100644 --- a/src/ray/raylet/raylet_client.h +++ b/src/ray/raylet/raylet_client.h @@ -6,8 +6,8 @@ #include #include +#include "ray/common/status.h" #include "ray/raylet/task_spec.h" -#include "ray/status.h" using ray::ActorCheckpointID; using ray::ActorID; diff --git a/src/ray/raylet/reconstruction_policy.h b/src/ray/raylet/reconstruction_policy.h index fc441de8a65d..cd969cc2706e 100644 --- a/src/ray/raylet/reconstruction_policy.h +++ b/src/ray/raylet/reconstruction_policy.h @@ -7,8 +7,8 @@ #include +#include "ray/common/id.h" #include "ray/gcs/tables.h" -#include "ray/id.h" #include "ray/util/util.h" #include "ray/object_manager/object_directory.h" diff --git a/src/ray/raylet/scheduling_queue.cc b/src/ray/raylet/scheduling_queue.cc index 29af345b8391..85295e403769 100644 --- a/src/ray/raylet/scheduling_queue.cc +++ b/src/ray/raylet/scheduling_queue.cc @@ -2,8 +2,8 @@ #include +#include "ray/common/status.h" #include "ray/stats/stats.h" -#include "ray/status.h" namespace { diff --git a/src/ray/raylet/task_dependency_manager.h b/src/ray/raylet/task_dependency_manager.h index 9691b3d64170..3788a5eae7ae 100644 --- a/src/ray/raylet/task_dependency_manager.h +++ b/src/ray/raylet/task_dependency_manager.h @@ -2,7 +2,7 @@ #define RAY_RAYLET_TASK_DEPENDENCY_MANAGER_H // clang-format off -#include "ray/id.h" +#include "ray/common/id.h" #include "ray/raylet/task.h" #include "ray/object_manager/object_manager.h" #include "ray/raylet/reconstruction_policy.h" diff --git a/src/ray/raylet/task_execution_spec.h b/src/ray/raylet/task_execution_spec.h index fc26e7e275a8..6fc3b833ace8 100644 --- a/src/ray/raylet/task_execution_spec.h +++ b/src/ray/raylet/task_execution_spec.h @@ -3,7 +3,7 @@ #include -#include "ray/id.h" +#include "ray/common/id.h" #include "ray/raylet/format/node_manager_generated.h" namespace ray { diff --git a/src/ray/raylet/task_spec.h b/src/ray/raylet/task_spec.h index 6bb2cdad972c..d557c188ae68 100644 --- a/src/ray/raylet/task_spec.h +++ b/src/ray/raylet/task_spec.h @@ -6,8 +6,8 @@ #include #include +#include "ray/common/id.h" #include "ray/gcs/format/gcs_generated.h" -#include "ray/id.h" #include "ray/raylet/scheduling_resources.h" extern "C" { diff --git a/src/ray/raylet/worker.h b/src/ray/raylet/worker.h index 4860342e3578..cb0797dddce8 100644 --- a/src/ray/raylet/worker.h +++ b/src/ray/raylet/worker.h @@ -4,7 +4,7 @@ #include #include "ray/common/client_connection.h" -#include "ray/id.h" +#include "ray/common/id.h" #include "ray/raylet/scheduling_resources.h" namespace ray { diff --git a/src/ray/raylet/worker_pool.cc b/src/ray/raylet/worker_pool.cc index 27e7fea05311..43698c53f0d8 100644 --- a/src/ray/raylet/worker_pool.cc +++ b/src/ray/raylet/worker_pool.cc @@ -5,9 +5,9 @@ #include #include -#include "ray/ray_config.h" +#include "ray/common/ray_config.h" +#include "ray/common/status.h" #include "ray/stats/stats.h" -#include "ray/status.h" #include "ray/util/logging.h" namespace { diff --git a/src/ray/util/util.h b/src/ray/util/util.h index ba34cb7339e9..f86870858c5c 100644 --- a/src/ray/util/util.h +++ b/src/ray/util/util.h @@ -4,7 +4,7 @@ #include #include -#include "ray/status.h" +#include "ray/common/status.h" /// Return the number of milliseconds since the steady clock epoch. NOTE: The /// returned timestamp may be used for accurately measuring intervals but has From ffaae1c5be80d5b5d507a02b3ea50e53f54c09bd Mon Sep 17 00:00:00 2001 From: Stefan Pantic Date: Wed, 5 Jun 2019 12:09:44 +0200 Subject: [PATCH 069/118] Fix error --- python/ray/rllib/agents/impala/vtrace_policy.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/ray/rllib/agents/impala/vtrace_policy.py b/python/ray/rllib/agents/impala/vtrace_policy.py index d241b1e91008..7fd137bae08b 100644 --- a/python/ray/rllib/agents/impala/vtrace_policy.py +++ b/python/ray/rllib/agents/impala/vtrace_policy.py @@ -274,8 +274,8 @@ def make_time_major(tensor, drop_last=False): with tf.name_scope('kl_divergence'): # KL divergence between worker and learner logits for debugging - model_dist = MultiCategorical(unpacked_outputs) - behaviour_dist = MultiCategorical(unpacked_behaviour_logits) + model_dist = MultiCategorical(self.model.outputs, output_hidden_shape) + behaviour_dist = MultiCategorical(behaviour_logits, output_hidden_shape) kls = model_dist.kl(behaviour_dist) if len(kls) > 1: From 2702b15b04f3e8a84f65f98ccb6e7300755a217f Mon Sep 17 00:00:00 2001 From: Timon Ruban Date: Wed, 5 Jun 2019 18:04:36 +0200 Subject: [PATCH 070/118] [tune] Add requirements-dev.txt and update docs for contributing (#4925) * Add requirements-dev.txt and update docs. * Update doc/source/tune-contrib.rst Co-Authored-By: Richard Liaw * Unpin everything except for yapf. --- doc/source/tune-contrib.rst | 15 ++++++++++----- python/ray/tune/requirements-dev.txt | 9 +++++++++ 2 files changed, 19 insertions(+), 5 deletions(-) create mode 100644 python/ray/tune/requirements-dev.txt diff --git a/doc/source/tune-contrib.rst b/doc/source/tune-contrib.rst index f945ee679353..4774791e333d 100644 --- a/doc/source/tune-contrib.rst +++ b/doc/source/tune-contrib.rst @@ -15,10 +15,9 @@ We welcome (and encourage!) all forms of contributions to Tune, including and no Setting up a development environment ------------------------------------ -If you have Ray installed via pip (``pip install -U ray``), you can develop Tune locally without needing to compile Ray. +If you have Ray installed via pip (``pip install -U [link to wheel]`` - you can find the link to the latest wheel `here `__), you can develop Tune locally without needing to compile Ray. - -First, you will need your own [fork](https://help.github.com/en/articles/fork-a-repo) to work on the code. Press the Fork button on the `ray project page `__. +First, you will need your own `fork `__ to work on the code. Press the Fork button on the `ray project page `__. Then, clone the project to your machine and connect your repository to the upstream (main project) ray repository. .. code-block:: shell @@ -28,10 +27,16 @@ Then, clone the project to your machine and connect your repository to the upstr git remote add upstream https://github.com/ray-project/ray.git +Before continuing, make sure that your git branch is in sync with the installed Ray binaries (i.e., you are up-to-date on `master `__ and have the latest `wheel `__ installed.) + Then, run `[path to ray directory]/python/ray/setup-dev.py` `(also here on Github) `__ script. This sets up links between the ``tune`` dir (among other directories) in your local repo and the one bundled with the ``ray`` package. -When using this script, make sure that your git branch is in sync with the installed Ray binaries (i.e., you are up-to-date on `master `__ and have the latest `wheel `__ installed.) +As a last step make sure to install all packages required for development of tune. This can be done by running: + +.. code-block:: shell + + pip install -r [path to ray directory]/python/ray/tune/requirements-dev.txt What can I work on? @@ -89,7 +94,7 @@ burden and speedup review process. Documentation should be documented in `Google style `__ format. We also have tests for code formatting and linting that need to pass before merge. -Install `yapf==0.23, flake8, flake8-quotes`. You can run the following locally: +Install `yapf==0.23, flake8, flake8-quotes` (these are also in the `requirements-dev.txt` found in ``python/ray/tune``). You can run the following locally: .. code-block:: shell diff --git a/python/ray/tune/requirements-dev.txt b/python/ray/tune/requirements-dev.txt new file mode 100644 index 000000000000..9d3d3ddab12f --- /dev/null +++ b/python/ray/tune/requirements-dev.txt @@ -0,0 +1,9 @@ +flake8 +flake8-quotes +gym +opencv-python +pandas +requests +tabulate +tensorflow +yapf==0.23.0 From 82b3972d4279e6f2fd3e3eebb848207289250137 Mon Sep 17 00:00:00 2001 From: Stefan Pantic Date: Thu, 6 Jun 2019 13:30:55 +0200 Subject: [PATCH 071/118] Fix compute actions return value --- python/ray/rllib/policy/policy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/rllib/policy/policy.py b/python/ray/rllib/policy/policy.py index e12cafef2cc4..1593dc860a51 100644 --- a/python/ray/rllib/policy/policy.py +++ b/python/ray/rllib/policy/policy.py @@ -122,7 +122,7 @@ def compute_single_action(self, info_batch = [info] if episode is not None: episodes = [episode] - [action], state_out, info = self.compute_actions( + [action], state_out, info, *_ = self.compute_actions( [obs], [[s] for s in state], prev_action_batch=prev_action_batch, prev_reward_batch=prev_reward_batch, From a0f14e9e6c6c5a10178472ab7b453187cfc73a49 Mon Sep 17 00:00:00 2001 From: Robert Nishihara Date: Thu, 6 Jun 2019 11:20:06 -0700 Subject: [PATCH 072/118] Bump version from 0.7.1 to 0.8.0.dev1. (#4937) --- python/ray/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/__init__.py b/python/ray/__init__.py index 421b1c6838ac..03792e5eb48a 100644 --- a/python/ray/__init__.py +++ b/python/ray/__init__.py @@ -96,7 +96,7 @@ from ray.runtime_context import _get_runtime_context # noqa: E402 # Ray version string. -__version__ = "0.7.1" +__version__ = "0.8.0.dev1" __all__ = [ "global_state", From c3f8fc1c44e4fe2219a3356cc2b28527d2de196b Mon Sep 17 00:00:00 2001 From: Robert Nishihara Date: Thu, 6 Jun 2019 17:22:45 -0700 Subject: [PATCH 073/118] Update version number in documentation after release 0.7.0 -> 0.7.1 and 0.8.0.dev0 -> 0.8.0.dev1. (#4941) --- README.rst | 2 +- .../run_perf_integration.sh | 2 +- .../application_cluster_template.yaml | 4 ++-- ci/stress_tests/stress_testing_config.yaml | 2 +- doc/source/installation.rst | 16 ++++++++-------- docker/stress_test/Dockerfile | 2 +- docker/tune_test/Dockerfile | 2 +- python/ray/autoscaler/aws/example-full.yaml | 6 +++--- .../ray/autoscaler/aws/example-gpu-docker.yaml | 6 +++--- python/ray/autoscaler/gcp/example-full.yaml | 6 +++--- .../ray/autoscaler/gcp/example-gpu-docker.yaml | 6 +++--- src/ray/raylet/main.cc | 2 +- 12 files changed, 28 insertions(+), 28 deletions(-) diff --git a/README.rst b/README.rst index aed1d2e81ae9..bf210f4cb97c 100644 --- a/README.rst +++ b/README.rst @@ -6,7 +6,7 @@ .. image:: https://readthedocs.org/projects/ray/badge/?version=latest :target: http://ray.readthedocs.io/en/latest/?badge=latest -.. image:: https://img.shields.io/badge/pypi-0.7.0-blue.svg +.. image:: https://img.shields.io/badge/pypi-0.7.1-blue.svg :target: https://pypi.org/project/ray/ | diff --git a/ci/jenkins_tests/perf_integration_tests/run_perf_integration.sh b/ci/jenkins_tests/perf_integration_tests/run_perf_integration.sh index f723d5122981..7962b21075c0 100755 --- a/ci/jenkins_tests/perf_integration_tests/run_perf_integration.sh +++ b/ci/jenkins_tests/perf_integration_tests/run_perf_integration.sh @@ -9,7 +9,7 @@ pushd "$ROOT_DIR" python -m pip install pytest-benchmark -pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp27-cp27mu-manylinux1_x86_64.whl +pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev1-cp27-cp27mu-manylinux1_x86_64.whl python -m pytest --benchmark-autosave --benchmark-min-rounds=10 --benchmark-columns="min, max, mean" $ROOT_DIR/../../../python/ray/tests/perf_integration_tests/test_perf_integration.py pushd $ROOT_DIR/../../../python diff --git a/ci/stress_tests/application_cluster_template.yaml b/ci/stress_tests/application_cluster_template.yaml index 541419da55af..d6ccf4769b04 100644 --- a/ci/stress_tests/application_cluster_template.yaml +++ b/ci/stress_tests/application_cluster_template.yaml @@ -90,8 +90,8 @@ file_mounts: { # List of shell commands to run to set up nodes. setup_commands: - echo 'export PATH="$HOME/anaconda3/envs/tensorflow_<<>>/bin:$PATH"' >> ~/.bashrc - - ray || wget https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-<<>>-manylinux1_x86_64.whl - - rllib || pip install -U ray-0.8.0.dev0-<<>>-manylinux1_x86_64.whl[rllib] + - ray || wget https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev1-<<>>-manylinux1_x86_64.whl + - rllib || pip install -U ray-0.8.0.dev1-<<>>-manylinux1_x86_64.whl[rllib] - pip install tensorflow-gpu==1.12.0 - echo "sudo halt" | at now + 60 minutes # Consider uncommenting these if you also want to run apt-get commands during setup diff --git a/ci/stress_tests/stress_testing_config.yaml b/ci/stress_tests/stress_testing_config.yaml index f71ae8f2dc18..793c1338432d 100644 --- a/ci/stress_tests/stress_testing_config.yaml +++ b/ci/stress_tests/stress_testing_config.yaml @@ -101,7 +101,7 @@ setup_commands: # - ray/ci/travis/install-bazel.sh - pip install boto3==1.4.8 cython==0.29.0 # - cd ray/python; git checkout master; git pull; pip install -e . --verbose - - pip install https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp36-cp36m-manylinux1_x86_64.whl + - pip install https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev1-cp36-cp36m-manylinux1_x86_64.whl - echo "sudo halt" | at now + 60 minutes # Custom commands that will be run on the head node after common setup. diff --git a/doc/source/installation.rst b/doc/source/installation.rst index ad92cb347e83..b7cb27c831b6 100644 --- a/doc/source/installation.rst +++ b/doc/source/installation.rst @@ -33,14 +33,14 @@ Here are links to the latest wheels (which are built off of master). To install =================== =================== -.. _`Linux Python 3.7`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp37-cp37m-manylinux1_x86_64.whl -.. _`Linux Python 3.6`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp36-cp36m-manylinux1_x86_64.whl -.. _`Linux Python 3.5`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp35-cp35m-manylinux1_x86_64.whl -.. _`Linux Python 2.7`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp27-cp27mu-manylinux1_x86_64.whl -.. _`MacOS Python 3.7`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp37-cp37m-macosx_10_6_intel.whl -.. _`MacOS Python 3.6`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp36-cp36m-macosx_10_6_intel.whl -.. _`MacOS Python 3.5`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp35-cp35m-macosx_10_6_intel.whl -.. _`MacOS Python 2.7`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp27-cp27m-macosx_10_6_intel.whl +.. _`Linux Python 3.7`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev1-cp37-cp37m-manylinux1_x86_64.whl +.. _`Linux Python 3.6`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev1-cp36-cp36m-manylinux1_x86_64.whl +.. _`Linux Python 3.5`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev1-cp35-cp35m-manylinux1_x86_64.whl +.. _`Linux Python 2.7`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev1-cp27-cp27mu-manylinux1_x86_64.whl +.. _`MacOS Python 3.7`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev1-cp37-cp37m-macosx_10_6_intel.whl +.. _`MacOS Python 3.6`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev1-cp36-cp36m-macosx_10_6_intel.whl +.. _`MacOS Python 3.5`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev1-cp35-cp35m-macosx_10_6_intel.whl +.. _`MacOS Python 2.7`: https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev1-cp27-cp27m-macosx_10_6_intel.whl Building Ray from source diff --git a/docker/stress_test/Dockerfile b/docker/stress_test/Dockerfile index 664370eb0479..1d174ed72f92 100644 --- a/docker/stress_test/Dockerfile +++ b/docker/stress_test/Dockerfile @@ -4,7 +4,7 @@ FROM ray-project/base-deps # We install ray and boto3 to enable the ray autoscaler as # a test runner. -RUN pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp27-cp27mu-manylinux1_x86_64.whl boto3 +RUN pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev1-cp27-cp27mu-manylinux1_x86_64.whl boto3 RUN mkdir -p /root/.ssh/ # We port the source code in so that we run the most up-to-date stress tests. diff --git a/docker/tune_test/Dockerfile b/docker/tune_test/Dockerfile index b0cf426c1b1d..6e098d5218f6 100644 --- a/docker/tune_test/Dockerfile +++ b/docker/tune_test/Dockerfile @@ -4,7 +4,7 @@ FROM ray-project/base-deps # We install ray and boto3 to enable the ray autoscaler as # a test runner. -RUN pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp27-cp27mu-manylinux1_x86_64.whl boto3 +RUN pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev1-cp27-cp27mu-manylinux1_x86_64.whl boto3 # We install this after the latest wheels -- this should not override the latest wheels. RUN apt-get install -y zlib1g-dev RUN pip install gym[atari]==0.10.11 opencv-python-headless tensorflow lz4 keras pytest-timeout smart_open diff --git a/python/ray/autoscaler/aws/example-full.yaml b/python/ray/autoscaler/aws/example-full.yaml index 7399450aeedb..b3ecd22e7d5d 100644 --- a/python/ray/autoscaler/aws/example-full.yaml +++ b/python/ray/autoscaler/aws/example-full.yaml @@ -113,9 +113,9 @@ setup_commands: # has your Ray repo pre-cloned. Then, you can replace the pip installs # below with a git checkout (and possibly a recompile). - echo 'export PATH="$HOME/anaconda3/envs/tensorflow_p36/bin:$PATH"' >> ~/.bashrc - # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp27-cp27mu-manylinux1_x86_64.whl - # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp35-cp35m-manylinux1_x86_64.whl - - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp36-cp36m-manylinux1_x86_64.whl + # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev1-cp27-cp27mu-manylinux1_x86_64.whl + # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev1-cp35-cp35m-manylinux1_x86_64.whl + - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev1-cp36-cp36m-manylinux1_x86_64.whl # Consider uncommenting these if you also want to run apt-get commands during setup # - sudo pkill -9 apt-get || true # - sudo pkill -9 dpkg || true diff --git a/python/ray/autoscaler/aws/example-gpu-docker.yaml b/python/ray/autoscaler/aws/example-gpu-docker.yaml index 79fdc055b091..b63030a48344 100644 --- a/python/ray/autoscaler/aws/example-gpu-docker.yaml +++ b/python/ray/autoscaler/aws/example-gpu-docker.yaml @@ -105,9 +105,9 @@ file_mounts: { # List of shell commands to run to set up nodes. setup_commands: - # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp27-cp27mu-manylinux1_x86_64.whl - - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp35-cp35m-manylinux1_x86_64.whl - # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp36-cp36m-manylinux1_x86_64.whl + # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev1-cp27-cp27mu-manylinux1_x86_64.whl + - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev1-cp35-cp35m-manylinux1_x86_64.whl + # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev1-cp36-cp36m-manylinux1_x86_64.whl # Custom commands that will be run on the head node after common setup. head_setup_commands: diff --git a/python/ray/autoscaler/gcp/example-full.yaml b/python/ray/autoscaler/gcp/example-full.yaml index 4ab2093dd865..c307f1b10103 100644 --- a/python/ray/autoscaler/gcp/example-full.yaml +++ b/python/ray/autoscaler/gcp/example-full.yaml @@ -127,9 +127,9 @@ setup_commands: && echo 'export PATH="$HOME/anaconda3/bin:$PATH"' >> ~/.profile # Install ray - # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp27-cp27mu-manylinux1_x86_64.whl - # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp35-cp35m-manylinux1_x86_64.whl - - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp36-cp36m-manylinux1_x86_64.whl + # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev1-cp27-cp27mu-manylinux1_x86_64.whl + # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev1-cp35-cp35m-manylinux1_x86_64.whl + - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev1-cp36-cp36m-manylinux1_x86_64.whl # Custom commands that will be run on the head node after common setup. diff --git a/python/ray/autoscaler/gcp/example-gpu-docker.yaml b/python/ray/autoscaler/gcp/example-gpu-docker.yaml index 75e0497094cb..5bb5eb9fb980 100644 --- a/python/ray/autoscaler/gcp/example-gpu-docker.yaml +++ b/python/ray/autoscaler/gcp/example-gpu-docker.yaml @@ -140,9 +140,9 @@ setup_commands: # - echo 'export PATH="$HOME/anaconda3/envs/tensorflow_p36/bin:$PATH"' >> ~/.bashrc # Install ray - # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp27-cp27mu-manylinux1_x86_64.whl - - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp35-cp35m-manylinux1_x86_64.whl - # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev0-cp36-cp36m-manylinux1_x86_64.whl + # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev1-cp27-cp27mu-manylinux1_x86_64.whl + - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev1-cp35-cp35m-manylinux1_x86_64.whl + # - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev1-cp36-cp36m-manylinux1_x86_64.whl # Custom commands that will be run on the head node after common setup. head_setup_commands: diff --git a/src/ray/raylet/main.cc b/src/ray/raylet/main.cc index 003e48370dbf..e75981d8b752 100644 --- a/src/ray/raylet/main.cc +++ b/src/ray/raylet/main.cc @@ -69,7 +69,7 @@ int main(int argc, char *argv[]) { // Initialize stats. const ray::stats::TagsType global_tags = { {ray::stats::JobNameKey, "raylet"}, - {ray::stats::VersionKey, "0.7.0"}, + {ray::stats::VersionKey, "0.7.1"}, {ray::stats::NodeAddressKey, node_ip_address}}; ray::stats::Init(stat_address, global_tags, disable_stats, enable_stdout_exporter); From cbc67fc75097b3fb977c20cccdb613e57fdd4e29 Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Thu, 6 Jun 2019 18:18:24 -0700 Subject: [PATCH 074/118] [doc] Update developer docs with bazel instructions (#4944) --- doc/source/development.rst | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/doc/source/development.rst b/doc/source/development.rst index 1fdc65fa35cf..ecbed6c31f9e 100644 --- a/doc/source/development.rst +++ b/doc/source/development.rst @@ -29,8 +29,11 @@ recompile much more quickly by doing .. code-block:: shell - cd ray/build - make -j8 + cd ray + bazel build //:ray_pkg + +This command is not enough to recompile all C++ unit tests. To do so, see +`Testing locally`_. Debugging --------- @@ -144,6 +147,14 @@ When running tests, usually only the first test failure matters. A single test failure often triggers the failure of subsequent tests in the same script. +To compile and run all C++ tests, you can run: + +.. code-block:: shell + + cd ray + bazel test $(bazel query 'kind(cc_test, ...)') + + Linting ------- From 5eff47b657dd8ed74fc1ff4b5e2339d60924d49e Mon Sep 17 00:00:00 2001 From: Yuhong Guo Date: Fri, 7 Jun 2019 16:11:37 +0800 Subject: [PATCH 075/118] [C++] Add hash table to Redis-Module (#4911) --- BUILD.bazel | 2 +- doc/source/conf.py | 2 +- java/BUILD.bazel | 2 +- python/ray/gcs_utils.py | 4 +- python/ray/monitor.py | 6 +- python/ray/state.py | 22 +- python/ray/worker.py | 2 +- src/ray/gcs/client.cc | 3 + src/ray/gcs/client.h | 2 + src/ray/gcs/client_test.cc | 172 ++++++++++++++- src/ray/gcs/format/gcs.fbs | 8 +- src/ray/gcs/redis_module/ray_redis_module.cc | 217 +++++++++++++++---- src/ray/gcs/tables.cc | 162 +++++++++++++- src/ray/gcs/tables.h | 159 +++++++++++++- src/ray/object_manager/object_directory.cc | 16 +- 15 files changed, 686 insertions(+), 93 deletions(-) diff --git a/BUILD.bazel b/BUILD.bazel index 90b0f536a10b..36f02e292fa1 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -535,7 +535,7 @@ flatbuffer_py_library( "ErrorTableData.py", "ErrorType.py", "FunctionTableData.py", - "GcsTableEntry.py", + "GcsEntry.py", "HeartbeatBatchTableData.py", "HeartbeatTableData.py", "Language.py", diff --git a/doc/source/conf.py b/doc/source/conf.py index b0ae3416d4ab..98fb3e0d02dd 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -29,7 +29,7 @@ "ray.core.generated.EntryType", "ray.core.generated.ErrorTableData", "ray.core.generated.ErrorType", - "ray.core.generated.GcsTableEntry", + "ray.core.generated.GcsEntry", "ray.core.generated.HeartbeatBatchTableData", "ray.core.generated.HeartbeatTableData", "ray.core.generated.Language", diff --git a/java/BUILD.bazel b/java/BUILD.bazel index f86df8d40f96..f3ae6f063304 100644 --- a/java/BUILD.bazel +++ b/java/BUILD.bazel @@ -160,7 +160,7 @@ flatbuffers_generated_files = [ "ErrorTableData.java", "ErrorType.java", "FunctionTableData.java", - "GcsTableEntry.java", + "GcsEntry.java", "HeartbeatBatchTableData.java", "HeartbeatTableData.java", "Language.java", diff --git a/python/ray/gcs_utils.py b/python/ray/gcs_utils.py index 15eec6c81136..cadd197ec73f 100644 --- a/python/ray/gcs_utils.py +++ b/python/ray/gcs_utils.py @@ -9,7 +9,7 @@ from ray.core.generated.ClientTableData import ClientTableData from ray.core.generated.DriverTableData import DriverTableData from ray.core.generated.ErrorTableData import ErrorTableData -from ray.core.generated.GcsTableEntry import GcsTableEntry +from ray.core.generated.GcsEntry import GcsEntry from ray.core.generated.HeartbeatBatchTableData import HeartbeatBatchTableData from ray.core.generated.HeartbeatTableData import HeartbeatTableData from ray.core.generated.Language import Language @@ -25,7 +25,7 @@ "ClientTableData", "DriverTableData", "ErrorTableData", - "GcsTableEntry", + "GcsEntry", "HeartbeatBatchTableData", "HeartbeatTableData", "Language", diff --git a/python/ray/monitor.py b/python/ray/monitor.py index 09a154d7b548..c9e0424b3eb8 100644 --- a/python/ray/monitor.py +++ b/python/ray/monitor.py @@ -101,8 +101,7 @@ def subscribe(self, channel): def xray_heartbeat_batch_handler(self, unused_channel, data): """Handle an xray heartbeat batch message from Redis.""" - gcs_entries = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry( - data, 0) + gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(data, 0) heartbeat_data = gcs_entries.Entries(0) message = (ray.gcs_utils.HeartbeatBatchTableData. @@ -208,8 +207,7 @@ def xray_driver_removed_handler(self, unused_channel, data): unused_channel: The message channel. data: The message data. """ - gcs_entries = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry( - data, 0) + gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(data, 0) driver_data = gcs_entries.Entries(0) message = ray.gcs_utils.DriverTableData.GetRootAsDriverTableData( driver_data, 0) diff --git a/python/ray/state.py b/python/ray/state.py index 6b2c8a4ef8bc..14ba49987ec4 100644 --- a/python/ray/state.py +++ b/python/ray/state.py @@ -41,7 +41,7 @@ def _parse_client_table(redis_client): return [] node_info = {} - gcs_entry = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry(message, 0) + gcs_entry = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0) ordered_client_ids = [] @@ -248,8 +248,7 @@ def _object_table(self, object_id): object_id.binary()) if message is None: return {} - gcs_entry = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry( - message, 0) + gcs_entry = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0) assert gcs_entry.EntriesLength() > 0 @@ -307,8 +306,7 @@ def _task_table(self, task_id): "", task_id.binary()) if message is None: return {} - gcs_entries = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry( - message, 0) + gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0) assert gcs_entries.EntriesLength() == 1 @@ -431,8 +429,7 @@ def _profile_table(self, batch_id): if message is None: return [] - gcs_entries = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry( - message, 0) + gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0) profile_events = [] for i in range(gcs_entries.EntriesLength()): @@ -815,9 +812,8 @@ def available_resources(self): ray.gcs_utils.XRAY_HEARTBEAT_CHANNEL): continue data = raw_message["data"] - gcs_entries = ( - ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry( - data, 0)) + gcs_entries = (ray.gcs_utils.GcsEntry.GetRootAsGcsEntry( + data, 0)) heartbeat_data = gcs_entries.Entries(0) message = (ray.gcs_utils.HeartbeatTableData. GetRootAsHeartbeatTableData(heartbeat_data, 0)) @@ -871,8 +867,7 @@ def _error_messages(self, driver_id): if message is None: return [] - gcs_entries = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry( - message, 0) + gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0) error_messages = [] for i in range(gcs_entries.EntriesLength()): error_data = ray.gcs_utils.ErrorTableData.GetRootAsErrorTableData( @@ -934,8 +929,7 @@ def actor_checkpoint_info(self, actor_id): ) if message is None: return None - gcs_entry = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry( - message, 0) + gcs_entry = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0) entry = ( ray.gcs_utils.ActorCheckpointIdData.GetRootAsActorCheckpointIdData( gcs_entry.Entries(0), 0)) diff --git a/python/ray/worker.py b/python/ray/worker.py index 7786c742d9b1..7505120574a6 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -1656,7 +1656,7 @@ def listen_error_messages_raylet(worker, task_error_queue, threads_stopped): if msg is None: threads_stopped.wait(timeout=0.01) continue - gcs_entry = ray.gcs_utils.GcsTableEntry.GetRootAsGcsTableEntry( + gcs_entry = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry( msg["data"], 0) assert gcs_entry.EntriesLength() == 1 error_data = ray.gcs_utils.ErrorTableData.GetRootAsErrorTableData( diff --git a/src/ray/gcs/client.cc b/src/ray/gcs/client.cc index d9b5087c4719..3d1c6602740c 100644 --- a/src/ray/gcs/client.cc +++ b/src/ray/gcs/client.cc @@ -120,6 +120,7 @@ AsyncGcsClient::AsyncGcsClient(const std::string &address, int port, profile_table_.reset(new ProfileTable(shard_contexts_, this)); actor_checkpoint_table_.reset(new ActorCheckpointTable(shard_contexts_, this)); actor_checkpoint_id_table_.reset(new ActorCheckpointIdTable(shard_contexts_, this)); + resource_table_.reset(new DynamicResourceTable({primary_context_}, this)); command_type_ = command_type; // TODO(swang): Call the client table's Connect() method here. To do this, @@ -229,6 +230,8 @@ ActorCheckpointIdTable &AsyncGcsClient::actor_checkpoint_id_table() { return *actor_checkpoint_id_table_; } +DynamicResourceTable &AsyncGcsClient::resource_table() { return *resource_table_; } + } // namespace gcs } // namespace ray diff --git a/src/ray/gcs/client.h b/src/ray/gcs/client.h index d47d9a6e8b24..c9f5b4bca624 100644 --- a/src/ray/gcs/client.h +++ b/src/ray/gcs/client.h @@ -62,6 +62,7 @@ class RAY_EXPORT AsyncGcsClient { ProfileTable &profile_table(); ActorCheckpointTable &actor_checkpoint_table(); ActorCheckpointIdTable &actor_checkpoint_id_table(); + DynamicResourceTable &resource_table(); // We also need something to export generic code to run on workers from the // driver (to set the PYTHONPATH) @@ -94,6 +95,7 @@ class RAY_EXPORT AsyncGcsClient { std::unique_ptr client_table_; std::unique_ptr actor_checkpoint_table_; std::unique_ptr actor_checkpoint_id_table_; + std::unique_ptr resource_table_; // The following contexts write to the data shard std::vector> shard_contexts_; std::vector> shard_asio_async_clients_; diff --git a/src/ray/gcs/client_test.cc b/src/ray/gcs/client_test.cc index 1b43bcc23c08..4eb34a95328a 100644 --- a/src/ray/gcs/client_test.cc +++ b/src/ray/gcs/client_test.cc @@ -657,13 +657,12 @@ void TestSetSubscribeAll(const DriverID &driver_id, // Callback for a notification. auto notification_callback = [object_ids, managers]( - gcs::AsyncGcsClient *client, const ObjectID &id, - const GcsTableNotificationMode notification_mode, + gcs::AsyncGcsClient *client, const ObjectID &id, const GcsChangeMode change_mode, const std::vector data) { if (test->NumCallbacks() < 3 * 3) { - ASSERT_EQ(notification_mode, GcsTableNotificationMode::APPEND_OR_ADD); + ASSERT_EQ(change_mode, GcsChangeMode::APPEND_OR_ADD); } else { - ASSERT_EQ(notification_mode, GcsTableNotificationMode::REMOVE); + ASSERT_EQ(change_mode, GcsChangeMode::REMOVE); } ASSERT_EQ(id, object_ids[test->NumCallbacks() / 3 % 3]); // Check that we get notifications in the same order as the writes. @@ -894,10 +893,9 @@ void TestSetSubscribeId(const DriverID &driver_id, // The callback for a notification from the table. This should only be // received for keys that we requested notifications for. auto notification_callback = [object_id2, managers2]( - gcs::AsyncGcsClient *client, const ObjectID &id, - const GcsTableNotificationMode notification_mode, + gcs::AsyncGcsClient *client, const ObjectID &id, const GcsChangeMode change_mode, const std::vector &data) { - ASSERT_EQ(notification_mode, GcsTableNotificationMode::APPEND_OR_ADD); + ASSERT_EQ(change_mode, GcsChangeMode::APPEND_OR_ADD); // Check that we only get notifications for the requested key. ASSERT_EQ(id, object_id2); // Check that we get notifications in the same order as the writes. @@ -1111,10 +1109,9 @@ void TestSetSubscribeCancel(const DriverID &driver_id, // The callback for a notification from the object table. This should only be // received for the object that we requested notifications for. auto notification_callback = [object_id, managers]( - gcs::AsyncGcsClient *client, const ObjectID &id, - const GcsTableNotificationMode notification_mode, + gcs::AsyncGcsClient *client, const ObjectID &id, const GcsChangeMode change_mode, const std::vector &data) { - ASSERT_EQ(notification_mode, GcsTableNotificationMode::APPEND_OR_ADD); + ASSERT_EQ(change_mode, GcsChangeMode::APPEND_OR_ADD); ASSERT_EQ(id, object_id); // Check that we get a duplicate notification for the first write. We get a // duplicate notification because notifications @@ -1307,6 +1304,161 @@ TEST_F(TestGcsWithAsio, TestClientTableMarkDisconnected) { TestClientTableMarkDisconnected(driver_id_, client_); } +void TestHashTable(const DriverID &driver_id, + std::shared_ptr client) { + const int expected_count = 14; + ClientID client_id = ClientID::FromRandom(); + // Prepare the first resource map: data_map1. + auto cpu_data = std::make_shared(); + cpu_data->resource_name = "CPU"; + cpu_data->resource_capacity = 100; + auto gpu_data = std::make_shared(); + gpu_data->resource_name = "GPU"; + gpu_data->resource_capacity = 2; + DynamicResourceTable::DataMap data_map1; + data_map1.emplace("CPU", cpu_data); + data_map1.emplace("GPU", gpu_data); + // Prepare the second resource map: data_map2 which decreases CPU, + // increases GPU and add a new CUSTOM compared to data_map1. + auto data_cpu = std::make_shared(); + data_cpu->resource_name = "CPU"; + data_cpu->resource_capacity = 50; + auto data_gpu = std::make_shared(); + data_gpu->resource_name = "GPU"; + data_gpu->resource_capacity = 10; + auto data_custom = std::make_shared(); + data_custom->resource_name = "CUSTOM"; + data_custom->resource_capacity = 2; + DynamicResourceTable::DataMap data_map2; + data_map2.emplace("CPU", data_cpu); + data_map2.emplace("GPU", data_gpu); + data_map2.emplace("CUSTOM", data_custom); + data_map2["CPU"]->resource_capacity = 50; + // This is a common comparison function for the test. + auto compare_test = [](const DynamicResourceTable::DataMap &data1, + const DynamicResourceTable::DataMap &data2) { + ASSERT_EQ(data1.size(), data2.size()); + for (const auto &data : data1) { + auto iter = data2.find(data.first); + ASSERT_TRUE(iter != data2.end()); + ASSERT_EQ(iter->second->resource_name, data.second->resource_name); + ASSERT_EQ(iter->second->resource_capacity, data.second->resource_capacity); + } + }; + auto subscribe_callback = [](AsyncGcsClient *client) { + ASSERT_TRUE(true); + test->IncrementNumCallbacks(); + }; + auto notification_callback = [data_map1, data_map2, compare_test]( + AsyncGcsClient *client, const ClientID &id, const GcsChangeMode change_mode, + const DynamicResourceTable::DataMap &data) { + if (change_mode == GcsChangeMode::REMOVE) { + ASSERT_EQ(data.size(), 2); + ASSERT_TRUE(data.find("GPU") != data.end()); + ASSERT_TRUE(data.find("CUSTOM") != data.end() || data.find("CPU") != data.end()); + // The key "None-Existent" will not appear in the notification. + } else { + if (data.size() == 2) { + compare_test(data_map1, data); + } else if (data.size() == 3) { + compare_test(data_map2, data); + } else { + ASSERT_TRUE(false); + } + } + test->IncrementNumCallbacks(); + // It is not sure which of the notification or lookup callback will come first. + if (test->NumCallbacks() == expected_count) { + test->Stop(); + } + }; + // Step 0: Subscribe the change of the hash table. + RAY_CHECK_OK(client->resource_table().Subscribe( + driver_id, ClientID::Nil(), notification_callback, subscribe_callback)); + RAY_CHECK_OK(client->resource_table().RequestNotifications( + driver_id, client_id, client->client_table().GetLocalClientId())); + + // Step 1: Add elements to the hash table. + auto update_callback1 = [data_map1, compare_test]( + AsyncGcsClient *client, const ClientID &id, + const DynamicResourceTable::DataMap &callback_data) { + compare_test(data_map1, callback_data); + test->IncrementNumCallbacks(); + }; + RAY_CHECK_OK( + client->resource_table().Update(driver_id, client_id, data_map1, update_callback1)); + auto lookup_callback1 = [data_map1, compare_test]( + AsyncGcsClient *client, const ClientID &id, + const DynamicResourceTable::DataMap &callback_data) { + compare_test(data_map1, callback_data); + test->IncrementNumCallbacks(); + }; + RAY_CHECK_OK(client->resource_table().Lookup(driver_id, client_id, lookup_callback1)); + + // Step 2: Decrease one element, increase one and add a new one. + RAY_CHECK_OK(client->resource_table().Update(driver_id, client_id, data_map2, nullptr)); + auto lookup_callback2 = [data_map2, compare_test]( + AsyncGcsClient *client, const ClientID &id, + const DynamicResourceTable::DataMap &callback_data) { + compare_test(data_map2, callback_data); + test->IncrementNumCallbacks(); + }; + RAY_CHECK_OK(client->resource_table().Lookup(driver_id, client_id, lookup_callback2)); + std::vector delete_keys({"GPU", "CUSTOM", "None-Existent"}); + auto remove_callback = [delete_keys](AsyncGcsClient *client, const ClientID &id, + const std::vector &callback_data) { + for (int i = 0; i < callback_data.size(); ++i) { + // All deleting keys exist in this argument even if the key doesn't exist. + ASSERT_EQ(callback_data[i], delete_keys[i]); + } + test->IncrementNumCallbacks(); + }; + RAY_CHECK_OK(client->resource_table().RemoveEntries(driver_id, client_id, delete_keys, + remove_callback)); + DynamicResourceTable::DataMap data_map3(data_map2); + data_map3.erase("GPU"); + data_map3.erase("CUSTOM"); + auto lookup_callback3 = [data_map3, compare_test]( + AsyncGcsClient *client, const ClientID &id, + const DynamicResourceTable::DataMap &callback_data) { + compare_test(data_map3, callback_data); + test->IncrementNumCallbacks(); + }; + RAY_CHECK_OK(client->resource_table().Lookup(driver_id, client_id, lookup_callback3)); + + // Step 3: Reset the the resources to data_map1. + RAY_CHECK_OK( + client->resource_table().Update(driver_id, client_id, data_map1, update_callback1)); + auto lookup_callback4 = [data_map1, compare_test]( + AsyncGcsClient *client, const ClientID &id, + const DynamicResourceTable::DataMap &callback_data) { + compare_test(data_map1, callback_data); + test->IncrementNumCallbacks(); + }; + RAY_CHECK_OK(client->resource_table().Lookup(driver_id, client_id, lookup_callback4)); + + // Step 4: Removing all elements will remove the home Hash table from GCS. + RAY_CHECK_OK(client->resource_table().RemoveEntries( + driver_id, client_id, {"GPU", "CPU", "CUSTOM", "None-Existent"}, nullptr)); + auto lookup_callback5 = [](AsyncGcsClient *client, const ClientID &id, + const DynamicResourceTable::DataMap &callback_data) { + ASSERT_EQ(callback_data.size(), 0); + test->IncrementNumCallbacks(); + // It is not sure which of notification or lookup callback will come first. + if (test->NumCallbacks() == expected_count) { + test->Stop(); + } + }; + RAY_CHECK_OK(client->resource_table().Lookup(driver_id, client_id, lookup_callback5)); + test->Start(); + ASSERT_EQ(test->NumCallbacks(), expected_count); +} + +TEST_F(TestGcsWithAsio, TestHashTable) { + test = this; + TestHashTable(driver_id_, client_); +} + #undef TEST_MACRO } // namespace gcs diff --git a/src/ray/gcs/format/gcs.fbs b/src/ray/gcs/format/gcs.fbs index b81f388d88c5..614c80b27672 100644 --- a/src/ray/gcs/format/gcs.fbs +++ b/src/ray/gcs/format/gcs.fbs @@ -22,6 +22,7 @@ enum TablePrefix:int { TASK_LEASE, ACTOR_CHECKPOINT, ACTOR_CHECKPOINT_ID, + NODE_RESOURCE, } // The channel that Add operations to the Table should be published on, if any. @@ -37,6 +38,7 @@ enum TablePubsub:int { ERROR_INFO, TASK_LEASE, DRIVER, + NODE_RESOURCE, } // Enum for the entry type in the ClientTable @@ -113,13 +115,13 @@ table ResourcePair { value: double; } -enum GcsTableNotificationMode:int { +enum GcsChangeMode:int { APPEND_OR_ADD = 0, REMOVE, } -table GcsTableEntry { - notification_mode: GcsTableNotificationMode; +table GcsEntry { + change_mode: GcsChangeMode; id: string; entries: [string]; } diff --git a/src/ray/gcs/redis_module/ray_redis_module.cc b/src/ray/gcs/redis_module/ray_redis_module.cc index 23e611e400df..e059787472f1 100644 --- a/src/ray/gcs/redis_module/ray_redis_module.cc +++ b/src/ray/gcs/redis_module/ray_redis_module.cc @@ -179,32 +179,20 @@ flatbuffers::Offset RedisStringToFlatbuf( return fbb.CreateString(redis_string_str, redis_string_size); } -/// Publish a notification for an entry update at a key. This publishes a -/// notification to all subscribers of the table, as well as every client that -/// has requested notifications for this key. +/// Helper method to publish formatted data to target channel. /// /// \param pubsub_channel_str The pubsub channel name that notifications for /// this key should be published to. When publishing to a specific client, the /// channel name should be :. /// \param id The ID of the key that the notification is about. -/// \param mode the update mode, such as append or remove. -/// \param data The appended/removed data. +/// \param data_buffer The data to publish, which is a GcsEntry buffer. /// \return OK if there is no error during a publish. -int PublishTableUpdate(RedisModuleCtx *ctx, RedisModuleString *pubsub_channel_str, - RedisModuleString *id, GcsTableNotificationMode notification_mode, - RedisModuleString *data) { - // Serialize the notification to send. - flatbuffers::FlatBufferBuilder fbb; - auto data_flatbuf = RedisStringToFlatbuf(fbb, data); - auto message = - CreateGcsTableEntry(fbb, notification_mode, RedisStringToFlatbuf(fbb, id), - fbb.CreateVector(&data_flatbuf, 1)); - fbb.Finish(message); - +int PublishDataHelper(RedisModuleCtx *ctx, RedisModuleString *pubsub_channel_str, + RedisModuleString *id, RedisModuleString *data_buffer) { // Write the data back to any subscribers that are listening to all table // notifications. - RedisModuleCallReply *reply = RedisModule_Call(ctx, "PUBLISH", "sb", pubsub_channel_str, - fbb.GetBufferPointer(), fbb.GetSize()); + RedisModuleCallReply *reply = + RedisModule_Call(ctx, "PUBLISH", "ss", pubsub_channel_str, data_buffer); if (reply == NULL) { return RedisModule_ReplyWithError(ctx, "error during PUBLISH"); } @@ -221,8 +209,8 @@ int PublishTableUpdate(RedisModuleCtx *ctx, RedisModuleString *pubsub_channel_st // will be garbage collected by redis. auto channel = RedisModule_CreateString(ctx, client_channel.data(), client_channel.size()); - RedisModuleCallReply *reply = RedisModule_Call( - ctx, "PUBLISH", "sb", channel, fbb.GetBufferPointer(), fbb.GetSize()); + RedisModuleCallReply *reply = + RedisModule_Call(ctx, "PUBLISH", "ss", channel, data_buffer); if (reply == NULL) { return RedisModule_ReplyWithError(ctx, "error during PUBLISH"); } @@ -231,6 +219,31 @@ int PublishTableUpdate(RedisModuleCtx *ctx, RedisModuleString *pubsub_channel_st return RedisModule_ReplyWithSimpleString(ctx, "OK"); } +/// Publish a notification for an entry update at a key. This publishes a +/// notification to all subscribers of the table, as well as every client that +/// has requested notifications for this key. +/// +/// \param pubsub_channel_str The pubsub channel name that notifications for +/// this key should be published to. When publishing to a specific client, the +/// channel name should be :. +/// \param id The ID of the key that the notification is about. +/// \param mode the update mode, such as append or remove. +/// \param data The appended/removed data. +/// \return OK if there is no error during a publish. +int PublishTableUpdate(RedisModuleCtx *ctx, RedisModuleString *pubsub_channel_str, + RedisModuleString *id, GcsChangeMode change_mode, + RedisModuleString *data) { + // Serialize the notification to send. + flatbuffers::FlatBufferBuilder fbb; + auto data_flatbuf = RedisStringToFlatbuf(fbb, data); + auto message = CreateGcsEntry(fbb, change_mode, RedisStringToFlatbuf(fbb, id), + fbb.CreateVector(&data_flatbuf, 1)); + fbb.Finish(message); + auto data_buffer = RedisModule_CreateString( + ctx, reinterpret_cast(fbb.GetBufferPointer()), fbb.GetSize()); + return PublishDataHelper(ctx, pubsub_channel_str, id, data_buffer); +} + // RAY.TABLE_ADD: // TableAdd_RedisCommand: the actual command handler. // (helper) TableAdd_DoWrite: performs the write to redis state. @@ -266,8 +279,8 @@ int TableAdd_DoPublish(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) if (pubsub_channel != TablePubsub::NO_PUBLISH) { // All other pubsub channels write the data back directly onto the channel. - return PublishTableUpdate(ctx, pubsub_channel_str, id, - GcsTableNotificationMode::APPEND_OR_ADD, data); + return PublishTableUpdate(ctx, pubsub_channel_str, id, GcsChangeMode::APPEND_OR_ADD, + data); } else { return RedisModule_ReplyWithSimpleString(ctx, "OK"); } @@ -366,8 +379,8 @@ int TableAppend_DoPublish(RedisModuleCtx *ctx, RedisModuleString **argv, int /*a if (pubsub_channel != TablePubsub::NO_PUBLISH) { // All other pubsub channels write the data back directly onto the // channel. - return PublishTableUpdate(ctx, pubsub_channel_str, id, - GcsTableNotificationMode::APPEND_OR_ADD, data); + return PublishTableUpdate(ctx, pubsub_channel_str, id, GcsChangeMode::APPEND_OR_ADD, + data); } else { return RedisModule_ReplyWithSimpleString(ctx, "OK"); } @@ -419,10 +432,9 @@ int Set_DoPublish(RedisModuleCtx *ctx, RedisModuleString **argv, bool is_add) { if (pubsub_channel != TablePubsub::NO_PUBLISH) { // All other pubsub channels write the data back directly onto the // channel. - return PublishTableUpdate(ctx, pubsub_channel_str, id, - is_add ? GcsTableNotificationMode::APPEND_OR_ADD - : GcsTableNotificationMode::REMOVE, - data); + return PublishTableUpdate( + ctx, pubsub_channel_str, id, + is_add ? GcsChangeMode::APPEND_OR_ADD : GcsChangeMode::REMOVE, data); } else { return RedisModule_ReplyWithSimpleString(ctx, "OK"); } @@ -518,7 +530,125 @@ int SetRemove_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int ar return RedisModule_ReplyWithSimpleString(ctx, "OK"); } -/// A helper function to create and finish a GcsTableEntry, based on the +int Hash_DoPublish(RedisModuleCtx *ctx, RedisModuleString **argv) { + RedisModuleString *pubsub_channel_str = argv[2]; + RedisModuleString *id = argv[3]; + RedisModuleString *data = argv[4]; + // Publish a message on the requested pubsub channel if necessary. + TablePubsub pubsub_channel; + REPLY_AND_RETURN_IF_NOT_OK(ParseTablePubsub(&pubsub_channel, pubsub_channel_str)); + if (pubsub_channel != TablePubsub::NO_PUBLISH) { + // All other pubsub channels write the data back directly onto the + // channel. + return PublishDataHelper(ctx, pubsub_channel_str, id, data); + } else { + return RedisModule_ReplyWithSimpleString(ctx, "OK"); + } +} + +/// Do the hash table write operation. This is called from by HashUpdate_RedisCommand. +/// +/// \param change_mode Output the mode of the operation: APPEND_OR_ADD or REMOVE. +/// \param deleted_data Output data if the deleted data is not the same as required. +int HashUpdate_DoWrite(RedisModuleCtx *ctx, RedisModuleString **argv, int argc, + GcsChangeMode *change_mode, RedisModuleString **changed_data) { + if (argc != 5) { + return RedisModule_WrongArity(ctx); + } + RedisModuleString *prefix_str = argv[1]; + RedisModuleString *id = argv[3]; + RedisModuleString *update_data = argv[4]; + + RedisModuleKey *key; + REPLY_AND_RETURN_IF_NOT_OK(OpenPrefixedKey( + &key, ctx, prefix_str, id, REDISMODULE_READ | REDISMODULE_WRITE, nullptr)); + int type = RedisModule_KeyType(key); + REPLY_AND_RETURN_IF_FALSE( + type == REDISMODULE_KEYTYPE_HASH || type == REDISMODULE_KEYTYPE_EMPTY, + "HashUpdate_DoWrite: entries must be a hash or an empty hash"); + + size_t update_data_len = 0; + const char *update_data_buf = RedisModule_StringPtrLen(update_data, &update_data_len); + + auto data_vec = flatbuffers::GetRoot(update_data_buf); + *change_mode = data_vec->change_mode(); + if (*change_mode == GcsChangeMode::APPEND_OR_ADD) { + // This code path means they are updating command. + size_t total_size = data_vec->entries()->size(); + REPLY_AND_RETURN_IF_FALSE(total_size % 2 == 0, "Invalid Hash Update data vector."); + for (int i = 0; i < total_size; i += 2) { + // Reconstruct a key-value pair from a flattened list. + RedisModuleString *entry_key = RedisModule_CreateString( + ctx, data_vec->entries()->Get(i)->data(), data_vec->entries()->Get(i)->size()); + RedisModuleString *entry_value = + RedisModule_CreateString(ctx, data_vec->entries()->Get(i + 1)->data(), + data_vec->entries()->Get(i + 1)->size()); + // Returning 0 if key exists(still updated), 1 if the key is created. + RAY_IGNORE_EXPR( + RedisModule_HashSet(key, REDISMODULE_HASH_NONE, entry_key, entry_value, NULL)); + } + *changed_data = update_data; + } else { + // This code path means the command wants to remove the entries. + size_t total_size = data_vec->entries()->size(); + flatbuffers::FlatBufferBuilder fbb; + std::vector> data; + for (int i = 0; i < total_size; i++) { + RedisModuleString *entry_key = RedisModule_CreateString( + ctx, data_vec->entries()->Get(i)->data(), data_vec->entries()->Get(i)->size()); + int deleted_num = RedisModule_HashSet(key, REDISMODULE_HASH_NONE, entry_key, + REDISMODULE_HASH_DELETE, NULL); + if (deleted_num != 0) { + // The corresponding key is removed. + data.push_back(fbb.CreateString(data_vec->entries()->Get(i)->data(), + data_vec->entries()->Get(i)->size())); + } + } + auto message = + CreateGcsEntry(fbb, data_vec->change_mode(), + fbb.CreateString(data_vec->id()->data(), data_vec->id()->size()), + fbb.CreateVector(data)); + fbb.Finish(message); + *changed_data = RedisModule_CreateString( + ctx, reinterpret_cast(fbb.GetBufferPointer()), fbb.GetSize()); + auto size = RedisModule_ValueLength(key); + if (size == 0) { + REPLY_AND_RETURN_IF_FALSE(RedisModule_DeleteKey(key) == REDISMODULE_OK, + "ERR Failed to delete empty hash."); + } + } + return REDISMODULE_OK; +} + +/// Update entries for a hash table. +/// +/// This is called from a client with the command: +// +/// RAY.HASH_UPDATE +/// +/// \param table_prefix The prefix string for keys in this table. +/// \param pubsub_channel The pubsub channel name that notifications for this +/// key should be published to. When publishing to a specific client, the +/// channel name should be :. +/// \param id The ID of the key to remove from. +/// \param data The GcsEntry flatbugger data used to update this hash table. +/// 1). For deletion, this is a list of keys. +/// 2). For updating, this is a list of pairs with each key followed by the value. +/// \return OK if the remove succeeds, or an error message string if the remove +/// fails. +int HashUpdate_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) { + GcsChangeMode mode; + RedisModuleString *changed_data = nullptr; + if (HashUpdate_DoWrite(ctx, argv, argc, &mode, &changed_data) != REDISMODULE_OK) { + return REDISMODULE_ERR; + } + // Replace the data with the changed data to do the publish. + std::vector new_argv(argv, argv + argc); + new_argv[4] = changed_data; + return Hash_DoPublish(ctx, new_argv.data()); +} + +/// A helper function to create and finish a GcsEntry, based on the /// current value or values at the given key. /// /// \param ctx The Redis module context. @@ -528,7 +658,7 @@ int SetRemove_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int ar /// \param prefix_str The string prefix associated with the open Redis key. /// When parsed, this is expected to be a TablePrefix. /// \param entry_id The UniqueID associated with the open Redis key. -/// \param fbb A flatbuffer builder used to build the GcsTableEntry. +/// \param fbb A flatbuffer builder used to build the GcsEntry. Status TableEntryToFlatbuf(RedisModuleCtx *ctx, RedisModuleKey *table_key, RedisModuleString *prefix_str, RedisModuleString *entry_id, flatbuffers::FlatBufferBuilder &fbb) { @@ -539,12 +669,13 @@ Status TableEntryToFlatbuf(RedisModuleCtx *ctx, RedisModuleKey *table_key, size_t data_len = 0; char *data_buf = RedisModule_StringDMA(table_key, &data_len, REDISMODULE_READ); auto data = fbb.CreateString(data_buf, data_len); - auto message = CreateGcsTableEntry(fbb, GcsTableNotificationMode::APPEND_OR_ADD, - RedisStringToFlatbuf(fbb, entry_id), - fbb.CreateVector(&data, 1)); + auto message = + CreateGcsEntry(fbb, GcsChangeMode::APPEND_OR_ADD, + RedisStringToFlatbuf(fbb, entry_id), fbb.CreateVector(&data, 1)); fbb.Finish(message); } break; case REDISMODULE_KEYTYPE_LIST: + case REDISMODULE_KEYTYPE_HASH: case REDISMODULE_KEYTYPE_SET: { RedisModule_CloseKey(table_key); // Close the key before executing the command. NOTE(swang): According to @@ -561,10 +692,13 @@ Status TableEntryToFlatbuf(RedisModuleCtx *ctx, RedisModuleKey *table_key, case REDISMODULE_KEYTYPE_SET: reply = RedisModule_Call(ctx, "SMEMBERS", "s", table_key_str); break; + case REDISMODULE_KEYTYPE_HASH: + reply = RedisModule_Call(ctx, "HGETALL", "s", table_key_str); + break; } // Build the flatbuffer from the set of log entries. if (reply == nullptr || RedisModule_CallReplyType(reply) != REDISMODULE_REPLY_ARRAY) { - return Status::RedisError("Empty list or wrong type"); + return Status::RedisError("Empty list/set/hash or wrong type"); } std::vector> data; for (size_t i = 0; i < RedisModule_CallReplyLength(reply); i++) { @@ -574,13 +708,13 @@ Status TableEntryToFlatbuf(RedisModuleCtx *ctx, RedisModuleKey *table_key, data.push_back(fbb.CreateString(element_str, len)); } auto message = - CreateGcsTableEntry(fbb, GcsTableNotificationMode::APPEND_OR_ADD, - RedisStringToFlatbuf(fbb, entry_id), fbb.CreateVector(data)); + CreateGcsEntry(fbb, GcsChangeMode::APPEND_OR_ADD, + RedisStringToFlatbuf(fbb, entry_id), fbb.CreateVector(data)); fbb.Finish(message); } break; case REDISMODULE_KEYTYPE_EMPTY: { - auto message = CreateGcsTableEntry( - fbb, GcsTableNotificationMode::APPEND_OR_ADD, RedisStringToFlatbuf(fbb, entry_id), + auto message = CreateGcsEntry( + fbb, GcsChangeMode::APPEND_OR_ADD, RedisStringToFlatbuf(fbb, entry_id), fbb.CreateVector(std::vector>())); fbb.Finish(message); } break; @@ -637,6 +771,7 @@ static Status DeleteKeyHelper(RedisModuleCtx *ctx, RedisModuleString *prefix_str return Status::RedisError("Key does not exist."); } auto key_type = RedisModule_KeyType(delete_key); + // Set/Hash will delete itself when the length is 0. if (key_type == REDISMODULE_KEYTYPE_STRING || key_type == REDISMODULE_KEYTYPE_LIST) { // Current Table or Log only has this two types of entries. RAY_RETURN_NOT_OK( @@ -873,6 +1008,7 @@ int DebugString_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int // Wrap all Redis commands with Redis' auto memory management. AUTO_MEMORY(TableAdd_RedisCommand); +AUTO_MEMORY(HashUpdate_RedisCommand); AUTO_MEMORY(TableAppend_RedisCommand); AUTO_MEMORY(SetAdd_RedisCommand); AUTO_MEMORY(SetRemove_RedisCommand); @@ -929,6 +1065,11 @@ int RedisModule_OnLoad(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) return REDISMODULE_ERR; } + if (RedisModule_CreateCommand(ctx, "ray.hash_update", HashUpdate_RedisCommand, + "write pubsub", 0, 0, 0) == REDISMODULE_ERR) { + return REDISMODULE_ERR; + } + if (RedisModule_CreateCommand(ctx, "ray.table_request_notifications", TableRequestNotifications_RedisCommand, "write pubsub", 0, 0, 0) == REDISMODULE_ERR) { diff --git a/src/ray/gcs/tables.cc b/src/ray/gcs/tables.cc index ffc44daa049a..e20384a04bdc 100644 --- a/src/ray/gcs/tables.cc +++ b/src/ray/gcs/tables.cc @@ -92,7 +92,7 @@ Status Log::Lookup(const DriverID &driver_id, const ID &id, std::vector results; if (!reply.IsNil()) { const auto data = reply.ReadAsString(); - auto root = flatbuffers::GetRoot(data.data()); + auto root = flatbuffers::GetRoot(data.data()); RAY_CHECK(from_flatbuf(*root->id()) == id); for (size_t i = 0; i < root->entries()->size(); i++) { DataT result; @@ -114,9 +114,9 @@ Status Log::Subscribe(const DriverID &driver_id, const ClientID &clien const Callback &subscribe, const SubscriptionCallback &done) { auto subscribe_wrapper = [subscribe](AsyncGcsClient *client, const ID &id, - const GcsTableNotificationMode notification_mode, + const GcsChangeMode change_mode, const std::vector &data) { - RAY_CHECK(notification_mode != GcsTableNotificationMode::REMOVE); + RAY_CHECK(change_mode != GcsChangeMode::REMOVE); subscribe(client, id, data); }; return Subscribe(driver_id, client_id, subscribe_wrapper, done); @@ -141,7 +141,7 @@ Status Log::Subscribe(const DriverID &driver_id, const ClientID &clien // Data is provided. This is the callback for a message. if (subscribe != nullptr) { // Parse the notification. - auto root = flatbuffers::GetRoot(data.data()); + auto root = flatbuffers::GetRoot(data.data()); ID id; if (root->id()->size() > 0) { id = from_flatbuf(*root->id()); @@ -153,7 +153,7 @@ Status Log::Subscribe(const DriverID &driver_id, const ClientID &clien data_root->UnPackTo(&result); results.emplace_back(std::move(result)); } - subscribe(client_, id, root->notification_mode(), results); + subscribe(client_, id, root->change_mode(), results); } } }; @@ -339,6 +339,155 @@ std::string Set::DebugString() const { return result.str(); } +template +Status Hash::Update(const DriverID &driver_id, const ID &id, + const DataMap &data_map, const HashCallback &done) { + num_adds_++; + auto callback = [this, id, data_map, done](const CallbackReply &reply) { + if (done != nullptr) { + (done)(client_, id, data_map); + } + }; + flatbuffers::FlatBufferBuilder fbb; + std::vector> data_vec; + data_vec.reserve(data_map.size() * 2); + for (auto const &pair : data_map) { + // Add the key. + data_vec.push_back(fbb.CreateString(pair.first)); + flatbuffers::FlatBufferBuilder fbb_data; + fbb_data.ForceDefaults(true); + fbb_data.Finish(Data::Pack(fbb_data, pair.second.get())); + std::string data(reinterpret_cast(fbb_data.GetBufferPointer()), + fbb_data.GetSize()); + // Add the value. + data_vec.push_back(fbb.CreateString(data)); + } + + fbb.Finish(CreateGcsEntry(fbb, GcsChangeMode::APPEND_OR_ADD, + fbb.CreateString(id.Binary()), fbb.CreateVector(data_vec))); + return GetRedisContext(id)->RunAsync("RAY.HASH_UPDATE", id, fbb.GetBufferPointer(), + fbb.GetSize(), prefix_, pubsub_channel_, + std::move(callback)); +} + +template +Status Hash::RemoveEntries(const DriverID &driver_id, const ID &id, + const std::vector &keys, + const HashRemoveCallback &remove_callback) { + num_removes_++; + auto callback = [this, id, keys, remove_callback](const CallbackReply &reply) { + if (remove_callback != nullptr) { + (remove_callback)(client_, id, keys); + } + }; + flatbuffers::FlatBufferBuilder fbb; + std::vector> data_vec; + data_vec.reserve(keys.size()); + // Add the keys. + for (auto const &key : keys) { + data_vec.push_back(fbb.CreateString(key)); + } + + fbb.Finish(CreateGcsEntry(fbb, GcsChangeMode::REMOVE, fbb.CreateString(id.Binary()), + fbb.CreateVector(data_vec))); + return GetRedisContext(id)->RunAsync("RAY.HASH_UPDATE", id, fbb.GetBufferPointer(), + fbb.GetSize(), prefix_, pubsub_channel_, + std::move(callback)); +} + +template +std::string Hash::DebugString() const { + std::stringstream result; + result << "num lookups: " << num_lookups_ << ", num adds: " << num_adds_ + << ", num removes: " << num_removes_; + return result.str(); +} + +template +Status Hash::Lookup(const DriverID &driver_id, const ID &id, + const HashCallback &lookup) { + num_lookups_++; + auto callback = [this, id, lookup](const CallbackReply &reply) { + if (lookup != nullptr) { + DataMap results; + if (!reply.IsNil()) { + const auto data = reply.ReadAsString(); + auto root = flatbuffers::GetRoot(data.data()); + RAY_CHECK(from_flatbuf(*root->id()) == id); + RAY_CHECK(root->entries()->size() % 2 == 0); + for (size_t i = 0; i < root->entries()->size(); i += 2) { + std::string key(root->entries()->Get(i)->data(), + root->entries()->Get(i)->size()); + auto result = std::make_shared(); + auto data_root = + flatbuffers::GetRoot(root->entries()->Get(i + 1)->data()); + data_root->UnPackTo(result.get()); + results.emplace(key, std::move(result)); + } + } + lookup(client_, id, results); + } + }; + std::vector nil; + return GetRedisContext(id)->RunAsync("RAY.TABLE_LOOKUP", id, nil.data(), nil.size(), + prefix_, pubsub_channel_, std::move(callback)); +} + +template +Status Hash::Subscribe(const DriverID &driver_id, const ClientID &client_id, + const HashNotificationCallback &subscribe, + const SubscriptionCallback &done) { + RAY_CHECK(subscribe_callback_index_ == -1) + << "Client called Subscribe twice on the same table"; + auto callback = [this, subscribe, done](const CallbackReply &reply) { + const auto data = reply.ReadAsPubsubData(); + if (data.empty()) { + // No notification data is provided. This is the callback for the + // initial subscription request. + if (done != nullptr) { + done(client_); + } + } else { + // Data is provided. This is the callback for a message. + if (subscribe != nullptr) { + // Parse the notification. + auto root = flatbuffers::GetRoot(data.data()); + DataMap data_map; + ID id; + if (root->id()->size() > 0) { + id = from_flatbuf(*root->id()); + } + if (root->change_mode() == GcsChangeMode::REMOVE) { + for (size_t i = 0; i < root->entries()->size(); i++) { + std::string key(root->entries()->Get(i)->data(), + root->entries()->Get(i)->size()); + data_map.emplace(key, std::shared_ptr()); + } + } else { + RAY_CHECK(root->entries()->size() % 2 == 0); + for (size_t i = 0; i < root->entries()->size(); i += 2) { + std::string key(root->entries()->Get(i)->data(), + root->entries()->Get(i)->size()); + auto result = std::make_shared(); + auto data_root = + flatbuffers::GetRoot(root->entries()->Get(i + 1)->data()); + data_root->UnPackTo(result.get()); + data_map.emplace(key, std::move(result)); + } + } + subscribe(client_, id, root->change_mode(), data_map); + } + } + }; + + subscribe_callback_index_ = 1; + for (auto &context : shard_contexts_) { + RAY_RETURN_NOT_OK(context->SubscribeAsync(client_id, pubsub_channel_, callback, + &subscribe_callback_index_)); + } + return Status::OK(); +} + Status ErrorTable::PushErrorToDriver(const DriverID &driver_id, const std::string &type, const std::string &error_message, double timestamp) { auto data = std::make_shared(); @@ -696,6 +845,9 @@ template class Log; template class Table; template class Table; +template class Log; +template class Hash; + } // namespace gcs } // namespace ray diff --git a/src/ray/gcs/tables.h b/src/ray/gcs/tables.h index af42509bda96..6a1d502a7f54 100644 --- a/src/ray/gcs/tables.h +++ b/src/ray/gcs/tables.h @@ -75,9 +75,9 @@ class Log : public LogInterface, virtual public PubsubInterface { using DataT = typename Data::NativeTableType; using Callback = std::function &data)>; - using NotificationCallback = std::function &data)>; + using NotificationCallback = std::function &data)>; /// The callback to call when a write to a key succeeds. using WriteCallback = typename LogInterface::WriteCallback; /// The callback to call when a SUBSCRIBE call completes and we are ready to @@ -214,7 +214,7 @@ class Log : public LogInterface, virtual public PubsubInterface { /// to subscribe to all modifications, or to subscribe only to keys that it /// requests notifications for. This may only be called once per Log /// instance. This function is different from public version due to - /// an additional parameter notification_mode in NotificationCallback. Therefore this + /// an additional parameter change_mode in NotificationCallback. Therefore this /// function supports notifications of remove operations. /// /// \param driver_id The ID of the job (= driver). @@ -451,6 +451,157 @@ class Set : private Log, using Log::num_lookups_; }; +template +class HashInterface { + public: + using DataT = typename Data::NativeTableType; + using DataMap = std::unordered_map>; + // Reuse Log's SubscriptionCallback when Subscribe is successfully called. + using SubscriptionCallback = typename Log::SubscriptionCallback; + + /// The callback function used by function Update & Lookup. + /// + /// \param client The client on which the RemoveEntries is called. + /// \param id The ID of the Hash Table whose entries are removed. + /// \param data Map data contains the change to the Hash Table. + /// \return Void + using HashCallback = + std::function; + + /// The callback function used by function RemoveEntries. + /// + /// \param client The client on which the RemoveEntries is called. + /// \param id The ID of the Hash Table whose entries are removed. + /// \param keys The keys that are moved from this Hash Table. + /// \return Void + using HashRemoveCallback = std::function &keys)>; + + /// The notification function used by function Subscribe. + /// + /// \param client The client on which the Subscribe is called. + /// \param change_mode The mode to identify the data is removed or updated. + /// \param data Map data contains the change to the Hash Table. + /// \return Void + using HashNotificationCallback = + std::function; + + /// Add entries of a hash table. + /// + /// \param driver_id The ID of the job (= driver). + /// \param id The ID of the data that is added to the GCS. + /// \param pairs Map data to add to the hash table. + /// \param done HashCallback that is called once the request data has been written to + /// the GCS. + /// \return Status + virtual Status Update(const DriverID &driver_id, const ID &id, const DataMap &pairs, + const HashCallback &done) = 0; + + /// Remove entries from the hash table. + /// + /// \param driver_id The ID of the job (= driver). + /// \param id The ID of the data that is removed from the GCS. + /// \param keys The entry keys of the hash table. + /// \param remove_callback HashRemoveCallback that is called once the data has been + /// written to the GCS no matter whether the key exists in the hash table. + /// \return Status + virtual Status RemoveEntries(const DriverID &driver_id, const ID &id, + const std::vector &keys, + const HashRemoveCallback &remove_callback) = 0; + + /// Lookup the map data of a hash table. + /// + /// \param driver_id The ID of the job (= driver). + /// \param id The ID of the data that is looked up in the GCS. + /// \param lookup HashCallback that is called after lookup. If the callback is + /// called with an empty hash table, then there was no data in the callback. + /// \return Status + virtual Status Lookup(const DriverID &driver_id, const ID &id, + const HashCallback &lookup) = 0; + + /// Subscribe to any Update or Remove operations to this hash table. + /// + /// \param driver_id The ID of the driver. + /// \param client_id The type of update to listen to. If this is nil, then a + /// message for each Update to the table will be received. Else, only + /// messages for the given client will be received. In the latter + /// case, the client may request notifications on specific keys in the + /// table via `RequestNotifications`. + /// \param subscribe HashNotificationCallback that is called on each received message. + /// \param done SubscriptionCallback that is called when subscription is complete and + /// we are ready to receive messages. + /// \return Status + virtual Status Subscribe(const DriverID &driver_id, const ClientID &client_id, + const HashNotificationCallback &subscribe, + const SubscriptionCallback &done) = 0; + + virtual ~HashInterface(){}; +}; + +template +class Hash : private Log, + public HashInterface, + virtual public PubsubInterface { + public: + using DataT = typename Log::DataT; + using DataMap = std::unordered_map>; + using HashCallback = typename HashInterface::HashCallback; + using HashRemoveCallback = typename HashInterface::HashRemoveCallback; + using HashNotificationCallback = + typename HashInterface::HashNotificationCallback; + using SubscriptionCallback = typename Log::SubscriptionCallback; + + Hash(const std::vector> &contexts, AsyncGcsClient *client) + : Log(contexts, client) {} + + using Log::RequestNotifications; + using Log::CancelNotifications; + + Status Update(const DriverID &driver_id, const ID &id, const DataMap &pairs, + const HashCallback &done) override; + + Status Subscribe(const DriverID &driver_id, const ClientID &client_id, + const HashNotificationCallback &subscribe, + const SubscriptionCallback &done) override; + + Status Lookup(const DriverID &driver_id, const ID &id, + const HashCallback &lookup) override; + + Status RemoveEntries(const DriverID &driver_id, const ID &id, + const std::vector &keys, + const HashRemoveCallback &remove_callback) override; + + /// Returns debug string for class. + /// + /// \return string. + std::string DebugString() const; + + protected: + using Log::shard_contexts_; + using Log::client_; + using Log::pubsub_channel_; + using Log::prefix_; + using Log::subscribe_callback_index_; + using Log::GetRedisContext; + + int64_t num_adds_ = 0; + int64_t num_removes_ = 0; + using Log::num_lookups_; +}; + +class DynamicResourceTable : public Hash { + public: + DynamicResourceTable(const std::vector> &contexts, + AsyncGcsClient *client) + : Hash(contexts, client) { + pubsub_channel_ = TablePubsub::NODE_RESOURCE; + prefix_ = TablePrefix::NODE_RESOURCE; + }; + + virtual ~DynamicResourceTable(){}; +}; + class ObjectTable : public Set { public: ObjectTable(const std::vector> &contexts, diff --git a/src/ray/object_manager/object_directory.cc b/src/ray/object_manager/object_directory.cc index 1f05559f4b87..d2496dceb8bf 100644 --- a/src/ray/object_manager/object_directory.cc +++ b/src/ray/object_manager/object_directory.cc @@ -11,16 +11,16 @@ namespace { /// Process a notification of the object table entries and store the result in /// client_ids. This assumes that client_ids already contains the result of the /// object table entries up to but not including this notification. -void UpdateObjectLocations(const GcsTableNotificationMode notification_mode, +void UpdateObjectLocations(const GcsChangeMode change_mode, const std::vector &location_updates, const ray::gcs::ClientTable &client_table, std::unordered_set *client_ids) { // location_updates contains the updates of locations of the object. - // with GcsTableNotificationMode, we can determine whether the update mode is + // with GcsChangeMode, we can determine whether the update mode is // addition or deletion. for (const auto &object_table_data : location_updates) { ClientID client_id = ClientID::FromBinary(object_table_data.manager); - if (notification_mode != GcsTableNotificationMode::REMOVE) { + if (change_mode != GcsChangeMode::REMOVE) { client_ids->insert(client_id); } else { client_ids->erase(client_id); @@ -41,7 +41,7 @@ void UpdateObjectLocations(const GcsTableNotificationMode notification_mode, void ObjectDirectory::RegisterBackend() { auto object_notification_callback = [this]( gcs::AsyncGcsClient *client, const ObjectID &object_id, - const GcsTableNotificationMode notification_mode, + const GcsChangeMode change_mode, const std::vector &location_updates) { // Objects are added to this map in SubscribeObjectLocations. auto it = listeners_.find(object_id); @@ -54,8 +54,7 @@ void ObjectDirectory::RegisterBackend() { it->second.subscribed = true; // Update entries for this object. - UpdateObjectLocations(notification_mode, location_updates, - gcs_client_->client_table(), + UpdateObjectLocations(change_mode, location_updates, gcs_client_->client_table(), &it->second.current_object_locations); // Copy the callbacks so that the callbacks can unsubscribe without interrupting // looping over the callbacks. @@ -135,8 +134,7 @@ void ObjectDirectory::HandleClientRemoved(const ClientID &client_id) { if (listener.second.current_object_locations.count(client_id) > 0) { // If the subscribed object has the removed client as a location, update // its locations with an empty update so that the location will be removed. - UpdateObjectLocations(GcsTableNotificationMode::APPEND_OR_ADD, {}, - gcs_client_->client_table(), + UpdateObjectLocations(GcsChangeMode::APPEND_OR_ADD, {}, gcs_client_->client_table(), &listener.second.current_object_locations); // Re-call all the subscribed callbacks for the object, since its // locations have changed. @@ -213,7 +211,7 @@ ray::Status ObjectDirectory::LookupLocations(const ObjectID &object_id, const std::vector &location_updates) { // Build the set of current locations based on the entries in the log. std::unordered_set client_ids; - UpdateObjectLocations(GcsTableNotificationMode::APPEND_OR_ADD, location_updates, + UpdateObjectLocations(GcsChangeMode::APPEND_OR_ADD, location_updates, gcs_client_->client_table(), &client_ids); // It is safe to call the callback directly since this is already running // in the GCS client's lookup callback stack. From 873d45b46750f01e9ded84f9ee8d05328535c75d Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Fri, 7 Jun 2019 11:35:18 -0700 Subject: [PATCH 076/118] Flush lineage cache on task submission instead of execution (#4942) --- src/ray/raylet/lineage_cache.cc | 124 ++++---------------- src/ray/raylet/lineage_cache.h | 67 +++++------ src/ray/raylet/lineage_cache_test.cc | 163 +++++++++------------------ src/ray/raylet/node_manager.cc | 64 +++-------- 4 files changed, 121 insertions(+), 297 deletions(-) diff --git a/src/ray/raylet/lineage_cache.cc b/src/ray/raylet/lineage_cache.cc index 910c3481bf58..795f2b54a6cb 100644 --- a/src/ray/raylet/lineage_cache.cc +++ b/src/ray/raylet/lineage_cache.cc @@ -68,7 +68,7 @@ Lineage::Lineage(const protocol::ForwardTaskRequest &task_request) { auto tasks = task_request.uncommitted_tasks(); for (auto it = tasks->begin(); it != tasks->end(); it++) { const auto &task = **it; - RAY_CHECK(SetEntry(task, GcsStatus::UNCOMMITTED_REMOTE)); + RAY_CHECK(SetEntry(task, GcsStatus::UNCOMMITTED)); } } @@ -108,38 +108,23 @@ bool Lineage::SetEntry(const Task &task, GcsStatus status) { auto task_id = task.GetTaskSpecification().TaskId(); auto it = entries_.find(task_id); bool updated = false; - std::unordered_set old_parents; if (it != entries_.end()) { if (it->second.SetStatus(status)) { - // The task's spec may have changed, so record its old dependencies. - old_parents = it->second.GetParentTaskIds(); - // SetStatus() would check if the new status is greater, - // if it succeeds, go ahead to update the task field. - it->second.UpdateTaskData(task); + // We assume here that the new `task` has the same fields as the task + // already in the lineage cache. If this is not true, then it is + // necessary to update the task data of the existing lineage cache entry + // with LineageEntry::UpdateTaskData. updated = true; } } else { LineageEntry new_entry(task, status); it = entries_.emplace(std::make_pair(task_id, std::move(new_entry))).first; updated = true; - } - // If the task data was updated, then record which tasks it depends on. Add - // all new tasks that it depends on and remove any old tasks that it no - // longer depends on. - // TODO(swang): Updating the task data every time could be inefficient for - // tasks that have lots of dependencies and/or large specs. A flag could be - // passed in for tasks whose data has not changed. - if (updated) { + // New task data was added to the local cache, so record which tasks it + // depends on. Add all new tasks that it depends on. for (const auto &parent_id : it->second.GetParentTaskIds()) { - if (old_parents.count(parent_id) == 0) { - AddChild(parent_id, task_id); - } else { - old_parents.erase(parent_id); - } - } - for (const auto &old_parent_id : old_parents) { - RemoveChild(old_parent_id, task_id); + AddChild(parent_id, task_id); } } return updated; @@ -198,15 +183,15 @@ LineageCache::LineageCache(const ClientID &client_id, /// A helper function to add some uncommitted lineage to the local cache. void LineageCache::AddUncommittedLineage(const TaskID &task_id, - const Lineage &uncommitted_lineage, - std::unordered_set &subscribe_tasks) { + const Lineage &uncommitted_lineage) { + RAY_LOG(DEBUG) << "Adding uncommitted task " << task_id << " on " << client_id_; // If the entry is not found in the lineage to merge, then we stop since // there is nothing to copy into the merged lineage. auto entry = uncommitted_lineage.GetEntry(task_id); if (!entry) { return; } - RAY_CHECK(entry->GetStatus() == GcsStatus::UNCOMMITTED_REMOTE); + RAY_CHECK(entry->GetStatus() == GcsStatus::UNCOMMITTED); // Insert a copy of the entry into our cache. const auto &parent_ids = entry->GetParentTaskIds(); @@ -214,90 +199,34 @@ void LineageCache::AddUncommittedLineage(const TaskID &task_id, // if the new entry has an equal or lower GCS status than the current entry // in our cache. This also prevents us from traversing the same node twice. if (lineage_.SetEntry(entry->TaskData(), entry->GetStatus())) { - subscribe_tasks.insert(task_id); + RAY_CHECK(SubscribeTask(task_id)); for (const auto &parent_id : parent_ids) { - AddUncommittedLineage(parent_id, uncommitted_lineage, subscribe_tasks); + AddUncommittedLineage(parent_id, uncommitted_lineage); } } } -bool LineageCache::AddWaitingTask(const Task &task, const Lineage &uncommitted_lineage) { - auto task_id = task.GetTaskSpecification().TaskId(); - RAY_LOG(DEBUG) << "Add waiting task " << task_id << " on " << client_id_; - - // Merge the uncommitted lineage into the lineage cache. Collect the IDs of - // tasks that we should subscribe to. These are all of the tasks that were - // included in the uncommitted lineage that we did not already have in our - // stash. - std::unordered_set subscribe_tasks; - AddUncommittedLineage(task_id, uncommitted_lineage, subscribe_tasks); - // Add the submitted task to the lineage cache as UNCOMMITTED_WAITING. It - // should be marked as UNCOMMITTED_READY once the task starts execution. - auto added = lineage_.SetEntry(task, GcsStatus::UNCOMMITTED_WAITING); - - // Do not subscribe to the waiting task itself. We just added it as - // UNCOMMITTED_WAITING, so the task is local. - subscribe_tasks.erase(task_id); - // Unsubscribe to the waiting task since we may have previously been - // subscribed to it. - UnsubscribeTask(task_id); - // Subscribe to all other tasks that were included in the uncommitted lineage - // and that were not already in the local stash. These tasks haven't been - // committed yet and will be committed by a different node, so we will not - // evict them until a notification for their commit is received. - for (const auto &task_id : subscribe_tasks) { - RAY_CHECK(SubscribeTask(task_id)); - } - - return added; -} - -bool LineageCache::AddReadyTask(const Task &task) { +bool LineageCache::CommitTask(const Task &task) { const TaskID task_id = task.GetTaskSpecification().TaskId(); - RAY_LOG(DEBUG) << "Add ready task " << task_id << " on " << client_id_; + RAY_LOG(DEBUG) << "Committing task " << task_id << " on " << client_id_; - // Set the task to READY. - if (lineage_.SetEntry(task, GcsStatus::UNCOMMITTED_READY)) { - // Attempt to flush the task. + if (lineage_.SetEntry(task, GcsStatus::UNCOMMITTED) || + lineage_.GetEntry(task_id)->GetStatus() == GcsStatus::UNCOMMITTED) { + // Attempt to flush the task if the task is uncommitted. FlushTask(task_id); return true; } else { - // The task was already ready to be committed (UNCOMMITTED_READY) or - // committing (COMMITTING). - return false; - } -} - -bool LineageCache::RemoveWaitingTask(const TaskID &task_id) { - RAY_LOG(DEBUG) << "Remove waiting task " << task_id << " on " << client_id_; - auto entry = lineage_.GetEntryMutable(task_id); - if (!entry) { - // The task was already evicted. - return false; - } - - // If the task is already not in WAITING status, then exit. This should only - // happen when there are two copies of the task executing at the node, due to - // a spurious reconstruction. Then, either the task is already past WAITING - // status, in which case it will be committed, or it is in - // UNCOMMITTED_REMOTE, in which case it was already removed. - if (entry->GetStatus() != GcsStatus::UNCOMMITTED_WAITING) { + // The task was already committing (COMMITTING). return false; } - - // Reset the status to REMOTE. We keep the task instead of removing it - // completely in case another task is submitted locally that depends on this - // one. - entry->ResetStatus(GcsStatus::UNCOMMITTED_REMOTE); - // The task is now remote, so subscribe to the task to make sure that we'll - // eventually clean it up. - RAY_CHECK(SubscribeTask(task_id)); - return true; } void LineageCache::MarkTaskAsForwarded(const TaskID &task_id, const ClientID &node_id) { RAY_CHECK(!node_id.IsNil()); - lineage_.GetEntryMutable(task_id)->MarkExplicitlyForwarded(node_id); + auto entry = lineage_.GetEntryMutable(task_id); + if (entry) { + entry->MarkExplicitlyForwarded(node_id); + } } /// A helper function to get the uncommitted lineage of a task. @@ -345,7 +274,7 @@ Lineage LineageCache::GetUncommittedLineageOrDie(const TaskID &task_id, void LineageCache::FlushTask(const TaskID &task_id) { auto entry = lineage_.GetEntryMutable(task_id); RAY_CHECK(entry); - RAY_CHECK(entry->GetStatus() == GcsStatus::UNCOMMITTED_READY); + RAY_CHECK(entry->GetStatus() < GcsStatus::COMMITTING); gcs::raylet::TaskTable::WriteCallback task_callback = [this]( ray::gcs::AsyncGcsClient *client, const TaskID &id, const protocol::TaskT &data) { @@ -406,11 +335,6 @@ void LineageCache::EvictTask(const TaskID &task_id) { if (!entry) { return; } - // Only evict tasks that we were subscribed to or that we were committing. - if (!(entry->GetStatus() == GcsStatus::UNCOMMITTED_REMOTE || - entry->GetStatus() == GcsStatus::COMMITTING)) { - return; - } // Entries cannot be safely evicted until their parents are all evicted. for (const auto &parent_id : entry->GetParentTaskIds()) { if (ContainsTask(parent_id)) { diff --git a/src/ray/raylet/lineage_cache.h b/src/ray/raylet/lineage_cache.h index 02d98b8cffe6..2dff0e94a4d1 100644 --- a/src/ray/raylet/lineage_cache.h +++ b/src/ray/raylet/lineage_cache.h @@ -17,19 +17,23 @@ namespace ray { namespace raylet { /// The status of a lineage cache entry according to its status in the GCS. +/// Tasks can only transition to a higher GcsStatus (e.g., an UNCOMMITTED state +/// can become COMMITTING but not vice versa). If a task is evicted from the +/// local cache, it implicitly goes back to state `NONE`, after which it may be +/// added to the local cache again (e.g., if it is forwarded to us again). enum class GcsStatus { /// The task is not in the lineage cache. NONE = 0, - /// The task is being executed or created on a remote node. - UNCOMMITTED_REMOTE, - /// The task is waiting to be executed or created locally. - UNCOMMITTED_WAITING, - /// The task has started execution, but the entry has not been written to the - /// GCS yet. - UNCOMMITTED_READY, - /// The task has been written to the GCS and we are waiting for an - /// acknowledgement of the commit. + /// The task is uncommitted. Unless there is a failure, we will expect a + /// different node to commit this task. + UNCOMMITTED, + /// We flushed this task and are waiting for the commit acknowledgement. COMMITTING, + // TODO(swang): Add a COMMITTED state for tasks for which we received a + // commit acknowledgement, but which we cannot evict yet (due to an ancestor + // that has not been evicted). This is to allow a performance optimization + // that avoids unnecessary subscribes when we receive tasks that were + // already COMMITTED at the sender. }; /// \class LineageEntry @@ -220,37 +224,23 @@ class LineageCache { gcs::TableInterface &task_storage, gcs::PubsubInterface &task_pubsub, uint64_t max_lineage_size); - /// Add a task that is waiting for execution and its uncommitted lineage. - /// These entries will not be written to the GCS until set to ready. + /// Asynchronously commit a task to the GCS. /// - /// \param task The waiting task to add. + /// \param task The task to commit. It will be moved to the COMMITTING state. + /// \return Whether the task was successfully committed. This can fail if the + /// task was already in the COMMITTING state. + bool CommitTask(const Task &task); + + /// Add a task and its (estimated) uncommitted lineage to the local cache. We + /// will subscribe to commit notifications for all uncommitted tasks to + /// determine when it is safe to evict the lineage from the local cache. + /// + /// \param task_id The ID of the uncommitted task to add. /// \param uncommitted_lineage The task's uncommitted lineage. These are the /// tasks that the given task is data-dependent on, but that have not - /// been made durable in the GCS, as far the task's submitter knows. - /// \return Whether the task was successfully marked as waiting to be - /// committed. This will return false if the task is already waiting to be - /// committed (UNCOMMITTED_WAITING), ready to be committed - /// (UNCOMMITTED_READY), or committing (COMMITTING). - bool AddWaitingTask(const Task &task, const Lineage &uncommitted_lineage); - - /// Add a task that is ready for GCS writeback. This overwrites the task’s - /// mutable fields in the execution specification. - /// - /// \param task The task to set as ready. - /// \return Whether the task was successfully marked as ready to be - /// committed. This will return false if the task is already ready to be - /// committed (UNCOMMITTED_READY) or committing (COMMITTING). - bool AddReadyTask(const Task &task); - - /// Remove a task that was waiting for execution. Its uncommitted lineage - /// will remain unchanged. - /// - /// \param task_id The ID of the waiting task to remove. - /// \return Whether the task was successfully removed. This will return false - /// if the task is not waiting to be committed. Then, the waiting task has - /// already been removed (UNCOMMITTED_REMOTE), or if it's ready to be - /// committed (UNCOMMITTED_READY) or committing (COMMITTING). - bool RemoveWaitingTask(const TaskID &task_id); + /// been committed to the GCS. This must contain the given task ID. + /// \return Void. + void AddUncommittedLineage(const TaskID &task_id, const Lineage &uncommitted_lineage); /// Mark a task as having been explicitly forwarded to a node. /// The lineage of the task is implicitly assumed to have also been forwarded. @@ -317,9 +307,6 @@ class LineageCache { /// Unsubscribe from notifications for a task. Returns whether the operation /// was successful (whether we were subscribed). bool UnsubscribeTask(const TaskID &task_id); - /// Add a task and its uncommitted lineage to the local stash. - void AddUncommittedLineage(const TaskID &task_id, const Lineage &uncommitted_lineage, - std::unordered_set &subscribe_tasks); /// The client ID, used to request notifications for specific tasks. /// TODO(swang): Move the ClientID into the generic Table implementation. diff --git a/src/ray/raylet/lineage_cache_test.cc b/src/ray/raylet/lineage_cache_test.cc index a61ae846a925..e5c126bcf078 100644 --- a/src/ray/raylet/lineage_cache_test.cc +++ b/src/ray/raylet/lineage_cache_test.cc @@ -122,15 +122,22 @@ static inline Task ExampleTask(const std::vector &arguments, return task; } +/// Helper method to create a Lineage object with a single task. +Lineage CreateSingletonLineage(const Task &task) { + Lineage singleton_lineage; + singleton_lineage.SetEntry(task, GcsStatus::UNCOMMITTED); + return singleton_lineage; +} + std::vector InsertTaskChain(LineageCache &lineage_cache, std::vector &inserted_tasks, int chain_size, const std::vector &initial_arguments, int64_t num_returns) { - Lineage empty_lineage; std::vector arguments = initial_arguments; for (int i = 0; i < chain_size; i++) { auto task = ExampleTask(arguments, num_returns); - RAY_CHECK(lineage_cache.AddWaitingTask(task, empty_lineage)); + Lineage lineage = CreateSingletonLineage(task); + lineage_cache.AddUncommittedLineage(task.GetTaskSpecification().TaskId(), lineage); inserted_tasks.push_back(task); arguments.clear(); for (int j = 0; j < task.GetTaskSpecification().NumReturns(); j++) { @@ -190,6 +197,34 @@ TEST_F(LineageCacheTest, TestGetUncommittedLineageOrDie) { } } +TEST_F(LineageCacheTest, TestDuplicateUncommittedLineage) { + // Insert a chain of tasks. + std::vector tasks; + auto return_values = + InsertTaskChain(lineage_cache_, tasks, 3, std::vector(), 1); + std::vector task_ids; + for (const auto &task : tasks) { + task_ids.push_back(task.GetTaskSpecification().TaskId()); + } + // Check that we subscribed to each of the uncommitted tasks. + ASSERT_EQ(mock_gcs_.NumRequestedNotifications(), task_ids.size()); + + // Check that if we add the same tasks as UNCOMMITTED again, we do not issue + // duplicate subscribe requests. + Lineage duplicate_lineage; + for (const auto &task : tasks) { + duplicate_lineage.SetEntry(task, GcsStatus::UNCOMMITTED); + } + lineage_cache_.AddUncommittedLineage(task_ids.back(), duplicate_lineage); + ASSERT_EQ(mock_gcs_.NumRequestedNotifications(), task_ids.size()); + + // Check that if we commit one of the tasks, we still do not issue any + // duplicate subscribe requests. + lineage_cache_.CommitTask(tasks.front()); + lineage_cache_.AddUncommittedLineage(task_ids.back(), duplicate_lineage); + ASSERT_EQ(mock_gcs_.NumRequestedNotifications(), task_ids.size()); +} + TEST_F(LineageCacheTest, TestMarkTaskAsForwarded) { // Insert chain of tasks. std::vector tasks; @@ -222,7 +257,7 @@ TEST_F(LineageCacheTest, TestMarkTaskAsForwarded) { ASSERT_EQ(1, uncommitted_lineage_forwarded.GetEntries().size()); } -TEST_F(LineageCacheTest, TestWritebackNoneReady) { +TEST_F(LineageCacheTest, TestWritebackReady) { // Insert a chain of dependent tasks. size_t num_tasks_flushed = 0; std::vector tasks; @@ -231,16 +266,9 @@ TEST_F(LineageCacheTest, TestWritebackNoneReady) { // Check that when no tasks have been marked as ready, we do not flush any // entries. ASSERT_EQ(mock_gcs_.TaskTable().size(), num_tasks_flushed); -} - -TEST_F(LineageCacheTest, TestWritebackReady) { - // Insert a chain of dependent tasks. - size_t num_tasks_flushed = 0; - std::vector tasks; - InsertTaskChain(lineage_cache_, tasks, 3, std::vector(), 1); // Check that after marking the first task as ready, we flush only that task. - ASSERT_TRUE(lineage_cache_.AddReadyTask(tasks.front())); + ASSERT_TRUE(lineage_cache_.CommitTask(tasks.front())); num_tasks_flushed++; ASSERT_EQ(mock_gcs_.TaskTable().size(), num_tasks_flushed); } @@ -253,7 +281,7 @@ TEST_F(LineageCacheTest, TestWritebackOrder) { // Mark all tasks as ready. All tasks should be flushed. for (const auto &task : tasks) { - ASSERT_TRUE(lineage_cache_.AddReadyTask(task)); + ASSERT_TRUE(lineage_cache_.CommitTask(task)); } ASSERT_EQ(mock_gcs_.TaskTable().size(), num_tasks_flushed); @@ -272,12 +300,13 @@ TEST_F(LineageCacheTest, TestEvictChain) { Lineage uncommitted_lineage; for (const auto &task : tasks) { - uncommitted_lineage.SetEntry(task, GcsStatus::UNCOMMITTED_REMOTE); + uncommitted_lineage.SetEntry(task, GcsStatus::UNCOMMITTED); } // Mark the last task as ready to flush. - ASSERT_TRUE(lineage_cache_.AddWaitingTask(tasks.back(), uncommitted_lineage)); + lineage_cache_.AddUncommittedLineage(tasks.back().GetTaskSpecification().TaskId(), + uncommitted_lineage); ASSERT_EQ(lineage_cache_.GetLineage().GetEntries().size(), tasks.size()); - ASSERT_TRUE(lineage_cache_.AddReadyTask(tasks.back())); + ASSERT_TRUE(lineage_cache_.CommitTask(tasks.back())); num_tasks_flushed++; ASSERT_EQ(mock_gcs_.TaskTable().size(), num_tasks_flushed); // Flush acknowledgements. The lineage cache should receive the commit for @@ -320,17 +349,20 @@ TEST_F(LineageCacheTest, TestEvictManyParents) { auto task = ExampleTask({}, 1); parent_tasks.push_back(task); arguments.push_back(task.GetTaskSpecification().ReturnId(0)); - ASSERT_TRUE(lineage_cache_.AddWaitingTask(task, Lineage())); + auto lineage = CreateSingletonLineage(task); + lineage_cache_.AddUncommittedLineage(task.GetTaskSpecification().TaskId(), lineage); } // Create a child task that is dependent on all of the previous tasks. auto child_task = ExampleTask(arguments, 1); - ASSERT_TRUE(lineage_cache_.AddWaitingTask(child_task, Lineage())); + auto lineage = CreateSingletonLineage(child_task); + lineage_cache_.AddUncommittedLineage(child_task.GetTaskSpecification().TaskId(), + lineage); // Flush the child task. Make sure that it remains in the cache, since none // of its parents have been committed yet, and that the uncommitted lineage // still includes all of the parent tasks. size_t total_tasks = parent_tasks.size() + 1; - lineage_cache_.AddReadyTask(child_task); + lineage_cache_.CommitTask(child_task); mock_gcs_.Flush(); ASSERT_EQ(lineage_cache_.GetLineage().GetEntries().size(), total_tasks); ASSERT_EQ(lineage_cache_ @@ -342,7 +374,7 @@ TEST_F(LineageCacheTest, TestEvictManyParents) { // Flush each parent task and check for eviction safety. for (const auto &parent_task : parent_tasks) { - lineage_cache_.AddReadyTask(parent_task); + lineage_cache_.CommitTask(parent_task); mock_gcs_.Flush(); total_tasks--; if (total_tasks > 1) { @@ -364,75 +396,6 @@ TEST_F(LineageCacheTest, TestEvictManyParents) { ASSERT_EQ(lineage_cache_.GetLineage().GetChildrenSize(), 0); } -TEST_F(LineageCacheTest, TestForwardTasksRoundTrip) { - // Insert a chain of dependent tasks. - uint64_t lineage_size = max_lineage_size_ + 1; - std::vector tasks; - InsertTaskChain(lineage_cache_, tasks, lineage_size, std::vector(), 1); - - // Simulate removing each task, forwarding it to another node, then - // receiving the task back again. - for (auto it = tasks.begin(); it != tasks.end(); it++) { - const auto task_id = it->GetTaskSpecification().TaskId(); - // Simulate removing the task and forwarding it to another node. - auto uncommitted_lineage = - lineage_cache_.GetUncommittedLineageOrDie(task_id, ClientID::Nil()); - ASSERT_TRUE(lineage_cache_.RemoveWaitingTask(task_id)); - // Simulate receiving the task again. Make sure we can add the task back. - flatbuffers::FlatBufferBuilder fbb; - auto uncommitted_lineage_message = uncommitted_lineage.ToFlatbuffer(fbb, task_id); - fbb.Finish(uncommitted_lineage_message); - uncommitted_lineage = Lineage( - *flatbuffers::GetRoot(fbb.GetBufferPointer())); - ASSERT_TRUE(lineage_cache_.AddWaitingTask(*it, uncommitted_lineage)); - } -} - -TEST_F(LineageCacheTest, TestForwardTask) { - // Insert a chain of dependent tasks. - size_t num_tasks_flushed = 0; - std::vector tasks; - InsertTaskChain(lineage_cache_, tasks, 3, std::vector(), 1); - - // Simulate removing the task and forwarding it to another node. - auto it = tasks.begin() + 1; - auto forwarded_task = *it; - tasks.erase(it); - auto task_id_to_remove = forwarded_task.GetTaskSpecification().TaskId(); - auto uncommitted_lineage = - lineage_cache_.GetUncommittedLineageOrDie(task_id_to_remove, ClientID::Nil()); - ASSERT_TRUE(lineage_cache_.RemoveWaitingTask(task_id_to_remove)); - ASSERT_EQ(lineage_cache_.GetLineage().GetEntries().size(), 3); - - // Simulate executing the remaining tasks. - for (const auto &task : tasks) { - ASSERT_TRUE(lineage_cache_.AddReadyTask(task)); - num_tasks_flushed++; - } - // Check that the first task, which has no dependencies can be flushed. The - // last task cannot be flushed since one of its dependencies has not been - // added by the remote node yet. - ASSERT_EQ(mock_gcs_.TaskTable().size(), num_tasks_flushed); - mock_gcs_.Flush(); - ASSERT_EQ(lineage_cache_.GetLineage().GetEntries().size(), 2); - - // Simulate executing the task on a remote node and adding it to the GCS. - auto task_data = std::make_shared(); - RAY_CHECK_OK( - mock_gcs_.RemoteAdd(forwarded_task.GetTaskSpecification().TaskId(), task_data)); - // Check that the remote task is flushed. - num_tasks_flushed++; - ASSERT_EQ(mock_gcs_.TaskTable().size(), num_tasks_flushed); - ASSERT_EQ(mock_gcs_.SubscribedTasks().size(), 1); - - // Check that once we receive the callback for the remote task, we can now - // flush the last task. - mock_gcs_.Flush(); - ASSERT_EQ(mock_gcs_.SubscribedTasks().size(), 0); - ASSERT_EQ(lineage_cache_.GetLineage().GetEntries().size(), 0); - ASSERT_EQ(lineage_cache_.GetLineage().GetChildrenSize(), 0); -} - TEST_F(LineageCacheTest, TestEviction) { // Insert a chain of dependent tasks. uint64_t lineage_size = max_lineage_size_ + 1; @@ -440,12 +403,6 @@ TEST_F(LineageCacheTest, TestEviction) { std::vector tasks; InsertTaskChain(lineage_cache_, tasks, lineage_size, std::vector(), 1); - // Simulate forwarding the chain of tasks to a remote node. - for (const auto &task : tasks) { - auto task_id = task.GetTaskSpecification().TaskId(); - ASSERT_TRUE(lineage_cache_.RemoveWaitingTask(task_id)); - } - // Check that the last task in the chain still has all tasks in its // uncommitted lineage. const auto last_task_id = tasks.back().GetTaskSpecification().TaskId(); @@ -500,12 +457,6 @@ TEST_F(LineageCacheTest, TestOutOfOrderEviction) { std::vector tasks; InsertTaskChain(lineage_cache_, tasks, lineage_size, std::vector(), 1); - // Simulate forwarding the chain of tasks to a remote node. - for (const auto &task : tasks) { - auto task_id = task.GetTaskSpecification().TaskId(); - ASSERT_TRUE(lineage_cache_.RemoveWaitingTask(task_id)); - } - // Check that the last task in the chain still has all tasks in its // uncommitted lineage. const auto last_task_id = tasks.back().GetTaskSpecification().TaskId(); @@ -545,19 +496,15 @@ TEST_F(LineageCacheTest, TestEvictionUncommittedChildren) { std::vector tasks; InsertTaskChain(lineage_cache_, tasks, lineage_size, std::vector(), 1); - // Simulate forwarding the chain of tasks to a remote node. - for (const auto &task : tasks) { - auto task_id = task.GetTaskSpecification().TaskId(); - ASSERT_TRUE(lineage_cache_.RemoveWaitingTask(task_id)); - } - // Add more tasks to the lineage cache that will remain local. Each of these // tasks is dependent one of the tasks that was forwarded above. for (const auto &task : tasks) { auto return_id = task.GetTaskSpecification().ReturnId(0); auto dependent_task = ExampleTask({return_id}, 1); - ASSERT_TRUE(lineage_cache_.AddWaitingTask(dependent_task, Lineage())); - ASSERT_TRUE(lineage_cache_.AddReadyTask(dependent_task)); + auto lineage = CreateSingletonLineage(dependent_task); + lineage_cache_.AddUncommittedLineage(dependent_task.GetTaskSpecification().TaskId(), + lineage); + ASSERT_TRUE(lineage_cache_.CommitTask(dependent_task)); // Once the forwarded tasks are evicted from the lineage cache, we expect // each of these dependent tasks to be flushed, since all of their // dependencies have been committed. diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index e3fd9a0df09f..07dca3c7ab32 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -693,11 +693,6 @@ void NodeManager::HandleActorStateTransition(const ActorID &actor_id, // known. auto created_actor_methods = local_queues_.RemoveTasks(created_actor_method_ids); for (const auto &method : created_actor_methods) { - if (!lineage_cache_.RemoveWaitingTask(method.GetTaskSpecification().TaskId())) { - RAY_LOG(WARNING) << "Task " << method.GetTaskSpecification().TaskId() - << " already removed from the lineage cache. This is most " - "likely due to reconstruction."; - } // Maintain the invariant that if a task is in the // MethodsWaitingForActorCreation queue, then it is subscribed to its // respective actor creation task. Since the actor location is now known, @@ -1466,10 +1461,6 @@ void NodeManager::TreatTaskAsFailed(const Task &task, const ErrorType &error_typ current_time_ms())); } } - // A task failing is equivalent to assigning and finishing the task, so clean - // up any leftover state as for any task dispatched and removed from the - // local queue. - lineage_cache_.AddReadyTask(task); task_dependency_manager_.TaskCanceled(spec.TaskId()); // Notify the task dependency manager that we no longer need this task's // object dependencies. TODO(swang): Ideally, we would check the return value @@ -1538,10 +1529,14 @@ void NodeManager::SubmitTask(const Task &task, const Lineage &uncommitted_lineag } // Add the task and its uncommitted lineage to the lineage cache. - if (!lineage_cache_.AddWaitingTask(task, uncommitted_lineage)) { - RAY_LOG(WARNING) - << "Task " << task_id - << " already in lineage cache. This is most likely due to reconstruction."; + if (forwarded) { + lineage_cache_.AddUncommittedLineage(task_id, uncommitted_lineage); + } else { + if (!lineage_cache_.CommitTask(task)) { + RAY_LOG(WARNING) + << "Task " << task_id + << " already committed to the GCS. This is most likely due to reconstruction."; + } } if (spec.IsActorTask()) { @@ -1869,32 +1864,14 @@ bool NodeManager::AssignTask(const Task &task) { actor_entry->second.AddHandle(new_handle_id, execution_dependency); } - // If the task was an actor task, then record this execution to - // guarantee consistency in the case of reconstruction. - auto execution_dependency = actor_entry->second.GetExecutionDependency(); - // The execution dependency is initialized to the actor creation task's - // return value, and is subsequently updated to the assigned tasks' - // return values, so it should never be nil. - RAY_CHECK(!execution_dependency.IsNil()); - // Update the task's execution dependencies to reflect the actual - // execution order, to support deterministic reconstruction. - // NOTE(swang): The update of an actor task's execution dependencies is - // performed asynchronously. This means that if this node manager dies, - // we may lose updates that are in flight to the task table. We only - // guarantee deterministic reconstruction ordering for tasks whose - // updates are reflected in the task table. - // (SetExecutionDependencies takes a non-const so copy task in a - // on-const variable.) - assigned_task.SetExecutionDependencies({execution_dependency}); + // TODO(swang): For actors with multiple actor handles, to + // guarantee that tasks are replayed in the same order after a + // failure, we must update the task's execution dependency to be + // the actor's current execution dependency. } else { RAY_CHECK(spec.NewActorHandles().empty()); } - // We started running the task, so the task is ready to write to GCS. - if (!lineage_cache_.AddReadyTask(assigned_task)) { - RAY_LOG(WARNING) << "Task " << spec.TaskId() << " already in lineage cache." - << " This is most likely due to reconstruction."; - } // Mark the task as running. // (See design_docs/task_states.rst for the state transition diagram.) local_queues_.QueueTasks({assigned_task}, TaskState::RUNNING); @@ -2260,9 +2237,6 @@ void NodeManager::ForwardTaskOrResubmit(const Task &task, // Temporarily move the RESUBMITTED task to the SWAP queue while the // timer is active. local_queues_.QueueTasks({task}, TaskState::SWAP); - // Remove the task from the lineage cache. The task will get added back - // once it is resubmitted. - lineage_cache_.RemoveWaitingTask(task_id); } else { // The task is not for an actor and may therefore be placed on another // node immediately. Send it to the scheduling policy to be placed again. @@ -2327,17 +2301,9 @@ void NodeManager::ForwardTask( if (status.ok()) { const auto &spec = task.GetTaskSpecification(); - // If we were able to forward the task, remove the forwarded task from the - // lineage cache since the receiving node is now responsible for writing - // the task to the GCS. - if (!lineage_cache_.RemoveWaitingTask(task_id)) { - RAY_LOG(WARNING) << "Task " << task_id << " already removed from the lineage" - << " cache. This is most likely due to reconstruction."; - } else { - // Mark as forwarded so that the task and its lineage is not - // re-forwarded in the future to the receiving node. - lineage_cache_.MarkTaskAsForwarded(task_id, node_id); - } + // Mark as forwarded so that the task and its lineage are not + // re-forwarded in the future to the receiving node. + lineage_cache_.MarkTaskAsForwarded(task_id, node_id); // Notify the task dependency manager that we are no longer responsible // for executing this task. From 9e328fbe6f94a069b31cd98511244694efc96f92 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Fri, 7 Jun 2019 16:42:37 -0700 Subject: [PATCH 077/118] [rllib] Add docs on how to use TF eager execution (#4927) --- ci/jenkins_tests/run_rllib_tests.sh | 10 ++ doc/source/rllib-concepts.rst | 31 ++++++ doc/source/rllib-examples.rst | 2 + doc/source/rllib-training.rst | 7 ++ doc/source/rllib.rst | 2 + python/ray/rllib/agents/a3c/a3c_tf_policy.py | 5 +- python/ray/rllib/agents/ppo/ppo_policy.py | 10 +- python/ray/rllib/agents/trainer.py | 3 + python/ray/rllib/examples/eager_execution.py | 101 +++++++++++++++++++ python/ray/rllib/policy/dynamic_tf_policy.py | 50 +++++++++ 10 files changed, 215 insertions(+), 6 deletions(-) create mode 100644 python/ray/rllib/examples/eager_execution.py diff --git a/ci/jenkins_tests/run_rllib_tests.sh b/ci/jenkins_tests/run_rllib_tests.sh index a97bf5517ea2..13036ae7da0f 100644 --- a/ci/jenkins_tests/run_rllib_tests.sh +++ b/ci/jenkins_tests/run_rllib_tests.sh @@ -392,6 +392,16 @@ docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ /ray/ci/suppress_output python /ray/python/ray/rllib/examples/rollout_worker_custom_workflow.py +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ + /ray/ci/suppress_output python /ray/python/ray/rllib/examples/eager_execution.py --iters=2 + +docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ + /ray/ci/suppress_output /ray/python/ray/rllib/train.py \ + --env CartPole-v0 \ + --run PPO \ + --stop '{"training_iteration": 1}' \ + --config '{"use_eager": true, "simple_optimizer": true}' + docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ /ray/ci/suppress_output python /ray/python/ray/rllib/examples/custom_tf_policy.py --iters=2 diff --git a/doc/source/rllib-concepts.rst b/doc/source/rllib-concepts.rst index b7b3ff823774..4b00f5636540 100644 --- a/doc/source/rllib-concepts.rst +++ b/doc/source/rllib-concepts.rst @@ -346,6 +346,37 @@ In PPO we run ``setup_mixins`` before the loss function is called (i.e., ``befor Finally, note that you do not have to use ``build_tf_policy`` to define a TensorFlow policy. You can alternatively subclass ``Policy``, ``TFPolicy``, or ``DynamicTFPolicy`` as convenient. +Building Policies in TensorFlow Eager +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +While RLlib runs all TF operations in graph mode, you can still leverage TensorFlow eager using `tf.py_function `__. However, note that eager and non-eager tensors cannot be mixed within the ``py_function``. Here's an example of embedding eager execution within a policy loss function: + +.. code-block:: python + + def eager_loss(policy, batch_tensors): + """Example of using embedded eager execution in a custom loss. + + Here `compute_penalty` prints the actions and rewards for debugging, and + also computes a (dummy) penalty term to add to the loss. + """ + + def compute_penalty(actions, rewards): + penalty = tf.reduce_mean(tf.cast(actions, tf.float32)) + if random.random() > 0.9: + print("The eagerly computed penalty is", penalty, actions, rewards) + return penalty + + actions = batch_tensors[SampleBatch.ACTIONS] + rewards = batch_tensors[SampleBatch.REWARDS] + penalty = tf.py_function( + compute_penalty, [actions, rewards], Tout=tf.float32) + + return penalty - tf.reduce_mean(policy.action_dist.logp(actions) * rewards) + +You can find a runnable file for the above eager execution example `here `__. + +There is also experimental support for running the entire loss function in eager mode. This can be enabled with ``use_eager: True``, e.g., ``rllib train --env=CartPole-v0 --run=PPO --config='{"use_eager": true, "simple_optimizer": true}'``. However this currently only works for a couple algorithms. + Building Policies in PyTorch ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/doc/source/rllib-examples.rst b/doc/source/rllib-examples.rst index 13bfdc68bfc1..604abf394de1 100644 --- a/doc/source/rllib-examples.rst +++ b/doc/source/rllib-examples.rst @@ -38,6 +38,8 @@ Custom Envs and Models Example of adding batch norm layers to a custom model. - `Parametric actions `__: Example of how to handle variable-length or parametric action spaces. +- `Eager execution `__: + Example of how to leverage TensorFlow eager to simplify debugging and design of custom models and policies. Serving and Offline ------------------- diff --git a/doc/source/rllib-training.rst b/doc/source/rllib-training.rst index 824ef4c3dd88..9c365f8fb427 100644 --- a/doc/source/rllib-training.rst +++ b/doc/source/rllib-training.rst @@ -367,6 +367,13 @@ The ``"monitor": true`` config can be used to save Gym episode videos to the res openaigym.video.0.31403.video000000.meta.json openaigym.video.0.31403.video000000.mp4 +TensorFlow Eager +~~~~~~~~~~~~~~~~ + +While RLlib uses TF graph mode for all computations, you can still leverage TF eager to inspect the intermediate state of computations using `tf.py_function `__. Here's an example of using eager mode in `a custom RLlib model and loss `__. + +There is also experimental support for running the entire loss function in eager mode. This can be enabled with ``use_eager: True``, e.g., ``rllib train --env=CartPole-v0 --run=PPO --config='{"use_eager": true, "simple_optimizer": true}'``. However this currently only works for a couple algorithms. + Episode Traces ~~~~~~~~~~~~~~ diff --git a/doc/source/rllib.rst b/doc/source/rllib.rst index f0571b23c20e..d0d9d715aa7c 100644 --- a/doc/source/rllib.rst +++ b/doc/source/rllib.rst @@ -101,6 +101,8 @@ Concepts and Building Custom Algorithms - `Building Policies in TensorFlow `__ + - `Building Policies in TensorFlow Eager `__ + - `Building Policies in PyTorch `__ - `Extending Existing Policies `__ diff --git a/python/ray/rllib/agents/a3c/a3c_tf_policy.py b/python/ray/rllib/agents/a3c/a3c_tf_policy.py index ed3676472850..d05f496a7945 100644 --- a/python/ray/rllib/agents/a3c/a3c_tf_policy.py +++ b/python/ray/rllib/agents/a3c/a3c_tf_policy.py @@ -41,8 +41,9 @@ def actor_critic_loss(policy, batch_tensors): policy.loss = A3CLoss( policy.action_dist, batch_tensors[SampleBatch.ACTIONS], batch_tensors[Postprocessing.ADVANTAGES], - batch_tensors[Postprocessing.VALUE_TARGETS], policy.vf, - policy.config["vf_loss_coeff"], policy.config["entropy_coeff"]) + batch_tensors[Postprocessing.VALUE_TARGETS], + policy.convert_to_eager(policy.vf), policy.config["vf_loss_coeff"], + policy.config["entropy_coeff"]) return policy.loss.total_loss diff --git a/python/ray/rllib/agents/ppo/ppo_policy.py b/python/ray/rllib/agents/ppo/ppo_policy.py index 4b391cab2cdc..ad79d90faa9a 100644 --- a/python/ray/rllib/agents/ppo/ppo_policy.py +++ b/python/ray/rllib/agents/ppo/ppo_policy.py @@ -106,8 +106,10 @@ def reduce_mean_valid(t): def ppo_surrogate_loss(policy, batch_tensors): if policy.model.state_in: - max_seq_len = tf.reduce_max(policy.model.seq_lens) - mask = tf.sequence_mask(policy.model.seq_lens, max_seq_len) + max_seq_len = tf.reduce_max( + policy.convert_to_eager(policy.model.seq_lens)) + mask = tf.sequence_mask( + policy.convert_to_eager(policy.model.seq_lens), max_seq_len) mask = tf.reshape(mask, [-1]) else: mask = tf.ones_like( @@ -121,8 +123,8 @@ def ppo_surrogate_loss(policy, batch_tensors): batch_tensors[BEHAVIOUR_LOGITS], batch_tensors[SampleBatch.VF_PREDS], policy.action_dist, - policy.value_function, - policy.kl_coeff, + policy.convert_to_eager(policy.value_function), + policy.convert_to_eager(policy.kl_coeff), mask, entropy_coeff=policy.config["entropy_coeff"], clip_param=policy.config["clip_param"], diff --git a/python/ray/rllib/agents/trainer.py b/python/ray/rllib/agents/trainer.py index f08b23e93fd7..a0d48d2ef714 100644 --- a/python/ray/rllib/agents/trainer.py +++ b/python/ray/rllib/agents/trainer.py @@ -67,6 +67,9 @@ }, # Whether to attempt to continue training if a worker crashes. "ignore_worker_failures": False, + # Execute TF loss functions in eager mode. This is currently experimental + # and only really works with the basic PG algorithm. + "use_eager": False, # === Policy === # Arguments to pass to model. See models/catalog.py for a full list of the diff --git a/python/ray/rllib/examples/eager_execution.py b/python/ray/rllib/examples/eager_execution.py new file mode 100644 index 000000000000..a3c418a33139 --- /dev/null +++ b/python/ray/rllib/examples/eager_execution.py @@ -0,0 +1,101 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import random + +import ray +from ray import tune +from ray.rllib.agents.trainer_template import build_trainer +from ray.rllib.models import FullyConnectedNetwork, Model, ModelCatalog +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.tf_policy_template import build_tf_policy +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() + +parser = argparse.ArgumentParser() +parser.add_argument("--iters", type=int, default=200) + + +class EagerModel(Model): + """Example of using embedded eager execution in a custom model. + + This shows how to use tf.py_function() to execute a snippet of TF code + in eager mode. Here the `self.forward_eager` method just prints out + the intermediate tensor for debug purposes, but you can in general + perform any TF eager operation in tf.py_function(). + """ + + def _build_layers_v2(self, input_dict, num_outputs, options): + self.fcnet = FullyConnectedNetwork(input_dict, self.obs_space, + self.action_space, num_outputs, + options) + feature_out = tf.py_function(self.forward_eager, + [self.fcnet.last_layer], tf.float32) + + with tf.control_dependencies([feature_out]): + return tf.identity(self.fcnet.outputs), feature_out + + def forward_eager(self, feature_layer): + assert tf.executing_eagerly() + if random.random() > 0.99: + print("Eagerly printing the feature layer mean value", + tf.reduce_mean(feature_layer)) + return feature_layer + + +def policy_gradient_loss(policy, batch_tensors): + """Example of using embedded eager execution in a custom loss. + + Here `compute_penalty` prints the actions and rewards for debugging, and + also computes a (dummy) penalty term to add to the loss. + + Alternatively, you can set config["use_eager"] = True, which will try to + automatically eagerify the entire loss function. However, this only works + if your loss doesn't reference any non-eager tensors. It also won't work + with the multi-GPU optimizer used by PPO. + """ + + def compute_penalty(actions, rewards): + assert tf.executing_eagerly() + penalty = tf.reduce_mean(tf.cast(actions, tf.float32)) + if random.random() > 0.9: + print("The eagerly computed penalty is", penalty, actions, rewards) + return penalty + + actions = batch_tensors[SampleBatch.ACTIONS] + rewards = batch_tensors[SampleBatch.REWARDS] + penalty = tf.py_function( + compute_penalty, [actions, rewards], Tout=tf.float32) + + return penalty - tf.reduce_mean(policy.action_dist.logp(actions) * rewards) + + +# +MyTFPolicy = build_tf_policy( + name="MyTFPolicy", + loss_fn=policy_gradient_loss, +) + +# +MyTrainer = build_trainer( + name="MyCustomTrainer", + default_policy=MyTFPolicy, +) + +if __name__ == "__main__": + ray.init() + args = parser.parse_args() + ModelCatalog.register_custom_model("eager_model", EagerModel) + tune.run( + MyTrainer, + stop={"training_iteration": args.iters}, + config={ + "env": "CartPole-v0", + "num_workers": 0, + "model": { + "custom_model": "eager_model" + }, + }) diff --git a/python/ray/rllib/policy/dynamic_tf_policy.py b/python/ray/rllib/policy/dynamic_tf_policy.py index 0240f275de37..23014553bf0d 100644 --- a/python/ray/rllib/policy/dynamic_tf_policy.py +++ b/python/ray/rllib/policy/dynamic_tf_policy.py @@ -167,6 +167,8 @@ def __init__(self, batch_divisibility_req=batch_divisibility_req) # Phase 2 init + self._needs_eager_conversion = set() + self._eager_tensors = {} before_loss_init(self, obs_space, action_space, config) if not existing_inputs: self._initialize_loss() @@ -178,10 +180,26 @@ def get_obs_input_dict(self): """ return self.input_dict + def convert_to_eager(self, tensor): + """Convert a graph tensor accessed in the loss to an eager tensor. + + Experimental. + """ + if tf.executing_eagerly(): + return self._eager_tensors[tensor] + else: + self._needs_eager_conversion.add(tensor) + return tensor + @override(TFPolicy) def copy(self, existing_inputs): """Creates a copy of self using existing input placeholders.""" + if self.config["use_eager"]: + raise ValueError( + "eager not implemented for multi-GPU, try setting " + "`simple_optimizer: true`") + # Note that there might be RNN state inputs at the end of the list if self._state_inputs: num_state_inputs = len(self._state_inputs) + 1 @@ -297,6 +315,38 @@ def fake_array(tensor): loss = self._do_loss_init(batch_tensors) for k in sorted(batch_tensors.accessed_keys): loss_inputs.append((k, batch_tensors[k])) + + # XXX experimental support for automatically eagerifying the loss. + # The main limitation right now is that TF doesn't support mixing eager + # and non-eager tensors, so losses that read non-eager tensors through + # `policy` need to use `policy.convert_to_eager(tensor)`. + if self.config["use_eager"]: + if not self.model: + raise ValueError("eager not implemented in this case") + graph_tensors = list(self._needs_eager_conversion) + + def gen_loss(model_outputs, *args): + # fill in the batch tensor dict with eager ensors + eager_inputs = dict( + zip([k for (k, v) in loss_inputs], + args[:len(loss_inputs)])) + # fill in the eager versions of all accessed graph tensors + self._eager_tensors = dict( + zip(graph_tensors, args[len(loss_inputs):])) + # patch the action dist to use eager mode tensors + self.action_dist.inputs = model_outputs + return self._loss_fn(self, eager_inputs) + + # TODO(ekl) also handle the stats funcs + loss = tf.py_function( + gen_loss, + # cast works around TypeError: Cannot convert provided value + # to EagerTensor. Provided value: 0.0 Requested dtype: int64 + [self.model.outputs] + [ + tf.cast(v, tf.float32) for (k, v) in loss_inputs + ] + [tf.cast(t, tf.float32) for t in graph_tensors], + tf.float32) + TFPolicy._initialize_loss(self, loss, loss_inputs) if self._grad_stats_fn: self._stats_fetches.update(self._grad_stats_fn(self, self._grads)) From 77689d1116c5d5eefc775740ae7a6a090d643186 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Fri, 7 Jun 2019 16:45:36 -0700 Subject: [PATCH 078/118] [rllib] Port remainder of algorithms to build_trainer() pattern (#4920) --- doc/source/rllib-package-ref.rst | 18 +- python/ray/rllib/agents/ddpg/apex.py | 27 +- python/ray/rllib/agents/ddpg/ddpg.py | 109 +++-- python/ray/rllib/agents/ddpg/td3.py | 16 +- python/ray/rllib/agents/dqn/apex.py | 63 ++- python/ray/rllib/agents/dqn/dqn.py | 393 ++++++++---------- python/ray/rllib/agents/impala/impala.py | 130 +++--- python/ray/rllib/agents/marwil/marwil.py | 41 +- python/ray/rllib/agents/ppo/appo.py | 17 +- python/ray/rllib/agents/qmix/apex.py | 27 +- python/ray/rllib/agents/qmix/qmix.py | 25 +- python/ray/rllib/agents/trainer.py | 9 + python/ray/rllib/agents/trainer_template.py | 89 +++- python/ray/rllib/policy/tf_policy_template.py | 9 +- .../ray/rllib/policy/torch_policy_template.py | 9 +- python/ray/rllib/utils/__init__.py | 15 + 16 files changed, 511 insertions(+), 486 deletions(-) diff --git a/doc/source/rllib-package-ref.rst b/doc/source/rllib-package-ref.rst index db4b2dbfe0eb..6a4e6aed43f8 100644 --- a/doc/source/rllib-package-ref.rst +++ b/doc/source/rllib-package-ref.rst @@ -1,25 +1,11 @@ RLlib Package Reference ======================= -ray.rllib.agents +ray.rllib.policy ---------------- -.. automodule:: ray.rllib.agents +.. automodule:: ray.rllib.policy :members: - -.. autoclass:: ray.rllib.agents.a3c.A2CTrainer -.. autoclass:: ray.rllib.agents.a3c.A3CTrainer -.. autoclass:: ray.rllib.agents.ddpg.ApexDDPGTrainer -.. autoclass:: ray.rllib.agents.ddpg.DDPGTrainer -.. autoclass:: ray.rllib.agents.dqn.ApexTrainer -.. autoclass:: ray.rllib.agents.dqn.DQNTrainer -.. autoclass:: ray.rllib.agents.es.ESTrainer -.. autoclass:: ray.rllib.agents.pg.PGTrainer -.. autoclass:: ray.rllib.agents.impala.ImpalaTrainer -.. autoclass:: ray.rllib.agents.ppo.APPOTrainer -.. autoclass:: ray.rllib.agents.ppo.PPOTrainer -.. autoclass:: ray.rllib.agents.marwil.MARWILTrainer - ray.rllib.env ------------- diff --git a/python/ray/rllib/agents/ddpg/apex.py b/python/ray/rllib/agents/ddpg/apex.py index 5ea732f17508..e0731e87a809 100644 --- a/python/ray/rllib/agents/ddpg/apex.py +++ b/python/ray/rllib/agents/ddpg/apex.py @@ -2,15 +2,14 @@ from __future__ import division from __future__ import print_function +from ray.rllib.agents.dqn.apex import APEX_TRAINER_PROPERTIES from ray.rllib.agents.ddpg.ddpg import DDPGTrainer, \ DEFAULT_CONFIG as DDPG_CONFIG -from ray.rllib.utils.annotations import override from ray.rllib.utils import merge_dicts APEX_DDPG_DEFAULT_CONFIG = merge_dicts( DDPG_CONFIG, # see also the options in ddpg.py, which are also supported { - "optimizer_class": "AsyncReplayOptimizer", "optimizer": merge_dicts( DDPG_CONFIG["optimizer"], { "max_weight_sync_delay": 400, @@ -32,23 +31,7 @@ }, ) - -class ApexDDPGTrainer(DDPGTrainer): - """DDPG variant that uses the Ape-X distributed policy optimizer. - - By default, this is configured for a large single node (32 cores). For - running in a large cluster, increase the `num_workers` config var. - """ - - _name = "APEX_DDPG" - _default_config = APEX_DDPG_DEFAULT_CONFIG - - @override(DDPGTrainer) - def update_target_if_needed(self): - # Ape-X updates based on num steps trained, not sampled - if self.optimizer.num_steps_trained - self.last_target_update_ts > \ - self.config["target_network_update_freq"]: - self.workers.local_worker().foreach_trainable_policy( - lambda p, _: p.update_target()) - self.last_target_update_ts = self.optimizer.num_steps_trained - self.num_target_updates += 1 +ApexDDPGTrainer = DDPGTrainer.with_updates( + name="APEX_DDPG", + default_config=APEX_DDPG_DEFAULT_CONFIG, + **APEX_TRAINER_PROPERTIES) diff --git a/python/ray/rllib/agents/ddpg/ddpg.py b/python/ray/rllib/agents/ddpg/ddpg.py index a9676335eb3f..a6b42f1ca927 100644 --- a/python/ray/rllib/agents/ddpg/ddpg.py +++ b/python/ray/rllib/agents/ddpg/ddpg.py @@ -3,9 +3,9 @@ from __future__ import print_function from ray.rllib.agents.trainer import with_common_config -from ray.rllib.agents.dqn.dqn import DQNTrainer +from ray.rllib.agents.dqn.dqn import GenericOffPolicyTrainer, \ + update_worker_explorations from ray.rllib.agents.ddpg.ddpg_policy import DDPGTFPolicy -from ray.rllib.utils.annotations import override from ray.rllib.utils.schedules import ConstantSchedule, LinearSchedule # yapf: disable @@ -97,6 +97,11 @@ # optimization on initial policy parameters. Note that this will be # disabled when the action noise scale is set to 0 (e.g during evaluation). "pure_exploration_steps": 1000, + # Extra configuration that disables exploration. + "evaluation_config": { + "exploration_fraction": 0, + "exploration_final_eps": 0, + }, # === Replay buffer === # Size of the replay buffer. Note that if async_updates is set, then @@ -108,6 +113,11 @@ "prioritized_replay_alpha": 0.6, # Beta parameter for sampling from prioritized replay buffer. "prioritized_replay_beta": 0.4, + # Fraction of entire training period over which the beta parameter is + # annealed + "beta_annealing_fraction": 0.2, + # Final value of beta + "final_prioritized_replay_beta": 0.4, # Epsilon to add to the TD errors when updating priorities. "prioritized_replay_eps": 1e-6, # Whether to LZ4 compress observations @@ -146,8 +156,6 @@ # to increase if your environment is particularly slow to sample, or if # you're using the Async or Ape-X optimizers. "num_workers": 0, - # Optimizer class to use. - "optimizer_class": "SyncReplayOptimizer", # Whether to use a distribution of epsilons across workers for exploration. "per_worker_exploration": False, # Whether to compute priorities on workers. @@ -159,47 +167,56 @@ # yapf: enable -class DDPGTrainer(DQNTrainer): - """DDPG implementation in TensorFlow.""" - _name = "DDPG" - _default_config = DEFAULT_CONFIG - _policy = DDPGTFPolicy +def make_exploration_schedule(config, worker_index): + # Modification of DQN's schedule to take into account + # `exploration_ou_noise_scale` + if config["per_worker_exploration"]: + assert config["num_workers"] > 1, "This requires multiple workers" + if worker_index >= 0: + # FIXME: what do magic constants mean? (0.4, 7) + max_index = float(config["num_workers"] - 1) + exponent = 1 + worker_index / max_index * 7 + return ConstantSchedule(0.4**exponent) + else: + # local ev should have zero exploration so that eval rollouts + # run properly + return ConstantSchedule(0.0) + elif config["exploration_should_anneal"]: + return LinearSchedule( + schedule_timesteps=int(config["exploration_fraction"] * + config["schedule_max_timesteps"]), + initial_p=1.0, + final_p=config["exploration_final_scale"]) + else: + # *always* add exploration noise + return ConstantSchedule(1.0) + + +def setup_ddpg_exploration(trainer): + trainer.exploration0 = make_exploration_schedule(trainer.config, -1) + trainer.explorations = [ + make_exploration_schedule(trainer.config, i) + for i in range(trainer.config["num_workers"]) + ] - @override(DQNTrainer) - def _train(self): - pure_expl_steps = self.config["pure_exploration_steps"] - if pure_expl_steps: - # tell workers whether they should do pure exploration - only_explore = self.global_timestep < pure_expl_steps - self.workers.local_worker().foreach_trainable_policy( + +def add_pure_exploration_phase(trainer): + global_timestep = trainer.optimizer.num_steps_sampled + pure_expl_steps = trainer.config["pure_exploration_steps"] + if pure_expl_steps: + # tell workers whether they should do pure exploration + only_explore = global_timestep < pure_expl_steps + trainer.workers.local_worker().foreach_trainable_policy( + lambda p, _: p.set_pure_exploration_phase(only_explore)) + for e in trainer.workers.remote_workers(): + e.foreach_trainable_policy.remote( lambda p, _: p.set_pure_exploration_phase(only_explore)) - for e in self.workers.remote_workers(): - e.foreach_trainable_policy.remote( - lambda p, _: p.set_pure_exploration_phase(only_explore)) - return super(DDPGTrainer, self)._train() - - @override(DQNTrainer) - def _make_exploration_schedule(self, worker_index): - # Override DQN's schedule to take into account - # `exploration_ou_noise_scale` - if self.config["per_worker_exploration"]: - assert self.config["num_workers"] > 1, \ - "This requires multiple workers" - if worker_index >= 0: - # FIXME: what do magic constants mean? (0.4, 7) - max_index = float(self.config["num_workers"] - 1) - exponent = 1 + worker_index / max_index * 7 - return ConstantSchedule(0.4**exponent) - else: - # local ev should have zero exploration so that eval rollouts - # run properly - return ConstantSchedule(0.0) - elif self.config["exploration_should_anneal"]: - return LinearSchedule( - schedule_timesteps=int(self.config["exploration_fraction"] * - self.config["schedule_max_timesteps"]), - initial_p=1.0, - final_p=self.config["exploration_final_scale"]) - else: - # *always* add exploration noise - return ConstantSchedule(1.0) + update_worker_explorations(trainer) + + +DDPGTrainer = GenericOffPolicyTrainer.with_updates( + name="DDPG", + default_config=DEFAULT_CONFIG, + default_policy=DDPGTFPolicy, + before_init=setup_ddpg_exploration, + before_train_step=add_pure_exploration_phase) diff --git a/python/ray/rllib/agents/ddpg/td3.py b/python/ray/rllib/agents/ddpg/td3.py index 714c39c6b2f8..ad3675294ce5 100644 --- a/python/ray/rllib/agents/ddpg/td3.py +++ b/python/ray/rllib/agents/ddpg/td3.py @@ -1,3 +1,9 @@ +"""A more stable successor to TD3. + +By default, this uses a near-identical configuration to that reported in the +TD3 paper. +""" + from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -36,7 +42,6 @@ "train_batch_size": 100, "use_huber": False, "target_network_update_freq": 0, - "optimizer_class": "SyncReplayOptimizer", "num_workers": 0, "num_gpus_per_worker": 0, "per_worker_exploration": False, @@ -48,10 +53,5 @@ }, ) - -class TD3Trainer(DDPGTrainer): - """A more stable successor to TD3. By default, this uses a near-identical - configuration to that reported in the TD3 paper.""" - - _name = "TD3" - _default_config = TD3_DEFAULT_CONFIG +TD3Trainer = DDPGTrainer.with_updates( + name="TD3", default_config=TD3_DEFAULT_CONFIG) diff --git a/python/ray/rllib/agents/dqn/apex.py b/python/ray/rllib/agents/dqn/apex.py index 129839a27119..ab89256a6b95 100644 --- a/python/ray/rllib/agents/dqn/apex.py +++ b/python/ray/rllib/agents/dqn/apex.py @@ -3,15 +3,14 @@ from __future__ import print_function from ray.rllib.agents.dqn.dqn import DQNTrainer, DEFAULT_CONFIG as DQN_CONFIG +from ray.rllib.optimizers import AsyncReplayOptimizer from ray.rllib.utils import merge_dicts -from ray.rllib.utils.annotations import override # yapf: disable # __sphinx_doc_begin__ APEX_DEFAULT_CONFIG = merge_dicts( DQN_CONFIG, # see also the options in dqn.py, which are also supported { - "optimizer_class": "AsyncReplayOptimizer", "optimizer": merge_dicts( DQN_CONFIG["optimizer"], { "max_weight_sync_delay": 400, @@ -36,22 +35,50 @@ # yapf: enable -class ApexTrainer(DQNTrainer): - """DQN variant that uses the Ape-X distributed policy optimizer. +def defer_make_workers(trainer, env_creator, policy, config): + # Hack to workaround https://github.com/ray-project/ray/issues/2541 + # The workers will be creatd later, after the optimizer is created + return trainer._make_workers(env_creator, policy, config, 0) - By default, this is configured for a large single node (32 cores). For - running in a large cluster, increase the `num_workers` config var. - """ - _name = "APEX" - _default_config = APEX_DEFAULT_CONFIG +def make_async_optimizer(workers, config): + assert len(workers.remote_workers()) == 0 + extra_config = config["optimizer"].copy() + for key in [ + "prioritized_replay", "prioritized_replay_alpha", + "prioritized_replay_beta", "prioritized_replay_eps" + ]: + if key in config: + extra_config[key] = config[key] + opt = AsyncReplayOptimizer( + workers, + learning_starts=config["learning_starts"], + buffer_size=config["buffer_size"], + train_batch_size=config["train_batch_size"], + sample_batch_size=config["sample_batch_size"], + **extra_config) + workers.add_workers(config["num_workers"]) + opt._set_workers(workers.remote_workers()) + return opt - @override(DQNTrainer) - def update_target_if_needed(self): - # Ape-X updates based on num steps trained, not sampled - if self.optimizer.num_steps_trained - self.last_target_update_ts > \ - self.config["target_network_update_freq"]: - self.workers.local_worker().foreach_trainable_policy( - lambda p, _: p.update_target()) - self.last_target_update_ts = self.optimizer.num_steps_trained - self.num_target_updates += 1 + +def update_target_based_on_num_steps_trained(trainer, fetches): + # Ape-X updates based on num steps trained, not sampled + if (trainer.optimizer.num_steps_trained - + trainer.state["last_target_update_ts"] > + trainer.config["target_network_update_freq"]): + trainer.workers.local_worker().foreach_trainable_policy( + lambda p, _: p.update_target()) + trainer.state["last_target_update_ts"] = ( + trainer.optimizer.num_steps_trained) + trainer.state["num_target_updates"] += 1 + + +APEX_TRAINER_PROPERTIES = { + "make_workers": defer_make_workers, + "make_policy_optimizer": make_async_optimizer, + "after_optimizer_step": update_target_based_on_num_steps_trained, +} + +ApexTrainer = DQNTrainer.with_updates( + name="APEX", default_config=APEX_DEFAULT_CONFIG, **APEX_TRAINER_PROPERTIES) diff --git a/python/ray/rllib/agents/dqn/dqn.py b/python/ray/rllib/agents/dqn/dqn.py index 15379e3fb394..cc418907a0b9 100644 --- a/python/ray/rllib/agents/dqn/dqn.py +++ b/python/ray/rllib/agents/dqn/dqn.py @@ -3,27 +3,17 @@ from __future__ import print_function import logging -import time from ray import tune -from ray.rllib import optimizers -from ray.rllib.agents.trainer import Trainer, with_common_config +from ray.rllib.agents.trainer import with_common_config +from ray.rllib.agents.trainer_template import build_trainer from ray.rllib.agents.dqn.dqn_policy import DQNTFPolicy -from ray.rllib.evaluation.metrics import collect_metrics +from ray.rllib.optimizers import SyncReplayOptimizer from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID -from ray.rllib.utils.annotations import override from ray.rllib.utils.schedules import ConstantSchedule, LinearSchedule logger = logging.getLogger(__name__) -OPTIMIZER_SHARED_CONFIGS = [ - "buffer_size", "prioritized_replay", "prioritized_replay_alpha", - "prioritized_replay_beta", "schedule_max_timesteps", - "beta_annealing_fraction", "final_prioritized_replay_beta", - "prioritized_replay_eps", "sample_batch_size", "train_batch_size", - "learning_starts" -] - # yapf: disable # __sphinx_doc_begin__ DEFAULT_CONFIG = with_common_config({ @@ -53,7 +43,8 @@ # 1.0 to exploration_fraction over this number of timesteps scaled by # exploration_fraction "schedule_max_timesteps": 100000, - # Number of env steps to optimize for before returning + # Minimum env steps to optimize for per train call. This value does + # not affect learning, only the length of iterations. "timesteps_per_iteration": 1000, # Fraction of entire training period over which the exploration rate is # annealed @@ -70,6 +61,11 @@ # If True parameter space noise will be used for exploration # See https://blog.openai.com/better-exploration-with-parameter-noise/ "parameter_noise": False, + # Extra configuration that disables exploration. + "evaluation_config": { + "exploration_fraction": 0, + "exploration_final_eps": 0, + }, # === Replay buffer === # Size of the replay buffer. Note that if async_updates is set, then @@ -115,8 +111,6 @@ # to increase if your environment is particularly slow to sample, or if # you"re using the Async or Ape-X optimizers. "num_workers": 0, - # Optimizer class to use. - "optimizer_class": "SyncReplayOptimizer", # Whether to use a distribution of epsilons across workers for exploration. "per_worker_exploration": False, # Whether to compute priorities on workers. @@ -128,202 +122,175 @@ # yapf: enable -class DQNTrainer(Trainer): - """DQN implementation in TensorFlow.""" - - _name = "DQN" - _default_config = DEFAULT_CONFIG - _policy = DQNTFPolicy - _optimizer_shared_configs = OPTIMIZER_SHARED_CONFIGS - - @override(Trainer) - def _init(self, config, env_creator): - self._validate_config() - - # Update effective batch size to include n-step - adjusted_batch_size = max(config["sample_batch_size"], - config.get("n_step", 1)) - config["sample_batch_size"] = adjusted_batch_size - - self.exploration0 = self._make_exploration_schedule(-1) - self.explorations = [ - self._make_exploration_schedule(i) - for i in range(config["num_workers"]) - ] - - for k in self._optimizer_shared_configs: - if self._name != "DQN" and k in [ - "schedule_max_timesteps", "beta_annealing_fraction", - "final_prioritized_replay_beta" - ]: - # only Rainbow needs annealing prioritized_replay_beta - continue - if k not in config["optimizer"]: - config["optimizer"][k] = config[k] - - if config.get("parameter_noise", False): - if config["callbacks"]["on_episode_start"]: - start_callback = config["callbacks"]["on_episode_start"] - else: - start_callback = None - - def on_episode_start(info): - # as a callback function to sample and pose parameter space - # noise on the parameters of network - policies = info["policy"] - for pi in policies.values(): - pi.add_parameter_noise() - if start_callback: - start_callback(info) - - config["callbacks"]["on_episode_start"] = tune.function( - on_episode_start) - if config["callbacks"]["on_episode_end"]: - end_callback = config["callbacks"]["on_episode_end"] - else: - end_callback = None - - def on_episode_end(info): - # as a callback function to monitor the distance - # between noisy policy and original policy - policies = info["policy"] - episode = info["episode"] - episode.custom_metrics["policy_distance"] = policies[ - DEFAULT_POLICY_ID].pi_distance - if end_callback: - end_callback(info) - - config["callbacks"]["on_episode_end"] = tune.function( - on_episode_end) - - if config["optimizer_class"] != "AsyncReplayOptimizer": - self.workers = self._make_workers( - env_creator, - self._policy, - config, - num_workers=self.config["num_workers"]) - workers_needed = 0 +def make_optimizer(workers, config): + return SyncReplayOptimizer( + workers, + learning_starts=config["learning_starts"], + buffer_size=config["buffer_size"], + prioritized_replay=config["prioritized_replay"], + prioritized_replay_alpha=config["prioritized_replay_alpha"], + prioritized_replay_beta=config["prioritized_replay_beta"], + schedule_max_timesteps=config["schedule_max_timesteps"], + beta_annealing_fraction=config["beta_annealing_fraction"], + final_prioritized_replay_beta=config["final_prioritized_replay_beta"], + prioritized_replay_eps=config["prioritized_replay_eps"], + train_batch_size=config["train_batch_size"], + sample_batch_size=config["sample_batch_size"], + **config["optimizer"]) + + +def check_config_and_setup_param_noise(config): + """Update the config based on settings. + + Rewrites sample_batch_size to take into account n_step truncation, and also + adds the necessary callbacks to support parameter space noise exploration. + """ + + # Update effective batch size to include n-step + adjusted_batch_size = max(config["sample_batch_size"], + config.get("n_step", 1)) + config["sample_batch_size"] = adjusted_batch_size + + if config.get("parameter_noise", False): + if config["batch_mode"] != "complete_episodes": + raise ValueError("Exploration with parameter space noise requires " + "batch_mode to be complete_episodes.") + if config.get("noisy", False): + raise ValueError( + "Exploration with parameter space noise and noisy network " + "cannot be used at the same time.") + if config["callbacks"]["on_episode_start"]: + start_callback = config["callbacks"]["on_episode_start"] + else: + start_callback = None + + def on_episode_start(info): + # as a callback function to sample and pose parameter space + # noise on the parameters of network + policies = info["policy"] + for pi in policies.values(): + pi.add_parameter_noise() + if start_callback: + start_callback(info) + + config["callbacks"]["on_episode_start"] = tune.function( + on_episode_start) + if config["callbacks"]["on_episode_end"]: + end_callback = config["callbacks"]["on_episode_end"] else: - # Hack to workaround https://github.com/ray-project/ray/issues/2541 - self.workers = self._make_workers( - env_creator, self._policy, config, num_workers=0) - workers_needed = self.config["num_workers"] - - self.optimizer = getattr(optimizers, config["optimizer_class"])( - self.workers, **config["optimizer"]) - - # Create the remote workers *after* the replay actors - if workers_needed > 0: - self.workers.add_workers(workers_needed) - self.optimizer._set_workers(self.workers.remote_workers()) - - self.last_target_update_ts = 0 - self.num_target_updates = 0 - - @override(Trainer) - def _train(self): - start_timestep = self.global_timestep - - # Update worker explorations - exp_vals = [self.exploration0.value(self.global_timestep)] - self.workers.local_worker().foreach_trainable_policy( - lambda p, _: p.set_epsilon(exp_vals[0])) - for i, e in enumerate(self.workers.remote_workers()): - exp_val = self.explorations[i].value(self.global_timestep) - e.foreach_trainable_policy.remote( - lambda p, _: p.set_epsilon(exp_val)) - exp_vals.append(exp_val) - - # Do optimization steps - start = time.time() - while (self.global_timestep - start_timestep < - self.config["timesteps_per_iteration"] - ) or time.time() - start < self.config["min_iter_time_s"]: - self.optimizer.step() - self.update_target_if_needed() - - if self.config["per_worker_exploration"]: - # Only collect metrics from the third of workers with lowest eps - result = self.collect_metrics( - selected_workers=self.workers.remote_workers()[ - -len(self.workers.remote_workers()) // 3:]) + end_callback = None + + def on_episode_end(info): + # as a callback function to monitor the distance + # between noisy policy and original policy + policies = info["policy"] + episode = info["episode"] + episode.custom_metrics["policy_distance"] = policies[ + DEFAULT_POLICY_ID].pi_distance + if end_callback: + end_callback(info) + + config["callbacks"]["on_episode_end"] = tune.function(on_episode_end) + + +def get_initial_state(config): + return { + "last_target_update_ts": 0, + "num_target_updates": 0, + } + + +def make_exploration_schedule(config, worker_index): + # Use either a different `eps` per worker, or a linear schedule. + if config["per_worker_exploration"]: + assert config["num_workers"] > 1, \ + "This requires multiple workers" + if worker_index >= 0: + exponent = ( + 1 + worker_index / float(config["num_workers"] - 1) * 7) + return ConstantSchedule(0.4**exponent) else: - result = self.collect_metrics() - - result.update( - timesteps_this_iter=self.global_timestep - start_timestep, - info=dict({ - "min_exploration": min(exp_vals), - "max_exploration": max(exp_vals), - "num_target_updates": self.num_target_updates, - }, **self.optimizer.stats())) - - return result - - def update_target_if_needed(self): - if self.global_timestep - self.last_target_update_ts > \ - self.config["target_network_update_freq"]: - self.workers.local_worker().foreach_trainable_policy( - lambda p, _: p.update_target()) - self.last_target_update_ts = self.global_timestep - self.num_target_updates += 1 - - @property - def global_timestep(self): - return self.optimizer.num_steps_sampled - - def _evaluate(self): - logger.info("Evaluating current policy for {} episodes".format( - self.config["evaluation_num_episodes"])) - self.evaluation_workers.local_worker().restore( - self.workers.local_worker().save()) - self.evaluation_workers.local_worker().foreach_policy( - lambda p, _: p.set_epsilon(0)) - for _ in range(self.config["evaluation_num_episodes"]): - self.evaluation_workers.local_worker().sample() - metrics = collect_metrics(self.evaluation_workers.local_worker()) - return {"evaluation": metrics} - - def _make_exploration_schedule(self, worker_index): - # Use either a different `eps` per worker, or a linear schedule. - if self.config["per_worker_exploration"]: - assert self.config["num_workers"] > 1, \ - "This requires multiple workers" - if worker_index >= 0: - exponent = ( - 1 + - worker_index / float(self.config["num_workers"] - 1) * 7) - return ConstantSchedule(0.4**exponent) - else: - # local ev should have zero exploration so that eval rollouts - # run properly - return ConstantSchedule(0.0) - return LinearSchedule( - schedule_timesteps=int(self.config["exploration_fraction"] * - self.config["schedule_max_timesteps"]), - initial_p=1.0, - final_p=self.config["exploration_final_eps"]) - - def __getstate__(self): - state = Trainer.__getstate__(self) - state.update({ - "num_target_updates": self.num_target_updates, - "last_target_update_ts": self.last_target_update_ts, - }) - return state - - def __setstate__(self, state): - Trainer.__setstate__(self, state) - self.num_target_updates = state["num_target_updates"] - self.last_target_update_ts = state["last_target_update_ts"] - - def _validate_config(self): - if self.config.get("parameter_noise", False): - if self.config["batch_mode"] != "complete_episodes": - raise ValueError( - "Exploration with parameter space noise requires " - "batch_mode to be complete_episodes.") - if self.config.get("noisy", False): - raise ValueError( - "Exploration with parameter space noise and noisy network " - "cannot be used at the same time.") + # local ev should have zero exploration so that eval rollouts + # run properly + return ConstantSchedule(0.0) + return LinearSchedule( + schedule_timesteps=int( + config["exploration_fraction"] * config["schedule_max_timesteps"]), + initial_p=1.0, + final_p=config["exploration_final_eps"]) + + +def setup_exploration(trainer): + trainer.exploration0 = make_exploration_schedule(trainer.config, -1) + trainer.explorations = [ + make_exploration_schedule(trainer.config, i) + for i in range(trainer.config["num_workers"]) + ] + + +def update_worker_explorations(trainer): + global_timestep = trainer.optimizer.num_steps_sampled + exp_vals = [trainer.exploration0.value(global_timestep)] + trainer.workers.local_worker().foreach_trainable_policy( + lambda p, _: p.set_epsilon(exp_vals[0])) + for i, e in enumerate(trainer.workers.remote_workers()): + exp_val = trainer.explorations[i].value(global_timestep) + e.foreach_trainable_policy.remote(lambda p, _: p.set_epsilon(exp_val)) + exp_vals.append(exp_val) + trainer.train_start_timestep = global_timestep + trainer.cur_exp_vals = exp_vals + + +def add_trainer_metrics(trainer, result): + global_timestep = trainer.optimizer.num_steps_sampled + result.update( + timesteps_this_iter=global_timestep - trainer.train_start_timestep, + info=dict({ + "min_exploration": min(trainer.cur_exp_vals), + "max_exploration": max(trainer.cur_exp_vals), + "num_target_updates": trainer.state["num_target_updates"], + }, **trainer.optimizer.stats())) + + +def update_target_if_needed(trainer, fetches): + global_timestep = trainer.optimizer.num_steps_sampled + if global_timestep - trainer.state["last_target_update_ts"] > \ + trainer.config["target_network_update_freq"]: + trainer.workers.local_worker().foreach_trainable_policy( + lambda p, _: p.update_target()) + trainer.state["last_target_update_ts"] = global_timestep + trainer.state["num_target_updates"] += 1 + + +def collect_metrics(trainer): + if trainer.config["per_worker_exploration"]: + # Only collect metrics from the third of workers with lowest eps + result = trainer.collect_metrics( + selected_workers=trainer.workers.remote_workers()[ + -len(trainer.workers.remote_workers()) // 3:]) + else: + result = trainer.collect_metrics() + return result + + +def disable_exploration(trainer): + trainer.evaluation_workers.local_worker().foreach_policy( + lambda p, _: p.set_epsilon(0)) + + +GenericOffPolicyTrainer = build_trainer( + name="GenericOffPolicyAlgorithm", + default_policy=None, + default_config=DEFAULT_CONFIG, + validate_config=check_config_and_setup_param_noise, + get_initial_state=get_initial_state, + make_policy_optimizer=make_optimizer, + before_init=setup_exploration, + before_train_step=update_worker_explorations, + after_optimizer_step=update_target_if_needed, + after_train_result=add_trainer_metrics, + collect_metrics_fn=collect_metrics, + before_evaluate_fn=disable_exploration) + +DQNTrainer = GenericOffPolicyTrainer.with_updates( + name="DQN", default_policy=DQNTFPolicy, default_config=DEFAULT_CONFIG) diff --git a/python/ray/rllib/agents/impala/impala.py b/python/ray/rllib/agents/impala/impala.py index e025a4817f8f..b9699888bfaf 100644 --- a/python/ray/rllib/agents/impala/impala.py +++ b/python/ray/rllib/agents/impala/impala.py @@ -2,33 +2,16 @@ from __future__ import division from __future__ import print_function -import time - from ray.rllib.agents.a3c.a3c_tf_policy import A3CTFPolicy from ray.rllib.agents.impala.vtrace_policy import VTraceTFPolicy from ray.rllib.agents.trainer import Trainer, with_common_config +from ray.rllib.agents.trainer_template import build_trainer from ray.rllib.optimizers import AsyncSamplesOptimizer from ray.rllib.optimizers.aso_tree_aggregator import TreeAggregator from ray.rllib.utils.annotations import override from ray.tune.trainable import Trainable from ray.tune.trial import Resources -OPTIMIZER_SHARED_CONFIGS = [ - "lr", - "num_envs_per_worker", - "num_gpus", - "sample_batch_size", - "train_batch_size", - "replay_buffer_num_slots", - "replay_proportion", - "num_data_loader_buffers", - "max_sample_requests_in_flight_per_worker", - "broadcast_interval", - "num_sgd_iter", - "minibatch_buffer_size", - "num_aggregation_workers", -] - # yapf: disable # __sphinx_doc_begin__ DEFAULT_CONFIG = with_common_config({ @@ -100,37 +83,57 @@ # yapf: enable -class ImpalaTrainer(Trainer): - """IMPALA implementation using DeepMind's V-trace.""" - - _name = "IMPALA" - _default_config = DEFAULT_CONFIG - _policy = VTraceTFPolicy - - @override(Trainer) - def _init(self, config, env_creator): - for k in OPTIMIZER_SHARED_CONFIGS: - if k not in config["optimizer"]: - config["optimizer"][k] = config[k] - policy_cls = self._get_policy() - self.workers = self._make_workers( - self.env_creator, policy_cls, self.config, num_workers=0) - - if self.config["num_aggregation_workers"] > 0: - # Create co-located aggregator actors first for placement pref - aggregators = TreeAggregator.precreate_aggregators( - self.config["num_aggregation_workers"]) - - self.workers.add_workers(config["num_workers"]) - self.optimizer = AsyncSamplesOptimizer(self.workers, - **config["optimizer"]) - if config["entropy_coeff"] < 0: - raise DeprecationWarning("entropy_coeff must be >= 0") - - if self.config["num_aggregation_workers"] > 0: - # Assign the pre-created aggregators to the optimizer - self.optimizer.aggregator.init(aggregators) - +def choose_policy(config): + if config["vtrace"]: + return VTraceTFPolicy + else: + return A3CTFPolicy + + +def validate_config(config): + if config["entropy_coeff"] < 0: + raise DeprecationWarning("entropy_coeff must be >= 0") + + +def defer_make_workers(trainer, env_creator, policy, config): + # Defer worker creation to after the optimizer has been created. + return trainer._make_workers(env_creator, policy, config, 0) + + +def make_aggregators_and_optimizer(workers, config): + if config["num_aggregation_workers"] > 0: + # Create co-located aggregator actors first for placement pref + aggregators = TreeAggregator.precreate_aggregators( + config["num_aggregation_workers"]) + else: + aggregators = None + workers.add_workers(config["num_workers"]) + + optimizer = AsyncSamplesOptimizer( + workers, + lr=config["lr"], + num_envs_per_worker=config["num_envs_per_worker"], + num_gpus=config["num_gpus"], + sample_batch_size=config["sample_batch_size"], + train_batch_size=config["train_batch_size"], + replay_buffer_num_slots=config["replay_buffer_num_slots"], + replay_proportion=config["replay_proportion"], + num_data_loader_buffers=config["num_data_loader_buffers"], + max_sample_requests_in_flight_per_worker=config[ + "max_sample_requests_in_flight_per_worker"], + broadcast_interval=config["broadcast_interval"], + num_sgd_iter=config["num_sgd_iter"], + minibatch_buffer_size=config["minibatch_buffer_size"], + num_aggregation_workers=config["num_aggregation_workers"], + **config["optimizer"]) + + if aggregators: + # Assign the pre-created aggregators to the optimizer + optimizer.aggregator.init(aggregators) + return optimizer + + +class OverrideDefaultResourceRequest(object): @classmethod @override(Trainable) def default_resource_request(cls, config): @@ -143,22 +146,13 @@ def default_resource_request(cls, config): cf["num_aggregation_workers"], extra_gpu=cf["num_gpus_per_worker"] * cf["num_workers"]) - @override(Trainer) - def _train(self): - prev_steps = self.optimizer.num_steps_sampled - start = time.time() - self.optimizer.step() - while (time.time() - start < self.config["min_iter_time_s"] - or self.optimizer.num_steps_sampled == prev_steps): - self.optimizer.step() - result = self.collect_metrics() - result.update(timesteps_this_iter=self.optimizer.num_steps_sampled - - prev_steps) - return result - - def _get_policy(self): - if self.config["vtrace"]: - policy_cls = self._policy - else: - policy_cls = A3CTFPolicy - return policy_cls + +ImpalaTrainer = build_trainer( + name="IMPALA", + default_config=DEFAULT_CONFIG, + default_policy=VTraceTFPolicy, + validate_config=validate_config, + get_policy_class=choose_policy, + make_workers=defer_make_workers, + make_policy_optimizer=make_aggregators_and_optimizer, + mixins=[OverrideDefaultResourceRequest]) diff --git a/python/ray/rllib/agents/marwil/marwil.py b/python/ray/rllib/agents/marwil/marwil.py index b8e01806ca29..29be38a84c32 100644 --- a/python/ray/rllib/agents/marwil/marwil.py +++ b/python/ray/rllib/agents/marwil/marwil.py @@ -2,10 +2,10 @@ from __future__ import division from __future__ import print_function -from ray.rllib.agents.trainer import Trainer, with_common_config +from ray.rllib.agents.trainer import with_common_config +from ray.rllib.agents.trainer_template import build_trainer from ray.rllib.agents.marwil.marwil_policy import MARWILPolicy from ray.rllib.optimizers import SyncBatchReplayOptimizer -from ray.rllib.utils.annotations import override # yapf: disable # __sphinx_doc_begin__ @@ -39,30 +39,17 @@ # yapf: enable -class MARWILTrainer(Trainer): - """MARWIL implementation in TensorFlow.""" +def make_optimizer(workers, config): + return SyncBatchReplayOptimizer( + workers, + learning_starts=config["learning_starts"], + buffer_size=config["replay_buffer_size"], + train_batch_size=config["train_batch_size"], + ) - _name = "MARWIL" - _default_config = DEFAULT_CONFIG - _policy = MARWILPolicy - @override(Trainer) - def _init(self, config, env_creator): - self.workers = self._make_workers(env_creator, self._policy, config, - config["num_workers"]) - self.optimizer = SyncBatchReplayOptimizer( - self.workers, - learning_starts=config["learning_starts"], - buffer_size=config["replay_buffer_size"], - train_batch_size=config["train_batch_size"], - ) - - @override(Trainer) - def _train(self): - prev_steps = self.optimizer.num_steps_sampled - fetches = self.optimizer.step() - res = self.collect_metrics() - res.update( - timesteps_this_iter=self.optimizer.num_steps_sampled - prev_steps, - info=dict(fetches, **res.get("info", {}))) - return res +MARWILTrainer = build_trainer( + name="MARWIL", + default_config=DEFAULT_CONFIG, + default_policy=MARWILPolicy, + make_policy_optimizer=make_optimizer) diff --git a/python/ray/rllib/agents/ppo/appo.py b/python/ray/rllib/agents/ppo/appo.py index 0438b2714221..4b0d9945dec3 100644 --- a/python/ray/rllib/agents/ppo/appo.py +++ b/python/ray/rllib/agents/ppo/appo.py @@ -5,7 +5,6 @@ from ray.rllib.agents.ppo.appo_policy import AsyncPPOTFPolicy from ray.rllib.agents.trainer import with_base_config from ray.rllib.agents import impala -from ray.rllib.utils.annotations import override # yapf: disable # __sphinx_doc_begin__ @@ -51,14 +50,8 @@ # __sphinx_doc_end__ # yapf: enable - -class APPOTrainer(impala.ImpalaTrainer): - """PPO surrogate loss with IMPALA-architecture.""" - - _name = "APPO" - _default_config = DEFAULT_CONFIG - _policy = AsyncPPOTFPolicy - - @override(impala.ImpalaTrainer) - def _get_policy(self): - return AsyncPPOTFPolicy +APPOTrainer = impala.ImpalaTrainer.with_updates( + name="APPO", + default_config=DEFAULT_CONFIG, + default_policy=AsyncPPOTFPolicy, + get_policy_class=lambda _: AsyncPPOTFPolicy) diff --git a/python/ray/rllib/agents/qmix/apex.py b/python/ray/rllib/agents/qmix/apex.py index 65c91d655af2..aac5d83f726a 100644 --- a/python/ray/rllib/agents/qmix/apex.py +++ b/python/ray/rllib/agents/qmix/apex.py @@ -4,15 +4,14 @@ from __future__ import division from __future__ import print_function +from ray.rllib.agents.dqn.apex import APEX_TRAINER_PROPERTIES from ray.rllib.agents.qmix.qmix import QMixTrainer, \ DEFAULT_CONFIG as QMIX_CONFIG -from ray.rllib.utils.annotations import override from ray.rllib.utils import merge_dicts APEX_QMIX_DEFAULT_CONFIG = merge_dicts( QMIX_CONFIG, # see also the options in qmix.py, which are also supported { - "optimizer_class": "AsyncReplayOptimizer", "optimizer": merge_dicts( QMIX_CONFIG["optimizer"], { @@ -34,23 +33,7 @@ }, ) - -class ApexQMixTrainer(QMixTrainer): - """QMIX variant that uses the Ape-X distributed policy optimizer. - - By default, this is configured for a large single node (32 cores). For - running in a large cluster, increase the `num_workers` config var. - """ - - _name = "APEX_QMIX" - _default_config = APEX_QMIX_DEFAULT_CONFIG - - @override(QMixTrainer) - def update_target_if_needed(self): - # Ape-X updates based on num steps trained, not sampled - if self.optimizer.num_steps_trained - self.last_target_update_ts > \ - self.config["target_network_update_freq"]: - self.workers.local_worker().foreach_trainable_policy( - lambda p, _: p.update_target()) - self.last_target_update_ts = self.optimizer.num_steps_trained - self.num_target_updates += 1 +ApexQMixTrainer = QMixTrainer.with_updates( + name="APEX_QMIX", + default_config=APEX_QMIX_DEFAULT_CONFIG, + **APEX_TRAINER_PROPERTIES) diff --git a/python/ray/rllib/agents/qmix/qmix.py b/python/ray/rllib/agents/qmix/qmix.py index 2ad6a3e56f95..6a5bff9d63e8 100644 --- a/python/ray/rllib/agents/qmix/qmix.py +++ b/python/ray/rllib/agents/qmix/qmix.py @@ -3,8 +3,9 @@ from __future__ import print_function from ray.rllib.agents.trainer import with_common_config -from ray.rllib.agents.dqn.dqn import DQNTrainer +from ray.rllib.agents.dqn.dqn import GenericOffPolicyTrainer from ray.rllib.agents.qmix.qmix_policy import QMixTorchPolicy +from ray.rllib.optimizers import SyncBatchReplayOptimizer # yapf: disable # __sphinx_doc_begin__ @@ -71,8 +72,6 @@ # to increase if your environment is particularly slow to sample, or if # you"re using the Async or Ape-X optimizers. "num_workers": 0, - # Optimizer class to use. - "optimizer_class": "SyncBatchReplayOptimizer", # Whether to use a distribution of epsilons across workers for exploration. "per_worker_exploration": False, # Whether to compute priorities on workers. @@ -90,12 +89,16 @@ # yapf: enable -class QMixTrainer(DQNTrainer): - """QMix implementation in PyTorch.""" +def make_sync_batch_optimizer(workers, config): + return SyncBatchReplayOptimizer( + workers, + learning_starts=config["learning_starts"], + buffer_size=config["buffer_size"], + train_batch_size=config["train_batch_size"]) - _name = "QMIX" - _default_config = DEFAULT_CONFIG - _policy = QMixTorchPolicy - _optimizer_shared_configs = [ - "learning_starts", "buffer_size", "train_batch_size" - ] + +QMixTrainer = GenericOffPolicyTrainer.with_updates( + name="QMIX", + default_config=DEFAULT_CONFIG, + default_policy=QMixTorchPolicy, + make_policy_optimizer=make_sync_batch_optimizer) diff --git a/python/ray/rllib/agents/trainer.py b/python/ray/rllib/agents/trainer.py index a0d48d2ef714..8e123cb01458 100644 --- a/python/ray/rllib/agents/trainer.py +++ b/python/ray/rllib/agents/trainer.py @@ -189,6 +189,9 @@ "remote_env_batch_wait_ms": 0, # Minimum time per iteration "min_iter_time_s": 0, + # Minimum env steps to optimize for per train call. This value does + # not affect learning, only the length of iterations. + "timesteps_per_iteration": 0, # === Offline Datasets === # Specify how to generate experiences: @@ -502,6 +505,7 @@ def _evaluate(self): logger.info("Evaluating current policy for {} episodes".format( self.config["evaluation_num_episodes"])) + self._before_evaluate() self.evaluation_workers.local_worker().restore( self.workers.local_worker().save()) for _ in range(self.config["evaluation_num_episodes"]): @@ -510,6 +514,11 @@ def _evaluate(self): metrics = collect_metrics(self.evaluation_workers.local_worker()) return {"evaluation": metrics} + @DeveloperAPI + def _before_evaluate(self): + """Pre-evaluation callback.""" + pass + @PublicAPI def compute_action(self, observation, diff --git a/python/ray/rllib/agents/trainer_template.py b/python/ray/rllib/agents/trainer_template.py index 6af9e1c781e0..ee0b4181c337 100644 --- a/python/ray/rllib/agents/trainer_template.py +++ b/python/ray/rllib/agents/trainer_template.py @@ -6,6 +6,7 @@ from ray.rllib.agents.trainer import Trainer, COMMON_CONFIG from ray.rllib.optimizers import SyncSamplesOptimizer +from ray.rllib.utils import add_mixins from ray.rllib.utils.annotations import override, DeveloperAPI @@ -13,25 +14,47 @@ def build_trainer(name, default_policy, default_config=None, - make_policy_optimizer=None, validate_config=None, + get_initial_state=None, get_policy_class=None, + before_init=None, + make_workers=None, + make_policy_optimizer=None, + after_init=None, before_train_step=None, after_optimizer_step=None, - after_train_result=None): + after_train_result=None, + collect_metrics_fn=None, + before_evaluate_fn=None, + mixins=None): """Helper function for defining a custom trainer. + Functions will be run in this order to initialize the trainer: + 1. Config setup: validate_config, get_initial_state, get_policy + 2. Worker setup: before_init, make_workers, make_policy_optimizer + 3. Post setup: after_init + Arguments: name (str): name of the trainer (e.g., "PPO") default_policy (cls): the default Policy class to use default_config (dict): the default config dict of the algorithm, otherwises uses the Trainer default config - make_policy_optimizer (func): optional function that returns a - PolicyOptimizer instance given (WorkerSet, config) validate_config (func): optional callback that checks a given config for correctness. It may mutate the config as needed. + get_initial_state (func): optional function that returns the initial + state dict given the trainer instance as an argument. The state + dict must be serializable so that it can be checkpointed, and will + be available as the `trainer.state` variable. get_policy_class (func): optional callback that takes a config and returns the policy class to override the default with + before_init (func): optional function to run at the start of trainer + init that takes the trainer instance as argument + make_workers (func): override the method that creates rollout workers. + This takes in (trainer, env_creator, policy, config) as args. + make_policy_optimizer (func): optional function that returns a + PolicyOptimizer instance given (WorkerSet, config) + after_init (func): optional function to run at the end of trainer init + that takes the trainer instance as argument before_train_step (func): optional callback to run before each train() call. It takes the trainer instance as an argument. after_optimizer_step (func): optional callback to run after each @@ -40,27 +63,47 @@ def build_trainer(name, after_train_result (func): optional callback to run at the end of each train() call. It takes the trainer instance and result dict as arguments, and may mutate the result dict as needed. + collect_metrics_fn (func): override the method used to collect metrics. + It takes the trainer instance as argumnt. + before_evaluate_fn (func): callback to run before evaluation. This + takes the trainer instance as argument. + mixins (list): list of any class mixins for the returned trainer class. + These mixins will be applied in order and will have higher + precedence than the Trainer class Returns: a Trainer instance that uses the specified args. """ original_kwargs = locals().copy() + base = add_mixins(Trainer, mixins) - class trainer_cls(Trainer): + class trainer_cls(base): _name = name _default_config = default_config or COMMON_CONFIG _policy = default_policy + def __init__(self, config=None, env=None, logger_creator=None): + Trainer.__init__(self, config, env, logger_creator) + def _init(self, config, env_creator): if validate_config: validate_config(config) + if get_initial_state: + self.state = get_initial_state(self) + else: + self.state = {} if get_policy_class is None: policy = default_policy else: policy = get_policy_class(config) - self.workers = self._make_workers(env_creator, policy, config, - self.config["num_workers"]) + if before_init: + before_init(self) + if make_workers: + self.workers = make_workers(self, env_creator, policy, config) + else: + self.workers = self._make_workers(env_creator, policy, config, + self.config["num_workers"]) if make_policy_optimizer: self.optimizer = make_policy_optimizer(self.workers, config) else: @@ -69,6 +112,8 @@ def _init(self, config, env_creator): **{"train_batch_size": config["train_batch_size"]}) self.optimizer = SyncSamplesOptimizer(self.workers, **optimizer_config) + if after_init: + after_init(self) @override(Trainer) def _train(self): @@ -81,20 +126,46 @@ def _train(self): fetches = self.optimizer.step() if after_optimizer_step: after_optimizer_step(self, fetches) - if time.time() - start > self.config["min_iter_time_s"]: + if (time.time() - start >= self.config["min_iter_time_s"] + and self.optimizer.num_steps_sampled - prev_steps >= + self.config["timesteps_per_iteration"]): break - res = self.collect_metrics() + if collect_metrics_fn: + res = collect_metrics_fn(self) + else: + res = self.collect_metrics() res.update( timesteps_this_iter=self.optimizer.num_steps_sampled - prev_steps, info=res.get("info", {})) + if after_train_result: after_train_result(self, res) return res + @override(Trainer) + def _before_evaluate(self): + if before_evaluate_fn: + before_evaluate_fn(self) + + def __getstate__(self): + state = Trainer.__getstate__(self) + state.update(self.state) + return state + + def __setstate__(self, state): + Trainer.__setstate__(self, state) + self.state = state + @staticmethod def with_updates(**overrides): + """Build a copy of this trainer with the specified overrides. + + Arguments: + overrides (dict): use this to override any of the arguments + originally passed to build_trainer() for this policy. + """ return build_trainer(**dict(original_kwargs, **overrides)) trainer_cls.with_updates = with_updates diff --git a/python/ray/rllib/policy/tf_policy_template.py b/python/ray/rllib/policy/tf_policy_template.py index b7f33fcb0887..37828bfe18b0 100644 --- a/python/ray/rllib/policy/tf_policy_template.py +++ b/python/ray/rllib/policy/tf_policy_template.py @@ -5,6 +5,7 @@ from ray.rllib.policy.dynamic_tf_policy import DynamicTFPolicy from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY from ray.rllib.policy.tf_policy import TFPolicy +from ray.rllib.utils import add_mixins from ray.rllib.utils.annotations import override, DeveloperAPI @@ -89,13 +90,7 @@ def build_tf_policy(name, """ original_kwargs = locals().copy() - base = DynamicTFPolicy - while mixins: - - class new_base(mixins.pop(), base): - pass - - base = new_base + base = add_mixins(DynamicTFPolicy, mixins) class policy_cls(base): def __init__(self, diff --git a/python/ray/rllib/policy/torch_policy_template.py b/python/ray/rllib/policy/torch_policy_template.py index 1f4185f9c12e..f1b0c0c682d6 100644 --- a/python/ray/rllib/policy/torch_policy_template.py +++ b/python/ray/rllib/policy/torch_policy_template.py @@ -5,6 +5,7 @@ from ray.rllib.policy.policy import Policy from ray.rllib.policy.torch_policy import TorchPolicy from ray.rllib.models.catalog import ModelCatalog +from ray.rllib.utils import add_mixins from ray.rllib.utils.annotations import override, DeveloperAPI @@ -56,13 +57,7 @@ def build_torch_policy(name, """ original_kwargs = locals().copy() - base = TorchPolicy - while mixins: - - class new_base(mixins.pop(), base): - pass - - base = new_base + base = add_mixins(TorchPolicy, mixins) class policy_cls(base): def __init__(self, obs_space, action_space, config): diff --git a/python/ray/rllib/utils/__init__.py b/python/ray/rllib/utils/__init__.py index aad5590fd097..bde901e22a9c 100644 --- a/python/ray/rllib/utils/__init__.py +++ b/python/ray/rllib/utils/__init__.py @@ -27,6 +27,21 @@ def __init__(self, *args, **kw): return DeprecationWrapper +def add_mixins(base, mixins): + """Returns a new class with mixins applied in priority order.""" + + mixins = list(mixins or []) + + while mixins: + + class new_base(mixins.pop(), base): + pass + + base = new_base + + return base + + def renamed_agent(cls): """Helper class for renaming Agent => Trainer with a warning.""" From a82e8118a0013278c430339bd3ce1ac8f6bc8906 Mon Sep 17 00:00:00 2001 From: Robert Nishihara Date: Fri, 7 Jun 2019 21:07:27 -0700 Subject: [PATCH 079/118] Fix resource bookkeeping bug with acquiring unknown resource. (#4945) --- python/ray/tests/test_basic.py | 14 ++++++++++++-- src/ray/raylet/node_manager.cc | 6 +++--- src/ray/raylet/scheduling_resources.cc | 15 +++++++++++---- 3 files changed, 26 insertions(+), 9 deletions(-) diff --git a/python/ray/tests/test_basic.py b/python/ray/tests/test_basic.py index 50aeca025362..7f1f78d1b5c4 100644 --- a/python/ray/tests/test_basic.py +++ b/python/ray/tests/test_basic.py @@ -1754,7 +1754,7 @@ def f(n): def g(n): time.sleep(n) - time_buffer = 0.5 + time_buffer = 2 start_time = time.time() ray.get([f.remote(0.5), g.remote(0.5)]) @@ -1878,13 +1878,23 @@ def test(self): def test_zero_cpus(shutdown_only): ray.init(num_cpus=0) + # We should be able to execute a task that requires 0 CPU resources. @ray.remote(num_cpus=0) def f(): return 1 - # The task should be able to execute. ray.get(f.remote()) + # We should be able to create an actor that requires 0 CPU resources. + @ray.remote(num_cpus=0) + class Actor(object): + def method(self): + pass + + a = Actor.remote() + x = a.method.remote() + ray.get(x) + def test_zero_cpus_actor(ray_start_cluster): cluster = ray_start_cluster diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 07dca3c7ab32..b710b0873b0c 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -1809,9 +1809,9 @@ bool NodeManager::AssignTask(const Task &task) { cluster_resource_map_[my_client_id].Acquire(spec.GetRequiredResources()); if (spec.IsActorCreationTask()) { - // Check that we are not placing an actor creation task on a node with 0 CPUs. - RAY_CHECK(cluster_resource_map_[my_client_id].GetTotalResources().GetResourceMap().at( - kCPU_ResourceLabel) != 0); + // Check that the actor's placement resource requirements are satisfied. + RAY_CHECK(spec.GetRequiredPlacementResources().IsSubset( + cluster_resource_map_[my_client_id].GetTotalResources())); worker->SetLifetimeResourceIds(acquired_resources); } else { worker->SetTaskResourceIds(acquired_resources); diff --git a/src/ray/raylet/scheduling_resources.cc b/src/ray/raylet/scheduling_resources.cc index 895535a9a7f0..cdc17307755c 100644 --- a/src/ray/raylet/scheduling_resources.cc +++ b/src/ray/raylet/scheduling_resources.cc @@ -76,7 +76,11 @@ ResourceSet::ResourceSet() {} ResourceSet::ResourceSet( const std::unordered_map &resource_map) - : resource_capacity_(resource_map) {} + : resource_capacity_(resource_map) { + for (auto const &resource_pair : resource_map) { + RAY_CHECK(resource_pair.second > 0); + } +} ResourceSet::ResourceSet(const std::unordered_map &resource_map) { for (auto const &resource_pair : resource_map) { @@ -169,7 +173,8 @@ void ResourceSet::SubtractResourcesStrict(const ResourceSet &other) { const std::string &resource_label = resource_pair.first; const FractionalResourceQuantity &resource_capacity = resource_pair.second; RAY_CHECK(resource_capacity_.count(resource_label) == 1) - << "Attempt to acquire unknown resource: " << resource_label; + << "Attempt to acquire unknown resource: " << resource_label << " capacity " + << resource_capacity.ToDouble(); resource_capacity_[resource_label] -= resource_capacity; // Ensure that quantity is positive. Note, we have to have the check before @@ -233,8 +238,10 @@ FractionalResourceQuantity ResourceSet::GetResource( const ResourceSet ResourceSet::GetNumCpus() const { ResourceSet cpu_resource_set; - cpu_resource_set.resource_capacity_[kCPU_ResourceLabel] = - GetResource(kCPU_ResourceLabel); + const FractionalResourceQuantity cpu_quantity = GetResource(kCPU_ResourceLabel); + if (cpu_quantity > 0) { + cpu_resource_set.resource_capacity_[kCPU_ResourceLabel] = cpu_quantity; + } return cpu_resource_set; } From 85b82b2454145999d274262aa94ce889b7e82cad Mon Sep 17 00:00:00 2001 From: Robert Nishihara Date: Fri, 7 Jun 2019 23:19:10 -0700 Subject: [PATCH 080/118] Update aws keys for uploading wheels to s3. (#4948) --- .travis.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.travis.yml b/.travis.yml index 2266833f8414..1ec95c33dfde 100644 --- a/.travis.yml +++ b/.travis.yml @@ -183,9 +183,9 @@ script: deploy: - provider: s3 - access_key_id: AKIAJ2L7XDUSZVTXI5QA + access_key_id: AKIAU6DMUCJUFL3EX3SM secret_access_key: - secure: OS9V8c/fQ9SIOP+Lg2MIz+PtCSKNQVB3mubscDRHKJcCmOp3cB6AKsC/yepbNZvvjDD/ncW2v6KJVsUEneAeDKrZQWSIpNb34yGAvWb7g4xleLxiadNtx6XEzjWaOcg+Y6409e68XeoHq/5ItopWNQ9p9NHXgsoHbZaOurPyHNskNgwBVaObCy+cCak7ifkITDk6cil0OJYnTbOe3NhcU82Fh5BZzS2+G2qNq8tGNcbfINhq0rruWIBuV5WRB/14CmBR+mou74qFSiiodH/MKbOcplx9+BxoOsTnkl7SeyybcK6DX6jxJCuhSBIjct9uT8Qdovv6mzOMkXvLkLKFkHfkTJSGBRIIZEvkPvzhlEriqTcr4tX/MV8HKs/Acz1NnlD0tNEygOr3VaiSLB0dvpz4iCeI9berqSu/jV1VI1X5iVNfChYbOMQ+OYafJMs5WdO60AMWIHy60U511FjAlbS7IubXBjfhoCItIB1xlVNI7FfKaRbNRwP5qvPenB8FUgZpv3UBg5OZDkeBXSNoLydr0w505p6s8Jqnz750TpVYI11fih5D0N3Ea57OwQr9r/rk+Z8aGeTpWj6hIgQiNkrIf2VZnWTApd+utJPw3X3txUEcnOtcdDnMsPuEIeMvIDrrFMRwzClqMNXq9MewU43wp7cCl67YmDBDKubl7Vs= + secure: J1sX71fKFPQhgWzColllxfzcF877ScBZ1cIl71krZ6SO0LKnwsCScpQck5eZOyQo/Iverwye0iKtE87qNsiRi3+V2D9iulSr18T09j7+FjPKfxAmXmjfrNafoMXTDQroSJblCri5vl+DysISPqImJkWTNaYhGJ9QakoSd5djnAopLNWj6PCR3S50baS49+nB5nSIY3jMhtUzlaBdniFPFC81Cxyuafr4pv6McGRfR/dK+ZnPhdGtMnVeIJXB+ooZKQ26mDJKBPka4jm3u1Oa72b/Atu2RO3MwxTg79LTrMxXKh2OcCqhtD2Z3lz1OltvNSunCuwY8AejCJsfSLbM9mGDoz+xhNUWmYNy48YFf+61OY8PXi8S/9Q817yb3GpLbb2l/P+KMgq9eSEiELIOwuYsDxPX5TuAg6dx0wCNgDEBJoThSQjYl6MgJrLrs7p+JBxp3giedHiy0TLa5hCVKTj3euONAXDArYnnT+DvUIOkaeTk5DClRZbZ0sUXhLy//HuT5WJvjFBJJZ0u0f4RLVb5D7DI4uMZr7+yJPDR2AXCyW9YMaBEbmEYbPaKi283jlEyn7R33+AZlnXv0THHwZ4xvjKKG3/fBSXsOUmv5wmUveEqVGDj1mKPGj9NF8iA5qMm2AaZuJpEEBVBZtSlTZt6ZG7rzAJZGNL52t7xuMo= bucket: ray-wheels acl: public_read region: us-west-2 @@ -197,9 +197,9 @@ deploy: all_branches: true condition: $LINUX_WHEELS = 1 || $MAC_WHEELS = 1 - provider: s3 - access_key_id: AKIAJ2L7XDUSZVTXI5QA + access_key_id: AKIAU6DMUCJUFL3EX3SM secret_access_key: - secure: OS9V8c/fQ9SIOP+Lg2MIz+PtCSKNQVB3mubscDRHKJcCmOp3cB6AKsC/yepbNZvvjDD/ncW2v6KJVsUEneAeDKrZQWSIpNb34yGAvWb7g4xleLxiadNtx6XEzjWaOcg+Y6409e68XeoHq/5ItopWNQ9p9NHXgsoHbZaOurPyHNskNgwBVaObCy+cCak7ifkITDk6cil0OJYnTbOe3NhcU82Fh5BZzS2+G2qNq8tGNcbfINhq0rruWIBuV5WRB/14CmBR+mou74qFSiiodH/MKbOcplx9+BxoOsTnkl7SeyybcK6DX6jxJCuhSBIjct9uT8Qdovv6mzOMkXvLkLKFkHfkTJSGBRIIZEvkPvzhlEriqTcr4tX/MV8HKs/Acz1NnlD0tNEygOr3VaiSLB0dvpz4iCeI9berqSu/jV1VI1X5iVNfChYbOMQ+OYafJMs5WdO60AMWIHy60U511FjAlbS7IubXBjfhoCItIB1xlVNI7FfKaRbNRwP5qvPenB8FUgZpv3UBg5OZDkeBXSNoLydr0w505p6s8Jqnz750TpVYI11fih5D0N3Ea57OwQr9r/rk+Z8aGeTpWj6hIgQiNkrIf2VZnWTApd+utJPw3X3txUEcnOtcdDnMsPuEIeMvIDrrFMRwzClqMNXq9MewU43wp7cCl67YmDBDKubl7Vs= + secure: J1sX71fKFPQhgWzColllxfzcF877ScBZ1cIl71krZ6SO0LKnwsCScpQck5eZOyQo/Iverwye0iKtE87qNsiRi3+V2D9iulSr18T09j7+FjPKfxAmXmjfrNafoMXTDQroSJblCri5vl+DysISPqImJkWTNaYhGJ9QakoSd5djnAopLNWj6PCR3S50baS49+nB5nSIY3jMhtUzlaBdniFPFC81Cxyuafr4pv6McGRfR/dK+ZnPhdGtMnVeIJXB+ooZKQ26mDJKBPka4jm3u1Oa72b/Atu2RO3MwxTg79LTrMxXKh2OcCqhtD2Z3lz1OltvNSunCuwY8AejCJsfSLbM9mGDoz+xhNUWmYNy48YFf+61OY8PXi8S/9Q817yb3GpLbb2l/P+KMgq9eSEiELIOwuYsDxPX5TuAg6dx0wCNgDEBJoThSQjYl6MgJrLrs7p+JBxp3giedHiy0TLa5hCVKTj3euONAXDArYnnT+DvUIOkaeTk5DClRZbZ0sUXhLy//HuT5WJvjFBJJZ0u0f4RLVb5D7DI4uMZr7+yJPDR2AXCyW9YMaBEbmEYbPaKi283jlEyn7R33+AZlnXv0THHwZ4xvjKKG3/fBSXsOUmv5wmUveEqVGDj1mKPGj9NF8iA5qMm2AaZuJpEEBVBZtSlTZt6ZG7rzAJZGNL52t7xuMo= bucket: ray-wheels acl: public_read region: us-west-2 From ec8aaf011b92e24b9764f96a093144d71fd3ebcb Mon Sep 17 00:00:00 2001 From: Robert Nishihara Date: Fri, 7 Jun 2019 23:20:29 -0700 Subject: [PATCH 081/118] Upload wheels on Travis to branchname/commit_id. (#4949) --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 1ec95c33dfde..f1f292c58088 100644 --- a/.travis.yml +++ b/.travis.yml @@ -190,7 +190,7 @@ deploy: acl: public_read region: us-west-2 local_dir: .whl - upload-dir: $TRAVIS_COMMIT + upload-dir: "$TRAVIS_BRANCH/$TRAVIS_COMMIT" skip_cleanup: true on: repo: ray-project/ray From 671c0f769e480e2215f4bd94449342136599f2c0 Mon Sep 17 00:00:00 2001 From: Qing Wang Date: Sat, 8 Jun 2019 22:56:00 +0800 Subject: [PATCH 082/118] [Java] Fix serializing issues of `RaySerializer` (#4887) * Fix * Address comment. --- java/dependencies.bzl | 2 +- java/runtime/pom.xml | 2 +- .../java/org/ray/runtime/RayPyActorImpl.java | 4 +++- .../org/ray/api/test/RaySerializerTest.java | 23 +++++++++++++++++++ 4 files changed, 28 insertions(+), 3 deletions(-) create mode 100644 java/test/src/main/java/org/ray/api/test/RaySerializerTest.java diff --git a/java/dependencies.bzl b/java/dependencies.bzl index d0178ba0f8f4..7c716166d399 100644 --- a/java/dependencies.bzl +++ b/java/dependencies.bzl @@ -11,7 +11,7 @@ def gen_java_deps(): "com.sun.xml.bind:jaxb-impl:2.3.0", "com.typesafe:config:1.3.2", "commons-io:commons-io:2.5", - "de.ruedigermoeller:fst:2.47", + "de.ruedigermoeller:fst:2.57", "javax.xml.bind:jaxb-api:2.3.0", "org.apache.commons:commons-lang3:3.4", "org.ow2.asm:asm:6.0", diff --git a/java/runtime/pom.xml b/java/runtime/pom.xml index 1ce51971c03e..c75e2eeef13f 100644 --- a/java/runtime/pom.xml +++ b/java/runtime/pom.xml @@ -54,7 +54,7 @@ de.ruedigermoeller fst - 2.47 + 2.57 org.apache.commons diff --git a/java/runtime/src/main/java/org/ray/runtime/RayPyActorImpl.java b/java/runtime/src/main/java/org/ray/runtime/RayPyActorImpl.java index 2938478d22e8..f1f26d40874e 100644 --- a/java/runtime/src/main/java/org/ray/runtime/RayPyActorImpl.java +++ b/java/runtime/src/main/java/org/ray/runtime/RayPyActorImpl.java @@ -20,7 +20,9 @@ public class RayPyActorImpl extends RayActorImpl implements RayPyActor { */ private String className; - private RayPyActorImpl() {} + // Note that this empty constructor must be public + // since it'll be needed when deserializing. + public RayPyActorImpl() {} public RayPyActorImpl(UniqueId id, String moduleName, String className) { super(id); diff --git a/java/test/src/main/java/org/ray/api/test/RaySerializerTest.java b/java/test/src/main/java/org/ray/api/test/RaySerializerTest.java new file mode 100644 index 000000000000..33283abc7a36 --- /dev/null +++ b/java/test/src/main/java/org/ray/api/test/RaySerializerTest.java @@ -0,0 +1,23 @@ +package org.ray.api.test; + +import org.ray.api.RayPyActor; +import org.ray.api.id.UniqueId; +import org.ray.runtime.RayPyActorImpl; +import org.ray.runtime.util.Serializer; +import org.testng.Assert; +import org.testng.annotations.Test; + +public class RaySerializerTest { + + @Test + public void testSerializePyActor() { + final UniqueId pyActorId = UniqueId.randomId(); + RayPyActor pyActor = new RayPyActorImpl(pyActorId, "test", "RaySerializerTest"); + byte[] bytes = Serializer.encode(pyActor); + RayPyActor result = Serializer.decode(bytes); + Assert.assertEquals(result.getId(), pyActorId); + Assert.assertEquals(result.getModuleName(), "test"); + Assert.assertEquals(result.getClassName(), "RaySerializerTest"); + } + +} From 4f8e100fe0417da4fe1098defbfa478088502244 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Sun, 9 Jun 2019 19:20:55 -0700 Subject: [PATCH 083/118] fix (#4950) --- python/ray/rllib/examples/saving_experiences.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/python/ray/rllib/examples/saving_experiences.py b/python/ray/rllib/examples/saving_experiences.py index 7a29b0fe7b0d..d2de88302d23 100644 --- a/python/ray/rllib/examples/saving_experiences.py +++ b/python/ray/rllib/examples/saving_experiences.py @@ -7,6 +7,7 @@ import gym import numpy as np +from ray.rllib.models.preprocessors import get_preprocessor from ray.rllib.evaluation.sample_batch_builder import SampleBatchBuilder from ray.rllib.offline.json_writer import JsonWriter @@ -18,6 +19,12 @@ # simulator is available, but let's do it anyways for example purposes: env = gym.make("CartPole-v0") + # RLlib uses preprocessors to implement transforms such as one-hot encoding + # and flattening of tuple and dict observations. For CartPole a no-op + # preprocessor is used, but this may be relevant for more complex envs. + prep = get_preprocessor(env.observation_space)(env.observation_space) + print("The preprocessor is", prep) + for eps_id in range(100): obs = env.reset() prev_action = np.zeros_like(env.action_space.sample()) @@ -31,7 +38,7 @@ t=t, eps_id=eps_id, agent_index=0, - obs=obs, + obs=prep.transform(obs), actions=action, action_prob=1.0, # put the true action probability here rewards=rew, @@ -39,7 +46,7 @@ prev_rewards=prev_reward, dones=done, infos=info, - new_obs=new_obs) + new_obs=prep.transform(new_obs)) obs = new_obs prev_action = action prev_reward = rew From e6baffba563d241bfb1a5fa0ad7d43e92db85587 Mon Sep 17 00:00:00 2001 From: Qing Wang Date: Mon, 10 Jun 2019 23:52:08 +0800 Subject: [PATCH 084/118] [Java] Add inner class `Builder` to build call options. (#4956) * Add Builder class * format * Refactor by IDE * Remove uncessary dependency --- .../ray/api/options/ActorCreationOptions.java | 34 +++++++++++++------ .../java/org/ray/api/options/CallOptions.java | 23 ++++++++++--- .../ray/api/test/ActorReconstructionTest.java | 8 ++--- .../org/ray/api/test/DynamicResourceTest.java | 3 +- .../java/org/ray/api/test/HelloWorldTest.java | 1 + .../ray/api/test/ResourcesManagementTest.java | 18 +++++----- 6 files changed, 59 insertions(+), 28 deletions(-) diff --git a/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java b/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java index e4f54f0094c4..d1e92f7bb9e9 100644 --- a/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java +++ b/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java @@ -1,5 +1,6 @@ package org.ray.api.options; +import java.util.HashMap; import java.util.Map; /** @@ -12,19 +13,32 @@ public class ActorCreationOptions extends BaseTaskOptions { public final int maxReconstructions; - public ActorCreationOptions() { - super(); - this.maxReconstructions = NO_RECONSTRUCTION; - } - - public ActorCreationOptions(Map resources) { + private ActorCreationOptions(Map resources, int maxReconstructions) { super(resources); - this.maxReconstructions = NO_RECONSTRUCTION; + this.maxReconstructions = maxReconstructions; } + /** + * The inner class for building ActorCreationOptions. + */ + public static class Builder { - public ActorCreationOptions(Map resources, int maxReconstructions) { - super(resources); - this.maxReconstructions = maxReconstructions; + private Map resources = new HashMap<>(); + private int maxReconstructions = NO_RECONSTRUCTION; + + public Builder setResources(Map resources) { + this.resources = resources; + return this; + } + + public Builder setMaxReconstructions(int maxReconstructions) { + this.maxReconstructions = maxReconstructions; + return this; + } + + public ActorCreationOptions createActorCreationOptions() { + return new ActorCreationOptions(resources, maxReconstructions); + } } + } diff --git a/java/api/src/main/java/org/ray/api/options/CallOptions.java b/java/api/src/main/java/org/ray/api/options/CallOptions.java index 84adfc122e04..1e5b61bf16d3 100644 --- a/java/api/src/main/java/org/ray/api/options/CallOptions.java +++ b/java/api/src/main/java/org/ray/api/options/CallOptions.java @@ -1,5 +1,6 @@ package org.ray.api.options; +import java.util.HashMap; import java.util.Map; /** @@ -7,12 +8,24 @@ */ public class CallOptions extends BaseTaskOptions { - public CallOptions() { - super(); - } - - public CallOptions(Map resources) { + private CallOptions(Map resources) { super(resources); } + /** + * This inner class for building CallOptions. + */ + public static class Builder { + + private Map resources = new HashMap<>(); + + public Builder setResources(Map resources) { + this.resources = resources; + return this; + } + + public CallOptions createCallOptions() { + return new CallOptions(resources); + } + } } diff --git a/java/test/src/main/java/org/ray/api/test/ActorReconstructionTest.java b/java/test/src/main/java/org/ray/api/test/ActorReconstructionTest.java index e575daa84f13..149c87f55931 100644 --- a/java/test/src/main/java/org/ray/api/test/ActorReconstructionTest.java +++ b/java/test/src/main/java/org/ray/api/test/ActorReconstructionTest.java @@ -3,7 +3,6 @@ import static org.ray.runtime.util.SystemUtil.pid; import java.io.IOException; -import java.util.HashMap; import java.util.List; import java.util.concurrent.TimeUnit; import org.ray.api.Checkpointable; @@ -47,7 +46,8 @@ public int getPid() { @Test public void testActorReconstruction() throws InterruptedException, IOException { TestUtils.skipTestUnderSingleProcess(); - ActorCreationOptions options = new ActorCreationOptions(new HashMap<>(), 1); + ActorCreationOptions options = + new ActorCreationOptions.Builder().setMaxReconstructions(1).createActorCreationOptions(); RayActor actor = Ray.createActor(Counter::new, options); // Call increase 3 times. for (int i = 0; i < 3; i++) { @@ -127,8 +127,8 @@ public void checkpointExpired(UniqueId actorId, UniqueId checkpointId) { @Test public void testActorCheckpointing() throws IOException, InterruptedException { TestUtils.skipTestUnderSingleProcess(); - - ActorCreationOptions options = new ActorCreationOptions(new HashMap<>(), 1); + ActorCreationOptions options = + new ActorCreationOptions.Builder().setMaxReconstructions(1).createActorCreationOptions(); RayActor actor = Ray.createActor(CheckpointableCounter::new, options); // Call increase 3 times. for (int i = 0; i < 3; i++) { diff --git a/java/test/src/main/java/org/ray/api/test/DynamicResourceTest.java b/java/test/src/main/java/org/ray/api/test/DynamicResourceTest.java index ffda0732287e..79b3eba0ed13 100644 --- a/java/test/src/main/java/org/ray/api/test/DynamicResourceTest.java +++ b/java/test/src/main/java/org/ray/api/test/DynamicResourceTest.java @@ -23,7 +23,8 @@ public static String sayHi() { @Test public void testSetResource() { TestUtils.skipTestUnderSingleProcess(); - CallOptions op1 = new CallOptions(ImmutableMap.of("A", 10.0)); + CallOptions op1 = + new CallOptions.Builder().setResources(ImmutableMap.of("A", 10.0)).createCallOptions(); RayObject obj = Ray.call(DynamicResourceTest::sayHi, op1); WaitResult result = Ray.wait(ImmutableList.of(obj), 1, 1000); Assert.assertEquals(result.getReady().size(), 0); diff --git a/java/test/src/main/java/org/ray/api/test/HelloWorldTest.java b/java/test/src/main/java/org/ray/api/test/HelloWorldTest.java index feb07fe2cd42..04883bdf8673 100644 --- a/java/test/src/main/java/org/ray/api/test/HelloWorldTest.java +++ b/java/test/src/main/java/org/ray/api/test/HelloWorldTest.java @@ -33,4 +33,5 @@ public void testHelloWorld() { String helloWorld = Ray.call(HelloWorldTest::merge, hello, world).get(); Assert.assertEquals("hello,world!", helloWorld); } + } diff --git a/java/test/src/main/java/org/ray/api/test/ResourcesManagementTest.java b/java/test/src/main/java/org/ray/api/test/ResourcesManagementTest.java index c3d0e4152e5a..dca559764b87 100644 --- a/java/test/src/main/java/org/ray/api/test/ResourcesManagementTest.java +++ b/java/test/src/main/java/org/ray/api/test/ResourcesManagementTest.java @@ -46,14 +46,16 @@ public Integer echo(Integer number) { @Test public void testMethods() { TestUtils.skipTestUnderSingleProcess(); - CallOptions callOptions1 = new CallOptions(ImmutableMap.of("CPU", 4.0)); + CallOptions callOptions1 = + new CallOptions.Builder().setResources(ImmutableMap.of("CPU", 4.0)).createCallOptions(); // This is a case that can satisfy required resources. // The static resources for test are "CPU:4,RES-A:4". RayObject result1 = Ray.call(ResourcesManagementTest::echo, 100, callOptions1); Assert.assertEquals(100, (int) result1.get()); - CallOptions callOptions2 = new CallOptions(ImmutableMap.of("CPU", 4.0)); + CallOptions callOptions2 = + new CallOptions.Builder().setResources(ImmutableMap.of("CPU", 4.0)).createCallOptions(); // This is a case that can't satisfy required resources. // The static resources for test are "CPU:4,RES-A:4". @@ -64,7 +66,8 @@ public void testMethods() { Assert.assertEquals(0, waitResult.getUnready().size()); try { - CallOptions callOptions3 = new CallOptions(ImmutableMap.of("CPU", 0.0)); + CallOptions callOptions3 = + new CallOptions.Builder().setResources(ImmutableMap.of("CPU", 0.0)).createCallOptions(); Assert.fail(); } catch (RuntimeException e) { // We should receive a RuntimeException indicates that we should not @@ -76,9 +79,8 @@ public void testMethods() { public void testActors() { TestUtils.skipTestUnderSingleProcess(); - ActorCreationOptions actorCreationOptions1 = - new ActorCreationOptions(ImmutableMap.of("CPU", 2.0)); - + ActorCreationOptions actorCreationOptions1 = new ActorCreationOptions.Builder() + .setResources(ImmutableMap.of("CPU", 2.0)).createActorCreationOptions(); // This is a case that can satisfy required resources. // The static resources for test are "CPU:4,RES-A:4". RayActor echo1 = Ray.createActor(Echo::new, actorCreationOptions1); @@ -87,8 +89,8 @@ public void testActors() { // This is a case that can't satisfy required resources. // The static resources for test are "CPU:4,RES-A:4". - ActorCreationOptions actorCreationOptions2 = - new ActorCreationOptions(ImmutableMap.of("CPU", 8.0)); + ActorCreationOptions actorCreationOptions2 = new ActorCreationOptions.Builder() + .setResources(ImmutableMap.of("CPU", 8.0)).createActorCreationOptions(); RayActor echo2 = Ray.createActor(Echo::new, actorCreationOptions2); From 6f4899232280797eeb8d8cdee74e9f06717525e4 Mon Sep 17 00:00:00 2001 From: Robert Nishihara Date: Mon, 10 Jun 2019 23:04:01 -0700 Subject: [PATCH 085/118] Make release stress tests work and improve them. (#4955) --- .../application_cluster_template.yaml | 6 +- .../run_application_stress_tests.sh | 88 +++++++---- ci/stress_tests/run_stress_tests.sh | 47 ++++-- ci/stress_tests/stress_testing_config.yaml | 2 +- ...ks_and_transfers.py => test_many_tasks.py} | 0 dev/RELEASE_PROCESS.rst | 143 +++++++++--------- python/ray/autoscaler/updater.py | 7 +- 7 files changed, 168 insertions(+), 125 deletions(-) rename ci/stress_tests/{test_many_tasks_and_transfers.py => test_many_tasks.py} (100%) diff --git a/ci/stress_tests/application_cluster_template.yaml b/ci/stress_tests/application_cluster_template.yaml index d6ccf4769b04..9218c2cf7356 100644 --- a/ci/stress_tests/application_cluster_template.yaml +++ b/ci/stress_tests/application_cluster_template.yaml @@ -37,7 +37,7 @@ provider: # Availability zone(s), comma-separated, that nodes may be launched in. # Nodes are currently spread between zones by a round-robin approach, # however this implementation detail should not be relied upon. - availability_zone: us-west-2a,us-west-2b + availability_zone: us-west-2b # How Ray will authenticate with newly launched nodes. auth: @@ -90,8 +90,8 @@ file_mounts: { # List of shell commands to run to set up nodes. setup_commands: - echo 'export PATH="$HOME/anaconda3/envs/tensorflow_<<>>/bin:$PATH"' >> ~/.bashrc - - ray || wget https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev1-<<>>-manylinux1_x86_64.whl - - rllib || pip install -U ray-0.8.0.dev1-<<>>-manylinux1_x86_64.whl[rllib] + - ray || wget https://s3-us-west-2.amazonaws.com/ray-wheels/releases/<<>>/<<>>/ray-<<>>-<<>>-manylinux1_x86_64.whl + - rllib || pip install -U ray-<<>>-<<>>-manylinux1_x86_64.whl[rllib] - pip install tensorflow-gpu==1.12.0 - echo "sudo halt" | at now + 60 minutes # Consider uncommenting these if you also want to run apt-get commands during setup diff --git a/ci/stress_tests/run_application_stress_tests.sh b/ci/stress_tests/run_application_stress_tests.sh index a8ded40fa797..293530928745 100755 --- a/ci/stress_tests/run_application_stress_tests.sh +++ b/ci/stress_tests/run_application_stress_tests.sh @@ -1,4 +1,11 @@ #!/usr/bin/env bash + +# This script should be run as follows: +# ./run_application_stress_tests.sh +# For example, might be 0.7.1 +# and might be bc3b6efdb6933d410563ee70f690855c05f25483. The commit +# should be the latest commit on the branch "releases/". + # This script runs all of the application tests. # Currently includes an IMPALA stress test and a SGD stress test. # on both Python 2.7 and 3.6. @@ -10,26 +17,39 @@ # This script will exit with code 1 if the test did not run successfully. +# Show explicitly which commands are currently running. This should only be AFTER +# the private key is placed. +set -x ROOT_DIR=$(cd "$(dirname "${BASH_SOURCE:-$0}")"; pwd) RESULT_FILE=$ROOT_DIR/"results-$(date '+%Y-%m-%d_%H-%M-%S').log" -echo "Logging to" $RESULT_FILE -echo -e $RAY_AWS_SSH_KEY > /root/.ssh/ray-autoscaler_us-west-2.pem && chmod 400 /root/.ssh/ray-autoscaler_us-west-2.pem || true +touch "$RESULT_FILE" +echo "Logging to" "$RESULT_FILE" +if [[ -z "$1" ]]; then + echo "ERROR: The first argument must be the Ray version string." + exit 1 +else + RAY_VERSION=$1 +fi -# Show explicitly which commands are currently running. This should only be AFTER -# the private key is placed. -set -x +if [[ -z "$2" ]]; then + echo "ERROR: The second argument must be the commit hash to test." + exit 1 +else + RAY_COMMIT=$2 +fi -touch $RESULT_FILE +echo "Testing ray==$RAY_VERSION at commit $RAY_COMMIT." +echo "The wheels used will live under https://s3-us-west-2.amazonaws.com/ray-wheels/releases/$RAY_VERSION/$RAY_COMMIT/" # This function identifies the right string for the Ray wheel. _find_wheel_str(){ local python_version=$1 # echo "PYTHON_VERSION", $python_version local wheel_str="" - if [ $python_version == "p27" ]; then + if [ "$python_version" == "p27" ]; then wheel_str="cp27-cp27mu" else wheel_str="cp36-cp36m" @@ -41,7 +61,7 @@ _find_wheel_str(){ # Actual test runtime is roughly 10 minutes. test_impala(){ local PYTHON_VERSION=$1 - local WHEEL_STR=$(_find_wheel_str $PYTHON_VERSION) + local WHEEL_STR=$(_find_wheel_str "$PYTHON_VERSION") pushd "$ROOT_DIR" local TEST_NAME="rllib_impala_$PYTHON_VERSION" @@ -50,32 +70,34 @@ test_impala(){ cat application_cluster_template.yaml | sed -e " + s/<<>>/$RAY_VERSION/g; + s/<<>>/$RAY_COMMIT/; s/<<>>/$TEST_NAME/; - s/<<>>/g3.16xlarge/; + s/<<>>/p3.16xlarge/; s/<<>>/m5.24xlarge/; s/<<>>/5/; s/<<>>/5/; s/<<>>/$PYTHON_VERSION/; - s/<<>>/$WHEEL_STR/;" > $CLUSTER + s/<<>>/$WHEEL_STR/;" > "$CLUSTER" echo "Try running IMPALA stress test." { RLLIB_DIR=../../python/ray/rllib/ - ray --logging-level=DEBUG up -y $CLUSTER && - ray rsync_up $CLUSTER $RLLIB_DIR/tuned_examples/ tuned_examples/ && + ray --logging-level=DEBUG up -y "$CLUSTER" && + ray rsync_up "$CLUSTER" $RLLIB_DIR/tuned_examples/ tuned_examples/ && sleep 1 && - ray --logging-level=DEBUG exec $CLUSTER "rllib || true" && - ray --logging-level=DEBUG exec $CLUSTER " + ray --logging-level=DEBUG exec "$CLUSTER" "rllib || true" && + ray --logging-level=DEBUG exec "$CLUSTER" " rllib train -f tuned_examples/atari-impala-large.yaml --redis-address='localhost:6379' --queue-trials" && - echo "PASS: IMPALA Test for" $PYTHON_VERSION >> $RESULT_FILE - } || echo "FAIL: IMPALA Test for" $PYTHON_VERSION >> $RESULT_FILE + echo "PASS: IMPALA Test for" "$PYTHON_VERSION" >> "$RESULT_FILE" + } || echo "FAIL: IMPALA Test for" "$PYTHON_VERSION" >> "$RESULT_FILE" # Tear down cluster. if [ "$DEBUG_MODE" = "" ]; then - ray down -y $CLUSTER - rm $CLUSTER + ray down -y "$CLUSTER" + rm "$CLUSTER" else - echo "Not tearing down cluster" $CLUSTER + echo "Not tearing down cluster" "$CLUSTER" fi popd } @@ -93,32 +115,34 @@ test_sgd(){ cat application_cluster_template.yaml | sed -e " + s/<<>>/$RAY_VERSION/g; + s/<<>>/$RAY_COMMIT/; s/<<>>/$TEST_NAME/; - s/<<>>/g3.16xlarge/; - s/<<>>/g3.16xlarge/; + s/<<>>/p3.16xlarge/; + s/<<>>/p3.16xlarge/; s/<<>>/3/; s/<<>>/3/; s/<<>>/$PYTHON_VERSION/; - s/<<>>/$WHEEL_STR/;" > $CLUSTER + s/<<>>/$WHEEL_STR/;" > "$CLUSTER" echo "Try running SGD stress test." { SGD_DIR=$ROOT_DIR/../../python/ray/experimental/sgd/ - ray --logging-level=DEBUG up -y $CLUSTER && + ray --logging-level=DEBUG up -y "$CLUSTER" && # TODO: fix submit so that args work - ray rsync_up $CLUSTER $SGD_DIR/mnist_example.py mnist_example.py && + ray rsync_up "$CLUSTER" "$SGD_DIR/mnist_example.py" mnist_example.py && sleep 1 && - ray --logging-level=DEBUG exec $CLUSTER " + ray --logging-level=DEBUG exec "$CLUSTER" " python mnist_example.py --redis-address=localhost:6379 --num-iters=2000 --num-workers=8 --devices-per-worker=2 --gpu" && - echo "PASS: SGD Test for" $PYTHON_VERSION >> $RESULT_FILE - } || echo "FAIL: SGD Test for" $PYTHON_VERSION >> $RESULT_FILE + echo "PASS: SGD Test for" "$PYTHON_VERSION" >> "$RESULT_FILE" + } || echo "FAIL: SGD Test for" "$PYTHON_VERSION" >> "$RESULT_FILE" # Tear down cluster. if [ "$DEBUG_MODE" = "" ]; then - ray down -y $CLUSTER - rm $CLUSTER + ray down -y "$CLUSTER" + rm "$CLUSTER" else - echo "Not tearing down cluster" $CLUSTER + echo "Not tearing down cluster" "$CLUSTER" fi popd } @@ -130,6 +154,6 @@ do test_sgd $PYTHON_VERSION done -cat $RESULT_FILE -cat $RESULT_FILE | grep FAIL > test.log +cat "$RESULT_FILE" +cat "$RESULT_FILE" | grep FAIL > test.log [ ! -s test.log ] || exit 1 diff --git a/ci/stress_tests/run_stress_tests.sh b/ci/stress_tests/run_stress_tests.sh index 1d4d102092ee..f92e8c592d40 100755 --- a/ci/stress_tests/run_stress_tests.sh +++ b/ci/stress_tests/run_stress_tests.sh @@ -1,40 +1,61 @@ #!/usr/bin/env bash +# Show explicitly which commands are currently running. +set -x + ROOT_DIR=$(cd "$(dirname "${BASH_SOURCE:-$0}")"; pwd) RESULT_FILE=$ROOT_DIR/results-$(date '+%Y-%m-%d_%H-%M-%S').log -echo "Logging to" $RESULT_FILE -echo -e $RAY_AWS_SSH_KEY > /root/.ssh/ray-autoscaler_us-west-2.pem && chmod 400 /root/.ssh/ray-autoscaler_us-west-2.pem || true +touch "$RESULT_FILE" +echo "Logging to" "$RESULT_FILE" -# Show explicitly which commands are currently running. This should only be AFTER -# the private key is placed. -set -x +if [[ -z "$1" ]]; then + echo "ERROR: The first argument must be the Ray version string." + exit 1 +else + RAY_VERSION=$1 +fi -touch $RESULT_FILE +if [[ -z "$2" ]]; then + echo "ERROR: The second argument must be the commit hash to test." + exit 1 +else + RAY_COMMIT=$2 +fi + +echo "Testing ray==$RAY_VERSION at commit $RAY_COMMIT." +echo "The wheels used will live under https://s3-us-west-2.amazonaws.com/ray-wheels/releases/$RAY_VERSION/$RAY_COMMIT/" run_test(){ local test_name=$1 - local CLUSTER="stress_testing_config.yaml" + local CLUSTER="stress_testing_config_temporary.yaml" + + cat stress_testing_config.yaml | + sed -e " + s/<<>>/$RAY_VERSION/g; + s/<<>>/$RAY_COMMIT/;" > "$CLUSTER" + echo "Try running $test_name." { ray up -y $CLUSTER --cluster-name "$test_name" && sleep 1 && - ray --logging-level=DEBUG submit $CLUSTER --cluster-name "$test_name" "$test_name.py" - } || echo "FAIL: $test_name" >> $RESULT_FILE + ray --logging-level=DEBUG submit "$CLUSTER" --cluster-name "$test_name" "$test_name.py" + } || echo "FAIL: $test_name" >> "$RESULT_FILE" # Tear down cluster. if [ "$DEBUG_MODE" = "" ]; then ray down -y $CLUSTER --cluster-name "$test_name" + rm "$CLUSTER" else - echo "Not tearing down cluster" $CLUSTER + echo "Not tearing down cluster" "$CLUSTER" fi } pushd "$ROOT_DIR" - run_test test_many_tasks_and_transfers + run_test test_many_tasks run_test test_dead_actors popd -cat $RESULT_FILE -[ ! -s $RESULT_FILE ] || exit 1 +cat "$RESULT_FILE" +[ ! -s "$RESULT_FILE" ] || exit 1 diff --git a/ci/stress_tests/stress_testing_config.yaml b/ci/stress_tests/stress_testing_config.yaml index 793c1338432d..ae878963094f 100644 --- a/ci/stress_tests/stress_testing_config.yaml +++ b/ci/stress_tests/stress_testing_config.yaml @@ -101,7 +101,7 @@ setup_commands: # - ray/ci/travis/install-bazel.sh - pip install boto3==1.4.8 cython==0.29.0 # - cd ray/python; git checkout master; git pull; pip install -e . --verbose - - pip install https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev1-cp36-cp36m-manylinux1_x86_64.whl + - pip install https://s3-us-west-2.amazonaws.com/ray-wheels/releases/<<>>/<<>>/ray-<<>>-cp36-cp36m-manylinux1_x86_64.whl - echo "sudo halt" | at now + 60 minutes # Custom commands that will be run on the head node after common setup. diff --git a/ci/stress_tests/test_many_tasks_and_transfers.py b/ci/stress_tests/test_many_tasks.py similarity index 100% rename from ci/stress_tests/test_many_tasks_and_transfers.py rename to ci/stress_tests/test_many_tasks.py diff --git a/dev/RELEASE_PROCESS.rst b/dev/RELEASE_PROCESS.rst index 62862506e1ed..3b78cef5eda5 100644 --- a/dev/RELEASE_PROCESS.rst +++ b/dev/RELEASE_PROCESS.rst @@ -6,38 +6,45 @@ This document describes the process for creating new releases. 1. **Increment the Python version:** Create a PR that increments the Python package version. See `this example`_. -2. **Download the Travis-built wheels:** Once Travis has completed the tests, - the wheels from this commit can be downloaded from S3 to do testing, etc. - The URL is structured like this: - ``https://s3-us-west-2.amazonaws.com/ray-wheels//`` - where ```` is replaced by the ID of the commit and the ```` - is the incremented version from the previous step. The ```` can - be determined by looking at the OS/Version matrix in the documentation_. - -3. **Create a release branch:** This branch should also have the same commit ID as the - previous two steps. In order to create the branch, locally checkout the commit ID - i.e. ``git checkout ``. Then checkout a new branch of the format - ``releases/``. The release number must match the increment in - the first step. Then push that branch to the ray repo: - ``git push upstream releases/``. +2. **Bump version on Ray master branch again:** Create a pull request to + increment the version of the master branch. The format of the new version is + as follows: + + New minor release (e.g., 0.7.0): Increment the minor version and append + ``.dev0`` to the version. For example, if the version of the new release is + 0.7.0, the master branch needs to be updated to 0.8.0.dev0. + + New micro release (e.g., 0.7.1): Increment the ``dev`` number, such that the + number after ``dev`` equals the micro version. For example, if the version + of the new release is 0.7.1, the master branch needs to be updated to + 0.8.0.dev1. + + This can be merged as soon as step 1 is complete. + +3. **Create a release branch:** Create the branch from the version bump PR (the + one from step 1, not step 2). In order to create the branch, locally checkout + the commit ID i.e., ``git checkout ``. Then checkout a new branch of + the format ``releases/``. Then push that branch to the ray + repo: ``git push upstream releases/``. 4. **Testing:** Before a release is created, significant testing should be done. - Run the scripts `ci/stress_tests/run_stress_tests.sh`_ and - `ci/stress_tests/run_application_stress_tests.sh`_ and make sure they - pass. You **MUST** modify the autoscaler config file and replace - ``<>`` and ``<>`` with the appropriate - values to test the correct wheels. This will use the autoscaler to start a bunch of - machines and run some tests. Any new stress tests should be added to this - script so that they will be run automatically for future release testing. - -5. **Resolve release-blockers:** Should any release blocking issues arise, - there are two ways these issues are resolved: A PR to patch the issue or a - revert commit that removes the breaking change from the release. In the case - of a PR, that PR should be created against master. Once it is merged, the - release master should ``git cherry-pick`` the commit to the release branch. - If the decision is to revert a commit that caused the release blocker, the - release master should ``git revert`` the commit to be reverted on the - release branch. Push these changes directly to the release branch. + Run the following scripts + + .. code-block:: bash + + ray/ci/stress_tests/run_stress_tests.sh + ray/ci/stress_tests/run_application_stress_tests.sh + + and make sure they pass. If they pass, it will be obvious that they passed. + This will use the autoscaler to start a bunch of machines and run some tests. + +5. **Resolve release-blockers:** If a release blocking issue arises, there are + two ways the issue can be resolved: 1) Fix the issue on the master branch and + cherry-pick the relevant commit (using ``git cherry-pick``) onto the release + branch. 2) Revert the commit that introduced the bug on the release branch + (using ``git revert``), but not on the master. + + These changes should then be pushed directly to the release branch. 6. **Download all the wheels:** Now the release is ready to begin final testing. The wheels are automatically uploaded to S3, even on the release @@ -47,20 +54,20 @@ This document describes the process for creating new releases. export RAY_HASH=... # e.g., 618147f57fb40368448da3b2fb4fd213828fa12b export RAY_VERSION=... # e.g., 0.7.0 - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/$RAY_HASH/ray-$RAY_VERSION-cp27-cp27mu-manylinux1_x86_64.whl - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/$RAY_HASH/ray-$RAY_VERSION-cp35-cp35m-manylinux1_x86_64.whl - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/$RAY_HASH/ray-$RAY_VERSION-cp36-cp36m-manylinux1_x86_64.whl - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/$RAY_HASH/ray-$RAY_VERSION-cp37-cp37m-manylinux1_x86_64.whl - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/$RAY_HASH/ray-$RAY_VERSION-cp27-cp27m-macosx_10_6_intel.whl - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/$RAY_HASH/ray-$RAY_VERSION-cp35-cp35m-macosx_10_6_intel.whl - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/$RAY_HASH/ray-$RAY_VERSION-cp36-cp36m-macosx_10_6_intel.whl - pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/$RAY_HASH/ray-$RAY_VERSION-cp37-cp37m-macosx_10_6_intel.whl + pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/releases/$RAY_VERSION/$RAY_HASH/ray-$RAY_VERSION-cp27-cp27mu-manylinux1_x86_64.whl + pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/releases/$RAY_VERSION/$RAY_HASH/ray-$RAY_VERSION-cp35-cp35m-manylinux1_x86_64.whl + pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/releases/$RAY_VERSION/$RAY_HASH/ray-$RAY_VERSION-cp36-cp36m-manylinux1_x86_64.whl + pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/releases/$RAY_VERSION/$RAY_HASH/ray-$RAY_VERSION-cp37-cp37m-manylinux1_x86_64.whl + pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/releases/$RAY_VERSION/$RAY_HASH/ray-$RAY_VERSION-cp27-cp27m-macosx_10_6_intel.whl + pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/releases/$RAY_VERSION/$RAY_HASH/ray-$RAY_VERSION-cp35-cp35m-macosx_10_6_intel.whl + pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/releases/$RAY_VERSION/$RAY_HASH/ray-$RAY_VERSION-cp36-cp36m-macosx_10_6_intel.whl + pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/releases/$RAY_VERSION/$RAY_HASH/ray-$RAY_VERSION-cp37-cp37m-macosx_10_6_intel.whl 7. **Final Testing:** Send a link to the wheels to the other contributors and - core members of the Ray project. Make sure the wheels are tested on Ubuntu, - Mac OSX 10.12, and Mac OSX 10.13+. This testing should verify that the - wheels are correct and that all release blockers have been resolved. Should - a new release blocker be found, repeat steps 5-7. + core members of the Ray project. Make sure the wheels are tested on Ubuntu + and MacOS (ideally multiple versions of Ubuntu and MacOS). This testing + should verify that the wheels are correct and that all release blockers have + been resolved. Should a new release blocker be found, repeat steps 5-7. 8. **Upload to PyPI Test:** Upload the wheels to the PyPI test site using ``twine`` (ask Robert to add you as a maintainer to the PyPI project). You'll @@ -68,11 +75,11 @@ This document describes the process for creating new releases. .. code-block:: bash - twine upload --repository-url https://test.pypi.org/legacy/ray/.whl/* + twine upload --repository-url https://test.pypi.org/legacy/ ray/.whl/* assuming that you've downloaded the wheels from the ``ray-wheels`` S3 bucket and put them in ``ray/.whl``, that you've installed ``twine`` through - ``pip``, and that you've made PyPI accounts. + ``pip``, and that you've created both PyPI accounts. Test that you can install the wheels with pip from the PyPI test repository with @@ -86,7 +93,7 @@ This document describes the process for creating new releases. installed by checking ``ray.__version__`` and ``ray.__file__``. Do this at least for MacOS and for Linux, as well as for Python 2 and Python - 3. Also do this for different versions of MacOS. + 3. 9. **Upload to PyPI:** Now that you've tested the wheels on the PyPI test repository, they can be uploaded to the main PyPI repository. Be careful, @@ -107,41 +114,31 @@ This document describes the process for creating new releases. finds the correct Ray version, and successfully runs some simple scripts on both MacOS and Linux as well as Python 2 and Python 3. -10. **Create a GitHub release:** Create a GitHub release through the `GitHub website`_. - The release should be created at the commit from the previous - step. This should include **release notes**. Copy the style and formatting - used by previous releases. Create a draft of the release notes containing - information about substantial changes/updates/bugfixes and their PR number. - Once you have a draft, make sure you solicit feedback from other Ray - developers before publishing. Use the following to get started: +10. **Create a GitHub release:** Create a GitHub release through the + `GitHub website`_. The release should be created at the commit from the + previous step. This should include **release notes**. Copy the style and + formatting used by previous releases. Create a draft of the release notes + containing information about substantial changes/updates/bugfixes and their + PR numbers. Once you have a draft, make sure you solicit feedback from other + Ray developers before publishing. Use the following to get started: .. code-block:: bash git pull origin master --tags git log $(git describe --tags --abbrev=0)..HEAD --pretty=format:"%s" | sort -11. **Bump version on Ray master branch:** Create a pull request to increment the - version of the master branch. The format of the new version is as follows: - - New minor release (e.g., 0.7.0): Increment the minor version and append ``.dev0`` to - the version. For example, if the version of the new release is 0.7.0, the master - branch needs to be updated to 0.8.0.dev0. `Example PR for minor release` - - New micro release (e.g., 0.7.1): Increment the ``dev`` number, such that the number - after ``dev`` equals the micro version. For example, if the version of the new - release is 0.7.1, the master branch needs to be updated to 0.8.0.dev1. +11. **Update version numbers throughout codebase:** Suppose we just released + 0.7.1. The previous release version number (in this case 0.7.0) and the + previous dev version number (in this case 0.8.0.dev0) appear in many places + throughout the code base including the installation documentation, the + example autoscaler config files, and the testing scripts. Search for all of + the occurrences of these version numbers and update them to use the new + release and dev version numbers. **NOTE:** Not all of the version numbers + should be replaced. For example, ``0.7.0`` appears in this file but should + not be updated. -12. **Update version numbers throughout codebase:** Suppose we just released 0.7.1. The - previous release version number (in this case 0.7.0) and the previous dev version number - (in this case 0.8.0.dev0) appear in many places throughout the code base including - the installation documentation, the example autoscaler config files, and the testing - scripts. Search for all of the occurrences of these version numbers and update them to - use the new release and dev version numbers. +12. **Improve the release process:** Find some way to improve the release + process so that whoever manages the release next will have an easier time. -.. _documentation: https://ray.readthedocs.io/en/latest/installation.html#trying-snapshots-from-master -.. _`documentation for building wheels`: https://github.com/ray-project/ray/blob/master/python/README-building-wheels.md -.. _`ci/stress_tests/run_stress_tests.sh`: https://github.com/ray-project/ray/blob/master/ci/stress_tests/run_stress_tests.sh -.. _`ci/stress_tests/run_application_stress_tests.sh`: https://github.com/ray-project/ray/blob/master/ci/stress_tests/run_application_stress_tests.sh .. _`this example`: https://github.com/ray-project/ray/pull/4226 .. _`GitHub website`: https://github.com/ray-project/ray/releases -.. _`Example PR for minor release`: https://github.com/ray-project/ray/pull/4845 diff --git a/python/ray/autoscaler/updater.py b/python/ray/autoscaler/updater.py index c86750fe399d..d42bf041ac8c 100644 --- a/python/ray/autoscaler/updater.py +++ b/python/ray/autoscaler/updater.py @@ -165,9 +165,10 @@ def wait_for_ssh(self, deadline): logger.debug("NodeUpdater: " "{}: Waiting for SSH...".format(self.node_id)) - with open("/dev/null", "w") as redirect: - self.ssh_cmd( - "uptime", connect_timeout=5, redirect=redirect) + # Setting redirect=False allows the user to see errors like + # unix_listener: path "/tmp/rkn_ray_ssh_sockets/..." too long + # for Unix domain socket. + self.ssh_cmd("uptime", connect_timeout=5, redirect=False) return True From 1e2b64958054be4252721fe8d544ab4998468fb7 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 10 Jun 2019 23:46:37 -0700 Subject: [PATCH 086/118] Use proper session directory for debug_string.txt (#4960) --- python/ray/services.py | 1 + src/ray/raylet/main.cc | 3 +++ src/ray/raylet/node_manager.cc | 3 ++- src/ray/raylet/node_manager.h | 2 ++ 4 files changed, 8 insertions(+), 1 deletion(-) diff --git a/python/ray/services.py b/python/ray/services.py index 00ae4e1a2b09..2c843f7bbbc7 100644 --- a/python/ray/services.py +++ b/python/ray/services.py @@ -1194,6 +1194,7 @@ def start_raylet(redis_address, "--java_worker_command={}".format(java_worker_command), "--redis_password={}".format(redis_password or ""), "--temp_dir={}".format(temp_dir), + "--session_dir={}".format(session_dir), ] process_info = start_ray_process( command, diff --git a/src/ray/raylet/main.cc b/src/ray/raylet/main.cc index e75981d8b752..c6e581cec9b7 100644 --- a/src/ray/raylet/main.cc +++ b/src/ray/raylet/main.cc @@ -22,6 +22,7 @@ DEFINE_string(python_worker_command, "", "Python worker command."); DEFINE_string(java_worker_command, "", "Java worker command."); DEFINE_string(redis_password, "", "The password of redis."); DEFINE_string(temp_dir, "", "Temporary directory."); +DEFINE_string(session_dir, "", "The path of this ray session directory."); DEFINE_bool(disable_stats, false, "Whether disable the stats."); DEFINE_string(stat_address, "127.0.0.1:8888", "The address that we report metrics to."); DEFINE_bool(enable_stdout_exporter, false, @@ -61,6 +62,7 @@ int main(int argc, char *argv[]) { const std::string java_worker_command = FLAGS_java_worker_command; const std::string redis_password = FLAGS_redis_password; const std::string temp_dir = FLAGS_temp_dir; + const std::string session_dir = FLAGS_session_dir; const bool disable_stats = FLAGS_disable_stats; const std::string stat_address = FLAGS_stat_address; const bool enable_stdout_exporter = FLAGS_enable_stdout_exporter; @@ -132,6 +134,7 @@ int main(int argc, char *argv[]) { node_manager_config.max_lineage_size = RayConfig::instance().max_lineage_size(); node_manager_config.store_socket_name = store_socket_name; node_manager_config.temp_dir = temp_dir; + node_manager_config.session_dir = session_dir; // Configuration for the object manager. ray::ObjectManagerConfig object_manager_config; diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index b710b0873b0c..5a97239faaf8 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -2334,7 +2334,8 @@ void NodeManager::ForwardTask( void NodeManager::DumpDebugState() const { std::fstream fs; - fs.open(temp_dir_ + "/debug_state.txt", std::fstream::out | std::fstream::trunc); + fs.open(initial_config_.session_dir + "/debug_state.txt", + std::fstream::out | std::fstream::trunc); fs << DebugString(); fs.close(); } diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index 576ffbc23f72..3f7e4d7da97c 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -48,6 +48,8 @@ struct NodeManagerConfig { std::string store_socket_name; /// The path to the ray temp dir. std::string temp_dir; + /// The path of this ray session dir. + std::string session_dir; }; class NodeManager { From ebb3b3b92833f9a910efcf0bba2a801e7461e42c Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 10 Jun 2019 23:49:04 -0700 Subject: [PATCH 087/118] [core] Use int64_t instead of int to keep track of fractional resources (#4959) --- src/ray/raylet/scheduling_resources.cc | 3 ++- src/ray/raylet/scheduling_resources.h | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/ray/raylet/scheduling_resources.cc b/src/ray/raylet/scheduling_resources.cc index cdc17307755c..923e6aad9d85 100644 --- a/src/ray/raylet/scheduling_resources.cc +++ b/src/ray/raylet/scheduling_resources.cc @@ -18,7 +18,8 @@ FractionalResourceQuantity::FractionalResourceQuantity(double resource_quantity) RAY_CHECK(resource_quantity >= 0) << "Resource capacity, " << resource_quantity << ", should be nonnegative."; - resource_quantity_ = static_cast(resource_quantity * kResourceConversionFactor); + resource_quantity_ = + static_cast(resource_quantity * kResourceConversionFactor); } const FractionalResourceQuantity FractionalResourceQuantity::operator+( diff --git a/src/ray/raylet/scheduling_resources.h b/src/ray/raylet/scheduling_resources.h index 9f64ddae6b45..9e3a2a64ce4a 100644 --- a/src/ray/raylet/scheduling_resources.h +++ b/src/ray/raylet/scheduling_resources.h @@ -58,7 +58,7 @@ class FractionalResourceQuantity { private: /// The resource quantity represented as 1/kResourceConversionFactor-th of a /// unit. - int resource_quantity_; + int64_t resource_quantity_; }; /// \class ResourceSet From 472c36ed1eded314f9198b2b8a0bc0bf30b7f703 Mon Sep 17 00:00:00 2001 From: Zhijun Fu <37800433+zhijunfu@users.noreply.github.com> Date: Wed, 12 Jun 2019 10:10:12 +0800 Subject: [PATCH 088/118] [core worker] add task submission & execution interface (#4922) --- BUILD.bazel | 11 +- src/ray/core_worker/common.h | 4 +- src/ray/core_worker/context.cc | 3 +- src/ray/core_worker/context.h | 2 +- src/ray/core_worker/core_worker.cc | 30 ++- src/ray/core_worker/core_worker.h | 25 ++- src/ray/core_worker/core_worker_test.cc | 268 ++++++++++++++++++++++-- src/ray/core_worker/mock_worker.cc | 66 ++++++ src/ray/core_worker/object_interface.cc | 40 ++-- src/ray/core_worker/object_interface.h | 9 +- src/ray/core_worker/task_execution.cc | 77 ++++++- src/ray/core_worker/task_execution.h | 25 ++- src/ray/core_worker/task_interface.cc | 129 +++++++++++- src/ray/core_worker/task_interface.h | 66 +++++- src/ray/raylet/node_manager.cc | 2 +- src/ray/test/run_core_worker_tests.sh | 5 +- 16 files changed, 680 insertions(+), 82 deletions(-) create mode 100644 src/ray/core_worker/mock_worker.cc diff --git a/BUILD.bazel b/BUILD.bazel index 36f02e292fa1..27ab40ef74d1 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -114,6 +114,7 @@ cc_library( ], exclude = [ "src/ray/core_worker/*_test.cc", + "src/ray/core_worker/mock_worker.cc", ], ), hdrs = glob([ @@ -127,7 +128,15 @@ cc_library( ], ) -# This test is run by src/ray/test/run_core_worker_tests.sh +cc_binary( + name = "mock_worker", + srcs = ["src/ray/core_worker/mock_worker.cc"], + copts = COPTS, + deps = [ + ":core_worker_lib", + ], +) + cc_binary( name = "core_worker_test", srcs = ["src/ray/core_worker/core_worker_test.cc"], diff --git a/src/ray/core_worker/common.h b/src/ray/core_worker/common.h index 8317bf181207..b11fabfe46f8 100644 --- a/src/ray/core_worker/common.h +++ b/src/ray/core_worker/common.h @@ -12,12 +12,12 @@ namespace ray { enum class WorkerType { WORKER, DRIVER }; /// Language of Ray tasks and workers. -enum class Language { PYTHON, JAVA }; +enum class WorkerLanguage { PYTHON, JAVA }; /// Information about a remote function. struct RayFunction { /// Language of the remote function. - const Language language; + const WorkerLanguage language; /// Function descriptor of the remote function. const std::vector function_descriptor; }; diff --git a/src/ray/core_worker/context.cc b/src/ray/core_worker/context.cc index fedcfc6625d9..660330e5cee3 100644 --- a/src/ray/core_worker/context.cc +++ b/src/ray/core_worker/context.cc @@ -1,5 +1,5 @@ -#include "context.h" +#include "ray/core_worker/context.h" namespace ray { @@ -23,7 +23,6 @@ struct WorkerThreadContext { void SetCurrentTask(const raylet::TaskSpecification &spec) { SetCurrentTask(spec.TaskId()); } - private: /// The task ID for current task. TaskID current_task_id; diff --git a/src/ray/core_worker/context.h b/src/ray/core_worker/context.h index 6e0cf3f9f2cf..932d02891b6a 100644 --- a/src/ray/core_worker/context.h +++ b/src/ray/core_worker/context.h @@ -1,7 +1,7 @@ #ifndef RAY_CORE_WORKER_CONTEXT_H #define RAY_CORE_WORKER_CONTEXT_H -#include "common.h" +#include "ray/core_worker/common.h" #include "ray/raylet/task_spec.h" namespace ray { diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 82f2d885ec58..033409196d9b 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -1,9 +1,10 @@ -#include "core_worker.h" -#include "context.h" +#include "ray/core_worker/core_worker.h" +#include "ray/core_worker/context.h" namespace ray { -CoreWorker::CoreWorker(const enum WorkerType worker_type, const enum Language language, +CoreWorker::CoreWorker(const enum WorkerType worker_type, + const enum WorkerLanguage language, const std::string &store_socket, const std::string &raylet_socket, DriverID driver_id) : worker_type_(worker_type), @@ -11,20 +12,28 @@ CoreWorker::CoreWorker(const enum WorkerType worker_type, const enum Language la worker_context_(worker_type, driver_id), store_socket_(store_socket), raylet_socket_(raylet_socket), + is_initialized_(false), task_interface_(*this), object_interface_(*this), - task_execution_interface_(*this) {} + task_execution_interface_(*this) { + switch (language_) { + case ray::WorkerLanguage::JAVA: + task_language_ = ::Language::JAVA; + break; + case ray::WorkerLanguage::PYTHON: + task_language_ = ::Language::PYTHON; + break; + default: + RAY_LOG(FATAL) << "Unsupported worker language: " << static_cast(language_); + break; + } +} Status CoreWorker::Connect() { // connect to plasma. RAY_ARROW_RETURN_NOT_OK(store_client_.Connect(store_socket_)); // connect to raylet. - ::Language lang = ::Language::PYTHON; - if (language_ == ray::Language::JAVA) { - lang = ::Language::JAVA; - } - // TODO: currently RayletClient would crash in its constructor if it cannot // connect to Raylet after a number of retries, this needs to be changed // so that the worker (java/python .etc) can retrieve and handle the error @@ -32,7 +41,8 @@ Status CoreWorker::Connect() { raylet_client_ = std::unique_ptr( new RayletClient(raylet_socket_, worker_context_.GetWorkerID(), (worker_type_ == ray::WorkerType::WORKER), - worker_context_.GetCurrentDriverID(), lang)); + worker_context_.GetCurrentDriverID(), task_language_)); + is_initialized_ = true; return Status::OK(); } diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index 951b55451f09..c038b76ce53f 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -1,13 +1,13 @@ #ifndef RAY_CORE_WORKER_CORE_WORKER_H #define RAY_CORE_WORKER_CORE_WORKER_H -#include "common.h" -#include "context.h" -#include "object_interface.h" #include "ray/common/buffer.h" +#include "ray/core_worker/common.h" +#include "ray/core_worker/context.h" +#include "ray/core_worker/object_interface.h" +#include "ray/core_worker/task_execution.h" +#include "ray/core_worker/task_interface.h" #include "ray/raylet/raylet_client.h" -#include "task_execution.h" -#include "task_interface.h" namespace ray { @@ -20,7 +20,7 @@ class CoreWorker { /// /// \param[in] worker_type Type of this worker. /// \param[in] langauge Language of this worker. - CoreWorker(const WorkerType worker_type, const Language language, + CoreWorker(const WorkerType worker_type, const WorkerLanguage language, const std::string &store_socket, const std::string &raylet_socket, DriverID driver_id = DriverID::Nil()); @@ -31,7 +31,7 @@ class CoreWorker { enum WorkerType WorkerType() const { return worker_type_; } /// Language of this worker. - enum Language Language() const { return language_; } + enum WorkerLanguage Language() const { return language_; } /// Return the `CoreWorkerTaskInterface` that contains the methods related to task /// submisson. @@ -50,7 +50,10 @@ class CoreWorker { const enum WorkerType worker_type_; /// Language of this worker. - const enum Language language_; + const enum WorkerLanguage language_; + + /// Language of this worker as specified in flatbuf (used by task spec). + ::Language task_language_; /// Worker context per thread. WorkerContext worker_context_; @@ -64,9 +67,15 @@ class CoreWorker { /// Plasma store client. plasma::PlasmaClient store_client_; + /// Mutex to protect store_client_. + std::mutex store_client_mutex_; + /// Raylet client. std::unique_ptr raylet_client_; + /// Whether this worker has been initialized. + bool is_initialized_; + /// The `CoreWorkerTaskInterface` instance. CoreWorkerTaskInterface task_interface_; diff --git a/src/ray/core_worker/core_worker_test.cc b/src/ray/core_worker/core_worker_test.cc index e440aae24d67..fedfb9c2356b 100644 --- a/src/ray/core_worker/core_worker_test.cc +++ b/src/ray/core_worker/core_worker_test.cc @@ -2,9 +2,9 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" -#include "context.h" -#include "core_worker.h" #include "ray/common/buffer.h" +#include "ray/core_worker/context.h" +#include "ray/core_worker/core_worker.h" #include "ray/raylet/raylet_client.h" #include @@ -18,6 +18,7 @@ namespace ray { std::string store_executable; std::string raylet_executable; +std::string mock_worker_executable; ray::ObjectID RandomObjectID() { return ObjectID::FromRandom(); } @@ -32,6 +33,9 @@ static void flushall_redis(void) { class CoreWorkerTest : public ::testing::Test { public: CoreWorkerTest(int num_nodes) { + // flush redis first. + flushall_redis(); + RAY_CHECK(num_nodes >= 0); if (num_nodes > 0) { raylet_socket_names_.resize(num_nodes); @@ -43,10 +47,12 @@ class CoreWorkerTest : public ::testing::Test { store_socket = StartStore(); } - // start raylet on each node + // start raylet on each node. Assign each node with different resources so that + // a task can be scheduled to the desired node. for (int i = 0; i < num_nodes; i++) { - raylet_socket_names_[i] = StartRaylet(raylet_store_socket_names_[i], "127.0.0.1", - "127.0.0.1", "\"CPU,4.0\""); + raylet_socket_names_[i] = + StartRaylet(raylet_store_socket_names_[i], "127.0.0.1", "127.0.0.1", + "\"CPU,4.0,resource" + std::to_string(i) + ",10\""); } } @@ -66,7 +72,7 @@ class CoreWorkerTest : public ::testing::Test { std::string plasma_command = store_executable + " -m 10000000 -s " + store_socket_name + " 1> /dev/null 2> /dev/null & echo $! > " + store_pid; - RAY_LOG(INFO) << plasma_command; + RAY_LOG(DEBUG) << plasma_command; RAY_CHECK(system(plasma_command.c_str()) == 0); usleep(200 * 1000); return store_socket_name; @@ -75,7 +81,7 @@ class CoreWorkerTest : public ::testing::Test { void StopStore(std::string store_socket_name) { std::string store_pid = store_socket_name + ".pid"; std::string kill_9 = "kill -9 `cat " + store_pid + "`"; - RAY_LOG(INFO) << kill_9; + RAY_LOG(DEBUG) << kill_9; ASSERT_TRUE(system(kill_9.c_str()) == 0); ASSERT_TRUE(system(("rm -rf " + store_socket_name).c_str()) == 0); ASSERT_TRUE(system(("rm -rf " + store_socket_name + ".pid").c_str()) == 0); @@ -91,13 +97,14 @@ class CoreWorkerTest : public ::testing::Test { .append(" --node_ip_address=" + node_ip_address) .append(" --redis_address=" + redis_address) .append(" --redis_port=6379") - .append(" --num_initial_workers=0") + .append(" --num_initial_workers=1") .append(" --maximum_startup_concurrency=10") .append(" --static_resource_list=" + resource) - .append(" --python_worker_command=NoneCmd") + .append(" --python_worker_command=\"" + mock_worker_executable + " " + + store_socket_name + " " + raylet_socket_name + "\"") .append(" & echo $! > " + raylet_socket_name + ".pid"); - RAY_LOG(INFO) << "Ray Start command: " << ray_start_cmd; + RAY_LOG(DEBUG) << "Ray Start command: " << ray_start_cmd; RAY_CHECK(system(ray_start_cmd.c_str()) == 0); usleep(200 * 1000); return raylet_socket_name; @@ -106,16 +113,134 @@ class CoreWorkerTest : public ::testing::Test { void StopRaylet(std::string raylet_socket_name) { std::string raylet_pid = raylet_socket_name + ".pid"; std::string kill_9 = "kill -9 `cat " + raylet_pid + "`"; - RAY_LOG(INFO) << kill_9; + RAY_LOG(DEBUG) << kill_9; ASSERT_TRUE(system(kill_9.c_str()) == 0); ASSERT_TRUE(system(("rm -rf " + raylet_socket_name).c_str()) == 0); ASSERT_TRUE(system(("rm -rf " + raylet_socket_name + ".pid").c_str()) == 0); } - void SetUp() { flushall_redis(); } + void SetUp() {} void TearDown() {} + void TestNormalTask(const std::unordered_map &resources) { + CoreWorker driver(WorkerType::DRIVER, WorkerLanguage::PYTHON, + raylet_store_socket_names_[0], raylet_socket_names_[0], + DriverID::FromRandom()); + + RAY_CHECK_OK(driver.Connect()); + + // Test pass by value. + { + uint8_t array1[] = {1, 2, 3, 4, 5, 6, 7, 8}; + + auto buffer1 = std::make_shared(array1, sizeof(array1)); + + RayFunction func{ray::WorkerLanguage::PYTHON, {}}; + std::vector args; + args.emplace_back(TaskArg::PassByValue(buffer1)); + + TaskOptions options; + + std::vector return_ids; + RAY_CHECK_OK(driver.Tasks().SubmitTask(func, args, options, &return_ids)); + + ASSERT_EQ(return_ids.size(), 1); + + std::vector> results; + RAY_CHECK_OK(driver.Objects().Get(return_ids, -1, &results)); + + ASSERT_EQ(results.size(), 1); + ASSERT_EQ(results[0]->Size(), buffer1->Size()); + ASSERT_EQ(memcmp(results[0]->Data(), buffer1->Data(), buffer1->Size()), 0); + } + + // Test pass by reference. + { + uint8_t array1[] = {10, 11, 12, 13, 14, 15}; + auto buffer1 = std::make_shared(array1, sizeof(array1)); + + ObjectID object_id; + RAY_CHECK_OK(driver.Objects().Put(*buffer1, &object_id)); + + std::vector args; + args.emplace_back(TaskArg::PassByReference(object_id)); + + RayFunction func{ray::WorkerLanguage::PYTHON, {}}; + TaskOptions options; + + std::vector return_ids; + RAY_CHECK_OK(driver.Tasks().SubmitTask(func, args, options, &return_ids)); + + ASSERT_EQ(return_ids.size(), 1); + + std::vector> results; + RAY_CHECK_OK(driver.Objects().Get(return_ids, -1, &results)); + + ASSERT_EQ(results.size(), 1); + ASSERT_EQ(results[0]->Size(), buffer1->Size()); + ASSERT_EQ(memcmp(results[0]->Data(), buffer1->Data(), buffer1->Size()), 0); + } + } + + void TestActorTask(const std::unordered_map &resources) { + CoreWorker driver(WorkerType::DRIVER, WorkerLanguage::PYTHON, + raylet_store_socket_names_[0], raylet_socket_names_[0], + DriverID::FromRandom()); + RAY_CHECK_OK(driver.Connect()); + + std::unique_ptr actor_handle; + + // Test creating actor. + { + uint8_t array[] = {1, 2, 3}; + auto buffer = std::make_shared(array, sizeof(array)); + + RayFunction func{ray::WorkerLanguage::PYTHON, {}}; + std::vector args; + args.emplace_back(TaskArg::PassByValue(buffer)); + + ActorCreationOptions actor_options{0, resources}; + + // Create an actor. + RAY_CHECK_OK(driver.Tasks().CreateActor(func, args, actor_options, &actor_handle)); + } + + // Test submitting a task for that actor. + { + uint8_t array1[] = {1, 2, 3, 4, 5, 6, 7, 8}; + uint8_t array2[] = {10, 11, 12, 13, 14, 15}; + + auto buffer1 = std::make_shared(array1, sizeof(array1)); + auto buffer2 = std::make_shared(array2, sizeof(array2)); + + ObjectID object_id; + RAY_CHECK_OK(driver.Objects().Put(*buffer1, &object_id)); + + // Create arguments with PassByRef and PassByValue. + std::vector args; + args.emplace_back(TaskArg::PassByReference(object_id)); + args.emplace_back(TaskArg::PassByValue(buffer2)); + + TaskOptions options{1, resources}; + std::vector return_ids; + RayFunction func{ray::WorkerLanguage::PYTHON, {}}; + RAY_CHECK_OK(driver.Tasks().SubmitActorTask(*actor_handle, func, args, options, + &return_ids)); + RAY_CHECK(return_ids.size() == 1); + + std::vector> results; + RAY_CHECK_OK(driver.Objects().Get(return_ids, -1, &results)); + + ASSERT_EQ(results.size(), 1); + ASSERT_EQ(results[0]->Size(), buffer1->Size() + buffer2->Size()); + ASSERT_EQ(memcmp(results[0]->Data(), buffer1->Data(), buffer1->Size()), 0); + ASSERT_EQ( + memcmp(results[0]->Data() + buffer1->Size(), buffer2->Data(), buffer2->Size()), + 0); + } + } + protected: std::vector raylet_socket_names_; std::vector raylet_store_socket_names_; @@ -131,6 +256,11 @@ class SingleNodeTest : public CoreWorkerTest { SingleNodeTest() : CoreWorkerTest(1) {} }; +class TwoNodeTest : public CoreWorkerTest { + public: + TwoNodeTest() : CoreWorkerTest(2) {} +}; + TEST_F(ZeroNodeTest, TestTaskArg) { // Test by-reference argument. ObjectID id = ObjectID::FromRandom(); @@ -148,10 +278,10 @@ TEST_F(ZeroNodeTest, TestTaskArg) { } TEST_F(ZeroNodeTest, TestAttributeGetters) { - CoreWorker core_worker(WorkerType::DRIVER, Language::PYTHON, "", "", + CoreWorker core_worker(WorkerType::DRIVER, WorkerLanguage::PYTHON, "", "", DriverID::FromRandom()); ASSERT_EQ(core_worker.WorkerType(), WorkerType::DRIVER); - ASSERT_EQ(core_worker.Language(), Language::PYTHON); + ASSERT_EQ(core_worker.Language(), WorkerLanguage::PYTHON); } TEST_F(ZeroNodeTest, TestWorkerContext) { @@ -180,7 +310,7 @@ TEST_F(ZeroNodeTest, TestWorkerContext) { } TEST_F(SingleNodeTest, TestObjectInterface) { - CoreWorker core_worker(WorkerType::DRIVER, Language::PYTHON, + CoreWorker core_worker(WorkerType::DRIVER, WorkerLanguage::PYTHON, raylet_store_socket_names_[0], raylet_socket_names_[0], DriverID::FromRandom()); RAY_CHECK_OK(core_worker.Connect()); @@ -193,16 +323,16 @@ TEST_F(SingleNodeTest, TestObjectInterface) { buffers.emplace_back(array2, sizeof(array2)); std::vector ids(buffers.size()); - for (int i = 0; i < ids.size(); i++) { - core_worker.Objects().Put(buffers[i], &ids[i]); + for (size_t i = 0; i < ids.size(); i++) { + RAY_CHECK_OK(core_worker.Objects().Put(buffers[i], &ids[i])); } // Test Get(). std::vector> results; - core_worker.Objects().Get(ids, 0, &results); + RAY_CHECK_OK(core_worker.Objects().Get(ids, -1, &results)); ASSERT_EQ(results.size(), 2); - for (int i = 0; i < ids.size(); i++) { + for (size_t i = 0; i < ids.size(); i++) { ASSERT_EQ(results[i]->Size(), buffers[i].Size()); ASSERT_EQ(memcmp(results[i]->Data(), buffers[i].Data(), buffers[i].Size()), 0); } @@ -213,34 +343,126 @@ TEST_F(SingleNodeTest, TestObjectInterface) { all_ids.push_back(non_existent_id); std::vector wait_results; - core_worker.Objects().Wait(all_ids, 2, -1, &wait_results); + RAY_CHECK_OK(core_worker.Objects().Wait(all_ids, 2, -1, &wait_results)); ASSERT_EQ(wait_results.size(), 3); ASSERT_EQ(wait_results, std::vector({true, true, false})); - core_worker.Objects().Wait(all_ids, 3, 100, &wait_results); + RAY_CHECK_OK(core_worker.Objects().Wait(all_ids, 3, 100, &wait_results)); ASSERT_EQ(wait_results.size(), 3); ASSERT_EQ(wait_results, std::vector({true, true, false})); // Test Delete(). // clear the reference held by PlasmaBuffer. results.clear(); - core_worker.Objects().Delete(ids, true, false); + RAY_CHECK_OK(core_worker.Objects().Delete(ids, true, false)); // Note that Delete() calls RayletClient::FreeObjects and would not // wait for objects being deleted, so wait a while for plasma store // to process the command. usleep(200 * 1000); - core_worker.Objects().Get(ids, 0, &results); + RAY_CHECK_OK(core_worker.Objects().Get(ids, 0, &results)); + ASSERT_EQ(results.size(), 2); + ASSERT_TRUE(!results[0]); + ASSERT_TRUE(!results[1]); +} + +TEST_F(TwoNodeTest, TestObjectInterfaceCrossNodes) { + CoreWorker worker1(WorkerType::DRIVER, WorkerLanguage::PYTHON, + raylet_store_socket_names_[0], raylet_socket_names_[0], + DriverID::FromRandom()); + RAY_CHECK_OK(worker1.Connect()); + + CoreWorker worker2(WorkerType::DRIVER, WorkerLanguage::PYTHON, + raylet_store_socket_names_[1], raylet_socket_names_[1], + DriverID::FromRandom()); + RAY_CHECK_OK(worker2.Connect()); + + uint8_t array1[] = {1, 2, 3, 4, 5, 6, 7, 8}; + uint8_t array2[] = {10, 11, 12, 13, 14, 15}; + + std::vector buffers; + buffers.emplace_back(array1, sizeof(array1)); + buffers.emplace_back(array2, sizeof(array2)); + + std::vector ids(buffers.size()); + for (size_t i = 0; i < ids.size(); i++) { + RAY_CHECK_OK(worker1.Objects().Put(buffers[i], &ids[i])); + } + + // Test Get() from remote node. + std::vector> results; + RAY_CHECK_OK(worker2.Objects().Get(ids, -1, &results)); + + ASSERT_EQ(results.size(), 2); + for (size_t i = 0; i < ids.size(); i++) { + ASSERT_EQ(results[i]->Size(), buffers[i].Size()); + ASSERT_EQ(memcmp(results[i]->Data(), buffers[i].Data(), buffers[i].Size()), 0); + } + + // Test Wait() from remote node. + ObjectID non_existent_id = ObjectID::FromRandom(); + std::vector all_ids(ids); + all_ids.push_back(non_existent_id); + + std::vector wait_results; + RAY_CHECK_OK(worker2.Objects().Wait(all_ids, 2, -1, &wait_results)); + ASSERT_EQ(wait_results.size(), 3); + ASSERT_EQ(wait_results, std::vector({true, true, false})); + + RAY_CHECK_OK(worker2.Objects().Wait(all_ids, 3, 100, &wait_results)); + ASSERT_EQ(wait_results.size(), 3); + ASSERT_EQ(wait_results, std::vector({true, true, false})); + + // Test Delete() from all machines. + // clear the reference held by PlasmaBuffer. + results.clear(); + RAY_CHECK_OK(worker2.Objects().Delete(ids, false, false)); + + // Note that Delete() calls RayletClient::FreeObjects and would not + // wait for objects being deleted, so wait a while for plasma store + // to process the command. + usleep(1000 * 1000); + // Verify objects are deleted from both machines. + RAY_CHECK_OK(worker2.Objects().Get(ids, 0, &results)); + ASSERT_EQ(results.size(), 2); + ASSERT_TRUE(!results[0]); + ASSERT_TRUE(!results[1]); + + RAY_CHECK_OK(worker1.Objects().Get(ids, 0, &results)); ASSERT_EQ(results.size(), 2); ASSERT_TRUE(!results[0]); ASSERT_TRUE(!results[1]); } +TEST_F(SingleNodeTest, TestNormalTaskLocal) { + std::unordered_map resources; + TestNormalTask(resources); +} + +TEST_F(TwoNodeTest, TestNormalTaskCrossNodes) { + std::unordered_map resources; + resources.emplace("resource1", 1); + TestNormalTask(resources); +} + +TEST_F(SingleNodeTest, TestActorTaskLocal) { + std::unordered_map resources; + TestActorTask(resources); +} + +TEST_F(TwoNodeTest, TestActorTaskCrossNodes) { + std::unordered_map resources; + resources.emplace("resource1", 1); + TestActorTask(resources); +} + } // namespace ray int main(int argc, char **argv) { ::testing::InitGoogleTest(&argc, argv); + RAY_CHECK(argc == 4); ray::store_executable = std::string(argv[1]); ray::raylet_executable = std::string(argv[2]); + ray::mock_worker_executable = std::string(argv[3]); return RUN_ALL_TESTS(); } diff --git a/src/ray/core_worker/mock_worker.cc b/src/ray/core_worker/mock_worker.cc new file mode 100644 index 000000000000..205fcfce961d --- /dev/null +++ b/src/ray/core_worker/mock_worker.cc @@ -0,0 +1,66 @@ +#include "ray/core_worker/context.h" +#include "ray/core_worker/core_worker.h" +#include "ray/core_worker/task_execution.h" + +namespace ray { + +/// A mock C++ worker used by core_worker_test.cc to verify the task submission/execution +/// interfaces in both single node and cross-nodes scenarios. As the raylet client can +/// only +/// be called by a real worker process, core_worker_test.cc has to use this program binary +/// to start the actual worker process, in the test, the task submission interfaces are +/// called +/// in core_worker_test, and task execution interfaces are called in this file, see that +/// test +/// for more details on how this class is used. +class MockWorker { + public: + MockWorker(const std::string &store_socket, const std::string &raylet_socket) + : worker_(WorkerType::WORKER, WorkerLanguage::PYTHON, store_socket, raylet_socket, + DriverID::FromRandom()) { + RAY_CHECK_OK(worker_.Connect()); + } + + void Run() { + auto executor_func = [this](const RayFunction &ray_function, + const std::vector> &args, + const TaskID &task_id, int num_returns) { + + // Note that this doesn't include dummy object id. + RAY_CHECK(num_returns >= 0); + + // Merge all the content from input args. + std::vector buffer; + for (const auto &arg : args) { + buffer.insert(buffer.end(), arg->Data(), arg->Data() + arg->Size()); + } + + LocalMemoryBuffer memory_buffer(buffer.data(), buffer.size()); + + // Write the merged content to each of return ids. + for (int i = 0; i < num_returns; i++) { + ObjectID id = ObjectID::ForTaskReturn(task_id, i + 1); + RAY_CHECK_OK(worker_.Objects().Put(memory_buffer, id)); + } + return Status::OK(); + }; + + // Start executing tasks. + worker_.Execution().Run(executor_func); + } + + private: + CoreWorker worker_; +}; + +} // namespace ray + +int main(int argc, char **argv) { + RAY_CHECK(argc == 3); + auto store_socket = std::string(argv[1]); + auto raylet_socket = std::string(argv[2]); + + ray::MockWorker worker(store_socket, raylet_socket); + worker.Run(); + return 0; +} diff --git a/src/ray/core_worker/object_interface.cc b/src/ray/core_worker/object_interface.cc index 0b94c9d4a747..5ab5d33330d7 100644 --- a/src/ray/core_worker/object_interface.cc +++ b/src/ray/core_worker/object_interface.cc @@ -1,7 +1,7 @@ -#include "object_interface.h" -#include "context.h" -#include "core_worker.h" +#include "ray/core_worker/object_interface.h" #include "ray/common/ray_config.h" +#include "ray/core_worker/context.h" +#include "ray/core_worker/core_worker.h" namespace ray { @@ -12,14 +12,25 @@ Status CoreWorkerObjectInterface::Put(const Buffer &buffer, ObjectID *object_id) ObjectID put_id = ObjectID::ForPut(core_worker_.worker_context_.GetCurrentTaskID(), core_worker_.worker_context_.GetNextPutIndex()); *object_id = put_id; + return Put(buffer, put_id); +} - auto plasma_id = put_id.ToPlasmaId(); +Status CoreWorkerObjectInterface::Put(const Buffer &buffer, const ObjectID &object_id) { + auto plasma_id = object_id.ToPlasmaId(); std::shared_ptr data; - RAY_ARROW_RETURN_NOT_OK( - core_worker_.store_client_.Create(plasma_id, buffer.Size(), nullptr, 0, &data)); + { + std::unique_lock guard(core_worker_.store_client_mutex_); + RAY_ARROW_RETURN_NOT_OK( + core_worker_.store_client_.Create(plasma_id, buffer.Size(), nullptr, 0, &data)); + } + memcpy(data->mutable_data(), buffer.Data(), buffer.Size()); - RAY_ARROW_RETURN_NOT_OK(core_worker_.store_client_.Seal(plasma_id)); - RAY_ARROW_RETURN_NOT_OK(core_worker_.store_client_.Release(plasma_id)); + + { + std::unique_lock guard(core_worker_.store_client_mutex_); + RAY_ARROW_RETURN_NOT_OK(core_worker_.store_client_.Seal(plasma_id)); + RAY_ARROW_RETURN_NOT_OK(core_worker_.store_client_.Release(plasma_id)); + } return Status::OK(); } @@ -31,7 +42,7 @@ Status CoreWorkerObjectInterface::Get(const std::vector &ids, bool was_blocked = false; std::unordered_map unready; - for (int i = 0; i < ids.size(); i++) { + for (size_t i = 0; i < ids.size(); i++) { unready.insert({ids[i], i}); } @@ -73,10 +84,13 @@ Status CoreWorkerObjectInterface::Get(const std::vector &ids, } std::vector object_buffers; - auto status = - core_worker_.store_client_.Get(plasma_ids, get_timeout, &object_buffers); + { + std::unique_lock guard(core_worker_.store_client_mutex_); + auto status = + core_worker_.store_client_.Get(plasma_ids, get_timeout, &object_buffers); + } - for (int i = 0; i < object_buffers.size(); i++) { + for (size_t i = 0; i < object_buffers.size(); i++) { if (object_buffers[i].data != nullptr) { const auto &object_id = unready_ids[i]; (*results)[unready[object_id]] = @@ -112,7 +126,7 @@ Status CoreWorkerObjectInterface::Wait(const std::vector &object_ids, // TODO: change RayletClient::Wait() to return a bit set, so that we don't need // to do this translation. (*results).resize(object_ids.size()); - for (int i = 0; i < object_ids.size(); i++) { + for (size_t i = 0; i < object_ids.size(); i++) { (*results)[i] = ready_ids.count(object_ids[i]) > 0; } diff --git a/src/ray/core_worker/object_interface.h b/src/ray/core_worker/object_interface.h index 8a9e20c48c6e..431b3f825ac9 100644 --- a/src/ray/core_worker/object_interface.h +++ b/src/ray/core_worker/object_interface.h @@ -1,11 +1,11 @@ #ifndef RAY_CORE_WORKER_OBJECT_INTERFACE_H #define RAY_CORE_WORKER_OBJECT_INTERFACE_H -#include "common.h" #include "plasma/client.h" #include "ray/common/buffer.h" #include "ray/common/id.h" #include "ray/common/status.h" +#include "ray/core_worker/common.h" namespace ray { @@ -23,6 +23,13 @@ class CoreWorkerObjectInterface { /// \return Status. Status Put(const Buffer &buffer, ObjectID *object_id); + /// Put an object with specified ID into object store. + /// + /// \param[in] buffer Data buffer of the object. + /// \param[in] object_id Object ID specified by user. + /// \return Status. + Status Put(const Buffer &buffer, const ObjectID &object_id); + /// Get a list of objects from the object store. /// /// \param[in] ids IDs of the objects to get. diff --git a/src/ray/core_worker/task_execution.cc b/src/ray/core_worker/task_execution.cc index aea48b4de34a..fc22fce96c97 100644 --- a/src/ray/core_worker/task_execution.cc +++ b/src/ray/core_worker/task_execution.cc @@ -1,7 +1,80 @@ -#include "task_execution.h" +#include "ray/core_worker/task_execution.h" +#include "ray/core_worker/context.h" +#include "ray/core_worker/core_worker.h" namespace ray { -void CoreWorkerTaskExecutionInterface::Start(const TaskExecutor &executor) {} +Status CoreWorkerTaskExecutionInterface::Run(const TaskExecutor &executor) { + RAY_CHECK(core_worker_.is_initialized_); + + while (true) { + std::unique_ptr task_spec; + auto status = core_worker_.raylet_client_->GetTask(&task_spec); + if (!status.ok()) { + RAY_LOG(ERROR) << "Get task failed with error: " + << ray::Status::IOError(status.message()); + return status; + } + + const auto &spec = *task_spec; + core_worker_.worker_context_.SetCurrentTask(spec); + + WorkerLanguage language = (spec.GetLanguage() == ::Language::JAVA) + ? WorkerLanguage::JAVA + : WorkerLanguage::PYTHON; + RayFunction func{language, spec.FunctionDescriptor()}; + + std::vector> args; + RAY_CHECK_OK(BuildArgsForExecutor(spec, &args)); + + auto num_returns = spec.NumReturns(); + if (spec.IsActorCreationTask() || spec.IsActorTask()) { + RAY_CHECK(num_returns > 0); + // Decrease to account for the dummy object id. + num_returns--; + } + + status = executor(func, args, spec.TaskId(), num_returns); + // TODO: + // 1. Check and handle failure. + // 2. Save or load checkpoint. + } + + // should never reach here. + return Status::OK(); +} + +Status CoreWorkerTaskExecutionInterface::BuildArgsForExecutor( + const raylet::TaskSpecification &spec, std::vector> *args) { + auto num_args = spec.NumArgs(); + (*args).resize(num_args); + + std::vector object_ids_to_fetch; + std::vector indices; + + for (int i = 0; i < spec.NumArgs(); ++i) { + int count = spec.ArgIdCount(i); + if (count > 0) { + // pass by reference. + RAY_CHECK(count == 1); + object_ids_to_fetch.push_back(spec.ArgId(i, 0)); + indices.push_back(i); + } else { + // pass by value. + (*args)[i] = std::make_shared( + const_cast(spec.ArgVal(i)), spec.ArgValLength(i)); + } + } + + std::vector> results; + auto status = core_worker_.object_interface_.Get(object_ids_to_fetch, -1, &results); + if (status.ok()) { + for (size_t i = 0; i < results.size(); i++) { + (*args)[indices[i]] = results[i]; + } + } + + return status; +} } // namespace ray diff --git a/src/ray/core_worker/task_execution.h b/src/ray/core_worker/task_execution.h index c4de937ee439..e2fe2148a3ab 100644 --- a/src/ray/core_worker/task_execution.h +++ b/src/ray/core_worker/task_execution.h @@ -1,14 +1,18 @@ #ifndef RAY_CORE_WORKER_TASK_EXECUTION_H #define RAY_CORE_WORKER_TASK_EXECUTION_H -#include "common.h" #include "ray/common/buffer.h" #include "ray/common/status.h" +#include "ray/core_worker/common.h" namespace ray { class CoreWorker; +namespace raylet { +class TaskSpecification; +} + /// The interface that contains all `CoreWorker` methods that are related to task /// execution. class CoreWorkerTaskExecutionInterface { @@ -20,13 +24,26 @@ class CoreWorkerTaskExecutionInterface { /// \param ray_function[in] Information about the function to execute. /// \param args[in] Arguments of the task. /// \return Status. - using TaskExecutor = std::function &args)>; + using TaskExecutor = std::function> &args, + const TaskID &task_id, int num_returns)>; /// Start receving and executes tasks in a infinite loop. - void Start(const TaskExecutor &executor); + /// \return Status. + Status Run(const TaskExecutor &executor); private: + /// Build arguments for task executor. This would loop through all the arguments + /// in task spec, and for each of them that's passed by reference (ObjectID), + /// fetch its content from store and; for arguments that are passed by value, + /// just copy their content. + /// + /// \param spec[in] Task specification. + /// \param args[out] The arguments for passing to task executor. + /// + Status BuildArgsForExecutor(const raylet::TaskSpecification &spec, + std::vector> *args); + /// Reference to the parent CoreWorker instance. CoreWorker &core_worker_; }; diff --git a/src/ray/core_worker/task_interface.cc b/src/ray/core_worker/task_interface.cc index ab8b8950c298..c19b1e23a7f9 100644 --- a/src/ray/core_worker/task_interface.cc +++ b/src/ray/core_worker/task_interface.cc @@ -1,4 +1,7 @@ -#include "task_interface.h" +#include "ray/raylet/task.h" +#include "ray/core_worker/context.h" +#include "ray/core_worker/core_worker.h" +#include "ray/core_worker/task_interface.h" namespace ray { @@ -6,13 +9,61 @@ Status CoreWorkerTaskInterface::SubmitTask(const RayFunction &function, const std::vector &args, const TaskOptions &task_options, std::vector *return_ids) { - return Status::OK(); + auto &context = core_worker_.worker_context_; + auto next_task_index = context.GetNextTaskIndex(); + const auto task_id = GenerateTaskId(context.GetCurrentDriverID(), + context.GetCurrentTaskID(), next_task_index); + + auto num_returns = task_options.num_returns; + (*return_ids).resize(num_returns); + for (int i = 0; i < num_returns; i++) { + (*return_ids)[i] = ObjectID::ForTaskReturn(task_id, i + 1); + } + + auto task_arguments = BuildTaskArguments(args); + auto language = ToTaskLanguage(function.language); + + ray::raylet::TaskSpecification spec(context.GetCurrentDriverID(), + context.GetCurrentTaskID(), next_task_index, + task_arguments, num_returns, task_options.resources, + language, function.function_descriptor); + + std::vector execution_dependencies; + return core_worker_.raylet_client_->SubmitTask(execution_dependencies, spec); } Status CoreWorkerTaskInterface::CreateActor( const RayFunction &function, const std::vector &args, - const ActorCreationOptions &actor_creation_options, ActorHandle *actor_handle) { - return Status::OK(); + const ActorCreationOptions &actor_creation_options, + std::unique_ptr *actor_handle) { + auto &context = core_worker_.worker_context_; + auto next_task_index = context.GetNextTaskIndex(); + const auto task_id = GenerateTaskId(context.GetCurrentDriverID(), + context.GetCurrentTaskID(), next_task_index); + + std::vector return_ids; + return_ids.push_back(ObjectID::ForTaskReturn(task_id, 1)); + ActorID actor_creation_id = ActorID::FromBinary(return_ids[0].Binary()); + + *actor_handle = std::unique_ptr( + new ActorHandle(actor_creation_id, ActorHandleID::Nil())); + (*actor_handle)->IncreaseTaskCounter(); + (*actor_handle)->SetActorCursor(return_ids[0]); + + auto task_arguments = BuildTaskArguments(args); + auto language = ToTaskLanguage(function.language); + + // Note that the caller is supposed to specify required placement resources + // correctly via actor_creation_options.resources. + ray::raylet::TaskSpecification spec( + context.GetCurrentDriverID(), context.GetCurrentTaskID(), next_task_index, + actor_creation_id, ObjectID::Nil(), actor_creation_options.max_reconstructions, + ActorID::Nil(), ActorHandleID::Nil(), 0, {}, task_arguments, 1, + actor_creation_options.resources, actor_creation_options.resources, language, + function.function_descriptor); + + std::vector execution_dependencies; + return core_worker_.raylet_client_->SubmitTask(execution_dependencies, spec); } Status CoreWorkerTaskInterface::SubmitActorTask(ActorHandle &actor_handle, @@ -20,7 +71,75 @@ Status CoreWorkerTaskInterface::SubmitActorTask(ActorHandle &actor_handle, const std::vector &args, const TaskOptions &task_options, std::vector *return_ids) { - return Status::OK(); + auto &context = core_worker_.worker_context_; + auto next_task_index = context.GetNextTaskIndex(); + const auto task_id = GenerateTaskId(context.GetCurrentDriverID(), + context.GetCurrentTaskID(), next_task_index); + + // add one for actor cursor object id. + auto num_returns = task_options.num_returns + 1; + (*return_ids).resize(num_returns); + for (int i = 0; i < num_returns; i++) { + (*return_ids)[i] = ObjectID::ForTaskReturn(task_id, i + 1); + } + + auto actor_creation_dummy_object_id = + ObjectID::FromBinary(actor_handle.ActorID().Binary()); + + auto task_arguments = BuildTaskArguments(args); + auto language = ToTaskLanguage(function.language); + + std::vector new_actor_handles; + ray::raylet::TaskSpecification spec( + context.GetCurrentDriverID(), context.GetCurrentTaskID(), next_task_index, + ActorID::Nil(), actor_creation_dummy_object_id, 0, actor_handle.ActorID(), + actor_handle.ActorHandleID(), actor_handle.IncreaseTaskCounter(), new_actor_handles, + task_arguments, num_returns, task_options.resources, task_options.resources, + language, function.function_descriptor); + + std::vector execution_dependencies; + execution_dependencies.push_back(actor_handle.ActorCursor()); + + auto actor_cursor = (*return_ids).back(); + actor_handle.SetActorCursor(actor_cursor); + actor_handle.ClearNewActorHandles(); + + auto status = core_worker_.raylet_client_->SubmitTask(execution_dependencies, spec); + + // remove cursor from return ids. + (*return_ids).pop_back(); + return status; +} + +std::vector> +CoreWorkerTaskInterface::BuildTaskArguments(const std::vector &args) { + std::vector> task_arguments; + for (const auto &arg : args) { + if (arg.IsPassedByReference()) { + std::vector references{arg.GetReference()}; + task_arguments.push_back( + std::make_shared(references)); + } else { + auto data = arg.GetValue(); + task_arguments.push_back( + std::make_shared(data->Data(), data->Size())); + } + } + return task_arguments; +} + +::Language CoreWorkerTaskInterface::ToTaskLanguage(WorkerLanguage language) { + switch (language) { + case ray::WorkerLanguage::JAVA: + return ::Language::JAVA; + break; + case ray::WorkerLanguage::PYTHON: + return ::Language::PYTHON; + break; + default: + RAY_LOG(FATAL) << "invalid language specified: " << static_cast(language); + break; + } } } // namespace ray diff --git a/src/ray/core_worker/task_interface.h b/src/ray/core_worker/task_interface.h index e23f049d341d..06bd5409a8dd 100644 --- a/src/ray/core_worker/task_interface.h +++ b/src/ray/core_worker/task_interface.h @@ -1,10 +1,12 @@ #ifndef RAY_CORE_WORKER_TASK_INTERFACE_H #define RAY_CORE_WORKER_TASK_INTERFACE_H -#include "common.h" +#include + #include "ray/common/buffer.h" #include "ray/common/id.h" #include "ray/common/status.h" +#include "ray/core_worker/common.h" namespace ray { @@ -12,6 +14,10 @@ class CoreWorker; /// Options of a non-actor-creation task. struct TaskOptions { + TaskOptions() {} + TaskOptions(int num_returns, const std::unordered_map &resources) + : num_returns(num_returns), resources(resources) {} + /// Number of returns of this task. const int num_returns = 1; /// Resources required by this task. @@ -20,6 +26,11 @@ struct TaskOptions { /// Options of an actor creation task. struct ActorCreationOptions { + ActorCreationOptions() {} + ActorCreationOptions(uint64_t max_reconstructions, + const std::unordered_map &resources) + : max_reconstructions(max_reconstructions), resources(resources) {} + /// Maximum number of times that the actor should be reconstructed when it dies /// unexpectedly. It must be non-negative. If it's 0, the actor won't be reconstructed. const uint64_t max_reconstructions = 0; @@ -31,19 +42,46 @@ struct ActorCreationOptions { class ActorHandle { public: ActorHandle(const ActorID &actor_id, const ActorHandleID &actor_handle_id) - : actor_id_(actor_id), actor_handle_id_(actor_handle_id) {} + : actor_id_(actor_id), + actor_handle_id_(actor_handle_id), + actor_cursor_(ObjectID::FromBinary(actor_id.Binary())), + task_counter_(0) {} /// ID of the actor. - const class ActorID &ActorID() const { return actor_id_; } + const ray::ActorID &ActorID() const { return actor_id_; }; /// ID of this actor handle. - const class ActorHandleID &ActorHandleID() const { return actor_handle_id_; } + const ray::ActorHandleID &ActorHandleID() const { return actor_handle_id_; }; + + private: + /// Cursor of this actor. + const ObjectID &ActorCursor() const { return actor_cursor_; }; + + /// Set actor cursor. + void SetActorCursor(const ObjectID &actor_cursor) { actor_cursor_ = actor_cursor; }; + + /// Increase task counter. + int IncreaseTaskCounter() { return task_counter_++; } + + std::list GetNewActorHandle() { + // TODO: implement this. + return std::list(); + } + + void ClearNewActorHandles() { /* TODO: implement this. */ + } private: /// ID of the actor. - const class ActorID actor_id_; + const ray::ActorID actor_id_; /// ID of this actor handle. - const class ActorHandleID actor_handle_id_; + const ray::ActorHandleID actor_handle_id_; + /// ID of this actor cursor. + ObjectID actor_cursor_; + /// Counter for tasks from this handle. + int task_counter_; + + friend class CoreWorkerTaskInterface; }; /// The interface that contains all `CoreWorker` methods that are related to task @@ -71,7 +109,7 @@ class CoreWorkerTaskInterface { /// \return Status. Status CreateActor(const RayFunction &function, const std::vector &args, const ActorCreationOptions &actor_creation_options, - ActorHandle *actor_handle); + std::unique_ptr *actor_handle); /// Submit an actor task. /// @@ -89,6 +127,20 @@ class CoreWorkerTaskInterface { private: /// Reference to the parent CoreWorker instance. CoreWorker &core_worker_; + + private: + /// Build the arguments for a task spec. + /// + /// \param[in] args Arguments of a task. + /// \return Arguments as required by task spec. + std::vector> BuildTaskArguments( + const std::vector &args); + + /// Translate from WorkLanguage to Language type (required by taks spec). + /// + /// \param[in] language Language for a task. + /// \return Translated task language. + ::Language ToTaskLanguage(WorkerLanguage language); }; } // namespace ray diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 5a97239faaf8..7b0df3522166 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -2143,7 +2143,7 @@ void NodeManager::HandleObjectLocal(const ObjectID &object_id) { // Notify the task dependency manager that this object is local. const auto ready_task_ids = task_dependency_manager_.HandleObjectLocal(object_id); RAY_LOG(DEBUG) << "Object local " << object_id << ", " - << " on " << gcs_client_->client_table().GetLocalClientId() + << " on " << gcs_client_->client_table().GetLocalClientId() << ", " << ready_task_ids.size() << " tasks ready"; // Transition the tasks whose dependencies are now fulfilled to the ready state. if (ready_task_ids.size() > 0) { diff --git a/src/ray/test/run_core_worker_tests.sh b/src/ray/test/run_core_worker_tests.sh index 5f1dd2eda69f..104b19ff19cb 100644 --- a/src/ray/test/run_core_worker_tests.sh +++ b/src/ray/test/run_core_worker_tests.sh @@ -6,7 +6,7 @@ set -e set -x -bazel build "//:core_worker_test" "//:raylet" "//:libray_redis_module.so" "@plasma//:plasma_store_server" +bazel build "//:core_worker_test" "//:mock_worker" "//:raylet" "//:libray_redis_module.so" "@plasma//:plasma_store_server" # Get the directory in which this script is executing. SCRIPT_DIR="`dirname \"$0\"`" @@ -26,6 +26,7 @@ REDIS_MODULE="./bazel-bin/libray_redis_module.so" LOAD_MODULE_ARGS="--loadmodule ${REDIS_MODULE}" STORE_EXEC="./bazel-bin/external/plasma/plasma_store_server" RAYLET_EXEC="./bazel-bin/raylet" +MOCK_WORKER_EXEC="./bazel-bin/mock_worker" # Allow cleanup commands to fail. bazel run //:redis-cli -- -p 6379 shutdown || true @@ -37,7 +38,7 @@ sleep 2s bazel run //:redis-server -- --loglevel warning ${LOAD_MODULE_ARGS} --port 6380 & sleep 2s # Run tests. -./bazel-bin/core_worker_test $STORE_EXEC $RAYLET_EXEC +./bazel-bin/core_worker_test $STORE_EXEC $RAYLET_EXEC $MOCK_WORKER_EXEC sleep 1s bazel run //:redis-cli -- -p 6379 shutdown bazel run //:redis-cli -- -p 6380 shutdown From e0e52f1871b529ba6a8d86f1547dc636c4c0f703 Mon Sep 17 00:00:00 2001 From: Peter Schafhalter Date: Wed, 12 Jun 2019 07:38:34 +0200 Subject: [PATCH 089/118] [sgd] Add non-distributed PyTorch runner (#4933) * Add non-distributed PyTorch runner * use dist.is_available() instead of checking OS * Nicer exception * Fix bug in choosing port * Refactor some code * Address comments * Address comments --- .../sgd/pytorch/distributed_pytorch_runner.py | 131 ++++++++++++++++++ .../sgd/pytorch/pytorch_runner.py | 105 +++++--------- .../sgd/pytorch/pytorch_trainer.py | 111 ++++++++------- python/ray/experimental/sgd/pytorch/utils.py | 2 +- .../experimental/sgd/tests/test_pytorch.py | 20 +-- 5 files changed, 237 insertions(+), 132 deletions(-) create mode 100644 python/ray/experimental/sgd/pytorch/distributed_pytorch_runner.py diff --git a/python/ray/experimental/sgd/pytorch/distributed_pytorch_runner.py b/python/ray/experimental/sgd/pytorch/distributed_pytorch_runner.py new file mode 100644 index 000000000000..160544633353 --- /dev/null +++ b/python/ray/experimental/sgd/pytorch/distributed_pytorch_runner.py @@ -0,0 +1,131 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import logging +import os +import torch.distributed as dist +import torch.utils.data + +from ray.experimental.sgd.pytorch.pytorch_runner import PyTorchRunner + +logger = logging.getLogger(__name__) + + +class DistributedPyTorchRunner(PyTorchRunner): + """Manages a distributed PyTorch model replica.""" + + def __init__(self, + model_creator, + data_creator, + optimizer_creator, + config=None, + batch_size=16, + backend="gloo"): + """Initializes the runner. + + Args: + model_creator (dict -> torch.nn.Module): see pytorch_trainer.py. + data_creator (dict -> Dataset, Dataset): see pytorch_trainer.py. + optimizer_creator (torch.nn.Module, dict -> loss, optimizer): + see pytorch_trainer.py. + config (dict): see pytorch_trainer.py. + batch_size (int): batch size used by one replica for an update. + backend (string): see pytorch_trainer.py. + """ + super(DistributedPyTorchRunner, self).__init__( + model_creator, data_creator, optimizer_creator, config, batch_size) + self.backend = backend + + def setup(self, url, world_rank, world_size): + """Connects to the distributed PyTorch backend and initializes the model. + + Args: + url (str): the URL used to connect to distributed PyTorch. + world_rank (int): the index of the runner. + world_size (int): the total number of runners. + """ + self._setup_distributed_pytorch(url, world_rank, world_size) + self._setup_training() + + def _setup_distributed_pytorch(self, url, world_rank, world_size): + os.environ["CUDA_LAUNCH_BLOCKING"] = "1" + with self._timers["setup_proc"]: + self.world_rank = world_rank + logger.debug( + "Connecting to {} world_rank: {} world_size: {}".format( + url, world_rank, world_size)) + logger.debug("using {}".format(self.backend)) + dist.init_process_group( + backend=self.backend, + init_method=url, + rank=world_rank, + world_size=world_size) + + def _setup_training(self): + logger.debug("Creating model") + self.model = self.model_creator(self.config) + if torch.cuda.is_available(): + self.model = torch.nn.parallel.DistributedDataParallel( + self.model.cuda()) + else: + self.model = torch.nn.parallel.DistributedDataParallelCPU( + self.model) + + logger.debug("Creating optimizer") + self.criterion, self.optimizer = self.optimizer_creator( + self.model, self.config) + if torch.cuda.is_available(): + self.criterion = self.criterion.cuda() + + logger.debug("Creating dataset") + self.training_set, self.validation_set = self.data_creator(self.config) + + # TODO: make num_workers configurable + self.train_sampler = torch.utils.data.distributed.DistributedSampler( + self.training_set) + self.train_loader = torch.utils.data.DataLoader( + self.training_set, + batch_size=self.batch_size, + shuffle=(self.train_sampler is None), + num_workers=2, + pin_memory=False, + sampler=self.train_sampler) + + self.validation_sampler = ( + torch.utils.data.distributed.DistributedSampler( + self.validation_set)) + self.validation_loader = torch.utils.data.DataLoader( + self.validation_set, + batch_size=self.batch_size, + shuffle=(self.validation_sampler is None), + num_workers=2, + pin_memory=False, + sampler=self.validation_sampler) + + def step(self): + """Runs a training epoch and updates the model parameters.""" + logger.debug("Starting step") + self.train_sampler.set_epoch(self.epoch) + return super(DistributedPyTorchRunner, self).step() + + def get_state(self): + """Returns the state of the runner.""" + return { + "epoch": self.epoch, + "model": self.model.module.state_dict(), + "optimizer": self.optimizer.state_dict(), + "stats": self.stats() + } + + def set_state(self, state): + """Sets the state of the model.""" + # TODO: restore timer stats + self.model.module.load_state_dict(state["model"]) + self.optimizer.load_state_dict(state["optimizer"]) + self.epoch = state["stats"]["epoch"] + + def shutdown(self): + """Attempts to shut down the worker.""" + super(DistributedPyTorchRunner, self).shutdown() + dist.destroy_process_group() diff --git a/python/ray/experimental/sgd/pytorch/pytorch_runner.py b/python/ray/experimental/sgd/pytorch/pytorch_runner.py index 5fe4ba1009f9..1663b2c64f0e 100644 --- a/python/ray/experimental/sgd/pytorch/pytorch_runner.py +++ b/python/ray/experimental/sgd/pytorch/pytorch_runner.py @@ -3,9 +3,7 @@ from __future__ import print_function import logging -import os import torch -import torch.distributed as dist import torch.utils.data import ray @@ -15,28 +13,23 @@ class PyTorchRunner(object): - """Manages a distributed PyTorch model replica""" + """Manages a PyTorch model for training.""" def __init__(self, model_creator, data_creator, optimizer_creator, config=None, - batch_size=16, - backend="gloo"): + batch_size=16): """Initializes the runner. Args: - model_creator (dict -> torch.nn.Module): creates the model using - the config. - data_creator (dict -> Dataset, Dataset): creates the training and - validation data sets using the config. + model_creator (dict -> torch.nn.Module): see pytorch_trainer.py. + data_creator (dict -> Dataset, Dataset): see pytorch_trainer.py. optimizer_creator (torch.nn.Module, dict -> loss, optimizer): - creates the loss and optimizer using the model and the config. - config (dict): configuration passed to 'model_creator', - 'data_creator', and 'optimizer_creator'. - batch_size (int): batch size used in an update. - backend (string): backend used by distributed PyTorch. + see pytorch_trainer.py. + config (dict): see pytorch_trainer.py. + batch_size (int): see pytorch_trainer.py. """ self.model_creator = model_creator @@ -44,7 +37,6 @@ def __init__(self, self.optimizer_creator = optimizer_creator self.config = {} if config is None else config self.batch_size = batch_size - self.backend = backend self.verbose = True self.epoch = 0 @@ -56,82 +48,45 @@ def __init__(self, ] } - def setup(self, url, world_rank, world_size): - """Connects to the distributed PyTorch backend and initializes the model. - - Args: - url (str): the URL used to connect to distributed PyTorch. - world_rank (int): the index of the runner. - world_size (int): the total number of runners. - """ - self._setup_distributed_pytorch(url, world_rank, world_size) - self._setup_training() - - def _setup_distributed_pytorch(self, url, world_rank, world_size): - os.environ["CUDA_LAUNCH_BLOCKING"] = "1" - with self._timers["setup_proc"]: - self.world_rank = world_rank - logger.debug( - "Connecting to {} world_rank: {} world_size: {}".format( - url, world_rank, world_size)) - logger.debug("using {}".format(self.backend)) - dist.init_process_group( - backend=self.backend, - init_method=url, - rank=world_rank, - world_size=world_size) - - def _setup_training(self): + def setup(self): + """Initializes the model.""" logger.debug("Creating model") self.model = self.model_creator(self.config) if torch.cuda.is_available(): - self.model = torch.nn.parallel.DistributedDataParallel( - self.model.cuda()) - else: - self.model = torch.nn.parallel.DistributedDataParallelCPU( - self.model) + self.model = self.model.cuda() logger.debug("Creating optimizer") self.criterion, self.optimizer = self.optimizer_creator( self.model, self.config) - if torch.cuda.is_available(): self.criterion = self.criterion.cuda() logger.debug("Creating dataset") self.training_set, self.validation_set = self.data_creator(self.config) - - # TODO: make num_workers configurable - self.train_sampler = torch.utils.data.distributed.DistributedSampler( - self.training_set) self.train_loader = torch.utils.data.DataLoader( self.training_set, batch_size=self.batch_size, - shuffle=(self.train_sampler is None), + shuffle=True, num_workers=2, - pin_memory=False, - sampler=self.train_sampler) + pin_memory=False) - self.validation_sampler = ( - torch.utils.data.distributed.DistributedSampler( - self.validation_set)) self.validation_loader = torch.utils.data.DataLoader( self.validation_set, batch_size=self.batch_size, - shuffle=(self.validation_sampler is None), + shuffle=True, num_workers=2, - pin_memory=False, - sampler=self.validation_sampler) + pin_memory=False) def get_node_ip(self): - """Returns the IP address of the current node""" + """Returns the IP address of the current node.""" return ray.services.get_node_ip_address() - def step(self): - """Runs a training epoch and updates the model parameters""" - logger.debug("Starting step") - self.train_sampler.set_epoch(self.epoch) + def find_free_port(self): + """Finds a free port on the current node.""" + return utils.find_free_port() + def step(self): + """Runs a training epoch and updates the model parameters.""" logger.debug("Begin Training Epoch {}".format(self.epoch + 1)) with self._timers["training"]: train_stats = utils.train(self.train_loader, self.model, @@ -144,7 +99,7 @@ def step(self): return train_stats def validate(self): - """Evaluates the model on the validation data set""" + """Evaluates the model on the validation data set.""" with self._timers["validation"]: validation_stats = utils.validate(self.validation_loader, self.model, self.criterion) @@ -153,7 +108,7 @@ def validate(self): return validation_stats def stats(self): - """Returns a dictionary of statistics collected""" + """Returns a dictionary of statistics collected.""" stats = {"epoch": self.epoch} for k, t in self._timers.items(): stats[k + "_time_mean"] = t.mean @@ -162,7 +117,7 @@ def stats(self): return stats def get_state(self): - """Returns the state of the runner""" + """Returns the state of the runner.""" return { "epoch": self.epoch, "model": self.model.state_dict(), @@ -171,12 +126,20 @@ def get_state(self): } def set_state(self, state): - """Sets the state of the model""" + """Sets the state of the model.""" # TODO: restore timer stats self.model.load_state_dict(state["model"]) self.optimizer.load_state_dict(state["optimizer"]) self.epoch = state["stats"]["epoch"] def shutdown(self): - """Attempts to shut down the worker""" - dist.destroy_process_group() + """Attempts to shut down the worker.""" + del self.validation_loader + del self.validation_set + del self.train_loader + del self.training_set + del self.criterion + del self.optimizer + del self.model + if torch.cuda.is_available(): + torch.cuda.empty_cache() diff --git a/python/ray/experimental/sgd/pytorch/pytorch_trainer.py b/python/ray/experimental/sgd/pytorch/pytorch_trainer.py index 073ad3d34042..0e0c5d8436a1 100644 --- a/python/ray/experimental/sgd/pytorch/pytorch_trainer.py +++ b/python/ray/experimental/sgd/pytorch/pytorch_trainer.py @@ -3,13 +3,15 @@ from __future__ import print_function import numpy as np -import sys import torch +import torch.distributed as dist import logging import ray from ray.experimental.sgd.pytorch.pytorch_runner import PyTorchRunner +from ray.experimental.sgd.pytorch.distributed_pytorch_runner import ( + DistributedPyTorchRunner) from ray.experimental.sgd.pytorch import utils logger = logging.getLogger(__name__) @@ -51,10 +53,11 @@ def __init__(self, """ # TODO: add support for mixed precision # TODO: add support for callbacks - if sys.platform == "darwin": - raise Exception( - ("Distributed PyTorch is not supported on macOS. For more " - "information, see " + if num_replicas > 1 and not dist.is_available(): + raise ValueError( + ("Distributed PyTorch is not supported on macOS. " + "To run without distributed PyTorch, set 'num_replicas=1'. " + "For more information, see " "https://github.com/pytorch/examples/issues/467.")) self.model_creator = model_creator @@ -68,40 +71,55 @@ def __init__(self, if backend == "auto": backend = "nccl" if resources_per_replica.num_gpus > 0 else "gloo" - Runner = ray.remote( - num_cpus=resources_per_replica.num_cpus, - num_gpus=resources_per_replica.num_gpus, - resources=resources_per_replica.resources)(PyTorchRunner) - - batch_size_per_replica = batch_size // num_replicas - if batch_size % num_replicas > 0: - new_batch_size = batch_size_per_replica * num_replicas - logger.warn( - ("Changing batch size from {old_batch_size} to " - "{new_batch_size} to evenly distribute batches across " - "{num_replicas} replicas.").format( - old_batch_size=batch_size, - new_batch_size=new_batch_size, - num_replicas=num_replicas)) - - self.workers = [ - Runner.remote(model_creator, data_creator, optimizer_creator, - self.config, batch_size_per_replica, backend) - for i in range(num_replicas) - ] - - ip = ray.get(self.workers[0].get_node_ip.remote()) - port = utils.find_free_port() - address = "tcp://{ip}:{port}".format(ip=ip, port=port) - - # Get setup tasks in order to throw errors on failure - ray.get([ - worker.setup.remote(address, i, len(self.workers)) - for i, worker in enumerate(self.workers) - ]) + if num_replicas == 1: + # Generate actor class + Runner = ray.remote( + num_cpus=resources_per_replica.num_cpus, + num_gpus=resources_per_replica.num_gpus, + resources=resources_per_replica.resources)(PyTorchRunner) + # Start workers + self.workers = [ + Runner.remote(model_creator, data_creator, optimizer_creator, + self.config, batch_size) + ] + # Get setup tasks in order to throw errors on failure + ray.get(self.workers[0].setup.remote()) + else: + # Geneate actor class + Runner = ray.remote( + num_cpus=resources_per_replica.num_cpus, + num_gpus=resources_per_replica.num_gpus, + resources=resources_per_replica.resources)( + DistributedPyTorchRunner) + # Compute batch size per replica + batch_size_per_replica = batch_size // num_replicas + if batch_size % num_replicas > 0: + new_batch_size = batch_size_per_replica * num_replicas + logger.warn( + ("Changing batch size from {old_batch_size} to " + "{new_batch_size} to evenly distribute batches across " + "{num_replicas} replicas.").format( + old_batch_size=batch_size, + new_batch_size=new_batch_size, + num_replicas=num_replicas)) + # Start workers + self.workers = [ + Runner.remote(model_creator, data_creator, optimizer_creator, + self.config, batch_size_per_replica, backend) + for i in range(num_replicas) + ] + # Compute URL for initializing distributed PyTorch + ip = ray.get(self.workers[0].get_node_ip.remote()) + port = ray.get(self.workers[0].find_free_port.remote()) + address = "tcp://{ip}:{port}".format(ip=ip, port=port) + # Get setup tasks in order to throw errors on failure + ray.get([ + worker.setup.remote(address, i, len(self.workers)) + for i, worker in enumerate(self.workers) + ]) def train(self): - """Runs a training epoch""" + """Runs a training epoch.""" with self.optimizer_timer: worker_stats = ray.get([w.step.remote() for w in self.workers]) @@ -111,7 +129,7 @@ def train(self): return train_stats def validate(self): - """Evaluates the model on the validation data set""" + """Evaluates the model on the validation data set.""" worker_stats = ray.get([w.validate.remote() for w in self.workers]) validation_stats = worker_stats[0].copy() validation_stats["validation_loss"] = np.mean( @@ -119,32 +137,25 @@ def validate(self): return validation_stats def get_model(self): - """Returns the learned model""" + """Returns the learned model.""" model = self.model_creator(self.config) state = ray.get(self.workers[0].get_state.remote()) - - # Remove module. prefix added by distrbuted pytorch - state_dict = { - k.replace("module.", ""): v - for k, v in state["model"].items() - } - - model.load_state_dict(state_dict) + model.load_state_dict(state["model"]) return model def save(self, ckpt): - """Saves the model at the provided checkpoint""" + """Saves the model at the provided checkpoint.""" state = ray.get(self.workers[0].get_state.remote()) torch.save(state, ckpt) def restore(self, ckpt): - """Restores the model from the provided checkpoint""" + """Restores the model from the provided checkpoint.""" state = torch.load(ckpt) state_id = ray.put(state) ray.get([worker.set_state.remote(state_id) for worker in self.workers]) def shutdown(self): - """Shuts down workers and releases resources""" + """Shuts down workers and releases resources.""" for worker in self.workers: worker.shutdown.remote() worker.__ray_terminate__.remote() diff --git a/python/ray/experimental/sgd/pytorch/utils.py b/python/ray/experimental/sgd/pytorch/utils.py index f7c6e4abac97..5be26b331cfd 100644 --- a/python/ray/experimental/sgd/pytorch/utils.py +++ b/python/ray/experimental/sgd/pytorch/utils.py @@ -196,7 +196,7 @@ def find_free_port(): class AverageMeter(object): - """Computes and stores the average and current value""" + """Computes and stores the average and current value.""" def __init__(self): self.reset() diff --git a/python/ray/experimental/sgd/tests/test_pytorch.py b/python/ray/experimental/sgd/tests/test_pytorch.py index faff23f8a809..aa0596aa158c 100644 --- a/python/ray/experimental/sgd/tests/test_pytorch.py +++ b/python/ray/experimental/sgd/tests/test_pytorch.py @@ -4,9 +4,9 @@ import os import pytest -import sys import tempfile import torch +import torch.distributed as dist from ray.tests.conftest import ray_start_2_cpus # noqa: F401 from ray.experimental.sgd.pytorch import PyTorchTrainer, Resources @@ -15,14 +15,14 @@ model_creator, optimizer_creator, data_creator) -@pytest.mark.skipif( # noqa: F811 - sys.platform == "darwin", reason="Doesn't work on macOS.") -def test_train(ray_start_2_cpus): # noqa: F811 +@pytest.mark.parametrize( # noqa: F811 + "num_replicas", [1, 2] if dist.is_available() else [1]) +def test_train(ray_start_2_cpus, num_replicas): # noqa: F811 trainer = PyTorchTrainer( model_creator, data_creator, optimizer_creator, - num_replicas=2, + num_replicas=num_replicas, resources_per_replica=Resources(num_cpus=1)) train_loss1 = trainer.train()["train_loss"] validation_loss1 = trainer.validate()["validation_loss"] @@ -37,14 +37,14 @@ def test_train(ray_start_2_cpus): # noqa: F811 assert validation_loss2 <= validation_loss1 -@pytest.mark.skipif( # noqa: F811 - sys.platform == "darwin", reason="Doesn't work on macOS.") -def test_save_and_restore(ray_start_2_cpus): # noqa: F811 +@pytest.mark.parametrize( # noqa: F811 + "num_replicas", [1, 2] if dist.is_available() else [1]) +def test_save_and_restore(ray_start_2_cpus, num_replicas): # noqa: F811 trainer1 = PyTorchTrainer( model_creator, data_creator, optimizer_creator, - num_replicas=2, + num_replicas=num_replicas, resources_per_replica=Resources(num_cpus=1)) trainer1.train() @@ -59,7 +59,7 @@ def test_save_and_restore(ray_start_2_cpus): # noqa: F811 model_creator, data_creator, optimizer_creator, - num_replicas=2, + num_replicas=num_replicas, resources_per_replica=Resources(num_cpus=1)) trainer2.restore(filename) From 89ca5eeb29d5be0753f8d2e64826b7f51002a0de Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Wed, 12 Jun 2019 11:13:39 -0700 Subject: [PATCH 090/118] Flush all tasks from local lineage cache after a node failure (#4964) --- src/ray/raylet/lineage_cache.cc | 14 +++++++ src/ray/raylet/lineage_cache.h | 7 ++++ src/ray/raylet/lineage_cache_test.cc | 55 +++++++++++++++++++++++++++- src/ray/raylet/node_manager.cc | 5 +++ 4 files changed, 80 insertions(+), 1 deletion(-) diff --git a/src/ray/raylet/lineage_cache.cc b/src/ray/raylet/lineage_cache.cc index 795f2b54a6cb..fcae3f8e04c9 100644 --- a/src/ray/raylet/lineage_cache.cc +++ b/src/ray/raylet/lineage_cache.cc @@ -221,6 +221,20 @@ bool LineageCache::CommitTask(const Task &task) { } } +void LineageCache::FlushAllUncommittedTasks() { + size_t num_flushed = 0; + for (const auto &entry : lineage_.GetEntries()) { + // Flush all tasks that have not yet committed. + if (entry.second.GetStatus() == GcsStatus::UNCOMMITTED) { + RAY_CHECK(UnsubscribeTask(entry.first)); + FlushTask(entry.first); + num_flushed++; + } + } + + RAY_LOG(DEBUG) << "Flushed " << num_flushed << " uncommitted tasks"; +} + void LineageCache::MarkTaskAsForwarded(const TaskID &task_id, const ClientID &node_id) { RAY_CHECK(!node_id.IsNil()); auto entry = lineage_.GetEntryMutable(task_id); diff --git a/src/ray/raylet/lineage_cache.h b/src/ray/raylet/lineage_cache.h index 2dff0e94a4d1..5436fa372fa4 100644 --- a/src/ray/raylet/lineage_cache.h +++ b/src/ray/raylet/lineage_cache.h @@ -231,6 +231,13 @@ class LineageCache { /// task was already in the COMMITTING state. bool CommitTask(const Task &task); + /// Flush all tasks in the local cache that are not already being + /// committed. This is equivalent to all tasks in the UNCOMMITTED + /// state. + /// + /// \return Void. + void FlushAllUncommittedTasks(); + /// Add a task and its (estimated) uncommitted lineage to the local cache. We /// will subscribe to commit notifications for all uncommitted tasks to /// determine when it is safe to evict the lineage from the local cache. diff --git a/src/ray/raylet/lineage_cache_test.cc b/src/ray/raylet/lineage_cache_test.cc index e5c126bcf078..43e64e400292 100644 --- a/src/ray/raylet/lineage_cache_test.cc +++ b/src/ray/raylet/lineage_cache_test.cc @@ -26,8 +26,22 @@ class MockGcs : public gcs::TableInterface, std::shared_ptr &task_data, const gcs::TableInterface::WriteCallback &done) { task_table_[task_id] = task_data; + auto callback = done; + // If we requested notifications for this task ID, send the notification as + // part of the callback. + if (subscribed_tasks_.count(task_id) == 1) { + callback = [this, done](ray::gcs::AsyncGcsClient *client, const TaskID &task_id, + const protocol::TaskT &data) { + done(client, task_id, data); + // If we're subscribed to the task to be added, also send a + // subscription notification. + notification_callback_(client, task_id, data); + }; + } + callbacks_.push_back( - std::pair(done, task_id)); + std::pair(callback, task_id)); + num_task_adds_++; return ray::Status::OK(); } @@ -78,28 +92,34 @@ class MockGcs : public gcs::TableInterface, const int NumRequestedNotifications() const { return num_requested_notifications_; } + const int NumTaskAdds() const { return num_task_adds_; } + private: std::unordered_map> task_table_; std::vector> callbacks_; gcs::raylet::TaskTable::WriteCallback notification_callback_; std::unordered_set subscribed_tasks_; int num_requested_notifications_ = 0; + int num_task_adds_ = 0; }; class LineageCacheTest : public ::testing::Test { public: LineageCacheTest() : max_lineage_size_(10), + num_notifications_(0), mock_gcs_(), lineage_cache_(ClientID::FromRandom(), mock_gcs_, mock_gcs_, max_lineage_size_) { mock_gcs_.Subscribe([this](ray::gcs::AsyncGcsClient *client, const TaskID &task_id, const ray::protocol::TaskT &data) { lineage_cache_.HandleEntryCommitted(task_id); + num_notifications_++; }); } protected: uint64_t max_lineage_size_; + uint64_t num_notifications_; MockGcs mock_gcs_; LineageCache lineage_cache_; }; @@ -529,6 +549,39 @@ TEST_F(LineageCacheTest, TestEvictionUncommittedChildren) { ASSERT_EQ(lineage_cache_.GetLineage().GetChildrenSize(), 0); } +TEST_F(LineageCacheTest, TestFlushAllUncommittedTasks) { + // Insert a chain of tasks. + std::vector tasks; + auto return_values = + InsertTaskChain(lineage_cache_, tasks, 3, std::vector(), 1); + std::vector task_ids; + for (const auto &task : tasks) { + task_ids.push_back(task.GetTaskSpecification().TaskId()); + } + // Check that we subscribed to each of the uncommitted tasks. + ASSERT_EQ(mock_gcs_.NumRequestedNotifications(), task_ids.size()); + + // Flush all uncommitted tasks and make sure we add all tasks to + // the task table. + lineage_cache_.FlushAllUncommittedTasks(); + ASSERT_EQ(mock_gcs_.NumTaskAdds(), tasks.size()); + // Flush again and make sure there are no new tasks added to the + // task table. + lineage_cache_.FlushAllUncommittedTasks(); + ASSERT_EQ(mock_gcs_.NumTaskAdds(), tasks.size()); + + // Flush all GCS notifications. + mock_gcs_.Flush(); + // Make sure that we unsubscribed to the uncommitted tasks before + // we flushed them. + ASSERT_EQ(num_notifications_, 0); + + // Flush again and make sure there are no new tasks added to the + // task table. + lineage_cache_.FlushAllUncommittedTasks(); + ASSERT_EQ(mock_gcs_.NumTaskAdds(), tasks.size()); +} + } // namespace raylet } // namespace ray diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 7b0df3522166..75377a13c73d 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -475,6 +475,11 @@ void NodeManager::ClientRemoved(const ClientTableDataT &client_data) { // Notify the object directory that the client has been removed so that it // can remove it from any cached locations. object_directory_->HandleClientRemoved(client_id); + + // Flush all uncommitted tasks from the local lineage cache. This is to + // guarantee that all tasks get flushed eventually, in case one of the tasks + // in our local cache was supposed to be flushed by the node that died. + lineage_cache_.FlushAllUncommittedTasks(); } void NodeManager::ResourceCreateUpdated(const ClientTableDataT &client_data) { From d2f5b71c3bc93e5c9727d3be534ffce2f61277af Mon Sep 17 00:00:00 2001 From: Robert Nishihara Date: Wed, 12 Jun 2019 15:02:12 -0700 Subject: [PATCH 091/118] Remove typing from setup.py install_requirements. (#4971) --- python/setup.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/setup.py b/python/setup.py index 10eea24f1310..db8676042de9 100644 --- a/python/setup.py +++ b/python/setup.py @@ -148,8 +148,6 @@ def find_version(*filepath): # NOTE: Don't upgrade the version of six! Doing so causes installation # problems. See https://github.com/ray-project/ray/issues/4169. "six >= 1.0.0", - # The typing module is required by modin. - "typing", "flatbuffers", "faulthandler;python_version<'3.3'", ] From ef1af49efd715d2a156d8cde95dd07a084d14851 Mon Sep 17 00:00:00 2001 From: Qing Wang Date: Thu, 13 Jun 2019 20:52:41 +0800 Subject: [PATCH 092/118] [Java] Fix bug of `BaseID` in multi-threading case. (#4974) --- java/api/src/main/java/org/ray/api/id/BaseId.java | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/java/api/src/main/java/org/ray/api/id/BaseId.java b/java/api/src/main/java/org/ray/api/id/BaseId.java index 3c5e1e3a3619..e08955d5a93e 100644 --- a/java/api/src/main/java/org/ray/api/id/BaseId.java +++ b/java/api/src/main/java/org/ray/api/id/BaseId.java @@ -41,13 +41,14 @@ public ByteBuffer toByteBuffer() { */ public boolean isNil() { if (isNilCache == null) { - isNilCache = true; + boolean localIsNil = true; for (int i = 0; i < size(); ++i) { if (id[i] != (byte) 0xff) { - isNilCache = false; + localIsNil = false; break; } } + isNilCache = localIsNil; } return isNilCache; } From fa1d4c9807c4e5bb4c7abb273f6451eb9d822a68 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Thu, 13 Jun 2019 15:07:46 -0700 Subject: [PATCH 093/118] [rllib] Fix DDPG example (#4973) --- python/ray/rllib/policy/tf_policy.py | 2 +- python/ray/rllib/tuned_examples/halfcheetah-ddpg.yaml | 1 - python/ray/rllib/tuned_examples/mountaincarcontinuous-ddpg.yaml | 1 - python/ray/rllib/tuned_examples/pendulum-ddpg.yaml | 1 - 4 files changed, 1 insertion(+), 4 deletions(-) diff --git a/python/ray/rllib/policy/tf_policy.py b/python/ray/rllib/policy/tf_policy.py index ed234f809512..ef0de42e2f0c 100644 --- a/python/ray/rllib/policy/tf_policy.py +++ b/python/ray/rllib/policy/tf_policy.py @@ -205,7 +205,7 @@ def _initialize_loss(self, loss, loss_inputs): self._grads_and_vars) if log_once("loss_used"): - logger.info( + logger.debug( "These tensors were used in the loss_fn:\n\n{}\n".format( summarize(self._loss_input_dict))) diff --git a/python/ray/rllib/tuned_examples/halfcheetah-ddpg.yaml b/python/ray/rllib/tuned_examples/halfcheetah-ddpg.yaml index 6a4bd52e77fe..0513f7bf6ef1 100644 --- a/python/ray/rllib/tuned_examples/halfcheetah-ddpg.yaml +++ b/python/ray/rllib/tuned_examples/halfcheetah-ddpg.yaml @@ -47,7 +47,6 @@ halfcheetah-ddpg: # === Parallelism === num_workers: 0 num_gpus_per_worker: 0 - optimizer_class: "SyncReplayOptimizer" per_worker_exploration: False worker_side_prioritization: False diff --git a/python/ray/rllib/tuned_examples/mountaincarcontinuous-ddpg.yaml b/python/ray/rllib/tuned_examples/mountaincarcontinuous-ddpg.yaml index 3a8f61229224..87ce8eff58cc 100644 --- a/python/ray/rllib/tuned_examples/mountaincarcontinuous-ddpg.yaml +++ b/python/ray/rllib/tuned_examples/mountaincarcontinuous-ddpg.yaml @@ -47,7 +47,6 @@ mountaincarcontinuous-ddpg: # === Parallelism === num_workers: 0 num_gpus_per_worker: 0 - optimizer_class: "SyncReplayOptimizer" per_worker_exploration: False worker_side_prioritization: False diff --git a/python/ray/rllib/tuned_examples/pendulum-ddpg.yaml b/python/ray/rllib/tuned_examples/pendulum-ddpg.yaml index 59891a86b6bc..a2ad295fb4c0 100644 --- a/python/ray/rllib/tuned_examples/pendulum-ddpg.yaml +++ b/python/ray/rllib/tuned_examples/pendulum-ddpg.yaml @@ -47,7 +47,6 @@ pendulum-ddpg: # === Parallelism === num_workers: 0 num_gpus_per_worker: 0 - optimizer_class: "SyncReplayOptimizer" per_worker_exploration: False worker_side_prioritization: False From 3c92b2ee4d78c3b099207d7157c06bcd5f0473cf Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Fri, 14 Jun 2019 14:52:32 +0800 Subject: [PATCH 094/118] Upgrade CI clang-format to 6.0 (#4976) --- .travis.yml | 7 +- ci/travis/check-git-clang-format-output.sh | 2 +- src/ray/common/id.cc | 4 +- src/ray/core_worker/context.cc | 1 + src/ray/core_worker/mock_worker.cc | 1 - src/ray/core_worker/task_interface.cc | 3 +- src/ray/core_worker/task_interface.h | 1 + src/ray/gcs/client.cc | 4 +- src/ray/gcs/client_test.cc | 100 ++-- src/ray/gcs/redis_context.cc | 20 +- src/ray/gcs/redis_module/ray_redis_module.cc | 2 +- src/ray/gcs/redis_module/redismodule.h | 439 ++++++++++-------- src/ray/gcs/tables.cc | 14 +- src/ray/object_manager/object_directory.cc | 60 +-- .../test/object_manager_stress_test.cc | 31 +- .../test/object_manager_test.cc | 38 +- src/ray/raylet/client_connection_test.cc | 18 +- ...org_ray_runtime_raylet_RayletClientImpl.cc | 10 +- src/ray/raylet/lineage_cache.cc | 7 +- src/ray/raylet/main.cc | 2 +- src/ray/raylet/monitor.cc | 4 +- src/ray/raylet/node_manager.cc | 137 +++--- .../raylet/object_manager_integration_test.cc | 21 +- src/ray/raylet/raylet.cc | 36 +- src/ray/raylet/scheduling_queue.cc | 18 +- src/ray/raylet/scheduling_queue.h | 10 +- src/ray/raylet/scheduling_resources.cc | 4 +- src/ray/util/logging.cc | 3 +- src/ray/util/logging.h | 20 +- src/ray/util/macros.h | 2 +- 30 files changed, 555 insertions(+), 464 deletions(-) diff --git a/.travis.yml b/.travis.yml index f1f292c58088..1888fa4ce03f 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,4 +1,6 @@ language: generic +# Use Ubuntu 16.04 +dist: xenial matrix: include: @@ -35,11 +37,8 @@ matrix: - os: linux env: LINT=1 PYTHONWARNINGS=ignore before_install: - # In case we ever want to use a different version of clang-format: - #- wget -O - http://apt.llvm.org/llvm-snapshot.gpg.key | sudo apt-key add - - #- echo "deb http://apt.llvm.org/trusty/ llvm-toolchain-trusty main" | sudo tee -a /etc/apt/sources.list > /dev/null - sudo apt-get update -qq - - sudo apt-get install -qq clang-format-3.8 + - sudo apt-get install -qq clang-format-6.0 install: [] script: - ./ci/travis/check-git-clang-format-output.sh diff --git a/ci/travis/check-git-clang-format-output.sh b/ci/travis/check-git-clang-format-output.sh index 4209811cd21c..6d83044c4877 100755 --- a/ci/travis/check-git-clang-format-output.sh +++ b/ci/travis/check-git-clang-format-output.sh @@ -8,7 +8,7 @@ else base_commit="$TRAVIS_BRANCH" echo "Running clang-format against branch $base_commit, with hash $(git rev-parse $base_commit)" fi -output="$(ci/travis/git-clang-format --binary clang-format-3.8 --commit $base_commit --diff --exclude '(.*thirdparty/|.*redismodule.h|.*.js|.*.java)')" +output="$(ci/travis/git-clang-format --binary clang-format --commit $base_commit --diff --exclude '(.*thirdparty/|.*redismodule.h|.*.js|.*.java)')" if [ "$output" == "no modified files to format" ] || [ "$output" == "clang-format did not modify any files" ] ; then echo "clang-format passed." exit 0 diff --git a/src/ray/common/id.cc b/src/ray/common/id.cc index 57e41d97d10c..3928d4adfcb7 100644 --- a/src/ray/common/id.cc +++ b/src/ray/common/id.cc @@ -105,8 +105,8 @@ ObjectID ObjectID::ForPut(const TaskID &task_id, int64_t put_index) { } ObjectID ObjectID::ForTaskReturn(const TaskID &task_id, int64_t return_index) { - RAY_CHECK(return_index >= 1 && return_index <= kMaxTaskReturns) << "index=" - << return_index; + RAY_CHECK(return_index >= 1 && return_index <= kMaxTaskReturns) + << "index=" << return_index; ObjectID object_id; std::memcpy(object_id.id_, task_id.Binary().c_str(), task_id.Size()); object_id.index_ = return_index; diff --git a/src/ray/core_worker/context.cc b/src/ray/core_worker/context.cc index 660330e5cee3..717c52e07076 100644 --- a/src/ray/core_worker/context.cc +++ b/src/ray/core_worker/context.cc @@ -23,6 +23,7 @@ struct WorkerThreadContext { void SetCurrentTask(const raylet::TaskSpecification &spec) { SetCurrentTask(spec.TaskId()); } + private: /// The task ID for current task. TaskID current_task_id; diff --git a/src/ray/core_worker/mock_worker.cc b/src/ray/core_worker/mock_worker.cc index 205fcfce961d..95d11bb259a8 100644 --- a/src/ray/core_worker/mock_worker.cc +++ b/src/ray/core_worker/mock_worker.cc @@ -25,7 +25,6 @@ class MockWorker { auto executor_func = [this](const RayFunction &ray_function, const std::vector> &args, const TaskID &task_id, int num_returns) { - // Note that this doesn't include dummy object id. RAY_CHECK(num_returns >= 0); diff --git a/src/ray/core_worker/task_interface.cc b/src/ray/core_worker/task_interface.cc index c19b1e23a7f9..00f15237f1c3 100644 --- a/src/ray/core_worker/task_interface.cc +++ b/src/ray/core_worker/task_interface.cc @@ -1,7 +1,6 @@ -#include "ray/raylet/task.h" +#include "ray/core_worker/task_interface.h" #include "ray/core_worker/context.h" #include "ray/core_worker/core_worker.h" -#include "ray/core_worker/task_interface.h" namespace ray { diff --git a/src/ray/core_worker/task_interface.h b/src/ray/core_worker/task_interface.h index 06bd5409a8dd..2ec3b1329cbc 100644 --- a/src/ray/core_worker/task_interface.h +++ b/src/ray/core_worker/task_interface.h @@ -7,6 +7,7 @@ #include "ray/common/id.h" #include "ray/common/status.h" #include "ray/core_worker/common.h" +#include "ray/raylet/task.h" namespace ray { diff --git a/src/ray/gcs/client.cc b/src/ray/gcs/client.cc index 3d1c6602740c..c9b1e138575d 100644 --- a/src/ray/gcs/client.cc +++ b/src/ray/gcs/client.cc @@ -23,8 +23,8 @@ static void GetRedisShards(redisContext *context, std::vector &addr } RAY_CHECK(num_attempts < RayConfig::instance().redis_db_connect_retries()) << "No entry found for NumRedisShards"; - RAY_CHECK(reply->type == REDIS_REPLY_STRING) << "Expected string, found Redis type " - << reply->type << " for NumRedisShards"; + RAY_CHECK(reply->type == REDIS_REPLY_STRING) + << "Expected string, found Redis type " << reply->type << " for NumRedisShards"; int num_redis_shards = atoi(reply->str); RAY_CHECK(num_redis_shards >= 1) << "Expected at least one Redis shard, " << "found " << num_redis_shards; diff --git a/src/ray/gcs/client_test.cc b/src/ray/gcs/client_test.cc index 4eb34a95328a..c7dc02e50651 100644 --- a/src/ray/gcs/client_test.cc +++ b/src/ray/gcs/client_test.cc @@ -150,8 +150,8 @@ void TestLogLookup(const DriverID &driver_id, // Check that lookup returns the added object entries. auto lookup_callback = [task_id, node_manager_ids]( - gcs::AsyncGcsClient *client, const TaskID &id, - const std::vector &data) { + gcs::AsyncGcsClient *client, const TaskID &id, + const std::vector &data) { ASSERT_EQ(id, task_id); for (const auto &entry : data) { ASSERT_EQ(entry.node_manager_id, node_manager_ids[test->NumCallbacks()]); @@ -241,8 +241,8 @@ void TestLogAppendAt(const DriverID &driver_id, /*done callback=*/nullptr, failure_callback, /*log_length=*/1)); auto lookup_callback = [node_manager_ids]( - gcs::AsyncGcsClient *client, const TaskID &id, - const std::vector &data) { + gcs::AsyncGcsClient *client, const TaskID &id, + const std::vector &data) { std::vector appended_managers; for (const auto &entry : data) { appended_managers.push_back(entry.node_manager_id); @@ -282,8 +282,8 @@ void TestSet(const DriverID &driver_id, std::shared_ptr cli // Check that lookup returns the added object entries. auto lookup_callback = [object_id, managers]( - gcs::AsyncGcsClient *client, const ObjectID &id, - const std::vector &data) { + gcs::AsyncGcsClient *client, const ObjectID &id, + const std::vector &data) { ASSERT_EQ(id, object_id); ASSERT_EQ(data.size(), managers.size()); test->IncrementNumCallbacks(); @@ -296,8 +296,9 @@ void TestSet(const DriverID &driver_id, std::shared_ptr cli auto data = std::make_shared(); data->manager = manager; // Check that we added the correct object entries. - auto remove_entry_callback = [object_id, data]( - gcs::AsyncGcsClient *client, const ObjectID &id, const ObjectTableDataT &d) { + auto remove_entry_callback = [object_id, data](gcs::AsyncGcsClient *client, + const ObjectID &id, + const ObjectTableDataT &d) { ASSERT_EQ(id, object_id); ASSERT_EQ(data->manager, d.manager); test->IncrementNumCallbacks(); @@ -308,8 +309,8 @@ void TestSet(const DriverID &driver_id, std::shared_ptr cli // Check that the entries are removed. auto lookup_callback2 = [object_id, managers]( - gcs::AsyncGcsClient *client, const ObjectID &id, - const std::vector &data) { + gcs::AsyncGcsClient *client, const ObjectID &id, + const std::vector &data) { ASSERT_EQ(id, object_id); ASSERT_EQ(data.size(), 0); test->IncrementNumCallbacks(); @@ -350,8 +351,8 @@ void TestDeleteKeysFromLog( for (const auto &task_id : ids) { // Check that lookup returns the added object entries. auto lookup_callback = [task_id, data_vector]( - gcs::AsyncGcsClient *client, const TaskID &id, - const std::vector &data) { + gcs::AsyncGcsClient *client, const TaskID &id, + const std::vector &data) { ASSERT_EQ(id, task_id); ASSERT_EQ(data.size(), 1); test->IncrementNumCallbacks(); @@ -445,8 +446,8 @@ void TestDeleteKeysFromSet(const DriverID &driver_id, for (const auto &object_id : ids) { // Check that lookup returns the added object entries. auto lookup_callback = [object_id, data_vector]( - gcs::AsyncGcsClient *client, const ObjectID &id, - const std::vector &data) { + gcs::AsyncGcsClient *client, const ObjectID &id, + const std::vector &data) { ASSERT_EQ(id, object_id); ASSERT_EQ(data.size(), 1); test->IncrementNumCallbacks(); @@ -657,8 +658,9 @@ void TestSetSubscribeAll(const DriverID &driver_id, // Callback for a notification. auto notification_callback = [object_ids, managers]( - gcs::AsyncGcsClient *client, const ObjectID &id, const GcsChangeMode change_mode, - const std::vector data) { + gcs::AsyncGcsClient *client, const ObjectID &id, + const GcsChangeMode change_mode, + const std::vector data) { if (test->NumCallbacks() < 3 * 3) { ASSERT_EQ(change_mode, GcsChangeMode::APPEND_OR_ADD); } else { @@ -736,8 +738,9 @@ void TestTableSubscribeId(const DriverID &driver_id, // The callback for a notification from the table. This should only be // received for keys that we requested notifications for. - auto notification_callback = [task_id2, task_specs2]( - gcs::AsyncGcsClient *client, const TaskID &id, const protocol::TaskT &data) { + auto notification_callback = [task_id2, task_specs2](gcs::AsyncGcsClient *client, + const TaskID &id, + const protocol::TaskT &data) { // Check that we only get notifications for the requested key. ASSERT_EQ(id, task_id2); // Check that we get notifications in the same order as the writes. @@ -751,7 +754,7 @@ void TestTableSubscribeId(const DriverID &driver_id, // The failure callback should be called once since both keys start as empty. bool failure_notification_received = false; auto failure_callback = [task_id2, &failure_notification_received]( - gcs::AsyncGcsClient *client, const TaskID &id) { + gcs::AsyncGcsClient *client, const TaskID &id) { ASSERT_EQ(id, task_id2); // The failure notification should be the first notification received. ASSERT_EQ(test->NumCallbacks(), 0); @@ -819,8 +822,8 @@ void TestLogSubscribeId(const DriverID &driver_id, // The callback for a notification from the table. This should only be // received for keys that we requested notifications for. auto notification_callback = [driver_id2, driver_ids2]( - gcs::AsyncGcsClient *client, const UniqueID &id, - const std::vector &data) { + gcs::AsyncGcsClient *client, const UniqueID &id, + const std::vector &data) { // Check that we only get notifications for the requested key. ASSERT_EQ(id, driver_id2); // Check that we get notifications in the same order as the writes. @@ -893,8 +896,9 @@ void TestSetSubscribeId(const DriverID &driver_id, // The callback for a notification from the table. This should only be // received for keys that we requested notifications for. auto notification_callback = [object_id2, managers2]( - gcs::AsyncGcsClient *client, const ObjectID &id, const GcsChangeMode change_mode, - const std::vector &data) { + gcs::AsyncGcsClient *client, const ObjectID &id, + const GcsChangeMode change_mode, + const std::vector &data) { ASSERT_EQ(change_mode, GcsChangeMode::APPEND_OR_ADD); // Check that we only get notifications for the requested key. ASSERT_EQ(id, object_id2); @@ -966,8 +970,9 @@ void TestTableSubscribeCancel(const DriverID &driver_id, // The callback for a notification from the table. This should only be // received for keys that we requested notifications for. - auto notification_callback = [task_id, task_specs]( - gcs::AsyncGcsClient *client, const TaskID &id, const protocol::TaskT &data) { + auto notification_callback = [task_id, task_specs](gcs::AsyncGcsClient *client, + const TaskID &id, + const protocol::TaskT &data) { ASSERT_EQ(id, task_id); // Check that we only get notifications for the first and last writes, // since notifications are canceled in between. @@ -1036,8 +1041,8 @@ void TestLogSubscribeCancel(const DriverID &driver_id, // The callback for a notification from the object table. This should only be // received for the object that we requested notifications for. auto notification_callback = [random_driver_id, driver_ids]( - gcs::AsyncGcsClient *client, const UniqueID &id, - const std::vector &data) { + gcs::AsyncGcsClient *client, const UniqueID &id, + const std::vector &data) { ASSERT_EQ(id, random_driver_id); // Check that we get a duplicate notification for the first write. We get a // duplicate notification because the log is append-only and notifications @@ -1109,8 +1114,9 @@ void TestSetSubscribeCancel(const DriverID &driver_id, // The callback for a notification from the object table. This should only be // received for the object that we requested notifications for. auto notification_callback = [object_id, managers]( - gcs::AsyncGcsClient *client, const ObjectID &id, const GcsChangeMode change_mode, - const std::vector &data) { + gcs::AsyncGcsClient *client, const ObjectID &id, + const GcsChangeMode change_mode, + const std::vector &data) { ASSERT_EQ(change_mode, GcsChangeMode::APPEND_OR_ADD); ASSERT_EQ(id, object_id); // Check that we get a duplicate notification for the first write. We get a @@ -1291,11 +1297,12 @@ void TestClientTableMarkDisconnected(const DriverID &driver_id, RAY_CHECK_OK(client->client_table().MarkDisconnected(dead_client_id)); // Make sure we only get a notification for the removal of the client we // marked as dead. - client->client_table().RegisterClientRemovedCallback([dead_client_id]( - gcs::AsyncGcsClient *client, const UniqueID &id, const ClientTableDataT &data) { - ASSERT_EQ(ClientID::FromBinary(data.client_id), dead_client_id); - test->Stop(); - }); + client->client_table().RegisterClientRemovedCallback( + [dead_client_id](gcs::AsyncGcsClient *client, const UniqueID &id, + const ClientTableDataT &data) { + ASSERT_EQ(ClientID::FromBinary(data.client_id), dead_client_id); + test->Stop(); + }); test->Start(); } @@ -1350,8 +1357,9 @@ void TestHashTable(const DriverID &driver_id, test->IncrementNumCallbacks(); }; auto notification_callback = [data_map1, data_map2, compare_test]( - AsyncGcsClient *client, const ClientID &id, const GcsChangeMode change_mode, - const DynamicResourceTable::DataMap &data) { + AsyncGcsClient *client, const ClientID &id, + const GcsChangeMode change_mode, + const DynamicResourceTable::DataMap &data) { if (change_mode == GcsChangeMode::REMOVE) { ASSERT_EQ(data.size(), 2); ASSERT_TRUE(data.find("GPU") != data.end()); @@ -1380,16 +1388,16 @@ void TestHashTable(const DriverID &driver_id, // Step 1: Add elements to the hash table. auto update_callback1 = [data_map1, compare_test]( - AsyncGcsClient *client, const ClientID &id, - const DynamicResourceTable::DataMap &callback_data) { + AsyncGcsClient *client, const ClientID &id, + const DynamicResourceTable::DataMap &callback_data) { compare_test(data_map1, callback_data); test->IncrementNumCallbacks(); }; RAY_CHECK_OK( client->resource_table().Update(driver_id, client_id, data_map1, update_callback1)); auto lookup_callback1 = [data_map1, compare_test]( - AsyncGcsClient *client, const ClientID &id, - const DynamicResourceTable::DataMap &callback_data) { + AsyncGcsClient *client, const ClientID &id, + const DynamicResourceTable::DataMap &callback_data) { compare_test(data_map1, callback_data); test->IncrementNumCallbacks(); }; @@ -1398,8 +1406,8 @@ void TestHashTable(const DriverID &driver_id, // Step 2: Decrease one element, increase one and add a new one. RAY_CHECK_OK(client->resource_table().Update(driver_id, client_id, data_map2, nullptr)); auto lookup_callback2 = [data_map2, compare_test]( - AsyncGcsClient *client, const ClientID &id, - const DynamicResourceTable::DataMap &callback_data) { + AsyncGcsClient *client, const ClientID &id, + const DynamicResourceTable::DataMap &callback_data) { compare_test(data_map2, callback_data); test->IncrementNumCallbacks(); }; @@ -1419,8 +1427,8 @@ void TestHashTable(const DriverID &driver_id, data_map3.erase("GPU"); data_map3.erase("CUSTOM"); auto lookup_callback3 = [data_map3, compare_test]( - AsyncGcsClient *client, const ClientID &id, - const DynamicResourceTable::DataMap &callback_data) { + AsyncGcsClient *client, const ClientID &id, + const DynamicResourceTable::DataMap &callback_data) { compare_test(data_map3, callback_data); test->IncrementNumCallbacks(); }; @@ -1430,8 +1438,8 @@ void TestHashTable(const DriverID &driver_id, RAY_CHECK_OK( client->resource_table().Update(driver_id, client_id, data_map1, update_callback1)); auto lookup_callback4 = [data_map1, compare_test]( - AsyncGcsClient *client, const ClientID &id, - const DynamicResourceTable::DataMap &callback_data) { + AsyncGcsClient *client, const ClientID &id, + const DynamicResourceTable::DataMap &callback_data) { compare_test(data_map1, callback_data); test->IncrementNumCallbacks(); }; diff --git a/src/ray/gcs/redis_context.cc b/src/ray/gcs/redis_context.cc index fe5ba3d1d134..ae6cb6088cec 100644 --- a/src/ray/gcs/redis_context.cc +++ b/src/ray/gcs/redis_context.cc @@ -48,28 +48,28 @@ namespace gcs { CallbackReply::CallbackReply(redisReply *redis_reply) { RAY_CHECK(nullptr != redis_reply); - RAY_CHECK(redis_reply->type != REDIS_REPLY_ERROR) << "Got an error in redis reply: " - << redis_reply->str; + RAY_CHECK(redis_reply->type != REDIS_REPLY_ERROR) + << "Got an error in redis reply: " << redis_reply->str; this->redis_reply_ = redis_reply; } bool CallbackReply::IsNil() const { return REDIS_REPLY_NIL == redis_reply_->type; } int64_t CallbackReply::ReadAsInteger() const { - RAY_CHECK(REDIS_REPLY_INTEGER == redis_reply_->type) << "Unexpected type: " - << redis_reply_->type; + RAY_CHECK(REDIS_REPLY_INTEGER == redis_reply_->type) + << "Unexpected type: " << redis_reply_->type; return static_cast(redis_reply_->integer); } std::string CallbackReply::ReadAsString() const { - RAY_CHECK(REDIS_REPLY_STRING == redis_reply_->type) << "Unexpected type: " - << redis_reply_->type; + RAY_CHECK(REDIS_REPLY_STRING == redis_reply_->type) + << "Unexpected type: " << redis_reply_->type; return std::string(redis_reply_->str, redis_reply_->len); } Status CallbackReply::ReadAsStatus() const { - RAY_CHECK(REDIS_REPLY_STATUS == redis_reply_->type) << "Unexpected type: " - << redis_reply_->type; + RAY_CHECK(REDIS_REPLY_STATUS == redis_reply_->type) + << "Unexpected type: " << redis_reply_->type; const std::string status_str(redis_reply_->str, redis_reply_->len); if ("OK" == status_str) { return Status::OK(); @@ -79,8 +79,8 @@ Status CallbackReply::ReadAsStatus() const { } std::string CallbackReply::ReadAsPubsubData() const { - RAY_CHECK(REDIS_REPLY_ARRAY == redis_reply_->type) << "Unexpected type: " - << redis_reply_->type; + RAY_CHECK(REDIS_REPLY_ARRAY == redis_reply_->type) + << "Unexpected type: " << redis_reply_->type; std::string data = ""; // Parse the published message. diff --git a/src/ray/gcs/redis_module/ray_redis_module.cc b/src/ray/gcs/redis_module/ray_redis_module.cc index e059787472f1..e291b7ffdb32 100644 --- a/src/ray/gcs/redis_module/ray_redis_module.cc +++ b/src/ray/gcs/redis_module/ray_redis_module.cc @@ -1004,7 +1004,7 @@ int DebugString_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int std::string debug_string = DebugString(); return RedisModule_ReplyWithStringBuffer(ctx, debug_string.data(), debug_string.size()); } -}; +}; // namespace internal_redis_commands // Wrap all Redis commands with Redis' auto memory management. AUTO_MEMORY(TableAdd_RedisCommand); diff --git a/src/ray/gcs/redis_module/redismodule.h b/src/ray/gcs/redis_module/redismodule.h index 186e284c0e04..23721beef07d 100644 --- a/src/ray/gcs/redis_module/redismodule.h +++ b/src/ray/gcs/redis_module/redismodule.h @@ -1,9 +1,9 @@ #ifndef REDISMODULE_H #define REDISMODULE_H -#include #include #include +#include /* ---------------- Defines common between core and modules --------------- */ @@ -15,8 +15,8 @@ #define REDISMODULE_APIVER_1 1 /* API flags and constants */ -#define REDISMODULE_READ (1<<0) -#define REDISMODULE_WRITE (1<<1) +#define REDISMODULE_READ (1 << 0) +#define REDISMODULE_WRITE (1 << 1) #define REDISMODULE_LIST_HEAD 0 #define REDISMODULE_LIST_TAIL 1 @@ -45,30 +45,31 @@ #define REDISMODULE_NO_EXPIRE -1 /* Sorted set API flags. */ -#define REDISMODULE_ZADD_XX (1<<0) -#define REDISMODULE_ZADD_NX (1<<1) -#define REDISMODULE_ZADD_ADDED (1<<2) -#define REDISMODULE_ZADD_UPDATED (1<<3) -#define REDISMODULE_ZADD_NOP (1<<4) +#define REDISMODULE_ZADD_XX (1 << 0) +#define REDISMODULE_ZADD_NX (1 << 1) +#define REDISMODULE_ZADD_ADDED (1 << 2) +#define REDISMODULE_ZADD_UPDATED (1 << 3) +#define REDISMODULE_ZADD_NOP (1 << 4) /* Hash API flags. */ -#define REDISMODULE_HASH_NONE 0 -#define REDISMODULE_HASH_NX (1<<0) -#define REDISMODULE_HASH_XX (1<<1) -#define REDISMODULE_HASH_CFIELDS (1<<2) -#define REDISMODULE_HASH_EXISTS (1<<3) +#define REDISMODULE_HASH_NONE 0 +#define REDISMODULE_HASH_NX (1 << 0) +#define REDISMODULE_HASH_XX (1 << 1) +#define REDISMODULE_HASH_CFIELDS (1 << 2) +#define REDISMODULE_HASH_EXISTS (1 << 3) /* A special pointer that we can use between the core and the module to signal * field deletion, and that is impossible to be a valid pointer. */ -#define REDISMODULE_HASH_DELETE ((RedisModuleString*)(long)1) +#define REDISMODULE_HASH_DELETE ((RedisModuleString *)(long)1) /* Error messages. */ -#define REDISMODULE_ERRORMSG_WRONGTYPE "WRONGTYPE Operation against a key holding the wrong kind of value" +#define REDISMODULE_ERRORMSG_WRONGTYPE \ + "WRONGTYPE Operation against a key holding the wrong kind of value" -#define REDISMODULE_POSITIVE_INFINITE (1.0/0.0) -#define REDISMODULE_NEGATIVE_INFINITE (-1.0/0.0) +#define REDISMODULE_POSITIVE_INFINITE (1.0 / 0.0) +#define REDISMODULE_NEGATIVE_INFINITE (-1.0 / 0.0) -#define REDISMODULE_NOT_USED(V) ((void) V) +#define REDISMODULE_NOT_USED(V) ((void)V) /* ------------------------- End of common defines ------------------------ */ @@ -86,95 +87,142 @@ typedef struct RedisModuleType RedisModuleType; typedef struct RedisModuleDigest RedisModuleDigest; typedef struct RedisModuleBlockedClient RedisModuleBlockedClient; -typedef int (*RedisModuleCmdFunc) (RedisModuleCtx *ctx, RedisModuleString **argv, int argc); +typedef int (*RedisModuleCmdFunc)(RedisModuleCtx *ctx, RedisModuleString **argv, + int argc); typedef void *(*RedisModuleTypeLoadFunc)(RedisModuleIO *rdb, int encver); typedef void (*RedisModuleTypeSaveFunc)(RedisModuleIO *rdb, void *value); -typedef void (*RedisModuleTypeRewriteFunc)(RedisModuleIO *aof, RedisModuleString *key, void *value); +typedef void (*RedisModuleTypeRewriteFunc)(RedisModuleIO *aof, RedisModuleString *key, + void *value); typedef size_t (*RedisModuleTypeMemUsageFunc)(void *value); typedef void (*RedisModuleTypeDigestFunc)(RedisModuleDigest *digest, void *value); typedef void (*RedisModuleTypeFreeFunc)(void *value); #define REDISMODULE_TYPE_METHOD_VERSION 1 typedef struct RedisModuleTypeMethods { - uint64_t version; - RedisModuleTypeLoadFunc rdb_load; - RedisModuleTypeSaveFunc rdb_save; - RedisModuleTypeRewriteFunc aof_rewrite; - RedisModuleTypeMemUsageFunc mem_usage; - RedisModuleTypeDigestFunc digest; - RedisModuleTypeFreeFunc free; + uint64_t version; + RedisModuleTypeLoadFunc rdb_load; + RedisModuleTypeSaveFunc rdb_save; + RedisModuleTypeRewriteFunc aof_rewrite; + RedisModuleTypeMemUsageFunc mem_usage; + RedisModuleTypeDigestFunc digest; + RedisModuleTypeFreeFunc free; } RedisModuleTypeMethods; #define REDISMODULE_GET_API(name) \ - RedisModule_GetApi("RedisModule_" #name, ((void **)&RedisModule_ ## name)) + RedisModule_GetApi("RedisModule_" #name, ((void **)&RedisModule_##name)) #define REDISMODULE_API_FUNC(x) (*x) - void *REDISMODULE_API_FUNC(RedisModule_Alloc)(size_t bytes); void *REDISMODULE_API_FUNC(RedisModule_Realloc)(void *ptr, size_t bytes); void REDISMODULE_API_FUNC(RedisModule_Free)(void *ptr); void *REDISMODULE_API_FUNC(RedisModule_Calloc)(size_t nmemb, size_t size); char *REDISMODULE_API_FUNC(RedisModule_Strdup)(const char *str); int REDISMODULE_API_FUNC(RedisModule_GetApi)(const char *, void *); -int REDISMODULE_API_FUNC(RedisModule_CreateCommand)(RedisModuleCtx *ctx, const char *name, RedisModuleCmdFunc cmdfunc, const char *strflags, int firstkey, int lastkey, int keystep); -int REDISMODULE_API_FUNC(RedisModule_SetModuleAttribs)(RedisModuleCtx *ctx, const char *name, int ver, int apiver); +int REDISMODULE_API_FUNC(RedisModule_CreateCommand)(RedisModuleCtx *ctx, const char *name, + RedisModuleCmdFunc cmdfunc, + const char *strflags, int firstkey, + int lastkey, int keystep); +int REDISMODULE_API_FUNC(RedisModule_SetModuleAttribs)(RedisModuleCtx *ctx, + const char *name, int ver, + int apiver); int REDISMODULE_API_FUNC(RedisModule_WrongArity)(RedisModuleCtx *ctx); -int REDISMODULE_API_FUNC(RedisModule_ReplyWithLongLong)(RedisModuleCtx *ctx, long long ll); +int REDISMODULE_API_FUNC(RedisModule_ReplyWithLongLong)(RedisModuleCtx *ctx, + long long ll); int REDISMODULE_API_FUNC(RedisModule_GetSelectedDb)(RedisModuleCtx *ctx); int REDISMODULE_API_FUNC(RedisModule_SelectDb)(RedisModuleCtx *ctx, int newid); -void *REDISMODULE_API_FUNC(RedisModule_OpenKey)(RedisModuleCtx *ctx, RedisModuleString *keyname, int mode); +void *REDISMODULE_API_FUNC(RedisModule_OpenKey)(RedisModuleCtx *ctx, + RedisModuleString *keyname, int mode); void REDISMODULE_API_FUNC(RedisModule_CloseKey)(RedisModuleKey *kp); int REDISMODULE_API_FUNC(RedisModule_KeyType)(RedisModuleKey *kp); size_t REDISMODULE_API_FUNC(RedisModule_ValueLength)(RedisModuleKey *kp); -int REDISMODULE_API_FUNC(RedisModule_ListPush)(RedisModuleKey *kp, int where, RedisModuleString *ele); -RedisModuleString *REDISMODULE_API_FUNC(RedisModule_ListPop)(RedisModuleKey *key, int where); -RedisModuleCallReply *REDISMODULE_API_FUNC(RedisModule_Call)(RedisModuleCtx *ctx, const char *cmdname, const char *fmt, ...); -const char *REDISMODULE_API_FUNC(RedisModule_CallReplyProto)(RedisModuleCallReply *reply, size_t *len); +int REDISMODULE_API_FUNC(RedisModule_ListPush)(RedisModuleKey *kp, int where, + RedisModuleString *ele); +RedisModuleString *REDISMODULE_API_FUNC(RedisModule_ListPop)(RedisModuleKey *key, + int where); +RedisModuleCallReply *REDISMODULE_API_FUNC(RedisModule_Call)(RedisModuleCtx *ctx, + const char *cmdname, + const char *fmt, ...); +const char *REDISMODULE_API_FUNC(RedisModule_CallReplyProto)(RedisModuleCallReply *reply, + size_t *len); void REDISMODULE_API_FUNC(RedisModule_FreeCallReply)(RedisModuleCallReply *reply); int REDISMODULE_API_FUNC(RedisModule_CallReplyType)(RedisModuleCallReply *reply); long long REDISMODULE_API_FUNC(RedisModule_CallReplyInteger)(RedisModuleCallReply *reply); size_t REDISMODULE_API_FUNC(RedisModule_CallReplyLength)(RedisModuleCallReply *reply); -RedisModuleCallReply *REDISMODULE_API_FUNC(RedisModule_CallReplyArrayElement)(RedisModuleCallReply *reply, size_t idx); -RedisModuleString *REDISMODULE_API_FUNC(RedisModule_CreateString)(RedisModuleCtx *ctx, const char *ptr, size_t len); -RedisModuleString *REDISMODULE_API_FUNC(RedisModule_CreateStringFromLongLong)(RedisModuleCtx *ctx, long long ll); -RedisModuleString *REDISMODULE_API_FUNC(RedisModule_CreateStringFromString)(RedisModuleCtx *ctx, const RedisModuleString *str); -RedisModuleString *REDISMODULE_API_FUNC(RedisModule_CreateStringPrintf)(RedisModuleCtx *ctx, const char *fmt, ...); -void REDISMODULE_API_FUNC(RedisModule_FreeString)(RedisModuleCtx *ctx, RedisModuleString *str); -const char *REDISMODULE_API_FUNC(RedisModule_StringPtrLen)(const RedisModuleString *str, size_t *len); -int REDISMODULE_API_FUNC(RedisModule_ReplyWithError)(RedisModuleCtx *ctx, const char *err); -int REDISMODULE_API_FUNC(RedisModule_ReplyWithSimpleString)(RedisModuleCtx *ctx, const char *msg); +RedisModuleCallReply *REDISMODULE_API_FUNC(RedisModule_CallReplyArrayElement)( + RedisModuleCallReply *reply, size_t idx); +RedisModuleString *REDISMODULE_API_FUNC(RedisModule_CreateString)(RedisModuleCtx *ctx, + const char *ptr, + size_t len); +RedisModuleString *REDISMODULE_API_FUNC(RedisModule_CreateStringFromLongLong)( + RedisModuleCtx *ctx, long long ll); +RedisModuleString *REDISMODULE_API_FUNC(RedisModule_CreateStringFromString)( + RedisModuleCtx *ctx, const RedisModuleString *str); +RedisModuleString *REDISMODULE_API_FUNC(RedisModule_CreateStringPrintf)( + RedisModuleCtx *ctx, const char *fmt, ...); +void REDISMODULE_API_FUNC(RedisModule_FreeString)(RedisModuleCtx *ctx, + RedisModuleString *str); +const char *REDISMODULE_API_FUNC(RedisModule_StringPtrLen)(const RedisModuleString *str, + size_t *len); +int REDISMODULE_API_FUNC(RedisModule_ReplyWithError)(RedisModuleCtx *ctx, + const char *err); +int REDISMODULE_API_FUNC(RedisModule_ReplyWithSimpleString)(RedisModuleCtx *ctx, + const char *msg); int REDISMODULE_API_FUNC(RedisModule_ReplyWithArray)(RedisModuleCtx *ctx, long len); void REDISMODULE_API_FUNC(RedisModule_ReplySetArrayLength)(RedisModuleCtx *ctx, long len); -int REDISMODULE_API_FUNC(RedisModule_ReplyWithStringBuffer)(RedisModuleCtx *ctx, const char *buf, size_t len); -int REDISMODULE_API_FUNC(RedisModule_ReplyWithString)(RedisModuleCtx *ctx, RedisModuleString *str); +int REDISMODULE_API_FUNC(RedisModule_ReplyWithStringBuffer)(RedisModuleCtx *ctx, + const char *buf, size_t len); +int REDISMODULE_API_FUNC(RedisModule_ReplyWithString)(RedisModuleCtx *ctx, + RedisModuleString *str); int REDISMODULE_API_FUNC(RedisModule_ReplyWithNull)(RedisModuleCtx *ctx); int REDISMODULE_API_FUNC(RedisModule_ReplyWithDouble)(RedisModuleCtx *ctx, double d); -int REDISMODULE_API_FUNC(RedisModule_ReplyWithCallReply)(RedisModuleCtx *ctx, RedisModuleCallReply *reply); -int REDISMODULE_API_FUNC(RedisModule_StringToLongLong)(const RedisModuleString *str, long long *ll); -int REDISMODULE_API_FUNC(RedisModule_StringToDouble)(const RedisModuleString *str, double *d); +int REDISMODULE_API_FUNC(RedisModule_ReplyWithCallReply)(RedisModuleCtx *ctx, + RedisModuleCallReply *reply); +int REDISMODULE_API_FUNC(RedisModule_StringToLongLong)(const RedisModuleString *str, + long long *ll); +int REDISMODULE_API_FUNC(RedisModule_StringToDouble)(const RedisModuleString *str, + double *d); void REDISMODULE_API_FUNC(RedisModule_AutoMemory)(RedisModuleCtx *ctx); -int REDISMODULE_API_FUNC(RedisModule_Replicate)(RedisModuleCtx *ctx, const char *cmdname, const char *fmt, ...); +int REDISMODULE_API_FUNC(RedisModule_Replicate)(RedisModuleCtx *ctx, const char *cmdname, + const char *fmt, ...); int REDISMODULE_API_FUNC(RedisModule_ReplicateVerbatim)(RedisModuleCtx *ctx); -const char *REDISMODULE_API_FUNC(RedisModule_CallReplyStringPtr)(RedisModuleCallReply *reply, size_t *len); -RedisModuleString *REDISMODULE_API_FUNC(RedisModule_CreateStringFromCallReply)(RedisModuleCallReply *reply); +const char *REDISMODULE_API_FUNC(RedisModule_CallReplyStringPtr)( + RedisModuleCallReply *reply, size_t *len); +RedisModuleString *REDISMODULE_API_FUNC(RedisModule_CreateStringFromCallReply)( + RedisModuleCallReply *reply); int REDISMODULE_API_FUNC(RedisModule_DeleteKey)(RedisModuleKey *key); -int REDISMODULE_API_FUNC(RedisModule_StringSet)(RedisModuleKey *key, RedisModuleString *str); -char *REDISMODULE_API_FUNC(RedisModule_StringDMA)(RedisModuleKey *key, size_t *len, int mode); +int REDISMODULE_API_FUNC(RedisModule_StringSet)(RedisModuleKey *key, + RedisModuleString *str); +char *REDISMODULE_API_FUNC(RedisModule_StringDMA)(RedisModuleKey *key, size_t *len, + int mode); int REDISMODULE_API_FUNC(RedisModule_StringTruncate)(RedisModuleKey *key, size_t newlen); mstime_t REDISMODULE_API_FUNC(RedisModule_GetExpire)(RedisModuleKey *key); int REDISMODULE_API_FUNC(RedisModule_SetExpire)(RedisModuleKey *key, mstime_t expire); -int REDISMODULE_API_FUNC(RedisModule_ZsetAdd)(RedisModuleKey *key, double score, RedisModuleString *ele, int *flagsptr); -int REDISMODULE_API_FUNC(RedisModule_ZsetIncrby)(RedisModuleKey *key, double score, RedisModuleString *ele, int *flagsptr, double *newscore); -int REDISMODULE_API_FUNC(RedisModule_ZsetScore)(RedisModuleKey *key, RedisModuleString *ele, double *score); -int REDISMODULE_API_FUNC(RedisModule_ZsetRem)(RedisModuleKey *key, RedisModuleString *ele, int *deleted); +int REDISMODULE_API_FUNC(RedisModule_ZsetAdd)(RedisModuleKey *key, double score, + RedisModuleString *ele, int *flagsptr); +int REDISMODULE_API_FUNC(RedisModule_ZsetIncrby)(RedisModuleKey *key, double score, + RedisModuleString *ele, int *flagsptr, + double *newscore); +int REDISMODULE_API_FUNC(RedisModule_ZsetScore)(RedisModuleKey *key, + RedisModuleString *ele, double *score); +int REDISMODULE_API_FUNC(RedisModule_ZsetRem)(RedisModuleKey *key, RedisModuleString *ele, + int *deleted); void REDISMODULE_API_FUNC(RedisModule_ZsetRangeStop)(RedisModuleKey *key); -int REDISMODULE_API_FUNC(RedisModule_ZsetFirstInScoreRange)(RedisModuleKey *key, double min, double max, int minex, int maxex); -int REDISMODULE_API_FUNC(RedisModule_ZsetLastInScoreRange)(RedisModuleKey *key, double min, double max, int minex, int maxex); -int REDISMODULE_API_FUNC(RedisModule_ZsetFirstInLexRange)(RedisModuleKey *key, RedisModuleString *min, RedisModuleString *max); -int REDISMODULE_API_FUNC(RedisModule_ZsetLastInLexRange)(RedisModuleKey *key, RedisModuleString *min, RedisModuleString *max); -RedisModuleString *REDISMODULE_API_FUNC(RedisModule_ZsetRangeCurrentElement)(RedisModuleKey *key, double *score); +int REDISMODULE_API_FUNC(RedisModule_ZsetFirstInScoreRange)(RedisModuleKey *key, + double min, double max, + int minex, int maxex); +int REDISMODULE_API_FUNC(RedisModule_ZsetLastInScoreRange)(RedisModuleKey *key, + double min, double max, + int minex, int maxex); +int REDISMODULE_API_FUNC(RedisModule_ZsetFirstInLexRange)(RedisModuleKey *key, + RedisModuleString *min, + RedisModuleString *max); +int REDISMODULE_API_FUNC(RedisModule_ZsetLastInLexRange)(RedisModuleKey *key, + RedisModuleString *min, + RedisModuleString *max); +RedisModuleString *REDISMODULE_API_FUNC(RedisModule_ZsetRangeCurrentElement)( + RedisModuleKey *key, double *score); int REDISMODULE_API_FUNC(RedisModule_ZsetRangeNext)(RedisModuleKey *key); int REDISMODULE_API_FUNC(RedisModule_ZsetRangePrev)(RedisModuleKey *key); int REDISMODULE_API_FUNC(RedisModule_ZsetRangeEndReached)(RedisModuleKey *key); @@ -184,31 +232,49 @@ int REDISMODULE_API_FUNC(RedisModule_IsKeysPositionRequest)(RedisModuleCtx *ctx) void REDISMODULE_API_FUNC(RedisModule_KeyAtPos)(RedisModuleCtx *ctx, int pos); unsigned long long REDISMODULE_API_FUNC(RedisModule_GetClientId)(RedisModuleCtx *ctx); void *REDISMODULE_API_FUNC(RedisModule_PoolAlloc)(RedisModuleCtx *ctx, size_t bytes); -RedisModuleType *REDISMODULE_API_FUNC(RedisModule_CreateDataType)(RedisModuleCtx *ctx, const char *name, int encver, RedisModuleTypeMethods *typemethods); -int REDISMODULE_API_FUNC(RedisModule_ModuleTypeSetValue)(RedisModuleKey *key, RedisModuleType *mt, void *value); +RedisModuleType *REDISMODULE_API_FUNC(RedisModule_CreateDataType)( + RedisModuleCtx *ctx, const char *name, int encver, + RedisModuleTypeMethods *typemethods); +int REDISMODULE_API_FUNC(RedisModule_ModuleTypeSetValue)(RedisModuleKey *key, + RedisModuleType *mt, + void *value); RedisModuleType *REDISMODULE_API_FUNC(RedisModule_ModuleTypeGetType)(RedisModuleKey *key); void *REDISMODULE_API_FUNC(RedisModule_ModuleTypeGetValue)(RedisModuleKey *key); void REDISMODULE_API_FUNC(RedisModule_SaveUnsigned)(RedisModuleIO *io, uint64_t value); uint64_t REDISMODULE_API_FUNC(RedisModule_LoadUnsigned)(RedisModuleIO *io); void REDISMODULE_API_FUNC(RedisModule_SaveSigned)(RedisModuleIO *io, int64_t value); int64_t REDISMODULE_API_FUNC(RedisModule_LoadSigned)(RedisModuleIO *io); -void REDISMODULE_API_FUNC(RedisModule_EmitAOF)(RedisModuleIO *io, const char *cmdname, const char *fmt, ...); -void REDISMODULE_API_FUNC(RedisModule_SaveString)(RedisModuleIO *io, RedisModuleString *s); -void REDISMODULE_API_FUNC(RedisModule_SaveStringBuffer)(RedisModuleIO *io, const char *str, size_t len); +void REDISMODULE_API_FUNC(RedisModule_EmitAOF)(RedisModuleIO *io, const char *cmdname, + const char *fmt, ...); +void REDISMODULE_API_FUNC(RedisModule_SaveString)(RedisModuleIO *io, + RedisModuleString *s); +void REDISMODULE_API_FUNC(RedisModule_SaveStringBuffer)(RedisModuleIO *io, + const char *str, size_t len); RedisModuleString *REDISMODULE_API_FUNC(RedisModule_LoadString)(RedisModuleIO *io); -char *REDISMODULE_API_FUNC(RedisModule_LoadStringBuffer)(RedisModuleIO *io, size_t *lenptr); +char *REDISMODULE_API_FUNC(RedisModule_LoadStringBuffer)(RedisModuleIO *io, + size_t *lenptr); void REDISMODULE_API_FUNC(RedisModule_SaveDouble)(RedisModuleIO *io, double value); double REDISMODULE_API_FUNC(RedisModule_LoadDouble)(RedisModuleIO *io); void REDISMODULE_API_FUNC(RedisModule_SaveFloat)(RedisModuleIO *io, float value); float REDISMODULE_API_FUNC(RedisModule_LoadFloat)(RedisModuleIO *io); -void REDISMODULE_API_FUNC(RedisModule_Log)(RedisModuleCtx *ctx, const char *level, const char *fmt, ...); -void REDISMODULE_API_FUNC(RedisModule_LogIOError)(RedisModuleIO *io, const char *levelstr, const char *fmt, ...); -int REDISMODULE_API_FUNC(RedisModule_StringAppendBuffer)(RedisModuleCtx *ctx, RedisModuleString *str, const char *buf, size_t len); -void REDISMODULE_API_FUNC(RedisModule_RetainString)(RedisModuleCtx *ctx, RedisModuleString *str); -int REDISMODULE_API_FUNC(RedisModule_StringCompare)(RedisModuleString *a, RedisModuleString *b); +void REDISMODULE_API_FUNC(RedisModule_Log)(RedisModuleCtx *ctx, const char *level, + const char *fmt, ...); +void REDISMODULE_API_FUNC(RedisModule_LogIOError)(RedisModuleIO *io, const char *levelstr, + const char *fmt, ...); +int REDISMODULE_API_FUNC(RedisModule_StringAppendBuffer)(RedisModuleCtx *ctx, + RedisModuleString *str, + const char *buf, size_t len); +void REDISMODULE_API_FUNC(RedisModule_RetainString)(RedisModuleCtx *ctx, + RedisModuleString *str); +int REDISMODULE_API_FUNC(RedisModule_StringCompare)(RedisModuleString *a, + RedisModuleString *b); RedisModuleCtx *REDISMODULE_API_FUNC(RedisModule_GetContextFromIO)(RedisModuleIO *io); -RedisModuleBlockedClient *REDISMODULE_API_FUNC(RedisModule_BlockClient)(RedisModuleCtx *ctx, RedisModuleCmdFunc reply_callback, RedisModuleCmdFunc timeout_callback, void (*free_privdata)(void*), long long timeout_ms); -int REDISMODULE_API_FUNC(RedisModule_UnblockClient)(RedisModuleBlockedClient *bc, void *privdata); +RedisModuleBlockedClient *REDISMODULE_API_FUNC(RedisModule_BlockClient)( + RedisModuleCtx *ctx, RedisModuleCmdFunc reply_callback, + RedisModuleCmdFunc timeout_callback, void (*free_privdata)(void *), + long long timeout_ms); +int REDISMODULE_API_FUNC(RedisModule_UnblockClient)(RedisModuleBlockedClient *bc, + void *privdata); int REDISMODULE_API_FUNC(RedisModule_IsBlockedReplyRequest)(RedisModuleCtx *ctx); int REDISMODULE_API_FUNC(RedisModule_IsBlockedTimeoutRequest)(RedisModuleCtx *ctx); void *REDISMODULE_API_FUNC(RedisModule_GetBlockedClientPrivateData)(RedisModuleCtx *ctx); @@ -216,115 +282,116 @@ int REDISMODULE_API_FUNC(RedisModule_AbortBlock)(RedisModuleBlockedClient *bc); long long REDISMODULE_API_FUNC(RedisModule_Milliseconds)(void); /* This is included inline inside each Redis module. */ -static int RedisModule_Init(RedisModuleCtx *ctx, const char *name, int ver, int apiver) __attribute__((unused)); +static int RedisModule_Init(RedisModuleCtx *ctx, const char *name, int ver, int apiver) + __attribute__((unused)); static int RedisModule_Init(RedisModuleCtx *ctx, const char *name, int ver, int apiver) { - void *getapifuncptr = ((void**)ctx)[0]; - RedisModule_GetApi = (int (*)(const char *, void *)) (unsigned long)getapifuncptr; - REDISMODULE_GET_API(Alloc); - REDISMODULE_GET_API(Calloc); - REDISMODULE_GET_API(Free); - REDISMODULE_GET_API(Realloc); - REDISMODULE_GET_API(Strdup); - REDISMODULE_GET_API(CreateCommand); - REDISMODULE_GET_API(SetModuleAttribs); - REDISMODULE_GET_API(WrongArity); - REDISMODULE_GET_API(ReplyWithLongLong); - REDISMODULE_GET_API(ReplyWithError); - REDISMODULE_GET_API(ReplyWithSimpleString); - REDISMODULE_GET_API(ReplyWithArray); - REDISMODULE_GET_API(ReplySetArrayLength); - REDISMODULE_GET_API(ReplyWithStringBuffer); - REDISMODULE_GET_API(ReplyWithString); - REDISMODULE_GET_API(ReplyWithNull); - REDISMODULE_GET_API(ReplyWithCallReply); - REDISMODULE_GET_API(ReplyWithDouble); - REDISMODULE_GET_API(ReplySetArrayLength); - REDISMODULE_GET_API(GetSelectedDb); - REDISMODULE_GET_API(SelectDb); - REDISMODULE_GET_API(OpenKey); - REDISMODULE_GET_API(CloseKey); - REDISMODULE_GET_API(KeyType); - REDISMODULE_GET_API(ValueLength); - REDISMODULE_GET_API(ListPush); - REDISMODULE_GET_API(ListPop); - REDISMODULE_GET_API(StringToLongLong); - REDISMODULE_GET_API(StringToDouble); - REDISMODULE_GET_API(Call); - REDISMODULE_GET_API(CallReplyProto); - REDISMODULE_GET_API(FreeCallReply); - REDISMODULE_GET_API(CallReplyInteger); - REDISMODULE_GET_API(CallReplyType); - REDISMODULE_GET_API(CallReplyLength); - REDISMODULE_GET_API(CallReplyArrayElement); - REDISMODULE_GET_API(CallReplyStringPtr); - REDISMODULE_GET_API(CreateStringFromCallReply); - REDISMODULE_GET_API(CreateString); - REDISMODULE_GET_API(CreateStringFromLongLong); - REDISMODULE_GET_API(CreateStringFromString); - REDISMODULE_GET_API(CreateStringPrintf); - REDISMODULE_GET_API(FreeString); - REDISMODULE_GET_API(StringPtrLen); - REDISMODULE_GET_API(AutoMemory); - REDISMODULE_GET_API(Replicate); - REDISMODULE_GET_API(ReplicateVerbatim); - REDISMODULE_GET_API(DeleteKey); - REDISMODULE_GET_API(StringSet); - REDISMODULE_GET_API(StringDMA); - REDISMODULE_GET_API(StringTruncate); - REDISMODULE_GET_API(GetExpire); - REDISMODULE_GET_API(SetExpire); - REDISMODULE_GET_API(ZsetAdd); - REDISMODULE_GET_API(ZsetIncrby); - REDISMODULE_GET_API(ZsetScore); - REDISMODULE_GET_API(ZsetRem); - REDISMODULE_GET_API(ZsetRangeStop); - REDISMODULE_GET_API(ZsetFirstInScoreRange); - REDISMODULE_GET_API(ZsetLastInScoreRange); - REDISMODULE_GET_API(ZsetFirstInLexRange); - REDISMODULE_GET_API(ZsetLastInLexRange); - REDISMODULE_GET_API(ZsetRangeCurrentElement); - REDISMODULE_GET_API(ZsetRangeNext); - REDISMODULE_GET_API(ZsetRangePrev); - REDISMODULE_GET_API(ZsetRangeEndReached); - REDISMODULE_GET_API(HashSet); - REDISMODULE_GET_API(HashGet); - REDISMODULE_GET_API(IsKeysPositionRequest); - REDISMODULE_GET_API(KeyAtPos); - REDISMODULE_GET_API(GetClientId); - REDISMODULE_GET_API(PoolAlloc); - REDISMODULE_GET_API(CreateDataType); - REDISMODULE_GET_API(ModuleTypeSetValue); - REDISMODULE_GET_API(ModuleTypeGetType); - REDISMODULE_GET_API(ModuleTypeGetValue); - REDISMODULE_GET_API(SaveUnsigned); - REDISMODULE_GET_API(LoadUnsigned); - REDISMODULE_GET_API(SaveSigned); - REDISMODULE_GET_API(LoadSigned); - REDISMODULE_GET_API(SaveString); - REDISMODULE_GET_API(SaveStringBuffer); - REDISMODULE_GET_API(LoadString); - REDISMODULE_GET_API(LoadStringBuffer); - REDISMODULE_GET_API(SaveDouble); - REDISMODULE_GET_API(LoadDouble); - REDISMODULE_GET_API(SaveFloat); - REDISMODULE_GET_API(LoadFloat); - REDISMODULE_GET_API(EmitAOF); - REDISMODULE_GET_API(Log); - REDISMODULE_GET_API(LogIOError); - REDISMODULE_GET_API(StringAppendBuffer); - REDISMODULE_GET_API(RetainString); - REDISMODULE_GET_API(StringCompare); - REDISMODULE_GET_API(GetContextFromIO); - REDISMODULE_GET_API(BlockClient); - REDISMODULE_GET_API(UnblockClient); - REDISMODULE_GET_API(IsBlockedReplyRequest); - REDISMODULE_GET_API(IsBlockedTimeoutRequest); - REDISMODULE_GET_API(GetBlockedClientPrivateData); - REDISMODULE_GET_API(AbortBlock); - REDISMODULE_GET_API(Milliseconds); - - RedisModule_SetModuleAttribs(ctx,name,ver,apiver); - return REDISMODULE_OK; + void *getapifuncptr = ((void **)ctx)[0]; + RedisModule_GetApi = (int (*)(const char *, void *))(unsigned long)getapifuncptr; + REDISMODULE_GET_API(Alloc); + REDISMODULE_GET_API(Calloc); + REDISMODULE_GET_API(Free); + REDISMODULE_GET_API(Realloc); + REDISMODULE_GET_API(Strdup); + REDISMODULE_GET_API(CreateCommand); + REDISMODULE_GET_API(SetModuleAttribs); + REDISMODULE_GET_API(WrongArity); + REDISMODULE_GET_API(ReplyWithLongLong); + REDISMODULE_GET_API(ReplyWithError); + REDISMODULE_GET_API(ReplyWithSimpleString); + REDISMODULE_GET_API(ReplyWithArray); + REDISMODULE_GET_API(ReplySetArrayLength); + REDISMODULE_GET_API(ReplyWithStringBuffer); + REDISMODULE_GET_API(ReplyWithString); + REDISMODULE_GET_API(ReplyWithNull); + REDISMODULE_GET_API(ReplyWithCallReply); + REDISMODULE_GET_API(ReplyWithDouble); + REDISMODULE_GET_API(ReplySetArrayLength); + REDISMODULE_GET_API(GetSelectedDb); + REDISMODULE_GET_API(SelectDb); + REDISMODULE_GET_API(OpenKey); + REDISMODULE_GET_API(CloseKey); + REDISMODULE_GET_API(KeyType); + REDISMODULE_GET_API(ValueLength); + REDISMODULE_GET_API(ListPush); + REDISMODULE_GET_API(ListPop); + REDISMODULE_GET_API(StringToLongLong); + REDISMODULE_GET_API(StringToDouble); + REDISMODULE_GET_API(Call); + REDISMODULE_GET_API(CallReplyProto); + REDISMODULE_GET_API(FreeCallReply); + REDISMODULE_GET_API(CallReplyInteger); + REDISMODULE_GET_API(CallReplyType); + REDISMODULE_GET_API(CallReplyLength); + REDISMODULE_GET_API(CallReplyArrayElement); + REDISMODULE_GET_API(CallReplyStringPtr); + REDISMODULE_GET_API(CreateStringFromCallReply); + REDISMODULE_GET_API(CreateString); + REDISMODULE_GET_API(CreateStringFromLongLong); + REDISMODULE_GET_API(CreateStringFromString); + REDISMODULE_GET_API(CreateStringPrintf); + REDISMODULE_GET_API(FreeString); + REDISMODULE_GET_API(StringPtrLen); + REDISMODULE_GET_API(AutoMemory); + REDISMODULE_GET_API(Replicate); + REDISMODULE_GET_API(ReplicateVerbatim); + REDISMODULE_GET_API(DeleteKey); + REDISMODULE_GET_API(StringSet); + REDISMODULE_GET_API(StringDMA); + REDISMODULE_GET_API(StringTruncate); + REDISMODULE_GET_API(GetExpire); + REDISMODULE_GET_API(SetExpire); + REDISMODULE_GET_API(ZsetAdd); + REDISMODULE_GET_API(ZsetIncrby); + REDISMODULE_GET_API(ZsetScore); + REDISMODULE_GET_API(ZsetRem); + REDISMODULE_GET_API(ZsetRangeStop); + REDISMODULE_GET_API(ZsetFirstInScoreRange); + REDISMODULE_GET_API(ZsetLastInScoreRange); + REDISMODULE_GET_API(ZsetFirstInLexRange); + REDISMODULE_GET_API(ZsetLastInLexRange); + REDISMODULE_GET_API(ZsetRangeCurrentElement); + REDISMODULE_GET_API(ZsetRangeNext); + REDISMODULE_GET_API(ZsetRangePrev); + REDISMODULE_GET_API(ZsetRangeEndReached); + REDISMODULE_GET_API(HashSet); + REDISMODULE_GET_API(HashGet); + REDISMODULE_GET_API(IsKeysPositionRequest); + REDISMODULE_GET_API(KeyAtPos); + REDISMODULE_GET_API(GetClientId); + REDISMODULE_GET_API(PoolAlloc); + REDISMODULE_GET_API(CreateDataType); + REDISMODULE_GET_API(ModuleTypeSetValue); + REDISMODULE_GET_API(ModuleTypeGetType); + REDISMODULE_GET_API(ModuleTypeGetValue); + REDISMODULE_GET_API(SaveUnsigned); + REDISMODULE_GET_API(LoadUnsigned); + REDISMODULE_GET_API(SaveSigned); + REDISMODULE_GET_API(LoadSigned); + REDISMODULE_GET_API(SaveString); + REDISMODULE_GET_API(SaveStringBuffer); + REDISMODULE_GET_API(LoadString); + REDISMODULE_GET_API(LoadStringBuffer); + REDISMODULE_GET_API(SaveDouble); + REDISMODULE_GET_API(LoadDouble); + REDISMODULE_GET_API(SaveFloat); + REDISMODULE_GET_API(LoadFloat); + REDISMODULE_GET_API(EmitAOF); + REDISMODULE_GET_API(Log); + REDISMODULE_GET_API(LogIOError); + REDISMODULE_GET_API(StringAppendBuffer); + REDISMODULE_GET_API(RetainString); + REDISMODULE_GET_API(StringCompare); + REDISMODULE_GET_API(GetContextFromIO); + REDISMODULE_GET_API(BlockClient); + REDISMODULE_GET_API(UnblockClient); + REDISMODULE_GET_API(IsBlockedReplyRequest); + REDISMODULE_GET_API(IsBlockedTimeoutRequest); + REDISMODULE_GET_API(GetBlockedClientPrivateData); + REDISMODULE_GET_API(AbortBlock); + REDISMODULE_GET_API(Milliseconds); + + RedisModule_SetModuleAttribs(ctx, name, ver, apiver); + return REDISMODULE_OK; } #else diff --git a/src/ray/gcs/tables.cc b/src/ray/gcs/tables.cc index e20384a04bdc..33f1615580a6 100644 --- a/src/ray/gcs/tables.cc +++ b/src/ray/gcs/tables.cc @@ -674,8 +674,8 @@ void ClientTable::HandleNotification(AsyncGcsClient *client, void ClientTable::HandleConnected(AsyncGcsClient *client, const ClientTableDataT &data) { auto connected_client_id = ClientID::FromBinary(data.client_id); - RAY_CHECK(client_id_ == connected_client_id) << connected_client_id << " " - << client_id_; + RAY_CHECK(client_id_ == connected_client_id) + << connected_client_id << " " << client_id_; } const ClientID &ClientTable::GetLocalClientId() const { return client_id_; } @@ -704,8 +704,8 @@ Status ClientTable::Connect(const ClientTableDataT &local_client) { // Callback for a notification from the client table. auto notification_callback = [this]( - AsyncGcsClient *client, const UniqueID &log_key, - const std::vector ¬ifications) { + AsyncGcsClient *client, const UniqueID &log_key, + const std::vector ¬ifications) { RAY_CHECK(log_key == client_log_key_); std::unordered_map connected_nodes; std::unordered_map disconnected_nodes; @@ -797,8 +797,8 @@ Status ActorCheckpointIdTable::AddCheckpointId(const DriverID &driver_id, const ActorID &actor_id, const ActorCheckpointID &checkpoint_id) { auto lookup_callback = [this, checkpoint_id, driver_id, actor_id]( - ray::gcs::AsyncGcsClient *client, const UniqueID &id, - const ActorCheckpointIdDataT &data) { + ray::gcs::AsyncGcsClient *client, const UniqueID &id, + const ActorCheckpointIdDataT &data) { std::shared_ptr copy = std::make_shared(data); copy->timestamps.push_back(current_sys_time_ms()); @@ -817,7 +817,7 @@ Status ActorCheckpointIdTable::AddCheckpointId(const DriverID &driver_id, RAY_CHECK_OK(Add(driver_id, actor_id, copy, nullptr)); }; auto failure_callback = [this, checkpoint_id, driver_id, actor_id]( - ray::gcs::AsyncGcsClient *client, const UniqueID &id) { + ray::gcs::AsyncGcsClient *client, const UniqueID &id) { std::shared_ptr data = std::make_shared(); data->actor_id = id.Binary(); diff --git a/src/ray/object_manager/object_directory.cc b/src/ray/object_manager/object_directory.cc index d2496dceb8bf..5b6794a505d3 100644 --- a/src/ray/object_manager/object_directory.cc +++ b/src/ray/object_manager/object_directory.cc @@ -39,36 +39,36 @@ void UpdateObjectLocations(const GcsChangeMode change_mode, } // namespace void ObjectDirectory::RegisterBackend() { - auto object_notification_callback = [this]( - gcs::AsyncGcsClient *client, const ObjectID &object_id, - const GcsChangeMode change_mode, - const std::vector &location_updates) { - // Objects are added to this map in SubscribeObjectLocations. - auto it = listeners_.find(object_id); - // Do nothing for objects we are not listening for. - if (it == listeners_.end()) { - return; - } - - // Once this flag is set to true, it should never go back to false. - it->second.subscribed = true; - - // Update entries for this object. - UpdateObjectLocations(change_mode, location_updates, gcs_client_->client_table(), - &it->second.current_object_locations); - // Copy the callbacks so that the callbacks can unsubscribe without interrupting - // looping over the callbacks. - auto callbacks = it->second.callbacks; - // Call all callbacks associated with the object id locations we have - // received. This notifies the client even if the list of locations is - // empty, since this may indicate that the objects have been evicted from - // all nodes. - for (const auto &callback_pair : callbacks) { - // It is safe to call the callback directly since this is already running - // in the subscription callback stack. - callback_pair.second(object_id, it->second.current_object_locations); - } - }; + auto object_notification_callback = + [this](gcs::AsyncGcsClient *client, const ObjectID &object_id, + const GcsChangeMode change_mode, + const std::vector &location_updates) { + // Objects are added to this map in SubscribeObjectLocations. + auto it = listeners_.find(object_id); + // Do nothing for objects we are not listening for. + if (it == listeners_.end()) { + return; + } + + // Once this flag is set to true, it should never go back to false. + it->second.subscribed = true; + + // Update entries for this object. + UpdateObjectLocations(change_mode, location_updates, gcs_client_->client_table(), + &it->second.current_object_locations); + // Copy the callbacks so that the callbacks can unsubscribe without interrupting + // looping over the callbacks. + auto callbacks = it->second.callbacks; + // Call all callbacks associated with the object id locations we have + // received. This notifies the client even if the list of locations is + // empty, since this may indicate that the objects have been evicted from + // all nodes. + for (const auto &callback_pair : callbacks) { + // It is safe to call the callback directly since this is already running + // in the subscription callback stack. + callback_pair.second(object_id, it->second.current_object_locations); + } + }; RAY_CHECK_OK(gcs_client_->object_table().Subscribe( DriverID::Nil(), gcs_client_->client_table().GetLocalClientId(), object_notification_callback, nullptr)); diff --git a/src/ray/object_manager/test/object_manager_stress_test.cc b/src/ray/object_manager/test/object_manager_stress_test.cc index f1169605134a..55aa59124a99 100644 --- a/src/ray/object_manager/test/object_manager_stress_test.cc +++ b/src/ray/object_manager/test/object_manager_stress_test.cc @@ -70,11 +70,11 @@ class MockServer { void HandleAcceptObjectManager(const boost::system::error_code &error) { ClientHandler client_handler = [this](TcpClientConnection &client) { object_manager_.ProcessNewClient(client); }; - MessageHandler message_handler = [this]( - std::shared_ptr client, int64_t message_type, - const uint8_t *message) { - object_manager_.ProcessClientMessage(client, message_type, message); - }; + MessageHandler message_handler = + [this](std::shared_ptr client, int64_t message_type, + const uint8_t *message) { + object_manager_.ProcessClientMessage(client, message_type, message); + }; // Accept a new local client and dispatch it to the node manager. auto new_connection = TcpClientConnection::Create( client_handler, message_handler, std::move(object_manager_socket_), @@ -240,16 +240,17 @@ class StressTestObjectManager : public TestObjectManagerBase { void WaitConnections() { client_id_1 = gcs_client_1->client_table().GetLocalClientId(); client_id_2 = gcs_client_2->client_table().GetLocalClientId(); - gcs_client_1->client_table().RegisterClientAddedCallback([this]( - gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) { - ClientID parsed_id = ClientID::FromBinary(data.client_id); - if (parsed_id == client_id_1 || parsed_id == client_id_2) { - num_connected_clients += 1; - } - if (num_connected_clients == 2) { - StartTests(); - } - }); + gcs_client_1->client_table().RegisterClientAddedCallback( + [this](gcs::AsyncGcsClient *client, const ClientID &id, + const ClientTableDataT &data) { + ClientID parsed_id = ClientID::FromBinary(data.client_id); + if (parsed_id == client_id_1 || parsed_id == client_id_2) { + num_connected_clients += 1; + } + if (num_connected_clients == 2) { + StartTests(); + } + }); } void StartTests() { diff --git a/src/ray/object_manager/test/object_manager_test.cc b/src/ray/object_manager/test/object_manager_test.cc index 012c306938d6..ee6c78d8ed42 100644 --- a/src/ray/object_manager/test/object_manager_test.cc +++ b/src/ray/object_manager/test/object_manager_test.cc @@ -64,11 +64,11 @@ class MockServer { void HandleAcceptObjectManager(const boost::system::error_code &error) { ClientHandler client_handler = [this](TcpClientConnection &client) { object_manager_.ProcessNewClient(client); }; - MessageHandler message_handler = [this]( - std::shared_ptr client, int64_t message_type, - const uint8_t *message) { - object_manager_.ProcessClientMessage(client, message_type, message); - }; + MessageHandler message_handler = + [this](std::shared_ptr client, int64_t message_type, + const uint8_t *message) { + object_manager_.ProcessClientMessage(client, message_type, message); + }; // Accept a new local client and dispatch it to the node manager. auto new_connection = TcpClientConnection::Create( client_handler, message_handler, std::move(object_manager_socket_), @@ -219,16 +219,17 @@ class TestObjectManager : public TestObjectManagerBase { void WaitConnections() { client_id_1 = gcs_client_1->client_table().GetLocalClientId(); client_id_2 = gcs_client_2->client_table().GetLocalClientId(); - gcs_client_1->client_table().RegisterClientAddedCallback([this]( - gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) { - ClientID parsed_id = ClientID::FromBinary(data.client_id); - if (parsed_id == client_id_1 || parsed_id == client_id_2) { - num_connected_clients += 1; - } - if (num_connected_clients == 2) { - StartTests(); - } - }); + gcs_client_1->client_table().RegisterClientAddedCallback( + [this](gcs::AsyncGcsClient *client, const ClientID &id, + const ClientTableDataT &data) { + ClientID parsed_id = ClientID::FromBinary(data.client_id); + if (parsed_id == client_id_1 || parsed_id == client_id_2) { + num_connected_clients += 1; + } + if (num_connected_clients == 2) { + StartTests(); + } + }); } void StartTests() { @@ -291,9 +292,10 @@ class TestObjectManager : public TestObjectManagerBase { UniqueID sub_id = ray::UniqueID::FromRandom(); RAY_CHECK_OK(server1->object_manager_.object_directory_->SubscribeObjectLocations( - sub_id, object_1, [this, sub_id, object_1, object_2]( - const ray::ObjectID &object_id, - const std::unordered_set &clients) { + sub_id, object_1, + [this, sub_id, object_1, object_2]( + const ray::ObjectID &object_id, + const std::unordered_set &clients) { if (!clients.empty()) { TestWaitWhileSubscribed(sub_id, object_1, object_2); } diff --git a/src/ray/raylet/client_connection_test.cc b/src/ray/raylet/client_connection_test.cc index 952088bb2f4b..d6359ae281dd 100644 --- a/src/ray/raylet/client_connection_test.cc +++ b/src/ray/raylet/client_connection_test.cc @@ -73,9 +73,9 @@ TEST_F(ClientConnectionTest, SimpleAsyncWrite) { ClientHandler client_handler = [](LocalClientConnection &client) {}; - MessageHandler noop_handler = []( - std::shared_ptr client, int64_t message_type, - const uint8_t *message) {}; + MessageHandler noop_handler = + [](std::shared_ptr client, int64_t message_type, + const uint8_t *message) {}; std::shared_ptr reader = NULL; @@ -120,9 +120,9 @@ TEST_F(ClientConnectionTest, SimpleAsyncError) { ClientHandler client_handler = [](LocalClientConnection &client) {}; - MessageHandler noop_handler = []( - std::shared_ptr client, int64_t message_type, - const uint8_t *message) {}; + MessageHandler noop_handler = + [](std::shared_ptr client, int64_t message_type, + const uint8_t *message) {}; auto writer = LocalClientConnection::Create( client_handler, noop_handler, std::move(in_), "writer", {}, error_message_type_); @@ -142,9 +142,9 @@ TEST_F(ClientConnectionTest, CallbackWithSharedRefDoesNotLeakConnection) { ClientHandler client_handler = [](LocalClientConnection &client) {}; - MessageHandler noop_handler = []( - std::shared_ptr client, int64_t message_type, - const uint8_t *message) {}; + MessageHandler noop_handler = + [](std::shared_ptr client, int64_t message_type, + const uint8_t *message) {}; auto writer = LocalClientConnection::Create( client_handler, noop_handler, std::move(in_), "writer", {}, error_message_type_); diff --git a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc index 2afcba18c356..319d29d4a93a 100644 --- a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc +++ b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc @@ -307,15 +307,15 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativeNotifyActorResumedFromCheckpo * Method: nativeSetResource * Signature: (JLjava/lang/String;D[B)V */ -JNIEXPORT void JNICALL -Java_org_ray_runtime_raylet_RayletClientImpl_nativeSetResource(JNIEnv *env, jclass, - jlong client, jstring resourceName, jdouble capacity, jbyteArray nodeId) { +JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeSetResource( + JNIEnv *env, jclass, jlong client, jstring resourceName, jdouble capacity, + jbyteArray nodeId) { auto raylet_client = reinterpret_cast(client); UniqueIdFromJByteArray node_id(env, nodeId); const char *native_resource_name = env->GetStringUTFChars(resourceName, JNI_FALSE); - auto status = raylet_client->SetResource(native_resource_name, - static_cast(capacity), node_id.GetId()); + auto status = raylet_client->SetResource( + native_resource_name, static_cast(capacity), node_id.GetId()); env->ReleaseStringUTFChars(resourceName, native_resource_name); ThrowRayExceptionIfNotOK(env, status); } diff --git a/src/ray/raylet/lineage_cache.cc b/src/ray/raylet/lineage_cache.cc index fcae3f8e04c9..32dddada5244 100644 --- a/src/ray/raylet/lineage_cache.cc +++ b/src/ray/raylet/lineage_cache.cc @@ -290,10 +290,9 @@ void LineageCache::FlushTask(const TaskID &task_id) { RAY_CHECK(entry); RAY_CHECK(entry->GetStatus() < GcsStatus::COMMITTING); - gcs::raylet::TaskTable::WriteCallback task_callback = [this]( - ray::gcs::AsyncGcsClient *client, const TaskID &id, const protocol::TaskT &data) { - HandleEntryCommitted(id); - }; + gcs::raylet::TaskTable::WriteCallback task_callback = + [this](ray::gcs::AsyncGcsClient *client, const TaskID &id, + const protocol::TaskT &data) { HandleEntryCommitted(id); }; auto task = lineage_.GetEntry(task_id); // TODO(swang): Make this better... flatbuffers::FlatBufferBuilder fbb; diff --git a/src/ray/raylet/main.cc b/src/ray/raylet/main.cc index c6e581cec9b7..eca282a53309 100644 --- a/src/ray/raylet/main.cc +++ b/src/ray/raylet/main.cc @@ -174,7 +174,7 @@ int main(int argc, char *argv[]) { // instead of returning immediately. // We should stop the service and remove the local socket file. auto handler = [&main_service, &raylet_socket_name, &server, &gcs_client]( - const boost::system::error_code &error, int signal_number) { + const boost::system::error_code &error, int signal_number) { auto shutdown_callback = [&server, &main_service]() { server.reset(); main_service.stop(); diff --git a/src/ray/raylet/monitor.cc b/src/ray/raylet/monitor.cc index a87257cadda4..62ecb00b819f 100644 --- a/src/ray/raylet/monitor.cc +++ b/src/ray/raylet/monitor.cc @@ -48,8 +48,8 @@ void Monitor::Tick() { auto client_id = it->first; RAY_LOG(WARNING) << "Client timed out: " << client_id; auto lookup_callback = [this, client_id]( - gcs::AsyncGcsClient *client, const ClientID &id, - const std::vector &all_data) { + gcs::AsyncGcsClient *client, const ClientID &id, + const std::vector &all_data) { bool marked = false; for (const auto &data : all_data) { if (client_id.Binary() == data.client_id && diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 75377a13c73d..671a7a7982b5 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -179,45 +179,42 @@ ray::Status NodeManager::RegisterGcs() { }; gcs_client_->client_table().RegisterClientAddedCallback(node_manager_client_added); // Register a callback on the client table for removed clients. - auto node_manager_client_removed = [this]( - gcs::AsyncGcsClient *client, const UniqueID &id, const ClientTableDataT &data) { - ClientRemoved(data); - }; + auto node_manager_client_removed = + [this](gcs::AsyncGcsClient *client, const UniqueID &id, + const ClientTableDataT &data) { ClientRemoved(data); }; gcs_client_->client_table().RegisterClientRemovedCallback(node_manager_client_removed); // Register a callback on the client table for resource create/update requests - auto node_manager_resource_createupdated = [this]( - gcs::AsyncGcsClient *client, const UniqueID &id, const ClientTableDataT &data) { - ResourceCreateUpdated(data); - }; + auto node_manager_resource_createupdated = + [this](gcs::AsyncGcsClient *client, const UniqueID &id, + const ClientTableDataT &data) { ResourceCreateUpdated(data); }; gcs_client_->client_table().RegisterResourceCreateUpdatedCallback( node_manager_resource_createupdated); // Register a callback on the client table for resource delete requests - auto node_manager_resource_deleted = [this]( - gcs::AsyncGcsClient *client, const UniqueID &id, const ClientTableDataT &data) { - ResourceDeleted(data); - }; + auto node_manager_resource_deleted = + [this](gcs::AsyncGcsClient *client, const UniqueID &id, + const ClientTableDataT &data) { ResourceDeleted(data); }; gcs_client_->client_table().RegisterResourceDeletedCallback( node_manager_resource_deleted); // Subscribe to heartbeat batches from the monitor. - const auto &heartbeat_batch_added = [this]( - gcs::AsyncGcsClient *client, const ClientID &id, - const HeartbeatBatchTableDataT &heartbeat_batch) { - HeartbeatBatchAdded(heartbeat_batch); - }; + const auto &heartbeat_batch_added = + [this](gcs::AsyncGcsClient *client, const ClientID &id, + const HeartbeatBatchTableDataT &heartbeat_batch) { + HeartbeatBatchAdded(heartbeat_batch); + }; RAY_RETURN_NOT_OK(gcs_client_->heartbeat_batch_table().Subscribe( DriverID::Nil(), ClientID::Nil(), heartbeat_batch_added, /*subscribe_callback=*/nullptr, /*done_callback=*/nullptr)); // Subscribe to driver table updates. - const auto driver_table_handler = [this]( - gcs::AsyncGcsClient *client, const DriverID &client_id, - const std::vector &driver_data) { - HandleDriverTableUpdate(client_id, driver_data); - }; + const auto driver_table_handler = + [this](gcs::AsyncGcsClient *client, const DriverID &client_id, + const std::vector &driver_data) { + HandleDriverTableUpdate(client_id, driver_data); + }; RAY_RETURN_NOT_OK(gcs_client_->driver_table().Subscribe( DriverID::Nil(), ClientID::Nil(), driver_table_handler, nullptr)); @@ -2202,53 +2199,55 @@ void NodeManager::ForwardTaskOrResubmit(const Task &task, /// TODO(rkn): Should we check that the node manager is remote and not local? /// TODO(rkn): Should we check if the remote node manager is known to be dead? // Attempt to forward the task. - ForwardTask(task, node_manager_id, [this, node_manager_id](ray::Status error, - const Task &task) { - const TaskID task_id = task.GetTaskSpecification().TaskId(); - RAY_LOG(INFO) << "Failed to forward task " << task_id << " to node manager " - << node_manager_id; - - // Mark the failed task as pending to let other raylets know that we still - // have the task. TaskDependencyManager::TaskPending() is assumed to be - // idempotent. - task_dependency_manager_.TaskPending(task); - - // Actor tasks can only be executed at the actor's location, so they are - // retried after a timeout. All other tasks that fail to be forwarded are - // deemed to be placeable again. - if (task.GetTaskSpecification().IsActorTask()) { - // The task is for an actor on another node. Create a timer to resubmit - // the task in a little bit. TODO(rkn): Really this should be a - // unique_ptr instead of a shared_ptr. However, it's a little harder to - // move unique_ptrs into lambdas. - auto retry_timer = std::make_shared(io_service_); - auto retry_duration = boost::posix_time::milliseconds( - RayConfig::instance().node_manager_forward_task_retry_timeout_milliseconds()); - retry_timer->expires_from_now(retry_duration); - retry_timer->async_wait( - [this, task_id, retry_timer](const boost::system::error_code &error) { - // Timer killing will receive the boost::asio::error::operation_aborted, - // we only handle the timeout event. - RAY_CHECK(!error); - RAY_LOG(INFO) << "Resubmitting task " << task_id - << " because ForwardTask failed."; - // Remove the RESUBMITTED task from the SWAP queue. - TaskState state; - const auto task = local_queues_.RemoveTask(task_id, &state); - RAY_CHECK(state == TaskState::SWAP); - // Submit the task again. - SubmitTask(task, Lineage()); - }); - // Temporarily move the RESUBMITTED task to the SWAP queue while the - // timer is active. - local_queues_.QueueTasks({task}, TaskState::SWAP); - } else { - // The task is not for an actor and may therefore be placed on another - // node immediately. Send it to the scheduling policy to be placed again. - local_queues_.QueueTasks({task}, TaskState::PLACEABLE); - ScheduleTasks(cluster_resource_map_); - } - }); + ForwardTask( + task, node_manager_id, + [this, node_manager_id](ray::Status error, const Task &task) { + const TaskID task_id = task.GetTaskSpecification().TaskId(); + RAY_LOG(INFO) << "Failed to forward task " << task_id << " to node manager " + << node_manager_id; + + // Mark the failed task as pending to let other raylets know that we still + // have the task. TaskDependencyManager::TaskPending() is assumed to be + // idempotent. + task_dependency_manager_.TaskPending(task); + + // Actor tasks can only be executed at the actor's location, so they are + // retried after a timeout. All other tasks that fail to be forwarded are + // deemed to be placeable again. + if (task.GetTaskSpecification().IsActorTask()) { + // The task is for an actor on another node. Create a timer to resubmit + // the task in a little bit. TODO(rkn): Really this should be a + // unique_ptr instead of a shared_ptr. However, it's a little harder to + // move unique_ptrs into lambdas. + auto retry_timer = std::make_shared(io_service_); + auto retry_duration = boost::posix_time::milliseconds( + RayConfig::instance() + .node_manager_forward_task_retry_timeout_milliseconds()); + retry_timer->expires_from_now(retry_duration); + retry_timer->async_wait( + [this, task_id, retry_timer](const boost::system::error_code &error) { + // Timer killing will receive the boost::asio::error::operation_aborted, + // we only handle the timeout event. + RAY_CHECK(!error); + RAY_LOG(INFO) << "Resubmitting task " << task_id + << " because ForwardTask failed."; + // Remove the RESUBMITTED task from the SWAP queue. + TaskState state; + const auto task = local_queues_.RemoveTask(task_id, &state); + RAY_CHECK(state == TaskState::SWAP); + // Submit the task again. + SubmitTask(task, Lineage()); + }); + // Temporarily move the RESUBMITTED task to the SWAP queue while the + // timer is active. + local_queues_.QueueTasks({task}, TaskState::SWAP); + } else { + // The task is not for an actor and may therefore be placed on another + // node immediately. Send it to the scheduling policy to be placed again. + local_queues_.QueueTasks({task}, TaskState::PLACEABLE); + ScheduleTasks(cluster_resource_map_); + } + }); } void NodeManager::ForwardTask( diff --git a/src/ray/raylet/object_manager_integration_test.cc b/src/ray/raylet/object_manager_integration_test.cc index 1b043ca58c2b..0f411e8c581d 100644 --- a/src/ray/raylet/object_manager_integration_test.cc +++ b/src/ray/raylet/object_manager_integration_test.cc @@ -136,16 +136,17 @@ class TestObjectManagerIntegration : public TestObjectManagerBase { void WaitConnections() { client_id_1 = gcs_client_1->client_table().GetLocalClientId(); client_id_2 = gcs_client_2->client_table().GetLocalClientId(); - gcs_client_1->client_table().RegisterClientAddedCallback([this]( - gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) { - ClientID parsed_id = ClientID::FromBinary(data.client_id); - if (parsed_id == client_id_1 || parsed_id == client_id_2) { - num_connected_clients += 1; - } - if (num_connected_clients == 2) { - StartTests(); - } - }); + gcs_client_1->client_table().RegisterClientAddedCallback( + [this](gcs::AsyncGcsClient *client, const ClientID &id, + const ClientTableDataT &data) { + ClientID parsed_id = ClientID::FromBinary(data.client_id); + if (parsed_id == client_id_1 || parsed_id == client_id_2) { + num_connected_clients += 1; + } + if (num_connected_clients == 2) { + StartTests(); + } + }); } void StartTests() { diff --git a/src/ray/raylet/raylet.cc b/src/ray/raylet/raylet.cc index dd9e5fac318e..80630d372a61 100644 --- a/src/ray/raylet/raylet.cc +++ b/src/ray/raylet/raylet.cc @@ -128,13 +128,15 @@ void Raylet::DoAcceptNodeManager() { void Raylet::HandleAcceptNodeManager(const boost::system::error_code &error) { if (!error) { - ClientHandler client_handler = [this]( - TcpClientConnection &client) { node_manager_.ProcessNewNodeManager(client); }; - MessageHandler message_handler = [this]( - std::shared_ptr client, int64_t message_type, - const uint8_t *message) { - node_manager_.ProcessNodeManagerMessage(*client, message_type, message); - }; + ClientHandler client_handler = + [this](TcpClientConnection &client) { + node_manager_.ProcessNewNodeManager(client); + }; + MessageHandler message_handler = + [this](std::shared_ptr client, int64_t message_type, + const uint8_t *message) { + node_manager_.ProcessNodeManagerMessage(*client, message_type, message); + }; // Accept a new TCP client and dispatch it to the node manager. auto new_connection = TcpClientConnection::Create( client_handler, message_handler, std::move(node_manager_socket_), "node manager", @@ -154,11 +156,11 @@ void Raylet::DoAcceptObjectManager() { void Raylet::HandleAcceptObjectManager(const boost::system::error_code &error) { ClientHandler client_handler = [this](TcpClientConnection &client) { object_manager_.ProcessNewClient(client); }; - MessageHandler message_handler = [this]( - std::shared_ptr client, int64_t message_type, - const uint8_t *message) { - object_manager_.ProcessClientMessage(client, message_type, message); - }; + MessageHandler message_handler = + [this](std::shared_ptr client, int64_t message_type, + const uint8_t *message) { + object_manager_.ProcessClientMessage(client, message_type, message); + }; // Accept a new TCP client and dispatch it to the node manager. auto new_connection = TcpClientConnection::Create( client_handler, message_handler, std::move(object_manager_socket_), @@ -177,11 +179,11 @@ void Raylet::HandleAccept(const boost::system::error_code &error) { // TODO: typedef these handlers. ClientHandler client_handler = [this](LocalClientConnection &client) { node_manager_.ProcessNewClient(client); }; - MessageHandler message_handler = [this]( - std::shared_ptr client, int64_t message_type, - const uint8_t *message) { - node_manager_.ProcessClientMessage(client, message_type, message); - }; + MessageHandler message_handler = + [this](std::shared_ptr client, int64_t message_type, + const uint8_t *message) { + node_manager_.ProcessClientMessage(client, message_type, message); + }; // Accept a new local client and dispatch it to the node manager. auto new_connection = LocalClientConnection::Create( client_handler, message_handler, std::move(socket_), "worker", diff --git a/src/ray/raylet/scheduling_queue.cc b/src/ray/raylet/scheduling_queue.cc index 85295e403769..73f0e2ef803a 100644 --- a/src/ray/raylet/scheduling_queue.cc +++ b/src/ray/raylet/scheduling_queue.cc @@ -233,8 +233,13 @@ std::vector SchedulingQueue::RemoveTasks(std::unordered_set &task_ std::vector removed_tasks; // Try to find the tasks to remove from the queues. for (const auto &task_state : { - TaskState::PLACEABLE, TaskState::WAITING, TaskState::READY, TaskState::RUNNING, - TaskState::INFEASIBLE, TaskState::WAITING_FOR_ACTOR_CREATION, TaskState::SWAP, + TaskState::PLACEABLE, + TaskState::WAITING, + TaskState::READY, + TaskState::RUNNING, + TaskState::INFEASIBLE, + TaskState::WAITING_FOR_ACTOR_CREATION, + TaskState::SWAP, }) { RemoveTasksFromQueue(task_state, task_ids, &removed_tasks); } @@ -248,8 +253,13 @@ Task SchedulingQueue::RemoveTask(const TaskID &task_id, TaskState *removed_task_ std::unordered_set task_id_set = {task_id}; // Try to find the task to remove in the queues. for (const auto &task_state : { - TaskState::PLACEABLE, TaskState::WAITING, TaskState::READY, TaskState::RUNNING, - TaskState::INFEASIBLE, TaskState::WAITING_FOR_ACTOR_CREATION, TaskState::SWAP, + TaskState::PLACEABLE, + TaskState::WAITING, + TaskState::READY, + TaskState::RUNNING, + TaskState::INFEASIBLE, + TaskState::WAITING_FOR_ACTOR_CREATION, + TaskState::SWAP, }) { RemoveTasksFromQueue(task_state, task_id_set, &removed_tasks); if (task_id_set.empty()) { diff --git a/src/ray/raylet/scheduling_queue.h b/src/ray/raylet/scheduling_queue.h index 4fd07e5ca606..465f2a4341a0 100644 --- a/src/ray/raylet/scheduling_queue.h +++ b/src/ray/raylet/scheduling_queue.h @@ -149,9 +149,13 @@ class SchedulingQueue { /// Create a scheduling queue. SchedulingQueue() : ready_queue_(std::make_shared()) { for (const auto &task_state : { - TaskState::PLACEABLE, TaskState::WAITING, TaskState::READY, - TaskState::RUNNING, TaskState::INFEASIBLE, - TaskState::WAITING_FOR_ACTOR_CREATION, TaskState::SWAP, + TaskState::PLACEABLE, + TaskState::WAITING, + TaskState::READY, + TaskState::RUNNING, + TaskState::INFEASIBLE, + TaskState::WAITING_FOR_ACTOR_CREATION, + TaskState::SWAP, }) { if (task_state == TaskState::READY) { task_queues_[static_cast(task_state)] = ready_queue_; diff --git a/src/ray/raylet/scheduling_resources.cc b/src/ray/raylet/scheduling_resources.cc index 923e6aad9d85..c80282601256 100644 --- a/src/ray/raylet/scheduling_resources.cc +++ b/src/ray/raylet/scheduling_resources.cc @@ -15,8 +15,8 @@ FractionalResourceQuantity::FractionalResourceQuantity(double resource_quantity) // We check for nonnegativeity due to the implicit conversion to // FractionalResourceQuantity from ints/doubles when we do logical // comparisons. - RAY_CHECK(resource_quantity >= 0) << "Resource capacity, " << resource_quantity - << ", should be nonnegative."; + RAY_CHECK(resource_quantity >= 0) + << "Resource capacity, " << resource_quantity << ", should be nonnegative."; resource_quantity_ = static_cast(resource_quantity * kResourceConversionFactor); diff --git a/src/ray/util/logging.cc b/src/ray/util/logging.cc index 97c871d8cb7f..1e2a95408f13 100644 --- a/src/ray/util/logging.cc +++ b/src/ray/util/logging.cc @@ -186,8 +186,7 @@ bool RayLog::IsLevelEnabled(RayLogLevel log_level) { RayLog::RayLog(const char *file_name, int line_number, RayLogLevel severity) // glog does not have DEBUG level, we can handle it using is_enabled_. - : logging_provider_(nullptr), - is_enabled_(severity >= severity_threshold_) { + : logging_provider_(nullptr), is_enabled_(severity >= severity_threshold_) { #ifdef RAY_USE_GLOG if (is_enabled_) { logging_provider_ = diff --git a/src/ray/util/logging.h b/src/ray/util/logging.h index d37ab9a73897..39428eba9583 100644 --- a/src/ray/util/logging.h +++ b/src/ray/util/logging.h @@ -16,19 +16,19 @@ enum class RayLogLevel { DEBUG = -1, INFO = 0, WARNING = 1, ERROR = 2, FATAL = 3 #define RAY_IGNORE_EXPR(expr) ((void)(expr)) -#define RAY_CHECK(condition) \ - (condition) ? RAY_IGNORE_EXPR(0) \ - : ::ray::Voidify() & \ - ::ray::RayLog(__FILE__, __LINE__, ray::RayLogLevel::FATAL) \ - << " Check failed: " #condition " " +#define RAY_CHECK(condition) \ + (condition) \ + ? RAY_IGNORE_EXPR(0) \ + : ::ray::Voidify() & ::ray::RayLog(__FILE__, __LINE__, ray::RayLogLevel::FATAL) \ + << " Check failed: " #condition " " #ifdef NDEBUG -#define RAY_DCHECK(condition) \ - (condition) ? RAY_IGNORE_EXPR(0) \ - : ::ray::Voidify() & \ - ::ray::RayLog(__FILE__, __LINE__, ray::RayLogLevel::ERROR) \ - << " Debug check failed: " #condition " " +#define RAY_DCHECK(condition) \ + (condition) \ + ? RAY_IGNORE_EXPR(0) \ + : ::ray::Voidify() & ::ray::RayLog(__FILE__, __LINE__, ray::RayLogLevel::ERROR) \ + << " Debug check failed: " #condition " " #else #define RAY_DCHECK(condition) RAY_CHECK(condition) diff --git a/src/ray/util/macros.h b/src/ray/util/macros.h index dbf85fe399e5..f105c4bd2b5f 100644 --- a/src/ray/util/macros.h +++ b/src/ray/util/macros.h @@ -8,7 +8,7 @@ void operator=(const TypeName &) = delete #endif -#define RAY_UNUSED(x) (void) x +#define RAY_UNUSED(x) (void)x // // GCC can be told that a certain branch is not likely to be taken (for From 37abdb283f71f5722013c7fd45baa1ec4676488b Mon Sep 17 00:00:00 2001 From: Zhijun Fu <37800433+zhijunfu@users.noreply.github.com> Date: Fri, 14 Jun 2019 18:35:32 +0800 Subject: [PATCH 095/118] [Core worker] add store & task provider (#4966) --- BUILD.bazel | 4 + src/ray/core_worker/common.h | 31 ++++ src/ray/core_worker/core_worker.cc | 44 +++--- src/ray/core_worker/core_worker.h | 29 ++-- src/ray/core_worker/core_worker_test.cc | 22 ++- src/ray/core_worker/mock_worker.cc | 4 +- src/ray/core_worker/object_interface.cc | 127 +++------------- src/ray/core_worker/object_interface.h | 5 + .../store_provider/plasma_store_provider.cc | 139 ++++++++++++++++++ .../store_provider/plasma_store_provider.h | 76 ++++++++++ .../store_provider/store_provider.h | 64 ++++++++ src/ray/core_worker/task_execution.cc | 57 ++++--- src/ray/core_worker/task_execution.h | 7 +- src/ray/core_worker/task_interface.cc | 40 ++--- src/ray/core_worker/task_interface.h | 14 +- .../core_worker/transport/raylet_transport.cc | 32 ++++ .../core_worker/transport/raylet_transport.h | 44 ++++++ src/ray/core_worker/transport/transport.h | 41 ++++++ src/ray/test/run_core_worker_tests.sh | 3 - 19 files changed, 565 insertions(+), 218 deletions(-) create mode 100644 src/ray/core_worker/store_provider/plasma_store_provider.cc create mode 100644 src/ray/core_worker/store_provider/plasma_store_provider.h create mode 100644 src/ray/core_worker/store_provider/store_provider.h create mode 100644 src/ray/core_worker/transport/raylet_transport.cc create mode 100644 src/ray/core_worker/transport/raylet_transport.h create mode 100644 src/ray/core_worker/transport/transport.h diff --git a/BUILD.bazel b/BUILD.bazel index 27ab40ef74d1..47e795011ab4 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -111,6 +111,8 @@ cc_library( srcs = glob( [ "src/ray/core_worker/*.cc", + "src/ray/core_worker/store_provider/*.cc", + "src/ray/core_worker/transport/*.cc", ], exclude = [ "src/ray/core_worker/*_test.cc", @@ -119,6 +121,8 @@ cc_library( ), hdrs = glob([ "src/ray/core_worker/*.h", + "src/ray/core_worker/store_provider/*.h", + "src/ray/core_worker/transport/*.h", ]), copts = COPTS, deps = [ diff --git a/src/ray/core_worker/common.h b/src/ray/core_worker/common.h index b11fabfe46f8..3fda406613ef 100644 --- a/src/ray/core_worker/common.h +++ b/src/ray/core_worker/common.h @@ -5,6 +5,8 @@ #include "ray/common/buffer.h" #include "ray/common/id.h" +#include "ray/raylet/raylet_client.h" +#include "ray/raylet/task_spec.h" namespace ray { @@ -66,6 +68,35 @@ class TaskArg { const std::shared_ptr data_; }; +/// Task specification, which includes the immutable information about the task +/// which are determined at the submission time. +/// TODO(zhijunfu): this can be removed after everything is moved to protobuf. +class TaskSpec { + public: + TaskSpec(const raylet::TaskSpecification &task_spec, + const std::vector &dependencies) + : task_spec_(task_spec), dependencies_(dependencies) {} + + TaskSpec(const raylet::TaskSpecification &&task_spec, + const std::vector &&dependencies) + : task_spec_(task_spec), dependencies_(dependencies) {} + + const raylet::TaskSpecification &GetTaskSpecification() const { return task_spec_; } + + const std::vector &GetDependencies() const { return dependencies_; } + + private: + /// Raylet task specification. + raylet::TaskSpecification task_spec_; + + /// Dependencies. + std::vector dependencies_; +}; + +enum class StoreProviderType { PLASMA }; + +enum class TaskTransportType { RAYLET }; + } // namespace ray #endif // RAY_CORE_WORKER_COMMON_H diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 033409196d9b..bcc1bdd963db 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -9,41 +9,39 @@ CoreWorker::CoreWorker(const enum WorkerType worker_type, DriverID driver_id) : worker_type_(worker_type), language_(language), - worker_context_(worker_type, driver_id), store_socket_(store_socket), raylet_socket_(raylet_socket), - is_initialized_(false), + worker_context_(worker_type, driver_id), + raylet_client_(raylet_socket_, worker_context_.GetWorkerID(), + (worker_type_ == ray::WorkerType::WORKER), + worker_context_.GetCurrentDriverID(), ToTaskLanguage(language_)), task_interface_(*this), object_interface_(*this), task_execution_interface_(*this) { - switch (language_) { + // TODO(zhijunfu): currently RayletClient would crash in its constructor if it cannot + // connect to Raylet after a number of retries, this needs to be changed + // so that the worker (java/python .etc) can retrieve and handle the error + // instead of crashing. + auto status = store_client_.Connect(store_socket_); + if (!status.ok()) { + RAY_LOG(ERROR) << "Connecting plasma store failed when trying to construct" + << " core worker: " << status.message(); + throw std::runtime_error(status.message()); + } +} + +::Language CoreWorker::ToTaskLanguage(WorkerLanguage language) { + switch (language) { case ray::WorkerLanguage::JAVA: - task_language_ = ::Language::JAVA; + return ::Language::JAVA; break; case ray::WorkerLanguage::PYTHON: - task_language_ = ::Language::PYTHON; + return ::Language::PYTHON; break; default: - RAY_LOG(FATAL) << "Unsupported worker language: " << static_cast(language_); + RAY_LOG(FATAL) << "invalid language specified: " << static_cast(language); break; } } -Status CoreWorker::Connect() { - // connect to plasma. - RAY_ARROW_RETURN_NOT_OK(store_client_.Connect(store_socket_)); - - // connect to raylet. - // TODO: currently RayletClient would crash in its constructor if it cannot - // connect to Raylet after a number of retries, this needs to be changed - // so that the worker (java/python .etc) can retrieve and handle the error - // instead of crashing. - raylet_client_ = std::unique_ptr( - new RayletClient(raylet_socket_, worker_context_.GetWorkerID(), - (worker_type_ == ray::WorkerType::WORKER), - worker_context_.GetCurrentDriverID(), task_language_)); - is_initialized_ = true; - return Status::OK(); -} - } // namespace ray diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index c038b76ce53f..e03a8700be81 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -20,13 +20,12 @@ class CoreWorker { /// /// \param[in] worker_type Type of this worker. /// \param[in] langauge Language of this worker. + /// + /// NOTE(zhijunfu): the constructor would throw if a failure happens. CoreWorker(const WorkerType worker_type, const WorkerLanguage language, const std::string &store_socket, const std::string &raylet_socket, DriverID driver_id = DriverID::Nil()); - /// Connect to raylet. - Status Connect(); - /// Type of this worker. enum WorkerType WorkerType() const { return worker_type_; } @@ -46,23 +45,26 @@ class CoreWorker { CoreWorkerTaskExecutionInterface &Execution() { return task_execution_interface_; } private: + /// Translate from WorkLanguage to Language type (required by raylet client). + /// + /// \param[in] language Language for a task. + /// \return Translated task language. + ::Language ToTaskLanguage(WorkerLanguage language); + /// Type of this worker. const enum WorkerType worker_type_; /// Language of this worker. const enum WorkerLanguage language_; - /// Language of this worker as specified in flatbuf (used by task spec). - ::Language task_language_; - - /// Worker context per thread. - WorkerContext worker_context_; - /// Plasma store socket name. - std::string store_socket_; + const std::string store_socket_; /// raylet socket name. - std::string raylet_socket_; + const std::string raylet_socket_; + + /// Worker context. + WorkerContext worker_context_; /// Plasma store client. plasma::PlasmaClient store_client_; @@ -71,10 +73,7 @@ class CoreWorker { std::mutex store_client_mutex_; /// Raylet client. - std::unique_ptr raylet_client_; - - /// Whether this worker has been initialized. - bool is_initialized_; + RayletClient raylet_client_; /// The `CoreWorkerTaskInterface` instance. CoreWorkerTaskInterface task_interface_; diff --git a/src/ray/core_worker/core_worker_test.cc b/src/ray/core_worker/core_worker_test.cc index fedfb9c2356b..6e4ecc161fb4 100644 --- a/src/ray/core_worker/core_worker_test.cc +++ b/src/ray/core_worker/core_worker_test.cc @@ -128,8 +128,6 @@ class CoreWorkerTest : public ::testing::Test { raylet_store_socket_names_[0], raylet_socket_names_[0], DriverID::FromRandom()); - RAY_CHECK_OK(driver.Connect()); - // Test pass by value. { uint8_t array1[] = {1, 2, 3, 4, 5, 6, 7, 8}; @@ -187,7 +185,6 @@ class CoreWorkerTest : public ::testing::Test { CoreWorker driver(WorkerType::DRIVER, WorkerLanguage::PYTHON, raylet_store_socket_names_[0], raylet_socket_names_[0], DriverID::FromRandom()); - RAY_CHECK_OK(driver.Connect()); std::unique_ptr actor_handle; @@ -277,13 +274,6 @@ TEST_F(ZeroNodeTest, TestTaskArg) { ASSERT_EQ(*data, *buffer); } -TEST_F(ZeroNodeTest, TestAttributeGetters) { - CoreWorker core_worker(WorkerType::DRIVER, WorkerLanguage::PYTHON, "", "", - DriverID::FromRandom()); - ASSERT_EQ(core_worker.WorkerType(), WorkerType::DRIVER); - ASSERT_EQ(core_worker.Language(), WorkerLanguage::PYTHON); -} - TEST_F(ZeroNodeTest, TestWorkerContext) { auto driver_id = DriverID::FromRandom(); @@ -313,7 +303,6 @@ TEST_F(SingleNodeTest, TestObjectInterface) { CoreWorker core_worker(WorkerType::DRIVER, WorkerLanguage::PYTHON, raylet_store_socket_names_[0], raylet_socket_names_[0], DriverID::FromRandom()); - RAY_CHECK_OK(core_worker.Connect()); uint8_t array1[] = {1, 2, 3, 4, 5, 6, 7, 8}; uint8_t array2[] = {10, 11, 12, 13, 14, 15}; @@ -370,12 +359,10 @@ TEST_F(TwoNodeTest, TestObjectInterfaceCrossNodes) { CoreWorker worker1(WorkerType::DRIVER, WorkerLanguage::PYTHON, raylet_store_socket_names_[0], raylet_socket_names_[0], DriverID::FromRandom()); - RAY_CHECK_OK(worker1.Connect()); CoreWorker worker2(WorkerType::DRIVER, WorkerLanguage::PYTHON, raylet_store_socket_names_[1], raylet_socket_names_[1], DriverID::FromRandom()); - RAY_CHECK_OK(worker2.Connect()); uint8_t array1[] = {1, 2, 3, 4, 5, 6, 7, 8}; uint8_t array2[] = {10, 11, 12, 13, 14, 15}; @@ -456,6 +443,15 @@ TEST_F(TwoNodeTest, TestActorTaskCrossNodes) { TestActorTask(resources); } +TEST_F(SingleNodeTest, TestCoreWorkerConstructorFailure) { + try { + CoreWorker core_worker(WorkerType::DRIVER, WorkerLanguage::PYTHON, "", + raylet_socket_names_[0], DriverID::FromRandom()); + } catch (const std::exception &e) { + std::cout << "Caught exception when constructing core worker: " << e.what(); + } +} + } // namespace ray int main(int argc, char **argv) { diff --git a/src/ray/core_worker/mock_worker.cc b/src/ray/core_worker/mock_worker.cc index 95d11bb259a8..a331a0b6ae12 100644 --- a/src/ray/core_worker/mock_worker.cc +++ b/src/ray/core_worker/mock_worker.cc @@ -17,9 +17,7 @@ class MockWorker { public: MockWorker(const std::string &store_socket, const std::string &raylet_socket) : worker_(WorkerType::WORKER, WorkerLanguage::PYTHON, store_socket, raylet_socket, - DriverID::FromRandom()) { - RAY_CHECK_OK(worker_.Connect()); - } + DriverID::FromRandom()) {} void Run() { auto executor_func = [this](const RayFunction &ray_function, diff --git a/src/ray/core_worker/object_interface.cc b/src/ray/core_worker/object_interface.cc index 5ab5d33330d7..81777117cd14 100644 --- a/src/ray/core_worker/object_interface.cc +++ b/src/ray/core_worker/object_interface.cc @@ -2,11 +2,18 @@ #include "ray/common/ray_config.h" #include "ray/core_worker/context.h" #include "ray/core_worker/core_worker.h" +#include "ray/core_worker/store_provider/plasma_store_provider.h" namespace ray { CoreWorkerObjectInterface::CoreWorkerObjectInterface(CoreWorker &core_worker) - : core_worker_(core_worker) {} + : core_worker_(core_worker) { + store_providers_.emplace( + static_cast(StoreProviderType::PLASMA), + std::unique_ptr(new CoreWorkerPlasmaStoreProvider( + core_worker_.store_client_, core_worker_.store_client_mutex_, + core_worker_.raylet_client_))); +} Status CoreWorkerObjectInterface::Put(const Buffer &buffer, ObjectID *object_id) { ObjectID put_id = ObjectID::ForPut(core_worker_.worker_context_.GetCurrentTaskID(), @@ -16,127 +23,31 @@ Status CoreWorkerObjectInterface::Put(const Buffer &buffer, ObjectID *object_id) } Status CoreWorkerObjectInterface::Put(const Buffer &buffer, const ObjectID &object_id) { - auto plasma_id = object_id.ToPlasmaId(); - std::shared_ptr data; - { - std::unique_lock guard(core_worker_.store_client_mutex_); - RAY_ARROW_RETURN_NOT_OK( - core_worker_.store_client_.Create(plasma_id, buffer.Size(), nullptr, 0, &data)); - } - - memcpy(data->mutable_data(), buffer.Data(), buffer.Size()); - - { - std::unique_lock guard(core_worker_.store_client_mutex_); - RAY_ARROW_RETURN_NOT_OK(core_worker_.store_client_.Seal(plasma_id)); - RAY_ARROW_RETURN_NOT_OK(core_worker_.store_client_.Release(plasma_id)); - } - return Status::OK(); + auto type = static_cast(StoreProviderType::PLASMA); + return store_providers_[type]->Put(buffer, object_id); } Status CoreWorkerObjectInterface::Get(const std::vector &ids, int64_t timeout_ms, std::vector> *results) { - (*results).resize(ids.size(), nullptr); - - bool was_blocked = false; - - std::unordered_map unready; - for (size_t i = 0; i < ids.size(); i++) { - unready.insert({ids[i], i}); - } - - int num_attempts = 0; - bool should_break = false; - int64_t remaining_timeout = timeout_ms; - // Repeat until we get all objects. - while (!unready.empty() && !should_break) { - std::vector unready_ids; - for (const auto &entry : unready) { - unready_ids.push_back(entry.first); - } - - // For the initial fetch, we only fetch the objects, do not reconstruct them. - bool fetch_only = num_attempts == 0; - if (!fetch_only) { - // If fetch_only is false, this worker will be blocked. - was_blocked = true; - } - - // TODO: can call `fetchOrReconstruct` in batches as an optimization. - RAY_CHECK_OK(core_worker_.raylet_client_->FetchOrReconstruct( - unready_ids, fetch_only, core_worker_.worker_context_.GetCurrentTaskID())); - - // Get the objects from the object store, and parse the result. - int64_t get_timeout; - if (remaining_timeout >= 0) { - get_timeout = - std::min(remaining_timeout, RayConfig::instance().get_timeout_milliseconds()); - remaining_timeout -= get_timeout; - should_break = remaining_timeout <= 0; - } else { - get_timeout = RayConfig::instance().get_timeout_milliseconds(); - } - - std::vector plasma_ids; - for (const auto &id : unready_ids) { - plasma_ids.push_back(id.ToPlasmaId()); - } - - std::vector object_buffers; - { - std::unique_lock guard(core_worker_.store_client_mutex_); - auto status = - core_worker_.store_client_.Get(plasma_ids, get_timeout, &object_buffers); - } - - for (size_t i = 0; i < object_buffers.size(); i++) { - if (object_buffers[i].data != nullptr) { - const auto &object_id = unready_ids[i]; - (*results)[unready[object_id]] = - std::make_shared(object_buffers[i].data); - unready.erase(object_id); - } - } - - num_attempts += 1; - // TODO: log a message if attempted too many times. - } - - if (was_blocked) { - RAY_CHECK_OK(core_worker_.raylet_client_->NotifyUnblocked( - core_worker_.worker_context_.GetCurrentTaskID())); - } - - return Status::OK(); + auto type = static_cast(StoreProviderType::PLASMA); + return store_providers_[type]->Get( + ids, timeout_ms, core_worker_.worker_context_.GetCurrentTaskID(), results); } Status CoreWorkerObjectInterface::Wait(const std::vector &object_ids, int num_objects, int64_t timeout_ms, std::vector *results) { - WaitResultPair result_pair; - auto status = core_worker_.raylet_client_->Wait( - object_ids, num_objects, timeout_ms, false, - core_worker_.worker_context_.GetCurrentTaskID(), &result_pair); - std::unordered_set ready_ids; - for (const auto &entry : result_pair.first) { - ready_ids.insert(entry); - } - - // TODO: change RayletClient::Wait() to return a bit set, so that we don't need - // to do this translation. - (*results).resize(object_ids.size()); - for (size_t i = 0; i < object_ids.size(); i++) { - (*results)[i] = ready_ids.count(object_ids[i]) > 0; - } - - return status; + auto type = static_cast(StoreProviderType::PLASMA); + return store_providers_[type]->Wait(object_ids, num_objects, timeout_ms, + core_worker_.worker_context_.GetCurrentTaskID(), + results); } Status CoreWorkerObjectInterface::Delete(const std::vector &object_ids, bool local_only, bool delete_creating_tasks) { - return core_worker_.raylet_client_->FreeObjects(object_ids, local_only, - delete_creating_tasks); + auto type = static_cast(StoreProviderType::PLASMA); + return store_providers_[type]->Delete(object_ids, local_only, delete_creating_tasks); } } // namespace ray diff --git a/src/ray/core_worker/object_interface.h b/src/ray/core_worker/object_interface.h index 431b3f825ac9..35403675f164 100644 --- a/src/ray/core_worker/object_interface.h +++ b/src/ray/core_worker/object_interface.h @@ -6,10 +6,12 @@ #include "ray/common/id.h" #include "ray/common/status.h" #include "ray/core_worker/common.h" +#include "ray/core_worker/store_provider/store_provider.h" namespace ray { class CoreWorker; +class CoreWorkerStoreProvider; /// The interface that contains all `CoreWorker` methods that are related to object store. class CoreWorkerObjectInterface { @@ -62,6 +64,9 @@ class CoreWorkerObjectInterface { private: /// Reference to the parent CoreWorker instance. CoreWorker &core_worker_; + + /// All the store providers supported. + std::unordered_map> store_providers_; }; } // namespace ray diff --git a/src/ray/core_worker/store_provider/plasma_store_provider.cc b/src/ray/core_worker/store_provider/plasma_store_provider.cc new file mode 100644 index 000000000000..b5dd91d82881 --- /dev/null +++ b/src/ray/core_worker/store_provider/plasma_store_provider.cc @@ -0,0 +1,139 @@ +#include "ray/core_worker/store_provider/plasma_store_provider.h" +#include "ray/common/ray_config.h" +#include "ray/core_worker/context.h" +#include "ray/core_worker/core_worker.h" +#include "ray/core_worker/object_interface.h" + +namespace ray { + +CoreWorkerPlasmaStoreProvider::CoreWorkerPlasmaStoreProvider( + plasma::PlasmaClient &store_client, std::mutex &store_client_mutex, + RayletClient &raylet_client) + : store_client_(store_client), + store_client_mutex_(store_client_mutex), + raylet_client_(raylet_client) {} + +Status CoreWorkerPlasmaStoreProvider::Put(const Buffer &buffer, + const ObjectID &object_id) { + auto plasma_id = object_id.ToPlasmaId(); + std::shared_ptr data; + { + std::unique_lock guard(store_client_mutex_); + RAY_ARROW_RETURN_NOT_OK( + store_client_.Create(plasma_id, buffer.Size(), nullptr, 0, &data)); + } + + memcpy(data->mutable_data(), buffer.Data(), buffer.Size()); + + { + std::unique_lock guard(store_client_mutex_); + RAY_ARROW_RETURN_NOT_OK(store_client_.Seal(plasma_id)); + RAY_ARROW_RETURN_NOT_OK(store_client_.Release(plasma_id)); + } + return Status::OK(); +} + +Status CoreWorkerPlasmaStoreProvider::Get(const std::vector &ids, + int64_t timeout_ms, const TaskID &task_id, + std::vector> *results) { + (*results).resize(ids.size(), nullptr); + + bool was_blocked = false; + + std::unordered_map unready; + for (size_t i = 0; i < ids.size(); i++) { + unready.insert({ids[i], i}); + } + + int num_attempts = 0; + bool should_break = false; + int64_t remaining_timeout = timeout_ms; + // Repeat until we get all objects. + while (!unready.empty() && !should_break) { + std::vector unready_ids; + for (const auto &entry : unready) { + unready_ids.push_back(entry.first); + } + + // For the initial fetch, we only fetch the objects, do not reconstruct them. + bool fetch_only = num_attempts == 0; + if (!fetch_only) { + // If fetch_only is false, this worker will be blocked. + was_blocked = true; + } + + // TODO(zhijunfu): can call `fetchOrReconstruct` in batches as an optimization. + RAY_CHECK_OK(raylet_client_.FetchOrReconstruct(unready_ids, fetch_only, task_id)); + + // Get the objects from the object store, and parse the result. + int64_t get_timeout; + if (remaining_timeout >= 0) { + get_timeout = + std::min(remaining_timeout, RayConfig::instance().get_timeout_milliseconds()); + remaining_timeout -= get_timeout; + should_break = remaining_timeout <= 0; + } else { + get_timeout = RayConfig::instance().get_timeout_milliseconds(); + } + + std::vector plasma_ids; + for (const auto &id : unready_ids) { + plasma_ids.push_back(id.ToPlasmaId()); + } + + std::vector object_buffers; + { + std::unique_lock guard(store_client_mutex_); + auto status = store_client_.Get(plasma_ids, get_timeout, &object_buffers); + } + + for (size_t i = 0; i < object_buffers.size(); i++) { + if (object_buffers[i].data != nullptr) { + const auto &object_id = unready_ids[i]; + (*results)[unready[object_id]] = + std::make_shared(object_buffers[i].data); + unready.erase(object_id); + } + } + + num_attempts += 1; + // TODO(zhijunfu): log a message if attempted too many times. + } + + if (was_blocked) { + RAY_CHECK_OK(raylet_client_.NotifyUnblocked(task_id)); + } + + return Status::OK(); +} + +Status CoreWorkerPlasmaStoreProvider::Wait(const std::vector &object_ids, + int num_objects, int64_t timeout_ms, + const TaskID &task_id, + std::vector *results) { + WaitResultPair result_pair; + auto status = raylet_client_.Wait(object_ids, num_objects, timeout_ms, false, task_id, + &result_pair); + std::unordered_set ready_ids; + for (const auto &entry : result_pair.first) { + ready_ids.insert(entry); + } + + // TODO(zhijunfu): change RayletClient::Wait() to return a bit set, so that we don't + // need + // to do this translation. + (*results).resize(object_ids.size()); + for (size_t i = 0; i < object_ids.size(); i++) { + (*results)[i] = ready_ids.count(object_ids[i]) > 0; + } + + return status; +} + +Status CoreWorkerPlasmaStoreProvider::Delete(const std::vector &object_ids, + bool local_only, + bool delete_creating_tasks) { + return raylet_client_.FreeObjects(object_ids, local_only, delete_creating_tasks); +} + +} // namespace ray diff --git a/src/ray/core_worker/store_provider/plasma_store_provider.h b/src/ray/core_worker/store_provider/plasma_store_provider.h new file mode 100644 index 000000000000..0dfce1eb1e45 --- /dev/null +++ b/src/ray/core_worker/store_provider/plasma_store_provider.h @@ -0,0 +1,76 @@ +#ifndef RAY_CORE_WORKER_PLASMA_STORE_PROVIDER_H +#define RAY_CORE_WORKER_PLASMA_STORE_PROVIDER_H + +#include "plasma/client.h" +#include "ray/common/buffer.h" +#include "ray/common/id.h" +#include "ray/common/status.h" +#include "ray/core_worker/common.h" +#include "ray/core_worker/store_provider/store_provider.h" +#include "ray/raylet/raylet_client.h" + +namespace ray { + +class CoreWorker; + +/// The class provides implementations for accessing plasma store, which includes both +/// local and remote store, remote access is done via raylet. +class CoreWorkerPlasmaStoreProvider : public CoreWorkerStoreProvider { + public: + CoreWorkerPlasmaStoreProvider(plasma::PlasmaClient &store_client, + std::mutex &store_client_mutex, + RayletClient &raylet_client); + + /// Put an object with specified ID into object store. + /// + /// \param[in] buffer Data buffer of the object. + /// \param[in] object_id Object ID specified by user. + /// \return Status. + Status Put(const Buffer &buffer, const ObjectID &object_id) override; + + /// Get a list of objects from the object store. + /// + /// \param[in] ids IDs of the objects to get. + /// \param[in] timeout_ms Timeout in milliseconds, wait infinitely if it's negative. + /// \param[in] task_id ID for the current task. + /// \param[out] results Result list of objects data. + /// \return Status. + Status Get(const std::vector &ids, int64_t timeout_ms, const TaskID &task_id, + std::vector> *results) override; + + /// Wait for a list of objects to appear in the object store. + /// + /// \param[in] IDs of the objects to wait for. + /// \param[in] num_returns Number of objects that should appear. + /// \param[in] timeout_ms Timeout in milliseconds, wait infinitely if it's negative. + /// \param[in] task_id ID for the current task. + /// \param[out] results A bitset that indicates each object has appeared or not. + /// \return Status. + Status Wait(const std::vector &object_ids, int num_objects, + int64_t timeout_ms, const TaskID &task_id, + std::vector *results) override; + + /// Delete a list of objects from the object store. + /// + /// \param[in] object_ids IDs of the objects to delete. + /// \param[in] local_only Whether only delete the objects in local node, or all nodes in + /// the cluster. + /// \param[in] delete_creating_tasks Whether also delete the tasks that + /// created these objects. \return Status. + Status Delete(const std::vector &object_ids, bool local_only, + bool delete_creating_tasks) override; + + private: + /// Plasma store client. + plasma::PlasmaClient &store_client_; + + /// Mutex to protect store_client_. + std::mutex &store_client_mutex_; + + /// Raylet client. + RayletClient &raylet_client_; +}; + +} // namespace ray + +#endif // RAY_CORE_WORKER_PLASMA_STORE_PROVIDER_H diff --git a/src/ray/core_worker/store_provider/store_provider.h b/src/ray/core_worker/store_provider/store_provider.h new file mode 100644 index 000000000000..f1521edf1626 --- /dev/null +++ b/src/ray/core_worker/store_provider/store_provider.h @@ -0,0 +1,64 @@ +#ifndef RAY_CORE_WORKER_STORE_PROVIDER_H +#define RAY_CORE_WORKER_STORE_PROVIDER_H + +#include "ray/common/buffer.h" +#include "ray/common/id.h" +#include "ray/common/status.h" +#include "ray/core_worker/common.h" + +namespace ray { + +/// Provider interface for store access. Store provider should inherit from this class and +/// provide implementions for the methods. The actual store provider may use a plasma +/// store or local memory store in worker process, or possibly other types of storage. + +class CoreWorkerStoreProvider { + public: + CoreWorkerStoreProvider() {} + + virtual ~CoreWorkerStoreProvider() {} + + /// Put an object with specified ID into object store. + /// + /// \param[in] buffer Data buffer of the object. + /// \param[in] object_id Object ID specified by user. + /// \return Status. + virtual Status Put(const Buffer &buffer, const ObjectID &object_id) = 0; + + /// Get a list of objects from the object store. + /// + /// \param[in] ids IDs of the objects to get. + /// \param[in] timeout_ms Timeout in milliseconds, wait infinitely if it's negative. + /// \param[in] task_id ID for the current task. + /// \param[out] results Result list of objects data. + /// \return Status. + virtual Status Get(const std::vector &ids, int64_t timeout_ms, + const TaskID &task_id, + std::vector> *results) = 0; + + /// Wait for a list of objects to appear in the object store. + /// + /// \param[in] IDs of the objects to wait for. + /// \param[in] num_returns Number of objects that should appear. + /// \param[in] timeout_ms Timeout in milliseconds, wait infinitely if it's negative. + /// \param[in] task_id ID for the current task. + /// \param[out] results A bitset that indicates each object has appeared or not. + /// \return Status. + virtual Status Wait(const std::vector &object_ids, int num_objects, + int64_t timeout_ms, const TaskID &task_id, + std::vector *results) = 0; + + /// Delete a list of objects from the object store. + /// + /// \param[in] object_ids IDs of the objects to delete. + /// \param[in] local_only Whether only delete the objects in local node, or all nodes in + /// the cluster. + /// \param[in] delete_creating_tasks Whether also delete the tasks that + /// created these objects. \return Status. + virtual Status Delete(const std::vector &object_ids, bool local_only, + bool delete_creating_tasks) = 0; +}; + +} // namespace ray + +#endif // RAY_CORE_WORKER_STORE_PROVIDER_H diff --git a/src/ray/core_worker/task_execution.cc b/src/ray/core_worker/task_execution.cc index fc22fce96c97..701ae3124c97 100644 --- a/src/ray/core_worker/task_execution.cc +++ b/src/ray/core_worker/task_execution.cc @@ -1,43 +1,54 @@ #include "ray/core_worker/task_execution.h" #include "ray/core_worker/context.h" #include "ray/core_worker/core_worker.h" +#include "ray/core_worker/transport/raylet_transport.h" namespace ray { -Status CoreWorkerTaskExecutionInterface::Run(const TaskExecutor &executor) { - RAY_CHECK(core_worker_.is_initialized_); +CoreWorkerTaskExecutionInterface::CoreWorkerTaskExecutionInterface( + CoreWorker &core_worker) + : core_worker_(core_worker) { + task_receivers.emplace( + static_cast(TaskTransportType::RAYLET), + std::unique_ptr( + new CoreWorkerRayletTaskReceiver(core_worker_.raylet_client_))); +} +Status CoreWorkerTaskExecutionInterface::Run(const TaskExecutor &executor) { while (true) { - std::unique_ptr task_spec; - auto status = core_worker_.raylet_client_->GetTask(&task_spec); + std::vector tasks; + auto status = + task_receivers[static_cast(TaskTransportType::RAYLET)]->GetTasks(&tasks); if (!status.ok()) { - RAY_LOG(ERROR) << "Get task failed with error: " + RAY_LOG(ERROR) << "Getting task failed with error: " << ray::Status::IOError(status.message()); return status; } - const auto &spec = *task_spec; - core_worker_.worker_context_.SetCurrentTask(spec); + for (const auto &task : tasks) { + const auto &spec = task.GetTaskSpecification(); + core_worker_.worker_context_.SetCurrentTask(spec); - WorkerLanguage language = (spec.GetLanguage() == ::Language::JAVA) - ? WorkerLanguage::JAVA - : WorkerLanguage::PYTHON; - RayFunction func{language, spec.FunctionDescriptor()}; + WorkerLanguage language = (spec.GetLanguage() == ::Language::JAVA) + ? WorkerLanguage::JAVA + : WorkerLanguage::PYTHON; + RayFunction func{language, spec.FunctionDescriptor()}; - std::vector> args; - RAY_CHECK_OK(BuildArgsForExecutor(spec, &args)); + std::vector> args; + RAY_CHECK_OK(BuildArgsForExecutor(spec, &args)); - auto num_returns = spec.NumReturns(); - if (spec.IsActorCreationTask() || spec.IsActorTask()) { - RAY_CHECK(num_returns > 0); - // Decrease to account for the dummy object id. - num_returns--; - } + auto num_returns = spec.NumReturns(); + if (spec.IsActorCreationTask() || spec.IsActorTask()) { + RAY_CHECK(num_returns > 0); + // Decrease to account for the dummy object id. + num_returns--; + } - status = executor(func, args, spec.TaskId(), num_returns); - // TODO: - // 1. Check and handle failure. - // 2. Save or load checkpoint. + status = executor(func, args, spec.TaskId(), num_returns); + // TODO(zhijunfu): + // 1. Check and handle failure. + // 2. Save or load checkpoint. + } } // should never reach here. diff --git a/src/ray/core_worker/task_execution.h b/src/ray/core_worker/task_execution.h index e2fe2148a3ab..f4b44b9e131d 100644 --- a/src/ray/core_worker/task_execution.h +++ b/src/ray/core_worker/task_execution.h @@ -4,6 +4,7 @@ #include "ray/common/buffer.h" #include "ray/common/status.h" #include "ray/core_worker/common.h" +#include "ray/core_worker/transport/transport.h" namespace ray { @@ -17,8 +18,7 @@ class TaskSpecification; /// execution. class CoreWorkerTaskExecutionInterface { public: - CoreWorkerTaskExecutionInterface(CoreWorker &core_worker) : core_worker_(core_worker) {} - + CoreWorkerTaskExecutionInterface(CoreWorker &core_worker); /// The callback provided app-language workers that executes tasks. /// /// \param ray_function[in] Information about the function to execute. @@ -46,6 +46,9 @@ class CoreWorkerTaskExecutionInterface { /// Reference to the parent CoreWorker instance. CoreWorker &core_worker_; + + /// All the task task receivers supported. + std::unordered_map> task_receivers; }; } // namespace ray diff --git a/src/ray/core_worker/task_interface.cc b/src/ray/core_worker/task_interface.cc index 00f15237f1c3..6a91bd6b2101 100644 --- a/src/ray/core_worker/task_interface.cc +++ b/src/ray/core_worker/task_interface.cc @@ -1,9 +1,19 @@ #include "ray/core_worker/task_interface.h" #include "ray/core_worker/context.h" #include "ray/core_worker/core_worker.h" +#include "ray/core_worker/task_interface.h" +#include "ray/core_worker/transport/raylet_transport.h" namespace ray { +CoreWorkerTaskInterface::CoreWorkerTaskInterface(CoreWorker &core_worker) + : core_worker_(core_worker) { + task_submitters_.emplace( + static_cast(TaskTransportType::RAYLET), + std::unique_ptr( + new CoreWorkerRayletTaskSubmitter(core_worker_.raylet_client_))); +} + Status CoreWorkerTaskInterface::SubmitTask(const RayFunction &function, const std::vector &args, const TaskOptions &task_options, @@ -20,7 +30,7 @@ Status CoreWorkerTaskInterface::SubmitTask(const RayFunction &function, } auto task_arguments = BuildTaskArguments(args); - auto language = ToTaskLanguage(function.language); + auto language = core_worker_.ToTaskLanguage(function.language); ray::raylet::TaskSpecification spec(context.GetCurrentDriverID(), context.GetCurrentTaskID(), next_task_index, @@ -28,7 +38,8 @@ Status CoreWorkerTaskInterface::SubmitTask(const RayFunction &function, language, function.function_descriptor); std::vector execution_dependencies; - return core_worker_.raylet_client_->SubmitTask(execution_dependencies, spec); + TaskSpec task(std::move(spec), execution_dependencies); + return task_submitters_[static_cast(TaskTransportType::RAYLET)]->SubmitTask(task); } Status CoreWorkerTaskInterface::CreateActor( @@ -50,7 +61,7 @@ Status CoreWorkerTaskInterface::CreateActor( (*actor_handle)->SetActorCursor(return_ids[0]); auto task_arguments = BuildTaskArguments(args); - auto language = ToTaskLanguage(function.language); + auto language = core_worker_.ToTaskLanguage(function.language); // Note that the caller is supposed to specify required placement resources // correctly via actor_creation_options.resources. @@ -62,7 +73,8 @@ Status CoreWorkerTaskInterface::CreateActor( function.function_descriptor); std::vector execution_dependencies; - return core_worker_.raylet_client_->SubmitTask(execution_dependencies, spec); + TaskSpec task(std::move(spec), execution_dependencies); + return task_submitters_[static_cast(TaskTransportType::RAYLET)]->SubmitTask(task); } Status CoreWorkerTaskInterface::SubmitActorTask(ActorHandle &actor_handle, @@ -86,7 +98,7 @@ Status CoreWorkerTaskInterface::SubmitActorTask(ActorHandle &actor_handle, ObjectID::FromBinary(actor_handle.ActorID().Binary()); auto task_arguments = BuildTaskArguments(args); - auto language = ToTaskLanguage(function.language); + auto language = core_worker_.ToTaskLanguage(function.language); std::vector new_actor_handles; ray::raylet::TaskSpecification spec( @@ -103,7 +115,9 @@ Status CoreWorkerTaskInterface::SubmitActorTask(ActorHandle &actor_handle, actor_handle.SetActorCursor(actor_cursor); actor_handle.ClearNewActorHandles(); - auto status = core_worker_.raylet_client_->SubmitTask(execution_dependencies, spec); + TaskSpec task(std::move(spec), execution_dependencies); + auto status = + task_submitters_[static_cast(TaskTransportType::RAYLET)]->SubmitTask(task); // remove cursor from return ids. (*return_ids).pop_back(); @@ -127,18 +141,4 @@ CoreWorkerTaskInterface::BuildTaskArguments(const std::vector &args) { return task_arguments; } -::Language CoreWorkerTaskInterface::ToTaskLanguage(WorkerLanguage language) { - switch (language) { - case ray::WorkerLanguage::JAVA: - return ::Language::JAVA; - break; - case ray::WorkerLanguage::PYTHON: - return ::Language::PYTHON; - break; - default: - RAY_LOG(FATAL) << "invalid language specified: " << static_cast(language); - break; - } -} - } // namespace ray diff --git a/src/ray/core_worker/task_interface.h b/src/ray/core_worker/task_interface.h index 2ec3b1329cbc..e59934f9b51d 100644 --- a/src/ray/core_worker/task_interface.h +++ b/src/ray/core_worker/task_interface.h @@ -7,6 +7,7 @@ #include "ray/common/id.h" #include "ray/common/status.h" #include "ray/core_worker/common.h" +#include "ray/core_worker/transport/transport.h" #include "ray/raylet/task.h" namespace ray { @@ -65,11 +66,11 @@ class ActorHandle { int IncreaseTaskCounter() { return task_counter_++; } std::list GetNewActorHandle() { - // TODO: implement this. + // TODO(zhijunfu): implement this. return std::list(); } - void ClearNewActorHandles() { /* TODO: implement this. */ + void ClearNewActorHandles() { /* TODO(zhijunfu): implement this. */ } private: @@ -89,7 +90,7 @@ class ActorHandle { /// submission. class CoreWorkerTaskInterface { public: - CoreWorkerTaskInterface(CoreWorker &core_worker) : core_worker_(core_worker) {} + CoreWorkerTaskInterface(CoreWorker &core_worker); /// Submit a normal task. /// @@ -137,11 +138,8 @@ class CoreWorkerTaskInterface { std::vector> BuildTaskArguments( const std::vector &args); - /// Translate from WorkLanguage to Language type (required by taks spec). - /// - /// \param[in] language Language for a task. - /// \return Translated task language. - ::Language ToTaskLanguage(WorkerLanguage language); + /// All the task submitters supported. + std::unordered_map> task_submitters_; }; } // namespace ray diff --git a/src/ray/core_worker/transport/raylet_transport.cc b/src/ray/core_worker/transport/raylet_transport.cc new file mode 100644 index 000000000000..14906acfe0bf --- /dev/null +++ b/src/ray/core_worker/transport/raylet_transport.cc @@ -0,0 +1,32 @@ + +#include "ray/core_worker/transport/raylet_transport.h" + +namespace ray { + +CoreWorkerRayletTaskSubmitter::CoreWorkerRayletTaskSubmitter(RayletClient &raylet_client) + : raylet_client_(raylet_client) {} + +Status CoreWorkerRayletTaskSubmitter::SubmitTask(const TaskSpec &task) { + return raylet_client_.SubmitTask(task.GetDependencies(), task.GetTaskSpecification()); +} + +CoreWorkerRayletTaskReceiver::CoreWorkerRayletTaskReceiver(RayletClient &raylet_client) + : raylet_client_(raylet_client) {} + +Status CoreWorkerRayletTaskReceiver::GetTasks(std::vector *tasks) { + std::unique_ptr task_spec; + auto status = raylet_client_.GetTask(&task_spec); + if (!status.ok()) { + RAY_LOG(ERROR) << "Get task from raylet failed with error: " + << ray::Status::IOError(status.message()); + return status; + } + + std::vector dependencies; + RAY_CHECK((*tasks).empty()); + (*tasks).emplace_back(*task_spec, dependencies); + + return Status::OK(); +} + +} // namespace ray diff --git a/src/ray/core_worker/transport/raylet_transport.h b/src/ray/core_worker/transport/raylet_transport.h new file mode 100644 index 000000000000..03bf82f29886 --- /dev/null +++ b/src/ray/core_worker/transport/raylet_transport.h @@ -0,0 +1,44 @@ +#ifndef RAY_CORE_WORKER_RAYLET_TRANSPORT_H +#define RAY_CORE_WORKER_RAYLET_TRANSPORT_H + +#include + +#include "ray/core_worker/transport/transport.h" +#include "ray/raylet/raylet_client.h" + +namespace ray { + +/// In raylet task submitter and receiver, a task is submitted to raylet, and possibly +/// gets forwarded to another raylet on which node the task should be executed, and +/// then a worker on that node gets this task and starts executing it. + +class CoreWorkerRayletTaskSubmitter : public CoreWorkerTaskSubmitter { + public: + CoreWorkerRayletTaskSubmitter(RayletClient &raylet_client); + + /// Submit a task for execution to raylet. + /// + /// \param[in] task The task spec to submit. + /// \return Status. + virtual Status SubmitTask(const TaskSpec &task) override; + + private: + /// Raylet client. + RayletClient &raylet_client_; +}; + +class CoreWorkerRayletTaskReceiver : public CoreWorkerTaskReceiver { + public: + CoreWorkerRayletTaskReceiver(RayletClient &raylet_client); + + // Get tasks for execution from raylet. + virtual Status GetTasks(std::vector *tasks) override; + + private: + /// Raylet client. + RayletClient &raylet_client_; +}; + +} // namespace ray + +#endif // RAY_CORE_WORKER_RAYLET_TRANSPORT_H diff --git a/src/ray/core_worker/transport/transport.h b/src/ray/core_worker/transport/transport.h new file mode 100644 index 000000000000..44be74b989c7 --- /dev/null +++ b/src/ray/core_worker/transport/transport.h @@ -0,0 +1,41 @@ +#ifndef RAY_CORE_WORKER_TRANSPORT_H +#define RAY_CORE_WORKER_TRANSPORT_H + +#include + +#include "ray/common/buffer.h" +#include "ray/common/id.h" +#include "ray/common/status.h" +#include "ray/core_worker/common.h" +#include "ray/raylet/task_spec.h" + +namespace ray { + +/// Interfaces for task submitter and receiver. They are separate classes but should be +/// used in pairs - one type of task submitter should be used together with task +/// with the same type, so these classes are put together in this same file. +/// +/// Task submitter/receiver should inherit from these classes and provide implementions +/// for the methods. The actual task submitter/receiver can submit/get tasks via raylet, +/// or directly to/from another worker. + +/// This class is responsible to submit tasks. +class CoreWorkerTaskSubmitter { + public: + /// Submit a task for execution. + /// + /// \param[in] task The task spec to submit. + /// \return Status. + virtual Status SubmitTask(const TaskSpec &task) = 0; +}; + +/// This class receives tasks for execution. +class CoreWorkerTaskReceiver { + public: + // Get tasks for execution. + virtual Status GetTasks(std::vector *tasks) = 0; +}; + +} // namespace ray + +#endif // RAY_CORE_WORKER_TRANSPORT_H diff --git a/src/ray/test/run_core_worker_tests.sh b/src/ray/test/run_core_worker_tests.sh index 104b19ff19cb..7668b92ac272 100644 --- a/src/ray/test/run_core_worker_tests.sh +++ b/src/ray/test/run_core_worker_tests.sh @@ -43,6 +43,3 @@ sleep 1s bazel run //:redis-cli -- -p 6379 shutdown bazel run //:redis-cli -- -p 6380 shutdown sleep 1s - -# Include raylet integration test once it's ready. -# ./bazel-bin/object_manager_integration_test $STORE_EXEC From 1b86e551fb8b01e1c5959b019652e614d7b169db Mon Sep 17 00:00:00 2001 From: Tianhong Dai Date: Sat, 15 Jun 2019 08:22:36 +0800 Subject: [PATCH 096/118] Fix bugs in the a3c code template. (#4984) --- doc/source/example-a3c.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/source/example-a3c.rst b/doc/source/example-a3c.rst index 3037b2b6b132..f8a8bfb4c1f3 100644 --- a/doc/source/example-a3c.rst +++ b/doc/source/example-a3c.rst @@ -127,7 +127,7 @@ global model parameters. The main training script looks like the following. obs = 0 # Start simulations on actors - agents = [Runner(env_name, i) for i in range(num_workers)] + agents = [Runner.remote(env_name, i) for i in range(num_workers)] # Start gradient calculation tasks on each actor parameters = policy.get_weights() From 05e274807092ee263d0c98283ad6fcfbea45de29 Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Sat, 15 Jun 2019 11:01:27 -0700 Subject: [PATCH 097/118] Inherit Function Docstrings and other metedata (#4985) --- python/ray/remote_function.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/python/ray/remote_function.py b/python/ray/remote_function.py index 44d2777a2900..9ff6994b8d42 100644 --- a/python/ray/remote_function.py +++ b/python/ray/remote_function.py @@ -4,6 +4,7 @@ import copy import logging +from functools import wraps from ray.function_manager import FunctionDescriptor import ray.signature @@ -74,15 +75,18 @@ def __init__(self, function, num_cpus, num_gpus, resources, self._last_driver_id_exported_for = None + # Override task.remote's signature and docstring + @wraps(function) + def _remote_proxy(*args, **kwargs): + return self._remote(args=args, kwargs=kwargs) + + self.remote = _remote_proxy + def __call__(self, *args, **kwargs): raise Exception("Remote functions cannot be called directly. Instead " "of running '{}()', try '{}.remote()'.".format( self._function_name, self._function_name)) - def remote(self, *args, **kwargs): - """This runs immediately when a remote function is called.""" - return self._remote(args=args, kwargs=kwargs) - def _submit(self, args=None, kwargs=None, From b08765a08b3060c9d45e14406814f3537bd812c2 Mon Sep 17 00:00:00 2001 From: Qing Wang Date: Mon, 17 Jun 2019 13:34:23 +0800 Subject: [PATCH 098/118] Fix a crash when unknown worker registering to raylet (#4992) --- src/ray/raylet/worker_pool.cc | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/ray/raylet/worker_pool.cc b/src/ray/raylet/worker_pool.cc index 43698c53f0d8..d4ac4cf4ecce 100644 --- a/src/ray/raylet/worker_pool.cc +++ b/src/ray/raylet/worker_pool.cc @@ -164,7 +164,10 @@ void WorkerPool::RegisterWorker(const std::shared_ptr &worker) { state.registered_workers.insert(std::move(worker)); auto it = state.starting_worker_processes.find(pid); - RAY_CHECK(it != state.starting_worker_processes.end()); + if (it == state.starting_worker_processes.end()) { + RAY_LOG(WARNING) << "Received a register request from an unknown worker " << pid; + return; + } it->second--; if (it->second == 0) { state.starting_worker_processes.erase(it); From 2bf92e02e22a04f9895d7701a41985f288906808 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Mon, 17 Jun 2019 19:00:50 +0800 Subject: [PATCH 099/118] [gRPC] Use gRPC for inter-node-manager communication (#4968) --- .bazelrc | 2 + .travis.yml | 2 +- BUILD.bazel | 26 ++ bazel/ray_deps_build_all.bzl | 4 + bazel/ray_deps_setup.bzl | 8 + ci/suppress_output | 2 +- ci/travis/install-bazel.sh | 2 +- java/BUILD.bazel | 7 +- java/test/pom.xml | 15 ++ .../test_perf_integration.py | 23 ++ src/ray/protobuf/node_manager.proto | 24 ++ src/ray/raylet/node_manager.cc | 237 +++++++----------- src/ray/raylet/node_manager.h | 43 ++-- src/ray/raylet/raylet.cc | 36 +-- src/ray/raylet/raylet.h | 6 - src/ray/raylet/task.cc | 10 +- src/ray/raylet/task.h | 3 + src/ray/rpc/client_call.h | 169 +++++++++++++ src/ray/rpc/grpc_server.cc | 70 ++++++ src/ray/rpc/grpc_server.h | 92 +++++++ src/ray/rpc/node_manager_client.h | 56 +++++ src/ray/rpc/node_manager_server.h | 71 ++++++ src/ray/rpc/server_call.h | 233 +++++++++++++++++ src/ray/rpc/util.h | 33 +++ 24 files changed, 957 insertions(+), 217 deletions(-) create mode 100644 src/ray/protobuf/node_manager.proto create mode 100644 src/ray/rpc/client_call.h create mode 100644 src/ray/rpc/grpc_server.cc create mode 100644 src/ray/rpc/grpc_server.h create mode 100644 src/ray/rpc/node_manager_client.h create mode 100644 src/ray/rpc/node_manager_server.h create mode 100644 src/ray/rpc/server_call.h create mode 100644 src/ray/rpc/util.h diff --git a/.bazelrc b/.bazelrc index 488b33101594..3e3c3b6c4fa4 100644 --- a/.bazelrc +++ b/.bazelrc @@ -2,3 +2,5 @@ build --compilation_mode=opt build --action_env=PATH build --action_env=PYTHON_BIN_PATH +# This workaround is needed due to https://github.com/bazelbuild/bazel/issues/4341 +build --per_file_copt="external/com_github_grpc_grpc/.*@-DGRPC_BAZEL_BUILD" diff --git a/.travis.yml b/.travis.yml index 1888fa4ce03f..9a4fb66d84a0 100644 --- a/.travis.yml +++ b/.travis.yml @@ -126,7 +126,7 @@ matrix: - ./ci/suppress_output ./ci/travis/install-dependencies.sh # This command should be kept in sync with ray/python/README-building-wheels.md. - - ./python/build-wheel-macos.sh + - ./ci/suppress_output ./python/build-wheel-macos.sh script: - if [ $RAY_CI_MACOS_WHEELS_AFFECTED != "1" ]; then exit; fi diff --git a/BUILD.bazel b/BUILD.bazel index 47e795011ab4..da36eec0cf57 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -1,12 +1,37 @@ # Bazel build # C/C++ documentation: https://docs.bazel.build/versions/master/be/c-cpp.html +load("@com_github_grpc_grpc//bazel:grpc_build_system.bzl", "grpc_proto_library") load("@com_github_google_flatbuffers//:build_defs.bzl", "flatbuffer_cc_library") load("@//bazel:ray.bzl", "flatbuffer_py_library") load("@//bazel:cython_library.bzl", "pyx_library") COPTS = ["-DRAY_USE_GLOG"] +# Node manager gRPC lib. +grpc_proto_library( + name = "node_manager_grpc_lib", + srcs = ["src/ray/protobuf/node_manager.proto"], +) + +# Node manager server and client. +cc_library( + name = "node_manager_rpc_lib", + srcs = glob([ + "src/ray/rpc/*.cc", + ]), + hdrs = glob([ + "src/ray/rpc/*.h", + ]), + copts = COPTS, + deps = [ + ":node_manager_grpc_lib", + ":ray_common", + "@boost//:asio", + "@com_github_grpc_grpc//:grpc++", + ], +) + cc_binary( name = "raylet", srcs = ["src/ray/raylet/main.cc"], @@ -89,6 +114,7 @@ cc_library( ":gcs", ":gcs_fbs", ":node_manager_fbs", + ":node_manager_rpc_lib", ":object_manager", ":ray_common", ":ray_util", diff --git a/bazel/ray_deps_build_all.bzl b/bazel/ray_deps_build_all.bzl index 5598d5820e35..3e1e1838a59a 100644 --- a/bazel/ray_deps_build_all.bzl +++ b/bazel/ray_deps_build_all.bzl @@ -3,6 +3,8 @@ load("@com_github_nelhage_rules_boost//:boost/boost.bzl", "boost_deps") load("@com_github_jupp0r_prometheus_cpp//:repositories.bzl", "prometheus_cpp_repositories") load("@com_github_ray_project_ray//bazel:python_configure.bzl", "python_configure") load("@com_github_checkstyle_java//:repo.bzl", "checkstyle_deps") +load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps") + def ray_deps_build_all(): gen_java_deps() @@ -10,3 +12,5 @@ def ray_deps_build_all(): boost_deps() prometheus_cpp_repositories() python_configure(name = "local_config_python") + grpc_deps() + diff --git a/bazel/ray_deps_setup.bzl b/bazel/ray_deps_setup.bzl index b3cd21b9b3b1..e6dc21585699 100644 --- a/bazel/ray_deps_setup.bzl +++ b/bazel/ray_deps_setup.bzl @@ -101,3 +101,11 @@ def ray_deps_setup(): # `https://github.com/jupp0r/prometheus-cpp/pull/225` getting merged. urls = ["https://github.com/jovany-wang/prometheus-cpp/archive/master.zip"], ) + + http_archive( + name = "com_github_grpc_grpc", + urls = [ + "https://github.com/grpc/grpc/archive/7741e806a213cba63c96234f16d712a8aa101a49.tar.gz", + ], + strip_prefix = "grpc-7741e806a213cba63c96234f16d712a8aa101a49", + ) diff --git a/ci/suppress_output b/ci/suppress_output index 623559d11cbc..0f32b1a88b37 100755 --- a/ci/suppress_output +++ b/ci/suppress_output @@ -23,7 +23,7 @@ time "$@" >$TMPFILE 2>&1 CODE=$? if [ $CODE != 0 ]; then - cat $TMPFILE + tail -n 2000 $TMPFILE echo "FAILED $CODE" kill $WATCHDOG_PID exit $CODE diff --git a/ci/travis/install-bazel.sh b/ci/travis/install-bazel.sh index c9614f7722ef..5b6d9572952e 100755 --- a/ci/travis/install-bazel.sh +++ b/ci/travis/install-bazel.sh @@ -16,7 +16,7 @@ else exit 1 fi -URL="https://github.com/bazelbuild/bazel/releases/download/0.21.0/bazel-0.21.0-installer-${platform}-x86_64.sh" +URL="https://github.com/bazelbuild/bazel/releases/download/0.26.1/bazel-0.26.1-installer-${platform}-x86_64.sh" wget -O install.sh $URL chmod +x install.sh ./install.sh --user diff --git a/java/BUILD.bazel b/java/BUILD.bazel index f3ae6f063304..80ccabccfc12 100644 --- a/java/BUILD.bazel +++ b/java/BUILD.bazel @@ -94,11 +94,14 @@ define_java_module( ":org_ray_ray_api", ":org_ray_ray_runtime", "@plasma//:org_apache_arrow_arrow_plasma", + "@maven//:com_google_guava_guava", + "@maven//:com_sun_xml_bind_jaxb_core", + "@maven//:com_sun_xml_bind_jaxb_impl", + "@maven//:commons_io_commons_io", + "@maven//:javax_xml_bind_jaxb_api", "@maven//:org_apache_commons_commons_lang3", "@maven//:org_slf4j_slf4j_api", "@maven//:org_testng_testng", - "@maven//:com_google_guava_guava", - "@maven//:commons_io_commons_io", ], ) diff --git a/java/test/pom.xml b/java/test/pom.xml index 10f7ea4b3313..6a3a31d2032e 100644 --- a/java/test/pom.xml +++ b/java/test/pom.xml @@ -32,11 +32,26 @@ guava 27.0.1-jre + + com.sun.xml.bind + jaxb-core + 2.3.0 + + + com.sun.xml.bind + jaxb-impl + 2.3.0 + commons-io commons-io 2.5 + + javax.xml.bind + jaxb-api + 2.3.0 + org.apache.commons commons-lang3 diff --git a/python/ray/tests/perf_integration_tests/test_perf_integration.py b/python/ray/tests/perf_integration_tests/test_perf_integration.py index 2ce2a305a0e8..ff34fe4125fa 100644 --- a/python/ray/tests/perf_integration_tests/test_perf_integration.py +++ b/python/ray/tests/perf_integration_tests/test_perf_integration.py @@ -6,6 +6,7 @@ import pytest import ray +from ray.tests.conftest import _ray_start_cluster num_tasks_submitted = [10**n for n in range(0, 6)] num_tasks_ids = ["{}_tasks".format(i) for i in num_tasks_submitted] @@ -41,3 +42,25 @@ def test_task_submission(benchmark, num_tasks): warmup() benchmark(benchmark_task_submission, num_tasks) ray.shutdown() + + +def benchmark_task_forward(f, num_tasks): + ray.get([f.remote() for _ in range(num_tasks)]) + + +@pytest.mark.benchmark +@pytest.mark.parametrize( + "num_tasks", [10**3, 10**4], + ids=[str(num) + "_tasks" for num in [10**3, 10**4]]) +def test_task_forward(benchmark, num_tasks): + with _ray_start_cluster(num_cpus=16, object_store_memory=10**6) as cluster: + cluster.add_node(resources={"my_resource": 100}) + ray.init(redis_address=cluster.redis_address) + + @ray.remote(resources={"my_resource": 0.001}) + def f(): + return 1 + + # Warm up + ray.get([f.remote() for _ in range(100)]) + benchmark(benchmark_task_forward, f, num_tasks) diff --git a/src/ray/protobuf/node_manager.proto b/src/ray/protobuf/node_manager.proto new file mode 100644 index 000000000000..8a82da1c77fd --- /dev/null +++ b/src/ray/protobuf/node_manager.proto @@ -0,0 +1,24 @@ +syntax = "proto3"; + +package ray.rpc; + +message ForwardTaskRequest { + // The ID of the task to be forwarded. + bytes task_id = 1; + // The tasks in the uncommitted lineage of the forwarded task. This + // should include task_id. + // TODO(hchen): Currently, `uncommitted_tasks` are represented as + // flatbutters-serialized bytes. This is because the flatbuffers-defined Task data + // structure is being used in many places. We should move Task and all related data + // strucutres to protobuf. + repeated bytes uncommitted_tasks = 2; +} + +message ForwardTaskReply { +} + +// Service for inter-node-manager communication. +service NodeManagerService { + // Forward a task and its uncommitted lineage to the remote node manager. + rpc ForwardTask(ForwardTaskRequest) returns (ForwardTaskReply); +} diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 671a7a7982b5..a0bde1ff0655 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -99,9 +99,9 @@ NodeManager::NodeManager(boost::asio::io_service &io_service, lineage_cache_(gcs_client_->client_table().GetLocalClientId(), gcs_client_->raylet_task_table(), gcs_client_->raylet_task_table(), config.max_lineage_size), - remote_clients_(), - remote_server_connections_(), - actor_registry_() { + actor_registry_(), + node_manager_server_(config.node_manager_port, io_service, *this), + client_call_manager_(io_service) { RAY_CHECK(heartbeat_period_.count() > 0); // Initialize the resource map with own cluster resource configuration. ClientID local_client_id = gcs_client_->client_table().GetLocalClientId(); @@ -117,6 +117,8 @@ NodeManager::NodeManager(boost::asio::io_service &io_service, [this](const ObjectID &object_id) { HandleObjectMissing(object_id); })); RAY_ARROW_CHECK_OK(store_client_.Connect(config.store_socket_name.c_str())); + // Run the node manger rpc server. + node_manager_server_.Run(); } ray::Status NodeManager::RegisterGcs() { @@ -366,66 +368,24 @@ void NodeManager::ClientAdded(const ClientTableDataT &client_data) { return; } - // TODO(atumanov): make remote client lookup O(1) - if (std::find(remote_clients_.begin(), remote_clients_.end(), client_id) == - remote_clients_.end()) { - remote_clients_.push_back(client_id); - } else { - // NodeManager connection to this client was already established. - RAY_LOG(DEBUG) << "Received a new client connection that already exists: " + auto entry = remote_node_manager_clients_.find(client_id); + if (entry != remote_node_manager_clients_.end()) { + RAY_LOG(DEBUG) << "Received notification of a new client that already exists: " << client_id; return; } - // Establish a new NodeManager connection to this GCS client. - auto status = ConnectRemoteNodeManager(client_id, client_data.node_manager_address, - client_data.node_manager_port); - if (!status.ok()) { - // This is not a fatal error for raylet, but it should not happen. - // We need to broadcase this message. - std::string type = "raylet_connection_error"; - std::ostringstream error_message; - error_message << "Failed to connect to ray node " << client_id - << " with status: " << status.ToString() - << ". This may be since the node was recently removed."; - // We use the nil DriverID to broadcast the message to all drivers. - RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver( - DriverID::Nil(), type, error_message.str(), current_time_ms())); - return; - } + // Initialize a rpc client to the new node manager. + std::unique_ptr client( + new rpc::NodeManagerClient(client_data.node_manager_address, + client_data.node_manager_port, client_call_manager_)); + remote_node_manager_clients_.emplace(client_id, std::move(client)); ResourceSet resources_total(client_data.resources_total_label, client_data.resources_total_capacity); cluster_resource_map_.emplace(client_id, SchedulingResources(resources_total)); } -ray::Status NodeManager::ConnectRemoteNodeManager(const ClientID &client_id, - const std::string &client_address, - int32_t client_port) { - // Establish a new NodeManager connection to this GCS client. - RAY_LOG(INFO) << "[ConnectClient] Trying to connect to client " << client_id << " at " - << client_address << ":" << client_port; - - boost::asio::ip::tcp::socket socket(io_service_); - RAY_RETURN_NOT_OK(TcpConnect(socket, client_address, client_port)); - - // The client is connected, now send a connect message to remote node manager. - auto server_conn = TcpServerConnection::Create(std::move(socket)); - - // Prepare client connection info buffer - flatbuffers::FlatBufferBuilder fbb; - auto message = protocol::CreateConnectClient(fbb, to_flatbuf(fbb, client_id_)); - fbb.Finish(message); - // Send synchronously. - // TODO(swang): Make this a WriteMessageAsync. - RAY_RETURN_NOT_OK(server_conn->WriteMessage( - static_cast(protocol::MessageType::ConnectClient), fbb.GetSize(), - fbb.GetBufferPointer())); - - remote_server_connections_.emplace(client_id, std::move(server_conn)); - return ray::Status::OK(); -} - void NodeManager::ClientRemoved(const ClientTableDataT &client_data) { // TODO(swang): If we receive a notification for our own death, clean up and // exit immediately. @@ -440,17 +400,13 @@ void NodeManager::ClientRemoved(const ClientTableDataT &client_data) { // check that it is actually removed, or log a warning otherwise, but that may // not be necessary. - // Remove the client from the list of remote clients. - std::remove(remote_clients_.begin(), remote_clients_.end(), client_id); - // Remove the client from the resource map. cluster_resource_map_.erase(client_id); - // Remove the remote server connection. - const auto connection_entry = remote_server_connections_.find(client_id); - if (connection_entry != remote_server_connections_.end()) { - connection_entry->second->Close(); - remote_server_connections_.erase(connection_entry); + // Remove the node manager client. + const auto client_entry = remote_node_manager_clients_.find(client_id); + if (client_entry != remote_node_manager_clients_.end()) { + remote_node_manager_clients_.erase(client_entry); } else { RAY_LOG(WARNING) << "Received ClientRemoved callback for an unknown client " << client_id << "."; @@ -1241,41 +1197,24 @@ void NodeManager::ProcessNewNodeManager(TcpClientConnection &node_manager_client node_manager_client.ProcessMessages(); } -void NodeManager::ProcessNodeManagerMessage(TcpClientConnection &node_manager_client, - int64_t message_type, - const uint8_t *message_data) { - const auto message_type_value = static_cast(message_type); - RAY_LOG(DEBUG) << "[NodeManager] Message " - << protocol::EnumNameMessageType(message_type_value) << "(" - << message_type << ") from node manager"; - switch (message_type_value) { - case protocol::MessageType::ConnectClient: { - auto message = flatbuffers::GetRoot(message_data); - auto client_id = from_flatbuf(*message->client_id()); - node_manager_client.SetClientID(client_id); - } break; - case protocol::MessageType::ForwardTaskRequest: { - auto message = flatbuffers::GetRoot(message_data); - TaskID task_id = from_flatbuf(*message->task_id()); - - Lineage uncommitted_lineage(*message); - const Task &task = uncommitted_lineage.GetEntry(task_id)->TaskData(); - RAY_LOG(DEBUG) << "Received forwarded task " << task.GetTaskSpecification().TaskId() - << " on node " << gcs_client_->client_table().GetLocalClientId() - << " spillback=" << task.GetTaskExecutionSpec().NumForwards(); - SubmitTask(task, uncommitted_lineage, /* forwarded = */ true); - } break; - case protocol::MessageType::DisconnectClient: { - // TODO(rkn): We need to do some cleanup here. - RAY_LOG(DEBUG) << "Received disconnect message from remote node manager. " - << "We need to do some cleanup here."; - // Do not process any more messages from this node manager. - return; - } break; - default: - RAY_LOG(FATAL) << "Received unexpected message type " << message_type; - } - node_manager_client.ProcessMessages(); +void NodeManager::HandleForwardTask(const rpc::ForwardTaskRequest &request, + rpc::ForwardTaskReply *reply, + rpc::RequestDoneCallback done_callback) { + // Get the forwarded task and its uncommitted lineage from the request. + TaskID task_id = TaskID::FromBinary(request.task_id()); + Lineage uncommitted_lineage; + for (int i = 0; i < request.uncommitted_tasks_size(); i++) { + const std::string &task_message = request.uncommitted_tasks(i); + const Task task(*flatbuffers::GetRoot( + reinterpret_cast(task_message.data()))); + RAY_CHECK(uncommitted_lineage.SetEntry(std::move(task), GcsStatus::UNCOMMITTED)); + } + const Task &task = uncommitted_lineage.GetEntry(task_id)->TaskData(); + RAY_LOG(DEBUG) << "Received forwarded task " << task.GetTaskSpecification().TaskId() + << " on node " << gcs_client_->client_table().GetLocalClientId() + << " spillback=" << task.GetTaskExecutionSpec().NumForwards(); + SubmitTask(task, uncommitted_lineage, /* forwarded = */ true); + done_callback(Status::OK()); } void NodeManager::ProcessSetResourceRequest( @@ -2253,6 +2192,16 @@ void NodeManager::ForwardTaskOrResubmit(const Task &task, void NodeManager::ForwardTask( const Task &task, const ClientID &node_id, const std::function &on_error) { + // Lookup node manager client for this node_id and use it to send the request. + auto client_entry = remote_node_manager_clients_.find(node_id); + if (client_entry == remote_node_manager_clients_.end()) { + // TODO(atumanov): caller must handle failure to ensure tasks are not lost. + RAY_LOG(INFO) << "No node manager client found for GCS client id " << node_id; + on_error(ray::Status::IOError("Node manager client not found"), task); + return; + } + auto &client = client_entry->second; + const auto &spec = task.GetTaskSpecification(); auto task_id = spec.TaskId(); @@ -2272,68 +2221,61 @@ void NodeManager::ForwardTask( // Increment forward count for the forwarded task. lineage_cache_entry_task.IncrementNumForwards(); - flatbuffers::FlatBufferBuilder fbb; - auto request = uncommitted_lineage.ToFlatbuffer(fbb, task_id); - fbb.Finish(request); - RAY_LOG(DEBUG) << "Forwarding task " << task_id << " from " << gcs_client_->client_table().GetLocalClientId() << " to " << node_id << " spillback=" << lineage_cache_entry_task.GetTaskExecutionSpec().NumForwards(); - // Lookup remote server connection for this node_id and use it to send the request. - auto it = remote_server_connections_.find(node_id); - if (it == remote_server_connections_.end()) { - // TODO(atumanov): caller must handle failure to ensure tasks are not lost. - RAY_LOG(INFO) << "No NodeManager connection found for GCS client id " << node_id; - on_error(ray::Status::IOError("NodeManager connection not found"), task); - return; + // Prepare the request message. + rpc::ForwardTaskRequest request; + request.set_task_id(task_id.Binary()); + for (auto &entry : uncommitted_lineage.GetEntries()) { + request.add_uncommitted_tasks(entry.second.TaskData().Serialize()); } - auto &server_conn = it->second; + // Move the FORWARDING task to the SWAP queue so that we remember that we // have it queued locally. Once the ForwardTaskRequest has been sent, the // task will get re-queued, depending on whether the message succeeded or // not. local_queues_.QueueTasks({task}, TaskState::SWAP); - server_conn->WriteMessageAsync( - static_cast(protocol::MessageType::ForwardTaskRequest), fbb.GetSize(), - fbb.GetBufferPointer(), [this, on_error, task_id, node_id](ray::Status status) { - // Remove the FORWARDING task from the SWAP queue. - TaskState state; - const auto task = local_queues_.RemoveTask(task_id, &state); - RAY_CHECK(state == TaskState::SWAP); - - if (status.ok()) { - const auto &spec = task.GetTaskSpecification(); - // Mark as forwarded so that the task and its lineage are not - // re-forwarded in the future to the receiving node. - lineage_cache_.MarkTaskAsForwarded(task_id, node_id); - - // Notify the task dependency manager that we are no longer responsible - // for executing this task. - task_dependency_manager_.TaskCanceled(task_id); - // Preemptively push any local arguments to the receiving node. For now, we - // only do this with actor tasks, since actor tasks must be executed by a - // specific process and therefore have affinity to the receiving node. - if (spec.IsActorTask()) { - // Iterate through the object's arguments. NOTE(swang): We do not include - // the execution dependencies here since those cannot be transferred - // between nodes. - for (int i = 0; i < spec.NumArgs(); ++i) { - int count = spec.ArgIdCount(i); - for (int j = 0; j < count; j++) { - ObjectID argument_id = spec.ArgId(i, j); - // If the argument is local, then push it to the receiving node. - if (task_dependency_manager_.CheckObjectLocal(argument_id)) { - object_manager_.Push(argument_id, node_id); - } - } + client->ForwardTask(request, [this, on_error, task_id, node_id]( + Status status, const rpc::ForwardTaskReply &reply) { + // Remove the FORWARDING task from the SWAP queue. + TaskState state; + const auto task = local_queues_.RemoveTask(task_id, &state); + RAY_CHECK(state == TaskState::SWAP); + + if (status.ok()) { + const auto &spec = task.GetTaskSpecification(); + // Mark as forwarded so that the task and its lineage are not + // re-forwarded in the future to the receiving node. + lineage_cache_.MarkTaskAsForwarded(task_id, node_id); + + // Notify the task dependency manager that we are no longer responsible + // for executing this task. + task_dependency_manager_.TaskCanceled(task_id); + // Preemptively push any local arguments to the receiving node. For now, we + // only do this with actor tasks, since actor tasks must be executed by a + // specific process and therefore have affinity to the receiving node. + if (spec.IsActorTask()) { + // Iterate through the object's arguments. NOTE(swang): We do not include + // the execution dependencies here since those cannot be transferred + // between nodes. + for (int i = 0; i < spec.NumArgs(); ++i) { + int count = spec.ArgIdCount(i); + for (int j = 0; j < count; j++) { + ObjectID argument_id = spec.ArgId(i, j); + // If the argument is local, then push it to the receiving node. + if (task_dependency_manager_.CheckObjectLocal(argument_id)) { + object_manager_.Push(argument_id, node_id); } } - } else { - on_error(status, task); } - }); + } + } else { + on_error(status, task); + } + }); } void NodeManager::DumpDebugState() const { @@ -2368,10 +2310,11 @@ std::string NodeManager::DebugString() const { result << "\n- num dead actors: " << statistical_data.dead_actors; result << "\n- max num handles: " << statistical_data.max_num_handles; - result << "\nRemoteConnections:"; - for (auto &pair : remote_server_connections_) { - result << "\n" << pair.first.Hex() << ": " << pair.second->DebugString(); + result << "\nRemote node manager clients: "; + for (const auto &entry : remote_node_manager_clients_) { + result << "\n" << entry.first; } + result << "\nDebugString() time ms: " << (current_time_ms() - now_ms); return result.str(); } diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index 3f7e4d7da97c..61613358330c 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -4,6 +4,9 @@ #include // clang-format off +#include "ray/rpc/client_call.h" +#include "ray/rpc/node_manager_server.h" +#include "ray/rpc/node_manager_client.h" #include "ray/raylet/task.h" #include "ray/object_manager/object_manager.h" #include "ray/common/client_connection.h" @@ -52,7 +55,7 @@ struct NodeManagerConfig { std::string session_dir; }; -class NodeManager { +class NodeManager : public rpc::NodeManagerServiceHandler { public: /// Create a node manager. /// @@ -86,15 +89,6 @@ class NodeManager { /// \return Void. void ProcessNewNodeManager(TcpClientConnection &node_manager_client); - /// Handle a message from a remote node manager. - /// - /// \param node_manager_client The connection to the remote node manager. - /// \param message_type The type of the message. - /// \param message The message contents. - /// \return Void. - void ProcessNodeManagerMessage(TcpClientConnection &node_manager_client, - int64_t message_type, const uint8_t *message); - /// Subscribe to the relevant GCS tables and set up handlers. /// /// \return Status indicating whether this was done successfully or not. @@ -108,6 +102,9 @@ class NodeManager { /// Record metrics. void RecordMetrics() const; + /// Get the port of the node manager rpc server. + int GetServerPort() const { return node_manager_server_.GetPort(); } + private: /// Methods for handling clients. @@ -450,15 +447,10 @@ class NodeManager { void HandleDisconnectedActor(const ActorID &actor_id, bool was_local, bool intentional_disconnect); - /// connect to a remote node manager. - /// - /// \param client_id The client ID for the remote node manager. - /// \param client_address The IP address for the remote node manager. - /// \param client_port The listening port for the remote node manager. - /// \return True if the connect succeeds. - ray::Status ConnectRemoteNodeManager(const ClientID &client_id, - const std::string &client_address, - int32_t client_port); + /// Handle a `ForwardTask` request. + void HandleForwardTask(const rpc::ForwardTaskRequest &request, + rpc::ForwardTaskReply *reply, + rpc::RequestDoneCallback done_callback) override; // GCS client ID for this node. ClientID client_id_; @@ -505,9 +497,6 @@ class NodeManager { TaskDependencyManager task_dependency_manager_; /// The lineage cache for the GCS object and task tables. LineageCache lineage_cache_; - std::vector remote_clients_; - std::unordered_map> - remote_server_connections_; /// A mapping from actor ID to registration information about that actor /// (including which node manager owns it). std::unordered_map actor_registry_; @@ -515,6 +504,16 @@ class NodeManager { /// This map stores actor ID to the ID of the checkpoint that will be used to /// restore the actor. std::unordered_map checkpoint_id_to_restore_; + + /// The RPC server. + rpc::NodeManagerServer node_manager_server_; + + /// The `ClientCallManager` object that is shared by all `NodeManagerClient`s. + rpc::ClientCallManager client_call_manager_; + + /// Map from node ids to clients of the remote node managers. + std::unordered_map> + remote_node_manager_clients_; }; } // namespace raylet diff --git a/src/ray/raylet/raylet.cc b/src/ray/raylet/raylet.cc index 80630d372a61..473e6c263ffe 100644 --- a/src/ray/raylet/raylet.cc +++ b/src/ray/raylet/raylet.cc @@ -61,15 +61,10 @@ Raylet::Raylet(boost::asio::io_service &main_service, const std::string &socket_ main_service, boost::asio::ip::tcp::endpoint(boost::asio::ip::tcp::v4(), object_manager_config.object_manager_port)), - object_manager_socket_(main_service), - node_manager_acceptor_(main_service, boost::asio::ip::tcp::endpoint( - boost::asio::ip::tcp::v4(), - node_manager_config.node_manager_port)), - node_manager_socket_(main_service) { + object_manager_socket_(main_service) { // Start listening for clients. DoAccept(); DoAcceptObjectManager(); - DoAcceptNodeManager(); RAY_CHECK_OK(RegisterGcs( node_ip_address, socket_name_, object_manager_config.store_socket_name, @@ -100,7 +95,7 @@ ray::Status Raylet::RegisterGcs(const std::string &node_ip_address, client_info.raylet_socket_name = raylet_socket_name; client_info.object_store_socket_name = object_store_socket_name; client_info.object_manager_port = object_manager_acceptor_.local_endpoint().port(); - client_info.node_manager_port = node_manager_acceptor_.local_endpoint().port(); + client_info.node_manager_port = node_manager_.GetServerPort(); // Add resource information. for (const auto &resource_pair : node_manager_config.resource_config.GetResourceMap()) { client_info.resources_total_label.push_back(resource_pair.first); @@ -120,33 +115,6 @@ ray::Status Raylet::RegisterGcs(const std::string &node_ip_address, return Status::OK(); } -void Raylet::DoAcceptNodeManager() { - node_manager_acceptor_.async_accept(node_manager_socket_, - boost::bind(&Raylet::HandleAcceptNodeManager, this, - boost::asio::placeholders::error)); -} - -void Raylet::HandleAcceptNodeManager(const boost::system::error_code &error) { - if (!error) { - ClientHandler client_handler = - [this](TcpClientConnection &client) { - node_manager_.ProcessNewNodeManager(client); - }; - MessageHandler message_handler = - [this](std::shared_ptr client, int64_t message_type, - const uint8_t *message) { - node_manager_.ProcessNodeManagerMessage(*client, message_type, message); - }; - // Accept a new TCP client and dispatch it to the node manager. - auto new_connection = TcpClientConnection::Create( - client_handler, message_handler, std::move(node_manager_socket_), "node manager", - node_manager_message_enum, - static_cast(protocol::MessageType::DisconnectClient)); - } - // We're ready to accept another client. - DoAcceptNodeManager(); -} - void Raylet::DoAcceptObjectManager() { object_manager_acceptor_.async_accept( object_manager_socket_, boost::bind(&Raylet::HandleAcceptObjectManager, this, diff --git a/src/ray/raylet/raylet.h b/src/ray/raylet/raylet.h index 84274ea6ecfe..26fe74b2b622 100644 --- a/src/ray/raylet/raylet.h +++ b/src/ray/raylet/raylet.h @@ -63,8 +63,6 @@ class Raylet { void DoAcceptObjectManager(); /// Handle an accepted tcp client connection. void HandleAcceptObjectManager(const boost::system::error_code &error); - void DoAcceptNodeManager(); - void HandleAcceptNodeManager(const boost::system::error_code &error); friend class TestObjectManagerIntegration; @@ -88,10 +86,6 @@ class Raylet { boost::asio::ip::tcp::acceptor object_manager_acceptor_; /// The socket to listen on for new object manager tcp clients. boost::asio::ip::tcp::socket object_manager_socket_; - /// An acceptor for new tcp clients. - boost::asio::ip::tcp::acceptor node_manager_acceptor_; - /// The socket to listen on for new tcp clients. - boost::asio::ip::tcp::socket node_manager_socket_; }; } // namespace raylet diff --git a/src/ray/raylet/task.cc b/src/ray/raylet/task.cc index 5d6a02186ced..9d8036411303 100644 --- a/src/ray/raylet/task.cc +++ b/src/ray/raylet/task.cc @@ -46,14 +46,18 @@ void Task::CopyTaskExecutionSpec(const Task &task) { ComputeDependencies(); } +const std::string Task::Serialize() const { + flatbuffers::FlatBufferBuilder fbb; + fbb.Finish(ToFlatbuffer(fbb)); + return std::string(fbb.GetBufferPointer(), fbb.GetBufferPointer() + fbb.GetSize()); +} + std::string SerializeTaskAsString(const std::vector *dependencies, const TaskSpecification *task_spec) { - flatbuffers::FlatBufferBuilder fbb; std::vector execution_dependencies(*dependencies); TaskExecutionSpecification execution_spec(std::move(execution_dependencies)); Task task(execution_spec, *task_spec); - fbb.Finish(task.ToFlatbuffer(fbb)); - return std::string(fbb.GetBufferPointer(), fbb.GetBufferPointer() + fbb.GetSize()); + return task.Serialize(); } } // namespace raylet diff --git a/src/ray/raylet/task.h b/src/ray/raylet/task.h index b942e2bf2c03..10cdfe5110f4 100644 --- a/src/ray/raylet/task.h +++ b/src/ray/raylet/task.h @@ -84,6 +84,9 @@ class Task { /// \param task Task structure with updated dynamic information. void CopyTaskExecutionSpec(const Task &task); + /// Serialize this task as a string. + const std::string Serialize() const; + private: void ComputeDependencies(); diff --git a/src/ray/rpc/client_call.h b/src/ray/rpc/client_call.h new file mode 100644 index 000000000000..725652cb5ebc --- /dev/null +++ b/src/ray/rpc/client_call.h @@ -0,0 +1,169 @@ +#ifndef RAY_RPC_CLIENT_CALL_H +#define RAY_RPC_CLIENT_CALL_H + +#include +#include + +#include "ray/common/status.h" +#include "ray/rpc/util.h" + +namespace ray { +namespace rpc { + +/// Represents an outgoing gRPC request. +/// +/// The lifecycle of a `ClientCall` is as follows. +/// +/// When a client submits a new gRPC request, a new `ClientCall` object will be created +/// by `ClientCallMangager::CreateCall`. Then the object will be used as the tag of +/// `CompletionQueue`. +/// +/// When the reply is received, `ClientCallMangager` will get the address of this object +/// via `CompletionQueue`'s tag. And the manager should call `OnReplyReceived` and then +/// delete this object. +/// +/// NOTE(hchen): Compared to `ClientCallImpl`, this abstract interface doesn't use +/// template. This allows the users (e.g., `ClientCallMangager`) not having to use +/// template as well. +class ClientCall { + public: + /// The callback to be called by `ClientCallManager` when the reply of this request is + /// received. + virtual void OnReplyReceived() = 0; +}; + +class ClientCallManager; + +/// Reprents the client callback function of a particular rpc method. +/// +/// \tparam Reply Type of the reply message. +template +using ClientCallback = std::function; + +/// Implementaion of the `ClientCall`. It represents a `ClientCall` for a particular +/// RPC method. +/// +/// \tparam Reply Type of the Reply message. +template +class ClientCallImpl : public ClientCall { + public: + void OnReplyReceived() override { callback_(GrpcStatusToRayStatus(status_), reply_); } + + private: + /// Constructor. + /// + /// \param[in] callback The callback function to handle the reply. + ClientCallImpl(const ClientCallback &callback) : callback_(callback) {} + + /// The reply message. + Reply reply_; + + /// The callback function to handle the reply. + ClientCallback callback_; + + /// The response reader. + std::unique_ptr> response_reader_; + + /// gRPC status of this request. + grpc::Status status_; + + /// Context for the client. It could be used to convey extra information to + /// the server and/or tweak certain RPC behaviors. + grpc::ClientContext context_; + + friend class ClientCallManager; +}; + +/// Peprents the generic signature of a `FooService::Stub::PrepareAsyncBar` +/// function, where `Foo` is the service name and `Bar` is the rpc method name. +/// +/// \tparam GrpcService Type of the gRPC-generated service class. +/// \tparam Request Type of the request message. +/// \tparam Reply Type of the reply message. +template +using PrepareAsyncFunction = std::unique_ptr> ( + GrpcService::Stub::*)(grpc::ClientContext *context, const Request &request, + grpc::CompletionQueue *cq); + +/// `ClientCallManager` is used to manage outgoing gRPC requests and the lifecycles of +/// `ClientCall` objects. +/// +/// It maintains a thread that keeps polling events from `CompletionQueue`, and post +/// the callback function to the main event loop when a reply is received. +/// +/// Mutiple clients can share one `ClientCallManager`. +class ClientCallManager { + public: + /// Constructor. + /// + /// \param[in] main_service The main event loop, to which the callback functions will be + /// posted. + ClientCallManager(boost::asio::io_service &main_service) : main_service_(main_service) { + // Start the polling thread. + std::thread polling_thread(&ClientCallManager::PollEventsFromCompletionQueue, this); + polling_thread.detach(); + } + + ~ClientCallManager() { cq_.Shutdown(); } + + /// Create a new `ClientCall` and send request. + /// + /// \param[in] stub The gRPC-generated stub. + /// \param[in] prepare_async_function Pointer to the gRPC-generated + /// `FooService::Stub::PrepareAsyncBar` function. + /// \param[in] request The request message. + /// \param[in] callback The callback function that handles reply. + /// + /// \tparam GrpcService Type of the gRPC-generated service class. + /// \tparam Request Type of the request message. + /// \tparam Reply Type of the reply message. + template + ClientCall *CreateCall( + typename GrpcService::Stub &stub, + const PrepareAsyncFunction prepare_async_function, + const Request &request, const ClientCallback &callback) { + // Create a new `ClientCall` object. This object will eventuall be deleted in the + // `ClientCallManager::PollEventsFromCompletionQueue` when reply is received. + auto call = new ClientCallImpl(callback); + // Send request. + call->response_reader_ = + (stub.*prepare_async_function)(&call->context_, request, &cq_); + call->response_reader_->StartCall(); + call->response_reader_->Finish(&call->reply_, &call->status_, (void *)call); + return call; + } + + private: + /// This function runs in a background thread. It keeps polling events from the + /// `CompletionQueue`, and dispaches the event to the callbacks via the `ClientCall` + /// objects. + void PollEventsFromCompletionQueue() { + void *got_tag; + bool ok = false; + // Keep reading events from the `CompletionQueue` until it's shutdown. + while (cq_.Next(&got_tag, &ok)) { + ClientCall *call = reinterpret_cast(got_tag); + if (ok) { + // Post the callback to the main event loop. + main_service_.post([call]() { + call->OnReplyReceived(); + // The call is finished, we can delete the `ClientCall` object now. + delete call; + }); + } else { + delete call; + } + } + } + + /// The main event loop, to which the callback functions will be posted. + boost::asio::io_service &main_service_; + + /// The gRPC `CompletionQueue` object used to poll events. + grpc::CompletionQueue cq_; +}; + +} // namespace rpc +} // namespace ray + +#endif diff --git a/src/ray/rpc/grpc_server.cc b/src/ray/rpc/grpc_server.cc new file mode 100644 index 000000000000..feb788da7692 --- /dev/null +++ b/src/ray/rpc/grpc_server.cc @@ -0,0 +1,70 @@ +#include "ray/rpc/grpc_server.h" + +namespace ray { +namespace rpc { + +void GrpcServer::Run() { + std::string server_address("0.0.0.0:" + std::to_string(port_)); + + grpc::ServerBuilder builder; + // TODO(hchen): Add options for authentication. + builder.AddListeningPort(server_address, grpc::InsecureServerCredentials(), &port_); + // Allow subclasses to register concrete services. + RegisterServices(builder); + // Get hold of the completion queue used for the asynchronous communication + // with the gRPC runtime. + cq_ = builder.AddCompletionQueue(); + // Build and start server. + server_ = builder.BuildAndStart(); + RAY_LOG(DEBUG) << name_ << " server started, listening on port " << port_ << "."; + + // Allow subclasses to initialize the server call factories. + InitServerCallFactories(&server_call_factories_and_concurrencies_); + for (auto &entry : server_call_factories_and_concurrencies_) { + for (int i = 0; i < entry.second; i++) { + // Create and request calls from the factory. + entry.first->CreateCall(); + } + } + // Start a thread that polls incoming requests. + std::thread polling_thread(&GrpcServer::PollEventsFromCompletionQueue, this); + polling_thread.detach(); +} + +void GrpcServer::PollEventsFromCompletionQueue() { + void *tag; + bool ok; + // Keep reading events from the `CompletionQueue` until it's shutdown. + while (cq_->Next(&tag, &ok)) { + ServerCall *server_call = static_cast(tag); + // `ok == false` indicates that the server has been shut down. + // We should delete the call object in this case. + bool delete_call = !ok; + if (ok) { + switch (server_call->GetState()) { + case ServerCallState::PENDING: + // We've received a new incoming request. Now this call object is used to + // track this request. So we need to create another call to handle next + // incoming request. + server_call->GetFactory().CreateCall(); + server_call->SetState(ServerCallState::PROCESSING); + main_service_.post([server_call] { server_call->HandleRequest(); }); + break; + case ServerCallState::SENDING_REPLY: + // The reply has been sent, this call can be deleted now. + // This event is triggered by `ServerCallImpl::SendReply`. + delete_call = true; + break; + default: + RAY_LOG(FATAL) << "Shouldn't reach here."; + break; + } + } + if (delete_call) { + delete server_call; + } + } +} + +} // namespace rpc +} // namespace ray diff --git a/src/ray/rpc/grpc_server.h b/src/ray/rpc/grpc_server.h new file mode 100644 index 000000000000..4953f470610f --- /dev/null +++ b/src/ray/rpc/grpc_server.h @@ -0,0 +1,92 @@ +#ifndef RAY_RPC_GRPC_SERVER_H +#define RAY_RPC_GRPC_SERVER_H + +#include + +#include +#include + +#include "ray/common/status.h" +#include "ray/rpc/server_call.h" + +namespace ray { +namespace rpc { + +/// Base class that represents an abstract gRPC server. +/// +/// A `GrpcServer` listens on a specific port. It owns +/// 1) a `ServerCompletionQueue` that is used for polling events from gRPC, +/// 2) and a thread that polls events from the `ServerCompletionQueue`. +/// +/// Subclasses can register one or multiple services to a `GrpcServer`, see +/// `RegisterServices`. And they should also implement `InitServerCallFactories` to decide +/// which kinds of requests this server should accept. +class GrpcServer { + public: + /// Constructor. + /// + /// \param[in] name Name of this server, used for logging and debugging purpose. + /// \param[in] port The port to bind this server to. If it's 0, a random available port + /// will be chosen. + /// \param[in] main_service The main event loop, to which service handler functions + /// will be posted. + GrpcServer(const std::string &name, const uint32_t port, + boost::asio::io_service &main_service) + : name_(name), port_(port), main_service_(main_service) {} + + /// Destruct this gRPC server. + ~GrpcServer() { + server_->Shutdown(); + cq_->Shutdown(); + } + + /// Initialize and run this server. + void Run(); + + /// Get the port of this gRPC server. + int GetPort() const { return port_; } + + protected: + /// Subclasses should implement this method and register one or multiple gRPC services + /// to the given `ServerBuilder`. + /// + /// \param[in] builder The `ServerBuilder` instance to register services to. + virtual void RegisterServices(grpc::ServerBuilder &builder) = 0; + + /// Subclasses should implement this method to initialize the `ServerCallFactory` + /// instances, as well as specify maximum number of concurrent requests that gRPC + /// server can "accept" (not "handle"). Each factory will be used to create + /// `accept_concurrency` `ServerCall` objects, each of which will be used to accept and + /// handle an incoming request. + /// + /// \param[out] server_call_factories_and_concurrencies The `ServerCallFactory` objects, + /// and the maximum number of concurrent requests that gRPC server can accept. + virtual void InitServerCallFactories( + std::vector, int>> + *server_call_factories_and_concurrencies) = 0; + + /// This function runs in a background thread. It keeps polling events from the + /// `ServerCompletionQueue`, and dispaches the event to the `ServiceHandler` instances + /// via the `ServerCall` objects. + void PollEventsFromCompletionQueue(); + + /// The main event loop, to which the service handler functions will be posted. + boost::asio::io_service &main_service_; + /// Name of this server, used for logging and debugging purpose. + const std::string name_; + /// Port of this server. + int port_; + /// The `ServerCallFactory` objects, and the maximum number of concurrent requests that + /// gRPC server can accept. + std::vector, int>> + server_call_factories_and_concurrencies_; + /// The `ServerCompletionQueue` object used for polling events. + std::unique_ptr cq_; + /// The `Server` object. + std::unique_ptr server_; +}; + +} // namespace rpc +} // namespace ray + +#endif diff --git a/src/ray/rpc/node_manager_client.h b/src/ray/rpc/node_manager_client.h new file mode 100644 index 000000000000..005c75db40d2 --- /dev/null +++ b/src/ray/rpc/node_manager_client.h @@ -0,0 +1,56 @@ +#ifndef RAY_RPC_NODE_MANAGER_CLIENT_H +#define RAY_RPC_NODE_MANAGER_CLIENT_H + +#include + +#include + +#include "ray/common/status.h" +#include "ray/rpc/client_call.h" +#include "ray/util/logging.h" +#include "src/ray/protobuf/node_manager.grpc.pb.h" +#include "src/ray/protobuf/node_manager.pb.h" + +namespace ray { +namespace rpc { + +/// Client used for communicating with a remote node manager server. +class NodeManagerClient { + public: + /// Constructor. + /// + /// \param[in] address Address of the node manager server. + /// \param[in] port Port of the node manager server. + /// \param[in] client_call_manager The `ClientCallManager` used for managing requests. + NodeManagerClient(const std::string &address, const int port, + ClientCallManager &client_call_manager) + : client_call_manager_(client_call_manager) { + std::shared_ptr channel = grpc::CreateChannel( + address + ":" + std::to_string(port), grpc::InsecureChannelCredentials()); + stub_ = NodeManagerService::NewStub(channel); + }; + + /// Forward a task and its uncommitted lineage. + /// + /// \param[in] request The request message. + /// \param[in] callback The callback function that handles reply. + void ForwardTask(const ForwardTaskRequest &request, + const ClientCallback &callback) { + client_call_manager_ + .CreateCall( + *stub_, &NodeManagerService::Stub::PrepareAsyncForwardTask, request, + callback); + } + + private: + /// The gRPC-generated stub. + std::unique_ptr stub_; + + /// The `ClientCallManager` used for managing requests. + ClientCallManager &client_call_manager_; +}; + +} // namespace rpc +} // namespace ray + +#endif // RAY_RPC_NODE_MANAGER_CLIENT_H diff --git a/src/ray/rpc/node_manager_server.h b/src/ray/rpc/node_manager_server.h new file mode 100644 index 000000000000..afaea299ea89 --- /dev/null +++ b/src/ray/rpc/node_manager_server.h @@ -0,0 +1,71 @@ +#ifndef RAY_RPC_NODE_MANAGER_SERVER_H +#define RAY_RPC_NODE_MANAGER_SERVER_H + +#include "ray/rpc/grpc_server.h" +#include "ray/rpc/server_call.h" + +#include "src/ray/protobuf/node_manager.grpc.pb.h" +#include "src/ray/protobuf/node_manager.pb.h" + +namespace ray { +namespace rpc { + +/// Interface of the `NodeManagerService`, see `src/ray/protobuf/node_manager.proto`. +class NodeManagerServiceHandler { + public: + /// Handle a `ForwardTask` request. + /// The implementation can handle this request asynchronously. When hanling is done, the + /// `done_callback` should be called. + /// + /// \param[in] request The request message. + /// \param[out] reply The reply message. + /// \param[in] done_callback The callback to be called when the request is done. + virtual void HandleForwardTask(const ForwardTaskRequest &request, + ForwardTaskReply *reply, + RequestDoneCallback done_callback) = 0; +}; + +/// The `GrpcServer` for `NodeManagerService`. +class NodeManagerServer : public GrpcServer { + public: + /// Constructor. + /// + /// \param[in] port See super class. + /// \param[in] main_service See super class. + /// \param[in] handler The service handler that actually handle the requests. + NodeManagerServer(const uint32_t port, boost::asio::io_service &main_service, + NodeManagerServiceHandler &service_handler) + : GrpcServer("NodeManager", port, main_service), + service_handler_(service_handler){}; + + void RegisterServices(grpc::ServerBuilder &builder) override { + /// Register `NodeManagerService`. + builder.RegisterService(&service_); + } + + void InitServerCallFactories( + std::vector, int>> + *server_call_factories_and_concurrencies) override { + // Initialize the factory for `ForwardTask` requests. + std::unique_ptr forward_task_call_factory( + new ServerCallFactoryImpl( + service_, &NodeManagerService::AsyncService::RequestForwardTask, + service_handler_, &NodeManagerServiceHandler::HandleForwardTask, cq_)); + + // Set `ForwardTask`'s accept concurrency to 100. + server_call_factories_and_concurrencies->emplace_back( + std::move(forward_task_call_factory), 100); + } + + private: + /// The grpc async service object. + NodeManagerService::AsyncService service_; + /// The service handler that actually handle the requests. + NodeManagerServiceHandler &service_handler_; +}; + +} // namespace rpc +} // namespace ray + +#endif diff --git a/src/ray/rpc/server_call.h b/src/ray/rpc/server_call.h new file mode 100644 index 000000000000..e06278260ab6 --- /dev/null +++ b/src/ray/rpc/server_call.h @@ -0,0 +1,233 @@ +#ifndef RAY_RPC_SERVER_CALL_H +#define RAY_RPC_SERVER_CALL_H + +#include + +#include "ray/common/status.h" +#include "ray/rpc/util.h" + +namespace ray { +namespace rpc { + +/// Represents the callback function to be called when a `ServiceHandler` finishes +/// handling a request. +using RequestDoneCallback = std::function; + +/// Represents state of a `ServerCall`. +enum class ServerCallState { + /// The call is created and waiting for an incoming request. + PENDING, + /// Request is received and being processed. + PROCESSING, + /// Request processing is done, and reply is being sent to client. + SENDING_REPLY +}; + +class ServerCallFactory; + +/// Reprensents an incoming request of a gRPC server. +/// +/// The lifecycle and state transition of a `ServerCall` is as follows: +/// +/// --(1)--> PENDING --(2)--> PROCESSING --(3)--> SENDING_REPLY --(4)--> [FINISHED] +/// +/// (1) The `GrpcServer` creates a `ServerCall` and use it as the tag to accept requests +/// gRPC `CompletionQueue`. Now the state is `PENDING`. +/// (2) When a request is received, an event will be gotten from the `CompletionQueue`. +/// `GrpcServer` then should change `ServerCall`'s state to PROCESSING and call +/// `ServerCall::HandleRequest`. +/// (3) When the `ServiceHandler` finishes handling the request, `ServerCallImpl::Finish` +/// will be called, and the state becomes `SENDING_REPLY`. +/// (4) When the reply is sent, an event will be gotten from the `CompletionQueue`. +/// `GrpcServer` will then delete this call. +/// +/// NOTE(hchen): Compared to `ServerCallImpl`, this abstract interface doesn't use +/// template. This allows the users (e.g., `GrpcServer`) not having to use +/// template as well. +class ServerCall { + public: + /// Get the state of this `ServerCall`. + virtual ServerCallState GetState() const = 0; + + /// Set state of this `ServerCall`. + virtual void SetState(const ServerCallState &new_state) = 0; + + /// Handle the requst. This is the callback function to be called by + /// `GrpcServer` when the request is received. + virtual void HandleRequest() = 0; + + /// Get the factory that created this `ServerCall`. + virtual const ServerCallFactory &GetFactory() const = 0; +}; + +/// The factory that creates a particular kind of `ServerCall` objects. +class ServerCallFactory { + public: + /// Create a new `ServerCall` and request gRPC runtime to start accepting the + /// corresonding type of requests. + /// + /// \return Pointer to the `ServerCall` object. + virtual ServerCall *CreateCall() const = 0; +}; + +/// Represents the generic signature of a `FooServiceHandler::HandleBar()` +/// function, where `Foo` is the service name and `Bar` is the rpc method name. +/// +/// \tparam ServiceHandler Type of the handler that handles the request. +/// \tparam Request Type of the request message. +/// \tparam Reply Type of the reply message. +template +using HandleRequestFunction = void (ServiceHandler::*)(const Request &, Reply *, + RequestDoneCallback); + +/// Implementation of `ServerCall`. It represents `ServerCall` for a particular +/// RPC method. +/// +/// \tparam ServiceHandler Type of the handler that handles the request. +/// \tparam Request Type of the request message. +/// \tparam Reply Type of the reply message. +template +class ServerCallImpl : public ServerCall { + public: + /// Constructor. + /// + /// \param[in] factory The factory which created this call. + /// \param[in] service_handler The service handler that handles the request. + /// \param[in] handle_request_function Pointer to the service handler function. + ServerCallImpl( + const ServerCallFactory &factory, ServiceHandler &service_handler, + HandleRequestFunction handle_request_function) + : state_(ServerCallState::PENDING), + factory_(factory), + service_handler_(service_handler), + handle_request_function_(handle_request_function), + response_writer_(&context_) {} + + ServerCallState GetState() const override { return state_; } + + void SetState(const ServerCallState &new_state) override { state_ = new_state; } + + void HandleRequest() override { + state_ = ServerCallState::PROCESSING; + (service_handler_.*handle_request_function_)(request_, &reply_, + [this](Status status) { + // When the handler is done with the + // request, tell gRPC to finish this + // request. + SendReply(status); + }); + } + + const ServerCallFactory &GetFactory() const override { return factory_; } + + private: + /// Tell gRPC to finish this request. + void SendReply(Status status) { + state_ = ServerCallState::SENDING_REPLY; + response_writer_.Finish(reply_, RayStatusToGrpcStatus(status), this); + } + + /// State of this call. + ServerCallState state_; + + /// The factory which created this call. + const ServerCallFactory &factory_; + + /// The service handler that handles the request. + ServiceHandler &service_handler_; + + /// Pointer to the service handler function. + HandleRequestFunction handle_request_function_; + + /// Context for the request, allowing to tweak aspects of it such as the use + /// of compression, authentication, as well as to send metadata back to the client. + grpc::ServerContext context_; + + /// The reponse writer. + grpc::ServerAsyncResponseWriter response_writer_; + + /// The request message. + Request request_; + + /// The reply message. + Reply reply_; + + template + friend class ServerCallFactoryImpl; +}; + +/// Represents the generic signature of a `FooService::AsyncService::RequestBar()` +/// function, where `Foo` is the service name and `Bar` is the rpc method name. +/// \tparam GrpcService Type of the gRPC-generated service class. +/// \tparam Request Type of the request message. +/// \tparam Reply Type of the reply message. +template +using RequestCallFunction = void (GrpcService::AsyncService::*)( + grpc::ServerContext *, Request *, grpc::ServerAsyncResponseWriter *, + grpc::CompletionQueue *, grpc::ServerCompletionQueue *, void *); + +/// Implementation of `ServerCallFactory` +/// +/// \tparam GrpcService Type of the gRPC-generated service class. +/// \tparam ServiceHandler Type of the handler that handles the request. +/// \tparam Request Type of the request message. +/// \tparam Reply Type of the reply message. +template +class ServerCallFactoryImpl : public ServerCallFactory { + using AsyncService = typename GrpcService::AsyncService; + + public: + /// Constructor. + /// + /// \param[in] service The gRPC-generated `AsyncService`. + /// \param[in] request_call_function Pointer to the `AsyncService::RequestMethod` + // function. + /// \param[in] service_handler The service handler that handles the request. + /// \param[in] handle_request_function Pointer to the service handler function. + /// \param[in] cq The `CompletionQueue`. + ServerCallFactoryImpl( + AsyncService &service, + RequestCallFunction request_call_function, + ServiceHandler &service_handler, + HandleRequestFunction handle_request_function, + const std::unique_ptr &cq) + : service_(service), + request_call_function_(request_call_function), + service_handler_(service_handler), + handle_request_function_(handle_request_function), + cq_(cq) {} + + ServerCall *CreateCall() const override { + // Create a new `ServerCall`. This object will eventually be deleted by + // `GrpcServer::PollEventsFromCompletionQueue`. + auto call = new ServerCallImpl( + *this, service_handler_, handle_request_function_); + /// Request gRPC runtime to starting accepting this kind of request, using the call as + /// the tag. + (service_.*request_call_function_)(&call->context_, &call->request_, + &call->response_writer_, cq_.get(), cq_.get(), + call); + return call; + } + + private: + /// The gRPC-generated `AsyncService`. + AsyncService &service_; + + /// Pointer to the `AsyncService::RequestMethod` function. + RequestCallFunction request_call_function_; + + /// The service handler that handles the request. + ServiceHandler &service_handler_; + + /// Pointer to the service handler function. + HandleRequestFunction handle_request_function_; + + /// The `CompletionQueue`. + const std::unique_ptr &cq_; +}; + +} // namespace rpc +} // namespace ray + +#endif diff --git a/src/ray/rpc/util.h b/src/ray/rpc/util.h new file mode 100644 index 000000000000..6ecc6c3c4a34 --- /dev/null +++ b/src/ray/rpc/util.h @@ -0,0 +1,33 @@ +#ifndef RAY_RPC_UTIL_H +#define RAY_RPC_UTIL_H + +#include + +#include "ray/common/status.h" + +namespace ray { +namespace rpc { + +/// Helper function that converts a ray status to gRPC status. +inline grpc::Status RayStatusToGrpcStatus(const Status &ray_status) { + if (ray_status.ok()) { + return grpc::Status::OK; + } else { + // TODO(hchen): Use more specific error code. + return grpc::Status(grpc::StatusCode::UNKNOWN, ray_status.message()); + } +} + +/// Helper function that converts a gRPC status to ray status. +inline Status GrpcStatusToRayStatus(const grpc::Status &grpc_status) { + if (grpc_status.ok()) { + return Status::OK(); + } else { + return Status::IOError(grpc_status.error_message()); + } +} + +} // namespace rpc +} // namespace ray + +#endif From 7bda5edc16d40880b16ac04a5421dbfe79f4ccb2 Mon Sep 17 00:00:00 2001 From: Qing Wang Date: Wed, 19 Jun 2019 11:36:21 +0800 Subject: [PATCH 100/118] Fix Java CI failure (#4995) --- .../src/main/java/org/ray/api/id/BaseId.java | 2 +- .../src/main/java/org/ray/api/TestUtils.java | 15 +++++++++++++++ .../org/ray/api/test/DynamicResourceTest.java | 17 +++++++++++++---- .../main/java/org/ray/api/test/WaitTest.java | 5 +++++ 4 files changed, 34 insertions(+), 5 deletions(-) diff --git a/java/api/src/main/java/org/ray/api/id/BaseId.java b/java/api/src/main/java/org/ray/api/id/BaseId.java index e08955d5a93e..c13f0436f94d 100644 --- a/java/api/src/main/java/org/ray/api/id/BaseId.java +++ b/java/api/src/main/java/org/ray/api/id/BaseId.java @@ -48,7 +48,7 @@ public boolean isNil() { break; } } - isNilCache = localIsNil; + isNilCache = localIsNil; } return isNilCache; } diff --git a/java/test/src/main/java/org/ray/api/TestUtils.java b/java/test/src/main/java/org/ray/api/TestUtils.java index 9b3bbf233856..3636c93e4909 100644 --- a/java/test/src/main/java/org/ray/api/TestUtils.java +++ b/java/test/src/main/java/org/ray/api/TestUtils.java @@ -1,8 +1,10 @@ package org.ray.api; import java.util.function.Supplier; +import org.ray.api.annotation.RayRemote; import org.ray.runtime.AbstractRayRuntime; import org.ray.runtime.config.RunMode; +import org.testng.Assert; import org.testng.SkipException; public class TestUtils { @@ -42,4 +44,17 @@ public static boolean waitForCondition(Supplier condition, int timeoutM } return false; } + + @RayRemote + private static String hi() { + return "hi"; + } + + /** + * Warm up the cluster. + */ + public static void warmUpCluster() { + RayObject obj = Ray.call(TestUtils::hi); + Assert.assertEquals(obj.get(), "hi"); + } } diff --git a/java/test/src/main/java/org/ray/api/test/DynamicResourceTest.java b/java/test/src/main/java/org/ray/api/test/DynamicResourceTest.java index 79b3eba0ed13..71766c6cf2bf 100644 --- a/java/test/src/main/java/org/ray/api/test/DynamicResourceTest.java +++ b/java/test/src/main/java/org/ray/api/test/DynamicResourceTest.java @@ -23,6 +23,10 @@ public static String sayHi() { @Test public void testSetResource() { TestUtils.skipTestUnderSingleProcess(); + + // Call a task in advance to warm up the cluster to avoid being too slow to start workers. + TestUtils.warmUpCluster(); + CallOptions op1 = new CallOptions.Builder().setResources(ImmutableMap.of("A", 10.0)).createCallOptions(); RayObject obj = Ray.call(DynamicResourceTest::sayHi, op1); @@ -30,16 +34,21 @@ public void testSetResource() { Assert.assertEquals(result.getReady().size(), 0); Ray.setResource("A", 10.0); + boolean resourceReady = TestUtils.waitForCondition(() -> { + List nodes = Ray.getRuntimeContext().getAllNodeInfo(); + if (nodes.size() != 1) { + return false; + } + return (0 == Double.compare(10.0, nodes.get(0).resources.get("A"))); + }, 2000); - // Assert node info. - List nodes = Ray.getRuntimeContext().getAllNodeInfo(); - Assert.assertEquals(nodes.size(), 1); - Assert.assertEquals(nodes.get(0).resources.get("A"), 10.0); + Assert.assertTrue(resourceReady); // Assert ray call result. result = Ray.wait(ImmutableList.of(obj), 1, 1000); Assert.assertEquals(result.getReady().size(), 1); Assert.assertEquals(Ray.get(obj.getId()), "hi"); + } } diff --git a/java/test/src/main/java/org/ray/api/test/WaitTest.java b/java/test/src/main/java/org/ray/api/test/WaitTest.java index e82b99d364ba..bccc50a50bdf 100644 --- a/java/test/src/main/java/org/ray/api/test/WaitTest.java +++ b/java/test/src/main/java/org/ray/api/test/WaitTest.java @@ -5,6 +5,7 @@ import java.util.List; import org.ray.api.Ray; import org.ray.api.RayObject; +import org.ray.api.TestUtils; import org.ray.api.WaitResult; import org.ray.api.annotation.RayRemote; import org.testng.Assert; @@ -28,6 +29,9 @@ private static String delayedHi() { } private static void testWait() { + // Call a task in advance to warm up the cluster to avoid being too slow to start workers. + TestUtils.warmUpCluster(); + RayObject obj1 = Ray.call(WaitTest::hi); RayObject obj2 = Ray.call(WaitTest::delayedHi); @@ -71,4 +75,5 @@ public void testWaitForEmpty() { Assert.assertTrue(true); } } + } From e59e8074dd139ce8c2b816d10885fa024a43f639 Mon Sep 17 00:00:00 2001 From: Andrew Berger Date: Thu, 20 Jun 2019 18:33:40 -0400 Subject: [PATCH 101/118] fix handling of non-integral timeout values in signal.receive (#5002) --- python/ray/experimental/signal.py | 14 +++++++++++-- python/ray/tests/test_signal.py | 33 +++++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 2 deletions(-) diff --git a/python/ray/experimental/signal.py b/python/ray/experimental/signal.py index f2a0d81ca343..25ec072d3fc7 100644 --- a/python/ray/experimental/signal.py +++ b/python/ray/experimental/signal.py @@ -2,6 +2,8 @@ from __future__ import division from __future__ import print_function +import logging + from collections import defaultdict import ray @@ -13,6 +15,8 @@ # in node_manager.cc ACTOR_DIED_STR = "ACTOR_DIED_SIGNAL" +logger = logging.getLogger(__name__) + class Signal(object): """Base class for Ray signals.""" @@ -125,10 +129,16 @@ def receive(sources, timeout=None): for s in sources: task_id_to_sources[_get_task_id(s).hex()].append(s) + if timeout < 1e-3: + logger.warning("Timeout too small. Using 1ms minimum") + timeout = 1e-3 + + timeout_ms = int(1000 * timeout) + # Construct the redis query. query = "XREAD BLOCK " - # Multiply by 1000x since timeout is in sec and redis expects ms. - query += str(1000 * timeout) + # redis expects ms. + query += str(timeout_ms) query += " STREAMS " query += " ".join([task_id for task_id in task_id_to_sources]) query += " " diff --git a/python/ray/tests/test_signal.py b/python/ray/tests/test_signal.py index fe2e74379245..176fbd45bcaa 100644 --- a/python/ray/tests/test_signal.py +++ b/python/ray/tests/test_signal.py @@ -353,3 +353,36 @@ def f(sources): assert len(result_list) == 1 result_list = ray.get(f.remote([a])) assert len(result_list) == 1 + + +def test_non_integral_receive_timeout(ray_start_regular): + @ray.remote + def send_signal(value): + signal.send(UserSignal(value)) + + a = send_signal.remote(0) + # make sure send_signal had a chance to execute + ray.get(a) + + result_list = ray.experimental.signal.receive([a], timeout=0.1) + + assert len(result_list) == 1 + + +def test_small_receive_timeout(ray_start_regular): + """ Test that receive handles timeout smaller than the 1ms min + """ + # 0.1 ms + small_timeout = 1e-4 + + @ray.remote + def send_signal(value): + signal.send(UserSignal(value)) + + a = send_signal.remote(0) + # make sure send_signal had a chance to execute + ray.get(a) + + result_list = ray.experimental.signal.receive([a], timeout=small_timeout) + + assert len(result_list) == 1 From 1d17125333aae0733f909c0fab97bbc34012380f Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Thu, 20 Jun 2019 18:07:44 -0700 Subject: [PATCH 102/118] temp fix for build (#5006) --- python/ray/rllib/tests/test_optimizers.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/python/ray/rllib/tests/test_optimizers.py b/python/ray/rllib/tests/test_optimizers.py index a87a295ccf1d..d27270c20965 100644 --- a/python/ray/rllib/tests/test_optimizers.py +++ b/python/ray/rllib/tests/test_optimizers.py @@ -125,14 +125,14 @@ def testSimple(self): def testMultiGPU(self): local, remotes = self._make_evs() workers = WorkerSet._from_existing(local, remotes) - optimizer = AsyncSamplesOptimizer(workers, num_gpus=2, _fake_gpus=True) + optimizer = AsyncSamplesOptimizer(workers, num_gpus=1, _fake_gpus=True) self._wait_for(optimizer, 1000, 1000) def testMultiGPUParallelLoad(self): local, remotes = self._make_evs() workers = WorkerSet._from_existing(local, remotes) optimizer = AsyncSamplesOptimizer( - workers, num_gpus=2, num_data_loader_buffers=2, _fake_gpus=True) + workers, num_gpus=1, num_data_loader_buffers=1, _fake_gpus=True) self._wait_for(optimizer, 1000, 1000) def testMultiplePasses(self): @@ -211,21 +211,21 @@ def testRejectBadConfigs(self): num_data_loader_buffers=2, minibatch_buffer_size=4)) optimizer = AsyncSamplesOptimizer( workers, - num_gpus=2, + num_gpus=1, train_batch_size=100, sample_batch_size=50, _fake_gpus=True) self._wait_for(optimizer, 1000, 1000) optimizer = AsyncSamplesOptimizer( workers, - num_gpus=2, + num_gpus=1, train_batch_size=100, sample_batch_size=25, _fake_gpus=True) self._wait_for(optimizer, 1000, 1000) optimizer = AsyncSamplesOptimizer( workers, - num_gpus=2, + num_gpus=1, train_batch_size=100, sample_batch_size=74, _fake_gpus=True) From 31b6da12f91a97182a5c2094656cfd37e1e7736e Mon Sep 17 00:00:00 2001 From: Richard Liaw Date: Fri, 21 Jun 2019 12:59:49 +0800 Subject: [PATCH 103/118] [tune] Tutorial UX Changes (#4990) * add integration, iris, ASHA, recursive changes, set reuse_actors=True, and enable Analysis as a return object * docstring * fix up example * fix * cleanup tests * experiment analysis --- .../ray/tune/analysis/experiment_analysis.py | 94 ++++++++++++++----- python/ray/tune/examples/track_example.py | 4 +- python/ray/tune/examples/tune_mnist_keras.py | 8 +- python/ray/tune/examples/utils.py | 36 +++---- python/ray/tune/experiment.py | 8 ++ python/ray/tune/integration/__init__.py | 0 python/ray/tune/integration/keras.py | 34 +++++++ python/ray/tune/schedulers/__init__.py | 6 +- python/ray/tune/schedulers/async_hyperband.py | 2 + .../tune/tests/test_experiment_analysis.py | 62 +++++++----- python/ray/tune/tests/test_trial_runner.py | 8 ++ python/ray/tune/trial.py | 25 +++-- python/ray/tune/tune.py | 11 ++- 13 files changed, 211 insertions(+), 87 deletions(-) create mode 100644 python/ray/tune/integration/__init__.py create mode 100644 python/ray/tune/integration/keras.py diff --git a/python/ray/tune/analysis/experiment_analysis.py b/python/ray/tune/analysis/experiment_analysis.py index 0164ec2b1a2e..a3c246aba161 100644 --- a/python/ray/tune/analysis/experiment_analysis.py +++ b/python/ray/tune/analysis/experiment_analysis.py @@ -47,7 +47,14 @@ class ExperimentAnalysis(object): >>> experiment_path="~/tune_results/my_exp") """ - def __init__(self, experiment_path): + def __init__(self, experiment_path, trials=None): + """Initializer. + + Args: + experiment_path (str): Path to where experiment is located. + trials (list|None): List of trials that can be accessed via + `analysis.trials`. + """ experiment_path = os.path.expanduser(experiment_path) if not os.path.isdir(experiment_path): raise TuneError( @@ -55,7 +62,8 @@ def __init__(self, experiment_path): experiment_state_paths = glob.glob( os.path.join(experiment_path, "experiment_state*.json")) if not experiment_state_paths: - raise TuneError("No experiment state found!") + raise TuneError( + "No experiment state found in {}!".format(experiment_path)) experiment_filename = max( list(experiment_state_paths)) # if more than one, pick latest with open(os.path.join(experiment_path, experiment_filename)) as f: @@ -65,10 +73,27 @@ def __init__(self, experiment_path): raise TuneError("Experiment state invalid; no checkpoints found.") self._checkpoints = self._experiment_state["checkpoints"] self._scrubbed_checkpoints = unnest_checkpoints(self._checkpoints) + self.trials = trials + self._dataframe = None + + def get_all_trial_dataframes(self): + trial_dfs = {} + for checkpoint in self._checkpoints: + logdir = checkpoint["logdir"] + progress = max(glob.glob(os.path.join(logdir, "progress.csv"))) + trial_dfs[checkpoint["trial_id"]] = pd.read_csv(progress) + return trial_dfs + + def dataframe(self, refresh=False): + """Returns a pandas.DataFrame object constructed from the trials. - def dataframe(self): - """Returns a pandas.DataFrame object constructed from the trials.""" - return pd.DataFrame(self._scrubbed_checkpoints) + Args: + refresh (bool): Clears the cache which may have an existing copy. + + """ + if self._dataframe is None or refresh: + self._dataframe = pd.DataFrame(self._scrubbed_checkpoints) + return self._dataframe def stats(self): """Returns a dictionary of the statistics of the experiment.""" @@ -87,22 +112,45 @@ def trial_dataframe(self, trial_id): return pd.read_csv(progress) raise ValueError("Trial id {} not found".format(trial_id)) - def get_best_trainable(self, metric, trainable_cls): - """Returns the best Trainable based on the experiment metric.""" - return trainable_cls(config=self.get_best_config(metric)) - - def get_best_config(self, metric): - """Retrieve the best config from the best trial.""" - return self._get_best_trial(metric)["config"] - - def _get_best_trial(self, metric): - """Retrieve the best trial based on the experiment metric.""" - return max( + def get_best_trainable(self, metric, trainable_cls, mode="max"): + """Returns the best Trainable based on the experiment metric. + + Args: + metric (str): Key for trial info to order on. + mode (str): One of [min, max]. + + """ + return trainable_cls(config=self.get_best_config(metric, mode=mode)) + + def get_best_config(self, metric, mode="max"): + """Retrieve the best config from the best trial. + + Args: + metric (str): Key for trial info to order on. + mode (str): One of [min, max]. + + """ + return self.get_best_info(metric, flatten=False, mode=mode)["config"] + + def get_best_logdir(self, metric, mode="max"): + df = self.dataframe() + if mode == "max": + return df.iloc[df[metric].idxmax()].logdir + elif mode == "min": + return df.iloc[df[metric].idxmin()].logdir + + def get_best_info(self, metric, mode="max", flatten=True): + """Retrieve the best trial based on the experiment metric. + + Args: + metric (str): Key for trial info to order on. + mode (str): One of [min, max]. + flatten (bool): Assumes trial info is flattened, where + nested entries are concatenated like `info:metric`. + """ + optimize_op = max if mode == "max" else min + if flatten: + return optimize_op( + self._scrubbed_checkpoints, key=lambda d: d.get(metric, 0)) + return optimize_op( self._checkpoints, key=lambda d: d["last_result"].get(metric, 0)) - - def _get_sorted_trials(self, metric): - """Retrive trials in sorted order based on the experiment metric.""" - return sorted( - self._checkpoints, - key=lambda d: d["last_result"].get(metric, 0), - reverse=True) diff --git a/python/ray/tune/examples/track_example.py b/python/ray/tune/examples/track_example.py index 1ccec39462d0..751f0ed44fa9 100644 --- a/python/ray/tune/examples/track_example.py +++ b/python/ray/tune/examples/track_example.py @@ -9,7 +9,7 @@ from keras.layers import (Dense, Dropout, Flatten, Conv2D, MaxPooling2D) from ray.tune import track -from ray.tune.examples.utils import TuneKerasCallback, get_mnist_data +from ray.tune.examples.utils import TuneReporterCallback, get_mnist_data parser = argparse.ArgumentParser() parser.add_argument( @@ -63,7 +63,7 @@ def train_mnist(args): batch_size=batch_size, epochs=epochs, validation_data=(x_test, y_test), - callbacks=[TuneKerasCallback(track.metric)]) + callbacks=[TuneReporterCallback(track.metric)]) track.shutdown() diff --git a/python/ray/tune/examples/tune_mnist_keras.py b/python/ray/tune/examples/tune_mnist_keras.py index 5357d86af19e..ecd3c34bc042 100644 --- a/python/ray/tune/examples/tune_mnist_keras.py +++ b/python/ray/tune/examples/tune_mnist_keras.py @@ -9,8 +9,8 @@ from keras.models import Sequential from keras.layers import (Dense, Dropout, Flatten, Conv2D, MaxPooling2D) -from ray.tune.examples.utils import (TuneKerasCallback, get_mnist_data, - set_keras_threads) +from ray.tune.integration.keras import TuneReporterCallback +from ray.tune.examples.utils import get_mnist_data, set_keras_threads parser = argparse.ArgumentParser() parser.add_argument( @@ -52,7 +52,7 @@ def train_mnist(config, reporter): epochs=epochs, verbose=0, validation_data=(x_test, y_test), - callbacks=[TuneKerasCallback(reporter)]) + callbacks=[TuneReporterCallback(reporter)]) if __name__ == "__main__": @@ -63,7 +63,7 @@ def train_mnist(config, reporter): ray.init() sched = AsyncHyperBandScheduler( - time_attr="timesteps_total", + time_attr="training_iteration", metric="mean_accuracy", mode="max", max_t=400, diff --git a/python/ray/tune/examples/utils.py b/python/ray/tune/examples/utils.py index a5ab1dbdb6a1..f40707a014fc 100644 --- a/python/ray/tune/examples/utils.py +++ b/python/ray/tune/examples/utils.py @@ -5,24 +5,9 @@ import keras from keras.datasets import mnist from keras import backend as K - - -class TuneKerasCallback(keras.callbacks.Callback): - def __init__(self, reporter, logs={}): - self.reporter = reporter - self.iteration = 0 - super(TuneKerasCallback, self).__init__() - - def on_train_end(self, epoch, logs={}): - self.reporter( - timesteps_total=self.iteration, - done=1, - mean_accuracy=logs.get("acc")) - - def on_batch_end(self, batch, logs={}): - self.iteration += 1 - self.reporter( - timesteps_total=self.iteration, mean_accuracy=logs["acc"]) +from sklearn.datasets import load_iris +from sklearn.model_selection import train_test_split +from sklearn.preprocessing import OneHotEncoder def get_mnist_data(): @@ -53,6 +38,16 @@ def get_mnist_data(): return x_train, y_train, x_test, y_test, input_shape +def get_iris_data(test_size=0.2): + iris_data = load_iris() + x = iris_data.data + y = iris_data.target.reshape(-1, 1) + encoder = OneHotEncoder(sparse=False) + y = encoder.fit_transform(y) + train_x, test_x, train_y, test_y = train_test_split(x, y) + return train_x, train_y, test_x, test_y + + def set_keras_threads(threads): # We set threads here to avoid contention, as Keras # is heavily parallelized across multiple cores. @@ -61,3 +56,8 @@ def set_keras_threads(threads): config=K.tf.ConfigProto( intra_op_parallelism_threads=threads, inter_op_parallelism_threads=threads))) + + +def TuneKerasCallback(*args, **kwargs): + raise DeprecationWarning("TuneKerasCallback is now " + "tune.integration.keras.TuneReporterCallback.") diff --git a/python/ray/tune/experiment.py b/python/ray/tune/experiment.py index 5f3e46aabd0a..95cb12043f8f 100644 --- a/python/ray/tune/experiment.py +++ b/python/ray/tune/experiment.py @@ -176,6 +176,14 @@ def _register_if_needed(cls, run_object): else: raise TuneError("Improper 'run' - not string nor trainable.") + @property + def local_dir(self): + return self.spec.get("local_dir") + + @property + def checkpoint_dir(self): + return os.path.join(self.spec["local_dir"], self.name) + def convert_to_experiment_list(experiments): """Produces a list of Experiment objects. diff --git a/python/ray/tune/integration/__init__.py b/python/ray/tune/integration/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/python/ray/tune/integration/keras.py b/python/ray/tune/integration/keras.py new file mode 100644 index 000000000000..197a7eef9841 --- /dev/null +++ b/python/ray/tune/integration/keras.py @@ -0,0 +1,34 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import keras +from ray.tune import track + + +class TuneReporterCallback(keras.callbacks.Callback): + def __init__(self, reporter=None, freq="batch", logs={}): + self.reporter = reporter or track.log + self.iteration = 0 + if freq not in ["batch", "epoch"]: + raise ValueError("{} not supported as a frequency.".format(freq)) + self.freq = freq + super(TuneReporterCallback, self).__init__() + + def on_batch_end(self, batch, logs={}): + if not self.freq == "batch": + return + self.iteration += 1 + for metric in list(logs): + if "loss" in metric and "neg_" not in metric: + logs["neg_" + metric] = -logs[metric] + self.reporter(keras_info=logs, mean_accuracy=logs["acc"]) + + def on_epoch_end(self, batch, logs={}): + if not self.freq == "epoch": + return + self.iteration += 1 + for metric in list(logs): + if "loss" in metric and "neg_" not in metric: + logs["neg_" + metric] = -logs[metric] + self.reporter(keras_info=logs, mean_accuracy=logs["acc"]) diff --git a/python/ray/tune/schedulers/__init__.py b/python/ray/tune/schedulers/__init__.py index 50bb447437e4..34655372f40a 100644 --- a/python/ray/tune/schedulers/__init__.py +++ b/python/ray/tune/schedulers/__init__.py @@ -4,11 +4,13 @@ from ray.tune.schedulers.trial_scheduler import TrialScheduler, FIFOScheduler from ray.tune.schedulers.hyperband import HyperBandScheduler -from ray.tune.schedulers.async_hyperband import AsyncHyperBandScheduler +from ray.tune.schedulers.async_hyperband import (AsyncHyperBandScheduler, + ASHAScheduler) from ray.tune.schedulers.median_stopping_rule import MedianStoppingRule from ray.tune.schedulers.pbt import PopulationBasedTraining __all__ = [ "TrialScheduler", "HyperBandScheduler", "AsyncHyperBandScheduler", - "MedianStoppingRule", "FIFOScheduler", "PopulationBasedTraining" + "ASHAScheduler", "MedianStoppingRule", "FIFOScheduler", + "PopulationBasedTraining" ] diff --git a/python/ray/tune/schedulers/async_hyperband.py b/python/ray/tune/schedulers/async_hyperband.py index 487eb350efcf..0370d03d3b50 100644 --- a/python/ray/tune/schedulers/async_hyperband.py +++ b/python/ray/tune/schedulers/async_hyperband.py @@ -168,6 +168,8 @@ def debug_str(self): return "Bracket: " + iters +ASHAScheduler = AsyncHyperBandScheduler + if __name__ == "__main__": sched = AsyncHyperBandScheduler( grace_period=1, max_t=10, reduction_factor=2) diff --git a/python/ray/tune/tests/test_experiment_analysis.py b/python/ray/tune/tests/test_experiment_analysis.py index a0721abc5d29..7b613a6fdea2 100644 --- a/python/ray/tune/tests/test_experiment_analysis.py +++ b/python/ray/tune/tests/test_experiment_analysis.py @@ -11,9 +11,7 @@ import ray from ray.tune import run, sample_from -from ray.tune.analysis import ExperimentAnalysis from ray.tune.examples.async_hyperband_example import MyTrainableClass -from ray.tune.schedulers import AsyncHyperBandScheduler class ExperimentAnalysisSuite(unittest.TestCase): @@ -27,35 +25,22 @@ def setUp(self): self.test_path = os.path.join(self.test_dir, self.test_name) self.run_test_exp() - self.ea = ExperimentAnalysis(self.test_path) - def tearDown(self): shutil.rmtree(self.test_dir, ignore_errors=True) ray.shutdown() def run_test_exp(self): - ahb = AsyncHyperBandScheduler( - time_attr="training_iteration", - metric=self.metric, - mode="max", - grace_period=5, - max_t=100) - - run(MyTrainableClass, + self.ea = run( + MyTrainableClass, name=self.test_name, - scheduler=ahb, local_dir=self.test_dir, - **{ - "stop": { - "training_iteration": 1 - }, - "num_samples": 10, - "config": { - "width": sample_from( - lambda spec: 10 + int(90 * random.random())), - "height": sample_from( - lambda spec: int(100 * random.random())), - }, + return_trials=False, + stop={"training_iteration": 1}, + num_samples=self.num_samples, + config={ + "width": sample_from( + lambda spec: 10 + int(90 * random.random())), + "height": sample_from(lambda spec: int(100 * random.random())), }) def testDataframe(self): @@ -87,7 +72,7 @@ def testBestConfig(self): self.assertTrue("height" in best_config) def testBestTrial(self): - best_trial = self.ea._get_best_trial(self.metric) + best_trial = self.ea.get_best_info(self.metric, flatten=False) self.assertTrue(isinstance(best_trial, dict)) self.assertTrue("local_dir" in best_trial) @@ -99,6 +84,18 @@ def testBestTrial(self): self.assertTrue("last_result" in best_trial) self.assertTrue(self.metric in best_trial["last_result"]) + min_trial = self.ea.get_best_info( + self.metric, mode="min", flatten=False) + + self.assertTrue(isinstance(min_trial, dict)) + self.assertLess(min_trial["last_result"][self.metric], + best_trial["last_result"][self.metric]) + + flat_trial = self.ea.get_best_info(self.metric, flatten=True) + + self.assertTrue(isinstance(min_trial, dict)) + self.assertTrue(self.metric in flat_trial) + def testCheckpoints(self): checkpoints = self.ea._checkpoints @@ -121,6 +118,21 @@ def testRunnerData(self): self.assertEqual(runner_data["_metadata_checkpoint_dir"], os.path.expanduser(self.test_path)) + def testBestLogdir(self): + logdir = self.ea.get_best_logdir(self.metric) + self.assertTrue(logdir.startswith(self.test_path)) + logdir2 = self.ea.get_best_logdir(self.metric, mode="min") + self.assertTrue(logdir2.startswith(self.test_path)) + self.assertNotEquals(logdir, logdir2) + + def testAllDataframes(self): + dataframes = self.ea.get_all_trial_dataframes() + self.assertTrue(len(dataframes) == self.num_samples) + + self.assertTrue(isinstance(dataframes, dict)) + for df in dataframes.values(): + self.assertEqual(df.training_iteration.max(), 1) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/python/ray/tune/tests/test_trial_runner.py b/python/ray/tune/tests/test_trial_runner.py index 37022ceab615..64b8e9761488 100644 --- a/python/ray/tune/tests/test_trial_runner.py +++ b/python/ray/tune/tests/test_trial_runner.py @@ -441,6 +441,14 @@ def f(): self.assertRaises(TuneError, f) + def testNestedStoppingReturn(self): + def train(config, reporter): + for i in range(10): + reporter(test={"test1": {"test2": i}}) + + [trial] = tune.run(train, stop={"test": {"test1": {"test2": 6}}}) + self.assertEqual(trial.last_result["training_iteration"], 7) + def testEarlyReturn(self): def train(config, reporter): reporter(timesteps_total=100, done=True) diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index cb9351f9adf8..1a44575c716e 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -181,6 +181,21 @@ def has_trainable(trainable_name): ray.tune.registry.TRAINABLE_CLASS, trainable_name) +def recursive_criteria_check(result, criteria): + for criteria, stop_value in criteria.items(): + if criteria not in result: + raise TuneError( + "Stopping criteria {} not provided in result {}.".format( + criteria, result)) + elif isinstance(result[criteria], dict) and isinstance( + stop_value, dict): + if recursive_criteria_check(result[criteria], stop_value): + return True + elif result[criteria] >= stop_value: + return True + return False + + class Checkpoint(object): """Describes a checkpoint of trial state. @@ -425,15 +440,7 @@ def should_stop(self, result): if result.get(DONE): return True - for criteria, stop_value in self.stopping_criterion.items(): - if criteria not in result: - raise TuneError( - "Stopping criteria {} not provided in result {}.".format( - criteria, result)) - if result[criteria] >= stop_value: - return True - - return False + return recursive_criteria_check(result, self.stopping_criterion) def should_checkpoint(self): """Whether this trial is due for checkpointing.""" diff --git a/python/ray/tune/tune.py b/python/ray/tune/tune.py index 0d84b665167a..db302f6bd5e6 100644 --- a/python/ray/tune/tune.py +++ b/python/ray/tune/tune.py @@ -4,11 +4,11 @@ import click import logging -import os import time from ray.tune.error import TuneError from ray.tune.experiment import convert_to_experiment_list, Experiment +from ray.tune.analysis import ExperimentAnalysis from ray.tune.suggest import BasicVariantGenerator from ray.tune.trial import Trial, DEBUG_PRINT_INTERVAL from ray.tune.ray_trial_executor import RayTrialExecutor @@ -39,7 +39,7 @@ def _make_scheduler(args): def _find_checkpoint_dir(exp): # TODO(rliaw): Make sure the checkpoint_dir is resolved earlier. # Right now it is resolved somewhere far down the trial generation process - return os.path.join(exp.spec["local_dir"], exp.name) + return exp.checkpoint_dir def _prompt_restore(checkpoint_dir, resume): @@ -89,9 +89,10 @@ def run(run_or_experiment, verbose=2, resume=False, queue_trials=False, - reuse_actors=False, + reuse_actors=True, trial_executor=None, raise_on_failed_trial=True, + return_trials=True, ray_auto_init=True): """Executes training. @@ -273,7 +274,9 @@ def run(run_or_experiment, else: logger.error("Trials did not complete: %s", errored_trials) - return runner.get_trials() + if return_trials: + return runner.get_trials() + return ExperimentAnalysis(experiment.checkpoint_dir) def run_experiments(experiments, From 3b23d94cb8811f380c059461f48ac4fc538cb858 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Thu, 20 Jun 2019 22:22:37 -0700 Subject: [PATCH 104/118] Fix valgrind build by installing new version of valgrind (#5008) --- .travis.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.travis.yml b/.travis.yml index 9a4fb66d84a0..e5631fd9f0b1 100644 --- a/.travis.yml +++ b/.travis.yml @@ -74,8 +74,12 @@ matrix: - eval `python $TRAVIS_BUILD_DIR/ci/travis/determine_tests_to_run.py` - if [ $RAY_CI_PYTHON_AFFECTED != "1" ]; then exit; fi + # Install a newer version of valgrind, the one that comes with + # Ubuntu 16.04 is broken (Illegal instruction) + - sudo add-apt-repository -y ppa:msulikowski/valgrind - sudo apt-get update -qq - sudo apt-get install -qq valgrind + install: - if [ $RAY_CI_PYTHON_AFFECTED != "1" ]; then exit; fi From a7f84b536f7948a483bc8347df1c2dc0c2b83425 Mon Sep 17 00:00:00 2001 From: Joey Jiang <452084368@qq.com> Date: Fri, 21 Jun 2019 17:08:25 +0800 Subject: [PATCH 105/118] Fix no cpus test (#5009) --- python/ray/tests/conftest.py | 8 ++++++++ python/ray/tests/test_actor.py | 2 +- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/python/ray/tests/conftest.py b/python/ray/tests/conftest.py index 2e670fb0a84d..f7c93fd50c2e 100644 --- a/python/ray/tests/conftest.py +++ b/python/ray/tests/conftest.py @@ -56,6 +56,14 @@ def _ray_start(**kwargs): ray.shutdown() +# The following fixture will start ray with 0 cpu. +@pytest.fixture +def ray_start_no_cpu(request): + param = getattr(request, "param", {}) + with _ray_start(num_cpus=0, **param) as res: + yield res + + # The following fixture will start ray with 1 cpu. @pytest.fixture def ray_start_regular(request): diff --git a/python/ray/tests/test_actor.py b/python/ray/tests/test_actor.py index dd726e00f27b..932f7b090bf7 100644 --- a/python/ray/tests/test_actor.py +++ b/python/ray/tests/test_actor.py @@ -842,7 +842,7 @@ def f(): assert actor_id not in resulting_ids -def test_actors_on_nodes_with_no_cpus(ray_start_regular): +def test_actors_on_nodes_with_no_cpus(ray_start_no_cpu): @ray.remote class Foo(object): def method(self): From 2e342ef71fd12066913d019140bc6ea43d74d6eb Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Fri, 21 Jun 2019 11:04:40 -0700 Subject: [PATCH 106/118] Fix tensorflow-1.14 installation in jenkins (#5007) --- docker/examples/Dockerfile | 2 ++ docker/tune_test/Dockerfile | 2 ++ 2 files changed, 4 insertions(+) diff --git a/docker/examples/Dockerfile b/docker/examples/Dockerfile index 6883c5a64a0e..04df50a822fa 100644 --- a/docker/examples/Dockerfile +++ b/docker/examples/Dockerfile @@ -5,6 +5,8 @@ FROM ray-project/deploy # This updates numpy to 1.14 and mutes errors from other libraries RUN conda install -y numpy RUN apt-get install -y zlib1g-dev +# The following is needed to support TensorFlow 1.14 +RUN conda remove -y --force wrapt RUN pip install gym[atari] opencv-python-headless tensorflow lz4 keras pytest-timeout smart_open RUN pip install -U h5py # Mutes FutureWarnings RUN pip install --upgrade bayesian-optimization diff --git a/docker/tune_test/Dockerfile b/docker/tune_test/Dockerfile index 6e098d5218f6..41ef63390266 100644 --- a/docker/tune_test/Dockerfile +++ b/docker/tune_test/Dockerfile @@ -7,6 +7,8 @@ FROM ray-project/base-deps RUN pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev1-cp27-cp27mu-manylinux1_x86_64.whl boto3 # We install this after the latest wheels -- this should not override the latest wheels. RUN apt-get install -y zlib1g-dev +# The following is needed to support TensorFlow 1.14 +RUN conda remove -y --force wrapt RUN pip install gym[atari]==0.10.11 opencv-python-headless tensorflow lz4 keras pytest-timeout smart_open RUN pip install --upgrade bayesian-optimization RUN pip install --upgrade git+git://github.com/hyperopt/hyperopt.git From e33d0eac68771ee7e2321be2b8bf9f73dc48a9da Mon Sep 17 00:00:00 2001 From: Qing Wang Date: Sun, 23 Jun 2019 18:08:33 +0800 Subject: [PATCH 107/118] Add dynamic worker options for worker command. (#4970) * Add fields for fbs * WIP * Fix complition errors * Add java part * FIx * Fix * Fix * Fix lint * Refine API * address comments and add test * Fix * Address comment. * Address comments. * Fix linting * Refine * Fix lint * WIP: address comment. * Fix java * Fix py * Refin * Fix * Fix * Fix linting * Fix lint * Address comments * WIP * Fix * Fix * minor refine * Fix lint * Fix raylet test. * Fix lint * Update src/ray/raylet/worker_pool.h Co-Authored-By: Hao Chen * Update java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java Co-Authored-By: Hao Chen * Address comments. * Address comments. * Fix test. * Update src/ray/raylet/worker_pool.h Co-Authored-By: Hao Chen * Address comments. * Address comments. * Fix * Fix lint * Fix lint * Fix * Address comments. * Fix linting --- .../ray/api/options/ActorCreationOptions.java | 15 ++- .../org/ray/runtime/AbstractRayRuntime.java | 9 +- .../ray/runtime/raylet/RayletClientImpl.java | 18 +++- .../org/ray/runtime/runner/RunManager.java | 3 + .../java/org/ray/runtime/task/TaskSpec.java | 8 +- .../ray/api/test/WorkerJvmOptionsTest.java | 31 ++++++ python/ray/services.py | 3 + src/ray/common/constants.h | 2 + src/ray/gcs/format/gcs.fbs | 5 + src/ray/raylet/node_manager.cc | 21 ++-- src/ray/raylet/task_spec.cc | 12 ++- src/ray/raylet/task_spec.h | 6 +- src/ray/raylet/worker_pool.cc | 98 ++++++++++++++++--- src/ray/raylet/worker_pool.h | 56 +++++++---- src/ray/raylet/worker_pool_test.cc | 65 ++++++++++-- 15 files changed, 290 insertions(+), 62 deletions(-) create mode 100644 java/test/src/main/java/org/ray/api/test/WorkerJvmOptionsTest.java diff --git a/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java b/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java index d1e92f7bb9e9..2e14ca8584dd 100644 --- a/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java +++ b/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java @@ -13,9 +13,14 @@ public class ActorCreationOptions extends BaseTaskOptions { public final int maxReconstructions; - private ActorCreationOptions(Map resources, int maxReconstructions) { + public final String jvmOptions; + + private ActorCreationOptions(Map resources, + int maxReconstructions, + String jvmOptions) { super(resources); this.maxReconstructions = maxReconstructions; + this.jvmOptions = jvmOptions; } /** @@ -25,6 +30,7 @@ public static class Builder { private Map resources = new HashMap<>(); private int maxReconstructions = NO_RECONSTRUCTION; + private String jvmOptions = ""; public Builder setResources(Map resources) { this.resources = resources; @@ -36,8 +42,13 @@ public Builder setMaxReconstructions(int maxReconstructions) { return this; } + public Builder setJvmOptions(String jvmOptions) { + this.jvmOptions = jvmOptions; + return this; + } + public ActorCreationOptions createActorCreationOptions() { - return new ActorCreationOptions(resources, maxReconstructions); + return new ActorCreationOptions(resources, maxReconstructions, jvmOptions); } } diff --git a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java index fbd03bf10483..26a8d6e541ba 100644 --- a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java @@ -35,6 +35,7 @@ import org.ray.runtime.task.TaskLanguage; import org.ray.runtime.task.TaskSpec; import org.ray.runtime.util.IdUtil; +import org.ray.runtime.util.StringUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -363,8 +364,13 @@ private TaskSpec createTaskSpec(RayFunc func, PyFunctionDescriptor pyFunctionDes } int maxActorReconstruction = 0; + List dynamicWorkerOptions = ImmutableList.of(); if (taskOptions instanceof ActorCreationOptions) { maxActorReconstruction = ((ActorCreationOptions) taskOptions).maxReconstructions; + String jvmOptions = ((ActorCreationOptions) taskOptions).jvmOptions; + if (!StringUtil.isNullOrEmpty(jvmOptions)) { + dynamicWorkerOptions = ImmutableList.of(((ActorCreationOptions) taskOptions).jvmOptions); + } } TaskLanguage language; @@ -393,7 +399,8 @@ private TaskSpec createTaskSpec(RayFunc func, PyFunctionDescriptor pyFunctionDes numReturns, resources, language, - functionDescriptor + functionDescriptor, + dynamicWorkerOptions ); } diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java index 01b9e4675016..c369e6f2cab8 100644 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java +++ b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java @@ -190,9 +190,16 @@ private static TaskSpec parseTaskSpecFromFlatbuffer(ByteBuffer bb) { JavaFunctionDescriptor functionDescriptor = new JavaFunctionDescriptor( info.functionDescriptor(0), info.functionDescriptor(1), info.functionDescriptor(2) ); + + // Deserialize dynamic worker options. + List dynamicWorkerOptions = new ArrayList<>(); + for (int i = 0; i < info.dynamicWorkerOptionsLength(); ++i) { + dynamicWorkerOptions.add(info.dynamicWorkerOptions(i)); + } + return new TaskSpec(driverId, taskId, parentTaskId, parentCounter, actorCreationId, maxActorReconstructions, actorId, actorHandleId, actorCounter, newActorHandles, - args, numReturns, resources, TaskLanguage.JAVA, functionDescriptor); + args, numReturns, resources, TaskLanguage.JAVA, functionDescriptor, dynamicWorkerOptions); } private static ByteBuffer convertTaskSpecToFlatbuffer(TaskSpec task) { @@ -275,6 +282,12 @@ private static ByteBuffer convertTaskSpecToFlatbuffer(TaskSpec task) { functionDescriptorOffset = fbb.createVectorOfTables(functionDescriptorOffsets); } + int [] dynamicWorkerOptionsOffsets = new int[task.dynamicWorkerOptions.size()]; + for (int index = 0; index < task.dynamicWorkerOptions.size(); ++index) { + dynamicWorkerOptionsOffsets[index] = fbb.createString(task.dynamicWorkerOptions.get(index)); + } + int dynamicWorkerOptionsOffset = fbb.createVectorOfTables(dynamicWorkerOptionsOffsets); + int root = TaskInfo.createTaskInfo( fbb, driverIdOffset, @@ -293,7 +306,8 @@ private static ByteBuffer convertTaskSpecToFlatbuffer(TaskSpec task) { requiredResourcesOffset, requiredPlacementResourcesOffset, language, - functionDescriptorOffset); + functionDescriptorOffset, + dynamicWorkerOptionsOffset); fbb.finish(root); ByteBuffer buffer = fbb.dataBuffer(); diff --git a/java/runtime/src/main/java/org/ray/runtime/runner/RunManager.java b/java/runtime/src/main/java/org/ray/runtime/runner/RunManager.java index 15240e43e234..773499fcf5cf 100644 --- a/java/runtime/src/main/java/org/ray/runtime/runner/RunManager.java +++ b/java/runtime/src/main/java/org/ray/runtime/runner/RunManager.java @@ -319,6 +319,9 @@ private String buildWorkerCommandRaylet() { cmd.addAll(rayConfig.jvmParameters); + // jvm options + cmd.add("RAY_WORKER_OPTION_0"); + // Main class cmd.add(WORKER_CLASS); String command = Joiner.on(" ").join(cmd); diff --git a/java/runtime/src/main/java/org/ray/runtime/task/TaskSpec.java b/java/runtime/src/main/java/org/ray/runtime/task/TaskSpec.java index 3473a9bdb3cc..060ca6fff4c3 100644 --- a/java/runtime/src/main/java/org/ray/runtime/task/TaskSpec.java +++ b/java/runtime/src/main/java/org/ray/runtime/task/TaskSpec.java @@ -63,6 +63,8 @@ public class TaskSpec { // Language of this task. public final TaskLanguage language; + public final List dynamicWorkerOptions; + // Descriptor of the remote function. // Note, if task language is Java, the type is JavaFunctionDescriptor. If the task language // is Python, the type is PyFunctionDescriptor. @@ -93,7 +95,8 @@ public TaskSpec( int numReturns, Map resources, TaskLanguage language, - FunctionDescriptor functionDescriptor) { + FunctionDescriptor functionDescriptor, + List dynamicWorkerOptions) { this.driverId = driverId; this.taskId = taskId; this.parentTaskId = parentTaskId; @@ -106,6 +109,8 @@ public TaskSpec( this.newActorHandles = newActorHandles; this.args = args; this.numReturns = numReturns; + this.dynamicWorkerOptions = dynamicWorkerOptions; + returnIds = new ObjectId[numReturns]; for (int i = 0; i < numReturns; ++i) { returnIds[i] = IdUtil.computeReturnId(taskId, i + 1); @@ -157,6 +162,7 @@ public String toString() { ", resources=" + resources + ", language=" + language + ", functionDescriptor=" + functionDescriptor + + ", dynamicWorkerOptions=" + dynamicWorkerOptions + ", executionDependencies=" + executionDependencies + '}'; } diff --git a/java/test/src/main/java/org/ray/api/test/WorkerJvmOptionsTest.java b/java/test/src/main/java/org/ray/api/test/WorkerJvmOptionsTest.java new file mode 100644 index 000000000000..90a2817a8366 --- /dev/null +++ b/java/test/src/main/java/org/ray/api/test/WorkerJvmOptionsTest.java @@ -0,0 +1,31 @@ +package org.ray.api.test; + +import org.ray.api.Ray; +import org.ray.api.RayActor; +import org.ray.api.RayObject; +import org.ray.api.TestUtils; +import org.ray.api.annotation.RayRemote; +import org.ray.api.options.ActorCreationOptions; +import org.testng.Assert; +import org.testng.annotations.Test; + +public class WorkerJvmOptionsTest extends BaseTest { + + @RayRemote + public static class Echo { + String getOptions() { + return System.getProperty("test.suffix"); + } + } + + @Test + public void testJvmOptions() { + TestUtils.skipTestUnderSingleProcess(); + ActorCreationOptions options = new ActorCreationOptions.Builder() + .setJvmOptions("-Dtest.suffix=suffix") + .createActorCreationOptions(); + RayActor actor = Ray.createActor(Echo::new, options); + RayObject obj = Ray.call(Echo::getOptions, actor); + Assert.assertEquals(obj.get(), "suffix"); + } +} diff --git a/python/ray/services.py b/python/ray/services.py index 2c843f7bbbc7..14e13620eea2 100644 --- a/python/ray/services.py +++ b/python/ray/services.py @@ -1233,6 +1233,7 @@ def build_java_worker_command( assert java_worker_options is not None command = "java " + if redis_address is not None: command += "-Dray.redis.address={} ".format(redis_address) @@ -1253,6 +1254,8 @@ def build_java_worker_command( # Put `java_worker_options` in the last, so it can overwrite the # above options. command += java_worker_options + " " + + command += "RAY_WORKER_OPTION_0 " command += "org.ray.runtime.runner.worker.DefaultWorker" return command diff --git a/src/ray/common/constants.h b/src/ray/common/constants.h index c92e6a74aa5d..1f50b8025d57 100644 --- a/src/ray/common/constants.h +++ b/src/ray/common/constants.h @@ -36,4 +36,6 @@ constexpr char kObjectTablePrefix[] = "ObjectTable"; /// Prefix for the task table keys in redis. constexpr char kTaskTablePrefix[] = "TaskTable"; +constexpr char kWorkerDynamicOptionPlaceholderPrefix[] = "RAY_WORKER_OPTION_"; + #endif // RAY_CONSTANTS_H_ diff --git a/src/ray/gcs/format/gcs.fbs b/src/ray/gcs/format/gcs.fbs index 614c80b27672..90476da73425 100644 --- a/src/ray/gcs/format/gcs.fbs +++ b/src/ray/gcs/format/gcs.fbs @@ -106,6 +106,11 @@ table TaskInfo { // For a Python function, it should be: [module_name, class_name, function_name] // For a Java function, it should be: [class_name, method_name, type_descriptor] function_descriptor: [string]; + // The dynamic options used in the worker command when starting the worker process for + // an actor creation task. If the list isn't empty, the options will be used to replace + // the placeholder strings (`RAY_WORKER_OPTION_0`, `RAY_WORKER_OPTION_1`, etc) in the + // worker command. + dynamic_worker_options: [string]; } table ResourcePair { diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index a0bde1ff0655..fc364539ccce 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -83,7 +83,8 @@ NodeManager::NodeManager(boost::asio::io_service &io_service, initial_config_(config), local_available_resources_(config.resource_config), worker_pool_(config.num_initial_workers, config.num_workers_per_process, - config.maximum_startup_concurrency, config.worker_commands), + config.maximum_startup_concurrency, gcs_client_, + config.worker_commands), scheduling_policy_(local_queues_), reconstruction_policy_( io_service_, @@ -1723,18 +1724,6 @@ bool NodeManager::AssignTask(const Task &task) { std::shared_ptr worker = worker_pool_.PopWorker(spec); if (worker == nullptr) { // There are no workers that can execute this task. - if (!spec.IsActorTask()) { - // There are no more non-actor workers available to execute this task. - // Start a new worker. - worker_pool_.StartWorkerProcess(spec.GetLanguage()); - // Push an error message to the user if the worker pool tells us that it is - // getting too big. - const std::string warning_message = worker_pool_.WarningAboutSize(); - if (warning_message != "") { - RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver( - DriverID::Nil(), "worker_pool_large", warning_message, current_time_ms())); - } - } // We couldn't assign this task, as no worker available. return false; } @@ -2205,6 +2194,12 @@ void NodeManager::ForwardTask( const auto &spec = task.GetTaskSpecification(); auto task_id = spec.TaskId(); + if (worker_pool_.HasPendingWorkerForTask(spec.GetLanguage(), task_id)) { + // There is a worker being starting for this task, + // so we shouldn't forward this task to another node. + return; + } + // Get and serialize the task's unforwarded, uncommitted lineage. Lineage uncommitted_lineage; if (lineage_cache_.ContainsTask(task_id)) { diff --git a/src/ray/raylet/task_spec.cc b/src/ray/raylet/task_spec.cc index eeab29272126..1d722de18f73 100644 --- a/src/ray/raylet/task_spec.cc +++ b/src/ray/raylet/task_spec.cc @@ -80,12 +80,12 @@ TaskSpecification::TaskSpecification( const std::vector> &task_arguments, int64_t num_returns, const std::unordered_map &required_resources, const std::unordered_map &required_placement_resources, - const Language &language, const std::vector &function_descriptor) + const Language &language, const std::vector &function_descriptor, + const std::vector &dynamic_worker_options) : spec_() { flatbuffers::FlatBufferBuilder fbb; TaskID task_id = GenerateTaskId(driver_id, parent_task_id, parent_counter); - // Add argument object IDs. std::vector> arguments; for (auto &argument : task_arguments) { @@ -101,7 +101,8 @@ TaskSpecification::TaskSpecification( ids_to_flatbuf(fbb, new_actor_handles), fbb.CreateVector(arguments), num_returns, map_to_flatbuf(fbb, required_resources), map_to_flatbuf(fbb, required_placement_resources), language, - string_vec_to_flatbuf(fbb, function_descriptor)); + string_vec_to_flatbuf(fbb, function_descriptor), + string_vec_to_flatbuf(fbb, dynamic_worker_options)); fbb.Finish(spec); AssignSpecification(fbb.GetBufferPointer(), fbb.GetSize()); } @@ -258,6 +259,11 @@ std::vector TaskSpecification::NewActorHandles() const { return ids_from_flatbuf(*message->new_actor_handles()); } +std::vector TaskSpecification::DynamicWorkerOptions() const { + auto message = flatbuffers::GetRoot(spec_.data()); + return string_vec_from_flatbuf(*message->dynamic_worker_options()); +} + } // namespace raylet } // namespace ray diff --git a/src/ray/raylet/task_spec.h b/src/ray/raylet/task_spec.h index d557c188ae68..8a08e9974ef2 100644 --- a/src/ray/raylet/task_spec.h +++ b/src/ray/raylet/task_spec.h @@ -128,6 +128,7 @@ class TaskSpecification { /// will default to be equal to the required_resources argument. /// \param language The language of the worker that must execute the function. /// \param function_descriptor The function descriptor. + /// \param dynamic_worker_options The dynamic options for starting an actor worker. TaskSpecification( const DriverID &driver_id, const TaskID &parent_task_id, int64_t parent_counter, const ActorID &actor_creation_id, const ObjectID &actor_creation_dummy_object_id, @@ -138,7 +139,8 @@ class TaskSpecification { int64_t num_returns, const std::unordered_map &required_resources, const std::unordered_map &required_placement_resources, - const Language &language, const std::vector &function_descriptor); + const Language &language, const std::vector &function_descriptor, + const std::vector &dynamic_worker_options = {}); /// Deserialize a task specification from a string. /// @@ -214,6 +216,8 @@ class TaskSpecification { ObjectID ActorDummyObject() const; std::vector NewActorHandles() const; + std::vector DynamicWorkerOptions() const; + private: /// Assign the specification data from a pointer. void AssignSpecification(const uint8_t *spec, size_t spec_size); diff --git a/src/ray/raylet/worker_pool.cc b/src/ray/raylet/worker_pool.cc index d4ac4cf4ecce..719378216fb7 100644 --- a/src/ray/raylet/worker_pool.cc +++ b/src/ray/raylet/worker_pool.cc @@ -5,10 +5,12 @@ #include #include +#include "ray/common/constants.h" #include "ray/common/ray_config.h" #include "ray/common/status.h" #include "ray/stats/stats.h" #include "ray/util/logging.h" +#include "ray/util/util.h" namespace { @@ -41,11 +43,12 @@ namespace raylet { /// (num_worker_processes * num_workers_per_process) workers for each language. WorkerPool::WorkerPool( int num_worker_processes, int num_workers_per_process, - int maximum_startup_concurrency, + int maximum_startup_concurrency, std::shared_ptr gcs_client, const std::unordered_map> &worker_commands) : num_workers_per_process_(num_workers_per_process), multiple_for_warning_(std::max(num_worker_processes, maximum_startup_concurrency)), maximum_startup_concurrency_(maximum_startup_concurrency), + gcs_client_(std::move(gcs_client)), last_warning_multiple_(0) { RAY_CHECK(num_workers_per_process > 0) << "num_workers_per_process must be positive."; RAY_CHECK(maximum_startup_concurrency > 0); @@ -98,7 +101,8 @@ uint32_t WorkerPool::Size(const Language &language) const { } } -void WorkerPool::StartWorkerProcess(const Language &language) { +int WorkerPool::StartWorkerProcess(const Language &language, + const std::vector &dynamic_options) { auto &state = GetStateForLanguage(language); // If we are already starting up too many workers, then return without starting // more. @@ -108,7 +112,7 @@ void WorkerPool::StartWorkerProcess(const Language &language) { RAY_LOG(DEBUG) << "Worker not started, " << state.starting_worker_processes.size() << " worker processes of language type " << static_cast(language) << " pending registration"; - return; + return -1; } // Either there are no workers pending registration or the worker start is being forced. RAY_LOG(DEBUG) << "Starting new worker process, current pool has " @@ -117,8 +121,20 @@ void WorkerPool::StartWorkerProcess(const Language &language) { // Extract pointers from the worker command to pass into execvp. std::vector worker_command_args; + size_t dynamic_option_index = 0; for (auto const &token : state.worker_command) { - worker_command_args.push_back(token.c_str()); + const auto option_placeholder = + kWorkerDynamicOptionPlaceholderPrefix + std::to_string(dynamic_option_index); + + if (token == option_placeholder) { + if (!dynamic_options.empty()) { + RAY_CHECK(dynamic_option_index < dynamic_options.size()); + worker_command_args.push_back(dynamic_options[dynamic_option_index].c_str()); + ++dynamic_option_index; + } + } else { + worker_command_args.push_back(token.c_str()); + } } worker_command_args.push_back(nullptr); @@ -126,14 +142,14 @@ void WorkerPool::StartWorkerProcess(const Language &language) { if (pid < 0) { // Failure case. RAY_LOG(FATAL) << "Failed to fork worker process: " << strerror(errno); - return; } else if (pid > 0) { // Parent process case. RAY_LOG(DEBUG) << "Started worker process with pid " << pid; state.starting_worker_processes.emplace( std::make_pair(pid, num_workers_per_process_)); - return; + return pid; } + return -1; } pid_t WorkerPool::StartProcess(const std::vector &worker_command_args) { @@ -158,7 +174,7 @@ pid_t WorkerPool::StartProcess(const std::vector &worker_command_a } void WorkerPool::RegisterWorker(const std::shared_ptr &worker) { - auto pid = worker->Pid(); + const auto pid = worker->Pid(); RAY_LOG(DEBUG) << "Registering worker with pid " << pid; auto &state = GetStateForLanguage(worker->GetLanguage()); state.registered_workers.insert(std::move(worker)); @@ -207,30 +223,74 @@ void WorkerPool::PushWorker(const std::shared_ptr &worker) { RAY_CHECK(worker->GetAssignedTaskId().IsNil()) << "Idle workers cannot have an assigned task ID"; auto &state = GetStateForLanguage(worker->GetLanguage()); - // Add the worker to the idle pool. - if (worker->GetActorId().IsNil()) { - state.idle.insert(std::move(worker)); + + auto it = state.dedicated_workers_to_tasks.find(worker->Pid()); + if (it != state.dedicated_workers_to_tasks.end()) { + // The worker is used for the actor creation task with dynamic options. + // Put it into idle dedicated worker pool. + const auto task_id = it->second; + state.idle_dedicated_workers[task_id] = std::move(worker); } else { - state.idle_actor[worker->GetActorId()] = std::move(worker); + // The worker is not used for the actor creation task without dynamic options. + // Put the worker to the corresponding idle pool. + if (worker->GetActorId().IsNil()) { + state.idle.insert(std::move(worker)); + } else { + state.idle_actor[worker->GetActorId()] = std::move(worker); + } } } std::shared_ptr WorkerPool::PopWorker(const TaskSpecification &task_spec) { auto &state = GetStateForLanguage(task_spec.GetLanguage()); const auto &actor_id = task_spec.ActorId(); + std::shared_ptr worker = nullptr; - if (actor_id.IsNil()) { + int pid = -1; + if (task_spec.IsActorCreationTask() && !task_spec.DynamicWorkerOptions().empty()) { + // Code path of actor creation task with dynamic worker options. + // Try to pop it from idle dedicated pool. + auto it = state.idle_dedicated_workers.find(task_spec.TaskId()); + if (it != state.idle_dedicated_workers.end()) { + // There is an idle dedicated worker for this task. + worker = std::move(it->second); + state.idle_dedicated_workers.erase(it); + // Because we found a worker that can perform this task, + // we can remove it from dedicated_workers_to_tasks. + state.dedicated_workers_to_tasks.erase(worker->Pid()); + state.tasks_to_dedicated_workers.erase(task_spec.TaskId()); + } else if (!HasPendingWorkerForTask(task_spec.GetLanguage(), task_spec.TaskId())) { + // We are not pending a registration from a worker for this task, + // so start a new worker process for this task. + pid = StartWorkerProcess(task_spec.GetLanguage(), task_spec.DynamicWorkerOptions()); + if (pid > 0) { + state.dedicated_workers_to_tasks[pid] = task_spec.TaskId(); + state.tasks_to_dedicated_workers[task_spec.TaskId()] = pid; + } + } + } else if (!task_spec.IsActorTask()) { + // Code path of normal task or actor creation task without dynamic worker options. if (!state.idle.empty()) { worker = std::move(*state.idle.begin()); state.idle.erase(state.idle.begin()); + } else { + // There are no more non-actor workers available to execute this task. + // Start a new worker process. + pid = StartWorkerProcess(task_spec.GetLanguage()); } } else { + // Code path of actor task. auto actor_entry = state.idle_actor.find(actor_id); if (actor_entry != state.idle_actor.end()) { worker = std::move(actor_entry->second); state.idle_actor.erase(actor_entry); } } + + if (worker == nullptr && pid > 0) { + WarnAboutSize(); + } + return worker; } @@ -274,7 +334,7 @@ std::vector> WorkerPool::GetWorkersRunningTasksForDriver return workers; } -std::string WorkerPool::WarningAboutSize() { +void WorkerPool::WarnAboutSize() { int64_t num_workers_started_or_registered = 0; for (const auto &entry : states_by_lang_) { num_workers_started_or_registered += @@ -285,6 +345,8 @@ std::string WorkerPool::WarningAboutSize() { int64_t multiple = num_workers_started_or_registered / multiple_for_warning_; std::stringstream warning_message; if (multiple >= 3 && multiple > last_warning_multiple_) { + // Push an error message to the user if the worker pool tells us that it is + // getting too big. last_warning_multiple_ = multiple; warning_message << "WARNING: " << num_workers_started_or_registered << " workers have been started. This could be a result of using " @@ -292,8 +354,16 @@ std::string WorkerPool::WarningAboutSize() { << "using nested tasks " << "(see https://github.com/ray-project/ray/issues/3644) for " << "some a discussion of workarounds."; + RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver( + DriverID::Nil(), "worker_pool_large", warning_message.str(), current_time_ms())); } - return warning_message.str(); +} + +bool WorkerPool::HasPendingWorkerForTask(const Language &language, + const TaskID &task_id) { + auto &state = GetStateForLanguage(language); + auto it = state.tasks_to_dedicated_workers.find(task_id); + return it != state.tasks_to_dedicated_workers.end(); } std::string WorkerPool::DebugString() const { diff --git a/src/ray/raylet/worker_pool.h b/src/ray/raylet/worker_pool.h index 03443447cf58..e1e726268093 100644 --- a/src/ray/raylet/worker_pool.h +++ b/src/ray/raylet/worker_pool.h @@ -7,6 +7,7 @@ #include #include "ray/common/client_connection.h" +#include "ray/gcs/client.h" #include "ray/gcs/format/util.h" #include "ray/raylet/task.h" #include "ray/raylet/worker.h" @@ -37,22 +38,12 @@ class WorkerPool { /// language. WorkerPool( int num_worker_processes, int num_workers_per_process, - int maximum_startup_concurrency, + int maximum_startup_concurrency, std::shared_ptr gcs_client, const std::unordered_map> &worker_commands); /// Destructor responsible for freeing a set of workers owned by this class. virtual ~WorkerPool(); - /// Asynchronously start a new worker process. Once the worker process has - /// registered with an external server, the process should create and - /// register num_workers_per_process_ workers, then add them to the pool. - /// Failure to start the worker process is a fatal error. If too many workers - /// are already being started, then this function will return without starting - /// any workers. - /// - /// \param language Which language this worker process should be. - void StartWorkerProcess(const Language &language); - /// Register a new worker. The Worker should be added by the caller to the /// pool after it becomes idle (e.g., requests a work assignment). /// @@ -118,6 +109,15 @@ class WorkerPool { std::vector> GetWorkersRunningTasksForDriver( const DriverID &driver_id) const; + /// Whether there is a pending worker for the given task. + /// Note that, this is only used for actor creation task with dynamic options. + /// And if the worker registered but isn't assigned a task, + /// the worker also is in pending state, and this'll return true. + /// + /// \param language The required language. + /// \param task_id The task that we want to query. + bool HasPendingWorkerForTask(const Language &language, const TaskID &task_id); + /// Returns debug string for class. /// /// \return string. @@ -126,24 +126,37 @@ class WorkerPool { /// Record metrics. void RecordMetrics() const; - /// Generate a warning about the number of workers that have registered or - /// started if appropriate. + protected: + /// Asynchronously start a new worker process. Once the worker process has + /// registered with an external server, the process should create and + /// register num_workers_per_process_ workers, then add them to the pool. + /// Failure to start the worker process is a fatal error. If too many workers + /// are already being started, then this function will return without starting + /// any workers. /// - /// \return An empty string if no warning should be generated and otherwise a - /// string with a warning message. - std::string WarningAboutSize(); + /// \param language Which language this worker process should be. + /// \param dynamic_options The dynamic options that we should add for worker command. + /// \return The id of the process that we started if it's positive, + /// otherwise it means we didn't start a process. + int StartWorkerProcess(const Language &language, + const std::vector &dynamic_options = {}); - protected: /// The implementation of how to start a new worker process with command arguments. /// /// \param worker_command_args The command arguments of new worker process. /// \return The process ID of started worker process. virtual pid_t StartProcess(const std::vector &worker_command_args); + /// Push an warning message to user if worker pool is getting to big. + virtual void WarnAboutSize(); + /// An internal data structure that maintains the pool state per language. struct State { /// The commands and arguments used to start the worker process std::vector worker_command; + /// The pool of dedicated workers for actor creation tasks + /// with prefix or suffix worker command. + std::unordered_map> idle_dedicated_workers; /// The pool of idle non-actor workers. std::unordered_set> idle; /// The pool of idle actor workers. @@ -156,6 +169,11 @@ class WorkerPool { /// A map from the pids of starting worker processes /// to the number of their unregistered workers. std::unordered_map starting_worker_processes; + /// A map for looking up the task with dynamic options by the pid of + /// worker. Note that this is used for the dedicated worker processes. + std::unordered_map dedicated_workers_to_tasks; + /// A map for speeding up looking up the pending worker for the given task. + std::unordered_map tasks_to_dedicated_workers; }; /// The number of workers per process. @@ -166,7 +184,7 @@ class WorkerPool { private: /// A helper function that returns the reference of the pool state /// for a given language. - inline State &GetStateForLanguage(const Language &language); + State &GetStateForLanguage(const Language &language); /// We'll push a warning to the user every time a multiple of this many /// workers has been started. @@ -176,6 +194,8 @@ class WorkerPool { /// The last size at which a warning about the number of registered workers /// was generated. int64_t last_warning_multiple_; + /// A client connection to the GCS. + std::shared_ptr gcs_client_; }; } // namespace raylet diff --git a/src/ray/raylet/worker_pool_test.cc b/src/ray/raylet/worker_pool_test.cc index 143ffd57dda6..15a5fb0471e0 100644 --- a/src/ray/raylet/worker_pool_test.cc +++ b/src/ray/raylet/worker_pool_test.cc @@ -1,6 +1,7 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "ray/common/constants.h" #include "ray/raylet/node_manager.h" #include "ray/raylet/worker_pool.h" @@ -14,21 +15,46 @@ int MAXIMUM_STARTUP_CONCURRENCY = 5; class WorkerPoolMock : public WorkerPool { public: WorkerPoolMock() - : WorkerPool(0, NUM_WORKERS_PER_PROCESS, MAXIMUM_STARTUP_CONCURRENCY, - {{Language::PYTHON, {"dummy_py_worker_command"}}, - {Language::JAVA, {"dummy_java_worker_command"}}}), + : WorkerPoolMock({{Language::PYTHON, {"dummy_py_worker_command"}}, + {Language::JAVA, {"dummy_java_worker_command"}}}) {} + + explicit WorkerPoolMock( + const std::unordered_map> &worker_commands) + : WorkerPool(0, NUM_WORKERS_PER_PROCESS, MAXIMUM_STARTUP_CONCURRENCY, nullptr, + worker_commands), last_worker_pid_(0) {} + ~WorkerPoolMock() { // Avoid killing real processes states_by_lang_.clear(); } + void StartWorkerProcess(const Language &language, + const std::vector &dynamic_options = {}) { + WorkerPool::StartWorkerProcess(language, dynamic_options); + } + pid_t StartProcess(const std::vector &worker_command_args) override { - return ++last_worker_pid_; + last_worker_pid_ += 1; + std::vector local_worker_commands_args; + for (auto item : worker_command_args) { + if (item == nullptr) { + break; + } + local_worker_commands_args.push_back(std::string(item)); + } + worker_commands_by_pid[last_worker_pid_] = std::move(local_worker_commands_args); + return last_worker_pid_; } + void WarnAboutSize() override {} + pid_t LastStartedWorkerProcess() const { return last_worker_pid_; } + const std::vector &GetWorkerCommand(int pid) { + return worker_commands_by_pid[pid]; + } + int NumWorkerProcessesStarting() const { int total = 0; for (auto &entry : states_by_lang_) { @@ -39,6 +65,8 @@ class WorkerPoolMock : public WorkerPool { private: int last_worker_pid_; + // The worker commands by pid. + std::unordered_map> worker_commands_by_pid; }; class WorkerPoolTest : public ::testing::Test { @@ -61,6 +89,12 @@ class WorkerPoolTest : public ::testing::Test { return std::shared_ptr(new Worker(pid, language, client)); } + void SetWorkerCommands( + const std::unordered_map> &worker_commands) { + WorkerPoolMock worker_pool(worker_commands); + this->worker_pool_ = std::move(worker_pool); + } + protected: WorkerPoolMock worker_pool_; boost::asio::io_service io_service_; @@ -72,10 +106,10 @@ class WorkerPoolTest : public ::testing::Test { }; static inline TaskSpecification ExampleTaskSpec( - const ActorID actor_id = ActorID::Nil(), - const Language &language = Language::PYTHON) { + const ActorID actor_id = ActorID::Nil(), const Language &language = Language::PYTHON, + const ActorID actor_creation_id = ActorID::Nil()) { std::vector function_descriptor(3); - return TaskSpecification(DriverID::Nil(), TaskID::Nil(), 0, ActorID::Nil(), + return TaskSpecification(DriverID::Nil(), TaskID::Nil(), 0, actor_creation_id, ObjectID::Nil(), 0, actor_id, ActorHandleID::Nil(), 0, {}, {}, 0, {}, {}, language, function_descriptor); } @@ -186,6 +220,23 @@ TEST_F(WorkerPoolTest, PopWorkersOfMultipleLanguages) { ASSERT_NE(worker_pool_.PopWorker(java_task_spec), nullptr); } +TEST_F(WorkerPoolTest, StartWorkerWithDynamicOptionsCommand) { + const std::vector java_worker_command = { + "RAY_WORKER_OPTION_0", "dummy_java_worker_command", "RAY_WORKER_OPTION_1"}; + SetWorkerCommands({{Language::PYTHON, {"dummy_py_worker_command"}}, + {Language::JAVA, java_worker_command}}); + + TaskSpecification task_spec(DriverID::Nil(), TaskID::Nil(), 0, ActorID::FromRandom(), + ObjectID::Nil(), 0, ActorID::Nil(), ActorHandleID::Nil(), 0, + {}, {}, 0, {}, {}, Language::JAVA, {"", "", ""}, + {"test_op_0", "test_op_1"}); + worker_pool_.StartWorkerProcess(Language::JAVA, task_spec.DynamicWorkerOptions()); + const auto real_command = + worker_pool_.GetWorkerCommand(worker_pool_.LastStartedWorkerProcess()); + ASSERT_EQ(real_command, std::vector( + {"test_op_0", "dummy_java_worker_command", "test_op_1"})); +} + } // namespace raylet } // namespace ray From 11ccf6634607f45c8cd15c0b54d52d47d800addf Mon Sep 17 00:00:00 2001 From: Ashwinee Panda Date: Mon, 24 Jun 2019 11:26:53 -0700 Subject: [PATCH 108/118] [docs] docs for running Tensorboard without sudo (#5015) * Instructions for running Tensorboard without sudo When we run Tensorboard to visualize the results of Ray outputs on multi-user clusters where we don't have sudo access, such as RISE clusters, a few commands need to first be run to make sure tensorboard can edit the tmp directory. This is a pretty common usecase so I figured we may as well put it in the documentation for Tune. * Update tune-usage.rst --- doc/source/tune-usage.rst | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/doc/source/tune-usage.rst b/doc/source/tune-usage.rst index 281ccbd6107e..e8ce405d9457 100644 --- a/doc/source/tune-usage.rst +++ b/doc/source/tune-usage.rst @@ -355,6 +355,12 @@ Then, after you run a experiment, you can visualize your experiment with TensorB $ tensorboard --logdir=~/ray_results/my_experiment +If you are running Ray on a remote multi-user cluster where you do not have sudo access, you can run the following commands to make sure tensorboard is able to write to the tmp directory: + +.. code-block:: bash + + $ export TMPDIR=/tmp/$USER; mkdir -p $TMPDIR; tensorboard --logdir=~/ray_results + .. image:: ray-tune-tensorboard.png To use rllab's VisKit (you may have to install some dependencies), run: From bd8aceb8968f0bdf6e2717a20fac0bb9def200aa Mon Sep 17 00:00:00 2001 From: Richard Liaw Date: Mon, 24 Jun 2019 21:50:37 -0700 Subject: [PATCH 109/118] [ci] Change Jenkins to py3 (#5022) * conda3 * integration * add nevergrad, remotedata * pytest 0.3.1 * otherdockers * setup * tune --- .../perf_integration_tests/run_perf_integration.sh | 2 +- ci/jenkins_tests/run_tune_tests.sh | 6 +++--- docker/base-deps/Dockerfile | 2 +- docker/examples/Dockerfile | 3 ++- docker/stress_test/Dockerfile | 2 +- docker/tune_test/Dockerfile | 7 ++++--- 6 files changed, 12 insertions(+), 10 deletions(-) diff --git a/ci/jenkins_tests/perf_integration_tests/run_perf_integration.sh b/ci/jenkins_tests/perf_integration_tests/run_perf_integration.sh index 7962b21075c0..f25d32df22a1 100755 --- a/ci/jenkins_tests/perf_integration_tests/run_perf_integration.sh +++ b/ci/jenkins_tests/perf_integration_tests/run_perf_integration.sh @@ -9,7 +9,7 @@ pushd "$ROOT_DIR" python -m pip install pytest-benchmark -pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev1-cp27-cp27mu-manylinux1_x86_64.whl +pip install -U https://ray-wheels.s3-us-west-2.amazonaws.com/latest/ray-0.8.0.dev1-cp36-cp36m-manylinux1_x86_64.whl python -m pytest --benchmark-autosave --benchmark-min-rounds=10 --benchmark-columns="min, max, mean" $ROOT_DIR/../../../python/ray/tests/perf_integration_tests/test_perf_integration.py pushd $ROOT_DIR/../../../python diff --git a/ci/jenkins_tests/run_tune_tests.sh b/ci/jenkins_tests/run_tune_tests.sh index 6154fe70d4f6..84e7e7fe9c0f 100755 --- a/ci/jenkins_tests/run_tune_tests.sh +++ b/ci/jenkins_tests/run_tune_tests.sh @@ -78,9 +78,9 @@ $SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} --smoke-test # Runs only on Python3 -# docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ -# python3 /ray/python/ray/tune/examples/nevergrad_example.py \ -# --smoke-test +$SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ + python /ray/python/ray/tune/examples/nevergrad_example.py \ + --smoke-test $SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/tune/examples/tune_mnist_keras.py \ diff --git a/docker/base-deps/Dockerfile b/docker/base-deps/Dockerfile index c21430c627a4..db8f28c85f86 100644 --- a/docker/base-deps/Dockerfile +++ b/docker/base-deps/Dockerfile @@ -12,7 +12,7 @@ RUN apt-get update \ && apt-get clean \ && echo 'export PATH=/opt/conda/bin:$PATH' > /etc/profile.d/conda.sh \ && wget \ - --quiet 'https://repo.continuum.io/archive/Anaconda2-5.2.0-Linux-x86_64.sh' \ + --quiet 'https://repo.continuum.io/archive/Anaconda3-5.2.0-Linux-x86_64.sh' \ -O /tmp/anaconda.sh \ && /bin/bash /tmp/anaconda.sh -b -p /opt/conda \ && rm /tmp/anaconda.sh \ diff --git a/docker/examples/Dockerfile b/docker/examples/Dockerfile index 04df50a822fa..bafcdf35e628 100644 --- a/docker/examples/Dockerfile +++ b/docker/examples/Dockerfile @@ -12,6 +12,7 @@ RUN pip install -U h5py # Mutes FutureWarnings RUN pip install --upgrade bayesian-optimization RUN pip install --upgrade git+git://github.com/hyperopt/hyperopt.git RUN pip install --upgrade sigopt -# RUN pip install --upgrade nevergrad +RUN pip install --upgrade nevergrad RUN pip install --upgrade scikit-optimize +RUN pip install -U pytest-remotedata>=0.3.1 RUN conda install pytorch-cpu torchvision-cpu -c pytorch diff --git a/docker/stress_test/Dockerfile b/docker/stress_test/Dockerfile index 1d174ed72f92..376fe5340fd9 100644 --- a/docker/stress_test/Dockerfile +++ b/docker/stress_test/Dockerfile @@ -4,7 +4,7 @@ FROM ray-project/base-deps # We install ray and boto3 to enable the ray autoscaler as # a test runner. -RUN pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev1-cp27-cp27mu-manylinux1_x86_64.whl boto3 +RUN pip install -U https://ray-wheels.s3-us-west-2.amazonaws.com/latest/ray-0.8.0.dev1-cp36-cp36m-manylinux1_x86_64.whl boto3 RUN mkdir -p /root/.ssh/ # We port the source code in so that we run the most up-to-date stress tests. diff --git a/docker/tune_test/Dockerfile b/docker/tune_test/Dockerfile index 41ef63390266..1d252a62fd62 100644 --- a/docker/tune_test/Dockerfile +++ b/docker/tune_test/Dockerfile @@ -4,7 +4,7 @@ FROM ray-project/base-deps # We install ray and boto3 to enable the ray autoscaler as # a test runner. -RUN pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev1-cp27-cp27mu-manylinux1_x86_64.whl boto3 +RUN pip install -U https://ray-wheels.s3-us-west-2.amazonaws.com/latest/ray-0.8.0.dev1-cp36-cp36m-manylinux1_x86_64.whl boto3 # We install this after the latest wheels -- this should not override the latest wheels. RUN apt-get install -y zlib1g-dev # The following is needed to support TensorFlow 1.14 @@ -13,8 +13,9 @@ RUN pip install gym[atari]==0.10.11 opencv-python-headless tensorflow lz4 keras RUN pip install --upgrade bayesian-optimization RUN pip install --upgrade git+git://github.com/hyperopt/hyperopt.git RUN pip install --upgrade sigopt -# RUN pip install --upgrade nevergrad +RUN pip install --upgrade nevergrad RUN pip install --upgrade scikit-optimize +RUN pip install -U pytest-remotedata>=0.3.1 RUN conda install pytorch-cpu torchvision-cpu -c pytorch # RUN mkdir -p /root/.ssh/ @@ -22,6 +23,6 @@ RUN conda install pytorch-cpu torchvision-cpu -c pytorch # We port the source code in so that we run the most up-to-date stress tests. ADD ray.tar /ray ADD git-rev /ray/git-rev -RUN python /ray/python/ray/rllib/setup-rllib-dev.py --yes +RUN python /ray/python/ray/setup-dev.py --yes WORKDIR /ray From 0131353d42f20f28034480d45918130525fb0377 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Wed, 26 Jun 2019 05:31:19 +0800 Subject: [PATCH 110/118] [gRPC] Migrate gcs data structures to protobuf (#5024) --- BUILD.bazel | 96 ++-- bazel/ray_deps_build_all.bzl | 4 + bazel/ray_deps_setup.bzl | 11 +- doc/source/conf.py | 15 +- java/BUILD.bazel | 51 +-- java/dependencies.bzl | 1 + ...modify_generated_java_flatbuffers_files.py | 20 +- java/runtime/pom.xml | 5 + .../java/org/ray/runtime/gcs/GcsClient.java | 69 +-- .../runtime/objectstore/ObjectStoreProxy.java | 12 +- python/ray/gcs_utils.py | 71 ++- python/ray/monitor.py | 33 +- python/ray/state.py | 230 ++++------ python/ray/tests/cluster_utils.py | 4 +- python/ray/tests/test_basic.py | 14 +- python/ray/tests/test_failure.py | 5 +- python/ray/utils.py | 8 +- python/ray/worker.py | 40 +- python/setup.py | 1 + src/ray/gcs/client.cc | 4 - src/ray/gcs/client.h | 6 - src/ray/gcs/client_test.cc | 353 +++++++-------- src/ray/gcs/format/gcs.fbs | 281 +----------- src/ray/gcs/redis_context.h | 15 +- src/ray/gcs/redis_module/ray_redis_module.cc | 209 ++++----- src/ray/gcs/tables.cc | 417 ++++++++---------- src/ray/gcs/tables.h | 136 +++--- src/ray/object_manager/object_directory.cc | 34 +- src/ray/object_manager/object_manager.cc | 49 +- src/ray/object_manager/object_manager.h | 4 +- .../test/object_manager_stress_test.cc | 30 +- .../test/object_manager_test.cc | 36 +- src/ray/protobuf/gcs.proto | 280 ++++++++++++ src/ray/raylet/actor_registration.cc | 51 +-- src/ray/raylet/actor_registration.h | 24 +- src/ray/raylet/lineage_cache.cc | 37 +- src/ray/raylet/lineage_cache.h | 28 +- src/ray/raylet/lineage_cache_test.cc | 28 +- src/ray/raylet/monitor.cc | 15 +- src/ray/raylet/monitor.h | 8 +- src/ray/raylet/node_manager.cc | 237 +++++----- src/ray/raylet/node_manager.h | 26 +- src/ray/raylet/raylet.cc | 24 +- src/ray/raylet/raylet.h | 2 + src/ray/raylet/reconstruction_policy.cc | 10 +- src/ray/raylet/reconstruction_policy.h | 2 + src/ray/raylet/reconstruction_policy_test.cc | 42 +- src/ray/raylet/task_dependency_manager.cc | 8 +- src/ray/raylet/task_dependency_manager.h | 2 + .../raylet/task_dependency_manager_test.cc | 2 +- src/ray/raylet/worker_pool.cc | 4 +- src/ray/rpc/util.h | 13 + 52 files changed, 1465 insertions(+), 1642 deletions(-) create mode 100644 src/ray/protobuf/gcs.proto diff --git a/BUILD.bazel b/BUILD.bazel index da36eec0cf57..bc9e6bcd8006 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -1,22 +1,55 @@ # Bazel build # C/C++ documentation: https://docs.bazel.build/versions/master/be/c-cpp.html -load("@com_github_grpc_grpc//bazel:grpc_build_system.bzl", "grpc_proto_library") +load("@com_github_grpc_grpc//bazel:cc_grpc_library.bzl", "cc_grpc_library") +load("@build_stack_rules_proto//python:python_proto_compile.bzl", "python_proto_compile") load("@com_github_google_flatbuffers//:build_defs.bzl", "flatbuffer_cc_library") load("@//bazel:ray.bzl", "flatbuffer_py_library") load("@//bazel:cython_library.bzl", "pyx_library") COPTS = ["-DRAY_USE_GLOG"] -# Node manager gRPC lib. -grpc_proto_library( - name = "node_manager_grpc_lib", +# === Begin of protobuf definitions === + +proto_library( + name = "gcs_proto", + srcs = ["src/ray/protobuf/gcs.proto"], + visibility = ["//java:__subpackages__"], +) + +cc_proto_library( + name = "gcs_cc_proto", + deps = [":gcs_proto"], +) + +python_proto_compile( + name = "gcs_py_proto", + deps = [":gcs_proto"], +) + +proto_library( + name = "node_manager_proto", srcs = ["src/ray/protobuf/node_manager.proto"], ) +cc_proto_library( + name = "node_manager_cc_proto", + deps = ["node_manager_proto"], +) + +# === End of protobuf definitions === + +# Node manager gRPC lib. +cc_grpc_library( + name = "node_manager_cc_grpc", + srcs = [":node_manager_proto"], + grpc_only = True, + deps = [":node_manager_cc_proto"], +) + # Node manager server and client. cc_library( - name = "node_manager_rpc_lib", + name = "node_manager_rpc", srcs = glob([ "src/ray/rpc/*.cc", ]), @@ -25,7 +58,7 @@ cc_library( ]), copts = COPTS, deps = [ - ":node_manager_grpc_lib", + ":node_manager_cc_grpc", ":ray_common", "@boost//:asio", "@com_github_grpc_grpc//:grpc++", @@ -114,7 +147,7 @@ cc_library( ":gcs", ":gcs_fbs", ":node_manager_fbs", - ":node_manager_rpc_lib", + ":node_manager_rpc", ":object_manager", ":ray_common", ":ray_util", @@ -422,9 +455,11 @@ cc_library( "src/ray/gcs/format", ], deps = [ + ":gcs_cc_proto", ":gcs_fbs", ":hiredis", ":node_manager_fbs", + ":node_manager_rpc", ":ray_common", ":ray_util", ":stats_lib", @@ -555,46 +590,6 @@ filegroup( visibility = ["//java:__subpackages__"], ) -flatbuffer_py_library( - name = "python_gcs_fbs", - srcs = [ - ":gcs_fbs_file", - ], - outs = [ - "ActorCheckpointIdData.py", - "ActorState.py", - "ActorTableData.py", - "Arg.py", - "ClassTableData.py", - "ClientTableData.py", - "ConfigTableData.py", - "CustomSerializerData.py", - "DriverTableData.py", - "EntryType.py", - "ErrorTableData.py", - "ErrorType.py", - "FunctionTableData.py", - "GcsEntry.py", - "HeartbeatBatchTableData.py", - "HeartbeatTableData.py", - "Language.py", - "ObjectTableData.py", - "ProfileEvent.py", - "ProfileTableData.py", - "RayResource.py", - "ResourcePair.py", - "SchedulingState.py", - "TablePrefix.py", - "TablePubsub.py", - "TaskInfo.py", - "TaskLeaseData.py", - "TaskReconstructionData.py", - "TaskTableData.py", - "TaskTableTestAndUpdate.py", - ], - out_prefix = "python/ray/core/generated/", -) - flatbuffer_py_library( name = "python_node_manager_fbs", srcs = [ @@ -679,6 +674,7 @@ cc_binary( linkstatic = 1, visibility = ["//java:__subpackages__"], deps = [ + ":gcs_cc_proto", ":ray_common", ], ) @@ -688,7 +684,7 @@ genrule( srcs = [ "python/ray/_raylet.so", "//:python_sources", - "//:python_gcs_fbs", + "//:gcs_py_proto", "//:python_node_manager_fbs", "//:redis-server", "//:redis-cli", @@ -710,11 +706,13 @@ genrule( cp -f $(location //:raylet_monitor) $$WORK_DIR/python/ray/core/src/ray/raylet/ && cp -f $(location @plasma//:plasma_store_server) $$WORK_DIR/python/ray/core/src/plasma/ && cp -f $(location //:raylet) $$WORK_DIR/python/ray/core/src/ray/raylet/ && - for f in $(locations //:python_gcs_fbs); do cp -f $$f $$WORK_DIR/python/ray/core/generated/; done && mkdir -p $$WORK_DIR/python/ray/core/generated/ray/protocol/ && for f in $(locations //:python_node_manager_fbs); do cp -f $$f $$WORK_DIR/python/ray/core/generated/ray/protocol/; done && + for f in $(locations //:gcs_py_proto); do + cp -f $$f $$WORK_DIR/python/ray/core/generated/; + done && echo $$WORK_DIR > $@ """, local = 1, diff --git a/bazel/ray_deps_build_all.bzl b/bazel/ray_deps_build_all.bzl index 3e1e1838a59a..eda88bece7d2 100644 --- a/bazel/ray_deps_build_all.bzl +++ b/bazel/ray_deps_build_all.bzl @@ -4,6 +4,8 @@ load("@com_github_jupp0r_prometheus_cpp//:repositories.bzl", "prometheus_cpp_rep load("@com_github_ray_project_ray//bazel:python_configure.bzl", "python_configure") load("@com_github_checkstyle_java//:repo.bzl", "checkstyle_deps") load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps") +load("@build_stack_rules_proto//java:deps.bzl", "java_proto_compile") +load("@build_stack_rules_proto//python:deps.bzl", "python_proto_compile") def ray_deps_build_all(): @@ -13,4 +15,6 @@ def ray_deps_build_all(): prometheus_cpp_repositories() python_configure(name = "local_config_python") grpc_deps() + java_proto_compile() + python_proto_compile() diff --git a/bazel/ray_deps_setup.bzl b/bazel/ray_deps_setup.bzl index e6dc21585699..aa322654cf9f 100644 --- a/bazel/ray_deps_setup.bzl +++ b/bazel/ray_deps_setup.bzl @@ -105,7 +105,14 @@ def ray_deps_setup(): http_archive( name = "com_github_grpc_grpc", urls = [ - "https://github.com/grpc/grpc/archive/7741e806a213cba63c96234f16d712a8aa101a49.tar.gz", + "https://github.com/grpc/grpc/archive/76a381869413834692b8ed305fbe923c0f9c4472.tar.gz", ], - strip_prefix = "grpc-7741e806a213cba63c96234f16d712a8aa101a49", + strip_prefix = "grpc-76a381869413834692b8ed305fbe923c0f9c4472", + ) + + http_archive( + name = "build_stack_rules_proto", + urls = ["https://github.com/stackb/rules_proto/archive/b93b544f851fdcd3fc5c3d47aee3b7ca158a8841.tar.gz"], + sha256 = "c62f0b442e82a6152fcd5b1c0b7c4028233a9e314078952b6b04253421d56d61", + strip_prefix = "rules_proto-b93b544f851fdcd3fc5c3d47aee3b7ca158a8841", ) diff --git a/doc/source/conf.py b/doc/source/conf.py index 98fb3e0d02dd..5cf6b01217f9 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -23,20 +23,7 @@ "gym.spaces", "ray._raylet", "ray.core.generated", - "ray.core.generated.ActorCheckpointIdData", - "ray.core.generated.ClientTableData", - "ray.core.generated.DriverTableData", - "ray.core.generated.EntryType", - "ray.core.generated.ErrorTableData", - "ray.core.generated.ErrorType", - "ray.core.generated.GcsEntry", - "ray.core.generated.HeartbeatBatchTableData", - "ray.core.generated.HeartbeatTableData", - "ray.core.generated.Language", - "ray.core.generated.ObjectTableData", - "ray.core.generated.ProfileTableData", - "ray.core.generated.TablePrefix", - "ray.core.generated.TablePubsub", + "ray.core.generated.gcs_pb2", "ray.core.generated.ray.protocol.Task", "scipy", "scipy.signal", diff --git a/java/BUILD.bazel b/java/BUILD.bazel index 80ccabccfc12..4960434af180 100644 --- a/java/BUILD.bazel +++ b/java/BUILD.bazel @@ -1,4 +1,5 @@ load("//bazel:ray.bzl", "flatbuffer_java_library", "define_java_module") +load("@build_stack_rules_proto//java:java_proto_compile.bzl", "java_proto_compile") exports_files([ "testng.xml", @@ -50,6 +51,7 @@ define_java_module( name = "runtime", additional_srcs = [ ":generate_java_gcs_fbs", + ":gcs_java_proto", ], additional_resources = [ ":java_native_deps", @@ -68,6 +70,7 @@ define_java_module( "@plasma//:org_apache_arrow_arrow_plasma", "@maven//:com_github_davidmoten_flatbuffers_java", "@maven//:com_google_guava_guava", + "@maven//:com_google_protobuf_protobuf_java", "@maven//:com_typesafe_config", "@maven//:commons_io_commons_io", "@maven//:de_ruedigermoeller_fst", @@ -148,38 +151,16 @@ java_binary( ], ) +java_proto_compile( + name = "gcs_java_proto", + deps = ["@//:gcs_proto"], +) + flatbuffers_generated_files = [ - "ActorCheckpointData.java", - "ActorCheckpointIdData.java", - "ActorState.java", - "ActorTableData.java", "Arg.java", - "ClassTableData.java", - "ClientTableData.java", - "ConfigTableData.java", - "CustomSerializerData.java", - "DriverTableData.java", - "EntryType.java", - "ErrorTableData.java", - "ErrorType.java", - "FunctionTableData.java", - "GcsEntry.java", - "HeartbeatBatchTableData.java", - "HeartbeatTableData.java", "Language.java", - "ObjectTableData.java", - "ProfileEvent.java", - "ProfileTableData.java", - "RayResource.java", - "ResourcePair.java", - "SchedulingState.java", - "TablePrefix.java", - "TablePubsub.java", "TaskInfo.java", - "TaskLeaseData.java", - "TaskReconstructionData.java", - "TaskTableData.java", - "TaskTableTestAndUpdate.java", + "ResourcePair.java", ] flatbuffer_java_library( @@ -198,7 +179,7 @@ genrule( cmd = """ for f in $(locations //java:java_gcs_fbs); do chmod +w $$f - cp -f $$f $(@D)/runtime/src/main/java/org/ray/runtime/generated + mv -f $$f $(@D)/runtime/src/main/java/org/ray/runtime/generated done python $$(pwd)/java/modify_generated_java_flatbuffers_files.py $(@D)/.. """, @@ -221,8 +202,10 @@ filegroup( genrule( name = "gen_maven_deps", srcs = [ - ":java_native_deps", + ":gcs_java_proto", ":generate_java_gcs_fbs", + ":java_native_deps", + ":copy_pom_file", "@plasma//:org_apache_arrow_arrow_plasma", ], outs = ["gen_maven_deps.out"], @@ -237,10 +220,15 @@ genrule( chmod +w $$f cp $$f $$NATIVE_DEPS_DIR done - # Copy flatbuffers-generated files + # Copy protobuf-generated files. GENERATED_DIR=$$WORK_DIR/java/runtime/src/main/java/org/ray/runtime/generated rm -rf $$GENERATED_DIR mkdir -p $$GENERATED_DIR + for f in $(locations //java:gcs_java_proto); do + unzip $$f + mv org/ray/runtime/generated/* $$GENERATED_DIR + done + # Copy flatbuffers-generated files for f in $(locations //java:generate_java_gcs_fbs); do cp $$f $$GENERATED_DIR done @@ -250,6 +238,7 @@ genrule( echo $$(date) > $@ """, local = 1, + tags = ["no-cache"], ) genrule( diff --git a/java/dependencies.bzl b/java/dependencies.bzl index 7c716166d399..ef667137562b 100644 --- a/java/dependencies.bzl +++ b/java/dependencies.bzl @@ -6,6 +6,7 @@ def gen_java_deps(): "com.beust:jcommander:1.72", "com.github.davidmoten:flatbuffers-java:1.9.0.1", "com.google.guava:guava:27.0.1-jre", + "com.google.protobuf:protobuf-java:3.8.0", "com.puppycrawl.tools:checkstyle:8.15", "com.sun.xml.bind:jaxb-core:2.3.0", "com.sun.xml.bind:jaxb-impl:2.3.0", diff --git a/java/modify_generated_java_flatbuffers_files.py b/java/modify_generated_java_flatbuffers_files.py index c1b723f25f8d..5bf62e56d7e4 100644 --- a/java/modify_generated_java_flatbuffers_files.py +++ b/java/modify_generated_java_flatbuffers_files.py @@ -4,7 +4,6 @@ import os import sys - """ This script is used for modifying the generated java flatbuffer files for the reason: The package declaration in Java is different @@ -21,19 +20,18 @@ PACKAGE_DECLARATION = "package org.ray.runtime.generated;" -def add_new_line(file, line_num, text): +def add_package(file): with open(file, "r") as file_handler: lines = file_handler.readlines() - if (line_num <= 0) or (line_num > len(lines) + 1): - return False - lines.insert(line_num - 1, text + os.linesep) + if "FlatBuffers" not in lines[0]: + return + + lines.insert(1, PACKAGE_DECLARATION + os.linesep) with open(file, "w") as file_handler: for line in lines: file_handler.write(line) - return True - def add_package_declarations(generated_root_path): file_names = os.listdir(generated_root_path) @@ -41,15 +39,11 @@ def add_package_declarations(generated_root_path): if not file_name.endswith(".java"): continue full_name = os.path.join(generated_root_path, file_name) - success = add_new_line(full_name, 2, PACKAGE_DECLARATION) - if not success: - raise RuntimeError("Failed to add package declarations, " - "file name is %s" % full_name) + add_package(full_name) if __name__ == "__main__": ray_home = sys.argv[1] root_path = os.path.join( - ray_home, - "java/runtime/src/main/java/org/ray/runtime/generated") + ray_home, "java/runtime/src/main/java/org/ray/runtime/generated") add_package_declarations(root_path) diff --git a/java/runtime/pom.xml b/java/runtime/pom.xml index c75e2eeef13f..e13dd95f927f 100644 --- a/java/runtime/pom.xml +++ b/java/runtime/pom.xml @@ -41,6 +41,11 @@ guava 27.0.1-jre + + com.google.protobuf + protobuf-java + 3.8.0 + com.typesafe config diff --git a/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java b/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java index 431b48ded58c..17c248ed0a57 100644 --- a/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java +++ b/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java @@ -1,7 +1,7 @@ package org.ray.runtime.gcs; import com.google.common.base.Preconditions; -import java.nio.ByteBuffer; +import com.google.protobuf.InvalidProtocolBufferException; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -13,10 +13,10 @@ import org.ray.api.id.TaskId; import org.ray.api.id.UniqueId; import org.ray.api.runtimecontext.NodeInfo; -import org.ray.runtime.generated.ActorCheckpointIdData; -import org.ray.runtime.generated.ClientTableData; -import org.ray.runtime.generated.EntryType; -import org.ray.runtime.generated.TablePrefix; +import org.ray.runtime.generated.Gcs.ActorCheckpointIdData; +import org.ray.runtime.generated.Gcs.ClientTableData; +import org.ray.runtime.generated.Gcs.ClientTableData.EntryType; +import org.ray.runtime.generated.Gcs.TablePrefix; import org.ray.runtime.util.IdUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -51,7 +51,7 @@ public GcsClient(String redisAddress, String redisPassword) { } public List getAllNodeInfo() { - final String prefix = TablePrefix.name(TablePrefix.CLIENT); + final String prefix = TablePrefix.CLIENT.toString(); final byte[] key = ArrayUtils.addAll(prefix.getBytes(), UniqueId.NIL.getBytes()); List results = primary.lrange(key, 0, -1); @@ -63,36 +63,42 @@ public List getAllNodeInfo() { Map clients = new HashMap<>(); for (byte[] result : results) { Preconditions.checkNotNull(result); - ClientTableData data = ClientTableData.getRootAsClientTableData(ByteBuffer.wrap(result)); - final UniqueId clientId = UniqueId.fromByteBuffer(data.clientIdAsByteBuffer()); + ClientTableData data = null; + try { + data = ClientTableData.parseFrom(result); + } catch (InvalidProtocolBufferException e) { + throw new RuntimeException("Received invalid protobuf data from GCS."); + } + final UniqueId clientId = UniqueId + .fromByteBuffer(data.getClientId().asReadOnlyByteBuffer()); - if (data.entryType() == EntryType.INSERTION) { + if (data.getEntryType() == EntryType.INSERTION) { //Code path of node insertion. Map resources = new HashMap<>(); // Compute resources. Preconditions.checkState( - data.resourcesTotalLabelLength() == data.resourcesTotalCapacityLength()); - for (int i = 0; i < data.resourcesTotalLabelLength(); i++) { - resources.put(data.resourcesTotalLabel(i), data.resourcesTotalCapacity(i)); + data.getResourcesTotalLabelCount() == data.getResourcesTotalCapacityCount()); + for (int i = 0; i < data.getResourcesTotalLabelCount(); i++) { + resources.put(data.getResourcesTotalLabel(i), data.getResourcesTotalCapacity(i)); } NodeInfo nodeInfo = new NodeInfo( - clientId, data.nodeManagerAddress(), true, resources); + clientId, data.getNodeManagerAddress(), true, resources); clients.put(clientId, nodeInfo); - } else if (data.entryType() == EntryType.RES_CREATEUPDATE) { + } else if (data.getEntryType() == EntryType.RES_CREATEUPDATE) { Preconditions.checkState(clients.containsKey(clientId)); NodeInfo nodeInfo = clients.get(clientId); - for (int i = 0; i < data.resourcesTotalLabelLength(); i++) { - nodeInfo.resources.put(data.resourcesTotalLabel(i), data.resourcesTotalCapacity(i)); + for (int i = 0; i < data.getResourcesTotalLabelCount(); i++) { + nodeInfo.resources.put(data.getResourcesTotalLabel(i), data.getResourcesTotalCapacity(i)); } - } else if (data.entryType() == EntryType.RES_DELETE) { + } else if (data.getEntryType() == EntryType.RES_DELETE) { Preconditions.checkState(clients.containsKey(clientId)); NodeInfo nodeInfo = clients.get(clientId); - for (int i = 0; i < data.resourcesTotalLabelLength(); i++) { - nodeInfo.resources.remove(data.resourcesTotalLabel(i)); + for (int i = 0; i < data.getResourcesTotalLabelCount(); i++) { + nodeInfo.resources.remove(data.getResourcesTotalLabel(i)); } } else { // Code path of node deletion. - Preconditions.checkState(data.entryType() == EntryType.DELETION); + Preconditions.checkState(data.getEntryType() == EntryType.DELETION); NodeInfo nodeInfo = new NodeInfo(clientId, clients.get(clientId).nodeAddress, false, clients.get(clientId).resources); clients.put(clientId, nodeInfo); @@ -107,7 +113,7 @@ public List getAllNodeInfo() { */ public boolean actorExists(UniqueId actorId) { byte[] key = ArrayUtils.addAll( - TablePrefix.name(TablePrefix.ACTOR).getBytes(), actorId.getBytes()); + TablePrefix.ACTOR.toString().getBytes(), actorId.getBytes()); return primary.exists(key); } @@ -115,7 +121,7 @@ public boolean actorExists(UniqueId actorId) { * Query whether the raylet task exists in Gcs. */ public boolean rayletTaskExistsInGcs(TaskId taskId) { - byte[] key = ArrayUtils.addAll(TablePrefix.name(TablePrefix.RAYLET_TASK).getBytes(), + byte[] key = ArrayUtils.addAll(TablePrefix.RAYLET_TASK.toString().getBytes(), taskId.getBytes()); RedisClient client = getShardClient(taskId); return client.exists(key); @@ -126,19 +132,26 @@ public boolean rayletTaskExistsInGcs(TaskId taskId) { */ public List getCheckpointsForActor(UniqueId actorId) { List checkpoints = new ArrayList<>(); - final String prefix = TablePrefix.name(TablePrefix.ACTOR_CHECKPOINT_ID); + final String prefix = TablePrefix.ACTOR_CHECKPOINT_ID.toString(); final byte[] key = ArrayUtils.addAll(prefix.getBytes(), actorId.getBytes()); RedisClient client = getShardClient(actorId); byte[] result = client.get(key); if (result != null) { - ActorCheckpointIdData data = - ActorCheckpointIdData.getRootAsActorCheckpointIdData(ByteBuffer.wrap(result)); - UniqueId[] checkpointIds = IdUtil.getUniqueIdsFromByteBuffer( - data.checkpointIdsAsByteBuffer()); + ActorCheckpointIdData data = null; + try { + data = ActorCheckpointIdData.parseFrom(result); + } catch (InvalidProtocolBufferException e) { + throw new RuntimeException("Received invalid protobuf data from GCS."); + } + UniqueId[] checkpointIds = new UniqueId[data.getCheckpointIdsCount()]; + for (int i = 0; i < checkpointIds.length; i++) { + checkpointIds[i] = UniqueId + .fromByteBuffer(data.getCheckpointIds(i).asReadOnlyByteBuffer()); + } for (int i = 0; i < checkpointIds.length; i++) { - checkpoints.add(new Checkpoint(checkpointIds[i], data.timestamps(i))); + checkpoints.add(new Checkpoint(checkpointIds[i], data.getTimestamps(i))); } } checkpoints.sort((x, y) -> Long.compare(y.timestamp, x.timestamp)); diff --git a/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java b/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java index f9e310249a35..1a7e4701c22b 100644 --- a/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java +++ b/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java @@ -16,7 +16,7 @@ import org.ray.runtime.AbstractRayRuntime; import org.ray.runtime.RayDevRuntime; import org.ray.runtime.config.RunMode; -import org.ray.runtime.generated.ErrorType; +import org.ray.runtime.generated.Gcs.ErrorType; import org.ray.runtime.util.IdUtil; import org.ray.runtime.util.Serializer; import org.slf4j.Logger; @@ -29,12 +29,12 @@ public class ObjectStoreProxy { private static final Logger LOGGER = LoggerFactory.getLogger(ObjectStoreProxy.class); - private static final byte[] WORKER_EXCEPTION_META = String.valueOf(ErrorType.WORKER_DIED) - .getBytes(); - private static final byte[] ACTOR_EXCEPTION_META = String.valueOf(ErrorType.ACTOR_DIED) - .getBytes(); + private static final byte[] WORKER_EXCEPTION_META = String + .valueOf(ErrorType.WORKER_DIED.getNumber()).getBytes(); + private static final byte[] ACTOR_EXCEPTION_META = String + .valueOf(ErrorType.ACTOR_DIED.getNumber()).getBytes(); private static final byte[] UNRECONSTRUCTABLE_EXCEPTION_META = String - .valueOf(ErrorType.OBJECT_UNRECONSTRUCTABLE).getBytes(); + .valueOf(ErrorType.OBJECT_UNRECONSTRUCTABLE.getNumber()).getBytes(); private static final byte[] RAW_TYPE_META = "RAW".getBytes(); diff --git a/python/ray/gcs_utils.py b/python/ray/gcs_utils.py index cadd197ec73f..ba72e96f41db 100644 --- a/python/ray/gcs_utils.py +++ b/python/ray/gcs_utils.py @@ -2,38 +2,39 @@ from __future__ import division from __future__ import print_function -import flatbuffers -import ray.core.generated.ErrorTableData - -from ray.core.generated.ActorCheckpointIdData import ActorCheckpointIdData -from ray.core.generated.ClientTableData import ClientTableData -from ray.core.generated.DriverTableData import DriverTableData -from ray.core.generated.ErrorTableData import ErrorTableData -from ray.core.generated.GcsEntry import GcsEntry -from ray.core.generated.HeartbeatBatchTableData import HeartbeatBatchTableData -from ray.core.generated.HeartbeatTableData import HeartbeatTableData -from ray.core.generated.Language import Language -from ray.core.generated.ObjectTableData import ObjectTableData -from ray.core.generated.ProfileTableData import ProfileTableData -from ray.core.generated.TablePrefix import TablePrefix -from ray.core.generated.TablePubsub import TablePubsub - from ray.core.generated.ray.protocol.Task import Task +from ray.core.generated.gcs_pb2 import ( + ActorCheckpointIdData, + ClientTableData, + DriverTableData, + ErrorTableData, + ErrorType, + GcsEntry, + HeartbeatBatchTableData, + HeartbeatTableData, + ObjectTableData, + ProfileTableData, + TablePrefix, + TablePubsub, + TaskTableData, +) + __all__ = [ "ActorCheckpointIdData", "ClientTableData", "DriverTableData", "ErrorTableData", + "ErrorType", "GcsEntry", "HeartbeatBatchTableData", "HeartbeatTableData", - "Language", "ObjectTableData", "ProfileTableData", "TablePrefix", "TablePubsub", "Task", + "TaskTableData", "construct_error_message", ] @@ -42,13 +43,16 @@ REPORTER_CHANNEL = "RAY_REPORTER" # xray heartbeats -XRAY_HEARTBEAT_CHANNEL = str(TablePubsub.HEARTBEAT).encode("ascii") -XRAY_HEARTBEAT_BATCH_CHANNEL = str(TablePubsub.HEARTBEAT_BATCH).encode("ascii") +XRAY_HEARTBEAT_CHANNEL = str( + TablePubsub.Value("HEARTBEAT_PUBSUB")).encode("ascii") +XRAY_HEARTBEAT_BATCH_CHANNEL = str( + TablePubsub.Value("HEARTBEAT_BATCH_PUBSUB")).encode("ascii") # xray driver updates -XRAY_DRIVER_CHANNEL = str(TablePubsub.DRIVER).encode("ascii") +XRAY_DRIVER_CHANNEL = str(TablePubsub.Value("DRIVER_PUBSUB")).encode("ascii") -# These prefixes must be kept up-to-date with the TablePrefix enum in gcs.fbs. +# These prefixes must be kept up-to-date with the TablePrefix enum in +# gcs.proto. # TODO(rkn): We should use scoped enums, in which case we should be able to # just access the flatbuffer generated values. TablePrefix_RAYLET_TASK_string = "RAYLET_TASK" @@ -70,22 +74,9 @@ def construct_error_message(driver_id, error_type, message, timestamp): Returns: The serialized object. """ - builder = flatbuffers.Builder(0) - driver_offset = builder.CreateString(driver_id.binary()) - error_type_offset = builder.CreateString(error_type) - message_offset = builder.CreateString(message) - - ray.core.generated.ErrorTableData.ErrorTableDataStart(builder) - ray.core.generated.ErrorTableData.ErrorTableDataAddDriverId( - builder, driver_offset) - ray.core.generated.ErrorTableData.ErrorTableDataAddType( - builder, error_type_offset) - ray.core.generated.ErrorTableData.ErrorTableDataAddErrorMessage( - builder, message_offset) - ray.core.generated.ErrorTableData.ErrorTableDataAddTimestamp( - builder, timestamp) - error_data_offset = ray.core.generated.ErrorTableData.ErrorTableDataEnd( - builder) - builder.Finish(error_data_offset) - - return bytes(builder.Output()) + data = ErrorTableData() + data.driver_id = driver_id.binary() + data.type = error_type + data.error_message = message + data.timestamp = timestamp + return data.SerializeToString() diff --git a/python/ray/monitor.py b/python/ray/monitor.py index c9e0424b3eb8..35597ef231e3 100644 --- a/python/ray/monitor.py +++ b/python/ray/monitor.py @@ -101,28 +101,26 @@ def subscribe(self, channel): def xray_heartbeat_batch_handler(self, unused_channel, data): """Handle an xray heartbeat batch message from Redis.""" - gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(data, 0) - heartbeat_data = gcs_entries.Entries(0) + gcs_entries = ray.gcs_utils.GcsEntry.FromString(data) + heartbeat_data = gcs_entries.entries[0] - message = (ray.gcs_utils.HeartbeatBatchTableData. - GetRootAsHeartbeatBatchTableData(heartbeat_data, 0)) + message = ray.gcs_utils.HeartbeatBatchTableData.FromString( + heartbeat_data) - for j in range(message.BatchLength()): - heartbeat_message = message.Batch(j) - - num_resources = heartbeat_message.ResourcesTotalLabelLength() + for heartbeat_message in message.batch: + num_resources = len(heartbeat_message.resources_available_label) static_resources = {} dynamic_resources = {} for i in range(num_resources): - dyn = heartbeat_message.ResourcesAvailableLabel(i) - static = heartbeat_message.ResourcesTotalLabel(i) + dyn = heartbeat_message.resources_available_label[i] + static = heartbeat_message.resources_total_label[i] dynamic_resources[dyn] = ( - heartbeat_message.ResourcesAvailableCapacity(i)) + heartbeat_message.resources_available_capacity[i]) static_resources[static] = ( - heartbeat_message.ResourcesTotalCapacity(i)) + heartbeat_message.resources_total_capacity[i]) # Update the load metrics for this raylet. - client_id = ray.utils.binary_to_hex(heartbeat_message.ClientId()) + client_id = ray.utils.binary_to_hex(heartbeat_message.client_id) ip = self.raylet_id_to_ip_map.get(client_id) if ip: self.load_metrics.update(ip, static_resources, @@ -207,11 +205,10 @@ def xray_driver_removed_handler(self, unused_channel, data): unused_channel: The message channel. data: The message data. """ - gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(data, 0) - driver_data = gcs_entries.Entries(0) - message = ray.gcs_utils.DriverTableData.GetRootAsDriverTableData( - driver_data, 0) - driver_id = message.DriverId() + gcs_entries = ray.gcs_utils.GcsEntry.FromString(data) + driver_data = gcs_entries.entries[0] + message = ray.gcs_utils.DriverTableData.FromString(driver_data) + driver_id = message.driver_id logger.info("Monitor: " "XRay Driver {} has been removed.".format( binary_to_hex(driver_id))) diff --git a/python/ray/state.py b/python/ray/state.py index 14ba49987ec4..35f97cd65f5e 100644 --- a/python/ray/state.py +++ b/python/ray/state.py @@ -10,11 +10,11 @@ import ray from ray.function_manager import FunctionDescriptor -import ray.gcs_utils -from ray.ray_constants import ID_SIZE -from ray import services -from ray.core.generated.EntryType import EntryType +from ray import ( + gcs_utils, + services, +) from ray.utils import (decode, binary_to_object_id, binary_to_hex, hex_to_binary) @@ -31,9 +31,9 @@ def _parse_client_table(redis_client): A list of information about the nodes in the cluster. """ NIL_CLIENT_ID = ray.ObjectID.nil().binary() - message = redis_client.execute_command("RAY.TABLE_LOOKUP", - ray.gcs_utils.TablePrefix.CLIENT, - "", NIL_CLIENT_ID) + message = redis_client.execute_command( + "RAY.TABLE_LOOKUP", gcs_utils.TablePrefix.Value("CLIENT"), "", + NIL_CLIENT_ID) # Handle the case where no clients are returned. This should only # occur potentially immediately after the cluster is started. @@ -41,36 +41,31 @@ def _parse_client_table(redis_client): return [] node_info = {} - gcs_entry = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0) + gcs_entry = gcs_utils.GcsEntry.FromString(message) ordered_client_ids = [] # Since GCS entries are append-only, we override so that # only the latest entries are kept. - for i in range(gcs_entry.EntriesLength()): - client = (ray.gcs_utils.ClientTableData.GetRootAsClientTableData( - gcs_entry.Entries(i), 0)) + for entry in gcs_entry.entries: + client = gcs_utils.ClientTableData.FromString(entry) resources = { - decode(client.ResourcesTotalLabel(i)): - client.ResourcesTotalCapacity(i) - for i in range(client.ResourcesTotalLabelLength()) + client.resources_total_label[i]: client.resources_total_capacity[i] + for i in range(len(client.resources_total_label)) } - client_id = ray.utils.binary_to_hex(client.ClientId()) + client_id = ray.utils.binary_to_hex(client.client_id) - if client.EntryType() == EntryType.INSERTION: + if client.entry_type == gcs_utils.ClientTableData.INSERTION: ordered_client_ids.append(client_id) node_info[client_id] = { "ClientID": client_id, - "EntryType": client.EntryType(), - "NodeManagerAddress": decode( - client.NodeManagerAddress(), allow_none=True), - "NodeManagerPort": client.NodeManagerPort(), - "ObjectManagerPort": client.ObjectManagerPort(), - "ObjectStoreSocketName": decode( - client.ObjectStoreSocketName(), allow_none=True), - "RayletSocketName": decode( - client.RayletSocketName(), allow_none=True), + "EntryType": client.entry_type, + "NodeManagerAddress": client.node_manager_address, + "NodeManagerPort": client.node_manager_port, + "ObjectManagerPort": client.object_manager_port, + "ObjectStoreSocketName": client.object_store_socket_name, + "RayletSocketName": client.raylet_socket_name, "Resources": resources } @@ -79,22 +74,23 @@ def _parse_client_table(redis_client): # it cannot have previously been removed. else: assert client_id in node_info, "Client not found!" - assert node_info[client_id]["EntryType"] != EntryType.DELETION, ( - "Unexpected updation of deleted client.") + is_deletion = (node_info[client_id]["EntryType"] != + gcs_utils.ClientTableData.DELETION) + assert is_deletion, "Unexpected updation of deleted client." res_map = node_info[client_id]["Resources"] - if client.EntryType() == EntryType.RES_CREATEUPDATE: + if client.entry_type == gcs_utils.ClientTableData.RES_CREATEUPDATE: for res in resources: res_map[res] = resources[res] - elif client.EntryType() == EntryType.RES_DELETE: + elif client.entry_type == gcs_utils.ClientTableData.RES_DELETE: for res in resources: res_map.pop(res, None) - elif client.EntryType() == EntryType.DELETION: + elif client.entry_type == gcs_utils.ClientTableData.DELETION: pass # Do nothing with the resmap if client deletion else: raise RuntimeError("Unexpected EntryType {}".format( - client.EntryType())) + client.entry_type)) node_info[client_id]["Resources"] = res_map - node_info[client_id]["EntryType"] = client.EntryType() + node_info[client_id]["EntryType"] = client.entry_type # NOTE: We return the list comprehension below instead of simply doing # 'list(node_info.values())' in order to have the nodes appear in the order # that they joined the cluster. Python dictionaries do not preserve @@ -244,20 +240,19 @@ def _object_table(self, object_id): # Return information about a single object ID. message = self._execute_command(object_id, "RAY.TABLE_LOOKUP", - ray.gcs_utils.TablePrefix.OBJECT, "", - object_id.binary()) + gcs_utils.TablePrefix.Value("OBJECT"), + "", object_id.binary()) if message is None: return {} - gcs_entry = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0) + gcs_entry = gcs_utils.GcsEntry.FromString(message) - assert gcs_entry.EntriesLength() > 0 + assert len(gcs_entry.entries) > 0 - entry = ray.gcs_utils.ObjectTableData.GetRootAsObjectTableData( - gcs_entry.Entries(0), 0) + entry = gcs_utils.ObjectTableData.FromString(gcs_entry.entries[0]) object_info = { - "DataSize": entry.ObjectSize(), - "Manager": entry.Manager(), + "DataSize": entry.object_size, + "Manager": entry.manager, } return object_info @@ -278,10 +273,9 @@ def object_table(self, object_id=None): return self._object_table(object_id) else: # Return the entire object table. - object_keys = self._keys(ray.gcs_utils.TablePrefix_OBJECT_string + - "*") + object_keys = self._keys(gcs_utils.TablePrefix_OBJECT_string + "*") object_ids_binary = { - key[len(ray.gcs_utils.TablePrefix_OBJECT_string):] + key[len(gcs_utils.TablePrefix_OBJECT_string):] for key in object_keys } @@ -301,17 +295,18 @@ def _task_table(self, task_id): A dictionary with information about the task ID in question. """ assert isinstance(task_id, ray.TaskID) - message = self._execute_command(task_id, "RAY.TABLE_LOOKUP", - ray.gcs_utils.TablePrefix.RAYLET_TASK, - "", task_id.binary()) + message = self._execute_command( + task_id, "RAY.TABLE_LOOKUP", + gcs_utils.TablePrefix.Value("RAYLET_TASK"), "", task_id.binary()) if message is None: return {} - gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0) - - assert gcs_entries.EntriesLength() == 1 + gcs_entries = gcs_utils.GcsEntry.FromString(message) - task_table_message = ray.gcs_utils.Task.GetRootAsTask( - gcs_entries.Entries(0), 0) + assert len(gcs_entries.entries) == 1 + task_table_data = gcs_utils.TaskTableData.FromString( + gcs_entries.entries[0]) + task_table_message = gcs_utils.Task.GetRootAsTask( + task_table_data.task, 0) execution_spec = task_table_message.TaskExecutionSpec() task_spec = task_table_message.TaskSpecification() @@ -368,9 +363,9 @@ def task_table(self, task_id=None): return self._task_table(task_id) else: task_table_keys = self._keys( - ray.gcs_utils.TablePrefix_RAYLET_TASK_string + "*") + gcs_utils.TablePrefix_RAYLET_TASK_string + "*") task_ids_binary = [ - key[len(ray.gcs_utils.TablePrefix_RAYLET_TASK_string):] + key[len(gcs_utils.TablePrefix_RAYLET_TASK_string):] for key in task_table_keys ] @@ -380,27 +375,6 @@ def task_table(self, task_id=None): ray.TaskID(task_id_binary)) return results - def function_table(self, function_id=None): - """Fetch and parse the function table. - - Returns: - A dictionary that maps function IDs to information about the - function. - """ - self._check_connected() - function_table_keys = self.redis_client.keys( - ray.gcs_utils.FUNCTION_PREFIX + "*") - results = {} - for key in function_table_keys: - info = self.redis_client.hgetall(key) - function_info_parsed = { - "DriverID": binary_to_hex(info[b"driver_id"]), - "Module": decode(info[b"module"]), - "Name": decode(info[b"name"]) - } - results[binary_to_hex(info[b"function_id"])] = function_info_parsed - return results - def client_table(self): """Fetch and parse the Redis DB client table. @@ -423,37 +397,32 @@ def _profile_table(self, batch_id): # TODO(rkn): This method should support limiting the number of log # events and should also support returning a window of events. message = self._execute_command(batch_id, "RAY.TABLE_LOOKUP", - ray.gcs_utils.TablePrefix.PROFILE, "", - batch_id.binary()) + gcs_utils.TablePrefix.Value("PROFILE"), + "", batch_id.binary()) if message is None: return [] - gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0) + gcs_entries = gcs_utils.GcsEntry.FromString(message) profile_events = [] - for i in range(gcs_entries.EntriesLength()): - profile_table_message = ( - ray.gcs_utils.ProfileTableData.GetRootAsProfileTableData( - gcs_entries.Entries(i), 0)) - - component_type = decode(profile_table_message.ComponentType()) - component_id = binary_to_hex(profile_table_message.ComponentId()) - node_ip_address = decode( - profile_table_message.NodeIpAddress(), allow_none=True) + for entry in gcs_entries.entries: + profile_table_message = gcs_utils.ProfileTableData.FromString( + entry) - for j in range(profile_table_message.ProfileEventsLength()): - profile_event_message = profile_table_message.ProfileEvents(j) + component_type = profile_table_message.component_type + component_id = binary_to_hex(profile_table_message.component_id) + node_ip_address = profile_table_message.node_ip_address + for profile_event_message in profile_table_message.profile_events: profile_event = { - "event_type": decode(profile_event_message.EventType()), + "event_type": profile_event_message.event_type, "component_id": component_id, "node_ip_address": node_ip_address, "component_type": component_type, - "start_time": profile_event_message.StartTime(), - "end_time": profile_event_message.EndTime(), - "extra_data": json.loads( - decode(profile_event_message.ExtraData())), + "start_time": profile_event_message.start_time, + "end_time": profile_event_message.end_time, + "extra_data": json.loads(profile_event_message.extra_data), } profile_events.append(profile_event) @@ -462,10 +431,10 @@ def _profile_table(self, batch_id): def profile_table(self): self._check_connected() - profile_table_keys = self._keys( - ray.gcs_utils.TablePrefix_PROFILE_string + "*") + profile_table_keys = self._keys(gcs_utils.TablePrefix_PROFILE_string + + "*") batch_identifiers_binary = [ - key[len(ray.gcs_utils.TablePrefix_PROFILE_string):] + key[len(gcs_utils.TablePrefix_PROFILE_string):] for key in profile_table_keys ] @@ -766,7 +735,7 @@ def cluster_resources(self): clients = self.client_table() for client in clients: # Only count resources from latest entries of live clients. - if client["EntryType"] != EntryType.DELETION: + if client["EntryType"] != gcs_utils.ClientTableData.DELETION: for key, value in client["Resources"].items(): resources[key] += value return dict(resources) @@ -776,7 +745,7 @@ def _live_client_ids(self): return { client["ClientID"] for client in self.client_table() - if (client["EntryType"] != EntryType.DELETION) + if (client["EntryType"] != gcs_utils.ClientTableData.DELETION) } def available_resources(self): @@ -800,7 +769,7 @@ def available_resources(self): for redis_client in self.redis_clients ] for subscribe_client in subscribe_clients: - subscribe_client.subscribe(ray.gcs_utils.XRAY_HEARTBEAT_CHANNEL) + subscribe_client.subscribe(gcs_utils.XRAY_HEARTBEAT_CHANNEL) client_ids = self._live_client_ids() @@ -809,24 +778,23 @@ def available_resources(self): # Parse client message raw_message = subscribe_client.get_message() if (raw_message is None or raw_message["channel"] != - ray.gcs_utils.XRAY_HEARTBEAT_CHANNEL): + gcs_utils.XRAY_HEARTBEAT_CHANNEL): continue data = raw_message["data"] - gcs_entries = (ray.gcs_utils.GcsEntry.GetRootAsGcsEntry( - data, 0)) - heartbeat_data = gcs_entries.Entries(0) - message = (ray.gcs_utils.HeartbeatTableData. - GetRootAsHeartbeatTableData(heartbeat_data, 0)) + gcs_entries = gcs_utils.GcsEntry.FromString(data) + heartbeat_data = gcs_entries.entries[0] + message = gcs_utils.HeartbeatTableData.FromString( + heartbeat_data) # Calculate available resources for this client - num_resources = message.ResourcesAvailableLabelLength() + num_resources = len(message.resources_available_label) dynamic_resources = {} for i in range(num_resources): - resource_id = decode(message.ResourcesAvailableLabel(i)) + resource_id = message.resources_available_label[i] dynamic_resources[resource_id] = ( - message.ResourcesAvailableCapacity(i)) + message.resources_available_capacity[i]) # Update available resources for this client - client_id = ray.utils.binary_to_hex(message.ClientId()) + client_id = ray.utils.binary_to_hex(message.client_id) available_resources_by_id[client_id] = dynamic_resources # Update clients in cluster @@ -860,23 +828,22 @@ def _error_messages(self, driver_id): """ assert isinstance(driver_id, ray.DriverID) message = self.redis_client.execute_command( - "RAY.TABLE_LOOKUP", ray.gcs_utils.TablePrefix.ERROR_INFO, "", + "RAY.TABLE_LOOKUP", gcs_utils.TablePrefix.Value("ERROR_INFO"), "", driver_id.binary()) # If there are no errors, return early. if message is None: return [] - gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0) + gcs_entries = gcs_utils.GcsEntry.FromString(message) error_messages = [] - for i in range(gcs_entries.EntriesLength()): - error_data = ray.gcs_utils.ErrorTableData.GetRootAsErrorTableData( - gcs_entries.Entries(i), 0) - assert driver_id.binary() == error_data.DriverId() + for entry in gcs_entries.entries: + error_data = gcs_utils.ErrorTableData.FromString(entry) + assert driver_id.binary() == error_data.driver_id error_message = { - "type": decode(error_data.Type()), - "message": decode(error_data.ErrorMessage()), - "timestamp": error_data.Timestamp(), + "type": error_data.type, + "message": error_data.error_message, + "timestamp": error_data.timestamp, } error_messages.append(error_message) return error_messages @@ -899,9 +866,9 @@ def error_messages(self, driver_id=None): return self._error_messages(driver_id) error_table_keys = self.redis_client.keys( - ray.gcs_utils.TablePrefix_ERROR_INFO_string + "*") + gcs_utils.TablePrefix_ERROR_INFO_string + "*") driver_ids = [ - key[len(ray.gcs_utils.TablePrefix_ERROR_INFO_string):] + key[len(gcs_utils.TablePrefix_ERROR_INFO_string):] for key in error_table_keys ] @@ -923,30 +890,23 @@ def actor_checkpoint_info(self, actor_id): message = self._execute_command( actor_id, "RAY.TABLE_LOOKUP", - ray.gcs_utils.TablePrefix.ACTOR_CHECKPOINT_ID, + gcs_utils.TablePrefix.Value("ACTOR_CHECKPOINT_ID"), "", actor_id.binary(), ) if message is None: return None - gcs_entry = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0) - entry = ( - ray.gcs_utils.ActorCheckpointIdData.GetRootAsActorCheckpointIdData( - gcs_entry.Entries(0), 0)) - checkpoint_ids_str = entry.CheckpointIds() - num_checkpoints = len(checkpoint_ids_str) // ID_SIZE - assert len(checkpoint_ids_str) % ID_SIZE == 0 + gcs_entry = gcs_utils.GcsEntry.FromString(message) + entry = gcs_utils.ActorCheckpointIdData.FromString( + gcs_entry.entries[0]) checkpoint_ids = [ - ray.ActorCheckpointID( - checkpoint_ids_str[(i * ID_SIZE):((i + 1) * ID_SIZE)]) - for i in range(num_checkpoints) + ray.ActorCheckpointID(checkpoint_id) + for checkpoint_id in entry.checkpoint_ids ] return { - "ActorID": ray.utils.binary_to_hex(entry.ActorId()), + "ActorID": ray.utils.binary_to_hex(entry.actor_id), "CheckpointIds": checkpoint_ids, - "Timestamps": [ - entry.Timestamps(i) for i in range(num_checkpoints) - ], + "Timestamps": list(entry.timestamps), } diff --git a/python/ray/tests/cluster_utils.py b/python/ray/tests/cluster_utils.py index 703c3a1420ed..76dfd3000b86 100644 --- a/python/ray/tests/cluster_utils.py +++ b/python/ray/tests/cluster_utils.py @@ -8,7 +8,7 @@ import redis import ray -from ray.core.generated.EntryType import EntryType +from ray.gcs_utils import ClientTableData logger = logging.getLogger(__name__) @@ -177,7 +177,7 @@ def wait_for_nodes(self, timeout=30): clients = ray.state._parse_client_table(redis_client) live_clients = [ client for client in clients - if client["EntryType"] == EntryType.INSERTION + if client["EntryType"] == ClientTableData.INSERTION ] expected = len(self.list_all_nodes()) diff --git a/python/ray/tests/test_basic.py b/python/ray/tests/test_basic.py index 7f1f78d1b5c4..6b4bd754cd4d 100644 --- a/python/ray/tests/test_basic.py +++ b/python/ray/tests/test_basic.py @@ -2736,15 +2736,17 @@ def test_duplicate_error_messages(shutdown_only): r = ray.worker.global_worker.redis_client - r.execute_command("RAY.TABLE_APPEND", ray.gcs_utils.TablePrefix.ERROR_INFO, - ray.gcs_utils.TablePubsub.ERROR_INFO, driver_id.binary(), - error_data) + r.execute_command("RAY.TABLE_APPEND", + ray.gcs_utils.TablePrefix.Value("ERROR_INFO"), + ray.gcs_utils.TablePubsub.Value("ERROR_INFO_PUBSUB"), + driver_id.binary(), error_data) # Before https://github.com/ray-project/ray/pull/3316 this would # give an error - r.execute_command("RAY.TABLE_APPEND", ray.gcs_utils.TablePrefix.ERROR_INFO, - ray.gcs_utils.TablePubsub.ERROR_INFO, driver_id.binary(), - error_data) + r.execute_command("RAY.TABLE_APPEND", + ray.gcs_utils.TablePrefix.Value("ERROR_INFO"), + ray.gcs_utils.TablePubsub.Value("ERROR_INFO_PUBSUB"), + driver_id.binary(), error_data) @pytest.mark.skipif( diff --git a/python/ray/tests/test_failure.py b/python/ray/tests/test_failure.py index 51b906695c2d..a560e461f7a2 100644 --- a/python/ray/tests/test_failure.py +++ b/python/ray/tests/test_failure.py @@ -493,8 +493,9 @@ def test_warning_monitor_died(shutdown_only): malformed_message = "asdf" redis_client = ray.worker.global_worker.redis_client redis_client.execute_command( - "RAY.TABLE_ADD", ray.gcs_utils.TablePrefix.HEARTBEAT_BATCH, - ray.gcs_utils.TablePubsub.HEARTBEAT_BATCH, fake_id, malformed_message) + "RAY.TABLE_ADD", ray.gcs_utils.TablePrefix.Value("HEARTBEAT_BATCH"), + ray.gcs_utils.TablePubsub.Value("HEARTBEAT_BATCH_PUBSUB"), fake_id, + malformed_message) wait_for_errors(ray_constants.MONITOR_DIED_ERROR, 1) diff --git a/python/ray/utils.py b/python/ray/utils.py index 7b87486e325e..0db48e41d025 100644 --- a/python/ray/utils.py +++ b/python/ray/utils.py @@ -93,10 +93,10 @@ def push_error_to_driver_through_redis(redis_client, # of through the raylet. error_data = ray.gcs_utils.construct_error_message(driver_id, error_type, message, time.time()) - redis_client.execute_command("RAY.TABLE_APPEND", - ray.gcs_utils.TablePrefix.ERROR_INFO, - ray.gcs_utils.TablePubsub.ERROR_INFO, - driver_id.binary(), error_data) + redis_client.execute_command( + "RAY.TABLE_APPEND", ray.gcs_utils.TablePrefix.Value("ERROR_INFO"), + ray.gcs_utils.TablePubsub.Value("ERROR_INFO_PUBSUB"), + driver_id.binary(), error_data) def is_cython(obj): diff --git a/python/ray/worker.py b/python/ray/worker.py index 7505120574a6..710f0db43c6b 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -47,7 +47,7 @@ from ray import import_thread from ray import profiling -from ray.core.generated.ErrorType import ErrorType +from ray.gcs_utils import ErrorType from ray.exceptions import ( RayActorError, RayError, @@ -461,11 +461,11 @@ def _deserialize_object_from_arrow(self, data, metadata, object_id, # Otherwise, return an exception object based on # the error type. error_type = int(metadata) - if error_type == ErrorType.WORKER_DIED: + if error_type == ErrorType.Value("WORKER_DIED"): return RayWorkerError() - elif error_type == ErrorType.ACTOR_DIED: + elif error_type == ErrorType.Value("ACTOR_DIED"): return RayActorError() - elif error_type == ErrorType.OBJECT_UNRECONSTRUCTABLE: + elif error_type == ErrorType.Value("OBJECT_UNRECONSTRUCTABLE"): return UnreconstructableError(ray.ObjectID(object_id.binary())) else: assert False, "Unrecognized error type " + str(error_type) @@ -1637,7 +1637,7 @@ def listen_error_messages_raylet(worker, task_error_queue, threads_stopped): # Really we should just subscribe to the errors for this specific job. # However, currently all errors seem to be published on the same channel. error_pubsub_channel = str( - ray.gcs_utils.TablePubsub.ERROR_INFO).encode("ascii") + ray.gcs_utils.TablePubsub.Value("ERROR_INFO_PUBSUB")).encode("ascii") worker.error_message_pubsub_client.subscribe(error_pubsub_channel) # worker.error_message_pubsub_client.psubscribe("*") @@ -1656,21 +1656,19 @@ def listen_error_messages_raylet(worker, task_error_queue, threads_stopped): if msg is None: threads_stopped.wait(timeout=0.01) continue - gcs_entry = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry( - msg["data"], 0) - assert gcs_entry.EntriesLength() == 1 - error_data = ray.gcs_utils.ErrorTableData.GetRootAsErrorTableData( - gcs_entry.Entries(0), 0) - driver_id = error_data.DriverId() + gcs_entry = ray.gcs_utils.GcsEntry.FromString(msg["data"]) + assert len(gcs_entry.entries) == 1 + error_data = ray.gcs_utils.ErrorTableData.FromString( + gcs_entry.entries[0]) + driver_id = error_data.driver_id if driver_id not in [ worker.task_driver_id.binary(), DriverID.nil().binary() ]: continue - error_message = ray.utils.decode(error_data.ErrorMessage()) - if (ray.utils.decode( - error_data.Type()) == ray_constants.TASK_PUSH_ERROR): + error_message = error_data.error_message + if (error_data.type == ray_constants.TASK_PUSH_ERROR): # Delay it a bit to see if we can suppress it task_error_queue.put((error_message, time.time())) else: @@ -1878,14 +1876,16 @@ def connect(node, {}, # resource_map. {}, # placement_resource_map. ) + task_table_data = ray.gcs_utils.TaskTableData() + task_table_data.task = driver_task._serialized_raylet_task() # Add the driver task to the task table. - ray.state.state._execute_command(driver_task.task_id(), - "RAY.TABLE_ADD", - ray.gcs_utils.TablePrefix.RAYLET_TASK, - ray.gcs_utils.TablePubsub.RAYLET_TASK, - driver_task.task_id().binary(), - driver_task._serialized_raylet_task()) + ray.state.state._execute_command( + driver_task.task_id(), "RAY.TABLE_ADD", + ray.gcs_utils.TablePrefix.Value("RAYLET_TASK"), + ray.gcs_utils.TablePubsub.Value("RAYLET_TASK_PUBSUB"), + driver_task.task_id().binary(), + task_table_data.SerializeToString()) # Set the driver's current task ID to the task ID assigned to the # driver task. diff --git a/python/setup.py b/python/setup.py index db8676042de9..e7cf14737ee2 100644 --- a/python/setup.py +++ b/python/setup.py @@ -150,6 +150,7 @@ def find_version(*filepath): "six >= 1.0.0", "flatbuffers", "faulthandler;python_version<'3.3'", + "protobuf", ] setup( diff --git a/src/ray/gcs/client.cc b/src/ray/gcs/client.cc index c9b1e138575d..6de29bb52764 100644 --- a/src/ray/gcs/client.cc +++ b/src/ray/gcs/client.cc @@ -206,10 +206,6 @@ TaskLeaseTable &AsyncGcsClient::task_lease_table() { return *task_lease_table_; ClientTable &AsyncGcsClient::client_table() { return *client_table_; } -FunctionTable &AsyncGcsClient::function_table() { return *function_table_; } - -ClassTable &AsyncGcsClient::class_table() { return *class_table_; } - HeartbeatTable &AsyncGcsClient::heartbeat_table() { return *heartbeat_table_; } HeartbeatBatchTable &AsyncGcsClient::heartbeat_batch_table() { diff --git a/src/ray/gcs/client.h b/src/ray/gcs/client.h index c9f5b4bca624..5e70025b39a0 100644 --- a/src/ray/gcs/client.h +++ b/src/ray/gcs/client.h @@ -44,11 +44,7 @@ class RAY_EXPORT AsyncGcsClient { /// one event loop should be attached at a time. Status Attach(boost::asio::io_service &io_service); - inline FunctionTable &function_table(); // TODO: Some API for getting the error on the driver - inline ClassTable &class_table(); - inline CustomSerializerTable &custom_serializer_table(); - inline ConfigTable &config_table(); ObjectTable &object_table(); raylet::TaskTable &raylet_task_table(); ActorTable &actor_table(); @@ -81,8 +77,6 @@ class RAY_EXPORT AsyncGcsClient { std::string DebugString() const; private: - std::unique_ptr function_table_; - std::unique_ptr class_table_; std::unique_ptr object_table_; std::unique_ptr raylet_task_table_; std::unique_ptr actor_table_; diff --git a/src/ray/gcs/client_test.cc b/src/ray/gcs/client_test.cc index c7dc02e50651..55115b1e2067 100644 --- a/src/ray/gcs/client_test.cc +++ b/src/ray/gcs/client_test.cc @@ -85,21 +85,21 @@ class TestGcsWithChainAsio : public TestGcsWithAsio { void TestTableLookup(const DriverID &driver_id, std::shared_ptr client) { TaskID task_id = TaskID::FromRandom(); - auto data = std::make_shared(); - data->task_specification = "123"; + auto data = std::make_shared(); + data->set_task("123"); // Check that we added the correct task. auto add_callback = [task_id, data](gcs::AsyncGcsClient *client, const TaskID &id, - const protocol::TaskT &d) { + const TaskTableData &d) { ASSERT_EQ(id, task_id); - ASSERT_EQ(data->task_specification, d.task_specification); + ASSERT_EQ(data->task(), d.task()); }; // Check that the lookup returns the added task. auto lookup_callback = [task_id, data](gcs::AsyncGcsClient *client, const TaskID &id, - const protocol::TaskT &d) { + const TaskTableData &d) { ASSERT_EQ(id, task_id); - ASSERT_EQ(data->task_specification, d.task_specification); + ASSERT_EQ(data->task(), d.task()); test->Stop(); }; @@ -136,13 +136,13 @@ void TestLogLookup(const DriverID &driver_id, TaskID task_id = TaskID::FromRandom(); std::vector node_manager_ids = {"abc", "def", "ghi"}; for (auto &node_manager_id : node_manager_ids) { - auto data = std::make_shared(); - data->node_manager_id = node_manager_id; + auto data = std::make_shared(); + data->set_node_manager_id(node_manager_id); // Check that we added the correct object entries. auto add_callback = [task_id, data](gcs::AsyncGcsClient *client, const TaskID &id, - const TaskReconstructionDataT &d) { + const TaskReconstructionData &d) { ASSERT_EQ(id, task_id); - ASSERT_EQ(data->node_manager_id, d.node_manager_id); + ASSERT_EQ(data->node_manager_id(), d.node_manager_id()); }; RAY_CHECK_OK( client->task_reconstruction_log().Append(driver_id, task_id, data, add_callback)); @@ -151,10 +151,10 @@ void TestLogLookup(const DriverID &driver_id, // Check that lookup returns the added object entries. auto lookup_callback = [task_id, node_manager_ids]( gcs::AsyncGcsClient *client, const TaskID &id, - const std::vector &data) { + const std::vector &data) { ASSERT_EQ(id, task_id); for (const auto &entry : data) { - ASSERT_EQ(entry.node_manager_id, node_manager_ids[test->NumCallbacks()]); + ASSERT_EQ(entry.node_manager_id(), node_manager_ids[test->NumCallbacks()]); test->IncrementNumCallbacks(); } if (test->NumCallbacks() == node_manager_ids.size()) { @@ -182,7 +182,7 @@ void TestTableLookupFailure(const DriverID &driver_id, // Check that the lookup does not return data. auto lookup_callback = [](gcs::AsyncGcsClient *client, const TaskID &id, - const protocol::TaskT &d) { RAY_CHECK(false); }; + const TaskTableData &d) { RAY_CHECK(false); }; // Check that the lookup returns an empty entry. auto failure_callback = [task_id](gcs::AsyncGcsClient *client, const TaskID &id) { @@ -207,16 +207,16 @@ void TestLogAppendAt(const DriverID &driver_id, std::shared_ptr client) { TaskID task_id = TaskID::FromRandom(); std::vector node_manager_ids = {"A", "B"}; - std::vector> data_log; + std::vector> data_log; for (const auto &node_manager_id : node_manager_ids) { - auto data = std::make_shared(); - data->node_manager_id = node_manager_id; + auto data = std::make_shared(); + data->set_node_manager_id(node_manager_id); data_log.push_back(data); } // Check that we added the correct task. auto failure_callback = [task_id](gcs::AsyncGcsClient *client, const TaskID &id, - const TaskReconstructionDataT &d) { + const TaskReconstructionData &d) { ASSERT_EQ(id, task_id); test->IncrementNumCallbacks(); }; @@ -242,10 +242,10 @@ void TestLogAppendAt(const DriverID &driver_id, auto lookup_callback = [node_manager_ids]( gcs::AsyncGcsClient *client, const TaskID &id, - const std::vector &data) { + const std::vector &data) { std::vector appended_managers; for (const auto &entry : data) { - appended_managers.push_back(entry.node_manager_id); + appended_managers.push_back(entry.node_manager_id()); } ASSERT_EQ(appended_managers, node_manager_ids); test->Stop(); @@ -268,22 +268,22 @@ void TestSet(const DriverID &driver_id, std::shared_ptr cli ObjectID object_id = ObjectID::FromRandom(); std::vector managers = {"abc", "def", "ghi"}; for (auto &manager : managers) { - auto data = std::make_shared(); - data->manager = manager; + auto data = std::make_shared(); + data->set_manager(manager); // Check that we added the correct object entries. auto add_callback = [object_id, data](gcs::AsyncGcsClient *client, const ObjectID &id, - const ObjectTableDataT &d) { + const ObjectTableData &d) { ASSERT_EQ(id, object_id); - ASSERT_EQ(data->manager, d.manager); + ASSERT_EQ(data->manager(), d.manager()); test->IncrementNumCallbacks(); }; RAY_CHECK_OK(client->object_table().Add(driver_id, object_id, data, add_callback)); } // Check that lookup returns the added object entries. - auto lookup_callback = [object_id, managers]( - gcs::AsyncGcsClient *client, const ObjectID &id, - const std::vector &data) { + auto lookup_callback = [object_id, managers](gcs::AsyncGcsClient *client, + const ObjectID &id, + const std::vector &data) { ASSERT_EQ(id, object_id); ASSERT_EQ(data.size(), managers.size()); test->IncrementNumCallbacks(); @@ -293,14 +293,14 @@ void TestSet(const DriverID &driver_id, std::shared_ptr cli RAY_CHECK_OK(client->object_table().Lookup(driver_id, object_id, lookup_callback)); for (auto &manager : managers) { - auto data = std::make_shared(); - data->manager = manager; + auto data = std::make_shared(); + data->set_manager(manager); // Check that we added the correct object entries. auto remove_entry_callback = [object_id, data](gcs::AsyncGcsClient *client, const ObjectID &id, - const ObjectTableDataT &d) { + const ObjectTableData &d) { ASSERT_EQ(id, object_id); - ASSERT_EQ(data->manager, d.manager); + ASSERT_EQ(data->manager(), d.manager()); test->IncrementNumCallbacks(); }; RAY_CHECK_OK( @@ -310,7 +310,7 @@ void TestSet(const DriverID &driver_id, std::shared_ptr cli // Check that the entries are removed. auto lookup_callback2 = [object_id, managers]( gcs::AsyncGcsClient *client, const ObjectID &id, - const std::vector &data) { + const std::vector &data) { ASSERT_EQ(id, object_id); ASSERT_EQ(data.size(), 0); test->IncrementNumCallbacks(); @@ -332,7 +332,7 @@ TEST_F(TestGcsWithAsio, TestSet) { void TestDeleteKeysFromLog( const DriverID &driver_id, std::shared_ptr client, - std::vector> &data_vector) { + std::vector> &data_vector) { std::vector ids; TaskID task_id; for (auto &data : data_vector) { @@ -340,9 +340,9 @@ void TestDeleteKeysFromLog( ids.push_back(task_id); // Check that we added the correct object entries. auto add_callback = [task_id, data](gcs::AsyncGcsClient *client, const TaskID &id, - const TaskReconstructionDataT &d) { + const TaskReconstructionData &d) { ASSERT_EQ(id, task_id); - ASSERT_EQ(data->node_manager_id, d.node_manager_id); + ASSERT_EQ(data->node_manager_id(), d.node_manager_id()); test->IncrementNumCallbacks(); }; RAY_CHECK_OK( @@ -352,7 +352,7 @@ void TestDeleteKeysFromLog( // Check that lookup returns the added object entries. auto lookup_callback = [task_id, data_vector]( gcs::AsyncGcsClient *client, const TaskID &id, - const std::vector &data) { + const std::vector &data) { ASSERT_EQ(id, task_id); ASSERT_EQ(data.size(), 1); test->IncrementNumCallbacks(); @@ -367,7 +367,7 @@ void TestDeleteKeysFromLog( } for (const auto &task_id : ids) { auto lookup_callback = [task_id](gcs::AsyncGcsClient *client, const TaskID &id, - const std::vector &data) { + const std::vector &data) { ASSERT_EQ(id, task_id); ASSERT_TRUE(data.size() == 0); test->IncrementNumCallbacks(); @@ -379,7 +379,7 @@ void TestDeleteKeysFromLog( void TestDeleteKeysFromTable(const DriverID &driver_id, std::shared_ptr client, - std::vector> &data_vector, + std::vector> &data_vector, bool stop_at_end) { std::vector ids; TaskID task_id; @@ -388,16 +388,16 @@ void TestDeleteKeysFromTable(const DriverID &driver_id, ids.push_back(task_id); // Check that we added the correct object entries. auto add_callback = [task_id, data](gcs::AsyncGcsClient *client, const TaskID &id, - const protocol::TaskT &d) { + const TaskTableData &d) { ASSERT_EQ(id, task_id); - ASSERT_EQ(data->task_specification, d.task_specification); + ASSERT_EQ(data->task(), d.task()); test->IncrementNumCallbacks(); }; RAY_CHECK_OK(client->raylet_task_table().Add(driver_id, task_id, data, add_callback)); } for (const auto &task_id : ids) { auto task_lookup_callback = [task_id](gcs::AsyncGcsClient *client, const TaskID &id, - const protocol::TaskT &data) { + const TaskTableData &data) { ASSERT_EQ(id, task_id); test->IncrementNumCallbacks(); }; @@ -414,7 +414,7 @@ void TestDeleteKeysFromTable(const DriverID &driver_id, test->IncrementNumCallbacks(); }; auto undesired_callback = [](gcs::AsyncGcsClient *client, const TaskID &id, - const protocol::TaskT &data) { ASSERT_TRUE(false); }; + const TaskTableData &data) { ASSERT_TRUE(false); }; for (size_t i = 0; i < ids.size(); ++i) { RAY_CHECK_OK(client->raylet_task_table().Lookup( driver_id, task_id, undesired_callback, expected_failure_callback)); @@ -428,7 +428,7 @@ void TestDeleteKeysFromTable(const DriverID &driver_id, void TestDeleteKeysFromSet(const DriverID &driver_id, std::shared_ptr client, - std::vector> &data_vector) { + std::vector> &data_vector) { std::vector ids; ObjectID object_id; for (auto &data : data_vector) { @@ -436,9 +436,9 @@ void TestDeleteKeysFromSet(const DriverID &driver_id, ids.push_back(object_id); // Check that we added the correct object entries. auto add_callback = [object_id, data](gcs::AsyncGcsClient *client, const ObjectID &id, - const ObjectTableDataT &d) { + const ObjectTableData &d) { ASSERT_EQ(id, object_id); - ASSERT_EQ(data->manager, d.manager); + ASSERT_EQ(data->manager(), d.manager()); test->IncrementNumCallbacks(); }; RAY_CHECK_OK(client->object_table().Add(driver_id, object_id, data, add_callback)); @@ -447,7 +447,7 @@ void TestDeleteKeysFromSet(const DriverID &driver_id, // Check that lookup returns the added object entries. auto lookup_callback = [object_id, data_vector]( gcs::AsyncGcsClient *client, const ObjectID &id, - const std::vector &data) { + const std::vector &data) { ASSERT_EQ(id, object_id); ASSERT_EQ(data.size(), 1); test->IncrementNumCallbacks(); @@ -461,7 +461,7 @@ void TestDeleteKeysFromSet(const DriverID &driver_id, } for (const auto &object_id : ids) { auto lookup_callback = [object_id](gcs::AsyncGcsClient *client, const ObjectID &id, - const std::vector &data) { + const std::vector &data) { ASSERT_EQ(id, object_id); ASSERT_TRUE(data.size() == 0); test->IncrementNumCallbacks(); @@ -474,11 +474,11 @@ void TestDeleteKeysFromSet(const DriverID &driver_id, void TestDeleteKeys(const DriverID &driver_id, std::shared_ptr client) { // Test delete function for keys of Log. - std::vector> task_reconstruction_vector; + std::vector> task_reconstruction_vector; auto AppendTaskReconstructionData = [&task_reconstruction_vector](size_t add_count) { for (size_t i = 0; i < add_count; ++i) { - auto data = std::make_shared(); - data->node_manager_id = ObjectID::FromRandom().Hex(); + auto data = std::make_shared(); + data->set_node_manager_id(ObjectID::FromRandom().Hex()); task_reconstruction_vector.push_back(data); } }; @@ -503,11 +503,11 @@ void TestDeleteKeys(const DriverID &driver_id, TestDeleteKeysFromLog(driver_id, client, task_reconstruction_vector); // Test delete function for keys of Table. - std::vector> task_vector; + std::vector> task_vector; auto AppendTaskData = [&task_vector](size_t add_count) { for (size_t i = 0; i < add_count; ++i) { - auto task_data = std::make_shared(); - task_data->task_specification = ObjectID::FromRandom().Hex(); + auto task_data = std::make_shared(); + task_data->set_task(ObjectID::FromRandom().Hex()); task_vector.push_back(task_data); } }; @@ -529,11 +529,11 @@ void TestDeleteKeys(const DriverID &driver_id, 9 * RayConfig::instance().maximum_gcs_deletion_batch_size()); // Test delete function for keys of Set. - std::vector> object_vector; + std::vector> object_vector; auto AppendObjectData = [&object_vector](size_t add_count) { for (size_t i = 0; i < add_count; ++i) { - auto data = std::make_shared(); - data->manager = ObjectID::FromRandom().Hex(); + auto data = std::make_shared(); + data->set_manager(ObjectID::FromRandom().Hex()); object_vector.push_back(data); } }; @@ -561,45 +561,6 @@ TEST_F(TestGcsWithAsio, TestDeleteKey) { TestDeleteKeys(driver_id_, client_); } -// Task table callbacks. -void TaskAdded(gcs::AsyncGcsClient *client, const TaskID &id, - const TaskTableDataT &data) { - ASSERT_EQ(data.scheduling_state, SchedulingState::SCHEDULED); - ASSERT_EQ(data.raylet_id, kRandomId); -} - -void TaskLookupHelper(gcs::AsyncGcsClient *client, const TaskID &id, - const TaskTableDataT &data, bool do_stop) { - ASSERT_EQ(data.scheduling_state, SchedulingState::SCHEDULED); - ASSERT_EQ(data.raylet_id, kRandomId); - if (do_stop) { - test->Stop(); - } -} -void TaskLookup(gcs::AsyncGcsClient *client, const TaskID &id, - const TaskTableDataT &data) { - TaskLookupHelper(client, id, data, /*do_stop=*/false); -} -void TaskLookupWithStop(gcs::AsyncGcsClient *client, const TaskID &id, - const TaskTableDataT &data) { - TaskLookupHelper(client, id, data, /*do_stop=*/true); -} - -void TaskLookupFailure(gcs::AsyncGcsClient *client, const TaskID &id) { - RAY_CHECK(false); -} - -void TaskLookupAfterUpdate(gcs::AsyncGcsClient *client, const TaskID &id, - const TaskTableDataT &data) { - ASSERT_EQ(data.scheduling_state, SchedulingState::LOST); - test->Stop(); -} - -void TaskLookupAfterUpdateFailure(gcs::AsyncGcsClient *client, const TaskID &id) { - RAY_CHECK(false); - test->Stop(); -} - void TestLogSubscribeAll(const DriverID &driver_id, std::shared_ptr client) { std::vector driver_ids; @@ -609,11 +570,11 @@ void TestLogSubscribeAll(const DriverID &driver_id, // Callback for a notification. auto notification_callback = [driver_ids](gcs::AsyncGcsClient *client, const DriverID &id, - const std::vector data) { + const std::vector data) { ASSERT_EQ(id, driver_ids[test->NumCallbacks()]); // Check that we get notifications in the same order as the writes. for (const auto &entry : data) { - ASSERT_EQ(entry.driver_id, driver_ids[test->NumCallbacks()].Binary()); + ASSERT_EQ(entry.driver_id(), driver_ids[test->NumCallbacks()].Binary()); test->IncrementNumCallbacks(); } if (test->NumCallbacks() == driver_ids.size()) { @@ -660,7 +621,7 @@ void TestSetSubscribeAll(const DriverID &driver_id, auto notification_callback = [object_ids, managers]( gcs::AsyncGcsClient *client, const ObjectID &id, const GcsChangeMode change_mode, - const std::vector data) { + const std::vector data) { if (test->NumCallbacks() < 3 * 3) { ASSERT_EQ(change_mode, GcsChangeMode::APPEND_OR_ADD); } else { @@ -669,7 +630,7 @@ void TestSetSubscribeAll(const DriverID &driver_id, ASSERT_EQ(id, object_ids[test->NumCallbacks() / 3 % 3]); // Check that we get notifications in the same order as the writes. for (const auto &entry : data) { - ASSERT_EQ(entry.manager, managers[test->NumCallbacks() % 3]); + ASSERT_EQ(entry.manager(), managers[test->NumCallbacks() % 3]); test->IncrementNumCallbacks(); } if (test->NumCallbacks() == object_ids.size() * 3 * 2) { @@ -684,8 +645,8 @@ void TestSetSubscribeAll(const DriverID &driver_id, // We have subscribed. Do the writes to the table. for (size_t i = 0; i < object_ids.size(); i++) { for (size_t j = 0; j < managers.size(); j++) { - auto data = std::make_shared(); - data->manager = managers[j]; + auto data = std::make_shared(); + data->set_manager(managers[j]); for (int k = 0; k < 3; k++) { // Add the same entry several times. // Expect no notification if the entry already exists. @@ -696,8 +657,8 @@ void TestSetSubscribeAll(const DriverID &driver_id, } for (size_t i = 0; i < object_ids.size(); i++) { for (size_t j = 0; j < managers.size(); j++) { - auto data = std::make_shared(); - data->manager = managers[j]; + auto data = std::make_shared(); + data->set_manager(managers[j]); for (int k = 0; k < 3; k++) { // Remove the same entry several times. // Expect no notification if the entry doesn't exist. @@ -740,11 +701,11 @@ void TestTableSubscribeId(const DriverID &driver_id, // received for keys that we requested notifications for. auto notification_callback = [task_id2, task_specs2](gcs::AsyncGcsClient *client, const TaskID &id, - const protocol::TaskT &data) { + const TaskTableData &data) { // Check that we only get notifications for the requested key. ASSERT_EQ(id, task_id2); // Check that we get notifications in the same order as the writes. - ASSERT_EQ(data.task_specification, task_specs2[test->NumCallbacks()]); + ASSERT_EQ(data.task(), task_specs2[test->NumCallbacks()]); test->IncrementNumCallbacks(); if (test->NumCallbacks() == task_specs2.size()) { test->Stop(); @@ -771,13 +732,13 @@ void TestTableSubscribeId(const DriverID &driver_id, // Write both keys. We should only receive notifications for the key that // we requested them for. for (const auto &task_spec : task_specs1) { - auto data = std::make_shared(); - data->task_specification = task_spec; + auto data = std::make_shared(); + data->set_task(task_spec); RAY_CHECK_OK(client->raylet_task_table().Add(driver_id, task_id1, data, nullptr)); } for (const auto &task_spec : task_specs2) { - auto data = std::make_shared(); - data->task_specification = task_spec; + auto data = std::make_shared(); + data->set_task(task_spec); RAY_CHECK_OK(client->raylet_task_table().Add(driver_id, task_id2, data, nullptr)); } }; @@ -808,27 +769,27 @@ void TestLogSubscribeId(const DriverID &driver_id, // Add a log entry. DriverID driver_id1 = DriverID::FromRandom(); std::vector driver_ids1 = {"abc", "def", "ghi"}; - auto data1 = std::make_shared(); - data1->driver_id = driver_ids1[0]; + auto data1 = std::make_shared(); + data1->set_driver_id(driver_ids1[0]); RAY_CHECK_OK(client->driver_table().Append(driver_id, driver_id1, data1, nullptr)); // Add a log entry at a second key. DriverID driver_id2 = DriverID::FromRandom(); std::vector driver_ids2 = {"jkl", "mno", "pqr"}; - auto data2 = std::make_shared(); - data2->driver_id = driver_ids2[0]; + auto data2 = std::make_shared(); + data2->set_driver_id(driver_ids2[0]); RAY_CHECK_OK(client->driver_table().Append(driver_id, driver_id2, data2, nullptr)); // The callback for a notification from the table. This should only be // received for keys that we requested notifications for. auto notification_callback = [driver_id2, driver_ids2]( gcs::AsyncGcsClient *client, const UniqueID &id, - const std::vector &data) { + const std::vector &data) { // Check that we only get notifications for the requested key. ASSERT_EQ(id, driver_id2); // Check that we get notifications in the same order as the writes. for (const auto &entry : data) { - ASSERT_EQ(entry.driver_id, driver_ids2[test->NumCallbacks()]); + ASSERT_EQ(entry.driver_id(), driver_ids2[test->NumCallbacks()]); test->IncrementNumCallbacks(); } if (test->NumCallbacks() == driver_ids2.size()) { @@ -847,14 +808,14 @@ void TestLogSubscribeId(const DriverID &driver_id, // we requested them for. auto remaining = std::vector(++driver_ids1.begin(), driver_ids1.end()); for (const auto &driver_id_it : remaining) { - auto data = std::make_shared(); - data->driver_id = driver_id_it; + auto data = std::make_shared(); + data->set_driver_id(driver_id_it); RAY_CHECK_OK(client->driver_table().Append(driver_id, driver_id1, data, nullptr)); } remaining = std::vector(++driver_ids2.begin(), driver_ids2.end()); for (const auto &driver_id_it : remaining) { - auto data = std::make_shared(); - data->driver_id = driver_id_it; + auto data = std::make_shared(); + data->set_driver_id(driver_id_it); RAY_CHECK_OK(client->driver_table().Append(driver_id, driver_id2, data, nullptr)); } }; @@ -882,15 +843,15 @@ void TestSetSubscribeId(const DriverID &driver_id, // Add a set entry. ObjectID object_id1 = ObjectID::FromRandom(); std::vector managers1 = {"abc", "def", "ghi"}; - auto data1 = std::make_shared(); - data1->manager = managers1[0]; + auto data1 = std::make_shared(); + data1->set_manager(managers1[0]); RAY_CHECK_OK(client->object_table().Add(driver_id, object_id1, data1, nullptr)); // Add a set entry at a second key. ObjectID object_id2 = ObjectID::FromRandom(); std::vector managers2 = {"jkl", "mno", "pqr"}; - auto data2 = std::make_shared(); - data2->manager = managers2[0]; + auto data2 = std::make_shared(); + data2->set_manager(managers2[0]); RAY_CHECK_OK(client->object_table().Add(driver_id, object_id2, data2, nullptr)); // The callback for a notification from the table. This should only be @@ -898,13 +859,13 @@ void TestSetSubscribeId(const DriverID &driver_id, auto notification_callback = [object_id2, managers2]( gcs::AsyncGcsClient *client, const ObjectID &id, const GcsChangeMode change_mode, - const std::vector &data) { + const std::vector &data) { ASSERT_EQ(change_mode, GcsChangeMode::APPEND_OR_ADD); // Check that we only get notifications for the requested key. ASSERT_EQ(id, object_id2); // Check that we get notifications in the same order as the writes. for (const auto &entry : data) { - ASSERT_EQ(entry.manager, managers2[test->NumCallbacks()]); + ASSERT_EQ(entry.manager(), managers2[test->NumCallbacks()]); test->IncrementNumCallbacks(); } if (test->NumCallbacks() == managers2.size()) { @@ -923,14 +884,14 @@ void TestSetSubscribeId(const DriverID &driver_id, // we requested them for. auto remaining = std::vector(++managers1.begin(), managers1.end()); for (const auto &manager : remaining) { - auto data = std::make_shared(); - data->manager = manager; + auto data = std::make_shared(); + data->set_manager(manager); RAY_CHECK_OK(client->object_table().Add(driver_id, object_id1, data, nullptr)); } remaining = std::vector(++managers2.begin(), managers2.end()); for (const auto &manager : remaining) { - auto data = std::make_shared(); - data->manager = manager; + auto data = std::make_shared(); + data->set_manager(manager); RAY_CHECK_OK(client->object_table().Add(driver_id, object_id2, data, nullptr)); } }; @@ -958,8 +919,8 @@ void TestTableSubscribeCancel(const DriverID &driver_id, // Add a table entry. TaskID task_id = TaskID::FromRandom(); std::vector task_specs = {"jkl", "mno", "pqr"}; - auto data = std::make_shared(); - data->task_specification = task_specs[0]; + auto data = std::make_shared(); + data->set_task(task_specs[0]); RAY_CHECK_OK(client->raylet_task_table().Add(driver_id, task_id, data, nullptr)); // The failure callback should not be called since all keys are non-empty @@ -972,14 +933,14 @@ void TestTableSubscribeCancel(const DriverID &driver_id, // received for keys that we requested notifications for. auto notification_callback = [task_id, task_specs](gcs::AsyncGcsClient *client, const TaskID &id, - const protocol::TaskT &data) { + const TaskTableData &data) { ASSERT_EQ(id, task_id); // Check that we only get notifications for the first and last writes, // since notifications are canceled in between. if (test->NumCallbacks() == 0) { - ASSERT_EQ(data.task_specification, task_specs.front()); + ASSERT_EQ(data.task(), task_specs.front()); } else { - ASSERT_EQ(data.task_specification, task_specs.back()); + ASSERT_EQ(data.task(), task_specs.back()); } test->IncrementNumCallbacks(); if (test->NumCallbacks() == 2) { @@ -1001,8 +962,8 @@ void TestTableSubscribeCancel(const DriverID &driver_id, // a notification for these writes. auto remaining = std::vector(++task_specs.begin(), task_specs.end()); for (const auto &task_spec : remaining) { - auto data = std::make_shared(); - data->task_specification = task_spec; + auto data = std::make_shared(); + data->set_task(task_spec); RAY_CHECK_OK(client->raylet_task_table().Add(driver_id, task_id, data, nullptr)); } // Request notifications again. We should receive a notification for the @@ -1034,15 +995,15 @@ void TestLogSubscribeCancel(const DriverID &driver_id, // Add a log entry. DriverID random_driver_id = DriverID::FromRandom(); std::vector driver_ids = {"jkl", "mno", "pqr"}; - auto data = std::make_shared(); - data->driver_id = driver_ids[0]; + auto data = std::make_shared(); + data->set_driver_id(driver_ids[0]); RAY_CHECK_OK(client->driver_table().Append(driver_id, random_driver_id, data, nullptr)); // The callback for a notification from the object table. This should only be // received for the object that we requested notifications for. auto notification_callback = [random_driver_id, driver_ids]( gcs::AsyncGcsClient *client, const UniqueID &id, - const std::vector &data) { + const std::vector &data) { ASSERT_EQ(id, random_driver_id); // Check that we get a duplicate notification for the first write. We get a // duplicate notification because the log is append-only and notifications @@ -1050,7 +1011,7 @@ void TestLogSubscribeCancel(const DriverID &driver_id, auto driver_ids_copy = driver_ids; driver_ids_copy.insert(driver_ids_copy.begin(), driver_ids_copy.front()); for (const auto &entry : data) { - ASSERT_EQ(entry.driver_id, driver_ids_copy[test->NumCallbacks()]); + ASSERT_EQ(entry.driver_id(), driver_ids_copy[test->NumCallbacks()]); test->IncrementNumCallbacks(); } if (test->NumCallbacks() == driver_ids_copy.size()) { @@ -1072,8 +1033,8 @@ void TestLogSubscribeCancel(const DriverID &driver_id, // receive a notification for these writes. auto remaining = std::vector(++driver_ids.begin(), driver_ids.end()); for (const auto &remaining_driver_id : remaining) { - auto data = std::make_shared(); - data->driver_id = remaining_driver_id; + auto data = std::make_shared(); + data->set_driver_id(remaining_driver_id); RAY_CHECK_OK( client->driver_table().Append(driver_id, random_driver_id, data, nullptr)); } @@ -1107,8 +1068,8 @@ void TestSetSubscribeCancel(const DriverID &driver_id, // Add a set entry. ObjectID object_id = ObjectID::FromRandom(); std::vector managers = {"jkl", "mno", "pqr"}; - auto data = std::make_shared(); - data->manager = managers[0]; + auto data = std::make_shared(); + data->set_manager(managers[0]); RAY_CHECK_OK(client->object_table().Add(driver_id, object_id, data, nullptr)); // The callback for a notification from the object table. This should only be @@ -1116,7 +1077,7 @@ void TestSetSubscribeCancel(const DriverID &driver_id, auto notification_callback = [object_id, managers]( gcs::AsyncGcsClient *client, const ObjectID &id, const GcsChangeMode change_mode, - const std::vector &data) { + const std::vector &data) { ASSERT_EQ(change_mode, GcsChangeMode::APPEND_OR_ADD); ASSERT_EQ(id, object_id); // Check that we get a duplicate notification for the first write. We get a @@ -1124,7 +1085,7 @@ void TestSetSubscribeCancel(const DriverID &driver_id, // are canceled after the first write, then requested again. if (data.size() == 1) { // first notification - ASSERT_EQ(data[0].manager, managers[0]); + ASSERT_EQ(data[0].manager(), managers[0]); test->IncrementNumCallbacks(); } else { // second notification @@ -1132,7 +1093,7 @@ void TestSetSubscribeCancel(const DriverID &driver_id, std::unordered_set managers_set(managers.begin(), managers.end()); std::unordered_set data_managers_set; for (const auto &entry : data) { - data_managers_set.insert(entry.manager); + data_managers_set.insert(entry.manager()); test->IncrementNumCallbacks(); } ASSERT_EQ(managers_set, data_managers_set); @@ -1156,8 +1117,8 @@ void TestSetSubscribeCancel(const DriverID &driver_id, // receive a notification for these writes. auto remaining = std::vector(++managers.begin(), managers.end()); for (const auto &manager : remaining) { - auto data = std::make_shared(); - data->manager = manager; + auto data = std::make_shared(); + data->set_manager(manager); RAY_CHECK_OK(client->object_table().Add(driver_id, object_id, data, nullptr)); } // Request notifications again. We should receive a notification for the @@ -1186,17 +1147,17 @@ TEST_F(TestGcsWithAsio, TestSetSubscribeCancel) { } void ClientTableNotification(gcs::AsyncGcsClient *client, const ClientID &client_id, - const ClientTableDataT &data, bool is_insertion) { + const ClientTableData &data, bool is_insertion) { ClientID added_id = client->client_table().GetLocalClientId(); ASSERT_EQ(client_id, added_id); - ASSERT_EQ(ClientID::FromBinary(data.client_id), added_id); - ASSERT_EQ(ClientID::FromBinary(data.client_id), added_id); - ASSERT_EQ(data.entry_type == EntryType::INSERTION, is_insertion); + ASSERT_EQ(ClientID::FromBinary(data.client_id()), added_id); + ASSERT_EQ(ClientID::FromBinary(data.client_id()), added_id); + ASSERT_EQ(data.entry_type() == ClientTableData::INSERTION, is_insertion); - ClientTableDataT cached_client; + ClientTableData cached_client; client->client_table().GetClient(added_id, cached_client); - ASSERT_EQ(ClientID::FromBinary(cached_client.client_id), added_id); - ASSERT_EQ(cached_client.entry_type == EntryType::INSERTION, is_insertion); + ASSERT_EQ(ClientID::FromBinary(cached_client.client_id()), added_id); + ASSERT_EQ(cached_client.entry_type() == ClientTableData::INSERTION, is_insertion); } void TestClientTableConnect(const DriverID &driver_id, @@ -1204,17 +1165,17 @@ void TestClientTableConnect(const DriverID &driver_id, // Register callbacks for when a client gets added and removed. The latter // event will stop the event loop. client->client_table().RegisterClientAddedCallback( - [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) { + [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableData &data) { ClientTableNotification(client, id, data, true); test->Stop(); }); // Connect and disconnect to client table. We should receive notifications // for the addition and removal of our own entry. - ClientTableDataT local_client_info = client->client_table().GetLocalClient(); - local_client_info.node_manager_address = "127.0.0.1"; - local_client_info.node_manager_port = 0; - local_client_info.object_manager_port = 0; + ClientTableData local_client_info = client->client_table().GetLocalClient(); + local_client_info.set_node_manager_address("127.0.0.1"); + local_client_info.set_node_manager_port(0); + local_client_info.set_object_manager_port(0); RAY_CHECK_OK(client->client_table().Connect(local_client_info)); test->Start(); } @@ -1229,23 +1190,23 @@ void TestClientTableDisconnect(const DriverID &driver_id, // Register callbacks for when a client gets added and removed. The latter // event will stop the event loop. client->client_table().RegisterClientAddedCallback( - [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) { + [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableData &data) { ClientTableNotification(client, id, data, /*is_insertion=*/true); // Disconnect from the client table. We should receive a notification // for the removal of our own entry. RAY_CHECK_OK(client->client_table().Disconnect()); }); client->client_table().RegisterClientRemovedCallback( - [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) { + [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableData &data) { ClientTableNotification(client, id, data, /*is_insertion=*/false); test->Stop(); }); // Connect to the client table. We should receive notification for the // addition of our own entry. - ClientTableDataT local_client_info = client->client_table().GetLocalClient(); - local_client_info.node_manager_address = "127.0.0.1"; - local_client_info.node_manager_port = 0; - local_client_info.object_manager_port = 0; + ClientTableData local_client_info = client->client_table().GetLocalClient(); + local_client_info.set_node_manager_address("127.0.0.1"); + local_client_info.set_node_manager_port(0); + local_client_info.set_object_manager_port(0); RAY_CHECK_OK(client->client_table().Connect(local_client_info)); test->Start(); } @@ -1260,20 +1221,20 @@ void TestClientTableImmediateDisconnect(const DriverID &driver_id, // Register callbacks for when a client gets added and removed. The latter // event will stop the event loop. client->client_table().RegisterClientAddedCallback( - [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) { + [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableData &data) { ClientTableNotification(client, id, data, true); }); client->client_table().RegisterClientRemovedCallback( - [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) { + [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableData &data) { ClientTableNotification(client, id, data, false); test->Stop(); }); // Connect to then immediately disconnect from the client table. We should // receive notifications for the addition and removal of our own entry. - ClientTableDataT local_client_info = client->client_table().GetLocalClient(); - local_client_info.node_manager_address = "127.0.0.1"; - local_client_info.node_manager_port = 0; - local_client_info.object_manager_port = 0; + ClientTableData local_client_info = client->client_table().GetLocalClient(); + local_client_info.set_node_manager_address("127.0.0.1"); + local_client_info.set_node_manager_port(0); + local_client_info.set_object_manager_port(0); RAY_CHECK_OK(client->client_table().Connect(local_client_info)); RAY_CHECK_OK(client->client_table().Disconnect()); test->Start(); @@ -1286,10 +1247,10 @@ TEST_F(TestGcsWithAsio, TestClientTableImmediateDisconnect) { void TestClientTableMarkDisconnected(const DriverID &driver_id, std::shared_ptr client) { - ClientTableDataT local_client_info = client->client_table().GetLocalClient(); - local_client_info.node_manager_address = "127.0.0.1"; - local_client_info.node_manager_port = 0; - local_client_info.object_manager_port = 0; + ClientTableData local_client_info = client->client_table().GetLocalClient(); + local_client_info.set_node_manager_address("127.0.0.1"); + local_client_info.set_node_manager_port(0); + local_client_info.set_object_manager_port(0); // Connect to the client table to start receiving notifications. RAY_CHECK_OK(client->client_table().Connect(local_client_info)); // Mark a different client as dead. @@ -1299,8 +1260,8 @@ void TestClientTableMarkDisconnected(const DriverID &driver_id, // marked as dead. client->client_table().RegisterClientRemovedCallback( [dead_client_id](gcs::AsyncGcsClient *client, const UniqueID &id, - const ClientTableDataT &data) { - ASSERT_EQ(ClientID::FromBinary(data.client_id), dead_client_id); + const ClientTableData &data) { + ASSERT_EQ(ClientID::FromBinary(data.client_id()), dead_client_id); test->Stop(); }); test->Start(); @@ -1316,31 +1277,31 @@ void TestHashTable(const DriverID &driver_id, const int expected_count = 14; ClientID client_id = ClientID::FromRandom(); // Prepare the first resource map: data_map1. - auto cpu_data = std::make_shared(); - cpu_data->resource_name = "CPU"; - cpu_data->resource_capacity = 100; - auto gpu_data = std::make_shared(); - gpu_data->resource_name = "GPU"; - gpu_data->resource_capacity = 2; + auto cpu_data = std::make_shared(); + cpu_data->set_resource_name("CPU"); + cpu_data->set_resource_capacity(100); + auto gpu_data = std::make_shared(); + gpu_data->set_resource_name("GPU"); + gpu_data->set_resource_capacity(2); DynamicResourceTable::DataMap data_map1; data_map1.emplace("CPU", cpu_data); data_map1.emplace("GPU", gpu_data); // Prepare the second resource map: data_map2 which decreases CPU, // increases GPU and add a new CUSTOM compared to data_map1. - auto data_cpu = std::make_shared(); - data_cpu->resource_name = "CPU"; - data_cpu->resource_capacity = 50; - auto data_gpu = std::make_shared(); - data_gpu->resource_name = "GPU"; - data_gpu->resource_capacity = 10; - auto data_custom = std::make_shared(); - data_custom->resource_name = "CUSTOM"; - data_custom->resource_capacity = 2; + auto data_cpu = std::make_shared(); + data_cpu->set_resource_name("CPU"); + data_cpu->set_resource_capacity(50); + auto data_gpu = std::make_shared(); + data_gpu->set_resource_name("GPU"); + data_gpu->set_resource_capacity(10); + auto data_custom = std::make_shared(); + data_custom->set_resource_name("CUSTOM"); + data_custom->set_resource_capacity(2); DynamicResourceTable::DataMap data_map2; data_map2.emplace("CPU", data_cpu); data_map2.emplace("GPU", data_gpu); data_map2.emplace("CUSTOM", data_custom); - data_map2["CPU"]->resource_capacity = 50; + data_map2["CPU"]->set_resource_capacity(50); // This is a common comparison function for the test. auto compare_test = [](const DynamicResourceTable::DataMap &data1, const DynamicResourceTable::DataMap &data2) { @@ -1348,8 +1309,8 @@ void TestHashTable(const DriverID &driver_id, for (const auto &data : data1) { auto iter = data2.find(data.first); ASSERT_TRUE(iter != data2.end()); - ASSERT_EQ(iter->second->resource_name, data.second->resource_name); - ASSERT_EQ(iter->second->resource_capacity, data.second->resource_capacity); + ASSERT_EQ(iter->second->resource_name(), data.second->resource_name()); + ASSERT_EQ(iter->second->resource_capacity(), data.second->resource_capacity()); } }; auto subscribe_callback = [](AsyncGcsClient *client) { diff --git a/src/ray/gcs/format/gcs.fbs b/src/ray/gcs/format/gcs.fbs index 90476da73425..c06c79a02928 100644 --- a/src/ray/gcs/format/gcs.fbs +++ b/src/ray/gcs/format/gcs.fbs @@ -1,52 +1,9 @@ -enum Language:int { - PYTHON = 0, - CPP = 1, - JAVA = 2 -} - -// These indexes are mapped to strings in ray_redis_module.cc. -enum TablePrefix:int { - UNUSED = 0, - TASK, - RAYLET_TASK, - CLIENT, - OBJECT, - ACTOR, - FUNCTION, - TASK_RECONSTRUCTION, - HEARTBEAT, - HEARTBEAT_BATCH, - ERROR_INFO, - DRIVER, - PROFILE, - TASK_LEASE, - ACTOR_CHECKPOINT, - ACTOR_CHECKPOINT_ID, - NODE_RESOURCE, -} +// TODO(hchen): Migrate data structures in this file to protobuf (`gcs.proto`). -// The channel that Add operations to the Table should be published on, if any. -enum TablePubsub:int { - NO_PUBLISH = 0, - TASK, - RAYLET_TASK, - CLIENT, - OBJECT, - ACTOR, - HEARTBEAT, - HEARTBEAT_BATCH, - ERROR_INFO, - TASK_LEASE, - DRIVER, - NODE_RESOURCE, -} - -// Enum for the entry type in the ClientTable -enum EntryType:int { - INSERTION = 0, - DELETION, - RES_CREATEUPDATE, - RES_DELETE, +enum Language:int { + PYTHON=0, + JAVA=1, + CPP=2, } table Arg { @@ -120,118 +77,6 @@ table ResourcePair { value: double; } -enum GcsChangeMode:int { - APPEND_OR_ADD = 0, - REMOVE, -} - -table GcsEntry { - change_mode: GcsChangeMode; - id: string; - entries: [string]; -} - -table FunctionTableData { - language: Language; - name: string; - data: string; -} - -table ObjectTableData { - // The size of the object. - object_size: long; - // The node manager ID that this object appeared on or was evicted by. - manager: string; -} - -table TaskReconstructionData { - // The number of times this task has been reconstructed so far. - num_reconstructions: int; - // The node manager that is trying to reconstruct the task. - node_manager_id: string; -} - -enum SchedulingState:int { - NONE = 0, - WAITING = 1, - SCHEDULED = 2, - QUEUED = 4, - RUNNING = 8, - DONE = 16, - LOST = 32, - RECONSTRUCTING = 64 -} - -table TaskTableData { - // The state of the task. - scheduling_state: SchedulingState; - // A raylet ID. - raylet_id: string; - // A string of bytes representing the task's TaskExecutionDependencies. - execution_dependencies: string; - // The number of times the task was spilled back by raylets. - spillback_count: long; - // A string of bytes representing the task specification. - task_info: string; - // TODO(pcm): This is at the moment duplicated in task_info, remove that one - updated: bool; -} - -table TaskTableTestAndUpdate { - test_raylet_id: string; - test_state_bitmask: SchedulingState; - update_state: SchedulingState; -} - -table ClassTableData { -} - -enum ActorState:int { - // Actor is alive. - ALIVE = 0, - // Actor is dead, now being reconstructed. - // After reconstruction finishes, the state will become alive again. - RECONSTRUCTING = 1, - // Actor is already dead and won't be reconstructed. - DEAD = 2 -} - -table ActorTableData { - // The ID of the actor that was created. - actor_id: string; - // The dummy object ID returned by the actor creation task. If the actor - // dies, then this is the object that should be reconstructed for the actor - // to be recreated. - actor_creation_dummy_object_id: string; - // The ID of the driver that created the actor. - driver_id: string; - // The ID of the node manager that created the actor. - node_manager_id: string; - // Current state of this actor. - state: ActorState; - // Max number of times this actor should be reconstructed. - max_reconstructions: int; - // Remaining number of reconstructions. - remaining_reconstructions: int; -} - -table ErrorTableData { - // The ID of the driver that the error is for. - driver_id: string; - // The type of the error. - type: string; - // The error message. - error_message: string; - // The timestamp of the error message. - timestamp: double; -} - -table CustomSerializerData { -} - -table ConfigTableData { -} - table ProfileEvent { // The type of the event. event_type: string; @@ -258,119 +103,3 @@ table ProfileTableData { // we don't want each event to require a GCS command. profile_events: [ProfileEvent]; } - -table RayResource { - // The type of the resource. - resource_name: string; - // The total capacity of this resource type. - resource_capacity: double; -} - -table ClientTableData { - // The client ID of the client that the message is about. - client_id: string; - // The IP address of the client's node manager. - node_manager_address: string; - // The IPC socket name of the client's raylet. - raylet_socket_name: string; - // The IPC socket name of the client's plasma store. - object_store_socket_name: string; - // The port at which the client's node manager is listening for TCP - // connections from other node managers. - node_manager_port: int; - // The port at which the client's object manager is listening for TCP - // connections from other object managers. - object_manager_port: int; - // Enum to store the entry type in the log - entry_type: EntryType = INSERTION; - resources_total_label: [string]; - resources_total_capacity: [double]; -} - -table HeartbeatTableData { - // Node manager client id - client_id: string; - // Resource capacity currently available on this node manager. - resources_available_label: [string]; - resources_available_capacity: [double]; - // Total resource capacity configured for this node manager. - resources_total_label: [string]; - resources_total_capacity: [double]; - // Aggregate outstanding resource load on this node manager. - resource_load_label: [string]; - resource_load_capacity: [double]; -} - -table HeartbeatBatchTableData { - batch: [HeartbeatTableData]; -} - -// Data for a lease on task execution. -table TaskLeaseData { - // Node manager client ID. - node_manager_id: string; - // The time that the lease was last acquired at. NOTE(swang): This is the - // system clock time according to the node that added the entry and is not - // synchronized with other nodes. - acquired_at: long; - // The period that the lease is active for. - timeout: long; -} - -table DriverTableData { - // The driver ID. - driver_id: string; - // Whether it's dead. - is_dead: bool; -} - -// This table stores the actor checkpoint data. An actor checkpoint -// is the snapshot of an actor's state in the actor registration. -// See `actor_registration.h` for more detailed explanation of these fields. -table ActorCheckpointData { - // ID of this actor. - actor_id: string; - // The dummy object ID of actor's most recently executed task. - execution_dependency: string; - // A list of IDs of this actor's handles. - handle_ids: [string]; - // The task counters of the above handles. - task_counters: [long]; - // The frontier dependencies of the above handles. - frontier_dependencies: [string]; - // A list of unreleased dummy objects from this actor. - unreleased_dummy_objects: [string]; - // The numbers of dependencies for the above unreleased dummy objects. - num_dummy_object_dependencies: [int]; -} - -// This table stores the actor-to-available-checkpoint-ids mapping. -table ActorCheckpointIdData { - // ID of this actor. - actor_id: string; - // IDs of this actor's available checkpoints. - // Note, this is a long string that concatenates all the IDs. - checkpoint_ids: string; - // A list of the timestamps for each of the above `checkpoint_ids`. - timestamps: [long]; -} - -// This enum type is used as object's metadata to indicate the object's creating -// task has failed because of a certain error. -// TODO(hchen): We may want to make these errors more specific. E.g., we may want -// to distinguish between intentional and expected actor failures, and between -// worker process failure and node failure. -enum ErrorType:int { - // Indicates that a task failed because the worker died unexpectedly while executing it. - WORKER_DIED = 1, - // Indicates that a task failed because the actor died unexpectedly before finishing it. - ACTOR_DIED = 2, - // Indicates that an object is lost and cannot be reconstructed. - // Note, this currently only happens to actor objects. When the actor's state is already - // after the object's creating task, the actor cannot re-run the task. - // TODO(hchen): we may want to reuse this error type for more cases. E.g., - // 1) A object that was put by the driver. - // 2) The object's creating task is already cleaned up from GCS (this currently - // crashes raylet). - OBJECT_UNRECONSTRUCTABLE = 3, -} diff --git a/src/ray/gcs/redis_context.h b/src/ray/gcs/redis_context.h index fc42e5cd98c2..093aab2455d9 100644 --- a/src/ray/gcs/redis_context.h +++ b/src/ray/gcs/redis_context.h @@ -9,7 +9,7 @@ #include "ray/common/status.h" #include "ray/util/logging.h" -#include "ray/gcs/format/gcs_generated.h" +#include "ray/protobuf/gcs.pb.h" extern "C" { #include "ray/thirdparty/hiredis/adapters/ae.h" @@ -25,6 +25,9 @@ namespace ray { namespace gcs { +using rpc::TablePrefix; +using rpc::TablePubsub; + /// A simple reply wrapper for redis reply. class CallbackReply { public: @@ -126,8 +129,8 @@ class RedisContext { /// -1 for unused. If set, then data must be provided. /// \return Status. template - Status RunAsync(const std::string &command, const ID &id, const uint8_t *data, - int64_t length, const TablePrefix prefix, + Status RunAsync(const std::string &command, const ID &id, const void *data, + size_t length, const TablePrefix prefix, const TablePubsub pubsub_channel, RedisCallback redisCallback, int log_length = -1); @@ -157,9 +160,9 @@ class RedisContext { }; template -Status RedisContext::RunAsync(const std::string &command, const ID &id, - const uint8_t *data, int64_t length, - const TablePrefix prefix, const TablePubsub pubsub_channel, +Status RedisContext::RunAsync(const std::string &command, const ID &id, const void *data, + size_t length, const TablePrefix prefix, + const TablePubsub pubsub_channel, RedisCallback redisCallback, int log_length) { int64_t callback_index = RedisCallbackManager::instance().add(redisCallback, false); if (length > 0) { diff --git a/src/ray/gcs/redis_module/ray_redis_module.cc b/src/ray/gcs/redis_module/ray_redis_module.cc index e291b7ffdb32..c3a82c320d06 100644 --- a/src/ray/gcs/redis_module/ray_redis_module.cc +++ b/src/ray/gcs/redis_module/ray_redis_module.cc @@ -5,11 +5,16 @@ #include "ray/common/id.h" #include "ray/common/status.h" #include "ray/gcs/format/gcs_generated.h" +#include "ray/protobuf/gcs.pb.h" #include "ray/util/logging.h" #include "redis_string.h" #include "redismodule.h" using ray::Status; +using ray::rpc::GcsChangeMode; +using ray::rpc::GcsEntry; +using ray::rpc::TablePrefix; +using ray::rpc::TablePubsub; #if RAY_USE_NEW_GCS // Under this flag, ray-project/credis will be loaded. Specifically, via @@ -64,8 +69,8 @@ Status ParseTablePubsub(TablePubsub *out, const RedisModuleString *pubsub_channe REDISMODULE_OK) { return Status::RedisError("Pubsub channel must be a valid integer."); } - if (pubsub_channel_long > static_cast(TablePubsub::MAX) || - pubsub_channel_long < static_cast(TablePubsub::MIN)) { + if (pubsub_channel_long >= static_cast(TablePubsub::TABLE_PUBSUB_MAX) || + pubsub_channel_long <= static_cast(TablePubsub::TABLE_PUBSUB_MIN)) { return Status::RedisError("Pubsub channel must be in the TablePubsub range."); } else { *out = static_cast(pubsub_channel_long); @@ -80,7 +85,7 @@ Status FormatPubsubChannel(RedisModuleString **out, RedisModuleCtx *ctx, const RedisModuleString *id) { // Format the pubsub channel enum to a string. TablePubsub_MAX should be more // than enough digits, but add 1 just in case for the null terminator. - char pubsub_channel[static_cast(TablePubsub::MAX) + 1]; + char pubsub_channel[static_cast(TablePubsub::TABLE_PUBSUB_MAX) + 1]; TablePubsub table_pubsub; RAY_RETURN_NOT_OK(ParseTablePubsub(&table_pubsub, pubsub_channel_str)); sprintf(pubsub_channel, "%d", static_cast(table_pubsub)); @@ -95,8 +100,8 @@ Status ParseTablePrefix(const RedisModuleString *table_prefix_str, TablePrefix * REDISMODULE_OK) { return Status::RedisError("Prefix must be a valid TablePrefix integer"); } - if (table_prefix_long > static_cast(TablePrefix::MAX) || - table_prefix_long < static_cast(TablePrefix::MIN)) { + if (table_prefix_long >= static_cast(TablePrefix::TABLE_PREFIX_MAX) || + table_prefix_long <= static_cast(TablePrefix::TABLE_PREFIX_MIN)) { return Status::RedisError("Prefix must be in the TablePrefix range"); } else { *out = static_cast(table_prefix_long); @@ -113,7 +118,7 @@ RedisModuleString *PrefixedKeyString(RedisModuleCtx *ctx, RedisModuleString *pre if (!ParseTablePrefix(prefix_enum, &prefix).ok()) { return nullptr; } - return RedisString_Format(ctx, "%s%S", EnumNameTablePrefix(prefix), keyname); + return RedisString_Format(ctx, "%s%S", TablePrefix_Name(prefix).c_str(), keyname); } // TODO(swang): This helper function should be deprecated by the version below, @@ -136,8 +141,8 @@ Status OpenPrefixedKey(RedisModuleKey **out, RedisModuleCtx *ctx, int mode, RedisModuleString **mutated_key_str) { TablePrefix prefix; RAY_RETURN_NOT_OK(ParseTablePrefix(prefix_enum, &prefix)); - *out = - OpenPrefixedKey(ctx, EnumNameTablePrefix(prefix), keyname, mode, mutated_key_str); + *out = OpenPrefixedKey(ctx, TablePrefix_Name(prefix).c_str(), keyname, mode, + mutated_key_str); return Status::OK(); } @@ -165,18 +170,24 @@ Status GetBroadcastKey(RedisModuleCtx *ctx, RedisModuleString *pubsub_channel_st return Status::OK(); } -/// This is a helper method to convert a redis module string to a flatbuffer -/// string. +/// A helper function that creates `GcsEntry` protobuf object. /// -/// \param fbb The flatbuffer builder. -/// \param redis_string The redis string. -/// \return The flatbuffer string. -flatbuffers::Offset RedisStringToFlatbuf( - flatbuffers::FlatBufferBuilder &fbb, RedisModuleString *redis_string) { - size_t redis_string_size; - const char *redis_string_str = - RedisModule_StringPtrLen(redis_string, &redis_string_size); - return fbb.CreateString(redis_string_str, redis_string_size); +/// \param[in] id Id of the entry. +/// \param[in] change_mode Change mode of the entry. +/// \param[in] entries Vector of entries. +/// \param[out] result The created `GcsEntry` object. +inline void CreateGcsEntry(RedisModuleString *id, GcsChangeMode change_mode, + const std::vector &entries, + GcsEntry *result) { + const char *data; + size_t size; + data = RedisModule_StringPtrLen(id, &size); + result->set_id(data, size); + result->set_change_mode(change_mode); + for (const auto &entry : entries) { + data = RedisModule_StringPtrLen(entry, &size); + result->add_entries(data, size); + } } /// Helper method to publish formatted data to target channel. @@ -234,13 +245,10 @@ int PublishTableUpdate(RedisModuleCtx *ctx, RedisModuleString *pubsub_channel_st RedisModuleString *id, GcsChangeMode change_mode, RedisModuleString *data) { // Serialize the notification to send. - flatbuffers::FlatBufferBuilder fbb; - auto data_flatbuf = RedisStringToFlatbuf(fbb, data); - auto message = CreateGcsEntry(fbb, change_mode, RedisStringToFlatbuf(fbb, id), - fbb.CreateVector(&data_flatbuf, 1)); - fbb.Finish(message); - auto data_buffer = RedisModule_CreateString( - ctx, reinterpret_cast(fbb.GetBufferPointer()), fbb.GetSize()); + GcsEntry gcs_entry; + CreateGcsEntry(id, change_mode, {data}, &gcs_entry); + std::string str = gcs_entry.SerializeAsString(); + auto data_buffer = RedisModule_CreateString(ctx, str.data(), str.size()); return PublishDataHelper(ctx, pubsub_channel_str, id, data_buffer); } @@ -570,19 +578,20 @@ int HashUpdate_DoWrite(RedisModuleCtx *ctx, RedisModuleString **argv, int argc, size_t update_data_len = 0; const char *update_data_buf = RedisModule_StringPtrLen(update_data, &update_data_len); - auto data_vec = flatbuffers::GetRoot(update_data_buf); - *change_mode = data_vec->change_mode(); + GcsEntry gcs_entry; + gcs_entry.ParseFromArray(update_data_buf, update_data_len); + *change_mode = gcs_entry.change_mode(); + if (*change_mode == GcsChangeMode::APPEND_OR_ADD) { // This code path means they are updating command. - size_t total_size = data_vec->entries()->size(); + size_t total_size = gcs_entry.entries_size(); REPLY_AND_RETURN_IF_FALSE(total_size % 2 == 0, "Invalid Hash Update data vector."); for (int i = 0; i < total_size; i += 2) { // Reconstruct a key-value pair from a flattened list. RedisModuleString *entry_key = RedisModule_CreateString( - ctx, data_vec->entries()->Get(i)->data(), data_vec->entries()->Get(i)->size()); - RedisModuleString *entry_value = - RedisModule_CreateString(ctx, data_vec->entries()->Get(i + 1)->data(), - data_vec->entries()->Get(i + 1)->size()); + ctx, gcs_entry.entries(i).data(), gcs_entry.entries(i).size()); + RedisModuleString *entry_value = RedisModule_CreateString( + ctx, gcs_entry.entries(i + 1).data(), gcs_entry.entries(i + 1).size()); // Returning 0 if key exists(still updated), 1 if the key is created. RAY_IGNORE_EXPR( RedisModule_HashSet(key, REDISMODULE_HASH_NONE, entry_key, entry_value, NULL)); @@ -590,27 +599,25 @@ int HashUpdate_DoWrite(RedisModuleCtx *ctx, RedisModuleString **argv, int argc, *changed_data = update_data; } else { // This code path means the command wants to remove the entries. - size_t total_size = data_vec->entries()->size(); - flatbuffers::FlatBufferBuilder fbb; - std::vector> data; + GcsEntry updated; + updated.set_id(gcs_entry.id()); + updated.set_change_mode(gcs_entry.change_mode()); + + size_t total_size = gcs_entry.entries_size(); for (int i = 0; i < total_size; i++) { RedisModuleString *entry_key = RedisModule_CreateString( - ctx, data_vec->entries()->Get(i)->data(), data_vec->entries()->Get(i)->size()); + ctx, gcs_entry.entries(i).data(), gcs_entry.entries(i).size()); int deleted_num = RedisModule_HashSet(key, REDISMODULE_HASH_NONE, entry_key, REDISMODULE_HASH_DELETE, NULL); if (deleted_num != 0) { // The corresponding key is removed. - data.push_back(fbb.CreateString(data_vec->entries()->Get(i)->data(), - data_vec->entries()->Get(i)->size())); + updated.add_entries(gcs_entry.entries(i)); } } - auto message = - CreateGcsEntry(fbb, data_vec->change_mode(), - fbb.CreateString(data_vec->id()->data(), data_vec->id()->size()), - fbb.CreateVector(data)); - fbb.Finish(message); - *changed_data = RedisModule_CreateString( - ctx, reinterpret_cast(fbb.GetBufferPointer()), fbb.GetSize()); + + // Serialize updated data. + std::string str = updated.SerializeAsString(); + *changed_data = RedisModule_CreateString(ctx, str.data(), str.size()); auto size = RedisModule_ValueLength(key); if (size == 0) { REPLY_AND_RETURN_IF_FALSE(RedisModule_DeleteKey(key) == REDISMODULE_OK, @@ -631,7 +638,7 @@ int HashUpdate_DoWrite(RedisModuleCtx *ctx, RedisModuleString **argv, int argc, /// key should be published to. When publishing to a specific client, the /// channel name should be :. /// \param id The ID of the key to remove from. -/// \param data The GcsEntry flatbugger data used to update this hash table. +/// \param data The GcsEntry protobuf data used to update this hash table. /// 1). For deletion, this is a list of keys. /// 2). For updating, this is a list of pairs with each key followed by the value. /// \return OK if the remove succeeds, or an error message string if the remove @@ -648,7 +655,7 @@ int HashUpdate_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int a return Hash_DoPublish(ctx, new_argv.data()); } -/// A helper function to create and finish a GcsEntry, based on the +/// A helper function to create a GcsEntry protobuf, based on the /// current value or values at the given key. /// /// \param ctx The Redis module context. @@ -658,21 +665,18 @@ int HashUpdate_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int a /// \param prefix_str The string prefix associated with the open Redis key. /// When parsed, this is expected to be a TablePrefix. /// \param entry_id The UniqueID associated with the open Redis key. -/// \param fbb A flatbuffer builder used to build the GcsEntry. -Status TableEntryToFlatbuf(RedisModuleCtx *ctx, RedisModuleKey *table_key, - RedisModuleString *prefix_str, RedisModuleString *entry_id, - flatbuffers::FlatBufferBuilder &fbb) { +/// \param[out] gcs_entry The created GcsEntry. +Status TableEntryToProtobuf(RedisModuleCtx *ctx, RedisModuleKey *table_key, + RedisModuleString *prefix_str, RedisModuleString *entry_id, + GcsEntry *gcs_entry) { auto key_type = RedisModule_KeyType(table_key); switch (key_type) { case REDISMODULE_KEYTYPE_STRING: { - // Build the flatbuffer from the string data. + // Build the GcsEntry from the string data. + CreateGcsEntry(entry_id, GcsChangeMode::APPEND_OR_ADD, {}, gcs_entry); size_t data_len = 0; char *data_buf = RedisModule_StringDMA(table_key, &data_len, REDISMODULE_READ); - auto data = fbb.CreateString(data_buf, data_len); - auto message = - CreateGcsEntry(fbb, GcsChangeMode::APPEND_OR_ADD, - RedisStringToFlatbuf(fbb, entry_id), fbb.CreateVector(&data, 1)); - fbb.Finish(message); + gcs_entry->add_entries(data_buf, data_len); } break; case REDISMODULE_KEYTYPE_LIST: case REDISMODULE_KEYTYPE_HASH: @@ -696,27 +700,20 @@ Status TableEntryToFlatbuf(RedisModuleCtx *ctx, RedisModuleKey *table_key, reply = RedisModule_Call(ctx, "HGETALL", "s", table_key_str); break; } - // Build the flatbuffer from the set of log entries. + // Build the GcsEntry from the set of log entries. if (reply == nullptr || RedisModule_CallReplyType(reply) != REDISMODULE_REPLY_ARRAY) { return Status::RedisError("Empty list/set/hash or wrong type"); } - std::vector> data; + CreateGcsEntry(entry_id, GcsChangeMode::APPEND_OR_ADD, {}, gcs_entry); for (size_t i = 0; i < RedisModule_CallReplyLength(reply); i++) { RedisModuleCallReply *element = RedisModule_CallReplyArrayElement(reply, i); size_t len; const char *element_str = RedisModule_CallReplyStringPtr(element, &len); - data.push_back(fbb.CreateString(element_str, len)); + gcs_entry->add_entries(element_str, len); } - auto message = - CreateGcsEntry(fbb, GcsChangeMode::APPEND_OR_ADD, - RedisStringToFlatbuf(fbb, entry_id), fbb.CreateVector(data)); - fbb.Finish(message); } break; case REDISMODULE_KEYTYPE_EMPTY: { - auto message = CreateGcsEntry( - fbb, GcsChangeMode::APPEND_OR_ADD, RedisStringToFlatbuf(fbb, entry_id), - fbb.CreateVector(std::vector>())); - fbb.Finish(message); + CreateGcsEntry(entry_id, GcsChangeMode::APPEND_OR_ADD, {}, gcs_entry); } break; default: return Status::RedisError("Invalid Redis type during lookup."); @@ -752,11 +749,12 @@ int TableLookup_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int if (table_key == nullptr) { RedisModule_ReplyWithNull(ctx); } else { - // Serialize the data to a flatbuffer to return to the client. - flatbuffers::FlatBufferBuilder fbb; - REPLY_AND_RETURN_IF_NOT_OK(TableEntryToFlatbuf(ctx, table_key, prefix_str, id, fbb)); - RedisModule_ReplyWithStringBuffer( - ctx, reinterpret_cast(fbb.GetBufferPointer()), fbb.GetSize()); + // Serialize the data to a GcsEntry to return to the client. + GcsEntry gcs_entry; + REPLY_AND_RETURN_IF_NOT_OK( + TableEntryToProtobuf(ctx, table_key, prefix_str, id, &gcs_entry)); + std::string str = gcs_entry.SerializeAsString(); + RedisModule_ReplyWithStringBuffer(ctx, str.data(), str.size()); } return REDISMODULE_OK; } @@ -870,10 +868,11 @@ int TableRequestNotifications_RedisCommand(RedisModuleCtx *ctx, RedisModuleStrin // Publish the current value at the key to the client that is requesting // notifications. An empty notification will be published if the key is // empty. - flatbuffers::FlatBufferBuilder fbb; - REPLY_AND_RETURN_IF_NOT_OK(TableEntryToFlatbuf(ctx, table_key, prefix_str, id, fbb)); - RedisModule_Call(ctx, "PUBLISH", "sb", client_channel, - reinterpret_cast(fbb.GetBufferPointer()), fbb.GetSize()); + GcsEntry gcs_entry; + REPLY_AND_RETURN_IF_NOT_OK( + TableEntryToProtobuf(ctx, table_key, prefix_str, id, &gcs_entry)); + std::string str = gcs_entry.SerializeAsString(); + RedisModule_Call(ctx, "PUBLISH", "sb", client_channel, str.data(), str.size()); return RedisModule_ReplyWithNull(ctx); } @@ -940,53 +939,6 @@ Status IsNil(bool *out, const std::string &data) { return Status::OK(); } -// This is a temporary redis command that will be removed once -// the GCS uses https://github.com/pcmoritz/credis. -// Be careful, this only supports Task Table payloads. -int TableTestAndUpdate_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, - int argc) { - if (argc != 5) { - return RedisModule_WrongArity(ctx); - } - RedisModuleString *prefix_str = argv[1]; - RedisModuleString *id = argv[3]; - RedisModuleString *update_data = argv[4]; - - RedisModuleKey *key; - REPLY_AND_RETURN_IF_NOT_OK( - OpenPrefixedKey(&key, ctx, prefix_str, id, REDISMODULE_READ | REDISMODULE_WRITE)); - - size_t value_len = 0; - char *value_buf = RedisModule_StringDMA(key, &value_len, REDISMODULE_READ); - - size_t update_len = 0; - const char *update_buf = RedisModule_StringPtrLen(update_data, &update_len); - - auto data = - flatbuffers::GetMutableRoot(reinterpret_cast(value_buf)); - - auto update = flatbuffers::GetRoot(update_buf); - - bool do_update = static_cast(data->scheduling_state()) & - static_cast(update->test_state_bitmask()); - - bool is_nil_result; - REPLY_AND_RETURN_IF_NOT_OK(IsNil(&is_nil_result, update->test_raylet_id()->str())); - if (!is_nil_result) { - do_update = do_update && update->test_raylet_id()->str() == data->raylet_id()->str(); - } - - if (do_update) { - REPLY_AND_RETURN_IF_FALSE(data->mutate_scheduling_state(update->update_state()), - "mutate_scheduling_state failed"); - } - REPLY_AND_RETURN_IF_FALSE(data->mutate_updated(do_update), "mutate_updated failed"); - - int result = RedisModule_ReplyWithStringBuffer(ctx, value_buf, value_len); - - return result; -} - std::string DebugString() { std::stringstream result; result << "RedisModule:"; @@ -1016,7 +968,6 @@ AUTO_MEMORY(TableLookup_RedisCommand); AUTO_MEMORY(TableRequestNotifications_RedisCommand); AUTO_MEMORY(TableDelete_RedisCommand); AUTO_MEMORY(TableCancelNotifications_RedisCommand); -AUTO_MEMORY(TableTestAndUpdate_RedisCommand); AUTO_MEMORY(DebugString_RedisCommand); #if RAY_USE_NEW_GCS AUTO_MEMORY(ChainTableAdd_RedisCommand); @@ -1082,12 +1033,6 @@ int RedisModule_OnLoad(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) return REDISMODULE_ERR; } - if (RedisModule_CreateCommand(ctx, "ray.table_test_and_update", - TableTestAndUpdate_RedisCommand, "write", 0, 0, - 0) == REDISMODULE_ERR) { - return REDISMODULE_ERR; - } - if (RedisModule_CreateCommand(ctx, "ray.debug_string", DebugString_RedisCommand, "readonly", 0, 0, 0) == REDISMODULE_ERR) { return REDISMODULE_ERR; diff --git a/src/ray/gcs/tables.cc b/src/ray/gcs/tables.cc index 33f1615580a6..b7c19ebfd595 100644 --- a/src/ray/gcs/tables.cc +++ b/src/ray/gcs/tables.cc @@ -3,6 +3,7 @@ #include "ray/common/common_protocol.h" #include "ray/common/ray_config.h" #include "ray/gcs/client.h" +#include "ray/rpc/util.h" #include "ray/util/util.h" namespace { @@ -39,48 +40,44 @@ namespace gcs { template Status Log::Append(const DriverID &driver_id, const ID &id, - std::shared_ptr &dataT, const WriteCallback &done) { + std::shared_ptr &data, const WriteCallback &done) { num_appends_++; - auto callback = [this, id, dataT, done](const CallbackReply &reply) { + auto callback = [this, id, data, done](const CallbackReply &reply) { const auto status = reply.ReadAsStatus(); // Failed to append the entry. RAY_CHECK(status.ok()) << "Failed to execute command TABLE_APPEND:" << status.ToString(); if (done != nullptr) { - (done)(client_, id, *dataT); + (done)(client_, id, *data); } }; - flatbuffers::FlatBufferBuilder fbb; - fbb.ForceDefaults(true); - fbb.Finish(Data::Pack(fbb, dataT.get())); - return GetRedisContext(id)->RunAsync(GetLogAppendCommand(command_type_), id, - fbb.GetBufferPointer(), fbb.GetSize(), prefix_, - pubsub_channel_, std::move(callback)); + std::string str = data->SerializeAsString(); + return GetRedisContext(id)->RunAsync(GetLogAppendCommand(command_type_), id, str.data(), + str.length(), prefix_, pubsub_channel_, + std::move(callback)); } template Status Log::AppendAt(const DriverID &driver_id, const ID &id, - std::shared_ptr &dataT, const WriteCallback &done, + std::shared_ptr &data, const WriteCallback &done, const WriteCallback &failure, int log_length) { num_appends_++; - auto callback = [this, id, dataT, done, failure](const CallbackReply &reply) { + auto callback = [this, id, data, done, failure](const CallbackReply &reply) { const auto status = reply.ReadAsStatus(); if (status.ok()) { if (done != nullptr) { - (done)(client_, id, *dataT); + (done)(client_, id, *data); } } else { if (failure != nullptr) { - (failure)(client_, id, *dataT); + (failure)(client_, id, *data); } } }; - flatbuffers::FlatBufferBuilder fbb; - fbb.ForceDefaults(true); - fbb.Finish(Data::Pack(fbb, dataT.get())); - return GetRedisContext(id)->RunAsync(GetLogAppendCommand(command_type_), id, - fbb.GetBufferPointer(), fbb.GetSize(), prefix_, - pubsub_channel_, std::move(callback), log_length); + std::string str = data->SerializeAsString(); + return GetRedisContext(id)->RunAsync(GetLogAppendCommand(command_type_), id, str.data(), + str.length(), prefix_, pubsub_channel_, + std::move(callback), log_length); } template @@ -89,16 +86,15 @@ Status Log::Lookup(const DriverID &driver_id, const ID &id, num_lookups_++; auto callback = [this, id, lookup](const CallbackReply &reply) { if (lookup != nullptr) { - std::vector results; + std::vector results; if (!reply.IsNil()) { - const auto data = reply.ReadAsString(); - auto root = flatbuffers::GetRoot(data.data()); - RAY_CHECK(from_flatbuf(*root->id()) == id); - for (size_t i = 0; i < root->entries()->size(); i++) { - DataT result; - auto data_root = flatbuffers::GetRoot(root->entries()->Get(i)->data()); - data_root->UnPackTo(&result); - results.emplace_back(std::move(result)); + GcsEntry gcs_entry; + gcs_entry.ParseFromString(reply.ReadAsString()); + RAY_CHECK(ID::FromBinary(gcs_entry.id()) == id); + for (size_t i = 0; i < gcs_entry.entries_size(); i++) { + Data data; + data.ParseFromString(gcs_entry.entries(i)); + results.emplace_back(std::move(data)); } } lookup(client_, id, results); @@ -115,7 +111,7 @@ Status Log::Subscribe(const DriverID &driver_id, const ClientID &clien const SubscriptionCallback &done) { auto subscribe_wrapper = [subscribe](AsyncGcsClient *client, const ID &id, const GcsChangeMode change_mode, - const std::vector &data) { + const std::vector &data) { RAY_CHECK(change_mode != GcsChangeMode::REMOVE); subscribe(client, id, data); }; @@ -141,19 +137,16 @@ Status Log::Subscribe(const DriverID &driver_id, const ClientID &clien // Data is provided. This is the callback for a message. if (subscribe != nullptr) { // Parse the notification. - auto root = flatbuffers::GetRoot(data.data()); - ID id; - if (root->id()->size() > 0) { - id = from_flatbuf(*root->id()); - } - std::vector results; - for (size_t i = 0; i < root->entries()->size(); i++) { - DataT result; - auto data_root = flatbuffers::GetRoot(root->entries()->Get(i)->data()); - data_root->UnPackTo(&result); + GcsEntry gcs_entry; + gcs_entry.ParseFromString(data); + ID id = ID::FromBinary(gcs_entry.id()); + std::vector results; + for (size_t i = 0; i < gcs_entry.entries_size(); i++) { + Data result; + result.ParseFromString(gcs_entry.entries(i)); results.emplace_back(std::move(result)); } - subscribe(client_, id, root->change_mode(), results); + subscribe(client_, id, gcs_entry.change_mode(), results); } } }; @@ -234,19 +227,17 @@ std::string Log::DebugString() const { template Status Table::Add(const DriverID &driver_id, const ID &id, - std::shared_ptr &dataT, const WriteCallback &done) { + std::shared_ptr &data, const WriteCallback &done) { num_adds_++; - auto callback = [this, id, dataT, done](const CallbackReply &reply) { + auto callback = [this, id, data, done](const CallbackReply &reply) { if (done != nullptr) { - (done)(client_, id, *dataT); + (done)(client_, id, *data); } }; - flatbuffers::FlatBufferBuilder fbb; - fbb.ForceDefaults(true); - fbb.Finish(Data::Pack(fbb, dataT.get())); - return GetRedisContext(id)->RunAsync(GetTableAddCommand(command_type_), id, - fbb.GetBufferPointer(), fbb.GetSize(), prefix_, - pubsub_channel_, std::move(callback)); + std::string str = data->SerializeAsString(); + return GetRedisContext(id)->RunAsync(GetTableAddCommand(command_type_), id, str.data(), + str.length(), prefix_, pubsub_channel_, + std::move(callback)); } template @@ -255,7 +246,7 @@ Status Table::Lookup(const DriverID &driver_id, const ID &id, num_lookups_++; return Log::Lookup(driver_id, id, [lookup, failure](AsyncGcsClient *client, const ID &id, - const std::vector &data) { + const std::vector &data) { if (data.empty()) { if (failure != nullptr) { (failure)(client, id); @@ -277,7 +268,7 @@ Status Table::Subscribe(const DriverID &driver_id, const ClientID &cli return Log::Subscribe( driver_id, client_id, [subscribe, failure](AsyncGcsClient *client, const ID &id, - const std::vector &data) { + const std::vector &data) { RAY_CHECK(data.empty() || data.size() == 1); if (data.size() == 1) { subscribe(client, id, data[0]); @@ -299,36 +290,30 @@ std::string Table::DebugString() const { template Status Set::Add(const DriverID &driver_id, const ID &id, - std::shared_ptr &dataT, const WriteCallback &done) { + std::shared_ptr &data, const WriteCallback &done) { num_adds_++; - auto callback = [this, id, dataT, done](const CallbackReply &reply) { + auto callback = [this, id, data, done](const CallbackReply &reply) { if (done != nullptr) { - (done)(client_, id, *dataT); + (done)(client_, id, *data); } }; - flatbuffers::FlatBufferBuilder fbb; - fbb.ForceDefaults(true); - fbb.Finish(Data::Pack(fbb, dataT.get())); - return GetRedisContext(id)->RunAsync("RAY.SET_ADD", id, fbb.GetBufferPointer(), - fbb.GetSize(), prefix_, pubsub_channel_, - std::move(callback)); + std::string str = data->SerializeAsString(); + return GetRedisContext(id)->RunAsync("RAY.SET_ADD", id, str.data(), str.length(), + prefix_, pubsub_channel_, std::move(callback)); } template Status Set::Remove(const DriverID &driver_id, const ID &id, - std::shared_ptr &dataT, const WriteCallback &done) { + std::shared_ptr &data, const WriteCallback &done) { num_removes_++; - auto callback = [this, id, dataT, done](const CallbackReply &reply) { + auto callback = [this, id, data, done](const CallbackReply &reply) { if (done != nullptr) { - (done)(client_, id, *dataT); + (done)(client_, id, *data); } }; - flatbuffers::FlatBufferBuilder fbb; - fbb.ForceDefaults(true); - fbb.Finish(Data::Pack(fbb, dataT.get())); - return GetRedisContext(id)->RunAsync("RAY.SET_REMOVE", id, fbb.GetBufferPointer(), - fbb.GetSize(), prefix_, pubsub_channel_, - std::move(callback)); + std::string str = data->SerializeAsString(); + return GetRedisContext(id)->RunAsync("RAY.SET_REMOVE", id, str.data(), str.length(), + prefix_, pubsub_channel_, std::move(callback)); } template @@ -348,26 +333,16 @@ Status Hash::Update(const DriverID &driver_id, const ID &id, (done)(client_, id, data_map); } }; - flatbuffers::FlatBufferBuilder fbb; - std::vector> data_vec; - data_vec.reserve(data_map.size() * 2); - for (auto const &pair : data_map) { - // Add the key. - data_vec.push_back(fbb.CreateString(pair.first)); - flatbuffers::FlatBufferBuilder fbb_data; - fbb_data.ForceDefaults(true); - fbb_data.Finish(Data::Pack(fbb_data, pair.second.get())); - std::string data(reinterpret_cast(fbb_data.GetBufferPointer()), - fbb_data.GetSize()); - // Add the value. - data_vec.push_back(fbb.CreateString(data)); + GcsEntry gcs_entry; + gcs_entry.set_id(id.Binary()); + gcs_entry.set_change_mode(GcsChangeMode::APPEND_OR_ADD); + for (const auto &pair : data_map) { + gcs_entry.add_entries(pair.first); + gcs_entry.add_entries(pair.second->SerializeAsString()); } - - fbb.Finish(CreateGcsEntry(fbb, GcsChangeMode::APPEND_OR_ADD, - fbb.CreateString(id.Binary()), fbb.CreateVector(data_vec))); - return GetRedisContext(id)->RunAsync("RAY.HASH_UPDATE", id, fbb.GetBufferPointer(), - fbb.GetSize(), prefix_, pubsub_channel_, - std::move(callback)); + std::string str = gcs_entry.SerializeAsString(); + return GetRedisContext(id)->RunAsync("RAY.HASH_UPDATE", id, str.data(), str.size(), + prefix_, pubsub_channel_, std::move(callback)); } template @@ -380,19 +355,15 @@ Status Hash::RemoveEntries(const DriverID &driver_id, const ID &id, (remove_callback)(client_, id, keys); } }; - flatbuffers::FlatBufferBuilder fbb; - std::vector> data_vec; - data_vec.reserve(keys.size()); - // Add the keys. - for (auto const &key : keys) { - data_vec.push_back(fbb.CreateString(key)); + GcsEntry gcs_entry; + gcs_entry.set_id(id.Binary()); + gcs_entry.set_change_mode(GcsChangeMode::REMOVE); + for (const auto &key : keys) { + gcs_entry.add_entries(key); } - - fbb.Finish(CreateGcsEntry(fbb, GcsChangeMode::REMOVE, fbb.CreateString(id.Binary()), - fbb.CreateVector(data_vec))); - return GetRedisContext(id)->RunAsync("RAY.HASH_UPDATE", id, fbb.GetBufferPointer(), - fbb.GetSize(), prefix_, pubsub_channel_, - std::move(callback)); + std::string str = gcs_entry.SerializeAsString(); + return GetRedisContext(id)->RunAsync("RAY.HASH_UPDATE", id, str.data(), str.size(), + prefix_, pubsub_channel_, std::move(callback)); } template @@ -412,17 +383,15 @@ Status Hash::Lookup(const DriverID &driver_id, const ID &id, DataMap results; if (!reply.IsNil()) { const auto data = reply.ReadAsString(); - auto root = flatbuffers::GetRoot(data.data()); - RAY_CHECK(from_flatbuf(*root->id()) == id); - RAY_CHECK(root->entries()->size() % 2 == 0); - for (size_t i = 0; i < root->entries()->size(); i += 2) { - std::string key(root->entries()->Get(i)->data(), - root->entries()->Get(i)->size()); - auto result = std::make_shared(); - auto data_root = - flatbuffers::GetRoot(root->entries()->Get(i + 1)->data()); - data_root->UnPackTo(result.get()); - results.emplace(key, std::move(result)); + GcsEntry gcs_entry; + gcs_entry.ParseFromString(reply.ReadAsString()); + RAY_CHECK(ID::FromBinary(gcs_entry.id()) == id); + RAY_CHECK(gcs_entry.entries_size() % 2 == 0); + for (int i = 0; i < gcs_entry.entries_size(); i += 2) { + const auto &key = gcs_entry.entries(i); + const auto value = std::make_shared(); + value->ParseFromString(gcs_entry.entries(i + 1)); + results.emplace(key, std::move(value)); } } lookup(client_, id, results); @@ -451,31 +420,24 @@ Status Hash::Subscribe(const DriverID &driver_id, const ClientID &clie // Data is provided. This is the callback for a message. if (subscribe != nullptr) { // Parse the notification. - auto root = flatbuffers::GetRoot(data.data()); + GcsEntry gcs_entry; + gcs_entry.ParseFromString(data); + ID id = ID::FromBinary(gcs_entry.id()); DataMap data_map; - ID id; - if (root->id()->size() > 0) { - id = from_flatbuf(*root->id()); - } - if (root->change_mode() == GcsChangeMode::REMOVE) { - for (size_t i = 0; i < root->entries()->size(); i++) { - std::string key(root->entries()->Get(i)->data(), - root->entries()->Get(i)->size()); - data_map.emplace(key, std::shared_ptr()); + if (gcs_entry.change_mode() == GcsChangeMode::REMOVE) { + for (const auto &key : gcs_entry.entries()) { + data_map.emplace(key, std::shared_ptr()); } } else { - RAY_CHECK(root->entries()->size() % 2 == 0); - for (size_t i = 0; i < root->entries()->size(); i += 2) { - std::string key(root->entries()->Get(i)->data(), - root->entries()->Get(i)->size()); - auto result = std::make_shared(); - auto data_root = - flatbuffers::GetRoot(root->entries()->Get(i + 1)->data()); - data_root->UnPackTo(result.get()); - data_map.emplace(key, std::move(result)); + RAY_CHECK(gcs_entry.entries_size() % 2 == 0); + for (int i = 0; i < gcs_entry.entries_size(); i += 2) { + const auto &key = gcs_entry.entries(i); + const auto value = std::make_shared(); + value->ParseFromString(gcs_entry.entries(i + 1)); + data_map.emplace(key, std::move(value)); } } - subscribe(client_, id, root->change_mode(), data_map); + subscribe(client_, id, gcs_entry.change_mode(), data_map); } } }; @@ -490,11 +452,11 @@ Status Hash::Subscribe(const DriverID &driver_id, const ClientID &clie Status ErrorTable::PushErrorToDriver(const DriverID &driver_id, const std::string &type, const std::string &error_message, double timestamp) { - auto data = std::make_shared(); - data->driver_id = driver_id.Binary(); - data->type = type; - data->error_message = error_message; - data->timestamp = timestamp; + auto data = std::make_shared(); + data->set_driver_id(driver_id.Binary()); + data->set_type(type); + data->set_error_message(error_message); + data->set_timestamp(timestamp); return Append(DriverID(driver_id), driver_id, data, /*done_callback=*/nullptr); } @@ -503,11 +465,9 @@ std::string ErrorTable::DebugString() const { } Status ProfileTable::AddProfileEventBatch(const ProfileTableData &profile_events) { - auto data = std::make_shared(); - // There is some room for optimization here because the Append function will just - // call "Pack" and undo the "UnPack". - profile_events.UnPackTo(data.get()); - + // TODO(hchen): Change the parameter to shared_ptr to avoid copying data. + auto data = std::make_shared(); + data->CopyFrom(profile_events); return Append(DriverID::Nil(), UniqueID::FromRandom(), data, /*done_callback=*/nullptr); } @@ -517,9 +477,9 @@ std::string ProfileTable::DebugString() const { } Status DriverTable::AppendDriverData(const DriverID &driver_id, bool is_dead) { - auto data = std::make_shared(); - data->driver_id = driver_id.Binary(); - data->is_dead = is_dead; + auto data = std::make_shared(); + data->set_driver_id(driver_id.Binary()); + data->set_is_dead(is_dead); return Append(DriverID(driver_id), driver_id, data, /*done_callback=*/nullptr); } @@ -527,7 +487,8 @@ void ClientTable::RegisterClientAddedCallback(const ClientTableCallback &callbac client_added_callback_ = callback; // Call the callback for any added clients that are cached. for (const auto &entry : client_cache_) { - if (!entry.first.IsNil() && (entry.second.entry_type == EntryType::INSERTION)) { + if (!entry.first.IsNil() && + (entry.second.entry_type() == ClientTableData::INSERTION)) { client_added_callback_(client_, entry.first, entry.second); } } @@ -537,7 +498,7 @@ void ClientTable::RegisterClientRemovedCallback(const ClientTableCallback &callb client_removed_callback_ = callback; // Call the callback for any removed clients that are cached. for (const auto &entry : client_cache_) { - if (!entry.first.IsNil() && entry.second.entry_type == EntryType::DELETION) { + if (!entry.first.IsNil() && entry.second.entry_type() == ClientTableData::DELETION) { client_removed_callback_(client_, entry.first, entry.second); } } @@ -549,7 +510,7 @@ void ClientTable::RegisterResourceCreateUpdatedCallback( // Call the callback for any clients that are cached. for (const auto &entry : client_cache_) { if (!entry.first.IsNil() && - (entry.second.entry_type == EntryType::RES_CREATEUPDATE)) { + (entry.second.entry_type() == ClientTableData::RES_CREATEUPDATE)) { resource_createupdated_callback_(client_, entry.first, entry.second); } } @@ -559,15 +520,16 @@ void ClientTable::RegisterResourceDeletedCallback(const ClientTableCallback &cal resource_deleted_callback_ = callback; // Call the callback for any clients that are cached. for (const auto &entry : client_cache_) { - if (!entry.first.IsNil() && entry.second.entry_type == EntryType::RES_DELETE) { + if (!entry.first.IsNil() && + entry.second.entry_type() == ClientTableData::RES_DELETE) { resource_deleted_callback_(client_, entry.first, entry.second); } } } void ClientTable::HandleNotification(AsyncGcsClient *client, - const ClientTableDataT &data) { - ClientID client_id = ClientID::FromBinary(data.client_id); + const ClientTableData &data) { + ClientID client_id = ClientID::FromBinary(data.client_id()); // It's possible to get duplicate notifications from the client table, so // check whether this notification is new. auto entry = client_cache_.find(client_id); @@ -578,16 +540,16 @@ void ClientTable::HandleNotification(AsyncGcsClient *client, } else { // If the entry is in the cache, then the notification is new if the client // was alive and is now dead or resources have been updated. - bool was_not_deleted = (entry->second.entry_type != EntryType::DELETION); - bool is_deleted = (data.entry_type == EntryType::DELETION); - bool is_res_modified = ((data.entry_type == EntryType::RES_CREATEUPDATE) || - (data.entry_type == EntryType::RES_DELETE)); + bool was_not_deleted = (entry->second.entry_type() != ClientTableData::DELETION); + bool is_deleted = (data.entry_type() == ClientTableData::DELETION); + bool is_res_modified = ((data.entry_type() == ClientTableData::RES_CREATEUPDATE) || + (data.entry_type() == ClientTableData::RES_DELETE)); is_notif_new = (was_not_deleted && (is_deleted || is_res_modified)); // Once a client with a given ID has been removed, it should never be added // again. If the entry was in the cache and the client was deleted, check // that this new notification is not an insertion. - if (entry->second.entry_type == EntryType::DELETION) { - RAY_CHECK((data.entry_type == EntryType::DELETION)) + if (entry->second.entry_type() == ClientTableData::DELETION) { + RAY_CHECK((data.entry_type() == ClientTableData::DELETION)) << "Notification for addition of a client that was already removed:" << client_id; } @@ -595,64 +557,64 @@ void ClientTable::HandleNotification(AsyncGcsClient *client, // Add the notification to our cache. Notifications are idempotent. // If it is a new client or a client removal, add as is - if ((data.entry_type == EntryType::INSERTION) || - (data.entry_type == EntryType::DELETION)) { + if ((data.entry_type() == ClientTableData::INSERTION) || + (data.entry_type() == ClientTableData::DELETION)) { RAY_LOG(DEBUG) << "[ClientTableNotification] ClientTable Insertion/Deletion " "notification for client id " - << client_id << ". EntryType: " << int(data.entry_type) + << client_id << ". EntryType: " << int(data.entry_type()) << ". Setting the client cache to data."; client_cache_[client_id] = data; - } else if ((data.entry_type == EntryType::RES_CREATEUPDATE) || - (data.entry_type == EntryType::RES_DELETE)) { + } else if ((data.entry_type() == ClientTableData::RES_CREATEUPDATE) || + (data.entry_type() == ClientTableData::RES_DELETE)) { RAY_LOG(DEBUG) << "[ClientTableNotification] ClientTable RES_CREATEUPDATE " "notification for client id " - << client_id << ". EntryType: " << int(data.entry_type) + << client_id << ". EntryType: " << int(data.entry_type()) << ". Updating the client cache with the delta from the log."; - ClientTableDataT &cache_data = client_cache_[client_id]; + ClientTableData &cache_data = client_cache_[client_id]; // Iterate over all resources in the new create/update notification - for (std::vector::size_type i = 0; i != data.resources_total_label.size(); i++) { - auto const &resource_name = data.resources_total_label[i]; - auto const &capacity = data.resources_total_capacity[i]; + for (std::vector::size_type i = 0; i != data.resources_total_label_size(); i++) { + auto const &resource_name = data.resources_total_label(i); + auto const &capacity = data.resources_total_capacity(i); // If resource exists in the ClientTableData, update it, else create it auto existing_resource_label = - std::find(cache_data.resources_total_label.begin(), - cache_data.resources_total_label.end(), resource_name); - if (existing_resource_label != cache_data.resources_total_label.end()) { - auto index = std::distance(cache_data.resources_total_label.begin(), + std::find(cache_data.resources_total_label().begin(), + cache_data.resources_total_label().end(), resource_name); + if (existing_resource_label != cache_data.resources_total_label().end()) { + auto index = std::distance(cache_data.resources_total_label().begin(), existing_resource_label); // Resource already exists, set capacity if updation call.. - if (data.entry_type == EntryType::RES_CREATEUPDATE) { - cache_data.resources_total_capacity[index] = capacity; + if (data.entry_type() == ClientTableData::RES_CREATEUPDATE) { + cache_data.set_resources_total_capacity(index, capacity); } // .. delete if deletion call. - else if (data.entry_type == EntryType::RES_DELETE) { - cache_data.resources_total_label.erase( - cache_data.resources_total_label.begin() + index); - cache_data.resources_total_capacity.erase( - cache_data.resources_total_capacity.begin() + index); + else if (data.entry_type() == ClientTableData::RES_DELETE) { + cache_data.mutable_resources_total_label()->erase( + cache_data.resources_total_label().begin() + index); + cache_data.mutable_resources_total_capacity()->erase( + cache_data.resources_total_capacity().begin() + index); } } else { // Resource does not exist, create resource and add capacity if it was a resource // create call. - if (data.entry_type == EntryType::RES_CREATEUPDATE) { - cache_data.resources_total_label.push_back(resource_name); - cache_data.resources_total_capacity.push_back(capacity); + if (data.entry_type() == ClientTableData::RES_CREATEUPDATE) { + cache_data.add_resources_total_label(resource_name); + cache_data.add_resources_total_capacity(capacity); } } } } // If the notification is new, call any registered callbacks. - ClientTableDataT &cache_data = client_cache_[client_id]; + ClientTableData &cache_data = client_cache_[client_id]; if (is_notif_new) { - if (data.entry_type == EntryType::INSERTION) { + if (data.entry_type() == ClientTableData::INSERTION) { if (client_added_callback_ != nullptr) { client_added_callback_(client, client_id, cache_data); } RAY_CHECK(removed_clients_.find(client_id) == removed_clients_.end()); - } else if (data.entry_type == EntryType::DELETION) { + } else if (data.entry_type() == ClientTableData::DELETION) { // NOTE(swang): The client should be added to this data structure before // the callback gets called, in case the callback depends on the data // structure getting updated. @@ -660,11 +622,11 @@ void ClientTable::HandleNotification(AsyncGcsClient *client, if (client_removed_callback_ != nullptr) { client_removed_callback_(client, client_id, cache_data); } - } else if (data.entry_type == EntryType::RES_CREATEUPDATE) { + } else if (data.entry_type() == ClientTableData::RES_CREATEUPDATE) { if (resource_createupdated_callback_ != nullptr) { resource_createupdated_callback_(client, client_id, cache_data); } - } else if (data.entry_type == EntryType::RES_DELETE) { + } else if (data.entry_type() == ClientTableData::RES_DELETE) { if (resource_deleted_callback_ != nullptr) { resource_deleted_callback_(client, client_id, cache_data); } @@ -672,54 +634,54 @@ void ClientTable::HandleNotification(AsyncGcsClient *client, } } -void ClientTable::HandleConnected(AsyncGcsClient *client, const ClientTableDataT &data) { - auto connected_client_id = ClientID::FromBinary(data.client_id); +void ClientTable::HandleConnected(AsyncGcsClient *client, const ClientTableData &data) { + auto connected_client_id = ClientID::FromBinary(data.client_id()); RAY_CHECK(client_id_ == connected_client_id) << connected_client_id << " " << client_id_; } const ClientID &ClientTable::GetLocalClientId() const { return client_id_; } -const ClientTableDataT &ClientTable::GetLocalClient() const { return local_client_; } +const ClientTableData &ClientTable::GetLocalClient() const { return local_client_; } bool ClientTable::IsRemoved(const ClientID &client_id) const { return removed_clients_.count(client_id) == 1; } -Status ClientTable::Connect(const ClientTableDataT &local_client) { +Status ClientTable::Connect(const ClientTableData &local_client) { RAY_CHECK(!disconnected_) << "Tried to reconnect a disconnected client."; - RAY_CHECK(local_client.client_id == local_client_.client_id); + RAY_CHECK(local_client.client_id() == local_client_.client_id()); local_client_ = local_client; // Construct the data to add to the client table. - auto data = std::make_shared(local_client_); - data->entry_type = EntryType::INSERTION; + auto data = std::make_shared(local_client_); + data->set_entry_type(ClientTableData::INSERTION); // Callback to handle our own successful connection once we've added // ourselves. auto add_callback = [this](AsyncGcsClient *client, const UniqueID &log_key, - const ClientTableDataT &data) { + const ClientTableData &data) { RAY_CHECK(log_key == client_log_key_); HandleConnected(client, data); // Callback for a notification from the client table. auto notification_callback = [this]( AsyncGcsClient *client, const UniqueID &log_key, - const std::vector ¬ifications) { + const std::vector ¬ifications) { RAY_CHECK(log_key == client_log_key_); - std::unordered_map connected_nodes; - std::unordered_map disconnected_nodes; + std::unordered_map connected_nodes; + std::unordered_map disconnected_nodes; for (auto ¬ification : notifications) { // This is temporary fix for Issue 4140 to avoid connect to dead nodes. // TODO(yuhguo): remove this temporary fix after GCS entry is removable. - if (notification.entry_type != EntryType::DELETION) { - connected_nodes.emplace(notification.client_id, notification); + if (notification.entry_type() != ClientTableData::DELETION) { + connected_nodes.emplace(notification.client_id(), notification); } else { - auto iter = connected_nodes.find(notification.client_id); + auto iter = connected_nodes.find(notification.client_id()); if (iter != connected_nodes.end()) { connected_nodes.erase(iter); } - disconnected_nodes.emplace(notification.client_id, notification); + disconnected_nodes.emplace(notification.client_id(), notification); } } for (const auto &pair : connected_nodes) { @@ -742,10 +704,10 @@ Status ClientTable::Connect(const ClientTableDataT &local_client) { } Status ClientTable::Disconnect(const DisconnectCallback &callback) { - auto data = std::make_shared(local_client_); - data->entry_type = EntryType::DELETION; + auto data = std::make_shared(local_client_); + data->set_entry_type(ClientTableData::DELETION); auto add_callback = [this, callback](AsyncGcsClient *client, const ClientID &id, - const ClientTableDataT &data) { + const ClientTableData &data) { HandleConnected(client, data); RAY_CHECK_OK(CancelNotifications(DriverID::Nil(), client_log_key_, id)); if (callback != nullptr) { @@ -759,24 +721,24 @@ Status ClientTable::Disconnect(const DisconnectCallback &callback) { } ray::Status ClientTable::MarkDisconnected(const ClientID &dead_client_id) { - auto data = std::make_shared(); - data->client_id = dead_client_id.Binary(); - data->entry_type = EntryType::DELETION; + auto data = std::make_shared(); + data->set_client_id(dead_client_id.Binary()); + data->set_entry_type(ClientTableData::DELETION); return Append(DriverID::Nil(), client_log_key_, data, nullptr); } void ClientTable::GetClient(const ClientID &client_id, - ClientTableDataT &client_info) const { + ClientTableData &client_info) const { RAY_CHECK(!client_id.IsNil()); auto entry = client_cache_.find(client_id); if (entry != client_cache_.end()) { client_info = entry->second; } else { - client_info.client_id = ClientID::Nil().Binary(); + client_info.set_client_id(ClientID::Nil().Binary()); } } -const std::unordered_map &ClientTable::GetAllClients() const { +const std::unordered_map &ClientTable::GetAllClients() const { return client_cache_; } @@ -798,31 +760,29 @@ Status ActorCheckpointIdTable::AddCheckpointId(const DriverID &driver_id, const ActorCheckpointID &checkpoint_id) { auto lookup_callback = [this, checkpoint_id, driver_id, actor_id]( ray::gcs::AsyncGcsClient *client, const UniqueID &id, - const ActorCheckpointIdDataT &data) { - std::shared_ptr copy = - std::make_shared(data); - copy->timestamps.push_back(current_sys_time_ms()); - copy->checkpoint_ids += checkpoint_id.Binary(); + const ActorCheckpointIdData &data) { + std::shared_ptr copy = + std::make_shared(data); + copy->add_timestamps(current_sys_time_ms()); + copy->add_checkpoint_ids(checkpoint_id.Binary()); auto num_to_keep = RayConfig::instance().num_actor_checkpoints_to_keep(); - while (copy->timestamps.size() > num_to_keep) { + while (copy->timestamps().size() > num_to_keep) { // Delete the checkpoint from actor checkpoint table. - const auto &checkpoint_id = - ActorCheckpointID::FromBinary(copy->checkpoint_ids.substr(0, kUniqueIDSize)); - RAY_LOG(DEBUG) << "Deleting checkpoint " << checkpoint_id << " for actor " - << actor_id; - copy->timestamps.erase(copy->timestamps.begin()); - copy->checkpoint_ids.erase(0, kUniqueIDSize); - client_->actor_checkpoint_table().Delete(driver_id, checkpoint_id); + const auto &to_delete = ActorCheckpointID::FromBinary(copy->checkpoint_ids(0)); + RAY_LOG(DEBUG) << "Deleting checkpoint " << to_delete << " for actor " << actor_id; + copy->mutable_checkpoint_ids()->erase(copy->mutable_checkpoint_ids()->begin()); + copy->mutable_timestamps()->erase(copy->mutable_timestamps()->begin()); + client_->actor_checkpoint_table().Delete(driver_id, to_delete); } RAY_CHECK_OK(Add(driver_id, actor_id, copy, nullptr)); }; auto failure_callback = [this, checkpoint_id, driver_id, actor_id]( ray::gcs::AsyncGcsClient *client, const UniqueID &id) { - std::shared_ptr data = - std::make_shared(); - data->actor_id = id.Binary(); - data->timestamps.push_back(current_sys_time_ms()); - data->checkpoint_ids = checkpoint_id.Binary(); + std::shared_ptr data = + std::make_shared(); + data->set_actor_id(id.Binary()); + data->add_timestamps(current_sys_time_ms()); + *data->add_checkpoint_ids() = checkpoint_id.Binary(); RAY_CHECK_OK(Add(driver_id, actor_id, data, nullptr)); }; return Lookup(driver_id, actor_id, lookup_callback, failure_callback); @@ -830,8 +790,7 @@ Status ActorCheckpointIdTable::AddCheckpointId(const DriverID &driver_id, template class Log; template class Set; -template class Log; -template class Table; +template class Log; template class Table; template class Log; template class Log; diff --git a/src/ray/gcs/tables.h b/src/ray/gcs/tables.h index 6a1d502a7f54..2ecc3440839e 100644 --- a/src/ray/gcs/tables.h +++ b/src/ray/gcs/tables.h @@ -11,10 +11,8 @@ #include "ray/common/status.h" #include "ray/util/logging.h" -#include "ray/gcs/format/gcs_generated.h" #include "ray/gcs/redis_context.h" -// TODO(rkn): Remove this include. -#include "ray/raylet/format/node_manager_generated.h" +#include "ray/protobuf/gcs.pb.h" struct redisAsyncContext; @@ -22,6 +20,25 @@ namespace ray { namespace gcs { +using rpc::ActorCheckpointData; +using rpc::ActorCheckpointIdData; +using rpc::ActorTableData; +using rpc::ClientTableData; +using rpc::DriverTableData; +using rpc::ErrorTableData; +using rpc::GcsChangeMode; +using rpc::GcsEntry; +using rpc::HeartbeatBatchTableData; +using rpc::HeartbeatTableData; +using rpc::ObjectTableData; +using rpc::ProfileTableData; +using rpc::RayResource; +using rpc::TablePrefix; +using rpc::TablePubsub; +using rpc::TaskLeaseData; +using rpc::TaskReconstructionData; +using rpc::TaskTableData; + class RedisContext; class AsyncGcsClient; @@ -48,13 +65,12 @@ class PubsubInterface { template class LogInterface { public: - using DataT = typename Data::NativeTableType; using WriteCallback = - std::function; + std::function; virtual Status Append(const DriverID &driver_id, const ID &id, - std::shared_ptr &data, const WriteCallback &done) = 0; + std::shared_ptr &data, const WriteCallback &done) = 0; virtual Status AppendAt(const DriverID &driver_id, const ID &task_id, - std::shared_ptr &data, const WriteCallback &done, + std::shared_ptr &data, const WriteCallback &done, const WriteCallback &failure, int log_length) = 0; virtual ~LogInterface(){}; }; @@ -72,12 +88,11 @@ class LogInterface { template class Log : public LogInterface, virtual public PubsubInterface { public: - using DataT = typename Data::NativeTableType; using Callback = std::function &data)>; - using NotificationCallback = std::function &data)>; + const std::vector &data)>; + using NotificationCallback = + std::function &data)>; /// The callback to call when a write to a key succeeds. using WriteCallback = typename LogInterface::WriteCallback; /// The callback to call when a SUBSCRIBE call completes and we are ready to @@ -86,7 +101,7 @@ class Log : public LogInterface, virtual public PubsubInterface { struct CallbackData { ID id; - std::shared_ptr data; + std::shared_ptr data; Callback callback; // An optional callback to call for subscription operations, where the // first message is a notification of subscription success. @@ -111,7 +126,7 @@ class Log : public LogInterface, virtual public PubsubInterface { /// \param done Callback that is called once the data has been written to the /// GCS. /// \return Status - Status Append(const DriverID &driver_id, const ID &id, std::shared_ptr &data, + Status Append(const DriverID &driver_id, const ID &id, std::shared_ptr &data, const WriteCallback &done); /// Append a log entry to a key if and only if the log has the given number @@ -126,7 +141,7 @@ class Log : public LogInterface, virtual public PubsubInterface { /// \param log_length The number of entries that the log must have for the /// append to succeed. /// \return Status - Status AppendAt(const DriverID &driver_id, const ID &id, std::shared_ptr &data, + Status AppendAt(const DriverID &driver_id, const ID &id, std::shared_ptr &data, const WriteCallback &done, const WriteCallback &failure, int log_length); @@ -259,10 +274,9 @@ class Log : public LogInterface, virtual public PubsubInterface { template class TableInterface { public: - using DataT = typename Data::NativeTableType; using WriteCallback = typename Log::WriteCallback; virtual Status Add(const DriverID &driver_id, const ID &task_id, - std::shared_ptr &data, const WriteCallback &done) = 0; + std::shared_ptr &data, const WriteCallback &done) = 0; virtual ~TableInterface(){}; }; @@ -280,9 +294,8 @@ class Table : private Log, public TableInterface, virtual public PubsubInterface { public: - using DataT = typename Log::DataT; using Callback = - std::function; + std::function; using WriteCallback = typename Log::WriteCallback; /// The callback to call when a Lookup call returns an empty entry. using FailureCallback = std::function; @@ -305,7 +318,7 @@ class Table : private Log, /// \param done Callback that is called once the data has been written to the /// GCS. /// \return Status - Status Add(const DriverID &driver_id, const ID &id, std::shared_ptr &data, + Status Add(const DriverID &driver_id, const ID &id, std::shared_ptr &data, const WriteCallback &done); /// Lookup an entry asynchronously. @@ -369,12 +382,11 @@ class Table : private Log, template class SetInterface { public: - using DataT = typename Data::NativeTableType; using WriteCallback = typename Log::WriteCallback; - virtual Status Add(const DriverID &driver_id, const ID &id, - std::shared_ptr &data, const WriteCallback &done) = 0; + virtual Status Add(const DriverID &driver_id, const ID &id, std::shared_ptr &data, + const WriteCallback &done) = 0; virtual Status Remove(const DriverID &driver_id, const ID &id, - std::shared_ptr &data, const WriteCallback &done) = 0; + std::shared_ptr &data, const WriteCallback &done) = 0; virtual ~SetInterface(){}; }; @@ -392,7 +404,6 @@ class Set : private Log, public SetInterface, virtual public PubsubInterface { public: - using DataT = typename Log::DataT; using Callback = typename Log::Callback; using WriteCallback = typename Log::WriteCallback; using NotificationCallback = typename Log::NotificationCallback; @@ -414,7 +425,7 @@ class Set : private Log, /// \param done Callback that is called once the data has been written to the /// GCS. /// \return Status - Status Add(const DriverID &driver_id, const ID &id, std::shared_ptr &data, + Status Add(const DriverID &driver_id, const ID &id, std::shared_ptr &data, const WriteCallback &done); /// Remove an entry from the set. @@ -425,7 +436,7 @@ class Set : private Log, /// \param done Callback that is called once the data has been written to the /// GCS. /// \return Status - Status Remove(const DriverID &driver_id, const ID &id, std::shared_ptr &data, + Status Remove(const DriverID &driver_id, const ID &id, std::shared_ptr &data, const WriteCallback &done); Status Subscribe(const DriverID &driver_id, const ClientID &client_id, @@ -454,8 +465,7 @@ class Set : private Log, template class HashInterface { public: - using DataT = typename Data::NativeTableType; - using DataMap = std::unordered_map>; + using DataMap = std::unordered_map>; // Reuse Log's SubscriptionCallback when Subscribe is successfully called. using SubscriptionCallback = typename Log::SubscriptionCallback; @@ -544,8 +554,7 @@ class Hash : private Log, public HashInterface, virtual public PubsubInterface { public: - using DataT = typename Log::DataT; - using DataMap = std::unordered_map>; + using DataMap = std::unordered_map>; using HashCallback = typename HashInterface::HashCallback; using HashRemoveCallback = typename HashInterface::HashRemoveCallback; using HashNotificationCallback = @@ -595,7 +604,7 @@ class DynamicResourceTable : public Hash { DynamicResourceTable(const std::vector> &contexts, AsyncGcsClient *client) : Hash(contexts, client) { - pubsub_channel_ = TablePubsub::NODE_RESOURCE; + pubsub_channel_ = TablePubsub::NODE_RESOURCE_PUBSUB; prefix_ = TablePrefix::NODE_RESOURCE; }; @@ -607,7 +616,7 @@ class ObjectTable : public Set { ObjectTable(const std::vector> &contexts, AsyncGcsClient *client) : Set(contexts, client) { - pubsub_channel_ = TablePubsub::OBJECT; + pubsub_channel_ = TablePubsub::OBJECT_PUBSUB; prefix_ = TablePrefix::OBJECT; }; @@ -619,7 +628,7 @@ class HeartbeatTable : public Table { HeartbeatTable(const std::vector> &contexts, AsyncGcsClient *client) : Table(contexts, client) { - pubsub_channel_ = TablePubsub::HEARTBEAT; + pubsub_channel_ = TablePubsub::HEARTBEAT_PUBSUB; prefix_ = TablePrefix::HEARTBEAT; } virtual ~HeartbeatTable() {} @@ -630,7 +639,7 @@ class HeartbeatBatchTable : public Table { HeartbeatBatchTable(const std::vector> &contexts, AsyncGcsClient *client) : Table(contexts, client) { - pubsub_channel_ = TablePubsub::HEARTBEAT_BATCH; + pubsub_channel_ = TablePubsub::HEARTBEAT_BATCH_PUBSUB; prefix_ = TablePrefix::HEARTBEAT_BATCH; } virtual ~HeartbeatBatchTable() {} @@ -641,7 +650,7 @@ class DriverTable : public Log { DriverTable(const std::vector> &contexts, AsyncGcsClient *client) : Log(contexts, client) { - pubsub_channel_ = TablePubsub::DRIVER; + pubsub_channel_ = TablePubsub::DRIVER_PUBSUB; prefix_ = TablePrefix::DRIVER; }; @@ -655,18 +664,6 @@ class DriverTable : public Log { Status AppendDriverData(const DriverID &driver_id, bool is_dead); }; -class FunctionTable : public Table { - public: - FunctionTable(const std::vector> &contexts, - AsyncGcsClient *client) - : Table(contexts, client) { - pubsub_channel_ = TablePubsub::NO_PUBLISH; - prefix_ = TablePrefix::FUNCTION; - }; -}; - -using ClassTable = Table; - /// Actor table starts with an ALIVE entry, which represents the first time the actor /// is created. This may be followed by 0 or more pairs of RECONSTRUCTING, ALIVE entries, /// which represent each time the actor fails (RECONSTRUCTING) and gets recreated (ALIVE). @@ -677,7 +674,7 @@ class ActorTable : public Log { ActorTable(const std::vector> &contexts, AsyncGcsClient *client) : Log(contexts, client) { - pubsub_channel_ = TablePubsub::ACTOR; + pubsub_channel_ = TablePubsub::ACTOR_PUBSUB; prefix_ = TablePrefix::ACTOR; } }; @@ -696,12 +693,12 @@ class TaskLeaseTable : public Table { TaskLeaseTable(const std::vector> &contexts, AsyncGcsClient *client) : Table(contexts, client) { - pubsub_channel_ = TablePubsub::TASK_LEASE; + pubsub_channel_ = TablePubsub::TASK_LEASE_PUBSUB; prefix_ = TablePrefix::TASK_LEASE; } Status Add(const DriverID &driver_id, const TaskID &id, - std::shared_ptr &data, const WriteCallback &done) override { + std::shared_ptr &data, const WriteCallback &done) override { RAY_RETURN_NOT_OK((Table::Add(driver_id, id, data, done))); // Mark the entry for expiration in Redis. It's okay if this command fails // since the lease entry itself contains the expiration period. In the @@ -709,9 +706,8 @@ class TaskLeaseTable : public Table { // entry will overestimate the expiration time. // TODO(swang): Use a common helper function to format the key instead of // hardcoding it to match the Redis module. - std::vector args = {"PEXPIRE", - EnumNameTablePrefix(prefix_) + id.Binary(), - std::to_string(data->timeout)}; + std::vector args = {"PEXPIRE", TablePrefix_Name(prefix_) + id.Binary(), + std::to_string(data->timeout())}; return GetRedisContext(id)->RunArgvAsync(args); } @@ -747,12 +743,12 @@ class ActorCheckpointIdTable : public Table { namespace raylet { -class TaskTable : public Table { +class TaskTable : public Table { public: TaskTable(const std::vector> &contexts, AsyncGcsClient *client) : Table(contexts, client) { - pubsub_channel_ = TablePubsub::RAYLET_TASK; + pubsub_channel_ = TablePubsub::RAYLET_TASK_PUBSUB; prefix_ = TablePrefix::RAYLET_TASK; } @@ -770,7 +766,7 @@ class ErrorTable : private Log { ErrorTable(const std::vector> &contexts, AsyncGcsClient *client) : Log(contexts, client) { - pubsub_channel_ = TablePubsub::ERROR_INFO; + pubsub_channel_ = TablePubsub::ERROR_INFO_PUBSUB; prefix_ = TablePrefix::ERROR_INFO; }; @@ -815,10 +811,6 @@ class ProfileTable : private Log { std::string DebugString() const; }; -using CustomSerializerTable = Table; - -using ConfigTable = Table; - /// \class ClientTable /// /// The ClientTable stores information about active and inactive clients. It is @@ -831,7 +823,7 @@ using ConfigTable = Table; class ClientTable : public Log { public: using ClientTableCallback = std::function; + AsyncGcsClient *client, const ClientID &id, const ClientTableData &data)>; using DisconnectCallback = std::function; ClientTable(const std::vector> &contexts, AsyncGcsClient *client, const ClientID &client_id) @@ -842,11 +834,11 @@ class ClientTable : public Log { disconnected_(false), client_id_(client_id), local_client_() { - pubsub_channel_ = TablePubsub::CLIENT; + pubsub_channel_ = TablePubsub::CLIENT_PUBSUB; prefix_ = TablePrefix::CLIENT; // Set the local client's ID. - local_client_.client_id = client_id.Binary(); + local_client_.set_client_id(client_id.Binary()); }; /// Connect as a client to the GCS. This registers us in the client table @@ -855,7 +847,7 @@ class ClientTable : public Log { /// \param Information about the connecting client. This must have the /// same client_id as the one set in the client table. /// \return Status - ray::Status Connect(const ClientTableDataT &local_client); + ray::Status Connect(const ClientTableData &local_client); /// Disconnect the client from the GCS. The client ID assigned during /// registration should never be reused after disconnecting. @@ -898,7 +890,7 @@ class ClientTable : public Log { /// about the client in the cache, then the reference will be modified to /// contain that information. Else, the reference will be updated to contain /// a nil client ID. - void GetClient(const ClientID &client, ClientTableDataT &client_info) const; + void GetClient(const ClientID &client, ClientTableData &client_info) const; /// Get the local client's ID. /// @@ -908,7 +900,7 @@ class ClientTable : public Log { /// Get the local client's information. /// /// \return The local client's information. - const ClientTableDataT &GetLocalClient() const; + const ClientTableData &GetLocalClient() const; /// Check whether the given client is removed. /// @@ -919,7 +911,7 @@ class ClientTable : public Log { /// Get the information of all clients. /// /// \return The client ID to client information map. - const std::unordered_map &GetAllClients() const; + const std::unordered_map &GetAllClients() const; /// Lookup the client data in the client table. /// @@ -940,15 +932,15 @@ class ClientTable : public Log { private: /// Handle a client table notification. - void HandleNotification(AsyncGcsClient *client, const ClientTableDataT ¬ifications); + void HandleNotification(AsyncGcsClient *client, const ClientTableData ¬ifications); /// Handle this client's successful connection to the GCS. - void HandleConnected(AsyncGcsClient *client, const ClientTableDataT &client_data); + void HandleConnected(AsyncGcsClient *client, const ClientTableData &client_data); /// Whether this client has called Disconnect(). bool disconnected_; /// This client's ID. const ClientID client_id_; /// Information about this client. - ClientTableDataT local_client_; + ClientTableData local_client_; /// The callback to call when a new client is added. ClientTableCallback client_added_callback_; /// The callback to call when a client is removed. @@ -958,7 +950,7 @@ class ClientTable : public Log { /// The callback to call when a resource is deleted. ClientTableCallback resource_deleted_callback_; /// A cache for information about all clients. - std::unordered_map client_cache_; + std::unordered_map client_cache_; /// The set of removed clients. std::unordered_set removed_clients_; }; diff --git a/src/ray/object_manager/object_directory.cc b/src/ray/object_manager/object_directory.cc index 5b6794a505d3..454379d18302 100644 --- a/src/ray/object_manager/object_directory.cc +++ b/src/ray/object_manager/object_directory.cc @@ -8,18 +8,22 @@ ObjectDirectory::ObjectDirectory(boost::asio::io_service &io_service, namespace { +using ray::rpc::ClientTableData; +using ray::rpc::GcsChangeMode; +using ray::rpc::ObjectTableData; + /// Process a notification of the object table entries and store the result in /// client_ids. This assumes that client_ids already contains the result of the /// object table entries up to but not including this notification. void UpdateObjectLocations(const GcsChangeMode change_mode, - const std::vector &location_updates, + const std::vector &location_updates, const ray::gcs::ClientTable &client_table, std::unordered_set *client_ids) { // location_updates contains the updates of locations of the object. // with GcsChangeMode, we can determine whether the update mode is // addition or deletion. for (const auto &object_table_data : location_updates) { - ClientID client_id = ClientID::FromBinary(object_table_data.manager); + ClientID client_id = ClientID::FromBinary(object_table_data.manager()); if (change_mode != GcsChangeMode::REMOVE) { client_ids->insert(client_id); } else { @@ -42,7 +46,7 @@ void ObjectDirectory::RegisterBackend() { auto object_notification_callback = [this](gcs::AsyncGcsClient *client, const ObjectID &object_id, const GcsChangeMode change_mode, - const std::vector &location_updates) { + const std::vector &location_updates) { // Objects are added to this map in SubscribeObjectLocations. auto it = listeners_.find(object_id); // Do nothing for objects we are not listening for. @@ -79,9 +83,9 @@ ray::Status ObjectDirectory::ReportObjectAdded( const object_manager::protocol::ObjectInfoT &object_info) { RAY_LOG(DEBUG) << "Reporting object added to GCS " << object_id; // Append the addition entry to the object table. - auto data = std::make_shared(); - data->manager = client_id.Binary(); - data->object_size = object_info.data_size; + auto data = std::make_shared(); + data->set_manager(client_id.Binary()); + data->set_object_size(object_info.data_size); ray::Status status = gcs_client_->object_table().Add(DriverID::Nil(), object_id, data, nullptr); return status; @@ -92,9 +96,9 @@ ray::Status ObjectDirectory::ReportObjectRemoved( const object_manager::protocol::ObjectInfoT &object_info) { RAY_LOG(DEBUG) << "Reporting object removed to GCS " << object_id; // Append the eviction entry to the object table. - auto data = std::make_shared(); - data->manager = client_id.Binary(); - data->object_size = object_info.data_size; + auto data = std::make_shared(); + data->set_manager(client_id.Binary()); + data->set_object_size(object_info.data_size); ray::Status status = gcs_client_->object_table().Remove(DriverID::Nil(), object_id, data, nullptr); return status; @@ -102,14 +106,14 @@ ray::Status ObjectDirectory::ReportObjectRemoved( void ObjectDirectory::LookupRemoteConnectionInfo( RemoteConnectionInfo &connection_info) const { - ClientTableDataT client_data; + ClientTableData client_data; gcs_client_->client_table().GetClient(connection_info.client_id, client_data); - ClientID result_client_id = ClientID::FromBinary(client_data.client_id); + ClientID result_client_id = ClientID::FromBinary(client_data.client_id()); if (!result_client_id.IsNil()) { RAY_CHECK(result_client_id == connection_info.client_id); - if (client_data.entry_type == EntryType::INSERTION) { - connection_info.ip = client_data.node_manager_address; - connection_info.port = static_cast(client_data.object_manager_port); + if (client_data.entry_type() == ClientTableData::INSERTION) { + connection_info.ip = client_data.node_manager_address(); + connection_info.port = static_cast(client_data.object_manager_port()); } } } @@ -208,7 +212,7 @@ ray::Status ObjectDirectory::LookupLocations(const ObjectID &object_id, status = gcs_client_->object_table().Lookup( DriverID::Nil(), object_id, [this, callback](gcs::AsyncGcsClient *client, const ObjectID &object_id, - const std::vector &location_updates) { + const std::vector &location_updates) { // Build the set of current locations based on the entries in the log. std::unordered_set client_ids; UpdateObjectLocations(GcsChangeMode::APPEND_OR_ADD, location_updates, diff --git a/src/ray/object_manager/object_manager.cc b/src/ray/object_manager/object_manager.cc index 954162c21aef..964cee605ced 100644 --- a/src/ray/object_manager/object_manager.cc +++ b/src/ray/object_manager/object_manager.cc @@ -309,15 +309,15 @@ void ObjectManager::HandleSendFinished(const ObjectID &object_id, // TODO(rkn): What do we want to do if the send failed? } - ProfileEventT profile_event; - profile_event.event_type = "transfer_send"; - profile_event.start_time = start_time; - profile_event.end_time = end_time; + rpc::ProfileTableData::ProfileEvent profile_event; + profile_event.set_event_type("transfer_send"); + profile_event.set_start_time(start_time); + profile_event.set_end_time(end_time); // Encode the object ID, client ID, chunk index, and status as a json list, // which will be parsed by the reader of the profile table. - profile_event.extra_data = "[\"" + object_id.Hex() + "\",\"" + client_id.Hex() + "\"," + - std::to_string(chunk_index) + ",\"" + status.ToString() + - "\"]"; + profile_event.set_extra_data("[\"" + object_id.Hex() + "\",\"" + client_id.Hex() + + "\"," + std::to_string(chunk_index) + ",\"" + + status.ToString() + "\"]"); profile_events_.push_back(profile_event); } @@ -329,15 +329,15 @@ void ObjectManager::HandleReceiveFinished(const ObjectID &object_id, // TODO(rkn): What do we want to do if the send failed? } - ProfileEventT profile_event; - profile_event.event_type = "transfer_receive"; - profile_event.start_time = start_time; - profile_event.end_time = end_time; + rpc::ProfileTableData::ProfileEvent profile_event; + profile_event.set_event_type("transfer_receive"); + profile_event.set_start_time(start_time); + profile_event.set_end_time(end_time); // Encode the object ID, client ID, chunk index, and status as a json list, // which will be parsed by the reader of the profile table. - profile_event.extra_data = "[\"" + object_id.Hex() + "\",\"" + client_id.Hex() + "\"," + - std::to_string(chunk_index) + ",\"" + status.ToString() + - "\"]"; + profile_event.set_extra_data("[\"" + object_id.Hex() + "\",\"" + client_id.Hex() + + "\"," + std::to_string(chunk_index) + ",\"" + + status.ToString() + "\"]"); profile_events_.push_back(profile_event); } @@ -801,11 +801,12 @@ void ObjectManager::ReceivePullRequest(std::shared_ptr &con ObjectID object_id = ObjectID::FromBinary(pr->object_id()->str()); ClientID client_id = ClientID::FromBinary(pr->client_id()->str()); - ProfileEventT profile_event; - profile_event.event_type = "receive_pull_request"; - profile_event.start_time = current_sys_time_seconds(); - profile_event.end_time = profile_event.start_time; - profile_event.extra_data = "[\"" + object_id.Hex() + "\",\"" + client_id.Hex() + "\"]"; + rpc::ProfileTableData::ProfileEvent profile_event; + profile_event.set_event_type("receive_pull_request"); + profile_event.set_start_time(current_sys_time_seconds()); + profile_event.set_end_time(profile_event.start_time()); + profile_event.set_extra_data("[\"" + object_id.Hex() + "\",\"" + client_id.Hex() + + "\"]"); profile_events_.push_back(profile_event); Push(object_id, client_id); @@ -938,13 +939,13 @@ void ObjectManager::SpreadFreeObjectRequest(const std::vector &object_ } } -ProfileTableDataT ObjectManager::GetAndResetProfilingInfo() { - ProfileTableDataT profile_info; - profile_info.component_type = "object_manager"; - profile_info.component_id = client_id_.Binary(); +rpc::ProfileTableData ObjectManager::GetAndResetProfilingInfo() { + rpc::ProfileTableData profile_info; + profile_info.set_component_type("object_manager"); + profile_info.set_component_id(client_id_.Binary()); for (auto const &profile_event : profile_events_) { - profile_info.profile_events.emplace_back(new ProfileEventT(profile_event)); + profile_info.add_profile_events()->CopyFrom(profile_event); } profile_events_.clear(); diff --git a/src/ray/object_manager/object_manager.h b/src/ray/object_manager/object_manager.h index 6318250ae3e8..6664dd0a93bd 100644 --- a/src/ray/object_manager/object_manager.h +++ b/src/ray/object_manager/object_manager.h @@ -180,7 +180,7 @@ class ObjectManager : public ObjectManagerInterface { /// /// \return All profiling information that has accumulated since the last call /// to this method. - ProfileTableDataT GetAndResetProfilingInfo(); + rpc::ProfileTableData GetAndResetProfilingInfo(); /// Returns debug string for class. /// @@ -412,7 +412,7 @@ class ObjectManager : public ObjectManagerInterface { /// Profiling events that are to be batched together and added to the profile /// table in the GCS. - std::vector profile_events_; + std::vector profile_events_; /// Internally maintained random number generator. std::mt19937_64 gen_; diff --git a/src/ray/object_manager/test/object_manager_stress_test.cc b/src/ray/object_manager/test/object_manager_stress_test.cc index 55aa59124a99..2d5292842acf 100644 --- a/src/ray/object_manager/test/object_manager_stress_test.cc +++ b/src/ray/object_manager/test/object_manager_stress_test.cc @@ -11,6 +11,8 @@ namespace ray { +using rpc::ClientTableData; + std::string store_executable; static inline void flushall_redis(void) { @@ -52,10 +54,10 @@ class MockServer { std::string ip = endpoint.address().to_string(); unsigned short object_manager_port = endpoint.port(); - ClientTableDataT client_info = gcs_client_->client_table().GetLocalClient(); - client_info.node_manager_address = ip; - client_info.node_manager_port = object_manager_port; - client_info.object_manager_port = object_manager_port; + ClientTableData client_info = gcs_client_->client_table().GetLocalClient(); + client_info.set_node_manager_address(ip); + client_info.set_node_manager_port(object_manager_port); + client_info.set_object_manager_port(object_manager_port); ray::Status status = gcs_client_->client_table().Connect(client_info); object_manager_.RegisterGcs(); return status; @@ -242,8 +244,8 @@ class StressTestObjectManager : public TestObjectManagerBase { client_id_2 = gcs_client_2->client_table().GetLocalClientId(); gcs_client_1->client_table().RegisterClientAddedCallback( [this](gcs::AsyncGcsClient *client, const ClientID &id, - const ClientTableDataT &data) { - ClientID parsed_id = ClientID::FromBinary(data.client_id); + const ClientTableData &data) { + ClientID parsed_id = ClientID::FromBinary(data.client_id()); if (parsed_id == client_id_1 || parsed_id == client_id_2) { num_connected_clients += 1; } @@ -438,16 +440,16 @@ class StressTestObjectManager : public TestObjectManagerBase { RAY_LOG(DEBUG) << "\n" << "All connected clients:" << "\n"; - ClientTableDataT data; + ClientTableData data; gcs_client_1->client_table().GetClient(client_id_1, data); - RAY_LOG(DEBUG) << "ClientID=" << ClientID::FromBinary(data.client_id) << "\n" - << "ClientIp=" << data.node_manager_address << "\n" - << "ClientPort=" << data.node_manager_port; - ClientTableDataT data2; + RAY_LOG(DEBUG) << "ClientID=" << ClientID::FromBinary(data.client_id()) << "\n" + << "ClientIp=" << data.node_manager_address() << "\n" + << "ClientPort=" << data.node_manager_port(); + ClientTableData data2; gcs_client_1->client_table().GetClient(client_id_2, data2); - RAY_LOG(DEBUG) << "ClientID=" << ClientID::FromBinary(data2.client_id) << "\n" - << "ClientIp=" << data2.node_manager_address << "\n" - << "ClientPort=" << data2.node_manager_port; + RAY_LOG(DEBUG) << "ClientID=" << ClientID::FromBinary(data2.client_id()) << "\n" + << "ClientIp=" << data2.node_manager_address() << "\n" + << "ClientPort=" << data2.node_manager_port(); } }; diff --git a/src/ray/object_manager/test/object_manager_test.cc b/src/ray/object_manager/test/object_manager_test.cc index ee6c78d8ed42..45b80a267f2f 100644 --- a/src/ray/object_manager/test/object_manager_test.cc +++ b/src/ray/object_manager/test/object_manager_test.cc @@ -14,6 +14,8 @@ int64_t wait_timeout_ms; namespace ray { +using rpc::ClientTableData; + static inline void flushall_redis(void) { redisContext *context = redisConnect("127.0.0.1", 6379); freeReplyObject(redisCommand(context, "FLUSHALL")); @@ -46,10 +48,10 @@ class MockServer { std::string ip = endpoint.address().to_string(); unsigned short object_manager_port = endpoint.port(); - ClientTableDataT client_info = gcs_client_->client_table().GetLocalClient(); - client_info.node_manager_address = ip; - client_info.node_manager_port = object_manager_port; - client_info.object_manager_port = object_manager_port; + ClientTableData client_info = gcs_client_->client_table().GetLocalClient(); + client_info.set_node_manager_address(ip); + client_info.set_node_manager_port(object_manager_port); + client_info.set_object_manager_port(object_manager_port); ray::Status status = gcs_client_->client_table().Connect(client_info); object_manager_.RegisterGcs(); return status; @@ -221,8 +223,8 @@ class TestObjectManager : public TestObjectManagerBase { client_id_2 = gcs_client_2->client_table().GetLocalClientId(); gcs_client_1->client_table().RegisterClientAddedCallback( [this](gcs::AsyncGcsClient *client, const ClientID &id, - const ClientTableDataT &data) { - ClientID parsed_id = ClientID::FromBinary(data.client_id); + const ClientTableData &data) { + ClientID parsed_id = ClientID::FromBinary(data.client_id()); if (parsed_id == client_id_1 || parsed_id == client_id_2) { num_connected_clients += 1; } @@ -457,19 +459,19 @@ class TestObjectManager : public TestObjectManagerBase { RAY_LOG(DEBUG) << "\n" << "Server client ids:" << "\n"; - ClientTableDataT data; + ClientTableData data; gcs_client_1->client_table().GetClient(client_id_1, data); - RAY_LOG(DEBUG) << (ClientID::FromBinary(data.client_id).IsNil()); - RAY_LOG(DEBUG) << "Server 1 ClientID=" << ClientID::FromBinary(data.client_id); - RAY_LOG(DEBUG) << "Server 1 ClientIp=" << data.node_manager_address; - RAY_LOG(DEBUG) << "Server 1 ClientPort=" << data.node_manager_port; - ASSERT_EQ(client_id_1, ClientID::FromBinary(data.client_id)); - ClientTableDataT data2; + RAY_LOG(DEBUG) << (ClientID::FromBinary(data.client_id()).IsNil()); + RAY_LOG(DEBUG) << "Server 1 ClientID=" << ClientID::FromBinary(data.client_id()); + RAY_LOG(DEBUG) << "Server 1 ClientIp=" << data.node_manager_address(); + RAY_LOG(DEBUG) << "Server 1 ClientPort=" << data.node_manager_port(); + ASSERT_EQ(client_id_1, ClientID::FromBinary(data.client_id())); + ClientTableData data2; gcs_client_1->client_table().GetClient(client_id_2, data2); - RAY_LOG(DEBUG) << "Server 2 ClientID=" << ClientID::FromBinary(data2.client_id); - RAY_LOG(DEBUG) << "Server 2 ClientIp=" << data2.node_manager_address; - RAY_LOG(DEBUG) << "Server 2 ClientPort=" << data2.node_manager_port; - ASSERT_EQ(client_id_2, ClientID::FromBinary(data2.client_id)); + RAY_LOG(DEBUG) << "Server 2 ClientID=" << ClientID::FromBinary(data2.client_id()); + RAY_LOG(DEBUG) << "Server 2 ClientIp=" << data2.node_manager_address(); + RAY_LOG(DEBUG) << "Server 2 ClientPort=" << data2.node_manager_port(); + ASSERT_EQ(client_id_2, ClientID::FromBinary(data2.client_id())); } }; diff --git a/src/ray/protobuf/gcs.proto b/src/ray/protobuf/gcs.proto new file mode 100644 index 000000000000..d0b2c5e007fe --- /dev/null +++ b/src/ray/protobuf/gcs.proto @@ -0,0 +1,280 @@ +syntax = "proto3"; + +package ray.rpc; + +option java_package = "org.ray.runtime.generated"; + +// Language of a worker or task. +enum Language { + PYTHON = 0; + CPP = 1; + JAVA = 2; +} + +// These indexes are mapped to strings in ray_redis_module.cc. +enum TablePrefix { + TABLE_PREFIX_MIN = 0; + UNUSED = 1; + TASK = 2; + RAYLET_TASK = 3; + CLIENT = 4; + OBJECT = 5; + ACTOR = 6; + FUNCTION = 7; + TASK_RECONSTRUCTION = 8; + HEARTBEAT = 9; + HEARTBEAT_BATCH = 10; + ERROR_INFO = 11; + DRIVER = 12; + PROFILE = 13; + TASK_LEASE = 14; + ACTOR_CHECKPOINT = 15; + ACTOR_CHECKPOINT_ID = 16; + NODE_RESOURCE = 17; + TABLE_PREFIX_MAX = 18; +} + +// The channel that Add operations to the Table should be published on, if any. +enum TablePubsub { + TABLE_PUBSUB_MIN = 0; + NO_PUBLISH = 1; + TASK_PUBSUB = 2; + RAYLET_TASK_PUBSUB = 3; + CLIENT_PUBSUB = 4; + OBJECT_PUBSUB = 5; + ACTOR_PUBSUB = 6; + HEARTBEAT_PUBSUB = 7; + HEARTBEAT_BATCH_PUBSUB = 8; + ERROR_INFO_PUBSUB = 9; + TASK_LEASE_PUBSUB = 10; + DRIVER_PUBSUB = 11; + NODE_RESOURCE_PUBSUB = 12; + TABLE_PUBSUB_MAX = 13; +} + +enum GcsChangeMode { + APPEND_OR_ADD = 0; + REMOVE = 1; +} + +message GcsEntry { + GcsChangeMode change_mode = 1; + bytes id = 2; + repeated bytes entries = 3; +} + +message ObjectTableData { + // The size of the object. + uint64 object_size = 1; + // The node manager ID that this object appeared on or was evicted by. + bytes manager = 2; +} + +message TaskReconstructionData { + // The number of times this task has been reconstructed so far. + uint64 num_reconstructions = 1; + // The node manager that is trying to reconstruct the task. + bytes node_manager_id = 2; +} + +// TODO(hchen): Task table currently still uses flatbuffers-defined data structure +// (`Task` in `node_manager.fbs`), because a lot of code depends on that. This should +// be migrated to protobuf very soon. +message TaskTableData { + // Flatbuffers-serialized content of the task, see `src/ray/raylet/task.h`. + bytes task = 1; +} + +message ActorTableData { + // State of an actor. + enum ActorState { + // Actor is alive. + ALIVE = 0; + // Actor is dead, now being reconstructed. + // After reconstruction finishes, the state will become alive again. + RECONSTRUCTING = 1; + // Actor is already dead and won't be reconstructed. + DEAD = 2; + } + // The ID of the actor that was created. + bytes actor_id = 1; + // The dummy object ID returned by the actor creation task. If the actor + // dies, then this is the object that should be reconstructed for the actor + // to be recreated. + bytes actor_creation_dummy_object_id = 2; + // The ID of the driver that created the actor. + bytes driver_id = 3; + // The ID of the node manager that created the actor. + bytes node_manager_id = 4; + // Current state of this actor. + ActorState state = 5; + // Max number of times this actor should be reconstructed. + uint64 max_reconstructions = 6; + // Remaining number of reconstructions. + uint64 remaining_reconstructions = 7; +} + +message ErrorTableData { + // The ID of the driver that the error is for. + bytes driver_id = 1; + // The type of the error. + string type = 2; + // The error message. + string error_message = 3; + // The timestamp of the error message. + double timestamp = 4; +} + +message ProfileTableData { + // Represents a profile event. + message ProfileEvent { + // The type of the event. + string event_type = 1; + // The start time of the event. + double start_time = 2; + // The end time of the event. If the event is a point event, then this should + // be the same as the start time. + double end_time = 3; + // Additional data associated with the event. This data must be serialized + // using JSON. + string extra_data = 4; + } + + // The type of the component that generated the event, e.g., worker or + // object_manager, or node_manager. + string component_type = 1; + // An identifier for the component that generated the event. + bytes component_id = 2; + // An identifier for the node that generated the event. + string node_ip_address = 3; + // This is a batch of profiling events. We batch these together for + // performance reasons because a single task may generate many events, and + // we don't want each event to require a GCS command. + repeated ProfileEvent profile_events = 4; +} + +message RayResource { + // The type of the resource. + string resource_name = 1; + // The total capacity of this resource type. + double resource_capacity = 2; +} + +message ClientTableData { + // Enum for the entry type in the ClientTable + enum EntryType { + INSERTION = 0; + DELETION = 1; + RES_CREATEUPDATE = 2; + RES_DELETE = 3; + } + + // The client ID of the client that the message is about. + bytes client_id = 1; + // The IP address of the client's node manager. + string node_manager_address = 2; + // The IPC socket name of the client's raylet. + string raylet_socket_name = 3; + // The IPC socket name of the client's plasma store. + string object_store_socket_name = 4; + // The port at which the client's node manager is listening for TCP + // connections from other node managers. + int32 node_manager_port = 5; + // The port at which the client's object manager is listening for TCP + // connections from other object managers. + int32 object_manager_port = 6; + // Enum to store the entry type in the log + EntryType entry_type = 7; + + // TODO(hchen): Define the following resources in map format. + repeated string resources_total_label = 8; + repeated double resources_total_capacity = 9; +} + +message HeartbeatTableData { + // Node manager client id + bytes client_id = 1; + // TODO(hchen): Define the following resources in map format. + // Resource capacity currently available on this node manager. + repeated string resources_available_label = 2; + repeated double resources_available_capacity = 3; + // Total resource capacity configured for this node manager. + repeated string resources_total_label = 4; + repeated double resources_total_capacity = 5; + // Aggregate outstanding resource load on this node manager. + repeated string resource_load_label = 6; + repeated double resource_load_capacity = 7; +} + +message HeartbeatBatchTableData { + repeated HeartbeatTableData batch = 1; +} + +// Data for a lease on task execution. +message TaskLeaseData { + // Node manager client ID. + bytes node_manager_id = 1; + // The time that the lease was last acquired at. NOTE(swang): This is the + // system clock time according to the node that added the entry and is not + // synchronized with other nodes. + uint64 acquired_at = 2; + // The period that the lease is active for. + uint64 timeout = 3; +} + +message DriverTableData { + // The driver ID. + bytes driver_id = 1; + // Whether it's dead. + bool is_dead = 2; +} + +// This table stores the actor checkpoint data. An actor checkpoint +// is the snapshot of an actor's state in the actor registration. +// See `actor_registration.h` for more detailed explanation of these fields. +message ActorCheckpointData { + // ID of this actor. + bytes actor_id = 1; + // The dummy object ID of actor's most recently executed task. + bytes execution_dependency = 2; + // A list of IDs of this actor's handles. + repeated bytes handle_ids = 3; + // The task counters of the above handles. + repeated uint64 task_counters = 4; + // The frontier dependencies of the above handles. + repeated bytes frontier_dependencies = 5; + // A list of unreleased dummy objects from this actor. + repeated bytes unreleased_dummy_objects = 6; + // The numbers of dependencies for the above unreleased dummy objects. + repeated uint32 num_dummy_object_dependencies = 7; +} + +// This table stores the actor-to-available-checkpoint-ids mapping. +message ActorCheckpointIdData { + // ID of this actor. + bytes actor_id = 1; + // IDs of this actor's available checkpoints. + repeated bytes checkpoint_ids = 2; + // A list of the timestamps for each of the above `checkpoint_ids`. + repeated uint64 timestamps = 3; +} + +// This enum type is used as object's metadata to indicate the object's creating +// task has failed because of a certain error. +// TODO(hchen): We may want to make these errors more specific. E.g., we may want +// to distinguish between intentional and expected actor failures, and between +// worker process failure and node failure. +enum ErrorType { + // Indicates that a task failed because the worker died unexpectedly while executing it. + WORKER_DIED = 0; + // Indicates that a task failed because the actor died unexpectedly before finishing it. + ACTOR_DIED = 1; + // Indicates that an object is lost and cannot be reconstructed. + // Note, this currently only happens to actor objects. When the actor's state is already + // after the object's creating task, the actor cannot re-run the task. + // TODO(hchen): we may want to reuse this error type for more cases. E.g., + // 1) A object that was put by the driver. + // 2) The object's creating task is already cleaned up from GCS (this currently + // crashes raylet). + OBJECT_UNRECONSTRUCTABLE = 2; +} diff --git a/src/ray/raylet/actor_registration.cc b/src/ray/raylet/actor_registration.cc index cc587bc4d74e..7f940006b5be 100644 --- a/src/ray/raylet/actor_registration.cc +++ b/src/ray/raylet/actor_registration.cc @@ -8,34 +8,35 @@ namespace ray { namespace raylet { -ActorRegistration::ActorRegistration(const ActorTableDataT &actor_table_data) +ActorRegistration::ActorRegistration(const ActorTableData &actor_table_data) : actor_table_data_(actor_table_data) {} -ActorRegistration::ActorRegistration(const ActorTableDataT &actor_table_data, - const ActorCheckpointDataT &checkpoint_data) +ActorRegistration::ActorRegistration(const ActorTableData &actor_table_data, + const ActorCheckpointData &checkpoint_data) : actor_table_data_(actor_table_data), - execution_dependency_(ObjectID::FromBinary(checkpoint_data.execution_dependency)) { + execution_dependency_( + ObjectID::FromBinary(checkpoint_data.execution_dependency())) { // Restore `frontier_`. - for (size_t i = 0; i < checkpoint_data.handle_ids.size(); i++) { - auto handle_id = ActorHandleID::FromBinary(checkpoint_data.handle_ids[i]); + for (size_t i = 0; i < checkpoint_data.handle_ids_size(); i++) { + auto handle_id = ActorHandleID::FromBinary(checkpoint_data.handle_ids(i)); auto &frontier_entry = frontier_[handle_id]; - frontier_entry.task_counter = checkpoint_data.task_counters[i]; + frontier_entry.task_counter = checkpoint_data.task_counters(i); frontier_entry.execution_dependency = - ObjectID::FromBinary(checkpoint_data.frontier_dependencies[i]); + ObjectID::FromBinary(checkpoint_data.frontier_dependencies(i)); } // Restore `dummy_objects_`. - for (size_t i = 0; i < checkpoint_data.unreleased_dummy_objects.size(); i++) { - auto dummy = ObjectID::FromBinary(checkpoint_data.unreleased_dummy_objects[i]); - dummy_objects_[dummy] = checkpoint_data.num_dummy_object_dependencies[i]; + for (size_t i = 0; i < checkpoint_data.unreleased_dummy_objects_size(); i++) { + auto dummy = ObjectID::FromBinary(checkpoint_data.unreleased_dummy_objects(i)); + dummy_objects_[dummy] = checkpoint_data.num_dummy_object_dependencies(i); } } const ClientID ActorRegistration::GetNodeManagerId() const { - return ClientID::FromBinary(actor_table_data_.node_manager_id); + return ClientID::FromBinary(actor_table_data_.node_manager_id()); } const ObjectID ActorRegistration::GetActorCreationDependency() const { - return ObjectID::FromBinary(actor_table_data_.actor_creation_dummy_object_id); + return ObjectID::FromBinary(actor_table_data_.actor_creation_dummy_object_id()); } const ObjectID ActorRegistration::GetExecutionDependency() const { @@ -43,15 +44,15 @@ const ObjectID ActorRegistration::GetExecutionDependency() const { } const DriverID ActorRegistration::GetDriverId() const { - return DriverID::FromBinary(actor_table_data_.driver_id); + return DriverID::FromBinary(actor_table_data_.driver_id()); } const int64_t ActorRegistration::GetMaxReconstructions() const { - return actor_table_data_.max_reconstructions; + return actor_table_data_.max_reconstructions(); } const int64_t ActorRegistration::GetRemainingReconstructions() const { - return actor_table_data_.remaining_reconstructions; + return actor_table_data_.remaining_reconstructions(); } const std::unordered_map @@ -96,7 +97,7 @@ void ActorRegistration::AddHandle(const ActorHandleID &handle_id, int ActorRegistration::NumHandles() const { return frontier_.size(); } -std::shared_ptr ActorRegistration::GenerateCheckpointData( +std::shared_ptr ActorRegistration::GenerateCheckpointData( const ActorID &actor_id, const Task &task) { const auto actor_handle_id = task.GetTaskSpecification().ActorHandleId(); const auto dummy_object = task.GetTaskSpecification().ActorDummyObject(); @@ -109,18 +110,18 @@ std::shared_ptr ActorRegistration::GenerateCheckpointData( copy.ExtendFrontier(actor_handle_id, dummy_object); // Use actor's current state to generate checkpoint data. - auto checkpoint_data = std::make_shared(); - checkpoint_data->actor_id = actor_id.Binary(); - checkpoint_data->execution_dependency = copy.GetExecutionDependency().Binary(); + auto checkpoint_data = std::make_shared(); + checkpoint_data->set_actor_id(actor_id.Binary()); + checkpoint_data->set_execution_dependency(copy.GetExecutionDependency().Binary()); for (const auto &frontier : copy.GetFrontier()) { - checkpoint_data->handle_ids.push_back(frontier.first.Binary()); - checkpoint_data->task_counters.push_back(frontier.second.task_counter); - checkpoint_data->frontier_dependencies.push_back( + checkpoint_data->add_handle_ids(frontier.first.Binary()); + checkpoint_data->add_task_counters(frontier.second.task_counter); + checkpoint_data->add_frontier_dependencies( frontier.second.execution_dependency.Binary()); } for (const auto &entry : copy.GetDummyObjects()) { - checkpoint_data->unreleased_dummy_objects.push_back(entry.first.Binary()); - checkpoint_data->num_dummy_object_dependencies.push_back(entry.second); + checkpoint_data->add_unreleased_dummy_objects(entry.first.Binary()); + checkpoint_data->add_num_dummy_object_dependencies(entry.second); } return checkpoint_data; } diff --git a/src/ray/raylet/actor_registration.h b/src/ray/raylet/actor_registration.h index 8d7ce2a449ec..208e4998263f 100644 --- a/src/ray/raylet/actor_registration.h +++ b/src/ray/raylet/actor_registration.h @@ -4,13 +4,17 @@ #include #include "ray/common/id.h" -#include "ray/gcs/format/gcs_generated.h" +#include "ray/protobuf/gcs.pb.h" #include "ray/raylet/task.h" namespace ray { namespace raylet { +using rpc::ActorTableData; +using ActorState = rpc::ActorTableData::ActorState; +using rpc::ActorCheckpointData; + /// \class ActorRegistration /// /// Information about an actor registered in the system. This includes the @@ -23,13 +27,13 @@ class ActorRegistration { /// /// \param actor_table_data Information from the global actor table about /// this actor. This includes the actor's node manager location. - ActorRegistration(const ActorTableDataT &actor_table_data); + explicit ActorRegistration(const ActorTableData &actor_table_data); /// Recreate an actor's registration from a checkpoint. /// /// \param checkpoint_data The checkpoint used to restore the actor. - ActorRegistration(const ActorTableDataT &actor_table_data, - const ActorCheckpointDataT &checkpoint_data); + ActorRegistration(const ActorTableData &actor_table_data, + const ActorCheckpointData &checkpoint_data); /// Each actor may have multiple callers, or "handles". A frontier leaf /// represents the execution state of the actor with respect to a single @@ -46,15 +50,15 @@ class ActorRegistration { /// Get the actor table data. /// /// \return The actor table data. - const ActorTableDataT &GetTableData() const { return actor_table_data_; } + const ActorTableData &GetTableData() const { return actor_table_data_; } /// Get the actor's current state (ALIVE or DEAD). /// /// \return The actor's current state. - const ActorState &GetState() const { return actor_table_data_.state; } + const ActorState GetState() const { return actor_table_data_.state(); } /// Update actor's state. - void SetState(const ActorState &state) { actor_table_data_.state = state; } + void SetState(const ActorState &state) { actor_table_data_.set_state(state); } /// Get the actor's node manager location. /// @@ -131,13 +135,13 @@ class ActorRegistration { /// \param actor_id ID of this actor. /// \param task The task that just finished on the actor. /// \return A shared pointer to the generated checkpoint data. - std::shared_ptr GenerateCheckpointData(const ActorID &actor_id, - const Task &task); + std::shared_ptr GenerateCheckpointData(const ActorID &actor_id, + const Task &task); private: /// Information from the global actor table about this actor, including the /// node manager location. - ActorTableDataT actor_table_data_; + ActorTableData actor_table_data_; /// The object representing the state following the actor's most recently /// executed task. The next task to execute on the actor should be marked as /// execution-dependent on this object. diff --git a/src/ray/raylet/lineage_cache.cc b/src/ray/raylet/lineage_cache.cc index 32dddada5244..68d5aa817c2b 100644 --- a/src/ray/raylet/lineage_cache.cc +++ b/src/ray/raylet/lineage_cache.cc @@ -63,15 +63,6 @@ void LineageEntry::UpdateTaskData(const Task &task) { Lineage::Lineage() {} -Lineage::Lineage(const protocol::ForwardTaskRequest &task_request) { - // Deserialize and set entries for the uncommitted tasks. - auto tasks = task_request.uncommitted_tasks(); - for (auto it = tasks->begin(); it != tasks->end(); it++) { - const auto &task = **it; - RAY_CHECK(SetEntry(task, GcsStatus::UNCOMMITTED)); - } -} - boost::optional Lineage::GetEntry(const TaskID &task_id) const { auto entry = entries_.find(task_id); if (entry != entries_.end()) { @@ -151,20 +142,6 @@ const std::unordered_map &Lineage::GetEntries() cons return entries_; } -flatbuffers::Offset Lineage::ToFlatbuffer( - flatbuffers::FlatBufferBuilder &fbb, const TaskID &task_id) const { - RAY_CHECK(GetEntry(task_id)); - // Serialize the task and object entries. - std::vector> uncommitted_tasks; - for (const auto &entry : entries_) { - uncommitted_tasks.push_back(entry.second.TaskData().ToFlatbuffer(fbb)); - } - - auto request = protocol::CreateForwardTaskRequest(fbb, to_flatbuf(fbb, task_id), - fbb.CreateVector(uncommitted_tasks)); - return request; -} - const std::unordered_set &Lineage::GetChildren(const TaskID &task_id) const { static const std::unordered_set empty_children; const auto it = children_.find(task_id); @@ -176,7 +153,7 @@ const std::unordered_set &Lineage::GetChildren(const TaskID &task_id) co } LineageCache::LineageCache(const ClientID &client_id, - gcs::TableInterface &task_storage, + gcs::TableInterface &task_storage, gcs::PubsubInterface &task_pubsub, uint64_t max_lineage_size) : client_id_(client_id), task_storage_(task_storage), task_pubsub_(task_pubsub) {} @@ -292,15 +269,11 @@ void LineageCache::FlushTask(const TaskID &task_id) { gcs::raylet::TaskTable::WriteCallback task_callback = [this](ray::gcs::AsyncGcsClient *client, const TaskID &id, - const protocol::TaskT &data) { HandleEntryCommitted(id); }; + const TaskTableData &data) { HandleEntryCommitted(id); }; auto task = lineage_.GetEntry(task_id); // TODO(swang): Make this better... - flatbuffers::FlatBufferBuilder fbb; - auto message = task->TaskData().ToFlatbuffer(fbb); - fbb.Finish(message); - auto task_data = std::make_shared(); - auto root = flatbuffers::GetRoot(fbb.GetBufferPointer()); - root->UnPackTo(task_data.get()); + auto task_data = std::make_shared(); + task_data->set_task(task->TaskData().Serialize()); RAY_CHECK_OK( task_storage_.Add(DriverID(task->TaskData().GetTaskSpecification().DriverId()), task_id, task_data, task_callback)); @@ -365,8 +338,6 @@ void LineageCache::EvictTask(const TaskID &task_id) { for (const auto &child_id : children) { EvictTask(child_id); } - - return; } void LineageCache::HandleEntryCommitted(const TaskID &task_id) { diff --git a/src/ray/raylet/lineage_cache.h b/src/ray/raylet/lineage_cache.h index 5436fa372fa4..37ce5caf6507 100644 --- a/src/ray/raylet/lineage_cache.h +++ b/src/ray/raylet/lineage_cache.h @@ -4,18 +4,17 @@ #include #include -// clang-format off -#include "ray/common/common_protocol.h" -#include "ray/raylet/task.h" -#include "ray/gcs/tables.h" #include "ray/common/id.h" #include "ray/common/status.h" -// clang-format on +#include "ray/gcs/tables.h" +#include "ray/raylet/task.h" namespace ray { namespace raylet { +using rpc::TaskTableData; + /// The status of a lineage cache entry according to its status in the GCS. /// Tasks can only transition to a higher GcsStatus (e.g., an UNCOMMITTED state /// can become COMMITTING but not vice versa). If a task is evicted from the @@ -136,12 +135,6 @@ class Lineage { /// Construct an empty Lineage. Lineage(); - /// Construct a Lineage from a ForwardTaskRequest. - /// - /// \param task_request The request to construct the lineage from. All - /// uncommitted tasks in the request will be added to the lineage. - Lineage(const protocol::ForwardTaskRequest &task_request); - /// Get an entry from the lineage. /// /// \param entry_id The ID of the entry to get. @@ -172,15 +165,6 @@ class Lineage { /// \return A const reference to the lineage entries. const std::unordered_map &GetEntries() const; - /// Serialize this lineage to a ForwardTaskRequest flatbuffer. - /// - /// \param entry_id The task ID to include in the ForwardTaskRequest - /// flatbuffer. - /// \return An offset to the serialized lineage. The serialization includes - /// all task and object entries in the lineage. - flatbuffers::Offset ToFlatbuffer( - flatbuffers::FlatBufferBuilder &fbb, const TaskID &entry_id) const; - /// Return the IDs of tasks in the lineage that are dependent on the given /// task. /// @@ -221,7 +205,7 @@ class LineageCache { /// Create a lineage cache for the given task storage system. /// TODO(swang): Pass in the policy (interface?). LineageCache(const ClientID &client_id, - gcs::TableInterface &task_storage, + gcs::TableInterface &task_storage, gcs::PubsubInterface &task_pubsub, uint64_t max_lineage_size); /// Asynchronously commit a task to the GCS. @@ -319,7 +303,7 @@ class LineageCache { /// TODO(swang): Move the ClientID into the generic Table implementation. ClientID client_id_; /// The durable storage system for task information. - gcs::TableInterface &task_storage_; + gcs::TableInterface &task_storage_; /// The pubsub storage system for task information. This can be used to /// request notifications for the commit of a task entry. gcs::PubsubInterface &task_pubsub_; diff --git a/src/ray/raylet/lineage_cache_test.cc b/src/ray/raylet/lineage_cache_test.cc index 43e64e400292..a6184902f803 100644 --- a/src/ray/raylet/lineage_cache_test.cc +++ b/src/ray/raylet/lineage_cache_test.cc @@ -13,7 +13,7 @@ namespace ray { namespace raylet { -class MockGcs : public gcs::TableInterface, +class MockGcs : public gcs::TableInterface, public gcs::PubsubInterface { public: MockGcs() {} @@ -23,15 +23,15 @@ class MockGcs : public gcs::TableInterface, } Status Add(const DriverID &driver_id, const TaskID &task_id, - std::shared_ptr &task_data, - const gcs::TableInterface::WriteCallback &done) { + std::shared_ptr &task_data, + const gcs::TableInterface::WriteCallback &done) { task_table_[task_id] = task_data; auto callback = done; // If we requested notifications for this task ID, send the notification as // part of the callback. if (subscribed_tasks_.count(task_id) == 1) { callback = [this, done](ray::gcs::AsyncGcsClient *client, const TaskID &task_id, - const protocol::TaskT &data) { + const TaskTableData &data) { done(client, task_id, data); // If we're subscribed to the task to be added, also send a // subscription notification. @@ -45,14 +45,14 @@ class MockGcs : public gcs::TableInterface, return ray::Status::OK(); } - Status RemoteAdd(const TaskID &task_id, std::shared_ptr task_data) { + Status RemoteAdd(const TaskID &task_id, std::shared_ptr task_data) { task_table_[task_id] = task_data; // Send a notification after the add if the lineage cache requested // notifications for this key. bool send_notification = (subscribed_tasks_.count(task_id) == 1); auto callback = [this, send_notification](ray::gcs::AsyncGcsClient *client, const TaskID &task_id, - const protocol::TaskT &data) { + const TaskTableData &data) { if (send_notification) { notification_callback_(client, task_id, data); } @@ -84,7 +84,7 @@ class MockGcs : public gcs::TableInterface, } } - const std::unordered_map> &TaskTable() const { + const std::unordered_map> &TaskTable() const { return task_table_; } @@ -95,7 +95,7 @@ class MockGcs : public gcs::TableInterface, const int NumTaskAdds() const { return num_task_adds_; } private: - std::unordered_map> task_table_; + std::unordered_map> task_table_; std::vector> callbacks_; gcs::raylet::TaskTable::WriteCallback notification_callback_; std::unordered_set subscribed_tasks_; @@ -111,7 +111,7 @@ class LineageCacheTest : public ::testing::Test { mock_gcs_(), lineage_cache_(ClientID::FromRandom(), mock_gcs_, mock_gcs_, max_lineage_size_) { mock_gcs_.Subscribe([this](ray::gcs::AsyncGcsClient *client, const TaskID &task_id, - const ray::protocol::TaskT &data) { + const TaskTableData &data) { lineage_cache_.HandleEntryCommitted(task_id); num_notifications_++; }); @@ -341,7 +341,7 @@ TEST_F(LineageCacheTest, TestEvictChain) { ASSERT_EQ(lineage_cache_.GetLineage().GetEntries().size(), tasks.size()); // Simulate executing the task on a remote node and adding it to the GCS. - auto task_data = std::make_shared(); + auto task_data = std::make_shared(); RAY_CHECK_OK( mock_gcs_.RemoteAdd(tasks.at(1).GetTaskSpecification().TaskId(), task_data)); mock_gcs_.Flush(); @@ -432,7 +432,7 @@ TEST_F(LineageCacheTest, TestEviction) { // Simulate executing the first task on a remote node and adding it to the // GCS. - auto task_data = std::make_shared(); + auto task_data = std::make_shared(); auto it = tasks.begin(); RAY_CHECK_OK(mock_gcs_.RemoteAdd(it->GetTaskSpecification().TaskId(), task_data)); it++; @@ -490,7 +490,7 @@ TEST_F(LineageCacheTest, TestOutOfOrderEviction) { auto last_task = tasks.front(); tasks.erase(tasks.begin()); for (auto it = tasks.rbegin(); it != tasks.rend(); it++) { - auto task_data = std::make_shared(); + auto task_data = std::make_shared(); RAY_CHECK_OK(mock_gcs_.RemoteAdd(it->GetTaskSpecification().TaskId(), task_data)); // Check that the remote task is flushed. num_tasks_flushed++; @@ -500,7 +500,7 @@ TEST_F(LineageCacheTest, TestOutOfOrderEviction) { } // Flush the last task. The lineage should not get evicted until this task's // commit is received. - auto task_data = std::make_shared(); + auto task_data = std::make_shared(); RAY_CHECK_OK(mock_gcs_.RemoteAdd(last_task.GetTaskSpecification().TaskId(), task_data)); num_tasks_flushed++; mock_gcs_.Flush(); @@ -536,7 +536,7 @@ TEST_F(LineageCacheTest, TestEvictionUncommittedChildren) { // until after the final remote task is executed, since a task can only be // evicted once all of its ancestors have been committed. for (auto it = tasks.rbegin(); it != tasks.rend(); it++) { - auto task_data = std::make_shared(); + auto task_data = std::make_shared(); ASSERT_EQ(lineage_cache_.GetLineage().GetEntries().size(), lineage_size * 2); RAY_CHECK_OK(mock_gcs_.RemoteAdd(it->GetTaskSpecification().TaskId(), task_data)); num_tasks_flushed++; diff --git a/src/ray/raylet/monitor.cc b/src/ray/raylet/monitor.cc index 62ecb00b819f..0a853260887e 100644 --- a/src/ray/raylet/monitor.cc +++ b/src/ray/raylet/monitor.cc @@ -24,14 +24,14 @@ Monitor::Monitor(boost::asio::io_service &io_service, const std::string &redis_a } void Monitor::HandleHeartbeat(const ClientID &client_id, - const HeartbeatTableDataT &heartbeat_data) { + const HeartbeatTableData &heartbeat_data) { heartbeats_[client_id] = num_heartbeats_timeout_; heartbeat_buffer_[client_id] = heartbeat_data; } void Monitor::Start() { const auto heartbeat_callback = [this](gcs::AsyncGcsClient *client, const ClientID &id, - const HeartbeatTableDataT &heartbeat_data) { + const HeartbeatTableData &heartbeat_data) { HandleHeartbeat(id, heartbeat_data); }; RAY_CHECK_OK(gcs_client_.heartbeat_table().Subscribe( @@ -49,11 +49,11 @@ void Monitor::Tick() { RAY_LOG(WARNING) << "Client timed out: " << client_id; auto lookup_callback = [this, client_id]( gcs::AsyncGcsClient *client, const ClientID &id, - const std::vector &all_data) { + const std::vector &all_data) { bool marked = false; for (const auto &data : all_data) { - if (client_id.Binary() == data.client_id && - data.entry_type == EntryType::DELETION) { + if (client_id.Binary() == data.client_id() && + data.entry_type() == ClientTableData::DELETION) { // The node has been marked dead by itself. marked = true; } @@ -84,10 +84,9 @@ void Monitor::Tick() { // Send any buffered heartbeats as a single publish. if (!heartbeat_buffer_.empty()) { - auto batch = std::make_shared(); + auto batch = std::make_shared(); for (const auto &heartbeat : heartbeat_buffer_) { - batch->batch.push_back(std::unique_ptr( - new HeartbeatTableDataT(heartbeat.second))); + batch->add_batch()->CopyFrom(heartbeat.second); } RAY_CHECK_OK(gcs_client_.heartbeat_batch_table().Add(DriverID::Nil(), ClientID::Nil(), batch, nullptr)); diff --git a/src/ray/raylet/monitor.h b/src/ray/raylet/monitor.h index c69cc9f003e0..5725e52cf495 100644 --- a/src/ray/raylet/monitor.h +++ b/src/ray/raylet/monitor.h @@ -11,6 +11,10 @@ namespace ray { namespace raylet { +using rpc::ClientTableData; +using rpc::HeartbeatBatchTableData; +using rpc::HeartbeatTableData; + class Monitor { public: /// Create a Raylet monitor attached to the given GCS address and port. @@ -35,7 +39,7 @@ class Monitor { /// \param client_id The client ID of the Raylet that sent the heartbeat. /// \param heartbeat_data The heartbeat sent by the client. void HandleHeartbeat(const ClientID &client_id, - const HeartbeatTableDataT &heartbeat_data); + const HeartbeatTableData &heartbeat_data); private: /// A client to the GCS, through which heartbeats are received. @@ -50,7 +54,7 @@ class Monitor { /// The Raylets that have been marked as dead in the client table. std::unordered_set dead_clients_; /// A buffer containing heartbeats received from node managers in the last tick. - std::unordered_map heartbeat_buffer_; + std::unordered_map heartbeat_buffer_; }; } // namespace raylet diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index fc364539ccce..808eeb6fd211 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -46,9 +46,9 @@ ActorStats GetActorStatisticalData( std::unordered_map actor_registry) { ActorStats item; for (auto &pair : actor_registry) { - if (pair.second.GetState() == ActorState::ALIVE) { + if (pair.second.GetState() == ray::rpc::ActorTableData::ALIVE) { item.live_actors += 1; - } else if (pair.second.GetState() == ActorState::RECONSTRUCTING) { + } else if (pair.second.GetState() == ray::rpc::ActorTableData::RECONSTRUCTING) { item.reconstructing_actors += 1; } else { item.dead_actors += 1; @@ -130,7 +130,7 @@ ray::Status NodeManager::RegisterGcs() { // that were executed remotely. const auto task_committed_callback = [this](gcs::AsyncGcsClient *client, const TaskID &task_id, - const ray::protocol::TaskT &task_data) { + const TaskTableData &task_data) { lineage_cache_.HandleEntryCommitted(task_id); }; RAY_RETURN_NOT_OK(gcs_client_->raylet_task_table().Subscribe( @@ -139,8 +139,8 @@ ray::Status NodeManager::RegisterGcs() { const auto task_lease_notification_callback = [this](gcs::AsyncGcsClient *client, const TaskID &task_id, - const TaskLeaseDataT &task_lease) { - const ClientID node_manager_id = ClientID::FromBinary(task_lease.node_manager_id); + const TaskLeaseData &task_lease) { + const ClientID node_manager_id = ClientID::FromBinary(task_lease.node_manager_id()); if (gcs_client_->client_table().IsRemoved(node_manager_id)) { // The node manager that added the task lease is already removed. The // lease is considered inactive. @@ -150,7 +150,7 @@ ray::Status NodeManager::RegisterGcs() { // expiration period since the entry may have been in the GCS for some // time already. For a more accurate estimate, the age of the entry in // the GCS should be subtracted from task_lease.timeout. - reconstruction_policy_.HandleTaskLeaseNotification(task_id, task_lease.timeout); + reconstruction_policy_.HandleTaskLeaseNotification(task_id, task_lease.timeout()); } }; const auto task_lease_empty_callback = [this](gcs::AsyncGcsClient *client, @@ -164,7 +164,7 @@ ray::Status NodeManager::RegisterGcs() { // Register a callback to handle actor notifications. auto actor_notification_callback = [this](gcs::AsyncGcsClient *client, const ActorID &actor_id, - const std::vector &data) { + const std::vector &data) { if (!data.empty()) { // We only need the last entry, because it represents the latest state of // this actor. @@ -177,34 +177,34 @@ ray::Status NodeManager::RegisterGcs() { // Register a callback on the client table for new clients. auto node_manager_client_added = [this](gcs::AsyncGcsClient *client, const UniqueID &id, - const ClientTableDataT &data) { + const ClientTableData &data) { ClientAdded(data); }; gcs_client_->client_table().RegisterClientAddedCallback(node_manager_client_added); // Register a callback on the client table for removed clients. auto node_manager_client_removed = [this](gcs::AsyncGcsClient *client, const UniqueID &id, - const ClientTableDataT &data) { ClientRemoved(data); }; + const ClientTableData &data) { ClientRemoved(data); }; gcs_client_->client_table().RegisterClientRemovedCallback(node_manager_client_removed); // Register a callback on the client table for resource create/update requests auto node_manager_resource_createupdated = [this](gcs::AsyncGcsClient *client, const UniqueID &id, - const ClientTableDataT &data) { ResourceCreateUpdated(data); }; + const ClientTableData &data) { ResourceCreateUpdated(data); }; gcs_client_->client_table().RegisterResourceCreateUpdatedCallback( node_manager_resource_createupdated); // Register a callback on the client table for resource delete requests auto node_manager_resource_deleted = [this](gcs::AsyncGcsClient *client, const UniqueID &id, - const ClientTableDataT &data) { ResourceDeleted(data); }; + const ClientTableData &data) { ResourceDeleted(data); }; gcs_client_->client_table().RegisterResourceDeletedCallback( node_manager_resource_deleted); // Subscribe to heartbeat batches from the monitor. const auto &heartbeat_batch_added = [this](gcs::AsyncGcsClient *client, const ClientID &id, - const HeartbeatBatchTableDataT &heartbeat_batch) { + const HeartbeatBatchTableData &heartbeat_batch) { HeartbeatBatchAdded(heartbeat_batch); }; RAY_RETURN_NOT_OK(gcs_client_->heartbeat_batch_table().Subscribe( @@ -215,7 +215,7 @@ ray::Status NodeManager::RegisterGcs() { // Subscribe to driver table updates. const auto driver_table_handler = [this](gcs::AsyncGcsClient *client, const DriverID &client_id, - const std::vector &driver_data) { + const std::vector &driver_data) { HandleDriverTableUpdate(client_id, driver_data); }; RAY_RETURN_NOT_OK(gcs_client_->driver_table().Subscribe( @@ -251,12 +251,12 @@ void NodeManager::KillWorker(std::shared_ptr worker) { } void NodeManager::HandleDriverTableUpdate( - const DriverID &id, const std::vector &driver_data) { + const DriverID &id, const std::vector &driver_data) { for (const auto &entry : driver_data) { - RAY_LOG(DEBUG) << "HandleDriverTableUpdate " << UniqueID::FromBinary(entry.driver_id) - << " " << entry.is_dead; - if (entry.is_dead) { - auto driver_id = DriverID::FromBinary(entry.driver_id); + RAY_LOG(DEBUG) << "HandleDriverTableUpdate " + << UniqueID::FromBinary(entry.driver_id()) << " " << entry.is_dead(); + if (entry.is_dead()) { + auto driver_id = DriverID::FromBinary(entry.driver_id()); auto workers = worker_pool_.GetWorkersRunningTasksForDriver(driver_id); // Kill all the workers. The actual cleanup for these workers is done @@ -288,26 +288,26 @@ void NodeManager::Heartbeat() { last_heartbeat_at_ms_ = now_ms; auto &heartbeat_table = gcs_client_->heartbeat_table(); - auto heartbeat_data = std::make_shared(); + auto heartbeat_data = std::make_shared(); const auto &my_client_id = gcs_client_->client_table().GetLocalClientId(); SchedulingResources &local_resources = cluster_resource_map_[my_client_id]; - heartbeat_data->client_id = my_client_id.Binary(); + heartbeat_data->set_client_id(my_client_id.Binary()); // TODO(atumanov): modify the heartbeat table protocol to use the ResourceSet directly. // TODO(atumanov): implement a ResourceSet const_iterator. for (const auto &resource_pair : local_resources.GetAvailableResources().GetResourceMap()) { - heartbeat_data->resources_available_label.push_back(resource_pair.first); - heartbeat_data->resources_available_capacity.push_back(resource_pair.second); + heartbeat_data->add_resources_available_label(resource_pair.first); + heartbeat_data->add_resources_available_capacity(resource_pair.second); } for (const auto &resource_pair : local_resources.GetTotalResources().GetResourceMap()) { - heartbeat_data->resources_total_label.push_back(resource_pair.first); - heartbeat_data->resources_total_capacity.push_back(resource_pair.second); + heartbeat_data->add_resources_total_label(resource_pair.first); + heartbeat_data->add_resources_total_capacity(resource_pair.second); } local_resources.SetLoadResources(local_queues_.GetResourceLoad()); for (const auto &resource_pair : local_resources.GetLoadResources().GetResourceMap()) { - heartbeat_data->resource_load_label.push_back(resource_pair.first); - heartbeat_data->resource_load_capacity.push_back(resource_pair.second); + heartbeat_data->add_resource_load_label(resource_pair.first); + heartbeat_data->add_resource_load_capacity(resource_pair.second); } ray::Status status = heartbeat_table.Add( @@ -335,13 +335,8 @@ void NodeManager::GetObjectManagerProfileInfo() { auto profile_info = object_manager_.GetAndResetProfilingInfo(); - if (profile_info.profile_events.size() > 0) { - flatbuffers::FlatBufferBuilder fbb; - auto message = CreateProfileTableData(fbb, &profile_info); - fbb.Finish(message); - auto profile_message = flatbuffers::GetRoot(fbb.GetBufferPointer()); - - RAY_CHECK_OK(gcs_client_->profile_table().AddProfileEventBatch(*profile_message)); + if (profile_info.profile_events_size() > 0) { + RAY_CHECK_OK(gcs_client_->profile_table().AddProfileEventBatch(profile_info)); } // Reset the timer. @@ -358,8 +353,8 @@ void NodeManager::GetObjectManagerProfileInfo() { } } -void NodeManager::ClientAdded(const ClientTableDataT &client_data) { - const ClientID client_id = ClientID::FromBinary(client_data.client_id); +void NodeManager::ClientAdded(const ClientTableData &client_data) { + const ClientID client_id = ClientID::FromBinary(client_data.client_id()); RAY_LOG(DEBUG) << "[ClientAdded] Received callback from client id " << client_id; if (client_id == gcs_client_->client_table().GetLocalClientId()) { @@ -378,19 +373,20 @@ void NodeManager::ClientAdded(const ClientTableDataT &client_data) { // Initialize a rpc client to the new node manager. std::unique_ptr client( - new rpc::NodeManagerClient(client_data.node_manager_address, - client_data.node_manager_port, client_call_manager_)); + new rpc::NodeManagerClient(client_data.node_manager_address(), + client_data.node_manager_port(), client_call_manager_)); remote_node_manager_clients_.emplace(client_id, std::move(client)); - ResourceSet resources_total(client_data.resources_total_label, - client_data.resources_total_capacity); + ResourceSet resources_total( + rpc::VectorFromProtobuf(client_data.resources_total_label()), + rpc::VectorFromProtobuf(client_data.resources_total_capacity())); cluster_resource_map_.emplace(client_id, SchedulingResources(resources_total)); } -void NodeManager::ClientRemoved(const ClientTableDataT &client_data) { +void NodeManager::ClientRemoved(const ClientTableData &client_data) { // TODO(swang): If we receive a notification for our own death, clean up and // exit immediately. - const ClientID client_id = ClientID::FromBinary(client_data.client_id); + const ClientID client_id = ClientID::FromBinary(client_data.client_id()); RAY_LOG(DEBUG) << "[ClientRemoved] Received callback from client id " << client_id; RAY_CHECK(client_id != gcs_client_->client_table().GetLocalClientId()) @@ -418,7 +414,7 @@ void NodeManager::ClientRemoved(const ClientTableDataT &client_data) { // TODO(swang): This could be very slow if there are many actors. for (const auto &actor_entry : actor_registry_) { if (actor_entry.second.GetNodeManagerId() == client_id && - actor_entry.second.GetState() == ActorState::ALIVE) { + actor_entry.second.GetState() == ActorTableData::ALIVE) { RAY_LOG(INFO) << "Actor " << actor_entry.first << " is disconnected, because its node " << client_id << " is removed from cluster. It may be reconstructed."; @@ -436,14 +432,15 @@ void NodeManager::ClientRemoved(const ClientTableDataT &client_data) { lineage_cache_.FlushAllUncommittedTasks(); } -void NodeManager::ResourceCreateUpdated(const ClientTableDataT &client_data) { - const ClientID client_id = ClientID::FromBinary(client_data.client_id); +void NodeManager::ResourceCreateUpdated(const ClientTableData &client_data) { + const ClientID client_id = ClientID::FromBinary(client_data.client_id()); const ClientID &local_client_id = gcs_client_->client_table().GetLocalClientId(); RAY_LOG(DEBUG) << "[ResourceCreateUpdated] received callback from client id " << client_id << ". Updating resource map."; - ResourceSet new_res_set(client_data.resources_total_label, - client_data.resources_total_capacity); + ResourceSet new_res_set( + rpc::VectorFromProtobuf(client_data.resources_total_label()), + rpc::VectorFromProtobuf(client_data.resources_total_capacity())); const ResourceSet &old_res_set = cluster_resource_map_[client_id].GetTotalResources(); ResourceSet difference_set = old_res_set.FindUpdatedResources(new_res_set); @@ -472,12 +469,13 @@ void NodeManager::ResourceCreateUpdated(const ClientTableDataT &client_data) { return; } -void NodeManager::ResourceDeleted(const ClientTableDataT &client_data) { - const ClientID client_id = ClientID::FromBinary(client_data.client_id); +void NodeManager::ResourceDeleted(const ClientTableData &client_data) { + const ClientID client_id = ClientID::FromBinary(client_data.client_id()); const ClientID &local_client_id = gcs_client_->client_table().GetLocalClientId(); - ResourceSet new_res_set(client_data.resources_total_label, - client_data.resources_total_capacity); + ResourceSet new_res_set( + rpc::VectorFromProtobuf(client_data.resources_total_label()), + rpc::VectorFromProtobuf(client_data.resources_total_capacity())); RAY_LOG(DEBUG) << "[ResourceDeleted] received callback from client id " << client_id << " with new resources: " << new_res_set.ToString() << ". Updating resource map."; @@ -523,7 +521,7 @@ void NodeManager::TryLocalInfeasibleTaskScheduling() { } void NodeManager::HeartbeatAdded(const ClientID &client_id, - const HeartbeatTableDataT &heartbeat_data) { + const HeartbeatTableData &heartbeat_data) { // Locate the client id in remote client table and update available resources based on // the received heartbeat information. auto it = cluster_resource_map_.find(client_id); @@ -535,10 +533,12 @@ void NodeManager::HeartbeatAdded(const ClientID &client_id, } SchedulingResources &remote_resources = it->second; - ResourceSet remote_available(heartbeat_data.resources_available_label, - heartbeat_data.resources_available_capacity); - ResourceSet remote_load(heartbeat_data.resource_load_label, - heartbeat_data.resource_load_capacity); + ResourceSet remote_available( + rpc::VectorFromProtobuf(heartbeat_data.resources_total_label()), + rpc::VectorFromProtobuf(heartbeat_data.resources_total_capacity())); + ResourceSet remote_load( + rpc::VectorFromProtobuf(heartbeat_data.resource_load_label()), + rpc::VectorFromProtobuf(heartbeat_data.resource_load_capacity())); // TODO(atumanov): assert that the load is a non-empty ResourceSet. remote_resources.SetAvailableResources(std::move(remote_available)); // Extract the load information and save it locally. @@ -563,40 +563,41 @@ void NodeManager::HeartbeatAdded(const ClientID &client_id, } } -void NodeManager::HeartbeatBatchAdded(const HeartbeatBatchTableDataT &heartbeat_batch) { +void NodeManager::HeartbeatBatchAdded(const HeartbeatBatchTableData &heartbeat_batch) { const ClientID &local_client_id = gcs_client_->client_table().GetLocalClientId(); // Update load information provided by each heartbeat. - for (const auto &heartbeat_data : heartbeat_batch.batch) { - const ClientID &client_id = ClientID::FromBinary(heartbeat_data->client_id); + for (const auto &heartbeat_data : heartbeat_batch.batch()) { + const ClientID &client_id = ClientID::FromBinary(heartbeat_data.client_id()); if (client_id == local_client_id) { // Skip heartbeats from self. continue; } - HeartbeatAdded(client_id, *heartbeat_data); + HeartbeatAdded(client_id, heartbeat_data); } } void NodeManager::PublishActorStateTransition( - const ActorID &actor_id, const ActorTableDataT &data, + const ActorID &actor_id, const ActorTableData &data, const ray::gcs::ActorTable::WriteCallback &failure_callback) { // Copy the actor notification data. - auto actor_notification = std::make_shared(data); + auto actor_notification = std::make_shared(data); // The actor log starts with an ALIVE entry. This is followed by 0 to N pairs // of (RECONSTRUCTING, ALIVE) entries, where N is the maximum number of // reconstructions. This is followed optionally by a DEAD entry. - int log_length = 2 * (actor_notification->max_reconstructions - - actor_notification->remaining_reconstructions); - if (actor_notification->state != ActorState::ALIVE) { + int log_length = 2 * (actor_notification->max_reconstructions() - + actor_notification->remaining_reconstructions()); + if (actor_notification->state() != ActorTableData::ALIVE) { // RECONSTRUCTING or DEAD entries have an odd index. log_length += 1; } // If we successful appended a record to the GCS table of the actor that // has died, signal this to anyone receiving signals from this actor. auto success_callback = [](gcs::AsyncGcsClient *client, const ActorID &id, - const ActorTableDataT &data) { + const ActorTableData &data) { auto redis_context = client->primary_context(); - if (data.state == ActorState::DEAD || data.state == ActorState::RECONSTRUCTING) { + if (data.state() == ActorTableData::DEAD || + data.state() == ActorTableData::RECONSTRUCTING) { std::vector args = {"XADD", id.Hex(), "*", "signal", "ACTOR_DIED_SIGNAL"}; RAY_CHECK_OK(redis_context->RunArgvAsync(args)); @@ -633,11 +634,12 @@ void NodeManager::HandleActorStateTransition(const ActorID &actor_id, } RAY_LOG(DEBUG) << "Actor notification received: actor_id = " << actor_id << ", node_manager_id = " << actor_registration.GetNodeManagerId() - << ", state = " << EnumNameActorState(actor_registration.GetState()) + << ", state = " + << ActorTableData::ActorState_Name(actor_registration.GetState()) << ", remaining_reconstructions = " << actor_registration.GetRemainingReconstructions(); - if (actor_registration.GetState() == ActorState::ALIVE) { + if (actor_registration.GetState() == ActorTableData::ALIVE) { // The actor's location is now known. Dequeue any methods that were // submitted before the actor's location was known. // (See design_docs/task_states.rst for the state transition diagram.) @@ -664,7 +666,7 @@ void NodeManager::HandleActorStateTransition(const ActorID &actor_id, // empty lineage this time. SubmitTask(method, Lineage()); } - } else if (actor_registration.GetState() == ActorState::DEAD) { + } else if (actor_registration.GetState() == ActorTableData::DEAD) { // When an actor dies, loop over all of the queued tasks for that actor // and treat them as failed. auto tasks_to_remove = local_queues_.GetTaskIdsForActor(actor_id); @@ -673,7 +675,7 @@ void NodeManager::HandleActorStateTransition(const ActorID &actor_id, TreatTaskAsFailed(task, ErrorType::ACTOR_DIED); } } else { - RAY_CHECK(actor_registration.GetState() == ActorState::RECONSTRUCTING); + RAY_CHECK(actor_registration.GetState() == ActorTableData::RECONSTRUCTING); RAY_LOG(DEBUG) << "Actor is being reconstructed: " << actor_id; // When an actor fails but can be reconstructed, resubmit all of the queued // tasks for that actor. This will mark the tasks as waiting for actor @@ -794,8 +796,20 @@ void NodeManager::ProcessClientMessage( ProcessPushErrorRequestMessage(message_data); } break; case protocol::MessageType::PushProfileEventsRequest: { - auto message = flatbuffers::GetRoot(message_data); - RAY_CHECK_OK(gcs_client_->profile_table().AddProfileEventBatch(*message)); + ProfileTableDataT fbs_message; + flatbuffers::GetRoot(message_data)->UnPackTo(&fbs_message); + rpc::ProfileTableData profile_table_data; + profile_table_data.set_component_type(fbs_message.component_type); + profile_table_data.set_component_id(fbs_message.component_id); + for (const auto &fbs_event : fbs_message.profile_events) { + rpc::ProfileTableData::ProfileEvent *event = + profile_table_data.add_profile_events(); + event->set_event_type(fbs_event->event_type); + event->set_start_time(fbs_event->start_time); + event->set_end_time(fbs_event->end_time); + event->set_extra_data(fbs_event->extra_data); + } + RAY_CHECK_OK(gcs_client_->profile_table().AddProfileEventBatch(profile_table_data)); } break; case protocol::MessageType::FreeObjectsInObjectStoreRequest: { auto message = flatbuffers::GetRoot(message_data); @@ -863,8 +877,8 @@ void NodeManager::HandleDisconnectedActor(const ActorID &actor_id, bool was_loca // Check if this actor needs to be reconstructed. ActorState new_state = actor_registration.GetRemainingReconstructions() > 0 && !intentional_disconnect - ? ActorState::RECONSTRUCTING - : ActorState::DEAD; + ? ActorTableData::RECONSTRUCTING + : ActorTableData::DEAD; if (was_local) { // Clean up the dummy objects from this actor. RAY_LOG(DEBUG) << "Removing dummy objects for actor: " << actor_id; @@ -873,8 +887,8 @@ void NodeManager::HandleDisconnectedActor(const ActorID &actor_id, bool was_loca } } // Update the actor's state. - ActorTableDataT new_actor_data = actor_entry->second.GetTableData(); - new_actor_data.state = new_state; + ActorTableData new_actor_data = actor_entry->second.GetTableData(); + new_actor_data.set_state(new_state); if (was_local) { // If the actor was local, immediately update the state in actor registry. // So if we receive any actor tasks before we receive GCS notification, @@ -885,7 +899,7 @@ void NodeManager::HandleDisconnectedActor(const ActorID &actor_id, bool was_loca ray::gcs::ActorTable::WriteCallback failure_callback = nullptr; if (was_local) { failure_callback = [](gcs::AsyncGcsClient *client, const ActorID &id, - const ActorTableDataT &data) { + const ActorTableData &data) { // If the disconnected actor was local, only this node will try to update actor // state. So the update shouldn't fail. RAY_LOG(FATAL) << "Failed to update state for actor " << id; @@ -1160,7 +1174,7 @@ void NodeManager::ProcessPrepareActorCheckpointRequest( DriverID::Nil(), checkpoint_id, checkpoint_data, [worker, actor_id, this](ray::gcs::AsyncGcsClient *client, const ActorCheckpointID &checkpoint_id, - const ActorCheckpointDataT &data) { + const ActorCheckpointData &data) { RAY_LOG(DEBUG) << "Checkpoint " << checkpoint_id << " saved for actor " << worker->GetActorId(); // Save this actor-to-checkpoint mapping, and remove old checkpoints associated @@ -1244,19 +1258,19 @@ void NodeManager::ProcessSetResourceRequest( return; } - // Add the new resource to a skeleton ClientTableDataT object - ClientTableDataT data; + // Add the new resource to a skeleton ClientTableData object + ClientTableData data; gcs_client_->client_table().GetClient(client_id, data); // Replace the resource vectors with the resource deltas from the message. // RES_CREATEUPDATE and RES_DELETE entries in the ClientTable track changes (deltas) in // the resources - data.resources_total_label = std::vector{resource_name}; - data.resources_total_capacity = std::vector{capacity}; + data.add_resources_total_label(resource_name); + data.add_resources_total_capacity(capacity); // Set the correct flag for entry_type if (is_deletion) { - data.entry_type = EntryType::RES_DELETE; + data.set_entry_type(ClientTableData::RES_DELETE); } else { - data.entry_type = EntryType::RES_CREATEUPDATE; + data.set_entry_type(ClientTableData::RES_CREATEUPDATE); } // Submit to the client table. This calls the ResourceCreateUpdated callback, which @@ -1265,7 +1279,7 @@ void NodeManager::ProcessSetResourceRequest( if (not worker) { worker = worker_pool_.GetRegisteredDriver(client); } - auto data_shared_ptr = std::make_shared(data); + auto data_shared_ptr = std::make_shared(data); auto client_table = gcs_client_->client_table(); RAY_CHECK_OK(gcs_client_->client_table().Append( DriverID::Nil(), client_table.client_log_key_, data_shared_ptr, nullptr)); @@ -1370,7 +1384,7 @@ bool NodeManager::CheckDependencyManagerInvariant() const { void NodeManager::TreatTaskAsFailed(const Task &task, const ErrorType &error_type) { const TaskSpecification &spec = task.GetTaskSpecification(); RAY_LOG(DEBUG) << "Treating task " << spec.TaskId() << " as failed because of error " - << EnumNameErrorType(error_type) << "."; + << ErrorType_Name(error_type) << "."; // If this was an actor creation task that tried to resume from a checkpoint, // then erase it here since the task did not finish. if (spec.IsActorCreationTask()) { @@ -1488,9 +1502,9 @@ void NodeManager::SubmitTask(const Task &task, const Lineage &uncommitted_lineag // If we have already seen this actor and this actor is not being reconstructed, // its location is known. bool location_known = - seen && actor_entry->second.GetState() != ActorState::RECONSTRUCTING; + seen && actor_entry->second.GetState() != ActorTableData::RECONSTRUCTING; if (location_known) { - if (actor_entry->second.GetState() == ActorState::DEAD) { + if (actor_entry->second.GetState() == ActorTableData::DEAD) { // If this actor is dead, either because the actor process is dead // or because its residing node is dead, treat this task as failed. TreatTaskAsFailed(task, ErrorType::ACTOR_DIED); @@ -1535,7 +1549,7 @@ void NodeManager::SubmitTask(const Task &task, const Lineage &uncommitted_lineag // we missed the creation notification. auto lookup_callback = [this](gcs::AsyncGcsClient *client, const ActorID &actor_id, - const std::vector &data) { + const std::vector &data) { if (!data.empty()) { // The actor has been created. We only need the last entry, because // it represents the latest state of this actor. @@ -1861,11 +1875,11 @@ void NodeManager::FinishAssignedTask(Worker &worker) { } } -ActorTableDataT NodeManager::CreateActorTableDataFromCreationTask(const Task &task) { +ActorTableData NodeManager::CreateActorTableDataFromCreationTask(const Task &task) { RAY_CHECK(task.GetTaskSpecification().IsActorCreationTask()); auto actor_id = task.GetTaskSpecification().ActorCreationId(); auto actor_entry = actor_registry_.find(actor_id); - ActorTableDataT new_actor_data; + ActorTableData new_actor_data; // TODO(swang): If this is an actor that was reconstructed, and previous // actor notifications were delayed, then this node may not have an entry for // the actor in actor_regisry_. Then, the fields for the number of @@ -1873,32 +1887,33 @@ ActorTableDataT NodeManager::CreateActorTableDataFromCreationTask(const Task &ta if (actor_entry == actor_registry_.end()) { // Set all of the static fields for the actor. These fields will not // change even if the actor fails or is reconstructed. - new_actor_data.actor_id = actor_id.Binary(); - new_actor_data.actor_creation_dummy_object_id = - task.GetTaskSpecification().ActorDummyObject().Binary(); - new_actor_data.driver_id = task.GetTaskSpecification().DriverId().Binary(); - new_actor_data.max_reconstructions = - task.GetTaskSpecification().MaxActorReconstructions(); + new_actor_data.set_actor_id(actor_id.Binary()); + new_actor_data.set_actor_creation_dummy_object_id( + task.GetTaskSpecification().ActorDummyObject().Binary()); + new_actor_data.set_driver_id(task.GetTaskSpecification().DriverId().Binary()); + new_actor_data.set_max_reconstructions( + task.GetTaskSpecification().MaxActorReconstructions()); // This is the first time that the actor has been created, so the number // of remaining reconstructions is the max. - new_actor_data.remaining_reconstructions = - task.GetTaskSpecification().MaxActorReconstructions(); + new_actor_data.set_remaining_reconstructions( + task.GetTaskSpecification().MaxActorReconstructions()); } else { // If we've already seen this actor, it means that this actor was reconstructed. // Thus, its previous state must be RECONSTRUCTING. - RAY_CHECK(actor_entry->second.GetState() == ActorState::RECONSTRUCTING); + RAY_CHECK(actor_entry->second.GetState() == ActorTableData::RECONSTRUCTING); // Copy the static fields from the current actor entry. new_actor_data = actor_entry->second.GetTableData(); // We are reconstructing the actor, so subtract its // remaining_reconstructions by 1. - new_actor_data.remaining_reconstructions--; + new_actor_data.set_remaining_reconstructions( + new_actor_data.remaining_reconstructions() - 1); } // Set the new fields for the actor's state to indicate that the actor is // now alive on this node manager. - new_actor_data.node_manager_id = - gcs_client_->client_table().GetLocalClientId().Binary(); - new_actor_data.state = ActorState::ALIVE; + new_actor_data.set_node_manager_id( + gcs_client_->client_table().GetLocalClientId().Binary()); + new_actor_data.set_state(ActorTableData::ALIVE); return new_actor_data; } @@ -1934,7 +1949,7 @@ void NodeManager::FinishAssignedActorTask(Worker &worker, const Task &task) { DriverID::Nil(), checkpoint_id, [this, actor_id, new_actor_data](ray::gcs::AsyncGcsClient *client, const UniqueID &checkpoint_id, - const ActorCheckpointDataT &checkpoint_data) { + const ActorCheckpointData &checkpoint_data) { RAY_LOG(INFO) << "Restoring registration for actor " << actor_id << " from checkpoint " << checkpoint_id; ActorRegistration actor_registration = @@ -1948,7 +1963,7 @@ void NodeManager::FinishAssignedActorTask(Worker &worker, const Task &task) { actor_id, new_actor_data, /*failure_callback=*/ [](gcs::AsyncGcsClient *client, const ActorID &id, - const ActorTableDataT &data) { + const ActorTableData &data) { // Only one node at a time should succeed at creating the actor. RAY_LOG(FATAL) << "Failed to update state to ALIVE for actor " << id; }); @@ -1964,8 +1979,7 @@ void NodeManager::FinishAssignedActorTask(Worker &worker, const Task &task) { PublishActorStateTransition( actor_id, new_actor_data, /*failure_callback=*/ - [](gcs::AsyncGcsClient *client, const ActorID &id, - const ActorTableDataT &data) { + [](gcs::AsyncGcsClient *client, const ActorID &id, const ActorTableData &data) { // Only one node at a time should succeed at creating the actor. RAY_LOG(FATAL) << "Failed to update state to ALIVE for actor " << id; }); @@ -2004,10 +2018,11 @@ void NodeManager::HandleTaskReconstruction(const TaskID &task_id) { DriverID::Nil(), task_id, /*success_callback=*/ [this](ray::gcs::AsyncGcsClient *client, const TaskID &task_id, - const ray::protocol::TaskT &task_data) { + const TaskTableData &task_data) { // The task was in the GCS task table. Use the stored task spec to // re-execute the task. - const Task task(task_data); + auto message = flatbuffers::GetRoot(task_data.task().data()); + const Task task(*message); ResubmitTask(task); }, /*failure_callback=*/ @@ -2035,7 +2050,7 @@ void NodeManager::ResubmitTask(const Task &task) { if (task.GetTaskSpecification().IsActorCreationTask()) { const auto &actor_id = task.GetTaskSpecification().ActorCreationId(); const auto it = actor_registry_.find(actor_id); - if (it != actor_registry_.end() && it->second.GetState() == ActorState::ALIVE) { + if (it != actor_registry_.end() && it->second.GetState() == ActorTableData::ALIVE) { // If the actor is still alive, then do not resubmit the task. If the // actor actually is dead and a result is needed, then reconstruction // for this task will be triggered again. diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index 61613358330c..f45c8b035553 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -10,7 +10,6 @@ #include "ray/raylet/task.h" #include "ray/object_manager/object_manager.h" #include "ray/common/client_connection.h" -#include "ray/gcs/format/util.h" #include "ray/raylet/actor_registration.h" #include "ray/raylet/lineage_cache.h" #include "ray/raylet/scheduling_policy.h" @@ -26,6 +25,13 @@ namespace ray { namespace raylet { +using rpc::ActorTableData; +using rpc::ClientTableData; +using rpc::DriverTableData; +using rpc::ErrorType; +using rpc::HeartbeatBatchTableData; +using rpc::HeartbeatTableData; + struct NodeManagerConfig { /// The node's resource configuration. ResourceSet resource_config; @@ -112,22 +118,22 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// /// \param data Data associated with the new client. /// \return Void. - void ClientAdded(const ClientTableDataT &data); + void ClientAdded(const ClientTableData &data); /// Handler for the removal of a GCS client. /// \param client_data Data associated with the removed client. /// \return Void. - void ClientRemoved(const ClientTableDataT &client_data); + void ClientRemoved(const ClientTableData &client_data); /// Handler for the addition or updation of a resource in the GCS /// \param client_data Data associated with the new client. /// \return Void. - void ResourceCreateUpdated(const ClientTableDataT &client_data); + void ResourceCreateUpdated(const ClientTableData &client_data); /// Handler for the deletion of a resource in the GCS /// \param client_data Data associated with the new client. /// \return Void. - void ResourceDeleted(const ClientTableDataT &client_data); + void ResourceDeleted(const ClientTableData &client_data); /// Evaluates the local infeasible queue to check if any tasks can be scheduled. /// This is called whenever there's an update to the resources on the local client. @@ -150,11 +156,11 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// \param id The ID of the node manager that sent the heartbeat. /// \param data The heartbeat data including load information. /// \return Void. - void HeartbeatAdded(const ClientID &id, const HeartbeatTableDataT &data); + void HeartbeatAdded(const ClientID &id, const HeartbeatTableData &data); /// Handler for a heartbeat batch notification from the GCS /// /// \param heartbeat_batch The batch of heartbeat data. - void HeartbeatBatchAdded(const HeartbeatBatchTableDataT &heartbeat_batch); + void HeartbeatBatchAdded(const HeartbeatBatchTableData &heartbeat_batch); /// Methods for task scheduling. @@ -206,7 +212,7 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// Helper function to produce actor table data for a newly created actor. /// /// \param task The actor creation task that created the actor. - ActorTableDataT CreateActorTableDataFromCreationTask(const Task &task); + ActorTableData CreateActorTableDataFromCreationTask(const Task &task); /// Handle a worker finishing an assigned actor task or actor creation task. /// \param worker The worker that finished the task. /// \param task The actor task or actor creationt ask. @@ -317,7 +323,7 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// \param failure_callback An optional callback to call if the publish is /// unsuccessful. void PublishActorStateTransition( - const ActorID &actor_id, const ActorTableDataT &data, + const ActorID &actor_id, const ActorTableData &data, const ray::gcs::ActorTable::WriteCallback &failure_callback); /// When a driver dies, loop over all of the queued tasks for that driver and @@ -346,7 +352,7 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// \param driver_data Data associated with a driver table event. /// \return Void. void HandleDriverTableUpdate(const DriverID &id, - const std::vector &driver_data); + const std::vector &driver_data); /// Check if certain invariants associated with the task dependency manager /// and the local queues are satisfied. This is only used for debugging diff --git a/src/ray/raylet/raylet.cc b/src/ray/raylet/raylet.cc index 473e6c263ffe..cbf9b25213ca 100644 --- a/src/ray/raylet/raylet.cc +++ b/src/ray/raylet/raylet.cc @@ -90,23 +90,23 @@ ray::Status Raylet::RegisterGcs(const std::string &node_ip_address, const NodeManagerConfig &node_manager_config) { RAY_RETURN_NOT_OK(gcs_client_->Attach(io_service)); - ClientTableDataT client_info = gcs_client_->client_table().GetLocalClient(); - client_info.node_manager_address = node_ip_address; - client_info.raylet_socket_name = raylet_socket_name; - client_info.object_store_socket_name = object_store_socket_name; - client_info.object_manager_port = object_manager_acceptor_.local_endpoint().port(); - client_info.node_manager_port = node_manager_.GetServerPort(); + ClientTableData client_info = gcs_client_->client_table().GetLocalClient(); + client_info.set_node_manager_address(node_ip_address); + client_info.set_raylet_socket_name(raylet_socket_name); + client_info.set_object_store_socket_name(object_store_socket_name); + client_info.set_object_manager_port(object_manager_acceptor_.local_endpoint().port()); + client_info.set_node_manager_port(node_manager_.GetServerPort()); // Add resource information. for (const auto &resource_pair : node_manager_config.resource_config.GetResourceMap()) { - client_info.resources_total_label.push_back(resource_pair.first); - client_info.resources_total_capacity.push_back(resource_pair.second); + client_info.add_resources_total_label(resource_pair.first); + client_info.add_resources_total_capacity(resource_pair.second); } RAY_LOG(DEBUG) << "Node manager " << gcs_client_->client_table().GetLocalClientId() - << " started on " << client_info.node_manager_address << ":" - << client_info.node_manager_port << " object manager at " - << client_info.node_manager_address << ":" - << client_info.object_manager_port; + << " started on " << client_info.node_manager_address() << ":" + << client_info.node_manager_port() << " object manager at " + << client_info.node_manager_address() << ":" + << client_info.object_manager_port(); ; RAY_RETURN_NOT_OK(gcs_client_->client_table().Connect(client_info)); diff --git a/src/ray/raylet/raylet.h b/src/ray/raylet/raylet.h index 26fe74b2b622..9367a5054591 100644 --- a/src/ray/raylet/raylet.h +++ b/src/ray/raylet/raylet.h @@ -16,6 +16,8 @@ namespace ray { namespace raylet { +using rpc::ClientTableData; + class Task; class NodeManager; diff --git a/src/ray/raylet/reconstruction_policy.cc b/src/ray/raylet/reconstruction_policy.cc index 97c86ea73cd8..bf5c1acfaa37 100644 --- a/src/ray/raylet/reconstruction_policy.cc +++ b/src/ray/raylet/reconstruction_policy.cc @@ -106,19 +106,19 @@ void ReconstructionPolicy::AttemptReconstruction(const TaskID &task_id, // Attempt to reconstruct the task by inserting an entry into the task // reconstruction log. This will fail if another node has already inserted // an entry for this reconstruction. - auto reconstruction_entry = std::make_shared(); - reconstruction_entry->num_reconstructions = reconstruction_attempt; - reconstruction_entry->node_manager_id = client_id_.Binary(); + auto reconstruction_entry = std::make_shared(); + reconstruction_entry->set_num_reconstructions(reconstruction_attempt); + reconstruction_entry->set_node_manager_id(client_id_.Binary()); RAY_CHECK_OK(task_reconstruction_log_.AppendAt( DriverID::Nil(), task_id, reconstruction_entry, /*success_callback=*/ [this](gcs::AsyncGcsClient *client, const TaskID &task_id, - const TaskReconstructionDataT &data) { + const TaskReconstructionData &data) { HandleReconstructionLogAppend(task_id, /*success=*/true); }, /*failure_callback=*/ [this](gcs::AsyncGcsClient *client, const TaskID &task_id, - const TaskReconstructionDataT &data) { + const TaskReconstructionData &data) { HandleReconstructionLogAppend(task_id, /*success=*/false); }, reconstruction_attempt)); diff --git a/src/ray/raylet/reconstruction_policy.h b/src/ray/raylet/reconstruction_policy.h index cd969cc2706e..a194443e1425 100644 --- a/src/ray/raylet/reconstruction_policy.h +++ b/src/ray/raylet/reconstruction_policy.h @@ -17,6 +17,8 @@ namespace ray { namespace raylet { +using rpc::TaskReconstructionData; + class ReconstructionPolicyInterface { public: virtual void ListenAndMaybeReconstruct(const ObjectID &object_id) = 0; diff --git a/src/ray/raylet/reconstruction_policy_test.cc b/src/ray/raylet/reconstruction_policy_test.cc index 4ccebd0c0c09..12d9336a382f 100644 --- a/src/ray/raylet/reconstruction_policy_test.cc +++ b/src/ray/raylet/reconstruction_policy_test.cc @@ -14,6 +14,8 @@ namespace ray { namespace raylet { +using rpc::TaskLeaseData; + class MockObjectDirectory : public ObjectDirectoryInterface { public: MockObjectDirectory() {} @@ -83,7 +85,7 @@ class MockGcs : public gcs::PubsubInterface, } void Add(const DriverID &driver_id, const TaskID &task_id, - std::shared_ptr &task_lease_data) { + std::shared_ptr &task_lease_data) { task_lease_table_[task_id] = task_lease_data; if (subscribed_tasks_.count(task_id) == 1) { notification_callback_(nullptr, task_id, *task_lease_data); @@ -110,7 +112,7 @@ class MockGcs : public gcs::PubsubInterface, Status AppendAt( const DriverID &driver_id, const TaskID &task_id, - std::shared_ptr &task_data, + std::shared_ptr &task_data, const ray::gcs::LogInterface::WriteCallback &success_callback, const ray::gcs::LogInterface::WriteCallback @@ -132,15 +134,15 @@ class MockGcs : public gcs::PubsubInterface, MOCK_METHOD4( Append, ray::Status( - const DriverID &, const TaskID &, std::shared_ptr &, + const DriverID &, const TaskID &, std::shared_ptr &, const ray::gcs::LogInterface::WriteCallback &)); private: gcs::TaskLeaseTable::WriteCallback notification_callback_; gcs::TaskLeaseTable::FailureCallback failure_callback_; - std::unordered_map> task_lease_table_; + std::unordered_map> task_lease_table_; std::unordered_set subscribed_tasks_; - std::unordered_map> + std::unordered_map> task_reconstruction_log_; }; @@ -159,9 +161,9 @@ class ReconstructionPolicyTest : public ::testing::Test { timer_canceled_(false) { mock_gcs_.Subscribe( [this](gcs::AsyncGcsClient *client, const TaskID &task_id, - const TaskLeaseDataT &task_lease) { + const TaskLeaseData &task_lease) { reconstruction_policy_->HandleTaskLeaseNotification(task_id, - task_lease.timeout); + task_lease.timeout()); }, [this](gcs::AsyncGcsClient *client, const TaskID &task_id) { reconstruction_policy_->HandleTaskLeaseNotification(task_id, 0); @@ -314,10 +316,10 @@ TEST_F(ReconstructionPolicyTest, TestReconstructionSuppressed) { int64_t test_period = 2 * reconstruction_timeout_ms_; // Acquire the task lease for a period longer than the test period. - auto task_lease_data = std::make_shared(); - task_lease_data->node_manager_id = ClientID::FromRandom().Binary(); - task_lease_data->acquired_at = current_sys_time_ms(); - task_lease_data->timeout = 2 * test_period; + auto task_lease_data = std::make_shared(); + task_lease_data->set_node_manager_id(ClientID::FromRandom().Binary()); + task_lease_data->set_acquired_at(current_sys_time_ms()); + task_lease_data->set_timeout(2 * test_period); mock_gcs_.Add(DriverID::Nil(), task_id, task_lease_data); // Listen for an object. @@ -328,7 +330,7 @@ TEST_F(ReconstructionPolicyTest, TestReconstructionSuppressed) { ASSERT_TRUE(reconstructed_tasks_.empty()); // Run the test again past the expiration time of the lease. - Run(task_lease_data->timeout * 1.1); + Run(task_lease_data->timeout() * 1.1); // Check that this time, reconstruction is triggered. ASSERT_EQ(reconstructed_tasks_[task_id], 1); } @@ -341,10 +343,10 @@ TEST_F(ReconstructionPolicyTest, TestReconstructionContinuallySuppressed) { reconstruction_policy_->ListenAndMaybeReconstruct(object_id); // Send the reconstruction manager heartbeats about the object. SetPeriodicTimer(reconstruction_timeout_ms_ / 2, [this, task_id]() { - auto task_lease_data = std::make_shared(); - task_lease_data->node_manager_id = ClientID::FromRandom().Binary(); - task_lease_data->acquired_at = current_sys_time_ms(); - task_lease_data->timeout = reconstruction_timeout_ms_; + auto task_lease_data = std::make_shared(); + task_lease_data->set_node_manager_id(ClientID::FromRandom().Binary()); + task_lease_data->set_acquired_at(current_sys_time_ms()); + task_lease_data->set_timeout(reconstruction_timeout_ms_); mock_gcs_.Add(DriverID::Nil(), task_id, task_lease_data); }); // Run the test for much longer than the reconstruction timeout. @@ -393,14 +395,14 @@ TEST_F(ReconstructionPolicyTest, TestSimultaneousReconstructionSuppressed) { // Log a reconstruction attempt to simulate a different node attempting the // reconstruction first. This should suppress this node's first attempt at // reconstruction. - auto task_reconstruction_data = std::make_shared(); - task_reconstruction_data->node_manager_id = ClientID::FromRandom().Binary(); - task_reconstruction_data->num_reconstructions = 0; + auto task_reconstruction_data = std::make_shared(); + task_reconstruction_data->set_node_manager_id(ClientID::FromRandom().Binary()); + task_reconstruction_data->set_num_reconstructions(0); RAY_CHECK_OK( mock_gcs_.AppendAt(DriverID::Nil(), task_id, task_reconstruction_data, nullptr, /*failure_callback=*/ [](ray::gcs::AsyncGcsClient *client, const TaskID &task_id, - const TaskReconstructionDataT &data) { ASSERT_TRUE(false); }, + const TaskReconstructionData &data) { ASSERT_TRUE(false); }, /*log_index=*/0)); // Listen for an object. diff --git a/src/ray/raylet/task_dependency_manager.cc b/src/ray/raylet/task_dependency_manager.cc index c5155b96b0c1..89028c733d0d 100644 --- a/src/ray/raylet/task_dependency_manager.cc +++ b/src/ray/raylet/task_dependency_manager.cc @@ -261,10 +261,10 @@ void TaskDependencyManager::AcquireTaskLease(const TaskID &task_id) { << (it->second.expires_at - now_ms) << "ms"; } - auto task_lease_data = std::make_shared(); - task_lease_data->node_manager_id = client_id_.Hex(); - task_lease_data->acquired_at = current_sys_time_ms(); - task_lease_data->timeout = it->second.lease_period; + auto task_lease_data = std::make_shared(); + task_lease_data->set_node_manager_id(client_id_.Hex()); + task_lease_data->set_acquired_at(current_sys_time_ms()); + task_lease_data->set_timeout(it->second.lease_period); RAY_CHECK_OK(task_lease_table_.Add(DriverID::Nil(), task_id, task_lease_data, nullptr)); auto period = boost::posix_time::milliseconds(it->second.lease_period / 2); diff --git a/src/ray/raylet/task_dependency_manager.h b/src/ray/raylet/task_dependency_manager.h index 3788a5eae7ae..a96558295234 100644 --- a/src/ray/raylet/task_dependency_manager.h +++ b/src/ray/raylet/task_dependency_manager.h @@ -13,6 +13,8 @@ namespace ray { namespace raylet { +using rpc::TaskLeaseData; + class ReconstructionPolicy; /// \class TaskDependencyManager diff --git a/src/ray/raylet/task_dependency_manager_test.cc b/src/ray/raylet/task_dependency_manager_test.cc index e0f832a12870..f7a60989fcba 100644 --- a/src/ray/raylet/task_dependency_manager_test.cc +++ b/src/ray/raylet/task_dependency_manager_test.cc @@ -30,7 +30,7 @@ class MockGcs : public gcs::TableInterface { MOCK_METHOD4( Add, ray::Status(const DriverID &driver_id, const TaskID &task_id, - std::shared_ptr &task_data, + std::shared_ptr &task_data, const gcs::TableInterface::WriteCallback &done)); }; diff --git a/src/ray/raylet/worker_pool.cc b/src/ray/raylet/worker_pool.cc index 719378216fb7..16086565de80 100644 --- a/src/ray/raylet/worker_pool.cc +++ b/src/ray/raylet/worker_pool.cc @@ -48,8 +48,8 @@ WorkerPool::WorkerPool( : num_workers_per_process_(num_workers_per_process), multiple_for_warning_(std::max(num_worker_processes, maximum_startup_concurrency)), maximum_startup_concurrency_(maximum_startup_concurrency), - gcs_client_(std::move(gcs_client)), - last_warning_multiple_(0) { + last_warning_multiple_(0), + gcs_client_(std::move(gcs_client)) { RAY_CHECK(num_workers_per_process > 0) << "num_workers_per_process must be positive."; RAY_CHECK(maximum_startup_concurrency > 0); // Ignore SIGCHLD signals. If we don't do this, then worker processes will diff --git a/src/ray/rpc/util.h b/src/ray/rpc/util.h index 6ecc6c3c4a34..59ae75ae33be 100644 --- a/src/ray/rpc/util.h +++ b/src/ray/rpc/util.h @@ -1,6 +1,7 @@ #ifndef RAY_RPC_UTIL_H #define RAY_RPC_UTIL_H +#include #include #include "ray/common/status.h" @@ -27,6 +28,18 @@ inline Status GrpcStatusToRayStatus(const grpc::Status &grpc_status) { } } +template +inline std::vector VectorFromProtobuf( + const ::google::protobuf::RepeatedPtrField &pb_repeated) { + return std::vector(pb_repeated.begin(), pb_repeated.end()); +} + +template +inline std::vector VectorFromProtobuf( + const ::google::protobuf::RepeatedField &pb_repeated) { + return std::vector(pb_repeated.begin(), pb_repeated.end()); +} + } // namespace rpc } // namespace ray From aa5fc52e32cea80783abd25d8c19e5eb9a1c3b3c Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Tue, 25 Jun 2019 19:02:40 -0700 Subject: [PATCH 111/118] [rllib] Add QMIX mixer parameters to optimizer param list (#5014) * add mixer params * Update qmix_policy.py --- python/ray/rllib/agents/qmix/qmix_policy.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/ray/rllib/agents/qmix/qmix_policy.py b/python/ray/rllib/agents/qmix/qmix_policy.py index 26ec387de004..99045899684b 100644 --- a/python/ray/rllib/agents/qmix/qmix_policy.py +++ b/python/ray/rllib/agents/qmix/qmix_policy.py @@ -204,6 +204,8 @@ def __init__(self, obs_space, action_space, config): # Setup optimizer self.params = list(self.model.parameters()) + if self.mixer: + self.params += list(self.mixer.parameters()) self.loss = QMixLoss(self.model, self.target_model, self.mixer, self.target_mixer, self.n_agents, self.n_actions, self.config["double_q"], self.config["gamma"]) From bb8e75b532e8fa4761976402fa3dc65c223ea843 Mon Sep 17 00:00:00 2001 From: Zhijun Fu <37800433+zhijunfu@users.noreply.github.com> Date: Wed, 26 Jun 2019 10:08:09 +0800 Subject: [PATCH 112/118] [grpc] refactor rpc server to support multiple io services (#5023) --- src/ray/raylet/node_manager.cc | 4 +- src/ray/raylet/node_manager.h | 5 +- src/ray/rpc/grpc_server.cc | 17 +++++-- src/ray/rpc/grpc_server.h | 77 +++++++++++++++++++++---------- src/ray/rpc/node_manager_server.h | 25 +++++----- src/ray/rpc/server_call.h | 26 +++++++++-- 6 files changed, 104 insertions(+), 50 deletions(-) diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 808eeb6fd211..226a8fb6d251 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -101,7 +101,8 @@ NodeManager::NodeManager(boost::asio::io_service &io_service, gcs_client_->raylet_task_table(), gcs_client_->raylet_task_table(), config.max_lineage_size), actor_registry_(), - node_manager_server_(config.node_manager_port, io_service, *this), + node_manager_server_("NodeManager", config.node_manager_port), + node_manager_service_(io_service, *this), client_call_manager_(io_service) { RAY_CHECK(heartbeat_period_.count() > 0); // Initialize the resource map with own cluster resource configuration. @@ -119,6 +120,7 @@ NodeManager::NodeManager(boost::asio::io_service &io_service, RAY_ARROW_CHECK_OK(store_client_.Connect(config.store_socket_name.c_str())); // Run the node manger rpc server. + node_manager_server_.RegisterService(node_manager_service_); node_manager_server_.Run(); } diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index f45c8b035553..7e812183657c 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -512,7 +512,10 @@ class NodeManager : public rpc::NodeManagerServiceHandler { std::unordered_map checkpoint_id_to_restore_; /// The RPC server. - rpc::NodeManagerServer node_manager_server_; + rpc::GrpcServer node_manager_server_; + + /// The RPC service. + rpc::NodeManagerGrpcService node_manager_service_; /// The `ClientCallManager` object that is shared by all `NodeManagerClient`s. rpc::ClientCallManager client_call_manager_; diff --git a/src/ray/rpc/grpc_server.cc b/src/ray/rpc/grpc_server.cc index feb788da7692..f507039990c2 100644 --- a/src/ray/rpc/grpc_server.cc +++ b/src/ray/rpc/grpc_server.cc @@ -1,4 +1,5 @@ #include "ray/rpc/grpc_server.h" +#include namespace ray { namespace rpc { @@ -9,8 +10,10 @@ void GrpcServer::Run() { grpc::ServerBuilder builder; // TODO(hchen): Add options for authentication. builder.AddListeningPort(server_address, grpc::InsecureServerCredentials(), &port_); - // Allow subclasses to register concrete services. - RegisterServices(builder); + // Register all the services to this server. + for (auto &entry : services_) { + builder.RegisterService(&entry.get()); + } // Get hold of the completion queue used for the asynchronous communication // with the gRPC runtime. cq_ = builder.AddCompletionQueue(); @@ -18,8 +21,7 @@ void GrpcServer::Run() { server_ = builder.BuildAndStart(); RAY_LOG(DEBUG) << name_ << " server started, listening on port " << port_ << "."; - // Allow subclasses to initialize the server call factories. - InitServerCallFactories(&server_call_factories_and_concurrencies_); + // Create calls for all the server call factories. for (auto &entry : server_call_factories_and_concurrencies_) { for (int i = 0; i < entry.second; i++) { // Create and request calls from the factory. @@ -31,6 +33,11 @@ void GrpcServer::Run() { polling_thread.detach(); } +void GrpcServer::RegisterService(GrpcService &service) { + services_.emplace_back(service.GetGrpcService()); + service.InitServerCallFactories(cq_, &server_call_factories_and_concurrencies_); +} + void GrpcServer::PollEventsFromCompletionQueue() { void *tag; bool ok; @@ -48,7 +55,7 @@ void GrpcServer::PollEventsFromCompletionQueue() { // incoming request. server_call->GetFactory().CreateCall(); server_call->SetState(ServerCallState::PROCESSING); - main_service_.post([server_call] { server_call->HandleRequest(); }); + server_call->HandleRequest(); break; case ServerCallState::SENDING_REPLY: // The reply has been sent, this call can be deleted now. diff --git a/src/ray/rpc/grpc_server.h b/src/ray/rpc/grpc_server.h index 4953f470610f..584da6565a47 100644 --- a/src/ray/rpc/grpc_server.h +++ b/src/ray/rpc/grpc_server.h @@ -12,7 +12,9 @@ namespace ray { namespace rpc { -/// Base class that represents an abstract gRPC server. +class GrpcService; + +/// Class that represents an gRPC server. /// /// A `GrpcServer` listens on a specific port. It owns /// 1) a `ServerCompletionQueue` that is used for polling events from gRPC, @@ -28,11 +30,7 @@ class GrpcServer { /// \param[in] name Name of this server, used for logging and debugging purpose. /// \param[in] port The port to bind this server to. If it's 0, a random available port /// will be chosen. - /// \param[in] main_service The main event loop, to which service handler functions - /// will be posted. - GrpcServer(const std::string &name, const uint32_t port, - boost::asio::io_service &main_service) - : name_(name), port_(port), main_service_(main_service) {} + GrpcServer(const std::string &name, const uint32_t port) : name_(name), port_(port) {} /// Destruct this gRPC server. ~GrpcServer() { @@ -46,36 +44,25 @@ class GrpcServer { /// Get the port of this gRPC server. int GetPort() const { return port_; } - protected: - /// Subclasses should implement this method and register one or multiple gRPC services - /// to the given `ServerBuilder`. + /// Register a grpc service. Multiple services can be registered to the same server. + /// Note that the `service` registered must remain valid for the lifetime of the + /// `GrpcServer`, as it holds the underlying `grpc::Service`. /// - /// \param[in] builder The `ServerBuilder` instance to register services to. - virtual void RegisterServices(grpc::ServerBuilder &builder) = 0; - - /// Subclasses should implement this method to initialize the `ServerCallFactory` - /// instances, as well as specify maximum number of concurrent requests that gRPC - /// server can "accept" (not "handle"). Each factory will be used to create - /// `accept_concurrency` `ServerCall` objects, each of which will be used to accept and - /// handle an incoming request. - /// - /// \param[out] server_call_factories_and_concurrencies The `ServerCallFactory` objects, - /// and the maximum number of concurrent requests that gRPC server can accept. - virtual void InitServerCallFactories( - std::vector, int>> - *server_call_factories_and_concurrencies) = 0; + /// \param[in] service A `GrpcService` to register to this server. + void RegisterService(GrpcService &service); + protected: /// This function runs in a background thread. It keeps polling events from the /// `ServerCompletionQueue`, and dispaches the event to the `ServiceHandler` instances /// via the `ServerCall` objects. void PollEventsFromCompletionQueue(); - /// The main event loop, to which the service handler functions will be posted. - boost::asio::io_service &main_service_; /// Name of this server, used for logging and debugging purpose. const std::string name_; /// Port of this server. int port_; + /// The `grpc::Service` objects which should be registered to `ServerBuilder`. + std::vector> services_; /// The `ServerCallFactory` objects, and the maximum number of concurrent requests that /// gRPC server can accept. std::vector, int>> @@ -86,6 +73,46 @@ class GrpcServer { std::unique_ptr server_; }; +/// Base class that represents an abstract gRPC service. +/// +/// Subclass should implement `InitServerCallFactories` to decide +/// which kinds of requests this service should accept. +class GrpcService { + public: + /// Constructor. + /// + /// \param[in] main_service The main event loop, to which service handler functions + /// will be posted. + GrpcService(boost::asio::io_service &main_service) : main_service_(main_service) {} + + /// Destruct this gRPC service. + ~GrpcService() {} + + protected: + /// Return the underlying grpc::Service object for this class. + /// This is passed to `GrpcServer` to be registered to grpc `ServerBuilder`. + virtual grpc::Service &GetGrpcService() = 0; + + /// Subclasses should implement this method to initialize the `ServerCallFactory` + /// instances, as well as specify maximum number of concurrent requests that gRPC + /// server can "accept" (not "handle"). Each factory will be used to create + /// `accept_concurrency` `ServerCall` objects, each of which will be used to accept and + /// handle an incoming request. + /// + /// \param[in] cq The grpc completion queue. + /// \param[out] server_call_factories_and_concurrencies The `ServerCallFactory` objects, + /// and the maximum number of concurrent requests that gRPC server can accept. + virtual void InitServerCallFactories( + const std::unique_ptr &cq, + std::vector, int>> + *server_call_factories_and_concurrencies) = 0; + + /// The main event loop, to which the service handler functions will be posted. + boost::asio::io_service &main_service_; + + friend class GrpcServer; +}; + } // namespace rpc } // namespace ray diff --git a/src/ray/rpc/node_manager_server.h b/src/ray/rpc/node_manager_server.h index afaea299ea89..d05f268c65b2 100644 --- a/src/ray/rpc/node_manager_server.h +++ b/src/ray/rpc/node_manager_server.h @@ -25,25 +25,22 @@ class NodeManagerServiceHandler { RequestDoneCallback done_callback) = 0; }; -/// The `GrpcServer` for `NodeManagerService`. -class NodeManagerServer : public GrpcServer { +/// The `GrpcService` for `NodeManagerService`. +class NodeManagerGrpcService : public GrpcService { public: /// Constructor. /// - /// \param[in] port See super class. - /// \param[in] main_service See super class. + /// \param[in] io_service See super class. /// \param[in] handler The service handler that actually handle the requests. - NodeManagerServer(const uint32_t port, boost::asio::io_service &main_service, - NodeManagerServiceHandler &service_handler) - : GrpcServer("NodeManager", port, main_service), - service_handler_(service_handler){}; + NodeManagerGrpcService(boost::asio::io_service &io_service, + NodeManagerServiceHandler &service_handler) + : GrpcService(io_service), service_handler_(service_handler){}; - void RegisterServices(grpc::ServerBuilder &builder) override { - /// Register `NodeManagerService`. - builder.RegisterService(&service_); - } + protected: + grpc::Service &GetGrpcService() override { return service_; } void InitServerCallFactories( + const std::unique_ptr &cq, std::vector, int>> *server_call_factories_and_concurrencies) override { // Initialize the factory for `ForwardTask` requests. @@ -51,7 +48,8 @@ class NodeManagerServer : public GrpcServer { new ServerCallFactoryImpl( service_, &NodeManagerService::AsyncService::RequestForwardTask, - service_handler_, &NodeManagerServiceHandler::HandleForwardTask, cq_)); + service_handler_, &NodeManagerServiceHandler::HandleForwardTask, cq, + main_service_)); // Set `ForwardTask`'s accept concurrency to 100. server_call_factories_and_concurrencies->emplace_back( @@ -61,6 +59,7 @@ class NodeManagerServer : public GrpcServer { private: /// The grpc async service object. NodeManagerService::AsyncService service_; + /// The service handler that actually handle the requests. NodeManagerServiceHandler &service_handler_; }; diff --git a/src/ray/rpc/server_call.h b/src/ray/rpc/server_call.h index e06278260ab6..08ca128323ee 100644 --- a/src/ray/rpc/server_call.h +++ b/src/ray/rpc/server_call.h @@ -94,20 +94,27 @@ class ServerCallImpl : public ServerCall { /// \param[in] factory The factory which created this call. /// \param[in] service_handler The service handler that handles the request. /// \param[in] handle_request_function Pointer to the service handler function. + /// \param[in] io_service The event loop. ServerCallImpl( const ServerCallFactory &factory, ServiceHandler &service_handler, - HandleRequestFunction handle_request_function) + HandleRequestFunction handle_request_function, + boost::asio::io_service &io_service) : state_(ServerCallState::PENDING), factory_(factory), service_handler_(service_handler), handle_request_function_(handle_request_function), - response_writer_(&context_) {} + response_writer_(&context_), + io_service_(io_service) {} ServerCallState GetState() const override { return state_; } void SetState(const ServerCallState &new_state) override { state_ = new_state; } void HandleRequest() override { + io_service_.post([this] { HandleRequestImpl(); }); + } + + void HandleRequestImpl() { state_ = ServerCallState::PROCESSING; (service_handler_.*handle_request_function_)(request_, &reply_, [this](Status status) { @@ -146,6 +153,9 @@ class ServerCallImpl : public ServerCall { /// The reponse writer. grpc::ServerAsyncResponseWriter response_writer_; + /// The event loop. + boost::asio::io_service &io_service_; + /// The request message. Request request_; @@ -185,23 +195,26 @@ class ServerCallFactoryImpl : public ServerCallFactory { /// \param[in] service_handler The service handler that handles the request. /// \param[in] handle_request_function Pointer to the service handler function. /// \param[in] cq The `CompletionQueue`. + /// \param[in] io_service The event loop. ServerCallFactoryImpl( AsyncService &service, RequestCallFunction request_call_function, ServiceHandler &service_handler, HandleRequestFunction handle_request_function, - const std::unique_ptr &cq) + const std::unique_ptr &cq, + boost::asio::io_service &io_service) : service_(service), request_call_function_(request_call_function), service_handler_(service_handler), handle_request_function_(handle_request_function), - cq_(cq) {} + cq_(cq), + io_service_(io_service) {} ServerCall *CreateCall() const override { // Create a new `ServerCall`. This object will eventually be deleted by // `GrpcServer::PollEventsFromCompletionQueue`. auto call = new ServerCallImpl( - *this, service_handler_, handle_request_function_); + *this, service_handler_, handle_request_function_, io_service_); /// Request gRPC runtime to starting accepting this kind of request, using the call as /// the tag. (service_.*request_call_function_)(&call->context_, &call->request_, @@ -225,6 +238,9 @@ class ServerCallFactoryImpl : public ServerCallFactory { /// The `CompletionQueue`. const std::unique_ptr &cq_; + + /// The event loop. + boost::asio::io_service &io_service_; }; } // namespace rpc From bbe3e5b4edfb68e555a80710ddad73f80d44fce7 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Tue, 25 Jun 2019 22:06:36 -0700 Subject: [PATCH 113/118] [rllib] Give error if sample_async is used with pytorch for A3C (#5000) * give error if sample_async is used with pytorch * update * Update a3c.py --- python/ray/rllib/agents/a3c/a3c.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/ray/rllib/agents/a3c/a3c.py b/python/ray/rllib/agents/a3c/a3c.py index c269df2fc6e5..d320b9636881 100644 --- a/python/ray/rllib/agents/a3c/a3c.py +++ b/python/ray/rllib/agents/a3c/a3c.py @@ -48,6 +48,10 @@ def get_policy_class(config): def validate_config(config): if config["entropy_coeff"] < 0: raise DeprecationWarning("entropy_coeff must be >= 0") + if config["sample_async"] and config["use_pytorch"]: + raise ValueError( + "The sample_async option is not supported with use_pytorch: " + "Multithreading can be lead to crashes if used with pytorch.") def make_async_optimizer(workers, config): From b1827d5fbe144bfeb713b1b3498d1857baf55b09 Mon Sep 17 00:00:00 2001 From: Richard Liaw Date: Tue, 25 Jun 2019 22:50:15 -0700 Subject: [PATCH 114/118] [tune] Update MNIST Example (#4991) --- ci/jenkins_tests/run_tune_tests.sh | 2 +- docker/tune_test/Dockerfile | 2 + python/ray/tune/examples/mnist_pytorch.py | 273 +++++++++------------- 3 files changed, 112 insertions(+), 165 deletions(-) diff --git a/ci/jenkins_tests/run_tune_tests.sh b/ci/jenkins_tests/run_tune_tests.sh index 84e7e7fe9c0f..6b890d7d371c 100755 --- a/ci/jenkins_tests/run_tune_tests.sh +++ b/ci/jenkins_tests/run_tune_tests.sh @@ -87,7 +87,7 @@ $SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} --smoke-test $SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ - python /ray/python/ray/tune/examples/mnist_pytorch.py --smoke-test --no-cuda + python /ray/python/ray/tune/examples/mnist_pytorch.py --smoke-test $SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/tune/examples/mnist_pytorch_trainable.py \ diff --git a/docker/tune_test/Dockerfile b/docker/tune_test/Dockerfile index 1d252a62fd62..77cf390493d6 100644 --- a/docker/tune_test/Dockerfile +++ b/docker/tune_test/Dockerfile @@ -4,6 +4,8 @@ FROM ray-project/base-deps # We install ray and boto3 to enable the ray autoscaler as # a test runner. +RUN conda install -y -c anaconda wrapt=1.11.1 +RUN conda install -y -c anaconda numpy=1.16.4 RUN pip install -U https://ray-wheels.s3-us-west-2.amazonaws.com/latest/ray-0.8.0.dev1-cp36-cp36m-manylinux1_x86_64.whl boto3 # We install this after the latest wheels -- this should not override the latest wheels. RUN apt-get install -y zlib1g-dev diff --git a/python/ray/tune/examples/mnist_pytorch.py b/python/ray/tune/examples/mnist_pytorch.py index 03dd2f1607e2..acef9fc5105d 100644 --- a/python/ray/tune/examples/mnist_pytorch.py +++ b/python/ray/tune/examples/mnist_pytorch.py @@ -1,7 +1,10 @@ # Original Code here: # https://github.com/pytorch/examples/blob/master/mnist/main.py +from __future__ import absolute_import +from __future__ import division from __future__ import print_function +import numpy as np import argparse import torch import torch.nn as nn @@ -9,181 +12,123 @@ import torch.optim as optim from torchvision import datasets, transforms -# Training settings -parser = argparse.ArgumentParser(description="PyTorch MNIST Example") -parser.add_argument( - "--batch-size", - type=int, - default=64, - metavar="N", - help="input batch size for training (default: 64)") -parser.add_argument( - "--test-batch-size", - type=int, - default=1000, - metavar="N", - help="input batch size for testing (default: 1000)") -parser.add_argument( - "--epochs", - type=int, - default=1, - metavar="N", - help="number of epochs to train (default: 1)") -parser.add_argument( - "--lr", - type=float, - default=0.01, - metavar="LR", - help="learning rate (default: 0.01)") -parser.add_argument( - "--momentum", - type=float, - default=0.5, - metavar="M", - help="SGD momentum (default: 0.5)") -parser.add_argument( - "--no-cuda", - action="store_true", - default=False, - help="disables CUDA training") -parser.add_argument( - "--seed", - type=int, - default=1, - metavar="S", - help="random seed (default: 1)") -parser.add_argument( - "--smoke-test", action="store_true", help="Finish quickly for testing") - - -def train_mnist(args, config, reporter): - vars(args).update(config) - args.cuda = not args.no_cuda and torch.cuda.is_available() - - torch.manual_seed(args.seed) - if args.cuda: - torch.cuda.manual_seed(args.seed) - - kwargs = {"num_workers": 1, "pin_memory": True} if args.cuda else {} +import ray +from ray import tune +from ray.tune import track +from ray.tune.schedulers import AsyncHyperBandScheduler + +# Change these values if you want the training to run quicker or slower. +EPOCH_SIZE = 512 +TEST_SIZE = 256 + + +class Net(nn.Module): + def __init__(self, config): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(1, 3, kernel_size=3) + self.fc = nn.Linear(192, 10) + + def forward(self, x): + x = F.relu(F.max_pool2d(self.conv1(x), 3)) + x = x.view(-1, 192) + x = self.fc(x) + return F.log_softmax(x, dim=1) + + +def train(model, optimizer, train_loader, device): + model.train() + for batch_idx, (data, target) in enumerate(train_loader): + if batch_idx * len(data) > EPOCH_SIZE: + return + data, target = data.to(device), target.to(device) + optimizer.zero_grad() + output = model(data) + loss = F.nll_loss(output, target) + loss.backward() + optimizer.step() + + +def test(model, data_loader, device): + model.eval() + correct = 0 + total = 0 + with torch.no_grad(): + for batch_idx, (data, target) in enumerate(data_loader): + if batch_idx * len(data) > TEST_SIZE: + break + data, target = data.to(device), target.to(device) + outputs = model(data) + _, predicted = torch.max(outputs.data, 1) + total += target.size(0) + correct += (predicted == target).sum().item() + + return correct / total + + +def get_data_loaders(): + mnist_transforms = transforms.Compose( + [transforms.ToTensor(), + transforms.Normalize((0.1307, ), (0.3081, ))]) + train_loader = torch.utils.data.DataLoader( datasets.MNIST( - "~/data", - train=True, - download=False, - transform=transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.1307, ), (0.3081, )) - ])), - batch_size=args.batch_size, - shuffle=True, - **kwargs) + "~/data", train=True, download=True, transform=mnist_transforms), + batch_size=64, + shuffle=True) test_loader = torch.utils.data.DataLoader( - datasets.MNIST( - "~/data", - train=False, - transform=transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.1307, ), (0.3081, )) - ])), - batch_size=args.test_batch_size, - shuffle=True, - **kwargs) - - class Net(nn.Module): - def __init__(self): - super(Net, self).__init__() - self.conv1 = nn.Conv2d(1, 10, kernel_size=5) - self.conv2 = nn.Conv2d(10, 20, kernel_size=5) - self.conv2_drop = nn.Dropout2d() - self.fc1 = nn.Linear(320, 50) - self.fc2 = nn.Linear(50, 10) - - def forward(self, x): - x = F.relu(F.max_pool2d(self.conv1(x), 2)) - x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) - x = x.view(-1, 320) - x = F.relu(self.fc1(x)) - x = F.dropout(x, training=self.training) - x = self.fc2(x) - return F.log_softmax(x, dim=1) - - model = Net() - if args.cuda: - model.cuda() + datasets.MNIST("~/data", train=False, transform=mnist_transforms), + batch_size=64, + shuffle=True) + return train_loader, test_loader + + +def train_mnist(config): + use_cuda = config.get("use_gpu") and torch.cuda.is_available() + device = torch.device("cuda" if use_cuda else "cpu") + train_loader, test_loader = get_data_loaders() + model = Net(config).to(device) optimizer = optim.SGD( - model.parameters(), lr=args.lr, momentum=args.momentum) - - def train(epoch): - model.train() - for batch_idx, (data, target) in enumerate(train_loader): - if args.cuda: - data, target = data.cuda(), target.cuda() - optimizer.zero_grad() - output = model(data) - loss = F.nll_loss(output, target) - loss.backward() - optimizer.step() - - def test(): - model.eval() - test_loss = 0 - correct = 0 - with torch.no_grad(): - for data, target in test_loader: - if args.cuda: - data, target = data.cuda(), target.cuda() - output = model(data) - # sum up batch loss - test_loss += F.nll_loss(output, target, reduction="sum").item() - # get the index of the max log-probability - pred = output.argmax(dim=1, keepdim=True) - correct += pred.eq( - target.data.view_as(pred)).long().cpu().sum() - - test_loss = test_loss / len(test_loader.dataset) - accuracy = correct.item() / len(test_loader.dataset) - reporter(mean_loss=test_loss, mean_accuracy=accuracy) - - for epoch in range(1, args.epochs + 1): - train(epoch) - test() + model.parameters(), lr=config["lr"], momentum=config["momentum"]) + + while True: + train(model, optimizer, train_loader, device) + acc = test(model, test_loader, device) + track.log(mean_accuracy=acc) if __name__ == "__main__": - datasets.MNIST("~/data", train=True, download=True) + parser = argparse.ArgumentParser(description="PyTorch MNIST Example") + parser.add_argument( + "--cuda", + action="store_true", + default=False, + help="Enables GPU training") + parser.add_argument( + "--smoke-test", action="store_true", help="Finish quickly for testing") + parser.add_argument( + "--ray-redis-address", + help="Address of Ray cluster for seamless distributed execution.") args = parser.parse_args() - - import ray - from ray import tune - from ray.tune.schedulers import AsyncHyperBandScheduler - - ray.init() + if args.ray_redis_address: + ray.init(redis_address=args.ray_redis_address) sched = AsyncHyperBandScheduler( - time_attr="training_iteration", - metric="mean_loss", - mode="min", - max_t=400, - grace_period=20) - tune.register_trainable( - "TRAIN_FN", - lambda config, reporter: train_mnist(args, config, reporter)) + time_attr="training_iteration", metric="mean_accuracy") tune.run( - "TRAIN_FN", + train_mnist, name="exp", scheduler=sched, - **{ - "stop": { - "mean_accuracy": 0.98, - "training_iteration": 1 if args.smoke_test else 20 - }, - "resources_per_trial": { - "cpu": 3, - "gpu": int(not args.no_cuda) - }, - "num_samples": 1 if args.smoke_test else 10, - "config": { - "lr": tune.uniform(0.001, 0.1), - "momentum": tune.uniform(0.1, 0.9), - } + stop={ + "mean_accuracy": 0.98, + "training_iteration": 5 if args.smoke_test else 20 + }, + resources_per_trial={ + "cpu": 2, + "gpu": int(args.cuda) + }, + num_samples=1 if args.smoke_test else 10, + config={ + "lr": tune.sample_from(lambda spec: 10**(-10 * np.random.rand())), + "momentum": tune.uniform(0.1, 0.9), + "use_gpu": int(args.cuda) }) From d63973769dfd30fdbfd7a5cc60d4828d97269c2a Mon Sep 17 00:00:00 2001 From: Stefan Pantic Date: Wed, 26 Jun 2019 10:04:29 +0200 Subject: [PATCH 115/118] Add entropy coeff schedule --- python/ray/rllib/agents/impala/impala.py | 1 + .../ray/rllib/agents/impala/vtrace_policy.py | 10 +++-- python/ray/rllib/policy/tf_policy.py | 38 ++++++++++++++----- 3 files changed, 37 insertions(+), 12 deletions(-) diff --git a/python/ray/rllib/agents/impala/impala.py b/python/ray/rllib/agents/impala/impala.py index b9699888bfaf..23b5ada167db 100644 --- a/python/ray/rllib/agents/impala/impala.py +++ b/python/ray/rllib/agents/impala/impala.py @@ -75,6 +75,7 @@ # balancing the three losses "vf_loss_coeff": 0.5, "entropy_coeff": 0.01, + "entropy_schedule": None, # use fake (infinite speed) sampler for testing "_fake_sampler": False, diff --git a/python/ray/rllib/agents/impala/vtrace_policy.py b/python/ray/rllib/agents/impala/vtrace_policy.py index 9b283c7172cc..20e03af0d4b8 100644 --- a/python/ray/rllib/agents/impala/vtrace_policy.py +++ b/python/ray/rllib/agents/impala/vtrace_policy.py @@ -14,7 +14,7 @@ from ray.rllib.policy.policy import Policy from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.tf_policy import TFPolicy, \ - LearningRateSchedule + LearningRateSchedule, EntropyCoeffSchedule from ray.rllib.models.action_dist import Categorical from ray.rllib.models.catalog import ModelCatalog from ray.rllib.utils.annotations import override @@ -126,7 +126,7 @@ def postprocess_trajectory(self, return sample_batch -class VTraceTFPolicy(LearningRateSchedule, VTracePostprocessing, TFPolicy): +class VTraceTFPolicy(LearningRateSchedule, EntropyCoeffSchedule, VTracePostprocessing, TFPolicy): def __init__(self, observation_space, action_space, @@ -241,6 +241,9 @@ def make_time_major(tensor, drop_last=False): loss_actions = actions if is_multidiscrete else tf.expand_dims( actions, axis=1) + EntropyCoeffSchedule.__init__(self, self.config["entropy_coeff"], + self.config["entropy_schedule"]) + # Inputs are reshaped from [B * T] => [T - 1, B] for V-trace calc. self.loss = VTraceLoss( actions=make_time_major(loss_actions, drop_last=True), @@ -259,7 +262,7 @@ def make_time_major(tensor, drop_last=False): dist_class=Categorical if is_multidiscrete else dist_class, valid_mask=make_time_major(mask, drop_last=True), vf_loss_coeff=self.config["vf_loss_coeff"], - entropy_coeff=self.config["entropy_coeff"], + entropy_coeff=self.entropy_coeff, clip_rho_threshold=self.config["vtrace_clip_rho_threshold"], clip_pg_rho_threshold=self.config["vtrace_clip_pg_rho_threshold"]) @@ -299,6 +302,7 @@ def make_time_major(tensor, drop_last=False): self.stats_fetches = { LEARNER_STATS_KEY: { "cur_lr": tf.cast(self.cur_lr, tf.float64), + "entropy_coeff": tf.cast(self.entropy_coeff, tf.float64), "policy_loss": self.loss.pi_loss, "entropy": self.loss.entropy, "grad_gnorm": tf.global_norm(self._grads), diff --git a/python/ray/rllib/policy/tf_policy.py b/python/ray/rllib/policy/tf_policy.py index ef0de42e2f0c..006973b033ca 100644 --- a/python/ray/rllib/policy/tf_policy.py +++ b/python/ray/rllib/policy/tf_policy.py @@ -2,21 +2,21 @@ from __future__ import division from __future__ import print_function -import os import errno import logging -import numpy as np +import os +import numpy as np import ray import ray.experimental.tf_utils +from ray.rllib.models.lstm import chop_into_sequences from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.models.lstm import chop_into_sequences +from ray.rllib.utils import try_import_tf from ray.rllib.utils.annotations import override, DeveloperAPI from ray.rllib.utils.debug import log_once, summarize from ray.rllib.utils.schedules import ConstantSchedule, PiecewiseSchedule from ray.rllib.utils.tf_run_builder import TFRunBuilder -from ray.rllib.utils import try_import_tf tf = try_import_tf() logger = logging.getLogger(__name__) @@ -416,7 +416,7 @@ def _build_compute_actions(self, if len(self._state_inputs) != len(state_batches): raise ValueError( "Must pass in RNN state batches for placeholders {}, got {}". - format(self._state_inputs, state_batches)) + format(self._state_inputs, state_batches)) builder.add_feed_dict(self.extra_compute_action_feed_dict()) builder.add_feed_dict({self._obs_input: obs_batch}) if state_batches: @@ -443,7 +443,7 @@ def _build_apply_gradients(self, builder, gradients): if len(gradients) != len(self._grads): raise ValueError( "Unexpected number of gradients to apply, got {} for {}". - format(gradients, self._grads)) + format(gradients, self._grads)) builder.add_feed_dict({self._is_training: True}) builder.add_feed_dict(dict(zip(self._grads, gradients))) fetches = builder.add_fetches([self._apply_op]) @@ -473,9 +473,9 @@ def _get_loss_inputs_dict(self, batch): feed_dict = {} if self._batch_divisibility_req > 1: meets_divisibility_reqs = ( - len(batch[SampleBatch.CUR_OBS]) % - self._batch_divisibility_req == 0 - and max(batch[SampleBatch.AGENT_INDEX]) == 0) # not multiagent + len(batch[SampleBatch.CUR_OBS]) % + self._batch_divisibility_req == 0 + and max(batch[SampleBatch.AGENT_INDEX]) == 0) # not multiagent else: meets_divisibility_reqs = True @@ -544,3 +544,23 @@ def on_global_var_update(self, global_vars): @override(TFPolicy) def optimizer(self): return tf.train.AdamOptimizer(self.cur_lr) + + +@DeveloperAPI +class EntropyCoeffSchedule(object): + """Mixin for TFPolicy that adds entropy coeff decay.""" + + @DeveloperAPI + def __init__(self, entropy_coeff, entropy_schedule): + self.entropy_coeff = tf.get_variable("entropy_coeff", initializer=entropy_coeff) + self._entropy_schedule = entropy_schedule + + @override(Policy) + def on_global_var_update(self, global_vars): + super(EntropyCoeffSchedule, self).on_global_var_update(global_vars) + if self._entropy_schedule is not None: + self.entropy_coeff.load( + self.config['entropy_coeff'] * + (1 - global_vars['timestep'] / + self.config['entropy_schedule']), + session=self._sess) From 92c0f88b9cd75be6281204467b95526951c03e87 Mon Sep 17 00:00:00 2001 From: Stefan Pantic Date: Wed, 26 Jun 2019 11:53:03 +0200 Subject: [PATCH 116/118] Revert "Merge with ray master" This reverts commit 108bfa293001ffd589e79288e98999aacf5b59f9, reversing changes made to 2e0eec9f723f6ba96e11183f16dba0fc664cb655. --- BUILD.bazel | 96 ++-- bazel/ray_deps_build_all.bzl | 4 - bazel/ray_deps_setup.bzl | 11 +- .../run_perf_integration.sh | 2 +- ci/jenkins_tests/run_tune_tests.sh | 8 +- doc/source/conf.py | 15 +- doc/source/tune-usage.rst | 6 - docker/base-deps/Dockerfile | 2 +- docker/examples/Dockerfile | 5 +- docker/stress_test/Dockerfile | 2 +- docker/tune_test/Dockerfile | 11 +- java/BUILD.bazel | 51 ++- .../src/main/java/org/ray/api/id/BaseId.java | 2 +- .../ray/api/options/ActorCreationOptions.java | 15 +- java/dependencies.bzl | 1 - ...modify_generated_java_flatbuffers_files.py | 20 +- java/runtime/pom.xml | 5 - .../org/ray/runtime/AbstractRayRuntime.java | 9 +- .../java/org/ray/runtime/gcs/GcsClient.java | 69 ++- .../runtime/objectstore/ObjectStoreProxy.java | 12 +- .../ray/runtime/raylet/RayletClientImpl.java | 18 +- .../org/ray/runtime/runner/RunManager.java | 3 - .../java/org/ray/runtime/task/TaskSpec.java | 8 +- .../src/main/java/org/ray/api/TestUtils.java | 15 - .../org/ray/api/test/DynamicResourceTest.java | 17 +- .../main/java/org/ray/api/test/WaitTest.java | 5 - .../ray/api/test/WorkerJvmOptionsTest.java | 31 -- python/ray/experimental/signal.py | 14 +- python/ray/gcs_utils.py | 71 +-- python/ray/monitor.py | 33 +- python/ray/rllib/agents/a3c/a3c.py | 4 - python/ray/rllib/agents/impala/impala.py | 1 - .../ray/rllib/agents/impala/vtrace_policy.py | 8 +- python/ray/rllib/agents/qmix/qmix_policy.py | 2 - python/ray/rllib/policy/tf_policy.py | 38 +- python/ray/rllib/tests/test_optimizers.py | 10 +- python/ray/services.py | 3 - python/ray/state.py | 230 ++++++---- python/ray/tests/cluster_utils.py | 4 +- python/ray/tests/conftest.py | 8 - python/ray/tests/test_actor.py | 2 +- python/ray/tests/test_basic.py | 14 +- python/ray/tests/test_failure.py | 5 +- python/ray/tests/test_signal.py | 33 -- .../ray/tune/analysis/experiment_analysis.py | 94 +--- python/ray/tune/examples/mnist_pytorch.py | 273 +++++++----- python/ray/tune/examples/track_example.py | 4 +- python/ray/tune/examples/tune_mnist_keras.py | 8 +- python/ray/tune/examples/utils.py | 36 +- python/ray/tune/experiment.py | 8 - python/ray/tune/integration/__init__.py | 0 python/ray/tune/integration/keras.py | 34 -- python/ray/tune/schedulers/__init__.py | 6 +- python/ray/tune/schedulers/async_hyperband.py | 2 - .../tune/tests/test_experiment_analysis.py | 62 ++- python/ray/tune/tests/test_trial_runner.py | 8 - python/ray/tune/trial.py | 25 +- python/ray/tune/tune.py | 11 +- python/ray/utils.py | 8 +- python/ray/worker.py | 40 +- python/setup.py | 1 - src/ray/common/constants.h | 2 - src/ray/gcs/client.cc | 4 + src/ray/gcs/client.h | 6 + src/ray/gcs/client_test.cc | 353 ++++++++------- src/ray/gcs/format/gcs.fbs | 286 +++++++++++- src/ray/gcs/redis_context.h | 15 +- src/ray/gcs/redis_module/ray_redis_module.cc | 209 +++++---- src/ray/gcs/tables.cc | 417 ++++++++++-------- src/ray/gcs/tables.h | 136 +++--- src/ray/object_manager/object_directory.cc | 34 +- src/ray/object_manager/object_manager.cc | 49 +- src/ray/object_manager/object_manager.h | 4 +- .../test/object_manager_stress_test.cc | 30 +- .../test/object_manager_test.cc | 36 +- src/ray/protobuf/gcs.proto | 280 ------------ src/ray/raylet/actor_registration.cc | 51 ++- src/ray/raylet/actor_registration.h | 24 +- src/ray/raylet/lineage_cache.cc | 37 +- src/ray/raylet/lineage_cache.h | 28 +- src/ray/raylet/lineage_cache_test.cc | 28 +- src/ray/raylet/monitor.cc | 15 +- src/ray/raylet/monitor.h | 8 +- src/ray/raylet/node_manager.cc | 262 ++++++----- src/ray/raylet/node_manager.h | 31 +- src/ray/raylet/raylet.cc | 24 +- src/ray/raylet/raylet.h | 2 - src/ray/raylet/reconstruction_policy.cc | 10 +- src/ray/raylet/reconstruction_policy.h | 2 - src/ray/raylet/reconstruction_policy_test.cc | 42 +- src/ray/raylet/task_dependency_manager.cc | 8 +- src/ray/raylet/task_dependency_manager.h | 2 - .../raylet/task_dependency_manager_test.cc | 2 +- src/ray/raylet/task_spec.cc | 12 +- src/ray/raylet/task_spec.h | 6 +- src/ray/raylet/worker_pool.cc | 100 +---- src/ray/raylet/worker_pool.h | 56 +-- src/ray/raylet/worker_pool_test.cc | 65 +-- src/ray/rpc/grpc_server.cc | 17 +- src/ray/rpc/grpc_server.h | 77 ++-- src/ray/rpc/node_manager_server.h | 25 +- src/ray/rpc/server_call.h | 26 +- src/ray/rpc/util.h | 13 - 103 files changed, 2039 insertions(+), 2338 deletions(-) delete mode 100644 java/test/src/main/java/org/ray/api/test/WorkerJvmOptionsTest.java delete mode 100644 python/ray/tune/integration/__init__.py delete mode 100644 python/ray/tune/integration/keras.py delete mode 100644 src/ray/protobuf/gcs.proto diff --git a/BUILD.bazel b/BUILD.bazel index bc9e6bcd8006..da36eec0cf57 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -1,55 +1,22 @@ # Bazel build # C/C++ documentation: https://docs.bazel.build/versions/master/be/c-cpp.html -load("@com_github_grpc_grpc//bazel:cc_grpc_library.bzl", "cc_grpc_library") -load("@build_stack_rules_proto//python:python_proto_compile.bzl", "python_proto_compile") +load("@com_github_grpc_grpc//bazel:grpc_build_system.bzl", "grpc_proto_library") load("@com_github_google_flatbuffers//:build_defs.bzl", "flatbuffer_cc_library") load("@//bazel:ray.bzl", "flatbuffer_py_library") load("@//bazel:cython_library.bzl", "pyx_library") COPTS = ["-DRAY_USE_GLOG"] -# === Begin of protobuf definitions === - -proto_library( - name = "gcs_proto", - srcs = ["src/ray/protobuf/gcs.proto"], - visibility = ["//java:__subpackages__"], -) - -cc_proto_library( - name = "gcs_cc_proto", - deps = [":gcs_proto"], -) - -python_proto_compile( - name = "gcs_py_proto", - deps = [":gcs_proto"], -) - -proto_library( - name = "node_manager_proto", - srcs = ["src/ray/protobuf/node_manager.proto"], -) - -cc_proto_library( - name = "node_manager_cc_proto", - deps = ["node_manager_proto"], -) - -# === End of protobuf definitions === - # Node manager gRPC lib. -cc_grpc_library( - name = "node_manager_cc_grpc", - srcs = [":node_manager_proto"], - grpc_only = True, - deps = [":node_manager_cc_proto"], +grpc_proto_library( + name = "node_manager_grpc_lib", + srcs = ["src/ray/protobuf/node_manager.proto"], ) # Node manager server and client. cc_library( - name = "node_manager_rpc", + name = "node_manager_rpc_lib", srcs = glob([ "src/ray/rpc/*.cc", ]), @@ -58,7 +25,7 @@ cc_library( ]), copts = COPTS, deps = [ - ":node_manager_cc_grpc", + ":node_manager_grpc_lib", ":ray_common", "@boost//:asio", "@com_github_grpc_grpc//:grpc++", @@ -147,7 +114,7 @@ cc_library( ":gcs", ":gcs_fbs", ":node_manager_fbs", - ":node_manager_rpc", + ":node_manager_rpc_lib", ":object_manager", ":ray_common", ":ray_util", @@ -455,11 +422,9 @@ cc_library( "src/ray/gcs/format", ], deps = [ - ":gcs_cc_proto", ":gcs_fbs", ":hiredis", ":node_manager_fbs", - ":node_manager_rpc", ":ray_common", ":ray_util", ":stats_lib", @@ -590,6 +555,46 @@ filegroup( visibility = ["//java:__subpackages__"], ) +flatbuffer_py_library( + name = "python_gcs_fbs", + srcs = [ + ":gcs_fbs_file", + ], + outs = [ + "ActorCheckpointIdData.py", + "ActorState.py", + "ActorTableData.py", + "Arg.py", + "ClassTableData.py", + "ClientTableData.py", + "ConfigTableData.py", + "CustomSerializerData.py", + "DriverTableData.py", + "EntryType.py", + "ErrorTableData.py", + "ErrorType.py", + "FunctionTableData.py", + "GcsEntry.py", + "HeartbeatBatchTableData.py", + "HeartbeatTableData.py", + "Language.py", + "ObjectTableData.py", + "ProfileEvent.py", + "ProfileTableData.py", + "RayResource.py", + "ResourcePair.py", + "SchedulingState.py", + "TablePrefix.py", + "TablePubsub.py", + "TaskInfo.py", + "TaskLeaseData.py", + "TaskReconstructionData.py", + "TaskTableData.py", + "TaskTableTestAndUpdate.py", + ], + out_prefix = "python/ray/core/generated/", +) + flatbuffer_py_library( name = "python_node_manager_fbs", srcs = [ @@ -674,7 +679,6 @@ cc_binary( linkstatic = 1, visibility = ["//java:__subpackages__"], deps = [ - ":gcs_cc_proto", ":ray_common", ], ) @@ -684,7 +688,7 @@ genrule( srcs = [ "python/ray/_raylet.so", "//:python_sources", - "//:gcs_py_proto", + "//:python_gcs_fbs", "//:python_node_manager_fbs", "//:redis-server", "//:redis-cli", @@ -706,13 +710,11 @@ genrule( cp -f $(location //:raylet_monitor) $$WORK_DIR/python/ray/core/src/ray/raylet/ && cp -f $(location @plasma//:plasma_store_server) $$WORK_DIR/python/ray/core/src/plasma/ && cp -f $(location //:raylet) $$WORK_DIR/python/ray/core/src/ray/raylet/ && + for f in $(locations //:python_gcs_fbs); do cp -f $$f $$WORK_DIR/python/ray/core/generated/; done && mkdir -p $$WORK_DIR/python/ray/core/generated/ray/protocol/ && for f in $(locations //:python_node_manager_fbs); do cp -f $$f $$WORK_DIR/python/ray/core/generated/ray/protocol/; done && - for f in $(locations //:gcs_py_proto); do - cp -f $$f $$WORK_DIR/python/ray/core/generated/; - done && echo $$WORK_DIR > $@ """, local = 1, diff --git a/bazel/ray_deps_build_all.bzl b/bazel/ray_deps_build_all.bzl index eda88bece7d2..3e1e1838a59a 100644 --- a/bazel/ray_deps_build_all.bzl +++ b/bazel/ray_deps_build_all.bzl @@ -4,8 +4,6 @@ load("@com_github_jupp0r_prometheus_cpp//:repositories.bzl", "prometheus_cpp_rep load("@com_github_ray_project_ray//bazel:python_configure.bzl", "python_configure") load("@com_github_checkstyle_java//:repo.bzl", "checkstyle_deps") load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps") -load("@build_stack_rules_proto//java:deps.bzl", "java_proto_compile") -load("@build_stack_rules_proto//python:deps.bzl", "python_proto_compile") def ray_deps_build_all(): @@ -15,6 +13,4 @@ def ray_deps_build_all(): prometheus_cpp_repositories() python_configure(name = "local_config_python") grpc_deps() - java_proto_compile() - python_proto_compile() diff --git a/bazel/ray_deps_setup.bzl b/bazel/ray_deps_setup.bzl index aa322654cf9f..e6dc21585699 100644 --- a/bazel/ray_deps_setup.bzl +++ b/bazel/ray_deps_setup.bzl @@ -105,14 +105,7 @@ def ray_deps_setup(): http_archive( name = "com_github_grpc_grpc", urls = [ - "https://github.com/grpc/grpc/archive/76a381869413834692b8ed305fbe923c0f9c4472.tar.gz", + "https://github.com/grpc/grpc/archive/7741e806a213cba63c96234f16d712a8aa101a49.tar.gz", ], - strip_prefix = "grpc-76a381869413834692b8ed305fbe923c0f9c4472", - ) - - http_archive( - name = "build_stack_rules_proto", - urls = ["https://github.com/stackb/rules_proto/archive/b93b544f851fdcd3fc5c3d47aee3b7ca158a8841.tar.gz"], - sha256 = "c62f0b442e82a6152fcd5b1c0b7c4028233a9e314078952b6b04253421d56d61", - strip_prefix = "rules_proto-b93b544f851fdcd3fc5c3d47aee3b7ca158a8841", + strip_prefix = "grpc-7741e806a213cba63c96234f16d712a8aa101a49", ) diff --git a/ci/jenkins_tests/perf_integration_tests/run_perf_integration.sh b/ci/jenkins_tests/perf_integration_tests/run_perf_integration.sh index f25d32df22a1..7962b21075c0 100755 --- a/ci/jenkins_tests/perf_integration_tests/run_perf_integration.sh +++ b/ci/jenkins_tests/perf_integration_tests/run_perf_integration.sh @@ -9,7 +9,7 @@ pushd "$ROOT_DIR" python -m pip install pytest-benchmark -pip install -U https://ray-wheels.s3-us-west-2.amazonaws.com/latest/ray-0.8.0.dev1-cp36-cp36m-manylinux1_x86_64.whl +pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev1-cp27-cp27mu-manylinux1_x86_64.whl python -m pytest --benchmark-autosave --benchmark-min-rounds=10 --benchmark-columns="min, max, mean" $ROOT_DIR/../../../python/ray/tests/perf_integration_tests/test_perf_integration.py pushd $ROOT_DIR/../../../python diff --git a/ci/jenkins_tests/run_tune_tests.sh b/ci/jenkins_tests/run_tune_tests.sh index 6b890d7d371c..6154fe70d4f6 100755 --- a/ci/jenkins_tests/run_tune_tests.sh +++ b/ci/jenkins_tests/run_tune_tests.sh @@ -78,16 +78,16 @@ $SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} --smoke-test # Runs only on Python3 -$SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ - python /ray/python/ray/tune/examples/nevergrad_example.py \ - --smoke-test +# docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ +# python3 /ray/python/ray/tune/examples/nevergrad_example.py \ +# --smoke-test $SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/tune/examples/tune_mnist_keras.py \ --smoke-test $SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ - python /ray/python/ray/tune/examples/mnist_pytorch.py --smoke-test + python /ray/python/ray/tune/examples/mnist_pytorch.py --smoke-test --no-cuda $SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/tune/examples/mnist_pytorch_trainable.py \ diff --git a/doc/source/conf.py b/doc/source/conf.py index 5cf6b01217f9..98fb3e0d02dd 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -23,7 +23,20 @@ "gym.spaces", "ray._raylet", "ray.core.generated", - "ray.core.generated.gcs_pb2", + "ray.core.generated.ActorCheckpointIdData", + "ray.core.generated.ClientTableData", + "ray.core.generated.DriverTableData", + "ray.core.generated.EntryType", + "ray.core.generated.ErrorTableData", + "ray.core.generated.ErrorType", + "ray.core.generated.GcsEntry", + "ray.core.generated.HeartbeatBatchTableData", + "ray.core.generated.HeartbeatTableData", + "ray.core.generated.Language", + "ray.core.generated.ObjectTableData", + "ray.core.generated.ProfileTableData", + "ray.core.generated.TablePrefix", + "ray.core.generated.TablePubsub", "ray.core.generated.ray.protocol.Task", "scipy", "scipy.signal", diff --git a/doc/source/tune-usage.rst b/doc/source/tune-usage.rst index e8ce405d9457..281ccbd6107e 100644 --- a/doc/source/tune-usage.rst +++ b/doc/source/tune-usage.rst @@ -355,12 +355,6 @@ Then, after you run a experiment, you can visualize your experiment with TensorB $ tensorboard --logdir=~/ray_results/my_experiment -If you are running Ray on a remote multi-user cluster where you do not have sudo access, you can run the following commands to make sure tensorboard is able to write to the tmp directory: - -.. code-block:: bash - - $ export TMPDIR=/tmp/$USER; mkdir -p $TMPDIR; tensorboard --logdir=~/ray_results - .. image:: ray-tune-tensorboard.png To use rllab's VisKit (you may have to install some dependencies), run: diff --git a/docker/base-deps/Dockerfile b/docker/base-deps/Dockerfile index db8f28c85f86..c21430c627a4 100644 --- a/docker/base-deps/Dockerfile +++ b/docker/base-deps/Dockerfile @@ -12,7 +12,7 @@ RUN apt-get update \ && apt-get clean \ && echo 'export PATH=/opt/conda/bin:$PATH' > /etc/profile.d/conda.sh \ && wget \ - --quiet 'https://repo.continuum.io/archive/Anaconda3-5.2.0-Linux-x86_64.sh' \ + --quiet 'https://repo.continuum.io/archive/Anaconda2-5.2.0-Linux-x86_64.sh' \ -O /tmp/anaconda.sh \ && /bin/bash /tmp/anaconda.sh -b -p /opt/conda \ && rm /tmp/anaconda.sh \ diff --git a/docker/examples/Dockerfile b/docker/examples/Dockerfile index bafcdf35e628..6883c5a64a0e 100644 --- a/docker/examples/Dockerfile +++ b/docker/examples/Dockerfile @@ -5,14 +5,11 @@ FROM ray-project/deploy # This updates numpy to 1.14 and mutes errors from other libraries RUN conda install -y numpy RUN apt-get install -y zlib1g-dev -# The following is needed to support TensorFlow 1.14 -RUN conda remove -y --force wrapt RUN pip install gym[atari] opencv-python-headless tensorflow lz4 keras pytest-timeout smart_open RUN pip install -U h5py # Mutes FutureWarnings RUN pip install --upgrade bayesian-optimization RUN pip install --upgrade git+git://github.com/hyperopt/hyperopt.git RUN pip install --upgrade sigopt -RUN pip install --upgrade nevergrad +# RUN pip install --upgrade nevergrad RUN pip install --upgrade scikit-optimize -RUN pip install -U pytest-remotedata>=0.3.1 RUN conda install pytorch-cpu torchvision-cpu -c pytorch diff --git a/docker/stress_test/Dockerfile b/docker/stress_test/Dockerfile index 376fe5340fd9..1d174ed72f92 100644 --- a/docker/stress_test/Dockerfile +++ b/docker/stress_test/Dockerfile @@ -4,7 +4,7 @@ FROM ray-project/base-deps # We install ray and boto3 to enable the ray autoscaler as # a test runner. -RUN pip install -U https://ray-wheels.s3-us-west-2.amazonaws.com/latest/ray-0.8.0.dev1-cp36-cp36m-manylinux1_x86_64.whl boto3 +RUN pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev1-cp27-cp27mu-manylinux1_x86_64.whl boto3 RUN mkdir -p /root/.ssh/ # We port the source code in so that we run the most up-to-date stress tests. diff --git a/docker/tune_test/Dockerfile b/docker/tune_test/Dockerfile index 77cf390493d6..6e098d5218f6 100644 --- a/docker/tune_test/Dockerfile +++ b/docker/tune_test/Dockerfile @@ -4,20 +4,15 @@ FROM ray-project/base-deps # We install ray and boto3 to enable the ray autoscaler as # a test runner. -RUN conda install -y -c anaconda wrapt=1.11.1 -RUN conda install -y -c anaconda numpy=1.16.4 -RUN pip install -U https://ray-wheels.s3-us-west-2.amazonaws.com/latest/ray-0.8.0.dev1-cp36-cp36m-manylinux1_x86_64.whl boto3 +RUN pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev1-cp27-cp27mu-manylinux1_x86_64.whl boto3 # We install this after the latest wheels -- this should not override the latest wheels. RUN apt-get install -y zlib1g-dev -# The following is needed to support TensorFlow 1.14 -RUN conda remove -y --force wrapt RUN pip install gym[atari]==0.10.11 opencv-python-headless tensorflow lz4 keras pytest-timeout smart_open RUN pip install --upgrade bayesian-optimization RUN pip install --upgrade git+git://github.com/hyperopt/hyperopt.git RUN pip install --upgrade sigopt -RUN pip install --upgrade nevergrad +# RUN pip install --upgrade nevergrad RUN pip install --upgrade scikit-optimize -RUN pip install -U pytest-remotedata>=0.3.1 RUN conda install pytorch-cpu torchvision-cpu -c pytorch # RUN mkdir -p /root/.ssh/ @@ -25,6 +20,6 @@ RUN conda install pytorch-cpu torchvision-cpu -c pytorch # We port the source code in so that we run the most up-to-date stress tests. ADD ray.tar /ray ADD git-rev /ray/git-rev -RUN python /ray/python/ray/setup-dev.py --yes +RUN python /ray/python/ray/rllib/setup-rllib-dev.py --yes WORKDIR /ray diff --git a/java/BUILD.bazel b/java/BUILD.bazel index 4960434af180..80ccabccfc12 100644 --- a/java/BUILD.bazel +++ b/java/BUILD.bazel @@ -1,5 +1,4 @@ load("//bazel:ray.bzl", "flatbuffer_java_library", "define_java_module") -load("@build_stack_rules_proto//java:java_proto_compile.bzl", "java_proto_compile") exports_files([ "testng.xml", @@ -51,7 +50,6 @@ define_java_module( name = "runtime", additional_srcs = [ ":generate_java_gcs_fbs", - ":gcs_java_proto", ], additional_resources = [ ":java_native_deps", @@ -70,7 +68,6 @@ define_java_module( "@plasma//:org_apache_arrow_arrow_plasma", "@maven//:com_github_davidmoten_flatbuffers_java", "@maven//:com_google_guava_guava", - "@maven//:com_google_protobuf_protobuf_java", "@maven//:com_typesafe_config", "@maven//:commons_io_commons_io", "@maven//:de_ruedigermoeller_fst", @@ -151,16 +148,38 @@ java_binary( ], ) -java_proto_compile( - name = "gcs_java_proto", - deps = ["@//:gcs_proto"], -) - flatbuffers_generated_files = [ + "ActorCheckpointData.java", + "ActorCheckpointIdData.java", + "ActorState.java", + "ActorTableData.java", "Arg.java", + "ClassTableData.java", + "ClientTableData.java", + "ConfigTableData.java", + "CustomSerializerData.java", + "DriverTableData.java", + "EntryType.java", + "ErrorTableData.java", + "ErrorType.java", + "FunctionTableData.java", + "GcsEntry.java", + "HeartbeatBatchTableData.java", + "HeartbeatTableData.java", "Language.java", - "TaskInfo.java", + "ObjectTableData.java", + "ProfileEvent.java", + "ProfileTableData.java", + "RayResource.java", "ResourcePair.java", + "SchedulingState.java", + "TablePrefix.java", + "TablePubsub.java", + "TaskInfo.java", + "TaskLeaseData.java", + "TaskReconstructionData.java", + "TaskTableData.java", + "TaskTableTestAndUpdate.java", ] flatbuffer_java_library( @@ -179,7 +198,7 @@ genrule( cmd = """ for f in $(locations //java:java_gcs_fbs); do chmod +w $$f - mv -f $$f $(@D)/runtime/src/main/java/org/ray/runtime/generated + cp -f $$f $(@D)/runtime/src/main/java/org/ray/runtime/generated done python $$(pwd)/java/modify_generated_java_flatbuffers_files.py $(@D)/.. """, @@ -202,10 +221,8 @@ filegroup( genrule( name = "gen_maven_deps", srcs = [ - ":gcs_java_proto", - ":generate_java_gcs_fbs", ":java_native_deps", - ":copy_pom_file", + ":generate_java_gcs_fbs", "@plasma//:org_apache_arrow_arrow_plasma", ], outs = ["gen_maven_deps.out"], @@ -220,15 +237,10 @@ genrule( chmod +w $$f cp $$f $$NATIVE_DEPS_DIR done - # Copy protobuf-generated files. + # Copy flatbuffers-generated files GENERATED_DIR=$$WORK_DIR/java/runtime/src/main/java/org/ray/runtime/generated rm -rf $$GENERATED_DIR mkdir -p $$GENERATED_DIR - for f in $(locations //java:gcs_java_proto); do - unzip $$f - mv org/ray/runtime/generated/* $$GENERATED_DIR - done - # Copy flatbuffers-generated files for f in $(locations //java:generate_java_gcs_fbs); do cp $$f $$GENERATED_DIR done @@ -238,7 +250,6 @@ genrule( echo $$(date) > $@ """, local = 1, - tags = ["no-cache"], ) genrule( diff --git a/java/api/src/main/java/org/ray/api/id/BaseId.java b/java/api/src/main/java/org/ray/api/id/BaseId.java index c13f0436f94d..e08955d5a93e 100644 --- a/java/api/src/main/java/org/ray/api/id/BaseId.java +++ b/java/api/src/main/java/org/ray/api/id/BaseId.java @@ -48,7 +48,7 @@ public boolean isNil() { break; } } - isNilCache = localIsNil; + isNilCache = localIsNil; } return isNilCache; } diff --git a/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java b/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java index 2e14ca8584dd..d1e92f7bb9e9 100644 --- a/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java +++ b/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java @@ -13,14 +13,9 @@ public class ActorCreationOptions extends BaseTaskOptions { public final int maxReconstructions; - public final String jvmOptions; - - private ActorCreationOptions(Map resources, - int maxReconstructions, - String jvmOptions) { + private ActorCreationOptions(Map resources, int maxReconstructions) { super(resources); this.maxReconstructions = maxReconstructions; - this.jvmOptions = jvmOptions; } /** @@ -30,7 +25,6 @@ public static class Builder { private Map resources = new HashMap<>(); private int maxReconstructions = NO_RECONSTRUCTION; - private String jvmOptions = ""; public Builder setResources(Map resources) { this.resources = resources; @@ -42,13 +36,8 @@ public Builder setMaxReconstructions(int maxReconstructions) { return this; } - public Builder setJvmOptions(String jvmOptions) { - this.jvmOptions = jvmOptions; - return this; - } - public ActorCreationOptions createActorCreationOptions() { - return new ActorCreationOptions(resources, maxReconstructions, jvmOptions); + return new ActorCreationOptions(resources, maxReconstructions); } } diff --git a/java/dependencies.bzl b/java/dependencies.bzl index ef667137562b..7c716166d399 100644 --- a/java/dependencies.bzl +++ b/java/dependencies.bzl @@ -6,7 +6,6 @@ def gen_java_deps(): "com.beust:jcommander:1.72", "com.github.davidmoten:flatbuffers-java:1.9.0.1", "com.google.guava:guava:27.0.1-jre", - "com.google.protobuf:protobuf-java:3.8.0", "com.puppycrawl.tools:checkstyle:8.15", "com.sun.xml.bind:jaxb-core:2.3.0", "com.sun.xml.bind:jaxb-impl:2.3.0", diff --git a/java/modify_generated_java_flatbuffers_files.py b/java/modify_generated_java_flatbuffers_files.py index 5bf62e56d7e4..c1b723f25f8d 100644 --- a/java/modify_generated_java_flatbuffers_files.py +++ b/java/modify_generated_java_flatbuffers_files.py @@ -4,6 +4,7 @@ import os import sys + """ This script is used for modifying the generated java flatbuffer files for the reason: The package declaration in Java is different @@ -20,18 +21,19 @@ PACKAGE_DECLARATION = "package org.ray.runtime.generated;" -def add_package(file): +def add_new_line(file, line_num, text): with open(file, "r") as file_handler: lines = file_handler.readlines() + if (line_num <= 0) or (line_num > len(lines) + 1): + return False - if "FlatBuffers" not in lines[0]: - return - - lines.insert(1, PACKAGE_DECLARATION + os.linesep) + lines.insert(line_num - 1, text + os.linesep) with open(file, "w") as file_handler: for line in lines: file_handler.write(line) + return True + def add_package_declarations(generated_root_path): file_names = os.listdir(generated_root_path) @@ -39,11 +41,15 @@ def add_package_declarations(generated_root_path): if not file_name.endswith(".java"): continue full_name = os.path.join(generated_root_path, file_name) - add_package(full_name) + success = add_new_line(full_name, 2, PACKAGE_DECLARATION) + if not success: + raise RuntimeError("Failed to add package declarations, " + "file name is %s" % full_name) if __name__ == "__main__": ray_home = sys.argv[1] root_path = os.path.join( - ray_home, "java/runtime/src/main/java/org/ray/runtime/generated") + ray_home, + "java/runtime/src/main/java/org/ray/runtime/generated") add_package_declarations(root_path) diff --git a/java/runtime/pom.xml b/java/runtime/pom.xml index e13dd95f927f..c75e2eeef13f 100644 --- a/java/runtime/pom.xml +++ b/java/runtime/pom.xml @@ -41,11 +41,6 @@ guava 27.0.1-jre - - com.google.protobuf - protobuf-java - 3.8.0 - com.typesafe config diff --git a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java index 26a8d6e541ba..fbd03bf10483 100644 --- a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java @@ -35,7 +35,6 @@ import org.ray.runtime.task.TaskLanguage; import org.ray.runtime.task.TaskSpec; import org.ray.runtime.util.IdUtil; -import org.ray.runtime.util.StringUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -364,13 +363,8 @@ private TaskSpec createTaskSpec(RayFunc func, PyFunctionDescriptor pyFunctionDes } int maxActorReconstruction = 0; - List dynamicWorkerOptions = ImmutableList.of(); if (taskOptions instanceof ActorCreationOptions) { maxActorReconstruction = ((ActorCreationOptions) taskOptions).maxReconstructions; - String jvmOptions = ((ActorCreationOptions) taskOptions).jvmOptions; - if (!StringUtil.isNullOrEmpty(jvmOptions)) { - dynamicWorkerOptions = ImmutableList.of(((ActorCreationOptions) taskOptions).jvmOptions); - } } TaskLanguage language; @@ -399,8 +393,7 @@ private TaskSpec createTaskSpec(RayFunc func, PyFunctionDescriptor pyFunctionDes numReturns, resources, language, - functionDescriptor, - dynamicWorkerOptions + functionDescriptor ); } diff --git a/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java b/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java index 17c248ed0a57..431b48ded58c 100644 --- a/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java +++ b/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java @@ -1,7 +1,7 @@ package org.ray.runtime.gcs; import com.google.common.base.Preconditions; -import com.google.protobuf.InvalidProtocolBufferException; +import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -13,10 +13,10 @@ import org.ray.api.id.TaskId; import org.ray.api.id.UniqueId; import org.ray.api.runtimecontext.NodeInfo; -import org.ray.runtime.generated.Gcs.ActorCheckpointIdData; -import org.ray.runtime.generated.Gcs.ClientTableData; -import org.ray.runtime.generated.Gcs.ClientTableData.EntryType; -import org.ray.runtime.generated.Gcs.TablePrefix; +import org.ray.runtime.generated.ActorCheckpointIdData; +import org.ray.runtime.generated.ClientTableData; +import org.ray.runtime.generated.EntryType; +import org.ray.runtime.generated.TablePrefix; import org.ray.runtime.util.IdUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -51,7 +51,7 @@ public GcsClient(String redisAddress, String redisPassword) { } public List getAllNodeInfo() { - final String prefix = TablePrefix.CLIENT.toString(); + final String prefix = TablePrefix.name(TablePrefix.CLIENT); final byte[] key = ArrayUtils.addAll(prefix.getBytes(), UniqueId.NIL.getBytes()); List results = primary.lrange(key, 0, -1); @@ -63,42 +63,36 @@ public List getAllNodeInfo() { Map clients = new HashMap<>(); for (byte[] result : results) { Preconditions.checkNotNull(result); - ClientTableData data = null; - try { - data = ClientTableData.parseFrom(result); - } catch (InvalidProtocolBufferException e) { - throw new RuntimeException("Received invalid protobuf data from GCS."); - } - final UniqueId clientId = UniqueId - .fromByteBuffer(data.getClientId().asReadOnlyByteBuffer()); + ClientTableData data = ClientTableData.getRootAsClientTableData(ByteBuffer.wrap(result)); + final UniqueId clientId = UniqueId.fromByteBuffer(data.clientIdAsByteBuffer()); - if (data.getEntryType() == EntryType.INSERTION) { + if (data.entryType() == EntryType.INSERTION) { //Code path of node insertion. Map resources = new HashMap<>(); // Compute resources. Preconditions.checkState( - data.getResourcesTotalLabelCount() == data.getResourcesTotalCapacityCount()); - for (int i = 0; i < data.getResourcesTotalLabelCount(); i++) { - resources.put(data.getResourcesTotalLabel(i), data.getResourcesTotalCapacity(i)); + data.resourcesTotalLabelLength() == data.resourcesTotalCapacityLength()); + for (int i = 0; i < data.resourcesTotalLabelLength(); i++) { + resources.put(data.resourcesTotalLabel(i), data.resourcesTotalCapacity(i)); } NodeInfo nodeInfo = new NodeInfo( - clientId, data.getNodeManagerAddress(), true, resources); + clientId, data.nodeManagerAddress(), true, resources); clients.put(clientId, nodeInfo); - } else if (data.getEntryType() == EntryType.RES_CREATEUPDATE) { + } else if (data.entryType() == EntryType.RES_CREATEUPDATE) { Preconditions.checkState(clients.containsKey(clientId)); NodeInfo nodeInfo = clients.get(clientId); - for (int i = 0; i < data.getResourcesTotalLabelCount(); i++) { - nodeInfo.resources.put(data.getResourcesTotalLabel(i), data.getResourcesTotalCapacity(i)); + for (int i = 0; i < data.resourcesTotalLabelLength(); i++) { + nodeInfo.resources.put(data.resourcesTotalLabel(i), data.resourcesTotalCapacity(i)); } - } else if (data.getEntryType() == EntryType.RES_DELETE) { + } else if (data.entryType() == EntryType.RES_DELETE) { Preconditions.checkState(clients.containsKey(clientId)); NodeInfo nodeInfo = clients.get(clientId); - for (int i = 0; i < data.getResourcesTotalLabelCount(); i++) { - nodeInfo.resources.remove(data.getResourcesTotalLabel(i)); + for (int i = 0; i < data.resourcesTotalLabelLength(); i++) { + nodeInfo.resources.remove(data.resourcesTotalLabel(i)); } } else { // Code path of node deletion. - Preconditions.checkState(data.getEntryType() == EntryType.DELETION); + Preconditions.checkState(data.entryType() == EntryType.DELETION); NodeInfo nodeInfo = new NodeInfo(clientId, clients.get(clientId).nodeAddress, false, clients.get(clientId).resources); clients.put(clientId, nodeInfo); @@ -113,7 +107,7 @@ public List getAllNodeInfo() { */ public boolean actorExists(UniqueId actorId) { byte[] key = ArrayUtils.addAll( - TablePrefix.ACTOR.toString().getBytes(), actorId.getBytes()); + TablePrefix.name(TablePrefix.ACTOR).getBytes(), actorId.getBytes()); return primary.exists(key); } @@ -121,7 +115,7 @@ public boolean actorExists(UniqueId actorId) { * Query whether the raylet task exists in Gcs. */ public boolean rayletTaskExistsInGcs(TaskId taskId) { - byte[] key = ArrayUtils.addAll(TablePrefix.RAYLET_TASK.toString().getBytes(), + byte[] key = ArrayUtils.addAll(TablePrefix.name(TablePrefix.RAYLET_TASK).getBytes(), taskId.getBytes()); RedisClient client = getShardClient(taskId); return client.exists(key); @@ -132,26 +126,19 @@ public boolean rayletTaskExistsInGcs(TaskId taskId) { */ public List getCheckpointsForActor(UniqueId actorId) { List checkpoints = new ArrayList<>(); - final String prefix = TablePrefix.ACTOR_CHECKPOINT_ID.toString(); + final String prefix = TablePrefix.name(TablePrefix.ACTOR_CHECKPOINT_ID); final byte[] key = ArrayUtils.addAll(prefix.getBytes(), actorId.getBytes()); RedisClient client = getShardClient(actorId); byte[] result = client.get(key); if (result != null) { - ActorCheckpointIdData data = null; - try { - data = ActorCheckpointIdData.parseFrom(result); - } catch (InvalidProtocolBufferException e) { - throw new RuntimeException("Received invalid protobuf data from GCS."); - } - UniqueId[] checkpointIds = new UniqueId[data.getCheckpointIdsCount()]; - for (int i = 0; i < checkpointIds.length; i++) { - checkpointIds[i] = UniqueId - .fromByteBuffer(data.getCheckpointIds(i).asReadOnlyByteBuffer()); - } + ActorCheckpointIdData data = + ActorCheckpointIdData.getRootAsActorCheckpointIdData(ByteBuffer.wrap(result)); + UniqueId[] checkpointIds = IdUtil.getUniqueIdsFromByteBuffer( + data.checkpointIdsAsByteBuffer()); for (int i = 0; i < checkpointIds.length; i++) { - checkpoints.add(new Checkpoint(checkpointIds[i], data.getTimestamps(i))); + checkpoints.add(new Checkpoint(checkpointIds[i], data.timestamps(i))); } } checkpoints.sort((x, y) -> Long.compare(y.timestamp, x.timestamp)); diff --git a/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java b/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java index 1a7e4701c22b..f9e310249a35 100644 --- a/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java +++ b/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java @@ -16,7 +16,7 @@ import org.ray.runtime.AbstractRayRuntime; import org.ray.runtime.RayDevRuntime; import org.ray.runtime.config.RunMode; -import org.ray.runtime.generated.Gcs.ErrorType; +import org.ray.runtime.generated.ErrorType; import org.ray.runtime.util.IdUtil; import org.ray.runtime.util.Serializer; import org.slf4j.Logger; @@ -29,12 +29,12 @@ public class ObjectStoreProxy { private static final Logger LOGGER = LoggerFactory.getLogger(ObjectStoreProxy.class); - private static final byte[] WORKER_EXCEPTION_META = String - .valueOf(ErrorType.WORKER_DIED.getNumber()).getBytes(); - private static final byte[] ACTOR_EXCEPTION_META = String - .valueOf(ErrorType.ACTOR_DIED.getNumber()).getBytes(); + private static final byte[] WORKER_EXCEPTION_META = String.valueOf(ErrorType.WORKER_DIED) + .getBytes(); + private static final byte[] ACTOR_EXCEPTION_META = String.valueOf(ErrorType.ACTOR_DIED) + .getBytes(); private static final byte[] UNRECONSTRUCTABLE_EXCEPTION_META = String - .valueOf(ErrorType.OBJECT_UNRECONSTRUCTABLE.getNumber()).getBytes(); + .valueOf(ErrorType.OBJECT_UNRECONSTRUCTABLE).getBytes(); private static final byte[] RAW_TYPE_META = "RAW".getBytes(); diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java index c369e6f2cab8..01b9e4675016 100644 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java +++ b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java @@ -190,16 +190,9 @@ private static TaskSpec parseTaskSpecFromFlatbuffer(ByteBuffer bb) { JavaFunctionDescriptor functionDescriptor = new JavaFunctionDescriptor( info.functionDescriptor(0), info.functionDescriptor(1), info.functionDescriptor(2) ); - - // Deserialize dynamic worker options. - List dynamicWorkerOptions = new ArrayList<>(); - for (int i = 0; i < info.dynamicWorkerOptionsLength(); ++i) { - dynamicWorkerOptions.add(info.dynamicWorkerOptions(i)); - } - return new TaskSpec(driverId, taskId, parentTaskId, parentCounter, actorCreationId, maxActorReconstructions, actorId, actorHandleId, actorCounter, newActorHandles, - args, numReturns, resources, TaskLanguage.JAVA, functionDescriptor, dynamicWorkerOptions); + args, numReturns, resources, TaskLanguage.JAVA, functionDescriptor); } private static ByteBuffer convertTaskSpecToFlatbuffer(TaskSpec task) { @@ -282,12 +275,6 @@ private static ByteBuffer convertTaskSpecToFlatbuffer(TaskSpec task) { functionDescriptorOffset = fbb.createVectorOfTables(functionDescriptorOffsets); } - int [] dynamicWorkerOptionsOffsets = new int[task.dynamicWorkerOptions.size()]; - for (int index = 0; index < task.dynamicWorkerOptions.size(); ++index) { - dynamicWorkerOptionsOffsets[index] = fbb.createString(task.dynamicWorkerOptions.get(index)); - } - int dynamicWorkerOptionsOffset = fbb.createVectorOfTables(dynamicWorkerOptionsOffsets); - int root = TaskInfo.createTaskInfo( fbb, driverIdOffset, @@ -306,8 +293,7 @@ private static ByteBuffer convertTaskSpecToFlatbuffer(TaskSpec task) { requiredResourcesOffset, requiredPlacementResourcesOffset, language, - functionDescriptorOffset, - dynamicWorkerOptionsOffset); + functionDescriptorOffset); fbb.finish(root); ByteBuffer buffer = fbb.dataBuffer(); diff --git a/java/runtime/src/main/java/org/ray/runtime/runner/RunManager.java b/java/runtime/src/main/java/org/ray/runtime/runner/RunManager.java index 773499fcf5cf..15240e43e234 100644 --- a/java/runtime/src/main/java/org/ray/runtime/runner/RunManager.java +++ b/java/runtime/src/main/java/org/ray/runtime/runner/RunManager.java @@ -319,9 +319,6 @@ private String buildWorkerCommandRaylet() { cmd.addAll(rayConfig.jvmParameters); - // jvm options - cmd.add("RAY_WORKER_OPTION_0"); - // Main class cmd.add(WORKER_CLASS); String command = Joiner.on(" ").join(cmd); diff --git a/java/runtime/src/main/java/org/ray/runtime/task/TaskSpec.java b/java/runtime/src/main/java/org/ray/runtime/task/TaskSpec.java index 060ca6fff4c3..3473a9bdb3cc 100644 --- a/java/runtime/src/main/java/org/ray/runtime/task/TaskSpec.java +++ b/java/runtime/src/main/java/org/ray/runtime/task/TaskSpec.java @@ -63,8 +63,6 @@ public class TaskSpec { // Language of this task. public final TaskLanguage language; - public final List dynamicWorkerOptions; - // Descriptor of the remote function. // Note, if task language is Java, the type is JavaFunctionDescriptor. If the task language // is Python, the type is PyFunctionDescriptor. @@ -95,8 +93,7 @@ public TaskSpec( int numReturns, Map resources, TaskLanguage language, - FunctionDescriptor functionDescriptor, - List dynamicWorkerOptions) { + FunctionDescriptor functionDescriptor) { this.driverId = driverId; this.taskId = taskId; this.parentTaskId = parentTaskId; @@ -109,8 +106,6 @@ public TaskSpec( this.newActorHandles = newActorHandles; this.args = args; this.numReturns = numReturns; - this.dynamicWorkerOptions = dynamicWorkerOptions; - returnIds = new ObjectId[numReturns]; for (int i = 0; i < numReturns; ++i) { returnIds[i] = IdUtil.computeReturnId(taskId, i + 1); @@ -162,7 +157,6 @@ public String toString() { ", resources=" + resources + ", language=" + language + ", functionDescriptor=" + functionDescriptor + - ", dynamicWorkerOptions=" + dynamicWorkerOptions + ", executionDependencies=" + executionDependencies + '}'; } diff --git a/java/test/src/main/java/org/ray/api/TestUtils.java b/java/test/src/main/java/org/ray/api/TestUtils.java index 3636c93e4909..9b3bbf233856 100644 --- a/java/test/src/main/java/org/ray/api/TestUtils.java +++ b/java/test/src/main/java/org/ray/api/TestUtils.java @@ -1,10 +1,8 @@ package org.ray.api; import java.util.function.Supplier; -import org.ray.api.annotation.RayRemote; import org.ray.runtime.AbstractRayRuntime; import org.ray.runtime.config.RunMode; -import org.testng.Assert; import org.testng.SkipException; public class TestUtils { @@ -44,17 +42,4 @@ public static boolean waitForCondition(Supplier condition, int timeoutM } return false; } - - @RayRemote - private static String hi() { - return "hi"; - } - - /** - * Warm up the cluster. - */ - public static void warmUpCluster() { - RayObject obj = Ray.call(TestUtils::hi); - Assert.assertEquals(obj.get(), "hi"); - } } diff --git a/java/test/src/main/java/org/ray/api/test/DynamicResourceTest.java b/java/test/src/main/java/org/ray/api/test/DynamicResourceTest.java index 71766c6cf2bf..79b3eba0ed13 100644 --- a/java/test/src/main/java/org/ray/api/test/DynamicResourceTest.java +++ b/java/test/src/main/java/org/ray/api/test/DynamicResourceTest.java @@ -23,10 +23,6 @@ public static String sayHi() { @Test public void testSetResource() { TestUtils.skipTestUnderSingleProcess(); - - // Call a task in advance to warm up the cluster to avoid being too slow to start workers. - TestUtils.warmUpCluster(); - CallOptions op1 = new CallOptions.Builder().setResources(ImmutableMap.of("A", 10.0)).createCallOptions(); RayObject obj = Ray.call(DynamicResourceTest::sayHi, op1); @@ -34,21 +30,16 @@ public void testSetResource() { Assert.assertEquals(result.getReady().size(), 0); Ray.setResource("A", 10.0); - boolean resourceReady = TestUtils.waitForCondition(() -> { - List nodes = Ray.getRuntimeContext().getAllNodeInfo(); - if (nodes.size() != 1) { - return false; - } - return (0 == Double.compare(10.0, nodes.get(0).resources.get("A"))); - }, 2000); - Assert.assertTrue(resourceReady); + // Assert node info. + List nodes = Ray.getRuntimeContext().getAllNodeInfo(); + Assert.assertEquals(nodes.size(), 1); + Assert.assertEquals(nodes.get(0).resources.get("A"), 10.0); // Assert ray call result. result = Ray.wait(ImmutableList.of(obj), 1, 1000); Assert.assertEquals(result.getReady().size(), 1); Assert.assertEquals(Ray.get(obj.getId()), "hi"); - } } diff --git a/java/test/src/main/java/org/ray/api/test/WaitTest.java b/java/test/src/main/java/org/ray/api/test/WaitTest.java index bccc50a50bdf..e82b99d364ba 100644 --- a/java/test/src/main/java/org/ray/api/test/WaitTest.java +++ b/java/test/src/main/java/org/ray/api/test/WaitTest.java @@ -5,7 +5,6 @@ import java.util.List; import org.ray.api.Ray; import org.ray.api.RayObject; -import org.ray.api.TestUtils; import org.ray.api.WaitResult; import org.ray.api.annotation.RayRemote; import org.testng.Assert; @@ -29,9 +28,6 @@ private static String delayedHi() { } private static void testWait() { - // Call a task in advance to warm up the cluster to avoid being too slow to start workers. - TestUtils.warmUpCluster(); - RayObject obj1 = Ray.call(WaitTest::hi); RayObject obj2 = Ray.call(WaitTest::delayedHi); @@ -75,5 +71,4 @@ public void testWaitForEmpty() { Assert.assertTrue(true); } } - } diff --git a/java/test/src/main/java/org/ray/api/test/WorkerJvmOptionsTest.java b/java/test/src/main/java/org/ray/api/test/WorkerJvmOptionsTest.java deleted file mode 100644 index 90a2817a8366..000000000000 --- a/java/test/src/main/java/org/ray/api/test/WorkerJvmOptionsTest.java +++ /dev/null @@ -1,31 +0,0 @@ -package org.ray.api.test; - -import org.ray.api.Ray; -import org.ray.api.RayActor; -import org.ray.api.RayObject; -import org.ray.api.TestUtils; -import org.ray.api.annotation.RayRemote; -import org.ray.api.options.ActorCreationOptions; -import org.testng.Assert; -import org.testng.annotations.Test; - -public class WorkerJvmOptionsTest extends BaseTest { - - @RayRemote - public static class Echo { - String getOptions() { - return System.getProperty("test.suffix"); - } - } - - @Test - public void testJvmOptions() { - TestUtils.skipTestUnderSingleProcess(); - ActorCreationOptions options = new ActorCreationOptions.Builder() - .setJvmOptions("-Dtest.suffix=suffix") - .createActorCreationOptions(); - RayActor actor = Ray.createActor(Echo::new, options); - RayObject obj = Ray.call(Echo::getOptions, actor); - Assert.assertEquals(obj.get(), "suffix"); - } -} diff --git a/python/ray/experimental/signal.py b/python/ray/experimental/signal.py index 25ec072d3fc7..f2a0d81ca343 100644 --- a/python/ray/experimental/signal.py +++ b/python/ray/experimental/signal.py @@ -2,8 +2,6 @@ from __future__ import division from __future__ import print_function -import logging - from collections import defaultdict import ray @@ -15,8 +13,6 @@ # in node_manager.cc ACTOR_DIED_STR = "ACTOR_DIED_SIGNAL" -logger = logging.getLogger(__name__) - class Signal(object): """Base class for Ray signals.""" @@ -129,16 +125,10 @@ def receive(sources, timeout=None): for s in sources: task_id_to_sources[_get_task_id(s).hex()].append(s) - if timeout < 1e-3: - logger.warning("Timeout too small. Using 1ms minimum") - timeout = 1e-3 - - timeout_ms = int(1000 * timeout) - # Construct the redis query. query = "XREAD BLOCK " - # redis expects ms. - query += str(timeout_ms) + # Multiply by 1000x since timeout is in sec and redis expects ms. + query += str(1000 * timeout) query += " STREAMS " query += " ".join([task_id for task_id in task_id_to_sources]) query += " " diff --git a/python/ray/gcs_utils.py b/python/ray/gcs_utils.py index ba72e96f41db..cadd197ec73f 100644 --- a/python/ray/gcs_utils.py +++ b/python/ray/gcs_utils.py @@ -2,39 +2,38 @@ from __future__ import division from __future__ import print_function -from ray.core.generated.ray.protocol.Task import Task +import flatbuffers +import ray.core.generated.ErrorTableData + +from ray.core.generated.ActorCheckpointIdData import ActorCheckpointIdData +from ray.core.generated.ClientTableData import ClientTableData +from ray.core.generated.DriverTableData import DriverTableData +from ray.core.generated.ErrorTableData import ErrorTableData +from ray.core.generated.GcsEntry import GcsEntry +from ray.core.generated.HeartbeatBatchTableData import HeartbeatBatchTableData +from ray.core.generated.HeartbeatTableData import HeartbeatTableData +from ray.core.generated.Language import Language +from ray.core.generated.ObjectTableData import ObjectTableData +from ray.core.generated.ProfileTableData import ProfileTableData +from ray.core.generated.TablePrefix import TablePrefix +from ray.core.generated.TablePubsub import TablePubsub -from ray.core.generated.gcs_pb2 import ( - ActorCheckpointIdData, - ClientTableData, - DriverTableData, - ErrorTableData, - ErrorType, - GcsEntry, - HeartbeatBatchTableData, - HeartbeatTableData, - ObjectTableData, - ProfileTableData, - TablePrefix, - TablePubsub, - TaskTableData, -) +from ray.core.generated.ray.protocol.Task import Task __all__ = [ "ActorCheckpointIdData", "ClientTableData", "DriverTableData", "ErrorTableData", - "ErrorType", "GcsEntry", "HeartbeatBatchTableData", "HeartbeatTableData", + "Language", "ObjectTableData", "ProfileTableData", "TablePrefix", "TablePubsub", "Task", - "TaskTableData", "construct_error_message", ] @@ -43,16 +42,13 @@ REPORTER_CHANNEL = "RAY_REPORTER" # xray heartbeats -XRAY_HEARTBEAT_CHANNEL = str( - TablePubsub.Value("HEARTBEAT_PUBSUB")).encode("ascii") -XRAY_HEARTBEAT_BATCH_CHANNEL = str( - TablePubsub.Value("HEARTBEAT_BATCH_PUBSUB")).encode("ascii") +XRAY_HEARTBEAT_CHANNEL = str(TablePubsub.HEARTBEAT).encode("ascii") +XRAY_HEARTBEAT_BATCH_CHANNEL = str(TablePubsub.HEARTBEAT_BATCH).encode("ascii") # xray driver updates -XRAY_DRIVER_CHANNEL = str(TablePubsub.Value("DRIVER_PUBSUB")).encode("ascii") +XRAY_DRIVER_CHANNEL = str(TablePubsub.DRIVER).encode("ascii") -# These prefixes must be kept up-to-date with the TablePrefix enum in -# gcs.proto. +# These prefixes must be kept up-to-date with the TablePrefix enum in gcs.fbs. # TODO(rkn): We should use scoped enums, in which case we should be able to # just access the flatbuffer generated values. TablePrefix_RAYLET_TASK_string = "RAYLET_TASK" @@ -74,9 +70,22 @@ def construct_error_message(driver_id, error_type, message, timestamp): Returns: The serialized object. """ - data = ErrorTableData() - data.driver_id = driver_id.binary() - data.type = error_type - data.error_message = message - data.timestamp = timestamp - return data.SerializeToString() + builder = flatbuffers.Builder(0) + driver_offset = builder.CreateString(driver_id.binary()) + error_type_offset = builder.CreateString(error_type) + message_offset = builder.CreateString(message) + + ray.core.generated.ErrorTableData.ErrorTableDataStart(builder) + ray.core.generated.ErrorTableData.ErrorTableDataAddDriverId( + builder, driver_offset) + ray.core.generated.ErrorTableData.ErrorTableDataAddType( + builder, error_type_offset) + ray.core.generated.ErrorTableData.ErrorTableDataAddErrorMessage( + builder, message_offset) + ray.core.generated.ErrorTableData.ErrorTableDataAddTimestamp( + builder, timestamp) + error_data_offset = ray.core.generated.ErrorTableData.ErrorTableDataEnd( + builder) + builder.Finish(error_data_offset) + + return bytes(builder.Output()) diff --git a/python/ray/monitor.py b/python/ray/monitor.py index 35597ef231e3..c9e0424b3eb8 100644 --- a/python/ray/monitor.py +++ b/python/ray/monitor.py @@ -101,26 +101,28 @@ def subscribe(self, channel): def xray_heartbeat_batch_handler(self, unused_channel, data): """Handle an xray heartbeat batch message from Redis.""" - gcs_entries = ray.gcs_utils.GcsEntry.FromString(data) - heartbeat_data = gcs_entries.entries[0] + gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(data, 0) + heartbeat_data = gcs_entries.Entries(0) - message = ray.gcs_utils.HeartbeatBatchTableData.FromString( - heartbeat_data) + message = (ray.gcs_utils.HeartbeatBatchTableData. + GetRootAsHeartbeatBatchTableData(heartbeat_data, 0)) - for heartbeat_message in message.batch: - num_resources = len(heartbeat_message.resources_available_label) + for j in range(message.BatchLength()): + heartbeat_message = message.Batch(j) + + num_resources = heartbeat_message.ResourcesTotalLabelLength() static_resources = {} dynamic_resources = {} for i in range(num_resources): - dyn = heartbeat_message.resources_available_label[i] - static = heartbeat_message.resources_total_label[i] + dyn = heartbeat_message.ResourcesAvailableLabel(i) + static = heartbeat_message.ResourcesTotalLabel(i) dynamic_resources[dyn] = ( - heartbeat_message.resources_available_capacity[i]) + heartbeat_message.ResourcesAvailableCapacity(i)) static_resources[static] = ( - heartbeat_message.resources_total_capacity[i]) + heartbeat_message.ResourcesTotalCapacity(i)) # Update the load metrics for this raylet. - client_id = ray.utils.binary_to_hex(heartbeat_message.client_id) + client_id = ray.utils.binary_to_hex(heartbeat_message.ClientId()) ip = self.raylet_id_to_ip_map.get(client_id) if ip: self.load_metrics.update(ip, static_resources, @@ -205,10 +207,11 @@ def xray_driver_removed_handler(self, unused_channel, data): unused_channel: The message channel. data: The message data. """ - gcs_entries = ray.gcs_utils.GcsEntry.FromString(data) - driver_data = gcs_entries.entries[0] - message = ray.gcs_utils.DriverTableData.FromString(driver_data) - driver_id = message.driver_id + gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(data, 0) + driver_data = gcs_entries.Entries(0) + message = ray.gcs_utils.DriverTableData.GetRootAsDriverTableData( + driver_data, 0) + driver_id = message.DriverId() logger.info("Monitor: " "XRay Driver {} has been removed.".format( binary_to_hex(driver_id))) diff --git a/python/ray/rllib/agents/a3c/a3c.py b/python/ray/rllib/agents/a3c/a3c.py index d320b9636881..c269df2fc6e5 100644 --- a/python/ray/rllib/agents/a3c/a3c.py +++ b/python/ray/rllib/agents/a3c/a3c.py @@ -48,10 +48,6 @@ def get_policy_class(config): def validate_config(config): if config["entropy_coeff"] < 0: raise DeprecationWarning("entropy_coeff must be >= 0") - if config["sample_async"] and config["use_pytorch"]: - raise ValueError( - "The sample_async option is not supported with use_pytorch: " - "Multithreading can be lead to crashes if used with pytorch.") def make_async_optimizer(workers, config): diff --git a/python/ray/rllib/agents/impala/impala.py b/python/ray/rllib/agents/impala/impala.py index 23b5ada167db..b9699888bfaf 100644 --- a/python/ray/rllib/agents/impala/impala.py +++ b/python/ray/rllib/agents/impala/impala.py @@ -75,7 +75,6 @@ # balancing the three losses "vf_loss_coeff": 0.5, "entropy_coeff": 0.01, - "entropy_schedule": None, # use fake (infinite speed) sampler for testing "_fake_sampler": False, diff --git a/python/ray/rllib/agents/impala/vtrace_policy.py b/python/ray/rllib/agents/impala/vtrace_policy.py index 9860783238a0..7fd137bae08b 100644 --- a/python/ray/rllib/agents/impala/vtrace_policy.py +++ b/python/ray/rllib/agents/impala/vtrace_policy.py @@ -14,7 +14,7 @@ from ray.rllib.policy.policy import Policy from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.tf_policy import TFPolicy, \ - LearningRateSchedule, EntropyCoeffSchedule + LearningRateSchedule from ray.rllib.models.action_dist import MultiCategorical from ray.rllib.models.catalog import ModelCatalog from ray.rllib.utils.annotations import override @@ -126,7 +126,7 @@ def postprocess_trajectory(self, return sample_batch -class VTraceTFPolicy(LearningRateSchedule, EntropyCoeffSchedule, VTracePostprocessing, TFPolicy): +class VTraceTFPolicy(LearningRateSchedule, VTracePostprocessing, TFPolicy): def __init__(self, observation_space, action_space, @@ -249,9 +249,6 @@ def make_time_major(tensor, drop_last=False): loss_actions = actions if is_multidiscrete else tf.expand_dims( actions, axis=1) - EntropyCoeffSchedule.__init__(self, self.config["entropy_coeff"], - self.config["entropy_schedule"]) - # Inputs are reshaped from [B * T] => [T - 1, B] for V-trace calc. with tf.name_scope('vtrace_loss'): self.loss = VTraceLoss( @@ -336,7 +333,6 @@ def make_time_major(tensor, drop_last=False): self.stats_fetches = { LEARNER_STATS_KEY: dict({ "cur_lr": tf.cast(self.cur_lr, tf.float64), - "entropy_coeff": tf.cast(self.entropy_coeff, tf.float64), "policy_loss": self.loss.pi_loss, "entropy": self.loss.entropy, "grad_gnorm": tf.global_norm(self._grads), diff --git a/python/ray/rllib/agents/qmix/qmix_policy.py b/python/ray/rllib/agents/qmix/qmix_policy.py index 99045899684b..26ec387de004 100644 --- a/python/ray/rllib/agents/qmix/qmix_policy.py +++ b/python/ray/rllib/agents/qmix/qmix_policy.py @@ -204,8 +204,6 @@ def __init__(self, obs_space, action_space, config): # Setup optimizer self.params = list(self.model.parameters()) - if self.mixer: - self.params += list(self.mixer.parameters()) self.loss = QMixLoss(self.model, self.target_model, self.mixer, self.target_mixer, self.n_agents, self.n_actions, self.config["double_q"], self.config["gamma"]) diff --git a/python/ray/rllib/policy/tf_policy.py b/python/ray/rllib/policy/tf_policy.py index 591363a793be..abc5cf546184 100644 --- a/python/ray/rllib/policy/tf_policy.py +++ b/python/ray/rllib/policy/tf_policy.py @@ -2,21 +2,21 @@ from __future__ import division from __future__ import print_function +import os import errno import logging -import os - import numpy as np + import ray import ray.experimental.tf_utils -from ray.rllib.models.lstm import chop_into_sequences from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.utils import try_import_tf +from ray.rllib.models.lstm import chop_into_sequences from ray.rllib.utils.annotations import override, DeveloperAPI from ray.rllib.utils.debug import log_once, summarize from ray.rllib.utils.schedules import ConstantSchedule, PiecewiseSchedule, LinearSchedule from ray.rllib.utils.tf_run_builder import TFRunBuilder +from ray.rllib.utils import try_import_tf tf = try_import_tf() logger = logging.getLogger(__name__) @@ -416,7 +416,7 @@ def _build_compute_actions(self, if len(self._state_inputs) != len(state_batches): raise ValueError( "Must pass in RNN state batches for placeholders {}, got {}". - format(self._state_inputs, state_batches)) + format(self._state_inputs, state_batches)) builder.add_feed_dict(self.extra_compute_action_feed_dict()) builder.add_feed_dict({self._obs_input: obs_batch}) if state_batches: @@ -443,7 +443,7 @@ def _build_apply_gradients(self, builder, gradients): if len(gradients) != len(self._grads): raise ValueError( "Unexpected number of gradients to apply, got {} for {}". - format(gradients, self._grads)) + format(gradients, self._grads)) builder.add_feed_dict({self._is_training: True}) builder.add_feed_dict(dict(zip(self._grads, gradients))) fetches = builder.add_fetches([self._apply_op]) @@ -473,9 +473,9 @@ def _get_loss_inputs_dict(self, batch): feed_dict = {} if self._batch_divisibility_req > 1: meets_divisibility_reqs = ( - len(batch[SampleBatch.CUR_OBS]) % - self._batch_divisibility_req == 0 - and max(batch[SampleBatch.AGENT_INDEX]) == 0) # not multiagent + len(batch[SampleBatch.CUR_OBS]) % + self._batch_divisibility_req == 0 + and max(batch[SampleBatch.AGENT_INDEX]) == 0) # not multiagent else: meets_divisibility_reqs = True @@ -551,23 +551,3 @@ def on_global_var_update(self, global_vars): @override(TFPolicy) def optimizer(self): return tf.train.AdamOptimizer(self.cur_lr) - - -@DeveloperAPI -class EntropyCoeffSchedule(object): - """Mixin for TFPolicy that adds entropy coeff decay.""" - - @DeveloperAPI - def __init__(self, entropy_coeff, entropy_schedule): - self.entropy_coeff = tf.get_variable("entropy_coeff", initializer=entropy_coeff) - self._entropy_schedule = entropy_schedule - - @override(Policy) - def on_global_var_update(self, global_vars): - super(EntropyCoeffSchedule, self).on_global_var_update(global_vars) - if self._entropy_schedule is not None: - self.entropy_coeff.load( - self.config['entropy_coeff'] * - (1 - global_vars['timestep'] / - self.config['entropy_schedule']), - session=self._sess) diff --git a/python/ray/rllib/tests/test_optimizers.py b/python/ray/rllib/tests/test_optimizers.py index d27270c20965..a87a295ccf1d 100644 --- a/python/ray/rllib/tests/test_optimizers.py +++ b/python/ray/rllib/tests/test_optimizers.py @@ -125,14 +125,14 @@ def testSimple(self): def testMultiGPU(self): local, remotes = self._make_evs() workers = WorkerSet._from_existing(local, remotes) - optimizer = AsyncSamplesOptimizer(workers, num_gpus=1, _fake_gpus=True) + optimizer = AsyncSamplesOptimizer(workers, num_gpus=2, _fake_gpus=True) self._wait_for(optimizer, 1000, 1000) def testMultiGPUParallelLoad(self): local, remotes = self._make_evs() workers = WorkerSet._from_existing(local, remotes) optimizer = AsyncSamplesOptimizer( - workers, num_gpus=1, num_data_loader_buffers=1, _fake_gpus=True) + workers, num_gpus=2, num_data_loader_buffers=2, _fake_gpus=True) self._wait_for(optimizer, 1000, 1000) def testMultiplePasses(self): @@ -211,21 +211,21 @@ def testRejectBadConfigs(self): num_data_loader_buffers=2, minibatch_buffer_size=4)) optimizer = AsyncSamplesOptimizer( workers, - num_gpus=1, + num_gpus=2, train_batch_size=100, sample_batch_size=50, _fake_gpus=True) self._wait_for(optimizer, 1000, 1000) optimizer = AsyncSamplesOptimizer( workers, - num_gpus=1, + num_gpus=2, train_batch_size=100, sample_batch_size=25, _fake_gpus=True) self._wait_for(optimizer, 1000, 1000) optimizer = AsyncSamplesOptimizer( workers, - num_gpus=1, + num_gpus=2, train_batch_size=100, sample_batch_size=74, _fake_gpus=True) diff --git a/python/ray/services.py b/python/ray/services.py index ff4111b2c258..66d4069820d0 100644 --- a/python/ray/services.py +++ b/python/ray/services.py @@ -1245,7 +1245,6 @@ def build_java_worker_command( assert java_worker_options is not None command = "java " - if redis_address is not None: command += "-Dray.redis.address={} ".format(redis_address) @@ -1266,8 +1265,6 @@ def build_java_worker_command( # Put `java_worker_options` in the last, so it can overwrite the # above options. command += java_worker_options + " " - - command += "RAY_WORKER_OPTION_0 " command += "org.ray.runtime.runner.worker.DefaultWorker" return command diff --git a/python/ray/state.py b/python/ray/state.py index 35f97cd65f5e..14ba49987ec4 100644 --- a/python/ray/state.py +++ b/python/ray/state.py @@ -10,11 +10,11 @@ import ray from ray.function_manager import FunctionDescriptor +import ray.gcs_utils -from ray import ( - gcs_utils, - services, -) +from ray.ray_constants import ID_SIZE +from ray import services +from ray.core.generated.EntryType import EntryType from ray.utils import (decode, binary_to_object_id, binary_to_hex, hex_to_binary) @@ -31,9 +31,9 @@ def _parse_client_table(redis_client): A list of information about the nodes in the cluster. """ NIL_CLIENT_ID = ray.ObjectID.nil().binary() - message = redis_client.execute_command( - "RAY.TABLE_LOOKUP", gcs_utils.TablePrefix.Value("CLIENT"), "", - NIL_CLIENT_ID) + message = redis_client.execute_command("RAY.TABLE_LOOKUP", + ray.gcs_utils.TablePrefix.CLIENT, + "", NIL_CLIENT_ID) # Handle the case where no clients are returned. This should only # occur potentially immediately after the cluster is started. @@ -41,31 +41,36 @@ def _parse_client_table(redis_client): return [] node_info = {} - gcs_entry = gcs_utils.GcsEntry.FromString(message) + gcs_entry = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0) ordered_client_ids = [] # Since GCS entries are append-only, we override so that # only the latest entries are kept. - for entry in gcs_entry.entries: - client = gcs_utils.ClientTableData.FromString(entry) + for i in range(gcs_entry.EntriesLength()): + client = (ray.gcs_utils.ClientTableData.GetRootAsClientTableData( + gcs_entry.Entries(i), 0)) resources = { - client.resources_total_label[i]: client.resources_total_capacity[i] - for i in range(len(client.resources_total_label)) + decode(client.ResourcesTotalLabel(i)): + client.ResourcesTotalCapacity(i) + for i in range(client.ResourcesTotalLabelLength()) } - client_id = ray.utils.binary_to_hex(client.client_id) + client_id = ray.utils.binary_to_hex(client.ClientId()) - if client.entry_type == gcs_utils.ClientTableData.INSERTION: + if client.EntryType() == EntryType.INSERTION: ordered_client_ids.append(client_id) node_info[client_id] = { "ClientID": client_id, - "EntryType": client.entry_type, - "NodeManagerAddress": client.node_manager_address, - "NodeManagerPort": client.node_manager_port, - "ObjectManagerPort": client.object_manager_port, - "ObjectStoreSocketName": client.object_store_socket_name, - "RayletSocketName": client.raylet_socket_name, + "EntryType": client.EntryType(), + "NodeManagerAddress": decode( + client.NodeManagerAddress(), allow_none=True), + "NodeManagerPort": client.NodeManagerPort(), + "ObjectManagerPort": client.ObjectManagerPort(), + "ObjectStoreSocketName": decode( + client.ObjectStoreSocketName(), allow_none=True), + "RayletSocketName": decode( + client.RayletSocketName(), allow_none=True), "Resources": resources } @@ -74,23 +79,22 @@ def _parse_client_table(redis_client): # it cannot have previously been removed. else: assert client_id in node_info, "Client not found!" - is_deletion = (node_info[client_id]["EntryType"] != - gcs_utils.ClientTableData.DELETION) - assert is_deletion, "Unexpected updation of deleted client." + assert node_info[client_id]["EntryType"] != EntryType.DELETION, ( + "Unexpected updation of deleted client.") res_map = node_info[client_id]["Resources"] - if client.entry_type == gcs_utils.ClientTableData.RES_CREATEUPDATE: + if client.EntryType() == EntryType.RES_CREATEUPDATE: for res in resources: res_map[res] = resources[res] - elif client.entry_type == gcs_utils.ClientTableData.RES_DELETE: + elif client.EntryType() == EntryType.RES_DELETE: for res in resources: res_map.pop(res, None) - elif client.entry_type == gcs_utils.ClientTableData.DELETION: + elif client.EntryType() == EntryType.DELETION: pass # Do nothing with the resmap if client deletion else: raise RuntimeError("Unexpected EntryType {}".format( - client.entry_type)) + client.EntryType())) node_info[client_id]["Resources"] = res_map - node_info[client_id]["EntryType"] = client.entry_type + node_info[client_id]["EntryType"] = client.EntryType() # NOTE: We return the list comprehension below instead of simply doing # 'list(node_info.values())' in order to have the nodes appear in the order # that they joined the cluster. Python dictionaries do not preserve @@ -240,19 +244,20 @@ def _object_table(self, object_id): # Return information about a single object ID. message = self._execute_command(object_id, "RAY.TABLE_LOOKUP", - gcs_utils.TablePrefix.Value("OBJECT"), - "", object_id.binary()) + ray.gcs_utils.TablePrefix.OBJECT, "", + object_id.binary()) if message is None: return {} - gcs_entry = gcs_utils.GcsEntry.FromString(message) + gcs_entry = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0) - assert len(gcs_entry.entries) > 0 + assert gcs_entry.EntriesLength() > 0 - entry = gcs_utils.ObjectTableData.FromString(gcs_entry.entries[0]) + entry = ray.gcs_utils.ObjectTableData.GetRootAsObjectTableData( + gcs_entry.Entries(0), 0) object_info = { - "DataSize": entry.object_size, - "Manager": entry.manager, + "DataSize": entry.ObjectSize(), + "Manager": entry.Manager(), } return object_info @@ -273,9 +278,10 @@ def object_table(self, object_id=None): return self._object_table(object_id) else: # Return the entire object table. - object_keys = self._keys(gcs_utils.TablePrefix_OBJECT_string + "*") + object_keys = self._keys(ray.gcs_utils.TablePrefix_OBJECT_string + + "*") object_ids_binary = { - key[len(gcs_utils.TablePrefix_OBJECT_string):] + key[len(ray.gcs_utils.TablePrefix_OBJECT_string):] for key in object_keys } @@ -295,18 +301,17 @@ def _task_table(self, task_id): A dictionary with information about the task ID in question. """ assert isinstance(task_id, ray.TaskID) - message = self._execute_command( - task_id, "RAY.TABLE_LOOKUP", - gcs_utils.TablePrefix.Value("RAYLET_TASK"), "", task_id.binary()) + message = self._execute_command(task_id, "RAY.TABLE_LOOKUP", + ray.gcs_utils.TablePrefix.RAYLET_TASK, + "", task_id.binary()) if message is None: return {} - gcs_entries = gcs_utils.GcsEntry.FromString(message) + gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0) + + assert gcs_entries.EntriesLength() == 1 - assert len(gcs_entries.entries) == 1 - task_table_data = gcs_utils.TaskTableData.FromString( - gcs_entries.entries[0]) - task_table_message = gcs_utils.Task.GetRootAsTask( - task_table_data.task, 0) + task_table_message = ray.gcs_utils.Task.GetRootAsTask( + gcs_entries.Entries(0), 0) execution_spec = task_table_message.TaskExecutionSpec() task_spec = task_table_message.TaskSpecification() @@ -363,9 +368,9 @@ def task_table(self, task_id=None): return self._task_table(task_id) else: task_table_keys = self._keys( - gcs_utils.TablePrefix_RAYLET_TASK_string + "*") + ray.gcs_utils.TablePrefix_RAYLET_TASK_string + "*") task_ids_binary = [ - key[len(gcs_utils.TablePrefix_RAYLET_TASK_string):] + key[len(ray.gcs_utils.TablePrefix_RAYLET_TASK_string):] for key in task_table_keys ] @@ -375,6 +380,27 @@ def task_table(self, task_id=None): ray.TaskID(task_id_binary)) return results + def function_table(self, function_id=None): + """Fetch and parse the function table. + + Returns: + A dictionary that maps function IDs to information about the + function. + """ + self._check_connected() + function_table_keys = self.redis_client.keys( + ray.gcs_utils.FUNCTION_PREFIX + "*") + results = {} + for key in function_table_keys: + info = self.redis_client.hgetall(key) + function_info_parsed = { + "DriverID": binary_to_hex(info[b"driver_id"]), + "Module": decode(info[b"module"]), + "Name": decode(info[b"name"]) + } + results[binary_to_hex(info[b"function_id"])] = function_info_parsed + return results + def client_table(self): """Fetch and parse the Redis DB client table. @@ -397,32 +423,37 @@ def _profile_table(self, batch_id): # TODO(rkn): This method should support limiting the number of log # events and should also support returning a window of events. message = self._execute_command(batch_id, "RAY.TABLE_LOOKUP", - gcs_utils.TablePrefix.Value("PROFILE"), - "", batch_id.binary()) + ray.gcs_utils.TablePrefix.PROFILE, "", + batch_id.binary()) if message is None: return [] - gcs_entries = gcs_utils.GcsEntry.FromString(message) + gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0) profile_events = [] - for entry in gcs_entries.entries: - profile_table_message = gcs_utils.ProfileTableData.FromString( - entry) + for i in range(gcs_entries.EntriesLength()): + profile_table_message = ( + ray.gcs_utils.ProfileTableData.GetRootAsProfileTableData( + gcs_entries.Entries(i), 0)) + + component_type = decode(profile_table_message.ComponentType()) + component_id = binary_to_hex(profile_table_message.ComponentId()) + node_ip_address = decode( + profile_table_message.NodeIpAddress(), allow_none=True) - component_type = profile_table_message.component_type - component_id = binary_to_hex(profile_table_message.component_id) - node_ip_address = profile_table_message.node_ip_address + for j in range(profile_table_message.ProfileEventsLength()): + profile_event_message = profile_table_message.ProfileEvents(j) - for profile_event_message in profile_table_message.profile_events: profile_event = { - "event_type": profile_event_message.event_type, + "event_type": decode(profile_event_message.EventType()), "component_id": component_id, "node_ip_address": node_ip_address, "component_type": component_type, - "start_time": profile_event_message.start_time, - "end_time": profile_event_message.end_time, - "extra_data": json.loads(profile_event_message.extra_data), + "start_time": profile_event_message.StartTime(), + "end_time": profile_event_message.EndTime(), + "extra_data": json.loads( + decode(profile_event_message.ExtraData())), } profile_events.append(profile_event) @@ -431,10 +462,10 @@ def _profile_table(self, batch_id): def profile_table(self): self._check_connected() - profile_table_keys = self._keys(gcs_utils.TablePrefix_PROFILE_string + - "*") + profile_table_keys = self._keys( + ray.gcs_utils.TablePrefix_PROFILE_string + "*") batch_identifiers_binary = [ - key[len(gcs_utils.TablePrefix_PROFILE_string):] + key[len(ray.gcs_utils.TablePrefix_PROFILE_string):] for key in profile_table_keys ] @@ -735,7 +766,7 @@ def cluster_resources(self): clients = self.client_table() for client in clients: # Only count resources from latest entries of live clients. - if client["EntryType"] != gcs_utils.ClientTableData.DELETION: + if client["EntryType"] != EntryType.DELETION: for key, value in client["Resources"].items(): resources[key] += value return dict(resources) @@ -745,7 +776,7 @@ def _live_client_ids(self): return { client["ClientID"] for client in self.client_table() - if (client["EntryType"] != gcs_utils.ClientTableData.DELETION) + if (client["EntryType"] != EntryType.DELETION) } def available_resources(self): @@ -769,7 +800,7 @@ def available_resources(self): for redis_client in self.redis_clients ] for subscribe_client in subscribe_clients: - subscribe_client.subscribe(gcs_utils.XRAY_HEARTBEAT_CHANNEL) + subscribe_client.subscribe(ray.gcs_utils.XRAY_HEARTBEAT_CHANNEL) client_ids = self._live_client_ids() @@ -778,23 +809,24 @@ def available_resources(self): # Parse client message raw_message = subscribe_client.get_message() if (raw_message is None or raw_message["channel"] != - gcs_utils.XRAY_HEARTBEAT_CHANNEL): + ray.gcs_utils.XRAY_HEARTBEAT_CHANNEL): continue data = raw_message["data"] - gcs_entries = gcs_utils.GcsEntry.FromString(data) - heartbeat_data = gcs_entries.entries[0] - message = gcs_utils.HeartbeatTableData.FromString( - heartbeat_data) + gcs_entries = (ray.gcs_utils.GcsEntry.GetRootAsGcsEntry( + data, 0)) + heartbeat_data = gcs_entries.Entries(0) + message = (ray.gcs_utils.HeartbeatTableData. + GetRootAsHeartbeatTableData(heartbeat_data, 0)) # Calculate available resources for this client - num_resources = len(message.resources_available_label) + num_resources = message.ResourcesAvailableLabelLength() dynamic_resources = {} for i in range(num_resources): - resource_id = message.resources_available_label[i] + resource_id = decode(message.ResourcesAvailableLabel(i)) dynamic_resources[resource_id] = ( - message.resources_available_capacity[i]) + message.ResourcesAvailableCapacity(i)) # Update available resources for this client - client_id = ray.utils.binary_to_hex(message.client_id) + client_id = ray.utils.binary_to_hex(message.ClientId()) available_resources_by_id[client_id] = dynamic_resources # Update clients in cluster @@ -828,22 +860,23 @@ def _error_messages(self, driver_id): """ assert isinstance(driver_id, ray.DriverID) message = self.redis_client.execute_command( - "RAY.TABLE_LOOKUP", gcs_utils.TablePrefix.Value("ERROR_INFO"), "", + "RAY.TABLE_LOOKUP", ray.gcs_utils.TablePrefix.ERROR_INFO, "", driver_id.binary()) # If there are no errors, return early. if message is None: return [] - gcs_entries = gcs_utils.GcsEntry.FromString(message) + gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0) error_messages = [] - for entry in gcs_entries.entries: - error_data = gcs_utils.ErrorTableData.FromString(entry) - assert driver_id.binary() == error_data.driver_id + for i in range(gcs_entries.EntriesLength()): + error_data = ray.gcs_utils.ErrorTableData.GetRootAsErrorTableData( + gcs_entries.Entries(i), 0) + assert driver_id.binary() == error_data.DriverId() error_message = { - "type": error_data.type, - "message": error_data.error_message, - "timestamp": error_data.timestamp, + "type": decode(error_data.Type()), + "message": decode(error_data.ErrorMessage()), + "timestamp": error_data.Timestamp(), } error_messages.append(error_message) return error_messages @@ -866,9 +899,9 @@ def error_messages(self, driver_id=None): return self._error_messages(driver_id) error_table_keys = self.redis_client.keys( - gcs_utils.TablePrefix_ERROR_INFO_string + "*") + ray.gcs_utils.TablePrefix_ERROR_INFO_string + "*") driver_ids = [ - key[len(gcs_utils.TablePrefix_ERROR_INFO_string):] + key[len(ray.gcs_utils.TablePrefix_ERROR_INFO_string):] for key in error_table_keys ] @@ -890,23 +923,30 @@ def actor_checkpoint_info(self, actor_id): message = self._execute_command( actor_id, "RAY.TABLE_LOOKUP", - gcs_utils.TablePrefix.Value("ACTOR_CHECKPOINT_ID"), + ray.gcs_utils.TablePrefix.ACTOR_CHECKPOINT_ID, "", actor_id.binary(), ) if message is None: return None - gcs_entry = gcs_utils.GcsEntry.FromString(message) - entry = gcs_utils.ActorCheckpointIdData.FromString( - gcs_entry.entries[0]) + gcs_entry = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0) + entry = ( + ray.gcs_utils.ActorCheckpointIdData.GetRootAsActorCheckpointIdData( + gcs_entry.Entries(0), 0)) + checkpoint_ids_str = entry.CheckpointIds() + num_checkpoints = len(checkpoint_ids_str) // ID_SIZE + assert len(checkpoint_ids_str) % ID_SIZE == 0 checkpoint_ids = [ - ray.ActorCheckpointID(checkpoint_id) - for checkpoint_id in entry.checkpoint_ids + ray.ActorCheckpointID( + checkpoint_ids_str[(i * ID_SIZE):((i + 1) * ID_SIZE)]) + for i in range(num_checkpoints) ] return { - "ActorID": ray.utils.binary_to_hex(entry.actor_id), + "ActorID": ray.utils.binary_to_hex(entry.ActorId()), "CheckpointIds": checkpoint_ids, - "Timestamps": list(entry.timestamps), + "Timestamps": [ + entry.Timestamps(i) for i in range(num_checkpoints) + ], } diff --git a/python/ray/tests/cluster_utils.py b/python/ray/tests/cluster_utils.py index 76dfd3000b86..703c3a1420ed 100644 --- a/python/ray/tests/cluster_utils.py +++ b/python/ray/tests/cluster_utils.py @@ -8,7 +8,7 @@ import redis import ray -from ray.gcs_utils import ClientTableData +from ray.core.generated.EntryType import EntryType logger = logging.getLogger(__name__) @@ -177,7 +177,7 @@ def wait_for_nodes(self, timeout=30): clients = ray.state._parse_client_table(redis_client) live_clients = [ client for client in clients - if client["EntryType"] == ClientTableData.INSERTION + if client["EntryType"] == EntryType.INSERTION ] expected = len(self.list_all_nodes()) diff --git a/python/ray/tests/conftest.py b/python/ray/tests/conftest.py index f7c93fd50c2e..2e670fb0a84d 100644 --- a/python/ray/tests/conftest.py +++ b/python/ray/tests/conftest.py @@ -56,14 +56,6 @@ def _ray_start(**kwargs): ray.shutdown() -# The following fixture will start ray with 0 cpu. -@pytest.fixture -def ray_start_no_cpu(request): - param = getattr(request, "param", {}) - with _ray_start(num_cpus=0, **param) as res: - yield res - - # The following fixture will start ray with 1 cpu. @pytest.fixture def ray_start_regular(request): diff --git a/python/ray/tests/test_actor.py b/python/ray/tests/test_actor.py index 932f7b090bf7..dd726e00f27b 100644 --- a/python/ray/tests/test_actor.py +++ b/python/ray/tests/test_actor.py @@ -842,7 +842,7 @@ def f(): assert actor_id not in resulting_ids -def test_actors_on_nodes_with_no_cpus(ray_start_no_cpu): +def test_actors_on_nodes_with_no_cpus(ray_start_regular): @ray.remote class Foo(object): def method(self): diff --git a/python/ray/tests/test_basic.py b/python/ray/tests/test_basic.py index 6b4bd754cd4d..7f1f78d1b5c4 100644 --- a/python/ray/tests/test_basic.py +++ b/python/ray/tests/test_basic.py @@ -2736,17 +2736,15 @@ def test_duplicate_error_messages(shutdown_only): r = ray.worker.global_worker.redis_client - r.execute_command("RAY.TABLE_APPEND", - ray.gcs_utils.TablePrefix.Value("ERROR_INFO"), - ray.gcs_utils.TablePubsub.Value("ERROR_INFO_PUBSUB"), - driver_id.binary(), error_data) + r.execute_command("RAY.TABLE_APPEND", ray.gcs_utils.TablePrefix.ERROR_INFO, + ray.gcs_utils.TablePubsub.ERROR_INFO, driver_id.binary(), + error_data) # Before https://github.com/ray-project/ray/pull/3316 this would # give an error - r.execute_command("RAY.TABLE_APPEND", - ray.gcs_utils.TablePrefix.Value("ERROR_INFO"), - ray.gcs_utils.TablePubsub.Value("ERROR_INFO_PUBSUB"), - driver_id.binary(), error_data) + r.execute_command("RAY.TABLE_APPEND", ray.gcs_utils.TablePrefix.ERROR_INFO, + ray.gcs_utils.TablePubsub.ERROR_INFO, driver_id.binary(), + error_data) @pytest.mark.skipif( diff --git a/python/ray/tests/test_failure.py b/python/ray/tests/test_failure.py index a560e461f7a2..51b906695c2d 100644 --- a/python/ray/tests/test_failure.py +++ b/python/ray/tests/test_failure.py @@ -493,9 +493,8 @@ def test_warning_monitor_died(shutdown_only): malformed_message = "asdf" redis_client = ray.worker.global_worker.redis_client redis_client.execute_command( - "RAY.TABLE_ADD", ray.gcs_utils.TablePrefix.Value("HEARTBEAT_BATCH"), - ray.gcs_utils.TablePubsub.Value("HEARTBEAT_BATCH_PUBSUB"), fake_id, - malformed_message) + "RAY.TABLE_ADD", ray.gcs_utils.TablePrefix.HEARTBEAT_BATCH, + ray.gcs_utils.TablePubsub.HEARTBEAT_BATCH, fake_id, malformed_message) wait_for_errors(ray_constants.MONITOR_DIED_ERROR, 1) diff --git a/python/ray/tests/test_signal.py b/python/ray/tests/test_signal.py index 176fbd45bcaa..fe2e74379245 100644 --- a/python/ray/tests/test_signal.py +++ b/python/ray/tests/test_signal.py @@ -353,36 +353,3 @@ def f(sources): assert len(result_list) == 1 result_list = ray.get(f.remote([a])) assert len(result_list) == 1 - - -def test_non_integral_receive_timeout(ray_start_regular): - @ray.remote - def send_signal(value): - signal.send(UserSignal(value)) - - a = send_signal.remote(0) - # make sure send_signal had a chance to execute - ray.get(a) - - result_list = ray.experimental.signal.receive([a], timeout=0.1) - - assert len(result_list) == 1 - - -def test_small_receive_timeout(ray_start_regular): - """ Test that receive handles timeout smaller than the 1ms min - """ - # 0.1 ms - small_timeout = 1e-4 - - @ray.remote - def send_signal(value): - signal.send(UserSignal(value)) - - a = send_signal.remote(0) - # make sure send_signal had a chance to execute - ray.get(a) - - result_list = ray.experimental.signal.receive([a], timeout=small_timeout) - - assert len(result_list) == 1 diff --git a/python/ray/tune/analysis/experiment_analysis.py b/python/ray/tune/analysis/experiment_analysis.py index a3c246aba161..0164ec2b1a2e 100644 --- a/python/ray/tune/analysis/experiment_analysis.py +++ b/python/ray/tune/analysis/experiment_analysis.py @@ -47,14 +47,7 @@ class ExperimentAnalysis(object): >>> experiment_path="~/tune_results/my_exp") """ - def __init__(self, experiment_path, trials=None): - """Initializer. - - Args: - experiment_path (str): Path to where experiment is located. - trials (list|None): List of trials that can be accessed via - `analysis.trials`. - """ + def __init__(self, experiment_path): experiment_path = os.path.expanduser(experiment_path) if not os.path.isdir(experiment_path): raise TuneError( @@ -62,8 +55,7 @@ def __init__(self, experiment_path, trials=None): experiment_state_paths = glob.glob( os.path.join(experiment_path, "experiment_state*.json")) if not experiment_state_paths: - raise TuneError( - "No experiment state found in {}!".format(experiment_path)) + raise TuneError("No experiment state found!") experiment_filename = max( list(experiment_state_paths)) # if more than one, pick latest with open(os.path.join(experiment_path, experiment_filename)) as f: @@ -73,27 +65,10 @@ def __init__(self, experiment_path, trials=None): raise TuneError("Experiment state invalid; no checkpoints found.") self._checkpoints = self._experiment_state["checkpoints"] self._scrubbed_checkpoints = unnest_checkpoints(self._checkpoints) - self.trials = trials - self._dataframe = None - - def get_all_trial_dataframes(self): - trial_dfs = {} - for checkpoint in self._checkpoints: - logdir = checkpoint["logdir"] - progress = max(glob.glob(os.path.join(logdir, "progress.csv"))) - trial_dfs[checkpoint["trial_id"]] = pd.read_csv(progress) - return trial_dfs - - def dataframe(self, refresh=False): - """Returns a pandas.DataFrame object constructed from the trials. - Args: - refresh (bool): Clears the cache which may have an existing copy. - - """ - if self._dataframe is None or refresh: - self._dataframe = pd.DataFrame(self._scrubbed_checkpoints) - return self._dataframe + def dataframe(self): + """Returns a pandas.DataFrame object constructed from the trials.""" + return pd.DataFrame(self._scrubbed_checkpoints) def stats(self): """Returns a dictionary of the statistics of the experiment.""" @@ -112,45 +87,22 @@ def trial_dataframe(self, trial_id): return pd.read_csv(progress) raise ValueError("Trial id {} not found".format(trial_id)) - def get_best_trainable(self, metric, trainable_cls, mode="max"): - """Returns the best Trainable based on the experiment metric. - - Args: - metric (str): Key for trial info to order on. - mode (str): One of [min, max]. - - """ - return trainable_cls(config=self.get_best_config(metric, mode=mode)) - - def get_best_config(self, metric, mode="max"): - """Retrieve the best config from the best trial. - - Args: - metric (str): Key for trial info to order on. - mode (str): One of [min, max]. - - """ - return self.get_best_info(metric, flatten=False, mode=mode)["config"] - - def get_best_logdir(self, metric, mode="max"): - df = self.dataframe() - if mode == "max": - return df.iloc[df[metric].idxmax()].logdir - elif mode == "min": - return df.iloc[df[metric].idxmin()].logdir - - def get_best_info(self, metric, mode="max", flatten=True): - """Retrieve the best trial based on the experiment metric. - - Args: - metric (str): Key for trial info to order on. - mode (str): One of [min, max]. - flatten (bool): Assumes trial info is flattened, where - nested entries are concatenated like `info:metric`. - """ - optimize_op = max if mode == "max" else min - if flatten: - return optimize_op( - self._scrubbed_checkpoints, key=lambda d: d.get(metric, 0)) - return optimize_op( + def get_best_trainable(self, metric, trainable_cls): + """Returns the best Trainable based on the experiment metric.""" + return trainable_cls(config=self.get_best_config(metric)) + + def get_best_config(self, metric): + """Retrieve the best config from the best trial.""" + return self._get_best_trial(metric)["config"] + + def _get_best_trial(self, metric): + """Retrieve the best trial based on the experiment metric.""" + return max( self._checkpoints, key=lambda d: d["last_result"].get(metric, 0)) + + def _get_sorted_trials(self, metric): + """Retrive trials in sorted order based on the experiment metric.""" + return sorted( + self._checkpoints, + key=lambda d: d["last_result"].get(metric, 0), + reverse=True) diff --git a/python/ray/tune/examples/mnist_pytorch.py b/python/ray/tune/examples/mnist_pytorch.py index acef9fc5105d..03dd2f1607e2 100644 --- a/python/ray/tune/examples/mnist_pytorch.py +++ b/python/ray/tune/examples/mnist_pytorch.py @@ -1,10 +1,7 @@ # Original Code here: # https://github.com/pytorch/examples/blob/master/mnist/main.py -from __future__ import absolute_import -from __future__ import division from __future__ import print_function -import numpy as np import argparse import torch import torch.nn as nn @@ -12,123 +9,181 @@ import torch.optim as optim from torchvision import datasets, transforms -import ray -from ray import tune -from ray.tune import track -from ray.tune.schedulers import AsyncHyperBandScheduler - -# Change these values if you want the training to run quicker or slower. -EPOCH_SIZE = 512 -TEST_SIZE = 256 - - -class Net(nn.Module): - def __init__(self, config): - super(Net, self).__init__() - self.conv1 = nn.Conv2d(1, 3, kernel_size=3) - self.fc = nn.Linear(192, 10) - - def forward(self, x): - x = F.relu(F.max_pool2d(self.conv1(x), 3)) - x = x.view(-1, 192) - x = self.fc(x) - return F.log_softmax(x, dim=1) - - -def train(model, optimizer, train_loader, device): - model.train() - for batch_idx, (data, target) in enumerate(train_loader): - if batch_idx * len(data) > EPOCH_SIZE: - return - data, target = data.to(device), target.to(device) - optimizer.zero_grad() - output = model(data) - loss = F.nll_loss(output, target) - loss.backward() - optimizer.step() - - -def test(model, data_loader, device): - model.eval() - correct = 0 - total = 0 - with torch.no_grad(): - for batch_idx, (data, target) in enumerate(data_loader): - if batch_idx * len(data) > TEST_SIZE: - break - data, target = data.to(device), target.to(device) - outputs = model(data) - _, predicted = torch.max(outputs.data, 1) - total += target.size(0) - correct += (predicted == target).sum().item() - - return correct / total - - -def get_data_loaders(): - mnist_transforms = transforms.Compose( - [transforms.ToTensor(), - transforms.Normalize((0.1307, ), (0.3081, ))]) - +# Training settings +parser = argparse.ArgumentParser(description="PyTorch MNIST Example") +parser.add_argument( + "--batch-size", + type=int, + default=64, + metavar="N", + help="input batch size for training (default: 64)") +parser.add_argument( + "--test-batch-size", + type=int, + default=1000, + metavar="N", + help="input batch size for testing (default: 1000)") +parser.add_argument( + "--epochs", + type=int, + default=1, + metavar="N", + help="number of epochs to train (default: 1)") +parser.add_argument( + "--lr", + type=float, + default=0.01, + metavar="LR", + help="learning rate (default: 0.01)") +parser.add_argument( + "--momentum", + type=float, + default=0.5, + metavar="M", + help="SGD momentum (default: 0.5)") +parser.add_argument( + "--no-cuda", + action="store_true", + default=False, + help="disables CUDA training") +parser.add_argument( + "--seed", + type=int, + default=1, + metavar="S", + help="random seed (default: 1)") +parser.add_argument( + "--smoke-test", action="store_true", help="Finish quickly for testing") + + +def train_mnist(args, config, reporter): + vars(args).update(config) + args.cuda = not args.no_cuda and torch.cuda.is_available() + + torch.manual_seed(args.seed) + if args.cuda: + torch.cuda.manual_seed(args.seed) + + kwargs = {"num_workers": 1, "pin_memory": True} if args.cuda else {} train_loader = torch.utils.data.DataLoader( datasets.MNIST( - "~/data", train=True, download=True, transform=mnist_transforms), - batch_size=64, - shuffle=True) + "~/data", + train=True, + download=False, + transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307, ), (0.3081, )) + ])), + batch_size=args.batch_size, + shuffle=True, + **kwargs) test_loader = torch.utils.data.DataLoader( - datasets.MNIST("~/data", train=False, transform=mnist_transforms), - batch_size=64, - shuffle=True) - return train_loader, test_loader - - -def train_mnist(config): - use_cuda = config.get("use_gpu") and torch.cuda.is_available() - device = torch.device("cuda" if use_cuda else "cpu") - train_loader, test_loader = get_data_loaders() - model = Net(config).to(device) + datasets.MNIST( + "~/data", + train=False, + transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307, ), (0.3081, )) + ])), + batch_size=args.test_batch_size, + shuffle=True, + **kwargs) + + class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(1, 10, kernel_size=5) + self.conv2 = nn.Conv2d(10, 20, kernel_size=5) + self.conv2_drop = nn.Dropout2d() + self.fc1 = nn.Linear(320, 50) + self.fc2 = nn.Linear(50, 10) + + def forward(self, x): + x = F.relu(F.max_pool2d(self.conv1(x), 2)) + x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) + x = x.view(-1, 320) + x = F.relu(self.fc1(x)) + x = F.dropout(x, training=self.training) + x = self.fc2(x) + return F.log_softmax(x, dim=1) + + model = Net() + if args.cuda: + model.cuda() optimizer = optim.SGD( - model.parameters(), lr=config["lr"], momentum=config["momentum"]) - - while True: - train(model, optimizer, train_loader, device) - acc = test(model, test_loader, device) - track.log(mean_accuracy=acc) + model.parameters(), lr=args.lr, momentum=args.momentum) + + def train(epoch): + model.train() + for batch_idx, (data, target) in enumerate(train_loader): + if args.cuda: + data, target = data.cuda(), target.cuda() + optimizer.zero_grad() + output = model(data) + loss = F.nll_loss(output, target) + loss.backward() + optimizer.step() + + def test(): + model.eval() + test_loss = 0 + correct = 0 + with torch.no_grad(): + for data, target in test_loader: + if args.cuda: + data, target = data.cuda(), target.cuda() + output = model(data) + # sum up batch loss + test_loss += F.nll_loss(output, target, reduction="sum").item() + # get the index of the max log-probability + pred = output.argmax(dim=1, keepdim=True) + correct += pred.eq( + target.data.view_as(pred)).long().cpu().sum() + + test_loss = test_loss / len(test_loader.dataset) + accuracy = correct.item() / len(test_loader.dataset) + reporter(mean_loss=test_loss, mean_accuracy=accuracy) + + for epoch in range(1, args.epochs + 1): + train(epoch) + test() if __name__ == "__main__": - parser = argparse.ArgumentParser(description="PyTorch MNIST Example") - parser.add_argument( - "--cuda", - action="store_true", - default=False, - help="Enables GPU training") - parser.add_argument( - "--smoke-test", action="store_true", help="Finish quickly for testing") - parser.add_argument( - "--ray-redis-address", - help="Address of Ray cluster for seamless distributed execution.") + datasets.MNIST("~/data", train=True, download=True) args = parser.parse_args() - if args.ray_redis_address: - ray.init(redis_address=args.ray_redis_address) + + import ray + from ray import tune + from ray.tune.schedulers import AsyncHyperBandScheduler + + ray.init() sched = AsyncHyperBandScheduler( - time_attr="training_iteration", metric="mean_accuracy") + time_attr="training_iteration", + metric="mean_loss", + mode="min", + max_t=400, + grace_period=20) + tune.register_trainable( + "TRAIN_FN", + lambda config, reporter: train_mnist(args, config, reporter)) tune.run( - train_mnist, + "TRAIN_FN", name="exp", scheduler=sched, - stop={ - "mean_accuracy": 0.98, - "training_iteration": 5 if args.smoke_test else 20 - }, - resources_per_trial={ - "cpu": 2, - "gpu": int(args.cuda) - }, - num_samples=1 if args.smoke_test else 10, - config={ - "lr": tune.sample_from(lambda spec: 10**(-10 * np.random.rand())), - "momentum": tune.uniform(0.1, 0.9), - "use_gpu": int(args.cuda) + **{ + "stop": { + "mean_accuracy": 0.98, + "training_iteration": 1 if args.smoke_test else 20 + }, + "resources_per_trial": { + "cpu": 3, + "gpu": int(not args.no_cuda) + }, + "num_samples": 1 if args.smoke_test else 10, + "config": { + "lr": tune.uniform(0.001, 0.1), + "momentum": tune.uniform(0.1, 0.9), + } }) diff --git a/python/ray/tune/examples/track_example.py b/python/ray/tune/examples/track_example.py index 751f0ed44fa9..1ccec39462d0 100644 --- a/python/ray/tune/examples/track_example.py +++ b/python/ray/tune/examples/track_example.py @@ -9,7 +9,7 @@ from keras.layers import (Dense, Dropout, Flatten, Conv2D, MaxPooling2D) from ray.tune import track -from ray.tune.examples.utils import TuneReporterCallback, get_mnist_data +from ray.tune.examples.utils import TuneKerasCallback, get_mnist_data parser = argparse.ArgumentParser() parser.add_argument( @@ -63,7 +63,7 @@ def train_mnist(args): batch_size=batch_size, epochs=epochs, validation_data=(x_test, y_test), - callbacks=[TuneReporterCallback(track.metric)]) + callbacks=[TuneKerasCallback(track.metric)]) track.shutdown() diff --git a/python/ray/tune/examples/tune_mnist_keras.py b/python/ray/tune/examples/tune_mnist_keras.py index ecd3c34bc042..5357d86af19e 100644 --- a/python/ray/tune/examples/tune_mnist_keras.py +++ b/python/ray/tune/examples/tune_mnist_keras.py @@ -9,8 +9,8 @@ from keras.models import Sequential from keras.layers import (Dense, Dropout, Flatten, Conv2D, MaxPooling2D) -from ray.tune.integration.keras import TuneReporterCallback -from ray.tune.examples.utils import get_mnist_data, set_keras_threads +from ray.tune.examples.utils import (TuneKerasCallback, get_mnist_data, + set_keras_threads) parser = argparse.ArgumentParser() parser.add_argument( @@ -52,7 +52,7 @@ def train_mnist(config, reporter): epochs=epochs, verbose=0, validation_data=(x_test, y_test), - callbacks=[TuneReporterCallback(reporter)]) + callbacks=[TuneKerasCallback(reporter)]) if __name__ == "__main__": @@ -63,7 +63,7 @@ def train_mnist(config, reporter): ray.init() sched = AsyncHyperBandScheduler( - time_attr="training_iteration", + time_attr="timesteps_total", metric="mean_accuracy", mode="max", max_t=400, diff --git a/python/ray/tune/examples/utils.py b/python/ray/tune/examples/utils.py index f40707a014fc..a5ab1dbdb6a1 100644 --- a/python/ray/tune/examples/utils.py +++ b/python/ray/tune/examples/utils.py @@ -5,9 +5,24 @@ import keras from keras.datasets import mnist from keras import backend as K -from sklearn.datasets import load_iris -from sklearn.model_selection import train_test_split -from sklearn.preprocessing import OneHotEncoder + + +class TuneKerasCallback(keras.callbacks.Callback): + def __init__(self, reporter, logs={}): + self.reporter = reporter + self.iteration = 0 + super(TuneKerasCallback, self).__init__() + + def on_train_end(self, epoch, logs={}): + self.reporter( + timesteps_total=self.iteration, + done=1, + mean_accuracy=logs.get("acc")) + + def on_batch_end(self, batch, logs={}): + self.iteration += 1 + self.reporter( + timesteps_total=self.iteration, mean_accuracy=logs["acc"]) def get_mnist_data(): @@ -38,16 +53,6 @@ def get_mnist_data(): return x_train, y_train, x_test, y_test, input_shape -def get_iris_data(test_size=0.2): - iris_data = load_iris() - x = iris_data.data - y = iris_data.target.reshape(-1, 1) - encoder = OneHotEncoder(sparse=False) - y = encoder.fit_transform(y) - train_x, test_x, train_y, test_y = train_test_split(x, y) - return train_x, train_y, test_x, test_y - - def set_keras_threads(threads): # We set threads here to avoid contention, as Keras # is heavily parallelized across multiple cores. @@ -56,8 +61,3 @@ def set_keras_threads(threads): config=K.tf.ConfigProto( intra_op_parallelism_threads=threads, inter_op_parallelism_threads=threads))) - - -def TuneKerasCallback(*args, **kwargs): - raise DeprecationWarning("TuneKerasCallback is now " - "tune.integration.keras.TuneReporterCallback.") diff --git a/python/ray/tune/experiment.py b/python/ray/tune/experiment.py index 95cb12043f8f..5f3e46aabd0a 100644 --- a/python/ray/tune/experiment.py +++ b/python/ray/tune/experiment.py @@ -176,14 +176,6 @@ def _register_if_needed(cls, run_object): else: raise TuneError("Improper 'run' - not string nor trainable.") - @property - def local_dir(self): - return self.spec.get("local_dir") - - @property - def checkpoint_dir(self): - return os.path.join(self.spec["local_dir"], self.name) - def convert_to_experiment_list(experiments): """Produces a list of Experiment objects. diff --git a/python/ray/tune/integration/__init__.py b/python/ray/tune/integration/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/python/ray/tune/integration/keras.py b/python/ray/tune/integration/keras.py deleted file mode 100644 index 197a7eef9841..000000000000 --- a/python/ray/tune/integration/keras.py +++ /dev/null @@ -1,34 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import keras -from ray.tune import track - - -class TuneReporterCallback(keras.callbacks.Callback): - def __init__(self, reporter=None, freq="batch", logs={}): - self.reporter = reporter or track.log - self.iteration = 0 - if freq not in ["batch", "epoch"]: - raise ValueError("{} not supported as a frequency.".format(freq)) - self.freq = freq - super(TuneReporterCallback, self).__init__() - - def on_batch_end(self, batch, logs={}): - if not self.freq == "batch": - return - self.iteration += 1 - for metric in list(logs): - if "loss" in metric and "neg_" not in metric: - logs["neg_" + metric] = -logs[metric] - self.reporter(keras_info=logs, mean_accuracy=logs["acc"]) - - def on_epoch_end(self, batch, logs={}): - if not self.freq == "epoch": - return - self.iteration += 1 - for metric in list(logs): - if "loss" in metric and "neg_" not in metric: - logs["neg_" + metric] = -logs[metric] - self.reporter(keras_info=logs, mean_accuracy=logs["acc"]) diff --git a/python/ray/tune/schedulers/__init__.py b/python/ray/tune/schedulers/__init__.py index 34655372f40a..50bb447437e4 100644 --- a/python/ray/tune/schedulers/__init__.py +++ b/python/ray/tune/schedulers/__init__.py @@ -4,13 +4,11 @@ from ray.tune.schedulers.trial_scheduler import TrialScheduler, FIFOScheduler from ray.tune.schedulers.hyperband import HyperBandScheduler -from ray.tune.schedulers.async_hyperband import (AsyncHyperBandScheduler, - ASHAScheduler) +from ray.tune.schedulers.async_hyperband import AsyncHyperBandScheduler from ray.tune.schedulers.median_stopping_rule import MedianStoppingRule from ray.tune.schedulers.pbt import PopulationBasedTraining __all__ = [ "TrialScheduler", "HyperBandScheduler", "AsyncHyperBandScheduler", - "ASHAScheduler", "MedianStoppingRule", "FIFOScheduler", - "PopulationBasedTraining" + "MedianStoppingRule", "FIFOScheduler", "PopulationBasedTraining" ] diff --git a/python/ray/tune/schedulers/async_hyperband.py b/python/ray/tune/schedulers/async_hyperband.py index 0370d03d3b50..487eb350efcf 100644 --- a/python/ray/tune/schedulers/async_hyperband.py +++ b/python/ray/tune/schedulers/async_hyperband.py @@ -168,8 +168,6 @@ def debug_str(self): return "Bracket: " + iters -ASHAScheduler = AsyncHyperBandScheduler - if __name__ == "__main__": sched = AsyncHyperBandScheduler( grace_period=1, max_t=10, reduction_factor=2) diff --git a/python/ray/tune/tests/test_experiment_analysis.py b/python/ray/tune/tests/test_experiment_analysis.py index 7b613a6fdea2..a0721abc5d29 100644 --- a/python/ray/tune/tests/test_experiment_analysis.py +++ b/python/ray/tune/tests/test_experiment_analysis.py @@ -11,7 +11,9 @@ import ray from ray.tune import run, sample_from +from ray.tune.analysis import ExperimentAnalysis from ray.tune.examples.async_hyperband_example import MyTrainableClass +from ray.tune.schedulers import AsyncHyperBandScheduler class ExperimentAnalysisSuite(unittest.TestCase): @@ -25,22 +27,35 @@ def setUp(self): self.test_path = os.path.join(self.test_dir, self.test_name) self.run_test_exp() + self.ea = ExperimentAnalysis(self.test_path) + def tearDown(self): shutil.rmtree(self.test_dir, ignore_errors=True) ray.shutdown() def run_test_exp(self): - self.ea = run( - MyTrainableClass, + ahb = AsyncHyperBandScheduler( + time_attr="training_iteration", + metric=self.metric, + mode="max", + grace_period=5, + max_t=100) + + run(MyTrainableClass, name=self.test_name, + scheduler=ahb, local_dir=self.test_dir, - return_trials=False, - stop={"training_iteration": 1}, - num_samples=self.num_samples, - config={ - "width": sample_from( - lambda spec: 10 + int(90 * random.random())), - "height": sample_from(lambda spec: int(100 * random.random())), + **{ + "stop": { + "training_iteration": 1 + }, + "num_samples": 10, + "config": { + "width": sample_from( + lambda spec: 10 + int(90 * random.random())), + "height": sample_from( + lambda spec: int(100 * random.random())), + }, }) def testDataframe(self): @@ -72,7 +87,7 @@ def testBestConfig(self): self.assertTrue("height" in best_config) def testBestTrial(self): - best_trial = self.ea.get_best_info(self.metric, flatten=False) + best_trial = self.ea._get_best_trial(self.metric) self.assertTrue(isinstance(best_trial, dict)) self.assertTrue("local_dir" in best_trial) @@ -84,18 +99,6 @@ def testBestTrial(self): self.assertTrue("last_result" in best_trial) self.assertTrue(self.metric in best_trial["last_result"]) - min_trial = self.ea.get_best_info( - self.metric, mode="min", flatten=False) - - self.assertTrue(isinstance(min_trial, dict)) - self.assertLess(min_trial["last_result"][self.metric], - best_trial["last_result"][self.metric]) - - flat_trial = self.ea.get_best_info(self.metric, flatten=True) - - self.assertTrue(isinstance(min_trial, dict)) - self.assertTrue(self.metric in flat_trial) - def testCheckpoints(self): checkpoints = self.ea._checkpoints @@ -118,21 +121,6 @@ def testRunnerData(self): self.assertEqual(runner_data["_metadata_checkpoint_dir"], os.path.expanduser(self.test_path)) - def testBestLogdir(self): - logdir = self.ea.get_best_logdir(self.metric) - self.assertTrue(logdir.startswith(self.test_path)) - logdir2 = self.ea.get_best_logdir(self.metric, mode="min") - self.assertTrue(logdir2.startswith(self.test_path)) - self.assertNotEquals(logdir, logdir2) - - def testAllDataframes(self): - dataframes = self.ea.get_all_trial_dataframes() - self.assertTrue(len(dataframes) == self.num_samples) - - self.assertTrue(isinstance(dataframes, dict)) - for df in dataframes.values(): - self.assertEqual(df.training_iteration.max(), 1) - if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/python/ray/tune/tests/test_trial_runner.py b/python/ray/tune/tests/test_trial_runner.py index 64b8e9761488..37022ceab615 100644 --- a/python/ray/tune/tests/test_trial_runner.py +++ b/python/ray/tune/tests/test_trial_runner.py @@ -441,14 +441,6 @@ def f(): self.assertRaises(TuneError, f) - def testNestedStoppingReturn(self): - def train(config, reporter): - for i in range(10): - reporter(test={"test1": {"test2": i}}) - - [trial] = tune.run(train, stop={"test": {"test1": {"test2": 6}}}) - self.assertEqual(trial.last_result["training_iteration"], 7) - def testEarlyReturn(self): def train(config, reporter): reporter(timesteps_total=100, done=True) diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index a9938396e59b..f721023b4191 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -181,21 +181,6 @@ def has_trainable(trainable_name): ray.tune.registry.TRAINABLE_CLASS, trainable_name) -def recursive_criteria_check(result, criteria): - for criteria, stop_value in criteria.items(): - if criteria not in result: - raise TuneError( - "Stopping criteria {} not provided in result {}.".format( - criteria, result)) - elif isinstance(result[criteria], dict) and isinstance( - stop_value, dict): - if recursive_criteria_check(result[criteria], stop_value): - return True - elif result[criteria] >= stop_value: - return True - return False - - class Checkpoint(object): """Describes a checkpoint of trial state. @@ -440,7 +425,15 @@ def should_stop(self, result): if result.get(DONE): return True - return recursive_criteria_check(result, self.stopping_criterion) + for criteria, stop_value in self.stopping_criterion.items(): + if criteria not in result: + raise TuneError( + "Stopping criteria {} not provided in result {}.".format( + criteria, result)) + if result[criteria] >= stop_value: + return True + + return False def should_checkpoint(self): """Whether this trial is due for checkpointing.""" diff --git a/python/ray/tune/tune.py b/python/ray/tune/tune.py index 47a82ba0c17f..1568db0f1102 100644 --- a/python/ray/tune/tune.py +++ b/python/ray/tune/tune.py @@ -4,11 +4,11 @@ import click import logging +import os import time from ray.tune.error import TuneError from ray.tune.experiment import convert_to_experiment_list, Experiment -from ray.tune.analysis import ExperimentAnalysis from ray.tune.suggest import BasicVariantGenerator from ray.tune.trial import Trial, DEBUG_PRINT_INTERVAL from ray.tune.ray_trial_executor import RayTrialExecutor @@ -39,7 +39,7 @@ def _make_scheduler(args): def _find_checkpoint_dir(exp): # TODO(rliaw): Make sure the checkpoint_dir is resolved earlier. # Right now it is resolved somewhere far down the trial generation process - return exp.checkpoint_dir + return os.path.join(exp.spec["local_dir"], exp.name) def _prompt_restore(checkpoint_dir, resume): @@ -89,10 +89,9 @@ def run(run_or_experiment, verbose=2, resume=False, queue_trials=False, - reuse_actors=True, + reuse_actors=False, trial_executor=None, raise_on_failed_trial=True, - return_trials=True, ray_auto_init=True): """Executes training. @@ -323,9 +322,7 @@ def override_flags(restored_config, new_config, flags_to_override): else: logger.error("Trials did not complete: %s", errored_trials) - if return_trials: - return runner.get_trials() - return ExperimentAnalysis(experiment.checkpoint_dir) + return runner.get_trials() def run_experiments(experiments, diff --git a/python/ray/utils.py b/python/ray/utils.py index 0db48e41d025..7b87486e325e 100644 --- a/python/ray/utils.py +++ b/python/ray/utils.py @@ -93,10 +93,10 @@ def push_error_to_driver_through_redis(redis_client, # of through the raylet. error_data = ray.gcs_utils.construct_error_message(driver_id, error_type, message, time.time()) - redis_client.execute_command( - "RAY.TABLE_APPEND", ray.gcs_utils.TablePrefix.Value("ERROR_INFO"), - ray.gcs_utils.TablePubsub.Value("ERROR_INFO_PUBSUB"), - driver_id.binary(), error_data) + redis_client.execute_command("RAY.TABLE_APPEND", + ray.gcs_utils.TablePrefix.ERROR_INFO, + ray.gcs_utils.TablePubsub.ERROR_INFO, + driver_id.binary(), error_data) def is_cython(obj): diff --git a/python/ray/worker.py b/python/ray/worker.py index 710f0db43c6b..7505120574a6 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -47,7 +47,7 @@ from ray import import_thread from ray import profiling -from ray.gcs_utils import ErrorType +from ray.core.generated.ErrorType import ErrorType from ray.exceptions import ( RayActorError, RayError, @@ -461,11 +461,11 @@ def _deserialize_object_from_arrow(self, data, metadata, object_id, # Otherwise, return an exception object based on # the error type. error_type = int(metadata) - if error_type == ErrorType.Value("WORKER_DIED"): + if error_type == ErrorType.WORKER_DIED: return RayWorkerError() - elif error_type == ErrorType.Value("ACTOR_DIED"): + elif error_type == ErrorType.ACTOR_DIED: return RayActorError() - elif error_type == ErrorType.Value("OBJECT_UNRECONSTRUCTABLE"): + elif error_type == ErrorType.OBJECT_UNRECONSTRUCTABLE: return UnreconstructableError(ray.ObjectID(object_id.binary())) else: assert False, "Unrecognized error type " + str(error_type) @@ -1637,7 +1637,7 @@ def listen_error_messages_raylet(worker, task_error_queue, threads_stopped): # Really we should just subscribe to the errors for this specific job. # However, currently all errors seem to be published on the same channel. error_pubsub_channel = str( - ray.gcs_utils.TablePubsub.Value("ERROR_INFO_PUBSUB")).encode("ascii") + ray.gcs_utils.TablePubsub.ERROR_INFO).encode("ascii") worker.error_message_pubsub_client.subscribe(error_pubsub_channel) # worker.error_message_pubsub_client.psubscribe("*") @@ -1656,19 +1656,21 @@ def listen_error_messages_raylet(worker, task_error_queue, threads_stopped): if msg is None: threads_stopped.wait(timeout=0.01) continue - gcs_entry = ray.gcs_utils.GcsEntry.FromString(msg["data"]) - assert len(gcs_entry.entries) == 1 - error_data = ray.gcs_utils.ErrorTableData.FromString( - gcs_entry.entries[0]) - driver_id = error_data.driver_id + gcs_entry = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry( + msg["data"], 0) + assert gcs_entry.EntriesLength() == 1 + error_data = ray.gcs_utils.ErrorTableData.GetRootAsErrorTableData( + gcs_entry.Entries(0), 0) + driver_id = error_data.DriverId() if driver_id not in [ worker.task_driver_id.binary(), DriverID.nil().binary() ]: continue - error_message = error_data.error_message - if (error_data.type == ray_constants.TASK_PUSH_ERROR): + error_message = ray.utils.decode(error_data.ErrorMessage()) + if (ray.utils.decode( + error_data.Type()) == ray_constants.TASK_PUSH_ERROR): # Delay it a bit to see if we can suppress it task_error_queue.put((error_message, time.time())) else: @@ -1876,16 +1878,14 @@ def connect(node, {}, # resource_map. {}, # placement_resource_map. ) - task_table_data = ray.gcs_utils.TaskTableData() - task_table_data.task = driver_task._serialized_raylet_task() # Add the driver task to the task table. - ray.state.state._execute_command( - driver_task.task_id(), "RAY.TABLE_ADD", - ray.gcs_utils.TablePrefix.Value("RAYLET_TASK"), - ray.gcs_utils.TablePubsub.Value("RAYLET_TASK_PUBSUB"), - driver_task.task_id().binary(), - task_table_data.SerializeToString()) + ray.state.state._execute_command(driver_task.task_id(), + "RAY.TABLE_ADD", + ray.gcs_utils.TablePrefix.RAYLET_TASK, + ray.gcs_utils.TablePubsub.RAYLET_TASK, + driver_task.task_id().binary(), + driver_task._serialized_raylet_task()) # Set the driver's current task ID to the task ID assigned to the # driver task. diff --git a/python/setup.py b/python/setup.py index 95e7e66bad3e..eb200ea7d5e4 100644 --- a/python/setup.py +++ b/python/setup.py @@ -151,7 +151,6 @@ def find_version(*filepath): "six >= 1.0.0", "flatbuffers", "faulthandler;python_version<'3.3'", - "protobuf", ] setup( diff --git a/src/ray/common/constants.h b/src/ray/common/constants.h index 1f50b8025d57..c92e6a74aa5d 100644 --- a/src/ray/common/constants.h +++ b/src/ray/common/constants.h @@ -36,6 +36,4 @@ constexpr char kObjectTablePrefix[] = "ObjectTable"; /// Prefix for the task table keys in redis. constexpr char kTaskTablePrefix[] = "TaskTable"; -constexpr char kWorkerDynamicOptionPlaceholderPrefix[] = "RAY_WORKER_OPTION_"; - #endif // RAY_CONSTANTS_H_ diff --git a/src/ray/gcs/client.cc b/src/ray/gcs/client.cc index 6de29bb52764..c9b1e138575d 100644 --- a/src/ray/gcs/client.cc +++ b/src/ray/gcs/client.cc @@ -206,6 +206,10 @@ TaskLeaseTable &AsyncGcsClient::task_lease_table() { return *task_lease_table_; ClientTable &AsyncGcsClient::client_table() { return *client_table_; } +FunctionTable &AsyncGcsClient::function_table() { return *function_table_; } + +ClassTable &AsyncGcsClient::class_table() { return *class_table_; } + HeartbeatTable &AsyncGcsClient::heartbeat_table() { return *heartbeat_table_; } HeartbeatBatchTable &AsyncGcsClient::heartbeat_batch_table() { diff --git a/src/ray/gcs/client.h b/src/ray/gcs/client.h index 5e70025b39a0..c9f5b4bca624 100644 --- a/src/ray/gcs/client.h +++ b/src/ray/gcs/client.h @@ -44,7 +44,11 @@ class RAY_EXPORT AsyncGcsClient { /// one event loop should be attached at a time. Status Attach(boost::asio::io_service &io_service); + inline FunctionTable &function_table(); // TODO: Some API for getting the error on the driver + inline ClassTable &class_table(); + inline CustomSerializerTable &custom_serializer_table(); + inline ConfigTable &config_table(); ObjectTable &object_table(); raylet::TaskTable &raylet_task_table(); ActorTable &actor_table(); @@ -77,6 +81,8 @@ class RAY_EXPORT AsyncGcsClient { std::string DebugString() const; private: + std::unique_ptr function_table_; + std::unique_ptr class_table_; std::unique_ptr object_table_; std::unique_ptr raylet_task_table_; std::unique_ptr actor_table_; diff --git a/src/ray/gcs/client_test.cc b/src/ray/gcs/client_test.cc index 55115b1e2067..c7dc02e50651 100644 --- a/src/ray/gcs/client_test.cc +++ b/src/ray/gcs/client_test.cc @@ -85,21 +85,21 @@ class TestGcsWithChainAsio : public TestGcsWithAsio { void TestTableLookup(const DriverID &driver_id, std::shared_ptr client) { TaskID task_id = TaskID::FromRandom(); - auto data = std::make_shared(); - data->set_task("123"); + auto data = std::make_shared(); + data->task_specification = "123"; // Check that we added the correct task. auto add_callback = [task_id, data](gcs::AsyncGcsClient *client, const TaskID &id, - const TaskTableData &d) { + const protocol::TaskT &d) { ASSERT_EQ(id, task_id); - ASSERT_EQ(data->task(), d.task()); + ASSERT_EQ(data->task_specification, d.task_specification); }; // Check that the lookup returns the added task. auto lookup_callback = [task_id, data](gcs::AsyncGcsClient *client, const TaskID &id, - const TaskTableData &d) { + const protocol::TaskT &d) { ASSERT_EQ(id, task_id); - ASSERT_EQ(data->task(), d.task()); + ASSERT_EQ(data->task_specification, d.task_specification); test->Stop(); }; @@ -136,13 +136,13 @@ void TestLogLookup(const DriverID &driver_id, TaskID task_id = TaskID::FromRandom(); std::vector node_manager_ids = {"abc", "def", "ghi"}; for (auto &node_manager_id : node_manager_ids) { - auto data = std::make_shared(); - data->set_node_manager_id(node_manager_id); + auto data = std::make_shared(); + data->node_manager_id = node_manager_id; // Check that we added the correct object entries. auto add_callback = [task_id, data](gcs::AsyncGcsClient *client, const TaskID &id, - const TaskReconstructionData &d) { + const TaskReconstructionDataT &d) { ASSERT_EQ(id, task_id); - ASSERT_EQ(data->node_manager_id(), d.node_manager_id()); + ASSERT_EQ(data->node_manager_id, d.node_manager_id); }; RAY_CHECK_OK( client->task_reconstruction_log().Append(driver_id, task_id, data, add_callback)); @@ -151,10 +151,10 @@ void TestLogLookup(const DriverID &driver_id, // Check that lookup returns the added object entries. auto lookup_callback = [task_id, node_manager_ids]( gcs::AsyncGcsClient *client, const TaskID &id, - const std::vector &data) { + const std::vector &data) { ASSERT_EQ(id, task_id); for (const auto &entry : data) { - ASSERT_EQ(entry.node_manager_id(), node_manager_ids[test->NumCallbacks()]); + ASSERT_EQ(entry.node_manager_id, node_manager_ids[test->NumCallbacks()]); test->IncrementNumCallbacks(); } if (test->NumCallbacks() == node_manager_ids.size()) { @@ -182,7 +182,7 @@ void TestTableLookupFailure(const DriverID &driver_id, // Check that the lookup does not return data. auto lookup_callback = [](gcs::AsyncGcsClient *client, const TaskID &id, - const TaskTableData &d) { RAY_CHECK(false); }; + const protocol::TaskT &d) { RAY_CHECK(false); }; // Check that the lookup returns an empty entry. auto failure_callback = [task_id](gcs::AsyncGcsClient *client, const TaskID &id) { @@ -207,16 +207,16 @@ void TestLogAppendAt(const DriverID &driver_id, std::shared_ptr client) { TaskID task_id = TaskID::FromRandom(); std::vector node_manager_ids = {"A", "B"}; - std::vector> data_log; + std::vector> data_log; for (const auto &node_manager_id : node_manager_ids) { - auto data = std::make_shared(); - data->set_node_manager_id(node_manager_id); + auto data = std::make_shared(); + data->node_manager_id = node_manager_id; data_log.push_back(data); } // Check that we added the correct task. auto failure_callback = [task_id](gcs::AsyncGcsClient *client, const TaskID &id, - const TaskReconstructionData &d) { + const TaskReconstructionDataT &d) { ASSERT_EQ(id, task_id); test->IncrementNumCallbacks(); }; @@ -242,10 +242,10 @@ void TestLogAppendAt(const DriverID &driver_id, auto lookup_callback = [node_manager_ids]( gcs::AsyncGcsClient *client, const TaskID &id, - const std::vector &data) { + const std::vector &data) { std::vector appended_managers; for (const auto &entry : data) { - appended_managers.push_back(entry.node_manager_id()); + appended_managers.push_back(entry.node_manager_id); } ASSERT_EQ(appended_managers, node_manager_ids); test->Stop(); @@ -268,22 +268,22 @@ void TestSet(const DriverID &driver_id, std::shared_ptr cli ObjectID object_id = ObjectID::FromRandom(); std::vector managers = {"abc", "def", "ghi"}; for (auto &manager : managers) { - auto data = std::make_shared(); - data->set_manager(manager); + auto data = std::make_shared(); + data->manager = manager; // Check that we added the correct object entries. auto add_callback = [object_id, data](gcs::AsyncGcsClient *client, const ObjectID &id, - const ObjectTableData &d) { + const ObjectTableDataT &d) { ASSERT_EQ(id, object_id); - ASSERT_EQ(data->manager(), d.manager()); + ASSERT_EQ(data->manager, d.manager); test->IncrementNumCallbacks(); }; RAY_CHECK_OK(client->object_table().Add(driver_id, object_id, data, add_callback)); } // Check that lookup returns the added object entries. - auto lookup_callback = [object_id, managers](gcs::AsyncGcsClient *client, - const ObjectID &id, - const std::vector &data) { + auto lookup_callback = [object_id, managers]( + gcs::AsyncGcsClient *client, const ObjectID &id, + const std::vector &data) { ASSERT_EQ(id, object_id); ASSERT_EQ(data.size(), managers.size()); test->IncrementNumCallbacks(); @@ -293,14 +293,14 @@ void TestSet(const DriverID &driver_id, std::shared_ptr cli RAY_CHECK_OK(client->object_table().Lookup(driver_id, object_id, lookup_callback)); for (auto &manager : managers) { - auto data = std::make_shared(); - data->set_manager(manager); + auto data = std::make_shared(); + data->manager = manager; // Check that we added the correct object entries. auto remove_entry_callback = [object_id, data](gcs::AsyncGcsClient *client, const ObjectID &id, - const ObjectTableData &d) { + const ObjectTableDataT &d) { ASSERT_EQ(id, object_id); - ASSERT_EQ(data->manager(), d.manager()); + ASSERT_EQ(data->manager, d.manager); test->IncrementNumCallbacks(); }; RAY_CHECK_OK( @@ -310,7 +310,7 @@ void TestSet(const DriverID &driver_id, std::shared_ptr cli // Check that the entries are removed. auto lookup_callback2 = [object_id, managers]( gcs::AsyncGcsClient *client, const ObjectID &id, - const std::vector &data) { + const std::vector &data) { ASSERT_EQ(id, object_id); ASSERT_EQ(data.size(), 0); test->IncrementNumCallbacks(); @@ -332,7 +332,7 @@ TEST_F(TestGcsWithAsio, TestSet) { void TestDeleteKeysFromLog( const DriverID &driver_id, std::shared_ptr client, - std::vector> &data_vector) { + std::vector> &data_vector) { std::vector ids; TaskID task_id; for (auto &data : data_vector) { @@ -340,9 +340,9 @@ void TestDeleteKeysFromLog( ids.push_back(task_id); // Check that we added the correct object entries. auto add_callback = [task_id, data](gcs::AsyncGcsClient *client, const TaskID &id, - const TaskReconstructionData &d) { + const TaskReconstructionDataT &d) { ASSERT_EQ(id, task_id); - ASSERT_EQ(data->node_manager_id(), d.node_manager_id()); + ASSERT_EQ(data->node_manager_id, d.node_manager_id); test->IncrementNumCallbacks(); }; RAY_CHECK_OK( @@ -352,7 +352,7 @@ void TestDeleteKeysFromLog( // Check that lookup returns the added object entries. auto lookup_callback = [task_id, data_vector]( gcs::AsyncGcsClient *client, const TaskID &id, - const std::vector &data) { + const std::vector &data) { ASSERT_EQ(id, task_id); ASSERT_EQ(data.size(), 1); test->IncrementNumCallbacks(); @@ -367,7 +367,7 @@ void TestDeleteKeysFromLog( } for (const auto &task_id : ids) { auto lookup_callback = [task_id](gcs::AsyncGcsClient *client, const TaskID &id, - const std::vector &data) { + const std::vector &data) { ASSERT_EQ(id, task_id); ASSERT_TRUE(data.size() == 0); test->IncrementNumCallbacks(); @@ -379,7 +379,7 @@ void TestDeleteKeysFromLog( void TestDeleteKeysFromTable(const DriverID &driver_id, std::shared_ptr client, - std::vector> &data_vector, + std::vector> &data_vector, bool stop_at_end) { std::vector ids; TaskID task_id; @@ -388,16 +388,16 @@ void TestDeleteKeysFromTable(const DriverID &driver_id, ids.push_back(task_id); // Check that we added the correct object entries. auto add_callback = [task_id, data](gcs::AsyncGcsClient *client, const TaskID &id, - const TaskTableData &d) { + const protocol::TaskT &d) { ASSERT_EQ(id, task_id); - ASSERT_EQ(data->task(), d.task()); + ASSERT_EQ(data->task_specification, d.task_specification); test->IncrementNumCallbacks(); }; RAY_CHECK_OK(client->raylet_task_table().Add(driver_id, task_id, data, add_callback)); } for (const auto &task_id : ids) { auto task_lookup_callback = [task_id](gcs::AsyncGcsClient *client, const TaskID &id, - const TaskTableData &data) { + const protocol::TaskT &data) { ASSERT_EQ(id, task_id); test->IncrementNumCallbacks(); }; @@ -414,7 +414,7 @@ void TestDeleteKeysFromTable(const DriverID &driver_id, test->IncrementNumCallbacks(); }; auto undesired_callback = [](gcs::AsyncGcsClient *client, const TaskID &id, - const TaskTableData &data) { ASSERT_TRUE(false); }; + const protocol::TaskT &data) { ASSERT_TRUE(false); }; for (size_t i = 0; i < ids.size(); ++i) { RAY_CHECK_OK(client->raylet_task_table().Lookup( driver_id, task_id, undesired_callback, expected_failure_callback)); @@ -428,7 +428,7 @@ void TestDeleteKeysFromTable(const DriverID &driver_id, void TestDeleteKeysFromSet(const DriverID &driver_id, std::shared_ptr client, - std::vector> &data_vector) { + std::vector> &data_vector) { std::vector ids; ObjectID object_id; for (auto &data : data_vector) { @@ -436,9 +436,9 @@ void TestDeleteKeysFromSet(const DriverID &driver_id, ids.push_back(object_id); // Check that we added the correct object entries. auto add_callback = [object_id, data](gcs::AsyncGcsClient *client, const ObjectID &id, - const ObjectTableData &d) { + const ObjectTableDataT &d) { ASSERT_EQ(id, object_id); - ASSERT_EQ(data->manager(), d.manager()); + ASSERT_EQ(data->manager, d.manager); test->IncrementNumCallbacks(); }; RAY_CHECK_OK(client->object_table().Add(driver_id, object_id, data, add_callback)); @@ -447,7 +447,7 @@ void TestDeleteKeysFromSet(const DriverID &driver_id, // Check that lookup returns the added object entries. auto lookup_callback = [object_id, data_vector]( gcs::AsyncGcsClient *client, const ObjectID &id, - const std::vector &data) { + const std::vector &data) { ASSERT_EQ(id, object_id); ASSERT_EQ(data.size(), 1); test->IncrementNumCallbacks(); @@ -461,7 +461,7 @@ void TestDeleteKeysFromSet(const DriverID &driver_id, } for (const auto &object_id : ids) { auto lookup_callback = [object_id](gcs::AsyncGcsClient *client, const ObjectID &id, - const std::vector &data) { + const std::vector &data) { ASSERT_EQ(id, object_id); ASSERT_TRUE(data.size() == 0); test->IncrementNumCallbacks(); @@ -474,11 +474,11 @@ void TestDeleteKeysFromSet(const DriverID &driver_id, void TestDeleteKeys(const DriverID &driver_id, std::shared_ptr client) { // Test delete function for keys of Log. - std::vector> task_reconstruction_vector; + std::vector> task_reconstruction_vector; auto AppendTaskReconstructionData = [&task_reconstruction_vector](size_t add_count) { for (size_t i = 0; i < add_count; ++i) { - auto data = std::make_shared(); - data->set_node_manager_id(ObjectID::FromRandom().Hex()); + auto data = std::make_shared(); + data->node_manager_id = ObjectID::FromRandom().Hex(); task_reconstruction_vector.push_back(data); } }; @@ -503,11 +503,11 @@ void TestDeleteKeys(const DriverID &driver_id, TestDeleteKeysFromLog(driver_id, client, task_reconstruction_vector); // Test delete function for keys of Table. - std::vector> task_vector; + std::vector> task_vector; auto AppendTaskData = [&task_vector](size_t add_count) { for (size_t i = 0; i < add_count; ++i) { - auto task_data = std::make_shared(); - task_data->set_task(ObjectID::FromRandom().Hex()); + auto task_data = std::make_shared(); + task_data->task_specification = ObjectID::FromRandom().Hex(); task_vector.push_back(task_data); } }; @@ -529,11 +529,11 @@ void TestDeleteKeys(const DriverID &driver_id, 9 * RayConfig::instance().maximum_gcs_deletion_batch_size()); // Test delete function for keys of Set. - std::vector> object_vector; + std::vector> object_vector; auto AppendObjectData = [&object_vector](size_t add_count) { for (size_t i = 0; i < add_count; ++i) { - auto data = std::make_shared(); - data->set_manager(ObjectID::FromRandom().Hex()); + auto data = std::make_shared(); + data->manager = ObjectID::FromRandom().Hex(); object_vector.push_back(data); } }; @@ -561,6 +561,45 @@ TEST_F(TestGcsWithAsio, TestDeleteKey) { TestDeleteKeys(driver_id_, client_); } +// Task table callbacks. +void TaskAdded(gcs::AsyncGcsClient *client, const TaskID &id, + const TaskTableDataT &data) { + ASSERT_EQ(data.scheduling_state, SchedulingState::SCHEDULED); + ASSERT_EQ(data.raylet_id, kRandomId); +} + +void TaskLookupHelper(gcs::AsyncGcsClient *client, const TaskID &id, + const TaskTableDataT &data, bool do_stop) { + ASSERT_EQ(data.scheduling_state, SchedulingState::SCHEDULED); + ASSERT_EQ(data.raylet_id, kRandomId); + if (do_stop) { + test->Stop(); + } +} +void TaskLookup(gcs::AsyncGcsClient *client, const TaskID &id, + const TaskTableDataT &data) { + TaskLookupHelper(client, id, data, /*do_stop=*/false); +} +void TaskLookupWithStop(gcs::AsyncGcsClient *client, const TaskID &id, + const TaskTableDataT &data) { + TaskLookupHelper(client, id, data, /*do_stop=*/true); +} + +void TaskLookupFailure(gcs::AsyncGcsClient *client, const TaskID &id) { + RAY_CHECK(false); +} + +void TaskLookupAfterUpdate(gcs::AsyncGcsClient *client, const TaskID &id, + const TaskTableDataT &data) { + ASSERT_EQ(data.scheduling_state, SchedulingState::LOST); + test->Stop(); +} + +void TaskLookupAfterUpdateFailure(gcs::AsyncGcsClient *client, const TaskID &id) { + RAY_CHECK(false); + test->Stop(); +} + void TestLogSubscribeAll(const DriverID &driver_id, std::shared_ptr client) { std::vector driver_ids; @@ -570,11 +609,11 @@ void TestLogSubscribeAll(const DriverID &driver_id, // Callback for a notification. auto notification_callback = [driver_ids](gcs::AsyncGcsClient *client, const DriverID &id, - const std::vector data) { + const std::vector data) { ASSERT_EQ(id, driver_ids[test->NumCallbacks()]); // Check that we get notifications in the same order as the writes. for (const auto &entry : data) { - ASSERT_EQ(entry.driver_id(), driver_ids[test->NumCallbacks()].Binary()); + ASSERT_EQ(entry.driver_id, driver_ids[test->NumCallbacks()].Binary()); test->IncrementNumCallbacks(); } if (test->NumCallbacks() == driver_ids.size()) { @@ -621,7 +660,7 @@ void TestSetSubscribeAll(const DriverID &driver_id, auto notification_callback = [object_ids, managers]( gcs::AsyncGcsClient *client, const ObjectID &id, const GcsChangeMode change_mode, - const std::vector data) { + const std::vector data) { if (test->NumCallbacks() < 3 * 3) { ASSERT_EQ(change_mode, GcsChangeMode::APPEND_OR_ADD); } else { @@ -630,7 +669,7 @@ void TestSetSubscribeAll(const DriverID &driver_id, ASSERT_EQ(id, object_ids[test->NumCallbacks() / 3 % 3]); // Check that we get notifications in the same order as the writes. for (const auto &entry : data) { - ASSERT_EQ(entry.manager(), managers[test->NumCallbacks() % 3]); + ASSERT_EQ(entry.manager, managers[test->NumCallbacks() % 3]); test->IncrementNumCallbacks(); } if (test->NumCallbacks() == object_ids.size() * 3 * 2) { @@ -645,8 +684,8 @@ void TestSetSubscribeAll(const DriverID &driver_id, // We have subscribed. Do the writes to the table. for (size_t i = 0; i < object_ids.size(); i++) { for (size_t j = 0; j < managers.size(); j++) { - auto data = std::make_shared(); - data->set_manager(managers[j]); + auto data = std::make_shared(); + data->manager = managers[j]; for (int k = 0; k < 3; k++) { // Add the same entry several times. // Expect no notification if the entry already exists. @@ -657,8 +696,8 @@ void TestSetSubscribeAll(const DriverID &driver_id, } for (size_t i = 0; i < object_ids.size(); i++) { for (size_t j = 0; j < managers.size(); j++) { - auto data = std::make_shared(); - data->set_manager(managers[j]); + auto data = std::make_shared(); + data->manager = managers[j]; for (int k = 0; k < 3; k++) { // Remove the same entry several times. // Expect no notification if the entry doesn't exist. @@ -701,11 +740,11 @@ void TestTableSubscribeId(const DriverID &driver_id, // received for keys that we requested notifications for. auto notification_callback = [task_id2, task_specs2](gcs::AsyncGcsClient *client, const TaskID &id, - const TaskTableData &data) { + const protocol::TaskT &data) { // Check that we only get notifications for the requested key. ASSERT_EQ(id, task_id2); // Check that we get notifications in the same order as the writes. - ASSERT_EQ(data.task(), task_specs2[test->NumCallbacks()]); + ASSERT_EQ(data.task_specification, task_specs2[test->NumCallbacks()]); test->IncrementNumCallbacks(); if (test->NumCallbacks() == task_specs2.size()) { test->Stop(); @@ -732,13 +771,13 @@ void TestTableSubscribeId(const DriverID &driver_id, // Write both keys. We should only receive notifications for the key that // we requested them for. for (const auto &task_spec : task_specs1) { - auto data = std::make_shared(); - data->set_task(task_spec); + auto data = std::make_shared(); + data->task_specification = task_spec; RAY_CHECK_OK(client->raylet_task_table().Add(driver_id, task_id1, data, nullptr)); } for (const auto &task_spec : task_specs2) { - auto data = std::make_shared(); - data->set_task(task_spec); + auto data = std::make_shared(); + data->task_specification = task_spec; RAY_CHECK_OK(client->raylet_task_table().Add(driver_id, task_id2, data, nullptr)); } }; @@ -769,27 +808,27 @@ void TestLogSubscribeId(const DriverID &driver_id, // Add a log entry. DriverID driver_id1 = DriverID::FromRandom(); std::vector driver_ids1 = {"abc", "def", "ghi"}; - auto data1 = std::make_shared(); - data1->set_driver_id(driver_ids1[0]); + auto data1 = std::make_shared(); + data1->driver_id = driver_ids1[0]; RAY_CHECK_OK(client->driver_table().Append(driver_id, driver_id1, data1, nullptr)); // Add a log entry at a second key. DriverID driver_id2 = DriverID::FromRandom(); std::vector driver_ids2 = {"jkl", "mno", "pqr"}; - auto data2 = std::make_shared(); - data2->set_driver_id(driver_ids2[0]); + auto data2 = std::make_shared(); + data2->driver_id = driver_ids2[0]; RAY_CHECK_OK(client->driver_table().Append(driver_id, driver_id2, data2, nullptr)); // The callback for a notification from the table. This should only be // received for keys that we requested notifications for. auto notification_callback = [driver_id2, driver_ids2]( gcs::AsyncGcsClient *client, const UniqueID &id, - const std::vector &data) { + const std::vector &data) { // Check that we only get notifications for the requested key. ASSERT_EQ(id, driver_id2); // Check that we get notifications in the same order as the writes. for (const auto &entry : data) { - ASSERT_EQ(entry.driver_id(), driver_ids2[test->NumCallbacks()]); + ASSERT_EQ(entry.driver_id, driver_ids2[test->NumCallbacks()]); test->IncrementNumCallbacks(); } if (test->NumCallbacks() == driver_ids2.size()) { @@ -808,14 +847,14 @@ void TestLogSubscribeId(const DriverID &driver_id, // we requested them for. auto remaining = std::vector(++driver_ids1.begin(), driver_ids1.end()); for (const auto &driver_id_it : remaining) { - auto data = std::make_shared(); - data->set_driver_id(driver_id_it); + auto data = std::make_shared(); + data->driver_id = driver_id_it; RAY_CHECK_OK(client->driver_table().Append(driver_id, driver_id1, data, nullptr)); } remaining = std::vector(++driver_ids2.begin(), driver_ids2.end()); for (const auto &driver_id_it : remaining) { - auto data = std::make_shared(); - data->set_driver_id(driver_id_it); + auto data = std::make_shared(); + data->driver_id = driver_id_it; RAY_CHECK_OK(client->driver_table().Append(driver_id, driver_id2, data, nullptr)); } }; @@ -843,15 +882,15 @@ void TestSetSubscribeId(const DriverID &driver_id, // Add a set entry. ObjectID object_id1 = ObjectID::FromRandom(); std::vector managers1 = {"abc", "def", "ghi"}; - auto data1 = std::make_shared(); - data1->set_manager(managers1[0]); + auto data1 = std::make_shared(); + data1->manager = managers1[0]; RAY_CHECK_OK(client->object_table().Add(driver_id, object_id1, data1, nullptr)); // Add a set entry at a second key. ObjectID object_id2 = ObjectID::FromRandom(); std::vector managers2 = {"jkl", "mno", "pqr"}; - auto data2 = std::make_shared(); - data2->set_manager(managers2[0]); + auto data2 = std::make_shared(); + data2->manager = managers2[0]; RAY_CHECK_OK(client->object_table().Add(driver_id, object_id2, data2, nullptr)); // The callback for a notification from the table. This should only be @@ -859,13 +898,13 @@ void TestSetSubscribeId(const DriverID &driver_id, auto notification_callback = [object_id2, managers2]( gcs::AsyncGcsClient *client, const ObjectID &id, const GcsChangeMode change_mode, - const std::vector &data) { + const std::vector &data) { ASSERT_EQ(change_mode, GcsChangeMode::APPEND_OR_ADD); // Check that we only get notifications for the requested key. ASSERT_EQ(id, object_id2); // Check that we get notifications in the same order as the writes. for (const auto &entry : data) { - ASSERT_EQ(entry.manager(), managers2[test->NumCallbacks()]); + ASSERT_EQ(entry.manager, managers2[test->NumCallbacks()]); test->IncrementNumCallbacks(); } if (test->NumCallbacks() == managers2.size()) { @@ -884,14 +923,14 @@ void TestSetSubscribeId(const DriverID &driver_id, // we requested them for. auto remaining = std::vector(++managers1.begin(), managers1.end()); for (const auto &manager : remaining) { - auto data = std::make_shared(); - data->set_manager(manager); + auto data = std::make_shared(); + data->manager = manager; RAY_CHECK_OK(client->object_table().Add(driver_id, object_id1, data, nullptr)); } remaining = std::vector(++managers2.begin(), managers2.end()); for (const auto &manager : remaining) { - auto data = std::make_shared(); - data->set_manager(manager); + auto data = std::make_shared(); + data->manager = manager; RAY_CHECK_OK(client->object_table().Add(driver_id, object_id2, data, nullptr)); } }; @@ -919,8 +958,8 @@ void TestTableSubscribeCancel(const DriverID &driver_id, // Add a table entry. TaskID task_id = TaskID::FromRandom(); std::vector task_specs = {"jkl", "mno", "pqr"}; - auto data = std::make_shared(); - data->set_task(task_specs[0]); + auto data = std::make_shared(); + data->task_specification = task_specs[0]; RAY_CHECK_OK(client->raylet_task_table().Add(driver_id, task_id, data, nullptr)); // The failure callback should not be called since all keys are non-empty @@ -933,14 +972,14 @@ void TestTableSubscribeCancel(const DriverID &driver_id, // received for keys that we requested notifications for. auto notification_callback = [task_id, task_specs](gcs::AsyncGcsClient *client, const TaskID &id, - const TaskTableData &data) { + const protocol::TaskT &data) { ASSERT_EQ(id, task_id); // Check that we only get notifications for the first and last writes, // since notifications are canceled in between. if (test->NumCallbacks() == 0) { - ASSERT_EQ(data.task(), task_specs.front()); + ASSERT_EQ(data.task_specification, task_specs.front()); } else { - ASSERT_EQ(data.task(), task_specs.back()); + ASSERT_EQ(data.task_specification, task_specs.back()); } test->IncrementNumCallbacks(); if (test->NumCallbacks() == 2) { @@ -962,8 +1001,8 @@ void TestTableSubscribeCancel(const DriverID &driver_id, // a notification for these writes. auto remaining = std::vector(++task_specs.begin(), task_specs.end()); for (const auto &task_spec : remaining) { - auto data = std::make_shared(); - data->set_task(task_spec); + auto data = std::make_shared(); + data->task_specification = task_spec; RAY_CHECK_OK(client->raylet_task_table().Add(driver_id, task_id, data, nullptr)); } // Request notifications again. We should receive a notification for the @@ -995,15 +1034,15 @@ void TestLogSubscribeCancel(const DriverID &driver_id, // Add a log entry. DriverID random_driver_id = DriverID::FromRandom(); std::vector driver_ids = {"jkl", "mno", "pqr"}; - auto data = std::make_shared(); - data->set_driver_id(driver_ids[0]); + auto data = std::make_shared(); + data->driver_id = driver_ids[0]; RAY_CHECK_OK(client->driver_table().Append(driver_id, random_driver_id, data, nullptr)); // The callback for a notification from the object table. This should only be // received for the object that we requested notifications for. auto notification_callback = [random_driver_id, driver_ids]( gcs::AsyncGcsClient *client, const UniqueID &id, - const std::vector &data) { + const std::vector &data) { ASSERT_EQ(id, random_driver_id); // Check that we get a duplicate notification for the first write. We get a // duplicate notification because the log is append-only and notifications @@ -1011,7 +1050,7 @@ void TestLogSubscribeCancel(const DriverID &driver_id, auto driver_ids_copy = driver_ids; driver_ids_copy.insert(driver_ids_copy.begin(), driver_ids_copy.front()); for (const auto &entry : data) { - ASSERT_EQ(entry.driver_id(), driver_ids_copy[test->NumCallbacks()]); + ASSERT_EQ(entry.driver_id, driver_ids_copy[test->NumCallbacks()]); test->IncrementNumCallbacks(); } if (test->NumCallbacks() == driver_ids_copy.size()) { @@ -1033,8 +1072,8 @@ void TestLogSubscribeCancel(const DriverID &driver_id, // receive a notification for these writes. auto remaining = std::vector(++driver_ids.begin(), driver_ids.end()); for (const auto &remaining_driver_id : remaining) { - auto data = std::make_shared(); - data->set_driver_id(remaining_driver_id); + auto data = std::make_shared(); + data->driver_id = remaining_driver_id; RAY_CHECK_OK( client->driver_table().Append(driver_id, random_driver_id, data, nullptr)); } @@ -1068,8 +1107,8 @@ void TestSetSubscribeCancel(const DriverID &driver_id, // Add a set entry. ObjectID object_id = ObjectID::FromRandom(); std::vector managers = {"jkl", "mno", "pqr"}; - auto data = std::make_shared(); - data->set_manager(managers[0]); + auto data = std::make_shared(); + data->manager = managers[0]; RAY_CHECK_OK(client->object_table().Add(driver_id, object_id, data, nullptr)); // The callback for a notification from the object table. This should only be @@ -1077,7 +1116,7 @@ void TestSetSubscribeCancel(const DriverID &driver_id, auto notification_callback = [object_id, managers]( gcs::AsyncGcsClient *client, const ObjectID &id, const GcsChangeMode change_mode, - const std::vector &data) { + const std::vector &data) { ASSERT_EQ(change_mode, GcsChangeMode::APPEND_OR_ADD); ASSERT_EQ(id, object_id); // Check that we get a duplicate notification for the first write. We get a @@ -1085,7 +1124,7 @@ void TestSetSubscribeCancel(const DriverID &driver_id, // are canceled after the first write, then requested again. if (data.size() == 1) { // first notification - ASSERT_EQ(data[0].manager(), managers[0]); + ASSERT_EQ(data[0].manager, managers[0]); test->IncrementNumCallbacks(); } else { // second notification @@ -1093,7 +1132,7 @@ void TestSetSubscribeCancel(const DriverID &driver_id, std::unordered_set managers_set(managers.begin(), managers.end()); std::unordered_set data_managers_set; for (const auto &entry : data) { - data_managers_set.insert(entry.manager()); + data_managers_set.insert(entry.manager); test->IncrementNumCallbacks(); } ASSERT_EQ(managers_set, data_managers_set); @@ -1117,8 +1156,8 @@ void TestSetSubscribeCancel(const DriverID &driver_id, // receive a notification for these writes. auto remaining = std::vector(++managers.begin(), managers.end()); for (const auto &manager : remaining) { - auto data = std::make_shared(); - data->set_manager(manager); + auto data = std::make_shared(); + data->manager = manager; RAY_CHECK_OK(client->object_table().Add(driver_id, object_id, data, nullptr)); } // Request notifications again. We should receive a notification for the @@ -1147,17 +1186,17 @@ TEST_F(TestGcsWithAsio, TestSetSubscribeCancel) { } void ClientTableNotification(gcs::AsyncGcsClient *client, const ClientID &client_id, - const ClientTableData &data, bool is_insertion) { + const ClientTableDataT &data, bool is_insertion) { ClientID added_id = client->client_table().GetLocalClientId(); ASSERT_EQ(client_id, added_id); - ASSERT_EQ(ClientID::FromBinary(data.client_id()), added_id); - ASSERT_EQ(ClientID::FromBinary(data.client_id()), added_id); - ASSERT_EQ(data.entry_type() == ClientTableData::INSERTION, is_insertion); + ASSERT_EQ(ClientID::FromBinary(data.client_id), added_id); + ASSERT_EQ(ClientID::FromBinary(data.client_id), added_id); + ASSERT_EQ(data.entry_type == EntryType::INSERTION, is_insertion); - ClientTableData cached_client; + ClientTableDataT cached_client; client->client_table().GetClient(added_id, cached_client); - ASSERT_EQ(ClientID::FromBinary(cached_client.client_id()), added_id); - ASSERT_EQ(cached_client.entry_type() == ClientTableData::INSERTION, is_insertion); + ASSERT_EQ(ClientID::FromBinary(cached_client.client_id), added_id); + ASSERT_EQ(cached_client.entry_type == EntryType::INSERTION, is_insertion); } void TestClientTableConnect(const DriverID &driver_id, @@ -1165,17 +1204,17 @@ void TestClientTableConnect(const DriverID &driver_id, // Register callbacks for when a client gets added and removed. The latter // event will stop the event loop. client->client_table().RegisterClientAddedCallback( - [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableData &data) { + [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) { ClientTableNotification(client, id, data, true); test->Stop(); }); // Connect and disconnect to client table. We should receive notifications // for the addition and removal of our own entry. - ClientTableData local_client_info = client->client_table().GetLocalClient(); - local_client_info.set_node_manager_address("127.0.0.1"); - local_client_info.set_node_manager_port(0); - local_client_info.set_object_manager_port(0); + ClientTableDataT local_client_info = client->client_table().GetLocalClient(); + local_client_info.node_manager_address = "127.0.0.1"; + local_client_info.node_manager_port = 0; + local_client_info.object_manager_port = 0; RAY_CHECK_OK(client->client_table().Connect(local_client_info)); test->Start(); } @@ -1190,23 +1229,23 @@ void TestClientTableDisconnect(const DriverID &driver_id, // Register callbacks for when a client gets added and removed. The latter // event will stop the event loop. client->client_table().RegisterClientAddedCallback( - [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableData &data) { + [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) { ClientTableNotification(client, id, data, /*is_insertion=*/true); // Disconnect from the client table. We should receive a notification // for the removal of our own entry. RAY_CHECK_OK(client->client_table().Disconnect()); }); client->client_table().RegisterClientRemovedCallback( - [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableData &data) { + [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) { ClientTableNotification(client, id, data, /*is_insertion=*/false); test->Stop(); }); // Connect to the client table. We should receive notification for the // addition of our own entry. - ClientTableData local_client_info = client->client_table().GetLocalClient(); - local_client_info.set_node_manager_address("127.0.0.1"); - local_client_info.set_node_manager_port(0); - local_client_info.set_object_manager_port(0); + ClientTableDataT local_client_info = client->client_table().GetLocalClient(); + local_client_info.node_manager_address = "127.0.0.1"; + local_client_info.node_manager_port = 0; + local_client_info.object_manager_port = 0; RAY_CHECK_OK(client->client_table().Connect(local_client_info)); test->Start(); } @@ -1221,20 +1260,20 @@ void TestClientTableImmediateDisconnect(const DriverID &driver_id, // Register callbacks for when a client gets added and removed. The latter // event will stop the event loop. client->client_table().RegisterClientAddedCallback( - [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableData &data) { + [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) { ClientTableNotification(client, id, data, true); }); client->client_table().RegisterClientRemovedCallback( - [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableData &data) { + [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) { ClientTableNotification(client, id, data, false); test->Stop(); }); // Connect to then immediately disconnect from the client table. We should // receive notifications for the addition and removal of our own entry. - ClientTableData local_client_info = client->client_table().GetLocalClient(); - local_client_info.set_node_manager_address("127.0.0.1"); - local_client_info.set_node_manager_port(0); - local_client_info.set_object_manager_port(0); + ClientTableDataT local_client_info = client->client_table().GetLocalClient(); + local_client_info.node_manager_address = "127.0.0.1"; + local_client_info.node_manager_port = 0; + local_client_info.object_manager_port = 0; RAY_CHECK_OK(client->client_table().Connect(local_client_info)); RAY_CHECK_OK(client->client_table().Disconnect()); test->Start(); @@ -1247,10 +1286,10 @@ TEST_F(TestGcsWithAsio, TestClientTableImmediateDisconnect) { void TestClientTableMarkDisconnected(const DriverID &driver_id, std::shared_ptr client) { - ClientTableData local_client_info = client->client_table().GetLocalClient(); - local_client_info.set_node_manager_address("127.0.0.1"); - local_client_info.set_node_manager_port(0); - local_client_info.set_object_manager_port(0); + ClientTableDataT local_client_info = client->client_table().GetLocalClient(); + local_client_info.node_manager_address = "127.0.0.1"; + local_client_info.node_manager_port = 0; + local_client_info.object_manager_port = 0; // Connect to the client table to start receiving notifications. RAY_CHECK_OK(client->client_table().Connect(local_client_info)); // Mark a different client as dead. @@ -1260,8 +1299,8 @@ void TestClientTableMarkDisconnected(const DriverID &driver_id, // marked as dead. client->client_table().RegisterClientRemovedCallback( [dead_client_id](gcs::AsyncGcsClient *client, const UniqueID &id, - const ClientTableData &data) { - ASSERT_EQ(ClientID::FromBinary(data.client_id()), dead_client_id); + const ClientTableDataT &data) { + ASSERT_EQ(ClientID::FromBinary(data.client_id), dead_client_id); test->Stop(); }); test->Start(); @@ -1277,31 +1316,31 @@ void TestHashTable(const DriverID &driver_id, const int expected_count = 14; ClientID client_id = ClientID::FromRandom(); // Prepare the first resource map: data_map1. - auto cpu_data = std::make_shared(); - cpu_data->set_resource_name("CPU"); - cpu_data->set_resource_capacity(100); - auto gpu_data = std::make_shared(); - gpu_data->set_resource_name("GPU"); - gpu_data->set_resource_capacity(2); + auto cpu_data = std::make_shared(); + cpu_data->resource_name = "CPU"; + cpu_data->resource_capacity = 100; + auto gpu_data = std::make_shared(); + gpu_data->resource_name = "GPU"; + gpu_data->resource_capacity = 2; DynamicResourceTable::DataMap data_map1; data_map1.emplace("CPU", cpu_data); data_map1.emplace("GPU", gpu_data); // Prepare the second resource map: data_map2 which decreases CPU, // increases GPU and add a new CUSTOM compared to data_map1. - auto data_cpu = std::make_shared(); - data_cpu->set_resource_name("CPU"); - data_cpu->set_resource_capacity(50); - auto data_gpu = std::make_shared(); - data_gpu->set_resource_name("GPU"); - data_gpu->set_resource_capacity(10); - auto data_custom = std::make_shared(); - data_custom->set_resource_name("CUSTOM"); - data_custom->set_resource_capacity(2); + auto data_cpu = std::make_shared(); + data_cpu->resource_name = "CPU"; + data_cpu->resource_capacity = 50; + auto data_gpu = std::make_shared(); + data_gpu->resource_name = "GPU"; + data_gpu->resource_capacity = 10; + auto data_custom = std::make_shared(); + data_custom->resource_name = "CUSTOM"; + data_custom->resource_capacity = 2; DynamicResourceTable::DataMap data_map2; data_map2.emplace("CPU", data_cpu); data_map2.emplace("GPU", data_gpu); data_map2.emplace("CUSTOM", data_custom); - data_map2["CPU"]->set_resource_capacity(50); + data_map2["CPU"]->resource_capacity = 50; // This is a common comparison function for the test. auto compare_test = [](const DynamicResourceTable::DataMap &data1, const DynamicResourceTable::DataMap &data2) { @@ -1309,8 +1348,8 @@ void TestHashTable(const DriverID &driver_id, for (const auto &data : data1) { auto iter = data2.find(data.first); ASSERT_TRUE(iter != data2.end()); - ASSERT_EQ(iter->second->resource_name(), data.second->resource_name()); - ASSERT_EQ(iter->second->resource_capacity(), data.second->resource_capacity()); + ASSERT_EQ(iter->second->resource_name, data.second->resource_name); + ASSERT_EQ(iter->second->resource_capacity, data.second->resource_capacity); } }; auto subscribe_callback = [](AsyncGcsClient *client) { diff --git a/src/ray/gcs/format/gcs.fbs b/src/ray/gcs/format/gcs.fbs index c06c79a02928..614c80b27672 100644 --- a/src/ray/gcs/format/gcs.fbs +++ b/src/ray/gcs/format/gcs.fbs @@ -1,9 +1,52 @@ -// TODO(hchen): Migrate data structures in this file to protobuf (`gcs.proto`). - enum Language:int { - PYTHON=0, - JAVA=1, - CPP=2, + PYTHON = 0, + CPP = 1, + JAVA = 2 +} + +// These indexes are mapped to strings in ray_redis_module.cc. +enum TablePrefix:int { + UNUSED = 0, + TASK, + RAYLET_TASK, + CLIENT, + OBJECT, + ACTOR, + FUNCTION, + TASK_RECONSTRUCTION, + HEARTBEAT, + HEARTBEAT_BATCH, + ERROR_INFO, + DRIVER, + PROFILE, + TASK_LEASE, + ACTOR_CHECKPOINT, + ACTOR_CHECKPOINT_ID, + NODE_RESOURCE, +} + +// The channel that Add operations to the Table should be published on, if any. +enum TablePubsub:int { + NO_PUBLISH = 0, + TASK, + RAYLET_TASK, + CLIENT, + OBJECT, + ACTOR, + HEARTBEAT, + HEARTBEAT_BATCH, + ERROR_INFO, + TASK_LEASE, + DRIVER, + NODE_RESOURCE, +} + +// Enum for the entry type in the ClientTable +enum EntryType:int { + INSERTION = 0, + DELETION, + RES_CREATEUPDATE, + RES_DELETE, } table Arg { @@ -63,11 +106,6 @@ table TaskInfo { // For a Python function, it should be: [module_name, class_name, function_name] // For a Java function, it should be: [class_name, method_name, type_descriptor] function_descriptor: [string]; - // The dynamic options used in the worker command when starting the worker process for - // an actor creation task. If the list isn't empty, the options will be used to replace - // the placeholder strings (`RAY_WORKER_OPTION_0`, `RAY_WORKER_OPTION_1`, etc) in the - // worker command. - dynamic_worker_options: [string]; } table ResourcePair { @@ -77,6 +115,118 @@ table ResourcePair { value: double; } +enum GcsChangeMode:int { + APPEND_OR_ADD = 0, + REMOVE, +} + +table GcsEntry { + change_mode: GcsChangeMode; + id: string; + entries: [string]; +} + +table FunctionTableData { + language: Language; + name: string; + data: string; +} + +table ObjectTableData { + // The size of the object. + object_size: long; + // The node manager ID that this object appeared on or was evicted by. + manager: string; +} + +table TaskReconstructionData { + // The number of times this task has been reconstructed so far. + num_reconstructions: int; + // The node manager that is trying to reconstruct the task. + node_manager_id: string; +} + +enum SchedulingState:int { + NONE = 0, + WAITING = 1, + SCHEDULED = 2, + QUEUED = 4, + RUNNING = 8, + DONE = 16, + LOST = 32, + RECONSTRUCTING = 64 +} + +table TaskTableData { + // The state of the task. + scheduling_state: SchedulingState; + // A raylet ID. + raylet_id: string; + // A string of bytes representing the task's TaskExecutionDependencies. + execution_dependencies: string; + // The number of times the task was spilled back by raylets. + spillback_count: long; + // A string of bytes representing the task specification. + task_info: string; + // TODO(pcm): This is at the moment duplicated in task_info, remove that one + updated: bool; +} + +table TaskTableTestAndUpdate { + test_raylet_id: string; + test_state_bitmask: SchedulingState; + update_state: SchedulingState; +} + +table ClassTableData { +} + +enum ActorState:int { + // Actor is alive. + ALIVE = 0, + // Actor is dead, now being reconstructed. + // After reconstruction finishes, the state will become alive again. + RECONSTRUCTING = 1, + // Actor is already dead and won't be reconstructed. + DEAD = 2 +} + +table ActorTableData { + // The ID of the actor that was created. + actor_id: string; + // The dummy object ID returned by the actor creation task. If the actor + // dies, then this is the object that should be reconstructed for the actor + // to be recreated. + actor_creation_dummy_object_id: string; + // The ID of the driver that created the actor. + driver_id: string; + // The ID of the node manager that created the actor. + node_manager_id: string; + // Current state of this actor. + state: ActorState; + // Max number of times this actor should be reconstructed. + max_reconstructions: int; + // Remaining number of reconstructions. + remaining_reconstructions: int; +} + +table ErrorTableData { + // The ID of the driver that the error is for. + driver_id: string; + // The type of the error. + type: string; + // The error message. + error_message: string; + // The timestamp of the error message. + timestamp: double; +} + +table CustomSerializerData { +} + +table ConfigTableData { +} + table ProfileEvent { // The type of the event. event_type: string; @@ -103,3 +253,119 @@ table ProfileTableData { // we don't want each event to require a GCS command. profile_events: [ProfileEvent]; } + +table RayResource { + // The type of the resource. + resource_name: string; + // The total capacity of this resource type. + resource_capacity: double; +} + +table ClientTableData { + // The client ID of the client that the message is about. + client_id: string; + // The IP address of the client's node manager. + node_manager_address: string; + // The IPC socket name of the client's raylet. + raylet_socket_name: string; + // The IPC socket name of the client's plasma store. + object_store_socket_name: string; + // The port at which the client's node manager is listening for TCP + // connections from other node managers. + node_manager_port: int; + // The port at which the client's object manager is listening for TCP + // connections from other object managers. + object_manager_port: int; + // Enum to store the entry type in the log + entry_type: EntryType = INSERTION; + resources_total_label: [string]; + resources_total_capacity: [double]; +} + +table HeartbeatTableData { + // Node manager client id + client_id: string; + // Resource capacity currently available on this node manager. + resources_available_label: [string]; + resources_available_capacity: [double]; + // Total resource capacity configured for this node manager. + resources_total_label: [string]; + resources_total_capacity: [double]; + // Aggregate outstanding resource load on this node manager. + resource_load_label: [string]; + resource_load_capacity: [double]; +} + +table HeartbeatBatchTableData { + batch: [HeartbeatTableData]; +} + +// Data for a lease on task execution. +table TaskLeaseData { + // Node manager client ID. + node_manager_id: string; + // The time that the lease was last acquired at. NOTE(swang): This is the + // system clock time according to the node that added the entry and is not + // synchronized with other nodes. + acquired_at: long; + // The period that the lease is active for. + timeout: long; +} + +table DriverTableData { + // The driver ID. + driver_id: string; + // Whether it's dead. + is_dead: bool; +} + +// This table stores the actor checkpoint data. An actor checkpoint +// is the snapshot of an actor's state in the actor registration. +// See `actor_registration.h` for more detailed explanation of these fields. +table ActorCheckpointData { + // ID of this actor. + actor_id: string; + // The dummy object ID of actor's most recently executed task. + execution_dependency: string; + // A list of IDs of this actor's handles. + handle_ids: [string]; + // The task counters of the above handles. + task_counters: [long]; + // The frontier dependencies of the above handles. + frontier_dependencies: [string]; + // A list of unreleased dummy objects from this actor. + unreleased_dummy_objects: [string]; + // The numbers of dependencies for the above unreleased dummy objects. + num_dummy_object_dependencies: [int]; +} + +// This table stores the actor-to-available-checkpoint-ids mapping. +table ActorCheckpointIdData { + // ID of this actor. + actor_id: string; + // IDs of this actor's available checkpoints. + // Note, this is a long string that concatenates all the IDs. + checkpoint_ids: string; + // A list of the timestamps for each of the above `checkpoint_ids`. + timestamps: [long]; +} + +// This enum type is used as object's metadata to indicate the object's creating +// task has failed because of a certain error. +// TODO(hchen): We may want to make these errors more specific. E.g., we may want +// to distinguish between intentional and expected actor failures, and between +// worker process failure and node failure. +enum ErrorType:int { + // Indicates that a task failed because the worker died unexpectedly while executing it. + WORKER_DIED = 1, + // Indicates that a task failed because the actor died unexpectedly before finishing it. + ACTOR_DIED = 2, + // Indicates that an object is lost and cannot be reconstructed. + // Note, this currently only happens to actor objects. When the actor's state is already + // after the object's creating task, the actor cannot re-run the task. + // TODO(hchen): we may want to reuse this error type for more cases. E.g., + // 1) A object that was put by the driver. + // 2) The object's creating task is already cleaned up from GCS (this currently + // crashes raylet). + OBJECT_UNRECONSTRUCTABLE = 3, +} diff --git a/src/ray/gcs/redis_context.h b/src/ray/gcs/redis_context.h index 093aab2455d9..fc42e5cd98c2 100644 --- a/src/ray/gcs/redis_context.h +++ b/src/ray/gcs/redis_context.h @@ -9,7 +9,7 @@ #include "ray/common/status.h" #include "ray/util/logging.h" -#include "ray/protobuf/gcs.pb.h" +#include "ray/gcs/format/gcs_generated.h" extern "C" { #include "ray/thirdparty/hiredis/adapters/ae.h" @@ -25,9 +25,6 @@ namespace ray { namespace gcs { -using rpc::TablePrefix; -using rpc::TablePubsub; - /// A simple reply wrapper for redis reply. class CallbackReply { public: @@ -129,8 +126,8 @@ class RedisContext { /// -1 for unused. If set, then data must be provided. /// \return Status. template - Status RunAsync(const std::string &command, const ID &id, const void *data, - size_t length, const TablePrefix prefix, + Status RunAsync(const std::string &command, const ID &id, const uint8_t *data, + int64_t length, const TablePrefix prefix, const TablePubsub pubsub_channel, RedisCallback redisCallback, int log_length = -1); @@ -160,9 +157,9 @@ class RedisContext { }; template -Status RedisContext::RunAsync(const std::string &command, const ID &id, const void *data, - size_t length, const TablePrefix prefix, - const TablePubsub pubsub_channel, +Status RedisContext::RunAsync(const std::string &command, const ID &id, + const uint8_t *data, int64_t length, + const TablePrefix prefix, const TablePubsub pubsub_channel, RedisCallback redisCallback, int log_length) { int64_t callback_index = RedisCallbackManager::instance().add(redisCallback, false); if (length > 0) { diff --git a/src/ray/gcs/redis_module/ray_redis_module.cc b/src/ray/gcs/redis_module/ray_redis_module.cc index c3a82c320d06..e291b7ffdb32 100644 --- a/src/ray/gcs/redis_module/ray_redis_module.cc +++ b/src/ray/gcs/redis_module/ray_redis_module.cc @@ -5,16 +5,11 @@ #include "ray/common/id.h" #include "ray/common/status.h" #include "ray/gcs/format/gcs_generated.h" -#include "ray/protobuf/gcs.pb.h" #include "ray/util/logging.h" #include "redis_string.h" #include "redismodule.h" using ray::Status; -using ray::rpc::GcsChangeMode; -using ray::rpc::GcsEntry; -using ray::rpc::TablePrefix; -using ray::rpc::TablePubsub; #if RAY_USE_NEW_GCS // Under this flag, ray-project/credis will be loaded. Specifically, via @@ -69,8 +64,8 @@ Status ParseTablePubsub(TablePubsub *out, const RedisModuleString *pubsub_channe REDISMODULE_OK) { return Status::RedisError("Pubsub channel must be a valid integer."); } - if (pubsub_channel_long >= static_cast(TablePubsub::TABLE_PUBSUB_MAX) || - pubsub_channel_long <= static_cast(TablePubsub::TABLE_PUBSUB_MIN)) { + if (pubsub_channel_long > static_cast(TablePubsub::MAX) || + pubsub_channel_long < static_cast(TablePubsub::MIN)) { return Status::RedisError("Pubsub channel must be in the TablePubsub range."); } else { *out = static_cast(pubsub_channel_long); @@ -85,7 +80,7 @@ Status FormatPubsubChannel(RedisModuleString **out, RedisModuleCtx *ctx, const RedisModuleString *id) { // Format the pubsub channel enum to a string. TablePubsub_MAX should be more // than enough digits, but add 1 just in case for the null terminator. - char pubsub_channel[static_cast(TablePubsub::TABLE_PUBSUB_MAX) + 1]; + char pubsub_channel[static_cast(TablePubsub::MAX) + 1]; TablePubsub table_pubsub; RAY_RETURN_NOT_OK(ParseTablePubsub(&table_pubsub, pubsub_channel_str)); sprintf(pubsub_channel, "%d", static_cast(table_pubsub)); @@ -100,8 +95,8 @@ Status ParseTablePrefix(const RedisModuleString *table_prefix_str, TablePrefix * REDISMODULE_OK) { return Status::RedisError("Prefix must be a valid TablePrefix integer"); } - if (table_prefix_long >= static_cast(TablePrefix::TABLE_PREFIX_MAX) || - table_prefix_long <= static_cast(TablePrefix::TABLE_PREFIX_MIN)) { + if (table_prefix_long > static_cast(TablePrefix::MAX) || + table_prefix_long < static_cast(TablePrefix::MIN)) { return Status::RedisError("Prefix must be in the TablePrefix range"); } else { *out = static_cast(table_prefix_long); @@ -118,7 +113,7 @@ RedisModuleString *PrefixedKeyString(RedisModuleCtx *ctx, RedisModuleString *pre if (!ParseTablePrefix(prefix_enum, &prefix).ok()) { return nullptr; } - return RedisString_Format(ctx, "%s%S", TablePrefix_Name(prefix).c_str(), keyname); + return RedisString_Format(ctx, "%s%S", EnumNameTablePrefix(prefix), keyname); } // TODO(swang): This helper function should be deprecated by the version below, @@ -141,8 +136,8 @@ Status OpenPrefixedKey(RedisModuleKey **out, RedisModuleCtx *ctx, int mode, RedisModuleString **mutated_key_str) { TablePrefix prefix; RAY_RETURN_NOT_OK(ParseTablePrefix(prefix_enum, &prefix)); - *out = OpenPrefixedKey(ctx, TablePrefix_Name(prefix).c_str(), keyname, mode, - mutated_key_str); + *out = + OpenPrefixedKey(ctx, EnumNameTablePrefix(prefix), keyname, mode, mutated_key_str); return Status::OK(); } @@ -170,24 +165,18 @@ Status GetBroadcastKey(RedisModuleCtx *ctx, RedisModuleString *pubsub_channel_st return Status::OK(); } -/// A helper function that creates `GcsEntry` protobuf object. +/// This is a helper method to convert a redis module string to a flatbuffer +/// string. /// -/// \param[in] id Id of the entry. -/// \param[in] change_mode Change mode of the entry. -/// \param[in] entries Vector of entries. -/// \param[out] result The created `GcsEntry` object. -inline void CreateGcsEntry(RedisModuleString *id, GcsChangeMode change_mode, - const std::vector &entries, - GcsEntry *result) { - const char *data; - size_t size; - data = RedisModule_StringPtrLen(id, &size); - result->set_id(data, size); - result->set_change_mode(change_mode); - for (const auto &entry : entries) { - data = RedisModule_StringPtrLen(entry, &size); - result->add_entries(data, size); - } +/// \param fbb The flatbuffer builder. +/// \param redis_string The redis string. +/// \return The flatbuffer string. +flatbuffers::Offset RedisStringToFlatbuf( + flatbuffers::FlatBufferBuilder &fbb, RedisModuleString *redis_string) { + size_t redis_string_size; + const char *redis_string_str = + RedisModule_StringPtrLen(redis_string, &redis_string_size); + return fbb.CreateString(redis_string_str, redis_string_size); } /// Helper method to publish formatted data to target channel. @@ -245,10 +234,13 @@ int PublishTableUpdate(RedisModuleCtx *ctx, RedisModuleString *pubsub_channel_st RedisModuleString *id, GcsChangeMode change_mode, RedisModuleString *data) { // Serialize the notification to send. - GcsEntry gcs_entry; - CreateGcsEntry(id, change_mode, {data}, &gcs_entry); - std::string str = gcs_entry.SerializeAsString(); - auto data_buffer = RedisModule_CreateString(ctx, str.data(), str.size()); + flatbuffers::FlatBufferBuilder fbb; + auto data_flatbuf = RedisStringToFlatbuf(fbb, data); + auto message = CreateGcsEntry(fbb, change_mode, RedisStringToFlatbuf(fbb, id), + fbb.CreateVector(&data_flatbuf, 1)); + fbb.Finish(message); + auto data_buffer = RedisModule_CreateString( + ctx, reinterpret_cast(fbb.GetBufferPointer()), fbb.GetSize()); return PublishDataHelper(ctx, pubsub_channel_str, id, data_buffer); } @@ -578,20 +570,19 @@ int HashUpdate_DoWrite(RedisModuleCtx *ctx, RedisModuleString **argv, int argc, size_t update_data_len = 0; const char *update_data_buf = RedisModule_StringPtrLen(update_data, &update_data_len); - GcsEntry gcs_entry; - gcs_entry.ParseFromArray(update_data_buf, update_data_len); - *change_mode = gcs_entry.change_mode(); - + auto data_vec = flatbuffers::GetRoot(update_data_buf); + *change_mode = data_vec->change_mode(); if (*change_mode == GcsChangeMode::APPEND_OR_ADD) { // This code path means they are updating command. - size_t total_size = gcs_entry.entries_size(); + size_t total_size = data_vec->entries()->size(); REPLY_AND_RETURN_IF_FALSE(total_size % 2 == 0, "Invalid Hash Update data vector."); for (int i = 0; i < total_size; i += 2) { // Reconstruct a key-value pair from a flattened list. RedisModuleString *entry_key = RedisModule_CreateString( - ctx, gcs_entry.entries(i).data(), gcs_entry.entries(i).size()); - RedisModuleString *entry_value = RedisModule_CreateString( - ctx, gcs_entry.entries(i + 1).data(), gcs_entry.entries(i + 1).size()); + ctx, data_vec->entries()->Get(i)->data(), data_vec->entries()->Get(i)->size()); + RedisModuleString *entry_value = + RedisModule_CreateString(ctx, data_vec->entries()->Get(i + 1)->data(), + data_vec->entries()->Get(i + 1)->size()); // Returning 0 if key exists(still updated), 1 if the key is created. RAY_IGNORE_EXPR( RedisModule_HashSet(key, REDISMODULE_HASH_NONE, entry_key, entry_value, NULL)); @@ -599,25 +590,27 @@ int HashUpdate_DoWrite(RedisModuleCtx *ctx, RedisModuleString **argv, int argc, *changed_data = update_data; } else { // This code path means the command wants to remove the entries. - GcsEntry updated; - updated.set_id(gcs_entry.id()); - updated.set_change_mode(gcs_entry.change_mode()); - - size_t total_size = gcs_entry.entries_size(); + size_t total_size = data_vec->entries()->size(); + flatbuffers::FlatBufferBuilder fbb; + std::vector> data; for (int i = 0; i < total_size; i++) { RedisModuleString *entry_key = RedisModule_CreateString( - ctx, gcs_entry.entries(i).data(), gcs_entry.entries(i).size()); + ctx, data_vec->entries()->Get(i)->data(), data_vec->entries()->Get(i)->size()); int deleted_num = RedisModule_HashSet(key, REDISMODULE_HASH_NONE, entry_key, REDISMODULE_HASH_DELETE, NULL); if (deleted_num != 0) { // The corresponding key is removed. - updated.add_entries(gcs_entry.entries(i)); + data.push_back(fbb.CreateString(data_vec->entries()->Get(i)->data(), + data_vec->entries()->Get(i)->size())); } } - - // Serialize updated data. - std::string str = updated.SerializeAsString(); - *changed_data = RedisModule_CreateString(ctx, str.data(), str.size()); + auto message = + CreateGcsEntry(fbb, data_vec->change_mode(), + fbb.CreateString(data_vec->id()->data(), data_vec->id()->size()), + fbb.CreateVector(data)); + fbb.Finish(message); + *changed_data = RedisModule_CreateString( + ctx, reinterpret_cast(fbb.GetBufferPointer()), fbb.GetSize()); auto size = RedisModule_ValueLength(key); if (size == 0) { REPLY_AND_RETURN_IF_FALSE(RedisModule_DeleteKey(key) == REDISMODULE_OK, @@ -638,7 +631,7 @@ int HashUpdate_DoWrite(RedisModuleCtx *ctx, RedisModuleString **argv, int argc, /// key should be published to. When publishing to a specific client, the /// channel name should be :. /// \param id The ID of the key to remove from. -/// \param data The GcsEntry protobuf data used to update this hash table. +/// \param data The GcsEntry flatbugger data used to update this hash table. /// 1). For deletion, this is a list of keys. /// 2). For updating, this is a list of pairs with each key followed by the value. /// \return OK if the remove succeeds, or an error message string if the remove @@ -655,7 +648,7 @@ int HashUpdate_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int a return Hash_DoPublish(ctx, new_argv.data()); } -/// A helper function to create a GcsEntry protobuf, based on the +/// A helper function to create and finish a GcsEntry, based on the /// current value or values at the given key. /// /// \param ctx The Redis module context. @@ -665,18 +658,21 @@ int HashUpdate_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int a /// \param prefix_str The string prefix associated with the open Redis key. /// When parsed, this is expected to be a TablePrefix. /// \param entry_id The UniqueID associated with the open Redis key. -/// \param[out] gcs_entry The created GcsEntry. -Status TableEntryToProtobuf(RedisModuleCtx *ctx, RedisModuleKey *table_key, - RedisModuleString *prefix_str, RedisModuleString *entry_id, - GcsEntry *gcs_entry) { +/// \param fbb A flatbuffer builder used to build the GcsEntry. +Status TableEntryToFlatbuf(RedisModuleCtx *ctx, RedisModuleKey *table_key, + RedisModuleString *prefix_str, RedisModuleString *entry_id, + flatbuffers::FlatBufferBuilder &fbb) { auto key_type = RedisModule_KeyType(table_key); switch (key_type) { case REDISMODULE_KEYTYPE_STRING: { - // Build the GcsEntry from the string data. - CreateGcsEntry(entry_id, GcsChangeMode::APPEND_OR_ADD, {}, gcs_entry); + // Build the flatbuffer from the string data. size_t data_len = 0; char *data_buf = RedisModule_StringDMA(table_key, &data_len, REDISMODULE_READ); - gcs_entry->add_entries(data_buf, data_len); + auto data = fbb.CreateString(data_buf, data_len); + auto message = + CreateGcsEntry(fbb, GcsChangeMode::APPEND_OR_ADD, + RedisStringToFlatbuf(fbb, entry_id), fbb.CreateVector(&data, 1)); + fbb.Finish(message); } break; case REDISMODULE_KEYTYPE_LIST: case REDISMODULE_KEYTYPE_HASH: @@ -700,20 +696,27 @@ Status TableEntryToProtobuf(RedisModuleCtx *ctx, RedisModuleKey *table_key, reply = RedisModule_Call(ctx, "HGETALL", "s", table_key_str); break; } - // Build the GcsEntry from the set of log entries. + // Build the flatbuffer from the set of log entries. if (reply == nullptr || RedisModule_CallReplyType(reply) != REDISMODULE_REPLY_ARRAY) { return Status::RedisError("Empty list/set/hash or wrong type"); } - CreateGcsEntry(entry_id, GcsChangeMode::APPEND_OR_ADD, {}, gcs_entry); + std::vector> data; for (size_t i = 0; i < RedisModule_CallReplyLength(reply); i++) { RedisModuleCallReply *element = RedisModule_CallReplyArrayElement(reply, i); size_t len; const char *element_str = RedisModule_CallReplyStringPtr(element, &len); - gcs_entry->add_entries(element_str, len); + data.push_back(fbb.CreateString(element_str, len)); } + auto message = + CreateGcsEntry(fbb, GcsChangeMode::APPEND_OR_ADD, + RedisStringToFlatbuf(fbb, entry_id), fbb.CreateVector(data)); + fbb.Finish(message); } break; case REDISMODULE_KEYTYPE_EMPTY: { - CreateGcsEntry(entry_id, GcsChangeMode::APPEND_OR_ADD, {}, gcs_entry); + auto message = CreateGcsEntry( + fbb, GcsChangeMode::APPEND_OR_ADD, RedisStringToFlatbuf(fbb, entry_id), + fbb.CreateVector(std::vector>())); + fbb.Finish(message); } break; default: return Status::RedisError("Invalid Redis type during lookup."); @@ -749,12 +752,11 @@ int TableLookup_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int if (table_key == nullptr) { RedisModule_ReplyWithNull(ctx); } else { - // Serialize the data to a GcsEntry to return to the client. - GcsEntry gcs_entry; - REPLY_AND_RETURN_IF_NOT_OK( - TableEntryToProtobuf(ctx, table_key, prefix_str, id, &gcs_entry)); - std::string str = gcs_entry.SerializeAsString(); - RedisModule_ReplyWithStringBuffer(ctx, str.data(), str.size()); + // Serialize the data to a flatbuffer to return to the client. + flatbuffers::FlatBufferBuilder fbb; + REPLY_AND_RETURN_IF_NOT_OK(TableEntryToFlatbuf(ctx, table_key, prefix_str, id, fbb)); + RedisModule_ReplyWithStringBuffer( + ctx, reinterpret_cast(fbb.GetBufferPointer()), fbb.GetSize()); } return REDISMODULE_OK; } @@ -868,11 +870,10 @@ int TableRequestNotifications_RedisCommand(RedisModuleCtx *ctx, RedisModuleStrin // Publish the current value at the key to the client that is requesting // notifications. An empty notification will be published if the key is // empty. - GcsEntry gcs_entry; - REPLY_AND_RETURN_IF_NOT_OK( - TableEntryToProtobuf(ctx, table_key, prefix_str, id, &gcs_entry)); - std::string str = gcs_entry.SerializeAsString(); - RedisModule_Call(ctx, "PUBLISH", "sb", client_channel, str.data(), str.size()); + flatbuffers::FlatBufferBuilder fbb; + REPLY_AND_RETURN_IF_NOT_OK(TableEntryToFlatbuf(ctx, table_key, prefix_str, id, fbb)); + RedisModule_Call(ctx, "PUBLISH", "sb", client_channel, + reinterpret_cast(fbb.GetBufferPointer()), fbb.GetSize()); return RedisModule_ReplyWithNull(ctx); } @@ -939,6 +940,53 @@ Status IsNil(bool *out, const std::string &data) { return Status::OK(); } +// This is a temporary redis command that will be removed once +// the GCS uses https://github.com/pcmoritz/credis. +// Be careful, this only supports Task Table payloads. +int TableTestAndUpdate_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, + int argc) { + if (argc != 5) { + return RedisModule_WrongArity(ctx); + } + RedisModuleString *prefix_str = argv[1]; + RedisModuleString *id = argv[3]; + RedisModuleString *update_data = argv[4]; + + RedisModuleKey *key; + REPLY_AND_RETURN_IF_NOT_OK( + OpenPrefixedKey(&key, ctx, prefix_str, id, REDISMODULE_READ | REDISMODULE_WRITE)); + + size_t value_len = 0; + char *value_buf = RedisModule_StringDMA(key, &value_len, REDISMODULE_READ); + + size_t update_len = 0; + const char *update_buf = RedisModule_StringPtrLen(update_data, &update_len); + + auto data = + flatbuffers::GetMutableRoot(reinterpret_cast(value_buf)); + + auto update = flatbuffers::GetRoot(update_buf); + + bool do_update = static_cast(data->scheduling_state()) & + static_cast(update->test_state_bitmask()); + + bool is_nil_result; + REPLY_AND_RETURN_IF_NOT_OK(IsNil(&is_nil_result, update->test_raylet_id()->str())); + if (!is_nil_result) { + do_update = do_update && update->test_raylet_id()->str() == data->raylet_id()->str(); + } + + if (do_update) { + REPLY_AND_RETURN_IF_FALSE(data->mutate_scheduling_state(update->update_state()), + "mutate_scheduling_state failed"); + } + REPLY_AND_RETURN_IF_FALSE(data->mutate_updated(do_update), "mutate_updated failed"); + + int result = RedisModule_ReplyWithStringBuffer(ctx, value_buf, value_len); + + return result; +} + std::string DebugString() { std::stringstream result; result << "RedisModule:"; @@ -968,6 +1016,7 @@ AUTO_MEMORY(TableLookup_RedisCommand); AUTO_MEMORY(TableRequestNotifications_RedisCommand); AUTO_MEMORY(TableDelete_RedisCommand); AUTO_MEMORY(TableCancelNotifications_RedisCommand); +AUTO_MEMORY(TableTestAndUpdate_RedisCommand); AUTO_MEMORY(DebugString_RedisCommand); #if RAY_USE_NEW_GCS AUTO_MEMORY(ChainTableAdd_RedisCommand); @@ -1033,6 +1082,12 @@ int RedisModule_OnLoad(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) return REDISMODULE_ERR; } + if (RedisModule_CreateCommand(ctx, "ray.table_test_and_update", + TableTestAndUpdate_RedisCommand, "write", 0, 0, + 0) == REDISMODULE_ERR) { + return REDISMODULE_ERR; + } + if (RedisModule_CreateCommand(ctx, "ray.debug_string", DebugString_RedisCommand, "readonly", 0, 0, 0) == REDISMODULE_ERR) { return REDISMODULE_ERR; diff --git a/src/ray/gcs/tables.cc b/src/ray/gcs/tables.cc index b7c19ebfd595..33f1615580a6 100644 --- a/src/ray/gcs/tables.cc +++ b/src/ray/gcs/tables.cc @@ -3,7 +3,6 @@ #include "ray/common/common_protocol.h" #include "ray/common/ray_config.h" #include "ray/gcs/client.h" -#include "ray/rpc/util.h" #include "ray/util/util.h" namespace { @@ -40,44 +39,48 @@ namespace gcs { template Status Log::Append(const DriverID &driver_id, const ID &id, - std::shared_ptr &data, const WriteCallback &done) { + std::shared_ptr &dataT, const WriteCallback &done) { num_appends_++; - auto callback = [this, id, data, done](const CallbackReply &reply) { + auto callback = [this, id, dataT, done](const CallbackReply &reply) { const auto status = reply.ReadAsStatus(); // Failed to append the entry. RAY_CHECK(status.ok()) << "Failed to execute command TABLE_APPEND:" << status.ToString(); if (done != nullptr) { - (done)(client_, id, *data); + (done)(client_, id, *dataT); } }; - std::string str = data->SerializeAsString(); - return GetRedisContext(id)->RunAsync(GetLogAppendCommand(command_type_), id, str.data(), - str.length(), prefix_, pubsub_channel_, - std::move(callback)); + flatbuffers::FlatBufferBuilder fbb; + fbb.ForceDefaults(true); + fbb.Finish(Data::Pack(fbb, dataT.get())); + return GetRedisContext(id)->RunAsync(GetLogAppendCommand(command_type_), id, + fbb.GetBufferPointer(), fbb.GetSize(), prefix_, + pubsub_channel_, std::move(callback)); } template Status Log::AppendAt(const DriverID &driver_id, const ID &id, - std::shared_ptr &data, const WriteCallback &done, + std::shared_ptr &dataT, const WriteCallback &done, const WriteCallback &failure, int log_length) { num_appends_++; - auto callback = [this, id, data, done, failure](const CallbackReply &reply) { + auto callback = [this, id, dataT, done, failure](const CallbackReply &reply) { const auto status = reply.ReadAsStatus(); if (status.ok()) { if (done != nullptr) { - (done)(client_, id, *data); + (done)(client_, id, *dataT); } } else { if (failure != nullptr) { - (failure)(client_, id, *data); + (failure)(client_, id, *dataT); } } }; - std::string str = data->SerializeAsString(); - return GetRedisContext(id)->RunAsync(GetLogAppendCommand(command_type_), id, str.data(), - str.length(), prefix_, pubsub_channel_, - std::move(callback), log_length); + flatbuffers::FlatBufferBuilder fbb; + fbb.ForceDefaults(true); + fbb.Finish(Data::Pack(fbb, dataT.get())); + return GetRedisContext(id)->RunAsync(GetLogAppendCommand(command_type_), id, + fbb.GetBufferPointer(), fbb.GetSize(), prefix_, + pubsub_channel_, std::move(callback), log_length); } template @@ -86,15 +89,16 @@ Status Log::Lookup(const DriverID &driver_id, const ID &id, num_lookups_++; auto callback = [this, id, lookup](const CallbackReply &reply) { if (lookup != nullptr) { - std::vector results; + std::vector results; if (!reply.IsNil()) { - GcsEntry gcs_entry; - gcs_entry.ParseFromString(reply.ReadAsString()); - RAY_CHECK(ID::FromBinary(gcs_entry.id()) == id); - for (size_t i = 0; i < gcs_entry.entries_size(); i++) { - Data data; - data.ParseFromString(gcs_entry.entries(i)); - results.emplace_back(std::move(data)); + const auto data = reply.ReadAsString(); + auto root = flatbuffers::GetRoot(data.data()); + RAY_CHECK(from_flatbuf(*root->id()) == id); + for (size_t i = 0; i < root->entries()->size(); i++) { + DataT result; + auto data_root = flatbuffers::GetRoot(root->entries()->Get(i)->data()); + data_root->UnPackTo(&result); + results.emplace_back(std::move(result)); } } lookup(client_, id, results); @@ -111,7 +115,7 @@ Status Log::Subscribe(const DriverID &driver_id, const ClientID &clien const SubscriptionCallback &done) { auto subscribe_wrapper = [subscribe](AsyncGcsClient *client, const ID &id, const GcsChangeMode change_mode, - const std::vector &data) { + const std::vector &data) { RAY_CHECK(change_mode != GcsChangeMode::REMOVE); subscribe(client, id, data); }; @@ -137,16 +141,19 @@ Status Log::Subscribe(const DriverID &driver_id, const ClientID &clien // Data is provided. This is the callback for a message. if (subscribe != nullptr) { // Parse the notification. - GcsEntry gcs_entry; - gcs_entry.ParseFromString(data); - ID id = ID::FromBinary(gcs_entry.id()); - std::vector results; - for (size_t i = 0; i < gcs_entry.entries_size(); i++) { - Data result; - result.ParseFromString(gcs_entry.entries(i)); + auto root = flatbuffers::GetRoot(data.data()); + ID id; + if (root->id()->size() > 0) { + id = from_flatbuf(*root->id()); + } + std::vector results; + for (size_t i = 0; i < root->entries()->size(); i++) { + DataT result; + auto data_root = flatbuffers::GetRoot(root->entries()->Get(i)->data()); + data_root->UnPackTo(&result); results.emplace_back(std::move(result)); } - subscribe(client_, id, gcs_entry.change_mode(), results); + subscribe(client_, id, root->change_mode(), results); } } }; @@ -227,17 +234,19 @@ std::string Log::DebugString() const { template Status Table::Add(const DriverID &driver_id, const ID &id, - std::shared_ptr &data, const WriteCallback &done) { + std::shared_ptr &dataT, const WriteCallback &done) { num_adds_++; - auto callback = [this, id, data, done](const CallbackReply &reply) { + auto callback = [this, id, dataT, done](const CallbackReply &reply) { if (done != nullptr) { - (done)(client_, id, *data); + (done)(client_, id, *dataT); } }; - std::string str = data->SerializeAsString(); - return GetRedisContext(id)->RunAsync(GetTableAddCommand(command_type_), id, str.data(), - str.length(), prefix_, pubsub_channel_, - std::move(callback)); + flatbuffers::FlatBufferBuilder fbb; + fbb.ForceDefaults(true); + fbb.Finish(Data::Pack(fbb, dataT.get())); + return GetRedisContext(id)->RunAsync(GetTableAddCommand(command_type_), id, + fbb.GetBufferPointer(), fbb.GetSize(), prefix_, + pubsub_channel_, std::move(callback)); } template @@ -246,7 +255,7 @@ Status Table::Lookup(const DriverID &driver_id, const ID &id, num_lookups_++; return Log::Lookup(driver_id, id, [lookup, failure](AsyncGcsClient *client, const ID &id, - const std::vector &data) { + const std::vector &data) { if (data.empty()) { if (failure != nullptr) { (failure)(client, id); @@ -268,7 +277,7 @@ Status Table::Subscribe(const DriverID &driver_id, const ClientID &cli return Log::Subscribe( driver_id, client_id, [subscribe, failure](AsyncGcsClient *client, const ID &id, - const std::vector &data) { + const std::vector &data) { RAY_CHECK(data.empty() || data.size() == 1); if (data.size() == 1) { subscribe(client, id, data[0]); @@ -290,30 +299,36 @@ std::string Table::DebugString() const { template Status Set::Add(const DriverID &driver_id, const ID &id, - std::shared_ptr &data, const WriteCallback &done) { + std::shared_ptr &dataT, const WriteCallback &done) { num_adds_++; - auto callback = [this, id, data, done](const CallbackReply &reply) { + auto callback = [this, id, dataT, done](const CallbackReply &reply) { if (done != nullptr) { - (done)(client_, id, *data); + (done)(client_, id, *dataT); } }; - std::string str = data->SerializeAsString(); - return GetRedisContext(id)->RunAsync("RAY.SET_ADD", id, str.data(), str.length(), - prefix_, pubsub_channel_, std::move(callback)); + flatbuffers::FlatBufferBuilder fbb; + fbb.ForceDefaults(true); + fbb.Finish(Data::Pack(fbb, dataT.get())); + return GetRedisContext(id)->RunAsync("RAY.SET_ADD", id, fbb.GetBufferPointer(), + fbb.GetSize(), prefix_, pubsub_channel_, + std::move(callback)); } template Status Set::Remove(const DriverID &driver_id, const ID &id, - std::shared_ptr &data, const WriteCallback &done) { + std::shared_ptr &dataT, const WriteCallback &done) { num_removes_++; - auto callback = [this, id, data, done](const CallbackReply &reply) { + auto callback = [this, id, dataT, done](const CallbackReply &reply) { if (done != nullptr) { - (done)(client_, id, *data); + (done)(client_, id, *dataT); } }; - std::string str = data->SerializeAsString(); - return GetRedisContext(id)->RunAsync("RAY.SET_REMOVE", id, str.data(), str.length(), - prefix_, pubsub_channel_, std::move(callback)); + flatbuffers::FlatBufferBuilder fbb; + fbb.ForceDefaults(true); + fbb.Finish(Data::Pack(fbb, dataT.get())); + return GetRedisContext(id)->RunAsync("RAY.SET_REMOVE", id, fbb.GetBufferPointer(), + fbb.GetSize(), prefix_, pubsub_channel_, + std::move(callback)); } template @@ -333,16 +348,26 @@ Status Hash::Update(const DriverID &driver_id, const ID &id, (done)(client_, id, data_map); } }; - GcsEntry gcs_entry; - gcs_entry.set_id(id.Binary()); - gcs_entry.set_change_mode(GcsChangeMode::APPEND_OR_ADD); - for (const auto &pair : data_map) { - gcs_entry.add_entries(pair.first); - gcs_entry.add_entries(pair.second->SerializeAsString()); + flatbuffers::FlatBufferBuilder fbb; + std::vector> data_vec; + data_vec.reserve(data_map.size() * 2); + for (auto const &pair : data_map) { + // Add the key. + data_vec.push_back(fbb.CreateString(pair.first)); + flatbuffers::FlatBufferBuilder fbb_data; + fbb_data.ForceDefaults(true); + fbb_data.Finish(Data::Pack(fbb_data, pair.second.get())); + std::string data(reinterpret_cast(fbb_data.GetBufferPointer()), + fbb_data.GetSize()); + // Add the value. + data_vec.push_back(fbb.CreateString(data)); } - std::string str = gcs_entry.SerializeAsString(); - return GetRedisContext(id)->RunAsync("RAY.HASH_UPDATE", id, str.data(), str.size(), - prefix_, pubsub_channel_, std::move(callback)); + + fbb.Finish(CreateGcsEntry(fbb, GcsChangeMode::APPEND_OR_ADD, + fbb.CreateString(id.Binary()), fbb.CreateVector(data_vec))); + return GetRedisContext(id)->RunAsync("RAY.HASH_UPDATE", id, fbb.GetBufferPointer(), + fbb.GetSize(), prefix_, pubsub_channel_, + std::move(callback)); } template @@ -355,15 +380,19 @@ Status Hash::RemoveEntries(const DriverID &driver_id, const ID &id, (remove_callback)(client_, id, keys); } }; - GcsEntry gcs_entry; - gcs_entry.set_id(id.Binary()); - gcs_entry.set_change_mode(GcsChangeMode::REMOVE); - for (const auto &key : keys) { - gcs_entry.add_entries(key); + flatbuffers::FlatBufferBuilder fbb; + std::vector> data_vec; + data_vec.reserve(keys.size()); + // Add the keys. + for (auto const &key : keys) { + data_vec.push_back(fbb.CreateString(key)); } - std::string str = gcs_entry.SerializeAsString(); - return GetRedisContext(id)->RunAsync("RAY.HASH_UPDATE", id, str.data(), str.size(), - prefix_, pubsub_channel_, std::move(callback)); + + fbb.Finish(CreateGcsEntry(fbb, GcsChangeMode::REMOVE, fbb.CreateString(id.Binary()), + fbb.CreateVector(data_vec))); + return GetRedisContext(id)->RunAsync("RAY.HASH_UPDATE", id, fbb.GetBufferPointer(), + fbb.GetSize(), prefix_, pubsub_channel_, + std::move(callback)); } template @@ -383,15 +412,17 @@ Status Hash::Lookup(const DriverID &driver_id, const ID &id, DataMap results; if (!reply.IsNil()) { const auto data = reply.ReadAsString(); - GcsEntry gcs_entry; - gcs_entry.ParseFromString(reply.ReadAsString()); - RAY_CHECK(ID::FromBinary(gcs_entry.id()) == id); - RAY_CHECK(gcs_entry.entries_size() % 2 == 0); - for (int i = 0; i < gcs_entry.entries_size(); i += 2) { - const auto &key = gcs_entry.entries(i); - const auto value = std::make_shared(); - value->ParseFromString(gcs_entry.entries(i + 1)); - results.emplace(key, std::move(value)); + auto root = flatbuffers::GetRoot(data.data()); + RAY_CHECK(from_flatbuf(*root->id()) == id); + RAY_CHECK(root->entries()->size() % 2 == 0); + for (size_t i = 0; i < root->entries()->size(); i += 2) { + std::string key(root->entries()->Get(i)->data(), + root->entries()->Get(i)->size()); + auto result = std::make_shared(); + auto data_root = + flatbuffers::GetRoot(root->entries()->Get(i + 1)->data()); + data_root->UnPackTo(result.get()); + results.emplace(key, std::move(result)); } } lookup(client_, id, results); @@ -420,24 +451,31 @@ Status Hash::Subscribe(const DriverID &driver_id, const ClientID &clie // Data is provided. This is the callback for a message. if (subscribe != nullptr) { // Parse the notification. - GcsEntry gcs_entry; - gcs_entry.ParseFromString(data); - ID id = ID::FromBinary(gcs_entry.id()); + auto root = flatbuffers::GetRoot(data.data()); DataMap data_map; - if (gcs_entry.change_mode() == GcsChangeMode::REMOVE) { - for (const auto &key : gcs_entry.entries()) { - data_map.emplace(key, std::shared_ptr()); + ID id; + if (root->id()->size() > 0) { + id = from_flatbuf(*root->id()); + } + if (root->change_mode() == GcsChangeMode::REMOVE) { + for (size_t i = 0; i < root->entries()->size(); i++) { + std::string key(root->entries()->Get(i)->data(), + root->entries()->Get(i)->size()); + data_map.emplace(key, std::shared_ptr()); } } else { - RAY_CHECK(gcs_entry.entries_size() % 2 == 0); - for (int i = 0; i < gcs_entry.entries_size(); i += 2) { - const auto &key = gcs_entry.entries(i); - const auto value = std::make_shared(); - value->ParseFromString(gcs_entry.entries(i + 1)); - data_map.emplace(key, std::move(value)); + RAY_CHECK(root->entries()->size() % 2 == 0); + for (size_t i = 0; i < root->entries()->size(); i += 2) { + std::string key(root->entries()->Get(i)->data(), + root->entries()->Get(i)->size()); + auto result = std::make_shared(); + auto data_root = + flatbuffers::GetRoot(root->entries()->Get(i + 1)->data()); + data_root->UnPackTo(result.get()); + data_map.emplace(key, std::move(result)); } } - subscribe(client_, id, gcs_entry.change_mode(), data_map); + subscribe(client_, id, root->change_mode(), data_map); } } }; @@ -452,11 +490,11 @@ Status Hash::Subscribe(const DriverID &driver_id, const ClientID &clie Status ErrorTable::PushErrorToDriver(const DriverID &driver_id, const std::string &type, const std::string &error_message, double timestamp) { - auto data = std::make_shared(); - data->set_driver_id(driver_id.Binary()); - data->set_type(type); - data->set_error_message(error_message); - data->set_timestamp(timestamp); + auto data = std::make_shared(); + data->driver_id = driver_id.Binary(); + data->type = type; + data->error_message = error_message; + data->timestamp = timestamp; return Append(DriverID(driver_id), driver_id, data, /*done_callback=*/nullptr); } @@ -465,9 +503,11 @@ std::string ErrorTable::DebugString() const { } Status ProfileTable::AddProfileEventBatch(const ProfileTableData &profile_events) { - // TODO(hchen): Change the parameter to shared_ptr to avoid copying data. - auto data = std::make_shared(); - data->CopyFrom(profile_events); + auto data = std::make_shared(); + // There is some room for optimization here because the Append function will just + // call "Pack" and undo the "UnPack". + profile_events.UnPackTo(data.get()); + return Append(DriverID::Nil(), UniqueID::FromRandom(), data, /*done_callback=*/nullptr); } @@ -477,9 +517,9 @@ std::string ProfileTable::DebugString() const { } Status DriverTable::AppendDriverData(const DriverID &driver_id, bool is_dead) { - auto data = std::make_shared(); - data->set_driver_id(driver_id.Binary()); - data->set_is_dead(is_dead); + auto data = std::make_shared(); + data->driver_id = driver_id.Binary(); + data->is_dead = is_dead; return Append(DriverID(driver_id), driver_id, data, /*done_callback=*/nullptr); } @@ -487,8 +527,7 @@ void ClientTable::RegisterClientAddedCallback(const ClientTableCallback &callbac client_added_callback_ = callback; // Call the callback for any added clients that are cached. for (const auto &entry : client_cache_) { - if (!entry.first.IsNil() && - (entry.second.entry_type() == ClientTableData::INSERTION)) { + if (!entry.first.IsNil() && (entry.second.entry_type == EntryType::INSERTION)) { client_added_callback_(client_, entry.first, entry.second); } } @@ -498,7 +537,7 @@ void ClientTable::RegisterClientRemovedCallback(const ClientTableCallback &callb client_removed_callback_ = callback; // Call the callback for any removed clients that are cached. for (const auto &entry : client_cache_) { - if (!entry.first.IsNil() && entry.second.entry_type() == ClientTableData::DELETION) { + if (!entry.first.IsNil() && entry.second.entry_type == EntryType::DELETION) { client_removed_callback_(client_, entry.first, entry.second); } } @@ -510,7 +549,7 @@ void ClientTable::RegisterResourceCreateUpdatedCallback( // Call the callback for any clients that are cached. for (const auto &entry : client_cache_) { if (!entry.first.IsNil() && - (entry.second.entry_type() == ClientTableData::RES_CREATEUPDATE)) { + (entry.second.entry_type == EntryType::RES_CREATEUPDATE)) { resource_createupdated_callback_(client_, entry.first, entry.second); } } @@ -520,16 +559,15 @@ void ClientTable::RegisterResourceDeletedCallback(const ClientTableCallback &cal resource_deleted_callback_ = callback; // Call the callback for any clients that are cached. for (const auto &entry : client_cache_) { - if (!entry.first.IsNil() && - entry.second.entry_type() == ClientTableData::RES_DELETE) { + if (!entry.first.IsNil() && entry.second.entry_type == EntryType::RES_DELETE) { resource_deleted_callback_(client_, entry.first, entry.second); } } } void ClientTable::HandleNotification(AsyncGcsClient *client, - const ClientTableData &data) { - ClientID client_id = ClientID::FromBinary(data.client_id()); + const ClientTableDataT &data) { + ClientID client_id = ClientID::FromBinary(data.client_id); // It's possible to get duplicate notifications from the client table, so // check whether this notification is new. auto entry = client_cache_.find(client_id); @@ -540,16 +578,16 @@ void ClientTable::HandleNotification(AsyncGcsClient *client, } else { // If the entry is in the cache, then the notification is new if the client // was alive and is now dead or resources have been updated. - bool was_not_deleted = (entry->second.entry_type() != ClientTableData::DELETION); - bool is_deleted = (data.entry_type() == ClientTableData::DELETION); - bool is_res_modified = ((data.entry_type() == ClientTableData::RES_CREATEUPDATE) || - (data.entry_type() == ClientTableData::RES_DELETE)); + bool was_not_deleted = (entry->second.entry_type != EntryType::DELETION); + bool is_deleted = (data.entry_type == EntryType::DELETION); + bool is_res_modified = ((data.entry_type == EntryType::RES_CREATEUPDATE) || + (data.entry_type == EntryType::RES_DELETE)); is_notif_new = (was_not_deleted && (is_deleted || is_res_modified)); // Once a client with a given ID has been removed, it should never be added // again. If the entry was in the cache and the client was deleted, check // that this new notification is not an insertion. - if (entry->second.entry_type() == ClientTableData::DELETION) { - RAY_CHECK((data.entry_type() == ClientTableData::DELETION)) + if (entry->second.entry_type == EntryType::DELETION) { + RAY_CHECK((data.entry_type == EntryType::DELETION)) << "Notification for addition of a client that was already removed:" << client_id; } @@ -557,64 +595,64 @@ void ClientTable::HandleNotification(AsyncGcsClient *client, // Add the notification to our cache. Notifications are idempotent. // If it is a new client or a client removal, add as is - if ((data.entry_type() == ClientTableData::INSERTION) || - (data.entry_type() == ClientTableData::DELETION)) { + if ((data.entry_type == EntryType::INSERTION) || + (data.entry_type == EntryType::DELETION)) { RAY_LOG(DEBUG) << "[ClientTableNotification] ClientTable Insertion/Deletion " "notification for client id " - << client_id << ". EntryType: " << int(data.entry_type()) + << client_id << ". EntryType: " << int(data.entry_type) << ". Setting the client cache to data."; client_cache_[client_id] = data; - } else if ((data.entry_type() == ClientTableData::RES_CREATEUPDATE) || - (data.entry_type() == ClientTableData::RES_DELETE)) { + } else if ((data.entry_type == EntryType::RES_CREATEUPDATE) || + (data.entry_type == EntryType::RES_DELETE)) { RAY_LOG(DEBUG) << "[ClientTableNotification] ClientTable RES_CREATEUPDATE " "notification for client id " - << client_id << ". EntryType: " << int(data.entry_type()) + << client_id << ". EntryType: " << int(data.entry_type) << ". Updating the client cache with the delta from the log."; - ClientTableData &cache_data = client_cache_[client_id]; + ClientTableDataT &cache_data = client_cache_[client_id]; // Iterate over all resources in the new create/update notification - for (std::vector::size_type i = 0; i != data.resources_total_label_size(); i++) { - auto const &resource_name = data.resources_total_label(i); - auto const &capacity = data.resources_total_capacity(i); + for (std::vector::size_type i = 0; i != data.resources_total_label.size(); i++) { + auto const &resource_name = data.resources_total_label[i]; + auto const &capacity = data.resources_total_capacity[i]; // If resource exists in the ClientTableData, update it, else create it auto existing_resource_label = - std::find(cache_data.resources_total_label().begin(), - cache_data.resources_total_label().end(), resource_name); - if (existing_resource_label != cache_data.resources_total_label().end()) { - auto index = std::distance(cache_data.resources_total_label().begin(), + std::find(cache_data.resources_total_label.begin(), + cache_data.resources_total_label.end(), resource_name); + if (existing_resource_label != cache_data.resources_total_label.end()) { + auto index = std::distance(cache_data.resources_total_label.begin(), existing_resource_label); // Resource already exists, set capacity if updation call.. - if (data.entry_type() == ClientTableData::RES_CREATEUPDATE) { - cache_data.set_resources_total_capacity(index, capacity); + if (data.entry_type == EntryType::RES_CREATEUPDATE) { + cache_data.resources_total_capacity[index] = capacity; } // .. delete if deletion call. - else if (data.entry_type() == ClientTableData::RES_DELETE) { - cache_data.mutable_resources_total_label()->erase( - cache_data.resources_total_label().begin() + index); - cache_data.mutable_resources_total_capacity()->erase( - cache_data.resources_total_capacity().begin() + index); + else if (data.entry_type == EntryType::RES_DELETE) { + cache_data.resources_total_label.erase( + cache_data.resources_total_label.begin() + index); + cache_data.resources_total_capacity.erase( + cache_data.resources_total_capacity.begin() + index); } } else { // Resource does not exist, create resource and add capacity if it was a resource // create call. - if (data.entry_type() == ClientTableData::RES_CREATEUPDATE) { - cache_data.add_resources_total_label(resource_name); - cache_data.add_resources_total_capacity(capacity); + if (data.entry_type == EntryType::RES_CREATEUPDATE) { + cache_data.resources_total_label.push_back(resource_name); + cache_data.resources_total_capacity.push_back(capacity); } } } } // If the notification is new, call any registered callbacks. - ClientTableData &cache_data = client_cache_[client_id]; + ClientTableDataT &cache_data = client_cache_[client_id]; if (is_notif_new) { - if (data.entry_type() == ClientTableData::INSERTION) { + if (data.entry_type == EntryType::INSERTION) { if (client_added_callback_ != nullptr) { client_added_callback_(client, client_id, cache_data); } RAY_CHECK(removed_clients_.find(client_id) == removed_clients_.end()); - } else if (data.entry_type() == ClientTableData::DELETION) { + } else if (data.entry_type == EntryType::DELETION) { // NOTE(swang): The client should be added to this data structure before // the callback gets called, in case the callback depends on the data // structure getting updated. @@ -622,11 +660,11 @@ void ClientTable::HandleNotification(AsyncGcsClient *client, if (client_removed_callback_ != nullptr) { client_removed_callback_(client, client_id, cache_data); } - } else if (data.entry_type() == ClientTableData::RES_CREATEUPDATE) { + } else if (data.entry_type == EntryType::RES_CREATEUPDATE) { if (resource_createupdated_callback_ != nullptr) { resource_createupdated_callback_(client, client_id, cache_data); } - } else if (data.entry_type() == ClientTableData::RES_DELETE) { + } else if (data.entry_type == EntryType::RES_DELETE) { if (resource_deleted_callback_ != nullptr) { resource_deleted_callback_(client, client_id, cache_data); } @@ -634,54 +672,54 @@ void ClientTable::HandleNotification(AsyncGcsClient *client, } } -void ClientTable::HandleConnected(AsyncGcsClient *client, const ClientTableData &data) { - auto connected_client_id = ClientID::FromBinary(data.client_id()); +void ClientTable::HandleConnected(AsyncGcsClient *client, const ClientTableDataT &data) { + auto connected_client_id = ClientID::FromBinary(data.client_id); RAY_CHECK(client_id_ == connected_client_id) << connected_client_id << " " << client_id_; } const ClientID &ClientTable::GetLocalClientId() const { return client_id_; } -const ClientTableData &ClientTable::GetLocalClient() const { return local_client_; } +const ClientTableDataT &ClientTable::GetLocalClient() const { return local_client_; } bool ClientTable::IsRemoved(const ClientID &client_id) const { return removed_clients_.count(client_id) == 1; } -Status ClientTable::Connect(const ClientTableData &local_client) { +Status ClientTable::Connect(const ClientTableDataT &local_client) { RAY_CHECK(!disconnected_) << "Tried to reconnect a disconnected client."; - RAY_CHECK(local_client.client_id() == local_client_.client_id()); + RAY_CHECK(local_client.client_id == local_client_.client_id); local_client_ = local_client; // Construct the data to add to the client table. - auto data = std::make_shared(local_client_); - data->set_entry_type(ClientTableData::INSERTION); + auto data = std::make_shared(local_client_); + data->entry_type = EntryType::INSERTION; // Callback to handle our own successful connection once we've added // ourselves. auto add_callback = [this](AsyncGcsClient *client, const UniqueID &log_key, - const ClientTableData &data) { + const ClientTableDataT &data) { RAY_CHECK(log_key == client_log_key_); HandleConnected(client, data); // Callback for a notification from the client table. auto notification_callback = [this]( AsyncGcsClient *client, const UniqueID &log_key, - const std::vector ¬ifications) { + const std::vector ¬ifications) { RAY_CHECK(log_key == client_log_key_); - std::unordered_map connected_nodes; - std::unordered_map disconnected_nodes; + std::unordered_map connected_nodes; + std::unordered_map disconnected_nodes; for (auto ¬ification : notifications) { // This is temporary fix for Issue 4140 to avoid connect to dead nodes. // TODO(yuhguo): remove this temporary fix after GCS entry is removable. - if (notification.entry_type() != ClientTableData::DELETION) { - connected_nodes.emplace(notification.client_id(), notification); + if (notification.entry_type != EntryType::DELETION) { + connected_nodes.emplace(notification.client_id, notification); } else { - auto iter = connected_nodes.find(notification.client_id()); + auto iter = connected_nodes.find(notification.client_id); if (iter != connected_nodes.end()) { connected_nodes.erase(iter); } - disconnected_nodes.emplace(notification.client_id(), notification); + disconnected_nodes.emplace(notification.client_id, notification); } } for (const auto &pair : connected_nodes) { @@ -704,10 +742,10 @@ Status ClientTable::Connect(const ClientTableData &local_client) { } Status ClientTable::Disconnect(const DisconnectCallback &callback) { - auto data = std::make_shared(local_client_); - data->set_entry_type(ClientTableData::DELETION); + auto data = std::make_shared(local_client_); + data->entry_type = EntryType::DELETION; auto add_callback = [this, callback](AsyncGcsClient *client, const ClientID &id, - const ClientTableData &data) { + const ClientTableDataT &data) { HandleConnected(client, data); RAY_CHECK_OK(CancelNotifications(DriverID::Nil(), client_log_key_, id)); if (callback != nullptr) { @@ -721,24 +759,24 @@ Status ClientTable::Disconnect(const DisconnectCallback &callback) { } ray::Status ClientTable::MarkDisconnected(const ClientID &dead_client_id) { - auto data = std::make_shared(); - data->set_client_id(dead_client_id.Binary()); - data->set_entry_type(ClientTableData::DELETION); + auto data = std::make_shared(); + data->client_id = dead_client_id.Binary(); + data->entry_type = EntryType::DELETION; return Append(DriverID::Nil(), client_log_key_, data, nullptr); } void ClientTable::GetClient(const ClientID &client_id, - ClientTableData &client_info) const { + ClientTableDataT &client_info) const { RAY_CHECK(!client_id.IsNil()); auto entry = client_cache_.find(client_id); if (entry != client_cache_.end()) { client_info = entry->second; } else { - client_info.set_client_id(ClientID::Nil().Binary()); + client_info.client_id = ClientID::Nil().Binary(); } } -const std::unordered_map &ClientTable::GetAllClients() const { +const std::unordered_map &ClientTable::GetAllClients() const { return client_cache_; } @@ -760,29 +798,31 @@ Status ActorCheckpointIdTable::AddCheckpointId(const DriverID &driver_id, const ActorCheckpointID &checkpoint_id) { auto lookup_callback = [this, checkpoint_id, driver_id, actor_id]( ray::gcs::AsyncGcsClient *client, const UniqueID &id, - const ActorCheckpointIdData &data) { - std::shared_ptr copy = - std::make_shared(data); - copy->add_timestamps(current_sys_time_ms()); - copy->add_checkpoint_ids(checkpoint_id.Binary()); + const ActorCheckpointIdDataT &data) { + std::shared_ptr copy = + std::make_shared(data); + copy->timestamps.push_back(current_sys_time_ms()); + copy->checkpoint_ids += checkpoint_id.Binary(); auto num_to_keep = RayConfig::instance().num_actor_checkpoints_to_keep(); - while (copy->timestamps().size() > num_to_keep) { + while (copy->timestamps.size() > num_to_keep) { // Delete the checkpoint from actor checkpoint table. - const auto &to_delete = ActorCheckpointID::FromBinary(copy->checkpoint_ids(0)); - RAY_LOG(DEBUG) << "Deleting checkpoint " << to_delete << " for actor " << actor_id; - copy->mutable_checkpoint_ids()->erase(copy->mutable_checkpoint_ids()->begin()); - copy->mutable_timestamps()->erase(copy->mutable_timestamps()->begin()); - client_->actor_checkpoint_table().Delete(driver_id, to_delete); + const auto &checkpoint_id = + ActorCheckpointID::FromBinary(copy->checkpoint_ids.substr(0, kUniqueIDSize)); + RAY_LOG(DEBUG) << "Deleting checkpoint " << checkpoint_id << " for actor " + << actor_id; + copy->timestamps.erase(copy->timestamps.begin()); + copy->checkpoint_ids.erase(0, kUniqueIDSize); + client_->actor_checkpoint_table().Delete(driver_id, checkpoint_id); } RAY_CHECK_OK(Add(driver_id, actor_id, copy, nullptr)); }; auto failure_callback = [this, checkpoint_id, driver_id, actor_id]( ray::gcs::AsyncGcsClient *client, const UniqueID &id) { - std::shared_ptr data = - std::make_shared(); - data->set_actor_id(id.Binary()); - data->add_timestamps(current_sys_time_ms()); - *data->add_checkpoint_ids() = checkpoint_id.Binary(); + std::shared_ptr data = + std::make_shared(); + data->actor_id = id.Binary(); + data->timestamps.push_back(current_sys_time_ms()); + data->checkpoint_ids = checkpoint_id.Binary(); RAY_CHECK_OK(Add(driver_id, actor_id, data, nullptr)); }; return Lookup(driver_id, actor_id, lookup_callback, failure_callback); @@ -790,7 +830,8 @@ Status ActorCheckpointIdTable::AddCheckpointId(const DriverID &driver_id, template class Log; template class Set; -template class Log; +template class Log; +template class Table; template class Table; template class Log; template class Log; diff --git a/src/ray/gcs/tables.h b/src/ray/gcs/tables.h index 2ecc3440839e..6a1d502a7f54 100644 --- a/src/ray/gcs/tables.h +++ b/src/ray/gcs/tables.h @@ -11,8 +11,10 @@ #include "ray/common/status.h" #include "ray/util/logging.h" +#include "ray/gcs/format/gcs_generated.h" #include "ray/gcs/redis_context.h" -#include "ray/protobuf/gcs.pb.h" +// TODO(rkn): Remove this include. +#include "ray/raylet/format/node_manager_generated.h" struct redisAsyncContext; @@ -20,25 +22,6 @@ namespace ray { namespace gcs { -using rpc::ActorCheckpointData; -using rpc::ActorCheckpointIdData; -using rpc::ActorTableData; -using rpc::ClientTableData; -using rpc::DriverTableData; -using rpc::ErrorTableData; -using rpc::GcsChangeMode; -using rpc::GcsEntry; -using rpc::HeartbeatBatchTableData; -using rpc::HeartbeatTableData; -using rpc::ObjectTableData; -using rpc::ProfileTableData; -using rpc::RayResource; -using rpc::TablePrefix; -using rpc::TablePubsub; -using rpc::TaskLeaseData; -using rpc::TaskReconstructionData; -using rpc::TaskTableData; - class RedisContext; class AsyncGcsClient; @@ -65,12 +48,13 @@ class PubsubInterface { template class LogInterface { public: + using DataT = typename Data::NativeTableType; using WriteCallback = - std::function; + std::function; virtual Status Append(const DriverID &driver_id, const ID &id, - std::shared_ptr &data, const WriteCallback &done) = 0; + std::shared_ptr &data, const WriteCallback &done) = 0; virtual Status AppendAt(const DriverID &driver_id, const ID &task_id, - std::shared_ptr &data, const WriteCallback &done, + std::shared_ptr &data, const WriteCallback &done, const WriteCallback &failure, int log_length) = 0; virtual ~LogInterface(){}; }; @@ -88,11 +72,12 @@ class LogInterface { template class Log : public LogInterface, virtual public PubsubInterface { public: + using DataT = typename Data::NativeTableType; using Callback = std::function &data)>; - using NotificationCallback = - std::function &data)>; + const std::vector &data)>; + using NotificationCallback = std::function &data)>; /// The callback to call when a write to a key succeeds. using WriteCallback = typename LogInterface::WriteCallback; /// The callback to call when a SUBSCRIBE call completes and we are ready to @@ -101,7 +86,7 @@ class Log : public LogInterface, virtual public PubsubInterface { struct CallbackData { ID id; - std::shared_ptr data; + std::shared_ptr data; Callback callback; // An optional callback to call for subscription operations, where the // first message is a notification of subscription success. @@ -126,7 +111,7 @@ class Log : public LogInterface, virtual public PubsubInterface { /// \param done Callback that is called once the data has been written to the /// GCS. /// \return Status - Status Append(const DriverID &driver_id, const ID &id, std::shared_ptr &data, + Status Append(const DriverID &driver_id, const ID &id, std::shared_ptr &data, const WriteCallback &done); /// Append a log entry to a key if and only if the log has the given number @@ -141,7 +126,7 @@ class Log : public LogInterface, virtual public PubsubInterface { /// \param log_length The number of entries that the log must have for the /// append to succeed. /// \return Status - Status AppendAt(const DriverID &driver_id, const ID &id, std::shared_ptr &data, + Status AppendAt(const DriverID &driver_id, const ID &id, std::shared_ptr &data, const WriteCallback &done, const WriteCallback &failure, int log_length); @@ -274,9 +259,10 @@ class Log : public LogInterface, virtual public PubsubInterface { template class TableInterface { public: + using DataT = typename Data::NativeTableType; using WriteCallback = typename Log::WriteCallback; virtual Status Add(const DriverID &driver_id, const ID &task_id, - std::shared_ptr &data, const WriteCallback &done) = 0; + std::shared_ptr &data, const WriteCallback &done) = 0; virtual ~TableInterface(){}; }; @@ -294,8 +280,9 @@ class Table : private Log, public TableInterface, virtual public PubsubInterface { public: + using DataT = typename Log::DataT; using Callback = - std::function; + std::function; using WriteCallback = typename Log::WriteCallback; /// The callback to call when a Lookup call returns an empty entry. using FailureCallback = std::function; @@ -318,7 +305,7 @@ class Table : private Log, /// \param done Callback that is called once the data has been written to the /// GCS. /// \return Status - Status Add(const DriverID &driver_id, const ID &id, std::shared_ptr &data, + Status Add(const DriverID &driver_id, const ID &id, std::shared_ptr &data, const WriteCallback &done); /// Lookup an entry asynchronously. @@ -382,11 +369,12 @@ class Table : private Log, template class SetInterface { public: + using DataT = typename Data::NativeTableType; using WriteCallback = typename Log::WriteCallback; - virtual Status Add(const DriverID &driver_id, const ID &id, std::shared_ptr &data, - const WriteCallback &done) = 0; + virtual Status Add(const DriverID &driver_id, const ID &id, + std::shared_ptr &data, const WriteCallback &done) = 0; virtual Status Remove(const DriverID &driver_id, const ID &id, - std::shared_ptr &data, const WriteCallback &done) = 0; + std::shared_ptr &data, const WriteCallback &done) = 0; virtual ~SetInterface(){}; }; @@ -404,6 +392,7 @@ class Set : private Log, public SetInterface, virtual public PubsubInterface { public: + using DataT = typename Log::DataT; using Callback = typename Log::Callback; using WriteCallback = typename Log::WriteCallback; using NotificationCallback = typename Log::NotificationCallback; @@ -425,7 +414,7 @@ class Set : private Log, /// \param done Callback that is called once the data has been written to the /// GCS. /// \return Status - Status Add(const DriverID &driver_id, const ID &id, std::shared_ptr &data, + Status Add(const DriverID &driver_id, const ID &id, std::shared_ptr &data, const WriteCallback &done); /// Remove an entry from the set. @@ -436,7 +425,7 @@ class Set : private Log, /// \param done Callback that is called once the data has been written to the /// GCS. /// \return Status - Status Remove(const DriverID &driver_id, const ID &id, std::shared_ptr &data, + Status Remove(const DriverID &driver_id, const ID &id, std::shared_ptr &data, const WriteCallback &done); Status Subscribe(const DriverID &driver_id, const ClientID &client_id, @@ -465,7 +454,8 @@ class Set : private Log, template class HashInterface { public: - using DataMap = std::unordered_map>; + using DataT = typename Data::NativeTableType; + using DataMap = std::unordered_map>; // Reuse Log's SubscriptionCallback when Subscribe is successfully called. using SubscriptionCallback = typename Log::SubscriptionCallback; @@ -554,7 +544,8 @@ class Hash : private Log, public HashInterface, virtual public PubsubInterface { public: - using DataMap = std::unordered_map>; + using DataT = typename Log::DataT; + using DataMap = std::unordered_map>; using HashCallback = typename HashInterface::HashCallback; using HashRemoveCallback = typename HashInterface::HashRemoveCallback; using HashNotificationCallback = @@ -604,7 +595,7 @@ class DynamicResourceTable : public Hash { DynamicResourceTable(const std::vector> &contexts, AsyncGcsClient *client) : Hash(contexts, client) { - pubsub_channel_ = TablePubsub::NODE_RESOURCE_PUBSUB; + pubsub_channel_ = TablePubsub::NODE_RESOURCE; prefix_ = TablePrefix::NODE_RESOURCE; }; @@ -616,7 +607,7 @@ class ObjectTable : public Set { ObjectTable(const std::vector> &contexts, AsyncGcsClient *client) : Set(contexts, client) { - pubsub_channel_ = TablePubsub::OBJECT_PUBSUB; + pubsub_channel_ = TablePubsub::OBJECT; prefix_ = TablePrefix::OBJECT; }; @@ -628,7 +619,7 @@ class HeartbeatTable : public Table { HeartbeatTable(const std::vector> &contexts, AsyncGcsClient *client) : Table(contexts, client) { - pubsub_channel_ = TablePubsub::HEARTBEAT_PUBSUB; + pubsub_channel_ = TablePubsub::HEARTBEAT; prefix_ = TablePrefix::HEARTBEAT; } virtual ~HeartbeatTable() {} @@ -639,7 +630,7 @@ class HeartbeatBatchTable : public Table { HeartbeatBatchTable(const std::vector> &contexts, AsyncGcsClient *client) : Table(contexts, client) { - pubsub_channel_ = TablePubsub::HEARTBEAT_BATCH_PUBSUB; + pubsub_channel_ = TablePubsub::HEARTBEAT_BATCH; prefix_ = TablePrefix::HEARTBEAT_BATCH; } virtual ~HeartbeatBatchTable() {} @@ -650,7 +641,7 @@ class DriverTable : public Log { DriverTable(const std::vector> &contexts, AsyncGcsClient *client) : Log(contexts, client) { - pubsub_channel_ = TablePubsub::DRIVER_PUBSUB; + pubsub_channel_ = TablePubsub::DRIVER; prefix_ = TablePrefix::DRIVER; }; @@ -664,6 +655,18 @@ class DriverTable : public Log { Status AppendDriverData(const DriverID &driver_id, bool is_dead); }; +class FunctionTable : public Table { + public: + FunctionTable(const std::vector> &contexts, + AsyncGcsClient *client) + : Table(contexts, client) { + pubsub_channel_ = TablePubsub::NO_PUBLISH; + prefix_ = TablePrefix::FUNCTION; + }; +}; + +using ClassTable = Table; + /// Actor table starts with an ALIVE entry, which represents the first time the actor /// is created. This may be followed by 0 or more pairs of RECONSTRUCTING, ALIVE entries, /// which represent each time the actor fails (RECONSTRUCTING) and gets recreated (ALIVE). @@ -674,7 +677,7 @@ class ActorTable : public Log { ActorTable(const std::vector> &contexts, AsyncGcsClient *client) : Log(contexts, client) { - pubsub_channel_ = TablePubsub::ACTOR_PUBSUB; + pubsub_channel_ = TablePubsub::ACTOR; prefix_ = TablePrefix::ACTOR; } }; @@ -693,12 +696,12 @@ class TaskLeaseTable : public Table { TaskLeaseTable(const std::vector> &contexts, AsyncGcsClient *client) : Table(contexts, client) { - pubsub_channel_ = TablePubsub::TASK_LEASE_PUBSUB; + pubsub_channel_ = TablePubsub::TASK_LEASE; prefix_ = TablePrefix::TASK_LEASE; } Status Add(const DriverID &driver_id, const TaskID &id, - std::shared_ptr &data, const WriteCallback &done) override { + std::shared_ptr &data, const WriteCallback &done) override { RAY_RETURN_NOT_OK((Table::Add(driver_id, id, data, done))); // Mark the entry for expiration in Redis. It's okay if this command fails // since the lease entry itself contains the expiration period. In the @@ -706,8 +709,9 @@ class TaskLeaseTable : public Table { // entry will overestimate the expiration time. // TODO(swang): Use a common helper function to format the key instead of // hardcoding it to match the Redis module. - std::vector args = {"PEXPIRE", TablePrefix_Name(prefix_) + id.Binary(), - std::to_string(data->timeout())}; + std::vector args = {"PEXPIRE", + EnumNameTablePrefix(prefix_) + id.Binary(), + std::to_string(data->timeout)}; return GetRedisContext(id)->RunArgvAsync(args); } @@ -743,12 +747,12 @@ class ActorCheckpointIdTable : public Table { namespace raylet { -class TaskTable : public Table { +class TaskTable : public Table { public: TaskTable(const std::vector> &contexts, AsyncGcsClient *client) : Table(contexts, client) { - pubsub_channel_ = TablePubsub::RAYLET_TASK_PUBSUB; + pubsub_channel_ = TablePubsub::RAYLET_TASK; prefix_ = TablePrefix::RAYLET_TASK; } @@ -766,7 +770,7 @@ class ErrorTable : private Log { ErrorTable(const std::vector> &contexts, AsyncGcsClient *client) : Log(contexts, client) { - pubsub_channel_ = TablePubsub::ERROR_INFO_PUBSUB; + pubsub_channel_ = TablePubsub::ERROR_INFO; prefix_ = TablePrefix::ERROR_INFO; }; @@ -811,6 +815,10 @@ class ProfileTable : private Log { std::string DebugString() const; }; +using CustomSerializerTable = Table; + +using ConfigTable = Table; + /// \class ClientTable /// /// The ClientTable stores information about active and inactive clients. It is @@ -823,7 +831,7 @@ class ProfileTable : private Log { class ClientTable : public Log { public: using ClientTableCallback = std::function; + AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data)>; using DisconnectCallback = std::function; ClientTable(const std::vector> &contexts, AsyncGcsClient *client, const ClientID &client_id) @@ -834,11 +842,11 @@ class ClientTable : public Log { disconnected_(false), client_id_(client_id), local_client_() { - pubsub_channel_ = TablePubsub::CLIENT_PUBSUB; + pubsub_channel_ = TablePubsub::CLIENT; prefix_ = TablePrefix::CLIENT; // Set the local client's ID. - local_client_.set_client_id(client_id.Binary()); + local_client_.client_id = client_id.Binary(); }; /// Connect as a client to the GCS. This registers us in the client table @@ -847,7 +855,7 @@ class ClientTable : public Log { /// \param Information about the connecting client. This must have the /// same client_id as the one set in the client table. /// \return Status - ray::Status Connect(const ClientTableData &local_client); + ray::Status Connect(const ClientTableDataT &local_client); /// Disconnect the client from the GCS. The client ID assigned during /// registration should never be reused after disconnecting. @@ -890,7 +898,7 @@ class ClientTable : public Log { /// about the client in the cache, then the reference will be modified to /// contain that information. Else, the reference will be updated to contain /// a nil client ID. - void GetClient(const ClientID &client, ClientTableData &client_info) const; + void GetClient(const ClientID &client, ClientTableDataT &client_info) const; /// Get the local client's ID. /// @@ -900,7 +908,7 @@ class ClientTable : public Log { /// Get the local client's information. /// /// \return The local client's information. - const ClientTableData &GetLocalClient() const; + const ClientTableDataT &GetLocalClient() const; /// Check whether the given client is removed. /// @@ -911,7 +919,7 @@ class ClientTable : public Log { /// Get the information of all clients. /// /// \return The client ID to client information map. - const std::unordered_map &GetAllClients() const; + const std::unordered_map &GetAllClients() const; /// Lookup the client data in the client table. /// @@ -932,15 +940,15 @@ class ClientTable : public Log { private: /// Handle a client table notification. - void HandleNotification(AsyncGcsClient *client, const ClientTableData ¬ifications); + void HandleNotification(AsyncGcsClient *client, const ClientTableDataT ¬ifications); /// Handle this client's successful connection to the GCS. - void HandleConnected(AsyncGcsClient *client, const ClientTableData &client_data); + void HandleConnected(AsyncGcsClient *client, const ClientTableDataT &client_data); /// Whether this client has called Disconnect(). bool disconnected_; /// This client's ID. const ClientID client_id_; /// Information about this client. - ClientTableData local_client_; + ClientTableDataT local_client_; /// The callback to call when a new client is added. ClientTableCallback client_added_callback_; /// The callback to call when a client is removed. @@ -950,7 +958,7 @@ class ClientTable : public Log { /// The callback to call when a resource is deleted. ClientTableCallback resource_deleted_callback_; /// A cache for information about all clients. - std::unordered_map client_cache_; + std::unordered_map client_cache_; /// The set of removed clients. std::unordered_set removed_clients_; }; diff --git a/src/ray/object_manager/object_directory.cc b/src/ray/object_manager/object_directory.cc index 454379d18302..5b6794a505d3 100644 --- a/src/ray/object_manager/object_directory.cc +++ b/src/ray/object_manager/object_directory.cc @@ -8,22 +8,18 @@ ObjectDirectory::ObjectDirectory(boost::asio::io_service &io_service, namespace { -using ray::rpc::ClientTableData; -using ray::rpc::GcsChangeMode; -using ray::rpc::ObjectTableData; - /// Process a notification of the object table entries and store the result in /// client_ids. This assumes that client_ids already contains the result of the /// object table entries up to but not including this notification. void UpdateObjectLocations(const GcsChangeMode change_mode, - const std::vector &location_updates, + const std::vector &location_updates, const ray::gcs::ClientTable &client_table, std::unordered_set *client_ids) { // location_updates contains the updates of locations of the object. // with GcsChangeMode, we can determine whether the update mode is // addition or deletion. for (const auto &object_table_data : location_updates) { - ClientID client_id = ClientID::FromBinary(object_table_data.manager()); + ClientID client_id = ClientID::FromBinary(object_table_data.manager); if (change_mode != GcsChangeMode::REMOVE) { client_ids->insert(client_id); } else { @@ -46,7 +42,7 @@ void ObjectDirectory::RegisterBackend() { auto object_notification_callback = [this](gcs::AsyncGcsClient *client, const ObjectID &object_id, const GcsChangeMode change_mode, - const std::vector &location_updates) { + const std::vector &location_updates) { // Objects are added to this map in SubscribeObjectLocations. auto it = listeners_.find(object_id); // Do nothing for objects we are not listening for. @@ -83,9 +79,9 @@ ray::Status ObjectDirectory::ReportObjectAdded( const object_manager::protocol::ObjectInfoT &object_info) { RAY_LOG(DEBUG) << "Reporting object added to GCS " << object_id; // Append the addition entry to the object table. - auto data = std::make_shared(); - data->set_manager(client_id.Binary()); - data->set_object_size(object_info.data_size); + auto data = std::make_shared(); + data->manager = client_id.Binary(); + data->object_size = object_info.data_size; ray::Status status = gcs_client_->object_table().Add(DriverID::Nil(), object_id, data, nullptr); return status; @@ -96,9 +92,9 @@ ray::Status ObjectDirectory::ReportObjectRemoved( const object_manager::protocol::ObjectInfoT &object_info) { RAY_LOG(DEBUG) << "Reporting object removed to GCS " << object_id; // Append the eviction entry to the object table. - auto data = std::make_shared(); - data->set_manager(client_id.Binary()); - data->set_object_size(object_info.data_size); + auto data = std::make_shared(); + data->manager = client_id.Binary(); + data->object_size = object_info.data_size; ray::Status status = gcs_client_->object_table().Remove(DriverID::Nil(), object_id, data, nullptr); return status; @@ -106,14 +102,14 @@ ray::Status ObjectDirectory::ReportObjectRemoved( void ObjectDirectory::LookupRemoteConnectionInfo( RemoteConnectionInfo &connection_info) const { - ClientTableData client_data; + ClientTableDataT client_data; gcs_client_->client_table().GetClient(connection_info.client_id, client_data); - ClientID result_client_id = ClientID::FromBinary(client_data.client_id()); + ClientID result_client_id = ClientID::FromBinary(client_data.client_id); if (!result_client_id.IsNil()) { RAY_CHECK(result_client_id == connection_info.client_id); - if (client_data.entry_type() == ClientTableData::INSERTION) { - connection_info.ip = client_data.node_manager_address(); - connection_info.port = static_cast(client_data.object_manager_port()); + if (client_data.entry_type == EntryType::INSERTION) { + connection_info.ip = client_data.node_manager_address; + connection_info.port = static_cast(client_data.object_manager_port); } } } @@ -212,7 +208,7 @@ ray::Status ObjectDirectory::LookupLocations(const ObjectID &object_id, status = gcs_client_->object_table().Lookup( DriverID::Nil(), object_id, [this, callback](gcs::AsyncGcsClient *client, const ObjectID &object_id, - const std::vector &location_updates) { + const std::vector &location_updates) { // Build the set of current locations based on the entries in the log. std::unordered_set client_ids; UpdateObjectLocations(GcsChangeMode::APPEND_OR_ADD, location_updates, diff --git a/src/ray/object_manager/object_manager.cc b/src/ray/object_manager/object_manager.cc index 964cee605ced..954162c21aef 100644 --- a/src/ray/object_manager/object_manager.cc +++ b/src/ray/object_manager/object_manager.cc @@ -309,15 +309,15 @@ void ObjectManager::HandleSendFinished(const ObjectID &object_id, // TODO(rkn): What do we want to do if the send failed? } - rpc::ProfileTableData::ProfileEvent profile_event; - profile_event.set_event_type("transfer_send"); - profile_event.set_start_time(start_time); - profile_event.set_end_time(end_time); + ProfileEventT profile_event; + profile_event.event_type = "transfer_send"; + profile_event.start_time = start_time; + profile_event.end_time = end_time; // Encode the object ID, client ID, chunk index, and status as a json list, // which will be parsed by the reader of the profile table. - profile_event.set_extra_data("[\"" + object_id.Hex() + "\",\"" + client_id.Hex() + - "\"," + std::to_string(chunk_index) + ",\"" + - status.ToString() + "\"]"); + profile_event.extra_data = "[\"" + object_id.Hex() + "\",\"" + client_id.Hex() + "\"," + + std::to_string(chunk_index) + ",\"" + status.ToString() + + "\"]"; profile_events_.push_back(profile_event); } @@ -329,15 +329,15 @@ void ObjectManager::HandleReceiveFinished(const ObjectID &object_id, // TODO(rkn): What do we want to do if the send failed? } - rpc::ProfileTableData::ProfileEvent profile_event; - profile_event.set_event_type("transfer_receive"); - profile_event.set_start_time(start_time); - profile_event.set_end_time(end_time); + ProfileEventT profile_event; + profile_event.event_type = "transfer_receive"; + profile_event.start_time = start_time; + profile_event.end_time = end_time; // Encode the object ID, client ID, chunk index, and status as a json list, // which will be parsed by the reader of the profile table. - profile_event.set_extra_data("[\"" + object_id.Hex() + "\",\"" + client_id.Hex() + - "\"," + std::to_string(chunk_index) + ",\"" + - status.ToString() + "\"]"); + profile_event.extra_data = "[\"" + object_id.Hex() + "\",\"" + client_id.Hex() + "\"," + + std::to_string(chunk_index) + ",\"" + status.ToString() + + "\"]"; profile_events_.push_back(profile_event); } @@ -801,12 +801,11 @@ void ObjectManager::ReceivePullRequest(std::shared_ptr &con ObjectID object_id = ObjectID::FromBinary(pr->object_id()->str()); ClientID client_id = ClientID::FromBinary(pr->client_id()->str()); - rpc::ProfileTableData::ProfileEvent profile_event; - profile_event.set_event_type("receive_pull_request"); - profile_event.set_start_time(current_sys_time_seconds()); - profile_event.set_end_time(profile_event.start_time()); - profile_event.set_extra_data("[\"" + object_id.Hex() + "\",\"" + client_id.Hex() + - "\"]"); + ProfileEventT profile_event; + profile_event.event_type = "receive_pull_request"; + profile_event.start_time = current_sys_time_seconds(); + profile_event.end_time = profile_event.start_time; + profile_event.extra_data = "[\"" + object_id.Hex() + "\",\"" + client_id.Hex() + "\"]"; profile_events_.push_back(profile_event); Push(object_id, client_id); @@ -939,13 +938,13 @@ void ObjectManager::SpreadFreeObjectRequest(const std::vector &object_ } } -rpc::ProfileTableData ObjectManager::GetAndResetProfilingInfo() { - rpc::ProfileTableData profile_info; - profile_info.set_component_type("object_manager"); - profile_info.set_component_id(client_id_.Binary()); +ProfileTableDataT ObjectManager::GetAndResetProfilingInfo() { + ProfileTableDataT profile_info; + profile_info.component_type = "object_manager"; + profile_info.component_id = client_id_.Binary(); for (auto const &profile_event : profile_events_) { - profile_info.add_profile_events()->CopyFrom(profile_event); + profile_info.profile_events.emplace_back(new ProfileEventT(profile_event)); } profile_events_.clear(); diff --git a/src/ray/object_manager/object_manager.h b/src/ray/object_manager/object_manager.h index 6664dd0a93bd..6318250ae3e8 100644 --- a/src/ray/object_manager/object_manager.h +++ b/src/ray/object_manager/object_manager.h @@ -180,7 +180,7 @@ class ObjectManager : public ObjectManagerInterface { /// /// \return All profiling information that has accumulated since the last call /// to this method. - rpc::ProfileTableData GetAndResetProfilingInfo(); + ProfileTableDataT GetAndResetProfilingInfo(); /// Returns debug string for class. /// @@ -412,7 +412,7 @@ class ObjectManager : public ObjectManagerInterface { /// Profiling events that are to be batched together and added to the profile /// table in the GCS. - std::vector profile_events_; + std::vector profile_events_; /// Internally maintained random number generator. std::mt19937_64 gen_; diff --git a/src/ray/object_manager/test/object_manager_stress_test.cc b/src/ray/object_manager/test/object_manager_stress_test.cc index 2d5292842acf..55aa59124a99 100644 --- a/src/ray/object_manager/test/object_manager_stress_test.cc +++ b/src/ray/object_manager/test/object_manager_stress_test.cc @@ -11,8 +11,6 @@ namespace ray { -using rpc::ClientTableData; - std::string store_executable; static inline void flushall_redis(void) { @@ -54,10 +52,10 @@ class MockServer { std::string ip = endpoint.address().to_string(); unsigned short object_manager_port = endpoint.port(); - ClientTableData client_info = gcs_client_->client_table().GetLocalClient(); - client_info.set_node_manager_address(ip); - client_info.set_node_manager_port(object_manager_port); - client_info.set_object_manager_port(object_manager_port); + ClientTableDataT client_info = gcs_client_->client_table().GetLocalClient(); + client_info.node_manager_address = ip; + client_info.node_manager_port = object_manager_port; + client_info.object_manager_port = object_manager_port; ray::Status status = gcs_client_->client_table().Connect(client_info); object_manager_.RegisterGcs(); return status; @@ -244,8 +242,8 @@ class StressTestObjectManager : public TestObjectManagerBase { client_id_2 = gcs_client_2->client_table().GetLocalClientId(); gcs_client_1->client_table().RegisterClientAddedCallback( [this](gcs::AsyncGcsClient *client, const ClientID &id, - const ClientTableData &data) { - ClientID parsed_id = ClientID::FromBinary(data.client_id()); + const ClientTableDataT &data) { + ClientID parsed_id = ClientID::FromBinary(data.client_id); if (parsed_id == client_id_1 || parsed_id == client_id_2) { num_connected_clients += 1; } @@ -440,16 +438,16 @@ class StressTestObjectManager : public TestObjectManagerBase { RAY_LOG(DEBUG) << "\n" << "All connected clients:" << "\n"; - ClientTableData data; + ClientTableDataT data; gcs_client_1->client_table().GetClient(client_id_1, data); - RAY_LOG(DEBUG) << "ClientID=" << ClientID::FromBinary(data.client_id()) << "\n" - << "ClientIp=" << data.node_manager_address() << "\n" - << "ClientPort=" << data.node_manager_port(); - ClientTableData data2; + RAY_LOG(DEBUG) << "ClientID=" << ClientID::FromBinary(data.client_id) << "\n" + << "ClientIp=" << data.node_manager_address << "\n" + << "ClientPort=" << data.node_manager_port; + ClientTableDataT data2; gcs_client_1->client_table().GetClient(client_id_2, data2); - RAY_LOG(DEBUG) << "ClientID=" << ClientID::FromBinary(data2.client_id()) << "\n" - << "ClientIp=" << data2.node_manager_address() << "\n" - << "ClientPort=" << data2.node_manager_port(); + RAY_LOG(DEBUG) << "ClientID=" << ClientID::FromBinary(data2.client_id) << "\n" + << "ClientIp=" << data2.node_manager_address << "\n" + << "ClientPort=" << data2.node_manager_port; } }; diff --git a/src/ray/object_manager/test/object_manager_test.cc b/src/ray/object_manager/test/object_manager_test.cc index 45b80a267f2f..ee6c78d8ed42 100644 --- a/src/ray/object_manager/test/object_manager_test.cc +++ b/src/ray/object_manager/test/object_manager_test.cc @@ -14,8 +14,6 @@ int64_t wait_timeout_ms; namespace ray { -using rpc::ClientTableData; - static inline void flushall_redis(void) { redisContext *context = redisConnect("127.0.0.1", 6379); freeReplyObject(redisCommand(context, "FLUSHALL")); @@ -48,10 +46,10 @@ class MockServer { std::string ip = endpoint.address().to_string(); unsigned short object_manager_port = endpoint.port(); - ClientTableData client_info = gcs_client_->client_table().GetLocalClient(); - client_info.set_node_manager_address(ip); - client_info.set_node_manager_port(object_manager_port); - client_info.set_object_manager_port(object_manager_port); + ClientTableDataT client_info = gcs_client_->client_table().GetLocalClient(); + client_info.node_manager_address = ip; + client_info.node_manager_port = object_manager_port; + client_info.object_manager_port = object_manager_port; ray::Status status = gcs_client_->client_table().Connect(client_info); object_manager_.RegisterGcs(); return status; @@ -223,8 +221,8 @@ class TestObjectManager : public TestObjectManagerBase { client_id_2 = gcs_client_2->client_table().GetLocalClientId(); gcs_client_1->client_table().RegisterClientAddedCallback( [this](gcs::AsyncGcsClient *client, const ClientID &id, - const ClientTableData &data) { - ClientID parsed_id = ClientID::FromBinary(data.client_id()); + const ClientTableDataT &data) { + ClientID parsed_id = ClientID::FromBinary(data.client_id); if (parsed_id == client_id_1 || parsed_id == client_id_2) { num_connected_clients += 1; } @@ -459,19 +457,19 @@ class TestObjectManager : public TestObjectManagerBase { RAY_LOG(DEBUG) << "\n" << "Server client ids:" << "\n"; - ClientTableData data; + ClientTableDataT data; gcs_client_1->client_table().GetClient(client_id_1, data); - RAY_LOG(DEBUG) << (ClientID::FromBinary(data.client_id()).IsNil()); - RAY_LOG(DEBUG) << "Server 1 ClientID=" << ClientID::FromBinary(data.client_id()); - RAY_LOG(DEBUG) << "Server 1 ClientIp=" << data.node_manager_address(); - RAY_LOG(DEBUG) << "Server 1 ClientPort=" << data.node_manager_port(); - ASSERT_EQ(client_id_1, ClientID::FromBinary(data.client_id())); - ClientTableData data2; + RAY_LOG(DEBUG) << (ClientID::FromBinary(data.client_id).IsNil()); + RAY_LOG(DEBUG) << "Server 1 ClientID=" << ClientID::FromBinary(data.client_id); + RAY_LOG(DEBUG) << "Server 1 ClientIp=" << data.node_manager_address; + RAY_LOG(DEBUG) << "Server 1 ClientPort=" << data.node_manager_port; + ASSERT_EQ(client_id_1, ClientID::FromBinary(data.client_id)); + ClientTableDataT data2; gcs_client_1->client_table().GetClient(client_id_2, data2); - RAY_LOG(DEBUG) << "Server 2 ClientID=" << ClientID::FromBinary(data2.client_id()); - RAY_LOG(DEBUG) << "Server 2 ClientIp=" << data2.node_manager_address(); - RAY_LOG(DEBUG) << "Server 2 ClientPort=" << data2.node_manager_port(); - ASSERT_EQ(client_id_2, ClientID::FromBinary(data2.client_id())); + RAY_LOG(DEBUG) << "Server 2 ClientID=" << ClientID::FromBinary(data2.client_id); + RAY_LOG(DEBUG) << "Server 2 ClientIp=" << data2.node_manager_address; + RAY_LOG(DEBUG) << "Server 2 ClientPort=" << data2.node_manager_port; + ASSERT_EQ(client_id_2, ClientID::FromBinary(data2.client_id)); } }; diff --git a/src/ray/protobuf/gcs.proto b/src/ray/protobuf/gcs.proto deleted file mode 100644 index d0b2c5e007fe..000000000000 --- a/src/ray/protobuf/gcs.proto +++ /dev/null @@ -1,280 +0,0 @@ -syntax = "proto3"; - -package ray.rpc; - -option java_package = "org.ray.runtime.generated"; - -// Language of a worker or task. -enum Language { - PYTHON = 0; - CPP = 1; - JAVA = 2; -} - -// These indexes are mapped to strings in ray_redis_module.cc. -enum TablePrefix { - TABLE_PREFIX_MIN = 0; - UNUSED = 1; - TASK = 2; - RAYLET_TASK = 3; - CLIENT = 4; - OBJECT = 5; - ACTOR = 6; - FUNCTION = 7; - TASK_RECONSTRUCTION = 8; - HEARTBEAT = 9; - HEARTBEAT_BATCH = 10; - ERROR_INFO = 11; - DRIVER = 12; - PROFILE = 13; - TASK_LEASE = 14; - ACTOR_CHECKPOINT = 15; - ACTOR_CHECKPOINT_ID = 16; - NODE_RESOURCE = 17; - TABLE_PREFIX_MAX = 18; -} - -// The channel that Add operations to the Table should be published on, if any. -enum TablePubsub { - TABLE_PUBSUB_MIN = 0; - NO_PUBLISH = 1; - TASK_PUBSUB = 2; - RAYLET_TASK_PUBSUB = 3; - CLIENT_PUBSUB = 4; - OBJECT_PUBSUB = 5; - ACTOR_PUBSUB = 6; - HEARTBEAT_PUBSUB = 7; - HEARTBEAT_BATCH_PUBSUB = 8; - ERROR_INFO_PUBSUB = 9; - TASK_LEASE_PUBSUB = 10; - DRIVER_PUBSUB = 11; - NODE_RESOURCE_PUBSUB = 12; - TABLE_PUBSUB_MAX = 13; -} - -enum GcsChangeMode { - APPEND_OR_ADD = 0; - REMOVE = 1; -} - -message GcsEntry { - GcsChangeMode change_mode = 1; - bytes id = 2; - repeated bytes entries = 3; -} - -message ObjectTableData { - // The size of the object. - uint64 object_size = 1; - // The node manager ID that this object appeared on or was evicted by. - bytes manager = 2; -} - -message TaskReconstructionData { - // The number of times this task has been reconstructed so far. - uint64 num_reconstructions = 1; - // The node manager that is trying to reconstruct the task. - bytes node_manager_id = 2; -} - -// TODO(hchen): Task table currently still uses flatbuffers-defined data structure -// (`Task` in `node_manager.fbs`), because a lot of code depends on that. This should -// be migrated to protobuf very soon. -message TaskTableData { - // Flatbuffers-serialized content of the task, see `src/ray/raylet/task.h`. - bytes task = 1; -} - -message ActorTableData { - // State of an actor. - enum ActorState { - // Actor is alive. - ALIVE = 0; - // Actor is dead, now being reconstructed. - // After reconstruction finishes, the state will become alive again. - RECONSTRUCTING = 1; - // Actor is already dead and won't be reconstructed. - DEAD = 2; - } - // The ID of the actor that was created. - bytes actor_id = 1; - // The dummy object ID returned by the actor creation task. If the actor - // dies, then this is the object that should be reconstructed for the actor - // to be recreated. - bytes actor_creation_dummy_object_id = 2; - // The ID of the driver that created the actor. - bytes driver_id = 3; - // The ID of the node manager that created the actor. - bytes node_manager_id = 4; - // Current state of this actor. - ActorState state = 5; - // Max number of times this actor should be reconstructed. - uint64 max_reconstructions = 6; - // Remaining number of reconstructions. - uint64 remaining_reconstructions = 7; -} - -message ErrorTableData { - // The ID of the driver that the error is for. - bytes driver_id = 1; - // The type of the error. - string type = 2; - // The error message. - string error_message = 3; - // The timestamp of the error message. - double timestamp = 4; -} - -message ProfileTableData { - // Represents a profile event. - message ProfileEvent { - // The type of the event. - string event_type = 1; - // The start time of the event. - double start_time = 2; - // The end time of the event. If the event is a point event, then this should - // be the same as the start time. - double end_time = 3; - // Additional data associated with the event. This data must be serialized - // using JSON. - string extra_data = 4; - } - - // The type of the component that generated the event, e.g., worker or - // object_manager, or node_manager. - string component_type = 1; - // An identifier for the component that generated the event. - bytes component_id = 2; - // An identifier for the node that generated the event. - string node_ip_address = 3; - // This is a batch of profiling events. We batch these together for - // performance reasons because a single task may generate many events, and - // we don't want each event to require a GCS command. - repeated ProfileEvent profile_events = 4; -} - -message RayResource { - // The type of the resource. - string resource_name = 1; - // The total capacity of this resource type. - double resource_capacity = 2; -} - -message ClientTableData { - // Enum for the entry type in the ClientTable - enum EntryType { - INSERTION = 0; - DELETION = 1; - RES_CREATEUPDATE = 2; - RES_DELETE = 3; - } - - // The client ID of the client that the message is about. - bytes client_id = 1; - // The IP address of the client's node manager. - string node_manager_address = 2; - // The IPC socket name of the client's raylet. - string raylet_socket_name = 3; - // The IPC socket name of the client's plasma store. - string object_store_socket_name = 4; - // The port at which the client's node manager is listening for TCP - // connections from other node managers. - int32 node_manager_port = 5; - // The port at which the client's object manager is listening for TCP - // connections from other object managers. - int32 object_manager_port = 6; - // Enum to store the entry type in the log - EntryType entry_type = 7; - - // TODO(hchen): Define the following resources in map format. - repeated string resources_total_label = 8; - repeated double resources_total_capacity = 9; -} - -message HeartbeatTableData { - // Node manager client id - bytes client_id = 1; - // TODO(hchen): Define the following resources in map format. - // Resource capacity currently available on this node manager. - repeated string resources_available_label = 2; - repeated double resources_available_capacity = 3; - // Total resource capacity configured for this node manager. - repeated string resources_total_label = 4; - repeated double resources_total_capacity = 5; - // Aggregate outstanding resource load on this node manager. - repeated string resource_load_label = 6; - repeated double resource_load_capacity = 7; -} - -message HeartbeatBatchTableData { - repeated HeartbeatTableData batch = 1; -} - -// Data for a lease on task execution. -message TaskLeaseData { - // Node manager client ID. - bytes node_manager_id = 1; - // The time that the lease was last acquired at. NOTE(swang): This is the - // system clock time according to the node that added the entry and is not - // synchronized with other nodes. - uint64 acquired_at = 2; - // The period that the lease is active for. - uint64 timeout = 3; -} - -message DriverTableData { - // The driver ID. - bytes driver_id = 1; - // Whether it's dead. - bool is_dead = 2; -} - -// This table stores the actor checkpoint data. An actor checkpoint -// is the snapshot of an actor's state in the actor registration. -// See `actor_registration.h` for more detailed explanation of these fields. -message ActorCheckpointData { - // ID of this actor. - bytes actor_id = 1; - // The dummy object ID of actor's most recently executed task. - bytes execution_dependency = 2; - // A list of IDs of this actor's handles. - repeated bytes handle_ids = 3; - // The task counters of the above handles. - repeated uint64 task_counters = 4; - // The frontier dependencies of the above handles. - repeated bytes frontier_dependencies = 5; - // A list of unreleased dummy objects from this actor. - repeated bytes unreleased_dummy_objects = 6; - // The numbers of dependencies for the above unreleased dummy objects. - repeated uint32 num_dummy_object_dependencies = 7; -} - -// This table stores the actor-to-available-checkpoint-ids mapping. -message ActorCheckpointIdData { - // ID of this actor. - bytes actor_id = 1; - // IDs of this actor's available checkpoints. - repeated bytes checkpoint_ids = 2; - // A list of the timestamps for each of the above `checkpoint_ids`. - repeated uint64 timestamps = 3; -} - -// This enum type is used as object's metadata to indicate the object's creating -// task has failed because of a certain error. -// TODO(hchen): We may want to make these errors more specific. E.g., we may want -// to distinguish between intentional and expected actor failures, and between -// worker process failure and node failure. -enum ErrorType { - // Indicates that a task failed because the worker died unexpectedly while executing it. - WORKER_DIED = 0; - // Indicates that a task failed because the actor died unexpectedly before finishing it. - ACTOR_DIED = 1; - // Indicates that an object is lost and cannot be reconstructed. - // Note, this currently only happens to actor objects. When the actor's state is already - // after the object's creating task, the actor cannot re-run the task. - // TODO(hchen): we may want to reuse this error type for more cases. E.g., - // 1) A object that was put by the driver. - // 2) The object's creating task is already cleaned up from GCS (this currently - // crashes raylet). - OBJECT_UNRECONSTRUCTABLE = 2; -} diff --git a/src/ray/raylet/actor_registration.cc b/src/ray/raylet/actor_registration.cc index 7f940006b5be..cc587bc4d74e 100644 --- a/src/ray/raylet/actor_registration.cc +++ b/src/ray/raylet/actor_registration.cc @@ -8,35 +8,34 @@ namespace ray { namespace raylet { -ActorRegistration::ActorRegistration(const ActorTableData &actor_table_data) +ActorRegistration::ActorRegistration(const ActorTableDataT &actor_table_data) : actor_table_data_(actor_table_data) {} -ActorRegistration::ActorRegistration(const ActorTableData &actor_table_data, - const ActorCheckpointData &checkpoint_data) +ActorRegistration::ActorRegistration(const ActorTableDataT &actor_table_data, + const ActorCheckpointDataT &checkpoint_data) : actor_table_data_(actor_table_data), - execution_dependency_( - ObjectID::FromBinary(checkpoint_data.execution_dependency())) { + execution_dependency_(ObjectID::FromBinary(checkpoint_data.execution_dependency)) { // Restore `frontier_`. - for (size_t i = 0; i < checkpoint_data.handle_ids_size(); i++) { - auto handle_id = ActorHandleID::FromBinary(checkpoint_data.handle_ids(i)); + for (size_t i = 0; i < checkpoint_data.handle_ids.size(); i++) { + auto handle_id = ActorHandleID::FromBinary(checkpoint_data.handle_ids[i]); auto &frontier_entry = frontier_[handle_id]; - frontier_entry.task_counter = checkpoint_data.task_counters(i); + frontier_entry.task_counter = checkpoint_data.task_counters[i]; frontier_entry.execution_dependency = - ObjectID::FromBinary(checkpoint_data.frontier_dependencies(i)); + ObjectID::FromBinary(checkpoint_data.frontier_dependencies[i]); } // Restore `dummy_objects_`. - for (size_t i = 0; i < checkpoint_data.unreleased_dummy_objects_size(); i++) { - auto dummy = ObjectID::FromBinary(checkpoint_data.unreleased_dummy_objects(i)); - dummy_objects_[dummy] = checkpoint_data.num_dummy_object_dependencies(i); + for (size_t i = 0; i < checkpoint_data.unreleased_dummy_objects.size(); i++) { + auto dummy = ObjectID::FromBinary(checkpoint_data.unreleased_dummy_objects[i]); + dummy_objects_[dummy] = checkpoint_data.num_dummy_object_dependencies[i]; } } const ClientID ActorRegistration::GetNodeManagerId() const { - return ClientID::FromBinary(actor_table_data_.node_manager_id()); + return ClientID::FromBinary(actor_table_data_.node_manager_id); } const ObjectID ActorRegistration::GetActorCreationDependency() const { - return ObjectID::FromBinary(actor_table_data_.actor_creation_dummy_object_id()); + return ObjectID::FromBinary(actor_table_data_.actor_creation_dummy_object_id); } const ObjectID ActorRegistration::GetExecutionDependency() const { @@ -44,15 +43,15 @@ const ObjectID ActorRegistration::GetExecutionDependency() const { } const DriverID ActorRegistration::GetDriverId() const { - return DriverID::FromBinary(actor_table_data_.driver_id()); + return DriverID::FromBinary(actor_table_data_.driver_id); } const int64_t ActorRegistration::GetMaxReconstructions() const { - return actor_table_data_.max_reconstructions(); + return actor_table_data_.max_reconstructions; } const int64_t ActorRegistration::GetRemainingReconstructions() const { - return actor_table_data_.remaining_reconstructions(); + return actor_table_data_.remaining_reconstructions; } const std::unordered_map @@ -97,7 +96,7 @@ void ActorRegistration::AddHandle(const ActorHandleID &handle_id, int ActorRegistration::NumHandles() const { return frontier_.size(); } -std::shared_ptr ActorRegistration::GenerateCheckpointData( +std::shared_ptr ActorRegistration::GenerateCheckpointData( const ActorID &actor_id, const Task &task) { const auto actor_handle_id = task.GetTaskSpecification().ActorHandleId(); const auto dummy_object = task.GetTaskSpecification().ActorDummyObject(); @@ -110,18 +109,18 @@ std::shared_ptr ActorRegistration::GenerateCheckpointData( copy.ExtendFrontier(actor_handle_id, dummy_object); // Use actor's current state to generate checkpoint data. - auto checkpoint_data = std::make_shared(); - checkpoint_data->set_actor_id(actor_id.Binary()); - checkpoint_data->set_execution_dependency(copy.GetExecutionDependency().Binary()); + auto checkpoint_data = std::make_shared(); + checkpoint_data->actor_id = actor_id.Binary(); + checkpoint_data->execution_dependency = copy.GetExecutionDependency().Binary(); for (const auto &frontier : copy.GetFrontier()) { - checkpoint_data->add_handle_ids(frontier.first.Binary()); - checkpoint_data->add_task_counters(frontier.second.task_counter); - checkpoint_data->add_frontier_dependencies( + checkpoint_data->handle_ids.push_back(frontier.first.Binary()); + checkpoint_data->task_counters.push_back(frontier.second.task_counter); + checkpoint_data->frontier_dependencies.push_back( frontier.second.execution_dependency.Binary()); } for (const auto &entry : copy.GetDummyObjects()) { - checkpoint_data->add_unreleased_dummy_objects(entry.first.Binary()); - checkpoint_data->add_num_dummy_object_dependencies(entry.second); + checkpoint_data->unreleased_dummy_objects.push_back(entry.first.Binary()); + checkpoint_data->num_dummy_object_dependencies.push_back(entry.second); } return checkpoint_data; } diff --git a/src/ray/raylet/actor_registration.h b/src/ray/raylet/actor_registration.h index 208e4998263f..8d7ce2a449ec 100644 --- a/src/ray/raylet/actor_registration.h +++ b/src/ray/raylet/actor_registration.h @@ -4,17 +4,13 @@ #include #include "ray/common/id.h" -#include "ray/protobuf/gcs.pb.h" +#include "ray/gcs/format/gcs_generated.h" #include "ray/raylet/task.h" namespace ray { namespace raylet { -using rpc::ActorTableData; -using ActorState = rpc::ActorTableData::ActorState; -using rpc::ActorCheckpointData; - /// \class ActorRegistration /// /// Information about an actor registered in the system. This includes the @@ -27,13 +23,13 @@ class ActorRegistration { /// /// \param actor_table_data Information from the global actor table about /// this actor. This includes the actor's node manager location. - explicit ActorRegistration(const ActorTableData &actor_table_data); + ActorRegistration(const ActorTableDataT &actor_table_data); /// Recreate an actor's registration from a checkpoint. /// /// \param checkpoint_data The checkpoint used to restore the actor. - ActorRegistration(const ActorTableData &actor_table_data, - const ActorCheckpointData &checkpoint_data); + ActorRegistration(const ActorTableDataT &actor_table_data, + const ActorCheckpointDataT &checkpoint_data); /// Each actor may have multiple callers, or "handles". A frontier leaf /// represents the execution state of the actor with respect to a single @@ -50,15 +46,15 @@ class ActorRegistration { /// Get the actor table data. /// /// \return The actor table data. - const ActorTableData &GetTableData() const { return actor_table_data_; } + const ActorTableDataT &GetTableData() const { return actor_table_data_; } /// Get the actor's current state (ALIVE or DEAD). /// /// \return The actor's current state. - const ActorState GetState() const { return actor_table_data_.state(); } + const ActorState &GetState() const { return actor_table_data_.state; } /// Update actor's state. - void SetState(const ActorState &state) { actor_table_data_.set_state(state); } + void SetState(const ActorState &state) { actor_table_data_.state = state; } /// Get the actor's node manager location. /// @@ -135,13 +131,13 @@ class ActorRegistration { /// \param actor_id ID of this actor. /// \param task The task that just finished on the actor. /// \return A shared pointer to the generated checkpoint data. - std::shared_ptr GenerateCheckpointData(const ActorID &actor_id, - const Task &task); + std::shared_ptr GenerateCheckpointData(const ActorID &actor_id, + const Task &task); private: /// Information from the global actor table about this actor, including the /// node manager location. - ActorTableData actor_table_data_; + ActorTableDataT actor_table_data_; /// The object representing the state following the actor's most recently /// executed task. The next task to execute on the actor should be marked as /// execution-dependent on this object. diff --git a/src/ray/raylet/lineage_cache.cc b/src/ray/raylet/lineage_cache.cc index 68d5aa817c2b..32dddada5244 100644 --- a/src/ray/raylet/lineage_cache.cc +++ b/src/ray/raylet/lineage_cache.cc @@ -63,6 +63,15 @@ void LineageEntry::UpdateTaskData(const Task &task) { Lineage::Lineage() {} +Lineage::Lineage(const protocol::ForwardTaskRequest &task_request) { + // Deserialize and set entries for the uncommitted tasks. + auto tasks = task_request.uncommitted_tasks(); + for (auto it = tasks->begin(); it != tasks->end(); it++) { + const auto &task = **it; + RAY_CHECK(SetEntry(task, GcsStatus::UNCOMMITTED)); + } +} + boost::optional Lineage::GetEntry(const TaskID &task_id) const { auto entry = entries_.find(task_id); if (entry != entries_.end()) { @@ -142,6 +151,20 @@ const std::unordered_map &Lineage::GetEntries() cons return entries_; } +flatbuffers::Offset Lineage::ToFlatbuffer( + flatbuffers::FlatBufferBuilder &fbb, const TaskID &task_id) const { + RAY_CHECK(GetEntry(task_id)); + // Serialize the task and object entries. + std::vector> uncommitted_tasks; + for (const auto &entry : entries_) { + uncommitted_tasks.push_back(entry.second.TaskData().ToFlatbuffer(fbb)); + } + + auto request = protocol::CreateForwardTaskRequest(fbb, to_flatbuf(fbb, task_id), + fbb.CreateVector(uncommitted_tasks)); + return request; +} + const std::unordered_set &Lineage::GetChildren(const TaskID &task_id) const { static const std::unordered_set empty_children; const auto it = children_.find(task_id); @@ -153,7 +176,7 @@ const std::unordered_set &Lineage::GetChildren(const TaskID &task_id) co } LineageCache::LineageCache(const ClientID &client_id, - gcs::TableInterface &task_storage, + gcs::TableInterface &task_storage, gcs::PubsubInterface &task_pubsub, uint64_t max_lineage_size) : client_id_(client_id), task_storage_(task_storage), task_pubsub_(task_pubsub) {} @@ -269,11 +292,15 @@ void LineageCache::FlushTask(const TaskID &task_id) { gcs::raylet::TaskTable::WriteCallback task_callback = [this](ray::gcs::AsyncGcsClient *client, const TaskID &id, - const TaskTableData &data) { HandleEntryCommitted(id); }; + const protocol::TaskT &data) { HandleEntryCommitted(id); }; auto task = lineage_.GetEntry(task_id); // TODO(swang): Make this better... - auto task_data = std::make_shared(); - task_data->set_task(task->TaskData().Serialize()); + flatbuffers::FlatBufferBuilder fbb; + auto message = task->TaskData().ToFlatbuffer(fbb); + fbb.Finish(message); + auto task_data = std::make_shared(); + auto root = flatbuffers::GetRoot(fbb.GetBufferPointer()); + root->UnPackTo(task_data.get()); RAY_CHECK_OK( task_storage_.Add(DriverID(task->TaskData().GetTaskSpecification().DriverId()), task_id, task_data, task_callback)); @@ -338,6 +365,8 @@ void LineageCache::EvictTask(const TaskID &task_id) { for (const auto &child_id : children) { EvictTask(child_id); } + + return; } void LineageCache::HandleEntryCommitted(const TaskID &task_id) { diff --git a/src/ray/raylet/lineage_cache.h b/src/ray/raylet/lineage_cache.h index 37ce5caf6507..5436fa372fa4 100644 --- a/src/ray/raylet/lineage_cache.h +++ b/src/ray/raylet/lineage_cache.h @@ -4,17 +4,18 @@ #include #include +// clang-format off +#include "ray/common/common_protocol.h" +#include "ray/raylet/task.h" +#include "ray/gcs/tables.h" #include "ray/common/id.h" #include "ray/common/status.h" -#include "ray/gcs/tables.h" -#include "ray/raylet/task.h" +// clang-format on namespace ray { namespace raylet { -using rpc::TaskTableData; - /// The status of a lineage cache entry according to its status in the GCS. /// Tasks can only transition to a higher GcsStatus (e.g., an UNCOMMITTED state /// can become COMMITTING but not vice versa). If a task is evicted from the @@ -135,6 +136,12 @@ class Lineage { /// Construct an empty Lineage. Lineage(); + /// Construct a Lineage from a ForwardTaskRequest. + /// + /// \param task_request The request to construct the lineage from. All + /// uncommitted tasks in the request will be added to the lineage. + Lineage(const protocol::ForwardTaskRequest &task_request); + /// Get an entry from the lineage. /// /// \param entry_id The ID of the entry to get. @@ -165,6 +172,15 @@ class Lineage { /// \return A const reference to the lineage entries. const std::unordered_map &GetEntries() const; + /// Serialize this lineage to a ForwardTaskRequest flatbuffer. + /// + /// \param entry_id The task ID to include in the ForwardTaskRequest + /// flatbuffer. + /// \return An offset to the serialized lineage. The serialization includes + /// all task and object entries in the lineage. + flatbuffers::Offset ToFlatbuffer( + flatbuffers::FlatBufferBuilder &fbb, const TaskID &entry_id) const; + /// Return the IDs of tasks in the lineage that are dependent on the given /// task. /// @@ -205,7 +221,7 @@ class LineageCache { /// Create a lineage cache for the given task storage system. /// TODO(swang): Pass in the policy (interface?). LineageCache(const ClientID &client_id, - gcs::TableInterface &task_storage, + gcs::TableInterface &task_storage, gcs::PubsubInterface &task_pubsub, uint64_t max_lineage_size); /// Asynchronously commit a task to the GCS. @@ -303,7 +319,7 @@ class LineageCache { /// TODO(swang): Move the ClientID into the generic Table implementation. ClientID client_id_; /// The durable storage system for task information. - gcs::TableInterface &task_storage_; + gcs::TableInterface &task_storage_; /// The pubsub storage system for task information. This can be used to /// request notifications for the commit of a task entry. gcs::PubsubInterface &task_pubsub_; diff --git a/src/ray/raylet/lineage_cache_test.cc b/src/ray/raylet/lineage_cache_test.cc index a6184902f803..43e64e400292 100644 --- a/src/ray/raylet/lineage_cache_test.cc +++ b/src/ray/raylet/lineage_cache_test.cc @@ -13,7 +13,7 @@ namespace ray { namespace raylet { -class MockGcs : public gcs::TableInterface, +class MockGcs : public gcs::TableInterface, public gcs::PubsubInterface { public: MockGcs() {} @@ -23,15 +23,15 @@ class MockGcs : public gcs::TableInterface, } Status Add(const DriverID &driver_id, const TaskID &task_id, - std::shared_ptr &task_data, - const gcs::TableInterface::WriteCallback &done) { + std::shared_ptr &task_data, + const gcs::TableInterface::WriteCallback &done) { task_table_[task_id] = task_data; auto callback = done; // If we requested notifications for this task ID, send the notification as // part of the callback. if (subscribed_tasks_.count(task_id) == 1) { callback = [this, done](ray::gcs::AsyncGcsClient *client, const TaskID &task_id, - const TaskTableData &data) { + const protocol::TaskT &data) { done(client, task_id, data); // If we're subscribed to the task to be added, also send a // subscription notification. @@ -45,14 +45,14 @@ class MockGcs : public gcs::TableInterface, return ray::Status::OK(); } - Status RemoteAdd(const TaskID &task_id, std::shared_ptr task_data) { + Status RemoteAdd(const TaskID &task_id, std::shared_ptr task_data) { task_table_[task_id] = task_data; // Send a notification after the add if the lineage cache requested // notifications for this key. bool send_notification = (subscribed_tasks_.count(task_id) == 1); auto callback = [this, send_notification](ray::gcs::AsyncGcsClient *client, const TaskID &task_id, - const TaskTableData &data) { + const protocol::TaskT &data) { if (send_notification) { notification_callback_(client, task_id, data); } @@ -84,7 +84,7 @@ class MockGcs : public gcs::TableInterface, } } - const std::unordered_map> &TaskTable() const { + const std::unordered_map> &TaskTable() const { return task_table_; } @@ -95,7 +95,7 @@ class MockGcs : public gcs::TableInterface, const int NumTaskAdds() const { return num_task_adds_; } private: - std::unordered_map> task_table_; + std::unordered_map> task_table_; std::vector> callbacks_; gcs::raylet::TaskTable::WriteCallback notification_callback_; std::unordered_set subscribed_tasks_; @@ -111,7 +111,7 @@ class LineageCacheTest : public ::testing::Test { mock_gcs_(), lineage_cache_(ClientID::FromRandom(), mock_gcs_, mock_gcs_, max_lineage_size_) { mock_gcs_.Subscribe([this](ray::gcs::AsyncGcsClient *client, const TaskID &task_id, - const TaskTableData &data) { + const ray::protocol::TaskT &data) { lineage_cache_.HandleEntryCommitted(task_id); num_notifications_++; }); @@ -341,7 +341,7 @@ TEST_F(LineageCacheTest, TestEvictChain) { ASSERT_EQ(lineage_cache_.GetLineage().GetEntries().size(), tasks.size()); // Simulate executing the task on a remote node and adding it to the GCS. - auto task_data = std::make_shared(); + auto task_data = std::make_shared(); RAY_CHECK_OK( mock_gcs_.RemoteAdd(tasks.at(1).GetTaskSpecification().TaskId(), task_data)); mock_gcs_.Flush(); @@ -432,7 +432,7 @@ TEST_F(LineageCacheTest, TestEviction) { // Simulate executing the first task on a remote node and adding it to the // GCS. - auto task_data = std::make_shared(); + auto task_data = std::make_shared(); auto it = tasks.begin(); RAY_CHECK_OK(mock_gcs_.RemoteAdd(it->GetTaskSpecification().TaskId(), task_data)); it++; @@ -490,7 +490,7 @@ TEST_F(LineageCacheTest, TestOutOfOrderEviction) { auto last_task = tasks.front(); tasks.erase(tasks.begin()); for (auto it = tasks.rbegin(); it != tasks.rend(); it++) { - auto task_data = std::make_shared(); + auto task_data = std::make_shared(); RAY_CHECK_OK(mock_gcs_.RemoteAdd(it->GetTaskSpecification().TaskId(), task_data)); // Check that the remote task is flushed. num_tasks_flushed++; @@ -500,7 +500,7 @@ TEST_F(LineageCacheTest, TestOutOfOrderEviction) { } // Flush the last task. The lineage should not get evicted until this task's // commit is received. - auto task_data = std::make_shared(); + auto task_data = std::make_shared(); RAY_CHECK_OK(mock_gcs_.RemoteAdd(last_task.GetTaskSpecification().TaskId(), task_data)); num_tasks_flushed++; mock_gcs_.Flush(); @@ -536,7 +536,7 @@ TEST_F(LineageCacheTest, TestEvictionUncommittedChildren) { // until after the final remote task is executed, since a task can only be // evicted once all of its ancestors have been committed. for (auto it = tasks.rbegin(); it != tasks.rend(); it++) { - auto task_data = std::make_shared(); + auto task_data = std::make_shared(); ASSERT_EQ(lineage_cache_.GetLineage().GetEntries().size(), lineage_size * 2); RAY_CHECK_OK(mock_gcs_.RemoteAdd(it->GetTaskSpecification().TaskId(), task_data)); num_tasks_flushed++; diff --git a/src/ray/raylet/monitor.cc b/src/ray/raylet/monitor.cc index 0a853260887e..62ecb00b819f 100644 --- a/src/ray/raylet/monitor.cc +++ b/src/ray/raylet/monitor.cc @@ -24,14 +24,14 @@ Monitor::Monitor(boost::asio::io_service &io_service, const std::string &redis_a } void Monitor::HandleHeartbeat(const ClientID &client_id, - const HeartbeatTableData &heartbeat_data) { + const HeartbeatTableDataT &heartbeat_data) { heartbeats_[client_id] = num_heartbeats_timeout_; heartbeat_buffer_[client_id] = heartbeat_data; } void Monitor::Start() { const auto heartbeat_callback = [this](gcs::AsyncGcsClient *client, const ClientID &id, - const HeartbeatTableData &heartbeat_data) { + const HeartbeatTableDataT &heartbeat_data) { HandleHeartbeat(id, heartbeat_data); }; RAY_CHECK_OK(gcs_client_.heartbeat_table().Subscribe( @@ -49,11 +49,11 @@ void Monitor::Tick() { RAY_LOG(WARNING) << "Client timed out: " << client_id; auto lookup_callback = [this, client_id]( gcs::AsyncGcsClient *client, const ClientID &id, - const std::vector &all_data) { + const std::vector &all_data) { bool marked = false; for (const auto &data : all_data) { - if (client_id.Binary() == data.client_id() && - data.entry_type() == ClientTableData::DELETION) { + if (client_id.Binary() == data.client_id && + data.entry_type == EntryType::DELETION) { // The node has been marked dead by itself. marked = true; } @@ -84,9 +84,10 @@ void Monitor::Tick() { // Send any buffered heartbeats as a single publish. if (!heartbeat_buffer_.empty()) { - auto batch = std::make_shared(); + auto batch = std::make_shared(); for (const auto &heartbeat : heartbeat_buffer_) { - batch->add_batch()->CopyFrom(heartbeat.second); + batch->batch.push_back(std::unique_ptr( + new HeartbeatTableDataT(heartbeat.second))); } RAY_CHECK_OK(gcs_client_.heartbeat_batch_table().Add(DriverID::Nil(), ClientID::Nil(), batch, nullptr)); diff --git a/src/ray/raylet/monitor.h b/src/ray/raylet/monitor.h index 5725e52cf495..c69cc9f003e0 100644 --- a/src/ray/raylet/monitor.h +++ b/src/ray/raylet/monitor.h @@ -11,10 +11,6 @@ namespace ray { namespace raylet { -using rpc::ClientTableData; -using rpc::HeartbeatBatchTableData; -using rpc::HeartbeatTableData; - class Monitor { public: /// Create a Raylet monitor attached to the given GCS address and port. @@ -39,7 +35,7 @@ class Monitor { /// \param client_id The client ID of the Raylet that sent the heartbeat. /// \param heartbeat_data The heartbeat sent by the client. void HandleHeartbeat(const ClientID &client_id, - const HeartbeatTableData &heartbeat_data); + const HeartbeatTableDataT &heartbeat_data); private: /// A client to the GCS, through which heartbeats are received. @@ -54,7 +50,7 @@ class Monitor { /// The Raylets that have been marked as dead in the client table. std::unordered_set dead_clients_; /// A buffer containing heartbeats received from node managers in the last tick. - std::unordered_map heartbeat_buffer_; + std::unordered_map heartbeat_buffer_; }; } // namespace raylet diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 226a8fb6d251..a0bde1ff0655 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -46,9 +46,9 @@ ActorStats GetActorStatisticalData( std::unordered_map actor_registry) { ActorStats item; for (auto &pair : actor_registry) { - if (pair.second.GetState() == ray::rpc::ActorTableData::ALIVE) { + if (pair.second.GetState() == ActorState::ALIVE) { item.live_actors += 1; - } else if (pair.second.GetState() == ray::rpc::ActorTableData::RECONSTRUCTING) { + } else if (pair.second.GetState() == ActorState::RECONSTRUCTING) { item.reconstructing_actors += 1; } else { item.dead_actors += 1; @@ -83,8 +83,7 @@ NodeManager::NodeManager(boost::asio::io_service &io_service, initial_config_(config), local_available_resources_(config.resource_config), worker_pool_(config.num_initial_workers, config.num_workers_per_process, - config.maximum_startup_concurrency, gcs_client_, - config.worker_commands), + config.maximum_startup_concurrency, config.worker_commands), scheduling_policy_(local_queues_), reconstruction_policy_( io_service_, @@ -101,8 +100,7 @@ NodeManager::NodeManager(boost::asio::io_service &io_service, gcs_client_->raylet_task_table(), gcs_client_->raylet_task_table(), config.max_lineage_size), actor_registry_(), - node_manager_server_("NodeManager", config.node_manager_port), - node_manager_service_(io_service, *this), + node_manager_server_(config.node_manager_port, io_service, *this), client_call_manager_(io_service) { RAY_CHECK(heartbeat_period_.count() > 0); // Initialize the resource map with own cluster resource configuration. @@ -120,7 +118,6 @@ NodeManager::NodeManager(boost::asio::io_service &io_service, RAY_ARROW_CHECK_OK(store_client_.Connect(config.store_socket_name.c_str())); // Run the node manger rpc server. - node_manager_server_.RegisterService(node_manager_service_); node_manager_server_.Run(); } @@ -132,7 +129,7 @@ ray::Status NodeManager::RegisterGcs() { // that were executed remotely. const auto task_committed_callback = [this](gcs::AsyncGcsClient *client, const TaskID &task_id, - const TaskTableData &task_data) { + const ray::protocol::TaskT &task_data) { lineage_cache_.HandleEntryCommitted(task_id); }; RAY_RETURN_NOT_OK(gcs_client_->raylet_task_table().Subscribe( @@ -141,8 +138,8 @@ ray::Status NodeManager::RegisterGcs() { const auto task_lease_notification_callback = [this](gcs::AsyncGcsClient *client, const TaskID &task_id, - const TaskLeaseData &task_lease) { - const ClientID node_manager_id = ClientID::FromBinary(task_lease.node_manager_id()); + const TaskLeaseDataT &task_lease) { + const ClientID node_manager_id = ClientID::FromBinary(task_lease.node_manager_id); if (gcs_client_->client_table().IsRemoved(node_manager_id)) { // The node manager that added the task lease is already removed. The // lease is considered inactive. @@ -152,7 +149,7 @@ ray::Status NodeManager::RegisterGcs() { // expiration period since the entry may have been in the GCS for some // time already. For a more accurate estimate, the age of the entry in // the GCS should be subtracted from task_lease.timeout. - reconstruction_policy_.HandleTaskLeaseNotification(task_id, task_lease.timeout()); + reconstruction_policy_.HandleTaskLeaseNotification(task_id, task_lease.timeout); } }; const auto task_lease_empty_callback = [this](gcs::AsyncGcsClient *client, @@ -166,7 +163,7 @@ ray::Status NodeManager::RegisterGcs() { // Register a callback to handle actor notifications. auto actor_notification_callback = [this](gcs::AsyncGcsClient *client, const ActorID &actor_id, - const std::vector &data) { + const std::vector &data) { if (!data.empty()) { // We only need the last entry, because it represents the latest state of // this actor. @@ -179,34 +176,34 @@ ray::Status NodeManager::RegisterGcs() { // Register a callback on the client table for new clients. auto node_manager_client_added = [this](gcs::AsyncGcsClient *client, const UniqueID &id, - const ClientTableData &data) { + const ClientTableDataT &data) { ClientAdded(data); }; gcs_client_->client_table().RegisterClientAddedCallback(node_manager_client_added); // Register a callback on the client table for removed clients. auto node_manager_client_removed = [this](gcs::AsyncGcsClient *client, const UniqueID &id, - const ClientTableData &data) { ClientRemoved(data); }; + const ClientTableDataT &data) { ClientRemoved(data); }; gcs_client_->client_table().RegisterClientRemovedCallback(node_manager_client_removed); // Register a callback on the client table for resource create/update requests auto node_manager_resource_createupdated = [this](gcs::AsyncGcsClient *client, const UniqueID &id, - const ClientTableData &data) { ResourceCreateUpdated(data); }; + const ClientTableDataT &data) { ResourceCreateUpdated(data); }; gcs_client_->client_table().RegisterResourceCreateUpdatedCallback( node_manager_resource_createupdated); // Register a callback on the client table for resource delete requests auto node_manager_resource_deleted = [this](gcs::AsyncGcsClient *client, const UniqueID &id, - const ClientTableData &data) { ResourceDeleted(data); }; + const ClientTableDataT &data) { ResourceDeleted(data); }; gcs_client_->client_table().RegisterResourceDeletedCallback( node_manager_resource_deleted); // Subscribe to heartbeat batches from the monitor. const auto &heartbeat_batch_added = [this](gcs::AsyncGcsClient *client, const ClientID &id, - const HeartbeatBatchTableData &heartbeat_batch) { + const HeartbeatBatchTableDataT &heartbeat_batch) { HeartbeatBatchAdded(heartbeat_batch); }; RAY_RETURN_NOT_OK(gcs_client_->heartbeat_batch_table().Subscribe( @@ -217,7 +214,7 @@ ray::Status NodeManager::RegisterGcs() { // Subscribe to driver table updates. const auto driver_table_handler = [this](gcs::AsyncGcsClient *client, const DriverID &client_id, - const std::vector &driver_data) { + const std::vector &driver_data) { HandleDriverTableUpdate(client_id, driver_data); }; RAY_RETURN_NOT_OK(gcs_client_->driver_table().Subscribe( @@ -253,12 +250,12 @@ void NodeManager::KillWorker(std::shared_ptr worker) { } void NodeManager::HandleDriverTableUpdate( - const DriverID &id, const std::vector &driver_data) { + const DriverID &id, const std::vector &driver_data) { for (const auto &entry : driver_data) { - RAY_LOG(DEBUG) << "HandleDriverTableUpdate " - << UniqueID::FromBinary(entry.driver_id()) << " " << entry.is_dead(); - if (entry.is_dead()) { - auto driver_id = DriverID::FromBinary(entry.driver_id()); + RAY_LOG(DEBUG) << "HandleDriverTableUpdate " << UniqueID::FromBinary(entry.driver_id) + << " " << entry.is_dead; + if (entry.is_dead) { + auto driver_id = DriverID::FromBinary(entry.driver_id); auto workers = worker_pool_.GetWorkersRunningTasksForDriver(driver_id); // Kill all the workers. The actual cleanup for these workers is done @@ -290,26 +287,26 @@ void NodeManager::Heartbeat() { last_heartbeat_at_ms_ = now_ms; auto &heartbeat_table = gcs_client_->heartbeat_table(); - auto heartbeat_data = std::make_shared(); + auto heartbeat_data = std::make_shared(); const auto &my_client_id = gcs_client_->client_table().GetLocalClientId(); SchedulingResources &local_resources = cluster_resource_map_[my_client_id]; - heartbeat_data->set_client_id(my_client_id.Binary()); + heartbeat_data->client_id = my_client_id.Binary(); // TODO(atumanov): modify the heartbeat table protocol to use the ResourceSet directly. // TODO(atumanov): implement a ResourceSet const_iterator. for (const auto &resource_pair : local_resources.GetAvailableResources().GetResourceMap()) { - heartbeat_data->add_resources_available_label(resource_pair.first); - heartbeat_data->add_resources_available_capacity(resource_pair.second); + heartbeat_data->resources_available_label.push_back(resource_pair.first); + heartbeat_data->resources_available_capacity.push_back(resource_pair.second); } for (const auto &resource_pair : local_resources.GetTotalResources().GetResourceMap()) { - heartbeat_data->add_resources_total_label(resource_pair.first); - heartbeat_data->add_resources_total_capacity(resource_pair.second); + heartbeat_data->resources_total_label.push_back(resource_pair.first); + heartbeat_data->resources_total_capacity.push_back(resource_pair.second); } local_resources.SetLoadResources(local_queues_.GetResourceLoad()); for (const auto &resource_pair : local_resources.GetLoadResources().GetResourceMap()) { - heartbeat_data->add_resource_load_label(resource_pair.first); - heartbeat_data->add_resource_load_capacity(resource_pair.second); + heartbeat_data->resource_load_label.push_back(resource_pair.first); + heartbeat_data->resource_load_capacity.push_back(resource_pair.second); } ray::Status status = heartbeat_table.Add( @@ -337,8 +334,13 @@ void NodeManager::GetObjectManagerProfileInfo() { auto profile_info = object_manager_.GetAndResetProfilingInfo(); - if (profile_info.profile_events_size() > 0) { - RAY_CHECK_OK(gcs_client_->profile_table().AddProfileEventBatch(profile_info)); + if (profile_info.profile_events.size() > 0) { + flatbuffers::FlatBufferBuilder fbb; + auto message = CreateProfileTableData(fbb, &profile_info); + fbb.Finish(message); + auto profile_message = flatbuffers::GetRoot(fbb.GetBufferPointer()); + + RAY_CHECK_OK(gcs_client_->profile_table().AddProfileEventBatch(*profile_message)); } // Reset the timer. @@ -355,8 +357,8 @@ void NodeManager::GetObjectManagerProfileInfo() { } } -void NodeManager::ClientAdded(const ClientTableData &client_data) { - const ClientID client_id = ClientID::FromBinary(client_data.client_id()); +void NodeManager::ClientAdded(const ClientTableDataT &client_data) { + const ClientID client_id = ClientID::FromBinary(client_data.client_id); RAY_LOG(DEBUG) << "[ClientAdded] Received callback from client id " << client_id; if (client_id == gcs_client_->client_table().GetLocalClientId()) { @@ -375,20 +377,19 @@ void NodeManager::ClientAdded(const ClientTableData &client_data) { // Initialize a rpc client to the new node manager. std::unique_ptr client( - new rpc::NodeManagerClient(client_data.node_manager_address(), - client_data.node_manager_port(), client_call_manager_)); + new rpc::NodeManagerClient(client_data.node_manager_address, + client_data.node_manager_port, client_call_manager_)); remote_node_manager_clients_.emplace(client_id, std::move(client)); - ResourceSet resources_total( - rpc::VectorFromProtobuf(client_data.resources_total_label()), - rpc::VectorFromProtobuf(client_data.resources_total_capacity())); + ResourceSet resources_total(client_data.resources_total_label, + client_data.resources_total_capacity); cluster_resource_map_.emplace(client_id, SchedulingResources(resources_total)); } -void NodeManager::ClientRemoved(const ClientTableData &client_data) { +void NodeManager::ClientRemoved(const ClientTableDataT &client_data) { // TODO(swang): If we receive a notification for our own death, clean up and // exit immediately. - const ClientID client_id = ClientID::FromBinary(client_data.client_id()); + const ClientID client_id = ClientID::FromBinary(client_data.client_id); RAY_LOG(DEBUG) << "[ClientRemoved] Received callback from client id " << client_id; RAY_CHECK(client_id != gcs_client_->client_table().GetLocalClientId()) @@ -416,7 +417,7 @@ void NodeManager::ClientRemoved(const ClientTableData &client_data) { // TODO(swang): This could be very slow if there are many actors. for (const auto &actor_entry : actor_registry_) { if (actor_entry.second.GetNodeManagerId() == client_id && - actor_entry.second.GetState() == ActorTableData::ALIVE) { + actor_entry.second.GetState() == ActorState::ALIVE) { RAY_LOG(INFO) << "Actor " << actor_entry.first << " is disconnected, because its node " << client_id << " is removed from cluster. It may be reconstructed."; @@ -434,15 +435,14 @@ void NodeManager::ClientRemoved(const ClientTableData &client_data) { lineage_cache_.FlushAllUncommittedTasks(); } -void NodeManager::ResourceCreateUpdated(const ClientTableData &client_data) { - const ClientID client_id = ClientID::FromBinary(client_data.client_id()); +void NodeManager::ResourceCreateUpdated(const ClientTableDataT &client_data) { + const ClientID client_id = ClientID::FromBinary(client_data.client_id); const ClientID &local_client_id = gcs_client_->client_table().GetLocalClientId(); RAY_LOG(DEBUG) << "[ResourceCreateUpdated] received callback from client id " << client_id << ". Updating resource map."; - ResourceSet new_res_set( - rpc::VectorFromProtobuf(client_data.resources_total_label()), - rpc::VectorFromProtobuf(client_data.resources_total_capacity())); + ResourceSet new_res_set(client_data.resources_total_label, + client_data.resources_total_capacity); const ResourceSet &old_res_set = cluster_resource_map_[client_id].GetTotalResources(); ResourceSet difference_set = old_res_set.FindUpdatedResources(new_res_set); @@ -471,13 +471,12 @@ void NodeManager::ResourceCreateUpdated(const ClientTableData &client_data) { return; } -void NodeManager::ResourceDeleted(const ClientTableData &client_data) { - const ClientID client_id = ClientID::FromBinary(client_data.client_id()); +void NodeManager::ResourceDeleted(const ClientTableDataT &client_data) { + const ClientID client_id = ClientID::FromBinary(client_data.client_id); const ClientID &local_client_id = gcs_client_->client_table().GetLocalClientId(); - ResourceSet new_res_set( - rpc::VectorFromProtobuf(client_data.resources_total_label()), - rpc::VectorFromProtobuf(client_data.resources_total_capacity())); + ResourceSet new_res_set(client_data.resources_total_label, + client_data.resources_total_capacity); RAY_LOG(DEBUG) << "[ResourceDeleted] received callback from client id " << client_id << " with new resources: " << new_res_set.ToString() << ". Updating resource map."; @@ -523,7 +522,7 @@ void NodeManager::TryLocalInfeasibleTaskScheduling() { } void NodeManager::HeartbeatAdded(const ClientID &client_id, - const HeartbeatTableData &heartbeat_data) { + const HeartbeatTableDataT &heartbeat_data) { // Locate the client id in remote client table and update available resources based on // the received heartbeat information. auto it = cluster_resource_map_.find(client_id); @@ -535,12 +534,10 @@ void NodeManager::HeartbeatAdded(const ClientID &client_id, } SchedulingResources &remote_resources = it->second; - ResourceSet remote_available( - rpc::VectorFromProtobuf(heartbeat_data.resources_total_label()), - rpc::VectorFromProtobuf(heartbeat_data.resources_total_capacity())); - ResourceSet remote_load( - rpc::VectorFromProtobuf(heartbeat_data.resource_load_label()), - rpc::VectorFromProtobuf(heartbeat_data.resource_load_capacity())); + ResourceSet remote_available(heartbeat_data.resources_available_label, + heartbeat_data.resources_available_capacity); + ResourceSet remote_load(heartbeat_data.resource_load_label, + heartbeat_data.resource_load_capacity); // TODO(atumanov): assert that the load is a non-empty ResourceSet. remote_resources.SetAvailableResources(std::move(remote_available)); // Extract the load information and save it locally. @@ -565,41 +562,40 @@ void NodeManager::HeartbeatAdded(const ClientID &client_id, } } -void NodeManager::HeartbeatBatchAdded(const HeartbeatBatchTableData &heartbeat_batch) { +void NodeManager::HeartbeatBatchAdded(const HeartbeatBatchTableDataT &heartbeat_batch) { const ClientID &local_client_id = gcs_client_->client_table().GetLocalClientId(); // Update load information provided by each heartbeat. - for (const auto &heartbeat_data : heartbeat_batch.batch()) { - const ClientID &client_id = ClientID::FromBinary(heartbeat_data.client_id()); + for (const auto &heartbeat_data : heartbeat_batch.batch) { + const ClientID &client_id = ClientID::FromBinary(heartbeat_data->client_id); if (client_id == local_client_id) { // Skip heartbeats from self. continue; } - HeartbeatAdded(client_id, heartbeat_data); + HeartbeatAdded(client_id, *heartbeat_data); } } void NodeManager::PublishActorStateTransition( - const ActorID &actor_id, const ActorTableData &data, + const ActorID &actor_id, const ActorTableDataT &data, const ray::gcs::ActorTable::WriteCallback &failure_callback) { // Copy the actor notification data. - auto actor_notification = std::make_shared(data); + auto actor_notification = std::make_shared(data); // The actor log starts with an ALIVE entry. This is followed by 0 to N pairs // of (RECONSTRUCTING, ALIVE) entries, where N is the maximum number of // reconstructions. This is followed optionally by a DEAD entry. - int log_length = 2 * (actor_notification->max_reconstructions() - - actor_notification->remaining_reconstructions()); - if (actor_notification->state() != ActorTableData::ALIVE) { + int log_length = 2 * (actor_notification->max_reconstructions - + actor_notification->remaining_reconstructions); + if (actor_notification->state != ActorState::ALIVE) { // RECONSTRUCTING or DEAD entries have an odd index. log_length += 1; } // If we successful appended a record to the GCS table of the actor that // has died, signal this to anyone receiving signals from this actor. auto success_callback = [](gcs::AsyncGcsClient *client, const ActorID &id, - const ActorTableData &data) { + const ActorTableDataT &data) { auto redis_context = client->primary_context(); - if (data.state() == ActorTableData::DEAD || - data.state() == ActorTableData::RECONSTRUCTING) { + if (data.state == ActorState::DEAD || data.state == ActorState::RECONSTRUCTING) { std::vector args = {"XADD", id.Hex(), "*", "signal", "ACTOR_DIED_SIGNAL"}; RAY_CHECK_OK(redis_context->RunArgvAsync(args)); @@ -636,12 +632,11 @@ void NodeManager::HandleActorStateTransition(const ActorID &actor_id, } RAY_LOG(DEBUG) << "Actor notification received: actor_id = " << actor_id << ", node_manager_id = " << actor_registration.GetNodeManagerId() - << ", state = " - << ActorTableData::ActorState_Name(actor_registration.GetState()) + << ", state = " << EnumNameActorState(actor_registration.GetState()) << ", remaining_reconstructions = " << actor_registration.GetRemainingReconstructions(); - if (actor_registration.GetState() == ActorTableData::ALIVE) { + if (actor_registration.GetState() == ActorState::ALIVE) { // The actor's location is now known. Dequeue any methods that were // submitted before the actor's location was known. // (See design_docs/task_states.rst for the state transition diagram.) @@ -668,7 +663,7 @@ void NodeManager::HandleActorStateTransition(const ActorID &actor_id, // empty lineage this time. SubmitTask(method, Lineage()); } - } else if (actor_registration.GetState() == ActorTableData::DEAD) { + } else if (actor_registration.GetState() == ActorState::DEAD) { // When an actor dies, loop over all of the queued tasks for that actor // and treat them as failed. auto tasks_to_remove = local_queues_.GetTaskIdsForActor(actor_id); @@ -677,7 +672,7 @@ void NodeManager::HandleActorStateTransition(const ActorID &actor_id, TreatTaskAsFailed(task, ErrorType::ACTOR_DIED); } } else { - RAY_CHECK(actor_registration.GetState() == ActorTableData::RECONSTRUCTING); + RAY_CHECK(actor_registration.GetState() == ActorState::RECONSTRUCTING); RAY_LOG(DEBUG) << "Actor is being reconstructed: " << actor_id; // When an actor fails but can be reconstructed, resubmit all of the queued // tasks for that actor. This will mark the tasks as waiting for actor @@ -798,20 +793,8 @@ void NodeManager::ProcessClientMessage( ProcessPushErrorRequestMessage(message_data); } break; case protocol::MessageType::PushProfileEventsRequest: { - ProfileTableDataT fbs_message; - flatbuffers::GetRoot(message_data)->UnPackTo(&fbs_message); - rpc::ProfileTableData profile_table_data; - profile_table_data.set_component_type(fbs_message.component_type); - profile_table_data.set_component_id(fbs_message.component_id); - for (const auto &fbs_event : fbs_message.profile_events) { - rpc::ProfileTableData::ProfileEvent *event = - profile_table_data.add_profile_events(); - event->set_event_type(fbs_event->event_type); - event->set_start_time(fbs_event->start_time); - event->set_end_time(fbs_event->end_time); - event->set_extra_data(fbs_event->extra_data); - } - RAY_CHECK_OK(gcs_client_->profile_table().AddProfileEventBatch(profile_table_data)); + auto message = flatbuffers::GetRoot(message_data); + RAY_CHECK_OK(gcs_client_->profile_table().AddProfileEventBatch(*message)); } break; case protocol::MessageType::FreeObjectsInObjectStoreRequest: { auto message = flatbuffers::GetRoot(message_data); @@ -879,8 +862,8 @@ void NodeManager::HandleDisconnectedActor(const ActorID &actor_id, bool was_loca // Check if this actor needs to be reconstructed. ActorState new_state = actor_registration.GetRemainingReconstructions() > 0 && !intentional_disconnect - ? ActorTableData::RECONSTRUCTING - : ActorTableData::DEAD; + ? ActorState::RECONSTRUCTING + : ActorState::DEAD; if (was_local) { // Clean up the dummy objects from this actor. RAY_LOG(DEBUG) << "Removing dummy objects for actor: " << actor_id; @@ -889,8 +872,8 @@ void NodeManager::HandleDisconnectedActor(const ActorID &actor_id, bool was_loca } } // Update the actor's state. - ActorTableData new_actor_data = actor_entry->second.GetTableData(); - new_actor_data.set_state(new_state); + ActorTableDataT new_actor_data = actor_entry->second.GetTableData(); + new_actor_data.state = new_state; if (was_local) { // If the actor was local, immediately update the state in actor registry. // So if we receive any actor tasks before we receive GCS notification, @@ -901,7 +884,7 @@ void NodeManager::HandleDisconnectedActor(const ActorID &actor_id, bool was_loca ray::gcs::ActorTable::WriteCallback failure_callback = nullptr; if (was_local) { failure_callback = [](gcs::AsyncGcsClient *client, const ActorID &id, - const ActorTableData &data) { + const ActorTableDataT &data) { // If the disconnected actor was local, only this node will try to update actor // state. So the update shouldn't fail. RAY_LOG(FATAL) << "Failed to update state for actor " << id; @@ -1176,7 +1159,7 @@ void NodeManager::ProcessPrepareActorCheckpointRequest( DriverID::Nil(), checkpoint_id, checkpoint_data, [worker, actor_id, this](ray::gcs::AsyncGcsClient *client, const ActorCheckpointID &checkpoint_id, - const ActorCheckpointData &data) { + const ActorCheckpointDataT &data) { RAY_LOG(DEBUG) << "Checkpoint " << checkpoint_id << " saved for actor " << worker->GetActorId(); // Save this actor-to-checkpoint mapping, and remove old checkpoints associated @@ -1260,19 +1243,19 @@ void NodeManager::ProcessSetResourceRequest( return; } - // Add the new resource to a skeleton ClientTableData object - ClientTableData data; + // Add the new resource to a skeleton ClientTableDataT object + ClientTableDataT data; gcs_client_->client_table().GetClient(client_id, data); // Replace the resource vectors with the resource deltas from the message. // RES_CREATEUPDATE and RES_DELETE entries in the ClientTable track changes (deltas) in // the resources - data.add_resources_total_label(resource_name); - data.add_resources_total_capacity(capacity); + data.resources_total_label = std::vector{resource_name}; + data.resources_total_capacity = std::vector{capacity}; // Set the correct flag for entry_type if (is_deletion) { - data.set_entry_type(ClientTableData::RES_DELETE); + data.entry_type = EntryType::RES_DELETE; } else { - data.set_entry_type(ClientTableData::RES_CREATEUPDATE); + data.entry_type = EntryType::RES_CREATEUPDATE; } // Submit to the client table. This calls the ResourceCreateUpdated callback, which @@ -1281,7 +1264,7 @@ void NodeManager::ProcessSetResourceRequest( if (not worker) { worker = worker_pool_.GetRegisteredDriver(client); } - auto data_shared_ptr = std::make_shared(data); + auto data_shared_ptr = std::make_shared(data); auto client_table = gcs_client_->client_table(); RAY_CHECK_OK(gcs_client_->client_table().Append( DriverID::Nil(), client_table.client_log_key_, data_shared_ptr, nullptr)); @@ -1386,7 +1369,7 @@ bool NodeManager::CheckDependencyManagerInvariant() const { void NodeManager::TreatTaskAsFailed(const Task &task, const ErrorType &error_type) { const TaskSpecification &spec = task.GetTaskSpecification(); RAY_LOG(DEBUG) << "Treating task " << spec.TaskId() << " as failed because of error " - << ErrorType_Name(error_type) << "."; + << EnumNameErrorType(error_type) << "."; // If this was an actor creation task that tried to resume from a checkpoint, // then erase it here since the task did not finish. if (spec.IsActorCreationTask()) { @@ -1504,9 +1487,9 @@ void NodeManager::SubmitTask(const Task &task, const Lineage &uncommitted_lineag // If we have already seen this actor and this actor is not being reconstructed, // its location is known. bool location_known = - seen && actor_entry->second.GetState() != ActorTableData::RECONSTRUCTING; + seen && actor_entry->second.GetState() != ActorState::RECONSTRUCTING; if (location_known) { - if (actor_entry->second.GetState() == ActorTableData::DEAD) { + if (actor_entry->second.GetState() == ActorState::DEAD) { // If this actor is dead, either because the actor process is dead // or because its residing node is dead, treat this task as failed. TreatTaskAsFailed(task, ErrorType::ACTOR_DIED); @@ -1551,7 +1534,7 @@ void NodeManager::SubmitTask(const Task &task, const Lineage &uncommitted_lineag // we missed the creation notification. auto lookup_callback = [this](gcs::AsyncGcsClient *client, const ActorID &actor_id, - const std::vector &data) { + const std::vector &data) { if (!data.empty()) { // The actor has been created. We only need the last entry, because // it represents the latest state of this actor. @@ -1740,6 +1723,18 @@ bool NodeManager::AssignTask(const Task &task) { std::shared_ptr worker = worker_pool_.PopWorker(spec); if (worker == nullptr) { // There are no workers that can execute this task. + if (!spec.IsActorTask()) { + // There are no more non-actor workers available to execute this task. + // Start a new worker. + worker_pool_.StartWorkerProcess(spec.GetLanguage()); + // Push an error message to the user if the worker pool tells us that it is + // getting too big. + const std::string warning_message = worker_pool_.WarningAboutSize(); + if (warning_message != "") { + RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver( + DriverID::Nil(), "worker_pool_large", warning_message, current_time_ms())); + } + } // We couldn't assign this task, as no worker available. return false; } @@ -1877,11 +1872,11 @@ void NodeManager::FinishAssignedTask(Worker &worker) { } } -ActorTableData NodeManager::CreateActorTableDataFromCreationTask(const Task &task) { +ActorTableDataT NodeManager::CreateActorTableDataFromCreationTask(const Task &task) { RAY_CHECK(task.GetTaskSpecification().IsActorCreationTask()); auto actor_id = task.GetTaskSpecification().ActorCreationId(); auto actor_entry = actor_registry_.find(actor_id); - ActorTableData new_actor_data; + ActorTableDataT new_actor_data; // TODO(swang): If this is an actor that was reconstructed, and previous // actor notifications were delayed, then this node may not have an entry for // the actor in actor_regisry_. Then, the fields for the number of @@ -1889,33 +1884,32 @@ ActorTableData NodeManager::CreateActorTableDataFromCreationTask(const Task &tas if (actor_entry == actor_registry_.end()) { // Set all of the static fields for the actor. These fields will not // change even if the actor fails or is reconstructed. - new_actor_data.set_actor_id(actor_id.Binary()); - new_actor_data.set_actor_creation_dummy_object_id( - task.GetTaskSpecification().ActorDummyObject().Binary()); - new_actor_data.set_driver_id(task.GetTaskSpecification().DriverId().Binary()); - new_actor_data.set_max_reconstructions( - task.GetTaskSpecification().MaxActorReconstructions()); + new_actor_data.actor_id = actor_id.Binary(); + new_actor_data.actor_creation_dummy_object_id = + task.GetTaskSpecification().ActorDummyObject().Binary(); + new_actor_data.driver_id = task.GetTaskSpecification().DriverId().Binary(); + new_actor_data.max_reconstructions = + task.GetTaskSpecification().MaxActorReconstructions(); // This is the first time that the actor has been created, so the number // of remaining reconstructions is the max. - new_actor_data.set_remaining_reconstructions( - task.GetTaskSpecification().MaxActorReconstructions()); + new_actor_data.remaining_reconstructions = + task.GetTaskSpecification().MaxActorReconstructions(); } else { // If we've already seen this actor, it means that this actor was reconstructed. // Thus, its previous state must be RECONSTRUCTING. - RAY_CHECK(actor_entry->second.GetState() == ActorTableData::RECONSTRUCTING); + RAY_CHECK(actor_entry->second.GetState() == ActorState::RECONSTRUCTING); // Copy the static fields from the current actor entry. new_actor_data = actor_entry->second.GetTableData(); // We are reconstructing the actor, so subtract its // remaining_reconstructions by 1. - new_actor_data.set_remaining_reconstructions( - new_actor_data.remaining_reconstructions() - 1); + new_actor_data.remaining_reconstructions--; } // Set the new fields for the actor's state to indicate that the actor is // now alive on this node manager. - new_actor_data.set_node_manager_id( - gcs_client_->client_table().GetLocalClientId().Binary()); - new_actor_data.set_state(ActorTableData::ALIVE); + new_actor_data.node_manager_id = + gcs_client_->client_table().GetLocalClientId().Binary(); + new_actor_data.state = ActorState::ALIVE; return new_actor_data; } @@ -1951,7 +1945,7 @@ void NodeManager::FinishAssignedActorTask(Worker &worker, const Task &task) { DriverID::Nil(), checkpoint_id, [this, actor_id, new_actor_data](ray::gcs::AsyncGcsClient *client, const UniqueID &checkpoint_id, - const ActorCheckpointData &checkpoint_data) { + const ActorCheckpointDataT &checkpoint_data) { RAY_LOG(INFO) << "Restoring registration for actor " << actor_id << " from checkpoint " << checkpoint_id; ActorRegistration actor_registration = @@ -1965,7 +1959,7 @@ void NodeManager::FinishAssignedActorTask(Worker &worker, const Task &task) { actor_id, new_actor_data, /*failure_callback=*/ [](gcs::AsyncGcsClient *client, const ActorID &id, - const ActorTableData &data) { + const ActorTableDataT &data) { // Only one node at a time should succeed at creating the actor. RAY_LOG(FATAL) << "Failed to update state to ALIVE for actor " << id; }); @@ -1981,7 +1975,8 @@ void NodeManager::FinishAssignedActorTask(Worker &worker, const Task &task) { PublishActorStateTransition( actor_id, new_actor_data, /*failure_callback=*/ - [](gcs::AsyncGcsClient *client, const ActorID &id, const ActorTableData &data) { + [](gcs::AsyncGcsClient *client, const ActorID &id, + const ActorTableDataT &data) { // Only one node at a time should succeed at creating the actor. RAY_LOG(FATAL) << "Failed to update state to ALIVE for actor " << id; }); @@ -2020,11 +2015,10 @@ void NodeManager::HandleTaskReconstruction(const TaskID &task_id) { DriverID::Nil(), task_id, /*success_callback=*/ [this](ray::gcs::AsyncGcsClient *client, const TaskID &task_id, - const TaskTableData &task_data) { + const ray::protocol::TaskT &task_data) { // The task was in the GCS task table. Use the stored task spec to // re-execute the task. - auto message = flatbuffers::GetRoot(task_data.task().data()); - const Task task(*message); + const Task task(task_data); ResubmitTask(task); }, /*failure_callback=*/ @@ -2052,7 +2046,7 @@ void NodeManager::ResubmitTask(const Task &task) { if (task.GetTaskSpecification().IsActorCreationTask()) { const auto &actor_id = task.GetTaskSpecification().ActorCreationId(); const auto it = actor_registry_.find(actor_id); - if (it != actor_registry_.end() && it->second.GetState() == ActorTableData::ALIVE) { + if (it != actor_registry_.end() && it->second.GetState() == ActorState::ALIVE) { // If the actor is still alive, then do not resubmit the task. If the // actor actually is dead and a result is needed, then reconstruction // for this task will be triggered again. @@ -2211,12 +2205,6 @@ void NodeManager::ForwardTask( const auto &spec = task.GetTaskSpecification(); auto task_id = spec.TaskId(); - if (worker_pool_.HasPendingWorkerForTask(spec.GetLanguage(), task_id)) { - // There is a worker being starting for this task, - // so we shouldn't forward this task to another node. - return; - } - // Get and serialize the task's unforwarded, uncommitted lineage. Lineage uncommitted_lineage; if (lineage_cache_.ContainsTask(task_id)) { diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index 7e812183657c..61613358330c 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -10,6 +10,7 @@ #include "ray/raylet/task.h" #include "ray/object_manager/object_manager.h" #include "ray/common/client_connection.h" +#include "ray/gcs/format/util.h" #include "ray/raylet/actor_registration.h" #include "ray/raylet/lineage_cache.h" #include "ray/raylet/scheduling_policy.h" @@ -25,13 +26,6 @@ namespace ray { namespace raylet { -using rpc::ActorTableData; -using rpc::ClientTableData; -using rpc::DriverTableData; -using rpc::ErrorType; -using rpc::HeartbeatBatchTableData; -using rpc::HeartbeatTableData; - struct NodeManagerConfig { /// The node's resource configuration. ResourceSet resource_config; @@ -118,22 +112,22 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// /// \param data Data associated with the new client. /// \return Void. - void ClientAdded(const ClientTableData &data); + void ClientAdded(const ClientTableDataT &data); /// Handler for the removal of a GCS client. /// \param client_data Data associated with the removed client. /// \return Void. - void ClientRemoved(const ClientTableData &client_data); + void ClientRemoved(const ClientTableDataT &client_data); /// Handler for the addition or updation of a resource in the GCS /// \param client_data Data associated with the new client. /// \return Void. - void ResourceCreateUpdated(const ClientTableData &client_data); + void ResourceCreateUpdated(const ClientTableDataT &client_data); /// Handler for the deletion of a resource in the GCS /// \param client_data Data associated with the new client. /// \return Void. - void ResourceDeleted(const ClientTableData &client_data); + void ResourceDeleted(const ClientTableDataT &client_data); /// Evaluates the local infeasible queue to check if any tasks can be scheduled. /// This is called whenever there's an update to the resources on the local client. @@ -156,11 +150,11 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// \param id The ID of the node manager that sent the heartbeat. /// \param data The heartbeat data including load information. /// \return Void. - void HeartbeatAdded(const ClientID &id, const HeartbeatTableData &data); + void HeartbeatAdded(const ClientID &id, const HeartbeatTableDataT &data); /// Handler for a heartbeat batch notification from the GCS /// /// \param heartbeat_batch The batch of heartbeat data. - void HeartbeatBatchAdded(const HeartbeatBatchTableData &heartbeat_batch); + void HeartbeatBatchAdded(const HeartbeatBatchTableDataT &heartbeat_batch); /// Methods for task scheduling. @@ -212,7 +206,7 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// Helper function to produce actor table data for a newly created actor. /// /// \param task The actor creation task that created the actor. - ActorTableData CreateActorTableDataFromCreationTask(const Task &task); + ActorTableDataT CreateActorTableDataFromCreationTask(const Task &task); /// Handle a worker finishing an assigned actor task or actor creation task. /// \param worker The worker that finished the task. /// \param task The actor task or actor creationt ask. @@ -323,7 +317,7 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// \param failure_callback An optional callback to call if the publish is /// unsuccessful. void PublishActorStateTransition( - const ActorID &actor_id, const ActorTableData &data, + const ActorID &actor_id, const ActorTableDataT &data, const ray::gcs::ActorTable::WriteCallback &failure_callback); /// When a driver dies, loop over all of the queued tasks for that driver and @@ -352,7 +346,7 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// \param driver_data Data associated with a driver table event. /// \return Void. void HandleDriverTableUpdate(const DriverID &id, - const std::vector &driver_data); + const std::vector &driver_data); /// Check if certain invariants associated with the task dependency manager /// and the local queues are satisfied. This is only used for debugging @@ -512,10 +506,7 @@ class NodeManager : public rpc::NodeManagerServiceHandler { std::unordered_map checkpoint_id_to_restore_; /// The RPC server. - rpc::GrpcServer node_manager_server_; - - /// The RPC service. - rpc::NodeManagerGrpcService node_manager_service_; + rpc::NodeManagerServer node_manager_server_; /// The `ClientCallManager` object that is shared by all `NodeManagerClient`s. rpc::ClientCallManager client_call_manager_; diff --git a/src/ray/raylet/raylet.cc b/src/ray/raylet/raylet.cc index cbf9b25213ca..473e6c263ffe 100644 --- a/src/ray/raylet/raylet.cc +++ b/src/ray/raylet/raylet.cc @@ -90,23 +90,23 @@ ray::Status Raylet::RegisterGcs(const std::string &node_ip_address, const NodeManagerConfig &node_manager_config) { RAY_RETURN_NOT_OK(gcs_client_->Attach(io_service)); - ClientTableData client_info = gcs_client_->client_table().GetLocalClient(); - client_info.set_node_manager_address(node_ip_address); - client_info.set_raylet_socket_name(raylet_socket_name); - client_info.set_object_store_socket_name(object_store_socket_name); - client_info.set_object_manager_port(object_manager_acceptor_.local_endpoint().port()); - client_info.set_node_manager_port(node_manager_.GetServerPort()); + ClientTableDataT client_info = gcs_client_->client_table().GetLocalClient(); + client_info.node_manager_address = node_ip_address; + client_info.raylet_socket_name = raylet_socket_name; + client_info.object_store_socket_name = object_store_socket_name; + client_info.object_manager_port = object_manager_acceptor_.local_endpoint().port(); + client_info.node_manager_port = node_manager_.GetServerPort(); // Add resource information. for (const auto &resource_pair : node_manager_config.resource_config.GetResourceMap()) { - client_info.add_resources_total_label(resource_pair.first); - client_info.add_resources_total_capacity(resource_pair.second); + client_info.resources_total_label.push_back(resource_pair.first); + client_info.resources_total_capacity.push_back(resource_pair.second); } RAY_LOG(DEBUG) << "Node manager " << gcs_client_->client_table().GetLocalClientId() - << " started on " << client_info.node_manager_address() << ":" - << client_info.node_manager_port() << " object manager at " - << client_info.node_manager_address() << ":" - << client_info.object_manager_port(); + << " started on " << client_info.node_manager_address << ":" + << client_info.node_manager_port << " object manager at " + << client_info.node_manager_address << ":" + << client_info.object_manager_port; ; RAY_RETURN_NOT_OK(gcs_client_->client_table().Connect(client_info)); diff --git a/src/ray/raylet/raylet.h b/src/ray/raylet/raylet.h index 9367a5054591..26fe74b2b622 100644 --- a/src/ray/raylet/raylet.h +++ b/src/ray/raylet/raylet.h @@ -16,8 +16,6 @@ namespace ray { namespace raylet { -using rpc::ClientTableData; - class Task; class NodeManager; diff --git a/src/ray/raylet/reconstruction_policy.cc b/src/ray/raylet/reconstruction_policy.cc index bf5c1acfaa37..97c86ea73cd8 100644 --- a/src/ray/raylet/reconstruction_policy.cc +++ b/src/ray/raylet/reconstruction_policy.cc @@ -106,19 +106,19 @@ void ReconstructionPolicy::AttemptReconstruction(const TaskID &task_id, // Attempt to reconstruct the task by inserting an entry into the task // reconstruction log. This will fail if another node has already inserted // an entry for this reconstruction. - auto reconstruction_entry = std::make_shared(); - reconstruction_entry->set_num_reconstructions(reconstruction_attempt); - reconstruction_entry->set_node_manager_id(client_id_.Binary()); + auto reconstruction_entry = std::make_shared(); + reconstruction_entry->num_reconstructions = reconstruction_attempt; + reconstruction_entry->node_manager_id = client_id_.Binary(); RAY_CHECK_OK(task_reconstruction_log_.AppendAt( DriverID::Nil(), task_id, reconstruction_entry, /*success_callback=*/ [this](gcs::AsyncGcsClient *client, const TaskID &task_id, - const TaskReconstructionData &data) { + const TaskReconstructionDataT &data) { HandleReconstructionLogAppend(task_id, /*success=*/true); }, /*failure_callback=*/ [this](gcs::AsyncGcsClient *client, const TaskID &task_id, - const TaskReconstructionData &data) { + const TaskReconstructionDataT &data) { HandleReconstructionLogAppend(task_id, /*success=*/false); }, reconstruction_attempt)); diff --git a/src/ray/raylet/reconstruction_policy.h b/src/ray/raylet/reconstruction_policy.h index a194443e1425..cd969cc2706e 100644 --- a/src/ray/raylet/reconstruction_policy.h +++ b/src/ray/raylet/reconstruction_policy.h @@ -17,8 +17,6 @@ namespace ray { namespace raylet { -using rpc::TaskReconstructionData; - class ReconstructionPolicyInterface { public: virtual void ListenAndMaybeReconstruct(const ObjectID &object_id) = 0; diff --git a/src/ray/raylet/reconstruction_policy_test.cc b/src/ray/raylet/reconstruction_policy_test.cc index 12d9336a382f..4ccebd0c0c09 100644 --- a/src/ray/raylet/reconstruction_policy_test.cc +++ b/src/ray/raylet/reconstruction_policy_test.cc @@ -14,8 +14,6 @@ namespace ray { namespace raylet { -using rpc::TaskLeaseData; - class MockObjectDirectory : public ObjectDirectoryInterface { public: MockObjectDirectory() {} @@ -85,7 +83,7 @@ class MockGcs : public gcs::PubsubInterface, } void Add(const DriverID &driver_id, const TaskID &task_id, - std::shared_ptr &task_lease_data) { + std::shared_ptr &task_lease_data) { task_lease_table_[task_id] = task_lease_data; if (subscribed_tasks_.count(task_id) == 1) { notification_callback_(nullptr, task_id, *task_lease_data); @@ -112,7 +110,7 @@ class MockGcs : public gcs::PubsubInterface, Status AppendAt( const DriverID &driver_id, const TaskID &task_id, - std::shared_ptr &task_data, + std::shared_ptr &task_data, const ray::gcs::LogInterface::WriteCallback &success_callback, const ray::gcs::LogInterface::WriteCallback @@ -134,15 +132,15 @@ class MockGcs : public gcs::PubsubInterface, MOCK_METHOD4( Append, ray::Status( - const DriverID &, const TaskID &, std::shared_ptr &, + const DriverID &, const TaskID &, std::shared_ptr &, const ray::gcs::LogInterface::WriteCallback &)); private: gcs::TaskLeaseTable::WriteCallback notification_callback_; gcs::TaskLeaseTable::FailureCallback failure_callback_; - std::unordered_map> task_lease_table_; + std::unordered_map> task_lease_table_; std::unordered_set subscribed_tasks_; - std::unordered_map> + std::unordered_map> task_reconstruction_log_; }; @@ -161,9 +159,9 @@ class ReconstructionPolicyTest : public ::testing::Test { timer_canceled_(false) { mock_gcs_.Subscribe( [this](gcs::AsyncGcsClient *client, const TaskID &task_id, - const TaskLeaseData &task_lease) { + const TaskLeaseDataT &task_lease) { reconstruction_policy_->HandleTaskLeaseNotification(task_id, - task_lease.timeout()); + task_lease.timeout); }, [this](gcs::AsyncGcsClient *client, const TaskID &task_id) { reconstruction_policy_->HandleTaskLeaseNotification(task_id, 0); @@ -316,10 +314,10 @@ TEST_F(ReconstructionPolicyTest, TestReconstructionSuppressed) { int64_t test_period = 2 * reconstruction_timeout_ms_; // Acquire the task lease for a period longer than the test period. - auto task_lease_data = std::make_shared(); - task_lease_data->set_node_manager_id(ClientID::FromRandom().Binary()); - task_lease_data->set_acquired_at(current_sys_time_ms()); - task_lease_data->set_timeout(2 * test_period); + auto task_lease_data = std::make_shared(); + task_lease_data->node_manager_id = ClientID::FromRandom().Binary(); + task_lease_data->acquired_at = current_sys_time_ms(); + task_lease_data->timeout = 2 * test_period; mock_gcs_.Add(DriverID::Nil(), task_id, task_lease_data); // Listen for an object. @@ -330,7 +328,7 @@ TEST_F(ReconstructionPolicyTest, TestReconstructionSuppressed) { ASSERT_TRUE(reconstructed_tasks_.empty()); // Run the test again past the expiration time of the lease. - Run(task_lease_data->timeout() * 1.1); + Run(task_lease_data->timeout * 1.1); // Check that this time, reconstruction is triggered. ASSERT_EQ(reconstructed_tasks_[task_id], 1); } @@ -343,10 +341,10 @@ TEST_F(ReconstructionPolicyTest, TestReconstructionContinuallySuppressed) { reconstruction_policy_->ListenAndMaybeReconstruct(object_id); // Send the reconstruction manager heartbeats about the object. SetPeriodicTimer(reconstruction_timeout_ms_ / 2, [this, task_id]() { - auto task_lease_data = std::make_shared(); - task_lease_data->set_node_manager_id(ClientID::FromRandom().Binary()); - task_lease_data->set_acquired_at(current_sys_time_ms()); - task_lease_data->set_timeout(reconstruction_timeout_ms_); + auto task_lease_data = std::make_shared(); + task_lease_data->node_manager_id = ClientID::FromRandom().Binary(); + task_lease_data->acquired_at = current_sys_time_ms(); + task_lease_data->timeout = reconstruction_timeout_ms_; mock_gcs_.Add(DriverID::Nil(), task_id, task_lease_data); }); // Run the test for much longer than the reconstruction timeout. @@ -395,14 +393,14 @@ TEST_F(ReconstructionPolicyTest, TestSimultaneousReconstructionSuppressed) { // Log a reconstruction attempt to simulate a different node attempting the // reconstruction first. This should suppress this node's first attempt at // reconstruction. - auto task_reconstruction_data = std::make_shared(); - task_reconstruction_data->set_node_manager_id(ClientID::FromRandom().Binary()); - task_reconstruction_data->set_num_reconstructions(0); + auto task_reconstruction_data = std::make_shared(); + task_reconstruction_data->node_manager_id = ClientID::FromRandom().Binary(); + task_reconstruction_data->num_reconstructions = 0; RAY_CHECK_OK( mock_gcs_.AppendAt(DriverID::Nil(), task_id, task_reconstruction_data, nullptr, /*failure_callback=*/ [](ray::gcs::AsyncGcsClient *client, const TaskID &task_id, - const TaskReconstructionData &data) { ASSERT_TRUE(false); }, + const TaskReconstructionDataT &data) { ASSERT_TRUE(false); }, /*log_index=*/0)); // Listen for an object. diff --git a/src/ray/raylet/task_dependency_manager.cc b/src/ray/raylet/task_dependency_manager.cc index 89028c733d0d..c5155b96b0c1 100644 --- a/src/ray/raylet/task_dependency_manager.cc +++ b/src/ray/raylet/task_dependency_manager.cc @@ -261,10 +261,10 @@ void TaskDependencyManager::AcquireTaskLease(const TaskID &task_id) { << (it->second.expires_at - now_ms) << "ms"; } - auto task_lease_data = std::make_shared(); - task_lease_data->set_node_manager_id(client_id_.Hex()); - task_lease_data->set_acquired_at(current_sys_time_ms()); - task_lease_data->set_timeout(it->second.lease_period); + auto task_lease_data = std::make_shared(); + task_lease_data->node_manager_id = client_id_.Hex(); + task_lease_data->acquired_at = current_sys_time_ms(); + task_lease_data->timeout = it->second.lease_period; RAY_CHECK_OK(task_lease_table_.Add(DriverID::Nil(), task_id, task_lease_data, nullptr)); auto period = boost::posix_time::milliseconds(it->second.lease_period / 2); diff --git a/src/ray/raylet/task_dependency_manager.h b/src/ray/raylet/task_dependency_manager.h index a96558295234..3788a5eae7ae 100644 --- a/src/ray/raylet/task_dependency_manager.h +++ b/src/ray/raylet/task_dependency_manager.h @@ -13,8 +13,6 @@ namespace ray { namespace raylet { -using rpc::TaskLeaseData; - class ReconstructionPolicy; /// \class TaskDependencyManager diff --git a/src/ray/raylet/task_dependency_manager_test.cc b/src/ray/raylet/task_dependency_manager_test.cc index f7a60989fcba..e0f832a12870 100644 --- a/src/ray/raylet/task_dependency_manager_test.cc +++ b/src/ray/raylet/task_dependency_manager_test.cc @@ -30,7 +30,7 @@ class MockGcs : public gcs::TableInterface { MOCK_METHOD4( Add, ray::Status(const DriverID &driver_id, const TaskID &task_id, - std::shared_ptr &task_data, + std::shared_ptr &task_data, const gcs::TableInterface::WriteCallback &done)); }; diff --git a/src/ray/raylet/task_spec.cc b/src/ray/raylet/task_spec.cc index 1d722de18f73..eeab29272126 100644 --- a/src/ray/raylet/task_spec.cc +++ b/src/ray/raylet/task_spec.cc @@ -80,12 +80,12 @@ TaskSpecification::TaskSpecification( const std::vector> &task_arguments, int64_t num_returns, const std::unordered_map &required_resources, const std::unordered_map &required_placement_resources, - const Language &language, const std::vector &function_descriptor, - const std::vector &dynamic_worker_options) + const Language &language, const std::vector &function_descriptor) : spec_() { flatbuffers::FlatBufferBuilder fbb; TaskID task_id = GenerateTaskId(driver_id, parent_task_id, parent_counter); + // Add argument object IDs. std::vector> arguments; for (auto &argument : task_arguments) { @@ -101,8 +101,7 @@ TaskSpecification::TaskSpecification( ids_to_flatbuf(fbb, new_actor_handles), fbb.CreateVector(arguments), num_returns, map_to_flatbuf(fbb, required_resources), map_to_flatbuf(fbb, required_placement_resources), language, - string_vec_to_flatbuf(fbb, function_descriptor), - string_vec_to_flatbuf(fbb, dynamic_worker_options)); + string_vec_to_flatbuf(fbb, function_descriptor)); fbb.Finish(spec); AssignSpecification(fbb.GetBufferPointer(), fbb.GetSize()); } @@ -259,11 +258,6 @@ std::vector TaskSpecification::NewActorHandles() const { return ids_from_flatbuf(*message->new_actor_handles()); } -std::vector TaskSpecification::DynamicWorkerOptions() const { - auto message = flatbuffers::GetRoot(spec_.data()); - return string_vec_from_flatbuf(*message->dynamic_worker_options()); -} - } // namespace raylet } // namespace ray diff --git a/src/ray/raylet/task_spec.h b/src/ray/raylet/task_spec.h index 8a08e9974ef2..d557c188ae68 100644 --- a/src/ray/raylet/task_spec.h +++ b/src/ray/raylet/task_spec.h @@ -128,7 +128,6 @@ class TaskSpecification { /// will default to be equal to the required_resources argument. /// \param language The language of the worker that must execute the function. /// \param function_descriptor The function descriptor. - /// \param dynamic_worker_options The dynamic options for starting an actor worker. TaskSpecification( const DriverID &driver_id, const TaskID &parent_task_id, int64_t parent_counter, const ActorID &actor_creation_id, const ObjectID &actor_creation_dummy_object_id, @@ -139,8 +138,7 @@ class TaskSpecification { int64_t num_returns, const std::unordered_map &required_resources, const std::unordered_map &required_placement_resources, - const Language &language, const std::vector &function_descriptor, - const std::vector &dynamic_worker_options = {}); + const Language &language, const std::vector &function_descriptor); /// Deserialize a task specification from a string. /// @@ -216,8 +214,6 @@ class TaskSpecification { ObjectID ActorDummyObject() const; std::vector NewActorHandles() const; - std::vector DynamicWorkerOptions() const; - private: /// Assign the specification data from a pointer. void AssignSpecification(const uint8_t *spec, size_t spec_size); diff --git a/src/ray/raylet/worker_pool.cc b/src/ray/raylet/worker_pool.cc index 16086565de80..d4ac4cf4ecce 100644 --- a/src/ray/raylet/worker_pool.cc +++ b/src/ray/raylet/worker_pool.cc @@ -5,12 +5,10 @@ #include #include -#include "ray/common/constants.h" #include "ray/common/ray_config.h" #include "ray/common/status.h" #include "ray/stats/stats.h" #include "ray/util/logging.h" -#include "ray/util/util.h" namespace { @@ -43,13 +41,12 @@ namespace raylet { /// (num_worker_processes * num_workers_per_process) workers for each language. WorkerPool::WorkerPool( int num_worker_processes, int num_workers_per_process, - int maximum_startup_concurrency, std::shared_ptr gcs_client, + int maximum_startup_concurrency, const std::unordered_map> &worker_commands) : num_workers_per_process_(num_workers_per_process), multiple_for_warning_(std::max(num_worker_processes, maximum_startup_concurrency)), maximum_startup_concurrency_(maximum_startup_concurrency), - last_warning_multiple_(0), - gcs_client_(std::move(gcs_client)) { + last_warning_multiple_(0) { RAY_CHECK(num_workers_per_process > 0) << "num_workers_per_process must be positive."; RAY_CHECK(maximum_startup_concurrency > 0); // Ignore SIGCHLD signals. If we don't do this, then worker processes will @@ -101,8 +98,7 @@ uint32_t WorkerPool::Size(const Language &language) const { } } -int WorkerPool::StartWorkerProcess(const Language &language, - const std::vector &dynamic_options) { +void WorkerPool::StartWorkerProcess(const Language &language) { auto &state = GetStateForLanguage(language); // If we are already starting up too many workers, then return without starting // more. @@ -112,7 +108,7 @@ int WorkerPool::StartWorkerProcess(const Language &language, RAY_LOG(DEBUG) << "Worker not started, " << state.starting_worker_processes.size() << " worker processes of language type " << static_cast(language) << " pending registration"; - return -1; + return; } // Either there are no workers pending registration or the worker start is being forced. RAY_LOG(DEBUG) << "Starting new worker process, current pool has " @@ -121,20 +117,8 @@ int WorkerPool::StartWorkerProcess(const Language &language, // Extract pointers from the worker command to pass into execvp. std::vector worker_command_args; - size_t dynamic_option_index = 0; for (auto const &token : state.worker_command) { - const auto option_placeholder = - kWorkerDynamicOptionPlaceholderPrefix + std::to_string(dynamic_option_index); - - if (token == option_placeholder) { - if (!dynamic_options.empty()) { - RAY_CHECK(dynamic_option_index < dynamic_options.size()); - worker_command_args.push_back(dynamic_options[dynamic_option_index].c_str()); - ++dynamic_option_index; - } - } else { - worker_command_args.push_back(token.c_str()); - } + worker_command_args.push_back(token.c_str()); } worker_command_args.push_back(nullptr); @@ -142,14 +126,14 @@ int WorkerPool::StartWorkerProcess(const Language &language, if (pid < 0) { // Failure case. RAY_LOG(FATAL) << "Failed to fork worker process: " << strerror(errno); + return; } else if (pid > 0) { // Parent process case. RAY_LOG(DEBUG) << "Started worker process with pid " << pid; state.starting_worker_processes.emplace( std::make_pair(pid, num_workers_per_process_)); - return pid; + return; } - return -1; } pid_t WorkerPool::StartProcess(const std::vector &worker_command_args) { @@ -174,7 +158,7 @@ pid_t WorkerPool::StartProcess(const std::vector &worker_command_a } void WorkerPool::RegisterWorker(const std::shared_ptr &worker) { - const auto pid = worker->Pid(); + auto pid = worker->Pid(); RAY_LOG(DEBUG) << "Registering worker with pid " << pid; auto &state = GetStateForLanguage(worker->GetLanguage()); state.registered_workers.insert(std::move(worker)); @@ -223,74 +207,30 @@ void WorkerPool::PushWorker(const std::shared_ptr &worker) { RAY_CHECK(worker->GetAssignedTaskId().IsNil()) << "Idle workers cannot have an assigned task ID"; auto &state = GetStateForLanguage(worker->GetLanguage()); - - auto it = state.dedicated_workers_to_tasks.find(worker->Pid()); - if (it != state.dedicated_workers_to_tasks.end()) { - // The worker is used for the actor creation task with dynamic options. - // Put it into idle dedicated worker pool. - const auto task_id = it->second; - state.idle_dedicated_workers[task_id] = std::move(worker); + // Add the worker to the idle pool. + if (worker->GetActorId().IsNil()) { + state.idle.insert(std::move(worker)); } else { - // The worker is not used for the actor creation task without dynamic options. - // Put the worker to the corresponding idle pool. - if (worker->GetActorId().IsNil()) { - state.idle.insert(std::move(worker)); - } else { - state.idle_actor[worker->GetActorId()] = std::move(worker); - } + state.idle_actor[worker->GetActorId()] = std::move(worker); } } std::shared_ptr WorkerPool::PopWorker(const TaskSpecification &task_spec) { auto &state = GetStateForLanguage(task_spec.GetLanguage()); const auto &actor_id = task_spec.ActorId(); - std::shared_ptr worker = nullptr; - int pid = -1; - if (task_spec.IsActorCreationTask() && !task_spec.DynamicWorkerOptions().empty()) { - // Code path of actor creation task with dynamic worker options. - // Try to pop it from idle dedicated pool. - auto it = state.idle_dedicated_workers.find(task_spec.TaskId()); - if (it != state.idle_dedicated_workers.end()) { - // There is an idle dedicated worker for this task. - worker = std::move(it->second); - state.idle_dedicated_workers.erase(it); - // Because we found a worker that can perform this task, - // we can remove it from dedicated_workers_to_tasks. - state.dedicated_workers_to_tasks.erase(worker->Pid()); - state.tasks_to_dedicated_workers.erase(task_spec.TaskId()); - } else if (!HasPendingWorkerForTask(task_spec.GetLanguage(), task_spec.TaskId())) { - // We are not pending a registration from a worker for this task, - // so start a new worker process for this task. - pid = StartWorkerProcess(task_spec.GetLanguage(), task_spec.DynamicWorkerOptions()); - if (pid > 0) { - state.dedicated_workers_to_tasks[pid] = task_spec.TaskId(); - state.tasks_to_dedicated_workers[task_spec.TaskId()] = pid; - } - } - } else if (!task_spec.IsActorTask()) { - // Code path of normal task or actor creation task without dynamic worker options. + if (actor_id.IsNil()) { if (!state.idle.empty()) { worker = std::move(*state.idle.begin()); state.idle.erase(state.idle.begin()); - } else { - // There are no more non-actor workers available to execute this task. - // Start a new worker process. - pid = StartWorkerProcess(task_spec.GetLanguage()); } } else { - // Code path of actor task. auto actor_entry = state.idle_actor.find(actor_id); if (actor_entry != state.idle_actor.end()) { worker = std::move(actor_entry->second); state.idle_actor.erase(actor_entry); } } - - if (worker == nullptr && pid > 0) { - WarnAboutSize(); - } - return worker; } @@ -334,7 +274,7 @@ std::vector> WorkerPool::GetWorkersRunningTasksForDriver return workers; } -void WorkerPool::WarnAboutSize() { +std::string WorkerPool::WarningAboutSize() { int64_t num_workers_started_or_registered = 0; for (const auto &entry : states_by_lang_) { num_workers_started_or_registered += @@ -345,8 +285,6 @@ void WorkerPool::WarnAboutSize() { int64_t multiple = num_workers_started_or_registered / multiple_for_warning_; std::stringstream warning_message; if (multiple >= 3 && multiple > last_warning_multiple_) { - // Push an error message to the user if the worker pool tells us that it is - // getting too big. last_warning_multiple_ = multiple; warning_message << "WARNING: " << num_workers_started_or_registered << " workers have been started. This could be a result of using " @@ -354,16 +292,8 @@ void WorkerPool::WarnAboutSize() { << "using nested tasks " << "(see https://github.com/ray-project/ray/issues/3644) for " << "some a discussion of workarounds."; - RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver( - DriverID::Nil(), "worker_pool_large", warning_message.str(), current_time_ms())); } -} - -bool WorkerPool::HasPendingWorkerForTask(const Language &language, - const TaskID &task_id) { - auto &state = GetStateForLanguage(language); - auto it = state.tasks_to_dedicated_workers.find(task_id); - return it != state.tasks_to_dedicated_workers.end(); + return warning_message.str(); } std::string WorkerPool::DebugString() const { diff --git a/src/ray/raylet/worker_pool.h b/src/ray/raylet/worker_pool.h index e1e726268093..03443447cf58 100644 --- a/src/ray/raylet/worker_pool.h +++ b/src/ray/raylet/worker_pool.h @@ -7,7 +7,6 @@ #include #include "ray/common/client_connection.h" -#include "ray/gcs/client.h" #include "ray/gcs/format/util.h" #include "ray/raylet/task.h" #include "ray/raylet/worker.h" @@ -38,12 +37,22 @@ class WorkerPool { /// language. WorkerPool( int num_worker_processes, int num_workers_per_process, - int maximum_startup_concurrency, std::shared_ptr gcs_client, + int maximum_startup_concurrency, const std::unordered_map> &worker_commands); /// Destructor responsible for freeing a set of workers owned by this class. virtual ~WorkerPool(); + /// Asynchronously start a new worker process. Once the worker process has + /// registered with an external server, the process should create and + /// register num_workers_per_process_ workers, then add them to the pool. + /// Failure to start the worker process is a fatal error. If too many workers + /// are already being started, then this function will return without starting + /// any workers. + /// + /// \param language Which language this worker process should be. + void StartWorkerProcess(const Language &language); + /// Register a new worker. The Worker should be added by the caller to the /// pool after it becomes idle (e.g., requests a work assignment). /// @@ -109,15 +118,6 @@ class WorkerPool { std::vector> GetWorkersRunningTasksForDriver( const DriverID &driver_id) const; - /// Whether there is a pending worker for the given task. - /// Note that, this is only used for actor creation task with dynamic options. - /// And if the worker registered but isn't assigned a task, - /// the worker also is in pending state, and this'll return true. - /// - /// \param language The required language. - /// \param task_id The task that we want to query. - bool HasPendingWorkerForTask(const Language &language, const TaskID &task_id); - /// Returns debug string for class. /// /// \return string. @@ -126,37 +126,24 @@ class WorkerPool { /// Record metrics. void RecordMetrics() const; - protected: - /// Asynchronously start a new worker process. Once the worker process has - /// registered with an external server, the process should create and - /// register num_workers_per_process_ workers, then add them to the pool. - /// Failure to start the worker process is a fatal error. If too many workers - /// are already being started, then this function will return without starting - /// any workers. + /// Generate a warning about the number of workers that have registered or + /// started if appropriate. /// - /// \param language Which language this worker process should be. - /// \param dynamic_options The dynamic options that we should add for worker command. - /// \return The id of the process that we started if it's positive, - /// otherwise it means we didn't start a process. - int StartWorkerProcess(const Language &language, - const std::vector &dynamic_options = {}); + /// \return An empty string if no warning should be generated and otherwise a + /// string with a warning message. + std::string WarningAboutSize(); + protected: /// The implementation of how to start a new worker process with command arguments. /// /// \param worker_command_args The command arguments of new worker process. /// \return The process ID of started worker process. virtual pid_t StartProcess(const std::vector &worker_command_args); - /// Push an warning message to user if worker pool is getting to big. - virtual void WarnAboutSize(); - /// An internal data structure that maintains the pool state per language. struct State { /// The commands and arguments used to start the worker process std::vector worker_command; - /// The pool of dedicated workers for actor creation tasks - /// with prefix or suffix worker command. - std::unordered_map> idle_dedicated_workers; /// The pool of idle non-actor workers. std::unordered_set> idle; /// The pool of idle actor workers. @@ -169,11 +156,6 @@ class WorkerPool { /// A map from the pids of starting worker processes /// to the number of their unregistered workers. std::unordered_map starting_worker_processes; - /// A map for looking up the task with dynamic options by the pid of - /// worker. Note that this is used for the dedicated worker processes. - std::unordered_map dedicated_workers_to_tasks; - /// A map for speeding up looking up the pending worker for the given task. - std::unordered_map tasks_to_dedicated_workers; }; /// The number of workers per process. @@ -184,7 +166,7 @@ class WorkerPool { private: /// A helper function that returns the reference of the pool state /// for a given language. - State &GetStateForLanguage(const Language &language); + inline State &GetStateForLanguage(const Language &language); /// We'll push a warning to the user every time a multiple of this many /// workers has been started. @@ -194,8 +176,6 @@ class WorkerPool { /// The last size at which a warning about the number of registered workers /// was generated. int64_t last_warning_multiple_; - /// A client connection to the GCS. - std::shared_ptr gcs_client_; }; } // namespace raylet diff --git a/src/ray/raylet/worker_pool_test.cc b/src/ray/raylet/worker_pool_test.cc index 15a5fb0471e0..143ffd57dda6 100644 --- a/src/ray/raylet/worker_pool_test.cc +++ b/src/ray/raylet/worker_pool_test.cc @@ -1,7 +1,6 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" -#include "ray/common/constants.h" #include "ray/raylet/node_manager.h" #include "ray/raylet/worker_pool.h" @@ -15,46 +14,21 @@ int MAXIMUM_STARTUP_CONCURRENCY = 5; class WorkerPoolMock : public WorkerPool { public: WorkerPoolMock() - : WorkerPoolMock({{Language::PYTHON, {"dummy_py_worker_command"}}, - {Language::JAVA, {"dummy_java_worker_command"}}}) {} - - explicit WorkerPoolMock( - const std::unordered_map> &worker_commands) - : WorkerPool(0, NUM_WORKERS_PER_PROCESS, MAXIMUM_STARTUP_CONCURRENCY, nullptr, - worker_commands), + : WorkerPool(0, NUM_WORKERS_PER_PROCESS, MAXIMUM_STARTUP_CONCURRENCY, + {{Language::PYTHON, {"dummy_py_worker_command"}}, + {Language::JAVA, {"dummy_java_worker_command"}}}), last_worker_pid_(0) {} - ~WorkerPoolMock() { // Avoid killing real processes states_by_lang_.clear(); } - void StartWorkerProcess(const Language &language, - const std::vector &dynamic_options = {}) { - WorkerPool::StartWorkerProcess(language, dynamic_options); - } - pid_t StartProcess(const std::vector &worker_command_args) override { - last_worker_pid_ += 1; - std::vector local_worker_commands_args; - for (auto item : worker_command_args) { - if (item == nullptr) { - break; - } - local_worker_commands_args.push_back(std::string(item)); - } - worker_commands_by_pid[last_worker_pid_] = std::move(local_worker_commands_args); - return last_worker_pid_; + return ++last_worker_pid_; } - void WarnAboutSize() override {} - pid_t LastStartedWorkerProcess() const { return last_worker_pid_; } - const std::vector &GetWorkerCommand(int pid) { - return worker_commands_by_pid[pid]; - } - int NumWorkerProcessesStarting() const { int total = 0; for (auto &entry : states_by_lang_) { @@ -65,8 +39,6 @@ class WorkerPoolMock : public WorkerPool { private: int last_worker_pid_; - // The worker commands by pid. - std::unordered_map> worker_commands_by_pid; }; class WorkerPoolTest : public ::testing::Test { @@ -89,12 +61,6 @@ class WorkerPoolTest : public ::testing::Test { return std::shared_ptr(new Worker(pid, language, client)); } - void SetWorkerCommands( - const std::unordered_map> &worker_commands) { - WorkerPoolMock worker_pool(worker_commands); - this->worker_pool_ = std::move(worker_pool); - } - protected: WorkerPoolMock worker_pool_; boost::asio::io_service io_service_; @@ -106,10 +72,10 @@ class WorkerPoolTest : public ::testing::Test { }; static inline TaskSpecification ExampleTaskSpec( - const ActorID actor_id = ActorID::Nil(), const Language &language = Language::PYTHON, - const ActorID actor_creation_id = ActorID::Nil()) { + const ActorID actor_id = ActorID::Nil(), + const Language &language = Language::PYTHON) { std::vector function_descriptor(3); - return TaskSpecification(DriverID::Nil(), TaskID::Nil(), 0, actor_creation_id, + return TaskSpecification(DriverID::Nil(), TaskID::Nil(), 0, ActorID::Nil(), ObjectID::Nil(), 0, actor_id, ActorHandleID::Nil(), 0, {}, {}, 0, {}, {}, language, function_descriptor); } @@ -220,23 +186,6 @@ TEST_F(WorkerPoolTest, PopWorkersOfMultipleLanguages) { ASSERT_NE(worker_pool_.PopWorker(java_task_spec), nullptr); } -TEST_F(WorkerPoolTest, StartWorkerWithDynamicOptionsCommand) { - const std::vector java_worker_command = { - "RAY_WORKER_OPTION_0", "dummy_java_worker_command", "RAY_WORKER_OPTION_1"}; - SetWorkerCommands({{Language::PYTHON, {"dummy_py_worker_command"}}, - {Language::JAVA, java_worker_command}}); - - TaskSpecification task_spec(DriverID::Nil(), TaskID::Nil(), 0, ActorID::FromRandom(), - ObjectID::Nil(), 0, ActorID::Nil(), ActorHandleID::Nil(), 0, - {}, {}, 0, {}, {}, Language::JAVA, {"", "", ""}, - {"test_op_0", "test_op_1"}); - worker_pool_.StartWorkerProcess(Language::JAVA, task_spec.DynamicWorkerOptions()); - const auto real_command = - worker_pool_.GetWorkerCommand(worker_pool_.LastStartedWorkerProcess()); - ASSERT_EQ(real_command, std::vector( - {"test_op_0", "dummy_java_worker_command", "test_op_1"})); -} - } // namespace raylet } // namespace ray diff --git a/src/ray/rpc/grpc_server.cc b/src/ray/rpc/grpc_server.cc index f507039990c2..feb788da7692 100644 --- a/src/ray/rpc/grpc_server.cc +++ b/src/ray/rpc/grpc_server.cc @@ -1,5 +1,4 @@ #include "ray/rpc/grpc_server.h" -#include namespace ray { namespace rpc { @@ -10,10 +9,8 @@ void GrpcServer::Run() { grpc::ServerBuilder builder; // TODO(hchen): Add options for authentication. builder.AddListeningPort(server_address, grpc::InsecureServerCredentials(), &port_); - // Register all the services to this server. - for (auto &entry : services_) { - builder.RegisterService(&entry.get()); - } + // Allow subclasses to register concrete services. + RegisterServices(builder); // Get hold of the completion queue used for the asynchronous communication // with the gRPC runtime. cq_ = builder.AddCompletionQueue(); @@ -21,7 +18,8 @@ void GrpcServer::Run() { server_ = builder.BuildAndStart(); RAY_LOG(DEBUG) << name_ << " server started, listening on port " << port_ << "."; - // Create calls for all the server call factories. + // Allow subclasses to initialize the server call factories. + InitServerCallFactories(&server_call_factories_and_concurrencies_); for (auto &entry : server_call_factories_and_concurrencies_) { for (int i = 0; i < entry.second; i++) { // Create and request calls from the factory. @@ -33,11 +31,6 @@ void GrpcServer::Run() { polling_thread.detach(); } -void GrpcServer::RegisterService(GrpcService &service) { - services_.emplace_back(service.GetGrpcService()); - service.InitServerCallFactories(cq_, &server_call_factories_and_concurrencies_); -} - void GrpcServer::PollEventsFromCompletionQueue() { void *tag; bool ok; @@ -55,7 +48,7 @@ void GrpcServer::PollEventsFromCompletionQueue() { // incoming request. server_call->GetFactory().CreateCall(); server_call->SetState(ServerCallState::PROCESSING); - server_call->HandleRequest(); + main_service_.post([server_call] { server_call->HandleRequest(); }); break; case ServerCallState::SENDING_REPLY: // The reply has been sent, this call can be deleted now. diff --git a/src/ray/rpc/grpc_server.h b/src/ray/rpc/grpc_server.h index 584da6565a47..4953f470610f 100644 --- a/src/ray/rpc/grpc_server.h +++ b/src/ray/rpc/grpc_server.h @@ -12,9 +12,7 @@ namespace ray { namespace rpc { -class GrpcService; - -/// Class that represents an gRPC server. +/// Base class that represents an abstract gRPC server. /// /// A `GrpcServer` listens on a specific port. It owns /// 1) a `ServerCompletionQueue` that is used for polling events from gRPC, @@ -30,7 +28,11 @@ class GrpcServer { /// \param[in] name Name of this server, used for logging and debugging purpose. /// \param[in] port The port to bind this server to. If it's 0, a random available port /// will be chosen. - GrpcServer(const std::string &name, const uint32_t port) : name_(name), port_(port) {} + /// \param[in] main_service The main event loop, to which service handler functions + /// will be posted. + GrpcServer(const std::string &name, const uint32_t port, + boost::asio::io_service &main_service) + : name_(name), port_(port), main_service_(main_service) {} /// Destruct this gRPC server. ~GrpcServer() { @@ -44,25 +46,36 @@ class GrpcServer { /// Get the port of this gRPC server. int GetPort() const { return port_; } - /// Register a grpc service. Multiple services can be registered to the same server. - /// Note that the `service` registered must remain valid for the lifetime of the - /// `GrpcServer`, as it holds the underlying `grpc::Service`. + protected: + /// Subclasses should implement this method and register one or multiple gRPC services + /// to the given `ServerBuilder`. /// - /// \param[in] service A `GrpcService` to register to this server. - void RegisterService(GrpcService &service); + /// \param[in] builder The `ServerBuilder` instance to register services to. + virtual void RegisterServices(grpc::ServerBuilder &builder) = 0; + + /// Subclasses should implement this method to initialize the `ServerCallFactory` + /// instances, as well as specify maximum number of concurrent requests that gRPC + /// server can "accept" (not "handle"). Each factory will be used to create + /// `accept_concurrency` `ServerCall` objects, each of which will be used to accept and + /// handle an incoming request. + /// + /// \param[out] server_call_factories_and_concurrencies The `ServerCallFactory` objects, + /// and the maximum number of concurrent requests that gRPC server can accept. + virtual void InitServerCallFactories( + std::vector, int>> + *server_call_factories_and_concurrencies) = 0; - protected: /// This function runs in a background thread. It keeps polling events from the /// `ServerCompletionQueue`, and dispaches the event to the `ServiceHandler` instances /// via the `ServerCall` objects. void PollEventsFromCompletionQueue(); + /// The main event loop, to which the service handler functions will be posted. + boost::asio::io_service &main_service_; /// Name of this server, used for logging and debugging purpose. const std::string name_; /// Port of this server. int port_; - /// The `grpc::Service` objects which should be registered to `ServerBuilder`. - std::vector> services_; /// The `ServerCallFactory` objects, and the maximum number of concurrent requests that /// gRPC server can accept. std::vector, int>> @@ -73,46 +86,6 @@ class GrpcServer { std::unique_ptr server_; }; -/// Base class that represents an abstract gRPC service. -/// -/// Subclass should implement `InitServerCallFactories` to decide -/// which kinds of requests this service should accept. -class GrpcService { - public: - /// Constructor. - /// - /// \param[in] main_service The main event loop, to which service handler functions - /// will be posted. - GrpcService(boost::asio::io_service &main_service) : main_service_(main_service) {} - - /// Destruct this gRPC service. - ~GrpcService() {} - - protected: - /// Return the underlying grpc::Service object for this class. - /// This is passed to `GrpcServer` to be registered to grpc `ServerBuilder`. - virtual grpc::Service &GetGrpcService() = 0; - - /// Subclasses should implement this method to initialize the `ServerCallFactory` - /// instances, as well as specify maximum number of concurrent requests that gRPC - /// server can "accept" (not "handle"). Each factory will be used to create - /// `accept_concurrency` `ServerCall` objects, each of which will be used to accept and - /// handle an incoming request. - /// - /// \param[in] cq The grpc completion queue. - /// \param[out] server_call_factories_and_concurrencies The `ServerCallFactory` objects, - /// and the maximum number of concurrent requests that gRPC server can accept. - virtual void InitServerCallFactories( - const std::unique_ptr &cq, - std::vector, int>> - *server_call_factories_and_concurrencies) = 0; - - /// The main event loop, to which the service handler functions will be posted. - boost::asio::io_service &main_service_; - - friend class GrpcServer; -}; - } // namespace rpc } // namespace ray diff --git a/src/ray/rpc/node_manager_server.h b/src/ray/rpc/node_manager_server.h index d05f268c65b2..afaea299ea89 100644 --- a/src/ray/rpc/node_manager_server.h +++ b/src/ray/rpc/node_manager_server.h @@ -25,22 +25,25 @@ class NodeManagerServiceHandler { RequestDoneCallback done_callback) = 0; }; -/// The `GrpcService` for `NodeManagerService`. -class NodeManagerGrpcService : public GrpcService { +/// The `GrpcServer` for `NodeManagerService`. +class NodeManagerServer : public GrpcServer { public: /// Constructor. /// - /// \param[in] io_service See super class. + /// \param[in] port See super class. + /// \param[in] main_service See super class. /// \param[in] handler The service handler that actually handle the requests. - NodeManagerGrpcService(boost::asio::io_service &io_service, - NodeManagerServiceHandler &service_handler) - : GrpcService(io_service), service_handler_(service_handler){}; + NodeManagerServer(const uint32_t port, boost::asio::io_service &main_service, + NodeManagerServiceHandler &service_handler) + : GrpcServer("NodeManager", port, main_service), + service_handler_(service_handler){}; - protected: - grpc::Service &GetGrpcService() override { return service_; } + void RegisterServices(grpc::ServerBuilder &builder) override { + /// Register `NodeManagerService`. + builder.RegisterService(&service_); + } void InitServerCallFactories( - const std::unique_ptr &cq, std::vector, int>> *server_call_factories_and_concurrencies) override { // Initialize the factory for `ForwardTask` requests. @@ -48,8 +51,7 @@ class NodeManagerGrpcService : public GrpcService { new ServerCallFactoryImpl( service_, &NodeManagerService::AsyncService::RequestForwardTask, - service_handler_, &NodeManagerServiceHandler::HandleForwardTask, cq, - main_service_)); + service_handler_, &NodeManagerServiceHandler::HandleForwardTask, cq_)); // Set `ForwardTask`'s accept concurrency to 100. server_call_factories_and_concurrencies->emplace_back( @@ -59,7 +61,6 @@ class NodeManagerGrpcService : public GrpcService { private: /// The grpc async service object. NodeManagerService::AsyncService service_; - /// The service handler that actually handle the requests. NodeManagerServiceHandler &service_handler_; }; diff --git a/src/ray/rpc/server_call.h b/src/ray/rpc/server_call.h index 08ca128323ee..e06278260ab6 100644 --- a/src/ray/rpc/server_call.h +++ b/src/ray/rpc/server_call.h @@ -94,27 +94,20 @@ class ServerCallImpl : public ServerCall { /// \param[in] factory The factory which created this call. /// \param[in] service_handler The service handler that handles the request. /// \param[in] handle_request_function Pointer to the service handler function. - /// \param[in] io_service The event loop. ServerCallImpl( const ServerCallFactory &factory, ServiceHandler &service_handler, - HandleRequestFunction handle_request_function, - boost::asio::io_service &io_service) + HandleRequestFunction handle_request_function) : state_(ServerCallState::PENDING), factory_(factory), service_handler_(service_handler), handle_request_function_(handle_request_function), - response_writer_(&context_), - io_service_(io_service) {} + response_writer_(&context_) {} ServerCallState GetState() const override { return state_; } void SetState(const ServerCallState &new_state) override { state_ = new_state; } void HandleRequest() override { - io_service_.post([this] { HandleRequestImpl(); }); - } - - void HandleRequestImpl() { state_ = ServerCallState::PROCESSING; (service_handler_.*handle_request_function_)(request_, &reply_, [this](Status status) { @@ -153,9 +146,6 @@ class ServerCallImpl : public ServerCall { /// The reponse writer. grpc::ServerAsyncResponseWriter response_writer_; - /// The event loop. - boost::asio::io_service &io_service_; - /// The request message. Request request_; @@ -195,26 +185,23 @@ class ServerCallFactoryImpl : public ServerCallFactory { /// \param[in] service_handler The service handler that handles the request. /// \param[in] handle_request_function Pointer to the service handler function. /// \param[in] cq The `CompletionQueue`. - /// \param[in] io_service The event loop. ServerCallFactoryImpl( AsyncService &service, RequestCallFunction request_call_function, ServiceHandler &service_handler, HandleRequestFunction handle_request_function, - const std::unique_ptr &cq, - boost::asio::io_service &io_service) + const std::unique_ptr &cq) : service_(service), request_call_function_(request_call_function), service_handler_(service_handler), handle_request_function_(handle_request_function), - cq_(cq), - io_service_(io_service) {} + cq_(cq) {} ServerCall *CreateCall() const override { // Create a new `ServerCall`. This object will eventually be deleted by // `GrpcServer::PollEventsFromCompletionQueue`. auto call = new ServerCallImpl( - *this, service_handler_, handle_request_function_, io_service_); + *this, service_handler_, handle_request_function_); /// Request gRPC runtime to starting accepting this kind of request, using the call as /// the tag. (service_.*request_call_function_)(&call->context_, &call->request_, @@ -238,9 +225,6 @@ class ServerCallFactoryImpl : public ServerCallFactory { /// The `CompletionQueue`. const std::unique_ptr &cq_; - - /// The event loop. - boost::asio::io_service &io_service_; }; } // namespace rpc diff --git a/src/ray/rpc/util.h b/src/ray/rpc/util.h index 59ae75ae33be..6ecc6c3c4a34 100644 --- a/src/ray/rpc/util.h +++ b/src/ray/rpc/util.h @@ -1,7 +1,6 @@ #ifndef RAY_RPC_UTIL_H #define RAY_RPC_UTIL_H -#include #include #include "ray/common/status.h" @@ -28,18 +27,6 @@ inline Status GrpcStatusToRayStatus(const grpc::Status &grpc_status) { } } -template -inline std::vector VectorFromProtobuf( - const ::google::protobuf::RepeatedPtrField &pb_repeated) { - return std::vector(pb_repeated.begin(), pb_repeated.end()); -} - -template -inline std::vector VectorFromProtobuf( - const ::google::protobuf::RepeatedField &pb_repeated) { - return std::vector(pb_repeated.begin(), pb_repeated.end()); -} - } // namespace rpc } // namespace ray From 014cbb70aad54ab78275bb2b6694b5ec34472150 Mon Sep 17 00:00:00 2001 From: Stefan Pantic Date: Wed, 26 Jun 2019 14:59:10 +0200 Subject: [PATCH 117/118] Revert "Revert "Merge with ray master"" This reverts commit 92c0f88b9cd75be6281204467b95526951c03e87. --- BUILD.bazel | 96 ++-- bazel/ray_deps_build_all.bzl | 4 + bazel/ray_deps_setup.bzl | 11 +- .../run_perf_integration.sh | 2 +- ci/jenkins_tests/run_tune_tests.sh | 8 +- doc/source/conf.py | 15 +- doc/source/tune-usage.rst | 6 + docker/base-deps/Dockerfile | 2 +- docker/examples/Dockerfile | 5 +- docker/stress_test/Dockerfile | 2 +- docker/tune_test/Dockerfile | 11 +- java/BUILD.bazel | 51 +-- .../src/main/java/org/ray/api/id/BaseId.java | 2 +- .../ray/api/options/ActorCreationOptions.java | 15 +- java/dependencies.bzl | 1 + ...modify_generated_java_flatbuffers_files.py | 20 +- java/runtime/pom.xml | 5 + .../org/ray/runtime/AbstractRayRuntime.java | 9 +- .../java/org/ray/runtime/gcs/GcsClient.java | 69 +-- .../runtime/objectstore/ObjectStoreProxy.java | 12 +- .../ray/runtime/raylet/RayletClientImpl.java | 18 +- .../org/ray/runtime/runner/RunManager.java | 3 + .../java/org/ray/runtime/task/TaskSpec.java | 8 +- .../src/main/java/org/ray/api/TestUtils.java | 15 + .../org/ray/api/test/DynamicResourceTest.java | 17 +- .../main/java/org/ray/api/test/WaitTest.java | 5 + .../ray/api/test/WorkerJvmOptionsTest.java | 31 ++ python/ray/experimental/signal.py | 14 +- python/ray/gcs_utils.py | 71 ++- python/ray/monitor.py | 33 +- python/ray/rllib/agents/a3c/a3c.py | 4 + python/ray/rllib/agents/impala/impala.py | 1 + .../ray/rllib/agents/impala/vtrace_policy.py | 8 +- python/ray/rllib/agents/qmix/qmix_policy.py | 2 + python/ray/rllib/policy/tf_policy.py | 38 +- python/ray/rllib/tests/test_optimizers.py | 10 +- python/ray/services.py | 3 + python/ray/state.py | 230 ++++------ python/ray/tests/cluster_utils.py | 4 +- python/ray/tests/conftest.py | 8 + python/ray/tests/test_actor.py | 2 +- python/ray/tests/test_basic.py | 14 +- python/ray/tests/test_failure.py | 5 +- python/ray/tests/test_signal.py | 33 ++ .../ray/tune/analysis/experiment_analysis.py | 94 +++- python/ray/tune/examples/mnist_pytorch.py | 273 +++++------- python/ray/tune/examples/track_example.py | 4 +- python/ray/tune/examples/tune_mnist_keras.py | 8 +- python/ray/tune/examples/utils.py | 36 +- python/ray/tune/experiment.py | 8 + python/ray/tune/integration/__init__.py | 0 python/ray/tune/integration/keras.py | 34 ++ python/ray/tune/schedulers/__init__.py | 6 +- python/ray/tune/schedulers/async_hyperband.py | 2 + .../tune/tests/test_experiment_analysis.py | 62 +-- python/ray/tune/tests/test_trial_runner.py | 8 + python/ray/tune/trial.py | 25 +- python/ray/tune/tune.py | 11 +- python/ray/utils.py | 8 +- python/ray/worker.py | 40 +- python/setup.py | 1 + src/ray/common/constants.h | 2 + src/ray/gcs/client.cc | 4 - src/ray/gcs/client.h | 6 - src/ray/gcs/client_test.cc | 353 +++++++-------- src/ray/gcs/format/gcs.fbs | 286 +----------- src/ray/gcs/redis_context.h | 15 +- src/ray/gcs/redis_module/ray_redis_module.cc | 209 ++++----- src/ray/gcs/tables.cc | 417 ++++++++---------- src/ray/gcs/tables.h | 136 +++--- src/ray/object_manager/object_directory.cc | 34 +- src/ray/object_manager/object_manager.cc | 49 +- src/ray/object_manager/object_manager.h | 4 +- .../test/object_manager_stress_test.cc | 30 +- .../test/object_manager_test.cc | 36 +- src/ray/protobuf/gcs.proto | 280 ++++++++++++ src/ray/raylet/actor_registration.cc | 51 +-- src/ray/raylet/actor_registration.h | 24 +- src/ray/raylet/lineage_cache.cc | 37 +- src/ray/raylet/lineage_cache.h | 28 +- src/ray/raylet/lineage_cache_test.cc | 28 +- src/ray/raylet/monitor.cc | 15 +- src/ray/raylet/monitor.h | 8 +- src/ray/raylet/node_manager.cc | 262 +++++------ src/ray/raylet/node_manager.h | 31 +- src/ray/raylet/raylet.cc | 24 +- src/ray/raylet/raylet.h | 2 + src/ray/raylet/reconstruction_policy.cc | 10 +- src/ray/raylet/reconstruction_policy.h | 2 + src/ray/raylet/reconstruction_policy_test.cc | 42 +- src/ray/raylet/task_dependency_manager.cc | 8 +- src/ray/raylet/task_dependency_manager.h | 2 + .../raylet/task_dependency_manager_test.cc | 2 +- src/ray/raylet/task_spec.cc | 12 +- src/ray/raylet/task_spec.h | 6 +- src/ray/raylet/worker_pool.cc | 100 ++++- src/ray/raylet/worker_pool.h | 56 ++- src/ray/raylet/worker_pool_test.cc | 65 ++- src/ray/rpc/grpc_server.cc | 17 +- src/ray/rpc/grpc_server.h | 77 ++-- src/ray/rpc/node_manager_server.h | 25 +- src/ray/rpc/server_call.h | 26 +- src/ray/rpc/util.h | 13 + 103 files changed, 2338 insertions(+), 2039 deletions(-) create mode 100644 java/test/src/main/java/org/ray/api/test/WorkerJvmOptionsTest.java create mode 100644 python/ray/tune/integration/__init__.py create mode 100644 python/ray/tune/integration/keras.py create mode 100644 src/ray/protobuf/gcs.proto diff --git a/BUILD.bazel b/BUILD.bazel index da36eec0cf57..bc9e6bcd8006 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -1,22 +1,55 @@ # Bazel build # C/C++ documentation: https://docs.bazel.build/versions/master/be/c-cpp.html -load("@com_github_grpc_grpc//bazel:grpc_build_system.bzl", "grpc_proto_library") +load("@com_github_grpc_grpc//bazel:cc_grpc_library.bzl", "cc_grpc_library") +load("@build_stack_rules_proto//python:python_proto_compile.bzl", "python_proto_compile") load("@com_github_google_flatbuffers//:build_defs.bzl", "flatbuffer_cc_library") load("@//bazel:ray.bzl", "flatbuffer_py_library") load("@//bazel:cython_library.bzl", "pyx_library") COPTS = ["-DRAY_USE_GLOG"] -# Node manager gRPC lib. -grpc_proto_library( - name = "node_manager_grpc_lib", +# === Begin of protobuf definitions === + +proto_library( + name = "gcs_proto", + srcs = ["src/ray/protobuf/gcs.proto"], + visibility = ["//java:__subpackages__"], +) + +cc_proto_library( + name = "gcs_cc_proto", + deps = [":gcs_proto"], +) + +python_proto_compile( + name = "gcs_py_proto", + deps = [":gcs_proto"], +) + +proto_library( + name = "node_manager_proto", srcs = ["src/ray/protobuf/node_manager.proto"], ) +cc_proto_library( + name = "node_manager_cc_proto", + deps = ["node_manager_proto"], +) + +# === End of protobuf definitions === + +# Node manager gRPC lib. +cc_grpc_library( + name = "node_manager_cc_grpc", + srcs = [":node_manager_proto"], + grpc_only = True, + deps = [":node_manager_cc_proto"], +) + # Node manager server and client. cc_library( - name = "node_manager_rpc_lib", + name = "node_manager_rpc", srcs = glob([ "src/ray/rpc/*.cc", ]), @@ -25,7 +58,7 @@ cc_library( ]), copts = COPTS, deps = [ - ":node_manager_grpc_lib", + ":node_manager_cc_grpc", ":ray_common", "@boost//:asio", "@com_github_grpc_grpc//:grpc++", @@ -114,7 +147,7 @@ cc_library( ":gcs", ":gcs_fbs", ":node_manager_fbs", - ":node_manager_rpc_lib", + ":node_manager_rpc", ":object_manager", ":ray_common", ":ray_util", @@ -422,9 +455,11 @@ cc_library( "src/ray/gcs/format", ], deps = [ + ":gcs_cc_proto", ":gcs_fbs", ":hiredis", ":node_manager_fbs", + ":node_manager_rpc", ":ray_common", ":ray_util", ":stats_lib", @@ -555,46 +590,6 @@ filegroup( visibility = ["//java:__subpackages__"], ) -flatbuffer_py_library( - name = "python_gcs_fbs", - srcs = [ - ":gcs_fbs_file", - ], - outs = [ - "ActorCheckpointIdData.py", - "ActorState.py", - "ActorTableData.py", - "Arg.py", - "ClassTableData.py", - "ClientTableData.py", - "ConfigTableData.py", - "CustomSerializerData.py", - "DriverTableData.py", - "EntryType.py", - "ErrorTableData.py", - "ErrorType.py", - "FunctionTableData.py", - "GcsEntry.py", - "HeartbeatBatchTableData.py", - "HeartbeatTableData.py", - "Language.py", - "ObjectTableData.py", - "ProfileEvent.py", - "ProfileTableData.py", - "RayResource.py", - "ResourcePair.py", - "SchedulingState.py", - "TablePrefix.py", - "TablePubsub.py", - "TaskInfo.py", - "TaskLeaseData.py", - "TaskReconstructionData.py", - "TaskTableData.py", - "TaskTableTestAndUpdate.py", - ], - out_prefix = "python/ray/core/generated/", -) - flatbuffer_py_library( name = "python_node_manager_fbs", srcs = [ @@ -679,6 +674,7 @@ cc_binary( linkstatic = 1, visibility = ["//java:__subpackages__"], deps = [ + ":gcs_cc_proto", ":ray_common", ], ) @@ -688,7 +684,7 @@ genrule( srcs = [ "python/ray/_raylet.so", "//:python_sources", - "//:python_gcs_fbs", + "//:gcs_py_proto", "//:python_node_manager_fbs", "//:redis-server", "//:redis-cli", @@ -710,11 +706,13 @@ genrule( cp -f $(location //:raylet_monitor) $$WORK_DIR/python/ray/core/src/ray/raylet/ && cp -f $(location @plasma//:plasma_store_server) $$WORK_DIR/python/ray/core/src/plasma/ && cp -f $(location //:raylet) $$WORK_DIR/python/ray/core/src/ray/raylet/ && - for f in $(locations //:python_gcs_fbs); do cp -f $$f $$WORK_DIR/python/ray/core/generated/; done && mkdir -p $$WORK_DIR/python/ray/core/generated/ray/protocol/ && for f in $(locations //:python_node_manager_fbs); do cp -f $$f $$WORK_DIR/python/ray/core/generated/ray/protocol/; done && + for f in $(locations //:gcs_py_proto); do + cp -f $$f $$WORK_DIR/python/ray/core/generated/; + done && echo $$WORK_DIR > $@ """, local = 1, diff --git a/bazel/ray_deps_build_all.bzl b/bazel/ray_deps_build_all.bzl index 3e1e1838a59a..eda88bece7d2 100644 --- a/bazel/ray_deps_build_all.bzl +++ b/bazel/ray_deps_build_all.bzl @@ -4,6 +4,8 @@ load("@com_github_jupp0r_prometheus_cpp//:repositories.bzl", "prometheus_cpp_rep load("@com_github_ray_project_ray//bazel:python_configure.bzl", "python_configure") load("@com_github_checkstyle_java//:repo.bzl", "checkstyle_deps") load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps") +load("@build_stack_rules_proto//java:deps.bzl", "java_proto_compile") +load("@build_stack_rules_proto//python:deps.bzl", "python_proto_compile") def ray_deps_build_all(): @@ -13,4 +15,6 @@ def ray_deps_build_all(): prometheus_cpp_repositories() python_configure(name = "local_config_python") grpc_deps() + java_proto_compile() + python_proto_compile() diff --git a/bazel/ray_deps_setup.bzl b/bazel/ray_deps_setup.bzl index e6dc21585699..aa322654cf9f 100644 --- a/bazel/ray_deps_setup.bzl +++ b/bazel/ray_deps_setup.bzl @@ -105,7 +105,14 @@ def ray_deps_setup(): http_archive( name = "com_github_grpc_grpc", urls = [ - "https://github.com/grpc/grpc/archive/7741e806a213cba63c96234f16d712a8aa101a49.tar.gz", + "https://github.com/grpc/grpc/archive/76a381869413834692b8ed305fbe923c0f9c4472.tar.gz", ], - strip_prefix = "grpc-7741e806a213cba63c96234f16d712a8aa101a49", + strip_prefix = "grpc-76a381869413834692b8ed305fbe923c0f9c4472", + ) + + http_archive( + name = "build_stack_rules_proto", + urls = ["https://github.com/stackb/rules_proto/archive/b93b544f851fdcd3fc5c3d47aee3b7ca158a8841.tar.gz"], + sha256 = "c62f0b442e82a6152fcd5b1c0b7c4028233a9e314078952b6b04253421d56d61", + strip_prefix = "rules_proto-b93b544f851fdcd3fc5c3d47aee3b7ca158a8841", ) diff --git a/ci/jenkins_tests/perf_integration_tests/run_perf_integration.sh b/ci/jenkins_tests/perf_integration_tests/run_perf_integration.sh index 7962b21075c0..f25d32df22a1 100755 --- a/ci/jenkins_tests/perf_integration_tests/run_perf_integration.sh +++ b/ci/jenkins_tests/perf_integration_tests/run_perf_integration.sh @@ -9,7 +9,7 @@ pushd "$ROOT_DIR" python -m pip install pytest-benchmark -pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev1-cp27-cp27mu-manylinux1_x86_64.whl +pip install -U https://ray-wheels.s3-us-west-2.amazonaws.com/latest/ray-0.8.0.dev1-cp36-cp36m-manylinux1_x86_64.whl python -m pytest --benchmark-autosave --benchmark-min-rounds=10 --benchmark-columns="min, max, mean" $ROOT_DIR/../../../python/ray/tests/perf_integration_tests/test_perf_integration.py pushd $ROOT_DIR/../../../python diff --git a/ci/jenkins_tests/run_tune_tests.sh b/ci/jenkins_tests/run_tune_tests.sh index 6154fe70d4f6..6b890d7d371c 100755 --- a/ci/jenkins_tests/run_tune_tests.sh +++ b/ci/jenkins_tests/run_tune_tests.sh @@ -78,16 +78,16 @@ $SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} --smoke-test # Runs only on Python3 -# docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ -# python3 /ray/python/ray/tune/examples/nevergrad_example.py \ -# --smoke-test +$SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ + python /ray/python/ray/tune/examples/nevergrad_example.py \ + --smoke-test $SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/tune/examples/tune_mnist_keras.py \ --smoke-test $SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ - python /ray/python/ray/tune/examples/mnist_pytorch.py --smoke-test --no-cuda + python /ray/python/ray/tune/examples/mnist_pytorch.py --smoke-test $SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/tune/examples/mnist_pytorch_trainable.py \ diff --git a/doc/source/conf.py b/doc/source/conf.py index 98fb3e0d02dd..5cf6b01217f9 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -23,20 +23,7 @@ "gym.spaces", "ray._raylet", "ray.core.generated", - "ray.core.generated.ActorCheckpointIdData", - "ray.core.generated.ClientTableData", - "ray.core.generated.DriverTableData", - "ray.core.generated.EntryType", - "ray.core.generated.ErrorTableData", - "ray.core.generated.ErrorType", - "ray.core.generated.GcsEntry", - "ray.core.generated.HeartbeatBatchTableData", - "ray.core.generated.HeartbeatTableData", - "ray.core.generated.Language", - "ray.core.generated.ObjectTableData", - "ray.core.generated.ProfileTableData", - "ray.core.generated.TablePrefix", - "ray.core.generated.TablePubsub", + "ray.core.generated.gcs_pb2", "ray.core.generated.ray.protocol.Task", "scipy", "scipy.signal", diff --git a/doc/source/tune-usage.rst b/doc/source/tune-usage.rst index 281ccbd6107e..e8ce405d9457 100644 --- a/doc/source/tune-usage.rst +++ b/doc/source/tune-usage.rst @@ -355,6 +355,12 @@ Then, after you run a experiment, you can visualize your experiment with TensorB $ tensorboard --logdir=~/ray_results/my_experiment +If you are running Ray on a remote multi-user cluster where you do not have sudo access, you can run the following commands to make sure tensorboard is able to write to the tmp directory: + +.. code-block:: bash + + $ export TMPDIR=/tmp/$USER; mkdir -p $TMPDIR; tensorboard --logdir=~/ray_results + .. image:: ray-tune-tensorboard.png To use rllab's VisKit (you may have to install some dependencies), run: diff --git a/docker/base-deps/Dockerfile b/docker/base-deps/Dockerfile index c21430c627a4..db8f28c85f86 100644 --- a/docker/base-deps/Dockerfile +++ b/docker/base-deps/Dockerfile @@ -12,7 +12,7 @@ RUN apt-get update \ && apt-get clean \ && echo 'export PATH=/opt/conda/bin:$PATH' > /etc/profile.d/conda.sh \ && wget \ - --quiet 'https://repo.continuum.io/archive/Anaconda2-5.2.0-Linux-x86_64.sh' \ + --quiet 'https://repo.continuum.io/archive/Anaconda3-5.2.0-Linux-x86_64.sh' \ -O /tmp/anaconda.sh \ && /bin/bash /tmp/anaconda.sh -b -p /opt/conda \ && rm /tmp/anaconda.sh \ diff --git a/docker/examples/Dockerfile b/docker/examples/Dockerfile index 6883c5a64a0e..bafcdf35e628 100644 --- a/docker/examples/Dockerfile +++ b/docker/examples/Dockerfile @@ -5,11 +5,14 @@ FROM ray-project/deploy # This updates numpy to 1.14 and mutes errors from other libraries RUN conda install -y numpy RUN apt-get install -y zlib1g-dev +# The following is needed to support TensorFlow 1.14 +RUN conda remove -y --force wrapt RUN pip install gym[atari] opencv-python-headless tensorflow lz4 keras pytest-timeout smart_open RUN pip install -U h5py # Mutes FutureWarnings RUN pip install --upgrade bayesian-optimization RUN pip install --upgrade git+git://github.com/hyperopt/hyperopt.git RUN pip install --upgrade sigopt -# RUN pip install --upgrade nevergrad +RUN pip install --upgrade nevergrad RUN pip install --upgrade scikit-optimize +RUN pip install -U pytest-remotedata>=0.3.1 RUN conda install pytorch-cpu torchvision-cpu -c pytorch diff --git a/docker/stress_test/Dockerfile b/docker/stress_test/Dockerfile index 1d174ed72f92..376fe5340fd9 100644 --- a/docker/stress_test/Dockerfile +++ b/docker/stress_test/Dockerfile @@ -4,7 +4,7 @@ FROM ray-project/base-deps # We install ray and boto3 to enable the ray autoscaler as # a test runner. -RUN pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev1-cp27-cp27mu-manylinux1_x86_64.whl boto3 +RUN pip install -U https://ray-wheels.s3-us-west-2.amazonaws.com/latest/ray-0.8.0.dev1-cp36-cp36m-manylinux1_x86_64.whl boto3 RUN mkdir -p /root/.ssh/ # We port the source code in so that we run the most up-to-date stress tests. diff --git a/docker/tune_test/Dockerfile b/docker/tune_test/Dockerfile index 6e098d5218f6..77cf390493d6 100644 --- a/docker/tune_test/Dockerfile +++ b/docker/tune_test/Dockerfile @@ -4,15 +4,20 @@ FROM ray-project/base-deps # We install ray and boto3 to enable the ray autoscaler as # a test runner. -RUN pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev1-cp27-cp27mu-manylinux1_x86_64.whl boto3 +RUN conda install -y -c anaconda wrapt=1.11.1 +RUN conda install -y -c anaconda numpy=1.16.4 +RUN pip install -U https://ray-wheels.s3-us-west-2.amazonaws.com/latest/ray-0.8.0.dev1-cp36-cp36m-manylinux1_x86_64.whl boto3 # We install this after the latest wheels -- this should not override the latest wheels. RUN apt-get install -y zlib1g-dev +# The following is needed to support TensorFlow 1.14 +RUN conda remove -y --force wrapt RUN pip install gym[atari]==0.10.11 opencv-python-headless tensorflow lz4 keras pytest-timeout smart_open RUN pip install --upgrade bayesian-optimization RUN pip install --upgrade git+git://github.com/hyperopt/hyperopt.git RUN pip install --upgrade sigopt -# RUN pip install --upgrade nevergrad +RUN pip install --upgrade nevergrad RUN pip install --upgrade scikit-optimize +RUN pip install -U pytest-remotedata>=0.3.1 RUN conda install pytorch-cpu torchvision-cpu -c pytorch # RUN mkdir -p /root/.ssh/ @@ -20,6 +25,6 @@ RUN conda install pytorch-cpu torchvision-cpu -c pytorch # We port the source code in so that we run the most up-to-date stress tests. ADD ray.tar /ray ADD git-rev /ray/git-rev -RUN python /ray/python/ray/rllib/setup-rllib-dev.py --yes +RUN python /ray/python/ray/setup-dev.py --yes WORKDIR /ray diff --git a/java/BUILD.bazel b/java/BUILD.bazel index 80ccabccfc12..4960434af180 100644 --- a/java/BUILD.bazel +++ b/java/BUILD.bazel @@ -1,4 +1,5 @@ load("//bazel:ray.bzl", "flatbuffer_java_library", "define_java_module") +load("@build_stack_rules_proto//java:java_proto_compile.bzl", "java_proto_compile") exports_files([ "testng.xml", @@ -50,6 +51,7 @@ define_java_module( name = "runtime", additional_srcs = [ ":generate_java_gcs_fbs", + ":gcs_java_proto", ], additional_resources = [ ":java_native_deps", @@ -68,6 +70,7 @@ define_java_module( "@plasma//:org_apache_arrow_arrow_plasma", "@maven//:com_github_davidmoten_flatbuffers_java", "@maven//:com_google_guava_guava", + "@maven//:com_google_protobuf_protobuf_java", "@maven//:com_typesafe_config", "@maven//:commons_io_commons_io", "@maven//:de_ruedigermoeller_fst", @@ -148,38 +151,16 @@ java_binary( ], ) +java_proto_compile( + name = "gcs_java_proto", + deps = ["@//:gcs_proto"], +) + flatbuffers_generated_files = [ - "ActorCheckpointData.java", - "ActorCheckpointIdData.java", - "ActorState.java", - "ActorTableData.java", "Arg.java", - "ClassTableData.java", - "ClientTableData.java", - "ConfigTableData.java", - "CustomSerializerData.java", - "DriverTableData.java", - "EntryType.java", - "ErrorTableData.java", - "ErrorType.java", - "FunctionTableData.java", - "GcsEntry.java", - "HeartbeatBatchTableData.java", - "HeartbeatTableData.java", "Language.java", - "ObjectTableData.java", - "ProfileEvent.java", - "ProfileTableData.java", - "RayResource.java", - "ResourcePair.java", - "SchedulingState.java", - "TablePrefix.java", - "TablePubsub.java", "TaskInfo.java", - "TaskLeaseData.java", - "TaskReconstructionData.java", - "TaskTableData.java", - "TaskTableTestAndUpdate.java", + "ResourcePair.java", ] flatbuffer_java_library( @@ -198,7 +179,7 @@ genrule( cmd = """ for f in $(locations //java:java_gcs_fbs); do chmod +w $$f - cp -f $$f $(@D)/runtime/src/main/java/org/ray/runtime/generated + mv -f $$f $(@D)/runtime/src/main/java/org/ray/runtime/generated done python $$(pwd)/java/modify_generated_java_flatbuffers_files.py $(@D)/.. """, @@ -221,8 +202,10 @@ filegroup( genrule( name = "gen_maven_deps", srcs = [ - ":java_native_deps", + ":gcs_java_proto", ":generate_java_gcs_fbs", + ":java_native_deps", + ":copy_pom_file", "@plasma//:org_apache_arrow_arrow_plasma", ], outs = ["gen_maven_deps.out"], @@ -237,10 +220,15 @@ genrule( chmod +w $$f cp $$f $$NATIVE_DEPS_DIR done - # Copy flatbuffers-generated files + # Copy protobuf-generated files. GENERATED_DIR=$$WORK_DIR/java/runtime/src/main/java/org/ray/runtime/generated rm -rf $$GENERATED_DIR mkdir -p $$GENERATED_DIR + for f in $(locations //java:gcs_java_proto); do + unzip $$f + mv org/ray/runtime/generated/* $$GENERATED_DIR + done + # Copy flatbuffers-generated files for f in $(locations //java:generate_java_gcs_fbs); do cp $$f $$GENERATED_DIR done @@ -250,6 +238,7 @@ genrule( echo $$(date) > $@ """, local = 1, + tags = ["no-cache"], ) genrule( diff --git a/java/api/src/main/java/org/ray/api/id/BaseId.java b/java/api/src/main/java/org/ray/api/id/BaseId.java index e08955d5a93e..c13f0436f94d 100644 --- a/java/api/src/main/java/org/ray/api/id/BaseId.java +++ b/java/api/src/main/java/org/ray/api/id/BaseId.java @@ -48,7 +48,7 @@ public boolean isNil() { break; } } - isNilCache = localIsNil; + isNilCache = localIsNil; } return isNilCache; } diff --git a/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java b/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java index d1e92f7bb9e9..2e14ca8584dd 100644 --- a/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java +++ b/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java @@ -13,9 +13,14 @@ public class ActorCreationOptions extends BaseTaskOptions { public final int maxReconstructions; - private ActorCreationOptions(Map resources, int maxReconstructions) { + public final String jvmOptions; + + private ActorCreationOptions(Map resources, + int maxReconstructions, + String jvmOptions) { super(resources); this.maxReconstructions = maxReconstructions; + this.jvmOptions = jvmOptions; } /** @@ -25,6 +30,7 @@ public static class Builder { private Map resources = new HashMap<>(); private int maxReconstructions = NO_RECONSTRUCTION; + private String jvmOptions = ""; public Builder setResources(Map resources) { this.resources = resources; @@ -36,8 +42,13 @@ public Builder setMaxReconstructions(int maxReconstructions) { return this; } + public Builder setJvmOptions(String jvmOptions) { + this.jvmOptions = jvmOptions; + return this; + } + public ActorCreationOptions createActorCreationOptions() { - return new ActorCreationOptions(resources, maxReconstructions); + return new ActorCreationOptions(resources, maxReconstructions, jvmOptions); } } diff --git a/java/dependencies.bzl b/java/dependencies.bzl index 7c716166d399..ef667137562b 100644 --- a/java/dependencies.bzl +++ b/java/dependencies.bzl @@ -6,6 +6,7 @@ def gen_java_deps(): "com.beust:jcommander:1.72", "com.github.davidmoten:flatbuffers-java:1.9.0.1", "com.google.guava:guava:27.0.1-jre", + "com.google.protobuf:protobuf-java:3.8.0", "com.puppycrawl.tools:checkstyle:8.15", "com.sun.xml.bind:jaxb-core:2.3.0", "com.sun.xml.bind:jaxb-impl:2.3.0", diff --git a/java/modify_generated_java_flatbuffers_files.py b/java/modify_generated_java_flatbuffers_files.py index c1b723f25f8d..5bf62e56d7e4 100644 --- a/java/modify_generated_java_flatbuffers_files.py +++ b/java/modify_generated_java_flatbuffers_files.py @@ -4,7 +4,6 @@ import os import sys - """ This script is used for modifying the generated java flatbuffer files for the reason: The package declaration in Java is different @@ -21,19 +20,18 @@ PACKAGE_DECLARATION = "package org.ray.runtime.generated;" -def add_new_line(file, line_num, text): +def add_package(file): with open(file, "r") as file_handler: lines = file_handler.readlines() - if (line_num <= 0) or (line_num > len(lines) + 1): - return False - lines.insert(line_num - 1, text + os.linesep) + if "FlatBuffers" not in lines[0]: + return + + lines.insert(1, PACKAGE_DECLARATION + os.linesep) with open(file, "w") as file_handler: for line in lines: file_handler.write(line) - return True - def add_package_declarations(generated_root_path): file_names = os.listdir(generated_root_path) @@ -41,15 +39,11 @@ def add_package_declarations(generated_root_path): if not file_name.endswith(".java"): continue full_name = os.path.join(generated_root_path, file_name) - success = add_new_line(full_name, 2, PACKAGE_DECLARATION) - if not success: - raise RuntimeError("Failed to add package declarations, " - "file name is %s" % full_name) + add_package(full_name) if __name__ == "__main__": ray_home = sys.argv[1] root_path = os.path.join( - ray_home, - "java/runtime/src/main/java/org/ray/runtime/generated") + ray_home, "java/runtime/src/main/java/org/ray/runtime/generated") add_package_declarations(root_path) diff --git a/java/runtime/pom.xml b/java/runtime/pom.xml index c75e2eeef13f..e13dd95f927f 100644 --- a/java/runtime/pom.xml +++ b/java/runtime/pom.xml @@ -41,6 +41,11 @@ guava 27.0.1-jre + + com.google.protobuf + protobuf-java + 3.8.0 + com.typesafe config diff --git a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java index fbd03bf10483..26a8d6e541ba 100644 --- a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java @@ -35,6 +35,7 @@ import org.ray.runtime.task.TaskLanguage; import org.ray.runtime.task.TaskSpec; import org.ray.runtime.util.IdUtil; +import org.ray.runtime.util.StringUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -363,8 +364,13 @@ private TaskSpec createTaskSpec(RayFunc func, PyFunctionDescriptor pyFunctionDes } int maxActorReconstruction = 0; + List dynamicWorkerOptions = ImmutableList.of(); if (taskOptions instanceof ActorCreationOptions) { maxActorReconstruction = ((ActorCreationOptions) taskOptions).maxReconstructions; + String jvmOptions = ((ActorCreationOptions) taskOptions).jvmOptions; + if (!StringUtil.isNullOrEmpty(jvmOptions)) { + dynamicWorkerOptions = ImmutableList.of(((ActorCreationOptions) taskOptions).jvmOptions); + } } TaskLanguage language; @@ -393,7 +399,8 @@ private TaskSpec createTaskSpec(RayFunc func, PyFunctionDescriptor pyFunctionDes numReturns, resources, language, - functionDescriptor + functionDescriptor, + dynamicWorkerOptions ); } diff --git a/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java b/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java index 431b48ded58c..17c248ed0a57 100644 --- a/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java +++ b/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java @@ -1,7 +1,7 @@ package org.ray.runtime.gcs; import com.google.common.base.Preconditions; -import java.nio.ByteBuffer; +import com.google.protobuf.InvalidProtocolBufferException; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -13,10 +13,10 @@ import org.ray.api.id.TaskId; import org.ray.api.id.UniqueId; import org.ray.api.runtimecontext.NodeInfo; -import org.ray.runtime.generated.ActorCheckpointIdData; -import org.ray.runtime.generated.ClientTableData; -import org.ray.runtime.generated.EntryType; -import org.ray.runtime.generated.TablePrefix; +import org.ray.runtime.generated.Gcs.ActorCheckpointIdData; +import org.ray.runtime.generated.Gcs.ClientTableData; +import org.ray.runtime.generated.Gcs.ClientTableData.EntryType; +import org.ray.runtime.generated.Gcs.TablePrefix; import org.ray.runtime.util.IdUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -51,7 +51,7 @@ public GcsClient(String redisAddress, String redisPassword) { } public List getAllNodeInfo() { - final String prefix = TablePrefix.name(TablePrefix.CLIENT); + final String prefix = TablePrefix.CLIENT.toString(); final byte[] key = ArrayUtils.addAll(prefix.getBytes(), UniqueId.NIL.getBytes()); List results = primary.lrange(key, 0, -1); @@ -63,36 +63,42 @@ public List getAllNodeInfo() { Map clients = new HashMap<>(); for (byte[] result : results) { Preconditions.checkNotNull(result); - ClientTableData data = ClientTableData.getRootAsClientTableData(ByteBuffer.wrap(result)); - final UniqueId clientId = UniqueId.fromByteBuffer(data.clientIdAsByteBuffer()); + ClientTableData data = null; + try { + data = ClientTableData.parseFrom(result); + } catch (InvalidProtocolBufferException e) { + throw new RuntimeException("Received invalid protobuf data from GCS."); + } + final UniqueId clientId = UniqueId + .fromByteBuffer(data.getClientId().asReadOnlyByteBuffer()); - if (data.entryType() == EntryType.INSERTION) { + if (data.getEntryType() == EntryType.INSERTION) { //Code path of node insertion. Map resources = new HashMap<>(); // Compute resources. Preconditions.checkState( - data.resourcesTotalLabelLength() == data.resourcesTotalCapacityLength()); - for (int i = 0; i < data.resourcesTotalLabelLength(); i++) { - resources.put(data.resourcesTotalLabel(i), data.resourcesTotalCapacity(i)); + data.getResourcesTotalLabelCount() == data.getResourcesTotalCapacityCount()); + for (int i = 0; i < data.getResourcesTotalLabelCount(); i++) { + resources.put(data.getResourcesTotalLabel(i), data.getResourcesTotalCapacity(i)); } NodeInfo nodeInfo = new NodeInfo( - clientId, data.nodeManagerAddress(), true, resources); + clientId, data.getNodeManagerAddress(), true, resources); clients.put(clientId, nodeInfo); - } else if (data.entryType() == EntryType.RES_CREATEUPDATE) { + } else if (data.getEntryType() == EntryType.RES_CREATEUPDATE) { Preconditions.checkState(clients.containsKey(clientId)); NodeInfo nodeInfo = clients.get(clientId); - for (int i = 0; i < data.resourcesTotalLabelLength(); i++) { - nodeInfo.resources.put(data.resourcesTotalLabel(i), data.resourcesTotalCapacity(i)); + for (int i = 0; i < data.getResourcesTotalLabelCount(); i++) { + nodeInfo.resources.put(data.getResourcesTotalLabel(i), data.getResourcesTotalCapacity(i)); } - } else if (data.entryType() == EntryType.RES_DELETE) { + } else if (data.getEntryType() == EntryType.RES_DELETE) { Preconditions.checkState(clients.containsKey(clientId)); NodeInfo nodeInfo = clients.get(clientId); - for (int i = 0; i < data.resourcesTotalLabelLength(); i++) { - nodeInfo.resources.remove(data.resourcesTotalLabel(i)); + for (int i = 0; i < data.getResourcesTotalLabelCount(); i++) { + nodeInfo.resources.remove(data.getResourcesTotalLabel(i)); } } else { // Code path of node deletion. - Preconditions.checkState(data.entryType() == EntryType.DELETION); + Preconditions.checkState(data.getEntryType() == EntryType.DELETION); NodeInfo nodeInfo = new NodeInfo(clientId, clients.get(clientId).nodeAddress, false, clients.get(clientId).resources); clients.put(clientId, nodeInfo); @@ -107,7 +113,7 @@ public List getAllNodeInfo() { */ public boolean actorExists(UniqueId actorId) { byte[] key = ArrayUtils.addAll( - TablePrefix.name(TablePrefix.ACTOR).getBytes(), actorId.getBytes()); + TablePrefix.ACTOR.toString().getBytes(), actorId.getBytes()); return primary.exists(key); } @@ -115,7 +121,7 @@ public boolean actorExists(UniqueId actorId) { * Query whether the raylet task exists in Gcs. */ public boolean rayletTaskExistsInGcs(TaskId taskId) { - byte[] key = ArrayUtils.addAll(TablePrefix.name(TablePrefix.RAYLET_TASK).getBytes(), + byte[] key = ArrayUtils.addAll(TablePrefix.RAYLET_TASK.toString().getBytes(), taskId.getBytes()); RedisClient client = getShardClient(taskId); return client.exists(key); @@ -126,19 +132,26 @@ public boolean rayletTaskExistsInGcs(TaskId taskId) { */ public List getCheckpointsForActor(UniqueId actorId) { List checkpoints = new ArrayList<>(); - final String prefix = TablePrefix.name(TablePrefix.ACTOR_CHECKPOINT_ID); + final String prefix = TablePrefix.ACTOR_CHECKPOINT_ID.toString(); final byte[] key = ArrayUtils.addAll(prefix.getBytes(), actorId.getBytes()); RedisClient client = getShardClient(actorId); byte[] result = client.get(key); if (result != null) { - ActorCheckpointIdData data = - ActorCheckpointIdData.getRootAsActorCheckpointIdData(ByteBuffer.wrap(result)); - UniqueId[] checkpointIds = IdUtil.getUniqueIdsFromByteBuffer( - data.checkpointIdsAsByteBuffer()); + ActorCheckpointIdData data = null; + try { + data = ActorCheckpointIdData.parseFrom(result); + } catch (InvalidProtocolBufferException e) { + throw new RuntimeException("Received invalid protobuf data from GCS."); + } + UniqueId[] checkpointIds = new UniqueId[data.getCheckpointIdsCount()]; + for (int i = 0; i < checkpointIds.length; i++) { + checkpointIds[i] = UniqueId + .fromByteBuffer(data.getCheckpointIds(i).asReadOnlyByteBuffer()); + } for (int i = 0; i < checkpointIds.length; i++) { - checkpoints.add(new Checkpoint(checkpointIds[i], data.timestamps(i))); + checkpoints.add(new Checkpoint(checkpointIds[i], data.getTimestamps(i))); } } checkpoints.sort((x, y) -> Long.compare(y.timestamp, x.timestamp)); diff --git a/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java b/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java index f9e310249a35..1a7e4701c22b 100644 --- a/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java +++ b/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java @@ -16,7 +16,7 @@ import org.ray.runtime.AbstractRayRuntime; import org.ray.runtime.RayDevRuntime; import org.ray.runtime.config.RunMode; -import org.ray.runtime.generated.ErrorType; +import org.ray.runtime.generated.Gcs.ErrorType; import org.ray.runtime.util.IdUtil; import org.ray.runtime.util.Serializer; import org.slf4j.Logger; @@ -29,12 +29,12 @@ public class ObjectStoreProxy { private static final Logger LOGGER = LoggerFactory.getLogger(ObjectStoreProxy.class); - private static final byte[] WORKER_EXCEPTION_META = String.valueOf(ErrorType.WORKER_DIED) - .getBytes(); - private static final byte[] ACTOR_EXCEPTION_META = String.valueOf(ErrorType.ACTOR_DIED) - .getBytes(); + private static final byte[] WORKER_EXCEPTION_META = String + .valueOf(ErrorType.WORKER_DIED.getNumber()).getBytes(); + private static final byte[] ACTOR_EXCEPTION_META = String + .valueOf(ErrorType.ACTOR_DIED.getNumber()).getBytes(); private static final byte[] UNRECONSTRUCTABLE_EXCEPTION_META = String - .valueOf(ErrorType.OBJECT_UNRECONSTRUCTABLE).getBytes(); + .valueOf(ErrorType.OBJECT_UNRECONSTRUCTABLE.getNumber()).getBytes(); private static final byte[] RAW_TYPE_META = "RAW".getBytes(); diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java index 01b9e4675016..c369e6f2cab8 100644 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java +++ b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java @@ -190,9 +190,16 @@ private static TaskSpec parseTaskSpecFromFlatbuffer(ByteBuffer bb) { JavaFunctionDescriptor functionDescriptor = new JavaFunctionDescriptor( info.functionDescriptor(0), info.functionDescriptor(1), info.functionDescriptor(2) ); + + // Deserialize dynamic worker options. + List dynamicWorkerOptions = new ArrayList<>(); + for (int i = 0; i < info.dynamicWorkerOptionsLength(); ++i) { + dynamicWorkerOptions.add(info.dynamicWorkerOptions(i)); + } + return new TaskSpec(driverId, taskId, parentTaskId, parentCounter, actorCreationId, maxActorReconstructions, actorId, actorHandleId, actorCounter, newActorHandles, - args, numReturns, resources, TaskLanguage.JAVA, functionDescriptor); + args, numReturns, resources, TaskLanguage.JAVA, functionDescriptor, dynamicWorkerOptions); } private static ByteBuffer convertTaskSpecToFlatbuffer(TaskSpec task) { @@ -275,6 +282,12 @@ private static ByteBuffer convertTaskSpecToFlatbuffer(TaskSpec task) { functionDescriptorOffset = fbb.createVectorOfTables(functionDescriptorOffsets); } + int [] dynamicWorkerOptionsOffsets = new int[task.dynamicWorkerOptions.size()]; + for (int index = 0; index < task.dynamicWorkerOptions.size(); ++index) { + dynamicWorkerOptionsOffsets[index] = fbb.createString(task.dynamicWorkerOptions.get(index)); + } + int dynamicWorkerOptionsOffset = fbb.createVectorOfTables(dynamicWorkerOptionsOffsets); + int root = TaskInfo.createTaskInfo( fbb, driverIdOffset, @@ -293,7 +306,8 @@ private static ByteBuffer convertTaskSpecToFlatbuffer(TaskSpec task) { requiredResourcesOffset, requiredPlacementResourcesOffset, language, - functionDescriptorOffset); + functionDescriptorOffset, + dynamicWorkerOptionsOffset); fbb.finish(root); ByteBuffer buffer = fbb.dataBuffer(); diff --git a/java/runtime/src/main/java/org/ray/runtime/runner/RunManager.java b/java/runtime/src/main/java/org/ray/runtime/runner/RunManager.java index 15240e43e234..773499fcf5cf 100644 --- a/java/runtime/src/main/java/org/ray/runtime/runner/RunManager.java +++ b/java/runtime/src/main/java/org/ray/runtime/runner/RunManager.java @@ -319,6 +319,9 @@ private String buildWorkerCommandRaylet() { cmd.addAll(rayConfig.jvmParameters); + // jvm options + cmd.add("RAY_WORKER_OPTION_0"); + // Main class cmd.add(WORKER_CLASS); String command = Joiner.on(" ").join(cmd); diff --git a/java/runtime/src/main/java/org/ray/runtime/task/TaskSpec.java b/java/runtime/src/main/java/org/ray/runtime/task/TaskSpec.java index 3473a9bdb3cc..060ca6fff4c3 100644 --- a/java/runtime/src/main/java/org/ray/runtime/task/TaskSpec.java +++ b/java/runtime/src/main/java/org/ray/runtime/task/TaskSpec.java @@ -63,6 +63,8 @@ public class TaskSpec { // Language of this task. public final TaskLanguage language; + public final List dynamicWorkerOptions; + // Descriptor of the remote function. // Note, if task language is Java, the type is JavaFunctionDescriptor. If the task language // is Python, the type is PyFunctionDescriptor. @@ -93,7 +95,8 @@ public TaskSpec( int numReturns, Map resources, TaskLanguage language, - FunctionDescriptor functionDescriptor) { + FunctionDescriptor functionDescriptor, + List dynamicWorkerOptions) { this.driverId = driverId; this.taskId = taskId; this.parentTaskId = parentTaskId; @@ -106,6 +109,8 @@ public TaskSpec( this.newActorHandles = newActorHandles; this.args = args; this.numReturns = numReturns; + this.dynamicWorkerOptions = dynamicWorkerOptions; + returnIds = new ObjectId[numReturns]; for (int i = 0; i < numReturns; ++i) { returnIds[i] = IdUtil.computeReturnId(taskId, i + 1); @@ -157,6 +162,7 @@ public String toString() { ", resources=" + resources + ", language=" + language + ", functionDescriptor=" + functionDescriptor + + ", dynamicWorkerOptions=" + dynamicWorkerOptions + ", executionDependencies=" + executionDependencies + '}'; } diff --git a/java/test/src/main/java/org/ray/api/TestUtils.java b/java/test/src/main/java/org/ray/api/TestUtils.java index 9b3bbf233856..3636c93e4909 100644 --- a/java/test/src/main/java/org/ray/api/TestUtils.java +++ b/java/test/src/main/java/org/ray/api/TestUtils.java @@ -1,8 +1,10 @@ package org.ray.api; import java.util.function.Supplier; +import org.ray.api.annotation.RayRemote; import org.ray.runtime.AbstractRayRuntime; import org.ray.runtime.config.RunMode; +import org.testng.Assert; import org.testng.SkipException; public class TestUtils { @@ -42,4 +44,17 @@ public static boolean waitForCondition(Supplier condition, int timeoutM } return false; } + + @RayRemote + private static String hi() { + return "hi"; + } + + /** + * Warm up the cluster. + */ + public static void warmUpCluster() { + RayObject obj = Ray.call(TestUtils::hi); + Assert.assertEquals(obj.get(), "hi"); + } } diff --git a/java/test/src/main/java/org/ray/api/test/DynamicResourceTest.java b/java/test/src/main/java/org/ray/api/test/DynamicResourceTest.java index 79b3eba0ed13..71766c6cf2bf 100644 --- a/java/test/src/main/java/org/ray/api/test/DynamicResourceTest.java +++ b/java/test/src/main/java/org/ray/api/test/DynamicResourceTest.java @@ -23,6 +23,10 @@ public static String sayHi() { @Test public void testSetResource() { TestUtils.skipTestUnderSingleProcess(); + + // Call a task in advance to warm up the cluster to avoid being too slow to start workers. + TestUtils.warmUpCluster(); + CallOptions op1 = new CallOptions.Builder().setResources(ImmutableMap.of("A", 10.0)).createCallOptions(); RayObject obj = Ray.call(DynamicResourceTest::sayHi, op1); @@ -30,16 +34,21 @@ public void testSetResource() { Assert.assertEquals(result.getReady().size(), 0); Ray.setResource("A", 10.0); + boolean resourceReady = TestUtils.waitForCondition(() -> { + List nodes = Ray.getRuntimeContext().getAllNodeInfo(); + if (nodes.size() != 1) { + return false; + } + return (0 == Double.compare(10.0, nodes.get(0).resources.get("A"))); + }, 2000); - // Assert node info. - List nodes = Ray.getRuntimeContext().getAllNodeInfo(); - Assert.assertEquals(nodes.size(), 1); - Assert.assertEquals(nodes.get(0).resources.get("A"), 10.0); + Assert.assertTrue(resourceReady); // Assert ray call result. result = Ray.wait(ImmutableList.of(obj), 1, 1000); Assert.assertEquals(result.getReady().size(), 1); Assert.assertEquals(Ray.get(obj.getId()), "hi"); + } } diff --git a/java/test/src/main/java/org/ray/api/test/WaitTest.java b/java/test/src/main/java/org/ray/api/test/WaitTest.java index e82b99d364ba..bccc50a50bdf 100644 --- a/java/test/src/main/java/org/ray/api/test/WaitTest.java +++ b/java/test/src/main/java/org/ray/api/test/WaitTest.java @@ -5,6 +5,7 @@ import java.util.List; import org.ray.api.Ray; import org.ray.api.RayObject; +import org.ray.api.TestUtils; import org.ray.api.WaitResult; import org.ray.api.annotation.RayRemote; import org.testng.Assert; @@ -28,6 +29,9 @@ private static String delayedHi() { } private static void testWait() { + // Call a task in advance to warm up the cluster to avoid being too slow to start workers. + TestUtils.warmUpCluster(); + RayObject obj1 = Ray.call(WaitTest::hi); RayObject obj2 = Ray.call(WaitTest::delayedHi); @@ -71,4 +75,5 @@ public void testWaitForEmpty() { Assert.assertTrue(true); } } + } diff --git a/java/test/src/main/java/org/ray/api/test/WorkerJvmOptionsTest.java b/java/test/src/main/java/org/ray/api/test/WorkerJvmOptionsTest.java new file mode 100644 index 000000000000..90a2817a8366 --- /dev/null +++ b/java/test/src/main/java/org/ray/api/test/WorkerJvmOptionsTest.java @@ -0,0 +1,31 @@ +package org.ray.api.test; + +import org.ray.api.Ray; +import org.ray.api.RayActor; +import org.ray.api.RayObject; +import org.ray.api.TestUtils; +import org.ray.api.annotation.RayRemote; +import org.ray.api.options.ActorCreationOptions; +import org.testng.Assert; +import org.testng.annotations.Test; + +public class WorkerJvmOptionsTest extends BaseTest { + + @RayRemote + public static class Echo { + String getOptions() { + return System.getProperty("test.suffix"); + } + } + + @Test + public void testJvmOptions() { + TestUtils.skipTestUnderSingleProcess(); + ActorCreationOptions options = new ActorCreationOptions.Builder() + .setJvmOptions("-Dtest.suffix=suffix") + .createActorCreationOptions(); + RayActor actor = Ray.createActor(Echo::new, options); + RayObject obj = Ray.call(Echo::getOptions, actor); + Assert.assertEquals(obj.get(), "suffix"); + } +} diff --git a/python/ray/experimental/signal.py b/python/ray/experimental/signal.py index f2a0d81ca343..25ec072d3fc7 100644 --- a/python/ray/experimental/signal.py +++ b/python/ray/experimental/signal.py @@ -2,6 +2,8 @@ from __future__ import division from __future__ import print_function +import logging + from collections import defaultdict import ray @@ -13,6 +15,8 @@ # in node_manager.cc ACTOR_DIED_STR = "ACTOR_DIED_SIGNAL" +logger = logging.getLogger(__name__) + class Signal(object): """Base class for Ray signals.""" @@ -125,10 +129,16 @@ def receive(sources, timeout=None): for s in sources: task_id_to_sources[_get_task_id(s).hex()].append(s) + if timeout < 1e-3: + logger.warning("Timeout too small. Using 1ms minimum") + timeout = 1e-3 + + timeout_ms = int(1000 * timeout) + # Construct the redis query. query = "XREAD BLOCK " - # Multiply by 1000x since timeout is in sec and redis expects ms. - query += str(1000 * timeout) + # redis expects ms. + query += str(timeout_ms) query += " STREAMS " query += " ".join([task_id for task_id in task_id_to_sources]) query += " " diff --git a/python/ray/gcs_utils.py b/python/ray/gcs_utils.py index cadd197ec73f..ba72e96f41db 100644 --- a/python/ray/gcs_utils.py +++ b/python/ray/gcs_utils.py @@ -2,38 +2,39 @@ from __future__ import division from __future__ import print_function -import flatbuffers -import ray.core.generated.ErrorTableData - -from ray.core.generated.ActorCheckpointIdData import ActorCheckpointIdData -from ray.core.generated.ClientTableData import ClientTableData -from ray.core.generated.DriverTableData import DriverTableData -from ray.core.generated.ErrorTableData import ErrorTableData -from ray.core.generated.GcsEntry import GcsEntry -from ray.core.generated.HeartbeatBatchTableData import HeartbeatBatchTableData -from ray.core.generated.HeartbeatTableData import HeartbeatTableData -from ray.core.generated.Language import Language -from ray.core.generated.ObjectTableData import ObjectTableData -from ray.core.generated.ProfileTableData import ProfileTableData -from ray.core.generated.TablePrefix import TablePrefix -from ray.core.generated.TablePubsub import TablePubsub - from ray.core.generated.ray.protocol.Task import Task +from ray.core.generated.gcs_pb2 import ( + ActorCheckpointIdData, + ClientTableData, + DriverTableData, + ErrorTableData, + ErrorType, + GcsEntry, + HeartbeatBatchTableData, + HeartbeatTableData, + ObjectTableData, + ProfileTableData, + TablePrefix, + TablePubsub, + TaskTableData, +) + __all__ = [ "ActorCheckpointIdData", "ClientTableData", "DriverTableData", "ErrorTableData", + "ErrorType", "GcsEntry", "HeartbeatBatchTableData", "HeartbeatTableData", - "Language", "ObjectTableData", "ProfileTableData", "TablePrefix", "TablePubsub", "Task", + "TaskTableData", "construct_error_message", ] @@ -42,13 +43,16 @@ REPORTER_CHANNEL = "RAY_REPORTER" # xray heartbeats -XRAY_HEARTBEAT_CHANNEL = str(TablePubsub.HEARTBEAT).encode("ascii") -XRAY_HEARTBEAT_BATCH_CHANNEL = str(TablePubsub.HEARTBEAT_BATCH).encode("ascii") +XRAY_HEARTBEAT_CHANNEL = str( + TablePubsub.Value("HEARTBEAT_PUBSUB")).encode("ascii") +XRAY_HEARTBEAT_BATCH_CHANNEL = str( + TablePubsub.Value("HEARTBEAT_BATCH_PUBSUB")).encode("ascii") # xray driver updates -XRAY_DRIVER_CHANNEL = str(TablePubsub.DRIVER).encode("ascii") +XRAY_DRIVER_CHANNEL = str(TablePubsub.Value("DRIVER_PUBSUB")).encode("ascii") -# These prefixes must be kept up-to-date with the TablePrefix enum in gcs.fbs. +# These prefixes must be kept up-to-date with the TablePrefix enum in +# gcs.proto. # TODO(rkn): We should use scoped enums, in which case we should be able to # just access the flatbuffer generated values. TablePrefix_RAYLET_TASK_string = "RAYLET_TASK" @@ -70,22 +74,9 @@ def construct_error_message(driver_id, error_type, message, timestamp): Returns: The serialized object. """ - builder = flatbuffers.Builder(0) - driver_offset = builder.CreateString(driver_id.binary()) - error_type_offset = builder.CreateString(error_type) - message_offset = builder.CreateString(message) - - ray.core.generated.ErrorTableData.ErrorTableDataStart(builder) - ray.core.generated.ErrorTableData.ErrorTableDataAddDriverId( - builder, driver_offset) - ray.core.generated.ErrorTableData.ErrorTableDataAddType( - builder, error_type_offset) - ray.core.generated.ErrorTableData.ErrorTableDataAddErrorMessage( - builder, message_offset) - ray.core.generated.ErrorTableData.ErrorTableDataAddTimestamp( - builder, timestamp) - error_data_offset = ray.core.generated.ErrorTableData.ErrorTableDataEnd( - builder) - builder.Finish(error_data_offset) - - return bytes(builder.Output()) + data = ErrorTableData() + data.driver_id = driver_id.binary() + data.type = error_type + data.error_message = message + data.timestamp = timestamp + return data.SerializeToString() diff --git a/python/ray/monitor.py b/python/ray/monitor.py index c9e0424b3eb8..35597ef231e3 100644 --- a/python/ray/monitor.py +++ b/python/ray/monitor.py @@ -101,28 +101,26 @@ def subscribe(self, channel): def xray_heartbeat_batch_handler(self, unused_channel, data): """Handle an xray heartbeat batch message from Redis.""" - gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(data, 0) - heartbeat_data = gcs_entries.Entries(0) + gcs_entries = ray.gcs_utils.GcsEntry.FromString(data) + heartbeat_data = gcs_entries.entries[0] - message = (ray.gcs_utils.HeartbeatBatchTableData. - GetRootAsHeartbeatBatchTableData(heartbeat_data, 0)) + message = ray.gcs_utils.HeartbeatBatchTableData.FromString( + heartbeat_data) - for j in range(message.BatchLength()): - heartbeat_message = message.Batch(j) - - num_resources = heartbeat_message.ResourcesTotalLabelLength() + for heartbeat_message in message.batch: + num_resources = len(heartbeat_message.resources_available_label) static_resources = {} dynamic_resources = {} for i in range(num_resources): - dyn = heartbeat_message.ResourcesAvailableLabel(i) - static = heartbeat_message.ResourcesTotalLabel(i) + dyn = heartbeat_message.resources_available_label[i] + static = heartbeat_message.resources_total_label[i] dynamic_resources[dyn] = ( - heartbeat_message.ResourcesAvailableCapacity(i)) + heartbeat_message.resources_available_capacity[i]) static_resources[static] = ( - heartbeat_message.ResourcesTotalCapacity(i)) + heartbeat_message.resources_total_capacity[i]) # Update the load metrics for this raylet. - client_id = ray.utils.binary_to_hex(heartbeat_message.ClientId()) + client_id = ray.utils.binary_to_hex(heartbeat_message.client_id) ip = self.raylet_id_to_ip_map.get(client_id) if ip: self.load_metrics.update(ip, static_resources, @@ -207,11 +205,10 @@ def xray_driver_removed_handler(self, unused_channel, data): unused_channel: The message channel. data: The message data. """ - gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(data, 0) - driver_data = gcs_entries.Entries(0) - message = ray.gcs_utils.DriverTableData.GetRootAsDriverTableData( - driver_data, 0) - driver_id = message.DriverId() + gcs_entries = ray.gcs_utils.GcsEntry.FromString(data) + driver_data = gcs_entries.entries[0] + message = ray.gcs_utils.DriverTableData.FromString(driver_data) + driver_id = message.driver_id logger.info("Monitor: " "XRay Driver {} has been removed.".format( binary_to_hex(driver_id))) diff --git a/python/ray/rllib/agents/a3c/a3c.py b/python/ray/rllib/agents/a3c/a3c.py index c269df2fc6e5..d320b9636881 100644 --- a/python/ray/rllib/agents/a3c/a3c.py +++ b/python/ray/rllib/agents/a3c/a3c.py @@ -48,6 +48,10 @@ def get_policy_class(config): def validate_config(config): if config["entropy_coeff"] < 0: raise DeprecationWarning("entropy_coeff must be >= 0") + if config["sample_async"] and config["use_pytorch"]: + raise ValueError( + "The sample_async option is not supported with use_pytorch: " + "Multithreading can be lead to crashes if used with pytorch.") def make_async_optimizer(workers, config): diff --git a/python/ray/rllib/agents/impala/impala.py b/python/ray/rllib/agents/impala/impala.py index b9699888bfaf..23b5ada167db 100644 --- a/python/ray/rllib/agents/impala/impala.py +++ b/python/ray/rllib/agents/impala/impala.py @@ -75,6 +75,7 @@ # balancing the three losses "vf_loss_coeff": 0.5, "entropy_coeff": 0.01, + "entropy_schedule": None, # use fake (infinite speed) sampler for testing "_fake_sampler": False, diff --git a/python/ray/rllib/agents/impala/vtrace_policy.py b/python/ray/rllib/agents/impala/vtrace_policy.py index 7fd137bae08b..9860783238a0 100644 --- a/python/ray/rllib/agents/impala/vtrace_policy.py +++ b/python/ray/rllib/agents/impala/vtrace_policy.py @@ -14,7 +14,7 @@ from ray.rllib.policy.policy import Policy from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.tf_policy import TFPolicy, \ - LearningRateSchedule + LearningRateSchedule, EntropyCoeffSchedule from ray.rllib.models.action_dist import MultiCategorical from ray.rllib.models.catalog import ModelCatalog from ray.rllib.utils.annotations import override @@ -126,7 +126,7 @@ def postprocess_trajectory(self, return sample_batch -class VTraceTFPolicy(LearningRateSchedule, VTracePostprocessing, TFPolicy): +class VTraceTFPolicy(LearningRateSchedule, EntropyCoeffSchedule, VTracePostprocessing, TFPolicy): def __init__(self, observation_space, action_space, @@ -249,6 +249,9 @@ def make_time_major(tensor, drop_last=False): loss_actions = actions if is_multidiscrete else tf.expand_dims( actions, axis=1) + EntropyCoeffSchedule.__init__(self, self.config["entropy_coeff"], + self.config["entropy_schedule"]) + # Inputs are reshaped from [B * T] => [T - 1, B] for V-trace calc. with tf.name_scope('vtrace_loss'): self.loss = VTraceLoss( @@ -333,6 +336,7 @@ def make_time_major(tensor, drop_last=False): self.stats_fetches = { LEARNER_STATS_KEY: dict({ "cur_lr": tf.cast(self.cur_lr, tf.float64), + "entropy_coeff": tf.cast(self.entropy_coeff, tf.float64), "policy_loss": self.loss.pi_loss, "entropy": self.loss.entropy, "grad_gnorm": tf.global_norm(self._grads), diff --git a/python/ray/rllib/agents/qmix/qmix_policy.py b/python/ray/rllib/agents/qmix/qmix_policy.py index 26ec387de004..99045899684b 100644 --- a/python/ray/rllib/agents/qmix/qmix_policy.py +++ b/python/ray/rllib/agents/qmix/qmix_policy.py @@ -204,6 +204,8 @@ def __init__(self, obs_space, action_space, config): # Setup optimizer self.params = list(self.model.parameters()) + if self.mixer: + self.params += list(self.mixer.parameters()) self.loss = QMixLoss(self.model, self.target_model, self.mixer, self.target_mixer, self.n_agents, self.n_actions, self.config["double_q"], self.config["gamma"]) diff --git a/python/ray/rllib/policy/tf_policy.py b/python/ray/rllib/policy/tf_policy.py index abc5cf546184..591363a793be 100644 --- a/python/ray/rllib/policy/tf_policy.py +++ b/python/ray/rllib/policy/tf_policy.py @@ -2,21 +2,21 @@ from __future__ import division from __future__ import print_function -import os import errno import logging -import numpy as np +import os +import numpy as np import ray import ray.experimental.tf_utils +from ray.rllib.models.lstm import chop_into_sequences from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.models.lstm import chop_into_sequences +from ray.rllib.utils import try_import_tf from ray.rllib.utils.annotations import override, DeveloperAPI from ray.rllib.utils.debug import log_once, summarize from ray.rllib.utils.schedules import ConstantSchedule, PiecewiseSchedule, LinearSchedule from ray.rllib.utils.tf_run_builder import TFRunBuilder -from ray.rllib.utils import try_import_tf tf = try_import_tf() logger = logging.getLogger(__name__) @@ -416,7 +416,7 @@ def _build_compute_actions(self, if len(self._state_inputs) != len(state_batches): raise ValueError( "Must pass in RNN state batches for placeholders {}, got {}". - format(self._state_inputs, state_batches)) + format(self._state_inputs, state_batches)) builder.add_feed_dict(self.extra_compute_action_feed_dict()) builder.add_feed_dict({self._obs_input: obs_batch}) if state_batches: @@ -443,7 +443,7 @@ def _build_apply_gradients(self, builder, gradients): if len(gradients) != len(self._grads): raise ValueError( "Unexpected number of gradients to apply, got {} for {}". - format(gradients, self._grads)) + format(gradients, self._grads)) builder.add_feed_dict({self._is_training: True}) builder.add_feed_dict(dict(zip(self._grads, gradients))) fetches = builder.add_fetches([self._apply_op]) @@ -473,9 +473,9 @@ def _get_loss_inputs_dict(self, batch): feed_dict = {} if self._batch_divisibility_req > 1: meets_divisibility_reqs = ( - len(batch[SampleBatch.CUR_OBS]) % - self._batch_divisibility_req == 0 - and max(batch[SampleBatch.AGENT_INDEX]) == 0) # not multiagent + len(batch[SampleBatch.CUR_OBS]) % + self._batch_divisibility_req == 0 + and max(batch[SampleBatch.AGENT_INDEX]) == 0) # not multiagent else: meets_divisibility_reqs = True @@ -551,3 +551,23 @@ def on_global_var_update(self, global_vars): @override(TFPolicy) def optimizer(self): return tf.train.AdamOptimizer(self.cur_lr) + + +@DeveloperAPI +class EntropyCoeffSchedule(object): + """Mixin for TFPolicy that adds entropy coeff decay.""" + + @DeveloperAPI + def __init__(self, entropy_coeff, entropy_schedule): + self.entropy_coeff = tf.get_variable("entropy_coeff", initializer=entropy_coeff) + self._entropy_schedule = entropy_schedule + + @override(Policy) + def on_global_var_update(self, global_vars): + super(EntropyCoeffSchedule, self).on_global_var_update(global_vars) + if self._entropy_schedule is not None: + self.entropy_coeff.load( + self.config['entropy_coeff'] * + (1 - global_vars['timestep'] / + self.config['entropy_schedule']), + session=self._sess) diff --git a/python/ray/rllib/tests/test_optimizers.py b/python/ray/rllib/tests/test_optimizers.py index a87a295ccf1d..d27270c20965 100644 --- a/python/ray/rllib/tests/test_optimizers.py +++ b/python/ray/rllib/tests/test_optimizers.py @@ -125,14 +125,14 @@ def testSimple(self): def testMultiGPU(self): local, remotes = self._make_evs() workers = WorkerSet._from_existing(local, remotes) - optimizer = AsyncSamplesOptimizer(workers, num_gpus=2, _fake_gpus=True) + optimizer = AsyncSamplesOptimizer(workers, num_gpus=1, _fake_gpus=True) self._wait_for(optimizer, 1000, 1000) def testMultiGPUParallelLoad(self): local, remotes = self._make_evs() workers = WorkerSet._from_existing(local, remotes) optimizer = AsyncSamplesOptimizer( - workers, num_gpus=2, num_data_loader_buffers=2, _fake_gpus=True) + workers, num_gpus=1, num_data_loader_buffers=1, _fake_gpus=True) self._wait_for(optimizer, 1000, 1000) def testMultiplePasses(self): @@ -211,21 +211,21 @@ def testRejectBadConfigs(self): num_data_loader_buffers=2, minibatch_buffer_size=4)) optimizer = AsyncSamplesOptimizer( workers, - num_gpus=2, + num_gpus=1, train_batch_size=100, sample_batch_size=50, _fake_gpus=True) self._wait_for(optimizer, 1000, 1000) optimizer = AsyncSamplesOptimizer( workers, - num_gpus=2, + num_gpus=1, train_batch_size=100, sample_batch_size=25, _fake_gpus=True) self._wait_for(optimizer, 1000, 1000) optimizer = AsyncSamplesOptimizer( workers, - num_gpus=2, + num_gpus=1, train_batch_size=100, sample_batch_size=74, _fake_gpus=True) diff --git a/python/ray/services.py b/python/ray/services.py index 66d4069820d0..ff4111b2c258 100644 --- a/python/ray/services.py +++ b/python/ray/services.py @@ -1245,6 +1245,7 @@ def build_java_worker_command( assert java_worker_options is not None command = "java " + if redis_address is not None: command += "-Dray.redis.address={} ".format(redis_address) @@ -1265,6 +1266,8 @@ def build_java_worker_command( # Put `java_worker_options` in the last, so it can overwrite the # above options. command += java_worker_options + " " + + command += "RAY_WORKER_OPTION_0 " command += "org.ray.runtime.runner.worker.DefaultWorker" return command diff --git a/python/ray/state.py b/python/ray/state.py index 14ba49987ec4..35f97cd65f5e 100644 --- a/python/ray/state.py +++ b/python/ray/state.py @@ -10,11 +10,11 @@ import ray from ray.function_manager import FunctionDescriptor -import ray.gcs_utils -from ray.ray_constants import ID_SIZE -from ray import services -from ray.core.generated.EntryType import EntryType +from ray import ( + gcs_utils, + services, +) from ray.utils import (decode, binary_to_object_id, binary_to_hex, hex_to_binary) @@ -31,9 +31,9 @@ def _parse_client_table(redis_client): A list of information about the nodes in the cluster. """ NIL_CLIENT_ID = ray.ObjectID.nil().binary() - message = redis_client.execute_command("RAY.TABLE_LOOKUP", - ray.gcs_utils.TablePrefix.CLIENT, - "", NIL_CLIENT_ID) + message = redis_client.execute_command( + "RAY.TABLE_LOOKUP", gcs_utils.TablePrefix.Value("CLIENT"), "", + NIL_CLIENT_ID) # Handle the case where no clients are returned. This should only # occur potentially immediately after the cluster is started. @@ -41,36 +41,31 @@ def _parse_client_table(redis_client): return [] node_info = {} - gcs_entry = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0) + gcs_entry = gcs_utils.GcsEntry.FromString(message) ordered_client_ids = [] # Since GCS entries are append-only, we override so that # only the latest entries are kept. - for i in range(gcs_entry.EntriesLength()): - client = (ray.gcs_utils.ClientTableData.GetRootAsClientTableData( - gcs_entry.Entries(i), 0)) + for entry in gcs_entry.entries: + client = gcs_utils.ClientTableData.FromString(entry) resources = { - decode(client.ResourcesTotalLabel(i)): - client.ResourcesTotalCapacity(i) - for i in range(client.ResourcesTotalLabelLength()) + client.resources_total_label[i]: client.resources_total_capacity[i] + for i in range(len(client.resources_total_label)) } - client_id = ray.utils.binary_to_hex(client.ClientId()) + client_id = ray.utils.binary_to_hex(client.client_id) - if client.EntryType() == EntryType.INSERTION: + if client.entry_type == gcs_utils.ClientTableData.INSERTION: ordered_client_ids.append(client_id) node_info[client_id] = { "ClientID": client_id, - "EntryType": client.EntryType(), - "NodeManagerAddress": decode( - client.NodeManagerAddress(), allow_none=True), - "NodeManagerPort": client.NodeManagerPort(), - "ObjectManagerPort": client.ObjectManagerPort(), - "ObjectStoreSocketName": decode( - client.ObjectStoreSocketName(), allow_none=True), - "RayletSocketName": decode( - client.RayletSocketName(), allow_none=True), + "EntryType": client.entry_type, + "NodeManagerAddress": client.node_manager_address, + "NodeManagerPort": client.node_manager_port, + "ObjectManagerPort": client.object_manager_port, + "ObjectStoreSocketName": client.object_store_socket_name, + "RayletSocketName": client.raylet_socket_name, "Resources": resources } @@ -79,22 +74,23 @@ def _parse_client_table(redis_client): # it cannot have previously been removed. else: assert client_id in node_info, "Client not found!" - assert node_info[client_id]["EntryType"] != EntryType.DELETION, ( - "Unexpected updation of deleted client.") + is_deletion = (node_info[client_id]["EntryType"] != + gcs_utils.ClientTableData.DELETION) + assert is_deletion, "Unexpected updation of deleted client." res_map = node_info[client_id]["Resources"] - if client.EntryType() == EntryType.RES_CREATEUPDATE: + if client.entry_type == gcs_utils.ClientTableData.RES_CREATEUPDATE: for res in resources: res_map[res] = resources[res] - elif client.EntryType() == EntryType.RES_DELETE: + elif client.entry_type == gcs_utils.ClientTableData.RES_DELETE: for res in resources: res_map.pop(res, None) - elif client.EntryType() == EntryType.DELETION: + elif client.entry_type == gcs_utils.ClientTableData.DELETION: pass # Do nothing with the resmap if client deletion else: raise RuntimeError("Unexpected EntryType {}".format( - client.EntryType())) + client.entry_type)) node_info[client_id]["Resources"] = res_map - node_info[client_id]["EntryType"] = client.EntryType() + node_info[client_id]["EntryType"] = client.entry_type # NOTE: We return the list comprehension below instead of simply doing # 'list(node_info.values())' in order to have the nodes appear in the order # that they joined the cluster. Python dictionaries do not preserve @@ -244,20 +240,19 @@ def _object_table(self, object_id): # Return information about a single object ID. message = self._execute_command(object_id, "RAY.TABLE_LOOKUP", - ray.gcs_utils.TablePrefix.OBJECT, "", - object_id.binary()) + gcs_utils.TablePrefix.Value("OBJECT"), + "", object_id.binary()) if message is None: return {} - gcs_entry = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0) + gcs_entry = gcs_utils.GcsEntry.FromString(message) - assert gcs_entry.EntriesLength() > 0 + assert len(gcs_entry.entries) > 0 - entry = ray.gcs_utils.ObjectTableData.GetRootAsObjectTableData( - gcs_entry.Entries(0), 0) + entry = gcs_utils.ObjectTableData.FromString(gcs_entry.entries[0]) object_info = { - "DataSize": entry.ObjectSize(), - "Manager": entry.Manager(), + "DataSize": entry.object_size, + "Manager": entry.manager, } return object_info @@ -278,10 +273,9 @@ def object_table(self, object_id=None): return self._object_table(object_id) else: # Return the entire object table. - object_keys = self._keys(ray.gcs_utils.TablePrefix_OBJECT_string + - "*") + object_keys = self._keys(gcs_utils.TablePrefix_OBJECT_string + "*") object_ids_binary = { - key[len(ray.gcs_utils.TablePrefix_OBJECT_string):] + key[len(gcs_utils.TablePrefix_OBJECT_string):] for key in object_keys } @@ -301,17 +295,18 @@ def _task_table(self, task_id): A dictionary with information about the task ID in question. """ assert isinstance(task_id, ray.TaskID) - message = self._execute_command(task_id, "RAY.TABLE_LOOKUP", - ray.gcs_utils.TablePrefix.RAYLET_TASK, - "", task_id.binary()) + message = self._execute_command( + task_id, "RAY.TABLE_LOOKUP", + gcs_utils.TablePrefix.Value("RAYLET_TASK"), "", task_id.binary()) if message is None: return {} - gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0) - - assert gcs_entries.EntriesLength() == 1 + gcs_entries = gcs_utils.GcsEntry.FromString(message) - task_table_message = ray.gcs_utils.Task.GetRootAsTask( - gcs_entries.Entries(0), 0) + assert len(gcs_entries.entries) == 1 + task_table_data = gcs_utils.TaskTableData.FromString( + gcs_entries.entries[0]) + task_table_message = gcs_utils.Task.GetRootAsTask( + task_table_data.task, 0) execution_spec = task_table_message.TaskExecutionSpec() task_spec = task_table_message.TaskSpecification() @@ -368,9 +363,9 @@ def task_table(self, task_id=None): return self._task_table(task_id) else: task_table_keys = self._keys( - ray.gcs_utils.TablePrefix_RAYLET_TASK_string + "*") + gcs_utils.TablePrefix_RAYLET_TASK_string + "*") task_ids_binary = [ - key[len(ray.gcs_utils.TablePrefix_RAYLET_TASK_string):] + key[len(gcs_utils.TablePrefix_RAYLET_TASK_string):] for key in task_table_keys ] @@ -380,27 +375,6 @@ def task_table(self, task_id=None): ray.TaskID(task_id_binary)) return results - def function_table(self, function_id=None): - """Fetch and parse the function table. - - Returns: - A dictionary that maps function IDs to information about the - function. - """ - self._check_connected() - function_table_keys = self.redis_client.keys( - ray.gcs_utils.FUNCTION_PREFIX + "*") - results = {} - for key in function_table_keys: - info = self.redis_client.hgetall(key) - function_info_parsed = { - "DriverID": binary_to_hex(info[b"driver_id"]), - "Module": decode(info[b"module"]), - "Name": decode(info[b"name"]) - } - results[binary_to_hex(info[b"function_id"])] = function_info_parsed - return results - def client_table(self): """Fetch and parse the Redis DB client table. @@ -423,37 +397,32 @@ def _profile_table(self, batch_id): # TODO(rkn): This method should support limiting the number of log # events and should also support returning a window of events. message = self._execute_command(batch_id, "RAY.TABLE_LOOKUP", - ray.gcs_utils.TablePrefix.PROFILE, "", - batch_id.binary()) + gcs_utils.TablePrefix.Value("PROFILE"), + "", batch_id.binary()) if message is None: return [] - gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0) + gcs_entries = gcs_utils.GcsEntry.FromString(message) profile_events = [] - for i in range(gcs_entries.EntriesLength()): - profile_table_message = ( - ray.gcs_utils.ProfileTableData.GetRootAsProfileTableData( - gcs_entries.Entries(i), 0)) - - component_type = decode(profile_table_message.ComponentType()) - component_id = binary_to_hex(profile_table_message.ComponentId()) - node_ip_address = decode( - profile_table_message.NodeIpAddress(), allow_none=True) + for entry in gcs_entries.entries: + profile_table_message = gcs_utils.ProfileTableData.FromString( + entry) - for j in range(profile_table_message.ProfileEventsLength()): - profile_event_message = profile_table_message.ProfileEvents(j) + component_type = profile_table_message.component_type + component_id = binary_to_hex(profile_table_message.component_id) + node_ip_address = profile_table_message.node_ip_address + for profile_event_message in profile_table_message.profile_events: profile_event = { - "event_type": decode(profile_event_message.EventType()), + "event_type": profile_event_message.event_type, "component_id": component_id, "node_ip_address": node_ip_address, "component_type": component_type, - "start_time": profile_event_message.StartTime(), - "end_time": profile_event_message.EndTime(), - "extra_data": json.loads( - decode(profile_event_message.ExtraData())), + "start_time": profile_event_message.start_time, + "end_time": profile_event_message.end_time, + "extra_data": json.loads(profile_event_message.extra_data), } profile_events.append(profile_event) @@ -462,10 +431,10 @@ def _profile_table(self, batch_id): def profile_table(self): self._check_connected() - profile_table_keys = self._keys( - ray.gcs_utils.TablePrefix_PROFILE_string + "*") + profile_table_keys = self._keys(gcs_utils.TablePrefix_PROFILE_string + + "*") batch_identifiers_binary = [ - key[len(ray.gcs_utils.TablePrefix_PROFILE_string):] + key[len(gcs_utils.TablePrefix_PROFILE_string):] for key in profile_table_keys ] @@ -766,7 +735,7 @@ def cluster_resources(self): clients = self.client_table() for client in clients: # Only count resources from latest entries of live clients. - if client["EntryType"] != EntryType.DELETION: + if client["EntryType"] != gcs_utils.ClientTableData.DELETION: for key, value in client["Resources"].items(): resources[key] += value return dict(resources) @@ -776,7 +745,7 @@ def _live_client_ids(self): return { client["ClientID"] for client in self.client_table() - if (client["EntryType"] != EntryType.DELETION) + if (client["EntryType"] != gcs_utils.ClientTableData.DELETION) } def available_resources(self): @@ -800,7 +769,7 @@ def available_resources(self): for redis_client in self.redis_clients ] for subscribe_client in subscribe_clients: - subscribe_client.subscribe(ray.gcs_utils.XRAY_HEARTBEAT_CHANNEL) + subscribe_client.subscribe(gcs_utils.XRAY_HEARTBEAT_CHANNEL) client_ids = self._live_client_ids() @@ -809,24 +778,23 @@ def available_resources(self): # Parse client message raw_message = subscribe_client.get_message() if (raw_message is None or raw_message["channel"] != - ray.gcs_utils.XRAY_HEARTBEAT_CHANNEL): + gcs_utils.XRAY_HEARTBEAT_CHANNEL): continue data = raw_message["data"] - gcs_entries = (ray.gcs_utils.GcsEntry.GetRootAsGcsEntry( - data, 0)) - heartbeat_data = gcs_entries.Entries(0) - message = (ray.gcs_utils.HeartbeatTableData. - GetRootAsHeartbeatTableData(heartbeat_data, 0)) + gcs_entries = gcs_utils.GcsEntry.FromString(data) + heartbeat_data = gcs_entries.entries[0] + message = gcs_utils.HeartbeatTableData.FromString( + heartbeat_data) # Calculate available resources for this client - num_resources = message.ResourcesAvailableLabelLength() + num_resources = len(message.resources_available_label) dynamic_resources = {} for i in range(num_resources): - resource_id = decode(message.ResourcesAvailableLabel(i)) + resource_id = message.resources_available_label[i] dynamic_resources[resource_id] = ( - message.ResourcesAvailableCapacity(i)) + message.resources_available_capacity[i]) # Update available resources for this client - client_id = ray.utils.binary_to_hex(message.ClientId()) + client_id = ray.utils.binary_to_hex(message.client_id) available_resources_by_id[client_id] = dynamic_resources # Update clients in cluster @@ -860,23 +828,22 @@ def _error_messages(self, driver_id): """ assert isinstance(driver_id, ray.DriverID) message = self.redis_client.execute_command( - "RAY.TABLE_LOOKUP", ray.gcs_utils.TablePrefix.ERROR_INFO, "", + "RAY.TABLE_LOOKUP", gcs_utils.TablePrefix.Value("ERROR_INFO"), "", driver_id.binary()) # If there are no errors, return early. if message is None: return [] - gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0) + gcs_entries = gcs_utils.GcsEntry.FromString(message) error_messages = [] - for i in range(gcs_entries.EntriesLength()): - error_data = ray.gcs_utils.ErrorTableData.GetRootAsErrorTableData( - gcs_entries.Entries(i), 0) - assert driver_id.binary() == error_data.DriverId() + for entry in gcs_entries.entries: + error_data = gcs_utils.ErrorTableData.FromString(entry) + assert driver_id.binary() == error_data.driver_id error_message = { - "type": decode(error_data.Type()), - "message": decode(error_data.ErrorMessage()), - "timestamp": error_data.Timestamp(), + "type": error_data.type, + "message": error_data.error_message, + "timestamp": error_data.timestamp, } error_messages.append(error_message) return error_messages @@ -899,9 +866,9 @@ def error_messages(self, driver_id=None): return self._error_messages(driver_id) error_table_keys = self.redis_client.keys( - ray.gcs_utils.TablePrefix_ERROR_INFO_string + "*") + gcs_utils.TablePrefix_ERROR_INFO_string + "*") driver_ids = [ - key[len(ray.gcs_utils.TablePrefix_ERROR_INFO_string):] + key[len(gcs_utils.TablePrefix_ERROR_INFO_string):] for key in error_table_keys ] @@ -923,30 +890,23 @@ def actor_checkpoint_info(self, actor_id): message = self._execute_command( actor_id, "RAY.TABLE_LOOKUP", - ray.gcs_utils.TablePrefix.ACTOR_CHECKPOINT_ID, + gcs_utils.TablePrefix.Value("ACTOR_CHECKPOINT_ID"), "", actor_id.binary(), ) if message is None: return None - gcs_entry = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0) - entry = ( - ray.gcs_utils.ActorCheckpointIdData.GetRootAsActorCheckpointIdData( - gcs_entry.Entries(0), 0)) - checkpoint_ids_str = entry.CheckpointIds() - num_checkpoints = len(checkpoint_ids_str) // ID_SIZE - assert len(checkpoint_ids_str) % ID_SIZE == 0 + gcs_entry = gcs_utils.GcsEntry.FromString(message) + entry = gcs_utils.ActorCheckpointIdData.FromString( + gcs_entry.entries[0]) checkpoint_ids = [ - ray.ActorCheckpointID( - checkpoint_ids_str[(i * ID_SIZE):((i + 1) * ID_SIZE)]) - for i in range(num_checkpoints) + ray.ActorCheckpointID(checkpoint_id) + for checkpoint_id in entry.checkpoint_ids ] return { - "ActorID": ray.utils.binary_to_hex(entry.ActorId()), + "ActorID": ray.utils.binary_to_hex(entry.actor_id), "CheckpointIds": checkpoint_ids, - "Timestamps": [ - entry.Timestamps(i) for i in range(num_checkpoints) - ], + "Timestamps": list(entry.timestamps), } diff --git a/python/ray/tests/cluster_utils.py b/python/ray/tests/cluster_utils.py index 703c3a1420ed..76dfd3000b86 100644 --- a/python/ray/tests/cluster_utils.py +++ b/python/ray/tests/cluster_utils.py @@ -8,7 +8,7 @@ import redis import ray -from ray.core.generated.EntryType import EntryType +from ray.gcs_utils import ClientTableData logger = logging.getLogger(__name__) @@ -177,7 +177,7 @@ def wait_for_nodes(self, timeout=30): clients = ray.state._parse_client_table(redis_client) live_clients = [ client for client in clients - if client["EntryType"] == EntryType.INSERTION + if client["EntryType"] == ClientTableData.INSERTION ] expected = len(self.list_all_nodes()) diff --git a/python/ray/tests/conftest.py b/python/ray/tests/conftest.py index 2e670fb0a84d..f7c93fd50c2e 100644 --- a/python/ray/tests/conftest.py +++ b/python/ray/tests/conftest.py @@ -56,6 +56,14 @@ def _ray_start(**kwargs): ray.shutdown() +# The following fixture will start ray with 0 cpu. +@pytest.fixture +def ray_start_no_cpu(request): + param = getattr(request, "param", {}) + with _ray_start(num_cpus=0, **param) as res: + yield res + + # The following fixture will start ray with 1 cpu. @pytest.fixture def ray_start_regular(request): diff --git a/python/ray/tests/test_actor.py b/python/ray/tests/test_actor.py index dd726e00f27b..932f7b090bf7 100644 --- a/python/ray/tests/test_actor.py +++ b/python/ray/tests/test_actor.py @@ -842,7 +842,7 @@ def f(): assert actor_id not in resulting_ids -def test_actors_on_nodes_with_no_cpus(ray_start_regular): +def test_actors_on_nodes_with_no_cpus(ray_start_no_cpu): @ray.remote class Foo(object): def method(self): diff --git a/python/ray/tests/test_basic.py b/python/ray/tests/test_basic.py index 7f1f78d1b5c4..6b4bd754cd4d 100644 --- a/python/ray/tests/test_basic.py +++ b/python/ray/tests/test_basic.py @@ -2736,15 +2736,17 @@ def test_duplicate_error_messages(shutdown_only): r = ray.worker.global_worker.redis_client - r.execute_command("RAY.TABLE_APPEND", ray.gcs_utils.TablePrefix.ERROR_INFO, - ray.gcs_utils.TablePubsub.ERROR_INFO, driver_id.binary(), - error_data) + r.execute_command("RAY.TABLE_APPEND", + ray.gcs_utils.TablePrefix.Value("ERROR_INFO"), + ray.gcs_utils.TablePubsub.Value("ERROR_INFO_PUBSUB"), + driver_id.binary(), error_data) # Before https://github.com/ray-project/ray/pull/3316 this would # give an error - r.execute_command("RAY.TABLE_APPEND", ray.gcs_utils.TablePrefix.ERROR_INFO, - ray.gcs_utils.TablePubsub.ERROR_INFO, driver_id.binary(), - error_data) + r.execute_command("RAY.TABLE_APPEND", + ray.gcs_utils.TablePrefix.Value("ERROR_INFO"), + ray.gcs_utils.TablePubsub.Value("ERROR_INFO_PUBSUB"), + driver_id.binary(), error_data) @pytest.mark.skipif( diff --git a/python/ray/tests/test_failure.py b/python/ray/tests/test_failure.py index 51b906695c2d..a560e461f7a2 100644 --- a/python/ray/tests/test_failure.py +++ b/python/ray/tests/test_failure.py @@ -493,8 +493,9 @@ def test_warning_monitor_died(shutdown_only): malformed_message = "asdf" redis_client = ray.worker.global_worker.redis_client redis_client.execute_command( - "RAY.TABLE_ADD", ray.gcs_utils.TablePrefix.HEARTBEAT_BATCH, - ray.gcs_utils.TablePubsub.HEARTBEAT_BATCH, fake_id, malformed_message) + "RAY.TABLE_ADD", ray.gcs_utils.TablePrefix.Value("HEARTBEAT_BATCH"), + ray.gcs_utils.TablePubsub.Value("HEARTBEAT_BATCH_PUBSUB"), fake_id, + malformed_message) wait_for_errors(ray_constants.MONITOR_DIED_ERROR, 1) diff --git a/python/ray/tests/test_signal.py b/python/ray/tests/test_signal.py index fe2e74379245..176fbd45bcaa 100644 --- a/python/ray/tests/test_signal.py +++ b/python/ray/tests/test_signal.py @@ -353,3 +353,36 @@ def f(sources): assert len(result_list) == 1 result_list = ray.get(f.remote([a])) assert len(result_list) == 1 + + +def test_non_integral_receive_timeout(ray_start_regular): + @ray.remote + def send_signal(value): + signal.send(UserSignal(value)) + + a = send_signal.remote(0) + # make sure send_signal had a chance to execute + ray.get(a) + + result_list = ray.experimental.signal.receive([a], timeout=0.1) + + assert len(result_list) == 1 + + +def test_small_receive_timeout(ray_start_regular): + """ Test that receive handles timeout smaller than the 1ms min + """ + # 0.1 ms + small_timeout = 1e-4 + + @ray.remote + def send_signal(value): + signal.send(UserSignal(value)) + + a = send_signal.remote(0) + # make sure send_signal had a chance to execute + ray.get(a) + + result_list = ray.experimental.signal.receive([a], timeout=small_timeout) + + assert len(result_list) == 1 diff --git a/python/ray/tune/analysis/experiment_analysis.py b/python/ray/tune/analysis/experiment_analysis.py index 0164ec2b1a2e..a3c246aba161 100644 --- a/python/ray/tune/analysis/experiment_analysis.py +++ b/python/ray/tune/analysis/experiment_analysis.py @@ -47,7 +47,14 @@ class ExperimentAnalysis(object): >>> experiment_path="~/tune_results/my_exp") """ - def __init__(self, experiment_path): + def __init__(self, experiment_path, trials=None): + """Initializer. + + Args: + experiment_path (str): Path to where experiment is located. + trials (list|None): List of trials that can be accessed via + `analysis.trials`. + """ experiment_path = os.path.expanduser(experiment_path) if not os.path.isdir(experiment_path): raise TuneError( @@ -55,7 +62,8 @@ def __init__(self, experiment_path): experiment_state_paths = glob.glob( os.path.join(experiment_path, "experiment_state*.json")) if not experiment_state_paths: - raise TuneError("No experiment state found!") + raise TuneError( + "No experiment state found in {}!".format(experiment_path)) experiment_filename = max( list(experiment_state_paths)) # if more than one, pick latest with open(os.path.join(experiment_path, experiment_filename)) as f: @@ -65,10 +73,27 @@ def __init__(self, experiment_path): raise TuneError("Experiment state invalid; no checkpoints found.") self._checkpoints = self._experiment_state["checkpoints"] self._scrubbed_checkpoints = unnest_checkpoints(self._checkpoints) + self.trials = trials + self._dataframe = None + + def get_all_trial_dataframes(self): + trial_dfs = {} + for checkpoint in self._checkpoints: + logdir = checkpoint["logdir"] + progress = max(glob.glob(os.path.join(logdir, "progress.csv"))) + trial_dfs[checkpoint["trial_id"]] = pd.read_csv(progress) + return trial_dfs + + def dataframe(self, refresh=False): + """Returns a pandas.DataFrame object constructed from the trials. - def dataframe(self): - """Returns a pandas.DataFrame object constructed from the trials.""" - return pd.DataFrame(self._scrubbed_checkpoints) + Args: + refresh (bool): Clears the cache which may have an existing copy. + + """ + if self._dataframe is None or refresh: + self._dataframe = pd.DataFrame(self._scrubbed_checkpoints) + return self._dataframe def stats(self): """Returns a dictionary of the statistics of the experiment.""" @@ -87,22 +112,45 @@ def trial_dataframe(self, trial_id): return pd.read_csv(progress) raise ValueError("Trial id {} not found".format(trial_id)) - def get_best_trainable(self, metric, trainable_cls): - """Returns the best Trainable based on the experiment metric.""" - return trainable_cls(config=self.get_best_config(metric)) - - def get_best_config(self, metric): - """Retrieve the best config from the best trial.""" - return self._get_best_trial(metric)["config"] - - def _get_best_trial(self, metric): - """Retrieve the best trial based on the experiment metric.""" - return max( + def get_best_trainable(self, metric, trainable_cls, mode="max"): + """Returns the best Trainable based on the experiment metric. + + Args: + metric (str): Key for trial info to order on. + mode (str): One of [min, max]. + + """ + return trainable_cls(config=self.get_best_config(metric, mode=mode)) + + def get_best_config(self, metric, mode="max"): + """Retrieve the best config from the best trial. + + Args: + metric (str): Key for trial info to order on. + mode (str): One of [min, max]. + + """ + return self.get_best_info(metric, flatten=False, mode=mode)["config"] + + def get_best_logdir(self, metric, mode="max"): + df = self.dataframe() + if mode == "max": + return df.iloc[df[metric].idxmax()].logdir + elif mode == "min": + return df.iloc[df[metric].idxmin()].logdir + + def get_best_info(self, metric, mode="max", flatten=True): + """Retrieve the best trial based on the experiment metric. + + Args: + metric (str): Key for trial info to order on. + mode (str): One of [min, max]. + flatten (bool): Assumes trial info is flattened, where + nested entries are concatenated like `info:metric`. + """ + optimize_op = max if mode == "max" else min + if flatten: + return optimize_op( + self._scrubbed_checkpoints, key=lambda d: d.get(metric, 0)) + return optimize_op( self._checkpoints, key=lambda d: d["last_result"].get(metric, 0)) - - def _get_sorted_trials(self, metric): - """Retrive trials in sorted order based on the experiment metric.""" - return sorted( - self._checkpoints, - key=lambda d: d["last_result"].get(metric, 0), - reverse=True) diff --git a/python/ray/tune/examples/mnist_pytorch.py b/python/ray/tune/examples/mnist_pytorch.py index 03dd2f1607e2..acef9fc5105d 100644 --- a/python/ray/tune/examples/mnist_pytorch.py +++ b/python/ray/tune/examples/mnist_pytorch.py @@ -1,7 +1,10 @@ # Original Code here: # https://github.com/pytorch/examples/blob/master/mnist/main.py +from __future__ import absolute_import +from __future__ import division from __future__ import print_function +import numpy as np import argparse import torch import torch.nn as nn @@ -9,181 +12,123 @@ import torch.optim as optim from torchvision import datasets, transforms -# Training settings -parser = argparse.ArgumentParser(description="PyTorch MNIST Example") -parser.add_argument( - "--batch-size", - type=int, - default=64, - metavar="N", - help="input batch size for training (default: 64)") -parser.add_argument( - "--test-batch-size", - type=int, - default=1000, - metavar="N", - help="input batch size for testing (default: 1000)") -parser.add_argument( - "--epochs", - type=int, - default=1, - metavar="N", - help="number of epochs to train (default: 1)") -parser.add_argument( - "--lr", - type=float, - default=0.01, - metavar="LR", - help="learning rate (default: 0.01)") -parser.add_argument( - "--momentum", - type=float, - default=0.5, - metavar="M", - help="SGD momentum (default: 0.5)") -parser.add_argument( - "--no-cuda", - action="store_true", - default=False, - help="disables CUDA training") -parser.add_argument( - "--seed", - type=int, - default=1, - metavar="S", - help="random seed (default: 1)") -parser.add_argument( - "--smoke-test", action="store_true", help="Finish quickly for testing") - - -def train_mnist(args, config, reporter): - vars(args).update(config) - args.cuda = not args.no_cuda and torch.cuda.is_available() - - torch.manual_seed(args.seed) - if args.cuda: - torch.cuda.manual_seed(args.seed) - - kwargs = {"num_workers": 1, "pin_memory": True} if args.cuda else {} +import ray +from ray import tune +from ray.tune import track +from ray.tune.schedulers import AsyncHyperBandScheduler + +# Change these values if you want the training to run quicker or slower. +EPOCH_SIZE = 512 +TEST_SIZE = 256 + + +class Net(nn.Module): + def __init__(self, config): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(1, 3, kernel_size=3) + self.fc = nn.Linear(192, 10) + + def forward(self, x): + x = F.relu(F.max_pool2d(self.conv1(x), 3)) + x = x.view(-1, 192) + x = self.fc(x) + return F.log_softmax(x, dim=1) + + +def train(model, optimizer, train_loader, device): + model.train() + for batch_idx, (data, target) in enumerate(train_loader): + if batch_idx * len(data) > EPOCH_SIZE: + return + data, target = data.to(device), target.to(device) + optimizer.zero_grad() + output = model(data) + loss = F.nll_loss(output, target) + loss.backward() + optimizer.step() + + +def test(model, data_loader, device): + model.eval() + correct = 0 + total = 0 + with torch.no_grad(): + for batch_idx, (data, target) in enumerate(data_loader): + if batch_idx * len(data) > TEST_SIZE: + break + data, target = data.to(device), target.to(device) + outputs = model(data) + _, predicted = torch.max(outputs.data, 1) + total += target.size(0) + correct += (predicted == target).sum().item() + + return correct / total + + +def get_data_loaders(): + mnist_transforms = transforms.Compose( + [transforms.ToTensor(), + transforms.Normalize((0.1307, ), (0.3081, ))]) + train_loader = torch.utils.data.DataLoader( datasets.MNIST( - "~/data", - train=True, - download=False, - transform=transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.1307, ), (0.3081, )) - ])), - batch_size=args.batch_size, - shuffle=True, - **kwargs) + "~/data", train=True, download=True, transform=mnist_transforms), + batch_size=64, + shuffle=True) test_loader = torch.utils.data.DataLoader( - datasets.MNIST( - "~/data", - train=False, - transform=transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.1307, ), (0.3081, )) - ])), - batch_size=args.test_batch_size, - shuffle=True, - **kwargs) - - class Net(nn.Module): - def __init__(self): - super(Net, self).__init__() - self.conv1 = nn.Conv2d(1, 10, kernel_size=5) - self.conv2 = nn.Conv2d(10, 20, kernel_size=5) - self.conv2_drop = nn.Dropout2d() - self.fc1 = nn.Linear(320, 50) - self.fc2 = nn.Linear(50, 10) - - def forward(self, x): - x = F.relu(F.max_pool2d(self.conv1(x), 2)) - x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) - x = x.view(-1, 320) - x = F.relu(self.fc1(x)) - x = F.dropout(x, training=self.training) - x = self.fc2(x) - return F.log_softmax(x, dim=1) - - model = Net() - if args.cuda: - model.cuda() + datasets.MNIST("~/data", train=False, transform=mnist_transforms), + batch_size=64, + shuffle=True) + return train_loader, test_loader + + +def train_mnist(config): + use_cuda = config.get("use_gpu") and torch.cuda.is_available() + device = torch.device("cuda" if use_cuda else "cpu") + train_loader, test_loader = get_data_loaders() + model = Net(config).to(device) optimizer = optim.SGD( - model.parameters(), lr=args.lr, momentum=args.momentum) - - def train(epoch): - model.train() - for batch_idx, (data, target) in enumerate(train_loader): - if args.cuda: - data, target = data.cuda(), target.cuda() - optimizer.zero_grad() - output = model(data) - loss = F.nll_loss(output, target) - loss.backward() - optimizer.step() - - def test(): - model.eval() - test_loss = 0 - correct = 0 - with torch.no_grad(): - for data, target in test_loader: - if args.cuda: - data, target = data.cuda(), target.cuda() - output = model(data) - # sum up batch loss - test_loss += F.nll_loss(output, target, reduction="sum").item() - # get the index of the max log-probability - pred = output.argmax(dim=1, keepdim=True) - correct += pred.eq( - target.data.view_as(pred)).long().cpu().sum() - - test_loss = test_loss / len(test_loader.dataset) - accuracy = correct.item() / len(test_loader.dataset) - reporter(mean_loss=test_loss, mean_accuracy=accuracy) - - for epoch in range(1, args.epochs + 1): - train(epoch) - test() + model.parameters(), lr=config["lr"], momentum=config["momentum"]) + + while True: + train(model, optimizer, train_loader, device) + acc = test(model, test_loader, device) + track.log(mean_accuracy=acc) if __name__ == "__main__": - datasets.MNIST("~/data", train=True, download=True) + parser = argparse.ArgumentParser(description="PyTorch MNIST Example") + parser.add_argument( + "--cuda", + action="store_true", + default=False, + help="Enables GPU training") + parser.add_argument( + "--smoke-test", action="store_true", help="Finish quickly for testing") + parser.add_argument( + "--ray-redis-address", + help="Address of Ray cluster for seamless distributed execution.") args = parser.parse_args() - - import ray - from ray import tune - from ray.tune.schedulers import AsyncHyperBandScheduler - - ray.init() + if args.ray_redis_address: + ray.init(redis_address=args.ray_redis_address) sched = AsyncHyperBandScheduler( - time_attr="training_iteration", - metric="mean_loss", - mode="min", - max_t=400, - grace_period=20) - tune.register_trainable( - "TRAIN_FN", - lambda config, reporter: train_mnist(args, config, reporter)) + time_attr="training_iteration", metric="mean_accuracy") tune.run( - "TRAIN_FN", + train_mnist, name="exp", scheduler=sched, - **{ - "stop": { - "mean_accuracy": 0.98, - "training_iteration": 1 if args.smoke_test else 20 - }, - "resources_per_trial": { - "cpu": 3, - "gpu": int(not args.no_cuda) - }, - "num_samples": 1 if args.smoke_test else 10, - "config": { - "lr": tune.uniform(0.001, 0.1), - "momentum": tune.uniform(0.1, 0.9), - } + stop={ + "mean_accuracy": 0.98, + "training_iteration": 5 if args.smoke_test else 20 + }, + resources_per_trial={ + "cpu": 2, + "gpu": int(args.cuda) + }, + num_samples=1 if args.smoke_test else 10, + config={ + "lr": tune.sample_from(lambda spec: 10**(-10 * np.random.rand())), + "momentum": tune.uniform(0.1, 0.9), + "use_gpu": int(args.cuda) }) diff --git a/python/ray/tune/examples/track_example.py b/python/ray/tune/examples/track_example.py index 1ccec39462d0..751f0ed44fa9 100644 --- a/python/ray/tune/examples/track_example.py +++ b/python/ray/tune/examples/track_example.py @@ -9,7 +9,7 @@ from keras.layers import (Dense, Dropout, Flatten, Conv2D, MaxPooling2D) from ray.tune import track -from ray.tune.examples.utils import TuneKerasCallback, get_mnist_data +from ray.tune.examples.utils import TuneReporterCallback, get_mnist_data parser = argparse.ArgumentParser() parser.add_argument( @@ -63,7 +63,7 @@ def train_mnist(args): batch_size=batch_size, epochs=epochs, validation_data=(x_test, y_test), - callbacks=[TuneKerasCallback(track.metric)]) + callbacks=[TuneReporterCallback(track.metric)]) track.shutdown() diff --git a/python/ray/tune/examples/tune_mnist_keras.py b/python/ray/tune/examples/tune_mnist_keras.py index 5357d86af19e..ecd3c34bc042 100644 --- a/python/ray/tune/examples/tune_mnist_keras.py +++ b/python/ray/tune/examples/tune_mnist_keras.py @@ -9,8 +9,8 @@ from keras.models import Sequential from keras.layers import (Dense, Dropout, Flatten, Conv2D, MaxPooling2D) -from ray.tune.examples.utils import (TuneKerasCallback, get_mnist_data, - set_keras_threads) +from ray.tune.integration.keras import TuneReporterCallback +from ray.tune.examples.utils import get_mnist_data, set_keras_threads parser = argparse.ArgumentParser() parser.add_argument( @@ -52,7 +52,7 @@ def train_mnist(config, reporter): epochs=epochs, verbose=0, validation_data=(x_test, y_test), - callbacks=[TuneKerasCallback(reporter)]) + callbacks=[TuneReporterCallback(reporter)]) if __name__ == "__main__": @@ -63,7 +63,7 @@ def train_mnist(config, reporter): ray.init() sched = AsyncHyperBandScheduler( - time_attr="timesteps_total", + time_attr="training_iteration", metric="mean_accuracy", mode="max", max_t=400, diff --git a/python/ray/tune/examples/utils.py b/python/ray/tune/examples/utils.py index a5ab1dbdb6a1..f40707a014fc 100644 --- a/python/ray/tune/examples/utils.py +++ b/python/ray/tune/examples/utils.py @@ -5,24 +5,9 @@ import keras from keras.datasets import mnist from keras import backend as K - - -class TuneKerasCallback(keras.callbacks.Callback): - def __init__(self, reporter, logs={}): - self.reporter = reporter - self.iteration = 0 - super(TuneKerasCallback, self).__init__() - - def on_train_end(self, epoch, logs={}): - self.reporter( - timesteps_total=self.iteration, - done=1, - mean_accuracy=logs.get("acc")) - - def on_batch_end(self, batch, logs={}): - self.iteration += 1 - self.reporter( - timesteps_total=self.iteration, mean_accuracy=logs["acc"]) +from sklearn.datasets import load_iris +from sklearn.model_selection import train_test_split +from sklearn.preprocessing import OneHotEncoder def get_mnist_data(): @@ -53,6 +38,16 @@ def get_mnist_data(): return x_train, y_train, x_test, y_test, input_shape +def get_iris_data(test_size=0.2): + iris_data = load_iris() + x = iris_data.data + y = iris_data.target.reshape(-1, 1) + encoder = OneHotEncoder(sparse=False) + y = encoder.fit_transform(y) + train_x, test_x, train_y, test_y = train_test_split(x, y) + return train_x, train_y, test_x, test_y + + def set_keras_threads(threads): # We set threads here to avoid contention, as Keras # is heavily parallelized across multiple cores. @@ -61,3 +56,8 @@ def set_keras_threads(threads): config=K.tf.ConfigProto( intra_op_parallelism_threads=threads, inter_op_parallelism_threads=threads))) + + +def TuneKerasCallback(*args, **kwargs): + raise DeprecationWarning("TuneKerasCallback is now " + "tune.integration.keras.TuneReporterCallback.") diff --git a/python/ray/tune/experiment.py b/python/ray/tune/experiment.py index 5f3e46aabd0a..95cb12043f8f 100644 --- a/python/ray/tune/experiment.py +++ b/python/ray/tune/experiment.py @@ -176,6 +176,14 @@ def _register_if_needed(cls, run_object): else: raise TuneError("Improper 'run' - not string nor trainable.") + @property + def local_dir(self): + return self.spec.get("local_dir") + + @property + def checkpoint_dir(self): + return os.path.join(self.spec["local_dir"], self.name) + def convert_to_experiment_list(experiments): """Produces a list of Experiment objects. diff --git a/python/ray/tune/integration/__init__.py b/python/ray/tune/integration/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/python/ray/tune/integration/keras.py b/python/ray/tune/integration/keras.py new file mode 100644 index 000000000000..197a7eef9841 --- /dev/null +++ b/python/ray/tune/integration/keras.py @@ -0,0 +1,34 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import keras +from ray.tune import track + + +class TuneReporterCallback(keras.callbacks.Callback): + def __init__(self, reporter=None, freq="batch", logs={}): + self.reporter = reporter or track.log + self.iteration = 0 + if freq not in ["batch", "epoch"]: + raise ValueError("{} not supported as a frequency.".format(freq)) + self.freq = freq + super(TuneReporterCallback, self).__init__() + + def on_batch_end(self, batch, logs={}): + if not self.freq == "batch": + return + self.iteration += 1 + for metric in list(logs): + if "loss" in metric and "neg_" not in metric: + logs["neg_" + metric] = -logs[metric] + self.reporter(keras_info=logs, mean_accuracy=logs["acc"]) + + def on_epoch_end(self, batch, logs={}): + if not self.freq == "epoch": + return + self.iteration += 1 + for metric in list(logs): + if "loss" in metric and "neg_" not in metric: + logs["neg_" + metric] = -logs[metric] + self.reporter(keras_info=logs, mean_accuracy=logs["acc"]) diff --git a/python/ray/tune/schedulers/__init__.py b/python/ray/tune/schedulers/__init__.py index 50bb447437e4..34655372f40a 100644 --- a/python/ray/tune/schedulers/__init__.py +++ b/python/ray/tune/schedulers/__init__.py @@ -4,11 +4,13 @@ from ray.tune.schedulers.trial_scheduler import TrialScheduler, FIFOScheduler from ray.tune.schedulers.hyperband import HyperBandScheduler -from ray.tune.schedulers.async_hyperband import AsyncHyperBandScheduler +from ray.tune.schedulers.async_hyperband import (AsyncHyperBandScheduler, + ASHAScheduler) from ray.tune.schedulers.median_stopping_rule import MedianStoppingRule from ray.tune.schedulers.pbt import PopulationBasedTraining __all__ = [ "TrialScheduler", "HyperBandScheduler", "AsyncHyperBandScheduler", - "MedianStoppingRule", "FIFOScheduler", "PopulationBasedTraining" + "ASHAScheduler", "MedianStoppingRule", "FIFOScheduler", + "PopulationBasedTraining" ] diff --git a/python/ray/tune/schedulers/async_hyperband.py b/python/ray/tune/schedulers/async_hyperband.py index 487eb350efcf..0370d03d3b50 100644 --- a/python/ray/tune/schedulers/async_hyperband.py +++ b/python/ray/tune/schedulers/async_hyperband.py @@ -168,6 +168,8 @@ def debug_str(self): return "Bracket: " + iters +ASHAScheduler = AsyncHyperBandScheduler + if __name__ == "__main__": sched = AsyncHyperBandScheduler( grace_period=1, max_t=10, reduction_factor=2) diff --git a/python/ray/tune/tests/test_experiment_analysis.py b/python/ray/tune/tests/test_experiment_analysis.py index a0721abc5d29..7b613a6fdea2 100644 --- a/python/ray/tune/tests/test_experiment_analysis.py +++ b/python/ray/tune/tests/test_experiment_analysis.py @@ -11,9 +11,7 @@ import ray from ray.tune import run, sample_from -from ray.tune.analysis import ExperimentAnalysis from ray.tune.examples.async_hyperband_example import MyTrainableClass -from ray.tune.schedulers import AsyncHyperBandScheduler class ExperimentAnalysisSuite(unittest.TestCase): @@ -27,35 +25,22 @@ def setUp(self): self.test_path = os.path.join(self.test_dir, self.test_name) self.run_test_exp() - self.ea = ExperimentAnalysis(self.test_path) - def tearDown(self): shutil.rmtree(self.test_dir, ignore_errors=True) ray.shutdown() def run_test_exp(self): - ahb = AsyncHyperBandScheduler( - time_attr="training_iteration", - metric=self.metric, - mode="max", - grace_period=5, - max_t=100) - - run(MyTrainableClass, + self.ea = run( + MyTrainableClass, name=self.test_name, - scheduler=ahb, local_dir=self.test_dir, - **{ - "stop": { - "training_iteration": 1 - }, - "num_samples": 10, - "config": { - "width": sample_from( - lambda spec: 10 + int(90 * random.random())), - "height": sample_from( - lambda spec: int(100 * random.random())), - }, + return_trials=False, + stop={"training_iteration": 1}, + num_samples=self.num_samples, + config={ + "width": sample_from( + lambda spec: 10 + int(90 * random.random())), + "height": sample_from(lambda spec: int(100 * random.random())), }) def testDataframe(self): @@ -87,7 +72,7 @@ def testBestConfig(self): self.assertTrue("height" in best_config) def testBestTrial(self): - best_trial = self.ea._get_best_trial(self.metric) + best_trial = self.ea.get_best_info(self.metric, flatten=False) self.assertTrue(isinstance(best_trial, dict)) self.assertTrue("local_dir" in best_trial) @@ -99,6 +84,18 @@ def testBestTrial(self): self.assertTrue("last_result" in best_trial) self.assertTrue(self.metric in best_trial["last_result"]) + min_trial = self.ea.get_best_info( + self.metric, mode="min", flatten=False) + + self.assertTrue(isinstance(min_trial, dict)) + self.assertLess(min_trial["last_result"][self.metric], + best_trial["last_result"][self.metric]) + + flat_trial = self.ea.get_best_info(self.metric, flatten=True) + + self.assertTrue(isinstance(min_trial, dict)) + self.assertTrue(self.metric in flat_trial) + def testCheckpoints(self): checkpoints = self.ea._checkpoints @@ -121,6 +118,21 @@ def testRunnerData(self): self.assertEqual(runner_data["_metadata_checkpoint_dir"], os.path.expanduser(self.test_path)) + def testBestLogdir(self): + logdir = self.ea.get_best_logdir(self.metric) + self.assertTrue(logdir.startswith(self.test_path)) + logdir2 = self.ea.get_best_logdir(self.metric, mode="min") + self.assertTrue(logdir2.startswith(self.test_path)) + self.assertNotEquals(logdir, logdir2) + + def testAllDataframes(self): + dataframes = self.ea.get_all_trial_dataframes() + self.assertTrue(len(dataframes) == self.num_samples) + + self.assertTrue(isinstance(dataframes, dict)) + for df in dataframes.values(): + self.assertEqual(df.training_iteration.max(), 1) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/python/ray/tune/tests/test_trial_runner.py b/python/ray/tune/tests/test_trial_runner.py index 37022ceab615..64b8e9761488 100644 --- a/python/ray/tune/tests/test_trial_runner.py +++ b/python/ray/tune/tests/test_trial_runner.py @@ -441,6 +441,14 @@ def f(): self.assertRaises(TuneError, f) + def testNestedStoppingReturn(self): + def train(config, reporter): + for i in range(10): + reporter(test={"test1": {"test2": i}}) + + [trial] = tune.run(train, stop={"test": {"test1": {"test2": 6}}}) + self.assertEqual(trial.last_result["training_iteration"], 7) + def testEarlyReturn(self): def train(config, reporter): reporter(timesteps_total=100, done=True) diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index f721023b4191..a9938396e59b 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -181,6 +181,21 @@ def has_trainable(trainable_name): ray.tune.registry.TRAINABLE_CLASS, trainable_name) +def recursive_criteria_check(result, criteria): + for criteria, stop_value in criteria.items(): + if criteria not in result: + raise TuneError( + "Stopping criteria {} not provided in result {}.".format( + criteria, result)) + elif isinstance(result[criteria], dict) and isinstance( + stop_value, dict): + if recursive_criteria_check(result[criteria], stop_value): + return True + elif result[criteria] >= stop_value: + return True + return False + + class Checkpoint(object): """Describes a checkpoint of trial state. @@ -425,15 +440,7 @@ def should_stop(self, result): if result.get(DONE): return True - for criteria, stop_value in self.stopping_criterion.items(): - if criteria not in result: - raise TuneError( - "Stopping criteria {} not provided in result {}.".format( - criteria, result)) - if result[criteria] >= stop_value: - return True - - return False + return recursive_criteria_check(result, self.stopping_criterion) def should_checkpoint(self): """Whether this trial is due for checkpointing.""" diff --git a/python/ray/tune/tune.py b/python/ray/tune/tune.py index 1568db0f1102..47a82ba0c17f 100644 --- a/python/ray/tune/tune.py +++ b/python/ray/tune/tune.py @@ -4,11 +4,11 @@ import click import logging -import os import time from ray.tune.error import TuneError from ray.tune.experiment import convert_to_experiment_list, Experiment +from ray.tune.analysis import ExperimentAnalysis from ray.tune.suggest import BasicVariantGenerator from ray.tune.trial import Trial, DEBUG_PRINT_INTERVAL from ray.tune.ray_trial_executor import RayTrialExecutor @@ -39,7 +39,7 @@ def _make_scheduler(args): def _find_checkpoint_dir(exp): # TODO(rliaw): Make sure the checkpoint_dir is resolved earlier. # Right now it is resolved somewhere far down the trial generation process - return os.path.join(exp.spec["local_dir"], exp.name) + return exp.checkpoint_dir def _prompt_restore(checkpoint_dir, resume): @@ -89,9 +89,10 @@ def run(run_or_experiment, verbose=2, resume=False, queue_trials=False, - reuse_actors=False, + reuse_actors=True, trial_executor=None, raise_on_failed_trial=True, + return_trials=True, ray_auto_init=True): """Executes training. @@ -322,7 +323,9 @@ def override_flags(restored_config, new_config, flags_to_override): else: logger.error("Trials did not complete: %s", errored_trials) - return runner.get_trials() + if return_trials: + return runner.get_trials() + return ExperimentAnalysis(experiment.checkpoint_dir) def run_experiments(experiments, diff --git a/python/ray/utils.py b/python/ray/utils.py index 7b87486e325e..0db48e41d025 100644 --- a/python/ray/utils.py +++ b/python/ray/utils.py @@ -93,10 +93,10 @@ def push_error_to_driver_through_redis(redis_client, # of through the raylet. error_data = ray.gcs_utils.construct_error_message(driver_id, error_type, message, time.time()) - redis_client.execute_command("RAY.TABLE_APPEND", - ray.gcs_utils.TablePrefix.ERROR_INFO, - ray.gcs_utils.TablePubsub.ERROR_INFO, - driver_id.binary(), error_data) + redis_client.execute_command( + "RAY.TABLE_APPEND", ray.gcs_utils.TablePrefix.Value("ERROR_INFO"), + ray.gcs_utils.TablePubsub.Value("ERROR_INFO_PUBSUB"), + driver_id.binary(), error_data) def is_cython(obj): diff --git a/python/ray/worker.py b/python/ray/worker.py index 7505120574a6..710f0db43c6b 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -47,7 +47,7 @@ from ray import import_thread from ray import profiling -from ray.core.generated.ErrorType import ErrorType +from ray.gcs_utils import ErrorType from ray.exceptions import ( RayActorError, RayError, @@ -461,11 +461,11 @@ def _deserialize_object_from_arrow(self, data, metadata, object_id, # Otherwise, return an exception object based on # the error type. error_type = int(metadata) - if error_type == ErrorType.WORKER_DIED: + if error_type == ErrorType.Value("WORKER_DIED"): return RayWorkerError() - elif error_type == ErrorType.ACTOR_DIED: + elif error_type == ErrorType.Value("ACTOR_DIED"): return RayActorError() - elif error_type == ErrorType.OBJECT_UNRECONSTRUCTABLE: + elif error_type == ErrorType.Value("OBJECT_UNRECONSTRUCTABLE"): return UnreconstructableError(ray.ObjectID(object_id.binary())) else: assert False, "Unrecognized error type " + str(error_type) @@ -1637,7 +1637,7 @@ def listen_error_messages_raylet(worker, task_error_queue, threads_stopped): # Really we should just subscribe to the errors for this specific job. # However, currently all errors seem to be published on the same channel. error_pubsub_channel = str( - ray.gcs_utils.TablePubsub.ERROR_INFO).encode("ascii") + ray.gcs_utils.TablePubsub.Value("ERROR_INFO_PUBSUB")).encode("ascii") worker.error_message_pubsub_client.subscribe(error_pubsub_channel) # worker.error_message_pubsub_client.psubscribe("*") @@ -1656,21 +1656,19 @@ def listen_error_messages_raylet(worker, task_error_queue, threads_stopped): if msg is None: threads_stopped.wait(timeout=0.01) continue - gcs_entry = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry( - msg["data"], 0) - assert gcs_entry.EntriesLength() == 1 - error_data = ray.gcs_utils.ErrorTableData.GetRootAsErrorTableData( - gcs_entry.Entries(0), 0) - driver_id = error_data.DriverId() + gcs_entry = ray.gcs_utils.GcsEntry.FromString(msg["data"]) + assert len(gcs_entry.entries) == 1 + error_data = ray.gcs_utils.ErrorTableData.FromString( + gcs_entry.entries[0]) + driver_id = error_data.driver_id if driver_id not in [ worker.task_driver_id.binary(), DriverID.nil().binary() ]: continue - error_message = ray.utils.decode(error_data.ErrorMessage()) - if (ray.utils.decode( - error_data.Type()) == ray_constants.TASK_PUSH_ERROR): + error_message = error_data.error_message + if (error_data.type == ray_constants.TASK_PUSH_ERROR): # Delay it a bit to see if we can suppress it task_error_queue.put((error_message, time.time())) else: @@ -1878,14 +1876,16 @@ def connect(node, {}, # resource_map. {}, # placement_resource_map. ) + task_table_data = ray.gcs_utils.TaskTableData() + task_table_data.task = driver_task._serialized_raylet_task() # Add the driver task to the task table. - ray.state.state._execute_command(driver_task.task_id(), - "RAY.TABLE_ADD", - ray.gcs_utils.TablePrefix.RAYLET_TASK, - ray.gcs_utils.TablePubsub.RAYLET_TASK, - driver_task.task_id().binary(), - driver_task._serialized_raylet_task()) + ray.state.state._execute_command( + driver_task.task_id(), "RAY.TABLE_ADD", + ray.gcs_utils.TablePrefix.Value("RAYLET_TASK"), + ray.gcs_utils.TablePubsub.Value("RAYLET_TASK_PUBSUB"), + driver_task.task_id().binary(), + task_table_data.SerializeToString()) # Set the driver's current task ID to the task ID assigned to the # driver task. diff --git a/python/setup.py b/python/setup.py index eb200ea7d5e4..95e7e66bad3e 100644 --- a/python/setup.py +++ b/python/setup.py @@ -151,6 +151,7 @@ def find_version(*filepath): "six >= 1.0.0", "flatbuffers", "faulthandler;python_version<'3.3'", + "protobuf", ] setup( diff --git a/src/ray/common/constants.h b/src/ray/common/constants.h index c92e6a74aa5d..1f50b8025d57 100644 --- a/src/ray/common/constants.h +++ b/src/ray/common/constants.h @@ -36,4 +36,6 @@ constexpr char kObjectTablePrefix[] = "ObjectTable"; /// Prefix for the task table keys in redis. constexpr char kTaskTablePrefix[] = "TaskTable"; +constexpr char kWorkerDynamicOptionPlaceholderPrefix[] = "RAY_WORKER_OPTION_"; + #endif // RAY_CONSTANTS_H_ diff --git a/src/ray/gcs/client.cc b/src/ray/gcs/client.cc index c9b1e138575d..6de29bb52764 100644 --- a/src/ray/gcs/client.cc +++ b/src/ray/gcs/client.cc @@ -206,10 +206,6 @@ TaskLeaseTable &AsyncGcsClient::task_lease_table() { return *task_lease_table_; ClientTable &AsyncGcsClient::client_table() { return *client_table_; } -FunctionTable &AsyncGcsClient::function_table() { return *function_table_; } - -ClassTable &AsyncGcsClient::class_table() { return *class_table_; } - HeartbeatTable &AsyncGcsClient::heartbeat_table() { return *heartbeat_table_; } HeartbeatBatchTable &AsyncGcsClient::heartbeat_batch_table() { diff --git a/src/ray/gcs/client.h b/src/ray/gcs/client.h index c9f5b4bca624..5e70025b39a0 100644 --- a/src/ray/gcs/client.h +++ b/src/ray/gcs/client.h @@ -44,11 +44,7 @@ class RAY_EXPORT AsyncGcsClient { /// one event loop should be attached at a time. Status Attach(boost::asio::io_service &io_service); - inline FunctionTable &function_table(); // TODO: Some API for getting the error on the driver - inline ClassTable &class_table(); - inline CustomSerializerTable &custom_serializer_table(); - inline ConfigTable &config_table(); ObjectTable &object_table(); raylet::TaskTable &raylet_task_table(); ActorTable &actor_table(); @@ -81,8 +77,6 @@ class RAY_EXPORT AsyncGcsClient { std::string DebugString() const; private: - std::unique_ptr function_table_; - std::unique_ptr class_table_; std::unique_ptr object_table_; std::unique_ptr raylet_task_table_; std::unique_ptr actor_table_; diff --git a/src/ray/gcs/client_test.cc b/src/ray/gcs/client_test.cc index c7dc02e50651..55115b1e2067 100644 --- a/src/ray/gcs/client_test.cc +++ b/src/ray/gcs/client_test.cc @@ -85,21 +85,21 @@ class TestGcsWithChainAsio : public TestGcsWithAsio { void TestTableLookup(const DriverID &driver_id, std::shared_ptr client) { TaskID task_id = TaskID::FromRandom(); - auto data = std::make_shared(); - data->task_specification = "123"; + auto data = std::make_shared(); + data->set_task("123"); // Check that we added the correct task. auto add_callback = [task_id, data](gcs::AsyncGcsClient *client, const TaskID &id, - const protocol::TaskT &d) { + const TaskTableData &d) { ASSERT_EQ(id, task_id); - ASSERT_EQ(data->task_specification, d.task_specification); + ASSERT_EQ(data->task(), d.task()); }; // Check that the lookup returns the added task. auto lookup_callback = [task_id, data](gcs::AsyncGcsClient *client, const TaskID &id, - const protocol::TaskT &d) { + const TaskTableData &d) { ASSERT_EQ(id, task_id); - ASSERT_EQ(data->task_specification, d.task_specification); + ASSERT_EQ(data->task(), d.task()); test->Stop(); }; @@ -136,13 +136,13 @@ void TestLogLookup(const DriverID &driver_id, TaskID task_id = TaskID::FromRandom(); std::vector node_manager_ids = {"abc", "def", "ghi"}; for (auto &node_manager_id : node_manager_ids) { - auto data = std::make_shared(); - data->node_manager_id = node_manager_id; + auto data = std::make_shared(); + data->set_node_manager_id(node_manager_id); // Check that we added the correct object entries. auto add_callback = [task_id, data](gcs::AsyncGcsClient *client, const TaskID &id, - const TaskReconstructionDataT &d) { + const TaskReconstructionData &d) { ASSERT_EQ(id, task_id); - ASSERT_EQ(data->node_manager_id, d.node_manager_id); + ASSERT_EQ(data->node_manager_id(), d.node_manager_id()); }; RAY_CHECK_OK( client->task_reconstruction_log().Append(driver_id, task_id, data, add_callback)); @@ -151,10 +151,10 @@ void TestLogLookup(const DriverID &driver_id, // Check that lookup returns the added object entries. auto lookup_callback = [task_id, node_manager_ids]( gcs::AsyncGcsClient *client, const TaskID &id, - const std::vector &data) { + const std::vector &data) { ASSERT_EQ(id, task_id); for (const auto &entry : data) { - ASSERT_EQ(entry.node_manager_id, node_manager_ids[test->NumCallbacks()]); + ASSERT_EQ(entry.node_manager_id(), node_manager_ids[test->NumCallbacks()]); test->IncrementNumCallbacks(); } if (test->NumCallbacks() == node_manager_ids.size()) { @@ -182,7 +182,7 @@ void TestTableLookupFailure(const DriverID &driver_id, // Check that the lookup does not return data. auto lookup_callback = [](gcs::AsyncGcsClient *client, const TaskID &id, - const protocol::TaskT &d) { RAY_CHECK(false); }; + const TaskTableData &d) { RAY_CHECK(false); }; // Check that the lookup returns an empty entry. auto failure_callback = [task_id](gcs::AsyncGcsClient *client, const TaskID &id) { @@ -207,16 +207,16 @@ void TestLogAppendAt(const DriverID &driver_id, std::shared_ptr client) { TaskID task_id = TaskID::FromRandom(); std::vector node_manager_ids = {"A", "B"}; - std::vector> data_log; + std::vector> data_log; for (const auto &node_manager_id : node_manager_ids) { - auto data = std::make_shared(); - data->node_manager_id = node_manager_id; + auto data = std::make_shared(); + data->set_node_manager_id(node_manager_id); data_log.push_back(data); } // Check that we added the correct task. auto failure_callback = [task_id](gcs::AsyncGcsClient *client, const TaskID &id, - const TaskReconstructionDataT &d) { + const TaskReconstructionData &d) { ASSERT_EQ(id, task_id); test->IncrementNumCallbacks(); }; @@ -242,10 +242,10 @@ void TestLogAppendAt(const DriverID &driver_id, auto lookup_callback = [node_manager_ids]( gcs::AsyncGcsClient *client, const TaskID &id, - const std::vector &data) { + const std::vector &data) { std::vector appended_managers; for (const auto &entry : data) { - appended_managers.push_back(entry.node_manager_id); + appended_managers.push_back(entry.node_manager_id()); } ASSERT_EQ(appended_managers, node_manager_ids); test->Stop(); @@ -268,22 +268,22 @@ void TestSet(const DriverID &driver_id, std::shared_ptr cli ObjectID object_id = ObjectID::FromRandom(); std::vector managers = {"abc", "def", "ghi"}; for (auto &manager : managers) { - auto data = std::make_shared(); - data->manager = manager; + auto data = std::make_shared(); + data->set_manager(manager); // Check that we added the correct object entries. auto add_callback = [object_id, data](gcs::AsyncGcsClient *client, const ObjectID &id, - const ObjectTableDataT &d) { + const ObjectTableData &d) { ASSERT_EQ(id, object_id); - ASSERT_EQ(data->manager, d.manager); + ASSERT_EQ(data->manager(), d.manager()); test->IncrementNumCallbacks(); }; RAY_CHECK_OK(client->object_table().Add(driver_id, object_id, data, add_callback)); } // Check that lookup returns the added object entries. - auto lookup_callback = [object_id, managers]( - gcs::AsyncGcsClient *client, const ObjectID &id, - const std::vector &data) { + auto lookup_callback = [object_id, managers](gcs::AsyncGcsClient *client, + const ObjectID &id, + const std::vector &data) { ASSERT_EQ(id, object_id); ASSERT_EQ(data.size(), managers.size()); test->IncrementNumCallbacks(); @@ -293,14 +293,14 @@ void TestSet(const DriverID &driver_id, std::shared_ptr cli RAY_CHECK_OK(client->object_table().Lookup(driver_id, object_id, lookup_callback)); for (auto &manager : managers) { - auto data = std::make_shared(); - data->manager = manager; + auto data = std::make_shared(); + data->set_manager(manager); // Check that we added the correct object entries. auto remove_entry_callback = [object_id, data](gcs::AsyncGcsClient *client, const ObjectID &id, - const ObjectTableDataT &d) { + const ObjectTableData &d) { ASSERT_EQ(id, object_id); - ASSERT_EQ(data->manager, d.manager); + ASSERT_EQ(data->manager(), d.manager()); test->IncrementNumCallbacks(); }; RAY_CHECK_OK( @@ -310,7 +310,7 @@ void TestSet(const DriverID &driver_id, std::shared_ptr cli // Check that the entries are removed. auto lookup_callback2 = [object_id, managers]( gcs::AsyncGcsClient *client, const ObjectID &id, - const std::vector &data) { + const std::vector &data) { ASSERT_EQ(id, object_id); ASSERT_EQ(data.size(), 0); test->IncrementNumCallbacks(); @@ -332,7 +332,7 @@ TEST_F(TestGcsWithAsio, TestSet) { void TestDeleteKeysFromLog( const DriverID &driver_id, std::shared_ptr client, - std::vector> &data_vector) { + std::vector> &data_vector) { std::vector ids; TaskID task_id; for (auto &data : data_vector) { @@ -340,9 +340,9 @@ void TestDeleteKeysFromLog( ids.push_back(task_id); // Check that we added the correct object entries. auto add_callback = [task_id, data](gcs::AsyncGcsClient *client, const TaskID &id, - const TaskReconstructionDataT &d) { + const TaskReconstructionData &d) { ASSERT_EQ(id, task_id); - ASSERT_EQ(data->node_manager_id, d.node_manager_id); + ASSERT_EQ(data->node_manager_id(), d.node_manager_id()); test->IncrementNumCallbacks(); }; RAY_CHECK_OK( @@ -352,7 +352,7 @@ void TestDeleteKeysFromLog( // Check that lookup returns the added object entries. auto lookup_callback = [task_id, data_vector]( gcs::AsyncGcsClient *client, const TaskID &id, - const std::vector &data) { + const std::vector &data) { ASSERT_EQ(id, task_id); ASSERT_EQ(data.size(), 1); test->IncrementNumCallbacks(); @@ -367,7 +367,7 @@ void TestDeleteKeysFromLog( } for (const auto &task_id : ids) { auto lookup_callback = [task_id](gcs::AsyncGcsClient *client, const TaskID &id, - const std::vector &data) { + const std::vector &data) { ASSERT_EQ(id, task_id); ASSERT_TRUE(data.size() == 0); test->IncrementNumCallbacks(); @@ -379,7 +379,7 @@ void TestDeleteKeysFromLog( void TestDeleteKeysFromTable(const DriverID &driver_id, std::shared_ptr client, - std::vector> &data_vector, + std::vector> &data_vector, bool stop_at_end) { std::vector ids; TaskID task_id; @@ -388,16 +388,16 @@ void TestDeleteKeysFromTable(const DriverID &driver_id, ids.push_back(task_id); // Check that we added the correct object entries. auto add_callback = [task_id, data](gcs::AsyncGcsClient *client, const TaskID &id, - const protocol::TaskT &d) { + const TaskTableData &d) { ASSERT_EQ(id, task_id); - ASSERT_EQ(data->task_specification, d.task_specification); + ASSERT_EQ(data->task(), d.task()); test->IncrementNumCallbacks(); }; RAY_CHECK_OK(client->raylet_task_table().Add(driver_id, task_id, data, add_callback)); } for (const auto &task_id : ids) { auto task_lookup_callback = [task_id](gcs::AsyncGcsClient *client, const TaskID &id, - const protocol::TaskT &data) { + const TaskTableData &data) { ASSERT_EQ(id, task_id); test->IncrementNumCallbacks(); }; @@ -414,7 +414,7 @@ void TestDeleteKeysFromTable(const DriverID &driver_id, test->IncrementNumCallbacks(); }; auto undesired_callback = [](gcs::AsyncGcsClient *client, const TaskID &id, - const protocol::TaskT &data) { ASSERT_TRUE(false); }; + const TaskTableData &data) { ASSERT_TRUE(false); }; for (size_t i = 0; i < ids.size(); ++i) { RAY_CHECK_OK(client->raylet_task_table().Lookup( driver_id, task_id, undesired_callback, expected_failure_callback)); @@ -428,7 +428,7 @@ void TestDeleteKeysFromTable(const DriverID &driver_id, void TestDeleteKeysFromSet(const DriverID &driver_id, std::shared_ptr client, - std::vector> &data_vector) { + std::vector> &data_vector) { std::vector ids; ObjectID object_id; for (auto &data : data_vector) { @@ -436,9 +436,9 @@ void TestDeleteKeysFromSet(const DriverID &driver_id, ids.push_back(object_id); // Check that we added the correct object entries. auto add_callback = [object_id, data](gcs::AsyncGcsClient *client, const ObjectID &id, - const ObjectTableDataT &d) { + const ObjectTableData &d) { ASSERT_EQ(id, object_id); - ASSERT_EQ(data->manager, d.manager); + ASSERT_EQ(data->manager(), d.manager()); test->IncrementNumCallbacks(); }; RAY_CHECK_OK(client->object_table().Add(driver_id, object_id, data, add_callback)); @@ -447,7 +447,7 @@ void TestDeleteKeysFromSet(const DriverID &driver_id, // Check that lookup returns the added object entries. auto lookup_callback = [object_id, data_vector]( gcs::AsyncGcsClient *client, const ObjectID &id, - const std::vector &data) { + const std::vector &data) { ASSERT_EQ(id, object_id); ASSERT_EQ(data.size(), 1); test->IncrementNumCallbacks(); @@ -461,7 +461,7 @@ void TestDeleteKeysFromSet(const DriverID &driver_id, } for (const auto &object_id : ids) { auto lookup_callback = [object_id](gcs::AsyncGcsClient *client, const ObjectID &id, - const std::vector &data) { + const std::vector &data) { ASSERT_EQ(id, object_id); ASSERT_TRUE(data.size() == 0); test->IncrementNumCallbacks(); @@ -474,11 +474,11 @@ void TestDeleteKeysFromSet(const DriverID &driver_id, void TestDeleteKeys(const DriverID &driver_id, std::shared_ptr client) { // Test delete function for keys of Log. - std::vector> task_reconstruction_vector; + std::vector> task_reconstruction_vector; auto AppendTaskReconstructionData = [&task_reconstruction_vector](size_t add_count) { for (size_t i = 0; i < add_count; ++i) { - auto data = std::make_shared(); - data->node_manager_id = ObjectID::FromRandom().Hex(); + auto data = std::make_shared(); + data->set_node_manager_id(ObjectID::FromRandom().Hex()); task_reconstruction_vector.push_back(data); } }; @@ -503,11 +503,11 @@ void TestDeleteKeys(const DriverID &driver_id, TestDeleteKeysFromLog(driver_id, client, task_reconstruction_vector); // Test delete function for keys of Table. - std::vector> task_vector; + std::vector> task_vector; auto AppendTaskData = [&task_vector](size_t add_count) { for (size_t i = 0; i < add_count; ++i) { - auto task_data = std::make_shared(); - task_data->task_specification = ObjectID::FromRandom().Hex(); + auto task_data = std::make_shared(); + task_data->set_task(ObjectID::FromRandom().Hex()); task_vector.push_back(task_data); } }; @@ -529,11 +529,11 @@ void TestDeleteKeys(const DriverID &driver_id, 9 * RayConfig::instance().maximum_gcs_deletion_batch_size()); // Test delete function for keys of Set. - std::vector> object_vector; + std::vector> object_vector; auto AppendObjectData = [&object_vector](size_t add_count) { for (size_t i = 0; i < add_count; ++i) { - auto data = std::make_shared(); - data->manager = ObjectID::FromRandom().Hex(); + auto data = std::make_shared(); + data->set_manager(ObjectID::FromRandom().Hex()); object_vector.push_back(data); } }; @@ -561,45 +561,6 @@ TEST_F(TestGcsWithAsio, TestDeleteKey) { TestDeleteKeys(driver_id_, client_); } -// Task table callbacks. -void TaskAdded(gcs::AsyncGcsClient *client, const TaskID &id, - const TaskTableDataT &data) { - ASSERT_EQ(data.scheduling_state, SchedulingState::SCHEDULED); - ASSERT_EQ(data.raylet_id, kRandomId); -} - -void TaskLookupHelper(gcs::AsyncGcsClient *client, const TaskID &id, - const TaskTableDataT &data, bool do_stop) { - ASSERT_EQ(data.scheduling_state, SchedulingState::SCHEDULED); - ASSERT_EQ(data.raylet_id, kRandomId); - if (do_stop) { - test->Stop(); - } -} -void TaskLookup(gcs::AsyncGcsClient *client, const TaskID &id, - const TaskTableDataT &data) { - TaskLookupHelper(client, id, data, /*do_stop=*/false); -} -void TaskLookupWithStop(gcs::AsyncGcsClient *client, const TaskID &id, - const TaskTableDataT &data) { - TaskLookupHelper(client, id, data, /*do_stop=*/true); -} - -void TaskLookupFailure(gcs::AsyncGcsClient *client, const TaskID &id) { - RAY_CHECK(false); -} - -void TaskLookupAfterUpdate(gcs::AsyncGcsClient *client, const TaskID &id, - const TaskTableDataT &data) { - ASSERT_EQ(data.scheduling_state, SchedulingState::LOST); - test->Stop(); -} - -void TaskLookupAfterUpdateFailure(gcs::AsyncGcsClient *client, const TaskID &id) { - RAY_CHECK(false); - test->Stop(); -} - void TestLogSubscribeAll(const DriverID &driver_id, std::shared_ptr client) { std::vector driver_ids; @@ -609,11 +570,11 @@ void TestLogSubscribeAll(const DriverID &driver_id, // Callback for a notification. auto notification_callback = [driver_ids](gcs::AsyncGcsClient *client, const DriverID &id, - const std::vector data) { + const std::vector data) { ASSERT_EQ(id, driver_ids[test->NumCallbacks()]); // Check that we get notifications in the same order as the writes. for (const auto &entry : data) { - ASSERT_EQ(entry.driver_id, driver_ids[test->NumCallbacks()].Binary()); + ASSERT_EQ(entry.driver_id(), driver_ids[test->NumCallbacks()].Binary()); test->IncrementNumCallbacks(); } if (test->NumCallbacks() == driver_ids.size()) { @@ -660,7 +621,7 @@ void TestSetSubscribeAll(const DriverID &driver_id, auto notification_callback = [object_ids, managers]( gcs::AsyncGcsClient *client, const ObjectID &id, const GcsChangeMode change_mode, - const std::vector data) { + const std::vector data) { if (test->NumCallbacks() < 3 * 3) { ASSERT_EQ(change_mode, GcsChangeMode::APPEND_OR_ADD); } else { @@ -669,7 +630,7 @@ void TestSetSubscribeAll(const DriverID &driver_id, ASSERT_EQ(id, object_ids[test->NumCallbacks() / 3 % 3]); // Check that we get notifications in the same order as the writes. for (const auto &entry : data) { - ASSERT_EQ(entry.manager, managers[test->NumCallbacks() % 3]); + ASSERT_EQ(entry.manager(), managers[test->NumCallbacks() % 3]); test->IncrementNumCallbacks(); } if (test->NumCallbacks() == object_ids.size() * 3 * 2) { @@ -684,8 +645,8 @@ void TestSetSubscribeAll(const DriverID &driver_id, // We have subscribed. Do the writes to the table. for (size_t i = 0; i < object_ids.size(); i++) { for (size_t j = 0; j < managers.size(); j++) { - auto data = std::make_shared(); - data->manager = managers[j]; + auto data = std::make_shared(); + data->set_manager(managers[j]); for (int k = 0; k < 3; k++) { // Add the same entry several times. // Expect no notification if the entry already exists. @@ -696,8 +657,8 @@ void TestSetSubscribeAll(const DriverID &driver_id, } for (size_t i = 0; i < object_ids.size(); i++) { for (size_t j = 0; j < managers.size(); j++) { - auto data = std::make_shared(); - data->manager = managers[j]; + auto data = std::make_shared(); + data->set_manager(managers[j]); for (int k = 0; k < 3; k++) { // Remove the same entry several times. // Expect no notification if the entry doesn't exist. @@ -740,11 +701,11 @@ void TestTableSubscribeId(const DriverID &driver_id, // received for keys that we requested notifications for. auto notification_callback = [task_id2, task_specs2](gcs::AsyncGcsClient *client, const TaskID &id, - const protocol::TaskT &data) { + const TaskTableData &data) { // Check that we only get notifications for the requested key. ASSERT_EQ(id, task_id2); // Check that we get notifications in the same order as the writes. - ASSERT_EQ(data.task_specification, task_specs2[test->NumCallbacks()]); + ASSERT_EQ(data.task(), task_specs2[test->NumCallbacks()]); test->IncrementNumCallbacks(); if (test->NumCallbacks() == task_specs2.size()) { test->Stop(); @@ -771,13 +732,13 @@ void TestTableSubscribeId(const DriverID &driver_id, // Write both keys. We should only receive notifications for the key that // we requested them for. for (const auto &task_spec : task_specs1) { - auto data = std::make_shared(); - data->task_specification = task_spec; + auto data = std::make_shared(); + data->set_task(task_spec); RAY_CHECK_OK(client->raylet_task_table().Add(driver_id, task_id1, data, nullptr)); } for (const auto &task_spec : task_specs2) { - auto data = std::make_shared(); - data->task_specification = task_spec; + auto data = std::make_shared(); + data->set_task(task_spec); RAY_CHECK_OK(client->raylet_task_table().Add(driver_id, task_id2, data, nullptr)); } }; @@ -808,27 +769,27 @@ void TestLogSubscribeId(const DriverID &driver_id, // Add a log entry. DriverID driver_id1 = DriverID::FromRandom(); std::vector driver_ids1 = {"abc", "def", "ghi"}; - auto data1 = std::make_shared(); - data1->driver_id = driver_ids1[0]; + auto data1 = std::make_shared(); + data1->set_driver_id(driver_ids1[0]); RAY_CHECK_OK(client->driver_table().Append(driver_id, driver_id1, data1, nullptr)); // Add a log entry at a second key. DriverID driver_id2 = DriverID::FromRandom(); std::vector driver_ids2 = {"jkl", "mno", "pqr"}; - auto data2 = std::make_shared(); - data2->driver_id = driver_ids2[0]; + auto data2 = std::make_shared(); + data2->set_driver_id(driver_ids2[0]); RAY_CHECK_OK(client->driver_table().Append(driver_id, driver_id2, data2, nullptr)); // The callback for a notification from the table. This should only be // received for keys that we requested notifications for. auto notification_callback = [driver_id2, driver_ids2]( gcs::AsyncGcsClient *client, const UniqueID &id, - const std::vector &data) { + const std::vector &data) { // Check that we only get notifications for the requested key. ASSERT_EQ(id, driver_id2); // Check that we get notifications in the same order as the writes. for (const auto &entry : data) { - ASSERT_EQ(entry.driver_id, driver_ids2[test->NumCallbacks()]); + ASSERT_EQ(entry.driver_id(), driver_ids2[test->NumCallbacks()]); test->IncrementNumCallbacks(); } if (test->NumCallbacks() == driver_ids2.size()) { @@ -847,14 +808,14 @@ void TestLogSubscribeId(const DriverID &driver_id, // we requested them for. auto remaining = std::vector(++driver_ids1.begin(), driver_ids1.end()); for (const auto &driver_id_it : remaining) { - auto data = std::make_shared(); - data->driver_id = driver_id_it; + auto data = std::make_shared(); + data->set_driver_id(driver_id_it); RAY_CHECK_OK(client->driver_table().Append(driver_id, driver_id1, data, nullptr)); } remaining = std::vector(++driver_ids2.begin(), driver_ids2.end()); for (const auto &driver_id_it : remaining) { - auto data = std::make_shared(); - data->driver_id = driver_id_it; + auto data = std::make_shared(); + data->set_driver_id(driver_id_it); RAY_CHECK_OK(client->driver_table().Append(driver_id, driver_id2, data, nullptr)); } }; @@ -882,15 +843,15 @@ void TestSetSubscribeId(const DriverID &driver_id, // Add a set entry. ObjectID object_id1 = ObjectID::FromRandom(); std::vector managers1 = {"abc", "def", "ghi"}; - auto data1 = std::make_shared(); - data1->manager = managers1[0]; + auto data1 = std::make_shared(); + data1->set_manager(managers1[0]); RAY_CHECK_OK(client->object_table().Add(driver_id, object_id1, data1, nullptr)); // Add a set entry at a second key. ObjectID object_id2 = ObjectID::FromRandom(); std::vector managers2 = {"jkl", "mno", "pqr"}; - auto data2 = std::make_shared(); - data2->manager = managers2[0]; + auto data2 = std::make_shared(); + data2->set_manager(managers2[0]); RAY_CHECK_OK(client->object_table().Add(driver_id, object_id2, data2, nullptr)); // The callback for a notification from the table. This should only be @@ -898,13 +859,13 @@ void TestSetSubscribeId(const DriverID &driver_id, auto notification_callback = [object_id2, managers2]( gcs::AsyncGcsClient *client, const ObjectID &id, const GcsChangeMode change_mode, - const std::vector &data) { + const std::vector &data) { ASSERT_EQ(change_mode, GcsChangeMode::APPEND_OR_ADD); // Check that we only get notifications for the requested key. ASSERT_EQ(id, object_id2); // Check that we get notifications in the same order as the writes. for (const auto &entry : data) { - ASSERT_EQ(entry.manager, managers2[test->NumCallbacks()]); + ASSERT_EQ(entry.manager(), managers2[test->NumCallbacks()]); test->IncrementNumCallbacks(); } if (test->NumCallbacks() == managers2.size()) { @@ -923,14 +884,14 @@ void TestSetSubscribeId(const DriverID &driver_id, // we requested them for. auto remaining = std::vector(++managers1.begin(), managers1.end()); for (const auto &manager : remaining) { - auto data = std::make_shared(); - data->manager = manager; + auto data = std::make_shared(); + data->set_manager(manager); RAY_CHECK_OK(client->object_table().Add(driver_id, object_id1, data, nullptr)); } remaining = std::vector(++managers2.begin(), managers2.end()); for (const auto &manager : remaining) { - auto data = std::make_shared(); - data->manager = manager; + auto data = std::make_shared(); + data->set_manager(manager); RAY_CHECK_OK(client->object_table().Add(driver_id, object_id2, data, nullptr)); } }; @@ -958,8 +919,8 @@ void TestTableSubscribeCancel(const DriverID &driver_id, // Add a table entry. TaskID task_id = TaskID::FromRandom(); std::vector task_specs = {"jkl", "mno", "pqr"}; - auto data = std::make_shared(); - data->task_specification = task_specs[0]; + auto data = std::make_shared(); + data->set_task(task_specs[0]); RAY_CHECK_OK(client->raylet_task_table().Add(driver_id, task_id, data, nullptr)); // The failure callback should not be called since all keys are non-empty @@ -972,14 +933,14 @@ void TestTableSubscribeCancel(const DriverID &driver_id, // received for keys that we requested notifications for. auto notification_callback = [task_id, task_specs](gcs::AsyncGcsClient *client, const TaskID &id, - const protocol::TaskT &data) { + const TaskTableData &data) { ASSERT_EQ(id, task_id); // Check that we only get notifications for the first and last writes, // since notifications are canceled in between. if (test->NumCallbacks() == 0) { - ASSERT_EQ(data.task_specification, task_specs.front()); + ASSERT_EQ(data.task(), task_specs.front()); } else { - ASSERT_EQ(data.task_specification, task_specs.back()); + ASSERT_EQ(data.task(), task_specs.back()); } test->IncrementNumCallbacks(); if (test->NumCallbacks() == 2) { @@ -1001,8 +962,8 @@ void TestTableSubscribeCancel(const DriverID &driver_id, // a notification for these writes. auto remaining = std::vector(++task_specs.begin(), task_specs.end()); for (const auto &task_spec : remaining) { - auto data = std::make_shared(); - data->task_specification = task_spec; + auto data = std::make_shared(); + data->set_task(task_spec); RAY_CHECK_OK(client->raylet_task_table().Add(driver_id, task_id, data, nullptr)); } // Request notifications again. We should receive a notification for the @@ -1034,15 +995,15 @@ void TestLogSubscribeCancel(const DriverID &driver_id, // Add a log entry. DriverID random_driver_id = DriverID::FromRandom(); std::vector driver_ids = {"jkl", "mno", "pqr"}; - auto data = std::make_shared(); - data->driver_id = driver_ids[0]; + auto data = std::make_shared(); + data->set_driver_id(driver_ids[0]); RAY_CHECK_OK(client->driver_table().Append(driver_id, random_driver_id, data, nullptr)); // The callback for a notification from the object table. This should only be // received for the object that we requested notifications for. auto notification_callback = [random_driver_id, driver_ids]( gcs::AsyncGcsClient *client, const UniqueID &id, - const std::vector &data) { + const std::vector &data) { ASSERT_EQ(id, random_driver_id); // Check that we get a duplicate notification for the first write. We get a // duplicate notification because the log is append-only and notifications @@ -1050,7 +1011,7 @@ void TestLogSubscribeCancel(const DriverID &driver_id, auto driver_ids_copy = driver_ids; driver_ids_copy.insert(driver_ids_copy.begin(), driver_ids_copy.front()); for (const auto &entry : data) { - ASSERT_EQ(entry.driver_id, driver_ids_copy[test->NumCallbacks()]); + ASSERT_EQ(entry.driver_id(), driver_ids_copy[test->NumCallbacks()]); test->IncrementNumCallbacks(); } if (test->NumCallbacks() == driver_ids_copy.size()) { @@ -1072,8 +1033,8 @@ void TestLogSubscribeCancel(const DriverID &driver_id, // receive a notification for these writes. auto remaining = std::vector(++driver_ids.begin(), driver_ids.end()); for (const auto &remaining_driver_id : remaining) { - auto data = std::make_shared(); - data->driver_id = remaining_driver_id; + auto data = std::make_shared(); + data->set_driver_id(remaining_driver_id); RAY_CHECK_OK( client->driver_table().Append(driver_id, random_driver_id, data, nullptr)); } @@ -1107,8 +1068,8 @@ void TestSetSubscribeCancel(const DriverID &driver_id, // Add a set entry. ObjectID object_id = ObjectID::FromRandom(); std::vector managers = {"jkl", "mno", "pqr"}; - auto data = std::make_shared(); - data->manager = managers[0]; + auto data = std::make_shared(); + data->set_manager(managers[0]); RAY_CHECK_OK(client->object_table().Add(driver_id, object_id, data, nullptr)); // The callback for a notification from the object table. This should only be @@ -1116,7 +1077,7 @@ void TestSetSubscribeCancel(const DriverID &driver_id, auto notification_callback = [object_id, managers]( gcs::AsyncGcsClient *client, const ObjectID &id, const GcsChangeMode change_mode, - const std::vector &data) { + const std::vector &data) { ASSERT_EQ(change_mode, GcsChangeMode::APPEND_OR_ADD); ASSERT_EQ(id, object_id); // Check that we get a duplicate notification for the first write. We get a @@ -1124,7 +1085,7 @@ void TestSetSubscribeCancel(const DriverID &driver_id, // are canceled after the first write, then requested again. if (data.size() == 1) { // first notification - ASSERT_EQ(data[0].manager, managers[0]); + ASSERT_EQ(data[0].manager(), managers[0]); test->IncrementNumCallbacks(); } else { // second notification @@ -1132,7 +1093,7 @@ void TestSetSubscribeCancel(const DriverID &driver_id, std::unordered_set managers_set(managers.begin(), managers.end()); std::unordered_set data_managers_set; for (const auto &entry : data) { - data_managers_set.insert(entry.manager); + data_managers_set.insert(entry.manager()); test->IncrementNumCallbacks(); } ASSERT_EQ(managers_set, data_managers_set); @@ -1156,8 +1117,8 @@ void TestSetSubscribeCancel(const DriverID &driver_id, // receive a notification for these writes. auto remaining = std::vector(++managers.begin(), managers.end()); for (const auto &manager : remaining) { - auto data = std::make_shared(); - data->manager = manager; + auto data = std::make_shared(); + data->set_manager(manager); RAY_CHECK_OK(client->object_table().Add(driver_id, object_id, data, nullptr)); } // Request notifications again. We should receive a notification for the @@ -1186,17 +1147,17 @@ TEST_F(TestGcsWithAsio, TestSetSubscribeCancel) { } void ClientTableNotification(gcs::AsyncGcsClient *client, const ClientID &client_id, - const ClientTableDataT &data, bool is_insertion) { + const ClientTableData &data, bool is_insertion) { ClientID added_id = client->client_table().GetLocalClientId(); ASSERT_EQ(client_id, added_id); - ASSERT_EQ(ClientID::FromBinary(data.client_id), added_id); - ASSERT_EQ(ClientID::FromBinary(data.client_id), added_id); - ASSERT_EQ(data.entry_type == EntryType::INSERTION, is_insertion); + ASSERT_EQ(ClientID::FromBinary(data.client_id()), added_id); + ASSERT_EQ(ClientID::FromBinary(data.client_id()), added_id); + ASSERT_EQ(data.entry_type() == ClientTableData::INSERTION, is_insertion); - ClientTableDataT cached_client; + ClientTableData cached_client; client->client_table().GetClient(added_id, cached_client); - ASSERT_EQ(ClientID::FromBinary(cached_client.client_id), added_id); - ASSERT_EQ(cached_client.entry_type == EntryType::INSERTION, is_insertion); + ASSERT_EQ(ClientID::FromBinary(cached_client.client_id()), added_id); + ASSERT_EQ(cached_client.entry_type() == ClientTableData::INSERTION, is_insertion); } void TestClientTableConnect(const DriverID &driver_id, @@ -1204,17 +1165,17 @@ void TestClientTableConnect(const DriverID &driver_id, // Register callbacks for when a client gets added and removed. The latter // event will stop the event loop. client->client_table().RegisterClientAddedCallback( - [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) { + [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableData &data) { ClientTableNotification(client, id, data, true); test->Stop(); }); // Connect and disconnect to client table. We should receive notifications // for the addition and removal of our own entry. - ClientTableDataT local_client_info = client->client_table().GetLocalClient(); - local_client_info.node_manager_address = "127.0.0.1"; - local_client_info.node_manager_port = 0; - local_client_info.object_manager_port = 0; + ClientTableData local_client_info = client->client_table().GetLocalClient(); + local_client_info.set_node_manager_address("127.0.0.1"); + local_client_info.set_node_manager_port(0); + local_client_info.set_object_manager_port(0); RAY_CHECK_OK(client->client_table().Connect(local_client_info)); test->Start(); } @@ -1229,23 +1190,23 @@ void TestClientTableDisconnect(const DriverID &driver_id, // Register callbacks for when a client gets added and removed. The latter // event will stop the event loop. client->client_table().RegisterClientAddedCallback( - [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) { + [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableData &data) { ClientTableNotification(client, id, data, /*is_insertion=*/true); // Disconnect from the client table. We should receive a notification // for the removal of our own entry. RAY_CHECK_OK(client->client_table().Disconnect()); }); client->client_table().RegisterClientRemovedCallback( - [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) { + [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableData &data) { ClientTableNotification(client, id, data, /*is_insertion=*/false); test->Stop(); }); // Connect to the client table. We should receive notification for the // addition of our own entry. - ClientTableDataT local_client_info = client->client_table().GetLocalClient(); - local_client_info.node_manager_address = "127.0.0.1"; - local_client_info.node_manager_port = 0; - local_client_info.object_manager_port = 0; + ClientTableData local_client_info = client->client_table().GetLocalClient(); + local_client_info.set_node_manager_address("127.0.0.1"); + local_client_info.set_node_manager_port(0); + local_client_info.set_object_manager_port(0); RAY_CHECK_OK(client->client_table().Connect(local_client_info)); test->Start(); } @@ -1260,20 +1221,20 @@ void TestClientTableImmediateDisconnect(const DriverID &driver_id, // Register callbacks for when a client gets added and removed. The latter // event will stop the event loop. client->client_table().RegisterClientAddedCallback( - [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) { + [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableData &data) { ClientTableNotification(client, id, data, true); }); client->client_table().RegisterClientRemovedCallback( - [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) { + [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableData &data) { ClientTableNotification(client, id, data, false); test->Stop(); }); // Connect to then immediately disconnect from the client table. We should // receive notifications for the addition and removal of our own entry. - ClientTableDataT local_client_info = client->client_table().GetLocalClient(); - local_client_info.node_manager_address = "127.0.0.1"; - local_client_info.node_manager_port = 0; - local_client_info.object_manager_port = 0; + ClientTableData local_client_info = client->client_table().GetLocalClient(); + local_client_info.set_node_manager_address("127.0.0.1"); + local_client_info.set_node_manager_port(0); + local_client_info.set_object_manager_port(0); RAY_CHECK_OK(client->client_table().Connect(local_client_info)); RAY_CHECK_OK(client->client_table().Disconnect()); test->Start(); @@ -1286,10 +1247,10 @@ TEST_F(TestGcsWithAsio, TestClientTableImmediateDisconnect) { void TestClientTableMarkDisconnected(const DriverID &driver_id, std::shared_ptr client) { - ClientTableDataT local_client_info = client->client_table().GetLocalClient(); - local_client_info.node_manager_address = "127.0.0.1"; - local_client_info.node_manager_port = 0; - local_client_info.object_manager_port = 0; + ClientTableData local_client_info = client->client_table().GetLocalClient(); + local_client_info.set_node_manager_address("127.0.0.1"); + local_client_info.set_node_manager_port(0); + local_client_info.set_object_manager_port(0); // Connect to the client table to start receiving notifications. RAY_CHECK_OK(client->client_table().Connect(local_client_info)); // Mark a different client as dead. @@ -1299,8 +1260,8 @@ void TestClientTableMarkDisconnected(const DriverID &driver_id, // marked as dead. client->client_table().RegisterClientRemovedCallback( [dead_client_id](gcs::AsyncGcsClient *client, const UniqueID &id, - const ClientTableDataT &data) { - ASSERT_EQ(ClientID::FromBinary(data.client_id), dead_client_id); + const ClientTableData &data) { + ASSERT_EQ(ClientID::FromBinary(data.client_id()), dead_client_id); test->Stop(); }); test->Start(); @@ -1316,31 +1277,31 @@ void TestHashTable(const DriverID &driver_id, const int expected_count = 14; ClientID client_id = ClientID::FromRandom(); // Prepare the first resource map: data_map1. - auto cpu_data = std::make_shared(); - cpu_data->resource_name = "CPU"; - cpu_data->resource_capacity = 100; - auto gpu_data = std::make_shared(); - gpu_data->resource_name = "GPU"; - gpu_data->resource_capacity = 2; + auto cpu_data = std::make_shared(); + cpu_data->set_resource_name("CPU"); + cpu_data->set_resource_capacity(100); + auto gpu_data = std::make_shared(); + gpu_data->set_resource_name("GPU"); + gpu_data->set_resource_capacity(2); DynamicResourceTable::DataMap data_map1; data_map1.emplace("CPU", cpu_data); data_map1.emplace("GPU", gpu_data); // Prepare the second resource map: data_map2 which decreases CPU, // increases GPU and add a new CUSTOM compared to data_map1. - auto data_cpu = std::make_shared(); - data_cpu->resource_name = "CPU"; - data_cpu->resource_capacity = 50; - auto data_gpu = std::make_shared(); - data_gpu->resource_name = "GPU"; - data_gpu->resource_capacity = 10; - auto data_custom = std::make_shared(); - data_custom->resource_name = "CUSTOM"; - data_custom->resource_capacity = 2; + auto data_cpu = std::make_shared(); + data_cpu->set_resource_name("CPU"); + data_cpu->set_resource_capacity(50); + auto data_gpu = std::make_shared(); + data_gpu->set_resource_name("GPU"); + data_gpu->set_resource_capacity(10); + auto data_custom = std::make_shared(); + data_custom->set_resource_name("CUSTOM"); + data_custom->set_resource_capacity(2); DynamicResourceTable::DataMap data_map2; data_map2.emplace("CPU", data_cpu); data_map2.emplace("GPU", data_gpu); data_map2.emplace("CUSTOM", data_custom); - data_map2["CPU"]->resource_capacity = 50; + data_map2["CPU"]->set_resource_capacity(50); // This is a common comparison function for the test. auto compare_test = [](const DynamicResourceTable::DataMap &data1, const DynamicResourceTable::DataMap &data2) { @@ -1348,8 +1309,8 @@ void TestHashTable(const DriverID &driver_id, for (const auto &data : data1) { auto iter = data2.find(data.first); ASSERT_TRUE(iter != data2.end()); - ASSERT_EQ(iter->second->resource_name, data.second->resource_name); - ASSERT_EQ(iter->second->resource_capacity, data.second->resource_capacity); + ASSERT_EQ(iter->second->resource_name(), data.second->resource_name()); + ASSERT_EQ(iter->second->resource_capacity(), data.second->resource_capacity()); } }; auto subscribe_callback = [](AsyncGcsClient *client) { diff --git a/src/ray/gcs/format/gcs.fbs b/src/ray/gcs/format/gcs.fbs index 614c80b27672..c06c79a02928 100644 --- a/src/ray/gcs/format/gcs.fbs +++ b/src/ray/gcs/format/gcs.fbs @@ -1,52 +1,9 @@ -enum Language:int { - PYTHON = 0, - CPP = 1, - JAVA = 2 -} - -// These indexes are mapped to strings in ray_redis_module.cc. -enum TablePrefix:int { - UNUSED = 0, - TASK, - RAYLET_TASK, - CLIENT, - OBJECT, - ACTOR, - FUNCTION, - TASK_RECONSTRUCTION, - HEARTBEAT, - HEARTBEAT_BATCH, - ERROR_INFO, - DRIVER, - PROFILE, - TASK_LEASE, - ACTOR_CHECKPOINT, - ACTOR_CHECKPOINT_ID, - NODE_RESOURCE, -} +// TODO(hchen): Migrate data structures in this file to protobuf (`gcs.proto`). -// The channel that Add operations to the Table should be published on, if any. -enum TablePubsub:int { - NO_PUBLISH = 0, - TASK, - RAYLET_TASK, - CLIENT, - OBJECT, - ACTOR, - HEARTBEAT, - HEARTBEAT_BATCH, - ERROR_INFO, - TASK_LEASE, - DRIVER, - NODE_RESOURCE, -} - -// Enum for the entry type in the ClientTable -enum EntryType:int { - INSERTION = 0, - DELETION, - RES_CREATEUPDATE, - RES_DELETE, +enum Language:int { + PYTHON=0, + JAVA=1, + CPP=2, } table Arg { @@ -106,6 +63,11 @@ table TaskInfo { // For a Python function, it should be: [module_name, class_name, function_name] // For a Java function, it should be: [class_name, method_name, type_descriptor] function_descriptor: [string]; + // The dynamic options used in the worker command when starting the worker process for + // an actor creation task. If the list isn't empty, the options will be used to replace + // the placeholder strings (`RAY_WORKER_OPTION_0`, `RAY_WORKER_OPTION_1`, etc) in the + // worker command. + dynamic_worker_options: [string]; } table ResourcePair { @@ -115,118 +77,6 @@ table ResourcePair { value: double; } -enum GcsChangeMode:int { - APPEND_OR_ADD = 0, - REMOVE, -} - -table GcsEntry { - change_mode: GcsChangeMode; - id: string; - entries: [string]; -} - -table FunctionTableData { - language: Language; - name: string; - data: string; -} - -table ObjectTableData { - // The size of the object. - object_size: long; - // The node manager ID that this object appeared on or was evicted by. - manager: string; -} - -table TaskReconstructionData { - // The number of times this task has been reconstructed so far. - num_reconstructions: int; - // The node manager that is trying to reconstruct the task. - node_manager_id: string; -} - -enum SchedulingState:int { - NONE = 0, - WAITING = 1, - SCHEDULED = 2, - QUEUED = 4, - RUNNING = 8, - DONE = 16, - LOST = 32, - RECONSTRUCTING = 64 -} - -table TaskTableData { - // The state of the task. - scheduling_state: SchedulingState; - // A raylet ID. - raylet_id: string; - // A string of bytes representing the task's TaskExecutionDependencies. - execution_dependencies: string; - // The number of times the task was spilled back by raylets. - spillback_count: long; - // A string of bytes representing the task specification. - task_info: string; - // TODO(pcm): This is at the moment duplicated in task_info, remove that one - updated: bool; -} - -table TaskTableTestAndUpdate { - test_raylet_id: string; - test_state_bitmask: SchedulingState; - update_state: SchedulingState; -} - -table ClassTableData { -} - -enum ActorState:int { - // Actor is alive. - ALIVE = 0, - // Actor is dead, now being reconstructed. - // After reconstruction finishes, the state will become alive again. - RECONSTRUCTING = 1, - // Actor is already dead and won't be reconstructed. - DEAD = 2 -} - -table ActorTableData { - // The ID of the actor that was created. - actor_id: string; - // The dummy object ID returned by the actor creation task. If the actor - // dies, then this is the object that should be reconstructed for the actor - // to be recreated. - actor_creation_dummy_object_id: string; - // The ID of the driver that created the actor. - driver_id: string; - // The ID of the node manager that created the actor. - node_manager_id: string; - // Current state of this actor. - state: ActorState; - // Max number of times this actor should be reconstructed. - max_reconstructions: int; - // Remaining number of reconstructions. - remaining_reconstructions: int; -} - -table ErrorTableData { - // The ID of the driver that the error is for. - driver_id: string; - // The type of the error. - type: string; - // The error message. - error_message: string; - // The timestamp of the error message. - timestamp: double; -} - -table CustomSerializerData { -} - -table ConfigTableData { -} - table ProfileEvent { // The type of the event. event_type: string; @@ -253,119 +103,3 @@ table ProfileTableData { // we don't want each event to require a GCS command. profile_events: [ProfileEvent]; } - -table RayResource { - // The type of the resource. - resource_name: string; - // The total capacity of this resource type. - resource_capacity: double; -} - -table ClientTableData { - // The client ID of the client that the message is about. - client_id: string; - // The IP address of the client's node manager. - node_manager_address: string; - // The IPC socket name of the client's raylet. - raylet_socket_name: string; - // The IPC socket name of the client's plasma store. - object_store_socket_name: string; - // The port at which the client's node manager is listening for TCP - // connections from other node managers. - node_manager_port: int; - // The port at which the client's object manager is listening for TCP - // connections from other object managers. - object_manager_port: int; - // Enum to store the entry type in the log - entry_type: EntryType = INSERTION; - resources_total_label: [string]; - resources_total_capacity: [double]; -} - -table HeartbeatTableData { - // Node manager client id - client_id: string; - // Resource capacity currently available on this node manager. - resources_available_label: [string]; - resources_available_capacity: [double]; - // Total resource capacity configured for this node manager. - resources_total_label: [string]; - resources_total_capacity: [double]; - // Aggregate outstanding resource load on this node manager. - resource_load_label: [string]; - resource_load_capacity: [double]; -} - -table HeartbeatBatchTableData { - batch: [HeartbeatTableData]; -} - -// Data for a lease on task execution. -table TaskLeaseData { - // Node manager client ID. - node_manager_id: string; - // The time that the lease was last acquired at. NOTE(swang): This is the - // system clock time according to the node that added the entry and is not - // synchronized with other nodes. - acquired_at: long; - // The period that the lease is active for. - timeout: long; -} - -table DriverTableData { - // The driver ID. - driver_id: string; - // Whether it's dead. - is_dead: bool; -} - -// This table stores the actor checkpoint data. An actor checkpoint -// is the snapshot of an actor's state in the actor registration. -// See `actor_registration.h` for more detailed explanation of these fields. -table ActorCheckpointData { - // ID of this actor. - actor_id: string; - // The dummy object ID of actor's most recently executed task. - execution_dependency: string; - // A list of IDs of this actor's handles. - handle_ids: [string]; - // The task counters of the above handles. - task_counters: [long]; - // The frontier dependencies of the above handles. - frontier_dependencies: [string]; - // A list of unreleased dummy objects from this actor. - unreleased_dummy_objects: [string]; - // The numbers of dependencies for the above unreleased dummy objects. - num_dummy_object_dependencies: [int]; -} - -// This table stores the actor-to-available-checkpoint-ids mapping. -table ActorCheckpointIdData { - // ID of this actor. - actor_id: string; - // IDs of this actor's available checkpoints. - // Note, this is a long string that concatenates all the IDs. - checkpoint_ids: string; - // A list of the timestamps for each of the above `checkpoint_ids`. - timestamps: [long]; -} - -// This enum type is used as object's metadata to indicate the object's creating -// task has failed because of a certain error. -// TODO(hchen): We may want to make these errors more specific. E.g., we may want -// to distinguish between intentional and expected actor failures, and between -// worker process failure and node failure. -enum ErrorType:int { - // Indicates that a task failed because the worker died unexpectedly while executing it. - WORKER_DIED = 1, - // Indicates that a task failed because the actor died unexpectedly before finishing it. - ACTOR_DIED = 2, - // Indicates that an object is lost and cannot be reconstructed. - // Note, this currently only happens to actor objects. When the actor's state is already - // after the object's creating task, the actor cannot re-run the task. - // TODO(hchen): we may want to reuse this error type for more cases. E.g., - // 1) A object that was put by the driver. - // 2) The object's creating task is already cleaned up from GCS (this currently - // crashes raylet). - OBJECT_UNRECONSTRUCTABLE = 3, -} diff --git a/src/ray/gcs/redis_context.h b/src/ray/gcs/redis_context.h index fc42e5cd98c2..093aab2455d9 100644 --- a/src/ray/gcs/redis_context.h +++ b/src/ray/gcs/redis_context.h @@ -9,7 +9,7 @@ #include "ray/common/status.h" #include "ray/util/logging.h" -#include "ray/gcs/format/gcs_generated.h" +#include "ray/protobuf/gcs.pb.h" extern "C" { #include "ray/thirdparty/hiredis/adapters/ae.h" @@ -25,6 +25,9 @@ namespace ray { namespace gcs { +using rpc::TablePrefix; +using rpc::TablePubsub; + /// A simple reply wrapper for redis reply. class CallbackReply { public: @@ -126,8 +129,8 @@ class RedisContext { /// -1 for unused. If set, then data must be provided. /// \return Status. template - Status RunAsync(const std::string &command, const ID &id, const uint8_t *data, - int64_t length, const TablePrefix prefix, + Status RunAsync(const std::string &command, const ID &id, const void *data, + size_t length, const TablePrefix prefix, const TablePubsub pubsub_channel, RedisCallback redisCallback, int log_length = -1); @@ -157,9 +160,9 @@ class RedisContext { }; template -Status RedisContext::RunAsync(const std::string &command, const ID &id, - const uint8_t *data, int64_t length, - const TablePrefix prefix, const TablePubsub pubsub_channel, +Status RedisContext::RunAsync(const std::string &command, const ID &id, const void *data, + size_t length, const TablePrefix prefix, + const TablePubsub pubsub_channel, RedisCallback redisCallback, int log_length) { int64_t callback_index = RedisCallbackManager::instance().add(redisCallback, false); if (length > 0) { diff --git a/src/ray/gcs/redis_module/ray_redis_module.cc b/src/ray/gcs/redis_module/ray_redis_module.cc index e291b7ffdb32..c3a82c320d06 100644 --- a/src/ray/gcs/redis_module/ray_redis_module.cc +++ b/src/ray/gcs/redis_module/ray_redis_module.cc @@ -5,11 +5,16 @@ #include "ray/common/id.h" #include "ray/common/status.h" #include "ray/gcs/format/gcs_generated.h" +#include "ray/protobuf/gcs.pb.h" #include "ray/util/logging.h" #include "redis_string.h" #include "redismodule.h" using ray::Status; +using ray::rpc::GcsChangeMode; +using ray::rpc::GcsEntry; +using ray::rpc::TablePrefix; +using ray::rpc::TablePubsub; #if RAY_USE_NEW_GCS // Under this flag, ray-project/credis will be loaded. Specifically, via @@ -64,8 +69,8 @@ Status ParseTablePubsub(TablePubsub *out, const RedisModuleString *pubsub_channe REDISMODULE_OK) { return Status::RedisError("Pubsub channel must be a valid integer."); } - if (pubsub_channel_long > static_cast(TablePubsub::MAX) || - pubsub_channel_long < static_cast(TablePubsub::MIN)) { + if (pubsub_channel_long >= static_cast(TablePubsub::TABLE_PUBSUB_MAX) || + pubsub_channel_long <= static_cast(TablePubsub::TABLE_PUBSUB_MIN)) { return Status::RedisError("Pubsub channel must be in the TablePubsub range."); } else { *out = static_cast(pubsub_channel_long); @@ -80,7 +85,7 @@ Status FormatPubsubChannel(RedisModuleString **out, RedisModuleCtx *ctx, const RedisModuleString *id) { // Format the pubsub channel enum to a string. TablePubsub_MAX should be more // than enough digits, but add 1 just in case for the null terminator. - char pubsub_channel[static_cast(TablePubsub::MAX) + 1]; + char pubsub_channel[static_cast(TablePubsub::TABLE_PUBSUB_MAX) + 1]; TablePubsub table_pubsub; RAY_RETURN_NOT_OK(ParseTablePubsub(&table_pubsub, pubsub_channel_str)); sprintf(pubsub_channel, "%d", static_cast(table_pubsub)); @@ -95,8 +100,8 @@ Status ParseTablePrefix(const RedisModuleString *table_prefix_str, TablePrefix * REDISMODULE_OK) { return Status::RedisError("Prefix must be a valid TablePrefix integer"); } - if (table_prefix_long > static_cast(TablePrefix::MAX) || - table_prefix_long < static_cast(TablePrefix::MIN)) { + if (table_prefix_long >= static_cast(TablePrefix::TABLE_PREFIX_MAX) || + table_prefix_long <= static_cast(TablePrefix::TABLE_PREFIX_MIN)) { return Status::RedisError("Prefix must be in the TablePrefix range"); } else { *out = static_cast(table_prefix_long); @@ -113,7 +118,7 @@ RedisModuleString *PrefixedKeyString(RedisModuleCtx *ctx, RedisModuleString *pre if (!ParseTablePrefix(prefix_enum, &prefix).ok()) { return nullptr; } - return RedisString_Format(ctx, "%s%S", EnumNameTablePrefix(prefix), keyname); + return RedisString_Format(ctx, "%s%S", TablePrefix_Name(prefix).c_str(), keyname); } // TODO(swang): This helper function should be deprecated by the version below, @@ -136,8 +141,8 @@ Status OpenPrefixedKey(RedisModuleKey **out, RedisModuleCtx *ctx, int mode, RedisModuleString **mutated_key_str) { TablePrefix prefix; RAY_RETURN_NOT_OK(ParseTablePrefix(prefix_enum, &prefix)); - *out = - OpenPrefixedKey(ctx, EnumNameTablePrefix(prefix), keyname, mode, mutated_key_str); + *out = OpenPrefixedKey(ctx, TablePrefix_Name(prefix).c_str(), keyname, mode, + mutated_key_str); return Status::OK(); } @@ -165,18 +170,24 @@ Status GetBroadcastKey(RedisModuleCtx *ctx, RedisModuleString *pubsub_channel_st return Status::OK(); } -/// This is a helper method to convert a redis module string to a flatbuffer -/// string. +/// A helper function that creates `GcsEntry` protobuf object. /// -/// \param fbb The flatbuffer builder. -/// \param redis_string The redis string. -/// \return The flatbuffer string. -flatbuffers::Offset RedisStringToFlatbuf( - flatbuffers::FlatBufferBuilder &fbb, RedisModuleString *redis_string) { - size_t redis_string_size; - const char *redis_string_str = - RedisModule_StringPtrLen(redis_string, &redis_string_size); - return fbb.CreateString(redis_string_str, redis_string_size); +/// \param[in] id Id of the entry. +/// \param[in] change_mode Change mode of the entry. +/// \param[in] entries Vector of entries. +/// \param[out] result The created `GcsEntry` object. +inline void CreateGcsEntry(RedisModuleString *id, GcsChangeMode change_mode, + const std::vector &entries, + GcsEntry *result) { + const char *data; + size_t size; + data = RedisModule_StringPtrLen(id, &size); + result->set_id(data, size); + result->set_change_mode(change_mode); + for (const auto &entry : entries) { + data = RedisModule_StringPtrLen(entry, &size); + result->add_entries(data, size); + } } /// Helper method to publish formatted data to target channel. @@ -234,13 +245,10 @@ int PublishTableUpdate(RedisModuleCtx *ctx, RedisModuleString *pubsub_channel_st RedisModuleString *id, GcsChangeMode change_mode, RedisModuleString *data) { // Serialize the notification to send. - flatbuffers::FlatBufferBuilder fbb; - auto data_flatbuf = RedisStringToFlatbuf(fbb, data); - auto message = CreateGcsEntry(fbb, change_mode, RedisStringToFlatbuf(fbb, id), - fbb.CreateVector(&data_flatbuf, 1)); - fbb.Finish(message); - auto data_buffer = RedisModule_CreateString( - ctx, reinterpret_cast(fbb.GetBufferPointer()), fbb.GetSize()); + GcsEntry gcs_entry; + CreateGcsEntry(id, change_mode, {data}, &gcs_entry); + std::string str = gcs_entry.SerializeAsString(); + auto data_buffer = RedisModule_CreateString(ctx, str.data(), str.size()); return PublishDataHelper(ctx, pubsub_channel_str, id, data_buffer); } @@ -570,19 +578,20 @@ int HashUpdate_DoWrite(RedisModuleCtx *ctx, RedisModuleString **argv, int argc, size_t update_data_len = 0; const char *update_data_buf = RedisModule_StringPtrLen(update_data, &update_data_len); - auto data_vec = flatbuffers::GetRoot(update_data_buf); - *change_mode = data_vec->change_mode(); + GcsEntry gcs_entry; + gcs_entry.ParseFromArray(update_data_buf, update_data_len); + *change_mode = gcs_entry.change_mode(); + if (*change_mode == GcsChangeMode::APPEND_OR_ADD) { // This code path means they are updating command. - size_t total_size = data_vec->entries()->size(); + size_t total_size = gcs_entry.entries_size(); REPLY_AND_RETURN_IF_FALSE(total_size % 2 == 0, "Invalid Hash Update data vector."); for (int i = 0; i < total_size; i += 2) { // Reconstruct a key-value pair from a flattened list. RedisModuleString *entry_key = RedisModule_CreateString( - ctx, data_vec->entries()->Get(i)->data(), data_vec->entries()->Get(i)->size()); - RedisModuleString *entry_value = - RedisModule_CreateString(ctx, data_vec->entries()->Get(i + 1)->data(), - data_vec->entries()->Get(i + 1)->size()); + ctx, gcs_entry.entries(i).data(), gcs_entry.entries(i).size()); + RedisModuleString *entry_value = RedisModule_CreateString( + ctx, gcs_entry.entries(i + 1).data(), gcs_entry.entries(i + 1).size()); // Returning 0 if key exists(still updated), 1 if the key is created. RAY_IGNORE_EXPR( RedisModule_HashSet(key, REDISMODULE_HASH_NONE, entry_key, entry_value, NULL)); @@ -590,27 +599,25 @@ int HashUpdate_DoWrite(RedisModuleCtx *ctx, RedisModuleString **argv, int argc, *changed_data = update_data; } else { // This code path means the command wants to remove the entries. - size_t total_size = data_vec->entries()->size(); - flatbuffers::FlatBufferBuilder fbb; - std::vector> data; + GcsEntry updated; + updated.set_id(gcs_entry.id()); + updated.set_change_mode(gcs_entry.change_mode()); + + size_t total_size = gcs_entry.entries_size(); for (int i = 0; i < total_size; i++) { RedisModuleString *entry_key = RedisModule_CreateString( - ctx, data_vec->entries()->Get(i)->data(), data_vec->entries()->Get(i)->size()); + ctx, gcs_entry.entries(i).data(), gcs_entry.entries(i).size()); int deleted_num = RedisModule_HashSet(key, REDISMODULE_HASH_NONE, entry_key, REDISMODULE_HASH_DELETE, NULL); if (deleted_num != 0) { // The corresponding key is removed. - data.push_back(fbb.CreateString(data_vec->entries()->Get(i)->data(), - data_vec->entries()->Get(i)->size())); + updated.add_entries(gcs_entry.entries(i)); } } - auto message = - CreateGcsEntry(fbb, data_vec->change_mode(), - fbb.CreateString(data_vec->id()->data(), data_vec->id()->size()), - fbb.CreateVector(data)); - fbb.Finish(message); - *changed_data = RedisModule_CreateString( - ctx, reinterpret_cast(fbb.GetBufferPointer()), fbb.GetSize()); + + // Serialize updated data. + std::string str = updated.SerializeAsString(); + *changed_data = RedisModule_CreateString(ctx, str.data(), str.size()); auto size = RedisModule_ValueLength(key); if (size == 0) { REPLY_AND_RETURN_IF_FALSE(RedisModule_DeleteKey(key) == REDISMODULE_OK, @@ -631,7 +638,7 @@ int HashUpdate_DoWrite(RedisModuleCtx *ctx, RedisModuleString **argv, int argc, /// key should be published to. When publishing to a specific client, the /// channel name should be :. /// \param id The ID of the key to remove from. -/// \param data The GcsEntry flatbugger data used to update this hash table. +/// \param data The GcsEntry protobuf data used to update this hash table. /// 1). For deletion, this is a list of keys. /// 2). For updating, this is a list of pairs with each key followed by the value. /// \return OK if the remove succeeds, or an error message string if the remove @@ -648,7 +655,7 @@ int HashUpdate_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int a return Hash_DoPublish(ctx, new_argv.data()); } -/// A helper function to create and finish a GcsEntry, based on the +/// A helper function to create a GcsEntry protobuf, based on the /// current value or values at the given key. /// /// \param ctx The Redis module context. @@ -658,21 +665,18 @@ int HashUpdate_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int a /// \param prefix_str The string prefix associated with the open Redis key. /// When parsed, this is expected to be a TablePrefix. /// \param entry_id The UniqueID associated with the open Redis key. -/// \param fbb A flatbuffer builder used to build the GcsEntry. -Status TableEntryToFlatbuf(RedisModuleCtx *ctx, RedisModuleKey *table_key, - RedisModuleString *prefix_str, RedisModuleString *entry_id, - flatbuffers::FlatBufferBuilder &fbb) { +/// \param[out] gcs_entry The created GcsEntry. +Status TableEntryToProtobuf(RedisModuleCtx *ctx, RedisModuleKey *table_key, + RedisModuleString *prefix_str, RedisModuleString *entry_id, + GcsEntry *gcs_entry) { auto key_type = RedisModule_KeyType(table_key); switch (key_type) { case REDISMODULE_KEYTYPE_STRING: { - // Build the flatbuffer from the string data. + // Build the GcsEntry from the string data. + CreateGcsEntry(entry_id, GcsChangeMode::APPEND_OR_ADD, {}, gcs_entry); size_t data_len = 0; char *data_buf = RedisModule_StringDMA(table_key, &data_len, REDISMODULE_READ); - auto data = fbb.CreateString(data_buf, data_len); - auto message = - CreateGcsEntry(fbb, GcsChangeMode::APPEND_OR_ADD, - RedisStringToFlatbuf(fbb, entry_id), fbb.CreateVector(&data, 1)); - fbb.Finish(message); + gcs_entry->add_entries(data_buf, data_len); } break; case REDISMODULE_KEYTYPE_LIST: case REDISMODULE_KEYTYPE_HASH: @@ -696,27 +700,20 @@ Status TableEntryToFlatbuf(RedisModuleCtx *ctx, RedisModuleKey *table_key, reply = RedisModule_Call(ctx, "HGETALL", "s", table_key_str); break; } - // Build the flatbuffer from the set of log entries. + // Build the GcsEntry from the set of log entries. if (reply == nullptr || RedisModule_CallReplyType(reply) != REDISMODULE_REPLY_ARRAY) { return Status::RedisError("Empty list/set/hash or wrong type"); } - std::vector> data; + CreateGcsEntry(entry_id, GcsChangeMode::APPEND_OR_ADD, {}, gcs_entry); for (size_t i = 0; i < RedisModule_CallReplyLength(reply); i++) { RedisModuleCallReply *element = RedisModule_CallReplyArrayElement(reply, i); size_t len; const char *element_str = RedisModule_CallReplyStringPtr(element, &len); - data.push_back(fbb.CreateString(element_str, len)); + gcs_entry->add_entries(element_str, len); } - auto message = - CreateGcsEntry(fbb, GcsChangeMode::APPEND_OR_ADD, - RedisStringToFlatbuf(fbb, entry_id), fbb.CreateVector(data)); - fbb.Finish(message); } break; case REDISMODULE_KEYTYPE_EMPTY: { - auto message = CreateGcsEntry( - fbb, GcsChangeMode::APPEND_OR_ADD, RedisStringToFlatbuf(fbb, entry_id), - fbb.CreateVector(std::vector>())); - fbb.Finish(message); + CreateGcsEntry(entry_id, GcsChangeMode::APPEND_OR_ADD, {}, gcs_entry); } break; default: return Status::RedisError("Invalid Redis type during lookup."); @@ -752,11 +749,12 @@ int TableLookup_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int if (table_key == nullptr) { RedisModule_ReplyWithNull(ctx); } else { - // Serialize the data to a flatbuffer to return to the client. - flatbuffers::FlatBufferBuilder fbb; - REPLY_AND_RETURN_IF_NOT_OK(TableEntryToFlatbuf(ctx, table_key, prefix_str, id, fbb)); - RedisModule_ReplyWithStringBuffer( - ctx, reinterpret_cast(fbb.GetBufferPointer()), fbb.GetSize()); + // Serialize the data to a GcsEntry to return to the client. + GcsEntry gcs_entry; + REPLY_AND_RETURN_IF_NOT_OK( + TableEntryToProtobuf(ctx, table_key, prefix_str, id, &gcs_entry)); + std::string str = gcs_entry.SerializeAsString(); + RedisModule_ReplyWithStringBuffer(ctx, str.data(), str.size()); } return REDISMODULE_OK; } @@ -870,10 +868,11 @@ int TableRequestNotifications_RedisCommand(RedisModuleCtx *ctx, RedisModuleStrin // Publish the current value at the key to the client that is requesting // notifications. An empty notification will be published if the key is // empty. - flatbuffers::FlatBufferBuilder fbb; - REPLY_AND_RETURN_IF_NOT_OK(TableEntryToFlatbuf(ctx, table_key, prefix_str, id, fbb)); - RedisModule_Call(ctx, "PUBLISH", "sb", client_channel, - reinterpret_cast(fbb.GetBufferPointer()), fbb.GetSize()); + GcsEntry gcs_entry; + REPLY_AND_RETURN_IF_NOT_OK( + TableEntryToProtobuf(ctx, table_key, prefix_str, id, &gcs_entry)); + std::string str = gcs_entry.SerializeAsString(); + RedisModule_Call(ctx, "PUBLISH", "sb", client_channel, str.data(), str.size()); return RedisModule_ReplyWithNull(ctx); } @@ -940,53 +939,6 @@ Status IsNil(bool *out, const std::string &data) { return Status::OK(); } -// This is a temporary redis command that will be removed once -// the GCS uses https://github.com/pcmoritz/credis. -// Be careful, this only supports Task Table payloads. -int TableTestAndUpdate_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, - int argc) { - if (argc != 5) { - return RedisModule_WrongArity(ctx); - } - RedisModuleString *prefix_str = argv[1]; - RedisModuleString *id = argv[3]; - RedisModuleString *update_data = argv[4]; - - RedisModuleKey *key; - REPLY_AND_RETURN_IF_NOT_OK( - OpenPrefixedKey(&key, ctx, prefix_str, id, REDISMODULE_READ | REDISMODULE_WRITE)); - - size_t value_len = 0; - char *value_buf = RedisModule_StringDMA(key, &value_len, REDISMODULE_READ); - - size_t update_len = 0; - const char *update_buf = RedisModule_StringPtrLen(update_data, &update_len); - - auto data = - flatbuffers::GetMutableRoot(reinterpret_cast(value_buf)); - - auto update = flatbuffers::GetRoot(update_buf); - - bool do_update = static_cast(data->scheduling_state()) & - static_cast(update->test_state_bitmask()); - - bool is_nil_result; - REPLY_AND_RETURN_IF_NOT_OK(IsNil(&is_nil_result, update->test_raylet_id()->str())); - if (!is_nil_result) { - do_update = do_update && update->test_raylet_id()->str() == data->raylet_id()->str(); - } - - if (do_update) { - REPLY_AND_RETURN_IF_FALSE(data->mutate_scheduling_state(update->update_state()), - "mutate_scheduling_state failed"); - } - REPLY_AND_RETURN_IF_FALSE(data->mutate_updated(do_update), "mutate_updated failed"); - - int result = RedisModule_ReplyWithStringBuffer(ctx, value_buf, value_len); - - return result; -} - std::string DebugString() { std::stringstream result; result << "RedisModule:"; @@ -1016,7 +968,6 @@ AUTO_MEMORY(TableLookup_RedisCommand); AUTO_MEMORY(TableRequestNotifications_RedisCommand); AUTO_MEMORY(TableDelete_RedisCommand); AUTO_MEMORY(TableCancelNotifications_RedisCommand); -AUTO_MEMORY(TableTestAndUpdate_RedisCommand); AUTO_MEMORY(DebugString_RedisCommand); #if RAY_USE_NEW_GCS AUTO_MEMORY(ChainTableAdd_RedisCommand); @@ -1082,12 +1033,6 @@ int RedisModule_OnLoad(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) return REDISMODULE_ERR; } - if (RedisModule_CreateCommand(ctx, "ray.table_test_and_update", - TableTestAndUpdate_RedisCommand, "write", 0, 0, - 0) == REDISMODULE_ERR) { - return REDISMODULE_ERR; - } - if (RedisModule_CreateCommand(ctx, "ray.debug_string", DebugString_RedisCommand, "readonly", 0, 0, 0) == REDISMODULE_ERR) { return REDISMODULE_ERR; diff --git a/src/ray/gcs/tables.cc b/src/ray/gcs/tables.cc index 33f1615580a6..b7c19ebfd595 100644 --- a/src/ray/gcs/tables.cc +++ b/src/ray/gcs/tables.cc @@ -3,6 +3,7 @@ #include "ray/common/common_protocol.h" #include "ray/common/ray_config.h" #include "ray/gcs/client.h" +#include "ray/rpc/util.h" #include "ray/util/util.h" namespace { @@ -39,48 +40,44 @@ namespace gcs { template Status Log::Append(const DriverID &driver_id, const ID &id, - std::shared_ptr &dataT, const WriteCallback &done) { + std::shared_ptr &data, const WriteCallback &done) { num_appends_++; - auto callback = [this, id, dataT, done](const CallbackReply &reply) { + auto callback = [this, id, data, done](const CallbackReply &reply) { const auto status = reply.ReadAsStatus(); // Failed to append the entry. RAY_CHECK(status.ok()) << "Failed to execute command TABLE_APPEND:" << status.ToString(); if (done != nullptr) { - (done)(client_, id, *dataT); + (done)(client_, id, *data); } }; - flatbuffers::FlatBufferBuilder fbb; - fbb.ForceDefaults(true); - fbb.Finish(Data::Pack(fbb, dataT.get())); - return GetRedisContext(id)->RunAsync(GetLogAppendCommand(command_type_), id, - fbb.GetBufferPointer(), fbb.GetSize(), prefix_, - pubsub_channel_, std::move(callback)); + std::string str = data->SerializeAsString(); + return GetRedisContext(id)->RunAsync(GetLogAppendCommand(command_type_), id, str.data(), + str.length(), prefix_, pubsub_channel_, + std::move(callback)); } template Status Log::AppendAt(const DriverID &driver_id, const ID &id, - std::shared_ptr &dataT, const WriteCallback &done, + std::shared_ptr &data, const WriteCallback &done, const WriteCallback &failure, int log_length) { num_appends_++; - auto callback = [this, id, dataT, done, failure](const CallbackReply &reply) { + auto callback = [this, id, data, done, failure](const CallbackReply &reply) { const auto status = reply.ReadAsStatus(); if (status.ok()) { if (done != nullptr) { - (done)(client_, id, *dataT); + (done)(client_, id, *data); } } else { if (failure != nullptr) { - (failure)(client_, id, *dataT); + (failure)(client_, id, *data); } } }; - flatbuffers::FlatBufferBuilder fbb; - fbb.ForceDefaults(true); - fbb.Finish(Data::Pack(fbb, dataT.get())); - return GetRedisContext(id)->RunAsync(GetLogAppendCommand(command_type_), id, - fbb.GetBufferPointer(), fbb.GetSize(), prefix_, - pubsub_channel_, std::move(callback), log_length); + std::string str = data->SerializeAsString(); + return GetRedisContext(id)->RunAsync(GetLogAppendCommand(command_type_), id, str.data(), + str.length(), prefix_, pubsub_channel_, + std::move(callback), log_length); } template @@ -89,16 +86,15 @@ Status Log::Lookup(const DriverID &driver_id, const ID &id, num_lookups_++; auto callback = [this, id, lookup](const CallbackReply &reply) { if (lookup != nullptr) { - std::vector results; + std::vector results; if (!reply.IsNil()) { - const auto data = reply.ReadAsString(); - auto root = flatbuffers::GetRoot(data.data()); - RAY_CHECK(from_flatbuf(*root->id()) == id); - for (size_t i = 0; i < root->entries()->size(); i++) { - DataT result; - auto data_root = flatbuffers::GetRoot(root->entries()->Get(i)->data()); - data_root->UnPackTo(&result); - results.emplace_back(std::move(result)); + GcsEntry gcs_entry; + gcs_entry.ParseFromString(reply.ReadAsString()); + RAY_CHECK(ID::FromBinary(gcs_entry.id()) == id); + for (size_t i = 0; i < gcs_entry.entries_size(); i++) { + Data data; + data.ParseFromString(gcs_entry.entries(i)); + results.emplace_back(std::move(data)); } } lookup(client_, id, results); @@ -115,7 +111,7 @@ Status Log::Subscribe(const DriverID &driver_id, const ClientID &clien const SubscriptionCallback &done) { auto subscribe_wrapper = [subscribe](AsyncGcsClient *client, const ID &id, const GcsChangeMode change_mode, - const std::vector &data) { + const std::vector &data) { RAY_CHECK(change_mode != GcsChangeMode::REMOVE); subscribe(client, id, data); }; @@ -141,19 +137,16 @@ Status Log::Subscribe(const DriverID &driver_id, const ClientID &clien // Data is provided. This is the callback for a message. if (subscribe != nullptr) { // Parse the notification. - auto root = flatbuffers::GetRoot(data.data()); - ID id; - if (root->id()->size() > 0) { - id = from_flatbuf(*root->id()); - } - std::vector results; - for (size_t i = 0; i < root->entries()->size(); i++) { - DataT result; - auto data_root = flatbuffers::GetRoot(root->entries()->Get(i)->data()); - data_root->UnPackTo(&result); + GcsEntry gcs_entry; + gcs_entry.ParseFromString(data); + ID id = ID::FromBinary(gcs_entry.id()); + std::vector results; + for (size_t i = 0; i < gcs_entry.entries_size(); i++) { + Data result; + result.ParseFromString(gcs_entry.entries(i)); results.emplace_back(std::move(result)); } - subscribe(client_, id, root->change_mode(), results); + subscribe(client_, id, gcs_entry.change_mode(), results); } } }; @@ -234,19 +227,17 @@ std::string Log::DebugString() const { template Status Table::Add(const DriverID &driver_id, const ID &id, - std::shared_ptr &dataT, const WriteCallback &done) { + std::shared_ptr &data, const WriteCallback &done) { num_adds_++; - auto callback = [this, id, dataT, done](const CallbackReply &reply) { + auto callback = [this, id, data, done](const CallbackReply &reply) { if (done != nullptr) { - (done)(client_, id, *dataT); + (done)(client_, id, *data); } }; - flatbuffers::FlatBufferBuilder fbb; - fbb.ForceDefaults(true); - fbb.Finish(Data::Pack(fbb, dataT.get())); - return GetRedisContext(id)->RunAsync(GetTableAddCommand(command_type_), id, - fbb.GetBufferPointer(), fbb.GetSize(), prefix_, - pubsub_channel_, std::move(callback)); + std::string str = data->SerializeAsString(); + return GetRedisContext(id)->RunAsync(GetTableAddCommand(command_type_), id, str.data(), + str.length(), prefix_, pubsub_channel_, + std::move(callback)); } template @@ -255,7 +246,7 @@ Status Table::Lookup(const DriverID &driver_id, const ID &id, num_lookups_++; return Log::Lookup(driver_id, id, [lookup, failure](AsyncGcsClient *client, const ID &id, - const std::vector &data) { + const std::vector &data) { if (data.empty()) { if (failure != nullptr) { (failure)(client, id); @@ -277,7 +268,7 @@ Status Table::Subscribe(const DriverID &driver_id, const ClientID &cli return Log::Subscribe( driver_id, client_id, [subscribe, failure](AsyncGcsClient *client, const ID &id, - const std::vector &data) { + const std::vector &data) { RAY_CHECK(data.empty() || data.size() == 1); if (data.size() == 1) { subscribe(client, id, data[0]); @@ -299,36 +290,30 @@ std::string Table::DebugString() const { template Status Set::Add(const DriverID &driver_id, const ID &id, - std::shared_ptr &dataT, const WriteCallback &done) { + std::shared_ptr &data, const WriteCallback &done) { num_adds_++; - auto callback = [this, id, dataT, done](const CallbackReply &reply) { + auto callback = [this, id, data, done](const CallbackReply &reply) { if (done != nullptr) { - (done)(client_, id, *dataT); + (done)(client_, id, *data); } }; - flatbuffers::FlatBufferBuilder fbb; - fbb.ForceDefaults(true); - fbb.Finish(Data::Pack(fbb, dataT.get())); - return GetRedisContext(id)->RunAsync("RAY.SET_ADD", id, fbb.GetBufferPointer(), - fbb.GetSize(), prefix_, pubsub_channel_, - std::move(callback)); + std::string str = data->SerializeAsString(); + return GetRedisContext(id)->RunAsync("RAY.SET_ADD", id, str.data(), str.length(), + prefix_, pubsub_channel_, std::move(callback)); } template Status Set::Remove(const DriverID &driver_id, const ID &id, - std::shared_ptr &dataT, const WriteCallback &done) { + std::shared_ptr &data, const WriteCallback &done) { num_removes_++; - auto callback = [this, id, dataT, done](const CallbackReply &reply) { + auto callback = [this, id, data, done](const CallbackReply &reply) { if (done != nullptr) { - (done)(client_, id, *dataT); + (done)(client_, id, *data); } }; - flatbuffers::FlatBufferBuilder fbb; - fbb.ForceDefaults(true); - fbb.Finish(Data::Pack(fbb, dataT.get())); - return GetRedisContext(id)->RunAsync("RAY.SET_REMOVE", id, fbb.GetBufferPointer(), - fbb.GetSize(), prefix_, pubsub_channel_, - std::move(callback)); + std::string str = data->SerializeAsString(); + return GetRedisContext(id)->RunAsync("RAY.SET_REMOVE", id, str.data(), str.length(), + prefix_, pubsub_channel_, std::move(callback)); } template @@ -348,26 +333,16 @@ Status Hash::Update(const DriverID &driver_id, const ID &id, (done)(client_, id, data_map); } }; - flatbuffers::FlatBufferBuilder fbb; - std::vector> data_vec; - data_vec.reserve(data_map.size() * 2); - for (auto const &pair : data_map) { - // Add the key. - data_vec.push_back(fbb.CreateString(pair.first)); - flatbuffers::FlatBufferBuilder fbb_data; - fbb_data.ForceDefaults(true); - fbb_data.Finish(Data::Pack(fbb_data, pair.second.get())); - std::string data(reinterpret_cast(fbb_data.GetBufferPointer()), - fbb_data.GetSize()); - // Add the value. - data_vec.push_back(fbb.CreateString(data)); + GcsEntry gcs_entry; + gcs_entry.set_id(id.Binary()); + gcs_entry.set_change_mode(GcsChangeMode::APPEND_OR_ADD); + for (const auto &pair : data_map) { + gcs_entry.add_entries(pair.first); + gcs_entry.add_entries(pair.second->SerializeAsString()); } - - fbb.Finish(CreateGcsEntry(fbb, GcsChangeMode::APPEND_OR_ADD, - fbb.CreateString(id.Binary()), fbb.CreateVector(data_vec))); - return GetRedisContext(id)->RunAsync("RAY.HASH_UPDATE", id, fbb.GetBufferPointer(), - fbb.GetSize(), prefix_, pubsub_channel_, - std::move(callback)); + std::string str = gcs_entry.SerializeAsString(); + return GetRedisContext(id)->RunAsync("RAY.HASH_UPDATE", id, str.data(), str.size(), + prefix_, pubsub_channel_, std::move(callback)); } template @@ -380,19 +355,15 @@ Status Hash::RemoveEntries(const DriverID &driver_id, const ID &id, (remove_callback)(client_, id, keys); } }; - flatbuffers::FlatBufferBuilder fbb; - std::vector> data_vec; - data_vec.reserve(keys.size()); - // Add the keys. - for (auto const &key : keys) { - data_vec.push_back(fbb.CreateString(key)); + GcsEntry gcs_entry; + gcs_entry.set_id(id.Binary()); + gcs_entry.set_change_mode(GcsChangeMode::REMOVE); + for (const auto &key : keys) { + gcs_entry.add_entries(key); } - - fbb.Finish(CreateGcsEntry(fbb, GcsChangeMode::REMOVE, fbb.CreateString(id.Binary()), - fbb.CreateVector(data_vec))); - return GetRedisContext(id)->RunAsync("RAY.HASH_UPDATE", id, fbb.GetBufferPointer(), - fbb.GetSize(), prefix_, pubsub_channel_, - std::move(callback)); + std::string str = gcs_entry.SerializeAsString(); + return GetRedisContext(id)->RunAsync("RAY.HASH_UPDATE", id, str.data(), str.size(), + prefix_, pubsub_channel_, std::move(callback)); } template @@ -412,17 +383,15 @@ Status Hash::Lookup(const DriverID &driver_id, const ID &id, DataMap results; if (!reply.IsNil()) { const auto data = reply.ReadAsString(); - auto root = flatbuffers::GetRoot(data.data()); - RAY_CHECK(from_flatbuf(*root->id()) == id); - RAY_CHECK(root->entries()->size() % 2 == 0); - for (size_t i = 0; i < root->entries()->size(); i += 2) { - std::string key(root->entries()->Get(i)->data(), - root->entries()->Get(i)->size()); - auto result = std::make_shared(); - auto data_root = - flatbuffers::GetRoot(root->entries()->Get(i + 1)->data()); - data_root->UnPackTo(result.get()); - results.emplace(key, std::move(result)); + GcsEntry gcs_entry; + gcs_entry.ParseFromString(reply.ReadAsString()); + RAY_CHECK(ID::FromBinary(gcs_entry.id()) == id); + RAY_CHECK(gcs_entry.entries_size() % 2 == 0); + for (int i = 0; i < gcs_entry.entries_size(); i += 2) { + const auto &key = gcs_entry.entries(i); + const auto value = std::make_shared(); + value->ParseFromString(gcs_entry.entries(i + 1)); + results.emplace(key, std::move(value)); } } lookup(client_, id, results); @@ -451,31 +420,24 @@ Status Hash::Subscribe(const DriverID &driver_id, const ClientID &clie // Data is provided. This is the callback for a message. if (subscribe != nullptr) { // Parse the notification. - auto root = flatbuffers::GetRoot(data.data()); + GcsEntry gcs_entry; + gcs_entry.ParseFromString(data); + ID id = ID::FromBinary(gcs_entry.id()); DataMap data_map; - ID id; - if (root->id()->size() > 0) { - id = from_flatbuf(*root->id()); - } - if (root->change_mode() == GcsChangeMode::REMOVE) { - for (size_t i = 0; i < root->entries()->size(); i++) { - std::string key(root->entries()->Get(i)->data(), - root->entries()->Get(i)->size()); - data_map.emplace(key, std::shared_ptr()); + if (gcs_entry.change_mode() == GcsChangeMode::REMOVE) { + for (const auto &key : gcs_entry.entries()) { + data_map.emplace(key, std::shared_ptr()); } } else { - RAY_CHECK(root->entries()->size() % 2 == 0); - for (size_t i = 0; i < root->entries()->size(); i += 2) { - std::string key(root->entries()->Get(i)->data(), - root->entries()->Get(i)->size()); - auto result = std::make_shared(); - auto data_root = - flatbuffers::GetRoot(root->entries()->Get(i + 1)->data()); - data_root->UnPackTo(result.get()); - data_map.emplace(key, std::move(result)); + RAY_CHECK(gcs_entry.entries_size() % 2 == 0); + for (int i = 0; i < gcs_entry.entries_size(); i += 2) { + const auto &key = gcs_entry.entries(i); + const auto value = std::make_shared(); + value->ParseFromString(gcs_entry.entries(i + 1)); + data_map.emplace(key, std::move(value)); } } - subscribe(client_, id, root->change_mode(), data_map); + subscribe(client_, id, gcs_entry.change_mode(), data_map); } } }; @@ -490,11 +452,11 @@ Status Hash::Subscribe(const DriverID &driver_id, const ClientID &clie Status ErrorTable::PushErrorToDriver(const DriverID &driver_id, const std::string &type, const std::string &error_message, double timestamp) { - auto data = std::make_shared(); - data->driver_id = driver_id.Binary(); - data->type = type; - data->error_message = error_message; - data->timestamp = timestamp; + auto data = std::make_shared(); + data->set_driver_id(driver_id.Binary()); + data->set_type(type); + data->set_error_message(error_message); + data->set_timestamp(timestamp); return Append(DriverID(driver_id), driver_id, data, /*done_callback=*/nullptr); } @@ -503,11 +465,9 @@ std::string ErrorTable::DebugString() const { } Status ProfileTable::AddProfileEventBatch(const ProfileTableData &profile_events) { - auto data = std::make_shared(); - // There is some room for optimization here because the Append function will just - // call "Pack" and undo the "UnPack". - profile_events.UnPackTo(data.get()); - + // TODO(hchen): Change the parameter to shared_ptr to avoid copying data. + auto data = std::make_shared(); + data->CopyFrom(profile_events); return Append(DriverID::Nil(), UniqueID::FromRandom(), data, /*done_callback=*/nullptr); } @@ -517,9 +477,9 @@ std::string ProfileTable::DebugString() const { } Status DriverTable::AppendDriverData(const DriverID &driver_id, bool is_dead) { - auto data = std::make_shared(); - data->driver_id = driver_id.Binary(); - data->is_dead = is_dead; + auto data = std::make_shared(); + data->set_driver_id(driver_id.Binary()); + data->set_is_dead(is_dead); return Append(DriverID(driver_id), driver_id, data, /*done_callback=*/nullptr); } @@ -527,7 +487,8 @@ void ClientTable::RegisterClientAddedCallback(const ClientTableCallback &callbac client_added_callback_ = callback; // Call the callback for any added clients that are cached. for (const auto &entry : client_cache_) { - if (!entry.first.IsNil() && (entry.second.entry_type == EntryType::INSERTION)) { + if (!entry.first.IsNil() && + (entry.second.entry_type() == ClientTableData::INSERTION)) { client_added_callback_(client_, entry.first, entry.second); } } @@ -537,7 +498,7 @@ void ClientTable::RegisterClientRemovedCallback(const ClientTableCallback &callb client_removed_callback_ = callback; // Call the callback for any removed clients that are cached. for (const auto &entry : client_cache_) { - if (!entry.first.IsNil() && entry.second.entry_type == EntryType::DELETION) { + if (!entry.first.IsNil() && entry.second.entry_type() == ClientTableData::DELETION) { client_removed_callback_(client_, entry.first, entry.second); } } @@ -549,7 +510,7 @@ void ClientTable::RegisterResourceCreateUpdatedCallback( // Call the callback for any clients that are cached. for (const auto &entry : client_cache_) { if (!entry.first.IsNil() && - (entry.second.entry_type == EntryType::RES_CREATEUPDATE)) { + (entry.second.entry_type() == ClientTableData::RES_CREATEUPDATE)) { resource_createupdated_callback_(client_, entry.first, entry.second); } } @@ -559,15 +520,16 @@ void ClientTable::RegisterResourceDeletedCallback(const ClientTableCallback &cal resource_deleted_callback_ = callback; // Call the callback for any clients that are cached. for (const auto &entry : client_cache_) { - if (!entry.first.IsNil() && entry.second.entry_type == EntryType::RES_DELETE) { + if (!entry.first.IsNil() && + entry.second.entry_type() == ClientTableData::RES_DELETE) { resource_deleted_callback_(client_, entry.first, entry.second); } } } void ClientTable::HandleNotification(AsyncGcsClient *client, - const ClientTableDataT &data) { - ClientID client_id = ClientID::FromBinary(data.client_id); + const ClientTableData &data) { + ClientID client_id = ClientID::FromBinary(data.client_id()); // It's possible to get duplicate notifications from the client table, so // check whether this notification is new. auto entry = client_cache_.find(client_id); @@ -578,16 +540,16 @@ void ClientTable::HandleNotification(AsyncGcsClient *client, } else { // If the entry is in the cache, then the notification is new if the client // was alive and is now dead or resources have been updated. - bool was_not_deleted = (entry->second.entry_type != EntryType::DELETION); - bool is_deleted = (data.entry_type == EntryType::DELETION); - bool is_res_modified = ((data.entry_type == EntryType::RES_CREATEUPDATE) || - (data.entry_type == EntryType::RES_DELETE)); + bool was_not_deleted = (entry->second.entry_type() != ClientTableData::DELETION); + bool is_deleted = (data.entry_type() == ClientTableData::DELETION); + bool is_res_modified = ((data.entry_type() == ClientTableData::RES_CREATEUPDATE) || + (data.entry_type() == ClientTableData::RES_DELETE)); is_notif_new = (was_not_deleted && (is_deleted || is_res_modified)); // Once a client with a given ID has been removed, it should never be added // again. If the entry was in the cache and the client was deleted, check // that this new notification is not an insertion. - if (entry->second.entry_type == EntryType::DELETION) { - RAY_CHECK((data.entry_type == EntryType::DELETION)) + if (entry->second.entry_type() == ClientTableData::DELETION) { + RAY_CHECK((data.entry_type() == ClientTableData::DELETION)) << "Notification for addition of a client that was already removed:" << client_id; } @@ -595,64 +557,64 @@ void ClientTable::HandleNotification(AsyncGcsClient *client, // Add the notification to our cache. Notifications are idempotent. // If it is a new client or a client removal, add as is - if ((data.entry_type == EntryType::INSERTION) || - (data.entry_type == EntryType::DELETION)) { + if ((data.entry_type() == ClientTableData::INSERTION) || + (data.entry_type() == ClientTableData::DELETION)) { RAY_LOG(DEBUG) << "[ClientTableNotification] ClientTable Insertion/Deletion " "notification for client id " - << client_id << ". EntryType: " << int(data.entry_type) + << client_id << ". EntryType: " << int(data.entry_type()) << ". Setting the client cache to data."; client_cache_[client_id] = data; - } else if ((data.entry_type == EntryType::RES_CREATEUPDATE) || - (data.entry_type == EntryType::RES_DELETE)) { + } else if ((data.entry_type() == ClientTableData::RES_CREATEUPDATE) || + (data.entry_type() == ClientTableData::RES_DELETE)) { RAY_LOG(DEBUG) << "[ClientTableNotification] ClientTable RES_CREATEUPDATE " "notification for client id " - << client_id << ". EntryType: " << int(data.entry_type) + << client_id << ". EntryType: " << int(data.entry_type()) << ". Updating the client cache with the delta from the log."; - ClientTableDataT &cache_data = client_cache_[client_id]; + ClientTableData &cache_data = client_cache_[client_id]; // Iterate over all resources in the new create/update notification - for (std::vector::size_type i = 0; i != data.resources_total_label.size(); i++) { - auto const &resource_name = data.resources_total_label[i]; - auto const &capacity = data.resources_total_capacity[i]; + for (std::vector::size_type i = 0; i != data.resources_total_label_size(); i++) { + auto const &resource_name = data.resources_total_label(i); + auto const &capacity = data.resources_total_capacity(i); // If resource exists in the ClientTableData, update it, else create it auto existing_resource_label = - std::find(cache_data.resources_total_label.begin(), - cache_data.resources_total_label.end(), resource_name); - if (existing_resource_label != cache_data.resources_total_label.end()) { - auto index = std::distance(cache_data.resources_total_label.begin(), + std::find(cache_data.resources_total_label().begin(), + cache_data.resources_total_label().end(), resource_name); + if (existing_resource_label != cache_data.resources_total_label().end()) { + auto index = std::distance(cache_data.resources_total_label().begin(), existing_resource_label); // Resource already exists, set capacity if updation call.. - if (data.entry_type == EntryType::RES_CREATEUPDATE) { - cache_data.resources_total_capacity[index] = capacity; + if (data.entry_type() == ClientTableData::RES_CREATEUPDATE) { + cache_data.set_resources_total_capacity(index, capacity); } // .. delete if deletion call. - else if (data.entry_type == EntryType::RES_DELETE) { - cache_data.resources_total_label.erase( - cache_data.resources_total_label.begin() + index); - cache_data.resources_total_capacity.erase( - cache_data.resources_total_capacity.begin() + index); + else if (data.entry_type() == ClientTableData::RES_DELETE) { + cache_data.mutable_resources_total_label()->erase( + cache_data.resources_total_label().begin() + index); + cache_data.mutable_resources_total_capacity()->erase( + cache_data.resources_total_capacity().begin() + index); } } else { // Resource does not exist, create resource and add capacity if it was a resource // create call. - if (data.entry_type == EntryType::RES_CREATEUPDATE) { - cache_data.resources_total_label.push_back(resource_name); - cache_data.resources_total_capacity.push_back(capacity); + if (data.entry_type() == ClientTableData::RES_CREATEUPDATE) { + cache_data.add_resources_total_label(resource_name); + cache_data.add_resources_total_capacity(capacity); } } } } // If the notification is new, call any registered callbacks. - ClientTableDataT &cache_data = client_cache_[client_id]; + ClientTableData &cache_data = client_cache_[client_id]; if (is_notif_new) { - if (data.entry_type == EntryType::INSERTION) { + if (data.entry_type() == ClientTableData::INSERTION) { if (client_added_callback_ != nullptr) { client_added_callback_(client, client_id, cache_data); } RAY_CHECK(removed_clients_.find(client_id) == removed_clients_.end()); - } else if (data.entry_type == EntryType::DELETION) { + } else if (data.entry_type() == ClientTableData::DELETION) { // NOTE(swang): The client should be added to this data structure before // the callback gets called, in case the callback depends on the data // structure getting updated. @@ -660,11 +622,11 @@ void ClientTable::HandleNotification(AsyncGcsClient *client, if (client_removed_callback_ != nullptr) { client_removed_callback_(client, client_id, cache_data); } - } else if (data.entry_type == EntryType::RES_CREATEUPDATE) { + } else if (data.entry_type() == ClientTableData::RES_CREATEUPDATE) { if (resource_createupdated_callback_ != nullptr) { resource_createupdated_callback_(client, client_id, cache_data); } - } else if (data.entry_type == EntryType::RES_DELETE) { + } else if (data.entry_type() == ClientTableData::RES_DELETE) { if (resource_deleted_callback_ != nullptr) { resource_deleted_callback_(client, client_id, cache_data); } @@ -672,54 +634,54 @@ void ClientTable::HandleNotification(AsyncGcsClient *client, } } -void ClientTable::HandleConnected(AsyncGcsClient *client, const ClientTableDataT &data) { - auto connected_client_id = ClientID::FromBinary(data.client_id); +void ClientTable::HandleConnected(AsyncGcsClient *client, const ClientTableData &data) { + auto connected_client_id = ClientID::FromBinary(data.client_id()); RAY_CHECK(client_id_ == connected_client_id) << connected_client_id << " " << client_id_; } const ClientID &ClientTable::GetLocalClientId() const { return client_id_; } -const ClientTableDataT &ClientTable::GetLocalClient() const { return local_client_; } +const ClientTableData &ClientTable::GetLocalClient() const { return local_client_; } bool ClientTable::IsRemoved(const ClientID &client_id) const { return removed_clients_.count(client_id) == 1; } -Status ClientTable::Connect(const ClientTableDataT &local_client) { +Status ClientTable::Connect(const ClientTableData &local_client) { RAY_CHECK(!disconnected_) << "Tried to reconnect a disconnected client."; - RAY_CHECK(local_client.client_id == local_client_.client_id); + RAY_CHECK(local_client.client_id() == local_client_.client_id()); local_client_ = local_client; // Construct the data to add to the client table. - auto data = std::make_shared(local_client_); - data->entry_type = EntryType::INSERTION; + auto data = std::make_shared(local_client_); + data->set_entry_type(ClientTableData::INSERTION); // Callback to handle our own successful connection once we've added // ourselves. auto add_callback = [this](AsyncGcsClient *client, const UniqueID &log_key, - const ClientTableDataT &data) { + const ClientTableData &data) { RAY_CHECK(log_key == client_log_key_); HandleConnected(client, data); // Callback for a notification from the client table. auto notification_callback = [this]( AsyncGcsClient *client, const UniqueID &log_key, - const std::vector ¬ifications) { + const std::vector ¬ifications) { RAY_CHECK(log_key == client_log_key_); - std::unordered_map connected_nodes; - std::unordered_map disconnected_nodes; + std::unordered_map connected_nodes; + std::unordered_map disconnected_nodes; for (auto ¬ification : notifications) { // This is temporary fix for Issue 4140 to avoid connect to dead nodes. // TODO(yuhguo): remove this temporary fix after GCS entry is removable. - if (notification.entry_type != EntryType::DELETION) { - connected_nodes.emplace(notification.client_id, notification); + if (notification.entry_type() != ClientTableData::DELETION) { + connected_nodes.emplace(notification.client_id(), notification); } else { - auto iter = connected_nodes.find(notification.client_id); + auto iter = connected_nodes.find(notification.client_id()); if (iter != connected_nodes.end()) { connected_nodes.erase(iter); } - disconnected_nodes.emplace(notification.client_id, notification); + disconnected_nodes.emplace(notification.client_id(), notification); } } for (const auto &pair : connected_nodes) { @@ -742,10 +704,10 @@ Status ClientTable::Connect(const ClientTableDataT &local_client) { } Status ClientTable::Disconnect(const DisconnectCallback &callback) { - auto data = std::make_shared(local_client_); - data->entry_type = EntryType::DELETION; + auto data = std::make_shared(local_client_); + data->set_entry_type(ClientTableData::DELETION); auto add_callback = [this, callback](AsyncGcsClient *client, const ClientID &id, - const ClientTableDataT &data) { + const ClientTableData &data) { HandleConnected(client, data); RAY_CHECK_OK(CancelNotifications(DriverID::Nil(), client_log_key_, id)); if (callback != nullptr) { @@ -759,24 +721,24 @@ Status ClientTable::Disconnect(const DisconnectCallback &callback) { } ray::Status ClientTable::MarkDisconnected(const ClientID &dead_client_id) { - auto data = std::make_shared(); - data->client_id = dead_client_id.Binary(); - data->entry_type = EntryType::DELETION; + auto data = std::make_shared(); + data->set_client_id(dead_client_id.Binary()); + data->set_entry_type(ClientTableData::DELETION); return Append(DriverID::Nil(), client_log_key_, data, nullptr); } void ClientTable::GetClient(const ClientID &client_id, - ClientTableDataT &client_info) const { + ClientTableData &client_info) const { RAY_CHECK(!client_id.IsNil()); auto entry = client_cache_.find(client_id); if (entry != client_cache_.end()) { client_info = entry->second; } else { - client_info.client_id = ClientID::Nil().Binary(); + client_info.set_client_id(ClientID::Nil().Binary()); } } -const std::unordered_map &ClientTable::GetAllClients() const { +const std::unordered_map &ClientTable::GetAllClients() const { return client_cache_; } @@ -798,31 +760,29 @@ Status ActorCheckpointIdTable::AddCheckpointId(const DriverID &driver_id, const ActorCheckpointID &checkpoint_id) { auto lookup_callback = [this, checkpoint_id, driver_id, actor_id]( ray::gcs::AsyncGcsClient *client, const UniqueID &id, - const ActorCheckpointIdDataT &data) { - std::shared_ptr copy = - std::make_shared(data); - copy->timestamps.push_back(current_sys_time_ms()); - copy->checkpoint_ids += checkpoint_id.Binary(); + const ActorCheckpointIdData &data) { + std::shared_ptr copy = + std::make_shared(data); + copy->add_timestamps(current_sys_time_ms()); + copy->add_checkpoint_ids(checkpoint_id.Binary()); auto num_to_keep = RayConfig::instance().num_actor_checkpoints_to_keep(); - while (copy->timestamps.size() > num_to_keep) { + while (copy->timestamps().size() > num_to_keep) { // Delete the checkpoint from actor checkpoint table. - const auto &checkpoint_id = - ActorCheckpointID::FromBinary(copy->checkpoint_ids.substr(0, kUniqueIDSize)); - RAY_LOG(DEBUG) << "Deleting checkpoint " << checkpoint_id << " for actor " - << actor_id; - copy->timestamps.erase(copy->timestamps.begin()); - copy->checkpoint_ids.erase(0, kUniqueIDSize); - client_->actor_checkpoint_table().Delete(driver_id, checkpoint_id); + const auto &to_delete = ActorCheckpointID::FromBinary(copy->checkpoint_ids(0)); + RAY_LOG(DEBUG) << "Deleting checkpoint " << to_delete << " for actor " << actor_id; + copy->mutable_checkpoint_ids()->erase(copy->mutable_checkpoint_ids()->begin()); + copy->mutable_timestamps()->erase(copy->mutable_timestamps()->begin()); + client_->actor_checkpoint_table().Delete(driver_id, to_delete); } RAY_CHECK_OK(Add(driver_id, actor_id, copy, nullptr)); }; auto failure_callback = [this, checkpoint_id, driver_id, actor_id]( ray::gcs::AsyncGcsClient *client, const UniqueID &id) { - std::shared_ptr data = - std::make_shared(); - data->actor_id = id.Binary(); - data->timestamps.push_back(current_sys_time_ms()); - data->checkpoint_ids = checkpoint_id.Binary(); + std::shared_ptr data = + std::make_shared(); + data->set_actor_id(id.Binary()); + data->add_timestamps(current_sys_time_ms()); + *data->add_checkpoint_ids() = checkpoint_id.Binary(); RAY_CHECK_OK(Add(driver_id, actor_id, data, nullptr)); }; return Lookup(driver_id, actor_id, lookup_callback, failure_callback); @@ -830,8 +790,7 @@ Status ActorCheckpointIdTable::AddCheckpointId(const DriverID &driver_id, template class Log; template class Set; -template class Log; -template class Table; +template class Log; template class Table; template class Log; template class Log; diff --git a/src/ray/gcs/tables.h b/src/ray/gcs/tables.h index 6a1d502a7f54..2ecc3440839e 100644 --- a/src/ray/gcs/tables.h +++ b/src/ray/gcs/tables.h @@ -11,10 +11,8 @@ #include "ray/common/status.h" #include "ray/util/logging.h" -#include "ray/gcs/format/gcs_generated.h" #include "ray/gcs/redis_context.h" -// TODO(rkn): Remove this include. -#include "ray/raylet/format/node_manager_generated.h" +#include "ray/protobuf/gcs.pb.h" struct redisAsyncContext; @@ -22,6 +20,25 @@ namespace ray { namespace gcs { +using rpc::ActorCheckpointData; +using rpc::ActorCheckpointIdData; +using rpc::ActorTableData; +using rpc::ClientTableData; +using rpc::DriverTableData; +using rpc::ErrorTableData; +using rpc::GcsChangeMode; +using rpc::GcsEntry; +using rpc::HeartbeatBatchTableData; +using rpc::HeartbeatTableData; +using rpc::ObjectTableData; +using rpc::ProfileTableData; +using rpc::RayResource; +using rpc::TablePrefix; +using rpc::TablePubsub; +using rpc::TaskLeaseData; +using rpc::TaskReconstructionData; +using rpc::TaskTableData; + class RedisContext; class AsyncGcsClient; @@ -48,13 +65,12 @@ class PubsubInterface { template class LogInterface { public: - using DataT = typename Data::NativeTableType; using WriteCallback = - std::function; + std::function; virtual Status Append(const DriverID &driver_id, const ID &id, - std::shared_ptr &data, const WriteCallback &done) = 0; + std::shared_ptr &data, const WriteCallback &done) = 0; virtual Status AppendAt(const DriverID &driver_id, const ID &task_id, - std::shared_ptr &data, const WriteCallback &done, + std::shared_ptr &data, const WriteCallback &done, const WriteCallback &failure, int log_length) = 0; virtual ~LogInterface(){}; }; @@ -72,12 +88,11 @@ class LogInterface { template class Log : public LogInterface, virtual public PubsubInterface { public: - using DataT = typename Data::NativeTableType; using Callback = std::function &data)>; - using NotificationCallback = std::function &data)>; + const std::vector &data)>; + using NotificationCallback = + std::function &data)>; /// The callback to call when a write to a key succeeds. using WriteCallback = typename LogInterface::WriteCallback; /// The callback to call when a SUBSCRIBE call completes and we are ready to @@ -86,7 +101,7 @@ class Log : public LogInterface, virtual public PubsubInterface { struct CallbackData { ID id; - std::shared_ptr data; + std::shared_ptr data; Callback callback; // An optional callback to call for subscription operations, where the // first message is a notification of subscription success. @@ -111,7 +126,7 @@ class Log : public LogInterface, virtual public PubsubInterface { /// \param done Callback that is called once the data has been written to the /// GCS. /// \return Status - Status Append(const DriverID &driver_id, const ID &id, std::shared_ptr &data, + Status Append(const DriverID &driver_id, const ID &id, std::shared_ptr &data, const WriteCallback &done); /// Append a log entry to a key if and only if the log has the given number @@ -126,7 +141,7 @@ class Log : public LogInterface, virtual public PubsubInterface { /// \param log_length The number of entries that the log must have for the /// append to succeed. /// \return Status - Status AppendAt(const DriverID &driver_id, const ID &id, std::shared_ptr &data, + Status AppendAt(const DriverID &driver_id, const ID &id, std::shared_ptr &data, const WriteCallback &done, const WriteCallback &failure, int log_length); @@ -259,10 +274,9 @@ class Log : public LogInterface, virtual public PubsubInterface { template class TableInterface { public: - using DataT = typename Data::NativeTableType; using WriteCallback = typename Log::WriteCallback; virtual Status Add(const DriverID &driver_id, const ID &task_id, - std::shared_ptr &data, const WriteCallback &done) = 0; + std::shared_ptr &data, const WriteCallback &done) = 0; virtual ~TableInterface(){}; }; @@ -280,9 +294,8 @@ class Table : private Log, public TableInterface, virtual public PubsubInterface { public: - using DataT = typename Log::DataT; using Callback = - std::function; + std::function; using WriteCallback = typename Log::WriteCallback; /// The callback to call when a Lookup call returns an empty entry. using FailureCallback = std::function; @@ -305,7 +318,7 @@ class Table : private Log, /// \param done Callback that is called once the data has been written to the /// GCS. /// \return Status - Status Add(const DriverID &driver_id, const ID &id, std::shared_ptr &data, + Status Add(const DriverID &driver_id, const ID &id, std::shared_ptr &data, const WriteCallback &done); /// Lookup an entry asynchronously. @@ -369,12 +382,11 @@ class Table : private Log, template class SetInterface { public: - using DataT = typename Data::NativeTableType; using WriteCallback = typename Log::WriteCallback; - virtual Status Add(const DriverID &driver_id, const ID &id, - std::shared_ptr &data, const WriteCallback &done) = 0; + virtual Status Add(const DriverID &driver_id, const ID &id, std::shared_ptr &data, + const WriteCallback &done) = 0; virtual Status Remove(const DriverID &driver_id, const ID &id, - std::shared_ptr &data, const WriteCallback &done) = 0; + std::shared_ptr &data, const WriteCallback &done) = 0; virtual ~SetInterface(){}; }; @@ -392,7 +404,6 @@ class Set : private Log, public SetInterface, virtual public PubsubInterface { public: - using DataT = typename Log::DataT; using Callback = typename Log::Callback; using WriteCallback = typename Log::WriteCallback; using NotificationCallback = typename Log::NotificationCallback; @@ -414,7 +425,7 @@ class Set : private Log, /// \param done Callback that is called once the data has been written to the /// GCS. /// \return Status - Status Add(const DriverID &driver_id, const ID &id, std::shared_ptr &data, + Status Add(const DriverID &driver_id, const ID &id, std::shared_ptr &data, const WriteCallback &done); /// Remove an entry from the set. @@ -425,7 +436,7 @@ class Set : private Log, /// \param done Callback that is called once the data has been written to the /// GCS. /// \return Status - Status Remove(const DriverID &driver_id, const ID &id, std::shared_ptr &data, + Status Remove(const DriverID &driver_id, const ID &id, std::shared_ptr &data, const WriteCallback &done); Status Subscribe(const DriverID &driver_id, const ClientID &client_id, @@ -454,8 +465,7 @@ class Set : private Log, template class HashInterface { public: - using DataT = typename Data::NativeTableType; - using DataMap = std::unordered_map>; + using DataMap = std::unordered_map>; // Reuse Log's SubscriptionCallback when Subscribe is successfully called. using SubscriptionCallback = typename Log::SubscriptionCallback; @@ -544,8 +554,7 @@ class Hash : private Log, public HashInterface, virtual public PubsubInterface { public: - using DataT = typename Log::DataT; - using DataMap = std::unordered_map>; + using DataMap = std::unordered_map>; using HashCallback = typename HashInterface::HashCallback; using HashRemoveCallback = typename HashInterface::HashRemoveCallback; using HashNotificationCallback = @@ -595,7 +604,7 @@ class DynamicResourceTable : public Hash { DynamicResourceTable(const std::vector> &contexts, AsyncGcsClient *client) : Hash(contexts, client) { - pubsub_channel_ = TablePubsub::NODE_RESOURCE; + pubsub_channel_ = TablePubsub::NODE_RESOURCE_PUBSUB; prefix_ = TablePrefix::NODE_RESOURCE; }; @@ -607,7 +616,7 @@ class ObjectTable : public Set { ObjectTable(const std::vector> &contexts, AsyncGcsClient *client) : Set(contexts, client) { - pubsub_channel_ = TablePubsub::OBJECT; + pubsub_channel_ = TablePubsub::OBJECT_PUBSUB; prefix_ = TablePrefix::OBJECT; }; @@ -619,7 +628,7 @@ class HeartbeatTable : public Table { HeartbeatTable(const std::vector> &contexts, AsyncGcsClient *client) : Table(contexts, client) { - pubsub_channel_ = TablePubsub::HEARTBEAT; + pubsub_channel_ = TablePubsub::HEARTBEAT_PUBSUB; prefix_ = TablePrefix::HEARTBEAT; } virtual ~HeartbeatTable() {} @@ -630,7 +639,7 @@ class HeartbeatBatchTable : public Table { HeartbeatBatchTable(const std::vector> &contexts, AsyncGcsClient *client) : Table(contexts, client) { - pubsub_channel_ = TablePubsub::HEARTBEAT_BATCH; + pubsub_channel_ = TablePubsub::HEARTBEAT_BATCH_PUBSUB; prefix_ = TablePrefix::HEARTBEAT_BATCH; } virtual ~HeartbeatBatchTable() {} @@ -641,7 +650,7 @@ class DriverTable : public Log { DriverTable(const std::vector> &contexts, AsyncGcsClient *client) : Log(contexts, client) { - pubsub_channel_ = TablePubsub::DRIVER; + pubsub_channel_ = TablePubsub::DRIVER_PUBSUB; prefix_ = TablePrefix::DRIVER; }; @@ -655,18 +664,6 @@ class DriverTable : public Log { Status AppendDriverData(const DriverID &driver_id, bool is_dead); }; -class FunctionTable : public Table { - public: - FunctionTable(const std::vector> &contexts, - AsyncGcsClient *client) - : Table(contexts, client) { - pubsub_channel_ = TablePubsub::NO_PUBLISH; - prefix_ = TablePrefix::FUNCTION; - }; -}; - -using ClassTable = Table; - /// Actor table starts with an ALIVE entry, which represents the first time the actor /// is created. This may be followed by 0 or more pairs of RECONSTRUCTING, ALIVE entries, /// which represent each time the actor fails (RECONSTRUCTING) and gets recreated (ALIVE). @@ -677,7 +674,7 @@ class ActorTable : public Log { ActorTable(const std::vector> &contexts, AsyncGcsClient *client) : Log(contexts, client) { - pubsub_channel_ = TablePubsub::ACTOR; + pubsub_channel_ = TablePubsub::ACTOR_PUBSUB; prefix_ = TablePrefix::ACTOR; } }; @@ -696,12 +693,12 @@ class TaskLeaseTable : public Table { TaskLeaseTable(const std::vector> &contexts, AsyncGcsClient *client) : Table(contexts, client) { - pubsub_channel_ = TablePubsub::TASK_LEASE; + pubsub_channel_ = TablePubsub::TASK_LEASE_PUBSUB; prefix_ = TablePrefix::TASK_LEASE; } Status Add(const DriverID &driver_id, const TaskID &id, - std::shared_ptr &data, const WriteCallback &done) override { + std::shared_ptr &data, const WriteCallback &done) override { RAY_RETURN_NOT_OK((Table::Add(driver_id, id, data, done))); // Mark the entry for expiration in Redis. It's okay if this command fails // since the lease entry itself contains the expiration period. In the @@ -709,9 +706,8 @@ class TaskLeaseTable : public Table { // entry will overestimate the expiration time. // TODO(swang): Use a common helper function to format the key instead of // hardcoding it to match the Redis module. - std::vector args = {"PEXPIRE", - EnumNameTablePrefix(prefix_) + id.Binary(), - std::to_string(data->timeout)}; + std::vector args = {"PEXPIRE", TablePrefix_Name(prefix_) + id.Binary(), + std::to_string(data->timeout())}; return GetRedisContext(id)->RunArgvAsync(args); } @@ -747,12 +743,12 @@ class ActorCheckpointIdTable : public Table { namespace raylet { -class TaskTable : public Table { +class TaskTable : public Table { public: TaskTable(const std::vector> &contexts, AsyncGcsClient *client) : Table(contexts, client) { - pubsub_channel_ = TablePubsub::RAYLET_TASK; + pubsub_channel_ = TablePubsub::RAYLET_TASK_PUBSUB; prefix_ = TablePrefix::RAYLET_TASK; } @@ -770,7 +766,7 @@ class ErrorTable : private Log { ErrorTable(const std::vector> &contexts, AsyncGcsClient *client) : Log(contexts, client) { - pubsub_channel_ = TablePubsub::ERROR_INFO; + pubsub_channel_ = TablePubsub::ERROR_INFO_PUBSUB; prefix_ = TablePrefix::ERROR_INFO; }; @@ -815,10 +811,6 @@ class ProfileTable : private Log { std::string DebugString() const; }; -using CustomSerializerTable = Table; - -using ConfigTable = Table; - /// \class ClientTable /// /// The ClientTable stores information about active and inactive clients. It is @@ -831,7 +823,7 @@ using ConfigTable = Table; class ClientTable : public Log { public: using ClientTableCallback = std::function; + AsyncGcsClient *client, const ClientID &id, const ClientTableData &data)>; using DisconnectCallback = std::function; ClientTable(const std::vector> &contexts, AsyncGcsClient *client, const ClientID &client_id) @@ -842,11 +834,11 @@ class ClientTable : public Log { disconnected_(false), client_id_(client_id), local_client_() { - pubsub_channel_ = TablePubsub::CLIENT; + pubsub_channel_ = TablePubsub::CLIENT_PUBSUB; prefix_ = TablePrefix::CLIENT; // Set the local client's ID. - local_client_.client_id = client_id.Binary(); + local_client_.set_client_id(client_id.Binary()); }; /// Connect as a client to the GCS. This registers us in the client table @@ -855,7 +847,7 @@ class ClientTable : public Log { /// \param Information about the connecting client. This must have the /// same client_id as the one set in the client table. /// \return Status - ray::Status Connect(const ClientTableDataT &local_client); + ray::Status Connect(const ClientTableData &local_client); /// Disconnect the client from the GCS. The client ID assigned during /// registration should never be reused after disconnecting. @@ -898,7 +890,7 @@ class ClientTable : public Log { /// about the client in the cache, then the reference will be modified to /// contain that information. Else, the reference will be updated to contain /// a nil client ID. - void GetClient(const ClientID &client, ClientTableDataT &client_info) const; + void GetClient(const ClientID &client, ClientTableData &client_info) const; /// Get the local client's ID. /// @@ -908,7 +900,7 @@ class ClientTable : public Log { /// Get the local client's information. /// /// \return The local client's information. - const ClientTableDataT &GetLocalClient() const; + const ClientTableData &GetLocalClient() const; /// Check whether the given client is removed. /// @@ -919,7 +911,7 @@ class ClientTable : public Log { /// Get the information of all clients. /// /// \return The client ID to client information map. - const std::unordered_map &GetAllClients() const; + const std::unordered_map &GetAllClients() const; /// Lookup the client data in the client table. /// @@ -940,15 +932,15 @@ class ClientTable : public Log { private: /// Handle a client table notification. - void HandleNotification(AsyncGcsClient *client, const ClientTableDataT ¬ifications); + void HandleNotification(AsyncGcsClient *client, const ClientTableData ¬ifications); /// Handle this client's successful connection to the GCS. - void HandleConnected(AsyncGcsClient *client, const ClientTableDataT &client_data); + void HandleConnected(AsyncGcsClient *client, const ClientTableData &client_data); /// Whether this client has called Disconnect(). bool disconnected_; /// This client's ID. const ClientID client_id_; /// Information about this client. - ClientTableDataT local_client_; + ClientTableData local_client_; /// The callback to call when a new client is added. ClientTableCallback client_added_callback_; /// The callback to call when a client is removed. @@ -958,7 +950,7 @@ class ClientTable : public Log { /// The callback to call when a resource is deleted. ClientTableCallback resource_deleted_callback_; /// A cache for information about all clients. - std::unordered_map client_cache_; + std::unordered_map client_cache_; /// The set of removed clients. std::unordered_set removed_clients_; }; diff --git a/src/ray/object_manager/object_directory.cc b/src/ray/object_manager/object_directory.cc index 5b6794a505d3..454379d18302 100644 --- a/src/ray/object_manager/object_directory.cc +++ b/src/ray/object_manager/object_directory.cc @@ -8,18 +8,22 @@ ObjectDirectory::ObjectDirectory(boost::asio::io_service &io_service, namespace { +using ray::rpc::ClientTableData; +using ray::rpc::GcsChangeMode; +using ray::rpc::ObjectTableData; + /// Process a notification of the object table entries and store the result in /// client_ids. This assumes that client_ids already contains the result of the /// object table entries up to but not including this notification. void UpdateObjectLocations(const GcsChangeMode change_mode, - const std::vector &location_updates, + const std::vector &location_updates, const ray::gcs::ClientTable &client_table, std::unordered_set *client_ids) { // location_updates contains the updates of locations of the object. // with GcsChangeMode, we can determine whether the update mode is // addition or deletion. for (const auto &object_table_data : location_updates) { - ClientID client_id = ClientID::FromBinary(object_table_data.manager); + ClientID client_id = ClientID::FromBinary(object_table_data.manager()); if (change_mode != GcsChangeMode::REMOVE) { client_ids->insert(client_id); } else { @@ -42,7 +46,7 @@ void ObjectDirectory::RegisterBackend() { auto object_notification_callback = [this](gcs::AsyncGcsClient *client, const ObjectID &object_id, const GcsChangeMode change_mode, - const std::vector &location_updates) { + const std::vector &location_updates) { // Objects are added to this map in SubscribeObjectLocations. auto it = listeners_.find(object_id); // Do nothing for objects we are not listening for. @@ -79,9 +83,9 @@ ray::Status ObjectDirectory::ReportObjectAdded( const object_manager::protocol::ObjectInfoT &object_info) { RAY_LOG(DEBUG) << "Reporting object added to GCS " << object_id; // Append the addition entry to the object table. - auto data = std::make_shared(); - data->manager = client_id.Binary(); - data->object_size = object_info.data_size; + auto data = std::make_shared(); + data->set_manager(client_id.Binary()); + data->set_object_size(object_info.data_size); ray::Status status = gcs_client_->object_table().Add(DriverID::Nil(), object_id, data, nullptr); return status; @@ -92,9 +96,9 @@ ray::Status ObjectDirectory::ReportObjectRemoved( const object_manager::protocol::ObjectInfoT &object_info) { RAY_LOG(DEBUG) << "Reporting object removed to GCS " << object_id; // Append the eviction entry to the object table. - auto data = std::make_shared(); - data->manager = client_id.Binary(); - data->object_size = object_info.data_size; + auto data = std::make_shared(); + data->set_manager(client_id.Binary()); + data->set_object_size(object_info.data_size); ray::Status status = gcs_client_->object_table().Remove(DriverID::Nil(), object_id, data, nullptr); return status; @@ -102,14 +106,14 @@ ray::Status ObjectDirectory::ReportObjectRemoved( void ObjectDirectory::LookupRemoteConnectionInfo( RemoteConnectionInfo &connection_info) const { - ClientTableDataT client_data; + ClientTableData client_data; gcs_client_->client_table().GetClient(connection_info.client_id, client_data); - ClientID result_client_id = ClientID::FromBinary(client_data.client_id); + ClientID result_client_id = ClientID::FromBinary(client_data.client_id()); if (!result_client_id.IsNil()) { RAY_CHECK(result_client_id == connection_info.client_id); - if (client_data.entry_type == EntryType::INSERTION) { - connection_info.ip = client_data.node_manager_address; - connection_info.port = static_cast(client_data.object_manager_port); + if (client_data.entry_type() == ClientTableData::INSERTION) { + connection_info.ip = client_data.node_manager_address(); + connection_info.port = static_cast(client_data.object_manager_port()); } } } @@ -208,7 +212,7 @@ ray::Status ObjectDirectory::LookupLocations(const ObjectID &object_id, status = gcs_client_->object_table().Lookup( DriverID::Nil(), object_id, [this, callback](gcs::AsyncGcsClient *client, const ObjectID &object_id, - const std::vector &location_updates) { + const std::vector &location_updates) { // Build the set of current locations based on the entries in the log. std::unordered_set client_ids; UpdateObjectLocations(GcsChangeMode::APPEND_OR_ADD, location_updates, diff --git a/src/ray/object_manager/object_manager.cc b/src/ray/object_manager/object_manager.cc index 954162c21aef..964cee605ced 100644 --- a/src/ray/object_manager/object_manager.cc +++ b/src/ray/object_manager/object_manager.cc @@ -309,15 +309,15 @@ void ObjectManager::HandleSendFinished(const ObjectID &object_id, // TODO(rkn): What do we want to do if the send failed? } - ProfileEventT profile_event; - profile_event.event_type = "transfer_send"; - profile_event.start_time = start_time; - profile_event.end_time = end_time; + rpc::ProfileTableData::ProfileEvent profile_event; + profile_event.set_event_type("transfer_send"); + profile_event.set_start_time(start_time); + profile_event.set_end_time(end_time); // Encode the object ID, client ID, chunk index, and status as a json list, // which will be parsed by the reader of the profile table. - profile_event.extra_data = "[\"" + object_id.Hex() + "\",\"" + client_id.Hex() + "\"," + - std::to_string(chunk_index) + ",\"" + status.ToString() + - "\"]"; + profile_event.set_extra_data("[\"" + object_id.Hex() + "\",\"" + client_id.Hex() + + "\"," + std::to_string(chunk_index) + ",\"" + + status.ToString() + "\"]"); profile_events_.push_back(profile_event); } @@ -329,15 +329,15 @@ void ObjectManager::HandleReceiveFinished(const ObjectID &object_id, // TODO(rkn): What do we want to do if the send failed? } - ProfileEventT profile_event; - profile_event.event_type = "transfer_receive"; - profile_event.start_time = start_time; - profile_event.end_time = end_time; + rpc::ProfileTableData::ProfileEvent profile_event; + profile_event.set_event_type("transfer_receive"); + profile_event.set_start_time(start_time); + profile_event.set_end_time(end_time); // Encode the object ID, client ID, chunk index, and status as a json list, // which will be parsed by the reader of the profile table. - profile_event.extra_data = "[\"" + object_id.Hex() + "\",\"" + client_id.Hex() + "\"," + - std::to_string(chunk_index) + ",\"" + status.ToString() + - "\"]"; + profile_event.set_extra_data("[\"" + object_id.Hex() + "\",\"" + client_id.Hex() + + "\"," + std::to_string(chunk_index) + ",\"" + + status.ToString() + "\"]"); profile_events_.push_back(profile_event); } @@ -801,11 +801,12 @@ void ObjectManager::ReceivePullRequest(std::shared_ptr &con ObjectID object_id = ObjectID::FromBinary(pr->object_id()->str()); ClientID client_id = ClientID::FromBinary(pr->client_id()->str()); - ProfileEventT profile_event; - profile_event.event_type = "receive_pull_request"; - profile_event.start_time = current_sys_time_seconds(); - profile_event.end_time = profile_event.start_time; - profile_event.extra_data = "[\"" + object_id.Hex() + "\",\"" + client_id.Hex() + "\"]"; + rpc::ProfileTableData::ProfileEvent profile_event; + profile_event.set_event_type("receive_pull_request"); + profile_event.set_start_time(current_sys_time_seconds()); + profile_event.set_end_time(profile_event.start_time()); + profile_event.set_extra_data("[\"" + object_id.Hex() + "\",\"" + client_id.Hex() + + "\"]"); profile_events_.push_back(profile_event); Push(object_id, client_id); @@ -938,13 +939,13 @@ void ObjectManager::SpreadFreeObjectRequest(const std::vector &object_ } } -ProfileTableDataT ObjectManager::GetAndResetProfilingInfo() { - ProfileTableDataT profile_info; - profile_info.component_type = "object_manager"; - profile_info.component_id = client_id_.Binary(); +rpc::ProfileTableData ObjectManager::GetAndResetProfilingInfo() { + rpc::ProfileTableData profile_info; + profile_info.set_component_type("object_manager"); + profile_info.set_component_id(client_id_.Binary()); for (auto const &profile_event : profile_events_) { - profile_info.profile_events.emplace_back(new ProfileEventT(profile_event)); + profile_info.add_profile_events()->CopyFrom(profile_event); } profile_events_.clear(); diff --git a/src/ray/object_manager/object_manager.h b/src/ray/object_manager/object_manager.h index 6318250ae3e8..6664dd0a93bd 100644 --- a/src/ray/object_manager/object_manager.h +++ b/src/ray/object_manager/object_manager.h @@ -180,7 +180,7 @@ class ObjectManager : public ObjectManagerInterface { /// /// \return All profiling information that has accumulated since the last call /// to this method. - ProfileTableDataT GetAndResetProfilingInfo(); + rpc::ProfileTableData GetAndResetProfilingInfo(); /// Returns debug string for class. /// @@ -412,7 +412,7 @@ class ObjectManager : public ObjectManagerInterface { /// Profiling events that are to be batched together and added to the profile /// table in the GCS. - std::vector profile_events_; + std::vector profile_events_; /// Internally maintained random number generator. std::mt19937_64 gen_; diff --git a/src/ray/object_manager/test/object_manager_stress_test.cc b/src/ray/object_manager/test/object_manager_stress_test.cc index 55aa59124a99..2d5292842acf 100644 --- a/src/ray/object_manager/test/object_manager_stress_test.cc +++ b/src/ray/object_manager/test/object_manager_stress_test.cc @@ -11,6 +11,8 @@ namespace ray { +using rpc::ClientTableData; + std::string store_executable; static inline void flushall_redis(void) { @@ -52,10 +54,10 @@ class MockServer { std::string ip = endpoint.address().to_string(); unsigned short object_manager_port = endpoint.port(); - ClientTableDataT client_info = gcs_client_->client_table().GetLocalClient(); - client_info.node_manager_address = ip; - client_info.node_manager_port = object_manager_port; - client_info.object_manager_port = object_manager_port; + ClientTableData client_info = gcs_client_->client_table().GetLocalClient(); + client_info.set_node_manager_address(ip); + client_info.set_node_manager_port(object_manager_port); + client_info.set_object_manager_port(object_manager_port); ray::Status status = gcs_client_->client_table().Connect(client_info); object_manager_.RegisterGcs(); return status; @@ -242,8 +244,8 @@ class StressTestObjectManager : public TestObjectManagerBase { client_id_2 = gcs_client_2->client_table().GetLocalClientId(); gcs_client_1->client_table().RegisterClientAddedCallback( [this](gcs::AsyncGcsClient *client, const ClientID &id, - const ClientTableDataT &data) { - ClientID parsed_id = ClientID::FromBinary(data.client_id); + const ClientTableData &data) { + ClientID parsed_id = ClientID::FromBinary(data.client_id()); if (parsed_id == client_id_1 || parsed_id == client_id_2) { num_connected_clients += 1; } @@ -438,16 +440,16 @@ class StressTestObjectManager : public TestObjectManagerBase { RAY_LOG(DEBUG) << "\n" << "All connected clients:" << "\n"; - ClientTableDataT data; + ClientTableData data; gcs_client_1->client_table().GetClient(client_id_1, data); - RAY_LOG(DEBUG) << "ClientID=" << ClientID::FromBinary(data.client_id) << "\n" - << "ClientIp=" << data.node_manager_address << "\n" - << "ClientPort=" << data.node_manager_port; - ClientTableDataT data2; + RAY_LOG(DEBUG) << "ClientID=" << ClientID::FromBinary(data.client_id()) << "\n" + << "ClientIp=" << data.node_manager_address() << "\n" + << "ClientPort=" << data.node_manager_port(); + ClientTableData data2; gcs_client_1->client_table().GetClient(client_id_2, data2); - RAY_LOG(DEBUG) << "ClientID=" << ClientID::FromBinary(data2.client_id) << "\n" - << "ClientIp=" << data2.node_manager_address << "\n" - << "ClientPort=" << data2.node_manager_port; + RAY_LOG(DEBUG) << "ClientID=" << ClientID::FromBinary(data2.client_id()) << "\n" + << "ClientIp=" << data2.node_manager_address() << "\n" + << "ClientPort=" << data2.node_manager_port(); } }; diff --git a/src/ray/object_manager/test/object_manager_test.cc b/src/ray/object_manager/test/object_manager_test.cc index ee6c78d8ed42..45b80a267f2f 100644 --- a/src/ray/object_manager/test/object_manager_test.cc +++ b/src/ray/object_manager/test/object_manager_test.cc @@ -14,6 +14,8 @@ int64_t wait_timeout_ms; namespace ray { +using rpc::ClientTableData; + static inline void flushall_redis(void) { redisContext *context = redisConnect("127.0.0.1", 6379); freeReplyObject(redisCommand(context, "FLUSHALL")); @@ -46,10 +48,10 @@ class MockServer { std::string ip = endpoint.address().to_string(); unsigned short object_manager_port = endpoint.port(); - ClientTableDataT client_info = gcs_client_->client_table().GetLocalClient(); - client_info.node_manager_address = ip; - client_info.node_manager_port = object_manager_port; - client_info.object_manager_port = object_manager_port; + ClientTableData client_info = gcs_client_->client_table().GetLocalClient(); + client_info.set_node_manager_address(ip); + client_info.set_node_manager_port(object_manager_port); + client_info.set_object_manager_port(object_manager_port); ray::Status status = gcs_client_->client_table().Connect(client_info); object_manager_.RegisterGcs(); return status; @@ -221,8 +223,8 @@ class TestObjectManager : public TestObjectManagerBase { client_id_2 = gcs_client_2->client_table().GetLocalClientId(); gcs_client_1->client_table().RegisterClientAddedCallback( [this](gcs::AsyncGcsClient *client, const ClientID &id, - const ClientTableDataT &data) { - ClientID parsed_id = ClientID::FromBinary(data.client_id); + const ClientTableData &data) { + ClientID parsed_id = ClientID::FromBinary(data.client_id()); if (parsed_id == client_id_1 || parsed_id == client_id_2) { num_connected_clients += 1; } @@ -457,19 +459,19 @@ class TestObjectManager : public TestObjectManagerBase { RAY_LOG(DEBUG) << "\n" << "Server client ids:" << "\n"; - ClientTableDataT data; + ClientTableData data; gcs_client_1->client_table().GetClient(client_id_1, data); - RAY_LOG(DEBUG) << (ClientID::FromBinary(data.client_id).IsNil()); - RAY_LOG(DEBUG) << "Server 1 ClientID=" << ClientID::FromBinary(data.client_id); - RAY_LOG(DEBUG) << "Server 1 ClientIp=" << data.node_manager_address; - RAY_LOG(DEBUG) << "Server 1 ClientPort=" << data.node_manager_port; - ASSERT_EQ(client_id_1, ClientID::FromBinary(data.client_id)); - ClientTableDataT data2; + RAY_LOG(DEBUG) << (ClientID::FromBinary(data.client_id()).IsNil()); + RAY_LOG(DEBUG) << "Server 1 ClientID=" << ClientID::FromBinary(data.client_id()); + RAY_LOG(DEBUG) << "Server 1 ClientIp=" << data.node_manager_address(); + RAY_LOG(DEBUG) << "Server 1 ClientPort=" << data.node_manager_port(); + ASSERT_EQ(client_id_1, ClientID::FromBinary(data.client_id())); + ClientTableData data2; gcs_client_1->client_table().GetClient(client_id_2, data2); - RAY_LOG(DEBUG) << "Server 2 ClientID=" << ClientID::FromBinary(data2.client_id); - RAY_LOG(DEBUG) << "Server 2 ClientIp=" << data2.node_manager_address; - RAY_LOG(DEBUG) << "Server 2 ClientPort=" << data2.node_manager_port; - ASSERT_EQ(client_id_2, ClientID::FromBinary(data2.client_id)); + RAY_LOG(DEBUG) << "Server 2 ClientID=" << ClientID::FromBinary(data2.client_id()); + RAY_LOG(DEBUG) << "Server 2 ClientIp=" << data2.node_manager_address(); + RAY_LOG(DEBUG) << "Server 2 ClientPort=" << data2.node_manager_port(); + ASSERT_EQ(client_id_2, ClientID::FromBinary(data2.client_id())); } }; diff --git a/src/ray/protobuf/gcs.proto b/src/ray/protobuf/gcs.proto new file mode 100644 index 000000000000..d0b2c5e007fe --- /dev/null +++ b/src/ray/protobuf/gcs.proto @@ -0,0 +1,280 @@ +syntax = "proto3"; + +package ray.rpc; + +option java_package = "org.ray.runtime.generated"; + +// Language of a worker or task. +enum Language { + PYTHON = 0; + CPP = 1; + JAVA = 2; +} + +// These indexes are mapped to strings in ray_redis_module.cc. +enum TablePrefix { + TABLE_PREFIX_MIN = 0; + UNUSED = 1; + TASK = 2; + RAYLET_TASK = 3; + CLIENT = 4; + OBJECT = 5; + ACTOR = 6; + FUNCTION = 7; + TASK_RECONSTRUCTION = 8; + HEARTBEAT = 9; + HEARTBEAT_BATCH = 10; + ERROR_INFO = 11; + DRIVER = 12; + PROFILE = 13; + TASK_LEASE = 14; + ACTOR_CHECKPOINT = 15; + ACTOR_CHECKPOINT_ID = 16; + NODE_RESOURCE = 17; + TABLE_PREFIX_MAX = 18; +} + +// The channel that Add operations to the Table should be published on, if any. +enum TablePubsub { + TABLE_PUBSUB_MIN = 0; + NO_PUBLISH = 1; + TASK_PUBSUB = 2; + RAYLET_TASK_PUBSUB = 3; + CLIENT_PUBSUB = 4; + OBJECT_PUBSUB = 5; + ACTOR_PUBSUB = 6; + HEARTBEAT_PUBSUB = 7; + HEARTBEAT_BATCH_PUBSUB = 8; + ERROR_INFO_PUBSUB = 9; + TASK_LEASE_PUBSUB = 10; + DRIVER_PUBSUB = 11; + NODE_RESOURCE_PUBSUB = 12; + TABLE_PUBSUB_MAX = 13; +} + +enum GcsChangeMode { + APPEND_OR_ADD = 0; + REMOVE = 1; +} + +message GcsEntry { + GcsChangeMode change_mode = 1; + bytes id = 2; + repeated bytes entries = 3; +} + +message ObjectTableData { + // The size of the object. + uint64 object_size = 1; + // The node manager ID that this object appeared on or was evicted by. + bytes manager = 2; +} + +message TaskReconstructionData { + // The number of times this task has been reconstructed so far. + uint64 num_reconstructions = 1; + // The node manager that is trying to reconstruct the task. + bytes node_manager_id = 2; +} + +// TODO(hchen): Task table currently still uses flatbuffers-defined data structure +// (`Task` in `node_manager.fbs`), because a lot of code depends on that. This should +// be migrated to protobuf very soon. +message TaskTableData { + // Flatbuffers-serialized content of the task, see `src/ray/raylet/task.h`. + bytes task = 1; +} + +message ActorTableData { + // State of an actor. + enum ActorState { + // Actor is alive. + ALIVE = 0; + // Actor is dead, now being reconstructed. + // After reconstruction finishes, the state will become alive again. + RECONSTRUCTING = 1; + // Actor is already dead and won't be reconstructed. + DEAD = 2; + } + // The ID of the actor that was created. + bytes actor_id = 1; + // The dummy object ID returned by the actor creation task. If the actor + // dies, then this is the object that should be reconstructed for the actor + // to be recreated. + bytes actor_creation_dummy_object_id = 2; + // The ID of the driver that created the actor. + bytes driver_id = 3; + // The ID of the node manager that created the actor. + bytes node_manager_id = 4; + // Current state of this actor. + ActorState state = 5; + // Max number of times this actor should be reconstructed. + uint64 max_reconstructions = 6; + // Remaining number of reconstructions. + uint64 remaining_reconstructions = 7; +} + +message ErrorTableData { + // The ID of the driver that the error is for. + bytes driver_id = 1; + // The type of the error. + string type = 2; + // The error message. + string error_message = 3; + // The timestamp of the error message. + double timestamp = 4; +} + +message ProfileTableData { + // Represents a profile event. + message ProfileEvent { + // The type of the event. + string event_type = 1; + // The start time of the event. + double start_time = 2; + // The end time of the event. If the event is a point event, then this should + // be the same as the start time. + double end_time = 3; + // Additional data associated with the event. This data must be serialized + // using JSON. + string extra_data = 4; + } + + // The type of the component that generated the event, e.g., worker or + // object_manager, or node_manager. + string component_type = 1; + // An identifier for the component that generated the event. + bytes component_id = 2; + // An identifier for the node that generated the event. + string node_ip_address = 3; + // This is a batch of profiling events. We batch these together for + // performance reasons because a single task may generate many events, and + // we don't want each event to require a GCS command. + repeated ProfileEvent profile_events = 4; +} + +message RayResource { + // The type of the resource. + string resource_name = 1; + // The total capacity of this resource type. + double resource_capacity = 2; +} + +message ClientTableData { + // Enum for the entry type in the ClientTable + enum EntryType { + INSERTION = 0; + DELETION = 1; + RES_CREATEUPDATE = 2; + RES_DELETE = 3; + } + + // The client ID of the client that the message is about. + bytes client_id = 1; + // The IP address of the client's node manager. + string node_manager_address = 2; + // The IPC socket name of the client's raylet. + string raylet_socket_name = 3; + // The IPC socket name of the client's plasma store. + string object_store_socket_name = 4; + // The port at which the client's node manager is listening for TCP + // connections from other node managers. + int32 node_manager_port = 5; + // The port at which the client's object manager is listening for TCP + // connections from other object managers. + int32 object_manager_port = 6; + // Enum to store the entry type in the log + EntryType entry_type = 7; + + // TODO(hchen): Define the following resources in map format. + repeated string resources_total_label = 8; + repeated double resources_total_capacity = 9; +} + +message HeartbeatTableData { + // Node manager client id + bytes client_id = 1; + // TODO(hchen): Define the following resources in map format. + // Resource capacity currently available on this node manager. + repeated string resources_available_label = 2; + repeated double resources_available_capacity = 3; + // Total resource capacity configured for this node manager. + repeated string resources_total_label = 4; + repeated double resources_total_capacity = 5; + // Aggregate outstanding resource load on this node manager. + repeated string resource_load_label = 6; + repeated double resource_load_capacity = 7; +} + +message HeartbeatBatchTableData { + repeated HeartbeatTableData batch = 1; +} + +// Data for a lease on task execution. +message TaskLeaseData { + // Node manager client ID. + bytes node_manager_id = 1; + // The time that the lease was last acquired at. NOTE(swang): This is the + // system clock time according to the node that added the entry and is not + // synchronized with other nodes. + uint64 acquired_at = 2; + // The period that the lease is active for. + uint64 timeout = 3; +} + +message DriverTableData { + // The driver ID. + bytes driver_id = 1; + // Whether it's dead. + bool is_dead = 2; +} + +// This table stores the actor checkpoint data. An actor checkpoint +// is the snapshot of an actor's state in the actor registration. +// See `actor_registration.h` for more detailed explanation of these fields. +message ActorCheckpointData { + // ID of this actor. + bytes actor_id = 1; + // The dummy object ID of actor's most recently executed task. + bytes execution_dependency = 2; + // A list of IDs of this actor's handles. + repeated bytes handle_ids = 3; + // The task counters of the above handles. + repeated uint64 task_counters = 4; + // The frontier dependencies of the above handles. + repeated bytes frontier_dependencies = 5; + // A list of unreleased dummy objects from this actor. + repeated bytes unreleased_dummy_objects = 6; + // The numbers of dependencies for the above unreleased dummy objects. + repeated uint32 num_dummy_object_dependencies = 7; +} + +// This table stores the actor-to-available-checkpoint-ids mapping. +message ActorCheckpointIdData { + // ID of this actor. + bytes actor_id = 1; + // IDs of this actor's available checkpoints. + repeated bytes checkpoint_ids = 2; + // A list of the timestamps for each of the above `checkpoint_ids`. + repeated uint64 timestamps = 3; +} + +// This enum type is used as object's metadata to indicate the object's creating +// task has failed because of a certain error. +// TODO(hchen): We may want to make these errors more specific. E.g., we may want +// to distinguish between intentional and expected actor failures, and between +// worker process failure and node failure. +enum ErrorType { + // Indicates that a task failed because the worker died unexpectedly while executing it. + WORKER_DIED = 0; + // Indicates that a task failed because the actor died unexpectedly before finishing it. + ACTOR_DIED = 1; + // Indicates that an object is lost and cannot be reconstructed. + // Note, this currently only happens to actor objects. When the actor's state is already + // after the object's creating task, the actor cannot re-run the task. + // TODO(hchen): we may want to reuse this error type for more cases. E.g., + // 1) A object that was put by the driver. + // 2) The object's creating task is already cleaned up from GCS (this currently + // crashes raylet). + OBJECT_UNRECONSTRUCTABLE = 2; +} diff --git a/src/ray/raylet/actor_registration.cc b/src/ray/raylet/actor_registration.cc index cc587bc4d74e..7f940006b5be 100644 --- a/src/ray/raylet/actor_registration.cc +++ b/src/ray/raylet/actor_registration.cc @@ -8,34 +8,35 @@ namespace ray { namespace raylet { -ActorRegistration::ActorRegistration(const ActorTableDataT &actor_table_data) +ActorRegistration::ActorRegistration(const ActorTableData &actor_table_data) : actor_table_data_(actor_table_data) {} -ActorRegistration::ActorRegistration(const ActorTableDataT &actor_table_data, - const ActorCheckpointDataT &checkpoint_data) +ActorRegistration::ActorRegistration(const ActorTableData &actor_table_data, + const ActorCheckpointData &checkpoint_data) : actor_table_data_(actor_table_data), - execution_dependency_(ObjectID::FromBinary(checkpoint_data.execution_dependency)) { + execution_dependency_( + ObjectID::FromBinary(checkpoint_data.execution_dependency())) { // Restore `frontier_`. - for (size_t i = 0; i < checkpoint_data.handle_ids.size(); i++) { - auto handle_id = ActorHandleID::FromBinary(checkpoint_data.handle_ids[i]); + for (size_t i = 0; i < checkpoint_data.handle_ids_size(); i++) { + auto handle_id = ActorHandleID::FromBinary(checkpoint_data.handle_ids(i)); auto &frontier_entry = frontier_[handle_id]; - frontier_entry.task_counter = checkpoint_data.task_counters[i]; + frontier_entry.task_counter = checkpoint_data.task_counters(i); frontier_entry.execution_dependency = - ObjectID::FromBinary(checkpoint_data.frontier_dependencies[i]); + ObjectID::FromBinary(checkpoint_data.frontier_dependencies(i)); } // Restore `dummy_objects_`. - for (size_t i = 0; i < checkpoint_data.unreleased_dummy_objects.size(); i++) { - auto dummy = ObjectID::FromBinary(checkpoint_data.unreleased_dummy_objects[i]); - dummy_objects_[dummy] = checkpoint_data.num_dummy_object_dependencies[i]; + for (size_t i = 0; i < checkpoint_data.unreleased_dummy_objects_size(); i++) { + auto dummy = ObjectID::FromBinary(checkpoint_data.unreleased_dummy_objects(i)); + dummy_objects_[dummy] = checkpoint_data.num_dummy_object_dependencies(i); } } const ClientID ActorRegistration::GetNodeManagerId() const { - return ClientID::FromBinary(actor_table_data_.node_manager_id); + return ClientID::FromBinary(actor_table_data_.node_manager_id()); } const ObjectID ActorRegistration::GetActorCreationDependency() const { - return ObjectID::FromBinary(actor_table_data_.actor_creation_dummy_object_id); + return ObjectID::FromBinary(actor_table_data_.actor_creation_dummy_object_id()); } const ObjectID ActorRegistration::GetExecutionDependency() const { @@ -43,15 +44,15 @@ const ObjectID ActorRegistration::GetExecutionDependency() const { } const DriverID ActorRegistration::GetDriverId() const { - return DriverID::FromBinary(actor_table_data_.driver_id); + return DriverID::FromBinary(actor_table_data_.driver_id()); } const int64_t ActorRegistration::GetMaxReconstructions() const { - return actor_table_data_.max_reconstructions; + return actor_table_data_.max_reconstructions(); } const int64_t ActorRegistration::GetRemainingReconstructions() const { - return actor_table_data_.remaining_reconstructions; + return actor_table_data_.remaining_reconstructions(); } const std::unordered_map @@ -96,7 +97,7 @@ void ActorRegistration::AddHandle(const ActorHandleID &handle_id, int ActorRegistration::NumHandles() const { return frontier_.size(); } -std::shared_ptr ActorRegistration::GenerateCheckpointData( +std::shared_ptr ActorRegistration::GenerateCheckpointData( const ActorID &actor_id, const Task &task) { const auto actor_handle_id = task.GetTaskSpecification().ActorHandleId(); const auto dummy_object = task.GetTaskSpecification().ActorDummyObject(); @@ -109,18 +110,18 @@ std::shared_ptr ActorRegistration::GenerateCheckpointData( copy.ExtendFrontier(actor_handle_id, dummy_object); // Use actor's current state to generate checkpoint data. - auto checkpoint_data = std::make_shared(); - checkpoint_data->actor_id = actor_id.Binary(); - checkpoint_data->execution_dependency = copy.GetExecutionDependency().Binary(); + auto checkpoint_data = std::make_shared(); + checkpoint_data->set_actor_id(actor_id.Binary()); + checkpoint_data->set_execution_dependency(copy.GetExecutionDependency().Binary()); for (const auto &frontier : copy.GetFrontier()) { - checkpoint_data->handle_ids.push_back(frontier.first.Binary()); - checkpoint_data->task_counters.push_back(frontier.second.task_counter); - checkpoint_data->frontier_dependencies.push_back( + checkpoint_data->add_handle_ids(frontier.first.Binary()); + checkpoint_data->add_task_counters(frontier.second.task_counter); + checkpoint_data->add_frontier_dependencies( frontier.second.execution_dependency.Binary()); } for (const auto &entry : copy.GetDummyObjects()) { - checkpoint_data->unreleased_dummy_objects.push_back(entry.first.Binary()); - checkpoint_data->num_dummy_object_dependencies.push_back(entry.second); + checkpoint_data->add_unreleased_dummy_objects(entry.first.Binary()); + checkpoint_data->add_num_dummy_object_dependencies(entry.second); } return checkpoint_data; } diff --git a/src/ray/raylet/actor_registration.h b/src/ray/raylet/actor_registration.h index 8d7ce2a449ec..208e4998263f 100644 --- a/src/ray/raylet/actor_registration.h +++ b/src/ray/raylet/actor_registration.h @@ -4,13 +4,17 @@ #include #include "ray/common/id.h" -#include "ray/gcs/format/gcs_generated.h" +#include "ray/protobuf/gcs.pb.h" #include "ray/raylet/task.h" namespace ray { namespace raylet { +using rpc::ActorTableData; +using ActorState = rpc::ActorTableData::ActorState; +using rpc::ActorCheckpointData; + /// \class ActorRegistration /// /// Information about an actor registered in the system. This includes the @@ -23,13 +27,13 @@ class ActorRegistration { /// /// \param actor_table_data Information from the global actor table about /// this actor. This includes the actor's node manager location. - ActorRegistration(const ActorTableDataT &actor_table_data); + explicit ActorRegistration(const ActorTableData &actor_table_data); /// Recreate an actor's registration from a checkpoint. /// /// \param checkpoint_data The checkpoint used to restore the actor. - ActorRegistration(const ActorTableDataT &actor_table_data, - const ActorCheckpointDataT &checkpoint_data); + ActorRegistration(const ActorTableData &actor_table_data, + const ActorCheckpointData &checkpoint_data); /// Each actor may have multiple callers, or "handles". A frontier leaf /// represents the execution state of the actor with respect to a single @@ -46,15 +50,15 @@ class ActorRegistration { /// Get the actor table data. /// /// \return The actor table data. - const ActorTableDataT &GetTableData() const { return actor_table_data_; } + const ActorTableData &GetTableData() const { return actor_table_data_; } /// Get the actor's current state (ALIVE or DEAD). /// /// \return The actor's current state. - const ActorState &GetState() const { return actor_table_data_.state; } + const ActorState GetState() const { return actor_table_data_.state(); } /// Update actor's state. - void SetState(const ActorState &state) { actor_table_data_.state = state; } + void SetState(const ActorState &state) { actor_table_data_.set_state(state); } /// Get the actor's node manager location. /// @@ -131,13 +135,13 @@ class ActorRegistration { /// \param actor_id ID of this actor. /// \param task The task that just finished on the actor. /// \return A shared pointer to the generated checkpoint data. - std::shared_ptr GenerateCheckpointData(const ActorID &actor_id, - const Task &task); + std::shared_ptr GenerateCheckpointData(const ActorID &actor_id, + const Task &task); private: /// Information from the global actor table about this actor, including the /// node manager location. - ActorTableDataT actor_table_data_; + ActorTableData actor_table_data_; /// The object representing the state following the actor's most recently /// executed task. The next task to execute on the actor should be marked as /// execution-dependent on this object. diff --git a/src/ray/raylet/lineage_cache.cc b/src/ray/raylet/lineage_cache.cc index 32dddada5244..68d5aa817c2b 100644 --- a/src/ray/raylet/lineage_cache.cc +++ b/src/ray/raylet/lineage_cache.cc @@ -63,15 +63,6 @@ void LineageEntry::UpdateTaskData(const Task &task) { Lineage::Lineage() {} -Lineage::Lineage(const protocol::ForwardTaskRequest &task_request) { - // Deserialize and set entries for the uncommitted tasks. - auto tasks = task_request.uncommitted_tasks(); - for (auto it = tasks->begin(); it != tasks->end(); it++) { - const auto &task = **it; - RAY_CHECK(SetEntry(task, GcsStatus::UNCOMMITTED)); - } -} - boost::optional Lineage::GetEntry(const TaskID &task_id) const { auto entry = entries_.find(task_id); if (entry != entries_.end()) { @@ -151,20 +142,6 @@ const std::unordered_map &Lineage::GetEntries() cons return entries_; } -flatbuffers::Offset Lineage::ToFlatbuffer( - flatbuffers::FlatBufferBuilder &fbb, const TaskID &task_id) const { - RAY_CHECK(GetEntry(task_id)); - // Serialize the task and object entries. - std::vector> uncommitted_tasks; - for (const auto &entry : entries_) { - uncommitted_tasks.push_back(entry.second.TaskData().ToFlatbuffer(fbb)); - } - - auto request = protocol::CreateForwardTaskRequest(fbb, to_flatbuf(fbb, task_id), - fbb.CreateVector(uncommitted_tasks)); - return request; -} - const std::unordered_set &Lineage::GetChildren(const TaskID &task_id) const { static const std::unordered_set empty_children; const auto it = children_.find(task_id); @@ -176,7 +153,7 @@ const std::unordered_set &Lineage::GetChildren(const TaskID &task_id) co } LineageCache::LineageCache(const ClientID &client_id, - gcs::TableInterface &task_storage, + gcs::TableInterface &task_storage, gcs::PubsubInterface &task_pubsub, uint64_t max_lineage_size) : client_id_(client_id), task_storage_(task_storage), task_pubsub_(task_pubsub) {} @@ -292,15 +269,11 @@ void LineageCache::FlushTask(const TaskID &task_id) { gcs::raylet::TaskTable::WriteCallback task_callback = [this](ray::gcs::AsyncGcsClient *client, const TaskID &id, - const protocol::TaskT &data) { HandleEntryCommitted(id); }; + const TaskTableData &data) { HandleEntryCommitted(id); }; auto task = lineage_.GetEntry(task_id); // TODO(swang): Make this better... - flatbuffers::FlatBufferBuilder fbb; - auto message = task->TaskData().ToFlatbuffer(fbb); - fbb.Finish(message); - auto task_data = std::make_shared(); - auto root = flatbuffers::GetRoot(fbb.GetBufferPointer()); - root->UnPackTo(task_data.get()); + auto task_data = std::make_shared(); + task_data->set_task(task->TaskData().Serialize()); RAY_CHECK_OK( task_storage_.Add(DriverID(task->TaskData().GetTaskSpecification().DriverId()), task_id, task_data, task_callback)); @@ -365,8 +338,6 @@ void LineageCache::EvictTask(const TaskID &task_id) { for (const auto &child_id : children) { EvictTask(child_id); } - - return; } void LineageCache::HandleEntryCommitted(const TaskID &task_id) { diff --git a/src/ray/raylet/lineage_cache.h b/src/ray/raylet/lineage_cache.h index 5436fa372fa4..37ce5caf6507 100644 --- a/src/ray/raylet/lineage_cache.h +++ b/src/ray/raylet/lineage_cache.h @@ -4,18 +4,17 @@ #include #include -// clang-format off -#include "ray/common/common_protocol.h" -#include "ray/raylet/task.h" -#include "ray/gcs/tables.h" #include "ray/common/id.h" #include "ray/common/status.h" -// clang-format on +#include "ray/gcs/tables.h" +#include "ray/raylet/task.h" namespace ray { namespace raylet { +using rpc::TaskTableData; + /// The status of a lineage cache entry according to its status in the GCS. /// Tasks can only transition to a higher GcsStatus (e.g., an UNCOMMITTED state /// can become COMMITTING but not vice versa). If a task is evicted from the @@ -136,12 +135,6 @@ class Lineage { /// Construct an empty Lineage. Lineage(); - /// Construct a Lineage from a ForwardTaskRequest. - /// - /// \param task_request The request to construct the lineage from. All - /// uncommitted tasks in the request will be added to the lineage. - Lineage(const protocol::ForwardTaskRequest &task_request); - /// Get an entry from the lineage. /// /// \param entry_id The ID of the entry to get. @@ -172,15 +165,6 @@ class Lineage { /// \return A const reference to the lineage entries. const std::unordered_map &GetEntries() const; - /// Serialize this lineage to a ForwardTaskRequest flatbuffer. - /// - /// \param entry_id The task ID to include in the ForwardTaskRequest - /// flatbuffer. - /// \return An offset to the serialized lineage. The serialization includes - /// all task and object entries in the lineage. - flatbuffers::Offset ToFlatbuffer( - flatbuffers::FlatBufferBuilder &fbb, const TaskID &entry_id) const; - /// Return the IDs of tasks in the lineage that are dependent on the given /// task. /// @@ -221,7 +205,7 @@ class LineageCache { /// Create a lineage cache for the given task storage system. /// TODO(swang): Pass in the policy (interface?). LineageCache(const ClientID &client_id, - gcs::TableInterface &task_storage, + gcs::TableInterface &task_storage, gcs::PubsubInterface &task_pubsub, uint64_t max_lineage_size); /// Asynchronously commit a task to the GCS. @@ -319,7 +303,7 @@ class LineageCache { /// TODO(swang): Move the ClientID into the generic Table implementation. ClientID client_id_; /// The durable storage system for task information. - gcs::TableInterface &task_storage_; + gcs::TableInterface &task_storage_; /// The pubsub storage system for task information. This can be used to /// request notifications for the commit of a task entry. gcs::PubsubInterface &task_pubsub_; diff --git a/src/ray/raylet/lineage_cache_test.cc b/src/ray/raylet/lineage_cache_test.cc index 43e64e400292..a6184902f803 100644 --- a/src/ray/raylet/lineage_cache_test.cc +++ b/src/ray/raylet/lineage_cache_test.cc @@ -13,7 +13,7 @@ namespace ray { namespace raylet { -class MockGcs : public gcs::TableInterface, +class MockGcs : public gcs::TableInterface, public gcs::PubsubInterface { public: MockGcs() {} @@ -23,15 +23,15 @@ class MockGcs : public gcs::TableInterface, } Status Add(const DriverID &driver_id, const TaskID &task_id, - std::shared_ptr &task_data, - const gcs::TableInterface::WriteCallback &done) { + std::shared_ptr &task_data, + const gcs::TableInterface::WriteCallback &done) { task_table_[task_id] = task_data; auto callback = done; // If we requested notifications for this task ID, send the notification as // part of the callback. if (subscribed_tasks_.count(task_id) == 1) { callback = [this, done](ray::gcs::AsyncGcsClient *client, const TaskID &task_id, - const protocol::TaskT &data) { + const TaskTableData &data) { done(client, task_id, data); // If we're subscribed to the task to be added, also send a // subscription notification. @@ -45,14 +45,14 @@ class MockGcs : public gcs::TableInterface, return ray::Status::OK(); } - Status RemoteAdd(const TaskID &task_id, std::shared_ptr task_data) { + Status RemoteAdd(const TaskID &task_id, std::shared_ptr task_data) { task_table_[task_id] = task_data; // Send a notification after the add if the lineage cache requested // notifications for this key. bool send_notification = (subscribed_tasks_.count(task_id) == 1); auto callback = [this, send_notification](ray::gcs::AsyncGcsClient *client, const TaskID &task_id, - const protocol::TaskT &data) { + const TaskTableData &data) { if (send_notification) { notification_callback_(client, task_id, data); } @@ -84,7 +84,7 @@ class MockGcs : public gcs::TableInterface, } } - const std::unordered_map> &TaskTable() const { + const std::unordered_map> &TaskTable() const { return task_table_; } @@ -95,7 +95,7 @@ class MockGcs : public gcs::TableInterface, const int NumTaskAdds() const { return num_task_adds_; } private: - std::unordered_map> task_table_; + std::unordered_map> task_table_; std::vector> callbacks_; gcs::raylet::TaskTable::WriteCallback notification_callback_; std::unordered_set subscribed_tasks_; @@ -111,7 +111,7 @@ class LineageCacheTest : public ::testing::Test { mock_gcs_(), lineage_cache_(ClientID::FromRandom(), mock_gcs_, mock_gcs_, max_lineage_size_) { mock_gcs_.Subscribe([this](ray::gcs::AsyncGcsClient *client, const TaskID &task_id, - const ray::protocol::TaskT &data) { + const TaskTableData &data) { lineage_cache_.HandleEntryCommitted(task_id); num_notifications_++; }); @@ -341,7 +341,7 @@ TEST_F(LineageCacheTest, TestEvictChain) { ASSERT_EQ(lineage_cache_.GetLineage().GetEntries().size(), tasks.size()); // Simulate executing the task on a remote node and adding it to the GCS. - auto task_data = std::make_shared(); + auto task_data = std::make_shared(); RAY_CHECK_OK( mock_gcs_.RemoteAdd(tasks.at(1).GetTaskSpecification().TaskId(), task_data)); mock_gcs_.Flush(); @@ -432,7 +432,7 @@ TEST_F(LineageCacheTest, TestEviction) { // Simulate executing the first task on a remote node and adding it to the // GCS. - auto task_data = std::make_shared(); + auto task_data = std::make_shared(); auto it = tasks.begin(); RAY_CHECK_OK(mock_gcs_.RemoteAdd(it->GetTaskSpecification().TaskId(), task_data)); it++; @@ -490,7 +490,7 @@ TEST_F(LineageCacheTest, TestOutOfOrderEviction) { auto last_task = tasks.front(); tasks.erase(tasks.begin()); for (auto it = tasks.rbegin(); it != tasks.rend(); it++) { - auto task_data = std::make_shared(); + auto task_data = std::make_shared(); RAY_CHECK_OK(mock_gcs_.RemoteAdd(it->GetTaskSpecification().TaskId(), task_data)); // Check that the remote task is flushed. num_tasks_flushed++; @@ -500,7 +500,7 @@ TEST_F(LineageCacheTest, TestOutOfOrderEviction) { } // Flush the last task. The lineage should not get evicted until this task's // commit is received. - auto task_data = std::make_shared(); + auto task_data = std::make_shared(); RAY_CHECK_OK(mock_gcs_.RemoteAdd(last_task.GetTaskSpecification().TaskId(), task_data)); num_tasks_flushed++; mock_gcs_.Flush(); @@ -536,7 +536,7 @@ TEST_F(LineageCacheTest, TestEvictionUncommittedChildren) { // until after the final remote task is executed, since a task can only be // evicted once all of its ancestors have been committed. for (auto it = tasks.rbegin(); it != tasks.rend(); it++) { - auto task_data = std::make_shared(); + auto task_data = std::make_shared(); ASSERT_EQ(lineage_cache_.GetLineage().GetEntries().size(), lineage_size * 2); RAY_CHECK_OK(mock_gcs_.RemoteAdd(it->GetTaskSpecification().TaskId(), task_data)); num_tasks_flushed++; diff --git a/src/ray/raylet/monitor.cc b/src/ray/raylet/monitor.cc index 62ecb00b819f..0a853260887e 100644 --- a/src/ray/raylet/monitor.cc +++ b/src/ray/raylet/monitor.cc @@ -24,14 +24,14 @@ Monitor::Monitor(boost::asio::io_service &io_service, const std::string &redis_a } void Monitor::HandleHeartbeat(const ClientID &client_id, - const HeartbeatTableDataT &heartbeat_data) { + const HeartbeatTableData &heartbeat_data) { heartbeats_[client_id] = num_heartbeats_timeout_; heartbeat_buffer_[client_id] = heartbeat_data; } void Monitor::Start() { const auto heartbeat_callback = [this](gcs::AsyncGcsClient *client, const ClientID &id, - const HeartbeatTableDataT &heartbeat_data) { + const HeartbeatTableData &heartbeat_data) { HandleHeartbeat(id, heartbeat_data); }; RAY_CHECK_OK(gcs_client_.heartbeat_table().Subscribe( @@ -49,11 +49,11 @@ void Monitor::Tick() { RAY_LOG(WARNING) << "Client timed out: " << client_id; auto lookup_callback = [this, client_id]( gcs::AsyncGcsClient *client, const ClientID &id, - const std::vector &all_data) { + const std::vector &all_data) { bool marked = false; for (const auto &data : all_data) { - if (client_id.Binary() == data.client_id && - data.entry_type == EntryType::DELETION) { + if (client_id.Binary() == data.client_id() && + data.entry_type() == ClientTableData::DELETION) { // The node has been marked dead by itself. marked = true; } @@ -84,10 +84,9 @@ void Monitor::Tick() { // Send any buffered heartbeats as a single publish. if (!heartbeat_buffer_.empty()) { - auto batch = std::make_shared(); + auto batch = std::make_shared(); for (const auto &heartbeat : heartbeat_buffer_) { - batch->batch.push_back(std::unique_ptr( - new HeartbeatTableDataT(heartbeat.second))); + batch->add_batch()->CopyFrom(heartbeat.second); } RAY_CHECK_OK(gcs_client_.heartbeat_batch_table().Add(DriverID::Nil(), ClientID::Nil(), batch, nullptr)); diff --git a/src/ray/raylet/monitor.h b/src/ray/raylet/monitor.h index c69cc9f003e0..5725e52cf495 100644 --- a/src/ray/raylet/monitor.h +++ b/src/ray/raylet/monitor.h @@ -11,6 +11,10 @@ namespace ray { namespace raylet { +using rpc::ClientTableData; +using rpc::HeartbeatBatchTableData; +using rpc::HeartbeatTableData; + class Monitor { public: /// Create a Raylet monitor attached to the given GCS address and port. @@ -35,7 +39,7 @@ class Monitor { /// \param client_id The client ID of the Raylet that sent the heartbeat. /// \param heartbeat_data The heartbeat sent by the client. void HandleHeartbeat(const ClientID &client_id, - const HeartbeatTableDataT &heartbeat_data); + const HeartbeatTableData &heartbeat_data); private: /// A client to the GCS, through which heartbeats are received. @@ -50,7 +54,7 @@ class Monitor { /// The Raylets that have been marked as dead in the client table. std::unordered_set dead_clients_; /// A buffer containing heartbeats received from node managers in the last tick. - std::unordered_map heartbeat_buffer_; + std::unordered_map heartbeat_buffer_; }; } // namespace raylet diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index a0bde1ff0655..226a8fb6d251 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -46,9 +46,9 @@ ActorStats GetActorStatisticalData( std::unordered_map actor_registry) { ActorStats item; for (auto &pair : actor_registry) { - if (pair.second.GetState() == ActorState::ALIVE) { + if (pair.second.GetState() == ray::rpc::ActorTableData::ALIVE) { item.live_actors += 1; - } else if (pair.second.GetState() == ActorState::RECONSTRUCTING) { + } else if (pair.second.GetState() == ray::rpc::ActorTableData::RECONSTRUCTING) { item.reconstructing_actors += 1; } else { item.dead_actors += 1; @@ -83,7 +83,8 @@ NodeManager::NodeManager(boost::asio::io_service &io_service, initial_config_(config), local_available_resources_(config.resource_config), worker_pool_(config.num_initial_workers, config.num_workers_per_process, - config.maximum_startup_concurrency, config.worker_commands), + config.maximum_startup_concurrency, gcs_client_, + config.worker_commands), scheduling_policy_(local_queues_), reconstruction_policy_( io_service_, @@ -100,7 +101,8 @@ NodeManager::NodeManager(boost::asio::io_service &io_service, gcs_client_->raylet_task_table(), gcs_client_->raylet_task_table(), config.max_lineage_size), actor_registry_(), - node_manager_server_(config.node_manager_port, io_service, *this), + node_manager_server_("NodeManager", config.node_manager_port), + node_manager_service_(io_service, *this), client_call_manager_(io_service) { RAY_CHECK(heartbeat_period_.count() > 0); // Initialize the resource map with own cluster resource configuration. @@ -118,6 +120,7 @@ NodeManager::NodeManager(boost::asio::io_service &io_service, RAY_ARROW_CHECK_OK(store_client_.Connect(config.store_socket_name.c_str())); // Run the node manger rpc server. + node_manager_server_.RegisterService(node_manager_service_); node_manager_server_.Run(); } @@ -129,7 +132,7 @@ ray::Status NodeManager::RegisterGcs() { // that were executed remotely. const auto task_committed_callback = [this](gcs::AsyncGcsClient *client, const TaskID &task_id, - const ray::protocol::TaskT &task_data) { + const TaskTableData &task_data) { lineage_cache_.HandleEntryCommitted(task_id); }; RAY_RETURN_NOT_OK(gcs_client_->raylet_task_table().Subscribe( @@ -138,8 +141,8 @@ ray::Status NodeManager::RegisterGcs() { const auto task_lease_notification_callback = [this](gcs::AsyncGcsClient *client, const TaskID &task_id, - const TaskLeaseDataT &task_lease) { - const ClientID node_manager_id = ClientID::FromBinary(task_lease.node_manager_id); + const TaskLeaseData &task_lease) { + const ClientID node_manager_id = ClientID::FromBinary(task_lease.node_manager_id()); if (gcs_client_->client_table().IsRemoved(node_manager_id)) { // The node manager that added the task lease is already removed. The // lease is considered inactive. @@ -149,7 +152,7 @@ ray::Status NodeManager::RegisterGcs() { // expiration period since the entry may have been in the GCS for some // time already. For a more accurate estimate, the age of the entry in // the GCS should be subtracted from task_lease.timeout. - reconstruction_policy_.HandleTaskLeaseNotification(task_id, task_lease.timeout); + reconstruction_policy_.HandleTaskLeaseNotification(task_id, task_lease.timeout()); } }; const auto task_lease_empty_callback = [this](gcs::AsyncGcsClient *client, @@ -163,7 +166,7 @@ ray::Status NodeManager::RegisterGcs() { // Register a callback to handle actor notifications. auto actor_notification_callback = [this](gcs::AsyncGcsClient *client, const ActorID &actor_id, - const std::vector &data) { + const std::vector &data) { if (!data.empty()) { // We only need the last entry, because it represents the latest state of // this actor. @@ -176,34 +179,34 @@ ray::Status NodeManager::RegisterGcs() { // Register a callback on the client table for new clients. auto node_manager_client_added = [this](gcs::AsyncGcsClient *client, const UniqueID &id, - const ClientTableDataT &data) { + const ClientTableData &data) { ClientAdded(data); }; gcs_client_->client_table().RegisterClientAddedCallback(node_manager_client_added); // Register a callback on the client table for removed clients. auto node_manager_client_removed = [this](gcs::AsyncGcsClient *client, const UniqueID &id, - const ClientTableDataT &data) { ClientRemoved(data); }; + const ClientTableData &data) { ClientRemoved(data); }; gcs_client_->client_table().RegisterClientRemovedCallback(node_manager_client_removed); // Register a callback on the client table for resource create/update requests auto node_manager_resource_createupdated = [this](gcs::AsyncGcsClient *client, const UniqueID &id, - const ClientTableDataT &data) { ResourceCreateUpdated(data); }; + const ClientTableData &data) { ResourceCreateUpdated(data); }; gcs_client_->client_table().RegisterResourceCreateUpdatedCallback( node_manager_resource_createupdated); // Register a callback on the client table for resource delete requests auto node_manager_resource_deleted = [this](gcs::AsyncGcsClient *client, const UniqueID &id, - const ClientTableDataT &data) { ResourceDeleted(data); }; + const ClientTableData &data) { ResourceDeleted(data); }; gcs_client_->client_table().RegisterResourceDeletedCallback( node_manager_resource_deleted); // Subscribe to heartbeat batches from the monitor. const auto &heartbeat_batch_added = [this](gcs::AsyncGcsClient *client, const ClientID &id, - const HeartbeatBatchTableDataT &heartbeat_batch) { + const HeartbeatBatchTableData &heartbeat_batch) { HeartbeatBatchAdded(heartbeat_batch); }; RAY_RETURN_NOT_OK(gcs_client_->heartbeat_batch_table().Subscribe( @@ -214,7 +217,7 @@ ray::Status NodeManager::RegisterGcs() { // Subscribe to driver table updates. const auto driver_table_handler = [this](gcs::AsyncGcsClient *client, const DriverID &client_id, - const std::vector &driver_data) { + const std::vector &driver_data) { HandleDriverTableUpdate(client_id, driver_data); }; RAY_RETURN_NOT_OK(gcs_client_->driver_table().Subscribe( @@ -250,12 +253,12 @@ void NodeManager::KillWorker(std::shared_ptr worker) { } void NodeManager::HandleDriverTableUpdate( - const DriverID &id, const std::vector &driver_data) { + const DriverID &id, const std::vector &driver_data) { for (const auto &entry : driver_data) { - RAY_LOG(DEBUG) << "HandleDriverTableUpdate " << UniqueID::FromBinary(entry.driver_id) - << " " << entry.is_dead; - if (entry.is_dead) { - auto driver_id = DriverID::FromBinary(entry.driver_id); + RAY_LOG(DEBUG) << "HandleDriverTableUpdate " + << UniqueID::FromBinary(entry.driver_id()) << " " << entry.is_dead(); + if (entry.is_dead()) { + auto driver_id = DriverID::FromBinary(entry.driver_id()); auto workers = worker_pool_.GetWorkersRunningTasksForDriver(driver_id); // Kill all the workers. The actual cleanup for these workers is done @@ -287,26 +290,26 @@ void NodeManager::Heartbeat() { last_heartbeat_at_ms_ = now_ms; auto &heartbeat_table = gcs_client_->heartbeat_table(); - auto heartbeat_data = std::make_shared(); + auto heartbeat_data = std::make_shared(); const auto &my_client_id = gcs_client_->client_table().GetLocalClientId(); SchedulingResources &local_resources = cluster_resource_map_[my_client_id]; - heartbeat_data->client_id = my_client_id.Binary(); + heartbeat_data->set_client_id(my_client_id.Binary()); // TODO(atumanov): modify the heartbeat table protocol to use the ResourceSet directly. // TODO(atumanov): implement a ResourceSet const_iterator. for (const auto &resource_pair : local_resources.GetAvailableResources().GetResourceMap()) { - heartbeat_data->resources_available_label.push_back(resource_pair.first); - heartbeat_data->resources_available_capacity.push_back(resource_pair.second); + heartbeat_data->add_resources_available_label(resource_pair.first); + heartbeat_data->add_resources_available_capacity(resource_pair.second); } for (const auto &resource_pair : local_resources.GetTotalResources().GetResourceMap()) { - heartbeat_data->resources_total_label.push_back(resource_pair.first); - heartbeat_data->resources_total_capacity.push_back(resource_pair.second); + heartbeat_data->add_resources_total_label(resource_pair.first); + heartbeat_data->add_resources_total_capacity(resource_pair.second); } local_resources.SetLoadResources(local_queues_.GetResourceLoad()); for (const auto &resource_pair : local_resources.GetLoadResources().GetResourceMap()) { - heartbeat_data->resource_load_label.push_back(resource_pair.first); - heartbeat_data->resource_load_capacity.push_back(resource_pair.second); + heartbeat_data->add_resource_load_label(resource_pair.first); + heartbeat_data->add_resource_load_capacity(resource_pair.second); } ray::Status status = heartbeat_table.Add( @@ -334,13 +337,8 @@ void NodeManager::GetObjectManagerProfileInfo() { auto profile_info = object_manager_.GetAndResetProfilingInfo(); - if (profile_info.profile_events.size() > 0) { - flatbuffers::FlatBufferBuilder fbb; - auto message = CreateProfileTableData(fbb, &profile_info); - fbb.Finish(message); - auto profile_message = flatbuffers::GetRoot(fbb.GetBufferPointer()); - - RAY_CHECK_OK(gcs_client_->profile_table().AddProfileEventBatch(*profile_message)); + if (profile_info.profile_events_size() > 0) { + RAY_CHECK_OK(gcs_client_->profile_table().AddProfileEventBatch(profile_info)); } // Reset the timer. @@ -357,8 +355,8 @@ void NodeManager::GetObjectManagerProfileInfo() { } } -void NodeManager::ClientAdded(const ClientTableDataT &client_data) { - const ClientID client_id = ClientID::FromBinary(client_data.client_id); +void NodeManager::ClientAdded(const ClientTableData &client_data) { + const ClientID client_id = ClientID::FromBinary(client_data.client_id()); RAY_LOG(DEBUG) << "[ClientAdded] Received callback from client id " << client_id; if (client_id == gcs_client_->client_table().GetLocalClientId()) { @@ -377,19 +375,20 @@ void NodeManager::ClientAdded(const ClientTableDataT &client_data) { // Initialize a rpc client to the new node manager. std::unique_ptr client( - new rpc::NodeManagerClient(client_data.node_manager_address, - client_data.node_manager_port, client_call_manager_)); + new rpc::NodeManagerClient(client_data.node_manager_address(), + client_data.node_manager_port(), client_call_manager_)); remote_node_manager_clients_.emplace(client_id, std::move(client)); - ResourceSet resources_total(client_data.resources_total_label, - client_data.resources_total_capacity); + ResourceSet resources_total( + rpc::VectorFromProtobuf(client_data.resources_total_label()), + rpc::VectorFromProtobuf(client_data.resources_total_capacity())); cluster_resource_map_.emplace(client_id, SchedulingResources(resources_total)); } -void NodeManager::ClientRemoved(const ClientTableDataT &client_data) { +void NodeManager::ClientRemoved(const ClientTableData &client_data) { // TODO(swang): If we receive a notification for our own death, clean up and // exit immediately. - const ClientID client_id = ClientID::FromBinary(client_data.client_id); + const ClientID client_id = ClientID::FromBinary(client_data.client_id()); RAY_LOG(DEBUG) << "[ClientRemoved] Received callback from client id " << client_id; RAY_CHECK(client_id != gcs_client_->client_table().GetLocalClientId()) @@ -417,7 +416,7 @@ void NodeManager::ClientRemoved(const ClientTableDataT &client_data) { // TODO(swang): This could be very slow if there are many actors. for (const auto &actor_entry : actor_registry_) { if (actor_entry.second.GetNodeManagerId() == client_id && - actor_entry.second.GetState() == ActorState::ALIVE) { + actor_entry.second.GetState() == ActorTableData::ALIVE) { RAY_LOG(INFO) << "Actor " << actor_entry.first << " is disconnected, because its node " << client_id << " is removed from cluster. It may be reconstructed."; @@ -435,14 +434,15 @@ void NodeManager::ClientRemoved(const ClientTableDataT &client_data) { lineage_cache_.FlushAllUncommittedTasks(); } -void NodeManager::ResourceCreateUpdated(const ClientTableDataT &client_data) { - const ClientID client_id = ClientID::FromBinary(client_data.client_id); +void NodeManager::ResourceCreateUpdated(const ClientTableData &client_data) { + const ClientID client_id = ClientID::FromBinary(client_data.client_id()); const ClientID &local_client_id = gcs_client_->client_table().GetLocalClientId(); RAY_LOG(DEBUG) << "[ResourceCreateUpdated] received callback from client id " << client_id << ". Updating resource map."; - ResourceSet new_res_set(client_data.resources_total_label, - client_data.resources_total_capacity); + ResourceSet new_res_set( + rpc::VectorFromProtobuf(client_data.resources_total_label()), + rpc::VectorFromProtobuf(client_data.resources_total_capacity())); const ResourceSet &old_res_set = cluster_resource_map_[client_id].GetTotalResources(); ResourceSet difference_set = old_res_set.FindUpdatedResources(new_res_set); @@ -471,12 +471,13 @@ void NodeManager::ResourceCreateUpdated(const ClientTableDataT &client_data) { return; } -void NodeManager::ResourceDeleted(const ClientTableDataT &client_data) { - const ClientID client_id = ClientID::FromBinary(client_data.client_id); +void NodeManager::ResourceDeleted(const ClientTableData &client_data) { + const ClientID client_id = ClientID::FromBinary(client_data.client_id()); const ClientID &local_client_id = gcs_client_->client_table().GetLocalClientId(); - ResourceSet new_res_set(client_data.resources_total_label, - client_data.resources_total_capacity); + ResourceSet new_res_set( + rpc::VectorFromProtobuf(client_data.resources_total_label()), + rpc::VectorFromProtobuf(client_data.resources_total_capacity())); RAY_LOG(DEBUG) << "[ResourceDeleted] received callback from client id " << client_id << " with new resources: " << new_res_set.ToString() << ". Updating resource map."; @@ -522,7 +523,7 @@ void NodeManager::TryLocalInfeasibleTaskScheduling() { } void NodeManager::HeartbeatAdded(const ClientID &client_id, - const HeartbeatTableDataT &heartbeat_data) { + const HeartbeatTableData &heartbeat_data) { // Locate the client id in remote client table and update available resources based on // the received heartbeat information. auto it = cluster_resource_map_.find(client_id); @@ -534,10 +535,12 @@ void NodeManager::HeartbeatAdded(const ClientID &client_id, } SchedulingResources &remote_resources = it->second; - ResourceSet remote_available(heartbeat_data.resources_available_label, - heartbeat_data.resources_available_capacity); - ResourceSet remote_load(heartbeat_data.resource_load_label, - heartbeat_data.resource_load_capacity); + ResourceSet remote_available( + rpc::VectorFromProtobuf(heartbeat_data.resources_total_label()), + rpc::VectorFromProtobuf(heartbeat_data.resources_total_capacity())); + ResourceSet remote_load( + rpc::VectorFromProtobuf(heartbeat_data.resource_load_label()), + rpc::VectorFromProtobuf(heartbeat_data.resource_load_capacity())); // TODO(atumanov): assert that the load is a non-empty ResourceSet. remote_resources.SetAvailableResources(std::move(remote_available)); // Extract the load information and save it locally. @@ -562,40 +565,41 @@ void NodeManager::HeartbeatAdded(const ClientID &client_id, } } -void NodeManager::HeartbeatBatchAdded(const HeartbeatBatchTableDataT &heartbeat_batch) { +void NodeManager::HeartbeatBatchAdded(const HeartbeatBatchTableData &heartbeat_batch) { const ClientID &local_client_id = gcs_client_->client_table().GetLocalClientId(); // Update load information provided by each heartbeat. - for (const auto &heartbeat_data : heartbeat_batch.batch) { - const ClientID &client_id = ClientID::FromBinary(heartbeat_data->client_id); + for (const auto &heartbeat_data : heartbeat_batch.batch()) { + const ClientID &client_id = ClientID::FromBinary(heartbeat_data.client_id()); if (client_id == local_client_id) { // Skip heartbeats from self. continue; } - HeartbeatAdded(client_id, *heartbeat_data); + HeartbeatAdded(client_id, heartbeat_data); } } void NodeManager::PublishActorStateTransition( - const ActorID &actor_id, const ActorTableDataT &data, + const ActorID &actor_id, const ActorTableData &data, const ray::gcs::ActorTable::WriteCallback &failure_callback) { // Copy the actor notification data. - auto actor_notification = std::make_shared(data); + auto actor_notification = std::make_shared(data); // The actor log starts with an ALIVE entry. This is followed by 0 to N pairs // of (RECONSTRUCTING, ALIVE) entries, where N is the maximum number of // reconstructions. This is followed optionally by a DEAD entry. - int log_length = 2 * (actor_notification->max_reconstructions - - actor_notification->remaining_reconstructions); - if (actor_notification->state != ActorState::ALIVE) { + int log_length = 2 * (actor_notification->max_reconstructions() - + actor_notification->remaining_reconstructions()); + if (actor_notification->state() != ActorTableData::ALIVE) { // RECONSTRUCTING or DEAD entries have an odd index. log_length += 1; } // If we successful appended a record to the GCS table of the actor that // has died, signal this to anyone receiving signals from this actor. auto success_callback = [](gcs::AsyncGcsClient *client, const ActorID &id, - const ActorTableDataT &data) { + const ActorTableData &data) { auto redis_context = client->primary_context(); - if (data.state == ActorState::DEAD || data.state == ActorState::RECONSTRUCTING) { + if (data.state() == ActorTableData::DEAD || + data.state() == ActorTableData::RECONSTRUCTING) { std::vector args = {"XADD", id.Hex(), "*", "signal", "ACTOR_DIED_SIGNAL"}; RAY_CHECK_OK(redis_context->RunArgvAsync(args)); @@ -632,11 +636,12 @@ void NodeManager::HandleActorStateTransition(const ActorID &actor_id, } RAY_LOG(DEBUG) << "Actor notification received: actor_id = " << actor_id << ", node_manager_id = " << actor_registration.GetNodeManagerId() - << ", state = " << EnumNameActorState(actor_registration.GetState()) + << ", state = " + << ActorTableData::ActorState_Name(actor_registration.GetState()) << ", remaining_reconstructions = " << actor_registration.GetRemainingReconstructions(); - if (actor_registration.GetState() == ActorState::ALIVE) { + if (actor_registration.GetState() == ActorTableData::ALIVE) { // The actor's location is now known. Dequeue any methods that were // submitted before the actor's location was known. // (See design_docs/task_states.rst for the state transition diagram.) @@ -663,7 +668,7 @@ void NodeManager::HandleActorStateTransition(const ActorID &actor_id, // empty lineage this time. SubmitTask(method, Lineage()); } - } else if (actor_registration.GetState() == ActorState::DEAD) { + } else if (actor_registration.GetState() == ActorTableData::DEAD) { // When an actor dies, loop over all of the queued tasks for that actor // and treat them as failed. auto tasks_to_remove = local_queues_.GetTaskIdsForActor(actor_id); @@ -672,7 +677,7 @@ void NodeManager::HandleActorStateTransition(const ActorID &actor_id, TreatTaskAsFailed(task, ErrorType::ACTOR_DIED); } } else { - RAY_CHECK(actor_registration.GetState() == ActorState::RECONSTRUCTING); + RAY_CHECK(actor_registration.GetState() == ActorTableData::RECONSTRUCTING); RAY_LOG(DEBUG) << "Actor is being reconstructed: " << actor_id; // When an actor fails but can be reconstructed, resubmit all of the queued // tasks for that actor. This will mark the tasks as waiting for actor @@ -793,8 +798,20 @@ void NodeManager::ProcessClientMessage( ProcessPushErrorRequestMessage(message_data); } break; case protocol::MessageType::PushProfileEventsRequest: { - auto message = flatbuffers::GetRoot(message_data); - RAY_CHECK_OK(gcs_client_->profile_table().AddProfileEventBatch(*message)); + ProfileTableDataT fbs_message; + flatbuffers::GetRoot(message_data)->UnPackTo(&fbs_message); + rpc::ProfileTableData profile_table_data; + profile_table_data.set_component_type(fbs_message.component_type); + profile_table_data.set_component_id(fbs_message.component_id); + for (const auto &fbs_event : fbs_message.profile_events) { + rpc::ProfileTableData::ProfileEvent *event = + profile_table_data.add_profile_events(); + event->set_event_type(fbs_event->event_type); + event->set_start_time(fbs_event->start_time); + event->set_end_time(fbs_event->end_time); + event->set_extra_data(fbs_event->extra_data); + } + RAY_CHECK_OK(gcs_client_->profile_table().AddProfileEventBatch(profile_table_data)); } break; case protocol::MessageType::FreeObjectsInObjectStoreRequest: { auto message = flatbuffers::GetRoot(message_data); @@ -862,8 +879,8 @@ void NodeManager::HandleDisconnectedActor(const ActorID &actor_id, bool was_loca // Check if this actor needs to be reconstructed. ActorState new_state = actor_registration.GetRemainingReconstructions() > 0 && !intentional_disconnect - ? ActorState::RECONSTRUCTING - : ActorState::DEAD; + ? ActorTableData::RECONSTRUCTING + : ActorTableData::DEAD; if (was_local) { // Clean up the dummy objects from this actor. RAY_LOG(DEBUG) << "Removing dummy objects for actor: " << actor_id; @@ -872,8 +889,8 @@ void NodeManager::HandleDisconnectedActor(const ActorID &actor_id, bool was_loca } } // Update the actor's state. - ActorTableDataT new_actor_data = actor_entry->second.GetTableData(); - new_actor_data.state = new_state; + ActorTableData new_actor_data = actor_entry->second.GetTableData(); + new_actor_data.set_state(new_state); if (was_local) { // If the actor was local, immediately update the state in actor registry. // So if we receive any actor tasks before we receive GCS notification, @@ -884,7 +901,7 @@ void NodeManager::HandleDisconnectedActor(const ActorID &actor_id, bool was_loca ray::gcs::ActorTable::WriteCallback failure_callback = nullptr; if (was_local) { failure_callback = [](gcs::AsyncGcsClient *client, const ActorID &id, - const ActorTableDataT &data) { + const ActorTableData &data) { // If the disconnected actor was local, only this node will try to update actor // state. So the update shouldn't fail. RAY_LOG(FATAL) << "Failed to update state for actor " << id; @@ -1159,7 +1176,7 @@ void NodeManager::ProcessPrepareActorCheckpointRequest( DriverID::Nil(), checkpoint_id, checkpoint_data, [worker, actor_id, this](ray::gcs::AsyncGcsClient *client, const ActorCheckpointID &checkpoint_id, - const ActorCheckpointDataT &data) { + const ActorCheckpointData &data) { RAY_LOG(DEBUG) << "Checkpoint " << checkpoint_id << " saved for actor " << worker->GetActorId(); // Save this actor-to-checkpoint mapping, and remove old checkpoints associated @@ -1243,19 +1260,19 @@ void NodeManager::ProcessSetResourceRequest( return; } - // Add the new resource to a skeleton ClientTableDataT object - ClientTableDataT data; + // Add the new resource to a skeleton ClientTableData object + ClientTableData data; gcs_client_->client_table().GetClient(client_id, data); // Replace the resource vectors with the resource deltas from the message. // RES_CREATEUPDATE and RES_DELETE entries in the ClientTable track changes (deltas) in // the resources - data.resources_total_label = std::vector{resource_name}; - data.resources_total_capacity = std::vector{capacity}; + data.add_resources_total_label(resource_name); + data.add_resources_total_capacity(capacity); // Set the correct flag for entry_type if (is_deletion) { - data.entry_type = EntryType::RES_DELETE; + data.set_entry_type(ClientTableData::RES_DELETE); } else { - data.entry_type = EntryType::RES_CREATEUPDATE; + data.set_entry_type(ClientTableData::RES_CREATEUPDATE); } // Submit to the client table. This calls the ResourceCreateUpdated callback, which @@ -1264,7 +1281,7 @@ void NodeManager::ProcessSetResourceRequest( if (not worker) { worker = worker_pool_.GetRegisteredDriver(client); } - auto data_shared_ptr = std::make_shared(data); + auto data_shared_ptr = std::make_shared(data); auto client_table = gcs_client_->client_table(); RAY_CHECK_OK(gcs_client_->client_table().Append( DriverID::Nil(), client_table.client_log_key_, data_shared_ptr, nullptr)); @@ -1369,7 +1386,7 @@ bool NodeManager::CheckDependencyManagerInvariant() const { void NodeManager::TreatTaskAsFailed(const Task &task, const ErrorType &error_type) { const TaskSpecification &spec = task.GetTaskSpecification(); RAY_LOG(DEBUG) << "Treating task " << spec.TaskId() << " as failed because of error " - << EnumNameErrorType(error_type) << "."; + << ErrorType_Name(error_type) << "."; // If this was an actor creation task that tried to resume from a checkpoint, // then erase it here since the task did not finish. if (spec.IsActorCreationTask()) { @@ -1487,9 +1504,9 @@ void NodeManager::SubmitTask(const Task &task, const Lineage &uncommitted_lineag // If we have already seen this actor and this actor is not being reconstructed, // its location is known. bool location_known = - seen && actor_entry->second.GetState() != ActorState::RECONSTRUCTING; + seen && actor_entry->second.GetState() != ActorTableData::RECONSTRUCTING; if (location_known) { - if (actor_entry->second.GetState() == ActorState::DEAD) { + if (actor_entry->second.GetState() == ActorTableData::DEAD) { // If this actor is dead, either because the actor process is dead // or because its residing node is dead, treat this task as failed. TreatTaskAsFailed(task, ErrorType::ACTOR_DIED); @@ -1534,7 +1551,7 @@ void NodeManager::SubmitTask(const Task &task, const Lineage &uncommitted_lineag // we missed the creation notification. auto lookup_callback = [this](gcs::AsyncGcsClient *client, const ActorID &actor_id, - const std::vector &data) { + const std::vector &data) { if (!data.empty()) { // The actor has been created. We only need the last entry, because // it represents the latest state of this actor. @@ -1723,18 +1740,6 @@ bool NodeManager::AssignTask(const Task &task) { std::shared_ptr worker = worker_pool_.PopWorker(spec); if (worker == nullptr) { // There are no workers that can execute this task. - if (!spec.IsActorTask()) { - // There are no more non-actor workers available to execute this task. - // Start a new worker. - worker_pool_.StartWorkerProcess(spec.GetLanguage()); - // Push an error message to the user if the worker pool tells us that it is - // getting too big. - const std::string warning_message = worker_pool_.WarningAboutSize(); - if (warning_message != "") { - RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver( - DriverID::Nil(), "worker_pool_large", warning_message, current_time_ms())); - } - } // We couldn't assign this task, as no worker available. return false; } @@ -1872,11 +1877,11 @@ void NodeManager::FinishAssignedTask(Worker &worker) { } } -ActorTableDataT NodeManager::CreateActorTableDataFromCreationTask(const Task &task) { +ActorTableData NodeManager::CreateActorTableDataFromCreationTask(const Task &task) { RAY_CHECK(task.GetTaskSpecification().IsActorCreationTask()); auto actor_id = task.GetTaskSpecification().ActorCreationId(); auto actor_entry = actor_registry_.find(actor_id); - ActorTableDataT new_actor_data; + ActorTableData new_actor_data; // TODO(swang): If this is an actor that was reconstructed, and previous // actor notifications were delayed, then this node may not have an entry for // the actor in actor_regisry_. Then, the fields for the number of @@ -1884,32 +1889,33 @@ ActorTableDataT NodeManager::CreateActorTableDataFromCreationTask(const Task &ta if (actor_entry == actor_registry_.end()) { // Set all of the static fields for the actor. These fields will not // change even if the actor fails or is reconstructed. - new_actor_data.actor_id = actor_id.Binary(); - new_actor_data.actor_creation_dummy_object_id = - task.GetTaskSpecification().ActorDummyObject().Binary(); - new_actor_data.driver_id = task.GetTaskSpecification().DriverId().Binary(); - new_actor_data.max_reconstructions = - task.GetTaskSpecification().MaxActorReconstructions(); + new_actor_data.set_actor_id(actor_id.Binary()); + new_actor_data.set_actor_creation_dummy_object_id( + task.GetTaskSpecification().ActorDummyObject().Binary()); + new_actor_data.set_driver_id(task.GetTaskSpecification().DriverId().Binary()); + new_actor_data.set_max_reconstructions( + task.GetTaskSpecification().MaxActorReconstructions()); // This is the first time that the actor has been created, so the number // of remaining reconstructions is the max. - new_actor_data.remaining_reconstructions = - task.GetTaskSpecification().MaxActorReconstructions(); + new_actor_data.set_remaining_reconstructions( + task.GetTaskSpecification().MaxActorReconstructions()); } else { // If we've already seen this actor, it means that this actor was reconstructed. // Thus, its previous state must be RECONSTRUCTING. - RAY_CHECK(actor_entry->second.GetState() == ActorState::RECONSTRUCTING); + RAY_CHECK(actor_entry->second.GetState() == ActorTableData::RECONSTRUCTING); // Copy the static fields from the current actor entry. new_actor_data = actor_entry->second.GetTableData(); // We are reconstructing the actor, so subtract its // remaining_reconstructions by 1. - new_actor_data.remaining_reconstructions--; + new_actor_data.set_remaining_reconstructions( + new_actor_data.remaining_reconstructions() - 1); } // Set the new fields for the actor's state to indicate that the actor is // now alive on this node manager. - new_actor_data.node_manager_id = - gcs_client_->client_table().GetLocalClientId().Binary(); - new_actor_data.state = ActorState::ALIVE; + new_actor_data.set_node_manager_id( + gcs_client_->client_table().GetLocalClientId().Binary()); + new_actor_data.set_state(ActorTableData::ALIVE); return new_actor_data; } @@ -1945,7 +1951,7 @@ void NodeManager::FinishAssignedActorTask(Worker &worker, const Task &task) { DriverID::Nil(), checkpoint_id, [this, actor_id, new_actor_data](ray::gcs::AsyncGcsClient *client, const UniqueID &checkpoint_id, - const ActorCheckpointDataT &checkpoint_data) { + const ActorCheckpointData &checkpoint_data) { RAY_LOG(INFO) << "Restoring registration for actor " << actor_id << " from checkpoint " << checkpoint_id; ActorRegistration actor_registration = @@ -1959,7 +1965,7 @@ void NodeManager::FinishAssignedActorTask(Worker &worker, const Task &task) { actor_id, new_actor_data, /*failure_callback=*/ [](gcs::AsyncGcsClient *client, const ActorID &id, - const ActorTableDataT &data) { + const ActorTableData &data) { // Only one node at a time should succeed at creating the actor. RAY_LOG(FATAL) << "Failed to update state to ALIVE for actor " << id; }); @@ -1975,8 +1981,7 @@ void NodeManager::FinishAssignedActorTask(Worker &worker, const Task &task) { PublishActorStateTransition( actor_id, new_actor_data, /*failure_callback=*/ - [](gcs::AsyncGcsClient *client, const ActorID &id, - const ActorTableDataT &data) { + [](gcs::AsyncGcsClient *client, const ActorID &id, const ActorTableData &data) { // Only one node at a time should succeed at creating the actor. RAY_LOG(FATAL) << "Failed to update state to ALIVE for actor " << id; }); @@ -2015,10 +2020,11 @@ void NodeManager::HandleTaskReconstruction(const TaskID &task_id) { DriverID::Nil(), task_id, /*success_callback=*/ [this](ray::gcs::AsyncGcsClient *client, const TaskID &task_id, - const ray::protocol::TaskT &task_data) { + const TaskTableData &task_data) { // The task was in the GCS task table. Use the stored task spec to // re-execute the task. - const Task task(task_data); + auto message = flatbuffers::GetRoot(task_data.task().data()); + const Task task(*message); ResubmitTask(task); }, /*failure_callback=*/ @@ -2046,7 +2052,7 @@ void NodeManager::ResubmitTask(const Task &task) { if (task.GetTaskSpecification().IsActorCreationTask()) { const auto &actor_id = task.GetTaskSpecification().ActorCreationId(); const auto it = actor_registry_.find(actor_id); - if (it != actor_registry_.end() && it->second.GetState() == ActorState::ALIVE) { + if (it != actor_registry_.end() && it->second.GetState() == ActorTableData::ALIVE) { // If the actor is still alive, then do not resubmit the task. If the // actor actually is dead and a result is needed, then reconstruction // for this task will be triggered again. @@ -2205,6 +2211,12 @@ void NodeManager::ForwardTask( const auto &spec = task.GetTaskSpecification(); auto task_id = spec.TaskId(); + if (worker_pool_.HasPendingWorkerForTask(spec.GetLanguage(), task_id)) { + // There is a worker being starting for this task, + // so we shouldn't forward this task to another node. + return; + } + // Get and serialize the task's unforwarded, uncommitted lineage. Lineage uncommitted_lineage; if (lineage_cache_.ContainsTask(task_id)) { diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index 61613358330c..7e812183657c 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -10,7 +10,6 @@ #include "ray/raylet/task.h" #include "ray/object_manager/object_manager.h" #include "ray/common/client_connection.h" -#include "ray/gcs/format/util.h" #include "ray/raylet/actor_registration.h" #include "ray/raylet/lineage_cache.h" #include "ray/raylet/scheduling_policy.h" @@ -26,6 +25,13 @@ namespace ray { namespace raylet { +using rpc::ActorTableData; +using rpc::ClientTableData; +using rpc::DriverTableData; +using rpc::ErrorType; +using rpc::HeartbeatBatchTableData; +using rpc::HeartbeatTableData; + struct NodeManagerConfig { /// The node's resource configuration. ResourceSet resource_config; @@ -112,22 +118,22 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// /// \param data Data associated with the new client. /// \return Void. - void ClientAdded(const ClientTableDataT &data); + void ClientAdded(const ClientTableData &data); /// Handler for the removal of a GCS client. /// \param client_data Data associated with the removed client. /// \return Void. - void ClientRemoved(const ClientTableDataT &client_data); + void ClientRemoved(const ClientTableData &client_data); /// Handler for the addition or updation of a resource in the GCS /// \param client_data Data associated with the new client. /// \return Void. - void ResourceCreateUpdated(const ClientTableDataT &client_data); + void ResourceCreateUpdated(const ClientTableData &client_data); /// Handler for the deletion of a resource in the GCS /// \param client_data Data associated with the new client. /// \return Void. - void ResourceDeleted(const ClientTableDataT &client_data); + void ResourceDeleted(const ClientTableData &client_data); /// Evaluates the local infeasible queue to check if any tasks can be scheduled. /// This is called whenever there's an update to the resources on the local client. @@ -150,11 +156,11 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// \param id The ID of the node manager that sent the heartbeat. /// \param data The heartbeat data including load information. /// \return Void. - void HeartbeatAdded(const ClientID &id, const HeartbeatTableDataT &data); + void HeartbeatAdded(const ClientID &id, const HeartbeatTableData &data); /// Handler for a heartbeat batch notification from the GCS /// /// \param heartbeat_batch The batch of heartbeat data. - void HeartbeatBatchAdded(const HeartbeatBatchTableDataT &heartbeat_batch); + void HeartbeatBatchAdded(const HeartbeatBatchTableData &heartbeat_batch); /// Methods for task scheduling. @@ -206,7 +212,7 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// Helper function to produce actor table data for a newly created actor. /// /// \param task The actor creation task that created the actor. - ActorTableDataT CreateActorTableDataFromCreationTask(const Task &task); + ActorTableData CreateActorTableDataFromCreationTask(const Task &task); /// Handle a worker finishing an assigned actor task or actor creation task. /// \param worker The worker that finished the task. /// \param task The actor task or actor creationt ask. @@ -317,7 +323,7 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// \param failure_callback An optional callback to call if the publish is /// unsuccessful. void PublishActorStateTransition( - const ActorID &actor_id, const ActorTableDataT &data, + const ActorID &actor_id, const ActorTableData &data, const ray::gcs::ActorTable::WriteCallback &failure_callback); /// When a driver dies, loop over all of the queued tasks for that driver and @@ -346,7 +352,7 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// \param driver_data Data associated with a driver table event. /// \return Void. void HandleDriverTableUpdate(const DriverID &id, - const std::vector &driver_data); + const std::vector &driver_data); /// Check if certain invariants associated with the task dependency manager /// and the local queues are satisfied. This is only used for debugging @@ -506,7 +512,10 @@ class NodeManager : public rpc::NodeManagerServiceHandler { std::unordered_map checkpoint_id_to_restore_; /// The RPC server. - rpc::NodeManagerServer node_manager_server_; + rpc::GrpcServer node_manager_server_; + + /// The RPC service. + rpc::NodeManagerGrpcService node_manager_service_; /// The `ClientCallManager` object that is shared by all `NodeManagerClient`s. rpc::ClientCallManager client_call_manager_; diff --git a/src/ray/raylet/raylet.cc b/src/ray/raylet/raylet.cc index 473e6c263ffe..cbf9b25213ca 100644 --- a/src/ray/raylet/raylet.cc +++ b/src/ray/raylet/raylet.cc @@ -90,23 +90,23 @@ ray::Status Raylet::RegisterGcs(const std::string &node_ip_address, const NodeManagerConfig &node_manager_config) { RAY_RETURN_NOT_OK(gcs_client_->Attach(io_service)); - ClientTableDataT client_info = gcs_client_->client_table().GetLocalClient(); - client_info.node_manager_address = node_ip_address; - client_info.raylet_socket_name = raylet_socket_name; - client_info.object_store_socket_name = object_store_socket_name; - client_info.object_manager_port = object_manager_acceptor_.local_endpoint().port(); - client_info.node_manager_port = node_manager_.GetServerPort(); + ClientTableData client_info = gcs_client_->client_table().GetLocalClient(); + client_info.set_node_manager_address(node_ip_address); + client_info.set_raylet_socket_name(raylet_socket_name); + client_info.set_object_store_socket_name(object_store_socket_name); + client_info.set_object_manager_port(object_manager_acceptor_.local_endpoint().port()); + client_info.set_node_manager_port(node_manager_.GetServerPort()); // Add resource information. for (const auto &resource_pair : node_manager_config.resource_config.GetResourceMap()) { - client_info.resources_total_label.push_back(resource_pair.first); - client_info.resources_total_capacity.push_back(resource_pair.second); + client_info.add_resources_total_label(resource_pair.first); + client_info.add_resources_total_capacity(resource_pair.second); } RAY_LOG(DEBUG) << "Node manager " << gcs_client_->client_table().GetLocalClientId() - << " started on " << client_info.node_manager_address << ":" - << client_info.node_manager_port << " object manager at " - << client_info.node_manager_address << ":" - << client_info.object_manager_port; + << " started on " << client_info.node_manager_address() << ":" + << client_info.node_manager_port() << " object manager at " + << client_info.node_manager_address() << ":" + << client_info.object_manager_port(); ; RAY_RETURN_NOT_OK(gcs_client_->client_table().Connect(client_info)); diff --git a/src/ray/raylet/raylet.h b/src/ray/raylet/raylet.h index 26fe74b2b622..9367a5054591 100644 --- a/src/ray/raylet/raylet.h +++ b/src/ray/raylet/raylet.h @@ -16,6 +16,8 @@ namespace ray { namespace raylet { +using rpc::ClientTableData; + class Task; class NodeManager; diff --git a/src/ray/raylet/reconstruction_policy.cc b/src/ray/raylet/reconstruction_policy.cc index 97c86ea73cd8..bf5c1acfaa37 100644 --- a/src/ray/raylet/reconstruction_policy.cc +++ b/src/ray/raylet/reconstruction_policy.cc @@ -106,19 +106,19 @@ void ReconstructionPolicy::AttemptReconstruction(const TaskID &task_id, // Attempt to reconstruct the task by inserting an entry into the task // reconstruction log. This will fail if another node has already inserted // an entry for this reconstruction. - auto reconstruction_entry = std::make_shared(); - reconstruction_entry->num_reconstructions = reconstruction_attempt; - reconstruction_entry->node_manager_id = client_id_.Binary(); + auto reconstruction_entry = std::make_shared(); + reconstruction_entry->set_num_reconstructions(reconstruction_attempt); + reconstruction_entry->set_node_manager_id(client_id_.Binary()); RAY_CHECK_OK(task_reconstruction_log_.AppendAt( DriverID::Nil(), task_id, reconstruction_entry, /*success_callback=*/ [this](gcs::AsyncGcsClient *client, const TaskID &task_id, - const TaskReconstructionDataT &data) { + const TaskReconstructionData &data) { HandleReconstructionLogAppend(task_id, /*success=*/true); }, /*failure_callback=*/ [this](gcs::AsyncGcsClient *client, const TaskID &task_id, - const TaskReconstructionDataT &data) { + const TaskReconstructionData &data) { HandleReconstructionLogAppend(task_id, /*success=*/false); }, reconstruction_attempt)); diff --git a/src/ray/raylet/reconstruction_policy.h b/src/ray/raylet/reconstruction_policy.h index cd969cc2706e..a194443e1425 100644 --- a/src/ray/raylet/reconstruction_policy.h +++ b/src/ray/raylet/reconstruction_policy.h @@ -17,6 +17,8 @@ namespace ray { namespace raylet { +using rpc::TaskReconstructionData; + class ReconstructionPolicyInterface { public: virtual void ListenAndMaybeReconstruct(const ObjectID &object_id) = 0; diff --git a/src/ray/raylet/reconstruction_policy_test.cc b/src/ray/raylet/reconstruction_policy_test.cc index 4ccebd0c0c09..12d9336a382f 100644 --- a/src/ray/raylet/reconstruction_policy_test.cc +++ b/src/ray/raylet/reconstruction_policy_test.cc @@ -14,6 +14,8 @@ namespace ray { namespace raylet { +using rpc::TaskLeaseData; + class MockObjectDirectory : public ObjectDirectoryInterface { public: MockObjectDirectory() {} @@ -83,7 +85,7 @@ class MockGcs : public gcs::PubsubInterface, } void Add(const DriverID &driver_id, const TaskID &task_id, - std::shared_ptr &task_lease_data) { + std::shared_ptr &task_lease_data) { task_lease_table_[task_id] = task_lease_data; if (subscribed_tasks_.count(task_id) == 1) { notification_callback_(nullptr, task_id, *task_lease_data); @@ -110,7 +112,7 @@ class MockGcs : public gcs::PubsubInterface, Status AppendAt( const DriverID &driver_id, const TaskID &task_id, - std::shared_ptr &task_data, + std::shared_ptr &task_data, const ray::gcs::LogInterface::WriteCallback &success_callback, const ray::gcs::LogInterface::WriteCallback @@ -132,15 +134,15 @@ class MockGcs : public gcs::PubsubInterface, MOCK_METHOD4( Append, ray::Status( - const DriverID &, const TaskID &, std::shared_ptr &, + const DriverID &, const TaskID &, std::shared_ptr &, const ray::gcs::LogInterface::WriteCallback &)); private: gcs::TaskLeaseTable::WriteCallback notification_callback_; gcs::TaskLeaseTable::FailureCallback failure_callback_; - std::unordered_map> task_lease_table_; + std::unordered_map> task_lease_table_; std::unordered_set subscribed_tasks_; - std::unordered_map> + std::unordered_map> task_reconstruction_log_; }; @@ -159,9 +161,9 @@ class ReconstructionPolicyTest : public ::testing::Test { timer_canceled_(false) { mock_gcs_.Subscribe( [this](gcs::AsyncGcsClient *client, const TaskID &task_id, - const TaskLeaseDataT &task_lease) { + const TaskLeaseData &task_lease) { reconstruction_policy_->HandleTaskLeaseNotification(task_id, - task_lease.timeout); + task_lease.timeout()); }, [this](gcs::AsyncGcsClient *client, const TaskID &task_id) { reconstruction_policy_->HandleTaskLeaseNotification(task_id, 0); @@ -314,10 +316,10 @@ TEST_F(ReconstructionPolicyTest, TestReconstructionSuppressed) { int64_t test_period = 2 * reconstruction_timeout_ms_; // Acquire the task lease for a period longer than the test period. - auto task_lease_data = std::make_shared(); - task_lease_data->node_manager_id = ClientID::FromRandom().Binary(); - task_lease_data->acquired_at = current_sys_time_ms(); - task_lease_data->timeout = 2 * test_period; + auto task_lease_data = std::make_shared(); + task_lease_data->set_node_manager_id(ClientID::FromRandom().Binary()); + task_lease_data->set_acquired_at(current_sys_time_ms()); + task_lease_data->set_timeout(2 * test_period); mock_gcs_.Add(DriverID::Nil(), task_id, task_lease_data); // Listen for an object. @@ -328,7 +330,7 @@ TEST_F(ReconstructionPolicyTest, TestReconstructionSuppressed) { ASSERT_TRUE(reconstructed_tasks_.empty()); // Run the test again past the expiration time of the lease. - Run(task_lease_data->timeout * 1.1); + Run(task_lease_data->timeout() * 1.1); // Check that this time, reconstruction is triggered. ASSERT_EQ(reconstructed_tasks_[task_id], 1); } @@ -341,10 +343,10 @@ TEST_F(ReconstructionPolicyTest, TestReconstructionContinuallySuppressed) { reconstruction_policy_->ListenAndMaybeReconstruct(object_id); // Send the reconstruction manager heartbeats about the object. SetPeriodicTimer(reconstruction_timeout_ms_ / 2, [this, task_id]() { - auto task_lease_data = std::make_shared(); - task_lease_data->node_manager_id = ClientID::FromRandom().Binary(); - task_lease_data->acquired_at = current_sys_time_ms(); - task_lease_data->timeout = reconstruction_timeout_ms_; + auto task_lease_data = std::make_shared(); + task_lease_data->set_node_manager_id(ClientID::FromRandom().Binary()); + task_lease_data->set_acquired_at(current_sys_time_ms()); + task_lease_data->set_timeout(reconstruction_timeout_ms_); mock_gcs_.Add(DriverID::Nil(), task_id, task_lease_data); }); // Run the test for much longer than the reconstruction timeout. @@ -393,14 +395,14 @@ TEST_F(ReconstructionPolicyTest, TestSimultaneousReconstructionSuppressed) { // Log a reconstruction attempt to simulate a different node attempting the // reconstruction first. This should suppress this node's first attempt at // reconstruction. - auto task_reconstruction_data = std::make_shared(); - task_reconstruction_data->node_manager_id = ClientID::FromRandom().Binary(); - task_reconstruction_data->num_reconstructions = 0; + auto task_reconstruction_data = std::make_shared(); + task_reconstruction_data->set_node_manager_id(ClientID::FromRandom().Binary()); + task_reconstruction_data->set_num_reconstructions(0); RAY_CHECK_OK( mock_gcs_.AppendAt(DriverID::Nil(), task_id, task_reconstruction_data, nullptr, /*failure_callback=*/ [](ray::gcs::AsyncGcsClient *client, const TaskID &task_id, - const TaskReconstructionDataT &data) { ASSERT_TRUE(false); }, + const TaskReconstructionData &data) { ASSERT_TRUE(false); }, /*log_index=*/0)); // Listen for an object. diff --git a/src/ray/raylet/task_dependency_manager.cc b/src/ray/raylet/task_dependency_manager.cc index c5155b96b0c1..89028c733d0d 100644 --- a/src/ray/raylet/task_dependency_manager.cc +++ b/src/ray/raylet/task_dependency_manager.cc @@ -261,10 +261,10 @@ void TaskDependencyManager::AcquireTaskLease(const TaskID &task_id) { << (it->second.expires_at - now_ms) << "ms"; } - auto task_lease_data = std::make_shared(); - task_lease_data->node_manager_id = client_id_.Hex(); - task_lease_data->acquired_at = current_sys_time_ms(); - task_lease_data->timeout = it->second.lease_period; + auto task_lease_data = std::make_shared(); + task_lease_data->set_node_manager_id(client_id_.Hex()); + task_lease_data->set_acquired_at(current_sys_time_ms()); + task_lease_data->set_timeout(it->second.lease_period); RAY_CHECK_OK(task_lease_table_.Add(DriverID::Nil(), task_id, task_lease_data, nullptr)); auto period = boost::posix_time::milliseconds(it->second.lease_period / 2); diff --git a/src/ray/raylet/task_dependency_manager.h b/src/ray/raylet/task_dependency_manager.h index 3788a5eae7ae..a96558295234 100644 --- a/src/ray/raylet/task_dependency_manager.h +++ b/src/ray/raylet/task_dependency_manager.h @@ -13,6 +13,8 @@ namespace ray { namespace raylet { +using rpc::TaskLeaseData; + class ReconstructionPolicy; /// \class TaskDependencyManager diff --git a/src/ray/raylet/task_dependency_manager_test.cc b/src/ray/raylet/task_dependency_manager_test.cc index e0f832a12870..f7a60989fcba 100644 --- a/src/ray/raylet/task_dependency_manager_test.cc +++ b/src/ray/raylet/task_dependency_manager_test.cc @@ -30,7 +30,7 @@ class MockGcs : public gcs::TableInterface { MOCK_METHOD4( Add, ray::Status(const DriverID &driver_id, const TaskID &task_id, - std::shared_ptr &task_data, + std::shared_ptr &task_data, const gcs::TableInterface::WriteCallback &done)); }; diff --git a/src/ray/raylet/task_spec.cc b/src/ray/raylet/task_spec.cc index eeab29272126..1d722de18f73 100644 --- a/src/ray/raylet/task_spec.cc +++ b/src/ray/raylet/task_spec.cc @@ -80,12 +80,12 @@ TaskSpecification::TaskSpecification( const std::vector> &task_arguments, int64_t num_returns, const std::unordered_map &required_resources, const std::unordered_map &required_placement_resources, - const Language &language, const std::vector &function_descriptor) + const Language &language, const std::vector &function_descriptor, + const std::vector &dynamic_worker_options) : spec_() { flatbuffers::FlatBufferBuilder fbb; TaskID task_id = GenerateTaskId(driver_id, parent_task_id, parent_counter); - // Add argument object IDs. std::vector> arguments; for (auto &argument : task_arguments) { @@ -101,7 +101,8 @@ TaskSpecification::TaskSpecification( ids_to_flatbuf(fbb, new_actor_handles), fbb.CreateVector(arguments), num_returns, map_to_flatbuf(fbb, required_resources), map_to_flatbuf(fbb, required_placement_resources), language, - string_vec_to_flatbuf(fbb, function_descriptor)); + string_vec_to_flatbuf(fbb, function_descriptor), + string_vec_to_flatbuf(fbb, dynamic_worker_options)); fbb.Finish(spec); AssignSpecification(fbb.GetBufferPointer(), fbb.GetSize()); } @@ -258,6 +259,11 @@ std::vector TaskSpecification::NewActorHandles() const { return ids_from_flatbuf(*message->new_actor_handles()); } +std::vector TaskSpecification::DynamicWorkerOptions() const { + auto message = flatbuffers::GetRoot(spec_.data()); + return string_vec_from_flatbuf(*message->dynamic_worker_options()); +} + } // namespace raylet } // namespace ray diff --git a/src/ray/raylet/task_spec.h b/src/ray/raylet/task_spec.h index d557c188ae68..8a08e9974ef2 100644 --- a/src/ray/raylet/task_spec.h +++ b/src/ray/raylet/task_spec.h @@ -128,6 +128,7 @@ class TaskSpecification { /// will default to be equal to the required_resources argument. /// \param language The language of the worker that must execute the function. /// \param function_descriptor The function descriptor. + /// \param dynamic_worker_options The dynamic options for starting an actor worker. TaskSpecification( const DriverID &driver_id, const TaskID &parent_task_id, int64_t parent_counter, const ActorID &actor_creation_id, const ObjectID &actor_creation_dummy_object_id, @@ -138,7 +139,8 @@ class TaskSpecification { int64_t num_returns, const std::unordered_map &required_resources, const std::unordered_map &required_placement_resources, - const Language &language, const std::vector &function_descriptor); + const Language &language, const std::vector &function_descriptor, + const std::vector &dynamic_worker_options = {}); /// Deserialize a task specification from a string. /// @@ -214,6 +216,8 @@ class TaskSpecification { ObjectID ActorDummyObject() const; std::vector NewActorHandles() const; + std::vector DynamicWorkerOptions() const; + private: /// Assign the specification data from a pointer. void AssignSpecification(const uint8_t *spec, size_t spec_size); diff --git a/src/ray/raylet/worker_pool.cc b/src/ray/raylet/worker_pool.cc index d4ac4cf4ecce..16086565de80 100644 --- a/src/ray/raylet/worker_pool.cc +++ b/src/ray/raylet/worker_pool.cc @@ -5,10 +5,12 @@ #include #include +#include "ray/common/constants.h" #include "ray/common/ray_config.h" #include "ray/common/status.h" #include "ray/stats/stats.h" #include "ray/util/logging.h" +#include "ray/util/util.h" namespace { @@ -41,12 +43,13 @@ namespace raylet { /// (num_worker_processes * num_workers_per_process) workers for each language. WorkerPool::WorkerPool( int num_worker_processes, int num_workers_per_process, - int maximum_startup_concurrency, + int maximum_startup_concurrency, std::shared_ptr gcs_client, const std::unordered_map> &worker_commands) : num_workers_per_process_(num_workers_per_process), multiple_for_warning_(std::max(num_worker_processes, maximum_startup_concurrency)), maximum_startup_concurrency_(maximum_startup_concurrency), - last_warning_multiple_(0) { + last_warning_multiple_(0), + gcs_client_(std::move(gcs_client)) { RAY_CHECK(num_workers_per_process > 0) << "num_workers_per_process must be positive."; RAY_CHECK(maximum_startup_concurrency > 0); // Ignore SIGCHLD signals. If we don't do this, then worker processes will @@ -98,7 +101,8 @@ uint32_t WorkerPool::Size(const Language &language) const { } } -void WorkerPool::StartWorkerProcess(const Language &language) { +int WorkerPool::StartWorkerProcess(const Language &language, + const std::vector &dynamic_options) { auto &state = GetStateForLanguage(language); // If we are already starting up too many workers, then return without starting // more. @@ -108,7 +112,7 @@ void WorkerPool::StartWorkerProcess(const Language &language) { RAY_LOG(DEBUG) << "Worker not started, " << state.starting_worker_processes.size() << " worker processes of language type " << static_cast(language) << " pending registration"; - return; + return -1; } // Either there are no workers pending registration or the worker start is being forced. RAY_LOG(DEBUG) << "Starting new worker process, current pool has " @@ -117,8 +121,20 @@ void WorkerPool::StartWorkerProcess(const Language &language) { // Extract pointers from the worker command to pass into execvp. std::vector worker_command_args; + size_t dynamic_option_index = 0; for (auto const &token : state.worker_command) { - worker_command_args.push_back(token.c_str()); + const auto option_placeholder = + kWorkerDynamicOptionPlaceholderPrefix + std::to_string(dynamic_option_index); + + if (token == option_placeholder) { + if (!dynamic_options.empty()) { + RAY_CHECK(dynamic_option_index < dynamic_options.size()); + worker_command_args.push_back(dynamic_options[dynamic_option_index].c_str()); + ++dynamic_option_index; + } + } else { + worker_command_args.push_back(token.c_str()); + } } worker_command_args.push_back(nullptr); @@ -126,14 +142,14 @@ void WorkerPool::StartWorkerProcess(const Language &language) { if (pid < 0) { // Failure case. RAY_LOG(FATAL) << "Failed to fork worker process: " << strerror(errno); - return; } else if (pid > 0) { // Parent process case. RAY_LOG(DEBUG) << "Started worker process with pid " << pid; state.starting_worker_processes.emplace( std::make_pair(pid, num_workers_per_process_)); - return; + return pid; } + return -1; } pid_t WorkerPool::StartProcess(const std::vector &worker_command_args) { @@ -158,7 +174,7 @@ pid_t WorkerPool::StartProcess(const std::vector &worker_command_a } void WorkerPool::RegisterWorker(const std::shared_ptr &worker) { - auto pid = worker->Pid(); + const auto pid = worker->Pid(); RAY_LOG(DEBUG) << "Registering worker with pid " << pid; auto &state = GetStateForLanguage(worker->GetLanguage()); state.registered_workers.insert(std::move(worker)); @@ -207,30 +223,74 @@ void WorkerPool::PushWorker(const std::shared_ptr &worker) { RAY_CHECK(worker->GetAssignedTaskId().IsNil()) << "Idle workers cannot have an assigned task ID"; auto &state = GetStateForLanguage(worker->GetLanguage()); - // Add the worker to the idle pool. - if (worker->GetActorId().IsNil()) { - state.idle.insert(std::move(worker)); + + auto it = state.dedicated_workers_to_tasks.find(worker->Pid()); + if (it != state.dedicated_workers_to_tasks.end()) { + // The worker is used for the actor creation task with dynamic options. + // Put it into idle dedicated worker pool. + const auto task_id = it->second; + state.idle_dedicated_workers[task_id] = std::move(worker); } else { - state.idle_actor[worker->GetActorId()] = std::move(worker); + // The worker is not used for the actor creation task without dynamic options. + // Put the worker to the corresponding idle pool. + if (worker->GetActorId().IsNil()) { + state.idle.insert(std::move(worker)); + } else { + state.idle_actor[worker->GetActorId()] = std::move(worker); + } } } std::shared_ptr WorkerPool::PopWorker(const TaskSpecification &task_spec) { auto &state = GetStateForLanguage(task_spec.GetLanguage()); const auto &actor_id = task_spec.ActorId(); + std::shared_ptr worker = nullptr; - if (actor_id.IsNil()) { + int pid = -1; + if (task_spec.IsActorCreationTask() && !task_spec.DynamicWorkerOptions().empty()) { + // Code path of actor creation task with dynamic worker options. + // Try to pop it from idle dedicated pool. + auto it = state.idle_dedicated_workers.find(task_spec.TaskId()); + if (it != state.idle_dedicated_workers.end()) { + // There is an idle dedicated worker for this task. + worker = std::move(it->second); + state.idle_dedicated_workers.erase(it); + // Because we found a worker that can perform this task, + // we can remove it from dedicated_workers_to_tasks. + state.dedicated_workers_to_tasks.erase(worker->Pid()); + state.tasks_to_dedicated_workers.erase(task_spec.TaskId()); + } else if (!HasPendingWorkerForTask(task_spec.GetLanguage(), task_spec.TaskId())) { + // We are not pending a registration from a worker for this task, + // so start a new worker process for this task. + pid = StartWorkerProcess(task_spec.GetLanguage(), task_spec.DynamicWorkerOptions()); + if (pid > 0) { + state.dedicated_workers_to_tasks[pid] = task_spec.TaskId(); + state.tasks_to_dedicated_workers[task_spec.TaskId()] = pid; + } + } + } else if (!task_spec.IsActorTask()) { + // Code path of normal task or actor creation task without dynamic worker options. if (!state.idle.empty()) { worker = std::move(*state.idle.begin()); state.idle.erase(state.idle.begin()); + } else { + // There are no more non-actor workers available to execute this task. + // Start a new worker process. + pid = StartWorkerProcess(task_spec.GetLanguage()); } } else { + // Code path of actor task. auto actor_entry = state.idle_actor.find(actor_id); if (actor_entry != state.idle_actor.end()) { worker = std::move(actor_entry->second); state.idle_actor.erase(actor_entry); } } + + if (worker == nullptr && pid > 0) { + WarnAboutSize(); + } + return worker; } @@ -274,7 +334,7 @@ std::vector> WorkerPool::GetWorkersRunningTasksForDriver return workers; } -std::string WorkerPool::WarningAboutSize() { +void WorkerPool::WarnAboutSize() { int64_t num_workers_started_or_registered = 0; for (const auto &entry : states_by_lang_) { num_workers_started_or_registered += @@ -285,6 +345,8 @@ std::string WorkerPool::WarningAboutSize() { int64_t multiple = num_workers_started_or_registered / multiple_for_warning_; std::stringstream warning_message; if (multiple >= 3 && multiple > last_warning_multiple_) { + // Push an error message to the user if the worker pool tells us that it is + // getting too big. last_warning_multiple_ = multiple; warning_message << "WARNING: " << num_workers_started_or_registered << " workers have been started. This could be a result of using " @@ -292,8 +354,16 @@ std::string WorkerPool::WarningAboutSize() { << "using nested tasks " << "(see https://github.com/ray-project/ray/issues/3644) for " << "some a discussion of workarounds."; + RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver( + DriverID::Nil(), "worker_pool_large", warning_message.str(), current_time_ms())); } - return warning_message.str(); +} + +bool WorkerPool::HasPendingWorkerForTask(const Language &language, + const TaskID &task_id) { + auto &state = GetStateForLanguage(language); + auto it = state.tasks_to_dedicated_workers.find(task_id); + return it != state.tasks_to_dedicated_workers.end(); } std::string WorkerPool::DebugString() const { diff --git a/src/ray/raylet/worker_pool.h b/src/ray/raylet/worker_pool.h index 03443447cf58..e1e726268093 100644 --- a/src/ray/raylet/worker_pool.h +++ b/src/ray/raylet/worker_pool.h @@ -7,6 +7,7 @@ #include #include "ray/common/client_connection.h" +#include "ray/gcs/client.h" #include "ray/gcs/format/util.h" #include "ray/raylet/task.h" #include "ray/raylet/worker.h" @@ -37,22 +38,12 @@ class WorkerPool { /// language. WorkerPool( int num_worker_processes, int num_workers_per_process, - int maximum_startup_concurrency, + int maximum_startup_concurrency, std::shared_ptr gcs_client, const std::unordered_map> &worker_commands); /// Destructor responsible for freeing a set of workers owned by this class. virtual ~WorkerPool(); - /// Asynchronously start a new worker process. Once the worker process has - /// registered with an external server, the process should create and - /// register num_workers_per_process_ workers, then add them to the pool. - /// Failure to start the worker process is a fatal error. If too many workers - /// are already being started, then this function will return without starting - /// any workers. - /// - /// \param language Which language this worker process should be. - void StartWorkerProcess(const Language &language); - /// Register a new worker. The Worker should be added by the caller to the /// pool after it becomes idle (e.g., requests a work assignment). /// @@ -118,6 +109,15 @@ class WorkerPool { std::vector> GetWorkersRunningTasksForDriver( const DriverID &driver_id) const; + /// Whether there is a pending worker for the given task. + /// Note that, this is only used for actor creation task with dynamic options. + /// And if the worker registered but isn't assigned a task, + /// the worker also is in pending state, and this'll return true. + /// + /// \param language The required language. + /// \param task_id The task that we want to query. + bool HasPendingWorkerForTask(const Language &language, const TaskID &task_id); + /// Returns debug string for class. /// /// \return string. @@ -126,24 +126,37 @@ class WorkerPool { /// Record metrics. void RecordMetrics() const; - /// Generate a warning about the number of workers that have registered or - /// started if appropriate. + protected: + /// Asynchronously start a new worker process. Once the worker process has + /// registered with an external server, the process should create and + /// register num_workers_per_process_ workers, then add them to the pool. + /// Failure to start the worker process is a fatal error. If too many workers + /// are already being started, then this function will return without starting + /// any workers. /// - /// \return An empty string if no warning should be generated and otherwise a - /// string with a warning message. - std::string WarningAboutSize(); + /// \param language Which language this worker process should be. + /// \param dynamic_options The dynamic options that we should add for worker command. + /// \return The id of the process that we started if it's positive, + /// otherwise it means we didn't start a process. + int StartWorkerProcess(const Language &language, + const std::vector &dynamic_options = {}); - protected: /// The implementation of how to start a new worker process with command arguments. /// /// \param worker_command_args The command arguments of new worker process. /// \return The process ID of started worker process. virtual pid_t StartProcess(const std::vector &worker_command_args); + /// Push an warning message to user if worker pool is getting to big. + virtual void WarnAboutSize(); + /// An internal data structure that maintains the pool state per language. struct State { /// The commands and arguments used to start the worker process std::vector worker_command; + /// The pool of dedicated workers for actor creation tasks + /// with prefix or suffix worker command. + std::unordered_map> idle_dedicated_workers; /// The pool of idle non-actor workers. std::unordered_set> idle; /// The pool of idle actor workers. @@ -156,6 +169,11 @@ class WorkerPool { /// A map from the pids of starting worker processes /// to the number of their unregistered workers. std::unordered_map starting_worker_processes; + /// A map for looking up the task with dynamic options by the pid of + /// worker. Note that this is used for the dedicated worker processes. + std::unordered_map dedicated_workers_to_tasks; + /// A map for speeding up looking up the pending worker for the given task. + std::unordered_map tasks_to_dedicated_workers; }; /// The number of workers per process. @@ -166,7 +184,7 @@ class WorkerPool { private: /// A helper function that returns the reference of the pool state /// for a given language. - inline State &GetStateForLanguage(const Language &language); + State &GetStateForLanguage(const Language &language); /// We'll push a warning to the user every time a multiple of this many /// workers has been started. @@ -176,6 +194,8 @@ class WorkerPool { /// The last size at which a warning about the number of registered workers /// was generated. int64_t last_warning_multiple_; + /// A client connection to the GCS. + std::shared_ptr gcs_client_; }; } // namespace raylet diff --git a/src/ray/raylet/worker_pool_test.cc b/src/ray/raylet/worker_pool_test.cc index 143ffd57dda6..15a5fb0471e0 100644 --- a/src/ray/raylet/worker_pool_test.cc +++ b/src/ray/raylet/worker_pool_test.cc @@ -1,6 +1,7 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "ray/common/constants.h" #include "ray/raylet/node_manager.h" #include "ray/raylet/worker_pool.h" @@ -14,21 +15,46 @@ int MAXIMUM_STARTUP_CONCURRENCY = 5; class WorkerPoolMock : public WorkerPool { public: WorkerPoolMock() - : WorkerPool(0, NUM_WORKERS_PER_PROCESS, MAXIMUM_STARTUP_CONCURRENCY, - {{Language::PYTHON, {"dummy_py_worker_command"}}, - {Language::JAVA, {"dummy_java_worker_command"}}}), + : WorkerPoolMock({{Language::PYTHON, {"dummy_py_worker_command"}}, + {Language::JAVA, {"dummy_java_worker_command"}}}) {} + + explicit WorkerPoolMock( + const std::unordered_map> &worker_commands) + : WorkerPool(0, NUM_WORKERS_PER_PROCESS, MAXIMUM_STARTUP_CONCURRENCY, nullptr, + worker_commands), last_worker_pid_(0) {} + ~WorkerPoolMock() { // Avoid killing real processes states_by_lang_.clear(); } + void StartWorkerProcess(const Language &language, + const std::vector &dynamic_options = {}) { + WorkerPool::StartWorkerProcess(language, dynamic_options); + } + pid_t StartProcess(const std::vector &worker_command_args) override { - return ++last_worker_pid_; + last_worker_pid_ += 1; + std::vector local_worker_commands_args; + for (auto item : worker_command_args) { + if (item == nullptr) { + break; + } + local_worker_commands_args.push_back(std::string(item)); + } + worker_commands_by_pid[last_worker_pid_] = std::move(local_worker_commands_args); + return last_worker_pid_; } + void WarnAboutSize() override {} + pid_t LastStartedWorkerProcess() const { return last_worker_pid_; } + const std::vector &GetWorkerCommand(int pid) { + return worker_commands_by_pid[pid]; + } + int NumWorkerProcessesStarting() const { int total = 0; for (auto &entry : states_by_lang_) { @@ -39,6 +65,8 @@ class WorkerPoolMock : public WorkerPool { private: int last_worker_pid_; + // The worker commands by pid. + std::unordered_map> worker_commands_by_pid; }; class WorkerPoolTest : public ::testing::Test { @@ -61,6 +89,12 @@ class WorkerPoolTest : public ::testing::Test { return std::shared_ptr(new Worker(pid, language, client)); } + void SetWorkerCommands( + const std::unordered_map> &worker_commands) { + WorkerPoolMock worker_pool(worker_commands); + this->worker_pool_ = std::move(worker_pool); + } + protected: WorkerPoolMock worker_pool_; boost::asio::io_service io_service_; @@ -72,10 +106,10 @@ class WorkerPoolTest : public ::testing::Test { }; static inline TaskSpecification ExampleTaskSpec( - const ActorID actor_id = ActorID::Nil(), - const Language &language = Language::PYTHON) { + const ActorID actor_id = ActorID::Nil(), const Language &language = Language::PYTHON, + const ActorID actor_creation_id = ActorID::Nil()) { std::vector function_descriptor(3); - return TaskSpecification(DriverID::Nil(), TaskID::Nil(), 0, ActorID::Nil(), + return TaskSpecification(DriverID::Nil(), TaskID::Nil(), 0, actor_creation_id, ObjectID::Nil(), 0, actor_id, ActorHandleID::Nil(), 0, {}, {}, 0, {}, {}, language, function_descriptor); } @@ -186,6 +220,23 @@ TEST_F(WorkerPoolTest, PopWorkersOfMultipleLanguages) { ASSERT_NE(worker_pool_.PopWorker(java_task_spec), nullptr); } +TEST_F(WorkerPoolTest, StartWorkerWithDynamicOptionsCommand) { + const std::vector java_worker_command = { + "RAY_WORKER_OPTION_0", "dummy_java_worker_command", "RAY_WORKER_OPTION_1"}; + SetWorkerCommands({{Language::PYTHON, {"dummy_py_worker_command"}}, + {Language::JAVA, java_worker_command}}); + + TaskSpecification task_spec(DriverID::Nil(), TaskID::Nil(), 0, ActorID::FromRandom(), + ObjectID::Nil(), 0, ActorID::Nil(), ActorHandleID::Nil(), 0, + {}, {}, 0, {}, {}, Language::JAVA, {"", "", ""}, + {"test_op_0", "test_op_1"}); + worker_pool_.StartWorkerProcess(Language::JAVA, task_spec.DynamicWorkerOptions()); + const auto real_command = + worker_pool_.GetWorkerCommand(worker_pool_.LastStartedWorkerProcess()); + ASSERT_EQ(real_command, std::vector( + {"test_op_0", "dummy_java_worker_command", "test_op_1"})); +} + } // namespace raylet } // namespace ray diff --git a/src/ray/rpc/grpc_server.cc b/src/ray/rpc/grpc_server.cc index feb788da7692..f507039990c2 100644 --- a/src/ray/rpc/grpc_server.cc +++ b/src/ray/rpc/grpc_server.cc @@ -1,4 +1,5 @@ #include "ray/rpc/grpc_server.h" +#include namespace ray { namespace rpc { @@ -9,8 +10,10 @@ void GrpcServer::Run() { grpc::ServerBuilder builder; // TODO(hchen): Add options for authentication. builder.AddListeningPort(server_address, grpc::InsecureServerCredentials(), &port_); - // Allow subclasses to register concrete services. - RegisterServices(builder); + // Register all the services to this server. + for (auto &entry : services_) { + builder.RegisterService(&entry.get()); + } // Get hold of the completion queue used for the asynchronous communication // with the gRPC runtime. cq_ = builder.AddCompletionQueue(); @@ -18,8 +21,7 @@ void GrpcServer::Run() { server_ = builder.BuildAndStart(); RAY_LOG(DEBUG) << name_ << " server started, listening on port " << port_ << "."; - // Allow subclasses to initialize the server call factories. - InitServerCallFactories(&server_call_factories_and_concurrencies_); + // Create calls for all the server call factories. for (auto &entry : server_call_factories_and_concurrencies_) { for (int i = 0; i < entry.second; i++) { // Create and request calls from the factory. @@ -31,6 +33,11 @@ void GrpcServer::Run() { polling_thread.detach(); } +void GrpcServer::RegisterService(GrpcService &service) { + services_.emplace_back(service.GetGrpcService()); + service.InitServerCallFactories(cq_, &server_call_factories_and_concurrencies_); +} + void GrpcServer::PollEventsFromCompletionQueue() { void *tag; bool ok; @@ -48,7 +55,7 @@ void GrpcServer::PollEventsFromCompletionQueue() { // incoming request. server_call->GetFactory().CreateCall(); server_call->SetState(ServerCallState::PROCESSING); - main_service_.post([server_call] { server_call->HandleRequest(); }); + server_call->HandleRequest(); break; case ServerCallState::SENDING_REPLY: // The reply has been sent, this call can be deleted now. diff --git a/src/ray/rpc/grpc_server.h b/src/ray/rpc/grpc_server.h index 4953f470610f..584da6565a47 100644 --- a/src/ray/rpc/grpc_server.h +++ b/src/ray/rpc/grpc_server.h @@ -12,7 +12,9 @@ namespace ray { namespace rpc { -/// Base class that represents an abstract gRPC server. +class GrpcService; + +/// Class that represents an gRPC server. /// /// A `GrpcServer` listens on a specific port. It owns /// 1) a `ServerCompletionQueue` that is used for polling events from gRPC, @@ -28,11 +30,7 @@ class GrpcServer { /// \param[in] name Name of this server, used for logging and debugging purpose. /// \param[in] port The port to bind this server to. If it's 0, a random available port /// will be chosen. - /// \param[in] main_service The main event loop, to which service handler functions - /// will be posted. - GrpcServer(const std::string &name, const uint32_t port, - boost::asio::io_service &main_service) - : name_(name), port_(port), main_service_(main_service) {} + GrpcServer(const std::string &name, const uint32_t port) : name_(name), port_(port) {} /// Destruct this gRPC server. ~GrpcServer() { @@ -46,36 +44,25 @@ class GrpcServer { /// Get the port of this gRPC server. int GetPort() const { return port_; } - protected: - /// Subclasses should implement this method and register one or multiple gRPC services - /// to the given `ServerBuilder`. + /// Register a grpc service. Multiple services can be registered to the same server. + /// Note that the `service` registered must remain valid for the lifetime of the + /// `GrpcServer`, as it holds the underlying `grpc::Service`. /// - /// \param[in] builder The `ServerBuilder` instance to register services to. - virtual void RegisterServices(grpc::ServerBuilder &builder) = 0; - - /// Subclasses should implement this method to initialize the `ServerCallFactory` - /// instances, as well as specify maximum number of concurrent requests that gRPC - /// server can "accept" (not "handle"). Each factory will be used to create - /// `accept_concurrency` `ServerCall` objects, each of which will be used to accept and - /// handle an incoming request. - /// - /// \param[out] server_call_factories_and_concurrencies The `ServerCallFactory` objects, - /// and the maximum number of concurrent requests that gRPC server can accept. - virtual void InitServerCallFactories( - std::vector, int>> - *server_call_factories_and_concurrencies) = 0; + /// \param[in] service A `GrpcService` to register to this server. + void RegisterService(GrpcService &service); + protected: /// This function runs in a background thread. It keeps polling events from the /// `ServerCompletionQueue`, and dispaches the event to the `ServiceHandler` instances /// via the `ServerCall` objects. void PollEventsFromCompletionQueue(); - /// The main event loop, to which the service handler functions will be posted. - boost::asio::io_service &main_service_; /// Name of this server, used for logging and debugging purpose. const std::string name_; /// Port of this server. int port_; + /// The `grpc::Service` objects which should be registered to `ServerBuilder`. + std::vector> services_; /// The `ServerCallFactory` objects, and the maximum number of concurrent requests that /// gRPC server can accept. std::vector, int>> @@ -86,6 +73,46 @@ class GrpcServer { std::unique_ptr server_; }; +/// Base class that represents an abstract gRPC service. +/// +/// Subclass should implement `InitServerCallFactories` to decide +/// which kinds of requests this service should accept. +class GrpcService { + public: + /// Constructor. + /// + /// \param[in] main_service The main event loop, to which service handler functions + /// will be posted. + GrpcService(boost::asio::io_service &main_service) : main_service_(main_service) {} + + /// Destruct this gRPC service. + ~GrpcService() {} + + protected: + /// Return the underlying grpc::Service object for this class. + /// This is passed to `GrpcServer` to be registered to grpc `ServerBuilder`. + virtual grpc::Service &GetGrpcService() = 0; + + /// Subclasses should implement this method to initialize the `ServerCallFactory` + /// instances, as well as specify maximum number of concurrent requests that gRPC + /// server can "accept" (not "handle"). Each factory will be used to create + /// `accept_concurrency` `ServerCall` objects, each of which will be used to accept and + /// handle an incoming request. + /// + /// \param[in] cq The grpc completion queue. + /// \param[out] server_call_factories_and_concurrencies The `ServerCallFactory` objects, + /// and the maximum number of concurrent requests that gRPC server can accept. + virtual void InitServerCallFactories( + const std::unique_ptr &cq, + std::vector, int>> + *server_call_factories_and_concurrencies) = 0; + + /// The main event loop, to which the service handler functions will be posted. + boost::asio::io_service &main_service_; + + friend class GrpcServer; +}; + } // namespace rpc } // namespace ray diff --git a/src/ray/rpc/node_manager_server.h b/src/ray/rpc/node_manager_server.h index afaea299ea89..d05f268c65b2 100644 --- a/src/ray/rpc/node_manager_server.h +++ b/src/ray/rpc/node_manager_server.h @@ -25,25 +25,22 @@ class NodeManagerServiceHandler { RequestDoneCallback done_callback) = 0; }; -/// The `GrpcServer` for `NodeManagerService`. -class NodeManagerServer : public GrpcServer { +/// The `GrpcService` for `NodeManagerService`. +class NodeManagerGrpcService : public GrpcService { public: /// Constructor. /// - /// \param[in] port See super class. - /// \param[in] main_service See super class. + /// \param[in] io_service See super class. /// \param[in] handler The service handler that actually handle the requests. - NodeManagerServer(const uint32_t port, boost::asio::io_service &main_service, - NodeManagerServiceHandler &service_handler) - : GrpcServer("NodeManager", port, main_service), - service_handler_(service_handler){}; + NodeManagerGrpcService(boost::asio::io_service &io_service, + NodeManagerServiceHandler &service_handler) + : GrpcService(io_service), service_handler_(service_handler){}; - void RegisterServices(grpc::ServerBuilder &builder) override { - /// Register `NodeManagerService`. - builder.RegisterService(&service_); - } + protected: + grpc::Service &GetGrpcService() override { return service_; } void InitServerCallFactories( + const std::unique_ptr &cq, std::vector, int>> *server_call_factories_and_concurrencies) override { // Initialize the factory for `ForwardTask` requests. @@ -51,7 +48,8 @@ class NodeManagerServer : public GrpcServer { new ServerCallFactoryImpl( service_, &NodeManagerService::AsyncService::RequestForwardTask, - service_handler_, &NodeManagerServiceHandler::HandleForwardTask, cq_)); + service_handler_, &NodeManagerServiceHandler::HandleForwardTask, cq, + main_service_)); // Set `ForwardTask`'s accept concurrency to 100. server_call_factories_and_concurrencies->emplace_back( @@ -61,6 +59,7 @@ class NodeManagerServer : public GrpcServer { private: /// The grpc async service object. NodeManagerService::AsyncService service_; + /// The service handler that actually handle the requests. NodeManagerServiceHandler &service_handler_; }; diff --git a/src/ray/rpc/server_call.h b/src/ray/rpc/server_call.h index e06278260ab6..08ca128323ee 100644 --- a/src/ray/rpc/server_call.h +++ b/src/ray/rpc/server_call.h @@ -94,20 +94,27 @@ class ServerCallImpl : public ServerCall { /// \param[in] factory The factory which created this call. /// \param[in] service_handler The service handler that handles the request. /// \param[in] handle_request_function Pointer to the service handler function. + /// \param[in] io_service The event loop. ServerCallImpl( const ServerCallFactory &factory, ServiceHandler &service_handler, - HandleRequestFunction handle_request_function) + HandleRequestFunction handle_request_function, + boost::asio::io_service &io_service) : state_(ServerCallState::PENDING), factory_(factory), service_handler_(service_handler), handle_request_function_(handle_request_function), - response_writer_(&context_) {} + response_writer_(&context_), + io_service_(io_service) {} ServerCallState GetState() const override { return state_; } void SetState(const ServerCallState &new_state) override { state_ = new_state; } void HandleRequest() override { + io_service_.post([this] { HandleRequestImpl(); }); + } + + void HandleRequestImpl() { state_ = ServerCallState::PROCESSING; (service_handler_.*handle_request_function_)(request_, &reply_, [this](Status status) { @@ -146,6 +153,9 @@ class ServerCallImpl : public ServerCall { /// The reponse writer. grpc::ServerAsyncResponseWriter response_writer_; + /// The event loop. + boost::asio::io_service &io_service_; + /// The request message. Request request_; @@ -185,23 +195,26 @@ class ServerCallFactoryImpl : public ServerCallFactory { /// \param[in] service_handler The service handler that handles the request. /// \param[in] handle_request_function Pointer to the service handler function. /// \param[in] cq The `CompletionQueue`. + /// \param[in] io_service The event loop. ServerCallFactoryImpl( AsyncService &service, RequestCallFunction request_call_function, ServiceHandler &service_handler, HandleRequestFunction handle_request_function, - const std::unique_ptr &cq) + const std::unique_ptr &cq, + boost::asio::io_service &io_service) : service_(service), request_call_function_(request_call_function), service_handler_(service_handler), handle_request_function_(handle_request_function), - cq_(cq) {} + cq_(cq), + io_service_(io_service) {} ServerCall *CreateCall() const override { // Create a new `ServerCall`. This object will eventually be deleted by // `GrpcServer::PollEventsFromCompletionQueue`. auto call = new ServerCallImpl( - *this, service_handler_, handle_request_function_); + *this, service_handler_, handle_request_function_, io_service_); /// Request gRPC runtime to starting accepting this kind of request, using the call as /// the tag. (service_.*request_call_function_)(&call->context_, &call->request_, @@ -225,6 +238,9 @@ class ServerCallFactoryImpl : public ServerCallFactory { /// The `CompletionQueue`. const std::unique_ptr &cq_; + + /// The event loop. + boost::asio::io_service &io_service_; }; } // namespace rpc diff --git a/src/ray/rpc/util.h b/src/ray/rpc/util.h index 6ecc6c3c4a34..59ae75ae33be 100644 --- a/src/ray/rpc/util.h +++ b/src/ray/rpc/util.h @@ -1,6 +1,7 @@ #ifndef RAY_RPC_UTIL_H #define RAY_RPC_UTIL_H +#include #include #include "ray/common/status.h" @@ -27,6 +28,18 @@ inline Status GrpcStatusToRayStatus(const grpc::Status &grpc_status) { } } +template +inline std::vector VectorFromProtobuf( + const ::google::protobuf::RepeatedPtrField &pb_repeated) { + return std::vector(pb_repeated.begin(), pb_repeated.end()); +} + +template +inline std::vector VectorFromProtobuf( + const ::google::protobuf::RepeatedField &pb_repeated) { + return std::vector(pb_repeated.begin(), pb_repeated.end()); +} + } // namespace rpc } // namespace ray From 342854b85839eed2c7b38fa64017905bb94ad278 Mon Sep 17 00:00:00 2001 From: Stefan Pantic Date: Wed, 26 Jun 2019 15:13:10 +0200 Subject: [PATCH 118/118] Remove entropy decay stuff --- python/ray/rllib/agents/impala/impala.py | 1 - .../ray/rllib/agents/impala/vtrace_policy.py | 36 ++++++++++--------- python/ray/rllib/policy/tf_policy.py | 20 ----------- 3 files changed, 20 insertions(+), 37 deletions(-) diff --git a/python/ray/rllib/agents/impala/impala.py b/python/ray/rllib/agents/impala/impala.py index 23b5ada167db..b9699888bfaf 100644 --- a/python/ray/rllib/agents/impala/impala.py +++ b/python/ray/rllib/agents/impala/impala.py @@ -75,7 +75,6 @@ # balancing the three losses "vf_loss_coeff": 0.5, "entropy_coeff": 0.01, - "entropy_schedule": None, # use fake (infinite speed) sampler for testing "_fake_sampler": False, diff --git a/python/ray/rllib/agents/impala/vtrace_policy.py b/python/ray/rllib/agents/impala/vtrace_policy.py index 9860783238a0..8e9b0e8691e6 100644 --- a/python/ray/rllib/agents/impala/vtrace_policy.py +++ b/python/ray/rllib/agents/impala/vtrace_policy.py @@ -7,19 +7,18 @@ from __future__ import print_function import gym -import ray import numpy as np +import ray from ray.rllib.agents.impala import vtrace from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY -from ray.rllib.policy.policy import Policy -from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.policy.tf_policy import TFPolicy, \ - LearningRateSchedule, EntropyCoeffSchedule from ray.rllib.models.action_dist import MultiCategorical from ray.rllib.models.catalog import ModelCatalog +from ray.rllib.policy.policy import Policy +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.tf_policy import TFPolicy, LearningRateSchedule +from ray.rllib.utils import try_import_tf from ray.rllib.utils.annotations import override from ray.rllib.utils.explained_variance import explained_variance -from ray.rllib.utils import try_import_tf tf = try_import_tf() @@ -96,15 +95,22 @@ def __init__(self, # The baseline loss delta = tf.boolean_mask(values - self.vtrace_returns.vs, valid_mask) - self.vf_loss = tf.math.multiply(0.5, tf.reduce_sum(tf.square(delta)), name='vf_loss') + self.vf_loss = tf.math.multiply( + 0.5, tf.reduce_sum( + tf.square(delta)), name='vf_loss') # The entropy loss self.entropy = tf.reduce_sum( tf.boolean_mask(actions_entropy, valid_mask), name='entropy_loss') # The summed weighted loss - self.total_loss = tf.math.add(self.pi_loss, self.vf_loss * vf_loss_coeff - self.entropy * entropy_coeff, - name='total_loss') + self.total_loss = tf.math.add( + self.pi_loss, + self.vf_loss * + vf_loss_coeff - + self.entropy * + entropy_coeff, + name='total_loss') class VTracePostprocessing(object): @@ -126,7 +132,7 @@ def postprocess_trajectory(self, return sample_batch -class VTraceTFPolicy(LearningRateSchedule, EntropyCoeffSchedule, VTracePostprocessing, TFPolicy): +class VTraceTFPolicy(LearningRateSchedule, VTracePostprocessing, TFPolicy): def __init__(self, observation_space, action_space, @@ -249,9 +255,6 @@ def make_time_major(tensor, drop_last=False): loss_actions = actions if is_multidiscrete else tf.expand_dims( actions, axis=1) - EntropyCoeffSchedule.__init__(self, self.config["entropy_coeff"], - self.config["entropy_schedule"]) - # Inputs are reshaped from [B * T] => [T - 1, B] for V-trace calc. with tf.name_scope('vtrace_loss'): self.loss = VTraceLoss( @@ -277,8 +280,10 @@ def make_time_major(tensor, drop_last=False): with tf.name_scope('kl_divergence'): # KL divergence between worker and learner logits for debugging - model_dist = MultiCategorical(self.model.outputs, output_hidden_shape) - behaviour_dist = MultiCategorical(behaviour_logits, output_hidden_shape) + model_dist = MultiCategorical( + self.model.outputs, output_hidden_shape) + behaviour_dist = MultiCategorical( + behaviour_logits, output_hidden_shape) kls = model_dist.kl(behaviour_dist) if len(kls) > 1: @@ -336,7 +341,6 @@ def make_time_major(tensor, drop_last=False): self.stats_fetches = { LEARNER_STATS_KEY: dict({ "cur_lr": tf.cast(self.cur_lr, tf.float64), - "entropy_coeff": tf.cast(self.entropy_coeff, tf.float64), "policy_loss": self.loss.pi_loss, "entropy": self.loss.entropy, "grad_gnorm": tf.global_norm(self._grads), diff --git a/python/ray/rllib/policy/tf_policy.py b/python/ray/rllib/policy/tf_policy.py index 591363a793be..ddee7de9745b 100644 --- a/python/ray/rllib/policy/tf_policy.py +++ b/python/ray/rllib/policy/tf_policy.py @@ -551,23 +551,3 @@ def on_global_var_update(self, global_vars): @override(TFPolicy) def optimizer(self): return tf.train.AdamOptimizer(self.cur_lr) - - -@DeveloperAPI -class EntropyCoeffSchedule(object): - """Mixin for TFPolicy that adds entropy coeff decay.""" - - @DeveloperAPI - def __init__(self, entropy_coeff, entropy_schedule): - self.entropy_coeff = tf.get_variable("entropy_coeff", initializer=entropy_coeff) - self._entropy_schedule = entropy_schedule - - @override(Policy) - def on_global_var_update(self, global_vars): - super(EntropyCoeffSchedule, self).on_global_var_update(global_vars) - if self._entropy_schedule is not None: - self.entropy_coeff.load( - self.config['entropy_coeff'] * - (1 - global_vars['timestep'] / - self.config['entropy_schedule']), - session=self._sess)