From d6fbfc62310c261fae114e568100dcc2a2228e6e Mon Sep 17 00:00:00 2001 From: Sukun Date: Wed, 12 Jul 2023 00:06:49 +0530 Subject: [PATCH 1/3] swarm: remove unnecessary reqno for pending request tracking --- p2p/net/swarm/dial_worker.go | 80 +++++++++++++----------------------- 1 file changed, 28 insertions(+), 52 deletions(-) diff --git a/p2p/net/swarm/dial_worker.go b/p2p/net/swarm/dial_worker.go index 0334ac863e..2d61fb55b8 100644 --- a/p2p/net/swarm/dial_worker.go +++ b/p2p/net/swarm/dial_worker.go @@ -48,7 +48,7 @@ type pendRequest struct { // At the time of creation addrs is initialised to all the addresses of the peer. On a failed dial, // the addr is removed from the map and err is updated. On a successful dial, the dialRequest is // completed and response is sent with the connection - addrs map[string]struct{} + addrs map[string]bool } // addrDial tracks dials to a particular multiaddress. @@ -61,9 +61,6 @@ type addrDial struct { conn *Conn // err is the err on dialing the address err error - // requests is the list of pendRequests interested in this dial - // the value in the slice is the request number assigned to this request by the dialWorker - requests []int // dialed indicates whether we have triggered the dial to the address dialed bool // createdAt is the time this struct was created @@ -79,13 +76,9 @@ type dialWorker struct { peer peer.ID // reqch is used to send dial requests to the worker. close reqch to end the worker loop reqch <-chan dialRequest - // reqno is the request number used to track different dialRequests for a peer. - // Each incoming request is assigned a reqno. This reqno is used in pendingRequests and in - // addrDial objects in trackedDials to track this request - reqno int - // pendingRequests maps reqno to the pendRequest object for a dialRequest - pendingRequests map[int]*pendRequest - // trackedDials tracks dials to the peers addresses. An entry here is used to ensure that + // pendingRequests is the set of pendingRequests + pendingRequests map[*pendRequest]bool + // trackedDials tracks dials to the peer's addresses. An entry here is used to ensure that // we dial an address at most once trackedDials map[string]*addrDial // resch is used to receive response for dials to the peers addresses. @@ -106,7 +99,7 @@ func newDialWorker(s *Swarm, p peer.ID, reqch <-chan dialRequest, cl Clock) *dia s: s, peer: p, reqch: reqch, - pendingRequests: make(map[int]*pendRequest), + pendingRequests: make(map[*pendRequest]bool), trackedDials: make(map[string]*addrDial), resch: make(chan dialResult), cl: cl, @@ -192,10 +185,10 @@ loop: pr := &pendRequest{ req: req, err: &DialError{Peer: w.peer}, - addrs: make(map[string]struct{}, len(addrRanking)), + addrs: make(map[string]bool, len(addrRanking)), } for _, adelay := range addrRanking { - pr.addrs[string(adelay.Addr.Bytes())] = struct{}{} + pr.addrs[string(adelay.Addr.Bytes())] = true addrDelay[string(adelay.Addr.Bytes())] = adelay.Delay } @@ -237,10 +230,8 @@ loop: continue loop } - // The request has some pending or new dials. We assign this request a request number. - // This value of w.reqno is used to track this request in all the structures - w.reqno++ - w.pendingRequests[w.reqno] = pr + // The request has some pending or new dials + w.pendingRequests[pr] = true for _, ad := range tojoin { if !ad.dialed { @@ -258,7 +249,6 @@ loop: } } // add the request to the addrDial - ad.requests = append(ad.requests, w.reqno) } if len(todial) > 0 { @@ -268,7 +258,6 @@ loop: w.trackedDials[string(a.Bytes())] = &addrDial{ addr: a, ctx: req.ctx, - requests: []int{w.reqno}, createdAt: now, } dq.Add(network.AddrDelay{Addr: a, Delay: addrDelay[string(a.Bytes())]}) @@ -333,20 +322,14 @@ loop: continue loop } - // request succeeded, respond to all pending requests - for _, reqno := range ad.requests { - pr, ok := w.pendingRequests[reqno] - if !ok { - // some other dial for this request succeeded before this one - continue + for pr := range w.pendingRequests { + if pr.addrs[string(ad.addr.Bytes())] { + pr.req.resch <- dialResponse{conn: conn} + delete(w.pendingRequests, pr) } - pr.req.resch <- dialResponse{conn: conn} - delete(w.pendingRequests, reqno) } ad.conn = conn - ad.requests = nil - if !w.connected { w.connected = true if w.s.metricsTracer != nil { @@ -380,33 +363,26 @@ loop: // dispatches an error to a specific addr dial func (w *dialWorker) dispatchError(ad *addrDial, err error) { ad.err = err - for _, reqno := range ad.requests { - pr, ok := w.pendingRequests[reqno] - if !ok { - // some other dial for this request succeeded before this one - continue - } - + for pr := range w.pendingRequests { // accumulate the error - pr.err.recordErr(ad.addr, err) - - delete(pr.addrs, string(ad.addr.Bytes())) - if len(pr.addrs) == 0 { - // all addrs have erred, dispatch dial error - // but first do a last one check in case an acceptable connection has landed from - // a simultaneous dial that started later and added new acceptable addrs - c, _ := w.s.bestAcceptableConnToPeer(pr.req.ctx, w.peer) - if c != nil { - pr.req.resch <- dialResponse{conn: c} - } else { - pr.req.resch <- dialResponse{err: pr.err} + if pr.addrs[string(ad.addr.Bytes())] { + pr.err.recordErr(ad.addr, err) + delete(pr.addrs, string(ad.addr.Bytes())) + if len(pr.addrs) == 0 { + // all addrs have erred, dispatch dial error + // but first do a last one check in case an acceptable connection has landed from + // a simultaneous dial that started later and added new acceptable addrs + c, _ := w.s.bestAcceptableConnToPeer(pr.req.ctx, w.peer) + if c != nil { + pr.req.resch <- dialResponse{conn: c} + } else { + pr.req.resch <- dialResponse{err: pr.err} + } + delete(w.pendingRequests, pr) } - delete(w.pendingRequests, reqno) } } - ad.requests = nil - // if it was a backoff, clear the address dial so that it doesn't inhibit new dial requests. // this is necessary to support active listen scenarios, where a new dial comes in while // another dial is in progress, and needs to do a direct connection without inhibitions from From ccd767f6120e89e5b397ee18e6f62111c16b5ef7 Mon Sep 17 00:00:00 2001 From: sukun Date: Mon, 7 Aug 2023 16:19:50 +0530 Subject: [PATCH 2/3] swarm: return errors on filtered addresses when dialling --- p2p/net/swarm/black_hole_detector.go | 7 ++- p2p/net/swarm/black_hole_detector_test.go | 44 +++++++++++----- p2p/net/swarm/dial_worker.go | 13 +++-- p2p/net/swarm/swarm_dial.go | 61 +++++++++++++++-------- p2p/net/swarm/swarm_dial_test.go | 15 +++--- 5 files changed, 93 insertions(+), 47 deletions(-) diff --git a/p2p/net/swarm/black_hole_detector.go b/p2p/net/swarm/black_hole_detector.go index 078b1126c4..dd7849eea6 100644 --- a/p2p/net/swarm/black_hole_detector.go +++ b/p2p/net/swarm/black_hole_detector.go @@ -178,7 +178,7 @@ type blackHoleDetector struct { } // FilterAddrs filters the peer's addresses removing black holed addresses -func (d *blackHoleDetector) FilterAddrs(addrs []ma.Multiaddr) []ma.Multiaddr { +func (d *blackHoleDetector) FilterAddrs(addrs []ma.Multiaddr) (valid []ma.Multiaddr, blackHoled []ma.Multiaddr) { hasUDP, hasIPv6 := false, false for _, a := range addrs { if !manet.IsPublicAddr(a) { @@ -202,6 +202,7 @@ func (d *blackHoleDetector) FilterAddrs(addrs []ma.Multiaddr) []ma.Multiaddr { ipv6Res = d.ipv6.HandleRequest() } + blackHoled = make([]ma.Multiaddr, 0, len(addrs)) return ma.FilterAddrs( addrs, func(a ma.Multiaddr) bool { @@ -218,14 +219,16 @@ func (d *blackHoleDetector) FilterAddrs(addrs []ma.Multiaddr) []ma.Multiaddr { } if udpRes == blackHoleResultBlocked && isProtocolAddr(a, ma.P_UDP) { + blackHoled = append(blackHoled, a) return false } if ipv6Res == blackHoleResultBlocked && isProtocolAddr(a, ma.P_IP6) { + blackHoled = append(blackHoled, a) return false } return true }, - ) + ), blackHoled } // RecordResult updates the state of the relevant `blackHoleFilter`s for addr diff --git a/p2p/net/swarm/black_hole_detector_test.go b/p2p/net/swarm/black_hole_detector_test.go index 7b10fc88a6..dfbb30f90d 100644 --- a/p2p/net/swarm/black_hole_detector_test.go +++ b/p2p/net/swarm/black_hole_detector_test.go @@ -85,7 +85,7 @@ func TestBlackHoleDetectorInApplicableAddress(t *testing.T) { ma.StringCast("/ip4/192.168.1.5/udp/1234/quic-v1"), } for i := 0; i < 1000; i++ { - filteredAddrs := bhd.FilterAddrs(addrs) + filteredAddrs, _ := bhd.FilterAddrs(addrs) require.ElementsMatch(t, addrs, filteredAddrs) for j := 0; j < len(addrs); j++ { bhd.RecordResult(addrs[j], false) @@ -101,8 +101,12 @@ func TestBlackHoleDetectorUDPDisabled(t *testing.T) { for i := 0; i < 100; i++ { bhd.RecordResult(publicAddr, false) } - addrs := []ma.Multiaddr{publicAddr, privAddr} - require.ElementsMatch(t, addrs, bhd.FilterAddrs(addrs)) + wantAddrs := []ma.Multiaddr{publicAddr, privAddr} + wantRemovedAddrs := make([]ma.Multiaddr, 0) + + gotAddrs, gotRemovedAddrs := bhd.FilterAddrs(wantAddrs) + require.ElementsMatch(t, wantAddrs, gotAddrs) + require.ElementsMatch(t, wantRemovedAddrs, gotRemovedAddrs) } func TestBlackHoleDetectorIPv6Disabled(t *testing.T) { @@ -110,11 +114,16 @@ func TestBlackHoleDetectorIPv6Disabled(t *testing.T) { bhd := newBlackHoleDetector(udpConfig, blackHoleConfig{Enabled: false}, nil) publicAddr := ma.StringCast("/ip6/1::1/tcp/1234") privAddr := ma.StringCast("/ip6/::1/tcp/1234") - addrs := []ma.Multiaddr{publicAddr, privAddr} for i := 0; i < 100; i++ { bhd.RecordResult(publicAddr, false) } - require.ElementsMatch(t, addrs, bhd.FilterAddrs(addrs)) + + wantAddrs := []ma.Multiaddr{publicAddr, privAddr} + wantRemovedAddrs := make([]ma.Multiaddr, 0) + + gotAddrs, gotRemovedAddrs := bhd.FilterAddrs(wantAddrs) + require.ElementsMatch(t, wantAddrs, gotAddrs) + require.ElementsMatch(t, wantRemovedAddrs, gotRemovedAddrs) } func TestBlackHoleDetectorProbes(t *testing.T) { @@ -128,7 +137,7 @@ func TestBlackHoleDetectorProbes(t *testing.T) { bhd.RecordResult(udp6Addr, false) } for i := 1; i < 100; i++ { - filteredAddrs := bhd.FilterAddrs(addrs) + filteredAddrs, _ := bhd.FilterAddrs(addrs) if i%2 == 0 || i%3 == 0 { if len(filteredAddrs) == 0 { t.Fatalf("expected probe to be allowed irrespective of the state of other black hole filter") @@ -145,7 +154,7 @@ func TestBlackHoleDetectorProbes(t *testing.T) { func TestBlackHoleDetectorAddrFiltering(t *testing.T) { udp6Pub := ma.StringCast("/ip6/1::1/udp/1234/quic-v1") udp6Pri := ma.StringCast("/ip6/::1/udp/1234/quic-v1") - upd4Pub := ma.StringCast("/ip4/1.2.3.4/udp/1234/quic-v1") + udp4Pub := ma.StringCast("/ip4/1.2.3.4/udp/1234/quic-v1") udp4Pri := ma.StringCast("/ip4/192.168.1.5/udp/1234/quic-v1") tcp6Pub := ma.StringCast("/ip6/1::1/tcp/1234/quic-v1") tcp6Pri := ma.StringCast("/ip6/::1/tcp/1234/quic-v1") @@ -158,7 +167,7 @@ func TestBlackHoleDetectorAddrFiltering(t *testing.T) { ipv6: &blackHoleFilter{n: 100, minSuccesses: 10, name: "ipv6"}, } for i := 0; i < 100; i++ { - bhd.RecordResult(upd4Pub, !udpBlocked) + bhd.RecordResult(udp4Pub, !udpBlocked) } for i := 0; i < 100; i++ { bhd.RecordResult(tcp6Pub, !ipv6Blocked) @@ -166,18 +175,27 @@ func TestBlackHoleDetectorAddrFiltering(t *testing.T) { return bhd } - allInput := []ma.Multiaddr{udp6Pub, udp6Pri, upd4Pub, udp4Pri, tcp6Pub, tcp6Pri, + allInput := []ma.Multiaddr{udp6Pub, udp6Pri, udp4Pub, udp4Pri, tcp6Pub, tcp6Pri, tcp4Pub, tcp4Pri} udpBlockedOutput := []ma.Multiaddr{udp6Pri, udp4Pri, tcp6Pub, tcp6Pri, tcp4Pub, tcp4Pri} + udpPublicAddrs := []ma.Multiaddr{udp6Pub, udp4Pub} bhd := makeBHD(true, false) - require.ElementsMatch(t, udpBlockedOutput, bhd.FilterAddrs(allInput)) + gotAddrs, gotRemovedAddrs := bhd.FilterAddrs(allInput) + require.ElementsMatch(t, udpBlockedOutput, gotAddrs) + require.ElementsMatch(t, udpPublicAddrs, gotRemovedAddrs) - ip6BlockedOutput := []ma.Multiaddr{udp6Pri, upd4Pub, udp4Pri, tcp6Pri, tcp4Pub, tcp4Pri} + ip6BlockedOutput := []ma.Multiaddr{udp6Pri, udp4Pub, udp4Pri, tcp6Pri, tcp4Pub, tcp4Pri} + ip6PublicAddrs := []ma.Multiaddr{udp6Pub, tcp6Pub} bhd = makeBHD(false, true) - require.ElementsMatch(t, ip6BlockedOutput, bhd.FilterAddrs(allInput)) + gotAddrs, gotRemovedAddrs = bhd.FilterAddrs(allInput) + require.ElementsMatch(t, ip6BlockedOutput, gotAddrs) + require.ElementsMatch(t, ip6PublicAddrs, gotRemovedAddrs) bothBlockedOutput := []ma.Multiaddr{udp6Pri, udp4Pri, tcp6Pri, tcp4Pub, tcp4Pri} + bothPublicAddrs := []ma.Multiaddr{udp6Pub, tcp6Pub, udp4Pub} bhd = makeBHD(true, true) - require.ElementsMatch(t, bothBlockedOutput, bhd.FilterAddrs(allInput)) + gotAddrs, gotRemovedAddrs = bhd.FilterAddrs(allInput) + require.ElementsMatch(t, bothBlockedOutput, gotAddrs) + require.ElementsMatch(t, bothPublicAddrs, gotRemovedAddrs) } diff --git a/p2p/net/swarm/dial_worker.go b/p2p/net/swarm/dial_worker.go index 2d61fb55b8..d6e9ef77f3 100644 --- a/p2p/net/swarm/dial_worker.go +++ b/p2p/net/swarm/dial_worker.go @@ -170,9 +170,14 @@ loop: continue loop } - addrs, err := w.s.addrsForDial(req.ctx, w.peer) + addrs, addrErrs, err := w.s.addrsForDial(req.ctx, w.peer) if err != nil { - req.resch <- dialResponse{err: err} + req.resch <- dialResponse{ + err: &DialError{ + Peer: w.peer, + DialErrors: addrErrs, + Cause: err, + }} continue loop } @@ -184,7 +189,7 @@ loop: // create the pending request object pr := &pendRequest{ req: req, - err: &DialError{Peer: w.peer}, + err: &DialError{Peer: w.peer, DialErrors: addrErrs}, addrs: make(map[string]bool, len(addrRanking)), } for _, adelay := range addrRanking { @@ -226,6 +231,7 @@ loop: if len(todial) == 0 && len(tojoin) == 0 { // all request applicable addrs have been dialed, we must have errored + pr.err.Cause = ErrAllDialsFailed req.resch <- dialResponse{err: pr.err} continue loop } @@ -376,6 +382,7 @@ func (w *dialWorker) dispatchError(ad *addrDial, err error) { if c != nil { pr.req.resch <- dialResponse{conn: c} } else { + pr.err.Cause = ErrAllDialsFailed pr.req.resch <- dialResponse{err: pr.err} } delete(w.pendingRequests, pr) diff --git a/p2p/net/swarm/swarm_dial.go b/p2p/net/swarm/swarm_dial.go index f2df93af2f..e17c27353b 100644 --- a/p2p/net/swarm/swarm_dial.go +++ b/p2p/net/swarm/swarm_dial.go @@ -280,10 +280,10 @@ func (s *Swarm) dialWorkerLoop(p peer.ID, reqch <-chan dialRequest) { w.loop() } -func (s *Swarm) addrsForDial(ctx context.Context, p peer.ID) ([]ma.Multiaddr, error) { +func (s *Swarm) addrsForDial(ctx context.Context, p peer.ID) (goodAddrs []ma.Multiaddr, addrErrs []TransportError, err error) { peerAddrs := s.peers.Addrs(p) if len(peerAddrs) == 0 { - return nil, ErrNoAddresses + return nil, nil, ErrNoAddresses } peerAddrsAfterTransportResolved := make([]ma.Multiaddr, 0, len(peerAddrs)) @@ -308,22 +308,22 @@ func (s *Swarm) addrsForDial(ctx context.Context, p peer.ID) ([]ma.Multiaddr, er Addrs: peerAddrsAfterTransportResolved, }) if err != nil { - return nil, err + return nil, nil, err } - goodAddrs := s.filterKnownUndialables(p, resolved) + goodAddrs = ma.Unique(resolved) + goodAddrs, addrErrs = s.filterKnownUndialables(p, goodAddrs) if forceDirect, _ := network.GetForceDirectDial(ctx); forceDirect { goodAddrs = ma.FilterAddrs(goodAddrs, s.nonProxyAddr) } - goodAddrs = ma.Unique(goodAddrs) if len(goodAddrs) == 0 { - return nil, ErrNoGoodAddresses + return nil, addrErrs, ErrNoGoodAddresses } s.peers.AddAddrs(p, goodAddrs, peerstore.TempAddrTTL) - return goodAddrs, nil + return goodAddrs, addrErrs, nil } func (s *Swarm) resolveAddrs(ctx context.Context, pi peer.AddrInfo) ([]ma.Multiaddr, error) { @@ -402,11 +402,6 @@ func (s *Swarm) dialNextAddr(ctx context.Context, p peer.ID, addr ma.Multiaddr, return nil } -func (s *Swarm) canDial(addr ma.Multiaddr) bool { - t := s.TransportForDialing(addr) - return t != nil && t.CanDial(addr) -} - func (s *Swarm) nonProxyAddr(addr ma.Multiaddr) bool { t := s.TransportForDialing(addr) return !t.Proxy() @@ -418,7 +413,7 @@ func (s *Swarm) nonProxyAddr(addr ma.Multiaddr) bool { // addresses that we know to be our own, and addresses with a better tranport // available. This is an optimization to avoid wasting time on dials that we // know are going to fail or for which we have a better alternative. -func (s *Swarm) filterKnownUndialables(p peer.ID, addrs []ma.Multiaddr) []ma.Multiaddr { +func (s *Swarm) filterKnownUndialables(p peer.ID, addrs []ma.Multiaddr) (goodAddrs []ma.Multiaddr, addrErrs []TransportError) { lisAddrs, _ := s.InterfaceListenAddresses() var ourAddrs []ma.Multiaddr for _, addr := range lisAddrs { @@ -431,27 +426,49 @@ func (s *Swarm) filterKnownUndialables(p peer.ID, addrs []ma.Multiaddr) []ma.Mul }) } - // The order of these two filters is important. If we can only dial /webtransport, - // we don't want to filter /webtransport addresses out because the peer had a /quic-v1 - // address + addrErrs = make([]TransportError, 0, len(addrs)) - // filter addresses we cannot dial - addrs = ma.FilterAddrs(addrs, s.canDial) + // The order of checking for transport and filtering low priority addrs is important. If we + // can only dial /webtransport, we don't want to filter /webtransport addresses out because + // the peer had a /quic-v1 address + + // filter addresses with no transport + addrs = ma.FilterAddrs(addrs, func(a ma.Multiaddr) bool { + if s.TransportForDialing(a) == nil { + addrErrs = append(addrErrs, TransportError{Address: a, Cause: ErrNoTransport}) + return false + } + return true + }) // filter low priority addresses among the addresses we can dial + // We don't return an error for these addresses addrs = filterLowPriorityAddresses(addrs) // remove black holed addrs - addrs = s.bhd.FilterAddrs(addrs) + addrs, blackHoledAddrs := s.bhd.FilterAddrs(addrs) + for _, a := range blackHoledAddrs { + addrErrs = append(addrErrs, TransportError{Address: a, Cause: ErrDialRefusedBlackHole}) + } return ma.FilterAddrs(addrs, - func(addr ma.Multiaddr) bool { return !ma.Contains(ourAddrs, addr) }, + func(addr ma.Multiaddr) bool { + if ma.Contains(ourAddrs, addr) { + addrErrs = append(addrErrs, TransportError{Address: addr, Cause: ErrDialToSelf}) + return false + } + return true + }, // TODO: Consider allowing link-local addresses func(addr ma.Multiaddr) bool { return !manet.IsIP6LinkLocal(addr) }, func(addr ma.Multiaddr) bool { - return s.gater == nil || s.gater.InterceptAddrDial(p, addr) + if s.gater != nil && !s.gater.InterceptAddrDial(p, addr) { + addrErrs = append(addrErrs, TransportError{Address: addr, Cause: ErrGaterDisallowedConnection}) + return false + } + return true }, - ) + ), addrErrs } // limitedDial will start a dial to the given peer when diff --git a/p2p/net/swarm/swarm_dial_test.go b/p2p/net/swarm/swarm_dial_test.go index 2f6b3f8c4d..f9d90b8c44 100644 --- a/p2p/net/swarm/swarm_dial_test.go +++ b/p2p/net/swarm/swarm_dial_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "crypto/rand" + "errors" "net" "sort" "testing" @@ -65,7 +66,7 @@ func TestAddrsForDial(t *testing.T) { ps.AddAddr(otherPeer, ma.StringCast("/dns4/example.com/tcp/1234/wss"), time.Hour) ctx := context.Background() - mas, err := s.addrsForDial(ctx, otherPeer) + mas, _, err := s.addrsForDial(ctx, otherPeer) require.NoError(t, err) require.NotZero(t, len(mas)) @@ -110,7 +111,7 @@ func TestDedupAddrsForDial(t *testing.T) { ps.AddAddr(otherPeer, ma.StringCast("/ip4/1.2.3.4/tcp/1234"), time.Hour) ctx := context.Background() - mas, err := s.addrsForDial(ctx, otherPeer) + mas, _, err := s.addrsForDial(ctx, otherPeer) require.NoError(t, err) require.Equal(t, 1, len(mas)) @@ -183,7 +184,7 @@ func TestAddrResolution(t *testing.T) { tctx, cancel := context.WithTimeout(ctx, time.Millisecond*100) defer cancel() - mas, err := s.addrsForDial(tctx, p1) + mas, _, err := s.addrsForDial(tctx, p1) require.NoError(t, err) require.Len(t, mas, 1) @@ -241,7 +242,7 @@ func TestAddrResolutionRecursive(t *testing.T) { tctx, cancel := context.WithTimeout(ctx, time.Millisecond*100) defer cancel() s.Peerstore().AddAddrs(pi1.ID, pi1.Addrs, peerstore.TempAddrTTL) - _, err = s.addrsForDial(tctx, p1) + _, _, err = s.addrsForDial(tctx, p1) require.NoError(t, err) addrs1 := s.Peerstore().Addrs(pi1.ID) @@ -253,7 +254,7 @@ func TestAddrResolutionRecursive(t *testing.T) { require.NoError(t, err) s.Peerstore().AddAddrs(pi2.ID, pi2.Addrs, peerstore.TempAddrTTL) - _, err = s.addrsForDial(tctx, p2) + _, _, err = s.addrsForDial(tctx, p2) // This never resolves to a good address require.Equal(t, ErrNoGoodAddresses, err) @@ -315,7 +316,7 @@ func TestAddrsForDialFiltering(t *testing.T) { t.Run(tc.name, func(t *testing.T) { s.Peerstore().ClearAddrs(p1) s.Peerstore().AddAddrs(p1, tc.input, peerstore.PermanentAddrTTL) - result, err := s.addrsForDial(ctx, p1) + result, _, err := s.addrsForDial(ctx, p1) require.NoError(t, err) sort.Slice(result, func(i, j int) bool { return bytes.Compare(result[i].Bytes(), result[j].Bytes()) < 0 }) sort.Slice(tc.output, func(i, j int) bool { return bytes.Compare(tc.output[i].Bytes(), tc.output[j].Bytes()) < 0 }) @@ -369,7 +370,7 @@ func TestBlackHoledAddrBlocked(t *testing.T) { if conn != nil { t.Fatalf("expected dial to be blocked") } - if err != ErrNoGoodAddresses { + if !errors.Is(err, ErrNoGoodAddresses) { t.Fatalf("expected to receive an error of type *DialError, got %s of type %T", err, err) } } From 12ead095c42819d22b1542b071b357b2c1de0c25 Mon Sep 17 00:00:00 2001 From: sukun Date: Mon, 7 Aug 2023 16:38:44 +0530 Subject: [PATCH 3/3] swarm: move back off handling outside worker loop --- p2p/net/swarm/dial_worker.go | 21 +++-------------- p2p/net/swarm/swarm_dial.go | 34 ++++++++++++---------------- p2p/net/swarm/swarm_dial_test.go | 39 +++++++++++++++++++++++++++++++- 3 files changed, 55 insertions(+), 39 deletions(-) diff --git a/p2p/net/swarm/dial_worker.go b/p2p/net/swarm/dial_worker.go index d6e9ef77f3..0bbf9c4d36 100644 --- a/p2p/net/swarm/dial_worker.go +++ b/p2p/net/swarm/dial_worker.go @@ -288,16 +288,9 @@ loop: } ad.dialed = true ad.dialRankingDelay = now.Sub(ad.createdAt) - err := w.s.dialNextAddr(ad.ctx, w.peer, ad.addr, w.resch) - if err != nil { - // Errored without attempting a dial. This happens in case of - // backoff or black hole. - w.dispatchError(ad, err) - } else { - // the dial was successful. update inflight dials - dialsInFlight++ - totalDials++ - } + w.s.limitedDial(ad.ctx, w.peer, ad.addr, w.resch) + dialsInFlight++ + totalDials++ } timerRunning = false // schedule more dials @@ -389,14 +382,6 @@ func (w *dialWorker) dispatchError(ad *addrDial, err error) { } } } - - // if it was a backoff, clear the address dial so that it doesn't inhibit new dial requests. - // this is necessary to support active listen scenarios, where a new dial comes in while - // another dial is in progress, and needs to do a direct connection without inhibitions from - // dial backoff. - if err == ErrDialBackoff { - delete(w.trackedDials, string(ad.addr.Bytes())) - } } // rankAddrs ranks addresses for dialing. if it's a simConnect request we diff --git a/p2p/net/swarm/swarm_dial.go b/p2p/net/swarm/swarm_dial.go index e17c27353b..5f4fc0a552 100644 --- a/p2p/net/swarm/swarm_dial.go +++ b/p2p/net/swarm/swarm_dial.go @@ -311,12 +311,10 @@ func (s *Swarm) addrsForDial(ctx context.Context, p peer.ID) (goodAddrs []ma.Mul return nil, nil, err } - goodAddrs = ma.Unique(resolved) - goodAddrs, addrErrs = s.filterKnownUndialables(p, goodAddrs) - if forceDirect, _ := network.GetForceDirectDial(ctx); forceDirect { - goodAddrs = ma.FilterAddrs(goodAddrs, s.nonProxyAddr) - } + forceDirect, _ := network.GetForceDirectDial(ctx) + goodAddrs = ma.Unique(resolved) + goodAddrs, addrErrs = s.filterKnownUndialables(p, goodAddrs, forceDirect) if len(goodAddrs) == 0 { return nil, addrErrs, ErrNoGoodAddresses } @@ -388,20 +386,6 @@ func (s *Swarm) resolveAddrs(ctx context.Context, pi peer.AddrInfo) ([]ma.Multia return resolved, nil } -func (s *Swarm) dialNextAddr(ctx context.Context, p peer.ID, addr ma.Multiaddr, resch chan dialResult) error { - // check the dial backoff - if forceDirect, _ := network.GetForceDirectDial(ctx); !forceDirect { - if s.backf.Backoff(p, addr) { - return ErrDialBackoff - } - } - - // start the dial - s.limitedDial(ctx, p, addr, resch) - - return nil -} - func (s *Swarm) nonProxyAddr(addr ma.Multiaddr) bool { t := s.TransportForDialing(addr) return !t.Proxy() @@ -413,7 +397,7 @@ func (s *Swarm) nonProxyAddr(addr ma.Multiaddr) bool { // addresses that we know to be our own, and addresses with a better tranport // available. This is an optimization to avoid wasting time on dials that we // know are going to fail or for which we have a better alternative. -func (s *Swarm) filterKnownUndialables(p peer.ID, addrs []ma.Multiaddr) (goodAddrs []ma.Multiaddr, addrErrs []TransportError) { +func (s *Swarm) filterKnownUndialables(p peer.ID, addrs []ma.Multiaddr, forceDirect bool) (goodAddrs []ma.Multiaddr, addrErrs []TransportError) { lisAddrs, _ := s.InterfaceListenAddresses() var ourAddrs []ma.Multiaddr for _, addr := range lisAddrs { @@ -468,6 +452,16 @@ func (s *Swarm) filterKnownUndialables(p peer.ID, addrs []ma.Multiaddr) (goodAdd } return true }, + func(addr ma.Multiaddr) bool { + if !forceDirect && s.backf.Backoff(p, addr) { + addrErrs = append(addrErrs, TransportError{Address: addr, Cause: ErrDialBackoff}) + return false + } + return true + }, + func(addr ma.Multiaddr) bool { + return !forceDirect || s.nonProxyAddr(addr) + }, ), addrErrs } diff --git a/p2p/net/swarm/swarm_dial_test.go b/p2p/net/swarm/swarm_dial_test.go index f9d90b8c44..acfaa40c40 100644 --- a/p2p/net/swarm/swarm_dial_test.go +++ b/p2p/net/swarm/swarm_dial_test.go @@ -370,7 +370,44 @@ func TestBlackHoledAddrBlocked(t *testing.T) { if conn != nil { t.Fatalf("expected dial to be blocked") } - if !errors.Is(err, ErrNoGoodAddresses) { + var de *DialError + if !errors.As(err, &de) { t.Fatalf("expected to receive an error of type *DialError, got %s of type %T", err, err) } + require.Contains(t, de.DialErrors, TransportError{Address: addr, Cause: ErrDialRefusedBlackHole}) +} + +func TestBackoffAddrBlocked(t *testing.T) { + resolver, err := madns.NewResolver() + if err != nil { + t.Fatal(err) + } + s := newTestSwarmWithResolver(t, resolver) + defer s.Close() + + // all dials to the address will fail. RFC6666 Discard Prefix + addr := ma.StringCast("/ip6/0100::1/tcp/54321/") + p, err := test.RandPeerID() + if err != nil { + t.Error(err) + } + s.Peerstore().AddAddr(p, addr, peerstore.PermanentAddrTTL) + + // do 1 extra dial to ensure that the blackHoleDetector state is updated since it + // happens in a different goroutine + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + conn, err := s.DialPeer(ctx, p) + require.Nil(t, conn) + require.Error(t, err) + cancel() + + ctx, cancel = context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + conn, err = s.DialPeer(ctx, p) + require.Nil(t, conn) + var de *DialError + if !errors.As(err, &de) { + t.Fatalf("expected to receive an error of type *DialError, got %s of type %T", err, err) + } + require.Contains(t, de.DialErrors, TransportError{Address: addr, Cause: ErrDialBackoff}) }