Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[prototype] boost fiber instead of asio for async #16699

Closed
wants to merge 18 commits into from
2 changes: 2 additions & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ cc_library(
deps = [
":ray_common",
"@boost//:asio",
"@boost//:fiber",
"@com_google_absl//absl/types:optional",
"@com_github_grpc_grpc//:grpc++",
"@com_google_protobuf//:protobuf",
],
Expand Down
94 changes: 44 additions & 50 deletions src/ray/core_worker/core_worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1382,60 +1382,54 @@ Status CoreWorker::Delete(const std::vector<ObjectID> &object_ids, bool local_on
Status CoreWorker::GetLocationFromOwner(
const std::vector<ObjectID> &object_ids, int64_t timeout_ms,
std::vector<std::shared_ptr<ObjectLocation>> *results) {
results->resize(object_ids.size());
if (object_ids.empty()) {
return Status::OK();
}

auto mutex = std::make_shared<absl::Mutex>();
auto num_remaining = std::make_shared<size_t>(object_ids.size());
auto ready_promise = std::make_shared<std::promise<void>>();
auto location_by_id =
std::make_shared<absl::flat_hash_map<ObjectID, std::shared_ptr<ObjectLocation>>>();
boost::fibers::future<Status> f(boost::fibers::async([&]() mutable {
results->resize(object_ids.size());
if (object_ids.empty()) {
return Status::OK();
}

for (const auto &object_id : object_ids) {
auto owner_address = GetOwnerAddress(object_id);
auto client = core_worker_client_pool_->GetOrConnect(owner_address);
rpc::GetObjectLocationsOwnerRequest request;
request.set_intended_worker_id(owner_address.worker_id());
request.set_object_id(object_id.Binary());
request.set_last_version(-1);
client->GetObjectLocationsOwner(
request,
[object_id, mutex, num_remaining, ready_promise, location_by_id](
const Status &status, const rpc::GetObjectLocationsOwnerReply &reply) {
absl::MutexLock lock(mutex.get());
if (status.ok()) {
location_by_id->emplace(
object_id, std::make_shared<ObjectLocation>(CreateObjectLocation(reply)));
} else {
RAY_LOG(WARNING) << "Failed to query location information for " << object_id
<< " with error: " << status.ToString();
}
(*num_remaining)--;
if (*num_remaining == 0) {
ready_promise->set_value();
}
});
}
if (timeout_ms < 0) {
ready_promise->get_future().wait();
} else if (ready_promise->get_future().wait_for(
std::chrono::microseconds(timeout_ms)) != std::future_status::ready) {
std::ostringstream stream;
stream << "Failed querying object locations within " << timeout_ms
<< " milliseconds.";
return Status::TimedOut(stream.str());
}
absl::flat_hash_map<ObjectID, std::shared_ptr<ObjectLocation>> location_by_id;
std::vector<rpc::FutureType<rpc::GetObjectLocationsOwnerReply>> replies;
for (const auto &object_id : object_ids) {
auto owner_address = GetOwnerAddress(object_id);
auto client = core_worker_client_pool_->GetOrConnect(owner_address);
rpc::GetObjectLocationsOwnerRequest request;
request.set_intended_worker_id(owner_address.worker_id());
request.set_object_id(object_id.Binary());
request.set_last_version(-1);
client->GetObjectLocationsOwner(request);
replies.emplace_back(client->GetObjectLocationsOwner(request));
}
for (size_t i = 0; i < replies.size(); ++i) {
auto reply = replies[i].get();
if (reply.second.ok()) {
location_by_id.emplace(object_ids[i], std::make_shared<ObjectLocation>(
CreateObjectLocation(reply.first)));
} else {
RAY_LOG(WARNING) << "Failed to query location information for " << object_ids[i]
<< " with error: " << reply.second.ToString();
}
}
for (size_t i = 0; i < object_ids.size(); i++) {
auto pair = location_by_id.find(object_ids[i]);
if (pair == location_by_id.end()) {
continue;
}
(*results)[i] = pair->second;
}
return Status::OK();
}));

for (size_t i = 0; i < object_ids.size(); i++) {
auto pair = location_by_id->find(object_ids[i]);
if (pair == location_by_id->end()) {
continue;
if (timeout_ms > 0) {
auto wait_status = f.wait_for(std::chrono::milliseconds(timeout_ms));
if (wait_status != boost::fibers::future_status::ready) {
std::ostringstream stream;
stream << "Failed querying object locations within " << timeout_ms
<< " milliseconds.";
return Status::TimedOut(stream.str());
}
(*results)[i] = pair->second;
}
return Status::OK();
return f.get();
}

void CoreWorker::TriggerGlobalGC() {
Expand Down
57 changes: 43 additions & 14 deletions src/ray/rpc/client_call.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@

#include <boost/asio.hpp>

#include <boost/fiber/all.hpp>
#include "absl/synchronization/mutex.h"
#include "absl/types/optional.h"
#include "ray/common/asio/instrumented_io_context.h"
#include "ray/common/grpc_util.h"
#include "ray/common/status.h"
Expand All @@ -34,6 +36,7 @@ namespace rpc {
/// template as well.
class ClientCall {
public:
ClientCall(bool run_inline = false) : run_inline_(run_inline) {}
/// The callback to be called by `ClientCallManager` when the reply of this request is
/// received.
virtual void OnReplyReceived() = 0;
Expand All @@ -45,6 +48,11 @@ class ClientCall {
virtual std::shared_ptr<StatsHandle> GetStatsHandle() = 0;

virtual ~ClientCall() = default;

bool ShouldRunInline() const { return run_inline_; }

private:
bool run_inline_;
};

class ClientCallManager;
Expand All @@ -55,6 +63,11 @@ class ClientCallManager;
template <class Reply>
using ClientCallback = std::function<void(const Status &status, const Reply &reply)>;

template <typename Reply>
using PromiseType = boost::fibers::promise<std::pair<Reply, Status>>;
template <typename Reply>
using FutureType = boost::fibers::future<std::pair<Reply, Status>>;

/// Implementation of the `ClientCall`. It represents a `ClientCall` for a particular
/// RPC method.
///
Expand All @@ -66,9 +79,12 @@ class ClientCallImpl : public ClientCall {
///
/// \param[in] callback The callback function to handle the reply.
explicit ClientCallImpl(const ClientCallback<Reply> &callback,
std::shared_ptr<StatsHandle> stats_handle)
: callback_(std::move(const_cast<ClientCallback<Reply> &>(callback))),
stats_handle_(std::move(stats_handle)) {}
std::shared_ptr<StatsHandle> stats_handle,
absl::optional<PromiseType<Reply>> promise)
: ClientCall(promise != absl::nullopt),
callback_(std::move(const_cast<ClientCallback<Reply> &>(callback))),
stats_handle_(std::move(stats_handle)),
promise_(std::move(promise)) {}

Status GetStatus() override {
absl::MutexLock lock(&mutex_);
Expand All @@ -86,6 +102,13 @@ class ClientCallImpl : public ClientCall {
absl::MutexLock lock(&mutex_);
status = return_status_;
}

if (promise_) {
RAY_CHECK(callback_ == nullptr);
promise_->set_value(std::make_pair(std::move(reply_), status));
return;
}

if (callback_ != nullptr) {
callback_(status, reply_);
}
Expand Down Expand Up @@ -123,6 +146,7 @@ class ClientCallImpl : public ClientCall {
/// the server and/or tweak certain RPC behaviors.
grpc::ClientContext context_;

absl::optional<PromiseType<Reply>> promise_;
friend class ClientCallManager;
};

Expand Down Expand Up @@ -215,10 +239,10 @@ class ClientCallManager {
typename GrpcService::Stub &stub,
const PrepareAsyncFunction<GrpcService, Request, Reply> prepare_async_function,
const Request &request, const ClientCallback<Reply> &callback,
std::string call_name) {
std::string call_name, absl::optional<PromiseType<Reply>> promise = absl::nullopt) {
auto stats_handle = main_service_.RecordStart(call_name);
auto call =
std::make_shared<ClientCallImpl<Reply>>(callback, std::move(stats_handle));
auto call = std::make_shared<ClientCallImpl<Reply>>(callback, std::move(stats_handle),
std::move(promise));
// Send request.
// Find the next completion queue to wait for response.
call->response_reader_ = (stub.*prepare_async_function)(
Expand Down Expand Up @@ -265,14 +289,19 @@ class ClientCallManager {
std::shared_ptr<StatsHandle> stats_handle = tag->GetCall()->GetStatsHandle();
RAY_CHECK(stats_handle != nullptr);
if (ok && !main_service_.stopped() && !shutdown_) {
// Post the callback to the main event loop.
main_service_.post(
[tag]() {
tag->GetCall()->OnReplyReceived();
// The call is finished, and we can delete this tag now.
delete tag;
},
std::move(stats_handle));
if (tag->GetCall()->ShouldRunInline()) {
tag->GetCall()->OnReplyReceived();
delete tag;
} else {
// Post the callback to the main event loop.
main_service_.post(
[tag]() {
tag->GetCall()->OnReplyReceived();
// The call is finished, and we can delete this tag now.
delete tag;
},
std::move(stats_handle));
}
} else {
delete tag;
}
Expand Down
19 changes: 19 additions & 0 deletions src/ray/rpc/grpc_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,14 @@ namespace rpc {
INVOKE_RPC_CALL(SERVICE, METHOD, request, callback, rpc_client); \
}

#define FIBER_RPC_CLIENT_METHOD(SERVICE, METHOD, rpc_client, SPECS) \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this have reduced performance compared to the RPCs run directly on gRPC threads?

How does thread pooling work for fibers, do we have a pool of fiber threads?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the gRPC thread, it'll move the result out, so it won't reduce the performance.
Comparision:

  • Previously, we post a job to io_thread pool. Here a job is created in asio pool.
  • Here, we pass the result to promise with move. The extra cost is moving the structure, which is ok. For example, this is the pub sub message, if the arena is the same, it'll just swap and we always use the default one.
  inline PubMessage& operator=(PubMessage&& from) noexcept {
    if (GetArena() == from.GetArena()) {
      if (this != &from) InternalSwap(&from);
    } else {
      CopyFrom(from);
    }
    return *this;
  }

No, we don't have a pool for fiber thread for this prototype. We just reuse the thread calling CoreWorker which can be optimized later. We should put everything not cpu intensive to one thread (fiber thread) and put cpu intensive to cpu thread pool.

Fiber's thread is like asio thread, the different part is that, for asio, they manage task, which is a closure. For fiber, it also manage context and scheduling of the tasks. In asio, we wrap everything into lambda capture, and we lose call stack. In fiber, it put things into context and makes them resumable.

In the middle of migration, we can share the same thread with asio. I mean if we go this way we can do it granularly.

FutureType<METHOD##Reply> METHOD(const METHOD##Request &request) SPECS { \
PromiseType<METHOD##Reply> promise; \
auto future = promise.get_future(); \
INVOKE_RPC_CALL(SERVICE, METHOD, request, std::move(promise), rpc_client); \
return future; \
}

template <class GrpcService>
class GrpcClient {
public:
Expand Down Expand Up @@ -94,6 +102,17 @@ class GrpcClient {
RAY_CHECK(call != nullptr);
}

template <class Request, class Reply>
void CallMethod(
const PrepareAsyncFunction<GrpcService, Request, Reply> prepare_async_function,
const Request &request, PromiseType<Reply> promise,
std::string call_name = "UNKNOWN_RPC") {
auto call = client_call_manager_.CreateCall<GrpcService, Request, Reply>(
*stub_, prepare_async_function, request, nullptr, std::move(call_name),
std::move(promise));
RAY_CHECK(call != nullptr);
}

private:
ClientCallManager &client_call_manager_;
/// The gRPC-generated stub.
Expand Down
5 changes: 5 additions & 0 deletions src/ray/rpc/worker/core_worker_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,9 @@ class CoreWorkerClientInterface : public pubsub::SubscriberClientInterface {
const GetObjectLocationsOwnerRequest &request,
const ClientCallback<GetObjectLocationsOwnerReply> &callback) {}

virtual rpc::FutureType<GetObjectLocationsOwnerReply> GetObjectLocationsOwner(
const GetObjectLocationsOwnerRequest &request) = 0;

/// Tell this actor to exit immediately.
virtual void KillActor(const KillActorRequest &request,
const ClientCallback<KillActorReply> &callback) {}
Expand Down Expand Up @@ -247,6 +250,8 @@ class CoreWorkerClient : public std::enable_shared_from_this<CoreWorkerClient>,

VOID_RPC_CLIENT_METHOD(CoreWorkerService, GetObjectLocationsOwner, grpc_client_,
override)
FIBER_RPC_CLIENT_METHOD(CoreWorkerService, GetObjectLocationsOwner, grpc_client_,
override)

VOID_RPC_CLIENT_METHOD(CoreWorkerService, GetCoreWorkerStats, grpc_client_, override)

Expand Down