From d7399608ee8abc87fe551a3236c55fdb040bd880 Mon Sep 17 00:00:00 2001 From: Andrei Nesterov Date: Sun, 2 Jun 2024 20:49:28 +0300 Subject: [PATCH] Enable the development of custom consumers It is currently impossible to develop a custom consumer based on `BaseConsumer` because its `queue` property, which is necessary to receive notifications about new incoming messages, is private. This defines `set_nonempty_callback` method on `BaseConsumer` similarly to how it has already been done for `PartitionQueue`. That will allow setting `rdkafka_sys::rd_kafka_queue_cb_event_enable` callback from within a custom consumer implementation. --- src/consumer/base_consumer.rs | 34 +++++++++++++ tests/test_low_consumers.rs | 95 +++++++++++++++++++++++++++++++++++ 2 files changed, 129 insertions(+) diff --git a/src/consumer/base_consumer.rs b/src/consumer/base_consumer.rs index c67c90cb2..0beafd04c 100644 --- a/src/consumer/base_consumer.rs +++ b/src/consumer/base_consumer.rs @@ -38,6 +38,7 @@ where client: Client, queue: NativeQueue, group_id: Option, + nonempty_callback: Option>>, } impl FromClientConfig for BaseConsumer { @@ -98,6 +99,7 @@ where client, queue, group_id, + nonempty_callback: None, }) } @@ -360,6 +362,36 @@ where pub(crate) fn native_client(&self) -> &NativeClient { self.client.native_client() } + + /// Sets a callback that will be invoked whenever the queue becomes + /// nonempty. + pub fn set_nonempty_callback(&mut self, f: F) + where + F: Fn() + Send + Sync + 'static, + { + // SAFETY: we keep `F` alive until the next call to + // `rd_kafka_queue_cb_event_enable`. That might be the next call to + // `set_nonempty_callback` or it might be when the queue is dropped. The + // double indirection is required because `&dyn Fn` is a fat pointer. + + unsafe extern "C" fn native_message_queue_nonempty_cb( + _: *mut RDKafka, + opaque_ptr: *mut c_void, + ) { + let f = opaque_ptr as *const *const (dyn Fn() + Send + Sync); + (**f)(); + } + + let f: Box> = Box::new(Box::new(f)); + unsafe { + rdsys::rd_kafka_queue_cb_event_enable( + self.queue.ptr(), + Some(native_message_queue_nonempty_cb), + &*f as *const _ as *mut c_void, + ) + } + self.nonempty_callback = Some(f); + } } impl Consumer for BaseConsumer @@ -722,6 +754,8 @@ where C: ConsumerContext, { fn drop(&mut self) { + unsafe { rdsys::rd_kafka_queue_cb_event_enable(self.queue.ptr(), None, ptr::null_mut()) } + trace!("Destroying consumer: {:?}", self.client.native_ptr()); if self.group_id.is_some() { if let Err(err) = self.close_queue() { diff --git a/tests/test_low_consumers.rs b/tests/test_low_consumers.rs index 3b4cb19e8..aaecffe96 100644 --- a/tests/test_low_consumers.rs +++ b/tests/test_low_consumers.rs @@ -447,6 +447,101 @@ async fn test_produce_consume_message_queue_nonempty_callback() { assert_eq!(wakeups.load(Ordering::SeqCst), 2); } +#[tokio::test] +async fn test_produce_consume_consumer_nonempty_callback() { + let _r = env_logger::try_init(); + + let topic_name = rand_test_topic("test_produce_consume_consumer_nonempty_callback"); + + create_topic(&topic_name, 1).await; + + // Turn off statistics to prevent interference with the wakeups counter. + let mut config_overrides = HashMap::new(); + config_overrides.insert("statistics.interval.ms", "0"); + + let mut consumer: BaseConsumer<_> = consumer_config(&rand_test_group(), Some(config_overrides)) + .create_with_context(ConsumerTestContext { _n: 64 }) + .expect("Consumer creation failed"); + + let mut tpl = TopicPartitionList::new(); + tpl.add_partition_offset(&topic_name, 0, Offset::Beginning) + .unwrap(); + consumer.assign(&tpl).unwrap(); + + let wakeups = Arc::new(AtomicUsize::new(0)); + consumer.set_nonempty_callback({ + let wakeups = wakeups.clone(); + move || { + wakeups.fetch_add(1, Ordering::SeqCst); + } + }); + + let wait_for_wakeups = |target| { + let start = Instant::now(); + let timeout = Duration::from_secs(15); + loop { + let w = wakeups.load(Ordering::SeqCst); + match w.cmp(&target) { + std::cmp::Ordering::Equal => break, + std::cmp::Ordering::Greater => panic!("wakeups {} exceeds target {}", w, target), + std::cmp::Ordering::Less => (), + }; + thread::sleep(Duration::from_millis(100)); + if start.elapsed() > timeout { + panic!("timeout exceeded while waiting for wakeup"); + } + } + }; + + // Initiate connection. + assert!(consumer.poll(Duration::from_secs(0)).is_none()); + + // Expect no wakeups for 1s. + thread::sleep(Duration::from_secs(1)); + assert_eq!(wakeups.load(Ordering::SeqCst), 0); + + // Verify there are no messages waiting. + assert!(consumer.poll(Duration::from_secs(0)).is_none()); + + // Populate the topic, and expect a wakeup notifying us of the new messages. + populate_topic(&topic_name, 2, &value_fn, &key_fn, None, None).await; + wait_for_wakeups(1); + + // Read one of the messages. + assert!(consumer.poll(Duration::from_secs(0)).is_some()); + + // Add more messages to the topic. Expect no additional wakeups, as the + // queue is not fully drained, for 1s. + populate_topic(&topic_name, 2, &value_fn, &key_fn, None, None).await; + thread::sleep(Duration::from_secs(1)); + assert_eq!(wakeups.load(Ordering::SeqCst), 1); + + // Drain the queue. + assert!(consumer.poll(None).is_some()); + assert!(consumer.poll(None).is_some()); + assert!(consumer.poll(None).is_some()); + + // Expect no additional wakeups for 1s. + thread::sleep(Duration::from_secs(1)); + assert_eq!(wakeups.load(Ordering::SeqCst), 1); + + // Add another message, and expect a wakeup. + populate_topic(&topic_name, 1, &value_fn, &key_fn, None, None).await; + wait_for_wakeups(2); + + // Expect no additional wakeups for 1s. + thread::sleep(Duration::from_secs(1)); + assert_eq!(wakeups.load(Ordering::SeqCst), 2); + + // Disable the queue and add another message. + consumer.set_nonempty_callback(|| ()); + populate_topic(&topic_name, 1, &value_fn, &key_fn, None, None).await; + + // Expect no additional wakeups for 1s. + thread::sleep(Duration::from_secs(1)); + assert_eq!(wakeups.load(Ordering::SeqCst), 2); +} + #[tokio::test] async fn test_invalid_consumer_position() { // Regression test for #360, in which calling `position` on a consumer which