Skip to content

Commit

Permalink
Add perIPTLSConn to support MaxConnsPerIP with tls connections
Browse files Browse the repository at this point in the history
Otherwise calling RequestCtx.TLSConnectionState() will fail.

Fixes #1770
  • Loading branch information
erikdubbelboer committed Apr 29, 2024
1 parent a8fa9c0 commit 105eb3b
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 9 deletions.
46 changes: 37 additions & 9 deletions peripconn.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
package fasthttp

import (
"crypto/tls"
"net"
"sync"
)

type perIPConnCounter struct {
pool sync.Pool
lock sync.Mutex
m map[uint32]int
perIPConnPool sync.Pool
perIPTLSConnPool sync.Pool
lock sync.Mutex
m map[uint32]int
}

func (cc *perIPConnCounter) Register(ip uint32) int {
Expand Down Expand Up @@ -43,8 +45,30 @@ type perIPConn struct {
perIPConnCounter *perIPConnCounter
}

func acquirePerIPConn(conn net.Conn, ip uint32, counter *perIPConnCounter) *perIPConn {
v := counter.pool.Get()
type perIPTLSConn struct {
*tls.Conn

ip uint32
perIPConnCounter *perIPConnCounter
}

func acquirePerIPConn(conn net.Conn, ip uint32, counter *perIPConnCounter) net.Conn {
if tlcConn, ok := conn.(*tls.Conn); ok {
v := counter.perIPTLSConnPool.Get()
if v == nil {
return &perIPTLSConn{
perIPConnCounter: counter,
Conn: tlcConn,
ip: ip,
}
}
c := v.(*perIPConn)
c.Conn = conn
c.ip = ip
return c
}

v := counter.perIPConnPool.Get()
if v == nil {
return &perIPConn{
perIPConnCounter: counter,
Expand All @@ -58,15 +82,19 @@ func acquirePerIPConn(conn net.Conn, ip uint32, counter *perIPConnCounter) *perI
return c
}

func releasePerIPConn(c *perIPConn) {
func (c *perIPConn) Close() error {
err := c.Conn.Close()
c.perIPConnCounter.Unregister(c.ip)
c.Conn = nil
c.perIPConnCounter.pool.Put(c)
c.perIPConnCounter.perIPConnPool.Put(c)
return err
}

func (c *perIPConn) Close() error {
func (c *perIPTLSConn) Close() error {
err := c.Conn.Close()
c.perIPConnCounter.Unregister(c.ip)
releasePerIPConn(c)
c.Conn = nil
c.perIPConnCounter.perIPTLSConnPool.Put(c)
return err
}

Expand Down
2 changes: 2 additions & 0 deletions peripconn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"testing"
)

var _ connTLSer = &perIPTLSConn{}

func TestIPxUint32(t *testing.T) {
t.Parallel()

Expand Down

0 comments on commit 105eb3b

Please sign in to comment.