Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(tls): add native-tls adapters. #119

Merged
merged 10 commits into from
Oct 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ members = [
"compio-signal",
"compio-dispatcher",
"compio-io",
"compio-tls",
]
resolver = "2"

Expand Down
2 changes: 1 addition & 1 deletion azure-pipelines.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ jobs:

- script: |
rustup toolchain install nightly
cargo +nightly test --workspace --features all,polling,nightly --no-default-features
cargo +nightly test --workspace --features all,polling,native-tls,nightly --no-default-features
displayName: TestNightly-polling
# - script: |
# cargo test --workspace --features all,polling --no-default-features
Expand Down
20 changes: 20 additions & 0 deletions compio-io/src/buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ impl IoBufMut for Inner {
pub struct Buffer(Option<Inner>);

impl Buffer {
/// Create a buffer with capacity.
#[inline]
pub fn with_capacity(cap: usize) -> Self {
Self(Some(Inner {
Expand All @@ -87,11 +88,18 @@ impl Buffer {
}))
}

/// Get the initialized but not consumed part of the buffer.
#[inline]
pub fn slice(&self) -> &[u8] {
self.inner().slice()
}

/// If the inner buffer is empty.
#[inline]
pub fn is_empty(&self) -> bool {
self.inner().as_slice().is_empty()
}

/// All bytes in the buffer have been read
#[inline]
pub fn all_done(&self) -> bool {
Expand All @@ -113,6 +121,7 @@ impl Buffer {
buf.len() > buf.capacity() * 2 / 3
}

/// Clear the inner buffer and reset the position to the start.
#[inline]
pub fn reset(&mut self) {
self.inner_mut().reset();
Expand All @@ -130,6 +139,17 @@ impl Buffer {
res
}

/// Execute a funcition with ownership of the buffer, and restore the buffer
/// afterwards
pub fn with_sync<R>(
&mut self,
func: impl FnOnce(Inner) -> BufResult<R, Inner>,
) -> std::io::Result<R> {
let BufResult(res, buf) = func(self.take_inner());
self.restore_inner(buf);
res
}

/// Mark some bytes as read by advancing the progress tracker, return a
/// `bool` indicating if all bytes are read.
#[inline]
Expand Down
1 change: 1 addition & 0 deletions compio-io/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ mod write;

pub(crate) type IoResult<T> = std::io::Result<T>;

pub use buffer::Buffer;
pub use read::*;
pub use util::{copy, null, repeat};
pub use write::*;
9 changes: 7 additions & 2 deletions compio-io/src/read/buf.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use compio_buf::{buf_try, BufResult, IntoInner, IoBufMut, IoVectoredBufMut};
use compio_buf::{buf_try, BufResult, IntoInner, IoBuf, IoBufMut, IoVectoredBufMut};

use crate::{buffer::Buffer, util::DEFAULT_BUF_SIZE, AsyncRead, IoResult};
/// # AsyncBufRead
Expand Down Expand Up @@ -96,7 +96,12 @@ impl<R: AsyncRead> AsyncBufRead for BufReader<R> {
}

if buf.need_fill() {
buf.with(|b| reader.read(b)).await?;
buf.with(|b| async move {
Berrysoft marked this conversation as resolved.
Show resolved Hide resolved
let len = b.buf_len();
let b = b.slice(len..);
reader.read(b).await.into_inner()
})
.await?;
}

Ok(buf.slice())
Expand Down
22 changes: 11 additions & 11 deletions compio-io/src/write/buf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use compio_buf::{buf_try, BufResult, IntoInner, IoBuf};
use crate::{
buffer::Buffer,
util::{slice_to_buf, DEFAULT_BUF_SIZE},
AsyncWrite, IoResult,
AsyncWrite, AsyncWriteExt, IoResult,
};

/// Wraps a writer and buffers its output.
Expand Down Expand Up @@ -52,11 +52,12 @@ impl<W: AsyncWrite> AsyncWrite for BufWriter<W> {
async fn write<T: IoBuf>(&mut self, mut buf: T) -> compio_buf::BufResult<usize, T> {
let written = self
.buf
.with(|mut w| {
.with_sync(|w| {
let len = w.buf_len();
let mut w = w.slice(len..);
let written = slice_to_buf(buf.as_slice(), &mut w);
ready(BufResult(Ok(written), w))
BufResult(Ok(written), w.into_inner())
})
.await
.expect("Closure always return Ok");

if self.buf.need_flush() {
Expand All @@ -75,7 +76,10 @@ impl<W: AsyncWrite> AsyncWrite for BufWriter<W> {
.with(|mut w| {
let mut written = 0;
for buf in buf.as_dyn_bufs() {
written += slice_to_buf(buf.as_slice(), &mut w);
let len = w.buf_len();
let mut slice = w.slice(len..);
written += slice_to_buf(buf.as_slice(), &mut slice);
w = slice.into_inner();

if w.buf_len() == w.buf_capacity() {
break;
Expand All @@ -96,12 +100,8 @@ impl<W: AsyncWrite> AsyncWrite for BufWriter<W> {
async fn flush(&mut self) -> IoResult<()> {
let Self { writer, buf } = self;

let len = buf.with(|w| writer.write(w)).await?;
buf.advance(len);

if buf.all_done() {
buf.reset();
}
buf.with(|w| writer.write_all(w)).await?;
buf.reset();

Ok(())
}
Expand Down
2 changes: 1 addition & 1 deletion compio-net/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ rustdoc-args = ["--cfg docsrs"]
compio-buf = { workspace = true }
compio-driver = { workspace = true }
compio-io = { workspace = true, optional = true }
compio-runtime = { workspace = true, optional = true }
compio-runtime = { workspace = true, optional = true, features = ["event"] }
Berrysoft marked this conversation as resolved.
Show resolved Hide resolved

cfg-if = "1"
either = "1"
Expand Down
23 changes: 23 additions & 0 deletions compio-tls/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
[package]
name = "compio-tls"
version = "0.1.0"
categories = ["asynchronous", "network-programming"]
keywords = ["async", "net", "tls"]
edition = { workspace = true }
authors = { workspace = true }
readme = { workspace = true }
license = { workspace = true }
repository = { workspace = true }

[dependencies]
compio-buf = { workspace = true }
compio-io = { workspace = true }

native-tls = { version = "0.2", optional = true }

[dev-dependencies]
compio-net = { workspace = true, features = ["runtime"] }
compio-runtime = { workspace = true }

[features]
default = ["native-tls"]
109 changes: 109 additions & 0 deletions compio-tls/src/adapter.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
use std::io;

use compio_io::{AsyncRead, AsyncWrite};
use native_tls::HandshakeError;

use crate::{wrapper::StreamWrapper, TlsStream};

/// A wrapper around a [`native_tls::TlsConnector`], providing an async
/// `connect` method.
///
/// ```rust
/// use compio_io::{AsyncReadExt, AsyncWrite, AsyncWriteExt};
/// use compio_net::TcpStream;
/// use compio_tls::TlsConnector;
///
/// # compio_runtime::block_on(async {
/// let connector = TlsConnector::from(native_tls::TlsConnector::new().unwrap());
///
/// let stream = TcpStream::connect("www.example.com:443").await.unwrap();
/// let mut stream = connector.connect("www.example.com", stream).await.unwrap();
///
/// stream
/// .write_all("GET / HTTP/1.1\r\nHost:www.example.com\r\nConnection: close\r\n\r\n")
/// .await
/// .unwrap();
/// stream.flush().await.unwrap();
/// let (_, res) = stream.read_to_end(vec![]).await.unwrap();
/// println!("{}", String::from_utf8_lossy(&res));
/// # })
/// ```
#[derive(Debug, Clone)]
pub struct TlsConnector(native_tls::TlsConnector);

impl From<native_tls::TlsConnector> for TlsConnector {
fn from(value: native_tls::TlsConnector) -> Self {
Self(value)
}
}

impl TlsConnector {
/// Connects the provided stream with this connector, assuming the provided
/// domain.
///
/// This function will internally call `TlsConnector::connect` to connect
/// the stream and returns a future representing the resolution of the
/// connection operation. The returned future will resolve to either
/// `TlsStream<S>` or `Error` depending if it's successful or not.
///
/// This is typically used for clients who have already established, for
/// example, a TCP connection to a remote server. That stream is then
/// provided here to perform the client half of a connection to a
/// TLS-powered server.
pub async fn connect<S: AsyncRead + AsyncWrite>(
&self,
domain: &str,
stream: S,
) -> io::Result<TlsStream<S>> {
handshake(self.0.connect(domain, StreamWrapper::new(stream))).await
}
}

/// A wrapper around a [`native_tls::TlsAcceptor`], providing an async `accept`
/// method.
#[derive(Clone)]
pub struct TlsAcceptor(native_tls::TlsAcceptor);

impl From<native_tls::TlsAcceptor> for TlsAcceptor {
fn from(value: native_tls::TlsAcceptor) -> Self {
Self(value)
}
}

impl TlsAcceptor {
/// Accepts a new client connection with the provided stream.
///
/// This function will internally call `TlsAcceptor::accept` to connect
/// the stream and returns a future representing the resolution of the
/// connection operation. The returned future will resolve to either
/// `TlsStream<S>` or `Error` depending if it's successful or not.
///
/// This is typically used after a new socket has been accepted from a
/// `TcpListener`. That socket is then passed to this function to perform
/// the server half of accepting a client connection.
pub async fn accept<S: AsyncRead + AsyncWrite>(&self, stream: S) -> io::Result<TlsStream<S>> {
handshake(self.0.accept(StreamWrapper::new(stream))).await
}
}

async fn handshake<S: AsyncRead + AsyncWrite>(
mut res: Result<native_tls::TlsStream<StreamWrapper<S>>, HandshakeError<StreamWrapper<S>>>,
) -> io::Result<TlsStream<S>> {
loop {
match res {
Ok(mut s) => {
s.get_mut().flush_write_buf().await?;
return Ok(TlsStream::from(s));
}
Err(e) => match e {
HandshakeError::Failure(e) => return Err(io::Error::new(io::ErrorKind::Other, e)),
HandshakeError::WouldBlock(mut mid_stream) => {
if mid_stream.get_mut().flush_write_buf().await? == 0 {
mid_stream.get_mut().fill_read_buf().await?;
}
res = mid_stream.handshake();
}
},
}
}
}
11 changes: 11 additions & 0 deletions compio-tls/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
//! Async TLS streams.

#![warn(missing_docs)]

mod adapter;
mod stream;
mod wrapper;

pub use adapter::*;
pub use stream::*;
pub(crate) use wrapper::*;
73 changes: 73 additions & 0 deletions compio-tls/src/stream.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
use std::{io, mem::MaybeUninit};

use compio_buf::{BufResult, IoBuf, IoBufMut};
use compio_io::{AsyncRead, AsyncWrite};

use crate::StreamWrapper;

/// A wrapper around an underlying raw stream which implements the TLS or SSL
/// protocol.
///
/// A `TlsStream<S>` represents a handshake that has been completed successfully
/// and both the server and the client are ready for receiving and sending
/// data. Bytes read from a `TlsStream` are decrypted from `S` and bytes written
/// to a `TlsStream` are encrypted when passing through to `S`.
#[derive(Debug)]
pub struct TlsStream<S>(native_tls::TlsStream<StreamWrapper<S>>);

impl<S> From<native_tls::TlsStream<StreamWrapper<S>>> for TlsStream<S> {
fn from(value: native_tls::TlsStream<StreamWrapper<S>>) -> Self {
Self(value)
}
}

impl<S: AsyncRead> AsyncRead for TlsStream<S> {
async fn read<B: IoBufMut>(&mut self, mut buf: B) -> BufResult<usize, B> {
let slice: &mut [MaybeUninit<u8>] = buf.as_mut_slice();
slice.fill(MaybeUninit::new(0));
let slice =
unsafe { std::slice::from_raw_parts_mut(slice.as_mut_ptr().cast(), slice.len()) };
loop {
let res = io::Read::read(&mut self.0, slice);
match res {
Ok(res) => {
unsafe { buf.set_buf_init(res) };
return BufResult(Ok(res), buf);
}
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
match self.0.get_mut().fill_read_buf().await {
Ok(_) => continue,
Err(e) => return BufResult(Err(e), buf),
}
}
_ => return BufResult(res, buf),
}
}
}
}

impl<S: AsyncWrite> AsyncWrite for TlsStream<S> {
async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
let slice = buf.as_slice();
loop {
let res = io::Write::write(&mut self.0, slice);
match res {
Err(e) if e.kind() == io::ErrorKind::WouldBlock => match self.flush().await {
Ok(_) => continue,
Err(e) => return BufResult(Err(e), buf),
},
_ => return BufResult(res, buf),
}
}
}

async fn flush(&mut self) -> io::Result<()> {
self.0.get_mut().flush_write_buf().await?;
Ok(())
}

async fn shutdown(&mut self) -> io::Result<()> {
self.flush().await?;
self.0.get_mut().get_mut().shutdown().await
}
}
Loading