Skip to content

Commit

Permalink
adding additional audit log context around SSH port forwarding
Browse files Browse the repository at this point in the history
  • Loading branch information
eriktate committed Jan 15, 2025
1 parent 063603a commit f5e982d
Show file tree
Hide file tree
Showing 6 changed files with 183 additions and 27 deletions.
11 changes: 7 additions & 4 deletions lib/events/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,10 +264,13 @@ const (
X11ForwardErr = "error"

// Port forwarding event
PortForwardEvent = "port"
PortForwardAddr = "addr"
PortForwardSuccess = "success"
PortForwardErr = "error"
PortForwardEvent = "port"
PortForwardLocalEvent = "port.local"
PortForwardRemoteEvent = "port.remote"
PortForwardRemoteConnEvent = "port.remote_conn"
PortForwardAddr = "addr"
PortForwardSuccess = "success"
PortForwardErr = "error"

// AuthAttemptEvent is authentication attempt that either
// succeeded or failed based on event status
Expand Down
6 changes: 3 additions & 3 deletions lib/srv/ctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -1269,12 +1269,12 @@ func (c *ServerContext) GetSessionMetadata() apievents.SessionMetadata {
}
}

func (c *ServerContext) GetPortForwardEvent() apievents.PortForward {
func (c *ServerContext) GetPortForwardEvent(evType, code string) apievents.PortForward {
sconn := c.ConnectionContext.ServerConn
return apievents.PortForward{
Metadata: apievents.Metadata{
Type: events.PortForwardEvent,
Code: events.PortForwardCode,
Type: evType,
Code: code,
},
UserMetadata: c.Identity.GetUserMetadata(),
ConnectionMetadata: apievents.ConnectionMetadata{
Expand Down
79 changes: 72 additions & 7 deletions lib/srv/forward/sshserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,7 @@ func (s *Server) Serve() {
CloseCancel: func() { s.connectionContext.Close() },
})

go s.handleClientChannels(ctx, forwardedTCPIP)
go s.handleClientChannels(ctx, forwardedTCPIP, sconn.LocalAddr().String(), sconn.RemoteAddr().String())
go s.handleConnection(ctx, chans, reqs)
}

Expand Down Expand Up @@ -874,7 +874,33 @@ func (s *Server) handleConnection(ctx context.Context, chans <-chan ssh.NewChann
}

// handleClientChannels handles channel open requests from the remote server.
func (s *Server) handleClientChannels(ctx context.Context, forwardedTCPIP <-chan ssh.NewChannel) {
func (s *Server) handleClientChannels(ctx context.Context, forwardedTCPIP <-chan ssh.NewChannel, localAddr, remoteAddr string) {
forwarding := false
defer func() {
// don't log the stop code unless we've logged the start code
if !forwarding {
return
}

if err := s.EmitAuditEvent(ctx, &apievents.PortForward{
Metadata: apievents.Metadata{
Type: events.PortForwardEvent,
Code: events.PortForwardStopCode,
},
UserMetadata: s.identityContext.GetUserMetadata(),
ConnectionMetadata: apievents.ConnectionMetadata{
LocalAddr: localAddr,
RemoteAddr: remoteAddr,
},
Addr: s.targetAddr,
Status: apievents.Status{
Success: true,
},
}); err != nil {
s.logger.WarnContext(ctx, "Failed to emit port forward event", "error", err)
}
}()

for nch := range forwardedTCPIP {
chanCtx, nch := tracessh.ContextFromNewChannel(nch)
ctx, span := s.tracerProvider.Tracer("ssh").Start(
Expand All @@ -894,6 +920,28 @@ func (s *Server) handleClientChannels(ctx context.Context, forwardedTCPIP <-chan
s.logger.ErrorContext(ctx, "Error handling forwarded-tcpip request", "error", err)
}
}()

if !forwarding {
forwarding = true

if err := s.EmitAuditEvent(ctx, &apievents.PortForward{
Metadata: apievents.Metadata{
Type: events.PortForwardEvent,
Code: events.PortForwardCode,
},
UserMetadata: s.identityContext.GetUserMetadata(),
ConnectionMetadata: apievents.ConnectionMetadata{
LocalAddr: localAddr,
RemoteAddr: remoteAddr,
},
Addr: s.targetAddr,
Status: apievents.Status{
Success: true,
},
}); err != nil {
s.logger.WarnContext(ctx, "Failed to emit port forward event", "error", err)
}
}
}
}

Expand Down Expand Up @@ -922,6 +970,11 @@ func (s *Server) handleForwardedTCPIPRequest(ctx context.Context, nch ssh.NewCha
scx.DstAddr = sshutils.JoinHostPort(req.Host, req.Port)
defer scx.Close()

event := scx.GetPortForwardEvent(events.PortForwardRemoteConnEvent, events.PortForwardCode)
if err := s.EmitAuditEvent(ctx, &event); err != nil {
s.logger.ErrorContext(ctx, "Failed to emit audit event", "error", err)
}

// Open a forwarding channel on the client.
outCh, outRequests, err := scx.ServerConn.OpenChannel(nch.ChannelType(), nch.ExtraData())
if err != nil {
Expand All @@ -941,10 +994,12 @@ func (s *Server) handleForwardedTCPIPRequest(ctx context.Context, nch ssh.NewCha
go io.Copy(io.Discard, ch.Stderr())
ch = scx.TrackActivity(ch)

event := scx.GetPortForwardEvent()
if err := s.EmitAuditEvent(ctx, &event); err != nil {
s.logger.ErrorContext(ctx, "Failed to emit audit event", "error", err)
}
defer func() {
stopEvent := scx.GetPortForwardEvent(events.PortForwardRemoteConnEvent, events.PortForwardStopCode)
if err := s.EmitAuditEvent(ctx, &stopEvent); err != nil {
s.logger.ErrorContext(ctx, "Failed to emit audit event", "error", err)
}
}()

return trace.Wrap(utils.ProxyConn(ctx, ch, outCh))
}
Expand Down Expand Up @@ -1120,13 +1175,23 @@ func (s *Server) handleDirectTCPIPRequest(ctx context.Context, ch ssh.Channel, r
}
defer conn.Close()

event := scx.GetPortForwardEvent()
event := scx.GetPortForwardEvent(events.PortForwardLocalEvent, events.PortForwardCode)
if err := s.EmitAuditEvent(s.closeContext, &event); err != nil {
s.logger.WarnContext(ctx, "Failed to emit port forward event", "error", err)
}

if err := utils.ProxyConn(ctx, ch, conn); err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, os.ErrClosed) {
s.logger.WarnContext(ctx, "Failed proxying data for port forwarding connection", "error", err)

event = scx.GetPortForwardEvent(events.PortForwardLocalEvent, events.PortForwardFailureCode)
if err := s.EmitAuditEvent(s.closeContext, &event); err != nil {
s.logger.WarnContext(ctx, "Failed to emit port forward event", "error", err)
}
}

event = scx.GetPortForwardEvent(events.PortForwardLocalEvent, events.PortForwardStopCode)
if err := s.EmitAuditEvent(s.closeContext, &event); err != nil {
s.logger.WarnContext(ctx, "Failed to emit port forward event", "error", err)
}
}

Expand Down
59 changes: 49 additions & 10 deletions lib/srv/regular/sshserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -1489,13 +1489,9 @@ func (s *Server) handleDirectTCPIPRequest(ctx context.Context, ccx *sshutils.Con
return
}

if err := utils.ProxyConn(ctx, conn, channel); err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, os.ErrClosed) {
scx.Logger.WarnContext(ctx, "Connection problem in direct-tcpip channel", "error", err)
}

if err := s.EmitAuditEvent(s.ctx, &apievents.PortForward{
Metadata: apievents.Metadata{
Type: events.PortForwardEvent,
Type: events.PortForwardLocalEvent,
Code: events.PortForwardCode,
},
UserMetadata: scx.Identity.GetUserMetadata(),
Expand All @@ -1510,6 +1506,19 @@ func (s *Server) handleDirectTCPIPRequest(ctx context.Context, ccx *sshutils.Con
}); err != nil {
scx.Logger.WarnContext(ctx, "Failed to emit port forward event", "error", err)
}

if err := utils.ProxyConn(ctx, conn, channel); err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, os.ErrClosed) {
event := scx.GetPortForwardEvent(events.PortForwardLocalEvent, events.PortForwardFailureCode)
if err := s.EmitAuditEvent(s.ctx, &event); err != nil {
scx.Logger.WarnContext(ctx, "Failed to emit port forward event", "error", err)
}
scx.Logger.WarnContext(ctx, "Connection problem in direct-tcpip channel", "error", err)
}

event := scx.GetPortForwardEvent(events.PortForwardLocalEvent, events.PortForwardStopCode)
if err := s.EmitAuditEvent(s.ctx, &event); err != nil {
scx.Logger.WarnContext(ctx, "Failed to emit port forward event", "error", err)
}
}

// handleSessionRequests handles out of band session requests once the session
Expand Down Expand Up @@ -2162,6 +2171,7 @@ func (s *Server) createForwardingContext(ctx context.Context, ccx *sshutils.Conn
if err != nil {
return nil, nil, trace.Wrap(err)
}

listenAddr := sshutils.JoinHostPort(req.Addr, req.Port)
scx.IsTestStub = s.isTestStub
scx.ExecType = teleport.TCPIPForwardRequest
Expand Down Expand Up @@ -2201,14 +2211,38 @@ func (s *Server) handleTCPIPForwardRequest(ctx context.Context, ccx *sshutils.Co
}
scx.SrcAddr = sshutils.JoinHostPort(srcHost, listenPort)

event := scx.GetPortForwardEvent()
if err := s.EmitAuditEvent(ctx, &event); err != nil {
s.logger.WarnContext(ctx, "Failed to emit audit event", "error", err)
// pregenerate audit events since ServerContext may be closed before they're used
startEvent := scx.GetPortForwardEvent(events.PortForwardRemoteConnEvent, events.PortForwardCode)
stopEvent := scx.GetPortForwardEvent(events.PortForwardRemoteConnEvent, events.PortForwardStopCode)
errEvent := scx.GetPortForwardEvent(events.PortForwardRemoteConnEvent, events.PortForwardFailureCode)
proxyWithAudit := func(ctx context.Context, remoteAddr string, client io.ReadWriteCloser, server io.ReadWriteCloser) {
startEvent.RemoteAddr = remoteAddr
if err := s.EmitAuditEvent(ctx, &startEvent); err != nil {
s.logger.WarnContext(ctx, "Failed to emit audit event", "error", err)
}

if err := utils.ProxyConn(ctx, client, server); err != nil {
errEvent.RemoteAddr = remoteAddr
if err := s.EmitAuditEvent(ctx, &errEvent); err != nil {
s.logger.WarnContext(ctx, "Failed to emit audit event", "error", err)
}
}

stopEvent.RemoteAddr = remoteAddr
if err := s.EmitAuditEvent(ctx, &stopEvent); err != nil {
s.logger.WarnContext(ctx, "Failed to emit audit event", "error", err)
}
}
if err := sshutils.StartRemoteListener(ctx, scx.ConnectionContext.ServerConn, scx.SrcAddr, listener); err != nil {

if err := sshutils.StartRemoteListener(ctx, scx.ServerConn, scx.SrcAddr, listener, proxyWithAudit); err != nil {
return trace.Wrap(err)
}

event := scx.GetPortForwardEvent(events.PortForwardRemoteEvent, events.PortForwardCode)
if err := s.EmitAuditEvent(ctx, &event); err != nil {
s.logger.WarnContext(ctx, "Failed to emit audit event", "error", err)
}

// Report addr back to the client.
if r.WantReply {
var payload []byte
Expand All @@ -2232,6 +2266,11 @@ func (s *Server) handleTCPIPForwardRequest(ctx context.Context, ccx *sshutils.Co
// Close the listener once the connection is closed, if it hasn't
// been closed already via a cancel-tcpip-forward request.
ccx.AddCloser(utils.CloseFunc(func() error {
event := scx.GetPortForwardEvent(events.PortForwardRemoteEvent, events.PortForwardStopCode)
if err := s.EmitAuditEvent(context.Background(), &event); err != nil {
s.logger.WarnContext(context.Background(), "Failed to emit audit event", "error", err)
}

listener, ok := s.remoteForwardingMap.LoadAndDelete(scx.SrcAddr)
if ok {
return trace.Wrap(listener.Close())
Expand All @@ -2250,14 +2289,14 @@ func (s *Server) handleCancelTCPIPForwardRequest(ctx context.Context, ccx *sshut
return trace.Wrap(err)
}
defer scx.Close()

listener, ok := s.remoteForwardingMap.LoadAndDelete(scx.SrcAddr)
if !ok {
return trace.NotFound("no remote forwarding listener at %v", scx.SrcAddr)
}
if err := r.Reply(true, nil); err != nil {
s.logger.WarnContext(ctx, "Failed to reply to request", "request_type", r.Type, "error", err)
}

return trace.Wrap(listener.Close())
}

Expand Down
11 changes: 9 additions & 2 deletions lib/sshutils/tcpip.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ type channelOpener interface {

// StartRemoteListener listens on the given listener and forwards any accepted
// connections over a new "forwarded-tcpip" channel.
func StartRemoteListener(ctx context.Context, sshConn channelOpener, srcAddr string, listener net.Listener) error {
func StartRemoteListener(ctx context.Context, sshConn channelOpener, srcAddr string, listener net.Listener, proxyFn func(ctx context.Context, remoteAddr string, client io.ReadWriteCloser, server io.ReadWriteCloser)) error {
srcHost, srcPort, err := SplitHostPort(srcAddr)
if err != nil {
return trace.Wrap(err)
Expand Down Expand Up @@ -127,7 +127,14 @@ func StartRemoteListener(ctx context.Context, sshConn channelOpener, srcAddr str
}
go ssh.DiscardRequests(rch)
go io.Copy(io.Discard, ch.Stderr())
go utils.ProxyConn(ctx, conn, ch)
go func() {
if proxyFn != nil {
proxyFn(ctx, conn.RemoteAddr().String(), conn, ch)
return
}

utils.ProxyConn(ctx, conn, ch)
}()
}
}()

Expand Down
44 changes: 43 additions & 1 deletion lib/sshutils/tcpip_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ import (
"time"

"github.com/stretchr/testify/require"

"github.com/gravitational/teleport/lib/utils"
)

func TestStartRemoteListener(t *testing.T) {
Expand All @@ -39,19 +41,58 @@ func TestStartRemoteListener(t *testing.T) {
t.Cleanup(tsrv.Close)
testSrvConn, err := net.Dial("tcp", tsrv.Listener.Addr().String())
require.NoError(t, err)
t.Cleanup(func() { testSrvConn.Close() })

sshConn := &mockSSHConn{
mockChan: &mockChannel{
ReadWriter: testSrvConn,
},
}

// Start the remote listener.
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
t.Cleanup(cancel)
require.NoError(t, StartRemoteListener(ctx, sshConn, "127.0.0.1:12345", listener, nil))

// Check that dialing listener makes it all the way to the test http server.
resp, err := http.Get("http://" + listener.Addr().String())
require.NoError(t, err)
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "Hello, world", string(body))
}

func TestStartRemoteListenerWithCustomProxy(t *testing.T) {
// Create a test server to act as the other side of the channel.
tsrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, "Hello, world")
}))
t.Cleanup(tsrv.Close)
testSrvConn, err := net.Dial("tcp", tsrv.Listener.Addr().String())
require.NoError(t, err)
t.Cleanup(func() { testSrvConn.Close() })

sshConn := &mockSSHConn{
mockChan: &mockChannel{
ReadWriter: testSrvConn,
},
}

proxied := false
proxyFn := func(ctx context.Context, remoteAddr string, client io.ReadWriteCloser, server io.ReadWriteCloser) {
proxied = true
_ = utils.ProxyConn(ctx, client, server)
}

// Start the remote listener.
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
t.Cleanup(cancel)
require.NoError(t, StartRemoteListener(ctx, sshConn, "127.0.0.1:12345", listener))
require.NoError(t, StartRemoteListener(ctx, sshConn, "127.0.0.1:12345", listener, proxyFn))

// Check that dialing listener makes it all the way to the test http server.
resp, err := http.Get("http://" + listener.Addr().String())
Expand All @@ -60,4 +101,5 @@ func TestStartRemoteListener(t *testing.T) {
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "Hello, world", string(body))
require.True(t, proxied)
}

0 comments on commit f5e982d

Please sign in to comment.