Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

headerdownload: handle tie breaker for forkchoice in pow networks #8616

Merged
merged 2 commits into from
Nov 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 78 additions & 17 deletions turbo/stages/headerdownload/header_algo_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package headerdownload_test

import (
"bytes"
"context"
"math/big"
"testing"

"github.com/ledgerwatch/erigon-lib/common"
"github.com/ledgerwatch/erigon-lib/kv"

"github.com/ledgerwatch/erigon/core"
Expand All @@ -16,7 +18,7 @@ import (
"github.com/ledgerwatch/erigon/turbo/stages/mock"
)

func TestInserter1(t *testing.T) {
func TestSideChainInsert(t *testing.T) {
funds := big.NewInt(1000000000)
key, _ := crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291")
address := crypto.PubkeyToAddress(key.PublicKey)
Expand All @@ -40,24 +42,83 @@ func TestInserter1(t *testing.T) {
defer tx.Rollback()
br := m.BlockReader
hi := headerdownload.NewHeaderInserter("headers", big.NewInt(0), 0, br)
h1 := types.Header{
Number: big.NewInt(1),
Difficulty: big.NewInt(10),
ParentHash: genesis.Hash(),

// Chain with higher initial difficulty
chain1 := createTestChain(3, genesis.Hash(), 2, []byte(""))

// Smaller side chain (non-canonical)
chain2 := createTestChain(5, genesis.Hash(), 1, []byte("side1"))

// Bigger side chain (canonical)
chain3 := createTestChain(7, genesis.Hash(), 1, []byte("side2"))

// Again smaller side chain but with high difficulty (canonical)
chain4 := createTestChain(5, genesis.Hash(), 2, []byte("side3"))

// More smaller side chain with same difficulty (canonical)
chain5 := createTestChain(2, genesis.Hash(), 5, []byte("side5"))

// Bigger side chain with same difficulty (non-canonical)
chain6 := createTestChain(10, genesis.Hash(), 1, []byte("side6"))

// Same side chain (in terms of number and difficulty) but different hash
chain7 := createTestChain(2, genesis.Hash(), 5, []byte("side7"))

finalExpectedHash := chain5[len(chain5)-1].Hash()
if bytes.Compare(chain5[len(chain5)-1].Hash().Bytes(), chain7[len(chain7)-1].Hash().Bytes()) < 0 {
finalExpectedHash = chain7[len(chain7)-1].Hash()
}
h1Hash := h1.Hash()
h2 := types.Header{
Number: big.NewInt(2),
Difficulty: big.NewInt(1010),
ParentHash: h1Hash,

testCases := []struct {
name string
chain []types.Header
expectedHash common.Hash
expectedDiff int64
}{
{"normal initial insert", chain1, chain1[len(chain1)-1].Hash(), 6},
{"td(current) > td(incoming)", chain2, chain1[len(chain1)-1].Hash(), 6},
{"td(incoming) > td(current), number(incoming) > number(current)", chain3, chain3[len(chain3)-1].Hash(), 7},
{"td(incoming) > td(current), number(current) > number(incoming)", chain4, chain4[len(chain4)-1].Hash(), 10},
{"td(incoming) = td(current), number(current) > number(current)", chain5, chain5[len(chain5)-1].Hash(), 10},
{"td(incoming) = td(current), number(incoming) > number(current)", chain6, chain5[len(chain5)-1].Hash(), 10},
{"td(incoming) = td(current), number(incoming) = number(current), hash different", chain7, finalExpectedHash, 10},
}
h2Hash := h2.Hash()
data1, _ := rlp.EncodeToBytes(&h1)
if _, err = hi.FeedHeaderPoW(tx, br, &h1, data1, h1Hash, 1); err != nil {
t.Errorf("feed empty header 1: %v", err)

for _, tc := range testCases {
tc := tc
for i, h := range tc.chain {
h := h
data, _ := rlp.EncodeToBytes(&h)
if _, err = hi.FeedHeaderPoW(tx, br, &h, data, h.Hash(), uint64(i+1)); err != nil {
t.Errorf("feed empty header for %s, err: %v", tc.name, err)
}
}

if hi.GetHighestHash() != tc.expectedHash {
t.Errorf("incorrect highest hash for %s, expected %s, got %s", tc.name, tc.expectedHash, hi.GetHighestHash())
}
if hi.GetLocalTd().Int64() != tc.expectedDiff {
t.Errorf("incorrect difficulty for %s, expected %d, got %d", tc.name, tc.expectedDiff, hi.GetLocalTd().Int64())
}
}
data2, _ := rlp.EncodeToBytes(&h2)
if _, err = hi.FeedHeaderPoW(tx, br, &h2, data2, h2Hash, 2); err != nil {
t.Errorf("feed empty header 2: %v", err)
}

func createTestChain(length int64, parent common.Hash, diff int64, extra []byte) []types.Header {
var (
i int64
headers []types.Header
)

for i = 0; i < length; i++ {
h := types.Header{
Number: big.NewInt(i + 1),
Difficulty: big.NewInt(diff),
ParentHash: parent,
Extra: extra,
}
headers = append(headers, h)
parent = h.Hash()
}

return headers
}
53 changes: 37 additions & 16 deletions turbo/stages/headerdownload/header_algos.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@ import (
"encoding/base64"
"errors"
"fmt"
"github.com/ledgerwatch/erigon-lib/kv/dbutils"
"io"
"math/big"
"sort"
"strings"
"time"

"github.com/ledgerwatch/erigon-lib/kv/dbutils"

libcommon "github.com/ledgerwatch/erigon-lib/common"
"github.com/ledgerwatch/erigon-lib/etl"
"github.com/ledgerwatch/erigon-lib/kv"
Expand Down Expand Up @@ -892,24 +893,40 @@ func (hi *HeaderInserter) FeedHeaderPoW(db kv.StatelessRwTx, headerReader servic
}
// Calculate total difficulty of this header using parent's total difficulty
td = new(big.Int).Add(parentTd, header.Difficulty)

// Now we can decide wether this header will create a change in the canonical head
if td.Cmp(hi.localTd) > 0 {
hi.newCanonical = true
forkingPoint, err := hi.ForkingPoint(db, header, parent)
if err != nil {
return nil, err
if td.Cmp(hi.localTd) >= 0 {
reorg := true

// TODO: Add bor check here if required
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is something which I am not sure of. If no other chain is using the PoW flow, I think we don't need to wrap this with isBor check. If not, we might have to change the function signature for checking if we're running bor.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PoW flow is only used for Polygon.

// Borrowed from https://github.com/maticnetwork/bor/blob/master/core/forkchoice.go#L81
if td.Cmp(hi.localTd) == 0 {
if blockHeight > hi.highest {
reorg = false
} else if blockHeight == hi.highest {
// Compare hashes of block in case of tie breaker. Lexicographically larger hash wins.
reorg = bytes.Compare(hi.highestHash.Bytes(), hash.Bytes()) < 0
}
}
hi.highest = blockHeight
hi.highestHash = hash
hi.highestTimestamp = header.Time
hi.canonicalCache.Add(blockHeight, hash)
// See if the forking point affects the unwindPoint (the block number to which other stages will need to unwind before the new canonical chain is applied)
if forkingPoint < hi.unwindPoint {
hi.unwindPoint = forkingPoint
hi.unwind = true

if reorg {
hi.newCanonical = true
forkingPoint, err := hi.ForkingPoint(db, header, parent)
if err != nil {
return nil, err
}
hi.highest = blockHeight
hi.highestHash = hash
hi.highestTimestamp = header.Time
hi.canonicalCache.Add(blockHeight, hash)
// See if the forking point affects the unwindPoint (the block number to which other stages will need to unwind before the new canonical chain is applied)
if forkingPoint < hi.unwindPoint {
hi.unwindPoint = forkingPoint
hi.unwind = true
}
// This makes sure we end up choosing the chain with the max total difficulty
hi.localTd.Set(td)
}
// This makes sure we end up choosing the chain with the max total difficulty
hi.localTd.Set(td)
}
if err = rawdb.WriteTd(db, hash, blockHeight, td); err != nil {
return nil, fmt.Errorf("[%s] failed to WriteTd: %w", hi.logPrefix, err)
Expand Down Expand Up @@ -946,6 +963,10 @@ func (hi *HeaderInserter) FeedHeaderPoS(db kv.RwTx, header *types.Header, hash l
return nil
}

func (hi *HeaderInserter) GetLocalTd() *big.Int {
return hi.localTd
}

func (hi *HeaderInserter) GetHighest() uint64 {
return hi.highest
}
Expand Down
Loading