From ef5a833e1a8445d5f6b535b350907f08ac0e4f53 Mon Sep 17 00:00:00 2001 From: Santiago Fraire Willemoes Date: Mon, 6 Jan 2025 14:32:46 +0100 Subject: [PATCH] fix: improve DI performance avoids allocating a dict each time the middleware --- kstreams/middleware/udf_middleware.py | 32 ++++++++++++++++++--------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/kstreams/middleware/udf_middleware.py b/kstreams/middleware/udf_middleware.py index 7ba4d2d..12b48f4 100644 --- a/kstreams/middleware/udf_middleware.py +++ b/kstreams/middleware/udf_middleware.py @@ -15,6 +15,18 @@ async def anext(async_gen: typing.AsyncGenerator): class UdfHandler(BaseMiddleware): + """User Defined Function Handler Middleware + + Manages dependency injection for user defined functions (UDFs) that are + defined as coroutines. The UDFs are defined by the user and are passed + to the stream engine to be executed when a consumer record is received. + + The UDFs can have different signatures and the middleware is responsible + for managing the dependency injection for the UDFs. + + UdfHandler tries to stay small and performant. + """ + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) signature = inspect.signature(self.next_call) @@ -23,20 +35,17 @@ def __init__(self, *args, **kwargs) -> None: for param in signature.parameters.values() ] self.type: UDFType = setup_type(self.params) - - def get_type(self) -> UDFType: - return self.type - - def bind_udf_params(self, cr: types.ConsumerRecord) -> typing.List: - # NOTE: When `no typing` support is deprecated then this can - # be more eficient as the CR will be always there. - ANNOTATIONS_TO_PARAMS = { - types.ConsumerRecord: cr, + self.annotations_to_params: dict[type, typing.Any] = { + types.ConsumerRecord: None, Stream: self.stream, types.Send: self.send, } - return [ANNOTATIONS_TO_PARAMS[param_type] for param_type in self.params] + def get_type(self) -> UDFType: + return self.type + + def bind_cr(self, cr: types.ConsumerRecord) -> None: + self.annotations_to_params[types.ConsumerRecord] = cr async def __call__(self, cr: types.ConsumerRecord) -> typing.Any: """ @@ -58,7 +67,8 @@ async def consume(cr: ConsumerRecord): async def consume(cr: ConsumerRecord, stream: Stream): ... """ - params = self.bind_udf_params(cr) + self.bind_cr(cr) + params = [self.annotations_to_params[param_type] for param_type in self.params] if inspect.isasyncgenfunction(self.next_call): return await anext(self.next_call(*params))