Skip to content

Commit

Permalink
upstream: add quic connector
Browse files Browse the repository at this point in the history
  • Loading branch information
EugeneOne1 committed Jan 26, 2024
1 parent 7eb64fd commit e3b5131
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 14 deletions.
14 changes: 0 additions & 14 deletions upstream/quic_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package upstream

import (
"bytes"
"context"
"crypto/tls"
"crypto/x509"
Expand All @@ -17,7 +16,6 @@ import (

"github.com/AdguardTeam/dnsproxy/proxyutil"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns"
Expand Down Expand Up @@ -379,16 +377,6 @@ func (q *quicConnTracer) SentLongHeaderPacket(
}

func TestDNSOverQUIC_closingConns(t *testing.T) {
// TODO(e.burkov): !! get rid of this
oldLevel, logLevel := log.GetLevel(), log.DEBUG
oldWriter, logWriter := log.Writer(), &bytes.Buffer{}
log.SetLevel(logLevel)
log.SetOutput(logWriter)
t.Cleanup(func() {
log.SetLevel(oldLevel)
log.SetOutput(oldWriter)
})

addrPort := startDoQServer(t, 0)

upsURL := (&url.URL{
Expand Down Expand Up @@ -454,6 +442,4 @@ func TestDNSOverQUIC_closingConns(t *testing.T) {
t.Logf("got %d errors", len(wrapperSlice.Unwrap()))
}
})

t.Logf("logged during test: %s", logWriter)
}
66 changes: 66 additions & 0 deletions upstream/quicconnector.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package upstream

import (
"sync"
"sync/atomic"

"github.com/AdguardTeam/golibs/errors"
"github.com/quic-go/quic-go"
)

// quicConnResult is used to store the result of a single connection
// establishment.
type quicConnResult struct {
conn quic.Connection
err error
}

// quicConnector is used to establish a single connection on several demands.
type quicConnector struct {
value atomic.Pointer[quicConnResult]
once atomic.Pointer[sync.Once]
open func() (conn quic.Connection, err error)
}

// newQUICConnector creates a new quicConnector.
func newQUICConnector(open func() (quic.Connection, error)) (sf *quicConnector) {
sf = &quicConnector{
value: atomic.Pointer[quicConnResult]{},
once: atomic.Pointer[sync.Once]{},
open: open,
}
sf.value.Store(&quicConnResult{
conn: nil,
err: errors.Error("not initialized"),
})
sf.once.Store(&sync.Once{})

return sf
}

// reset enforces the next call to get to re-establish the connection.
func (sf *quicConnector) reset() {
sf.once.Store(&sync.Once{})
}

// get returns the connection. If the connection is not established yet, it
// will be established. If the connection establishment fails, the next call
// to get will try to establish the connection again.
func (sf *quicConnector) get() (c quic.Connection, err error) {
sf.once.Load().Do(sf.do)
res := sf.value.Load()

return res.conn, res.err
}

// do actually opens the connection and stores the result. It also check the
// error and resets the connector if the connection establishment failed.
func (sf *quicConnector) do() {
res := &quicConnResult{}
res.conn, res.err = sf.open()

sf.value.Store(res)
if res.err != nil {
sf.reset()
}
}
96 changes: 96 additions & 0 deletions upstream/quicconnector_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
package upstream

import (
"sync"
"sync/atomic"
"testing"

"github.com/AdguardTeam/golibs/testutil"
"github.com/quic-go/quic-go"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestShortFlighter(t *testing.T) {
const (
routineNum = 100
triesNum = 4
)

type testConn struct {
quic.Connection
}

var connTriesNum atomic.Int32
var beforeGet, afterGet sync.WaitGroup
pt := testutil.PanicT{}

t.Run("success", func(t *testing.T) {
t.Cleanup(func() { connTriesNum.Store(0) })

emptyConn := &testConn{}

open := func() (conn quic.Connection, err error) {
beforeGet.Wait()

connTriesNum.Add(1)

return emptyConn, nil
}

sf := newQUICConnector(open)

for i := 0; i < triesNum; i++ {
sf.reset()
beforeGet.Add(routineNum)
afterGet.Add(routineNum)

for j := 0; j < routineNum; j++ {
go func() {
beforeGet.Done()
conn, err := sf.get()
afterGet.Done()

require.NoError(pt, err)
require.Same(pt, emptyConn, conn)
}()
}
afterGet.Wait()

assert.Equal(t, int32(i+1), connTriesNum.Load())
}
})

t.Run("error", func(t *testing.T) {
t.Cleanup(func() { connTriesNum.Store(0) })

open := func() (conn quic.Connection, err error) {
beforeGet.Wait()

connTriesNum.Add(1)

return nil, assert.AnError
}

sf := newQUICConnector(open)

for i := 0; i < triesNum; i++ {
beforeGet.Add(routineNum)
afterGet.Add(routineNum)

for j := 0; j < routineNum; j++ {
go func() {
beforeGet.Done()
conn, err := sf.get()
afterGet.Done()

require.Nil(pt, conn)
require.Same(pt, assert.AnError, err)
}()
}
afterGet.Wait()

assert.Equal(t, int32(i+1), connTriesNum.Load())
}
})
}

0 comments on commit e3b5131

Please sign in to comment.