Skip to content

Commit

Permalink
Make proxy.Client infer the cluster name from Proxy
Browse files Browse the repository at this point in the history
Instead of relying on users to provide the cluster name, the client
now determines the cluster name by inspecting the certificate
presented by the Proxy during the TLS or SSH handshake. This is
required when connecting to a Proxy via a jump host since the
name of the cluster may not match the currently logged in cluster.

This is achieved by leveraging a custom `credentials.TransportCredentials`
when connecting via gRPC and a custom `ssh.HostKeyCallback` when
connecting SSH.
  • Loading branch information
rosstimothy committed Mar 27, 2023
1 parent b2d2b12 commit bfab398
Show file tree
Hide file tree
Showing 2 changed files with 346 additions and 21 deletions.
172 changes: 152 additions & 20 deletions api/client/proxy/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@ package proxy
import (
"context"
"crypto/tls"
"encoding/asn1"
"io"
"net"
"strings"
"sync"
"time"

"github.com/gravitational/trace"
Expand Down Expand Up @@ -64,9 +66,6 @@ type ClientConfig struct {
ProxySSHAddress string
// TLSRoutingEnabled indicates if the cluster is using TLS Routing.
TLSRoutingEnabled bool
// ClusterName is the name of the Teleport cluster that the client
// will be connected to.
ClusterName string
// TLSConfig contains the tls.Config required for mTLS connections.
TLSConfig *tls.Config
// UnaryInterceptors are optional [grpc.UnaryClientInterceptor] to apply
Expand All @@ -91,16 +90,15 @@ type ClientConfig struct {
clientCreds func() client.Credentials
}

// CheckAndSetDefaults ensures required options are present and
// sets the default value of any that are omitted.
func (c *ClientConfig) CheckAndSetDefaults() error {
if c.ProxyWebAddress == "" {
return trace.BadParameter("missing required parameter ProxyWebAddress")
}
if c.ProxySSHAddress == "" {
return trace.BadParameter("missing required parameter ProxySSHAddress")
}
if c.ClusterName == "" {
return trace.BadParameter("missing required parameter ClusterName")
}
if c.SSHDialer == nil {
return trace.BadParameter("missing required parameter SSHDialer")
}
Expand All @@ -112,16 +110,27 @@ func (c *ClientConfig) CheckAndSetDefaults() error {
}

if c.TLSConfig != nil {
if !slices.Contains(c.TLSConfig.NextProtos, protocolProxySSHGRPC) {
tlsCfg := c.TLSConfig.Clone()
tlsCfg.NextProtos = append(tlsCfg.NextProtos, protocolProxySSHGRPC)
c.TLSConfig = tlsCfg
}
c.clientCreds = func() client.Credentials {
return client.LoadTLS(c.TLSConfig.Clone())
}
c.creds = func() credentials.TransportCredentials {
return credentials.NewTLS(c.TLSConfig.Clone())
tlsCfg := c.TLSConfig.Clone()
if !slices.Contains(c.TLSConfig.NextProtos, protocolProxySSHGRPC) {
tlsCfg.NextProtos = append(tlsCfg.NextProtos, protocolProxySSHGRPC)
}

// This logic still appears to be necessary to force client to always send
// a certificate regardless of the server setting. Otherwise the client may pick
// not to send the client certificate by looking at certificate request.
if len(tlsCfg.Certificates) > 0 {
cert := tlsCfg.Certificates[0]
tlsCfg.Certificates = nil
tlsCfg.GetClientCertificate = func(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) {
return &cert, nil
}
}

return credentials.NewTLS(tlsCfg)
}
} else {
c.clientCreds = func() client.Credentials {
Expand All @@ -135,6 +144,8 @@ func (c *ClientConfig) CheckAndSetDefaults() error {
return nil
}

// insecureCredentials implements [client.Credentials] and is used by tests
// to connect to the Auth server without mTLS.
type insecureCredentials struct{}

func (mc insecureCredentials) Dialer(client.Config) (client.ContextDialer, error) {
Expand Down Expand Up @@ -164,6 +175,9 @@ type Client struct {
transport *transportv1.Client
// sshClient is the established SSH connection to the Proxy.
sshClient *tracessh.Client
// clusterName as determined by inspecting the certificate presented by
// the Proxy during the connection handshake.
clusterName *clusterName
}

// protocolProxySSHGRPC is TLS ALPN protocol value used to indicate gRPC
Expand Down Expand Up @@ -205,16 +219,87 @@ func NewClient(ctx context.Context, cfg ClientConfig) (*Client, error) {
return nil, trace.NewAggregate(grpcErr, sshErr)
}

// clusterName stores the name of the cluster
// in a protected manner which allows it to
// be set during handshakes with the server.
type clusterName struct {
mu sync.Mutex
name string
}

func (c *clusterName) get() string {
c.mu.Lock()
defer c.mu.Unlock()

return c.name
}

func (c *clusterName) set(name string) {
c.mu.Lock()
defer c.mu.Unlock()

c.name = name
}

// clusterCredentials is a [credentials.TransportCredentials] implementation
// that obtains the name of the cluster being connected to from the certificate
// presented by the server. This allows the client to determine the cluster name when
// connecting via using jump hosts.
type clusterCredentials struct {
credentials.TransportCredentials
clusterName *clusterName
}

var (
// teleportClusterASN1ExtensionOID is an extension ID used when encoding/decoding
// origin teleport cluster name into certificates.
teleportClusterASN1ExtensionOID = asn1.ObjectIdentifier{1, 3, 9999, 1, 7}
)

// ClientHandshake performs the handshake with the wrapped [credentials.TransportCredentials] and
// then inspects the provided cert for the [teleportClusterASN1ExtensionOID] to determine
// the cluster that the server belongs to.
func (c *clusterCredentials) ClientHandshake(ctx context.Context, authority string, conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
conn, info, err := c.TransportCredentials.ClientHandshake(ctx, authority, conn)
if err != nil {
return conn, info, trace.Wrap(err)
}

tlsInfo, ok := info.(credentials.TLSInfo)
if !ok {
return conn, info, nil
}

certs := tlsInfo.State.PeerCertificates
if len(certs) == 0 {
return conn, info, nil
}

clientCert := certs[0]
for _, attr := range clientCert.Subject.Names {
if attr.Type.Equal(teleportClusterASN1ExtensionOID) {
val, ok := attr.Value.(string)
if ok {
c.clusterName.set(val)
break
}
}
}

return conn, info, nil
}

// newGRPCClient creates a Client that is connected via gRPC.
func newGRPCClient(ctx context.Context, cfg *ClientConfig) (_ *Client, err error) {
dialCtx, cancel := context.WithTimeout(ctx, cfg.DialTimeout)
defer cancel()

c := &clusterName{}
conn, err := grpc.DialContext(
dialCtx,
cfg.ProxySSHAddress,
append(cfg.DialOpts,
grpc.WithTransportCredentials(cfg.creds()),
grpc.WithTransportCredentials(&clusterCredentials{TransportCredentials: cfg.creds(), clusterName: c}),
grpc.WithChainUnaryInterceptor(
append(cfg.UnaryInterceptors,
otelgrpc.UnaryClientInterceptor(),
Expand Down Expand Up @@ -245,25 +330,72 @@ func newGRPCClient(ctx context.Context, cfg *ClientConfig) (_ *Client, err error
}

return &Client{
cfg: cfg,
grpcConn: conn,
transport: transport,
cfg: cfg,
grpcConn: conn,
transport: transport,
clusterName: c,
}, nil
}

// teleportAuthority is the extension set by the server
// which contains the name of the cluster it is in.
const teleportAuthority = "x-teleport-authority"

// clusterCallback is a [ssh.HostKeyCallback] that obtains the name
// of the cluster being connected to from the certificate presented by the server.
// This allows the clien to determine the cluster name when using jump hosts.
func clusterCallback(c *clusterName, wrapped ssh.HostKeyCallback) ssh.HostKeyCallback {
return func(hostname string, remote net.Addr, key ssh.PublicKey) error {
if err := wrapped(hostname, remote, key); err != nil {
return trace.Wrap(err)
}

cert, ok := key.(*ssh.Certificate)
if !ok {
return trace.BadParameter("proxy did not present a host certificate")
}

clusterName, ok := cert.Permissions.Extensions[teleportAuthority]
if !ok {
return trace.BadParameter("authority not present in host certificate")
}

c.set(clusterName)
return nil
}
}

// newSSHClient creates a Client that is connected via SSH.
func newSSHClient(ctx context.Context, cfg *ClientConfig) (*Client, error) {
clt, err := cfg.SSHDialer.Dial(ctx, "tcp", cfg.ProxySSHAddress, cfg.SSHConfig)
c := &clusterName{}
clientCfg := &ssh.ClientConfig{
User: cfg.SSHConfig.User,
Auth: cfg.SSHConfig.Auth,
HostKeyCallback: clusterCallback(c, cfg.SSHConfig.HostKeyCallback),
BannerCallback: cfg.SSHConfig.BannerCallback,
ClientVersion: cfg.SSHConfig.ClientVersion,
HostKeyAlgorithms: cfg.SSHConfig.HostKeyAlgorithms,
Timeout: cfg.SSHConfig.Timeout,
}

clt, err := cfg.SSHDialer.Dial(ctx, "tcp", cfg.ProxySSHAddress, clientCfg)
if err != nil {
return nil, trace.Wrap(err)
}

return &Client{
cfg: cfg,
sshClient: clt,
cfg: cfg,
sshClient: clt,
clusterName: c,
}, nil
}

// ClusterName returns the name of the cluster that the
// connected Proxy is a member of.
func (c *Client) ClusterName() string {
return c.clusterName.get()
}

// Close attempts to close both the gRPC and SSH connections.
func (c *Client) Close() error {
var errs []error
Expand Down Expand Up @@ -486,7 +618,7 @@ func dialSSH(ctx context.Context, clt *tracessh.Client, proxyAddress, targetAddr
// read the stderr output from the failed SSH session and append
// it to the end of our own message:
serverErrorMsg, _ := io.ReadAll(sessionError)
return nil, trace.ConnectionProblem(err, "failed connecting to host %s: %v. %v", targetAddress, serverErrorMsg, err)
return nil, trace.ConnectionProblem(err, "failed connecting to host %s: %s. %v", targetAddress, serverErrorMsg, err)
}

return conn, nil
Expand Down
Loading

0 comments on commit bfab398

Please sign in to comment.