Skip to content

Commit

Permalink
Allow bearer tokens in client, not just typed socket
Browse files Browse the repository at this point in the history
  • Loading branch information
paulgb committed Jan 9, 2024
1 parent 58dd4be commit e5216eb
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 95 deletions.
54 changes: 54 additions & 0 deletions plane/src/client/controller_address.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
use url::Url;

/// An authorized address combines a URL with an optional bearer token.
#[derive(Clone, Debug)]
pub struct AuthorizedAddress {
pub url: Url,
pub bearer_token: Option<String>,
}

impl AuthorizedAddress {
pub fn join(&self, path: &str) -> AuthorizedAddress {
let url = self.url.clone();
let url = url.join(path).expect("URL is always valid");

Self {
url,
bearer_token: self.bearer_token.clone(),
}
}

pub fn to_websocket_address(mut self) -> AuthorizedAddress {
if self.url.scheme() == "http" {
self.url
.set_scheme("ws")
.expect("should always be able to set URL scheme to static value ws");
} else if self.url.scheme() == "https" {
self.url
.set_scheme("wss")
.expect("should always be able to set URL scheme to static value wss");
}

self
}

pub fn bearer_header(&self) -> Option<String> {
self.bearer_token
.as_ref()
.map(|token| format!("Bearer {}", token))
}
}

impl From<Url> for AuthorizedAddress {
fn from(url: Url) -> Self {
let bearer_token = match url.username() {
"" => None,
username => Some(username.to_string()),
};

let mut url = url;
url.set_username("").expect("URL is always valid");

Self { url, bearer_token }
}
}
157 changes: 85 additions & 72 deletions plane/src/client/mod.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
use self::controller_address::AuthorizedAddress;
use crate::{
controller::error::ApiError,
controller::{error::ApiError, StatusResponse},
names::{BackendName, DroneName},
protocol::{MessageFromDns, MessageFromDrone, MessageFromProxy},
typed_socket::client::TypedSocketConnector,
types::{ClusterName, ConnectRequest, ConnectResponse, TimestampedBackendStatus},
};
use reqwest::{Response, StatusCode};
use serde::de::DeserializeOwned;
use serde_json::Value;
use url::Url;

pub mod controller_address;
mod sse;

#[derive(thiserror::Error, Debug)]
Expand All @@ -33,78 +34,52 @@ pub enum PlaneClientError {
#[derive(Clone)]
pub struct PlaneClient {
client: reqwest::Client,
base_url: Url,
}

async fn get_response<T: DeserializeOwned>(response: Response) -> Result<T, PlaneClientError> {
if response.status().is_success() {
Ok(response.json::<T>().await?)
} else {
let url = response.url().to_string();
tracing::error!(?url, "Got error response from API server.");
let status = response.status();
if let Ok(api_error) = response.json::<ApiError>().await {
Err(PlaneClientError::PlaneError(api_error, status))
} else {
Err(PlaneClientError::UnexpectedStatus(status))
}
}
}

fn http_to_ws_url(url: &mut Url) {
if url.scheme() == "http" {
url.set_scheme("ws")
.expect("should always be able to set URL scheme to static value ws");
} else if url.scheme() == "https" {
url.set_scheme("wss")
.expect("should always be able to set URL scheme to static value wss");
}
controller_address: AuthorizedAddress,
}

impl PlaneClient {
pub fn new(base_url: Url) -> Self {
let client = reqwest::Client::new();
Self { client, base_url }
}
let controller_address = AuthorizedAddress::from(base_url);

pub async fn status(&self) -> Result<(), PlaneClientError> {
let url = self.base_url.join("/ctrl/status")?;
Self {
client,
controller_address,
}
}

let response = self.client.get(url).send().await?;
get_response::<Value>(response).await?;
Ok(())
pub async fn status(&self) -> Result<StatusResponse, PlaneClientError> {
let addr = self.controller_address.join("/ctrl/status");
authed_get(&self.client, &addr).await
}

pub fn drone_connection(
&self,
cluster: &ClusterName,
) -> TypedSocketConnector<MessageFromDrone> {
let mut url = self
.base_url
let addr = self
.controller_address
.join(&format!("/ctrl/c/{}/drone-socket", cluster))
.expect("url is always valid");
http_to_ws_url(&mut url);
TypedSocketConnector::new(url)
.to_websocket_address();
TypedSocketConnector::new(addr)
}

pub fn proxy_connection(
&self,
cluster: &ClusterName,
) -> TypedSocketConnector<MessageFromProxy> {
let mut url = self
.base_url
let addr = self
.controller_address
.join(&format!("/ctrl/c/{}/proxy-socket", cluster))
.expect("url is always valid");
http_to_ws_url(&mut url);
TypedSocketConnector::new(url)
.to_websocket_address();
TypedSocketConnector::new(addr)
}

pub fn dns_connection(&self) -> TypedSocketConnector<MessageFromDns> {
let mut url = self
.base_url
let url = self
.controller_address
.join("/ctrl/dns-socket")
.expect("url is always valid");
http_to_ws_url(&mut url);
.to_websocket_address();
TypedSocketConnector::new(url)
}

Expand All @@ -113,26 +88,24 @@ impl PlaneClient {
cluster: &ClusterName,
connect_request: &ConnectRequest,
) -> Result<ConnectResponse, PlaneClientError> {
let url = self
.base_url
.join(&format!("/ctrl/c/{}/connect", cluster))?;
let addr = self
.controller_address
.join(&format!("/ctrl/c/{}/connect", cluster));

let respose = self.client.post(url).json(connect_request).send().await?;
let connect_response: ConnectResponse = get_response(respose).await?;
Ok(connect_response)
let response = authed_post(&self.client, &addr, connect_request).await?;
Ok(response)
}

pub async fn drain(
&self,
cluster: &ClusterName,
drone: &DroneName,
) -> Result<(), PlaneClientError> {
let url = self
.base_url
.join(&format!("/ctrl/c/{}/d/{}/drain", cluster, drone))?;
let addr = self
.controller_address
.join(&format!("/ctrl/c/{}/d/{}/drain", cluster, drone));

let response = self.client.post(url).send().await?;
get_response::<Value>(response).await?;
authed_post(&self.client, &addr, &()).await?;
Ok(())
}

Expand All @@ -141,13 +114,12 @@ impl PlaneClient {
cluster: &ClusterName,
backend_id: &BackendName,
) -> Result<(), PlaneClientError> {
let url = self.base_url.join(&format!(
let addr = self.controller_address.join(&format!(
"/ctrl/c/{}/b/{}/soft-terminate",
cluster, backend_id
))?;
));

let response = self.client.post(url).send().await?;
get_response::<Value>(response).await?;
authed_post(&self.client, &addr, &()).await?;
Ok(())
}

Expand All @@ -156,20 +128,19 @@ impl PlaneClient {
cluster: &ClusterName,
backend_id: &BackendName,
) -> Result<(), PlaneClientError> {
let url = self.base_url.join(&format!(
let addr = self.controller_address.join(&format!(
"/ctrl/c/{}/b/{}/hard-terminate",
cluster, backend_id
))?;
));

let response = self.client.post(url).send().await?;
get_response::<Value>(response).await?;
authed_post(&self.client, &addr, &()).await?;
Ok(())
}

pub fn backend_status_url(&self, cluster: &ClusterName, backend_id: &BackendName) -> Url {
self.base_url
self.controller_address
.join(&format!("/pub/c/{}/b/{}/status", cluster, backend_id))
.expect("Constructed URL is always valid.")
.url
}

pub async fn backend_status(
Expand All @@ -189,12 +160,12 @@ impl PlaneClient {
cluster: &ClusterName,
backend_id: &BackendName,
) -> Url {
self.base_url
self.controller_address
.join(&format!(
"/pub/c/{}/b/{}/status-stream",
cluster, backend_id
))
.expect("Constructed URL is always valid.")
.url
}

pub async fn backend_status_stream(
Expand All @@ -208,3 +179,45 @@ impl PlaneClient {
Ok(stream)
}
}

async fn get_response<T: DeserializeOwned>(response: Response) -> Result<T, PlaneClientError> {
if response.status().is_success() {
Ok(response.json::<T>().await?)
} else {
let url = response.url().to_string();
tracing::error!(?url, "Got error response from API server.");
let status = response.status();
if let Ok(api_error) = response.json::<ApiError>().await {
Err(PlaneClientError::PlaneError(api_error, status))
} else {
Err(PlaneClientError::UnexpectedStatus(status))
}
}
}

async fn authed_get<T: DeserializeOwned>(
client: &reqwest::Client,
addr: &AuthorizedAddress,
) -> Result<T, PlaneClientError> {
let mut req = client.get(addr.url.clone());
if let Some(header) = addr.bearer_header() {
req = req.header("Authorization", header);
}

let response = req.send().await?;
get_response(response).await
}

async fn authed_post<T: DeserializeOwned>(
client: &reqwest::Client,
addr: &AuthorizedAddress,
body: &impl serde::Serialize,
) -> Result<T, PlaneClientError> {
let mut req = client.post(addr.url.clone());
if let Some(header) = addr.bearer_header() {
req = req.header("Authorization", header);
}

let response = req.json(body).send().await?;
get_response(response).await
}
21 changes: 14 additions & 7 deletions plane/src/controller/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use axum::{
routing::{get, post},
Json, Router, Server,
};
use serde_json::{json, Value};
use serde::{Deserialize, Serialize};
use std::net::{SocketAddr, TcpListener};
use tokio::{
sync::oneshot::{self},
Expand All @@ -38,12 +38,19 @@ pub mod error;
mod proxy;
mod terminate;

pub async fn status() -> Json<Value> {
Json(json!({
"status": "ok",
"version": PLANE_VERSION,
"hash": PLANE_GIT_HASH,
}))
#[derive(Serialize, Deserialize)]
pub struct StatusResponse {
pub status: String,
pub version: String,
pub hash: String,
}

pub async fn status() -> Json<StatusResponse> {
Json(StatusResponse {
status: "ok".to_string(),
version: PLANE_VERSION.to_string(),
hash: PLANE_GIT_HASH.to_string(),
})
}

struct HeartbeatSender {
Expand Down
Loading

0 comments on commit e5216eb

Please sign in to comment.