diff --git a/apollo-router/src/ageing_priority_queue.rs b/apollo-router/src/ageing_priority_queue.rs index 4dfb9b4c51..6af1c1c672 100644 --- a/apollo-router/src/ageing_priority_queue.rs +++ b/apollo-router/src/ageing_priority_queue.rs @@ -15,7 +15,10 @@ pub(crate) enum Priority { } #[derive(Debug, Clone)] -pub(crate) struct QueueIsFullError; +pub(crate) enum SendError { + QueueIsFull, + Disconnected, +} const INNER_QUEUES_COUNT: usize = Priority::P8 as usize - Priority::P1 as usize + 1; @@ -66,15 +69,16 @@ where } /// Panics if `priority` is not in `AVAILABLE_PRIORITIES` - pub(crate) fn send(&self, priority: Priority, message: T) -> Result<(), QueueIsFullError> { + pub(crate) fn send(&self, priority: Priority, message: T) -> Result<(), SendError> { self.queued_count .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |previous_count| { (previous_count < self.capacity).then_some(previous_count + 1) }) - .map_err(|_| QueueIsFullError)?; + .map_err(|_| SendError::QueueIsFull)?; let (inner_sender, _) = &self.inner_queues[index_from_priority(priority)]; - inner_sender.send(message).expect("disconnected channel"); - Ok(()) + inner_sender + .send(message) + .map_err(|crossbeam_channel::SendError(_)| SendError::Disconnected) } pub(crate) fn receiver(&self) -> Receiver<'_, T> { @@ -93,17 +97,17 @@ impl Receiver<'_, T> where T: Send + 'static, { - pub(crate) fn blocking_recv(&mut self) -> T { + pub(crate) fn blocking_recv(&mut self) -> Result { // Because we used `Select::new_biased` above, // `select()` will not shuffle receivers as it would with `Select::new` (for fairness) // but instead will try each one in priority order. let selected = self.select.select(); let index = selected.index(); let (_tx, rx) = &self.shared.inner_queues[index]; - let item = selected.recv(rx).expect("disconnected channel"); + let item = selected.recv(rx)?; self.shared.queued_count.fetch_sub(1, Ordering::Relaxed); self.age(index); - item + Ok(item) } // Promote some messages from priorities lower (higher indices) than `message_consumed_at_index` @@ -138,9 +142,9 @@ fn test_priorities() { assert_eq!(queue.queued_count(), 4); let mut receiver = queue.receiver(); - assert_eq!(receiver.blocking_recv(), "p3"); - assert_eq!(receiver.blocking_recv(), "p2"); - assert_eq!(receiver.blocking_recv(), "p2 again"); - assert_eq!(receiver.blocking_recv(), "p1"); + assert_eq!(receiver.blocking_recv().unwrap(), "p3"); + assert_eq!(receiver.blocking_recv().unwrap(), "p2"); + assert_eq!(receiver.blocking_recv().unwrap(), "p2 again"); + assert_eq!(receiver.blocking_recv().unwrap(), "p1"); assert_eq!(queue.queued_count(), 0); } diff --git a/apollo-router/src/compute_job.rs b/apollo-router/src/compute_job.rs index 3e28e7e670..2a7e0774d3 100644 --- a/apollo-router/src/compute_job.rs +++ b/apollo-router/src/compute_job.rs @@ -8,7 +8,7 @@ use tokio::sync::oneshot; use crate::ageing_priority_queue::AgeingPriorityQueue; pub(crate) use crate::ageing_priority_queue::Priority; -use crate::ageing_priority_queue::QueueIsFullError; +use crate::ageing_priority_queue::SendError; use crate::metrics::meter_provider; /// We generate backpressure in tower `poll_ready` when the number of queued jobs @@ -78,7 +78,10 @@ pub(crate) fn queue() -> &'static AgeingPriorityQueue { let mut receiver = queue.receiver(); loop { - let job = receiver.blocking_recv(); + // This `expect` never panics because this channel can never be disconnect: + // the sender is owned by `queue` which we can access here: + let _proof_of_life: &'static AgeingPriorityQueue<_> = queue; + let job = receiver.blocking_recv().expect("disconnected channel"); job(); } }); @@ -101,10 +104,23 @@ where // Ignore the error if the oneshot receiver was dropped let _ = tx.send(std::panic::catch_unwind(job)); }); - queue() - .send(priority, job) - .map_err(|QueueIsFullError| ComputeBackPressureError)?; - Ok(async { rx.await.expect("channel disconnected") }) + let queue = queue(); + queue.send(priority, job).map_err(|e| match e { + SendError::QueueIsFull => ComputeBackPressureError, + SendError::Disconnected => { + // This never panics because this channel can never be disconnect: + // the receiver is owned by `queue` which we can access here: + let _proof_of_life: &'static AgeingPriorityQueue<_> = queue; + unreachable!() + } + })?; + Ok(async move { + // This `expect` never panics because this oneshot channel can never be disconnect: + // the sender is owned by `job` which, if we reach here, was successfully sent to the queue. + // The queue or thread pool never drop a job without executing it. + // When executing, `catch_unwind` ensures that the sender cannot be dropped without sending. + rx.await.expect("channel disconnected") + }) } pub(crate) fn create_queue_size_gauge() -> ObservableGauge {