diff --git a/lib/events/dynamic.go b/lib/events/dynamic.go index f8784a646a77e..b5050f5951fb2 100644 --- a/lib/events/dynamic.go +++ b/lib/events/dynamic.go @@ -107,6 +107,12 @@ func FromEventFields(fields EventFields) (events.AuditEvent, error) { e = &events.X11Forward{} case PortForwardEvent: e = &events.PortForward{} + case PortForwardLocalEvent: + e = &events.PortForward{} + case PortForwardRemoteEvent: + e = &events.PortForward{} + case PortForwardRemoteConnEvent: + e = &events.PortForward{} case AuthAttemptEvent: e = &events.AuthAttempt{} case SCPEvent: diff --git a/lib/srv/ctx.go b/lib/srv/ctx.go index a88ee2396fd6b..fb017be8ddc9a 100644 --- a/lib/srv/ctx.go +++ b/lib/srv/ctx.go @@ -1269,7 +1269,7 @@ func (c *ServerContext) GetSessionMetadata() apievents.SessionMetadata { } } -func (c *ServerContext) GetPortForwardEvent(evType, code string) apievents.PortForward { +func (c *ServerContext) GetPortForwardEvent(evType, code, addr string) apievents.PortForward { sconn := c.ConnectionContext.ServerConn return apievents.PortForward{ Metadata: apievents.Metadata{ @@ -1281,7 +1281,7 @@ func (c *ServerContext) GetPortForwardEvent(evType, code string) apievents.PortF LocalAddr: sconn.LocalAddr().String(), RemoteAddr: sconn.RemoteAddr().String(), }, - Addr: c.DstAddr, + Addr: addr, Status: apievents.Status{ Success: true, }, diff --git a/lib/srv/forward/sshserver.go b/lib/srv/forward/sshserver.go index 4d5dcb4a2b497..849bc0a228c53 100644 --- a/lib/srv/forward/sshserver.go +++ b/lib/srv/forward/sshserver.go @@ -665,7 +665,7 @@ func (s *Server) Serve() { CloseCancel: func() { s.connectionContext.Close() }, }) - go s.handleClientChannels(ctx, forwardedTCPIP, sconn.LocalAddr().String(), sconn.RemoteAddr().String()) + go s.handleClientChannels(ctx, forwardedTCPIP) go s.handleConnection(ctx, chans, reqs) } @@ -874,33 +874,7 @@ func (s *Server) handleConnection(ctx context.Context, chans <-chan ssh.NewChann } // handleClientChannels handles channel open requests from the remote server. -func (s *Server) handleClientChannels(ctx context.Context, forwardedTCPIP <-chan ssh.NewChannel, localAddr, remoteAddr string) { - forwarding := false - defer func() { - // don't log the stop code unless we've logged the start code - if !forwarding { - return - } - - if err := s.EmitAuditEvent(ctx, &apievents.PortForward{ - Metadata: apievents.Metadata{ - Type: events.PortForwardEvent, - Code: events.PortForwardStopCode, - }, - UserMetadata: s.identityContext.GetUserMetadata(), - ConnectionMetadata: apievents.ConnectionMetadata{ - LocalAddr: localAddr, - RemoteAddr: remoteAddr, - }, - Addr: s.targetAddr, - Status: apievents.Status{ - Success: true, - }, - }); err != nil { - s.logger.WarnContext(ctx, "Failed to emit port forward event", "error", err) - } - }() - +func (s *Server) handleClientChannels(ctx context.Context, forwardedTCPIP <-chan ssh.NewChannel) { for nch := range forwardedTCPIP { chanCtx, nch := tracessh.ContextFromNewChannel(nch) ctx, span := s.tracerProvider.Tracer("ssh").Start( @@ -920,28 +894,6 @@ func (s *Server) handleClientChannels(ctx context.Context, forwardedTCPIP <-chan s.logger.ErrorContext(ctx, "Error handling forwarded-tcpip request", "error", err) } }() - - if !forwarding { - forwarding = true - - if err := s.EmitAuditEvent(ctx, &apievents.PortForward{ - Metadata: apievents.Metadata{ - Type: events.PortForwardEvent, - Code: events.PortForwardCode, - }, - UserMetadata: s.identityContext.GetUserMetadata(), - ConnectionMetadata: apievents.ConnectionMetadata{ - LocalAddr: localAddr, - RemoteAddr: remoteAddr, - }, - Addr: s.targetAddr, - Status: apievents.Status{ - Success: true, - }, - }); err != nil { - s.logger.WarnContext(ctx, "Failed to emit port forward event", "error", err) - } - } } } @@ -970,11 +922,6 @@ func (s *Server) handleForwardedTCPIPRequest(ctx context.Context, nch ssh.NewCha scx.DstAddr = sshutils.JoinHostPort(req.Host, req.Port) defer scx.Close() - event := scx.GetPortForwardEvent(events.PortForwardRemoteConnEvent, events.PortForwardCode) - if err := s.EmitAuditEvent(ctx, &event); err != nil { - s.logger.ErrorContext(ctx, "Failed to emit audit event", "error", err) - } - // Open a forwarding channel on the client. outCh, outRequests, err := scx.ServerConn.OpenChannel(nch.ChannelType(), nch.ExtraData()) if err != nil { @@ -994,12 +941,10 @@ func (s *Server) handleForwardedTCPIPRequest(ctx context.Context, nch ssh.NewCha go io.Copy(io.Discard, ch.Stderr()) ch = scx.TrackActivity(ch) - defer func() { - stopEvent := scx.GetPortForwardEvent(events.PortForwardRemoteConnEvent, events.PortForwardStopCode) - if err := s.EmitAuditEvent(ctx, &stopEvent); err != nil { - s.logger.ErrorContext(ctx, "Failed to emit audit event", "error", err) - } - }() + event := scx.GetPortForwardEvent() + if err := s.EmitAuditEvent(ctx, &event); err != nil { + s.logger.ErrorContext(ctx, "Failed to emit audit event", "error", err) + } return trace.Wrap(utils.ProxyConn(ctx, ch, outCh)) } @@ -1175,23 +1120,13 @@ func (s *Server) handleDirectTCPIPRequest(ctx context.Context, ch ssh.Channel, r } defer conn.Close() - event := scx.GetPortForwardEvent(events.PortForwardLocalEvent, events.PortForwardCode) + event := scx.GetPortForwardEvent() if err := s.EmitAuditEvent(s.closeContext, &event); err != nil { s.logger.WarnContext(ctx, "Failed to emit port forward event", "error", err) } if err := utils.ProxyConn(ctx, ch, conn); err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, os.ErrClosed) { s.logger.WarnContext(ctx, "Failed proxying data for port forwarding connection", "error", err) - - event = scx.GetPortForwardEvent(events.PortForwardLocalEvent, events.PortForwardFailureCode) - if err := s.EmitAuditEvent(s.closeContext, &event); err != nil { - s.logger.WarnContext(ctx, "Failed to emit port forward event", "error", err) - } - } - - event = scx.GetPortForwardEvent(events.PortForwardLocalEvent, events.PortForwardStopCode) - if err := s.EmitAuditEvent(s.closeContext, &event); err != nil { - s.logger.WarnContext(ctx, "Failed to emit port forward event", "error", err) } } diff --git a/lib/srv/regular/sshserver.go b/lib/srv/regular/sshserver.go index bc0d849420bf7..c3e0ba02ac3dd 100644 --- a/lib/srv/regular/sshserver.go +++ b/lib/srv/regular/sshserver.go @@ -1489,34 +1489,21 @@ func (s *Server) handleDirectTCPIPRequest(ctx context.Context, ccx *sshutils.Con return } - if err := s.EmitAuditEvent(s.ctx, &apievents.PortForward{ - Metadata: apievents.Metadata{ - Type: events.PortForwardLocalEvent, - Code: events.PortForwardCode, - }, - UserMetadata: scx.Identity.GetUserMetadata(), - ConnectionMetadata: apievents.ConnectionMetadata{ - LocalAddr: scx.ServerConn.LocalAddr().String(), - RemoteAddr: scx.ServerConn.RemoteAddr().String(), - }, - Addr: scx.DstAddr, - Status: apievents.Status{ - Success: true, - }, - }); err != nil { + event := scx.GetPortForwardEvent(events.PortForwardLocalEvent, events.PortForwardCode, scx.DstAddr) + if err := s.EmitAuditEvent(ctx, &event); err != nil { scx.Logger.WarnContext(ctx, "Failed to emit port forward event", "error", err) } if err := utils.ProxyConn(ctx, conn, channel); err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, os.ErrClosed) { - event := scx.GetPortForwardEvent(events.PortForwardLocalEvent, events.PortForwardFailureCode) - if err := s.EmitAuditEvent(s.ctx, &event); err != nil { + event := scx.GetPortForwardEvent(events.PortForwardLocalEvent, events.PortForwardFailureCode, scx.DstAddr) + if err := s.EmitAuditEvent(ctx, &event); err != nil { scx.Logger.WarnContext(ctx, "Failed to emit port forward event", "error", err) } scx.Logger.WarnContext(ctx, "Connection problem in direct-tcpip channel", "error", err) } - event := scx.GetPortForwardEvent(events.PortForwardLocalEvent, events.PortForwardStopCode) - if err := s.EmitAuditEvent(s.ctx, &event); err != nil { + event = scx.GetPortForwardEvent(events.PortForwardLocalEvent, events.PortForwardStopCode, scx.DstAddr) + if err := s.EmitAuditEvent(ctx, &event); err != nil { scx.Logger.WarnContext(ctx, "Failed to emit port forward event", "error", err) } } @@ -2212,9 +2199,9 @@ func (s *Server) handleTCPIPForwardRequest(ctx context.Context, ccx *sshutils.Co scx.SrcAddr = sshutils.JoinHostPort(srcHost, listenPort) // pregenerate audit events since ServerContext may be closed before they're used - startEvent := scx.GetPortForwardEvent(events.PortForwardRemoteConnEvent, events.PortForwardCode) - stopEvent := scx.GetPortForwardEvent(events.PortForwardRemoteConnEvent, events.PortForwardStopCode) - errEvent := scx.GetPortForwardEvent(events.PortForwardRemoteConnEvent, events.PortForwardFailureCode) + startEvent := scx.GetPortForwardEvent(events.PortForwardRemoteConnEvent, events.PortForwardCode, scx.SrcAddr) + stopEvent := scx.GetPortForwardEvent(events.PortForwardRemoteConnEvent, events.PortForwardStopCode, scx.SrcAddr) + errEvent := scx.GetPortForwardEvent(events.PortForwardRemoteConnEvent, events.PortForwardFailureCode, scx.SrcAddr) proxyWithAudit := func(ctx context.Context, remoteAddr string, client io.ReadWriteCloser, server io.ReadWriteCloser) { startEvent.RemoteAddr = remoteAddr if err := s.EmitAuditEvent(ctx, &startEvent); err != nil { @@ -2234,11 +2221,16 @@ func (s *Server) handleTCPIPForwardRequest(ctx context.Context, ccx *sshutils.Co } } - if err := sshutils.StartRemoteListener(ctx, scx.ServerConn, scx.SrcAddr, listener, proxyWithAudit); err != nil { + forwardStopEvent := scx.GetPortForwardEvent(events.PortForwardRemoteEvent, events.PortForwardStopCode, scx.SrcAddr) + if err := sshutils.StartRemoteListener(ctx, scx.ServerConn, scx.SrcAddr, listener, proxyWithAudit, func() { + if err := s.EmitAuditEvent(ctx, &forwardStopEvent); err != nil { + s.logger.WarnContext(ctx, "Failed to emit audit event", "error", err) + } + }); err != nil { return trace.Wrap(err) } - event := scx.GetPortForwardEvent(events.PortForwardRemoteEvent, events.PortForwardCode) + event := scx.GetPortForwardEvent(events.PortForwardRemoteEvent, events.PortForwardCode, scx.SrcAddr) if err := s.EmitAuditEvent(ctx, &event); err != nil { s.logger.WarnContext(ctx, "Failed to emit audit event", "error", err) } @@ -2266,11 +2258,6 @@ func (s *Server) handleTCPIPForwardRequest(ctx context.Context, ccx *sshutils.Co // Close the listener once the connection is closed, if it hasn't // been closed already via a cancel-tcpip-forward request. ccx.AddCloser(utils.CloseFunc(func() error { - event := scx.GetPortForwardEvent(events.PortForwardRemoteEvent, events.PortForwardStopCode) - if err := s.EmitAuditEvent(context.Background(), &event); err != nil { - s.logger.WarnContext(context.Background(), "Failed to emit audit event", "error", err) - } - listener, ok := s.remoteForwardingMap.LoadAndDelete(scx.SrcAddr) if ok { return trace.Wrap(listener.Close()) diff --git a/lib/srv/regular/sshserver_test.go b/lib/srv/regular/sshserver_test.go index a2212126416fb..1dc151ed10a64 100644 --- a/lib/srv/regular/sshserver_test.go +++ b/lib/srv/regular/sshserver_test.go @@ -474,8 +474,6 @@ func TestSessionAuditLog(t *testing.T) { roleOptions := role.GetOptions() roleOptions.PermitX11Forwarding = types.NewBool(true) roleOptions.ForwardAgent = types.NewBool(true) - //nolint:staticcheck // this field is preserved for existing deployments, but shouldn't be used going forward - roleOptions.PortForwarding = types.NewBoolOption(true) role.SetOptions(roleOptions) _, err = f.testSrv.Auth().UpsertRole(ctx, role) require.NoError(t, err) @@ -529,20 +527,62 @@ func TestSessionAuditLog(t *testing.T) { ts.Listener = listener ts.Start() + e = nextEvent() + remoteForwardStart, ok := e.(*apievents.PortForward) + require.True(t, ok, "expected PortForward event but got event of type %T", e) + // require.Equal(t, listener.Addr().String(), remoteForwardStart.Addr) + require.Equal(t, events.PortForwardRemoteEvent, remoteForwardStart.GetType()) + require.Equal(t, events.PortForwardCode, remoteForwardStart.GetCode()) + // Request forward to remote port. Each dial should result in a new event. Note that we don't // know what port the server will forward the connection on, so we don't have an easy way to // validate the event's addr field. conn, err := f.ssh.clt.DialContext(context.Background(), "tcp", listener.Addr().String()) require.NoError(t, err) + + // the order of PortForwardLocal events and PortForwardRemoteConn events are sometimes swapped but order doesn't matter, so we just + // need to ensure that we receive both + foundLocalForwardStart := false + foundConnForwardStart := false + for i := 0; i < 2; i += 1 { + e = nextEvent() + require.IsType(t, &apievents.PortForward{}, e, "expected PortForward event but got event of type %T", e) + if !foundLocalForwardStart && e.GetType() == events.PortForwardLocalEvent { + foundLocalForwardStart = e.GetCode() == events.PortForwardCode + continue + } + + if !foundConnForwardStart && e.GetType() == events.PortForwardRemoteConnEvent { + foundConnForwardStart = e.GetCode() == events.PortForwardCode + } + } + require.True(t, foundLocalForwardStart && foundConnForwardStart) + conn.Close() + // similar to above, order of stop events received is inconsistent and mostly irrelevant here + foundLocalForwardStop := false + foundConnForwardStop := false + for i := 0; i < 2; i += 1 { + e = nextEvent() + require.IsType(t, &apievents.PortForward{}, e, "expected PortForward event but got event of type %T", e) + if !foundLocalForwardStop && e.GetType() == events.PortForwardLocalEvent { + foundLocalForwardStop = e.GetCode() == events.PortForwardStopCode + continue + } - directPortForwardEvent := nextEvent() - require.IsType(t, &apievents.PortForward{}, directPortForwardEvent, "expected PortForward event but got event of type %T", directPortForwardEvent) + if !foundConnForwardStop && e.GetType() == events.PortForwardRemoteConnEvent { + foundConnForwardStop = e.GetCode() == events.PortForwardStopCode + } + } + require.True(t, foundLocalForwardStop && foundConnForwardStop) + ts.Close() e = nextEvent() - remotePortForwardEvent, ok := e.(*apievents.PortForward) + remoteForwardStop, ok := e.(*apievents.PortForward) require.True(t, ok, "expected PortForward event but got event of type %T", e) - require.Equal(t, listener.Addr().String(), remotePortForwardEvent.Addr) + // require.Equal(t, listener.Addr().String(), remoteForwardStop.Addr) + require.Equal(t, events.PortForwardRemoteEvent, remoteForwardStop.GetType()) + require.Equal(t, events.PortForwardStopCode, remoteForwardStop.Code) // End the session. Session leave, data, and end events should be emitted. se.Close() diff --git a/lib/sshutils/tcpip.go b/lib/sshutils/tcpip.go index 3358fff88d015..e5fa9518fa6ac 100644 --- a/lib/sshutils/tcpip.go +++ b/lib/sshutils/tcpip.go @@ -79,13 +79,19 @@ type channelOpener interface { // StartRemoteListener listens on the given listener and forwards any accepted // connections over a new "forwarded-tcpip" channel. -func StartRemoteListener(ctx context.Context, sshConn channelOpener, srcAddr string, listener net.Listener, proxyFn func(ctx context.Context, remoteAddr string, client io.ReadWriteCloser, server io.ReadWriteCloser)) error { +func StartRemoteListener(ctx context.Context, sshConn channelOpener, srcAddr string, listener net.Listener, proxyFn func(ctx context.Context, remoteAddr string, client io.ReadWriteCloser, server io.ReadWriteCloser), onClose func()) error { srcHost, srcPort, err := SplitHostPort(srcAddr) if err != nil { return trace.Wrap(err) } go func() { + defer func() { + if onClose != nil { + onClose() + } + }() + for { conn, err := listener.Accept() if err != nil { diff --git a/lib/sshutils/tcpip_test.go b/lib/sshutils/tcpip_test.go index 6e84efe55b5b9..ca1297e87835a 100644 --- a/lib/sshutils/tcpip_test.go +++ b/lib/sshutils/tcpip_test.go @@ -54,7 +54,7 @@ func TestStartRemoteListener(t *testing.T) { require.NoError(t, err) ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) t.Cleanup(cancel) - require.NoError(t, StartRemoteListener(ctx, sshConn, "127.0.0.1:12345", listener, nil)) + require.NoError(t, StartRemoteListener(ctx, sshConn, "127.0.0.1:12345", listener, nil, nil)) // Check that dialing listener makes it all the way to the test http server. resp, err := http.Get("http://" + listener.Addr().String()) @@ -65,7 +65,7 @@ func TestStartRemoteListener(t *testing.T) { require.Equal(t, "Hello, world", string(body)) } -func TestStartRemoteListenerWithCustomProxy(t *testing.T) { +func TestStartRemoteListenerWithCallbacks(t *testing.T) { // Create a test server to act as the other side of the channel. tsrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "Hello, world") @@ -82,17 +82,22 @@ func TestStartRemoteListenerWithCustomProxy(t *testing.T) { } proxied := false + closedCh := make(chan struct{}) proxyFn := func(ctx context.Context, remoteAddr string, client io.ReadWriteCloser, server io.ReadWriteCloser) { proxied = true _ = utils.ProxyConn(ctx, client, server) } + closeFn := func() { + closedCh <- struct{}{} + } + // Start the remote listener. listener, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) t.Cleanup(cancel) - require.NoError(t, StartRemoteListener(ctx, sshConn, "127.0.0.1:12345", listener, proxyFn)) + require.NoError(t, StartRemoteListener(ctx, sshConn, "127.0.0.1:12345", listener, proxyFn, closeFn)) // Check that dialing listener makes it all the way to the test http server. resp, err := http.Get("http://" + listener.Addr().String()) @@ -102,4 +107,12 @@ func TestStartRemoteListenerWithCustomProxy(t *testing.T) { require.NoError(t, err) require.Equal(t, "Hello, world", string(body)) require.True(t, proxied) + listener.Close() + + select { + case <-ctx.Done(): + require.Fail(t, "expected closeFn callback to be called before cancellation") + case <-closedCh: + return + } }