Skip to content

Commit

Permalink
refactor: trim dependencies
Browse files Browse the repository at this point in the history
Trim dependencies to minimize the build times of user projects.
  • Loading branch information
chesedo committed Mar 14, 2023
1 parent 60b1c02 commit a0ea049
Show file tree
Hide file tree
Showing 25 changed files with 61 additions and 322 deletions.
1 change: 0 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ shuttle-service = { path = "service", version = "0.11.0" }
anyhow = "1.0.66"
async-trait = "0.1.58"
axum = { version = "0.6.0", default-features = false }
chrono = { version = "0.4.23", default-features = false, features = ["clock"] }
chrono = { version = "0.4.23", default-features = false }
clap = { version = "4.0.27", features = [ "derive" ] }
headers = "0.3.8"
http = "0.2.8"
Expand Down
2 changes: 1 addition & 1 deletion auth/src/api/handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use axum::{
use axum_sessions::extractors::{ReadableSession, WritableSession};
use http::StatusCode;
use serde::{Deserialize, Serialize};
use shuttle_common::{backends::auth::Claim, models::user};
use shuttle_common::{claims::Claim, models::user};
use tracing::instrument;

use super::{
Expand Down
2 changes: 1 addition & 1 deletion auth/src/user.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use axum::{
};
use rand::distributions::{Alphanumeric, DistString};
use serde::{Deserialize, Deserializer, Serialize};
use shuttle_common::backends::auth::Scope;
use shuttle_common::claims::Scope;
use sqlx::{query, Row, SqlitePool};
use tracing::{trace, Span};

Expand Down
2 changes: 1 addition & 1 deletion auth/tests/api/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use axum_extra::extract::cookie::{self, Cookie};
use http::{Request, StatusCode};
use hyper::Body;
use serde_json::{json, Value};
use shuttle_common::backends::auth::Claim;
use shuttle_common::claims::Claim;

use crate::helpers::app;

Expand Down
16 changes: 9 additions & 7 deletions common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ anyhow = { workspace = true, optional = true }
async-trait = { workspace = true , optional = true }
axum = { workspace = true, optional = true }
bytes = { version = "1.3.0", optional = true }
chrono = { workspace = true, features = ["serde"] }
chrono = { workspace = true }
comfy-table = { version = "6.1.3", optional = true }
crossterm = { version = "0.25.0", optional = true }
headers = { workspace = true }
headers = { workspace = true, optional = true }
http = { workspace = true, optional = true }
http-body = { version = "0.4.5", optional = true }
http-serde = { version = "1.1.2", optional = true }
Expand All @@ -25,7 +25,7 @@ once_cell = { workspace = true, optional = true }
opentelemetry = { workspace = true, optional = true }
opentelemetry-http = { workspace = true, optional = true }
opentelemetry-otlp = { version = "0.11.0", optional = true }
pin-project = { workspace = true }
pin-project = { workspace = true, optional = true }
prost-types = { workspace = true, optional = true }
reqwest = { version = "0.11.13", optional = true }
rmp-serde = { version = "1.1.1", optional = true }
Expand All @@ -44,12 +44,14 @@ ttl_cache = { workspace = true, optional = true }
uuid = { workspace = true, features = ["v4", "serde"], optional = true }

[features]
backend = ["async-trait", "axum/matched-path", "bytes", "http", "http-body", "hyper/client", "jsonwebtoken", "opentelemetry", "opentelemetry-http", "opentelemetry-otlp", "thiserror", "tower", "tower-http", "tracing-opentelemetry", "tracing-subscriber/env-filter", "tracing-subscriber/fmt", "ttl_cache"]
display = ["comfy-table", "crossterm"]
models = ["anyhow", "async-trait", "display", "http", "prost-types", "reqwest", "serde_json", "service", "thiserror"]
backend = ["async-trait", "axum/matched-path", "claims", "hyper/client", "opentelemetry-otlp", "thiserror", "tower-http", "tracing-subscriber/env-filter", "tracing-subscriber/fmt", "ttl_cache"]
claims = ["bytes", "chrono/clock", "headers", "http", "http-body", "jsonwebtoken", "opentelemetry", "opentelemetry-http", "pin-project", "tower", "tracing", "tracing-opentelemetry"]
display = ["chrono/clock", "comfy-table", "crossterm"]
error = ["prost-types", "serde_json", "thiserror", "uuid"]
models = ["anyhow", "async-trait", "display", "http", "reqwest", "serde_json", "service"]
service = ["chrono/serde", "once_cell", "rustrict", "serde/derive", "strum", "uuid"]
tracing = ["serde_json"]
wasm = ["http-serde", "http", "rmp-serde", "tracing", "tracing-subscriber"]
wasm = ["chrono/clock", "http-serde", "http", "rmp-serde", "tracing", "tracing-subscriber"]

[dev-dependencies]
axum = { workspace = true }
Expand Down
213 changes: 7 additions & 206 deletions common/src/backends/auth.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
use std::{convert::Infallible, future::Future, ops::Add, pin::Pin, sync::Arc};
use std::{convert::Infallible, future::Future, pin::Pin, sync::Arc};

use async_trait::async_trait;
use bytes::Bytes;
use chrono::{Duration, Utc};
use headers::{authorization::Bearer, Authorization, HeaderMapExt};
use http::{Request, Response, StatusCode, Uri};
use http_body::combinators::UnsyncBoxBody;
use hyper::{body, Body, Client};
use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header as JwtHeader, Validation};
use opentelemetry::global;
use opentelemetry_http::HeaderInjector;
use serde::{Deserialize, Serialize};
Expand All @@ -16,14 +14,14 @@ use tower::{Layer, Service};
use tracing::{error, trace, Span};
use tracing_opentelemetry::OpenTelemetrySpanExt;

use crate::claims::{Claim, Scope};

use super::{
cache::{CacheManagement, CacheManager},
future::{ResponseFuture, StatusCodeFuture},
future::StatusCodeFuture,
headers::XShuttleAdminSecret,
};

pub const EXP_MINUTES: i64 = 5;
const ISS: &str = "shuttle";
const PUBLIC_KEY_CACHE_KEY: &str = "shuttle.public-key";

/// Layer to check the admin secret set by deployer is correct
Expand Down Expand Up @@ -86,164 +84,12 @@ where
}
}

/// The scope of operations that can be performed on shuttle
/// Every scope defaults to read and will use a suffix for updating tasks
#[derive(Clone, Debug, Deserialize, Serialize, Eq, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum Scope {
/// Read the details, such as status and address, of a deployment
Deployment,

/// Push a new deployment
DeploymentPush,

/// Read the logs of a deployment
Logs,

/// Read the details of a service
Service,

/// Create a new service
ServiceCreate,

/// Read the status of a project
Project,

/// Create a new project
ProjectCreate,

/// Get the resources for a project
Resources,

/// Provision new resources for a project or update existing ones
ResourcesWrite,

/// List the secrets of a project
Secret,

/// Add or update secrets of a project
SecretWrite,

/// Get list of users
User,

/// Add or update users
UserCreate,

/// Create an ACME account
AcmeCreate,

/// Create a custom domain,
CustomDomainCreate,

/// Admin level scope to internals
Admin,
}

#[derive(Deserialize, Serialize)]
/// Response used internally to pass around JWT token
pub struct ConvertResponse {
pub token: String,
}

#[derive(Clone, Debug, Deserialize, Serialize, Eq, PartialEq)]
pub struct Claim {
/// Expiration time (as UTC timestamp).
pub exp: usize,
/// Issued at (as UTC timestamp).
iat: usize,
/// Issuer.
iss: String,
/// Not Before (as UTC timestamp).
nbf: usize,
/// Subject (whom token refers to).
pub sub: String,
/// Scopes this token can access
pub scopes: Vec<Scope>,
/// The original token that was parsed
token: Option<String>,
}

impl Claim {
/// Create a new claim for a user with the given scopes
pub fn new(sub: String, scopes: Vec<Scope>) -> Self {
let iat = Utc::now();
let exp = iat.add(Duration::minutes(EXP_MINUTES));

Self {
exp: exp.timestamp() as usize,
iat: iat.timestamp() as usize,
iss: ISS.to_string(),
nbf: iat.timestamp() as usize,
sub,
scopes,
token: None,
}
}

pub fn into_token(self, encoding_key: &EncodingKey) -> Result<String, StatusCode> {
if let Some(token) = self.token {
Ok(token)
} else {
encode(
&JwtHeader::new(jsonwebtoken::Algorithm::EdDSA),
&self,
encoding_key,
)
.map_err(|err| {
error!(
error = &err as &dyn std::error::Error,
"failed to convert claim to token"
);
match err.kind() {
jsonwebtoken::errors::ErrorKind::Json(_) => StatusCode::INTERNAL_SERVER_ERROR,
jsonwebtoken::errors::ErrorKind::Crypto(_) => StatusCode::SERVICE_UNAVAILABLE,
_ => StatusCode::INTERNAL_SERVER_ERROR,
}
})
}
}

pub fn from_token(token: &str, public_key: &[u8]) -> Result<Self, StatusCode> {
let decoding_key = DecodingKey::from_ed_der(public_key);
let mut validation = Validation::new(jsonwebtoken::Algorithm::EdDSA);
validation.set_issuer(&[ISS]);

trace!(token, "converting token to claim");
let mut claim: Self = decode(token, &decoding_key, &validation)
.map_err(|err| {
error!(
error = &err as &dyn std::error::Error,
"failed to convert token to claim"
);
match err.kind() {
jsonwebtoken::errors::ErrorKind::InvalidSignature
| jsonwebtoken::errors::ErrorKind::InvalidAlgorithmName
| jsonwebtoken::errors::ErrorKind::ExpiredSignature
| jsonwebtoken::errors::ErrorKind::InvalidIssuer
| jsonwebtoken::errors::ErrorKind::ImmatureSignature => {
StatusCode::UNAUTHORIZED
}
jsonwebtoken::errors::ErrorKind::InvalidToken
| jsonwebtoken::errors::ErrorKind::InvalidAlgorithm
| jsonwebtoken::errors::ErrorKind::Base64(_)
| jsonwebtoken::errors::ErrorKind::Json(_)
| jsonwebtoken::errors::ErrorKind::Utf8(_) => StatusCode::BAD_REQUEST,
jsonwebtoken::errors::ErrorKind::MissingAlgorithm => {
StatusCode::INTERNAL_SERVER_ERROR
}
jsonwebtoken::errors::ErrorKind::Crypto(_) => StatusCode::SERVICE_UNAVAILABLE,
_ => StatusCode::INTERNAL_SERVER_ERROR,
}
})?
.claims;

claim.token = Some(token.to_string());

Ok(claim)
}
}

/// Trait to get a public key asynchronously
#[async_trait]
pub trait PublicKeyFn: Send + Sync + Clone {
Expand Down Expand Up @@ -439,53 +285,6 @@ where
}
}

/// This layer takes a claim on a request extension and uses it's internal token to set the Authorization Bearer
#[derive(Clone)]
pub struct ClaimLayer;

impl<S> Layer<S> for ClaimLayer {
type Service = ClaimService<S>;

fn layer(&self, inner: S) -> Self::Service {
ClaimService { inner }
}
}

#[derive(Clone)]
pub struct ClaimService<S> {
inner: S,
}

impl<S, RequestError> Service<Request<UnsyncBoxBody<Bytes, RequestError>>> for ClaimService<S>
where
S: Service<Request<UnsyncBoxBody<Bytes, RequestError>>> + Send + 'static,
S::Future: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = ResponseFuture<S::Future>;

fn poll_ready(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}

fn call(&mut self, mut req: Request<UnsyncBoxBody<Bytes, RequestError>>) -> Self::Future {
if let Some(claim) = req.extensions().get::<Claim>() {
if let Some(token) = claim.token.clone() {
req.headers_mut()
.typed_insert(Authorization::bearer(&token).expect("to set JWT token"));
}
}

let future = self.inner.call(req);

ResponseFuture(future)
}
}

/// Check that the required scopes are set on the [Claim] extension on a [Request]
#[derive(Clone)]
pub struct ScopedLayer {
Expand Down Expand Up @@ -568,7 +367,9 @@ mod tests {
use serde_json::json;
use tower::{ServiceBuilder, ServiceExt};

use super::{Claim, JwtAuthenticationLayer, Scope, ScopedLayer};
use crate::claims::{Claim, Scope};

use super::{JwtAuthenticationLayer, ScopedLayer};

#[test]
fn to_token_and_back() {
Expand Down
17 changes: 0 additions & 17 deletions common/src/backends/future.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,6 @@ use axum::response::Response;
use http::StatusCode;
use pin_project::pin_project;

// Future for layers that just return the inner response
#[pin_project]
pub struct ResponseFuture<F>(#[pin] pub F);

impl<F, Response, Error> Future for ResponseFuture<F>
where
F: Future<Output = Result<Response, Error>>,
{
type Output = Result<Response, Error>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();

this.0.poll(cx)
}
}

/// Future for layers that might return a different status code
#[pin_project(project = StatusCodeProj)]
pub enum StatusCodeFuture<F> {
Expand Down
Loading

0 comments on commit a0ea049

Please sign in to comment.