From 110aaab4eaf9c7fc9cdd9d191247bdb39cbeb36e Mon Sep 17 00:00:00 2001 From: Stefan Pantic Date: Wed, 26 Jun 2019 16:44:45 +0200 Subject: [PATCH] Merge with ray master (#36) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [rllib] Remove dependency on TensorFlow (#4764) * remove hard tf dep * add test * comment fix * fix test * Dynamic Custom Resources - create and delete resources (#3742) * Update tutorial link in doc (#4777) * [rllib] Implement learn_on_batch() in torch policy graph * Fix `ray stop` by killing raylet before plasma (#4778) * Fatal check if object store dies (#4763) * [rllib] fix clip by value issue as TF upgraded (#4697) * fix clip_by_value issue * fix typo * [autoscaler] Fix submit (#4782) * Queue tasks in the raylet in between async callbacks (#4766) * Add a SWAP TaskQueue so that we can keep track of tasks that are temporarily dequeued * Fix bug where tasks that fail to be forwarded don't appear to be local by adding them to SWAP queue * cleanups * updates * updates * [Java][Bazel] Refine auto-generated pom files (#4780) * Bump version to 0.7.0 (#4791) * [JAVA] setDefaultUncaughtExceptionHandler to log uncaught exception in user thread. (#4798) * Add WorkerUncaughtExceptionHandler * Fix * revert bazel and pom * [tune] Fix CLI test (#4801) * Fix pom file generation (#4800) * [rllib] Support continuous action distributions in IMPALA/APPO (#4771) * [rllib] TensorFlow 2 compatibility (#4802) * Change tagline in documentation and README. (#4807) * Update README.rst, index.rst, tutorial.rst and _config.yml * [tune] Support non-arg submit (#4803) * [autoscaler] rsync cluster (#4785) * [tune] Remove extra parsing functionality (#4804) * Fix Java worker log dir (#4781) * [tune] Initial track integration (#4362) Introduces a minimally invasive utility for logging experiment results. A broad requirement for this tool is that it should integrate seamlessly with Tune execution. * [rllib] [RFC] Dynamic definition of loss functions and modularization support (#4795) * dynamic graph * wip * clean up * fix * document trainer * wip * initialize the graph using a fake batch * clean up dynamic init * wip * spelling * use builder for ppo pol graph * add ppo graph * fix naming * order * docs * set class name correctly * add torch builder * add custom model support in builder * cleanup * remove underscores * fix py2 compat * Update dynamic_tf_policy_graph.py * Update tracking_dict.py * wip * rename * debug level * rename policy_graph -> policy in new classes * fix test * rename ppo tf policy * port appo too * forgot grads * default policy optimizer * make default config optional * add config to optimizer * use lr by default in optimizer * update * comments * remove optimizer * fix tuple actions support in dynamic tf graph * [rllib] Rename PolicyGraph => Policy, move from evaluation/ to policy/ (#4819) This implements some of the renames proposed in #4813 We leave behind backwards-compatibility aliases for *PolicyGraph and SampleBatch. * [Java] Dynamic resource API in Java (#4824) * Add default values for Wgym flags * Fix import * Fix issue when starting `raylet_monitor` (#4829) * Refactor ID Serial 1: Separate ObjectID and TaskID from UniqueID (#4776) * Enable BaseId. * Change TaskID and make python test pass * Remove unnecessary functions and fix test failure and change TaskID to 16 bytes. * Java code change draft * Refine * Lint * Update java/api/src/main/java/org/ray/api/id/TaskId.java Co-Authored-By: Hao Chen * Update java/api/src/main/java/org/ray/api/id/BaseId.java Co-Authored-By: Hao Chen * Update java/api/src/main/java/org/ray/api/id/BaseId.java Co-Authored-By: Hao Chen * Update java/api/src/main/java/org/ray/api/id/ObjectId.java Co-Authored-By: Hao Chen * Address comment * Lint * Fix SINGLE_PROCESS * Fix comments * Refine code * Refine test * Resolve conflict * Fix bug in which actor classes are not exported multiple times. (#4838) * Bump Ray master version to 0.8.0.dev0 (#4845) * Add section to bump version of master branch and cleanup release docs (#4846) * Fix import * Export remote functions when first used and also fix bug in which rem… (#4844) * Export remote functions when first used and also fix bug in which remote functions and actor classes are not exported from workers during subsequent ray sessions. * Documentation update * Fix tests. * Fix grammar * Update wheel versions in documentation to 0.8.0.dev0 and 0.7.0. (#4847) * [tune] Later expansion of local_dir (#4806) * [rllib] [RFC] Deprecate Python 2 / RLlib (#4832) * Fix a typo in kubernetes yaml (#4872) * Move global state API out of global_state object. (#4857) * Install bazel in autoscaler development configs. (#4874) * [tune] Fix up Ax Search and Examples (#4851) * update Ax for cleaner API * docs update * [rllib] Update concepts docs and add "Building Policies in Torch/TensorFlow" section (#4821) * wip * fix index * fix bugs * todo * add imports * note on get ph * note on get ph * rename to building custom algs * add rnn state info * [rllib] Fix error getting kl when simple_optimizer: True in multi-agent PPO * Replace ReturnIds with NumReturns in TaskInfo to reduce the size (#4854) * Refine TaskInfo * Fix * Add a test to print task info size * Lint * Refine * Update deps commits of opencensus to support building with bzl 0.25.x (#4862) * Update deps to support bzl 2.5.x * Fix * Upgrade arrow to latest master (#4858) * [tune] Auto-init Ray + default SearchAlg (#4815) * Bump version from 0.8.0.dev0 to 0.7.1. (#4890) * [rllib] Allow access to batches prior to postprocessing (#4871) * [rllib] Fix Multidiscrete support (#4869) * Refactor redis callback handling (#4841) * Add CallbackReply * Fix * fix linting by format.sh * Fix linting * Address comments. * Fix * Initial high-level code structure of CoreWorker. (#4875) * Drop duplicated string format (#4897) This string format is unnecessary. java_worker_options has been appended to the commandline later. * Refactor ID Serial 2: change all ID functions to `CamelCase` (#4896) * Hotfix for change of from_random to FromRandom (#4909) * [rllib] Fix documentation on custom policies (#4910) * wip * add docs * lint * todo sections * fix doc * [rllib] Allow Torch policies access to full action input dict in extra_action_out_fn (#4894) * fix torch extra out * preserve setitem * fix docs * [tune] Pretty print params json in logger.py (#4903) * [sgd] Distributed Training via PyTorch (#4797) Implements distributed SGD using distributed PyTorch. * [rllib] Rough port of DQN to build_tf_policy() pattern (#4823) * fetching objects in parallel in _get_arguments_for_execution (#4775) * [tune] Disallow setting resources_per_trial when it is already configured (#4880) * disallow it * import fix * fix example * fix test * fix tests * Update mock.py * fix * make less convoluted * fix tests * [rllib] Rename PolicyEvaluator => RolloutWorker (#4820) * Fix local cluster yaml (#4918) * [tune] Directional metrics for components (#4120) (#4915) * [Core Worker] implement ObjectInterface and add test framework (#4899) * [tune] Make PBT Quantile fraction configurable (#4912) * Better organize ray_common module (#4898) * Fix error * [tune] Add requirements-dev.txt and update docs for contributing (#4925) * Add requirements-dev.txt and update docs. * Update doc/source/tune-contrib.rst Co-Authored-By: Richard Liaw * Unpin everything except for yapf. * Fix compute actions return value * Bump version from 0.7.1 to 0.8.0.dev1. (#4937) * Update version number in documentation after release 0.7.0 -> 0.7.1 and 0.8.0.dev0 -> 0.8.0.dev1. (#4941) * [doc] Update developer docs with bazel instructions (#4944) * [C++] Add hash table to Redis-Module (#4911) * Flush lineage cache on task submission instead of execution (#4942) * [rllib] Add docs on how to use TF eager execution (#4927) * [rllib] Port remainder of algorithms to build_trainer() pattern (#4920) * Fix resource bookkeeping bug with acquiring unknown resource. (#4945) * Update aws keys for uploading wheels to s3. (#4948) * Upload wheels on Travis to branchname/commit_id. (#4949) * [Java] Fix serializing issues of `RaySerializer` (#4887) * Fix * Address comment. * fix (#4950) * [Java] Add inner class `Builder` to build call options. (#4956) * Add Builder class * format * Refactor by IDE * Remove uncessary dependency * Make release stress tests work and improve them. (#4955) * Use proper session directory for debug_string.txt (#4960) * [core] Use int64_t instead of int to keep track of fractional resources (#4959) * [core worker] add task submission & execution interface (#4922) * [sgd] Add non-distributed PyTorch runner (#4933) * Add non-distributed PyTorch runner * use dist.is_available() instead of checking OS * Nicer exception * Fix bug in choosing port * Refactor some code * Address comments * Address comments * Flush all tasks from local lineage cache after a node failure (#4964) * Remove typing from setup.py install_requirements. (#4971) * [Java] Fix bug of `BaseID` in multi-threading case. (#4974) * [rllib] Fix DDPG example (#4973) * Upgrade CI clang-format to 6.0 (#4976) * [Core worker] add store & task provider (#4966) * Fix bugs in the a3c code template. (#4984) * Inherit Function Docstrings and other metedata (#4985) * Fix a crash when unknown worker registering to raylet (#4992) * [gRPC] Use gRPC for inter-node-manager communication (#4968) * Fix Java CI failure (#4995) * fix handling of non-integral timeout values in signal.receive (#5002) * temp fix for build (#5006) * [tune] Tutorial UX Changes (#4990) * add integration, iris, ASHA, recursive changes, set reuse_actors=True, and enable Analysis as a return object * docstring * fix up example * fix * cleanup tests * experiment analysis * Fix valgrind build by installing new version of valgrind (#5008) * Fix no cpus test (#5009) * Fix tensorflow-1.14 installation in jenkins (#5007) * Add dynamic worker options for worker command. (#4970) * Add fields for fbs * WIP * Fix complition errors * Add java part * FIx * Fix * Fix * Fix lint * Refine API * address comments and add test * Fix * Address comment. * Address comments. * Fix linting * Refine * Fix lint * WIP: address comment. * Fix java * Fix py * Refin * Fix * Fix * Fix linting * Fix lint * Address comments * WIP * Fix * Fix * minor refine * Fix lint * Fix raylet test. * Fix lint * Update src/ray/raylet/worker_pool.h Co-Authored-By: Hao Chen * Update java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java Co-Authored-By: Hao Chen * Address comments. * Address comments. * Fix test. * Update src/ray/raylet/worker_pool.h Co-Authored-By: Hao Chen * Address comments. * Address comments. * Fix * Fix lint * Fix lint * Fix * Address comments. * Fix linting * [docs] docs for running Tensorboard without sudo (#5015) * Instructions for running Tensorboard without sudo When we run Tensorboard to visualize the results of Ray outputs on multi-user clusters where we don't have sudo access, such as RISE clusters, a few commands need to first be run to make sure tensorboard can edit the tmp directory. This is a pretty common usecase so I figured we may as well put it in the documentation for Tune. * Update tune-usage.rst * [ci] Change Jenkins to py3 (#5022) * conda3 * integration * add nevergrad, remotedata * pytest 0.3.1 * otherdockers * setup * tune * [gRPC] Migrate gcs data structures to protobuf (#5024) * [rllib] Add QMIX mixer parameters to optimizer param list (#5014) * add mixer params * Update qmix_policy.py * [grpc] refactor rpc server to support multiple io services (#5023) * [rllib] Give error if sample_async is used with pytorch for A3C (#5000) * give error if sample_async is used with pytorch * update * Update a3c.py * [tune] Update MNIST Example (#4991) * Add entropy coeff schedule * Revert "Merge with ray master" This reverts commit 108bfa293001ffd589e79288e98999aacf5b59f9, reversing changes made to 2e0eec9f723f6ba96e11183f16dba0fc664cb655. * Revert "Revert "Merge with ray master"" This reverts commit 92c0f88b9cd75be6281204467b95526951c03e87. * Remove entropy decay stuff --- BUILD.bazel | 96 ++-- bazel/ray_deps_build_all.bzl | 4 + bazel/ray_deps_setup.bzl | 11 +- .../run_perf_integration.sh | 2 +- ci/jenkins_tests/run_tune_tests.sh | 8 +- doc/source/conf.py | 15 +- doc/source/tune-usage.rst | 6 + docker/base-deps/Dockerfile | 2 +- docker/examples/Dockerfile | 5 +- docker/stress_test/Dockerfile | 2 +- docker/tune_test/Dockerfile | 11 +- java/BUILD.bazel | 51 +-- .../src/main/java/org/ray/api/id/BaseId.java | 2 +- .../ray/api/options/ActorCreationOptions.java | 15 +- java/dependencies.bzl | 1 + ...modify_generated_java_flatbuffers_files.py | 20 +- java/runtime/pom.xml | 5 + .../org/ray/runtime/AbstractRayRuntime.java | 9 +- .../java/org/ray/runtime/gcs/GcsClient.java | 69 +-- .../runtime/objectstore/ObjectStoreProxy.java | 12 +- .../ray/runtime/raylet/RayletClientImpl.java | 18 +- .../org/ray/runtime/runner/RunManager.java | 3 + .../java/org/ray/runtime/task/TaskSpec.java | 8 +- .../src/main/java/org/ray/api/TestUtils.java | 15 + .../org/ray/api/test/DynamicResourceTest.java | 17 +- .../main/java/org/ray/api/test/WaitTest.java | 5 + .../ray/api/test/WorkerJvmOptionsTest.java | 31 ++ python/ray/experimental/signal.py | 14 +- python/ray/gcs_utils.py | 71 ++- python/ray/monitor.py | 33 +- python/ray/rllib/agents/a3c/a3c.py | 4 + .../ray/rllib/agents/impala/vtrace_policy.py | 30 +- python/ray/rllib/agents/qmix/qmix_policy.py | 2 + python/ray/rllib/policy/tf_policy.py | 18 +- python/ray/rllib/tests/test_optimizers.py | 10 +- python/ray/services.py | 3 + python/ray/state.py | 230 ++++------ python/ray/tests/cluster_utils.py | 4 +- python/ray/tests/conftest.py | 8 + python/ray/tests/test_actor.py | 2 +- python/ray/tests/test_basic.py | 14 +- python/ray/tests/test_failure.py | 5 +- python/ray/tests/test_signal.py | 33 ++ .../ray/tune/analysis/experiment_analysis.py | 94 +++- python/ray/tune/examples/mnist_pytorch.py | 273 +++++------- python/ray/tune/examples/track_example.py | 4 +- python/ray/tune/examples/tune_mnist_keras.py | 8 +- python/ray/tune/examples/utils.py | 36 +- python/ray/tune/experiment.py | 8 + python/ray/tune/integration/__init__.py | 0 python/ray/tune/integration/keras.py | 34 ++ python/ray/tune/schedulers/__init__.py | 6 +- python/ray/tune/schedulers/async_hyperband.py | 2 + .../tune/tests/test_experiment_analysis.py | 62 +-- python/ray/tune/tests/test_trial_runner.py | 8 + python/ray/tune/trial.py | 25 +- python/ray/tune/tune.py | 11 +- python/ray/utils.py | 8 +- python/ray/worker.py | 40 +- python/setup.py | 1 + src/ray/common/constants.h | 2 + src/ray/gcs/client.cc | 4 - src/ray/gcs/client.h | 6 - src/ray/gcs/client_test.cc | 353 +++++++-------- src/ray/gcs/format/gcs.fbs | 286 +----------- src/ray/gcs/redis_context.h | 15 +- src/ray/gcs/redis_module/ray_redis_module.cc | 209 ++++----- src/ray/gcs/tables.cc | 417 ++++++++---------- src/ray/gcs/tables.h | 136 +++--- src/ray/object_manager/object_directory.cc | 34 +- src/ray/object_manager/object_manager.cc | 49 +- src/ray/object_manager/object_manager.h | 4 +- .../test/object_manager_stress_test.cc | 30 +- .../test/object_manager_test.cc | 36 +- src/ray/protobuf/gcs.proto | 280 ++++++++++++ src/ray/raylet/actor_registration.cc | 51 +-- src/ray/raylet/actor_registration.h | 24 +- src/ray/raylet/lineage_cache.cc | 37 +- src/ray/raylet/lineage_cache.h | 28 +- src/ray/raylet/lineage_cache_test.cc | 28 +- src/ray/raylet/monitor.cc | 15 +- src/ray/raylet/monitor.h | 8 +- src/ray/raylet/node_manager.cc | 262 +++++------ src/ray/raylet/node_manager.h | 31 +- src/ray/raylet/raylet.cc | 24 +- src/ray/raylet/raylet.h | 2 + src/ray/raylet/reconstruction_policy.cc | 10 +- src/ray/raylet/reconstruction_policy.h | 2 + src/ray/raylet/reconstruction_policy_test.cc | 42 +- src/ray/raylet/task_dependency_manager.cc | 8 +- src/ray/raylet/task_dependency_manager.h | 2 + .../raylet/task_dependency_manager_test.cc | 2 +- src/ray/raylet/task_spec.cc | 12 +- src/ray/raylet/task_spec.h | 6 +- src/ray/raylet/worker_pool.cc | 100 ++++- src/ray/raylet/worker_pool.h | 56 ++- src/ray/raylet/worker_pool_test.cc | 65 ++- src/ray/rpc/grpc_server.cc | 17 +- src/ray/rpc/grpc_server.h | 77 ++-- src/ray/rpc/node_manager_server.h | 25 +- src/ray/rpc/server_call.h | 26 +- src/ray/rpc/util.h | 13 + 102 files changed, 2330 insertions(+), 2048 deletions(-) create mode 100644 java/test/src/main/java/org/ray/api/test/WorkerJvmOptionsTest.java create mode 100644 python/ray/tune/integration/__init__.py create mode 100644 python/ray/tune/integration/keras.py create mode 100644 src/ray/protobuf/gcs.proto diff --git a/BUILD.bazel b/BUILD.bazel index da36eec0cf57..bc9e6bcd8006 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -1,22 +1,55 @@ # Bazel build # C/C++ documentation: https://docs.bazel.build/versions/master/be/c-cpp.html -load("@com_github_grpc_grpc//bazel:grpc_build_system.bzl", "grpc_proto_library") +load("@com_github_grpc_grpc//bazel:cc_grpc_library.bzl", "cc_grpc_library") +load("@build_stack_rules_proto//python:python_proto_compile.bzl", "python_proto_compile") load("@com_github_google_flatbuffers//:build_defs.bzl", "flatbuffer_cc_library") load("@//bazel:ray.bzl", "flatbuffer_py_library") load("@//bazel:cython_library.bzl", "pyx_library") COPTS = ["-DRAY_USE_GLOG"] -# Node manager gRPC lib. -grpc_proto_library( - name = "node_manager_grpc_lib", +# === Begin of protobuf definitions === + +proto_library( + name = "gcs_proto", + srcs = ["src/ray/protobuf/gcs.proto"], + visibility = ["//java:__subpackages__"], +) + +cc_proto_library( + name = "gcs_cc_proto", + deps = [":gcs_proto"], +) + +python_proto_compile( + name = "gcs_py_proto", + deps = [":gcs_proto"], +) + +proto_library( + name = "node_manager_proto", srcs = ["src/ray/protobuf/node_manager.proto"], ) +cc_proto_library( + name = "node_manager_cc_proto", + deps = ["node_manager_proto"], +) + +# === End of protobuf definitions === + +# Node manager gRPC lib. +cc_grpc_library( + name = "node_manager_cc_grpc", + srcs = [":node_manager_proto"], + grpc_only = True, + deps = [":node_manager_cc_proto"], +) + # Node manager server and client. cc_library( - name = "node_manager_rpc_lib", + name = "node_manager_rpc", srcs = glob([ "src/ray/rpc/*.cc", ]), @@ -25,7 +58,7 @@ cc_library( ]), copts = COPTS, deps = [ - ":node_manager_grpc_lib", + ":node_manager_cc_grpc", ":ray_common", "@boost//:asio", "@com_github_grpc_grpc//:grpc++", @@ -114,7 +147,7 @@ cc_library( ":gcs", ":gcs_fbs", ":node_manager_fbs", - ":node_manager_rpc_lib", + ":node_manager_rpc", ":object_manager", ":ray_common", ":ray_util", @@ -422,9 +455,11 @@ cc_library( "src/ray/gcs/format", ], deps = [ + ":gcs_cc_proto", ":gcs_fbs", ":hiredis", ":node_manager_fbs", + ":node_manager_rpc", ":ray_common", ":ray_util", ":stats_lib", @@ -555,46 +590,6 @@ filegroup( visibility = ["//java:__subpackages__"], ) -flatbuffer_py_library( - name = "python_gcs_fbs", - srcs = [ - ":gcs_fbs_file", - ], - outs = [ - "ActorCheckpointIdData.py", - "ActorState.py", - "ActorTableData.py", - "Arg.py", - "ClassTableData.py", - "ClientTableData.py", - "ConfigTableData.py", - "CustomSerializerData.py", - "DriverTableData.py", - "EntryType.py", - "ErrorTableData.py", - "ErrorType.py", - "FunctionTableData.py", - "GcsEntry.py", - "HeartbeatBatchTableData.py", - "HeartbeatTableData.py", - "Language.py", - "ObjectTableData.py", - "ProfileEvent.py", - "ProfileTableData.py", - "RayResource.py", - "ResourcePair.py", - "SchedulingState.py", - "TablePrefix.py", - "TablePubsub.py", - "TaskInfo.py", - "TaskLeaseData.py", - "TaskReconstructionData.py", - "TaskTableData.py", - "TaskTableTestAndUpdate.py", - ], - out_prefix = "python/ray/core/generated/", -) - flatbuffer_py_library( name = "python_node_manager_fbs", srcs = [ @@ -679,6 +674,7 @@ cc_binary( linkstatic = 1, visibility = ["//java:__subpackages__"], deps = [ + ":gcs_cc_proto", ":ray_common", ], ) @@ -688,7 +684,7 @@ genrule( srcs = [ "python/ray/_raylet.so", "//:python_sources", - "//:python_gcs_fbs", + "//:gcs_py_proto", "//:python_node_manager_fbs", "//:redis-server", "//:redis-cli", @@ -710,11 +706,13 @@ genrule( cp -f $(location //:raylet_monitor) $$WORK_DIR/python/ray/core/src/ray/raylet/ && cp -f $(location @plasma//:plasma_store_server) $$WORK_DIR/python/ray/core/src/plasma/ && cp -f $(location //:raylet) $$WORK_DIR/python/ray/core/src/ray/raylet/ && - for f in $(locations //:python_gcs_fbs); do cp -f $$f $$WORK_DIR/python/ray/core/generated/; done && mkdir -p $$WORK_DIR/python/ray/core/generated/ray/protocol/ && for f in $(locations //:python_node_manager_fbs); do cp -f $$f $$WORK_DIR/python/ray/core/generated/ray/protocol/; done && + for f in $(locations //:gcs_py_proto); do + cp -f $$f $$WORK_DIR/python/ray/core/generated/; + done && echo $$WORK_DIR > $@ """, local = 1, diff --git a/bazel/ray_deps_build_all.bzl b/bazel/ray_deps_build_all.bzl index 3e1e1838a59a..eda88bece7d2 100644 --- a/bazel/ray_deps_build_all.bzl +++ b/bazel/ray_deps_build_all.bzl @@ -4,6 +4,8 @@ load("@com_github_jupp0r_prometheus_cpp//:repositories.bzl", "prometheus_cpp_rep load("@com_github_ray_project_ray//bazel:python_configure.bzl", "python_configure") load("@com_github_checkstyle_java//:repo.bzl", "checkstyle_deps") load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps") +load("@build_stack_rules_proto//java:deps.bzl", "java_proto_compile") +load("@build_stack_rules_proto//python:deps.bzl", "python_proto_compile") def ray_deps_build_all(): @@ -13,4 +15,6 @@ def ray_deps_build_all(): prometheus_cpp_repositories() python_configure(name = "local_config_python") grpc_deps() + java_proto_compile() + python_proto_compile() diff --git a/bazel/ray_deps_setup.bzl b/bazel/ray_deps_setup.bzl index e6dc21585699..aa322654cf9f 100644 --- a/bazel/ray_deps_setup.bzl +++ b/bazel/ray_deps_setup.bzl @@ -105,7 +105,14 @@ def ray_deps_setup(): http_archive( name = "com_github_grpc_grpc", urls = [ - "https://github.com/grpc/grpc/archive/7741e806a213cba63c96234f16d712a8aa101a49.tar.gz", + "https://github.com/grpc/grpc/archive/76a381869413834692b8ed305fbe923c0f9c4472.tar.gz", ], - strip_prefix = "grpc-7741e806a213cba63c96234f16d712a8aa101a49", + strip_prefix = "grpc-76a381869413834692b8ed305fbe923c0f9c4472", + ) + + http_archive( + name = "build_stack_rules_proto", + urls = ["https://github.com/stackb/rules_proto/archive/b93b544f851fdcd3fc5c3d47aee3b7ca158a8841.tar.gz"], + sha256 = "c62f0b442e82a6152fcd5b1c0b7c4028233a9e314078952b6b04253421d56d61", + strip_prefix = "rules_proto-b93b544f851fdcd3fc5c3d47aee3b7ca158a8841", ) diff --git a/ci/jenkins_tests/perf_integration_tests/run_perf_integration.sh b/ci/jenkins_tests/perf_integration_tests/run_perf_integration.sh index 7962b21075c0..f25d32df22a1 100755 --- a/ci/jenkins_tests/perf_integration_tests/run_perf_integration.sh +++ b/ci/jenkins_tests/perf_integration_tests/run_perf_integration.sh @@ -9,7 +9,7 @@ pushd "$ROOT_DIR" python -m pip install pytest-benchmark -pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev1-cp27-cp27mu-manylinux1_x86_64.whl +pip install -U https://ray-wheels.s3-us-west-2.amazonaws.com/latest/ray-0.8.0.dev1-cp36-cp36m-manylinux1_x86_64.whl python -m pytest --benchmark-autosave --benchmark-min-rounds=10 --benchmark-columns="min, max, mean" $ROOT_DIR/../../../python/ray/tests/perf_integration_tests/test_perf_integration.py pushd $ROOT_DIR/../../../python diff --git a/ci/jenkins_tests/run_tune_tests.sh b/ci/jenkins_tests/run_tune_tests.sh index 6154fe70d4f6..6b890d7d371c 100755 --- a/ci/jenkins_tests/run_tune_tests.sh +++ b/ci/jenkins_tests/run_tune_tests.sh @@ -78,16 +78,16 @@ $SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} --smoke-test # Runs only on Python3 -# docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ -# python3 /ray/python/ray/tune/examples/nevergrad_example.py \ -# --smoke-test +$SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ + python /ray/python/ray/tune/examples/nevergrad_example.py \ + --smoke-test $SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/tune/examples/tune_mnist_keras.py \ --smoke-test $SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ - python /ray/python/ray/tune/examples/mnist_pytorch.py --smoke-test --no-cuda + python /ray/python/ray/tune/examples/mnist_pytorch.py --smoke-test $SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ python /ray/python/ray/tune/examples/mnist_pytorch_trainable.py \ diff --git a/doc/source/conf.py b/doc/source/conf.py index 98fb3e0d02dd..5cf6b01217f9 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -23,20 +23,7 @@ "gym.spaces", "ray._raylet", "ray.core.generated", - "ray.core.generated.ActorCheckpointIdData", - "ray.core.generated.ClientTableData", - "ray.core.generated.DriverTableData", - "ray.core.generated.EntryType", - "ray.core.generated.ErrorTableData", - "ray.core.generated.ErrorType", - "ray.core.generated.GcsEntry", - "ray.core.generated.HeartbeatBatchTableData", - "ray.core.generated.HeartbeatTableData", - "ray.core.generated.Language", - "ray.core.generated.ObjectTableData", - "ray.core.generated.ProfileTableData", - "ray.core.generated.TablePrefix", - "ray.core.generated.TablePubsub", + "ray.core.generated.gcs_pb2", "ray.core.generated.ray.protocol.Task", "scipy", "scipy.signal", diff --git a/doc/source/tune-usage.rst b/doc/source/tune-usage.rst index 281ccbd6107e..e8ce405d9457 100644 --- a/doc/source/tune-usage.rst +++ b/doc/source/tune-usage.rst @@ -355,6 +355,12 @@ Then, after you run a experiment, you can visualize your experiment with TensorB $ tensorboard --logdir=~/ray_results/my_experiment +If you are running Ray on a remote multi-user cluster where you do not have sudo access, you can run the following commands to make sure tensorboard is able to write to the tmp directory: + +.. code-block:: bash + + $ export TMPDIR=/tmp/$USER; mkdir -p $TMPDIR; tensorboard --logdir=~/ray_results + .. image:: ray-tune-tensorboard.png To use rllab's VisKit (you may have to install some dependencies), run: diff --git a/docker/base-deps/Dockerfile b/docker/base-deps/Dockerfile index c21430c627a4..db8f28c85f86 100644 --- a/docker/base-deps/Dockerfile +++ b/docker/base-deps/Dockerfile @@ -12,7 +12,7 @@ RUN apt-get update \ && apt-get clean \ && echo 'export PATH=/opt/conda/bin:$PATH' > /etc/profile.d/conda.sh \ && wget \ - --quiet 'https://repo.continuum.io/archive/Anaconda2-5.2.0-Linux-x86_64.sh' \ + --quiet 'https://repo.continuum.io/archive/Anaconda3-5.2.0-Linux-x86_64.sh' \ -O /tmp/anaconda.sh \ && /bin/bash /tmp/anaconda.sh -b -p /opt/conda \ && rm /tmp/anaconda.sh \ diff --git a/docker/examples/Dockerfile b/docker/examples/Dockerfile index 6883c5a64a0e..bafcdf35e628 100644 --- a/docker/examples/Dockerfile +++ b/docker/examples/Dockerfile @@ -5,11 +5,14 @@ FROM ray-project/deploy # This updates numpy to 1.14 and mutes errors from other libraries RUN conda install -y numpy RUN apt-get install -y zlib1g-dev +# The following is needed to support TensorFlow 1.14 +RUN conda remove -y --force wrapt RUN pip install gym[atari] opencv-python-headless tensorflow lz4 keras pytest-timeout smart_open RUN pip install -U h5py # Mutes FutureWarnings RUN pip install --upgrade bayesian-optimization RUN pip install --upgrade git+git://github.com/hyperopt/hyperopt.git RUN pip install --upgrade sigopt -# RUN pip install --upgrade nevergrad +RUN pip install --upgrade nevergrad RUN pip install --upgrade scikit-optimize +RUN pip install -U pytest-remotedata>=0.3.1 RUN conda install pytorch-cpu torchvision-cpu -c pytorch diff --git a/docker/stress_test/Dockerfile b/docker/stress_test/Dockerfile index 1d174ed72f92..376fe5340fd9 100644 --- a/docker/stress_test/Dockerfile +++ b/docker/stress_test/Dockerfile @@ -4,7 +4,7 @@ FROM ray-project/base-deps # We install ray and boto3 to enable the ray autoscaler as # a test runner. -RUN pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev1-cp27-cp27mu-manylinux1_x86_64.whl boto3 +RUN pip install -U https://ray-wheels.s3-us-west-2.amazonaws.com/latest/ray-0.8.0.dev1-cp36-cp36m-manylinux1_x86_64.whl boto3 RUN mkdir -p /root/.ssh/ # We port the source code in so that we run the most up-to-date stress tests. diff --git a/docker/tune_test/Dockerfile b/docker/tune_test/Dockerfile index 6e098d5218f6..77cf390493d6 100644 --- a/docker/tune_test/Dockerfile +++ b/docker/tune_test/Dockerfile @@ -4,15 +4,20 @@ FROM ray-project/base-deps # We install ray and boto3 to enable the ray autoscaler as # a test runner. -RUN pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev1-cp27-cp27mu-manylinux1_x86_64.whl boto3 +RUN conda install -y -c anaconda wrapt=1.11.1 +RUN conda install -y -c anaconda numpy=1.16.4 +RUN pip install -U https://ray-wheels.s3-us-west-2.amazonaws.com/latest/ray-0.8.0.dev1-cp36-cp36m-manylinux1_x86_64.whl boto3 # We install this after the latest wheels -- this should not override the latest wheels. RUN apt-get install -y zlib1g-dev +# The following is needed to support TensorFlow 1.14 +RUN conda remove -y --force wrapt RUN pip install gym[atari]==0.10.11 opencv-python-headless tensorflow lz4 keras pytest-timeout smart_open RUN pip install --upgrade bayesian-optimization RUN pip install --upgrade git+git://github.com/hyperopt/hyperopt.git RUN pip install --upgrade sigopt -# RUN pip install --upgrade nevergrad +RUN pip install --upgrade nevergrad RUN pip install --upgrade scikit-optimize +RUN pip install -U pytest-remotedata>=0.3.1 RUN conda install pytorch-cpu torchvision-cpu -c pytorch # RUN mkdir -p /root/.ssh/ @@ -20,6 +25,6 @@ RUN conda install pytorch-cpu torchvision-cpu -c pytorch # We port the source code in so that we run the most up-to-date stress tests. ADD ray.tar /ray ADD git-rev /ray/git-rev -RUN python /ray/python/ray/rllib/setup-rllib-dev.py --yes +RUN python /ray/python/ray/setup-dev.py --yes WORKDIR /ray diff --git a/java/BUILD.bazel b/java/BUILD.bazel index 80ccabccfc12..4960434af180 100644 --- a/java/BUILD.bazel +++ b/java/BUILD.bazel @@ -1,4 +1,5 @@ load("//bazel:ray.bzl", "flatbuffer_java_library", "define_java_module") +load("@build_stack_rules_proto//java:java_proto_compile.bzl", "java_proto_compile") exports_files([ "testng.xml", @@ -50,6 +51,7 @@ define_java_module( name = "runtime", additional_srcs = [ ":generate_java_gcs_fbs", + ":gcs_java_proto", ], additional_resources = [ ":java_native_deps", @@ -68,6 +70,7 @@ define_java_module( "@plasma//:org_apache_arrow_arrow_plasma", "@maven//:com_github_davidmoten_flatbuffers_java", "@maven//:com_google_guava_guava", + "@maven//:com_google_protobuf_protobuf_java", "@maven//:com_typesafe_config", "@maven//:commons_io_commons_io", "@maven//:de_ruedigermoeller_fst", @@ -148,38 +151,16 @@ java_binary( ], ) +java_proto_compile( + name = "gcs_java_proto", + deps = ["@//:gcs_proto"], +) + flatbuffers_generated_files = [ - "ActorCheckpointData.java", - "ActorCheckpointIdData.java", - "ActorState.java", - "ActorTableData.java", "Arg.java", - "ClassTableData.java", - "ClientTableData.java", - "ConfigTableData.java", - "CustomSerializerData.java", - "DriverTableData.java", - "EntryType.java", - "ErrorTableData.java", - "ErrorType.java", - "FunctionTableData.java", - "GcsEntry.java", - "HeartbeatBatchTableData.java", - "HeartbeatTableData.java", "Language.java", - "ObjectTableData.java", - "ProfileEvent.java", - "ProfileTableData.java", - "RayResource.java", - "ResourcePair.java", - "SchedulingState.java", - "TablePrefix.java", - "TablePubsub.java", "TaskInfo.java", - "TaskLeaseData.java", - "TaskReconstructionData.java", - "TaskTableData.java", - "TaskTableTestAndUpdate.java", + "ResourcePair.java", ] flatbuffer_java_library( @@ -198,7 +179,7 @@ genrule( cmd = """ for f in $(locations //java:java_gcs_fbs); do chmod +w $$f - cp -f $$f $(@D)/runtime/src/main/java/org/ray/runtime/generated + mv -f $$f $(@D)/runtime/src/main/java/org/ray/runtime/generated done python $$(pwd)/java/modify_generated_java_flatbuffers_files.py $(@D)/.. """, @@ -221,8 +202,10 @@ filegroup( genrule( name = "gen_maven_deps", srcs = [ - ":java_native_deps", + ":gcs_java_proto", ":generate_java_gcs_fbs", + ":java_native_deps", + ":copy_pom_file", "@plasma//:org_apache_arrow_arrow_plasma", ], outs = ["gen_maven_deps.out"], @@ -237,10 +220,15 @@ genrule( chmod +w $$f cp $$f $$NATIVE_DEPS_DIR done - # Copy flatbuffers-generated files + # Copy protobuf-generated files. GENERATED_DIR=$$WORK_DIR/java/runtime/src/main/java/org/ray/runtime/generated rm -rf $$GENERATED_DIR mkdir -p $$GENERATED_DIR + for f in $(locations //java:gcs_java_proto); do + unzip $$f + mv org/ray/runtime/generated/* $$GENERATED_DIR + done + # Copy flatbuffers-generated files for f in $(locations //java:generate_java_gcs_fbs); do cp $$f $$GENERATED_DIR done @@ -250,6 +238,7 @@ genrule( echo $$(date) > $@ """, local = 1, + tags = ["no-cache"], ) genrule( diff --git a/java/api/src/main/java/org/ray/api/id/BaseId.java b/java/api/src/main/java/org/ray/api/id/BaseId.java index e08955d5a93e..c13f0436f94d 100644 --- a/java/api/src/main/java/org/ray/api/id/BaseId.java +++ b/java/api/src/main/java/org/ray/api/id/BaseId.java @@ -48,7 +48,7 @@ public boolean isNil() { break; } } - isNilCache = localIsNil; + isNilCache = localIsNil; } return isNilCache; } diff --git a/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java b/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java index d1e92f7bb9e9..2e14ca8584dd 100644 --- a/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java +++ b/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java @@ -13,9 +13,14 @@ public class ActorCreationOptions extends BaseTaskOptions { public final int maxReconstructions; - private ActorCreationOptions(Map resources, int maxReconstructions) { + public final String jvmOptions; + + private ActorCreationOptions(Map resources, + int maxReconstructions, + String jvmOptions) { super(resources); this.maxReconstructions = maxReconstructions; + this.jvmOptions = jvmOptions; } /** @@ -25,6 +30,7 @@ public static class Builder { private Map resources = new HashMap<>(); private int maxReconstructions = NO_RECONSTRUCTION; + private String jvmOptions = ""; public Builder setResources(Map resources) { this.resources = resources; @@ -36,8 +42,13 @@ public Builder setMaxReconstructions(int maxReconstructions) { return this; } + public Builder setJvmOptions(String jvmOptions) { + this.jvmOptions = jvmOptions; + return this; + } + public ActorCreationOptions createActorCreationOptions() { - return new ActorCreationOptions(resources, maxReconstructions); + return new ActorCreationOptions(resources, maxReconstructions, jvmOptions); } } diff --git a/java/dependencies.bzl b/java/dependencies.bzl index 7c716166d399..ef667137562b 100644 --- a/java/dependencies.bzl +++ b/java/dependencies.bzl @@ -6,6 +6,7 @@ def gen_java_deps(): "com.beust:jcommander:1.72", "com.github.davidmoten:flatbuffers-java:1.9.0.1", "com.google.guava:guava:27.0.1-jre", + "com.google.protobuf:protobuf-java:3.8.0", "com.puppycrawl.tools:checkstyle:8.15", "com.sun.xml.bind:jaxb-core:2.3.0", "com.sun.xml.bind:jaxb-impl:2.3.0", diff --git a/java/modify_generated_java_flatbuffers_files.py b/java/modify_generated_java_flatbuffers_files.py index c1b723f25f8d..5bf62e56d7e4 100644 --- a/java/modify_generated_java_flatbuffers_files.py +++ b/java/modify_generated_java_flatbuffers_files.py @@ -4,7 +4,6 @@ import os import sys - """ This script is used for modifying the generated java flatbuffer files for the reason: The package declaration in Java is different @@ -21,19 +20,18 @@ PACKAGE_DECLARATION = "package org.ray.runtime.generated;" -def add_new_line(file, line_num, text): +def add_package(file): with open(file, "r") as file_handler: lines = file_handler.readlines() - if (line_num <= 0) or (line_num > len(lines) + 1): - return False - lines.insert(line_num - 1, text + os.linesep) + if "FlatBuffers" not in lines[0]: + return + + lines.insert(1, PACKAGE_DECLARATION + os.linesep) with open(file, "w") as file_handler: for line in lines: file_handler.write(line) - return True - def add_package_declarations(generated_root_path): file_names = os.listdir(generated_root_path) @@ -41,15 +39,11 @@ def add_package_declarations(generated_root_path): if not file_name.endswith(".java"): continue full_name = os.path.join(generated_root_path, file_name) - success = add_new_line(full_name, 2, PACKAGE_DECLARATION) - if not success: - raise RuntimeError("Failed to add package declarations, " - "file name is %s" % full_name) + add_package(full_name) if __name__ == "__main__": ray_home = sys.argv[1] root_path = os.path.join( - ray_home, - "java/runtime/src/main/java/org/ray/runtime/generated") + ray_home, "java/runtime/src/main/java/org/ray/runtime/generated") add_package_declarations(root_path) diff --git a/java/runtime/pom.xml b/java/runtime/pom.xml index c75e2eeef13f..e13dd95f927f 100644 --- a/java/runtime/pom.xml +++ b/java/runtime/pom.xml @@ -41,6 +41,11 @@ guava 27.0.1-jre + + com.google.protobuf + protobuf-java + 3.8.0 + com.typesafe config diff --git a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java index fbd03bf10483..26a8d6e541ba 100644 --- a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java @@ -35,6 +35,7 @@ import org.ray.runtime.task.TaskLanguage; import org.ray.runtime.task.TaskSpec; import org.ray.runtime.util.IdUtil; +import org.ray.runtime.util.StringUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -363,8 +364,13 @@ private TaskSpec createTaskSpec(RayFunc func, PyFunctionDescriptor pyFunctionDes } int maxActorReconstruction = 0; + List dynamicWorkerOptions = ImmutableList.of(); if (taskOptions instanceof ActorCreationOptions) { maxActorReconstruction = ((ActorCreationOptions) taskOptions).maxReconstructions; + String jvmOptions = ((ActorCreationOptions) taskOptions).jvmOptions; + if (!StringUtil.isNullOrEmpty(jvmOptions)) { + dynamicWorkerOptions = ImmutableList.of(((ActorCreationOptions) taskOptions).jvmOptions); + } } TaskLanguage language; @@ -393,7 +399,8 @@ private TaskSpec createTaskSpec(RayFunc func, PyFunctionDescriptor pyFunctionDes numReturns, resources, language, - functionDescriptor + functionDescriptor, + dynamicWorkerOptions ); } diff --git a/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java b/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java index 431b48ded58c..17c248ed0a57 100644 --- a/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java +++ b/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java @@ -1,7 +1,7 @@ package org.ray.runtime.gcs; import com.google.common.base.Preconditions; -import java.nio.ByteBuffer; +import com.google.protobuf.InvalidProtocolBufferException; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -13,10 +13,10 @@ import org.ray.api.id.TaskId; import org.ray.api.id.UniqueId; import org.ray.api.runtimecontext.NodeInfo; -import org.ray.runtime.generated.ActorCheckpointIdData; -import org.ray.runtime.generated.ClientTableData; -import org.ray.runtime.generated.EntryType; -import org.ray.runtime.generated.TablePrefix; +import org.ray.runtime.generated.Gcs.ActorCheckpointIdData; +import org.ray.runtime.generated.Gcs.ClientTableData; +import org.ray.runtime.generated.Gcs.ClientTableData.EntryType; +import org.ray.runtime.generated.Gcs.TablePrefix; import org.ray.runtime.util.IdUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -51,7 +51,7 @@ public GcsClient(String redisAddress, String redisPassword) { } public List getAllNodeInfo() { - final String prefix = TablePrefix.name(TablePrefix.CLIENT); + final String prefix = TablePrefix.CLIENT.toString(); final byte[] key = ArrayUtils.addAll(prefix.getBytes(), UniqueId.NIL.getBytes()); List results = primary.lrange(key, 0, -1); @@ -63,36 +63,42 @@ public List getAllNodeInfo() { Map clients = new HashMap<>(); for (byte[] result : results) { Preconditions.checkNotNull(result); - ClientTableData data = ClientTableData.getRootAsClientTableData(ByteBuffer.wrap(result)); - final UniqueId clientId = UniqueId.fromByteBuffer(data.clientIdAsByteBuffer()); + ClientTableData data = null; + try { + data = ClientTableData.parseFrom(result); + } catch (InvalidProtocolBufferException e) { + throw new RuntimeException("Received invalid protobuf data from GCS."); + } + final UniqueId clientId = UniqueId + .fromByteBuffer(data.getClientId().asReadOnlyByteBuffer()); - if (data.entryType() == EntryType.INSERTION) { + if (data.getEntryType() == EntryType.INSERTION) { //Code path of node insertion. Map resources = new HashMap<>(); // Compute resources. Preconditions.checkState( - data.resourcesTotalLabelLength() == data.resourcesTotalCapacityLength()); - for (int i = 0; i < data.resourcesTotalLabelLength(); i++) { - resources.put(data.resourcesTotalLabel(i), data.resourcesTotalCapacity(i)); + data.getResourcesTotalLabelCount() == data.getResourcesTotalCapacityCount()); + for (int i = 0; i < data.getResourcesTotalLabelCount(); i++) { + resources.put(data.getResourcesTotalLabel(i), data.getResourcesTotalCapacity(i)); } NodeInfo nodeInfo = new NodeInfo( - clientId, data.nodeManagerAddress(), true, resources); + clientId, data.getNodeManagerAddress(), true, resources); clients.put(clientId, nodeInfo); - } else if (data.entryType() == EntryType.RES_CREATEUPDATE) { + } else if (data.getEntryType() == EntryType.RES_CREATEUPDATE) { Preconditions.checkState(clients.containsKey(clientId)); NodeInfo nodeInfo = clients.get(clientId); - for (int i = 0; i < data.resourcesTotalLabelLength(); i++) { - nodeInfo.resources.put(data.resourcesTotalLabel(i), data.resourcesTotalCapacity(i)); + for (int i = 0; i < data.getResourcesTotalLabelCount(); i++) { + nodeInfo.resources.put(data.getResourcesTotalLabel(i), data.getResourcesTotalCapacity(i)); } - } else if (data.entryType() == EntryType.RES_DELETE) { + } else if (data.getEntryType() == EntryType.RES_DELETE) { Preconditions.checkState(clients.containsKey(clientId)); NodeInfo nodeInfo = clients.get(clientId); - for (int i = 0; i < data.resourcesTotalLabelLength(); i++) { - nodeInfo.resources.remove(data.resourcesTotalLabel(i)); + for (int i = 0; i < data.getResourcesTotalLabelCount(); i++) { + nodeInfo.resources.remove(data.getResourcesTotalLabel(i)); } } else { // Code path of node deletion. - Preconditions.checkState(data.entryType() == EntryType.DELETION); + Preconditions.checkState(data.getEntryType() == EntryType.DELETION); NodeInfo nodeInfo = new NodeInfo(clientId, clients.get(clientId).nodeAddress, false, clients.get(clientId).resources); clients.put(clientId, nodeInfo); @@ -107,7 +113,7 @@ public List getAllNodeInfo() { */ public boolean actorExists(UniqueId actorId) { byte[] key = ArrayUtils.addAll( - TablePrefix.name(TablePrefix.ACTOR).getBytes(), actorId.getBytes()); + TablePrefix.ACTOR.toString().getBytes(), actorId.getBytes()); return primary.exists(key); } @@ -115,7 +121,7 @@ public boolean actorExists(UniqueId actorId) { * Query whether the raylet task exists in Gcs. */ public boolean rayletTaskExistsInGcs(TaskId taskId) { - byte[] key = ArrayUtils.addAll(TablePrefix.name(TablePrefix.RAYLET_TASK).getBytes(), + byte[] key = ArrayUtils.addAll(TablePrefix.RAYLET_TASK.toString().getBytes(), taskId.getBytes()); RedisClient client = getShardClient(taskId); return client.exists(key); @@ -126,19 +132,26 @@ public boolean rayletTaskExistsInGcs(TaskId taskId) { */ public List getCheckpointsForActor(UniqueId actorId) { List checkpoints = new ArrayList<>(); - final String prefix = TablePrefix.name(TablePrefix.ACTOR_CHECKPOINT_ID); + final String prefix = TablePrefix.ACTOR_CHECKPOINT_ID.toString(); final byte[] key = ArrayUtils.addAll(prefix.getBytes(), actorId.getBytes()); RedisClient client = getShardClient(actorId); byte[] result = client.get(key); if (result != null) { - ActorCheckpointIdData data = - ActorCheckpointIdData.getRootAsActorCheckpointIdData(ByteBuffer.wrap(result)); - UniqueId[] checkpointIds = IdUtil.getUniqueIdsFromByteBuffer( - data.checkpointIdsAsByteBuffer()); + ActorCheckpointIdData data = null; + try { + data = ActorCheckpointIdData.parseFrom(result); + } catch (InvalidProtocolBufferException e) { + throw new RuntimeException("Received invalid protobuf data from GCS."); + } + UniqueId[] checkpointIds = new UniqueId[data.getCheckpointIdsCount()]; + for (int i = 0; i < checkpointIds.length; i++) { + checkpointIds[i] = UniqueId + .fromByteBuffer(data.getCheckpointIds(i).asReadOnlyByteBuffer()); + } for (int i = 0; i < checkpointIds.length; i++) { - checkpoints.add(new Checkpoint(checkpointIds[i], data.timestamps(i))); + checkpoints.add(new Checkpoint(checkpointIds[i], data.getTimestamps(i))); } } checkpoints.sort((x, y) -> Long.compare(y.timestamp, x.timestamp)); diff --git a/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java b/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java index f9e310249a35..1a7e4701c22b 100644 --- a/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java +++ b/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java @@ -16,7 +16,7 @@ import org.ray.runtime.AbstractRayRuntime; import org.ray.runtime.RayDevRuntime; import org.ray.runtime.config.RunMode; -import org.ray.runtime.generated.ErrorType; +import org.ray.runtime.generated.Gcs.ErrorType; import org.ray.runtime.util.IdUtil; import org.ray.runtime.util.Serializer; import org.slf4j.Logger; @@ -29,12 +29,12 @@ public class ObjectStoreProxy { private static final Logger LOGGER = LoggerFactory.getLogger(ObjectStoreProxy.class); - private static final byte[] WORKER_EXCEPTION_META = String.valueOf(ErrorType.WORKER_DIED) - .getBytes(); - private static final byte[] ACTOR_EXCEPTION_META = String.valueOf(ErrorType.ACTOR_DIED) - .getBytes(); + private static final byte[] WORKER_EXCEPTION_META = String + .valueOf(ErrorType.WORKER_DIED.getNumber()).getBytes(); + private static final byte[] ACTOR_EXCEPTION_META = String + .valueOf(ErrorType.ACTOR_DIED.getNumber()).getBytes(); private static final byte[] UNRECONSTRUCTABLE_EXCEPTION_META = String - .valueOf(ErrorType.OBJECT_UNRECONSTRUCTABLE).getBytes(); + .valueOf(ErrorType.OBJECT_UNRECONSTRUCTABLE.getNumber()).getBytes(); private static final byte[] RAW_TYPE_META = "RAW".getBytes(); diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java index 01b9e4675016..c369e6f2cab8 100644 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java +++ b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java @@ -190,9 +190,16 @@ private static TaskSpec parseTaskSpecFromFlatbuffer(ByteBuffer bb) { JavaFunctionDescriptor functionDescriptor = new JavaFunctionDescriptor( info.functionDescriptor(0), info.functionDescriptor(1), info.functionDescriptor(2) ); + + // Deserialize dynamic worker options. + List dynamicWorkerOptions = new ArrayList<>(); + for (int i = 0; i < info.dynamicWorkerOptionsLength(); ++i) { + dynamicWorkerOptions.add(info.dynamicWorkerOptions(i)); + } + return new TaskSpec(driverId, taskId, parentTaskId, parentCounter, actorCreationId, maxActorReconstructions, actorId, actorHandleId, actorCounter, newActorHandles, - args, numReturns, resources, TaskLanguage.JAVA, functionDescriptor); + args, numReturns, resources, TaskLanguage.JAVA, functionDescriptor, dynamicWorkerOptions); } private static ByteBuffer convertTaskSpecToFlatbuffer(TaskSpec task) { @@ -275,6 +282,12 @@ private static ByteBuffer convertTaskSpecToFlatbuffer(TaskSpec task) { functionDescriptorOffset = fbb.createVectorOfTables(functionDescriptorOffsets); } + int [] dynamicWorkerOptionsOffsets = new int[task.dynamicWorkerOptions.size()]; + for (int index = 0; index < task.dynamicWorkerOptions.size(); ++index) { + dynamicWorkerOptionsOffsets[index] = fbb.createString(task.dynamicWorkerOptions.get(index)); + } + int dynamicWorkerOptionsOffset = fbb.createVectorOfTables(dynamicWorkerOptionsOffsets); + int root = TaskInfo.createTaskInfo( fbb, driverIdOffset, @@ -293,7 +306,8 @@ private static ByteBuffer convertTaskSpecToFlatbuffer(TaskSpec task) { requiredResourcesOffset, requiredPlacementResourcesOffset, language, - functionDescriptorOffset); + functionDescriptorOffset, + dynamicWorkerOptionsOffset); fbb.finish(root); ByteBuffer buffer = fbb.dataBuffer(); diff --git a/java/runtime/src/main/java/org/ray/runtime/runner/RunManager.java b/java/runtime/src/main/java/org/ray/runtime/runner/RunManager.java index 15240e43e234..773499fcf5cf 100644 --- a/java/runtime/src/main/java/org/ray/runtime/runner/RunManager.java +++ b/java/runtime/src/main/java/org/ray/runtime/runner/RunManager.java @@ -319,6 +319,9 @@ private String buildWorkerCommandRaylet() { cmd.addAll(rayConfig.jvmParameters); + // jvm options + cmd.add("RAY_WORKER_OPTION_0"); + // Main class cmd.add(WORKER_CLASS); String command = Joiner.on(" ").join(cmd); diff --git a/java/runtime/src/main/java/org/ray/runtime/task/TaskSpec.java b/java/runtime/src/main/java/org/ray/runtime/task/TaskSpec.java index 3473a9bdb3cc..060ca6fff4c3 100644 --- a/java/runtime/src/main/java/org/ray/runtime/task/TaskSpec.java +++ b/java/runtime/src/main/java/org/ray/runtime/task/TaskSpec.java @@ -63,6 +63,8 @@ public class TaskSpec { // Language of this task. public final TaskLanguage language; + public final List dynamicWorkerOptions; + // Descriptor of the remote function. // Note, if task language is Java, the type is JavaFunctionDescriptor. If the task language // is Python, the type is PyFunctionDescriptor. @@ -93,7 +95,8 @@ public TaskSpec( int numReturns, Map resources, TaskLanguage language, - FunctionDescriptor functionDescriptor) { + FunctionDescriptor functionDescriptor, + List dynamicWorkerOptions) { this.driverId = driverId; this.taskId = taskId; this.parentTaskId = parentTaskId; @@ -106,6 +109,8 @@ public TaskSpec( this.newActorHandles = newActorHandles; this.args = args; this.numReturns = numReturns; + this.dynamicWorkerOptions = dynamicWorkerOptions; + returnIds = new ObjectId[numReturns]; for (int i = 0; i < numReturns; ++i) { returnIds[i] = IdUtil.computeReturnId(taskId, i + 1); @@ -157,6 +162,7 @@ public String toString() { ", resources=" + resources + ", language=" + language + ", functionDescriptor=" + functionDescriptor + + ", dynamicWorkerOptions=" + dynamicWorkerOptions + ", executionDependencies=" + executionDependencies + '}'; } diff --git a/java/test/src/main/java/org/ray/api/TestUtils.java b/java/test/src/main/java/org/ray/api/TestUtils.java index 9b3bbf233856..3636c93e4909 100644 --- a/java/test/src/main/java/org/ray/api/TestUtils.java +++ b/java/test/src/main/java/org/ray/api/TestUtils.java @@ -1,8 +1,10 @@ package org.ray.api; import java.util.function.Supplier; +import org.ray.api.annotation.RayRemote; import org.ray.runtime.AbstractRayRuntime; import org.ray.runtime.config.RunMode; +import org.testng.Assert; import org.testng.SkipException; public class TestUtils { @@ -42,4 +44,17 @@ public static boolean waitForCondition(Supplier condition, int timeoutM } return false; } + + @RayRemote + private static String hi() { + return "hi"; + } + + /** + * Warm up the cluster. + */ + public static void warmUpCluster() { + RayObject obj = Ray.call(TestUtils::hi); + Assert.assertEquals(obj.get(), "hi"); + } } diff --git a/java/test/src/main/java/org/ray/api/test/DynamicResourceTest.java b/java/test/src/main/java/org/ray/api/test/DynamicResourceTest.java index 79b3eba0ed13..71766c6cf2bf 100644 --- a/java/test/src/main/java/org/ray/api/test/DynamicResourceTest.java +++ b/java/test/src/main/java/org/ray/api/test/DynamicResourceTest.java @@ -23,6 +23,10 @@ public static String sayHi() { @Test public void testSetResource() { TestUtils.skipTestUnderSingleProcess(); + + // Call a task in advance to warm up the cluster to avoid being too slow to start workers. + TestUtils.warmUpCluster(); + CallOptions op1 = new CallOptions.Builder().setResources(ImmutableMap.of("A", 10.0)).createCallOptions(); RayObject obj = Ray.call(DynamicResourceTest::sayHi, op1); @@ -30,16 +34,21 @@ public void testSetResource() { Assert.assertEquals(result.getReady().size(), 0); Ray.setResource("A", 10.0); + boolean resourceReady = TestUtils.waitForCondition(() -> { + List nodes = Ray.getRuntimeContext().getAllNodeInfo(); + if (nodes.size() != 1) { + return false; + } + return (0 == Double.compare(10.0, nodes.get(0).resources.get("A"))); + }, 2000); - // Assert node info. - List nodes = Ray.getRuntimeContext().getAllNodeInfo(); - Assert.assertEquals(nodes.size(), 1); - Assert.assertEquals(nodes.get(0).resources.get("A"), 10.0); + Assert.assertTrue(resourceReady); // Assert ray call result. result = Ray.wait(ImmutableList.of(obj), 1, 1000); Assert.assertEquals(result.getReady().size(), 1); Assert.assertEquals(Ray.get(obj.getId()), "hi"); + } } diff --git a/java/test/src/main/java/org/ray/api/test/WaitTest.java b/java/test/src/main/java/org/ray/api/test/WaitTest.java index e82b99d364ba..bccc50a50bdf 100644 --- a/java/test/src/main/java/org/ray/api/test/WaitTest.java +++ b/java/test/src/main/java/org/ray/api/test/WaitTest.java @@ -5,6 +5,7 @@ import java.util.List; import org.ray.api.Ray; import org.ray.api.RayObject; +import org.ray.api.TestUtils; import org.ray.api.WaitResult; import org.ray.api.annotation.RayRemote; import org.testng.Assert; @@ -28,6 +29,9 @@ private static String delayedHi() { } private static void testWait() { + // Call a task in advance to warm up the cluster to avoid being too slow to start workers. + TestUtils.warmUpCluster(); + RayObject obj1 = Ray.call(WaitTest::hi); RayObject obj2 = Ray.call(WaitTest::delayedHi); @@ -71,4 +75,5 @@ public void testWaitForEmpty() { Assert.assertTrue(true); } } + } diff --git a/java/test/src/main/java/org/ray/api/test/WorkerJvmOptionsTest.java b/java/test/src/main/java/org/ray/api/test/WorkerJvmOptionsTest.java new file mode 100644 index 000000000000..90a2817a8366 --- /dev/null +++ b/java/test/src/main/java/org/ray/api/test/WorkerJvmOptionsTest.java @@ -0,0 +1,31 @@ +package org.ray.api.test; + +import org.ray.api.Ray; +import org.ray.api.RayActor; +import org.ray.api.RayObject; +import org.ray.api.TestUtils; +import org.ray.api.annotation.RayRemote; +import org.ray.api.options.ActorCreationOptions; +import org.testng.Assert; +import org.testng.annotations.Test; + +public class WorkerJvmOptionsTest extends BaseTest { + + @RayRemote + public static class Echo { + String getOptions() { + return System.getProperty("test.suffix"); + } + } + + @Test + public void testJvmOptions() { + TestUtils.skipTestUnderSingleProcess(); + ActorCreationOptions options = new ActorCreationOptions.Builder() + .setJvmOptions("-Dtest.suffix=suffix") + .createActorCreationOptions(); + RayActor actor = Ray.createActor(Echo::new, options); + RayObject obj = Ray.call(Echo::getOptions, actor); + Assert.assertEquals(obj.get(), "suffix"); + } +} diff --git a/python/ray/experimental/signal.py b/python/ray/experimental/signal.py index f2a0d81ca343..25ec072d3fc7 100644 --- a/python/ray/experimental/signal.py +++ b/python/ray/experimental/signal.py @@ -2,6 +2,8 @@ from __future__ import division from __future__ import print_function +import logging + from collections import defaultdict import ray @@ -13,6 +15,8 @@ # in node_manager.cc ACTOR_DIED_STR = "ACTOR_DIED_SIGNAL" +logger = logging.getLogger(__name__) + class Signal(object): """Base class for Ray signals.""" @@ -125,10 +129,16 @@ def receive(sources, timeout=None): for s in sources: task_id_to_sources[_get_task_id(s).hex()].append(s) + if timeout < 1e-3: + logger.warning("Timeout too small. Using 1ms minimum") + timeout = 1e-3 + + timeout_ms = int(1000 * timeout) + # Construct the redis query. query = "XREAD BLOCK " - # Multiply by 1000x since timeout is in sec and redis expects ms. - query += str(1000 * timeout) + # redis expects ms. + query += str(timeout_ms) query += " STREAMS " query += " ".join([task_id for task_id in task_id_to_sources]) query += " " diff --git a/python/ray/gcs_utils.py b/python/ray/gcs_utils.py index cadd197ec73f..ba72e96f41db 100644 --- a/python/ray/gcs_utils.py +++ b/python/ray/gcs_utils.py @@ -2,38 +2,39 @@ from __future__ import division from __future__ import print_function -import flatbuffers -import ray.core.generated.ErrorTableData - -from ray.core.generated.ActorCheckpointIdData import ActorCheckpointIdData -from ray.core.generated.ClientTableData import ClientTableData -from ray.core.generated.DriverTableData import DriverTableData -from ray.core.generated.ErrorTableData import ErrorTableData -from ray.core.generated.GcsEntry import GcsEntry -from ray.core.generated.HeartbeatBatchTableData import HeartbeatBatchTableData -from ray.core.generated.HeartbeatTableData import HeartbeatTableData -from ray.core.generated.Language import Language -from ray.core.generated.ObjectTableData import ObjectTableData -from ray.core.generated.ProfileTableData import ProfileTableData -from ray.core.generated.TablePrefix import TablePrefix -from ray.core.generated.TablePubsub import TablePubsub - from ray.core.generated.ray.protocol.Task import Task +from ray.core.generated.gcs_pb2 import ( + ActorCheckpointIdData, + ClientTableData, + DriverTableData, + ErrorTableData, + ErrorType, + GcsEntry, + HeartbeatBatchTableData, + HeartbeatTableData, + ObjectTableData, + ProfileTableData, + TablePrefix, + TablePubsub, + TaskTableData, +) + __all__ = [ "ActorCheckpointIdData", "ClientTableData", "DriverTableData", "ErrorTableData", + "ErrorType", "GcsEntry", "HeartbeatBatchTableData", "HeartbeatTableData", - "Language", "ObjectTableData", "ProfileTableData", "TablePrefix", "TablePubsub", "Task", + "TaskTableData", "construct_error_message", ] @@ -42,13 +43,16 @@ REPORTER_CHANNEL = "RAY_REPORTER" # xray heartbeats -XRAY_HEARTBEAT_CHANNEL = str(TablePubsub.HEARTBEAT).encode("ascii") -XRAY_HEARTBEAT_BATCH_CHANNEL = str(TablePubsub.HEARTBEAT_BATCH).encode("ascii") +XRAY_HEARTBEAT_CHANNEL = str( + TablePubsub.Value("HEARTBEAT_PUBSUB")).encode("ascii") +XRAY_HEARTBEAT_BATCH_CHANNEL = str( + TablePubsub.Value("HEARTBEAT_BATCH_PUBSUB")).encode("ascii") # xray driver updates -XRAY_DRIVER_CHANNEL = str(TablePubsub.DRIVER).encode("ascii") +XRAY_DRIVER_CHANNEL = str(TablePubsub.Value("DRIVER_PUBSUB")).encode("ascii") -# These prefixes must be kept up-to-date with the TablePrefix enum in gcs.fbs. +# These prefixes must be kept up-to-date with the TablePrefix enum in +# gcs.proto. # TODO(rkn): We should use scoped enums, in which case we should be able to # just access the flatbuffer generated values. TablePrefix_RAYLET_TASK_string = "RAYLET_TASK" @@ -70,22 +74,9 @@ def construct_error_message(driver_id, error_type, message, timestamp): Returns: The serialized object. """ - builder = flatbuffers.Builder(0) - driver_offset = builder.CreateString(driver_id.binary()) - error_type_offset = builder.CreateString(error_type) - message_offset = builder.CreateString(message) - - ray.core.generated.ErrorTableData.ErrorTableDataStart(builder) - ray.core.generated.ErrorTableData.ErrorTableDataAddDriverId( - builder, driver_offset) - ray.core.generated.ErrorTableData.ErrorTableDataAddType( - builder, error_type_offset) - ray.core.generated.ErrorTableData.ErrorTableDataAddErrorMessage( - builder, message_offset) - ray.core.generated.ErrorTableData.ErrorTableDataAddTimestamp( - builder, timestamp) - error_data_offset = ray.core.generated.ErrorTableData.ErrorTableDataEnd( - builder) - builder.Finish(error_data_offset) - - return bytes(builder.Output()) + data = ErrorTableData() + data.driver_id = driver_id.binary() + data.type = error_type + data.error_message = message + data.timestamp = timestamp + return data.SerializeToString() diff --git a/python/ray/monitor.py b/python/ray/monitor.py index c9e0424b3eb8..35597ef231e3 100644 --- a/python/ray/monitor.py +++ b/python/ray/monitor.py @@ -101,28 +101,26 @@ def subscribe(self, channel): def xray_heartbeat_batch_handler(self, unused_channel, data): """Handle an xray heartbeat batch message from Redis.""" - gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(data, 0) - heartbeat_data = gcs_entries.Entries(0) + gcs_entries = ray.gcs_utils.GcsEntry.FromString(data) + heartbeat_data = gcs_entries.entries[0] - message = (ray.gcs_utils.HeartbeatBatchTableData. - GetRootAsHeartbeatBatchTableData(heartbeat_data, 0)) + message = ray.gcs_utils.HeartbeatBatchTableData.FromString( + heartbeat_data) - for j in range(message.BatchLength()): - heartbeat_message = message.Batch(j) - - num_resources = heartbeat_message.ResourcesTotalLabelLength() + for heartbeat_message in message.batch: + num_resources = len(heartbeat_message.resources_available_label) static_resources = {} dynamic_resources = {} for i in range(num_resources): - dyn = heartbeat_message.ResourcesAvailableLabel(i) - static = heartbeat_message.ResourcesTotalLabel(i) + dyn = heartbeat_message.resources_available_label[i] + static = heartbeat_message.resources_total_label[i] dynamic_resources[dyn] = ( - heartbeat_message.ResourcesAvailableCapacity(i)) + heartbeat_message.resources_available_capacity[i]) static_resources[static] = ( - heartbeat_message.ResourcesTotalCapacity(i)) + heartbeat_message.resources_total_capacity[i]) # Update the load metrics for this raylet. - client_id = ray.utils.binary_to_hex(heartbeat_message.ClientId()) + client_id = ray.utils.binary_to_hex(heartbeat_message.client_id) ip = self.raylet_id_to_ip_map.get(client_id) if ip: self.load_metrics.update(ip, static_resources, @@ -207,11 +205,10 @@ def xray_driver_removed_handler(self, unused_channel, data): unused_channel: The message channel. data: The message data. """ - gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(data, 0) - driver_data = gcs_entries.Entries(0) - message = ray.gcs_utils.DriverTableData.GetRootAsDriverTableData( - driver_data, 0) - driver_id = message.DriverId() + gcs_entries = ray.gcs_utils.GcsEntry.FromString(data) + driver_data = gcs_entries.entries[0] + message = ray.gcs_utils.DriverTableData.FromString(driver_data) + driver_id = message.driver_id logger.info("Monitor: " "XRay Driver {} has been removed.".format( binary_to_hex(driver_id))) diff --git a/python/ray/rllib/agents/a3c/a3c.py b/python/ray/rllib/agents/a3c/a3c.py index c269df2fc6e5..d320b9636881 100644 --- a/python/ray/rllib/agents/a3c/a3c.py +++ b/python/ray/rllib/agents/a3c/a3c.py @@ -48,6 +48,10 @@ def get_policy_class(config): def validate_config(config): if config["entropy_coeff"] < 0: raise DeprecationWarning("entropy_coeff must be >= 0") + if config["sample_async"] and config["use_pytorch"]: + raise ValueError( + "The sample_async option is not supported with use_pytorch: " + "Multithreading can be lead to crashes if used with pytorch.") def make_async_optimizer(workers, config): diff --git a/python/ray/rllib/agents/impala/vtrace_policy.py b/python/ray/rllib/agents/impala/vtrace_policy.py index 7fd137bae08b..8e9b0e8691e6 100644 --- a/python/ray/rllib/agents/impala/vtrace_policy.py +++ b/python/ray/rllib/agents/impala/vtrace_policy.py @@ -7,19 +7,18 @@ from __future__ import print_function import gym -import ray import numpy as np +import ray from ray.rllib.agents.impala import vtrace from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY -from ray.rllib.policy.policy import Policy -from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.policy.tf_policy import TFPolicy, \ - LearningRateSchedule from ray.rllib.models.action_dist import MultiCategorical from ray.rllib.models.catalog import ModelCatalog +from ray.rllib.policy.policy import Policy +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.tf_policy import TFPolicy, LearningRateSchedule +from ray.rllib.utils import try_import_tf from ray.rllib.utils.annotations import override from ray.rllib.utils.explained_variance import explained_variance -from ray.rllib.utils import try_import_tf tf = try_import_tf() @@ -96,15 +95,22 @@ def __init__(self, # The baseline loss delta = tf.boolean_mask(values - self.vtrace_returns.vs, valid_mask) - self.vf_loss = tf.math.multiply(0.5, tf.reduce_sum(tf.square(delta)), name='vf_loss') + self.vf_loss = tf.math.multiply( + 0.5, tf.reduce_sum( + tf.square(delta)), name='vf_loss') # The entropy loss self.entropy = tf.reduce_sum( tf.boolean_mask(actions_entropy, valid_mask), name='entropy_loss') # The summed weighted loss - self.total_loss = tf.math.add(self.pi_loss, self.vf_loss * vf_loss_coeff - self.entropy * entropy_coeff, - name='total_loss') + self.total_loss = tf.math.add( + self.pi_loss, + self.vf_loss * + vf_loss_coeff - + self.entropy * + entropy_coeff, + name='total_loss') class VTracePostprocessing(object): @@ -274,8 +280,10 @@ def make_time_major(tensor, drop_last=False): with tf.name_scope('kl_divergence'): # KL divergence between worker and learner logits for debugging - model_dist = MultiCategorical(self.model.outputs, output_hidden_shape) - behaviour_dist = MultiCategorical(behaviour_logits, output_hidden_shape) + model_dist = MultiCategorical( + self.model.outputs, output_hidden_shape) + behaviour_dist = MultiCategorical( + behaviour_logits, output_hidden_shape) kls = model_dist.kl(behaviour_dist) if len(kls) > 1: diff --git a/python/ray/rllib/agents/qmix/qmix_policy.py b/python/ray/rllib/agents/qmix/qmix_policy.py index 26ec387de004..99045899684b 100644 --- a/python/ray/rllib/agents/qmix/qmix_policy.py +++ b/python/ray/rllib/agents/qmix/qmix_policy.py @@ -204,6 +204,8 @@ def __init__(self, obs_space, action_space, config): # Setup optimizer self.params = list(self.model.parameters()) + if self.mixer: + self.params += list(self.mixer.parameters()) self.loss = QMixLoss(self.model, self.target_model, self.mixer, self.target_mixer, self.n_agents, self.n_actions, self.config["double_q"], self.config["gamma"]) diff --git a/python/ray/rllib/policy/tf_policy.py b/python/ray/rllib/policy/tf_policy.py index abc5cf546184..ddee7de9745b 100644 --- a/python/ray/rllib/policy/tf_policy.py +++ b/python/ray/rllib/policy/tf_policy.py @@ -2,21 +2,21 @@ from __future__ import division from __future__ import print_function -import os import errno import logging -import numpy as np +import os +import numpy as np import ray import ray.experimental.tf_utils +from ray.rllib.models.lstm import chop_into_sequences from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.models.lstm import chop_into_sequences +from ray.rllib.utils import try_import_tf from ray.rllib.utils.annotations import override, DeveloperAPI from ray.rllib.utils.debug import log_once, summarize from ray.rllib.utils.schedules import ConstantSchedule, PiecewiseSchedule, LinearSchedule from ray.rllib.utils.tf_run_builder import TFRunBuilder -from ray.rllib.utils import try_import_tf tf = try_import_tf() logger = logging.getLogger(__name__) @@ -416,7 +416,7 @@ def _build_compute_actions(self, if len(self._state_inputs) != len(state_batches): raise ValueError( "Must pass in RNN state batches for placeholders {}, got {}". - format(self._state_inputs, state_batches)) + format(self._state_inputs, state_batches)) builder.add_feed_dict(self.extra_compute_action_feed_dict()) builder.add_feed_dict({self._obs_input: obs_batch}) if state_batches: @@ -443,7 +443,7 @@ def _build_apply_gradients(self, builder, gradients): if len(gradients) != len(self._grads): raise ValueError( "Unexpected number of gradients to apply, got {} for {}". - format(gradients, self._grads)) + format(gradients, self._grads)) builder.add_feed_dict({self._is_training: True}) builder.add_feed_dict(dict(zip(self._grads, gradients))) fetches = builder.add_fetches([self._apply_op]) @@ -473,9 +473,9 @@ def _get_loss_inputs_dict(self, batch): feed_dict = {} if self._batch_divisibility_req > 1: meets_divisibility_reqs = ( - len(batch[SampleBatch.CUR_OBS]) % - self._batch_divisibility_req == 0 - and max(batch[SampleBatch.AGENT_INDEX]) == 0) # not multiagent + len(batch[SampleBatch.CUR_OBS]) % + self._batch_divisibility_req == 0 + and max(batch[SampleBatch.AGENT_INDEX]) == 0) # not multiagent else: meets_divisibility_reqs = True diff --git a/python/ray/rllib/tests/test_optimizers.py b/python/ray/rllib/tests/test_optimizers.py index a87a295ccf1d..d27270c20965 100644 --- a/python/ray/rllib/tests/test_optimizers.py +++ b/python/ray/rllib/tests/test_optimizers.py @@ -125,14 +125,14 @@ def testSimple(self): def testMultiGPU(self): local, remotes = self._make_evs() workers = WorkerSet._from_existing(local, remotes) - optimizer = AsyncSamplesOptimizer(workers, num_gpus=2, _fake_gpus=True) + optimizer = AsyncSamplesOptimizer(workers, num_gpus=1, _fake_gpus=True) self._wait_for(optimizer, 1000, 1000) def testMultiGPUParallelLoad(self): local, remotes = self._make_evs() workers = WorkerSet._from_existing(local, remotes) optimizer = AsyncSamplesOptimizer( - workers, num_gpus=2, num_data_loader_buffers=2, _fake_gpus=True) + workers, num_gpus=1, num_data_loader_buffers=1, _fake_gpus=True) self._wait_for(optimizer, 1000, 1000) def testMultiplePasses(self): @@ -211,21 +211,21 @@ def testRejectBadConfigs(self): num_data_loader_buffers=2, minibatch_buffer_size=4)) optimizer = AsyncSamplesOptimizer( workers, - num_gpus=2, + num_gpus=1, train_batch_size=100, sample_batch_size=50, _fake_gpus=True) self._wait_for(optimizer, 1000, 1000) optimizer = AsyncSamplesOptimizer( workers, - num_gpus=2, + num_gpus=1, train_batch_size=100, sample_batch_size=25, _fake_gpus=True) self._wait_for(optimizer, 1000, 1000) optimizer = AsyncSamplesOptimizer( workers, - num_gpus=2, + num_gpus=1, train_batch_size=100, sample_batch_size=74, _fake_gpus=True) diff --git a/python/ray/services.py b/python/ray/services.py index 66d4069820d0..ff4111b2c258 100644 --- a/python/ray/services.py +++ b/python/ray/services.py @@ -1245,6 +1245,7 @@ def build_java_worker_command( assert java_worker_options is not None command = "java " + if redis_address is not None: command += "-Dray.redis.address={} ".format(redis_address) @@ -1265,6 +1266,8 @@ def build_java_worker_command( # Put `java_worker_options` in the last, so it can overwrite the # above options. command += java_worker_options + " " + + command += "RAY_WORKER_OPTION_0 " command += "org.ray.runtime.runner.worker.DefaultWorker" return command diff --git a/python/ray/state.py b/python/ray/state.py index 14ba49987ec4..35f97cd65f5e 100644 --- a/python/ray/state.py +++ b/python/ray/state.py @@ -10,11 +10,11 @@ import ray from ray.function_manager import FunctionDescriptor -import ray.gcs_utils -from ray.ray_constants import ID_SIZE -from ray import services -from ray.core.generated.EntryType import EntryType +from ray import ( + gcs_utils, + services, +) from ray.utils import (decode, binary_to_object_id, binary_to_hex, hex_to_binary) @@ -31,9 +31,9 @@ def _parse_client_table(redis_client): A list of information about the nodes in the cluster. """ NIL_CLIENT_ID = ray.ObjectID.nil().binary() - message = redis_client.execute_command("RAY.TABLE_LOOKUP", - ray.gcs_utils.TablePrefix.CLIENT, - "", NIL_CLIENT_ID) + message = redis_client.execute_command( + "RAY.TABLE_LOOKUP", gcs_utils.TablePrefix.Value("CLIENT"), "", + NIL_CLIENT_ID) # Handle the case where no clients are returned. This should only # occur potentially immediately after the cluster is started. @@ -41,36 +41,31 @@ def _parse_client_table(redis_client): return [] node_info = {} - gcs_entry = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0) + gcs_entry = gcs_utils.GcsEntry.FromString(message) ordered_client_ids = [] # Since GCS entries are append-only, we override so that # only the latest entries are kept. - for i in range(gcs_entry.EntriesLength()): - client = (ray.gcs_utils.ClientTableData.GetRootAsClientTableData( - gcs_entry.Entries(i), 0)) + for entry in gcs_entry.entries: + client = gcs_utils.ClientTableData.FromString(entry) resources = { - decode(client.ResourcesTotalLabel(i)): - client.ResourcesTotalCapacity(i) - for i in range(client.ResourcesTotalLabelLength()) + client.resources_total_label[i]: client.resources_total_capacity[i] + for i in range(len(client.resources_total_label)) } - client_id = ray.utils.binary_to_hex(client.ClientId()) + client_id = ray.utils.binary_to_hex(client.client_id) - if client.EntryType() == EntryType.INSERTION: + if client.entry_type == gcs_utils.ClientTableData.INSERTION: ordered_client_ids.append(client_id) node_info[client_id] = { "ClientID": client_id, - "EntryType": client.EntryType(), - "NodeManagerAddress": decode( - client.NodeManagerAddress(), allow_none=True), - "NodeManagerPort": client.NodeManagerPort(), - "ObjectManagerPort": client.ObjectManagerPort(), - "ObjectStoreSocketName": decode( - client.ObjectStoreSocketName(), allow_none=True), - "RayletSocketName": decode( - client.RayletSocketName(), allow_none=True), + "EntryType": client.entry_type, + "NodeManagerAddress": client.node_manager_address, + "NodeManagerPort": client.node_manager_port, + "ObjectManagerPort": client.object_manager_port, + "ObjectStoreSocketName": client.object_store_socket_name, + "RayletSocketName": client.raylet_socket_name, "Resources": resources } @@ -79,22 +74,23 @@ def _parse_client_table(redis_client): # it cannot have previously been removed. else: assert client_id in node_info, "Client not found!" - assert node_info[client_id]["EntryType"] != EntryType.DELETION, ( - "Unexpected updation of deleted client.") + is_deletion = (node_info[client_id]["EntryType"] != + gcs_utils.ClientTableData.DELETION) + assert is_deletion, "Unexpected updation of deleted client." res_map = node_info[client_id]["Resources"] - if client.EntryType() == EntryType.RES_CREATEUPDATE: + if client.entry_type == gcs_utils.ClientTableData.RES_CREATEUPDATE: for res in resources: res_map[res] = resources[res] - elif client.EntryType() == EntryType.RES_DELETE: + elif client.entry_type == gcs_utils.ClientTableData.RES_DELETE: for res in resources: res_map.pop(res, None) - elif client.EntryType() == EntryType.DELETION: + elif client.entry_type == gcs_utils.ClientTableData.DELETION: pass # Do nothing with the resmap if client deletion else: raise RuntimeError("Unexpected EntryType {}".format( - client.EntryType())) + client.entry_type)) node_info[client_id]["Resources"] = res_map - node_info[client_id]["EntryType"] = client.EntryType() + node_info[client_id]["EntryType"] = client.entry_type # NOTE: We return the list comprehension below instead of simply doing # 'list(node_info.values())' in order to have the nodes appear in the order # that they joined the cluster. Python dictionaries do not preserve @@ -244,20 +240,19 @@ def _object_table(self, object_id): # Return information about a single object ID. message = self._execute_command(object_id, "RAY.TABLE_LOOKUP", - ray.gcs_utils.TablePrefix.OBJECT, "", - object_id.binary()) + gcs_utils.TablePrefix.Value("OBJECT"), + "", object_id.binary()) if message is None: return {} - gcs_entry = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0) + gcs_entry = gcs_utils.GcsEntry.FromString(message) - assert gcs_entry.EntriesLength() > 0 + assert len(gcs_entry.entries) > 0 - entry = ray.gcs_utils.ObjectTableData.GetRootAsObjectTableData( - gcs_entry.Entries(0), 0) + entry = gcs_utils.ObjectTableData.FromString(gcs_entry.entries[0]) object_info = { - "DataSize": entry.ObjectSize(), - "Manager": entry.Manager(), + "DataSize": entry.object_size, + "Manager": entry.manager, } return object_info @@ -278,10 +273,9 @@ def object_table(self, object_id=None): return self._object_table(object_id) else: # Return the entire object table. - object_keys = self._keys(ray.gcs_utils.TablePrefix_OBJECT_string + - "*") + object_keys = self._keys(gcs_utils.TablePrefix_OBJECT_string + "*") object_ids_binary = { - key[len(ray.gcs_utils.TablePrefix_OBJECT_string):] + key[len(gcs_utils.TablePrefix_OBJECT_string):] for key in object_keys } @@ -301,17 +295,18 @@ def _task_table(self, task_id): A dictionary with information about the task ID in question. """ assert isinstance(task_id, ray.TaskID) - message = self._execute_command(task_id, "RAY.TABLE_LOOKUP", - ray.gcs_utils.TablePrefix.RAYLET_TASK, - "", task_id.binary()) + message = self._execute_command( + task_id, "RAY.TABLE_LOOKUP", + gcs_utils.TablePrefix.Value("RAYLET_TASK"), "", task_id.binary()) if message is None: return {} - gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0) - - assert gcs_entries.EntriesLength() == 1 + gcs_entries = gcs_utils.GcsEntry.FromString(message) - task_table_message = ray.gcs_utils.Task.GetRootAsTask( - gcs_entries.Entries(0), 0) + assert len(gcs_entries.entries) == 1 + task_table_data = gcs_utils.TaskTableData.FromString( + gcs_entries.entries[0]) + task_table_message = gcs_utils.Task.GetRootAsTask( + task_table_data.task, 0) execution_spec = task_table_message.TaskExecutionSpec() task_spec = task_table_message.TaskSpecification() @@ -368,9 +363,9 @@ def task_table(self, task_id=None): return self._task_table(task_id) else: task_table_keys = self._keys( - ray.gcs_utils.TablePrefix_RAYLET_TASK_string + "*") + gcs_utils.TablePrefix_RAYLET_TASK_string + "*") task_ids_binary = [ - key[len(ray.gcs_utils.TablePrefix_RAYLET_TASK_string):] + key[len(gcs_utils.TablePrefix_RAYLET_TASK_string):] for key in task_table_keys ] @@ -380,27 +375,6 @@ def task_table(self, task_id=None): ray.TaskID(task_id_binary)) return results - def function_table(self, function_id=None): - """Fetch and parse the function table. - - Returns: - A dictionary that maps function IDs to information about the - function. - """ - self._check_connected() - function_table_keys = self.redis_client.keys( - ray.gcs_utils.FUNCTION_PREFIX + "*") - results = {} - for key in function_table_keys: - info = self.redis_client.hgetall(key) - function_info_parsed = { - "DriverID": binary_to_hex(info[b"driver_id"]), - "Module": decode(info[b"module"]), - "Name": decode(info[b"name"]) - } - results[binary_to_hex(info[b"function_id"])] = function_info_parsed - return results - def client_table(self): """Fetch and parse the Redis DB client table. @@ -423,37 +397,32 @@ def _profile_table(self, batch_id): # TODO(rkn): This method should support limiting the number of log # events and should also support returning a window of events. message = self._execute_command(batch_id, "RAY.TABLE_LOOKUP", - ray.gcs_utils.TablePrefix.PROFILE, "", - batch_id.binary()) + gcs_utils.TablePrefix.Value("PROFILE"), + "", batch_id.binary()) if message is None: return [] - gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0) + gcs_entries = gcs_utils.GcsEntry.FromString(message) profile_events = [] - for i in range(gcs_entries.EntriesLength()): - profile_table_message = ( - ray.gcs_utils.ProfileTableData.GetRootAsProfileTableData( - gcs_entries.Entries(i), 0)) - - component_type = decode(profile_table_message.ComponentType()) - component_id = binary_to_hex(profile_table_message.ComponentId()) - node_ip_address = decode( - profile_table_message.NodeIpAddress(), allow_none=True) + for entry in gcs_entries.entries: + profile_table_message = gcs_utils.ProfileTableData.FromString( + entry) - for j in range(profile_table_message.ProfileEventsLength()): - profile_event_message = profile_table_message.ProfileEvents(j) + component_type = profile_table_message.component_type + component_id = binary_to_hex(profile_table_message.component_id) + node_ip_address = profile_table_message.node_ip_address + for profile_event_message in profile_table_message.profile_events: profile_event = { - "event_type": decode(profile_event_message.EventType()), + "event_type": profile_event_message.event_type, "component_id": component_id, "node_ip_address": node_ip_address, "component_type": component_type, - "start_time": profile_event_message.StartTime(), - "end_time": profile_event_message.EndTime(), - "extra_data": json.loads( - decode(profile_event_message.ExtraData())), + "start_time": profile_event_message.start_time, + "end_time": profile_event_message.end_time, + "extra_data": json.loads(profile_event_message.extra_data), } profile_events.append(profile_event) @@ -462,10 +431,10 @@ def _profile_table(self, batch_id): def profile_table(self): self._check_connected() - profile_table_keys = self._keys( - ray.gcs_utils.TablePrefix_PROFILE_string + "*") + profile_table_keys = self._keys(gcs_utils.TablePrefix_PROFILE_string + + "*") batch_identifiers_binary = [ - key[len(ray.gcs_utils.TablePrefix_PROFILE_string):] + key[len(gcs_utils.TablePrefix_PROFILE_string):] for key in profile_table_keys ] @@ -766,7 +735,7 @@ def cluster_resources(self): clients = self.client_table() for client in clients: # Only count resources from latest entries of live clients. - if client["EntryType"] != EntryType.DELETION: + if client["EntryType"] != gcs_utils.ClientTableData.DELETION: for key, value in client["Resources"].items(): resources[key] += value return dict(resources) @@ -776,7 +745,7 @@ def _live_client_ids(self): return { client["ClientID"] for client in self.client_table() - if (client["EntryType"] != EntryType.DELETION) + if (client["EntryType"] != gcs_utils.ClientTableData.DELETION) } def available_resources(self): @@ -800,7 +769,7 @@ def available_resources(self): for redis_client in self.redis_clients ] for subscribe_client in subscribe_clients: - subscribe_client.subscribe(ray.gcs_utils.XRAY_HEARTBEAT_CHANNEL) + subscribe_client.subscribe(gcs_utils.XRAY_HEARTBEAT_CHANNEL) client_ids = self._live_client_ids() @@ -809,24 +778,23 @@ def available_resources(self): # Parse client message raw_message = subscribe_client.get_message() if (raw_message is None or raw_message["channel"] != - ray.gcs_utils.XRAY_HEARTBEAT_CHANNEL): + gcs_utils.XRAY_HEARTBEAT_CHANNEL): continue data = raw_message["data"] - gcs_entries = (ray.gcs_utils.GcsEntry.GetRootAsGcsEntry( - data, 0)) - heartbeat_data = gcs_entries.Entries(0) - message = (ray.gcs_utils.HeartbeatTableData. - GetRootAsHeartbeatTableData(heartbeat_data, 0)) + gcs_entries = gcs_utils.GcsEntry.FromString(data) + heartbeat_data = gcs_entries.entries[0] + message = gcs_utils.HeartbeatTableData.FromString( + heartbeat_data) # Calculate available resources for this client - num_resources = message.ResourcesAvailableLabelLength() + num_resources = len(message.resources_available_label) dynamic_resources = {} for i in range(num_resources): - resource_id = decode(message.ResourcesAvailableLabel(i)) + resource_id = message.resources_available_label[i] dynamic_resources[resource_id] = ( - message.ResourcesAvailableCapacity(i)) + message.resources_available_capacity[i]) # Update available resources for this client - client_id = ray.utils.binary_to_hex(message.ClientId()) + client_id = ray.utils.binary_to_hex(message.client_id) available_resources_by_id[client_id] = dynamic_resources # Update clients in cluster @@ -860,23 +828,22 @@ def _error_messages(self, driver_id): """ assert isinstance(driver_id, ray.DriverID) message = self.redis_client.execute_command( - "RAY.TABLE_LOOKUP", ray.gcs_utils.TablePrefix.ERROR_INFO, "", + "RAY.TABLE_LOOKUP", gcs_utils.TablePrefix.Value("ERROR_INFO"), "", driver_id.binary()) # If there are no errors, return early. if message is None: return [] - gcs_entries = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0) + gcs_entries = gcs_utils.GcsEntry.FromString(message) error_messages = [] - for i in range(gcs_entries.EntriesLength()): - error_data = ray.gcs_utils.ErrorTableData.GetRootAsErrorTableData( - gcs_entries.Entries(i), 0) - assert driver_id.binary() == error_data.DriverId() + for entry in gcs_entries.entries: + error_data = gcs_utils.ErrorTableData.FromString(entry) + assert driver_id.binary() == error_data.driver_id error_message = { - "type": decode(error_data.Type()), - "message": decode(error_data.ErrorMessage()), - "timestamp": error_data.Timestamp(), + "type": error_data.type, + "message": error_data.error_message, + "timestamp": error_data.timestamp, } error_messages.append(error_message) return error_messages @@ -899,9 +866,9 @@ def error_messages(self, driver_id=None): return self._error_messages(driver_id) error_table_keys = self.redis_client.keys( - ray.gcs_utils.TablePrefix_ERROR_INFO_string + "*") + gcs_utils.TablePrefix_ERROR_INFO_string + "*") driver_ids = [ - key[len(ray.gcs_utils.TablePrefix_ERROR_INFO_string):] + key[len(gcs_utils.TablePrefix_ERROR_INFO_string):] for key in error_table_keys ] @@ -923,30 +890,23 @@ def actor_checkpoint_info(self, actor_id): message = self._execute_command( actor_id, "RAY.TABLE_LOOKUP", - ray.gcs_utils.TablePrefix.ACTOR_CHECKPOINT_ID, + gcs_utils.TablePrefix.Value("ACTOR_CHECKPOINT_ID"), "", actor_id.binary(), ) if message is None: return None - gcs_entry = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry(message, 0) - entry = ( - ray.gcs_utils.ActorCheckpointIdData.GetRootAsActorCheckpointIdData( - gcs_entry.Entries(0), 0)) - checkpoint_ids_str = entry.CheckpointIds() - num_checkpoints = len(checkpoint_ids_str) // ID_SIZE - assert len(checkpoint_ids_str) % ID_SIZE == 0 + gcs_entry = gcs_utils.GcsEntry.FromString(message) + entry = gcs_utils.ActorCheckpointIdData.FromString( + gcs_entry.entries[0]) checkpoint_ids = [ - ray.ActorCheckpointID( - checkpoint_ids_str[(i * ID_SIZE):((i + 1) * ID_SIZE)]) - for i in range(num_checkpoints) + ray.ActorCheckpointID(checkpoint_id) + for checkpoint_id in entry.checkpoint_ids ] return { - "ActorID": ray.utils.binary_to_hex(entry.ActorId()), + "ActorID": ray.utils.binary_to_hex(entry.actor_id), "CheckpointIds": checkpoint_ids, - "Timestamps": [ - entry.Timestamps(i) for i in range(num_checkpoints) - ], + "Timestamps": list(entry.timestamps), } diff --git a/python/ray/tests/cluster_utils.py b/python/ray/tests/cluster_utils.py index 703c3a1420ed..76dfd3000b86 100644 --- a/python/ray/tests/cluster_utils.py +++ b/python/ray/tests/cluster_utils.py @@ -8,7 +8,7 @@ import redis import ray -from ray.core.generated.EntryType import EntryType +from ray.gcs_utils import ClientTableData logger = logging.getLogger(__name__) @@ -177,7 +177,7 @@ def wait_for_nodes(self, timeout=30): clients = ray.state._parse_client_table(redis_client) live_clients = [ client for client in clients - if client["EntryType"] == EntryType.INSERTION + if client["EntryType"] == ClientTableData.INSERTION ] expected = len(self.list_all_nodes()) diff --git a/python/ray/tests/conftest.py b/python/ray/tests/conftest.py index 2e670fb0a84d..f7c93fd50c2e 100644 --- a/python/ray/tests/conftest.py +++ b/python/ray/tests/conftest.py @@ -56,6 +56,14 @@ def _ray_start(**kwargs): ray.shutdown() +# The following fixture will start ray with 0 cpu. +@pytest.fixture +def ray_start_no_cpu(request): + param = getattr(request, "param", {}) + with _ray_start(num_cpus=0, **param) as res: + yield res + + # The following fixture will start ray with 1 cpu. @pytest.fixture def ray_start_regular(request): diff --git a/python/ray/tests/test_actor.py b/python/ray/tests/test_actor.py index dd726e00f27b..932f7b090bf7 100644 --- a/python/ray/tests/test_actor.py +++ b/python/ray/tests/test_actor.py @@ -842,7 +842,7 @@ def f(): assert actor_id not in resulting_ids -def test_actors_on_nodes_with_no_cpus(ray_start_regular): +def test_actors_on_nodes_with_no_cpus(ray_start_no_cpu): @ray.remote class Foo(object): def method(self): diff --git a/python/ray/tests/test_basic.py b/python/ray/tests/test_basic.py index 7f1f78d1b5c4..6b4bd754cd4d 100644 --- a/python/ray/tests/test_basic.py +++ b/python/ray/tests/test_basic.py @@ -2736,15 +2736,17 @@ def test_duplicate_error_messages(shutdown_only): r = ray.worker.global_worker.redis_client - r.execute_command("RAY.TABLE_APPEND", ray.gcs_utils.TablePrefix.ERROR_INFO, - ray.gcs_utils.TablePubsub.ERROR_INFO, driver_id.binary(), - error_data) + r.execute_command("RAY.TABLE_APPEND", + ray.gcs_utils.TablePrefix.Value("ERROR_INFO"), + ray.gcs_utils.TablePubsub.Value("ERROR_INFO_PUBSUB"), + driver_id.binary(), error_data) # Before https://github.com/ray-project/ray/pull/3316 this would # give an error - r.execute_command("RAY.TABLE_APPEND", ray.gcs_utils.TablePrefix.ERROR_INFO, - ray.gcs_utils.TablePubsub.ERROR_INFO, driver_id.binary(), - error_data) + r.execute_command("RAY.TABLE_APPEND", + ray.gcs_utils.TablePrefix.Value("ERROR_INFO"), + ray.gcs_utils.TablePubsub.Value("ERROR_INFO_PUBSUB"), + driver_id.binary(), error_data) @pytest.mark.skipif( diff --git a/python/ray/tests/test_failure.py b/python/ray/tests/test_failure.py index 51b906695c2d..a560e461f7a2 100644 --- a/python/ray/tests/test_failure.py +++ b/python/ray/tests/test_failure.py @@ -493,8 +493,9 @@ def test_warning_monitor_died(shutdown_only): malformed_message = "asdf" redis_client = ray.worker.global_worker.redis_client redis_client.execute_command( - "RAY.TABLE_ADD", ray.gcs_utils.TablePrefix.HEARTBEAT_BATCH, - ray.gcs_utils.TablePubsub.HEARTBEAT_BATCH, fake_id, malformed_message) + "RAY.TABLE_ADD", ray.gcs_utils.TablePrefix.Value("HEARTBEAT_BATCH"), + ray.gcs_utils.TablePubsub.Value("HEARTBEAT_BATCH_PUBSUB"), fake_id, + malformed_message) wait_for_errors(ray_constants.MONITOR_DIED_ERROR, 1) diff --git a/python/ray/tests/test_signal.py b/python/ray/tests/test_signal.py index fe2e74379245..176fbd45bcaa 100644 --- a/python/ray/tests/test_signal.py +++ b/python/ray/tests/test_signal.py @@ -353,3 +353,36 @@ def f(sources): assert len(result_list) == 1 result_list = ray.get(f.remote([a])) assert len(result_list) == 1 + + +def test_non_integral_receive_timeout(ray_start_regular): + @ray.remote + def send_signal(value): + signal.send(UserSignal(value)) + + a = send_signal.remote(0) + # make sure send_signal had a chance to execute + ray.get(a) + + result_list = ray.experimental.signal.receive([a], timeout=0.1) + + assert len(result_list) == 1 + + +def test_small_receive_timeout(ray_start_regular): + """ Test that receive handles timeout smaller than the 1ms min + """ + # 0.1 ms + small_timeout = 1e-4 + + @ray.remote + def send_signal(value): + signal.send(UserSignal(value)) + + a = send_signal.remote(0) + # make sure send_signal had a chance to execute + ray.get(a) + + result_list = ray.experimental.signal.receive([a], timeout=small_timeout) + + assert len(result_list) == 1 diff --git a/python/ray/tune/analysis/experiment_analysis.py b/python/ray/tune/analysis/experiment_analysis.py index 0164ec2b1a2e..a3c246aba161 100644 --- a/python/ray/tune/analysis/experiment_analysis.py +++ b/python/ray/tune/analysis/experiment_analysis.py @@ -47,7 +47,14 @@ class ExperimentAnalysis(object): >>> experiment_path="~/tune_results/my_exp") """ - def __init__(self, experiment_path): + def __init__(self, experiment_path, trials=None): + """Initializer. + + Args: + experiment_path (str): Path to where experiment is located. + trials (list|None): List of trials that can be accessed via + `analysis.trials`. + """ experiment_path = os.path.expanduser(experiment_path) if not os.path.isdir(experiment_path): raise TuneError( @@ -55,7 +62,8 @@ def __init__(self, experiment_path): experiment_state_paths = glob.glob( os.path.join(experiment_path, "experiment_state*.json")) if not experiment_state_paths: - raise TuneError("No experiment state found!") + raise TuneError( + "No experiment state found in {}!".format(experiment_path)) experiment_filename = max( list(experiment_state_paths)) # if more than one, pick latest with open(os.path.join(experiment_path, experiment_filename)) as f: @@ -65,10 +73,27 @@ def __init__(self, experiment_path): raise TuneError("Experiment state invalid; no checkpoints found.") self._checkpoints = self._experiment_state["checkpoints"] self._scrubbed_checkpoints = unnest_checkpoints(self._checkpoints) + self.trials = trials + self._dataframe = None + + def get_all_trial_dataframes(self): + trial_dfs = {} + for checkpoint in self._checkpoints: + logdir = checkpoint["logdir"] + progress = max(glob.glob(os.path.join(logdir, "progress.csv"))) + trial_dfs[checkpoint["trial_id"]] = pd.read_csv(progress) + return trial_dfs + + def dataframe(self, refresh=False): + """Returns a pandas.DataFrame object constructed from the trials. - def dataframe(self): - """Returns a pandas.DataFrame object constructed from the trials.""" - return pd.DataFrame(self._scrubbed_checkpoints) + Args: + refresh (bool): Clears the cache which may have an existing copy. + + """ + if self._dataframe is None or refresh: + self._dataframe = pd.DataFrame(self._scrubbed_checkpoints) + return self._dataframe def stats(self): """Returns a dictionary of the statistics of the experiment.""" @@ -87,22 +112,45 @@ def trial_dataframe(self, trial_id): return pd.read_csv(progress) raise ValueError("Trial id {} not found".format(trial_id)) - def get_best_trainable(self, metric, trainable_cls): - """Returns the best Trainable based on the experiment metric.""" - return trainable_cls(config=self.get_best_config(metric)) - - def get_best_config(self, metric): - """Retrieve the best config from the best trial.""" - return self._get_best_trial(metric)["config"] - - def _get_best_trial(self, metric): - """Retrieve the best trial based on the experiment metric.""" - return max( + def get_best_trainable(self, metric, trainable_cls, mode="max"): + """Returns the best Trainable based on the experiment metric. + + Args: + metric (str): Key for trial info to order on. + mode (str): One of [min, max]. + + """ + return trainable_cls(config=self.get_best_config(metric, mode=mode)) + + def get_best_config(self, metric, mode="max"): + """Retrieve the best config from the best trial. + + Args: + metric (str): Key for trial info to order on. + mode (str): One of [min, max]. + + """ + return self.get_best_info(metric, flatten=False, mode=mode)["config"] + + def get_best_logdir(self, metric, mode="max"): + df = self.dataframe() + if mode == "max": + return df.iloc[df[metric].idxmax()].logdir + elif mode == "min": + return df.iloc[df[metric].idxmin()].logdir + + def get_best_info(self, metric, mode="max", flatten=True): + """Retrieve the best trial based on the experiment metric. + + Args: + metric (str): Key for trial info to order on. + mode (str): One of [min, max]. + flatten (bool): Assumes trial info is flattened, where + nested entries are concatenated like `info:metric`. + """ + optimize_op = max if mode == "max" else min + if flatten: + return optimize_op( + self._scrubbed_checkpoints, key=lambda d: d.get(metric, 0)) + return optimize_op( self._checkpoints, key=lambda d: d["last_result"].get(metric, 0)) - - def _get_sorted_trials(self, metric): - """Retrive trials in sorted order based on the experiment metric.""" - return sorted( - self._checkpoints, - key=lambda d: d["last_result"].get(metric, 0), - reverse=True) diff --git a/python/ray/tune/examples/mnist_pytorch.py b/python/ray/tune/examples/mnist_pytorch.py index 03dd2f1607e2..acef9fc5105d 100644 --- a/python/ray/tune/examples/mnist_pytorch.py +++ b/python/ray/tune/examples/mnist_pytorch.py @@ -1,7 +1,10 @@ # Original Code here: # https://github.com/pytorch/examples/blob/master/mnist/main.py +from __future__ import absolute_import +from __future__ import division from __future__ import print_function +import numpy as np import argparse import torch import torch.nn as nn @@ -9,181 +12,123 @@ import torch.optim as optim from torchvision import datasets, transforms -# Training settings -parser = argparse.ArgumentParser(description="PyTorch MNIST Example") -parser.add_argument( - "--batch-size", - type=int, - default=64, - metavar="N", - help="input batch size for training (default: 64)") -parser.add_argument( - "--test-batch-size", - type=int, - default=1000, - metavar="N", - help="input batch size for testing (default: 1000)") -parser.add_argument( - "--epochs", - type=int, - default=1, - metavar="N", - help="number of epochs to train (default: 1)") -parser.add_argument( - "--lr", - type=float, - default=0.01, - metavar="LR", - help="learning rate (default: 0.01)") -parser.add_argument( - "--momentum", - type=float, - default=0.5, - metavar="M", - help="SGD momentum (default: 0.5)") -parser.add_argument( - "--no-cuda", - action="store_true", - default=False, - help="disables CUDA training") -parser.add_argument( - "--seed", - type=int, - default=1, - metavar="S", - help="random seed (default: 1)") -parser.add_argument( - "--smoke-test", action="store_true", help="Finish quickly for testing") - - -def train_mnist(args, config, reporter): - vars(args).update(config) - args.cuda = not args.no_cuda and torch.cuda.is_available() - - torch.manual_seed(args.seed) - if args.cuda: - torch.cuda.manual_seed(args.seed) - - kwargs = {"num_workers": 1, "pin_memory": True} if args.cuda else {} +import ray +from ray import tune +from ray.tune import track +from ray.tune.schedulers import AsyncHyperBandScheduler + +# Change these values if you want the training to run quicker or slower. +EPOCH_SIZE = 512 +TEST_SIZE = 256 + + +class Net(nn.Module): + def __init__(self, config): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(1, 3, kernel_size=3) + self.fc = nn.Linear(192, 10) + + def forward(self, x): + x = F.relu(F.max_pool2d(self.conv1(x), 3)) + x = x.view(-1, 192) + x = self.fc(x) + return F.log_softmax(x, dim=1) + + +def train(model, optimizer, train_loader, device): + model.train() + for batch_idx, (data, target) in enumerate(train_loader): + if batch_idx * len(data) > EPOCH_SIZE: + return + data, target = data.to(device), target.to(device) + optimizer.zero_grad() + output = model(data) + loss = F.nll_loss(output, target) + loss.backward() + optimizer.step() + + +def test(model, data_loader, device): + model.eval() + correct = 0 + total = 0 + with torch.no_grad(): + for batch_idx, (data, target) in enumerate(data_loader): + if batch_idx * len(data) > TEST_SIZE: + break + data, target = data.to(device), target.to(device) + outputs = model(data) + _, predicted = torch.max(outputs.data, 1) + total += target.size(0) + correct += (predicted == target).sum().item() + + return correct / total + + +def get_data_loaders(): + mnist_transforms = transforms.Compose( + [transforms.ToTensor(), + transforms.Normalize((0.1307, ), (0.3081, ))]) + train_loader = torch.utils.data.DataLoader( datasets.MNIST( - "~/data", - train=True, - download=False, - transform=transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.1307, ), (0.3081, )) - ])), - batch_size=args.batch_size, - shuffle=True, - **kwargs) + "~/data", train=True, download=True, transform=mnist_transforms), + batch_size=64, + shuffle=True) test_loader = torch.utils.data.DataLoader( - datasets.MNIST( - "~/data", - train=False, - transform=transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.1307, ), (0.3081, )) - ])), - batch_size=args.test_batch_size, - shuffle=True, - **kwargs) - - class Net(nn.Module): - def __init__(self): - super(Net, self).__init__() - self.conv1 = nn.Conv2d(1, 10, kernel_size=5) - self.conv2 = nn.Conv2d(10, 20, kernel_size=5) - self.conv2_drop = nn.Dropout2d() - self.fc1 = nn.Linear(320, 50) - self.fc2 = nn.Linear(50, 10) - - def forward(self, x): - x = F.relu(F.max_pool2d(self.conv1(x), 2)) - x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) - x = x.view(-1, 320) - x = F.relu(self.fc1(x)) - x = F.dropout(x, training=self.training) - x = self.fc2(x) - return F.log_softmax(x, dim=1) - - model = Net() - if args.cuda: - model.cuda() + datasets.MNIST("~/data", train=False, transform=mnist_transforms), + batch_size=64, + shuffle=True) + return train_loader, test_loader + + +def train_mnist(config): + use_cuda = config.get("use_gpu") and torch.cuda.is_available() + device = torch.device("cuda" if use_cuda else "cpu") + train_loader, test_loader = get_data_loaders() + model = Net(config).to(device) optimizer = optim.SGD( - model.parameters(), lr=args.lr, momentum=args.momentum) - - def train(epoch): - model.train() - for batch_idx, (data, target) in enumerate(train_loader): - if args.cuda: - data, target = data.cuda(), target.cuda() - optimizer.zero_grad() - output = model(data) - loss = F.nll_loss(output, target) - loss.backward() - optimizer.step() - - def test(): - model.eval() - test_loss = 0 - correct = 0 - with torch.no_grad(): - for data, target in test_loader: - if args.cuda: - data, target = data.cuda(), target.cuda() - output = model(data) - # sum up batch loss - test_loss += F.nll_loss(output, target, reduction="sum").item() - # get the index of the max log-probability - pred = output.argmax(dim=1, keepdim=True) - correct += pred.eq( - target.data.view_as(pred)).long().cpu().sum() - - test_loss = test_loss / len(test_loader.dataset) - accuracy = correct.item() / len(test_loader.dataset) - reporter(mean_loss=test_loss, mean_accuracy=accuracy) - - for epoch in range(1, args.epochs + 1): - train(epoch) - test() + model.parameters(), lr=config["lr"], momentum=config["momentum"]) + + while True: + train(model, optimizer, train_loader, device) + acc = test(model, test_loader, device) + track.log(mean_accuracy=acc) if __name__ == "__main__": - datasets.MNIST("~/data", train=True, download=True) + parser = argparse.ArgumentParser(description="PyTorch MNIST Example") + parser.add_argument( + "--cuda", + action="store_true", + default=False, + help="Enables GPU training") + parser.add_argument( + "--smoke-test", action="store_true", help="Finish quickly for testing") + parser.add_argument( + "--ray-redis-address", + help="Address of Ray cluster for seamless distributed execution.") args = parser.parse_args() - - import ray - from ray import tune - from ray.tune.schedulers import AsyncHyperBandScheduler - - ray.init() + if args.ray_redis_address: + ray.init(redis_address=args.ray_redis_address) sched = AsyncHyperBandScheduler( - time_attr="training_iteration", - metric="mean_loss", - mode="min", - max_t=400, - grace_period=20) - tune.register_trainable( - "TRAIN_FN", - lambda config, reporter: train_mnist(args, config, reporter)) + time_attr="training_iteration", metric="mean_accuracy") tune.run( - "TRAIN_FN", + train_mnist, name="exp", scheduler=sched, - **{ - "stop": { - "mean_accuracy": 0.98, - "training_iteration": 1 if args.smoke_test else 20 - }, - "resources_per_trial": { - "cpu": 3, - "gpu": int(not args.no_cuda) - }, - "num_samples": 1 if args.smoke_test else 10, - "config": { - "lr": tune.uniform(0.001, 0.1), - "momentum": tune.uniform(0.1, 0.9), - } + stop={ + "mean_accuracy": 0.98, + "training_iteration": 5 if args.smoke_test else 20 + }, + resources_per_trial={ + "cpu": 2, + "gpu": int(args.cuda) + }, + num_samples=1 if args.smoke_test else 10, + config={ + "lr": tune.sample_from(lambda spec: 10**(-10 * np.random.rand())), + "momentum": tune.uniform(0.1, 0.9), + "use_gpu": int(args.cuda) }) diff --git a/python/ray/tune/examples/track_example.py b/python/ray/tune/examples/track_example.py index 1ccec39462d0..751f0ed44fa9 100644 --- a/python/ray/tune/examples/track_example.py +++ b/python/ray/tune/examples/track_example.py @@ -9,7 +9,7 @@ from keras.layers import (Dense, Dropout, Flatten, Conv2D, MaxPooling2D) from ray.tune import track -from ray.tune.examples.utils import TuneKerasCallback, get_mnist_data +from ray.tune.examples.utils import TuneReporterCallback, get_mnist_data parser = argparse.ArgumentParser() parser.add_argument( @@ -63,7 +63,7 @@ def train_mnist(args): batch_size=batch_size, epochs=epochs, validation_data=(x_test, y_test), - callbacks=[TuneKerasCallback(track.metric)]) + callbacks=[TuneReporterCallback(track.metric)]) track.shutdown() diff --git a/python/ray/tune/examples/tune_mnist_keras.py b/python/ray/tune/examples/tune_mnist_keras.py index 5357d86af19e..ecd3c34bc042 100644 --- a/python/ray/tune/examples/tune_mnist_keras.py +++ b/python/ray/tune/examples/tune_mnist_keras.py @@ -9,8 +9,8 @@ from keras.models import Sequential from keras.layers import (Dense, Dropout, Flatten, Conv2D, MaxPooling2D) -from ray.tune.examples.utils import (TuneKerasCallback, get_mnist_data, - set_keras_threads) +from ray.tune.integration.keras import TuneReporterCallback +from ray.tune.examples.utils import get_mnist_data, set_keras_threads parser = argparse.ArgumentParser() parser.add_argument( @@ -52,7 +52,7 @@ def train_mnist(config, reporter): epochs=epochs, verbose=0, validation_data=(x_test, y_test), - callbacks=[TuneKerasCallback(reporter)]) + callbacks=[TuneReporterCallback(reporter)]) if __name__ == "__main__": @@ -63,7 +63,7 @@ def train_mnist(config, reporter): ray.init() sched = AsyncHyperBandScheduler( - time_attr="timesteps_total", + time_attr="training_iteration", metric="mean_accuracy", mode="max", max_t=400, diff --git a/python/ray/tune/examples/utils.py b/python/ray/tune/examples/utils.py index a5ab1dbdb6a1..f40707a014fc 100644 --- a/python/ray/tune/examples/utils.py +++ b/python/ray/tune/examples/utils.py @@ -5,24 +5,9 @@ import keras from keras.datasets import mnist from keras import backend as K - - -class TuneKerasCallback(keras.callbacks.Callback): - def __init__(self, reporter, logs={}): - self.reporter = reporter - self.iteration = 0 - super(TuneKerasCallback, self).__init__() - - def on_train_end(self, epoch, logs={}): - self.reporter( - timesteps_total=self.iteration, - done=1, - mean_accuracy=logs.get("acc")) - - def on_batch_end(self, batch, logs={}): - self.iteration += 1 - self.reporter( - timesteps_total=self.iteration, mean_accuracy=logs["acc"]) +from sklearn.datasets import load_iris +from sklearn.model_selection import train_test_split +from sklearn.preprocessing import OneHotEncoder def get_mnist_data(): @@ -53,6 +38,16 @@ def get_mnist_data(): return x_train, y_train, x_test, y_test, input_shape +def get_iris_data(test_size=0.2): + iris_data = load_iris() + x = iris_data.data + y = iris_data.target.reshape(-1, 1) + encoder = OneHotEncoder(sparse=False) + y = encoder.fit_transform(y) + train_x, test_x, train_y, test_y = train_test_split(x, y) + return train_x, train_y, test_x, test_y + + def set_keras_threads(threads): # We set threads here to avoid contention, as Keras # is heavily parallelized across multiple cores. @@ -61,3 +56,8 @@ def set_keras_threads(threads): config=K.tf.ConfigProto( intra_op_parallelism_threads=threads, inter_op_parallelism_threads=threads))) + + +def TuneKerasCallback(*args, **kwargs): + raise DeprecationWarning("TuneKerasCallback is now " + "tune.integration.keras.TuneReporterCallback.") diff --git a/python/ray/tune/experiment.py b/python/ray/tune/experiment.py index 5f3e46aabd0a..95cb12043f8f 100644 --- a/python/ray/tune/experiment.py +++ b/python/ray/tune/experiment.py @@ -176,6 +176,14 @@ def _register_if_needed(cls, run_object): else: raise TuneError("Improper 'run' - not string nor trainable.") + @property + def local_dir(self): + return self.spec.get("local_dir") + + @property + def checkpoint_dir(self): + return os.path.join(self.spec["local_dir"], self.name) + def convert_to_experiment_list(experiments): """Produces a list of Experiment objects. diff --git a/python/ray/tune/integration/__init__.py b/python/ray/tune/integration/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/python/ray/tune/integration/keras.py b/python/ray/tune/integration/keras.py new file mode 100644 index 000000000000..197a7eef9841 --- /dev/null +++ b/python/ray/tune/integration/keras.py @@ -0,0 +1,34 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import keras +from ray.tune import track + + +class TuneReporterCallback(keras.callbacks.Callback): + def __init__(self, reporter=None, freq="batch", logs={}): + self.reporter = reporter or track.log + self.iteration = 0 + if freq not in ["batch", "epoch"]: + raise ValueError("{} not supported as a frequency.".format(freq)) + self.freq = freq + super(TuneReporterCallback, self).__init__() + + def on_batch_end(self, batch, logs={}): + if not self.freq == "batch": + return + self.iteration += 1 + for metric in list(logs): + if "loss" in metric and "neg_" not in metric: + logs["neg_" + metric] = -logs[metric] + self.reporter(keras_info=logs, mean_accuracy=logs["acc"]) + + def on_epoch_end(self, batch, logs={}): + if not self.freq == "epoch": + return + self.iteration += 1 + for metric in list(logs): + if "loss" in metric and "neg_" not in metric: + logs["neg_" + metric] = -logs[metric] + self.reporter(keras_info=logs, mean_accuracy=logs["acc"]) diff --git a/python/ray/tune/schedulers/__init__.py b/python/ray/tune/schedulers/__init__.py index 50bb447437e4..34655372f40a 100644 --- a/python/ray/tune/schedulers/__init__.py +++ b/python/ray/tune/schedulers/__init__.py @@ -4,11 +4,13 @@ from ray.tune.schedulers.trial_scheduler import TrialScheduler, FIFOScheduler from ray.tune.schedulers.hyperband import HyperBandScheduler -from ray.tune.schedulers.async_hyperband import AsyncHyperBandScheduler +from ray.tune.schedulers.async_hyperband import (AsyncHyperBandScheduler, + ASHAScheduler) from ray.tune.schedulers.median_stopping_rule import MedianStoppingRule from ray.tune.schedulers.pbt import PopulationBasedTraining __all__ = [ "TrialScheduler", "HyperBandScheduler", "AsyncHyperBandScheduler", - "MedianStoppingRule", "FIFOScheduler", "PopulationBasedTraining" + "ASHAScheduler", "MedianStoppingRule", "FIFOScheduler", + "PopulationBasedTraining" ] diff --git a/python/ray/tune/schedulers/async_hyperband.py b/python/ray/tune/schedulers/async_hyperband.py index 487eb350efcf..0370d03d3b50 100644 --- a/python/ray/tune/schedulers/async_hyperband.py +++ b/python/ray/tune/schedulers/async_hyperband.py @@ -168,6 +168,8 @@ def debug_str(self): return "Bracket: " + iters +ASHAScheduler = AsyncHyperBandScheduler + if __name__ == "__main__": sched = AsyncHyperBandScheduler( grace_period=1, max_t=10, reduction_factor=2) diff --git a/python/ray/tune/tests/test_experiment_analysis.py b/python/ray/tune/tests/test_experiment_analysis.py index a0721abc5d29..7b613a6fdea2 100644 --- a/python/ray/tune/tests/test_experiment_analysis.py +++ b/python/ray/tune/tests/test_experiment_analysis.py @@ -11,9 +11,7 @@ import ray from ray.tune import run, sample_from -from ray.tune.analysis import ExperimentAnalysis from ray.tune.examples.async_hyperband_example import MyTrainableClass -from ray.tune.schedulers import AsyncHyperBandScheduler class ExperimentAnalysisSuite(unittest.TestCase): @@ -27,35 +25,22 @@ def setUp(self): self.test_path = os.path.join(self.test_dir, self.test_name) self.run_test_exp() - self.ea = ExperimentAnalysis(self.test_path) - def tearDown(self): shutil.rmtree(self.test_dir, ignore_errors=True) ray.shutdown() def run_test_exp(self): - ahb = AsyncHyperBandScheduler( - time_attr="training_iteration", - metric=self.metric, - mode="max", - grace_period=5, - max_t=100) - - run(MyTrainableClass, + self.ea = run( + MyTrainableClass, name=self.test_name, - scheduler=ahb, local_dir=self.test_dir, - **{ - "stop": { - "training_iteration": 1 - }, - "num_samples": 10, - "config": { - "width": sample_from( - lambda spec: 10 + int(90 * random.random())), - "height": sample_from( - lambda spec: int(100 * random.random())), - }, + return_trials=False, + stop={"training_iteration": 1}, + num_samples=self.num_samples, + config={ + "width": sample_from( + lambda spec: 10 + int(90 * random.random())), + "height": sample_from(lambda spec: int(100 * random.random())), }) def testDataframe(self): @@ -87,7 +72,7 @@ def testBestConfig(self): self.assertTrue("height" in best_config) def testBestTrial(self): - best_trial = self.ea._get_best_trial(self.metric) + best_trial = self.ea.get_best_info(self.metric, flatten=False) self.assertTrue(isinstance(best_trial, dict)) self.assertTrue("local_dir" in best_trial) @@ -99,6 +84,18 @@ def testBestTrial(self): self.assertTrue("last_result" in best_trial) self.assertTrue(self.metric in best_trial["last_result"]) + min_trial = self.ea.get_best_info( + self.metric, mode="min", flatten=False) + + self.assertTrue(isinstance(min_trial, dict)) + self.assertLess(min_trial["last_result"][self.metric], + best_trial["last_result"][self.metric]) + + flat_trial = self.ea.get_best_info(self.metric, flatten=True) + + self.assertTrue(isinstance(min_trial, dict)) + self.assertTrue(self.metric in flat_trial) + def testCheckpoints(self): checkpoints = self.ea._checkpoints @@ -121,6 +118,21 @@ def testRunnerData(self): self.assertEqual(runner_data["_metadata_checkpoint_dir"], os.path.expanduser(self.test_path)) + def testBestLogdir(self): + logdir = self.ea.get_best_logdir(self.metric) + self.assertTrue(logdir.startswith(self.test_path)) + logdir2 = self.ea.get_best_logdir(self.metric, mode="min") + self.assertTrue(logdir2.startswith(self.test_path)) + self.assertNotEquals(logdir, logdir2) + + def testAllDataframes(self): + dataframes = self.ea.get_all_trial_dataframes() + self.assertTrue(len(dataframes) == self.num_samples) + + self.assertTrue(isinstance(dataframes, dict)) + for df in dataframes.values(): + self.assertEqual(df.training_iteration.max(), 1) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/python/ray/tune/tests/test_trial_runner.py b/python/ray/tune/tests/test_trial_runner.py index 37022ceab615..64b8e9761488 100644 --- a/python/ray/tune/tests/test_trial_runner.py +++ b/python/ray/tune/tests/test_trial_runner.py @@ -441,6 +441,14 @@ def f(): self.assertRaises(TuneError, f) + def testNestedStoppingReturn(self): + def train(config, reporter): + for i in range(10): + reporter(test={"test1": {"test2": i}}) + + [trial] = tune.run(train, stop={"test": {"test1": {"test2": 6}}}) + self.assertEqual(trial.last_result["training_iteration"], 7) + def testEarlyReturn(self): def train(config, reporter): reporter(timesteps_total=100, done=True) diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index f721023b4191..a9938396e59b 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -181,6 +181,21 @@ def has_trainable(trainable_name): ray.tune.registry.TRAINABLE_CLASS, trainable_name) +def recursive_criteria_check(result, criteria): + for criteria, stop_value in criteria.items(): + if criteria not in result: + raise TuneError( + "Stopping criteria {} not provided in result {}.".format( + criteria, result)) + elif isinstance(result[criteria], dict) and isinstance( + stop_value, dict): + if recursive_criteria_check(result[criteria], stop_value): + return True + elif result[criteria] >= stop_value: + return True + return False + + class Checkpoint(object): """Describes a checkpoint of trial state. @@ -425,15 +440,7 @@ def should_stop(self, result): if result.get(DONE): return True - for criteria, stop_value in self.stopping_criterion.items(): - if criteria not in result: - raise TuneError( - "Stopping criteria {} not provided in result {}.".format( - criteria, result)) - if result[criteria] >= stop_value: - return True - - return False + return recursive_criteria_check(result, self.stopping_criterion) def should_checkpoint(self): """Whether this trial is due for checkpointing.""" diff --git a/python/ray/tune/tune.py b/python/ray/tune/tune.py index 1568db0f1102..47a82ba0c17f 100644 --- a/python/ray/tune/tune.py +++ b/python/ray/tune/tune.py @@ -4,11 +4,11 @@ import click import logging -import os import time from ray.tune.error import TuneError from ray.tune.experiment import convert_to_experiment_list, Experiment +from ray.tune.analysis import ExperimentAnalysis from ray.tune.suggest import BasicVariantGenerator from ray.tune.trial import Trial, DEBUG_PRINT_INTERVAL from ray.tune.ray_trial_executor import RayTrialExecutor @@ -39,7 +39,7 @@ def _make_scheduler(args): def _find_checkpoint_dir(exp): # TODO(rliaw): Make sure the checkpoint_dir is resolved earlier. # Right now it is resolved somewhere far down the trial generation process - return os.path.join(exp.spec["local_dir"], exp.name) + return exp.checkpoint_dir def _prompt_restore(checkpoint_dir, resume): @@ -89,9 +89,10 @@ def run(run_or_experiment, verbose=2, resume=False, queue_trials=False, - reuse_actors=False, + reuse_actors=True, trial_executor=None, raise_on_failed_trial=True, + return_trials=True, ray_auto_init=True): """Executes training. @@ -322,7 +323,9 @@ def override_flags(restored_config, new_config, flags_to_override): else: logger.error("Trials did not complete: %s", errored_trials) - return runner.get_trials() + if return_trials: + return runner.get_trials() + return ExperimentAnalysis(experiment.checkpoint_dir) def run_experiments(experiments, diff --git a/python/ray/utils.py b/python/ray/utils.py index 7b87486e325e..0db48e41d025 100644 --- a/python/ray/utils.py +++ b/python/ray/utils.py @@ -93,10 +93,10 @@ def push_error_to_driver_through_redis(redis_client, # of through the raylet. error_data = ray.gcs_utils.construct_error_message(driver_id, error_type, message, time.time()) - redis_client.execute_command("RAY.TABLE_APPEND", - ray.gcs_utils.TablePrefix.ERROR_INFO, - ray.gcs_utils.TablePubsub.ERROR_INFO, - driver_id.binary(), error_data) + redis_client.execute_command( + "RAY.TABLE_APPEND", ray.gcs_utils.TablePrefix.Value("ERROR_INFO"), + ray.gcs_utils.TablePubsub.Value("ERROR_INFO_PUBSUB"), + driver_id.binary(), error_data) def is_cython(obj): diff --git a/python/ray/worker.py b/python/ray/worker.py index 7505120574a6..710f0db43c6b 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -47,7 +47,7 @@ from ray import import_thread from ray import profiling -from ray.core.generated.ErrorType import ErrorType +from ray.gcs_utils import ErrorType from ray.exceptions import ( RayActorError, RayError, @@ -461,11 +461,11 @@ def _deserialize_object_from_arrow(self, data, metadata, object_id, # Otherwise, return an exception object based on # the error type. error_type = int(metadata) - if error_type == ErrorType.WORKER_DIED: + if error_type == ErrorType.Value("WORKER_DIED"): return RayWorkerError() - elif error_type == ErrorType.ACTOR_DIED: + elif error_type == ErrorType.Value("ACTOR_DIED"): return RayActorError() - elif error_type == ErrorType.OBJECT_UNRECONSTRUCTABLE: + elif error_type == ErrorType.Value("OBJECT_UNRECONSTRUCTABLE"): return UnreconstructableError(ray.ObjectID(object_id.binary())) else: assert False, "Unrecognized error type " + str(error_type) @@ -1637,7 +1637,7 @@ def listen_error_messages_raylet(worker, task_error_queue, threads_stopped): # Really we should just subscribe to the errors for this specific job. # However, currently all errors seem to be published on the same channel. error_pubsub_channel = str( - ray.gcs_utils.TablePubsub.ERROR_INFO).encode("ascii") + ray.gcs_utils.TablePubsub.Value("ERROR_INFO_PUBSUB")).encode("ascii") worker.error_message_pubsub_client.subscribe(error_pubsub_channel) # worker.error_message_pubsub_client.psubscribe("*") @@ -1656,21 +1656,19 @@ def listen_error_messages_raylet(worker, task_error_queue, threads_stopped): if msg is None: threads_stopped.wait(timeout=0.01) continue - gcs_entry = ray.gcs_utils.GcsEntry.GetRootAsGcsEntry( - msg["data"], 0) - assert gcs_entry.EntriesLength() == 1 - error_data = ray.gcs_utils.ErrorTableData.GetRootAsErrorTableData( - gcs_entry.Entries(0), 0) - driver_id = error_data.DriverId() + gcs_entry = ray.gcs_utils.GcsEntry.FromString(msg["data"]) + assert len(gcs_entry.entries) == 1 + error_data = ray.gcs_utils.ErrorTableData.FromString( + gcs_entry.entries[0]) + driver_id = error_data.driver_id if driver_id not in [ worker.task_driver_id.binary(), DriverID.nil().binary() ]: continue - error_message = ray.utils.decode(error_data.ErrorMessage()) - if (ray.utils.decode( - error_data.Type()) == ray_constants.TASK_PUSH_ERROR): + error_message = error_data.error_message + if (error_data.type == ray_constants.TASK_PUSH_ERROR): # Delay it a bit to see if we can suppress it task_error_queue.put((error_message, time.time())) else: @@ -1878,14 +1876,16 @@ def connect(node, {}, # resource_map. {}, # placement_resource_map. ) + task_table_data = ray.gcs_utils.TaskTableData() + task_table_data.task = driver_task._serialized_raylet_task() # Add the driver task to the task table. - ray.state.state._execute_command(driver_task.task_id(), - "RAY.TABLE_ADD", - ray.gcs_utils.TablePrefix.RAYLET_TASK, - ray.gcs_utils.TablePubsub.RAYLET_TASK, - driver_task.task_id().binary(), - driver_task._serialized_raylet_task()) + ray.state.state._execute_command( + driver_task.task_id(), "RAY.TABLE_ADD", + ray.gcs_utils.TablePrefix.Value("RAYLET_TASK"), + ray.gcs_utils.TablePubsub.Value("RAYLET_TASK_PUBSUB"), + driver_task.task_id().binary(), + task_table_data.SerializeToString()) # Set the driver's current task ID to the task ID assigned to the # driver task. diff --git a/python/setup.py b/python/setup.py index eb200ea7d5e4..95e7e66bad3e 100644 --- a/python/setup.py +++ b/python/setup.py @@ -151,6 +151,7 @@ def find_version(*filepath): "six >= 1.0.0", "flatbuffers", "faulthandler;python_version<'3.3'", + "protobuf", ] setup( diff --git a/src/ray/common/constants.h b/src/ray/common/constants.h index c92e6a74aa5d..1f50b8025d57 100644 --- a/src/ray/common/constants.h +++ b/src/ray/common/constants.h @@ -36,4 +36,6 @@ constexpr char kObjectTablePrefix[] = "ObjectTable"; /// Prefix for the task table keys in redis. constexpr char kTaskTablePrefix[] = "TaskTable"; +constexpr char kWorkerDynamicOptionPlaceholderPrefix[] = "RAY_WORKER_OPTION_"; + #endif // RAY_CONSTANTS_H_ diff --git a/src/ray/gcs/client.cc b/src/ray/gcs/client.cc index c9b1e138575d..6de29bb52764 100644 --- a/src/ray/gcs/client.cc +++ b/src/ray/gcs/client.cc @@ -206,10 +206,6 @@ TaskLeaseTable &AsyncGcsClient::task_lease_table() { return *task_lease_table_; ClientTable &AsyncGcsClient::client_table() { return *client_table_; } -FunctionTable &AsyncGcsClient::function_table() { return *function_table_; } - -ClassTable &AsyncGcsClient::class_table() { return *class_table_; } - HeartbeatTable &AsyncGcsClient::heartbeat_table() { return *heartbeat_table_; } HeartbeatBatchTable &AsyncGcsClient::heartbeat_batch_table() { diff --git a/src/ray/gcs/client.h b/src/ray/gcs/client.h index c9f5b4bca624..5e70025b39a0 100644 --- a/src/ray/gcs/client.h +++ b/src/ray/gcs/client.h @@ -44,11 +44,7 @@ class RAY_EXPORT AsyncGcsClient { /// one event loop should be attached at a time. Status Attach(boost::asio::io_service &io_service); - inline FunctionTable &function_table(); // TODO: Some API for getting the error on the driver - inline ClassTable &class_table(); - inline CustomSerializerTable &custom_serializer_table(); - inline ConfigTable &config_table(); ObjectTable &object_table(); raylet::TaskTable &raylet_task_table(); ActorTable &actor_table(); @@ -81,8 +77,6 @@ class RAY_EXPORT AsyncGcsClient { std::string DebugString() const; private: - std::unique_ptr function_table_; - std::unique_ptr class_table_; std::unique_ptr object_table_; std::unique_ptr raylet_task_table_; std::unique_ptr actor_table_; diff --git a/src/ray/gcs/client_test.cc b/src/ray/gcs/client_test.cc index c7dc02e50651..55115b1e2067 100644 --- a/src/ray/gcs/client_test.cc +++ b/src/ray/gcs/client_test.cc @@ -85,21 +85,21 @@ class TestGcsWithChainAsio : public TestGcsWithAsio { void TestTableLookup(const DriverID &driver_id, std::shared_ptr client) { TaskID task_id = TaskID::FromRandom(); - auto data = std::make_shared(); - data->task_specification = "123"; + auto data = std::make_shared(); + data->set_task("123"); // Check that we added the correct task. auto add_callback = [task_id, data](gcs::AsyncGcsClient *client, const TaskID &id, - const protocol::TaskT &d) { + const TaskTableData &d) { ASSERT_EQ(id, task_id); - ASSERT_EQ(data->task_specification, d.task_specification); + ASSERT_EQ(data->task(), d.task()); }; // Check that the lookup returns the added task. auto lookup_callback = [task_id, data](gcs::AsyncGcsClient *client, const TaskID &id, - const protocol::TaskT &d) { + const TaskTableData &d) { ASSERT_EQ(id, task_id); - ASSERT_EQ(data->task_specification, d.task_specification); + ASSERT_EQ(data->task(), d.task()); test->Stop(); }; @@ -136,13 +136,13 @@ void TestLogLookup(const DriverID &driver_id, TaskID task_id = TaskID::FromRandom(); std::vector node_manager_ids = {"abc", "def", "ghi"}; for (auto &node_manager_id : node_manager_ids) { - auto data = std::make_shared(); - data->node_manager_id = node_manager_id; + auto data = std::make_shared(); + data->set_node_manager_id(node_manager_id); // Check that we added the correct object entries. auto add_callback = [task_id, data](gcs::AsyncGcsClient *client, const TaskID &id, - const TaskReconstructionDataT &d) { + const TaskReconstructionData &d) { ASSERT_EQ(id, task_id); - ASSERT_EQ(data->node_manager_id, d.node_manager_id); + ASSERT_EQ(data->node_manager_id(), d.node_manager_id()); }; RAY_CHECK_OK( client->task_reconstruction_log().Append(driver_id, task_id, data, add_callback)); @@ -151,10 +151,10 @@ void TestLogLookup(const DriverID &driver_id, // Check that lookup returns the added object entries. auto lookup_callback = [task_id, node_manager_ids]( gcs::AsyncGcsClient *client, const TaskID &id, - const std::vector &data) { + const std::vector &data) { ASSERT_EQ(id, task_id); for (const auto &entry : data) { - ASSERT_EQ(entry.node_manager_id, node_manager_ids[test->NumCallbacks()]); + ASSERT_EQ(entry.node_manager_id(), node_manager_ids[test->NumCallbacks()]); test->IncrementNumCallbacks(); } if (test->NumCallbacks() == node_manager_ids.size()) { @@ -182,7 +182,7 @@ void TestTableLookupFailure(const DriverID &driver_id, // Check that the lookup does not return data. auto lookup_callback = [](gcs::AsyncGcsClient *client, const TaskID &id, - const protocol::TaskT &d) { RAY_CHECK(false); }; + const TaskTableData &d) { RAY_CHECK(false); }; // Check that the lookup returns an empty entry. auto failure_callback = [task_id](gcs::AsyncGcsClient *client, const TaskID &id) { @@ -207,16 +207,16 @@ void TestLogAppendAt(const DriverID &driver_id, std::shared_ptr client) { TaskID task_id = TaskID::FromRandom(); std::vector node_manager_ids = {"A", "B"}; - std::vector> data_log; + std::vector> data_log; for (const auto &node_manager_id : node_manager_ids) { - auto data = std::make_shared(); - data->node_manager_id = node_manager_id; + auto data = std::make_shared(); + data->set_node_manager_id(node_manager_id); data_log.push_back(data); } // Check that we added the correct task. auto failure_callback = [task_id](gcs::AsyncGcsClient *client, const TaskID &id, - const TaskReconstructionDataT &d) { + const TaskReconstructionData &d) { ASSERT_EQ(id, task_id); test->IncrementNumCallbacks(); }; @@ -242,10 +242,10 @@ void TestLogAppendAt(const DriverID &driver_id, auto lookup_callback = [node_manager_ids]( gcs::AsyncGcsClient *client, const TaskID &id, - const std::vector &data) { + const std::vector &data) { std::vector appended_managers; for (const auto &entry : data) { - appended_managers.push_back(entry.node_manager_id); + appended_managers.push_back(entry.node_manager_id()); } ASSERT_EQ(appended_managers, node_manager_ids); test->Stop(); @@ -268,22 +268,22 @@ void TestSet(const DriverID &driver_id, std::shared_ptr cli ObjectID object_id = ObjectID::FromRandom(); std::vector managers = {"abc", "def", "ghi"}; for (auto &manager : managers) { - auto data = std::make_shared(); - data->manager = manager; + auto data = std::make_shared(); + data->set_manager(manager); // Check that we added the correct object entries. auto add_callback = [object_id, data](gcs::AsyncGcsClient *client, const ObjectID &id, - const ObjectTableDataT &d) { + const ObjectTableData &d) { ASSERT_EQ(id, object_id); - ASSERT_EQ(data->manager, d.manager); + ASSERT_EQ(data->manager(), d.manager()); test->IncrementNumCallbacks(); }; RAY_CHECK_OK(client->object_table().Add(driver_id, object_id, data, add_callback)); } // Check that lookup returns the added object entries. - auto lookup_callback = [object_id, managers]( - gcs::AsyncGcsClient *client, const ObjectID &id, - const std::vector &data) { + auto lookup_callback = [object_id, managers](gcs::AsyncGcsClient *client, + const ObjectID &id, + const std::vector &data) { ASSERT_EQ(id, object_id); ASSERT_EQ(data.size(), managers.size()); test->IncrementNumCallbacks(); @@ -293,14 +293,14 @@ void TestSet(const DriverID &driver_id, std::shared_ptr cli RAY_CHECK_OK(client->object_table().Lookup(driver_id, object_id, lookup_callback)); for (auto &manager : managers) { - auto data = std::make_shared(); - data->manager = manager; + auto data = std::make_shared(); + data->set_manager(manager); // Check that we added the correct object entries. auto remove_entry_callback = [object_id, data](gcs::AsyncGcsClient *client, const ObjectID &id, - const ObjectTableDataT &d) { + const ObjectTableData &d) { ASSERT_EQ(id, object_id); - ASSERT_EQ(data->manager, d.manager); + ASSERT_EQ(data->manager(), d.manager()); test->IncrementNumCallbacks(); }; RAY_CHECK_OK( @@ -310,7 +310,7 @@ void TestSet(const DriverID &driver_id, std::shared_ptr cli // Check that the entries are removed. auto lookup_callback2 = [object_id, managers]( gcs::AsyncGcsClient *client, const ObjectID &id, - const std::vector &data) { + const std::vector &data) { ASSERT_EQ(id, object_id); ASSERT_EQ(data.size(), 0); test->IncrementNumCallbacks(); @@ -332,7 +332,7 @@ TEST_F(TestGcsWithAsio, TestSet) { void TestDeleteKeysFromLog( const DriverID &driver_id, std::shared_ptr client, - std::vector> &data_vector) { + std::vector> &data_vector) { std::vector ids; TaskID task_id; for (auto &data : data_vector) { @@ -340,9 +340,9 @@ void TestDeleteKeysFromLog( ids.push_back(task_id); // Check that we added the correct object entries. auto add_callback = [task_id, data](gcs::AsyncGcsClient *client, const TaskID &id, - const TaskReconstructionDataT &d) { + const TaskReconstructionData &d) { ASSERT_EQ(id, task_id); - ASSERT_EQ(data->node_manager_id, d.node_manager_id); + ASSERT_EQ(data->node_manager_id(), d.node_manager_id()); test->IncrementNumCallbacks(); }; RAY_CHECK_OK( @@ -352,7 +352,7 @@ void TestDeleteKeysFromLog( // Check that lookup returns the added object entries. auto lookup_callback = [task_id, data_vector]( gcs::AsyncGcsClient *client, const TaskID &id, - const std::vector &data) { + const std::vector &data) { ASSERT_EQ(id, task_id); ASSERT_EQ(data.size(), 1); test->IncrementNumCallbacks(); @@ -367,7 +367,7 @@ void TestDeleteKeysFromLog( } for (const auto &task_id : ids) { auto lookup_callback = [task_id](gcs::AsyncGcsClient *client, const TaskID &id, - const std::vector &data) { + const std::vector &data) { ASSERT_EQ(id, task_id); ASSERT_TRUE(data.size() == 0); test->IncrementNumCallbacks(); @@ -379,7 +379,7 @@ void TestDeleteKeysFromLog( void TestDeleteKeysFromTable(const DriverID &driver_id, std::shared_ptr client, - std::vector> &data_vector, + std::vector> &data_vector, bool stop_at_end) { std::vector ids; TaskID task_id; @@ -388,16 +388,16 @@ void TestDeleteKeysFromTable(const DriverID &driver_id, ids.push_back(task_id); // Check that we added the correct object entries. auto add_callback = [task_id, data](gcs::AsyncGcsClient *client, const TaskID &id, - const protocol::TaskT &d) { + const TaskTableData &d) { ASSERT_EQ(id, task_id); - ASSERT_EQ(data->task_specification, d.task_specification); + ASSERT_EQ(data->task(), d.task()); test->IncrementNumCallbacks(); }; RAY_CHECK_OK(client->raylet_task_table().Add(driver_id, task_id, data, add_callback)); } for (const auto &task_id : ids) { auto task_lookup_callback = [task_id](gcs::AsyncGcsClient *client, const TaskID &id, - const protocol::TaskT &data) { + const TaskTableData &data) { ASSERT_EQ(id, task_id); test->IncrementNumCallbacks(); }; @@ -414,7 +414,7 @@ void TestDeleteKeysFromTable(const DriverID &driver_id, test->IncrementNumCallbacks(); }; auto undesired_callback = [](gcs::AsyncGcsClient *client, const TaskID &id, - const protocol::TaskT &data) { ASSERT_TRUE(false); }; + const TaskTableData &data) { ASSERT_TRUE(false); }; for (size_t i = 0; i < ids.size(); ++i) { RAY_CHECK_OK(client->raylet_task_table().Lookup( driver_id, task_id, undesired_callback, expected_failure_callback)); @@ -428,7 +428,7 @@ void TestDeleteKeysFromTable(const DriverID &driver_id, void TestDeleteKeysFromSet(const DriverID &driver_id, std::shared_ptr client, - std::vector> &data_vector) { + std::vector> &data_vector) { std::vector ids; ObjectID object_id; for (auto &data : data_vector) { @@ -436,9 +436,9 @@ void TestDeleteKeysFromSet(const DriverID &driver_id, ids.push_back(object_id); // Check that we added the correct object entries. auto add_callback = [object_id, data](gcs::AsyncGcsClient *client, const ObjectID &id, - const ObjectTableDataT &d) { + const ObjectTableData &d) { ASSERT_EQ(id, object_id); - ASSERT_EQ(data->manager, d.manager); + ASSERT_EQ(data->manager(), d.manager()); test->IncrementNumCallbacks(); }; RAY_CHECK_OK(client->object_table().Add(driver_id, object_id, data, add_callback)); @@ -447,7 +447,7 @@ void TestDeleteKeysFromSet(const DriverID &driver_id, // Check that lookup returns the added object entries. auto lookup_callback = [object_id, data_vector]( gcs::AsyncGcsClient *client, const ObjectID &id, - const std::vector &data) { + const std::vector &data) { ASSERT_EQ(id, object_id); ASSERT_EQ(data.size(), 1); test->IncrementNumCallbacks(); @@ -461,7 +461,7 @@ void TestDeleteKeysFromSet(const DriverID &driver_id, } for (const auto &object_id : ids) { auto lookup_callback = [object_id](gcs::AsyncGcsClient *client, const ObjectID &id, - const std::vector &data) { + const std::vector &data) { ASSERT_EQ(id, object_id); ASSERT_TRUE(data.size() == 0); test->IncrementNumCallbacks(); @@ -474,11 +474,11 @@ void TestDeleteKeysFromSet(const DriverID &driver_id, void TestDeleteKeys(const DriverID &driver_id, std::shared_ptr client) { // Test delete function for keys of Log. - std::vector> task_reconstruction_vector; + std::vector> task_reconstruction_vector; auto AppendTaskReconstructionData = [&task_reconstruction_vector](size_t add_count) { for (size_t i = 0; i < add_count; ++i) { - auto data = std::make_shared(); - data->node_manager_id = ObjectID::FromRandom().Hex(); + auto data = std::make_shared(); + data->set_node_manager_id(ObjectID::FromRandom().Hex()); task_reconstruction_vector.push_back(data); } }; @@ -503,11 +503,11 @@ void TestDeleteKeys(const DriverID &driver_id, TestDeleteKeysFromLog(driver_id, client, task_reconstruction_vector); // Test delete function for keys of Table. - std::vector> task_vector; + std::vector> task_vector; auto AppendTaskData = [&task_vector](size_t add_count) { for (size_t i = 0; i < add_count; ++i) { - auto task_data = std::make_shared(); - task_data->task_specification = ObjectID::FromRandom().Hex(); + auto task_data = std::make_shared(); + task_data->set_task(ObjectID::FromRandom().Hex()); task_vector.push_back(task_data); } }; @@ -529,11 +529,11 @@ void TestDeleteKeys(const DriverID &driver_id, 9 * RayConfig::instance().maximum_gcs_deletion_batch_size()); // Test delete function for keys of Set. - std::vector> object_vector; + std::vector> object_vector; auto AppendObjectData = [&object_vector](size_t add_count) { for (size_t i = 0; i < add_count; ++i) { - auto data = std::make_shared(); - data->manager = ObjectID::FromRandom().Hex(); + auto data = std::make_shared(); + data->set_manager(ObjectID::FromRandom().Hex()); object_vector.push_back(data); } }; @@ -561,45 +561,6 @@ TEST_F(TestGcsWithAsio, TestDeleteKey) { TestDeleteKeys(driver_id_, client_); } -// Task table callbacks. -void TaskAdded(gcs::AsyncGcsClient *client, const TaskID &id, - const TaskTableDataT &data) { - ASSERT_EQ(data.scheduling_state, SchedulingState::SCHEDULED); - ASSERT_EQ(data.raylet_id, kRandomId); -} - -void TaskLookupHelper(gcs::AsyncGcsClient *client, const TaskID &id, - const TaskTableDataT &data, bool do_stop) { - ASSERT_EQ(data.scheduling_state, SchedulingState::SCHEDULED); - ASSERT_EQ(data.raylet_id, kRandomId); - if (do_stop) { - test->Stop(); - } -} -void TaskLookup(gcs::AsyncGcsClient *client, const TaskID &id, - const TaskTableDataT &data) { - TaskLookupHelper(client, id, data, /*do_stop=*/false); -} -void TaskLookupWithStop(gcs::AsyncGcsClient *client, const TaskID &id, - const TaskTableDataT &data) { - TaskLookupHelper(client, id, data, /*do_stop=*/true); -} - -void TaskLookupFailure(gcs::AsyncGcsClient *client, const TaskID &id) { - RAY_CHECK(false); -} - -void TaskLookupAfterUpdate(gcs::AsyncGcsClient *client, const TaskID &id, - const TaskTableDataT &data) { - ASSERT_EQ(data.scheduling_state, SchedulingState::LOST); - test->Stop(); -} - -void TaskLookupAfterUpdateFailure(gcs::AsyncGcsClient *client, const TaskID &id) { - RAY_CHECK(false); - test->Stop(); -} - void TestLogSubscribeAll(const DriverID &driver_id, std::shared_ptr client) { std::vector driver_ids; @@ -609,11 +570,11 @@ void TestLogSubscribeAll(const DriverID &driver_id, // Callback for a notification. auto notification_callback = [driver_ids](gcs::AsyncGcsClient *client, const DriverID &id, - const std::vector data) { + const std::vector data) { ASSERT_EQ(id, driver_ids[test->NumCallbacks()]); // Check that we get notifications in the same order as the writes. for (const auto &entry : data) { - ASSERT_EQ(entry.driver_id, driver_ids[test->NumCallbacks()].Binary()); + ASSERT_EQ(entry.driver_id(), driver_ids[test->NumCallbacks()].Binary()); test->IncrementNumCallbacks(); } if (test->NumCallbacks() == driver_ids.size()) { @@ -660,7 +621,7 @@ void TestSetSubscribeAll(const DriverID &driver_id, auto notification_callback = [object_ids, managers]( gcs::AsyncGcsClient *client, const ObjectID &id, const GcsChangeMode change_mode, - const std::vector data) { + const std::vector data) { if (test->NumCallbacks() < 3 * 3) { ASSERT_EQ(change_mode, GcsChangeMode::APPEND_OR_ADD); } else { @@ -669,7 +630,7 @@ void TestSetSubscribeAll(const DriverID &driver_id, ASSERT_EQ(id, object_ids[test->NumCallbacks() / 3 % 3]); // Check that we get notifications in the same order as the writes. for (const auto &entry : data) { - ASSERT_EQ(entry.manager, managers[test->NumCallbacks() % 3]); + ASSERT_EQ(entry.manager(), managers[test->NumCallbacks() % 3]); test->IncrementNumCallbacks(); } if (test->NumCallbacks() == object_ids.size() * 3 * 2) { @@ -684,8 +645,8 @@ void TestSetSubscribeAll(const DriverID &driver_id, // We have subscribed. Do the writes to the table. for (size_t i = 0; i < object_ids.size(); i++) { for (size_t j = 0; j < managers.size(); j++) { - auto data = std::make_shared(); - data->manager = managers[j]; + auto data = std::make_shared(); + data->set_manager(managers[j]); for (int k = 0; k < 3; k++) { // Add the same entry several times. // Expect no notification if the entry already exists. @@ -696,8 +657,8 @@ void TestSetSubscribeAll(const DriverID &driver_id, } for (size_t i = 0; i < object_ids.size(); i++) { for (size_t j = 0; j < managers.size(); j++) { - auto data = std::make_shared(); - data->manager = managers[j]; + auto data = std::make_shared(); + data->set_manager(managers[j]); for (int k = 0; k < 3; k++) { // Remove the same entry several times. // Expect no notification if the entry doesn't exist. @@ -740,11 +701,11 @@ void TestTableSubscribeId(const DriverID &driver_id, // received for keys that we requested notifications for. auto notification_callback = [task_id2, task_specs2](gcs::AsyncGcsClient *client, const TaskID &id, - const protocol::TaskT &data) { + const TaskTableData &data) { // Check that we only get notifications for the requested key. ASSERT_EQ(id, task_id2); // Check that we get notifications in the same order as the writes. - ASSERT_EQ(data.task_specification, task_specs2[test->NumCallbacks()]); + ASSERT_EQ(data.task(), task_specs2[test->NumCallbacks()]); test->IncrementNumCallbacks(); if (test->NumCallbacks() == task_specs2.size()) { test->Stop(); @@ -771,13 +732,13 @@ void TestTableSubscribeId(const DriverID &driver_id, // Write both keys. We should only receive notifications for the key that // we requested them for. for (const auto &task_spec : task_specs1) { - auto data = std::make_shared(); - data->task_specification = task_spec; + auto data = std::make_shared(); + data->set_task(task_spec); RAY_CHECK_OK(client->raylet_task_table().Add(driver_id, task_id1, data, nullptr)); } for (const auto &task_spec : task_specs2) { - auto data = std::make_shared(); - data->task_specification = task_spec; + auto data = std::make_shared(); + data->set_task(task_spec); RAY_CHECK_OK(client->raylet_task_table().Add(driver_id, task_id2, data, nullptr)); } }; @@ -808,27 +769,27 @@ void TestLogSubscribeId(const DriverID &driver_id, // Add a log entry. DriverID driver_id1 = DriverID::FromRandom(); std::vector driver_ids1 = {"abc", "def", "ghi"}; - auto data1 = std::make_shared(); - data1->driver_id = driver_ids1[0]; + auto data1 = std::make_shared(); + data1->set_driver_id(driver_ids1[0]); RAY_CHECK_OK(client->driver_table().Append(driver_id, driver_id1, data1, nullptr)); // Add a log entry at a second key. DriverID driver_id2 = DriverID::FromRandom(); std::vector driver_ids2 = {"jkl", "mno", "pqr"}; - auto data2 = std::make_shared(); - data2->driver_id = driver_ids2[0]; + auto data2 = std::make_shared(); + data2->set_driver_id(driver_ids2[0]); RAY_CHECK_OK(client->driver_table().Append(driver_id, driver_id2, data2, nullptr)); // The callback for a notification from the table. This should only be // received for keys that we requested notifications for. auto notification_callback = [driver_id2, driver_ids2]( gcs::AsyncGcsClient *client, const UniqueID &id, - const std::vector &data) { + const std::vector &data) { // Check that we only get notifications for the requested key. ASSERT_EQ(id, driver_id2); // Check that we get notifications in the same order as the writes. for (const auto &entry : data) { - ASSERT_EQ(entry.driver_id, driver_ids2[test->NumCallbacks()]); + ASSERT_EQ(entry.driver_id(), driver_ids2[test->NumCallbacks()]); test->IncrementNumCallbacks(); } if (test->NumCallbacks() == driver_ids2.size()) { @@ -847,14 +808,14 @@ void TestLogSubscribeId(const DriverID &driver_id, // we requested them for. auto remaining = std::vector(++driver_ids1.begin(), driver_ids1.end()); for (const auto &driver_id_it : remaining) { - auto data = std::make_shared(); - data->driver_id = driver_id_it; + auto data = std::make_shared(); + data->set_driver_id(driver_id_it); RAY_CHECK_OK(client->driver_table().Append(driver_id, driver_id1, data, nullptr)); } remaining = std::vector(++driver_ids2.begin(), driver_ids2.end()); for (const auto &driver_id_it : remaining) { - auto data = std::make_shared(); - data->driver_id = driver_id_it; + auto data = std::make_shared(); + data->set_driver_id(driver_id_it); RAY_CHECK_OK(client->driver_table().Append(driver_id, driver_id2, data, nullptr)); } }; @@ -882,15 +843,15 @@ void TestSetSubscribeId(const DriverID &driver_id, // Add a set entry. ObjectID object_id1 = ObjectID::FromRandom(); std::vector managers1 = {"abc", "def", "ghi"}; - auto data1 = std::make_shared(); - data1->manager = managers1[0]; + auto data1 = std::make_shared(); + data1->set_manager(managers1[0]); RAY_CHECK_OK(client->object_table().Add(driver_id, object_id1, data1, nullptr)); // Add a set entry at a second key. ObjectID object_id2 = ObjectID::FromRandom(); std::vector managers2 = {"jkl", "mno", "pqr"}; - auto data2 = std::make_shared(); - data2->manager = managers2[0]; + auto data2 = std::make_shared(); + data2->set_manager(managers2[0]); RAY_CHECK_OK(client->object_table().Add(driver_id, object_id2, data2, nullptr)); // The callback for a notification from the table. This should only be @@ -898,13 +859,13 @@ void TestSetSubscribeId(const DriverID &driver_id, auto notification_callback = [object_id2, managers2]( gcs::AsyncGcsClient *client, const ObjectID &id, const GcsChangeMode change_mode, - const std::vector &data) { + const std::vector &data) { ASSERT_EQ(change_mode, GcsChangeMode::APPEND_OR_ADD); // Check that we only get notifications for the requested key. ASSERT_EQ(id, object_id2); // Check that we get notifications in the same order as the writes. for (const auto &entry : data) { - ASSERT_EQ(entry.manager, managers2[test->NumCallbacks()]); + ASSERT_EQ(entry.manager(), managers2[test->NumCallbacks()]); test->IncrementNumCallbacks(); } if (test->NumCallbacks() == managers2.size()) { @@ -923,14 +884,14 @@ void TestSetSubscribeId(const DriverID &driver_id, // we requested them for. auto remaining = std::vector(++managers1.begin(), managers1.end()); for (const auto &manager : remaining) { - auto data = std::make_shared(); - data->manager = manager; + auto data = std::make_shared(); + data->set_manager(manager); RAY_CHECK_OK(client->object_table().Add(driver_id, object_id1, data, nullptr)); } remaining = std::vector(++managers2.begin(), managers2.end()); for (const auto &manager : remaining) { - auto data = std::make_shared(); - data->manager = manager; + auto data = std::make_shared(); + data->set_manager(manager); RAY_CHECK_OK(client->object_table().Add(driver_id, object_id2, data, nullptr)); } }; @@ -958,8 +919,8 @@ void TestTableSubscribeCancel(const DriverID &driver_id, // Add a table entry. TaskID task_id = TaskID::FromRandom(); std::vector task_specs = {"jkl", "mno", "pqr"}; - auto data = std::make_shared(); - data->task_specification = task_specs[0]; + auto data = std::make_shared(); + data->set_task(task_specs[0]); RAY_CHECK_OK(client->raylet_task_table().Add(driver_id, task_id, data, nullptr)); // The failure callback should not be called since all keys are non-empty @@ -972,14 +933,14 @@ void TestTableSubscribeCancel(const DriverID &driver_id, // received for keys that we requested notifications for. auto notification_callback = [task_id, task_specs](gcs::AsyncGcsClient *client, const TaskID &id, - const protocol::TaskT &data) { + const TaskTableData &data) { ASSERT_EQ(id, task_id); // Check that we only get notifications for the first and last writes, // since notifications are canceled in between. if (test->NumCallbacks() == 0) { - ASSERT_EQ(data.task_specification, task_specs.front()); + ASSERT_EQ(data.task(), task_specs.front()); } else { - ASSERT_EQ(data.task_specification, task_specs.back()); + ASSERT_EQ(data.task(), task_specs.back()); } test->IncrementNumCallbacks(); if (test->NumCallbacks() == 2) { @@ -1001,8 +962,8 @@ void TestTableSubscribeCancel(const DriverID &driver_id, // a notification for these writes. auto remaining = std::vector(++task_specs.begin(), task_specs.end()); for (const auto &task_spec : remaining) { - auto data = std::make_shared(); - data->task_specification = task_spec; + auto data = std::make_shared(); + data->set_task(task_spec); RAY_CHECK_OK(client->raylet_task_table().Add(driver_id, task_id, data, nullptr)); } // Request notifications again. We should receive a notification for the @@ -1034,15 +995,15 @@ void TestLogSubscribeCancel(const DriverID &driver_id, // Add a log entry. DriverID random_driver_id = DriverID::FromRandom(); std::vector driver_ids = {"jkl", "mno", "pqr"}; - auto data = std::make_shared(); - data->driver_id = driver_ids[0]; + auto data = std::make_shared(); + data->set_driver_id(driver_ids[0]); RAY_CHECK_OK(client->driver_table().Append(driver_id, random_driver_id, data, nullptr)); // The callback for a notification from the object table. This should only be // received for the object that we requested notifications for. auto notification_callback = [random_driver_id, driver_ids]( gcs::AsyncGcsClient *client, const UniqueID &id, - const std::vector &data) { + const std::vector &data) { ASSERT_EQ(id, random_driver_id); // Check that we get a duplicate notification for the first write. We get a // duplicate notification because the log is append-only and notifications @@ -1050,7 +1011,7 @@ void TestLogSubscribeCancel(const DriverID &driver_id, auto driver_ids_copy = driver_ids; driver_ids_copy.insert(driver_ids_copy.begin(), driver_ids_copy.front()); for (const auto &entry : data) { - ASSERT_EQ(entry.driver_id, driver_ids_copy[test->NumCallbacks()]); + ASSERT_EQ(entry.driver_id(), driver_ids_copy[test->NumCallbacks()]); test->IncrementNumCallbacks(); } if (test->NumCallbacks() == driver_ids_copy.size()) { @@ -1072,8 +1033,8 @@ void TestLogSubscribeCancel(const DriverID &driver_id, // receive a notification for these writes. auto remaining = std::vector(++driver_ids.begin(), driver_ids.end()); for (const auto &remaining_driver_id : remaining) { - auto data = std::make_shared(); - data->driver_id = remaining_driver_id; + auto data = std::make_shared(); + data->set_driver_id(remaining_driver_id); RAY_CHECK_OK( client->driver_table().Append(driver_id, random_driver_id, data, nullptr)); } @@ -1107,8 +1068,8 @@ void TestSetSubscribeCancel(const DriverID &driver_id, // Add a set entry. ObjectID object_id = ObjectID::FromRandom(); std::vector managers = {"jkl", "mno", "pqr"}; - auto data = std::make_shared(); - data->manager = managers[0]; + auto data = std::make_shared(); + data->set_manager(managers[0]); RAY_CHECK_OK(client->object_table().Add(driver_id, object_id, data, nullptr)); // The callback for a notification from the object table. This should only be @@ -1116,7 +1077,7 @@ void TestSetSubscribeCancel(const DriverID &driver_id, auto notification_callback = [object_id, managers]( gcs::AsyncGcsClient *client, const ObjectID &id, const GcsChangeMode change_mode, - const std::vector &data) { + const std::vector &data) { ASSERT_EQ(change_mode, GcsChangeMode::APPEND_OR_ADD); ASSERT_EQ(id, object_id); // Check that we get a duplicate notification for the first write. We get a @@ -1124,7 +1085,7 @@ void TestSetSubscribeCancel(const DriverID &driver_id, // are canceled after the first write, then requested again. if (data.size() == 1) { // first notification - ASSERT_EQ(data[0].manager, managers[0]); + ASSERT_EQ(data[0].manager(), managers[0]); test->IncrementNumCallbacks(); } else { // second notification @@ -1132,7 +1093,7 @@ void TestSetSubscribeCancel(const DriverID &driver_id, std::unordered_set managers_set(managers.begin(), managers.end()); std::unordered_set data_managers_set; for (const auto &entry : data) { - data_managers_set.insert(entry.manager); + data_managers_set.insert(entry.manager()); test->IncrementNumCallbacks(); } ASSERT_EQ(managers_set, data_managers_set); @@ -1156,8 +1117,8 @@ void TestSetSubscribeCancel(const DriverID &driver_id, // receive a notification for these writes. auto remaining = std::vector(++managers.begin(), managers.end()); for (const auto &manager : remaining) { - auto data = std::make_shared(); - data->manager = manager; + auto data = std::make_shared(); + data->set_manager(manager); RAY_CHECK_OK(client->object_table().Add(driver_id, object_id, data, nullptr)); } // Request notifications again. We should receive a notification for the @@ -1186,17 +1147,17 @@ TEST_F(TestGcsWithAsio, TestSetSubscribeCancel) { } void ClientTableNotification(gcs::AsyncGcsClient *client, const ClientID &client_id, - const ClientTableDataT &data, bool is_insertion) { + const ClientTableData &data, bool is_insertion) { ClientID added_id = client->client_table().GetLocalClientId(); ASSERT_EQ(client_id, added_id); - ASSERT_EQ(ClientID::FromBinary(data.client_id), added_id); - ASSERT_EQ(ClientID::FromBinary(data.client_id), added_id); - ASSERT_EQ(data.entry_type == EntryType::INSERTION, is_insertion); + ASSERT_EQ(ClientID::FromBinary(data.client_id()), added_id); + ASSERT_EQ(ClientID::FromBinary(data.client_id()), added_id); + ASSERT_EQ(data.entry_type() == ClientTableData::INSERTION, is_insertion); - ClientTableDataT cached_client; + ClientTableData cached_client; client->client_table().GetClient(added_id, cached_client); - ASSERT_EQ(ClientID::FromBinary(cached_client.client_id), added_id); - ASSERT_EQ(cached_client.entry_type == EntryType::INSERTION, is_insertion); + ASSERT_EQ(ClientID::FromBinary(cached_client.client_id()), added_id); + ASSERT_EQ(cached_client.entry_type() == ClientTableData::INSERTION, is_insertion); } void TestClientTableConnect(const DriverID &driver_id, @@ -1204,17 +1165,17 @@ void TestClientTableConnect(const DriverID &driver_id, // Register callbacks for when a client gets added and removed. The latter // event will stop the event loop. client->client_table().RegisterClientAddedCallback( - [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) { + [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableData &data) { ClientTableNotification(client, id, data, true); test->Stop(); }); // Connect and disconnect to client table. We should receive notifications // for the addition and removal of our own entry. - ClientTableDataT local_client_info = client->client_table().GetLocalClient(); - local_client_info.node_manager_address = "127.0.0.1"; - local_client_info.node_manager_port = 0; - local_client_info.object_manager_port = 0; + ClientTableData local_client_info = client->client_table().GetLocalClient(); + local_client_info.set_node_manager_address("127.0.0.1"); + local_client_info.set_node_manager_port(0); + local_client_info.set_object_manager_port(0); RAY_CHECK_OK(client->client_table().Connect(local_client_info)); test->Start(); } @@ -1229,23 +1190,23 @@ void TestClientTableDisconnect(const DriverID &driver_id, // Register callbacks for when a client gets added and removed. The latter // event will stop the event loop. client->client_table().RegisterClientAddedCallback( - [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) { + [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableData &data) { ClientTableNotification(client, id, data, /*is_insertion=*/true); // Disconnect from the client table. We should receive a notification // for the removal of our own entry. RAY_CHECK_OK(client->client_table().Disconnect()); }); client->client_table().RegisterClientRemovedCallback( - [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) { + [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableData &data) { ClientTableNotification(client, id, data, /*is_insertion=*/false); test->Stop(); }); // Connect to the client table. We should receive notification for the // addition of our own entry. - ClientTableDataT local_client_info = client->client_table().GetLocalClient(); - local_client_info.node_manager_address = "127.0.0.1"; - local_client_info.node_manager_port = 0; - local_client_info.object_manager_port = 0; + ClientTableData local_client_info = client->client_table().GetLocalClient(); + local_client_info.set_node_manager_address("127.0.0.1"); + local_client_info.set_node_manager_port(0); + local_client_info.set_object_manager_port(0); RAY_CHECK_OK(client->client_table().Connect(local_client_info)); test->Start(); } @@ -1260,20 +1221,20 @@ void TestClientTableImmediateDisconnect(const DriverID &driver_id, // Register callbacks for when a client gets added and removed. The latter // event will stop the event loop. client->client_table().RegisterClientAddedCallback( - [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) { + [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableData &data) { ClientTableNotification(client, id, data, true); }); client->client_table().RegisterClientRemovedCallback( - [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableDataT &data) { + [](gcs::AsyncGcsClient *client, const ClientID &id, const ClientTableData &data) { ClientTableNotification(client, id, data, false); test->Stop(); }); // Connect to then immediately disconnect from the client table. We should // receive notifications for the addition and removal of our own entry. - ClientTableDataT local_client_info = client->client_table().GetLocalClient(); - local_client_info.node_manager_address = "127.0.0.1"; - local_client_info.node_manager_port = 0; - local_client_info.object_manager_port = 0; + ClientTableData local_client_info = client->client_table().GetLocalClient(); + local_client_info.set_node_manager_address("127.0.0.1"); + local_client_info.set_node_manager_port(0); + local_client_info.set_object_manager_port(0); RAY_CHECK_OK(client->client_table().Connect(local_client_info)); RAY_CHECK_OK(client->client_table().Disconnect()); test->Start(); @@ -1286,10 +1247,10 @@ TEST_F(TestGcsWithAsio, TestClientTableImmediateDisconnect) { void TestClientTableMarkDisconnected(const DriverID &driver_id, std::shared_ptr client) { - ClientTableDataT local_client_info = client->client_table().GetLocalClient(); - local_client_info.node_manager_address = "127.0.0.1"; - local_client_info.node_manager_port = 0; - local_client_info.object_manager_port = 0; + ClientTableData local_client_info = client->client_table().GetLocalClient(); + local_client_info.set_node_manager_address("127.0.0.1"); + local_client_info.set_node_manager_port(0); + local_client_info.set_object_manager_port(0); // Connect to the client table to start receiving notifications. RAY_CHECK_OK(client->client_table().Connect(local_client_info)); // Mark a different client as dead. @@ -1299,8 +1260,8 @@ void TestClientTableMarkDisconnected(const DriverID &driver_id, // marked as dead. client->client_table().RegisterClientRemovedCallback( [dead_client_id](gcs::AsyncGcsClient *client, const UniqueID &id, - const ClientTableDataT &data) { - ASSERT_EQ(ClientID::FromBinary(data.client_id), dead_client_id); + const ClientTableData &data) { + ASSERT_EQ(ClientID::FromBinary(data.client_id()), dead_client_id); test->Stop(); }); test->Start(); @@ -1316,31 +1277,31 @@ void TestHashTable(const DriverID &driver_id, const int expected_count = 14; ClientID client_id = ClientID::FromRandom(); // Prepare the first resource map: data_map1. - auto cpu_data = std::make_shared(); - cpu_data->resource_name = "CPU"; - cpu_data->resource_capacity = 100; - auto gpu_data = std::make_shared(); - gpu_data->resource_name = "GPU"; - gpu_data->resource_capacity = 2; + auto cpu_data = std::make_shared(); + cpu_data->set_resource_name("CPU"); + cpu_data->set_resource_capacity(100); + auto gpu_data = std::make_shared(); + gpu_data->set_resource_name("GPU"); + gpu_data->set_resource_capacity(2); DynamicResourceTable::DataMap data_map1; data_map1.emplace("CPU", cpu_data); data_map1.emplace("GPU", gpu_data); // Prepare the second resource map: data_map2 which decreases CPU, // increases GPU and add a new CUSTOM compared to data_map1. - auto data_cpu = std::make_shared(); - data_cpu->resource_name = "CPU"; - data_cpu->resource_capacity = 50; - auto data_gpu = std::make_shared(); - data_gpu->resource_name = "GPU"; - data_gpu->resource_capacity = 10; - auto data_custom = std::make_shared(); - data_custom->resource_name = "CUSTOM"; - data_custom->resource_capacity = 2; + auto data_cpu = std::make_shared(); + data_cpu->set_resource_name("CPU"); + data_cpu->set_resource_capacity(50); + auto data_gpu = std::make_shared(); + data_gpu->set_resource_name("GPU"); + data_gpu->set_resource_capacity(10); + auto data_custom = std::make_shared(); + data_custom->set_resource_name("CUSTOM"); + data_custom->set_resource_capacity(2); DynamicResourceTable::DataMap data_map2; data_map2.emplace("CPU", data_cpu); data_map2.emplace("GPU", data_gpu); data_map2.emplace("CUSTOM", data_custom); - data_map2["CPU"]->resource_capacity = 50; + data_map2["CPU"]->set_resource_capacity(50); // This is a common comparison function for the test. auto compare_test = [](const DynamicResourceTable::DataMap &data1, const DynamicResourceTable::DataMap &data2) { @@ -1348,8 +1309,8 @@ void TestHashTable(const DriverID &driver_id, for (const auto &data : data1) { auto iter = data2.find(data.first); ASSERT_TRUE(iter != data2.end()); - ASSERT_EQ(iter->second->resource_name, data.second->resource_name); - ASSERT_EQ(iter->second->resource_capacity, data.second->resource_capacity); + ASSERT_EQ(iter->second->resource_name(), data.second->resource_name()); + ASSERT_EQ(iter->second->resource_capacity(), data.second->resource_capacity()); } }; auto subscribe_callback = [](AsyncGcsClient *client) { diff --git a/src/ray/gcs/format/gcs.fbs b/src/ray/gcs/format/gcs.fbs index 614c80b27672..c06c79a02928 100644 --- a/src/ray/gcs/format/gcs.fbs +++ b/src/ray/gcs/format/gcs.fbs @@ -1,52 +1,9 @@ -enum Language:int { - PYTHON = 0, - CPP = 1, - JAVA = 2 -} - -// These indexes are mapped to strings in ray_redis_module.cc. -enum TablePrefix:int { - UNUSED = 0, - TASK, - RAYLET_TASK, - CLIENT, - OBJECT, - ACTOR, - FUNCTION, - TASK_RECONSTRUCTION, - HEARTBEAT, - HEARTBEAT_BATCH, - ERROR_INFO, - DRIVER, - PROFILE, - TASK_LEASE, - ACTOR_CHECKPOINT, - ACTOR_CHECKPOINT_ID, - NODE_RESOURCE, -} +// TODO(hchen): Migrate data structures in this file to protobuf (`gcs.proto`). -// The channel that Add operations to the Table should be published on, if any. -enum TablePubsub:int { - NO_PUBLISH = 0, - TASK, - RAYLET_TASK, - CLIENT, - OBJECT, - ACTOR, - HEARTBEAT, - HEARTBEAT_BATCH, - ERROR_INFO, - TASK_LEASE, - DRIVER, - NODE_RESOURCE, -} - -// Enum for the entry type in the ClientTable -enum EntryType:int { - INSERTION = 0, - DELETION, - RES_CREATEUPDATE, - RES_DELETE, +enum Language:int { + PYTHON=0, + JAVA=1, + CPP=2, } table Arg { @@ -106,6 +63,11 @@ table TaskInfo { // For a Python function, it should be: [module_name, class_name, function_name] // For a Java function, it should be: [class_name, method_name, type_descriptor] function_descriptor: [string]; + // The dynamic options used in the worker command when starting the worker process for + // an actor creation task. If the list isn't empty, the options will be used to replace + // the placeholder strings (`RAY_WORKER_OPTION_0`, `RAY_WORKER_OPTION_1`, etc) in the + // worker command. + dynamic_worker_options: [string]; } table ResourcePair { @@ -115,118 +77,6 @@ table ResourcePair { value: double; } -enum GcsChangeMode:int { - APPEND_OR_ADD = 0, - REMOVE, -} - -table GcsEntry { - change_mode: GcsChangeMode; - id: string; - entries: [string]; -} - -table FunctionTableData { - language: Language; - name: string; - data: string; -} - -table ObjectTableData { - // The size of the object. - object_size: long; - // The node manager ID that this object appeared on or was evicted by. - manager: string; -} - -table TaskReconstructionData { - // The number of times this task has been reconstructed so far. - num_reconstructions: int; - // The node manager that is trying to reconstruct the task. - node_manager_id: string; -} - -enum SchedulingState:int { - NONE = 0, - WAITING = 1, - SCHEDULED = 2, - QUEUED = 4, - RUNNING = 8, - DONE = 16, - LOST = 32, - RECONSTRUCTING = 64 -} - -table TaskTableData { - // The state of the task. - scheduling_state: SchedulingState; - // A raylet ID. - raylet_id: string; - // A string of bytes representing the task's TaskExecutionDependencies. - execution_dependencies: string; - // The number of times the task was spilled back by raylets. - spillback_count: long; - // A string of bytes representing the task specification. - task_info: string; - // TODO(pcm): This is at the moment duplicated in task_info, remove that one - updated: bool; -} - -table TaskTableTestAndUpdate { - test_raylet_id: string; - test_state_bitmask: SchedulingState; - update_state: SchedulingState; -} - -table ClassTableData { -} - -enum ActorState:int { - // Actor is alive. - ALIVE = 0, - // Actor is dead, now being reconstructed. - // After reconstruction finishes, the state will become alive again. - RECONSTRUCTING = 1, - // Actor is already dead and won't be reconstructed. - DEAD = 2 -} - -table ActorTableData { - // The ID of the actor that was created. - actor_id: string; - // The dummy object ID returned by the actor creation task. If the actor - // dies, then this is the object that should be reconstructed for the actor - // to be recreated. - actor_creation_dummy_object_id: string; - // The ID of the driver that created the actor. - driver_id: string; - // The ID of the node manager that created the actor. - node_manager_id: string; - // Current state of this actor. - state: ActorState; - // Max number of times this actor should be reconstructed. - max_reconstructions: int; - // Remaining number of reconstructions. - remaining_reconstructions: int; -} - -table ErrorTableData { - // The ID of the driver that the error is for. - driver_id: string; - // The type of the error. - type: string; - // The error message. - error_message: string; - // The timestamp of the error message. - timestamp: double; -} - -table CustomSerializerData { -} - -table ConfigTableData { -} - table ProfileEvent { // The type of the event. event_type: string; @@ -253,119 +103,3 @@ table ProfileTableData { // we don't want each event to require a GCS command. profile_events: [ProfileEvent]; } - -table RayResource { - // The type of the resource. - resource_name: string; - // The total capacity of this resource type. - resource_capacity: double; -} - -table ClientTableData { - // The client ID of the client that the message is about. - client_id: string; - // The IP address of the client's node manager. - node_manager_address: string; - // The IPC socket name of the client's raylet. - raylet_socket_name: string; - // The IPC socket name of the client's plasma store. - object_store_socket_name: string; - // The port at which the client's node manager is listening for TCP - // connections from other node managers. - node_manager_port: int; - // The port at which the client's object manager is listening for TCP - // connections from other object managers. - object_manager_port: int; - // Enum to store the entry type in the log - entry_type: EntryType = INSERTION; - resources_total_label: [string]; - resources_total_capacity: [double]; -} - -table HeartbeatTableData { - // Node manager client id - client_id: string; - // Resource capacity currently available on this node manager. - resources_available_label: [string]; - resources_available_capacity: [double]; - // Total resource capacity configured for this node manager. - resources_total_label: [string]; - resources_total_capacity: [double]; - // Aggregate outstanding resource load on this node manager. - resource_load_label: [string]; - resource_load_capacity: [double]; -} - -table HeartbeatBatchTableData { - batch: [HeartbeatTableData]; -} - -// Data for a lease on task execution. -table TaskLeaseData { - // Node manager client ID. - node_manager_id: string; - // The time that the lease was last acquired at. NOTE(swang): This is the - // system clock time according to the node that added the entry and is not - // synchronized with other nodes. - acquired_at: long; - // The period that the lease is active for. - timeout: long; -} - -table DriverTableData { - // The driver ID. - driver_id: string; - // Whether it's dead. - is_dead: bool; -} - -// This table stores the actor checkpoint data. An actor checkpoint -// is the snapshot of an actor's state in the actor registration. -// See `actor_registration.h` for more detailed explanation of these fields. -table ActorCheckpointData { - // ID of this actor. - actor_id: string; - // The dummy object ID of actor's most recently executed task. - execution_dependency: string; - // A list of IDs of this actor's handles. - handle_ids: [string]; - // The task counters of the above handles. - task_counters: [long]; - // The frontier dependencies of the above handles. - frontier_dependencies: [string]; - // A list of unreleased dummy objects from this actor. - unreleased_dummy_objects: [string]; - // The numbers of dependencies for the above unreleased dummy objects. - num_dummy_object_dependencies: [int]; -} - -// This table stores the actor-to-available-checkpoint-ids mapping. -table ActorCheckpointIdData { - // ID of this actor. - actor_id: string; - // IDs of this actor's available checkpoints. - // Note, this is a long string that concatenates all the IDs. - checkpoint_ids: string; - // A list of the timestamps for each of the above `checkpoint_ids`. - timestamps: [long]; -} - -// This enum type is used as object's metadata to indicate the object's creating -// task has failed because of a certain error. -// TODO(hchen): We may want to make these errors more specific. E.g., we may want -// to distinguish between intentional and expected actor failures, and between -// worker process failure and node failure. -enum ErrorType:int { - // Indicates that a task failed because the worker died unexpectedly while executing it. - WORKER_DIED = 1, - // Indicates that a task failed because the actor died unexpectedly before finishing it. - ACTOR_DIED = 2, - // Indicates that an object is lost and cannot be reconstructed. - // Note, this currently only happens to actor objects. When the actor's state is already - // after the object's creating task, the actor cannot re-run the task. - // TODO(hchen): we may want to reuse this error type for more cases. E.g., - // 1) A object that was put by the driver. - // 2) The object's creating task is already cleaned up from GCS (this currently - // crashes raylet). - OBJECT_UNRECONSTRUCTABLE = 3, -} diff --git a/src/ray/gcs/redis_context.h b/src/ray/gcs/redis_context.h index fc42e5cd98c2..093aab2455d9 100644 --- a/src/ray/gcs/redis_context.h +++ b/src/ray/gcs/redis_context.h @@ -9,7 +9,7 @@ #include "ray/common/status.h" #include "ray/util/logging.h" -#include "ray/gcs/format/gcs_generated.h" +#include "ray/protobuf/gcs.pb.h" extern "C" { #include "ray/thirdparty/hiredis/adapters/ae.h" @@ -25,6 +25,9 @@ namespace ray { namespace gcs { +using rpc::TablePrefix; +using rpc::TablePubsub; + /// A simple reply wrapper for redis reply. class CallbackReply { public: @@ -126,8 +129,8 @@ class RedisContext { /// -1 for unused. If set, then data must be provided. /// \return Status. template - Status RunAsync(const std::string &command, const ID &id, const uint8_t *data, - int64_t length, const TablePrefix prefix, + Status RunAsync(const std::string &command, const ID &id, const void *data, + size_t length, const TablePrefix prefix, const TablePubsub pubsub_channel, RedisCallback redisCallback, int log_length = -1); @@ -157,9 +160,9 @@ class RedisContext { }; template -Status RedisContext::RunAsync(const std::string &command, const ID &id, - const uint8_t *data, int64_t length, - const TablePrefix prefix, const TablePubsub pubsub_channel, +Status RedisContext::RunAsync(const std::string &command, const ID &id, const void *data, + size_t length, const TablePrefix prefix, + const TablePubsub pubsub_channel, RedisCallback redisCallback, int log_length) { int64_t callback_index = RedisCallbackManager::instance().add(redisCallback, false); if (length > 0) { diff --git a/src/ray/gcs/redis_module/ray_redis_module.cc b/src/ray/gcs/redis_module/ray_redis_module.cc index e291b7ffdb32..c3a82c320d06 100644 --- a/src/ray/gcs/redis_module/ray_redis_module.cc +++ b/src/ray/gcs/redis_module/ray_redis_module.cc @@ -5,11 +5,16 @@ #include "ray/common/id.h" #include "ray/common/status.h" #include "ray/gcs/format/gcs_generated.h" +#include "ray/protobuf/gcs.pb.h" #include "ray/util/logging.h" #include "redis_string.h" #include "redismodule.h" using ray::Status; +using ray::rpc::GcsChangeMode; +using ray::rpc::GcsEntry; +using ray::rpc::TablePrefix; +using ray::rpc::TablePubsub; #if RAY_USE_NEW_GCS // Under this flag, ray-project/credis will be loaded. Specifically, via @@ -64,8 +69,8 @@ Status ParseTablePubsub(TablePubsub *out, const RedisModuleString *pubsub_channe REDISMODULE_OK) { return Status::RedisError("Pubsub channel must be a valid integer."); } - if (pubsub_channel_long > static_cast(TablePubsub::MAX) || - pubsub_channel_long < static_cast(TablePubsub::MIN)) { + if (pubsub_channel_long >= static_cast(TablePubsub::TABLE_PUBSUB_MAX) || + pubsub_channel_long <= static_cast(TablePubsub::TABLE_PUBSUB_MIN)) { return Status::RedisError("Pubsub channel must be in the TablePubsub range."); } else { *out = static_cast(pubsub_channel_long); @@ -80,7 +85,7 @@ Status FormatPubsubChannel(RedisModuleString **out, RedisModuleCtx *ctx, const RedisModuleString *id) { // Format the pubsub channel enum to a string. TablePubsub_MAX should be more // than enough digits, but add 1 just in case for the null terminator. - char pubsub_channel[static_cast(TablePubsub::MAX) + 1]; + char pubsub_channel[static_cast(TablePubsub::TABLE_PUBSUB_MAX) + 1]; TablePubsub table_pubsub; RAY_RETURN_NOT_OK(ParseTablePubsub(&table_pubsub, pubsub_channel_str)); sprintf(pubsub_channel, "%d", static_cast(table_pubsub)); @@ -95,8 +100,8 @@ Status ParseTablePrefix(const RedisModuleString *table_prefix_str, TablePrefix * REDISMODULE_OK) { return Status::RedisError("Prefix must be a valid TablePrefix integer"); } - if (table_prefix_long > static_cast(TablePrefix::MAX) || - table_prefix_long < static_cast(TablePrefix::MIN)) { + if (table_prefix_long >= static_cast(TablePrefix::TABLE_PREFIX_MAX) || + table_prefix_long <= static_cast(TablePrefix::TABLE_PREFIX_MIN)) { return Status::RedisError("Prefix must be in the TablePrefix range"); } else { *out = static_cast(table_prefix_long); @@ -113,7 +118,7 @@ RedisModuleString *PrefixedKeyString(RedisModuleCtx *ctx, RedisModuleString *pre if (!ParseTablePrefix(prefix_enum, &prefix).ok()) { return nullptr; } - return RedisString_Format(ctx, "%s%S", EnumNameTablePrefix(prefix), keyname); + return RedisString_Format(ctx, "%s%S", TablePrefix_Name(prefix).c_str(), keyname); } // TODO(swang): This helper function should be deprecated by the version below, @@ -136,8 +141,8 @@ Status OpenPrefixedKey(RedisModuleKey **out, RedisModuleCtx *ctx, int mode, RedisModuleString **mutated_key_str) { TablePrefix prefix; RAY_RETURN_NOT_OK(ParseTablePrefix(prefix_enum, &prefix)); - *out = - OpenPrefixedKey(ctx, EnumNameTablePrefix(prefix), keyname, mode, mutated_key_str); + *out = OpenPrefixedKey(ctx, TablePrefix_Name(prefix).c_str(), keyname, mode, + mutated_key_str); return Status::OK(); } @@ -165,18 +170,24 @@ Status GetBroadcastKey(RedisModuleCtx *ctx, RedisModuleString *pubsub_channel_st return Status::OK(); } -/// This is a helper method to convert a redis module string to a flatbuffer -/// string. +/// A helper function that creates `GcsEntry` protobuf object. /// -/// \param fbb The flatbuffer builder. -/// \param redis_string The redis string. -/// \return The flatbuffer string. -flatbuffers::Offset RedisStringToFlatbuf( - flatbuffers::FlatBufferBuilder &fbb, RedisModuleString *redis_string) { - size_t redis_string_size; - const char *redis_string_str = - RedisModule_StringPtrLen(redis_string, &redis_string_size); - return fbb.CreateString(redis_string_str, redis_string_size); +/// \param[in] id Id of the entry. +/// \param[in] change_mode Change mode of the entry. +/// \param[in] entries Vector of entries. +/// \param[out] result The created `GcsEntry` object. +inline void CreateGcsEntry(RedisModuleString *id, GcsChangeMode change_mode, + const std::vector &entries, + GcsEntry *result) { + const char *data; + size_t size; + data = RedisModule_StringPtrLen(id, &size); + result->set_id(data, size); + result->set_change_mode(change_mode); + for (const auto &entry : entries) { + data = RedisModule_StringPtrLen(entry, &size); + result->add_entries(data, size); + } } /// Helper method to publish formatted data to target channel. @@ -234,13 +245,10 @@ int PublishTableUpdate(RedisModuleCtx *ctx, RedisModuleString *pubsub_channel_st RedisModuleString *id, GcsChangeMode change_mode, RedisModuleString *data) { // Serialize the notification to send. - flatbuffers::FlatBufferBuilder fbb; - auto data_flatbuf = RedisStringToFlatbuf(fbb, data); - auto message = CreateGcsEntry(fbb, change_mode, RedisStringToFlatbuf(fbb, id), - fbb.CreateVector(&data_flatbuf, 1)); - fbb.Finish(message); - auto data_buffer = RedisModule_CreateString( - ctx, reinterpret_cast(fbb.GetBufferPointer()), fbb.GetSize()); + GcsEntry gcs_entry; + CreateGcsEntry(id, change_mode, {data}, &gcs_entry); + std::string str = gcs_entry.SerializeAsString(); + auto data_buffer = RedisModule_CreateString(ctx, str.data(), str.size()); return PublishDataHelper(ctx, pubsub_channel_str, id, data_buffer); } @@ -570,19 +578,20 @@ int HashUpdate_DoWrite(RedisModuleCtx *ctx, RedisModuleString **argv, int argc, size_t update_data_len = 0; const char *update_data_buf = RedisModule_StringPtrLen(update_data, &update_data_len); - auto data_vec = flatbuffers::GetRoot(update_data_buf); - *change_mode = data_vec->change_mode(); + GcsEntry gcs_entry; + gcs_entry.ParseFromArray(update_data_buf, update_data_len); + *change_mode = gcs_entry.change_mode(); + if (*change_mode == GcsChangeMode::APPEND_OR_ADD) { // This code path means they are updating command. - size_t total_size = data_vec->entries()->size(); + size_t total_size = gcs_entry.entries_size(); REPLY_AND_RETURN_IF_FALSE(total_size % 2 == 0, "Invalid Hash Update data vector."); for (int i = 0; i < total_size; i += 2) { // Reconstruct a key-value pair from a flattened list. RedisModuleString *entry_key = RedisModule_CreateString( - ctx, data_vec->entries()->Get(i)->data(), data_vec->entries()->Get(i)->size()); - RedisModuleString *entry_value = - RedisModule_CreateString(ctx, data_vec->entries()->Get(i + 1)->data(), - data_vec->entries()->Get(i + 1)->size()); + ctx, gcs_entry.entries(i).data(), gcs_entry.entries(i).size()); + RedisModuleString *entry_value = RedisModule_CreateString( + ctx, gcs_entry.entries(i + 1).data(), gcs_entry.entries(i + 1).size()); // Returning 0 if key exists(still updated), 1 if the key is created. RAY_IGNORE_EXPR( RedisModule_HashSet(key, REDISMODULE_HASH_NONE, entry_key, entry_value, NULL)); @@ -590,27 +599,25 @@ int HashUpdate_DoWrite(RedisModuleCtx *ctx, RedisModuleString **argv, int argc, *changed_data = update_data; } else { // This code path means the command wants to remove the entries. - size_t total_size = data_vec->entries()->size(); - flatbuffers::FlatBufferBuilder fbb; - std::vector> data; + GcsEntry updated; + updated.set_id(gcs_entry.id()); + updated.set_change_mode(gcs_entry.change_mode()); + + size_t total_size = gcs_entry.entries_size(); for (int i = 0; i < total_size; i++) { RedisModuleString *entry_key = RedisModule_CreateString( - ctx, data_vec->entries()->Get(i)->data(), data_vec->entries()->Get(i)->size()); + ctx, gcs_entry.entries(i).data(), gcs_entry.entries(i).size()); int deleted_num = RedisModule_HashSet(key, REDISMODULE_HASH_NONE, entry_key, REDISMODULE_HASH_DELETE, NULL); if (deleted_num != 0) { // The corresponding key is removed. - data.push_back(fbb.CreateString(data_vec->entries()->Get(i)->data(), - data_vec->entries()->Get(i)->size())); + updated.add_entries(gcs_entry.entries(i)); } } - auto message = - CreateGcsEntry(fbb, data_vec->change_mode(), - fbb.CreateString(data_vec->id()->data(), data_vec->id()->size()), - fbb.CreateVector(data)); - fbb.Finish(message); - *changed_data = RedisModule_CreateString( - ctx, reinterpret_cast(fbb.GetBufferPointer()), fbb.GetSize()); + + // Serialize updated data. + std::string str = updated.SerializeAsString(); + *changed_data = RedisModule_CreateString(ctx, str.data(), str.size()); auto size = RedisModule_ValueLength(key); if (size == 0) { REPLY_AND_RETURN_IF_FALSE(RedisModule_DeleteKey(key) == REDISMODULE_OK, @@ -631,7 +638,7 @@ int HashUpdate_DoWrite(RedisModuleCtx *ctx, RedisModuleString **argv, int argc, /// key should be published to. When publishing to a specific client, the /// channel name should be :. /// \param id The ID of the key to remove from. -/// \param data The GcsEntry flatbugger data used to update this hash table. +/// \param data The GcsEntry protobuf data used to update this hash table. /// 1). For deletion, this is a list of keys. /// 2). For updating, this is a list of pairs with each key followed by the value. /// \return OK if the remove succeeds, or an error message string if the remove @@ -648,7 +655,7 @@ int HashUpdate_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int a return Hash_DoPublish(ctx, new_argv.data()); } -/// A helper function to create and finish a GcsEntry, based on the +/// A helper function to create a GcsEntry protobuf, based on the /// current value or values at the given key. /// /// \param ctx The Redis module context. @@ -658,21 +665,18 @@ int HashUpdate_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int a /// \param prefix_str The string prefix associated with the open Redis key. /// When parsed, this is expected to be a TablePrefix. /// \param entry_id The UniqueID associated with the open Redis key. -/// \param fbb A flatbuffer builder used to build the GcsEntry. -Status TableEntryToFlatbuf(RedisModuleCtx *ctx, RedisModuleKey *table_key, - RedisModuleString *prefix_str, RedisModuleString *entry_id, - flatbuffers::FlatBufferBuilder &fbb) { +/// \param[out] gcs_entry The created GcsEntry. +Status TableEntryToProtobuf(RedisModuleCtx *ctx, RedisModuleKey *table_key, + RedisModuleString *prefix_str, RedisModuleString *entry_id, + GcsEntry *gcs_entry) { auto key_type = RedisModule_KeyType(table_key); switch (key_type) { case REDISMODULE_KEYTYPE_STRING: { - // Build the flatbuffer from the string data. + // Build the GcsEntry from the string data. + CreateGcsEntry(entry_id, GcsChangeMode::APPEND_OR_ADD, {}, gcs_entry); size_t data_len = 0; char *data_buf = RedisModule_StringDMA(table_key, &data_len, REDISMODULE_READ); - auto data = fbb.CreateString(data_buf, data_len); - auto message = - CreateGcsEntry(fbb, GcsChangeMode::APPEND_OR_ADD, - RedisStringToFlatbuf(fbb, entry_id), fbb.CreateVector(&data, 1)); - fbb.Finish(message); + gcs_entry->add_entries(data_buf, data_len); } break; case REDISMODULE_KEYTYPE_LIST: case REDISMODULE_KEYTYPE_HASH: @@ -696,27 +700,20 @@ Status TableEntryToFlatbuf(RedisModuleCtx *ctx, RedisModuleKey *table_key, reply = RedisModule_Call(ctx, "HGETALL", "s", table_key_str); break; } - // Build the flatbuffer from the set of log entries. + // Build the GcsEntry from the set of log entries. if (reply == nullptr || RedisModule_CallReplyType(reply) != REDISMODULE_REPLY_ARRAY) { return Status::RedisError("Empty list/set/hash or wrong type"); } - std::vector> data; + CreateGcsEntry(entry_id, GcsChangeMode::APPEND_OR_ADD, {}, gcs_entry); for (size_t i = 0; i < RedisModule_CallReplyLength(reply); i++) { RedisModuleCallReply *element = RedisModule_CallReplyArrayElement(reply, i); size_t len; const char *element_str = RedisModule_CallReplyStringPtr(element, &len); - data.push_back(fbb.CreateString(element_str, len)); + gcs_entry->add_entries(element_str, len); } - auto message = - CreateGcsEntry(fbb, GcsChangeMode::APPEND_OR_ADD, - RedisStringToFlatbuf(fbb, entry_id), fbb.CreateVector(data)); - fbb.Finish(message); } break; case REDISMODULE_KEYTYPE_EMPTY: { - auto message = CreateGcsEntry( - fbb, GcsChangeMode::APPEND_OR_ADD, RedisStringToFlatbuf(fbb, entry_id), - fbb.CreateVector(std::vector>())); - fbb.Finish(message); + CreateGcsEntry(entry_id, GcsChangeMode::APPEND_OR_ADD, {}, gcs_entry); } break; default: return Status::RedisError("Invalid Redis type during lookup."); @@ -752,11 +749,12 @@ int TableLookup_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int if (table_key == nullptr) { RedisModule_ReplyWithNull(ctx); } else { - // Serialize the data to a flatbuffer to return to the client. - flatbuffers::FlatBufferBuilder fbb; - REPLY_AND_RETURN_IF_NOT_OK(TableEntryToFlatbuf(ctx, table_key, prefix_str, id, fbb)); - RedisModule_ReplyWithStringBuffer( - ctx, reinterpret_cast(fbb.GetBufferPointer()), fbb.GetSize()); + // Serialize the data to a GcsEntry to return to the client. + GcsEntry gcs_entry; + REPLY_AND_RETURN_IF_NOT_OK( + TableEntryToProtobuf(ctx, table_key, prefix_str, id, &gcs_entry)); + std::string str = gcs_entry.SerializeAsString(); + RedisModule_ReplyWithStringBuffer(ctx, str.data(), str.size()); } return REDISMODULE_OK; } @@ -870,10 +868,11 @@ int TableRequestNotifications_RedisCommand(RedisModuleCtx *ctx, RedisModuleStrin // Publish the current value at the key to the client that is requesting // notifications. An empty notification will be published if the key is // empty. - flatbuffers::FlatBufferBuilder fbb; - REPLY_AND_RETURN_IF_NOT_OK(TableEntryToFlatbuf(ctx, table_key, prefix_str, id, fbb)); - RedisModule_Call(ctx, "PUBLISH", "sb", client_channel, - reinterpret_cast(fbb.GetBufferPointer()), fbb.GetSize()); + GcsEntry gcs_entry; + REPLY_AND_RETURN_IF_NOT_OK( + TableEntryToProtobuf(ctx, table_key, prefix_str, id, &gcs_entry)); + std::string str = gcs_entry.SerializeAsString(); + RedisModule_Call(ctx, "PUBLISH", "sb", client_channel, str.data(), str.size()); return RedisModule_ReplyWithNull(ctx); } @@ -940,53 +939,6 @@ Status IsNil(bool *out, const std::string &data) { return Status::OK(); } -// This is a temporary redis command that will be removed once -// the GCS uses https://github.com/pcmoritz/credis. -// Be careful, this only supports Task Table payloads. -int TableTestAndUpdate_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, - int argc) { - if (argc != 5) { - return RedisModule_WrongArity(ctx); - } - RedisModuleString *prefix_str = argv[1]; - RedisModuleString *id = argv[3]; - RedisModuleString *update_data = argv[4]; - - RedisModuleKey *key; - REPLY_AND_RETURN_IF_NOT_OK( - OpenPrefixedKey(&key, ctx, prefix_str, id, REDISMODULE_READ | REDISMODULE_WRITE)); - - size_t value_len = 0; - char *value_buf = RedisModule_StringDMA(key, &value_len, REDISMODULE_READ); - - size_t update_len = 0; - const char *update_buf = RedisModule_StringPtrLen(update_data, &update_len); - - auto data = - flatbuffers::GetMutableRoot(reinterpret_cast(value_buf)); - - auto update = flatbuffers::GetRoot(update_buf); - - bool do_update = static_cast(data->scheduling_state()) & - static_cast(update->test_state_bitmask()); - - bool is_nil_result; - REPLY_AND_RETURN_IF_NOT_OK(IsNil(&is_nil_result, update->test_raylet_id()->str())); - if (!is_nil_result) { - do_update = do_update && update->test_raylet_id()->str() == data->raylet_id()->str(); - } - - if (do_update) { - REPLY_AND_RETURN_IF_FALSE(data->mutate_scheduling_state(update->update_state()), - "mutate_scheduling_state failed"); - } - REPLY_AND_RETURN_IF_FALSE(data->mutate_updated(do_update), "mutate_updated failed"); - - int result = RedisModule_ReplyWithStringBuffer(ctx, value_buf, value_len); - - return result; -} - std::string DebugString() { std::stringstream result; result << "RedisModule:"; @@ -1016,7 +968,6 @@ AUTO_MEMORY(TableLookup_RedisCommand); AUTO_MEMORY(TableRequestNotifications_RedisCommand); AUTO_MEMORY(TableDelete_RedisCommand); AUTO_MEMORY(TableCancelNotifications_RedisCommand); -AUTO_MEMORY(TableTestAndUpdate_RedisCommand); AUTO_MEMORY(DebugString_RedisCommand); #if RAY_USE_NEW_GCS AUTO_MEMORY(ChainTableAdd_RedisCommand); @@ -1082,12 +1033,6 @@ int RedisModule_OnLoad(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) return REDISMODULE_ERR; } - if (RedisModule_CreateCommand(ctx, "ray.table_test_and_update", - TableTestAndUpdate_RedisCommand, "write", 0, 0, - 0) == REDISMODULE_ERR) { - return REDISMODULE_ERR; - } - if (RedisModule_CreateCommand(ctx, "ray.debug_string", DebugString_RedisCommand, "readonly", 0, 0, 0) == REDISMODULE_ERR) { return REDISMODULE_ERR; diff --git a/src/ray/gcs/tables.cc b/src/ray/gcs/tables.cc index 33f1615580a6..b7c19ebfd595 100644 --- a/src/ray/gcs/tables.cc +++ b/src/ray/gcs/tables.cc @@ -3,6 +3,7 @@ #include "ray/common/common_protocol.h" #include "ray/common/ray_config.h" #include "ray/gcs/client.h" +#include "ray/rpc/util.h" #include "ray/util/util.h" namespace { @@ -39,48 +40,44 @@ namespace gcs { template Status Log::Append(const DriverID &driver_id, const ID &id, - std::shared_ptr &dataT, const WriteCallback &done) { + std::shared_ptr &data, const WriteCallback &done) { num_appends_++; - auto callback = [this, id, dataT, done](const CallbackReply &reply) { + auto callback = [this, id, data, done](const CallbackReply &reply) { const auto status = reply.ReadAsStatus(); // Failed to append the entry. RAY_CHECK(status.ok()) << "Failed to execute command TABLE_APPEND:" << status.ToString(); if (done != nullptr) { - (done)(client_, id, *dataT); + (done)(client_, id, *data); } }; - flatbuffers::FlatBufferBuilder fbb; - fbb.ForceDefaults(true); - fbb.Finish(Data::Pack(fbb, dataT.get())); - return GetRedisContext(id)->RunAsync(GetLogAppendCommand(command_type_), id, - fbb.GetBufferPointer(), fbb.GetSize(), prefix_, - pubsub_channel_, std::move(callback)); + std::string str = data->SerializeAsString(); + return GetRedisContext(id)->RunAsync(GetLogAppendCommand(command_type_), id, str.data(), + str.length(), prefix_, pubsub_channel_, + std::move(callback)); } template Status Log::AppendAt(const DriverID &driver_id, const ID &id, - std::shared_ptr &dataT, const WriteCallback &done, + std::shared_ptr &data, const WriteCallback &done, const WriteCallback &failure, int log_length) { num_appends_++; - auto callback = [this, id, dataT, done, failure](const CallbackReply &reply) { + auto callback = [this, id, data, done, failure](const CallbackReply &reply) { const auto status = reply.ReadAsStatus(); if (status.ok()) { if (done != nullptr) { - (done)(client_, id, *dataT); + (done)(client_, id, *data); } } else { if (failure != nullptr) { - (failure)(client_, id, *dataT); + (failure)(client_, id, *data); } } }; - flatbuffers::FlatBufferBuilder fbb; - fbb.ForceDefaults(true); - fbb.Finish(Data::Pack(fbb, dataT.get())); - return GetRedisContext(id)->RunAsync(GetLogAppendCommand(command_type_), id, - fbb.GetBufferPointer(), fbb.GetSize(), prefix_, - pubsub_channel_, std::move(callback), log_length); + std::string str = data->SerializeAsString(); + return GetRedisContext(id)->RunAsync(GetLogAppendCommand(command_type_), id, str.data(), + str.length(), prefix_, pubsub_channel_, + std::move(callback), log_length); } template @@ -89,16 +86,15 @@ Status Log::Lookup(const DriverID &driver_id, const ID &id, num_lookups_++; auto callback = [this, id, lookup](const CallbackReply &reply) { if (lookup != nullptr) { - std::vector results; + std::vector results; if (!reply.IsNil()) { - const auto data = reply.ReadAsString(); - auto root = flatbuffers::GetRoot(data.data()); - RAY_CHECK(from_flatbuf(*root->id()) == id); - for (size_t i = 0; i < root->entries()->size(); i++) { - DataT result; - auto data_root = flatbuffers::GetRoot(root->entries()->Get(i)->data()); - data_root->UnPackTo(&result); - results.emplace_back(std::move(result)); + GcsEntry gcs_entry; + gcs_entry.ParseFromString(reply.ReadAsString()); + RAY_CHECK(ID::FromBinary(gcs_entry.id()) == id); + for (size_t i = 0; i < gcs_entry.entries_size(); i++) { + Data data; + data.ParseFromString(gcs_entry.entries(i)); + results.emplace_back(std::move(data)); } } lookup(client_, id, results); @@ -115,7 +111,7 @@ Status Log::Subscribe(const DriverID &driver_id, const ClientID &clien const SubscriptionCallback &done) { auto subscribe_wrapper = [subscribe](AsyncGcsClient *client, const ID &id, const GcsChangeMode change_mode, - const std::vector &data) { + const std::vector &data) { RAY_CHECK(change_mode != GcsChangeMode::REMOVE); subscribe(client, id, data); }; @@ -141,19 +137,16 @@ Status Log::Subscribe(const DriverID &driver_id, const ClientID &clien // Data is provided. This is the callback for a message. if (subscribe != nullptr) { // Parse the notification. - auto root = flatbuffers::GetRoot(data.data()); - ID id; - if (root->id()->size() > 0) { - id = from_flatbuf(*root->id()); - } - std::vector results; - for (size_t i = 0; i < root->entries()->size(); i++) { - DataT result; - auto data_root = flatbuffers::GetRoot(root->entries()->Get(i)->data()); - data_root->UnPackTo(&result); + GcsEntry gcs_entry; + gcs_entry.ParseFromString(data); + ID id = ID::FromBinary(gcs_entry.id()); + std::vector results; + for (size_t i = 0; i < gcs_entry.entries_size(); i++) { + Data result; + result.ParseFromString(gcs_entry.entries(i)); results.emplace_back(std::move(result)); } - subscribe(client_, id, root->change_mode(), results); + subscribe(client_, id, gcs_entry.change_mode(), results); } } }; @@ -234,19 +227,17 @@ std::string Log::DebugString() const { template Status Table::Add(const DriverID &driver_id, const ID &id, - std::shared_ptr &dataT, const WriteCallback &done) { + std::shared_ptr &data, const WriteCallback &done) { num_adds_++; - auto callback = [this, id, dataT, done](const CallbackReply &reply) { + auto callback = [this, id, data, done](const CallbackReply &reply) { if (done != nullptr) { - (done)(client_, id, *dataT); + (done)(client_, id, *data); } }; - flatbuffers::FlatBufferBuilder fbb; - fbb.ForceDefaults(true); - fbb.Finish(Data::Pack(fbb, dataT.get())); - return GetRedisContext(id)->RunAsync(GetTableAddCommand(command_type_), id, - fbb.GetBufferPointer(), fbb.GetSize(), prefix_, - pubsub_channel_, std::move(callback)); + std::string str = data->SerializeAsString(); + return GetRedisContext(id)->RunAsync(GetTableAddCommand(command_type_), id, str.data(), + str.length(), prefix_, pubsub_channel_, + std::move(callback)); } template @@ -255,7 +246,7 @@ Status Table::Lookup(const DriverID &driver_id, const ID &id, num_lookups_++; return Log::Lookup(driver_id, id, [lookup, failure](AsyncGcsClient *client, const ID &id, - const std::vector &data) { + const std::vector &data) { if (data.empty()) { if (failure != nullptr) { (failure)(client, id); @@ -277,7 +268,7 @@ Status Table::Subscribe(const DriverID &driver_id, const ClientID &cli return Log::Subscribe( driver_id, client_id, [subscribe, failure](AsyncGcsClient *client, const ID &id, - const std::vector &data) { + const std::vector &data) { RAY_CHECK(data.empty() || data.size() == 1); if (data.size() == 1) { subscribe(client, id, data[0]); @@ -299,36 +290,30 @@ std::string Table::DebugString() const { template Status Set::Add(const DriverID &driver_id, const ID &id, - std::shared_ptr &dataT, const WriteCallback &done) { + std::shared_ptr &data, const WriteCallback &done) { num_adds_++; - auto callback = [this, id, dataT, done](const CallbackReply &reply) { + auto callback = [this, id, data, done](const CallbackReply &reply) { if (done != nullptr) { - (done)(client_, id, *dataT); + (done)(client_, id, *data); } }; - flatbuffers::FlatBufferBuilder fbb; - fbb.ForceDefaults(true); - fbb.Finish(Data::Pack(fbb, dataT.get())); - return GetRedisContext(id)->RunAsync("RAY.SET_ADD", id, fbb.GetBufferPointer(), - fbb.GetSize(), prefix_, pubsub_channel_, - std::move(callback)); + std::string str = data->SerializeAsString(); + return GetRedisContext(id)->RunAsync("RAY.SET_ADD", id, str.data(), str.length(), + prefix_, pubsub_channel_, std::move(callback)); } template Status Set::Remove(const DriverID &driver_id, const ID &id, - std::shared_ptr &dataT, const WriteCallback &done) { + std::shared_ptr &data, const WriteCallback &done) { num_removes_++; - auto callback = [this, id, dataT, done](const CallbackReply &reply) { + auto callback = [this, id, data, done](const CallbackReply &reply) { if (done != nullptr) { - (done)(client_, id, *dataT); + (done)(client_, id, *data); } }; - flatbuffers::FlatBufferBuilder fbb; - fbb.ForceDefaults(true); - fbb.Finish(Data::Pack(fbb, dataT.get())); - return GetRedisContext(id)->RunAsync("RAY.SET_REMOVE", id, fbb.GetBufferPointer(), - fbb.GetSize(), prefix_, pubsub_channel_, - std::move(callback)); + std::string str = data->SerializeAsString(); + return GetRedisContext(id)->RunAsync("RAY.SET_REMOVE", id, str.data(), str.length(), + prefix_, pubsub_channel_, std::move(callback)); } template @@ -348,26 +333,16 @@ Status Hash::Update(const DriverID &driver_id, const ID &id, (done)(client_, id, data_map); } }; - flatbuffers::FlatBufferBuilder fbb; - std::vector> data_vec; - data_vec.reserve(data_map.size() * 2); - for (auto const &pair : data_map) { - // Add the key. - data_vec.push_back(fbb.CreateString(pair.first)); - flatbuffers::FlatBufferBuilder fbb_data; - fbb_data.ForceDefaults(true); - fbb_data.Finish(Data::Pack(fbb_data, pair.second.get())); - std::string data(reinterpret_cast(fbb_data.GetBufferPointer()), - fbb_data.GetSize()); - // Add the value. - data_vec.push_back(fbb.CreateString(data)); + GcsEntry gcs_entry; + gcs_entry.set_id(id.Binary()); + gcs_entry.set_change_mode(GcsChangeMode::APPEND_OR_ADD); + for (const auto &pair : data_map) { + gcs_entry.add_entries(pair.first); + gcs_entry.add_entries(pair.second->SerializeAsString()); } - - fbb.Finish(CreateGcsEntry(fbb, GcsChangeMode::APPEND_OR_ADD, - fbb.CreateString(id.Binary()), fbb.CreateVector(data_vec))); - return GetRedisContext(id)->RunAsync("RAY.HASH_UPDATE", id, fbb.GetBufferPointer(), - fbb.GetSize(), prefix_, pubsub_channel_, - std::move(callback)); + std::string str = gcs_entry.SerializeAsString(); + return GetRedisContext(id)->RunAsync("RAY.HASH_UPDATE", id, str.data(), str.size(), + prefix_, pubsub_channel_, std::move(callback)); } template @@ -380,19 +355,15 @@ Status Hash::RemoveEntries(const DriverID &driver_id, const ID &id, (remove_callback)(client_, id, keys); } }; - flatbuffers::FlatBufferBuilder fbb; - std::vector> data_vec; - data_vec.reserve(keys.size()); - // Add the keys. - for (auto const &key : keys) { - data_vec.push_back(fbb.CreateString(key)); + GcsEntry gcs_entry; + gcs_entry.set_id(id.Binary()); + gcs_entry.set_change_mode(GcsChangeMode::REMOVE); + for (const auto &key : keys) { + gcs_entry.add_entries(key); } - - fbb.Finish(CreateGcsEntry(fbb, GcsChangeMode::REMOVE, fbb.CreateString(id.Binary()), - fbb.CreateVector(data_vec))); - return GetRedisContext(id)->RunAsync("RAY.HASH_UPDATE", id, fbb.GetBufferPointer(), - fbb.GetSize(), prefix_, pubsub_channel_, - std::move(callback)); + std::string str = gcs_entry.SerializeAsString(); + return GetRedisContext(id)->RunAsync("RAY.HASH_UPDATE", id, str.data(), str.size(), + prefix_, pubsub_channel_, std::move(callback)); } template @@ -412,17 +383,15 @@ Status Hash::Lookup(const DriverID &driver_id, const ID &id, DataMap results; if (!reply.IsNil()) { const auto data = reply.ReadAsString(); - auto root = flatbuffers::GetRoot(data.data()); - RAY_CHECK(from_flatbuf(*root->id()) == id); - RAY_CHECK(root->entries()->size() % 2 == 0); - for (size_t i = 0; i < root->entries()->size(); i += 2) { - std::string key(root->entries()->Get(i)->data(), - root->entries()->Get(i)->size()); - auto result = std::make_shared(); - auto data_root = - flatbuffers::GetRoot(root->entries()->Get(i + 1)->data()); - data_root->UnPackTo(result.get()); - results.emplace(key, std::move(result)); + GcsEntry gcs_entry; + gcs_entry.ParseFromString(reply.ReadAsString()); + RAY_CHECK(ID::FromBinary(gcs_entry.id()) == id); + RAY_CHECK(gcs_entry.entries_size() % 2 == 0); + for (int i = 0; i < gcs_entry.entries_size(); i += 2) { + const auto &key = gcs_entry.entries(i); + const auto value = std::make_shared(); + value->ParseFromString(gcs_entry.entries(i + 1)); + results.emplace(key, std::move(value)); } } lookup(client_, id, results); @@ -451,31 +420,24 @@ Status Hash::Subscribe(const DriverID &driver_id, const ClientID &clie // Data is provided. This is the callback for a message. if (subscribe != nullptr) { // Parse the notification. - auto root = flatbuffers::GetRoot(data.data()); + GcsEntry gcs_entry; + gcs_entry.ParseFromString(data); + ID id = ID::FromBinary(gcs_entry.id()); DataMap data_map; - ID id; - if (root->id()->size() > 0) { - id = from_flatbuf(*root->id()); - } - if (root->change_mode() == GcsChangeMode::REMOVE) { - for (size_t i = 0; i < root->entries()->size(); i++) { - std::string key(root->entries()->Get(i)->data(), - root->entries()->Get(i)->size()); - data_map.emplace(key, std::shared_ptr()); + if (gcs_entry.change_mode() == GcsChangeMode::REMOVE) { + for (const auto &key : gcs_entry.entries()) { + data_map.emplace(key, std::shared_ptr()); } } else { - RAY_CHECK(root->entries()->size() % 2 == 0); - for (size_t i = 0; i < root->entries()->size(); i += 2) { - std::string key(root->entries()->Get(i)->data(), - root->entries()->Get(i)->size()); - auto result = std::make_shared(); - auto data_root = - flatbuffers::GetRoot(root->entries()->Get(i + 1)->data()); - data_root->UnPackTo(result.get()); - data_map.emplace(key, std::move(result)); + RAY_CHECK(gcs_entry.entries_size() % 2 == 0); + for (int i = 0; i < gcs_entry.entries_size(); i += 2) { + const auto &key = gcs_entry.entries(i); + const auto value = std::make_shared(); + value->ParseFromString(gcs_entry.entries(i + 1)); + data_map.emplace(key, std::move(value)); } } - subscribe(client_, id, root->change_mode(), data_map); + subscribe(client_, id, gcs_entry.change_mode(), data_map); } } }; @@ -490,11 +452,11 @@ Status Hash::Subscribe(const DriverID &driver_id, const ClientID &clie Status ErrorTable::PushErrorToDriver(const DriverID &driver_id, const std::string &type, const std::string &error_message, double timestamp) { - auto data = std::make_shared(); - data->driver_id = driver_id.Binary(); - data->type = type; - data->error_message = error_message; - data->timestamp = timestamp; + auto data = std::make_shared(); + data->set_driver_id(driver_id.Binary()); + data->set_type(type); + data->set_error_message(error_message); + data->set_timestamp(timestamp); return Append(DriverID(driver_id), driver_id, data, /*done_callback=*/nullptr); } @@ -503,11 +465,9 @@ std::string ErrorTable::DebugString() const { } Status ProfileTable::AddProfileEventBatch(const ProfileTableData &profile_events) { - auto data = std::make_shared(); - // There is some room for optimization here because the Append function will just - // call "Pack" and undo the "UnPack". - profile_events.UnPackTo(data.get()); - + // TODO(hchen): Change the parameter to shared_ptr to avoid copying data. + auto data = std::make_shared(); + data->CopyFrom(profile_events); return Append(DriverID::Nil(), UniqueID::FromRandom(), data, /*done_callback=*/nullptr); } @@ -517,9 +477,9 @@ std::string ProfileTable::DebugString() const { } Status DriverTable::AppendDriverData(const DriverID &driver_id, bool is_dead) { - auto data = std::make_shared(); - data->driver_id = driver_id.Binary(); - data->is_dead = is_dead; + auto data = std::make_shared(); + data->set_driver_id(driver_id.Binary()); + data->set_is_dead(is_dead); return Append(DriverID(driver_id), driver_id, data, /*done_callback=*/nullptr); } @@ -527,7 +487,8 @@ void ClientTable::RegisterClientAddedCallback(const ClientTableCallback &callbac client_added_callback_ = callback; // Call the callback for any added clients that are cached. for (const auto &entry : client_cache_) { - if (!entry.first.IsNil() && (entry.second.entry_type == EntryType::INSERTION)) { + if (!entry.first.IsNil() && + (entry.second.entry_type() == ClientTableData::INSERTION)) { client_added_callback_(client_, entry.first, entry.second); } } @@ -537,7 +498,7 @@ void ClientTable::RegisterClientRemovedCallback(const ClientTableCallback &callb client_removed_callback_ = callback; // Call the callback for any removed clients that are cached. for (const auto &entry : client_cache_) { - if (!entry.first.IsNil() && entry.second.entry_type == EntryType::DELETION) { + if (!entry.first.IsNil() && entry.second.entry_type() == ClientTableData::DELETION) { client_removed_callback_(client_, entry.first, entry.second); } } @@ -549,7 +510,7 @@ void ClientTable::RegisterResourceCreateUpdatedCallback( // Call the callback for any clients that are cached. for (const auto &entry : client_cache_) { if (!entry.first.IsNil() && - (entry.second.entry_type == EntryType::RES_CREATEUPDATE)) { + (entry.second.entry_type() == ClientTableData::RES_CREATEUPDATE)) { resource_createupdated_callback_(client_, entry.first, entry.second); } } @@ -559,15 +520,16 @@ void ClientTable::RegisterResourceDeletedCallback(const ClientTableCallback &cal resource_deleted_callback_ = callback; // Call the callback for any clients that are cached. for (const auto &entry : client_cache_) { - if (!entry.first.IsNil() && entry.second.entry_type == EntryType::RES_DELETE) { + if (!entry.first.IsNil() && + entry.second.entry_type() == ClientTableData::RES_DELETE) { resource_deleted_callback_(client_, entry.first, entry.second); } } } void ClientTable::HandleNotification(AsyncGcsClient *client, - const ClientTableDataT &data) { - ClientID client_id = ClientID::FromBinary(data.client_id); + const ClientTableData &data) { + ClientID client_id = ClientID::FromBinary(data.client_id()); // It's possible to get duplicate notifications from the client table, so // check whether this notification is new. auto entry = client_cache_.find(client_id); @@ -578,16 +540,16 @@ void ClientTable::HandleNotification(AsyncGcsClient *client, } else { // If the entry is in the cache, then the notification is new if the client // was alive and is now dead or resources have been updated. - bool was_not_deleted = (entry->second.entry_type != EntryType::DELETION); - bool is_deleted = (data.entry_type == EntryType::DELETION); - bool is_res_modified = ((data.entry_type == EntryType::RES_CREATEUPDATE) || - (data.entry_type == EntryType::RES_DELETE)); + bool was_not_deleted = (entry->second.entry_type() != ClientTableData::DELETION); + bool is_deleted = (data.entry_type() == ClientTableData::DELETION); + bool is_res_modified = ((data.entry_type() == ClientTableData::RES_CREATEUPDATE) || + (data.entry_type() == ClientTableData::RES_DELETE)); is_notif_new = (was_not_deleted && (is_deleted || is_res_modified)); // Once a client with a given ID has been removed, it should never be added // again. If the entry was in the cache and the client was deleted, check // that this new notification is not an insertion. - if (entry->second.entry_type == EntryType::DELETION) { - RAY_CHECK((data.entry_type == EntryType::DELETION)) + if (entry->second.entry_type() == ClientTableData::DELETION) { + RAY_CHECK((data.entry_type() == ClientTableData::DELETION)) << "Notification for addition of a client that was already removed:" << client_id; } @@ -595,64 +557,64 @@ void ClientTable::HandleNotification(AsyncGcsClient *client, // Add the notification to our cache. Notifications are idempotent. // If it is a new client or a client removal, add as is - if ((data.entry_type == EntryType::INSERTION) || - (data.entry_type == EntryType::DELETION)) { + if ((data.entry_type() == ClientTableData::INSERTION) || + (data.entry_type() == ClientTableData::DELETION)) { RAY_LOG(DEBUG) << "[ClientTableNotification] ClientTable Insertion/Deletion " "notification for client id " - << client_id << ". EntryType: " << int(data.entry_type) + << client_id << ". EntryType: " << int(data.entry_type()) << ". Setting the client cache to data."; client_cache_[client_id] = data; - } else if ((data.entry_type == EntryType::RES_CREATEUPDATE) || - (data.entry_type == EntryType::RES_DELETE)) { + } else if ((data.entry_type() == ClientTableData::RES_CREATEUPDATE) || + (data.entry_type() == ClientTableData::RES_DELETE)) { RAY_LOG(DEBUG) << "[ClientTableNotification] ClientTable RES_CREATEUPDATE " "notification for client id " - << client_id << ". EntryType: " << int(data.entry_type) + << client_id << ". EntryType: " << int(data.entry_type()) << ". Updating the client cache with the delta from the log."; - ClientTableDataT &cache_data = client_cache_[client_id]; + ClientTableData &cache_data = client_cache_[client_id]; // Iterate over all resources in the new create/update notification - for (std::vector::size_type i = 0; i != data.resources_total_label.size(); i++) { - auto const &resource_name = data.resources_total_label[i]; - auto const &capacity = data.resources_total_capacity[i]; + for (std::vector::size_type i = 0; i != data.resources_total_label_size(); i++) { + auto const &resource_name = data.resources_total_label(i); + auto const &capacity = data.resources_total_capacity(i); // If resource exists in the ClientTableData, update it, else create it auto existing_resource_label = - std::find(cache_data.resources_total_label.begin(), - cache_data.resources_total_label.end(), resource_name); - if (existing_resource_label != cache_data.resources_total_label.end()) { - auto index = std::distance(cache_data.resources_total_label.begin(), + std::find(cache_data.resources_total_label().begin(), + cache_data.resources_total_label().end(), resource_name); + if (existing_resource_label != cache_data.resources_total_label().end()) { + auto index = std::distance(cache_data.resources_total_label().begin(), existing_resource_label); // Resource already exists, set capacity if updation call.. - if (data.entry_type == EntryType::RES_CREATEUPDATE) { - cache_data.resources_total_capacity[index] = capacity; + if (data.entry_type() == ClientTableData::RES_CREATEUPDATE) { + cache_data.set_resources_total_capacity(index, capacity); } // .. delete if deletion call. - else if (data.entry_type == EntryType::RES_DELETE) { - cache_data.resources_total_label.erase( - cache_data.resources_total_label.begin() + index); - cache_data.resources_total_capacity.erase( - cache_data.resources_total_capacity.begin() + index); + else if (data.entry_type() == ClientTableData::RES_DELETE) { + cache_data.mutable_resources_total_label()->erase( + cache_data.resources_total_label().begin() + index); + cache_data.mutable_resources_total_capacity()->erase( + cache_data.resources_total_capacity().begin() + index); } } else { // Resource does not exist, create resource and add capacity if it was a resource // create call. - if (data.entry_type == EntryType::RES_CREATEUPDATE) { - cache_data.resources_total_label.push_back(resource_name); - cache_data.resources_total_capacity.push_back(capacity); + if (data.entry_type() == ClientTableData::RES_CREATEUPDATE) { + cache_data.add_resources_total_label(resource_name); + cache_data.add_resources_total_capacity(capacity); } } } } // If the notification is new, call any registered callbacks. - ClientTableDataT &cache_data = client_cache_[client_id]; + ClientTableData &cache_data = client_cache_[client_id]; if (is_notif_new) { - if (data.entry_type == EntryType::INSERTION) { + if (data.entry_type() == ClientTableData::INSERTION) { if (client_added_callback_ != nullptr) { client_added_callback_(client, client_id, cache_data); } RAY_CHECK(removed_clients_.find(client_id) == removed_clients_.end()); - } else if (data.entry_type == EntryType::DELETION) { + } else if (data.entry_type() == ClientTableData::DELETION) { // NOTE(swang): The client should be added to this data structure before // the callback gets called, in case the callback depends on the data // structure getting updated. @@ -660,11 +622,11 @@ void ClientTable::HandleNotification(AsyncGcsClient *client, if (client_removed_callback_ != nullptr) { client_removed_callback_(client, client_id, cache_data); } - } else if (data.entry_type == EntryType::RES_CREATEUPDATE) { + } else if (data.entry_type() == ClientTableData::RES_CREATEUPDATE) { if (resource_createupdated_callback_ != nullptr) { resource_createupdated_callback_(client, client_id, cache_data); } - } else if (data.entry_type == EntryType::RES_DELETE) { + } else if (data.entry_type() == ClientTableData::RES_DELETE) { if (resource_deleted_callback_ != nullptr) { resource_deleted_callback_(client, client_id, cache_data); } @@ -672,54 +634,54 @@ void ClientTable::HandleNotification(AsyncGcsClient *client, } } -void ClientTable::HandleConnected(AsyncGcsClient *client, const ClientTableDataT &data) { - auto connected_client_id = ClientID::FromBinary(data.client_id); +void ClientTable::HandleConnected(AsyncGcsClient *client, const ClientTableData &data) { + auto connected_client_id = ClientID::FromBinary(data.client_id()); RAY_CHECK(client_id_ == connected_client_id) << connected_client_id << " " << client_id_; } const ClientID &ClientTable::GetLocalClientId() const { return client_id_; } -const ClientTableDataT &ClientTable::GetLocalClient() const { return local_client_; } +const ClientTableData &ClientTable::GetLocalClient() const { return local_client_; } bool ClientTable::IsRemoved(const ClientID &client_id) const { return removed_clients_.count(client_id) == 1; } -Status ClientTable::Connect(const ClientTableDataT &local_client) { +Status ClientTable::Connect(const ClientTableData &local_client) { RAY_CHECK(!disconnected_) << "Tried to reconnect a disconnected client."; - RAY_CHECK(local_client.client_id == local_client_.client_id); + RAY_CHECK(local_client.client_id() == local_client_.client_id()); local_client_ = local_client; // Construct the data to add to the client table. - auto data = std::make_shared(local_client_); - data->entry_type = EntryType::INSERTION; + auto data = std::make_shared(local_client_); + data->set_entry_type(ClientTableData::INSERTION); // Callback to handle our own successful connection once we've added // ourselves. auto add_callback = [this](AsyncGcsClient *client, const UniqueID &log_key, - const ClientTableDataT &data) { + const ClientTableData &data) { RAY_CHECK(log_key == client_log_key_); HandleConnected(client, data); // Callback for a notification from the client table. auto notification_callback = [this]( AsyncGcsClient *client, const UniqueID &log_key, - const std::vector ¬ifications) { + const std::vector ¬ifications) { RAY_CHECK(log_key == client_log_key_); - std::unordered_map connected_nodes; - std::unordered_map disconnected_nodes; + std::unordered_map connected_nodes; + std::unordered_map disconnected_nodes; for (auto ¬ification : notifications) { // This is temporary fix for Issue 4140 to avoid connect to dead nodes. // TODO(yuhguo): remove this temporary fix after GCS entry is removable. - if (notification.entry_type != EntryType::DELETION) { - connected_nodes.emplace(notification.client_id, notification); + if (notification.entry_type() != ClientTableData::DELETION) { + connected_nodes.emplace(notification.client_id(), notification); } else { - auto iter = connected_nodes.find(notification.client_id); + auto iter = connected_nodes.find(notification.client_id()); if (iter != connected_nodes.end()) { connected_nodes.erase(iter); } - disconnected_nodes.emplace(notification.client_id, notification); + disconnected_nodes.emplace(notification.client_id(), notification); } } for (const auto &pair : connected_nodes) { @@ -742,10 +704,10 @@ Status ClientTable::Connect(const ClientTableDataT &local_client) { } Status ClientTable::Disconnect(const DisconnectCallback &callback) { - auto data = std::make_shared(local_client_); - data->entry_type = EntryType::DELETION; + auto data = std::make_shared(local_client_); + data->set_entry_type(ClientTableData::DELETION); auto add_callback = [this, callback](AsyncGcsClient *client, const ClientID &id, - const ClientTableDataT &data) { + const ClientTableData &data) { HandleConnected(client, data); RAY_CHECK_OK(CancelNotifications(DriverID::Nil(), client_log_key_, id)); if (callback != nullptr) { @@ -759,24 +721,24 @@ Status ClientTable::Disconnect(const DisconnectCallback &callback) { } ray::Status ClientTable::MarkDisconnected(const ClientID &dead_client_id) { - auto data = std::make_shared(); - data->client_id = dead_client_id.Binary(); - data->entry_type = EntryType::DELETION; + auto data = std::make_shared(); + data->set_client_id(dead_client_id.Binary()); + data->set_entry_type(ClientTableData::DELETION); return Append(DriverID::Nil(), client_log_key_, data, nullptr); } void ClientTable::GetClient(const ClientID &client_id, - ClientTableDataT &client_info) const { + ClientTableData &client_info) const { RAY_CHECK(!client_id.IsNil()); auto entry = client_cache_.find(client_id); if (entry != client_cache_.end()) { client_info = entry->second; } else { - client_info.client_id = ClientID::Nil().Binary(); + client_info.set_client_id(ClientID::Nil().Binary()); } } -const std::unordered_map &ClientTable::GetAllClients() const { +const std::unordered_map &ClientTable::GetAllClients() const { return client_cache_; } @@ -798,31 +760,29 @@ Status ActorCheckpointIdTable::AddCheckpointId(const DriverID &driver_id, const ActorCheckpointID &checkpoint_id) { auto lookup_callback = [this, checkpoint_id, driver_id, actor_id]( ray::gcs::AsyncGcsClient *client, const UniqueID &id, - const ActorCheckpointIdDataT &data) { - std::shared_ptr copy = - std::make_shared(data); - copy->timestamps.push_back(current_sys_time_ms()); - copy->checkpoint_ids += checkpoint_id.Binary(); + const ActorCheckpointIdData &data) { + std::shared_ptr copy = + std::make_shared(data); + copy->add_timestamps(current_sys_time_ms()); + copy->add_checkpoint_ids(checkpoint_id.Binary()); auto num_to_keep = RayConfig::instance().num_actor_checkpoints_to_keep(); - while (copy->timestamps.size() > num_to_keep) { + while (copy->timestamps().size() > num_to_keep) { // Delete the checkpoint from actor checkpoint table. - const auto &checkpoint_id = - ActorCheckpointID::FromBinary(copy->checkpoint_ids.substr(0, kUniqueIDSize)); - RAY_LOG(DEBUG) << "Deleting checkpoint " << checkpoint_id << " for actor " - << actor_id; - copy->timestamps.erase(copy->timestamps.begin()); - copy->checkpoint_ids.erase(0, kUniqueIDSize); - client_->actor_checkpoint_table().Delete(driver_id, checkpoint_id); + const auto &to_delete = ActorCheckpointID::FromBinary(copy->checkpoint_ids(0)); + RAY_LOG(DEBUG) << "Deleting checkpoint " << to_delete << " for actor " << actor_id; + copy->mutable_checkpoint_ids()->erase(copy->mutable_checkpoint_ids()->begin()); + copy->mutable_timestamps()->erase(copy->mutable_timestamps()->begin()); + client_->actor_checkpoint_table().Delete(driver_id, to_delete); } RAY_CHECK_OK(Add(driver_id, actor_id, copy, nullptr)); }; auto failure_callback = [this, checkpoint_id, driver_id, actor_id]( ray::gcs::AsyncGcsClient *client, const UniqueID &id) { - std::shared_ptr data = - std::make_shared(); - data->actor_id = id.Binary(); - data->timestamps.push_back(current_sys_time_ms()); - data->checkpoint_ids = checkpoint_id.Binary(); + std::shared_ptr data = + std::make_shared(); + data->set_actor_id(id.Binary()); + data->add_timestamps(current_sys_time_ms()); + *data->add_checkpoint_ids() = checkpoint_id.Binary(); RAY_CHECK_OK(Add(driver_id, actor_id, data, nullptr)); }; return Lookup(driver_id, actor_id, lookup_callback, failure_callback); @@ -830,8 +790,7 @@ Status ActorCheckpointIdTable::AddCheckpointId(const DriverID &driver_id, template class Log; template class Set; -template class Log; -template class Table; +template class Log; template class Table; template class Log; template class Log; diff --git a/src/ray/gcs/tables.h b/src/ray/gcs/tables.h index 6a1d502a7f54..2ecc3440839e 100644 --- a/src/ray/gcs/tables.h +++ b/src/ray/gcs/tables.h @@ -11,10 +11,8 @@ #include "ray/common/status.h" #include "ray/util/logging.h" -#include "ray/gcs/format/gcs_generated.h" #include "ray/gcs/redis_context.h" -// TODO(rkn): Remove this include. -#include "ray/raylet/format/node_manager_generated.h" +#include "ray/protobuf/gcs.pb.h" struct redisAsyncContext; @@ -22,6 +20,25 @@ namespace ray { namespace gcs { +using rpc::ActorCheckpointData; +using rpc::ActorCheckpointIdData; +using rpc::ActorTableData; +using rpc::ClientTableData; +using rpc::DriverTableData; +using rpc::ErrorTableData; +using rpc::GcsChangeMode; +using rpc::GcsEntry; +using rpc::HeartbeatBatchTableData; +using rpc::HeartbeatTableData; +using rpc::ObjectTableData; +using rpc::ProfileTableData; +using rpc::RayResource; +using rpc::TablePrefix; +using rpc::TablePubsub; +using rpc::TaskLeaseData; +using rpc::TaskReconstructionData; +using rpc::TaskTableData; + class RedisContext; class AsyncGcsClient; @@ -48,13 +65,12 @@ class PubsubInterface { template class LogInterface { public: - using DataT = typename Data::NativeTableType; using WriteCallback = - std::function; + std::function; virtual Status Append(const DriverID &driver_id, const ID &id, - std::shared_ptr &data, const WriteCallback &done) = 0; + std::shared_ptr &data, const WriteCallback &done) = 0; virtual Status AppendAt(const DriverID &driver_id, const ID &task_id, - std::shared_ptr &data, const WriteCallback &done, + std::shared_ptr &data, const WriteCallback &done, const WriteCallback &failure, int log_length) = 0; virtual ~LogInterface(){}; }; @@ -72,12 +88,11 @@ class LogInterface { template class Log : public LogInterface, virtual public PubsubInterface { public: - using DataT = typename Data::NativeTableType; using Callback = std::function &data)>; - using NotificationCallback = std::function &data)>; + const std::vector &data)>; + using NotificationCallback = + std::function &data)>; /// The callback to call when a write to a key succeeds. using WriteCallback = typename LogInterface::WriteCallback; /// The callback to call when a SUBSCRIBE call completes and we are ready to @@ -86,7 +101,7 @@ class Log : public LogInterface, virtual public PubsubInterface { struct CallbackData { ID id; - std::shared_ptr data; + std::shared_ptr data; Callback callback; // An optional callback to call for subscription operations, where the // first message is a notification of subscription success. @@ -111,7 +126,7 @@ class Log : public LogInterface, virtual public PubsubInterface { /// \param done Callback that is called once the data has been written to the /// GCS. /// \return Status - Status Append(const DriverID &driver_id, const ID &id, std::shared_ptr &data, + Status Append(const DriverID &driver_id, const ID &id, std::shared_ptr &data, const WriteCallback &done); /// Append a log entry to a key if and only if the log has the given number @@ -126,7 +141,7 @@ class Log : public LogInterface, virtual public PubsubInterface { /// \param log_length The number of entries that the log must have for the /// append to succeed. /// \return Status - Status AppendAt(const DriverID &driver_id, const ID &id, std::shared_ptr &data, + Status AppendAt(const DriverID &driver_id, const ID &id, std::shared_ptr &data, const WriteCallback &done, const WriteCallback &failure, int log_length); @@ -259,10 +274,9 @@ class Log : public LogInterface, virtual public PubsubInterface { template class TableInterface { public: - using DataT = typename Data::NativeTableType; using WriteCallback = typename Log::WriteCallback; virtual Status Add(const DriverID &driver_id, const ID &task_id, - std::shared_ptr &data, const WriteCallback &done) = 0; + std::shared_ptr &data, const WriteCallback &done) = 0; virtual ~TableInterface(){}; }; @@ -280,9 +294,8 @@ class Table : private Log, public TableInterface, virtual public PubsubInterface { public: - using DataT = typename Log::DataT; using Callback = - std::function; + std::function; using WriteCallback = typename Log::WriteCallback; /// The callback to call when a Lookup call returns an empty entry. using FailureCallback = std::function; @@ -305,7 +318,7 @@ class Table : private Log, /// \param done Callback that is called once the data has been written to the /// GCS. /// \return Status - Status Add(const DriverID &driver_id, const ID &id, std::shared_ptr &data, + Status Add(const DriverID &driver_id, const ID &id, std::shared_ptr &data, const WriteCallback &done); /// Lookup an entry asynchronously. @@ -369,12 +382,11 @@ class Table : private Log, template class SetInterface { public: - using DataT = typename Data::NativeTableType; using WriteCallback = typename Log::WriteCallback; - virtual Status Add(const DriverID &driver_id, const ID &id, - std::shared_ptr &data, const WriteCallback &done) = 0; + virtual Status Add(const DriverID &driver_id, const ID &id, std::shared_ptr &data, + const WriteCallback &done) = 0; virtual Status Remove(const DriverID &driver_id, const ID &id, - std::shared_ptr &data, const WriteCallback &done) = 0; + std::shared_ptr &data, const WriteCallback &done) = 0; virtual ~SetInterface(){}; }; @@ -392,7 +404,6 @@ class Set : private Log, public SetInterface, virtual public PubsubInterface { public: - using DataT = typename Log::DataT; using Callback = typename Log::Callback; using WriteCallback = typename Log::WriteCallback; using NotificationCallback = typename Log::NotificationCallback; @@ -414,7 +425,7 @@ class Set : private Log, /// \param done Callback that is called once the data has been written to the /// GCS. /// \return Status - Status Add(const DriverID &driver_id, const ID &id, std::shared_ptr &data, + Status Add(const DriverID &driver_id, const ID &id, std::shared_ptr &data, const WriteCallback &done); /// Remove an entry from the set. @@ -425,7 +436,7 @@ class Set : private Log, /// \param done Callback that is called once the data has been written to the /// GCS. /// \return Status - Status Remove(const DriverID &driver_id, const ID &id, std::shared_ptr &data, + Status Remove(const DriverID &driver_id, const ID &id, std::shared_ptr &data, const WriteCallback &done); Status Subscribe(const DriverID &driver_id, const ClientID &client_id, @@ -454,8 +465,7 @@ class Set : private Log, template class HashInterface { public: - using DataT = typename Data::NativeTableType; - using DataMap = std::unordered_map>; + using DataMap = std::unordered_map>; // Reuse Log's SubscriptionCallback when Subscribe is successfully called. using SubscriptionCallback = typename Log::SubscriptionCallback; @@ -544,8 +554,7 @@ class Hash : private Log, public HashInterface, virtual public PubsubInterface { public: - using DataT = typename Log::DataT; - using DataMap = std::unordered_map>; + using DataMap = std::unordered_map>; using HashCallback = typename HashInterface::HashCallback; using HashRemoveCallback = typename HashInterface::HashRemoveCallback; using HashNotificationCallback = @@ -595,7 +604,7 @@ class DynamicResourceTable : public Hash { DynamicResourceTable(const std::vector> &contexts, AsyncGcsClient *client) : Hash(contexts, client) { - pubsub_channel_ = TablePubsub::NODE_RESOURCE; + pubsub_channel_ = TablePubsub::NODE_RESOURCE_PUBSUB; prefix_ = TablePrefix::NODE_RESOURCE; }; @@ -607,7 +616,7 @@ class ObjectTable : public Set { ObjectTable(const std::vector> &contexts, AsyncGcsClient *client) : Set(contexts, client) { - pubsub_channel_ = TablePubsub::OBJECT; + pubsub_channel_ = TablePubsub::OBJECT_PUBSUB; prefix_ = TablePrefix::OBJECT; }; @@ -619,7 +628,7 @@ class HeartbeatTable : public Table { HeartbeatTable(const std::vector> &contexts, AsyncGcsClient *client) : Table(contexts, client) { - pubsub_channel_ = TablePubsub::HEARTBEAT; + pubsub_channel_ = TablePubsub::HEARTBEAT_PUBSUB; prefix_ = TablePrefix::HEARTBEAT; } virtual ~HeartbeatTable() {} @@ -630,7 +639,7 @@ class HeartbeatBatchTable : public Table { HeartbeatBatchTable(const std::vector> &contexts, AsyncGcsClient *client) : Table(contexts, client) { - pubsub_channel_ = TablePubsub::HEARTBEAT_BATCH; + pubsub_channel_ = TablePubsub::HEARTBEAT_BATCH_PUBSUB; prefix_ = TablePrefix::HEARTBEAT_BATCH; } virtual ~HeartbeatBatchTable() {} @@ -641,7 +650,7 @@ class DriverTable : public Log { DriverTable(const std::vector> &contexts, AsyncGcsClient *client) : Log(contexts, client) { - pubsub_channel_ = TablePubsub::DRIVER; + pubsub_channel_ = TablePubsub::DRIVER_PUBSUB; prefix_ = TablePrefix::DRIVER; }; @@ -655,18 +664,6 @@ class DriverTable : public Log { Status AppendDriverData(const DriverID &driver_id, bool is_dead); }; -class FunctionTable : public Table { - public: - FunctionTable(const std::vector> &contexts, - AsyncGcsClient *client) - : Table(contexts, client) { - pubsub_channel_ = TablePubsub::NO_PUBLISH; - prefix_ = TablePrefix::FUNCTION; - }; -}; - -using ClassTable = Table; - /// Actor table starts with an ALIVE entry, which represents the first time the actor /// is created. This may be followed by 0 or more pairs of RECONSTRUCTING, ALIVE entries, /// which represent each time the actor fails (RECONSTRUCTING) and gets recreated (ALIVE). @@ -677,7 +674,7 @@ class ActorTable : public Log { ActorTable(const std::vector> &contexts, AsyncGcsClient *client) : Log(contexts, client) { - pubsub_channel_ = TablePubsub::ACTOR; + pubsub_channel_ = TablePubsub::ACTOR_PUBSUB; prefix_ = TablePrefix::ACTOR; } }; @@ -696,12 +693,12 @@ class TaskLeaseTable : public Table { TaskLeaseTable(const std::vector> &contexts, AsyncGcsClient *client) : Table(contexts, client) { - pubsub_channel_ = TablePubsub::TASK_LEASE; + pubsub_channel_ = TablePubsub::TASK_LEASE_PUBSUB; prefix_ = TablePrefix::TASK_LEASE; } Status Add(const DriverID &driver_id, const TaskID &id, - std::shared_ptr &data, const WriteCallback &done) override { + std::shared_ptr &data, const WriteCallback &done) override { RAY_RETURN_NOT_OK((Table::Add(driver_id, id, data, done))); // Mark the entry for expiration in Redis. It's okay if this command fails // since the lease entry itself contains the expiration period. In the @@ -709,9 +706,8 @@ class TaskLeaseTable : public Table { // entry will overestimate the expiration time. // TODO(swang): Use a common helper function to format the key instead of // hardcoding it to match the Redis module. - std::vector args = {"PEXPIRE", - EnumNameTablePrefix(prefix_) + id.Binary(), - std::to_string(data->timeout)}; + std::vector args = {"PEXPIRE", TablePrefix_Name(prefix_) + id.Binary(), + std::to_string(data->timeout())}; return GetRedisContext(id)->RunArgvAsync(args); } @@ -747,12 +743,12 @@ class ActorCheckpointIdTable : public Table { namespace raylet { -class TaskTable : public Table { +class TaskTable : public Table { public: TaskTable(const std::vector> &contexts, AsyncGcsClient *client) : Table(contexts, client) { - pubsub_channel_ = TablePubsub::RAYLET_TASK; + pubsub_channel_ = TablePubsub::RAYLET_TASK_PUBSUB; prefix_ = TablePrefix::RAYLET_TASK; } @@ -770,7 +766,7 @@ class ErrorTable : private Log { ErrorTable(const std::vector> &contexts, AsyncGcsClient *client) : Log(contexts, client) { - pubsub_channel_ = TablePubsub::ERROR_INFO; + pubsub_channel_ = TablePubsub::ERROR_INFO_PUBSUB; prefix_ = TablePrefix::ERROR_INFO; }; @@ -815,10 +811,6 @@ class ProfileTable : private Log { std::string DebugString() const; }; -using CustomSerializerTable = Table; - -using ConfigTable = Table; - /// \class ClientTable /// /// The ClientTable stores information about active and inactive clients. It is @@ -831,7 +823,7 @@ using ConfigTable = Table; class ClientTable : public Log { public: using ClientTableCallback = std::function; + AsyncGcsClient *client, const ClientID &id, const ClientTableData &data)>; using DisconnectCallback = std::function; ClientTable(const std::vector> &contexts, AsyncGcsClient *client, const ClientID &client_id) @@ -842,11 +834,11 @@ class ClientTable : public Log { disconnected_(false), client_id_(client_id), local_client_() { - pubsub_channel_ = TablePubsub::CLIENT; + pubsub_channel_ = TablePubsub::CLIENT_PUBSUB; prefix_ = TablePrefix::CLIENT; // Set the local client's ID. - local_client_.client_id = client_id.Binary(); + local_client_.set_client_id(client_id.Binary()); }; /// Connect as a client to the GCS. This registers us in the client table @@ -855,7 +847,7 @@ class ClientTable : public Log { /// \param Information about the connecting client. This must have the /// same client_id as the one set in the client table. /// \return Status - ray::Status Connect(const ClientTableDataT &local_client); + ray::Status Connect(const ClientTableData &local_client); /// Disconnect the client from the GCS. The client ID assigned during /// registration should never be reused after disconnecting. @@ -898,7 +890,7 @@ class ClientTable : public Log { /// about the client in the cache, then the reference will be modified to /// contain that information. Else, the reference will be updated to contain /// a nil client ID. - void GetClient(const ClientID &client, ClientTableDataT &client_info) const; + void GetClient(const ClientID &client, ClientTableData &client_info) const; /// Get the local client's ID. /// @@ -908,7 +900,7 @@ class ClientTable : public Log { /// Get the local client's information. /// /// \return The local client's information. - const ClientTableDataT &GetLocalClient() const; + const ClientTableData &GetLocalClient() const; /// Check whether the given client is removed. /// @@ -919,7 +911,7 @@ class ClientTable : public Log { /// Get the information of all clients. /// /// \return The client ID to client information map. - const std::unordered_map &GetAllClients() const; + const std::unordered_map &GetAllClients() const; /// Lookup the client data in the client table. /// @@ -940,15 +932,15 @@ class ClientTable : public Log { private: /// Handle a client table notification. - void HandleNotification(AsyncGcsClient *client, const ClientTableDataT ¬ifications); + void HandleNotification(AsyncGcsClient *client, const ClientTableData ¬ifications); /// Handle this client's successful connection to the GCS. - void HandleConnected(AsyncGcsClient *client, const ClientTableDataT &client_data); + void HandleConnected(AsyncGcsClient *client, const ClientTableData &client_data); /// Whether this client has called Disconnect(). bool disconnected_; /// This client's ID. const ClientID client_id_; /// Information about this client. - ClientTableDataT local_client_; + ClientTableData local_client_; /// The callback to call when a new client is added. ClientTableCallback client_added_callback_; /// The callback to call when a client is removed. @@ -958,7 +950,7 @@ class ClientTable : public Log { /// The callback to call when a resource is deleted. ClientTableCallback resource_deleted_callback_; /// A cache for information about all clients. - std::unordered_map client_cache_; + std::unordered_map client_cache_; /// The set of removed clients. std::unordered_set removed_clients_; }; diff --git a/src/ray/object_manager/object_directory.cc b/src/ray/object_manager/object_directory.cc index 5b6794a505d3..454379d18302 100644 --- a/src/ray/object_manager/object_directory.cc +++ b/src/ray/object_manager/object_directory.cc @@ -8,18 +8,22 @@ ObjectDirectory::ObjectDirectory(boost::asio::io_service &io_service, namespace { +using ray::rpc::ClientTableData; +using ray::rpc::GcsChangeMode; +using ray::rpc::ObjectTableData; + /// Process a notification of the object table entries and store the result in /// client_ids. This assumes that client_ids already contains the result of the /// object table entries up to but not including this notification. void UpdateObjectLocations(const GcsChangeMode change_mode, - const std::vector &location_updates, + const std::vector &location_updates, const ray::gcs::ClientTable &client_table, std::unordered_set *client_ids) { // location_updates contains the updates of locations of the object. // with GcsChangeMode, we can determine whether the update mode is // addition or deletion. for (const auto &object_table_data : location_updates) { - ClientID client_id = ClientID::FromBinary(object_table_data.manager); + ClientID client_id = ClientID::FromBinary(object_table_data.manager()); if (change_mode != GcsChangeMode::REMOVE) { client_ids->insert(client_id); } else { @@ -42,7 +46,7 @@ void ObjectDirectory::RegisterBackend() { auto object_notification_callback = [this](gcs::AsyncGcsClient *client, const ObjectID &object_id, const GcsChangeMode change_mode, - const std::vector &location_updates) { + const std::vector &location_updates) { // Objects are added to this map in SubscribeObjectLocations. auto it = listeners_.find(object_id); // Do nothing for objects we are not listening for. @@ -79,9 +83,9 @@ ray::Status ObjectDirectory::ReportObjectAdded( const object_manager::protocol::ObjectInfoT &object_info) { RAY_LOG(DEBUG) << "Reporting object added to GCS " << object_id; // Append the addition entry to the object table. - auto data = std::make_shared(); - data->manager = client_id.Binary(); - data->object_size = object_info.data_size; + auto data = std::make_shared(); + data->set_manager(client_id.Binary()); + data->set_object_size(object_info.data_size); ray::Status status = gcs_client_->object_table().Add(DriverID::Nil(), object_id, data, nullptr); return status; @@ -92,9 +96,9 @@ ray::Status ObjectDirectory::ReportObjectRemoved( const object_manager::protocol::ObjectInfoT &object_info) { RAY_LOG(DEBUG) << "Reporting object removed to GCS " << object_id; // Append the eviction entry to the object table. - auto data = std::make_shared(); - data->manager = client_id.Binary(); - data->object_size = object_info.data_size; + auto data = std::make_shared(); + data->set_manager(client_id.Binary()); + data->set_object_size(object_info.data_size); ray::Status status = gcs_client_->object_table().Remove(DriverID::Nil(), object_id, data, nullptr); return status; @@ -102,14 +106,14 @@ ray::Status ObjectDirectory::ReportObjectRemoved( void ObjectDirectory::LookupRemoteConnectionInfo( RemoteConnectionInfo &connection_info) const { - ClientTableDataT client_data; + ClientTableData client_data; gcs_client_->client_table().GetClient(connection_info.client_id, client_data); - ClientID result_client_id = ClientID::FromBinary(client_data.client_id); + ClientID result_client_id = ClientID::FromBinary(client_data.client_id()); if (!result_client_id.IsNil()) { RAY_CHECK(result_client_id == connection_info.client_id); - if (client_data.entry_type == EntryType::INSERTION) { - connection_info.ip = client_data.node_manager_address; - connection_info.port = static_cast(client_data.object_manager_port); + if (client_data.entry_type() == ClientTableData::INSERTION) { + connection_info.ip = client_data.node_manager_address(); + connection_info.port = static_cast(client_data.object_manager_port()); } } } @@ -208,7 +212,7 @@ ray::Status ObjectDirectory::LookupLocations(const ObjectID &object_id, status = gcs_client_->object_table().Lookup( DriverID::Nil(), object_id, [this, callback](gcs::AsyncGcsClient *client, const ObjectID &object_id, - const std::vector &location_updates) { + const std::vector &location_updates) { // Build the set of current locations based on the entries in the log. std::unordered_set client_ids; UpdateObjectLocations(GcsChangeMode::APPEND_OR_ADD, location_updates, diff --git a/src/ray/object_manager/object_manager.cc b/src/ray/object_manager/object_manager.cc index 954162c21aef..964cee605ced 100644 --- a/src/ray/object_manager/object_manager.cc +++ b/src/ray/object_manager/object_manager.cc @@ -309,15 +309,15 @@ void ObjectManager::HandleSendFinished(const ObjectID &object_id, // TODO(rkn): What do we want to do if the send failed? } - ProfileEventT profile_event; - profile_event.event_type = "transfer_send"; - profile_event.start_time = start_time; - profile_event.end_time = end_time; + rpc::ProfileTableData::ProfileEvent profile_event; + profile_event.set_event_type("transfer_send"); + profile_event.set_start_time(start_time); + profile_event.set_end_time(end_time); // Encode the object ID, client ID, chunk index, and status as a json list, // which will be parsed by the reader of the profile table. - profile_event.extra_data = "[\"" + object_id.Hex() + "\",\"" + client_id.Hex() + "\"," + - std::to_string(chunk_index) + ",\"" + status.ToString() + - "\"]"; + profile_event.set_extra_data("[\"" + object_id.Hex() + "\",\"" + client_id.Hex() + + "\"," + std::to_string(chunk_index) + ",\"" + + status.ToString() + "\"]"); profile_events_.push_back(profile_event); } @@ -329,15 +329,15 @@ void ObjectManager::HandleReceiveFinished(const ObjectID &object_id, // TODO(rkn): What do we want to do if the send failed? } - ProfileEventT profile_event; - profile_event.event_type = "transfer_receive"; - profile_event.start_time = start_time; - profile_event.end_time = end_time; + rpc::ProfileTableData::ProfileEvent profile_event; + profile_event.set_event_type("transfer_receive"); + profile_event.set_start_time(start_time); + profile_event.set_end_time(end_time); // Encode the object ID, client ID, chunk index, and status as a json list, // which will be parsed by the reader of the profile table. - profile_event.extra_data = "[\"" + object_id.Hex() + "\",\"" + client_id.Hex() + "\"," + - std::to_string(chunk_index) + ",\"" + status.ToString() + - "\"]"; + profile_event.set_extra_data("[\"" + object_id.Hex() + "\",\"" + client_id.Hex() + + "\"," + std::to_string(chunk_index) + ",\"" + + status.ToString() + "\"]"); profile_events_.push_back(profile_event); } @@ -801,11 +801,12 @@ void ObjectManager::ReceivePullRequest(std::shared_ptr &con ObjectID object_id = ObjectID::FromBinary(pr->object_id()->str()); ClientID client_id = ClientID::FromBinary(pr->client_id()->str()); - ProfileEventT profile_event; - profile_event.event_type = "receive_pull_request"; - profile_event.start_time = current_sys_time_seconds(); - profile_event.end_time = profile_event.start_time; - profile_event.extra_data = "[\"" + object_id.Hex() + "\",\"" + client_id.Hex() + "\"]"; + rpc::ProfileTableData::ProfileEvent profile_event; + profile_event.set_event_type("receive_pull_request"); + profile_event.set_start_time(current_sys_time_seconds()); + profile_event.set_end_time(profile_event.start_time()); + profile_event.set_extra_data("[\"" + object_id.Hex() + "\",\"" + client_id.Hex() + + "\"]"); profile_events_.push_back(profile_event); Push(object_id, client_id); @@ -938,13 +939,13 @@ void ObjectManager::SpreadFreeObjectRequest(const std::vector &object_ } } -ProfileTableDataT ObjectManager::GetAndResetProfilingInfo() { - ProfileTableDataT profile_info; - profile_info.component_type = "object_manager"; - profile_info.component_id = client_id_.Binary(); +rpc::ProfileTableData ObjectManager::GetAndResetProfilingInfo() { + rpc::ProfileTableData profile_info; + profile_info.set_component_type("object_manager"); + profile_info.set_component_id(client_id_.Binary()); for (auto const &profile_event : profile_events_) { - profile_info.profile_events.emplace_back(new ProfileEventT(profile_event)); + profile_info.add_profile_events()->CopyFrom(profile_event); } profile_events_.clear(); diff --git a/src/ray/object_manager/object_manager.h b/src/ray/object_manager/object_manager.h index 6318250ae3e8..6664dd0a93bd 100644 --- a/src/ray/object_manager/object_manager.h +++ b/src/ray/object_manager/object_manager.h @@ -180,7 +180,7 @@ class ObjectManager : public ObjectManagerInterface { /// /// \return All profiling information that has accumulated since the last call /// to this method. - ProfileTableDataT GetAndResetProfilingInfo(); + rpc::ProfileTableData GetAndResetProfilingInfo(); /// Returns debug string for class. /// @@ -412,7 +412,7 @@ class ObjectManager : public ObjectManagerInterface { /// Profiling events that are to be batched together and added to the profile /// table in the GCS. - std::vector profile_events_; + std::vector profile_events_; /// Internally maintained random number generator. std::mt19937_64 gen_; diff --git a/src/ray/object_manager/test/object_manager_stress_test.cc b/src/ray/object_manager/test/object_manager_stress_test.cc index 55aa59124a99..2d5292842acf 100644 --- a/src/ray/object_manager/test/object_manager_stress_test.cc +++ b/src/ray/object_manager/test/object_manager_stress_test.cc @@ -11,6 +11,8 @@ namespace ray { +using rpc::ClientTableData; + std::string store_executable; static inline void flushall_redis(void) { @@ -52,10 +54,10 @@ class MockServer { std::string ip = endpoint.address().to_string(); unsigned short object_manager_port = endpoint.port(); - ClientTableDataT client_info = gcs_client_->client_table().GetLocalClient(); - client_info.node_manager_address = ip; - client_info.node_manager_port = object_manager_port; - client_info.object_manager_port = object_manager_port; + ClientTableData client_info = gcs_client_->client_table().GetLocalClient(); + client_info.set_node_manager_address(ip); + client_info.set_node_manager_port(object_manager_port); + client_info.set_object_manager_port(object_manager_port); ray::Status status = gcs_client_->client_table().Connect(client_info); object_manager_.RegisterGcs(); return status; @@ -242,8 +244,8 @@ class StressTestObjectManager : public TestObjectManagerBase { client_id_2 = gcs_client_2->client_table().GetLocalClientId(); gcs_client_1->client_table().RegisterClientAddedCallback( [this](gcs::AsyncGcsClient *client, const ClientID &id, - const ClientTableDataT &data) { - ClientID parsed_id = ClientID::FromBinary(data.client_id); + const ClientTableData &data) { + ClientID parsed_id = ClientID::FromBinary(data.client_id()); if (parsed_id == client_id_1 || parsed_id == client_id_2) { num_connected_clients += 1; } @@ -438,16 +440,16 @@ class StressTestObjectManager : public TestObjectManagerBase { RAY_LOG(DEBUG) << "\n" << "All connected clients:" << "\n"; - ClientTableDataT data; + ClientTableData data; gcs_client_1->client_table().GetClient(client_id_1, data); - RAY_LOG(DEBUG) << "ClientID=" << ClientID::FromBinary(data.client_id) << "\n" - << "ClientIp=" << data.node_manager_address << "\n" - << "ClientPort=" << data.node_manager_port; - ClientTableDataT data2; + RAY_LOG(DEBUG) << "ClientID=" << ClientID::FromBinary(data.client_id()) << "\n" + << "ClientIp=" << data.node_manager_address() << "\n" + << "ClientPort=" << data.node_manager_port(); + ClientTableData data2; gcs_client_1->client_table().GetClient(client_id_2, data2); - RAY_LOG(DEBUG) << "ClientID=" << ClientID::FromBinary(data2.client_id) << "\n" - << "ClientIp=" << data2.node_manager_address << "\n" - << "ClientPort=" << data2.node_manager_port; + RAY_LOG(DEBUG) << "ClientID=" << ClientID::FromBinary(data2.client_id()) << "\n" + << "ClientIp=" << data2.node_manager_address() << "\n" + << "ClientPort=" << data2.node_manager_port(); } }; diff --git a/src/ray/object_manager/test/object_manager_test.cc b/src/ray/object_manager/test/object_manager_test.cc index ee6c78d8ed42..45b80a267f2f 100644 --- a/src/ray/object_manager/test/object_manager_test.cc +++ b/src/ray/object_manager/test/object_manager_test.cc @@ -14,6 +14,8 @@ int64_t wait_timeout_ms; namespace ray { +using rpc::ClientTableData; + static inline void flushall_redis(void) { redisContext *context = redisConnect("127.0.0.1", 6379); freeReplyObject(redisCommand(context, "FLUSHALL")); @@ -46,10 +48,10 @@ class MockServer { std::string ip = endpoint.address().to_string(); unsigned short object_manager_port = endpoint.port(); - ClientTableDataT client_info = gcs_client_->client_table().GetLocalClient(); - client_info.node_manager_address = ip; - client_info.node_manager_port = object_manager_port; - client_info.object_manager_port = object_manager_port; + ClientTableData client_info = gcs_client_->client_table().GetLocalClient(); + client_info.set_node_manager_address(ip); + client_info.set_node_manager_port(object_manager_port); + client_info.set_object_manager_port(object_manager_port); ray::Status status = gcs_client_->client_table().Connect(client_info); object_manager_.RegisterGcs(); return status; @@ -221,8 +223,8 @@ class TestObjectManager : public TestObjectManagerBase { client_id_2 = gcs_client_2->client_table().GetLocalClientId(); gcs_client_1->client_table().RegisterClientAddedCallback( [this](gcs::AsyncGcsClient *client, const ClientID &id, - const ClientTableDataT &data) { - ClientID parsed_id = ClientID::FromBinary(data.client_id); + const ClientTableData &data) { + ClientID parsed_id = ClientID::FromBinary(data.client_id()); if (parsed_id == client_id_1 || parsed_id == client_id_2) { num_connected_clients += 1; } @@ -457,19 +459,19 @@ class TestObjectManager : public TestObjectManagerBase { RAY_LOG(DEBUG) << "\n" << "Server client ids:" << "\n"; - ClientTableDataT data; + ClientTableData data; gcs_client_1->client_table().GetClient(client_id_1, data); - RAY_LOG(DEBUG) << (ClientID::FromBinary(data.client_id).IsNil()); - RAY_LOG(DEBUG) << "Server 1 ClientID=" << ClientID::FromBinary(data.client_id); - RAY_LOG(DEBUG) << "Server 1 ClientIp=" << data.node_manager_address; - RAY_LOG(DEBUG) << "Server 1 ClientPort=" << data.node_manager_port; - ASSERT_EQ(client_id_1, ClientID::FromBinary(data.client_id)); - ClientTableDataT data2; + RAY_LOG(DEBUG) << (ClientID::FromBinary(data.client_id()).IsNil()); + RAY_LOG(DEBUG) << "Server 1 ClientID=" << ClientID::FromBinary(data.client_id()); + RAY_LOG(DEBUG) << "Server 1 ClientIp=" << data.node_manager_address(); + RAY_LOG(DEBUG) << "Server 1 ClientPort=" << data.node_manager_port(); + ASSERT_EQ(client_id_1, ClientID::FromBinary(data.client_id())); + ClientTableData data2; gcs_client_1->client_table().GetClient(client_id_2, data2); - RAY_LOG(DEBUG) << "Server 2 ClientID=" << ClientID::FromBinary(data2.client_id); - RAY_LOG(DEBUG) << "Server 2 ClientIp=" << data2.node_manager_address; - RAY_LOG(DEBUG) << "Server 2 ClientPort=" << data2.node_manager_port; - ASSERT_EQ(client_id_2, ClientID::FromBinary(data2.client_id)); + RAY_LOG(DEBUG) << "Server 2 ClientID=" << ClientID::FromBinary(data2.client_id()); + RAY_LOG(DEBUG) << "Server 2 ClientIp=" << data2.node_manager_address(); + RAY_LOG(DEBUG) << "Server 2 ClientPort=" << data2.node_manager_port(); + ASSERT_EQ(client_id_2, ClientID::FromBinary(data2.client_id())); } }; diff --git a/src/ray/protobuf/gcs.proto b/src/ray/protobuf/gcs.proto new file mode 100644 index 000000000000..d0b2c5e007fe --- /dev/null +++ b/src/ray/protobuf/gcs.proto @@ -0,0 +1,280 @@ +syntax = "proto3"; + +package ray.rpc; + +option java_package = "org.ray.runtime.generated"; + +// Language of a worker or task. +enum Language { + PYTHON = 0; + CPP = 1; + JAVA = 2; +} + +// These indexes are mapped to strings in ray_redis_module.cc. +enum TablePrefix { + TABLE_PREFIX_MIN = 0; + UNUSED = 1; + TASK = 2; + RAYLET_TASK = 3; + CLIENT = 4; + OBJECT = 5; + ACTOR = 6; + FUNCTION = 7; + TASK_RECONSTRUCTION = 8; + HEARTBEAT = 9; + HEARTBEAT_BATCH = 10; + ERROR_INFO = 11; + DRIVER = 12; + PROFILE = 13; + TASK_LEASE = 14; + ACTOR_CHECKPOINT = 15; + ACTOR_CHECKPOINT_ID = 16; + NODE_RESOURCE = 17; + TABLE_PREFIX_MAX = 18; +} + +// The channel that Add operations to the Table should be published on, if any. +enum TablePubsub { + TABLE_PUBSUB_MIN = 0; + NO_PUBLISH = 1; + TASK_PUBSUB = 2; + RAYLET_TASK_PUBSUB = 3; + CLIENT_PUBSUB = 4; + OBJECT_PUBSUB = 5; + ACTOR_PUBSUB = 6; + HEARTBEAT_PUBSUB = 7; + HEARTBEAT_BATCH_PUBSUB = 8; + ERROR_INFO_PUBSUB = 9; + TASK_LEASE_PUBSUB = 10; + DRIVER_PUBSUB = 11; + NODE_RESOURCE_PUBSUB = 12; + TABLE_PUBSUB_MAX = 13; +} + +enum GcsChangeMode { + APPEND_OR_ADD = 0; + REMOVE = 1; +} + +message GcsEntry { + GcsChangeMode change_mode = 1; + bytes id = 2; + repeated bytes entries = 3; +} + +message ObjectTableData { + // The size of the object. + uint64 object_size = 1; + // The node manager ID that this object appeared on or was evicted by. + bytes manager = 2; +} + +message TaskReconstructionData { + // The number of times this task has been reconstructed so far. + uint64 num_reconstructions = 1; + // The node manager that is trying to reconstruct the task. + bytes node_manager_id = 2; +} + +// TODO(hchen): Task table currently still uses flatbuffers-defined data structure +// (`Task` in `node_manager.fbs`), because a lot of code depends on that. This should +// be migrated to protobuf very soon. +message TaskTableData { + // Flatbuffers-serialized content of the task, see `src/ray/raylet/task.h`. + bytes task = 1; +} + +message ActorTableData { + // State of an actor. + enum ActorState { + // Actor is alive. + ALIVE = 0; + // Actor is dead, now being reconstructed. + // After reconstruction finishes, the state will become alive again. + RECONSTRUCTING = 1; + // Actor is already dead and won't be reconstructed. + DEAD = 2; + } + // The ID of the actor that was created. + bytes actor_id = 1; + // The dummy object ID returned by the actor creation task. If the actor + // dies, then this is the object that should be reconstructed for the actor + // to be recreated. + bytes actor_creation_dummy_object_id = 2; + // The ID of the driver that created the actor. + bytes driver_id = 3; + // The ID of the node manager that created the actor. + bytes node_manager_id = 4; + // Current state of this actor. + ActorState state = 5; + // Max number of times this actor should be reconstructed. + uint64 max_reconstructions = 6; + // Remaining number of reconstructions. + uint64 remaining_reconstructions = 7; +} + +message ErrorTableData { + // The ID of the driver that the error is for. + bytes driver_id = 1; + // The type of the error. + string type = 2; + // The error message. + string error_message = 3; + // The timestamp of the error message. + double timestamp = 4; +} + +message ProfileTableData { + // Represents a profile event. + message ProfileEvent { + // The type of the event. + string event_type = 1; + // The start time of the event. + double start_time = 2; + // The end time of the event. If the event is a point event, then this should + // be the same as the start time. + double end_time = 3; + // Additional data associated with the event. This data must be serialized + // using JSON. + string extra_data = 4; + } + + // The type of the component that generated the event, e.g., worker or + // object_manager, or node_manager. + string component_type = 1; + // An identifier for the component that generated the event. + bytes component_id = 2; + // An identifier for the node that generated the event. + string node_ip_address = 3; + // This is a batch of profiling events. We batch these together for + // performance reasons because a single task may generate many events, and + // we don't want each event to require a GCS command. + repeated ProfileEvent profile_events = 4; +} + +message RayResource { + // The type of the resource. + string resource_name = 1; + // The total capacity of this resource type. + double resource_capacity = 2; +} + +message ClientTableData { + // Enum for the entry type in the ClientTable + enum EntryType { + INSERTION = 0; + DELETION = 1; + RES_CREATEUPDATE = 2; + RES_DELETE = 3; + } + + // The client ID of the client that the message is about. + bytes client_id = 1; + // The IP address of the client's node manager. + string node_manager_address = 2; + // The IPC socket name of the client's raylet. + string raylet_socket_name = 3; + // The IPC socket name of the client's plasma store. + string object_store_socket_name = 4; + // The port at which the client's node manager is listening for TCP + // connections from other node managers. + int32 node_manager_port = 5; + // The port at which the client's object manager is listening for TCP + // connections from other object managers. + int32 object_manager_port = 6; + // Enum to store the entry type in the log + EntryType entry_type = 7; + + // TODO(hchen): Define the following resources in map format. + repeated string resources_total_label = 8; + repeated double resources_total_capacity = 9; +} + +message HeartbeatTableData { + // Node manager client id + bytes client_id = 1; + // TODO(hchen): Define the following resources in map format. + // Resource capacity currently available on this node manager. + repeated string resources_available_label = 2; + repeated double resources_available_capacity = 3; + // Total resource capacity configured for this node manager. + repeated string resources_total_label = 4; + repeated double resources_total_capacity = 5; + // Aggregate outstanding resource load on this node manager. + repeated string resource_load_label = 6; + repeated double resource_load_capacity = 7; +} + +message HeartbeatBatchTableData { + repeated HeartbeatTableData batch = 1; +} + +// Data for a lease on task execution. +message TaskLeaseData { + // Node manager client ID. + bytes node_manager_id = 1; + // The time that the lease was last acquired at. NOTE(swang): This is the + // system clock time according to the node that added the entry and is not + // synchronized with other nodes. + uint64 acquired_at = 2; + // The period that the lease is active for. + uint64 timeout = 3; +} + +message DriverTableData { + // The driver ID. + bytes driver_id = 1; + // Whether it's dead. + bool is_dead = 2; +} + +// This table stores the actor checkpoint data. An actor checkpoint +// is the snapshot of an actor's state in the actor registration. +// See `actor_registration.h` for more detailed explanation of these fields. +message ActorCheckpointData { + // ID of this actor. + bytes actor_id = 1; + // The dummy object ID of actor's most recently executed task. + bytes execution_dependency = 2; + // A list of IDs of this actor's handles. + repeated bytes handle_ids = 3; + // The task counters of the above handles. + repeated uint64 task_counters = 4; + // The frontier dependencies of the above handles. + repeated bytes frontier_dependencies = 5; + // A list of unreleased dummy objects from this actor. + repeated bytes unreleased_dummy_objects = 6; + // The numbers of dependencies for the above unreleased dummy objects. + repeated uint32 num_dummy_object_dependencies = 7; +} + +// This table stores the actor-to-available-checkpoint-ids mapping. +message ActorCheckpointIdData { + // ID of this actor. + bytes actor_id = 1; + // IDs of this actor's available checkpoints. + repeated bytes checkpoint_ids = 2; + // A list of the timestamps for each of the above `checkpoint_ids`. + repeated uint64 timestamps = 3; +} + +// This enum type is used as object's metadata to indicate the object's creating +// task has failed because of a certain error. +// TODO(hchen): We may want to make these errors more specific. E.g., we may want +// to distinguish between intentional and expected actor failures, and between +// worker process failure and node failure. +enum ErrorType { + // Indicates that a task failed because the worker died unexpectedly while executing it. + WORKER_DIED = 0; + // Indicates that a task failed because the actor died unexpectedly before finishing it. + ACTOR_DIED = 1; + // Indicates that an object is lost and cannot be reconstructed. + // Note, this currently only happens to actor objects. When the actor's state is already + // after the object's creating task, the actor cannot re-run the task. + // TODO(hchen): we may want to reuse this error type for more cases. E.g., + // 1) A object that was put by the driver. + // 2) The object's creating task is already cleaned up from GCS (this currently + // crashes raylet). + OBJECT_UNRECONSTRUCTABLE = 2; +} diff --git a/src/ray/raylet/actor_registration.cc b/src/ray/raylet/actor_registration.cc index cc587bc4d74e..7f940006b5be 100644 --- a/src/ray/raylet/actor_registration.cc +++ b/src/ray/raylet/actor_registration.cc @@ -8,34 +8,35 @@ namespace ray { namespace raylet { -ActorRegistration::ActorRegistration(const ActorTableDataT &actor_table_data) +ActorRegistration::ActorRegistration(const ActorTableData &actor_table_data) : actor_table_data_(actor_table_data) {} -ActorRegistration::ActorRegistration(const ActorTableDataT &actor_table_data, - const ActorCheckpointDataT &checkpoint_data) +ActorRegistration::ActorRegistration(const ActorTableData &actor_table_data, + const ActorCheckpointData &checkpoint_data) : actor_table_data_(actor_table_data), - execution_dependency_(ObjectID::FromBinary(checkpoint_data.execution_dependency)) { + execution_dependency_( + ObjectID::FromBinary(checkpoint_data.execution_dependency())) { // Restore `frontier_`. - for (size_t i = 0; i < checkpoint_data.handle_ids.size(); i++) { - auto handle_id = ActorHandleID::FromBinary(checkpoint_data.handle_ids[i]); + for (size_t i = 0; i < checkpoint_data.handle_ids_size(); i++) { + auto handle_id = ActorHandleID::FromBinary(checkpoint_data.handle_ids(i)); auto &frontier_entry = frontier_[handle_id]; - frontier_entry.task_counter = checkpoint_data.task_counters[i]; + frontier_entry.task_counter = checkpoint_data.task_counters(i); frontier_entry.execution_dependency = - ObjectID::FromBinary(checkpoint_data.frontier_dependencies[i]); + ObjectID::FromBinary(checkpoint_data.frontier_dependencies(i)); } // Restore `dummy_objects_`. - for (size_t i = 0; i < checkpoint_data.unreleased_dummy_objects.size(); i++) { - auto dummy = ObjectID::FromBinary(checkpoint_data.unreleased_dummy_objects[i]); - dummy_objects_[dummy] = checkpoint_data.num_dummy_object_dependencies[i]; + for (size_t i = 0; i < checkpoint_data.unreleased_dummy_objects_size(); i++) { + auto dummy = ObjectID::FromBinary(checkpoint_data.unreleased_dummy_objects(i)); + dummy_objects_[dummy] = checkpoint_data.num_dummy_object_dependencies(i); } } const ClientID ActorRegistration::GetNodeManagerId() const { - return ClientID::FromBinary(actor_table_data_.node_manager_id); + return ClientID::FromBinary(actor_table_data_.node_manager_id()); } const ObjectID ActorRegistration::GetActorCreationDependency() const { - return ObjectID::FromBinary(actor_table_data_.actor_creation_dummy_object_id); + return ObjectID::FromBinary(actor_table_data_.actor_creation_dummy_object_id()); } const ObjectID ActorRegistration::GetExecutionDependency() const { @@ -43,15 +44,15 @@ const ObjectID ActorRegistration::GetExecutionDependency() const { } const DriverID ActorRegistration::GetDriverId() const { - return DriverID::FromBinary(actor_table_data_.driver_id); + return DriverID::FromBinary(actor_table_data_.driver_id()); } const int64_t ActorRegistration::GetMaxReconstructions() const { - return actor_table_data_.max_reconstructions; + return actor_table_data_.max_reconstructions(); } const int64_t ActorRegistration::GetRemainingReconstructions() const { - return actor_table_data_.remaining_reconstructions; + return actor_table_data_.remaining_reconstructions(); } const std::unordered_map @@ -96,7 +97,7 @@ void ActorRegistration::AddHandle(const ActorHandleID &handle_id, int ActorRegistration::NumHandles() const { return frontier_.size(); } -std::shared_ptr ActorRegistration::GenerateCheckpointData( +std::shared_ptr ActorRegistration::GenerateCheckpointData( const ActorID &actor_id, const Task &task) { const auto actor_handle_id = task.GetTaskSpecification().ActorHandleId(); const auto dummy_object = task.GetTaskSpecification().ActorDummyObject(); @@ -109,18 +110,18 @@ std::shared_ptr ActorRegistration::GenerateCheckpointData( copy.ExtendFrontier(actor_handle_id, dummy_object); // Use actor's current state to generate checkpoint data. - auto checkpoint_data = std::make_shared(); - checkpoint_data->actor_id = actor_id.Binary(); - checkpoint_data->execution_dependency = copy.GetExecutionDependency().Binary(); + auto checkpoint_data = std::make_shared(); + checkpoint_data->set_actor_id(actor_id.Binary()); + checkpoint_data->set_execution_dependency(copy.GetExecutionDependency().Binary()); for (const auto &frontier : copy.GetFrontier()) { - checkpoint_data->handle_ids.push_back(frontier.first.Binary()); - checkpoint_data->task_counters.push_back(frontier.second.task_counter); - checkpoint_data->frontier_dependencies.push_back( + checkpoint_data->add_handle_ids(frontier.first.Binary()); + checkpoint_data->add_task_counters(frontier.second.task_counter); + checkpoint_data->add_frontier_dependencies( frontier.second.execution_dependency.Binary()); } for (const auto &entry : copy.GetDummyObjects()) { - checkpoint_data->unreleased_dummy_objects.push_back(entry.first.Binary()); - checkpoint_data->num_dummy_object_dependencies.push_back(entry.second); + checkpoint_data->add_unreleased_dummy_objects(entry.first.Binary()); + checkpoint_data->add_num_dummy_object_dependencies(entry.second); } return checkpoint_data; } diff --git a/src/ray/raylet/actor_registration.h b/src/ray/raylet/actor_registration.h index 8d7ce2a449ec..208e4998263f 100644 --- a/src/ray/raylet/actor_registration.h +++ b/src/ray/raylet/actor_registration.h @@ -4,13 +4,17 @@ #include #include "ray/common/id.h" -#include "ray/gcs/format/gcs_generated.h" +#include "ray/protobuf/gcs.pb.h" #include "ray/raylet/task.h" namespace ray { namespace raylet { +using rpc::ActorTableData; +using ActorState = rpc::ActorTableData::ActorState; +using rpc::ActorCheckpointData; + /// \class ActorRegistration /// /// Information about an actor registered in the system. This includes the @@ -23,13 +27,13 @@ class ActorRegistration { /// /// \param actor_table_data Information from the global actor table about /// this actor. This includes the actor's node manager location. - ActorRegistration(const ActorTableDataT &actor_table_data); + explicit ActorRegistration(const ActorTableData &actor_table_data); /// Recreate an actor's registration from a checkpoint. /// /// \param checkpoint_data The checkpoint used to restore the actor. - ActorRegistration(const ActorTableDataT &actor_table_data, - const ActorCheckpointDataT &checkpoint_data); + ActorRegistration(const ActorTableData &actor_table_data, + const ActorCheckpointData &checkpoint_data); /// Each actor may have multiple callers, or "handles". A frontier leaf /// represents the execution state of the actor with respect to a single @@ -46,15 +50,15 @@ class ActorRegistration { /// Get the actor table data. /// /// \return The actor table data. - const ActorTableDataT &GetTableData() const { return actor_table_data_; } + const ActorTableData &GetTableData() const { return actor_table_data_; } /// Get the actor's current state (ALIVE or DEAD). /// /// \return The actor's current state. - const ActorState &GetState() const { return actor_table_data_.state; } + const ActorState GetState() const { return actor_table_data_.state(); } /// Update actor's state. - void SetState(const ActorState &state) { actor_table_data_.state = state; } + void SetState(const ActorState &state) { actor_table_data_.set_state(state); } /// Get the actor's node manager location. /// @@ -131,13 +135,13 @@ class ActorRegistration { /// \param actor_id ID of this actor. /// \param task The task that just finished on the actor. /// \return A shared pointer to the generated checkpoint data. - std::shared_ptr GenerateCheckpointData(const ActorID &actor_id, - const Task &task); + std::shared_ptr GenerateCheckpointData(const ActorID &actor_id, + const Task &task); private: /// Information from the global actor table about this actor, including the /// node manager location. - ActorTableDataT actor_table_data_; + ActorTableData actor_table_data_; /// The object representing the state following the actor's most recently /// executed task. The next task to execute on the actor should be marked as /// execution-dependent on this object. diff --git a/src/ray/raylet/lineage_cache.cc b/src/ray/raylet/lineage_cache.cc index 32dddada5244..68d5aa817c2b 100644 --- a/src/ray/raylet/lineage_cache.cc +++ b/src/ray/raylet/lineage_cache.cc @@ -63,15 +63,6 @@ void LineageEntry::UpdateTaskData(const Task &task) { Lineage::Lineage() {} -Lineage::Lineage(const protocol::ForwardTaskRequest &task_request) { - // Deserialize and set entries for the uncommitted tasks. - auto tasks = task_request.uncommitted_tasks(); - for (auto it = tasks->begin(); it != tasks->end(); it++) { - const auto &task = **it; - RAY_CHECK(SetEntry(task, GcsStatus::UNCOMMITTED)); - } -} - boost::optional Lineage::GetEntry(const TaskID &task_id) const { auto entry = entries_.find(task_id); if (entry != entries_.end()) { @@ -151,20 +142,6 @@ const std::unordered_map &Lineage::GetEntries() cons return entries_; } -flatbuffers::Offset Lineage::ToFlatbuffer( - flatbuffers::FlatBufferBuilder &fbb, const TaskID &task_id) const { - RAY_CHECK(GetEntry(task_id)); - // Serialize the task and object entries. - std::vector> uncommitted_tasks; - for (const auto &entry : entries_) { - uncommitted_tasks.push_back(entry.second.TaskData().ToFlatbuffer(fbb)); - } - - auto request = protocol::CreateForwardTaskRequest(fbb, to_flatbuf(fbb, task_id), - fbb.CreateVector(uncommitted_tasks)); - return request; -} - const std::unordered_set &Lineage::GetChildren(const TaskID &task_id) const { static const std::unordered_set empty_children; const auto it = children_.find(task_id); @@ -176,7 +153,7 @@ const std::unordered_set &Lineage::GetChildren(const TaskID &task_id) co } LineageCache::LineageCache(const ClientID &client_id, - gcs::TableInterface &task_storage, + gcs::TableInterface &task_storage, gcs::PubsubInterface &task_pubsub, uint64_t max_lineage_size) : client_id_(client_id), task_storage_(task_storage), task_pubsub_(task_pubsub) {} @@ -292,15 +269,11 @@ void LineageCache::FlushTask(const TaskID &task_id) { gcs::raylet::TaskTable::WriteCallback task_callback = [this](ray::gcs::AsyncGcsClient *client, const TaskID &id, - const protocol::TaskT &data) { HandleEntryCommitted(id); }; + const TaskTableData &data) { HandleEntryCommitted(id); }; auto task = lineage_.GetEntry(task_id); // TODO(swang): Make this better... - flatbuffers::FlatBufferBuilder fbb; - auto message = task->TaskData().ToFlatbuffer(fbb); - fbb.Finish(message); - auto task_data = std::make_shared(); - auto root = flatbuffers::GetRoot(fbb.GetBufferPointer()); - root->UnPackTo(task_data.get()); + auto task_data = std::make_shared(); + task_data->set_task(task->TaskData().Serialize()); RAY_CHECK_OK( task_storage_.Add(DriverID(task->TaskData().GetTaskSpecification().DriverId()), task_id, task_data, task_callback)); @@ -365,8 +338,6 @@ void LineageCache::EvictTask(const TaskID &task_id) { for (const auto &child_id : children) { EvictTask(child_id); } - - return; } void LineageCache::HandleEntryCommitted(const TaskID &task_id) { diff --git a/src/ray/raylet/lineage_cache.h b/src/ray/raylet/lineage_cache.h index 5436fa372fa4..37ce5caf6507 100644 --- a/src/ray/raylet/lineage_cache.h +++ b/src/ray/raylet/lineage_cache.h @@ -4,18 +4,17 @@ #include #include -// clang-format off -#include "ray/common/common_protocol.h" -#include "ray/raylet/task.h" -#include "ray/gcs/tables.h" #include "ray/common/id.h" #include "ray/common/status.h" -// clang-format on +#include "ray/gcs/tables.h" +#include "ray/raylet/task.h" namespace ray { namespace raylet { +using rpc::TaskTableData; + /// The status of a lineage cache entry according to its status in the GCS. /// Tasks can only transition to a higher GcsStatus (e.g., an UNCOMMITTED state /// can become COMMITTING but not vice versa). If a task is evicted from the @@ -136,12 +135,6 @@ class Lineage { /// Construct an empty Lineage. Lineage(); - /// Construct a Lineage from a ForwardTaskRequest. - /// - /// \param task_request The request to construct the lineage from. All - /// uncommitted tasks in the request will be added to the lineage. - Lineage(const protocol::ForwardTaskRequest &task_request); - /// Get an entry from the lineage. /// /// \param entry_id The ID of the entry to get. @@ -172,15 +165,6 @@ class Lineage { /// \return A const reference to the lineage entries. const std::unordered_map &GetEntries() const; - /// Serialize this lineage to a ForwardTaskRequest flatbuffer. - /// - /// \param entry_id The task ID to include in the ForwardTaskRequest - /// flatbuffer. - /// \return An offset to the serialized lineage. The serialization includes - /// all task and object entries in the lineage. - flatbuffers::Offset ToFlatbuffer( - flatbuffers::FlatBufferBuilder &fbb, const TaskID &entry_id) const; - /// Return the IDs of tasks in the lineage that are dependent on the given /// task. /// @@ -221,7 +205,7 @@ class LineageCache { /// Create a lineage cache for the given task storage system. /// TODO(swang): Pass in the policy (interface?). LineageCache(const ClientID &client_id, - gcs::TableInterface &task_storage, + gcs::TableInterface &task_storage, gcs::PubsubInterface &task_pubsub, uint64_t max_lineage_size); /// Asynchronously commit a task to the GCS. @@ -319,7 +303,7 @@ class LineageCache { /// TODO(swang): Move the ClientID into the generic Table implementation. ClientID client_id_; /// The durable storage system for task information. - gcs::TableInterface &task_storage_; + gcs::TableInterface &task_storage_; /// The pubsub storage system for task information. This can be used to /// request notifications for the commit of a task entry. gcs::PubsubInterface &task_pubsub_; diff --git a/src/ray/raylet/lineage_cache_test.cc b/src/ray/raylet/lineage_cache_test.cc index 43e64e400292..a6184902f803 100644 --- a/src/ray/raylet/lineage_cache_test.cc +++ b/src/ray/raylet/lineage_cache_test.cc @@ -13,7 +13,7 @@ namespace ray { namespace raylet { -class MockGcs : public gcs::TableInterface, +class MockGcs : public gcs::TableInterface, public gcs::PubsubInterface { public: MockGcs() {} @@ -23,15 +23,15 @@ class MockGcs : public gcs::TableInterface, } Status Add(const DriverID &driver_id, const TaskID &task_id, - std::shared_ptr &task_data, - const gcs::TableInterface::WriteCallback &done) { + std::shared_ptr &task_data, + const gcs::TableInterface::WriteCallback &done) { task_table_[task_id] = task_data; auto callback = done; // If we requested notifications for this task ID, send the notification as // part of the callback. if (subscribed_tasks_.count(task_id) == 1) { callback = [this, done](ray::gcs::AsyncGcsClient *client, const TaskID &task_id, - const protocol::TaskT &data) { + const TaskTableData &data) { done(client, task_id, data); // If we're subscribed to the task to be added, also send a // subscription notification. @@ -45,14 +45,14 @@ class MockGcs : public gcs::TableInterface, return ray::Status::OK(); } - Status RemoteAdd(const TaskID &task_id, std::shared_ptr task_data) { + Status RemoteAdd(const TaskID &task_id, std::shared_ptr task_data) { task_table_[task_id] = task_data; // Send a notification after the add if the lineage cache requested // notifications for this key. bool send_notification = (subscribed_tasks_.count(task_id) == 1); auto callback = [this, send_notification](ray::gcs::AsyncGcsClient *client, const TaskID &task_id, - const protocol::TaskT &data) { + const TaskTableData &data) { if (send_notification) { notification_callback_(client, task_id, data); } @@ -84,7 +84,7 @@ class MockGcs : public gcs::TableInterface, } } - const std::unordered_map> &TaskTable() const { + const std::unordered_map> &TaskTable() const { return task_table_; } @@ -95,7 +95,7 @@ class MockGcs : public gcs::TableInterface, const int NumTaskAdds() const { return num_task_adds_; } private: - std::unordered_map> task_table_; + std::unordered_map> task_table_; std::vector> callbacks_; gcs::raylet::TaskTable::WriteCallback notification_callback_; std::unordered_set subscribed_tasks_; @@ -111,7 +111,7 @@ class LineageCacheTest : public ::testing::Test { mock_gcs_(), lineage_cache_(ClientID::FromRandom(), mock_gcs_, mock_gcs_, max_lineage_size_) { mock_gcs_.Subscribe([this](ray::gcs::AsyncGcsClient *client, const TaskID &task_id, - const ray::protocol::TaskT &data) { + const TaskTableData &data) { lineage_cache_.HandleEntryCommitted(task_id); num_notifications_++; }); @@ -341,7 +341,7 @@ TEST_F(LineageCacheTest, TestEvictChain) { ASSERT_EQ(lineage_cache_.GetLineage().GetEntries().size(), tasks.size()); // Simulate executing the task on a remote node and adding it to the GCS. - auto task_data = std::make_shared(); + auto task_data = std::make_shared(); RAY_CHECK_OK( mock_gcs_.RemoteAdd(tasks.at(1).GetTaskSpecification().TaskId(), task_data)); mock_gcs_.Flush(); @@ -432,7 +432,7 @@ TEST_F(LineageCacheTest, TestEviction) { // Simulate executing the first task on a remote node and adding it to the // GCS. - auto task_data = std::make_shared(); + auto task_data = std::make_shared(); auto it = tasks.begin(); RAY_CHECK_OK(mock_gcs_.RemoteAdd(it->GetTaskSpecification().TaskId(), task_data)); it++; @@ -490,7 +490,7 @@ TEST_F(LineageCacheTest, TestOutOfOrderEviction) { auto last_task = tasks.front(); tasks.erase(tasks.begin()); for (auto it = tasks.rbegin(); it != tasks.rend(); it++) { - auto task_data = std::make_shared(); + auto task_data = std::make_shared(); RAY_CHECK_OK(mock_gcs_.RemoteAdd(it->GetTaskSpecification().TaskId(), task_data)); // Check that the remote task is flushed. num_tasks_flushed++; @@ -500,7 +500,7 @@ TEST_F(LineageCacheTest, TestOutOfOrderEviction) { } // Flush the last task. The lineage should not get evicted until this task's // commit is received. - auto task_data = std::make_shared(); + auto task_data = std::make_shared(); RAY_CHECK_OK(mock_gcs_.RemoteAdd(last_task.GetTaskSpecification().TaskId(), task_data)); num_tasks_flushed++; mock_gcs_.Flush(); @@ -536,7 +536,7 @@ TEST_F(LineageCacheTest, TestEvictionUncommittedChildren) { // until after the final remote task is executed, since a task can only be // evicted once all of its ancestors have been committed. for (auto it = tasks.rbegin(); it != tasks.rend(); it++) { - auto task_data = std::make_shared(); + auto task_data = std::make_shared(); ASSERT_EQ(lineage_cache_.GetLineage().GetEntries().size(), lineage_size * 2); RAY_CHECK_OK(mock_gcs_.RemoteAdd(it->GetTaskSpecification().TaskId(), task_data)); num_tasks_flushed++; diff --git a/src/ray/raylet/monitor.cc b/src/ray/raylet/monitor.cc index 62ecb00b819f..0a853260887e 100644 --- a/src/ray/raylet/monitor.cc +++ b/src/ray/raylet/monitor.cc @@ -24,14 +24,14 @@ Monitor::Monitor(boost::asio::io_service &io_service, const std::string &redis_a } void Monitor::HandleHeartbeat(const ClientID &client_id, - const HeartbeatTableDataT &heartbeat_data) { + const HeartbeatTableData &heartbeat_data) { heartbeats_[client_id] = num_heartbeats_timeout_; heartbeat_buffer_[client_id] = heartbeat_data; } void Monitor::Start() { const auto heartbeat_callback = [this](gcs::AsyncGcsClient *client, const ClientID &id, - const HeartbeatTableDataT &heartbeat_data) { + const HeartbeatTableData &heartbeat_data) { HandleHeartbeat(id, heartbeat_data); }; RAY_CHECK_OK(gcs_client_.heartbeat_table().Subscribe( @@ -49,11 +49,11 @@ void Monitor::Tick() { RAY_LOG(WARNING) << "Client timed out: " << client_id; auto lookup_callback = [this, client_id]( gcs::AsyncGcsClient *client, const ClientID &id, - const std::vector &all_data) { + const std::vector &all_data) { bool marked = false; for (const auto &data : all_data) { - if (client_id.Binary() == data.client_id && - data.entry_type == EntryType::DELETION) { + if (client_id.Binary() == data.client_id() && + data.entry_type() == ClientTableData::DELETION) { // The node has been marked dead by itself. marked = true; } @@ -84,10 +84,9 @@ void Monitor::Tick() { // Send any buffered heartbeats as a single publish. if (!heartbeat_buffer_.empty()) { - auto batch = std::make_shared(); + auto batch = std::make_shared(); for (const auto &heartbeat : heartbeat_buffer_) { - batch->batch.push_back(std::unique_ptr( - new HeartbeatTableDataT(heartbeat.second))); + batch->add_batch()->CopyFrom(heartbeat.second); } RAY_CHECK_OK(gcs_client_.heartbeat_batch_table().Add(DriverID::Nil(), ClientID::Nil(), batch, nullptr)); diff --git a/src/ray/raylet/monitor.h b/src/ray/raylet/monitor.h index c69cc9f003e0..5725e52cf495 100644 --- a/src/ray/raylet/monitor.h +++ b/src/ray/raylet/monitor.h @@ -11,6 +11,10 @@ namespace ray { namespace raylet { +using rpc::ClientTableData; +using rpc::HeartbeatBatchTableData; +using rpc::HeartbeatTableData; + class Monitor { public: /// Create a Raylet monitor attached to the given GCS address and port. @@ -35,7 +39,7 @@ class Monitor { /// \param client_id The client ID of the Raylet that sent the heartbeat. /// \param heartbeat_data The heartbeat sent by the client. void HandleHeartbeat(const ClientID &client_id, - const HeartbeatTableDataT &heartbeat_data); + const HeartbeatTableData &heartbeat_data); private: /// A client to the GCS, through which heartbeats are received. @@ -50,7 +54,7 @@ class Monitor { /// The Raylets that have been marked as dead in the client table. std::unordered_set dead_clients_; /// A buffer containing heartbeats received from node managers in the last tick. - std::unordered_map heartbeat_buffer_; + std::unordered_map heartbeat_buffer_; }; } // namespace raylet diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index a0bde1ff0655..226a8fb6d251 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -46,9 +46,9 @@ ActorStats GetActorStatisticalData( std::unordered_map actor_registry) { ActorStats item; for (auto &pair : actor_registry) { - if (pair.second.GetState() == ActorState::ALIVE) { + if (pair.second.GetState() == ray::rpc::ActorTableData::ALIVE) { item.live_actors += 1; - } else if (pair.second.GetState() == ActorState::RECONSTRUCTING) { + } else if (pair.second.GetState() == ray::rpc::ActorTableData::RECONSTRUCTING) { item.reconstructing_actors += 1; } else { item.dead_actors += 1; @@ -83,7 +83,8 @@ NodeManager::NodeManager(boost::asio::io_service &io_service, initial_config_(config), local_available_resources_(config.resource_config), worker_pool_(config.num_initial_workers, config.num_workers_per_process, - config.maximum_startup_concurrency, config.worker_commands), + config.maximum_startup_concurrency, gcs_client_, + config.worker_commands), scheduling_policy_(local_queues_), reconstruction_policy_( io_service_, @@ -100,7 +101,8 @@ NodeManager::NodeManager(boost::asio::io_service &io_service, gcs_client_->raylet_task_table(), gcs_client_->raylet_task_table(), config.max_lineage_size), actor_registry_(), - node_manager_server_(config.node_manager_port, io_service, *this), + node_manager_server_("NodeManager", config.node_manager_port), + node_manager_service_(io_service, *this), client_call_manager_(io_service) { RAY_CHECK(heartbeat_period_.count() > 0); // Initialize the resource map with own cluster resource configuration. @@ -118,6 +120,7 @@ NodeManager::NodeManager(boost::asio::io_service &io_service, RAY_ARROW_CHECK_OK(store_client_.Connect(config.store_socket_name.c_str())); // Run the node manger rpc server. + node_manager_server_.RegisterService(node_manager_service_); node_manager_server_.Run(); } @@ -129,7 +132,7 @@ ray::Status NodeManager::RegisterGcs() { // that were executed remotely. const auto task_committed_callback = [this](gcs::AsyncGcsClient *client, const TaskID &task_id, - const ray::protocol::TaskT &task_data) { + const TaskTableData &task_data) { lineage_cache_.HandleEntryCommitted(task_id); }; RAY_RETURN_NOT_OK(gcs_client_->raylet_task_table().Subscribe( @@ -138,8 +141,8 @@ ray::Status NodeManager::RegisterGcs() { const auto task_lease_notification_callback = [this](gcs::AsyncGcsClient *client, const TaskID &task_id, - const TaskLeaseDataT &task_lease) { - const ClientID node_manager_id = ClientID::FromBinary(task_lease.node_manager_id); + const TaskLeaseData &task_lease) { + const ClientID node_manager_id = ClientID::FromBinary(task_lease.node_manager_id()); if (gcs_client_->client_table().IsRemoved(node_manager_id)) { // The node manager that added the task lease is already removed. The // lease is considered inactive. @@ -149,7 +152,7 @@ ray::Status NodeManager::RegisterGcs() { // expiration period since the entry may have been in the GCS for some // time already. For a more accurate estimate, the age of the entry in // the GCS should be subtracted from task_lease.timeout. - reconstruction_policy_.HandleTaskLeaseNotification(task_id, task_lease.timeout); + reconstruction_policy_.HandleTaskLeaseNotification(task_id, task_lease.timeout()); } }; const auto task_lease_empty_callback = [this](gcs::AsyncGcsClient *client, @@ -163,7 +166,7 @@ ray::Status NodeManager::RegisterGcs() { // Register a callback to handle actor notifications. auto actor_notification_callback = [this](gcs::AsyncGcsClient *client, const ActorID &actor_id, - const std::vector &data) { + const std::vector &data) { if (!data.empty()) { // We only need the last entry, because it represents the latest state of // this actor. @@ -176,34 +179,34 @@ ray::Status NodeManager::RegisterGcs() { // Register a callback on the client table for new clients. auto node_manager_client_added = [this](gcs::AsyncGcsClient *client, const UniqueID &id, - const ClientTableDataT &data) { + const ClientTableData &data) { ClientAdded(data); }; gcs_client_->client_table().RegisterClientAddedCallback(node_manager_client_added); // Register a callback on the client table for removed clients. auto node_manager_client_removed = [this](gcs::AsyncGcsClient *client, const UniqueID &id, - const ClientTableDataT &data) { ClientRemoved(data); }; + const ClientTableData &data) { ClientRemoved(data); }; gcs_client_->client_table().RegisterClientRemovedCallback(node_manager_client_removed); // Register a callback on the client table for resource create/update requests auto node_manager_resource_createupdated = [this](gcs::AsyncGcsClient *client, const UniqueID &id, - const ClientTableDataT &data) { ResourceCreateUpdated(data); }; + const ClientTableData &data) { ResourceCreateUpdated(data); }; gcs_client_->client_table().RegisterResourceCreateUpdatedCallback( node_manager_resource_createupdated); // Register a callback on the client table for resource delete requests auto node_manager_resource_deleted = [this](gcs::AsyncGcsClient *client, const UniqueID &id, - const ClientTableDataT &data) { ResourceDeleted(data); }; + const ClientTableData &data) { ResourceDeleted(data); }; gcs_client_->client_table().RegisterResourceDeletedCallback( node_manager_resource_deleted); // Subscribe to heartbeat batches from the monitor. const auto &heartbeat_batch_added = [this](gcs::AsyncGcsClient *client, const ClientID &id, - const HeartbeatBatchTableDataT &heartbeat_batch) { + const HeartbeatBatchTableData &heartbeat_batch) { HeartbeatBatchAdded(heartbeat_batch); }; RAY_RETURN_NOT_OK(gcs_client_->heartbeat_batch_table().Subscribe( @@ -214,7 +217,7 @@ ray::Status NodeManager::RegisterGcs() { // Subscribe to driver table updates. const auto driver_table_handler = [this](gcs::AsyncGcsClient *client, const DriverID &client_id, - const std::vector &driver_data) { + const std::vector &driver_data) { HandleDriverTableUpdate(client_id, driver_data); }; RAY_RETURN_NOT_OK(gcs_client_->driver_table().Subscribe( @@ -250,12 +253,12 @@ void NodeManager::KillWorker(std::shared_ptr worker) { } void NodeManager::HandleDriverTableUpdate( - const DriverID &id, const std::vector &driver_data) { + const DriverID &id, const std::vector &driver_data) { for (const auto &entry : driver_data) { - RAY_LOG(DEBUG) << "HandleDriverTableUpdate " << UniqueID::FromBinary(entry.driver_id) - << " " << entry.is_dead; - if (entry.is_dead) { - auto driver_id = DriverID::FromBinary(entry.driver_id); + RAY_LOG(DEBUG) << "HandleDriverTableUpdate " + << UniqueID::FromBinary(entry.driver_id()) << " " << entry.is_dead(); + if (entry.is_dead()) { + auto driver_id = DriverID::FromBinary(entry.driver_id()); auto workers = worker_pool_.GetWorkersRunningTasksForDriver(driver_id); // Kill all the workers. The actual cleanup for these workers is done @@ -287,26 +290,26 @@ void NodeManager::Heartbeat() { last_heartbeat_at_ms_ = now_ms; auto &heartbeat_table = gcs_client_->heartbeat_table(); - auto heartbeat_data = std::make_shared(); + auto heartbeat_data = std::make_shared(); const auto &my_client_id = gcs_client_->client_table().GetLocalClientId(); SchedulingResources &local_resources = cluster_resource_map_[my_client_id]; - heartbeat_data->client_id = my_client_id.Binary(); + heartbeat_data->set_client_id(my_client_id.Binary()); // TODO(atumanov): modify the heartbeat table protocol to use the ResourceSet directly. // TODO(atumanov): implement a ResourceSet const_iterator. for (const auto &resource_pair : local_resources.GetAvailableResources().GetResourceMap()) { - heartbeat_data->resources_available_label.push_back(resource_pair.first); - heartbeat_data->resources_available_capacity.push_back(resource_pair.second); + heartbeat_data->add_resources_available_label(resource_pair.first); + heartbeat_data->add_resources_available_capacity(resource_pair.second); } for (const auto &resource_pair : local_resources.GetTotalResources().GetResourceMap()) { - heartbeat_data->resources_total_label.push_back(resource_pair.first); - heartbeat_data->resources_total_capacity.push_back(resource_pair.second); + heartbeat_data->add_resources_total_label(resource_pair.first); + heartbeat_data->add_resources_total_capacity(resource_pair.second); } local_resources.SetLoadResources(local_queues_.GetResourceLoad()); for (const auto &resource_pair : local_resources.GetLoadResources().GetResourceMap()) { - heartbeat_data->resource_load_label.push_back(resource_pair.first); - heartbeat_data->resource_load_capacity.push_back(resource_pair.second); + heartbeat_data->add_resource_load_label(resource_pair.first); + heartbeat_data->add_resource_load_capacity(resource_pair.second); } ray::Status status = heartbeat_table.Add( @@ -334,13 +337,8 @@ void NodeManager::GetObjectManagerProfileInfo() { auto profile_info = object_manager_.GetAndResetProfilingInfo(); - if (profile_info.profile_events.size() > 0) { - flatbuffers::FlatBufferBuilder fbb; - auto message = CreateProfileTableData(fbb, &profile_info); - fbb.Finish(message); - auto profile_message = flatbuffers::GetRoot(fbb.GetBufferPointer()); - - RAY_CHECK_OK(gcs_client_->profile_table().AddProfileEventBatch(*profile_message)); + if (profile_info.profile_events_size() > 0) { + RAY_CHECK_OK(gcs_client_->profile_table().AddProfileEventBatch(profile_info)); } // Reset the timer. @@ -357,8 +355,8 @@ void NodeManager::GetObjectManagerProfileInfo() { } } -void NodeManager::ClientAdded(const ClientTableDataT &client_data) { - const ClientID client_id = ClientID::FromBinary(client_data.client_id); +void NodeManager::ClientAdded(const ClientTableData &client_data) { + const ClientID client_id = ClientID::FromBinary(client_data.client_id()); RAY_LOG(DEBUG) << "[ClientAdded] Received callback from client id " << client_id; if (client_id == gcs_client_->client_table().GetLocalClientId()) { @@ -377,19 +375,20 @@ void NodeManager::ClientAdded(const ClientTableDataT &client_data) { // Initialize a rpc client to the new node manager. std::unique_ptr client( - new rpc::NodeManagerClient(client_data.node_manager_address, - client_data.node_manager_port, client_call_manager_)); + new rpc::NodeManagerClient(client_data.node_manager_address(), + client_data.node_manager_port(), client_call_manager_)); remote_node_manager_clients_.emplace(client_id, std::move(client)); - ResourceSet resources_total(client_data.resources_total_label, - client_data.resources_total_capacity); + ResourceSet resources_total( + rpc::VectorFromProtobuf(client_data.resources_total_label()), + rpc::VectorFromProtobuf(client_data.resources_total_capacity())); cluster_resource_map_.emplace(client_id, SchedulingResources(resources_total)); } -void NodeManager::ClientRemoved(const ClientTableDataT &client_data) { +void NodeManager::ClientRemoved(const ClientTableData &client_data) { // TODO(swang): If we receive a notification for our own death, clean up and // exit immediately. - const ClientID client_id = ClientID::FromBinary(client_data.client_id); + const ClientID client_id = ClientID::FromBinary(client_data.client_id()); RAY_LOG(DEBUG) << "[ClientRemoved] Received callback from client id " << client_id; RAY_CHECK(client_id != gcs_client_->client_table().GetLocalClientId()) @@ -417,7 +416,7 @@ void NodeManager::ClientRemoved(const ClientTableDataT &client_data) { // TODO(swang): This could be very slow if there are many actors. for (const auto &actor_entry : actor_registry_) { if (actor_entry.second.GetNodeManagerId() == client_id && - actor_entry.second.GetState() == ActorState::ALIVE) { + actor_entry.second.GetState() == ActorTableData::ALIVE) { RAY_LOG(INFO) << "Actor " << actor_entry.first << " is disconnected, because its node " << client_id << " is removed from cluster. It may be reconstructed."; @@ -435,14 +434,15 @@ void NodeManager::ClientRemoved(const ClientTableDataT &client_data) { lineage_cache_.FlushAllUncommittedTasks(); } -void NodeManager::ResourceCreateUpdated(const ClientTableDataT &client_data) { - const ClientID client_id = ClientID::FromBinary(client_data.client_id); +void NodeManager::ResourceCreateUpdated(const ClientTableData &client_data) { + const ClientID client_id = ClientID::FromBinary(client_data.client_id()); const ClientID &local_client_id = gcs_client_->client_table().GetLocalClientId(); RAY_LOG(DEBUG) << "[ResourceCreateUpdated] received callback from client id " << client_id << ". Updating resource map."; - ResourceSet new_res_set(client_data.resources_total_label, - client_data.resources_total_capacity); + ResourceSet new_res_set( + rpc::VectorFromProtobuf(client_data.resources_total_label()), + rpc::VectorFromProtobuf(client_data.resources_total_capacity())); const ResourceSet &old_res_set = cluster_resource_map_[client_id].GetTotalResources(); ResourceSet difference_set = old_res_set.FindUpdatedResources(new_res_set); @@ -471,12 +471,13 @@ void NodeManager::ResourceCreateUpdated(const ClientTableDataT &client_data) { return; } -void NodeManager::ResourceDeleted(const ClientTableDataT &client_data) { - const ClientID client_id = ClientID::FromBinary(client_data.client_id); +void NodeManager::ResourceDeleted(const ClientTableData &client_data) { + const ClientID client_id = ClientID::FromBinary(client_data.client_id()); const ClientID &local_client_id = gcs_client_->client_table().GetLocalClientId(); - ResourceSet new_res_set(client_data.resources_total_label, - client_data.resources_total_capacity); + ResourceSet new_res_set( + rpc::VectorFromProtobuf(client_data.resources_total_label()), + rpc::VectorFromProtobuf(client_data.resources_total_capacity())); RAY_LOG(DEBUG) << "[ResourceDeleted] received callback from client id " << client_id << " with new resources: " << new_res_set.ToString() << ". Updating resource map."; @@ -522,7 +523,7 @@ void NodeManager::TryLocalInfeasibleTaskScheduling() { } void NodeManager::HeartbeatAdded(const ClientID &client_id, - const HeartbeatTableDataT &heartbeat_data) { + const HeartbeatTableData &heartbeat_data) { // Locate the client id in remote client table and update available resources based on // the received heartbeat information. auto it = cluster_resource_map_.find(client_id); @@ -534,10 +535,12 @@ void NodeManager::HeartbeatAdded(const ClientID &client_id, } SchedulingResources &remote_resources = it->second; - ResourceSet remote_available(heartbeat_data.resources_available_label, - heartbeat_data.resources_available_capacity); - ResourceSet remote_load(heartbeat_data.resource_load_label, - heartbeat_data.resource_load_capacity); + ResourceSet remote_available( + rpc::VectorFromProtobuf(heartbeat_data.resources_total_label()), + rpc::VectorFromProtobuf(heartbeat_data.resources_total_capacity())); + ResourceSet remote_load( + rpc::VectorFromProtobuf(heartbeat_data.resource_load_label()), + rpc::VectorFromProtobuf(heartbeat_data.resource_load_capacity())); // TODO(atumanov): assert that the load is a non-empty ResourceSet. remote_resources.SetAvailableResources(std::move(remote_available)); // Extract the load information and save it locally. @@ -562,40 +565,41 @@ void NodeManager::HeartbeatAdded(const ClientID &client_id, } } -void NodeManager::HeartbeatBatchAdded(const HeartbeatBatchTableDataT &heartbeat_batch) { +void NodeManager::HeartbeatBatchAdded(const HeartbeatBatchTableData &heartbeat_batch) { const ClientID &local_client_id = gcs_client_->client_table().GetLocalClientId(); // Update load information provided by each heartbeat. - for (const auto &heartbeat_data : heartbeat_batch.batch) { - const ClientID &client_id = ClientID::FromBinary(heartbeat_data->client_id); + for (const auto &heartbeat_data : heartbeat_batch.batch()) { + const ClientID &client_id = ClientID::FromBinary(heartbeat_data.client_id()); if (client_id == local_client_id) { // Skip heartbeats from self. continue; } - HeartbeatAdded(client_id, *heartbeat_data); + HeartbeatAdded(client_id, heartbeat_data); } } void NodeManager::PublishActorStateTransition( - const ActorID &actor_id, const ActorTableDataT &data, + const ActorID &actor_id, const ActorTableData &data, const ray::gcs::ActorTable::WriteCallback &failure_callback) { // Copy the actor notification data. - auto actor_notification = std::make_shared(data); + auto actor_notification = std::make_shared(data); // The actor log starts with an ALIVE entry. This is followed by 0 to N pairs // of (RECONSTRUCTING, ALIVE) entries, where N is the maximum number of // reconstructions. This is followed optionally by a DEAD entry. - int log_length = 2 * (actor_notification->max_reconstructions - - actor_notification->remaining_reconstructions); - if (actor_notification->state != ActorState::ALIVE) { + int log_length = 2 * (actor_notification->max_reconstructions() - + actor_notification->remaining_reconstructions()); + if (actor_notification->state() != ActorTableData::ALIVE) { // RECONSTRUCTING or DEAD entries have an odd index. log_length += 1; } // If we successful appended a record to the GCS table of the actor that // has died, signal this to anyone receiving signals from this actor. auto success_callback = [](gcs::AsyncGcsClient *client, const ActorID &id, - const ActorTableDataT &data) { + const ActorTableData &data) { auto redis_context = client->primary_context(); - if (data.state == ActorState::DEAD || data.state == ActorState::RECONSTRUCTING) { + if (data.state() == ActorTableData::DEAD || + data.state() == ActorTableData::RECONSTRUCTING) { std::vector args = {"XADD", id.Hex(), "*", "signal", "ACTOR_DIED_SIGNAL"}; RAY_CHECK_OK(redis_context->RunArgvAsync(args)); @@ -632,11 +636,12 @@ void NodeManager::HandleActorStateTransition(const ActorID &actor_id, } RAY_LOG(DEBUG) << "Actor notification received: actor_id = " << actor_id << ", node_manager_id = " << actor_registration.GetNodeManagerId() - << ", state = " << EnumNameActorState(actor_registration.GetState()) + << ", state = " + << ActorTableData::ActorState_Name(actor_registration.GetState()) << ", remaining_reconstructions = " << actor_registration.GetRemainingReconstructions(); - if (actor_registration.GetState() == ActorState::ALIVE) { + if (actor_registration.GetState() == ActorTableData::ALIVE) { // The actor's location is now known. Dequeue any methods that were // submitted before the actor's location was known. // (See design_docs/task_states.rst for the state transition diagram.) @@ -663,7 +668,7 @@ void NodeManager::HandleActorStateTransition(const ActorID &actor_id, // empty lineage this time. SubmitTask(method, Lineage()); } - } else if (actor_registration.GetState() == ActorState::DEAD) { + } else if (actor_registration.GetState() == ActorTableData::DEAD) { // When an actor dies, loop over all of the queued tasks for that actor // and treat them as failed. auto tasks_to_remove = local_queues_.GetTaskIdsForActor(actor_id); @@ -672,7 +677,7 @@ void NodeManager::HandleActorStateTransition(const ActorID &actor_id, TreatTaskAsFailed(task, ErrorType::ACTOR_DIED); } } else { - RAY_CHECK(actor_registration.GetState() == ActorState::RECONSTRUCTING); + RAY_CHECK(actor_registration.GetState() == ActorTableData::RECONSTRUCTING); RAY_LOG(DEBUG) << "Actor is being reconstructed: " << actor_id; // When an actor fails but can be reconstructed, resubmit all of the queued // tasks for that actor. This will mark the tasks as waiting for actor @@ -793,8 +798,20 @@ void NodeManager::ProcessClientMessage( ProcessPushErrorRequestMessage(message_data); } break; case protocol::MessageType::PushProfileEventsRequest: { - auto message = flatbuffers::GetRoot(message_data); - RAY_CHECK_OK(gcs_client_->profile_table().AddProfileEventBatch(*message)); + ProfileTableDataT fbs_message; + flatbuffers::GetRoot(message_data)->UnPackTo(&fbs_message); + rpc::ProfileTableData profile_table_data; + profile_table_data.set_component_type(fbs_message.component_type); + profile_table_data.set_component_id(fbs_message.component_id); + for (const auto &fbs_event : fbs_message.profile_events) { + rpc::ProfileTableData::ProfileEvent *event = + profile_table_data.add_profile_events(); + event->set_event_type(fbs_event->event_type); + event->set_start_time(fbs_event->start_time); + event->set_end_time(fbs_event->end_time); + event->set_extra_data(fbs_event->extra_data); + } + RAY_CHECK_OK(gcs_client_->profile_table().AddProfileEventBatch(profile_table_data)); } break; case protocol::MessageType::FreeObjectsInObjectStoreRequest: { auto message = flatbuffers::GetRoot(message_data); @@ -862,8 +879,8 @@ void NodeManager::HandleDisconnectedActor(const ActorID &actor_id, bool was_loca // Check if this actor needs to be reconstructed. ActorState new_state = actor_registration.GetRemainingReconstructions() > 0 && !intentional_disconnect - ? ActorState::RECONSTRUCTING - : ActorState::DEAD; + ? ActorTableData::RECONSTRUCTING + : ActorTableData::DEAD; if (was_local) { // Clean up the dummy objects from this actor. RAY_LOG(DEBUG) << "Removing dummy objects for actor: " << actor_id; @@ -872,8 +889,8 @@ void NodeManager::HandleDisconnectedActor(const ActorID &actor_id, bool was_loca } } // Update the actor's state. - ActorTableDataT new_actor_data = actor_entry->second.GetTableData(); - new_actor_data.state = new_state; + ActorTableData new_actor_data = actor_entry->second.GetTableData(); + new_actor_data.set_state(new_state); if (was_local) { // If the actor was local, immediately update the state in actor registry. // So if we receive any actor tasks before we receive GCS notification, @@ -884,7 +901,7 @@ void NodeManager::HandleDisconnectedActor(const ActorID &actor_id, bool was_loca ray::gcs::ActorTable::WriteCallback failure_callback = nullptr; if (was_local) { failure_callback = [](gcs::AsyncGcsClient *client, const ActorID &id, - const ActorTableDataT &data) { + const ActorTableData &data) { // If the disconnected actor was local, only this node will try to update actor // state. So the update shouldn't fail. RAY_LOG(FATAL) << "Failed to update state for actor " << id; @@ -1159,7 +1176,7 @@ void NodeManager::ProcessPrepareActorCheckpointRequest( DriverID::Nil(), checkpoint_id, checkpoint_data, [worker, actor_id, this](ray::gcs::AsyncGcsClient *client, const ActorCheckpointID &checkpoint_id, - const ActorCheckpointDataT &data) { + const ActorCheckpointData &data) { RAY_LOG(DEBUG) << "Checkpoint " << checkpoint_id << " saved for actor " << worker->GetActorId(); // Save this actor-to-checkpoint mapping, and remove old checkpoints associated @@ -1243,19 +1260,19 @@ void NodeManager::ProcessSetResourceRequest( return; } - // Add the new resource to a skeleton ClientTableDataT object - ClientTableDataT data; + // Add the new resource to a skeleton ClientTableData object + ClientTableData data; gcs_client_->client_table().GetClient(client_id, data); // Replace the resource vectors with the resource deltas from the message. // RES_CREATEUPDATE and RES_DELETE entries in the ClientTable track changes (deltas) in // the resources - data.resources_total_label = std::vector{resource_name}; - data.resources_total_capacity = std::vector{capacity}; + data.add_resources_total_label(resource_name); + data.add_resources_total_capacity(capacity); // Set the correct flag for entry_type if (is_deletion) { - data.entry_type = EntryType::RES_DELETE; + data.set_entry_type(ClientTableData::RES_DELETE); } else { - data.entry_type = EntryType::RES_CREATEUPDATE; + data.set_entry_type(ClientTableData::RES_CREATEUPDATE); } // Submit to the client table. This calls the ResourceCreateUpdated callback, which @@ -1264,7 +1281,7 @@ void NodeManager::ProcessSetResourceRequest( if (not worker) { worker = worker_pool_.GetRegisteredDriver(client); } - auto data_shared_ptr = std::make_shared(data); + auto data_shared_ptr = std::make_shared(data); auto client_table = gcs_client_->client_table(); RAY_CHECK_OK(gcs_client_->client_table().Append( DriverID::Nil(), client_table.client_log_key_, data_shared_ptr, nullptr)); @@ -1369,7 +1386,7 @@ bool NodeManager::CheckDependencyManagerInvariant() const { void NodeManager::TreatTaskAsFailed(const Task &task, const ErrorType &error_type) { const TaskSpecification &spec = task.GetTaskSpecification(); RAY_LOG(DEBUG) << "Treating task " << spec.TaskId() << " as failed because of error " - << EnumNameErrorType(error_type) << "."; + << ErrorType_Name(error_type) << "."; // If this was an actor creation task that tried to resume from a checkpoint, // then erase it here since the task did not finish. if (spec.IsActorCreationTask()) { @@ -1487,9 +1504,9 @@ void NodeManager::SubmitTask(const Task &task, const Lineage &uncommitted_lineag // If we have already seen this actor and this actor is not being reconstructed, // its location is known. bool location_known = - seen && actor_entry->second.GetState() != ActorState::RECONSTRUCTING; + seen && actor_entry->second.GetState() != ActorTableData::RECONSTRUCTING; if (location_known) { - if (actor_entry->second.GetState() == ActorState::DEAD) { + if (actor_entry->second.GetState() == ActorTableData::DEAD) { // If this actor is dead, either because the actor process is dead // or because its residing node is dead, treat this task as failed. TreatTaskAsFailed(task, ErrorType::ACTOR_DIED); @@ -1534,7 +1551,7 @@ void NodeManager::SubmitTask(const Task &task, const Lineage &uncommitted_lineag // we missed the creation notification. auto lookup_callback = [this](gcs::AsyncGcsClient *client, const ActorID &actor_id, - const std::vector &data) { + const std::vector &data) { if (!data.empty()) { // The actor has been created. We only need the last entry, because // it represents the latest state of this actor. @@ -1723,18 +1740,6 @@ bool NodeManager::AssignTask(const Task &task) { std::shared_ptr worker = worker_pool_.PopWorker(spec); if (worker == nullptr) { // There are no workers that can execute this task. - if (!spec.IsActorTask()) { - // There are no more non-actor workers available to execute this task. - // Start a new worker. - worker_pool_.StartWorkerProcess(spec.GetLanguage()); - // Push an error message to the user if the worker pool tells us that it is - // getting too big. - const std::string warning_message = worker_pool_.WarningAboutSize(); - if (warning_message != "") { - RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver( - DriverID::Nil(), "worker_pool_large", warning_message, current_time_ms())); - } - } // We couldn't assign this task, as no worker available. return false; } @@ -1872,11 +1877,11 @@ void NodeManager::FinishAssignedTask(Worker &worker) { } } -ActorTableDataT NodeManager::CreateActorTableDataFromCreationTask(const Task &task) { +ActorTableData NodeManager::CreateActorTableDataFromCreationTask(const Task &task) { RAY_CHECK(task.GetTaskSpecification().IsActorCreationTask()); auto actor_id = task.GetTaskSpecification().ActorCreationId(); auto actor_entry = actor_registry_.find(actor_id); - ActorTableDataT new_actor_data; + ActorTableData new_actor_data; // TODO(swang): If this is an actor that was reconstructed, and previous // actor notifications were delayed, then this node may not have an entry for // the actor in actor_regisry_. Then, the fields for the number of @@ -1884,32 +1889,33 @@ ActorTableDataT NodeManager::CreateActorTableDataFromCreationTask(const Task &ta if (actor_entry == actor_registry_.end()) { // Set all of the static fields for the actor. These fields will not // change even if the actor fails or is reconstructed. - new_actor_data.actor_id = actor_id.Binary(); - new_actor_data.actor_creation_dummy_object_id = - task.GetTaskSpecification().ActorDummyObject().Binary(); - new_actor_data.driver_id = task.GetTaskSpecification().DriverId().Binary(); - new_actor_data.max_reconstructions = - task.GetTaskSpecification().MaxActorReconstructions(); + new_actor_data.set_actor_id(actor_id.Binary()); + new_actor_data.set_actor_creation_dummy_object_id( + task.GetTaskSpecification().ActorDummyObject().Binary()); + new_actor_data.set_driver_id(task.GetTaskSpecification().DriverId().Binary()); + new_actor_data.set_max_reconstructions( + task.GetTaskSpecification().MaxActorReconstructions()); // This is the first time that the actor has been created, so the number // of remaining reconstructions is the max. - new_actor_data.remaining_reconstructions = - task.GetTaskSpecification().MaxActorReconstructions(); + new_actor_data.set_remaining_reconstructions( + task.GetTaskSpecification().MaxActorReconstructions()); } else { // If we've already seen this actor, it means that this actor was reconstructed. // Thus, its previous state must be RECONSTRUCTING. - RAY_CHECK(actor_entry->second.GetState() == ActorState::RECONSTRUCTING); + RAY_CHECK(actor_entry->second.GetState() == ActorTableData::RECONSTRUCTING); // Copy the static fields from the current actor entry. new_actor_data = actor_entry->second.GetTableData(); // We are reconstructing the actor, so subtract its // remaining_reconstructions by 1. - new_actor_data.remaining_reconstructions--; + new_actor_data.set_remaining_reconstructions( + new_actor_data.remaining_reconstructions() - 1); } // Set the new fields for the actor's state to indicate that the actor is // now alive on this node manager. - new_actor_data.node_manager_id = - gcs_client_->client_table().GetLocalClientId().Binary(); - new_actor_data.state = ActorState::ALIVE; + new_actor_data.set_node_manager_id( + gcs_client_->client_table().GetLocalClientId().Binary()); + new_actor_data.set_state(ActorTableData::ALIVE); return new_actor_data; } @@ -1945,7 +1951,7 @@ void NodeManager::FinishAssignedActorTask(Worker &worker, const Task &task) { DriverID::Nil(), checkpoint_id, [this, actor_id, new_actor_data](ray::gcs::AsyncGcsClient *client, const UniqueID &checkpoint_id, - const ActorCheckpointDataT &checkpoint_data) { + const ActorCheckpointData &checkpoint_data) { RAY_LOG(INFO) << "Restoring registration for actor " << actor_id << " from checkpoint " << checkpoint_id; ActorRegistration actor_registration = @@ -1959,7 +1965,7 @@ void NodeManager::FinishAssignedActorTask(Worker &worker, const Task &task) { actor_id, new_actor_data, /*failure_callback=*/ [](gcs::AsyncGcsClient *client, const ActorID &id, - const ActorTableDataT &data) { + const ActorTableData &data) { // Only one node at a time should succeed at creating the actor. RAY_LOG(FATAL) << "Failed to update state to ALIVE for actor " << id; }); @@ -1975,8 +1981,7 @@ void NodeManager::FinishAssignedActorTask(Worker &worker, const Task &task) { PublishActorStateTransition( actor_id, new_actor_data, /*failure_callback=*/ - [](gcs::AsyncGcsClient *client, const ActorID &id, - const ActorTableDataT &data) { + [](gcs::AsyncGcsClient *client, const ActorID &id, const ActorTableData &data) { // Only one node at a time should succeed at creating the actor. RAY_LOG(FATAL) << "Failed to update state to ALIVE for actor " << id; }); @@ -2015,10 +2020,11 @@ void NodeManager::HandleTaskReconstruction(const TaskID &task_id) { DriverID::Nil(), task_id, /*success_callback=*/ [this](ray::gcs::AsyncGcsClient *client, const TaskID &task_id, - const ray::protocol::TaskT &task_data) { + const TaskTableData &task_data) { // The task was in the GCS task table. Use the stored task spec to // re-execute the task. - const Task task(task_data); + auto message = flatbuffers::GetRoot(task_data.task().data()); + const Task task(*message); ResubmitTask(task); }, /*failure_callback=*/ @@ -2046,7 +2052,7 @@ void NodeManager::ResubmitTask(const Task &task) { if (task.GetTaskSpecification().IsActorCreationTask()) { const auto &actor_id = task.GetTaskSpecification().ActorCreationId(); const auto it = actor_registry_.find(actor_id); - if (it != actor_registry_.end() && it->second.GetState() == ActorState::ALIVE) { + if (it != actor_registry_.end() && it->second.GetState() == ActorTableData::ALIVE) { // If the actor is still alive, then do not resubmit the task. If the // actor actually is dead and a result is needed, then reconstruction // for this task will be triggered again. @@ -2205,6 +2211,12 @@ void NodeManager::ForwardTask( const auto &spec = task.GetTaskSpecification(); auto task_id = spec.TaskId(); + if (worker_pool_.HasPendingWorkerForTask(spec.GetLanguage(), task_id)) { + // There is a worker being starting for this task, + // so we shouldn't forward this task to another node. + return; + } + // Get and serialize the task's unforwarded, uncommitted lineage. Lineage uncommitted_lineage; if (lineage_cache_.ContainsTask(task_id)) { diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index 61613358330c..7e812183657c 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -10,7 +10,6 @@ #include "ray/raylet/task.h" #include "ray/object_manager/object_manager.h" #include "ray/common/client_connection.h" -#include "ray/gcs/format/util.h" #include "ray/raylet/actor_registration.h" #include "ray/raylet/lineage_cache.h" #include "ray/raylet/scheduling_policy.h" @@ -26,6 +25,13 @@ namespace ray { namespace raylet { +using rpc::ActorTableData; +using rpc::ClientTableData; +using rpc::DriverTableData; +using rpc::ErrorType; +using rpc::HeartbeatBatchTableData; +using rpc::HeartbeatTableData; + struct NodeManagerConfig { /// The node's resource configuration. ResourceSet resource_config; @@ -112,22 +118,22 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// /// \param data Data associated with the new client. /// \return Void. - void ClientAdded(const ClientTableDataT &data); + void ClientAdded(const ClientTableData &data); /// Handler for the removal of a GCS client. /// \param client_data Data associated with the removed client. /// \return Void. - void ClientRemoved(const ClientTableDataT &client_data); + void ClientRemoved(const ClientTableData &client_data); /// Handler for the addition or updation of a resource in the GCS /// \param client_data Data associated with the new client. /// \return Void. - void ResourceCreateUpdated(const ClientTableDataT &client_data); + void ResourceCreateUpdated(const ClientTableData &client_data); /// Handler for the deletion of a resource in the GCS /// \param client_data Data associated with the new client. /// \return Void. - void ResourceDeleted(const ClientTableDataT &client_data); + void ResourceDeleted(const ClientTableData &client_data); /// Evaluates the local infeasible queue to check if any tasks can be scheduled. /// This is called whenever there's an update to the resources on the local client. @@ -150,11 +156,11 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// \param id The ID of the node manager that sent the heartbeat. /// \param data The heartbeat data including load information. /// \return Void. - void HeartbeatAdded(const ClientID &id, const HeartbeatTableDataT &data); + void HeartbeatAdded(const ClientID &id, const HeartbeatTableData &data); /// Handler for a heartbeat batch notification from the GCS /// /// \param heartbeat_batch The batch of heartbeat data. - void HeartbeatBatchAdded(const HeartbeatBatchTableDataT &heartbeat_batch); + void HeartbeatBatchAdded(const HeartbeatBatchTableData &heartbeat_batch); /// Methods for task scheduling. @@ -206,7 +212,7 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// Helper function to produce actor table data for a newly created actor. /// /// \param task The actor creation task that created the actor. - ActorTableDataT CreateActorTableDataFromCreationTask(const Task &task); + ActorTableData CreateActorTableDataFromCreationTask(const Task &task); /// Handle a worker finishing an assigned actor task or actor creation task. /// \param worker The worker that finished the task. /// \param task The actor task or actor creationt ask. @@ -317,7 +323,7 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// \param failure_callback An optional callback to call if the publish is /// unsuccessful. void PublishActorStateTransition( - const ActorID &actor_id, const ActorTableDataT &data, + const ActorID &actor_id, const ActorTableData &data, const ray::gcs::ActorTable::WriteCallback &failure_callback); /// When a driver dies, loop over all of the queued tasks for that driver and @@ -346,7 +352,7 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// \param driver_data Data associated with a driver table event. /// \return Void. void HandleDriverTableUpdate(const DriverID &id, - const std::vector &driver_data); + const std::vector &driver_data); /// Check if certain invariants associated with the task dependency manager /// and the local queues are satisfied. This is only used for debugging @@ -506,7 +512,10 @@ class NodeManager : public rpc::NodeManagerServiceHandler { std::unordered_map checkpoint_id_to_restore_; /// The RPC server. - rpc::NodeManagerServer node_manager_server_; + rpc::GrpcServer node_manager_server_; + + /// The RPC service. + rpc::NodeManagerGrpcService node_manager_service_; /// The `ClientCallManager` object that is shared by all `NodeManagerClient`s. rpc::ClientCallManager client_call_manager_; diff --git a/src/ray/raylet/raylet.cc b/src/ray/raylet/raylet.cc index 473e6c263ffe..cbf9b25213ca 100644 --- a/src/ray/raylet/raylet.cc +++ b/src/ray/raylet/raylet.cc @@ -90,23 +90,23 @@ ray::Status Raylet::RegisterGcs(const std::string &node_ip_address, const NodeManagerConfig &node_manager_config) { RAY_RETURN_NOT_OK(gcs_client_->Attach(io_service)); - ClientTableDataT client_info = gcs_client_->client_table().GetLocalClient(); - client_info.node_manager_address = node_ip_address; - client_info.raylet_socket_name = raylet_socket_name; - client_info.object_store_socket_name = object_store_socket_name; - client_info.object_manager_port = object_manager_acceptor_.local_endpoint().port(); - client_info.node_manager_port = node_manager_.GetServerPort(); + ClientTableData client_info = gcs_client_->client_table().GetLocalClient(); + client_info.set_node_manager_address(node_ip_address); + client_info.set_raylet_socket_name(raylet_socket_name); + client_info.set_object_store_socket_name(object_store_socket_name); + client_info.set_object_manager_port(object_manager_acceptor_.local_endpoint().port()); + client_info.set_node_manager_port(node_manager_.GetServerPort()); // Add resource information. for (const auto &resource_pair : node_manager_config.resource_config.GetResourceMap()) { - client_info.resources_total_label.push_back(resource_pair.first); - client_info.resources_total_capacity.push_back(resource_pair.second); + client_info.add_resources_total_label(resource_pair.first); + client_info.add_resources_total_capacity(resource_pair.second); } RAY_LOG(DEBUG) << "Node manager " << gcs_client_->client_table().GetLocalClientId() - << " started on " << client_info.node_manager_address << ":" - << client_info.node_manager_port << " object manager at " - << client_info.node_manager_address << ":" - << client_info.object_manager_port; + << " started on " << client_info.node_manager_address() << ":" + << client_info.node_manager_port() << " object manager at " + << client_info.node_manager_address() << ":" + << client_info.object_manager_port(); ; RAY_RETURN_NOT_OK(gcs_client_->client_table().Connect(client_info)); diff --git a/src/ray/raylet/raylet.h b/src/ray/raylet/raylet.h index 26fe74b2b622..9367a5054591 100644 --- a/src/ray/raylet/raylet.h +++ b/src/ray/raylet/raylet.h @@ -16,6 +16,8 @@ namespace ray { namespace raylet { +using rpc::ClientTableData; + class Task; class NodeManager; diff --git a/src/ray/raylet/reconstruction_policy.cc b/src/ray/raylet/reconstruction_policy.cc index 97c86ea73cd8..bf5c1acfaa37 100644 --- a/src/ray/raylet/reconstruction_policy.cc +++ b/src/ray/raylet/reconstruction_policy.cc @@ -106,19 +106,19 @@ void ReconstructionPolicy::AttemptReconstruction(const TaskID &task_id, // Attempt to reconstruct the task by inserting an entry into the task // reconstruction log. This will fail if another node has already inserted // an entry for this reconstruction. - auto reconstruction_entry = std::make_shared(); - reconstruction_entry->num_reconstructions = reconstruction_attempt; - reconstruction_entry->node_manager_id = client_id_.Binary(); + auto reconstruction_entry = std::make_shared(); + reconstruction_entry->set_num_reconstructions(reconstruction_attempt); + reconstruction_entry->set_node_manager_id(client_id_.Binary()); RAY_CHECK_OK(task_reconstruction_log_.AppendAt( DriverID::Nil(), task_id, reconstruction_entry, /*success_callback=*/ [this](gcs::AsyncGcsClient *client, const TaskID &task_id, - const TaskReconstructionDataT &data) { + const TaskReconstructionData &data) { HandleReconstructionLogAppend(task_id, /*success=*/true); }, /*failure_callback=*/ [this](gcs::AsyncGcsClient *client, const TaskID &task_id, - const TaskReconstructionDataT &data) { + const TaskReconstructionData &data) { HandleReconstructionLogAppend(task_id, /*success=*/false); }, reconstruction_attempt)); diff --git a/src/ray/raylet/reconstruction_policy.h b/src/ray/raylet/reconstruction_policy.h index cd969cc2706e..a194443e1425 100644 --- a/src/ray/raylet/reconstruction_policy.h +++ b/src/ray/raylet/reconstruction_policy.h @@ -17,6 +17,8 @@ namespace ray { namespace raylet { +using rpc::TaskReconstructionData; + class ReconstructionPolicyInterface { public: virtual void ListenAndMaybeReconstruct(const ObjectID &object_id) = 0; diff --git a/src/ray/raylet/reconstruction_policy_test.cc b/src/ray/raylet/reconstruction_policy_test.cc index 4ccebd0c0c09..12d9336a382f 100644 --- a/src/ray/raylet/reconstruction_policy_test.cc +++ b/src/ray/raylet/reconstruction_policy_test.cc @@ -14,6 +14,8 @@ namespace ray { namespace raylet { +using rpc::TaskLeaseData; + class MockObjectDirectory : public ObjectDirectoryInterface { public: MockObjectDirectory() {} @@ -83,7 +85,7 @@ class MockGcs : public gcs::PubsubInterface, } void Add(const DriverID &driver_id, const TaskID &task_id, - std::shared_ptr &task_lease_data) { + std::shared_ptr &task_lease_data) { task_lease_table_[task_id] = task_lease_data; if (subscribed_tasks_.count(task_id) == 1) { notification_callback_(nullptr, task_id, *task_lease_data); @@ -110,7 +112,7 @@ class MockGcs : public gcs::PubsubInterface, Status AppendAt( const DriverID &driver_id, const TaskID &task_id, - std::shared_ptr &task_data, + std::shared_ptr &task_data, const ray::gcs::LogInterface::WriteCallback &success_callback, const ray::gcs::LogInterface::WriteCallback @@ -132,15 +134,15 @@ class MockGcs : public gcs::PubsubInterface, MOCK_METHOD4( Append, ray::Status( - const DriverID &, const TaskID &, std::shared_ptr &, + const DriverID &, const TaskID &, std::shared_ptr &, const ray::gcs::LogInterface::WriteCallback &)); private: gcs::TaskLeaseTable::WriteCallback notification_callback_; gcs::TaskLeaseTable::FailureCallback failure_callback_; - std::unordered_map> task_lease_table_; + std::unordered_map> task_lease_table_; std::unordered_set subscribed_tasks_; - std::unordered_map> + std::unordered_map> task_reconstruction_log_; }; @@ -159,9 +161,9 @@ class ReconstructionPolicyTest : public ::testing::Test { timer_canceled_(false) { mock_gcs_.Subscribe( [this](gcs::AsyncGcsClient *client, const TaskID &task_id, - const TaskLeaseDataT &task_lease) { + const TaskLeaseData &task_lease) { reconstruction_policy_->HandleTaskLeaseNotification(task_id, - task_lease.timeout); + task_lease.timeout()); }, [this](gcs::AsyncGcsClient *client, const TaskID &task_id) { reconstruction_policy_->HandleTaskLeaseNotification(task_id, 0); @@ -314,10 +316,10 @@ TEST_F(ReconstructionPolicyTest, TestReconstructionSuppressed) { int64_t test_period = 2 * reconstruction_timeout_ms_; // Acquire the task lease for a period longer than the test period. - auto task_lease_data = std::make_shared(); - task_lease_data->node_manager_id = ClientID::FromRandom().Binary(); - task_lease_data->acquired_at = current_sys_time_ms(); - task_lease_data->timeout = 2 * test_period; + auto task_lease_data = std::make_shared(); + task_lease_data->set_node_manager_id(ClientID::FromRandom().Binary()); + task_lease_data->set_acquired_at(current_sys_time_ms()); + task_lease_data->set_timeout(2 * test_period); mock_gcs_.Add(DriverID::Nil(), task_id, task_lease_data); // Listen for an object. @@ -328,7 +330,7 @@ TEST_F(ReconstructionPolicyTest, TestReconstructionSuppressed) { ASSERT_TRUE(reconstructed_tasks_.empty()); // Run the test again past the expiration time of the lease. - Run(task_lease_data->timeout * 1.1); + Run(task_lease_data->timeout() * 1.1); // Check that this time, reconstruction is triggered. ASSERT_EQ(reconstructed_tasks_[task_id], 1); } @@ -341,10 +343,10 @@ TEST_F(ReconstructionPolicyTest, TestReconstructionContinuallySuppressed) { reconstruction_policy_->ListenAndMaybeReconstruct(object_id); // Send the reconstruction manager heartbeats about the object. SetPeriodicTimer(reconstruction_timeout_ms_ / 2, [this, task_id]() { - auto task_lease_data = std::make_shared(); - task_lease_data->node_manager_id = ClientID::FromRandom().Binary(); - task_lease_data->acquired_at = current_sys_time_ms(); - task_lease_data->timeout = reconstruction_timeout_ms_; + auto task_lease_data = std::make_shared(); + task_lease_data->set_node_manager_id(ClientID::FromRandom().Binary()); + task_lease_data->set_acquired_at(current_sys_time_ms()); + task_lease_data->set_timeout(reconstruction_timeout_ms_); mock_gcs_.Add(DriverID::Nil(), task_id, task_lease_data); }); // Run the test for much longer than the reconstruction timeout. @@ -393,14 +395,14 @@ TEST_F(ReconstructionPolicyTest, TestSimultaneousReconstructionSuppressed) { // Log a reconstruction attempt to simulate a different node attempting the // reconstruction first. This should suppress this node's first attempt at // reconstruction. - auto task_reconstruction_data = std::make_shared(); - task_reconstruction_data->node_manager_id = ClientID::FromRandom().Binary(); - task_reconstruction_data->num_reconstructions = 0; + auto task_reconstruction_data = std::make_shared(); + task_reconstruction_data->set_node_manager_id(ClientID::FromRandom().Binary()); + task_reconstruction_data->set_num_reconstructions(0); RAY_CHECK_OK( mock_gcs_.AppendAt(DriverID::Nil(), task_id, task_reconstruction_data, nullptr, /*failure_callback=*/ [](ray::gcs::AsyncGcsClient *client, const TaskID &task_id, - const TaskReconstructionDataT &data) { ASSERT_TRUE(false); }, + const TaskReconstructionData &data) { ASSERT_TRUE(false); }, /*log_index=*/0)); // Listen for an object. diff --git a/src/ray/raylet/task_dependency_manager.cc b/src/ray/raylet/task_dependency_manager.cc index c5155b96b0c1..89028c733d0d 100644 --- a/src/ray/raylet/task_dependency_manager.cc +++ b/src/ray/raylet/task_dependency_manager.cc @@ -261,10 +261,10 @@ void TaskDependencyManager::AcquireTaskLease(const TaskID &task_id) { << (it->second.expires_at - now_ms) << "ms"; } - auto task_lease_data = std::make_shared(); - task_lease_data->node_manager_id = client_id_.Hex(); - task_lease_data->acquired_at = current_sys_time_ms(); - task_lease_data->timeout = it->second.lease_period; + auto task_lease_data = std::make_shared(); + task_lease_data->set_node_manager_id(client_id_.Hex()); + task_lease_data->set_acquired_at(current_sys_time_ms()); + task_lease_data->set_timeout(it->second.lease_period); RAY_CHECK_OK(task_lease_table_.Add(DriverID::Nil(), task_id, task_lease_data, nullptr)); auto period = boost::posix_time::milliseconds(it->second.lease_period / 2); diff --git a/src/ray/raylet/task_dependency_manager.h b/src/ray/raylet/task_dependency_manager.h index 3788a5eae7ae..a96558295234 100644 --- a/src/ray/raylet/task_dependency_manager.h +++ b/src/ray/raylet/task_dependency_manager.h @@ -13,6 +13,8 @@ namespace ray { namespace raylet { +using rpc::TaskLeaseData; + class ReconstructionPolicy; /// \class TaskDependencyManager diff --git a/src/ray/raylet/task_dependency_manager_test.cc b/src/ray/raylet/task_dependency_manager_test.cc index e0f832a12870..f7a60989fcba 100644 --- a/src/ray/raylet/task_dependency_manager_test.cc +++ b/src/ray/raylet/task_dependency_manager_test.cc @@ -30,7 +30,7 @@ class MockGcs : public gcs::TableInterface { MOCK_METHOD4( Add, ray::Status(const DriverID &driver_id, const TaskID &task_id, - std::shared_ptr &task_data, + std::shared_ptr &task_data, const gcs::TableInterface::WriteCallback &done)); }; diff --git a/src/ray/raylet/task_spec.cc b/src/ray/raylet/task_spec.cc index eeab29272126..1d722de18f73 100644 --- a/src/ray/raylet/task_spec.cc +++ b/src/ray/raylet/task_spec.cc @@ -80,12 +80,12 @@ TaskSpecification::TaskSpecification( const std::vector> &task_arguments, int64_t num_returns, const std::unordered_map &required_resources, const std::unordered_map &required_placement_resources, - const Language &language, const std::vector &function_descriptor) + const Language &language, const std::vector &function_descriptor, + const std::vector &dynamic_worker_options) : spec_() { flatbuffers::FlatBufferBuilder fbb; TaskID task_id = GenerateTaskId(driver_id, parent_task_id, parent_counter); - // Add argument object IDs. std::vector> arguments; for (auto &argument : task_arguments) { @@ -101,7 +101,8 @@ TaskSpecification::TaskSpecification( ids_to_flatbuf(fbb, new_actor_handles), fbb.CreateVector(arguments), num_returns, map_to_flatbuf(fbb, required_resources), map_to_flatbuf(fbb, required_placement_resources), language, - string_vec_to_flatbuf(fbb, function_descriptor)); + string_vec_to_flatbuf(fbb, function_descriptor), + string_vec_to_flatbuf(fbb, dynamic_worker_options)); fbb.Finish(spec); AssignSpecification(fbb.GetBufferPointer(), fbb.GetSize()); } @@ -258,6 +259,11 @@ std::vector TaskSpecification::NewActorHandles() const { return ids_from_flatbuf(*message->new_actor_handles()); } +std::vector TaskSpecification::DynamicWorkerOptions() const { + auto message = flatbuffers::GetRoot(spec_.data()); + return string_vec_from_flatbuf(*message->dynamic_worker_options()); +} + } // namespace raylet } // namespace ray diff --git a/src/ray/raylet/task_spec.h b/src/ray/raylet/task_spec.h index d557c188ae68..8a08e9974ef2 100644 --- a/src/ray/raylet/task_spec.h +++ b/src/ray/raylet/task_spec.h @@ -128,6 +128,7 @@ class TaskSpecification { /// will default to be equal to the required_resources argument. /// \param language The language of the worker that must execute the function. /// \param function_descriptor The function descriptor. + /// \param dynamic_worker_options The dynamic options for starting an actor worker. TaskSpecification( const DriverID &driver_id, const TaskID &parent_task_id, int64_t parent_counter, const ActorID &actor_creation_id, const ObjectID &actor_creation_dummy_object_id, @@ -138,7 +139,8 @@ class TaskSpecification { int64_t num_returns, const std::unordered_map &required_resources, const std::unordered_map &required_placement_resources, - const Language &language, const std::vector &function_descriptor); + const Language &language, const std::vector &function_descriptor, + const std::vector &dynamic_worker_options = {}); /// Deserialize a task specification from a string. /// @@ -214,6 +216,8 @@ class TaskSpecification { ObjectID ActorDummyObject() const; std::vector NewActorHandles() const; + std::vector DynamicWorkerOptions() const; + private: /// Assign the specification data from a pointer. void AssignSpecification(const uint8_t *spec, size_t spec_size); diff --git a/src/ray/raylet/worker_pool.cc b/src/ray/raylet/worker_pool.cc index d4ac4cf4ecce..16086565de80 100644 --- a/src/ray/raylet/worker_pool.cc +++ b/src/ray/raylet/worker_pool.cc @@ -5,10 +5,12 @@ #include #include +#include "ray/common/constants.h" #include "ray/common/ray_config.h" #include "ray/common/status.h" #include "ray/stats/stats.h" #include "ray/util/logging.h" +#include "ray/util/util.h" namespace { @@ -41,12 +43,13 @@ namespace raylet { /// (num_worker_processes * num_workers_per_process) workers for each language. WorkerPool::WorkerPool( int num_worker_processes, int num_workers_per_process, - int maximum_startup_concurrency, + int maximum_startup_concurrency, std::shared_ptr gcs_client, const std::unordered_map> &worker_commands) : num_workers_per_process_(num_workers_per_process), multiple_for_warning_(std::max(num_worker_processes, maximum_startup_concurrency)), maximum_startup_concurrency_(maximum_startup_concurrency), - last_warning_multiple_(0) { + last_warning_multiple_(0), + gcs_client_(std::move(gcs_client)) { RAY_CHECK(num_workers_per_process > 0) << "num_workers_per_process must be positive."; RAY_CHECK(maximum_startup_concurrency > 0); // Ignore SIGCHLD signals. If we don't do this, then worker processes will @@ -98,7 +101,8 @@ uint32_t WorkerPool::Size(const Language &language) const { } } -void WorkerPool::StartWorkerProcess(const Language &language) { +int WorkerPool::StartWorkerProcess(const Language &language, + const std::vector &dynamic_options) { auto &state = GetStateForLanguage(language); // If we are already starting up too many workers, then return without starting // more. @@ -108,7 +112,7 @@ void WorkerPool::StartWorkerProcess(const Language &language) { RAY_LOG(DEBUG) << "Worker not started, " << state.starting_worker_processes.size() << " worker processes of language type " << static_cast(language) << " pending registration"; - return; + return -1; } // Either there are no workers pending registration or the worker start is being forced. RAY_LOG(DEBUG) << "Starting new worker process, current pool has " @@ -117,8 +121,20 @@ void WorkerPool::StartWorkerProcess(const Language &language) { // Extract pointers from the worker command to pass into execvp. std::vector worker_command_args; + size_t dynamic_option_index = 0; for (auto const &token : state.worker_command) { - worker_command_args.push_back(token.c_str()); + const auto option_placeholder = + kWorkerDynamicOptionPlaceholderPrefix + std::to_string(dynamic_option_index); + + if (token == option_placeholder) { + if (!dynamic_options.empty()) { + RAY_CHECK(dynamic_option_index < dynamic_options.size()); + worker_command_args.push_back(dynamic_options[dynamic_option_index].c_str()); + ++dynamic_option_index; + } + } else { + worker_command_args.push_back(token.c_str()); + } } worker_command_args.push_back(nullptr); @@ -126,14 +142,14 @@ void WorkerPool::StartWorkerProcess(const Language &language) { if (pid < 0) { // Failure case. RAY_LOG(FATAL) << "Failed to fork worker process: " << strerror(errno); - return; } else if (pid > 0) { // Parent process case. RAY_LOG(DEBUG) << "Started worker process with pid " << pid; state.starting_worker_processes.emplace( std::make_pair(pid, num_workers_per_process_)); - return; + return pid; } + return -1; } pid_t WorkerPool::StartProcess(const std::vector &worker_command_args) { @@ -158,7 +174,7 @@ pid_t WorkerPool::StartProcess(const std::vector &worker_command_a } void WorkerPool::RegisterWorker(const std::shared_ptr &worker) { - auto pid = worker->Pid(); + const auto pid = worker->Pid(); RAY_LOG(DEBUG) << "Registering worker with pid " << pid; auto &state = GetStateForLanguage(worker->GetLanguage()); state.registered_workers.insert(std::move(worker)); @@ -207,30 +223,74 @@ void WorkerPool::PushWorker(const std::shared_ptr &worker) { RAY_CHECK(worker->GetAssignedTaskId().IsNil()) << "Idle workers cannot have an assigned task ID"; auto &state = GetStateForLanguage(worker->GetLanguage()); - // Add the worker to the idle pool. - if (worker->GetActorId().IsNil()) { - state.idle.insert(std::move(worker)); + + auto it = state.dedicated_workers_to_tasks.find(worker->Pid()); + if (it != state.dedicated_workers_to_tasks.end()) { + // The worker is used for the actor creation task with dynamic options. + // Put it into idle dedicated worker pool. + const auto task_id = it->second; + state.idle_dedicated_workers[task_id] = std::move(worker); } else { - state.idle_actor[worker->GetActorId()] = std::move(worker); + // The worker is not used for the actor creation task without dynamic options. + // Put the worker to the corresponding idle pool. + if (worker->GetActorId().IsNil()) { + state.idle.insert(std::move(worker)); + } else { + state.idle_actor[worker->GetActorId()] = std::move(worker); + } } } std::shared_ptr WorkerPool::PopWorker(const TaskSpecification &task_spec) { auto &state = GetStateForLanguage(task_spec.GetLanguage()); const auto &actor_id = task_spec.ActorId(); + std::shared_ptr worker = nullptr; - if (actor_id.IsNil()) { + int pid = -1; + if (task_spec.IsActorCreationTask() && !task_spec.DynamicWorkerOptions().empty()) { + // Code path of actor creation task with dynamic worker options. + // Try to pop it from idle dedicated pool. + auto it = state.idle_dedicated_workers.find(task_spec.TaskId()); + if (it != state.idle_dedicated_workers.end()) { + // There is an idle dedicated worker for this task. + worker = std::move(it->second); + state.idle_dedicated_workers.erase(it); + // Because we found a worker that can perform this task, + // we can remove it from dedicated_workers_to_tasks. + state.dedicated_workers_to_tasks.erase(worker->Pid()); + state.tasks_to_dedicated_workers.erase(task_spec.TaskId()); + } else if (!HasPendingWorkerForTask(task_spec.GetLanguage(), task_spec.TaskId())) { + // We are not pending a registration from a worker for this task, + // so start a new worker process for this task. + pid = StartWorkerProcess(task_spec.GetLanguage(), task_spec.DynamicWorkerOptions()); + if (pid > 0) { + state.dedicated_workers_to_tasks[pid] = task_spec.TaskId(); + state.tasks_to_dedicated_workers[task_spec.TaskId()] = pid; + } + } + } else if (!task_spec.IsActorTask()) { + // Code path of normal task or actor creation task without dynamic worker options. if (!state.idle.empty()) { worker = std::move(*state.idle.begin()); state.idle.erase(state.idle.begin()); + } else { + // There are no more non-actor workers available to execute this task. + // Start a new worker process. + pid = StartWorkerProcess(task_spec.GetLanguage()); } } else { + // Code path of actor task. auto actor_entry = state.idle_actor.find(actor_id); if (actor_entry != state.idle_actor.end()) { worker = std::move(actor_entry->second); state.idle_actor.erase(actor_entry); } } + + if (worker == nullptr && pid > 0) { + WarnAboutSize(); + } + return worker; } @@ -274,7 +334,7 @@ std::vector> WorkerPool::GetWorkersRunningTasksForDriver return workers; } -std::string WorkerPool::WarningAboutSize() { +void WorkerPool::WarnAboutSize() { int64_t num_workers_started_or_registered = 0; for (const auto &entry : states_by_lang_) { num_workers_started_or_registered += @@ -285,6 +345,8 @@ std::string WorkerPool::WarningAboutSize() { int64_t multiple = num_workers_started_or_registered / multiple_for_warning_; std::stringstream warning_message; if (multiple >= 3 && multiple > last_warning_multiple_) { + // Push an error message to the user if the worker pool tells us that it is + // getting too big. last_warning_multiple_ = multiple; warning_message << "WARNING: " << num_workers_started_or_registered << " workers have been started. This could be a result of using " @@ -292,8 +354,16 @@ std::string WorkerPool::WarningAboutSize() { << "using nested tasks " << "(see https://github.com/ray-project/ray/issues/3644) for " << "some a discussion of workarounds."; + RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver( + DriverID::Nil(), "worker_pool_large", warning_message.str(), current_time_ms())); } - return warning_message.str(); +} + +bool WorkerPool::HasPendingWorkerForTask(const Language &language, + const TaskID &task_id) { + auto &state = GetStateForLanguage(language); + auto it = state.tasks_to_dedicated_workers.find(task_id); + return it != state.tasks_to_dedicated_workers.end(); } std::string WorkerPool::DebugString() const { diff --git a/src/ray/raylet/worker_pool.h b/src/ray/raylet/worker_pool.h index 03443447cf58..e1e726268093 100644 --- a/src/ray/raylet/worker_pool.h +++ b/src/ray/raylet/worker_pool.h @@ -7,6 +7,7 @@ #include #include "ray/common/client_connection.h" +#include "ray/gcs/client.h" #include "ray/gcs/format/util.h" #include "ray/raylet/task.h" #include "ray/raylet/worker.h" @@ -37,22 +38,12 @@ class WorkerPool { /// language. WorkerPool( int num_worker_processes, int num_workers_per_process, - int maximum_startup_concurrency, + int maximum_startup_concurrency, std::shared_ptr gcs_client, const std::unordered_map> &worker_commands); /// Destructor responsible for freeing a set of workers owned by this class. virtual ~WorkerPool(); - /// Asynchronously start a new worker process. Once the worker process has - /// registered with an external server, the process should create and - /// register num_workers_per_process_ workers, then add them to the pool. - /// Failure to start the worker process is a fatal error. If too many workers - /// are already being started, then this function will return without starting - /// any workers. - /// - /// \param language Which language this worker process should be. - void StartWorkerProcess(const Language &language); - /// Register a new worker. The Worker should be added by the caller to the /// pool after it becomes idle (e.g., requests a work assignment). /// @@ -118,6 +109,15 @@ class WorkerPool { std::vector> GetWorkersRunningTasksForDriver( const DriverID &driver_id) const; + /// Whether there is a pending worker for the given task. + /// Note that, this is only used for actor creation task with dynamic options. + /// And if the worker registered but isn't assigned a task, + /// the worker also is in pending state, and this'll return true. + /// + /// \param language The required language. + /// \param task_id The task that we want to query. + bool HasPendingWorkerForTask(const Language &language, const TaskID &task_id); + /// Returns debug string for class. /// /// \return string. @@ -126,24 +126,37 @@ class WorkerPool { /// Record metrics. void RecordMetrics() const; - /// Generate a warning about the number of workers that have registered or - /// started if appropriate. + protected: + /// Asynchronously start a new worker process. Once the worker process has + /// registered with an external server, the process should create and + /// register num_workers_per_process_ workers, then add them to the pool. + /// Failure to start the worker process is a fatal error. If too many workers + /// are already being started, then this function will return without starting + /// any workers. /// - /// \return An empty string if no warning should be generated and otherwise a - /// string with a warning message. - std::string WarningAboutSize(); + /// \param language Which language this worker process should be. + /// \param dynamic_options The dynamic options that we should add for worker command. + /// \return The id of the process that we started if it's positive, + /// otherwise it means we didn't start a process. + int StartWorkerProcess(const Language &language, + const std::vector &dynamic_options = {}); - protected: /// The implementation of how to start a new worker process with command arguments. /// /// \param worker_command_args The command arguments of new worker process. /// \return The process ID of started worker process. virtual pid_t StartProcess(const std::vector &worker_command_args); + /// Push an warning message to user if worker pool is getting to big. + virtual void WarnAboutSize(); + /// An internal data structure that maintains the pool state per language. struct State { /// The commands and arguments used to start the worker process std::vector worker_command; + /// The pool of dedicated workers for actor creation tasks + /// with prefix or suffix worker command. + std::unordered_map> idle_dedicated_workers; /// The pool of idle non-actor workers. std::unordered_set> idle; /// The pool of idle actor workers. @@ -156,6 +169,11 @@ class WorkerPool { /// A map from the pids of starting worker processes /// to the number of their unregistered workers. std::unordered_map starting_worker_processes; + /// A map for looking up the task with dynamic options by the pid of + /// worker. Note that this is used for the dedicated worker processes. + std::unordered_map dedicated_workers_to_tasks; + /// A map for speeding up looking up the pending worker for the given task. + std::unordered_map tasks_to_dedicated_workers; }; /// The number of workers per process. @@ -166,7 +184,7 @@ class WorkerPool { private: /// A helper function that returns the reference of the pool state /// for a given language. - inline State &GetStateForLanguage(const Language &language); + State &GetStateForLanguage(const Language &language); /// We'll push a warning to the user every time a multiple of this many /// workers has been started. @@ -176,6 +194,8 @@ class WorkerPool { /// The last size at which a warning about the number of registered workers /// was generated. int64_t last_warning_multiple_; + /// A client connection to the GCS. + std::shared_ptr gcs_client_; }; } // namespace raylet diff --git a/src/ray/raylet/worker_pool_test.cc b/src/ray/raylet/worker_pool_test.cc index 143ffd57dda6..15a5fb0471e0 100644 --- a/src/ray/raylet/worker_pool_test.cc +++ b/src/ray/raylet/worker_pool_test.cc @@ -1,6 +1,7 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "ray/common/constants.h" #include "ray/raylet/node_manager.h" #include "ray/raylet/worker_pool.h" @@ -14,21 +15,46 @@ int MAXIMUM_STARTUP_CONCURRENCY = 5; class WorkerPoolMock : public WorkerPool { public: WorkerPoolMock() - : WorkerPool(0, NUM_WORKERS_PER_PROCESS, MAXIMUM_STARTUP_CONCURRENCY, - {{Language::PYTHON, {"dummy_py_worker_command"}}, - {Language::JAVA, {"dummy_java_worker_command"}}}), + : WorkerPoolMock({{Language::PYTHON, {"dummy_py_worker_command"}}, + {Language::JAVA, {"dummy_java_worker_command"}}}) {} + + explicit WorkerPoolMock( + const std::unordered_map> &worker_commands) + : WorkerPool(0, NUM_WORKERS_PER_PROCESS, MAXIMUM_STARTUP_CONCURRENCY, nullptr, + worker_commands), last_worker_pid_(0) {} + ~WorkerPoolMock() { // Avoid killing real processes states_by_lang_.clear(); } + void StartWorkerProcess(const Language &language, + const std::vector &dynamic_options = {}) { + WorkerPool::StartWorkerProcess(language, dynamic_options); + } + pid_t StartProcess(const std::vector &worker_command_args) override { - return ++last_worker_pid_; + last_worker_pid_ += 1; + std::vector local_worker_commands_args; + for (auto item : worker_command_args) { + if (item == nullptr) { + break; + } + local_worker_commands_args.push_back(std::string(item)); + } + worker_commands_by_pid[last_worker_pid_] = std::move(local_worker_commands_args); + return last_worker_pid_; } + void WarnAboutSize() override {} + pid_t LastStartedWorkerProcess() const { return last_worker_pid_; } + const std::vector &GetWorkerCommand(int pid) { + return worker_commands_by_pid[pid]; + } + int NumWorkerProcessesStarting() const { int total = 0; for (auto &entry : states_by_lang_) { @@ -39,6 +65,8 @@ class WorkerPoolMock : public WorkerPool { private: int last_worker_pid_; + // The worker commands by pid. + std::unordered_map> worker_commands_by_pid; }; class WorkerPoolTest : public ::testing::Test { @@ -61,6 +89,12 @@ class WorkerPoolTest : public ::testing::Test { return std::shared_ptr(new Worker(pid, language, client)); } + void SetWorkerCommands( + const std::unordered_map> &worker_commands) { + WorkerPoolMock worker_pool(worker_commands); + this->worker_pool_ = std::move(worker_pool); + } + protected: WorkerPoolMock worker_pool_; boost::asio::io_service io_service_; @@ -72,10 +106,10 @@ class WorkerPoolTest : public ::testing::Test { }; static inline TaskSpecification ExampleTaskSpec( - const ActorID actor_id = ActorID::Nil(), - const Language &language = Language::PYTHON) { + const ActorID actor_id = ActorID::Nil(), const Language &language = Language::PYTHON, + const ActorID actor_creation_id = ActorID::Nil()) { std::vector function_descriptor(3); - return TaskSpecification(DriverID::Nil(), TaskID::Nil(), 0, ActorID::Nil(), + return TaskSpecification(DriverID::Nil(), TaskID::Nil(), 0, actor_creation_id, ObjectID::Nil(), 0, actor_id, ActorHandleID::Nil(), 0, {}, {}, 0, {}, {}, language, function_descriptor); } @@ -186,6 +220,23 @@ TEST_F(WorkerPoolTest, PopWorkersOfMultipleLanguages) { ASSERT_NE(worker_pool_.PopWorker(java_task_spec), nullptr); } +TEST_F(WorkerPoolTest, StartWorkerWithDynamicOptionsCommand) { + const std::vector java_worker_command = { + "RAY_WORKER_OPTION_0", "dummy_java_worker_command", "RAY_WORKER_OPTION_1"}; + SetWorkerCommands({{Language::PYTHON, {"dummy_py_worker_command"}}, + {Language::JAVA, java_worker_command}}); + + TaskSpecification task_spec(DriverID::Nil(), TaskID::Nil(), 0, ActorID::FromRandom(), + ObjectID::Nil(), 0, ActorID::Nil(), ActorHandleID::Nil(), 0, + {}, {}, 0, {}, {}, Language::JAVA, {"", "", ""}, + {"test_op_0", "test_op_1"}); + worker_pool_.StartWorkerProcess(Language::JAVA, task_spec.DynamicWorkerOptions()); + const auto real_command = + worker_pool_.GetWorkerCommand(worker_pool_.LastStartedWorkerProcess()); + ASSERT_EQ(real_command, std::vector( + {"test_op_0", "dummy_java_worker_command", "test_op_1"})); +} + } // namespace raylet } // namespace ray diff --git a/src/ray/rpc/grpc_server.cc b/src/ray/rpc/grpc_server.cc index feb788da7692..f507039990c2 100644 --- a/src/ray/rpc/grpc_server.cc +++ b/src/ray/rpc/grpc_server.cc @@ -1,4 +1,5 @@ #include "ray/rpc/grpc_server.h" +#include namespace ray { namespace rpc { @@ -9,8 +10,10 @@ void GrpcServer::Run() { grpc::ServerBuilder builder; // TODO(hchen): Add options for authentication. builder.AddListeningPort(server_address, grpc::InsecureServerCredentials(), &port_); - // Allow subclasses to register concrete services. - RegisterServices(builder); + // Register all the services to this server. + for (auto &entry : services_) { + builder.RegisterService(&entry.get()); + } // Get hold of the completion queue used for the asynchronous communication // with the gRPC runtime. cq_ = builder.AddCompletionQueue(); @@ -18,8 +21,7 @@ void GrpcServer::Run() { server_ = builder.BuildAndStart(); RAY_LOG(DEBUG) << name_ << " server started, listening on port " << port_ << "."; - // Allow subclasses to initialize the server call factories. - InitServerCallFactories(&server_call_factories_and_concurrencies_); + // Create calls for all the server call factories. for (auto &entry : server_call_factories_and_concurrencies_) { for (int i = 0; i < entry.second; i++) { // Create and request calls from the factory. @@ -31,6 +33,11 @@ void GrpcServer::Run() { polling_thread.detach(); } +void GrpcServer::RegisterService(GrpcService &service) { + services_.emplace_back(service.GetGrpcService()); + service.InitServerCallFactories(cq_, &server_call_factories_and_concurrencies_); +} + void GrpcServer::PollEventsFromCompletionQueue() { void *tag; bool ok; @@ -48,7 +55,7 @@ void GrpcServer::PollEventsFromCompletionQueue() { // incoming request. server_call->GetFactory().CreateCall(); server_call->SetState(ServerCallState::PROCESSING); - main_service_.post([server_call] { server_call->HandleRequest(); }); + server_call->HandleRequest(); break; case ServerCallState::SENDING_REPLY: // The reply has been sent, this call can be deleted now. diff --git a/src/ray/rpc/grpc_server.h b/src/ray/rpc/grpc_server.h index 4953f470610f..584da6565a47 100644 --- a/src/ray/rpc/grpc_server.h +++ b/src/ray/rpc/grpc_server.h @@ -12,7 +12,9 @@ namespace ray { namespace rpc { -/// Base class that represents an abstract gRPC server. +class GrpcService; + +/// Class that represents an gRPC server. /// /// A `GrpcServer` listens on a specific port. It owns /// 1) a `ServerCompletionQueue` that is used for polling events from gRPC, @@ -28,11 +30,7 @@ class GrpcServer { /// \param[in] name Name of this server, used for logging and debugging purpose. /// \param[in] port The port to bind this server to. If it's 0, a random available port /// will be chosen. - /// \param[in] main_service The main event loop, to which service handler functions - /// will be posted. - GrpcServer(const std::string &name, const uint32_t port, - boost::asio::io_service &main_service) - : name_(name), port_(port), main_service_(main_service) {} + GrpcServer(const std::string &name, const uint32_t port) : name_(name), port_(port) {} /// Destruct this gRPC server. ~GrpcServer() { @@ -46,36 +44,25 @@ class GrpcServer { /// Get the port of this gRPC server. int GetPort() const { return port_; } - protected: - /// Subclasses should implement this method and register one or multiple gRPC services - /// to the given `ServerBuilder`. + /// Register a grpc service. Multiple services can be registered to the same server. + /// Note that the `service` registered must remain valid for the lifetime of the + /// `GrpcServer`, as it holds the underlying `grpc::Service`. /// - /// \param[in] builder The `ServerBuilder` instance to register services to. - virtual void RegisterServices(grpc::ServerBuilder &builder) = 0; - - /// Subclasses should implement this method to initialize the `ServerCallFactory` - /// instances, as well as specify maximum number of concurrent requests that gRPC - /// server can "accept" (not "handle"). Each factory will be used to create - /// `accept_concurrency` `ServerCall` objects, each of which will be used to accept and - /// handle an incoming request. - /// - /// \param[out] server_call_factories_and_concurrencies The `ServerCallFactory` objects, - /// and the maximum number of concurrent requests that gRPC server can accept. - virtual void InitServerCallFactories( - std::vector, int>> - *server_call_factories_and_concurrencies) = 0; + /// \param[in] service A `GrpcService` to register to this server. + void RegisterService(GrpcService &service); + protected: /// This function runs in a background thread. It keeps polling events from the /// `ServerCompletionQueue`, and dispaches the event to the `ServiceHandler` instances /// via the `ServerCall` objects. void PollEventsFromCompletionQueue(); - /// The main event loop, to which the service handler functions will be posted. - boost::asio::io_service &main_service_; /// Name of this server, used for logging and debugging purpose. const std::string name_; /// Port of this server. int port_; + /// The `grpc::Service` objects which should be registered to `ServerBuilder`. + std::vector> services_; /// The `ServerCallFactory` objects, and the maximum number of concurrent requests that /// gRPC server can accept. std::vector, int>> @@ -86,6 +73,46 @@ class GrpcServer { std::unique_ptr server_; }; +/// Base class that represents an abstract gRPC service. +/// +/// Subclass should implement `InitServerCallFactories` to decide +/// which kinds of requests this service should accept. +class GrpcService { + public: + /// Constructor. + /// + /// \param[in] main_service The main event loop, to which service handler functions + /// will be posted. + GrpcService(boost::asio::io_service &main_service) : main_service_(main_service) {} + + /// Destruct this gRPC service. + ~GrpcService() {} + + protected: + /// Return the underlying grpc::Service object for this class. + /// This is passed to `GrpcServer` to be registered to grpc `ServerBuilder`. + virtual grpc::Service &GetGrpcService() = 0; + + /// Subclasses should implement this method to initialize the `ServerCallFactory` + /// instances, as well as specify maximum number of concurrent requests that gRPC + /// server can "accept" (not "handle"). Each factory will be used to create + /// `accept_concurrency` `ServerCall` objects, each of which will be used to accept and + /// handle an incoming request. + /// + /// \param[in] cq The grpc completion queue. + /// \param[out] server_call_factories_and_concurrencies The `ServerCallFactory` objects, + /// and the maximum number of concurrent requests that gRPC server can accept. + virtual void InitServerCallFactories( + const std::unique_ptr &cq, + std::vector, int>> + *server_call_factories_and_concurrencies) = 0; + + /// The main event loop, to which the service handler functions will be posted. + boost::asio::io_service &main_service_; + + friend class GrpcServer; +}; + } // namespace rpc } // namespace ray diff --git a/src/ray/rpc/node_manager_server.h b/src/ray/rpc/node_manager_server.h index afaea299ea89..d05f268c65b2 100644 --- a/src/ray/rpc/node_manager_server.h +++ b/src/ray/rpc/node_manager_server.h @@ -25,25 +25,22 @@ class NodeManagerServiceHandler { RequestDoneCallback done_callback) = 0; }; -/// The `GrpcServer` for `NodeManagerService`. -class NodeManagerServer : public GrpcServer { +/// The `GrpcService` for `NodeManagerService`. +class NodeManagerGrpcService : public GrpcService { public: /// Constructor. /// - /// \param[in] port See super class. - /// \param[in] main_service See super class. + /// \param[in] io_service See super class. /// \param[in] handler The service handler that actually handle the requests. - NodeManagerServer(const uint32_t port, boost::asio::io_service &main_service, - NodeManagerServiceHandler &service_handler) - : GrpcServer("NodeManager", port, main_service), - service_handler_(service_handler){}; + NodeManagerGrpcService(boost::asio::io_service &io_service, + NodeManagerServiceHandler &service_handler) + : GrpcService(io_service), service_handler_(service_handler){}; - void RegisterServices(grpc::ServerBuilder &builder) override { - /// Register `NodeManagerService`. - builder.RegisterService(&service_); - } + protected: + grpc::Service &GetGrpcService() override { return service_; } void InitServerCallFactories( + const std::unique_ptr &cq, std::vector, int>> *server_call_factories_and_concurrencies) override { // Initialize the factory for `ForwardTask` requests. @@ -51,7 +48,8 @@ class NodeManagerServer : public GrpcServer { new ServerCallFactoryImpl( service_, &NodeManagerService::AsyncService::RequestForwardTask, - service_handler_, &NodeManagerServiceHandler::HandleForwardTask, cq_)); + service_handler_, &NodeManagerServiceHandler::HandleForwardTask, cq, + main_service_)); // Set `ForwardTask`'s accept concurrency to 100. server_call_factories_and_concurrencies->emplace_back( @@ -61,6 +59,7 @@ class NodeManagerServer : public GrpcServer { private: /// The grpc async service object. NodeManagerService::AsyncService service_; + /// The service handler that actually handle the requests. NodeManagerServiceHandler &service_handler_; }; diff --git a/src/ray/rpc/server_call.h b/src/ray/rpc/server_call.h index e06278260ab6..08ca128323ee 100644 --- a/src/ray/rpc/server_call.h +++ b/src/ray/rpc/server_call.h @@ -94,20 +94,27 @@ class ServerCallImpl : public ServerCall { /// \param[in] factory The factory which created this call. /// \param[in] service_handler The service handler that handles the request. /// \param[in] handle_request_function Pointer to the service handler function. + /// \param[in] io_service The event loop. ServerCallImpl( const ServerCallFactory &factory, ServiceHandler &service_handler, - HandleRequestFunction handle_request_function) + HandleRequestFunction handle_request_function, + boost::asio::io_service &io_service) : state_(ServerCallState::PENDING), factory_(factory), service_handler_(service_handler), handle_request_function_(handle_request_function), - response_writer_(&context_) {} + response_writer_(&context_), + io_service_(io_service) {} ServerCallState GetState() const override { return state_; } void SetState(const ServerCallState &new_state) override { state_ = new_state; } void HandleRequest() override { + io_service_.post([this] { HandleRequestImpl(); }); + } + + void HandleRequestImpl() { state_ = ServerCallState::PROCESSING; (service_handler_.*handle_request_function_)(request_, &reply_, [this](Status status) { @@ -146,6 +153,9 @@ class ServerCallImpl : public ServerCall { /// The reponse writer. grpc::ServerAsyncResponseWriter response_writer_; + /// The event loop. + boost::asio::io_service &io_service_; + /// The request message. Request request_; @@ -185,23 +195,26 @@ class ServerCallFactoryImpl : public ServerCallFactory { /// \param[in] service_handler The service handler that handles the request. /// \param[in] handle_request_function Pointer to the service handler function. /// \param[in] cq The `CompletionQueue`. + /// \param[in] io_service The event loop. ServerCallFactoryImpl( AsyncService &service, RequestCallFunction request_call_function, ServiceHandler &service_handler, HandleRequestFunction handle_request_function, - const std::unique_ptr &cq) + const std::unique_ptr &cq, + boost::asio::io_service &io_service) : service_(service), request_call_function_(request_call_function), service_handler_(service_handler), handle_request_function_(handle_request_function), - cq_(cq) {} + cq_(cq), + io_service_(io_service) {} ServerCall *CreateCall() const override { // Create a new `ServerCall`. This object will eventually be deleted by // `GrpcServer::PollEventsFromCompletionQueue`. auto call = new ServerCallImpl( - *this, service_handler_, handle_request_function_); + *this, service_handler_, handle_request_function_, io_service_); /// Request gRPC runtime to starting accepting this kind of request, using the call as /// the tag. (service_.*request_call_function_)(&call->context_, &call->request_, @@ -225,6 +238,9 @@ class ServerCallFactoryImpl : public ServerCallFactory { /// The `CompletionQueue`. const std::unique_ptr &cq_; + + /// The event loop. + boost::asio::io_service &io_service_; }; } // namespace rpc diff --git a/src/ray/rpc/util.h b/src/ray/rpc/util.h index 6ecc6c3c4a34..59ae75ae33be 100644 --- a/src/ray/rpc/util.h +++ b/src/ray/rpc/util.h @@ -1,6 +1,7 @@ #ifndef RAY_RPC_UTIL_H #define RAY_RPC_UTIL_H +#include #include #include "ray/common/status.h" @@ -27,6 +28,18 @@ inline Status GrpcStatusToRayStatus(const grpc::Status &grpc_status) { } } +template +inline std::vector VectorFromProtobuf( + const ::google::protobuf::RepeatedPtrField &pb_repeated) { + return std::vector(pb_repeated.begin(), pb_repeated.end()); +} + +template +inline std::vector VectorFromProtobuf( + const ::google::protobuf::RepeatedField &pb_repeated) { + return std::vector(pb_repeated.begin(), pb_repeated.end()); +} + } // namespace rpc } // namespace ray