Skip to content

Commit

Permalink
Convert authentication trait to associated async trait
Browse files Browse the repository at this point in the history
  • Loading branch information
nemosupremo committed May 9, 2023
1 parent b792d00 commit 8327af0
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 56 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ tokio = { version = "1.25.0", features = ["io-util", "net", "time", "macros"] }
anyhow = "1.0"
thiserror = "1.0"
tokio-stream = "0.1.9"
futures = "0.3"

# Dependencies for examples and tests
[dev-dependencies]
Expand Down
22 changes: 13 additions & 9 deletions examples/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
extern crate log;

use fast_socks5::{
server::{Config, SimpleUserPassword, Socks5Server, Socks5Socket},
server::{Authentication, Config, SimpleUserPassword, Socks5Server, Socks5Socket},
Result, SocksError,
};
use std::future::Future;
Expand Down Expand Up @@ -76,22 +76,25 @@ async fn spawn_socks_server() -> Result<()> {
config.set_request_timeout(opt.request_timeout);
config.set_skip_auth(opt.skip_auth);

match opt.auth {
AuthMode::NoAuth => warn!("No authentication has been set!"),
let config = match opt.auth {
AuthMode::NoAuth => {
warn!("No authentication has been set!");
config
}
AuthMode::Password { username, password } => {
if opt.skip_auth {
return Err(SocksError::ArgumentInputError(
"Can't use skip-auth flag and authentication altogether.",
));
}

config.set_authentication(SimpleUserPassword { username, password });
info!("Simple auth system has been set.");
config.with_authentication(SimpleUserPassword { username, password })
}
}
};

let mut listener = Socks5Server::bind(&opt.listen_addr).await?;
listener.set_config(config);
let listener = <Socks5Server>::bind(&opt.listen_addr).await?;
let listener = listener.with_config(config);

let mut incoming = listener.incoming();

Expand All @@ -112,10 +115,11 @@ async fn spawn_socks_server() -> Result<()> {
Ok(())
}

fn spawn_and_log_error<F, T>(fut: F) -> task::JoinHandle<()>
fn spawn_and_log_error<F, T, A>(fut: F) -> task::JoinHandle<()>
where
F: Future<Output = Result<Socks5Socket<T>>> + Send + 'static,
F: Future<Output = Result<Socks5Socket<T, A>>> + Send + 'static,
T: AsyncRead + AsyncWrite + Unpin,
A: Authentication,
{
task::spawn(async move {
if let Err(e) = fut.await {
Expand Down
18 changes: 11 additions & 7 deletions examples/simple_tcp_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
extern crate log;

use fast_socks5::{
server::{Config, SimpleUserPassword, Socks5Socket},
server::{Authentication, Config, SimpleUserPassword, Socks5Socket},
Result,
};
use std::future::Future;
Expand Down Expand Up @@ -76,13 +76,16 @@ async fn spawn_socks_server() -> Result<()> {
let mut config = Config::default();
config.set_request_timeout(opt.request_timeout);

match opt.auth {
AuthMode::NoAuth => warn!("No authentication has been set!"),
let config = match opt.auth {
AuthMode::NoAuth => {
warn!("No authentication has been set!");
config
}
AuthMode::Password { username, password } => {
config.set_authentication(SimpleUserPassword { username, password });
info!("Simple auth system has been set.");
config.with_authentication(SimpleUserPassword { username, password })
}
}
};

let config = Arc::new(config);

Expand All @@ -105,10 +108,11 @@ async fn spawn_socks_server() -> Result<()> {
}
}

fn spawn_and_log_error<F, T>(fut: F) -> task::JoinHandle<()>
fn spawn_and_log_error<F, T, A>(fut: F) -> task::JoinHandle<()>
where
F: Future<Output = Result<Socks5Socket<T>>> + Send + 'static,
F: Future<Output = Result<Socks5Socket<T, A>>> + Send + 'static,
T: AsyncRead + AsyncWrite + Unpin,
A: Authentication,
{
task::spawn(async move {
if let Err(e) = fut.await {
Expand Down
10 changes: 4 additions & 6 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -312,12 +312,10 @@ mod test {
) -> Result<()> {
let mut config = server::Config::default();
config.set_udp_support(true);
match auth {
None => {}
Some(up) => {
config.set_authentication(up);
}
}
let config = match auth {
None => config,
Some(up) => config.with_authentication(up),
};

let config = Arc::new(config);
let listener = TcpListener::bind(proxy_addr).await?;
Expand Down
108 changes: 74 additions & 34 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ use crate::new_udp_header;
use crate::parse_udp_request;
use crate::read_exact;
use crate::ready;
use crate::util::target_addr::{read_address, TargetAddr};
use crate::util::stream::tcp_connect_with_timeout;
use crate::util::target_addr::{read_address, TargetAddr};
use crate::Socks5Command;
use crate::{consts, AuthenticationMethod, ReplyError, Result, SocksError};
use anyhow::Context;
Expand All @@ -22,7 +22,7 @@ use tokio::try_join;
use tokio_stream::Stream;

#[derive(Clone)]
pub struct Config {
pub struct Config<A: Authentication = DenyAuthentication> {
/// Timeout of the command request
request_timeout: u64,
/// Avoid useless roundtrips if we don't need the Authentication layer
Expand All @@ -33,10 +33,10 @@ pub struct Config {
execute_command: bool,
/// Enable UDP support
allow_udp: bool,
auth: Option<Arc<dyn Authentication>>,
auth: Option<Arc<A>>,
}

impl Default for Config {
impl<A: Authentication> Default for Config<A> {
fn default() -> Self {
Config {
request_timeout: 10,
Expand All @@ -51,7 +51,11 @@ impl Default for Config {

/// Use this trait to handle a custom authentication on your end.
pub trait Authentication: Send + Sync {
fn authenticate(&self, username: &str, password: &str) -> bool;
type Item<'a>: Future<Output = bool> + 'a
where
Self: 'a;

fn authenticate<'a>(&'a self, username: &str, password: &str) -> Self::Item<'a>;
}

/// Basic user/pass auth method provided.
Expand All @@ -61,12 +65,36 @@ pub struct SimpleUserPassword {
}

impl Authentication for SimpleUserPassword {
fn authenticate(&self, username: &str, password: &str) -> bool {
username == &self.username && password == &self.password
type Item<'a> = futures::future::Ready<bool>;

fn authenticate<'a>(&'a self, username: &str, password: &str) -> Self::Item<'a> {
futures::future::ready(username == &self.username && password == &self.password)
}
}

impl Config {
#[derive(Copy, Clone, Default)]
pub struct DenyAuthentication {}

impl Authentication for DenyAuthentication {
type Item<'a> = futures::future::Ready<bool>;

fn authenticate<'a>(&'a self, _username: &str, _password: &str) -> Self::Item<'a> {
futures::future::ready(false)
}
}

#[derive(Copy, Clone, Default)]
pub struct AcceptAuthentication {}

impl Authentication for AcceptAuthentication {
type Item<'a> = futures::future::Ready<bool>;

fn authenticate<'a>(&'a self, _username: &str, _password: &str) -> Self::Item<'a> {
futures::future::ready(true)
}
}

impl<A: Authentication> Config<A> {
/// How much time it should wait until the request timeout.
pub fn set_request_timeout(&mut self, n: u64) -> &mut Self {
self.request_timeout = n;
Expand All @@ -83,12 +111,15 @@ impl Config {
/// Enable authentication
/// 'static lifetime for Authentication avoid us to use `dyn Authentication`
/// and set the Arc before calling the function.
pub fn set_authentication<T: Authentication + 'static>(
&mut self,
authentication: T,
) -> &mut Self {
self.auth = Some(Arc::new(authentication));
self
pub fn with_authentication<T: Authentication + 'static>(self, authentication: T) -> Config<T> {
Config {
request_timeout: self.request_timeout,
skip_auth: self.skip_auth,
dns_resolve: self.dns_resolve,
execute_command: self.execute_command,
allow_udp: self.allow_udp,
auth: Some(Arc::new(authentication)),
}
}

/// Set whether or not to execute commands
Expand All @@ -112,40 +143,45 @@ impl Config {

/// Wrapper of TcpListener
/// Useful if you don't use any existing TcpListener's streams.
pub struct Socks5Server {
pub struct Socks5Server<A: Authentication = DenyAuthentication> {
listener: TcpListener,
config: Arc<Config>,
config: Arc<Config<A>>,
}

impl Socks5Server {
pub async fn bind<A: AsyncToSocketAddrs>(addr: A) -> io::Result<Socks5Server> {
impl<A: Authentication + Default> Socks5Server<A> {
pub async fn bind<S: AsyncToSocketAddrs>(addr: S) -> io::Result<Self> {
let listener = TcpListener::bind(&addr).await?;
let config = Arc::new(Config::default());

Ok(Socks5Server { listener, config })
}
}

impl<A: Authentication> Socks5Server<A> {
/// Set a custom config
pub fn set_config(&mut self, config: Config) {
self.config = Arc::new(config);
pub fn with_config<T: Authentication>(self, config: Config<T>) -> Socks5Server<T> {
Socks5Server {
listener: self.listener,
config: Arc::new(config),
}
}

/// Can loop on `incoming().next()` to iterate over incoming connections.
pub fn incoming(&self) -> Incoming<'_> {
pub fn incoming(&self) -> Incoming<'_, A> {
Incoming(self, None)
}
}

/// `Incoming` implements [`futures::stream::Stream`].
pub struct Incoming<'a>(
&'a Socks5Server,
pub struct Incoming<'a, A: Authentication>(
&'a Socks5Server<A>,
Option<Pin<Box<dyn Future<Output = io::Result<(TcpStream, SocketAddr)>> + Send + Sync + 'a>>>,
);

/// Iterator for each incoming stream connection
/// this wrapper will convert async_std TcpStream into Socks5Socket.
impl<'a> Stream for Incoming<'a> {
type Item = Result<Socks5Socket<TcpStream>>;
impl<'a, A: Authentication> Stream for Incoming<'a, A> {
type Item = Result<Socks5Socket<TcpStream, A>>;

/// this code is mainly borrowed from [`Incoming::poll_next()` of `TcpListener`][tcpListener]
/// [tcpListener]: https://docs.rs/async-std/1.8.0/async_std/net/struct.TcpListener.html#method.incoming
Expand Down Expand Up @@ -176,18 +212,18 @@ impl<'a> Stream for Incoming<'a> {
}

/// Wrap TcpStream and contains Socks5 protocol implementation.
pub struct Socks5Socket<T: AsyncRead + AsyncWrite + Unpin> {
pub struct Socks5Socket<T: AsyncRead + AsyncWrite + Unpin, A: Authentication> {
inner: T,
config: Arc<Config>,
config: Arc<Config<A>>,
auth: AuthenticationMethod,
target_addr: Option<TargetAddr>,
cmd: Option<Socks5Command>,
/// Socket address which will be used in the reply message.
reply_ip: Option<IpAddr>,
}

impl<T: AsyncRead + AsyncWrite + Unpin> Socks5Socket<T> {
pub fn new(socket: T, config: Arc<Config>) -> Self {
impl<T: AsyncRead + AsyncWrite + Unpin, A: Authentication> Socks5Socket<T, A> {
pub fn new(socket: T, config: Arc<Config<A>>) -> Self {
Socks5Socket {
inner: socket,
config,
Expand Down Expand Up @@ -215,7 +251,7 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Socks5Socket<T> {

/// Process clients SOCKS requests
/// This is the entry point where a whole request is processed.
pub async fn upgrade_to_socks5(mut self) -> Result<Socks5Socket<T>> {
pub async fn upgrade_to_socks5(mut self) -> Result<Socks5Socket<T, A>> {
trace!("upgrading to socks5...");

// Handshake
Expand Down Expand Up @@ -386,7 +422,7 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Socks5Socket<T> {
let password = String::from_utf8(password).context("Failed to convert password")?;
let auth = self.config.auth.as_ref().context("No auth module")?;

if auth.authenticate(&username, &password) {
if auth.authenticate(&username, &password).await {
self.inner
.write(&[1, consts::SOCKS5_REPLY_SUCCEEDED])
.await
Expand Down Expand Up @@ -681,7 +717,7 @@ async fn transfer_udp(inbound: UdpSocket) -> Result<()> {
}

/// Allow us to read directly from the struct
impl<T> AsyncRead for Socks5Socket<T>
impl<T, A: Authentication> AsyncRead for Socks5Socket<T, A>
where
T: AsyncRead + AsyncWrite + Unpin,
{
Expand All @@ -695,7 +731,7 @@ where
}

/// Allow us to write directly into the struct
impl<T> AsyncWrite for Socks5Socket<T>
impl<T, A: Authentication> AsyncWrite for Socks5Socket<T, A>
where
T: AsyncRead + AsyncWrite + Unpin,
{
Expand Down Expand Up @@ -754,10 +790,14 @@ mod test {
use crate::server::Socks5Server;
use tokio_test::block_on;

use super::AcceptAuthentication;

#[test]
fn test_bind() {
let f = async {
let _server = Socks5Server::bind("127.0.0.1:1080").await.unwrap();
let _server = Socks5Server::<AcceptAuthentication>::bind("127.0.0.1:1080")
.await
.unwrap();
};

block_on(f);
Expand Down

0 comments on commit 8327af0

Please sign in to comment.