diff --git a/p2p/net/swarm/dial_worker.go b/p2p/net/swarm/dial_worker.go index 006959785a..df3a0964ec 100644 --- a/p2p/net/swarm/dial_worker.go +++ b/p2p/net/swarm/dial_worker.go @@ -48,7 +48,7 @@ type pendRequest struct { // At the time of creation addrs is initialised to all the addresses of the peer. On a failed dial, // the addr is removed from the map and err is updated. On a successful dial, the dialRequest is // completed and response is sent with the connection - addrs map[string]struct{} + addrs map[string]bool } // addrDial tracks dials to a particular multiaddress. @@ -61,9 +61,6 @@ type addrDial struct { conn *Conn // err is the err on dialing the address err error - // requests is the list of pendRequests interested in this dial - // the value in the slice is the request number assigned to this request by the dialWorker - requests []int // dialed indicates whether we have triggered the dial to the address dialed bool // createdAt is the time this struct was created @@ -79,13 +76,9 @@ type dialWorker struct { peer peer.ID // reqch is used to send dial requests to the worker. close reqch to end the worker loop reqch <-chan dialRequest - // reqno is the request number used to track different dialRequests for a peer. - // Each incoming request is assigned a reqno. This reqno is used in pendingRequests and in - // addrDial objects in trackedDials to track this request - reqno int - // pendingRequests maps reqno to the pendRequest object for a dialRequest - pendingRequests map[int]*pendRequest - // trackedDials tracks dials to the peers addresses. An entry here is used to ensure that + // pendingRequests is the set of pendingRequests + pendingRequests map[*pendRequest]bool + // trackedDials tracks dials to the peer's addresses. An entry here is used to ensure that // we dial an address at most once trackedDials map[string]*addrDial // resch is used to receive response for dials to the peers addresses. @@ -106,7 +99,7 @@ func newDialWorker(s *Swarm, p peer.ID, reqch <-chan dialRequest, cl Clock) *dia s: s, peer: p, reqch: reqch, - pendingRequests: make(map[int]*pendRequest), + pendingRequests: make(map[*pendRequest]bool), trackedDials: make(map[string]*addrDial), resch: make(chan dialResult), cl: cl, @@ -192,10 +185,10 @@ loop: pr := &pendRequest{ req: req, err: &DialError{Peer: w.peer}, - addrs: make(map[string]struct{}, len(addrRanking)), + addrs: make(map[string]bool, len(addrRanking)), } for _, adelay := range addrRanking { - pr.addrs[string(adelay.Addr.Bytes())] = struct{}{} + pr.addrs[string(adelay.Addr.Bytes())] = true addrDelay[string(adelay.Addr.Bytes())] = adelay.Delay } @@ -237,10 +230,8 @@ loop: continue loop } - // The request has some pending or new dials. We assign this request a request number. - // This value of w.reqno is used to track this request in all the structures - w.reqno++ - w.pendingRequests[w.reqno] = pr + // The request has some pending or new dials + w.pendingRequests[pr] = true for _, ad := range tojoin { if !ad.dialed { @@ -258,7 +249,6 @@ loop: } } // add the request to the addrDial - ad.requests = append(ad.requests, w.reqno) } if len(todial) > 0 { @@ -268,7 +258,6 @@ loop: w.trackedDials[string(a.Bytes())] = &addrDial{ addr: a, ctx: req.ctx, - requests: []int{w.reqno}, createdAt: now, } dq.Add(network.AddrDelay{Addr: a, Delay: addrDelay[string(a.Bytes())]}) @@ -326,20 +315,14 @@ loop: continue loop } - // request succeeded, respond to all pending requests - for _, reqno := range ad.requests { - pr, ok := w.pendingRequests[reqno] - if !ok { - // some other dial for this request succeeded before this one - continue + for pr := range w.pendingRequests { + if pr.addrs[string(ad.addr.Bytes())] { + pr.req.resch <- dialResponse{conn: conn} + delete(w.pendingRequests, pr) } - pr.req.resch <- dialResponse{conn: conn} - delete(w.pendingRequests, reqno) } ad.conn = conn - ad.requests = nil - if !w.connected { w.connected = true if w.s.metricsTracer != nil { @@ -367,32 +350,25 @@ loop: // dispatches an error to a specific addr dial func (w *dialWorker) dispatchError(ad *addrDial, err error) { ad.err = err - for _, reqno := range ad.requests { - pr, ok := w.pendingRequests[reqno] - if !ok { - // some other dial for this request succeeded before this one - continue - } - + for pr := range w.pendingRequests { // accumulate the error - pr.err.recordErr(ad.addr, err) - - delete(pr.addrs, string(ad.addr.Bytes())) - if len(pr.addrs) == 0 { - // all addrs have erred, dispatch dial error - // but first do a last one check in case an acceptable connection has landed from - // a simultaneous dial that started later and added new acceptable addrs - c, _ := w.s.bestAcceptableConnToPeer(pr.req.ctx, w.peer) - if c != nil { - pr.req.resch <- dialResponse{conn: c} - } else { - pr.req.resch <- dialResponse{err: pr.err} + if pr.addrs[string(ad.addr.Bytes())] { + pr.err.recordErr(ad.addr, err) + delete(pr.addrs, string(ad.addr.Bytes())) + if len(pr.addrs) == 0 { + // all addrs have erred, dispatch dial error + // but first do a last one check in case an acceptable connection has landed from + // a simultaneous dial that started later and added new acceptable addrs + c, _ := w.s.bestAcceptableConnToPeer(pr.req.ctx, w.peer) + if c != nil { + pr.req.resch <- dialResponse{conn: c} + } else { + pr.req.resch <- dialResponse{err: pr.err} + } + delete(w.pendingRequests, pr) } - delete(w.pendingRequests, reqno) } } - - ad.requests = nil } // rankAddrs ranks addresses for dialing. if it's a simConnect request we