Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make proxy.Client infer the cluster name from Proxy #23644

Merged
merged 1 commit into from
Apr 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 148 additions & 20 deletions api/client/proxy/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@ package proxy
import (
"context"
"crypto/tls"
"encoding/asn1"
"io"
"net"
"strings"
"sync/atomic"
"time"

"github.com/gravitational/trace"
Expand Down Expand Up @@ -64,9 +66,6 @@ type ClientConfig struct {
ProxySSHAddress string
// TLSRoutingEnabled indicates if the cluster is using TLS Routing.
TLSRoutingEnabled bool
// ClusterName is the name of the Teleport cluster that the client
// will be connected to.
ClusterName string
// TLSConfig contains the tls.Config required for mTLS connections.
TLSConfig *tls.Config
// UnaryInterceptors are optional [grpc.UnaryClientInterceptor] to apply
Expand All @@ -91,16 +90,15 @@ type ClientConfig struct {
clientCreds func() client.Credentials
}

// CheckAndSetDefaults ensures required options are present and
// sets the default value of any that are omitted.
func (c *ClientConfig) CheckAndSetDefaults() error {
if c.ProxyWebAddress == "" {
return trace.BadParameter("missing required parameter ProxyWebAddress")
}
if c.ProxySSHAddress == "" {
return trace.BadParameter("missing required parameter ProxySSHAddress")
}
if c.ClusterName == "" {
return trace.BadParameter("missing required parameter ClusterName")
}
if c.SSHDialer == nil {
return trace.BadParameter("missing required parameter SSHDialer")
}
Expand All @@ -112,16 +110,27 @@ func (c *ClientConfig) CheckAndSetDefaults() error {
}

if c.TLSConfig != nil {
if !slices.Contains(c.TLSConfig.NextProtos, protocolProxySSHGRPC) {
tlsCfg := c.TLSConfig.Clone()
tlsCfg.NextProtos = append(tlsCfg.NextProtos, protocolProxySSHGRPC)
c.TLSConfig = tlsCfg
}
c.clientCreds = func() client.Credentials {
return client.LoadTLS(c.TLSConfig.Clone())
}
c.creds = func() credentials.TransportCredentials {
return credentials.NewTLS(c.TLSConfig.Clone())
tlsCfg := c.TLSConfig.Clone()
if !slices.Contains(c.TLSConfig.NextProtos, protocolProxySSHGRPC) {
tlsCfg.NextProtos = append(tlsCfg.NextProtos, protocolProxySSHGRPC)
}

// This logic still appears to be necessary to force client to always send
// a certificate regardless of the server setting. Otherwise the client may pick
// not to send the client certificate by looking at certificate request.
if len(tlsCfg.Certificates) > 0 {
cert := tlsCfg.Certificates[0]
tlsCfg.Certificates = nil
tlsCfg.GetClientCertificate = func(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) {
return &cert, nil
}
}

return credentials.NewTLS(tlsCfg)
}
} else {
c.clientCreds = func() client.Credentials {
Expand All @@ -135,6 +144,8 @@ func (c *ClientConfig) CheckAndSetDefaults() error {
return nil
}

// insecureCredentials implements [client.Credentials] and is used by tests
// to connect to the Auth server without mTLS.
type insecureCredentials struct{}

func (mc insecureCredentials) Dialer(client.Config) (client.ContextDialer, error) {
Expand Down Expand Up @@ -164,6 +175,9 @@ type Client struct {
transport *transportv1.Client
// sshClient is the established SSH connection to the Proxy.
sshClient *tracessh.Client
// clusterName as determined by inspecting the certificate presented by
// the Proxy during the connection handshake.
clusterName *clusterName
}

// protocolProxySSHGRPC is TLS ALPN protocol value used to indicate gRPC
Expand Down Expand Up @@ -205,16 +219,84 @@ func NewClient(ctx context.Context, cfg ClientConfig) (*Client, error) {
return nil, trace.NewAggregate(grpcErr, sshErr)
}

// clusterName stores the name of the cluster
// in a protected manner which allows it to
// be set during handshakes with the server.
type clusterName struct {
name atomic.Pointer[string]
}

func (c *clusterName) get() string {
name := c.name.Load()
if name != nil {
return *name
}
return ""
}

func (c *clusterName) set(name string) {
c.name.CompareAndSwap(nil, &name)
}

// clusterCredentials is a [credentials.TransportCredentials] implementation
// that obtains the name of the cluster being connected to from the certificate
// presented by the server. This allows the client to determine the cluster name when
// connecting via using jump hosts.
type clusterCredentials struct {
credentials.TransportCredentials
clusterName *clusterName
}

var (
// teleportClusterASN1ExtensionOID is an extension ID used when encoding/decoding
// origin teleport cluster name into certificates.
teleportClusterASN1ExtensionOID = asn1.ObjectIdentifier{1, 3, 9999, 1, 7}
)

// ClientHandshake performs the handshake with the wrapped [credentials.TransportCredentials] and
// then inspects the provided cert for the [teleportClusterASN1ExtensionOID] to determine
// the cluster that the server belongs to.
func (c *clusterCredentials) ClientHandshake(ctx context.Context, authority string, conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
conn, info, err := c.TransportCredentials.ClientHandshake(ctx, authority, conn)
if err != nil {
return conn, info, trace.Wrap(err)
}

tlsInfo, ok := info.(credentials.TLSInfo)
if !ok {
return conn, info, nil
}

certs := tlsInfo.State.PeerCertificates
if len(certs) == 0 {
return conn, info, nil
}

clientCert := certs[0]
for _, attr := range clientCert.Subject.Names {
if attr.Type.Equal(teleportClusterASN1ExtensionOID) {
val, ok := attr.Value.(string)
if ok {
c.clusterName.set(val)
break
}
}
}

return conn, info, nil
}

// newGRPCClient creates a Client that is connected via gRPC.
func newGRPCClient(ctx context.Context, cfg *ClientConfig) (_ *Client, err error) {
dialCtx, cancel := context.WithTimeout(ctx, cfg.DialTimeout)
defer cancel()

c := &clusterName{}
conn, err := grpc.DialContext(
dialCtx,
cfg.ProxySSHAddress,
append(cfg.DialOpts,
grpc.WithTransportCredentials(cfg.creds()),
grpc.WithTransportCredentials(&clusterCredentials{TransportCredentials: cfg.creds(), clusterName: c}),
grpc.WithChainUnaryInterceptor(
append(cfg.UnaryInterceptors,
otelgrpc.UnaryClientInterceptor(),
Expand Down Expand Up @@ -245,25 +327,71 @@ func newGRPCClient(ctx context.Context, cfg *ClientConfig) (_ *Client, err error
}

return &Client{
cfg: cfg,
grpcConn: conn,
transport: transport,
cfg: cfg,
grpcConn: conn,
transport: transport,
clusterName: c,
}, nil
}

// teleportAuthority is the extension set by the server
// which contains the name of the cluster it is in.
const teleportAuthority = "x-teleport-authority"

// clusterCallback is a [ssh.HostKeyCallback] that obtains the name
// of the cluster being connected to from the certificate presented by the server.
// This allows the client to determine the cluster name when using jump hosts.
func clusterCallback(c *clusterName, wrapped ssh.HostKeyCallback) ssh.HostKeyCallback {
return func(hostname string, remote net.Addr, key ssh.PublicKey) error {
if err := wrapped(hostname, remote, key); err != nil {
return trace.Wrap(err)
}

cert, ok := key.(*ssh.Certificate)
if !ok {
return nil
}

clusterName, ok := cert.Permissions.Extensions[teleportAuthority]
if ok {
c.set(clusterName)
}

return nil
}
}

// newSSHClient creates a Client that is connected via SSH.
func newSSHClient(ctx context.Context, cfg *ClientConfig) (*Client, error) {
clt, err := cfg.SSHDialer.Dial(ctx, "tcp", cfg.ProxySSHAddress, cfg.SSHConfig)
c := &clusterName{}
clientCfg := &ssh.ClientConfig{
User: cfg.SSHConfig.User,
Auth: cfg.SSHConfig.Auth,
HostKeyCallback: clusterCallback(c, cfg.SSHConfig.HostKeyCallback),
BannerCallback: cfg.SSHConfig.BannerCallback,
ClientVersion: cfg.SSHConfig.ClientVersion,
HostKeyAlgorithms: cfg.SSHConfig.HostKeyAlgorithms,
Timeout: cfg.SSHConfig.Timeout,
}

clt, err := cfg.SSHDialer.Dial(ctx, "tcp", cfg.ProxySSHAddress, clientCfg)
if err != nil {
return nil, trace.Wrap(err)
}

return &Client{
cfg: cfg,
sshClient: clt,
cfg: cfg,
sshClient: clt,
clusterName: c,
}, nil
}

// ClusterName returns the name of the cluster that the
// connected Proxy is a member of.
func (c *Client) ClusterName() string {
return c.clusterName.get()
}

// Close attempts to close both the gRPC and SSH connections.
func (c *Client) Close() error {
var errs []error
Expand Down Expand Up @@ -486,7 +614,7 @@ func dialSSH(ctx context.Context, clt *tracessh.Client, proxyAddress, targetAddr
// read the stderr output from the failed SSH session and append
// it to the end of our own message:
serverErrorMsg, _ := io.ReadAll(sessionError)
return nil, trace.ConnectionProblem(err, "failed connecting to host %s: %v. %v", targetAddress, serverErrorMsg, err)
return nil, trace.ConnectionProblem(err, "failed connecting to host %s: %s. %v", targetAddress, serverErrorMsg, err)
}

return conn, nil
Expand Down
Loading