diff --git a/queue/queue.go b/queue/queue.go index e997321..0e75ffa 100644 --- a/queue/queue.go +++ b/queue/queue.go @@ -5,6 +5,7 @@ package queue import ( + "context" "fmt" beatsqueue "github.com/elastic/beats/v7/libbeat/publisher/queue" @@ -42,7 +43,10 @@ type MetricsSource interface { Metrics() (Metrics, error) } -var ErrQueueIsFull = fmt.Errorf("couldn't publish: queue is full") +var ( + ErrQueueIsFull = fmt.Errorf("couldn't publish: queue is full") + ErrQueueIsClosed = fmt.Errorf("couldn't publish: queue is closed") +) func New(c Config) (*Queue, error) { var eventQueue beatsqueue.Queue @@ -60,8 +64,17 @@ func New(c Config) (*Queue, error) { return &Queue{eventQueue: eventQueue, producer: producer}, nil } -func (queue *Queue) Publish(event *messages.Event) (EntryID, error) { - if !queue.producer.Publish(event) { +func (queue *Queue) Publish(ctx context.Context, event *messages.Event) (EntryID, error) { + _ = ctx.Done() + // TODO pass the real channel once libbeat supports it + if !queue.producer.Publish(event /*, cancelCh*/) { + return EntryID(0), ErrQueueIsClosed + } + return EntryID(0), nil +} + +func (queue *Queue) TryPublish(event *messages.Event) (EntryID, error) { + if !queue.producer.TryPublish(event) { return EntryID(0), ErrQueueIsFull } return EntryID(0), nil diff --git a/queue/queue_test.go b/queue/queue_test.go index cebf6f5..0c4bff2 100644 --- a/queue/queue_test.go +++ b/queue/queue_test.go @@ -5,6 +5,7 @@ package queue import ( + "context" "os" "testing" "time" @@ -27,7 +28,7 @@ func TestMemoryQueueSimpleBatch(t *testing.T) { eventCount := 100 events := make([]messages.Event, eventCount) for i := 0; i < eventCount; i++ { - _, err = queue.Publish(&events[i]) + _, err = queue.Publish(context.Background(), &events[i]) assert.NoError(t, err, "couldn't publish to queue") } @@ -142,7 +143,7 @@ func TestQueueTypes(t *testing.T) { tracker := [10]bool{} for idx := range tracker { e := makeEvent(idx) - _, err = queue.Publish(e) + _, err = queue.Publish(context.Background(), e) assert.NoError(t, err, "couldn't publish to queue") } diff --git a/server/server.go b/server/server.go index 0830e1d..716ce97 100644 --- a/server/server.go +++ b/server/server.go @@ -33,8 +33,12 @@ type Publisher interface { AcceptedIndex() queue.EntryID // PersistedIndex returns the current sequential index of the persisted events PersistedIndex() queue.EntryID - // Publish publishes the given event and returns the current accepted index (after this event) - Publish(*messages.Event) (queue.EntryID, error) + // Publish publishes the given event and returns the current accepted index (after this event). + // This operation is blocking until the event is published or the target queue is closed. + Publish(context.Context, *messages.Event) (queue.EntryID, error) + // TryPublish tries to immediately publish the given event and returns the current accepted index (after this event). + // Returns an error if it was not possible to publish the event without blocking. + TryPublish(*messages.Event) (queue.EntryID, error) } // ShipperServer contains all the gRPC operations for the shipper endpoints. @@ -93,20 +97,23 @@ func (serv *shipperServer) GetPersistedIndex() uint64 { } // PublishEvents is the server implementation of the gRPC PublishEvents call. -func (serv *shipperServer) PublishEvents(_ context.Context, req *messages.PublishRequest) (*messages.PublishReply, error) { +func (serv *shipperServer) PublishEvents(ctx context.Context, req *messages.PublishRequest) (*messages.PublishReply, error) { resp := &messages.PublishReply{ - Uuid: serv.uuid, + Uuid: serv.uuid, + AcceptedIndex: serv.GetAcceptedIndex(), + PersistedIndex: serv.GetPersistedIndex(), } // the value in the request is optional if req.Uuid != "" && req.Uuid != serv.uuid { - resp.AcceptedIndex = serv.GetAcceptedIndex() - resp.PersistedIndex = serv.GetPersistedIndex() serv.logger.Debugf("shipper UUID does not match, all events rejected. Expected = %s, actual = %s", serv.uuid, req.Uuid) - return resp, status.Error(codes.FailedPrecondition, fmt.Sprintf("UUID does not match. Expected = %s, actual = %s", serv.uuid, req.Uuid)) } + if len(req.Events) == 0 { + return resp, nil + } + if serv.cfg.StrictMode { for _, e := range req.Events { err := serv.validateEvent(e) @@ -116,8 +123,16 @@ func (serv *shipperServer) PublishEvents(_ context.Context, req *messages.Publis } } - for _, e := range req.Events { - _, err := serv.publisher.Publish(e) + // we block until at least one event from the batch is published + _, err := serv.publisher.Publish(ctx, req.Events[0]) + if err != nil { + return nil, status.Error(codes.Unavailable, err.Error()) + } + resp.AcceptedCount++ + + // then we try to publish the rest without blocking + for _, e := range req.Events[1:] { + _, err := serv.publisher.TryPublish(e) if err == nil { resp.AcceptedCount++ continue diff --git a/server/server_test.go b/server/server_test.go index e870ed1..4db9af8 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -73,9 +73,7 @@ func TestPublish(t *testing.T) { Events: events, }) require.NoError(t, err) - require.Equal(t, uint32(len(events)), reply.AcceptedCount) - require.Equal(t, uint64(len(events)), reply.AcceptedIndex) - require.Equal(t, uint64(publisher.persistedIndex), pir.PersistedIndex) + assertIndices(t, reply, pir, len(events), len(events), int(publisher.persistedIndex)) }) t.Run("should grow accepted index", func(t *testing.T) { @@ -86,25 +84,19 @@ func TestPublish(t *testing.T) { Events: events, }) require.NoError(t, err) - require.Equal(t, uint32(len(events)), reply.AcceptedCount) - require.Equal(t, uint64(1), reply.AcceptedIndex) - require.Equal(t, uint64(publisher.persistedIndex), pir.PersistedIndex) + assertIndices(t, reply, pir, len(events), 1, int(publisher.persistedIndex)) reply, err = client.PublishEvents(ctx, &messages.PublishRequest{ Uuid: pir.Uuid, Events: events, }) require.NoError(t, err) - require.Equal(t, uint32(len(events)), reply.AcceptedCount) - require.Equal(t, uint64(2), reply.AcceptedIndex) - require.Equal(t, uint64(publisher.persistedIndex), pir.PersistedIndex) + assertIndices(t, reply, pir, len(events), 2, int(publisher.persistedIndex)) reply, err = client.PublishEvents(ctx, &messages.PublishRequest{ Uuid: pir.Uuid, Events: events, }) require.NoError(t, err) - require.Equal(t, uint32(len(events)), reply.AcceptedCount) - require.Equal(t, uint64(3), reply.AcceptedIndex) - require.Equal(t, uint64(publisher.persistedIndex), pir.PersistedIndex) + assertIndices(t, reply, pir, len(events), 3, int(publisher.persistedIndex)) }) t.Run("should return different count when queue is full", func(t *testing.T) { @@ -115,9 +107,7 @@ func TestPublish(t *testing.T) { Events: events, }) require.NoError(t, err) - require.Equal(t, uint32(1), reply.AcceptedCount) - require.Equal(t, uint64(1), reply.AcceptedIndex) - require.Equal(t, uint64(publisher.persistedIndex), pir.PersistedIndex) + assertIndices(t, reply, pir, 1, 1, int(publisher.persistedIndex)) }) t.Run("should return an error when uuid does not match", func(t *testing.T) { @@ -153,7 +143,7 @@ func TestPublish(t *testing.T) { Metadata: sampleValues, Fields: sampleValues, }, - expectedMsg: "timestamp: proto:\u00a0invalid nil Timestamp", + expectedMsg: "invalid nil Timestamp", }, { name: "no source", @@ -240,7 +230,7 @@ func TestPublish(t *testing.T) { status, ok := status.FromError(err) require.True(t, ok, "expected gRPC error") require.Equal(t, codes.InvalidArgument, status.Code()) - require.Equal(t, tc.expectedMsg, status.Message()) + require.Contains(t, status.Message(), tc.expectedMsg) // no validation in non-strict mode reply, err = client.PublishEvents(ctx, &messages.PublishRequest{ @@ -248,7 +238,7 @@ func TestPublish(t *testing.T) { Events: []*messages.Event{tc.event}, }) require.NoError(t, err) - require.Equal(t, uint32(1), reply.AcceptedCount) + require.Equal(t, uint32(1), reply.AcceptedCount, "should accept in non-strict mode") }) } }) @@ -350,6 +340,14 @@ func createConsumers(t *testing.T, ctx context.Context, client pb.ProducerClient return cl } +func assertIndices(t *testing.T, reply *messages.PublishReply, pir *messages.PersistedIndexReply, acceptedCount int, acceptedIndex int, persistedIndex int) { + require.NotNil(t, reply, "reply cannot be nil") + require.Equal(t, uint32(acceptedCount), reply.AcceptedCount, "accepted count does not match") + require.Equal(t, uint64(acceptedIndex), reply.AcceptedIndex, "accepted index does not match") + require.Equal(t, uint64(persistedIndex), reply.PersistedIndex, "persisted index does not match") + require.Equal(t, uint64(persistedIndex), pir.PersistedIndex, "persisted index reply does not match") +} + func getPersistedIndex(t *testing.T, ctx context.Context, client pb.ProducerClient) *messages.PersistedIndexReply { pirCtx, cancel := context.WithCancel(ctx) defer cancel() @@ -397,7 +395,11 @@ type publisherMock struct { persistedIndex queue.EntryID } -func (p *publisherMock) Publish(event *messages.Event) (queue.EntryID, error) { +func (p *publisherMock) Publish(_ context.Context, event *messages.Event) (queue.EntryID, error) { + return p.TryPublish(event) +} + +func (p *publisherMock) TryPublish(event *messages.Event) (queue.EntryID, error) { if len(p.q) == cap(p.q) { return queue.EntryID(0), queue.ErrQueueIsFull }