diff --git a/posttroll/backends/zmq/socket.py b/posttroll/backends/zmq/socket.py index cebfd72..fcc9411 100644 --- a/posttroll/backends/zmq/socket.py +++ b/posttroll/backends/zmq/socket.py @@ -1,5 +1,6 @@ """ZMQ socket handling functions.""" +from functools import cache from urllib.parse import urlsplit, urlunsplit import zmq @@ -101,6 +102,12 @@ def bind(sock, destination, port_interval): port_number = port return port_number +@cache +def get_auth_thread(ctx): + """Get the authenticator thread for the context.""" + thr = ThreadAuthenticator(ctx) + thr.start() + return thr def create_secure_server_socket(socket_type): """Create a secure server socket.""" @@ -109,10 +116,8 @@ def create_secure_server_socket(socket_type): authorized_sub_addresses = config.get("authorized_client_addresses", []) ctx = get_context() - # Start an authenticator for this context. - authenticator_thread = ThreadAuthenticator(ctx) - authenticator_thread.start() + authenticator_thread = get_auth_thread(ctx) authenticator_thread.allow(*authorized_sub_addresses) # Tell authenticator to use the certificate in a directory authenticator_thread.configure_curve(domain="*", location=clients_public_keys_directory) diff --git a/posttroll/message_broadcaster.py b/posttroll/message_broadcaster.py index 4990c36..214302b 100644 --- a/posttroll/message_broadcaster.py +++ b/posttroll/message_broadcaster.py @@ -43,6 +43,8 @@ def __init__(self, default_port, receivers): if backend == "unsecure_zmq": from posttroll.backends.zmq.message_broadcaster import ZMQDesignatedReceiversSender self._sender = ZMQDesignatedReceiversSender(default_port, receivers) + else: + raise NotImplementedError() def __call__(self, data): """Send messages from all receivers.""" diff --git a/posttroll/tests/test_nameserver.py b/posttroll/tests/test_nameserver.py index 88a11ed..bd3f709 100644 --- a/posttroll/tests/test_nameserver.py +++ b/posttroll/tests/test_nameserver.py @@ -60,9 +60,6 @@ def create_nameserver_instance(max_age=3, multicast_enabled=True): ns.stop() thr.join() -def fake_nameserver(): - config.set(nameserver_port=1111) - config.set(address_publish_port()) class TestAddressReceiver(unittest.TestCase): """Test the AddressReceiver.""" diff --git a/posttroll/tests/test_secure_zmq_backend.py b/posttroll/tests/test_secure_zmq_backend.py index f912125..7c4987f 100644 --- a/posttroll/tests/test_secure_zmq_backend.py +++ b/posttroll/tests/test_secure_zmq_backend.py @@ -6,8 +6,10 @@ import time from threading import Thread +import pytest import zmq.auth +import posttroll.backends.zmq from posttroll import config from posttroll.backends.zmq import generate_keys from posttroll.message import Message @@ -17,6 +19,15 @@ from posttroll.tests.test_nameserver import create_nameserver_instance +@pytest.fixture(autouse=True) +def new_context(monkeypatch): + """Create a new context for each test.""" + context = zmq.Context() + def get_context(): + return context + monkeypatch.setattr(posttroll.backends.zmq, "get_context", get_context) + + def create_keys(tmp_path): """Create keys.""" base_dir = tmp_path @@ -111,8 +122,6 @@ def test_switch_to_secure_zmq_backend(tmp_path): def test_ipc_pubsub_with_sec_and_factory_sub(tmp_path): """Test pub-sub on a secure ipc socket.""" - # create_keys(tmp_path) - server_public_key_file, server_secret_key_file = zmq.auth.create_certificates(tmp_path, "server") client_public_key_file, client_secret_key_file = zmq.auth.create_certificates(tmp_path, "client")