From 548b5e730c5c0e9f67b24b2ed3196ddcf184485d Mon Sep 17 00:00:00 2001 From: Ivan Belyaev Date: Mon, 15 Jul 2024 20:54:13 +0300 Subject: [PATCH] Overriding with context manager (one and multiple providers) (#53) * Implement overriding with context manager for one provider * Implement batch overriding with context manager for container Co-authored-by: ivan --- tests/providers/test_providers_overriding.py | 79 ++++++++++++++++++++ that_depends/container.py | 24 ++++++ that_depends/providers/base.py | 9 +++ 3 files changed, 112 insertions(+) diff --git a/tests/providers/test_providers_overriding.py b/tests/providers/test_providers_overriding.py index ccefbcf..a6dd0b5 100644 --- a/tests/providers/test_providers_overriding.py +++ b/tests/providers/test_providers_overriding.py @@ -1,8 +1,87 @@ import datetime +import pytest + from tests import container +async def test_batch_providers_overriding() -> None: + async_resource_mock = datetime.datetime.fromisoformat("2023-01-01") + sync_resource_mock = datetime.datetime.fromisoformat("2024-01-01") + async_factory_mock = datetime.datetime.fromisoformat("2025-01-01") + simple_factory_mock = container.SimpleFactory(dep1="override", dep2=999) + singleton_mock = container.SingletonFactory(dep1=False) + + providers_for_overriding = { + "async_resource": async_resource_mock, + "sync_resource": sync_resource_mock, + "simple_factory": simple_factory_mock, + "singleton": singleton_mock, + "async_factory": async_factory_mock, + } + + with container.DIContainer.override_providers(providers_for_overriding): + await container.DIContainer.simple_factory() + dependent_factory = await container.DIContainer.dependent_factory() + singleton = await container.DIContainer.singleton() + async_factory = await container.DIContainer.async_factory() + + assert dependent_factory.simple_factory.dep1 == simple_factory_mock.dep1 + assert dependent_factory.simple_factory.dep2 == simple_factory_mock.dep2 + assert dependent_factory.sync_resource == sync_resource_mock + assert dependent_factory.async_resource == async_resource_mock + assert singleton is singleton_mock + assert async_factory is async_factory_mock + + assert (await container.DIContainer.async_resource()) != async_resource_mock + + +async def test_batch_providers_overriding_sync_resolve() -> None: + async_resource_mock = datetime.datetime.fromisoformat("2023-01-01") + sync_resource_mock = datetime.datetime.fromisoformat("2024-01-01") + simple_factory_mock = container.SimpleFactory(dep1="override", dep2=999) + singleton_mock = container.SingletonFactory(dep1=False) + + providers_for_overriding = { + "async_resource": async_resource_mock, + "sync_resource": sync_resource_mock, + "simple_factory": simple_factory_mock, + "singleton": singleton_mock, + } + + with container.DIContainer.override_providers(providers_for_overriding): + container.DIContainer.simple_factory.sync_resolve() + await container.DIContainer.async_resource.async_resolve() + dependent_factory = container.DIContainer.dependent_factory.sync_resolve() + singleton = container.DIContainer.singleton.sync_resolve() + + assert dependent_factory.simple_factory.dep1 == simple_factory_mock.dep1 + assert dependent_factory.simple_factory.dep2 == simple_factory_mock.dep2 + assert dependent_factory.sync_resource == sync_resource_mock + assert dependent_factory.async_resource == async_resource_mock + assert singleton is singleton_mock + + assert container.DIContainer.sync_resource.sync_resolve() != sync_resource_mock + + +def test_providers_overriding_with_context_manager() -> None: + simple_factory_mock = container.SimpleFactory(dep1="override", dep2=999) + + with container.DIContainer.simple_factory.override_context(simple_factory_mock): + assert container.DIContainer.simple_factory.sync_resolve() is simple_factory_mock + + assert container.DIContainer.simple_factory.sync_resolve() is not simple_factory_mock + + +def test_providers_overriding_fail_with_unknown_provider() -> None: + unknown_provider_name = "unknown_provider_name" + match = f"Provider with name {unknown_provider_name!r} not found" + providers_for_overriding = {unknown_provider_name: None} + + with pytest.raises(RuntimeError, match=match), container.DIContainer.override_providers(providers_for_overriding): + ... # pragma: no cover + + async def test_providers_overriding() -> None: async_resource_mock = datetime.datetime.fromisoformat("2023-01-01") sync_resource_mock = datetime.datetime.fromisoformat("2024-01-01") diff --git a/that_depends/container.py b/that_depends/container.py index f4c874c..392efa4 100644 --- a/that_depends/container.py +++ b/that_depends/container.py @@ -1,5 +1,6 @@ import inspect import typing +from contextlib import contextmanager from that_depends.providers import AbstractProvider, AbstractResource, Singleton @@ -92,3 +93,26 @@ async def resolve(cls, object_to_resolve: type[T] | typing.Callable[..., T]) -> kwargs[field_name] = await providers[field_name].async_resolve() return object_to_resolve(**kwargs) + + @classmethod + @contextmanager + def override_providers(cls, providers_for_overriding: dict[str, typing.Any]) -> typing.Iterator[None]: + current_providers = cls.get_providers() + current_provider_names = set(current_providers.keys()) + given_provider_names = set(providers_for_overriding.keys()) + + for given_name in given_provider_names: + if given_name not in current_provider_names: + msg = f"Provider with name {given_name!r} not found" + raise RuntimeError(msg) + + for provider_name, mock in providers_for_overriding.items(): + provider = current_providers[provider_name] + provider.override(mock) + + try: + yield + finally: + for provider_name in providers_for_overriding: + provider = current_providers[provider_name] + provider.reset_override() diff --git a/that_depends/providers/base.py b/that_depends/providers/base.py index 1b99970..353d1a7 100644 --- a/that_depends/providers/base.py +++ b/that_depends/providers/base.py @@ -1,5 +1,6 @@ import abc import typing +from contextlib import contextmanager T = typing.TypeVar("T") @@ -24,6 +25,14 @@ async def __call__(self) -> T_co: def override(self, mock: object) -> None: self._override = mock + @contextmanager + def override_context(self, mock: object) -> typing.Iterator[None]: + self.override(mock) + try: + yield + finally: + self.reset_override() + def reset_override(self) -> None: self._override = None