Skip to content

Commit

Permalink
Add RedisCluster.remap_host_port, Update tests for CWE 404 (redis#2706)
Browse files Browse the repository at this point in the history
* Use provided redis address. Bind to IPv4

* Add missing "await" and perform the correct test for pipe eimpty

* Wait for a send event, rather than rely on sleep time. Excpect cancel errors.

* set delay to 0 except for operation we want to cancel
This speeds up the unit tests considerably by eliminating unnecessary delay.

* Release resources in test

* Fix cluster test to use address_remap and multiple proxies.

* Use context manager to manage DelayProxy

* Mark failing pipeline tests

* lint

* Use a common "master_host" test fixture
  • Loading branch information
kristjanvalur authored May 7, 2023
1 parent ffb2b83 commit 3748a8b
Show file tree
Hide file tree
Showing 7 changed files with 227 additions and 155 deletions.
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ def mock_cluster_resp_slaves(request, **kwargs):
def master_host(request):
url = request.config.getoption("--redis-url")
parts = urlparse(url)
yield parts.hostname, parts.port
return parts.hostname, (parts.port or 6379)


@pytest.fixture()
Expand Down
8 changes: 0 additions & 8 deletions tests/test_asyncio/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import random
from contextlib import asynccontextmanager as _asynccontextmanager
from typing import Union
from urllib.parse import urlparse

import pytest
import pytest_asyncio
Expand Down Expand Up @@ -209,13 +208,6 @@ async def mock_cluster_resp_slaves(create_redis, **kwargs):
return _gen_cluster_mock_resp(r, response)


@pytest_asyncio.fixture(scope="session")
def master_host(request):
url = request.config.getoption("--redis-url")
parts = urlparse(url)
return parts.hostname


async def wait_for_command(
client: redis.Redis, monitor: Monitor, command: str, key: Union[str, None] = None
):
Expand Down
20 changes: 4 additions & 16 deletions tests/test_asyncio/test_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,18 +102,6 @@ async def pipe(
await writer.drain()


@pytest.fixture
def redis_addr(request):
redis_url = request.config.getoption("--redis-url")
scheme, netloc = urlparse(redis_url)[:2]
assert scheme == "redis"
if ":" in netloc:
host, port = netloc.split(":")
return host, int(port)
else:
return netloc, 6379


@pytest_asyncio.fixture()
async def slowlog(r: RedisCluster) -> None:
"""
Expand Down Expand Up @@ -874,15 +862,16 @@ async def test_default_node_is_replaced_after_exception(self, r):
# Rollback to the old default node
r.replace_default_node(curr_default_node)

async def test_address_remap(self, create_redis, redis_addr):
async def test_address_remap(self, create_redis, master_host):
"""Test that we can create a rediscluster object with
a host-port remapper and map connections through proxy objects
"""

# we remap the first n nodes
offset = 1000
n = 6
ports = [redis_addr[1] + i for i in range(n)]
hostname, master_port = master_host
ports = [master_port + i for i in range(n)]

def address_remap(address):
# remap first three nodes to our local proxy
Expand All @@ -895,8 +884,7 @@ def address_remap(address):

# create the proxies
proxies = [
NodeProxy(("127.0.0.1", port + offset), (redis_addr[0], port))
for port in ports
NodeProxy(("127.0.0.1", port + offset), (hostname, port)) for port in ports
]
await asyncio.gather(*[p.start() for p in proxies])
try:
Expand Down
10 changes: 5 additions & 5 deletions tests/test_asyncio/test_connection_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,14 +136,14 @@ async def test_connection_creation(self):
assert connection.kwargs == connection_kwargs

async def test_multiple_connections(self, master_host):
connection_kwargs = {"host": master_host}
connection_kwargs = {"host": master_host[0]}
async with self.get_pool(connection_kwargs=connection_kwargs) as pool:
c1 = await pool.get_connection("_")
c2 = await pool.get_connection("_")
assert c1 != c2

async def test_max_connections(self, master_host):
connection_kwargs = {"host": master_host}
connection_kwargs = {"host": master_host[0]}
async with self.get_pool(
max_connections=2, connection_kwargs=connection_kwargs
) as pool:
Expand All @@ -153,7 +153,7 @@ async def test_max_connections(self, master_host):
await pool.get_connection("_")

async def test_reuse_previously_released_connection(self, master_host):
connection_kwargs = {"host": master_host}
connection_kwargs = {"host": master_host[0]}
async with self.get_pool(connection_kwargs=connection_kwargs) as pool:
c1 = await pool.get_connection("_")
await pool.release(c1)
Expand Down Expand Up @@ -237,7 +237,7 @@ async def test_multiple_connections(self, master_host):

async def test_connection_pool_blocks_until_timeout(self, master_host):
"""When out of connections, block for timeout seconds, then raise"""
connection_kwargs = {"host": master_host}
connection_kwargs = {"host": master_host[0]}
async with self.get_pool(
max_connections=1, timeout=0.1, connection_kwargs=connection_kwargs
) as pool:
Expand Down Expand Up @@ -270,7 +270,7 @@ async def target():
assert asyncio.get_running_loop().time() - start >= 0.1

async def test_reuse_previously_released_connection(self, master_host):
connection_kwargs = {"host": master_host}
connection_kwargs = {"host": master_host[0]}
async with self.get_pool(connection_kwargs=connection_kwargs) as pool:
c1 = await pool.get_connection("_")
await pool.release(c1)
Expand Down
Loading

0 comments on commit 3748a8b

Please sign in to comment.