From a15feded71dd47202db169613effdafc468a8cf3 Mon Sep 17 00:00:00 2001 From: Sam Gross Date: Tue, 23 Jul 2024 13:06:03 -0400 Subject: [PATCH] gh-120974: Make _asyncio._leave_task atomic in the free-threaded build (#122139) * gh-120974: Make _asyncio._leave_task atomic in the free-threaded build Update `_PyDict_DelItemIf` to allow for an argument to be passed to the predicate. --- Include/internal/pycore_dict.h | 8 +++++-- Modules/_asynciomodule.c | 42 +++++++++++++++++++--------------- Modules/_weakref.c | 13 +++-------- Objects/dictobject.c | 30 ++++++++++++------------ 4 files changed, 48 insertions(+), 45 deletions(-) diff --git a/Include/internal/pycore_dict.h b/Include/internal/pycore_dict.h index a4bdf0d7ad8283..fc304aca7fea10 100644 --- a/Include/internal/pycore_dict.h +++ b/Include/internal/pycore_dict.h @@ -14,8 +14,12 @@ extern "C" { // Unsafe flavor of PyDict_GetItemWithError(): no error checking extern PyObject* _PyDict_GetItemWithError(PyObject *dp, PyObject *key); -extern int _PyDict_DelItemIf(PyObject *mp, PyObject *key, - int (*predicate)(PyObject *value)); +// Delete an item from a dict if a predicate is true +// Returns -1 on error, 1 if the item was deleted, 0 otherwise +// Export for '_asyncio' shared extension +PyAPI_FUNC(int) _PyDict_DelItemIf(PyObject *mp, PyObject *key, + int (*predicate)(PyObject *value, void *arg), + void *arg); // "KnownHash" variants // Export for '_asyncio' shared extension diff --git a/Modules/_asynciomodule.c b/Modules/_asynciomodule.c index c44e89d98256fe..1a223f9bd0cbae 100644 --- a/Modules/_asynciomodule.c +++ b/Modules/_asynciomodule.c @@ -1994,30 +1994,36 @@ enter_task(asyncio_state *state, PyObject *loop, PyObject *task) return 0; } +static int +err_leave_task(PyObject *item, PyObject *task) +{ + PyErr_Format( + PyExc_RuntimeError, + "Leaving task %R does not match the current task %R.", + task, item); + return -1; +} + +static int +leave_task_predicate(PyObject *item, void *task) +{ + if (item != task) { + return err_leave_task(item, (PyObject *)task); + } + return 1; +} static int leave_task(asyncio_state *state, PyObject *loop, PyObject *task) /*[clinic end generated code: output=0ebf6db4b858fb41 input=51296a46313d1ad8]*/ { - PyObject *item; - Py_hash_t hash; - hash = PyObject_Hash(loop); - if (hash == -1) { - return -1; - } - item = _PyDict_GetItem_KnownHash(state->current_tasks, loop, hash); - if (item != task) { - if (item == NULL) { - /* Not entered, replace with None */ - item = Py_None; - } - PyErr_Format( - PyExc_RuntimeError, - "Leaving task %R does not match the current task %R.", - task, item, NULL); - return -1; + int res = _PyDict_DelItemIf(state->current_tasks, loop, + leave_task_predicate, task); + if (res == 0) { + // task was not found + return err_leave_task(Py_None, task); } - return _PyDict_DelItem_KnownHash(state->current_tasks, loop, hash); + return res; } static PyObject * diff --git a/Modules/_weakref.c b/Modules/_weakref.c index a5c15c0f10b930..ecaa08ff60f203 100644 --- a/Modules/_weakref.c +++ b/Modules/_weakref.c @@ -31,7 +31,7 @@ _weakref_getweakrefcount_impl(PyObject *module, PyObject *object) static int -is_dead_weakref(PyObject *value) +is_dead_weakref(PyObject *value, void *unused) { if (!PyWeakref_Check(value)) { PyErr_SetString(PyExc_TypeError, "not a weakref"); @@ -56,15 +56,8 @@ _weakref__remove_dead_weakref_impl(PyObject *module, PyObject *dct, PyObject *key) /*[clinic end generated code: output=d9ff53061fcb875c input=19fc91f257f96a1d]*/ { - if (_PyDict_DelItemIf(dct, key, is_dead_weakref) < 0) { - if (PyErr_ExceptionMatches(PyExc_KeyError)) - /* This function is meant to allow safe weak-value dicts - with GC in another thread (see issue #28427), so it's - ok if the key doesn't exist anymore. - */ - PyErr_Clear(); - else - return NULL; + if (_PyDict_DelItemIf(dct, key, is_dead_weakref, NULL) < 0) { + return NULL; } Py_RETURN_NONE; } diff --git a/Objects/dictobject.c b/Objects/dictobject.c index 7310c3c8e13b5b..ee88576cc77dec 100644 --- a/Objects/dictobject.c +++ b/Objects/dictobject.c @@ -2508,7 +2508,7 @@ delete_index_from_values(PyDictValues *values, Py_ssize_t ix) values->size = size; } -static int +static void delitem_common(PyDictObject *mp, Py_hash_t hash, Py_ssize_t ix, PyObject *old_value, uint64_t new_version) { @@ -2550,7 +2550,6 @@ delitem_common(PyDictObject *mp, Py_hash_t hash, Py_ssize_t ix, Py_DECREF(old_value); ASSERT_CONSISTENT(mp); - return 0; } int @@ -2593,7 +2592,8 @@ delitem_knownhash_lock_held(PyObject *op, PyObject *key, Py_hash_t hash) PyInterpreterState *interp = _PyInterpreterState_GET(); uint64_t new_version = _PyDict_NotifyEvent( interp, PyDict_EVENT_DELETED, mp, key, NULL); - return delitem_common(mp, hash, ix, old_value, new_version); + delitem_common(mp, hash, ix, old_value, new_version); + return 0; } int @@ -2608,7 +2608,8 @@ _PyDict_DelItem_KnownHash(PyObject *op, PyObject *key, Py_hash_t hash) static int delitemif_lock_held(PyObject *op, PyObject *key, - int (*predicate)(PyObject *value)) + int (*predicate)(PyObject *value, void *arg), + void *arg) { Py_ssize_t ix; PyDictObject *mp; @@ -2618,24 +2619,20 @@ delitemif_lock_held(PyObject *op, PyObject *key, ASSERT_DICT_LOCKED(op); - if (!PyDict_Check(op)) { - PyErr_BadInternalCall(); - return -1; - } assert(key); hash = PyObject_Hash(key); if (hash == -1) return -1; mp = (PyDictObject *)op; ix = _Py_dict_lookup(mp, key, hash, &old_value); - if (ix == DKIX_ERROR) + if (ix == DKIX_ERROR) { return -1; + } if (ix == DKIX_EMPTY || old_value == NULL) { - _PyErr_SetKeyError(key); - return -1; + return 0; } - res = predicate(old_value); + res = predicate(old_value, arg); if (res == -1) return -1; @@ -2643,7 +2640,8 @@ delitemif_lock_held(PyObject *op, PyObject *key, PyInterpreterState *interp = _PyInterpreterState_GET(); uint64_t new_version = _PyDict_NotifyEvent( interp, PyDict_EVENT_DELETED, mp, key, NULL); - return delitem_common(mp, hash, ix, old_value, new_version); + delitem_common(mp, hash, ix, old_value, new_version); + return 1; } else { return 0; } @@ -2655,11 +2653,13 @@ delitemif_lock_held(PyObject *op, PyObject *key, */ int _PyDict_DelItemIf(PyObject *op, PyObject *key, - int (*predicate)(PyObject *value)) + int (*predicate)(PyObject *value, void *arg), + void *arg) { + assert(PyDict_Check(op)); int res; Py_BEGIN_CRITICAL_SECTION(op); - res = delitemif_lock_held(op, key, predicate); + res = delitemif_lock_held(op, key, predicate, arg); Py_END_CRITICAL_SECTION(); return res; }