From d37bd82db65e7557e0143fbe76f937a35cbff441 Mon Sep 17 00:00:00 2001 From: David Drysdale Date: Thu, 4 Apr 2019 17:58:15 +0100 Subject: [PATCH] Fix DialContext when using a timeout (#2737) Fixes #2736 --- clientconn.go | 13 +++++++------ clientconn_test.go | 42 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 6 deletions(-) diff --git a/clientconn.go b/clientconn.go index e12ea3479f94..bd2d2b317798 100644 --- a/clientconn.go +++ b/clientconn.go @@ -138,12 +138,6 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn * } defer func() { - select { - case <-ctx.Done(): - conn, err = nil, ctx.Err() - default: - } - if err != nil { cc.Close() } @@ -217,6 +211,13 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn * ctx, cancel = context.WithTimeout(ctx, cc.dopts.timeout) defer cancel() } + defer func() { + select { + case <-ctx.Done(): + conn, err = nil, ctx.Err() + default: + } + }() scSet := false if cc.dopts.scChan != nil { diff --git a/clientconn_test.go b/clientconn_test.go index ab1db5664049..356f4f25af95 100644 --- a/clientconn_test.go +++ b/clientconn_test.go @@ -52,6 +52,48 @@ func assertState(wantState connectivity.State, cc *ClientConn) (connectivity.Sta return state, state == wantState } +func (s) TestDialWithTimeout(t *testing.T) { + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("Error while listening. Err: %v", err) + } + defer lis.Close() + lisAddr := resolver.Address{Addr: lis.Addr().String()} + lisDone := make(chan struct{}) + dialDone := make(chan struct{}) + // 1st listener accepts the connection and then does nothing + go func() { + defer close(lisDone) + conn, err := lis.Accept() + if err != nil { + t.Errorf("Error while accepting. Err: %v", err) + return + } + framer := http2.NewFramer(conn, conn) + if err := framer.WriteSettings(http2.Setting{}); err != nil { + t.Errorf("Error while writing settings. Err: %v", err) + return + } + <-dialDone // Close conn only after dial returns. + }() + + r, cleanup := manual.GenerateAndRegisterManualResolver() + defer cleanup() + r.InitialState(resolver.State{Addresses: []resolver.Address{lisAddr}}) + client, err := Dial(r.Scheme()+":///test.server", WithInsecure(), WithTimeout(5*time.Second)) + close(dialDone) + if err != nil { + t.Fatalf("Dial failed. Err: %v", err) + } + defer client.Close() + timeout := time.After(1 * time.Second) + select { + case <-timeout: + t.Fatal("timed out waiting for server to finish") + case <-lisDone: + } +} + func (s) TestDialWithMultipleBackendsNotSendingServerPreface(t *testing.T) { lis1, err := net.Listen("tcp", "localhost:0") if err != nil {