diff --git a/p2p/security/tls/extension_test.go b/p2p/security/tls/extension_test.go index 5a7ef47756..f50695a248 100644 --- a/p2p/security/tls/extension_test.go +++ b/p2p/security/tls/extension_test.go @@ -1,19 +1,18 @@ package libp2ptls import ( - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" + "testing" + + "github.com/stretchr/testify/require" ) -var _ = Describe("Extensions", func() { - It("generates a prefixed extension ID", func() { - Expect(getPrefixedExtensionID([]int{13, 37})).To(Equal([]int{1, 3, 6, 1, 4, 1, 53594, 13, 37})) - }) +func TestExtensionGenerating(t *testing.T) { + require.Equal(t, getPrefixedExtensionID([]int{13, 37}), []int{1, 3, 6, 1, 4, 1, 53594, 13, 37}) +} - It("compares extension IDs", func() { - Expect(extensionIDEqual([]int{1, 2, 3, 4}, []int{1, 2, 3, 4})).To(BeTrue()) - Expect(extensionIDEqual([]int{1, 2, 3, 4}, []int{1, 2, 3})).To(BeFalse()) - Expect(extensionIDEqual([]int{1, 2, 3}, []int{1, 2, 3, 4})).To(BeFalse()) - Expect(extensionIDEqual([]int{1, 2, 3, 4}, []int{4, 3, 2, 1})).To(BeFalse()) - }) -}) +func TestExtensionComparison(t *testing.T) { + require.True(t, extensionIDEqual([]int{1, 2, 3, 4}, []int{1, 2, 3, 4})) + require.False(t, extensionIDEqual([]int{1, 2, 3, 4}, []int{1, 2, 3})) + require.False(t, extensionIDEqual([]int{1, 2, 3}, []int{1, 2, 3, 4})) + require.False(t, extensionIDEqual([]int{1, 2, 3, 4}, []int{4, 3, 2, 1})) +} diff --git a/p2p/security/tls/libp2p_tls_suite_test.go b/p2p/security/tls/libp2p_tls_suite_test.go deleted file mode 100644 index e0e6785862..0000000000 --- a/p2p/security/tls/libp2p_tls_suite_test.go +++ /dev/null @@ -1,18 +0,0 @@ -package libp2ptls - -import ( - mrand "math/rand" - "testing" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -func TestLibp2pTLS(t *testing.T) { - RegisterFailHandler(Fail) - RunSpecs(t, "libp2p TLS Suite") -} - -var _ = BeforeSuite(func() { - mrand.Seed(GinkgoRandomSeed()) -}) diff --git a/p2p/security/tls/transport_test.go b/p2p/security/tls/transport_test.go index c85c64ea55..8106f815a7 100644 --- a/p2p/security/tls/transport_test.go +++ b/p2p/security/tls/transport_test.go @@ -15,486 +15,497 @@ import ( "math/big" mrand "math/rand" "net" + "testing" "time" - ci "github.com/libp2p/go-libp2p-core/crypto" + ic "github.com/libp2p/go-libp2p-core/crypto" "github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-libp2p-core/sec" - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" - "github.com/onsi/gomega/gbytes" - "github.com/onsi/gomega/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) -type transform struct { - name string - apply func(*Identity) - remoteErr types.GomegaMatcher // the error that the side validating the chain gets +func createPeer(t *testing.T) (peer.ID, ic.PrivKey) { + var priv ic.PrivKey + var err error + switch mrand.Int() % 4 { + case 0: + priv, _, err = ic.GenerateECDSAKeyPair(rand.Reader) + case 1: + priv, _, err = ic.GenerateRSAKeyPair(2048, rand.Reader) + case 2: + priv, _, err = ic.GenerateEd25519Key(rand.Reader) + case 3: + priv, _, err = ic.GenerateSecp256k1Key(rand.Reader) + } + require.NoError(t, err) + id, err := peer.IDFromPrivateKey(priv) + require.NoError(t, err) + t.Logf("using a %s key: %s", priv.Type(), id.Pretty()) + return id, priv } -var _ = Describe("Transport", func() { - var ( - serverKey, clientKey ci.PrivKey - serverID, clientID peer.ID - ) - - createPeer := func() (peer.ID, ci.PrivKey) { - var priv ci.PrivKey - var err error - switch mrand.Int() % 4 { - case 0: - fmt.Fprintf(GinkgoWriter, " using an ECDSA key: ") - priv, _, err = ci.GenerateECDSAKeyPair(rand.Reader) - case 1: - fmt.Fprintf(GinkgoWriter, " using an RSA key: ") - priv, _, err = ci.GenerateRSAKeyPair(2048, rand.Reader) - case 2: - fmt.Fprintf(GinkgoWriter, " using an Ed25519 key: ") - priv, _, err = ci.GenerateEd25519Key(rand.Reader) - case 3: - fmt.Fprintf(GinkgoWriter, " using an secp256k1 key: ") - priv, _, err = ci.GenerateSecp256k1Key(rand.Reader) - } - Expect(err).ToNot(HaveOccurred()) - id, err := peer.IDFromPrivateKey(priv) - Expect(err).ToNot(HaveOccurred()) - fmt.Fprintln(GinkgoWriter, id.Pretty()) - return id, priv - } +func connect(t *testing.T) (net.Conn, net.Conn) { + ln, err := net.Listen("tcp", "localhost:0") + require.NoError(t, err) + defer ln.Close() + serverConnChan := make(chan net.Conn) + go func() { + conn, err := ln.Accept() + assert.NoError(t, err) + serverConnChan <- conn + }() + conn, err := net.Dial("tcp", ln.Addr().String()) + require.NoError(t, err) + return conn, <-serverConnChan +} + +func TestHandshakeSucceeds(t *testing.T) { + clientID, clientKey := createPeer(t) + serverID, serverKey := createPeer(t) + + handshake := func(t *testing.T) { + clientTransport, err := New(clientKey) + require.NoError(t, err) + serverTransport, err := New(serverKey) + require.NoError(t, err) - connect := func() (net.Conn, net.Conn) { - ln, err := net.Listen("tcp", "localhost:0") - Expect(err).ToNot(HaveOccurred()) - defer ln.Close() - serverConnChan := make(chan net.Conn) + clientInsecureConn, serverInsecureConn := connect(t) + + serverConnChan := make(chan sec.SecureConn) go func() { - defer GinkgoRecover() - conn, err := ln.Accept() - Expect(err).ToNot(HaveOccurred()) - serverConnChan <- conn + serverConn, err := serverTransport.SecureInbound(context.Background(), serverInsecureConn, "") + require.NoError(t, err) + serverConnChan <- serverConn }() - conn, err := net.Dial("tcp", ln.Addr().String()) - Expect(err).ToNot(HaveOccurred()) - return conn, <-serverConnChan + + clientConn, err := clientTransport.SecureOutbound(context.Background(), clientInsecureConn, serverID) + require.NoError(t, err) + defer clientConn.Close() + + var serverConn sec.SecureConn + select { + case serverConn = <-serverConnChan: + case <-time.After(250 * time.Millisecond): + t.Fatal("expected the server to accept a connection") + } + defer serverConn.Close() + + require.Equal(t, clientConn.LocalPeer(), clientID) + require.Equal(t, serverConn.LocalPeer(), serverID) + require.True(t, clientConn.LocalPrivateKey().Equals(clientKey), "client private key mismatch") + require.True(t, serverConn.LocalPrivateKey().Equals(serverKey), "server private key mismatch") + require.Equal(t, clientConn.RemotePeer(), serverID) + require.Equal(t, serverConn.RemotePeer(), clientID) + require.True(t, clientConn.RemotePublicKey().Equals(serverKey.GetPublic()), "server public key mismatch") + require.True(t, serverConn.RemotePublicKey().Equals(clientKey.GetPublic()), "client public key mismatch") + // exchange some data + _, err = serverConn.Write([]byte("foobar")) + require.NoError(t, err) + b := make([]byte, 6) + _, err = clientConn.Read(b) + require.NoError(t, err) + require.Equal(t, string(b), "foobar") } - BeforeEach(func() { - fmt.Fprintf(GinkgoWriter, "Initializing a server") - serverID, serverKey = createPeer() - fmt.Fprintf(GinkgoWriter, "Initializing a client") - clientID, clientKey = createPeer() + t.Run("with extension not critical", func(t *testing.T) { + handshake(t) }) - Context("successful handshakes", func() { - for _, critical := range []bool{true, false} { - crit := critical - - It(fmt.Sprintf("handshakes, extension critical: %t", crit), func() { - extensionCritical = crit - defer func() { extensionCritical = false }() - clientTransport, err := New(clientKey) - Expect(err).ToNot(HaveOccurred()) - serverTransport, err := New(serverKey) - Expect(err).ToNot(HaveOccurred()) - - clientInsecureConn, serverInsecureConn := connect() - - serverConnChan := make(chan sec.SecureConn) - go func() { - defer GinkgoRecover() - serverConn, err := serverTransport.SecureInbound(context.Background(), serverInsecureConn, "") - Expect(err).ToNot(HaveOccurred()) - serverConnChan <- serverConn - }() - clientConn, err := clientTransport.SecureOutbound(context.Background(), clientInsecureConn, serverID) - Expect(err).ToNot(HaveOccurred()) - var serverConn sec.SecureConn - Eventually(serverConnChan).Should(Receive(&serverConn)) - defer clientConn.Close() - defer serverConn.Close() - Expect(clientConn.LocalPeer()).To(Equal(clientID)) - Expect(serverConn.LocalPeer()).To(Equal(serverID)) - Expect(clientConn.LocalPrivateKey()).To(Equal(clientKey)) - Expect(serverConn.LocalPrivateKey()).To(Equal(serverKey)) - Expect(clientConn.RemotePeer()).To(Equal(serverID)) - Expect(serverConn.RemotePeer()).To(Equal(clientID)) - Expect(ci.KeyEqual(clientConn.RemotePublicKey(), serverKey.GetPublic())).To(BeTrue()) - Expect(ci.KeyEqual(serverConn.RemotePublicKey(), clientKey.GetPublic())).To(BeTrue()) - // exchange some data - _, err = serverConn.Write([]byte("foobar")) - Expect(err).ToNot(HaveOccurred()) - b := make([]byte, 6) - _, err = clientConn.Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(string(b)).To(Equal("foobar")) - }) - } + t.Run("with extension critical", func(t *testing.T) { + extensionCritical = true + t.Cleanup(func() { extensionCritical = false }) + + handshake(t) }) +} - It("fails when the context of the outgoing connection is canceled", func() { - clientTransport, err := New(clientKey) - Expect(err).ToNot(HaveOccurred()) - serverTransport, err := New(serverKey) - Expect(err).ToNot(HaveOccurred()) +func TestHandshakeConnectionCancelations(t *testing.T) { + _, clientKey := createPeer(t) + serverID, serverKey := createPeer(t) + + clientTransport, err := New(clientKey) + require.NoError(t, err) + serverTransport, err := New(serverKey) + require.NoError(t, err) - clientInsecureConn, serverInsecureConn := connect() + t.Run("cancel outgoing connection", func(t *testing.T) { + clientInsecureConn, serverInsecureConn := connect(t) + errChan := make(chan error) go func() { - defer GinkgoRecover() _, err := serverTransport.SecureInbound(context.Background(), serverInsecureConn, "") - Expect(err).To(HaveOccurred()) + errChan <- err }() ctx, cancel := context.WithCancel(context.Background()) cancel() _, err = clientTransport.SecureOutbound(ctx, clientInsecureConn, serverID) - Expect(err).To(MatchError(context.Canceled)) + require.ErrorIs(t, err, context.Canceled) + require.Error(t, <-errChan) }) - It("fails when the context of the incoming connection is canceled", func() { - clientTransport, err := New(clientKey) - Expect(err).ToNot(HaveOccurred()) - serverTransport, err := New(serverKey) - Expect(err).ToNot(HaveOccurred()) - - clientInsecureConn, serverInsecureConn := connect() + t.Run("cancel incoming connection", func(t *testing.T) { + clientInsecureConn, serverInsecureConn := connect(t) + errChan := make(chan error) go func() { - defer GinkgoRecover() ctx, cancel := context.WithCancel(context.Background()) cancel() _, err := serverTransport.SecureInbound(ctx, serverInsecureConn, "") - Expect(err).To(MatchError(context.Canceled)) + errChan <- err }() _, err = clientTransport.SecureOutbound(context.Background(), clientInsecureConn, serverID) - Expect(err).To(HaveOccurred()) + require.Error(t, err) + require.ErrorIs(t, <-errChan, context.Canceled) }) +} - Context("peer ID checks", func() { - It("succeeds when the server checks the client's ID", func() { - serverTransport, err := New(serverKey) - Expect(err).ToNot(HaveOccurred()) - clientTransport, err := New(clientKey) - Expect(err).ToNot(HaveOccurred()) - - clientInsecureConn, serverInsecureConn := connect() - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - defer close(done) - conn, err := serverTransport.SecureInbound(context.Background(), serverInsecureConn, clientID) - Expect(err).ToNot(HaveOccurred()) - Expect(conn.RemotePeer()).To(Equal(clientID)) - b := make([]byte, 6) - _, err = conn.Read(b) - Expect(err).ToNot(HaveOccurred()) - Expect(string(b)).To(Equal("foobar")) - }() - conn, err := clientTransport.SecureOutbound(context.Background(), clientInsecureConn, serverID) - Expect(err).ToNot(HaveOccurred()) - defer conn.Close() - _, err = conn.Write([]byte("foobar")) - Expect(err).ToNot(HaveOccurred()) - Expect(conn.RemotePeer()).To(Equal(serverID)) - Eventually(done).Should(BeClosed()) - }) +func TestPeerIDMismatch(t *testing.T) { + _, clientKey := createPeer(t) + serverID, serverKey := createPeer(t) - It("fails if the peer ID doesn't match, for outgoing connections", func() { - fmt.Fprintf(GinkgoWriter, "Creating another peer") - thirdPartyID, _ := createPeer() + serverTransport, err := New(serverKey) + require.NoError(t, err) + clientTransport, err := New(clientKey) + require.NoError(t, err) - serverTransport, err := New(serverKey) - Expect(err).ToNot(HaveOccurred()) - clientTransport, err := New(clientKey) - Expect(err).ToNot(HaveOccurred()) + t.Run("for outgoing connections", func(t *testing.T) { + clientInsecureConn, serverInsecureConn := connect(t) - clientInsecureConn, serverInsecureConn := connect() - - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - defer close(done) - _, err := serverTransport.SecureInbound(context.Background(), serverInsecureConn, "") - Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("tls: bad certificate")) - }() - // dial, but expect the wrong peer ID - _, err = clientTransport.SecureOutbound(context.Background(), clientInsecureConn, thirdPartyID) - Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("peer IDs don't match")) - Eventually(done).Should(BeClosed()) - }) - - It("fails if the peer ID doesn't match, for incoming connections", func() { - fmt.Fprintf(GinkgoWriter, "Creating another peer") - thirdPartyID, _ := createPeer() + errChan := make(chan error) + go func() { + _, err := serverTransport.SecureInbound(context.Background(), serverInsecureConn, "") + errChan <- err + }() - serverTransport, err := New(serverKey) - Expect(err).ToNot(HaveOccurred()) - clientTransport, err := New(clientKey) - Expect(err).ToNot(HaveOccurred()) + // dial, but expect the wrong peer ID + thirdPartyID, _ := createPeer(t) + _, err = clientTransport.SecureOutbound(context.Background(), clientInsecureConn, thirdPartyID) + require.Error(t, err) + require.Contains(t, err.Error(), "peer IDs don't match") + + var serverErr error + select { + case serverErr = <-errChan: + case <-time.After(250 * time.Millisecond): + t.Fatal("expected handshake to return on the server side") + } + require.Error(t, serverErr) + require.Contains(t, serverErr.Error(), "tls: bad certificate") + }) - clientInsecureConn, serverInsecureConn := connect() + t.Run("for incoming connections", func(t *testing.T) { + clientInsecureConn, serverInsecureConn := connect(t) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - defer close(done) - conn, err := clientTransport.SecureOutbound(context.Background(), clientInsecureConn, serverID) - Expect(err).ToNot(HaveOccurred()) - _, err = conn.Read([]byte{0}) - Expect(err.Error()).To(ContainSubstring("tls: bad certificate")) - }() - // accept connection, but expect the wrong peer ID - _, err = serverTransport.SecureInbound(context.Background(), serverInsecureConn, thirdPartyID) - Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("peer IDs don't match")) - Eventually(done).Should(BeClosed()) - }) - }) + errChan := make(chan error) + go func() { + thirdPartyID, _ := createPeer(t) + // expect the wrong peer ID + _, err := serverTransport.SecureInbound(context.Background(), serverInsecureConn, thirdPartyID) + errChan <- err + }() - Context("invalid certificates", func() { - invalidateCertChain := func(identity *Identity) { - switch identity.config.Certificates[0].PrivateKey.(type) { - case *rsa.PrivateKey: - key, err := rsa.GenerateKey(rand.Reader, 2048) - Expect(err).ToNot(HaveOccurred()) - identity.config.Certificates[0].PrivateKey = key - case *ecdsa.PrivateKey: - key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - Expect(err).ToNot(HaveOccurred()) - identity.config.Certificates[0].PrivateKey = key - default: - Fail("unexpected private key type") - } + conn, err := clientTransport.SecureOutbound(context.Background(), clientInsecureConn, serverID) + require.NoError(t, err) + _, err = conn.Read([]byte{0}) + require.Error(t, err) + require.Contains(t, err.Error(), "tls: bad certificate") + + var serverErr error + select { + case serverErr = <-errChan: + case <-time.After(250 * time.Millisecond): + t.Fatal("expected handshake to return on the server side") } + require.Error(t, serverErr) + require.Contains(t, serverErr.Error(), "peer IDs don't match") + }) +} - twoCerts := func(identity *Identity) { - tmpl := &x509.Certificate{SerialNumber: big.NewInt(1)} - key1, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - Expect(err).ToNot(HaveOccurred()) - key2, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - Expect(err).ToNot(HaveOccurred()) - cert1DER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, key1.Public(), key1) - Expect(err).ToNot(HaveOccurred()) - cert1, err := x509.ParseCertificate(cert1DER) - Expect(err).ToNot(HaveOccurred()) - cert2DER, err := x509.CreateCertificate(rand.Reader, tmpl, cert1, key2.Public(), key1) - Expect(err).ToNot(HaveOccurred()) - identity.config.Certificates = []tls.Certificate{{ - Certificate: [][]byte{cert2DER, cert1DER}, - PrivateKey: key2, - }} - } +func TestInvalidCerts(t *testing.T) { + _, clientKey := createPeer(t) + serverID, serverKey := createPeer(t) - getCertWithKey := func(key crypto.Signer, tmpl *x509.Certificate) tls.Certificate { - cert, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, key.Public(), key) - Expect(err).ToNot(HaveOccurred()) - return tls.Certificate{ - Certificate: [][]byte{cert}, - PrivateKey: key, - } - } + type transform struct { + name string + apply func(*Identity) + checkErr func(*testing.T, error) // the error that the side validating the chain gets + } - getCert := func(tmpl *x509.Certificate) tls.Certificate { + invalidateCertChain := func(identity *Identity) { + switch identity.config.Certificates[0].PrivateKey.(type) { + case *rsa.PrivateKey: + key, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + identity.config.Certificates[0].PrivateKey = key + case *ecdsa.PrivateKey: key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - Expect(err).ToNot(HaveOccurred()) - return getCertWithKey(key, tmpl) + require.NoError(t, err) + identity.config.Certificates[0].PrivateKey = key + default: + t.Fatal("unexpected private key type") } + } - expiredCert := func(identity *Identity) { - cert := getCert(&x509.Certificate{ - SerialNumber: big.NewInt(1), - NotBefore: time.Now().Add(-time.Hour), - NotAfter: time.Now().Add(-time.Minute), - ExtraExtensions: []pkix.Extension{ - {Id: extensionID, Value: []byte("foobar")}, - }, - }) - identity.config.Certificates = []tls.Certificate{cert} - } + twoCerts := func(identity *Identity) { + tmpl := &x509.Certificate{SerialNumber: big.NewInt(1)} + key1, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + key2, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + cert1DER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, key1.Public(), key1) + require.NoError(t, err) + cert1, err := x509.ParseCertificate(cert1DER) + require.NoError(t, err) + cert2DER, err := x509.CreateCertificate(rand.Reader, tmpl, cert1, key2.Public(), key1) + require.NoError(t, err) + identity.config.Certificates = []tls.Certificate{{ + Certificate: [][]byte{cert2DER, cert1DER}, + PrivateKey: key2, + }} + } - noKeyExtension := func(identity *Identity) { - cert := getCert(&x509.Certificate{ - SerialNumber: big.NewInt(1), - NotBefore: time.Now().Add(-time.Hour), - NotAfter: time.Now().Add(time.Hour), - }) - identity.config.Certificates = []tls.Certificate{cert} + getCertWithKey := func(key crypto.Signer, tmpl *x509.Certificate) tls.Certificate { + cert, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, key.Public(), key) + require.NoError(t, err) + return tls.Certificate{ + Certificate: [][]byte{cert}, + PrivateKey: key, } + } - unparseableKeyExtension := func(identity *Identity) { - cert := getCert(&x509.Certificate{ - SerialNumber: big.NewInt(1), - NotBefore: time.Now().Add(-time.Hour), - NotAfter: time.Now().Add(time.Hour), - ExtraExtensions: []pkix.Extension{ - {Id: extensionID, Value: []byte("foobar")}, - }, - }) - identity.config.Certificates = []tls.Certificate{cert} - } + getCert := func(tmpl *x509.Certificate) tls.Certificate { + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + return getCertWithKey(key, tmpl) + } - unparseableKey := func(identity *Identity) { - data, err := asn1.Marshal(signedKey{PubKey: []byte("foobar")}) - Expect(err).ToNot(HaveOccurred()) - cert := getCert(&x509.Certificate{ - SerialNumber: big.NewInt(1), - NotBefore: time.Now().Add(-time.Hour), - NotAfter: time.Now().Add(time.Hour), - ExtraExtensions: []pkix.Extension{ - {Id: extensionID, Value: data}, - }, - }) - identity.config.Certificates = []tls.Certificate{cert} - } + expiredCert := func(identity *Identity) { + cert := getCert(&x509.Certificate{ + SerialNumber: big.NewInt(1), + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(-time.Minute), + ExtraExtensions: []pkix.Extension{ + {Id: extensionID, Value: []byte("foobar")}, + }, + }) + identity.config.Certificates = []tls.Certificate{cert} + } - tooShortSignature := func(identity *Identity) { - key, _, err := ci.GenerateSecp256k1Key(rand.Reader) - Expect(err).ToNot(HaveOccurred()) - keyBytes, err := ci.MarshalPublicKey(key.GetPublic()) - Expect(err).ToNot(HaveOccurred()) - data, err := asn1.Marshal(signedKey{ - PubKey: keyBytes, - Signature: []byte("foobar"), - }) - Expect(err).ToNot(HaveOccurred()) - cert := getCert(&x509.Certificate{ - SerialNumber: big.NewInt(1), - NotBefore: time.Now().Add(-time.Hour), - NotAfter: time.Now().Add(time.Hour), - ExtraExtensions: []pkix.Extension{ - {Id: extensionID, Value: data}, - }, - }) - identity.config.Certificates = []tls.Certificate{cert} - } + noKeyExtension := func(identity *Identity) { + cert := getCert(&x509.Certificate{ + SerialNumber: big.NewInt(1), + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(time.Hour), + }) + identity.config.Certificates = []tls.Certificate{cert} + } - invalidSignature := func(identity *Identity) { - key, _, err := ci.GenerateSecp256k1Key(rand.Reader) - Expect(err).ToNot(HaveOccurred()) - keyBytes, err := ci.MarshalPublicKey(key.GetPublic()) - Expect(err).ToNot(HaveOccurred()) - signature, err := key.Sign([]byte("foobar")) - Expect(err).ToNot(HaveOccurred()) - data, err := asn1.Marshal(signedKey{ - PubKey: keyBytes, - Signature: signature, - }) - Expect(err).ToNot(HaveOccurred()) - cert := getCert(&x509.Certificate{ - SerialNumber: big.NewInt(1), - NotBefore: time.Now().Add(-time.Hour), - NotAfter: time.Now().Add(time.Hour), - ExtraExtensions: []pkix.Extension{ - {Id: extensionID, Value: data}, - }, - }) - identity.config.Certificates = []tls.Certificate{cert} - } + unparseableKeyExtension := func(identity *Identity) { + cert := getCert(&x509.Certificate{ + SerialNumber: big.NewInt(1), + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(time.Hour), + ExtraExtensions: []pkix.Extension{ + {Id: extensionID, Value: []byte("foobar")}, + }, + }) + identity.config.Certificates = []tls.Certificate{cert} + } - transforms := []transform{ - { - name: "private key used in the TLS handshake doesn't match the public key in the cert", - apply: invalidateCertChain, - remoteErr: Or( - Equal("tls: invalid signature by the client certificate: ECDSA verification failure"), - Equal("tls: invalid signature by the server certificate: ECDSA verification failure"), - ), + unparseableKey := func(identity *Identity) { + data, err := asn1.Marshal(signedKey{PubKey: []byte("foobar")}) + require.NoError(t, err) + cert := getCert(&x509.Certificate{ + SerialNumber: big.NewInt(1), + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(time.Hour), + ExtraExtensions: []pkix.Extension{ + {Id: extensionID, Value: data}, }, - { - name: "certificate chain contains 2 certs", - apply: twoCerts, - remoteErr: Equal("expected one certificates in the chain"), + }) + identity.config.Certificates = []tls.Certificate{cert} + } + + tooShortSignature := func(identity *Identity) { + key, _, err := ic.GenerateSecp256k1Key(rand.Reader) + require.NoError(t, err) + keyBytes, err := ic.MarshalPublicKey(key.GetPublic()) + require.NoError(t, err) + data, err := asn1.Marshal(signedKey{ + PubKey: keyBytes, + Signature: []byte("foobar"), + }) + require.NoError(t, err) + cert := getCert(&x509.Certificate{ + SerialNumber: big.NewInt(1), + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(time.Hour), + ExtraExtensions: []pkix.Extension{ + {Id: extensionID, Value: data}, }, - { - name: "cert is expired", - apply: expiredCert, - remoteErr: ContainSubstring("certificate has expired or is not yet valid"), + }) + identity.config.Certificates = []tls.Certificate{cert} + } + + invalidSignature := func(identity *Identity) { + key, _, err := ic.GenerateSecp256k1Key(rand.Reader) + require.NoError(t, err) + keyBytes, err := ic.MarshalPublicKey(key.GetPublic()) + require.NoError(t, err) + signature, err := key.Sign([]byte("foobar")) + require.NoError(t, err) + data, err := asn1.Marshal(signedKey{ + PubKey: keyBytes, + Signature: signature, + }) + require.NoError(t, err) + cert := getCert(&x509.Certificate{ + SerialNumber: big.NewInt(1), + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(time.Hour), + ExtraExtensions: []pkix.Extension{ + {Id: extensionID, Value: data}, }, - { - name: "cert doesn't have the key extension", - apply: noKeyExtension, - remoteErr: Equal("expected certificate to contain the key extension"), + }) + identity.config.Certificates = []tls.Certificate{cert} + } + + transforms := []transform{ + { + name: "private key used in the TLS handshake doesn't match the public key in the cert", + apply: invalidateCertChain, + checkErr: func(t *testing.T, err error) { + if err.Error() != "tls: invalid signature by the client certificate: ECDSA verification failure" && + err.Error() != "tls: invalid signature by the server certificate: ECDSA verification failure" { + t.Fatalf("unexpected error message: %s", err) + } }, - { - name: "key extension not parseable", - apply: unparseableKeyExtension, - remoteErr: ContainSubstring("asn1"), + }, + { + name: "certificate chain contains 2 certs", + apply: twoCerts, + checkErr: func(t *testing.T, err error) { + require.EqualError(t, err, "expected one certificates in the chain") }, - { - name: "key protobuf not parseable", - apply: unparseableKey, - remoteErr: ContainSubstring("unmarshalling public key failed: proto:"), + }, + { + name: "cert is expired", + apply: expiredCert, + checkErr: func(t *testing.T, err error) { + require.Contains(t, err.Error(), "certificate has expired or is not yet valid") }, - { - name: "signature is malformed", - apply: tooShortSignature, - remoteErr: ContainSubstring("signature verification failed:"), + }, + { + name: "cert doesn't have the key extension", + apply: noKeyExtension, + checkErr: func(t *testing.T, err error) { + require.EqualError(t, err, "expected certificate to contain the key extension") }, - { - name: "signature is invalid", - apply: invalidSignature, - remoteErr: Equal("signature invalid"), + }, + { + name: "key extension not parseable", + apply: unparseableKeyExtension, + checkErr: func(t *testing.T, err error) { require.Contains(t, err.Error(), "asn1") }, + }, + { + name: "key protobuf not parseable", + apply: unparseableKey, + checkErr: func(t *testing.T, err error) { + require.Contains(t, err.Error(), "unmarshalling public key failed: proto:") }, - } + }, + { + name: "signature is malformed", + apply: tooShortSignature, + checkErr: func(t *testing.T, err error) { + require.Contains(t, err.Error(), "signature verification failed:") + }, + }, + { + name: "signature is invalid", + apply: invalidSignature, + checkErr: func(t *testing.T, err error) { + require.Contains(t, err.Error(), "signature invalid") + }, + }, + } - for i := range transforms { - t := transforms[i] - - It(fmt.Sprintf("fails if the client presents an invalid cert: %s", t.name), func() { - serverTransport, err := New(serverKey) - Expect(err).ToNot(HaveOccurred()) - clientTransport, err := New(clientKey) - Expect(err).ToNot(HaveOccurred()) - t.apply(clientTransport.identity) - - clientInsecureConn, serverInsecureConn := connect() - - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - _, err := serverTransport.SecureInbound(context.Background(), serverInsecureConn, "") - Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(t.remoteErr) - close(done) - }() - - conn, err := clientTransport.SecureOutbound(context.Background(), clientInsecureConn, serverID) - Expect(err).ToNot(HaveOccurred()) - _, err = gbytes.TimeoutReader(conn, time.Second).Read([]byte{0}) - Expect(err).To(Or( - // if the certificate's public key doesn't match the private key used for signing - MatchError("remote error: tls: error decrypting message"), - // all other errors - MatchError("remote error: tls: bad certificate"), - )) - Eventually(done).Should(BeClosed()) - }) - - It(fmt.Sprintf("fails if the server presents an invalid cert: %s", t.name), func() { - serverTransport, err := New(serverKey) - Expect(err).ToNot(HaveOccurred()) - t.apply(serverTransport.identity) - clientTransport, err := New(clientKey) - Expect(err).ToNot(HaveOccurred()) - - clientInsecureConn, serverInsecureConn := connect() - - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - _, err := serverTransport.SecureInbound(context.Background(), serverInsecureConn, "") - Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("remote error: tls:")) - close(done) - }() - - _, err = clientTransport.SecureOutbound(context.Background(), clientInsecureConn, serverID) - Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(t.remoteErr) - Eventually(done).Should(BeClosed()) - }) - } - }) -}) + for i := range transforms { + tr := transforms[i] + + t.Run(fmt.Sprintf("client offending: %s", tr.name), func(t *testing.T) { + serverTransport, err := New(serverKey) + require.NoError(t, err) + clientTransport, err := New(clientKey) + require.NoError(t, err) + tr.apply(clientTransport.identity) + + clientInsecureConn, serverInsecureConn := connect(t) + + serverErrChan := make(chan error) + go func() { + _, err := serverTransport.SecureInbound(context.Background(), serverInsecureConn, "") + serverErrChan <- err + }() + + conn, err := clientTransport.SecureOutbound(context.Background(), clientInsecureConn, serverID) + require.NoError(t, err) + clientErrChan := make(chan error) + go func() { + _, err := conn.Read([]byte{0}) + clientErrChan <- err + }() + var clientErr error + select { + case clientErr = <-clientErrChan: + case <-time.After(250 * time.Millisecond): + t.Fatal("expected the server handshake to return") + } + require.Error(t, clientErr) + if clientErr.Error() != "remote error: tls: error decrypting message" && + clientErr.Error() != "remote error: tls: bad certificate" { + t.Fatalf("unexpected error: %s", err.Error()) + } + + var serverErr error + select { + case serverErr = <-serverErrChan: + case <-time.After(250 * time.Millisecond): + t.Fatal("expected the server handshake to return") + } + require.Error(t, serverErr) + tr.checkErr(t, serverErr) + }) + + t.Run(fmt.Sprintf("server offending: %s", tr.name), func(t *testing.T) { + serverTransport, err := New(serverKey) + require.NoError(t, err) + tr.apply(serverTransport.identity) + clientTransport, err := New(clientKey) + require.NoError(t, err) + + clientInsecureConn, serverInsecureConn := connect(t) + + errChan := make(chan error) + go func() { + _, err := serverTransport.SecureInbound(context.Background(), serverInsecureConn, "") + errChan <- err + }() + + _, err = clientTransport.SecureOutbound(context.Background(), clientInsecureConn, serverID) + require.Error(t, err) + tr.checkErr(t, err) + + var serverErr error + select { + case serverErr = <-errChan: + case <-time.After(250 * time.Millisecond): + t.Fatal("expected the server handshake to return") + } + require.Error(t, serverErr) + require.Contains(t, serverErr.Error(), "remote error: tls:") + }) + } +}