diff --git a/pandasai/assets/prompt_templates/current_code.tmpl b/pandasai/assets/prompt_templates/current_code.tmpl index 2b6591df7..6ac77c5f5 100644 --- a/pandasai/assets/prompt_templates/current_code.tmpl +++ b/pandasai/assets/prompt_templates/current_code.tmpl @@ -1,10 +1,6 @@ # TODO: import the required dependencies {default_import} -""" -{dfs_declared_message} -{instructions} +# Write code here -Return a "result" variable dict: -{output_type_hint} -""" \ No newline at end of file +# Declare result var: {output_type_hint} \ No newline at end of file diff --git a/pandasai/assets/prompt_templates/default_instructions.tmpl b/pandasai/assets/prompt_templates/default_instructions.tmpl deleted file mode 100644 index 1423cd6be..000000000 --- a/pandasai/assets/prompt_templates/default_instructions.tmpl +++ /dev/null @@ -1,4 +0,0 @@ -1. Prep: preprocessing/cleaning -2. Proc: data manipulation (group, filter, aggregate) -3. Analyze data -{viz_library_type} \ No newline at end of file diff --git a/pandasai/assets/prompt_templates/generate_python_code.tmpl b/pandasai/assets/prompt_templates/generate_python_code.tmpl index 2f997dad3..ba531b57b 100644 --- a/pandasai/assets/prompt_templates/generate_python_code.tmpl +++ b/pandasai/assets/prompt_templates/generate_python_code.tmpl @@ -9,4 +9,5 @@ ``` {last_message} +Variable `dfs: list[pd.DataFrame]` is already declared. {reasoning} \ No newline at end of file diff --git a/pandasai/assets/prompt_templates/simple_reasoning.tmpl b/pandasai/assets/prompt_templates/simple_reasoning.tmpl index dbb090a52..4be8fbf38 100644 --- a/pandasai/assets/prompt_templates/simple_reasoning.tmpl +++ b/pandasai/assets/prompt_templates/simple_reasoning.tmpl @@ -1 +1,6 @@ -Return the full updated code: \ No newline at end of file + +At the end, declare "result" var dict: {output_type_hint} +{viz_library_type} +{instructions} + +Generate python code and return full updated code: \ No newline at end of file diff --git a/pandasai/assets/prompt_templates/viz_library.tmpl b/pandasai/assets/prompt_templates/viz_library.tmpl index ee3b7ed25..c01306e4c 100644 --- a/pandasai/assets/prompt_templates/viz_library.tmpl +++ b/pandasai/assets/prompt_templates/viz_library.tmpl @@ -1 +1 @@ -If the user requests to create a chart, utilize the Python {library} library to generate high-quality graphics that will be saved directly to a file. \ No newline at end of file +If you are asked to plot a chart, use "{library}" for charts, save as png. \ No newline at end of file diff --git a/pandasai/helpers/output_types/_output_types.py b/pandasai/helpers/output_types/_output_types.py index 266874243..87c811b2f 100644 --- a/pandasai/helpers/output_types/_output_types.py +++ b/pandasai/helpers/output_types/_output_types.py @@ -64,9 +64,7 @@ def validate(self, result: dict[str, Any]) -> tuple[bool, Iterable[str]]: class NumberOutputType(BaseOutputType): @property def template_hint(self): - return """- type (must be "number") - - value (must be a number) - Example output: { "type": "number", "value": 125 }""" + return """type (must be "number"), value must int. Example: { "type": "number", "value": 125 }""" # noqa E501 @property def name(self): @@ -79,9 +77,7 @@ def _validate_value(self, actual_value: Any) -> bool: class DataFrameOutputType(BaseOutputType): @property def template_hint(self): - return """- type (must be "dataframe") - - value (must be a pandas dataframe) - Example output: { "type": "dataframe", "value": pd.DataFrame({...}) }""" + return """type (must be "dataframe"), value must be pd.DataFrame or pd.Series. Example: { "type": "dataframe", "value": pd.DataFrame({...}) }""" # noqa E501 @property def name(self): @@ -94,9 +90,7 @@ def _validate_value(self, actual_value: Any) -> bool: class PlotOutputType(BaseOutputType): @property def template_hint(self): - return """- type (must be "plot") - - value (must be a string containing the path of the plot image) - Example output: { "type": "plot", "value": "export/charts/temp_chart.png" }""" + return """type (must be "plot"), value must be string. Example: { "type": "plot", "value": "temp_chart.png" }""" # noqa E501 @property def name(self): @@ -113,9 +107,7 @@ def _validate_value(self, actual_value: Any) -> bool: class StringOutputType(BaseOutputType): @property def template_hint(self): - return """- type (must be "string") - - value (must be a conversational answer, as a string) - Example output: { "type": "string", "value": f"The highest salary is {highest_salary}." }""" # noqa E501 + return """type (must be "string"), value must be string. Example: { "type": "string", "value": f"The highest salary is {highest_salary}." }""" # noqa E501 @property def name(self): @@ -128,16 +120,7 @@ def _validate_value(self, actual_value: Any) -> bool: class DefaultOutputType(BaseOutputType): @property def template_hint(self): - return """- type (possible values "string", "number", "dataframe", "plot") - - value (can be a string, a dataframe or the path of the plot, NOT a dictionary) - Examples: - { "type": "string", "value": f"The highest salary is {highest_salary}." } - or - { "type": "number", "value": 125 } - or - { "type": "dataframe", "value": pd.DataFrame({...}) } - or - { "type": "plot", "value": "temp_chart.png" }""" # noqa E501 + return """type (possible values "string", "number", "dataframe", "plot"). Examples: { "type": "string", "value": f"The highest salary is {highest_salary}." } or { "type": "number", "value": 125 } or { "type": "dataframe", "value": pd.DataFrame({...}) } or { "type": "plot", "value": "temp_chart.png" }""" # noqa E501 @property def name(self): diff --git a/pandasai/helpers/viz_library_types/__init__.py b/pandasai/helpers/viz_library_types/__init__.py index b47eb8110..c3a66d229 100644 --- a/pandasai/helpers/viz_library_types/__init__.py +++ b/pandasai/helpers/viz_library_types/__init__.py @@ -3,6 +3,7 @@ from .base import VisualizationLibrary from ._viz_library_types import ( + NoVizLibraryType, MatplotlibVizLibraryType, PlotlyVizLibraryType, SeabornVizLibraryType, @@ -55,7 +56,7 @@ def viz_lib_type_factory( level=logging.WARNING, ) - viz_lib_default = MatplotlibVizLibraryType + viz_lib_default = NoVizLibraryType viz_lib_type_helper = viz_lib_map.get(viz_lib_type, viz_lib_default)() if logger: diff --git a/pandasai/helpers/viz_library_types/_viz_library_types.py b/pandasai/helpers/viz_library_types/_viz_library_types.py index 3c9ae66e3..0eb07aa4e 100644 --- a/pandasai/helpers/viz_library_types/_viz_library_types.py +++ b/pandasai/helpers/viz_library_types/_viz_library_types.py @@ -46,6 +46,16 @@ def validate(self, result: dict[str, Any]) -> tuple[bool, Iterable[str]]: return type_ok, validation_logs +class NoVizLibraryType(BaseVizLibraryType): + @property + def template_hint(self) -> str: + return "" + + @property + def name(self): + return "no_viz_library" + + class MatplotlibVizLibraryType(BaseVizLibraryType): @property def name(self): diff --git a/pandasai/llm/google_vertexai.py b/pandasai/llm/google_vertexai.py index 261b65ff1..eb9aba9f9 100644 --- a/pandasai/llm/google_vertexai.py +++ b/pandasai/llm/google_vertexai.py @@ -125,7 +125,7 @@ def _generate_text(self, prompt: str) -> str: else: raise UnsupportedModelError(self.model) - return str(completion) + return completion.text @property def type(self) -> str: diff --git a/pandasai/pipelines/smart_datalake_chat/code_execution.py b/pandasai/pipelines/smart_datalake_chat/code_execution.py index 08e60699a..c08ec5afe 100644 --- a/pandasai/pipelines/smart_datalake_chat/code_execution.py +++ b/pandasai/pipelines/smart_datalake_chat/code_execution.py @@ -101,6 +101,9 @@ def _retry_run_code( "engine": context.dfs[0].engine, "code": code, "error_returned": e, + "output_type_hint": context.get_intermediate_value( + "output_type_helper" + ).template_hint, } error_correcting_instruction = context.get_intermediate_value("get_prompt")( "correct_error", diff --git a/pandasai/pipelines/smart_datalake_chat/prompt_generation.py b/pandasai/pipelines/smart_datalake_chat/prompt_generation.py index 50d068a14..9e1d79b47 100644 --- a/pandasai/pipelines/smart_datalake_chat/prompt_generation.py +++ b/pandasai/pipelines/smart_datalake_chat/prompt_generation.py @@ -60,9 +60,7 @@ def execute(self, input: Any, **kwargs) -> Any: default_values["current_code"] = pipeline_context.get_intermediate_value( "last_code_generated" ) - default_values[ - "code_description" - ] = "This is the code generated to answer the previous question:" # noqa: E501 + default_values["code_description"] = "" [key, default_prompt] = self._get_chat_prompt(pipeline_context) diff --git a/pandasai/prompts/base.py b/pandasai/prompts/base.py index 20b004830..898d885d1 100644 --- a/pandasai/prompts/base.py +++ b/pandasai/prompts/base.py @@ -62,7 +62,7 @@ def _generate_dataframes(self, dfs): dataframes.append(dataframe_info) - return "\n\n".join(dataframes) + return "\n".join(dataframes) @property @abstractmethod diff --git a/pandasai/prompts/direct_sql_prompt.py b/pandasai/prompts/direct_sql_prompt.py index e4dfae89a..37c28d5b4 100644 --- a/pandasai/prompts/direct_sql_prompt.py +++ b/pandasai/prompts/direct_sql_prompt.py @@ -3,7 +3,6 @@ from .generate_python_code import ( CurrentCodePrompt, SimpleReasoningPrompt, - DefaultInstructionsPrompt, ) @@ -31,7 +30,7 @@ def setup(self, tables, **kwargs) -> None: if "custom_instructions" in kwargs: self.set_var("instructions", kwargs["custom_instructions"]) else: - self.set_var("instructions", DefaultInstructionsPrompt()) + self.set_var("instructions", "") if "current_code" in kwargs: self.set_var("current_code", kwargs["current_code"]) diff --git a/pandasai/prompts/generate_python_code.py b/pandasai/prompts/generate_python_code.py index 4fc9c59ab..a41930974 100644 --- a/pandasai/prompts/generate_python_code.py +++ b/pandasai/prompts/generate_python_code.py @@ -35,12 +35,6 @@ def setup(self, **kwargs) -> None: self.set_var("dfs_declared_message", "") -class DefaultInstructionsPrompt(FileBasedPrompt): - """The default instructions""" - - _path_to_template = "assets/prompt_templates/default_instructions.tmpl" - - class SimpleReasoningPrompt(FileBasedPrompt): """The simple reasoning instructions""" @@ -62,7 +56,7 @@ def setup(self, **kwargs) -> None: if "custom_instructions" in kwargs: self.set_var("instructions", kwargs["custom_instructions"]) else: - self.set_var("instructions", DefaultInstructionsPrompt()) + self.set_var("instructions", "") if "current_code" in kwargs: self.set_var("current_code", kwargs["current_code"]) diff --git a/tests/prompts/test_generate_python_code_prompt.py b/tests/prompts/test_generate_python_code_prompt.py index 130602d86..c445cd222 100644 --- a/tests/prompts/test_generate_python_code_prompt.py +++ b/tests/prompts/test_generate_python_code_prompt.py @@ -91,20 +91,19 @@ def test_str_with_args( # TODO: import the required dependencies import pandas as pd -\"\"\" -The variable `dfs: list[pd.DataFrame]` is already declared. -1. Prep: preprocessing/cleaning -2. Proc: data manipulation (group, filter, aggregate) -3. Analyze data -{viz_library_type_hint} +# Write code here -Return a "result" variable dict: -{output_type_hint} -\"\"\" +# Declare result var: {output_type_hint} ``` Q: Question -Return the full updated code:""" # noqa E501 +Variable `dfs: list[pd.DataFrame]` is already declared. + +At the end, declare "result" var dict: {output_type_hint} +{viz_library_type_hint} + + +Generate python code and return full updated code:""" # noqa E501 actual_prompt_content = prompt.to_string() if sys.platform.startswith("win"): actual_prompt_content = actual_prompt_content.replace("\r\n", "\n") diff --git a/tests/prompts/test_sql_prompt.py b/tests/prompts/test_sql_prompt.py index 735d67ad7..6e19d1b18 100644 --- a/tests/prompts/test_sql_prompt.py +++ b/tests/prompts/test_sql_prompt.py @@ -84,7 +84,7 @@ def test_direct_sql_prompt_with_params( assert ( prompt_content - == ''' + == f''' @@ -107,5 +107,10 @@ def execute_sql_query(sql_query: str) -> pd.Dataframe ``` -Return the full updated code:''' # noqa: E501 + +At the end, declare "result" var dict: {output_type_hint} +{viz_library_type_hint} + + +Generate python code and return full updated code:''' # noqa: E501 ) diff --git a/tests/test_smartdataframe.py b/tests/test_smartdataframe.py index 406b86c55..2920f0563 100644 --- a/tests/test_smartdataframe.py +++ b/tests/test_smartdataframe.py @@ -24,7 +24,7 @@ from pandasai.prompts import AbstractPrompt, GeneratePythonCodePrompt from pandasai.helpers.cache import Cache from pandasai.helpers.viz_library_types import ( - MatplotlibVizLibraryType, + NoVizLibraryType, viz_lib_map, viz_lib_type_factory, ) @@ -195,29 +195,19 @@ def test_run_with_privacy_enforcement(self, llm): # TODO: import the required dependencies import pandas as pd -\"\"\" -The variable `dfs: list[pd.DataFrame]` is already declared. -1. Prep: preprocessing/cleaning -2. Proc: data manipulation (group, filter, aggregate) -3. Analyze data -If the user requests to create a chart, utilize the Python matplotlib library to generate high-quality graphics that will be saved directly to a file. - -Return a "result" variable dict: -- type (possible values "string", "number", "dataframe", "plot") - - value (can be a string, a dataframe or the path of the plot, NOT a dictionary) - Examples: - { "type": "string", "value": f"The highest salary is {highest_salary}." } - or - { "type": "number", "value": 125 } - or - { "type": "dataframe", "value": pd.DataFrame({...}) } - or - { "type": "plot", "value": "temp_chart.png" } -\"\"\" +# Write code here + +# Declare result var: type (possible values "string", "number", "dataframe", "plot"). Examples: { "type": "string", "value": f"The highest salary is {highest_salary}." } or { "type": "number", "value": 125 } or { "type": "dataframe", "value": pd.DataFrame({...}) } or { "type": "plot", "value": "temp_chart.png" } ``` Q: How many countries are in the dataframe? -Return the full updated code:""" # noqa: E501 +Variable `dfs: list[pd.DataFrame]` is already declared. + +At the end, declare "result" var dict: type (possible values "string", "number", "dataframe", "plot"). Examples: { "type": "string", "value": f"The highest salary is {highest_salary}." } or { "type": "number", "value": 125 } or { "type": "dataframe", "value": pd.DataFrame({...}) } or { "type": "plot", "value": "temp_chart.png" } + + + +Generate python code and return full updated code:""" # noqa: E501 df.chat("How many countries are in the dataframe?") last_prompt = df.last_prompt if sys.platform.startswith("win"): @@ -252,20 +242,19 @@ def test_run_passing_output_type(self, llm, output_type, output_type_hint): # TODO: import the required dependencies import pandas as pd -\"\"\" -The variable `dfs: list[pd.DataFrame]` is already declared. -1. Prep: preprocessing/cleaning -2. Proc: data manipulation (group, filter, aggregate) -3. Analyze data -If the user requests to create a chart, utilize the Python matplotlib library to generate high-quality graphics that will be saved directly to a file. +# Write code here -Return a "result" variable dict: -{output_type_hint} -\"\"\" +# Declare result var: {output_type_hint} ``` Q: How many countries are in the dataframe? -Return the full updated code:""" +Variable `dfs: list[pd.DataFrame]` is already declared. + +At the end, declare "result" var dict: {output_type_hint} + + + +Generate python code and return full updated code:""" df.chat("How many countries are in the dataframe?", output_type=output_type) last_prompt = df.last_prompt @@ -990,7 +979,7 @@ def test_head_csv_with_custom_head( @pytest.mark.parametrize( "viz_library_type,viz_library_type_hint", [ - (None, MatplotlibVizLibraryType().template_hint), + (None, NoVizLibraryType().template_hint), *[ (type_, viz_lib_type_factory(type_).template_hint) for type_ in viz_lib_map @@ -1024,29 +1013,19 @@ def test_run_passing_viz_library_type( # TODO: import the required dependencies import pandas as pd -\"\"\" -The variable `dfs: list[pd.DataFrame]` is already declared. -1. Prep: preprocessing/cleaning -2. Proc: data manipulation (group, filter, aggregate) -3. Analyze data -%s +# Write code here -Return a "result" variable dict: -- type (possible values "string", "number", "dataframe", "plot") - - value (can be a string, a dataframe or the path of the plot, NOT a dictionary) - Examples: - { "type": "string", "value": f"The highest salary is {highest_salary}." } - or - { "type": "number", "value": 125 } - or - { "type": "dataframe", "value": pd.DataFrame({...}) } - or - { "type": "plot", "value": "temp_chart.png" } -\"\"\" +# Declare result var: type (possible values "string", "number", "dataframe", "plot"). Examples: { "type": "string", "value": f"The highest salary is {highest_salary}." } or { "type": "number", "value": 125 } or { "type": "dataframe", "value": pd.DataFrame({...}) } or { "type": "plot", "value": "temp_chart.png" } ``` Q: Plot the histogram of countries showing for each the gdp with distinct bar colors -Return the full updated code:""" # noqa: E501 +Variable `dfs: list[pd.DataFrame]` is already declared. + +At the end, declare "result" var dict: type (possible values "string", "number", "dataframe", "plot"). Examples: { "type": "string", "value": f"The highest salary is {highest_salary}." } or { "type": "number", "value": 125 } or { "type": "dataframe", "value": pd.DataFrame({...}) } or { "type": "plot", "value": "temp_chart.png" } +%s + + +Generate python code and return full updated code:""" # noqa: E501 % viz_library_type_hint )