Skip to content

Commit

Permalink
fix(QueryTracker): publish query tracker results (#773)
Browse files Browse the repository at this point in the history
* fix(QueryTracker): publish query tracker results

* fix(QueryTracker): execute code is not being tracked

* chore(tests): change comment
  • Loading branch information
ArslanSaleem authored Nov 22, 2023
1 parent dcc660b commit 431c83c
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 11 deletions.
8 changes: 5 additions & 3 deletions pandasai/pipelines/smart_datalake_chat/code_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,11 @@ def execute(self, input: Any, **kwargs) -> Any:
pipeline_context.get_intermediate_value("last_prompt_id"),
pipeline_context.get_intermediate_value("skills"),
)
result = pipeline_context.get_intermediate_value(
"code_manager"
).execute_code(

result = pipeline_context.query_exec_tracker.execute_func(
pipeline_context.get_intermediate_value(
"code_manager"
).execute_code,
code=code_to_run,
context=code_context,
)
Expand Down
3 changes: 3 additions & 0 deletions pandasai/smart_datalake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,9 @@ def chat(self, query: str, output_type: Optional[str] = None):

self.update_intermediate_value_post_pipeline_execution(pipeline_context)

# publish query tracker
self._query_exec_tracker.publish()

return result

def _validate_output(self, result: dict, output_type: Optional[str] = None):
Expand Down
25 changes: 17 additions & 8 deletions tests/pipelines/smart_datalake/test_code_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ def mock_intermediate_values(key: str):
return mock_code_manager

context.get_intermediate_value = Mock(side_effect=mock_intermediate_values)
context._query_exec_tracker = Mock()
context.query_exec_tracker.execute_func = Mock(return_value="Mocked Result")

result = code_execution.execute(
input="Test Code", context=context, logger=logger
Expand Down Expand Up @@ -157,17 +159,24 @@ def mock_execute_code(*args, **kwargs):
raise Exception("Unit test exception")
return "Mocked Result after retry"

# Conditional return of execute_func method based arguments it is called with
def mock_execute_func(*args, **kwargs):
if isinstance(args[0], Mock) and args[0].name == "execute_code":
return mock_execute_code(*args, **kwargs)
else:
return [
"Interuppted Code",
"Exception Testing",
"Successful after Retry",
]

mock_code_manager = Mock()
mock_code_manager.execute_code = Mock(side_effect=mock_execute_code)
mock_code_manager.execute_code = Mock()
mock_code_manager.execute_code.name = "execute_code"

context._query_exec_tracker = Mock()
context.query_exec_tracker.execute_func = Mock(
return_value=[
"Interuppted Code",
"Exception Testing",
"Successful after Retry",
]
)

context.query_exec_tracker.execute_func = Mock(side_effect=mock_execute_func)

def mock_intermediate_values(key: str):
if key == "last_prompt_id":
Expand Down
19 changes: 19 additions & 0 deletions tests/test_smartdatalake.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,25 @@ def test_last_result_is_saved(self, _mocked_method, smart_datalake: SmartDatalak
"value": "There are 10 countries in the dataframe.",
}

@patch.object(
CodeManager,
"execute_code",
return_value={
"type": "string",
"value": "There are 10 countries in the dataframe.",
},
)
@patch("pandasai.helpers.query_exec_tracker.QueryExecTracker.publish")
def test_query_tracker_publish_called_in_chat_method(
self, mock_query_tracker_publish, _mocked_method, smart_datalake: SmartDatalake
):
assert smart_datalake.last_result is None

_mocked_method.__name__ = "execute_code"

smart_datalake.chat("How many countries are in the dataframe?")
mock_query_tracker_publish.assert_called()

def test_retry_on_error_with_single_df(
self, smart_datalake: SmartDatalake, smart_dataframe: SmartDataframe
):
Expand Down

0 comments on commit 431c83c

Please sign in to comment.