diff --git a/channel.go b/channel.go index b7d4d14..a4afc98 100644 --- a/channel.go +++ b/channel.go @@ -1330,8 +1330,19 @@ internal counter for DeliveryTags with the first confirmation starts at 1. */ func (ch *Channel) Publish(exchange, key string, mandatory, immediate bool, msg Publishing) error { + _, err := ch.PublishWithDeferredConfirm(exchange, key, mandatory, immediate, msg) + return err +} + +/* +PublishWithDeferredConfirm behaves identically to Publish but additionally returns a +DeferredConfirmation, allowing the caller to wait on the publisher confirmation +for this message. If the channel has not been put into confirm mode, +the DeferredConfirmation will be nil. +*/ +func (ch *Channel) PublishWithDeferredConfirm(exchange, key string, mandatory, immediate bool, msg Publishing) (*DeferredConfirmation, error) { if err := msg.Headers.Validate(); err != nil { - return err + return nil, err } ch.m.Lock() @@ -1359,14 +1370,14 @@ func (ch *Channel) Publish(exchange, key string, mandatory, immediate bool, msg AppId: msg.AppId, }, }); err != nil { - return err + return nil, err } if ch.confirming { - ch.confirms.Publish() + return ch.confirms.Publish(), nil } - return nil + return nil, nil } /* diff --git a/client_test.go b/client_test.go index ab036d4..748ff14 100644 --- a/client_test.go +++ b/client_test.go @@ -454,6 +454,69 @@ func TestConfirmMultipleOrdersDeliveryTags(t *testing.T) { } +func TestDeferredConfirmations(t *testing.T) { + rwc, srv := newSession(t) + defer rwc.Close() + + go func() { + srv.connectionOpen() + srv.channelOpen(1) + + srv.recv(1, &confirmSelect{}) + srv.send(1, &confirmSelectOk{}) + + srv.recv(1, &basicPublish{}) + srv.recv(1, &basicPublish{}) + srv.recv(1, &basicPublish{}) + srv.recv(1, &basicPublish{}) + }() + + c, err := Open(rwc, defaultConfig()) + if err != nil { + t.Fatalf("could not create connection: %v (%s)", c, err) + } + + ch, err := c.Channel() + if err != nil { + t.Fatalf("could not open channel: %v (%s)", ch, err) + } + + ch.Confirm(false) + + var results []*DeferredConfirmation + for i := 1; i < 5; i++ { + dc, err := ch.PublishWithDeferredConfirm("", "q", false, false, Publishing{Body: []byte("pub")}) + if err != nil { + t.Fatalf("failed to PublishWithDeferredConfirm: %v", err) + } + results = append(results, dc) + } + + acks := make(chan Confirmation, 4) + for _, result := range results { + go func(r *DeferredConfirmation) { + acks <- Confirmation{Ack: r.Wait(), DeliveryTag: r.DeliveryTag} + }(result) + } + + // received out of order, consumed out of order + assertReceive := func(ack Confirmation, tags ...uint64) { + for _, tag := range tags { + if tag == ack.DeliveryTag { + return + } + } + t.Fatalf("failed ack, expected ack to be in set %v, got %d", tags, ack.DeliveryTag) + } + srv.send(1, &basicAck{DeliveryTag: 2}) + assertReceive(<-acks, 2) + srv.send(1, &basicAck{DeliveryTag: 1}) + assertReceive(<-acks, 1) + srv.send(1, &basicAck{DeliveryTag: 4, Multiple: true}) + assertReceive(<-acks, 3, 4) // 3 and 4 are non-determistic due to map ordering + assertReceive(<-acks, 3, 4) +} + func TestNotifyClosesReusedPublisherConfirmChan(t *testing.T) { rwc, srv := newSession(t) diff --git a/confirms.go b/confirms.go index 299b8f0..654d755 100644 --- a/confirms.go +++ b/confirms.go @@ -5,24 +5,28 @@ package amqp091 -import "sync" +import ( + "sync" +) // confirms resequences and notifies one or multiple publisher confirmation listeners type confirms struct { - m sync.Mutex - listeners []chan Confirmation - sequencer map[uint64]Confirmation - published uint64 - publishedMut sync.Mutex - expecting uint64 + m sync.Mutex + listeners []chan Confirmation + sequencer map[uint64]Confirmation + deferredConfirmations *deferredConfirmations + published uint64 + publishedMut sync.Mutex + expecting uint64 } // newConfirms allocates a confirms func newConfirms() *confirms { return &confirms{ - sequencer: map[uint64]Confirmation{}, - published: 0, - expecting: 1, + sequencer: map[uint64]Confirmation{}, + deferredConfirmations: newDeferredConfirmations(), + published: 0, + expecting: 1, } } @@ -34,12 +38,12 @@ func (c *confirms) Listen(l chan Confirmation) { } // Publish increments the publishing counter -func (c *confirms) Publish() uint64 { +func (c *confirms) Publish() *DeferredConfirmation { c.publishedMut.Lock() defer c.publishedMut.Unlock() c.published++ - return c.published + return c.deferredConfirmations.Add(c.published) } // confirm confirms one publishing, increments the expecting delivery tag, and @@ -71,6 +75,8 @@ func (c *confirms) One(confirmed Confirmation) { c.m.Lock() defer c.m.Unlock() + c.deferredConfirmations.Confirm(confirmed) + if c.expecting == confirmed.DeliveryTag { c.confirm(confirmed) } else { @@ -84,6 +90,8 @@ func (c *confirms) Multiple(confirmed Confirmation) { c.m.Lock() defer c.m.Unlock() + c.deferredConfirmations.ConfirmMultiple(confirmed) + for c.expecting <= confirmed.DeliveryTag { c.confirm(Confirmation{c.expecting, confirmed.Ack}) } @@ -101,3 +109,56 @@ func (c *confirms) Close() error { c.listeners = nil return nil } + +type deferredConfirmations struct { + m sync.Mutex + confirmations map[uint64]*DeferredConfirmation +} + +func newDeferredConfirmations() *deferredConfirmations { + return &deferredConfirmations{ + confirmations: map[uint64]*DeferredConfirmation{}, + } +} + +func (d *deferredConfirmations) Add(tag uint64) *DeferredConfirmation { + d.m.Lock() + defer d.m.Unlock() + + dc := &DeferredConfirmation{DeliveryTag: tag} + dc.wg.Add(1) + d.confirmations[tag] = dc + return dc +} + +func (d *deferredConfirmations) Confirm(confirmation Confirmation) { + d.m.Lock() + defer d.m.Unlock() + + dc, found := d.confirmations[confirmation.DeliveryTag] + if !found { + // we should never receive a confirmation for a tag that hasn't been published, but a test causes this to happen + return + } + dc.confirmation = confirmation + dc.wg.Done() + delete(d.confirmations, confirmation.DeliveryTag) +} + +func (d *deferredConfirmations) ConfirmMultiple(confirmation Confirmation) { + d.m.Lock() + defer d.m.Unlock() + + for k, v := range d.confirmations { + if k <= confirmation.DeliveryTag { + v.confirmation = Confirmation{DeliveryTag: k, Ack: confirmation.Ack} + v.wg.Done() + delete(d.confirmations, k) + } + } +} + +func (d *DeferredConfirmation) Wait() bool { + d.wg.Wait() + return d.confirmation.Ack +} diff --git a/confirms_test.go b/confirms_test.go index d36e25d..ea8acdb 100644 --- a/confirms_test.go +++ b/confirms_test.go @@ -6,6 +6,7 @@ package amqp091 import ( + "sync" "testing" "time" ) @@ -24,8 +25,8 @@ func TestConfirmOneResequences(t *testing.T) { c.Listen(l) for i := range fixtures { - if want, got := uint64(i+1), c.Publish(); want != got { - t.Fatalf("expected publish to return the 1 based delivery tag published, want: %d, got: %d", want, got) + if want, got := uint64(i+1), c.Publish(); want != got.DeliveryTag { + t.Fatalf("expected publish to return the 1 based delivery tag published, want: %d, got: %d", want, got.DeliveryTag) } } @@ -139,7 +140,7 @@ func BenchmarkSequentialBufferedConfirms(t *testing.B) { if i > cap(l)-1 { <-l } - c.One(Confirmation{c.Publish(), true}) + c.One(Confirmation{c.Publish().DeliveryTag, true}) } } @@ -157,7 +158,7 @@ func TestConfirmsIsThreadSafe(t *testing.T) { c.Listen(l) for i := 0; i < count; i++ { - go func() { pub <- Confirmation{c.Publish(), true} }() + go func() { pub <- Confirmation{c.Publish().DeliveryTag, true} }() } for i := 0; i < count; i++ { @@ -176,3 +177,42 @@ func TestConfirmsIsThreadSafe(t *testing.T) { } } } + +func TestDeferredConfirmationsConfirm(t *testing.T) { + dcs := newDeferredConfirmations() + var wg sync.WaitGroup + for i, ack := range []bool{true, false} { + var result bool + deliveryTag := uint64(i + 1) + dc := dcs.Add(deliveryTag) + wg.Add(1) + go func() { + result = dc.Wait() + wg.Done() + }() + dcs.Confirm(Confirmation{deliveryTag, ack}) + wg.Wait() + if result != ack { + t.Fatalf("expected to receive matching ack got %v", result) + } + } +} + +func TestDeferredConfirmationsConfirmMultiple(t *testing.T) { + dcs := newDeferredConfirmations() + var wg sync.WaitGroup + var result bool + dc1 := dcs.Add(1) + dc2 := dcs.Add(2) + dc3 := dcs.Add(3) + wg.Add(1) + go func() { + result = dc1.Wait() && dc2.Wait() && dc3.Wait() + wg.Done() + }() + dcs.ConfirmMultiple(Confirmation{4, true}) + wg.Wait() + if !result { + t.Fatal("expected to receive true for result, received false") + } +} diff --git a/types.go b/types.go index eb256b1..3319990 100644 --- a/types.go +++ b/types.go @@ -8,6 +8,7 @@ package amqp091 import ( "fmt" "io" + "sync" "time" ) @@ -179,6 +180,15 @@ type Blocking struct { Reason string // Server reason for activation } +// DeferredConfirmation represents a future publisher confirm for a message. It +// allows users to directly correlate a publishing to a confirmation. These are +// returned from PublishWithDeferredConfirm on Channels. +type DeferredConfirmation struct { + wg sync.WaitGroup + DeliveryTag uint64 + confirmation Confirmation +} + // Confirmation notifies the acknowledgment or negative acknowledgement of a // publishing identified by its delivery tag. Use NotifyPublish on the Channel // to consume these events.