From 803ffb51d0942cf30e30bc33b66bfdc26f92ad71 Mon Sep 17 00:00:00 2001 From: Riccardo Zaglia Date: Tue, 8 Aug 2023 15:45:11 +0800 Subject: [PATCH] Finish ControlSocket sync conversion --- alvr/client_core/src/connection.rs | 124 ++++++------- alvr/server/src/connection.rs | 268 ++++++++++++++--------------- alvr/sockets/src/backend/mod.rs | 19 ++ alvr/sockets/src/backend/tcp.rs | 97 +++++++++++ alvr/sockets/src/backend/udp.rs | 40 +++++ alvr/sockets/src/control_socket.rs | 201 +++++++++++++--------- alvr/sockets/src/lib.rs | 50 ++++++ 7 files changed, 507 insertions(+), 292 deletions(-) create mode 100644 alvr/sockets/src/backend/mod.rs create mode 100644 alvr/sockets/src/backend/tcp.rs create mode 100644 alvr/sockets/src/backend/udp.rs diff --git a/alvr/client_core/src/connection.rs b/alvr/client_core/src/connection.rs index 68937d60ad..1d06aeb465 100644 --- a/alvr/client_core/src/connection.rs +++ b/alvr/client_core/src/connection.rs @@ -125,7 +125,7 @@ fn connection_pipeline( let (mut proto_control_socket, server_ip) = { let config = Config::load(); let announcer_socket = AnnouncerSocket::new(&config.hostname).to_con()?; - let listener_socket = alvr_sockets::get_server_listener(&runtime).to_con()?; + let listener_socket = alvr_sockets::get_server_listener(Duration::from_secs(1)).to_con()?; loop { if !IS_ALIVE.value() { @@ -145,7 +145,6 @@ fn connection_pipeline( } if let Ok(pair) = ProtoControlSocket::connect_to( - &runtime, DISCOVERY_RETRY_PAUSE, PeerType::Server(&listener_socket), ) { @@ -171,22 +170,18 @@ fn connection_pipeline( .unwrap(); proto_control_socket - .send( - &runtime, - &ClientConnectionResult::ConnectionAccepted { - client_protocol_id: alvr_common::protocol_id(), - display_name: platform::device_model(), - server_ip, - streaming_capabilities: Some(VideoStreamingCapabilities { - default_view_resolution: recommended_view_resolution, - supported_refresh_rates, - microphone_sample_rate, - }), - }, - ) + .send(&ClientConnectionResult::ConnectionAccepted { + client_protocol_id: alvr_common::protocol_id(), + display_name: platform::device_model(), + server_ip, + streaming_capabilities: Some(VideoStreamingCapabilities { + default_view_resolution: recommended_view_resolution, + supported_refresh_rates, + microphone_sample_rate, + }), + }) .to_con()?; - let config_packet = - proto_control_socket.recv::(&runtime, Duration::from_secs(1))?; + let config_packet = proto_control_socket.recv::()?; let settings = { let mut session_desc = SessionConfig::default(); @@ -228,9 +223,11 @@ fn connection_pipeline( }, )); - let (mut control_sender, mut control_receiver) = proto_control_socket.split(); + let (mut control_sender, mut control_receiver) = proto_control_socket + .split(Duration::from_millis(500)) + .to_con()?; - match control_receiver.recv(&runtime, Duration::from_secs(1)) { + match control_receiver.recv() { Ok(ServerControlPacket::StartStream) => { info!("Stream starting"); set_hud_message(STREAM_STARTING_MESSAGE); @@ -261,7 +258,7 @@ fn connection_pipeline( ) .to_con()?; - if let Err(e) = control_sender.send(&runtime, &ClientControlPacket::StreamReady) { + if let Err(e) = control_sender.send(&ClientControlPacket::StreamReady) { info!("Server disconnected. Cause: {e}"); set_hud_message(SERVER_DISCONNECTED_MESSAGE); return Ok(()); @@ -416,41 +413,30 @@ fn connection_pipeline( while IS_STREAMING.value() && IS_RESUMED.value() && IS_ALIVE.value() { if let Ok(packet) = control_channel_receiver.recv_timeout(Duration::from_millis(500)) { - if let Some(runtime) = &*CONNECTION_RUNTIME.read() { - if let Err(e) = control_sender.send(runtime, &packet) { - info!("Server disconnected. Cause: {e}"); - set_hud_message(SERVER_DISCONNECTED_MESSAGE); + if let Err(e) = control_sender.send(&packet) { + info!("Server disconnected. Cause: {e}"); + set_hud_message(SERVER_DISCONNECTED_MESSAGE); - break; - } + break; } } if Instant::now() > keepalive_deadline { - if let Some(runtime) = &*CONNECTION_RUNTIME.read() { - control_sender - .send(runtime, &ClientControlPacket::KeepAlive) - .ok(); + control_sender.send(&ClientControlPacket::KeepAlive).ok(); - keepalive_deadline = Instant::now() + KEEPALIVE_INTERVAL; - } + keepalive_deadline = Instant::now() + KEEPALIVE_INTERVAL; } #[cfg(target_os = "android")] if Instant::now() > battery_deadline { - if let Some(runtime) = &*CONNECTION_RUNTIME.read() { - let (gauge_value, is_plugged) = battery_manager.status(); - control_sender - .send( - runtime, - &ClientControlPacket::Battery(crate::BatteryPacket { - device_id: *alvr_common::HEAD_ID, - gauge_value, - is_plugged, - }), - ) - .ok(); - } + let (gauge_value, is_plugged) = battery_manager.status(); + control_sender + .send(&ClientControlPacket::Battery(crate::BatteryPacket { + device_id: *alvr_common::HEAD_ID, + gauge_value, + is_plugged, + })) + .ok(); battery_deadline = Instant::now() + Duration::from_secs(5); } @@ -461,36 +447,34 @@ fn connection_pipeline( } }); - let control_receive_thread = thread::spawn(move || loop { - let maybe_packet = if let Some(runtime) = &*CONNECTION_RUNTIME.read() { - control_receiver.recv(runtime, Duration::from_millis(500)) - } else { - return; - }; + let control_receive_thread = thread::spawn(move || { + while IS_STREAMING.value() { + let maybe_packet = control_receiver.recv(); - match maybe_packet { - Ok(ServerControlPacket::InitializeDecoder(config)) => { - decoder::create_decoder(config); - } - Ok(ServerControlPacket::Restarting) => { - info!("{SERVER_RESTART_MESSAGE}"); - set_hud_message(SERVER_RESTART_MESSAGE); - if let Some(notifier) = &*DISCONNECT_SERVER_NOTIFIER.lock() { - notifier.send(()).ok(); + match maybe_packet { + Ok(ServerControlPacket::InitializeDecoder(config)) => { + decoder::create_decoder(config); } + Ok(ServerControlPacket::Restarting) => { + info!("{SERVER_RESTART_MESSAGE}"); + set_hud_message(SERVER_RESTART_MESSAGE); + if let Some(notifier) = &*DISCONNECT_SERVER_NOTIFIER.lock() { + notifier.send(()).ok(); + } - return; - } - Ok(_) => (), - Err(ConnectionError::TryAgain) => (), - Err(e) => { - info!("{SERVER_DISCONNECTED_MESSAGE} Cause: {e}"); - set_hud_message(SERVER_DISCONNECTED_MESSAGE); - if let Some(notifier) = &*DISCONNECT_SERVER_NOTIFIER.lock() { - notifier.send(()).ok(); + return; } + Ok(_) => (), + Err(ConnectionError::TryAgain) => (), + Err(e) => { + info!("{SERVER_DISCONNECTED_MESSAGE} Cause: {e}"); + set_hud_message(SERVER_DISCONNECTED_MESSAGE); + if let Some(notifier) = &*DISCONNECT_SERVER_NOTIFIER.lock() { + notifier.send(()).ok(); + } - return; + return; + } } } }); diff --git a/alvr/server/src/connection.rs b/alvr/server/src/connection.rs index d67eef021f..b24ed87666 100644 --- a/alvr/server/src/connection.rs +++ b/alvr/server/src/connection.rs @@ -311,7 +311,6 @@ fn try_connect(mut client_ips: HashMap) -> ConResult { let runtime = Runtime::new().to_con()?; let (mut proto_socket, client_ip) = ProtoControlSocket::connect_to( - &runtime, Duration::from_secs(1), PeerType::AnyClient(client_ips.keys().cloned().collect()), )?; @@ -368,7 +367,7 @@ fn try_connect(mut client_ips: HashMap) -> ConResult { display_name, streaming_capabilities, .. - } = proto_socket.recv(&runtime, Duration::from_secs(1))? + } = proto_socket.recv()? { SERVER_DATA_MANAGER.write().update_client_list( client_hostname.clone(), @@ -486,9 +485,10 @@ fn try_connect(mut client_ips: HashMap) -> ConResult { }) .to_string(), }; - proto_socket.send(&runtime, &client_config).to_con()?; + proto_socket.send(&client_config).to_con()?; - let (mut control_sender, mut control_receiver) = proto_socket.split(); + let (mut control_sender, mut control_receiver) = + proto_socket.split(Duration::from_millis(500)).to_con()?; let mut new_openvr_config = contruct_openvr_config(); new_openvr_config.eye_resolution_width = stream_view_resolution.x; @@ -500,18 +500,16 @@ fn try_connect(mut client_ips: HashMap) -> ConResult { if SERVER_DATA_MANAGER.read().session().openvr_config != new_openvr_config { SERVER_DATA_MANAGER.write().session_mut().openvr_config = new_openvr_config; - control_sender - .send(&runtime, &ServerControlPacket::Restarting) - .ok(); + control_sender.send(&ServerControlPacket::Restarting).ok(); crate::notify_restart_driver(); } control_sender - .send(&runtime, &ServerControlPacket::StartStream) + .send(&ServerControlPacket::StartStream) .to_con()?; - match control_receiver.recv(&runtime, Duration::from_secs(1)) { + match control_receiver.recv() { Ok(ClientControlPacket::StreamReady) => (), Ok(_) => { con_bail!("Got unexpected packet waiting for stream ack"); @@ -562,19 +560,21 @@ fn try_connect(mut client_ips: HashMap) -> ConResult { *VIDEO_CHANNEL_SENDER.lock() = Some(video_channel_sender); *HAPTICS_SENDER.lock() = Some(haptics_sender); - let video_send_thread = thread::spawn(move || loop { - let VideoPacket { header, payload } = - match video_channel_receiver.recv_timeout(Duration::from_millis(500)) { - Ok(packet) => packet, - Err(RecvTimeoutError::Timeout) => continue, - Err(RecvTimeoutError::Disconnected) => return, - }; + let video_send_thread = thread::spawn(move || { + while IS_STREAMING.value() { + let VideoPacket { header, payload } = + match video_channel_receiver.recv_timeout(Duration::from_millis(500)) { + Ok(packet) => packet, + Err(RecvTimeoutError::Timeout) => continue, + Err(RecvTimeoutError::Disconnected) => return, + }; - if let Some(runtime) = &*CONNECTION_RUNTIME.read() { - // IMPORTANT: The only error that can happen here is socket closed. For this reason it's - // acceptable to call .ok() and ignore the error. The connection would already be - // closing so no corruption handling is necessary - video_sender.send(runtime, &header, payload).ok(); + if let Some(runtime) = &*CONNECTION_RUNTIME.read() { + // IMPORTANT: The only error that can happen here is socket closed. For this reason it's + // acceptable to call .ok() and ignore the error. The connection would already be + // closing so no corruption handling is necessary + video_sender.send(runtime, &header, payload).ok(); + } } }); @@ -829,12 +829,9 @@ fn try_connect(mut client_ips: HashMap) -> ConResult { let keepalive_thread = thread::spawn({ let control_sender = Arc::clone(&control_sender); let client_hostname = client_hostname.clone(); - move || loop { - if let Some(runtime) = &*CONNECTION_RUNTIME.read() { - let res = control_sender - .lock() - .send(runtime, &ServerControlPacket::KeepAlive); - if let Err(e) = res { + move || { + while IS_STREAMING.value() { + if let Err(e) = control_sender.lock().send(&ServerControlPacket::KeepAlive) { info!("Client disconnected. Cause: {e}"); SERVER_DATA_MANAGER.write().update_client_list( @@ -849,20 +846,18 @@ fn try_connect(mut client_ips: HashMap) -> ConResult { return; } - } else { - return; - } - thread::sleep(KEEPALIVE_INTERVAL); + thread::sleep(KEEPALIVE_INTERVAL); + } } }); let control_receive_thread = thread::spawn({ let control_sender = Arc::clone(&control_sender); let client_hostname = client_hostname.clone(); - move || loop { - let packet = if let Some(runtime) = &*CONNECTION_RUNTIME.read() { - match control_receiver.recv(runtime, Duration::from_millis(500)) { + move || { + while IS_STREAMING.value() { + let packet = match control_receiver.recv() { Ok(packet) => packet, Err(ConnectionError::TryAgain) => continue, Err(ConnectionError::Other(e)) => { @@ -880,118 +875,113 @@ fn try_connect(mut client_ips: HashMap) -> ConResult { return; } - } - } else { - return; - }; + }; - match packet { - ClientControlPacket::PlayspaceSync(packet) => { - if !settings.headset.tracking_ref_only { - let area = packet.unwrap_or(Vec2::new(2.0, 2.0)); - unsafe { crate::SetChaperone(area.x, area.y) }; - - let data_manager_lock = SERVER_DATA_MANAGER.read(); - let config = &data_manager_lock.settings().headset; - tracking_manager.lock().recenter( - config.position_recentering_mode, - config.rotation_recentering_mode, - ); - } - } - ClientControlPacket::RequestIdr => { - let maybe_config = DECODER_CONFIG.lock().clone(); - if let (Some(runtime), Some(config)) = - (&*CONNECTION_RUNTIME.read(), maybe_config) - { - control_sender - .lock() - .send(runtime, &ServerControlPacket::InitializeDecoder(config)) - .ok(); - } - unsafe { crate::RequestIDR() } - } - ClientControlPacket::VideoErrorReport => { - if let Some(stats) = &mut *STATISTICS_MANAGER.lock() { - stats.report_packet_loss(); + match packet { + ClientControlPacket::PlayspaceSync(packet) => { + if !settings.headset.tracking_ref_only { + let area = packet.unwrap_or(Vec2::new(2.0, 2.0)); + unsafe { crate::SetChaperone(area.x, area.y) }; + + let data_manager_lock = SERVER_DATA_MANAGER.read(); + let config = &data_manager_lock.settings().headset; + tracking_manager.lock().recenter( + config.position_recentering_mode, + config.rotation_recentering_mode, + ); + } } - unsafe { crate::VideoErrorReportReceive() }; - } - ClientControlPacket::ViewsConfig(config) => unsafe { - crate::SetViewsConfig(FfiViewsConfig { - fov: [ - FfiFov { - left: config.fov[0].left, - right: config.fov[0].right, - up: config.fov[0].up, - down: config.fov[0].down, - }, - FfiFov { - left: config.fov[1].left, - right: config.fov[1].right, - up: config.fov[1].up, - down: config.fov[1].down, - }, - ], - ipd_m: config.ipd_m, - }); - }, - ClientControlPacket::Battery(packet) => unsafe { - crate::SetBattery(packet.device_id, packet.gauge_value, packet.is_plugged); - - if let Some(stats) = &mut *STATISTICS_MANAGER.lock() { - stats.report_battery( - packet.device_id, - packet.gauge_value, - packet.is_plugged, - ); + ClientControlPacket::RequestIdr => { + if let Some(config) = DECODER_CONFIG.lock().clone() { + control_sender + .lock() + .send(&ServerControlPacket::InitializeDecoder(config)) + .ok(); + } + unsafe { crate::RequestIDR() } } - }, - ClientControlPacket::Buttons(entries) => { - { - let data_manager_lock = SERVER_DATA_MANAGER.read(); - if data_manager_lock.settings().logging.log_button_presses { - alvr_events::send_event(EventType::Buttons( - entries - .iter() - .map(|e| ButtonEvent { - path: BUTTON_PATH_FROM_ID - .get(&e.path_id) - .cloned() - .unwrap_or_else(|| { - format!("Unknown (ID: {:#16x})", e.path_id) - }), - value: e.value, - }) - .collect(), - )); + ClientControlPacket::VideoErrorReport => { + if let Some(stats) = &mut *STATISTICS_MANAGER.lock() { + stats.report_packet_loss(); } + unsafe { crate::VideoErrorReportReceive() }; } + ClientControlPacket::ViewsConfig(config) => unsafe { + crate::SetViewsConfig(FfiViewsConfig { + fov: [ + FfiFov { + left: config.fov[0].left, + right: config.fov[0].right, + up: config.fov[0].up, + down: config.fov[0].down, + }, + FfiFov { + left: config.fov[1].left, + right: config.fov[1].right, + up: config.fov[1].up, + down: config.fov[1].down, + }, + ], + ipd_m: config.ipd_m, + }); + }, + ClientControlPacket::Battery(packet) => unsafe { + crate::SetBattery(packet.device_id, packet.gauge_value, packet.is_plugged); + + if let Some(stats) = &mut *STATISTICS_MANAGER.lock() { + stats.report_battery( + packet.device_id, + packet.gauge_value, + packet.is_plugged, + ); + } + }, + ClientControlPacket::Buttons(entries) => { + { + let data_manager_lock = SERVER_DATA_MANAGER.read(); + if data_manager_lock.settings().logging.log_button_presses { + alvr_events::send_event(EventType::Buttons( + entries + .iter() + .map(|e| ButtonEvent { + path: BUTTON_PATH_FROM_ID + .get(&e.path_id) + .cloned() + .unwrap_or_else(|| { + format!("Unknown (ID: {:#16x})", e.path_id) + }), + value: e.value, + }) + .collect(), + )); + } + } - for entry in entries { - let value = match entry.value { - ButtonValue::Binary(value) => FfiButtonValue { - type_: crate::FfiButtonType_BUTTON_TYPE_BINARY, - __bindgen_anon_1: crate::FfiButtonValue__bindgen_ty_1 { - binary: value.into(), + for entry in entries { + let value = match entry.value { + ButtonValue::Binary(value) => FfiButtonValue { + type_: crate::FfiButtonType_BUTTON_TYPE_BINARY, + __bindgen_anon_1: crate::FfiButtonValue__bindgen_ty_1 { + binary: value.into(), + }, }, - }, - ButtonValue::Scalar(value) => FfiButtonValue { - type_: crate::FfiButtonType_BUTTON_TYPE_SCALAR, - __bindgen_anon_1: crate::FfiButtonValue__bindgen_ty_1 { - scalar: value, + ButtonValue::Scalar(value) => FfiButtonValue { + type_: crate::FfiButtonType_BUTTON_TYPE_SCALAR, + __bindgen_anon_1: crate::FfiButtonValue__bindgen_ty_1 { + scalar: value, + }, }, - }, - }; + }; - unsafe { crate::SetButton(entry.path_id, value) }; + unsafe { crate::SetButton(entry.path_id, value) }; + } } + ClientControlPacket::Log { level, message } => { + info!("Client {client_hostname}: [{level:?}] {message}") + } + _ => (), } - ClientControlPacket::Log { level, message } => { - info!("Client {client_hostname}: [{level:?}] {message}") - } - _ => (), } } }); @@ -1068,12 +1058,10 @@ fn try_connect(mut client_ips: HashMap) -> ConResult { let res = disconnect_receiver.recv(); if matches!(res, Ok(ClientDisconnectRequest::ServerRestart)) { - if let Some(runtime) = &*CONNECTION_RUNTIME.read() { - control_sender - .lock() - .send(runtime, &ServerControlPacket::Restarting) - .ok(); - } + control_sender + .lock() + .send(&ServerControlPacket::Restarting) + .ok(); } // This requests shutdown from threads diff --git a/alvr/sockets/src/backend/mod.rs b/alvr/sockets/src/backend/mod.rs new file mode 100644 index 0000000000..3cd9e7b6e5 --- /dev/null +++ b/alvr/sockets/src/backend/mod.rs @@ -0,0 +1,19 @@ +pub mod tcp; +pub mod udp; + +use alvr_common::{anyhow::Result, ConResult}; + +pub trait SocketWriter { + fn send(&mut self, buffer: &[u8]) -> Result<()>; +} + +// Trait used to abstract different socket (or other input/output) implementations. The funtionality +// is the intersection of the functionality of each implementation, that is it inheirits all +// limitations +pub trait SocketReader { + // Returns number of bytes written. buffer must be big enough to be able to receive a full + // packet (size of MTU) otherwise data will be corrupted. The size of the data is + fn recv(&mut self, buffer: &mut [u8]) -> ConResult; + + fn peek(&self, buffer: &mut [u8]) -> ConResult; +} diff --git a/alvr/sockets/src/backend/tcp.rs b/alvr/sockets/src/backend/tcp.rs new file mode 100644 index 0000000000..4bcf8ec713 --- /dev/null +++ b/alvr/sockets/src/backend/tcp.rs @@ -0,0 +1,97 @@ +use crate::LOCAL_IP; + +use super::{SocketReader, SocketWriter}; +use alvr_common::{anyhow::Result, con_bail, ConResult, IOToCon, ToCon}; +use alvr_session::SocketBufferSize; +use std::{ + io::Read, + io::Write, + net::{IpAddr, SocketAddr, TcpListener, TcpStream}, + time::Duration, +}; + +pub fn bind( + timeout: Duration, + port: u16, + send_buffer_bytes: SocketBufferSize, + recv_buffer_bytes: SocketBufferSize, +) -> Result { + let socket = TcpListener::bind((LOCAL_IP, port))?.into(); + + crate::set_socket_buffers(&socket, send_buffer_bytes, recv_buffer_bytes).ok(); + socket.set_read_timeout(Some(timeout))?; + + Ok(socket.into()) +} + +pub fn accept_from_server( + listener: &TcpListener, + server_ip: Option, +) -> ConResult<(TcpStream, TcpStream)> { + // Uses timeout set during bind() + let (socket, server_address) = listener.accept().io_to_con()?; + + if let Some(ip) = server_ip { + if server_address.ip() != ip { + con_bail!( + "Connected to wrong client: Expected: {ip}, Found {}", + server_address.ip() + ); + } + } + + socket.set_nodelay(true).to_con()?; + + Ok((socket.try_clone().to_con()?, socket)) +} + +pub fn connect_to_client( + timeout: Duration, + client_ips: &[IpAddr], + port: u16, + send_buffer_bytes: SocketBufferSize, + recv_buffer_bytes: SocketBufferSize, +) -> ConResult<(TcpStream, TcpStream)> { + let split_timeout = timeout / client_ips.len() as u32; + + let mut res = alvr_common::try_again(); + for ip in client_ips { + res = TcpStream::connect_timeout(&SocketAddr::new(*ip, port), split_timeout).io_to_con(); + + if res.is_ok() { + break; + } + } + let socket = res?.into(); + + crate::set_socket_buffers(&socket, send_buffer_bytes, recv_buffer_bytes).ok(); + socket.set_read_timeout(Some(timeout)).to_con()?; + + let socket = TcpStream::from(socket); + + socket.set_nodelay(true).to_con()?; + + Ok((socket.try_clone().to_con()?, socket)) +} + +impl SocketWriter for TcpStream { + fn send(&mut self, buffer: &[u8]) -> Result<()> { + self.write_all(buffer)?; + + Ok(()) + } +} + +impl SocketReader for TcpStream { + fn recv(&mut self, buffer: &mut [u8]) -> ConResult { + let bytes = Read::read(self, buffer).io_to_con()?; + + Ok(bytes) + } + + fn peek(&self, buffer: &mut [u8]) -> ConResult { + let bytes = TcpStream::peek(self, buffer).io_to_con()?; + + Ok(bytes) + } +} diff --git a/alvr/sockets/src/backend/udp.rs b/alvr/sockets/src/backend/udp.rs new file mode 100644 index 0000000000..83f7fb45ea --- /dev/null +++ b/alvr/sockets/src/backend/udp.rs @@ -0,0 +1,40 @@ +use crate::LOCAL_IP; + +use super::SocketReader; +use alvr_common::{anyhow::Result, ConResult, IOToCon}; +use alvr_session::SocketBufferSize; +use std::net::{IpAddr, UdpSocket}; + +// Create tokio socket, convert to socket2, apply settings, convert back to tokio. This is done to +// let tokio set all the internal parameters it needs from the start. +pub fn bind( + port: u16, + send_buffer_bytes: SocketBufferSize, + recv_buffer_bytes: SocketBufferSize, +) -> Result { + let socket = UdpSocket::bind((LOCAL_IP, port))?.into(); + + crate::set_socket_buffers(&socket, send_buffer_bytes, recv_buffer_bytes).ok(); + + Ok(socket.into()) +} + +pub fn connect(socket: &UdpSocket, peer_ip: IpAddr, port: u16) -> Result<(UdpSocket, UdpSocket)> { + socket.connect((peer_ip, port))?; + + Ok((socket.try_clone()?, socket.try_clone()?)) +} + +impl SocketReader for UdpSocket { + fn recv(&mut self, buffer: &mut [u8]) -> ConResult { + let bytes = UdpSocket::recv(self, buffer).io_to_con()?; + + Ok(bytes) + } + + fn peek(&self, buffer: &mut [u8]) -> ConResult { + let bytes = UdpSocket::peek(self, buffer).io_to_con()?; + + Ok(bytes) + } +} diff --git a/alvr/sockets/src/control_socket.rs b/alvr/sockets/src/control_socket.rs index 17a4543384..26c96b8b25 100644 --- a/alvr/sockets/src/control_socket.rs +++ b/alvr/sockets/src/control_socket.rs @@ -1,52 +1,116 @@ -use super::{Ldc, CONTROL_PORT, LOCAL_IP}; -use alvr_common::{anyhow::Result, ConResult, ToCon}; -use bytes::Bytes; -use futures::{ - stream::{SplitSink, SplitStream}, - SinkExt, StreamExt, -}; +use crate::backend::{tcp, SocketReader, SocketWriter}; + +use super::CONTROL_PORT; +use alvr_common::{anyhow::Result, ConResult, IOToCon, ToCon}; +use alvr_session::SocketBufferSize; use serde::{de::DeserializeOwned, Serialize}; -use std::{marker::PhantomData, net::IpAddr, time::Duration}; -use tokio::{ - net::{TcpListener, TcpStream}, - runtime::Runtime, - time, +use std::{ + marker::PhantomData, + mem, + net::{IpAddr, TcpListener, TcpStream}, + time::Duration, }; -use tokio_util::codec::Framed; + +// This corresponds to the length of the payload +const FRAMED_PREFIX_LENGTH: usize = mem::size_of::(); + +pub struct RecvState { + packet_cursor: usize, // counts also the length prefix bytes + packet_length: usize, // contains length prefix +} + +fn framed_send( + socket: &mut TcpStream, + buffer: &mut Vec, + packet: &S, +) -> Result<()> { + let serialized_size = bincode::serialized_size(&packet)? as usize; + let packet_size = serialized_size + FRAMED_PREFIX_LENGTH; + + if buffer.len() < packet_size { + buffer.resize(packet_size, 0); + } + + buffer[0..FRAMED_PREFIX_LENGTH].copy_from_slice(&(serialized_size as u32).to_be_bytes()); + bincode::serialize_into(&mut buffer[FRAMED_PREFIX_LENGTH..packet_size], &packet)?; + + socket.send(&buffer[0..packet_size])?; + + Ok(()) +} + +fn framed_recv( + socket: &mut TcpStream, + buffer: &mut Vec, + maybe_recv_state: &mut Option, +) -> ConResult { + let recv_state_mut = if let Some(state) = maybe_recv_state { + state + } else { + let mut payload_length_bytes = [0; FRAMED_PREFIX_LENGTH]; + let count = socket.peek(&mut payload_length_bytes).io_to_con()?; + if count != FRAMED_PREFIX_LENGTH { + return alvr_common::try_again(); + } + let packet_length = + FRAMED_PREFIX_LENGTH + u32::from_be_bytes(payload_length_bytes) as usize; + + if buffer.len() < packet_length { + buffer.resize(packet_length, 0); + } + + maybe_recv_state.insert(RecvState { + packet_length, + packet_cursor: 0, + }) + }; + + recv_state_mut.packet_cursor += + socket.recv(&mut buffer[recv_state_mut.packet_cursor..recv_state_mut.packet_length])?; + if recv_state_mut.packet_cursor != recv_state_mut.packet_length { + return alvr_common::try_again(); + } + + let data = bincode::deserialize(&buffer[FRAMED_PREFIX_LENGTH..recv_state_mut.packet_length]) + .to_con()?; + + *maybe_recv_state = None; + + Ok(data) +} pub struct ControlSocketSender { - inner: SplitSink, Bytes>, + inner: TcpStream, + buffer: Vec, _phantom: PhantomData, } impl ControlSocketSender { - pub fn send(&mut self, runtime: &Runtime, packet: &S) -> Result<()> { - let packet_bytes = bincode::serialize(packet)?; - runtime.block_on(self.inner.send(packet_bytes.into()))?; - - Ok(()) + pub fn send(&mut self, packet: &S) -> Result<()> { + framed_send(&mut self.inner, &mut self.buffer, packet) } } pub struct ControlSocketReceiver { - inner: SplitStream>, + inner: TcpStream, + buffer: Vec, + recv_state: Option, _phantom: PhantomData, } impl ControlSocketReceiver { - pub fn recv(&mut self, runtime: &Runtime, timeout: Duration) -> ConResult { - let packet_bytes = runtime.block_on(async { - tokio::select! { - res = self.inner.next() => res.map(|p| p.to_con()).to_con(), - _ = time::sleep(timeout) => alvr_common::try_again(), - } - })??; - bincode::deserialize(&packet_bytes).to_con() + pub fn recv(&mut self) -> ConResult { + framed_recv(&mut self.inner, &mut self.buffer, &mut self.recv_state) } } -pub fn get_server_listener(runtime: &Runtime) -> Result { - let listener = runtime.block_on(TcpListener::bind((LOCAL_IP, CONTROL_PORT)))?; +pub fn get_server_listener(timeout: Duration) -> Result { + let listener = tcp::bind( + timeout, + CONTROL_PORT, + SocketBufferSize::Default, + SocketBufferSize::Default, + )?; Ok(listener) } @@ -54,7 +118,7 @@ pub fn get_server_listener(runtime: &Runtime) -> Result { // Proto-control-socket that can send and receive any packet. After the split, only the packets of // the specified types can be exchanged pub struct ProtoControlSocket { - inner: Framed, + inner: TcpStream, } pub enum PeerType<'a> { @@ -63,79 +127,52 @@ pub enum PeerType<'a> { } impl ProtoControlSocket { - pub fn connect_to( - runtime: &Runtime, - timeout: Duration, - peer: PeerType<'_>, - ) -> ConResult<(Self, IpAddr)> { + pub fn connect_to(timeout: Duration, peer: PeerType<'_>) -> ConResult<(Self, IpAddr)> { let socket = match peer { PeerType::AnyClient(ips) => { - let client_addresses = ips - .iter() - .map(|&ip| (ip, CONTROL_PORT).into()) - .collect::>(); - runtime.block_on(async { - tokio::select! { - res = TcpStream::connect(client_addresses.as_slice()) => res.to_con(), - _ = time::sleep(timeout) => alvr_common::try_again(), - } - })? - } - PeerType::Server(listener) => { - let (socket, _) = runtime.block_on(async { - tokio::select! { - res = listener.accept() => res.to_con(), - _ = time::sleep(timeout) => alvr_common::try_again(), - } - })?; - socket + tcp::connect_to_client( + timeout, + &ips, + CONTROL_PORT, + SocketBufferSize::Default, + SocketBufferSize::Default, + )? + .0 } + PeerType::Server(listener) => tcp::accept_from_server(listener, None)?.0, }; - socket.set_nodelay(true).to_con()?; let peer_ip = socket.peer_addr().to_con()?.ip(); - let socket = Framed::new(socket, Ldc::new()); Ok((Self { inner: socket }, peer_ip)) } - pub fn send(&mut self, runtime: &Runtime, packet: &S) -> Result<()> { - runtime.block_on(self.inner.send(bincode::serialize(packet)?.into()))?; - - Ok(()) + pub fn send(&mut self, packet: &S) -> Result<()> { + framed_send(&mut self.inner, &mut vec![], packet) } - pub fn recv( - &mut self, - runtime: &Runtime, - timeout: Duration, - ) -> ConResult { - let packet_bytes = runtime - .block_on(async { - tokio::select! { - res = self.inner.next() => res.map(|p| p.to_con()), - _ = time::sleep(timeout) => Some(alvr_common::try_again()), - } - }) - .to_con()??; - - bincode::deserialize(&packet_bytes).to_con() + pub fn recv(&mut self) -> ConResult { + framed_recv(&mut self.inner, &mut vec![], &mut None) } pub fn split( self, - ) -> (ControlSocketSender, ControlSocketReceiver) { - let (sender, receiver) = self.inner.split(); + timeout: Duration, + ) -> Result<(ControlSocketSender, ControlSocketReceiver)> { + self.inner.set_read_timeout(Some(timeout))?; - ( + Ok(( ControlSocketSender { - inner: sender, + inner: self.inner.try_clone()?, + buffer: vec![], _phantom: PhantomData, }, ControlSocketReceiver { - inner: receiver, + inner: self.inner, + buffer: vec![], + recv_state: None, _phantom: PhantomData, }, - ) + )) } } diff --git a/alvr/sockets/src/lib.rs b/alvr/sockets/src/lib.rs index 7136ece9ae..e58e421fc2 100644 --- a/alvr/sockets/src/lib.rs +++ b/alvr/sockets/src/lib.rs @@ -1,3 +1,4 @@ +mod backend; mod control_socket; mod stream_socket; @@ -15,3 +16,52 @@ pub const HANDSHAKE_PACKET_SIZE_BYTES: usize = 56; // this may change in future pub const KEEPALIVE_INTERVAL: Duration = Duration::from_millis(500); type Ldc = tokio_util::codec::LengthDelimitedCodec; + +// Memory buffer that contains a hidden prefix +#[derive(Default)] +pub struct Buffer { + inner: Vec, + cursor: usize, + length: usize, +} + +impl Buffer { + // Length of payload (without prefix) + #[must_use] + pub fn len(&self) -> usize { + self.length + } + + #[must_use] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + // Note: this will not advance the cursor. Allocations are handled automatically + // In case of reallocation, do not remove the cursor offset. This buffer is expected to be + // reused and the total allocation size will not change after the running start. + pub fn get_mut(&mut self, offset: usize, size: usize) -> &mut [u8] { + let required_size = self.cursor + offset + size; + if required_size > self.inner.len() { + self.inner.resize(required_size, 0); + } + + self.length = self.length.max(offset + size); + + &mut self.inner + } + + pub fn get(&self) -> &[u8] { + &self.inner[self.cursor..self.cursor + self.length] + } + + pub fn advance_cursor(&mut self, count: usize) { + self.cursor += count + } + + // Clear buffer and cursor + pub fn clear(&mut self) { + self.cursor = 0; + self.length = 0; + } +}