Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure that disconnects propagate through port-forwarded tunnel #3801

Merged
merged 1 commit into from
Jun 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions lib/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -867,7 +867,11 @@ func (client *NodeClient) ExecuteSCP(cmd scp.Command) error {
return trace.Wrap(err)
}

func (client *NodeClient) proxyConnection(ctx context.Context, conn net.Conn, remoteAddr string) error {
type netDialer interface {
Dial(string, string) (net.Conn, error)
}

func proxyConnection(ctx context.Context, conn net.Conn, remoteAddr string, dialer netDialer) error {
defer conn.Close()
defer log.Debugf("Finished proxy from %v to %v.", conn.RemoteAddr(), remoteAddr)

Expand All @@ -878,7 +882,7 @@ func (client *NodeClient) proxyConnection(ctx context.Context, conn net.Conn, re

log.Debugf("Attempting to connect proxy from %v to %v.", conn.RemoteAddr(), remoteAddr)
for attempt := 1; attempt <= 5; attempt++ {
remoteConn, err = client.Client.Dial("tcp", remoteAddr)
remoteConn, err = dialer.Dial("tcp", remoteAddr)
if err != nil {
log.Debugf("Proxy connection attempt %v: %v.", attempt, err)

Expand Down Expand Up @@ -906,29 +910,33 @@ func (client *NodeClient) proxyConnection(ctx context.Context, conn net.Conn, re
errCh := make(chan error, 2)
go func() {
defer conn.Close()
defer remoteConn.Close()

_, err := io.Copy(conn, remoteConn)
errCh <- err
}()
go func() {
defer conn.Close()
defer remoteConn.Close()

_, err := io.Copy(remoteConn, conn)
errCh <- err
}()

var lastErr error
var errs []error
for i := 0; i < 2; i++ {
select {
case err := <-errCh:
if err != nil && err != io.EOF {
if err != nil && err != io.EOF && !strings.Contains(err.Error(), "use of closed network connection") {
log.Warnf("Failed to proxy connection: %v.", err)
lastErr = err
errs = append(errs, err)
}
case <-ctx.Done():
return trace.Wrap(ctx.Err())
}
}

return lastErr
return trace.NewAggregate(errs...)
}

// listenAndForward listens on a given socket and forwards all incoming
Expand All @@ -947,7 +955,7 @@ func (c *NodeClient) listenAndForward(ctx context.Context, ln net.Listener, remo

// Proxy the connection to the remote address.
go func() {
err := c.proxyConnection(ctx, conn, remoteAddr)
err := proxyConnection(ctx, conn, remoteAddr, c.Client)
if err != nil {
log.Warnf("Failed to proxy connection: %v.", err)
}
Expand Down Expand Up @@ -981,7 +989,7 @@ func (c *NodeClient) dynamicListenAndForward(ctx context.Context, ln net.Listene

// Proxy the connection to the remote address.
go func() {
err := c.proxyConnection(ctx, conn, remoteAddr)
err := proxyConnection(ctx, conn, remoteAddr, c.Client)
if err != nil {
log.Warnf("Failed to proxy connection: %v.", err)
}
Expand Down
123 changes: 123 additions & 0 deletions lib/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,13 @@ limitations under the License.
package client

import (
"context"
"io"
"io/ioutil"
"net"
"os"
"strings"
"time"

"github.com/gravitational/teleport/lib/sshutils"

Expand Down Expand Up @@ -76,3 +82,120 @@ func (s *ClientTestSuite) TestNewSession(c *check.C) {
// the session ID must be taken from tne environ map, if passed:
c.Assert(string(ses.id), check.Equals, "session-id")
}

// TestProxyConnection verifies that client or server-side disconnect
// propagates all the way to the opposite side.
func (s *ClientTestSuite) TestProxyConnection(c *check.C) {
// remoteSrv mocks a remote listener, accepting port-forwarded connections
// over SSH.
remoteConCh := make(chan net.Conn)
remoteErrCh := make(chan error, 3)
remoteSrv := newTestListener(c, func(con net.Conn) {
defer con.Close()

remoteConCh <- con

// Echo any data back to the sender.
_, err := io.Copy(con, con)
if err != nil && strings.Contains(err.Error(), "use of closed network connection") {
err = nil
}
remoteErrCh <- err
})
defer remoteSrv.Close()

// localSrv mocks a local tsh listener, accepting local connections for
// port-forwarding to remote SSH node.
proxyErrCh := make(chan error, 3)
localSrv := newTestListener(c, func(con net.Conn) {
defer con.Close()

proxyErrCh <- proxyConnection(context.Background(), con, remoteSrv.Addr().String(), new(net.Dialer))
})
defer localSrv.Close()

// Dial localSrv. This should trigger proxyConnection and a dial to
// remoteSrv.
localCon, err := net.Dial("tcp", localSrv.Addr().String())
c.Assert(err, check.IsNil)
clientErrCh := make(chan error, 3)
go func(con net.Conn) {
_, err := io.Copy(ioutil.Discard, con)
if err != nil && strings.Contains(err.Error(), "use of closed network connection") {
err = nil
}
clientErrCh <- err
}(localCon)

// Discard remoteCon to unblock the remote handler.
<-remoteConCh

// Simulate a client-side disconnect. All other parties (tsh proxy and
// remove listener) should disconnect as well.
c.Log("simulate client-side disconnect")
err = localCon.Close()
c.Assert(err, check.IsNil)

for i := 0; i < 3; i++ {
select {
case err := <-proxyErrCh:
c.Assert(err, check.IsNil)
case err := <-remoteErrCh:
c.Assert(err, check.IsNil)
case err := <-clientErrCh:
c.Assert(err, check.IsNil)
case <-time.After(5 * time.Second):
c.Fatal("proxyConnection, client and server didn't disconnect within 5s after client connection was closed")
}
}

// Dial localSrv again. This should trigger proxyConnection and a dial to
// remoteSrv.
localCon, err = net.Dial("tcp", localSrv.Addr().String())
c.Assert(err, check.IsNil)
go func(con net.Conn) {
_, err := io.Copy(ioutil.Discard, con)
if err != nil && strings.Contains(err.Error(), "use of closed network connection") {
err = nil
}
clientErrCh <- err
}(localCon)

// Simulate a server-side disconnect. All other parties (tsh proxy and
// local client) should disconnect as well.
c.Log("simulate server-side disconnect")
remoteCon := <-remoteConCh
err = remoteCon.Close()
c.Assert(err, check.IsNil)

for i := 0; i < 3; i++ {
select {
case err := <-proxyErrCh:
c.Assert(err, check.IsNil)
case err := <-remoteErrCh:
c.Assert(err, check.IsNil)
case err := <-clientErrCh:
c.Assert(err, check.IsNil)
case <-time.After(5 * time.Second):
c.Fatal("proxyConnection, client and server didn't disconnect within 5s after remote connection was closed")
}
}
}

func newTestListener(c *check.C, handle func(net.Conn)) net.Listener {
l, err := net.Listen("tcp", "localhost:0")
c.Assert(err, check.IsNil)

go func() {
for {
con, err := l.Accept()
if err != nil {
c.Logf("listener error: %v", err)
return
}
go handle(con)
}
}()

return l
}
18 changes: 9 additions & 9 deletions lib/srv/reexec.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,8 @@ func RunCommand() (io.Writer, int, error) {
// pipe) then port forwards.
func RunForward() (io.Writer, int, error) {
// errorWriter is used to return any error message back to the client.
errorWriter := os.Stdout
// Use stderr so that it's not forwarded to the remote client.
errorWriter := os.Stderr

// Parent sends the command payload in the third file descriptor.
cmdfd := os.NewFile(uintptr(3), "/proc/self/fd/3")
Expand Down Expand Up @@ -269,30 +270,29 @@ func RunForward() (io.Writer, int, error) {
// pipe to channel.
errorCh := make(chan error, 2)
go func() {
defer conn.Close()
defer os.Stdout.Close()
defer os.Stdin.Close()

_, err := io.Copy(os.Stdout, conn)
errorCh <- err
}()
go func() {
defer conn.Close()
defer os.Stdout.Close()
defer os.Stdin.Close()

_, err := io.Copy(conn, os.Stdin)
errorCh <- err
}()

// Block until copy is complete and the child process is done executing.
var errs []error
for i := 0; i < 2; i++ {
err := <-errorCh
if err != nil && err != io.EOF {
errs = append(errs, err)
}
// Block until copy is complete in either direction. The other direction
// will get cleaned up automatically.
if err = <-errorCh; err != nil && err != io.EOF {
return errorWriter, teleport.RemoteCommandFailure, trace.Wrap(err)
}

return ioutil.Discard, teleport.RemoteCommandSuccess, trace.NewAggregate(errs...)
return ioutil.Discard, teleport.RemoteCommandSuccess, nil
}

// RunAndExit will run the requested command and then exit. This wrapper
Expand Down
4 changes: 4 additions & 0 deletions lib/srv/regular/sshserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -950,6 +950,8 @@ func (s *Server) handleDirectTCPIPRequest(ccx *sshutils.ConnectionContext, ident
if err != nil {
writeStderr(channel, err.Error())
}
// Propagate stderr from the spawned Teleport process to log any errors.
cmd.Stderr = os.Stderr

// Create a pipe for std{in,out} that will be used to transfer data between
// parent and child.
Expand All @@ -974,13 +976,15 @@ func (s *Server) handleDirectTCPIPRequest(ccx *sshutils.ConnectionContext, ident
// pipe to channel.
errorCh := make(chan error, 2)
go func() {
defer channel.Close()
defer pw.Close()
defer pr.Close()

_, err := io.Copy(pw, channel)
errorCh <- err
}()
go func() {
defer channel.Close()
defer pw.Close()
defer pr.Close()

Expand Down