Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ARROW-5533: [C++] [Plasma] make plasma client thread safe #4503

Closed
wants to merge 8 commits into from
37 changes: 36 additions & 1 deletion cpp/src/plasma/client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,8 @@ class PlasmaClient::Impl : public std::enable_shared_from_this<PlasmaClient::Imp
int64_t store_capacity_;
/// A hash set to record the ids that users want to delete but still in use.
std::unordered_set<ObjectID> deletion_cache_;
/// A mutex which protects this class.
std::recursive_mutex client_mutex_;

#ifdef PLASMA_CUDA
/// Cuda Device Manager.
Expand Down Expand Up @@ -341,6 +343,8 @@ uint8_t* PlasmaClient::Impl::LookupMmappedFile(int store_fd_val) {
}

bool PlasmaClient::Impl::IsInUse(const ObjectID& object_id) {
std::lock_guard<std::recursive_mutex> guard(client_mutex_);

const auto elem = objects_in_use_.find(object_id);
return (elem != objects_in_use_.end());
}
Expand Down Expand Up @@ -384,6 +388,8 @@ void PlasmaClient::Impl::IncrementObjectCount(const ObjectID& object_id,
Status PlasmaClient::Impl::Create(const ObjectID& object_id, int64_t data_size,
const uint8_t* metadata, int64_t metadata_size,
std::shared_ptr<Buffer>* data, int device_num) {
std::lock_guard<std::recursive_mutex> guard(client_mutex_);

ARROW_LOG(DEBUG) << "called plasma_create on conn " << store_conn_ << " with size "
<< data_size << " and metadata size " << metadata_size;
RETURN_NOT_OK(
Expand Down Expand Up @@ -451,8 +457,9 @@ Status PlasmaClient::Impl::Create(const ObjectID& object_id, int64_t data_size,
Status PlasmaClient::Impl::CreateAndSeal(const ObjectID& object_id,
const std::string& data,
const std::string& metadata) {
ARROW_LOG(DEBUG) << "called CreateAndSeal on conn " << store_conn_;
std::lock_guard<std::recursive_mutex> guard(client_mutex_);

ARROW_LOG(DEBUG) << "called CreateAndSeal on conn " << store_conn_;
// Compute the object hash.
static unsigned char digest[kDigestSize];
// CreateAndSeal currently only supports device_num = 0, which corresponds to
Expand Down Expand Up @@ -608,6 +615,8 @@ Status PlasmaClient::Impl::GetBuffers(

Status PlasmaClient::Impl::Get(const std::vector<ObjectID>& object_ids,
int64_t timeout_ms, std::vector<ObjectBuffer>* out) {
std::lock_guard<std::recursive_mutex> guard(client_mutex_);

const auto wrap_buffer = [=](const ObjectID& object_id,
const std::shared_ptr<Buffer>& buffer) {
return std::make_shared<PlasmaBuffer>(shared_from_this(), object_id, buffer);
Expand All @@ -619,6 +628,8 @@ Status PlasmaClient::Impl::Get(const std::vector<ObjectID>& object_ids,

Status PlasmaClient::Impl::Get(const ObjectID* object_ids, int64_t num_objects,
int64_t timeout_ms, ObjectBuffer* out) {
std::lock_guard<std::recursive_mutex> guard(client_mutex_);

const auto wrap_buffer = [](const ObjectID& object_id,
const std::shared_ptr<Buffer>& buffer) { return buffer; };
return GetBuffers(object_ids, num_objects, timeout_ms, wrap_buffer, out);
Expand All @@ -635,6 +646,8 @@ Status PlasmaClient::Impl::MarkObjectUnused(const ObjectID& object_id) {
}

Status PlasmaClient::Impl::Release(const ObjectID& object_id) {
std::lock_guard<std::recursive_mutex> guard(client_mutex_);

// If the client is already disconnected, ignore release requests.
if (store_conn_ < 0) {
return Status::OK();
Expand Down Expand Up @@ -672,6 +685,8 @@ Status PlasmaClient::Impl::Release(const ObjectID& object_id) {

// This method is used to query whether the plasma store contains an object.
Status PlasmaClient::Impl::Contains(const ObjectID& object_id, bool* has_object) {
std::lock_guard<std::recursive_mutex> guard(client_mutex_);

// Check if we already have a reference to the object.
if (objects_in_use_.count(object_id) > 0) {
*has_object = 1;
Expand All @@ -690,6 +705,7 @@ Status PlasmaClient::Impl::Contains(const ObjectID& object_id, bool* has_object)
}

Status PlasmaClient::Impl::List(ObjectTable* objects) {
std::lock_guard<std::recursive_mutex> guard(client_mutex_);
RETURN_NOT_OK(SendListRequest(store_conn_));
std::vector<uint8_t> buffer;
RETURN_NOT_OK(PlasmaReceive(store_conn_, MessageType::PlasmaListReply, &buffer));
Expand Down Expand Up @@ -768,6 +784,8 @@ uint64_t PlasmaClient::Impl::ComputeObjectHash(const uint8_t* data, int64_t data
}

Status PlasmaClient::Impl::Seal(const ObjectID& object_id) {
std::lock_guard<std::recursive_mutex> guard(client_mutex_);

// Make sure this client has a reference to the object before sending the
// request to Plasma.
auto object_entry = objects_in_use_.find(object_id);
Expand All @@ -794,6 +812,7 @@ Status PlasmaClient::Impl::Seal(const ObjectID& object_id) {
}

Status PlasmaClient::Impl::Abort(const ObjectID& object_id) {
std::lock_guard<std::recursive_mutex> guard(client_mutex_);
auto object_entry = objects_in_use_.find(object_id);
ARROW_CHECK(object_entry != objects_in_use_.end())
<< "Plasma client called abort on an object without a reference to it";
Expand Down Expand Up @@ -832,6 +851,8 @@ Status PlasmaClient::Impl::Abort(const ObjectID& object_id) {
}

Status PlasmaClient::Impl::Delete(const std::vector<ObjectID>& object_ids) {
std::lock_guard<std::recursive_mutex> guard(client_mutex_);

std::vector<ObjectID> not_in_use_ids;
for (auto& object_id : object_ids) {
// If the object is in used, skip it.
Expand All @@ -855,6 +876,8 @@ Status PlasmaClient::Impl::Delete(const std::vector<ObjectID>& object_ids) {
}

Status PlasmaClient::Impl::Evict(int64_t num_bytes, int64_t& num_bytes_evicted) {
std::lock_guard<std::recursive_mutex> guard(client_mutex_);

// Send a request to the store to evict objects.
RETURN_NOT_OK(SendEvictRequest(store_conn_, num_bytes));
// Wait for a response with the number of bytes actually evicted.
Expand All @@ -865,6 +888,8 @@ Status PlasmaClient::Impl::Evict(int64_t num_bytes, int64_t& num_bytes_evicted)
}

Status PlasmaClient::Impl::Hash(const ObjectID& object_id, uint8_t* digest) {
std::lock_guard<std::recursive_mutex> guard(client_mutex_);

// Get the plasma object data. We pass in a timeout of 0 to indicate that
// the operation should timeout immediately.
std::vector<ObjectBuffer> object_buffers;
Expand All @@ -880,6 +905,8 @@ Status PlasmaClient::Impl::Hash(const ObjectID& object_id, uint8_t* digest) {
}

Status PlasmaClient::Impl::Subscribe(int* fd) {
std::lock_guard<std::recursive_mutex> guard(client_mutex_);

int sock[2];
// Create a non-blocking socket pair. This will only be used to send
// notifications from the Plasma store to the client.
Expand All @@ -902,6 +929,8 @@ Status PlasmaClient::Impl::Subscribe(int* fd) {
Status PlasmaClient::Impl::DecodeNotification(const uint8_t* buffer, ObjectID* object_id,
int64_t* data_size,
int64_t* metadata_size) {
std::lock_guard<std::recursive_mutex> guard(client_mutex_);

auto object_info = flatbuffers::GetRoot<fb::ObjectInfo>(buffer);
ARROW_CHECK(object_info->object_id()->size() == sizeof(ObjectID));
memcpy(object_id, object_info->object_id()->data(), sizeof(ObjectID));
Expand All @@ -917,6 +946,8 @@ Status PlasmaClient::Impl::DecodeNotification(const uint8_t* buffer, ObjectID* o

Status PlasmaClient::Impl::GetNotification(int fd, ObjectID* object_id,
int64_t* data_size, int64_t* metadata_size) {
std::lock_guard<std::recursive_mutex> guard(client_mutex_);

auto notification = ReadMessageAsync(fd);
if (notification == NULL) {
return Status::IOError("Failed to read object notification from Plasma socket");
Expand All @@ -927,6 +958,8 @@ Status PlasmaClient::Impl::GetNotification(int fd, ObjectID* object_id,
Status PlasmaClient::Impl::Connect(const std::string& store_socket_name,
const std::string& manager_socket_name,
int release_delay, int num_retries) {
std::lock_guard<std::recursive_mutex> guard(client_mutex_);

RETURN_NOT_OK(ConnectIpcSocketRetry(store_socket_name, num_retries, -1, &store_conn_));
if (manager_socket_name != "") {
return Status::NotImplemented("plasma manager is no longer supported");
Expand All @@ -944,6 +977,8 @@ Status PlasmaClient::Impl::Connect(const std::string& store_socket_name,
}

Status PlasmaClient::Impl::Disconnect() {
std::lock_guard<std::recursive_mutex> guard(client_mutex_);

// NOTE: We purposefully do not finish sending release calls for objects in
// use, so that we don't duplicate PlasmaClient::Release calls (when handling
// a SIGTERM, for example).
Expand Down