Skip to content

Commit

Permalink
Add support for custom service connectors
Browse files Browse the repository at this point in the history
  • Loading branch information
operutka committed Aug 6, 2024
1 parent 5d7a27c commit 453856c
Show file tree
Hide file tree
Showing 7 changed files with 202 additions and 43 deletions.
12 changes: 12 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 7 additions & 6 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@ all = ["discovery", "threads", "exports"]
crate-type = ["rlib", "cdylib", "staticlib"]

[dependencies]
bytes = "1"
farmhash = "1.1"
fs2 = "0.4"
json = "0.12"
openssl = "0.10"
time = "0.1"
bytes = "1"
farmhash = "1.1"
fs2 = "0.4"
json = "0.12"
openssl = "0.10"
time = "0.1"
trait-variant = "0.1"

[dependencies.futures]
version = "0.3"
Expand Down
36 changes: 31 additions & 5 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,19 @@ use crate::svc_table::Service;
use crate::utils::logger::{BoxLogger, Logger};
use crate::ArrowClientEventListener;

pub use crate::net::arrow::{DefaultServiceConnector, ServiceConnection, ServiceConnector};

/// Connection retry timeout.
const RETRY_TIMEOUT: Duration = Duration::from_secs(60);

/// Get maximum duration of the pairing mode.
const PAIRING_MODE_TIMEOUT: Duration = Duration::from_secs(1200);

/// This future ensures maintaining connection with a remote Arrow Service.
struct ArrowMainTask {
struct ArrowMainTask<C> {
app_context: ApplicationContext,
cmd_channel: CommandChannel,
svc_connector: C,
logger: BoxLogger,
default_addr: String,
current_addr: String,
Expand All @@ -52,9 +55,13 @@ struct ArrowMainTask {
diagnostic_mode: bool,
}

impl ArrowMainTask {
impl<C> ArrowMainTask<C>
where
C: ServiceConnector + Clone + Unpin + 'static,
C::Connection: Send + Unpin,
{
/// Create a new task.
async fn start(app_context: ApplicationContext, cmd_channel: CommandChannel) {
async fn start(app_context: ApplicationContext, cmd_channel: CommandChannel, svc_connector: C) {
let logger = app_context.get_logger();
let addr = app_context.get_arrow_service_address();
let diagnostic_mode = app_context.get_diagnostic_mode();
Expand All @@ -66,6 +73,7 @@ impl ArrowMainTask {
let mut task = ArrowMainTask {
app_context,
cmd_channel,
svc_connector,
logger,
default_addr: addr.clone(),
current_addr: addr,
Expand Down Expand Up @@ -97,6 +105,7 @@ impl ArrowMainTask {
arrow::connect(
self.app_context.clone(),
self.cmd_channel.clone(),
self.svc_connector.clone(),
&self.current_addr,
)
.await
Expand Down Expand Up @@ -128,7 +137,7 @@ impl ArrowMainTask {

let retry = process_connection_error(err, self.last_attempt, self.pairing_mode_timeout);

self.current_addr = self.default_addr.clone();
self.current_addr.clone_from(&self.default_addr);

let fut = wait_for_retry(&mut self.logger, retry);

Expand Down Expand Up @@ -273,14 +282,31 @@ pub struct ArrowClient {

impl ArrowClient {
/// Create a new Arrow client from a given config.
///
/// # Arguments
/// * `config` - Arrow client configuration
pub fn new(config: Config) -> (ArrowClient, ArrowClientTask) {
Self::new_with_connector(config, DefaultServiceConnector::new())
}

/// Create a new Arrow client from a given config.
///
/// # Arguments
/// * `config` - Arrow client configuration
/// * `svc_connector` - custom service connector
pub fn new_with_connector<C>(config: Config, svc_connector: C) -> (ArrowClient, ArrowClientTask)
where
C: ServiceConnector + Clone + Unpin + 'static,
C::Connection: Send + Unpin,
{
let context = ApplicationContext::new(config);

// create command handler
let (cmd_channel, cmd_handler) = cmd_handler::new(context.clone());

// create Arrow client main task
let arrow_main_task = ArrowMainTask::start(context.clone(), cmd_channel.clone());
let arrow_main_task =
ArrowMainTask::start(context.clone(), cmd_channel.clone(), svc_connector);

let nw_scan_cmd_channel = cmd_channel.clone();

Expand Down
4 changes: 3 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ pub mod storage;
#[doc(hidden)]
pub mod svc_table;

pub use client::{ArrowClient, ArrowClientTask};
pub use client::{
ArrowClient, ArrowClientTask, DefaultServiceConnector, ServiceConnection, ServiceConnector,
};
pub use context::ApplicationEventListener as ArrowClientEventListener;
pub use context::ConnectionState;

Expand Down
54 changes: 54 additions & 0 deletions src/net/arrow/connector.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
use std::{future::Future, io, net::SocketAddr};

use tokio::{
io::{AsyncRead, AsyncWrite},
net::TcpStream,
};

use crate::net::raw::ether::MacAddr;
use crate::svc_table::ServiceType;

/// Service connection.
pub trait ServiceConnection: AsyncRead + AsyncWrite {}

impl<T> ServiceConnection for T where T: AsyncRead + AsyncWrite {}

/// Service connector.
#[trait_variant::make(Send)]
pub trait ServiceConnector {
type Connection: ServiceConnection;

/// Connect to a given service.
async fn connect(
&self,
svc_type: ServiceType,
mac: MacAddr,
addr: SocketAddr,
) -> io::Result<Self::Connection>;
}

/// Default service connector.
#[derive(Default, Copy, Clone)]
pub struct DefaultServiceConnector(());

impl DefaultServiceConnector {
/// Create a new service connector.
#[inline]
pub const fn new() -> Self {
Self(())
}
}

impl ServiceConnector for DefaultServiceConnector {
type Connection = TcpStream;

#[inline]
fn connect(
&self,
_: ServiceType,
_: MacAddr,
addr: SocketAddr,
) -> impl Future<Output = io::Result<Self::Connection>> {
TcpStream::connect(addr)
}
}
52 changes: 38 additions & 14 deletions src/net/arrow/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

mod connector;
mod error;
mod proto;
mod session;
Expand Down Expand Up @@ -44,6 +45,7 @@ use crate::net::raw::ether::MacAddr;
use crate::svc_table::SharedServiceTableRef;
use crate::utils::logger::{BoxLogger, Logger};

pub use self::connector::{DefaultServiceConnector, ServiceConnection, ServiceConnector};
pub use self::error::{ArrowError, ErrorKind};

const ACK_TIMEOUT: Duration = Duration::from_secs(20);
Expand Down Expand Up @@ -82,13 +84,13 @@ impl ExpectedAck {
}

/// Arrow Client implementation.
struct ArrowClientContext {
struct ArrowClientContext<C> {
logger: BoxLogger,
app_context: ApplicationContext,
cmd_channel: CommandChannel,
svc_table: SharedServiceTableRef,
cmsg_factory: ControlMessageFactory,
sessions: SessionManager,
sessions: SessionManager<C>,
messages: VecDeque<ArrowMessage>,
expected_acks: VecDeque<ExpectedAck>,
state: ProtocolState,
Expand All @@ -100,9 +102,13 @@ struct ArrowClientContext {
last_stable_ver: u32,
}

impl ArrowClientContext {
impl<C> ArrowClientContext<C>
where
C: ServiceConnector + Clone + Send + 'static,
C::Connection: Send + Unpin,
{
/// Create a new Arrow Client.
fn new(app_context: ApplicationContext, cmd_channel: CommandChannel) -> Self {
fn new(app_context: ApplicationContext, cmd_channel: CommandChannel, svc_connector: C) -> Self {
let logger = app_context.get_logger();
let svc_table = app_context.get_service_table();

Expand All @@ -115,6 +121,7 @@ impl ArrowClientContext {
let session_manager = SessionManager::new(
app_context.clone(),
cmsg_factory.clone(),
svc_connector,
SESSION_WINDOW_SIZE,
gateway_mode,
);
Expand Down Expand Up @@ -535,7 +542,11 @@ impl ArrowClientContext {
}
}

impl Stream for ArrowClientContext {
impl<C> Stream for ArrowClientContext<C>
where
C: ServiceConnector + Clone + Unpin + 'static,
C::Connection: Send + Unpin,
{
type Item = Result<ArrowMessage, ArrowError>;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
Expand All @@ -558,19 +569,24 @@ impl Stream for ArrowClientContext {
}
}

struct ArrowClient<S> {
context: Arc<Mutex<ArrowClientContext>>,
struct ArrowClient<C, S> {
context: Arc<Mutex<ArrowClientContext<C>>>,
stream: S,
}

impl<S> ArrowClient<S> {
impl<C, S> ArrowClient<C, S>
where
C: ServiceConnector + Clone + 'static,
C::Connection: Send + Unpin,
{
/// Create a new instance of Arrow Client.
fn new(
app_context: ApplicationContext,
cmd_channel: CommandChannel,
svc_connector: C,
stream: S,
) -> (Self, JoinHandle<Result<(), ArrowError>>) {
let context = ArrowClientContext::new(app_context, cmd_channel);
let context = ArrowClientContext::new(app_context, cmd_channel, svc_connector);

let context = Arc::new(Mutex::new(context));

Expand Down Expand Up @@ -603,7 +619,7 @@ impl<S> ArrowClient<S> {
}
}

impl<S> Drop for ArrowClient<S> {
impl<C, S> Drop for ArrowClient<C, S> {
fn drop(&mut self) {
let mut context = self.context.lock().unwrap();

Expand All @@ -613,8 +629,10 @@ impl<S> Drop for ArrowClient<S> {
}
}

impl<S> Stream for ArrowClient<S>
impl<C, S> Stream for ArrowClient<C, S>
where
C: ServiceConnector + Clone + Unpin + 'static,
C::Connection: Send + Unpin,
S: Stream<Item = Result<ArrowMessage, ArrowError>> + Unpin,
{
type Item = Result<ArrowMessage, ArrowError>;
Expand Down Expand Up @@ -650,11 +668,16 @@ where
}

/// Connect Arrow Client to a given address and return either a redirect address or an error.
pub async fn connect(
pub async fn connect<C>(
app_context: ApplicationContext,
cmd_channel: CommandChannel,
svc_connector: C,
addr: &str,
) -> Result<String, ArrowError> {
) -> Result<String, ArrowError>
where
C: ServiceConnector + Clone + Unpin + 'static,
C::Connection: Send + Unpin,
{
let tls_connector = app_context
.get_tls_connector()
.map_err(|err| ArrowError::other(format!("unable to get TLS context: {}", err)))?;
Expand All @@ -678,7 +701,8 @@ pub async fn connect(

let (mut sink, stream) = framed.split();

let (mut arrow_client, watchdog) = ArrowClient::new(app_context, cmd_channel, stream);
let (mut arrow_client, watchdog) =
ArrowClient::new(app_context, cmd_channel, svc_connector, stream);

let send = sink.send_all(&mut arrow_client);

Expand Down
Loading

0 comments on commit 453856c

Please sign in to comment.