diff --git a/tcp_transport.go b/tcp_transport.go index c439443..37a0a67 100644 --- a/tcp_transport.go +++ b/tcp_transport.go @@ -92,7 +92,12 @@ func (t *tcpTransport) SetEncryption(ctx context.Context, e SessionEncryption) e tlsConn = tls.Client(t.conn, t.TLSConfig) } - deadline, _ := ctx.Deadline() // Use the deadline zero value if ctx has no deadline defined + var deadline time.Time + var ok bool + if deadline, ok = ctx.Deadline(); !ok { + deadline = time.Now().Add(30 * time.Second) + } + if err := tlsConn.SetWriteDeadline(deadline); err != nil { return err } @@ -157,7 +162,7 @@ func (t *tcpTransport) Close() error { return err } - err := t.conn.Close() + err := t.ctxConn.Close() t.conn = nil return err } @@ -383,12 +388,12 @@ func (c *ctxConn) SetWriteContext(ctx context.Context) { } func (c *ctxConn) Read(b []byte) (n int, err error) { - for c.readCtx.Err() == nil { - var deadline time.Time - if ctxDeadline, ok := c.readCtx.Deadline(); ok { + for err = c.readCtx.Err(); err == nil; { + deadline := time.Now().Add(c.readTimeout) + + // Use the context deadline only if it is early then the default + if ctxDeadline, ok := c.readCtx.Deadline(); ok && deadline.After(ctxDeadline) { deadline = ctxDeadline - } else { - deadline = time.Now().Add(c.readTimeout) } if err = c.conn.SetReadDeadline(deadline); err != nil { @@ -410,16 +415,16 @@ func (c *ctxConn) Read(b []byte) (n int, err error) { c.readCancel() } - return 0, c.readCtx.Err() + return 0, err } func (c *ctxConn) Write(b []byte) (n int, err error) { - for c.writeCtx.Err() == nil { - var deadline time.Time - if ctxDeadline, ok := c.writeCtx.Deadline(); ok { + for err = c.writeCtx.Err(); err == nil; { + deadline := time.Now().Add(c.writeTimeout) + + // Use the context deadline only if it is early then the default + if ctxDeadline, ok := c.readCtx.Deadline(); ok && deadline.After(ctxDeadline) { deadline = ctxDeadline - } else { - deadline = time.Now().Add(c.writeTimeout) } if err = c.conn.SetWriteDeadline(deadline); err != nil { @@ -441,10 +446,18 @@ func (c *ctxConn) Write(b []byte) (n int, err error) { c.writeCancel() } - return 0, c.writeCtx.Err() + return 0, err } func (c *ctxConn) Close() error { + if c.readCancel != nil { + c.readCancel() + } + + if c.writeCancel != nil { + c.writeCancel() + } + return c.conn.Close() }