Skip to content

Commit

Permalink
ssh: add ServerConfig.PreAuthConnCallback, ServerPreAuthConn (banner)…
Browse files Browse the repository at this point in the history
… interface

Fixes golang/go#68688

Change-Id: Id5f72b32c61c9383a26ec182339486a432c7cdf5
  • Loading branch information
bradfitz committed Oct 30, 2024
1 parent 750a45f commit 4e86b1d
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 0 deletions.
6 changes: 6 additions & 0 deletions ssh/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package ssh
import (
"fmt"
"net"
"sync/atomic"
)

// OpenChannelError is returned if the other side rejects an
Expand Down Expand Up @@ -89,6 +90,11 @@ type connection struct {
transport *handshakeTransport
sshConn

// serverAuthComplete is whether, when used as an incoming server
// auth connection, the auth phase is complete. This is used to prevent
// use of ServerPreAuthConn after the auth phase is complete.
serverAuthComplete atomic.Bool

// The connection protocol.
*mux
}
Expand Down
34 changes: 34 additions & 0 deletions ssh/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,30 @@ type GSSAPIWithMICConfig struct {
Server GSSAPIServer
}

// SendAuthanner implements [ServerPreAuthConn].
func (s *connection) SendAuthBanner(msg string) error {
if s.serverAuthComplete.Load() {
return errors.New("ssh: SendAuthBanner outside of authentication phase")
}
return s.transport.writePacket(Marshal(&userAuthBannerMsg{
Message: msg,
}))
}

func (*connection) unexportedMethodForFutureProofing() {}

// ServerPreAuthConn is the interface available on an incoming server
// connection before authentication has completed.
type ServerPreAuthConn interface {
unexportedMethodForFutureProofing() // permits growing ServerPreAuthConn safely later, ala testing.TB

ConnMetadata

// SendAuthBanner sends a baner message to the client.
// It returns an error once the authentication phase has ended.
SendAuthBanner(string) error
}

// ServerConfig holds server specific configuration data.
type ServerConfig struct {
// Config contains configuration shared between client and server.
Expand Down Expand Up @@ -118,6 +142,11 @@ type ServerConfig struct {
// attempts.
AuthLogCallback func(conn ConnMetadata, method string, err error)

// PreAuthConnCallback, if non-nil, is called upon receiving a new connection
// before any authentication has started. The provided ServerPreAuthConn
// can be used before authentication is complete.
PreAuthConnCallback func(ServerPreAuthConn)

// ServerVersion is the version identification string to announce in
// the public handshake.
// If empty, a reasonable default is used.
Expand Down Expand Up @@ -230,6 +259,7 @@ func NewServerConn(c net.Conn, config *ServerConfig) (*ServerConn, <-chan NewCha
sshConn: sshConn{conn: c},
}
perms, err := s.serverHandshake(&fullConf)
s.serverAuthComplete.Store(true)
if err != nil {
c.Close()
return nil, nil, nil, err
Expand Down Expand Up @@ -481,6 +511,10 @@ func (b *BannerError) Error() string {
}

func (s *connection) serverAuthenticate(config *ServerConfig) (*Permissions, error) {
if config.PreAuthConnCallback != nil {
config.PreAuthConnCallback(s)
}

sessionID := s.transport.getSessionID()
var cache pubKeyCache
var perms *Permissions
Expand Down
64 changes: 64 additions & 0 deletions ssh/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,70 @@ func TestBannerError(t *testing.T) {
}
}

func TestPreAuthConnAndBanners(t *testing.T) {
authConnc := make(chan ServerPreAuthConn, 1)
serverConfig := &ServerConfig{
PreAuthConnCallback: func(c ServerPreAuthConn) {
t.Logf("got ServerPreAuthConn: %v", c)
authConnc <- c // for use later in the test
for _, s := range []string{"hello1", "hello2"} {
if err := c.SendAuthBanner(s); err != nil {
t.Errorf("failed to send banner %q: %v", s, err)
}
}
},
NoClientAuth: true,
NoClientAuthCallback: func(ConnMetadata) (*Permissions, error) {
t.Logf("got NoClientAuthCallback")
return &Permissions{}, nil
},
}
serverConfig.AddHostKey(testSigners["rsa"])

var banners []string
clientConfig := &ClientConfig{
User: "test",
HostKeyCallback: InsecureIgnoreHostKey(),
BannerCallback: func(msg string) error {
banners = append(banners, msg)
return nil
},
}

c1, c2, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
}
defer c1.Close()
defer c2.Close()
go newServer(c1, serverConfig)
c, _, _, err := NewClientConn(c2, "", clientConfig)
if err != nil {
t.Fatalf("client connection failed: %v", err)
}
defer c.Close()

wantBanners := []string{
"hello1",
"hello2",
}
if !reflect.DeepEqual(banners, wantBanners) {
t.Errorf("got banners:\n%q\nwant banners:\n%q", banners, wantBanners)
}

// Now that we're authenticated, verify that use of SendBanner
// is an error.
var bc ServerPreAuthConn
select {
case bc = <-authConnc:
default:
t.Fatal("expected ServerPreAuthConn")
}
if err := bc.SendAuthBanner("wrong-phase"); err == nil {
t.Error("unexpected success of SendAuthBanner after authentication")
}
}

type markerConn struct {
closed uint32
used uint32
Expand Down

0 comments on commit 4e86b1d

Please sign in to comment.