Skip to content

Commit

Permalink
[Python] Allow to run pure Python request middlewares inside a Tower …
Browse files Browse the repository at this point in the history
…service (#1734)

## Motivation and Context
* Customers want to be able to implement simple middleware directly in Python. This PR aims to add the initial support for it.
* Improve the idiomatic experience of logging by exposing a handler compatible with Python's standard library `logging` module.

## Description
### Middleware
A middleware is defined as a sync or async Python function that can return multiple values, following these rules:

* Middleware not returning will let the execution continue without changing the original request.
* Middleware returning a modified Request will update the original request before continuing the execution.
* Middleware returning a Response will immediately terminate the request handling and return the response constructed from Python.
* Middleware raising MiddlewareException will immediately terminate the request handling and return a protocol specific error, with the option of setting the HTTP return code.
* Middleware raising any other exception will immediately terminate the request handling and return a protocol specific error, with HTTP status code 500.

Middlewares are registered into the Python application and executed in order of registration.

Example:

from sdk import App
from sdk.middleware import Request, MiddlewareException

app = App()

@app.request_middleware
def inject_header(request: Request):
    request.set_header("x-amzn-answer", "42")
    return request

@app.request_middleare
def check_header(request: Request):
    if request.get_header("x-amzn-answer") != "42":
        raise MiddlewareException("Wrong answer", 404)

@app.request_middleware
def dump_headers(request: Request):
    logging.debug(f"Request headers after middlewares: {request.headers()}")

**NOTE: this PR only adds support for request middlewares, which are executed before the operation handler. Response middlewares, executed after the operation are tracked here: #1754

### Logging
To improve the idiomatic experience, now logging need to be configured from the Python side by using the standard `logging` module. This allows customers to opt-out of our `tracing` based logging implementation and use their own and logging level is now driven by Python.

import logging
from sdk.logging import TracingHandler

logging.basicConfig(level=logging.INFO, handlers=[TracingHandler.handle()])

Signed-off-by: Bigo <1781140+crisidev@users.noreply.github.com>
Co-authored-by: Burak <burakvar@amazon.co.uk>
  • Loading branch information
crisidev and unexge authored Sep 22, 2022
1 parent 997beeb commit e5c8cf3
Show file tree
Hide file tree
Showing 16 changed files with 1,319 additions and 187 deletions.
9 changes: 7 additions & 2 deletions CHANGELOG.next.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,14 @@
# references = ["smithy-rs#920"]
# meta = { "breaking" = false, "tada" = false, "bug" = false, "target" = "client | server | all"}
# author = "rcoh"

[[rust-runtime]]
message = "Pokémon Service example code now runs clippy during build."
references = ["smithy-rs#1727"]
meta = { "breaking" = false, "tada" = false, "bug" = false }
meta = { "breaking" = false, "tada" = false, "bug" = false, "target" = "server" }
author = "GeneralSwiss"

[[smithy-rs]]
message = "Implement support for pure Python request middleware. Improve idiomatic logging support over tracing."
references = ["smithy-rs#1734"]
meta = { "breaking" = false, "tada" = false, "bug" = false, "target" = "server" }
author = "crisidev"
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency
* Example:
* from pool import DatabasePool
* from my_library import App, OperationInput, OperationOutput
* @dataclass
* class Context:
* db = DatabasePool()
Expand Down Expand Up @@ -69,6 +68,7 @@ class PythonApplicationGenerator(
private val libName = "lib${coreCodegenContext.settings.moduleName.toSnakeCase()}"
private val runtimeConfig = coreCodegenContext.runtimeConfig
private val model = coreCodegenContext.model
private val protocol = coreCodegenContext.protocol
private val codegenScope =
arrayOf(
"SmithyPython" to PythonServerCargoDependency.SmithyHttpServerPython(runtimeConfig).asType(),
Expand Down Expand Up @@ -101,6 +101,7 @@ class PythonApplicationGenerator(
##[derive(Debug, Default)]
pub struct App {
handlers: #{HashMap}<String, #{SmithyPython}::PyHandler>,
middlewares: #{SmithyPython}::PyMiddlewares,
context: Option<#{pyo3}::PyObject>,
workers: #{parking_lot}::Mutex<Vec<#{pyo3}::PyObject>>,
}
Expand All @@ -116,6 +117,7 @@ class PythonApplicationGenerator(
fn clone(&self) -> Self {
Self {
handlers: self.handlers.clone(),
middlewares: self.middlewares.clone(),
context: self.context.clone(),
workers: #{parking_lot}::Mutex::new(vec![]),
}
Expand Down Expand Up @@ -151,7 +153,7 @@ class PythonApplicationGenerator(
val name = operationName.toSnakeCase()
rustTemplate(
"""
let ${name}_locals = pyo3_asyncio::TaskLocals::new(event_loop);
let ${name}_locals = #{pyo3_asyncio}::TaskLocals::new(event_loop);
let handler = self.handlers.get("$name").expect("Python handler for operation `$name` not found").clone();
let router = router.$name(move |input, state| {
#{pyo3_asyncio}::tokio::scope(${name}_locals, crate::operation_handler::$name(input, state, handler))
Expand All @@ -162,11 +164,20 @@ class PythonApplicationGenerator(
}
rustTemplate(
"""
let middleware_locals = pyo3_asyncio::TaskLocals::new(event_loop);
use #{SmithyPython}::PyApp;
let service = #{tower}::ServiceBuilder::new().layer(
#{SmithyPython}::PyMiddlewareLayer::new(
self.middlewares.clone(),
self.protocol(),
middleware_locals
)?,
);
let router: #{SmithyServer}::routing::Router = router
.build()
.expect("Unable to build operation registry")
.into();
Ok(router)
Ok(router.layer(service))
""",
*codegenScope,
)
Expand All @@ -175,20 +186,25 @@ class PythonApplicationGenerator(
}

private fun renderPyAppTrait(writer: RustWriter) {
val protocol = protocol.toString().replace("#", "##")
writer.rustTemplate(
"""
impl #{SmithyPython}::PyApp for App {
fn workers(&self) -> &#{parking_lot}::Mutex<Vec<#{pyo3}::PyObject>> {
&self.workers
}
fn context(&self) -> &Option<#{pyo3}::PyObject> {
&self.context
}
fn handlers(&mut self) -> &mut #{HashMap}<String, #{SmithyPython}::PyHandler> {
&mut self.handlers
}
fn middlewares(&mut self) -> &mut #{SmithyPython}::PyMiddlewares {
&mut self.middlewares
}
fn protocol(&self) -> &'static str {
"$protocol"
}
}
""",
*codegenScope,
Expand All @@ -207,16 +223,20 @@ class PythonApplicationGenerator(
"""
/// Create a new [App].
##[new]
pub fn new(py: #{pyo3}::Python, log_level: Option<#{SmithyPython}::LogLevel>) -> #{pyo3}::PyResult<Self> {
let log_level = log_level.unwrap_or(#{SmithyPython}::LogLevel::Info);
#{SmithyPython}::logging::setup(py, log_level)?;
Ok(Self::default())
pub fn new() -> Self {
Self::default()
}
/// Register a context object that will be shared between handlers.
##[pyo3(text_signature = "(${'$'}self, context)")]
pub fn context(&mut self, context: #{pyo3}::PyObject) {
self.context = Some(context);
}
/// Register a request middleware function that will be run inside a Tower layer, without cloning the body.
##[pyo3(text_signature = "(${'$'}self, func)")]
pub fn request_middleware(&mut self, py: #{pyo3}::Python, func: #{pyo3}::PyObject) -> #{pyo3}::PyResult<()> {
use #{SmithyPython}::PyApp;
self.register_middleware(py, func, #{SmithyPython}::PyMiddlewareType::Request)
}
/// Main entrypoint: start the server on multiple workers.
##[pyo3(text_signature = "(${'$'}self, address, port, backlog, workers)")]
pub fn run(
Expand All @@ -235,7 +255,7 @@ class PythonApplicationGenerator(
pub fn start_worker(
&mut self,
py: pyo3::Python,
socket: &pyo3::PyCell<aws_smithy_http_server_python::PySocket>,
socket: &pyo3::PyCell<#{SmithyPython}::PySocket>,
worker_number: isize,
) -> pyo3::PyResult<()> {
use #{SmithyPython}::PyApp;
Expand Down Expand Up @@ -280,21 +300,17 @@ class PythonApplicationGenerator(
""".trimIndent(),
)
writer.rust(
if (operations.any { it.errors.isNotEmpty() }) {
"""
/// from $libName import ${Inputs.namespace}
/// from $libName import ${Outputs.namespace}
/// from $libName import ${Errors.namespace}
""".trimIndent()
} else {
"""
/// from $libName import ${Inputs.namespace}
/// from $libName import ${Outputs.namespace}
""".trimIndent()
},
"""
/// from $libName import ${Inputs.namespace}
/// from $libName import ${Outputs.namespace}
""".trimIndent(),
)
if (operations.any { it.errors.isNotEmpty() }) {
writer.rust("""/// from $libName import ${Errors.namespace}""".trimIndent())
}
writer.rust(
"""
/// from $libName import middleware
/// from $libName import App
///
/// @dataclass
Expand All @@ -304,6 +320,11 @@ class PythonApplicationGenerator(
/// app = App()
/// app.context(Context())
///
/// @app.request_middleware
/// def request_middleware(request: middleware::Request):
/// if request.get_header("x-amzn-id") != "secret":
/// raise middleware.MiddlewareException("Unsupported `x-amz-id` header", 401)
///
""".trimIndent(),
)
writer.operationImplementationStubs(operations)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ class PythonServerModuleGenerator(
renderPyCodegeneratedTypes()
renderPyWrapperTypes()
renderPySocketType()
renderPyLogging()
renderPyMiddlewareTypes()
renderPyApplicationType()
}
}
Expand Down Expand Up @@ -125,6 +127,43 @@ class PythonServerModuleGenerator(
)
}

// Render Python shared socket type.
private fun RustWriter.renderPyLogging() {
rustTemplate(
"""
let logging = #{pyo3}::types::PyModule::new(py, "logging")?;
logging.add_function(#{pyo3}::wrap_pyfunction!(#{SmithyPython}::py_tracing_event, m)?)?;
logging.add_class::<#{SmithyPython}::PyTracingHandler>()?;
#{pyo3}::py_run!(
py,
logging,
"import sys; sys.modules['$libName.logging'] = logging"
);
m.add_submodule(logging)?;
""",
*codegenScope,
)
}

private fun RustWriter.renderPyMiddlewareTypes() {
rustTemplate(
"""
let middleware = #{pyo3}::types::PyModule::new(py, "middleware")?;
middleware.add_class::<#{SmithyPython}::PyRequest>()?;
middleware.add_class::<#{SmithyPython}::PyResponse>()?;
middleware.add_class::<#{SmithyPython}::PyMiddlewareException>()?;
middleware.add_class::<#{SmithyPython}::PyHttpVersion>()?;
pyo3::py_run!(
py,
middleware,
"import sys; sys.modules['$libName.middleware'] = middleware"
);
m.add_submodule(middleware)?;
""",
*codegenScope,
)
}

// Render Python application type.
private fun RustWriter.renderPyApplicationType() {
rustTemplate(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,16 +90,14 @@ class PythonServerOperationHandlerGenerator(
rustTemplate(
"""
#{tracing}::debug!("Executing Python handler function `$name()`");
#{tokio}::task::block_in_place(move || {
#{pyo3}::Python::with_gil(|py| {
let pyhandler: &#{pyo3}::types::PyFunction = handler.extract(py)?;
let output = if handler.args == 1 {
pyhandler.call1((input,))?
} else {
pyhandler.call1((input, state.0))?
};
output.extract::<$output>()
})
#{pyo3}::Python::with_gil(|py| {
let pyhandler: &#{pyo3}::types::PyFunction = handler.extract(py)?;
let output = if handler.args == 1 {
pyhandler.call1((input,))?
} else {
pyhandler.call1((input, state.0))?
};
output.extract::<$output>()
})
""",
*codegenScope,
Expand Down
10 changes: 8 additions & 2 deletions rust-runtime/aws-smithy-http-server-python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,33 @@ Python server runtime for Smithy Rust Server Framework.
publish = true

[dependencies]
aws-smithy-http = { path = "../aws-smithy-http" }
aws-smithy-http-server = { path = "../aws-smithy-http-server" }
aws-smithy-json = { path = "../aws-smithy-json" }
aws-smithy-types = { path = "../aws-smithy-types" }
aws-smithy-http = { path = "../aws-smithy-http" }
aws-smithy-xml = { path = "../aws-smithy-xml" }
bytes = "1.2"
futures = "0.3"
http = "0.2"
hyper = { version = "0.14.20", features = ["server", "http1", "http2", "tcp", "stream"] }
num_cpus = "1.13.1"
parking_lot = "0.12.1"
pin-project-lite = "0.2"
pyo3 = "0.16.5"
pyo3-asyncio = { version = "0.16.0", features = ["tokio-runtime"] }
signal-hook = { version = "0.3.14", features = ["extended-siginfo"] }
socket2 = { version = "0.4.4", features = ["all"] }
thiserror = "1.0.32"
tokio = { version = "1.20.1", features = ["full"] }
tokio-stream = "0.1"
tower = "0.4.13"
tower = { version = "0.4.13", features = ["util"] }
tracing = "0.1.36"
tracing-subscriber = { version = "0.3.15", features = ["env-filter"] }
tracing-appender = { version = "0.2.2"}

[dev-dependencies]
pretty_assertions = "1"
futures-util = "0.3"

[package.metadata.docs.rs]
all-features = true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,25 @@
from typing import List, Optional

import aiohttp

from libpokemon_service_server_sdk import App
from libpokemon_service_server_sdk.error import ResourceNotFoundException
from libpokemon_service_server_sdk.input import (
EmptyOperationInput, GetPokemonSpeciesInput, GetServerStatisticsInput,
HealthCheckOperationInput, StreamPokemonRadioOperationInput)
from libpokemon_service_server_sdk.logging import TracingHandler
from libpokemon_service_server_sdk.middleware import (MiddlewareException,
Request)
from libpokemon_service_server_sdk.model import FlavorText, Language
from libpokemon_service_server_sdk.output import (
EmptyOperationOutput, GetPokemonSpeciesOutput, GetServerStatisticsOutput,
HealthCheckOperationOutput, StreamPokemonRadioOperationOutput)
from libpokemon_service_server_sdk.types import ByteStream

# Logging can bee setup using standard Python tooling. We provide
# fast logging handler, Tracingandler based on Rust tracing crate.
logging.basicConfig(handlers=[TracingHandler(level=logging.DEBUG).handler()])


# A slightly more atomic counter using a threading lock.
class FastWriteCounter:
Expand Down Expand Up @@ -111,6 +119,55 @@ def get_random_radio_stream(self) -> str:
app.context(Context())


###########################################################
# Middleware
############################################################
# Middlewares are sync or async function decorated by `@app.middleware`.
# They are executed in order and take as input the HTTP request object.
# A middleware can return multiple values, following these rules:
# * Middleware not returning will let the execution continue without
# changing the original request.
# * Middleware returning a modified Request will update the original
# request before continuing the execution.
# * Middleware returning a Response will immediately terminate the request
# handling and return the response constructed from Python.
# * Middleware raising MiddlewareException will immediately terminate the
# request handling and return a protocol specific error, with the option of
# setting the HTTP return code.
# * Middleware raising any other exception will immediately terminate the
# request handling and return a protocol specific error, with HTTP status
# code 500.
@app.request_middleware
def check_content_type_header(request: Request):
content_type = request.get_header("content-type")
if content_type == "application/json":
logging.debug("Found valid `application/json` content type")
else:
logging.warning(
f"Invalid content type {content_type}, dumping headers: {request.headers()}"
)


# This middleware adds a new header called `x-amzn-answer` to the
# request. We expect to see this header to be populated in the next
# middleware.
@app.request_middleware
def add_x_amzn_answer_header(request: Request):
request.set_header("x-amzn-answer", "42")
logging.debug("Setting `x-amzn-answer` header to 42")
return request


# This middleware checks if the header `x-amzn-answer` is correctly set
# to 42, otherwise it returns an exception with a set status code.
@app.request_middleware
async def check_x_amzn_answer_header(request: Request):
# Check that `x-amzn-answer` is 42.
if request.get_header("x-amzn-answer") != "42":
# Return an HTTP 401 Unauthorized if the content type is not JSON.
raise MiddlewareException("Invalid answer", 401)


###########################################################
# App handlers definition
###########################################################
Expand All @@ -131,6 +188,7 @@ def get_pokemon_species(
if flavor_text_entries:
logging.debug("Total requests executed: %s", context.get_calls_count())
logging.info("Found description for Pokémon %s", input.name)
logging.error("Found some stuff")
return GetPokemonSpeciesOutput(
name=input.name, flavor_text_entries=flavor_text_entries
)
Expand Down
Loading

0 comments on commit e5c8cf3

Please sign in to comment.