From a0a498ed0e17e002af9e18104a3bc1393373bb16 Mon Sep 17 00:00:00 2001
From: Marco Munizaga <git@marcopolo.io>
Date: Wed, 18 Dec 2024 15:45:30 -0800
Subject: [PATCH] fix(httpauth): Correctly handle concurrent requests on server
 (#3111)

Co-authored-by: Adin Schmahmann <adin.schmahmann@gmail.com>
---
 p2p/http/auth/auth_test.go | 66 +++++++++++++++++++++++++++++++-------
 p2p/http/auth/server.go    | 39 +++++++++++++++++++---
 2 files changed, 89 insertions(+), 16 deletions(-)

diff --git a/p2p/http/auth/auth_test.go b/p2p/http/auth/auth_test.go
index d080b19511..9d1b4688d4 100644
--- a/p2p/http/auth/auth_test.go
+++ b/p2p/http/auth/auth_test.go
@@ -2,11 +2,9 @@ package httppeeridauth
 
 import (
 	"bytes"
-	"crypto/hmac"
 	"crypto/rand"
-	"crypto/sha256"
 	"crypto/tls"
-	"hash"
+	"fmt"
 	"io"
 	"net/http"
 	"net/http/httptest"
@@ -171,14 +169,12 @@ func TestMutualAuth(t *testing.T) {
 
 				t.Run("Tokens Invalidated", func(t *testing.T) {
 					// Clear the auth token on the server side
-					server.Hmac = func() hash.Hash {
-						key := make([]byte, 32)
-						_, err := rand.Read(key)
-						if err != nil {
-							panic(err)
-						}
-						return hmac.New(sha256.New, key)
-					}()
+					key := make([]byte, 32)
+					_, err := rand.Read(key)
+					if err != nil {
+						panic(err)
+					}
+					server.hmacPool = newHmacPool(key)
 
 					req, err := http.NewRequest("POST", ts.URL, nil)
 					req.GetBody = func() (io.ReadCloser, error) {
@@ -241,3 +237,51 @@ func (irt *instrumentedRoundTripper) RoundTrip(req *http.Request) (*http.Respons
 func (irt *instrumentedRoundTripper) TLSClientConfig() *tls.Config {
 	return irt.RoundTripper.(*http.Transport).TLSClientConfig
 }
+
+func TestConcurrentAuth(t *testing.T) {
+	serverKey, _, err := crypto.GenerateEd25519Key(rand.Reader)
+	require.NoError(t, err)
+
+	auth := ServerPeerIDAuth{
+		PrivKey: serverKey,
+		ValidHostnameFn: func(s string) bool {
+			return s == "example.com"
+		},
+		TokenTTL: time.Hour,
+		NoTLS:    true,
+		Next: func(peer peer.ID, w http.ResponseWriter, r *http.Request) {
+			reqBody, err := io.ReadAll(r.Body)
+			require.NoError(t, err)
+			_, err = w.Write(reqBody)
+			require.NoError(t, err)
+		},
+	}
+
+	ts := httptest.NewServer(&auth)
+	t.Cleanup(ts.Close)
+
+	wg := sync.WaitGroup{}
+	for i := 0; i < 10; i++ {
+		wg.Add(1)
+		go func() {
+			defer wg.Done()
+			clientKey, _, err := crypto.GenerateEd25519Key(rand.Reader)
+			require.NoError(t, err)
+
+			clientAuth := ClientPeerIDAuth{PrivKey: clientKey}
+			reqBody := []byte(fmt.Sprintf("echo %d", i))
+			req, err := http.NewRequest("POST", ts.URL, bytes.NewReader(reqBody))
+			require.NoError(t, err)
+			req.Host = "example.com"
+
+			client := ts.Client()
+			_, resp, err := clientAuth.AuthenticatedDo(client, req)
+			require.NoError(t, err)
+			require.Equal(t, http.StatusOK, resp.StatusCode)
+			respBody, err := io.ReadAll(resp.Body)
+			require.NoError(t, err)
+			require.Equal(t, reqBody, respBody)
+		}()
+	}
+	wg.Wait()
+}
diff --git a/p2p/http/auth/server.go b/p2p/http/auth/server.go
index 3ee4f96dc8..b17c3fccf1 100644
--- a/p2p/http/auth/server.go
+++ b/p2p/http/auth/server.go
@@ -15,6 +15,30 @@ import (
 	"github.com/libp2p/go-libp2p/p2p/http/auth/internal/handshake"
 )
 
+type hmacPool struct {
+	p sync.Pool
+}
+
+func newHmacPool(key []byte) *hmacPool {
+	return &hmacPool{
+		p: sync.Pool{
+			New: func() any {
+				return hmac.New(sha256.New, key)
+			},
+		},
+	}
+}
+
+func (p *hmacPool) Get() hash.Hash {
+	h := p.p.Get().(hash.Hash)
+	h.Reset()
+	return h
+}
+
+func (p *hmacPool) Put(h hash.Hash) {
+	p.p.Put(h)
+}
+
 type ServerPeerIDAuth struct {
 	PrivKey  crypto.PrivKey
 	TokenTTL time.Duration
@@ -26,8 +50,9 @@ type ServerPeerIDAuth struct {
 	// which the Host header returns true.
 	ValidHostnameFn func(hostname string) bool
 
-	Hmac     hash.Hash
+	HmacKey  []byte
 	initHmac sync.Once
+	hmacPool *hmacPool
 }
 
 // ServeHTTP implements the http.Handler interface for PeerIDAuth. It will
@@ -36,14 +61,15 @@ type ServerPeerIDAuth struct {
 // requests.
 func (a *ServerPeerIDAuth) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 	a.initHmac.Do(func() {
-		if a.Hmac == nil {
+		if a.HmacKey == nil {
 			key := make([]byte, 32)
 			_, err := rand.Read(key)
 			if err != nil {
 				panic(err)
 			}
-			a.Hmac = hmac.New(sha256.New, key)
+			a.HmacKey = key
 		}
+		a.hmacPool = newHmacPool(a.HmacKey)
 	})
 
 	hostname := r.Host
@@ -76,11 +102,13 @@ func (a *ServerPeerIDAuth) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 		}
 	}
 
+	hmac := a.hmacPool.Get()
+	defer a.hmacPool.Put(hmac)
 	hs := handshake.PeerIDAuthHandshakeServer{
 		Hostname: hostname,
 		PrivKey:  a.PrivKey,
 		TokenTTL: a.TokenTTL,
-		Hmac:     a.Hmac,
+		Hmac:     hmac,
 	}
 	err := hs.ParseHeaderVal([]byte(r.Header.Get("Authorization")))
 	if err != nil {
@@ -95,11 +123,12 @@ func (a *ServerPeerIDAuth) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 			errors.Is(err, handshake.ErrExpiredChallenge),
 			errors.Is(err, handshake.ErrExpiredToken):
 
+			hmac.Reset()
 			hs := handshake.PeerIDAuthHandshakeServer{
 				Hostname: hostname,
 				PrivKey:  a.PrivKey,
 				TokenTTL: a.TokenTTL,
-				Hmac:     a.Hmac,
+				Hmac:     hmac,
 			}
 			hs.Run()
 			hs.SetHeader(w.Header())