Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(driver): generic SharedFd #247

Merged
merged 1 commit into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 42 additions & 53 deletions compio-driver/src/fd.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
#[cfg(unix)]
use std::os::fd::FromRawFd;
#[cfg(windows)]
use std::os::windows::io::{
FromRawHandle, FromRawSocket, OwnedHandle, OwnedSocket, RawHandle, RawSocket,
};
use std::os::windows::io::{FromRawHandle, FromRawSocket, RawHandle, RawSocket};
use std::{
future::{poll_fn, Future},
mem::ManuallyDrop,
ops::Deref,
panic::RefUnwindSafe,
sync::{
atomic::{AtomicBool, Ordering},
Expand All @@ -17,35 +16,35 @@ use std::{

use futures_util::task::AtomicWaker;

use crate::{AsRawFd, OwnedFd, RawFd};
use crate::{AsRawFd, RawFd};

#[derive(Debug)]
struct Inner {
fd: OwnedFd,
struct Inner<T> {
fd: T,
// whether there is a future waiting
waits: AtomicBool,
waker: AtomicWaker,
}

impl RefUnwindSafe for Inner {}
impl<T> RefUnwindSafe for Inner<T> {}

/// A shared fd. It is passed to the operations to make sure the fd won't be
/// closed before the operations complete.
#[derive(Debug, Clone)]
pub struct SharedFd(Arc<Inner>);
#[derive(Debug)]
pub struct SharedFd<T>(Arc<Inner<T>>);

impl SharedFd {
impl<T> SharedFd<T> {
/// Create the shared fd from an owned fd.
pub fn new(fd: impl Into<OwnedFd>) -> Self {
pub fn new(fd: T) -> Self {
Self(Arc::new(Inner {
fd: fd.into(),
fd,
waits: AtomicBool::new(false),
waker: AtomicWaker::new(),
}))
}

/// Try to take the inner owned fd.
pub fn try_unwrap(self) -> Result<OwnedFd, Self> {
pub fn try_unwrap(self) -> Result<T, Self> {
let this = ManuallyDrop::new(self);
if let Some(fd) = unsafe { Self::try_unwrap_inner(&this) } {
Ok(fd)
Expand All @@ -55,7 +54,7 @@ impl SharedFd {
}

// SAFETY: if `Some` is returned, the method should not be called again.
unsafe fn try_unwrap_inner(this: &ManuallyDrop<Self>) -> Option<OwnedFd> {
unsafe fn try_unwrap_inner(this: &ManuallyDrop<Self>) -> Option<T> {
let ptr = ManuallyDrop::new(std::ptr::read(&this.0));
// The ptr is duplicated without increasing the strong count, should forget.
match Arc::try_unwrap(ManuallyDrop::into_inner(ptr)) {
Expand All @@ -68,7 +67,7 @@ impl SharedFd {
}

/// Wait and take the inner owned fd.
pub fn take(self) -> impl Future<Output = Option<OwnedFd>> {
pub fn take(self) -> impl Future<Output = Option<T>> {
let this = ManuallyDrop::new(self);
async move {
if !this.0.waits.swap(true, Ordering::AcqRel) {
Expand All @@ -93,7 +92,7 @@ impl SharedFd {
}
}

impl Drop for SharedFd {
impl<T> Drop for SharedFd<T> {
fn drop(&mut self) {
// It's OK to wake multiple times.
if Arc::strong_count(&self.0) == 2 {
Expand All @@ -102,71 +101,61 @@ impl Drop for SharedFd {
}
}

#[cfg(windows)]
#[doc(hidden)]
impl SharedFd {
pub unsafe fn to_file(&self) -> ManuallyDrop<std::fs::File> {
ManuallyDrop::new(std::fs::File::from_raw_handle(self.as_raw_fd() as _))
}

pub unsafe fn to_socket(&self) -> ManuallyDrop<socket2::Socket> {
ManuallyDrop::new(socket2::Socket::from_raw_socket(self.as_raw_fd() as _))
}
}

#[cfg(unix)]
#[doc(hidden)]
impl SharedFd {
pub unsafe fn to_file(&self) -> ManuallyDrop<std::fs::File> {
ManuallyDrop::new(std::fs::File::from_raw_fd(self.as_raw_fd() as _))
}

pub unsafe fn to_socket(&self) -> ManuallyDrop<socket2::Socket> {
ManuallyDrop::new(socket2::Socket::from_raw_fd(self.as_raw_fd() as _))
}
}

impl AsRawFd for SharedFd {
impl<T: AsRawFd> AsRawFd for SharedFd<T> {
fn as_raw_fd(&self) -> RawFd {
self.0.fd.as_raw_fd()
}
}

#[cfg(windows)]
impl FromRawHandle for SharedFd {
impl<T: FromRawHandle> FromRawHandle for SharedFd<T> {
unsafe fn from_raw_handle(handle: RawHandle) -> Self {
Self::new(OwnedFd::File(OwnedHandle::from_raw_handle(handle)))
Self::new(T::from_raw_handle(handle))
}
}

#[cfg(windows)]
impl FromRawSocket for SharedFd {
impl<T: FromRawSocket> FromRawSocket for SharedFd<T> {
unsafe fn from_raw_socket(sock: RawSocket) -> Self {
Self::new(OwnedFd::Socket(OwnedSocket::from_raw_socket(sock)))
Self::new(T::from_raw_socket(sock))
}
}

#[cfg(unix)]
impl FromRawFd for SharedFd {
impl<T: FromRawFd> FromRawFd for SharedFd<T> {
unsafe fn from_raw_fd(fd: RawFd) -> Self {
Self::new(OwnedFd::from_raw_fd(fd))
Self::new(T::from_raw_fd(fd))
}
}

impl From<OwnedFd> for SharedFd {
fn from(value: OwnedFd) -> Self {
impl<T> From<T> for SharedFd<T> {
fn from(value: T) -> Self {
Self::new(value)
}
}

impl<T> Clone for SharedFd<T> {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}

impl<T> Deref for SharedFd<T> {
type Target = T;

fn deref(&self) -> &Self::Target {
&self.0.fd
}
}

/// Get a clone of [`SharedFd`].
pub trait ToSharedFd {
pub trait ToSharedFd<T> {
/// Return a cloned [`SharedFd`].
fn to_shared_fd(&self) -> SharedFd;
fn to_shared_fd(&self) -> SharedFd<T>;
}

impl ToSharedFd for SharedFd {
fn to_shared_fd(&self) -> SharedFd {
impl<T> ToSharedFd<T> for SharedFd<T> {
fn to_shared_fd(&self) -> SharedFd<T> {
self.clone()
}
}
12 changes: 6 additions & 6 deletions compio-driver/src/fusion/op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ pub use crate::unix::op::*;
use crate::SharedFd;

macro_rules! op {
(<$($ty:ident: $trait:ident),* $(,)?> $name:ident( $($arg:ident: $arg_t:ident),* $(,)? )) => {
(<$($ty:ident: $trait:ident),* $(,)?> $name:ident( $($arg:ident: $arg_t:ty),* $(,)? )) => {
::paste::paste!{
enum [< $name Inner >] <$($ty: $trait),*> {
Poll(poll::$name<$($ty),*>),
Expand Down Expand Up @@ -92,9 +92,9 @@ mod iour { pub use crate::sys::iour::{op::*, OpCode}; }
#[rustfmt::skip]
mod poll { pub use crate::sys::poll::{op::*, OpCode}; }

op!(<T: IoBufMut> RecvFrom(fd: SharedFd, buffer: T));
op!(<T: IoBuf> SendTo(fd: SharedFd, buffer: T, addr: SockAddr));
op!(<T: IoVectoredBufMut> RecvFromVectored(fd: SharedFd, buffer: T));
op!(<T: IoVectoredBuf> SendToVectored(fd: SharedFd, buffer: T, addr: SockAddr));
op!(<> FileStat(fd: SharedFd));
op!(<T: IoBufMut, S: AsRawFd> RecvFrom(fd: SharedFd<S>, buffer: T));
op!(<T: IoBuf, S: AsRawFd> SendTo(fd: SharedFd<S>, buffer: T, addr: SockAddr));
op!(<T: IoVectoredBufMut, S: AsRawFd> RecvFromVectored(fd: SharedFd<S>, buffer: T));
op!(<T: IoVectoredBuf, S: AsRawFd> SendToVectored(fd: SharedFd<S>, buffer: T, addr: SockAddr));
op!(<S: AsRawFd> FileStat(fd: SharedFd<S>));
op!(<> PathStat(path: CString, follow_symlink: bool));
30 changes: 30 additions & 0 deletions compio-driver/src/iocp/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,36 @@ impl AsRawFd for OwnedFd {
}
}

impl AsRawFd for RawFd {
fn as_raw_fd(&self) -> RawFd {
*self
}
}

impl AsRawFd for std::fs::File {
fn as_raw_fd(&self) -> RawFd {
self.as_raw_handle() as _
}
}

impl AsRawFd for OwnedHandle {
fn as_raw_fd(&self) -> RawFd {
self.as_raw_handle() as _
}
}

impl AsRawFd for socket2::Socket {
fn as_raw_fd(&self) -> RawFd {
self.as_raw_socket() as _
}
}

impl AsRawFd for OwnedSocket {
fn as_raw_fd(&self) -> RawFd {
self.as_raw_socket() as _
}
}

impl From<OwnedHandle> for OwnedFd {
fn from(value: OwnedHandle) -> Self {
Self::File(value)
Expand Down
Loading