Skip to content

Commit

Permalink
Disallow most workflow operations in read-only context (#351)
Browse files Browse the repository at this point in the history
Fixes #250
  • Loading branch information
cretz authored Jul 18, 2023
1 parent b9df212 commit a17c0ef
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 37 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ class GreetingWorkflow:
self._complete.set()

@workflow.query
async def current_greeting(self) -> str:
def current_greeting(self) -> str:
return self._current_greeting

```
Expand Down Expand Up @@ -566,7 +566,8 @@ Here are the decorators that can be applied:
* Return value is ignored
* `@workflow.query` - Defines a method as a query
* All the same constraints as `@workflow.signal` but should return a value
* Temporal queries should never mutate anything in the workflow
* Should not be `async`
* Temporal queries should never mutate anything in the workflow or call any calls that would mutate the workflow

#### Running

Expand Down
97 changes: 69 additions & 28 deletions temporalio/worker/_workflow_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import traceback
import warnings
from abc import ABC, abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass
from datetime import timedelta
from typing import (
Expand All @@ -21,6 +22,7 @@
Deque,
Dict,
Generator,
Iterator,
List,
Mapping,
MutableMapping,
Expand Down Expand Up @@ -193,6 +195,7 @@ def __init__(self, det: WorkflowInstanceDetails) -> None:
self._object: Any = None
self._is_replaying: bool = False
self._random = random.Random(det.randomness_seed)
self._read_only = False

# Patches we have been notified of and memoized patch responses
self._patches_notified: Set[str] = set()
Expand Down Expand Up @@ -421,36 +424,39 @@ async def run_query() -> None:
command = self._add_command()
command.respond_to_query.query_id = job.query_id
try:
# Named query or dynamic
defn = self._queries.get(job.query_type) or self._queries.get(None)
if not defn:
known_queries = sorted([k for k in self._queries.keys() if k])
raise RuntimeError(
f"Query handler for '{job.query_type}' expected but not found, "
f"known queries: [{' '.join(known_queries)}]"
with self._as_read_only():
# Named query or dynamic
defn = self._queries.get(job.query_type) or self._queries.get(None)
if not defn:
known_queries = sorted([k for k in self._queries.keys() if k])
raise RuntimeError(
f"Query handler for '{job.query_type}' expected but not found, "
f"known queries: [{' '.join(known_queries)}]"
)

# Create input
args = self._process_handler_args(
job.query_type,
job.arguments,
defn.name,
defn.arg_types,
defn.dynamic_vararg,
)

# Create input
args = self._process_handler_args(
job.query_type,
job.arguments,
defn.name,
defn.arg_types,
defn.dynamic_vararg,
)
input = HandleQueryInput(
id=job.query_id,
query=job.query_type,
args=args,
headers=job.headers,
)
success = await self._inbound.handle_query(input)
result_payloads = self._payload_converter.to_payloads([success])
if len(result_payloads) != 1:
raise ValueError(
f"Expected 1 result payload, got {len(result_payloads)}"
input = HandleQueryInput(
id=job.query_id,
query=job.query_type,
args=args,
headers=job.headers,
)
success = await self._inbound.handle_query(input)
result_payloads = self._payload_converter.to_payloads([success])
if len(result_payloads) != 1:
raise ValueError(
f"Expected 1 result payload, got {len(result_payloads)}"
)
command.respond_to_query.succeeded.response.CopyFrom(
result_payloads[0]
)
command.respond_to_query.succeeded.response.CopyFrom(result_payloads[0])
except Exception as err:
try:
self._failure_converter.to_failure(
Expand Down Expand Up @@ -695,6 +701,7 @@ def workflow_continue_as_new(
search_attributes: Optional[temporalio.common.SearchAttributes],
versioning_intent: Optional[temporalio.workflow.VersioningIntent],
) -> NoReturn:
self._assert_not_read_only("continue as new")
# Use definition if callable
name: Optional[str] = None
arg_types: Optional[List[Type]] = None
Expand Down Expand Up @@ -795,12 +802,20 @@ def workflow_payload_converter(self) -> temporalio.converter.PayloadConverter:
return self._payload_converter

def workflow_random(self) -> random.Random:
self._assert_not_read_only("random")
return self._random

def workflow_set_query_handler(
self, name: Optional[str], handler: Optional[Callable]
) -> None:
self._assert_not_read_only("set query handler")
if handler:
if inspect.iscoroutinefunction(handler):
warnings.warn(
"Queries as async def functions are deprecated",
DeprecationWarning,
stacklevel=3,
)
defn = temporalio.workflow._QueryDefinition(
name=name, fn=handler, is_method=False
)
Expand All @@ -817,6 +832,7 @@ def workflow_set_query_handler(
def workflow_set_signal_handler(
self, name: Optional[str], handler: Optional[Callable]
) -> None:
self._assert_not_read_only("set signal handler")
if handler:
defn = temporalio.workflow._SignalDefinition(
name=name, fn=handler, is_method=False
Expand Down Expand Up @@ -855,6 +871,7 @@ def workflow_start_activity(
activity_id: Optional[str],
versioning_intent: Optional[temporalio.workflow.VersioningIntent],
) -> temporalio.workflow.ActivityHandle[Any]:
self._assert_not_read_only("start activity")
# Get activity definition if it's callable
name: str
arg_types: Optional[List[Type]] = None
Expand Down Expand Up @@ -1012,6 +1029,7 @@ def workflow_upsert_search_attributes(
async def workflow_wait_condition(
self, fn: Callable[[], bool], *, timeout: Optional[float] = None
) -> None:
self._assert_not_read_only("wait condition")
fut = self.create_future()
self._conditions.append((fn, fut))
await asyncio.wait_for(fut, timeout)
Expand Down Expand Up @@ -1153,8 +1171,24 @@ async def run_child() -> Any:
# These are in alphabetical order.

def _add_command(self) -> temporalio.bridge.proto.workflow_commands.WorkflowCommand:
self._assert_not_read_only("add command")
return self._current_completion.successful.commands.add()

@contextmanager
def _as_read_only(self) -> Iterator[None]:
prev_val = self._read_only
self._read_only = True
try:
yield None
finally:
self._read_only = prev_val

def _assert_not_read_only(self, action_attempted: str) -> None:
if self._read_only:
raise temporalio.workflow.ReadOnlyContextError(
f"While in read-only function, action attempted: {action_attempted}"
)

async def _cancel_external_workflow(
self,
# Should not have seq set
Expand Down Expand Up @@ -1258,6 +1292,7 @@ def _register_task(
*,
name: Optional[str],
) -> None:
self._assert_not_read_only("create task")
# Name not supported on older Python versions
if sys.version_info >= (3, 8):
# Put the workflow info at the end of the task name
Expand Down Expand Up @@ -1423,6 +1458,7 @@ def call_soon(
*args: Any,
context: Optional[contextvars.Context] = None,
) -> asyncio.Handle:
self._assert_not_read_only("schedule task")
handle = asyncio.Handle(callback, args, self, context)
self._ready.append(handle)
return handle
Expand All @@ -1434,6 +1470,7 @@ def call_later(
*args: Any,
context: Optional[contextvars.Context] = None,
) -> asyncio.TimerHandle:
self._assert_not_read_only("schedule timer")
# Delay must be positive
if delay < 0:
raise RuntimeError("Attempting to schedule timer with negative delay")
Expand Down Expand Up @@ -1675,6 +1712,7 @@ def __init__(
instance._register_task(self, name=f"activity: {input.activity}")

def cancel(self, msg: Optional[Any] = None) -> bool:
self._instance._assert_not_read_only("cancel activity handle")
# We override this because if it's not yet started and not done, we need
# to send a cancel command because the async function won't run to trap
# the cancel (i.e. cancelled before started)
Expand Down Expand Up @@ -1821,6 +1859,7 @@ async def signal(
*,
args: Sequence[Any] = [],
) -> None:
self._instance._assert_not_read_only("signal child handle")
await self._instance._outbound.signal_child_workflow(
SignalChildWorkflowInput(
signal=temporalio.workflow._SignalDefinition.must_name_from_fn_or_str(
Expand Down Expand Up @@ -1935,6 +1974,7 @@ async def signal(
*,
args: Sequence[Any] = [],
) -> None:
self._instance._assert_not_read_only("signal external handle")
await self._instance._outbound.signal_external_workflow(
SignalExternalWorkflowInput(
signal=temporalio.workflow._SignalDefinition.must_name_from_fn_or_str(
Expand All @@ -1949,6 +1989,7 @@ async def signal(
)

async def cancel(self) -> None:
self._instance._assert_not_read_only("cancel external handle")
command = self._instance._add_command()
v = command.request_cancel_external_workflow_execution
v.workflow_execution.namespace = self._instance._info.namespace
Expand Down
34 changes: 29 additions & 5 deletions temporalio/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,9 +242,8 @@ def query(
):
"""Decorator for a workflow query method.
This is set on any async or non-async method that expects to handle a
query. If a function overrides one with this decorator, it too must be
decorated.
This is set on any non-async method that expects to handle a query. If a
function overrides one with this decorator, it too must be decorated.
Query methods can only have positional parameters. Best practice for
non-dynamic query methods is to only take a single object/dataclass
Expand All @@ -262,7 +261,15 @@ def query(
present.
"""

def with_name(name: Optional[str], fn: CallableType) -> CallableType:
def with_name(
name: Optional[str], fn: CallableType, *, bypass_async_check: bool = False
) -> CallableType:
if not bypass_async_check and inspect.iscoroutinefunction(fn):
warnings.warn(
"Queries as async def functions are deprecated",
DeprecationWarning,
stacklevel=2,
)
defn = _QueryDefinition(name=name, fn=fn, is_method=True)
setattr(fn, "__temporal_query_definition", defn)
if defn.dynamic_vararg:
Expand All @@ -279,7 +286,13 @@ def with_name(name: Optional[str], fn: CallableType) -> CallableType:
return partial(with_name, name)
if fn is None:
raise RuntimeError("Cannot create query without function or name or dynamic")
return with_name(fn.__name__, fn)
if inspect.iscoroutinefunction(fn):
warnings.warn(
"Queries as async def functions are deprecated",
DeprecationWarning,
stacklevel=2,
)
return with_name(fn.__name__, fn, bypass_async_check=True)


@dataclass(frozen=True)
Expand Down Expand Up @@ -3919,6 +3932,17 @@ def __init__(self, message: str) -> None:
self.message = message


class ReadOnlyContextError(temporalio.exceptions.TemporalError):
"""Error thrown when trying to do mutable workflow calls in a read-only
context like a query or update validator.
"""

def __init__(self, message: str) -> None:
"""Initialize a read-only context error."""
super().__init__(message)
self.message = message


class _NotInWorkflowEventLoopError(temporalio.exceptions.TemporalError):
def __init__(self, *args: object) -> None:
super().__init__("Not in workflow event loop")
Expand Down
2 changes: 1 addition & 1 deletion tests/testing/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ async def run(self) -> str:
return "all done"

@workflow.query
async def current_time(self) -> float:
def current_time(self) -> float:
return workflow.now().timestamp()

@workflow.signal
Expand Down
59 changes: 58 additions & 1 deletion tests/worker/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3014,7 +3014,7 @@ async def signal(self) -> None:
self._signal_count += 1

@workflow.query
async def signal_count(self) -> int:
def signal_count(self) -> int:
return self._signal_count


Expand Down Expand Up @@ -3097,6 +3097,63 @@ async def test_workflow_dynamic(client: Client):
assert result == DynamicWorkflowValue("some-workflow - val1 - val2")


@workflow.defn
class QueriesDoingBadThingsWorkflow:
@workflow.run
async def run(self) -> None:
await workflow.wait_condition(lambda: False)

@workflow.query
async def bad_query(self, bad_thing: str) -> str:
if bad_thing == "wait_condition":
await workflow.wait_condition(lambda: True)
elif bad_thing == "continue_as_new":
workflow.continue_as_new()
elif bad_thing == "upsert_search_attribute":
workflow.upsert_search_attributes({"foo": ["bar"]})
elif bad_thing == "start_activity":
workflow.start_activity(
"some-activity", start_to_close_timeout=timedelta(minutes=10)
)
elif bad_thing == "start_child_workflow":
await workflow.start_child_workflow("some-workflow")
elif bad_thing == "random":
workflow.random().random()
elif bad_thing == "set_query_handler":
workflow.set_query_handler("some-handler", lambda: "whatever")
elif bad_thing == "patch":
workflow.patched("some-patch")
elif bad_thing == "signal_external_handle":
await workflow.get_external_workflow_handle("some-id").signal("some-signal")
return "should never get here"


async def test_workflow_queries_doing_bad_things(client: Client):
async with new_worker(client, QueriesDoingBadThingsWorkflow) as worker:
handle = await client.start_workflow(
QueriesDoingBadThingsWorkflow.run,
id=f"wf-{uuid.uuid4()}",
task_queue=worker.task_queue,
)

async def assert_bad_query(bad_thing: str) -> None:
with pytest.raises(WorkflowQueryFailedError) as err:
_ = await handle.query(
QueriesDoingBadThingsWorkflow.bad_query, bad_thing
)
assert "While in read-only function, action attempted" in str(err)

await assert_bad_query("wait_condition")
await assert_bad_query("continue_as_new")
await assert_bad_query("upsert_search_attribute")
await assert_bad_query("start_activity")
await assert_bad_query("start_child_workflow")
await assert_bad_query("random")
await assert_bad_query("set_query_handler")
await assert_bad_query("patch")
await assert_bad_query("signal_external_handle")


# typing.Self only in 3.11+
if sys.version_info >= (3, 11):

Expand Down

0 comments on commit a17c0ef

Please sign in to comment.