diff --git a/credentials/tls.go b/credentials/tls.go index e163a473df93..bd5fe22b6af6 100644 --- a/credentials/tls.go +++ b/credentials/tls.go @@ -32,6 +32,8 @@ import ( "google.golang.org/grpc/internal/envconfig" ) +const alpnFailureHelpMessage = "If you upgraded from a grpc-go version earlier than 1.67, your TLS connections may have stopped working due to ALPN enforcement. For more details, see: https://github.com/grpc/grpc-go/issues/434" + var logger = grpclog.Component("credentials") // TLSInfo contains the auth information for a TLS authenticated connection. @@ -128,7 +130,7 @@ func (c *tlsCreds) ClientHandshake(ctx context.Context, authority string, rawCon if np == "" { if envconfig.EnforceALPNEnabled { conn.Close() - return nil, nil, fmt.Errorf("credentials: cannot check peer: missing selected ALPN property") + return nil, nil, fmt.Errorf("credentials: cannot check peer: missing selected ALPN property. %s", alpnFailureHelpMessage) } logger.Warningf("Allowing TLS connection to server %q with ALPN disabled. TLS connections to servers with ALPN disabled will be disallowed in future grpc-go releases", cfg.ServerName) } @@ -158,7 +160,7 @@ func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error) if cs.NegotiatedProtocol == "" { if envconfig.EnforceALPNEnabled { conn.Close() - return nil, nil, fmt.Errorf("credentials: cannot check peer: missing selected ALPN property") + return nil, nil, fmt.Errorf("credentials: cannot check peer: missing selected ALPN property. %s", alpnFailureHelpMessage) } else if logger.V(2) { logger.Info("Allowing TLS connection from client with ALPN disabled. TLS connections with ALPN disabled will be disallowed in future grpc-go releases") } diff --git a/experimental/credentials/credentials_test.go b/experimental/credentials/credentials_test.go new file mode 100644 index 000000000000..4b44a52aa942 --- /dev/null +++ b/experimental/credentials/credentials_test.go @@ -0,0 +1,252 @@ +/* + * + * Copyright 2025 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package credentials + +import ( + "context" + "crypto/tls" + "net" + "strings" + "testing" + "time" + + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/internal/grpctest" + "google.golang.org/grpc/testdata" +) + +const defaultTestTimeout = 10 * time.Second + +type s struct { + grpctest.Tester +} + +func Test(t *testing.T) { + grpctest.RunSubTests(t, s{}) +} + +func (s) TestTLSOverrideServerName(t *testing.T) { + expectedServerName := "server.name" + c := NewTLSWithALPNDisabled(nil) + c.OverrideServerName(expectedServerName) + if c.Info().ServerName != expectedServerName { + t.Fatalf("c.Info().ServerName = %v, want %v", c.Info().ServerName, expectedServerName) + } +} + +func (s) TestTLSClone(t *testing.T) { + expectedServerName := "server.name" + c := NewTLSWithALPNDisabled(nil) + c.OverrideServerName(expectedServerName) + cc := c.Clone() + if cc.Info().ServerName != expectedServerName { + t.Fatalf("cc.Info().ServerName = %v, want %v", cc.Info().ServerName, expectedServerName) + } + cc.OverrideServerName("") + if c.Info().ServerName != expectedServerName { + t.Fatalf("Change in clone should not affect the original, c.Info().ServerName = %v, want %v", c.Info().ServerName, expectedServerName) + } + +} + +type serverHandshake func(net.Conn) (credentials.AuthInfo, error) + +func (s) TestClientHandshakeReturnsAuthInfo(t *testing.T) { + tcs := []struct { + name string + address string + }{ + { + name: "localhost", + address: "localhost:0", + }, + { + name: "ipv4", + address: "127.0.0.1:0", + }, + { + name: "ipv6", + address: "[::1]:0", + }, + } + + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + done := make(chan credentials.AuthInfo, 1) + lis := launchServerOnListenAddress(t, tlsServerHandshake, done, tc.address) + defer lis.Close() + lisAddr := lis.Addr().String() + clientAuthInfo := clientHandle(t, gRPCClientHandshake, lisAddr) + // wait until server sends serverAuthInfo or fails. + serverAuthInfo, ok := <-done + if !ok { + t.Fatalf("Error at server-side") + } + if !compare(clientAuthInfo, serverAuthInfo) { + t.Fatalf("c.ClientHandshake(_, %v, _) = %v, want %v.", lisAddr, clientAuthInfo, serverAuthInfo) + } + }) + } +} + +func (s) TestServerHandshakeReturnsAuthInfo(t *testing.T) { + done := make(chan credentials.AuthInfo, 1) + lis := launchServer(t, gRPCServerHandshake, done) + defer lis.Close() + clientAuthInfo := clientHandle(t, tlsClientHandshake, lis.Addr().String()) + // wait until server sends serverAuthInfo or fails. + serverAuthInfo, ok := <-done + if !ok { + t.Fatalf("Error at server-side") + } + if !compare(clientAuthInfo, serverAuthInfo) { + t.Fatalf("ServerHandshake(_) = %v, want %v.", serverAuthInfo, clientAuthInfo) + } +} + +func (s) TestServerAndClientHandshake(t *testing.T) { + done := make(chan credentials.AuthInfo, 1) + lis := launchServer(t, gRPCServerHandshake, done) + defer lis.Close() + clientAuthInfo := clientHandle(t, gRPCClientHandshake, lis.Addr().String()) + // wait until server sends serverAuthInfo or fails. + serverAuthInfo, ok := <-done + if !ok { + t.Fatalf("Error at server-side") + } + if !compare(clientAuthInfo, serverAuthInfo) { + t.Fatalf("AuthInfo returned by server: %v and client: %v aren't same", serverAuthInfo, clientAuthInfo) + } +} + +func compare(a1, a2 credentials.AuthInfo) bool { + if a1.AuthType() != a2.AuthType() { + return false + } + switch a1.AuthType() { + case "tls": + state1 := a1.(credentials.TLSInfo).State + state2 := a2.(credentials.TLSInfo).State + if state1.Version == state2.Version && + state1.HandshakeComplete == state2.HandshakeComplete && + state1.CipherSuite == state2.CipherSuite && + state1.NegotiatedProtocol == state2.NegotiatedProtocol { + return true + } + return false + default: + return false + } +} + +func launchServer(t *testing.T, hs serverHandshake, done chan credentials.AuthInfo) net.Listener { + return launchServerOnListenAddress(t, hs, done, "localhost:0") +} + +func launchServerOnListenAddress(t *testing.T, hs serverHandshake, done chan credentials.AuthInfo, address string) net.Listener { + lis, err := net.Listen("tcp", address) + if err != nil { + if strings.Contains(err.Error(), "bind: cannot assign requested address") || + strings.Contains(err.Error(), "socket: address family not supported by protocol") { + t.Skipf("no support for address %v", address) + } + t.Fatalf("Failed to listen: %v", err) + } + go serverHandle(t, hs, done, lis) + return lis +} + +// Is run in a separate goroutine. +func serverHandle(t *testing.T, hs serverHandshake, done chan credentials.AuthInfo, lis net.Listener) { + serverRawConn, err := lis.Accept() + if err != nil { + t.Errorf("Server failed to accept connection: %v", err) + close(done) + return + } + serverAuthInfo, err := hs(serverRawConn) + if err != nil { + t.Errorf("Server failed while handshake. Error: %v", err) + serverRawConn.Close() + close(done) + return + } + done <- serverAuthInfo +} + +func clientHandle(t *testing.T, hs func(net.Conn, string) (credentials.AuthInfo, error), lisAddr string) credentials.AuthInfo { + conn, err := net.Dial("tcp", lisAddr) + if err != nil { + t.Fatalf("Client failed to connect to %s. Error: %v", lisAddr, err) + } + defer conn.Close() + clientAuthInfo, err := hs(conn, lisAddr) + if err != nil { + t.Fatalf("Error on client while handshake. Error: %v", err) + } + return clientAuthInfo +} + +// Server handshake implementation in gRPC. +func gRPCServerHandshake(conn net.Conn) (credentials.AuthInfo, error) { + serverTLS, err := NewServerTLSFromFileWithALPNDisabled(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem")) + if err != nil { + return nil, err + } + _, serverAuthInfo, err := serverTLS.ServerHandshake(conn) + if err != nil { + return nil, err + } + return serverAuthInfo, nil +} + +// Client handshake implementation in gRPC. +func gRPCClientHandshake(conn net.Conn, lisAddr string) (credentials.AuthInfo, error) { + clientTLS := NewTLSWithALPNDisabled(&tls.Config{InsecureSkipVerify: true}) + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + _, authInfo, err := clientTLS.ClientHandshake(ctx, lisAddr, conn) + if err != nil { + return nil, err + } + return authInfo, nil +} + +func tlsServerHandshake(conn net.Conn) (credentials.AuthInfo, error) { + cert, err := tls.LoadX509KeyPair(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem")) + if err != nil { + return nil, err + } + serverTLSConfig := &tls.Config{Certificates: []tls.Certificate{cert}} + serverConn := tls.Server(conn, serverTLSConfig) + err = serverConn.Handshake() + if err != nil { + return nil, err + } + return credentials.TLSInfo{State: serverConn.ConnectionState(), CommonAuthInfo: credentials.CommonAuthInfo{SecurityLevel: credentials.PrivacyAndIntegrity}}, nil +} + +func tlsClientHandshake(conn net.Conn, _ string) (credentials.AuthInfo, error) { + clientTLSConfig := &tls.Config{InsecureSkipVerify: true} + clientConn := tls.Client(conn, clientTLSConfig) + if err := clientConn.Handshake(); err != nil { + return nil, err + } + return credentials.TLSInfo{State: clientConn.ConnectionState(), CommonAuthInfo: credentials.CommonAuthInfo{SecurityLevel: credentials.PrivacyAndIntegrity}}, nil +} diff --git a/experimental/credentials/internal/spiffe.go b/experimental/credentials/internal/spiffe.go new file mode 100644 index 000000000000..569419a98a27 --- /dev/null +++ b/experimental/credentials/internal/spiffe.go @@ -0,0 +1,75 @@ +/* + * + * Copyright 2025 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package internal defines APIs for parsing SPIFFE ID. +// +// All APIs in this package are experimental. +package internal + +import ( + "crypto/tls" + "crypto/x509" + "net/url" + + "google.golang.org/grpc/grpclog" +) + +var logger = grpclog.Component("credentials") + +// SPIFFEIDFromState parses the SPIFFE ID from State. If the SPIFFE ID format +// is invalid, return nil with warning. +func SPIFFEIDFromState(state tls.ConnectionState) *url.URL { + if len(state.PeerCertificates) == 0 || len(state.PeerCertificates[0].URIs) == 0 { + return nil + } + return SPIFFEIDFromCert(state.PeerCertificates[0]) +} + +// SPIFFEIDFromCert parses the SPIFFE ID from x509.Certificate. If the SPIFFE +// ID format is invalid, return nil with warning. +func SPIFFEIDFromCert(cert *x509.Certificate) *url.URL { + if cert == nil || cert.URIs == nil { + return nil + } + var spiffeID *url.URL + for _, uri := range cert.URIs { + if uri == nil || uri.Scheme != "spiffe" || uri.Opaque != "" || (uri.User != nil && uri.User.Username() != "") { + continue + } + // From this point, we assume the uri is intended for a SPIFFE ID. + if len(uri.String()) > 2048 { + logger.Warning("invalid SPIFFE ID: total ID length larger than 2048 bytes") + return nil + } + if len(uri.Host) == 0 || len(uri.Path) == 0 { + logger.Warning("invalid SPIFFE ID: domain or workload ID is empty") + return nil + } + if len(uri.Host) > 255 { + logger.Warning("invalid SPIFFE ID: domain length larger than 255 characters") + return nil + } + // A valid SPIFFE certificate can only have exactly one URI SAN field. + if len(cert.URIs) > 1 { + logger.Warning("invalid SPIFFE ID: multiple URI SANs") + return nil + } + spiffeID = uri + } + return spiffeID +} diff --git a/experimental/credentials/internal/spiffe_test.go b/experimental/credentials/internal/spiffe_test.go new file mode 100644 index 000000000000..885fe1bf15d7 --- /dev/null +++ b/experimental/credentials/internal/spiffe_test.go @@ -0,0 +1,233 @@ +/* + * + * Copyright 2025 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package internal + +import ( + "crypto/tls" + "crypto/x509" + "encoding/pem" + "net/url" + "os" + "testing" + + "google.golang.org/grpc/internal/grpctest" + "google.golang.org/grpc/testdata" +) + +const wantURI = "spiffe://foo.bar.com/client/workload/1" + +type s struct { + grpctest.Tester +} + +func Test(t *testing.T) { + grpctest.RunSubTests(t, s{}) +} + +func (s) TestSPIFFEIDFromState(t *testing.T) { + tests := []struct { + name string + urls []*url.URL + // If we expect a SPIFFE ID to be returned. + wantID bool + }{ + { + name: "empty URIs", + urls: []*url.URL{}, + wantID: false, + }, + { + name: "good SPIFFE ID", + urls: []*url.URL{ + { + Scheme: "spiffe", + Host: "foo.bar.com", + Path: "workload/wl1", + RawPath: "workload/wl1", + }, + }, + wantID: true, + }, + { + name: "invalid host", + urls: []*url.URL{ + { + Scheme: "spiffe", + Host: "", + Path: "workload/wl1", + RawPath: "workload/wl1", + }, + }, + wantID: false, + }, + { + name: "invalid path", + urls: []*url.URL{ + { + Scheme: "spiffe", + Host: "foo.bar.com", + Path: "", + RawPath: "", + }, + }, + wantID: false, + }, + { + name: "large path", + urls: []*url.URL{ + { + Scheme: "spiffe", + Host: "foo.bar.com", + Path: string(make([]byte, 2050)), + RawPath: string(make([]byte, 2050)), + }, + }, + wantID: false, + }, + { + name: "large host", + urls: []*url.URL{ + { + Scheme: "spiffe", + Host: string(make([]byte, 256)), + Path: "workload/wl1", + RawPath: "workload/wl1", + }, + }, + wantID: false, + }, + { + name: "multiple URI SANs", + urls: []*url.URL{ + { + Scheme: "spiffe", + Host: "foo.bar.com", + Path: "workload/wl1", + RawPath: "workload/wl1", + }, + { + Scheme: "spiffe", + Host: "bar.baz.com", + Path: "workload/wl2", + RawPath: "workload/wl2", + }, + { + Scheme: "https", + Host: "foo.bar.com", + Path: "workload/wl1", + RawPath: "workload/wl1", + }, + }, + wantID: false, + }, + { + name: "multiple URI SANs without SPIFFE ID", + urls: []*url.URL{ + { + Scheme: "https", + Host: "foo.bar.com", + Path: "workload/wl1", + RawPath: "workload/wl1", + }, + { + Scheme: "ssh", + Host: "foo.bar.com", + Path: "workload/wl1", + RawPath: "workload/wl1", + }, + }, + wantID: false, + }, + { + name: "multiple URI SANs with one SPIFFE ID", + urls: []*url.URL{ + { + Scheme: "spiffe", + Host: "foo.bar.com", + Path: "workload/wl1", + RawPath: "workload/wl1", + }, + { + Scheme: "https", + Host: "foo.bar.com", + Path: "workload/wl1", + RawPath: "workload/wl1", + }, + }, + wantID: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + state := tls.ConnectionState{PeerCertificates: []*x509.Certificate{{URIs: tt.urls}}} + id := SPIFFEIDFromState(state) + if got, want := id != nil, tt.wantID; got != want { + t.Errorf("want wantID = %v, but SPIFFE ID is %v", want, id) + } + }) + } +} + +func (s) TestSPIFFEIDFromCert(t *testing.T) { + tests := []struct { + name string + dataPath string + // If we expect a SPIFFE ID to be returned. + wantID bool + }{ + { + name: "good certificate with SPIFFE ID", + dataPath: "x509/spiffe_cert.pem", + wantID: true, + }, + { + name: "bad certificate with SPIFFE ID and another URI", + dataPath: "x509/multiple_uri_cert.pem", + wantID: false, + }, + { + name: "certificate without SPIFFE ID", + dataPath: "x509/client1_cert.pem", + wantID: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data, err := os.ReadFile(testdata.Path(tt.dataPath)) + if err != nil { + t.Fatalf("os.ReadFile(%s) failed: %v", testdata.Path(tt.dataPath), err) + } + block, _ := pem.Decode(data) + if block == nil { + t.Fatalf("Failed to parse the certificate: byte block is nil") + } + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + t.Fatalf("x509.ParseCertificate(%b) failed: %v", block.Bytes, err) + } + uri := SPIFFEIDFromCert(cert) + if (uri != nil) != tt.wantID { + t.Fatalf("wantID got and want mismatch, got %t, want %t", uri != nil, tt.wantID) + } + if uri != nil && uri.String() != wantURI { + t.Fatalf("SPIFFE ID not expected, got %s, want %s", uri.String(), wantURI) + } + }) + } +} diff --git a/experimental/credentials/internal/syscallconn.go b/experimental/credentials/internal/syscallconn.go new file mode 100644 index 000000000000..35d67bd9338b --- /dev/null +++ b/experimental/credentials/internal/syscallconn.go @@ -0,0 +1,58 @@ +/* + * + * Copyright 2025 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package internal + +import ( + "net" + "syscall" +) + +type sysConn = syscall.Conn + +// syscallConn keeps reference of rawConn to support syscall.Conn for channelz. +// SyscallConn() (the method in interface syscall.Conn) is explicitly +// implemented on this type, +// +// Interface syscall.Conn is implemented by most net.Conn implementations (e.g. +// TCPConn, UnixConn), but is not part of net.Conn interface. So wrapper conns +// that embed net.Conn don't implement syscall.Conn. (Side note: tls.Conn +// doesn't embed net.Conn, so even if syscall.Conn is part of net.Conn, it won't +// help here). +type syscallConn struct { + net.Conn + // sysConn is a type alias of syscall.Conn. It's necessary because the name + // `Conn` collides with `net.Conn`. + sysConn +} + +// WrapSyscallConn tries to wrap rawConn and newConn into a net.Conn that +// implements syscall.Conn. rawConn will be used to support syscall, and newConn +// will be used for read/write. +// +// This function returns newConn if rawConn doesn't implement syscall.Conn. +func WrapSyscallConn(rawConn, newConn net.Conn) net.Conn { + sysConn, ok := rawConn.(syscall.Conn) + if !ok { + return newConn + } + return &syscallConn{ + Conn: newConn, + sysConn: sysConn, + } +} diff --git a/experimental/credentials/internal/syscallconn_test.go b/experimental/credentials/internal/syscallconn_test.go new file mode 100644 index 000000000000..75d7cfa855be --- /dev/null +++ b/experimental/credentials/internal/syscallconn_test.go @@ -0,0 +1,56 @@ +/* + * + * Copyright 2025 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package internal + +import ( + "net" + "syscall" + "testing" +) + +func (*syscallConn) SyscallConn() (syscall.RawConn, error) { + return nil, nil +} + +type nonSyscallConn struct { + net.Conn +} + +func (s) TestWrapSyscallConn(t *testing.T) { + sc := &syscallConn{} + nsc := &nonSyscallConn{} + + wrapConn := WrapSyscallConn(sc, nsc) + if _, ok := wrapConn.(syscall.Conn); !ok { + t.Errorf("returned conn (type %T) doesn't implement syscall.Conn, want implement", wrapConn) + } +} + +func (s) TestWrapSyscallConnNoWrap(t *testing.T) { + nscRaw := &nonSyscallConn{} + nsc := &nonSyscallConn{} + + wrapConn := WrapSyscallConn(nscRaw, nsc) + if _, ok := wrapConn.(syscall.Conn); ok { + t.Errorf("returned conn (type %T) implements syscall.Conn, want not implement", wrapConn) + } + if wrapConn != nsc { + t.Errorf("returned conn is %p, want %p (the passed-in newConn)", wrapConn, nsc) + } +} diff --git a/experimental/credentials/tls.go b/experimental/credentials/tls.go new file mode 100644 index 000000000000..7b50b0a72931 --- /dev/null +++ b/experimental/credentials/tls.go @@ -0,0 +1,249 @@ +/* + * + * Copyright 2025 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package credentials provides experimental TLS credentials. +// The use of this package is strongly discouraged. These credentials exist +// solely to maintain compatibility for users interacting with clients that +// violate the HTTP/2 specification by not advertising support for "h2" in ALPN. +// This package is slated for removal in upcoming grpc-go releases. Users must +// not rely on this package directly. Instead, they should either vendor a +// specific version of gRPC or copy the relevant credentials into their own +// codebase if absolutely necessary. +package credentials + +import ( + "context" + "crypto/tls" + "crypto/x509" + "fmt" + "net" + "os" + + "golang.org/x/net/http2" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/experimental/credentials/internal" +) + +// tlsCreds is the credentials required for authenticating a connection using TLS. +type tlsCreds struct { + // TLS configuration + config *tls.Config +} + +func (c tlsCreds) Info() credentials.ProtocolInfo { + return credentials.ProtocolInfo{ + SecurityProtocol: "tls", + SecurityVersion: "1.2", + ServerName: c.config.ServerName, + } +} + +func (c *tlsCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (_ net.Conn, _ credentials.AuthInfo, err error) { + // use local cfg to avoid clobbering ServerName if using multiple endpoints + cfg := cloneTLSConfig(c.config) + if cfg.ServerName == "" { + serverName, _, err := net.SplitHostPort(authority) + if err != nil { + // If the authority had no host port or if the authority cannot be parsed, use it as-is. + serverName = authority + } + cfg.ServerName = serverName + } + conn := tls.Client(rawConn, cfg) + errChannel := make(chan error, 1) + go func() { + errChannel <- conn.Handshake() + close(errChannel) + }() + select { + case err := <-errChannel: + if err != nil { + conn.Close() + return nil, nil, err + } + case <-ctx.Done(): + conn.Close() + return nil, nil, ctx.Err() + } + + tlsInfo := credentials.TLSInfo{ + State: conn.ConnectionState(), + CommonAuthInfo: credentials.CommonAuthInfo{ + SecurityLevel: credentials.PrivacyAndIntegrity, + }, + } + id := internal.SPIFFEIDFromState(conn.ConnectionState()) + if id != nil { + tlsInfo.SPIFFEID = id + } + return internal.WrapSyscallConn(rawConn, conn), tlsInfo, nil +} + +func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { + conn := tls.Server(rawConn, c.config) + if err := conn.Handshake(); err != nil { + conn.Close() + return nil, nil, err + } + cs := conn.ConnectionState() + tlsInfo := credentials.TLSInfo{ + State: cs, + CommonAuthInfo: credentials.CommonAuthInfo{ + SecurityLevel: credentials.PrivacyAndIntegrity, + }, + } + id := internal.SPIFFEIDFromState(conn.ConnectionState()) + if id != nil { + tlsInfo.SPIFFEID = id + } + return internal.WrapSyscallConn(rawConn, conn), tlsInfo, nil +} + +func (c *tlsCreds) Clone() credentials.TransportCredentials { + return NewTLSWithALPNDisabled(c.config) +} + +func (c *tlsCreds) OverrideServerName(serverNameOverride string) error { + c.config.ServerName = serverNameOverride + return nil +} + +// The following cipher suites are forbidden for use with HTTP/2 by +// https://datatracker.ietf.org/doc/html/rfc7540#appendix-A +var tls12ForbiddenCipherSuites = map[uint16]struct{}{ + tls.TLS_RSA_WITH_AES_128_CBC_SHA: {}, + tls.TLS_RSA_WITH_AES_256_CBC_SHA: {}, + tls.TLS_RSA_WITH_AES_128_GCM_SHA256: {}, + tls.TLS_RSA_WITH_AES_256_GCM_SHA384: {}, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA: {}, + tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA: {}, + tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA: {}, + tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA: {}, +} + +// NewTLSWithALPNDisabled uses c to construct a TransportCredentials based on +// TLS. ALPN verification is disabled. +func NewTLSWithALPNDisabled(c *tls.Config) credentials.TransportCredentials { + config := applyDefaults(c) + if config.GetConfigForClient != nil { + oldFn := config.GetConfigForClient + config.GetConfigForClient = func(hello *tls.ClientHelloInfo) (*tls.Config, error) { + cfgForClient, err := oldFn(hello) + if err != nil || cfgForClient == nil { + return cfgForClient, err + } + return applyDefaults(cfgForClient), nil + } + } + return &tlsCreds{config: config} +} + +func applyDefaults(c *tls.Config) *tls.Config { + config := cloneTLSConfig(c) + config.NextProtos = appendH2ToNextProtos(config.NextProtos) + // If the user did not configure a MinVersion and did not configure a + // MaxVersion < 1.2, use MinVersion=1.2, which is required by + // https://datatracker.ietf.org/doc/html/rfc7540#section-9.2 + if config.MinVersion == 0 && (config.MaxVersion == 0 || config.MaxVersion >= tls.VersionTLS12) { + config.MinVersion = tls.VersionTLS12 + } + // If the user did not configure CipherSuites, use all "secure" cipher + // suites reported by the TLS package, but remove some explicitly forbidden + // by https://datatracker.ietf.org/doc/html/rfc7540#appendix-A + if config.CipherSuites == nil { + for _, cs := range tls.CipherSuites() { + if _, ok := tls12ForbiddenCipherSuites[cs.ID]; !ok { + config.CipherSuites = append(config.CipherSuites, cs.ID) + } + } + } + return config +} + +// NewClientTLSFromCertWithALPNDisabled constructs TLS credentials from the +// provided root certificate authority certificate(s) to validate server +// connections. If certificates to establish the identity of the client need to +// be included in the credentials (eg: for mTLS), use NewTLS instead, where a +// complete tls.Config can be specified. +// serverNameOverride is for testing only. If set to a non empty string, +// it will override the virtual host name of authority (e.g. :authority header +// field) in requests. ALPN verification is disabled. +func NewClientTLSFromCertWithALPNDisabled(cp *x509.CertPool, serverNameOverride string) credentials.TransportCredentials { + return NewTLSWithALPNDisabled(&tls.Config{ServerName: serverNameOverride, RootCAs: cp}) +} + +// NewClientTLSFromFileWithALPNDisabled constructs TLS credentials from the +// provided root certificate authority certificate file(s) to validate server +// connections. If certificates to establish the identity of the client need to +// be included in the credentials (eg: for mTLS), use NewTLS instead, where a +// complete tls.Config can be specified. +// serverNameOverride is for testing only. If set to a non empty string, +// it will override the virtual host name of authority (e.g. :authority header +// field) in requests. ALPN verification is disabled. +func NewClientTLSFromFileWithALPNDisabled(certFile, serverNameOverride string) (credentials.TransportCredentials, error) { + b, err := os.ReadFile(certFile) + if err != nil { + return nil, err + } + cp := x509.NewCertPool() + if !cp.AppendCertsFromPEM(b) { + return nil, fmt.Errorf("credentials: failed to append certificates") + } + return NewTLSWithALPNDisabled(&tls.Config{ServerName: serverNameOverride, RootCAs: cp}), nil +} + +// NewServerTLSFromCertWithALPNDisabled constructs TLS credentials from the +// input certificate for server. ALPN verification is disabled. +func NewServerTLSFromCertWithALPNDisabled(cert *tls.Certificate) credentials.TransportCredentials { + return NewTLSWithALPNDisabled(&tls.Config{Certificates: []tls.Certificate{*cert}}) +} + +// NewServerTLSFromFileWithALPNDisabled constructs TLS credentials from the +// input certificate file and key file for server. ALPN verification is disabled. +func NewServerTLSFromFileWithALPNDisabled(certFile, keyFile string) (credentials.TransportCredentials, error) { + cert, err := tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + return nil, err + } + return NewTLSWithALPNDisabled(&tls.Config{Certificates: []tls.Certificate{cert}}), nil +} + +// cloneTLSConfig returns a shallow clone of the exported +// fields of cfg, ignoring the unexported sync.Once, which +// contains a mutex and must not be copied. +// +// If cfg is nil, a new zero tls.Config is returned. +func cloneTLSConfig(cfg *tls.Config) *tls.Config { + if cfg == nil { + return &tls.Config{} + } + + return cfg.Clone() +} + +// appendH2ToNextProtos appends h2 to next protos. +func appendH2ToNextProtos(ps []string) []string { + for _, p := range ps { + if p == http2.NextProtoTLS { + return ps + } + } + ret := make([]string, 0, len(ps)+1) + ret = append(ret, ps...) + return append(ret, http2.NextProtoTLS) +} diff --git a/experimental/credentials/tls_ext_test.go b/experimental/credentials/tls_ext_test.go new file mode 100644 index 000000000000..3d4e473ff8aa --- /dev/null +++ b/experimental/credentials/tls_ext_test.go @@ -0,0 +1,604 @@ +/* + * + * Copyright 2025 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package credentials_test + +import ( + "context" + "crypto/tls" + "crypto/x509" + "fmt" + "net" + "os" + "strings" + "testing" + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + credsstable "google.golang.org/grpc/credentials" + "google.golang.org/grpc/experimental/credentials" + "google.golang.org/grpc/internal/envconfig" + "google.golang.org/grpc/internal/grpctest" + "google.golang.org/grpc/internal/stubserver" + "google.golang.org/grpc/status" + "google.golang.org/grpc/testdata" + + testgrpc "google.golang.org/grpc/interop/grpc_testing" + testpb "google.golang.org/grpc/interop/grpc_testing" +) + +const defaultTestTimeout = 10 * time.Second + +type s struct { + grpctest.Tester +} + +func Test(t *testing.T) { + grpctest.RunSubTests(t, s{}) +} + +var serverCert tls.Certificate +var certPool *x509.CertPool +var serverName = "x.test.example.com" + +func init() { + var err error + serverCert, err = tls.LoadX509KeyPair(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem")) + if err != nil { + panic(fmt.Sprintf("tls.LoadX509KeyPair(server1.pem, server1.key) failed: %v", err)) + } + + b, err := os.ReadFile(testdata.Path("x509/server_ca_cert.pem")) + if err != nil { + panic(fmt.Sprintf("Error reading CA cert file: %v", err)) + } + certPool = x509.NewCertPool() + if !certPool.AppendCertsFromPEM(b) { + panic("Error appending cert from PEM") + } +} + +// Tests that the MinVersion of tls.Config is set to 1.2 if it is not already +// set by the user. +func (s) TestTLS_MinVersion12(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + testCases := []struct { + name string + serverTLS func() *tls.Config + }{ + { + name: "base_case", + serverTLS: func() *tls.Config { + return &tls.Config{ + // MinVersion should be set to 1.2 by gRPC by default. + Certificates: []tls.Certificate{serverCert}, + } + }, + }, + { + name: "fallback_to_base", + serverTLS: func() *tls.Config { + config := &tls.Config{ + // MinVersion should be set to 1.2 by gRPC by default. + Certificates: []tls.Certificate{serverCert}, + } + config.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) { + return nil, nil + } + return config + }, + }, + { + name: "dynamic_using_get_config_for_client", + serverTLS: func() *tls.Config { + return &tls.Config{ + GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) { + return &tls.Config{ + // MinVersion should be set to 1.2 by gRPC by default. + Certificates: []tls.Certificate{serverCert}, + }, nil + }, + } + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Create server creds without a minimum version. + serverCreds := credentials.NewTLSWithALPNDisabled(tc.serverTLS()) + ss := stubserver.StubServer{ + EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { + return &testpb.Empty{}, nil + }, + } + + // Create client creds that supports V1.0-V1.1. + clientCreds := credentials.NewTLSWithALPNDisabled(&tls.Config{ + ServerName: serverName, + RootCAs: certPool, + MinVersion: tls.VersionTLS10, + MaxVersion: tls.VersionTLS11, + }) + + // Start server and client separately, because Start() blocks on a + // successful connection, which we will not get. + if err := ss.StartServer(grpc.Creds(serverCreds)); err != nil { + t.Fatalf("Error starting server: %v", err) + } + defer ss.Stop() + + cc, err := grpc.NewClient(ss.Address, grpc.WithTransportCredentials(clientCreds)) + if err != nil { + t.Fatalf("grpc.NewClient error: %v", err) + } + defer cc.Close() + + client := testgrpc.NewTestServiceClient(cc) + + const wantStr = "authentication handshake failed" + if _, err = client.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) != codes.Unavailable || !strings.Contains(status.Convert(err).Message(), wantStr) { + t.Fatalf("EmptyCall err = %v; want code=%v, message contains %q", err, codes.Unavailable, wantStr) + } + + }) + } +} + +// Tests that the MinVersion of tls.Config is not changed if it is set by the +// user. +func (s) TestTLS_MinVersionOverridable(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + var allCipherSuites []uint16 + for _, cs := range tls.CipherSuites() { + allCipherSuites = append(allCipherSuites, cs.ID) + } + testCases := []struct { + name string + serverTLS func() *tls.Config + }{ + { + name: "base_case", + serverTLS: func() *tls.Config { + return &tls.Config{ + MinVersion: tls.VersionTLS10, + Certificates: []tls.Certificate{serverCert}, + CipherSuites: allCipherSuites, + } + }, + }, + { + name: "fallback_to_base", + serverTLS: func() *tls.Config { + config := &tls.Config{ + MinVersion: tls.VersionTLS10, + Certificates: []tls.Certificate{serverCert}, + CipherSuites: allCipherSuites, + } + config.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) { + return nil, nil + } + return config + }, + }, + { + name: "dynamic_using_get_config_for_client", + serverTLS: func() *tls.Config { + return &tls.Config{ + GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) { + return &tls.Config{ + MinVersion: tls.VersionTLS10, + Certificates: []tls.Certificate{serverCert}, + CipherSuites: allCipherSuites, + }, nil + }, + } + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Create server creds that allow v1.0. + serverCreds := credentials.NewTLSWithALPNDisabled(tc.serverTLS()) + ss := stubserver.StubServer{ + EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { + return &testpb.Empty{}, nil + }, + } + + // Create client creds that supports V1.0-V1.1. + clientCreds := credentials.NewTLSWithALPNDisabled(&tls.Config{ + ServerName: serverName, + RootCAs: certPool, + CipherSuites: allCipherSuites, + MinVersion: tls.VersionTLS10, + MaxVersion: tls.VersionTLS11, + }) + + if err := ss.Start([]grpc.ServerOption{grpc.Creds(serverCreds)}, grpc.WithTransportCredentials(clientCreds)); err != nil { + t.Fatalf("Error starting stub server: %v", err) + } + defer ss.Stop() + + if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil { + t.Fatalf("EmptyCall err = %v; want ", err) + } + }) + } +} + +// Tests that CipherSuites is set to exclude HTTP/2 forbidden suites by default. +func (s) TestTLS_CipherSuites(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + testCases := []struct { + name string + serverTLS func() *tls.Config + }{ + { + name: "base_case", + serverTLS: func() *tls.Config { + return &tls.Config{ + Certificates: []tls.Certificate{serverCert}, + } + }, + }, + { + name: "fallback_to_base", + serverTLS: func() *tls.Config { + config := &tls.Config{ + Certificates: []tls.Certificate{serverCert}, + } + config.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) { + return nil, nil + } + return config + }, + }, + { + name: "dynamic_using_get_config_for_client", + serverTLS: func() *tls.Config { + return &tls.Config{ + GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) { + return &tls.Config{ + Certificates: []tls.Certificate{serverCert}, + }, nil + }, + } + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Create server creds without cipher suites. + serverCreds := credentials.NewTLSWithALPNDisabled(tc.serverTLS()) + ss := stubserver.StubServer{ + EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { + return &testpb.Empty{}, nil + }, + } + + // Create client creds that use a forbidden suite only. + clientCreds := credentials.NewTLSWithALPNDisabled(&tls.Config{ + ServerName: serverName, + RootCAs: certPool, + CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA}, + MaxVersion: tls.VersionTLS12, // TLS1.3 cipher suites are not configurable, so limit to 1.2. + }) + + // Start server and client separately, because Start() blocks on a + // successful connection, which we will not get. + if err := ss.StartServer(grpc.Creds(serverCreds)); err != nil { + t.Fatalf("Error starting server: %v", err) + } + defer ss.Stop() + + cc, err := grpc.NewClient("dns:"+ss.Address, grpc.WithTransportCredentials(clientCreds)) + if err != nil { + t.Fatalf("grpc.NewClient error: %v", err) + } + defer cc.Close() + + client := testgrpc.NewTestServiceClient(cc) + + const wantStr = "authentication handshake failed" + if _, err = client.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) != codes.Unavailable || !strings.Contains(status.Convert(err).Message(), wantStr) { + t.Fatalf("EmptyCall err = %v; want code=%v, message contains %q", err, codes.Unavailable, wantStr) + } + }) + } +} + +// Tests that CipherSuites is not overridden when it is set. +func (s) TestTLS_CipherSuitesOverridable(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + testCases := []struct { + name string + serverTLS func() *tls.Config + }{ + { + name: "base_case", + serverTLS: func() *tls.Config { + return &tls.Config{ + Certificates: []tls.Certificate{serverCert}, + CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA}, + } + }, + }, + { + name: "fallback_to_base", + serverTLS: func() *tls.Config { + config := &tls.Config{ + Certificates: []tls.Certificate{serverCert}, + CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA}, + } + config.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) { + return nil, nil + } + return config + }, + }, + { + name: "dynamic_using_get_config_for_client", + serverTLS: func() *tls.Config { + return &tls.Config{ + GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) { + return &tls.Config{ + Certificates: []tls.Certificate{serverCert}, + CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA}, + }, nil + }, + } + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Create server that allows only a forbidden cipher suite. + serverCreds := credentials.NewTLSWithALPNDisabled(tc.serverTLS()) + ss := stubserver.StubServer{ + EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { + return &testpb.Empty{}, nil + }, + } + + // Create server that allows only a forbidden cipher suite. + clientCreds := credentials.NewTLSWithALPNDisabled(&tls.Config{ + ServerName: serverName, + RootCAs: certPool, + CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA}, + MaxVersion: tls.VersionTLS12, // TLS1.3 cipher suites are not configurable, so limit to 1.2. + }) + + if err := ss.Start([]grpc.ServerOption{grpc.Creds(serverCreds)}, grpc.WithTransportCredentials(clientCreds)); err != nil { + t.Fatalf("Error starting stub server: %v", err) + } + defer ss.Stop() + + if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil { + t.Fatalf("EmptyCall err = %v; want ", err) + } + }) + } +} + +// TestTLS_ServerConfiguresALPNByDefault verifies that ALPN is configured +// correctly for a server that doesn't specify the NextProtos field and uses +// GetConfigForClient to provide the TLS config during the handshake. +func (s) TestTLS_ServerConfiguresALPNByDefault(t *testing.T) { + initialVal := envconfig.EnforceALPNEnabled + defer func() { + envconfig.EnforceALPNEnabled = initialVal + }() + envconfig.EnforceALPNEnabled = true + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + // Create a server that doesn't set the NextProtos field. + serverCreds := credentials.NewTLSWithALPNDisabled(&tls.Config{ + GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) { + return &tls.Config{ + Certificates: []tls.Certificate{serverCert}, + }, nil + }, + }) + + ss := stubserver.StubServer{ + EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { + return &testpb.Empty{}, nil + }, + } + + clientCreds := credsstable.NewTLS(&tls.Config{ + ServerName: serverName, + RootCAs: certPool, + }) + + if err := ss.Start([]grpc.ServerOption{grpc.Creds(serverCreds)}, grpc.WithTransportCredentials(clientCreds)); err != nil { + t.Fatalf("Error starting stub server: %v", err) + } + defer ss.Stop() + + if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil { + t.Fatalf("EmptyCall err = %v; want ", err) + } +} + +// TestTLS_DisabledALPNClient tests the behaviour of TransportCredentials when +// connecting to a server that doesn't support ALPN. +func (s) TestTLS_DisabledALPNClient(t *testing.T) { + initialVal := envconfig.EnforceALPNEnabled + defer func() { + envconfig.EnforceALPNEnabled = initialVal + }() + + tests := []struct { + name string + alpnEnforced bool + wantErr bool + }{ + { + name: "enforced", + }, + { + name: "not_enforced", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + envconfig.EnforceALPNEnabled = tc.alpnEnforced + + listener, err := tls.Listen("tcp", "localhost:0", &tls.Config{ + Certificates: []tls.Certificate{serverCert}, + NextProtos: []string{}, // Empty list indicates ALPN is disabled. + }) + if err != nil { + t.Fatalf("Error starting TLS server: %v", err) + } + + errCh := make(chan error, 1) + go func() { + conn, err := listener.Accept() + if err != nil { + errCh <- fmt.Errorf("listener.Accept returned error: %v", err) + } else { + // The first write to the TLS listener initiates the TLS handshake. + conn.Write([]byte("Hello, World!")) + conn.Close() + } + close(errCh) + }() + + serverAddr := listener.Addr().String() + conn, err := net.Dial("tcp", serverAddr) + if err != nil { + t.Fatalf("net.Dial(%s) failed: %v", serverAddr, err) + } + defer conn.Close() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + clientCfg := tls.Config{ + ServerName: serverName, + RootCAs: certPool, + NextProtos: []string{"h2"}, + } + _, _, err = credentials.NewTLSWithALPNDisabled(&clientCfg).ClientHandshake(ctx, serverName, conn) + + if gotErr := (err != nil); gotErr != tc.wantErr { + t.Errorf("ClientHandshake returned unexpected error: got=%v, want=%t", err, tc.wantErr) + } + + select { + case err := <-errCh: + if err != nil { + t.Fatalf("Unexpected error received from server: %v", err) + } + case <-ctx.Done(): + t.Fatalf("Timeout waiting for error from server") + } + }) + } +} + +// TestTLS_DisabledALPNServer tests the behaviour of TransportCredentials when +// accepting a request from a client that doesn't support ALPN. +func (s) TestTLS_DisabledALPNServer(t *testing.T) { + initialVal := envconfig.EnforceALPNEnabled + defer func() { + envconfig.EnforceALPNEnabled = initialVal + }() + + tests := []struct { + name string + alpnEnforced bool + wantErr bool + }{ + { + name: "enforced", + }, + { + name: "not_enforced", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + envconfig.EnforceALPNEnabled = tc.alpnEnforced + + listener, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("Error starting server: %v", err) + } + + errCh := make(chan error, 1) + go func() { + conn, err := listener.Accept() + if err != nil { + errCh <- fmt.Errorf("listener.Accept returned error: %v", err) + return + } + defer conn.Close() + serverCfg := tls.Config{ + Certificates: []tls.Certificate{serverCert}, + NextProtos: []string{"h2"}, + } + _, _, err = credentials.NewTLSWithALPNDisabled(&serverCfg).ServerHandshake(conn) + if gotErr := (err != nil); gotErr != tc.wantErr { + t.Errorf("ServerHandshake returned unexpected error: got=%v, want=%t", err, tc.wantErr) + } + close(errCh) + }() + + serverAddr := listener.Addr().String() + clientCfg := &tls.Config{ + Certificates: []tls.Certificate{serverCert}, + NextProtos: []string{}, // Empty list indicates ALPN is disabled. + RootCAs: certPool, + ServerName: serverName, + } + conn, err := tls.Dial("tcp", serverAddr, clientCfg) + if err != nil { + t.Fatalf("tls.Dial(%s) failed: %v", serverAddr, err) + } + defer conn.Close() + + select { + case <-time.After(defaultTestTimeout): + t.Fatal("Timed out waiting for completion") + case err := <-errCh: + if err != nil { + t.Fatalf("Unexpected server error: %v", err) + } + } + }) + } +}