diff --git a/.circleci/config.yml b/.circleci/config.yml index f8f86c2ad..4bbc03100 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -9,10 +9,12 @@ executors: docker-rust: docker: - image: cimg/rust:1.65.0 + resource_class: small image-ubuntu: machine: image: ubuntu-2204:2022.04.1 docker_layer_caching: true + resource_class: xlarge # sscache steps are from this guide # https://medium.com/@edouard.oger/rust-caching-on-circleci-using-sccache-c996344f0115 @@ -159,8 +161,9 @@ commands: - << parameters.target >>.env jobs: - workspace-fmt: + workspace: executor: docker-rust + resource_class: xlarge steps: - checkout - restore-cargo-cache @@ -170,14 +173,6 @@ jobs: # https://github.com/DevinR528/cargo-sort/pull/29 is merged # - run: cargo install cargo-sort # - run: cargo sort --check --workspace - - run: cargo check --workspace --all-targets - - save-cargo-cache - workspace-clippy: - executor: docker-rust - steps: - - checkout - - restore-cargo-cache - - install-protoc - run: | cargo clippy --tests \ --all-targets \ @@ -214,20 +209,6 @@ jobs: -A clippy::format-push-string - run: cargo test --all-features --manifest-path << parameters.path >>/Cargo.toml -- --nocapture - save-cargo-cache - service-test: - # Using an image since tests will start a docker container - executor: image-ubuntu - steps: - - install-rust - - checkout - - restore-cargo-cache - - run: - name: Run unit tests - command: cargo test --package shuttle-service --features="codegen,builder" --lib -- --nocapture - - run: - name: Run integration tests - command: cargo test --package shuttle-service --features="codegen,builder" --test '*' -- --nocapture - - save-cargo-cache platform-test: parameters: crate: @@ -254,7 +235,6 @@ jobs: (cargo test --package << parameters.crate >> --all-features --test '*' -- --list 2>&1 | grep -q "no test target matches pattern") && echo "nothing to test" || cargo test --package << parameters.crate >> --all-features --test '*' -- --nocapture - save-cargo-cache e2e-test: - resource_class: xlarge executor: image-ubuntu steps: - install-rust @@ -286,7 +266,6 @@ jobs: key: docker-buildx-{{ .Branch }} when: always build-and-push: - resource_class: xlarge executor: image-ubuntu steps: - checkout @@ -360,7 +339,7 @@ jobs: build-binaries-mac: macos: xcode: 12.5.1 - resource_class: medium + resource_class: xlarge steps: - checkout - run: @@ -398,11 +377,9 @@ jobs: workflows: ci: jobs: - - workspace-fmt - - workspace-clippy: - requires: - - workspace-fmt + - workspace - check-standalone: + name: << matrix.path >> matrix: parameters: path: @@ -422,28 +399,26 @@ workflows: - services/shuttle-tide - services/shuttle-tower - services/shuttle-warp - - service-test: - requires: - - workspace-clippy - platform-test: + name: << matrix.crate >> requires: - - workspace-clippy + - workspace matrix: parameters: crate: [ "shuttle-auth", - "shuttle-deployer", "cargo-shuttle", "shuttle-codegen", "shuttle-common", + "shuttle-deployer", "shuttle-proto", "shuttle-provisioner", - "shuttle-runtime" + "shuttle-runtime", + "shuttle-service", ] - e2e-test: requires: - - service-test - platform-test - check-standalone filters: @@ -459,7 +434,7 @@ workflows: name: build-binaries-x86_64-gnu image: ubuntu-2204:2022.04.1 target: x86_64-unknown-linux-gnu - resource_class: medium + resource_class: xlarge filters: branches: only: production @@ -467,7 +442,7 @@ workflows: name: build-binaries-x86_64-musl image: ubuntu-2204:2022.04.1 target: x86_64-unknown-linux-musl - resource_class: medium + resource_class: xlarge filters: branches: only: production @@ -475,7 +450,7 @@ workflows: name: build-binaries-aarch64 image: ubuntu-2004:202101-01 target: aarch64-unknown-linux-musl - resource_class: arm.medium + resource_class: arm.xlarge filters: branches: only: production diff --git a/Cargo.lock b/Cargo.lock index ec6ba3832..0ace53bce 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4840,7 +4840,6 @@ version = "0.11.0" dependencies = [ "anyhow", "async-trait", - "axum", "cargo", "cargo_metadata", "crossbeam-channel", diff --git a/Cargo.toml b/Cargo.toml index ad2d9a375..86627999b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -41,7 +41,7 @@ shuttle-service = { path = "service", version = "0.11.0" } anyhow = "1.0.66" async-trait = "0.1.58" axum = { version = "0.6.0", default-features = false } -chrono = { version = "0.4.23", default-features = false, features = ["clock"] } +chrono = { version = "0.4.23", default-features = false } clap = { version = "4.0.27", features = [ "derive" ] } headers = "0.3.8" http = "0.2.8" diff --git a/auth/src/api/handlers.rs b/auth/src/api/handlers.rs index 2d08221ce..b45061bf6 100644 --- a/auth/src/api/handlers.rs +++ b/auth/src/api/handlers.rs @@ -9,7 +9,7 @@ use axum::{ use axum_sessions::extractors::{ReadableSession, WritableSession}; use http::StatusCode; use serde::{Deserialize, Serialize}; -use shuttle_common::{backends::auth::Claim, models::user}; +use shuttle_common::{claims::Claim, models::user}; use tracing::instrument; use super::{ diff --git a/auth/src/user.rs b/auth/src/user.rs index 264adfb60..8d9b5a0c9 100644 --- a/auth/src/user.rs +++ b/auth/src/user.rs @@ -9,7 +9,7 @@ use axum::{ }; use rand::distributions::{Alphanumeric, DistString}; use serde::{Deserialize, Deserializer, Serialize}; -use shuttle_common::backends::auth::Scope; +use shuttle_common::claims::Scope; use sqlx::{query, Row, SqlitePool}; use tracing::{trace, Span}; diff --git a/auth/tests/api/session.rs b/auth/tests/api/session.rs index aa07e8bdd..b768a0096 100644 --- a/auth/tests/api/session.rs +++ b/auth/tests/api/session.rs @@ -2,7 +2,7 @@ use axum_extra::extract::cookie::{self, Cookie}; use http::{Request, StatusCode}; use hyper::Body; use serde_json::{json, Value}; -use shuttle_common::backends::auth::Claim; +use shuttle_common::claims::Claim; use crate::helpers::app; diff --git a/cargo-shuttle/src/lib.rs b/cargo-shuttle/src/lib.rs index 730d11360..2194391db 100644 --- a/cargo-shuttle/src/lib.rs +++ b/cargo-shuttle/src/lib.rs @@ -1049,7 +1049,7 @@ mod tests { "Secrets.toml", "Secrets.toml.example", "Shuttle.toml", - "src/lib.rs", + "src/main.rs", ] ); } diff --git a/cargo-shuttle/tests/integration/run.rs b/cargo-shuttle/tests/integration/run.rs index e11dbe3dd..cc2244520 100644 --- a/cargo-shuttle/tests/integration/run.rs +++ b/cargo-shuttle/tests/integration/run.rs @@ -58,6 +58,7 @@ async fn cargo_shuttle_run(working_directory: &str, external: bool) -> String { } #[tokio::test(flavor = "multi_thread")] +#[ignore] async fn rocket_hello_world() { let url = cargo_shuttle_run("../examples/rocket/hello-world", false).await; @@ -120,6 +121,7 @@ async fn rocket_postgres() { } #[tokio::test(flavor = "multi_thread")] +#[ignore] async fn rocket_authentication() { let url = cargo_shuttle_run("../examples/rocket/authentication", false).await; let client = reqwest::Client::new(); @@ -176,6 +178,7 @@ async fn rocket_authentication() { } #[tokio::test(flavor = "multi_thread")] +#[ignore] async fn actix_web_hello_world() { let url = cargo_shuttle_run("../examples/actix-web/hello-world", false).await; @@ -192,6 +195,7 @@ async fn actix_web_hello_world() { } #[tokio::test(flavor = "multi_thread")] +#[ignore] async fn axum_hello_world() { let url = cargo_shuttle_run("../examples/axum/hello-world", false).await; @@ -208,6 +212,7 @@ async fn axum_hello_world() { } #[tokio::test(flavor = "multi_thread")] +#[ignore] async fn tide_hello_world() { let url = cargo_shuttle_run("../examples/tide/hello-world", false).await; @@ -224,6 +229,7 @@ async fn tide_hello_world() { } #[tokio::test(flavor = "multi_thread")] +#[ignore] async fn tower_hello_world() { let url = cargo_shuttle_run("../examples/tower/hello-world", false).await; @@ -240,6 +246,7 @@ async fn tower_hello_world() { } #[tokio::test(flavor = "multi_thread")] +#[ignore] async fn warp_hello_world() { let url = cargo_shuttle_run("../examples/warp/hello-world", false).await; @@ -256,6 +263,7 @@ async fn warp_hello_world() { } #[tokio::test(flavor = "multi_thread")] +#[ignore] async fn poem_hello_world() { let url = cargo_shuttle_run("../examples/poem/hello-world", false).await; @@ -273,6 +281,7 @@ async fn poem_hello_world() { // This example uses a shared Postgres. Thus local runs should create a docker container for it. #[tokio::test(flavor = "multi_thread")] +#[ignore] async fn poem_postgres() { let url = cargo_shuttle_run("../examples/poem/postgres", false).await; let client = reqwest::Client::new(); @@ -336,6 +345,7 @@ async fn poem_mongodb() { } #[tokio::test(flavor = "multi_thread")] +#[ignore] async fn salvo_hello_world() { let url = cargo_shuttle_run("../examples/salvo/hello-world", false).await; @@ -352,6 +362,7 @@ async fn salvo_hello_world() { } #[tokio::test(flavor = "multi_thread")] +#[ignore] async fn thruster_hello_world() { let url = cargo_shuttle_run("../examples/thruster/hello-world", false).await; diff --git a/codegen/tests/ui/main/missing-return.stderr b/codegen/tests/ui/main/missing-return.stderr index e84ed9923..7ec24dffd 100644 --- a/codegen/tests/ui/main/missing-return.stderr +++ b/codegen/tests/ui/main/missing-return.stderr @@ -1,7 +1,7 @@ -error: shuttle_service::main functions need to return a service +error: shuttle_runtime::main functions need to return a service = help: See the docs for services with first class support - = note: https://docs.rs/shuttle-service/latest/shuttle_service/attr.main.html#shuttle-supported-services + = note: https://docs.rs/shuttle-service/latest/shuttle_runtime/attr.main.html#shuttle-supported-services --> tests/ui/main/missing-return.rs:2:1 | diff --git a/codegen/tests/ui/main/return-tuple.stderr b/codegen/tests/ui/main/return-tuple.stderr index b9fed7820..94da0bcb5 100644 --- a/codegen/tests/ui/main/return-tuple.stderr +++ b/codegen/tests/ui/main/return-tuple.stderr @@ -1,7 +1,7 @@ -error: shuttle_service::main functions need to return a first class service or 'Result +error: shuttle_runtime::main functions need to return a first class service or 'Result = help: See the docs for services with first class support - = note: https://docs.rs/shuttle-service/latest/shuttle_service/attr.main.html#shuttle-supported-services + = note: https://docs.rs/shuttle-service/latest/shuttle_runtime/attr.main.html#shuttle-supported-services --> tests/ui/main/return-tuple.rs:2:28 | diff --git a/common/Cargo.toml b/common/Cargo.toml index 05b0494b5..49c4f736f 100644 --- a/common/Cargo.toml +++ b/common/Cargo.toml @@ -12,10 +12,10 @@ anyhow = { workspace = true, optional = true } async-trait = { workspace = true , optional = true } axum = { workspace = true, optional = true } bytes = { version = "1.3.0", optional = true } -chrono = { workspace = true, features = ["serde"] } +chrono = { workspace = true } comfy-table = { version = "6.1.3", optional = true } crossterm = { version = "0.25.0", optional = true } -headers = { workspace = true } +headers = { workspace = true, optional = true } http = { workspace = true, optional = true } http-body = { version = "0.4.5", optional = true } http-serde = { version = "1.1.2", optional = true } @@ -25,31 +25,33 @@ once_cell = { workspace = true, optional = true } opentelemetry = { workspace = true, optional = true } opentelemetry-http = { workspace = true, optional = true } opentelemetry-otlp = { version = "0.11.0", optional = true } -pin-project = { workspace = true } +pin-project = { workspace = true, optional = true } prost-types = { workspace = true, optional = true } reqwest = { version = "0.11.13", optional = true } rmp-serde = { version = "1.1.1", optional = true } rustrict = { version = "0.5.5", optional = true } -serde = { workspace = true } +serde = { workspace = true, features = ["derive", "std"] } serde_json = { workspace = true, optional = true } strum = { workspace = true, features = ["derive"], optional = true } thiserror = { workspace = true, optional = true } tonic = { version = "0.8.3", optional = true } tower = { workspace = true, optional = true } tower-http = { workspace = true, optional = true } -tracing = { workspace = true } +tracing = { workspace = true, features = ["std"] } tracing-opentelemetry = { workspace = true, optional = true } tracing-subscriber = { workspace = true, optional = true } ttl_cache = { workspace = true, optional = true } uuid = { workspace = true, features = ["v4", "serde"], optional = true } [features] -backend = ["async-trait", "axum/matched-path", "bytes", "http", "http-body", "hyper/client", "jsonwebtoken", "opentelemetry", "opentelemetry-http", "opentelemetry-otlp", "thiserror", "tower", "tower-http", "tracing-opentelemetry", "tracing-subscriber/env-filter", "tracing-subscriber/fmt", "ttl_cache"] -display = ["comfy-table", "crossterm"] -models = ["anyhow", "async-trait", "display", "http", "prost-types", "reqwest", "serde_json", "service", "thiserror"] +backend = ["async-trait", "axum/matched-path", "claims", "hyper/client", "opentelemetry-otlp", "thiserror", "tower-http", "tracing-subscriber/env-filter", "tracing-subscriber/fmt", "ttl_cache"] +claims = ["bytes", "chrono/clock", "headers", "http", "http-body", "jsonwebtoken", "opentelemetry", "opentelemetry-http", "pin-project", "tower", "tracing", "tracing-opentelemetry"] +display = ["chrono/clock", "comfy-table", "crossterm"] +error = ["prost-types", "serde_json", "thiserror", "uuid"] +models = ["anyhow", "async-trait", "display", "http", "reqwest", "serde_json", "service"] service = ["chrono/serde", "once_cell", "rustrict", "serde/derive", "strum", "uuid"] tracing = ["serde_json"] -wasm = ["http-serde", "http", "rmp-serde", "tracing", "tracing-subscriber"] +wasm = ["chrono/clock", "http-serde", "http", "rmp-serde", "tracing", "tracing-subscriber"] [dev-dependencies] axum = { workspace = true } diff --git a/common/src/backends/auth.rs b/common/src/backends/auth.rs index dc293c24b..f87bd3fed 100644 --- a/common/src/backends/auth.rs +++ b/common/src/backends/auth.rs @@ -1,13 +1,11 @@ -use std::{convert::Infallible, future::Future, ops::Add, pin::Pin, sync::Arc}; +use std::{convert::Infallible, future::Future, pin::Pin, sync::Arc}; use async_trait::async_trait; use bytes::Bytes; -use chrono::{Duration, Utc}; use headers::{authorization::Bearer, Authorization, HeaderMapExt}; use http::{Request, Response, StatusCode, Uri}; use http_body::combinators::UnsyncBoxBody; use hyper::{body, Body, Client}; -use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header as JwtHeader, Validation}; use opentelemetry::global; use opentelemetry_http::HeaderInjector; use serde::{Deserialize, Serialize}; @@ -16,14 +14,14 @@ use tower::{Layer, Service}; use tracing::{error, trace, Span}; use tracing_opentelemetry::OpenTelemetrySpanExt; +use crate::claims::{Claim, Scope}; + use super::{ cache::{CacheManagement, CacheManager}, - future::{ResponseFuture, StatusCodeFuture}, + future::StatusCodeFuture, headers::XShuttleAdminSecret, }; -pub const EXP_MINUTES: i64 = 5; -const ISS: &str = "shuttle"; const PUBLIC_KEY_CACHE_KEY: &str = "shuttle.public-key"; /// Layer to check the admin secret set by deployer is correct @@ -86,164 +84,12 @@ where } } -/// The scope of operations that can be performed on shuttle -/// Every scope defaults to read and will use a suffix for updating tasks -#[derive(Clone, Debug, Deserialize, Serialize, Eq, PartialEq)] -#[serde(rename_all = "snake_case")] -pub enum Scope { - /// Read the details, such as status and address, of a deployment - Deployment, - - /// Push a new deployment - DeploymentPush, - - /// Read the logs of a deployment - Logs, - - /// Read the details of a service - Service, - - /// Create a new service - ServiceCreate, - - /// Read the status of a project - Project, - - /// Create a new project - ProjectCreate, - - /// Get the resources for a project - Resources, - - /// Provision new resources for a project or update existing ones - ResourcesWrite, - - /// List the secrets of a project - Secret, - - /// Add or update secrets of a project - SecretWrite, - - /// Get list of users - User, - - /// Add or update users - UserCreate, - - /// Create an ACME account - AcmeCreate, - - /// Create a custom domain, - CustomDomainCreate, - - /// Admin level scope to internals - Admin, -} - #[derive(Deserialize, Serialize)] /// Response used internally to pass around JWT token pub struct ConvertResponse { pub token: String, } -#[derive(Clone, Debug, Deserialize, Serialize, Eq, PartialEq)] -pub struct Claim { - /// Expiration time (as UTC timestamp). - pub exp: usize, - /// Issued at (as UTC timestamp). - iat: usize, - /// Issuer. - iss: String, - /// Not Before (as UTC timestamp). - nbf: usize, - /// Subject (whom token refers to). - pub sub: String, - /// Scopes this token can access - pub scopes: Vec, - /// The original token that was parsed - token: Option, -} - -impl Claim { - /// Create a new claim for a user with the given scopes - pub fn new(sub: String, scopes: Vec) -> Self { - let iat = Utc::now(); - let exp = iat.add(Duration::minutes(EXP_MINUTES)); - - Self { - exp: exp.timestamp() as usize, - iat: iat.timestamp() as usize, - iss: ISS.to_string(), - nbf: iat.timestamp() as usize, - sub, - scopes, - token: None, - } - } - - pub fn into_token(self, encoding_key: &EncodingKey) -> Result { - if let Some(token) = self.token { - Ok(token) - } else { - encode( - &JwtHeader::new(jsonwebtoken::Algorithm::EdDSA), - &self, - encoding_key, - ) - .map_err(|err| { - error!( - error = &err as &dyn std::error::Error, - "failed to convert claim to token" - ); - match err.kind() { - jsonwebtoken::errors::ErrorKind::Json(_) => StatusCode::INTERNAL_SERVER_ERROR, - jsonwebtoken::errors::ErrorKind::Crypto(_) => StatusCode::SERVICE_UNAVAILABLE, - _ => StatusCode::INTERNAL_SERVER_ERROR, - } - }) - } - } - - pub fn from_token(token: &str, public_key: &[u8]) -> Result { - let decoding_key = DecodingKey::from_ed_der(public_key); - let mut validation = Validation::new(jsonwebtoken::Algorithm::EdDSA); - validation.set_issuer(&[ISS]); - - trace!(token, "converting token to claim"); - let mut claim: Self = decode(token, &decoding_key, &validation) - .map_err(|err| { - error!( - error = &err as &dyn std::error::Error, - "failed to convert token to claim" - ); - match err.kind() { - jsonwebtoken::errors::ErrorKind::InvalidSignature - | jsonwebtoken::errors::ErrorKind::InvalidAlgorithmName - | jsonwebtoken::errors::ErrorKind::ExpiredSignature - | jsonwebtoken::errors::ErrorKind::InvalidIssuer - | jsonwebtoken::errors::ErrorKind::ImmatureSignature => { - StatusCode::UNAUTHORIZED - } - jsonwebtoken::errors::ErrorKind::InvalidToken - | jsonwebtoken::errors::ErrorKind::InvalidAlgorithm - | jsonwebtoken::errors::ErrorKind::Base64(_) - | jsonwebtoken::errors::ErrorKind::Json(_) - | jsonwebtoken::errors::ErrorKind::Utf8(_) => StatusCode::BAD_REQUEST, - jsonwebtoken::errors::ErrorKind::MissingAlgorithm => { - StatusCode::INTERNAL_SERVER_ERROR - } - jsonwebtoken::errors::ErrorKind::Crypto(_) => StatusCode::SERVICE_UNAVAILABLE, - _ => StatusCode::INTERNAL_SERVER_ERROR, - } - })? - .claims; - - claim.token = Some(token.to_string()); - - Ok(claim) - } -} - /// Trait to get a public key asynchronously #[async_trait] pub trait PublicKeyFn: Send + Sync + Clone { @@ -439,53 +285,6 @@ where } } -/// This layer takes a claim on a request extension and uses it's internal token to set the Authorization Bearer -#[derive(Clone)] -pub struct ClaimLayer; - -impl Layer for ClaimLayer { - type Service = ClaimService; - - fn layer(&self, inner: S) -> Self::Service { - ClaimService { inner } - } -} - -#[derive(Clone)] -pub struct ClaimService { - inner: S, -} - -impl Service>> for ClaimService -where - S: Service>> + Send + 'static, - S::Future: Send + 'static, -{ - type Response = S::Response; - type Error = S::Error; - type Future = ResponseFuture; - - fn poll_ready( - &mut self, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - self.inner.poll_ready(cx) - } - - fn call(&mut self, mut req: Request>) -> Self::Future { - if let Some(claim) = req.extensions().get::() { - if let Some(token) = claim.token.clone() { - req.headers_mut() - .typed_insert(Authorization::bearer(&token).expect("to set JWT token")); - } - } - - let future = self.inner.call(req); - - ResponseFuture(future) - } -} - /// Check that the required scopes are set on the [Claim] extension on a [Request] #[derive(Clone)] pub struct ScopedLayer { @@ -568,7 +367,9 @@ mod tests { use serde_json::json; use tower::{ServiceBuilder, ServiceExt}; - use super::{Claim, JwtAuthenticationLayer, Scope, ScopedLayer}; + use crate::claims::{Claim, Scope}; + + use super::{JwtAuthenticationLayer, ScopedLayer}; #[test] fn to_token_and_back() { diff --git a/common/src/backends/future.rs b/common/src/backends/future.rs index 5603fdaa0..e50bd41d4 100644 --- a/common/src/backends/future.rs +++ b/common/src/backends/future.rs @@ -8,23 +8,6 @@ use axum::response::Response; use http::StatusCode; use pin_project::pin_project; -// Future for layers that just return the inner response -#[pin_project] -pub struct ResponseFuture(#[pin] pub F); - -impl Future for ResponseFuture -where - F: Future>, -{ - type Output = Result; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.project(); - - this.0.poll(cx) - } -} - /// Future for layers that might return a different status code #[pin_project(project = StatusCodeProj)] pub enum StatusCodeFuture { diff --git a/common/src/backends/metrics.rs b/common/src/backends/metrics.rs index 9d7fc2b77..0ecc7ebec 100644 --- a/common/src/backends/metrics.rs +++ b/common/src/backends/metrics.rs @@ -54,6 +54,9 @@ impl TraceLayer { /// /// # Example /// ``` + /// use shuttle_common::{request_span, backends::metrics::TraceLayer}; + /// use tracing::field; + /// /// TraceLayer::new(|request| { /// request_span!( /// request, @@ -206,7 +209,7 @@ mod tests { use tower::ServiceExt; use tracing::field; use tracing_fluent_assertions::{AssertionRegistry, AssertionsLayer}; - use tracing_subscriber::{layer::SubscriberExt, Registry}; + use tracing_subscriber::layer::SubscriberExt; use super::{Metrics, TraceLayer}; @@ -221,9 +224,9 @@ mod tests { #[tokio::test] async fn trace_layer() { let assertion_registry = AssertionRegistry::default(); - let base_subscriber = Registry::default(); - let subscriber = base_subscriber.with(AssertionsLayer::new(&assertion_registry)); - tracing::subscriber::set_global_default(subscriber).unwrap(); + let subscriber = + tracing_subscriber::registry().with(AssertionsLayer::new(&assertion_registry)); + let _guard = tracing::subscriber::set_default(subscriber); // Put in own block to make sure assertion to not interfere with the next test { diff --git a/common/src/backends/tracing.rs b/common/src/backends/tracing.rs index 64a95cb8f..a3fa04eb7 100644 --- a/common/src/backends/tracing.rs +++ b/common/src/backends/tracing.rs @@ -11,7 +11,7 @@ use opentelemetry::{ sdk::{propagation::TraceContextPropagator, trace, Resource}, KeyValue, }; -use opentelemetry_http::{HeaderExtractor, HeaderInjector}; +use opentelemetry_http::HeaderExtractor; use opentelemetry_otlp::WithExportConfig; use pin_project::pin_project; use tower::{Layer, Service}; @@ -19,8 +19,6 @@ use tracing::{debug_span, instrument::Instrumented, Instrument, Span, Subscriber use tracing_opentelemetry::OpenTelemetrySpanExt; use tracing_subscriber::{fmt, prelude::*, registry::LookupSpan, EnvFilter}; -use super::future::ResponseFuture; - pub fn setup_tracing(subscriber: S, service_name: &str) where S: Subscriber + for<'a> LookupSpan<'a> + Send + Sync, @@ -139,49 +137,3 @@ where ExtractPropagationFuture { response_future } } } - -/// This layer adds the current tracing span to any outgoing request -#[derive(Clone)] -pub struct InjectPropagationLayer; - -impl Layer for InjectPropagationLayer { - type Service = InjectPropagation; - - fn layer(&self, inner: S) -> Self::Service { - InjectPropagation { inner } - } -} - -#[derive(Clone)] -pub struct InjectPropagation { - inner: S, -} - -impl Service> for InjectPropagation -where - S: Service> + Send + 'static, - S::Future: Send + 'static, -{ - type Response = S::Response; - type Error = S::Error; - type Future = ResponseFuture; - - fn poll_ready( - &mut self, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - self.inner.poll_ready(cx) - } - - fn call(&mut self, mut req: Request) -> Self::Future { - let cx = Span::current().context(); - - global::get_text_map_propagator(|propagator| { - propagator.inject_context(&cx, &mut HeaderInjector(req.headers_mut())) - }); - - let future = self.inner.call(req); - - ResponseFuture(future) - } -} diff --git a/common/src/claims.rs b/common/src/claims.rs new file mode 100644 index 000000000..c9e4d35df --- /dev/null +++ b/common/src/claims.rs @@ -0,0 +1,285 @@ +use std::{ + future::Future, + ops::Add, + pin::Pin, + task::{Context, Poll}, +}; + +use bytes::Bytes; +use chrono::{Duration, Utc}; +use headers::{Authorization, HeaderMapExt}; +use http::{Request, StatusCode}; +use http_body::combinators::UnsyncBoxBody; +use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation}; +use opentelemetry::global; +use opentelemetry_http::HeaderInjector; +use pin_project::pin_project; +use serde::{Deserialize, Serialize}; +use tower::{Layer, Service}; +use tracing::{error, trace, Span}; +use tracing_opentelemetry::OpenTelemetrySpanExt; + +pub const EXP_MINUTES: i64 = 5; +const ISS: &str = "shuttle"; + +/// The scope of operations that can be performed on shuttle +/// Every scope defaults to read and will use a suffix for updating tasks +#[derive(Clone, Debug, Deserialize, Serialize, Eq, PartialEq)] +#[serde(rename_all = "snake_case")] +pub enum Scope { + /// Read the details, such as status and address, of a deployment + Deployment, + + /// Push a new deployment + DeploymentPush, + + /// Read the logs of a deployment + Logs, + + /// Read the details of a service + Service, + + /// Create a new service + ServiceCreate, + + /// Read the status of a project + Project, + + /// Create a new project + ProjectCreate, + + /// Get the resources for a project + Resources, + + /// Provision new resources for a project or update existing ones + ResourcesWrite, + + /// List the secrets of a project + Secret, + + /// Add or update secrets of a project + SecretWrite, + + /// Get list of users + User, + + /// Add or update users + UserCreate, + + /// Create an ACME account + AcmeCreate, + + /// Create a custom domain, + CustomDomainCreate, + + /// Admin level scope to internals + Admin, +} + +#[derive(Clone, Debug, Deserialize, Serialize, Eq, PartialEq)] +pub struct Claim { + /// Expiration time (as UTC timestamp). + pub exp: usize, + /// Issued at (as UTC timestamp). + iat: usize, + /// Issuer. + iss: String, + /// Not Before (as UTC timestamp). + nbf: usize, + /// Subject (whom token refers to). + pub sub: String, + /// Scopes this token can access + pub scopes: Vec, + /// The original token that was parsed + pub(crate) token: Option, +} + +impl Claim { + /// Create a new claim for a user with the given scopes + pub fn new(sub: String, scopes: Vec) -> Self { + let iat = Utc::now(); + let exp = iat.add(Duration::minutes(EXP_MINUTES)); + + Self { + exp: exp.timestamp() as usize, + iat: iat.timestamp() as usize, + iss: ISS.to_string(), + nbf: iat.timestamp() as usize, + sub, + scopes, + token: None, + } + } + + pub fn into_token(self, encoding_key: &EncodingKey) -> Result { + if let Some(token) = self.token { + Ok(token) + } else { + encode( + &Header::new(jsonwebtoken::Algorithm::EdDSA), + &self, + encoding_key, + ) + .map_err(|err| { + error!( + error = &err as &dyn std::error::Error, + "failed to convert claim to token" + ); + match err.kind() { + jsonwebtoken::errors::ErrorKind::Json(_) => StatusCode::INTERNAL_SERVER_ERROR, + jsonwebtoken::errors::ErrorKind::Crypto(_) => StatusCode::SERVICE_UNAVAILABLE, + _ => StatusCode::INTERNAL_SERVER_ERROR, + } + }) + } + } + + pub fn from_token(token: &str, public_key: &[u8]) -> Result { + let decoding_key = DecodingKey::from_ed_der(public_key); + let mut validation = Validation::new(jsonwebtoken::Algorithm::EdDSA); + validation.set_issuer(&[ISS]); + + trace!(token, "converting token to claim"); + let mut claim: Self = decode(token, &decoding_key, &validation) + .map_err(|err| { + error!( + error = &err as &dyn std::error::Error, + "failed to convert token to claim" + ); + match err.kind() { + jsonwebtoken::errors::ErrorKind::InvalidSignature + | jsonwebtoken::errors::ErrorKind::InvalidAlgorithmName + | jsonwebtoken::errors::ErrorKind::ExpiredSignature + | jsonwebtoken::errors::ErrorKind::InvalidIssuer + | jsonwebtoken::errors::ErrorKind::ImmatureSignature => { + StatusCode::UNAUTHORIZED + } + jsonwebtoken::errors::ErrorKind::InvalidToken + | jsonwebtoken::errors::ErrorKind::InvalidAlgorithm + | jsonwebtoken::errors::ErrorKind::Base64(_) + | jsonwebtoken::errors::ErrorKind::Json(_) + | jsonwebtoken::errors::ErrorKind::Utf8(_) => StatusCode::BAD_REQUEST, + jsonwebtoken::errors::ErrorKind::MissingAlgorithm => { + StatusCode::INTERNAL_SERVER_ERROR + } + jsonwebtoken::errors::ErrorKind::Crypto(_) => StatusCode::SERVICE_UNAVAILABLE, + _ => StatusCode::INTERNAL_SERVER_ERROR, + } + })? + .claims; + + claim.token = Some(token.to_string()); + + Ok(claim) + } +} + +// Future for layers that just return the inner response +#[pin_project] +pub struct ResponseFuture(#[pin] pub F); + +impl Future for ResponseFuture +where + F: Future>, +{ + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + + this.0.poll(cx) + } +} + +/// This layer takes a claim on a request extension and uses it's internal token to set the Authorization Bearer +#[derive(Clone)] +pub struct ClaimLayer; + +impl Layer for ClaimLayer { + type Service = ClaimService; + + fn layer(&self, inner: S) -> Self::Service { + ClaimService { inner } + } +} + +#[derive(Clone)] +pub struct ClaimService { + inner: S, +} + +impl Service>> for ClaimService +where + S: Service>> + Send + 'static, + S::Future: Send + 'static, +{ + type Response = S::Response; + type Error = S::Error; + type Future = ResponseFuture; + + fn poll_ready( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, mut req: Request>) -> Self::Future { + if let Some(claim) = req.extensions().get::() { + if let Some(token) = claim.token.clone() { + req.headers_mut() + .typed_insert(Authorization::bearer(&token).expect("to set JWT token")); + } + } + + let future = self.inner.call(req); + + ResponseFuture(future) + } +} + +/// This layer adds the current tracing span to any outgoing request +#[derive(Clone)] +pub struct InjectPropagationLayer; + +impl Layer for InjectPropagationLayer { + type Service = InjectPropagation; + + fn layer(&self, inner: S) -> Self::Service { + InjectPropagation { inner } + } +} + +#[derive(Clone)] +pub struct InjectPropagation { + inner: S, +} + +impl Service> for InjectPropagation +where + S: Service> + Send + 'static, + S::Future: Send + 'static, +{ + type Response = S::Response; + type Error = S::Error; + type Future = ResponseFuture; + + fn poll_ready( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, mut req: Request) -> Self::Future { + let cx = Span::current().context(); + + global::get_text_map_propagator(|propagator| { + propagator.inject_context(&cx, &mut HeaderInjector(req.headers_mut())) + }); + + let future = self.inner.call(req); + + ResponseFuture(future) + } +} diff --git a/common/src/lib.rs b/common/src/lib.rs index 11d3741c1..a68067084 100644 --- a/common/src/lib.rs +++ b/common/src/lib.rs @@ -1,5 +1,7 @@ #[cfg(feature = "backend")] pub mod backends; +#[cfg(feature = "claims")] +pub mod claims; #[cfg(feature = "service")] pub mod database; #[cfg(feature = "service")] @@ -38,6 +40,18 @@ pub type Host = String; #[cfg(feature = "service")] pub type DeploymentId = Uuid; +#[cfg(feature = "error")] +/// Errors that can occur when changing types. Especially from prost +#[derive(thiserror::Error, Debug)] +pub enum ParseError { + #[error("failed to parse UUID: {0}")] + Uuid(#[from] uuid::Error), + #[error("failed to parse timestamp: {0}")] + Timestamp(#[from] prost_types::TimestampError), + #[error("failed to parse serde: {0}")] + Serde(#[from] serde_json::Error), +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct DatabaseReadyInfo { engine: String, diff --git a/common/src/models/mod.rs b/common/src/models/mod.rs index 2d8042501..328a33cf8 100644 --- a/common/src/models/mod.rs +++ b/common/src/models/mod.rs @@ -11,7 +11,6 @@ use anyhow::{Context, Result}; use async_trait::async_trait; use http::StatusCode; use serde::de::DeserializeOwned; -use thiserror::Error; use tracing::trace; /// A to_json wrapper for handling our error states @@ -50,14 +49,3 @@ impl ToJson for reqwest::Response { } } } - -/// Errors that can occur when changing types. Especially from prost -#[derive(Error, Debug)] -pub enum ParseError { - #[error("failed to parse UUID: {0}")] - Uuid(#[from] uuid::Error), - #[error("failed to parse timestamp: {0}")] - Timestamp(#[from] prost_types::TimestampError), - #[error("failed to parse serde: {0}")] - Serde(#[from] serde_json::Error), -} diff --git a/common/src/wasm.rs b/common/src/wasm.rs index edaa8808b..1508644dd 100644 --- a/common/src/wasm.rs +++ b/common/src/wasm.rs @@ -366,7 +366,7 @@ where } #[cfg(test)] -mod test { +mod tests { use cap_std::os::unix::net::UnixStream; use serde_json::json; use std::io::{Read, Write}; @@ -499,7 +499,7 @@ mod test { (message, log.level) }; - tracing_subscriber::registry().with(logger).init(); + let _guard = tracing_subscriber::registry().with(logger).set_default(); tracing::debug!("this is"); tracing::info!("hi"); diff --git a/deployer/src/deployment/deploy_layer.rs b/deployer/src/deployment/deploy_layer.rs index ad0877360..f756a56de 100644 --- a/deployer/src/deployment/deploy_layer.rs +++ b/deployer/src/deployment/deploy_layer.rs @@ -21,7 +21,7 @@ use chrono::{DateTime, Utc}; use serde_json::json; -use shuttle_common::{models::ParseError, tracing::JsonVisitor, STATE_MESSAGE}; +use shuttle_common::{tracing::JsonVisitor, ParseError, STATE_MESSAGE}; use shuttle_proto::runtime; use std::{convert::TryFrom, str::FromStr, time::SystemTime}; use tracing::{field::Visit, span, warn, Metadata, Subscriber}; @@ -618,7 +618,7 @@ mod tests { ); select! { - _ = sleep(Duration::from_secs(240)) => { + _ = sleep(Duration::from_secs(360)) => { let states = RECORDER.lock().unwrap().get_deployment_states(&id); panic!("states should go into 'Running' for a valid service: {:#?}", states); }, @@ -702,7 +702,7 @@ mod tests { ); select! { - _ = sleep(Duration::from_secs(240)) => { + _ = sleep(Duration::from_secs(360)) => { let states = RECORDER.lock().unwrap().get_deployment_states(&id); panic!("states should go into 'Completed' when a service stops by itself: {:#?}", states); } @@ -749,7 +749,7 @@ mod tests { ); select! { - _ = sleep(Duration::from_secs(240)) => { + _ = sleep(Duration::from_secs(360)) => { let states = RECORDER.lock().unwrap().get_deployment_states(&id); panic!("states should go into 'Crashed' panicking in bind: {:#?}", states); } @@ -784,10 +784,6 @@ mod tests { id, state: State::Loading, }, - StateLog { - id, - state: State::Running, - }, StateLog { id, state: State::Crashed, @@ -796,7 +792,7 @@ mod tests { ); select! { - _ = sleep(Duration::from_secs(240)) => { + _ = sleep(Duration::from_secs(360)) => { let states = RECORDER.lock().unwrap().get_deployment_states(&id); panic!("states should go into 'Crashed' when panicking in main: {:#?}", states); } diff --git a/deployer/src/deployment/queue.rs b/deployer/src/deployment/queue.rs index 1c7263f3d..c48a20306 100644 --- a/deployer/src/deployment/queue.rs +++ b/deployer/src/deployment/queue.rs @@ -11,7 +11,7 @@ use chrono::Utc; use crossbeam_channel::Sender; use opentelemetry::global; use serde_json::json; -use shuttle_common::backends::auth::Claim; +use shuttle_common::claims::Claim; use shuttle_service::builder::{build_crate, get_config, Runtime}; use tokio::time::{sleep, timeout}; use tracing::{debug, debug_span, error, info, instrument, trace, warn, Instrument, Span}; diff --git a/deployer/src/deployment/run.rs b/deployer/src/deployment/run.rs index 4df39521c..5f520b06f 100644 --- a/deployer/src/deployment/run.rs +++ b/deployer/src/deployment/run.rs @@ -8,7 +8,7 @@ use std::{ use async_trait::async_trait; use opentelemetry::global; use portpicker::pick_unused_port; -use shuttle_common::{backends::auth::Claim, storage_manager::ArtifactsStorageManager}; +use shuttle_common::{claims::Claim, storage_manager::ArtifactsStorageManager}; use shuttle_proto::runtime::{ runtime_client::RuntimeClient, LoadRequest, StartRequest, StopReason, SubscribeStopRequest, @@ -276,8 +276,15 @@ async fn load( match response { Ok(response) => { - info!(response = ?response.into_inner(), "loading response: "); - Ok(()) + let response = response.into_inner(); + info!(?response, "loading response"); + + if response.success { + Ok(()) + } else { + error!(error = %response.message, "failed to load service"); + Err(Error::Load(response.message)) + } } Err(error) => { error!(%error, "failed to load service"); @@ -370,7 +377,6 @@ mod tests { use uuid::Uuid; use crate::{ - error::Error, persistence::{DeploymentUpdater, Secret, SecretGetter}, RuntimeManager, }; @@ -424,7 +430,13 @@ mod tests { let tmp_dir = Builder::new().prefix("shuttle_run_test").tempdir().unwrap(); let path = tmp_dir.into_path(); - let (tx, _rx) = crossbeam_channel::unbounded(); + let (tx, rx) = crossbeam_channel::unbounded(); + + tokio::runtime::Handle::current().spawn_blocking(move || { + while let Ok(log) = rx.recv() { + println!("test log: {log:?}"); + } + }); RuntimeManager::new(path, format!("http://{}", provisioner_addr), tx) } @@ -464,7 +476,7 @@ mod tests { // This test uses the kill signal to make sure a service does stop when asked to #[tokio::test] async fn can_be_killed() { - let (built, storage_manager) = make_so_and_built("sleep-async"); + let (built, storage_manager) = make_and_built("sleep-async"); let id = built.id; let runtime_manager = get_runtime_manager(); let (cleanup_send, cleanup_recv) = oneshot::channel(); @@ -506,7 +518,7 @@ mod tests { // This test does not use a kill signal to stop the service. Rather the service decided to stop on its own without errors #[tokio::test] async fn self_stop() { - let (built, storage_manager) = make_so_and_built("sleep-async"); + let (built, storage_manager) = make_and_built("sleep-async"); let runtime_manager = get_runtime_manager(); let (cleanup_send, cleanup_recv) = oneshot::channel(); @@ -544,7 +556,7 @@ mod tests { // Test for panics in Service::bind #[tokio::test] async fn panic_in_bind() { - let (built, storage_manager) = make_so_and_built("bind-panic"); + let (built, storage_manager) = make_and_built("bind-panic"); let runtime_manager = get_runtime_manager(); let (cleanup_send, cleanup_recv) = oneshot::channel(); @@ -552,7 +564,7 @@ mod tests { StopReason::from_i32(response.reason).unwrap(), response.message, ) { - (StopReason::Crash, mes) if mes.contains("Panic occurred in `Service::bind`") => { + (StopReason::Crash, mes) if mes.contains("panic in bind") => { cleanup_send.send(()).unwrap() } (_, mes) => panic!("expected stop due to crash: {mes}"), @@ -583,20 +595,12 @@ mod tests { // Test for panics in the main function #[tokio::test] + #[should_panic(expected = "Load(\"main panic\")")] async fn panic_in_main() { - let (built, storage_manager) = make_so_and_built("main-panic"); + let (built, storage_manager) = make_and_built("main-panic"); let runtime_manager = get_runtime_manager(); - let (cleanup_send, cleanup_recv) = oneshot::channel(); - let handle_cleanup = |response: SubscribeStopResponse| match ( - StopReason::from_i32(response.reason).unwrap(), - response.message, - ) { - (StopReason::Crash, mes) if mes.contains("Panic occurred in shuttle_service::main") => { - cleanup_send.send(()).unwrap() - } - (_, mes) => panic!("expected stop due to crash: {mes}"), - }; + let handle_cleanup = |_result| panic!("service should never be started"); let secret_getter = get_secret_getter(); @@ -611,50 +615,9 @@ mod tests { ) .await .unwrap(); - - tokio::select! { - _ = sleep(Duration::from_secs(5)) => panic!("cleanup should have been called"), - Ok(()) = cleanup_recv => {} - } - - // Prevent the runtime manager from dropping earlier, which will kill the processes it manages - drop(runtime_manager); } - #[tokio::test] - async fn missing_so() { - let built = Built { - id: Uuid::new_v4(), - service_name: "test".to_string(), - service_id: Uuid::new_v4(), - tracing_context: Default::default(), - is_next: false, - claim: None, - }; - - let handle_cleanup = |_result| panic!("no service means no cleanup"); - let secret_getter = get_secret_getter(); - let storage_manager = get_storage_manager(); - - let result = built - .handle( - storage_manager, - secret_getter, - get_runtime_manager(), - StubDeploymentUpdater, - kill_old_deployments(), - handle_cleanup, - ) - .await; - - assert!( - matches!(result, Err(Error::Load(_))), - "expected missing 'so' error: {:?}", - result - ); - } - - fn make_so_and_built(crate_name: &str) -> (Built, ArtifactsStorageManager) { + fn make_and_built(crate_name: &str) -> (Built, ArtifactsStorageManager) { let crate_dir: PathBuf = [RESOURCES_PATH, crate_name].iter().collect(); Command::new("cargo") @@ -665,12 +628,10 @@ mod tests { .wait() .unwrap(); - let dashes_replaced = crate_name.replace('-', "_"); - let lib_name = if cfg!(target_os = "windows") { - format!("{}.dll", dashes_replaced) + format!("{}.exe", crate_name) } else { - format!("lib{}.so", dashes_replaced) + crate_name.to_string() }; let id = Uuid::new_v4(); diff --git a/deployer/src/handlers/mod.rs b/deployer/src/handlers/mod.rs index c7a01f1c3..f0185f02f 100644 --- a/deployer/src/handlers/mod.rs +++ b/deployer/src/handlers/mod.rs @@ -13,10 +13,11 @@ use fqdn::FQDN; use futures::StreamExt; use hyper::Uri; use shuttle_common::backends::auth::{ - AdminSecretLayer, AuthPublicKey, Claim, JwtAuthenticationLayer, Scope, ScopedLayer, + AdminSecretLayer, AuthPublicKey, JwtAuthenticationLayer, ScopedLayer, }; use shuttle_common::backends::headers::XShuttleAccountName; use shuttle_common::backends::metrics::{Metrics, TraceLayer}; +use shuttle_common::claims::{Claim, Scope}; use shuttle_common::models::secret; use shuttle_common::project::ProjectName; use shuttle_common::storage_manager::StorageManager; diff --git a/deployer/tests/deploy_layer/bind-panic/Cargo.toml b/deployer/tests/deploy_layer/bind-panic/Cargo.toml index 5475a549c..632c09176 100644 --- a/deployer/tests/deploy_layer/bind-panic/Cargo.toml +++ b/deployer/tests/deploy_layer/bind-panic/Cargo.toml @@ -8,5 +8,5 @@ edition = "2021" [workspace] [dependencies] -shuttle-runtime = { path = "../../../../runtime" } +shuttle-runtime = "0.1.0" tokio = "1.22" diff --git a/deployer/tests/deploy_layer/main-panic/Cargo.toml b/deployer/tests/deploy_layer/main-panic/Cargo.toml index 9e068f31d..9632a72d7 100644 --- a/deployer/tests/deploy_layer/main-panic/Cargo.toml +++ b/deployer/tests/deploy_layer/main-panic/Cargo.toml @@ -8,5 +8,5 @@ edition = "2021" [workspace] [dependencies] -shuttle-runtime = { path = "../../../../runtime" } +shuttle-runtime = "0.1.0" tokio = "1.22" diff --git a/deployer/tests/deploy_layer/self-stop/Cargo.toml b/deployer/tests/deploy_layer/self-stop/Cargo.toml index f740b87b7..bab0ac511 100644 --- a/deployer/tests/deploy_layer/self-stop/Cargo.toml +++ b/deployer/tests/deploy_layer/self-stop/Cargo.toml @@ -8,5 +8,5 @@ edition = "2021" [workspace] [dependencies] -shuttle-runtime = { path = "../../../../runtime" } +shuttle-runtime = "0.1.0" tokio = "1.22" diff --git a/deployer/tests/deploy_layer/self-stop/src/main.rs b/deployer/tests/deploy_layer/self-stop/src/main.rs index 9150538b5..ddbba66d2 100644 --- a/deployer/tests/deploy_layer/self-stop/src/main.rs +++ b/deployer/tests/deploy_layer/self-stop/src/main.rs @@ -8,6 +8,6 @@ impl shuttle_runtime::Service for MyService { } #[shuttle_runtime::main] -async fn self_stop() -> Result { +async fn self_stop() -> Result { Ok(MyService) } diff --git a/deployer/tests/deploy_layer/sleep-async/Cargo.toml b/deployer/tests/deploy_layer/sleep-async/Cargo.toml index edd2ea6bc..ed33c2630 100644 --- a/deployer/tests/deploy_layer/sleep-async/Cargo.toml +++ b/deployer/tests/deploy_layer/sleep-async/Cargo.toml @@ -8,5 +8,5 @@ edition = "2021" [workspace] [dependencies] -shuttle-runtime = { path = "../../../../runtime" } +shuttle-runtime = "0.1.0" tokio = { version = "1.0", features = ["time"]} diff --git a/examples b/examples index a5c78703a..c35653890 160000 --- a/examples +++ b/examples @@ -1 +1 @@ -Subproject commit a5c78703ab676bf7ed1649ef19cb4bfe43c5cc29 +Subproject commit c35653890aa7b0d8e4cb70131027990d4ed6afa6 diff --git a/gateway/src/api/latest.rs b/gateway/src/api/latest.rs index 2b64bc4fb..42f1ff3a6 100644 --- a/gateway/src/api/latest.rs +++ b/gateway/src/api/latest.rs @@ -16,11 +16,10 @@ use futures::Future; use http::{StatusCode, Uri}; use instant_acme::{AccountCredentials, ChallengeType}; use serde::{Deserialize, Serialize}; -use shuttle_common::backends::auth::{ - AuthPublicKey, JwtAuthenticationLayer, Scope, ScopedLayer, EXP_MINUTES, -}; +use shuttle_common::backends::auth::{AuthPublicKey, JwtAuthenticationLayer, ScopedLayer}; use shuttle_common::backends::cache::CacheManager; use shuttle_common::backends::metrics::{Metrics, TraceLayer}; +use shuttle_common::claims::{Scope, EXP_MINUTES}; use shuttle_common::models::error::ErrorKind; use shuttle_common::models::{project, stats}; use shuttle_common::request_span; @@ -593,7 +592,8 @@ pub mod tests { Request::builder() .method("POST") .uri(format!("/projects/{project}")) - .body(Body::empty()) + .header("Content-Type", "application/json") + .body("{\"idle_minutes\": 3}".into()) .unwrap() }; @@ -762,7 +762,8 @@ pub mod tests { let create_project = Request::builder() .method("POST") .uri(format!("/projects/{matrix}")) - .body(Body::empty()) + .header("Content-Type", "application/json") + .body("{\"idle_minutes\": 3}".into()) .unwrap() .with_header(&authorization); diff --git a/gateway/src/auth.rs b/gateway/src/auth.rs index 679890b0c..0c9910aa9 100644 --- a/gateway/src/auth.rs +++ b/gateway/src/auth.rs @@ -4,7 +4,7 @@ use std::str::FromStr; use axum::extract::{FromRef, FromRequestParts, Path}; use axum::http::request::Parts; use serde::{Deserialize, Serialize}; -use shuttle_common::backends::auth::{Claim, Scope}; +use shuttle_common::claims::{Claim, Scope}; use tracing::{trace, Span}; use crate::api::latest::RouterState; diff --git a/gateway/src/lib.rs b/gateway/src/lib.rs index b1a6e6175..000700142 100644 --- a/gateway/src/lib.rs +++ b/gateway/src/lib.rs @@ -344,7 +344,8 @@ pub mod tests { use jsonwebtoken::EncodingKey; use rand::distributions::{Alphanumeric, DistString, Distribution, Uniform}; use ring::signature::{self, Ed25519KeyPair, KeyPair}; - use shuttle_common::backends::auth::{Claim, ConvertResponse, Scope}; + use shuttle_common::backends::auth::ConvertResponse; + use shuttle_common::claims::{Claim, Scope}; use shuttle_common::models::project; use sqlx::SqlitePool; use tokio::sync::mpsc::channel; @@ -801,7 +802,8 @@ pub mod tests { .request( Request::post("/projects/matrix") .with_header(&authorization) - .body(Body::empty()) + .header("Content-Type", "application/json") + .body("{\"idle_minutes\": 3}".into()) .unwrap(), ) .map_ok(|resp| { diff --git a/proto/Cargo.toml b/proto/Cargo.toml index aadf6f749..bb6725807 100644 --- a/proto/Cargo.toml +++ b/proto/Cargo.toml @@ -19,7 +19,7 @@ uuid = { workspace = true, features = ["v4"] } [dependencies.shuttle-common] workspace = true -features = ["models", "service", "wasm"] +features = ["error", "service", "wasm"] [build-dependencies] tonic-build = "0.8.3" diff --git a/proto/runtime.proto b/proto/runtime.proto index d75505419..8a0ee293d 100644 --- a/proto/runtime.proto +++ b/proto/runtime.proto @@ -34,6 +34,8 @@ message LoadRequest { message LoadResponse { // Could the service be loaded bool success = 1; + // Error message if not successful + string message = 2; } message StartRequest { diff --git a/proto/src/lib.rs b/proto/src/lib.rs index f27281d13..39275985d 100644 --- a/proto/src/lib.rs +++ b/proto/src/lib.rs @@ -104,7 +104,7 @@ pub mod runtime { use anyhow::Context; use chrono::DateTime; use prost_types::Timestamp; - use shuttle_common::models::ParseError; + use shuttle_common::ParseError; use tokio::process; use tonic::transport::{Channel, Endpoint}; use tracing::info; diff --git a/provisioner/src/lib.rs b/provisioner/src/lib.rs index 7c732149a..cfa9ebdc8 100644 --- a/provisioner/src/lib.rs +++ b/provisioner/src/lib.rs @@ -6,7 +6,7 @@ use aws_sdk_rds::{error::ModifyDBInstanceErrorKind, model::DbInstance, types::Sd pub use error::Error; use mongodb::{bson::doc, options::ClientOptions}; use rand::Rng; -use shuttle_common::backends::auth::{Claim, Scope}; +use shuttle_common::claims::{Claim, Scope}; pub use shuttle_proto::provisioner::provisioner_server::ProvisionerServer; use shuttle_proto::provisioner::{ aws_rds, database_request::DbType, shared, AwsRds, DatabaseRequest, DatabaseResponse, Shared, diff --git a/runtime/Cargo.toml b/runtime/Cargo.toml index 6ab22e9c3..0a417f502 100644 --- a/runtime/Cargo.toml +++ b/runtime/Cargo.toml @@ -14,7 +14,7 @@ required-features = ["next"] anyhow = { workspace = true } async-trait = { workspace = true } chrono = { workspace = true } -clap ={ version = "4.0.18", features = ["derive"] } +clap = { workspace = true } serde_json = { workspace = true } strfmt = "0.2.2" thiserror = { workspace = true } @@ -37,19 +37,19 @@ wasmtime-wasi = { version = "4.0.0", optional = true } [dependencies.shuttle-common] workspace = true -features = ["service", "backend"] +features = ["claims"] [dependencies.shuttle-proto] workspace = true [dependencies.shuttle-service] workspace = true -features = ["builder"] [dev-dependencies] crossbeam-channel = "0.5.6" portpicker = "0.1.1" futures = { version = "0.3.25" } +shuttle-service = { workspace = true, features = ["builder"] } [features] default = [] diff --git a/runtime/src/legacy/mod.rs b/runtime/src/legacy/mod.rs index 5d430bbc1..521669051 100644 --- a/runtime/src/legacy/mod.rs +++ b/runtime/src/legacy/mod.rs @@ -13,7 +13,7 @@ use async_trait::async_trait; use clap::Parser; use core::future::Future; use shuttle_common::{ - backends::{auth::ClaimLayer, tracing::InjectPropagationLayer}, + claims::{ClaimLayer, InjectPropagationLayer}, storage_manager::{ArtifactsStorageManager, StorageManager, WorkingDirStorageManager}, LogItem, }; @@ -38,7 +38,7 @@ use tonic::{ Request, Response, Status, }; use tower::ServiceBuilder; -use tracing::{error, instrument, trace}; +use tracing::{error, info, trace}; use uuid::Uuid; use crate::{provisioner_factory::ProvisionerFactory, Logger}; @@ -203,11 +203,51 @@ where let loader = self.loader.lock().unwrap().deref_mut().take().unwrap(); - let service = loader.load(factory, logger).await.unwrap(); + let service = match tokio::spawn(loader.load(factory, logger)).await { + Ok(res) => match res { + Ok(service) => service, + Err(error) => { + error!(%error, "loading service failed"); + + let message = LoadResponse { + success: false, + message: error.to_string(), + }; + return Ok(Response::new(message)); + } + }, + Err(error) => { + if error.is_panic() { + let panic = error.into_panic(); + let msg = panic + .downcast_ref::<&str>() + .map(|x| x.to_string()) + .unwrap_or_else(|| "".to_string()); + + error!(error = msg, "loading service panicked"); + + let message = LoadResponse { + success: false, + message: msg, + }; + return Ok(Response::new(message)); + } else { + error!(%error, "loading service crashed"); + let message = LoadResponse { + success: false, + message: error.to_string(), + }; + return Ok(Response::new(message)); + } + } + }; *self.service.lock().unwrap() = Some(service); - let message = LoadResponse { success: true }; + let message = LoadResponse { + success: true, + message: String::new(), + }; Ok(Response::new(message)) } @@ -231,13 +271,56 @@ where let (kill_tx, kill_rx) = tokio::sync::oneshot::channel(); *self.kill_tx.lock().unwrap() = Some(kill_tx); + let stopped_tx = self.stopped_tx.clone(); + + let handle = tokio::runtime::Handle::current(); + // start service as a background task with a kill receiver - tokio::spawn(run_until_stopped( - service, - service_address, - self.stopped_tx.clone(), - kill_rx, - )); + tokio::spawn(async move { + let mut background = handle.spawn(service.bind(service_address)); + + tokio::select! { + res = &mut background => { + match res { + Ok(_) => { + info!("service stopped all on its own"); + stopped_tx.send((StopReason::End, String::new())).unwrap(); + }, + Err(error) => { + if error.is_panic() { + let panic = error.into_panic(); + let msg = panic.downcast_ref::<&str>() + .map(|x| x.to_string()) + .unwrap_or_else(|| "".to_string()); + + error!(error = msg, "service panicked"); + + stopped_tx + .send((StopReason::Crash, msg)) + .unwrap(); + } else { + error!(%error, "service crashed"); + stopped_tx + .send((StopReason::Crash, error.to_string())) + .unwrap(); + } + }, + } + }, + message = kill_rx => { + match message { + Ok(_) => { + stopped_tx.send((StopReason::Request, String::new())).unwrap(); + } + Err(_) => trace!("the sender dropped") + }; + + info!("will now abort the service"); + background.abort(); + background.await.unwrap().expect("to stop service"); + } + } + }); let message = StartResponse { success: true }; @@ -307,35 +390,3 @@ where } } } - -/// Run the service until a stop signal is received -#[instrument(skip(service, stopped_tx, kill_rx))] -async fn run_until_stopped( - // service: LoadedService, - service: impl Service, - addr: SocketAddr, - stopped_tx: tokio::sync::broadcast::Sender<(StopReason, String)>, - kill_rx: tokio::sync::oneshot::Receiver, -) { - trace!("starting deployment on {}", &addr); - tokio::select! { - res = service.bind(addr) => { - match res { - Ok(_) => { - stopped_tx.send((StopReason::End, String::new())).unwrap(); - } - Err(error) => { - stopped_tx.send((StopReason::Crash, error.to_string())).unwrap(); - } - } - }, - message = kill_rx => { - match message { - Ok(_) => { - stopped_tx.send((StopReason::Request, String::new())).unwrap(); - } - Err(_) => trace!("the sender dropped") - }; - } - } -} diff --git a/runtime/src/logger.rs b/runtime/src/logger.rs index 79b2f6a49..8a82c2508 100644 --- a/runtime/src/logger.rs +++ b/runtime/src/logger.rs @@ -64,7 +64,7 @@ mod tests { let logger = Logger::new(s, Default::default()); - tracing_subscriber::registry().with(logger).init(); + let _guard = tracing_subscriber::registry().with(logger).set_default(); tracing::debug!("this is"); tracing::info!("hi"); diff --git a/runtime/src/next/mod.rs b/runtime/src/next/mod.rs index 6c43948d5..060c7b4b9 100644 --- a/runtime/src/next/mod.rs +++ b/runtime/src/next/mod.rs @@ -89,7 +89,10 @@ impl Runtime for AxumWasm { *self.router.lock().unwrap() = Some(router); - let message = LoadResponse { success: true }; + let message = LoadResponse { + success: true, + message: String::new(), + }; Ok(tonic::Response::new(message)) } @@ -426,7 +429,7 @@ pub mod tests { .unwrap(); } - #[tokio::test] + #[tokio::test(flavor = "multi_thread")] async fn axum() { compile_module(); diff --git a/runtime/src/provisioner_factory.rs b/runtime/src/provisioner_factory.rs index 8193af78e..56d91e94c 100644 --- a/runtime/src/provisioner_factory.rs +++ b/runtime/src/provisioner_factory.rs @@ -2,7 +2,7 @@ use std::{collections::BTreeMap, path::PathBuf, sync::Arc}; use async_trait::async_trait; use shuttle_common::{ - backends::{auth::ClaimService, tracing::InjectPropagation}, + claims::{ClaimService, InjectPropagation}, database, storage_manager::StorageManager, DatabaseReadyInfo, diff --git a/runtime/tests/integration/loader.rs b/runtime/tests/integration/loader.rs index d4051818a..3ef17f059 100644 --- a/runtime/tests/integration/loader.rs +++ b/runtime/tests/integration/loader.rs @@ -1,15 +1,9 @@ -use std::time::Duration; - -use shuttle_proto::runtime::{LoadRequest, StartRequest}; +use shuttle_proto::runtime::{LoadRequest, StartRequest, StopReason, SubscribeStopRequest}; use uuid::Uuid; use crate::helpers::{spawn_runtime, TestRuntime}; -/// This test does panic, but the panic happens in a spawned task inside the project runtime, -/// so we get this output: `thread 'tokio-runtime-worker' panicked at 'panic in bind', src/main.rs:6:9`, -/// but `should_panic(expected = "panic in bind")` doesn't catch it. #[tokio::test] -#[should_panic(expected = "panic in bind")] async fn bind_panic() { let project_path = format!("{}/tests/resources/bind-panic", env!("CARGO_MANIFEST_DIR")); @@ -29,20 +23,24 @@ async fn bind_panic() { let _ = runtime_client.load(load_request).await.unwrap(); + let mut stream = runtime_client + .subscribe_stop(tonic::Request::new(SubscribeStopRequest {})) + .await + .unwrap() + .into_inner(); + let start_request = StartRequest { deployment_id: Uuid::default().as_bytes().to_vec(), ip: runtime_address.to_string(), }; - // I also tried this without spawning, but it gave the same result. Panic but it isn't caught. - tokio::spawn(async move { - runtime_client - .start(tonic::Request::new(start_request)) - .await - .unwrap(); - // Give it a second to panic. - tokio::time::sleep(Duration::from_secs(1)).await; - }) - .await - .unwrap(); + runtime_client + .start(tonic::Request::new(start_request)) + .await + .unwrap(); + + let reason = stream.message().await.unwrap().unwrap(); + + assert_eq!(reason.reason, StopReason::Crash as i32); + assert_eq!(reason.message, "panic in bind"); } diff --git a/service/Cargo.toml b/service/Cargo.toml index c895c042e..64178f863 100644 --- a/service/Cargo.toml +++ b/service/Cargo.toml @@ -13,17 +13,16 @@ doctest = false [dependencies] anyhow = { workspace = true } async-trait = { workspace = true } -axum = { workspace = true, optional = true } # TODO: debug the libgit2-sys conflict with cargo-edit when upgrading cargo to 0.66 cargo = { version = "0.65.0", optional = true } -cargo_metadata = "0.15.2" -crossbeam-channel = "0.5.6" -pipe = "0.4.0" -serde_json = { workspace = true } +cargo_metadata = { version = "0.15.2", optional = true } +crossbeam-channel = { version = "0.5.6", optional = true } +pipe = { version = "0.4.0", optional = true } +serde_json = { workspace = true, optional = true } strfmt = "0.2.2" thiserror = { workspace = true } -tokio = { version = "1.26.0", features = ["sync"] } -tracing = { workspace = true } +tokio = { version = "1.26.0", features = ["sync"], optional = true } +tracing = { workspace = true, optional = true } [dependencies.shuttle-codegen] workspace = true @@ -31,7 +30,7 @@ optional = true [dependencies.shuttle-common] workspace = true -features = ["tracing", "service"] +features = ["service"] [dev-dependencies] tokio = { version = "1.26.0", features = ["macros", "rt"] } @@ -40,4 +39,4 @@ tokio = { version = "1.26.0", features = ["macros", "rt"] } default = ["codegen"] codegen = ["shuttle-codegen/frameworks"] -builder = ["cargo"] +builder = ["cargo", "cargo_metadata", "crossbeam-channel", "pipe", "tokio", "tracing"]