diff --git a/api/docgen/examples.go b/api/docgen/examples.go index 3456880c4f..b873e7e050 100644 --- a/api/docgen/examples.go +++ b/api/docgen/examples.go @@ -56,7 +56,7 @@ var ExampleValues = map[reflect.Type]interface{}{ reflect.TypeOf(node.Full): node.Full, reflect.TypeOf(auth.Permission("admin")): auth.Permission("admin"), reflect.TypeOf(byzantine.BadEncoding): byzantine.BadEncoding, - reflect.TypeOf((*fraud.Proof)(nil)).Elem(): byzantine.CreateBadEncodingProof( + reflect.TypeOf((*fraud.Proof[*header.ExtendedHeader])(nil)).Elem(): byzantine.CreateBadEncodingProof( []byte("bad encoding proof"), 42, &byzantine.ErrByzantine{ diff --git a/core/exchange.go b/core/exchange.go index bed2404195..f8e1606a3e 100644 --- a/core/exchange.go +++ b/core/exchange.go @@ -78,7 +78,7 @@ func (ce *Exchange) GetVerifiedRange( from *header.ExtendedHeader, amount uint64, ) ([]*header.ExtendedHeader, error) { - headers, err := ce.GetRangeByHeight(ctx, uint64(from.Height())+1, amount) + headers, err := ce.GetRangeByHeight(ctx, from.Height()+1, amount) if err != nil { return nil, err } @@ -115,7 +115,7 @@ func (ce *Exchange) Get(ctx context.Context, hash libhead.Hash) (*header.Extende return nil, fmt.Errorf("extending block data for height %d: %w", &block.Height, err) } // construct extended header - eh, err := ce.construct(ctx, &block.Header, comm, vals, eds) + eh, err := ce.construct(&block.Header, comm, vals, eds) if err != nil { panic(fmt.Errorf("constructing extended header for height %d: %w", &block.Height, err)) } @@ -133,7 +133,10 @@ func (ce *Exchange) Get(ctx context.Context, hash libhead.Hash) (*header.Extende return eh, nil } -func (ce *Exchange) Head(ctx context.Context) (*header.ExtendedHeader, error) { +func (ce *Exchange) Head( + ctx context.Context, + _ ...libhead.HeadOption[*header.ExtendedHeader], +) (*header.ExtendedHeader, error) { log.Debug("requesting head") return ce.getExtendedHeaderByHeight(ctx, nil) } @@ -157,7 +160,7 @@ func (ce *Exchange) getExtendedHeaderByHeight(ctx context.Context, height *int64 return nil, fmt.Errorf("extending block data for height %d: %w", b.Header.Height, err) } // create extended header - eh, err := ce.construct(ctx, &b.Header, &b.Commit, &b.ValidatorSet, eds) + eh, err := ce.construct(&b.Header, &b.Commit, &b.ValidatorSet, eds) if err != nil { panic(fmt.Errorf("constructing extended header for height %d: %w", b.Header.Height, err)) } diff --git a/core/header_test.go b/core/header_test.go index 1c89db9d6b..c942ea7875 100644 --- a/core/header_test.go +++ b/core/header_test.go @@ -33,7 +33,7 @@ func TestMakeExtendedHeaderForEmptyBlock(t *testing.T) { eds, err := extendBlock(b.Data, b.Header.Version.App) require.NoError(t, err) - headerExt, err := header.MakeExtendedHeader(ctx, &b.Header, comm, val, eds) + headerExt, err := header.MakeExtendedHeader(&b.Header, comm, val, eds) require.NoError(t, err) assert.Equal(t, header.EmptyDAH(), *headerExt.DAH) diff --git a/core/listener.go b/core/listener.go index 565fc62032..1c79fbbe71 100644 --- a/core/listener.go +++ b/core/listener.go @@ -160,7 +160,7 @@ func (cl *Listener) handleNewSignedBlock(ctx context.Context, b types.EventDataS return fmt.Errorf("extending block data: %w", err) } // generate extended header - eh, err := cl.construct(ctx, &b.Header, &b.Commit, &b.ValidatorSet, eds) + eh, err := cl.construct(&b.Header, &b.Commit, &b.ValidatorSet, eds) if err != nil { panic(fmt.Errorf("making extended header: %w", err)) } @@ -181,7 +181,7 @@ func (cl *Listener) handleNewSignedBlock(ctx context.Context, b types.EventDataS if !syncing { err = cl.hashBroadcaster(ctx, shrexsub.Notification{ DataHash: eh.DataHash.Bytes(), - Height: uint64(eh.Height()), + Height: eh.Height(), }) if err != nil && !errors.Is(err, context.Canceled) { log.Errorw("listener: broadcasting data hash", diff --git a/core/listener_test.go b/core/listener_test.go index 7d4b12310a..8b3d05bea9 100644 --- a/core/listener_test.go +++ b/core/listener_test.go @@ -31,8 +31,8 @@ func TestListener(t *testing.T) { // create mocknet with two pubsub endpoints ps0, ps1 := createMocknetWithTwoPubsubEndpoints(ctx, t) subscriber := p2p.NewSubscriber[*header.ExtendedHeader](ps1, header.MsgID, networkID) - err := subscriber.AddValidator(func(context.Context, *header.ExtendedHeader) pubsub.ValidationResult { - return pubsub.ValidationAccept + err := subscriber.SetVerifier(func(context.Context, *header.ExtendedHeader) error { + return nil }) require.NoError(t, err) require.NoError(t, subscriber.Start(ctx)) diff --git a/das/coordinator_test.go b/das/coordinator_test.go index 188cb0d222..55ed01dd4e 100644 --- a/das/coordinator_test.go +++ b/das/coordinator_test.go @@ -366,7 +366,7 @@ func (m *mockSampler) sample(ctx context.Context, h *header.ExtendedHeader) erro m.lock.Lock() defer m.lock.Unlock() - height := uint64(h.Height()) + height := h.Height() m.done[height]++ if len(m.done) > int(m.NetworkHead-m.SampleFrom) && !m.isFinished { @@ -503,7 +503,7 @@ func (o *checkOrder) middleWare(out sampleFn) sampleFn { if len(o.queue) > 0 { // check last item in queue to be same as input - if o.queue[0] != uint64(h.Height()) { + if o.queue[0] != h.Height() { defer o.lock.Unlock() return fmt.Errorf("expected height: %v,got: %v", o.queue[0], h.Height()) } @@ -573,7 +573,7 @@ func (l *lock) releaseAll(except ...uint64) { func (l *lock) middleWare(out sampleFn) sampleFn { return func(ctx context.Context, h *header.ExtendedHeader) error { l.m.Lock() - ch, blocked := l.blockList[uint64(h.Height())] + ch, blocked := l.blockList[h.Height()] l.m.Unlock() if !blocked { return out(ctx, h) @@ -589,7 +589,7 @@ func (l *lock) middleWare(out sampleFn) sampleFn { } func onceMiddleWare(out sampleFn) sampleFn { - db := make(map[int64]int) + db := make(map[uint64]int) m := sync.Mutex{} return func(ctx context.Context, h *header.ExtendedHeader) error { m.Lock() diff --git a/das/daser.go b/das/daser.go index d4ad0ee641..9d3e43a91b 100644 --- a/das/daser.go +++ b/das/daser.go @@ -25,7 +25,7 @@ type DASer struct { params Parameters da share.Availability - bcast fraud.Broadcaster + bcast fraud.Broadcaster[*header.ExtendedHeader] hsub libhead.Subscriber[*header.ExtendedHeader] // listens for new headers in the network getter libhead.Getter[*header.ExtendedHeader] // retrieves past headers @@ -47,7 +47,7 @@ func NewDASer( hsub libhead.Subscriber[*header.ExtendedHeader], getter libhead.Getter[*header.ExtendedHeader], dstore datastore.Datastore, - bcast fraud.Broadcaster, + bcast fraud.Broadcaster[*header.ExtendedHeader], shrexBroadcast shrexsub.BroadcastFn, options ...Option, ) (*DASer, error) { @@ -99,7 +99,7 @@ func (d *DASer) Start(ctx context.Context) error { // attempt to get head info. No need to handle error, later DASer // will be able to find new head from subscriber after it is started if h, err := d.getter.Head(ctx); err == nil { - cp.NetworkHead = uint64(h.Height()) + cp.NetworkHead = h.Height() } } log.Info("starting DASer from checkpoint: ", cp.String()) @@ -152,7 +152,7 @@ func (d *DASer) sample(ctx context.Context, h *header.ExtendedHeader) error { var byzantineErr *byzantine.ErrByzantine if errors.As(err, &byzantineErr) { log.Warn("Propagating proof...") - sendErr := d.bcast.Broadcast(ctx, byzantine.CreateBadEncodingProof(h.Hash(), uint64(h.Height()), byzantineErr)) + sendErr := d.bcast.Broadcast(ctx, byzantine.CreateBadEncodingProof(h.Hash(), h.Height(), byzantineErr)) if sendErr != nil { log.Errorw("fraud proof propagating failed", "err", sendErr) } diff --git a/das/daser_test.go b/das/daser_test.go index 7398310a6b..68f6e01ef2 100644 --- a/das/daser_test.go +++ b/das/daser_test.go @@ -159,21 +159,37 @@ func TestDASer_stopsAfter_BEFP(t *testing.T) { mockGet, sub, _ := createDASerSubcomponents(t, bServ, 15, 15) // create fraud service and break one header - getter := func(ctx context.Context, height uint64) (libhead.Header, error) { + getter := func(ctx context.Context, height uint64) (*header.ExtendedHeader, error) { return mockGet.GetByHeight(ctx, height) } - f := fraudserv.NewProofService(ps, net.Hosts()[0], getter, ds, false, "private") - require.NoError(t, f.Start(ctx)) + unmarshaler := fraud.MultiUnmarshaler[*header.ExtendedHeader]{ + Unmarshalers: map[fraud.ProofType]func([]byte) (fraud.Proof[*header.ExtendedHeader], error){ + byzantine.BadEncoding: func(data []byte) (fraud.Proof[*header.ExtendedHeader], error) { + befp := &byzantine.BadEncodingProof{} + return befp, befp.UnmarshalBinary(data) + }, + }, + } + + fserv := fraudserv.NewProofService[*header.ExtendedHeader](ps, + net.Hosts()[0], + getter, + unmarshaler, + ds, + false, + "private", + ) + require.NoError(t, fserv.Start(ctx)) mockGet.headers[1], _ = headertest.CreateFraudExtHeader(t, mockGet.headers[1], bServ) newCtx := context.Background() // create and start DASer - daser, err := NewDASer(avail, sub, mockGet, ds, f, newBroadcastMock(1)) + daser, err := NewDASer(avail, sub, mockGet, ds, fserv, newBroadcastMock(1)) require.NoError(t, err) resultCh := make(chan error) - go fraud.OnProof(newCtx, f, byzantine.BadEncoding, - func(fraud.Proof) { + go fraud.OnProof[*header.ExtendedHeader](newCtx, fserv, byzantine.BadEncoding, + func(fraud.Proof[*header.ExtendedHeader]) { resultCh <- daser.Stop(newCtx) }) @@ -210,10 +226,10 @@ func TestDASerSampleTimeout(t *testing.T) { ds := ds_sync.MutexWrap(datastore.NewMapDatastore()) sub := new(headertest.Subscriber) - f := new(fraudtest.DummyService) + fserv := &fraudtest.DummyService[*header.ExtendedHeader]{} // create and start DASer - daser, err := NewDASer(avail, sub, getter, ds, f, newBroadcastMock(1), WithSampleTimeout(1)) + daser, err := NewDASer(avail, sub, getter, ds, fserv, newBroadcastMock(1), WithSampleTimeout(1)) require.NoError(t, err) require.NoError(t, daser.Start(ctx)) @@ -235,9 +251,9 @@ func createDASerSubcomponents( bServ blockservice.BlockService, numGetter, numSub int, -) (*mockGetter, *headertest.Subscriber, *fraudtest.DummyService) { +) (*mockGetter, *headertest.Subscriber, *fraudtest.DummyService[*header.ExtendedHeader]) { mockGet, sub := createMockGetterAndSub(t, bServ, numGetter, numSub) - fraud := new(fraudtest.DummyService) + fraud := &fraudtest.DummyService[*header.ExtendedHeader]{} return mockGet, sub, fraud } @@ -313,7 +329,10 @@ func (m *mockGetter) generateHeaders(t *testing.T, bServ blockservice.BlockServi m.head = int64(startHeight + endHeight) } -func (m *mockGetter) Head(context.Context) (*header.ExtendedHeader, error) { +func (m *mockGetter) Head( + context.Context, + ...libhead.HeadOption[*header.ExtendedHeader], +) (*header.ExtendedHeader, error) { return m.headers[m.head], nil } @@ -354,7 +373,10 @@ func (m benchGetterStub) GetByHeight(context.Context, uint64) (*header.ExtendedH type getterStub struct{} -func (m getterStub) Head(context.Context) (*header.ExtendedHeader, error) { +func (m getterStub) Head( + context.Context, + ...libhead.HeadOption[*header.ExtendedHeader], +) (*header.ExtendedHeader, error) { return &header.ExtendedHeader{RawHeader: header.RawHeader{Height: 1}}, nil } diff --git a/das/state.go b/das/state.go index 6af0b7d8d8..bd3a018a40 100644 --- a/das/state.go +++ b/das/state.go @@ -132,30 +132,29 @@ func (s *coordinatorState) handleRetryResult(res result) { } } -func (s *coordinatorState) isNewHead(newHead int64) bool { +func (s *coordinatorState) isNewHead(newHead uint64) bool { // seen this header before - if uint64(newHead) <= s.networkHead { + if newHead <= s.networkHead { log.Warnf("received head height: %v, which is lower or the same as previously known: %v", newHead, s.networkHead) return false } return true } -func (s *coordinatorState) updateHead(newHead int64) { +func (s *coordinatorState) updateHead(newHead uint64) { if s.networkHead == s.sampleFrom { log.Infow("found first header, starting sampling") } - s.networkHead = uint64(newHead) + s.networkHead = newHead log.Debugw("updated head", "from_height", s.networkHead, "to_height", newHead) s.checkDone() } // recentJob creates a job to process a recent header. func (s *coordinatorState) recentJob(header *header.ExtendedHeader) job { - height := uint64(header.Height()) // move next, to prevent catchup job from processing same height - if s.next == height { + if s.next == header.Height() { s.next++ } s.nextJobID++ @@ -163,8 +162,8 @@ func (s *coordinatorState) recentJob(header *header.ExtendedHeader) job { id: s.nextJobID, jobType: recentJob, header: header, - from: height, - to: height, + from: header.Height(), + to: header.Height(), } } diff --git a/das/worker.go b/das/worker.go index 746324ec48..f2e8c4d821 100644 --- a/das/worker.go +++ b/das/worker.go @@ -135,7 +135,7 @@ func (w *worker) sample(ctx context.Context, timeout time.Duration, height uint6 if w.state.job.jobType == recentJob { err = w.broadcast(ctx, shrexsub.Notification{ DataHash: h.DataHash.Bytes(), - Height: uint64(h.Height()), + Height: h.Height(), }) if err != nil { log.Warn("failed to broadcast availability message", diff --git a/go.mod b/go.mod index 0d26924835..0bbebc96ed 100644 --- a/go.mod +++ b/go.mod @@ -14,8 +14,8 @@ require ( github.com/benbjohnson/clock v1.3.5 github.com/celestiaorg/celestia-app v1.0.0-rc12 github.com/celestiaorg/go-ds-badger4 v0.0.0-20230712104058-7ede1c814ac5 - github.com/celestiaorg/go-fraud v0.1.2 - github.com/celestiaorg/go-header v0.2.13 + github.com/celestiaorg/go-fraud v0.2.0 + github.com/celestiaorg/go-header v0.3.0 github.com/celestiaorg/go-libp2p-messenger v0.2.0 github.com/celestiaorg/nmt v0.18.1 github.com/celestiaorg/rsmt2d v0.11.0 diff --git a/go.sum b/go.sum index 6adeef7f7c..09bd8ab26e 100644 --- a/go.sum +++ b/go.sum @@ -366,10 +366,10 @@ github.com/celestiaorg/dagstore v0.0.0-20230824094345-537c012aa403 h1:Lj73O3S+KJ github.com/celestiaorg/dagstore v0.0.0-20230824094345-537c012aa403/go.mod h1:cCGM1UoMvyTk8k62mkc+ReVu8iHBCtSBAAL4wYU7KEI= github.com/celestiaorg/go-ds-badger4 v0.0.0-20230712104058-7ede1c814ac5 h1:MJgXvhJP1Au8rXTvMMlBXodu9jplEK1DxiLtMnEphOs= github.com/celestiaorg/go-ds-badger4 v0.0.0-20230712104058-7ede1c814ac5/go.mod h1:r6xB3nvGotmlTACpAr3SunxtoXeesbqb57elgMJqflY= -github.com/celestiaorg/go-fraud v0.1.2 h1:Bf7yIN3lZ4IR/Vlu5OtmcVCVNESBKEJ/xwu28rRKGA8= -github.com/celestiaorg/go-fraud v0.1.2/go.mod h1:kHZXQY+6gd1kYkoWRFFKgWyrLPWRgDN3vd1Ll9gE/oo= -github.com/celestiaorg/go-header v0.2.13 h1:sUJLXYs8ViPpxLXyIIaW3h4tPFgtVYMhzsLC4GHfS8I= -github.com/celestiaorg/go-header v0.2.13/go.mod h1:NhiWq97NtAYyRBu8quzYOUghQULjgOzO2Ql0iVEFOf0= +github.com/celestiaorg/go-fraud v0.2.0 h1:aaq2JiW0gTnhEdac3l51UCqSyJ4+VjFGTTpN83V4q7I= +github.com/celestiaorg/go-fraud v0.2.0/go.mod h1:lNY1i4K6kUeeE60Z2VK8WXd+qXb8KRzfBhvwPkK6aUc= +github.com/celestiaorg/go-header v0.3.0 h1:9fhxSgldPiWWq3yd9u7oSk5vYqaLV1JkeTnJdGcisFo= +github.com/celestiaorg/go-header v0.3.0/go.mod h1:H8xhnDLDLbkpwmWPhCaZyTnIV3dlVxBHPnxNXS2Qu6c= github.com/celestiaorg/go-libp2p-messenger v0.2.0 h1:/0MuPDcFamQMbw9xTZ73yImqgTO3jHV7wKHvWD/Irao= github.com/celestiaorg/go-libp2p-messenger v0.2.0/go.mod h1:s9PIhMi7ApOauIsfBcQwbr7m+HBzmVfDIS+QLdgzDSo= github.com/celestiaorg/go-verifcid v0.0.1-lazypatch h1:9TSe3w1cmJmbWlweCwCTIZkan7jV8M+KwglXpdD+UG8= diff --git a/header/header.go b/header/header.go index d69b11d998..92f8538696 100644 --- a/header/header.go +++ b/header/header.go @@ -2,12 +2,12 @@ package header import ( "bytes" - "context" "encoding/json" "fmt" "time" tmjson "github.com/tendermint/tendermint/libs/json" + "github.com/tendermint/tendermint/light" core "github.com/tendermint/tendermint/types" "github.com/celestiaorg/celestia-app/pkg/appconsts" @@ -18,7 +18,6 @@ import ( // ConstructFn aliases a function that creates an ExtendedHeader. type ConstructFn = func( - context.Context, *core.Header, *core.Commit, *core.ValidatorSet, @@ -45,31 +44,8 @@ type ExtendedHeader struct { DAH *DataAvailabilityHeader `json:"dah"` } -func (eh *ExtendedHeader) New() libhead.Header { - return new(ExtendedHeader) -} - -func (eh *ExtendedHeader) IsZero() bool { - return eh == nil -} - -func (eh *ExtendedHeader) ChainID() string { - return eh.RawHeader.ChainID -} - -func (eh *ExtendedHeader) Height() int64 { - return eh.RawHeader.Height -} - -func (eh *ExtendedHeader) Time() time.Time { - return eh.RawHeader.Time -} - -var _ libhead.Header = &ExtendedHeader{} - // MakeExtendedHeader assembles new ExtendedHeader. func MakeExtendedHeader( - _ context.Context, h *core.Header, comm *core.Commit, vals *core.ValidatorSet, @@ -95,14 +71,34 @@ func MakeExtendedHeader( Commit: comm, ValidatorSet: vals, } - return eh, eh.Validate() + return eh, nil +} + +func (eh *ExtendedHeader) New() *ExtendedHeader { + return new(ExtendedHeader) +} + +func (eh *ExtendedHeader) IsZero() bool { + return eh == nil +} + +func (eh *ExtendedHeader) ChainID() string { + return eh.RawHeader.ChainID +} + +func (eh *ExtendedHeader) Height() uint64 { + return uint64(eh.RawHeader.Height) +} + +func (eh *ExtendedHeader) Time() time.Time { + return eh.RawHeader.Time } // Hash returns Hash of the wrapped RawHeader. // NOTE: It purposely overrides Hash method of RawHeader to get it directly from Commit without // recomputing. func (eh *ExtendedHeader) Hash() libhead.Hash { - return libhead.Hash(eh.Commit.BlockID.Hash) + return eh.Commit.BlockID.Hash.Bytes() } // LastHeader returns the Hash of the last wrapped RawHeader. @@ -158,7 +154,8 @@ func (eh *ExtendedHeader) Validate() error { return fmt.Errorf("commit signs block %X, header is block %X", chash, hhash) } - if err := eh.ValidatorSet.VerifyCommitLight(eh.ChainID(), eh.Commit.BlockID, eh.Height(), eh.Commit); err != nil { + err = eh.ValidatorSet.VerifyCommitLight(eh.ChainID(), eh.Commit.BlockID, int64(eh.Height()), eh.Commit) + if err != nil { return fmt.Errorf("VerifyCommitLight error at height %d: %w", eh.Height(), err) } @@ -169,6 +166,42 @@ func (eh *ExtendedHeader) Validate() error { return nil } +// Verify validates given untrusted Header against trusted ExtendedHeader. +func (eh *ExtendedHeader) Verify(untrst *ExtendedHeader) error { + isAdjacent := eh.Height()+1 == untrst.Height() + if isAdjacent { + // Optimized verification for adjacent headers + // Check the validator hashes are the same + if !bytes.Equal(untrst.ValidatorsHash, eh.NextValidatorsHash) { + return &libhead.VerifyError{ + Reason: fmt.Errorf("expected old header next validators (%X) to match those from new header (%X)", + eh.NextValidatorsHash, + untrst.ValidatorsHash, + ), + } + } + + if !bytes.Equal(untrst.LastHeader(), eh.Hash()) { + return &libhead.VerifyError{ + Reason: fmt.Errorf("expected new header to point to last header hash (%X), but got %X)", + eh.Hash(), + untrst.LastHeader(), + ), + } + } + + return nil + } + + if err := eh.ValidatorSet.VerifyCommitLightTrusting(eh.ChainID(), untrst.Commit, light.DefaultTrustLevel); err != nil { + return &libhead.VerifyError{ + Reason: err, + SoftFailure: true, + } + } + return nil +} + // MarshalBinary marshals ExtendedHeader to binary. func (eh *ExtendedHeader) MarshalBinary() ([]byte, error) { return MarshalExtendedHeader(eh) @@ -240,3 +273,5 @@ func (eh *ExtendedHeader) UnmarshalJSON(data []byte) error { eh.RawHeader = *rawHeader return nil } + +var _ libhead.Header[*ExtendedHeader] = &ExtendedHeader{} diff --git a/header/headertest/testing.go b/header/headertest/testing.go index b20d389452..65ae8c950f 100644 --- a/header/headertest/testing.go +++ b/header/headertest/testing.go @@ -158,9 +158,9 @@ func (s *TestSuite) NextHeader() *header.ExtendedHeader { } func (s *TestSuite) GenRawHeader( - height int64, lastHeader, lastCommit, dataHash libhead.Hash) *header.RawHeader { + height uint64, lastHeader, lastCommit, dataHash libhead.Hash) *header.RawHeader { rh := RandRawHeader(s.t) - rh.Height = height + rh.Height = int64(height) rh.Time = time.Now() rh.LastBlockID = types.BlockID{Hash: bytes.HexBytes(lastHeader)} rh.LastCommitHash = bytes.HexBytes(lastCommit) @@ -299,7 +299,7 @@ func RandBlockID(*testing.T) types.BlockID { // FraudMaker creates a custom ConstructFn that breaks the block at the given height. func FraudMaker(t *testing.T, faultHeight int64, bServ blockservice.BlockService) header.ConstructFn { log.Warn("Corrupting block...", "height", faultHeight) - return func(ctx context.Context, + return func( h *types.Header, comm *types.Commit, vals *types.ValidatorSet, @@ -318,7 +318,7 @@ func FraudMaker(t *testing.T, faultHeight int64, bServ blockservice.BlockService } return eh, nil } - return header.MakeExtendedHeader(ctx, h, comm, vals, eds) + return header.MakeExtendedHeader(h, comm, vals, eds) } } diff --git a/header/headertest/verify_test.go b/header/headertest/verify_test.go index 33bcf72642..7ef16afc8d 100644 --- a/header/headertest/verify_test.go +++ b/header/headertest/verify_test.go @@ -3,34 +3,32 @@ package headertest import ( "strconv" "testing" - "time" "github.com/stretchr/testify/assert" tmrand "github.com/tendermint/tendermint/libs/rand" - "github.com/celestiaorg/celestia-app/pkg/appconsts" - libhead "github.com/celestiaorg/go-header" + "github.com/celestiaorg/celestia-node/header" ) func TestVerify(t *testing.T) { h := NewTestSuite(t, 2).GenExtendedHeaders(3) trusted, untrustedAdj, untrustedNonAdj := h[0], h[1], h[2] tests := []struct { - prepare func() libhead.Header + prepare func() *header.ExtendedHeader err bool }{ { - prepare: func() libhead.Header { return untrustedAdj }, + prepare: func() *header.ExtendedHeader { return untrustedAdj }, err: false, }, { - prepare: func() libhead.Header { + prepare: func() *header.ExtendedHeader { return untrustedNonAdj }, err: false, }, { - prepare: func() libhead.Header { + prepare: func() *header.ExtendedHeader { untrusted := *untrustedAdj untrusted.ValidatorsHash = tmrand.Bytes(32) return &untrusted @@ -38,7 +36,7 @@ func TestVerify(t *testing.T) { err: true, }, { - prepare: func() libhead.Header { + prepare: func() *header.ExtendedHeader { untrusted := *untrustedAdj untrusted.RawHeader.LastBlockID.Hash = tmrand.Bytes(32) return &untrusted @@ -46,37 +44,10 @@ func TestVerify(t *testing.T) { err: true, }, { - prepare: func() libhead.Header { - untrustedAdj.RawHeader.Time = untrustedAdj.RawHeader.Time.Add(time.Minute) - return untrustedAdj - }, - err: true, - }, - { - prepare: func() libhead.Header { - untrustedAdj.RawHeader.Time = untrustedAdj.RawHeader.Time.Truncate(time.Hour) - return untrustedAdj - }, - err: true, - }, - { - prepare: func() libhead.Header { - untrustedAdj.RawHeader.ChainID = "toaster" - return untrustedAdj - }, - err: true, - }, - { - prepare: func() libhead.Header { - untrustedAdj.RawHeader.Height++ - return untrustedAdj - }, - err: true, - }, - { - prepare: func() libhead.Header { - untrustedAdj.RawHeader.Version.App = appconsts.LatestVersion + 1 - return untrustedAdj + prepare: func() *header.ExtendedHeader { + untrusted := *untrustedNonAdj + untrusted.Commit = NewTestSuite(t, 2).Commit(RandRawHeader(t)) + return &untrusted }, err: true, }, diff --git a/header/serde.go b/header/serde.go index f4763e3b3b..a511a1352b 100644 --- a/header/serde.go +++ b/header/serde.go @@ -61,7 +61,7 @@ func UnmarshalExtendedHeader(data []byte) (*ExtendedHeader, error) { return nil, err } - return out, out.Validate() + return out, nil } func ExtendedHeaderToProto(eh *ExtendedHeader) (*header_pb.ExtendedHeader, error) { diff --git a/header/verify.go b/header/verify.go deleted file mode 100644 index 827f6c1d1b..0000000000 --- a/header/verify.go +++ /dev/null @@ -1,76 +0,0 @@ -package header - -import ( - "bytes" - "fmt" - "time" - - libhead "github.com/celestiaorg/go-header" -) - -// Verify validates given untrusted Header against trusted ExtendedHeader. -func (eh *ExtendedHeader) Verify(untrusted libhead.Header) error { - untrst, ok := untrusted.(*ExtendedHeader) - if !ok { - // if the header of the type was given, something very wrong happens - panic(fmt.Sprintf("invalid header type: expected %T, got %T", eh, untrusted)) - } - - if err := eh.verify(untrst); err != nil { - return &libhead.VerifyError{Reason: err} - } - - isAdjacent := eh.Height()+1 == untrst.Height() - if isAdjacent { - // Optimized verification for adjacent headers - // Check the validator hashes are the same - if !bytes.Equal(untrst.ValidatorsHash, eh.NextValidatorsHash) { - return &libhead.VerifyError{ - Reason: fmt.Errorf("expected old header next validators (%X) to match those from new header (%X)", - eh.NextValidatorsHash, - untrst.ValidatorsHash, - ), - } - } - - if !bytes.Equal(untrst.LastHeader(), eh.Hash()) { - return &libhead.VerifyError{ - Reason: fmt.Errorf("expected new header to point to last header hash (%X), but got %X)", - eh.Hash(), - untrst.LastHeader(), - ), - } - } - - return nil - } - - return nil -} - -// clockDrift defines how much new header's time can drift into -// the future relative to the now time during verification. -var clockDrift = 10 * time.Second - -// verify performs basic verification of untrusted header. -func (eh *ExtendedHeader) verify(untrst libhead.Header) error { - if untrst.Height() <= eh.Height() { - return fmt.Errorf("untrusted header height(%d) <= current trusted header(%d)", untrst.Height(), eh.Height()) - } - - if untrst.ChainID() != eh.ChainID() { - return fmt.Errorf("untrusted header has different chain %s, not %s", untrst.ChainID(), eh.ChainID()) - } - - if !untrst.Time().After(eh.Time()) { - return fmt.Errorf("untrusted header time(%v) must be after current trusted header(%v)", untrst.Time(), eh.Time()) - } - - now := time.Now() - if !untrst.Time().Before(now.Add(clockDrift)) { - return fmt.Errorf( - "new untrusted header has a time from the future %v (now: %v, clockDrift: %v)", untrst.Time(), now, clockDrift) - } - - return nil -} diff --git a/nodebuilder/das/constructors.go b/nodebuilder/das/constructors.go index 18f6962f40..7c6b5bed4f 100644 --- a/nodebuilder/das/constructors.go +++ b/nodebuilder/das/constructors.go @@ -42,16 +42,16 @@ func newDASer( hsub libhead.Subscriber[*header.ExtendedHeader], store libhead.Store[*header.ExtendedHeader], batching datastore.Batching, - fraudServ fraud.Service, + fraudServ fraud.Service[*header.ExtendedHeader], bFn shrexsub.BroadcastFn, options ...das.Option, -) (*das.DASer, *modfraud.ServiceBreaker[*das.DASer], error) { +) (*das.DASer, *modfraud.ServiceBreaker[*das.DASer, *header.ExtendedHeader], error) { ds, err := das.NewDASer(da, hsub, store, batching, fraudServ, bFn, options...) if err != nil { return nil, nil, err } - return ds, &modfraud.ServiceBreaker[*das.DASer]{ + return ds, &modfraud.ServiceBreaker[*das.DASer, *header.ExtendedHeader]{ Service: ds, FraudServ: fraudServ, FraudType: byzantine.BadEncoding, diff --git a/nodebuilder/das/module.go b/nodebuilder/das/module.go index 61c935fd40..d9f7e700e2 100644 --- a/nodebuilder/das/module.go +++ b/nodebuilder/das/module.go @@ -6,6 +6,7 @@ import ( "go.uber.org/fx" "github.com/celestiaorg/celestia-node/das" + "github.com/celestiaorg/celestia-node/header" modfraud "github.com/celestiaorg/celestia-node/nodebuilder/fraud" "github.com/celestiaorg/celestia-node/nodebuilder/node" ) @@ -41,10 +42,10 @@ func ConstructModule(tp node.Type, cfg *Config) fx.Option { baseComponents, fx.Provide(fx.Annotate( newDASer, - fx.OnStart(func(ctx context.Context, breaker *modfraud.ServiceBreaker[*das.DASer]) error { + fx.OnStart(func(ctx context.Context, breaker *modfraud.ServiceBreaker[*das.DASer, *header.ExtendedHeader]) error { return breaker.Start(ctx) }), - fx.OnStop(func(ctx context.Context, breaker *modfraud.ServiceBreaker[*das.DASer]) error { + fx.OnStop(func(ctx context.Context, breaker *modfraud.ServiceBreaker[*das.DASer, *header.ExtendedHeader]) error { return breaker.Stop(ctx) }), )), diff --git a/nodebuilder/fraud/constructors.go b/nodebuilder/fraud/constructors.go index a70ee3e3d4..eee85d4139 100644 --- a/nodebuilder/fraud/constructors.go +++ b/nodebuilder/fraud/constructors.go @@ -1,8 +1,6 @@ package fraud import ( - "context" - "github.com/ipfs/go-datastore" pubsub "github.com/libp2p/go-libp2p-pubsub" "github.com/libp2p/go-libp2p/core/host" @@ -16,32 +14,46 @@ import ( "github.com/celestiaorg/celestia-node/nodebuilder/p2p" ) -func newFraudService(syncerEnabled bool) func( - fx.Lifecycle, - *pubsub.PubSub, - host.Host, - libhead.Store[*header.ExtendedHeader], - datastore.Batching, - p2p.Network, -) (Module, fraud.Service, error) { - return func( - lc fx.Lifecycle, - sub *pubsub.PubSub, - host host.Host, - hstore libhead.Store[*header.ExtendedHeader], - ds datastore.Batching, - network p2p.Network, - ) (Module, fraud.Service, error) { - getter := func(ctx context.Context, height uint64) (libhead.Header, error) { - return hstore.GetByHeight(ctx, height) - } - pservice := fraudserv.NewProofService(sub, host, getter, ds, syncerEnabled, network.String()) - lc.Append(fx.Hook{ - OnStart: pservice.Start, - OnStop: pservice.Stop, - }) - return &Service{ - Service: pservice, - }, pservice, nil - } +func fraudUnmarshaler() fraud.ProofUnmarshaler[*header.ExtendedHeader] { + return defaultProofUnmarshaler +} + +func newFraudServiceWithSync( + lc fx.Lifecycle, + sub *pubsub.PubSub, + host host.Host, + hstore libhead.Store[*header.ExtendedHeader], + registry fraud.ProofUnmarshaler[*header.ExtendedHeader], + ds datastore.Batching, + network p2p.Network, +) (Module, fraud.Service[*header.ExtendedHeader], error) { + syncerEnabled := true + pservice := fraudserv.NewProofService(sub, host, hstore.GetByHeight, registry, ds, syncerEnabled, network.String()) + lc.Append(fx.Hook{ + OnStart: pservice.Start, + OnStop: pservice.Stop, + }) + return &module{ + Service: pservice, + }, pservice, nil +} + +func newFraudServiceWithoutSync( + lc fx.Lifecycle, + sub *pubsub.PubSub, + host host.Host, + hstore libhead.Store[*header.ExtendedHeader], + registry fraud.ProofUnmarshaler[*header.ExtendedHeader], + ds datastore.Batching, + network p2p.Network, +) (Module, fraud.Service[*header.ExtendedHeader], error) { + syncerEnabled := false + pservice := fraudserv.NewProofService(sub, host, hstore.GetByHeight, registry, ds, syncerEnabled, network.String()) + lc.Append(fx.Hook{ + OnStart: pservice.Start, + OnStop: pservice.Stop, + }) + return &module{ + Service: pservice, + }, pservice, nil } diff --git a/nodebuilder/fraud/fraud.go b/nodebuilder/fraud/fraud.go index 8d10d34e88..45c3863d6f 100644 --- a/nodebuilder/fraud/fraud.go +++ b/nodebuilder/fraud/fraud.go @@ -2,8 +2,12 @@ package fraud import ( "context" + "encoding/json" + "errors" "github.com/celestiaorg/go-fraud" + + "github.com/celestiaorg/celestia-node/header" ) var _ Module = (*API)(nil) @@ -35,3 +39,83 @@ func (api *API) Subscribe(ctx context.Context, proofType fraud.ProofType) (<-cha func (api *API) Get(ctx context.Context, proofType fraud.ProofType) ([]Proof, error) { return api.Internal.Get(ctx, proofType) } + +var _ Module = (*module)(nil) + +// module is an implementation of Module that uses fraud.module as a backend. It is used to +// provide fraud proofs as a non-interface type to the API, and wrap fraud.Subscriber with a +// channel of Proofs. +type module struct { + fraud.Service[*header.ExtendedHeader] +} + +func (s *module) Subscribe(ctx context.Context, proofType fraud.ProofType) (<-chan Proof, error) { + subscription, err := s.Service.Subscribe(proofType) + if err != nil { + return nil, err + } + proofs := make(chan Proof) + go func() { + defer close(proofs) + defer subscription.Cancel() + for { + proof, err := subscription.Proof(ctx) + if err != nil { + if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) { + log.Errorw("fetching proof from subscription", "err", err) + } + return + } + select { + case <-ctx.Done(): + return + case proofs <- Proof{Proof: proof}: + } + } + }() + return proofs, nil +} + +func (s *module) Get(ctx context.Context, proofType fraud.ProofType) ([]Proof, error) { + originalProofs, err := s.Service.Get(ctx, proofType) + if err != nil { + return nil, err + } + proofs := make([]Proof, len(originalProofs)) + for i, originalProof := range originalProofs { + proofs[i].Proof = originalProof + } + return proofs, nil +} + +// Proof embeds the fraud.Proof interface type to provide a concrete type for JSON serialization. +type Proof struct { + fraud.Proof[*header.ExtendedHeader] +} + +type fraudProofJSON struct { + ProofType fraud.ProofType `json:"proof_type"` + Data []byte `json:"data"` +} + +func (f *Proof) UnmarshalJSON(data []byte) error { + var fp fraudProofJSON + err := json.Unmarshal(data, &fp) + if err != nil { + return err + } + f.Proof, err = defaultProofUnmarshaler.Unmarshal(fp.ProofType, fp.Data) + return err +} + +func (f *Proof) MarshalJSON() ([]byte, error) { + marshaledProof, err := f.MarshalBinary() + if err != nil { + return nil, err + } + fraudProof := &fraudProofJSON{ + ProofType: f.Type(), + Data: marshaledProof, + } + return json.Marshal(fraudProof) +} diff --git a/nodebuilder/fraud/lifecycle.go b/nodebuilder/fraud/lifecycle.go index cffa4d0f56..1a6702aafa 100644 --- a/nodebuilder/fraud/lifecycle.go +++ b/nodebuilder/fraud/lifecycle.go @@ -2,11 +2,13 @@ package fraud import ( "context" + "errors" "fmt" "github.com/ipfs/go-datastore" "github.com/celestiaorg/go-fraud" + libhead "github.com/celestiaorg/go-header" ) // service defines minimal interface with service lifecycle methods @@ -18,30 +20,30 @@ type service interface { // ServiceBreaker wraps any service with fraud proof subscription of a specific type. // If proof happens the service is Stopped automatically. // TODO(@Wondertan): Support multiple fraud types. -type ServiceBreaker[S service] struct { +type ServiceBreaker[S service, H libhead.Header[H]] struct { Service S FraudType fraud.ProofType - FraudServ fraud.Service + FraudServ fraud.Service[H] ctx context.Context cancel context.CancelFunc - sub fraud.Subscription + sub fraud.Subscription[H] } // Start starts the inner service if there are no fraud proofs stored. // Subscribes for fraud and stops the service whenever necessary. -func (breaker *ServiceBreaker[S]) Start(ctx context.Context) error { +func (breaker *ServiceBreaker[S, H]) Start(ctx context.Context) error { if breaker == nil { return nil } proofs, err := breaker.FraudServ.Get(ctx, breaker.FraudType) - switch err { + switch { default: return fmt.Errorf("getting proof(%s): %w", breaker.FraudType, err) - case nil: - return &fraud.ErrFraudExists{Proof: proofs} - case datastore.ErrNotFound: + case err == nil: + return &fraud.ErrFraudExists[H]{Proof: proofs} + case errors.Is(err, datastore.ErrNotFound): } err = breaker.Service.Start(ctx) @@ -60,7 +62,7 @@ func (breaker *ServiceBreaker[S]) Start(ctx context.Context) error { } // Stop stops the service and cancels subscription. -func (breaker *ServiceBreaker[S]) Stop(ctx context.Context) error { +func (breaker *ServiceBreaker[S, H]) Stop(ctx context.Context) error { if breaker == nil { return nil } @@ -75,13 +77,13 @@ func (breaker *ServiceBreaker[S]) Stop(ctx context.Context) error { return breaker.Service.Stop(ctx) } -func (breaker *ServiceBreaker[S]) awaitProof() { +func (breaker *ServiceBreaker[S, H]) awaitProof() { _, err := breaker.sub.Proof(breaker.ctx) if err != nil { return } - if err := breaker.Stop(breaker.ctx); err != nil && err != context.Canceled { + if err := breaker.Stop(breaker.ctx); err != nil && !errors.Is(err, context.Canceled) { log.Errorw("stopping service: %s", err.Error()) } } diff --git a/nodebuilder/fraud/module.go b/nodebuilder/fraud/module.go index 718b702f84..bf353f63c6 100644 --- a/nodebuilder/fraud/module.go +++ b/nodebuilder/fraud/module.go @@ -6,27 +6,31 @@ import ( "github.com/celestiaorg/go-fraud" + "github.com/celestiaorg/celestia-node/header" "github.com/celestiaorg/celestia-node/nodebuilder/node" ) var log = logging.Logger("module/fraud") func ConstructModule(tp node.Type) fx.Option { - baseComponent := fx.Provide(func(serv fraud.Service) fraud.Getter { - return serv - }) + baseComponent := fx.Options( + fx.Provide(fraudUnmarshaler), + fx.Provide(func(serv fraud.Service[*header.ExtendedHeader]) fraud.Getter[*header.ExtendedHeader] { + return serv + }), + ) switch tp { case node.Light: return fx.Module( "fraud", baseComponent, - fx.Provide(newFraudService(true)), + fx.Provide(newFraudServiceWithSync), ) case node.Full, node.Bridge: return fx.Module( "fraud", baseComponent, - fx.Provide(newFraudService(false)), + fx.Provide(newFraudServiceWithoutSync), ) default: panic("invalid node type") diff --git a/nodebuilder/fraud/service.go b/nodebuilder/fraud/service.go deleted file mode 100644 index 0337c375ef..0000000000 --- a/nodebuilder/fraud/service.go +++ /dev/null @@ -1,87 +0,0 @@ -package fraud - -import ( - "context" - "encoding/json" - - "github.com/celestiaorg/go-fraud" -) - -var _ Module = (*Service)(nil) - -// Service is an implementation of Module that uses fraud.Service as a backend. It is used to -// provide fraud proofs as a non-interface type to the API, and wrap fraud.Subscriber with a -// channel of Proofs. -type Service struct { - fraud.Service -} - -func (s *Service) Subscribe(ctx context.Context, proofType fraud.ProofType) (<-chan Proof, error) { - subscription, err := s.Service.Subscribe(proofType) - if err != nil { - return nil, err - } - proofs := make(chan Proof) - go func() { - defer close(proofs) - for { - proof, err := subscription.Proof(ctx) - if err != nil { - if err != context.DeadlineExceeded && err != context.Canceled { - log.Errorw("fetching proof from subscription", "err", err) - } - return - } - select { - case <-ctx.Done(): - return - case proofs <- Proof{Proof: proof}: - } - } - }() - return proofs, nil -} - -func (s *Service) Get(ctx context.Context, proofType fraud.ProofType) ([]Proof, error) { - originalProofs, err := s.Service.Get(ctx, proofType) - if err != nil { - return nil, err - } - proofs := make([]Proof, len(originalProofs)) - for i, originalProof := range originalProofs { - proofs[i].Proof = originalProof - } - return proofs, nil -} - -// Proof embeds the fraud.Proof interface type to provide a concrete type for JSON serialization. -type Proof struct { - fraud.Proof -} - -type fraudProofJSON struct { - ProofType fraud.ProofType `json:"proof_type"` - Data []byte `json:"data"` -} - -func (f *Proof) UnmarshalJSON(data []byte) error { - var fp fraudProofJSON - err := json.Unmarshal(data, &fp) - if err != nil { - return err - } - f.Proof, err = fraud.Unmarshal(fp.ProofType, fp.Data) - return err -} - -func (f *Proof) MarshalJSON() ([]byte, error) { - marshaledProof, err := f.MarshalBinary() - if err != nil { - return nil, err - } - fraudProof := &fraudProofJSON{ - ProofType: f.Type(), - Data: marshaledProof, - } - return json.Marshal(fraudProof) -} diff --git a/nodebuilder/fraud/unmarshaler.go b/nodebuilder/fraud/unmarshaler.go new file mode 100644 index 0000000000..d5e0461f01 --- /dev/null +++ b/nodebuilder/fraud/unmarshaler.go @@ -0,0 +1,32 @@ +package fraud + +import ( + "github.com/celestiaorg/go-fraud" + + "github.com/celestiaorg/celestia-node/header" + "github.com/celestiaorg/celestia-node/share/eds/byzantine" +) + +var defaultProofUnmarshaler proofRegistry + +type proofRegistry struct{} + +func (pr proofRegistry) List() []fraud.ProofType { + return []fraud.ProofType{ + byzantine.BadEncoding, + } +} + +func (pr proofRegistry) Unmarshal(proofType fraud.ProofType, data []byte) (fraud.Proof[*header.ExtendedHeader], error) { + switch proofType { + case byzantine.BadEncoding: + befp := &byzantine.BadEncodingProof{} + err := befp.UnmarshalBinary(data) + if err != nil { + return nil, err + } + return befp, nil + default: + return nil, &fraud.ErrNoUnmarshaler{ProofType: proofType} + } +} diff --git a/nodebuilder/header/constructors.go b/nodebuilder/header/constructors.go index 7d70f0f5a8..267f0c30f7 100644 --- a/nodebuilder/header/constructors.go +++ b/nodebuilder/header/constructors.go @@ -15,21 +15,20 @@ import ( "github.com/celestiaorg/go-header/store" "github.com/celestiaorg/go-header/sync" - "github.com/celestiaorg/celestia-node/header" modfraud "github.com/celestiaorg/celestia-node/nodebuilder/fraud" modp2p "github.com/celestiaorg/celestia-node/nodebuilder/p2p" "github.com/celestiaorg/celestia-node/share/eds/byzantine" ) // newP2PExchange constructs a new Exchange for headers. -func newP2PExchange( +func newP2PExchange[H libhead.Header[H]]( lc fx.Lifecycle, bpeers modp2p.Bootstrappers, network modp2p.Network, host host.Host, conngater *conngater.BasicConnectionGater, cfg Config, -) (libhead.Exchange[*header.ExtendedHeader], error) { +) (libhead.Exchange[H], error) { peers, err := cfg.trustedPeers(bpeers) if err != nil { return nil, err @@ -39,7 +38,7 @@ func newP2PExchange( ids[index] = peer.ID host.Peerstore().AddAddrs(peer.ID, peer.Addrs, peerstore.PermanentAddrTTL) } - exchange, err := p2p.NewExchange[*header.ExtendedHeader](host, ids, conngater, + exchange, err := p2p.NewExchange[H](host, ids, conngater, p2p.WithParams(cfg.Client), p2p.WithNetworkID[p2p.ClientParameters](network.String()), p2p.WithChainID(network.String()), @@ -60,14 +59,14 @@ func newP2PExchange( } // newSyncer constructs new Syncer for headers. -func newSyncer( - ex libhead.Exchange[*header.ExtendedHeader], - fservice libfraud.Service, - store InitStore, - sub libhead.Subscriber[*header.ExtendedHeader], +func newSyncer[H libhead.Header[H]]( + ex libhead.Exchange[H], + fservice libfraud.Service[H], + store InitStore[H], + sub libhead.Subscriber[H], cfg Config, -) (*sync.Syncer[*header.ExtendedHeader], *modfraud.ServiceBreaker[*sync.Syncer[*header.ExtendedHeader]], error) { - syncer, err := sync.NewSyncer[*header.ExtendedHeader](ex, store, sub, +) (*sync.Syncer[H], *modfraud.ServiceBreaker[*sync.Syncer[H], H], error) { + syncer, err := sync.NewSyncer[H](ex, store, sub, sync.WithParams(cfg.Syncer), sync.WithBlockTime(modp2p.BlockTime), ) @@ -75,7 +74,7 @@ func newSyncer( return nil, nil, err } - return syncer, &modfraud.ServiceBreaker[*sync.Syncer[*header.ExtendedHeader]]{ + return syncer, &modfraud.ServiceBreaker[*sync.Syncer[H], H]{ Service: syncer, FraudType: byzantine.BadEncoding, FraudServ: fservice, @@ -84,16 +83,16 @@ func newSyncer( // InitStore is a type representing initialized header store. // NOTE: It is needed to ensure that Store is always initialized before Syncer is started. -type InitStore libhead.Store[*header.ExtendedHeader] +type InitStore[H libhead.Header[H]] libhead.Store[H] // newInitStore constructs an initialized store -func newInitStore( +func newInitStore[H libhead.Header[H]]( lc fx.Lifecycle, cfg Config, net modp2p.Network, - s libhead.Store[*header.ExtendedHeader], - ex libhead.Exchange[*header.ExtendedHeader], -) (InitStore, error) { + s libhead.Store[H], + ex libhead.Exchange[H], +) (InitStore[H], error) { trustedHash, err := cfg.trustedHash(net) if err != nil { return nil, err diff --git a/nodebuilder/header/module.go b/nodebuilder/header/module.go index 77e7c5eb99..5e02e94fe1 100644 --- a/nodebuilder/header/module.go +++ b/nodebuilder/header/module.go @@ -22,7 +22,7 @@ import ( var log = logging.Logger("module/header") -func ConstructModule(tp node.Type, cfg *Config) fx.Option { +func ConstructModule[H libhead.Header[H]](tp node.Type, cfg *Config) fx.Option { // sanitize config values before constructing module cfgErr := cfg.Validate(tp) @@ -31,61 +31,63 @@ func ConstructModule(tp node.Type, cfg *Config) fx.Option { fx.Error(cfgErr), fx.Provide(newHeaderService), fx.Provide(fx.Annotate( - func(ds datastore.Batching) (libhead.Store[*header.ExtendedHeader], error) { - return store.NewStore[*header.ExtendedHeader](ds, store.WithParams(cfg.Store)) + func(ds datastore.Batching) (libhead.Store[H], error) { + return store.NewStore[H](ds, store.WithParams(cfg.Store)) }, - fx.OnStart(func(ctx context.Context, store libhead.Store[*header.ExtendedHeader]) error { - return store.Start(ctx) + fx.OnStart(func(ctx context.Context, str libhead.Store[H]) error { + s := str.(*store.Store[H]) + return s.Start(ctx) }), - fx.OnStop(func(ctx context.Context, store libhead.Store[*header.ExtendedHeader]) error { - return store.Stop(ctx) + fx.OnStop(func(ctx context.Context, str libhead.Store[H]) error { + s := str.(*store.Store[H]) + return s.Stop(ctx) }), )), - fx.Provide(newInitStore), - fx.Provide(func(subscriber *p2p.Subscriber[*header.ExtendedHeader]) libhead.Subscriber[*header.ExtendedHeader] { + fx.Provide(newInitStore[H]), + fx.Provide(func(subscriber *p2p.Subscriber[H]) libhead.Subscriber[H] { return subscriber }), fx.Provide(fx.Annotate( - newSyncer, + newSyncer[H], fx.OnStart(func( ctx context.Context, - breaker *modfraud.ServiceBreaker[*sync.Syncer[*header.ExtendedHeader]], + breaker *modfraud.ServiceBreaker[*sync.Syncer[H], H], ) error { return breaker.Start(ctx) }), fx.OnStop(func( ctx context.Context, - breaker *modfraud.ServiceBreaker[*sync.Syncer[*header.ExtendedHeader]], + breaker *modfraud.ServiceBreaker[*sync.Syncer[H], H], ) error { return breaker.Stop(ctx) }), )), fx.Provide(fx.Annotate( - func(ps *pubsub.PubSub, network modp2p.Network) *p2p.Subscriber[*header.ExtendedHeader] { - return p2p.NewSubscriber[*header.ExtendedHeader](ps, header.MsgID, network.String()) + func(ps *pubsub.PubSub, network modp2p.Network) *p2p.Subscriber[H] { + return p2p.NewSubscriber[H](ps, header.MsgID, network.String()) }, - fx.OnStart(func(ctx context.Context, sub *p2p.Subscriber[*header.ExtendedHeader]) error { + fx.OnStart(func(ctx context.Context, sub *p2p.Subscriber[H]) error { return sub.Start(ctx) }), - fx.OnStop(func(ctx context.Context, sub *p2p.Subscriber[*header.ExtendedHeader]) error { + fx.OnStop(func(ctx context.Context, sub *p2p.Subscriber[H]) error { return sub.Stop(ctx) }), )), fx.Provide(fx.Annotate( func( host host.Host, - store libhead.Store[*header.ExtendedHeader], + store libhead.Store[H], network modp2p.Network, - ) (*p2p.ExchangeServer[*header.ExtendedHeader], error) { - return p2p.NewExchangeServer[*header.ExtendedHeader](host, store, + ) (*p2p.ExchangeServer[H], error) { + return p2p.NewExchangeServer[H](host, store, p2p.WithParams(cfg.Server), p2p.WithNetworkID[p2p.ServerParameters](network.String()), ) }, - fx.OnStart(func(ctx context.Context, server *p2p.ExchangeServer[*header.ExtendedHeader]) error { + fx.OnStart(func(ctx context.Context, server *p2p.ExchangeServer[H]) error { return server.Start(ctx) }), - fx.OnStop(func(ctx context.Context, server *p2p.ExchangeServer[*header.ExtendedHeader]) error { + fx.OnStop(func(ctx context.Context, server *p2p.ExchangeServer[H]) error { return server.Stop(ctx) }), )), @@ -96,13 +98,13 @@ func ConstructModule(tp node.Type, cfg *Config) fx.Option { return fx.Module( "header", baseComponents, - fx.Provide(newP2PExchange), + fx.Provide(newP2PExchange[H]), ) case node.Bridge: return fx.Module( "header", baseComponents, - fx.Provide(func(subscriber *p2p.Subscriber[*header.ExtendedHeader]) libhead.Broadcaster[*header.ExtendedHeader] { + fx.Provide(func(subscriber *p2p.Subscriber[H]) libhead.Broadcaster[H] { return subscriber }), fx.Supply(header.MakeExtendedHeader), diff --git a/nodebuilder/header/module_test.go b/nodebuilder/header/module_test.go index 89293e4ab4..6a35e35284 100644 --- a/nodebuilder/header/module_test.go +++ b/nodebuilder/header/module_test.go @@ -38,7 +38,7 @@ func TestConstructModule_StoreParams(t *testing.T) { fx.Provide(func() datastore.Batching { return datastore.NewMapDatastore() }), - ConstructModule(node.Light, &cfg), + ConstructModule[*header.ExtendedHeader](node.Light, &cfg), fx.Invoke( func(s libhead.Store[*header.ExtendedHeader]) { ss := s.(*store.Store[*header.ExtendedHeader]) @@ -72,10 +72,10 @@ func TestConstructModule_SyncerParams(t *testing.T) { fx.Provide(func() datastore.Batching { return datastore.NewMapDatastore() }), - fx.Provide(func() fraud.Service { + fx.Provide(func() fraud.Service[*header.ExtendedHeader] { return nil }), - ConstructModule(node.Light, &cfg), + ConstructModule[*header.ExtendedHeader](node.Light, &cfg), fx.Invoke(func(s *sync.Syncer[*header.ExtendedHeader]) { syncer = s }), @@ -100,7 +100,7 @@ func TestConstructModule_ExchangeParams(t *testing.T) { fx.Provide(func() datastore.Batching { return datastore.NewMapDatastore() }), - ConstructModule(node.Light, &cfg), + ConstructModule[*header.ExtendedHeader](node.Light, &cfg), fx.Provide(func(b datastore.Batching) (*conngater.BasicConnectionGater, error) { return conngater.NewBasicConnectionGater(b) }), diff --git a/nodebuilder/header/service.go b/nodebuilder/header/service.go index f410c04f04..2b208cb88d 100644 --- a/nodebuilder/header/service.go +++ b/nodebuilder/header/service.go @@ -65,9 +65,9 @@ func (s *Service) GetByHeight(ctx context.Context, height uint64) (*header.Exten switch { case err != nil: return nil, err - case uint64(head.Height()) == height: + case head.Height() == height: return head, nil - case uint64(head.Height())+1 < height: + case head.Height()+1 < height: return nil, fmt.Errorf("header: given height is from the future: "+ "networkHeight: %d, requestedHeight: %d", head.Height(), height) } @@ -78,10 +78,10 @@ func (s *Service) GetByHeight(ctx context.Context, height uint64) (*header.Exten switch { case err != nil: return nil, err - case uint64(head.Height()) == height: + case head.Height() == height: return head, nil // `+1` allows for one header network lag, e.g. user request header that is milliseconds away - case uint64(head.Height())+1 < height: + case head.Height()+1 < height: return nil, fmt.Errorf("header: syncing in progress: "+ "localHeadHeight: %d, requestedHeight: %d", head.Height(), height) default: diff --git a/nodebuilder/header/service_test.go b/nodebuilder/header/service_test.go index 6493d3d51d..14d5ada87d 100644 --- a/nodebuilder/header/service_test.go +++ b/nodebuilder/header/service_test.go @@ -25,9 +25,9 @@ func TestGetByHeightHandlesError(t *testing.T) { }) } -type errorSyncer[H libhead.Header] struct{} +type errorSyncer[H libhead.Header[H]] struct{} -func (d *errorSyncer[H]) Head(context.Context) (H, error) { +func (d *errorSyncer[H]) Head(context.Context, ...libhead.HeadOption[H]) (H, error) { var zero H return zero, fmt.Errorf("dummy error") } diff --git a/nodebuilder/module.go b/nodebuilder/module.go index 719705e35c..3068113102 100644 --- a/nodebuilder/module.go +++ b/nodebuilder/module.go @@ -5,13 +5,14 @@ import ( "go.uber.org/fx" + "github.com/celestiaorg/celestia-node/header" "github.com/celestiaorg/celestia-node/libs/fxutil" "github.com/celestiaorg/celestia-node/nodebuilder/blob" "github.com/celestiaorg/celestia-node/nodebuilder/core" "github.com/celestiaorg/celestia-node/nodebuilder/das" "github.com/celestiaorg/celestia-node/nodebuilder/fraud" "github.com/celestiaorg/celestia-node/nodebuilder/gateway" - "github.com/celestiaorg/celestia-node/nodebuilder/header" + modhead "github.com/celestiaorg/celestia-node/nodebuilder/header" "github.com/celestiaorg/celestia-node/nodebuilder/node" "github.com/celestiaorg/celestia-node/nodebuilder/p2p" "github.com/celestiaorg/celestia-node/nodebuilder/rpc" @@ -46,7 +47,7 @@ func ConstructModule(tp node.Type, network p2p.Network, cfg *Config, store Store // modules provided by the node p2p.ConstructModule(tp, &cfg.P2P), state.ConstructModule(tp, &cfg.State, &cfg.Core), - header.ConstructModule(tp, &cfg.Header), + modhead.ConstructModule[*header.ExtendedHeader](tp, &cfg.Header), share.ConstructModule(tp, &cfg.Share), rpc.ConstructModule(tp, &cfg.RPC), gateway.ConstructModule(tp, &cfg.Gateway), diff --git a/nodebuilder/p2p/pubsub.go b/nodebuilder/p2p/pubsub.go index 0061ab9eea..13d812e3ce 100644 --- a/nodebuilder/p2p/pubsub.go +++ b/nodebuilder/p2p/pubsub.go @@ -18,6 +18,8 @@ import ( "github.com/celestiaorg/go-fraud" "github.com/celestiaorg/go-fraud/fraudserv" headp2p "github.com/celestiaorg/go-header/p2p" + + "github.com/celestiaorg/celestia-node/header" ) func init() { @@ -66,7 +68,7 @@ func pubSub(cfg Config, params pubSubParams) (*pubsub.PubSub, error) { // * https://github.com/libp2p/specs/blob/master/pubsub/gossipsub/gossipsub-v1.1.md#peer-scoring // * lotus // * prysm - topicScores := topicScoreParams(params.Network) + topicScores := topicScoreParams(params) peerScores, err := peerScoreParams(params.Bootstrappers, cfg) if err != nil { return nil, err @@ -105,15 +107,16 @@ type pubSubParams struct { Host hst.Host Bootstrappers Bootstrappers Network Network + Unmarshaler fraud.ProofUnmarshaler[*header.ExtendedHeader] } -func topicScoreParams(network Network) map[string]*pubsub.TopicScoreParams { +func topicScoreParams(params pubSubParams) map[string]*pubsub.TopicScoreParams { mp := map[string]*pubsub.TopicScoreParams{ - headp2p.PubsubTopicID(network.String()): &headp2p.GossibSubScore, + headp2p.PubsubTopicID(params.Network.String()): &headp2p.GossibSubScore, } - for _, pt := range fraud.Registered() { - mp[fraudserv.PubsubTopicID(pt.String(), network.String())] = &fraudserv.GossibSubScore + for _, pt := range params.Unmarshaler.List() { + mp[fraudserv.PubsubTopicID(pt.String(), params.Network.String())] = &fraudserv.GossibSubScore } return mp diff --git a/nodebuilder/settings.go b/nodebuilder/settings.go index 97440fa7dc..d56125209c 100644 --- a/nodebuilder/settings.go +++ b/nodebuilder/settings.go @@ -22,6 +22,7 @@ import ( "github.com/celestiaorg/go-fraud" + "github.com/celestiaorg/celestia-node/header" "github.com/celestiaorg/celestia-node/nodebuilder/das" modheader "github.com/celestiaorg/celestia-node/nodebuilder/header" "github.com/celestiaorg/celestia-node/nodebuilder/node" @@ -80,7 +81,7 @@ func WithMetrics(metricOpts []otlpmetrichttp.Option, nodeType node.Type) fx.Opti } state.WithMetrics(ca) }), - fx.Invoke(fraud.WithMetrics), + fx.Invoke(fraud.WithMetrics[*header.ExtendedHeader]), fx.Invoke(node.WithMetrics), fx.Invoke(modheader.WithMetrics), fx.Invoke(share.WithDiscoveryMetrics), diff --git a/nodebuilder/state/core.go b/nodebuilder/state/core.go index 4636e0f099..f8f8508540 100644 --- a/nodebuilder/state/core.go +++ b/nodebuilder/state/core.go @@ -18,11 +18,11 @@ func coreAccessor( corecfg core.Config, signer *apptypes.KeyringSigner, sync *sync.Syncer[*header.ExtendedHeader], - fraudServ libfraud.Service, -) (*state.CoreAccessor, Module, *modfraud.ServiceBreaker[*state.CoreAccessor]) { + fraudServ libfraud.Service[*header.ExtendedHeader], +) (*state.CoreAccessor, Module, *modfraud.ServiceBreaker[*state.CoreAccessor, *header.ExtendedHeader]) { ca := state.NewCoreAccessor(signer, sync, corecfg.IP, corecfg.RPCPort, corecfg.GRPCPort) - return ca, ca, &modfraud.ServiceBreaker[*state.CoreAccessor]{ + return ca, ca, &modfraud.ServiceBreaker[*state.CoreAccessor, *header.ExtendedHeader]{ Service: ca, FraudType: byzantine.BadEncoding, FraudServ: fraudServ, diff --git a/nodebuilder/state/module.go b/nodebuilder/state/module.go index fe90d023eb..733419a918 100644 --- a/nodebuilder/state/module.go +++ b/nodebuilder/state/module.go @@ -6,6 +6,7 @@ import ( logging "github.com/ipfs/go-log/v2" "go.uber.org/fx" + "github.com/celestiaorg/celestia-node/header" "github.com/celestiaorg/celestia-node/libs/fxutil" "github.com/celestiaorg/celestia-node/nodebuilder/core" modfraud "github.com/celestiaorg/celestia-node/nodebuilder/fraud" @@ -26,10 +27,12 @@ func ConstructModule(tp node.Type, cfg *Config, coreCfg *core.Config) fx.Option fx.Error(cfgErr), fxutil.ProvideIf(coreCfg.IsEndpointConfigured(), fx.Annotate( coreAccessor, - fx.OnStart(func(ctx context.Context, breaker *modfraud.ServiceBreaker[*state.CoreAccessor]) error { + fx.OnStart(func(ctx context.Context, + breaker *modfraud.ServiceBreaker[*state.CoreAccessor, *header.ExtendedHeader]) error { return breaker.Start(ctx) }), - fx.OnStop(func(ctx context.Context, breaker *modfraud.ServiceBreaker[*state.CoreAccessor]) error { + fx.OnStop(func(ctx context.Context, + breaker *modfraud.ServiceBreaker[*state.CoreAccessor, *header.ExtendedHeader]) error { return breaker.Stop(ctx) }), )), diff --git a/nodebuilder/testing.go b/nodebuilder/testing.go index 6cb40a2b6c..36f2c2f47f 100644 --- a/nodebuilder/testing.go +++ b/nodebuilder/testing.go @@ -11,9 +11,10 @@ import ( apptypes "github.com/celestiaorg/celestia-app/x/blob/types" "github.com/celestiaorg/celestia-node/core" + "github.com/celestiaorg/celestia-node/header" "github.com/celestiaorg/celestia-node/header/headertest" "github.com/celestiaorg/celestia-node/libs/fxutil" - "github.com/celestiaorg/celestia-node/nodebuilder/header" + modhead "github.com/celestiaorg/celestia-node/nodebuilder/header" "github.com/celestiaorg/celestia-node/nodebuilder/node" "github.com/celestiaorg/celestia-node/nodebuilder/p2p" "github.com/celestiaorg/celestia-node/nodebuilder/state" @@ -47,7 +48,7 @@ func TestNodeWithConfig(t *testing.T, tp node.Type, cfg *Config, opts ...fx.Opti // temp dir for the eds store FIXME: Should be in mem fx.Replace(node.StorePath(t.TempDir())), // avoid requesting trustedPeer during initialization - fxutil.ReplaceAs(headertest.NewStore(t), new(header.InitStore)), + fxutil.ReplaceAs(headertest.NewStore(t), new(modhead.InitStore[*header.ExtendedHeader])), ) // in fact, we don't need core.Client in tests, but Bridge requires is a valid one diff --git a/nodebuilder/tests/api_test.go b/nodebuilder/tests/api_test.go index 1bc1c261de..3a66c4e58c 100644 --- a/nodebuilder/tests/api_test.go +++ b/nodebuilder/tests/api_test.go @@ -80,12 +80,12 @@ func TestGetByHeight(t *testing.T) { networkHead, err := client.Header.NetworkHead(ctx) require.NoError(t, err) - _, err = client.Header.GetByHeight(ctx, uint64(networkHead.Height()+1)) + _, err = client.Header.GetByHeight(ctx, networkHead.Height()+1) require.Nil(t, err, "Requesting syncer.Head()+1 shouldn't return an error") networkHead, err = client.Header.NetworkHead(ctx) require.NoError(t, err) - _, err = client.Header.GetByHeight(ctx, uint64(networkHead.Height()+2)) + _, err = client.Header.GetByHeight(ctx, networkHead.Height()+2) require.ErrorContains(t, err, "given height is from the future") } diff --git a/nodebuilder/tests/sync_test.go b/nodebuilder/tests/sync_test.go index dfa3577599..234556a3aa 100644 --- a/nodebuilder/tests/sync_test.go +++ b/nodebuilder/tests/sync_test.go @@ -295,7 +295,7 @@ func TestSyncLightAgainstFull(t *testing.T) { require.NoError(t, err) bridgeHead, err := bridge.HeaderServ.LocalHead(ctx) require.NoError(t, err) - _, err = full.HeaderServ.WaitForHeight(ctx, uint64(bridgeHead.Height())) + _, err = full.HeaderServ.WaitForHeight(ctx, bridgeHead.Height()) require.NoError(t, err) // reset suite bootstrapper list and set full node as a bootstrapper for @@ -316,7 +316,7 @@ func TestSyncLightAgainstFull(t *testing.T) { require.NoError(t, err) fullHead, err := full.HeaderServ.LocalHead(ctx) require.NoError(t, err) - _, err = light.HeaderServ.WaitForHeight(ctx, uint64(fullHead.Height())) + _, err = light.HeaderServ.WaitForHeight(ctx, fullHead.Height()) require.NoError(t, err) // wait for the core block filling process to exit diff --git a/share/eds/byzantine/bad_encoding.go b/share/eds/byzantine/bad_encoding.go index 3c5bc6951b..e3a862e38a 100644 --- a/share/eds/byzantine/bad_encoding.go +++ b/share/eds/byzantine/bad_encoding.go @@ -7,7 +7,6 @@ import ( "github.com/celestiaorg/celestia-app/pkg/wrapper" "github.com/celestiaorg/go-fraud" - libhead "github.com/celestiaorg/go-header" "github.com/celestiaorg/rsmt2d" "github.com/celestiaorg/celestia-node/header" @@ -22,10 +21,6 @@ const ( BadEncoding fraud.ProofType = "badencoding" + version ) -func init() { - fraud.Register(&BadEncodingProof{}) -} - type BadEncodingProof struct { headerHash []byte BlockHeight uint64 @@ -46,8 +41,7 @@ func CreateBadEncodingProof( hash []byte, height uint64, errByzantine *ErrByzantine, -) fraud.Proof { - +) fraud.Proof[*header.ExtendedHeader] { return &BadEncodingProof{ headerHash: hash, BlockHeight: height, @@ -112,34 +106,29 @@ func (p *BadEncodingProof) UnmarshalBinary(data []byte) error { // Validate checks that provided Merkle Proofs correspond to the shares, // rebuilds bad row or col from received shares, computes Merkle Root // and compares it with block's Merkle Root. -func (p *BadEncodingProof) Validate(hdr libhead.Header) error { - header, ok := hdr.(*header.ExtendedHeader) - if !ok { - panic(fmt.Sprintf("invalid header type received during BEFP validation: expected %T, got %T", header, hdr)) - } - - if header.Height() != int64(p.BlockHeight) { +func (p *BadEncodingProof) Validate(hdr *header.ExtendedHeader) error { + if hdr.Height() != p.BlockHeight { return fmt.Errorf("incorrect block height during BEFP validation: expected %d, got %d", - p.BlockHeight, header.Height(), + p.BlockHeight, hdr.Height(), ) } - if len(header.DAH.RowRoots) != len(header.DAH.ColumnRoots) { + if len(hdr.DAH.RowRoots) != len(hdr.DAH.ColumnRoots) { // NOTE: This should never happen as callers of this method should not feed it with a // malformed extended header. panic(fmt.Sprintf( "invalid extended header: length of row and column roots do not match. (rowRoots=%d) (colRoots=%d)", - len(header.DAH.RowRoots), - len(header.DAH.ColumnRoots)), + len(hdr.DAH.RowRoots), + len(hdr.DAH.ColumnRoots)), ) } // merkleRoots are the roots against which we are going to check the inclusion of the received // shares. Changing the order of the roots to prove the shares relative to the orthogonal axis, // because inside the rsmt2d library rsmt2d.Row = 0 and rsmt2d.Col = 1 - merkleRoots := header.DAH.RowRoots + merkleRoots := hdr.DAH.RowRoots if p.Axis == rsmt2d.Row { - merkleRoots = header.DAH.ColumnRoots + merkleRoots = hdr.DAH.ColumnRoots } if int(p.Index) >= len(merkleRoots) { @@ -196,7 +185,7 @@ func (p *BadEncodingProof) Validate(hdr libhead.Header) error { rebuiltShares, err := codec.Decode(shares) if err != nil { log.Infow("failed to decode shares at height", - "height", header.Height(), "err", err, + "height", hdr.Height(), "err", err, ) return nil } @@ -204,7 +193,7 @@ func (p *BadEncodingProof) Validate(hdr libhead.Header) error { rebuiltExtendedShares, err := codec.Encode(rebuiltShares[0:odsWidth]) if err != nil { log.Infow("failed to encode shares at height", - "height", header.Height(), "err", err, + "height", hdr.Height(), "err", err, ) return nil } @@ -215,7 +204,7 @@ func (p *BadEncodingProof) Validate(hdr libhead.Header) error { err = tree.Push(share) if err != nil { log.Infow("failed to build a tree from the reconstructed shares at height", - "height", header.Height(), "err", err, + "height", hdr.Height(), "err", err, ) return nil } @@ -224,15 +213,15 @@ func (p *BadEncodingProof) Validate(hdr libhead.Header) error { expectedRoot, err := tree.Root() if err != nil { log.Infow("failed to build a tree root at height", - "height", header.Height(), "err", err, + "height", hdr.Height(), "err", err, ) return nil } // root is a merkle root of the row/col where ErrByzantine occurred - root := header.DAH.RowRoots[p.Index] + root := hdr.DAH.RowRoots[p.Index] if p.Axis == rsmt2d.Col { - root = header.DAH.ColumnRoots[p.Index] + root = hdr.DAH.ColumnRoots[p.Index] } // comparing rebuilt Merkle Root of bad row/col with respective Merkle Root of row/col from block. diff --git a/share/eds/byzantine/bad_encoding_test.go b/share/eds/byzantine/bad_encoding_test.go index 49cf64c2c2..b5dcea3452 100644 --- a/share/eds/byzantine/bad_encoding_test.go +++ b/share/eds/byzantine/bad_encoding_test.go @@ -86,7 +86,7 @@ func TestIncorrectBadEncodingFraudProof(t *testing.T) { }, } - proof := CreateBadEncodingProof(h.Hash(), uint64(h.Height()), &fakeError) + proof := CreateBadEncodingProof(h.Hash(), h.Height(), &fakeError) err = proof.Validate(h) require.Error(t, err) } diff --git a/share/eds/retriever_test.go b/share/eds/retriever_test.go index ebccf0e384..12b1c11083 100644 --- a/share/eds/retriever_test.go +++ b/share/eds/retriever_test.go @@ -145,7 +145,7 @@ func TestFraudProofValidation(t *testing.T) { faultHeader, err := generateByzantineError(ctx, t, size, bServ) require.True(t, errors.As(err, &errByz)) - p := byzantine.CreateBadEncodingProof([]byte("hash"), uint64(faultHeader.Height()), errByz) + p := byzantine.CreateBadEncodingProof([]byte("hash"), faultHeader.Height(), errByz) err = p.Validate(faultHeader) require.NoError(t, err) }) @@ -197,7 +197,7 @@ func BenchmarkBEFPValidation(b *testing.B) { for i := 0; i < b.N; i++ { b.ReportAllocs() - p := byzantine.CreateBadEncodingProof([]byte("hash"), uint64(h.Height()), errByz) + p := byzantine.CreateBadEncodingProof([]byte("hash"), h.Height(), errByz) err = p.Validate(h) require.NoError(b, err) } diff --git a/share/p2p/discovery/discovery.go b/share/p2p/discovery/discovery.go index c880b9b3c2..f24df2c88b 100644 --- a/share/p2p/discovery/discovery.go +++ b/share/p2p/discovery/discovery.go @@ -199,7 +199,8 @@ func (d *Discovery) Advertise(ctx context.Context) { } // discoveryLoop ensures we always have '~peerLimit' connected peers. -// It initiates peer discovery upon request and restarts the process until the soft limit is reached. +// It initiates peer discovery upon request and restarts the process until the soft limit is +// reached. func (d *Discovery) discoveryLoop(ctx context.Context) { t := time.NewTicker(discoveryRetryTimeout) defer t.Stop() diff --git a/share/p2p/peers/manager.go b/share/p2p/peers/manager.go index 2a7c1fee18..87f9361ee2 100644 --- a/share/p2p/peers/manager.go +++ b/share/p2p/peers/manager.go @@ -293,7 +293,7 @@ func (m *Manager) subscribeHeader(ctx context.Context, headerSub libhead.Subscri m.validatedPool(h.DataHash.String()) // store first header for validation purposes - if m.initialHeight.CompareAndSwap(0, uint64(h.Height())) { + if m.initialHeight.CompareAndSwap(0, h.Height()) { log.Debugw("stored initial height", "height", h.Height()) } } diff --git a/share/p2p/peers/manager_test.go b/share/p2p/peers/manager_test.go index e10e820e84..ad04d2c7bd 100644 --- a/share/p2p/peers/manager_test.go +++ b/share/p2p/peers/manager_test.go @@ -274,7 +274,7 @@ func TestManager(t *testing.T) { // create shrexSub msg with height lower than first header from headerSub msg := shrexsub.Notification{ DataHash: share.DataHash("datahash"), - Height: uint64(h.Height() - 1), + Height: h.Height() - 1, } result := manager.Validate(ctx, "peer", msg) require.Equal(t, pubsub.ValidationIgnore, result) @@ -298,7 +298,7 @@ func TestManager(t *testing.T) { // create shrexSub msg with height lower than first header from headerSub msg := shrexsub.Notification{ DataHash: share.DataHash("datahash"), - Height: uint64(h.Height() - 1), + Height: h.Height() - 1, } result := manager.Validate(ctx, "peer", msg) require.Equal(t, pubsub.ValidationIgnore, result) @@ -537,7 +537,7 @@ func (s *subLock) Subscribe() (libhead.Subscription[*header.ExtendedHeader], err return s, nil } -func (s *subLock) AddValidator(func(context.Context, *header.ExtendedHeader) pubsub.ValidationResult) error { +func (s *subLock) SetVerifier(func(context.Context, *header.ExtendedHeader) error) error { panic("implement me") } @@ -561,6 +561,6 @@ func (s *subLock) Cancel() { func newShrexSubMsg(h *header.ExtendedHeader) shrexsub.Notification { return shrexsub.Notification{ DataHash: h.DataHash.Bytes(), - Height: uint64(h.Height()), + Height: h.Height(), } } diff --git a/state/core_access.go b/state/core_access.go index 43f5313c5f..d8f6894e24 100644 --- a/state/core_access.go +++ b/state/core_access.go @@ -180,9 +180,10 @@ func (ca *CoreAccessor) constructSignedTx( return ca.signer.EncodeTx(tx) } -// SubmitPayForBlob builds, signs, and synchronously submits a MsgPayForBlob. It blocks until the transaction -// is committed and returns the TxReponse. If gasLim is set to 0, the method will automatically estimate the -// gas limit. If the fee is negative, the method will use the nodes min gas price multiplied by the gas limit. +// SubmitPayForBlob builds, signs, and synchronously submits a MsgPayForBlob. It blocks until the +// transaction is committed and returns the TxReponse. If gasLim is set to 0, the method will +// automatically estimate the gas limit. If the fee is negative, the method will use the nodes min +// gas price multiplied by the gas limit. func (ca *CoreAccessor) SubmitPayForBlob( ctx context.Context, fee Int, @@ -201,8 +202,8 @@ func (ca *CoreAccessor) SubmitPayForBlob( appblobs[i] = &b.Blob } - // we only estimate gas if the user wants us to (by setting the gasLim to 0). In the future we may want - // to make these arguments optional. + // we only estimate gas if the user wants us to (by setting the gasLim to 0). In the future we may + // want to make these arguments optional. if gasLim == 0 { blobSizes := make([]uint32, len(blobs)) for i, blob := range blobs { @@ -294,7 +295,7 @@ func (ca *CoreAccessor) BalanceForAddress(ctx context.Context, addr Address) (*B abciReq := abci.RequestQuery{ // TODO @renayay: once https://github.com/cosmos/cosmos-sdk/pull/12674 is merged, use const instead Path: fmt.Sprintf("store/%s/key", banktypes.StoreKey), - Height: head.Height() - 1, + Height: int64(head.Height() - 1), Data: prefixedAccountKey, Prove: true, } diff --git a/state/integration_test.go b/state/integration_test.go index 8862de1bf8..193e7bddc7 100644 --- a/state/integration_test.go +++ b/state/integration_test.go @@ -20,6 +20,7 @@ import ( "github.com/celestiaorg/celestia-app/test/util/testfactory" "github.com/celestiaorg/celestia-app/test/util/testnode" blobtypes "github.com/celestiaorg/celestia-app/x/blob/types" + libhead "github.com/celestiaorg/go-header" "github.com/celestiaorg/celestia-node/core" "github.com/celestiaorg/celestia-node/header" @@ -95,7 +96,10 @@ type localHeader struct { client rpcclient.Client } -func (l localHeader) Head(ctx context.Context) (*header.ExtendedHeader, error) { +func (l localHeader) Head( + ctx context.Context, + _ ...libhead.HeadOption[*header.ExtendedHeader], +) (*header.ExtendedHeader, error) { latest, err := l.client.Block(ctx, nil) if err != nil { return nil, err