From 8ddf200ca765b35b66e031d661df708aa7ca84ad Mon Sep 17 00:00:00 2001 From: Elena Frank Date: Wed, 10 Aug 2022 10:10:12 +0200 Subject: [PATCH] *: Replace Future impl with poll_next method (#23) Remove Future impl on all platform IfWatcher's, instead add `poll_next` method. Implement `Stream` and `FusedStream` for user-facing IfWatcher. --- CHANGELOG.md | 7 +++++++ Cargo.toml | 2 +- examples/if_watch.rs | 5 +++-- src/apple.rs | 20 ++++++++------------ src/fallback.rs | 6 +----- src/lib.rs | 24 ++++++++++++++++++------ src/linux.rs | 33 +++++++++++++++------------------ src/win.rs | 22 +++++++++------------- 8 files changed, 62 insertions(+), 57 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7219911..1b54da8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,13 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [2.0.0] [Unreleased] + +### Changed +- Add `IfWatcher::poll_next`. Implement `Stream` instead of `Future` for `IfWatcher`. See [PR 23]. + +[PR 23]: https://github.com/mxinden/if-watch/pull/23 + ## [1.1.1] ### Fixed diff --git a/Cargo.toml b/Cargo.toml index ef5f2d9..0f6039a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "if-watch" -version = "1.1.1" +version = "2.0.0" authors = ["David Craven ", "Parity Technologies Limited "] edition = "2021" keywords = ["asynchronous", "routing"] diff --git a/examples/if_watch.rs b/examples/if_watch.rs index 9a0cc7e..49a3e3a 100644 --- a/examples/if_watch.rs +++ b/examples/if_watch.rs @@ -1,12 +1,13 @@ +use futures::StreamExt; use if_watch::IfWatcher; -use std::pin::Pin; fn main() { env_logger::init(); futures::executor::block_on(async { let mut set = IfWatcher::new().await.unwrap(); loop { - println!("Got event {:?}", Pin::new(&mut set).await); + let event = set.select_next_some().await; + println!("Got event {:?}", event); } }); } diff --git a/src/apple.rs b/src/apple.rs index aeb8972..74beb5b 100644 --- a/src/apple.rs +++ b/src/apple.rs @@ -7,7 +7,6 @@ use futures::channel::mpsc; use futures::stream::Stream; use if_addrs::IfAddr; use std::collections::VecDeque; -use std::future::Future; use std::io::Result; use std::pin::Pin; use std::task::{Context, Poll}; @@ -59,22 +58,19 @@ impl IfWatcher { pub fn iter(&self) -> impl Iterator { self.addrs.iter() } -} - -impl Future for IfWatcher { - type Output = Result; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { - while let Poll::Ready(_) = Pin::new(&mut self.rx).poll_next(cx) { + pub fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + loop { + if let Some(event) = self.queue.pop_front() { + return Poll::Ready(Ok(event)); + } + if Pin::new(&mut self.rx).poll_next(cx).is_pending() { + return Poll::Pending; + } if let Err(error) = self.resync() { return Poll::Ready(Err(error)); } } - if let Some(event) = self.queue.pop_front() { - Poll::Ready(Ok(event)) - } else { - Poll::Pending - } } } diff --git a/src/fallback.rs b/src/fallback.rs index 22d54f2..bc4886e 100644 --- a/src/fallback.rs +++ b/src/fallback.rs @@ -48,12 +48,8 @@ impl IfWatcher { pub fn iter(&self) -> impl Iterator { self.addrs.iter() } -} - -impl Future for IfWatcher { - type Output = Result; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { + pub fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { loop { if let Some(event) = self.queue.pop_front() { return Poll::Ready(Ok(event)); diff --git a/src/lib.rs b/src/lib.rs index 3198818..aad7077 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,8 +2,9 @@ #![deny(missing_docs)] #![deny(warnings)] +use futures::stream::FusedStream; +use futures::Stream; pub use ipnet::{IpNet, Ipv4Net, Ipv6Net}; -use std::future::Future; use std::io::Result; use std::pin::Pin; use std::task::{Context, Poll}; @@ -63,25 +64,36 @@ impl IfWatcher { pub fn iter(&self) -> impl Iterator { self.0.iter() } + + /// Poll for an address change event. + pub fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + Pin::new(&mut self.0).poll_next(cx) + } } -impl Future for IfWatcher { - type Output = Result; +impl Stream for IfWatcher { + type Item = Result; + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.poll_next(cx).map(Some) + } +} - fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { - Pin::new(&mut self.0).poll(cx) +impl FusedStream for IfWatcher { + fn is_terminated(&self) -> bool { + false } } #[cfg(test)] mod tests { use super::*; + use futures::StreamExt; #[test] fn test_ip_watch() { futures::executor::block_on(async { let mut set = IfWatcher::new().await.unwrap(); - let event = Pin::new(&mut set).await.unwrap(); + let event = set.select_next_some().await.unwrap(); println!("Got event {:?}", event); }); } diff --git a/src/linux.rs b/src/linux.rs index c0afed1..a35cd9c 100644 --- a/src/linux.rs +++ b/src/linux.rs @@ -2,6 +2,7 @@ use crate::{IfEvent, IpNet, Ipv4Net, Ipv6Net}; use fnv::FnvHashSet; use futures::channel::mpsc::UnboundedReceiver; use futures::future::Either; +use futures::ready; use futures::stream::{Stream, TryStreamExt}; use rtnetlink::constants::{RTMGRP_IPV4_IFADDR, RTMGRP_IPV6_IFADDR}; use rtnetlink::packet::address::nlas::Nla; @@ -12,7 +13,6 @@ use std::collections::VecDeque; use std::future::Future; use std::io::{Error, ErrorKind, Result}; use std::net::{Ipv4Addr, Ipv6Addr}; -use std::ops::DerefMut; use std::pin::Pin; use std::task::{Context, Poll}; @@ -93,20 +93,17 @@ impl IfWatcher { } } } -} - -impl Future for IfWatcher { - type Output = Result; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { - log::trace!("polling IfWatcher {:p}", self.deref_mut()); - if Pin::new(&mut self.conn).poll(cx).is_ready() { - return Poll::Ready(Err(std::io::Error::new( - ErrorKind::BrokenPipe, - "rtnetlink socket closed", - ))); - } - while let Poll::Ready(Some((message, _))) = Pin::new(&mut self.messages).poll_next(cx) { + pub fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + loop { + if let Some(event) = self.queue.pop_front() { + return Poll::Ready(Ok(event)); + } + if Pin::new(&mut self.conn).poll(cx).is_ready() { + return Poll::Ready(Err(socket_err())); + } + let (message, _) = + ready!(Pin::new(&mut self.messages).poll_next(cx)).ok_or_else(socket_err)?; match message.payload { NetlinkPayload::Error(err) => return Poll::Ready(Err(err.to_io())), NetlinkPayload::InnerMessage(msg) => match msg { @@ -117,13 +114,13 @@ impl Future for IfWatcher { _ => {} } } - if let Some(event) = self.queue.pop_front() { - return Poll::Ready(Ok(event)); - } - Poll::Pending } } +fn socket_err() -> std::io::Error { + std::io::Error::new(ErrorKind::BrokenPipe, "rtnetlink socket closed") +} + fn iter_nets(msg: AddressMessage) -> impl Iterator { let prefix = msg.header.prefix_len; let family = msg.header.family; diff --git a/src/win.rs b/src/win.rs index c9c3ff7..8444c70 100644 --- a/src/win.rs +++ b/src/win.rs @@ -4,7 +4,6 @@ use futures::task::AtomicWaker; use if_addrs::IfAddr; use std::collections::VecDeque; use std::ffi::c_void; -use std::future::Future; use std::io::{Error, ErrorKind, Result}; use std::pin::Pin; use std::sync::atomic::{AtomicBool, Ordering}; @@ -68,23 +67,20 @@ impl IfWatcher { pub fn iter(&self) -> impl Iterator { self.addrs.iter() } -} - -impl Future for IfWatcher { - type Output = Result; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { - self.waker.register(cx.waker()); - if self.resync.swap(false, Ordering::Relaxed) { + pub fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + loop { + if let Some(event) = self.queue.pop_front() { + return Poll::Ready(Ok(event)); + } + if !self.resync.swap(false, Ordering::Relaxed) { + self.waker.register(cx.waker()); + return Poll::Pending; + } if let Err(error) = self.resync() { return Poll::Ready(Err(error)); } } - if let Some(event) = self.queue.pop_front() { - Poll::Ready(Ok(event)) - } else { - Poll::Pending - } } }