Skip to content

Commit

Permalink
Add deadlock prevention
Browse files Browse the repository at this point in the history
Detects and prevents deadlocks during the library misuse, eg. by
injecting code into the critical sections that itself might want to
obtain the relevant lock.

A follow up to prometheus#1076.

Signed-off-by: Przemysław Suliga <mail@suligap.net>
  • Loading branch information
suligap committed Dec 8, 2024
1 parent ef95c4b commit fa51a75
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 13 deletions.
3 changes: 3 additions & 0 deletions prometheus_client/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@

class PrometheusClientRuntimeError(RuntimeError):
pass
9 changes: 4 additions & 5 deletions prometheus_client/metrics.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
from threading import Lock
import time
import types
from typing import (
Expand All @@ -13,7 +12,7 @@
from .metrics_core import Metric
from .registry import Collector, CollectorRegistry, REGISTRY
from .samples import Exemplar, Sample
from .utils import floatToGoString, INF
from .utils import floatToGoString, INF, WarnLock
from .validation import (
_validate_exemplar, _validate_labelnames, _validate_metric_name,
)
Expand Down Expand Up @@ -120,7 +119,7 @@ def __init__(self: T,

if self._is_parent():
# Prepare the fields needed for child metrics.
self._lock = Lock()
self._lock = WarnLock()
self._metrics: Dict[Sequence[str], T] = {}

if self._is_observable():
Expand Down Expand Up @@ -673,7 +672,7 @@ class Info(MetricWrapperBase):

def _metric_init(self):
self._labelname_set = set(self._labelnames)
self._lock = Lock()
self._lock = WarnLock()
self._value = {}

def info(self, val: Dict[str, str]) -> None:
Expand Down Expand Up @@ -735,7 +734,7 @@ def __init__(self,

def _metric_init(self) -> None:
self._value = 0
self._lock = Lock()
self._lock = WarnLock()

def state(self, state: str) -> None:
"""Set enum metric state."""
Expand Down
4 changes: 2 additions & 2 deletions prometheus_client/registry.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from abc import ABC, abstractmethod
import copy
from threading import Lock
from typing import Dict, Iterable, List, Optional

from .metrics_core import Metric
from .utils import WarnLock


# Ideally this would be a Protocol, but Protocols are only available in Python >= 3.8.
Expand All @@ -30,7 +30,7 @@ def __init__(self, auto_describe: bool = False, target_info: Optional[Dict[str,
self._collector_to_names: Dict[Collector, List[str]] = {}
self._names_to_collectors: Dict[str, Collector] = {}
self._auto_describe = auto_describe
self._lock = Lock()
self._lock = WarnLock()
self._target_info: Optional[Dict[str, str]] = {}
self.set_target_info(target_info)

Expand Down
40 changes: 40 additions & 0 deletions prometheus_client/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import math
from threading import Lock, RLock

from .errors import PrometheusClientRuntimeError

INF = float("inf")
MINUS_INF = float("-inf")
Expand All @@ -22,3 +25,40 @@ def floatToGoString(d):
mantissa = f'{s[0]}.{s[1:dot]}{s[dot + 1:]}'.rstrip('0.')
return f'{mantissa}e+0{dot - 1}'
return s


class WarnLock:
"""A wrapper around RLock and Lock that prevents deadlocks.
Raises a RuntimeError when it detects attempts to re-enter the critical
section from a single thread. Intended to be used as a context manager.
"""
error_msg = (
'Attempt to enter a non reentrant context from a single thread.'
' It is possible that the client code is trying to register or update'
' metrics from within metric registration code or from a signal handler'
' while metrics are being registered or updated.'
' This is unsafe and cannot be allowed. It would result in a deadlock'
' if this exception was not raised.'
)

def __init__(self):
self._rlock = RLock()
self._lock = Lock()

def __enter__(self):
self._rlock.acquire()
if not self._lock.acquire(blocking=False):
self._rlock.release()
raise PrometheusClientRuntimeError(self.error_msg)

def __exit__(self, exc_type, exc_value, traceback):
self._lock.release()
self._rlock.release()

def _locked(self):
# For use in tests.
if self._rlock.acquire(blocking=False):
self._rlock.release()
return False
return True
9 changes: 4 additions & 5 deletions prometheus_client/values.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import os
from threading import Lock
import warnings

from .mmap_dict import mmap_key, MmapedDict
from .utils import WarnLock


class MutexValue:
Expand All @@ -13,7 +13,7 @@ class MutexValue:
def __init__(self, typ, metric_name, name, labelnames, labelvalues, help_text, **kwargs):
self._value = 0.0
self._exemplar = None
self._lock = Lock()
self._lock = WarnLock()

def inc(self, amount):
with self._lock:
Expand Down Expand Up @@ -47,10 +47,9 @@ def MultiProcessValue(process_identifier=os.getpid):
files = {}
values = []
pid = {'value': process_identifier()}
# Use a single global lock when in multi-processing mode
# as we presume this means there is no threading going on.
# Use a single global lock when in multi-processing mode.
# This avoids the need to also have mutexes in __MmapDict.
lock = Lock()
lock = WarnLock()

class MmapedValue:
"""A float protected by a mutex backed by a per-process mmaped file."""
Expand Down
29 changes: 28 additions & 1 deletion tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
StateSetMetricFamily, Summary, SummaryMetricFamily, UntypedMetricFamily,
)
from prometheus_client.decorator import getargspec
from prometheus_client.errors import PrometheusClientRuntimeError
from prometheus_client.metrics import _get_use_created
from prometheus_client.validation import (
disable_legacy_validation, enable_legacy_validation,
Expand Down Expand Up @@ -134,6 +135,19 @@ def test_exemplar_too_long(self):
'y123456': '7+15 characters',
})

def test_single_thread_deadlock_detection(self):
counter = self.counter

class Tracked(float):
def __radd__(self, other):
counter.inc(10)
return self + other

expected_msg = 'Attempt to enter a non reentrant context from a single thread.'
self.assertRaisesRegex(
PrometheusClientRuntimeError, expected_msg, counter.inc, Tracked(100)
)


class TestDisableCreated(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -1004,7 +1018,20 @@ def test_restricted_registry_does_not_yield_while_locked(self):
m = Metric('target', 'Target metadata', 'info')
m.samples = [Sample('target_info', {'foo': 'bar'}, 1)]
for _ in registry.restricted_registry(['target_info', 's_sum']).collect():
self.assertFalse(registry._lock.locked())
self.assertFalse(registry._lock._locked())

def test_registry_deadlock_detection(self):
registry = CollectorRegistry(auto_describe=True)

class RecursiveCollector:
def collect(self):
Counter('x', 'help', registry=registry)
return [CounterMetricFamily('c_total', 'help', value=1)]

expected_msg = 'Attempt to enter a non reentrant context from a single thread.'
self.assertRaisesRegex(
PrometheusClientRuntimeError, expected_msg, registry.register, RecursiveCollector()
)


if __name__ == '__main__':
Expand Down

0 comments on commit fa51a75

Please sign in to comment.