-
Notifications
You must be signed in to change notification settings - Fork 258
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
7eb64fd
commit e3b5131
Showing
3 changed files
with
162 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
package upstream | ||
|
||
import ( | ||
"sync" | ||
"sync/atomic" | ||
|
||
"github.com/AdguardTeam/golibs/errors" | ||
"github.com/quic-go/quic-go" | ||
) | ||
|
||
// quicConnResult is used to store the result of a single connection | ||
// establishment. | ||
type quicConnResult struct { | ||
conn quic.Connection | ||
err error | ||
} | ||
|
||
// quicConnector is used to establish a single connection on several demands. | ||
type quicConnector struct { | ||
value atomic.Pointer[quicConnResult] | ||
once atomic.Pointer[sync.Once] | ||
open func() (conn quic.Connection, err error) | ||
} | ||
|
||
// newQUICConnector creates a new quicConnector. | ||
func newQUICConnector(open func() (quic.Connection, error)) (sf *quicConnector) { | ||
sf = &quicConnector{ | ||
value: atomic.Pointer[quicConnResult]{}, | ||
once: atomic.Pointer[sync.Once]{}, | ||
open: open, | ||
} | ||
sf.value.Store(&quicConnResult{ | ||
conn: nil, | ||
err: errors.Error("not initialized"), | ||
}) | ||
sf.once.Store(&sync.Once{}) | ||
|
||
return sf | ||
} | ||
|
||
// reset enforces the next call to get to re-establish the connection. | ||
func (sf *quicConnector) reset() { | ||
sf.once.Store(&sync.Once{}) | ||
} | ||
|
||
// get returns the connection. If the connection is not established yet, it | ||
// will be established. If the connection establishment fails, the next call | ||
// to get will try to establish the connection again. | ||
func (sf *quicConnector) get() (c quic.Connection, err error) { | ||
sf.once.Load().Do(sf.do) | ||
res := sf.value.Load() | ||
|
||
return res.conn, res.err | ||
} | ||
|
||
// do actually opens the connection and stores the result. It also check the | ||
// error and resets the connector if the connection establishment failed. | ||
func (sf *quicConnector) do() { | ||
res := &quicConnResult{} | ||
res.conn, res.err = sf.open() | ||
|
||
sf.value.Store(res) | ||
if res.err != nil { | ||
sf.reset() | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
package upstream | ||
|
||
import ( | ||
"sync" | ||
"sync/atomic" | ||
"testing" | ||
|
||
"github.com/AdguardTeam/golibs/testutil" | ||
"github.com/quic-go/quic-go" | ||
"github.com/stretchr/testify/assert" | ||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
func TestShortFlighter(t *testing.T) { | ||
const ( | ||
routineNum = 100 | ||
triesNum = 4 | ||
) | ||
|
||
type testConn struct { | ||
quic.Connection | ||
} | ||
|
||
var connTriesNum atomic.Int32 | ||
var beforeGet, afterGet sync.WaitGroup | ||
pt := testutil.PanicT{} | ||
|
||
t.Run("success", func(t *testing.T) { | ||
t.Cleanup(func() { connTriesNum.Store(0) }) | ||
|
||
emptyConn := &testConn{} | ||
|
||
open := func() (conn quic.Connection, err error) { | ||
beforeGet.Wait() | ||
|
||
connTriesNum.Add(1) | ||
|
||
return emptyConn, nil | ||
} | ||
|
||
sf := newQUICConnector(open) | ||
|
||
for i := 0; i < triesNum; i++ { | ||
sf.reset() | ||
beforeGet.Add(routineNum) | ||
afterGet.Add(routineNum) | ||
|
||
for j := 0; j < routineNum; j++ { | ||
go func() { | ||
beforeGet.Done() | ||
conn, err := sf.get() | ||
afterGet.Done() | ||
|
||
require.NoError(pt, err) | ||
require.Same(pt, emptyConn, conn) | ||
}() | ||
} | ||
afterGet.Wait() | ||
|
||
assert.Equal(t, int32(i+1), connTriesNum.Load()) | ||
} | ||
}) | ||
|
||
t.Run("error", func(t *testing.T) { | ||
t.Cleanup(func() { connTriesNum.Store(0) }) | ||
|
||
open := func() (conn quic.Connection, err error) { | ||
beforeGet.Wait() | ||
|
||
connTriesNum.Add(1) | ||
|
||
return nil, assert.AnError | ||
} | ||
|
||
sf := newQUICConnector(open) | ||
|
||
for i := 0; i < triesNum; i++ { | ||
beforeGet.Add(routineNum) | ||
afterGet.Add(routineNum) | ||
|
||
for j := 0; j < routineNum; j++ { | ||
go func() { | ||
beforeGet.Done() | ||
conn, err := sf.get() | ||
afterGet.Done() | ||
|
||
require.Nil(pt, conn) | ||
require.Same(pt, assert.AnError, err) | ||
}() | ||
} | ||
afterGet.Wait() | ||
|
||
assert.Equal(t, int32(i+1), connTriesNum.Load()) | ||
} | ||
}) | ||
} |