From e2e4488bd0922e13767da1ded012ad2a910409c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=B6ren=20Oldag?= Date: Mon, 15 May 2023 15:56:15 +0200 Subject: [PATCH] Allow to create request scopes outside of the middleware --- fastapi_injector/__init__.py | 2 ++ fastapi_injector/request_scope.py | 42 ++++++++++++++++++++++++------- tests/test_request_scope.py | 23 +++++++++++++++++ 3 files changed, 58 insertions(+), 9 deletions(-) diff --git a/fastapi_injector/__init__.py b/fastapi_injector/__init__.py index 98260c7..14784eb 100644 --- a/fastapi_injector/__init__.py +++ b/fastapi_injector/__init__.py @@ -9,6 +9,7 @@ from fastapi_injector.request_scope import ( InjectorMiddleware, RequestScope, + RequestScopeFactory, request_scope, ) @@ -20,5 +21,6 @@ "InjectorNotAttached", "request_scope", "RequestScope", + "RequestScopeFactory", "InjectorMiddleware", ] diff --git a/fastapi_injector/request_scope.py b/fastapi_injector/request_scope.py index 0fa015e..99c1ef3 100644 --- a/fastapi_injector/request_scope.py +++ b/fastapi_injector/request_scope.py @@ -1,8 +1,17 @@ import uuid +from contextlib import contextmanager from contextvars import ContextVar from typing import Any, Dict, Type -from injector import Injector, InstanceProvider, Provider, Scope, ScopeDecorator, T +from injector import ( + Inject, + Injector, + InstanceProvider, + Provider, + Scope, + ScopeDecorator, + T, +) from starlette.types import Receive, Send from fastapi_injector.exceptions import RequestScopeError @@ -66,6 +75,27 @@ def clear_key(self, key: uuid.UUID) -> None: request_scope = ScopeDecorator(RequestScope) +class RequestScopeFactory: + """ + Allows to create request scopes. + """ + + def __init__(self, request_scope_instance: Inject[RequestScope]) -> None: + self.request_scope_instance = request_scope_instance + + @contextmanager + def create_scope(self): + """Creates a new request scope within dependencies are cached.""" + rid = uuid.uuid4() + rid_ctx = _request_id_ctx.set(rid) + self.request_scope_instance.add_key(rid) + try: + yield + finally: + self.request_scope_instance.clear_key(rid) + _request_id_ctx.reset(rid_ctx) + + class InjectorMiddleware: """ Middleware that enables request-scoped injection through ContextVar-based cache. @@ -73,18 +103,12 @@ class InjectorMiddleware: def __init__(self, app, injector: Injector) -> None: self.app = app - self.request_scope_instance = injector.get(RequestScope) + self.request_scope_factory = injector.get(RequestScopeFactory) async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: """ Add an identifier to the request that can be used retrieve scoped dependencies. """ - rid = uuid.uuid4() - rid_ctx = _request_id_ctx.set(rid) - self.request_scope_instance.add_key(rid) - try: + with self.request_scope_factory.create_scope(): await self.app(scope, receive, send) - finally: - self.request_scope_instance.clear_key(rid) - _request_id_ctx.reset(rid_ctx) diff --git a/tests/test_request_scope.py b/tests/test_request_scope.py index 8c28b4b..c70197d 100644 --- a/tests/test_request_scope.py +++ b/tests/test_request_scope.py @@ -14,6 +14,7 @@ Injected, InjectorMiddleware, RequestScope, + RequestScopeFactory, attach_injector, request_scope, ) @@ -209,3 +210,25 @@ def get_root( await client.get("/") assert len(scope_instance.cache) == 0 + + +async def test_caches_instances_with_scope_factory(): + class DummyInterface: + pass + + class DummyImpl: + pass + + inj = Injector() + inj.binder.bind(DummyInterface, to=DummyImpl, scope=request_scope) + + factory = inj.get(RequestScopeFactory) + + with factory.create_scope(): + dummy1 = inj.get(DummyInterface) + dummy2 = inj.get(DummyInterface) + assert dummy1 is dummy2 + + with factory.create_scope(): + dummy3 = inj.get(DummyInterface) + assert dummy1 is not dummy3