diff --git a/docs/root/version_history/current.rst b/docs/root/version_history/current.rst index 660487b53e0f..739bd79aca77 100644 --- a/docs/root/version_history/current.rst +++ b/docs/root/version_history/current.rst @@ -6,3 +6,4 @@ Changes * http: fixed URL parsing for HTTP/1.1 fully qualified URLs and connect requests containing IPv6 addresses. * http: fixed bugs in datadog and squash filter's handling of responses with no bodies. +* tls: fix detection of the upstream connection close event. diff --git a/source/extensions/transport_sockets/tls/ssl_socket.cc b/source/extensions/transport_sockets/tls/ssl_socket.cc index 03f3f8b44e4a..6bb9ae48d221 100644 --- a/source/extensions/transport_sockets/tls/ssl_socket.cc +++ b/source/extensions/transport_sockets/tls/ssl_socket.cc @@ -128,10 +128,18 @@ Network::IoResult SslSocket::doRead(Buffer::Instance& read_buffer) { case SSL_ERROR_WANT_READ: break; case SSL_ERROR_ZERO_RETURN: + // Graceful shutdown using close_notify TLS alert. end_stream = true; break; + case SSL_ERROR_SYSCALL: + if (result.error_.value() == 0) { + // Non-graceful shutdown by closing the underlying socket. + end_stream = true; + break; + } + FALLTHRU; case SSL_ERROR_WANT_WRITE: - // Renegotiation has started. We don't handle renegotiation so just fall through. + // Renegotiation has started. We don't handle renegotiation so just fall through. default: drainErrorQueue(); action = PostIoAction::Close; diff --git a/test/extensions/transport_sockets/tls/ssl_socket_test.cc b/test/extensions/transport_sockets/tls/ssl_socket_test.cc index 764f3af7cc98..2c4d191f0784 100644 --- a/test/extensions/transport_sockets/tls/ssl_socket_test.cc +++ b/test/extensions/transport_sockets/tls/ssl_socket_test.cc @@ -2462,6 +2462,182 @@ TEST_P(SslSocketTest, HalfClose) { dispatcher_->run(Event::Dispatcher::RunType::Block); } +TEST_P(SslSocketTest, ShutdownWithCloseNotify) { + const std::string server_ctx_yaml = R"EOF( + common_tls_context: + tls_certificates: + certificate_chain: + filename: "{{ test_tmpdir }}/unittestcert.pem" + private_key: + filename: "{{ test_tmpdir }}/unittestkey.pem" + validation_context: + trusted_ca: + filename: "{{ test_rundir }}/test/extensions/transport_sockets/tls/test_data/ca_certificates.pem" +)EOF"; + + envoy::extensions::transport_sockets::tls::v3::DownstreamTlsContext server_tls_context; + TestUtility::loadFromYaml(TestEnvironment::substitute(server_ctx_yaml), server_tls_context); + auto server_cfg = std::make_unique(server_tls_context, factory_context_); + ContextManagerImpl manager(time_system_); + Stats::TestUtil::TestStore server_stats_store; + ServerSslSocketFactory server_ssl_socket_factory(std::move(server_cfg), manager, + server_stats_store, std::vector{}); + + auto socket = std::make_shared( + Network::Test::getCanonicalLoopbackAddress(GetParam()), nullptr, true); + Network::MockListenerCallbacks listener_callbacks; + Network::MockConnectionHandler connection_handler; + Network::ListenerPtr listener = dispatcher_->createListener(socket, listener_callbacks, true); + std::shared_ptr server_read_filter(new Network::MockReadFilter()); + std::shared_ptr client_read_filter(new Network::MockReadFilter()); + + const std::string client_ctx_yaml = R"EOF( + common_tls_context: + )EOF"; + + envoy::extensions::transport_sockets::tls::v3::UpstreamTlsContext tls_context; + TestUtility::loadFromYaml(TestEnvironment::substitute(client_ctx_yaml), tls_context); + auto client_cfg = std::make_unique(tls_context, factory_context_); + Stats::TestUtil::TestStore client_stats_store; + ClientSslSocketFactory client_ssl_socket_factory(std::move(client_cfg), manager, + client_stats_store); + Network::ClientConnectionPtr client_connection = dispatcher_->createClientConnection( + socket->localAddress(), Network::Address::InstanceConstSharedPtr(), + client_ssl_socket_factory.createTransportSocket(nullptr), nullptr); + Network::MockConnectionCallbacks client_connection_callbacks; + client_connection->enableHalfClose(true); + client_connection->addReadFilter(client_read_filter); + client_connection->addConnectionCallbacks(client_connection_callbacks); + client_connection->connect(); + + Network::ConnectionPtr server_connection; + Network::MockConnectionCallbacks server_connection_callbacks; + EXPECT_CALL(listener_callbacks, onAccept_(_)) + .WillOnce(Invoke([&](Network::ConnectionSocketPtr& socket) -> void { + server_connection = dispatcher_->createServerConnection( + std::move(socket), server_ssl_socket_factory.createTransportSocket(nullptr), + stream_info_); + server_connection->enableHalfClose(true); + server_connection->addReadFilter(server_read_filter); + server_connection->addConnectionCallbacks(server_connection_callbacks); + })); + EXPECT_CALL(*server_read_filter, onNewConnection()); + EXPECT_CALL(server_connection_callbacks, onEvent(Network::ConnectionEvent::Connected)) + .WillOnce(Invoke([&](Network::ConnectionEvent) -> void { + Buffer::OwnedImpl data("hello"); + server_connection->write(data, true); + EXPECT_EQ(data.length(), 0); + })); + + EXPECT_CALL(*client_read_filter, onNewConnection()) + .WillOnce(Return(Network::FilterStatus::Continue)); + EXPECT_CALL(client_connection_callbacks, onEvent(Network::ConnectionEvent::Connected)); + EXPECT_CALL(*client_read_filter, onData(BufferStringEqual("hello"), true)) + .WillOnce(Invoke([&](Buffer::Instance& read_buffer, bool) -> Network::FilterStatus { + read_buffer.drain(read_buffer.length()); + client_connection->close(Network::ConnectionCloseType::NoFlush); + return Network::FilterStatus::StopIteration; + })); + EXPECT_CALL(*server_read_filter, onData(_, true)); + + EXPECT_CALL(client_connection_callbacks, onEvent(Network::ConnectionEvent::LocalClose)); + EXPECT_CALL(server_connection_callbacks, onEvent(Network::ConnectionEvent::RemoteClose)) + .WillOnce(Invoke([&](Network::ConnectionEvent) -> void { + server_connection->close(Network::ConnectionCloseType::NoFlush); + dispatcher_->exit(); + })); + + dispatcher_->run(Event::Dispatcher::RunType::Block); +} + +TEST_P(SslSocketTest, ShutdownWithoutCloseNotify) { + const std::string server_ctx_yaml = R"EOF( + common_tls_context: + tls_certificates: + certificate_chain: + filename: "{{ test_tmpdir }}/unittestcert.pem" + private_key: + filename: "{{ test_tmpdir }}/unittestkey.pem" + validation_context: + trusted_ca: + filename: "{{ test_rundir }}/test/extensions/transport_sockets/tls/test_data/ca_certificates.pem" +)EOF"; + + envoy::extensions::transport_sockets::tls::v3::DownstreamTlsContext server_tls_context; + TestUtility::loadFromYaml(TestEnvironment::substitute(server_ctx_yaml), server_tls_context); + auto server_cfg = std::make_unique(server_tls_context, factory_context_); + ContextManagerImpl manager(time_system_); + Stats::TestUtil::TestStore server_stats_store; + ServerSslSocketFactory server_ssl_socket_factory(std::move(server_cfg), manager, + server_stats_store, std::vector{}); + + auto socket = std::make_shared( + Network::Test::getCanonicalLoopbackAddress(GetParam()), nullptr, true); + Network::MockListenerCallbacks listener_callbacks; + Network::MockConnectionHandler connection_handler; + Network::ListenerPtr listener = dispatcher_->createListener(socket, listener_callbacks, true); + std::shared_ptr server_read_filter(new Network::MockReadFilter()); + std::shared_ptr client_read_filter(new Network::MockReadFilter()); + + const std::string client_ctx_yaml = R"EOF( + common_tls_context: + )EOF"; + + envoy::extensions::transport_sockets::tls::v3::UpstreamTlsContext tls_context; + TestUtility::loadFromYaml(TestEnvironment::substitute(client_ctx_yaml), tls_context); + auto client_cfg = std::make_unique(tls_context, factory_context_); + Stats::TestUtil::TestStore client_stats_store; + ClientSslSocketFactory client_ssl_socket_factory(std::move(client_cfg), manager, + client_stats_store); + Network::ClientConnectionPtr client_connection = dispatcher_->createClientConnection( + socket->localAddress(), Network::Address::InstanceConstSharedPtr(), + client_ssl_socket_factory.createTransportSocket(nullptr), nullptr); + Network::MockConnectionCallbacks client_connection_callbacks; + client_connection->enableHalfClose(true); + client_connection->addReadFilter(client_read_filter); + client_connection->addConnectionCallbacks(client_connection_callbacks); + client_connection->connect(); + + Network::ConnectionPtr server_connection; + Network::MockConnectionCallbacks server_connection_callbacks; + EXPECT_CALL(listener_callbacks, onAccept_(_)) + .WillOnce(Invoke([&](Network::ConnectionSocketPtr& socket) -> void { + server_connection = dispatcher_->createServerConnection( + std::move(socket), server_ssl_socket_factory.createTransportSocket(nullptr), + stream_info_); + server_connection->enableHalfClose(true); + server_connection->addReadFilter(server_read_filter); + server_connection->addConnectionCallbacks(server_connection_callbacks); + })); + EXPECT_CALL(server_connection_callbacks, onEvent(Network::ConnectionEvent::Connected)) + .WillOnce(Invoke([&](Network::ConnectionEvent) -> void { + Buffer::OwnedImpl data("hello"); + server_connection->write(data, false); + EXPECT_EQ(data.length(), 0); + // Close without sending close_notify alert. + const SslSocketInfo* ssl_socket = + dynamic_cast(client_connection->ssl().get()); + SSL_set_quiet_shutdown(ssl_socket->ssl(), 1); + server_connection->close(Network::ConnectionCloseType::NoFlush); + })); + + EXPECT_CALL(*client_read_filter, onNewConnection()) + .WillOnce(Return(Network::FilterStatus::Continue)); + EXPECT_CALL(client_connection_callbacks, onEvent(Network::ConnectionEvent::Connected)); + EXPECT_CALL(*client_read_filter, onData(BufferStringEqual("hello"), true)) + .WillOnce(Invoke([&](Buffer::Instance& read_buffer, bool) -> Network::FilterStatus { + read_buffer.drain(read_buffer.length()); + client_connection->close(Network::ConnectionCloseType::NoFlush); + return Network::FilterStatus::StopIteration; + })); + + EXPECT_CALL(server_connection_callbacks, onEvent(Network::ConnectionEvent::LocalClose)); + EXPECT_CALL(client_connection_callbacks, onEvent(Network::ConnectionEvent::LocalClose)) + .WillOnce(Invoke([&](Network::ConnectionEvent) -> void { dispatcher_->exit(); })); + + dispatcher_->run(Event::Dispatcher::RunType::Block); +} + TEST_P(SslSocketTest, ClientAuthMultipleCAs) { const std::string server_ctx_yaml = R"EOF( common_tls_context: diff --git a/test/per_file_coverage.sh b/test/per_file_coverage.sh index fc1c13fa6918..31e48a5bce64 100755 --- a/test/per_file_coverage.sh +++ b/test/per_file_coverage.sh @@ -53,10 +53,10 @@ declare -a KNOWN_LOW_COVERAGE=( "source/extensions/tracers:96.3" "source/extensions/tracers/opencensus:90.1" "source/extensions/tracers/xray:95.3" -"source/extensions/transport_sockets:94.8" +"source/extensions/transport_sockets:94.6" "source/extensions/transport_sockets/raw_buffer:90.9" "source/extensions/transport_sockets/tap:95.6" -"source/extensions/transport_sockets/tls:94.2" +"source/extensions/transport_sockets/tls:93.8" "source/extensions/transport_sockets/tls/private_key:76.9" )