Skip to content

Commit

Permalink
upstream: add locker test
Browse files Browse the repository at this point in the history
  • Loading branch information
EugeneOne1 committed Jan 24, 2024
1 parent 889865f commit 261dd38
Showing 1 changed file with 68 additions and 0 deletions.
68 changes: 68 additions & 0 deletions upstream/quic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,19 @@ import (
"io"
"net"
"net/netip"
"net/url"
"sync"
"testing"
"time"

"github.com/AdguardTeam/dnsproxy/proxyutil"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/logging"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -369,3 +372,68 @@ func (q *quicConnTracer) SentLongHeaderPacket(

q.packets = append(q.packets, hdr.Header)
}

func TestDNSOverQUIC_closingConns(t *testing.T) {
addrPort := startDoQServer(t, 0)

upsURL := (&url.URL{
Scheme: "quic",
Host: addrPort.String(),
}).String()

tracer := &quicTracer{}
opts := &Options{
InsecureSkipVerify: true,
Timeout: 5 * time.Second,
QUICTracer: tracer.TracerForConnection,
}

u, err := AddressToUpstream(upsURL, opts)
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, u.Close)
uq := testutil.RequireTypeAssert[*dnsOverQUIC](t, u)

// reqNum should be greater than the number of connections that will cause
// the race for connsMu.
const reqNum = 100

// Initialize a connection.
checkUpstream(t, u, upsURL)

errs := [reqNum]error{}
var errNum int

t.Run("resolve_concurrently", func(t *testing.T) {
// Lock the connection to make sure that we don't close it while
// resolving the upstream.
uq.connMu.Lock()

var beforeExchange, afterExchange sync.WaitGroup
beforeExchange.Add(reqNum)
afterExchange.Add(reqNum)

req := createTestMessage()
for i := 0; i < reqNum; i++ {
go func(i int) {
reqClone := req.Copy()

// Accumulate exchanging routines.
beforeExchange.Done()
_, errs[i] = u.Exchange(reqClone)

afterExchange.Done()
}(i)
}

beforeExchange.Wait()
// Let all the goroutines race for the connection.
uq.connMu.Unlock()
afterExchange.Wait()

if reqsErr := errors.Join(errs[:]...); !assert.NoError(t, reqsErr) {
errNum = len(reqsErr.(errors.WrapperSlice).Unwrap())
}
})

assert.Len(t, tracer.getConnectionsInfo(), errNum+1)
}

0 comments on commit 261dd38

Please sign in to comment.