Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core worker] add store & task provider #4966

Merged
merged 14 commits into from
Jun 14, 2019
12 changes: 1 addition & 11 deletions src/ray/core_worker/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class TaskArg {

/// Task specification, which includes the immutable information about the task
/// which are determined at the submission time.
/// TODO: this can be removed after everything is moved to protobuf.
class TaskSpec {
zhijunfu marked this conversation as resolved.
Show resolved Hide resolved
public:
TaskSpec(const raylet::TaskSpecification &task_spec,
Expand All @@ -96,17 +97,6 @@ enum class StoreProviderType { PLASMA };

enum class TaskTransportType { RAYLET };

struct RayClient {
/// Plasma store client.
plasma::PlasmaClient store_client_;

/// Mutex to protect store_client_.
std::mutex store_client_mutex_;

/// Raylet client.
std::unique_ptr<RayletClient> raylet_client_;
};

} // namespace ray

#endif // RAY_CORE_WORKER_COMMON_H
22 changes: 10 additions & 12 deletions src/ray/core_worker/core_worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,23 @@ CoreWorker::CoreWorker(const enum WorkerType worker_type,
store_socket_(store_socket),
raylet_socket_(raylet_socket),
worker_context_(worker_type, driver_id),
raylet_client_(raylet_socket_, worker_context_.GetWorkerID(),
(worker_type_ == ray::WorkerType::WORKER),
worker_context_.GetCurrentDriverID(), ToTaskLanguage(language_)),
task_interface_(*this),
object_interface_(*this),
task_execution_interface_(*this) {}
task_execution_interface_(*this) {

Status CoreWorker::Connect() {
// connect to plasma.
RAY_ARROW_RETURN_NOT_OK(ray_client_.store_client_.Connect(store_socket_));

// connect to raylet.
// TODO: currently RayletClient would crash in its constructor if it cannot
// connect to Raylet after a number of retries, this needs to be changed
// so that the worker (java/python .etc) can retrieve and handle the error
// instead of crashing.
ray_client_.raylet_client_ = std::unique_ptr<RayletClient>(
new RayletClient(raylet_socket_, worker_context_.GetWorkerID(),
(worker_type_ == ray::WorkerType::WORKER),
worker_context_.GetCurrentDriverID(), ToTaskLanguage(language_)));

return Status::OK();
auto status = store_client_.Connect(store_socket_);
if (!status.ok()) {
RAY_LOG(ERROR) << "Connecting plasma store failed when trying to construct"
<< " core worker: " << status.message();
throw std::runtime_error(status.message());
}
}

::Language CoreWorker::ToTaskLanguage(WorkerLanguage language) {
Expand Down
17 changes: 9 additions & 8 deletions src/ray/core_worker/core_worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ namespace ray {
/// The root class that contains all the core and language-independent functionalities
/// of the worker. This class is supposed to be used to implement app-language (Java,
/// Python, etc) workers.
///
/// Note: the constructor of CoreWorker would throw if a failure happens.
zhijunfu marked this conversation as resolved.
Show resolved Hide resolved
class CoreWorker {
public:
/// Construct a CoreWorker instance.
Expand All @@ -24,9 +26,6 @@ class CoreWorker {
const std::string &store_socket, const std::string &raylet_socket,
DriverID driver_id = DriverID::Nil());

/// Connect to raylet.
Status Connect();

/// Type of this worker.
enum WorkerType WorkerType() const { return worker_type_; }

Expand Down Expand Up @@ -67,12 +66,14 @@ class CoreWorker {
/// Worker context.
WorkerContext worker_context_;

/// Ray client (this includes store client, raylet client and potentially gcs client
/// later).
RayClient ray_client_;
/// Plasma store client.
plasma::PlasmaClient store_client_;

/// Mutex to protect store_client_.
std::mutex store_client_mutex_;

/// Whether this worker has been initialized.
bool is_initialized_;
/// Raylet client.
RayletClient raylet_client_;

/// The `CoreWorkerTaskInterface` instance.
CoreWorkerTaskInterface task_interface_;
Expand Down
23 changes: 10 additions & 13 deletions src/ray/core_worker/core_worker_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,6 @@ class CoreWorkerTest : public ::testing::Test {
raylet_store_socket_names_[0], raylet_socket_names_[0],
DriverID::FromRandom());

RAY_CHECK_OK(driver.Connect());

// Test pass by value.
{
uint8_t array1[] = {1, 2, 3, 4, 5, 6, 7, 8};
Expand Down Expand Up @@ -187,7 +185,6 @@ class CoreWorkerTest : public ::testing::Test {
CoreWorker driver(WorkerType::DRIVER, WorkerLanguage::PYTHON,
raylet_store_socket_names_[0], raylet_socket_names_[0],
DriverID::FromRandom());
RAY_CHECK_OK(driver.Connect());

std::unique_ptr<ActorHandle> actor_handle;

Expand Down Expand Up @@ -277,13 +274,6 @@ TEST_F(ZeroNodeTest, TestTaskArg) {
ASSERT_EQ(*data, *buffer);
}

TEST_F(ZeroNodeTest, TestAttributeGetters) {
CoreWorker core_worker(WorkerType::DRIVER, WorkerLanguage::PYTHON, "", "",
DriverID::FromRandom());
ASSERT_EQ(core_worker.WorkerType(), WorkerType::DRIVER);
ASSERT_EQ(core_worker.Language(), WorkerLanguage::PYTHON);
}

TEST_F(ZeroNodeTest, TestWorkerContext) {
auto driver_id = DriverID::FromRandom();

Expand Down Expand Up @@ -313,7 +303,6 @@ TEST_F(SingleNodeTest, TestObjectInterface) {
CoreWorker core_worker(WorkerType::DRIVER, WorkerLanguage::PYTHON,
raylet_store_socket_names_[0], raylet_socket_names_[0],
DriverID::FromRandom());
RAY_CHECK_OK(core_worker.Connect());

uint8_t array1[] = {1, 2, 3, 4, 5, 6, 7, 8};
uint8_t array2[] = {10, 11, 12, 13, 14, 15};
Expand Down Expand Up @@ -370,12 +359,10 @@ TEST_F(TwoNodeTest, TestObjectInterfaceCrossNodes) {
CoreWorker worker1(WorkerType::DRIVER, WorkerLanguage::PYTHON,
raylet_store_socket_names_[0], raylet_socket_names_[0],
DriverID::FromRandom());
RAY_CHECK_OK(worker1.Connect());

CoreWorker worker2(WorkerType::DRIVER, WorkerLanguage::PYTHON,
raylet_store_socket_names_[1], raylet_socket_names_[1],
DriverID::FromRandom());
RAY_CHECK_OK(worker2.Connect());

uint8_t array1[] = {1, 2, 3, 4, 5, 6, 7, 8};
uint8_t array2[] = {10, 11, 12, 13, 14, 15};
Expand Down Expand Up @@ -456,6 +443,16 @@ TEST_F(TwoNodeTest, TestActorTaskCrossNodes) {
TestActorTask(resources);
}

TEST_F(SingleNodeTest, TestCoreWorkerConstructorFailure) {
try {
CoreWorker core_worker(WorkerType::DRIVER, WorkerLanguage::PYTHON,
"", raylet_socket_names_[0],
DriverID::FromRandom());
} catch (const std::exception& e) {
std::cout << "Caught exception when constructing core worker: " << e.what();
}
}

} // namespace ray

int main(int argc, char **argv) {
Expand Down
4 changes: 1 addition & 3 deletions src/ray/core_worker/mock_worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@ class MockWorker {
public:
MockWorker(const std::string &store_socket, const std::string &raylet_socket)
: worker_(WorkerType::WORKER, WorkerLanguage::PYTHON, store_socket, raylet_socket,
DriverID::FromRandom()) {
RAY_CHECK_OK(worker_.Connect());
}
DriverID::FromRandom()) {}

void Run() {
auto executor_func = [this](const RayFunction &ray_function,
Expand Down
3 changes: 2 additions & 1 deletion src/ray/core_worker/object_interface.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ CoreWorkerObjectInterface::CoreWorkerObjectInterface(CoreWorker &core_worker)
store_providers_.emplace(
static_cast<int>(StoreProviderType::PLASMA),
std::unique_ptr<CoreWorkerStoreProvider>(
new CoreWorkerPlasmaStoreProvider(core_worker_.ray_client_)));
new CoreWorkerPlasmaStoreProvider(core_worker_.store_client_,
core_worker_.store_client_mutex_, core_worker_.raylet_client_)));
}

Status CoreWorkerObjectInterface::Put(const Buffer &buffer, ObjectID *object_id) {
Expand Down
31 changes: 18 additions & 13 deletions src/ray/core_worker/store_provider/plasma_store_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,30 @@

namespace ray {

CoreWorkerPlasmaStoreProvider::CoreWorkerPlasmaStoreProvider(RayClient &ray_client)
: ray_client_(ray_client) {}
CoreWorkerPlasmaStoreProvider::CoreWorkerPlasmaStoreProvider(
plasma::PlasmaClient &store_client,
std::mutex &store_client_mutex,
RayletClient &raylet_client)
: store_client_(store_client),
store_client_mutex_(store_client_mutex),
raylet_client_(raylet_client) {}

Status CoreWorkerPlasmaStoreProvider::Put(const Buffer &buffer,
const ObjectID &object_id) {
auto plasma_id = object_id.ToPlasmaId();
std::shared_ptr<arrow::Buffer> data;
{
std::unique_lock<std::mutex> guard(ray_client_.store_client_mutex_);
std::unique_lock<std::mutex> guard(store_client_mutex_);
RAY_ARROW_RETURN_NOT_OK(
ray_client_.store_client_.Create(plasma_id, buffer.Size(), nullptr, 0, &data));
store_client_.Create(plasma_id, buffer.Size(), nullptr, 0, &data));
}

memcpy(data->mutable_data(), buffer.Data(), buffer.Size());

{
std::unique_lock<std::mutex> guard(ray_client_.store_client_mutex_);
RAY_ARROW_RETURN_NOT_OK(ray_client_.store_client_.Seal(plasma_id));
RAY_ARROW_RETURN_NOT_OK(ray_client_.store_client_.Release(plasma_id));
std::unique_lock<std::mutex> guard(store_client_mutex_);
RAY_ARROW_RETURN_NOT_OK(store_client_.Seal(plasma_id));
RAY_ARROW_RETURN_NOT_OK(store_client_.Release(plasma_id));
}
return Status::OK();
}
Expand Down Expand Up @@ -60,7 +65,7 @@ Status CoreWorkerPlasmaStoreProvider::Get(const std::vector<ObjectID> &ids,

// TODO: can call `fetchOrReconstruct` in batches as an optimization.
RAY_CHECK_OK(
ray_client_.raylet_client_->FetchOrReconstruct(unready_ids, fetch_only, task_id));
raylet_client_.FetchOrReconstruct(unready_ids, fetch_only, task_id));

// Get the objects from the object store, and parse the result.
int64_t get_timeout;
Expand All @@ -80,9 +85,9 @@ Status CoreWorkerPlasmaStoreProvider::Get(const std::vector<ObjectID> &ids,

std::vector<plasma::ObjectBuffer> object_buffers;
{
std::unique_lock<std::mutex> guard(ray_client_.store_client_mutex_);
std::unique_lock<std::mutex> guard(store_client_mutex_);
auto status =
ray_client_.store_client_.Get(plasma_ids, get_timeout, &object_buffers);
store_client_.Get(plasma_ids, get_timeout, &object_buffers);
}

for (size_t i = 0; i < object_buffers.size(); i++) {
Expand All @@ -99,7 +104,7 @@ Status CoreWorkerPlasmaStoreProvider::Get(const std::vector<ObjectID> &ids,
}

if (was_blocked) {
RAY_CHECK_OK(ray_client_.raylet_client_->NotifyUnblocked(task_id));
RAY_CHECK_OK(raylet_client_.NotifyUnblocked(task_id));
}

return Status::OK();
Expand All @@ -110,7 +115,7 @@ Status CoreWorkerPlasmaStoreProvider::Wait(const std::vector<ObjectID> &object_i
const TaskID &task_id,
std::vector<bool> *results) {
WaitResultPair result_pair;
auto status = ray_client_.raylet_client_->Wait(object_ids, num_objects, timeout_ms,
auto status = raylet_client_.Wait(object_ids, num_objects, timeout_ms,
false, task_id, &result_pair);
std::unordered_set<ObjectID> ready_ids;
for (const auto &entry : result_pair.first) {
Expand All @@ -130,7 +135,7 @@ Status CoreWorkerPlasmaStoreProvider::Wait(const std::vector<ObjectID> &object_i
Status CoreWorkerPlasmaStoreProvider::Delete(const std::vector<ObjectID> &object_ids,
bool local_only,
bool delete_creating_tasks) {
return ray_client_.raylet_client_->FreeObjects(object_ids, local_only,
return raylet_client_.FreeObjects(object_ids, local_only,
delete_creating_tasks);
}

Expand Down
13 changes: 10 additions & 3 deletions src/ray/core_worker/store_provider/plasma_store_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ class CoreWorker;
/// local and remote store, remote access is done via raylet.
class CoreWorkerPlasmaStoreProvider : public CoreWorkerStoreProvider {
public:
CoreWorkerPlasmaStoreProvider(RayClient &ray_client);
CoreWorkerPlasmaStoreProvider(plasma::PlasmaClient &store_client,
std::mutex &store_client_mutex, RayletClient &raylet_client);

/// Put an object with specified ID into object store.
///
Expand Down Expand Up @@ -59,8 +60,14 @@ class CoreWorkerPlasmaStoreProvider : public CoreWorkerStoreProvider {
bool delete_creating_tasks) override;

private:
/// Ray client.
RayClient &ray_client_;
/// Plasma store client.
plasma::PlasmaClient &store_client_;

/// Mutex to protect store_client_.
std::mutex &store_client_mutex_;

/// Raylet client.
RayletClient &raylet_client_;
};

} // namespace ray
Expand Down
3 changes: 1 addition & 2 deletions src/ray/core_worker/store_provider/store_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@ namespace ray {

/// Provider interface for store access. Store provider should inherit from this class and
/// provide implementions for the methods. The actual store provider may use a plasma
/// store
/// or local memory store in worker process, or possibly other types of storage.
/// store or local memory store in worker process, or possibly other types of storage.

zhijunfu marked this conversation as resolved.
Show resolved Hide resolved
class CoreWorkerStoreProvider {
public:
Expand Down
2 changes: 1 addition & 1 deletion src/ray/core_worker/task_execution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ CoreWorkerTaskExecutionInterface::CoreWorkerTaskExecutionInterface(
: core_worker_(core_worker) {
task_receivers.emplace(static_cast<int>(TaskTransportType::RAYLET),
std::unique_ptr<CoreWorkerRayletTaskReceiver>(
new CoreWorkerRayletTaskReceiver(core_worker_.ray_client_)));
new CoreWorkerRayletTaskReceiver(core_worker_.raylet_client_)));
}

Status CoreWorkerTaskExecutionInterface::Run(const TaskExecutor &executor) {
Expand Down
2 changes: 1 addition & 1 deletion src/ray/core_worker/task_interface.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ CoreWorkerTaskInterface::CoreWorkerTaskInterface(CoreWorker &core_worker)
task_submitters_.emplace(
static_cast<int>(TaskTransportType::RAYLET),
std::unique_ptr<CoreWorkerRayletTaskSubmitter>(
new CoreWorkerRayletTaskSubmitter(core_worker_.ray_client_)));
new CoreWorkerRayletTaskSubmitter(core_worker_.raylet_client_)));
}

Status CoreWorkerTaskInterface::SubmitTask(const RayFunction &function,
Expand Down
12 changes: 6 additions & 6 deletions src/ray/core_worker/transport/raylet_transport.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,20 @@

namespace ray {

CoreWorkerRayletTaskSubmitter::CoreWorkerRayletTaskSubmitter(RayClient &ray_client)
: ray_client_(ray_client) {}
CoreWorkerRayletTaskSubmitter::CoreWorkerRayletTaskSubmitter(RayletClient &raylet_client)
: raylet_client_(raylet_client) {}

Status CoreWorkerRayletTaskSubmitter::SubmitTask(const TaskSpec &task) {
return ray_client_.raylet_client_->SubmitTask(task.GetDependencies(),
return raylet_client_.SubmitTask(task.GetDependencies(),
task.GetTaskSpecification());
}

CoreWorkerRayletTaskReceiver::CoreWorkerRayletTaskReceiver(RayClient &ray_client)
: ray_client_(ray_client) {}
CoreWorkerRayletTaskReceiver::CoreWorkerRayletTaskReceiver(RayletClient &raylet_client)
: raylet_client_(raylet_client) {}

Status CoreWorkerRayletTaskReceiver::GetTasks(std::vector<TaskSpec> *tasks) {
std::unique_ptr<raylet::TaskSpecification> task_spec;
auto status = ray_client_.raylet_client_->GetTask(&task_spec);
auto status = raylet_client_.GetTask(&task_spec);
if (!status.ok()) {
RAY_LOG(ERROR) << "Get task from raylet failed with error: "
<< ray::Status::IOError(status.message());
Expand Down
12 changes: 6 additions & 6 deletions src/ray/core_worker/transport/raylet_transport.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace ray {

class CoreWorkerRayletTaskSubmitter : public CoreWorkerTaskSubmitter {
public:
CoreWorkerRayletTaskSubmitter(RayClient &ray_client);
CoreWorkerRayletTaskSubmitter(RayletClient &raylet_client);

/// Submit a task for execution to raylet.
///
Expand All @@ -23,20 +23,20 @@ class CoreWorkerRayletTaskSubmitter : public CoreWorkerTaskSubmitter {
virtual Status SubmitTask(const TaskSpec &task) override;

private:
/// ray client.
RayClient &ray_client_;
/// Raylet client.
RayletClient &raylet_client_;
};

class CoreWorkerRayletTaskReceiver : public CoreWorkerTaskReceiver {
public:
CoreWorkerRayletTaskReceiver(RayClient &ray_client);
CoreWorkerRayletTaskReceiver(RayletClient &raylet_client);

// Get tasks for execution from raylet.
virtual Status GetTasks(std::vector<TaskSpec> *tasks) override;

private:
/// ray client.
RayClient &ray_client_;
/// Raylet client.
RayletClient &raylet_client_;
};

} // namespace ray
Expand Down
Loading