From 16fe6a32a242dc36d21819b6263013a2556e22c6 Mon Sep 17 00:00:00 2001
From: Marius van der Wijden <m.vanderwijden@live.de>
Date: Tue, 28 Jun 2022 11:30:32 +0200
Subject: [PATCH 1/5] consensus/beacon: check that only the latest pow block is
 valid ttd block

---
 consensus/beacon/consensus.go | 46 +++++++++++++++++++++++++++++++++++
 consensus/errors.go           |  4 +++
 2 files changed, 50 insertions(+)

diff --git a/consensus/beacon/consensus.go b/consensus/beacon/consensus.go
index 1fd7deb872fb..fe5e4807cee0 100644
--- a/consensus/beacon/consensus.go
+++ b/consensus/beacon/consensus.go
@@ -112,16 +112,29 @@ func (beacon *Beacon) VerifyHeaders(chain consensus.ChainHeaderReader, headers [
 			break
 		}
 	}
+
 	// All the headers have passed the transition point, use new rules.
 	if len(preHeaders) == 0 {
 		return beacon.verifyHeaders(chain, headers, nil)
 	}
+
 	// The transition point exists in the middle, separate the headers
 	// into two batches and apply different verification rules for them.
 	var (
 		abort   = make(chan struct{})
 		results = make(chan error, len(headers))
 	)
+	// Verify that the last preHeader (and only the last one) satisfies the
+	// terminal total difficulty.
+	if err := beacon.verifyTerminalPoWBlock(chain, preHeaders); err != nil {
+		go func() {
+			select {
+			case results <- err:
+			case <-abort:
+			}
+		}()
+		return abort, results
+	}
 	go func() {
 		var (
 			old, new, out      = 0, len(preHeaders), 0
@@ -154,6 +167,39 @@ func (beacon *Beacon) VerifyHeaders(chain consensus.ChainHeaderReader, headers [
 	return abort, results
 }
 
+// verifyTerminalPoWBlock verifies that the preHeaders confirm to the specification
+// wrt. their total difficulty.
+// It expects:
+// - preHeaders to be at least 1 element
+// - the parent of the header element to be stored in the chain correctly
+// - the preHeaders to have a set difficulty
+// - the last element to be the terminal block
+func (beacon *Beacon) verifyTerminalPoWBlock(chain consensus.ChainHeaderReader, preHeaders []*types.Header) error {
+	var (
+		first = preHeaders[0]
+		last  = preHeaders[len(preHeaders)-1]
+	)
+
+	td := chain.GetTd(first.ParentHash, first.Number.Uint64()-1)
+	if td == nil {
+		return consensus.ErrUnknownAncestor
+	}
+	if len(preHeaders) != 1 {
+		for _, head := range preHeaders[:len(preHeaders)-2] {
+			td.Add(td, head.Difficulty)
+		}
+	}
+	// Check if the parent was already the terminal block
+	if td.Cmp(chain.Config().TerminalTotalDifficulty) >= 0 {
+		return consensus.ErrInvalidTerminalBlock
+	}
+	// Check that the last block is the terminal block
+	if td.Add(td, last.Difficulty).Cmp(chain.Config().TerminalTotalDifficulty) < 0 {
+		return consensus.ErrInvalidTerminalBlock
+	}
+	return nil
+}
+
 // VerifyUncles verifies that the given block's uncles conform to the consensus
 // rules of the Ethereum consensus engine.
 func (beacon *Beacon) VerifyUncles(chain consensus.ChainReader, block *types.Block) error {
diff --git a/consensus/errors.go b/consensus/errors.go
index ac5242fb54c5..d508b6580f55 100644
--- a/consensus/errors.go
+++ b/consensus/errors.go
@@ -34,4 +34,8 @@ var (
 	// ErrInvalidNumber is returned if a block's number doesn't equal its parent's
 	// plus one.
 	ErrInvalidNumber = errors.New("invalid block number")
+
+	// ErrInvalidTerminalBlock is returned if a block is invalid wrt. the terminal
+	// total difficulty.
+	ErrInvalidTerminalBlock = errors.New("invalid terminal block")
 )

From 235ae91e9f19e6299dd5cc3a6252d823af6fef62 Mon Sep 17 00:00:00 2001
From: Marius van der Wijden <m.vanderwijden@live.de>
Date: Tue, 28 Jun 2022 13:08:37 +0200
Subject: [PATCH 2/5] consensus/beacon: move verification to async function

---
 consensus/beacon/consensus.go | 38 +++++++++++++++++------------------
 1 file changed, 18 insertions(+), 20 deletions(-)

diff --git a/consensus/beacon/consensus.go b/consensus/beacon/consensus.go
index fe5e4807cee0..e772396b2e74 100644
--- a/consensus/beacon/consensus.go
+++ b/consensus/beacon/consensus.go
@@ -124,17 +124,6 @@ func (beacon *Beacon) VerifyHeaders(chain consensus.ChainHeaderReader, headers [
 		abort   = make(chan struct{})
 		results = make(chan error, len(headers))
 	)
-	// Verify that the last preHeader (and only the last one) satisfies the
-	// terminal total difficulty.
-	if err := beacon.verifyTerminalPoWBlock(chain, preHeaders); err != nil {
-		go func() {
-			select {
-			case results <- err:
-			case <-abort:
-			}
-		}()
-		return abort, results
-	}
 	go func() {
 		var (
 			old, new, out      = 0, len(preHeaders), 0
@@ -143,6 +132,14 @@ func (beacon *Beacon) VerifyHeaders(chain consensus.ChainHeaderReader, headers [
 			oldDone, oldResult = beacon.ethone.VerifyHeaders(chain, preHeaders, preSeals)
 			newDone, newResult = beacon.verifyHeaders(chain, postHeaders, preHeaders[len(preHeaders)-1])
 		)
+		// verify that the headers are valid wrt. the terminal block.
+		index, err := beacon.verifyTerminalPoWBlock(chain, preHeaders)
+		if err != nil {
+			// Mark all subsequent headers with the error.
+			for i := index; i < len(preHeaders); i++ {
+				errors[i], done[i] = err, true
+			}
+		}
 		for {
 			for ; done[out]; out++ {
 				results <- errors[out]
@@ -174,7 +171,7 @@ func (beacon *Beacon) VerifyHeaders(chain consensus.ChainHeaderReader, headers [
 // - the parent of the header element to be stored in the chain correctly
 // - the preHeaders to have a set difficulty
 // - the last element to be the terminal block
-func (beacon *Beacon) verifyTerminalPoWBlock(chain consensus.ChainHeaderReader, preHeaders []*types.Header) error {
+func (beacon *Beacon) verifyTerminalPoWBlock(chain consensus.ChainHeaderReader, preHeaders []*types.Header) (int, error) {
 	var (
 		first = preHeaders[0]
 		last  = preHeaders[len(preHeaders)-1]
@@ -182,22 +179,23 @@ func (beacon *Beacon) verifyTerminalPoWBlock(chain consensus.ChainHeaderReader,
 
 	td := chain.GetTd(first.ParentHash, first.Number.Uint64()-1)
 	if td == nil {
-		return consensus.ErrUnknownAncestor
+		return 0, consensus.ErrUnknownAncestor
 	}
 	if len(preHeaders) != 1 {
-		for _, head := range preHeaders[:len(preHeaders)-2] {
+		for i, head := range preHeaders[:len(preHeaders)-2] {
 			td.Add(td, head.Difficulty)
+			// Check if the parent was already the terminal block
+			if td.Cmp(chain.Config().TerminalTotalDifficulty) >= 0 {
+				return i, consensus.ErrInvalidTerminalBlock
+			}
 		}
 	}
-	// Check if the parent was already the terminal block
-	if td.Cmp(chain.Config().TerminalTotalDifficulty) >= 0 {
-		return consensus.ErrInvalidTerminalBlock
-	}
+
 	// Check that the last block is the terminal block
 	if td.Add(td, last.Difficulty).Cmp(chain.Config().TerminalTotalDifficulty) < 0 {
-		return consensus.ErrInvalidTerminalBlock
+		return len(preHeaders) - 1, consensus.ErrInvalidTerminalBlock
 	}
-	return nil
+	return 0, nil
 }
 
 // VerifyUncles verifies that the given block's uncles conform to the consensus

From 65a40d1243d30eedca5ae278a14bbfac9f239b5f Mon Sep 17 00:00:00 2001
From: Marius van der Wijden <m.vanderwijden@live.de>
Date: Wed, 29 Jun 2022 10:45:06 +0200
Subject: [PATCH 3/5] consensus/beacon: fix verifyTerminalPoWBlock, add test
 cases

---
 consensus/beacon/consensus.go      |  23 +++--
 consensus/beacon/consensus_test.go | 137 +++++++++++++++++++++++++++++
 2 files changed, 147 insertions(+), 13 deletions(-)
 create mode 100644 consensus/beacon/consensus_test.go

diff --git a/consensus/beacon/consensus.go b/consensus/beacon/consensus.go
index e772396b2e74..557712abca41 100644
--- a/consensus/beacon/consensus.go
+++ b/consensus/beacon/consensus.go
@@ -133,9 +133,9 @@ func (beacon *Beacon) VerifyHeaders(chain consensus.ChainHeaderReader, headers [
 			newDone, newResult = beacon.verifyHeaders(chain, postHeaders, preHeaders[len(preHeaders)-1])
 		)
 		// verify that the headers are valid wrt. the terminal block.
-		index, err := beacon.verifyTerminalPoWBlock(chain, preHeaders)
+		index, err := verifyTerminalPoWBlock(chain, preHeaders)
 		if err != nil {
-			// Mark all subsequent headers with the error.
+			// Mark all subsequent pow headers with the error.
 			for i := index; i < len(preHeaders); i++ {
 				errors[i], done[i] = err, true
 			}
@@ -171,28 +171,25 @@ func (beacon *Beacon) VerifyHeaders(chain consensus.ChainHeaderReader, headers [
 // - the parent of the header element to be stored in the chain correctly
 // - the preHeaders to have a set difficulty
 // - the last element to be the terminal block
-func (beacon *Beacon) verifyTerminalPoWBlock(chain consensus.ChainHeaderReader, preHeaders []*types.Header) (int, error) {
+func verifyTerminalPoWBlock(chain consensus.ChainHeaderReader, preHeaders []*types.Header) (int, error) {
 	var (
 		first = preHeaders[0]
-		last  = preHeaders[len(preHeaders)-1]
 	)
 
 	td := chain.GetTd(first.ParentHash, first.Number.Uint64()-1)
 	if td == nil {
 		return 0, consensus.ErrUnknownAncestor
 	}
-	if len(preHeaders) != 1 {
-		for i, head := range preHeaders[:len(preHeaders)-2] {
-			td.Add(td, head.Difficulty)
-			// Check if the parent was already the terminal block
-			if td.Cmp(chain.Config().TerminalTotalDifficulty) >= 0 {
-				return i, consensus.ErrInvalidTerminalBlock
-			}
+
+	for i, head := range preHeaders {
+		// Check if the parent was already the terminal block
+		if td.Cmp(chain.Config().TerminalTotalDifficulty) >= 0 {
+			return i, consensus.ErrInvalidTerminalBlock
 		}
+		td.Add(td, head.Difficulty)
 	}
-
 	// Check that the last block is the terminal block
-	if td.Add(td, last.Difficulty).Cmp(chain.Config().TerminalTotalDifficulty) < 0 {
+	if td.Cmp(chain.Config().TerminalTotalDifficulty) < 0 {
 		return len(preHeaders) - 1, consensus.ErrInvalidTerminalBlock
 	}
 	return 0, nil
diff --git a/consensus/beacon/consensus_test.go b/consensus/beacon/consensus_test.go
new file mode 100644
index 000000000000..09c0b27c4256
--- /dev/null
+++ b/consensus/beacon/consensus_test.go
@@ -0,0 +1,137 @@
+package beacon
+
+import (
+	"fmt"
+	"math/big"
+	"testing"
+
+	"github.com/ethereum/go-ethereum/common"
+	"github.com/ethereum/go-ethereum/consensus"
+	"github.com/ethereum/go-ethereum/core/types"
+	"github.com/ethereum/go-ethereum/params"
+)
+
+type mockChain struct {
+	config *params.ChainConfig
+	tds    map[uint64]*big.Int
+}
+
+func newMockChain() *mockChain {
+	return &mockChain{
+		config: new(params.ChainConfig),
+		tds:    make(map[uint64]*big.Int),
+	}
+}
+
+func (m *mockChain) Config() *params.ChainConfig {
+	return m.config
+}
+
+func (m *mockChain) CurrentHeader() *types.Header { panic("not implemented") }
+
+func (m *mockChain) GetHeader(hash common.Hash, number uint64) *types.Header {
+	panic("not implemented")
+}
+
+func (m *mockChain) GetHeaderByNumber(number uint64) *types.Header { panic("not implemented") }
+
+func (m *mockChain) GetHeaderByHash(hash common.Hash) *types.Header { panic("not implemented") }
+
+func (m *mockChain) GetTd(hash common.Hash, number uint64) *big.Int {
+	num, ok := m.tds[number]
+	if ok {
+		return new(big.Int).Set(num)
+	}
+	return nil
+}
+
+func TestVerifyTerminalBlock(t *testing.T) {
+	chain := newMockChain()
+	chain.tds[0] = big.NewInt(10)
+	chain.config.TerminalTotalDifficulty = big.NewInt(50)
+
+	tests := []struct {
+		preHeaders []*types.Header
+		ttd        *big.Int
+		err        error
+		index      int
+	}{
+		// valid ttd
+		{
+			preHeaders: []*types.Header{
+				{Number: big.NewInt(1), Difficulty: big.NewInt(10)},
+				{Number: big.NewInt(2), Difficulty: big.NewInt(10)},
+				{Number: big.NewInt(3), Difficulty: big.NewInt(10)},
+				{Number: big.NewInt(4), Difficulty: big.NewInt(10)},
+			},
+			ttd: big.NewInt(50),
+		},
+		// last block doesn't reach ttd
+		{
+			preHeaders: []*types.Header{
+				{Number: big.NewInt(1), Difficulty: big.NewInt(10)},
+				{Number: big.NewInt(2), Difficulty: big.NewInt(10)},
+				{Number: big.NewInt(3), Difficulty: big.NewInt(10)},
+				{Number: big.NewInt(4), Difficulty: big.NewInt(9)},
+			},
+			ttd:   big.NewInt(50),
+			err:   consensus.ErrInvalidTerminalBlock,
+			index: 3,
+		},
+		// two blocks reach ttd
+		{
+			preHeaders: []*types.Header{
+				{Number: big.NewInt(1), Difficulty: big.NewInt(10)},
+				{Number: big.NewInt(2), Difficulty: big.NewInt(10)},
+				{Number: big.NewInt(3), Difficulty: big.NewInt(20)},
+				{Number: big.NewInt(4), Difficulty: big.NewInt(10)},
+			},
+			ttd:   big.NewInt(50),
+			err:   consensus.ErrInvalidTerminalBlock,
+			index: 3,
+		},
+		// three blocks reach ttd
+		{
+			preHeaders: []*types.Header{
+				{Number: big.NewInt(1), Difficulty: big.NewInt(10)},
+				{Number: big.NewInt(2), Difficulty: big.NewInt(10)},
+				{Number: big.NewInt(3), Difficulty: big.NewInt(20)},
+				{Number: big.NewInt(4), Difficulty: big.NewInt(10)},
+				{Number: big.NewInt(4), Difficulty: big.NewInt(10)},
+			},
+			ttd:   big.NewInt(50),
+			err:   consensus.ErrInvalidTerminalBlock,
+			index: 3,
+		},
+		// parent reached ttd
+		{
+			preHeaders: []*types.Header{
+				{Number: big.NewInt(1), Difficulty: big.NewInt(10)},
+			},
+			ttd:   big.NewInt(9),
+			err:   consensus.ErrInvalidTerminalBlock,
+			index: 0,
+		},
+		// unknown parent
+		{
+			preHeaders: []*types.Header{
+				{Number: big.NewInt(4), Difficulty: big.NewInt(10)},
+			},
+			ttd:   big.NewInt(9),
+			err:   consensus.ErrUnknownAncestor,
+			index: 0,
+		},
+	}
+
+	for i, test := range tests {
+		fmt.Printf("Test: %v\n", i)
+		chain.config.TerminalTotalDifficulty = test.ttd
+		index, err := verifyTerminalPoWBlock(chain, test.preHeaders)
+		if err != test.err {
+			t.Fatalf("Invalid error encountered, expected %v got %v", test.err, err)
+		}
+		if index != test.index {
+			t.Fatalf("Invalid index, expected %v got %v", test.index, index)
+		}
+	}
+}

From 81d7a9049f2b3946aeee3fc89871d5cfc357c6b4 Mon Sep 17 00:00:00 2001
From: Marius van der Wijden <m.vanderwijden@live.de>
Date: Wed, 29 Jun 2022 10:46:07 +0200
Subject: [PATCH 4/5] consensus/beacon: cosmetic changes

---
 consensus/beacon/consensus.go | 7 +------
 1 file changed, 1 insertion(+), 6 deletions(-)

diff --git a/consensus/beacon/consensus.go b/consensus/beacon/consensus.go
index 557712abca41..0ae266562ea3 100644
--- a/consensus/beacon/consensus.go
+++ b/consensus/beacon/consensus.go
@@ -172,15 +172,10 @@ func (beacon *Beacon) VerifyHeaders(chain consensus.ChainHeaderReader, headers [
 // - the preHeaders to have a set difficulty
 // - the last element to be the terminal block
 func verifyTerminalPoWBlock(chain consensus.ChainHeaderReader, preHeaders []*types.Header) (int, error) {
-	var (
-		first = preHeaders[0]
-	)
-
-	td := chain.GetTd(first.ParentHash, first.Number.Uint64()-1)
+	td := chain.GetTd(preHeaders[0].ParentHash, preHeaders[0].Number.Uint64()-1)
 	if td == nil {
 		return 0, consensus.ErrUnknownAncestor
 	}
-
 	for i, head := range preHeaders {
 		// Check if the parent was already the terminal block
 		if td.Cmp(chain.Config().TerminalTotalDifficulty) >= 0 {

From fd93fcae0d051b9a93a8e5a0ec5356281359ab2d Mon Sep 17 00:00:00 2001
From: Marius van der Wijden <m.vanderwijden@live.de>
Date: Wed, 29 Jun 2022 10:58:22 +0200
Subject: [PATCH 5/5] consensus/beacon: apply karalabe's fixes

---
 consensus/beacon/consensus.go | 12 +++++++-----
 core/block_validator_test.go  |  3 ++-
 2 files changed, 9 insertions(+), 6 deletions(-)

diff --git a/consensus/beacon/consensus.go b/consensus/beacon/consensus.go
index 0ae266562ea3..e090a03990f6 100644
--- a/consensus/beacon/consensus.go
+++ b/consensus/beacon/consensus.go
@@ -132,14 +132,14 @@ func (beacon *Beacon) VerifyHeaders(chain consensus.ChainHeaderReader, headers [
 			oldDone, oldResult = beacon.ethone.VerifyHeaders(chain, preHeaders, preSeals)
 			newDone, newResult = beacon.verifyHeaders(chain, postHeaders, preHeaders[len(preHeaders)-1])
 		)
-		// verify that the headers are valid wrt. the terminal block.
-		index, err := verifyTerminalPoWBlock(chain, preHeaders)
-		if err != nil {
+		// Verify that pre-merge headers don't overflow the TTD
+		if index, err := verifyTerminalPoWBlock(chain, preHeaders); err != nil {
 			// Mark all subsequent pow headers with the error.
 			for i := index; i < len(preHeaders); i++ {
 				errors[i], done[i] = err, true
 			}
 		}
+		// Collect the results
 		for {
 			for ; done[out]; out++ {
 				results <- errors[out]
@@ -149,7 +149,9 @@ func (beacon *Beacon) VerifyHeaders(chain consensus.ChainHeaderReader, headers [
 			}
 			select {
 			case err := <-oldResult:
-				errors[old], done[old] = err, true
+				if !done[old] { // skip TTD-verified failures
+					errors[old], done[old] = err, true
+				}
 				old++
 			case err := <-newResult:
 				errors[new], done[new] = err, true
@@ -176,8 +178,8 @@ func verifyTerminalPoWBlock(chain consensus.ChainHeaderReader, preHeaders []*typ
 	if td == nil {
 		return 0, consensus.ErrUnknownAncestor
 	}
+	// Check that all blocks before the last one are below the TTD
 	for i, head := range preHeaders {
-		// Check if the parent was already the terminal block
 		if td.Cmp(chain.Config().TerminalTotalDifficulty) >= 0 {
 			return i, consensus.ErrInvalidTerminalBlock
 		}
diff --git a/core/block_validator_test.go b/core/block_validator_test.go
index 0f183ba52778..8dee8d576070 100644
--- a/core/block_validator_test.go
+++ b/core/block_validator_test.go
@@ -107,7 +107,8 @@ func testHeaderVerificationForMerging(t *testing.T, isClique bool) {
 			Alloc: map[common.Address]GenesisAccount{
 				addr: {Balance: big.NewInt(1)},
 			},
-			BaseFee: big.NewInt(params.InitialBaseFee),
+			BaseFee:    big.NewInt(params.InitialBaseFee),
+			Difficulty: new(big.Int),
 		}
 		copy(genspec.ExtraData[32:], addr[:])
 		genesis := genspec.MustCommit(testdb)