diff --git a/src/router/internal/server/v2/egress_server.go b/src/router/internal/server/v2/egress_server.go index 009435f71..584e98709 100644 --- a/src/router/internal/server/v2/egress_server.go +++ b/src/router/internal/server/v2/egress_server.go @@ -28,7 +28,7 @@ type DataSetter interface { type EgressServer struct { subscriber Subscriber egressMetric *metricemitter.Counter - droppedMetric *metricemitter.Counter + droppedMetric *metricemitter.Counter subscriptionsMetric *metricemitter.Gauge batchInterval time.Duration batchSize uint @@ -52,7 +52,7 @@ func NewEgressServer( return &EgressServer{ subscriber: s, egressMetric: egressMetric, - droppedMetric: droppedMetric, + droppedMetric: droppedMetric, subscriptionsMetric: subscriptionsMetric, batchInterval: batchInterval, batchSize: batchSize, @@ -80,9 +80,12 @@ func (s *EgressServer) BatchedReceiver( s.subscriptionsMetric.Increment(1.0) defer s.subscriptionsMetric.Decrement(1.0) - d := diodes.NewOneToOneWaiterEnvelopeV2(1000, gendiode.AlertFunc(func(missed int) { - log.Printf("Dropped %d envelopes (v2 buffer) ShardID: %s", missed, req.ShardId) - s.Alert(missed)}), + d := diodes.NewOneToOneWaiterEnvelopeV2( + 1000, + gendiode.AlertFunc(func(missed int) { + log.Printf("Dropped %d envelopes (v2 buffer) ShardID: %s", missed, req.ShardId) + s.Alert(missed) + }), gendiode.WithWaiterContext(sender.Context()), ) cancel := s.subscriber.Subscribe(req, d) diff --git a/src/router/internal/server/v2/egress_server_test.go b/src/router/internal/server/v2/egress_server_test.go index 7b686a468..952e1b4c2 100644 --- a/src/router/internal/server/v2/egress_server_test.go +++ b/src/router/internal/server/v2/egress_server_test.go @@ -35,9 +35,11 @@ var _ = Describe("EgressServer", func() { Describe("BatchedReceiver", func() { It("forwards messages to a connected client", func() { + ctx, cancel := context.WithCancel(context.Background()) spyReceiver := &spyBatchReceiver{ - _context: context.Background(), + _context: ctx, } + defer cancel() subscriber := &spySubscriber{} server := v2.NewEgressServer( subscriber, @@ -54,9 +56,11 @@ var _ = Describe("EgressServer", func() { }) It("flushes if there are no more envelopes in the diode", func() { + ctx, cancel := context.WithCancel(context.Background()) spyReceiver := &spyBatchReceiver{ - _context: context.Background(), + _context: ctx, } + defer cancel() subscriber := &spySubscriber{ wait: time.Hour, } @@ -98,10 +102,12 @@ var _ = Describe("EgressServer", func() { }) It("returns an error if one occurrs", func() { + ctx, cancel := context.WithCancel(context.Background()) spyReceiver := &spyBatchReceiver{ - _context: context.Background(), + _context: ctx, err: errors.New("some error"), } + defer cancel() subscriber := &spySubscriber{} server := v2.NewEgressServer( subscriber, @@ -117,10 +123,12 @@ var _ = Describe("EgressServer", func() { }) It("calls cleanup when exiting", func() { + ctx, cancel := context.WithCancel(context.Background()) spyReceiver := &spyBatchReceiver{ - _context: context.Background(), + _context: ctx, err: errors.New("some error"), } + defer cancel() subscriber := &spySubscriber{} server := v2.NewEgressServer( subscriber, @@ -137,9 +145,11 @@ var _ = Describe("EgressServer", func() { }) It("passes the request to the subscriber", func() { + ctx, cancel := context.WithCancel(context.Background()) spyReceiver := &spyBatchReceiver{ - _context: context.Background(), + _context: ctx, } + defer cancel() subscriber := &spySubscriber{} server := v2.NewEgressServer( subscriber, @@ -193,9 +203,12 @@ var _ = Describe("EgressServer", func() { It("emits a metric for the number of envelopes sent", func() { metricClient := testhelper.NewMetricClient() + + ctx, cancel := context.WithCancel(context.Background()) spyReceiver := &spyBatchReceiver{ - _context: context.Background(), + _context: ctx, } + defer cancel() subscriber := &spySubscriber{} server := v2.NewEgressServer( @@ -230,13 +243,13 @@ var _ = Describe("EgressServer", func() { 10, ) - br := newSlowBatchReciever(100 * time.Millisecond) + ctx, cancel := context.WithCancel(context.Background()) + br := newSlowBatchReciever(100*time.Millisecond, ctx) + defer cancel() go server.BatchedReceiver(&loggregator_v2.EgressBatchRequest{}, br) - Eventually( - egressDropped.GetDelta, - 40).Should(BeNumerically(">", 1)) + Eventually(egressDropped.GetDelta, 10).Should(BeNumerically(">", 1)) }) }) }) @@ -244,11 +257,14 @@ var _ = Describe("EgressServer", func() { type slowBatchReceiver struct { grpc.ServerStream sendDelay time.Duration + + _context context.Context } -func newSlowBatchReciever(sendDelay time.Duration) *slowBatchReceiver { +func newSlowBatchReciever(sendDelay time.Duration, ctx context.Context) *slowBatchReceiver { return &slowBatchReceiver{ sendDelay: sendDelay, + _context: ctx, } } @@ -258,7 +274,7 @@ func (s *slowBatchReceiver) Send(batch *loggregator_v2.EnvelopeBatch) error { } func (s *slowBatchReceiver) Context() context.Context { - return context.Background() + return s._context } type spyBatchReceiver struct {