-
Notifications
You must be signed in to change notification settings - Fork 2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[CHIA-710] Add the concept of 'action scopes' (#18124)
* Add the concept of 'action scopes' * pylint and test coverage * add try/finally * Address comments by @altendky * Address comments by @altendky * 86 memos * Only one callback * pylint * Address more comments by @altendky * remove unused variable * add comment
- Loading branch information
1 parent
a6fca99
commit a36c0b8
Showing
2 changed files
with
292 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
from __future__ import annotations | ||
|
||
from dataclasses import dataclass | ||
from typing import AsyncIterator, final | ||
|
||
import pytest | ||
|
||
from chia.util.action_scope import ActionScope, StateInterface | ||
|
||
|
||
@final | ||
@dataclass | ||
class TestSideEffects: | ||
buf: bytes = b"" | ||
|
||
def __bytes__(self) -> bytes: | ||
return self.buf | ||
|
||
@classmethod | ||
def from_bytes(cls, blob: bytes) -> TestSideEffects: | ||
return cls(blob) | ||
|
||
|
||
async def default_async_callback(interface: StateInterface[TestSideEffects]) -> None: | ||
return None # pragma: no cover | ||
|
||
|
||
# Test adding a callback | ||
def test_set_callback() -> None: | ||
state_interface = StateInterface(TestSideEffects(), True) | ||
state_interface.set_callback(default_async_callback) | ||
assert state_interface._callback == default_async_callback | ||
state_interface_no_callbacks = StateInterface(TestSideEffects(), False) | ||
with pytest.raises(RuntimeError, match="Callback cannot be edited from inside itself"): | ||
state_interface_no_callbacks.set_callback(None) | ||
|
||
|
||
@pytest.fixture(name="action_scope") | ||
async def action_scope_fixture() -> AsyncIterator[ActionScope[TestSideEffects]]: | ||
async with ActionScope.new_scope(TestSideEffects) as scope: | ||
yield scope | ||
|
||
|
||
@pytest.mark.anyio | ||
async def test_new_action_scope(action_scope: ActionScope[TestSideEffects]) -> None: | ||
""" | ||
Assert we can immediately check out some initial state | ||
""" | ||
async with action_scope.use() as interface: | ||
assert interface == StateInterface(TestSideEffects(), True) | ||
|
||
|
||
@pytest.mark.anyio | ||
async def test_scope_persistence(action_scope: ActionScope[TestSideEffects]) -> None: | ||
async with action_scope.use() as interface: | ||
interface.side_effects.buf = b"baz" | ||
|
||
async with action_scope.use() as interface: | ||
assert interface.side_effects.buf == b"baz" | ||
|
||
|
||
@pytest.mark.anyio | ||
async def test_transactionality(action_scope: ActionScope[TestSideEffects]) -> None: | ||
async with action_scope.use() as interface: | ||
interface.side_effects.buf = b"baz" | ||
|
||
with pytest.raises(Exception, match="Going to be caught"): | ||
async with action_scope.use() as interface: | ||
interface.side_effects.buf = b"qat" | ||
raise RuntimeError("Going to be caught") | ||
|
||
async with action_scope.use() as interface: | ||
assert interface.side_effects.buf == b"baz" | ||
|
||
|
||
@pytest.mark.anyio | ||
async def test_callbacks() -> None: | ||
async with ActionScope.new_scope(TestSideEffects) as action_scope: | ||
async with action_scope.use() as interface: | ||
|
||
async def callback(interface: StateInterface[TestSideEffects]) -> None: | ||
interface.side_effects.buf = b"bar" | ||
|
||
interface.set_callback(callback) | ||
|
||
assert action_scope.side_effects.buf == b"bar" | ||
|
||
|
||
@pytest.mark.anyio | ||
async def test_callback_in_callback_error() -> None: | ||
with pytest.raises(RuntimeError, match="Callback"): | ||
async with ActionScope.new_scope(TestSideEffects) as action_scope: | ||
async with action_scope.use() as interface: | ||
|
||
async def callback(interface: StateInterface[TestSideEffects]) -> None: | ||
interface.set_callback(default_async_callback) | ||
|
||
interface.set_callback(callback) | ||
|
||
|
||
@pytest.mark.anyio | ||
async def test_no_callbacks_if_error() -> None: | ||
with pytest.raises(Exception, match="This should prevent the callbacks from being called"): | ||
async with ActionScope.new_scope(TestSideEffects) as action_scope: | ||
async with action_scope.use() as interface: | ||
|
||
async def callback(interface: StateInterface[TestSideEffects]) -> None: | ||
raise NotImplementedError("Should not get here") # pragma: no cover | ||
|
||
interface.set_callback(callback) | ||
|
||
async with action_scope.use() as interface: | ||
raise RuntimeError("This should prevent the callbacks from being called") | ||
|
||
with pytest.raises(Exception, match="This should prevent the callbacks from being called"): | ||
async with ActionScope.new_scope(TestSideEffects) as action_scope: | ||
async with action_scope.use() as interface: | ||
|
||
async def callback2(interface: StateInterface[TestSideEffects]) -> None: | ||
raise NotImplementedError("Should not get here") # pragma: no cover | ||
|
||
interface.set_callback(callback2) | ||
|
||
raise RuntimeError("This should prevent the callbacks from being called") | ||
|
||
|
||
# TODO: add suport, change this test to test it and add a test for nested transactionality | ||
@pytest.mark.anyio | ||
async def test_nested_use_banned(action_scope: ActionScope[TestSideEffects]) -> None: | ||
async with action_scope.use(): | ||
with pytest.raises(RuntimeError, match="cannot currently support nested transactions"): | ||
async with action_scope.use(): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
from __future__ import annotations | ||
|
||
import contextlib | ||
from dataclasses import dataclass, field | ||
from typing import AsyncIterator, Awaitable, Callable, Generic, Optional, Protocol, Type, TypeVar, final | ||
|
||
import aiosqlite | ||
|
||
from chia.util.db_wrapper import DBWrapper2, execute_fetchone | ||
|
||
|
||
class ResourceManager(Protocol): | ||
@classmethod | ||
@contextlib.asynccontextmanager | ||
async def managed(cls, initial_resource: SideEffects) -> AsyncIterator[ResourceManager]: # pragma: no cover | ||
# yield included to make this a generator as expected by @contextlib.asynccontextmanager | ||
yield # type: ignore[misc] | ||
|
||
@contextlib.asynccontextmanager | ||
async def use(self) -> AsyncIterator[None]: # pragma: no cover | ||
# yield included to make this a generator as expected by @contextlib.asynccontextmanager | ||
yield | ||
|
||
async def get_resource(self, resource_type: Type[_T_SideEffects]) -> _T_SideEffects: ... | ||
|
||
async def save_resource(self, resource: SideEffects) -> None: ... | ||
|
||
|
||
@dataclass | ||
class SQLiteResourceManager: | ||
|
||
_db: DBWrapper2 | ||
_active_writer: Optional[aiosqlite.Connection] = field(init=False, default=None) | ||
|
||
def get_active_writer(self) -> aiosqlite.Connection: | ||
if self._active_writer is None: | ||
raise RuntimeError("Can only access resources while under `use()` context manager") | ||
|
||
return self._active_writer | ||
|
||
@classmethod | ||
@contextlib.asynccontextmanager | ||
async def managed(cls, initial_resource: SideEffects) -> AsyncIterator[ResourceManager]: | ||
async with DBWrapper2.managed(":memory:", reader_count=0) as db: | ||
self = cls(db) | ||
async with self._db.writer() as conn: | ||
await conn.execute("CREATE TABLE side_effects(total blob)") | ||
await conn.execute( | ||
"INSERT INTO side_effects VALUES(?)", | ||
(bytes(initial_resource),), | ||
) | ||
yield self | ||
|
||
@contextlib.asynccontextmanager | ||
async def use(self) -> AsyncIterator[None]: | ||
if self._active_writer is not None: | ||
raise RuntimeError("SQLiteResourceManager cannot currently support nested transactions") | ||
async with self._db.writer() as conn: | ||
self._active_writer = conn | ||
try: | ||
yield | ||
finally: | ||
self._active_writer = None | ||
|
||
async def get_resource(self, resource_type: Type[_T_SideEffects]) -> _T_SideEffects: | ||
row = await execute_fetchone(self.get_active_writer(), "SELECT total FROM side_effects") | ||
assert row is not None | ||
side_effects = resource_type.from_bytes(row[0]) | ||
return side_effects | ||
|
||
async def save_resource(self, resource: SideEffects) -> None: | ||
# This sets all rows (there's only one) to the new serialization | ||
await self.get_active_writer().execute( | ||
"UPDATE side_effects SET total=?", | ||
(bytes(resource),), | ||
) | ||
|
||
|
||
class SideEffects(Protocol): | ||
def __bytes__(self) -> bytes: ... | ||
|
||
@classmethod | ||
def from_bytes(cls: Type[_T_SideEffects], blob: bytes) -> _T_SideEffects: ... | ||
|
||
|
||
_T_SideEffects = TypeVar("_T_SideEffects", bound=SideEffects) | ||
|
||
|
||
@final | ||
@dataclass | ||
class ActionScope(Generic[_T_SideEffects]): | ||
""" | ||
The idea of an "action" is to map a single client input to many potentially distributed functions and side | ||
effects. The action holds on to a temporary state that the many callers modify at will but only one at a time. | ||
When the action is closed, the state is still available and can be committed elsewhere or discarded. | ||
Utilizes a "resource manager" to hold the state in order to take advantage of rollbacks and prevent concurrent tasks | ||
from interferring with each other. | ||
""" | ||
|
||
_resource_manager: ResourceManager | ||
_side_effects_format: Type[_T_SideEffects] | ||
_callback: Optional[Callable[[StateInterface[_T_SideEffects]], Awaitable[None]]] = None | ||
_final_side_effects: Optional[_T_SideEffects] = field(init=False, default=None) | ||
|
||
@property | ||
def side_effects(self) -> _T_SideEffects: | ||
if self._final_side_effects is None: | ||
raise RuntimeError( | ||
"Can only request ActionScope.side_effects after exiting context manager. " | ||
"While in context manager, use ActionScope.use()." | ||
) | ||
|
||
return self._final_side_effects | ||
|
||
@classmethod | ||
@contextlib.asynccontextmanager | ||
async def new_scope( | ||
cls, | ||
side_effects_format: Type[_T_SideEffects], | ||
resource_manager_backend: Type[ResourceManager] = SQLiteResourceManager, | ||
) -> AsyncIterator[ActionScope[_T_SideEffects]]: | ||
async with resource_manager_backend.managed(side_effects_format()) as resource_manager: | ||
self = cls(_resource_manager=resource_manager, _side_effects_format=side_effects_format) | ||
|
||
yield self | ||
|
||
async with self.use(_callbacks_allowed=False) as interface: | ||
if self._callback is not None: | ||
await self._callback(interface) | ||
self._final_side_effects = interface.side_effects | ||
|
||
@contextlib.asynccontextmanager | ||
async def use(self, _callbacks_allowed: bool = True) -> AsyncIterator[StateInterface[_T_SideEffects]]: | ||
async with self._resource_manager.use(): | ||
side_effects = await self._resource_manager.get_resource(self._side_effects_format) | ||
interface = StateInterface(side_effects, _callbacks_allowed) | ||
|
||
yield interface | ||
|
||
await self._resource_manager.save_resource(interface.side_effects) | ||
self._callback = interface.callback | ||
|
||
|
||
@dataclass | ||
class StateInterface(Generic[_T_SideEffects]): | ||
side_effects: _T_SideEffects | ||
_callbacks_allowed: bool | ||
_callback: Optional[Callable[[StateInterface[_T_SideEffects]], Awaitable[None]]] = None | ||
|
||
@property | ||
def callback(self) -> Optional[Callable[[StateInterface[_T_SideEffects]], Awaitable[None]]]: | ||
return self._callback | ||
|
||
def set_callback(self, new_callback: Optional[Callable[[StateInterface[_T_SideEffects]], Awaitable[None]]]) -> None: | ||
if not self._callbacks_allowed: | ||
raise RuntimeError("Callback cannot be edited from inside itself") | ||
|
||
self._callback = new_callback |