diff --git a/README.md b/README.md index 00e5b4fc..8d70d7d8 100644 --- a/README.md +++ b/README.md @@ -499,7 +499,7 @@ class GreetingWorkflow: self._complete.set() @workflow.query - async def current_greeting(self) -> str: + def current_greeting(self) -> str: return self._current_greeting ``` @@ -566,7 +566,8 @@ Here are the decorators that can be applied: * Return value is ignored * `@workflow.query` - Defines a method as a query * All the same constraints as `@workflow.signal` but should return a value - * Temporal queries should never mutate anything in the workflow + * Should not be `async` + * Temporal queries should never mutate anything in the workflow or call any calls that would mutate the workflow #### Running diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index 8d155581..8b13e389 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -12,6 +12,7 @@ import traceback import warnings from abc import ABC, abstractmethod +from contextlib import contextmanager from dataclasses import dataclass from datetime import timedelta from typing import ( @@ -21,6 +22,7 @@ Deque, Dict, Generator, + Iterator, List, Mapping, MutableMapping, @@ -193,6 +195,7 @@ def __init__(self, det: WorkflowInstanceDetails) -> None: self._object: Any = None self._is_replaying: bool = False self._random = random.Random(det.randomness_seed) + self._read_only = False # Patches we have been notified of and memoized patch responses self._patches_notified: Set[str] = set() @@ -421,36 +424,39 @@ async def run_query() -> None: command = self._add_command() command.respond_to_query.query_id = job.query_id try: - # Named query or dynamic - defn = self._queries.get(job.query_type) or self._queries.get(None) - if not defn: - known_queries = sorted([k for k in self._queries.keys() if k]) - raise RuntimeError( - f"Query handler for '{job.query_type}' expected but not found, " - f"known queries: [{' '.join(known_queries)}]" + with self._as_read_only(): + # Named query or dynamic + defn = self._queries.get(job.query_type) or self._queries.get(None) + if not defn: + known_queries = sorted([k for k in self._queries.keys() if k]) + raise RuntimeError( + f"Query handler for '{job.query_type}' expected but not found, " + f"known queries: [{' '.join(known_queries)}]" + ) + + # Create input + args = self._process_handler_args( + job.query_type, + job.arguments, + defn.name, + defn.arg_types, + defn.dynamic_vararg, ) - - # Create input - args = self._process_handler_args( - job.query_type, - job.arguments, - defn.name, - defn.arg_types, - defn.dynamic_vararg, - ) - input = HandleQueryInput( - id=job.query_id, - query=job.query_type, - args=args, - headers=job.headers, - ) - success = await self._inbound.handle_query(input) - result_payloads = self._payload_converter.to_payloads([success]) - if len(result_payloads) != 1: - raise ValueError( - f"Expected 1 result payload, got {len(result_payloads)}" + input = HandleQueryInput( + id=job.query_id, + query=job.query_type, + args=args, + headers=job.headers, + ) + success = await self._inbound.handle_query(input) + result_payloads = self._payload_converter.to_payloads([success]) + if len(result_payloads) != 1: + raise ValueError( + f"Expected 1 result payload, got {len(result_payloads)}" + ) + command.respond_to_query.succeeded.response.CopyFrom( + result_payloads[0] ) - command.respond_to_query.succeeded.response.CopyFrom(result_payloads[0]) except Exception as err: try: self._failure_converter.to_failure( @@ -695,6 +701,7 @@ def workflow_continue_as_new( search_attributes: Optional[temporalio.common.SearchAttributes], versioning_intent: Optional[temporalio.workflow.VersioningIntent], ) -> NoReturn: + self._assert_not_read_only("continue as new") # Use definition if callable name: Optional[str] = None arg_types: Optional[List[Type]] = None @@ -795,12 +802,20 @@ def workflow_payload_converter(self) -> temporalio.converter.PayloadConverter: return self._payload_converter def workflow_random(self) -> random.Random: + self._assert_not_read_only("random") return self._random def workflow_set_query_handler( self, name: Optional[str], handler: Optional[Callable] ) -> None: + self._assert_not_read_only("set query handler") if handler: + if inspect.iscoroutinefunction(handler): + warnings.warn( + "Queries as async def functions are deprecated", + DeprecationWarning, + stacklevel=3, + ) defn = temporalio.workflow._QueryDefinition( name=name, fn=handler, is_method=False ) @@ -817,6 +832,7 @@ def workflow_set_query_handler( def workflow_set_signal_handler( self, name: Optional[str], handler: Optional[Callable] ) -> None: + self._assert_not_read_only("set signal handler") if handler: defn = temporalio.workflow._SignalDefinition( name=name, fn=handler, is_method=False @@ -855,6 +871,7 @@ def workflow_start_activity( activity_id: Optional[str], versioning_intent: Optional[temporalio.workflow.VersioningIntent], ) -> temporalio.workflow.ActivityHandle[Any]: + self._assert_not_read_only("start activity") # Get activity definition if it's callable name: str arg_types: Optional[List[Type]] = None @@ -1012,6 +1029,7 @@ def workflow_upsert_search_attributes( async def workflow_wait_condition( self, fn: Callable[[], bool], *, timeout: Optional[float] = None ) -> None: + self._assert_not_read_only("wait condition") fut = self.create_future() self._conditions.append((fn, fut)) await asyncio.wait_for(fut, timeout) @@ -1153,8 +1171,24 @@ async def run_child() -> Any: # These are in alphabetical order. def _add_command(self) -> temporalio.bridge.proto.workflow_commands.WorkflowCommand: + self._assert_not_read_only("add command") return self._current_completion.successful.commands.add() + @contextmanager + def _as_read_only(self) -> Iterator[None]: + prev_val = self._read_only + self._read_only = True + try: + yield None + finally: + self._read_only = prev_val + + def _assert_not_read_only(self, action_attempted: str) -> None: + if self._read_only: + raise temporalio.workflow.ReadOnlyContextError( + f"While in read-only function, action attempted: {action_attempted}" + ) + async def _cancel_external_workflow( self, # Should not have seq set @@ -1258,6 +1292,7 @@ def _register_task( *, name: Optional[str], ) -> None: + self._assert_not_read_only("create task") # Name not supported on older Python versions if sys.version_info >= (3, 8): # Put the workflow info at the end of the task name @@ -1423,6 +1458,7 @@ def call_soon( *args: Any, context: Optional[contextvars.Context] = None, ) -> asyncio.Handle: + self._assert_not_read_only("schedule task") handle = asyncio.Handle(callback, args, self, context) self._ready.append(handle) return handle @@ -1434,6 +1470,7 @@ def call_later( *args: Any, context: Optional[contextvars.Context] = None, ) -> asyncio.TimerHandle: + self._assert_not_read_only("schedule timer") # Delay must be positive if delay < 0: raise RuntimeError("Attempting to schedule timer with negative delay") @@ -1675,6 +1712,7 @@ def __init__( instance._register_task(self, name=f"activity: {input.activity}") def cancel(self, msg: Optional[Any] = None) -> bool: + self._instance._assert_not_read_only("cancel activity handle") # We override this because if it's not yet started and not done, we need # to send a cancel command because the async function won't run to trap # the cancel (i.e. cancelled before started) @@ -1821,6 +1859,7 @@ async def signal( *, args: Sequence[Any] = [], ) -> None: + self._instance._assert_not_read_only("signal child handle") await self._instance._outbound.signal_child_workflow( SignalChildWorkflowInput( signal=temporalio.workflow._SignalDefinition.must_name_from_fn_or_str( @@ -1935,6 +1974,7 @@ async def signal( *, args: Sequence[Any] = [], ) -> None: + self._instance._assert_not_read_only("signal external handle") await self._instance._outbound.signal_external_workflow( SignalExternalWorkflowInput( signal=temporalio.workflow._SignalDefinition.must_name_from_fn_or_str( @@ -1949,6 +1989,7 @@ async def signal( ) async def cancel(self) -> None: + self._instance._assert_not_read_only("cancel external handle") command = self._instance._add_command() v = command.request_cancel_external_workflow_execution v.workflow_execution.namespace = self._instance._info.namespace diff --git a/temporalio/workflow.py b/temporalio/workflow.py index a5d02444..7f4b0d07 100644 --- a/temporalio/workflow.py +++ b/temporalio/workflow.py @@ -242,9 +242,8 @@ def query( ): """Decorator for a workflow query method. - This is set on any async or non-async method that expects to handle a - query. If a function overrides one with this decorator, it too must be - decorated. + This is set on any non-async method that expects to handle a query. If a + function overrides one with this decorator, it too must be decorated. Query methods can only have positional parameters. Best practice for non-dynamic query methods is to only take a single object/dataclass @@ -262,7 +261,15 @@ def query( present. """ - def with_name(name: Optional[str], fn: CallableType) -> CallableType: + def with_name( + name: Optional[str], fn: CallableType, *, bypass_async_check: bool = False + ) -> CallableType: + if not bypass_async_check and inspect.iscoroutinefunction(fn): + warnings.warn( + "Queries as async def functions are deprecated", + DeprecationWarning, + stacklevel=2, + ) defn = _QueryDefinition(name=name, fn=fn, is_method=True) setattr(fn, "__temporal_query_definition", defn) if defn.dynamic_vararg: @@ -279,7 +286,13 @@ def with_name(name: Optional[str], fn: CallableType) -> CallableType: return partial(with_name, name) if fn is None: raise RuntimeError("Cannot create query without function or name or dynamic") - return with_name(fn.__name__, fn) + if inspect.iscoroutinefunction(fn): + warnings.warn( + "Queries as async def functions are deprecated", + DeprecationWarning, + stacklevel=2, + ) + return with_name(fn.__name__, fn, bypass_async_check=True) @dataclass(frozen=True) @@ -3919,6 +3932,17 @@ def __init__(self, message: str) -> None: self.message = message +class ReadOnlyContextError(temporalio.exceptions.TemporalError): + """Error thrown when trying to do mutable workflow calls in a read-only + context like a query or update validator. + """ + + def __init__(self, message: str) -> None: + """Initialize a read-only context error.""" + super().__init__(message) + self.message = message + + class _NotInWorkflowEventLoopError(temporalio.exceptions.TemporalError): def __init__(self, *args: object) -> None: super().__init__("Not in workflow event loop") diff --git a/tests/testing/test_workflow.py b/tests/testing/test_workflow.py index 2c6b8025..2dc0eb8c 100644 --- a/tests/testing/test_workflow.py +++ b/tests/testing/test_workflow.py @@ -28,7 +28,7 @@ async def run(self) -> str: return "all done" @workflow.query - async def current_time(self) -> float: + def current_time(self) -> float: return workflow.now().timestamp() @workflow.signal diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index 6b6f0a27..1fdd5396 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -3014,7 +3014,7 @@ async def signal(self) -> None: self._signal_count += 1 @workflow.query - async def signal_count(self) -> int: + def signal_count(self) -> int: return self._signal_count @@ -3097,6 +3097,63 @@ async def test_workflow_dynamic(client: Client): assert result == DynamicWorkflowValue("some-workflow - val1 - val2") +@workflow.defn +class QueriesDoingBadThingsWorkflow: + @workflow.run + async def run(self) -> None: + await workflow.wait_condition(lambda: False) + + @workflow.query + async def bad_query(self, bad_thing: str) -> str: + if bad_thing == "wait_condition": + await workflow.wait_condition(lambda: True) + elif bad_thing == "continue_as_new": + workflow.continue_as_new() + elif bad_thing == "upsert_search_attribute": + workflow.upsert_search_attributes({"foo": ["bar"]}) + elif bad_thing == "start_activity": + workflow.start_activity( + "some-activity", start_to_close_timeout=timedelta(minutes=10) + ) + elif bad_thing == "start_child_workflow": + await workflow.start_child_workflow("some-workflow") + elif bad_thing == "random": + workflow.random().random() + elif bad_thing == "set_query_handler": + workflow.set_query_handler("some-handler", lambda: "whatever") + elif bad_thing == "patch": + workflow.patched("some-patch") + elif bad_thing == "signal_external_handle": + await workflow.get_external_workflow_handle("some-id").signal("some-signal") + return "should never get here" + + +async def test_workflow_queries_doing_bad_things(client: Client): + async with new_worker(client, QueriesDoingBadThingsWorkflow) as worker: + handle = await client.start_workflow( + QueriesDoingBadThingsWorkflow.run, + id=f"wf-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + async def assert_bad_query(bad_thing: str) -> None: + with pytest.raises(WorkflowQueryFailedError) as err: + _ = await handle.query( + QueriesDoingBadThingsWorkflow.bad_query, bad_thing + ) + assert "While in read-only function, action attempted" in str(err) + + await assert_bad_query("wait_condition") + await assert_bad_query("continue_as_new") + await assert_bad_query("upsert_search_attribute") + await assert_bad_query("start_activity") + await assert_bad_query("start_child_workflow") + await assert_bad_query("random") + await assert_bad_query("set_query_handler") + await assert_bad_query("patch") + await assert_bad_query("signal_external_handle") + + # typing.Self only in 3.11+ if sys.version_info >= (3, 11):