Skip to content

Commit

Permalink
Merge pull request hyperium#28 from quininer/early-data
Browse files Browse the repository at this point in the history
Add 0-RTT support
  • Loading branch information
quininer authored Feb 19, 2019
2 parents 7d6ed0a + 527db99 commit 8b8647b
Show file tree
Hide file tree
Showing 9 changed files with 482 additions and 204 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@ webpki = "0.19"
[dev-dependencies]
tokio = "0.1.6"
lazy_static = "1"
webpki-roots = "0.16"
196 changes: 196 additions & 0 deletions src/client.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
use super::*;
use std::io::Write;
use rustls::Session;


/// A wrapper around an underlying raw stream which implements the TLS or SSL
/// protocol.
#[derive(Debug)]
pub struct TlsStream<IO> {
pub(crate) io: IO,
pub(crate) session: ClientSession,
pub(crate) state: TlsState,
pub(crate) early_data: (usize, Vec<u8>)
}

#[derive(Debug)]
pub(crate) enum TlsState {
EarlyData,
Stream,
Eof,
Shutdown
}

pub(crate) enum MidHandshake<IO> {
Handshaking(TlsStream<IO>),
EarlyData(TlsStream<IO>),
End
}

impl<IO> TlsStream<IO> {
#[inline]
pub fn get_ref(&self) -> (&IO, &ClientSession) {
(&self.io, &self.session)
}

#[inline]
pub fn get_mut(&mut self) -> (&mut IO, &mut ClientSession) {
(&mut self.io, &mut self.session)
}

#[inline]
pub fn into_inner(self) -> (IO, ClientSession) {
(self.io, self.session)
}
}

impl<IO> Future for MidHandshake<IO>
where IO: AsyncRead + AsyncWrite,
{
type Item = TlsStream<IO>;
type Error = io::Error;

#[inline]
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
match self {
MidHandshake::Handshaking(stream) => {
let (io, session) = stream.get_mut();
let mut stream = Stream::new(io, session);

if stream.session.is_handshaking() {
try_nb!(stream.complete_io());
}

if stream.session.wants_write() {
try_nb!(stream.complete_io());
}
},
_ => ()
}

match mem::replace(self, MidHandshake::End) {
MidHandshake::Handshaking(stream)
| MidHandshake::EarlyData(stream) => Ok(Async::Ready(stream)),
MidHandshake::End => panic!()
}
}
}

impl<IO> io::Read for TlsStream<IO>
where IO: AsyncRead + AsyncWrite
{
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let mut stream = Stream::new(&mut self.io, &mut self.session);

match self.state {
TlsState::EarlyData => {
let (pos, data) = &mut self.early_data;

// complete handshake
if stream.session.is_handshaking() {
stream.complete_io()?;
}

// write early data (fallback)
if !stream.session.is_early_data_accepted() {
while *pos < data.len() {
let len = stream.write(&data[*pos..])?;
*pos += len;
}
}

// end
self.state = TlsState::Stream;
data.clear();
stream.read(buf)
},
TlsState::Stream => match stream.read(buf) {
Ok(0) => {
self.state = TlsState::Eof;
Ok(0)
},
Ok(n) => Ok(n),
Err(ref e) if e.kind() == io::ErrorKind::ConnectionAborted => {
self.state = TlsState::Shutdown;
stream.session.send_close_notify();
Ok(0)
},
Err(e) => Err(e)
},
TlsState::Eof | TlsState::Shutdown => Ok(0),
}
}
}

impl<IO> io::Write for TlsStream<IO>
where IO: AsyncRead + AsyncWrite
{
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let mut stream = Stream::new(&mut self.io, &mut self.session);

match self.state {
TlsState::EarlyData => {
let (pos, data) = &mut self.early_data;

// write early data
if let Some(mut early_data) = stream.session.early_data() {
let len = early_data.write(buf)?;
data.extend_from_slice(&buf[..len]);
return Ok(len);
}

// complete handshake
if stream.session.is_handshaking() {
stream.complete_io()?;
}

// write early data (fallback)
if !stream.session.is_early_data_accepted() {
while *pos < data.len() {
let len = stream.write(&data[*pos..])?;
*pos += len;
}
}

// end
self.state = TlsState::Stream;
data.clear();
stream.write(buf)
},
_ => stream.write(buf)
}
}

fn flush(&mut self) -> io::Result<()> {
Stream::new(&mut self.io, &mut self.session).flush()?;
self.io.flush()
}
}

impl<IO> AsyncRead for TlsStream<IO>
where IO: AsyncRead + AsyncWrite
{
unsafe fn prepare_uninitialized_buffer(&self, _: &mut [u8]) -> bool {
false
}
}

impl<IO> AsyncWrite for TlsStream<IO>
where IO: AsyncRead + AsyncWrite
{
fn shutdown(&mut self) -> Poll<(), io::Error> {
match self.state {
TlsState::Shutdown => (),
_ => {
self.session.send_close_notify();
self.state = TlsState::Shutdown;
}
}

{
let mut stream = Stream::new(&mut self.io, &mut self.session);
try_nb!(stream.complete_io());
}
self.io.shutdown()
}
}
20 changes: 10 additions & 10 deletions src/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,18 @@ use rustls::WriteV;
use tokio_io::{ AsyncRead, AsyncWrite };


pub struct Stream<'a, S: 'a, IO: 'a> {
pub session: &'a mut S,
pub io: &'a mut IO
pub struct Stream<'a, IO: 'a, S: 'a> {
pub io: &'a mut IO,
pub session: &'a mut S
}

pub trait WriteTls<'a, S: Session, IO: AsyncRead + AsyncWrite>: Read + Write {
pub trait WriteTls<'a, IO: AsyncRead + AsyncWrite, S: Session>: Read + Write {
fn write_tls(&mut self) -> io::Result<usize>;
}

impl<'a, S: Session, IO: AsyncRead + AsyncWrite> Stream<'a, S, IO> {
pub fn new(session: &'a mut S, io: &'a mut IO) -> Self {
Stream { session, io }
impl<'a, IO: AsyncRead + AsyncWrite, S: Session> Stream<'a, IO, S> {
pub fn new(io: &'a mut IO, session: &'a mut S) -> Self {
Stream { io, session }
}

pub fn complete_io(&mut self) -> io::Result<(usize, usize)> {
Expand Down Expand Up @@ -66,7 +66,7 @@ impl<'a, S: Session, IO: AsyncRead + AsyncWrite> Stream<'a, S, IO> {
}
}

impl<'a, S: Session, IO: AsyncRead + AsyncWrite> WriteTls<'a, S, IO> for Stream<'a, S, IO> {
impl<'a, IO: AsyncRead + AsyncWrite, S: Session> WriteTls<'a, IO, S> for Stream<'a, IO, S> {
fn write_tls(&mut self) -> io::Result<usize> {
use futures::Async;
use self::vecbuf::VecBuf;
Expand All @@ -89,7 +89,7 @@ impl<'a, S: Session, IO: AsyncRead + AsyncWrite> WriteTls<'a, S, IO> for Stream<
}
}

impl<'a, S: Session, IO: AsyncRead + AsyncWrite> Read for Stream<'a, S, IO> {
impl<'a, IO: AsyncRead + AsyncWrite, S: Session> Read for Stream<'a, IO, S> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
while self.session.wants_read() {
if let (0, 0) = self.complete_io()? {
Expand All @@ -100,7 +100,7 @@ impl<'a, S: Session, IO: AsyncRead + AsyncWrite> Read for Stream<'a, S, IO> {
}
}

impl<'a, S: Session, IO: AsyncRead + AsyncWrite> Write for Stream<'a, S, IO> {
impl<'a, IO: AsyncRead + AsyncWrite, S: Session> Write for Stream<'a, IO, S> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let len = self.session.write(buf)?;
while self.session.wants_write() {
Expand Down
10 changes: 5 additions & 5 deletions src/common/test_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ fn stream_good() -> io::Result<()> {

{
let mut good = Good(&mut server);
let mut stream = Stream::new(&mut client, &mut good);
let mut stream = Stream::new(&mut good, &mut client);

let mut buf = Vec::new();
stream.read_to_end(&mut buf)?;
Expand All @@ -102,7 +102,7 @@ fn stream_bad() -> io::Result<()> {
client.set_buffer_limit(1024);

let mut bad = Bad(true);
let mut stream = Stream::new(&mut client, &mut bad);
let mut stream = Stream::new(&mut bad, &mut client);
assert_eq!(stream.write(&[0x42; 8])?, 8);
assert_eq!(stream.write(&[0x42; 8])?, 8);
let r = stream.write(&[0x00; 1024])?; // fill buffer
Expand All @@ -121,7 +121,7 @@ fn stream_handshake() -> io::Result<()> {

{
let mut good = Good(&mut server);
let mut stream = Stream::new(&mut client, &mut good);
let mut stream = Stream::new(&mut good, &mut client);
let (r, w) = stream.complete_io()?;

assert!(r > 0);
Expand All @@ -141,7 +141,7 @@ fn stream_handshake_eof() -> io::Result<()> {
let (_, mut client) = make_pair();

let mut bad = Bad(false);
let mut stream = Stream::new(&mut client, &mut bad);
let mut stream = Stream::new(&mut bad, &mut client);
let r = stream.complete_io();

assert_eq!(r.unwrap_err().kind(), io::ErrorKind::UnexpectedEof);
Expand Down Expand Up @@ -171,7 +171,7 @@ fn make_pair() -> (ServerSession, ClientSession) {

fn do_handshake(client: &mut ClientSession, server: &mut ServerSession) {
let mut good = Good(server);
let mut stream = Stream::new(client, &mut good);
let mut stream = Stream::new(&mut good, client);
stream.complete_io().unwrap();
stream.complete_io().unwrap();
}
Loading

0 comments on commit 8b8647b

Please sign in to comment.