From 3cd1ce36454fade2fe8c8e2be42756183cd5a8a0 Mon Sep 17 00:00:00 2001 From: Zhen Lian Date: Thu, 15 Dec 2022 14:08:40 -0800 Subject: [PATCH] add integration tests --- .../advancedtls_integration_test.go | 108 ++++++++++++++++-- 1 file changed, 100 insertions(+), 8 deletions(-) diff --git a/security/advancedtls/advancedtls_integration_test.go b/security/advancedtls/advancedtls_integration_test.go index d5a620d14f96..a19718abf622 100644 --- a/security/advancedtls/advancedtls_integration_test.go +++ b/security/advancedtls/advancedtls_integration_test.go @@ -731,13 +731,12 @@ func (s) TestDefaultHostNameCheck(t *testing.T) { t.Fatalf("cs.LoadCerts() failed, err: %v", err) } for _, test := range []struct { - desc string - clientRoot *x509.CertPool - clientVerifyFunc CustomVerificationFunc - clientVType VerificationType - serverCert []tls.Certificate - serverVType VerificationType - expectError bool + desc string + clientRoot *x509.CertPool + clientVType VerificationType + serverCert []tls.Certificate + serverVType VerificationType + expectError bool }{ // Client side sets vType to CertAndHostVerification, and will do // default hostname check. Server uses a cert without "localhost" or @@ -787,7 +786,6 @@ func (s) TestDefaultHostNameCheck(t *testing.T) { pb.RegisterGreeterServer(s, greeterServer{}) go s.Serve(lis) clientOptions := &ClientOptions{ - VerifyPeer: test.clientVerifyFunc, RootOptions: RootCertificateOptions{ RootCACerts: test.clientRoot, }, @@ -811,3 +809,97 @@ func (s) TestDefaultHostNameCheck(t *testing.T) { }) } } + +func (s) TestTLSVersions(t *testing.T) { + cs := &testutils.CertStore{} + if err := cs.LoadCerts(); err != nil { + t.Fatalf("cs.LoadCerts() failed, err: %v", err) + } + for _, test := range []struct { + desc string + expectError bool + clientMinVersion uint16 + clientMaxVersion uint16 + serverMinVersion uint16 + serverMaxVersion uint16 + }{ + // Client side sets TLS version that is higher than required from the server side. + { + desc: "Client TLS version higher than server", + clientMinVersion: tls.VersionTLS13, + clientMaxVersion: tls.VersionTLS13, + serverMinVersion: tls.VersionTLS12, + serverMaxVersion: tls.VersionTLS12, + expectError: true, + }, + // Server side sets TLS version that is higher than required from the client side. + { + desc: "Server TLS version higher than client", + clientMinVersion: tls.VersionTLS12, + clientMaxVersion: tls.VersionTLS12, + serverMinVersion: tls.VersionTLS13, + serverMaxVersion: tls.VersionTLS13, + expectError: true, + }, + // Client and server set proper TLS versions. + { + desc: "Good TLS version settings", + clientMinVersion: tls.VersionTLS12, + clientMaxVersion: tls.VersionTLS13, + serverMinVersion: tls.VersionTLS12, + serverMaxVersion: tls.VersionTLS13, + expectError: false, + }, + } { + test := test + t.Run(test.desc, func(t *testing.T) { + // Start a server using ServerOptions in another goroutine. + serverOptions := &ServerOptions{ + IdentityOptions: IdentityCertificateOptions{ + Certificates: []tls.Certificate{cs.ServerPeerLocalhost1}, + }, + RequireClientCert: false, + VType: CertAndHostVerification, + MinVersion: test.serverMinVersion, + MaxVersion: test.serverMaxVersion, + } + serverTLSCreds, err := NewServerCreds(serverOptions) + if err != nil { + t.Fatalf("failed to create server creds: %v", err) + } + s := grpc.NewServer(grpc.Creds(serverTLSCreds)) + defer s.Stop() + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("failed to listen: %v", err) + } + defer lis.Close() + addr := fmt.Sprintf("localhost:%v", lis.Addr().(*net.TCPAddr).Port) + pb.RegisterGreeterServer(s, greeterServer{}) + go s.Serve(lis) + clientOptions := &ClientOptions{ + RootOptions: RootCertificateOptions{ + RootCACerts: cs.ClientTrust1, + }, + VType: CertAndHostVerification, + MinVersion: test.clientMinVersion, + MaxVersion: test.clientMaxVersion, + } + clientTLSCreds, err := NewClientCreds(clientOptions) + if err != nil { + t.Fatalf("clientTLSCreds failed to create: %v", err) + } + shouldFail := false + if test.expectError { + shouldFail = true + } + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + conn, _, err := callAndVerifyWithClientConn(ctx, addr, "rpc call 1", clientTLSCreds, shouldFail) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + }) + } +}