Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(output_type): handle errors for wrong output type #866

Merged
merged 4 commits into from
Jan 11, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{dataframes}

The user asked the following question:
{conversation}

You generated this python code:
{code}

Fix the python code above and return the new python code but the result type should be: {output_type_hint}
8 changes: 8 additions & 0 deletions pandasai/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,11 @@ class MaliciousQueryError(Exception):
Args:
Exception (Exception): MaliciousQueryError
"""


class InvalidLLMOutputType(Exception):
"""
Raise error if malicious query is generated
Args:
Exception (Exception): MaliciousQueryError
"""
Copy link
Contributor

@coderabbitai coderabbitai bot Jan 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring for the InvalidLLMOutputType exception class incorrectly mentions malicious query generation, which seems unrelated to output type validation. This should be corrected to accurately reflect the purpose of the exception.

- Raise error if malicious query is generated
+ Raise error if the output type is invalid

Committable suggestion

IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
class InvalidLLMOutputType(Exception):
"""
Raise error if malicious query is generated
Args:
Exception (Exception): MaliciousQueryError
"""
class InvalidLLMOutputType(Exception):
"""
Raise error if the output type is invalid
Args:
Exception (Exception): MaliciousQueryError
"""

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArslanSaleem leftover here!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like the humans are having a chat. I'll hop back into my burrow for now. If you need me again, just tag @coderabbitai in a new comment, and I'll come hopping out!

4 changes: 3 additions & 1 deletion pandasai/helpers/output_types/_output_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ def _validate_value(self, actual_value: Any) -> bool:


class DefaultOutputType(BaseOutputType):
default_types = ["string", "number", "dataframe", "plot"]

@property
def template_hint(self):
return """type (possible values "string", "number", "dataframe", "plot"). Examples: { "type": "string", "value": f"The highest salary is {highest_salary}." } or { "type": "number", "value": 125 } or { "type": "dataframe", "value": pd.DataFrame({...}) } or { "type": "plot", "value": "temp_chart.png" }""" # noqa E501
Expand All @@ -140,4 +142,4 @@ def validate(self, result: dict[str, Any]) -> tuple[bool, Iterable]:
(bool): True since the `DefaultOutputType`
is supposed to have no validation
"""
return True, ()
return result["type"] in self.default_types, ()
2 changes: 1 addition & 1 deletion pandasai/llm/fake.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
class FakeLLM(LLM):
"""Fake LLM"""

_output: str = """result = { 'type': 'text', 'value': "Hello World" }"""
_output: str = """result = { 'type': 'string', 'value': "Hello World" }"""

def __init__(self, output: Optional[str] = None):
if output is not None:
Expand Down
37 changes: 34 additions & 3 deletions pandasai/pipelines/smart_datalake_chat/code_execution.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
import logging
import traceback
from typing import Any, List

from pandasai.exceptions import InvalidLLMOutputType
from pandasai.prompts.base import AbstractPrompt
from pandasai.prompts.correct_output_type_error_prompt import (
CorrectOutputTypeErrorPrompt,
)
from ...helpers.code_manager import CodeExecutionContext
from ...helpers.logger import Logger
from ..base_logic_unit import BaseLogicUnit
Expand Down Expand Up @@ -51,6 +57,17 @@ def execute(self, input: Any, **kwargs) -> Any:
context=code_context,
)

output_helper = pipeline_context.get_intermediate_value(
"output_type_helper"
)
if output_helper := pipeline_context.get_intermediate_value(
"output_type_helper"
):
(validation_ok, validation_errors) = output_helper.validate(result)

if not validation_ok:
raise InvalidLLMOutputType(validation_errors)

break

except Exception as e:
Expand All @@ -69,18 +86,33 @@ def execute(self, input: Any, **kwargs) -> Any:
)

traceback_error = traceback.format_exc()

# Get Error Prompt for retry
error_prompt = self._get_error_prompt(e)
code_to_run = pipeline_context.query_exec_tracker.execute_func(
self._retry_run_code,
code,
pipeline_context,
logger,
traceback_error,
error_prompt,
)

return result

def _get_error_prompt(self, e: Exception) -> AbstractPrompt:
if isinstance(e, InvalidLLMOutputType):
return CorrectOutputTypeErrorPrompt()
else:
return CorrectErrorPrompt()

def _retry_run_code(
self, code: str, context: PipelineContext, logger: Logger, e: Exception
self,
code: str,
context: PipelineContext,
logger: Logger,
e: Exception,
error_prompt=CorrectErrorPrompt(),
) -> List:
"""
A method to retry the code execution with error correction framework.
Expand All @@ -94,7 +126,6 @@ def _retry_run_code(

Returns (str): A python code
"""

logger.log(f"Failed with error: {e}. Retrying", logging.ERROR)

default_values = {
Expand All @@ -107,7 +138,7 @@ def _retry_run_code(
}
error_correcting_instruction = context.get_intermediate_value("get_prompt")(
"correct_error",
default_prompt=CorrectErrorPrompt(),
default_prompt=error_prompt,
default_values=default_values,
)

Expand Down
22 changes: 22 additions & 0 deletions pandasai/prompts/correct_output_type_error_prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
""" Prompt to correct Output Type Python Code on Error
```
{dataframes}

{conversation}

You generated this python code:
{code}

It fails with the following error:
{error_returned}

Fix the python code above and return the new python code but the result type should be:
""" # noqa: E501

from .file_based_prompt import FileBasedPrompt


class CorrectOutputTypeErrorPrompt(FileBasedPrompt):
"""Prompt to Correct Python code on Error"""

_path_to_template = "assets/prompt_templates/correct_output_type_error_prompt.tmpl"
31 changes: 30 additions & 1 deletion tests/pipelines/smart_datalake/test_code_execution.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
from typing import Optional
from unittest.mock import Mock
from unittest.mock import MagicMock, Mock
import pandas as pd
import pytest
from pandasai.exceptions import InvalidLLMOutputType
from pandasai.helpers.logger import Logger
from pandasai.helpers.skills_manager import SkillsManager

from pandasai.llm.fake import FakeLLM
from pandasai.pipelines.pipeline_context import PipelineContext
from pandasai.prompts.correct_error_prompt import CorrectErrorPrompt
from pandasai.prompts.correct_output_type_error_prompt import (
CorrectOutputTypeErrorPrompt,
)
from pandasai.smart_dataframe import SmartDataframe
from pandasai.pipelines.smart_datalake_chat.code_execution import CodeExecution

Expand Down Expand Up @@ -194,3 +199,27 @@ def mock_intermediate_values(key: str):

assert isinstance(code_execution, CodeExecution)
assert result == "Mocked Result after retry"

def test_get_error_prompt_invalid_llm_output_type(self):
code_execution = CodeExecution()

# Mock the InvalidLLMOutputType exception
mock_exception = MagicMock(spec=InvalidLLMOutputType)

# Call the method with the mock exception
result = code_execution._get_error_prompt(mock_exception)

# Assert that the CorrectOutputTypeErrorPrompt is returned
assert isinstance(result, CorrectOutputTypeErrorPrompt)

def test_get_error_prompt_other_exception(self):
code_execution = CodeExecution()

# Mock a generic exception
mock_exception = MagicMock(spec=Exception)

# Call the method with the mock exception
result = code_execution._get_error_prompt(mock_exception)

# Assert that the CorrectErrorPrompt is returned
assert isinstance(result, CorrectErrorPrompt)
20 changes: 18 additions & 2 deletions tests/test_smartdataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,17 @@ def smart_dataframe(self, llm, sample_df, custom_head):
custom_head=custom_head,
)

@pytest.fixture
def llm_result_mocks(self, custom_head):
result_template = "result = {{ 'type': '{type}', 'value': {value} }}"

return {
"number": result_template.format(type="number", value=1),
"string": result_template.format(type="string", value="'Test'"),
"plot": result_template.format(type="plot", value="'temp_plot.png'"),
"dataframe": result_template.format(type="dataframe", value=custom_head),
}

@pytest.fixture
def smart_dataframe_mocked_df(self, llm, sample_df, custom_head):
smart_df = SmartDataframe(
Expand Down Expand Up @@ -225,7 +236,10 @@ def test_run_with_privacy_enforcement(self, llm):
],
],
)
def test_run_passing_output_type(self, llm, output_type, output_type_hint):
@patch("pandasai.responses.response_parser.ResponseParser.parse", autospec=True)
def test_run_passing_output_type(
self, parser_mock, llm, llm_result_mocks, output_type, output_type_hint
):
df = pd.DataFrame({"country": []})
df = SmartDataframe(df, config={"llm": llm, "enable_cache": False})

Expand Down Expand Up @@ -255,12 +269,14 @@ def test_run_passing_output_type(self, llm, output_type, output_type_hint):


Generate python code and return full updated code:"""
parser_mock.return_value = Mock()
type_ = output_type if output_type is not None else "string"
llm._output = llm_result_mocks[type_]

df.chat("How many countries are in the dataframe?", output_type=output_type)
last_prompt = df.last_prompt
if sys.platform.startswith("win"):
last_prompt = df.last_prompt.replace("\r\n", "\n")

assert last_prompt == expected_prompt

@pytest.mark.parametrize(
Expand Down
Loading