diff --git a/src/apple.rs b/src/apple.rs index 5942ce2..a5826a0 100644 --- a/src/apple.rs +++ b/src/apple.rs @@ -61,16 +61,17 @@ impl IfWatcher { } pub fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - while let Poll::Ready(_) = Pin::new(&mut self.rx).poll_next(cx) { + loop { + if let Some(event) = self.queue.pop_front() { + return Poll::Ready(Ok(event)); + } + if Pin::new(&mut self.rx).poll_next(cx).is_pending() { + return Poll::Pending; + } if let Err(error) = self.resync() { return Poll::Ready(Err(error)); } } - if let Some(event) = self.queue.pop_front() { - Poll::Ready(Ok(event)) - } else { - Poll::Pending - } } } diff --git a/src/linux.rs b/src/linux.rs index ce0c85d..a35cd9c 100644 --- a/src/linux.rs +++ b/src/linux.rs @@ -2,6 +2,7 @@ use crate::{IfEvent, IpNet, Ipv4Net, Ipv6Net}; use fnv::FnvHashSet; use futures::channel::mpsc::UnboundedReceiver; use futures::future::Either; +use futures::ready; use futures::stream::{Stream, TryStreamExt}; use rtnetlink::constants::{RTMGRP_IPV4_IFADDR, RTMGRP_IPV6_IFADDR}; use rtnetlink::packet::address::nlas::Nla; @@ -12,7 +13,6 @@ use std::collections::VecDeque; use std::future::Future; use std::io::{Error, ErrorKind, Result}; use std::net::{Ipv4Addr, Ipv6Addr}; -use std::ops::DerefMut; use std::pin::Pin; use std::task::{Context, Poll}; @@ -95,14 +95,15 @@ impl IfWatcher { } pub fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - log::trace!("polling IfWatcher {:p}", self.deref_mut()); - if Pin::new(&mut self.conn).poll(cx).is_ready() { - return Poll::Ready(Err(std::io::Error::new( - ErrorKind::BrokenPipe, - "rtnetlink socket closed", - ))); - } - while let Poll::Ready(Some((message, _))) = Pin::new(&mut self.messages).poll_next(cx) { + loop { + if let Some(event) = self.queue.pop_front() { + return Poll::Ready(Ok(event)); + } + if Pin::new(&mut self.conn).poll(cx).is_ready() { + return Poll::Ready(Err(socket_err())); + } + let (message, _) = + ready!(Pin::new(&mut self.messages).poll_next(cx)).ok_or_else(socket_err)?; match message.payload { NetlinkPayload::Error(err) => return Poll::Ready(Err(err.to_io())), NetlinkPayload::InnerMessage(msg) => match msg { @@ -113,13 +114,13 @@ impl IfWatcher { _ => {} } } - if let Some(event) = self.queue.pop_front() { - return Poll::Ready(Ok(event)); - } - Poll::Pending } } +fn socket_err() -> std::io::Error { + std::io::Error::new(ErrorKind::BrokenPipe, "rtnetlink socket closed") +} + fn iter_nets(msg: AddressMessage) -> impl Iterator { let prefix = msg.header.prefix_len; let family = msg.header.family; diff --git a/src/win.rs b/src/win.rs index d2b4f69..c55ec6d 100644 --- a/src/win.rs +++ b/src/win.rs @@ -70,17 +70,18 @@ impl IfWatcher { } pub fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - self.waker.register(cx.waker()); - if self.resync.swap(false, Ordering::Relaxed) { + loop { + if let Some(event) = self.queue.pop_front() { + Poll::Ready(Ok(event)) + } + if !self.resync.swap(false, Ordering::Relaxed) { + self.waker.register(cx.waker()); + return Poll::Pending; + } if let Err(error) = self.resync() { return Poll::Ready(Err(error)); } } - if let Some(event) = self.queue.pop_front() { - Poll::Ready(Ok(event)) - } else { - Poll::Pending - } } }