Skip to content

Commit

Permalink
Fixes bug in which inputs were not working for asynchronous applications
Browse files Browse the repository at this point in the history
See #439. We were filing with required inputs twice -- this ensures it
only happens once. We were calling _step from _astep, both of which did
this.
  • Loading branch information
elijahbenizzy committed Nov 28, 2024
1 parent 07ae56b commit 5f20105
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 4 deletions.
10 changes: 6 additions & 4 deletions burr/core/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,7 +874,7 @@ def _process_inputs(self, inputs: Dict[str, Any], action: Action) -> Dict[str, A
raise ValueError(
BASE_ERROR_MESSAGE
+ f"Inputs starting with a double underscore ({starting_with_double_underscore}) "
f"are reserved for internal use/injected inputs."
f"are reserved for internal use/injected inputs. "
"Please do not directly pass keys starting with a double underscore."
)
inputs = inputs.copy()
Expand Down Expand Up @@ -945,13 +945,12 @@ async def _astep(self, inputs: Optional[Dict[str, Any]], _run_hooks: bool = True
return None
if inputs is None:
inputs = {}
action_inputs = self._process_inputs(inputs, next_action)
if _run_hooks:
await self._adapter_set.call_all_lifecycle_hooks_sync_and_async(
"pre_run_step",
action=next_action,
state=self._state,
inputs=action_inputs,
inputs=inputs,
sequence_id=self.sequence_id,
app_id=self._uid,
partition_key=self._partition_key,
Expand All @@ -966,9 +965,12 @@ async def _astep(self, inputs: Optional[Dict[str, Any]], _run_hooks: bool = True
# TODO -- add an option/configuration to launch a thread (yikes, not super safe, but for a pure function
# which this is supposed to be its OK).
# this delegates hooks to the synchronous version, so we'll call all of them as well
# In this case we allow the self._step to do input processing
return self._step(
inputs=action_inputs, _run_hooks=False
inputs=inputs, _run_hooks=False
) # Skip hooks as we already ran all of them/will run all of them in this function's finally
# In this case we want to process inputs because we run the function directly
action_inputs = self._process_inputs(inputs, next_action)
if next_action.single_step:
result, new_state = await _arun_single_step_action(
next_action, self._state, inputs=action_inputs
Expand Down
52 changes: 52 additions & 0 deletions tests/core/test_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import datetime
import logging
import typing
import uuid
from typing import Any, Awaitable, Callable, Dict, Generator, Literal, Optional, Tuple

import pytest
Expand Down Expand Up @@ -1341,6 +1342,57 @@ async def test_app_astep():
assert state[PRIOR_STEP] == "counter_async" # internal contract, not part of the public API


def test_app_step_context():
APP_ID = str(uuid.uuid4())
PARTITION_KEY = str(uuid.uuid4())

@action(reads=[], writes=[])
def test_action(state: State, __context: ApplicationContext) -> State:
assert __context.sequence_id == 0
assert __context.partition_key == PARTITION_KEY
assert __context.app_id == APP_ID
return state

app = (
ApplicationBuilder()
.with_actions(test_action)
.with_entrypoint("test_action")
.with_transitions()
.with_identifiers(
app_id=APP_ID,
partition_key=PARTITION_KEY,
)
.build()
)
app.step()


async def test_app_astep_context():
"""Tests that app.astep correctly passes context."""
APP_ID = str(uuid.uuid4())
PARTITION_KEY = str(uuid.uuid4())

@action(reads=[], writes=[])
def test_action(state: State, __context: ApplicationContext) -> State:
assert __context.sequence_id == 0
assert __context.partition_key == PARTITION_KEY
assert __context.app_id == APP_ID
return state

app = (
ApplicationBuilder()
.with_actions(test_action)
.with_entrypoint("test_action")
.with_transitions()
.with_identifiers(
app_id=APP_ID,
partition_key=PARTITION_KEY,
)
.build()
)
await app.astep()


async def test_app_astep_with_inputs():
"""Tests that we can run an async step in an app"""
counter_action = base_single_step_counter_with_inputs_async.with_name("counter_async")
Expand Down

0 comments on commit 5f20105

Please sign in to comment.