Skip to content

Commit

Permalink
Ensure that disconnects propagate through port-forwarded tunnel
Browse files Browse the repository at this point in the history
When a client terminates, it should propagate over to the server on
remote host without it attempting to write. The defered cleanup wasn't
closing all the right connections to make this happen.

When a server terminates, re-execed teleport might not notice until
client sends new data. Re-execed teleport should exit on first observed
error in either direction and not wait for both ends.

Fixes #3749
  • Loading branch information
Andrew Lytvynov committed Jun 24, 2020
1 parent 914a284 commit 3821f76
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 17 deletions.
24 changes: 16 additions & 8 deletions lib/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -867,7 +867,11 @@ func (client *NodeClient) ExecuteSCP(cmd scp.Command) error {
return trace.Wrap(err)
}

func (client *NodeClient) proxyConnection(ctx context.Context, conn net.Conn, remoteAddr string) error {
type netDialer interface {
Dial(string, string) (net.Conn, error)
}

func proxyConnection(ctx context.Context, conn net.Conn, remoteAddr string, dialer netDialer) error {
defer conn.Close()
defer log.Debugf("Finished proxy from %v to %v.", conn.RemoteAddr(), remoteAddr)

Expand All @@ -878,7 +882,7 @@ func (client *NodeClient) proxyConnection(ctx context.Context, conn net.Conn, re

log.Debugf("Attempting to connect proxy from %v to %v.", conn.RemoteAddr(), remoteAddr)
for attempt := 1; attempt <= 5; attempt++ {
remoteConn, err = client.Client.Dial("tcp", remoteAddr)
remoteConn, err = dialer.Dial("tcp", remoteAddr)
if err != nil {
log.Debugf("Proxy connection attempt %v: %v.", attempt, err)

Expand Down Expand Up @@ -906,29 +910,33 @@ func (client *NodeClient) proxyConnection(ctx context.Context, conn net.Conn, re
errCh := make(chan error, 2)
go func() {
defer conn.Close()
defer remoteConn.Close()

_, err := io.Copy(conn, remoteConn)
errCh <- err
}()
go func() {
defer conn.Close()
defer remoteConn.Close()

_, err := io.Copy(remoteConn, conn)
errCh <- err
}()

var lastErr error
var errs []error
for i := 0; i < 2; i++ {
select {
case err := <-errCh:
if err != nil && err != io.EOF {
if err != nil && err != io.EOF && !strings.Contains(err.Error(), "use of closed network connection") {
log.Warnf("Failed to proxy connection: %v.", err)
lastErr = err
errs = append(errs, err)
}
case <-ctx.Done():
return trace.Wrap(ctx.Err())
}
}

return lastErr
return trace.NewAggregate(errs...)
}

// listenAndForward listens on a given socket and forwards all incoming
Expand All @@ -947,7 +955,7 @@ func (c *NodeClient) listenAndForward(ctx context.Context, ln net.Listener, remo

// Proxy the connection to the remote address.
go func() {
err := c.proxyConnection(ctx, conn, remoteAddr)
err := proxyConnection(ctx, conn, remoteAddr, c.Client)
if err != nil {
log.Warnf("Failed to proxy connection: %v.", err)
}
Expand Down Expand Up @@ -981,7 +989,7 @@ func (c *NodeClient) dynamicListenAndForward(ctx context.Context, ln net.Listene

// Proxy the connection to the remote address.
go func() {
err := c.proxyConnection(ctx, conn, remoteAddr)
err := proxyConnection(ctx, conn, remoteAddr, c.Client)
if err != nil {
log.Warnf("Failed to proxy connection: %v.", err)
}
Expand Down
123 changes: 123 additions & 0 deletions lib/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,13 @@ limitations under the License.
package client

import (
"context"
"io"
"io/ioutil"
"net"
"os"
"strings"
"time"

"github.com/gravitational/teleport/lib/sshutils"

Expand Down Expand Up @@ -76,3 +82,120 @@ func (s *ClientTestSuite) TestNewSession(c *check.C) {
// the session ID must be taken from tne environ map, if passed:
c.Assert(string(ses.id), check.Equals, "session-id")
}

// TestProxyConnection verifies that client or server-side disconnect
// propagates all the way to the opposite side.
func (s *ClientTestSuite) TestProxyConnection(c *check.C) {
// remoteSrv mocks a remote listener, accepting port-forwarded connections
// over SSH.
remoteConCh := make(chan net.Conn)
remoteErrCh := make(chan error, 3)
remoteSrv := newTestListener(c, func(con net.Conn) {
defer con.Close()

remoteConCh <- con

// Echo any data back to the sender.
_, err := io.Copy(con, con)
if err != nil && strings.Contains(err.Error(), "use of closed network connection") {
err = nil
}
remoteErrCh <- err
})
defer remoteSrv.Close()

// localSrv mocks a local tsh listener, accepting local connections for
// port-forwarding to remote SSH node.
proxyErrCh := make(chan error, 3)
localSrv := newTestListener(c, func(con net.Conn) {
defer con.Close()

proxyErrCh <- proxyConnection(context.Background(), con, remoteSrv.Addr().String(), new(net.Dialer))
})
defer localSrv.Close()

// Dial localSrv. This should trigger proxyConnection and a dial to
// remoteSrv.
localCon, err := net.Dial("tcp", localSrv.Addr().String())
c.Assert(err, check.IsNil)
clientErrCh := make(chan error, 3)
go func(con net.Conn) {
_, err := io.Copy(ioutil.Discard, con)
if err != nil && strings.Contains(err.Error(), "use of closed network connection") {
err = nil
}
clientErrCh <- err
}(localCon)

// Discard remoteCon to unblock the remote handler.
<-remoteConCh

// Simulate a client-side disconnect. All other parties (tsh proxy and
// remove listener) should disconnect as well.
c.Log("simulate client-side disconnect")
err = localCon.Close()
c.Assert(err, check.IsNil)

for i := 0; i < 3; i++ {
select {
case err := <-proxyErrCh:
c.Assert(err, check.IsNil)
case err := <-remoteErrCh:
c.Assert(err, check.IsNil)
case err := <-clientErrCh:
c.Assert(err, check.IsNil)
case <-time.After(5 * time.Second):
c.Fatal("proxyConnection, client and server didn't disconnect within 5s after client connection was closed")
}
}

// Dial localSrv again. This should trigger proxyConnection and a dial to
// remoteSrv.
localCon, err = net.Dial("tcp", localSrv.Addr().String())
c.Assert(err, check.IsNil)
go func(con net.Conn) {
_, err := io.Copy(ioutil.Discard, con)
if err != nil && strings.Contains(err.Error(), "use of closed network connection") {
err = nil
}
clientErrCh <- err
}(localCon)

// Simulate a server-side disconnect. All other parties (tsh proxy and
// local client) should disconnect as well.
c.Log("simulate server-side disconnect")
remoteCon := <-remoteConCh
err = remoteCon.Close()
c.Assert(err, check.IsNil)

for i := 0; i < 3; i++ {
select {
case err := <-proxyErrCh:
c.Assert(err, check.IsNil)
case err := <-remoteErrCh:
c.Assert(err, check.IsNil)
case err := <-clientErrCh:
c.Assert(err, check.IsNil)
case <-time.After(5 * time.Second):
c.Fatal("proxyConnection, client and server didn't disconnect within 5s after remote connection was closed")
}
}
}

func newTestListener(c *check.C, handle func(net.Conn)) net.Listener {
l, err := net.Listen("tcp", "localhost:0")
c.Assert(err, check.IsNil)

go func() {
for {
con, err := l.Accept()
if err != nil {
c.Logf("listener error: %v", err)
return
}
go handle(con)
}
}()

return l
}
18 changes: 9 additions & 9 deletions lib/srv/reexec.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,8 @@ func RunCommand() (io.Writer, int, error) {
// pipe) then port forwards.
func RunForward() (io.Writer, int, error) {
// errorWriter is used to return any error message back to the client.
errorWriter := os.Stdout
// Use stderr so that it's not forwarded to the remote client.
errorWriter := os.Stderr

// Parent sends the command payload in the third file descriptor.
cmdfd := os.NewFile(uintptr(3), "/proc/self/fd/3")
Expand Down Expand Up @@ -269,30 +270,29 @@ func RunForward() (io.Writer, int, error) {
// pipe to channel.
errorCh := make(chan error, 2)
go func() {
defer conn.Close()
defer os.Stdout.Close()
defer os.Stdin.Close()

_, err := io.Copy(os.Stdout, conn)
errorCh <- err
}()
go func() {
defer conn.Close()
defer os.Stdout.Close()
defer os.Stdin.Close()

_, err := io.Copy(conn, os.Stdin)
errorCh <- err
}()

// Block until copy is complete and the child process is done executing.
var errs []error
for i := 0; i < 2; i++ {
err := <-errorCh
if err != nil && err != io.EOF {
errs = append(errs, err)
}
// Block until copy is complete in either direction. The other direction
// will get cleaned up automatically.
if err = <-errorCh; err != nil && err != io.EOF {
return errorWriter, teleport.RemoteCommandFailure, trace.Wrap(err)
}

return ioutil.Discard, teleport.RemoteCommandSuccess, trace.NewAggregate(errs...)
return ioutil.Discard, teleport.RemoteCommandSuccess, nil
}

// RunAndExit will run the requested command and then exit. This wrapper
Expand Down
4 changes: 4 additions & 0 deletions lib/srv/regular/sshserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -947,6 +947,8 @@ func (s *Server) handleDirectTCPIPRequest(ctx context.Context, ccx *sshutils.Con
if err != nil {
writeStderr(channel, err.Error())
}
// Propagate stderr from the spawned Teleport process to log any errors.
cmd.Stderr = os.Stderr

// Create a pipe for std{in,out} that will be used to transfer data between
// parent and child.
Expand All @@ -971,13 +973,15 @@ func (s *Server) handleDirectTCPIPRequest(ctx context.Context, ccx *sshutils.Con
// pipe to channel.
errorCh := make(chan error, 2)
go func() {
defer channel.Close()
defer pw.Close()
defer pr.Close()

_, err := io.Copy(pw, channel)
errorCh <- err
}()
go func() {
defer channel.Close()
defer pw.Close()
defer pr.Close()

Expand Down

0 comments on commit 3821f76

Please sign in to comment.