diff --git a/core/eds.go b/core/eds.go new file mode 100644 index 0000000000..e996c0111e --- /dev/null +++ b/core/eds.go @@ -0,0 +1,26 @@ +package core + +import ( + "github.com/tendermint/tendermint/types" + + "github.com/celestiaorg/celestia-app/pkg/da" + appshares "github.com/celestiaorg/celestia-app/pkg/shares" + "github.com/celestiaorg/rsmt2d" + + "github.com/celestiaorg/celestia-node/libs/utils" +) + +// extendBlock extends the given block data, returning the resulting +// ExtendedDataSquare (EDS). If there are no transactions in the block, +// nil is returned in place of the eds. +func extendBlock(data types.Data) (*rsmt2d.ExtendedDataSquare, error) { + if len(data.Txs) == 0 { + return nil, nil + } + shares, err := appshares.Split(data, true) + if err != nil { + return nil, err + } + size := utils.SquareSize(len(shares)) + return da.ExtendShares(size, appshares.ToBytes(shares)) +} diff --git a/core/exchange.go b/core/exchange.go index bd9d0de645..71dd3b13d6 100644 --- a/core/exchange.go +++ b/core/exchange.go @@ -5,27 +5,26 @@ import ( "context" "fmt" - "github.com/ipfs/go-blockservice" - "github.com/celestiaorg/celestia-node/header" libhead "github.com/celestiaorg/celestia-node/libs/header" + "github.com/celestiaorg/celestia-node/share/eds" ) type Exchange struct { - fetcher *BlockFetcher - shareStore blockservice.BlockService - construct header.ConstructFn + fetcher *BlockFetcher + store *eds.Store + construct header.ConstructFn } func NewExchange( fetcher *BlockFetcher, - bServ blockservice.BlockService, + store *eds.Store, construct header.ConstructFn, ) *Exchange { return &Exchange{ - fetcher: fetcher, - shareStore: bServ, - construct: construct, + fetcher: fetcher, + store: store, + construct: construct, } } @@ -86,15 +85,27 @@ func (ce *Exchange) Get(ctx context.Context, hash libhead.Hash) (*header.Extende return nil, err } - eh, err := ce.construct(ctx, block, comm, vals, ce.shareStore) + // extend block data + eds, err := extendBlock(block.Data) + if err != nil { + return nil, err + } + // construct extended header + eh, err := ce.construct(ctx, block, comm, vals, eds) if err != nil { return nil, err } - // verify hashes match if !bytes.Equal(hash, eh.Hash()) { return nil, fmt.Errorf("incorrect hash in header: expected %x, got %x", hash, eh.Hash()) } + // store extended block if it is not empty + if eds != nil { + err = ce.store.Put(ctx, eh.DAH.Hash(), eds) + if err != nil { + return nil, err + } + } return eh, nil } @@ -115,5 +126,22 @@ func (ce *Exchange) getExtendedHeaderByHeight(ctx context.Context, height *int64 return nil, err } - return ce.construct(ctx, b, comm, vals, ce.shareStore) + // extend block data + eds, err := extendBlock(b.Data) + if err != nil { + return nil, err + } + // create extended header + eh, err := ce.construct(ctx, b, comm, vals, eds) + if err != nil { + return nil, err + } + // only store extended block if it's not empty + if eds != nil { + err = ce.store.Put(ctx, eh.DAH.Hash(), eds) + if err != nil { + return nil, err + } + } + return eh, nil } diff --git a/core/exchange_test.go b/core/exchange_test.go index 0a7b58ccd7..aae310d482 100644 --- a/core/exchange_test.go +++ b/core/exchange_test.go @@ -1,24 +1,29 @@ package core import ( - "bytes" "context" "testing" + "time" - mdutils "github.com/ipfs/go-merkledag/test" + ds "github.com/ipfs/go-datastore" + ds_sync "github.com/ipfs/go-datastore/sync" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/celestiaorg/celestia-app/testutil/testnode" + "github.com/celestiaorg/celestia-node/header" + "github.com/celestiaorg/celestia-node/share/eds" ) func TestCoreExchange_RequestHeaders(t *testing.T) { - fetcher := createCoreFetcher(t) - store := mdutils.Bserv() + fetcher, _ := createCoreFetcher(t, DefaultTestConfig()) // generate 10 blocks generateBlocks(t, fetcher) + store := createStore(t) + ce := NewExchange(fetcher, store, header.MakeExtendedHeader) headers, err := ce.GetRangeByHeight(context.Background(), 1, 10) require.NoError(t, err) @@ -26,16 +31,21 @@ func TestCoreExchange_RequestHeaders(t *testing.T) { assert.Equal(t, 10, len(headers)) } -func Test_hashMatch(t *testing.T) { - expected := []byte("AE0F153556A4FA5C0B7C3BFE0BAF0EC780C031933B281A8D759BB34C1DA31C56") - mismatch := []byte("57A0D7FE69FE88B3D277C824B3ACB9B60E5E65837A802485DE5CBB278C43576A") - - assert.False(t, bytes.Equal(expected, mismatch)) +func createCoreFetcher(t *testing.T, cfg *TestConfig) (*BlockFetcher, testnode.Context) { + cctx := StartTestNodeWithConfig(t, cfg) + // wait for height 2 in order to be able to start submitting txs (this prevents + // flakiness with accessing account state) + _, err := cctx.WaitForHeightWithTimeout(2, time.Second) // TODO @renaynay: configure? + require.NoError(t, err) + return NewBlockFetcher(cctx.Client), cctx } -func createCoreFetcher(t *testing.T) *BlockFetcher { - client := StartTestNode(t).Client - return NewBlockFetcher(client) +func createStore(t *testing.T) *eds.Store { + t.Helper() + + store, err := eds.NewStore(t.TempDir(), ds_sync.MutexWrap(ds.NewMapDatastore())) + require.NoError(t, err) + return store } func generateBlocks(t *testing.T, fetcher *BlockFetcher) { diff --git a/headertest/header_test.go b/core/header_test.go similarity index 78% rename from headertest/header_test.go rename to core/header_test.go index 3e360504a2..17337405c9 100644 --- a/headertest/header_test.go +++ b/core/header_test.go @@ -1,26 +1,23 @@ -package headertest +package core import ( "context" "testing" - mdutils "github.com/ipfs/go-merkledag/test" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/tendermint/tendermint/libs/rand" - "github.com/celestiaorg/celestia-node/core" "github.com/celestiaorg/celestia-node/header" + "github.com/celestiaorg/celestia-node/headertest" ) func TestMakeExtendedHeaderForEmptyBlock(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) t.Cleanup(cancel) - client := core.StartTestNode(t).Client - fetcher := core.NewBlockFetcher(client) - - store := mdutils.Bserv() + client := StartTestNode(t).Client + fetcher := NewBlockFetcher(client) sub, err := fetcher.SubscribeNewBlockEvent(ctx) require.NoError(t, err) @@ -33,14 +30,17 @@ func TestMakeExtendedHeaderForEmptyBlock(t *testing.T) { comm, val, err := fetcher.GetBlockInfo(ctx, &height) require.NoError(t, err) - headerExt, err := header.MakeExtendedHeader(ctx, b, comm, val, store) + eds, err := extendBlock(b.Data) + require.NoError(t, err) + + headerExt, err := header.MakeExtendedHeader(ctx, b, comm, val, eds) require.NoError(t, err) assert.Equal(t, header.EmptyDAH(), *headerExt.DAH) } func TestMismatchedDataHash_ComputedRoot(t *testing.T) { - header := RandExtendedHeader(t) + header := headertest.RandExtendedHeader(t) header.DataHash = rand.Bytes(32) diff --git a/core/listener.go b/core/listener.go index 7a8a3c8055..96a37c7f15 100644 --- a/core/listener.go +++ b/core/listener.go @@ -4,12 +4,12 @@ import ( "context" "fmt" - "github.com/ipfs/go-blockservice" pubsub "github.com/libp2p/go-libp2p-pubsub" "github.com/tendermint/tendermint/types" "github.com/celestiaorg/celestia-node/header" libhead "github.com/celestiaorg/celestia-node/libs/header" + "github.com/celestiaorg/celestia-node/share/eds" "github.com/celestiaorg/celestia-node/share/p2p/shrexsub" ) @@ -22,9 +22,9 @@ import ( // network. type Listener struct { fetcher *BlockFetcher - bServ blockservice.BlockService construct header.ConstructFn + store *eds.Store headerBroadcaster libhead.Broadcaster[*header.ExtendedHeader] hashBroadcaster shrexsub.BroadcastFn @@ -35,16 +35,16 @@ type Listener struct { func NewListener( bcast libhead.Broadcaster[*header.ExtendedHeader], fetcher *BlockFetcher, - extHeaderBroadcaster shrexsub.BroadcastFn, - bServ blockservice.BlockService, + hashBroadcaster shrexsub.BroadcastFn, construct header.ConstructFn, + store *eds.Store, ) *Listener { return &Listener{ - headerBroadcaster: bcast, fetcher: fetcher, - hashBroadcaster: extHeaderBroadcaster, - bServ: bServ, + headerBroadcaster: bcast, + hashBroadcaster: hashBroadcaster, construct: construct, + store: store, } } @@ -65,7 +65,7 @@ func (cl *Listener) Start(ctx context.Context) error { return nil } -// Stop stops the Listener listener loop. +// Stop stops the listener loop. func (cl *Listener) Stop(ctx context.Context) error { cl.cancel() cl.cancel = nil @@ -96,17 +96,25 @@ func (cl *Listener) listen(ctx context.Context, sub <-chan *types.Block) { return } - eh, err := cl.construct(ctx, b, comm, vals, cl.bServ) + // extend block data + eds, err := extendBlock(b.Data) if err != nil { - log.Errorw("listener: making extended header", "err", err) + log.Errorw("listener: extending block data", "err", err) return } - - // broadcast new ExtendedHeader, but if core is still syncing, notify only local subscribers - err = cl.headerBroadcaster.Broadcast(ctx, eh, pubsub.WithLocalPublication(syncing)) + // generate extended header + eh, err := cl.construct(ctx, b, comm, vals, eds) if err != nil { - log.Errorw("listener: broadcasting next header", "height", eh.Height(), - "err", err) + log.Errorw("listener: making extended header", "err", err) + return + } + // store block data if not empty + if eds != nil { + err = cl.store.Put(ctx, eh.DAH.Hash(), eds) + if err != nil { + log.Errorw("listener: storing extended header", "err", err) + return + } } // notify network of new EDS hash only if core is already synced @@ -117,6 +125,13 @@ func (cl *Listener) listen(ctx context.Context, sub <-chan *types.Block) { "hash", eh.Hash(), "err", err) } } + + // broadcast new ExtendedHeader, but if core is still syncing, notify only local subscribers + err = cl.headerBroadcaster.Broadcast(ctx, eh, pubsub.WithLocalPublication(syncing)) + if err != nil { + log.Errorw("listener: broadcasting next header", "height", eh.Height(), + "err", err) + } case <-ctx.Done(): return } diff --git a/core/listener_test.go b/core/listener_test.go index a1940b3975..87b32ab995 100644 --- a/core/listener_test.go +++ b/core/listener_test.go @@ -1,11 +1,12 @@ package core import ( + "bytes" "context" "testing" "time" - mdutils "github.com/ipfs/go-merkledag/test" + "github.com/cosmos/cosmos-sdk/client/flags" pubsub "github.com/libp2p/go-libp2p-pubsub" "github.com/libp2p/go-libp2p/core/event" mocknet "github.com/libp2p/go-libp2p/p2p/net/mock" @@ -15,6 +16,7 @@ import ( "github.com/celestiaorg/celestia-node/header" "github.com/celestiaorg/celestia-node/libs/header/p2p" network "github.com/celestiaorg/celestia-node/nodebuilder/p2p" + "github.com/celestiaorg/celestia-node/share/eds" "github.com/celestiaorg/celestia-node/share/p2p/shrexsub" ) @@ -33,21 +35,22 @@ func TestListener(t *testing.T) { require.NoError(t, subscriber.Start(ctx)) subs, err := subscriber.Subscribe() require.NoError(t, err) + t.Cleanup(subs.Cancel) // create one block to store as Head in local store and then unsubscribe from block events - fetcher := createCoreFetcher(t) + fetcher, _ := createCoreFetcher(t, DefaultTestConfig()) eds := createEdsPubSub(ctx, t) // create Listener and start listening - cl := createListener(ctx, t, fetcher, ps0, eds) + cl := createListener(ctx, t, fetcher, ps0, eds, createStore(t)) err = cl.Start(ctx) require.NoError(t, err) edsSubs, err := eds.Subscribe() require.NoError(t, err) - defer edsSubs.Cancel() + t.Cleanup(edsSubs.Cancel) // ensure headers and dataHash are getting broadcasted to the relevant topics - for i := 1; i < 6; i++ { + for i := 0; i < 5; i++ { h, err := subs.NextHeader(ctx) require.NoError(t, err) @@ -62,6 +65,61 @@ func TestListener(t *testing.T) { require.Nil(t, cl.cancel) } +// TestListenerWithNonEmptyBlocks ensures that non-empty blocks are actually +// stored to eds.Store. +func TestListenerWithNonEmptyBlocks(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + t.Cleanup(cancel) + + // create mocknet with two pubsub endpoints + ps0, _ := createMocknetWithTwoPubsubEndpoints(ctx, t) + + // create one block to store as Head in local store and then unsubscribe from block events + cfg := DefaultTestConfig() + fetcher, cctx := createCoreFetcher(t, cfg) + eds := createEdsPubSub(ctx, t) + + store := createStore(t) + err := store.Start(ctx) + require.NoError(t, err) + t.Cleanup(func() { + err = store.Stop(ctx) + require.NoError(t, err) + }) + + // create Listener and start listening + cl := createListener(ctx, t, fetcher, ps0, eds, store) + err = cl.Start(ctx) + require.NoError(t, err) + + // listen for eds hashes broadcasted through eds-sub and ensure store has + // already stored them + sub, err := eds.Subscribe() + require.NoError(t, err) + t.Cleanup(sub.Cancel) + + empty := header.EmptyDAH() + // TODO extract 16 + for i := 0; i < 16; i++ { + _, err := cctx.FillBlock(16, cfg.Accounts, flags.BroadcastBlock) + require.NoError(t, err) + hash, err := sub.Next(ctx) + require.NoError(t, err) + + if bytes.Equal(empty.Hash(), hash) { + continue + } + + has, err := store.Has(ctx, hash) + require.NoError(t, err) + require.True(t, has) + } + + err = cl.Stop(ctx) + require.NoError(t, err) + require.Nil(t, cl.cancel) +} + func createMocknetWithTwoPubsubEndpoints(ctx context.Context, t *testing.T) (*pubsub.PubSub, *pubsub.PubSub) { net, err := mocknet.FullMeshLinked(2) require.NoError(t, err) @@ -103,6 +161,7 @@ func createListener( fetcher *BlockFetcher, ps *pubsub.PubSub, edsSub *shrexsub.PubSub, + store *eds.Store, ) *Listener { p2pSub := p2p.NewSubscriber[*header.ExtendedHeader](ps, header.MsgID, string(network.Private)) err := p2pSub.Start(ctx) @@ -111,7 +170,7 @@ func createListener( require.NoError(t, p2pSub.Stop(ctx)) }) - return NewListener(p2pSub, fetcher, edsSub.Broadcast, mdutils.Bserv(), header.MakeExtendedHeader) + return NewListener(p2pSub, fetcher, edsSub.Broadcast, header.MakeExtendedHeader, store) } func createEdsPubSub(ctx context.Context, t *testing.T) *shrexsub.PubSub { diff --git a/core/testing.go b/core/testing.go index 726dfea493..202cf9f0a1 100644 --- a/core/testing.go +++ b/core/testing.go @@ -2,19 +2,15 @@ package core import ( "fmt" - "math/rand" "net" "net/url" - "sort" "testing" - "time" appconfig "github.com/cosmos/cosmos-sdk/server/config" "github.com/stretchr/testify/require" tmconfig "github.com/tendermint/tendermint/config" tmrand "github.com/tendermint/tendermint/libs/rand" tmproto "github.com/tendermint/tendermint/proto/tendermint/types" - tmtypes "github.com/tendermint/tendermint/types" "github.com/celestiaorg/celestia-app/testutil/testnode" ) @@ -119,76 +115,6 @@ func StartTestNodeWithConfig(t *testing.T, cfg *TestConfig) testnode.Context { return cctx } -func RandValidator(randPower bool, minPower int64) (*tmtypes.Validator, tmtypes.PrivValidator) { - privVal := tmtypes.NewMockPV() - votePower := minPower - if randPower { - //nolint:gosec // G404: Use of weak random number generator - votePower += int64(rand.Uint32()) - } - pubKey, err := privVal.GetPubKey() - if err != nil { - panic(fmt.Errorf("could not retrieve pubkey %w", err)) - } - val := tmtypes.NewValidator(pubKey, votePower) - return val, privVal -} - -func RandValidatorSet(numValidators int, votingPower int64) (*tmtypes.ValidatorSet, []tmtypes.PrivValidator) { - var ( - valz = make([]*tmtypes.Validator, numValidators) - privValidators = make([]tmtypes.PrivValidator, numValidators) - ) - - for i := 0; i < numValidators; i++ { - val, privValidator := RandValidator(false, votingPower) - valz[i] = val - privValidators[i] = privValidator - } - - sort.Sort(tmtypes.PrivValidatorsByAddress(privValidators)) - - return tmtypes.NewValidatorSet(valz), privValidators -} - -func MakeCommit(blockID tmtypes.BlockID, height int64, round int32, - voteSet *tmtypes.VoteSet, validators []tmtypes.PrivValidator, now time.Time) (*tmtypes.Commit, error) { - - // all sign - for i := 0; i < len(validators); i++ { - pubKey, err := validators[i].GetPubKey() - if err != nil { - return nil, fmt.Errorf("can't get pubkey: %w", err) - } - vote := &tmtypes.Vote{ - ValidatorAddress: pubKey.Address(), - ValidatorIndex: int32(i), - Height: height, - Round: round, - Type: tmproto.PrecommitType, - BlockID: blockID, - Timestamp: now, - } - - _, err = signAddVote(validators[i], vote, voteSet) - if err != nil { - return nil, err - } - } - - return voteSet.MakeCommit(), nil -} - -func signAddVote(privVal tmtypes.PrivValidator, vote *tmtypes.Vote, voteSet *tmtypes.VoteSet) (signed bool, err error) { - v := vote.ToProto() - err = privVal.SignVote(voteSet.ChainID(), v) - if err != nil { - return false, err - } - vote.Signature = v.Signature - return voteSet.AddVote(vote) -} - func getFreePort() int { a, err := net.ResolveTCPAddr("tcp", "localhost:0") if err == nil { diff --git a/das/coordinator.go b/das/coordinator.go index cf356349d9..83a9d65df8 100644 --- a/das/coordinator.go +++ b/das/coordinator.go @@ -6,14 +6,16 @@ import ( "github.com/celestiaorg/celestia-node/header" libhead "github.com/celestiaorg/celestia-node/libs/header" + "github.com/celestiaorg/celestia-node/share/p2p/shrexsub" ) // samplingCoordinator runs and coordinates sampling workers and updates current sampling state type samplingCoordinator struct { concurrencyLimit int - getter libhead.Getter[*header.ExtendedHeader] - sampleFn sampleFn + getter libhead.Getter[*header.ExtendedHeader] + sampleFn sampleFn + broadcastFn shrexsub.BroadcastFn state coordinatorState @@ -40,11 +42,13 @@ func newSamplingCoordinator( params Parameters, getter libhead.Getter[*header.ExtendedHeader], sample sampleFn, + broadcast shrexsub.BroadcastFn, ) *samplingCoordinator { return &samplingCoordinator{ concurrencyLimit: params.ConcurrencyLimit, getter: getter, sampleFn: sample, + broadcastFn: broadcast, state: newCoordinatorState(params), resultCh: make(chan result), updHeadCh: make(chan uint64), @@ -75,7 +79,10 @@ func (sc *samplingCoordinator) run(ctx context.Context, cp checkpoint) { select { case head := <-sc.updHeadCh: - if sc.state.updateHead(head) { + if sc.state.isNewHead(head) { + sc.runWorker(ctx, sc.state.newRecentJob(head)) + sc.state.updateHead(head) + // run worker without concurrency limit restrictions to reduced delay sc.metrics.observeNewHead(ctx) } case res := <-sc.resultCh: @@ -99,7 +106,7 @@ func (sc *samplingCoordinator) runWorker(ctx context.Context, j job) { sc.workersWg.Add(1) go func() { defer sc.workersWg.Done() - w.run(ctx, sc.getter, sc.sampleFn, sc.metrics, sc.resultCh) + w.run(ctx, sc.getter, sc.sampleFn, sc.broadcastFn, sc.metrics, sc.resultCh) }() } diff --git a/das/coordinator_test.go b/das/coordinator_test.go index a37ea73f08..016e7e0cf0 100644 --- a/das/coordinator_test.go +++ b/das/coordinator_test.go @@ -8,9 +8,11 @@ import ( "testing" "time" - "github.com/celestiaorg/celestia-node/header" - "github.com/stretchr/testify/assert" + + "github.com/celestiaorg/celestia-node/header" + "github.com/celestiaorg/celestia-node/share" + "github.com/celestiaorg/celestia-node/share/p2p/shrexsub" ) func TestCoordinator(t *testing.T) { @@ -19,7 +21,7 @@ func TestCoordinator(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testParams.timeoutDelay) sampler := newMockSampler(testParams.sampleFrom, testParams.networkHead) - coordinator := newSamplingCoordinator(testParams.dasParams, getterStub{}, onceMiddleWare(sampler.sample)) + coordinator := newSamplingCoordinator(testParams.dasParams, getterStub{}, onceMiddleWare(sampler.sample), nil) go coordinator.run(ctx, sampler.checkpoint) @@ -44,15 +46,12 @@ func TestCoordinator(t *testing.T) { sampler := newMockSampler(testParams.sampleFrom, testParams.networkHead) - coordinator := newSamplingCoordinator(testParams.dasParams, getterStub{}, sampler.sample) + newhead := testParams.networkHead + 200 + coordinator := newSamplingCoordinator(testParams.dasParams, getterStub{}, sampler.sample, newBroadcastMock(1)) go coordinator.run(ctx, sampler.checkpoint) - time.Sleep(50 * time.Millisecond) // discover new height - for i := 0; i < 200; i++ { - // mess the order by running in go-routine - sampler.discover(ctx, testParams.networkHead+uint64(i), coordinator.listen) - } + sampler.discover(ctx, newhead, coordinator.listen) // check if all jobs were sampled successfully assert.NoError(t, sampler.finished(ctx), "not all headers were sampled") @@ -69,7 +68,6 @@ func TestCoordinator(t *testing.T) { }) t.Run("prioritize newly discovered over known", func(t *testing.T) { - testParams := defaultTestParams() testParams.dasParams.ConcurrencyLimit = 1 @@ -86,31 +84,35 @@ func TestCoordinator(t *testing.T) { // lock worker before start, to not let it indicateDone before discover lk := newLock(testParams.sampleFrom, testParams.sampleFrom) - // expect worker to prioritize newly discovered (20 -> 10) and then old (0 -> 10) - order := newCheckOrder().addInterval( - testParams.sampleFrom, - testParams.dasParams.SamplingRange, - ) // worker will pick up first job before discovery + order := newCheckOrder().addInterval(toBeDiscovered, toBeDiscovered) - order.addStacks(testParams.networkHead+1, toBeDiscovered, testParams.dasParams.SamplingRange) - order.addInterval(testParams.dasParams.SamplingRange+1, toBeDiscovered) + // expect worker to prioritize newly discovered + order.addInterval( + testParams.sampleFrom, + toBeDiscovered, + ) // start coordinator coordinator := newSamplingCoordinator(testParams.dasParams, getterStub{}, lk.middleWare( order.middleWare(sampler.sample), ), + newBroadcastMock(1), ) go coordinator.run(ctx, sampler.checkpoint) - // wait for worker to pick up first job - time.Sleep(50 * time.Millisecond) - // discover new height sampler.discover(ctx, toBeDiscovered, coordinator.listen) // check if no header were sampled yet - assert.Equal(t, 0, sampler.sampledAmount()) + for sampler.sampledAmount() != 1 { + time.Sleep(time.Millisecond) + select { + case <-ctx.Done(): + assert.NoError(t, ctx.Err()) + default: + } + } // unblock worker lk.release(testParams.sampleFrom) @@ -129,7 +131,7 @@ func TestCoordinator(t *testing.T) { assert.Equal(t, sampler.finalState(), newCheckpoint(coordinator.state.unsafeStats())) }) - t.Run("priority routine should not lock other workers", func(t *testing.T) { + t.Run("recent headers sampling routine should not lock other workers", func(t *testing.T) { testParams := defaultTestParams() testParams.networkHead = uint64(20) @@ -139,10 +141,9 @@ func TestCoordinator(t *testing.T) { lk := newLock(testParams.sampleFrom, testParams.networkHead) // lock all workers before start coordinator := newSamplingCoordinator(testParams.dasParams, getterStub{}, - lk.middleWare(sampler.sample)) + lk.middleWare(sampler.sample), newBroadcastMock(1)) go coordinator.run(ctx, sampler.checkpoint) - time.Sleep(50 * time.Millisecond) // discover new height and lock it discovered := testParams.networkHead + 1 lk.add(discovered) @@ -186,7 +187,12 @@ func TestCoordinator(t *testing.T) { bornToFail := []uint64{4, 8, 15, 16, 23, 42} sampler := newMockSampler(testParams.sampleFrom, testParams.networkHead, bornToFail...) - coordinator := newSamplingCoordinator(testParams.dasParams, getterStub{}, onceMiddleWare(sampler.sample)) + coordinator := newSamplingCoordinator( + testParams.dasParams, + getterStub{}, + onceMiddleWare(sampler.sample), + newBroadcastMock(1), + ) go coordinator.run(ctx, sampler.checkpoint) // wait for coordinator to indicateDone catchup @@ -217,7 +223,12 @@ func TestCoordinator(t *testing.T) { sampler := newMockSampler(testParams.sampleFrom, testParams.networkHead, failedAgain...) sampler.checkpoint.Failed = failedLastRun - coordinator := newSamplingCoordinator(testParams.dasParams, getterStub{}, onceMiddleWare(sampler.sample)) + coordinator := newSamplingCoordinator( + testParams.dasParams, + getterStub{}, + onceMiddleWare(sampler.sample), + newBroadcastMock(1), + ) go coordinator.run(ctx, sampler.checkpoint) // check if all jobs were sampled successfully @@ -249,8 +260,12 @@ func BenchmarkCoordinator(b *testing.B) { b.Run("bench run", func(b *testing.B) { ctx, cancel := context.WithTimeout(context.Background(), timeoutDelay) - coordinator := newSamplingCoordinator(params, newBenchGetter(), - func(ctx context.Context, h *header.ExtendedHeader) error { return nil }) + coordinator := newSamplingCoordinator( + params, + newBenchGetter(), + func(ctx context.Context, h *header.ExtendedHeader) error { return nil }, + newBroadcastMock(1), + ) go coordinator.run(ctx, checkpoint{ SampleFrom: 1, NetworkHead: uint64(b.N), @@ -436,8 +451,8 @@ func (o *checkOrder) middleWare(out sampleFn) sampleFn { if len(o.queue) > 0 { // check last item in queue to be same as input if o.queue[0] != uint64(h.Height()) { - o.lock.Unlock() - return fmt.Errorf("expected height: %v,got: %v", o.queue[0], h) + defer o.lock.Unlock() + return fmt.Errorf("expected height: %v,got: %v", o.queue[0], h.Height()) } o.queue = o.queue[1:] } @@ -547,7 +562,20 @@ func defaultTestParams() testParams { return testParams{ networkHead: uint64(500), sampleFrom: dasParamsDefault.SampleFrom, - timeoutDelay: 125 * time.Second, + timeoutDelay: 5 * time.Second, dasParams: dasParamsDefault, } } + +func newBroadcastMock(callLimit int) shrexsub.BroadcastFn { + var m sync.Mutex + return func(ctx context.Context, hash share.DataHash) error { + m.Lock() + defer m.Unlock() + if callLimit == 0 { + return errors.New("exceeded mock call limit") + } + callLimit-- + return nil + } +} diff --git a/das/daser.go b/das/daser.go index 1be0b2af87..47c5960326 100644 --- a/das/daser.go +++ b/das/daser.go @@ -6,6 +6,8 @@ import ( "fmt" "sync/atomic" + "github.com/celestiaorg/celestia-node/share/p2p/shrexsub" + "github.com/ipfs/go-datastore" logging "github.com/ipfs/go-log/v2" @@ -46,6 +48,7 @@ func NewDASer( getter libhead.Getter[*header.ExtendedHeader], dstore datastore.Datastore, bcast fraud.Broadcaster, + shrexBroadcast shrexsub.BroadcastFn, options ...Option, ) (*DASer, error) { d := &DASer{ @@ -68,7 +71,7 @@ func NewDASer( return nil, err } - d.sampler = newSamplingCoordinator(d.params, getter, d.sample) + d.sampler = newSamplingCoordinator(d.params, getter, d.sample, shrexBroadcast) return d, nil } @@ -149,9 +152,6 @@ func (d *DASer) sample(ctx context.Context, h *header.ExtendedHeader) error { err := d.da.SharesAvailable(ctx, h.DAH) if err != nil { - if err == context.Canceled { - return err - } var byzantineErr *byzantine.ErrByzantine if errors.As(err, &byzantineErr) { log.Warn("Propagating proof...") @@ -159,6 +159,8 @@ func (d *DASer) sample(ctx context.Context, h *header.ExtendedHeader) error { if sendErr != nil { log.Errorw("fraud proof propagating failed", "err", sendErr) } + } else if err == context.Canceled { + return err } log.Errorw("sampling failed", "height", h.Height(), "hash", h.Hash(), diff --git a/das/daser_test.go b/das/daser_test.go index ab89e76b02..b410b0891e 100644 --- a/das/daser_test.go +++ b/das/daser_test.go @@ -11,7 +11,6 @@ import ( ds_sync "github.com/ipfs/go-datastore/sync" mdutils "github.com/ipfs/go-merkledag/test" pubsub "github.com/libp2p/go-libp2p-pubsub" - "github.com/libp2p/go-libp2p/core/peer" mocknet "github.com/libp2p/go-libp2p/p2p/net/mock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -44,7 +43,7 @@ func TestDASerLifecycle(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), timeout) t.Cleanup(cancel) - daser, err := NewDASer(avail, sub, mockGet, ds, mockService) + daser, err := NewDASer(avail, sub, mockGet, ds, mockService, newBroadcastMock(1)) require.NoError(t, err) err = daser.Start(ctx) @@ -84,7 +83,7 @@ func TestDASer_Restart(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), timeout) t.Cleanup(cancel) - daser, err := NewDASer(avail, sub, mockGet, ds, mockService) + daser, err := NewDASer(avail, sub, mockGet, ds, mockService, newBroadcastMock(1)) require.NoError(t, err) err = daser.Start(ctx) @@ -115,7 +114,7 @@ func TestDASer_Restart(t *testing.T) { restartCtx, restartCancel := context.WithTimeout(context.Background(), timeout) t.Cleanup(restartCancel) - daser, err = NewDASer(avail, sub, mockGet, ds, mockService) + daser, err = NewDASer(avail, sub, mockGet, ds, mockService, newBroadcastMock(1)) require.NoError(t, err) err = daser.Start(restartCtx) @@ -159,11 +158,11 @@ func TestDASer_stopsAfter_BEFP(t *testing.T) { // create fraud service and break one header f := fraud.NewProofService(ps, net.Hosts()[0], mockGet.GetByHeight, ds, false, string(p2p.Private)) require.NoError(t, f.Start(ctx)) - mockGet.headers[1] = headertest.CreateFraudExtHeader(t, mockGet.headers[1], bServ) + mockGet.headers[1], _ = headertest.CreateFraudExtHeader(t, mockGet.headers[1], bServ) newCtx := context.Background() // create and start DASer - daser, err := NewDASer(avail, sub, mockGet, ds, f) + daser, err := NewDASer(avail, sub, mockGet, ds, f, newBroadcastMock(1)) require.NoError(t, err) resultCh := make(chan error) @@ -191,7 +190,7 @@ func TestDASerSampleTimeout(t *testing.T) { getter := getterStub{} avail := mocks.NewMockAvailability(gomock.NewController(t)) avail.EXPECT().SharesAvailable(gomock.Any(), gomock.Any()).DoAndReturn( - func(sampleCtx context.Context, h *share.Root, peers ...peer.ID) error { + func(sampleCtx context.Context, h *share.Root) error { select { case <-sampleCtx.Done(): return sampleCtx.Err() @@ -206,7 +205,7 @@ func TestDASerSampleTimeout(t *testing.T) { f := new(fraud.DummyService) // create and start DASer - daser, err := NewDASer(avail, sub, getter, ds, f) + daser, err := NewDASer(avail, sub, getter, ds, f, newBroadcastMock(1)) require.NoError(t, err) // assign directly to avoid params validation diff --git a/das/doc.go b/das/doc.go index c255530aef..bc67fcc7f3 100644 --- a/das/doc.go +++ b/das/doc.go @@ -18,6 +18,6 @@ workers that perform DAS on new ExtendedHeaders in the network. The DASer kicks loop by loading its last DASed headers snapshot (`checkpoint`) and kicking off worker pool to DAS all headers between the checkpoint and the current network head. It subscribes to notifications about to new ExtendedHeaders, received via gossipsub. Newly found headers -are being put into higher priority queue and will be sampled by the next available worker. +are being put into workers directly, without applying concurrency limiting restrictions. */ package das diff --git a/das/metrics.go b/das/metrics.go index 47757aa740..b5105df53b 100644 --- a/das/metrics.go +++ b/das/metrics.go @@ -136,6 +136,7 @@ func (m *metrics) observeSample( h *header.ExtendedHeader, sampleTime time.Duration, err error, + isRecentHeader bool, ) { if m == nil { return @@ -154,7 +155,9 @@ func (m *metrics) observeSample( atomic.StoreUint64(&m.lastSampledTS, uint64(time.Now().UTC().Unix())) - if err == nil { + // only increment the counter if it's not a recent header job + // as those happen twice. + if err == nil && !isRecentHeader { atomic.AddUint64(&m.totalSampledInt, 1) } } diff --git a/das/options.go b/das/options.go index eca8596de4..93bc09e2f3 100644 --- a/das/options.go +++ b/das/options.go @@ -32,9 +32,6 @@ type Parameters struct { // checkpoint backup. BackgroundStoreInterval time.Duration - // PriorityQueueSize defines the size limit of the priority queue - PriorityQueueSize int - // SampleFrom is the height sampling will start from if no previous checkpoint was saved SampleFrom uint64 @@ -51,7 +48,6 @@ func DefaultParameters() Parameters { SamplingRange: 100, ConcurrencyLimit: 16, BackgroundStoreInterval: 10 * time.Minute, - PriorityQueueSize: 16 * 4, SampleFrom: 1, SampleTimeout: time.Minute, } @@ -139,14 +135,6 @@ func WithBackgroundStoreInterval(backgroundStoreInterval time.Duration) Option { } } -// WithPriorityQueueSize is a functional option to configure the daser's `priorityQueuSize` -// parameter Refer to WithSamplingRange documentation to see an example of how to use this -func WithPriorityQueueSize(priorityQueueSize int) Option { - return func(d *DASer) { - d.params.PriorityQueueSize = priorityQueueSize - } -} - // WithSampleFrom is a functional option to configure the daser's `SampleFrom` parameter // Refer to WithSamplingRange documentation to see an example of how to use this func WithSampleFrom(sampleFrom uint64) Option { diff --git a/das/state.go b/das/state.go index b377c2e0f4..41c24e5849 100644 --- a/das/state.go +++ b/das/state.go @@ -10,10 +10,9 @@ type coordinatorState struct { sampleFrom uint64 // is the height from which the DASer will start sampling samplingRange uint64 // is the maximum amount of headers processed in one job. - priorityQueueSize int // the size of the priority queue - priority []job // list of headers heights that will be sampled with higher priority - inProgress map[int]func() workerState // keeps track of running workers - failed map[uint64]int // stores heights of failed headers with amount of attempt as value + retry []job // list of headers heights that will be retried after last run + inProgress map[int]func() workerState // keeps track of running workers + failed map[uint64]int // stores heights of failed headers with amount of attempt as value nextJobID int next uint64 // all headers before next were sent to workers @@ -26,26 +25,25 @@ type coordinatorState struct { // newCoordinatorState initiates state for samplingCoordinator func newCoordinatorState(params Parameters) coordinatorState { return coordinatorState{ - sampleFrom: params.SampleFrom, - samplingRange: params.SamplingRange, - priorityQueueSize: params.PriorityQueueSize, - priority: make([]job, 0), - inProgress: make(map[int]func() workerState), - failed: make(map[uint64]int), - nextJobID: 0, - next: params.SampleFrom, - networkHead: params.SampleFrom, - catchUpDoneCh: make(chan struct{}), + sampleFrom: params.SampleFrom, + samplingRange: params.SamplingRange, + retry: make([]job, 0), + inProgress: make(map[int]func() workerState), + failed: make(map[uint64]int), + nextJobID: 0, + next: params.SampleFrom, + networkHead: params.SampleFrom, + catchUpDoneCh: make(chan struct{}), } } func (s *coordinatorState) resumeFromCheckpoint(c checkpoint) { s.next = c.SampleFrom s.networkHead = c.NetworkHead - // put failed into priority to retry them on restart + // store failed to retry them on restart for h, count := range c.Failed { s.failed[h] = count - s.priority = append(s.priority, s.newJob(h, h)) + s.retry = append(s.retry, s.newJob(h, h)) } } @@ -74,30 +72,33 @@ func (s *coordinatorState) handleResult(res result) { s.checkDone() } -func (s *coordinatorState) updateHead(last uint64) bool { +func (s *coordinatorState) isNewHead(newHead uint64) bool { // seen this header before - if last <= s.networkHead { - log.Warnf("received head height: %v, which is lower or the same as previously known: %v", last, s.networkHead) + if newHead <= s.networkHead { + log.Warnf("received head height: %v, which is lower or the same as previously known: %v", newHead, s.networkHead) return false } + return true +} +func (s *coordinatorState) updateHead(newHead uint64) { if s.networkHead == s.sampleFrom { - s.networkHead = last log.Infow("found first header, starting sampling") - return true - } - - // add most recent headers into priority queue - from := s.networkHead + 1 - for from <= last && len(s.priority) < s.priorityQueueSize { - s.priority = append(s.priority, s.newJob(from, last)) - from += s.samplingRange } - log.Debugw("added recent headers to DASer priority queue", "from_height", s.networkHead, "to_height", last) - s.networkHead = last + s.networkHead = newHead + log.Debugw("updated head", "from_height", s.networkHead, "to_height", newHead) s.checkDone() - return true +} + +func (s *coordinatorState) newRecentJob(newHead uint64) job { + s.nextJobID++ + return job{ + id: s.nextJobID, + isRecentHeader: true, + From: newHead, + To: newHead, + } } // nextJob will return header height to be processed and done flag if there is none @@ -107,8 +108,8 @@ func (s *coordinatorState) nextJob() (next job, found bool) { return job{}, false } - // try to take from priority first - if next, found := s.nextFromPriority(); found { + // try to take from retry first + if next, found := s.nextFromRetry(); found { return next, found } @@ -122,19 +123,15 @@ func (s *coordinatorState) nextJob() (next job, found bool) { return j, true } -func (s *coordinatorState) nextFromPriority() (job, bool) { - for len(s.priority) > 0 { - next := s.priority[len(s.priority)-1] - s.priority = s.priority[:len(s.priority)-1] +func (s *coordinatorState) nextFromRetry() (job, bool) { + if len(s.retry) == 0 { + return job{}, false + } - // this job will be processed next normally, we can skip it - if next.From == s.next { - continue - } + next := s.retry[len(s.retry)-1] + s.retry = s.retry[:len(s.retry)-1] - return next, true - } - return job{}, false + return next, true } func (s *coordinatorState) putInProgress(jobID int, getState func() workerState) { @@ -207,7 +204,7 @@ func (s *coordinatorState) unsafeStats() SamplingStats { } func (s *coordinatorState) checkDone() { - if len(s.inProgress) == 0 && len(s.priority) == 0 && s.next > s.networkHead { + if len(s.inProgress) == 0 && len(s.retry) == 0 && s.next > s.networkHead { if s.catchUpDone.CompareAndSwap(false, true) { close(s.catchUpDoneCh) } diff --git a/das/worker.go b/das/worker.go index 13ab67a05e..3aa3fd91e3 100644 --- a/das/worker.go +++ b/das/worker.go @@ -7,6 +7,8 @@ import ( "sync" "time" + "github.com/celestiaorg/celestia-node/share/p2p/shrexsub" + "go.uber.org/multierr" "github.com/celestiaorg/celestia-node/header" @@ -30,7 +32,9 @@ type workerState struct { // job represents headers interval to be processed by worker type job struct { - id int + id int + isRecentHeader bool + From uint64 To uint64 } @@ -39,6 +43,7 @@ func (w *worker) run( ctx context.Context, getter libhead.Getter[*header.ExtendedHeader], sample sampleFn, + broadcast shrexsub.BroadcastFn, metrics *metrics, resultCh chan<- result) { jobStart := time.Now() @@ -77,7 +82,9 @@ func (w *worker) run( break } w.setResult(curr, err) - metrics.observeSample(ctx, h, time.Since(startSample), err) + + metrics.observeSample(ctx, h, time.Since(startSample), err, w.state.isRecentHeader) + if err != nil { log.Debugw( "failed to sampled header", @@ -87,15 +94,25 @@ func (w *worker) run( "data root", h.DAH.Hash(), "err", err, ) - } else { - log.Debugw( - "sampled header", - "height", h.Height(), - "hash", h.Hash(), - "square width", len(h.DAH.RowsRoots), - "data root", h.DAH.Hash(), - "finished (s)", time.Since(startSample), - ) + continue + } + + log.Debugw( + "sampled header", + "height", h.Height(), + "hash", h.Hash(), + "square width", len(h.DAH.RowsRoots), + "data root", h.DAH.Hash(), + "finished (s)", time.Since(startSample), + ) + + // notify network about availability of new block data (note: only full nodes can notify) + if w.state.isRecentHeader { + err = broadcast(ctx, h.DataHash.Bytes()) + if err != nil { + log.Warn("failed to broadcast availability message", + "height", h.Height(), "hash", h.Hash(), "err", err) + } } } diff --git a/fraud/service_test.go b/fraud/service_test.go index faafd1efff..a1f405e3d1 100644 --- a/fraud/service_test.go +++ b/fraud/service_test.go @@ -16,7 +16,6 @@ import ( "github.com/stretchr/testify/require" "github.com/celestiaorg/celestia-node/header" - "github.com/celestiaorg/celestia-node/nodebuilder/p2p" ) func TestService_Subscribe(t *testing.T) { @@ -133,7 +132,7 @@ func TestService_ReGossiping(t *testing.T) { }, sync.MutexWrap(datastore.NewMapDatastore()), false, - string(p2p.Private), + "private", ) addrB := host.InfoFromHost(net.Hosts()[1]) // -> B @@ -149,7 +148,7 @@ func TestService_ReGossiping(t *testing.T) { }, sync.MutexWrap(datastore.NewMapDatastore()), false, - string(p2p.Private), + "private", ) // establish connections // connect peers: A -> B -> C, so A and C are not connected to each other diff --git a/fraud/testing.go b/fraud/testing.go index 368f6a5b75..355d0440e4 100644 --- a/fraud/testing.go +++ b/fraud/testing.go @@ -16,7 +16,6 @@ import ( "github.com/celestiaorg/celestia-node/header" "github.com/celestiaorg/celestia-node/headertest" - "github.com/celestiaorg/celestia-node/nodebuilder/p2p" ) type DummyService struct { @@ -144,6 +143,6 @@ func createTestServiceWithHost( store.GetByHeight, sync.MutexWrap(datastore.NewMapDatastore()), enabledSyncer, - string(p2p.Private), + "private", ), store } diff --git a/go.mod b/go.mod index 7c8b59e1cb..a150c4b20b 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,7 @@ require ( cosmossdk.io/math v1.0.0-beta.3 github.com/BurntSushi/toml v1.2.1 github.com/alecthomas/jsonschema v0.0.0-20200530073317-71f438968921 + github.com/benbjohnson/clock v1.3.0 github.com/celestiaorg/celestia-app v0.12.0-rc7 github.com/celestiaorg/go-libp2p-messenger v0.1.0 github.com/celestiaorg/nmt v0.14.0 @@ -67,6 +68,7 @@ require ( go.uber.org/fx v1.18.2 go.uber.org/multierr v1.9.0 golang.org/x/crypto v0.5.0 + golang.org/x/exp v0.0.0-20221205204356-47842c84f3db golang.org/x/sync v0.1.0 golang.org/x/text v0.6.0 google.golang.org/grpc v1.52.0 @@ -90,7 +92,6 @@ require ( github.com/Workiva/go-datastructures v1.0.53 // indirect github.com/armon/go-metrics v0.4.1 // indirect github.com/aws/aws-sdk-go v1.40.45 // indirect - github.com/benbjohnson/clock v1.3.0 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/bgentry/go-netrc v0.0.0-20140422174119-9fd32a8b3d3d // indirect github.com/bgentry/speakeasy v0.1.0 // indirect @@ -302,7 +303,6 @@ require ( go.uber.org/atomic v1.10.0 // indirect go.uber.org/dig v1.15.0 // indirect go.uber.org/zap v1.24.0 // indirect - golang.org/x/exp v0.0.0-20221205204356-47842c84f3db // indirect golang.org/x/mod v0.7.0 // indirect golang.org/x/net v0.5.0 // indirect golang.org/x/oauth2 v0.0.0-20221014153046-6fdb5e3db783 // indirect diff --git a/header/header.go b/header/header.go index e233852aa5..3e371e1c70 100644 --- a/header/header.go +++ b/header/header.go @@ -7,28 +7,22 @@ import ( "fmt" "time" - "github.com/ipfs/go-blockservice" - logging "github.com/ipfs/go-log/v2" - tmjson "github.com/tendermint/tendermint/libs/json" core "github.com/tendermint/tendermint/types" - appshares "github.com/celestiaorg/celestia-app/pkg/shares" - "github.com/celestiaorg/celestia-app/pkg/da" + "github.com/celestiaorg/rsmt2d" + libhead "github.com/celestiaorg/celestia-node/libs/header" - "github.com/celestiaorg/celestia-node/share" ) -var log = logging.Logger("header") - // ConstructFn aliases a function that creates an ExtendedHeader. type ConstructFn = func( context.Context, *core.Block, *core.Commit, *core.ValidatorSet, - blockservice.BlockService, + *rsmt2d.ExtendedDataSquare, ) (*ExtendedHeader, error) type DataAvailabilityHeader = da.DataAvailabilityHeader @@ -79,23 +73,14 @@ func MakeExtendedHeader( b *core.Block, comm *core.Commit, vals *core.ValidatorSet, - bServ blockservice.BlockService, + eds *rsmt2d.ExtendedDataSquare, ) (*ExtendedHeader, error) { var dah DataAvailabilityHeader - if len(b.Txs) > 0 { - shares, err := appshares.Split(b.Data, true) - if err != nil { - return nil, err - } - extended, err := share.AddShares(ctx, appshares.ToBytes(shares), bServ) - if err != nil { - return nil, err - } - dah = da.NewDataAvailabilityHeader(extended) - } else { - // use MinDataAvailabilityHeader for empty block + switch eds { + case nil: dah = EmptyDAH() - log.Debugw("empty block received", "height", "blockID", "time", b.Height, b.Time.String(), comm.BlockID) + default: + dah = da.NewDataAvailabilityHeader(eds) } eh := &ExtendedHeader{ diff --git a/headertest/testing.go b/headertest/testing.go index 4028042e75..71ce072762 100644 --- a/headertest/testing.go +++ b/headertest/testing.go @@ -2,8 +2,9 @@ package headertest import ( "context" - + "fmt" mrand "math/rand" + "sort" "testing" "time" @@ -18,10 +19,11 @@ import ( "github.com/tendermint/tendermint/proto/tendermint/version" "github.com/tendermint/tendermint/types" tmtime "github.com/tendermint/tendermint/types/time" + "golang.org/x/exp/rand" "github.com/celestiaorg/celestia-app/pkg/da" + "github.com/celestiaorg/rsmt2d" - "github.com/celestiaorg/celestia-node/core" "github.com/celestiaorg/celestia-node/header" libhead "github.com/celestiaorg/celestia-node/libs/header" "github.com/celestiaorg/celestia-node/libs/header/test" @@ -44,7 +46,7 @@ type TestSuite struct { // NewTestSuite setups a new test suite with a given number of validators. func NewTestSuite(t *testing.T, num int) *TestSuite { - valSet, vals := core.RandValidatorSet(num, 10) + valSet, vals := RandValidatorSet(num, 10) return &TestSuite{ t: t, vals: vals, @@ -62,7 +64,7 @@ func (s *TestSuite) genesis() *header.ExtendedHeader { gen.NextValidatorsHash = s.valSet.Hash() gen.Height = 1 voteSet := types.NewVoteSet(gen.ChainID, gen.Height, 0, tmproto.PrecommitType, s.valSet) - commit, err := core.MakeCommit(RandBlockID(s.t), gen.Height, 0, voteSet, s.vals, time.Now()) + commit, err := MakeCommit(RandBlockID(s.t), gen.Height, 0, voteSet, s.vals, time.Now()) require.NoError(s.t, err) eh := &header.ExtendedHeader{ @@ -75,6 +77,44 @@ func (s *TestSuite) genesis() *header.ExtendedHeader { return eh } +func MakeCommit(blockID types.BlockID, height int64, round int32, + voteSet *types.VoteSet, validators []types.PrivValidator, now time.Time) (*types.Commit, error) { + + // all sign + for i := 0; i < len(validators); i++ { + pubKey, err := validators[i].GetPubKey() + if err != nil { + return nil, fmt.Errorf("can't get pubkey: %w", err) + } + vote := &types.Vote{ + ValidatorAddress: pubKey.Address(), + ValidatorIndex: int32(i), + Height: height, + Round: round, + Type: tmproto.PrecommitType, + BlockID: blockID, + Timestamp: now, + } + + _, err = signAddVote(validators[i], vote, voteSet) + if err != nil { + return nil, err + } + } + + return voteSet.MakeCommit(), nil +} + +func signAddVote(privVal types.PrivValidator, vote *types.Vote, voteSet *types.VoteSet) (signed bool, err error) { + v := vote.ToProto() + err = privVal.SignVote(voteSet.ChainID(), v) + if err != nil { + return false, err + } + vote.Signature = v.Signature + return voteSet.AddVote(vote) +} + func (s *TestSuite) Head() *header.ExtendedHeader { if s.head == nil { s.head = s.genesis() @@ -173,10 +213,10 @@ func RandExtendedHeader(t *testing.T) *header.ExtendedHeader { rh := RandRawHeader(t) rh.DataHash = dah.Hash() - valSet, vals := core.RandValidatorSet(3, 1) + valSet, vals := RandValidatorSet(3, 1) rh.ValidatorsHash = valSet.Hash() voteSet := types.NewVoteSet(rh.ChainID, rh.Height, 0, tmproto.PrecommitType, valSet) - commit, err := core.MakeCommit(RandBlockID(t), rh.Height, 0, voteSet, vals, time.Now()) + commit, err := MakeCommit(RandBlockID(t), rh.Height, 0, voteSet, vals, time.Now()) require.NoError(t, err) return &header.ExtendedHeader{ @@ -187,6 +227,38 @@ func RandExtendedHeader(t *testing.T) *header.ExtendedHeader { } } +func RandValidatorSet(numValidators int, votingPower int64) (*types.ValidatorSet, []types.PrivValidator) { + var ( + valz = make([]*types.Validator, numValidators) + privValidators = make([]types.PrivValidator, numValidators) + ) + + for i := 0; i < numValidators; i++ { + val, privValidator := RandValidator(false, votingPower) + valz[i] = val + privValidators[i] = privValidator + } + + sort.Sort(types.PrivValidatorsByAddress(privValidators)) + + return types.NewValidatorSet(valz), privValidators +} + +func RandValidator(randPower bool, minPower int64) (*types.Validator, types.PrivValidator) { + privVal := types.NewMockPV() + votePower := minPower + if randPower { + //nolint:gosec // G404: Use of weak random number generator + votePower += int64(rand.Uint32()) + } + pubKey, err := privVal.GetPubKey() + if err != nil { + panic(fmt.Errorf("could not retrieve pubkey %w", err)) + } + val := types.NewValidator(pubKey, votePower) + return val, privVal +} + // RandRawHeader provides a RawHeader fixture. func RandRawHeader(t *testing.T) *header.RawHeader { return &header.RawHeader{ @@ -222,13 +294,14 @@ func RandBlockID(t *testing.T) types.BlockID { } // FraudMaker creates a custom ConstructFn that breaks the block at the given height. -func FraudMaker(t *testing.T, faultHeight int64) header.ConstructFn { +func FraudMaker(t *testing.T, faultHeight int64, bServ blockservice.BlockService) header.ConstructFn { log.Warn("Corrupting block...", "height", faultHeight) return func(ctx context.Context, b *types.Block, comm *types.Commit, vals *types.ValidatorSet, - bServ blockservice.BlockService) (*header.ExtendedHeader, error) { + eds *rsmt2d.ExtendedDataSquare, + ) (*header.ExtendedHeader, error) { if b.Height == faultHeight { eh := &header.ExtendedHeader{ RawHeader: b.Header, @@ -236,10 +309,13 @@ func FraudMaker(t *testing.T, faultHeight int64) header.ConstructFn { ValidatorSet: vals, } - eh = CreateFraudExtHeader(t, eh, bServ) + eh, dataSq := CreateFraudExtHeader(t, eh, bServ) + if eds != nil { + *eds = *dataSq + } return eh, nil } - return header.MakeExtendedHeader(ctx, b, comm, vals, bServ) + return header.MakeExtendedHeader(ctx, b, comm, vals, eds) } } @@ -247,7 +323,7 @@ func CreateFraudExtHeader( t *testing.T, eh *header.ExtendedHeader, dag blockservice.BlockService, -) *header.ExtendedHeader { +) (*header.ExtendedHeader, *rsmt2d.ExtendedDataSquare) { extended := share.RandEDS(t, 2) shares := share.ExtractEDS(extended) copy(shares[0][share.NamespaceSize:], shares[1][share.NamespaceSize:]) @@ -256,7 +332,7 @@ func CreateFraudExtHeader( dah := da.NewDataAvailabilityHeader(extended) eh.DAH = &dah eh.RawHeader.DataHash = dah.Hash() - return eh + return eh, extended } type DummySubscriber struct { diff --git a/libs/header/interface.go b/libs/header/interface.go index 387159d92c..96cdb39628 100644 --- a/libs/header/interface.go +++ b/libs/header/interface.go @@ -87,6 +87,9 @@ type Store[H Header] interface { // Has checks whether Header is already stored. Has(context.Context, Hash) (bool, error) + // HasAt checks whether Header at the given height is already stored. + HasAt(context.Context, uint64) bool + // Append stores and verifies the given Header(s). // It requires them to be adjacent and in ascending order, // as it applies them contiguously on top of the current head height. diff --git a/libs/header/mocks/store.go b/libs/header/mocks/store.go index a6227a60c9..8d3667d131 100644 --- a/libs/header/mocks/store.go +++ b/libs/header/mocks/store.go @@ -86,6 +86,10 @@ func (m *MockStore[H]) Has(context.Context, header.Hash) (bool, error) { return false, nil } +func (m *MockStore[H]) HasAt(_ context.Context, height uint64) bool { + return height != 0 && m.HeadHeight >= int64(height) +} + func (m *MockStore[H]) Append(ctx context.Context, headers ...H) (int, error) { for _, header := range headers { m.Headers[header.Height()] = header diff --git a/libs/header/p2p/exchange_test.go b/libs/header/p2p/exchange_test.go index b27c760949..b816d65f1e 100644 --- a/libs/header/p2p/exchange_test.go +++ b/libs/header/p2p/exchange_test.go @@ -347,19 +347,19 @@ func TestExchange_RequestHeadersFromAnotherPeerWhenTimeout(t *testing.T) { ) require.NoError(t, err) // change store implementation - serverSideEx.getter = &timedOutStore{exchg.Params.RequestTimeout} + serverSideEx.store = &timedOutStore{timeout: exchg.Params.RequestTimeout} require.NoError(t, serverSideEx.Start(context.Background())) t.Cleanup(func() { serverSideEx.Stop(context.Background()) //nolint:errcheck }) + prevScore := exchg.peerTracker.trackedPeers[host1.ID()].score() exchg.peerTracker.peerLk.Lock() exchg.peerTracker.trackedPeers[host2.ID()] = &peerStat{peerID: host2.ID(), peerScore: 200} exchg.peerTracker.peerLk.Unlock() _, err = exchg.GetRangeByHeight(context.Background(), 1, 3) require.NoError(t, err) - // ensure that peerScore for the first peer was decrease by 20% - newPeerScore := exchg.peerTracker.trackedPeers[host2.ID()].score() - assert.Less(t, newPeerScore, float32(200)) + newPeerScore := exchg.peerTracker.trackedPeers[host1.ID()].score() + assert.NotEqual(t, newPeerScore, prevScore) } func createMocknet(t *testing.T, amount int) []libhost.Host { @@ -400,34 +400,11 @@ func createP2PExAndServer( } type timedOutStore struct { + headerMock.MockStore[*test.DummyHeader] timeout time.Duration } -func (t *timedOutStore) Head(context.Context) (*test.DummyHeader, error) { - // TODO implement me - panic("implement me") -} - -func (t *timedOutStore) Get(context.Context, header.Hash) (*test.DummyHeader, error) { - // TODO implement me - panic("implement me") -} - -func (t *timedOutStore) GetByHeight(context.Context, uint64) (*test.DummyHeader, error) { - // TODO implement me - panic("implement me") -} - func (t *timedOutStore) GetRangeByHeight(_ context.Context, _, _ uint64) ([]*test.DummyHeader, error) { time.Sleep(t.timeout + 1) return []*test.DummyHeader{}, nil } - -func (t *timedOutStore) GetVerifiedRange( - context.Context, - *test.DummyHeader, - uint64, -) ([]*test.DummyHeader, error) { - // TODO implement me - panic("implement me") -} diff --git a/libs/header/p2p/server.go b/libs/header/p2p/server.go index d7d6b81d9e..e9d54aaa16 100644 --- a/libs/header/p2p/server.go +++ b/libs/header/p2p/server.go @@ -28,8 +28,8 @@ var ( type ExchangeServer[H header.Header] struct { protocolID protocol.ID - host host.Host - getter header.Getter[H] + host host.Host + store header.Store[H] ctx context.Context cancel context.CancelFunc @@ -41,7 +41,7 @@ type ExchangeServer[H header.Header] struct { // header-related requests. func NewExchangeServer[H header.Header]( host host.Host, - getter header.Getter[H], + store header.Store[H], opts ...Option[ServerParameters], ) (*ExchangeServer[H], error) { params := DefaultServerParameters() @@ -55,7 +55,7 @@ func NewExchangeServer[H header.Header]( return &ExchangeServer[H]{ protocolID: protocolID(params.protocolSuffix), host: host, - getter: getter, + store: store, Params: params, }, nil } @@ -164,7 +164,7 @@ func (serv *ExchangeServer[H]) handleRequestByHash(hash []byte) ([]H, error) { )) defer span.End() - h, err := serv.getter.Get(ctx, hash) + h, err := serv.store.Get(ctx, hash) if err != nil { log.Errorw("server: getting header by hash", "hash", header.Hash(hash).String(), "err", err) span.SetStatus(codes.Error, err.Error()) @@ -201,7 +201,13 @@ func (serv *ExchangeServer[H]) handleRequest(from, to uint64) ([]H, error) { } log.Debugw("server: handling headers request", "from", from, "to", to) - headersByRange, err := serv.getter.GetRangeByHeight(ctx, from, to) + if !serv.store.HasAt(ctx, to-1) { + span.SetStatus(codes.Error, header.ErrNotFound.Error()) + log.Debugw("server: requested headers not stored", "from", from, "to", to) + return nil, header.ErrNotFound + } + + headersByRange, err := serv.store.GetRangeByHeight(ctx, from, to) if err != nil { span.SetStatus(codes.Error, err.Error()) if errors.Is(err, context.DeadlineExceeded) { @@ -226,7 +232,7 @@ func (serv *ExchangeServer[H]) handleHeadRequest() ([]H, error) { ctx, span := tracer.Start(ctx, "request-head") defer span.End() - head, err := serv.getter.Head(ctx) + head, err := serv.store.Head(ctx) if err != nil { log.Errorw("server: getting head", "err", err) span.SetStatus(codes.Error, err.Error()) diff --git a/libs/header/store/store.go b/libs/header/store/store.go index 59c8897323..1f36fa839a 100644 --- a/libs/header/store/store.go +++ b/libs/header/store/store.go @@ -283,6 +283,10 @@ func (s *Store[H]) Has(ctx context.Context, hash header.Hash) (bool, error) { return s.ds.Has(ctx, datastore.NewKey(hash.String())) } +func (s *Store[H]) HasAt(_ context.Context, height uint64) bool { + return height != uint64(0) && s.Height() >= height +} + func (s *Store[H]) Append(ctx context.Context, headers ...H) (int, error) { lh := len(headers) if lh == 0 { diff --git a/libs/utils/square.go b/libs/utils/square.go new file mode 100644 index 0000000000..ce2663fd81 --- /dev/null +++ b/libs/utils/square.go @@ -0,0 +1,8 @@ +package utils + +import "math" + +// SquareSize returns the size of the square based on the given amount of shares. +func SquareSize(lenShares int) uint64 { + return uint64(math.Sqrt(float64(lenShares))) +} diff --git a/nodebuilder/config.go b/nodebuilder/config.go index e034a303b5..ef9784461a 100644 --- a/nodebuilder/config.go +++ b/nodebuilder/config.go @@ -39,7 +39,7 @@ func DefaultConfig(tp node.Type) *Config { commonConfig := &Config{ Core: core.DefaultConfig(), State: state.DefaultConfig(), - P2P: p2p.DefaultConfig(), + P2P: p2p.DefaultConfig(tp), RPC: rpc.DefaultConfig(), Gateway: gateway.DefaultConfig(), Share: share.DefaultConfig(), diff --git a/nodebuilder/core/module.go b/nodebuilder/core/module.go index efaa3dd8e8..ecb002890a 100644 --- a/nodebuilder/core/module.go +++ b/nodebuilder/core/module.go @@ -3,7 +3,6 @@ package core import ( "context" - "github.com/ipfs/go-blockservice" "go.uber.org/fx" "github.com/celestiaorg/celestia-node/core" @@ -11,6 +10,7 @@ import ( "github.com/celestiaorg/celestia-node/libs/fxutil" libhead "github.com/celestiaorg/celestia-node/libs/header" "github.com/celestiaorg/celestia-node/nodebuilder/node" + "github.com/celestiaorg/celestia-node/share/eds" "github.com/celestiaorg/celestia-node/share/p2p/shrexsub" ) @@ -35,13 +35,14 @@ func ConstructModule(tp node.Type, cfg *Config, options ...fx.Option) fx.Option fx.Provide(core.NewBlockFetcher), fxutil.ProvideAs(core.NewExchange, new(libhead.Exchange[*header.ExtendedHeader])), fx.Invoke(fx.Annotate( - func(bcast libhead.Broadcaster[*header.ExtendedHeader], + func( + bcast libhead.Broadcaster[*header.ExtendedHeader], fetcher *core.BlockFetcher, pubsub *shrexsub.PubSub, - bServ blockservice.BlockService, construct header.ConstructFn, + store *eds.Store, ) *core.Listener { - return core.NewListener(bcast, fetcher, pubsub.Broadcast, bServ, construct) + return core.NewListener(bcast, fetcher, pubsub.Broadcast, construct, store) }, fx.OnStart(func(ctx context.Context, listener *core.Listener) error { return listener.Start(ctx) diff --git a/nodebuilder/das/constructors.go b/nodebuilder/das/constructors.go index f73c9bbf53..5247a13597 100644 --- a/nodebuilder/das/constructors.go +++ b/nodebuilder/das/constructors.go @@ -11,6 +11,7 @@ import ( "github.com/celestiaorg/celestia-node/header" libhead "github.com/celestiaorg/celestia-node/libs/header" "github.com/celestiaorg/celestia-node/share" + "github.com/celestiaorg/celestia-node/share/p2p/shrexsub" ) var _ Module = (*daserStub)(nil) @@ -39,7 +40,8 @@ func newDASer( store libhead.Store[*header.ExtendedHeader], batching datastore.Batching, fraudServ fraud.Service, + bFn shrexsub.BroadcastFn, options ...das.Option, ) (*das.DASer, error) { - return das.NewDASer(da, hsub, store, batching, fraudServ, options...) + return das.NewDASer(da, hsub, store, batching, fraudServ, bFn, options...) } diff --git a/nodebuilder/das/module.go b/nodebuilder/das/module.go index d060a6a0b2..dbe13684a6 100644 --- a/nodebuilder/das/module.go +++ b/nodebuilder/das/module.go @@ -27,7 +27,6 @@ func ConstructModule(tp node.Type, cfg *Config) fx.Option { return []das.Option{ das.WithSamplingRange(c.SamplingRange), das.WithConcurrencyLimit(c.ConcurrencyLimit), - das.WithPriorityQueueSize(c.PriorityQueueSize), das.WithBackgroundStoreInterval(c.BackgroundStoreInterval), das.WithSampleFrom(c.SampleFrom), das.WithSampleTimeout(c.SampleTimeout), diff --git a/nodebuilder/gateway/module.go b/nodebuilder/gateway/module.go index 4cdf325dc0..b3070e01a6 100644 --- a/nodebuilder/gateway/module.go +++ b/nodebuilder/gateway/module.go @@ -21,6 +21,8 @@ func ConstructModule(tp node.Type, cfg *Config) fx.Option { if !cfg.Enabled { return fx.Options() } + // NOTE @distractedm1nd @renaynay: Remove whenever/if we decide to add auth to gateway + log.Warn("Gateway is enabled, however gateway endpoints are not authenticated. Use with caution!") baseComponents := fx.Options( fx.Supply(cfg), diff --git a/nodebuilder/p2p/bitswap.go b/nodebuilder/p2p/bitswap.go index 36be9ab953..fc4c72638d 100644 --- a/nodebuilder/p2p/bitswap.go +++ b/nodebuilder/p2p/bitswap.go @@ -13,6 +13,8 @@ import ( hst "github.com/libp2p/go-libp2p/core/host" "github.com/libp2p/go-libp2p/core/protocol" "go.uber.org/fx" + + "github.com/celestiaorg/celestia-node/share/eds" ) const ( @@ -25,30 +27,40 @@ const ( ) // dataExchange provides a constructor for IPFS block's DataExchange over BitSwap. -func dataExchange(params bitSwapParams) (exchange.Interface, blockstore.Blockstore, error) { - bs, err := blockstore.CachedBlockstore( - params.Ctx, - blockstore.NewBlockstore(params.Ds), - blockstore.CacheOpts{ - HasBloomFilterSize: defaultBloomFilterSize, - HasBloomFilterHashes: defaultBloomFilterHashes, - HasARCCacheSize: defaultARCCacheSize, - }, - ) - if err != nil { - return nil, nil, err - } +func dataExchange(params bitSwapParams) exchange.Interface { prefix := protocol.ID(fmt.Sprintf("/celestia/%s", params.Net)) return bitswap.New( params.Ctx, network.NewFromIpfsHost(params.Host, &routinghelpers.Null{}, network.Prefix(prefix)), - bs, + params.Bs, bitswap.ProvideEnabled(false), // NOTE: These below ar required for our protocol to work reliably. // See https://github.com/celestiaorg/celestia-node/issues/732 bitswap.SetSendDontHaves(false), bitswap.SetSimulateDontHavesOnTimeout(false), - ), bs, nil + ) +} + +func blockstoreFromDatastore(ctx context.Context, ds datastore.Batching) (blockstore.Blockstore, error) { + return blockstore.CachedBlockstore( + ctx, + blockstore.NewBlockstore(ds), + blockstore.CacheOpts{ + HasBloomFilterSize: defaultBloomFilterSize, + HasBloomFilterHashes: defaultBloomFilterHashes, + HasARCCacheSize: defaultARCCacheSize, + }, + ) +} + +func blockstoreFromEDSStore(ctx context.Context, store *eds.Store) (blockstore.Blockstore, error) { + return blockstore.CachedBlockstore( + ctx, + store.Blockstore(), + blockstore.CacheOpts{ + HasARCCacheSize: defaultARCCacheSize, + }, + ) } type bitSwapParams struct { @@ -57,5 +69,5 @@ type bitSwapParams struct { Ctx context.Context Net Network Host hst.Host - Ds datastore.Batching + Bs blockstore.Blockstore } diff --git a/nodebuilder/p2p/config.go b/nodebuilder/p2p/config.go index 0388a883c9..b968100095 100644 --- a/nodebuilder/p2p/config.go +++ b/nodebuilder/p2p/config.go @@ -6,6 +6,8 @@ import ( "github.com/libp2p/go-libp2p/core/peer" ma "github.com/multiformats/go-multiaddr" + + "github.com/celestiaorg/celestia-node/nodebuilder/node" ) const defaultRoutingRefreshPeriod = time.Minute @@ -36,7 +38,7 @@ type Config struct { } // DefaultConfig returns default configuration for P2P subsystem. -func DefaultConfig() Config { +func DefaultConfig(tp node.Type) Config { return Config{ ListenAddresses: []string{ "/ip4/0.0.0.0/udp/2121/quic-v1", @@ -55,7 +57,7 @@ func DefaultConfig() Config { }, MutualPeers: []string{}, Bootstrapper: false, - PeerExchange: false, + PeerExchange: tp == node.Bridge || tp == node.Full, ConnManager: defaultConnManagerConfig(), RoutingTableRefreshPeriod: defaultRoutingRefreshPeriod, } diff --git a/nodebuilder/p2p/module.go b/nodebuilder/p2p/module.go index c00f8ab60c..5b3e9f9d52 100644 --- a/nodebuilder/p2p/module.go +++ b/nodebuilder/p2p/module.go @@ -44,6 +44,7 @@ func ConstructModule(tp node.Type, cfg *Config) fx.Option { return fx.Module( "p2p", baseComponents, + fx.Provide(blockstoreFromEDSStore), fx.Provide(func() (network.ResourceManager, error) { return rcmgr.NewResourceManager(rcmgr.NewFixedLimiter(rcmgr.InfiniteLimits)) }), @@ -52,6 +53,7 @@ func ConstructModule(tp node.Type, cfg *Config) fx.Option { return fx.Module( "p2p", baseComponents, + fx.Provide(blockstoreFromDatastore), fx.Provide(func() (network.ResourceManager, error) { limits := rcmgr.DefaultLimits libp2p.SetDefaultServiceLimits(&limits) diff --git a/nodebuilder/p2p/network.go b/nodebuilder/p2p/network.go index 07a9a8b0ac..b13c5d7311 100644 --- a/nodebuilder/p2p/network.go +++ b/nodebuilder/p2p/network.go @@ -12,7 +12,7 @@ const ( // DefaultNetwork is the default network of the current build. DefaultNetwork = Mocha // Arabica testnet. See: celestiaorg/networks. - Arabica Network = "arabica-5" + Arabica Network = "arabica-6" // Mocha testnet. See: celestiaorg/networks. Mocha Network = "mocha" // Private can be used to set up any private network, including local testing setups. diff --git a/nodebuilder/share/config.go b/nodebuilder/share/config.go index 22564dd941..aeb37a93dc 100644 --- a/nodebuilder/share/config.go +++ b/nodebuilder/share/config.go @@ -18,6 +18,9 @@ type Config struct { // AdvertiseInterval is a interval between advertising sessions. // NOTE: only full and bridge can advertise themselves. AdvertiseInterval time.Duration + // UseShareExchange is a flag toggling the usage of shrex protocols for blocksync. + // NOTE: This config variable only has an effect on full and bridge nodes. + UseShareExchange bool } func DefaultConfig() Config { @@ -25,6 +28,7 @@ func DefaultConfig() Config { PeersLimit: 3, DiscoveryInterval: time.Second * 30, AdvertiseInterval: time.Second * 30, + UseShareExchange: true, } } diff --git a/nodebuilder/share/constructors.go b/nodebuilder/share/constructors.go index 0a5f5061aa..d039d535b0 100644 --- a/nodebuilder/share/constructors.go +++ b/nodebuilder/share/constructors.go @@ -3,6 +3,7 @@ package share import ( "context" "errors" + "time" "github.com/filecoin-project/dagstore" "github.com/ipfs/go-datastore" @@ -16,6 +17,7 @@ import ( "github.com/celestiaorg/celestia-node/share/availability/cache" disc "github.com/celestiaorg/celestia-node/share/availability/discovery" "github.com/celestiaorg/celestia-node/share/eds" + "github.com/celestiaorg/celestia-node/share/getters" ) func discovery(cfg Config) func(routing.ContentRouting, host.Host) *disc.Discovery { @@ -57,3 +59,42 @@ func ensureEmptyCARExists(ctx context.Context, store *eds.Store) error { } return err } + +func lightGetter( + shrexGetter *getters.ShrexGetter, + ipldGetter *getters.IPLDGetter, +) share.Getter { + return getters.NewCascadeGetter( + []share.Getter{ + shrexGetter, + ipldGetter, + }, + // based on the default value of das.SampleTimeout. + // will no longer be needed when async cascadegetter is merged + time.Minute, + ) +} + +func fullGetter( + store *eds.Store, + shrexGetter *getters.ShrexGetter, + ipldGetter *getters.IPLDGetter, + cfg Config, +) share.Getter { + var cascade []share.Getter + // based on the default value of das.SampleTimeout + timeout := time.Minute + cascade = append(cascade, getters.NewStoreGetter(store)) + if cfg.UseShareExchange { + // if we are using share exchange, we split the timeout between the two getters + // once async cascadegetter is implemented, we can remove this + timeout /= 2 + cascade = append(cascade, shrexGetter) + } + cascade = append(cascade, ipldGetter) + + return getters.NewTeeGetter( + getters.NewCascadeGetter(cascade, timeout), + store, + ) +} diff --git a/nodebuilder/share/module.go b/nodebuilder/share/module.go index b6fe19592d..6630d5039e 100644 --- a/nodebuilder/share/module.go +++ b/nodebuilder/share/module.go @@ -8,7 +8,6 @@ import ( "github.com/libp2p/go-libp2p/core/host" "go.uber.org/fx" - "github.com/celestiaorg/celestia-node/libs/fxutil" "github.com/celestiaorg/celestia-node/nodebuilder/node" modp2p "github.com/celestiaorg/celestia-node/nodebuilder/p2p" "github.com/celestiaorg/celestia-node/share" @@ -16,7 +15,9 @@ import ( "github.com/celestiaorg/celestia-node/share/availability/light" "github.com/celestiaorg/celestia-node/share/eds" "github.com/celestiaorg/celestia-node/share/getters" + "github.com/celestiaorg/celestia-node/share/p2p/peers" "github.com/celestiaorg/celestia-node/share/p2p/shrexeds" + "github.com/celestiaorg/celestia-node/share/p2p/shrexnd" "github.com/celestiaorg/celestia-node/share/p2p/shrexsub" ) @@ -33,9 +34,37 @@ func ConstructModule(tp node.Type, cfg *Config, options ...fx.Option) fx.Option fx.Error(cfgErr), fx.Options(options...), fx.Provide(discovery(*cfg)), - fx.Invoke(share.EnsureEmptySquareExists), fx.Provide(newModule), - fxutil.ProvideAs(getters.NewIPLDGetter, new(share.Getter)), + fx.Provide(getters.NewIPLDGetter), + fx.Provide(peers.NewManager), + fx.Provide( + func(ctx context.Context, h host.Host, network modp2p.Network) (*shrexsub.PubSub, error) { + return shrexsub.NewPubSub( + ctx, + h, + string(network), + ) + }, + ), + fx.Provide( + func(host host.Host, network modp2p.Network) (*shrexnd.Client, error) { + return shrexnd.NewClient(host, shrexnd.WithProtocolSuffix(string(network))) + }, + ), + fx.Provide( + func(host host.Host, network modp2p.Network) (*shrexeds.Client, error) { + return shrexeds.NewClient(host, shrexeds.WithProtocolSuffix(string(network))) + }, + ), + fx.Provide(fx.Annotate( + getters.NewShrexGetter, + fx.OnStart(func(ctx context.Context, getter *getters.ShrexGetter) error { + return getter.Start(ctx) + }), + fx.OnStop(func(ctx context.Context, getter *getters.ShrexGetter) error { + return getter.Stop(ctx) + }), + )), ) switch tp { @@ -43,15 +72,15 @@ func ConstructModule(tp node.Type, cfg *Config, options ...fx.Option) fx.Option return fx.Module( "share", baseComponents, - fx.Provide(fx.Annotate( - light.NewShareAvailability, - fx.OnStart(func(ctx context.Context, avail *light.ShareAvailability) error { - return avail.Start(ctx) - }), - fx.OnStop(func(ctx context.Context, avail *light.ShareAvailability) error { - return avail.Stop(ctx) - }), - )), + fx.Invoke(share.EnsureEmptySquareExists), + fx.Provide(lightGetter), + // shrexsub broadcaster stub for daser + fx.Provide(func() shrexsub.BroadcastFn { + return func(context.Context, share.DataHash) error { + return nil + } + }), + fx.Provide(fx.Annotate(light.NewShareAvailability)), // cacheAvailability's lifecycle continues to use a fx hook, // since the LC requires a cacheAvailability but the constructor returns a share.Availability fx.Provide(cacheAvailability[*light.ShareAvailability]), @@ -60,6 +89,7 @@ func ConstructModule(tp node.Type, cfg *Config, options ...fx.Option) fx.Option return fx.Module( "share", baseComponents, + fx.Invoke(func(edsSrv *shrexeds.Server, ndSrc *shrexnd.Server) {}), fx.Provide(fx.Annotate( func(host host.Host, store *eds.Store, network modp2p.Network) (*shrexeds.Server, error) { return shrexeds.NewServer(host, store, shrexeds.WithProtocolSuffix(string(network))) @@ -71,12 +101,22 @@ func ConstructModule(tp node.Type, cfg *Config, options ...fx.Option) fx.Option return server.Stop(ctx) }), )), - // Bridge Nodes need a client as well, for requests over FullAvailability - fx.Provide( - func(host host.Host, network modp2p.Network) (*shrexeds.Client, error) { - return shrexeds.NewClient(host, shrexeds.WithProtocolSuffix(string(network))) + fx.Provide(fx.Annotate( + func( + host host.Host, + store *eds.Store, + getter *getters.IPLDGetter, + network modp2p.Network, + ) (*shrexnd.Server, error) { + return shrexnd.NewServer(host, store, getter, shrexnd.WithProtocolSuffix(string(network))) }, - ), + fx.OnStart(func(ctx context.Context, server *shrexnd.Server) error { + return server.Start(ctx) + }), + fx.OnStop(func(ctx context.Context, server *shrexnd.Server) error { + return server.Stop(ctx) + }), + )), fx.Provide(fx.Annotate( func(path node.StorePath, ds datastore.Batching) (*eds.Store, error) { return eds.NewStore(string(path), ds) @@ -102,24 +142,13 @@ func ConstructModule(tp node.Type, cfg *Config, options ...fx.Option) fx.Option return avail.Stop(ctx) }), )), - fx.Provide(fx.Annotate( - func(ctx context.Context, h host.Host, network modp2p.Network) (*shrexsub.PubSub, error) { - return shrexsub.NewPubSub( - ctx, - h, - string(network), - ) - }, - fx.OnStart(func(ctx context.Context, pubsub *shrexsub.PubSub) error { - return pubsub.Start(ctx) - }), - fx.OnStop(func(ctx context.Context, pubsub *shrexsub.PubSub) error { - return pubsub.Stop(ctx) - }), - )), // cacheAvailability's lifecycle continues to use a fx hook, // since the LC requires a cacheAvailability but the constructor returns a share.Availability fx.Provide(cacheAvailability[*full.ShareAvailability]), + fx.Provide(func(shrexSub *shrexsub.PubSub) shrexsub.BroadcastFn { + return shrexSub.Broadcast + }), + fx.Provide(fullGetter), ) default: panic("invalid node type") diff --git a/nodebuilder/tests/fraud_test.go b/nodebuilder/tests/fraud_test.go index a19c042e91..ba065a2551 100644 --- a/nodebuilder/tests/fraud_test.go +++ b/nodebuilder/tests/fraud_test.go @@ -5,6 +5,7 @@ import ( "testing" "time" + mdutils "github.com/ipfs/go-merkledag/test" "github.com/libp2p/go-libp2p/core/host" "github.com/libp2p/go-libp2p/core/peer" "github.com/stretchr/testify/require" @@ -29,23 +30,34 @@ Steps: 5. Subscribe to a fraud proof and wait when it will be received. 6. Check FN is not synced to 15. Note: 15 is not available because DASer will be stopped before reaching this height due to receiving a fraud proof. +Another note: this test disables share exchange to speed up test results. */ func TestFraudProofBroadcasting(t *testing.T) { - // we increase the timeout for this test to decrease flakiness in CI - testTimeout := time.Millisecond * 200 - sw := swamp.NewSwamp(t, swamp.WithBlockTime(testTimeout)) - - bridge := sw.NewBridgeNode(core.WithHeaderConstructFn(headertest.FraudMaker(t, 20))) - + const ( + blocks = 15 + bsize = 2 + btime = time.Millisecond * 300 + ) + sw := swamp.NewSwamp(t, swamp.WithBlockTime(btime)) ctx, cancel := context.WithTimeout(context.Background(), swamp.DefaultTestTimeout) t.Cleanup(cancel) + fillDn := swamp.FillBlocks(ctx, sw.ClientContext, sw.Accounts, bsize, blocks) + cfg := nodebuilder.DefaultConfig(node.Bridge) + cfg.Share.UseShareExchange = false + bridge := sw.NewNodeWithConfig( + node.Bridge, + cfg, + core.WithHeaderConstructFn(headertest.FraudMaker(t, 10, mdutils.Bserv())), + ) + err := bridge.Start(ctx) require.NoError(t, err) addrs, err := peer.AddrInfoToP2pAddrs(host.InfoFromHost(bridge.Host)) require.NoError(t, err) - cfg := nodebuilder.DefaultConfig(node.Full) + cfg = nodebuilder.DefaultConfig(node.Full) + cfg.Share.UseShareExchange = false cfg.Header.TrustedPeers = append(cfg.Header.TrustedPeers, addrs[0].String()) store := nodebuilder.MockStore(t, cfg) full := sw.NewNodeWithStore(node.Full, store) @@ -59,13 +71,13 @@ func TestFraudProofBroadcasting(t *testing.T) { require.NoError(t, err) p := <-subscr - require.Equal(t, 20, int(p.Height())) + require.Equal(t, 10, int(p.Height())) // This is an obscure way to check if the Syncer was stopped. // If we cannot get a height header within a timeframe it means the syncer was stopped // FIXME: Eventually, this should be a check on service registry managing and keeping // lifecycles of each Module. - syncCtx, syncCancel := context.WithTimeout(context.Background(), testTimeout) + syncCtx, syncCancel := context.WithTimeout(context.Background(), btime) _, err = full.HeaderServ.GetByHeight(syncCtx, 100) require.ErrorIs(t, err, context.DeadlineExceeded) syncCancel() @@ -79,6 +91,7 @@ func TestFraudProofBroadcasting(t *testing.T) { proofs, err := full.FraudServ.Get(ctx, fraud.BadEncoding) require.NoError(t, err) require.NotNil(t, proofs) + require.NoError(t, <-fillDn) } /* @@ -93,16 +106,27 @@ Steps: 5. Subscribe to a fraud proof and wait when it will be received. 6. Start LN once a fraud proof is received and verified by FN. 7. Wait until LN will be connected to FN and fetch a fraud proof. +Note: this test disables share exchange to speed up test results. */ func TestFraudProofSyncing(t *testing.T) { - sw := swamp.NewSwamp(t, swamp.WithBlockTime(time.Millisecond*300)) + const ( + blocks = 15 + bsize = 2 + btime = time.Millisecond * 300 + ) + sw := swamp.NewSwamp(t, swamp.WithBlockTime(btime)) + ctx, cancel := context.WithTimeout(context.Background(), swamp.DefaultTestTimeout) + t.Cleanup(cancel) + fillDn := swamp.FillBlocks(ctx, sw.ClientContext, sw.Accounts, bsize, blocks) cfg := nodebuilder.DefaultConfig(node.Bridge) + cfg.Share.UseShareExchange = false store := nodebuilder.MockStore(t, cfg) - bridge := sw.NewNodeWithStore(node.Bridge, store, core.WithHeaderConstructFn(headertest.FraudMaker(t, 10))) - - ctx, cancel := context.WithTimeout(context.Background(), swamp.DefaultTestTimeout) - t.Cleanup(cancel) + bridge := sw.NewNodeWithStore( + node.Bridge, + store, + core.WithHeaderConstructFn(headertest.FraudMaker(t, 10, mdutils.Bserv())), + ) err := bridge.Start(ctx) require.NoError(t, err) @@ -111,6 +135,7 @@ func TestFraudProofSyncing(t *testing.T) { require.NoError(t, err) fullCfg := nodebuilder.DefaultConfig(node.Full) + fullCfg.Share.UseShareExchange = false fullCfg.Header.TrustedPeers = append(fullCfg.Header.TrustedPeers, addrs[0].String()) full := sw.NewNodeWithStore(node.Full, nodebuilder.MockStore(t, fullCfg)) @@ -147,4 +172,5 @@ func TestFraudProofSyncing(t *testing.T) { case <-ctx.Done(): t.Fatal("light node didn't get FP in time") } + require.NoError(t, <-fillDn) } diff --git a/nodebuilder/tests/reconstruct_test.go b/nodebuilder/tests/reconstruct_test.go index 27c1aaf22f..579a02e804 100644 --- a/nodebuilder/tests/reconstruct_test.go +++ b/nodebuilder/tests/reconstruct_test.go @@ -47,7 +47,7 @@ func TestFullReconstructFromBridge(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), swamp.DefaultTestTimeout) t.Cleanup(cancel) sw := swamp.NewSwamp(t, swamp.WithBlockTime(btime)) - fillDn := sw.FillBlocks(ctx, bsize, blocks) + fillDn := swamp.FillBlocks(ctx, sw.ClientContext, sw.Accounts, bsize, blocks) bridge := sw.NewBridgeNode() err := bridge.Start(ctx) @@ -106,7 +106,7 @@ func TestFullReconstructFromLights(t *testing.T) { t.Cleanup(cancel) sw := swamp.NewSwamp(t, swamp.WithBlockTime(btime)) - fillDn := sw.FillBlocks(ctx, bsize, blocks) + fillDn := swamp.FillBlocks(ctx, sw.ClientContext, sw.Accounts, bsize, blocks) const defaultTimeInterval = time.Second * 5 cfg := nodebuilder.DefaultConfig(node.Full) diff --git a/nodebuilder/tests/swamp/swamp.go b/nodebuilder/tests/swamp/swamp.go index 492c2ff93f..32434788a2 100644 --- a/nodebuilder/tests/swamp/swamp.go +++ b/nodebuilder/tests/swamp/swamp.go @@ -8,7 +8,8 @@ import ( "testing" "time" - mdutils "github.com/ipfs/go-merkledag/test" + ds "github.com/ipfs/go-datastore" + ds_sync "github.com/ipfs/go-datastore/sync" "github.com/libp2p/go-libp2p/core/host" "github.com/libp2p/go-libp2p/core/peer" mocknet "github.com/libp2p/go-libp2p/p2p/net/mock" @@ -28,6 +29,7 @@ import ( "github.com/celestiaorg/celestia-node/nodebuilder/node" "github.com/celestiaorg/celestia-node/nodebuilder/p2p" "github.com/celestiaorg/celestia-node/nodebuilder/state" + "github.com/celestiaorg/celestia-node/share/eds" ) var blackholeIP6 = net.ParseIP("100::") @@ -50,7 +52,7 @@ type Swamp struct { comps *Components ClientContext testnode.Context - accounts []string + Accounts []string genesis *header.ExtendedHeader } @@ -78,7 +80,7 @@ func NewSwamp(t *testing.T, options ...Option) *Swamp { Network: mocknet.New(), ClientContext: cctx, comps: ic, - accounts: ic.Accounts, + Accounts: ic.Accounts, } swp.t.Cleanup(func() { @@ -163,9 +165,12 @@ func (s *Swamp) createPeer(ks keystore.Keystore) host.Host { func (s *Swamp) setupGenesis(ctx context.Context) { s.WaitTillHeight(ctx, 1) + store, err := eds.NewStore(s.t.TempDir(), ds_sync.MutexWrap(ds.NewMapDatastore())) + require.NoError(s.t, err) + ex := core.NewExchange( core.NewBlockFetcher(s.ClientContext.Client), - mdutils.Bserv(), + store, header.MakeExtendedHeader, ) diff --git a/nodebuilder/tests/swamp/swamp_tx.go b/nodebuilder/tests/swamp/swamp_tx.go index e96c968bd4..956c8cf7c4 100644 --- a/nodebuilder/tests/swamp/swamp_tx.go +++ b/nodebuilder/tests/swamp/swamp_tx.go @@ -5,19 +5,21 @@ import ( "time" "github.com/cosmos/cosmos-sdk/client/flags" + + "github.com/celestiaorg/celestia-app/testutil/testnode" ) // FillBlocks produces the given amount of contiguous blocks with customizable size. // The returned channel reports when the process is finished. -func (s *Swamp) FillBlocks(ctx context.Context, bsize, blocks int) chan error { - // TODO @renaynay: figure out why sleep is necessary to prevent flakeyness for macOS - time.Sleep(time.Millisecond * 50) +func FillBlocks(ctx context.Context, cctx testnode.Context, accounts []string, bsize, blocks int) chan error { errCh := make(chan error) go func() { // TODO: FillBlock must respect the context + // fill blocks is not working correctly without sleep rn. + time.Sleep(time.Millisecond * 50) var err error for i := 0; i < blocks; i++ { - _, err = s.ClientContext.FillBlock(bsize, s.accounts, flags.BroadcastBlock) + _, err = cctx.FillBlock(bsize, accounts, flags.BroadcastBlock) if err != nil { break } diff --git a/nodebuilder/tests/sync_test.go b/nodebuilder/tests/sync_test.go index f6dd1552a2..0575f223c9 100644 --- a/nodebuilder/tests/sync_test.go +++ b/nodebuilder/tests/sync_test.go @@ -45,7 +45,7 @@ func TestSyncLightWithBridge(t *testing.T) { t.Cleanup(cancel) sw := swamp.NewSwamp(t, swamp.WithBlockTime(btime)) - fillDn := sw.FillBlocks(ctx, bsize, blocks) + fillDn := swamp.FillBlocks(ctx, sw.ClientContext, sw.Accounts, bsize, blocks) bridge := sw.NewBridgeNode() @@ -160,7 +160,7 @@ func TestSyncFullWithBridge(t *testing.T) { t.Cleanup(cancel) sw := swamp.NewSwamp(t, swamp.WithBlockTime(btime)) - fillDn := sw.FillBlocks(ctx, bsize, blocks) + fillDn := swamp.FillBlocks(ctx, sw.ClientContext, sw.Accounts, bsize, blocks) bridge := sw.NewBridgeNode() diff --git a/share/add.go b/share/add.go index 4e60eabb2c..02016cadf6 100644 --- a/share/add.go +++ b/share/add.go @@ -3,14 +3,15 @@ package share import ( "context" "fmt" - "math" "github.com/ipfs/go-blockservice" "github.com/celestiaorg/celestia-app/pkg/wrapper" - "github.com/celestiaorg/celestia-node/share/ipld" "github.com/celestiaorg/nmt" "github.com/celestiaorg/rsmt2d" + + "github.com/celestiaorg/celestia-node/libs/utils" + "github.com/celestiaorg/celestia-node/share/ipld" ) // AddShares erasures and extends shares to blockservice.BlockService using the provided @@ -23,7 +24,7 @@ func AddShares( if len(shares) == 0 { return nil, fmt.Errorf("empty data") // empty block is not an empty Data } - squareSize := int(math.Sqrt(float64(len(shares)))) + squareSize := int(utils.SquareSize(len(shares))) // create nmt adder wrapping batch adder with calculated size batchAdder := ipld.NewNmtNodeAdder(ctx, adder, ipld.MaxSizeBatchOption(squareSize*2)) // create the nmt wrapper to generate row and col commitments @@ -43,7 +44,7 @@ func AddShares( return eds, batchAdder.Commit() } -// ImportShares imports flattend chunks of data into Extended Data square and saves it in +// ImportShares imports flattened chunks of data into Extended Data square and saves it in // blockservice.BlockService func ImportShares( ctx context.Context, @@ -52,7 +53,7 @@ func ImportShares( if len(shares) == 0 { return nil, fmt.Errorf("ipld: importing empty data") } - squareSize := int(math.Sqrt(float64(len(shares)))) + squareSize := int(utils.SquareSize(len(shares))) // create nmt adder wrapping batch adder with calculated size batchAdder := ipld.NewNmtNodeAdder(ctx, adder, ipld.MaxSizeBatchOption(squareSize*2)) // recompute the eds diff --git a/share/availability.go b/share/availability.go index 190fb31dd1..045de350e8 100644 --- a/share/availability.go +++ b/share/availability.go @@ -5,8 +5,6 @@ import ( "errors" "time" - "github.com/libp2p/go-libp2p/core/peer" - "github.com/celestiaorg/celestia-app/pkg/da" ) @@ -28,7 +26,7 @@ type Root = da.DataAvailabilityHeader type Availability interface { // SharesAvailable subjectively validates if Shares committed to the given Root are available on // the Network by requesting the EDS from the provided peers. - SharesAvailable(context.Context, *Root, ...peer.ID) error + SharesAvailable(context.Context, *Root) error // ProbabilityOfAvailability calculates the probability of the data square // being available based on the number of samples collected. // TODO(@Wondertan): Merge with SharesAvailable method, eventually diff --git a/share/availability/cache/availability.go b/share/availability/cache/availability.go index 2a7934930c..2425f59867 100644 --- a/share/availability/cache/availability.go +++ b/share/availability/cache/availability.go @@ -9,10 +9,8 @@ import ( "github.com/ipfs/go-datastore/autobatch" "github.com/ipfs/go-datastore/namespace" logging "github.com/ipfs/go-log/v2" - "github.com/libp2p/go-libp2p/core/peer" "github.com/celestiaorg/celestia-app/pkg/da" - "github.com/celestiaorg/celestia-node/share" ) @@ -62,7 +60,7 @@ func (ca *ShareAvailability) Stop(context.Context) error { } // SharesAvailable will store, upon success, the hash of the given Root to disk. -func (ca *ShareAvailability) SharesAvailable(ctx context.Context, root *share.Root, _ ...peer.ID) error { +func (ca *ShareAvailability) SharesAvailable(ctx context.Context, root *share.Root) error { // short-circuit if the given root is minimum DAH of an empty data square if isMinRoot(root) { return nil diff --git a/share/availability/cache/availability_test.go b/share/availability/cache/availability_test.go index e578c0db6f..47df434c06 100644 --- a/share/availability/cache/availability_test.go +++ b/share/availability/cache/availability_test.go @@ -10,7 +10,6 @@ import ( "github.com/ipfs/go-datastore" "github.com/ipfs/go-datastore/sync" mdutils "github.com/ipfs/go-merkledag/test" - "github.com/libp2p/go-libp2p/core/peer" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -123,7 +122,7 @@ type dummyAvailability struct { // SharesAvailable should only be called once, if called more than once, return // error. -func (da *dummyAvailability) SharesAvailable(_ context.Context, root *share.Root, _ ...peer.ID) error { +func (da *dummyAvailability) SharesAvailable(_ context.Context, root *share.Root) error { if root == &invalidHeader { return fmt.Errorf("invalid header") } diff --git a/share/availability/discovery/discovery.go b/share/availability/discovery/discovery.go index fc59c7bf41..f8fc129724 100644 --- a/share/availability/discovery/discovery.go +++ b/share/availability/discovery/discovery.go @@ -2,7 +2,6 @@ package discovery import ( "context" - "sync" "time" logging "github.com/ipfs/go-log/v2" @@ -47,8 +46,6 @@ type Discovery struct { advertiseInterval time.Duration // onUpdatedPeers will be called on peer set changes onUpdatedPeers OnUpdatedPeers - // ensureIsRunning allows only one ensurePeers process to be running - ensurePeersOnce sync.Once } type OnUpdatedPeers func(peerID peer.ID, isAdded bool) @@ -70,14 +67,14 @@ func NewDiscovery( discInterval, advertiseInterval, func(peer.ID, bool) {}, - sync.Once{}, } } -// WithOnPeersUpdate adds OnPeersUpdate callback call on every update of discovered peers list. +// WithOnPeersUpdate chains OnPeersUpdate callbacks on every update of discovered peers list. func (d *Discovery) WithOnPeersUpdate(f OnUpdatedPeers) { + prev := d.onUpdatedPeers d.onUpdatedPeers = func(peerID peer.ID, isAdded bool) { - d.onUpdatedPeers(peerID, isAdded) + prev(peerID, isAdded) f(peerID, isAdded) } } @@ -111,12 +108,6 @@ func (d *Discovery) handlePeerFound(ctx context.Context, topic string, peer peer // It starts peer discovery every 30 seconds until peer cache reaches peersLimit. // Discovery is restarted if any previously connected peers disconnect. func (d *Discovery) EnsurePeers(ctx context.Context) { - d.ensurePeersOnce.Do(func() { - go d.ensurePeers(ctx) - }) -} - -func (d *Discovery) ensurePeers(ctx context.Context) { if d.peersLimit == 0 { log.Warn("peers limit is set to 0. Skipping discovery...") return diff --git a/share/availability/full/availability.go b/share/availability/full/availability.go index 9a0ce67f70..deaa86dec3 100644 --- a/share/availability/full/availability.go +++ b/share/availability/full/availability.go @@ -6,10 +6,10 @@ import ( ipldFormat "github.com/ipfs/go-ipld-format" logging "github.com/ipfs/go-log/v2" - "github.com/libp2p/go-libp2p/core/peer" "github.com/celestiaorg/celestia-node/share" "github.com/celestiaorg/celestia-node/share/availability/discovery" + "github.com/celestiaorg/celestia-node/share/eds/byzantine" ) var log = logging.Logger("share/full") @@ -37,7 +37,6 @@ func (fa *ShareAvailability) Start(context.Context) error { fa.cancel = cancel go fa.disc.Advertise(ctx) - fa.disc.EnsurePeers(ctx) return nil } @@ -48,7 +47,7 @@ func (fa *ShareAvailability) Stop(context.Context) error { // SharesAvailable reconstructs the data committed to the given Root by requesting // enough Shares from the network. -func (fa *ShareAvailability) SharesAvailable(ctx context.Context, root *share.Root, _ ...peer.ID) error { +func (fa *ShareAvailability) SharesAvailable(ctx context.Context, root *share.Root) error { ctx, cancel := context.WithTimeout(ctx, share.AvailabilityTimeout) defer cancel() // we assume the caller of this method has already performed basic validation on the @@ -61,8 +60,9 @@ func (fa *ShareAvailability) SharesAvailable(ctx context.Context, root *share.Ro _, err := fa.getter.GetEDS(ctx, root) if err != nil { - log.Errorw("availability validation failed", "root", root.Hash(), "err", err) - if ipldFormat.IsNotFound(err) || errors.Is(err, context.DeadlineExceeded) { + log.Errorw("availability validation failed", "root", root.Hash(), "err", err.Error()) + var byzantineErr *byzantine.ErrByzantine + if ipldFormat.IsNotFound(err) || errors.Is(err, context.DeadlineExceeded) && !errors.As(err, &byzantineErr) { return share.ErrNotAvailable } diff --git a/share/availability/light/availability.go b/share/availability/light/availability.go index 25858ffd7f..c23b28045f 100644 --- a/share/availability/light/availability.go +++ b/share/availability/light/availability.go @@ -7,10 +7,8 @@ import ( ipldFormat "github.com/ipfs/go-ipld-format" logging "github.com/ipfs/go-log/v2" - "github.com/libp2p/go-libp2p/core/peer" "github.com/celestiaorg/celestia-node/share" - "github.com/celestiaorg/celestia-node/share/availability/discovery" "github.com/celestiaorg/celestia-node/share/getters" ) @@ -22,40 +20,16 @@ var log = logging.Logger("share/light") // on the network doing sampling over the same Root to collectively verify its availability. type ShareAvailability struct { getter share.Getter - // disc discovers new full nodes in the network. - // it is not allowed to call advertise for light nodes (Full nodes only). - disc *discovery.Discovery - cancel context.CancelFunc } // NewShareAvailability creates a new light Availability. -func NewShareAvailability( - getter share.Getter, - disc *discovery.Discovery, -) *ShareAvailability { - la := &ShareAvailability{ - getter: getter, - disc: disc, - } - return la -} - -func (la *ShareAvailability) Start(context.Context) error { - ctx, cancel := context.WithCancel(context.Background()) - la.cancel = cancel - - la.disc.EnsurePeers(ctx) - return nil -} - -func (la *ShareAvailability) Stop(ctx context.Context) error { - la.cancel() - return nil +func NewShareAvailability(getter share.Getter) *ShareAvailability { + return &ShareAvailability{getter} } // SharesAvailable randomly samples DefaultSampleAmount amount of Shares committed to the given // Root. This way SharesAvailable subjectively verifies that Shares are available. -func (la *ShareAvailability) SharesAvailable(ctx context.Context, dah *share.Root, _ ...peer.ID) error { +func (la *ShareAvailability) SharesAvailable(ctx context.Context, dah *share.Root) error { log.Debugw("Validate availability", "root", dah.Hash()) // We assume the caller of this method has already performed basic validation on the // given dah/root. If for some reason this has not happened, the node should panic. diff --git a/share/availability/light/testing.go b/share/availability/light/testing.go index 0072f226c6..59163d6356 100644 --- a/share/availability/light/testing.go +++ b/share/availability/light/testing.go @@ -2,15 +2,11 @@ package light import ( "testing" - "time" "github.com/ipfs/go-blockservice" mdutils "github.com/ipfs/go-merkledag/test" - routinghelpers "github.com/libp2p/go-libp2p-routing-helpers" - "github.com/libp2p/go-libp2p/p2p/discovery/routing" "github.com/celestiaorg/celestia-node/share" - "github.com/celestiaorg/celestia-node/share/availability/discovery" availability_test "github.com/celestiaorg/celestia-node/share/availability/test" "github.com/celestiaorg/celestia-node/share/getters" ) @@ -46,8 +42,7 @@ func Node(dn *availability_test.TestDagNet) *availability_test.TestNode { } func TestAvailability(getter share.Getter) *ShareAvailability { - disc := discovery.NewDiscovery(nil, routing.NewRoutingDiscovery(routinghelpers.Null{}), 0, time.Second, time.Second) - return NewShareAvailability(getter, disc) + return NewShareAvailability(getter) } func SubNetNode(sn *availability_test.SubNet) *availability_test.TestNode { diff --git a/share/availability/mocks/availability.go b/share/availability/mocks/availability.go index 030348f4e4..ff4b8e1328 100644 --- a/share/availability/mocks/availability.go +++ b/share/availability/mocks/availability.go @@ -9,7 +9,6 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" - peer "github.com/libp2p/go-libp2p/core/peer" da "github.com/celestiaorg/celestia-app/pkg/da" ) @@ -52,20 +51,15 @@ func (mr *MockAvailabilityMockRecorder) ProbabilityOfAvailability(arg0 interface } // SharesAvailable mocks base method. -func (m *MockAvailability) SharesAvailable(arg0 context.Context, arg1 *da.DataAvailabilityHeader, arg2 ...peer.ID) error { +func (m *MockAvailability) SharesAvailable(arg0 context.Context, arg1 *da.DataAvailabilityHeader) error { m.ctrl.T.Helper() - varargs := []interface{}{arg0, arg1} - for _, a := range arg2 { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "SharesAvailable", varargs...) + ret := m.ctrl.Call(m, "SharesAvailable", arg0, arg1) ret0, _ := ret[0].(error) return ret0 } // SharesAvailable indicates an expected call of SharesAvailable. -func (mr *MockAvailabilityMockRecorder) SharesAvailable(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { +func (mr *MockAvailabilityMockRecorder) SharesAvailable(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{arg0, arg1}, arg2...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SharesAvailable", reflect.TypeOf((*MockAvailability)(nil).SharesAvailable), varargs...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SharesAvailable", reflect.TypeOf((*MockAvailability)(nil).SharesAvailable), arg0, arg1) } diff --git a/share/eds/retriever_test.go b/share/eds/retriever_test.go index a7bb8353ec..cc398890d0 100644 --- a/share/eds/retriever_test.go +++ b/share/eds/retriever_test.go @@ -157,7 +157,7 @@ func generateByzantineError( h, err := store.GetByHeight(ctx, 1) require.NoError(t, err) - faultHeader := headertest.CreateFraudExtHeader(t, h, bServ) + faultHeader, _ := headertest.CreateFraudExtHeader(t, h, bServ) _, err = NewRetriever(bServ).Retrieve(ctx, faultHeader.DAH) return faultHeader, err } diff --git a/share/get_test.go b/share/get_test.go index fdf882c569..3466a726f1 100644 --- a/share/get_test.go +++ b/share/get_test.go @@ -2,7 +2,6 @@ package share import ( "context" - "math" "math/rand" "strconv" "testing" @@ -20,10 +19,12 @@ import ( "github.com/stretchr/testify/require" "github.com/celestiaorg/celestia-app/pkg/wrapper" - "github.com/celestiaorg/celestia-node/share/ipld" "github.com/celestiaorg/nmt" "github.com/celestiaorg/nmt/namespace" "github.com/celestiaorg/rsmt2d" + + "github.com/celestiaorg/celestia-node/libs/utils" + "github.com/celestiaorg/celestia-node/share/ipld" ) func TestGetShare(t *testing.T) { @@ -73,7 +74,7 @@ func TestBlockRecovery(t *testing.T) { tc := tc t.Run(tc.name, func(t *testing.T) { - squareSize := uint64(math.Sqrt(float64(len(tc.shares)))) + squareSize := utils.SquareSize(len(tc.shares)) eds, err := rsmt2d.ComputeExtendedDataSquare(tc.shares, rsmt2d.NewRSGF8Codec(), wrapper.NewConstructor(squareSize)) require.NoError(t, err) diff --git a/share/getters/shrex.go b/share/getters/shrex.go new file mode 100644 index 0000000000..2c5a9d36d8 --- /dev/null +++ b/share/getters/shrex.go @@ -0,0 +1,115 @@ +package getters + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/celestiaorg/celestia-node/share" + "github.com/celestiaorg/celestia-node/share/p2p" + "github.com/celestiaorg/celestia-node/share/p2p/peers" + "github.com/celestiaorg/celestia-node/share/p2p/shrexeds" + "github.com/celestiaorg/celestia-node/share/p2p/shrexnd" + + "github.com/celestiaorg/nmt/namespace" + "github.com/celestiaorg/rsmt2d" +) + +var _ share.Getter = (*ShrexGetter)(nil) + +const defaultMaxRequestDuration = time.Second * 10 + +// ShrexGetter is a share.Getter that uses the shrex/eds and shrex/nd protocol to retrieve shares. +type ShrexGetter struct { + edsClient *shrexeds.Client + ndClient *shrexnd.Client + + peerManager *peers.Manager + maxRequestDuration time.Duration +} + +func NewShrexGetter(edsClient *shrexeds.Client, ndClient *shrexnd.Client, peerManager *peers.Manager) *ShrexGetter { + return &ShrexGetter{ + edsClient: edsClient, + ndClient: ndClient, + peerManager: peerManager, + maxRequestDuration: defaultMaxRequestDuration, + } +} + +func (sg *ShrexGetter) Start(ctx context.Context) error { + return sg.peerManager.Start(ctx) +} + +func (sg *ShrexGetter) Stop(ctx context.Context) error { + return sg.peerManager.Stop(ctx) +} + +func (sg *ShrexGetter) GetShare(ctx context.Context, root *share.Root, row, col int) (share.Share, error) { + return nil, errors.New("getter/shrex: GetShare is not supported") +} + +func (sg *ShrexGetter) GetEDS(ctx context.Context, root *share.Root) (*rsmt2d.ExtendedDataSquare, error) { + for { + select { + case <-ctx.Done(): + return nil, fmt.Errorf("getter/shrex: %w", ctx.Err()) + default: + } + peer, setStatus, err := sg.peerManager.Peer(ctx, root.Hash()) + if err != nil { + log.Debugw("couldn't find peer", "datahash", root.String(), "err", err) + return nil, fmt.Errorf("getter/shrex: %w", err) + } + + reqCtx, cancel := context.WithTimeout(ctx, sg.maxRequestDuration) + eds, err := sg.edsClient.RequestEDS(reqCtx, root.Hash(), peer) + cancel() + switch err { + case nil: + setStatus(peers.ResultSynced) + return eds, nil + case context.DeadlineExceeded: + log.Debugw("request exceeded deadline, trying with new peer", "datahash", root.String()) + case p2p.ErrInvalidResponse: + setStatus(peers.ResultBlacklistPeer) + default: + setStatus(peers.ResultCooldownPeer) + } + } +} + +func (sg *ShrexGetter) GetSharesByNamespace( + ctx context.Context, + root *share.Root, + id namespace.ID, +) (share.NamespacedShares, error) { + for { + select { + case <-ctx.Done(): + return nil, fmt.Errorf("getter/shrex: %w", ctx.Err()) + default: + } + peer, setStatus, err := sg.peerManager.Peer(ctx, root.Hash()) + if err != nil { + log.Debugw("couldn't find peer", "datahash", root.String(), "err", err) + return nil, fmt.Errorf("getter/shrex: %w", err) + } + + reqCtx, cancel := context.WithTimeout(ctx, sg.maxRequestDuration) + nd, err := sg.ndClient.RequestND(reqCtx, root, id, peer) + cancel() + switch err { + case nil: + setStatus(peers.ResultSuccess) + return nd, nil + case context.DeadlineExceeded: + log.Debugw("request exceeded deadline, trying with new peer", "datahash", root.String()) + case p2p.ErrInvalidResponse: + setStatus(peers.ResultBlacklistPeer) + default: + setStatus(peers.ResultCooldownPeer) + } + } +} diff --git a/share/getters/shrex_test.go b/share/getters/shrex_test.go index 37cb864bb7..41a39749cf 100644 --- a/share/getters/shrex_test.go +++ b/share/getters/shrex_test.go @@ -46,8 +46,12 @@ func TestGetSharesWithProofByNamespace(t *testing.T) { srvHost := net.NewTestNode().Host srv, err := shrexnd.NewServer(srvHost, edsStore, NewIPLDGetter(bServ)) require.NoError(t, err) - srv.Start() - t.Cleanup(srv.Stop) + err = srv.Start(ctx) + require.NoError(t, err) + + t.Cleanup(func() { + _ = srv.Stop(ctx) + }) // create client and connect it to server client, err := shrexnd.NewClient(net.NewTestNode().Host) diff --git a/share/getters/store.go b/share/getters/store.go index f2a222894e..cc87f0e6f1 100644 --- a/share/getters/store.go +++ b/share/getters/store.go @@ -2,8 +2,10 @@ package getters import ( "context" + "errors" "fmt" + "github.com/filecoin-project/dagstore" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" @@ -70,6 +72,9 @@ func (sg *StoreGetter) GetEDS(ctx context.Context, root *share.Root) (eds *rsmt2 }() eds, err = sg.store.Get(ctx, root.Hash()) + if errors.Is(err, dagstore.ErrShardUnknown) { + return nil, fmt.Errorf("getter/store: eds not found") + } if err != nil { return nil, fmt.Errorf("getter/store: failed to retrieve eds: %w", err) } diff --git a/share/ipld/nmt_test.go b/share/ipld/nmt_test.go index edc8824d04..ff0d38ea72 100644 --- a/share/ipld/nmt_test.go +++ b/share/ipld/nmt_test.go @@ -2,7 +2,6 @@ package ipld import ( "bytes" - "math" "math/rand" "sort" "strconv" @@ -13,6 +12,8 @@ import ( "github.com/celestiaorg/celestia-app/pkg/appconsts" "github.com/celestiaorg/celestia-app/pkg/da" + + "github.com/celestiaorg/celestia-node/libs/utils" ) // TestNamespaceFromCID checks that deriving the Namespaced hash from @@ -28,7 +29,7 @@ func TestNamespaceFromCID(t *testing.T) { for i, tt := range tests { t.Run(strconv.Itoa(i), func(t *testing.T) { // create DAH from rand data - squareSize := uint64(math.Sqrt(float64(len(tt.randData)))) + squareSize := utils.SquareSize(len(tt.randData)) eds, err := da.ExtendShares(squareSize, tt.randData) require.NoError(t, err) dah := da.NewDataAvailabilityHeader(eds) diff --git a/share/p2p/middleware.go b/share/p2p/middleware.go new file mode 100644 index 0000000000..fa3f0e26ae --- /dev/null +++ b/share/p2p/middleware.go @@ -0,0 +1,30 @@ +package p2p + +import ( + "sync/atomic" + + logging "github.com/ipfs/go-log/v2" + + "github.com/libp2p/go-libp2p/core/network" +) + +var log = logging.Logger("shrex/middleware") + +func RateLimitMiddleware(inner network.StreamHandler, concurrencyLimit int) network.StreamHandler { + var parallelRequests int64 + limit := int64(concurrencyLimit) + return func(stream network.Stream) { + current := atomic.AddInt64(¶llelRequests, 1) + defer atomic.AddInt64(¶llelRequests, -1) + + if current > limit { + log.Debug("concurrency limit reached") + err := stream.Close() + if err != nil { + log.Errorw("server: closing stream", "err", err) + } + return + } + inner(stream) + } +} diff --git a/share/p2p/peers/manager.go b/share/p2p/peers/manager.go new file mode 100644 index 0000000000..ae4e75abb2 --- /dev/null +++ b/share/p2p/peers/manager.go @@ -0,0 +1,390 @@ +package peers + +import ( + "context" + "errors" + "fmt" + "sync" + "sync/atomic" + "time" + "unsafe" + + "github.com/libp2p/go-libp2p/p2p/net/conngater" + + logging "github.com/ipfs/go-log/v2" + pubsub "github.com/libp2p/go-libp2p-pubsub" + "github.com/libp2p/go-libp2p/core/host" + "github.com/libp2p/go-libp2p/core/peer" + + "github.com/celestiaorg/celestia-node/header" + libhead "github.com/celestiaorg/celestia-node/libs/header" + "github.com/celestiaorg/celestia-node/share" + "github.com/celestiaorg/celestia-node/share/availability/discovery" + "github.com/celestiaorg/celestia-node/share/p2p/shrexsub" +) + +const ( + // ResultSuccess indicates operation was successful and no extra action is required + ResultSuccess result = iota + // ResultSynced will save the status of pool as "synced" and will remove peers from it + ResultSynced + // ResultCooldownPeer will put returned peer on cooldown, meaning it won't be available by Peer + // method for some time + ResultCooldownPeer + // ResultBlacklistPeer will blacklist peer. Blacklisted peers will be disconnected and blocked from + // any p2p communication in future by libp2p Gater + ResultBlacklistPeer +) + +var log = logging.Logger("shrex/peer-manager") + +// Manager keeps track of peers coming from shrex.Sub and from discovery +type Manager struct { + lock sync.Mutex + + // header subscription is necessary in order to validate the inbound eds hash + headerSub libhead.Subscriber[*header.ExtendedHeader] + shrexSub *shrexsub.PubSub + disc *discovery.Discovery + host host.Host + connGater *conngater.BasicConnectionGater + + // pools collecting peers from shrexSub + pools map[string]*syncPool + poolValidationTimeout time.Duration + peerCooldownTime time.Duration + gcInterval time.Duration + // fullNodes collects full nodes peer.ID found via discovery + fullNodes *pool + + // hashes that are not in the chain + blacklistedHashes map[string]bool + + cancel context.CancelFunc + done chan struct{} +} + +// DoneFunc updates internal state depending on call results. Should be called once per returned +// peer from Peer method +type DoneFunc func(result) + +type result int + +type syncPool struct { + *pool + + // isValidatedDataHash indicates if datahash was validated by receiving corresponding extended + // header from headerSub + isValidatedDataHash atomic.Bool + isSynced atomic.Bool + createdAt time.Time +} + +func NewManager( + headerSub libhead.Subscriber[*header.ExtendedHeader], + shrexSub *shrexsub.PubSub, + discovery *discovery.Discovery, + host host.Host, + connGater *conngater.BasicConnectionGater, + opts ...Option, +) *Manager { + params := DefaultParameters() + + s := &Manager{ + headerSub: headerSub, + shrexSub: shrexSub, + disc: discovery, + connGater: connGater, + host: host, + pools: make(map[string]*syncPool), + poolValidationTimeout: params.ValidationTimeout, + peerCooldownTime: params.PeerCooldown, + gcInterval: params.GcInterval, + blacklistedHashes: make(map[string]bool), + done: make(chan struct{}), + } + + for _, opt := range opts { + opt(s) + } + + s.fullNodes = newPool(s.peerCooldownTime) + + discovery.WithOnPeersUpdate( + func(peerID peer.ID, isAdded bool) { + if isAdded { + if s.peerIsBlacklisted(peerID) { + log.Debugw("got blacklisted from discovery", "peer", peerID) + return + } + log.Debugw("added to full nodes", "peer", peerID) + s.fullNodes.add(peerID) + return + } + + log.Debugw("removing peer from discovered full nodes", "peer", peerID) + s.fullNodes.remove(peerID) + }) + + return s +} + +func (m *Manager) Start(startCtx context.Context) error { + ctx, cancel := context.WithCancel(context.Background()) + m.cancel = cancel + + err := m.shrexSub.Start(startCtx) + if err != nil { + return fmt.Errorf("starting shrexsub: %w", err) + } + + err = m.shrexSub.AddValidator(m.validate) + if err != nil { + return fmt.Errorf("registering validator: %w", err) + } + + _, err = m.shrexSub.Subscribe() + if err != nil { + return fmt.Errorf("subscribing to shrexsub: %w", err) + } + + sub, err := m.headerSub.Subscribe() + if err != nil { + return fmt.Errorf("subscribing to headersub: %w", err) + } + + go m.disc.EnsurePeers(ctx) + go m.subscribeHeader(ctx, sub) + go m.GC(ctx) + + return nil +} + +func (m *Manager) Stop(ctx context.Context) error { + m.cancel() + select { + case <-m.done: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +// Peer returns peer collected from shrex.Sub for given datahash if any available. +// If there is none, it will look for full nodes collected from discovery. If there is no discovered +// full nodes, it will wait until any peer appear in either source or timeout happen. +// After fetching data using given peer, caller is required to call returned DoneFunc using +// appropriate result value +func (m *Manager) Peer( + ctx context.Context, datahash share.DataHash, +) (peer.ID, DoneFunc, error) { + p := m.getOrCreatePool(datahash.String()) + if p.markValidated() { + log.Debugw("marked validated", "datahash", datahash.String()) + } + + // first, check if a peer is available for the given datahash + peerID, ok := p.tryGet() + if ok { + // some pools could still have blacklisted peers in storage + if m.peerIsBlacklisted(peerID) { + log.Debugw("removing blacklisted peer from pool", "hash", datahash.String(), + "peer", peerID.String()) + p.remove(peerID) + return m.Peer(ctx, datahash) + } + log.Debugw("returning shrex-sub peer", "hash", datahash.String(), + "peer", peerID.String()) + return peerID, m.doneFunc(datahash, peerID), nil + } + + // if no peer for datahash is currently available, try to use full node + // obtained from discovery + peerID, ok = m.fullNodes.tryGet() + if ok { + log.Debugw("got peer from full nodes discovery pool", "peer", peerID, "datahash", datahash.String()) + return peerID, m.doneFunc(datahash, peerID), nil + } + + // no peers are available right now, wait for the first one + select { + case peerID = <-p.next(ctx): + log.Debugw("got peer from shrexSub pool after wait", "peer", peerID, "datahash", datahash.String()) + return peerID, m.doneFunc(datahash, peerID), nil + case peerID = <-m.fullNodes.next(ctx): + log.Debugw("got peer from discovery pool after wait", "peer", peerID, "datahash", datahash.String()) + return peerID, m.doneFunc(datahash, peerID), nil + case <-ctx.Done(): + return "", nil, ctx.Err() + } +} + +func (m *Manager) doneFunc(datahash share.DataHash, peerID peer.ID) DoneFunc { + return func(result result) { + log.Debugw("set peer status", + "peer", peerID, + "datahash", datahash.String(), + "result", result) + switch result { + case ResultSuccess: + case ResultSynced: + m.getOrCreatePool(datahash.String()).markSynced() + case ResultCooldownPeer: + m.getOrCreatePool(datahash.String()).putOnCooldown(peerID) + case ResultBlacklistPeer: + m.blacklistPeers(peerID) + } + } +} + +// subscribeHeader takes datahash from received header and validates corresponding peer pool. +func (m *Manager) subscribeHeader(ctx context.Context, headerSub libhead.Subscription[*header.ExtendedHeader]) { + defer close(m.done) + defer headerSub.Cancel() + + for { + h, err := headerSub.NextHeader(ctx) + if err != nil { + if errors.Is(err, context.Canceled) { + return + } + log.Errorw("get next header from sub", "err", err) + continue + } + + if m.getOrCreatePool(h.DataHash.String()).markValidated() { + log.Debugw("marked validated", "datahash", h.DataHash.String()) + } + } +} + +// Validate will collect peer.ID into corresponding peer pool +func (m *Manager) validate(ctx context.Context, peerID peer.ID, hash share.DataHash) pubsub.ValidationResult { + // messages broadcast from self should bypass the validation with Accept + if peerID == m.host.ID() { + log.Debugw("received datahash from self", "datahash", hash.String()) + return pubsub.ValidationAccept + } + + // punish peer for sending invalid hash if it has misbehaved in the past + if m.hashIsBlacklisted(hash) { + log.Debugw("received blacklisted hash, reject validation", "peer", peerID, "datahash", hash.String()) + return pubsub.ValidationReject + } + + if m.peerIsBlacklisted(peerID) { + log.Debugw("received message from blacklisted peer, reject validation", "peer", peerID, "datahash", hash.String()) + return pubsub.ValidationReject + } + + m.getOrCreatePool(hash.String()).add(peerID) + log.Debugw("got hash from shrex-sub", "peer", peerID, "datahash", hash.String()) + return pubsub.ValidationIgnore +} + +func (m *Manager) getOrCreatePool(datahash string) *syncPool { + m.lock.Lock() + defer m.lock.Unlock() + + p, ok := m.pools[datahash] + if !ok { + p = &syncPool{ + pool: newPool(m.peerCooldownTime), + createdAt: time.Now(), + } + m.pools[datahash] = p + } + + return p +} + +func (m *Manager) blacklistPeers(peerIDs ...peer.ID) { + log.Debugw("blacklisting peers", "peers", peerIDs) + for _, peerID := range peerIDs { + m.fullNodes.remove(peerID) + // add peer to the blacklist, so we can't connect to it in the future. + err := m.connGater.BlockPeer(peerID) + if err != nil { + log.Warnw("failed tp block peer", "peer", peerID, "err", err) + } + // close connections to peer. + err = m.host.Network().ClosePeer(peerID) + if err != nil { + log.Warnw("failed to close connection with peer", "peer", peerID, "err", err) + } + } +} + +func (m *Manager) peerIsBlacklisted(peerID peer.ID) bool { + return !m.connGater.InterceptPeerDial(peerID) +} + +func (m *Manager) hashIsBlacklisted(hash share.DataHash) bool { + m.lock.Lock() + defer m.lock.Unlock() + return m.blacklistedHashes[hash.String()] +} + +func (m *Manager) GC(ctx context.Context) { + ticker := time.NewTicker(m.gcInterval) + defer ticker.Stop() + + var blacklist []peer.ID + for { + blacklist = m.cleanUp() + if len(blacklist) > 0 { + m.blacklistPeers(blacklist...) + } + + select { + case <-ticker.C: + case <-ctx.Done(): + return + } + } +} + +func (m *Manager) cleanUp() []peer.ID { + m.lock.Lock() + defer m.lock.Unlock() + + addToBlackList := make(map[peer.ID]struct{}) + for h, p := range m.pools { + if time.Since(p.createdAt) > m.poolValidationTimeout && !p.isValidatedDataHash.Load() { + log.Debug("blacklisting datahash with all corresponding peers", + "datahash", h, + "peer_list", p.peersList) + // blacklist hash + delete(m.pools, h) + m.blacklistedHashes[h] = true + + // blacklist peers + for _, peer := range p.peersList { + addToBlackList[peer] = struct{}{} + } + } + } + + blacklist := make([]peer.ID, 0, len(addToBlackList)) + for peerID := range addToBlackList { + blacklist = append(blacklist, peerID) + } + return blacklist +} + +func (p *syncPool) markSynced() { + p.isSynced.Store(true) + old := (*unsafe.Pointer)(unsafe.Pointer(&p.pool)) + // release pointer to old pool to free up memory + atomic.StorePointer(old, unsafe.Pointer(newPool(time.Second))) +} + +func (p *syncPool) markValidated() bool { + return p.isValidatedDataHash.CompareAndSwap(false, true) +} + +func (p *syncPool) add(peers ...peer.ID) { + if !p.isSynced.Load() { + p.pool.add(peers...) + } +} diff --git a/share/p2p/peers/manager_test.go b/share/p2p/peers/manager_test.go new file mode 100644 index 0000000000..a9034f4586 --- /dev/null +++ b/share/p2p/peers/manager_test.go @@ -0,0 +1,455 @@ +package peers + +import ( + "context" + sync2 "sync" + "testing" + "time" + + "github.com/ipfs/go-datastore" + "github.com/ipfs/go-datastore/sync" + routinghelpers "github.com/libp2p/go-libp2p-routing-helpers" + "github.com/libp2p/go-libp2p/p2p/net/conngater" + + dht "github.com/libp2p/go-libp2p-kad-dht" + pubsub "github.com/libp2p/go-libp2p-pubsub" + "github.com/libp2p/go-libp2p/core/host" + "github.com/libp2p/go-libp2p/core/peer" + routingdisc "github.com/libp2p/go-libp2p/p2p/discovery/routing" + mocknet "github.com/libp2p/go-libp2p/p2p/net/mock" + "github.com/stretchr/testify/require" + + "github.com/celestiaorg/celestia-node/header" + libhead "github.com/celestiaorg/celestia-node/libs/header" + "github.com/celestiaorg/celestia-node/share" + "github.com/celestiaorg/celestia-node/share/availability/discovery" + "github.com/celestiaorg/celestia-node/share/p2p/shrexsub" +) + +// TODO: add broadcast to tests +func TestManager(t *testing.T) { + t.Run("validate datahash by headerSub", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*50) + t.Cleanup(cancel) + + // create headerSub mock + h := testHeader() + headerSub := newSubLock(h, nil) + + // start test manager + manager, err := testManager(ctx, headerSub) + require.NoError(t, err) + + // wait until header is requested from header sub + err = headerSub.wait(ctx, 1) + require.NoError(t, err) + + // check validation + require.True(t, manager.pools[h.DataHash.String()].isValidatedDataHash.Load()) + stopManager(t, manager) + }) + + t.Run("validate datahash by shrex.Getter", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + t.Cleanup(cancel) + + h := testHeader() + headerSub := newSubLock(h, nil) + + // start test manager + manager, err := testManager(ctx, headerSub) + require.NoError(t, err) + + peerID := peer.ID("peer1") + result := manager.validate(ctx, peerID, h.DataHash.Bytes()) + require.Equal(t, pubsub.ValidationIgnore, result) + + pID, done, err := manager.Peer(ctx, h.DataHash.Bytes()) + require.NoError(t, err) + require.Equal(t, peerID, pID) + + // check pool validation + require.True(t, manager.getOrCreatePool(h.DataHash.String()).isValidatedDataHash.Load()) + + done(ResultSynced) + // pool should not be removed after success + require.Len(t, manager.pools, 1) + require.Len(t, manager.getOrCreatePool(h.DataHash.String()).pool.peersList, 0) + }) + + t.Run("validator", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*50) + t.Cleanup(cancel) + + // create headerSub mock + h := testHeader() + headerSub := newSubLock(h, nil) + + // start test manager + manager, err := testManager(ctx, headerSub) + require.NoError(t, err) + + result := manager.validate(ctx, manager.host.ID(), h.DataHash.Bytes()) + require.Equal(t, pubsub.ValidationAccept, result) + + peerID := peer.ID("peer1") + result = manager.validate(ctx, peerID, h.DataHash.Bytes()) + require.Equal(t, pubsub.ValidationIgnore, result) + + pID, done, err := manager.Peer(ctx, h.DataHash.Bytes()) + require.NoError(t, err) + require.Equal(t, peerID, pID) + + // mark peer as misbehaved tp blacklist it + done(ResultBlacklistPeer) + + // misbehaved should be Rejected + result = manager.validate(ctx, pID, h.DataHash.Bytes()) + require.Equal(t, pubsub.ValidationReject, result) + + stopManager(t, manager) + }) + + t.Run("cleanup", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + t.Cleanup(cancel) + + // create headerSub mock + h := testHeader() + headerSub := newSubLock(h) + + // start test manager + manager, err := testManager(ctx, headerSub) + require.NoError(t, err) + + peerID := peer.ID("peer1") + manager.validate(ctx, peerID, h.DataHash.Bytes()) + // set syncTimeout to 0 to allow cleanup to find outdated datahash + manager.poolValidationTimeout = 0 + + blacklisted := manager.cleanUp() + require.Contains(t, blacklisted, peerID) + require.True(t, manager.hashIsBlacklisted(h.DataHash.Bytes())) + }) + + t.Run("no peers from shrex.Sub, get from discovery", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + t.Cleanup(cancel) + + // create headerSub mock + h := testHeader() + headerSub := newSubLock(h) + + // start test manager + manager, err := testManager(ctx, headerSub) + require.NoError(t, err) + + // add peers to fullnodes, imitating discovery add + peers := []peer.ID{"peer1", "peer2", "peer3"} + manager.fullNodes.add(peers...) + + peerID, done, err := manager.Peer(ctx, h.DataHash.Bytes()) + done(ResultSynced) + require.NoError(t, err) + require.Contains(t, peers, peerID) + + stopManager(t, manager) + }) + + t.Run("no peers from shrex.Sub and from discovery. Wait", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + t.Cleanup(cancel) + + // create headerSub mock + h := testHeader() + headerSub := newSubLock(h) + + // start test manager + manager, err := testManager(ctx, headerSub) + require.NoError(t, err) + + // make sure peers are not returned before timeout + timeoutCtx, cancel := context.WithTimeout(context.Background(), time.Millisecond) + t.Cleanup(cancel) + _, _, err = manager.Peer(timeoutCtx, h.DataHash.Bytes()) + require.ErrorIs(t, err, context.DeadlineExceeded) + + peers := []peer.ID{"peer1", "peer2", "peer3"} + + // launch wait routine + doneCh := make(chan struct{}) + go func() { + defer close(doneCh) + peerID, done, err := manager.Peer(ctx, h.DataHash.Bytes()) + done(ResultSynced) + require.NoError(t, err) + require.Contains(t, peers, peerID) + }() + + // send peers + manager.fullNodes.add(peers...) + + // wait for peer to be received + select { + case <-doneCh: + case <-ctx.Done(): + require.NoError(t, ctx.Err()) + } + + stopManager(t, manager) + }) + + t.Run("get peer from discovery", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + t.Cleanup(cancel) + + h := testHeader() + headerSub := newSubLock(h, nil) + + // start test manager + manager, err := testManager(ctx, headerSub) + require.NoError(t, err) + + peerID := peer.ID("peer1") + result := manager.validate(ctx, peerID, h.DataHash.Bytes()) + require.Equal(t, pubsub.ValidationIgnore, result) + + pID, done, err := manager.Peer(ctx, h.DataHash.Bytes()) + require.NoError(t, err) + require.Equal(t, peerID, pID) + done(ResultSynced) + + // check pool is soft deleted and marked synced + pool := manager.getOrCreatePool(h.DataHash.String()) + require.Len(t, pool.peersList, 0) + require.True(t, pool.isSynced.Load()) + + // add peer on synced pool should be noop + result = manager.validate(ctx, "peer2", h.DataHash.Bytes()) + require.Equal(t, pubsub.ValidationIgnore, result) + require.Len(t, pool.peersList, 0) + }) +} + +func TestIntegration(t *testing.T) { + t.Run("get peer from shrexsub", func(t *testing.T) { + nw, err := mocknet.FullMeshLinked(2) + require.NoError(t, err) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + t.Cleanup(cancel) + + bnPubSub, err := shrexsub.NewPubSub(ctx, nw.Hosts()[0], "test") + require.NoError(t, err) + + fnPubSub, err := shrexsub.NewPubSub(ctx, nw.Hosts()[1], "test") + require.NoError(t, err) + + require.NoError(t, bnPubSub.Start(ctx)) + require.NoError(t, fnPubSub.Start(ctx)) + + fnPeerManager, err := testManager(ctx, newSubLock()) + require.NoError(t, err) + fnPeerManager.host = nw.Hosts()[1] + + require.NoError(t, fnPubSub.AddValidator(fnPeerManager.validate)) + _, err = fnPubSub.Subscribe() + require.NoError(t, err) + + time.Sleep(time.Millisecond * 100) + require.NoError(t, nw.ConnectAllButSelf()) + time.Sleep(time.Millisecond * 100) + + // broadcast from BN + peerHash := share.DataHash("peer1") + require.NoError(t, bnPubSub.Broadcast(ctx, peerHash)) + + // FN should get message + peerID, _, err := fnPeerManager.Peer(ctx, peerHash) + require.NoError(t, err) + + // check that peerID matched bridge node + require.Equal(t, nw.Hosts()[0].ID(), peerID) + }) + + t.Run("get peer from discovery", func(t *testing.T) { + nw, err := mocknet.FullMeshConnected(3) + require.NoError(t, err) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + t.Cleanup(cancel) + + // set up bootstrapper + bsHost := nw.Hosts()[2] + bs := host.InfoFromHost(bsHost) + opts := []dht.Option{ + dht.Mode(dht.ModeAuto), + dht.BootstrapPeers(*bs), + dht.RoutingTableRefreshPeriod(time.Second), + } + + bsOpts := opts + bsOpts = append(bsOpts, + dht.Mode(dht.ModeServer), // it must accept incoming connections + dht.BootstrapPeers(), // no bootstrappers for a bootstrapper ¯\_(ツ)_/¯ + ) + bsRouter, err := dht.New(ctx, bsHost, bsOpts...) + require.NoError(t, err) + require.NoError(t, bsRouter.Bootstrap(ctx)) + + // set up broadcaster node + bnHost := nw.Hosts()[0] + router1, err := dht.New(ctx, bnHost, opts...) + require.NoError(t, err) + bnDisc := discovery.NewDiscovery( + nw.Hosts()[0], + routingdisc.NewRoutingDiscovery(router1), + 10, + time.Second, + time.Second) + + // set up full node / receiver node + fnHost := nw.Hosts()[0] + router2, err := dht.New(ctx, fnHost, opts...) + require.NoError(t, err) + fnDisc := discovery.NewDiscovery( + nw.Hosts()[1], + routingdisc.NewRoutingDiscovery(router2), + 10, + time.Second, + time.Second) + + // hook peer manager to discovery + connGater, err := conngater.NewBasicConnectionGater(sync.MutexWrap(datastore.NewMapDatastore())) + require.NoError(t, err) + fnPeerManager := NewManager( + nil, + nil, + fnDisc, + nil, + connGater, + WithValidationTimeout(time.Minute), + WithPeerCooldown(time.Second), + ) + + waitCh := make(chan struct{}) + fnDisc.WithOnPeersUpdate(func(peerID peer.ID, isAdded bool) { + defer close(waitCh) + // check that obtained peer id is same as BN + require.Equal(t, nw.Hosts()[0].ID(), peerID) + }) + + require.NoError(t, router1.Bootstrap(ctx)) + require.NoError(t, router2.Bootstrap(ctx)) + + go fnDisc.EnsurePeers(ctx) + go bnDisc.Advertise(ctx) + + select { + case <-waitCh: + require.Contains(t, fnPeerManager.fullNodes.peersList, fnHost.ID()) + case <-ctx.Done(): + require.NoError(t, ctx.Err()) + } + }) +} + +func testManager(ctx context.Context, headerSub libhead.Subscriber[*header.ExtendedHeader]) (*Manager, error) { + host, err := mocknet.New().GenPeer() + if err != nil { + return nil, err + } + shrexSub, err := shrexsub.NewPubSub(ctx, host, "test") + if err != nil { + return nil, err + } + disc := discovery.NewDiscovery(nil, + routingdisc.NewRoutingDiscovery(routinghelpers.Null{}), 0, time.Second, time.Second) + connGater, err := conngater.NewBasicConnectionGater(sync.MutexWrap(datastore.NewMapDatastore())) + if err != nil { + return nil, err + } + manager := NewManager( + headerSub, + shrexSub, + disc, + host, + connGater, + WithValidationTimeout(time.Minute), + WithPeerCooldown(time.Second), + ) + err = manager.Start(ctx) + return manager, err +} + +func stopManager(t *testing.T, m *Manager) { + closeCtx, cancel := context.WithTimeout(context.Background(), time.Second) + t.Cleanup(cancel) + require.NoError(t, m.Stop(closeCtx)) +} + +func testHeader() *header.ExtendedHeader { + return &header.ExtendedHeader{ + RawHeader: header.RawHeader{}, + } +} + +type subLock struct { + next chan struct{} + wg *sync2.WaitGroup + expected []*header.ExtendedHeader +} + +func (s subLock) wait(ctx context.Context, count int) error { + s.wg.Add(count) + for i := 0; i < count; i++ { + err := s.release(ctx) + if err != nil { + return err + } + } + s.wg.Wait() + return nil +} + +func (s subLock) release(ctx context.Context) error { + select { + case s.next <- struct{}{}: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +func newSubLock(expected ...*header.ExtendedHeader) *subLock { + wg := &sync2.WaitGroup{} + wg.Add(1) + return &subLock{ + next: make(chan struct{}), + expected: expected, + wg: wg, + } +} + +func (s *subLock) Subscribe() (libhead.Subscription[*header.ExtendedHeader], error) { + return s, nil +} + +func (s *subLock) AddValidator(f func(context.Context, *header.ExtendedHeader) pubsub.ValidationResult) error { + panic("implement me") +} + +func (s *subLock) NextHeader(ctx context.Context) (*header.ExtendedHeader, error) { + s.wg.Done() + + // wait for call to be unlocked by release + select { + case <-s.next: + h := s.expected[0] + s.expected = s.expected[1:] + return h, nil + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +func (s *subLock) Cancel() { +} diff --git a/share/p2p/peers/options.go b/share/p2p/peers/options.go new file mode 100644 index 0000000000..6858d3f222 --- /dev/null +++ b/share/p2p/peers/options.go @@ -0,0 +1,75 @@ +package peers + +import ( + "fmt" + "time" +) + +// Option is the functional option that is applied to the manager instance to configure peer manager +// parameters (the Parameters struct) +type Option func(manager *Manager) + +type Parameters struct { + // ValidationTimeout is the timeout used for validating incoming datahashes. Pools that have + // been created for datahashes from shrexsub that do not see this hash from headersub after this + // timeout will be garbage collected. + ValidationTimeout time.Duration + + // PeerCooldown is the time a peer is put on cooldown after a ResultCooldownPeer. + PeerCooldown time.Duration + + // GcInterval is the interval at which the manager will garbage collect unvalidated pools. + GcInterval time.Duration +} + +// Validate validates the values in Parameters +func (p *Parameters) Validate() error { + if p.ValidationTimeout <= 0 { + return fmt.Errorf("peer-manager: validation timeout must be positive") + } + + if p.PeerCooldown <= 0 { + return fmt.Errorf("peer-manager: peer cooldown must be positive") + } + + if p.GcInterval <= 0 { + return fmt.Errorf("peer-manager: garbage collection interval must be positive") + } + + return nil +} + +// DefaultParameters returns the default configuration values for the daser parameters +func DefaultParameters() Parameters { + return Parameters{ + // ValidationTimeout's default value is based on the default daser sampling timeout of 1 minute. + // If a received datahash has not tried to be sampled within these two minutes, the pool will be removed. + ValidationTimeout: 2 * time.Minute, + // PeerCooldown's default value is based on initial network tests that showed a ~3.5 second + // sync time for large blocks. This value gives our (discovery) peers enough time to sync + // the new block before we ask them again. + PeerCooldown: 3 * time.Second, + GcInterval: time.Second * 30, + } +} + +// WithValidationTimeout configures the manager's pool validation timeout. +func WithValidationTimeout(timeout time.Duration) Option { + return func(manager *Manager) { + manager.poolValidationTimeout = timeout + } +} + +// WithPeerCooldown configures the manager's peer cooldown time. +func WithPeerCooldown(cooldown time.Duration) Option { + return func(manager *Manager) { + manager.peerCooldownTime = cooldown + } +} + +// WithGcInterval configures the manager's garbage collection interval. +func WithGcInterval(interval time.Duration) Option { + return func(manager *Manager) { + manager.gcInterval = interval + } +} diff --git a/share/p2p/peers/pool.go b/share/p2p/peers/pool.go new file mode 100644 index 0000000000..89d79291bf --- /dev/null +++ b/share/p2p/peers/pool.go @@ -0,0 +1,200 @@ +package peers + +import ( + "context" + "sync" + "time" + + "github.com/libp2p/go-libp2p/core/peer" +) + +const defaultCleanupThreshold = 2 + +// pool stores peers and provides methods for simple round-robin access. +type pool struct { + m sync.Mutex + peersList []peer.ID + statuses map[peer.ID]status + cooldown *timedQueue + activeCount int + nextIdx int + + hasPeer bool + hasPeerCh chan struct{} + + cleanupThreshold int +} + +type status int + +const ( + active status = iota + cooldown + removed +) + +// newPool returns new empty pool. +func newPool(peerCooldownTime time.Duration) *pool { + p := &pool{ + peersList: make([]peer.ID, 0), + statuses: make(map[peer.ID]status), + hasPeerCh: make(chan struct{}), + cleanupThreshold: defaultCleanupThreshold, + } + p.cooldown = newTimedQueue(peerCooldownTime, p.afterCooldown) + return p +} + +// tryGet returns peer along with bool flag indicating success of operation. +func (p *pool) tryGet() (peer.ID, bool) { + p.m.Lock() + defer p.m.Unlock() + + if p.activeCount == 0 { + return "", false + } + + start := p.nextIdx + for { + peerID := p.peersList[p.nextIdx] + + p.nextIdx++ + if p.nextIdx == len(p.peersList) { + p.nextIdx = 0 + } + + if p.statuses[peerID] == active { + return peerID, true + } + + // full circle passed + if p.nextIdx == start { + return "", false + } + } +} + +// next sends a peer to the returned channel when it becomes available. +func (p *pool) next(ctx context.Context) <-chan peer.ID { + peerCh := make(chan peer.ID, 1) + go func() { + for { + if peerID, ok := p.tryGet(); ok { + peerCh <- peerID + return + } + + select { + case <-p.hasPeerCh: + case <-ctx.Done(): + return + } + } + }() + return peerCh +} + +func (p *pool) add(peers ...peer.ID) { + p.m.Lock() + defer p.m.Unlock() + + for _, peerID := range peers { + status, ok := p.statuses[peerID] + if ok && status != removed { + continue + } + + if !ok { + p.peersList = append(p.peersList, peerID) + } + + p.statuses[peerID] = active + p.activeCount++ + } + p.checkHasPeers() +} + +func (p *pool) remove(peers ...peer.ID) { + p.m.Lock() + defer p.m.Unlock() + + for _, peerID := range peers { + if status, ok := p.statuses[peerID]; ok && status != removed { + p.statuses[peerID] = removed + if status == active { + p.activeCount-- + } + } + } + + // do cleanup if too much garbage + if len(p.peersList) >= p.activeCount+p.cleanupThreshold { + p.cleanup() + } + p.checkHasPeers() +} + +// cleanup will reduce memory footprint of pool. +func (p *pool) cleanup() { + newList := make([]peer.ID, 0, p.activeCount) + for idx, peerID := range p.peersList { + status := p.statuses[peerID] + switch status { + case active, cooldown: + newList = append(newList, peerID) + case removed: + delete(p.statuses, peerID) + } + + if idx == p.nextIdx { + // if peer is not active and no more active peers left in list point to first peer + if status != active && len(newList) >= p.activeCount { + p.nextIdx = 0 + continue + } + p.nextIdx = len(newList) + } + } + p.peersList = newList +} + +func (p *pool) putOnCooldown(peerID peer.ID) { + p.m.Lock() + defer p.m.Unlock() + + if status, ok := p.statuses[peerID]; ok && status == active { + p.cooldown.push(peerID) + + p.statuses[peerID] = cooldown + p.activeCount-- + p.checkHasPeers() + } +} + +func (p *pool) afterCooldown(peerID peer.ID) { + p.m.Lock() + defer p.m.Unlock() + + // item could have been already removed by the time afterCooldown is called + if status, ok := p.statuses[peerID]; !ok || status != cooldown { + return + } + + p.statuses[peerID] = active + p.activeCount++ + p.checkHasPeers() +} + +// checkHasPeers will check and indicate if there are peers in the pool. +func (p *pool) checkHasPeers() { + if p.activeCount > 0 && !p.hasPeer { + p.hasPeer = true + close(p.hasPeerCh) + return + } + + if p.activeCount == 0 && p.hasPeer { + p.hasPeerCh = make(chan struct{}) + p.hasPeer = false + } +} diff --git a/share/p2p/peers/pool_test.go b/share/p2p/peers/pool_test.go new file mode 100644 index 0000000000..3f4f6a5b9a --- /dev/null +++ b/share/p2p/peers/pool_test.go @@ -0,0 +1,182 @@ +package peers + +import ( + "context" + "testing" + "time" + + "github.com/libp2p/go-libp2p/core/peer" + "github.com/stretchr/testify/require" +) + +func TestPool(t *testing.T) { + t.Run("add / remove peers", func(t *testing.T) { + p := newPool(time.Second) + + peers := []peer.ID{"peer1", "peer1", "peer2", "peer3"} + // adding same peer twice should not produce copies + p.add(peers...) + require.Equal(t, len(peers)-1, p.activeCount) + + p.remove("peer1", "peer2") + require.Equal(t, len(peers)-3, p.activeCount) + + peerID, ok := p.tryGet() + require.True(t, ok) + require.Equal(t, peers[3], peerID) + + p.remove("peer3") + p.remove("peer3") + require.Equal(t, 0, p.activeCount) + _, ok = p.tryGet() + require.False(t, ok) + }) + + t.Run("round robin", func(t *testing.T) { + p := newPool(time.Second) + + peers := []peer.ID{"peer1", "peer1", "peer2", "peer3"} + // adding same peer twice should not produce copies + p.add(peers...) + require.Equal(t, 3, p.activeCount) + + peerID, ok := p.tryGet() + require.True(t, ok) + require.Equal(t, peer.ID("peer1"), peerID) + + peerID, ok = p.tryGet() + require.True(t, ok) + require.Equal(t, peer.ID("peer2"), peerID) + + peerID, ok = p.tryGet() + require.True(t, ok) + require.Equal(t, peer.ID("peer3"), peerID) + + peerID, ok = p.tryGet() + require.True(t, ok) + require.Equal(t, peer.ID("peer1"), peerID) + + p.remove("peer2", "peer3") + require.Equal(t, 1, p.activeCount) + + // pointer should skip removed items until found active one + peerID, ok = p.tryGet() + require.True(t, ok) + require.Equal(t, peer.ID("peer1"), peerID) + }) + + t.Run("wait for peer", func(t *testing.T) { + timeout := time.Second + shortCtx, cancel := context.WithTimeout(context.Background(), timeout/10) + t.Cleanup(cancel) + + longCtx, cancel := context.WithTimeout(context.Background(), timeout) + t.Cleanup(cancel) + + p := newPool(time.Second) + done := make(chan struct{}) + + go func() { + select { + case <-p.next(shortCtx): + case <-shortCtx.Done(): + require.Error(t, shortCtx.Err()) + // unlock longCtx waiter by adding new peer + p.add("peer1") + } + }() + + go func() { + defer close(done) + select { + case peerID := <-p.next(longCtx): + require.Equal(t, peer.ID("peer1"), peerID) + case <-longCtx.Done(): + require.NoError(t, longCtx.Err()) + } + }() + + select { + case <-done: + case <-longCtx.Done(): + require.NoError(t, longCtx.Err()) + } + }) + + t.Run("nextIdx got removed", func(t *testing.T) { + p := newPool(time.Second) + + peers := []peer.ID{"peer1", "peer2", "peer3"} + p.add(peers...) + p.nextIdx = 2 + p.remove(peers[p.nextIdx]) + + // if previous nextIdx was removed, tryGet should iterate until available peer found + peerID, ok := p.tryGet() + require.True(t, ok) + require.Equal(t, peers[0], peerID) + }) + + t.Run("cleanup", func(t *testing.T) { + p := newPool(time.Second) + p.cleanupThreshold = 3 + + peers := []peer.ID{"peer1", "peer2", "peer3", "peer4", "peer5"} + p.add(peers...) + require.Equal(t, len(peers), p.activeCount) + + // point to last element that will be removed, to check how pointer will be updated + p.nextIdx = len(peers) - 1 + + // remove some, but not trigger cleanup yet + p.remove(peers[3:]...) + require.Equal(t, len(peers)-2, p.activeCount) + require.Equal(t, len(peers), len(p.statuses)) + + // trigger cleanup + p.remove(peers[2]) + require.Equal(t, len(peers)-3, p.activeCount) + require.Equal(t, len(peers)-3, len(p.statuses)) + // nextIdx pointer should be updated + require.Equal(t, 0, p.nextIdx) + }) + + t.Run("cooldown blocks get", func(t *testing.T) { + ttl := time.Second / 10 + p := newPool(ttl) + + peerID := peer.ID("peer1") + p.add(peerID) + + _, ok := p.tryGet() + require.True(t, ok) + + p.putOnCooldown(peerID) + // item should be unavailable + _, ok = p.tryGet() + require.False(t, ok) + + ctx, cancel := context.WithTimeout(context.Background(), ttl*5) + defer cancel() + select { + case <-p.next(ctx): + case <-ctx.Done(): + t.Fatal("item should be already available") + } + }) + + t.Run("put on cooldown removed item should be noop", func(t *testing.T) { + p := newPool(time.Second) + p.cleanupThreshold = 3 + + peerID := peer.ID("peer1") + p.add(peerID) + + p.remove(peerID) + p.cleanup() + p.putOnCooldown(peerID) + + _, ok := p.tryGet() + require.False(t, ok) + }) +} diff --git a/share/p2p/peers/timedqueue.go b/share/p2p/peers/timedqueue.go new file mode 100644 index 0000000000..3ed7e29a2c --- /dev/null +++ b/share/p2p/peers/timedqueue.go @@ -0,0 +1,91 @@ +package peers + +import ( + "sync" + "time" + + "github.com/benbjohnson/clock" + "github.com/libp2p/go-libp2p/core/peer" +) + +// timedQueue store items for ttl duration and releases it with calling onPop callback. Each item +// is tracked independently +type timedQueue struct { + sync.Mutex + items []item + + // ttl is the amount of time each item exist in the timedQueue + ttl time.Duration + clock clock.Clock + after *clock.Timer + // onPop will be called on item peer.ID after it is released + onPop func(peer.ID) +} + +type item struct { + peer.ID + createdAt time.Time +} + +func newTimedQueue(ttl time.Duration, onPop func(peer.ID)) *timedQueue { + return &timedQueue{ + items: make([]item, 0), + clock: clock.New(), + ttl: ttl, + onPop: onPop, + } +} + +// releaseExpired will release all expired items +func (q *timedQueue) releaseExpired() { + q.Lock() + defer q.Unlock() + q.releaseUnsafe() +} + +func (q *timedQueue) releaseUnsafe() { + if len(q.items) == 0 { + return + } + + var i int + for _, next := range q.items { + timeIn := q.clock.Since(next.createdAt) + if timeIn < q.ttl { + // item is not expired yet, create a timer that will call releaseExpired + q.after.Stop() + q.after = q.clock.AfterFunc(q.ttl-timeIn, q.releaseExpired) + break + } + + // item is expired + q.onPop(next.ID) + i++ + } + + if i > 0 { + copy(q.items, q.items[i:]) + q.items = q.items[:len(q.items)-i] + } +} + +func (q *timedQueue) push(peerID peer.ID) { + q.Lock() + defer q.Unlock() + + q.items = append(q.items, item{ + ID: peerID, + createdAt: q.clock.Now(), + }) + + // if it is the first item in queue, create a timer to call releaseExpired after its expiration + if len(q.items) == 1 { + q.after = q.clock.AfterFunc(q.ttl, q.releaseExpired) + } +} + +func (q *timedQueue) len() int { + q.Lock() + defer q.Unlock() + return len(q.items) +} diff --git a/share/p2p/peers/timedqueue_test.go b/share/p2p/peers/timedqueue_test.go new file mode 100644 index 0000000000..fb5ef9629f --- /dev/null +++ b/share/p2p/peers/timedqueue_test.go @@ -0,0 +1,61 @@ +package peers + +import ( + "testing" + "time" + + "github.com/benbjohnson/clock" + "github.com/libp2p/go-libp2p/core/peer" + "github.com/stretchr/testify/require" +) + +func TestTimedQueue(t *testing.T) { + t.Run("push item", func(t *testing.T) { + peers := []peer.ID{"peer1", "peer2"} + ttl := time.Second + + popCh := make(chan struct{}, 1) + queue := newTimedQueue(ttl, func(id peer.ID) { + go func() { + require.Contains(t, peers, id) + popCh <- struct{}{} + }() + }) + mock := clock.NewMock() + queue.clock = mock + + // push first item | global time : 0 + queue.push(peers[0]) + require.Equal(t, queue.len(), 1) + + // push second item with ttl/2 gap | global time : ttl/2 + mock.Add(ttl / 2) + queue.push(peers[1]) + require.Equal(t, queue.len(), 2) + + // advance clock 1 nano sec before first item should expire | global time : ttl - 1 + mock.Add(ttl/2 - 1) + // check that releaseExpired doesn't remove items + queue.releaseExpired() + require.Equal(t, queue.len(), 2) + // first item should be released after its own timeout | global time : ttl + mock.Add(1) + + select { + case <-popCh: + case <-time.After(ttl): + t.Fatal("first item is not released") + + } + require.Equal(t, queue.len(), 1) + + // first item should be released after ttl/2 gap timeout | global time : 3/2*ttl + mock.Add(ttl / 2) + select { + case <-popCh: + case <-time.After(ttl): + t.Fatal("second item is not released") + } + require.Equal(t, queue.len(), 0) + }) +} diff --git a/share/p2p/shrexeds/client.go b/share/p2p/shrexeds/client.go index 3de122227a..ffc59c82c0 100644 --- a/share/p2p/shrexeds/client.go +++ b/share/p2p/shrexeds/client.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "io" "net" "time" @@ -65,8 +66,9 @@ func (c *Client) RequestEDS( } } if err != p2p.ErrUnavailable { - log.Errorw("client: eds request to peer failed", "peer", peer, "hash", dataHash.String()) + log.Debugw("client: eds request to peer failed", "peer", peer, "hash", dataHash.String()) } + return nil, err } @@ -82,7 +84,7 @@ func (c *Client) doRequest( if dl, ok := ctx.Deadline(); ok { if err = stream.SetDeadline(dl); err != nil { - log.Debugw("error setting deadline: %s", err) + log.Debugf("error setting deadline: %s", err) } } @@ -105,6 +107,10 @@ func (c *Client) doRequest( resp := new(pb.EDSResponse) _, err = serde.Read(stream, resp) if err != nil { + // server is overloaded and closed the stream + if errors.Is(err, io.EOF) { + return nil, p2p.ErrUnavailable + } stream.Reset() //nolint:errcheck return nil, fmt.Errorf("failed to read status from stream: %w", err) } @@ -117,7 +123,7 @@ func (c *Client) doRequest( return nil, fmt.Errorf("failed to read eds from ods bytes: %w", err) } return eds, nil - case pb.Status_NOT_FOUND, pb.Status_REFUSED: + case pb.Status_NOT_FOUND: log.Debugf("client: peer %s couldn't serve eds %s with status %s", to.String(), dataHash.String(), resp.GetStatus()) return nil, p2p.ErrUnavailable case pb.Status_INVALID: diff --git a/share/p2p/shrexeds/exchange_test.go b/share/p2p/shrexeds/exchange_test.go index bfc8f076ee..376d3472c7 100644 --- a/share/p2p/shrexeds/exchange_test.go +++ b/share/p2p/shrexeds/exchange_test.go @@ -2,12 +2,14 @@ package shrexeds import ( "context" + "sync" "testing" "time" "github.com/ipfs/go-datastore" ds_sync "github.com/ipfs/go-datastore/sync" libhost "github.com/libp2p/go-libp2p/core/host" + "github.com/libp2p/go-libp2p/core/network" mocknet "github.com/libp2p/go-libp2p/p2p/net/mock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -83,6 +85,47 @@ func TestExchange_RequestEDS(t *testing.T) { assert.ErrorIs(t, err, timeoutCtx.Err()) assert.Nil(t, requestedEDS) }) + + // Testcase: Concurrency limit reached + t.Run("EDS_concurrency_limit", func(t *testing.T) { + store, client, server := makeExchange(t) + + require.NoError(t, store.Start(ctx)) + require.NoError(t, server.Start(ctx)) + + ctx, cancel := context.WithTimeout(ctx, time.Second) + t.Cleanup(cancel) + + rateLimit := 2 + wg := sync.WaitGroup{} + wg.Add(rateLimit) + + // mockHandler will block requests on server side until test is over + lock := make(chan struct{}) + defer close(lock) + mockHandler := func(network.Stream) { + wg.Done() + select { + case <-lock: + case <-ctx.Done(): + t.Fatal("timeout") + } + } + server.host.SetStreamHandler(server.protocolID, + p2p.RateLimitMiddleware(mockHandler, rateLimit)) + + // take server concurrency slots with blocked requests + for i := 0; i < rateLimit; i++ { + go func(i int) { + client.RequestEDS(ctx, nil, server.host.ID()) //nolint:errcheck + }(i) + } + + // wait until all server slots are taken + wg.Wait() + _, err = client.RequestEDS(ctx, nil, server.host.ID()) + require.ErrorIs(t, err, p2p.ErrUnavailable) + }) } func newStore(t *testing.T) *eds.Store { diff --git a/share/p2p/shrexeds/options.go b/share/p2p/shrexeds/options.go index 973de4c530..4239ec5bea 100644 --- a/share/p2p/shrexeds/options.go +++ b/share/p2p/shrexeds/options.go @@ -33,14 +33,18 @@ type Parameters struct { // protocolSuffix is appended to the protocolID and represents the network the protocol is // running on. protocolSuffix string + + // concurrencyLimit is the maximum number of concurrently handled streams + concurrencyLimit int } func DefaultParameters() *Parameters { return &Parameters{ - ReadDeadline: time.Minute, - WriteDeadline: time.Second * 5, - ReadCARDeadline: time.Minute, - BufferSize: 32 * 1024, + ReadDeadline: time.Minute, + WriteDeadline: time.Second * 5, + ReadCARDeadline: time.Minute, + BufferSize: 32 * 1024, + concurrencyLimit: 10, } } @@ -59,6 +63,9 @@ func (p *Parameters) Validate() error { if p.BufferSize <= 0 { return fmt.Errorf("invalid buffer size: %s", errSuffix) } + if p.concurrencyLimit <= 0 { + return fmt.Errorf("invalid concurrency limit: %s", errSuffix) + } return nil } @@ -69,6 +76,13 @@ func WithProtocolSuffix(protocolSuffix string) Option { } } +// WithConcurrencyLimit is a functional option that configures the `concurrencyLimit` parameter +func WithConcurrencyLimit(concurrencyLimit int) Option { + return func(parameters *Parameters) { + parameters.concurrencyLimit = concurrencyLimit + } +} + func protocolID(protocolSuffix string) protocol.ID { return protocol.ID(fmt.Sprintf("%s%s", protocolPrefix, protocolSuffix)) } diff --git a/share/p2p/shrexeds/pb/extended_data_square.pb.go b/share/p2p/shrexeds/pb/extended_data_square.pb.go index 2a62e4b503..c6f7e07123 100644 --- a/share/p2p/shrexeds/pb/extended_data_square.pb.go +++ b/share/p2p/shrexeds/pb/extended_data_square.pb.go @@ -28,21 +28,18 @@ const ( Status_INVALID Status = 0 Status_OK Status = 1 Status_NOT_FOUND Status = 2 - Status_REFUSED Status = 3 ) var Status_name = map[int32]string{ 0: "INVALID", 1: "OK", 2: "NOT_FOUND", - 3: "REFUSED", } var Status_value = map[string]int32{ "INVALID": 0, "OK": 1, "NOT_FOUND": 2, - "REFUSED": 3, } func (x Status) String() string { @@ -152,7 +149,7 @@ func init() { } var fileDescriptor_49d42aa96098056e = []byte{ - // 224 bytes of a gzipped FileDescriptorProto + // 213 bytes of a gzipped FileDescriptorProto 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0x32, 0x28, 0xce, 0x48, 0x2c, 0x4a, 0xd5, 0x2f, 0x30, 0x2a, 0xd0, 0x2f, 0xce, 0x28, 0x4a, 0xad, 0x48, 0x4d, 0x29, 0xd6, 0x2f, 0x48, 0xd2, 0x4f, 0xad, 0x28, 0x49, 0xcd, 0x4b, 0x49, 0x4d, 0x89, 0x4f, 0x49, 0x2c, 0x49, 0x8c, @@ -160,13 +157,13 @@ var fileDescriptor_49d42aa96098056e = []byte{ 0x72, 0x75, 0x09, 0x0e, 0x4a, 0x2d, 0x2c, 0x4d, 0x2d, 0x2e, 0x11, 0x12, 0xe2, 0x62, 0xc9, 0x48, 0x2c, 0xce, 0x90, 0x60, 0x54, 0x60, 0xd4, 0xe0, 0x09, 0x02, 0xb3, 0x95, 0xf4, 0xb8, 0xb8, 0xc1, 0x2a, 0x8a, 0x0b, 0xf2, 0xf3, 0x8a, 0x53, 0x85, 0xe4, 0xb9, 0xd8, 0x8a, 0x4b, 0x12, 0x4b, 0x4a, - 0x8b, 0xc1, 0x8a, 0xf8, 0x8c, 0xd8, 0xf5, 0x82, 0xc1, 0xdc, 0x20, 0xa8, 0xb0, 0x96, 0x25, 0x17, + 0x8b, 0xc1, 0x8a, 0xf8, 0x8c, 0xd8, 0xf5, 0x82, 0xc1, 0xdc, 0x20, 0xa8, 0xb0, 0x96, 0x0e, 0x17, 0x1b, 0x44, 0x44, 0x88, 0x9b, 0x8b, 0xdd, 0xd3, 0x2f, 0xcc, 0xd1, 0xc7, 0xd3, 0x45, 0x80, 0x41, 0x88, 0x8d, 0x8b, 0xc9, 0xdf, 0x5b, 0x80, 0x51, 0x88, 0x97, 0x8b, 0xd3, 0xcf, 0x3f, 0x24, 0xde, - 0xcd, 0x3f, 0xd4, 0xcf, 0x45, 0x80, 0x09, 0xa4, 0x26, 0xc8, 0xd5, 0x2d, 0x34, 0xd8, 0xd5, 0x45, - 0x80, 0xd9, 0x49, 0xe2, 0xc4, 0x23, 0x39, 0xc6, 0x0b, 0x8f, 0xe4, 0x18, 0x1f, 0x3c, 0x92, 0x63, - 0x9c, 0xf0, 0x58, 0x8e, 0xe1, 0xc2, 0x63, 0x39, 0x86, 0x1b, 0x8f, 0xe5, 0x18, 0x92, 0xd8, 0xc0, - 0xae, 0x35, 0x06, 0x04, 0x00, 0x00, 0xff, 0xff, 0x8c, 0x57, 0xb5, 0xbb, 0xe1, 0x00, 0x00, 0x00, + 0xcd, 0x3f, 0xd4, 0xcf, 0x45, 0x80, 0xc9, 0x49, 0xe2, 0xc4, 0x23, 0x39, 0xc6, 0x0b, 0x8f, 0xe4, + 0x18, 0x1f, 0x3c, 0x92, 0x63, 0x9c, 0xf0, 0x58, 0x8e, 0xe1, 0xc2, 0x63, 0x39, 0x86, 0x1b, 0x8f, + 0xe5, 0x18, 0x92, 0xd8, 0xc0, 0x0e, 0x34, 0x06, 0x04, 0x00, 0x00, 0xff, 0xff, 0x34, 0x8f, 0xa7, + 0x57, 0xd4, 0x00, 0x00, 0x00, } func (m *EDSRequest) Marshal() (dAtA []byte, err error) { diff --git a/share/p2p/shrexeds/pb/extended_data_square.proto b/share/p2p/shrexeds/pb/extended_data_square.proto index f96493bf66..1830d85f8e 100644 --- a/share/p2p/shrexeds/pb/extended_data_square.proto +++ b/share/p2p/shrexeds/pb/extended_data_square.proto @@ -8,7 +8,6 @@ enum Status { INVALID = 0; OK = 1; // data found NOT_FOUND = 2; // data not found - REFUSED = 3; // request refused } message EDSResponse { diff --git a/share/p2p/shrexeds/server.go b/share/p2p/shrexeds/server.go index 1922cf9e00..60086517b0 100644 --- a/share/p2p/shrexeds/server.go +++ b/share/p2p/shrexeds/server.go @@ -12,6 +12,7 @@ import ( "github.com/celestiaorg/celestia-node/share" "github.com/celestiaorg/celestia-node/share/eds" + "github.com/celestiaorg/celestia-node/share/p2p" p2p_pb "github.com/celestiaorg/celestia-node/share/p2p/shrexeds/pb" "github.com/celestiaorg/go-libp2p-messenger/serde" ) @@ -50,7 +51,7 @@ func NewServer(host host.Host, store *eds.Store, opts ...Option) (*Server, error func (s *Server) Start(context.Context) error { s.ctx, s.cancel = context.WithCancel(context.Background()) - s.host.SetStreamHandler(s.protocolID, s.handleStream) + s.host.SetStreamHandler(s.protocolID, p2p.RateLimitMiddleware(s.handleStream, s.params.concurrencyLimit)) return nil } diff --git a/share/p2p/shrexnd/client.go b/share/p2p/shrexnd/client.go index 171c83d45d..d284dba0b3 100644 --- a/share/p2p/shrexnd/client.go +++ b/share/p2p/shrexnd/client.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "io" "net" "time" @@ -73,7 +74,7 @@ func (c *Client) RequestND( } } if err != p2p.ErrUnavailable { - log.Errorw("client-nd: peer returned err", "peer", peer, "err", err) + log.Debugw("client-nd: peer returned err", "peer", peer, "err", err) } return nil, err } @@ -111,6 +112,10 @@ func (c *Client) doRequest( var resp pb.GetSharesByNamespaceResponse _, err = serde.Read(stream, &resp) if err != nil { + // server is overloaded and closed the stream + if errors.Is(err, io.EOF) { + return nil, p2p.ErrUnavailable + } stream.Reset() //nolint:errcheck return nil, fmt.Errorf("client-nd: reading response: %w", err) } @@ -192,7 +197,7 @@ func statusToErr(code pb.StatusCode) error { switch code { case pb.StatusCode_OK: return nil - case pb.StatusCode_NOT_FOUND, pb.StatusCode_REFUSED: + case pb.StatusCode_NOT_FOUND: return p2p.ErrUnavailable case pb.StatusCode_INTERNAL, pb.StatusCode_INVALID: fallthrough diff --git a/share/p2p/shrexnd/exchange_test.go b/share/p2p/shrexnd/exchange_test.go new file mode 100644 index 0000000000..3940474025 --- /dev/null +++ b/share/p2p/shrexnd/exchange_test.go @@ -0,0 +1,62 @@ +package shrexnd + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/libp2p/go-libp2p/core/network" + mocknet "github.com/libp2p/go-libp2p/p2p/net/mock" + "github.com/stretchr/testify/require" + + "github.com/celestiaorg/celestia-node/share/p2p" +) + +func TestExchange_RequestND(t *testing.T) { + // Testcase: Concurrency limit reached + t.Run("ND_concurrency_limit", func(t *testing.T) { + net, err := mocknet.FullMeshConnected(2) + require.NoError(t, err) + + client, err := NewClient(net.Hosts()[0]) + require.NoError(t, err) + server, err := NewServer(net.Hosts()[1], nil, nil) + require.NoError(t, err) + + require.NoError(t, server.Start(context.Background())) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + t.Cleanup(cancel) + + rateLimit := 2 + wg := sync.WaitGroup{} + wg.Add(rateLimit) + + // mockHandler will block requests on server side until test is over + lock := make(chan struct{}) + defer close(lock) + mockHandler := func(network.Stream) { + wg.Done() + select { + case <-lock: + case <-ctx.Done(): + t.Fatal("timeout") + } + } + server.host.SetStreamHandler(server.protocolID, + p2p.RateLimitMiddleware(mockHandler, rateLimit)) + + // take server concurrency slots with blocked requests + for i := 0; i < rateLimit; i++ { + go func(i int) { + client.RequestND(ctx, nil, nil, server.host.ID()) //nolint:errcheck + }(i) + } + + // wait until all server slots are taken + wg.Wait() + _, err = client.RequestND(ctx, nil, nil, server.host.ID()) + require.ErrorIs(t, err, p2p.ErrUnavailable) + }) +} diff --git a/share/p2p/shrexnd/options.go b/share/p2p/shrexnd/options.go index 3c540a6bb6..44f79e13f2 100644 --- a/share/p2p/shrexnd/options.go +++ b/share/p2p/shrexnd/options.go @@ -30,13 +30,17 @@ type Parameters struct { // protocolSuffix is appended to the protocolID and represents the network the protocol is // running on. protocolSuffix string + + // concurrencyLimit is the maximum number of concurrently handled streams + concurrencyLimit int } func DefaultParameters() *Parameters { return &Parameters{ - readTimeout: time.Second * 5, - writeTimeout: time.Second * 10, - serveTimeout: time.Second * 10, + readTimeout: time.Second * 5, + writeTimeout: time.Second * 10, + serveTimeout: time.Second * 10, + concurrencyLimit: 10, } } @@ -52,6 +56,9 @@ func (p *Parameters) Validate() error { if p.serveTimeout <= 0 { return fmt.Errorf("invalid serve timeout: %v, %s", p.serveTimeout, errSuffix) } + if p.concurrencyLimit <= 0 { + return fmt.Errorf("invalid concurrency limit: %v, %s", p.concurrencyLimit, errSuffix) + } return nil } @@ -83,6 +90,13 @@ func WithServeTimeout(serveTimeout time.Duration) Option { } } +// WithConcurrencyLimit is a functional option that configures the `concurrencyLimit` parameter +func WithConcurrencyLimit(concurrencyLimit int) Option { + return func(parameters *Parameters) { + parameters.concurrencyLimit = concurrencyLimit + } +} + func protocolID(protocolSuffix string) protocol.ID { return protocol.ID(fmt.Sprintf("%s%s", protocolPrefix, protocolSuffix)) } diff --git a/share/p2p/shrexnd/pb/share.pb.go b/share/p2p/shrexnd/pb/share.pb.go index e703090d9a..bb9d1e4d40 100644 --- a/share/p2p/shrexnd/pb/share.pb.go +++ b/share/p2p/shrexnd/pb/share.pb.go @@ -29,7 +29,6 @@ const ( StatusCode_OK StatusCode = 1 StatusCode_NOT_FOUND StatusCode = 2 StatusCode_INTERNAL StatusCode = 3 - StatusCode_REFUSED StatusCode = 4 ) var StatusCode_name = map[int32]string{ @@ -37,7 +36,6 @@ var StatusCode_name = map[int32]string{ 1: "OK", 2: "NOT_FOUND", 3: "INTERNAL", - 4: "REFUSED", } var StatusCode_value = map[string]int32{ @@ -45,7 +43,6 @@ var StatusCode_value = map[string]int32{ "OK": 1, "NOT_FOUND": 2, "INTERNAL": 3, - "REFUSED": 4, } func (x StatusCode) String() string { @@ -283,31 +280,31 @@ func init() { func init() { proto.RegisterFile("share/p2p/shrexnd/pb/share.proto", fileDescriptor_ed9f13149b0de397) } var fileDescriptor_ed9f13149b0de397 = []byte{ - // 380 bytes of a gzipped FileDescriptorProto - 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x6c, 0x91, 0xc1, 0xae, 0xd2, 0x40, - 0x14, 0x86, 0xdb, 0x0e, 0x54, 0x38, 0xad, 0xa6, 0x99, 0x18, 0xad, 0xc1, 0x34, 0xd8, 0x15, 0xd1, - 0xa4, 0x4d, 0x6a, 0xe2, 0x1e, 0xa4, 0x68, 0x03, 0x19, 0xcc, 0x00, 0xee, 0x0c, 0x29, 0x76, 0x4c, - 0x5d, 0xd8, 0x19, 0x3b, 0x43, 0xd0, 0xb5, 0x2f, 0xe0, 0x63, 0xb9, 0x64, 0xe9, 0xd2, 0xc0, 0x8b, - 0x98, 0x4e, 0xb9, 0xf7, 0x2e, 0x2e, 0xbb, 0xfe, 0xe7, 0x7c, 0xff, 0x7f, 0xce, 0xe9, 0xc0, 0x50, - 0x96, 0x79, 0xcd, 0x62, 0x91, 0x88, 0x58, 0x96, 0x35, 0xfb, 0x51, 0x15, 0xb1, 0xd8, 0xc5, 0xba, - 0x18, 0x89, 0x9a, 0x2b, 0x8e, 0xf1, 0x45, 0x24, 0x22, 0xd2, 0x44, 0x54, 0x15, 0xe1, 0x27, 0x18, - 0xbc, 0x63, 0x6a, 0xd5, 0x34, 0xe4, 0xe4, 0x27, 0xc9, 0xbf, 0x31, 0x29, 0xf2, 0xcf, 0x8c, 0xb2, - 0xef, 0x7b, 0x26, 0x15, 0x1e, 0x40, 0xbf, 0xe6, 0x5c, 0x6d, 0xcb, 0x5c, 0x96, 0xbe, 0x39, 0x34, - 0x47, 0x2e, 0xed, 0x35, 0x85, 0xf7, 0xb9, 0x2c, 0xf1, 0x0b, 0x70, 0xab, 0x1b, 0xc3, 0xf6, 0x6b, - 0xe1, 0x5b, 0xba, 0xef, 0xdc, 0xd6, 0xb2, 0x22, 0xfc, 0x65, 0xc2, 0xf3, 0xeb, 0xf9, 0x52, 0xf0, - 0x4a, 0x32, 0xfc, 0x06, 0x6c, 0xa9, 0x72, 0xb5, 0x97, 0x3a, 0xfd, 0x51, 0x12, 0x44, 0xf7, 0x97, - 0x8c, 0x56, 0x9a, 0x78, 0xcb, 0x0b, 0x46, 0x2f, 0x34, 0x7e, 0x05, 0x9d, 0x9a, 0x1f, 0xa4, 0x6f, - 0x0d, 0xd1, 0xc8, 0x49, 0x9e, 0x5e, 0x73, 0x51, 0x7e, 0xa0, 0x1a, 0x0a, 0x09, 0x20, 0xca, 0x0f, - 0xf8, 0x09, 0xd8, 0x1a, 0x6b, 0x66, 0xa1, 0x91, 0x4b, 0x2f, 0x0a, 0xc7, 0xd0, 0x15, 0x35, 0xe7, - 0x5f, 0xf4, 0x01, 0x4e, 0xf2, 0xec, 0x5a, 0xd8, 0x87, 0x06, 0xa0, 0x2d, 0x17, 0xa6, 0xd0, 0xd5, - 0x1a, 0x3f, 0x86, 0xae, 0x54, 0x79, 0xad, 0xf4, 0xf2, 0x88, 0xb6, 0x02, 0x7b, 0x80, 0x58, 0xd5, - 0xfe, 0x0e, 0x44, 0x9b, 0xcf, 0x86, 0x23, 0xbc, 0x60, 0xd2, 0x47, 0x7a, 0x70, 0x2b, 0x5e, 0xce, - 0x01, 0xee, 0x2e, 0xc3, 0x0e, 0x3c, 0xc8, 0xc8, 0xc7, 0xf1, 0x22, 0x9b, 0x7a, 0x06, 0xb6, 0xc1, - 0x5a, 0xce, 0x3d, 0x13, 0x3f, 0x84, 0x3e, 0x59, 0xae, 0xb7, 0xb3, 0xe5, 0x86, 0x4c, 0x3d, 0x0b, - 0xbb, 0xd0, 0xcb, 0xc8, 0x3a, 0xa5, 0x64, 0xbc, 0xf0, 0x50, 0xe3, 0xa0, 0xe9, 0x6c, 0xb3, 0x4a, - 0xa7, 0x5e, 0x67, 0xe2, 0xff, 0x39, 0x05, 0xe6, 0xf1, 0x14, 0x98, 0xff, 0x4e, 0x81, 0xf9, 0xfb, - 0x1c, 0x18, 0xc7, 0x73, 0x60, 0xfc, 0x3d, 0x07, 0xc6, 0xce, 0xd6, 0xaf, 0xff, 0xfa, 0x7f, 0x00, - 0x00, 0x00, 0xff, 0xff, 0x33, 0x30, 0xe8, 0xae, 0x21, 0x02, 0x00, 0x00, + // 374 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x6c, 0x91, 0xc1, 0x6a, 0xdb, 0x40, + 0x10, 0x86, 0x25, 0x6d, 0xad, 0xda, 0x23, 0xb5, 0x88, 0xa5, 0xb4, 0x2a, 0x2e, 0xc2, 0xd5, 0xc9, + 0xb4, 0x20, 0x81, 0x0a, 0x3d, 0x16, 0xec, 0xda, 0x6d, 0x45, 0xcd, 0xba, 0xac, 0xdd, 0xdc, 0x82, + 0x91, 0xa3, 0x0d, 0xca, 0x21, 0xda, 0x8d, 0x76, 0x8d, 0x93, 0x73, 0x5e, 0x20, 0x8f, 0x95, 0xa3, + 0x8f, 0x39, 0x06, 0xfb, 0x45, 0x82, 0x56, 0x4e, 0x72, 0x88, 0x6f, 0xfa, 0xff, 0xf9, 0xe6, 0x9f, + 0x19, 0x2d, 0xf4, 0x64, 0x91, 0x55, 0x2c, 0x16, 0x89, 0x88, 0x65, 0x51, 0xb1, 0xcb, 0x32, 0x8f, + 0xc5, 0x32, 0xd6, 0x66, 0x24, 0x2a, 0xae, 0x38, 0xc6, 0x7b, 0x91, 0x88, 0x48, 0x13, 0x51, 0x99, + 0x87, 0xc7, 0xd0, 0xfd, 0xcd, 0xd4, 0xac, 0x2e, 0xc8, 0xe1, 0x15, 0xc9, 0xce, 0x99, 0x14, 0xd9, + 0x09, 0xa3, 0xec, 0x62, 0xc5, 0xa4, 0xc2, 0x5d, 0xe8, 0x54, 0x9c, 0xab, 0x45, 0x91, 0xc9, 0xc2, + 0x37, 0x7b, 0x66, 0xdf, 0xa5, 0xed, 0xda, 0xf8, 0x93, 0xc9, 0x02, 0x7f, 0x06, 0xb7, 0x7c, 0x6c, + 0x58, 0x9c, 0xe5, 0xbe, 0xa5, 0xeb, 0xce, 0x93, 0x97, 0xe6, 0xe1, 0xb5, 0x09, 0x9f, 0x0e, 0xe7, + 0x4b, 0xc1, 0x4b, 0xc9, 0xf0, 0x77, 0xb0, 0xa5, 0xca, 0xd4, 0x4a, 0xea, 0xf4, 0xb7, 0x49, 0x10, + 0xbd, 0x5c, 0x32, 0x9a, 0x69, 0xe2, 0x27, 0xcf, 0x19, 0xdd, 0xd3, 0xf8, 0x2b, 0xbc, 0xaa, 0xf8, + 0x5a, 0xfa, 0x56, 0x0f, 0xf5, 0x9d, 0xe4, 0xc3, 0xa1, 0x2e, 0xca, 0xd7, 0x54, 0x43, 0x21, 0x01, + 0x44, 0xf9, 0x1a, 0xbf, 0x07, 0x5b, 0x63, 0xf5, 0x2c, 0xd4, 0x77, 0xe9, 0x5e, 0xe1, 0x18, 0x5a, + 0xa2, 0xe2, 0xfc, 0x54, 0x1f, 0xe0, 0x24, 0x1f, 0x0f, 0x85, 0xfd, 0xab, 0x01, 0xda, 0x70, 0xe1, + 0x18, 0x5a, 0x5a, 0xe3, 0x77, 0xd0, 0x92, 0x2a, 0xab, 0x94, 0x5e, 0x1e, 0xd1, 0x46, 0x60, 0x0f, + 0x10, 0x2b, 0x9b, 0xdf, 0x81, 0x68, 0xfd, 0x59, 0x73, 0x84, 0xe7, 0x4c, 0xfa, 0x48, 0x0f, 0x6e, + 0xc4, 0x97, 0x1f, 0x00, 0xcf, 0x97, 0x61, 0x07, 0x5e, 0xa7, 0xe4, 0x68, 0x30, 0x49, 0x47, 0x9e, + 0x81, 0x6d, 0xb0, 0xa6, 0x7f, 0x3d, 0x13, 0xbf, 0x81, 0x0e, 0x99, 0xce, 0x17, 0xbf, 0xa6, 0xff, + 0xc9, 0xc8, 0xb3, 0xb0, 0x0b, 0xed, 0x94, 0xcc, 0xc7, 0x94, 0x0c, 0x26, 0x1e, 0x1a, 0xfa, 0xb7, + 0xdb, 0xc0, 0xdc, 0x6c, 0x03, 0xf3, 0x7e, 0x1b, 0x98, 0x37, 0xbb, 0xc0, 0xd8, 0xec, 0x02, 0xe3, + 0x6e, 0x17, 0x18, 0x4b, 0x5b, 0x3f, 0xf8, 0xb7, 0x87, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00, 0x23, + 0xba, 0xf9, 0x14, 0x02, 0x00, 0x00, } func (m *GetSharesByNamespaceRequest) Marshal() (dAtA []byte, err error) { diff --git a/share/p2p/shrexnd/pb/share.proto b/share/p2p/shrexnd/pb/share.proto index 59dac8c82e..d181dd7010 100644 --- a/share/p2p/shrexnd/pb/share.proto +++ b/share/p2p/shrexnd/pb/share.proto @@ -17,7 +17,6 @@ enum StatusCode { OK = 1; NOT_FOUND = 2; INTERNAL = 3; - REFUSED = 4; }; message Row { diff --git a/share/p2p/shrexnd/server.go b/share/p2p/shrexnd/server.go index df767b7f9d..54b74ddc1e 100644 --- a/share/p2p/shrexnd/server.go +++ b/share/p2p/shrexnd/server.go @@ -13,8 +13,8 @@ import ( "github.com/celestiaorg/celestia-node/share" "github.com/celestiaorg/celestia-node/share/eds" "github.com/celestiaorg/celestia-node/share/ipld" + "github.com/celestiaorg/celestia-node/share/p2p" pb "github.com/celestiaorg/celestia-node/share/p2p/shrexnd/pb" - "github.com/celestiaorg/go-libp2p-messenger/serde" ) @@ -54,19 +54,22 @@ func NewServer(host host.Host, store *eds.Store, getter share.Getter, opts ...Op } // Start starts the server -func (srv *Server) Start() { +func (srv *Server) Start(context.Context) error { ctx, cancel := context.WithCancel(context.Background()) srv.cancel = cancel - srv.host.SetStreamHandler(srv.protocolID, func(s network.Stream) { + handler := func(s network.Stream) { srv.handleNamespacedData(ctx, s) - }) + } + srv.host.SetStreamHandler(srv.protocolID, p2p.RateLimitMiddleware(handler, srv.params.concurrencyLimit)) + return nil } // Stop stops the server -func (srv *Server) Stop() { +func (srv *Server) Stop(context.Context) error { srv.cancel() srv.host.RemoveStreamHandler(srv.protocolID) + return nil } func (srv *Server) handleNamespacedData(ctx context.Context, stream network.Stream) {