Skip to content

Commit

Permalink
[server] Upgrade to axum 0.7 #808 (#810)
Browse files Browse the repository at this point in the history
upgrade crates
  • Loading branch information
michaelvlach authored Nov 27, 2023
1 parent a348748 commit f0968b7
Show file tree
Hide file tree
Showing 9 changed files with 85 additions and 96 deletions.
8 changes: 4 additions & 4 deletions agdb_benchmarks/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ categories = ["database", "database-implementations"]

[dependencies]
agdb = { version = "0.5.1", path = "../agdb" }
num-format = { version = "0.4.4", features = ["with-serde"] }
serde = { version = "1.0.188", features = ["derive"] }
serde_yaml = "0.9.25"
tokio = { version = "1.32.0", features = ["full"] }
num-format = { version = "0.4", features = ["with-serde"] }
serde = { version = "1", features = ["derive"] }
serde_yaml = "0.9"
tokio = { version = "1", features = ["full"] }
4 changes: 2 additions & 2 deletions agdb_derive/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@ categories = ["database", "database-implementations"]
proc-macro = true

[dependencies]
quote = "1.0.32"
syn = "2.0.28"
quote = "1"
syn = "2"
34 changes: 17 additions & 17 deletions agdb_server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,22 @@ categories = ["database", "database-interfaces"]

[dependencies]
agdb = { version = "0.5.1", path = "../agdb" }
anyhow = "1.0.75"
axum = { version = "0.6.20", features = ["headers"] }
hyper = "0.14.27"
ring = "0.17.5"
serde = { version = "1.0.192", features = ["derive"] }
serde_json = "1.0.108"
serde_yaml = "0.9.27"
tokio = { version = "1.34.0", features = ["full"] }
tower = "0.4.13"
tower-http = { version = "0.4.4", features = ["map-request-body"] }
tracing = "0.1.40"
tracing-subscriber = "0.3.18"
utoipa = "4.1.0"
utoipa-swagger-ui = { version = "4.0.0", features = ["axum"] }
uuid = { version = "1.6.1", features = ["v4"] }
anyhow = "1"
axum = { version = "0.7", features = ["http2"] }
axum-extra = { version = "0.9", features = ["typed-header"] }
http-body-util = "0.1"
ring = "0.17"
serde = { version = "1", features = ["derive"] }
serde_json = "1"
serde_yaml = "0.9"
tokio = { version = "1", features = ["full"] }
tower = "0.4"
tracing = "0.1"
tracing-subscriber = "0.3"
utoipa = "4"
#utoipa-swagger-ui = { version = "4", features = ["axum"] }
uuid = { version = "1", features = ["v4"] }

[dev-dependencies]
assert_cmd = "2.0.12"
reqwest = { version = "0.11.22", features = ["json", "blocking"] }
assert_cmd = "2"
reqwest = { version = "0.11", features = ["json", "blocking"] }
31 changes: 12 additions & 19 deletions agdb_server/src/app.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
use std::fmt::Display;

use crate::api::Api;
use crate::db::Database;
use crate::db::DbPool;
use crate::db::User;
Expand All @@ -11,27 +8,24 @@ use crate::utilities;
use agdb::DbId;
use anyhow::anyhow;
use axum::async_trait;
use axum::body;
use axum::extract::FromRef;
use axum::extract::FromRequestParts;
use axum::extract::State;
use axum::headers::authorization::Bearer;
use axum::headers::Authorization;
use axum::http::request::Parts;
use axum::http::StatusCode;
use axum::middleware;
use axum::routing;
use axum::Json;
use axum::RequestPartsExt;
use axum::Router;
use axum::TypedHeader;
use axum_extra::headers::authorization::Bearer;
use axum_extra::headers::Authorization;
use axum_extra::TypedHeader;
use serde::Deserialize;
use serde::Serialize;
use std::fmt::Display;
use tokio::sync::broadcast::Sender;
use tower::ServiceBuilder;
use tower_http::map_request_body::MapRequestBodyLayer;
use utoipa::OpenApi;
use utoipa::ToSchema;
use utoipa_swagger_ui::SwaggerUi;
use uuid::Uuid;

#[derive(Clone)]
Expand Down Expand Up @@ -109,8 +103,8 @@ where
type Rejection = StatusCode;

async fn from_request_parts(parts: &mut Parts, db_pool: &S) -> Result<Self, Self::Rejection> {
let header = TypedHeader::<Authorization<Bearer>>::from_request_parts(parts, db_pool);
let bearer = header.await.map_err(unauthorized_error)?;
let bearer: TypedHeader<Authorization<Bearer>> =
parts.extract().await.map_err(unauthorized_error)?;
let id = DbPool::from_ref(db_pool)
.find_user_id(utilities::unquote(bearer.token()))
.map_err(unauthorized_error)?;
Expand All @@ -131,17 +125,14 @@ impl FromRef<ServerState> for Sender<()> {
}

pub(crate) fn app(shutdown_sender: Sender<()>, db_pool: DbPool) -> Router {
let logger = ServiceBuilder::new()
.layer(MapRequestBodyLayer::new(body::boxed))
.layer(middleware::from_fn(logger::logger));

let state = ServerState {
db_pool,
shutdown_sender,
};

Router::new()
.merge(SwaggerUi::new("/openapi").url("/openapi/openapi.json", Api::openapi()))
//.merge(SwaggerUi::new("/openapi").url("/openapi/openapi.json", Api::openapi()))
.route("/openapi", routing::get(StatusCode::OK))
.route("/shutdown", routing::get(shutdown))
.route("/error", routing::get(test_error))
.route("/add_db", routing::post(add_db))
Expand All @@ -151,7 +142,7 @@ pub(crate) fn app(shutdown_sender: Sender<()>, db_pool: DbPool) -> Router {
.route("/list", routing::get(list))
.route("/login", routing::post(login))
.route("/remove_db", routing::post(remove_db))
.layer(logger)
.layer(middleware::from_fn(logger::logger))
.with_state(state)
}

Expand Down Expand Up @@ -386,6 +377,7 @@ fn unauthorized_error<E>(_: E) -> StatusCode {
#[cfg(test)]
mod tests {
use super::*;
use crate::api::Api;
use crate::db::DbPoolImpl;
use crate::db::ServerDb;
use axum::body::Body;
Expand All @@ -397,6 +389,7 @@ mod tests {
use std::sync::Arc;
use std::sync::RwLock;
use tower::ServiceExt;
use utoipa::OpenApi;

fn test_db_pool() -> anyhow::Result<DbPool> {
Ok(DbPool(Arc::new(DbPoolImpl {
Expand Down
16 changes: 10 additions & 6 deletions agdb_server/src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,17 +136,21 @@ pub(crate) struct DbPool(pub(crate) Arc<DbPoolImpl>);

impl DbPool {
pub(crate) fn new() -> anyhow::Result<Self> {
let db_exists = Path::new("agdb_server.agdb").exists();

let db_pool = Self(Arc::new(DbPoolImpl {
server_db: ServerDb::new(SERVER_DB_NAME)?,
pool: RwLock::new(HashMap::new()),
}));

db_pool.0.server_db.get_mut()?.exec_mut(
&QueryBuilder::insert()
.nodes()
.aliases(vec!["users", "dbs"])
.query(),
)?;
if !db_exists {
db_pool.0.server_db.get_mut()?.exec_mut(
&QueryBuilder::insert()
.nodes()
.aliases(vec!["users", "dbs"])
.query(),
)?;
}

Ok(db_pool)
}
Expand Down
2 changes: 1 addition & 1 deletion agdb_server/src/error.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use axum::http::StatusCode;
use axum::response::IntoResponse;
use axum::response::Response;
use hyper::StatusCode;

pub(crate) struct ServerError {
pub(crate) status: StatusCode,
Expand Down
43 changes: 18 additions & 25 deletions agdb_server/src/logger.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use axum::body::BoxBody;
use axum::body::Full;
use axum::body::Body;
use axum::extract::Request;
use axum::http::StatusCode;
use axum::middleware::Next;
use axum::response::IntoResponse;
use axum::response::Response;
use axum::Error as AxumError;
use hyper::Request;
use hyper::StatusCode;
use http_body_util::BodyExt;
use serde::Serialize;
use std::collections::HashMap;
use std::time::Instant;
Expand Down Expand Up @@ -35,10 +35,7 @@ impl LogRecord {
}
}

pub(crate) async fn logger(
request: Request<BoxBody>,
next: Next<BoxBody>,
) -> Result<impl IntoResponse, Response> {
pub(crate) async fn logger(request: Request, next: Next) -> Result<impl IntoResponse, Response> {
let mut log_record = LogRecord::default();
let skip_body = request.uri().path().starts_with("/openapi");
let request = request_log(request, &mut log_record, skip_body).await?;
Expand All @@ -53,10 +50,10 @@ pub(crate) async fn logger(
}

async fn request_log(
request: Request<BoxBody>,
request: Request,
log_record: &mut LogRecord,
skip_body: bool,
) -> Result<Request<BoxBody>, Response> {
) -> Result<Request, Response> {
log_record.method = request.method().to_string();
log_record.uri = request.uri().to_string();
log_record.version = format!("{:?}", request.version());
Expand All @@ -68,23 +65,20 @@ async fn request_log(

if !skip_body {
let (parts, body) = request.into_parts();
let bytes = hyper::body::to_bytes(body).await.map_err(map_error)?;
let bytes = body.collect().await.map_err(map_error)?.to_bytes();
log_record.request_body = String::from_utf8_lossy(&bytes).to_string();

return Ok(Request::from_parts(
parts,
axum::body::boxed(Full::from(bytes)),
));
return Ok(Request::from_parts(parts, Body::from(bytes)));
}

Ok(request)
}

async fn response_log(
response: Response<BoxBody>,
response: Response,
log_record: &mut LogRecord,
skip_body: bool,
) -> Result<impl IntoResponse, Response> {
) -> Result<Response, Response> {
log_record.status = response.status().as_u16();
log_record.response_headers = response
.headers()
Expand All @@ -94,20 +88,19 @@ async fn response_log(

if !skip_body {
let (parts, body) = response.into_parts();
let resposne = hyper::body::to_bytes(body).await.map_err(map_error)?;
log_record.response = String::from_utf8_lossy(&resposne).to_string();
let bytes = body.collect().await.map_err(map_error)?.to_bytes();
log_record.response = String::from_utf8_lossy(&bytes).to_string();

return Ok(Response::from_parts(
parts,
axum::body::boxed(Full::from(resposne)),
));
return Ok(Response::from_parts(parts, Body::from(bytes)));
}

Ok(response)
}

fn map_error(error: AxumError) -> Response {
(StatusCode::INTERNAL_SERVER_ERROR, error.to_string()).into_response()
let mut response = Response::new(Body::from(error.to_string()));
*response.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
response
}

#[cfg(test)]
Expand All @@ -119,7 +112,7 @@ mod tests {
let error = AxumError::new(anyhow::Error::msg("error"));
let response = map_error(error);
assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
let body = hyper::body::to_bytes(response.into_body()).await?;
let body = response.into_body().collect().await?.to_bytes();
assert_eq!(&body[..], b"error");
Ok(())
}
Expand Down
27 changes: 9 additions & 18 deletions agdb_server/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,38 +8,29 @@ mod password;
mod utilities;

use crate::db::DbPool;
use axum::Server;
use std::net::SocketAddr;
use tokio::signal;
use tokio::sync::broadcast::Receiver;
use tracing::Level;

const BIND_ADDRESS_ARRAY: [u8; 4] = [127, 0, 0, 1];
const BIND_ADDRESS: &str = "127.0.0.1";

async fn shutdown_signal(mut shutdown_shutdown: Receiver<()>) {
tokio::select! {
_ = signal::ctrl_c() => {},
_ = shutdown_shutdown.recv() => {}
}
}

#[tokio::main]
async fn main() -> anyhow::Result<()> {
tracing_subscriber::fmt().with_max_level(Level::INFO).init();

let (shutdown_sender, shutdown_receiver) = tokio::sync::broadcast::channel::<()>(1);
let (shutdown_sender, mut shutdown_receiver) = tokio::sync::broadcast::channel::<()>(1);
let config = config::new()?;
let db_pool = DbPool::new()?;
let app = app::app(shutdown_sender, db_pool);
tracing::info!("Listening at {BIND_ADDRESS}:{}", config.port);
let addr = SocketAddr::from((BIND_ADDRESS_ARRAY, config.port));
let shutdown = shutdown_signal(shutdown_receiver);
let address = format!("{BIND_ADDRESS}:{}", config.port);
let listener = tokio::net::TcpListener::bind(address).await?;

Server::bind(&addr)
.serve(app.into_make_service())
.with_graceful_shutdown(shutdown)
.await?;
// Use actual graceful shutdown once it becomes available again...
tokio::select! {
_ = signal::ctrl_c() => {},
_ = shutdown_receiver.recv() => {},
_ = async { axum::serve(listener, app).await } => {},
};

Ok(())
}
16 changes: 12 additions & 4 deletions agdb_server/tests/server_test.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,27 @@
mod framework;

use crate::framework::TestServer;
use std::{collections::HashMap, path::Path};
use assert_cmd::cargo::CommandCargoExt;
use std::collections::HashMap;
use std::path::Path;
use std::process::Command;

#[tokio::test]
async fn config_port() -> anyhow::Result<()> {
let server = TestServer::new()?;
async fn db_reuse_and_error() -> anyhow::Result<()> {
let mut server = TestServer::new()?;
assert_eq!(server.get("/error").await?, 500);
assert_eq!(server.get("/shutdown").await?, 200);
assert!(server.process.wait().unwrap().success());
server.process = Command::cargo_bin("agdb_server")?
.current_dir(&server.dir)
.spawn()?;
Ok(())
}

#[tokio::test]
async fn openapi() -> anyhow::Result<()> {
let server = TestServer::new()?;
assert_eq!(server.get("/openapi/").await?, 200);
assert_eq!(server.get("/openapi").await?, 200);
Ok(())
}

Expand Down

0 comments on commit f0968b7

Please sign in to comment.