Skip to content

Commit

Permalink
Allow to create request scopes outside of the middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
soldag authored and matyasrichter committed May 19, 2023
1 parent a24801a commit e2e4488
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 9 deletions.
2 changes: 2 additions & 0 deletions fastapi_injector/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from fastapi_injector.request_scope import (
InjectorMiddleware,
RequestScope,
RequestScopeFactory,
request_scope,
)

Expand All @@ -20,5 +21,6 @@
"InjectorNotAttached",
"request_scope",
"RequestScope",
"RequestScopeFactory",
"InjectorMiddleware",
]
42 changes: 33 additions & 9 deletions fastapi_injector/request_scope.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
import uuid
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Any, Dict, Type

from injector import Injector, InstanceProvider, Provider, Scope, ScopeDecorator, T
from injector import (
Inject,
Injector,
InstanceProvider,
Provider,
Scope,
ScopeDecorator,
T,
)
from starlette.types import Receive, Send

from fastapi_injector.exceptions import RequestScopeError
Expand Down Expand Up @@ -66,25 +75,40 @@ def clear_key(self, key: uuid.UUID) -> None:
request_scope = ScopeDecorator(RequestScope)


class RequestScopeFactory:
"""
Allows to create request scopes.
"""

def __init__(self, request_scope_instance: Inject[RequestScope]) -> None:
self.request_scope_instance = request_scope_instance

@contextmanager
def create_scope(self):
"""Creates a new request scope within dependencies are cached."""
rid = uuid.uuid4()
rid_ctx = _request_id_ctx.set(rid)
self.request_scope_instance.add_key(rid)
try:
yield
finally:
self.request_scope_instance.clear_key(rid)
_request_id_ctx.reset(rid_ctx)


class InjectorMiddleware:
"""
Middleware that enables request-scoped injection through ContextVar-based cache.
"""

def __init__(self, app, injector: Injector) -> None:
self.app = app
self.request_scope_instance = injector.get(RequestScope)
self.request_scope_factory = injector.get(RequestScopeFactory)

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
"""
Add an identifier to the request
that can be used retrieve scoped dependencies.
"""
rid = uuid.uuid4()
rid_ctx = _request_id_ctx.set(rid)
self.request_scope_instance.add_key(rid)
try:
with self.request_scope_factory.create_scope():
await self.app(scope, receive, send)
finally:
self.request_scope_instance.clear_key(rid)
_request_id_ctx.reset(rid_ctx)
23 changes: 23 additions & 0 deletions tests/test_request_scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
Injected,
InjectorMiddleware,
RequestScope,
RequestScopeFactory,
attach_injector,
request_scope,
)
Expand Down Expand Up @@ -209,3 +210,25 @@ def get_root(
await client.get("/")

assert len(scope_instance.cache) == 0


async def test_caches_instances_with_scope_factory():
class DummyInterface:
pass

class DummyImpl:
pass

inj = Injector()
inj.binder.bind(DummyInterface, to=DummyImpl, scope=request_scope)

factory = inj.get(RequestScopeFactory)

with factory.create_scope():
dummy1 = inj.get(DummyInterface)
dummy2 = inj.get(DummyInterface)
assert dummy1 is dummy2

with factory.create_scope():
dummy3 = inj.get(DummyInterface)
assert dummy1 is not dummy3

0 comments on commit e2e4488

Please sign in to comment.