From 7e1f7a8c3b1a3046bfc4e0aca089a1d6053419f5 Mon Sep 17 00:00:00 2001
From: Sukun <sukunrt@gmail.com>
Date: Mon, 4 Mar 2024 19:49:52 +0530
Subject: [PATCH 1/2] Limit size of encrypted packet queue

---
 conn.go      | 105 +++++++++++++++++++++++++++++++++++----------------
 conn_test.go |  85 +++++++++++++++++++++++++++++++++++++++++
 resume.go    |   6 ++-
 3 files changed, 162 insertions(+), 34 deletions(-)

diff --git a/conn.go b/conn.go
index 338f793ad..cf8551b40 100644
--- a/conn.go
+++ b/conn.go
@@ -34,6 +34,9 @@ const (
 	inboundBufferSize     = 8192
 	// Default replay protection window is specified by RFC 6347 Section 4.1.2.6
 	defaultReplayProtectionWindow = 64
+	// maxAppDataPacketQueueSize is the maximum number of app data packets we will
+	// enqueue before the handshake is completed
+	maxAppDataPacketQueueSize = 100
 )
 
 func invalidKeyingLabels() map[string]bool {
@@ -81,7 +84,7 @@ type Conn struct {
 	replayProtectionWindow uint
 }
 
-func createConn(ctx context.Context, nextConn net.Conn, config *Config, isClient bool, initialState *State) (*Conn, error) {
+func createConn(nextConn net.Conn, config *Config, isClient bool) (*Conn, error) {
 	err := validateConfig(config)
 	if err != nil {
 		return nil, err
@@ -91,21 +94,6 @@ func createConn(ctx context.Context, nextConn net.Conn, config *Config, isClient
 		return nil, errNilNextConn
 	}
 
-	cipherSuites, err := parseCipherSuites(config.CipherSuites, config.CustomCipherSuites, config.includeCertificateSuites(), config.PSK != nil)
-	if err != nil {
-		return nil, err
-	}
-
-	signatureSchemes, err := signaturehash.ParseSignatureSchemes(config.SignatureSchemes, config.InsecureHashes)
-	if err != nil {
-		return nil, err
-	}
-
-	workerInterval := initialTickerInterval
-	if config.FlightInterval != 0 {
-		workerInterval = config.FlightInterval
-	}
-
 	loggerFactory := config.LoggerFactory
 	if loggerFactory == nil {
 		loggerFactory = logging.NewDefaultLoggerFactory()
@@ -149,6 +137,38 @@ func createConn(ctx context.Context, nextConn net.Conn, config *Config, isClient
 
 	c.setRemoteEpoch(0)
 	c.setLocalEpoch(0)
+	return c, nil
+}
+
+func handshakeConn(ctx context.Context, conn *Conn, config *Config, isClient bool, initialState *State) (*Conn, error) {
+	if conn == nil {
+		return nil, errNilNextConn
+	}
+
+	cipherSuites, err := parseCipherSuites(config.CipherSuites, config.CustomCipherSuites, config.includeCertificateSuites(), config.PSK != nil)
+	if err != nil {
+		return nil, err
+	}
+
+	signatureSchemes, err := signaturehash.ParseSignatureSchemes(config.SignatureSchemes, config.InsecureHashes)
+	if err != nil {
+		return nil, err
+	}
+
+	workerInterval := initialTickerInterval
+	if config.FlightInterval != 0 {
+		workerInterval = config.FlightInterval
+	}
+
+	mtu := config.MTU
+	if mtu <= 0 {
+		mtu = defaultMTU
+	}
+
+	replayProtectionWindow := config.ReplayProtectionWindow
+	if replayProtectionWindow <= 0 {
+		replayProtectionWindow = defaultReplayProtectionWindow
+	}
 
 	serverName := config.ServerName
 	// Do not allow the use of an IP address literal as an SNI value.
@@ -180,7 +200,7 @@ func createConn(ctx context.Context, nextConn net.Conn, config *Config, isClient
 		clientCAs:                   config.ClientCAs,
 		customCipherSuites:          config.CustomCipherSuites,
 		retransmitInterval:          workerInterval,
-		log:                         logger,
+		log:                         conn.log,
 		initialEpoch:                0,
 		keyLogWriter:                config.KeyLogWriter,
 		sessionStore:                config.SessionStore,
@@ -205,16 +225,16 @@ func createConn(ctx context.Context, nextConn net.Conn, config *Config, isClient
 	var initialFSMState handshakeState
 
 	if initialState != nil {
-		if c.state.isClient {
+		if conn.state.isClient {
 			initialFlight = flight5
 		} else {
 			initialFlight = flight6
 		}
 		initialFSMState = handshakeFinished
 
-		c.state = *initialState
+		conn.state = *initialState
 	} else {
-		if c.state.isClient {
+		if conn.state.isClient {
 			initialFlight = flight1
 		} else {
 			initialFlight = flight0
@@ -222,13 +242,13 @@ func createConn(ctx context.Context, nextConn net.Conn, config *Config, isClient
 		initialFSMState = handshakePreparing
 	}
 	// Do handshake
-	if err := c.handshake(ctx, hsCfg, initialFlight, initialFSMState); err != nil {
+	if err := conn.handshake(ctx, hsCfg, initialFlight, initialFSMState); err != nil {
 		return nil, err
 	}
 
-	c.log.Trace("Handshake Completed")
+	conn.log.Trace("Handshake Completed")
 
-	return c, nil
+	return conn, nil
 }
 
 // Dial connects to the given network address and establishes a DTLS connection on top.
@@ -279,7 +299,12 @@ func ClientWithContext(ctx context.Context, conn net.Conn, config *Config) (*Con
 		return nil, errPSKAndIdentityMustBeSetForClient
 	}
 
-	return createConn(ctx, conn, config, true, nil)
+	dconn, err := createConn(conn, config, true)
+	if err != nil {
+		return nil, err
+	}
+
+	return handshakeConn(ctx, dconn, config, true, nil)
 }
 
 // ServerWithContext listens for incoming DTLS connections.
@@ -287,8 +312,11 @@ func ServerWithContext(ctx context.Context, conn net.Conn, config *Config) (*Con
 	if config == nil {
 		return nil, errNoConfigProvided
 	}
-
-	return createConn(ctx, conn, config, false, nil)
+	dconn, err := createConn(conn, config, false)
+	if err != nil {
+		return nil, err
+	}
+	return handshakeConn(ctx, dconn, config, false, nil)
 }
 
 // Read reads data from the connection.
@@ -662,7 +690,6 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, enqueue boo
 		c.log.Debugf("discarded broken packet: %v", err)
 		return false, nil, nil
 	}
-
 	// Validate epoch
 	remoteEpoch := c.state.getRemoteEpoch()
 	if h.Epoch > remoteEpoch {
@@ -673,8 +700,12 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, enqueue boo
 			return false, nil, nil
 		}
 		if enqueue {
-			c.log.Debug("received packet of next epoch, queuing packet")
-			c.encryptedPackets = append(c.encryptedPackets, buf)
+			if len(c.encryptedPackets) < maxAppDataPacketQueueSize {
+				c.log.Debug("received packet of next epoch, queuing packet")
+				c.encryptedPackets = append(c.encryptedPackets, buf)
+			} else {
+				c.log.Debug("app data packet queue full, dropping packet")
+			}
 		}
 		return false, nil, nil
 	}
@@ -697,8 +728,12 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, enqueue boo
 	if h.Epoch != 0 {
 		if c.state.cipherSuite == nil || !c.state.cipherSuite.IsInitialized() {
 			if enqueue {
-				c.encryptedPackets = append(c.encryptedPackets, buf)
-				c.log.Debug("handshake not finished, queuing packet")
+				if len(c.encryptedPackets) < maxAppDataPacketQueueSize {
+					c.encryptedPackets = append(c.encryptedPackets, buf)
+					c.log.Debug("handshake not finished, queuing packet")
+				} else {
+					c.log.Debug("app data packet queue full, dropping packet")
+				}
 			}
 			return false, nil, nil
 		}
@@ -749,8 +784,12 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, enqueue boo
 	case *protocol.ChangeCipherSpec:
 		if c.state.cipherSuite == nil || !c.state.cipherSuite.IsInitialized() {
 			if enqueue {
-				c.encryptedPackets = append(c.encryptedPackets, buf)
-				c.log.Debugf("CipherSuite not initialized, queuing packet")
+				if len(c.encryptedPackets) < maxAppDataPacketQueueSize {
+					c.encryptedPackets = append(c.encryptedPackets, buf)
+					c.log.Debugf("CipherSuite not initialized, queuing packet")
+				} else {
+					c.log.Debug("app data packet queue full. dropping packet")
+				}
 			}
 			return false, nil, nil
 		}
diff --git a/conn_test.go b/conn_test.go
index ea3c842f7..6083a050a 100644
--- a/conn_test.go
+++ b/conn_test.go
@@ -3050,3 +3050,88 @@ func (c *connWithCallback) Write(b []byte) (int, error) {
 	}
 	return c.Conn.Write(b)
 }
+
+func TestApplicationDataQueueLimited(t *testing.T) {
+	// Limit runtime in case of deadlocks
+	lim := test.TimeOut(time.Second * 20)
+	defer lim.Stop()
+
+	// Check for leaking routines
+	report := test.CheckRoutines(t)
+	defer report()
+
+	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+	defer cancel()
+
+	ca, cb := dpipe.Pipe()
+	defer ca.Close()
+	defer cb.Close()
+
+	done := make(chan struct{})
+	go func() {
+		serverCert, err := selfsign.GenerateSelfSigned()
+		if err != nil {
+			t.Error(err)
+			return
+		}
+		cfg := &Config{}
+		cfg.Certificates = []tls.Certificate{serverCert}
+
+		dconn, err := createConn(cb, cfg, false)
+		if err != nil {
+			t.Error(err)
+			return
+		}
+		go func() {
+			for i := 0; i < 5; i++ {
+				dconn.lock.RLock()
+				qlen := len(dconn.encryptedPackets)
+				dconn.lock.RUnlock()
+				if qlen > maxAppDataPacketQueueSize {
+					t.Error("too many encrypted packets enqueued", len(dconn.encryptedPackets))
+				}
+				t.Log(qlen)
+				time.Sleep(1 * time.Second)
+			}
+
+		}()
+		if _, err := handshakeConn(ctx, dconn, cfg, false, nil); err == nil {
+			t.Error("expected handshake to fail")
+		}
+		close(done)
+	}()
+	extensions := []extension.Extension{}
+
+	time.Sleep(50 * time.Millisecond)
+
+	err := sendClientHello([]byte{}, ca, 0, extensions)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	time.Sleep(50 * time.Millisecond)
+
+	for i := 0; i < 1000; i++ {
+		// Send an application data packet
+		packet, err := (&recordlayer.RecordLayer{
+			Header: recordlayer.Header{
+				Version:        protocol.Version1_2,
+				SequenceNumber: uint64(3),
+				Epoch:          1, // use an epoch greater than 0
+			},
+			Content: &protocol.ApplicationData{
+				Data: []byte{1, 2, 3, 4},
+			},
+		}).Marshal()
+		if err != nil {
+			t.Fatal(err)
+		}
+		ca.Write(packet)
+		if i%100 == 0 {
+			time.Sleep(10 * time.Millisecond)
+		}
+	}
+	time.Sleep(1 * time.Second)
+	ca.Close()
+	<-done
+}
diff --git a/resume.go b/resume.go
index c470d856b..f070d7537 100644
--- a/resume.go
+++ b/resume.go
@@ -13,7 +13,11 @@ func Resume(state *State, conn net.Conn, config *Config) (*Conn, error) {
 	if err := state.initCipherSuite(); err != nil {
 		return nil, err
 	}
-	c, err := createConn(context.Background(), conn, config, state.isClient, state)
+	dconn, err := createConn(conn, config, state.isClient)
+	if err != nil {
+		return nil, err
+	}
+	c, err := handshakeConn(context.Background(), dconn, config, state.isClient, state)
 	if err != nil {
 		return nil, err
 	}

From 793cefded6573f8bfdf9a3ebe4e5bb0e7ab9217b Mon Sep 17 00:00:00 2001
From: sukun <sukunrt@gmail.com>
Date: Tue, 30 Apr 2024 20:32:25 +0530
Subject: [PATCH 2/2] remove unused vars, factor out function

---
 conn.go | 33 +++++++++++----------------------
 1 file changed, 11 insertions(+), 22 deletions(-)

diff --git a/conn.go b/conn.go
index cf8551b40..04b4f7d57 100644
--- a/conn.go
+++ b/conn.go
@@ -160,16 +160,6 @@ func handshakeConn(ctx context.Context, conn *Conn, config *Config, isClient boo
 		workerInterval = config.FlightInterval
 	}
 
-	mtu := config.MTU
-	if mtu <= 0 {
-		mtu = defaultMTU
-	}
-
-	replayProtectionWindow := config.ReplayProtectionWindow
-	if replayProtectionWindow <= 0 {
-		replayProtectionWindow = defaultReplayProtectionWindow
-	}
-
 	serverName := config.ServerName
 	// Do not allow the use of an IP address literal as an SNI value.
 	// See RFC 6066, Section 3.
@@ -682,6 +672,14 @@ func (c *Conn) handleQueuedPackets(ctx context.Context) error {
 	return nil
 }
 
+func (c *Conn) enqueueEncryptedPackets(packet []byte) bool {
+	if len(c.encryptedPackets) < maxAppDataPacketQueueSize {
+		c.encryptedPackets = append(c.encryptedPackets, packet)
+		return true
+	}
+	return false
+}
+
 func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, enqueue bool) (bool, *alert.Alert, error) { //nolint:gocognit
 	h := &recordlayer.Header{}
 	if err := h.Unmarshal(buf); err != nil {
@@ -700,11 +698,8 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, enqueue boo
 			return false, nil, nil
 		}
 		if enqueue {
-			if len(c.encryptedPackets) < maxAppDataPacketQueueSize {
+			if ok := c.enqueueEncryptedPackets(buf); ok {
 				c.log.Debug("received packet of next epoch, queuing packet")
-				c.encryptedPackets = append(c.encryptedPackets, buf)
-			} else {
-				c.log.Debug("app data packet queue full, dropping packet")
 			}
 		}
 		return false, nil, nil
@@ -728,11 +723,8 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, enqueue boo
 	if h.Epoch != 0 {
 		if c.state.cipherSuite == nil || !c.state.cipherSuite.IsInitialized() {
 			if enqueue {
-				if len(c.encryptedPackets) < maxAppDataPacketQueueSize {
-					c.encryptedPackets = append(c.encryptedPackets, buf)
+				if ok := c.enqueueEncryptedPackets(buf); ok {
 					c.log.Debug("handshake not finished, queuing packet")
-				} else {
-					c.log.Debug("app data packet queue full, dropping packet")
 				}
 			}
 			return false, nil, nil
@@ -784,11 +776,8 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, enqueue boo
 	case *protocol.ChangeCipherSpec:
 		if c.state.cipherSuite == nil || !c.state.cipherSuite.IsInitialized() {
 			if enqueue {
-				if len(c.encryptedPackets) < maxAppDataPacketQueueSize {
-					c.encryptedPackets = append(c.encryptedPackets, buf)
+				if ok := c.enqueueEncryptedPackets(buf); ok {
 					c.log.Debugf("CipherSuite not initialized, queuing packet")
-				} else {
-					c.log.Debug("app data packet queue full. dropping packet")
 				}
 			}
 			return false, nil, nil