Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

swarm: move back off handling outside worker loop #2414

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions p2p/net/swarm/black_hole_detector.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ type blackHoleDetector struct {
}

// FilterAddrs filters the peer's addresses removing black holed addresses
func (d *blackHoleDetector) FilterAddrs(addrs []ma.Multiaddr) []ma.Multiaddr {
func (d *blackHoleDetector) FilterAddrs(addrs []ma.Multiaddr) (valid []ma.Multiaddr, blackHoled []ma.Multiaddr) {
hasUDP, hasIPv6 := false, false
for _, a := range addrs {
if !manet.IsPublicAddr(a) {
Expand All @@ -202,6 +202,7 @@ func (d *blackHoleDetector) FilterAddrs(addrs []ma.Multiaddr) []ma.Multiaddr {
ipv6Res = d.ipv6.HandleRequest()
}

blackHoled = make([]ma.Multiaddr, 0, len(addrs))
return ma.FilterAddrs(
addrs,
func(a ma.Multiaddr) bool {
Expand All @@ -218,14 +219,16 @@ func (d *blackHoleDetector) FilterAddrs(addrs []ma.Multiaddr) []ma.Multiaddr {
}

if udpRes == blackHoleResultBlocked && isProtocolAddr(a, ma.P_UDP) {
blackHoled = append(blackHoled, a)
return false
}
if ipv6Res == blackHoleResultBlocked && isProtocolAddr(a, ma.P_IP6) {
blackHoled = append(blackHoled, a)
return false
}
return true
},
)
), blackHoled
}

// RecordResult updates the state of the relevant `blackHoleFilter`s for addr
Expand Down
44 changes: 31 additions & 13 deletions p2p/net/swarm/black_hole_detector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func TestBlackHoleDetectorInApplicableAddress(t *testing.T) {
ma.StringCast("/ip4/192.168.1.5/udp/1234/quic-v1"),
}
for i := 0; i < 1000; i++ {
filteredAddrs := bhd.FilterAddrs(addrs)
filteredAddrs, _ := bhd.FilterAddrs(addrs)
require.ElementsMatch(t, addrs, filteredAddrs)
for j := 0; j < len(addrs); j++ {
bhd.RecordResult(addrs[j], false)
Expand All @@ -101,20 +101,29 @@ func TestBlackHoleDetectorUDPDisabled(t *testing.T) {
for i := 0; i < 100; i++ {
bhd.RecordResult(publicAddr, false)
}
addrs := []ma.Multiaddr{publicAddr, privAddr}
require.ElementsMatch(t, addrs, bhd.FilterAddrs(addrs))
wantAddrs := []ma.Multiaddr{publicAddr, privAddr}
wantRemovedAddrs := make([]ma.Multiaddr, 0)

gotAddrs, gotRemovedAddrs := bhd.FilterAddrs(wantAddrs)
require.ElementsMatch(t, wantAddrs, gotAddrs)
require.ElementsMatch(t, wantRemovedAddrs, gotRemovedAddrs)
}

func TestBlackHoleDetectorIPv6Disabled(t *testing.T) {
udpConfig := blackHoleConfig{Enabled: true, N: 10, MinSuccesses: 5}
bhd := newBlackHoleDetector(udpConfig, blackHoleConfig{Enabled: false}, nil)
publicAddr := ma.StringCast("/ip6/1::1/tcp/1234")
privAddr := ma.StringCast("/ip6/::1/tcp/1234")
addrs := []ma.Multiaddr{publicAddr, privAddr}
for i := 0; i < 100; i++ {
bhd.RecordResult(publicAddr, false)
}
require.ElementsMatch(t, addrs, bhd.FilterAddrs(addrs))

wantAddrs := []ma.Multiaddr{publicAddr, privAddr}
wantRemovedAddrs := make([]ma.Multiaddr, 0)

gotAddrs, gotRemovedAddrs := bhd.FilterAddrs(wantAddrs)
require.ElementsMatch(t, wantAddrs, gotAddrs)
require.ElementsMatch(t, wantRemovedAddrs, gotRemovedAddrs)
}

func TestBlackHoleDetectorProbes(t *testing.T) {
Expand All @@ -128,7 +137,7 @@ func TestBlackHoleDetectorProbes(t *testing.T) {
bhd.RecordResult(udp6Addr, false)
}
for i := 1; i < 100; i++ {
filteredAddrs := bhd.FilterAddrs(addrs)
filteredAddrs, _ := bhd.FilterAddrs(addrs)
if i%2 == 0 || i%3 == 0 {
if len(filteredAddrs) == 0 {
t.Fatalf("expected probe to be allowed irrespective of the state of other black hole filter")
Expand All @@ -145,7 +154,7 @@ func TestBlackHoleDetectorProbes(t *testing.T) {
func TestBlackHoleDetectorAddrFiltering(t *testing.T) {
udp6Pub := ma.StringCast("/ip6/1::1/udp/1234/quic-v1")
udp6Pri := ma.StringCast("/ip6/::1/udp/1234/quic-v1")
upd4Pub := ma.StringCast("/ip4/1.2.3.4/udp/1234/quic-v1")
udp4Pub := ma.StringCast("/ip4/1.2.3.4/udp/1234/quic-v1")
udp4Pri := ma.StringCast("/ip4/192.168.1.5/udp/1234/quic-v1")
tcp6Pub := ma.StringCast("/ip6/1::1/tcp/1234/quic-v1")
tcp6Pri := ma.StringCast("/ip6/::1/tcp/1234/quic-v1")
Expand All @@ -158,26 +167,35 @@ func TestBlackHoleDetectorAddrFiltering(t *testing.T) {
ipv6: &blackHoleFilter{n: 100, minSuccesses: 10, name: "ipv6"},
}
for i := 0; i < 100; i++ {
bhd.RecordResult(upd4Pub, !udpBlocked)
bhd.RecordResult(udp4Pub, !udpBlocked)
}
for i := 0; i < 100; i++ {
bhd.RecordResult(tcp6Pub, !ipv6Blocked)
}
return bhd
}

allInput := []ma.Multiaddr{udp6Pub, udp6Pri, upd4Pub, udp4Pri, tcp6Pub, tcp6Pri,
allInput := []ma.Multiaddr{udp6Pub, udp6Pri, udp4Pub, udp4Pri, tcp6Pub, tcp6Pri,
tcp4Pub, tcp4Pri}

udpBlockedOutput := []ma.Multiaddr{udp6Pri, udp4Pri, tcp6Pub, tcp6Pri, tcp4Pub, tcp4Pri}
udpPublicAddrs := []ma.Multiaddr{udp6Pub, udp4Pub}
bhd := makeBHD(true, false)
require.ElementsMatch(t, udpBlockedOutput, bhd.FilterAddrs(allInput))
gotAddrs, gotRemovedAddrs := bhd.FilterAddrs(allInput)
require.ElementsMatch(t, udpBlockedOutput, gotAddrs)
require.ElementsMatch(t, udpPublicAddrs, gotRemovedAddrs)

ip6BlockedOutput := []ma.Multiaddr{udp6Pri, upd4Pub, udp4Pri, tcp6Pri, tcp4Pub, tcp4Pri}
ip6BlockedOutput := []ma.Multiaddr{udp6Pri, udp4Pub, udp4Pri, tcp6Pri, tcp4Pub, tcp4Pri}
ip6PublicAddrs := []ma.Multiaddr{udp6Pub, tcp6Pub}
bhd = makeBHD(false, true)
require.ElementsMatch(t, ip6BlockedOutput, bhd.FilterAddrs(allInput))
gotAddrs, gotRemovedAddrs = bhd.FilterAddrs(allInput)
require.ElementsMatch(t, ip6BlockedOutput, gotAddrs)
require.ElementsMatch(t, ip6PublicAddrs, gotRemovedAddrs)

bothBlockedOutput := []ma.Multiaddr{udp6Pri, udp4Pri, tcp6Pri, tcp4Pub, tcp4Pri}
bothPublicAddrs := []ma.Multiaddr{udp6Pub, tcp6Pub, udp4Pub}
bhd = makeBHD(true, true)
require.ElementsMatch(t, bothBlockedOutput, bhd.FilterAddrs(allInput))
gotAddrs, gotRemovedAddrs = bhd.FilterAddrs(allInput)
require.ElementsMatch(t, bothBlockedOutput, gotAddrs)
require.ElementsMatch(t, bothPublicAddrs, gotRemovedAddrs)
}
114 changes: 41 additions & 73 deletions p2p/net/swarm/dial_worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ type pendRequest struct {
// At the time of creation addrs is initialised to all the addresses of the peer. On a failed dial,
// the addr is removed from the map and err is updated. On a successful dial, the dialRequest is
// completed and response is sent with the connection
addrs map[string]struct{}
addrs map[string]bool
}

// addrDial tracks dials to a particular multiaddress.
Expand All @@ -61,9 +61,6 @@ type addrDial struct {
conn *Conn
// err is the err on dialing the address
err error
// requests is the list of pendRequests interested in this dial
// the value in the slice is the request number assigned to this request by the dialWorker
requests []int
// dialed indicates whether we have triggered the dial to the address
dialed bool
// createdAt is the time this struct was created
Expand All @@ -79,13 +76,9 @@ type dialWorker struct {
peer peer.ID
// reqch is used to send dial requests to the worker. close reqch to end the worker loop
reqch <-chan dialRequest
// reqno is the request number used to track different dialRequests for a peer.
// Each incoming request is assigned a reqno. This reqno is used in pendingRequests and in
// addrDial objects in trackedDials to track this request
reqno int
// pendingRequests maps reqno to the pendRequest object for a dialRequest
pendingRequests map[int]*pendRequest
// trackedDials tracks dials to the peers addresses. An entry here is used to ensure that
// pendingRequests is the set of pendingRequests
pendingRequests map[*pendRequest]bool
// trackedDials tracks dials to the peer's addresses. An entry here is used to ensure that
// we dial an address at most once
trackedDials map[string]*addrDial
// resch is used to receive response for dials to the peers addresses.
Expand All @@ -106,7 +99,7 @@ func newDialWorker(s *Swarm, p peer.ID, reqch <-chan dialRequest, cl Clock) *dia
s: s,
peer: p,
reqch: reqch,
pendingRequests: make(map[int]*pendRequest),
pendingRequests: make(map[*pendRequest]bool),
trackedDials: make(map[string]*addrDial),
resch: make(chan dialResult),
cl: cl,
Expand Down Expand Up @@ -177,9 +170,14 @@ loop:
continue loop
}

addrs, err := w.s.addrsForDial(req.ctx, w.peer)
addrs, addrErrs, err := w.s.addrsForDial(req.ctx, w.peer)
if err != nil {
req.resch <- dialResponse{err: err}
req.resch <- dialResponse{
err: &DialError{
Peer: w.peer,
DialErrors: addrErrs,
Cause: err,
}}
continue loop
}

Expand All @@ -191,11 +189,11 @@ loop:
// create the pending request object
pr := &pendRequest{
req: req,
err: &DialError{Peer: w.peer},
addrs: make(map[string]struct{}, len(addrRanking)),
err: &DialError{Peer: w.peer, DialErrors: addrErrs},
addrs: make(map[string]bool, len(addrRanking)),
}
for _, adelay := range addrRanking {
pr.addrs[string(adelay.Addr.Bytes())] = struct{}{}
pr.addrs[string(adelay.Addr.Bytes())] = true
addrDelay[string(adelay.Addr.Bytes())] = adelay.Delay
}

Expand Down Expand Up @@ -233,14 +231,13 @@ loop:

if len(todial) == 0 && len(tojoin) == 0 {
// all request applicable addrs have been dialed, we must have errored
pr.err.Cause = ErrAllDialsFailed
req.resch <- dialResponse{err: pr.err}
continue loop
}

// The request has some pending or new dials. We assign this request a request number.
// This value of w.reqno is used to track this request in all the structures
w.reqno++
w.pendingRequests[w.reqno] = pr
// The request has some pending or new dials
w.pendingRequests[pr] = true

for _, ad := range tojoin {
if !ad.dialed {
Expand All @@ -258,7 +255,6 @@ loop:
}
}
// add the request to the addrDial
ad.requests = append(ad.requests, w.reqno)
}

if len(todial) > 0 {
Expand All @@ -268,7 +264,6 @@ loop:
w.trackedDials[string(a.Bytes())] = &addrDial{
addr: a,
ctx: req.ctx,
requests: []int{w.reqno},
createdAt: now,
}
dq.Add(network.AddrDelay{Addr: a, Delay: addrDelay[string(a.Bytes())]})
Expand All @@ -293,16 +288,9 @@ loop:
}
ad.dialed = true
ad.dialRankingDelay = now.Sub(ad.createdAt)
err := w.s.dialNextAddr(ad.ctx, w.peer, ad.addr, w.resch)
if err != nil {
// Errored without attempting a dial. This happens in case of
// backoff or black hole.
w.dispatchError(ad, err)
} else {
// the dial was successful. update inflight dials
dialsInFlight++
totalDials++
}
w.s.limitedDial(ad.ctx, w.peer, ad.addr, w.resch)
dialsInFlight++
totalDials++
}
timerRunning = false
// schedule more dials
Expand Down Expand Up @@ -333,20 +321,14 @@ loop:
continue loop
}

// request succeeded, respond to all pending requests
for _, reqno := range ad.requests {
pr, ok := w.pendingRequests[reqno]
if !ok {
// some other dial for this request succeeded before this one
continue
for pr := range w.pendingRequests {
if pr.addrs[string(ad.addr.Bytes())] {
pr.req.resch <- dialResponse{conn: conn}
delete(w.pendingRequests, pr)
}
pr.req.resch <- dialResponse{conn: conn}
delete(w.pendingRequests, reqno)
}

ad.conn = conn
ad.requests = nil

if !w.connected {
w.connected = true
if w.s.metricsTracer != nil {
Expand Down Expand Up @@ -380,40 +362,26 @@ loop:
// dispatches an error to a specific addr dial
func (w *dialWorker) dispatchError(ad *addrDial, err error) {
ad.err = err
for _, reqno := range ad.requests {
pr, ok := w.pendingRequests[reqno]
if !ok {
// some other dial for this request succeeded before this one
continue
}

for pr := range w.pendingRequests {
// accumulate the error
pr.err.recordErr(ad.addr, err)

delete(pr.addrs, string(ad.addr.Bytes()))
if len(pr.addrs) == 0 {
// all addrs have erred, dispatch dial error
// but first do a last one check in case an acceptable connection has landed from
// a simultaneous dial that started later and added new acceptable addrs
c, _ := w.s.bestAcceptableConnToPeer(pr.req.ctx, w.peer)
if c != nil {
pr.req.resch <- dialResponse{conn: c}
} else {
pr.req.resch <- dialResponse{err: pr.err}
if pr.addrs[string(ad.addr.Bytes())] {
pr.err.recordErr(ad.addr, err)
delete(pr.addrs, string(ad.addr.Bytes()))
if len(pr.addrs) == 0 {
// all addrs have erred, dispatch dial error
// but first do a last one check in case an acceptable connection has landed from
// a simultaneous dial that started later and added new acceptable addrs
c, _ := w.s.bestAcceptableConnToPeer(pr.req.ctx, w.peer)
if c != nil {
pr.req.resch <- dialResponse{conn: c}
} else {
pr.err.Cause = ErrAllDialsFailed
pr.req.resch <- dialResponse{err: pr.err}
}
delete(w.pendingRequests, pr)
}
delete(w.pendingRequests, reqno)
}
}

ad.requests = nil

// if it was a backoff, clear the address dial so that it doesn't inhibit new dial requests.
// this is necessary to support active listen scenarios, where a new dial comes in while
// another dial is in progress, and needs to do a direct connection without inhibitions from
// dial backoff.
if err == ErrDialBackoff {
delete(w.trackedDials, string(ad.addr.Bytes()))
}
}

// rankAddrs ranks addresses for dialing. if it's a simConnect request we
Expand Down
Loading