From f83f43b450afce7a08c7a528aa8cc996df2aaf8a Mon Sep 17 00:00:00 2001 From: Ben Cutler <90279826+ben-cutler-datarobot@users.noreply.github.com> Date: Tue, 22 Aug 2023 16:52:19 -0400 Subject: [PATCH] Change RedisBackend to accept Redis client directly (#755) Co-authored-by: Sam Bull --- CHANGES.rst | 1 + aiocache/backends/memcached.py | 8 +-- aiocache/backends/redis.py | 46 +++------------ aiocache/factory.py | 30 ++++++++-- examples/cached_alias_config.py | 15 +++-- examples/cached_decorator.py | 5 +- examples/multicached_decorator.py | 11 ++-- examples/optimistic_lock.py | 5 +- examples/python_object.py | 6 +- examples/redlock.py | 5 +- examples/serializer_class.py | 4 +- examples/serializer_function.py | 4 +- examples/simple_redis.py | 3 +- tests/acceptance/conftest.py | 4 +- tests/acceptance/test_factory.py | 17 +++--- tests/conftest.py | 27 +++++++++ tests/performance/conftest.py | 7 ++- tests/performance/server.py | 13 ++++- tests/ut/backends/test_memcached.py | 6 +- tests/ut/backends/test_redis.py | 87 ++++++----------------------- tests/ut/conftest.py | 4 +- tests/ut/test_factory.py | 51 +++++++---------- 22 files changed, 171 insertions(+), 188 deletions(-) create mode 100644 tests/conftest.py diff --git a/CHANGES.rst b/CHANGES.rst index 31ba5b7d..8962edf4 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -12,6 +12,7 @@ Migration instructions There are a number of backwards-incompatible changes. These points should help with migrating from an older release: +* ``RedisBackend`` now expects a ``redis.Redis`` instance as an argument, instead of creating one internally from keyword arguments. * The ``key_builder`` parameter for caches now expects a callback which accepts 2 strings and returns a string in all cache implementations, making the builders simpler and interchangeable. * The ``key`` parameter has been removed from the ``cached`` decorator. The behaviour can be easily reimplemented with ``key_builder=lambda *a, **kw: "foo"`` * When using the ``key_builder`` parameter in ``@multicached``, the function will now return the original, unmodified keys, only using the transformed keys in the cache (this has always been the documented behaviour, but not the implemented behaviour). diff --git a/aiocache/backends/memcached.py b/aiocache/backends/memcached.py index 06b33b42..76ac34e1 100644 --- a/aiocache/backends/memcached.py +++ b/aiocache/backends/memcached.py @@ -8,13 +8,13 @@ class MemcachedBackend(BaseCache[bytes]): - def __init__(self, endpoint="127.0.0.1", port=11211, pool_size=2, **kwargs): + def __init__(self, host="127.0.0.1", port=11211, pool_size=2, **kwargs): super().__init__(**kwargs) - self.endpoint = endpoint + self.host = host self.port = port self.pool_size = int(pool_size) self.client = aiomcache.Client( - self.endpoint, self.port, pool_size=self.pool_size + self.host, self.port, pool_size=self.pool_size ) async def _get(self, key, encoding="utf-8", _conn=None): @@ -153,4 +153,4 @@ def parse_uri_path(cls, path): return {} def __repr__(self): # pragma: no cover - return "MemcachedCache ({}:{})".format(self.endpoint, self.port) + return "MemcachedCache ({}:{})".format(self.host, self.port) diff --git a/aiocache/backends/redis.py b/aiocache/backends/redis.py index d30202ff..b150fbdd 100644 --- a/aiocache/backends/redis.py +++ b/aiocache/backends/redis.py @@ -1,5 +1,4 @@ import itertools -import warnings from typing import Any, Callable, Optional, TYPE_CHECKING import redis.asyncio as redis @@ -38,41 +37,19 @@ class RedisBackend(BaseCache[str]): def __init__( self, - endpoint="127.0.0.1", - port=6379, - db=0, - password=None, - pool_min_size=_NOT_SET, - pool_max_size=None, - create_connection_timeout=None, + client: redis.Redis, **kwargs, ): super().__init__(**kwargs) - if pool_min_size is not _NOT_SET: - warnings.warn( - "Parameter 'pool_min_size' is deprecated since aiocache 0.12", - DeprecationWarning, stacklevel=2 - ) - - self.endpoint = endpoint - self.port = int(port) - self.db = int(db) - self.password = password - # TODO: Remove int() call some time after adding type annotations. - self.pool_max_size = None if pool_max_size is None else int(pool_max_size) - self.create_connection_timeout = ( - float(create_connection_timeout) if create_connection_timeout else None - ) # NOTE: decoding can't be controlled on API level after switching to # redis, we need to disable decoding on global/connection level # (decode_responses=False), because some of the values are saved as # bytes directly, like pickle serialized values, which may raise an # exception when decoded with 'utf-8'. - self.client = redis.Redis(host=self.endpoint, port=self.port, db=self.db, - password=self.password, decode_responses=False, - socket_connect_timeout=self.create_connection_timeout, - max_connections=self.pool_max_size) + if client.connection_pool.connection_kwargs['decode_responses']: + raise ValueError("redis client must be constructed with decode_responses set to False") + self.client = client async def _get(self, key, encoding="utf-8", _conn=None): value = await self.client.get(key) @@ -175,9 +152,6 @@ async def _raw(self, command, *args, encoding="utf-8", _conn=None, **kwargs): async def _redlock_release(self, key, value): return await self._raw("eval", self.RELEASE_SCRIPT, 1, key, value) - async def _close(self, *args, _conn=None, **kwargs): - await self.client.close() - def build_key(self, key: str, namespace: Optional[str] = None) -> str: return self._str_build_key(key, namespace) @@ -196,24 +170,21 @@ class RedisCache(RedisBackend): the backend. Default is an empty string, "". :param timeout: int or float in seconds specifying maximum timeout for the operations to last. By default its 5. - :param endpoint: str with the endpoint to connect to. Default is "127.0.0.1". - :param port: int with the port to connect to. Default is 6379. - :param db: int indicating database to use. Default is 0. - :param password: str indicating password to use. Default is None. - :param pool_max_size: int maximum pool size for the redis connections pool. Default is None. - :param create_connection_timeout: int timeout for the creation of connection. Default is None + :param client: redis.Redis which is an active client for working with redis """ NAME = "redis" def __init__( self, + client: redis.Redis, serializer: Optional["BaseSerializer"] = None, namespace: str = "", key_builder: Callable[[str, str], str] = lambda k, ns: f"{ns}:{k}" if ns else k, **kwargs: Any, ): super().__init__( + client=client, serializer=serializer or JsonSerializer(), namespace=namespace, key_builder=key_builder, @@ -237,4 +208,5 @@ def parse_uri_path(cls, path): return options def __repr__(self): # pragma: no cover - return "RedisCache ({}:{})".format(self.endpoint, self.port) + connection_kwargs = self.client.connection_pool.connection_kwargs + return "RedisCache ({}:{})".format(connection_kwargs['host'], connection_kwargs['port']) diff --git a/aiocache/factory.py b/aiocache/factory.py index e1ebac33..1a4346a4 100644 --- a/aiocache/factory.py +++ b/aiocache/factory.py @@ -1,5 +1,6 @@ import logging import urllib +from contextlib import suppress from copy import deepcopy from typing import Dict @@ -7,6 +8,9 @@ from aiocache.base import BaseCache from aiocache.exceptions import InvalidCacheType +with suppress(ImportError): + import redis.asyncio as redis + logger = logging.getLogger(__name__) @@ -18,6 +22,7 @@ def _class_from_string(class_path): def _create_cache(cache, serializer=None, plugins=None, **kwargs): + kwargs = deepcopy(kwargs) if serializer is not None: cls = serializer.pop("class") cls = _class_from_string(cls) if isinstance(cls, str) else cls @@ -29,10 +34,17 @@ def _create_cache(cache, serializer=None, plugins=None, **kwargs): cls = plugin.pop("class") cls = _class_from_string(cls) if isinstance(cls, str) else cls plugins_instances.append(cls(**plugin)) - cache = _class_from_string(cache) if isinstance(cache, str) else cache - instance = cache(serializer=serializer, plugins=plugins_instances, **kwargs) - return instance + if cache == AIOCACHE_CACHES.get("redis"): + return cache( + serializer=serializer, + plugins=plugins_instances, + namespace=kwargs.pop('namespace', ''), + ttl=kwargs.pop('ttl', None), + client=redis.Redis(**kwargs) + ) + else: + return cache(serializer=serializer, plugins=plugins_instances, **kwargs) class Cache: @@ -112,7 +124,7 @@ def from_url(cls, url): kwargs.update(cache_class.parse_uri_path(parsed_url.path)) if parsed_url.hostname: - kwargs["endpoint"] = parsed_url.hostname + kwargs["host"] = parsed_url.hostname if parsed_url.port: kwargs["port"] = parsed_url.port @@ -120,7 +132,13 @@ def from_url(cls, url): if parsed_url.password: kwargs["password"] = parsed_url.password - return Cache(cache_class, **kwargs) + for arg in ['max_connections', 'socket_connect_timeout']: + if arg in kwargs: + kwargs[arg] = int(kwargs[arg]) + if cache_class == cls.REDIS: + return Cache(cache_class, client=redis.Redis(**kwargs)) + else: + return Cache(cache_class, **kwargs) class CacheHandler: @@ -214,7 +232,7 @@ def set_config(self, config): }, 'redis_alt': { 'cache': "aiocache.RedisCache", - 'endpoint': "127.0.0.10", + 'host': "127.0.0.10", 'port': 6378, 'serializer': { 'class': "aiocache.serializers.PickleSerializer" diff --git a/examples/cached_alias_config.py b/examples/cached_alias_config.py index a22678ad..27aea69b 100644 --- a/examples/cached_alias_config.py +++ b/examples/cached_alias_config.py @@ -1,5 +1,7 @@ import asyncio +import redis.asyncio as redis + from aiocache import caches, Cache from aiocache.serializers import StringSerializer, PickleSerializer @@ -12,9 +14,9 @@ }, 'redis_alt': { 'cache': "aiocache.RedisCache", - 'endpoint': "127.0.0.1", + "host": "127.0.0.1", 'port': 6379, - 'timeout': 1, + "socket_connect_timeout": 1, 'serializer': { 'class': "aiocache.serializers.PickleSerializer" }, @@ -45,9 +47,10 @@ async def alt_cache(): assert isinstance(cache, Cache.REDIS) assert isinstance(cache.serializer, PickleSerializer) assert len(cache.plugins) == 2 - assert cache.endpoint == "127.0.0.1" - assert cache.timeout == 1 - assert cache.port == 6379 + connection_args = cache.client.connection_pool.connection_kwargs + assert connection_args["host"] == "127.0.0.1" + assert connection_args["socket_connect_timeout"] == 1 + assert connection_args["port"] == 6379 await cache.close() @@ -55,7 +58,7 @@ async def test_alias(): await default_cache() await alt_cache() - cache = Cache(Cache.REDIS) + cache = Cache(Cache.REDIS, client=redis.Redis()) await cache.delete("key") await cache.close() diff --git a/examples/cached_decorator.py b/examples/cached_decorator.py index 01c5a46a..78d1cb11 100644 --- a/examples/cached_decorator.py +++ b/examples/cached_decorator.py @@ -1,6 +1,7 @@ import asyncio from collections import namedtuple +import redis.asyncio as redis from aiocache import cached, Cache from aiocache.serializers import PickleSerializer @@ -10,13 +11,13 @@ @cached( ttl=10, cache=Cache.REDIS, key_builder=lambda *args, **kw: "key", - serializer=PickleSerializer(), port=6379, namespace="main") + serializer=PickleSerializer(), namespace="main", client=redis.Redis()) async def cached_call(): return Result("content", 200) async def test_cached(): - async with Cache(Cache.REDIS, endpoint="127.0.0.1", port=6379, namespace="main") as cache: + async with Cache(Cache.REDIS, namespace="main", client=redis.Redis()) as cache: await cached_call() exists = await cache.exists("key") assert exists is True diff --git a/examples/multicached_decorator.py b/examples/multicached_decorator.py index d05d5f4a..59c0db80 100644 --- a/examples/multicached_decorator.py +++ b/examples/multicached_decorator.py @@ -1,5 +1,7 @@ import asyncio +import redis.asyncio as redis + from aiocache import multi_cached, Cache DICT = { @@ -9,20 +11,19 @@ 'd': "W" } +cache = Cache(Cache.REDIS, namespace="main", client=redis.Redis()) + -@multi_cached("ids", cache=Cache.REDIS, namespace="main") +@multi_cached("ids", cache=Cache.REDIS, namespace="main", client=cache.client) async def multi_cached_ids(ids=None): return {id_: DICT[id_] for id_ in ids} -@multi_cached("keys", cache=Cache.REDIS, namespace="main") +@multi_cached("keys", cache=Cache.REDIS, namespace="main", client=cache.client) async def multi_cached_keys(keys=None): return {id_: DICT[id_] for id_ in keys} -cache = Cache(Cache.REDIS, endpoint="127.0.0.1", port=6379, namespace="main") - - async def test_multi_cached(): await multi_cached_ids(ids=("a", "b")) await multi_cached_ids(ids=("a", "c")) diff --git a/examples/optimistic_lock.py b/examples/optimistic_lock.py index 20624907..422973e4 100644 --- a/examples/optimistic_lock.py +++ b/examples/optimistic_lock.py @@ -2,12 +2,13 @@ import logging import random +import redis.asyncio as redis + from aiocache import Cache from aiocache.lock import OptimisticLock, OptimisticLockError - logger = logging.getLogger(__name__) -cache = Cache(Cache.REDIS, endpoint='127.0.0.1', port=6379, namespace='main') +cache = Cache(Cache.REDIS, namespace="main", client=redis.Redis()) async def expensive_function(): diff --git a/examples/python_object.py b/examples/python_object.py index 8eea8b3b..984fad4c 100644 --- a/examples/python_object.py +++ b/examples/python_object.py @@ -1,12 +1,14 @@ import asyncio from collections import namedtuple +import redis.asyncio as redis + + from aiocache import Cache from aiocache.serializers import PickleSerializer - MyObject = namedtuple("MyObject", ["x", "y"]) -cache = Cache(Cache.REDIS, serializer=PickleSerializer(), namespace="main") +cache = Cache(Cache.REDIS, serializer=PickleSerializer(), namespace="main", client=redis.Redis()) async def complex_object(): diff --git a/examples/redlock.py b/examples/redlock.py index e763ddb3..38d703d7 100644 --- a/examples/redlock.py +++ b/examples/redlock.py @@ -1,12 +1,13 @@ import asyncio import logging +import redis.asyncio as redis + from aiocache import Cache from aiocache.lock import RedLock - logger = logging.getLogger(__name__) -cache = Cache(Cache.REDIS, endpoint='127.0.0.1', port=6379, namespace='main') +cache = Cache(Cache.REDIS, namespace="main", client=redis.Redis()) async def expensive_function(): diff --git a/examples/serializer_class.py b/examples/serializer_class.py index 50562b12..a9154843 100644 --- a/examples/serializer_class.py +++ b/examples/serializer_class.py @@ -1,6 +1,8 @@ import asyncio import zlib +import redis.asyncio as redis + from aiocache import Cache from aiocache.serializers import BaseSerializer @@ -25,7 +27,7 @@ def loads(self, value): return decompressed -cache = Cache(Cache.REDIS, serializer=CompressionSerializer(), namespace="main") +cache = Cache(Cache.REDIS, serializer=CompressionSerializer(), namespace="main", client=redis.Redis()) async def serializer(): diff --git a/examples/serializer_function.py b/examples/serializer_function.py index affa0b3b..05c5ba04 100644 --- a/examples/serializer_function.py +++ b/examples/serializer_function.py @@ -1,6 +1,8 @@ import asyncio import json +import redis.asyncio as redis + from marshmallow import Schema, fields, post_load from aiocache import Cache @@ -28,7 +30,7 @@ def loads(value): return MyTypeSchema().loads(value) -cache = Cache(Cache.REDIS, namespace="main") +cache = Cache(Cache.REDIS, namespace="main", client=redis.Redis()) async def serializer_function(): diff --git a/examples/simple_redis.py b/examples/simple_redis.py index 2ff6278d..1f429623 100644 --- a/examples/simple_redis.py +++ b/examples/simple_redis.py @@ -2,8 +2,9 @@ from aiocache import Cache +import redis.asyncio as redis -cache = Cache(Cache.REDIS, endpoint="127.0.0.1", port=6379, namespace="main") +cache = Cache(Cache.REDIS, namespace="main", client=redis.Redis()) async def redis(): diff --git a/tests/acceptance/conftest.py b/tests/acceptance/conftest.py index e4b2ba5c..0d5a306f 100644 --- a/tests/acceptance/conftest.py +++ b/tests/acceptance/conftest.py @@ -20,8 +20,8 @@ def reset_caches(): @pytest.fixture -async def redis_cache(): - async with Cache(Cache.REDIS, namespace="test") as cache: +async def redis_cache(redis_client): + async with Cache(Cache.REDIS, namespace="test", client=redis_client) as cache: yield cache await asyncio.gather(*(cache.delete(k) for k in (*Keys, KEY_LOCK))) diff --git a/tests/acceptance/test_factory.py b/tests/acceptance/test_factory.py index f39a9119..4a3bb4f8 100644 --- a/tests/acceptance/test_factory.py +++ b/tests/acceptance/test_factory.py @@ -11,22 +11,23 @@ async def test_from_url_memory(self): def test_from_url_memory_no_endpoint(self): with pytest.raises(TypeError): - Cache.from_url("memory://endpoint:10") + Cache.from_url("memory://host:10") @pytest.mark.redis async def test_from_url_redis(self): from aiocache.backends.redis import RedisCache url = ("redis://endpoint:1000/0/?password=pass" - + "&pool_max_size=50&create_connection_timeout=20") + + "&max_connections=50&socket_connect_timeout=20") async with Cache.from_url(url) as cache: assert isinstance(cache, RedisCache) - assert cache.endpoint == "endpoint" - assert cache.port == 1000 - assert cache.password == "pass" - assert cache.pool_max_size == 50 - assert cache.create_connection_timeout == 20 + connection_args = cache.client.connection_pool.connection_kwargs + assert connection_args["host"] == "endpoint" + assert connection_args["port"] == 1000 + assert connection_args["password"] == "pass" + assert cache.client.connection_pool.max_connections == 50 + assert connection_args["socket_connect_timeout"] == 20 @pytest.mark.memcached async def test_from_url_memcached(self): @@ -36,7 +37,7 @@ async def test_from_url_memcached(self): async with Cache.from_url(url) as cache: assert isinstance(cache, MemcachedCache) - assert cache.endpoint == "endpoint" + assert cache.host == "endpoint" assert cache.port == 1000 assert cache.pool_size == 10 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..4482701d --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,27 @@ +import pytest + + +@pytest.fixture() +def max_conns(): + return None + + +@pytest.fixture() +def decode_responses(): + return False + + +@pytest.fixture +async def redis_client(max_conns, decode_responses): + import redis.asyncio as redis + + async with redis.Redis( + host="127.0.0.1", + port=6379, + db=0, + password=None, + decode_responses=decode_responses, + socket_connect_timeout=None, + max_connections=max_conns + ) as r: + yield r diff --git a/tests/performance/conftest.py b/tests/performance/conftest.py index 6df8dbc9..c7a1ccd2 100644 --- a/tests/performance/conftest.py +++ b/tests/performance/conftest.py @@ -4,10 +4,11 @@ @pytest.fixture -async def redis_cache(): +@pytest.mark.parametrize("max_conns", 1) +async def redis_cache(redis_client): # redis connection pool raises ConnectionError but doesn't wait for conn reuse - # when exceeding max pool size. - async with Cache(Cache.REDIS, namespace="test", pool_max_size=1) as cache: + # when exceeding max pool size. + async with Cache(Cache.REDIS, namespace="test", client=redis_client) as cache: yield cache diff --git a/tests/performance/server.py b/tests/performance/server.py index 8de8c6b8..679fdd23 100644 --- a/tests/performance/server.py +++ b/tests/performance/server.py @@ -2,6 +2,7 @@ import logging import uuid +import redis.asyncio as redis from aiohttp import web from aiocache import Cache @@ -16,7 +17,17 @@ def __init__(self, backend: str): "redis": Cache.REDIS, "memcached": Cache.MEMCACHED, } - self.cache = Cache(backends[backend]) + if backend == "redis": + cache_kwargs = {"client": redis.Redis( + host="127.0.0.1", + port=6379, + db=0, + password=None, + decode_responses=False, + )} + else: + cache_kwargs = dict() + self.cache = Cache(backends[backend], **cache_kwargs) async def get(self, key): return await self.cache.get(key, timeout=0.1) diff --git a/tests/ut/backends/test_memcached.py b/tests/ut/backends/test_memcached.py index 81cd744c..f0de04cb 100644 --- a/tests/ut/backends/test_memcached.py +++ b/tests/ut/backends/test_memcached.py @@ -32,17 +32,17 @@ def test_setup(self): aiomcache_client.assert_called_with("127.0.0.1", 11211, pool_size=2) - assert memcached.endpoint == "127.0.0.1" + assert memcached.host == "127.0.0.1" assert memcached.port == 11211 assert memcached.pool_size == 2 def test_setup_override(self): with patch.object(aiomcache, "Client", autospec=True) as aiomcache_client: - memcached = MemcachedBackend(endpoint="127.0.0.2", port=2, pool_size=10) + memcached = MemcachedBackend(host="127.0.0.2", port=2, pool_size=10) aiomcache_client.assert_called_with("127.0.0.2", 2, pool_size=10) - assert memcached.endpoint == "127.0.0.2" + assert memcached.host == "127.0.0.2" assert memcached.port == 2 assert memcached.pool_size == 10 diff --git a/tests/ut/backends/test_redis.py b/tests/ut/backends/test_redis.py index c6ad755a..10e5d2de 100644 --- a/tests/ut/backends/test_redis.py +++ b/tests/ut/backends/test_redis.py @@ -11,8 +11,8 @@ @pytest.fixture -def redis(): - redis = RedisBackend() +def redis(redis_client): + redis = RedisBackend(client=redis_client) with patch.object(redis, "client", autospec=True) as m: # These methods actually return an awaitable. for method in ( @@ -29,64 +29,15 @@ def redis(): class TestRedisBackend: - default_redis_kwargs = { - "host": "127.0.0.1", - "port": 6379, - "db": 0, - "password": None, - "socket_connect_timeout": None, - "decode_responses": False, - "max_connections": None, - } - - @patch("redis.asyncio.Redis", name="mock_class", autospec=True) - def test_setup(self, mock_class): - redis_backend = RedisBackend() - kwargs = self.default_redis_kwargs.copy() - mock_class.assert_called_with(**kwargs) - assert redis_backend.endpoint == "127.0.0.1" - assert redis_backend.port == 6379 - assert redis_backend.db == 0 - assert redis_backend.password is None - assert redis_backend.pool_max_size is None - - @patch("redis.asyncio.Redis", name="mock_class", autospec=True) - def test_setup_override(self, mock_class): - override = {"db": 2, "password": "pass"} - redis_backend = RedisBackend(**override) - - kwargs = self.default_redis_kwargs.copy() - kwargs.update(override) - mock_class.assert_called_with(**kwargs) - - assert redis_backend.endpoint == "127.0.0.1" - assert redis_backend.port == 6379 - assert redis_backend.db == 2 - assert redis_backend.password == "pass" - - @patch("redis.asyncio.Redis", name="mock_class", autospec=True) - def test_setup_casts(self, mock_class): - override = { - "db": "2", - "port": "6379", - "pool_max_size": "10", - "create_connection_timeout": "1.5", - } - redis_backend = RedisBackend(**override) - - kwargs = self.default_redis_kwargs.copy() - kwargs.update({ - "db": 2, - "port": 6379, - "max_connections": 10, - "socket_connect_timeout": 1.5, - }) - mock_class.assert_called_with(**kwargs) - - assert redis_backend.db == 2 - assert redis_backend.port == 6379 - assert redis_backend.pool_max_size == 10 - assert redis_backend.create_connection_timeout == 1.5 + + @pytest.mark.parametrize("decode_responses", [True]) + async def test_redis_backend_requires_client_decode_responses(self, redis_client): + with pytest.raises(ValueError) as ve: + RedisBackend(client=redis_client) + + assert str(ve.value) == ( + "redis client must be constructed with decode_responses set to False" + ) async def test_get(self, redis): redis.client.get.return_value = b"value" @@ -224,10 +175,6 @@ async def test_redlock_release(self, mocker, redis): await redis._redlock_release(Keys.KEY, "random") redis._raw.assert_called_with("eval", redis.RELEASE_SCRIPT, 1, Keys.KEY, "random") - async def test_close(self, redis): - await redis._close() - assert redis.client.close.call_count == 1 - class TestRedisCache: @pytest.fixture @@ -239,17 +186,17 @@ def set_test_namespace(self, redis_cache): def test_name(self): assert RedisCache.NAME == "redis" - def test_inheritance(self): - assert isinstance(RedisCache(), BaseCache) + def test_inheritance(self, redis_client): + assert isinstance(RedisCache(client=redis_client), BaseCache) - def test_default_serializer(self): - assert isinstance(RedisCache().serializer, JsonSerializer) + def test_default_serializer(self, redis_client): + assert isinstance(RedisCache(client=redis_client).serializer, JsonSerializer) @pytest.mark.parametrize( "path,expected", [("", {}), ("/", {}), ("/1", {"db": "1"}), ("/1/2/3", {"db": "1"})] ) - def test_parse_uri_path(self, path, expected): - assert RedisCache().parse_uri_path(path) == expected + def test_parse_uri_path(self, path, expected, redis_client): + assert RedisCache(client=redis_client).parse_uri_path(path) == expected @pytest.mark.parametrize( "namespace, expected", diff --git a/tests/ut/conftest.py b/tests/ut/conftest.py index 9323567b..591f1c44 100644 --- a/tests/ut/conftest.py +++ b/tests/ut/conftest.py @@ -53,10 +53,10 @@ def base_cache(): @pytest.fixture -async def redis_cache(): +async def redis_cache(redis_client): from aiocache.backends.redis import RedisCache - async with RedisCache() as cache: + async with RedisCache(client=redis_client) as cache: yield cache diff --git a/tests/ut/test_factory.py b/tests/ut/test_factory.py index 56de47e7..7b33b8b3 100644 --- a/tests/ut/test_factory.py +++ b/tests/ut/test_factory.py @@ -34,15 +34,6 @@ def test_class_from_string(): assert _class_from_string("aiocache.RedisCache") == RedisCache -@pytest.mark.redis -def test_create_simple_cache(): - redis = _create_cache(RedisCache, endpoint="127.0.0.10", port=6378) - - assert isinstance(redis, RedisCache) - assert redis.endpoint == "127.0.0.10" - assert redis.port == 6378 - - def test_create_cache_with_everything(): cache = _create_cache( SimpleMemoryCache, @@ -97,26 +88,26 @@ def test_from_url_returns_cache_from_scheme(self, scheme): "url,expected_args", [ ("redis://", {}), - ("redis://localhost", {"endpoint": "localhost"}), - ("redis://localhost/", {"endpoint": "localhost"}), - ("redis://localhost:6379", {"endpoint": "localhost", "port": 6379}), + ("redis://localhost", {"host": "localhost"}), + ("redis://localhost/", {"host": "localhost"}), + ("redis://localhost:6379", {"host": "localhost", "port": 6379}), ( "redis://localhost/?arg1=arg1&arg2=arg2", - {"endpoint": "localhost", "arg1": "arg1", "arg2": "arg2"}, + {"host": "localhost", "arg1": "arg1", "arg2": "arg2"}, ), ( "redis://localhost:6379/?arg1=arg1&arg2=arg2", - {"endpoint": "localhost", "port": 6379, "arg1": "arg1", "arg2": "arg2"}, + {"host": "localhost", "port": 6379, "arg1": "arg1", "arg2": "arg2"}, ), ("redis:///?arg1=arg1", {"arg1": "arg1"}), ("redis:///?arg2=arg2", {"arg2": "arg2"}), ( "redis://:password@localhost:6379", - {"endpoint": "localhost", "password": "password", "port": 6379}, + {"host": "localhost", "password": "password", "port": 6379}, ), ( "redis://:password@localhost:6379?password=pass", - {"endpoint": "localhost", "password": "password", "port": 6379}, + {"host": "localhost", "password": "password", "port": 6379}, ), ], ) @@ -185,16 +176,16 @@ def test_create_extra_args(self): { "default": { "cache": "aiocache.RedisCache", - "endpoint": "127.0.0.9", + "host": "127.0.0.9", "db": 10, "port": 6378, } } ) - cache = caches.create("default", namespace="whatever", endpoint="127.0.0.10", db=10) + cache = caches.create("default", namespace="whatever", host="127.0.0.10", db=10) assert cache.namespace == "whatever" - assert cache.endpoint == "127.0.0.10" - assert cache.db == 10 + assert cache.client.connection_pool.connection_kwargs["host"] == "127.0.0.10" + assert cache.client.connection_pool.connection_kwargs["db"] == 10 @pytest.mark.redis def test_retrieve_cache(self): @@ -202,7 +193,7 @@ def test_retrieve_cache(self): { "default": { "cache": "aiocache.RedisCache", - "endpoint": "127.0.0.10", + "host": "127.0.0.10", "port": 6378, "ttl": 10, "serializer": { @@ -219,8 +210,8 @@ def test_retrieve_cache(self): cache = caches.get("default") assert isinstance(cache, RedisCache) - assert cache.endpoint == "127.0.0.10" - assert cache.port == 6378 + assert cache.client.connection_pool.connection_kwargs["host"] == "127.0.0.10" + assert cache.client.connection_pool.connection_kwargs["port"] == 6378 assert cache.ttl == 10 assert isinstance(cache.serializer, PickleSerializer) assert cache.serializer.encoding == "encoding" @@ -232,7 +223,7 @@ def test_retrieve_cache_new_instance(self): { "default": { "cache": "aiocache.RedisCache", - "endpoint": "127.0.0.10", + "host": "127.0.0.10", "port": 6378, "serializer": { "class": "aiocache.serializers.PickleSerializer", @@ -248,8 +239,8 @@ def test_retrieve_cache_new_instance(self): cache = caches.create("default") assert isinstance(cache, RedisCache) - assert cache.endpoint == "127.0.0.10" - assert cache.port == 6378 + assert cache.client.connection_pool.connection_kwargs["host"] == "127.0.0.10" + assert cache.client.connection_pool.connection_kwargs["port"] == 6378 assert isinstance(cache.serializer, PickleSerializer) assert cache.serializer.encoding == "encoding" assert len(cache.plugins) == 2 @@ -260,7 +251,7 @@ def test_multiple_caches(self): { "default": { "cache": "aiocache.RedisCache", - "endpoint": "127.0.0.10", + "host": "127.0.0.10", "port": 6378, "serializer": {"class": "aiocache.serializers.PickleSerializer"}, "plugins": [ @@ -276,8 +267,8 @@ def test_multiple_caches(self): alt = caches.get("alt") assert isinstance(default, RedisCache) - assert default.endpoint == "127.0.0.10" - assert default.port == 6378 + assert default.client.connection_pool.connection_kwargs["host"] == "127.0.0.10" + assert default.client.connection_pool.connection_kwargs["port"] == 6378 assert isinstance(default.serializer, PickleSerializer) assert len(default.plugins) == 2 @@ -338,7 +329,7 @@ def test_set_config_no_default(self): { "no_default": { "cache": "aiocache.RedisCache", - "endpoint": "127.0.0.10", + "host": "127.0.0.10", "port": 6378, "serializer": {"class": "aiocache.serializers.PickleSerializer"}, "plugins": [