Skip to content

Commit

Permalink
Fix race condition in authentication
Browse files Browse the repository at this point in the history
  • Loading branch information
mraspaud committed Feb 3, 2025
1 parent 010be39 commit d336876
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 59 deletions.
13 changes: 0 additions & 13 deletions posttroll/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,3 @@
config = Config("posttroll", defaults=[dict(backend="unsecure_zmq")])

logger = logging.getLogger(__name__)


def get_context():
"""Provide the context to use.
This function takes care of creating new contexts in case of forks.
"""
backend = config["backend"]
if "zmq" in backend:
from posttroll.backends.zmq import get_context
return get_context()
else:
raise NotImplementedError(f"No support for backend {backend} implemented (yet?).")
15 changes: 10 additions & 5 deletions posttroll/backends/zmq/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,18 @@

from contextlib import suppress
from functools import cache
from threading import Lock
from urllib.parse import urlsplit, urlunsplit

import zmq
from zmq.auth import load_certificate
from zmq.auth.thread import ThreadAuthenticator

from posttroll import config, get_context
from posttroll import config
from posttroll.backends.zmq import get_context
from posttroll.message import Message

authenticator_lock = Lock()

def close_socket(sock):
"""Close a zmq socket."""
Expand Down Expand Up @@ -49,11 +53,11 @@ def create_secure_client_socket(socket_type):

client_secret_key_file = config["client_secret_key_file"]
server_public_key_file = config["server_public_key_file"]
client_public, client_secret = zmq.auth.load_certificate(client_secret_key_file)
client_public, client_secret = load_certificate(client_secret_key_file)
subscriber.curve_secretkey = client_secret
subscriber.curve_publickey = client_public

server_public, _ = zmq.auth.load_certificate(server_public_key_file)
server_public, _ = load_certificate(server_public_key_file)
# The client must know the server's public key to make a CURVE connection.
subscriber.curve_serverkey = server_public
return subscriber
Expand Down Expand Up @@ -119,14 +123,15 @@ def create_secure_server_socket(socket_type):

ctx = get_context()
# Start an authenticator for this context.
authenticator_thread = get_auth_thread(ctx)
with authenticator_lock:
authenticator_thread = get_auth_thread(ctx)
authenticator_thread.allow(*authorized_sub_addresses)
# Tell authenticator to use the certificate in a directory
authenticator_thread.configure_curve(domain="*", location=clients_public_keys_directory)

server_socket = ctx.socket(socket_type)

server_public, server_secret = zmq.auth.load_certificate(server_secret_key)
server_public, server_secret = load_certificate(server_secret_key)
server_socket.curve_secretkey = server_secret
server_socket.curve_publickey = server_public
server_socket.curve_server = True
Expand Down
13 changes: 0 additions & 13 deletions posttroll/tests/test_nameserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,7 @@
from unittest import mock

import pytest
import zmq

import posttroll.backends.zmq
from posttroll import config
from posttroll.backends.zmq.ns import create_nameserver_address
from posttroll.message import Message
Expand All @@ -21,17 +19,6 @@
from posttroll.tests.test_bbmcast import random_valid_mc_address


@pytest.fixture(autouse=True)
def new_context(monkeypatch):
"""Create a new context for each test."""
context = zmq.Context()
def get_context():
return context
monkeypatch.setattr(posttroll.backends.zmq, "get_context", get_context)
yield
context.term()


@pytest.fixture(autouse=True)
def new_mc_group():
"""Create a unique mc group for each test."""
Expand Down
15 changes: 0 additions & 15 deletions posttroll/tests/test_pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,26 +35,13 @@
from donfig import Config

import posttroll
import posttroll.backends.zmq
from posttroll import config
from posttroll.message import Message
from posttroll.publisher import Publish, Publisher, create_publisher_from_dict_config
from posttroll.subscriber import Subscribe, Subscriber

test_lock = Lock()


@pytest.fixture(autouse=True)
def new_context(monkeypatch):
"""Create a new context for each test."""
context = zmq.Context()
def get_context():
return context
monkeypatch.setattr(posttroll.backends.zmq, "get_context", get_context)
yield
context.term()


def free_port():
"""Get a free port.
Expand Down Expand Up @@ -501,7 +488,6 @@ def test_subscriber_tcp_keepalive_not_set():


def _assert_tcp_keepalive(socket):
import zmq

assert socket.getsockopt(zmq.TCP_KEEPALIVE) == 1
assert socket.getsockopt(zmq.TCP_KEEPALIVE_CNT) == 10
Expand All @@ -510,7 +496,6 @@ def _assert_tcp_keepalive(socket):


def _assert_no_tcp_keepalive(socket):
import zmq

assert socket.getsockopt(zmq.TCP_KEEPALIVE) == -1
assert socket.getsockopt(zmq.TCP_KEEPALIVE_CNT) == -1
Expand Down
13 changes: 0 additions & 13 deletions posttroll/tests/test_secure_zmq_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@
import time
from threading import Thread

import pytest
import zmq.auth

import posttroll.backends.zmq
from posttroll import config
from posttroll.backends.zmq import generate_keys
from posttroll.message import Message
Expand All @@ -19,17 +17,6 @@
from posttroll.tests.test_nameserver import create_nameserver_instance


@pytest.fixture(autouse=True)
def new_context(monkeypatch):
"""Create a new context for each test."""
context = zmq.Context()
def get_context():
return context
monkeypatch.setattr(posttroll.backends.zmq, "get_context", get_context)
yield
context.term()


def create_keys(tmp_path):
"""Create keys."""
base_dir = tmp_path
Expand Down

0 comments on commit d336876

Please sign in to comment.