diff --git a/src/core/ext/xds/xds_transport_grpc.cc b/src/core/ext/xds/xds_transport_grpc.cc index c793bbba9bb9e..1deff78f2a356 100644 --- a/src/core/ext/xds/xds_transport_grpc.cc +++ b/src/core/ext/xds/xds_transport_grpc.cc @@ -88,21 +88,20 @@ GrpcXdsTransportFactory::GrpcXdsTransport::GrpcStreamingCall::GrpcStreamingCall( grpc_call_error call_error; grpc_op ops[2]; memset(ops, 0, sizeof(ops)); - // Send initial metadata. No callback for this, since we don't really - // care when it finishes. + // Send initial metadata. grpc_op* op = ops; op->op = GRPC_OP_SEND_INITIAL_METADATA; op->data.send_initial_metadata.count = 0; op->flags = GRPC_INITIAL_METADATA_WAIT_FOR_READY | GRPC_INITIAL_METADATA_WAIT_FOR_READY_EXPLICITLY_SET; op->reserved = nullptr; - op++; + ++op; op->op = GRPC_OP_RECV_INITIAL_METADATA; op->data.recv_initial_metadata.recv_initial_metadata = &initial_metadata_recv_; op->flags = 0; op->reserved = nullptr; - op++; + ++op; // Ref will be released in the callback GRPC_CLOSURE_INIT( &on_recv_initial_metadata_, OnRecvInitialMetadata, @@ -119,7 +118,7 @@ GrpcXdsTransportFactory::GrpcXdsTransport::GrpcStreamingCall::GrpcStreamingCall( op->data.recv_status_on_client.status_details = &status_details_; op->flags = 0; op->reserved = nullptr; - op++; + ++op; // This callback signals the end of the call, so it relies on the initial // ref instead of a new ref. When it's invoked, it's the initial ref that is // unreffed. @@ -176,7 +175,6 @@ void GrpcXdsTransportFactory::GrpcXdsTransport::GrpcStreamingCall:: op.op = GRPC_OP_RECV_MESSAGE; op.data.recv_message.recv_message = &recv_message_payload_; GPR_ASSERT(call_ != nullptr); - // Reuses the "OnResponseReceived" ref taken in ctor. const grpc_call_error call_error = grpc_call_start_batch_and_execute(call_, &op, 1, &on_response_received_); GPR_ASSERT(GRPC_CALL_OK == call_error); @@ -184,26 +182,23 @@ void GrpcXdsTransportFactory::GrpcXdsTransport::GrpcStreamingCall:: void GrpcXdsTransportFactory::GrpcXdsTransport::GrpcStreamingCall:: OnRecvInitialMetadata(void* arg, grpc_error_handle /*error*/) { - auto self = static_cast(arg); + RefCountedPtr self(static_cast(arg)); grpc_metadata_array_destroy(&self->initial_metadata_recv_); - self->Unref(DEBUG_LOCATION, "OnRecvInitialMetadata"); } void GrpcXdsTransportFactory::GrpcXdsTransport::GrpcStreamingCall:: OnRequestSent(void* arg, grpc_error_handle error) { - auto* self = static_cast(arg); + RefCountedPtr self(static_cast(arg)); // Clean up the sent message. grpc_byte_buffer_destroy(self->send_message_payload_); self->send_message_payload_ = nullptr; // Invoke request handler. self->event_handler_->OnRequestSent(error.ok()); - // Drop the ref. - self->Unref(DEBUG_LOCATION, "OnRequestSent"); } void GrpcXdsTransportFactory::GrpcXdsTransport::GrpcStreamingCall:: OnResponseReceived(void* arg, grpc_error_handle /*error*/) { - auto self(static_cast(arg)); + RefCountedPtr self(static_cast(arg)); // If there was no payload, then we received status before we received // another message, so we stop reading. if (self->recv_message_payload_ != nullptr) { @@ -217,16 +212,14 @@ void GrpcXdsTransportFactory::GrpcXdsTransport::GrpcStreamingCall:: self->event_handler_->OnRecvMessage(StringViewFromSlice(response_slice)); CSliceUnref(response_slice); } - self->Unref(DEBUG_LOCATION, "StartRecvMessage"); } void GrpcXdsTransportFactory::GrpcXdsTransport::GrpcStreamingCall:: OnStatusReceived(void* arg, grpc_error_handle /*error*/) { - auto* self = static_cast(arg); + RefCountedPtr self(static_cast(arg)); self->event_handler_->OnStatusReceived( absl::Status(static_cast(self->status_code_), StringViewFromSlice(self->status_details_))); - self->Unref(DEBUG_LOCATION, "OnStatusReceived"); } // diff --git a/test/core/xds/BUILD b/test/core/xds/BUILD index 1e19fc469982e..4dc271f1cf2cb 100644 --- a/test/core/xds/BUILD +++ b/test/core/xds/BUILD @@ -152,7 +152,6 @@ grpc_cc_library( external_deps = [ "absl/strings", "absl/types:optional", - "gtest", ], language = "C++", deps = [ diff --git a/test/core/xds/xds_client_fuzzer.cc b/test/core/xds/xds_client_fuzzer.cc index a921fe78893c9..11b2b19a6a121 100644 --- a/test/core/xds/xds_client_fuzzer.cc +++ b/test/core/xds/xds_client_fuzzer.cc @@ -57,7 +57,9 @@ class Fuzzer { // Leave xds_client_ unset, so Act() will be a no-op. return; } - auto transport_factory = MakeOrphanable(); + auto transport_factory = MakeOrphanable([]() { + gpr_assertion_failed(__FILE__, __LINE__, "Multiple concurrent reads"); + }); transport_factory->SetAutoCompleteMessagesFromClient(false); transport_factory->SetAbortOnUndrainedMessages(false); transport_factory_ = transport_factory.get(); diff --git a/test/core/xds/xds_client_test.cc b/test/core/xds/xds_client_test.cc index 6caf52dea856d..6a8e35ceeca92 100644 --- a/test/core/xds/xds_client_test.cc +++ b/test/core/xds/xds_client_test.cc @@ -285,7 +285,7 @@ class XdsClientTest : public ::testing::Test { if (!resource_and_handle.has_value()) { return nullptr; } - return resource_and_handle->resource; + return std::move(resource_and_handle->resource); } absl::optional WaitForNextError( @@ -600,7 +600,8 @@ class XdsClientTest : public ::testing::Test { void InitXdsClient( FakeXdsBootstrap::Builder bootstrap_builder = FakeXdsBootstrap::Builder(), Duration resource_request_timeout = Duration::Seconds(15)) { - auto transport_factory = MakeOrphanable(); + auto transport_factory = MakeOrphanable( + []() { FAIL() << "Multiple concurrent reads"; }); transport_factory_ = transport_factory->Ref().TakeAsSubclass(); xds_client_ = MakeRefCounted( @@ -2723,18 +2724,23 @@ TEST_F(XdsClientTest, AdsReadWaitsForHandleRelease) { InitXdsClient(); // Start watches for "foo1" and "foo2". auto watcher1 = StartFooWatch("foo1"); - auto watcher2 = StartFooWatch("foo2"); - // Watchers should initially not see any resource reported. - EXPECT_FALSE(watcher1->HasEvent()); - EXPECT_FALSE(watcher2->HasEvent()); // XdsClient should have created an ADS stream. auto stream = WaitForAdsStream(); ASSERT_TRUE(stream != nullptr); // XdsClient should have sent a subscription request on the ADS stream. auto request = WaitForRequest(stream.get()); ASSERT_TRUE(request.has_value()); + CheckRequest(*request, XdsFooResourceType::Get()->type_url(), + /*version_info=*/"", /*response_nonce=*/"", + /*error_detail=*/absl::OkStatus(), + /*resource_names=*/{"foo1"}); + auto watcher2 = StartFooWatch("foo2"); request = WaitForRequest(stream.get()); ASSERT_TRUE(request.has_value()); + CheckRequest(*request, XdsFooResourceType::Get()->type_url(), + /*version_info=*/"", /*response_nonce=*/"", + /*error_detail=*/absl::OkStatus(), + /*resource_names=*/{"foo1", "foo2"}); // Send a response with 2 resources. stream->SendMessageToClient( ResponseBuilder(XdsFooResourceType::Get()->type_url()) @@ -2771,7 +2777,6 @@ TEST_F(XdsClientTest, AdsReadWaitsForHandleRelease) { resource1->read_delay_handle.reset(); EXPECT_EQ(stream->reads_started(), 1); resource2->read_delay_handle.reset(); - EXPECT_EQ(stream->reads_started(), 2); resource1 = watcher1->WaitForNextResourceAndHandle(); ASSERT_NE(resource1, absl::nullopt); EXPECT_EQ(resource1->resource->name, "foo1"); @@ -2784,6 +2789,7 @@ TEST_F(XdsClientTest, AdsReadWaitsForHandleRelease) { /*version_info=*/"2", /*response_nonce=*/"B", /*error_detail=*/absl::OkStatus(), /*resource_names=*/{"foo1", "foo2"}); + EXPECT_EQ(stream->reads_started(), 2); resource1->read_delay_handle.reset(); EXPECT_EQ(stream->reads_started(), 3); // Cancel watch. diff --git a/test/core/xds/xds_transport_fake.cc b/test/core/xds/xds_transport_fake.cc index 10a20d68732da..e8ef0b25ad263 100644 --- a/test/core/xds/xds_transport_fake.cc +++ b/test/core/xds/xds_transport_fake.cc @@ -129,14 +129,11 @@ void FakeXdsTransportFactory::FakeStreamingCall::CompleteSendMessageFromClient( void FakeXdsTransportFactory::FakeStreamingCall::StartRecvMessage() { absl::optional pending; MutexLock lock(&mu_); - if (read_pending_) { - gpr_log(GPR_ERROR, - "StartRecvMessage had been called while there is already a pending " - "read request"); - return; + if (num_pending_reads_ > 0) { + too_many_pending_reads_callback_(); } ++reads_started_; - read_pending_ = true; + ++num_pending_reads_; if (!to_client_messages_.empty()) { // Dispatch pending message (if there's one) on a separate thread to avoid // recursion @@ -158,19 +155,21 @@ void FakeXdsTransportFactory::FakeStreamingCall::SendMessageToClient( void FakeXdsTransportFactory::FakeStreamingCall::MaybeDeliverMessageToClient() { RefCountedPtr event_handler; std::string message; - { - ReleasableMutexLock lock(&mu_); - if (!read_pending_ || to_client_messages_.empty()) { - return; + // Loop terminates with a break inside + while (true) { + { + MutexLock lock(&mu_); + if (num_pending_reads_ == 0 || to_client_messages_.empty()) { + break; + } + --num_pending_reads_; + message = std::move(to_client_messages_.front()); + to_client_messages_.pop_front(); + event_handler = event_handler_; } - read_pending_ = false; - message = std::move(to_client_messages_.front()); - to_client_messages_.pop_front(); - event_handler = event_handler_; - lock.Release(); + ExecCtx exec_ctx; + event_handler->OnRecvMessage(message); } - ExecCtx exec_ctx; - event_handler->OnRecvMessage(message); } void FakeXdsTransportFactory::FakeStreamingCall::MaybeSendStatusToClient( @@ -259,7 +258,8 @@ FakeXdsTransportFactory::FakeXdsTransport::CreateStreamingCall( const char* method, std::unique_ptr event_handler) { auto call = MakeOrphanable( - RefAsSubclass(), method, std::move(event_handler)); + RefAsSubclass(), method, std::move(event_handler), + too_many_pending_reads_callback_); MutexLock lock(&mu_); active_calls_[method] = call->Ref().TakeAsSubclass(); cv_.Signal(); @@ -284,7 +284,7 @@ FakeXdsTransportFactory::Create( auto transport = MakeOrphanable( RefAsSubclass(), server, std::move(on_connectivity_failure), auto_complete_messages_from_client_, - abort_on_undrained_messages_); + abort_on_undrained_messages_, too_many_pending_reads_callback_); entry = transport->Ref().TakeAsSubclass(); return transport; } diff --git a/test/core/xds/xds_transport_fake.h b/test/core/xds/xds_transport_fake.h index a1a9e8d85dbd0..70c54dcbc4d16 100644 --- a/test/core/xds/xds_transport_fake.h +++ b/test/core/xds/xds_transport_fake.h @@ -58,11 +58,14 @@ class FakeXdsTransportFactory : public XdsTransportFactory { public: FakeStreamingCall( RefCountedPtr transport, const char* method, - std::unique_ptr event_handler) + std::unique_ptr event_handler, + std::function too_many_pending_reads_callback) : transport_(std::move(transport)), method_(method), - event_handler_(MakeRefCounted( - std::move(event_handler))) {} + event_handler_( + MakeRefCounted(std::move(event_handler))), + too_many_pending_reads_callback_( + std::move(too_many_pending_reads_callback)) {} ~FakeStreamingCall() override; @@ -128,11 +131,14 @@ class FakeXdsTransportFactory : public XdsTransportFactory { bool status_sent_ ABSL_GUARDED_BY(&mu_) = false; bool orphaned_ ABSL_GUARDED_BY(&mu_) = false; size_t reads_started_ ABSL_GUARDED_BY(&mu_) = 0; - bool read_pending_ ABSL_GUARDED_BY(&mu_) = false; + size_t num_pending_reads_ ABSL_GUARDED_BY(&mu_) = 0; std::deque to_client_messages_ ABSL_GUARDED_BY(&mu_); + std::function too_many_pending_reads_callback_; }; - FakeXdsTransportFactory() = default; + FakeXdsTransportFactory(std::function too_many_pending_reads_callback) + : too_many_pending_reads_callback_( + std::move(too_many_pending_reads_callback)) {} using XdsTransportFactory::Ref; // Make it public. @@ -173,7 +179,8 @@ class FakeXdsTransportFactory : public XdsTransportFactory { const XdsBootstrap::XdsServer& server, std::function on_connectivity_failure, bool auto_complete_messages_from_client, - bool abort_on_undrained_messages) + bool abort_on_undrained_messages, + std::function too_many_pending_reads_callback) : factory_(std::move(factory)), server_(server), auto_complete_messages_from_client_( @@ -181,7 +188,9 @@ class FakeXdsTransportFactory : public XdsTransportFactory { abort_on_undrained_messages_(abort_on_undrained_messages), on_connectivity_failure_( MakeRefCounted( - std::move(on_connectivity_failure))) {} + std::move(on_connectivity_failure))), + too_many_pending_reads_callback_( + std::move(too_many_pending_reads_callback)) {} void Orphan() override; @@ -235,6 +244,7 @@ class FakeXdsTransportFactory : public XdsTransportFactory { ABSL_GUARDED_BY(&mu_); std::map> active_calls_ ABSL_GUARDED_BY(&mu_); + std::function too_many_pending_reads_callback_; }; OrphanablePtr Create( @@ -250,6 +260,7 @@ class FakeXdsTransportFactory : public XdsTransportFactory { transport_map_ ABSL_GUARDED_BY(&mu_); bool auto_complete_messages_from_client_ ABSL_GUARDED_BY(&mu_) = true; bool abort_on_undrained_messages_ ABSL_GUARDED_BY(&mu_) = true; + std::function too_many_pending_reads_callback_; }; } // namespace grpc_core