Skip to content

Commit

Permalink
Allow client to be set in test env
Browse files Browse the repository at this point in the history
  • Loading branch information
dandavison committed Jan 23, 2025
1 parent f7f3978 commit 3164f69
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 7 deletions.
9 changes: 7 additions & 2 deletions temporalio/activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ class _Context:
temporalio.converter.PayloadConverter,
]
runtime_metric_meter: Optional[temporalio.common.MetricMeter]
client: Client
client: Optional[Client]
_logger_details: Optional[Mapping[str, Any]] = None
_payload_converter: Optional[temporalio.converter.PayloadConverter] = None
_metric_meter: Optional[temporalio.common.MetricMeter] = None
Expand Down Expand Up @@ -249,7 +249,12 @@ def client() -> Client:
Raises:
RuntimeError: When not in an activity.
"""
return _Context.current().client
client = _Context.current().client
if not client:
raise RuntimeError(
"No client available. In tests you can pass a client when creating ActivityEnvironment."
)
return client


def in_activity() -> bool:
Expand Down
16 changes: 11 additions & 5 deletions temporalio/testing/_activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import temporalio.converter
import temporalio.exceptions
import temporalio.worker._activity
from temporalio.client import Client

_Params = ParamSpec("_Params")
_Return = TypeVar("_Return")
Expand Down Expand Up @@ -62,7 +63,7 @@ class ActivityEnvironment:
take effect. Default is noop.
"""

def __init__(self) -> None:
def __init__(self, client: Optional[Client] = None) -> None:
"""Create an ActivityEnvironment for running activity code."""
self.info = _default_info
self.on_heartbeat: Callable[..., None] = lambda *args: None
Expand All @@ -73,6 +74,7 @@ def __init__(self) -> None:
self._cancelled = False
self._worker_shutdown = False
self._activities: Set[_Activity] = set()
self._client = client

def cancel(self) -> None:
"""Cancel the activity.
Expand Down Expand Up @@ -113,14 +115,15 @@ def run(
The callable's result.
"""
# Create an activity and run it
return _Activity(self, fn).run(*args, **kwargs)
return _Activity(self, fn, self._client).run(*args, **kwargs)


class _Activity:
def __init__(
self,
env: ActivityEnvironment,
fn: Callable,
client: Optional[Client],
) -> None:
self.env = env
self.fn = fn
Expand Down Expand Up @@ -148,11 +151,14 @@ def __init__(
thread_event=threading.Event(),
async_event=asyncio.Event() if self.is_async else None,
),
shield_thread_cancel_exception=None
if not self.cancel_thread_raiser
else self.cancel_thread_raiser.shielded,
shield_thread_cancel_exception=(
None
if not self.cancel_thread_raiser
else self.cancel_thread_raiser.shielded
),
payload_converter_class_or_instance=env.payload_converter,
runtime_metric_meter=env.metric_meter,
client=client,
)
self.task: Optional[asyncio.Task] = None

Expand Down
31 changes: 31 additions & 0 deletions tests/testing/test_activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,12 @@
import threading
import time
from contextvars import copy_context
from unittest.mock import Mock

import pytest

from temporalio import activity
from temporalio.client import Client
from temporalio.exceptions import CancelledError
from temporalio.testing import ActivityEnvironment

Expand Down Expand Up @@ -110,3 +114,30 @@ async def assert_equals(a: str, b: str) -> None:

assert type(expected_err) == type(actual_err)
assert str(expected_err) == str(actual_err)


async def test_activity_env_without_client():
saw_error: bool = False

def my_activity() -> None:
with pytest.raises(RuntimeError):
activity.client()
nonlocal saw_error
saw_error = True

env = ActivityEnvironment()
env.run(my_activity)
assert saw_error


async def test_activity_env_with_client():
got_client: bool = False

def my_activity() -> None:
nonlocal got_client
if activity.client():
got_client = True

env = ActivityEnvironment(client=Mock(spec=Client))
env.run(my_activity)
assert got_client

0 comments on commit 3164f69

Please sign in to comment.