Skip to content
This repository has been archived by the owner on Sep 21, 2023. It is now read-only.

Add blocking on the first event in the batch #91

Merged
merged 1 commit into from
Aug 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions queue/queue.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package queue

import (
"context"
"fmt"

beatsqueue "github.com/elastic/beats/v7/libbeat/publisher/queue"
Expand Down Expand Up @@ -42,7 +43,10 @@ type MetricsSource interface {
Metrics() (Metrics, error)
}

var ErrQueueIsFull = fmt.Errorf("couldn't publish: queue is full")
var (
ErrQueueIsFull = fmt.Errorf("couldn't publish: queue is full")
ErrQueueIsClosed = fmt.Errorf("couldn't publish: queue is closed")
)

func New(c Config) (*Queue, error) {
var eventQueue beatsqueue.Queue
Expand All @@ -60,8 +64,17 @@ func New(c Config) (*Queue, error) {
return &Queue{eventQueue: eventQueue, producer: producer}, nil
}

func (queue *Queue) Publish(event *messages.Event) (EntryID, error) {
if !queue.producer.Publish(event) {
func (queue *Queue) Publish(ctx context.Context, event *messages.Event) (EntryID, error) {
_ = ctx.Done()
// TODO pass the real channel once libbeat supports it
if !queue.producer.Publish(event /*, cancelCh*/) {
return EntryID(0), ErrQueueIsClosed
}
return EntryID(0), nil
}

func (queue *Queue) TryPublish(event *messages.Event) (EntryID, error) {
if !queue.producer.TryPublish(event) {
return EntryID(0), ErrQueueIsFull
}
return EntryID(0), nil
Expand Down
5 changes: 3 additions & 2 deletions queue/queue_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package queue

import (
"context"
"os"
"testing"
"time"
Expand All @@ -27,7 +28,7 @@ func TestMemoryQueueSimpleBatch(t *testing.T) {
eventCount := 100
events := make([]messages.Event, eventCount)
for i := 0; i < eventCount; i++ {
_, err = queue.Publish(&events[i])
_, err = queue.Publish(context.Background(), &events[i])
assert.NoError(t, err, "couldn't publish to queue")
}

Expand Down Expand Up @@ -142,7 +143,7 @@ func TestQueueTypes(t *testing.T) {
tracker := [10]bool{}
for idx := range tracker {
e := makeEvent(idx)
_, err = queue.Publish(e)
_, err = queue.Publish(context.Background(), e)
assert.NoError(t, err, "couldn't publish to queue")
}

Expand Down
33 changes: 24 additions & 9 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,12 @@ type Publisher interface {
AcceptedIndex() queue.EntryID
// PersistedIndex returns the current sequential index of the persisted events
PersistedIndex() queue.EntryID
// Publish publishes the given event and returns the current accepted index (after this event)
Publish(*messages.Event) (queue.EntryID, error)
// Publish publishes the given event and returns the current accepted index (after this event).
// This operation is blocking until the event is published or the target queue is closed.
Publish(context.Context, *messages.Event) (queue.EntryID, error)
// TryPublish tries to immediately publish the given event and returns the current accepted index (after this event).
// Returns an error if it was not possible to publish the event without blocking.
TryPublish(*messages.Event) (queue.EntryID, error)
}

// ShipperServer contains all the gRPC operations for the shipper endpoints.
Expand Down Expand Up @@ -93,20 +97,23 @@ func (serv *shipperServer) GetPersistedIndex() uint64 {
}

// PublishEvents is the server implementation of the gRPC PublishEvents call.
func (serv *shipperServer) PublishEvents(_ context.Context, req *messages.PublishRequest) (*messages.PublishReply, error) {
func (serv *shipperServer) PublishEvents(ctx context.Context, req *messages.PublishRequest) (*messages.PublishReply, error) {
resp := &messages.PublishReply{
Uuid: serv.uuid,
Uuid: serv.uuid,
AcceptedIndex: serv.GetAcceptedIndex(),
PersistedIndex: serv.GetPersistedIndex(),
}

// the value in the request is optional
if req.Uuid != "" && req.Uuid != serv.uuid {
resp.AcceptedIndex = serv.GetAcceptedIndex()
resp.PersistedIndex = serv.GetPersistedIndex()
serv.logger.Debugf("shipper UUID does not match, all events rejected. Expected = %s, actual = %s", serv.uuid, req.Uuid)

return resp, status.Error(codes.FailedPrecondition, fmt.Sprintf("UUID does not match. Expected = %s, actual = %s", serv.uuid, req.Uuid))
}

if len(req.Events) == 0 {
return resp, nil
}

if serv.cfg.StrictMode {
for _, e := range req.Events {
err := serv.validateEvent(e)
Expand All @@ -116,8 +123,16 @@ func (serv *shipperServer) PublishEvents(_ context.Context, req *messages.Publis
}
}

for _, e := range req.Events {
_, err := serv.publisher.Publish(e)
// we block until at least one event from the batch is published
_, err := serv.publisher.Publish(ctx, req.Events[0])
if err != nil {
return nil, status.Error(codes.Unavailable, err.Error())
}
resp.AcceptedCount++

// then we try to publish the rest without blocking
for _, e := range req.Events[1:] {
_, err := serv.publisher.TryPublish(e)
if err == nil {
resp.AcceptedCount++
continue
Expand Down
40 changes: 21 additions & 19 deletions server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,7 @@ func TestPublish(t *testing.T) {
Events: events,
})
require.NoError(t, err)
require.Equal(t, uint32(len(events)), reply.AcceptedCount)
require.Equal(t, uint64(len(events)), reply.AcceptedIndex)
require.Equal(t, uint64(publisher.persistedIndex), pir.PersistedIndex)
assertIndices(t, reply, pir, len(events), len(events), int(publisher.persistedIndex))
})

t.Run("should grow accepted index", func(t *testing.T) {
Expand All @@ -86,25 +84,19 @@ func TestPublish(t *testing.T) {
Events: events,
})
require.NoError(t, err)
require.Equal(t, uint32(len(events)), reply.AcceptedCount)
require.Equal(t, uint64(1), reply.AcceptedIndex)
require.Equal(t, uint64(publisher.persistedIndex), pir.PersistedIndex)
assertIndices(t, reply, pir, len(events), 1, int(publisher.persistedIndex))
reply, err = client.PublishEvents(ctx, &messages.PublishRequest{
Uuid: pir.Uuid,
Events: events,
})
require.NoError(t, err)
require.Equal(t, uint32(len(events)), reply.AcceptedCount)
require.Equal(t, uint64(2), reply.AcceptedIndex)
require.Equal(t, uint64(publisher.persistedIndex), pir.PersistedIndex)
assertIndices(t, reply, pir, len(events), 2, int(publisher.persistedIndex))
reply, err = client.PublishEvents(ctx, &messages.PublishRequest{
Uuid: pir.Uuid,
Events: events,
})
require.NoError(t, err)
require.Equal(t, uint32(len(events)), reply.AcceptedCount)
require.Equal(t, uint64(3), reply.AcceptedIndex)
require.Equal(t, uint64(publisher.persistedIndex), pir.PersistedIndex)
assertIndices(t, reply, pir, len(events), 3, int(publisher.persistedIndex))
})

t.Run("should return different count when queue is full", func(t *testing.T) {
Expand All @@ -115,9 +107,7 @@ func TestPublish(t *testing.T) {
Events: events,
})
require.NoError(t, err)
require.Equal(t, uint32(1), reply.AcceptedCount)
require.Equal(t, uint64(1), reply.AcceptedIndex)
require.Equal(t, uint64(publisher.persistedIndex), pir.PersistedIndex)
assertIndices(t, reply, pir, 1, 1, int(publisher.persistedIndex))
})

t.Run("should return an error when uuid does not match", func(t *testing.T) {
Expand Down Expand Up @@ -153,7 +143,7 @@ func TestPublish(t *testing.T) {
Metadata: sampleValues,
Fields: sampleValues,
},
expectedMsg: "timestamp: proto:\u00a0invalid nil Timestamp",
expectedMsg: "invalid nil Timestamp",
},
{
name: "no source",
Expand Down Expand Up @@ -240,15 +230,15 @@ func TestPublish(t *testing.T) {
status, ok := status.FromError(err)
require.True(t, ok, "expected gRPC error")
require.Equal(t, codes.InvalidArgument, status.Code())
require.Equal(t, tc.expectedMsg, status.Message())
require.Contains(t, status.Message(), tc.expectedMsg)
Copy link
Member Author

@rdner rdner Aug 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to fix #90 (comment)


// no validation in non-strict mode
reply, err = client.PublishEvents(ctx, &messages.PublishRequest{
Uuid: pir.Uuid,
Events: []*messages.Event{tc.event},
})
require.NoError(t, err)
require.Equal(t, uint32(1), reply.AcceptedCount)
require.Equal(t, uint32(1), reply.AcceptedCount, "should accept in non-strict mode")
})
}
})
Expand Down Expand Up @@ -350,6 +340,14 @@ func createConsumers(t *testing.T, ctx context.Context, client pb.ProducerClient
return cl
}

func assertIndices(t *testing.T, reply *messages.PublishReply, pir *messages.PersistedIndexReply, acceptedCount int, acceptedIndex int, persistedIndex int) {
require.NotNil(t, reply, "reply cannot be nil")
require.Equal(t, uint32(acceptedCount), reply.AcceptedCount, "accepted count does not match")
require.Equal(t, uint64(acceptedIndex), reply.AcceptedIndex, "accepted index does not match")
require.Equal(t, uint64(persistedIndex), reply.PersistedIndex, "persisted index does not match")
require.Equal(t, uint64(persistedIndex), pir.PersistedIndex, "persisted index reply does not match")
}

func getPersistedIndex(t *testing.T, ctx context.Context, client pb.ProducerClient) *messages.PersistedIndexReply {
pirCtx, cancel := context.WithCancel(ctx)
defer cancel()
Expand Down Expand Up @@ -397,7 +395,11 @@ type publisherMock struct {
persistedIndex queue.EntryID
}

func (p *publisherMock) Publish(event *messages.Event) (queue.EntryID, error) {
func (p *publisherMock) Publish(_ context.Context, event *messages.Event) (queue.EntryID, error) {
return p.TryPublish(event)
}

func (p *publisherMock) TryPublish(event *messages.Event) (queue.EntryID, error) {
if len(p.q) == cap(p.q) {
return queue.EntryID(0), queue.ErrQueueIsFull
}
Expand Down