diff --git a/src/ray/core_worker/task_manager.cc b/src/ray/core_worker/task_manager.cc index 33f14580d5e09..a451217358475 100644 --- a/src/ray/core_worker/task_manager.cc +++ b/src/ray/core_worker/task_manager.cc @@ -30,6 +30,22 @@ const int64_t kTaskFailureThrottlingThreshold = 50; // Throttle task failure logs to once this interval. const int64_t kTaskFailureLoggingFrequencyMillis = 5000; +std::vector ObjectRefStream::GetItemsUnconsumed() const { + std::vector result; + if (next_index_ == end_of_stream_index_) { + return {}; + } + + for (const auto &it : item_index_to_refs_) { + const auto &index = it.first; + const auto &object_id = it.second; + if (index >= next_index_) { + result.push_back(object_id); + } + } + return result; +} + Status ObjectRefStream::TryReadNextItem(ObjectID *object_id_out) { bool is_eof_set = end_of_stream_index_ != -1; if (is_eof_set && next_index_ >= end_of_stream_index_) { @@ -388,25 +404,8 @@ void TaskManager::DelObjectRefStream(const ObjectID &generator_id) { return; } - while (true) { - ObjectID object_id; - const auto &status = TryReadObjectRefStreamInternal(generator_id, &object_id); - - // keyError means the stream reaches to EoF. - if (status.IsObjectRefStreamEoF()) { - break; - } - - if (object_id == ObjectID::Nil()) { - // No more objects to obtain. Stop iteration. - break; - } else { - // It means the object hasn't been consumed. - // We should remove references since we have 1 reference to this object. - object_ids_unconsumed.push_back(object_id); - } - } - + const auto &stream = it->second; + object_ids_unconsumed = stream.GetItemsUnconsumed(); object_ref_streams_.erase(generator_id); } @@ -454,7 +453,6 @@ bool TaskManager::HandleReportGeneratorItemReturns( absl::MutexLock lock(&mu_); auto stream_it = object_ref_streams_.find(generator_id); if (stream_it == object_ref_streams_.end()) { - // SANG-TODO add an unit test. // Stream has been already deleted. Do not handle it. return false; } diff --git a/src/ray/core_worker/task_manager.h b/src/ray/core_worker/task_manager.h index 31b039711f721..e3abfb24d48e3 100644 --- a/src/ray/core_worker/task_manager.h +++ b/src/ray/core_worker/task_manager.h @@ -120,6 +120,11 @@ class ObjectRefStream { /// \param[in] The last item index that means the end of stream. void MarkEndOfStream(int64_t item_index); + /// Get all the ObjectIDs that are not read yet via TryReadNextItem. + /// + /// \return A list of object IDs that are not read yet. + std::vector GetItemsUnconsumed() const; + private: const ObjectID generator_id_; diff --git a/src/ray/core_worker/test/task_manager_test.cc b/src/ray/core_worker/test/task_manager_test.cc index ce7c325474467..9dce967c43469 100644 --- a/src/ray/core_worker/test/task_manager_test.cc +++ b/src/ray/core_worker/test/task_manager_test.cc @@ -1672,6 +1672,62 @@ TEST_F(TaskManagerTest, TestObjectRefStreamOutofOrder) { manager_.DelObjectRefStream(generator_id); } +TEST_F(TaskManagerTest, TestObjectRefStreamDelOutOfOrder) { + /** + * Verify there's no leak when we delete a ObjectRefStream + * that has out of order WRITEs. + * WRITE index 1 -> Del -> Write index 0. Both 0 and 1 have to be + * deleted. + */ + // Submit a generator task. + rpc::Address caller_address; + auto spec = CreateTaskHelper(1, {}, /*dynamic_returns=*/true); + auto generator_id = spec.ReturnId(0); + manager_.AddPendingTask(caller_address, spec, "", /*num_retries=*/0); + manager_.MarkDependenciesResolved(spec.TaskId()); + manager_.MarkTaskWaitingForExecution( + spec.TaskId(), NodeID::FromRandom(), WorkerID::FromRandom()); + + // CREATE + manager_.CreateObjectRefStream(generator_id); + + // WRITE to index 1 + auto dynamic_return_id_index_1 = ObjectID::FromIndex(spec.TaskId(), 3); + auto data = GenerateRandomBuffer(); + auto req = GetIntermediateTaskReturn( + /*idx*/ 1, + /*finished*/ false, + generator_id, + /*dynamic_return_id*/ dynamic_return_id_index_1, + /*data*/ data, + /*set_in_plasma*/ false); + ASSERT_TRUE(manager_.HandleReportGeneratorItemReturns(req)); + ASSERT_TRUE(reference_counter_->HasReference(dynamic_return_id_index_1)); + + // Delete the stream. This should remove references from ^. + manager_.DelObjectRefStream(generator_id); + ASSERT_FALSE(reference_counter_->HasReference(dynamic_return_id_index_1)); + + // WRITE to index 0. It should fail cuz the stream has been removed. + auto dynamic_return_id_index_0 = ObjectID::FromIndex(spec.TaskId(), 2); + data = GenerateRandomBuffer(); + req = GetIntermediateTaskReturn( + /*idx*/ 0, + /*finished*/ false, + generator_id, + /*dynamic_return_id*/ dynamic_return_id_index_0, + /*data*/ data, + /*set_in_plasma*/ false); + ASSERT_FALSE(manager_.HandleReportGeneratorItemReturns(req)); + ASSERT_FALSE(reference_counter_->HasReference(dynamic_return_id_index_0)); + + rpc::PushTaskReply reply; + manager_.CompletePendingTask(spec.TaskId(), reply, caller_address, false); + + // There must be only a generator ID. + ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 1); +} + } // namespace core } // namespace ray