From e5c8cf3061a21f4d33423fd7df642d9c5fb70729 Mon Sep 17 00:00:00 2001 From: Matteo Bigoi <1781140+crisidev@users.noreply.github.com> Date: Thu, 22 Sep 2022 19:01:48 +0100 Subject: [PATCH] [Python] Allow to run pure Python request middlewares inside a Tower service (#1734) ## Motivation and Context * Customers want to be able to implement simple middleware directly in Python. This PR aims to add the initial support for it. * Improve the idiomatic experience of logging by exposing a handler compatible with Python's standard library `logging` module. ## Description ### Middleware A middleware is defined as a sync or async Python function that can return multiple values, following these rules: * Middleware not returning will let the execution continue without changing the original request. * Middleware returning a modified Request will update the original request before continuing the execution. * Middleware returning a Response will immediately terminate the request handling and return the response constructed from Python. * Middleware raising MiddlewareException will immediately terminate the request handling and return a protocol specific error, with the option of setting the HTTP return code. * Middleware raising any other exception will immediately terminate the request handling and return a protocol specific error, with HTTP status code 500. Middlewares are registered into the Python application and executed in order of registration. Example: from sdk import App from sdk.middleware import Request, MiddlewareException app = App() @app.request_middleware def inject_header(request: Request): request.set_header("x-amzn-answer", "42") return request @app.request_middleare def check_header(request: Request): if request.get_header("x-amzn-answer") != "42": raise MiddlewareException("Wrong answer", 404) @app.request_middleware def dump_headers(request: Request): logging.debug(f"Request headers after middlewares: {request.headers()}") **NOTE: this PR only adds support for request middlewares, which are executed before the operation handler. Response middlewares, executed after the operation are tracked here: https://github.com/awslabs/smithy-rs/issues/1754.** ### Logging To improve the idiomatic experience, now logging need to be configured from the Python side by using the standard `logging` module. This allows customers to opt-out of our `tracing` based logging implementation and use their own and logging level is now driven by Python. import logging from sdk.logging import TracingHandler logging.basicConfig(level=logging.INFO, handlers=[TracingHandler.handle()]) Signed-off-by: Bigo <1781140+crisidev@users.noreply.github.com> Co-authored-by: Burak --- CHANGELOG.next.toml | 9 +- .../generators/PythonApplicationGenerator.kt | 65 ++-- .../generators/PythonServerModuleGenerator.kt | 39 ++ .../PythonServerOperationHandlerGenerator.kt | 18 +- .../aws-smithy-http-server-python/Cargo.toml | 10 +- .../examples/pokemon_service.py | 58 +++ .../src/error.rs | 105 +++++- .../aws-smithy-http-server-python/src/lib.rs | 22 +- .../src/logging.rs | 275 +++++++------- .../src/middleware/handler.rs | 352 ++++++++++++++++++ .../src/middleware/layer.rs | 251 +++++++++++++ .../src/middleware/mod.rs | 23 ++ .../src/middleware/request.rs | 123 ++++++ .../src/middleware/response.rs | 79 ++++ .../src/server.rs | 67 +++- .../src/types.rs | 10 +- 16 files changed, 1319 insertions(+), 187 deletions(-) create mode 100644 rust-runtime/aws-smithy-http-server-python/src/middleware/handler.rs create mode 100644 rust-runtime/aws-smithy-http-server-python/src/middleware/layer.rs create mode 100644 rust-runtime/aws-smithy-http-server-python/src/middleware/mod.rs create mode 100644 rust-runtime/aws-smithy-http-server-python/src/middleware/request.rs create mode 100644 rust-runtime/aws-smithy-http-server-python/src/middleware/response.rs diff --git a/CHANGELOG.next.toml b/CHANGELOG.next.toml index 3b1b94fedc..49cdd29e26 100644 --- a/CHANGELOG.next.toml +++ b/CHANGELOG.next.toml @@ -10,9 +10,14 @@ # references = ["smithy-rs#920"] # meta = { "breaking" = false, "tada" = false, "bug" = false, "target" = "client | server | all"} # author = "rcoh" - [[rust-runtime]] message = "Pokémon Service example code now runs clippy during build." references = ["smithy-rs#1727"] -meta = { "breaking" = false, "tada" = false, "bug" = false } +meta = { "breaking" = false, "tada" = false, "bug" = false, "target" = "server" } author = "GeneralSwiss" + +[[smithy-rs]] +message = "Implement support for pure Python request middleware. Improve idiomatic logging support over tracing." +references = ["smithy-rs#1734"] +meta = { "breaking" = false, "tada" = false, "bug" = false, "target" = "server" } +author = "crisidev" diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt index 55cff80bc4..8854a09447 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt @@ -30,7 +30,6 @@ import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency * Example: * from pool import DatabasePool * from my_library import App, OperationInput, OperationOutput - * @dataclass * class Context: * db = DatabasePool() @@ -69,6 +68,7 @@ class PythonApplicationGenerator( private val libName = "lib${coreCodegenContext.settings.moduleName.toSnakeCase()}" private val runtimeConfig = coreCodegenContext.runtimeConfig private val model = coreCodegenContext.model + private val protocol = coreCodegenContext.protocol private val codegenScope = arrayOf( "SmithyPython" to PythonServerCargoDependency.SmithyHttpServerPython(runtimeConfig).asType(), @@ -101,6 +101,7 @@ class PythonApplicationGenerator( ##[derive(Debug, Default)] pub struct App { handlers: #{HashMap}, + middlewares: #{SmithyPython}::PyMiddlewares, context: Option<#{pyo3}::PyObject>, workers: #{parking_lot}::Mutex>, } @@ -116,6 +117,7 @@ class PythonApplicationGenerator( fn clone(&self) -> Self { Self { handlers: self.handlers.clone(), + middlewares: self.middlewares.clone(), context: self.context.clone(), workers: #{parking_lot}::Mutex::new(vec![]), } @@ -151,7 +153,7 @@ class PythonApplicationGenerator( val name = operationName.toSnakeCase() rustTemplate( """ - let ${name}_locals = pyo3_asyncio::TaskLocals::new(event_loop); + let ${name}_locals = #{pyo3_asyncio}::TaskLocals::new(event_loop); let handler = self.handlers.get("$name").expect("Python handler for operation `$name` not found").clone(); let router = router.$name(move |input, state| { #{pyo3_asyncio}::tokio::scope(${name}_locals, crate::operation_handler::$name(input, state, handler)) @@ -162,11 +164,20 @@ class PythonApplicationGenerator( } rustTemplate( """ + let middleware_locals = pyo3_asyncio::TaskLocals::new(event_loop); + use #{SmithyPython}::PyApp; + let service = #{tower}::ServiceBuilder::new().layer( + #{SmithyPython}::PyMiddlewareLayer::new( + self.middlewares.clone(), + self.protocol(), + middleware_locals + )?, + ); let router: #{SmithyServer}::routing::Router = router .build() .expect("Unable to build operation registry") .into(); - Ok(router) + Ok(router.layer(service)) """, *codegenScope, ) @@ -175,20 +186,25 @@ class PythonApplicationGenerator( } private fun renderPyAppTrait(writer: RustWriter) { + val protocol = protocol.toString().replace("#", "##") writer.rustTemplate( """ impl #{SmithyPython}::PyApp for App { fn workers(&self) -> &#{parking_lot}::Mutex> { &self.workers } - fn context(&self) -> &Option<#{pyo3}::PyObject> { &self.context } - fn handlers(&mut self) -> &mut #{HashMap} { &mut self.handlers } + fn middlewares(&mut self) -> &mut #{SmithyPython}::PyMiddlewares { + &mut self.middlewares + } + fn protocol(&self) -> &'static str { + "$protocol" + } } """, *codegenScope, @@ -207,16 +223,20 @@ class PythonApplicationGenerator( """ /// Create a new [App]. ##[new] - pub fn new(py: #{pyo3}::Python, log_level: Option<#{SmithyPython}::LogLevel>) -> #{pyo3}::PyResult { - let log_level = log_level.unwrap_or(#{SmithyPython}::LogLevel::Info); - #{SmithyPython}::logging::setup(py, log_level)?; - Ok(Self::default()) + pub fn new() -> Self { + Self::default() } /// Register a context object that will be shared between handlers. ##[pyo3(text_signature = "(${'$'}self, context)")] pub fn context(&mut self, context: #{pyo3}::PyObject) { self.context = Some(context); } + /// Register a request middleware function that will be run inside a Tower layer, without cloning the body. + ##[pyo3(text_signature = "(${'$'}self, func)")] + pub fn request_middleware(&mut self, py: #{pyo3}::Python, func: #{pyo3}::PyObject) -> #{pyo3}::PyResult<()> { + use #{SmithyPython}::PyApp; + self.register_middleware(py, func, #{SmithyPython}::PyMiddlewareType::Request) + } /// Main entrypoint: start the server on multiple workers. ##[pyo3(text_signature = "(${'$'}self, address, port, backlog, workers)")] pub fn run( @@ -235,7 +255,7 @@ class PythonApplicationGenerator( pub fn start_worker( &mut self, py: pyo3::Python, - socket: &pyo3::PyCell, + socket: &pyo3::PyCell<#{SmithyPython}::PySocket>, worker_number: isize, ) -> pyo3::PyResult<()> { use #{SmithyPython}::PyApp; @@ -280,21 +300,17 @@ class PythonApplicationGenerator( """.trimIndent(), ) writer.rust( - if (operations.any { it.errors.isNotEmpty() }) { - """ - /// from $libName import ${Inputs.namespace} - /// from $libName import ${Outputs.namespace} - /// from $libName import ${Errors.namespace} - """.trimIndent() - } else { - """ - /// from $libName import ${Inputs.namespace} - /// from $libName import ${Outputs.namespace} - """.trimIndent() - }, + """ + /// from $libName import ${Inputs.namespace} + /// from $libName import ${Outputs.namespace} + """.trimIndent(), ) + if (operations.any { it.errors.isNotEmpty() }) { + writer.rust("""/// from $libName import ${Errors.namespace}""".trimIndent()) + } writer.rust( """ + /// from $libName import middleware /// from $libName import App /// /// @dataclass @@ -304,6 +320,11 @@ class PythonApplicationGenerator( /// app = App() /// app.context(Context()) /// + /// @app.request_middleware + /// def request_middleware(request: middleware::Request): + /// if request.get_header("x-amzn-id") != "secret": + /// raise middleware.MiddlewareException("Unsupported `x-amz-id` header", 401) + /// """.trimIndent(), ) writer.operationImplementationStubs(operations) diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerModuleGenerator.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerModuleGenerator.kt index cf9631047d..654e7c10e0 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerModuleGenerator.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerModuleGenerator.kt @@ -47,6 +47,8 @@ class PythonServerModuleGenerator( renderPyCodegeneratedTypes() renderPyWrapperTypes() renderPySocketType() + renderPyLogging() + renderPyMiddlewareTypes() renderPyApplicationType() } } @@ -125,6 +127,43 @@ class PythonServerModuleGenerator( ) } + // Render Python shared socket type. + private fun RustWriter.renderPyLogging() { + rustTemplate( + """ + let logging = #{pyo3}::types::PyModule::new(py, "logging")?; + logging.add_function(#{pyo3}::wrap_pyfunction!(#{SmithyPython}::py_tracing_event, m)?)?; + logging.add_class::<#{SmithyPython}::PyTracingHandler>()?; + #{pyo3}::py_run!( + py, + logging, + "import sys; sys.modules['$libName.logging'] = logging" + ); + m.add_submodule(logging)?; + """, + *codegenScope, + ) + } + + private fun RustWriter.renderPyMiddlewareTypes() { + rustTemplate( + """ + let middleware = #{pyo3}::types::PyModule::new(py, "middleware")?; + middleware.add_class::<#{SmithyPython}::PyRequest>()?; + middleware.add_class::<#{SmithyPython}::PyResponse>()?; + middleware.add_class::<#{SmithyPython}::PyMiddlewareException>()?; + middleware.add_class::<#{SmithyPython}::PyHttpVersion>()?; + pyo3::py_run!( + py, + middleware, + "import sys; sys.modules['$libName.middleware'] = middleware" + ); + m.add_submodule(middleware)?; + """, + *codegenScope, + ) + } + // Render Python application type. private fun RustWriter.renderPyApplicationType() { rustTemplate( diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerOperationHandlerGenerator.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerOperationHandlerGenerator.kt index c624cb1385..a158d67722 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerOperationHandlerGenerator.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerOperationHandlerGenerator.kt @@ -90,16 +90,14 @@ class PythonServerOperationHandlerGenerator( rustTemplate( """ #{tracing}::debug!("Executing Python handler function `$name()`"); - #{tokio}::task::block_in_place(move || { - #{pyo3}::Python::with_gil(|py| { - let pyhandler: &#{pyo3}::types::PyFunction = handler.extract(py)?; - let output = if handler.args == 1 { - pyhandler.call1((input,))? - } else { - pyhandler.call1((input, state.0))? - }; - output.extract::<$output>() - }) + #{pyo3}::Python::with_gil(|py| { + let pyhandler: &#{pyo3}::types::PyFunction = handler.extract(py)?; + let output = if handler.args == 1 { + pyhandler.call1((input,))? + } else { + pyhandler.call1((input, state.0))? + }; + output.extract::<$output>() }) """, *codegenScope, diff --git a/rust-runtime/aws-smithy-http-server-python/Cargo.toml b/rust-runtime/aws-smithy-http-server-python/Cargo.toml index 7106b30c4a..b77292c324 100644 --- a/rust-runtime/aws-smithy-http-server-python/Cargo.toml +++ b/rust-runtime/aws-smithy-http-server-python/Cargo.toml @@ -13,14 +13,18 @@ Python server runtime for Smithy Rust Server Framework. publish = true [dependencies] +aws-smithy-http = { path = "../aws-smithy-http" } aws-smithy-http-server = { path = "../aws-smithy-http-server" } +aws-smithy-json = { path = "../aws-smithy-json" } aws-smithy-types = { path = "../aws-smithy-types" } -aws-smithy-http = { path = "../aws-smithy-http" } +aws-smithy-xml = { path = "../aws-smithy-xml" } bytes = "1.2" futures = "0.3" +http = "0.2" hyper = { version = "0.14.20", features = ["server", "http1", "http2", "tcp", "stream"] } num_cpus = "1.13.1" parking_lot = "0.12.1" +pin-project-lite = "0.2" pyo3 = "0.16.5" pyo3-asyncio = { version = "0.16.0", features = ["tokio-runtime"] } signal-hook = { version = "0.3.14", features = ["extended-siginfo"] } @@ -28,12 +32,14 @@ socket2 = { version = "0.4.4", features = ["all"] } thiserror = "1.0.32" tokio = { version = "1.20.1", features = ["full"] } tokio-stream = "0.1" -tower = "0.4.13" +tower = { version = "0.4.13", features = ["util"] } tracing = "0.1.36" tracing-subscriber = { version = "0.3.15", features = ["env-filter"] } +tracing-appender = { version = "0.2.2"} [dev-dependencies] pretty_assertions = "1" +futures-util = "0.3" [package.metadata.docs.rs] all-features = true diff --git a/rust-runtime/aws-smithy-http-server-python/examples/pokemon_service.py b/rust-runtime/aws-smithy-http-server-python/examples/pokemon_service.py index cfad7c1309..67345fb629 100644 --- a/rust-runtime/aws-smithy-http-server-python/examples/pokemon_service.py +++ b/rust-runtime/aws-smithy-http-server-python/examples/pokemon_service.py @@ -11,17 +11,25 @@ from typing import List, Optional import aiohttp + from libpokemon_service_server_sdk import App from libpokemon_service_server_sdk.error import ResourceNotFoundException from libpokemon_service_server_sdk.input import ( EmptyOperationInput, GetPokemonSpeciesInput, GetServerStatisticsInput, HealthCheckOperationInput, StreamPokemonRadioOperationInput) +from libpokemon_service_server_sdk.logging import TracingHandler +from libpokemon_service_server_sdk.middleware import (MiddlewareException, + Request) from libpokemon_service_server_sdk.model import FlavorText, Language from libpokemon_service_server_sdk.output import ( EmptyOperationOutput, GetPokemonSpeciesOutput, GetServerStatisticsOutput, HealthCheckOperationOutput, StreamPokemonRadioOperationOutput) from libpokemon_service_server_sdk.types import ByteStream +# Logging can bee setup using standard Python tooling. We provide +# fast logging handler, Tracingandler based on Rust tracing crate. +logging.basicConfig(handlers=[TracingHandler(level=logging.DEBUG).handler()]) + # A slightly more atomic counter using a threading lock. class FastWriteCounter: @@ -111,6 +119,55 @@ def get_random_radio_stream(self) -> str: app.context(Context()) +########################################################### +# Middleware +############################################################ +# Middlewares are sync or async function decorated by `@app.middleware`. +# They are executed in order and take as input the HTTP request object. +# A middleware can return multiple values, following these rules: +# * Middleware not returning will let the execution continue without +# changing the original request. +# * Middleware returning a modified Request will update the original +# request before continuing the execution. +# * Middleware returning a Response will immediately terminate the request +# handling and return the response constructed from Python. +# * Middleware raising MiddlewareException will immediately terminate the +# request handling and return a protocol specific error, with the option of +# setting the HTTP return code. +# * Middleware raising any other exception will immediately terminate the +# request handling and return a protocol specific error, with HTTP status +# code 500. +@app.request_middleware +def check_content_type_header(request: Request): + content_type = request.get_header("content-type") + if content_type == "application/json": + logging.debug("Found valid `application/json` content type") + else: + logging.warning( + f"Invalid content type {content_type}, dumping headers: {request.headers()}" + ) + + +# This middleware adds a new header called `x-amzn-answer` to the +# request. We expect to see this header to be populated in the next +# middleware. +@app.request_middleware +def add_x_amzn_answer_header(request: Request): + request.set_header("x-amzn-answer", "42") + logging.debug("Setting `x-amzn-answer` header to 42") + return request + + +# This middleware checks if the header `x-amzn-answer` is correctly set +# to 42, otherwise it returns an exception with a set status code. +@app.request_middleware +async def check_x_amzn_answer_header(request: Request): + # Check that `x-amzn-answer` is 42. + if request.get_header("x-amzn-answer") != "42": + # Return an HTTP 401 Unauthorized if the content type is not JSON. + raise MiddlewareException("Invalid answer", 401) + + ########################################################### # App handlers definition ########################################################### @@ -131,6 +188,7 @@ def get_pokemon_species( if flavor_text_entries: logging.debug("Total requests executed: %s", context.get_calls_count()) logging.info("Found description for Pokémon %s", input.name) + logging.error("Found some stuff") return GetPokemonSpeciesOutput( name=input.name, flavor_text_entries=flavor_text_entries ) diff --git a/rust-runtime/aws-smithy-http-server-python/src/error.rs b/rust-runtime/aws-smithy-http-server-python/src/error.rs index 42800954be..519e75501a 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/error.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/error.rs @@ -5,13 +5,15 @@ //! Python error definition. +use aws_smithy_http_server::protocols::Protocol; +use aws_smithy_http_server::{body::to_boxed, response::Response}; use aws_smithy_types::date_time::{ConversionError, DateTimeParseError}; -use pyo3::{exceptions::PyException, PyErr}; +use pyo3::{create_exception, exceptions::PyException as BasePyException, prelude::*, PyErr}; use thiserror::Error; /// Python error that implements foreign errors. #[derive(Error, Debug)] -pub enum Error { +pub enum PyError { /// Implements `From`. #[error("DateTimeConversion: {0}")] DateTimeConversion(#[from] ConversionError), @@ -20,8 +22,103 @@ pub enum Error { DateTimeParse(#[from] DateTimeParseError), } -impl From for PyErr { - fn from(other: Error) -> PyErr { +create_exception!(smithy, PyException, BasePyException); + +impl From for PyErr { + fn from(other: PyError) -> PyErr { PyException::new_err(other.to_string()) } } + +/// Exception that can be thrown from a Python middleware. +/// +/// It allows to specify a message and HTTP status code and implementing protocol specific capabilities +/// to build a [aws_smithy_http_server::response::Response] from it. +#[pyclass(name = "MiddlewareException", extends = BasePyException)] +#[pyo3(text_signature = "(message, status_code)")] +#[derive(Debug, Clone)] +pub struct PyMiddlewareException { + #[pyo3(get, set)] + message: String, + #[pyo3(get, set)] + status_code: u16, +} + +#[pymethods] +impl PyMiddlewareException { + /// Create a new [PyMiddlewareException]. + #[new] + fn newpy(message: String, status_code: Option) -> Self { + Self { + message, + status_code: status_code.unwrap_or(500), + } + } +} + +impl From for PyMiddlewareException { + fn from(other: PyErr) -> Self { + Self::newpy(other.to_string(), None) + } +} + +impl PyMiddlewareException { + /// Convert the exception into a [Response], following the [Protocol] specification. + pub(crate) fn into_response(self, protocol: Protocol) -> Response { + let body = to_boxed(match protocol { + Protocol::RestJson1 => self.json_body(), + Protocol::RestXml => self.xml_body(), + // See https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_0-protocol.html#empty-body-serialization + Protocol::AwsJson10 => self.json_body(), + // See https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_1-protocol.html#empty-body-serialization + Protocol::AwsJson11 => self.json_body(), + }); + + let mut builder = http::Response::builder(); + builder = builder.status(self.status_code); + + match protocol { + Protocol::RestJson1 => { + builder = builder + .header("Content-Type", "application/json") + .header("X-Amzn-Errortype", "MiddlewareException"); + } + Protocol::RestXml => builder = builder.header("Content-Type", "application/xml"), + Protocol::AwsJson10 => { + builder = builder.header("Content-Type", "application/x-amz-json-1.0") + } + Protocol::AwsJson11 => { + builder = builder.header("Content-Type", "application/x-amz-json-1.1") + } + } + + builder.body(body).expect("invalid HTTP response for `MiddlewareException`; please file a bug report under https://github.com/awslabs/smithy-rs/issues") + } + + /// Serialize the body into a JSON object. + fn json_body(&self) -> String { + let mut out = String::new(); + let mut object = aws_smithy_json::serialize::JsonObjectWriter::new(&mut out); + object.key("message").string(self.message.as_str()); + object.finish(); + out + } + + /// Serialize the body into a XML object. + fn xml_body(&self) -> String { + let mut out = String::new(); + { + let mut writer = aws_smithy_xml::encode::XmlWriter::new(&mut out); + let root = writer + .start_el("Error") + .write_ns("http://s3.amazonaws.com/doc/2006-03-01/", None); + let mut scope = root.finish(); + { + let mut inner_writer = scope.start_el("Message").finish(); + inner_writer.data(self.message.as_ref()); + } + scope.finish(); + } + out + } +} diff --git a/rust-runtime/aws-smithy-http-server-python/src/lib.rs b/rust-runtime/aws-smithy-http-server-python/src/lib.rs index 1f5c738194..793104af59 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/lib.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/lib.rs @@ -13,14 +13,19 @@ mod error; pub mod logging; +pub mod middleware; mod server; mod socket; pub mod types; #[doc(inline)] -pub use error::Error; +pub use error::{PyError, PyMiddlewareException}; #[doc(inline)] -pub use logging::LogLevel; +pub use logging::{py_tracing_event, PyTracingHandler}; +#[doc(inline)] +pub use middleware::{ + PyHttpVersion, PyMiddlewareLayer, PyMiddlewareType, PyMiddlewares, PyRequest, PyResponse, +}; #[doc(inline)] pub use server::{PyApp, PyHandler}; #[doc(inline)] @@ -30,11 +35,22 @@ pub use socket::PySocket; mod tests { use std::sync::Once; + use pyo3::{PyErr, Python}; + use pyo3_asyncio::TaskLocals; + static INIT: Once = Once::new(); - pub(crate) fn initialize() { + pub(crate) fn initialize() -> TaskLocals { INIT.call_once(|| { pyo3::prepare_freethreaded_python(); }); + + Python::with_gil(|py| { + let asyncio = py.import("asyncio")?; + let event_loop = asyncio.call_method0("new_event_loop")?; + asyncio.call_method1("set_event_loop", (event_loop,))?; + Ok::(TaskLocals::new(event_loop)) + }) + .unwrap() } } diff --git a/rust-runtime/aws-smithy-http-server-python/src/logging.rs b/rust-runtime/aws-smithy-http-server-python/src/logging.rs index 4c690f3f7f..c7894d0c5e 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/logging.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/logging.rs @@ -4,87 +4,78 @@ */ //! Rust `tracing` and Python `logging` setup and utilities. +use std::path::PathBuf; use pyo3::prelude::*; +#[cfg(not(test))] +use tracing::span; use tracing::Level; -use tracing_subscriber::filter::LevelFilter; -use tracing_subscriber::{prelude::*, EnvFilter}; - -/// Setup `tracing::subscriber` reading the log level from RUST_LOG environment variable -/// and inject the custom Python `logger` into the interpreter. -pub fn setup(py: Python, level: LogLevel) -> PyResult<()> { - let format = tracing_subscriber::fmt::layer() - .with_ansi(true) - .with_line_number(true) - .with_level(true); - match EnvFilter::try_from_default_env() { - Ok(filter) => { - let level: LogLevel = filter.to_string().into(); - tracing_subscriber::registry() - .with(format) - .with(filter) - .init(); - setup_python_logging(py, level)?; +use tracing_appender::non_blocking::WorkerGuard; +use tracing_subscriber::{ + fmt::{self, writer::MakeWriterExt}, + layer::SubscriberExt, + util::SubscriberInitExt, +}; + +use crate::error::PyException; + +/// Setup tracing-subscriber to log on console or to a hourly rolling file. +fn setup_tracing_subscriber( + level: Option, + logfile: Option, +) -> PyResult> { + let appender = match logfile { + Some(logfile) => { + let parent = logfile.parent().ok_or_else(|| { + PyException::new_err(format!( + "Tracing setup failed: unable to extract dirname from path {}", + logfile.display() + )) + })?; + let filename = logfile.file_name().ok_or_else(|| { + PyException::new_err(format!( + "Tracing setup failed: unable to extract basename from path {}", + logfile.display() + )) + })?; + let file_appender = tracing_appender::rolling::hourly(parent, filename); + let (appender, guard) = tracing_appender::non_blocking(file_appender); + Some((appender, guard)) } - Err(_) => { - tracing_subscriber::registry() - .with(format) - .with(LevelFilter::from_level(level.into())) - .init(); - setup_python_logging(py, level)?; - } - } - Ok(()) -} - -/// This custom logger enum exported to Python can be used to configure the -/// both the Rust `tracing` and Python `logging` levels. -/// We cannot export directly `tracing::Level` to Python. -#[pyclass] -#[derive(Debug, Clone, Copy)] -pub enum LogLevel { - Trace, - Debug, - Info, - Warn, - Error, -} + None => None, + }; -/// `From` is used to convert `LogLevel` to the correct string -/// needed by Python `logging` module. -impl From for String { - fn from(other: LogLevel) -> String { - match other { - LogLevel::Error => "ERROR".into(), - LogLevel::Warn => "WARN".into(), - LogLevel::Info => "INFO".into(), - _ => "DEBUG".into(), - } - } -} + let tracing_level = match level { + Some(40u8) => Level::ERROR, + Some(30u8) => Level::WARN, + Some(20u8) => Level::INFO, + Some(10u8) => Level::DEBUG, + None => Level::INFO, + _ => Level::TRACE, + }; -/// `From` is used to covert `tracing::EnvFilter` into `LogLevel`. -impl From for LogLevel { - fn from(other: String) -> LogLevel { - match other.as_str() { - "error" => LogLevel::Error, - "warn" => LogLevel::Warn, - "info" => LogLevel::Info, - "debug" => LogLevel::Debug, - _ => LogLevel::Trace, + match appender { + Some((appender, guard)) => { + let layer = Some( + fmt::Layer::new() + .with_writer(appender.with_max_level(tracing_level)) + .with_ansi(true) + .with_line_number(true) + .with_level(true), + ); + tracing_subscriber::registry().with(layer).init(); + Ok(Some(guard)) } - } -} - -/// `From` is used to covert `LogLevel` into `tracing::EnvFilter`. -impl From for Level { - fn from(other: LogLevel) -> Level { - match other { - LogLevel::Debug => Level::DEBUG, - LogLevel::Info => Level::INFO, - LogLevel::Warn => Level::WARN, - LogLevel::Error => Level::ERROR, - _ => Level::TRACE, + None => { + let layer = Some( + fmt::Layer::new() + .with_writer(std::io::stdout.with_max_level(tracing_level)) + .with_ansi(true) + .with_line_number(true) + .with_level(true), + ); + tracing_subscriber::registry().with(layer).init(); + Ok(None) } } } @@ -92,88 +83,116 @@ impl From for Level { /// Modifies the Python `logging` module to deliver its log messages using [tracing::Subscriber] events. /// /// To achieve this goal, the following changes are made to the module: -/// - A new builtin function `logging.python_tracing` transcodes `logging.LogRecord`s to `tracing::Event`s. This function +/// - A new builtin function `logging.py_tracing_event` transcodes `logging.LogRecord`s to `tracing::Event`s. This function /// is not exported in `logging.__all__`, as it is not intended to be called directly. -/// - A new class `logging.RustTracing` provides a `logging.Handler` that delivers all records to `python_tracing`. -/// - `logging.basicConfig` is changed to use `logging.HostHandler` by default. -/// -/// Since any call like `logging.warn(...)` sets up logging via `logging.basicConfig`, all log messages are now -/// delivered to `crate::logging`, which will send them to `tracing::event!`. -fn setup_python_logging(py: Python, level: LogLevel) -> PyResult<()> { - let logging = py.import("logging")?; - logging.setattr("python_tracing", wrap_pyfunction!(python_tracing, logging)?)?; - - let level: String = level.into(); - let pycode = format!( - r#" -class RustTracing(Handler): - """ Python logging to Rust tracing handler. """ - def __init__(self, level=0): - super().__init__(level=level) - - def emit(self, record): - python_tracing(record) +/// - A new class `logging.TracingHandler` provides a `logging.Handler` that delivers all records to `python_tracing`. +#[pyclass(name = "TracingHandler")] +#[derive(Debug)] +pub struct PyTracingHandler { + _guard: Option, +} -# Store the old basicConfig in the local namespace. -oldBasicConfig = basicConfig +#[pymethods] +impl PyTracingHandler { + #[new] + fn newpy(py: Python, level: Option, logfile: Option) -> PyResult { + let _guard = setup_tracing_subscriber(level, logfile)?; + let logging = py.import("logging")?; + let root = logging.getattr("root")?; + root.setattr("level", level)?; + // TODO(Investigate why the file appender just create the file and does not write anything, event after holding the guard) + Ok(Self { _guard }) + } -def basicConfig(*pargs, **kwargs): - """ Reimplement basicConfig to hijack the root logger. """ - if "handlers" not in kwargs: - kwargs["handlers"] = [RustTracing()] - kwargs["level"] = {level} - return oldBasicConfig(*pargs, **kwargs) -"#, - ); + fn handler(&self, py: Python) -> PyResult> { + let logging = py.import("logging")?; + logging.setattr( + "py_tracing_event", + wrap_pyfunction!(py_tracing_event, logging)?, + )?; - py.run(&pycode, Some(logging.dict()), None)?; - let all = logging.index()?; - all.append("RustTracing")?; - Ok(()) + let pycode = r#" +class TracingHandler(Handler): + """ Python logging to Rust tracing handler. """ + def emit(self, record): + py_tracing_event( + record.levelno, record.getMessage(), record.module, + record.filename, record.lineno, record.process + ) +"#; + py.run(pycode, Some(logging.dict()), None)?; + let all = logging.index()?; + all.append("TracingHandler")?; + let handler = logging.getattr("TracingHandler")?; + Ok(handler.call0()?.into_py(py)) + } } /// Consumes a Python `logging.LogRecord` and emits a Rust [tracing::Event] instead. #[cfg(not(test))] #[pyfunction] -#[pyo3(text_signature = "(record)")] -fn python_tracing(record: &PyAny) -> PyResult<()> { - let level = record.getattr("levelno")?; - let message = record.getattr("getMessage")?.call0()?; - let module = record.getattr("module")?; - let filename = record.getattr("filename")?; - let line = record.getattr("lineno")?; - let pid = record.getattr("process")?; - - match level.extract()? { - 40u8 => tracing::event!(Level::ERROR, %pid, %module, %filename, %line, "{message}"), - 30u8 => tracing::event!(Level::WARN, %pid, %module, %filename, %line, "{message}"), - 20u8 => tracing::event!(Level::INFO, %pid, %module, %filename, %line, "{message}"), - 10u8 => tracing::event!(Level::DEBUG, %pid, %module, %filename, %line, "{message}"), - _ => tracing::event!(Level::TRACE, %pid, %module, %filename, %line, "{message}"), +#[pyo3(text_signature = "(level, record, message, module, filename, line, pid)")] +pub fn py_tracing_event( + level: u8, + message: &str, + module: &str, + filename: &str, + lineno: usize, + pid: usize, +) -> PyResult<()> { + let span = span!( + Level::TRACE, + "python", + pid = pid, + module = module, + filename = filename, + lineno = lineno + ); + println!("message2: {message}"); + let _guard = span.enter(); + match level { + 40 => tracing::error!("{message}"), + 30 => tracing::warn!("{message}"), + 20 => tracing::info!("{message}"), + 10 => tracing::debug!("{message}"), + _ => tracing::trace!("{message}"), }; - Ok(()) } #[cfg(test)] #[pyfunction] -#[pyo3(text_signature = "(record)")] -fn python_tracing(record: &PyAny) -> PyResult<()> { - let message = record.getattr("getMessage")?.call0()?; +#[pyo3(text_signature = "(level, record, message, module, filename, line, pid)")] +pub fn py_tracing_event( + _level: u8, + message: &str, + _module: &str, + _filename: &str, + _line: usize, + _pid: usize, +) -> PyResult<()> { pretty_assertions::assert_eq!(message.to_string(), "a message"); Ok(()) } #[cfg(test)] mod tests { + use pyo3::types::PyDict; + use super::*; #[test] fn tracing_handler_is_injected_in_python() { crate::tests::initialize(); Python::with_gil(|py| { - setup_python_logging(py, LogLevel::Info).unwrap(); + let handler = PyTracingHandler::newpy(py, Some(10), None).unwrap(); + let kwargs = PyDict::new(py); + kwargs + .set_item("handlers", vec![handler.handler(py).unwrap()]) + .unwrap(); let logging = py.import("logging").unwrap(); + let basic_config = logging.getattr("basicConfig").unwrap(); + basic_config.call((), Some(kwargs)).unwrap(); logging.call_method1("info", ("a message",)).unwrap(); }); } diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/handler.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/handler.rs new file mode 100644 index 0000000000..00f99312e8 --- /dev/null +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/handler.rs @@ -0,0 +1,352 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +//! Execute Python middleware handlers. +use aws_smithy_http_server::body::Body; +use http::Request; +use pyo3::prelude::*; + +use aws_smithy_http_server::protocols::Protocol; +use pyo3_asyncio::TaskLocals; + +use crate::{PyMiddlewareException, PyRequest, PyResponse}; + +use super::PyFuture; + +#[derive(Debug, Clone, Copy)] +pub enum PyMiddlewareType { + Request, + Response, +} + +/// A Python middleware handler function representation. +/// +/// The Python business logic implementation needs to carry some information +/// to be executed properly like if it is a coroutine. +#[derive(Debug, Clone)] +pub struct PyMiddlewareHandler { + pub name: String, + pub func: PyObject, + pub is_coroutine: bool, + pub _type: PyMiddlewareType, +} + +/// Structure holding the list of Python middlewares that will be executed by this server. +/// +/// Middlewares are executed one after each other inside the [crate::PyMiddlewareLayer] Tower layer. +#[derive(Debug, Clone, Default)] +pub struct PyMiddlewares(Vec); + +impl PyMiddlewares { + /// Create a new instance of `PyMiddlewareHandlers` from a list of heandlers. + pub fn new(handlers: Vec) -> Self { + Self(handlers) + } + + /// Add a new handler to the list. + pub fn push(&mut self, handler: PyMiddlewareHandler) { + self.0.push(handler); + } + + /// Execute a single middleware handler. + /// + /// The handler is scheduled on the Python interpreter syncronously or asynchronously, + /// dependening on the handler signature. + async fn execute_middleware( + request: PyRequest, + handler: PyMiddlewareHandler, + ) -> Result<(Option, Option), PyMiddlewareException> { + let handle: PyResult> = if handler.is_coroutine { + tracing::debug!("Executing Python middleware coroutine `{}`", handler.name); + let result = pyo3::Python::with_gil(|py| { + let pyhandler: &pyo3::types::PyFunction = handler.func.extract(py)?; + let coroutine = pyhandler.call1((request,))?; + pyo3_asyncio::tokio::into_future(coroutine) + })?; + let output = result.await?; + Ok(output) + } else { + tracing::debug!("Executing Python middleware function `{}`", handler.name); + pyo3::Python::with_gil(|py| { + let pyhandler: &pyo3::types::PyFunction = handler.func.extract(py)?; + let output = pyhandler.call1((request,))?; + Ok(output.into_py(py)) + }) + }; + Python::with_gil(|py| match handle { + Ok(result) => { + if let Ok(request) = result.extract::(py) { + return Ok((Some(request), None)); + } + if let Ok(response) = result.extract::(py) { + return Ok((None, Some(response))); + } + Ok((None, None)) + } + Err(e) => pyo3::Python::with_gil(|py| { + let traceback = match e.traceback(py) { + Some(t) => t.format().unwrap_or_else(|e| e.to_string()), + None => "Unknown traceback\n".to_string(), + }; + tracing::error!("{}{}", traceback, e); + let variant = e.value(py); + if let Ok(v) = variant.extract::() { + Err(v) + } else { + Err(e.into()) + } + }), + }) + } + + /// Execute all the available Python middlewares in order of registration. + /// + /// Once the response is returned by the Python interpreter, different scenarios can happen: + /// * Middleware not returning will let the execution continue to the next middleware without + /// changing the original request. + /// * Middleware returning a modified [PyRequest] will update the original request before + /// continuing the execution of the next middleware. + /// * Middleware returning a [PyResponse] will immediately terminate the request handling and + /// return the response constructed from Python. + /// * Middleware raising [PyMiddlewareException] will immediately terminate the request handling + /// and return a protocol specific error, with the option of setting the HTTP return code. + /// * Middleware raising any other exception will immediately terminate the request handling and + /// return a protocol specific error, with HTTP status code 500. + pub fn run( + &mut self, + mut request: Request, + protocol: Protocol, + locals: TaskLocals, + ) -> PyFuture { + let handlers = self.0.clone(); + // Run all Python handlers in a loop. + Box::pin(async move { + tracing::debug!("Executing Python middleware stack"); + for handler in handlers { + let name = handler.name.clone(); + let pyrequest = PyRequest::new(&request); + let loop_locals = locals.clone(); + let result = pyo3_asyncio::tokio::scope( + loop_locals, + Self::execute_middleware(pyrequest, handler), + ) + .await; + match result { + Ok((pyrequest, pyresponse)) => { + if let Some(pyrequest) = pyrequest { + if let Ok(headers) = (&pyrequest.headers).try_into() { + tracing::debug!("Python middleware `{name}` returned an HTTP request, override headers with middleware's one"); + *request.headers_mut() = headers; + } + } + if let Some(pyresponse) = pyresponse { + tracing::debug!( + "Python middleware `{name}` returned a HTTP response, exit middleware loop" + ); + return Err(pyresponse.into()); + } + } + Err(e) => { + tracing::debug!( + "Middleware `{name}` returned an error, exit middleware loop" + ); + return Err(e.into_response(protocol)); + } + } + } + tracing::debug!( + "Python middleware execution finised, returning the request to operation handler" + ); + Ok(request) + }) + } +} + +#[cfg(test)] +mod tests { + use http::HeaderValue; + use hyper::body::to_bytes; + use pretty_assertions::assert_eq; + + use super::*; + + #[tokio::test] + async fn request_middleware_chain_keeps_headers_changes() -> PyResult<()> { + let locals = crate::tests::initialize(); + let mut middlewares = PyMiddlewares(vec![]); + + Python::with_gil(|py| { + let middleware = PyModule::new(py, "middleware").unwrap(); + middleware.add_class::().unwrap(); + middleware.add_class::().unwrap(); + let pycode = r#" +def first_middleware(request: Request): + request.set_header("x-amzn-answer", "42") + return request + +def second_middleware(request: Request): + if request.get_header("x-amzn-answer") != "42": + raise MiddlewareException("wrong answer", 401) +"#; + py.run(pycode, Some(middleware.dict()), None)?; + let all = middleware.index()?; + let first_middleware = PyMiddlewareHandler { + func: middleware.getattr("first_middleware")?.into_py(py), + is_coroutine: false, + name: "first".to_string(), + _type: PyMiddlewareType::Request, + }; + all.append("first_middleware")?; + middlewares.push(first_middleware); + let second_middleware = PyMiddlewareHandler { + func: middleware.getattr("second_middleware")?.into_py(py), + is_coroutine: false, + name: "second".to_string(), + _type: PyMiddlewareType::Request, + }; + all.append("second_middleware")?; + middlewares.push(second_middleware); + Ok::<(), PyErr>(()) + })?; + + let result = middlewares + .run( + Request::builder().body(Body::from("")).unwrap(), + Protocol::RestJson1, + locals, + ) + .await + .unwrap(); + assert_eq!( + result.headers().get("x-amzn-answer"), + Some(&HeaderValue::from_static("42")) + ); + Ok(()) + } + + #[tokio::test] + async fn request_middleware_return_response() -> PyResult<()> { + let locals = crate::tests::initialize(); + let mut middlewares = PyMiddlewares(vec![]); + + Python::with_gil(|py| { + let middleware = PyModule::new(py, "middleware").unwrap(); + middleware.add_class::().unwrap(); + middleware.add_class::().unwrap(); + let pycode = r#" +def middleware(request: Request): + return Response(200, {}, b"something")"#; + py.run(pycode, Some(middleware.dict()), None)?; + let all = middleware.index()?; + let middleware = PyMiddlewareHandler { + func: middleware.getattr("middleware")?.into_py(py), + is_coroutine: false, + name: "middleware".to_string(), + _type: PyMiddlewareType::Request, + }; + all.append("middleware")?; + middlewares.push(middleware); + Ok::<(), PyErr>(()) + })?; + + let result = middlewares + .run( + Request::builder().body(Body::from("")).unwrap(), + Protocol::RestJson1, + locals, + ) + .await + .unwrap_err(); + assert_eq!(result.status(), 200); + let body = to_bytes(result.into_body()).await.unwrap(); + assert_eq!(body, "something".as_bytes()); + Ok(()) + } + + #[tokio::test] + async fn request_middleware_raise_middleware_exception() -> PyResult<()> { + let locals = crate::tests::initialize(); + let mut middlewares = PyMiddlewares(vec![]); + + Python::with_gil(|py| { + let middleware = PyModule::new(py, "middleware").unwrap(); + middleware.add_class::().unwrap(); + middleware.add_class::().unwrap(); + let pycode = r#" +def middleware(request: Request): + raise MiddlewareException("error", 503)"#; + py.run(pycode, Some(middleware.dict()), None)?; + let all = middleware.index()?; + let middleware = PyMiddlewareHandler { + func: middleware.getattr("middleware")?.into_py(py), + is_coroutine: false, + name: "middleware".to_string(), + _type: PyMiddlewareType::Request, + }; + all.append("middleware")?; + middlewares.push(middleware); + Ok::<(), PyErr>(()) + })?; + + let result = middlewares + .run( + Request::builder().body(Body::from("")).unwrap(), + Protocol::RestJson1, + locals, + ) + .await + .unwrap_err(); + assert_eq!(result.status(), 503); + assert_eq!( + result.headers().get("X-Amzn-Errortype"), + Some(&HeaderValue::from_static("MiddlewareException")) + ); + let body = to_bytes(result.into_body()).await.unwrap(); + assert_eq!(body, r#"{"message":"error"}"#.as_bytes()); + Ok(()) + } + + #[tokio::test] + async fn request_middleware_raise_python_exception() -> PyResult<()> { + let locals = crate::tests::initialize(); + let mut middlewares = PyMiddlewares(vec![]); + + Python::with_gil(|py| { + let middleware = PyModule::from_code( + py, + r#" +def middleware(request): + raise ValueError("error")"#, + "", + "", + )?; + let middleware = PyMiddlewareHandler { + func: middleware.getattr("middleware")?.into_py(py), + is_coroutine: false, + name: "middleware".to_string(), + _type: PyMiddlewareType::Request, + }; + middlewares.push(middleware); + Ok::<(), PyErr>(()) + })?; + + let result = middlewares + .run( + Request::builder().body(Body::from("")).unwrap(), + Protocol::RestJson1, + locals, + ) + .await + .unwrap_err(); + assert_eq!(result.status(), 500); + assert_eq!( + result.headers().get("X-Amzn-Errortype"), + Some(&HeaderValue::from_static("MiddlewareException")) + ); + let body = to_bytes(result.into_body()).await.unwrap(); + assert_eq!(body, r#"{"message":"ValueError: error"}"#.as_bytes()); + Ok(()) + } +} diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/layer.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/layer.rs new file mode 100644 index 0000000000..73508541a2 --- /dev/null +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/layer.rs @@ -0,0 +1,251 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +//! Tower layer implementation of Python middleware handling. +use std::{ + pin::Pin, + task::{Context, Poll}, +}; + +use aws_smithy_http_server::{ + body::{Body, BoxBody}, + protocols::Protocol, +}; +use futures::{ready, Future}; +use http::{Request, Response}; +use pin_project_lite::pin_project; +use pyo3::PyResult; +use pyo3_asyncio::TaskLocals; +use tower::{Layer, Service}; + +use crate::{error::PyException, middleware::PyFuture, PyMiddlewares}; + +/// Tower [Layer] implementation of Python middleware handling. +/// +/// Middleware stored in the `handlers` attribute will be executed, in order, +/// inside an async Tower middleware. +#[derive(Debug, Clone)] +pub struct PyMiddlewareLayer { + handlers: PyMiddlewares, + protocol: Protocol, + locals: TaskLocals, +} + +impl PyMiddlewareLayer { + pub fn new( + handlers: PyMiddlewares, + protocol: &str, + locals: TaskLocals, + ) -> PyResult { + let protocol = match protocol { + "aws.protocols#restJson1" => Protocol::RestJson1, + "aws.protocols#restXml" => Protocol::RestXml, + "aws.protocols#awsjson10" => Protocol::AwsJson10, + "aws.protocols#awsjson11" => Protocol::AwsJson11, + _ => { + return Err(PyException::new_err(format!( + "Protocol {protocol} is not supported" + ))) + } + }; + Ok(Self { + handlers, + protocol, + locals, + }) + } +} + +impl Layer for PyMiddlewareLayer { + type Service = PyMiddlewareService; + + fn layer(&self, inner: S) -> Self::Service { + PyMiddlewareService::new( + inner, + self.handlers.clone(), + self.protocol, + self.locals.clone(), + ) + } +} + +// Tower [Service] wrapping the Python middleware [Layer]. +#[derive(Clone, Debug)] +pub struct PyMiddlewareService { + inner: S, + handlers: PyMiddlewares, + protocol: Protocol, + locals: TaskLocals, +} + +impl PyMiddlewareService { + pub fn new( + inner: S, + handlers: PyMiddlewares, + protocol: Protocol, + locals: TaskLocals, + ) -> PyMiddlewareService { + Self { + inner, + handlers, + protocol, + locals, + } + } +} + +impl Service> for PyMiddlewareService +where + S: Service, Response = Response> + Clone, +{ + type Response = Response; + type Error = S::Error; + type Future = ResponseFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: Request) -> Self::Future { + // TODO(Should we make this clone less expensive by wrapping inner in a Arc?) + let clone = self.inner.clone(); + // See https://docs.rs/tower/latest/tower/trait.Service.html#be-careful-when-cloning-inner-services + let inner = std::mem::replace(&mut self.inner, clone); + let run = self.handlers.run(req, self.protocol, self.locals.clone()); + + ResponseFuture { + middleware: State::Running { run }, + service: inner, + } + } +} + +pin_project! { + /// Response future handling the state transition between a running and a done future. + pub struct ResponseFuture + where + S: Service>, + { + #[pin] + middleware: State, + service: S, + } +} + +pin_project! { + /// Representation of the result of the middleware execution. + #[project = StateProj] + enum State { + Running { + #[pin] + run: A, + }, + Done { + #[pin] + fut: Fut + } + } +} + +impl Future for ResponseFuture +where + S: Service, Response = Response>, +{ + type Output = Result, S::Error>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut this = self.project(); + loop { + match this.middleware.as_mut().project() { + // Run the handler and store the future inside the inner state. + StateProj::Running { run } => { + let run = ready!(run.poll(cx)); + match run { + Ok(req) => { + let fut = this.service.call(req); + this.middleware.set(State::Done { fut }); + } + Err(res) => return Poll::Ready(Ok(res)), + } + } + // Execute the future returned by the layer. + StateProj::Done { fut } => return fut.poll(cx), + } + } + } +} + +#[cfg(test)] +mod tests { + use std::error::Error; + + use super::*; + + use aws_smithy_http_server::body::to_boxed; + use pyo3::prelude::*; + use tower::{Service, ServiceBuilder, ServiceExt}; + + use crate::middleware::PyMiddlewareHandler; + use crate::{PyMiddlewareException, PyMiddlewareType, PyRequest}; + + async fn echo(req: Request) -> Result, Box> { + Ok(Response::new(to_boxed(req.into_body()))) + } + + #[tokio::test] + async fn request_middlewares_are_chained_inside_layer() -> PyResult<()> { + let locals = crate::tests::initialize(); + let mut middlewares = PyMiddlewares::new(vec![]); + + Python::with_gil(|py| { + let middleware = PyModule::new(py, "middleware").unwrap(); + middleware.add_class::().unwrap(); + middleware.add_class::().unwrap(); + let pycode = r#" +def first_middleware(request: Request): + request.set_header("x-amzn-answer", "42") + return request + +def second_middleware(request: Request): + if request.get_header("x-amzn-answer") != "42": + raise MiddlewareException("wrong answer", 401) +"#; + py.run(pycode, Some(middleware.dict()), None)?; + let all = middleware.index()?; + let first_middleware = PyMiddlewareHandler { + func: middleware.getattr("first_middleware")?.into_py(py), + is_coroutine: false, + name: "first".to_string(), + _type: PyMiddlewareType::Request, + }; + all.append("first_middleware")?; + middlewares.push(first_middleware); + let second_middleware = PyMiddlewareHandler { + func: middleware.getattr("second_middleware")?.into_py(py), + is_coroutine: false, + name: "second".to_string(), + _type: PyMiddlewareType::Request, + }; + all.append("second_middleware")?; + middlewares.push(second_middleware); + Ok::<(), PyErr>(()) + })?; + + let mut service = ServiceBuilder::new() + .layer(PyMiddlewareLayer::new( + middlewares, + "aws.protocols#restJson1", + locals, + )?) + .service_fn(echo); + + let request = Request::get("/").body(Body::empty()).unwrap(); + + let res = service.ready().await.unwrap().call(request).await.unwrap(); + + assert_eq!(res.status(), 200); + Ok(()) + } +} diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/mod.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/mod.rs new file mode 100644 index 0000000000..a1a2d14ced --- /dev/null +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/mod.rs @@ -0,0 +1,23 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +//! Schedule pure Python middlewares as `Tower` layers. +mod handler; +mod layer; +mod request; +mod response; + +use aws_smithy_http_server::body::{Body, BoxBody}; +use futures::future::BoxFuture; +use http::{Request, Response}; + +pub use self::handler::{PyMiddlewareType, PyMiddlewares}; +pub use self::layer::PyMiddlewareLayer; +pub use self::request::{PyHttpVersion, PyRequest}; +pub use self::response::PyResponse; + +pub(crate) use self::handler::PyMiddlewareHandler; +/// Future type returned by the Python middleware handler. +pub(crate) type PyFuture = BoxFuture<'static, Result, Response>>; diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/request.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/request.rs new file mode 100644 index 0000000000..467d7dbb7b --- /dev/null +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/request.rs @@ -0,0 +1,123 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +//! Python-compatible middleware [http::Request] implementation. +use std::collections::HashMap; + +use aws_smithy_http_server::body::Body; +use http::{Request, Version}; +use pyo3::prelude::*; + +/// Python compabible HTTP [Version]. +#[pyclass(name = "HttpVersion")] +#[derive(PartialEq, PartialOrd, Copy, Clone, Eq, Ord, Hash)] +pub struct PyHttpVersion(Version); + +#[pymethods] +impl PyHttpVersion { + /// Extract the value of the HTTP [Version] into a string that + /// can be used by Python. + #[pyo3(text_signature = "($self)")] + fn value(&self) -> &str { + match self.0 { + Version::HTTP_09 => "HTTP/0.9", + Version::HTTP_10 => "HTTP/1.0", + Version::HTTP_11 => "HTTP/1.1", + Version::HTTP_2 => "HTTP/2.0", + Version::HTTP_3 => "HTTP/3.0", + _ => unreachable!(), + } + } +} + +/// Python-compatible [Request] object. +/// +/// For performance reasons, there is not support yet to pass the body to the Python middleware, +/// as it requires to consume and clone the body, which is a very expensive operation. +/// +/// TODO(if customers request for it, we can implemented an opt-in functionality to also pass +/// the body around). +#[pyclass(name = "Request")] +#[pyo3(text_signature = "(request)")] +#[derive(Debug, Clone)] +pub struct PyRequest { + #[pyo3(get, set)] + method: String, + #[pyo3(get, set)] + uri: String, + // TODO(investigate if using a PyDict can make the experience more idiomatic) + // I'd like to be able to do request.headers.get("my-header") and + // request.headers["my-header"] = 42 instead of implementing set_header() and get_header() + // under pymethods. The same applies to response. + pub(crate) headers: HashMap, + version: Version, +} + +impl PyRequest { + /// Create a new Python-compatible [Request] structure from the Rust side. + /// + /// This is done by cloning the headers, method, URI and HTTP version to let them be owned by Python. + pub fn new(request: &Request) -> Self { + Self { + method: request.method().to_string(), + uri: request.uri().to_string(), + headers: request + .headers() + .into_iter() + .map(|(k, v)| -> (String, String) { + let name: String = k.as_str().to_string(); + let value: String = String::from_utf8_lossy(v.as_bytes()).to_string(); + (name, value) + }) + .collect(), + version: request.version(), + } + } +} + +#[pymethods] +impl PyRequest { + #[new] + /// Create a new Python-compatible `Request` object from the Python side. + fn newpy( + method: String, + uri: String, + headers: Option>, + version: Option, + ) -> Self { + let version = version.map(|v| v.0).unwrap_or(Version::HTTP_11); + Self { + method, + uri, + headers: headers.unwrap_or_default(), + version, + } + } + + /// Return the HTTP version of this request. + #[pyo3(text_signature = "($self)")] + fn version(&self) -> String { + PyHttpVersion(self.version).value().to_string() + } + + /// Return the HTTP headers of this request. + /// TODO(can we use `Py::clone_ref()` to prevent cloning the hashmap?) + #[pyo3(text_signature = "($self)")] + fn headers(&self) -> HashMap { + self.headers.clone() + } + + /// Insert a new key/value into this request's headers. + #[pyo3(text_signature = "($self, key, value)")] + fn set_header(&mut self, key: &str, value: &str) { + self.headers.insert(key.to_string(), value.to_string()); + } + + /// Return a header value of this request. + #[pyo3(text_signature = "($self, key)")] + fn get_header(&self, key: &str) -> Option<&String> { + self.headers.get(key) + } +} diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/response.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/response.rs new file mode 100644 index 0000000000..773fe76327 --- /dev/null +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/response.rs @@ -0,0 +1,79 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +//! Python-compatible middleware [http::Request] implementation. +use std::{collections::HashMap, convert::TryInto}; + +use aws_smithy_http_server::body::{to_boxed, BoxBody}; +use http::{Response, StatusCode}; +use pyo3::prelude::*; + +/// Python-compatible [Response] object. +/// +/// For performance reasons, there is not support yet to pass the body to the Python middleware, +/// as it requires to consume and clone the body, which is a very expensive operation. +/// +// TODO(if customers request for it, we can implemented an opt-in functionality to also pass +// the body around). +#[pyclass(name = "Response")] +#[pyo3(text_signature = "(status, headers, body)")] +#[derive(Debug, Clone)] +pub struct PyResponse { + #[pyo3(get, set)] + status: u16, + #[pyo3(get, set)] + body: Vec, + headers: HashMap, +} + +#[pymethods] +impl PyResponse { + /// Python-compatible [Response] object from the Python side. + #[new] + fn newpy(status: u16, headers: Option>, body: Option>) -> Self { + Self { + status, + body: body.unwrap_or_default(), + headers: headers.unwrap_or_default(), + } + } + + /// Return the HTTP headers of this response. + // TODO(can we use `Py::clone_ref()` to prevent cloning the hashmap?) + #[pyo3(text_signature = "($self)")] + fn headers(&self) -> HashMap { + self.headers.clone() + } + + /// Insert a new key/value into this response's headers. + #[pyo3(text_signature = "($self, key, value)")] + fn set_header(&mut self, key: &str, value: &str) { + self.headers.insert(key.to_string(), value.to_string()); + } + + /// Return a header value of this response. + #[pyo3(text_signature = "($self, key)")] + fn get_header(&self, key: &str) -> Option<&String> { + self.headers.get(key) + } +} + +/// Allow to convert between a [PyResponse] and a [Response]. +impl From for Response { + fn from(pyresponse: PyResponse) -> Self { + let mut response = Response::builder() + .status( + StatusCode::from_u16(pyresponse.status) + .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR), + ) + .body(to_boxed(pyresponse.body)) + .unwrap_or_default(); + match (&pyresponse.headers).try_into() { + Ok(headers) => *response.headers_mut() = headers, + Err(e) => tracing::error!("Error extracting HTTP headers from PyResponse: {e}"), + }; + response + } +} diff --git a/rust-runtime/aws-smithy-http-server-python/src/server.rs b/rust-runtime/aws-smithy-http-server-python/src/server.rs index 3f1dac23ce..94680ce3bf 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/server.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/server.rs @@ -13,7 +13,7 @@ use signal_hook::{consts::*, iterator::Signals}; use tokio::runtime; use tower::ServiceBuilder; -use crate::PySocket; +use crate::{middleware::PyMiddlewareHandler, PyMiddlewareType, PyMiddlewares, PySocket}; /// A Python handler function representation. /// @@ -61,6 +61,10 @@ pub trait PyApp: Clone + pyo3::IntoPy { /// Mapping between operation names and their `PyHandler` representation. fn handlers(&mut self) -> &mut HashMap; + fn middlewares(&mut self) -> &mut PyMiddlewares; + + fn protocol(&self) -> &'static str; + /// Handle the graceful termination of Python workers by looping through all the /// active workers and calling `terminate()` on them. If termination fails, this /// method will try to `kill()` any failed worker. @@ -144,7 +148,7 @@ pub trait PyApp: Clone + pyo3::IntoPy { self.graceful_termination(self.workers()); } _ => { - tracing::warn!("Signal {sig:?} is ignored by this application"); + tracing::debug!("Signal {sig:?} is ignored by this application"); } } } @@ -162,6 +166,10 @@ import functools import signal async def shutdown(sig, event_loop): + # reimport asyncio and logging to be sure they are available when + # this handler runs on signal catching. + import asyncio + import logging logging.info(f"Caught signal {sig.name}, cancelling tasks registered on this loop") tasks = [task for task in asyncio.all_tasks() if task is not asyncio.current_task()] @@ -252,6 +260,45 @@ event_loop.add_signal_handler(signal.SIGINT, Ok(()) } + // Check if a Python function is a coroutine. Since the function has not run yet, + // we cannot use `asyncio.iscoroutine()`, we need to use `inspect.iscoroutinefunction()`. + fn is_coroutine(&self, py: Python, func: &PyObject) -> PyResult { + let inspect = py.import("inspect")?; + // NOTE: that `asyncio.iscoroutine()` doesn't work here. + inspect + .call_method1("iscoroutinefunction", (func,))? + .extract::() + } + + /// Register a Python function to be executed inside a Tower middleware layer. + /// + /// There are some information needed to execute the Python code from a Rust handler, + /// such has if the registered function needs to be awaited (if it is a coroutine).. + fn register_middleware( + &mut self, + py: Python, + func: PyObject, + _type: PyMiddlewareType, + ) -> PyResult<()> { + let name = func.getattr(py, "__name__")?.extract::(py)?; + let is_coroutine = self.is_coroutine(py, &func)?; + // Find number of expected methods (a Python implementation could not accept the context). + let handler = PyMiddlewareHandler { + name, + func, + is_coroutine, + _type, + }; + tracing::info!( + "Registering middleware function `{}`, coroutine: {}, type: {:?}", + handler.name, + handler.is_coroutine, + handler._type + ); + self.middlewares().push(handler); + Ok(()) + } + /// Register a Python function to be executed inside the Smithy Rust handler. /// /// There are some information needed to execute the Python code from a Rust handler, @@ -259,13 +306,9 @@ event_loop.add_signal_handler(signal.SIGINT, /// the number of arguments available, which tells us if the handler wants the state to be /// passed or not. fn register_operation(&mut self, py: Python, name: &str, func: PyObject) -> PyResult<()> { + let is_coroutine = self.is_coroutine(py, &func)?; + // Find number of expected methods (a Python implementation could not accept the context). let inspect = py.import("inspect")?; - // Check if the function is a coroutine. - // NOTE: that `asyncio.iscoroutine()` doesn't work here. - let is_coroutine = inspect - .call_method1("iscoroutinefunction", (&func,))? - .extract::()?; - // Find number of expected methods (a Pythzzon implementation could not accept the context). let func_args = inspect .call_method1("getargs", (func.getattr(py, "__code__")?,))? .getattr("args")? @@ -276,12 +319,12 @@ event_loop.add_signal_handler(signal.SIGINT, args: func_args.len(), }; tracing::info!( - "Registering function `{name}`, coroutine: {}, arguments: {}", + "Registering handler function `{name}`, coroutine: {}, arguments: {}", handler.is_coroutine, handler.args, ); // Insert the handler in the handlers map. - self.handlers().insert(String::from(name), handler); + self.handlers().insert(name.to_string(), handler); Ok(()) } @@ -326,7 +369,7 @@ event_loop.add_signal_handler(signal.SIGINT, /// ```no_run /// use std::collections::HashMap; /// use pyo3::prelude::*; - /// use aws_smithy_http_server_python::{PyApp, PyHandler}; + /// use aws_smithy_http_server_python::{PyApp, PyHandler, PyMiddlewares}; /// use parking_lot::Mutex; /// /// #[pyclass] @@ -341,6 +384,8 @@ event_loop.add_signal_handler(signal.SIGINT, /// fn workers(&self) -> &Mutex> { todo!() } /// fn context(&self) -> &Option { todo!() } /// fn handlers(&mut self) -> &mut HashMap { todo!() } + /// fn middlewares(&mut self) -> &mut PyMiddlewares { todo!() } + /// fn protocol(&self) -> &'static str { "proto1" } /// } /// /// #[pymethods] diff --git a/rust-runtime/aws-smithy-http-server-python/src/types.rs b/rust-runtime/aws-smithy-http-server-python/src/types.rs index 9e288badbc..98adbd4119 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/types.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/types.rs @@ -22,7 +22,7 @@ use pyo3::{ use tokio::sync::Mutex; use tokio_stream::StreamExt; -use crate::Error; +use crate::PyError; /// Python Wrapper for [aws_smithy_types::Blob]. #[pyclass] @@ -152,7 +152,7 @@ impl DateTime { pub fn from_nanos(epoch_nanos: i128) -> PyResult { Ok(Self( aws_smithy_types::date_time::DateTime::from_nanos(epoch_nanos) - .map_err(Error::DateTimeConversion)?, + .map_err(PyError::DateTimeConversion)?, )) } @@ -160,7 +160,7 @@ impl DateTime { #[staticmethod] pub fn read(s: &str, format: Format, delim: char) -> PyResult<(Self, &str)> { let (self_, next) = aws_smithy_types::date_time::DateTime::read(s, format.into(), delim) - .map_err(Error::DateTimeParse)?; + .map_err(PyError::DateTimeParse)?; Ok((Self(self_), next)) } @@ -195,7 +195,7 @@ impl DateTime { pub fn from_str(s: &str, format: Format) -> PyResult { Ok(Self( aws_smithy_types::date_time::DateTime::from_str(s, format.into()) - .map_err(Error::DateTimeParse)?, + .map_err(PyError::DateTimeParse)?, )) } @@ -226,7 +226,7 @@ impl DateTime { /// Converts the `DateTime` to the number of milliseconds since the Unix epoch. pub fn to_millis(&self) -> PyResult { - Ok(self.0.to_millis().map_err(Error::DateTimeConversion)?) + Ok(self.0.to_millis().map_err(PyError::DateTimeConversion)?) } }