From 3ec00a3f0a258d06971252a8df9fcabd76abb29b Mon Sep 17 00:00:00 2001 From: Jakob Schikowski Date: Mon, 4 Mar 2024 22:59:20 +0100 Subject: [PATCH] feat: Implement read limit --- flow-test/src/server_tester.rs | 16 ++++ flow-test/tests/flow-test-server.rs | 30 ++++++ proxy/src/proxy.rs | 19 +++- src/client.rs | 94 ++++++++++-------- src/receive.rs | 36 ++++--- src/send_command.rs | 29 +++--- src/send_response.rs | 19 ++-- src/server.rs | 142 +++++++++++++++++----------- src/stream.rs | 77 ++++++++++++--- 9 files changed, 311 insertions(+), 151 deletions(-) diff --git a/flow-test/src/server_tester.rs b/flow-test/src/server_tester.rs index e3241312..436b49bb 100644 --- a/flow-test/src/server_tester.rs +++ b/flow-test/src/server_tester.rs @@ -148,6 +148,22 @@ impl ServerTester { } } + pub async fn receive_error_because_command_too_long(&mut self, expected_bytes: &[u8]) { + let server = self.connection_state.greeted(); + let error = server.progress().await.unwrap_err(); + match error { + ServerFlowError::CommandTooLong { discarded_bytes } => { + assert_eq!( + expected_bytes.as_bstr(), + discarded_bytes.declassify().as_bstr() + ); + } + error => { + panic!("Server has unexpected error: {error:?}"); + } + } + } + /// Progresses internal responses without expecting any results. pub async fn progress_internal_responses(&mut self) -> T { let server = self.connection_state.greeted(); diff --git a/flow-test/tests/flow-test-server.rs b/flow-test/tests/flow-test-server.rs index eefeb9ff..4e6c8ea8 100644 --- a/flow-test/tests/flow-test-server.rs +++ b/flow-test/tests/flow-test-server.rs @@ -155,3 +155,33 @@ fn login_with_rejected_literal() { rt.run2_and_select(client.receive(status), server.progress_internal_responses()); } } + +#[test] +fn command_larger_than_max_command_size() { + // The server will reject the command because it's larger than the max size + let max_command_size_tests = [9, 10, 20, 100, 10 * 1024 * 1024]; + + for max_command_size in max_command_size_tests { + let mut setup = TestSetup::default(); + setup.server_flow_options.max_command_size = max_command_size as u32; + + let (rt, mut server, mut client) = setup.setup_server(); + + let greeting = b"* OK ...\r\n"; + rt.run2(server.send_greeting(greeting), client.receive(greeting)); + + // Command smaller than the max size can be received + let small_command = b"A1 NOOP\r\n"; + rt.run2( + client.send(small_command), + server.receive_command(small_command), + ); + + // Command larger than the max size triggers an error + let large_command = &vec![b'.'; max_command_size + 1]; + rt.run2( + client.send(large_command), + server.receive_error_because_command_too_long(&large_command[..max_command_size]), + ); + } +} diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 1eeb110e..97004f9a 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -262,13 +262,20 @@ fn handle_client_event( } | ServerFlowError::LiteralTooLong { ref discarded_bytes, + } + | ServerFlowError::CommandTooLong { + ref discarded_bytes, }), ) => { error!(role = "c2p", %error, ?discarded_bytes, "Discard client message"); return ControlFlow::Continue; } - Err(ServerFlowError::Stream(error)) => { - error!(role = "c2p", %error, "Connection terminated"); + Err(ServerFlowError::StreamClosed) => { + error!(role = "c2p", "Stream closed"); + return ControlFlow::Abort; + } + Err(ServerFlowError::Io(error)) => { + error!(role = "c2p", %error, "IO error"); return ControlFlow::Abort; } }; @@ -341,8 +348,12 @@ fn handle_server_event( error!(role = "c2p", %error, ?discarded_bytes, "Discard server message"); return ControlFlow::Continue; } - Err(ClientFlowError::Stream(error)) => { - error!(role = "s2p", %error, "Connection terminated"); + Err(ClientFlowError::StreamClosed) => { + error!(role = "s2p", "Stream closed"); + return ControlFlow::Abort; + } + Err(ClientFlowError::Io(error)) => { + error!(role = "s2p", %error, "IO error"); return ControlFlow::Abort; } }; diff --git a/src/client.rs b/src/client.rs index f5eacf66..d5fcde6a 100644 --- a/src/client.rs +++ b/src/client.rs @@ -17,7 +17,7 @@ use crate::{ handle::{Handle, HandleGenerator, HandleGeneratorGenerator, RawHandle}, receive::{ReceiveEvent, ReceiveState}, send_command::{SendCommandEvent, SendCommandState, SendCommandTermination}, - stream::{AnyStream, StreamError}, + stream::{AnyStream, ReadBuffer, ReadError, WriteBuffer, WriteError}, types::CommandAuthenticate, }; @@ -61,39 +61,46 @@ impl ClientFlow { options: ClientFlowOptions, ) -> Result<(Self, Greeting<'static>), ClientFlowError> { // Receive greeting. - let mut receive_greeting_state = ReceiveState::new( - GreetingCodec::default(), - options.crlf_relaxed, - BytesMut::new(), - ); + let read_buffer = ReadBuffer { + bytes: BytesMut::new(), + // No limit is set because we trust the server + limit: None, + }; + let mut receive_greeting_state = + ReceiveState::new(GreetingCodec::default(), options.crlf_relaxed, read_buffer); - let greeting = match receive_greeting_state.progress(&mut stream).await? { - ReceiveEvent::DecodingSuccess(greeting) => { + let greeting = match receive_greeting_state.progress(&mut stream).await { + Ok(ReceiveEvent::DecodingSuccess(greeting)) => { receive_greeting_state.finish_message(); greeting } - ReceiveEvent::DecodingFailure( + Ok(ReceiveEvent::DecodingFailure( GreetingDecodeError::Failed | GreetingDecodeError::Incomplete, - ) => { - let discarded_bytes = receive_greeting_state.discard_message(); - return Err(ClientFlowError::MalformedMessage { - discarded_bytes: Secret::new(discarded_bytes), - }); + )) => { + let discarded_bytes = receive_greeting_state.discard_message().into(); + return Err(ClientFlowError::MalformedMessage { discarded_bytes }); + } + Ok(ReceiveEvent::ExpectedCrlfGotLf) => { + let discarded_bytes = receive_greeting_state.discard_message().into(); + return Err(ClientFlowError::ExpectedCrlfGotLf { discarded_bytes }); } - ReceiveEvent::ExpectedCrlfGotLf => { - let discarded_bytes = receive_greeting_state.discard_message(); - return Err(ClientFlowError::ExpectedCrlfGotLf { - discarded_bytes: Secret::new(discarded_bytes), - }); + Err(ReadError::Closed) => return Err(ClientFlowError::StreamClosed), + Err(ReadError::Io(err)) => return Err(ClientFlowError::Io(err)), + Err(ReadError::BufferLimitReached) => { + // Unreachable because limit is not set for read buffer. + unreachable!() } }; // Create state to send commands ... + let write_buffer = WriteBuffer { + bytes: BytesMut::new(), + }; let send_command_state = SendCommandState::new( CommandCodec::default(), AuthenticateDataCodec::default(), IdleDoneCodec::default(), - BytesMut::new(), + write_buffer, ); // ..., and state to receive responses. @@ -171,33 +178,31 @@ impl ClientFlow { async fn progress_receive(&mut self) -> Result, ClientFlowError> { let event = loop { - let response = match self - .receive_response_state - .progress(&mut self.stream) - .await? - { - ReceiveEvent::DecodingSuccess(response) => { + let response = match self.receive_response_state.progress(&mut self.stream).await { + Ok(ReceiveEvent::DecodingSuccess(response)) => { self.receive_response_state.finish_message(); response } - ReceiveEvent::DecodingFailure(ResponseDecodeError::LiteralFound { length }) => { + Ok(ReceiveEvent::DecodingFailure(ResponseDecodeError::LiteralFound { length })) => { // The client must accept the literal in any case. self.receive_response_state.start_literal(length); continue; } - ReceiveEvent::DecodingFailure( + Ok(ReceiveEvent::DecodingFailure( ResponseDecodeError::Failed | ResponseDecodeError::Incomplete, - ) => { - let discarded_bytes = self.receive_response_state.discard_message(); - return Err(ClientFlowError::MalformedMessage { - discarded_bytes: Secret::new(discarded_bytes), - }); + )) => { + let discarded_bytes = self.receive_response_state.discard_message().into(); + return Err(ClientFlowError::MalformedMessage { discarded_bytes }); + } + Ok(ReceiveEvent::ExpectedCrlfGotLf) => { + let discarded_bytes = self.receive_response_state.discard_message().into(); + return Err(ClientFlowError::ExpectedCrlfGotLf { discarded_bytes }); } - ReceiveEvent::ExpectedCrlfGotLf => { - let discarded_bytes = self.receive_response_state.discard_message(); - return Err(ClientFlowError::ExpectedCrlfGotLf { - discarded_bytes: Secret::new(discarded_bytes), - }); + Err(ReadError::Closed) => return Err(ClientFlowError::StreamClosed), + Err(ReadError::Io(err)) => return Err(ClientFlowError::Io(err)), + Err(ReadError::BufferLimitReached) => { + // Unreachable because limit is not set for read buffer. + unreachable!() } }; @@ -386,10 +391,21 @@ pub enum ClientFlowEvent { #[derive(Debug, Error)] pub enum ClientFlowError { + #[error("Stream was closed")] + StreamClosed, #[error(transparent)] - Stream(#[from] StreamError), + Io(#[from] tokio::io::Error), #[error("Expected `\\r\\n`, got `\\n`")] ExpectedCrlfGotLf { discarded_bytes: Secret> }, #[error("Received malformed message")] MalformedMessage { discarded_bytes: Secret> }, } + +impl From for ClientFlowError { + fn from(err: WriteError) -> Self { + match err { + WriteError::Closed => Self::StreamClosed, + WriteError::Io(err) => Self::Io(err), + } + } +} diff --git a/src/receive.rs b/src/receive.rs index 15e0a5d1..2884f971 100644 --- a/src/receive.rs +++ b/src/receive.rs @@ -1,8 +1,8 @@ use bounded_static::IntoBoundedStatic; -use bytes::{Buf, BytesMut}; +use bytes::Buf; use imap_codec::decode::Decoder; -use crate::stream::{AnyStream, StreamError}; +use crate::stream::{AnyStream, ReadBuffer, ReadError}; pub struct ReceiveState { codec: C, @@ -14,11 +14,11 @@ pub struct ReceiveState { seen_bytes: usize, // Used for reading the current message from the stream. // Its length should always be equal to or greater than `seen_bytes`. - read_buffer: BytesMut, + read_buffer: ReadBuffer, } impl ReceiveState { - pub fn new(codec: C, crlf_relaxed: bool, read_buffer: BytesMut) -> Self { + pub fn new(codec: C, crlf_relaxed: bool, read_buffer: ReadBuffer) -> Self { Self { codec, crlf_relaxed, @@ -30,22 +30,27 @@ impl ReceiveState { pub fn start_literal(&mut self, length: u32) { self.next_fragment = NextFragment::Literal { length }; - self.read_buffer.reserve(length as usize); + self.read_buffer.bytes.reserve(length as usize); } pub fn finish_message(&mut self) { - self.read_buffer.advance(self.seen_bytes); + self.read_buffer.bytes.advance(self.seen_bytes); self.seen_bytes = 0; self.next_fragment = NextFragment::start_new_line(); } pub fn discard_message(&mut self) -> Box<[u8]> { - let discarded_bytes = self.read_buffer[..self.seen_bytes].into(); + let discarded_bytes = self.read_buffer.bytes[..self.seen_bytes].into(); self.finish_message(); discarded_bytes } - pub async fn progress(&mut self, stream: &mut AnyStream) -> Result, StreamError> + pub fn discard_all_bytes(&mut self) -> Box<[u8]> { + self.seen_bytes = self.read_buffer.bytes.len(); + self.discard_message() + } + + pub async fn progress(&mut self, stream: &mut AnyStream) -> Result, ReadError> where C: Decoder, for<'a> C::Message<'a>: IntoBoundedStatic>, @@ -69,21 +74,21 @@ impl ReceiveState { &mut self, stream: &mut AnyStream, seen_bytes_in_line: usize, - ) -> Result>, StreamError> + ) -> Result>, ReadError> where C: Decoder, for<'a> C::Message<'a>: IntoBoundedStatic>, for<'a> C::Error<'a>: IntoBoundedStatic>, { let Some(crlf_result) = find_crlf( - &self.read_buffer[self.seen_bytes..], + &self.read_buffer.bytes[self.seen_bytes..], seen_bytes_in_line, self.crlf_relaxed, ) else { // No full line received yet, more data needed. // Mark the bytes of the partial line as seen. - let seen_bytes_in_line = self.read_buffer.len() - self.seen_bytes; + let seen_bytes_in_line = self.read_buffer.bytes.len() - self.seen_bytes; self.next_fragment = NextFragment::Line { seen_bytes_in_line }; // Read more data. @@ -104,7 +109,10 @@ impl ReceiveState { // TODO(#129): If the message is really long and we need multiple attempts to receive it, // then this is O(n^2). IMO this can be only fixed by using a generator-like // decoder. - match self.codec.decode(&self.read_buffer[..self.seen_bytes]) { + match self + .codec + .decode(&self.read_buffer.bytes[..self.seen_bytes]) + { Ok((remaining, message)) => { assert!(remaining.is_empty()); Ok(Some(ReceiveEvent::DecodingSuccess(message.into_static()))) @@ -117,8 +125,8 @@ impl ReceiveState { &mut self, stream: &mut AnyStream, literal_length: u32, - ) -> Result<(), StreamError> { - let unseen_bytes = self.read_buffer.len() - self.seen_bytes; + ) -> Result<(), ReadError> { + let unseen_bytes = self.read_buffer.bytes.len() - self.seen_bytes; if unseen_bytes < literal_length as usize { // We did not receive enough bytes for the literal yet. diff --git a/src/send_command.rs b/src/send_command.rs index 6c1eed6e..c7cf0b75 100644 --- a/src/send_command.rs +++ b/src/send_command.rs @@ -1,6 +1,5 @@ use std::collections::VecDeque; -use bytes::BytesMut; use imap_codec::{ encode::{Encoder, Fragment}, AuthenticateDataCodec, CommandCodec, IdleDoneCodec, @@ -16,7 +15,7 @@ use tracing::warn; use crate::{ client::ClientFlowCommandHandle, - stream::{AnyStream, StreamError}, + stream::{AnyStream, WriteBuffer, WriteError}, types::CommandAuthenticate, }; @@ -32,7 +31,7 @@ pub struct SendCommandState { /// Note that this buffer can be non-empty even if `current_command` is `None` /// because commands can be aborted (see `maybe_terminate`) but partially sent /// fragment must never be aborted. - write_buffer: BytesMut, + write_buffer: WriteBuffer, } impl SendCommandState { @@ -40,7 +39,7 @@ impl SendCommandState { command_codec: CommandCodec, authenticate_data_codec: AuthenticateDataCodec, idle_done_codec: IdleDoneCodec, - write_buffer: BytesMut, + write_buffer: WriteBuffer, ) -> Self { Self { command_codec, @@ -285,7 +284,7 @@ impl SendCommandState { pub async fn progress( &mut self, stream: &mut AnyStream, - ) -> Result, StreamError> { + ) -> Result, WriteError> { let current_command = match self.current_command.take() { Some(current_command) => { // We are currently sending a command but the sending process was aborted for one @@ -408,7 +407,7 @@ enum CurrentCommand { impl CurrentCommand { /// Pushes as many bytes as possible from the command to the buffer. - fn push_to_buffer(self, write_buffer: &mut BytesMut) -> Self { + fn push_to_buffer(self, write_buffer: &mut WriteBuffer) -> Self { match self { Self::Command(state) => Self::Command(state.push_to_buffer(write_buffer)), Self::Authenticate(state) => Self::Authenticate(state.push_to_buffer(write_buffer)), @@ -463,13 +462,13 @@ struct CommandState { } impl CommandState { - fn push_to_buffer(self, write_buffer: &mut BytesMut) -> Self { + fn push_to_buffer(self, write_buffer: &mut WriteBuffer) -> Self { let mut fragments = self.fragments; let activity = match self.activity { CommandActivity::PushingFragments { accepted_literal } => { // First push the accepted literal if available if let Some(data) = accepted_literal { - write_buffer.extend(data); + write_buffer.bytes.extend(data); } // Push as many fragments as possible @@ -482,7 +481,7 @@ impl CommandState { mode: LiteralMode::NonSync, }, ) => { - write_buffer.extend(data); + write_buffer.bytes.extend(data); } Some(Fragment::Literal { data, @@ -560,14 +559,14 @@ struct AuthenticateState { } impl AuthenticateState { - fn push_to_buffer(self, write_buffer: &mut BytesMut) -> Self { + fn push_to_buffer(self, write_buffer: &mut WriteBuffer) -> Self { let activity = match self.activity { AuthenticateActivity::PushingAuthenticate { authenticate } => { - write_buffer.extend(authenticate); + write_buffer.bytes.extend(authenticate); AuthenticateActivity::WaitingForAuthenticateSent } AuthenticateActivity::PushingAuthenticateData { authenticate_data } => { - write_buffer.extend(authenticate_data); + write_buffer.bytes.extend(authenticate_data); AuthenticateActivity::WaitingForAuthenticateDataSent } activity => activity, @@ -629,14 +628,14 @@ struct IdleState { } impl IdleState { - fn push_to_buffer(self, write_buffer: &mut BytesMut) -> Self { + fn push_to_buffer(self, write_buffer: &mut WriteBuffer) -> Self { let activity = match self.activity { IdleActivity::PushingIdle { idle } => { - write_buffer.extend(idle); + write_buffer.bytes.extend(idle); IdleActivity::WaitingForIdleSent } IdleActivity::PushingIdleDone { idle_done } => { - write_buffer.extend(idle_done); + write_buffer.bytes.extend(idle_done); IdleActivity::WaitingForIdleDoneSent } activity => activity, diff --git a/src/send_response.rs b/src/send_response.rs index e696dba6..906e55f6 100644 --- a/src/send_response.rs +++ b/src/send_response.rs @@ -1,11 +1,10 @@ use std::collections::VecDeque; -use bytes::BytesMut; use imap_codec::encode::{Encoder, Fragment}; use crate::{ server::ServerFlowResponseHandle, - stream::{AnyStream, StreamError}, + stream::{AnyStream, WriteBuffer, WriteError}, }; pub struct SendResponseState { @@ -16,11 +15,11 @@ pub struct SendResponseState { current_response: Option>, // Used for writing the current response to the stream. // Should be empty if `current_response` is `None`. - write_buffer: BytesMut, + write_buffer: WriteBuffer, } impl SendResponseState { - pub fn new(codec: C, write_buffer: BytesMut) -> Self { + pub fn new(codec: C, write_buffer: WriteBuffer) -> Self { Self { codec, queued_responses: VecDeque::new(), @@ -38,15 +37,15 @@ impl SendResponseState { .push_back(QueuedResponse { handle, response }); } - pub fn finish(mut self) -> BytesMut { - self.write_buffer.clear(); + pub fn finish(mut self) -> WriteBuffer { + self.write_buffer.bytes.clear(); self.write_buffer } pub async fn progress( &mut self, stream: &mut AnyStream, - ) -> Result>, StreamError> { + ) -> Result>, WriteError> { let current_response = match self.current_response.take() { Some(current_response) => { // We are currently sending a response but the sending process was cancelled. @@ -54,7 +53,7 @@ impl SendResponseState { current_response } None => { - assert!(self.write_buffer.is_empty()); + assert!(self.write_buffer.bytes.is_empty()); let Some(queued_response) = self.queued_responses.pop_front() else { // There is currently no response that needs to be sent @@ -89,7 +88,7 @@ struct QueuedResponse { } impl QueuedResponse { - fn push_to_buffer(self, write_buffer: &mut BytesMut, codec: &C) -> CurrentResponse { + fn push_to_buffer(self, write_buffer: &mut WriteBuffer, codec: &C) -> CurrentResponse { for fragment in codec.encode(&self.response) { let data = match fragment { Fragment::Line { data } => data, @@ -99,7 +98,7 @@ impl QueuedResponse { // see https://github.com/duesee/imap-codec/issues/332 Fragment::Literal { data, .. } => data, }; - write_buffer.extend(data); + write_buffer.bytes.extend(data); } CurrentResponse { diff --git a/src/server.rs b/src/server.rs index fbce0481..914ecb44 100644 --- a/src/server.rs +++ b/src/server.rs @@ -22,7 +22,7 @@ use crate::{ handle::{Handle, HandleGenerator, HandleGeneratorGenerator, RawHandle}, receive::{ReceiveEvent, ReceiveState}, send_response::{SendResponseEvent, SendResponseState}, - stream::{AnyStream, StreamError}, + stream::{AnyStream, ReadBuffer, ReadError, WriteBuffer, WriteError}, types::CommandAuthenticate, }; @@ -33,7 +33,18 @@ static HANDLE_GENERATOR_GENERATOR: HandleGeneratorGenerator, pub literal_reject_text: Text<'static>, } @@ -43,8 +54,11 @@ impl Default for ServerFlowOptions { Self { // Lean towards conformity crlf_relaxed: false, - // 25 MiB is a common maximum email size (Oct. 2023) + // 25 MiB is a common maximum email size (Oct. 2023). max_literal_size: 25 * 1024 * 1024, + // Must be bigger than `max_literal_size`. + // 64 KiB is used by Dovecot. + max_command_size: (25 * 1024 * 1024) + (64 * 1024), // Short unmeaning text literal_accept_text: Text::unvalidated("..."), // Short unmeaning text @@ -77,7 +91,9 @@ impl ServerFlow { greeting: Greeting<'static>, ) -> Result<(Self, Greeting<'static>), ServerFlowError> { // Send greeting - let write_buffer = BytesMut::new(); + let write_buffer = WriteBuffer { + bytes: BytesMut::new(), + }; let mut send_greeting_state = SendResponseState::new(GreetingCodec::default(), write_buffer); send_greeting_state.enqueue(None, greeting); @@ -93,7 +109,10 @@ impl ServerFlow { // Successfully sent greeting, construct instance let write_buffer = send_greeting_state.finish(); let send_response_state = SendResponseState::new(ResponseCodec::default(), write_buffer); - let read_buffer = BytesMut::new(); + let read_buffer = ReadBuffer { + bytes: BytesMut::new(), + limit: Some(options.max_command_size as usize), + }; let receive_command_state = ReceiveState::new(CommandCodec::default(), options.crlf_relaxed, read_buffer); let server_flow = Self { @@ -202,8 +221,8 @@ impl ServerFlow { async fn progress_receive(&mut self) -> Result, ServerFlowError> { match &mut self.receive_command_state { ServerReceiveState::Command(state) => { - match state.progress(&mut self.stream).await? { - ReceiveEvent::DecodingSuccess(command) => { + match state.progress(&mut self.stream).await { + Ok(ReceiveEvent::DecodingSuccess(command)) => { state.finish_message(); match command.body { @@ -238,11 +257,11 @@ impl ServerFlow { })), } } - ReceiveEvent::DecodingFailure(CommandDecodeError::LiteralFound { + Ok(ReceiveEvent::DecodingFailure(CommandDecodeError::LiteralFound { tag, length, mode, - }) => { + })) => { if length > self.options.max_literal_size { match mode { LiteralMode::Sync => { @@ -258,11 +277,9 @@ impl ServerFlow { self.send_response_state .enqueue(None, Response::Status(status)); - let discarded_bytes = state.discard_message(); + let discarded_bytes = state.discard_message().into(); - Err(ServerFlowError::LiteralTooLong { - discarded_bytes: Secret::new(discarded_bytes), - }) + Err(ServerFlowError::LiteralTooLong { discarded_bytes }) } LiteralMode::NonSync => { // TODO: We can't (reliably) make the client stop sending data. @@ -273,11 +290,9 @@ impl ServerFlow { // * ... // // The LITERAL+ RFC has some recommendations. - let discarded_bytes = state.discard_message(); + let discarded_bytes = state.discard_message().into(); - Err(ServerFlowError::LiteralTooLong { - discarded_bytes: Secret::new(discarded_bytes), - }) + Err(ServerFlowError::LiteralTooLong { discarded_bytes }) } } } else { @@ -305,44 +320,48 @@ impl ServerFlow { Ok(None) } } - ReceiveEvent::DecodingFailure( + Ok(ReceiveEvent::DecodingFailure( CommandDecodeError::Failed | CommandDecodeError::Incomplete, - ) => { - let discarded_bytes = state.discard_message(); - Err(ServerFlowError::MalformedMessage { - discarded_bytes: Secret::new(discarded_bytes), - }) + )) => { + let discarded_bytes = state.discard_message().into(); + Err(ServerFlowError::MalformedMessage { discarded_bytes }) + } + Ok(ReceiveEvent::ExpectedCrlfGotLf) => { + let discarded_bytes = state.discard_message().into(); + Err(ServerFlowError::ExpectedCrlfGotLf { discarded_bytes }) } - ReceiveEvent::ExpectedCrlfGotLf => { - let discarded_bytes = state.discard_message(); - Err(ServerFlowError::ExpectedCrlfGotLf { - discarded_bytes: Secret::new(discarded_bytes), - }) + Err(ReadError::Closed) => Err(ServerFlowError::StreamClosed), + Err(ReadError::Io(err)) => Err(ServerFlowError::Io(err)), + Err(ReadError::BufferLimitReached) => { + let discarded_bytes = state.discard_all_bytes().into(); + Err(ServerFlowError::CommandTooLong { discarded_bytes }) } } } ServerReceiveState::AuthenticateData(state) => { - match state.progress(&mut self.stream).await? { - ReceiveEvent::DecodingSuccess(authenticate_data) => { + match state.progress(&mut self.stream).await { + Ok(ReceiveEvent::DecodingSuccess(authenticate_data)) => { state.finish_message(); Ok(Some(ServerFlowEvent::AuthenticateDataReceived { authenticate_data, })) } - ReceiveEvent::DecodingFailure( + Ok(ReceiveEvent::DecodingFailure( AuthenticateDataDecodeError::Failed | AuthenticateDataDecodeError::Incomplete, - ) => { - let discarded_bytes = state.discard_message(); - Err(ServerFlowError::MalformedMessage { - discarded_bytes: Secret::new(discarded_bytes), - }) + )) => { + let discarded_bytes = state.discard_message().into(); + Err(ServerFlowError::MalformedMessage { discarded_bytes }) } - ReceiveEvent::ExpectedCrlfGotLf => { - let discarded_bytes = state.discard_message(); - Err(ServerFlowError::ExpectedCrlfGotLf { - discarded_bytes: Secret::new(discarded_bytes), - }) + Ok(ReceiveEvent::ExpectedCrlfGotLf) => { + let discarded_bytes = state.discard_message().into(); + Err(ServerFlowError::ExpectedCrlfGotLf { discarded_bytes }) + } + Err(ReadError::Closed) => Err(ServerFlowError::StreamClosed), + Err(ReadError::Io(err)) => Err(ServerFlowError::Io(err)), + Err(ReadError::BufferLimitReached) => { + let discarded_bytes = state.discard_all_bytes().into(); + Err(ServerFlowError::CommandTooLong { discarded_bytes }) } } } @@ -352,8 +371,8 @@ impl ServerFlow { // `idle_accept` or `idle_reject`. pending().await } - ServerReceiveState::IdleDone(state) => match state.progress(&mut self.stream).await? { - ReceiveEvent::DecodingSuccess(IdleDone) => { + ServerReceiveState::IdleDone(state) => match state.progress(&mut self.stream).await { + Ok(ReceiveEvent::DecodingSuccess(IdleDone)) => { state.finish_message(); self.receive_command_state @@ -361,19 +380,21 @@ impl ServerFlow { Ok(Some(ServerFlowEvent::IdleDoneReceived)) } - ReceiveEvent::DecodingFailure( + Ok(ReceiveEvent::DecodingFailure( IdleDoneDecodeError::Failed | IdleDoneDecodeError::Incomplete, - ) => { - let discarded_bytes = state.discard_message(); - Err(ServerFlowError::MalformedMessage { - discarded_bytes: Secret::new(discarded_bytes), - }) + )) => { + let discarded_bytes = state.discard_message().into(); + Err(ServerFlowError::MalformedMessage { discarded_bytes }) + } + Ok(ReceiveEvent::ExpectedCrlfGotLf) => { + let discarded_bytes = state.discard_message().into(); + Err(ServerFlowError::ExpectedCrlfGotLf { discarded_bytes }) } - ReceiveEvent::ExpectedCrlfGotLf => { - let discarded_bytes = state.discard_message(); - Err(ServerFlowError::ExpectedCrlfGotLf { - discarded_bytes: Secret::new(discarded_bytes), - }) + Err(ReadError::Closed) => Err(ServerFlowError::StreamClosed), + Err(ReadError::Io(err)) => Err(ServerFlowError::Io(err)), + Err(ReadError::BufferLimitReached) => { + let discarded_bytes = state.discard_all_bytes().into(); + Err(ServerFlowError::CommandTooLong { discarded_bytes }) } }, ServerReceiveState::Dummy => { @@ -587,14 +608,27 @@ pub enum ServerFlowEvent { #[derive(Debug, Error)] pub enum ServerFlowError { + #[error("Stream was closed")] + StreamClosed, #[error(transparent)] - Stream(#[from] StreamError), + Io(#[from] tokio::io::Error), #[error("Expected `\\r\\n`, got `\\n`")] ExpectedCrlfGotLf { discarded_bytes: Secret> }, #[error("Received malformed message")] MalformedMessage { discarded_bytes: Secret> }, #[error("Literal was rejected because it was too long")] LiteralTooLong { discarded_bytes: Secret> }, + #[error("Command was rejected because it was too long")] + CommandTooLong { discarded_bytes: Secret> }, +} + +impl From for ServerFlowError { + fn from(err: WriteError) -> Self { + match err { + WriteError::Closed => Self::StreamClosed, + WriteError::Io(err) => Self::Io(err), + } + } } /// A dummy codec we use for technical reasons when we don't want to receive anything at all. diff --git a/src/stream.rs b/src/stream.rs index 1a6053c3..0af6cfe2 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -24,14 +24,27 @@ impl AnyStream { /// Reads at least one byte into the buffer and returns the number of read bytes. /// /// Returns [`StreamError::Closed`] when no bytes could be read. - pub async fn read(&mut self, read_buffer: &mut BytesMut) -> Result { - #[cfg(debug_assertions)] - let old_len = read_buffer.len(); - let byte_count = self.0.read_buf(read_buffer).await?; + pub async fn read(&mut self, read_buffer: &mut ReadBuffer) -> Result { + let current_len = read_buffer.bytes.len(); + + let byte_count = match read_buffer.limit { + None => self.0.read_buf(&mut read_buffer.bytes).await?, + Some(limit) => { + let remaining_byte_count = limit.saturating_sub(current_len); + if remaining_byte_count == 0 { + return Err(ReadError::BufferLimitReached); + } + + (&mut self.0) + .take(remaining_byte_count as u64) + .read_buf(&mut read_buffer.bytes) + .await? + } + }; #[cfg(debug_assertions)] trace!( - data = escape_byte_string(&read_buffer[old_len..]), + data = escape_byte_string(&read_buffer.bytes[current_len..]), "io/read/raw" ); @@ -40,7 +53,7 @@ impl AnyStream { // The result is 0 if the stream reached "end of file" or the read buffer was // already full before calling `read_buf`. Because we use an unlimited buffer we // know that the first case occurred. - Err(StreamError::Closed) + Err(ReadError::Closed) } Some(byte_count) => Ok(byte_count), } @@ -49,21 +62,21 @@ impl AnyStream { /// Writes all bytes from the write buffer. /// /// Returns [`StreamError::Closed`] when not all bytes could be written. - pub async fn write_all(&mut self, write_buffer: &mut BytesMut) -> Result<(), StreamError> { - while !write_buffer.is_empty() { - let byte_count = self.0.write(write_buffer).await?; + pub async fn write_all(&mut self, write_buffer: &mut WriteBuffer) -> Result<(), WriteError> { + while !write_buffer.bytes.is_empty() { + let byte_count = self.0.write(&write_buffer.bytes).await?; #[cfg(debug_assertions)] trace!( - data = escape_byte_string(&write_buffer[..byte_count]), + data = escape_byte_string(&write_buffer.bytes[..byte_count]), "io/write/raw" ); - write_buffer.advance(byte_count); + write_buffer.bytes.advance(byte_count); if byte_count == 0 { // The result is 0 if the stream doesn't accept bytes anymore or the write buffer // was already empty before calling `write_buf`. Because we checked the buffer // we know that the first case occurred. - return Err(StreamError::Closed); + return Err(WriteError::Closed); } } @@ -71,12 +84,29 @@ impl AnyStream { } } -/// Error during reading from or writing to a [`Stream`]. +/// Error raised by [`AnyStream::read`]. +#[derive(Debug, Error)] +pub enum ReadError { + /// The operation failed because the stream is closed. + /// + /// We detect this by checking if the read byte count is 0. Whether the stream is + /// closed indefinitely or temporarily depend on the actual stream implementation. + #[error("Stream was closed")] + Closed, + /// An I/O error occurred in the underlying stream. + #[error(transparent)] + Io(#[from] tokio::io::Error), + /// Can't read more bytes because the buffer limit is already reached. + #[error("Read buffer has overflown")] + BufferLimitReached, +} + +/// Error raised by [`AnyStream::write_all`]. #[derive(Debug, Error)] -pub enum StreamError { +pub enum WriteError { /// The operation failed because the stream is closed. /// - /// We detect this by checking if the read or written byte count is 0. Whether the stream is + /// We detect this by checking if the written byte count is 0. Whether the stream is /// closed indefinitely or temporarily depend on the actual stream implementation. #[error("Stream was closed")] Closed, @@ -84,3 +114,20 @@ pub enum StreamError { #[error(transparent)] Io(#[from] tokio::io::Error), } + +/// Buffer for reading bytes with [`AnyStream::read`]. +#[derive(Default)] +pub struct ReadBuffer { + pub bytes: BytesMut, + /// The max number of bytes to be stored in `bytes`. + /// + /// If the maximum number is reached and [`AnyStream::read`] is called, + /// it results in [`ReadError::BufferLimitReached`]. + pub limit: Option, +} + +/// Buffer for writing bytes with [`AnyStream::write`]. +#[derive(Default)] +pub struct WriteBuffer { + pub bytes: BytesMut, +}