From ec5f5999750f70efe58cc061c5856250dcef5ce2 Mon Sep 17 00:00:00 2001 From: Gavin Yu Date: Mon, 16 Dec 2024 20:05:45 +0800 Subject: [PATCH] fix(taiko-client): valid status check in `BatchGetBlocksProofStatus` (#18595) --- packages/taiko-client/pkg/rpc/ethclient.go | 18 ++++- packages/taiko-client/pkg/rpc/utils.go | 77 ++++++++++------------ 2 files changed, 52 insertions(+), 43 deletions(-) diff --git a/packages/taiko-client/pkg/rpc/ethclient.go b/packages/taiko-client/pkg/rpc/ethclient.go index f16e43d57b2..23254d1b666 100644 --- a/packages/taiko-client/pkg/rpc/ethclient.go +++ b/packages/taiko-client/pkg/rpc/ethclient.go @@ -3,6 +3,7 @@ package rpc import ( "context" "errors" + "fmt" "math/big" "time" @@ -164,7 +165,7 @@ func (c *EthClient) BatchHeadersByNumbers(ctx context.Context, numbers []*big.In for i, blockNum := range numbers { reqs[i] = rpc.BatchElem{ Method: "eth_getBlockByNumber", - Args: []interface{}{blockNum, false}, + Args: []interface{}{toBlockNumArg(blockNum), false}, Result: &results[i], } } @@ -180,6 +181,21 @@ func (c *EthClient) BatchHeadersByNumbers(ctx context.Context, numbers []*big.In return results, nil } +func toBlockNumArg(number *big.Int) string { + if number == nil { + return "latest" + } + if number.Sign() >= 0 { + return hexutil.EncodeBig(number) + } + // It's negative. + if number.IsInt64() { + return rpc.BlockNumber(number.Int64()).String() + } + // It's negative and large, which is invalid. + return fmt.Sprintf("", number) +} + // TransactionByHash returns the transaction with the given hash. func (c *EthClient) TransactionByHash( ctx context.Context, diff --git a/packages/taiko-client/pkg/rpc/utils.go b/packages/taiko-client/pkg/rpc/utils.go index 570d72da166..c425b82158f 100644 --- a/packages/taiko-client/pkg/rpc/utils.go +++ b/packages/taiko-client/pkg/rpc/utils.go @@ -2,6 +2,7 @@ package rpc import ( "context" + "errors" "math/big" "os" "os/signal" @@ -32,6 +33,7 @@ var ( syscall.SIGTERM, syscall.SIGQUIT, } + ErrInvalidLength = errors.New("invalid length") ) // GetProtocolConfigs gets the protocol configs from TaikoL1 contract. @@ -265,46 +267,52 @@ func BatchGetBlocksProofStatus( defer cancel() var ( parentHashes = make([][32]byte, len(ids)) - parents = make([]*types.Header, len(ids)) - blockIDs = make([]uint64, len(ids)) + parentIDs = make([]*big.Int, len(ids)) + blockIDs = make([]*big.Int, len(ids)) + uint64BlockIDs = make([]uint64, len(ids)) result = make([]*BlockProofStatus, len(ids)) highestBlockID = big.NewInt(0) ) - // Get the local L2 parent header. - g, gCtx := errgroup.WithContext(ctxWithTimeout) for i, id := range ids { - g.Go(func() error { - parent, err := cli.L2.HeaderByNumber(gCtx, new(big.Int).Sub(id, common.Big1)) - if err != nil { - return err - } - parentHashes[i] = parent.Hash() - parents[i] = parent - blockIDs[i] = id.Uint64() - if id.Cmp(highestBlockID) > 0 { - highestBlockID = id - } - return nil - }) + parentIDs[i] = new(big.Int).Sub(id, common.Big1) + blockIDs[i] = id + uint64BlockIDs[i] = id.Uint64() + if id.Cmp(highestBlockID) > 0 { + highestBlockID = id + } } - if gErr := g.Wait(); gErr != nil { - return nil, gErr + // Get the local L2 parent headers. + parents, err := cli.L2.BatchHeadersByNumbers(ctxWithTimeout, parentIDs) + if err != nil { + return nil, err + } + if len(parents) != len(ids) { + return nil, ErrInvalidLength + } + for i := range ids { + parentHashes[i] = parents[i].Hash() } - // Get the transition state from TaikoL1 contract. transitions, err := cli.TaikoL1.GetTransitions( &bind.CallOpts{Context: ctxWithTimeout}, - blockIDs, + uint64BlockIDs, parentHashes, ) if err != nil { return nil, err } - highestHeader, err := cli.WaitL2Header(ctxWithTimeout, highestBlockID) + _, err = cli.WaitL2Header(ctxWithTimeout, highestBlockID) + if err != nil { + return nil, err + } + blockHeaders, err := cli.L2.BatchHeadersByNumbers(ctxWithTimeout, blockIDs) if err != nil { return nil, err } - g, gCtx = errgroup.WithContext(ctxWithTimeout) + if len(transitions) != len(ids) || len(blockHeaders) != len(ids) { + return nil, ErrInvalidLength + } + g, _ := errgroup.WithContext(ctxWithTimeout) for i, transition := range transitions { // No proof on chain if transition.BlockHash == (common.Hash{}) { @@ -312,28 +320,13 @@ func BatchGetBlocksProofStatus( continue } g.Go(func() error { - if err != nil { - return err - } - var ( - localBlockHash common.Hash - localStateRoot [32]byte - ) - if i+1 < len(parents) { - localBlockHash = parents[i+1].Hash() - localStateRoot = parents[i+1].Root - } else { - localBlockHash = highestHeader.Hash() - localStateRoot = highestHeader.Root - } - - if localBlockHash != transition.BlockHash || - (transition.StateRoot != (common.Hash{}) && transition.StateRoot != localStateRoot) { + if blockHeaders[i].Hash() != transition.BlockHash || + (transition.StateRoot != (common.Hash{}) && transition.StateRoot != blockHeaders[i].Root) { log.Info( "Different block hash or state root detected, try submitting a contest", - "localBlockHash", localBlockHash, + "localBlockHash", blockHeaders[i].Hash(), "protocolTransitionBlockHash", common.BytesToHash(transition.BlockHash[:]), - "localStateRoot", localStateRoot, + "localStateRoot", blockHeaders[i].Root, "protocolTransitionStateRoot", common.BytesToHash(transition.StateRoot[:]), ) result[i] = &BlockProofStatus{