Skip to content

Commit

Permalink
Overriding with context manager (one and multiple providers) (#53)
Browse files Browse the repository at this point in the history
* Implement overriding with context manager for one provider

* Implement batch overriding with context manager for container

Co-authored-by: ivan <ivan.belyaev@ailet.com>
  • Loading branch information
nightblure and ivan authored Jul 15, 2024
1 parent 18a1857 commit 548b5e7
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 0 deletions.
79 changes: 79 additions & 0 deletions tests/providers/test_providers_overriding.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,87 @@
import datetime

import pytest

from tests import container


async def test_batch_providers_overriding() -> None:
async_resource_mock = datetime.datetime.fromisoformat("2023-01-01")
sync_resource_mock = datetime.datetime.fromisoformat("2024-01-01")
async_factory_mock = datetime.datetime.fromisoformat("2025-01-01")
simple_factory_mock = container.SimpleFactory(dep1="override", dep2=999)
singleton_mock = container.SingletonFactory(dep1=False)

providers_for_overriding = {
"async_resource": async_resource_mock,
"sync_resource": sync_resource_mock,
"simple_factory": simple_factory_mock,
"singleton": singleton_mock,
"async_factory": async_factory_mock,
}

with container.DIContainer.override_providers(providers_for_overriding):
await container.DIContainer.simple_factory()
dependent_factory = await container.DIContainer.dependent_factory()
singleton = await container.DIContainer.singleton()
async_factory = await container.DIContainer.async_factory()

assert dependent_factory.simple_factory.dep1 == simple_factory_mock.dep1
assert dependent_factory.simple_factory.dep2 == simple_factory_mock.dep2
assert dependent_factory.sync_resource == sync_resource_mock
assert dependent_factory.async_resource == async_resource_mock
assert singleton is singleton_mock
assert async_factory is async_factory_mock

assert (await container.DIContainer.async_resource()) != async_resource_mock


async def test_batch_providers_overriding_sync_resolve() -> None:
async_resource_mock = datetime.datetime.fromisoformat("2023-01-01")
sync_resource_mock = datetime.datetime.fromisoformat("2024-01-01")
simple_factory_mock = container.SimpleFactory(dep1="override", dep2=999)
singleton_mock = container.SingletonFactory(dep1=False)

providers_for_overriding = {
"async_resource": async_resource_mock,
"sync_resource": sync_resource_mock,
"simple_factory": simple_factory_mock,
"singleton": singleton_mock,
}

with container.DIContainer.override_providers(providers_for_overriding):
container.DIContainer.simple_factory.sync_resolve()
await container.DIContainer.async_resource.async_resolve()
dependent_factory = container.DIContainer.dependent_factory.sync_resolve()
singleton = container.DIContainer.singleton.sync_resolve()

assert dependent_factory.simple_factory.dep1 == simple_factory_mock.dep1
assert dependent_factory.simple_factory.dep2 == simple_factory_mock.dep2
assert dependent_factory.sync_resource == sync_resource_mock
assert dependent_factory.async_resource == async_resource_mock
assert singleton is singleton_mock

assert container.DIContainer.sync_resource.sync_resolve() != sync_resource_mock


def test_providers_overriding_with_context_manager() -> None:
simple_factory_mock = container.SimpleFactory(dep1="override", dep2=999)

with container.DIContainer.simple_factory.override_context(simple_factory_mock):
assert container.DIContainer.simple_factory.sync_resolve() is simple_factory_mock

assert container.DIContainer.simple_factory.sync_resolve() is not simple_factory_mock


def test_providers_overriding_fail_with_unknown_provider() -> None:
unknown_provider_name = "unknown_provider_name"
match = f"Provider with name {unknown_provider_name!r} not found"
providers_for_overriding = {unknown_provider_name: None}

with pytest.raises(RuntimeError, match=match), container.DIContainer.override_providers(providers_for_overriding):
... # pragma: no cover


async def test_providers_overriding() -> None:
async_resource_mock = datetime.datetime.fromisoformat("2023-01-01")
sync_resource_mock = datetime.datetime.fromisoformat("2024-01-01")
Expand Down
24 changes: 24 additions & 0 deletions that_depends/container.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import inspect
import typing
from contextlib import contextmanager

from that_depends.providers import AbstractProvider, AbstractResource, Singleton

Expand Down Expand Up @@ -92,3 +93,26 @@ async def resolve(cls, object_to_resolve: type[T] | typing.Callable[..., T]) ->
kwargs[field_name] = await providers[field_name].async_resolve()

return object_to_resolve(**kwargs)

@classmethod
@contextmanager
def override_providers(cls, providers_for_overriding: dict[str, typing.Any]) -> typing.Iterator[None]:
current_providers = cls.get_providers()
current_provider_names = set(current_providers.keys())
given_provider_names = set(providers_for_overriding.keys())

for given_name in given_provider_names:
if given_name not in current_provider_names:
msg = f"Provider with name {given_name!r} not found"
raise RuntimeError(msg)

for provider_name, mock in providers_for_overriding.items():
provider = current_providers[provider_name]
provider.override(mock)

try:
yield
finally:
for provider_name in providers_for_overriding:
provider = current_providers[provider_name]
provider.reset_override()
9 changes: 9 additions & 0 deletions that_depends/providers/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import abc
import typing
from contextlib import contextmanager


T = typing.TypeVar("T")
Expand All @@ -24,6 +25,14 @@ async def __call__(self) -> T_co:
def override(self, mock: object) -> None:
self._override = mock

@contextmanager
def override_context(self, mock: object) -> typing.Iterator[None]:
self.override(mock)
try:
yield
finally:
self.reset_override()

def reset_override(self) -> None:
self._override = None

Expand Down

0 comments on commit 548b5e7

Please sign in to comment.