diff --git a/nodebuilder/share/module.go b/nodebuilder/share/module.go index 58a692095a..15785dc9a3 100644 --- a/nodebuilder/share/module.go +++ b/nodebuilder/share/module.go @@ -4,13 +4,16 @@ import ( "context" "github.com/ipfs/go-datastore" + "github.com/libp2p/go-libp2p-core/host" "go.uber.org/fx" "github.com/celestiaorg/celestia-node/nodebuilder/node" + modp2p "github.com/celestiaorg/celestia-node/nodebuilder/p2p" "github.com/celestiaorg/celestia-node/share" "github.com/celestiaorg/celestia-node/share/availability/full" "github.com/celestiaorg/celestia-node/share/availability/light" "github.com/celestiaorg/celestia-node/share/eds" + "github.com/celestiaorg/celestia-node/share/p2p/shrexeds" ) func ConstructModule(tp node.Type, cfg *Config, options ...fx.Option) fx.Option { @@ -48,6 +51,23 @@ func ConstructModule(tp node.Type, cfg *Config, options ...fx.Option) fx.Option return fx.Module( "share", baseComponents, + fx.Provide(fx.Annotate( + func(host host.Host, store *eds.Store, network modp2p.Network) (*shrexeds.Server, error) { + return shrexeds.NewServer(host, store, shrexeds.WithProtocolSuffix(string(network))) + }, + fx.OnStart(func(ctx context.Context, server *shrexeds.Server) error { + return server.Start(ctx) + }), + fx.OnStop(func(ctx context.Context, server *shrexeds.Server) error { + return server.Stop(ctx) + }), + )), + // Bridge Nodes need a client as well, for requests over FullAvailability + fx.Provide( + func(host host.Host, network modp2p.Network) (*shrexeds.Client, error) { + return shrexeds.NewClient(host, shrexeds.WithProtocolSuffix(string(network))) + }, + ), fx.Provide(fx.Annotate( func(path node.StorePath, ds datastore.Batching) (*eds.Store, error) { return eds.NewStore(string(path), ds) diff --git a/share/p2p/shrexeds/client.go b/share/p2p/shrexeds/client.go new file mode 100644 index 0000000000..06c04c7ec3 --- /dev/null +++ b/share/p2p/shrexeds/client.go @@ -0,0 +1,147 @@ +package shrexeds + +import ( + "context" + "errors" + "fmt" + "net" + + "github.com/libp2p/go-libp2p-core/host" + "github.com/libp2p/go-libp2p-core/peer" + "github.com/libp2p/go-libp2p-core/protocol" + + "github.com/celestiaorg/celestia-node/share" + "github.com/celestiaorg/celestia-node/share/eds" + p2p_pb "github.com/celestiaorg/celestia-node/share/p2p/shrexeds/pb" + "github.com/celestiaorg/go-libp2p-messenger/serde" + "github.com/celestiaorg/rsmt2d" +) + +var errNoMorePeers = errors.New("all peers returned invalid responses") + +// Client is responsible for requesting EDSs for blocksync over the ShrEx/EDS protocol. +type Client struct { + protocolID protocol.ID + host host.Host +} + +// NewClient creates a new ShrEx/EDS client. +func NewClient(host host.Host, opts ...Option) (*Client, error) { + params := DefaultParameters() + for _, opt := range opts { + opt(params) + } + + if err := params.Validate(); err != nil { + return nil, fmt.Errorf("shrex-eds: client creation failed: %w", err) + } + + return &Client{ + host: host, + protocolID: protocolID(params.protocolSuffix), + }, nil +} + +// RequestEDS requests the full ODS from one of the given peers and returns the EDS. +// +// The peers are requested in a round-robin manner with retries until one of them gives a valid +// response, blocking until the context is canceled or a valid response is given. +func (c *Client) RequestEDS( + ctx context.Context, + dataHash share.DataHash, + peers peer.IDSlice, +) (*rsmt2d.ExtendedDataSquare, error) { + req := &p2p_pb.EDSRequest{Hash: dataHash} + + // requests are retried for every peer until a valid response is received + excludedPeers := make(map[peer.ID]struct{}) + for { + // if no peers are left, return + if len(peers) == len(excludedPeers) { + return nil, errNoMorePeers + } + + for _, to := range peers { + // skip over excluded peers + if _, ok := excludedPeers[to]; ok { + continue + } + eds, err := c.doRequest(ctx, req, to) + if eds != nil { + return eds, err + } + if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) { + return nil, ctx.Err() + } + // some net.Errors also mean the context deadline was exceeded, but yamux/mocknet do not + // unwrap to a ctx err + var ne net.Error + if errors.As(err, &ne) && ne.Timeout() { + return nil, context.DeadlineExceeded + } + if err != nil { + // peer has misbehaved, exclude them from round-robin + excludedPeers[to] = struct{}{} + log.Errorw("client: eds request to peer failed", "peer", to, "hash", dataHash.String()) + } + + // no eds was found, continue + } + } +} + +func (c *Client) doRequest( + ctx context.Context, + req *p2p_pb.EDSRequest, + to peer.ID, +) (*rsmt2d.ExtendedDataSquare, error) { + dataHash := share.DataHash(req.Hash) + log.Debugf("client: requesting eds %s from peer %s", dataHash.String(), to) + stream, err := c.host.NewStream(ctx, to, c.protocolID) + if err != nil { + return nil, fmt.Errorf("failed to open stream: %w", err) + } + if dl, ok := ctx.Deadline(); ok { + if err = stream.SetDeadline(dl); err != nil { + log.Debugw("error setting deadline: %s", err) + } + } + + // request ODS + _, err = serde.Write(stream, req) + if err != nil { + stream.Reset() //nolint:errcheck + return nil, fmt.Errorf("failed to write request to stream: %w", err) + } + err = stream.CloseWrite() + if err != nil { + stream.Reset() //nolint:errcheck + return nil, fmt.Errorf("failed to close write on stream: %w", err) + } + + // read and parse status from peer + resp := new(p2p_pb.EDSResponse) + _, err = serde.Read(stream, resp) + if err != nil { + stream.Reset() //nolint:errcheck + return nil, fmt.Errorf("failed to read status from stream: %w", err) + } + + switch resp.Status { + case p2p_pb.Status_OK: + // use header and ODS bytes to construct EDS and verify it against dataHash + eds, err := eds.ReadEDS(ctx, stream, dataHash) + if err != nil { + return nil, fmt.Errorf("failed to read eds from ods bytes: %w", err) + } + return eds, nil + case p2p_pb.Status_NOT_FOUND, p2p_pb.Status_REFUSED: + log.Debugf("client: peer %s couldn't serve eds %s with status %s", to.String(), dataHash.String(), resp.GetStatus()) + // no eds was returned, but the request was valid and should be retried + return nil, nil + case p2p_pb.Status_INVALID: + fallthrough + default: + return nil, fmt.Errorf("request status %s returned for root %s", resp.GetStatus(), dataHash.String()) + } +} diff --git a/share/p2p/shrexeds/exchange_test.go b/share/p2p/shrexeds/exchange_test.go new file mode 100644 index 0000000000..95edebd543 --- /dev/null +++ b/share/p2p/shrexeds/exchange_test.go @@ -0,0 +1,115 @@ +package shrexeds + +import ( + "context" + "testing" + "time" + + "github.com/ipfs/go-datastore" + ds_sync "github.com/ipfs/go-datastore/sync" + libhost "github.com/libp2p/go-libp2p-core/host" + "github.com/libp2p/go-libp2p-core/peer" + mocknet "github.com/libp2p/go-libp2p/p2p/net/mock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/celestiaorg/celestia-app/pkg/da" + "github.com/celestiaorg/celestia-node/share" + "github.com/celestiaorg/celestia-node/share/eds" +) + +func TestExchange_RequestEDS(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + store, client, server := makeExchange(t) + + err := store.Start(ctx) + require.NoError(t, err) + + err = server.Start(ctx) + require.NoError(t, err) + + // Testcase: EDS is immediately available + t.Run("EDS_Available", func(t *testing.T) { + eds := share.RandEDS(t, 4) + dah := da.NewDataAvailabilityHeader(eds) + err = store.Put(ctx, dah.Hash(), eds) + require.NoError(t, err) + + requestedEDS, err := client.RequestEDS(ctx, dah.Hash(), []peer.ID{server.host.ID()}) + assert.NoError(t, err) + assert.Equal(t, eds.Flattened(), requestedEDS.Flattened()) + + }) + + // Testcase: EDS is unavailable initially, but is found after multiple requests + t.Run("EDS_AvailableAfterDelay", func(t *testing.T) { + storageDelay := time.Second + eds := share.RandEDS(t, 4) + dah := da.NewDataAvailabilityHeader(eds) + go func() { + time.Sleep(storageDelay) + err = store.Put(ctx, dah.Hash(), eds) + // require.NoError(t, err) + }() + + now := time.Now() + requestedEDS, err := client.RequestEDS(ctx, dah.Hash(), []peer.ID{server.host.ID()}) + finished := time.Now() + + assert.Greater(t, finished.Sub(now), storageDelay) + assert.NoError(t, err) + assert.Equal(t, eds.Flattened(), requestedEDS.Flattened()) + }) + + // Testcase: Invalid request excludes peer from round-robin, stopping request + t.Run("EDS_InvalidRequest", func(t *testing.T) { + dataHash := []byte("invalid") + requestedEDS, err := client.RequestEDS(ctx, dataHash, []peer.ID{server.host.ID()}) + assert.ErrorIs(t, err, errNoMorePeers) + assert.Nil(t, requestedEDS) + }) + + // Testcase: Valid request, which server cannot serve, waits forever + t.Run("EDS_ValidTimeout", func(t *testing.T) { + timeoutCtx, cancel := context.WithTimeout(ctx, time.Second) + t.Cleanup(cancel) + eds := share.RandEDS(t, 4) + dah := da.NewDataAvailabilityHeader(eds) + requestedEDS, err := client.RequestEDS(timeoutCtx, dah.Hash(), []peer.ID{server.host.ID()}) + assert.ErrorIs(t, err, timeoutCtx.Err()) + assert.Nil(t, requestedEDS) + }) +} + +func newStore(t *testing.T) *eds.Store { + t.Helper() + + tmpDir := t.TempDir() + ds := ds_sync.MutexWrap(datastore.NewMapDatastore()) + store, err := eds.NewStore(tmpDir, ds) + require.NoError(t, err) + return store +} + +func createMocknet(t *testing.T, amount int) []libhost.Host { + t.Helper() + + net, err := mocknet.FullMeshConnected(amount) + require.NoError(t, err) + // get host and peer + return net.Hosts() +} + +func makeExchange(t *testing.T) (*eds.Store, *Client, *Server) { + t.Helper() + store := newStore(t) + hosts := createMocknet(t, 2) + + client, err := NewClient(hosts[0]) + require.NoError(t, err) + server, err := NewServer(hosts[1], store) + require.NoError(t, err) + + return store, client, server +} diff --git a/share/p2p/shrexeds/options.go b/share/p2p/shrexeds/options.go new file mode 100644 index 0000000000..23c2d54ef4 --- /dev/null +++ b/share/p2p/shrexeds/options.go @@ -0,0 +1,74 @@ +package shrexeds + +import ( + "fmt" + "time" + + logging "github.com/ipfs/go-log/v2" + "github.com/libp2p/go-libp2p-core/protocol" +) + +const protocolPrefix = "/shrex/eds/v0.0.1/" + +var log = logging.Logger("shrex-eds") + +// Option is the functional option that is applied to the shrex/eds protocol to configure its +// parameters. +type Option func(*Parameters) + +// Parameters is the set of parameters that must be configured for the shrex/eds protocol. +type Parameters struct { + // ReadDeadline sets the timeout for reading messages from the stream. + ReadDeadline time.Duration + + // WriteDeadline sets the timeout for writing messages to the stream. + WriteDeadline time.Duration + + // ReadCARDeadline defines the deadline for reading a CAR from disk. + ReadCARDeadline time.Duration + + // BufferSize defines the size of the buffer used for writing an ODS over the stream. + BufferSize uint64 + + // protocolSuffix is appended to the protocolID and represents the network the protocol is + // running on. + protocolSuffix string +} + +func DefaultParameters() *Parameters { + return &Parameters{ + ReadDeadline: time.Minute, + WriteDeadline: time.Second * 5, + ReadCARDeadline: time.Minute, + BufferSize: 32 * 1024, + } +} + +const errSuffix = "value should be positive and non-zero" + +func (p *Parameters) Validate() error { + if p.ReadDeadline <= 0 { + return fmt.Errorf("invalid stream read deadline: %s", errSuffix) + } + if p.WriteDeadline <= 0 { + return fmt.Errorf("invalid write deadline: %s", errSuffix) + } + if p.ReadCARDeadline <= 0 { + return fmt.Errorf("invalid read CAR deadline: %s", errSuffix) + } + if p.BufferSize <= 0 { + return fmt.Errorf("invalid buffer size: %s", errSuffix) + } + return nil +} + +// WithProtocolSuffix is a functional option that configures the `protocolSuffix` parameter +func WithProtocolSuffix(protocolSuffix string) Option { + return func(parameters *Parameters) { + parameters.protocolSuffix = protocolSuffix + } +} + +func protocolID(protocolSuffix string) protocol.ID { + return protocol.ID(fmt.Sprintf("%s%s", protocolPrefix, protocolSuffix)) +} diff --git a/share/p2p/shrexeds/pb/extended_data_square.pb.go b/share/p2p/shrexeds/pb/extended_data_square.pb.go new file mode 100644 index 0000000000..9492cb298b --- /dev/null +++ b/share/p2p/shrexeds/pb/extended_data_square.pb.go @@ -0,0 +1,508 @@ +// Code generated by protoc-gen-gogo. DO NOT EDIT. +// source: share/eds/p2p/pb/extended_data_square.proto + +package extended_data_square + +import ( + fmt "fmt" + proto "github.com/gogo/protobuf/proto" + io "io" + math "math" + math_bits "math/bits" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.GoGoProtoPackageIsVersion3 // please upgrade the proto package + +type Status int32 + +const ( + Status_INVALID Status = 0 + Status_OK Status = 100 + Status_NOT_FOUND Status = 200 + Status_REFUSED Status = 201 +) + +var Status_name = map[int32]string{ + 0: "INVALID", + 100: "OK", + 200: "NOT_FOUND", + 201: "REFUSED", +} + +var Status_value = map[string]int32{ + "INVALID": 0, + "OK": 100, + "NOT_FOUND": 200, + "REFUSED": 201, +} + +func (x Status) String() string { + return proto.EnumName(Status_name, int32(x)) +} + +func (Status) EnumDescriptor() ([]byte, []int) { + return fileDescriptor_e8ddcd8d207cc22e, []int{0} +} + +type EDSRequest struct { + Hash []byte `protobuf:"bytes,1,opt,name=hash,proto3" json:"hash,omitempty"` +} + +func (m *EDSRequest) Reset() { *m = EDSRequest{} } +func (m *EDSRequest) String() string { return proto.CompactTextString(m) } +func (*EDSRequest) ProtoMessage() {} +func (*EDSRequest) Descriptor() ([]byte, []int) { + return fileDescriptor_e8ddcd8d207cc22e, []int{0} +} +func (m *EDSRequest) XXX_Unmarshal(b []byte) error { + return m.Unmarshal(b) +} +func (m *EDSRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + if deterministic { + return xxx_messageInfo_EDSRequest.Marshal(b, m, deterministic) + } else { + b = b[:cap(b)] + n, err := m.MarshalToSizedBuffer(b) + if err != nil { + return nil, err + } + return b[:n], nil + } +} +func (m *EDSRequest) XXX_Merge(src proto.Message) { + xxx_messageInfo_EDSRequest.Merge(m, src) +} +func (m *EDSRequest) XXX_Size() int { + return m.Size() +} +func (m *EDSRequest) XXX_DiscardUnknown() { + xxx_messageInfo_EDSRequest.DiscardUnknown(m) +} + +var xxx_messageInfo_EDSRequest proto.InternalMessageInfo + +func (m *EDSRequest) GetHash() []byte { + if m != nil { + return m.Hash + } + return nil +} + +type EDSResponse struct { + Status Status `protobuf:"varint,1,opt,name=status,proto3,enum=Status" json:"status,omitempty"` +} + +func (m *EDSResponse) Reset() { *m = EDSResponse{} } +func (m *EDSResponse) String() string { return proto.CompactTextString(m) } +func (*EDSResponse) ProtoMessage() {} +func (*EDSResponse) Descriptor() ([]byte, []int) { + return fileDescriptor_e8ddcd8d207cc22e, []int{1} +} +func (m *EDSResponse) XXX_Unmarshal(b []byte) error { + return m.Unmarshal(b) +} +func (m *EDSResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + if deterministic { + return xxx_messageInfo_EDSResponse.Marshal(b, m, deterministic) + } else { + b = b[:cap(b)] + n, err := m.MarshalToSizedBuffer(b) + if err != nil { + return nil, err + } + return b[:n], nil + } +} +func (m *EDSResponse) XXX_Merge(src proto.Message) { + xxx_messageInfo_EDSResponse.Merge(m, src) +} +func (m *EDSResponse) XXX_Size() int { + return m.Size() +} +func (m *EDSResponse) XXX_DiscardUnknown() { + xxx_messageInfo_EDSResponse.DiscardUnknown(m) +} + +var xxx_messageInfo_EDSResponse proto.InternalMessageInfo + +func (m *EDSResponse) GetStatus() Status { + if m != nil { + return m.Status + } + return Status_INVALID +} + +func init() { + proto.RegisterEnum("Status", Status_name, Status_value) + proto.RegisterType((*EDSRequest)(nil), "EDSRequest") + proto.RegisterType((*EDSResponse)(nil), "EDSResponse") +} + +func init() { + proto.RegisterFile("share/eds/p2p/pb/extended_data_square.proto", fileDescriptor_e8ddcd8d207cc22e) +} + +var fileDescriptor_e8ddcd8d207cc22e = []byte{ + // 223 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xd2, 0x2e, 0xce, 0x48, 0x2c, + 0x4a, 0xd5, 0x4f, 0x4d, 0x29, 0xd6, 0x2f, 0x30, 0x2a, 0xd0, 0x2f, 0x48, 0xd2, 0x4f, 0xad, 0x28, + 0x49, 0xcd, 0x4b, 0x49, 0x4d, 0x89, 0x4f, 0x49, 0x2c, 0x49, 0x8c, 0x2f, 0x2e, 0x2c, 0x4d, 0x2c, + 0x4a, 0xd5, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x57, 0x52, 0xe0, 0xe2, 0x72, 0x75, 0x09, 0x0e, 0x4a, + 0x2d, 0x2c, 0x4d, 0x2d, 0x2e, 0x11, 0x12, 0xe2, 0x62, 0xc9, 0x48, 0x2c, 0xce, 0x90, 0x60, 0x54, + 0x60, 0xd4, 0xe0, 0x09, 0x02, 0xb3, 0x95, 0xf4, 0xb8, 0xb8, 0xc1, 0x2a, 0x8a, 0x0b, 0xf2, 0xf3, + 0x8a, 0x53, 0x85, 0xe4, 0xb9, 0xd8, 0x8a, 0x4b, 0x12, 0x4b, 0x4a, 0x8b, 0xc1, 0x8a, 0xf8, 0x8c, + 0xd8, 0xf5, 0x82, 0xc1, 0xdc, 0x20, 0xa8, 0xb0, 0x96, 0x35, 0x17, 0x1b, 0x44, 0x44, 0x88, 0x9b, + 0x8b, 0xdd, 0xd3, 0x2f, 0xcc, 0xd1, 0xc7, 0xd3, 0x45, 0x80, 0x41, 0x88, 0x8d, 0x8b, 0xc9, 0xdf, + 0x5b, 0x20, 0x45, 0x88, 0x8f, 0x8b, 0xd3, 0xcf, 0x3f, 0x24, 0xde, 0xcd, 0x3f, 0xd4, 0xcf, 0x45, + 0xe0, 0x04, 0xa3, 0x10, 0x0f, 0x17, 0x7b, 0x90, 0xab, 0x5b, 0x68, 0xb0, 0xab, 0x8b, 0xc0, 0x49, + 0x46, 0x27, 0x89, 0x13, 0x8f, 0xe4, 0x18, 0x2f, 0x3c, 0x92, 0x63, 0x7c, 0xf0, 0x48, 0x8e, 0x71, + 0xc2, 0x63, 0x39, 0x86, 0x0b, 0x8f, 0xe5, 0x18, 0x6e, 0x3c, 0x96, 0x63, 0x48, 0x62, 0x03, 0xbb, + 0xd7, 0x18, 0x10, 0x00, 0x00, 0xff, 0xff, 0x3c, 0x20, 0x59, 0x9d, 0xde, 0x00, 0x00, 0x00, +} + +func (m *EDSRequest) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBuffer(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *EDSRequest) MarshalTo(dAtA []byte) (int, error) { + size := m.Size() + return m.MarshalToSizedBuffer(dAtA[:size]) +} + +func (m *EDSRequest) MarshalToSizedBuffer(dAtA []byte) (int, error) { + i := len(dAtA) + _ = i + var l int + _ = l + if len(m.Hash) > 0 { + i -= len(m.Hash) + copy(dAtA[i:], m.Hash) + i = encodeVarintExtendedDataSquare(dAtA, i, uint64(len(m.Hash))) + i-- + dAtA[i] = 0xa + } + return len(dAtA) - i, nil +} + +func (m *EDSResponse) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBuffer(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *EDSResponse) MarshalTo(dAtA []byte) (int, error) { + size := m.Size() + return m.MarshalToSizedBuffer(dAtA[:size]) +} + +func (m *EDSResponse) MarshalToSizedBuffer(dAtA []byte) (int, error) { + i := len(dAtA) + _ = i + var l int + _ = l + if m.Status != 0 { + i = encodeVarintExtendedDataSquare(dAtA, i, uint64(m.Status)) + i-- + dAtA[i] = 0x8 + } + return len(dAtA) - i, nil +} + +func encodeVarintExtendedDataSquare(dAtA []byte, offset int, v uint64) int { + offset -= sovExtendedDataSquare(v) + base := offset + for v >= 1<<7 { + dAtA[offset] = uint8(v&0x7f | 0x80) + v >>= 7 + offset++ + } + dAtA[offset] = uint8(v) + return base +} +func (m *EDSRequest) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + l = len(m.Hash) + if l > 0 { + n += 1 + l + sovExtendedDataSquare(uint64(l)) + } + return n +} + +func (m *EDSResponse) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + if m.Status != 0 { + n += 1 + sovExtendedDataSquare(uint64(m.Status)) + } + return n +} + +func sovExtendedDataSquare(x uint64) (n int) { + return (math_bits.Len64(x|1) + 6) / 7 +} +func sozExtendedDataSquare(x uint64) (n int) { + return sovExtendedDataSquare(uint64((x << 1) ^ uint64((int64(x) >> 63)))) +} +func (m *EDSRequest) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowExtendedDataSquare + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: EDSRequest: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: EDSRequest: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Hash", wireType) + } + var byteLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowExtendedDataSquare + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + byteLen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if byteLen < 0 { + return ErrInvalidLengthExtendedDataSquare + } + postIndex := iNdEx + byteLen + if postIndex < 0 { + return ErrInvalidLengthExtendedDataSquare + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Hash = append(m.Hash[:0], dAtA[iNdEx:postIndex]...) + if m.Hash == nil { + m.Hash = []byte{} + } + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipExtendedDataSquare(dAtA[iNdEx:]) + if err != nil { + return err + } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return ErrInvalidLengthExtendedDataSquare + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *EDSResponse) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowExtendedDataSquare + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: EDSResponse: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: EDSResponse: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field Status", wireType) + } + m.Status = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowExtendedDataSquare + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.Status |= Status(b&0x7F) << shift + if b < 0x80 { + break + } + } + default: + iNdEx = preIndex + skippy, err := skipExtendedDataSquare(dAtA[iNdEx:]) + if err != nil { + return err + } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return ErrInvalidLengthExtendedDataSquare + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func skipExtendedDataSquare(dAtA []byte) (n int, err error) { + l := len(dAtA) + iNdEx := 0 + depth := 0 + for iNdEx < l { + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowExtendedDataSquare + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + wireType := int(wire & 0x7) + switch wireType { + case 0: + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowExtendedDataSquare + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + iNdEx++ + if dAtA[iNdEx-1] < 0x80 { + break + } + } + case 1: + iNdEx += 8 + case 2: + var length int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return 0, ErrIntOverflowExtendedDataSquare + } + if iNdEx >= l { + return 0, io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + length |= (int(b) & 0x7F) << shift + if b < 0x80 { + break + } + } + if length < 0 { + return 0, ErrInvalidLengthExtendedDataSquare + } + iNdEx += length + case 3: + depth++ + case 4: + if depth == 0 { + return 0, ErrUnexpectedEndOfGroupExtendedDataSquare + } + depth-- + case 5: + iNdEx += 4 + default: + return 0, fmt.Errorf("proto: illegal wireType %d", wireType) + } + if iNdEx < 0 { + return 0, ErrInvalidLengthExtendedDataSquare + } + if depth == 0 { + return iNdEx, nil + } + } + return 0, io.ErrUnexpectedEOF +} + +var ( + ErrInvalidLengthExtendedDataSquare = fmt.Errorf("proto: negative length found during unmarshaling") + ErrIntOverflowExtendedDataSquare = fmt.Errorf("proto: integer overflow") + ErrUnexpectedEndOfGroupExtendedDataSquare = fmt.Errorf("proto: unexpected end of group") +) diff --git a/share/p2p/shrexeds/pb/extended_data_square.proto b/share/p2p/shrexeds/pb/extended_data_square.proto new file mode 100644 index 0000000000..c7f826aea5 --- /dev/null +++ b/share/p2p/shrexeds/pb/extended_data_square.proto @@ -0,0 +1,16 @@ +syntax = "proto3"; + +message EDSRequest { + bytes hash = 1; // identifies the requested EDS. +} + +enum Status { + INVALID = 0; + OK = 100; // data found + NOT_FOUND = 200; // data not found + REFUSED = 201; // request refused +} + +message EDSResponse { + Status status = 1; +} diff --git a/share/p2p/shrexeds/server.go b/share/p2p/shrexeds/server.go new file mode 100644 index 0000000000..a261f8f5cf --- /dev/null +++ b/share/p2p/shrexeds/server.go @@ -0,0 +1,167 @@ +package shrexeds + +import ( + "context" + "fmt" + "io" + "time" + + "github.com/libp2p/go-libp2p-core/host" + "github.com/libp2p/go-libp2p-core/network" + "github.com/libp2p/go-libp2p-core/protocol" + + "github.com/celestiaorg/celestia-node/share" + "github.com/celestiaorg/celestia-node/share/eds" + p2p_pb "github.com/celestiaorg/celestia-node/share/p2p/shrexeds/pb" + "github.com/celestiaorg/go-libp2p-messenger/serde" +) + +// Server is responsible for serving ODSs for blocksync over the ShrEx/EDS protocol. +type Server struct { + ctx context.Context + cancel context.CancelFunc + + host host.Host + protocolID protocol.ID + + store *eds.Store + + params *Parameters +} + +// NewServer creates a new ShrEx/EDS server. +func NewServer(host host.Host, store *eds.Store, opts ...Option) (*Server, error) { + params := DefaultParameters() + for _, opt := range opts { + opt(params) + } + + if err := params.Validate(); err != nil { + return nil, fmt.Errorf("shrex-eds: server creation failed: %w", err) + } + + return &Server{ + host: host, + store: store, + protocolID: protocolID(params.protocolSuffix), + params: params, + }, nil +} + +func (s *Server) Start(context.Context) error { + s.ctx, s.cancel = context.WithCancel(context.Background()) + s.host.SetStreamHandler(s.protocolID, s.handleStream) + return nil +} + +func (s *Server) Stop(context.Context) error { + defer s.cancel() + s.host.RemoveStreamHandler(s.protocolID) + return nil +} + +func (s *Server) handleStream(stream network.Stream) { + log.Debug("server: handling eds request") + + // read request from stream to get the dataHash for store lookup + req, err := s.readRequest(stream) + if err != nil { + log.Errorw("server: reading request from stream", "err", err) + stream.Reset() //nolint:errcheck + return + } + + // ensure the requested dataHash is a valid root + hash := share.DataHash(req.Hash) + err = hash.Validate() + if err != nil { + stream.Reset() //nolint:errcheck + return + } + + ctx, cancel := context.WithTimeout(s.ctx, s.params.ReadCARDeadline) + defer cancel() + status := p2p_pb.Status_OK + // determine whether the EDS is available in our store + edsReader, err := s.store.GetCAR(ctx, hash) + if err != nil { + status = p2p_pb.Status_NOT_FOUND + } else { + defer edsReader.Close() + } + + // inform the client of our status + err = s.writeStatus(status, stream) + if err != nil { + log.Errorw("server: writing status to stream", "err", err) + stream.Reset() //nolint:errcheck + return + } + // if we cannot serve the EDS, we are already done + if status != p2p_pb.Status_OK { + stream.Close() + return + } + + // start streaming the ODS to the client + err = s.writeODS(edsReader, stream) + if err != nil { + log.Errorw("server: writing ods to stream", "err", err) + stream.Reset() //nolint:errcheck + return + } + + err = stream.Close() + if err != nil { + log.Errorw("server: closing stream", "err", err) + } +} + +func (s *Server) readRequest(stream network.Stream) (*p2p_pb.EDSRequest, error) { + err := stream.SetReadDeadline(time.Now().Add(s.params.ReadDeadline)) + if err != nil { + log.Debug(err) + } + + req := new(p2p_pb.EDSRequest) + _, err = serde.Read(stream, req) + if err != nil { + return nil, err + } + err = stream.CloseRead() + if err != nil { + log.Error(err) + } + + return req, nil +} + +func (s *Server) writeStatus(status p2p_pb.Status, stream network.Stream) error { + err := stream.SetWriteDeadline(time.Now().Add(s.params.WriteDeadline)) + if err != nil { + log.Debug(err) + } + + resp := &p2p_pb.EDSResponse{Status: status} + _, err = serde.Write(stream, resp) + return err +} + +func (s *Server) writeODS(edsReader io.ReadCloser, stream network.Stream) error { + err := stream.SetWriteDeadline(time.Now().Add(s.params.WriteDeadline)) + if err != nil { + log.Debug(err) + } + + odsReader, err := eds.ODSReader(edsReader) + if err != nil { + return fmt.Errorf("creating ODS reader: %w", err) + } + buf := make([]byte, s.params.BufferSize) + _, err = io.CopyBuffer(stream, odsReader, buf) + if err != nil { + return fmt.Errorf("writing ODS bytes: %w", err) + } + + return nil +}