diff --git a/p2p/peering.go b/p2p/peering.go index ad07faf6c..7e1bec982 100644 --- a/p2p/peering.go +++ b/p2p/peering.go @@ -114,9 +114,10 @@ type RemotePeer struct { // headers message management. Headers can either be fetched synchronously // or used to push block notifications with sendheaders. - requestedHeaders chan<- *wire.MsgHeaders // non-nil result chan when synchronous getheaders in process - sendheaders bool // whether a sendheaders message was sent - requestedHeadersMu sync.Mutex + requestedHeaders chan<- *wire.MsgHeaders // non-nil result chan when synchronous getheaders in process + requestedHeadersLoc []*chainhash.Hash // non-nil when requested headers with getheaders + sendheaders bool // whether a sendheaders message was sent + requestedHeadersMu sync.Mutex // init state message management. requestedInitState chan<- *wire.MsgInitState // non-nil result chan when synchronous getinitstate in process @@ -1097,7 +1098,7 @@ func (rp *RemotePeer) receivedCFilterV2(ctx context.Context, msg *wire.MsgCFilte } } -func (rp *RemotePeer) addRequestedHeaders(c chan<- *wire.MsgHeaders) (sendheaders, newRequest bool) { +func (rp *RemotePeer) addRequestedHeaders(c chan<- *wire.MsgHeaders, loc []*chainhash.Hash) (sendheaders, newRequest bool) { rp.requestedHeadersMu.Lock() if rp.sendheaders { rp.requestedHeadersMu.Unlock() @@ -1108,6 +1109,7 @@ func (rp *RemotePeer) addRequestedHeaders(c chan<- *wire.MsgHeaders) (sendheader return false, false } rp.requestedHeaders = c + rp.requestedHeadersLoc = loc rp.requestedHeadersMu.Unlock() return false, true } @@ -1115,6 +1117,7 @@ func (rp *RemotePeer) addRequestedHeaders(c chan<- *wire.MsgHeaders) (sendheader func (rp *RemotePeer) deleteRequestedHeaders() { rp.requestedHeadersMu.Lock() rp.requestedHeaders = nil + rp.requestedHeadersLoc = nil rp.requestedHeadersMu.Unlock() } @@ -1146,6 +1149,29 @@ func (rp *RemotePeer) receivedHeaders(ctx context.Context, msg *wire.MsgHeaders) return } + // The parent of the first header (if there is one) MUST be one of the + // block locators we used to request headers from the peer when this + // is a response to a getheaders request. + if len(msg.Headers) > 0 && rp.requestedHeadersLoc != nil { + wantParent := msg.Headers[0].PrevBlock + contains := false + for _, loc := range rp.requestedHeadersLoc { + if *loc == wantParent { + contains = true + break + } + } + if !contains { + op := errors.Opf(opf, rp.raddr) + err := errors.E(op, errors.Protocol, + "peer sent headers that do not connect "+ + "to block locators") + rp.Disconnect(err) + rp.requestedHeadersMu.Unlock() + return + } + } + // Sanity check the headers connect to each other in sequence. var prevHash chainhash.Hash var prevHeight uint32 @@ -1186,6 +1212,7 @@ func (rp *RemotePeer) receivedHeaders(ctx context.Context, msg *wire.MsgHeaders) // Headers as a response to getheaders. c := rp.requestedHeaders rp.requestedHeaders = nil + rp.requestedHeadersLoc = nil rp.requestedHeadersMu.Unlock() select { case <-ctx.Done(): @@ -1988,7 +2015,7 @@ func (rp *RemotePeer) Headers(ctx context.Context, blockLocators []*chainhash.Ha HashStop: *hashStop, } c := make(chan *wire.MsgHeaders, 1) - sendheaders, newRequest := rp.addRequestedHeaders(c) + sendheaders, newRequest := rp.addRequestedHeaders(c, blockLocators) if sendheaders { op := errors.Opf(opf, rp.raddr) return nil, errors.E(op, errors.Invalid, "synchronous getheaders after sendheaders is unsupported") @@ -2020,28 +2047,6 @@ func (rp *RemotePeer) Headers(ctx context.Context, blockLocators []*chainhash.Ha case m := <-c: stalled.Stop() - // The parent of the first header (if there is one) MUST - // be one of the block locators we used to request - // headers from the peer. - if len(m.Headers) > 0 { - wantParent := m.Headers[0].PrevBlock - contains := false - for _, loc := range blockLocators { - if *loc == wantParent { - contains = true - break - } - } - if !contains { - op := errors.Opf(opf, rp.raddr) - err := errors.E(op, errors.Protocol, - "peer sent headers that do not connect "+ - "to block locators") - rp.Disconnect(err) - return nil, err - } - } - return m.Headers, nil } }