Skip to content

Commit

Permalink
refactor: simplify prompt for VertexAI and Claude
Browse files Browse the repository at this point in the history
  • Loading branch information
gventuri committed Nov 24, 2023
1 parent 90f193e commit ca85093
Show file tree
Hide file tree
Showing 17 changed files with 81 additions and 112 deletions.
8 changes: 2 additions & 6 deletions pandasai/assets/prompt_templates/current_code.tmpl
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
# TODO: import the required dependencies
{default_import}

"""
{dfs_declared_message}
{instructions}
# Write code here

Return a "result" variable dict:
{output_type_hint}
"""
# Declare result var: {output_type_hint}
4 changes: 0 additions & 4 deletions pandasai/assets/prompt_templates/default_instructions.tmpl

This file was deleted.

1 change: 1 addition & 0 deletions pandasai/assets/prompt_templates/generate_python_code.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@
```

{last_message}
Variable `dfs: list[pd.DataFrame]` is already declared.
{reasoning}
7 changes: 6 additions & 1 deletion pandasai/assets/prompt_templates/simple_reasoning.tmpl
Original file line number Diff line number Diff line change
@@ -1 +1,6 @@
Return the full updated code:

At the end, declare "result" var dict: {output_type_hint}
{viz_library_type}
{instructions}

Generate python code and return full updated code:
2 changes: 1 addition & 1 deletion pandasai/assets/prompt_templates/viz_library.tmpl
Original file line number Diff line number Diff line change
@@ -1 +1 @@
If the user requests to create a chart, utilize the Python {library} library to generate high-quality graphics that will be saved directly to a file.
If you are asked to plot a chart, use "{library}" for charts, save as png.
27 changes: 5 additions & 22 deletions pandasai/helpers/output_types/_output_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,7 @@ def validate(self, result: dict[str, Any]) -> tuple[bool, Iterable[str]]:
class NumberOutputType(BaseOutputType):
@property
def template_hint(self):
return """- type (must be "number")
- value (must be a number)
Example output: { "type": "number", "value": 125 }"""
return """type (must be "number"), value must int. Example: { "type": "number", "value": 125 }""" # noqa E501

@property
def name(self):
Expand All @@ -79,9 +77,7 @@ def _validate_value(self, actual_value: Any) -> bool:
class DataFrameOutputType(BaseOutputType):
@property
def template_hint(self):
return """- type (must be "dataframe")
- value (must be a pandas dataframe)
Example output: { "type": "dataframe", "value": pd.DataFrame({...}) }"""
return """type (must be "dataframe"), value must be pd.DataFrame or pd.Series. Example: { "type": "dataframe", "value": pd.DataFrame({...}) }""" # noqa E501

@property
def name(self):
Expand All @@ -94,9 +90,7 @@ def _validate_value(self, actual_value: Any) -> bool:
class PlotOutputType(BaseOutputType):
@property
def template_hint(self):
return """- type (must be "plot")
- value (must be a string containing the path of the plot image)
Example output: { "type": "plot", "value": "export/charts/temp_chart.png" }"""
return """type (must be "plot"), value must be string. Example: { "type": "plot", "value": "temp_chart.png" }""" # noqa E501

@property
def name(self):
Expand All @@ -113,9 +107,7 @@ def _validate_value(self, actual_value: Any) -> bool:
class StringOutputType(BaseOutputType):
@property
def template_hint(self):
return """- type (must be "string")
- value (must be a conversational answer, as a string)
Example output: { "type": "string", "value": f"The highest salary is {highest_salary}." }""" # noqa E501
return """type (must be "string"), value must be string. Example: { "type": "string", "value": f"The highest salary is {highest_salary}." }""" # noqa E501

@property
def name(self):
Expand All @@ -128,16 +120,7 @@ def _validate_value(self, actual_value: Any) -> bool:
class DefaultOutputType(BaseOutputType):
@property
def template_hint(self):
return """- type (possible values "string", "number", "dataframe", "plot")
- value (can be a string, a dataframe or the path of the plot, NOT a dictionary)
Examples:
{ "type": "string", "value": f"The highest salary is {highest_salary}." }
or
{ "type": "number", "value": 125 }
or
{ "type": "dataframe", "value": pd.DataFrame({...}) }
or
{ "type": "plot", "value": "temp_chart.png" }""" # noqa E501
return """type (possible values "string", "number", "dataframe", "plot"). Examples: { "type": "string", "value": f"The highest salary is {highest_salary}." } or { "type": "number", "value": 125 } or { "type": "dataframe", "value": pd.DataFrame({...}) } or { "type": "plot", "value": "temp_chart.png" }""" # noqa E501

@property
def name(self):
Expand Down
3 changes: 2 additions & 1 deletion pandasai/helpers/viz_library_types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .base import VisualizationLibrary

from ._viz_library_types import (
NoVizLibraryType,
MatplotlibVizLibraryType,
PlotlyVizLibraryType,
SeabornVizLibraryType,
Expand Down Expand Up @@ -55,7 +56,7 @@ def viz_lib_type_factory(
level=logging.WARNING,
)

viz_lib_default = MatplotlibVizLibraryType
viz_lib_default = NoVizLibraryType
viz_lib_type_helper = viz_lib_map.get(viz_lib_type, viz_lib_default)()

if logger:
Expand Down
10 changes: 10 additions & 0 deletions pandasai/helpers/viz_library_types/_viz_library_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,16 @@ def validate(self, result: dict[str, Any]) -> tuple[bool, Iterable[str]]:
return type_ok, validation_logs


class NoVizLibraryType(BaseVizLibraryType):
@property
def template_hint(self) -> str:
return ""

@property
def name(self):
return "no_viz_library"


class MatplotlibVizLibraryType(BaseVizLibraryType):
@property
def name(self):
Expand Down
2 changes: 1 addition & 1 deletion pandasai/llm/google_vertexai.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def _generate_text(self, prompt: str) -> str:
else:
raise UnsupportedModelError(self.model)

return str(completion)
return completion.text

@property
def type(self) -> str:
Expand Down
3 changes: 3 additions & 0 deletions pandasai/pipelines/smart_datalake_chat/code_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ def _retry_run_code(
"engine": context.dfs[0].engine,
"code": code,
"error_returned": e,
"output_type_hint": context.get_intermediate_value(
"output_type_helper"
).template_hint,
}
error_correcting_instruction = context.get_intermediate_value("get_prompt")(
"correct_error",
Expand Down
4 changes: 1 addition & 3 deletions pandasai/pipelines/smart_datalake_chat/prompt_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,7 @@ def execute(self, input: Any, **kwargs) -> Any:
default_values["current_code"] = pipeline_context.get_intermediate_value(
"last_code_generated"
)
default_values[
"code_description"
] = "This is the code generated to answer the previous question:" # noqa: E501
default_values["code_description"] = ""

[key, default_prompt] = self._get_chat_prompt(pipeline_context)

Expand Down
2 changes: 1 addition & 1 deletion pandasai/prompts/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def _generate_dataframes(self, dfs):

dataframes.append(dataframe_info)

return "\n\n".join(dataframes)
return "\n".join(dataframes)

@property
@abstractmethod
Expand Down
3 changes: 1 addition & 2 deletions pandasai/prompts/direct_sql_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from .generate_python_code import (
CurrentCodePrompt,
SimpleReasoningPrompt,
DefaultInstructionsPrompt,
)


Expand Down Expand Up @@ -31,7 +30,7 @@ def setup(self, tables, **kwargs) -> None:
if "custom_instructions" in kwargs:
self.set_var("instructions", kwargs["custom_instructions"])
else:
self.set_var("instructions", DefaultInstructionsPrompt())
self.set_var("instructions", "")

if "current_code" in kwargs:
self.set_var("current_code", kwargs["current_code"])
Expand Down
8 changes: 1 addition & 7 deletions pandasai/prompts/generate_python_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,6 @@ def setup(self, **kwargs) -> None:
self.set_var("dfs_declared_message", "")


class DefaultInstructionsPrompt(FileBasedPrompt):
"""The default instructions"""

_path_to_template = "assets/prompt_templates/default_instructions.tmpl"


class SimpleReasoningPrompt(FileBasedPrompt):
"""The simple reasoning instructions"""

Expand All @@ -62,7 +56,7 @@ def setup(self, **kwargs) -> None:
if "custom_instructions" in kwargs:
self.set_var("instructions", kwargs["custom_instructions"])
else:
self.set_var("instructions", DefaultInstructionsPrompt())
self.set_var("instructions", "")

if "current_code" in kwargs:
self.set_var("current_code", kwargs["current_code"])
Expand Down
19 changes: 9 additions & 10 deletions tests/prompts/test_generate_python_code_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,20 +91,19 @@ def test_str_with_args(
# TODO: import the required dependencies
import pandas as pd
\"\"\"
The variable `dfs: list[pd.DataFrame]` is already declared.
1. Prep: preprocessing/cleaning
2. Proc: data manipulation (group, filter, aggregate)
3. Analyze data
{viz_library_type_hint}
# Write code here
Return a "result" variable dict:
{output_type_hint}
\"\"\"
# Declare result var: {output_type_hint}
```
Q: Question
Return the full updated code:""" # noqa E501
Variable `dfs: list[pd.DataFrame]` is already declared.
At the end, declare "result" var dict: {output_type_hint}
{viz_library_type_hint}
Generate python code and return full updated code:""" # noqa E501
actual_prompt_content = prompt.to_string()
if sys.platform.startswith("win"):
actual_prompt_content = actual_prompt_content.replace("\r\n", "\n")
Expand Down
9 changes: 7 additions & 2 deletions tests/prompts/test_sql_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def test_direct_sql_prompt_with_params(

assert (
prompt_content
== '''<tables>
== f'''<tables>
<table name="None">
Expand All @@ -107,5 +107,10 @@ def execute_sql_query(sql_query: str) -> pd.Dataframe
```
Return the full updated code:''' # noqa: E501
At the end, declare "result" var dict: {output_type_hint}
{viz_library_type_hint}
Generate python code and return full updated code:''' # noqa: E501
)
81 changes: 30 additions & 51 deletions tests/test_smartdataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from pandasai.prompts import AbstractPrompt, GeneratePythonCodePrompt
from pandasai.helpers.cache import Cache
from pandasai.helpers.viz_library_types import (
MatplotlibVizLibraryType,
NoVizLibraryType,
viz_lib_map,
viz_lib_type_factory,
)
Expand Down Expand Up @@ -195,29 +195,19 @@ def test_run_with_privacy_enforcement(self, llm):
# TODO: import the required dependencies
import pandas as pd
\"\"\"
The variable `dfs: list[pd.DataFrame]` is already declared.
1. Prep: preprocessing/cleaning
2. Proc: data manipulation (group, filter, aggregate)
3. Analyze data
If the user requests to create a chart, utilize the Python matplotlib library to generate high-quality graphics that will be saved directly to a file.
Return a "result" variable dict:
- type (possible values "string", "number", "dataframe", "plot")
- value (can be a string, a dataframe or the path of the plot, NOT a dictionary)
Examples:
{ "type": "string", "value": f"The highest salary is {highest_salary}." }
or
{ "type": "number", "value": 125 }
or
{ "type": "dataframe", "value": pd.DataFrame({...}) }
or
{ "type": "plot", "value": "temp_chart.png" }
\"\"\"
# Write code here
# Declare result var: type (possible values "string", "number", "dataframe", "plot"). Examples: { "type": "string", "value": f"The highest salary is {highest_salary}." } or { "type": "number", "value": 125 } or { "type": "dataframe", "value": pd.DataFrame({...}) } or { "type": "plot", "value": "temp_chart.png" }
```
Q: How many countries are in the dataframe?
Return the full updated code:""" # noqa: E501
Variable `dfs: list[pd.DataFrame]` is already declared.
At the end, declare "result" var dict: type (possible values "string", "number", "dataframe", "plot"). Examples: { "type": "string", "value": f"The highest salary is {highest_salary}." } or { "type": "number", "value": 125 } or { "type": "dataframe", "value": pd.DataFrame({...}) } or { "type": "plot", "value": "temp_chart.png" }
Generate python code and return full updated code:""" # noqa: E501
df.chat("How many countries are in the dataframe?")
last_prompt = df.last_prompt
if sys.platform.startswith("win"):
Expand Down Expand Up @@ -252,20 +242,19 @@ def test_run_passing_output_type(self, llm, output_type, output_type_hint):
# TODO: import the required dependencies
import pandas as pd
\"\"\"
The variable `dfs: list[pd.DataFrame]` is already declared.
1. Prep: preprocessing/cleaning
2. Proc: data manipulation (group, filter, aggregate)
3. Analyze data
If the user requests to create a chart, utilize the Python matplotlib library to generate high-quality graphics that will be saved directly to a file.
# Write code here
Return a "result" variable dict:
{output_type_hint}
\"\"\"
# Declare result var: {output_type_hint}
```
Q: How many countries are in the dataframe?
Return the full updated code:"""
Variable `dfs: list[pd.DataFrame]` is already declared.
At the end, declare "result" var dict: {output_type_hint}
Generate python code and return full updated code:"""

df.chat("How many countries are in the dataframe?", output_type=output_type)
last_prompt = df.last_prompt
Expand Down Expand Up @@ -990,7 +979,7 @@ def test_head_csv_with_custom_head(
@pytest.mark.parametrize(
"viz_library_type,viz_library_type_hint",
[
(None, MatplotlibVizLibraryType().template_hint),
(None, NoVizLibraryType().template_hint),
*[
(type_, viz_lib_type_factory(type_).template_hint)
for type_ in viz_lib_map
Expand Down Expand Up @@ -1024,29 +1013,19 @@ def test_run_passing_viz_library_type(
# TODO: import the required dependencies
import pandas as pd
\"\"\"
The variable `dfs: list[pd.DataFrame]` is already declared.
1. Prep: preprocessing/cleaning
2. Proc: data manipulation (group, filter, aggregate)
3. Analyze data
%s
# Write code here
Return a "result" variable dict:
- type (possible values "string", "number", "dataframe", "plot")
- value (can be a string, a dataframe or the path of the plot, NOT a dictionary)
Examples:
{ "type": "string", "value": f"The highest salary is {highest_salary}." }
or
{ "type": "number", "value": 125 }
or
{ "type": "dataframe", "value": pd.DataFrame({...}) }
or
{ "type": "plot", "value": "temp_chart.png" }
\"\"\"
# Declare result var: type (possible values "string", "number", "dataframe", "plot"). Examples: { "type": "string", "value": f"The highest salary is {highest_salary}." } or { "type": "number", "value": 125 } or { "type": "dataframe", "value": pd.DataFrame({...}) } or { "type": "plot", "value": "temp_chart.png" }
```
Q: Plot the histogram of countries showing for each the gdp with distinct bar colors
Return the full updated code:""" # noqa: E501
Variable `dfs: list[pd.DataFrame]` is already declared.
At the end, declare "result" var dict: type (possible values "string", "number", "dataframe", "plot"). Examples: { "type": "string", "value": f"The highest salary is {highest_salary}." } or { "type": "number", "value": 125 } or { "type": "dataframe", "value": pd.DataFrame({...}) } or { "type": "plot", "value": "temp_chart.png" }
%s
Generate python code and return full updated code:""" # noqa: E501
% viz_library_type_hint
)

Expand Down

0 comments on commit ca85093

Please sign in to comment.