Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement read limit #153

Merged
merged 1 commit into from
Apr 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions flow-test/src/server_tester.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,22 @@ impl ServerTester {
}
}

pub async fn receive_error_because_command_too_long(&mut self, expected_bytes: &[u8]) {
let server = self.connection_state.greeted();
let error = server.progress().await.unwrap_err();
match error {
ServerFlowError::CommandTooLong { discarded_bytes } => {
assert_eq!(
expected_bytes.as_bstr(),
discarded_bytes.declassify().as_bstr()
);
}
error => {
panic!("Server has unexpected error: {error:?}");
}
}
}

/// Progresses internal responses without expecting any results.
pub async fn progress_internal_responses<T>(&mut self) -> T {
let server = self.connection_state.greeted();
Expand Down
30 changes: 30 additions & 0 deletions flow-test/tests/flow-test-server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,33 @@ fn login_with_rejected_literal() {
rt.run2_and_select(client.receive(status), server.progress_internal_responses());
}
}

#[test]
fn command_larger_than_max_command_size() {
// The server will reject the command because it's larger than the max size
let max_command_size_tests = [9, 10, 20, 100, 10 * 1024 * 1024];

for max_command_size in max_command_size_tests {
let mut setup = TestSetup::default();
setup.server_flow_options.max_command_size = max_command_size as u32;

let (rt, mut server, mut client) = setup.setup_server();

let greeting = b"* OK ...\r\n";
rt.run2(server.send_greeting(greeting), client.receive(greeting));

// Command smaller than the max size can be received
let small_command = b"A1 NOOP\r\n";
rt.run2(
client.send(small_command),
server.receive_command(small_command),
);

// Command larger than the max size triggers an error
let large_command = &vec![b'.'; max_command_size + 1];
rt.run2(
client.send(large_command),
server.receive_error_because_command_too_long(&large_command[..max_command_size]),
);
}
}
19 changes: 15 additions & 4 deletions proxy/src/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -262,13 +262,20 @@ fn handle_client_event(
}
| ServerFlowError::LiteralTooLong {
ref discarded_bytes,
}
| ServerFlowError::CommandTooLong {
ref discarded_bytes,
}),
) => {
error!(role = "c2p", %error, ?discarded_bytes, "Discard client message");
return ControlFlow::Continue;
}
Err(ServerFlowError::Stream(error)) => {
error!(role = "c2p", %error, "Connection terminated");
Err(ServerFlowError::StreamClosed) => {
error!(role = "c2p", "Stream closed");
return ControlFlow::Abort;
}
Err(ServerFlowError::Io(error)) => {
error!(role = "c2p", %error, "IO error");
return ControlFlow::Abort;
}
};
Expand Down Expand Up @@ -341,8 +348,12 @@ fn handle_server_event(
error!(role = "c2p", %error, ?discarded_bytes, "Discard server message");
return ControlFlow::Continue;
}
Err(ClientFlowError::Stream(error)) => {
error!(role = "s2p", %error, "Connection terminated");
Err(ClientFlowError::StreamClosed) => {
error!(role = "s2p", "Stream closed");
return ControlFlow::Abort;
}
Err(ClientFlowError::Io(error)) => {
error!(role = "s2p", %error, "IO error");
return ControlFlow::Abort;
}
};
Expand Down
94 changes: 55 additions & 39 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use crate::{
handle::{Handle, HandleGenerator, HandleGeneratorGenerator, RawHandle},
receive::{ReceiveEvent, ReceiveState},
send_command::{SendCommandEvent, SendCommandState, SendCommandTermination},
stream::{AnyStream, StreamError},
stream::{AnyStream, ReadBuffer, ReadError, WriteBuffer, WriteError},
types::CommandAuthenticate,
};

Expand Down Expand Up @@ -61,39 +61,46 @@ impl ClientFlow {
options: ClientFlowOptions,
) -> Result<(Self, Greeting<'static>), ClientFlowError> {
// Receive greeting.
let mut receive_greeting_state = ReceiveState::new(
GreetingCodec::default(),
options.crlf_relaxed,
BytesMut::new(),
);
let read_buffer = ReadBuffer {
bytes: BytesMut::new(),
// No limit is set because we trust the server
limit: None,
};
let mut receive_greeting_state =
ReceiveState::new(GreetingCodec::default(), options.crlf_relaxed, read_buffer);

let greeting = match receive_greeting_state.progress(&mut stream).await? {
ReceiveEvent::DecodingSuccess(greeting) => {
let greeting = match receive_greeting_state.progress(&mut stream).await {
Ok(ReceiveEvent::DecodingSuccess(greeting)) => {
receive_greeting_state.finish_message();
greeting
}
ReceiveEvent::DecodingFailure(
Ok(ReceiveEvent::DecodingFailure(
GreetingDecodeError::Failed | GreetingDecodeError::Incomplete,
) => {
let discarded_bytes = receive_greeting_state.discard_message();
return Err(ClientFlowError::MalformedMessage {
discarded_bytes: Secret::new(discarded_bytes),
});
)) => {
let discarded_bytes = receive_greeting_state.discard_message().into();
return Err(ClientFlowError::MalformedMessage { discarded_bytes });
}
Ok(ReceiveEvent::ExpectedCrlfGotLf) => {
let discarded_bytes = receive_greeting_state.discard_message().into();
return Err(ClientFlowError::ExpectedCrlfGotLf { discarded_bytes });
}
ReceiveEvent::ExpectedCrlfGotLf => {
let discarded_bytes = receive_greeting_state.discard_message();
return Err(ClientFlowError::ExpectedCrlfGotLf {
discarded_bytes: Secret::new(discarded_bytes),
});
Err(ReadError::Closed) => return Err(ClientFlowError::StreamClosed),
Err(ReadError::Io(err)) => return Err(ClientFlowError::Io(err)),
Err(ReadError::BufferLimitReached) => {
// Unreachable because limit is not set for read buffer.
unreachable!()
}
};

// Create state to send commands ...
let write_buffer = WriteBuffer {
bytes: BytesMut::new(),
};
let send_command_state = SendCommandState::new(
CommandCodec::default(),
AuthenticateDataCodec::default(),
IdleDoneCodec::default(),
BytesMut::new(),
write_buffer,
);

// ..., and state to receive responses.
Expand Down Expand Up @@ -171,33 +178,31 @@ impl ClientFlow {

async fn progress_receive(&mut self) -> Result<Option<ClientFlowEvent>, ClientFlowError> {
let event = loop {
let response = match self
.receive_response_state
.progress(&mut self.stream)
.await?
{
ReceiveEvent::DecodingSuccess(response) => {
let response = match self.receive_response_state.progress(&mut self.stream).await {
Ok(ReceiveEvent::DecodingSuccess(response)) => {
self.receive_response_state.finish_message();
response
}
ReceiveEvent::DecodingFailure(ResponseDecodeError::LiteralFound { length }) => {
Ok(ReceiveEvent::DecodingFailure(ResponseDecodeError::LiteralFound { length })) => {
// The client must accept the literal in any case.
self.receive_response_state.start_literal(length);
continue;
}
ReceiveEvent::DecodingFailure(
Ok(ReceiveEvent::DecodingFailure(
ResponseDecodeError::Failed | ResponseDecodeError::Incomplete,
) => {
let discarded_bytes = self.receive_response_state.discard_message();
return Err(ClientFlowError::MalformedMessage {
discarded_bytes: Secret::new(discarded_bytes),
});
)) => {
let discarded_bytes = self.receive_response_state.discard_message().into();
return Err(ClientFlowError::MalformedMessage { discarded_bytes });
}
Ok(ReceiveEvent::ExpectedCrlfGotLf) => {
let discarded_bytes = self.receive_response_state.discard_message().into();
return Err(ClientFlowError::ExpectedCrlfGotLf { discarded_bytes });
}
ReceiveEvent::ExpectedCrlfGotLf => {
let discarded_bytes = self.receive_response_state.discard_message();
return Err(ClientFlowError::ExpectedCrlfGotLf {
discarded_bytes: Secret::new(discarded_bytes),
});
Err(ReadError::Closed) => return Err(ClientFlowError::StreamClosed),
Err(ReadError::Io(err)) => return Err(ClientFlowError::Io(err)),
Err(ReadError::BufferLimitReached) => {
// Unreachable because limit is not set for read buffer.
unreachable!()
}
};

Expand Down Expand Up @@ -386,10 +391,21 @@ pub enum ClientFlowEvent {

#[derive(Debug, Error)]
pub enum ClientFlowError {
#[error("Stream was closed")]
StreamClosed,
#[error(transparent)]
Stream(#[from] StreamError),
Io(#[from] tokio::io::Error),
#[error("Expected `\\r\\n`, got `\\n`")]
ExpectedCrlfGotLf { discarded_bytes: Secret<Box<[u8]>> },
#[error("Received malformed message")]
MalformedMessage { discarded_bytes: Secret<Box<[u8]>> },
}

impl From<WriteError> for ClientFlowError {
fn from(err: WriteError) -> Self {
match err {
WriteError::Closed => Self::StreamClosed,
WriteError::Io(err) => Self::Io(err),
}
}
}
36 changes: 22 additions & 14 deletions src/receive.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use bounded_static::IntoBoundedStatic;
use bytes::{Buf, BytesMut};
use bytes::Buf;
use imap_codec::decode::Decoder;

use crate::stream::{AnyStream, StreamError};
use crate::stream::{AnyStream, ReadBuffer, ReadError};

pub struct ReceiveState<C> {
codec: C,
Expand All @@ -14,11 +14,11 @@ pub struct ReceiveState<C> {
seen_bytes: usize,
// Used for reading the current message from the stream.
// Its length should always be equal to or greater than `seen_bytes`.
read_buffer: BytesMut,
read_buffer: ReadBuffer,
}

impl<C> ReceiveState<C> {
pub fn new(codec: C, crlf_relaxed: bool, read_buffer: BytesMut) -> Self {
pub fn new(codec: C, crlf_relaxed: bool, read_buffer: ReadBuffer) -> Self {
Self {
codec,
crlf_relaxed,
Expand All @@ -30,22 +30,27 @@ impl<C> ReceiveState<C> {

pub fn start_literal(&mut self, length: u32) {
self.next_fragment = NextFragment::Literal { length };
self.read_buffer.reserve(length as usize);
self.read_buffer.bytes.reserve(length as usize);
}

pub fn finish_message(&mut self) {
self.read_buffer.advance(self.seen_bytes);
self.read_buffer.bytes.advance(self.seen_bytes);
self.seen_bytes = 0;
self.next_fragment = NextFragment::start_new_line();
}

pub fn discard_message(&mut self) -> Box<[u8]> {
let discarded_bytes = self.read_buffer[..self.seen_bytes].into();
let discarded_bytes = self.read_buffer.bytes[..self.seen_bytes].into();
self.finish_message();
discarded_bytes
}

pub async fn progress(&mut self, stream: &mut AnyStream) -> Result<ReceiveEvent<C>, StreamError>
pub fn discard_all_bytes(&mut self) -> Box<[u8]> {
self.seen_bytes = self.read_buffer.bytes.len();
self.discard_message()
}

pub async fn progress(&mut self, stream: &mut AnyStream) -> Result<ReceiveEvent<C>, ReadError>
where
C: Decoder,
for<'a> C::Message<'a>: IntoBoundedStatic<Static = C::Message<'static>>,
Expand All @@ -69,21 +74,21 @@ impl<C> ReceiveState<C> {
&mut self,
stream: &mut AnyStream,
seen_bytes_in_line: usize,
) -> Result<Option<ReceiveEvent<C>>, StreamError>
) -> Result<Option<ReceiveEvent<C>>, ReadError>
where
C: Decoder,
for<'a> C::Message<'a>: IntoBoundedStatic<Static = C::Message<'static>>,
for<'a> C::Error<'a>: IntoBoundedStatic<Static = C::Error<'static>>,
{
let Some(crlf_result) = find_crlf(
&self.read_buffer[self.seen_bytes..],
&self.read_buffer.bytes[self.seen_bytes..],
seen_bytes_in_line,
self.crlf_relaxed,
) else {
// No full line received yet, more data needed.

// Mark the bytes of the partial line as seen.
let seen_bytes_in_line = self.read_buffer.len() - self.seen_bytes;
let seen_bytes_in_line = self.read_buffer.bytes.len() - self.seen_bytes;
self.next_fragment = NextFragment::Line { seen_bytes_in_line };

// Read more data.
Expand All @@ -104,7 +109,10 @@ impl<C> ReceiveState<C> {
// TODO(#129): If the message is really long and we need multiple attempts to receive it,
// then this is O(n^2). IMO this can be only fixed by using a generator-like
// decoder.
match self.codec.decode(&self.read_buffer[..self.seen_bytes]) {
match self
.codec
.decode(&self.read_buffer.bytes[..self.seen_bytes])
{
Ok((remaining, message)) => {
assert!(remaining.is_empty());
Ok(Some(ReceiveEvent::DecodingSuccess(message.into_static())))
Expand All @@ -117,8 +125,8 @@ impl<C> ReceiveState<C> {
&mut self,
stream: &mut AnyStream,
literal_length: u32,
) -> Result<(), StreamError> {
let unseen_bytes = self.read_buffer.len() - self.seen_bytes;
) -> Result<(), ReadError> {
let unseen_bytes = self.read_buffer.bytes.len() - self.seen_bytes;

if unseen_bytes < literal_length as usize {
// We did not receive enough bytes for the literal yet.
Expand Down
Loading
Loading