diff --git a/ssh/connection.go b/ssh/connection.go index 8f345ee924..f12731200c 100644 --- a/ssh/connection.go +++ b/ssh/connection.go @@ -7,6 +7,7 @@ package ssh import ( "fmt" "net" + "sync/atomic" ) // OpenChannelError is returned if the other side rejects an @@ -89,6 +90,11 @@ type connection struct { transport *handshakeTransport sshConn + // serverAuthComplete is whether, when used as an incoming server + // auth connection, the auth phase is complete. This is used to prevent + // use of ServerPreAuthConn after the auth phase is complete. + serverAuthComplete atomic.Bool + // The connection protocol. *mux } diff --git a/ssh/server.go b/ssh/server.go index c0d1c29e6f..bd782cff90 100644 --- a/ssh/server.go +++ b/ssh/server.go @@ -59,6 +59,30 @@ type GSSAPIWithMICConfig struct { Server GSSAPIServer } +// SendAuthanner implements [ServerPreAuthConn]. +func (s *connection) SendAuthBanner(msg string) error { + if s.serverAuthComplete.Load() { + return errors.New("ssh: SendAuthBanner outside of authentication phase") + } + return s.transport.writePacket(Marshal(&userAuthBannerMsg{ + Message: msg, + })) +} + +func (*connection) unexportedMethodForFutureProofing() {} + +// ServerPreAuthConn is the interface available on an incoming server +// connection before authentication has completed. +type ServerPreAuthConn interface { + unexportedMethodForFutureProofing() // permits growing ServerPreAuthConn safely later, ala testing.TB + + ConnMetadata + + // SendAuthBanner sends a baner message to the client. + // It returns an error once the authentication phase has ended. + SendAuthBanner(string) error +} + // ServerConfig holds server specific configuration data. type ServerConfig struct { // Config contains configuration shared between client and server. @@ -118,6 +142,11 @@ type ServerConfig struct { // attempts. AuthLogCallback func(conn ConnMetadata, method string, err error) + // PreAuthConnCallback, if non-nil, is called upon receiving a new connection + // before any authentication has started. The provided ServerPreAuthConn + // can be used before authentication is complete. + PreAuthConnCallback func(ServerPreAuthConn) + // ServerVersion is the version identification string to announce in // the public handshake. // If empty, a reasonable default is used. @@ -230,6 +259,7 @@ func NewServerConn(c net.Conn, config *ServerConfig) (*ServerConn, <-chan NewCha sshConn: sshConn{conn: c}, } perms, err := s.serverHandshake(&fullConf) + s.serverAuthComplete.Store(true) if err != nil { c.Close() return nil, nil, nil, err @@ -481,6 +511,10 @@ func (b *BannerError) Error() string { } func (s *connection) serverAuthenticate(config *ServerConfig) (*Permissions, error) { + if config.PreAuthConnCallback != nil { + config.PreAuthConnCallback(s) + } + sessionID := s.transport.getSessionID() var cache pubKeyCache var perms *Permissions diff --git a/ssh/server_test.go b/ssh/server_test.go index b6d8ab3333..2064463d5b 100644 --- a/ssh/server_test.go +++ b/ssh/server_test.go @@ -299,6 +299,70 @@ func TestBannerError(t *testing.T) { } } +func TestPreAuthConnAndBanners(t *testing.T) { + authConnc := make(chan ServerPreAuthConn, 1) + serverConfig := &ServerConfig{ + PreAuthConnCallback: func(c ServerPreAuthConn) { + t.Logf("got ServerPreAuthConn: %v", c) + authConnc <- c // for use later in the test + for _, s := range []string{"hello1", "hello2"} { + if err := c.SendAuthBanner(s); err != nil { + t.Errorf("failed to send banner %q: %v", s, err) + } + } + }, + NoClientAuth: true, + NoClientAuthCallback: func(ConnMetadata) (*Permissions, error) { + t.Logf("got NoClientAuthCallback") + return &Permissions{}, nil + }, + } + serverConfig.AddHostKey(testSigners["rsa"]) + + var banners []string + clientConfig := &ClientConfig{ + User: "test", + HostKeyCallback: InsecureIgnoreHostKey(), + BannerCallback: func(msg string) error { + banners = append(banners, msg) + return nil + }, + } + + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + go newServer(c1, serverConfig) + c, _, _, err := NewClientConn(c2, "", clientConfig) + if err != nil { + t.Fatalf("client connection failed: %v", err) + } + defer c.Close() + + wantBanners := []string{ + "hello1", + "hello2", + } + if !reflect.DeepEqual(banners, wantBanners) { + t.Errorf("got banners:\n%q\nwant banners:\n%q", banners, wantBanners) + } + + // Now that we're authenticated, verify that use of SendBanner + // is an error. + var bc ServerPreAuthConn + select { + case bc = <-authConnc: + default: + t.Fatal("expected ServerPreAuthConn") + } + if err := bc.SendAuthBanner("wrong-phase"); err == nil { + t.Error("unexpected success of SendAuthBanner after authentication") + } +} + type markerConn struct { closed uint32 used uint32