diff --git a/tests/integration_tests/tests/connection.rs b/tests/integration_tests/tests/connection.rs new file mode 100644 index 000000000..788a04318 --- /dev/null +++ b/tests/integration_tests/tests/connection.rs @@ -0,0 +1,55 @@ +use futures_util::FutureExt; +use integration_tests::pb::{test_client::TestClient, test_server, Input, Output}; +use std::sync::{Arc, Mutex}; +use std::time::Duration; +use tokio::sync::oneshot; +use tonic::{transport::Server, Request, Response, Status}; + +#[tokio::test] +async fn connect_returns_err() { + let res = TestClient::connect("http://thisdoesntexist").await; + + assert!(res.is_err()); +} + +#[tokio::test] +async fn connect_returns_err_via_call_after_connected() { + struct Svc(Arc>>>); + + #[tonic::async_trait] + impl test_server::Test for Svc { + async fn unary_call(&self, _: Request) -> Result, Status> { + let mut l = self.0.lock().unwrap(); + l.take().unwrap().send(()).unwrap(); + + Ok(Response::new(Output {})) + } + } + + let (tx, rx) = oneshot::channel(); + let sender = Arc::new(Mutex::new(Some(tx))); + let svc = test_server::TestServer::new(Svc(sender)); + + let jh = tokio::spawn(async move { + Server::builder() + .add_service(svc) + .serve_with_shutdown("127.0.0.1:1338".parse().unwrap(), rx.map(drop)) + .await + .unwrap(); + }); + + tokio::time::delay_for(Duration::from_millis(100)).await; + + let mut client = TestClient::connect("http://127.0.0.1:1338").await.unwrap(); + + // First call should pass, then shutdown the server + client.unary_call(Request::new(Input {})).await.unwrap(); + + tokio::time::delay_for(Duration::from_millis(100)).await; + + let res = client.unary_call(Request::new(Input {})).await; + + assert!(res.is_err()); + + jh.await.unwrap(); +} diff --git a/tonic/src/transport/service/reconnect.rs b/tonic/src/transport/service/reconnect.rs index 3e6e7015f..1d6394a2b 100644 --- a/tonic/src/transport/service/reconnect.rs +++ b/tonic/src/transport/service/reconnect.rs @@ -18,6 +18,7 @@ where state: State, target: Target, error: Option, + has_been_connected: bool, } #[derive(Debug)] @@ -37,6 +38,7 @@ where state: State::Idle, target, error: None, + has_been_connected: false, } } } @@ -84,14 +86,23 @@ where } Poll::Ready(Err(e)) => { trace!("poll_ready; error"); + state = State::Idle; - self.error = Some(e.into()); - break; + + if self.has_been_connected { + self.error = Some(e.into()); + break; + } else { + return Poll::Ready(Err(e.into())); + } } } } State::Connected(ref mut inner) => { trace!("poll_ready; connected"); + + self.has_been_connected = true; + match inner.poll_ready(cx) { Poll::Ready(Ok(())) => { trace!("poll_ready; ready");