diff --git a/pytest_asyncio/plugin.py b/pytest_asyncio/plugin.py index e1ce5c9d..12ead10f 100644 --- a/pytest_asyncio/plugin.py +++ b/pytest_asyncio/plugin.py @@ -14,6 +14,7 @@ from collections.abc import ( AsyncIterator, Awaitable, + Coroutine as AbstractCoroutine, Generator, Iterable, Iterator, @@ -410,7 +411,11 @@ def _get_event_loop_fixture_id_for_async_fixture( return event_loop_fixture_id -def _create_task_in_context(loop, coro, context): +def _create_task_in_context( + loop: asyncio.AbstractEventLoop, + coro: AbstractCoroutine[Any, Any, _T], + context: contextvars.Context, +) -> asyncio.Task[_T]: """ Return an asyncio task that runs the coro in the specified context, if possible. diff --git a/tests/async_fixtures/test_async_fixtures_contextvars.py b/tests/async_fixtures/test_async_fixtures_contextvars.py index 8d67db49..ff79e17e 100644 --- a/tests/async_fixtures/test_async_fixtures_contextvars.py +++ b/tests/async_fixtures/test_async_fixtures_contextvars.py @@ -6,68 +6,242 @@ from __future__ import annotations import sys -from contextlib import contextmanager -from contextvars import ContextVar +from textwrap import dedent import pytest - -_context_var = ContextVar("context_var") +from pytest import Pytester + +_prelude = dedent( + """ + import pytest + import pytest_asyncio + from contextlib import contextmanager + from contextvars import ContextVar + + _context_var = ContextVar("context_var") + + @contextmanager + def context_var_manager(value): + token = _context_var.set(value) + try: + yield + finally: + _context_var.reset(token) +""" +) -@contextmanager -def context_var_manager(value): - token = _context_var.set(value) - try: - yield - finally: - _context_var.reset(token) +def test_var_from_sync_generator_propagates_to_async(pytester: Pytester): + pytester.makeini("[pytest]\nasyncio_default_fixture_loop_scope = function") + pytester.makepyfile( + _prelude + + dedent( + """ + @pytest.fixture + def var_fixture(): + with context_var_manager("value"): + yield + @pytest_asyncio.fixture + async def check_var_fixture(var_fixture): + assert _context_var.get() == "value" -@pytest.fixture(scope="function") -async def no_var_fixture(): - with pytest.raises(LookupError): - _context_var.get() - yield - with pytest.raises(LookupError): - _context_var.get() + @pytest.mark.asyncio + async def test(check_var_fixture): + assert _context_var.get() == "value" + """ + ) + ) + result = pytester.runpytest("--asyncio-mode=strict") + result.assert_outcomes(passed=1) -@pytest.fixture(scope="function") -async def var_fixture_1(no_var_fixture): - with context_var_manager("value1"): - yield +@pytest.mark.xfail( + sys.version_info < (3, 11), + reason="requires asyncio Task context support", + strict=True, +) +def test_var_from_async_generator_propagates_to_sync(pytester: Pytester): + pytester.makeini("[pytest]\nasyncio_default_fixture_loop_scope = function") + pytester.makepyfile( + _prelude + + dedent( + """ + @pytest_asyncio.fixture + async def var_fixture(): + with context_var_manager("value"): + yield + + @pytest.fixture + def check_var_fixture(var_fixture): + assert _context_var.get() == "value" + + @pytest.mark.asyncio + async def test(check_var_fixture): + assert _context_var.get() == "value" + """ + ) + ) + result = pytester.runpytest("--asyncio-mode=strict") + result.assert_outcomes(passed=1) -@pytest.fixture(scope="function") -async def var_nop_fixture(var_fixture_1): - with context_var_manager(_context_var.get()): - yield +@pytest.mark.xfail( + sys.version_info < (3, 11), + reason="requires asyncio Task context support", + strict=True, +) +def test_var_from_async_fixture_propagates_to_sync(pytester: Pytester): + pytester.makeini("[pytest]\nasyncio_default_fixture_loop_scope = function") + pytester.makepyfile( + _prelude + + dedent( + """ + @pytest_asyncio.fixture + async def var_fixture(): + _context_var.set("value") + # Rely on async fixture teardown to reset the context var. + + @pytest.fixture + def check_var_fixture(var_fixture): + assert _context_var.get() == "value" + + def test(check_var_fixture): + assert _context_var.get() == "value" + """ + ) + ) + result = pytester.runpytest("--asyncio-mode=strict") + result.assert_outcomes(passed=1) -@pytest.fixture(scope="function") -def var_fixture_2(var_nop_fixture): - assert _context_var.get() == "value1" - with context_var_manager("value2"): - yield +@pytest.mark.xfail( + sys.version_info < (3, 11), + reason="requires asyncio Task context support", + strict=True, +) +def test_var_from_generator_reset_before_previous_fixture_cleanup(pytester: Pytester): + pytester.makeini("[pytest]\nasyncio_default_fixture_loop_scope = function") + pytester.makepyfile( + _prelude + + dedent( + """ + @pytest_asyncio.fixture + async def no_var_fixture(): + with pytest.raises(LookupError): + _context_var.get() + yield + with pytest.raises(LookupError): + _context_var.get() + + @pytest_asyncio.fixture + async def var_fixture(no_var_fixture): + with context_var_manager("value"): + yield + + @pytest.mark.asyncio + async def test(var_fixture): + assert _context_var.get() == "value" + """ + ) + ) + result = pytester.runpytest("--asyncio-mode=strict") + result.assert_outcomes(passed=1) -@pytest.fixture(scope="function") -async def var_fixture_3(var_fixture_2): - assert _context_var.get() == "value2" - with context_var_manager("value3"): - yield +@pytest.mark.xfail( + sys.version_info < (3, 11), + reason="requires asyncio Task context support", + strict=True, +) +def test_var_from_fixture_reset_before_previous_fixture_cleanup(pytester: Pytester): + pytester.makeini("[pytest]\nasyncio_default_fixture_loop_scope = function") + pytester.makepyfile( + _prelude + + dedent( + """ + @pytest_asyncio.fixture + async def no_var_fixture(): + with pytest.raises(LookupError): + _context_var.get() + yield + with pytest.raises(LookupError): + _context_var.get() + + @pytest_asyncio.fixture + async def var_fixture(no_var_fixture): + _context_var.set("value") + # Rely on async fixture teardown to reset the context var. + + @pytest.mark.asyncio + async def test(var_fixture): + assert _context_var.get() == "value" + """ + ) + ) + result = pytester.runpytest("--asyncio-mode=strict") + result.assert_outcomes(passed=1) -@pytest.fixture(scope="function") -async def var_fixture_4(var_fixture_3, request): - assert _context_var.get() == "value3" - _context_var.set("value4") - # Rely on fixture teardown to reset the context var. +@pytest.mark.xfail( + sys.version_info < (3, 11), + reason="requires asyncio Task context support", + strict=True, +) +def test_var_previous_value_restored_after_fixture(pytester: Pytester): + pytester.makeini("[pytest]\nasyncio_default_fixture_loop_scope = function") + pytester.makepyfile( + _prelude + + dedent( + """ + @pytest_asyncio.fixture + async def var_fixture_1(): + with context_var_manager("value1"): + yield + assert _context_var.get() == "value1" + + @pytest_asyncio.fixture + async def var_fixture_2(var_fixture_1): + with context_var_manager("value2"): + yield + assert _context_var.get() == "value2" + + @pytest.mark.asyncio + async def test(var_fixture_2): + assert _context_var.get() == "value2" + """ + ) + ) + result = pytester.runpytest("--asyncio-mode=strict") + result.assert_outcomes(passed=1) -@pytest.mark.asyncio @pytest.mark.xfail( - sys.version_info < (3, 11), reason="requires asyncio Task context support" + sys.version_info < (3, 11), + reason="requires asyncio Task context support", + strict=True, ) -async def test(var_fixture_4): - assert _context_var.get() == "value4" +def test_var_set_to_existing_value_ok(pytester: Pytester): + pytester.makeini("[pytest]\nasyncio_default_fixture_loop_scope = function") + pytester.makepyfile( + _prelude + + dedent( + """ + @pytest_asyncio.fixture + async def var_fixture(): + with context_var_manager("value"): + yield + + @pytest_asyncio.fixture + async def same_var_fixture(var_fixture): + with context_var_manager(_context_var.get()): + yield + + @pytest.mark.asyncio + async def test(same_var_fixture): + assert _context_var.get() == "value" + """ + ) + ) + result = pytester.runpytest("--asyncio-mode=strict") + result.assert_outcomes(passed=1)