Skip to content

Commit

Permalink
fixing tests, event placements, and forwarding addresses
Browse files Browse the repository at this point in the history
  • Loading branch information
eriktate committed Jan 15, 2025
1 parent f5e982d commit 3d0ccd6
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 113 deletions.
6 changes: 6 additions & 0 deletions lib/events/dynamic.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,12 @@ func FromEventFields(fields EventFields) (events.AuditEvent, error) {
e = &events.X11Forward{}
case PortForwardEvent:
e = &events.PortForward{}
case PortForwardLocalEvent:
e = &events.PortForward{}
case PortForwardRemoteEvent:
e = &events.PortForward{}
case PortForwardRemoteConnEvent:
e = &events.PortForward{}
case AuthAttemptEvent:
e = &events.AuthAttempt{}
case SCPEvent:
Expand Down
4 changes: 2 additions & 2 deletions lib/srv/ctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -1269,7 +1269,7 @@ func (c *ServerContext) GetSessionMetadata() apievents.SessionMetadata {
}
}

func (c *ServerContext) GetPortForwardEvent(evType, code string) apievents.PortForward {
func (c *ServerContext) GetPortForwardEvent(evType, code, addr string) apievents.PortForward {
sconn := c.ConnectionContext.ServerConn
return apievents.PortForward{
Metadata: apievents.Metadata{
Expand All @@ -1281,7 +1281,7 @@ func (c *ServerContext) GetPortForwardEvent(evType, code string) apievents.PortF
LocalAddr: sconn.LocalAddr().String(),
RemoteAddr: sconn.RemoteAddr().String(),
},
Addr: c.DstAddr,
Addr: addr,
Status: apievents.Status{
Success: true,
},
Expand Down
79 changes: 7 additions & 72 deletions lib/srv/forward/sshserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,7 @@ func (s *Server) Serve() {
CloseCancel: func() { s.connectionContext.Close() },
})

go s.handleClientChannels(ctx, forwardedTCPIP, sconn.LocalAddr().String(), sconn.RemoteAddr().String())
go s.handleClientChannels(ctx, forwardedTCPIP)
go s.handleConnection(ctx, chans, reqs)
}

Expand Down Expand Up @@ -874,33 +874,7 @@ func (s *Server) handleConnection(ctx context.Context, chans <-chan ssh.NewChann
}

// handleClientChannels handles channel open requests from the remote server.
func (s *Server) handleClientChannels(ctx context.Context, forwardedTCPIP <-chan ssh.NewChannel, localAddr, remoteAddr string) {
forwarding := false
defer func() {
// don't log the stop code unless we've logged the start code
if !forwarding {
return
}

if err := s.EmitAuditEvent(ctx, &apievents.PortForward{
Metadata: apievents.Metadata{
Type: events.PortForwardEvent,
Code: events.PortForwardStopCode,
},
UserMetadata: s.identityContext.GetUserMetadata(),
ConnectionMetadata: apievents.ConnectionMetadata{
LocalAddr: localAddr,
RemoteAddr: remoteAddr,
},
Addr: s.targetAddr,
Status: apievents.Status{
Success: true,
},
}); err != nil {
s.logger.WarnContext(ctx, "Failed to emit port forward event", "error", err)
}
}()

func (s *Server) handleClientChannels(ctx context.Context, forwardedTCPIP <-chan ssh.NewChannel) {
for nch := range forwardedTCPIP {
chanCtx, nch := tracessh.ContextFromNewChannel(nch)
ctx, span := s.tracerProvider.Tracer("ssh").Start(
Expand All @@ -920,28 +894,6 @@ func (s *Server) handleClientChannels(ctx context.Context, forwardedTCPIP <-chan
s.logger.ErrorContext(ctx, "Error handling forwarded-tcpip request", "error", err)
}
}()

if !forwarding {
forwarding = true

if err := s.EmitAuditEvent(ctx, &apievents.PortForward{
Metadata: apievents.Metadata{
Type: events.PortForwardEvent,
Code: events.PortForwardCode,
},
UserMetadata: s.identityContext.GetUserMetadata(),
ConnectionMetadata: apievents.ConnectionMetadata{
LocalAddr: localAddr,
RemoteAddr: remoteAddr,
},
Addr: s.targetAddr,
Status: apievents.Status{
Success: true,
},
}); err != nil {
s.logger.WarnContext(ctx, "Failed to emit port forward event", "error", err)
}
}
}
}

Expand Down Expand Up @@ -970,11 +922,6 @@ func (s *Server) handleForwardedTCPIPRequest(ctx context.Context, nch ssh.NewCha
scx.DstAddr = sshutils.JoinHostPort(req.Host, req.Port)
defer scx.Close()

event := scx.GetPortForwardEvent(events.PortForwardRemoteConnEvent, events.PortForwardCode)
if err := s.EmitAuditEvent(ctx, &event); err != nil {
s.logger.ErrorContext(ctx, "Failed to emit audit event", "error", err)
}

// Open a forwarding channel on the client.
outCh, outRequests, err := scx.ServerConn.OpenChannel(nch.ChannelType(), nch.ExtraData())
if err != nil {
Expand All @@ -994,12 +941,10 @@ func (s *Server) handleForwardedTCPIPRequest(ctx context.Context, nch ssh.NewCha
go io.Copy(io.Discard, ch.Stderr())
ch = scx.TrackActivity(ch)

defer func() {
stopEvent := scx.GetPortForwardEvent(events.PortForwardRemoteConnEvent, events.PortForwardStopCode)
if err := s.EmitAuditEvent(ctx, &stopEvent); err != nil {
s.logger.ErrorContext(ctx, "Failed to emit audit event", "error", err)
}
}()
event := scx.GetPortForwardEvent()
if err := s.EmitAuditEvent(ctx, &event); err != nil {
s.logger.ErrorContext(ctx, "Failed to emit audit event", "error", err)
}

return trace.Wrap(utils.ProxyConn(ctx, ch, outCh))
}
Expand Down Expand Up @@ -1175,23 +1120,13 @@ func (s *Server) handleDirectTCPIPRequest(ctx context.Context, ch ssh.Channel, r
}
defer conn.Close()

event := scx.GetPortForwardEvent(events.PortForwardLocalEvent, events.PortForwardCode)
event := scx.GetPortForwardEvent()
if err := s.EmitAuditEvent(s.closeContext, &event); err != nil {
s.logger.WarnContext(ctx, "Failed to emit port forward event", "error", err)
}

if err := utils.ProxyConn(ctx, ch, conn); err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, os.ErrClosed) {
s.logger.WarnContext(ctx, "Failed proxying data for port forwarding connection", "error", err)

event = scx.GetPortForwardEvent(events.PortForwardLocalEvent, events.PortForwardFailureCode)
if err := s.EmitAuditEvent(s.closeContext, &event); err != nil {
s.logger.WarnContext(ctx, "Failed to emit port forward event", "error", err)
}
}

event = scx.GetPortForwardEvent(events.PortForwardLocalEvent, events.PortForwardStopCode)
if err := s.EmitAuditEvent(s.closeContext, &event); err != nil {
s.logger.WarnContext(ctx, "Failed to emit port forward event", "error", err)
}
}

Expand Down
45 changes: 16 additions & 29 deletions lib/srv/regular/sshserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -1489,34 +1489,21 @@ func (s *Server) handleDirectTCPIPRequest(ctx context.Context, ccx *sshutils.Con
return
}

if err := s.EmitAuditEvent(s.ctx, &apievents.PortForward{
Metadata: apievents.Metadata{
Type: events.PortForwardLocalEvent,
Code: events.PortForwardCode,
},
UserMetadata: scx.Identity.GetUserMetadata(),
ConnectionMetadata: apievents.ConnectionMetadata{
LocalAddr: scx.ServerConn.LocalAddr().String(),
RemoteAddr: scx.ServerConn.RemoteAddr().String(),
},
Addr: scx.DstAddr,
Status: apievents.Status{
Success: true,
},
}); err != nil {
event := scx.GetPortForwardEvent(events.PortForwardLocalEvent, events.PortForwardCode, scx.DstAddr)
if err := s.EmitAuditEvent(ctx, &event); err != nil {
scx.Logger.WarnContext(ctx, "Failed to emit port forward event", "error", err)
}

if err := utils.ProxyConn(ctx, conn, channel); err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, os.ErrClosed) {
event := scx.GetPortForwardEvent(events.PortForwardLocalEvent, events.PortForwardFailureCode)
if err := s.EmitAuditEvent(s.ctx, &event); err != nil {
event := scx.GetPortForwardEvent(events.PortForwardLocalEvent, events.PortForwardFailureCode, scx.DstAddr)
if err := s.EmitAuditEvent(ctx, &event); err != nil {
scx.Logger.WarnContext(ctx, "Failed to emit port forward event", "error", err)
}
scx.Logger.WarnContext(ctx, "Connection problem in direct-tcpip channel", "error", err)
}

event := scx.GetPortForwardEvent(events.PortForwardLocalEvent, events.PortForwardStopCode)
if err := s.EmitAuditEvent(s.ctx, &event); err != nil {
event = scx.GetPortForwardEvent(events.PortForwardLocalEvent, events.PortForwardStopCode, scx.DstAddr)
if err := s.EmitAuditEvent(ctx, &event); err != nil {
scx.Logger.WarnContext(ctx, "Failed to emit port forward event", "error", err)
}
}
Expand Down Expand Up @@ -2212,9 +2199,9 @@ func (s *Server) handleTCPIPForwardRequest(ctx context.Context, ccx *sshutils.Co
scx.SrcAddr = sshutils.JoinHostPort(srcHost, listenPort)

// pregenerate audit events since ServerContext may be closed before they're used
startEvent := scx.GetPortForwardEvent(events.PortForwardRemoteConnEvent, events.PortForwardCode)
stopEvent := scx.GetPortForwardEvent(events.PortForwardRemoteConnEvent, events.PortForwardStopCode)
errEvent := scx.GetPortForwardEvent(events.PortForwardRemoteConnEvent, events.PortForwardFailureCode)
startEvent := scx.GetPortForwardEvent(events.PortForwardRemoteConnEvent, events.PortForwardCode, scx.SrcAddr)
stopEvent := scx.GetPortForwardEvent(events.PortForwardRemoteConnEvent, events.PortForwardStopCode, scx.SrcAddr)
errEvent := scx.GetPortForwardEvent(events.PortForwardRemoteConnEvent, events.PortForwardFailureCode, scx.SrcAddr)
proxyWithAudit := func(ctx context.Context, remoteAddr string, client io.ReadWriteCloser, server io.ReadWriteCloser) {
startEvent.RemoteAddr = remoteAddr
if err := s.EmitAuditEvent(ctx, &startEvent); err != nil {
Expand All @@ -2234,11 +2221,16 @@ func (s *Server) handleTCPIPForwardRequest(ctx context.Context, ccx *sshutils.Co
}
}

if err := sshutils.StartRemoteListener(ctx, scx.ServerConn, scx.SrcAddr, listener, proxyWithAudit); err != nil {
forwardStopEvent := scx.GetPortForwardEvent(events.PortForwardRemoteEvent, events.PortForwardStopCode, scx.SrcAddr)
if err := sshutils.StartRemoteListener(ctx, scx.ServerConn, scx.SrcAddr, listener, proxyWithAudit, func() {
if err := s.EmitAuditEvent(ctx, &forwardStopEvent); err != nil {
s.logger.WarnContext(ctx, "Failed to emit audit event", "error", err)
}
}); err != nil {
return trace.Wrap(err)
}

event := scx.GetPortForwardEvent(events.PortForwardRemoteEvent, events.PortForwardCode)
event := scx.GetPortForwardEvent(events.PortForwardRemoteEvent, events.PortForwardCode, scx.SrcAddr)
if err := s.EmitAuditEvent(ctx, &event); err != nil {
s.logger.WarnContext(ctx, "Failed to emit audit event", "error", err)
}
Expand Down Expand Up @@ -2266,11 +2258,6 @@ func (s *Server) handleTCPIPForwardRequest(ctx context.Context, ccx *sshutils.Co
// Close the listener once the connection is closed, if it hasn't
// been closed already via a cancel-tcpip-forward request.
ccx.AddCloser(utils.CloseFunc(func() error {
event := scx.GetPortForwardEvent(events.PortForwardRemoteEvent, events.PortForwardStopCode)
if err := s.EmitAuditEvent(context.Background(), &event); err != nil {
s.logger.WarnContext(context.Background(), "Failed to emit audit event", "error", err)
}

listener, ok := s.remoteForwardingMap.LoadAndDelete(scx.SrcAddr)
if ok {
return trace.Wrap(listener.Close())
Expand Down
52 changes: 46 additions & 6 deletions lib/srv/regular/sshserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -474,8 +474,6 @@ func TestSessionAuditLog(t *testing.T) {
roleOptions := role.GetOptions()
roleOptions.PermitX11Forwarding = types.NewBool(true)
roleOptions.ForwardAgent = types.NewBool(true)
//nolint:staticcheck // this field is preserved for existing deployments, but shouldn't be used going forward
roleOptions.PortForwarding = types.NewBoolOption(true)
role.SetOptions(roleOptions)
_, err = f.testSrv.Auth().UpsertRole(ctx, role)
require.NoError(t, err)
Expand Down Expand Up @@ -529,20 +527,62 @@ func TestSessionAuditLog(t *testing.T) {
ts.Listener = listener
ts.Start()

e = nextEvent()
remoteForwardStart, ok := e.(*apievents.PortForward)
require.True(t, ok, "expected PortForward event but got event of type %T", e)
// require.Equal(t, listener.Addr().String(), remoteForwardStart.Addr)
require.Equal(t, events.PortForwardRemoteEvent, remoteForwardStart.GetType())
require.Equal(t, events.PortForwardCode, remoteForwardStart.GetCode())

// Request forward to remote port. Each dial should result in a new event. Note that we don't
// know what port the server will forward the connection on, so we don't have an easy way to
// validate the event's addr field.
conn, err := f.ssh.clt.DialContext(context.Background(), "tcp", listener.Addr().String())
require.NoError(t, err)

// the order of PortForwardLocal events and PortForwardRemoteConn events are sometimes swapped but order doesn't matter, so we just
// need to ensure that we receive both
foundLocalForwardStart := false
foundConnForwardStart := false
for i := 0; i < 2; i += 1 {
e = nextEvent()
require.IsType(t, &apievents.PortForward{}, e, "expected PortForward event but got event of type %T", e)
if !foundLocalForwardStart && e.GetType() == events.PortForwardLocalEvent {
foundLocalForwardStart = e.GetCode() == events.PortForwardCode
continue
}

if !foundConnForwardStart && e.GetType() == events.PortForwardRemoteConnEvent {
foundConnForwardStart = e.GetCode() == events.PortForwardCode
}
}
require.True(t, foundLocalForwardStart && foundConnForwardStart)

conn.Close()
// similar to above, order of stop events received is inconsistent and mostly irrelevant here
foundLocalForwardStop := false
foundConnForwardStop := false
for i := 0; i < 2; i += 1 {
e = nextEvent()
require.IsType(t, &apievents.PortForward{}, e, "expected PortForward event but got event of type %T", e)
if !foundLocalForwardStop && e.GetType() == events.PortForwardLocalEvent {
foundLocalForwardStop = e.GetCode() == events.PortForwardStopCode
continue
}

directPortForwardEvent := nextEvent()
require.IsType(t, &apievents.PortForward{}, directPortForwardEvent, "expected PortForward event but got event of type %T", directPortForwardEvent)
if !foundConnForwardStop && e.GetType() == events.PortForwardRemoteConnEvent {
foundConnForwardStop = e.GetCode() == events.PortForwardStopCode
}
}
require.True(t, foundLocalForwardStop && foundConnForwardStop)

ts.Close()
e = nextEvent()
remotePortForwardEvent, ok := e.(*apievents.PortForward)
remoteForwardStop, ok := e.(*apievents.PortForward)
require.True(t, ok, "expected PortForward event but got event of type %T", e)
require.Equal(t, listener.Addr().String(), remotePortForwardEvent.Addr)
// require.Equal(t, listener.Addr().String(), remoteForwardStop.Addr)
require.Equal(t, events.PortForwardRemoteEvent, remoteForwardStop.GetType())
require.Equal(t, events.PortForwardStopCode, remoteForwardStop.Code)

// End the session. Session leave, data, and end events should be emitted.
se.Close()
Expand Down
8 changes: 7 additions & 1 deletion lib/sshutils/tcpip.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,19 @@ type channelOpener interface {

// StartRemoteListener listens on the given listener and forwards any accepted
// connections over a new "forwarded-tcpip" channel.
func StartRemoteListener(ctx context.Context, sshConn channelOpener, srcAddr string, listener net.Listener, proxyFn func(ctx context.Context, remoteAddr string, client io.ReadWriteCloser, server io.ReadWriteCloser)) error {
func StartRemoteListener(ctx context.Context, sshConn channelOpener, srcAddr string, listener net.Listener, proxyFn func(ctx context.Context, remoteAddr string, client io.ReadWriteCloser, server io.ReadWriteCloser), onClose func()) error {
srcHost, srcPort, err := SplitHostPort(srcAddr)
if err != nil {
return trace.Wrap(err)
}

go func() {
defer func() {
if onClose != nil {
onClose()
}
}()

for {
conn, err := listener.Accept()
if err != nil {
Expand Down
19 changes: 16 additions & 3 deletions lib/sshutils/tcpip_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func TestStartRemoteListener(t *testing.T) {
require.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
t.Cleanup(cancel)
require.NoError(t, StartRemoteListener(ctx, sshConn, "127.0.0.1:12345", listener, nil))
require.NoError(t, StartRemoteListener(ctx, sshConn, "127.0.0.1:12345", listener, nil, nil))

// Check that dialing listener makes it all the way to the test http server.
resp, err := http.Get("http://" + listener.Addr().String())
Expand All @@ -65,7 +65,7 @@ func TestStartRemoteListener(t *testing.T) {
require.Equal(t, "Hello, world", string(body))
}

func TestStartRemoteListenerWithCustomProxy(t *testing.T) {
func TestStartRemoteListenerWithCallbacks(t *testing.T) {
// Create a test server to act as the other side of the channel.
tsrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, "Hello, world")
Expand All @@ -82,17 +82,22 @@ func TestStartRemoteListenerWithCustomProxy(t *testing.T) {
}

proxied := false
closedCh := make(chan struct{})
proxyFn := func(ctx context.Context, remoteAddr string, client io.ReadWriteCloser, server io.ReadWriteCloser) {
proxied = true
_ = utils.ProxyConn(ctx, client, server)
}

closeFn := func() {
closedCh <- struct{}{}
}

// Start the remote listener.
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
t.Cleanup(cancel)
require.NoError(t, StartRemoteListener(ctx, sshConn, "127.0.0.1:12345", listener, proxyFn))
require.NoError(t, StartRemoteListener(ctx, sshConn, "127.0.0.1:12345", listener, proxyFn, closeFn))

// Check that dialing listener makes it all the way to the test http server.
resp, err := http.Get("http://" + listener.Addr().String())
Expand All @@ -102,4 +107,12 @@ func TestStartRemoteListenerWithCustomProxy(t *testing.T) {
require.NoError(t, err)
require.Equal(t, "Hello, world", string(body))
require.True(t, proxied)
listener.Close()

select {
case <-ctx.Done():
require.Fail(t, "expected closeFn callback to be called before cancellation")
case <-closedCh:
return
}
}

0 comments on commit 3d0ccd6

Please sign in to comment.