Skip to content

Commit

Permalink
Merge pull request #38 from ideal-world/tcp-gate
Browse files Browse the repository at this point in the history
Tcp gate
  • Loading branch information
4t145 authored Oct 15, 2024
2 parents f95c1d8 + 99500b0 commit e4df383
Show file tree
Hide file tree
Showing 20 changed files with 533 additions and 166 deletions.
7 changes: 4 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ members = [
"crates/config",
"crates/shell",
"examples/sayhello",
"examples/socks5-proxy",
]
resolver = "2"
[profile.release]
Expand Down Expand Up @@ -70,8 +71,8 @@ hyper-util = { version = "0" }
# ws
tokio-tungstenite = { version = "0" }
tower-layer = { version = "0.3" }
tower-http = { version = "0.5" }
tower = { version = "0.4" }
tower-http = { version = "0.6" }
tower = { version = "0.5" }

# K8s
kube = { version = "0.85", features = ["runtime", "derive"] }
Expand Down Expand Up @@ -108,4 +109,4 @@ ipnet = { version = "2" }
notify = { version = "6.1.1" }

# web-server
axum = "0.7.5"
axum = "0.7.6"
4 changes: 2 additions & 2 deletions crates/config/src/service/redis/retrieve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ where
}

async fn retrieve_config_item_route(&self, gateway_name: &str, route_name: &str) -> BoxResult<Option<crate::model::SgHttpRoute>> {
let http_route_config: Option<String> = self.get_con().await?.hget(&format!("{CONF_HTTP_ROUTE_KEY}{}", gateway_name), route_name).await?;
let http_route_config: Option<String> = self.get_con().await?.hget(format!("{CONF_HTTP_ROUTE_KEY}{}", gateway_name), route_name).await?;
http_route_config.map(|config| self.format.de::<SgHttpRoute>(config.as_bytes()).map_err(|e| format!("[SG.Config] Route Config parse error {}", e).into())).transpose()
}

async fn retrieve_config_item_route_names(&self, name: &str) -> BoxResult<Vec<String>> {
let http_route_configs: HashMap<String, String> = self.get_con().await?.hgetall(&format!("{CONF_HTTP_ROUTE_KEY}{}", name)).await?;
let http_route_configs: HashMap<String, String> = self.get_con().await?.hgetall(format!("{CONF_HTTP_ROUTE_KEY}{}", name)).await?;

Ok(http_route_configs.into_keys().collect())
}
Expand Down
2 changes: 1 addition & 1 deletion crates/kernel/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ crossbeam-utils = "0.8"
[dev-dependencies]
tokio = { version = "1", features = ["net", "time", "rt", "macros"] }
axum = { workspace = true, features = ["multipart"] }
axum-server = { version = "0.6", features = ["tls-rustls"] }
axum-server = { version = "0.7", features = ["tls-rustls"] }
md5 = { version = "0.7.0" }
reqwest = { version = "0.12", features = ["multipart", "stream"] }
tokio-tungstenite = { workspace = true }
Expand Down
4 changes: 2 additions & 2 deletions crates/kernel/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ pub use backend_service::ArcHyperService;
pub use body::SgBody;
use extension::Reflect;
pub use extractor::Extract;
use hyper::{body::Bytes, Request, Response, StatusCode};
use std::{convert::Infallible, fmt};
pub use tokio_util::sync::CancellationToken;
pub use tower_layer::Layer;

use hyper::{body::Bytes, Request, Response, StatusCode};

use tower_layer::layer_fn;

pub type BoxResult<T> = Result<T, BoxError>;
Expand Down
191 changes: 61 additions & 130 deletions crates/kernel/src/listener.rs
Original file line number Diff line number Diff line change
@@ -1,183 +1,114 @@
use futures_util::future::BoxFuture;
use hyper::{body::Incoming, Request, Response};
use hyper_util::rt::{self, TokioIo};
use std::{net::SocketAddr, sync::Arc};

use std::{convert::Infallible, net::SocketAddr, sync::Arc};
use tokio::net::TcpStream;
use tokio_rustls::rustls;
use futures_util::TryFutureExt;
use tokio_util::sync::CancellationToken;
use tracing::instrument;
use tracing::{instrument, Instrument};

use crate::{
extension::{EnterTime, PeerAddr, Reflect},
utils::with_length_or_chunked,
BoxError, SgBody,
};
use crate::{service::TcpService, BoxError, BoxResult};

/// Listener embodies the concept of a logical endpoint where a Gateway accepts network connections.
#[derive(Clone)]
pub struct SgListen<S> {
conn_builder: hyper_util::server::conn::auto::Builder<rt::TokioExecutor>,
pub struct SgListen {
pub socket_addr: SocketAddr,
pub service: S,
pub tls_cfg: Option<Arc<rustls::ServerConfig>>,
pub cancel_token: CancellationToken,
pub services: Vec<Arc<dyn TcpService>>,
pub listener_id: String,
cancel_token: CancellationToken,
}

impl<S> std::fmt::Debug for SgListen<S> {
impl std::fmt::Debug for SgListen {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SgListen")
.field("socket_addr", &self.socket_addr)
.field("tls_enabled", &self.tls_cfg.is_some())
.field("listener_id", &self.listener_id)
.field("services", &self.services.iter().map(|s| s.protocol_name()).collect::<Vec<_>>())
.finish_non_exhaustive()
}
}

impl<S> SgListen<S> {
impl SgListen {
/// we only have 65535 ports for a console, so it's a safe size
pub fn new(socket_addr: SocketAddr, service: S, cancel_token: CancellationToken) -> Self {
let listener_id = format!("{socket_addr}");
pub fn new(socket_addr: SocketAddr, cancel_token: CancellationToken) -> Self {
Self {
conn_builder: hyper_util::server::conn::auto::Builder::new(rt::TokioExecutor::new()),
socket_addr,
service,
tls_cfg: None,
services: Vec::new(),
cancel_token,
listener_id,
listener_id: Default::default(),
}
}

/// Set the TLS config for this listener.
/// see [rustls::ServerConfig](https://docs.rs/rustls/latest/rustls/server/struct.ServerConfig.html)
#[must_use]
pub fn with_tls_config(mut self, tls_cfg: impl Into<Arc<rustls::ServerConfig>>) -> Self {
self.tls_cfg = Some(tls_cfg.into());
pub fn with_service<S: TcpService>(mut self, service: S) -> Self {
self.services.push(Arc::new(service));
self
}
}

#[derive(Clone)]
struct HyperServiceAdapter<S>
where
S: hyper::service::Service<Request<SgBody>, Error = Infallible, Response = Response<SgBody>> + Clone + Send + 'static,
S::Future: Send + 'static,
{
service: S,
peer: SocketAddr,
}
impl<S> HyperServiceAdapter<S>
where
S: hyper::service::Service<Request<SgBody>, Error = Infallible, Response = Response<SgBody>> + Clone + Send + 'static,
S::Future: Send + 'static,
{
pub fn new(service: S, peer: SocketAddr) -> Self {
Self { service, peer }
pub fn add_service<S: TcpService>(&mut self, service: S) {
self.services.push(Arc::new(service));
}
}

impl<S> hyper::service::Service<Request<Incoming>> for HyperServiceAdapter<S>
where
S: hyper::service::Service<Request<SgBody>, Error = Infallible, Response = Response<SgBody>> + Clone + Send + 'static,
S::Future: Send + 'static,
{
type Response = Response<SgBody>;
type Error = Infallible;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
pub fn with_services(mut self, services: Vec<Arc<dyn TcpService>>) -> Self {
self.services.extend(services);
self
}

#[inline]
fn call(&self, mut req: Request<Incoming>) -> Self::Future {
req.extensions_mut().insert(self.peer);
// here we will clone underlying service,
// so it's important that underlying service is cheap to clone.
// here, the service are likely to be a `ArcHyperService` so it's ok
// but if underlying service is big, it will be expensive to clone.
// especially the router is big and the too many plugins are installed.
// so we should avoid that
let enter_time = EnterTime::new();
let service = self.service.clone();
let mut req = req.map(SgBody::new);
let mut reflect = Reflect::default();
reflect.insert(enter_time);
req.extensions_mut().insert(reflect);
req.extensions_mut().insert(PeerAddr(self.peer));
req.extensions_mut().insert(enter_time);
pub fn extend_services(&mut self, services: Vec<Arc<dyn TcpService>>) {
self.services.extend(services);
}

Box::pin(async move {
let mut resp = service.call(req).await.expect("infallible");
with_length_or_chunked(&mut resp);
let status = resp.status();
if status.is_server_error() {
tracing::warn!(status = ?status, headers = ?resp.headers(), "server error response");
} else if status.is_client_error() {
tracing::debug!(status = ?status, headers = ?resp.headers(), "client error response");
} else if status.is_success() {
tracing::trace!(status = ?status, headers = ?resp.headers(), "success response");
}
tracing::trace!(latency = ?enter_time.elapsed(), "request finished");
Ok(resp)
})
pub fn with_listener_id(mut self, listener_id: impl Into<String>) -> Self {
self.listener_id = listener_id.into();
self
}
}

impl<S> SgListen<S>
where
S: hyper::service::Service<Request<SgBody>, Error = Infallible, Response = Response<SgBody>> + Clone + Send + 'static,
S::Future: Send + 'static,
{
#[instrument(skip(stream, service, tls_cfg, conn_builder))]
async fn accept(
conn_builder: hyper_util::server::conn::auto::Builder<rt::TokioExecutor>,
stream: TcpStream,
peer_addr: SocketAddr,
tls_cfg: Option<Arc<rustls::ServerConfig>>,
service: S,
) {
tracing::debug!("[Sg.Listen] Accepted connection");
let service = HyperServiceAdapter::new(service, peer_addr);
let conn_result = if let Some(tls_cfg) = tls_cfg {
let connector = tokio_rustls::TlsAcceptor::from(tls_cfg);
let Ok(accepted) = connector.accept(stream).await.inspect_err(|e| tracing::warn!("[Sg.Listen] Tls connect error: {}", e)) else {
return;
};
let io = TokioIo::new(accepted);
let conn = conn_builder.serve_connection_with_upgrades(io, service);
conn.await
} else {
let io = TokioIo::new(stream);
let conn = conn_builder.serve_connection_with_upgrades(io, service);
conn.await
};
if let Err(e) = conn_result {
tracing::warn!("[Sg.Listen] Connection closed with error {e}")
} else {
tracing::debug!("[Sg.Listen] Connection closed");
}
impl SgListen {
/// Spawn the listener on the tokio runtime.
///
/// It's a shortcut for `tokio::spawn(listener.listen())`.
pub fn spawn(self) -> tokio::task::JoinHandle<Result<(), BoxError>> {
tokio::spawn(self.listen())
}
#[instrument()]

/// Listen on the socket address.
#[instrument(skip(self), fields(bind=%self.socket_addr))]
pub async fn listen(self) -> Result<(), BoxError> {
tracing::debug!("[Sg.Listen] start binding...");
tracing::debug!("start binding...");
let listener = tokio::net::TcpListener::bind(self.socket_addr).await?;
let cancel_token = self.cancel_token;
tracing::debug!("[Sg.Listen] start listening...");
tracing::debug!("start listening...");
let peek_size = self.services.iter().fold(0, |acc, s| acc.max(s.sniff_peek_size()));
let services: Arc<[Arc<dyn TcpService>]> = self.services.clone().into();
loop {
let accepted = tokio::select! {
() = cancel_token.cancelled() => {
tracing::warn!("[Sg.Listen] cancelled");
tracing::warn!("cancelled");
return Ok(());
},
accepted = listener.accept() => accepted
};
match accepted {
Ok((stream, peer_addr)) => {
let tls_cfg = self.tls_cfg.clone();
let service = self.service.clone();
let builder = self.conn_builder.clone();
tokio::spawn(Self::accept(builder, stream, peer_addr, tls_cfg, service));
let services = services.clone();
let _task = tokio::spawn(
async move {
let mut peek_buf = vec![0u8; peek_size];
stream.peek(&mut peek_buf).await?;
for s in services.iter() {
if s.sniff(&peek_buf) {
tracing::debug!(tcp_service=%s.protocol_name(), "accepted");
s.handle(stream, peer_addr).await?;
break;
}
}
BoxResult::Ok(())
}
.inspect_err(|e| {
tracing::warn!("TcpService error: {:?}", e);
})
.instrument(tracing::info_span!("connection", peer = %peer_addr)),
);
}
Err(e) => {
tracing::warn!("[Sg.Listen] Accept tcp connection error: {:?}", e);
tracing::warn!("Accept tcp connection error: {:?}", e);
}
}
}
Expand Down
Loading

0 comments on commit e4df383

Please sign in to comment.