Skip to content

Commit

Permalink
feat: Implement read limit
Browse files Browse the repository at this point in the history
  • Loading branch information
jakoschiko committed Mar 20, 2024
1 parent dce759a commit d656139
Show file tree
Hide file tree
Showing 9 changed files with 294 additions and 150 deletions.
16 changes: 16 additions & 0 deletions flow-test/src/server_tester.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,22 @@ impl ServerTester {
}
}

pub async fn receive_error_because_command_too_long(&mut self, expected_bytes: &[u8]) {
let server = self.connection_state.greeted();
let error = server.progress().await.unwrap_err();
match error {
ServerFlowError::CommandTooLong { discarded_bytes } => {
assert_eq!(
expected_bytes.as_bstr(),
discarded_bytes.declassify().as_bstr()
);
}
error => {
panic!("Server has unexpected error: {error:?}");
}
}
}

/// Progresses internal responses without expecting any results.
pub async fn progress_internal_responses<T>(&mut self) -> T {
let server = self.connection_state.greeted();
Expand Down
30 changes: 30 additions & 0 deletions flow-test/tests/flow-test-server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,33 @@ fn login_with_rejected_literal() {
rt.run2_and_select(client.receive(status), server.progress_internal_responses());
}
}

#[test]
fn command_larger_than_max_command_size() {
// The server will reject the command because it's larger than the max size
let max_command_size_tests = [9, 10, 20, 100, 10 * 1024 * 1024];

for max_command_size in max_command_size_tests {
let mut setup = TestSetup::default();
setup.server_flow_options.max_command_size = max_command_size as u32;

let (rt, mut server, mut client) = setup.setup_server();

let greeting = b"* OK ...\r\n";
rt.run2(server.send_greeting(greeting), client.receive(greeting));

// Command smaller than the max size can be received
let small_command = b"A1 NOOP\r\n";
rt.run2(
client.send(small_command),
server.receive_command(small_command),
);

// Command larger than the max size triggers an error
let large_command = &vec![b'.'; max_command_size + 1];
rt.run2(
client.send(large_command),
server.receive_error_because_command_too_long(&large_command[..max_command_size]),
);
}
}
19 changes: 15 additions & 4 deletions proxy/src/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -262,13 +262,20 @@ fn handle_client_event(
}
| ServerFlowError::LiteralTooLong {
ref discarded_bytes,
}
| ServerFlowError::CommandTooLong {
ref discarded_bytes,
}),
) => {
error!(role = "c2p", %error, ?discarded_bytes, "Discard client message");
return ControlFlow::Continue;
}
Err(ServerFlowError::Stream(error)) => {
error!(role = "c2p", %error, "Connection terminated");
Err(ServerFlowError::StreamClosed) => {
error!(role = "c2p", "Stream closed");
return ControlFlow::Abort;
}
Err(ServerFlowError::Io(error)) => {
error!(role = "c2p", %error, "IO error");
return ControlFlow::Abort;
}
};
Expand Down Expand Up @@ -341,8 +348,12 @@ fn handle_server_event(
error!(role = "c2p", %error, ?discarded_bytes, "Discard server message");
return ControlFlow::Continue;
}
Err(ClientFlowError::Stream(error)) => {
error!(role = "s2p", %error, "Connection terminated");
Err(ClientFlowError::StreamClosed) => {
error!(role = "s2p", "Stream closed");
return ControlFlow::Abort;
}
Err(ClientFlowError::Io(error)) => {
error!(role = "s2p", %error, "IO error");
return ControlFlow::Abort;
}
};
Expand Down
94 changes: 55 additions & 39 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use crate::{
handle::{Handle, HandleGenerator, HandleGeneratorGenerator, RawHandle},
receive::{ReceiveEvent, ReceiveState},
send_command::{SendCommandEvent, SendCommandState, SendCommandTermination},
stream::{AnyStream, StreamError},
stream::{AnyStream, ReadBuffer, ReadError, WriteBuffer, WriteError},
types::CommandAuthenticate,
};

Expand Down Expand Up @@ -61,39 +61,46 @@ impl ClientFlow {
options: ClientFlowOptions,
) -> Result<(Self, Greeting<'static>), ClientFlowError> {
// Receive greeting.
let mut receive_greeting_state = ReceiveState::new(
GreetingCodec::default(),
options.crlf_relaxed,
BytesMut::new(),
);
let read_buffer = ReadBuffer {
bytes: BytesMut::new(),
// No limit is set because we trust the server
limit: None,
};
let mut receive_greeting_state =
ReceiveState::new(GreetingCodec::default(), options.crlf_relaxed, read_buffer);

let greeting = match receive_greeting_state.progress(&mut stream).await? {
ReceiveEvent::DecodingSuccess(greeting) => {
let greeting = match receive_greeting_state.progress(&mut stream).await {
Ok(ReceiveEvent::DecodingSuccess(greeting)) => {
receive_greeting_state.finish_message();
greeting
}
ReceiveEvent::DecodingFailure(
Ok(ReceiveEvent::DecodingFailure(
GreetingDecodeError::Failed | GreetingDecodeError::Incomplete,
) => {
let discarded_bytes = receive_greeting_state.discard_message();
return Err(ClientFlowError::MalformedMessage {
discarded_bytes: Secret::new(discarded_bytes),
});
)) => {
let discarded_bytes = receive_greeting_state.discard_message().into();
return Err(ClientFlowError::MalformedMessage { discarded_bytes });
}
Ok(ReceiveEvent::ExpectedCrlfGotLf) => {
let discarded_bytes = receive_greeting_state.discard_message().into();
return Err(ClientFlowError::ExpectedCrlfGotLf { discarded_bytes });
}
ReceiveEvent::ExpectedCrlfGotLf => {
let discarded_bytes = receive_greeting_state.discard_message();
return Err(ClientFlowError::ExpectedCrlfGotLf {
discarded_bytes: Secret::new(discarded_bytes),
});
Err(ReadError::Closed) => return Err(ClientFlowError::StreamClosed),
Err(ReadError::Io(err)) => return Err(ClientFlowError::Io(err)),
Err(ReadError::ReadBufferOverflow) => {
// Unreachable because limit is not set for read buffer.
unreachable!()
}
};

// Create state to send commands ...
let write_buffer = WriteBuffer {
bytes: BytesMut::new(),
};
let send_command_state = SendCommandState::new(
CommandCodec::default(),
AuthenticateDataCodec::default(),
IdleDoneCodec::default(),
BytesMut::new(),
write_buffer,
);

// ..., and state to receive responses.
Expand Down Expand Up @@ -171,33 +178,31 @@ impl ClientFlow {

async fn progress_receive(&mut self) -> Result<Option<ClientFlowEvent>, ClientFlowError> {
let event = loop {
let response = match self
.receive_response_state
.progress(&mut self.stream)
.await?
{
ReceiveEvent::DecodingSuccess(response) => {
let response = match self.receive_response_state.progress(&mut self.stream).await {
Ok(ReceiveEvent::DecodingSuccess(response)) => {
self.receive_response_state.finish_message();
response
}
ReceiveEvent::DecodingFailure(ResponseDecodeError::LiteralFound { length }) => {
Ok(ReceiveEvent::DecodingFailure(ResponseDecodeError::LiteralFound { length })) => {
// The client must accept the literal in any case.
self.receive_response_state.start_literal(length);
continue;
}
ReceiveEvent::DecodingFailure(
Ok(ReceiveEvent::DecodingFailure(
ResponseDecodeError::Failed | ResponseDecodeError::Incomplete,
) => {
let discarded_bytes = self.receive_response_state.discard_message();
return Err(ClientFlowError::MalformedMessage {
discarded_bytes: Secret::new(discarded_bytes),
});
)) => {
let discarded_bytes = self.receive_response_state.discard_message().into();
return Err(ClientFlowError::MalformedMessage { discarded_bytes });
}
Ok(ReceiveEvent::ExpectedCrlfGotLf) => {
let discarded_bytes = self.receive_response_state.discard_message().into();
return Err(ClientFlowError::ExpectedCrlfGotLf { discarded_bytes });
}
ReceiveEvent::ExpectedCrlfGotLf => {
let discarded_bytes = self.receive_response_state.discard_message();
return Err(ClientFlowError::ExpectedCrlfGotLf {
discarded_bytes: Secret::new(discarded_bytes),
});
Err(ReadError::Closed) => return Err(ClientFlowError::StreamClosed),
Err(ReadError::Io(err)) => return Err(ClientFlowError::Io(err)),
Err(ReadError::ReadBufferOverflow) => {
// Unreachable because limit is not set for read buffer.
unreachable!()
}
};

Expand Down Expand Up @@ -386,10 +391,21 @@ pub enum ClientFlowEvent {

#[derive(Debug, Error)]
pub enum ClientFlowError {
#[error("Stream was closed")]
StreamClosed,
#[error(transparent)]
Stream(#[from] StreamError),
Io(#[from] tokio::io::Error),
#[error("Expected `\\r\\n`, got `\\n`")]
ExpectedCrlfGotLf { discarded_bytes: Secret<Box<[u8]>> },
#[error("Received malformed message")]
MalformedMessage { discarded_bytes: Secret<Box<[u8]>> },
}

impl From<WriteError> for ClientFlowError {
fn from(err: WriteError) -> Self {
match err {
WriteError::Closed => Self::StreamClosed,
WriteError::Io(err) => Self::Io(err),
}
}
}
36 changes: 22 additions & 14 deletions src/receive.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use bounded_static::IntoBoundedStatic;
use bytes::{Buf, BytesMut};
use bytes::Buf;
use imap_codec::decode::Decoder;

use crate::stream::{AnyStream, StreamError};
use crate::stream::{AnyStream, ReadBuffer, ReadError};

pub struct ReceiveState<C> {
codec: C,
Expand All @@ -14,11 +14,11 @@ pub struct ReceiveState<C> {
seen_bytes: usize,
// Used for reading the current message from the stream.
// Its length should always be equal to or greater than `seen_bytes`.
read_buffer: BytesMut,
read_buffer: ReadBuffer,
}

impl<C> ReceiveState<C> {
pub fn new(codec: C, crlf_relaxed: bool, read_buffer: BytesMut) -> Self {
pub fn new(codec: C, crlf_relaxed: bool, read_buffer: ReadBuffer) -> Self {
Self {
codec,
crlf_relaxed,
Expand All @@ -30,22 +30,27 @@ impl<C> ReceiveState<C> {

pub fn start_literal(&mut self, length: u32) {
self.next_fragment = NextFragment::Literal { length };
self.read_buffer.reserve(length as usize);
self.read_buffer.bytes.reserve(length as usize);
}

pub fn finish_message(&mut self) {
self.read_buffer.advance(self.seen_bytes);
self.read_buffer.bytes.advance(self.seen_bytes);
self.seen_bytes = 0;
self.next_fragment = NextFragment::start_new_line();
}

pub fn discard_message(&mut self) -> Box<[u8]> {
let discarded_bytes = self.read_buffer[..self.seen_bytes].into();
let discarded_bytes = self.read_buffer.bytes[..self.seen_bytes].into();
self.finish_message();
discarded_bytes
}

pub async fn progress(&mut self, stream: &mut AnyStream) -> Result<ReceiveEvent<C>, StreamError>
pub fn discard_all_bytes(&mut self) -> Box<[u8]> {
self.seen_bytes = self.read_buffer.bytes.len();
self.discard_message()
}

pub async fn progress(&mut self, stream: &mut AnyStream) -> Result<ReceiveEvent<C>, ReadError>
where
C: Decoder,
for<'a> C::Message<'a>: IntoBoundedStatic<Static = C::Message<'static>>,
Expand All @@ -69,21 +74,21 @@ impl<C> ReceiveState<C> {
&mut self,
stream: &mut AnyStream,
seen_bytes_in_line: usize,
) -> Result<Option<ReceiveEvent<C>>, StreamError>
) -> Result<Option<ReceiveEvent<C>>, ReadError>
where
C: Decoder,
for<'a> C::Message<'a>: IntoBoundedStatic<Static = C::Message<'static>>,
for<'a> C::Error<'a>: IntoBoundedStatic<Static = C::Error<'static>>,
{
let Some(crlf_result) = find_crlf(
&self.read_buffer[self.seen_bytes..],
&self.read_buffer.bytes[self.seen_bytes..],
seen_bytes_in_line,
self.crlf_relaxed,
) else {
// No full line received yet, more data needed.

// Mark the bytes of the partial line as seen.
let seen_bytes_in_line = self.read_buffer.len() - self.seen_bytes;
let seen_bytes_in_line = self.read_buffer.bytes.len() - self.seen_bytes;
self.next_fragment = NextFragment::Line { seen_bytes_in_line };

// Read more data.
Expand All @@ -104,7 +109,10 @@ impl<C> ReceiveState<C> {
// TODO(#129): If the message is really long and we need multiple attempts to receive it,
// then this is O(n^2). IMO this can be only fixed by using a generator-like
// decoder.
match self.codec.decode(&self.read_buffer[..self.seen_bytes]) {
match self
.codec
.decode(&self.read_buffer.bytes[..self.seen_bytes])
{
Ok((remaining, message)) => {
assert!(remaining.is_empty());
Ok(Some(ReceiveEvent::DecodingSuccess(message.into_static())))
Expand All @@ -117,8 +125,8 @@ impl<C> ReceiveState<C> {
&mut self,
stream: &mut AnyStream,
literal_length: u32,
) -> Result<(), StreamError> {
let unseen_bytes = self.read_buffer.len() - self.seen_bytes;
) -> Result<(), ReadError> {
let unseen_bytes = self.read_buffer.bytes.len() - self.seen_bytes;

if unseen_bytes < literal_length as usize {
// We did not receive enough bytes for the literal yet.
Expand Down
Loading

0 comments on commit d656139

Please sign in to comment.