Skip to content

Commit

Permalink
[Core Worker] implement ObjectInterface and add test framework (ray-p…
Browse files Browse the repository at this point in the history
  • Loading branch information
zhijunfu authored and raulchen committed Jun 3, 2019
1 parent 89722ff commit b674c4a
Showing 14 changed files with 611 additions and 27 deletions.
3 changes: 3 additions & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
@@ -148,6 +148,9 @@ install:
- ./ci/suppress_output bazel build //:stats_test -c opt
- ./bazel-bin/stats_test

# core worker test.
- ./ci/suppress_output bash src/ray/test/run_core_worker_tests.sh

# Raylet tests.
- ./ci/suppress_output bash src/ray/test/run_object_manager_tests.sh
- ./ci/suppress_output bazel test --build_tests_only --test_lang_filters=cc //:all
7 changes: 6 additions & 1 deletion BUILD.bazel
Original file line number Diff line number Diff line change
@@ -77,6 +77,7 @@ cc_library(
"src/ray/raylet/mock_gcs_client.cc",
"src/ray/raylet/monitor_main.cc",
"src/ray/raylet/*_test.cc",
"src/ray/raylet/main.cc",
],
),
hdrs = glob([
@@ -122,15 +123,18 @@ cc_library(
deps = [
":ray_common",
":ray_util",
":raylet_lib",
],
)

cc_test(
# This test is run by src/ray/test/run_core_worker_tests.sh
cc_binary(
name = "core_worker_test",
srcs = ["src/ray/core_worker/core_worker_test.cc"],
copts = COPTS,
deps = [
":core_worker_lib",
":gcs",
"@com_google_googletest//:gtest_main",
],
)
@@ -320,6 +324,7 @@ cc_library(
":node_manager_fbs",
":ray_util",
"@boost//:asio",
"@plasma//:plasma_client",
],
)

22 changes: 21 additions & 1 deletion src/ray/common/buffer.h
Original file line number Diff line number Diff line change
@@ -3,6 +3,11 @@

#include <cstdint>
#include <cstdio>
#include "plasma/client.h"

namespace arrow {
class Buffer;
}

namespace ray {

@@ -15,7 +20,7 @@ class Buffer {
/// Size of this buffer.
virtual size_t Size() const = 0;

virtual ~Buffer() {}
virtual ~Buffer(){};

bool operator==(const Buffer &rhs) const {
return this->Data() == rhs.Data() && this->Size() == rhs.Size();
@@ -40,6 +45,21 @@ class LocalMemoryBuffer : public Buffer {
size_t size_;
};

/// Represents a byte buffer for plasma object.
class PlasmaBuffer : public Buffer {
public:
PlasmaBuffer(std::shared_ptr<arrow::Buffer> buffer) : buffer_(buffer) {}

uint8_t *Data() const override { return const_cast<uint8_t *>(buffer_->data()); }

size_t Size() const override { return buffer_->size(); }

private:
/// shared_ptr to arrow buffer which can potentially hold a reference
/// for the object (when it's a plasma::PlasmaBuffer).
std::shared_ptr<arrow::Buffer> buffer_;
};

} // namespace ray

#endif // RAY_COMMON_BUFFER_H
4 changes: 2 additions & 2 deletions src/ray/core_worker/common.h
Original file line number Diff line number Diff line change
@@ -45,13 +45,13 @@ class TaskArg {
bool IsPassedByReference() const { return id_ != nullptr; }

/// Get the reference object ID.
ObjectID &GetReference() {
const ObjectID &GetReference() const {
RAY_CHECK(id_ != nullptr) << "This argument isn't passed by reference.";
return *id_;
}

/// Get the value.
std::shared_ptr<Buffer> GetValue() {
std::shared_ptr<Buffer> GetValue() const {
RAY_CHECK(data_ != nullptr) << "This argument isn't passed by value.";
return data_;
}
81 changes: 81 additions & 0 deletions src/ray/core_worker/context.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@

#include "context.h"

namespace ray {

/// per-thread context for core worker.
struct WorkerThreadContext {
WorkerThreadContext()
: current_task_id(TaskID::FromRandom()), task_index(0), put_index(0) {}

int GetNextTaskIndex() { return ++task_index; }

int GetNextPutIndex() { return ++put_index; }

const TaskID &GetCurrentTaskID() const { return current_task_id; }

void SetCurrentTask(const TaskID &task_id) {
current_task_id = task_id;
task_index = 0;
put_index = 0;
}

void SetCurrentTask(const raylet::TaskSpecification &spec) {
SetCurrentTask(spec.TaskId());
}

private:
/// The task ID for current task.
TaskID current_task_id;

/// Number of tasks that have been submitted from current task.
int task_index;

/// Number of objects that have been put from current task.
int put_index;
};

thread_local std::unique_ptr<WorkerThreadContext> WorkerContext::thread_context_ =
nullptr;

WorkerContext::WorkerContext(WorkerType worker_type, const DriverID &driver_id)
: worker_type(worker_type),
worker_id(worker_type == WorkerType::DRIVER
? ClientID::FromBinary(driver_id.Binary())
: ClientID::FromRandom()),
current_driver_id(worker_type == WorkerType::DRIVER ? driver_id : DriverID::Nil()) {
// For worker main thread which initializes the WorkerContext,
// set task_id according to whether current worker is a driver.
// (For other threads it's set to randmom ID via GetThreadContext).
GetThreadContext().SetCurrentTask(
(worker_type == WorkerType::DRIVER) ? TaskID::FromRandom() : TaskID::Nil());
}

const WorkerType WorkerContext::GetWorkerType() const { return worker_type; }

const ClientID &WorkerContext::GetWorkerID() const { return worker_id; }

int WorkerContext::GetNextTaskIndex() { return GetThreadContext().GetNextTaskIndex(); }

int WorkerContext::GetNextPutIndex() { return GetThreadContext().GetNextPutIndex(); }

const DriverID &WorkerContext::GetCurrentDriverID() const { return current_driver_id; }

const TaskID &WorkerContext::GetCurrentTaskID() const {
return GetThreadContext().GetCurrentTaskID();
}

void WorkerContext::SetCurrentTask(const raylet::TaskSpecification &spec) {
current_driver_id = spec.DriverId();
GetThreadContext().SetCurrentTask(spec);
}

WorkerThreadContext &WorkerContext::GetThreadContext() {
if (thread_context_ == nullptr) {
thread_context_ = std::unique_ptr<WorkerThreadContext>(new WorkerThreadContext());
}

return *thread_context_;
}

} // namespace ray
48 changes: 48 additions & 0 deletions src/ray/core_worker/context.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#ifndef RAY_CORE_WORKER_CONTEXT_H
#define RAY_CORE_WORKER_CONTEXT_H

#include "common.h"
#include "ray/raylet/task_spec.h"

namespace ray {

struct WorkerThreadContext;

class WorkerContext {
public:
WorkerContext(WorkerType worker_type, const DriverID &driver_id);

const WorkerType GetWorkerType() const;

const ClientID &GetWorkerID() const;

const DriverID &GetCurrentDriverID() const;

const TaskID &GetCurrentTaskID() const;

void SetCurrentTask(const raylet::TaskSpecification &spec);

int GetNextTaskIndex();

int GetNextPutIndex();

private:
/// Type of the worker.
const WorkerType worker_type;

/// ID for this worker.
const ClientID worker_id;

/// Driver ID for this worker.
DriverID current_driver_id;

private:
static WorkerThreadContext &GetThreadContext();

/// Per-thread worker context.
static thread_local std::unique_ptr<WorkerThreadContext> thread_context_;
};

} // namespace ray

#endif // RAY_CORE_WORKER_CONTEXT_H
39 changes: 39 additions & 0 deletions src/ray/core_worker/core_worker.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#include "core_worker.h"
#include "context.h"

namespace ray {

CoreWorker::CoreWorker(const enum WorkerType worker_type, const enum Language language,
const std::string &store_socket, const std::string &raylet_socket,
DriverID driver_id)
: worker_type_(worker_type),
language_(language),
worker_context_(worker_type, driver_id),
store_socket_(store_socket),
raylet_socket_(raylet_socket),
task_interface_(*this),
object_interface_(*this),
task_execution_interface_(*this) {}

Status CoreWorker::Connect() {
// connect to plasma.
RAY_ARROW_RETURN_NOT_OK(store_client_.Connect(store_socket_));

// connect to raylet.
::Language lang = ::Language::PYTHON;
if (language_ == ray::Language::JAVA) {
lang = ::Language::JAVA;
}

// TODO: currently RayletClient would crash in its constructor if it cannot
// connect to Raylet after a number of retries, this needs to be changed
// so that the worker (java/python .etc) can retrieve and handle the error
// instead of crashing.
raylet_client_ = std::unique_ptr<RayletClient>(
new RayletClient(raylet_socket_, worker_context_.GetWorkerID(),
(worker_type_ == ray::WorkerType::WORKER),
worker_context_.GetCurrentDriverID(), lang));
return Status::OK();
}

} // namespace ray
34 changes: 26 additions & 8 deletions src/ray/core_worker/core_worker.h
Original file line number Diff line number Diff line change
@@ -2,8 +2,10 @@
#define RAY_CORE_WORKER_CORE_WORKER_H

#include "common.h"
#include "context.h"
#include "object_interface.h"
#include "ray/common/buffer.h"
#include "ray/raylet/raylet_client.h"
#include "task_execution.h"
#include "task_interface.h"

@@ -18,15 +20,12 @@ class CoreWorker {
///
/// \param[in] worker_type Type of this worker.
/// \param[in] langauge Language of this worker.
CoreWorker(const WorkerType worker_type, const Language language)
: worker_type_(worker_type),
language_(language),
task_interface_(*this),
object_interface_(*this),
task_execution_interface_(*this) {}
CoreWorker(const WorkerType worker_type, const Language language,
const std::string &store_socket, const std::string &raylet_socket,
DriverID driver_id = DriverID::Nil());

/// Connect this worker to Raylet.
Status Connect() { return Status::OK(); }
/// Connect to raylet.
Status Connect();

/// Type of this worker.
enum WorkerType WorkerType() const { return worker_type_; }
@@ -53,6 +52,21 @@ class CoreWorker {
/// Language of this worker.
const enum Language language_;

/// Worker context per thread.
WorkerContext worker_context_;

/// Plasma store socket name.
std::string store_socket_;

/// raylet socket name.
std::string raylet_socket_;

/// Plasma store client.
plasma::PlasmaClient store_client_;

/// Raylet client.
std::unique_ptr<RayletClient> raylet_client_;

/// The `CoreWorkerTaskInterface` instance.
CoreWorkerTaskInterface task_interface_;

@@ -61,6 +75,10 @@ class CoreWorker {

/// The `CoreWorkerTaskExecutionInterface` instance.
CoreWorkerTaskExecutionInterface task_execution_interface_;

friend class CoreWorkerTaskInterface;
friend class CoreWorkerObjectInterface;
friend class CoreWorkerTaskExecutionInterface;
};

} // namespace ray
Loading

0 comments on commit b674c4a

Please sign in to comment.