Skip to content

Commit

Permalink
Merge pull request #3 from oyyd/dev
Browse files Browse the repository at this point in the history
fix: fix frame sending order
  • Loading branch information
oyyd authored Jan 20, 2024
2 parents 31ae802 + a70167c commit a7f0185
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 67 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# CHANGELOG

## 0.2.0
chore: add frame trace log
fix: fix frame sending order

## 0.1.1

chore: change package description
Expand Down
12 changes: 11 additions & 1 deletion src/frame.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
pub type Sid = u32;

/// Frame commands of smux protocal.
#[derive(Clone, Copy)]
#[derive(Clone, Copy, Debug)]
pub enum Cmd {
/// Stream open.
Sync,
Expand Down Expand Up @@ -66,6 +66,16 @@ pub struct Frame {
pub data: Option<Vec<u8>>,
}

impl std::fmt::Debug for Frame {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Frame")
.field("sid", &self.sid)
.field("cmd", &self.cmd)
.field("len", &self.length)
.finish()
}
}

impl Frame {
pub fn new(ver: u8, cmd: Cmd, sid: Sid) -> Self {
Self {
Expand Down
52 changes: 32 additions & 20 deletions src/read_frame_grouper.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use crate::frame::{Cmd, Frame, Sid};
use crate::session_inner::ReadRequest;
use dashmap::DashMap;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio::sync::{mpsc, Mutex};

// Consume reading frames and split them into:
// - Sync frames
Expand All @@ -18,11 +18,11 @@ pub(crate) struct ReadFrameGrouper {
pub sync_tx: mpsc::Sender<Frame>,

// Session could also operate `sid_tx_map` and `sid_rx_map`.
pub sid_tx_map: Arc<DashMap<Sid, mpsc::Sender<Frame>>>,
pub sid_tx_map: Arc<Mutex<HashMap<Sid, Arc<Mutex<mpsc::Sender<Frame>>>>>>,

// `sid_rx_map` is shared with session.
// Items of `sid_rx_map` will be taken away by the session when accepting new streams.
pub sid_rx_map: Arc<DashMap<Sid, mpsc::Receiver<Frame>>>,
pub sid_rx_map: Arc<Mutex<HashMap<Sid, mpsc::Receiver<Frame>>>>,

pub sid_frame_buffer_size: usize,
}
Expand Down Expand Up @@ -57,11 +57,18 @@ impl ReadFrameGrouper {

async fn handle_sync(&mut self, read_req: ReadRequest) {
let sid = read_req.frame.sid;
if !self.sid_tx_map.contains_key(&sid) {
let (tx, rx) = mpsc::channel(self.sid_frame_buffer_size);
self.sid_tx_map.insert(sid, tx);
self.sid_rx_map.insert(sid, rx);
}
{
let contained = { self.sid_tx_map.lock().await.contains_key(&sid) };
if !contained {
let (tx, rx) = mpsc::channel(self.sid_frame_buffer_size);
self
.sid_tx_map
.lock()
.await
.insert(sid, Arc::new(Mutex::new(tx)));
self.sid_rx_map.lock().await.insert(sid, rx);
}
};
let send_sync_tx_res = self.sync_tx.send(read_req.frame).await;
if send_sync_tx_res.is_err() {
// session closed
Expand All @@ -71,13 +78,18 @@ impl ReadFrameGrouper {

async fn handle_fin_push(&mut self, read_req: ReadRequest) {
let sid = read_req.frame.sid;
if !self.sid_tx_map.contains_key(&sid) {
let contained = { self.sid_tx_map.lock().await.contains_key(&sid) };
if !contained {
// unexpected, ignore the frame
log::warn!("[grouper] receive unexecpted frame, sid: {}", sid,);
return;
}
let tx = self.sid_tx_map.get(&sid).unwrap();
let _ = tx.send(read_req.frame).await;
let tx = {
let lock = self.sid_tx_map.lock().await;
let tx = lock.get(&sid).unwrap();
tx.clone()
};
let _ = tx.lock().await.send(read_req.frame).await;
}
}

Expand All @@ -86,21 +98,22 @@ mod test {
use crate::frame::{Cmd, Frame};
use crate::read_frame_grouper::ReadFrameGrouper;
use crate::session_inner::ReadRequest;
use dashmap::DashMap;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio::sync::{mpsc, Mutex};

#[tokio::test]
async fn test_grouper() {
let (new_frame_tx, new_frame_rx) = mpsc::channel(1024);
let (sync_tx, mut sync_rx) = mpsc::channel(1024);
let sid_rx_map = Arc::new(DashMap::new());
let sid_rx_map = HashMap::new();
let sid_rx_map = Arc::new(Mutex::new(sid_rx_map));
let mut grouper = ReadFrameGrouper {
new_frame_rx,
sid_frame_buffer_size: 1024,
sync_tx,
sid_rx_map: sid_rx_map.clone(),
sid_tx_map: Arc::new(DashMap::new()),
sid_tx_map: Arc::new(Mutex::new(HashMap::new())),
};

// Should create correspond sid_tx_map when receive sync frames.
Expand All @@ -121,10 +134,9 @@ mod test {

let frame = sync_rx.recv().await.unwrap();
assert!(matches!(frame.cmd, Cmd::Sync));
let item = sid_rx_map.remove(&sid);
let item = { sid_rx_map.lock().await.remove(&sid) };
assert!(item.is_some());
let (id, mut item_frame_rx) = item.unwrap();
assert_eq!(id, sid);
let mut item_frame_rx = item.unwrap();
let frame = item_frame_rx.recv().await.unwrap();
assert_eq!(frame.sid, sid);

Expand All @@ -141,7 +153,7 @@ mod test {
frame.with_data(vec![0; 10]);
new_frame_tx.send(ReadRequest { frame }).await.unwrap();
item_frame_rx.recv().await.unwrap();
assert!(sid_rx_map.remove(&sid).is_none());
assert!(sid_rx_map.lock().await.remove(&sid).is_none());

// Cloud new_frame_rx should
drop(new_frame_tx);
Expand Down
82 changes: 44 additions & 38 deletions src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::frame::{Cmd, Frame, Sid};
use crate::read_frame_grouper::ReadFrameGrouper;
use crate::session_inner::{SessionInner, WriteRequest};
use crate::stream::Stream;
use dashmap::DashMap;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::sync::{mpsc, oneshot, Mutex};
Expand All @@ -24,9 +24,9 @@ pub struct Session {
go_away: bool,

sync_rx: mpsc::Receiver<Frame>,
sid_tx_map: Arc<DashMap<Sid, mpsc::Sender<Frame>>>,
sid_rx_map: Arc<DashMap<Sid, mpsc::Receiver<Frame>>>,
sid_close_tx_map: Arc<DashMap<Sid, oneshot::Sender<()>>>,
sid_tx_map: Arc<Mutex<HashMap<Sid, Arc<Mutex<mpsc::Sender<Frame>>>>>>,
sid_rx_map: Arc<Mutex<HashMap<Sid, mpsc::Receiver<Frame>>>>,
sid_close_tx_map: Arc<Mutex<HashMap<Sid, oneshot::Sender<()>>>>,
sid_drop_tx: mpsc::Sender<Sid>,

inner_err: Arc<Mutex<Option<TokioSmuxError>>>,
Expand All @@ -35,16 +35,18 @@ pub struct Session {
impl Drop for Session {
fn drop(&mut self) {
// close all streams
let mut keys: Vec<Sid> = vec![];
for kv in self.sid_close_tx_map.iter() {
let sid = kv.key();
keys.push(*sid);
}
for id in keys {
let item = self.sid_close_tx_map.remove(&id);
let (_, tx) = item.unwrap();
let _ = tx.send(());
}
let sid_close_tx_map = self.sid_close_tx_map.clone();
tokio::spawn(async move {
let mut keys: Vec<Sid> = vec![];
for (kv, _) in sid_close_tx_map.lock().await.iter() {
keys.push(*kv);
}
for id in keys {
let item = sid_close_tx_map.lock().await.remove(&id);
let tx = item.unwrap();
let _ = tx.send(());
}
});
}
}

Expand Down Expand Up @@ -91,8 +93,8 @@ impl Session {

// init ReadFrameGrouper
let (sync_tx, sync_rx) = mpsc::channel(MAX_IN_QUEUE_SYNC_FRAMES);
let sid_tx_map = Arc::new(DashMap::new());
let sid_rx_map = Arc::new(DashMap::new());
let sid_tx_map = Arc::new(Mutex::new(HashMap::new()));
let sid_rx_map = Arc::new(Mutex::new(HashMap::new()));
let mut spliter = ReadFrameGrouper {
new_frame_rx,
sync_tx,
Expand All @@ -119,7 +121,7 @@ impl Session {

sid_tx_map,
sid_rx_map,
sid_close_tx_map: Arc::new(DashMap::new()),
sid_close_tx_map: Arc::new(Mutex::new(HashMap::new())),
sid_drop_tx,

sync_rx,
Expand Down Expand Up @@ -181,17 +183,22 @@ impl Session {

// Update sid_tx_map and sid_rx_map when open_stream.
{
if !self.sid_tx_map.contains_key(&sid) {
let contained = { self.sid_tx_map.lock().await.contains_key(&sid) };
if !contained {
let (tx, rx) = mpsc::channel(self.config.stream_reading_frame_channel_capacity);
self.sid_tx_map.insert(sid, tx);
self.sid_rx_map.insert(sid, rx);
self
.sid_tx_map
.lock()
.await
.insert(sid, Arc::new(Mutex::new(tx)));
self.sid_rx_map.lock().await.insert(sid, rx);
}
}
let stream = self.new_stream(sid);
let stream = self.new_stream(sid).await;
if stream.is_err() {
// not likely
self.sid_tx_map.remove(&sid);
self.sid_rx_map.remove(&sid);
self.sid_tx_map.lock().await.remove(&sid);
self.sid_rx_map.lock().await.remove(&sid);
}
Ok(stream.unwrap())
}
Expand All @@ -209,25 +216,24 @@ impl Session {
let frame = frame.unwrap();
let sid = frame.sid;

let stream = self.new_stream(sid)?;
let stream = self.new_stream(sid).await?;

Ok(stream)
}

fn new_stream(&mut self, sid: Sid) -> Result<Stream> {
async fn new_stream(&mut self, sid: Sid) -> Result<Stream> {
let frame_rx = {
let rx = self.sid_rx_map.remove(&sid);
let rx = self.sid_rx_map.lock().await.remove(&sid);
if rx.is_none() {
return Err(TokioSmuxError::Default {
msg: "unexpected empty sid in sid_rx_map".to_string(),
});
}
let (_id, rx) = rx.unwrap();
rx
rx.unwrap()
};

let (close_tx, close_rx) = oneshot::channel();
self.sid_close_tx_map.insert(sid, close_tx);
self.sid_close_tx_map.lock().await.insert(sid, close_tx);

let mut stream = Stream::new(sid, frame_rx, self.write_tx.clone(), close_rx);
stream.with_drop_tx(Some(self.sid_drop_tx.clone()));
Expand Down Expand Up @@ -255,9 +261,9 @@ impl Session {
struct SessionCleaner {
sid_drop_rx: mpsc::Receiver<Sid>,

sid_tx_map: Arc<DashMap<Sid, mpsc::Sender<Frame>>>,
sid_rx_map: Arc<DashMap<Sid, mpsc::Receiver<Frame>>>,
sid_close_tx_map: Arc<DashMap<Sid, oneshot::Sender<()>>>,
sid_tx_map: Arc<Mutex<HashMap<Sid, Arc<Mutex<mpsc::Sender<Frame>>>>>>,
sid_rx_map: Arc<Mutex<HashMap<Sid, mpsc::Receiver<Frame>>>>,
sid_close_tx_map: Arc<Mutex<HashMap<Sid, oneshot::Sender<()>>>>,
}

impl SessionCleaner {
Expand All @@ -278,9 +284,9 @@ impl SessionCleaner {
}

let sid = sid.unwrap();
self.sid_tx_map.remove(&sid);
self.sid_rx_map.remove(&sid);
self.sid_close_tx_map.remove(&sid);
self.sid_tx_map.lock().await.remove(&sid);
self.sid_rx_map.lock().await.remove(&sid);
self.sid_close_tx_map.lock().await.remove(&sid);
}
}
}
Expand Down Expand Up @@ -531,9 +537,9 @@ pub mod test {

// clean up after a while
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
assert!(client.sid_tx_map.get(&sid).is_none());
assert!(client.sid_rx_map.get(&sid).is_none());
assert!(client.sid_close_tx_map.get(&sid).is_none());
assert!(client.sid_tx_map.lock().await.get(&sid).is_none());
assert!(client.sid_rx_map.lock().await.get(&sid).is_none());
assert!(client.sid_close_tx_map.lock().await.get(&sid).is_none());

// the remote should also receive the fin
let data = write_rx.recv().await.unwrap();
Expand Down
17 changes: 10 additions & 7 deletions src/session_inner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use std::future;
use crate::error::Result;
use crate::frame::HEADER_SIZE;
use crate::frame::{Cmd, Frame};
use log;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::sync::{mpsc, oneshot};
use tokio::time;
Expand Down Expand Up @@ -115,7 +116,7 @@ impl<T: AsyncRead + AsyncWrite + Send + Unpin + 'static> SessionInner<T> {
self.read_finished = true;
continue;
}
self.handle_read_data(&data[0..size])?;
self.handle_read_data(&data[0..size]).await?;
}
// write
req = self.write_rx.recv() => {
Expand All @@ -137,6 +138,7 @@ impl<T: AsyncRead + AsyncWrite + Send + Unpin + 'static> SessionInner<T> {

async fn handle_keep_alive_interval_tick(&mut self) -> Result<()> {
let frame = Frame::new_v1(Cmd::Nop, 0);
log::trace!("send frame: {:?}", frame);
let buf = frame.get_buf()?;
self.conn.write_all(&buf).await?;

Expand All @@ -152,14 +154,16 @@ impl<T: AsyncRead + AsyncWrite + Send + Unpin + 'static> SessionInner<T> {
return Ok(());
}

log::trace!("send frame: {:?}", req.frame);

let finish_tx = req.finish_tx.take().unwrap();
// ignore stream closed error
let _ = finish_tx.send(());

Ok(())
}

fn handle_read_data(&mut self, data: &[u8]) -> Result<()> {
async fn handle_read_data(&mut self, data: &[u8]) -> Result<()> {
if data.len() == 0 {
// Remote write side closed, no more data.
return Ok(());
Expand All @@ -179,6 +183,9 @@ impl<T: AsyncRead + AsyncWrite + Send + Unpin + 'static> SessionInner<T> {
break;
}
let mut frame = frame.unwrap();

log::trace!("receive frame: {:?}", frame);

let frame_length = frame.length;
// check if all data ready
if (frame_length as u32 + HEADER_SIZE as u32) > (self.read_buf.len() as u32) {
Expand All @@ -196,11 +203,7 @@ impl<T: AsyncRead + AsyncWrite + Send + Unpin + 'static> SessionInner<T> {

// output frame
let recv_tx = self.recv_tx.clone();
tokio::spawn(async move {
// Will block if the tx capability is empty.
// is_err() means the session is closed, therefore ignore the error.
let _ = recv_tx.send(read_req).await;
});
let _ = recv_tx.send(read_req).await;

// continue
}
Expand Down
Loading

0 comments on commit a7f0185

Please sign in to comment.