diff --git a/consensus/consortium/v2/consortium.go b/consensus/consortium/v2/consortium.go index 9106d81709..d898cc6ffe 100644 --- a/consensus/consortium/v2/consortium.go +++ b/consensus/consortium/v2/consortium.go @@ -764,7 +764,7 @@ func (c *Consortium) processSystemTransactions(chain consensus.ChainHeaderReader // If the parent's block includes the finality votes, distribute reward for the voters if c.chainConfig.IsShillin(new(big.Int).Sub(header.Number, common.Big1)) { - parentHeader := chain.GetHeaderByHash(header.ParentHash) + parentHeader := chain.GetHeader(header.ParentHash, header.Number.Uint64()-1) extraData, err := finality.DecodeExtra(parentHeader.Extra, true) if err != nil { return err diff --git a/consensus/consortium/v2/consortium_test.go b/consensus/consortium/v2/consortium_test.go index 653eca4854..f2a6612d14 100644 --- a/consensus/consortium/v2/consortium_test.go +++ b/consensus/consortium/v2/consortium_test.go @@ -2,6 +2,7 @@ package v2 import ( "bytes" + "crypto/ecdsa" "encoding/binary" "errors" "math/big" @@ -17,6 +18,7 @@ import ( "github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/vm" + "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto/bls/blst" blsCommon "github.com/ethereum/go-ethereum/crypto/bls/common" "github.com/ethereum/go-ethereum/params" @@ -972,3 +974,266 @@ func TestVerifyVote(t *testing.T) { t.Errorf("Expect sucessful verification have %s", err) } } + +func TestKnownBlockReorg(t *testing.T) { + db := rawdb.NewMemoryDatabase() + + blsKeys := make([]blsCommon.SecretKey, 3) + ecdsaKeys := make([]*ecdsa.PrivateKey, 3) + validatorAddrs := make([]common.Address, 3) + + for i := range blsKeys { + blsKey, err := blst.RandKey() + if err != nil { + t.Fatal(err) + } + blsKeys[i] = blsKey + + secretKey, err := crypto.GenerateKey() + if err != nil { + t.Fatal(err) + } + ecdsaKeys[i] = secretKey + validatorAddrs[i] = crypto.PubkeyToAddress(secretKey.PublicKey) + } + + for i := 0; i < len(blsKeys)-1; i++ { + for j := i; j < len(blsKeys); j++ { + if bytes.Compare(validatorAddrs[i][:], validatorAddrs[j][:]) > 0 { + validatorAddrs[i], validatorAddrs[j] = validatorAddrs[j], validatorAddrs[i] + blsKeys[i], blsKeys[j] = blsKeys[j], blsKeys[i] + ecdsaKeys[i], ecdsaKeys[j] = ecdsaKeys[j], ecdsaKeys[i] + } + } + } + + chainConfig := params.ChainConfig{ + ChainID: big.NewInt(2021), + HomesteadBlock: common.Big0, + EIP150Block: common.Big0, + EIP155Block: common.Big0, + EIP158Block: common.Big0, + ConsortiumV2Block: common.Big0, + ShillinBlock: big.NewInt(10), + Consortium: ¶ms.ConsortiumConfig{ + EpochV2: 10, + }, + } + + genesis := (&core.Genesis{ + Config: &chainConfig, + }).MustCommit(db) + + mock := &mockContract{ + validators: make(map[common.Address]blsCommon.PublicKey), + } + mock.validators[validatorAddrs[0]] = blsKeys[0].PublicKey() + recents, _ := lru.NewARC(inmemorySnapshots) + signatures, _ := lru.NewARC(inmemorySignatures) + + v2 := Consortium{ + chainConfig: &chainConfig, + contract: mock, + recents: recents, + signatures: signatures, + config: chainConfig.Consortium, + db: db, + } + + chain, _ := core.NewBlockChain(db, nil, &chainConfig, &v2, vm.Config{}, nil, nil) + extraData := [consortiumCommon.ExtraVanity + consortiumCommon.ExtraSeal]byte{} + + blocks, _ := core.GenerateConsortiumChain( + &chainConfig, + genesis, + &v2, + db, + 9, + func(i int, bg *core.BlockGen) { + bg.SetCoinbase(validatorAddrs[0]) + bg.SetExtra(extraData[:]) + bg.SetDifficulty(big.NewInt(7)) + }, + true, + func(i int, bg *core.BlockGen) { + header := bg.Header() + hash := calculateSealHash(header, big.NewInt(2021)) + sig, err := crypto.Sign(hash[:], ecdsaKeys[0]) + if err != nil { + t.Fatalf("Failed to sign block, err %s", err) + } + copy(header.Extra[len(header.Extra)-consortiumCommon.ExtraSeal:], sig) + bg.SetExtra(header.Extra) + }, + ) + + _, err := chain.InsertChain(blocks) + if err != nil { + t.Fatalf("Failed to insert block, err %s", err) + } + + for i := range validatorAddrs { + mock.validators[validatorAddrs[i]] = blsKeys[i].PublicKey() + } + + var checkpointValidators []finality.ValidatorWithBlsPub + for i := range validatorAddrs { + checkpointValidators = append(checkpointValidators, finality.ValidatorWithBlsPub{ + Address: validatorAddrs[i], + BlsPublicKey: blsKeys[i].PublicKey(), + }) + } + + // Prepare checkpoint block + blocks, _ = core.GenerateConsortiumChain( + &chainConfig, + blocks[len(blocks)-1], + &v2, + db, + 1, + func(i int, bg *core.BlockGen) { + var extra finality.HeaderExtraData + + bg.SetCoinbase(validatorAddrs[0]) + bg.SetDifficulty(big.NewInt(7)) + extra.CheckpointValidators = checkpointValidators + bg.SetExtra(extra.Encode(true)) + }, + true, + func(i int, bg *core.BlockGen) { + header := bg.Header() + hash := calculateSealHash(header, big.NewInt(2021)) + sig, err := crypto.Sign(hash[:], ecdsaKeys[0]) + if err != nil { + t.Fatalf("Failed to sign block, err %s", err) + } + copy(header.Extra[len(header.Extra)-consortiumCommon.ExtraSeal:], sig) + bg.SetExtra(header.Extra) + }, + ) + + _, err = chain.InsertChain(blocks) + if err != nil { + t.Fatalf("Failed to insert block, err %s", err) + } + + extraDataShillin := [consortiumCommon.ExtraVanity + 1 + consortiumCommon.ExtraSeal]byte{} + knownBlocks, _ := core.GenerateConsortiumChain( + &chainConfig, + blocks[len(blocks)-1], + &v2, + db, + 1, + func(i int, bg *core.BlockGen) { + bg.SetCoinbase(validatorAddrs[2]) + bg.SetExtra(extraDataShillin[:]) + bg.SetDifficulty(big.NewInt(7)) + }, + true, + func(i int, bg *core.BlockGen) { + header := bg.Header() + hash := calculateSealHash(header, big.NewInt(2021)) + sig, err := crypto.Sign(hash[:], ecdsaKeys[2]) + if err != nil { + t.Fatalf("Failed to sign block, err %s", err) + } + copy(header.Extra[len(header.Extra)-consortiumCommon.ExtraSeal:], sig) + bg.SetExtra(header.Extra) + }, + ) + + _, err = chain.InsertChain(knownBlocks) + if err != nil { + t.Fatalf("Failed to insert block, err %s", err) + } + + header := chain.CurrentHeader() + if header.Number.Uint64() != 11 { + t.Fatalf("Expect head header to be %d, got %d", 11, header.Number.Uint64()) + } + if header.Difficulty.Cmp(big.NewInt(7)) != 0 { + t.Fatalf("Expect header header to have difficulty %d, got %d", 7, header.Difficulty.Uint64()) + } + + justifiedBlocks, _ := core.GenerateConsortiumChain( + &chainConfig, + blocks[len(blocks)-1], + &v2, + db, + 2, + func(i int, bg *core.BlockGen) { + if bg.Number().Uint64() == 11 { + bg.SetCoinbase(validatorAddrs[1]) + bg.SetExtra(extraDataShillin[:]) + } else { + bg.SetCoinbase(validatorAddrs[2]) + + var ( + extra finality.HeaderExtraData + voteBitset finality.FinalityVoteBitSet + signatures []blsCommon.Signature + ) + voteBitset.SetBit(0) + voteBitset.SetBit(1) + voteBitset.SetBit(2) + extra.HasFinalityVote = 1 + extra.FinalityVotedValidators = voteBitset + + block := bg.PrevBlock(-1) + voteData := types.VoteData{ + TargetNumber: block.NumberU64(), + TargetHash: block.Hash(), + } + for i := range blsKeys { + signatures = append(signatures, blsKeys[i].Sign(voteData.Hash().Bytes())) + } + + extra.AggregatedFinalityVotes = blst.AggregateSignatures(signatures) + bg.SetExtra(extra.Encode(true)) + } + + bg.SetDifficulty(big.NewInt(3)) + }, + true, + func(i int, bg *core.BlockGen) { + header := bg.Header() + hash := calculateSealHash(header, big.NewInt(2021)) + + var ecdsaKey *ecdsa.PrivateKey + if bg.Number().Uint64() == 11 { + ecdsaKey = ecdsaKeys[1] + } else { + ecdsaKey = ecdsaKeys[2] + } + sig, err := crypto.Sign(hash[:], ecdsaKey) + if err != nil { + t.Fatalf("Failed to sign block, err %s", err) + } + copy(header.Extra[len(header.Extra)-consortiumCommon.ExtraSeal:], sig) + bg.SetExtra(header.Extra) + }, + ) + + _, err = chain.InsertChain(justifiedBlocks) + if err != nil { + t.Fatalf("Failed to insert block, err %s", err) + } + + header = chain.CurrentHeader() + if header.Number.Uint64() != 12 { + t.Fatalf("Expect head header to be %d, got %d", 12, header.Number.Uint64()) + } + + _, err = chain.InsertChain(knownBlocks) + if err != nil { + t.Fatalf("Failed to insert block, err %s", err) + } + header = chain.CurrentHeader() + if header.Number.Uint64() != 12 { + t.Fatalf("Expect head header to be %d, got %d", 12, header.Number.Uint64()) + } + header = chain.GetHeaderByNumber(11) + if header.Difficulty.Uint64() != 3 { + t.Fatalf("Expect head header to have difficulty %d, got %d", 3, header.Difficulty.Uint64()) + } +}