Skip to content

Commit

Permalink
refac: minor changes
Browse files Browse the repository at this point in the history
  • Loading branch information
lffg committed Jun 21, 2024
1 parent 288184a commit d597435
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 122 deletions.
121 changes: 60 additions & 61 deletions ctl/src/balancer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,55 @@ use proto::{
use tracing::{instrument, trace, warn};
use utils::http::{self, OptionExt as _, ResultExt as _};

#[instrument(skip_all)]
pub async fn proxy(
ConnectInfo(addr): ConnectInfo<SocketAddr>,
State(balancer): State<BalancerState>,
mut req: Request,
) -> http::Result<impl IntoResponse> {
let service_id = extract_service_id(&mut req)?;

let (instance_id, server_addr) = balancer
.next(&service_id)
.or_http_error(StatusCode::NOT_FOUND, "service not found")?;
trace!(%service_id, %instance_id, %server_addr, "received and balanced user request");

*req.uri_mut() = {
let uri = req.uri();
let mut parts = uri.clone().into_parts();
parts.authority = Authority::from_str(&format!("{server_addr}:{WORKER_PROXY_PORT}")).ok();
parts.scheme = Some(Scheme::HTTP);
Uri::from_parts(parts).unwrap()
};

req.headers_mut().insert(
PROXY_INSTANCE_HEADER_NAME,
HeaderValue::from_str(&instance_id.to_string()).unwrap(),
);
req.headers_mut().insert(
PROXY_FORWARDED_HEADER_NAME,
HeaderValue::from_str(&addr.ip().to_string()).unwrap(),
);

balancer
.client
.request(req)
.await
.http_error(StatusCode::BAD_GATEWAY, "bad gateway")
}

fn extract_service_id(req: &mut Request) -> http::Result<ServiceId> {
let inner = req
.headers()
.get("Host")
.unwrap()
.to_str()
.ok()
.and_then(|s| s.parse().ok())
.or_http_error(StatusCode::BAD_REQUEST, "invalid service name")?;
Ok(ServiceId(inner))
}

#[derive(Default)]
pub struct InstanceBag {
pub instances: Vec<(InstanceId, IpAddr)>,
Expand All @@ -44,19 +93,18 @@ pub struct BalancerState {
impl BalancerState {
#[must_use]
pub fn new() -> (Self, BalancerHandle) {
let map = Arc::new(Mutex::new(HashMap::default()));
(
BalancerState {
addrs: map.clone(),
client: {
let mut connector = HttpConnector::new();
connector.set_keepalive(Some(Duration::from_secs(60)));
connector.set_nodelay(true);
Client::builder(TokioExecutor::new()).build::<_, Body>(connector)
},
let addrs = Arc::new(Mutex::new(HashMap::default()));
let state = BalancerState {
addrs: addrs.clone(),
client: {
let mut connector = HttpConnector::new();
connector.set_keepalive(Some(Duration::from_secs(60)));
connector.set_nodelay(true);
Client::builder(TokioExecutor::new()).build::<_, Body>(connector)
},
BalancerHandle { addrs: map },
)
};
let handle = BalancerHandle { addrs };
(state, handle)
}

pub fn next(&self, service: &ServiceId) -> Option<(InstanceId, IpAddr)> {
Expand Down Expand Up @@ -90,52 +138,3 @@ impl BalancerHandle {
bag.instances.retain(|(inst, _)| inst != &instance_id);
}
}

#[instrument(skip_all)]
pub async fn proxy(
ConnectInfo(addr): ConnectInfo<SocketAddr>,
State(balancer): State<BalancerState>,
mut req: Request,
) -> http::Result<impl IntoResponse> {
let service_id = extract_service_id(&mut req)?;

let (instance_id, server_addr) = balancer
.next(&service_id)
.or_http_error(StatusCode::NOT_FOUND, "service not found")?;
trace!(%service_id, %instance_id, %server_addr, "received and balanced user request");

*req.uri_mut() = {
let uri = req.uri();
let mut parts = uri.clone().into_parts();
parts.authority = Authority::from_str(&format!("{server_addr}:{WORKER_PROXY_PORT}")).ok();
parts.scheme = Some(Scheme::HTTP);
Uri::from_parts(parts).unwrap()
};

req.headers_mut().insert(
PROXY_INSTANCE_HEADER_NAME,
HeaderValue::from_str(&instance_id.to_string()).unwrap(),
);
req.headers_mut().insert(
PROXY_FORWARDED_HEADER_NAME,
HeaderValue::from_str(&addr.ip().to_string()).unwrap(),
);

balancer
.client
.request(req)
.await
.http_error(StatusCode::BAD_GATEWAY, "bad gateway")
}

fn extract_service_id(req: &mut Request) -> http::Result<ServiceId> {
let inner = req
.headers()
.get("Host")
.unwrap()
.to_str()
.ok()
.and_then(|s| s.parse().ok())
.or_http_error(StatusCode::BAD_REQUEST, "invalid service name")?;
Ok(ServiceId(inner))
}
24 changes: 11 additions & 13 deletions worker/src/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,19 +64,18 @@ pub struct ProxyState {
impl ProxyState {
#[must_use]
pub fn new() -> (Self, ProxyHandle) {
let map = Arc::new(RwLock::new(HashMap::default()));
(
ProxyState {
ports: map.clone(),
client: {
let mut connector = HttpConnector::new();
connector.set_keepalive(Some(Duration::from_secs(60)));
connector.set_nodelay(true);
Client::builder(TokioExecutor::new()).build::<_, Body>(connector)
},
let ports = Arc::new(RwLock::new(HashMap::default()));
let state = ProxyState {
ports: ports.clone(),
client: {
let mut connector = HttpConnector::new();
connector.set_keepalive(Some(Duration::from_secs(60)));
connector.set_nodelay(true);
Client::builder(TokioExecutor::new()).build::<_, Body>(connector)
},
ProxyHandle { ports: map },
)
};
let handle = ProxyHandle { ports };
(state, handle)
}
}

Expand All @@ -97,7 +96,6 @@ impl ProxyHandle {
}

fn extract_instance_id(req: &mut Request) -> http::Result<InstanceId> {
// i'm so sorry
let inner = req
.headers_mut()
.get(PROXY_INSTANCE_HEADER_NAME)
Expand Down
38 changes: 0 additions & 38 deletions worker/src/runner/handle.rs

This file was deleted.

52 changes: 42 additions & 10 deletions worker/src/runner/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@ use tracing::{error, trace};
mod container_rt;
use crate::proxy::ProxyHandle;

mod handle;
pub use handle::RunnerHandle;

pub struct Runner {
rx: mpsc::Receiver<Msg>,
instances: HashMap<InstanceId, u16>,
Expand Down Expand Up @@ -77,7 +74,9 @@ impl Runner {
}

async fn deploy_instance(&mut self, spec: InstanceSpec) -> eyre::Result<()> {
let port = self.get_port_for_instance(spec.instance_id).await?;
let port = self.get_available_instance_port().await?;
self.add_instance(spec.instance_id, port);

let rt = self.container_runtime.clone();
let handle = self.handle.clone();
tokio::spawn(async move {
Expand All @@ -98,8 +97,7 @@ impl Runner {
use instance::Status::*;
match &status {
Started => (),
Terminated => self.remove_instance(instance_id),
Crashed { error: _ } | Killed { reason: _ } | FailedToStart { error: _ } => {
Terminated | Crashed { .. } | Killed { .. } | FailedToStart { .. } => {
self.remove_instance(instance_id);
}
}
Expand All @@ -113,17 +111,20 @@ impl Runner {
});
}

async fn get_port_for_instance(&mut self, id: InstanceId) -> eyre::Result<u16> {
async fn get_available_instance_port(&mut self) -> eyre::Result<u16> {
let port = loop {
let port = get_port().await?;
let port = get_available_port().await?;
if !self.ports.contains(&port) {
break port;
}
};
Ok(port)
}

fn add_instance(&mut self, id: InstanceId, port: u16) {
self.instances.insert(id, port);
self.ports.insert(port);
self.proxy_handle.add_instance(id, port);
Ok(port)
}

fn remove_instance(&mut self, id: InstanceId) {
Expand All @@ -133,6 +134,37 @@ impl Runner {
}
}

#[derive(Clone)]
pub struct RunnerHandle(pub mpsc::Sender<Msg>);

impl RunnerHandle {
async fn send(&self, msg: Msg) {
_ = self.0.send(msg).await;
}

/// Sends a message and waits for a reply.
async fn send_wait<F, R>(&self, f: F) -> R
where
F: FnOnce(oneshot::Sender<R>) -> Msg,
{
let (tx, rx) = oneshot::channel();
self.send(f(tx)).await;
rx.await.expect("actor must be alive")
}

pub async fn deploy_instance(&self, spec: InstanceSpec) -> Result<(), Report> {
self.send_wait(|tx| Msg::DeployInstance(spec, tx)).await
}

pub async fn terminate_instance(&self, id: InstanceId) -> Result<(), Report> {
self.send_wait(|tx| Msg::TerminateInstance(id, tx)).await
}

pub async fn report_instance_status(&self, id: InstanceId, status: instance::Status) {
self.send(Msg::ReportInstanceStatus(id, status)).await;
}
}

#[allow(dead_code)]
pub enum Msg {
DeployInstance(InstanceSpec, oneshot::Sender<Result<(), Report>>),
Expand All @@ -142,7 +174,7 @@ pub enum Msg {
ReportInstanceStatus(InstanceId, instance::Status),
}

async fn get_port() -> eyre::Result<u16> {
async fn get_available_port() -> eyre::Result<u16> {
let listener = TcpListener::bind(("0.0.0.0", 0))
.await
.wrap_err("failed to bind while deciding port")?;
Expand Down

0 comments on commit d597435

Please sign in to comment.