From a628f6e69320640828f8c4914f602815101e0a88 Mon Sep 17 00:00:00 2001 From: Jackie Luc <15662837+jackieluc@users.noreply.github.com> Date: Fri, 7 Jun 2024 11:45:30 -0700 Subject: [PATCH] test: refactor TestMtlsRootCAsFromCertificate to table-based tests --- cns/service_test.go | 117 +++++++++++++++++++++++++------------------- 1 file changed, 67 insertions(+), 50 deletions(-) diff --git a/cns/service_test.go b/cns/service_test.go index ea3610eac99..9bf4af8ce7f 100644 --- a/cns/service_test.go +++ b/cns/service_test.go @@ -197,58 +197,75 @@ func TestMtlsRootCAsFromCertificate(t *testing.T) { key, err := tlsCertRetriever.GetPrivateKey() require.NoError(t, err) - t.Run("returns root CA pool when provided a single self-signed CA cert", func(t *testing.T) { - // one root CA - tlsCert := tls.Certificate{ - Certificate: [][]byte{cert.Raw}, - PrivateKey: key, - Leaf: cert, - } - - var r *x509.CertPool - r, err = mtlsRootCAsFromCertificate(&tlsCert) - require.NoError(t, err) - assert.NotNil(t, r) - }) - t.Run("returns root CA pool when provided with a full cert chain", func(t *testing.T) { - // simulate a full cert chain (leaf cert + root CA cert) - tlsCert := tls.Certificate{ - Certificate: [][]byte{cert.Raw, cert.Raw}, - PrivateKey: key, - Leaf: cert, - } - require.NoError(t, err) - r, err := mtlsRootCAsFromCertificate(&tlsCert) - require.NoError(t, err) - assert.NotNil(t, r) - }) - t.Run("does not return root CA pool when provided with no cert", func(t *testing.T) { - r, err := mtlsRootCAsFromCertificate(nil) - require.Error(t, err) - assert.Nil(t, r) - - r, err = mtlsRootCAsFromCertificate(&tls.Certificate{}) - require.Error(t, err) - assert.Nil(t, r) - }) - t.Run("does not return root CA pool when provided with invalid certs", func(t *testing.T) { - tt := []struct { - invalidCert [][]byte - }{ - {nil}, - {[][]byte{[]byte("invalid leaf cert")}}, - {[][]byte{[]byte("invalid leaf cert"), []byte("invalid root CA cert")}}, - } + tests := []struct { + name string + cert *tls.Certificate + wantErr bool + wantErrMsg string + }{ + { + name: "returns root CA pool when provided a single self-signed CA cert", + cert: &tls.Certificate{ + Certificate: [][]byte{cert.Raw}, + PrivateKey: key, + Leaf: cert, + }, + wantErr: false, + wantErrMsg: "", + }, + { + name: "returns root CA pool when provided with a full cert chain", + cert: &tls.Certificate{ + Certificate: [][]byte{cert.Raw, cert.Raw}, + PrivateKey: key, + Leaf: cert, + }, + wantErr: false, + wantErrMsg: "", + }, + { + name: "does not return root CA pool when provided with nil", + cert: nil, + wantErr: true, + wantErrMsg: "no certificate provided", + }, + { + name: "does not return root CA pool when provided with empty cert", + cert: &tls.Certificate{}, + wantErr: true, + wantErrMsg: "no certificate provided", + }, + { + name: "does not return root CA pool when provided with single invalid cert", + cert: &tls.Certificate{ + Certificate: [][]byte{[]byte("invalid leaf cert")}, + }, + wantErr: true, + wantErrMsg: "parsing self signed cert", + }, + { + name: "does not return root CA pool when provided with invalid full chain cert", + cert: &tls.Certificate{ + Certificate: [][]byte{[]byte("invalid leaf cert"), []byte("invalid root CA cert")}, + }, + wantErr: true, + wantErrMsg: "parsing root certs", + }, + } - for _, tc := range tt { - cert := tls.Certificate{ - Certificate: tc.invalidCert, + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r, err := mtlsRootCAsFromCertificate(tt.cert) + if tt.wantErr { + require.Error(t, err) + require.ErrorContains(t, err, tt.wantErrMsg) + assert.Nil(t, r) + } else { + require.NoError(t, err) + assert.NotNil(t, r) } - r, err := mtlsRootCAsFromCertificate(&cert) - require.Error(t, err) - assert.Nil(t, r) - } - }) + }) + } } // createTestCertificate is a test helper that creates a test certificate