diff --git a/src/aiofiles/base.py b/src/aiofiles/base.py index 64f7d6b..1827f7d 100644 --- a/src/aiofiles/base.py +++ b/src/aiofiles/base.py @@ -1,7 +1,7 @@ -"""Various base classes.""" +import asyncio from collections.abc import Awaitable from contextlib import AbstractAsyncContextManager -from asyncio import get_running_loop +from typing import Optional class AsyncBase: @@ -12,7 +12,7 @@ def __init__(self, file, loop, executor): @property def _loop(self): - return self._ref_loop or get_running_loop() + return self._ref_loop or asyncio.get_running_loop() def __aiter__(self): """We are our own iterator.""" @@ -23,11 +23,10 @@ def __repr__(self): async def __anext__(self): """Simulate normal file iteration.""" - line = await self.readline() - if line: + + if (line := await self.readline()): return line - else: - raise StopAsyncIteration + raise StopAsyncIteration class AsyncIndirectBase(AsyncBase): @@ -63,7 +62,11 @@ async def __aenter__(self): return await self async def __aexit__(self, exc_type, exc_val, exc_tb): - await get_running_loop().run_in_executor( - None, self._obj._file.__exit__, exc_type, exc_val, exc_tb + await asyncio.get_running_loop().run_in_executor( + None, + self._obj._file.__exit__, + exc_type, + exc_val, + exc_tb, ) self._obj = None diff --git a/tests/threadpool/test_open.py b/tests/threadpool/test_open.py index 654e06b..5e2c40a 100644 --- a/tests/threadpool/test_open.py +++ b/tests/threadpool/test_open.py @@ -1,8 +1,16 @@ -"""Test the open functionality.""" -from aiofiles.threadpool import open as aioopen, wrap +import asyncio +from pathlib import Path + +from aiofiles.threadpool import open as aioopen + import pytest +RESOURCES_DIR = Path(__file__).parent.parent / "resources" +TEST_FILE = RESOURCES_DIR / "test_file1.txt" +TEST_FILE_CONTENTS = "0123456789" + + @pytest.mark.parametrize("mode", ["r", "rb"]) async def test_file_not_found(mode): filename = "non_existent" @@ -25,7 +33,37 @@ async def test_file_not_found(mode): assert str(actual) == str(expected) -def test_unsupported_wrap(): - """A type error should be raised when wrapping something unsupported.""" - with pytest.raises(TypeError): - wrap(int) +async def test_file_async_context_aexit(): + async with aioopen(TEST_FILE) as fp: + pass + + with pytest.raises(ValueError): + line = await fp.read() + + async with aioopen(TEST_FILE) as fp: + line = await fp.read() + assert line == TEST_FILE_CONTENTS + + +async def test_filetask_async_context_aexit(): + async def _process_test_file(file_ctx, sleep_time: float = 1.0): + nonlocal file_ref + async with file_ctx: + file_ref = file_ctx._obj + await asyncio.sleep(sleep_time) + + cancel_time, sleep_time = 0.1, 10 + assert cancel_time <= (sleep_time / 10) + + file_ref = None + file_ctx = aioopen(TEST_FILE) + + task = asyncio.create_task( + _process_test_file(file_ctx=file_ctx, sleep_time=sleep_time) + ) + try: + await asyncio.wait_for(task, timeout=cancel_time) + except asyncio.TimeoutError: + assert task.cancelled + + assert file_ref.closed diff --git a/tests/threadpool/test_wrap.py b/tests/threadpool/test_wrap.py new file mode 100644 index 0000000..114503e --- /dev/null +++ b/tests/threadpool/test_wrap.py @@ -0,0 +1,10 @@ +from aiofiles.threadpool import wrap + +import pytest + + +def test_unsupported_wrap(): + """Raising TypeError when wrapping unsupported entities.""" + + with pytest.raises(TypeError): + wrap(int)