Skip to content
This repository has been archived by the owner on Sep 21, 2023. It is now read-only.

Commit

Permalink
Add blocking on the first event in the batch (#91)
Browse files Browse the repository at this point in the history
Also:

* changed the queue wrapper to support blocking `Publish` and
non-blocking `TryPublish`.
* removed returning `ErrQueueIsFull` from `Publish` since it was
incorrect (`Publish` is blocking until the queue is available).
* Added context to `Publish` for future change in libbeat (adding
`cancel` channel to each producer.
* Fix flaky timestamp error check
  • Loading branch information
rdner authored Aug 15, 2022
1 parent 8935aec commit 34571c4
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 33 deletions.
19 changes: 16 additions & 3 deletions queue/queue.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package queue

import (
"context"
"fmt"

beatsqueue "github.com/elastic/beats/v7/libbeat/publisher/queue"
Expand Down Expand Up @@ -42,7 +43,10 @@ type MetricsSource interface {
Metrics() (Metrics, error)
}

var ErrQueueIsFull = fmt.Errorf("couldn't publish: queue is full")
var (
ErrQueueIsFull = fmt.Errorf("couldn't publish: queue is full")
ErrQueueIsClosed = fmt.Errorf("couldn't publish: queue is closed")
)

func New(c Config) (*Queue, error) {
var eventQueue beatsqueue.Queue
Expand All @@ -60,8 +64,17 @@ func New(c Config) (*Queue, error) {
return &Queue{eventQueue: eventQueue, producer: producer}, nil
}

func (queue *Queue) Publish(event *messages.Event) (EntryID, error) {
if !queue.producer.Publish(event) {
func (queue *Queue) Publish(ctx context.Context, event *messages.Event) (EntryID, error) {
_ = ctx.Done()
// TODO pass the real channel once libbeat supports it
if !queue.producer.Publish(event /*, cancelCh*/) {
return EntryID(0), ErrQueueIsClosed
}
return EntryID(0), nil
}

func (queue *Queue) TryPublish(event *messages.Event) (EntryID, error) {
if !queue.producer.TryPublish(event) {
return EntryID(0), ErrQueueIsFull
}
return EntryID(0), nil
Expand Down
5 changes: 3 additions & 2 deletions queue/queue_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package queue

import (
"context"
"os"
"testing"
"time"
Expand All @@ -27,7 +28,7 @@ func TestMemoryQueueSimpleBatch(t *testing.T) {
eventCount := 100
events := make([]messages.Event, eventCount)
for i := 0; i < eventCount; i++ {
_, err = queue.Publish(&events[i])
_, err = queue.Publish(context.Background(), &events[i])
assert.NoError(t, err, "couldn't publish to queue")
}

Expand Down Expand Up @@ -142,7 +143,7 @@ func TestQueueTypes(t *testing.T) {
tracker := [10]bool{}
for idx := range tracker {
e := makeEvent(idx)
_, err = queue.Publish(e)
_, err = queue.Publish(context.Background(), e)
assert.NoError(t, err, "couldn't publish to queue")
}

Expand Down
33 changes: 24 additions & 9 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,12 @@ type Publisher interface {
AcceptedIndex() queue.EntryID
// PersistedIndex returns the current sequential index of the persisted events
PersistedIndex() queue.EntryID
// Publish publishes the given event and returns the current accepted index (after this event)
Publish(*messages.Event) (queue.EntryID, error)
// Publish publishes the given event and returns the current accepted index (after this event).
// This operation is blocking until the event is published or the target queue is closed.
Publish(context.Context, *messages.Event) (queue.EntryID, error)
// TryPublish tries to immediately publish the given event and returns the current accepted index (after this event).
// Returns an error if it was not possible to publish the event without blocking.
TryPublish(*messages.Event) (queue.EntryID, error)
}

// ShipperServer contains all the gRPC operations for the shipper endpoints.
Expand Down Expand Up @@ -93,20 +97,23 @@ func (serv *shipperServer) GetPersistedIndex() uint64 {
}

// PublishEvents is the server implementation of the gRPC PublishEvents call.
func (serv *shipperServer) PublishEvents(_ context.Context, req *messages.PublishRequest) (*messages.PublishReply, error) {
func (serv *shipperServer) PublishEvents(ctx context.Context, req *messages.PublishRequest) (*messages.PublishReply, error) {
resp := &messages.PublishReply{
Uuid: serv.uuid,
Uuid: serv.uuid,
AcceptedIndex: serv.GetAcceptedIndex(),
PersistedIndex: serv.GetPersistedIndex(),
}

// the value in the request is optional
if req.Uuid != "" && req.Uuid != serv.uuid {
resp.AcceptedIndex = serv.GetAcceptedIndex()
resp.PersistedIndex = serv.GetPersistedIndex()
serv.logger.Debugf("shipper UUID does not match, all events rejected. Expected = %s, actual = %s", serv.uuid, req.Uuid)

return resp, status.Error(codes.FailedPrecondition, fmt.Sprintf("UUID does not match. Expected = %s, actual = %s", serv.uuid, req.Uuid))
}

if len(req.Events) == 0 {
return resp, nil
}

if serv.cfg.StrictMode {
for _, e := range req.Events {
err := serv.validateEvent(e)
Expand All @@ -116,8 +123,16 @@ func (serv *shipperServer) PublishEvents(_ context.Context, req *messages.Publis
}
}

for _, e := range req.Events {
_, err := serv.publisher.Publish(e)
// we block until at least one event from the batch is published
_, err := serv.publisher.Publish(ctx, req.Events[0])
if err != nil {
return nil, status.Error(codes.Unavailable, err.Error())
}
resp.AcceptedCount++

// then we try to publish the rest without blocking
for _, e := range req.Events[1:] {
_, err := serv.publisher.TryPublish(e)
if err == nil {
resp.AcceptedCount++
continue
Expand Down
40 changes: 21 additions & 19 deletions server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,7 @@ func TestPublish(t *testing.T) {
Events: events,
})
require.NoError(t, err)
require.Equal(t, uint32(len(events)), reply.AcceptedCount)
require.Equal(t, uint64(len(events)), reply.AcceptedIndex)
require.Equal(t, uint64(publisher.persistedIndex), pir.PersistedIndex)
assertIndices(t, reply, pir, len(events), len(events), int(publisher.persistedIndex))
})

t.Run("should grow accepted index", func(t *testing.T) {
Expand All @@ -86,25 +84,19 @@ func TestPublish(t *testing.T) {
Events: events,
})
require.NoError(t, err)
require.Equal(t, uint32(len(events)), reply.AcceptedCount)
require.Equal(t, uint64(1), reply.AcceptedIndex)
require.Equal(t, uint64(publisher.persistedIndex), pir.PersistedIndex)
assertIndices(t, reply, pir, len(events), 1, int(publisher.persistedIndex))
reply, err = client.PublishEvents(ctx, &messages.PublishRequest{
Uuid: pir.Uuid,
Events: events,
})
require.NoError(t, err)
require.Equal(t, uint32(len(events)), reply.AcceptedCount)
require.Equal(t, uint64(2), reply.AcceptedIndex)
require.Equal(t, uint64(publisher.persistedIndex), pir.PersistedIndex)
assertIndices(t, reply, pir, len(events), 2, int(publisher.persistedIndex))
reply, err = client.PublishEvents(ctx, &messages.PublishRequest{
Uuid: pir.Uuid,
Events: events,
})
require.NoError(t, err)
require.Equal(t, uint32(len(events)), reply.AcceptedCount)
require.Equal(t, uint64(3), reply.AcceptedIndex)
require.Equal(t, uint64(publisher.persistedIndex), pir.PersistedIndex)
assertIndices(t, reply, pir, len(events), 3, int(publisher.persistedIndex))
})

t.Run("should return different count when queue is full", func(t *testing.T) {
Expand All @@ -115,9 +107,7 @@ func TestPublish(t *testing.T) {
Events: events,
})
require.NoError(t, err)
require.Equal(t, uint32(1), reply.AcceptedCount)
require.Equal(t, uint64(1), reply.AcceptedIndex)
require.Equal(t, uint64(publisher.persistedIndex), pir.PersistedIndex)
assertIndices(t, reply, pir, 1, 1, int(publisher.persistedIndex))
})

t.Run("should return an error when uuid does not match", func(t *testing.T) {
Expand Down Expand Up @@ -153,7 +143,7 @@ func TestPublish(t *testing.T) {
Metadata: sampleValues,
Fields: sampleValues,
},
expectedMsg: "timestamp: proto:\u00a0invalid nil Timestamp",
expectedMsg: "invalid nil Timestamp",
},
{
name: "no source",
Expand Down Expand Up @@ -240,15 +230,15 @@ func TestPublish(t *testing.T) {
status, ok := status.FromError(err)
require.True(t, ok, "expected gRPC error")
require.Equal(t, codes.InvalidArgument, status.Code())
require.Equal(t, tc.expectedMsg, status.Message())
require.Contains(t, status.Message(), tc.expectedMsg)

// no validation in non-strict mode
reply, err = client.PublishEvents(ctx, &messages.PublishRequest{
Uuid: pir.Uuid,
Events: []*messages.Event{tc.event},
})
require.NoError(t, err)
require.Equal(t, uint32(1), reply.AcceptedCount)
require.Equal(t, uint32(1), reply.AcceptedCount, "should accept in non-strict mode")
})
}
})
Expand Down Expand Up @@ -350,6 +340,14 @@ func createConsumers(t *testing.T, ctx context.Context, client pb.ProducerClient
return cl
}

func assertIndices(t *testing.T, reply *messages.PublishReply, pir *messages.PersistedIndexReply, acceptedCount int, acceptedIndex int, persistedIndex int) {
require.NotNil(t, reply, "reply cannot be nil")
require.Equal(t, uint32(acceptedCount), reply.AcceptedCount, "accepted count does not match")
require.Equal(t, uint64(acceptedIndex), reply.AcceptedIndex, "accepted index does not match")
require.Equal(t, uint64(persistedIndex), reply.PersistedIndex, "persisted index does not match")
require.Equal(t, uint64(persistedIndex), pir.PersistedIndex, "persisted index reply does not match")
}

func getPersistedIndex(t *testing.T, ctx context.Context, client pb.ProducerClient) *messages.PersistedIndexReply {
pirCtx, cancel := context.WithCancel(ctx)
defer cancel()
Expand Down Expand Up @@ -397,7 +395,11 @@ type publisherMock struct {
persistedIndex queue.EntryID
}

func (p *publisherMock) Publish(event *messages.Event) (queue.EntryID, error) {
func (p *publisherMock) Publish(_ context.Context, event *messages.Event) (queue.EntryID, error) {
return p.TryPublish(event)
}

func (p *publisherMock) TryPublish(event *messages.Event) (queue.EntryID, error) {
if len(p.q) == cap(p.q) {
return queue.EntryID(0), queue.ErrQueueIsFull
}
Expand Down

0 comments on commit 34571c4

Please sign in to comment.