diff --git a/src/aiofiles/base.py b/src/aiofiles/base.py index 64f7d6b..35b26d2 100644 --- a/src/aiofiles/base.py +++ b/src/aiofiles/base.py @@ -1,7 +1,6 @@ -"""Various base classes.""" +from asyncio import get_running_loop from collections.abc import Awaitable from contextlib import AbstractAsyncContextManager -from asyncio import get_running_loop class AsyncBase: @@ -23,11 +22,10 @@ def __repr__(self): async def __anext__(self): """Simulate normal file iteration.""" - line = await self.readline() - if line: + + if line := await self.readline(): return line - else: - raise StopAsyncIteration + raise StopAsyncIteration class AsyncIndirectBase(AsyncBase): diff --git a/tests/threadpool/test_open.py b/tests/threadpool/test_open.py index 654e06b..497a403 100644 --- a/tests/threadpool/test_open.py +++ b/tests/threadpool/test_open.py @@ -1,8 +1,16 @@ -"""Test the open functionality.""" -from aiofiles.threadpool import open as aioopen, wrap +import asyncio +from pathlib import Path + +from aiofiles.threadpool import open as aioopen + import pytest +RESOURCES_DIR = Path(__file__).parent.parent / "resources" +TEST_FILE = RESOURCES_DIR / "test_file1.txt" +TEST_FILE_CONTENTS = "0123456789" + + @pytest.mark.parametrize("mode", ["r", "rb"]) async def test_file_not_found(mode): filename = "non_existent" @@ -25,7 +33,38 @@ async def test_file_not_found(mode): assert str(actual) == str(expected) -def test_unsupported_wrap(): - """A type error should be raised when wrapping something unsupported.""" - with pytest.raises(TypeError): - wrap(int) +async def test_file_async_context_aexit(): + async with aioopen(TEST_FILE) as fp: + pass + + with pytest.raises(ValueError): + line = await fp.read() + + async with aioopen(TEST_FILE) as fp: + line = await fp.read() + assert line == TEST_FILE_CONTENTS + + +async def test_filetask_async_context_aexit(): + async def _process_test_file(file_ctx, sleep_time: float = 1.0): + nonlocal file_ref + async with file_ctx as fp: + file_ref = file_ctx._obj + await asyncio.sleep(sleep_time) + await fp.read() + + cancel_time, sleep_time = 0.1, 10 + assert cancel_time <= (sleep_time / 10) + + file_ref = None + file_ctx = aioopen(TEST_FILE) + + task = asyncio.create_task( + _process_test_file(file_ctx=file_ctx, sleep_time=sleep_time) + ) + try: + await asyncio.wait_for(task, timeout=cancel_time) + except asyncio.TimeoutError: + assert task.cancelled + + assert file_ref.closed diff --git a/tests/threadpool/test_wrap.py b/tests/threadpool/test_wrap.py new file mode 100644 index 0000000..114503e --- /dev/null +++ b/tests/threadpool/test_wrap.py @@ -0,0 +1,10 @@ +from aiofiles.threadpool import wrap + +import pytest + + +def test_unsupported_wrap(): + """Raising TypeError when wrapping unsupported entities.""" + + with pytest.raises(TypeError): + wrap(int)