diff --git a/Cargo.toml b/Cargo.toml index 7dc2f644..2b6608a3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,6 +10,7 @@ members = [ "compio-signal", "compio-dispatcher", "compio-io", + "compio-tls", ] resolver = "2" diff --git a/azure-pipelines.yml b/azure-pipelines.yml index b0468474..4aecc6d2 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -47,7 +47,7 @@ jobs: - script: | rustup toolchain install nightly - cargo +nightly test --workspace --features all,polling,nightly --no-default-features + cargo +nightly test --workspace --features all,polling,native-tls,nightly --no-default-features displayName: TestNightly-polling # - script: | # cargo test --workspace --features all,polling --no-default-features diff --git a/compio-io/src/buffer.rs b/compio-io/src/buffer.rs index 03edfc71..73efb33b 100644 --- a/compio-io/src/buffer.rs +++ b/compio-io/src/buffer.rs @@ -79,6 +79,7 @@ impl IoBufMut for Inner { pub struct Buffer(Option); impl Buffer { + /// Create a buffer with capacity. #[inline] pub fn with_capacity(cap: usize) -> Self { Self(Some(Inner { @@ -87,11 +88,18 @@ impl Buffer { })) } + /// Get the initialized but not consumed part of the buffer. #[inline] pub fn slice(&self) -> &[u8] { self.inner().slice() } + /// If the inner buffer is empty. + #[inline] + pub fn is_empty(&self) -> bool { + self.inner().as_slice().is_empty() + } + /// All bytes in the buffer have been read #[inline] pub fn all_done(&self) -> bool { @@ -113,6 +121,7 @@ impl Buffer { buf.len() > buf.capacity() * 2 / 3 } + /// Clear the inner buffer and reset the position to the start. #[inline] pub fn reset(&mut self) { self.inner_mut().reset(); @@ -130,6 +139,17 @@ impl Buffer { res } + /// Execute a funcition with ownership of the buffer, and restore the buffer + /// afterwards + pub fn with_sync( + &mut self, + func: impl FnOnce(Inner) -> BufResult, + ) -> std::io::Result { + let BufResult(res, buf) = func(self.take_inner()); + self.restore_inner(buf); + res + } + /// Mark some bytes as read by advancing the progress tracker, return a /// `bool` indicating if all bytes are read. #[inline] diff --git a/compio-io/src/lib.rs b/compio-io/src/lib.rs index fff61a35..8f9e6754 100644 --- a/compio-io/src/lib.rs +++ b/compio-io/src/lib.rs @@ -95,6 +95,7 @@ mod write; pub(crate) type IoResult = std::io::Result; +pub use buffer::Buffer; pub use read::*; pub use util::{copy, null, repeat}; pub use write::*; diff --git a/compio-io/src/read/buf.rs b/compio-io/src/read/buf.rs index 8f202b7e..12d07235 100644 --- a/compio-io/src/read/buf.rs +++ b/compio-io/src/read/buf.rs @@ -1,4 +1,4 @@ -use compio_buf::{buf_try, BufResult, IntoInner, IoBufMut, IoVectoredBufMut}; +use compio_buf::{buf_try, BufResult, IntoInner, IoBuf, IoBufMut, IoVectoredBufMut}; use crate::{buffer::Buffer, util::DEFAULT_BUF_SIZE, AsyncRead, IoResult}; /// # AsyncBufRead @@ -96,7 +96,12 @@ impl AsyncBufRead for BufReader { } if buf.need_fill() { - buf.with(|b| reader.read(b)).await?; + buf.with(|b| async move { + let len = b.buf_len(); + let b = b.slice(len..); + reader.read(b).await.into_inner() + }) + .await?; } Ok(buf.slice()) diff --git a/compio-io/src/write/buf.rs b/compio-io/src/write/buf.rs index c4e0c588..35af2e7e 100644 --- a/compio-io/src/write/buf.rs +++ b/compio-io/src/write/buf.rs @@ -5,7 +5,7 @@ use compio_buf::{buf_try, BufResult, IntoInner, IoBuf}; use crate::{ buffer::Buffer, util::{slice_to_buf, DEFAULT_BUF_SIZE}, - AsyncWrite, IoResult, + AsyncWrite, AsyncWriteExt, IoResult, }; /// Wraps a writer and buffers its output. @@ -52,11 +52,12 @@ impl AsyncWrite for BufWriter { async fn write(&mut self, mut buf: T) -> compio_buf::BufResult { let written = self .buf - .with(|mut w| { + .with_sync(|w| { + let len = w.buf_len(); + let mut w = w.slice(len..); let written = slice_to_buf(buf.as_slice(), &mut w); - ready(BufResult(Ok(written), w)) + BufResult(Ok(written), w.into_inner()) }) - .await .expect("Closure always return Ok"); if self.buf.need_flush() { @@ -75,7 +76,10 @@ impl AsyncWrite for BufWriter { .with(|mut w| { let mut written = 0; for buf in buf.as_dyn_bufs() { - written += slice_to_buf(buf.as_slice(), &mut w); + let len = w.buf_len(); + let mut slice = w.slice(len..); + written += slice_to_buf(buf.as_slice(), &mut slice); + w = slice.into_inner(); if w.buf_len() == w.buf_capacity() { break; @@ -96,12 +100,8 @@ impl AsyncWrite for BufWriter { async fn flush(&mut self) -> IoResult<()> { let Self { writer, buf } = self; - let len = buf.with(|w| writer.write(w)).await?; - buf.advance(len); - - if buf.all_done() { - buf.reset(); - } + buf.with(|w| writer.write_all(w)).await?; + buf.reset(); Ok(()) } diff --git a/compio-net/Cargo.toml b/compio-net/Cargo.toml index 434993c4..389612c1 100644 --- a/compio-net/Cargo.toml +++ b/compio-net/Cargo.toml @@ -19,7 +19,7 @@ rustdoc-args = ["--cfg docsrs"] compio-buf = { workspace = true } compio-driver = { workspace = true } compio-io = { workspace = true, optional = true } -compio-runtime = { workspace = true, optional = true } +compio-runtime = { workspace = true, optional = true, features = ["event"] } cfg-if = "1" either = "1" diff --git a/compio-tls/Cargo.toml b/compio-tls/Cargo.toml new file mode 100644 index 00000000..16319e6a --- /dev/null +++ b/compio-tls/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "compio-tls" +version = "0.1.0" +categories = ["asynchronous", "network-programming"] +keywords = ["async", "net", "tls"] +edition = { workspace = true } +authors = { workspace = true } +readme = { workspace = true } +license = { workspace = true } +repository = { workspace = true } + +[dependencies] +compio-buf = { workspace = true } +compio-io = { workspace = true } + +native-tls = { version = "0.2", optional = true } + +[dev-dependencies] +compio-net = { workspace = true, features = ["runtime"] } +compio-runtime = { workspace = true } + +[features] +default = ["native-tls"] diff --git a/compio-tls/src/adapter.rs b/compio-tls/src/adapter.rs new file mode 100644 index 00000000..bf523a5c --- /dev/null +++ b/compio-tls/src/adapter.rs @@ -0,0 +1,109 @@ +use std::io; + +use compio_io::{AsyncRead, AsyncWrite}; +use native_tls::HandshakeError; + +use crate::{wrapper::StreamWrapper, TlsStream}; + +/// A wrapper around a [`native_tls::TlsConnector`], providing an async +/// `connect` method. +/// +/// ```rust +/// use compio_io::{AsyncReadExt, AsyncWrite, AsyncWriteExt}; +/// use compio_net::TcpStream; +/// use compio_tls::TlsConnector; +/// +/// # compio_runtime::block_on(async { +/// let connector = TlsConnector::from(native_tls::TlsConnector::new().unwrap()); +/// +/// let stream = TcpStream::connect("www.example.com:443").await.unwrap(); +/// let mut stream = connector.connect("www.example.com", stream).await.unwrap(); +/// +/// stream +/// .write_all("GET / HTTP/1.1\r\nHost:www.example.com\r\nConnection: close\r\n\r\n") +/// .await +/// .unwrap(); +/// stream.flush().await.unwrap(); +/// let (_, res) = stream.read_to_end(vec![]).await.unwrap(); +/// println!("{}", String::from_utf8_lossy(&res)); +/// # }) +/// ``` +#[derive(Debug, Clone)] +pub struct TlsConnector(native_tls::TlsConnector); + +impl From for TlsConnector { + fn from(value: native_tls::TlsConnector) -> Self { + Self(value) + } +} + +impl TlsConnector { + /// Connects the provided stream with this connector, assuming the provided + /// domain. + /// + /// This function will internally call `TlsConnector::connect` to connect + /// the stream and returns a future representing the resolution of the + /// connection operation. The returned future will resolve to either + /// `TlsStream` or `Error` depending if it's successful or not. + /// + /// This is typically used for clients who have already established, for + /// example, a TCP connection to a remote server. That stream is then + /// provided here to perform the client half of a connection to a + /// TLS-powered server. + pub async fn connect( + &self, + domain: &str, + stream: S, + ) -> io::Result> { + handshake(self.0.connect(domain, StreamWrapper::new(stream))).await + } +} + +/// A wrapper around a [`native_tls::TlsAcceptor`], providing an async `accept` +/// method. +#[derive(Clone)] +pub struct TlsAcceptor(native_tls::TlsAcceptor); + +impl From for TlsAcceptor { + fn from(value: native_tls::TlsAcceptor) -> Self { + Self(value) + } +} + +impl TlsAcceptor { + /// Accepts a new client connection with the provided stream. + /// + /// This function will internally call `TlsAcceptor::accept` to connect + /// the stream and returns a future representing the resolution of the + /// connection operation. The returned future will resolve to either + /// `TlsStream` or `Error` depending if it's successful or not. + /// + /// This is typically used after a new socket has been accepted from a + /// `TcpListener`. That socket is then passed to this function to perform + /// the server half of accepting a client connection. + pub async fn accept(&self, stream: S) -> io::Result> { + handshake(self.0.accept(StreamWrapper::new(stream))).await + } +} + +async fn handshake( + mut res: Result>, HandshakeError>>, +) -> io::Result> { + loop { + match res { + Ok(mut s) => { + s.get_mut().flush_write_buf().await?; + return Ok(TlsStream::from(s)); + } + Err(e) => match e { + HandshakeError::Failure(e) => return Err(io::Error::new(io::ErrorKind::Other, e)), + HandshakeError::WouldBlock(mut mid_stream) => { + if mid_stream.get_mut().flush_write_buf().await? == 0 { + mid_stream.get_mut().fill_read_buf().await?; + } + res = mid_stream.handshake(); + } + }, + } + } +} diff --git a/compio-tls/src/lib.rs b/compio-tls/src/lib.rs new file mode 100644 index 00000000..6622724b --- /dev/null +++ b/compio-tls/src/lib.rs @@ -0,0 +1,11 @@ +//! Async TLS streams. + +#![warn(missing_docs)] + +mod adapter; +mod stream; +mod wrapper; + +pub use adapter::*; +pub use stream::*; +pub(crate) use wrapper::*; diff --git a/compio-tls/src/stream.rs b/compio-tls/src/stream.rs new file mode 100644 index 00000000..c648246e --- /dev/null +++ b/compio-tls/src/stream.rs @@ -0,0 +1,73 @@ +use std::{io, mem::MaybeUninit}; + +use compio_buf::{BufResult, IoBuf, IoBufMut}; +use compio_io::{AsyncRead, AsyncWrite}; + +use crate::StreamWrapper; + +/// A wrapper around an underlying raw stream which implements the TLS or SSL +/// protocol. +/// +/// A `TlsStream` represents a handshake that has been completed successfully +/// and both the server and the client are ready for receiving and sending +/// data. Bytes read from a `TlsStream` are decrypted from `S` and bytes written +/// to a `TlsStream` are encrypted when passing through to `S`. +#[derive(Debug)] +pub struct TlsStream(native_tls::TlsStream>); + +impl From>> for TlsStream { + fn from(value: native_tls::TlsStream>) -> Self { + Self(value) + } +} + +impl AsyncRead for TlsStream { + async fn read(&mut self, mut buf: B) -> BufResult { + let slice: &mut [MaybeUninit] = buf.as_mut_slice(); + slice.fill(MaybeUninit::new(0)); + let slice = + unsafe { std::slice::from_raw_parts_mut(slice.as_mut_ptr().cast(), slice.len()) }; + loop { + let res = io::Read::read(&mut self.0, slice); + match res { + Ok(res) => { + unsafe { buf.set_buf_init(res) }; + return BufResult(Ok(res), buf); + } + Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + match self.0.get_mut().fill_read_buf().await { + Ok(_) => continue, + Err(e) => return BufResult(Err(e), buf), + } + } + _ => return BufResult(res, buf), + } + } + } +} + +impl AsyncWrite for TlsStream { + async fn write(&mut self, buf: T) -> BufResult { + let slice = buf.as_slice(); + loop { + let res = io::Write::write(&mut self.0, slice); + match res { + Err(e) if e.kind() == io::ErrorKind::WouldBlock => match self.flush().await { + Ok(_) => continue, + Err(e) => return BufResult(Err(e), buf), + }, + _ => return BufResult(res, buf), + } + } + } + + async fn flush(&mut self) -> io::Result<()> { + self.0.get_mut().flush_write_buf().await?; + Ok(()) + } + + async fn shutdown(&mut self) -> io::Result<()> { + self.flush().await?; + self.0.get_mut().get_mut().shutdown().await + } +} diff --git a/compio-tls/src/wrapper.rs b/compio-tls/src/wrapper.rs new file mode 100644 index 00000000..0e9e2c3f --- /dev/null +++ b/compio-tls/src/wrapper.rs @@ -0,0 +1,134 @@ +use std::io::{self, BufRead, Read, Write}; + +use compio_buf::{BufResult, IntoInner, IoBuf, IoBufMut, SetBufInit}; +use compio_io::{AsyncWriteExt, Buffer}; + +const DEFAULT_BUF_SIZE: usize = 8 * 1024; + +#[derive(Debug)] +pub struct StreamWrapper { + stream: S, + eof: bool, + read_buffer: Buffer, + write_buffer: Buffer, +} + +impl StreamWrapper { + pub fn new(stream: S) -> Self { + Self::with_capacity(stream, DEFAULT_BUF_SIZE) + } + + pub fn with_capacity(stream: S, cap: usize) -> Self { + Self { + stream, + eof: false, + read_buffer: Buffer::with_capacity(cap), + write_buffer: Buffer::with_capacity(cap), + } + } + + pub fn get_ref(&self) -> &S { + &self.stream + } + + pub fn get_mut(&mut self) -> &mut S { + &mut self.stream + } + + fn flush_impl(&mut self) -> io::Result<()> { + if !self.write_buffer.is_empty() { + Err(would_block("need to flush the write buffer")) + } else { + Ok(()) + } + } +} + +impl Read for StreamWrapper { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + let mut slice = self.fill_buf()?; + slice.read(buf).map(|res| { + self.consume(res); + res + }) + } +} + +impl BufRead for StreamWrapper { + fn fill_buf(&mut self) -> io::Result<&[u8]> { + if self.read_buffer.all_done() { + self.read_buffer.reset(); + } + + if self.read_buffer.slice().is_empty() && !self.eof { + return Err(would_block("need to fill the read buffer")); + } + + Ok(self.read_buffer.slice()) + } + + fn consume(&mut self, amt: usize) { + self.read_buffer.advance(amt); + } +} + +impl Write for StreamWrapper { + fn write(&mut self, buf: &[u8]) -> io::Result { + if self.write_buffer.need_flush() { + self.flush_impl()?; + } + + let written = self.write_buffer.with_sync(|mut inner| { + let len = buf.len().min(inner.buf_capacity() - inner.buf_len()); + unsafe { + std::ptr::copy_nonoverlapping( + buf.as_ptr(), + inner.as_buf_mut_ptr().add(inner.buf_len()), + len, + ); + inner.set_buf_init(inner.buf_len() + len); + } + BufResult(Ok(len), inner) + })?; + + Ok(written) + } + + fn flush(&mut self) -> io::Result<()> { + // Related PR: + // https://github.com/sfackler/rust-openssl/pull/1922 + // After this PR merged, we can use self.flush_impl() + Ok(()) + } +} + +fn would_block(msg: &str) -> io::Error { + io::Error::new(io::ErrorKind::WouldBlock, msg) +} + +impl StreamWrapper { + pub async fn fill_read_buf(&mut self) -> io::Result { + let stream = &mut self.stream; + let len = self + .read_buffer + .with(|b| async move { + let len = b.buf_len(); + let b = b.slice(len..); + stream.read(b).await.into_inner() + }) + .await?; + if len == 0 { + self.eof = true; + } + Ok(len) + } +} + +impl StreamWrapper { + pub async fn flush_write_buf(&mut self) -> io::Result { + let stream = &mut self.stream; + let len = self.write_buffer.with(|b| stream.write_all(b)).await?; + self.write_buffer.reset(); + Ok(len) + } +}