From 086126fd8d0065336e40f7ea05546c5e62313e1b Mon Sep 17 00:00:00 2001 From: Delgan <4193924+Delgan@users.noreply.github.com> Date: Mon, 11 Sep 2023 16:19:41 +0200 Subject: [PATCH] Fix error using "set_start_method()" after "logger" import (#974) Calling "multiprocessing.get_context(method=None)" had the unexpected side effect of also fixing the global start method (which can't be changed afterwards). --- CHANGELOG.rst | 1 + loguru/_handler.py | 12 ++++++-- loguru/_logger.py | 4 +-- tests/test_add_option_context.py | 52 ++++++++++++++++++-------------- 4 files changed, 42 insertions(+), 27 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 9ed06198..7a8d112d 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -2,6 +2,7 @@ ============= - Add support for formatting of ``ExceptionGroup`` errors (`#805 `_). +- Fix possible ``RuntimeError`` when using ``multiprocessing.set_start_method()`` after importing the ``logger`` (`#974 `_) - Fix formatting of possible ``__notes__`` attached to an ``Exception`` (`#980 `_). diff --git a/loguru/_handler.py b/loguru/_handler.py index 6b684d21..81a3dca0 100644 --- a/loguru/_handler.py +++ b/loguru/_handler.py @@ -1,5 +1,6 @@ import functools import json +import multiprocessing import os import threading from contextlib import contextmanager @@ -88,10 +89,15 @@ def __init__( self._decolorized_format = self._formatter.strip() if self._enqueue: - self._queue = self._multiprocessing_context.SimpleQueue() + if self._multiprocessing_context is None: + self._queue = multiprocessing.SimpleQueue() + self._confirmation_event = multiprocessing.Event() + self._confirmation_lock = multiprocessing.Lock() + else: + self._queue = self._multiprocessing_context.SimpleQueue() + self._confirmation_event = self._multiprocessing_context.Event() + self._confirmation_lock = self._multiprocessing_context.Lock() self._queue_lock = create_handler_lock() - self._confirmation_event = self._multiprocessing_context.Event() - self._confirmation_lock = self._multiprocessing_context.Lock() self._owner_process_pid = os.getpid() self._thread = Thread( target=self._queued_writer, daemon=True, name="loguru-writer-%d" % self._id diff --git a/loguru/_logger.py b/loguru/_logger.py index 09713f64..f750967a 100644 --- a/loguru/_logger.py +++ b/loguru/_logger.py @@ -967,9 +967,9 @@ def add( if not isinstance(encoding, str): encoding = "ascii" - if context is None or isinstance(context, str): + if isinstance(context, str): context = get_context(context) - elif not isinstance(context, BaseContext): + elif context is not None and not isinstance(context, BaseContext): raise TypeError( "Invalid context, it should be a string or a multiprocessing context, " "not: '%s'" % type(context).__name__ diff --git a/tests/test_add_option_context.py b/tests/test_add_option_context.py index ee6d7c1f..c9d3c27e 100644 --- a/tests/test_add_option_context.py +++ b/tests/test_add_option_context.py @@ -1,55 +1,63 @@ import multiprocessing import os -from unittest.mock import MagicMock +from unittest.mock import patch import pytest from loguru import logger -def get_handler_context(): - # No better way to test correct value than to access the private attribute. - handler = next(iter(logger._core.handlers.values())) - return handler._multiprocessing_context +@pytest.fixture +def reset_start_method(): + yield + multiprocessing.set_start_method(None, force=True) -def test_default_context(): - logger.add(lambda _: None, context=None) - assert get_handler_context() == multiprocessing.get_context(None) +@pytest.mark.usefixtures("reset_start_method") +def test_using_multiprocessing_directly_if_context_is_none(): + logger.add(lambda _: None, enqueue=True, context=None) + assert multiprocessing.get_start_method(allow_none=True) is not None @pytest.mark.skipif(os.name == "nt", reason="Windows does not support forking") @pytest.mark.parametrize("context_name", ["fork", "forkserver"]) def test_fork_context_as_string(context_name): - logger.add(lambda _: None, context=context_name) - assert get_handler_context() == multiprocessing.get_context(context_name) + context = multiprocessing.get_context(context_name) + with patch.object(type(context), "Lock", wraps=context.Lock) as mock: + logger.add(lambda _: None, context=context_name, enqueue=True) + assert mock.called + assert multiprocessing.get_start_method(allow_none=True) is None def test_spawn_context_as_string(): - logger.add(lambda _: None, context="spawn") - assert get_handler_context() == multiprocessing.get_context("spawn") + context = multiprocessing.get_context("spawn") + with patch.object(type(context), "Lock", wraps=context.Lock) as mock: + logger.add(lambda _: None, context="spawn", enqueue=True) + assert mock.called + assert multiprocessing.get_start_method(allow_none=True) is None @pytest.mark.skipif(os.name == "nt", reason="Windows does not support forking") @pytest.mark.parametrize("context_name", ["fork", "forkserver"]) def test_fork_context_as_object(context_name): context = multiprocessing.get_context(context_name) - logger.add(lambda _: None, context=context) - assert get_handler_context() == context + with patch.object(type(context), "Lock", wraps=context.Lock) as mock: + logger.add(lambda _: None, context=context, enqueue=True) + assert mock.called + assert multiprocessing.get_start_method(allow_none=True) is None def test_spawn_context_as_object(): context = multiprocessing.get_context("spawn") - logger.add(lambda _: None, context=context) - assert get_handler_context() == context + with patch.object(type(context), "Lock", wraps=context.Lock) as mock: + logger.add(lambda _: None, context=context, enqueue=True) + assert mock.called + assert multiprocessing.get_start_method(allow_none=True) is None -def test_context_effectively_used(): - default_context = multiprocessing.get_context() - mocked_context = MagicMock(spec=default_context, wraps=default_context) - logger.add(lambda _: None, context=mocked_context, enqueue=True) - logger.complete() - assert mocked_context.Lock.called +def test_global_start_method_is_none_if_enqueue_is_false(): + logger.add(lambda _: None, enqueue=False, context=None) + assert multiprocessing.get_start_method(allow_none=True) is None def test_invalid_context_name():