Skip to content

Commit

Permalink
Chore(SemanticAgent): improve time based aggregation and filtering (#…
Browse files Browse the repository at this point in the history
…1230)

* fix: timedimension for the case of line plot

* chore(semantic_agent): improve semantic agent

* fix: ruff errors
  • Loading branch information
ArslanSaleem authored Jun 13, 2024
1 parent 954f5f3 commit 668e6fe
Show file tree
Hide file tree
Showing 18 changed files with 980 additions and 204 deletions.
5 changes: 4 additions & 1 deletion pandasai/connectors/pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Union

import duckdb
import sqlglot
from pydantic import BaseModel

import pandasai.pandas as pd
Expand Down Expand Up @@ -165,6 +166,7 @@ def enable_sql_query(self, table_name=None):
raise PandasConnectorTableNotFound("Table name not found!")

table = table_name or self.name

duckdb_relation = duckdb.from_df(self.pandas_df)
duckdb_relation.create(table)
self.sql_enabled = True
Expand All @@ -173,7 +175,8 @@ def enable_sql_query(self, table_name=None):
def execute_direct_sql_query(self, sql_query):
if not self.sql_enabled:
self.enable_sql_query()
sql_query = sql_query.replace("`", '"')

sql_query = sqlglot.transpile(sql_query, read="mysql", write="duckdb")[0]
return duckdb.query(sql_query).df()

@property
Expand Down
3 changes: 2 additions & 1 deletion pandasai/connectors/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from functools import cache, cached_property
from typing import Optional, Union

import sqlglot
from sqlalchemy import asc, create_engine, select, text
from sqlalchemy.engine import Connection

Expand Down Expand Up @@ -651,7 +652,7 @@ def cs_table_name(self):
return f'"{self.config.table}"'

def execute_direct_sql_query(self, sql_query):
sql_query = sql_query.replace("`", '"')
sql_query = sqlglot.transpile(sql_query, read="mysql", write="postgres")[0]
return super().execute_direct_sql_query(sql_query)


Expand Down
33 changes: 32 additions & 1 deletion pandasai/ee/agents/semantic_agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from pandasai.ee.agents.semantic_agent.prompts.generate_df_schema import (
GenerateDFSchemaPrompt,
)
from pandasai.exceptions import InvalidConfigError, InvalidTrainJson
from pandasai.exceptions import InvalidConfigError, InvalidSchemaJson, InvalidTrainJson
from pandasai.helpers.cache import Cache
from pandasai.helpers.memory import Memory
from pandasai.llm.bamboo_llm import BambooLLM
Expand Down Expand Up @@ -51,6 +51,8 @@ def __init__(

self._create_schema()

self._sort_dfs_according_to_schema()

self.init_duckdb_instance()

# semantic agent works only with direct sql true
Expand Down Expand Up @@ -125,8 +127,37 @@ def query(self, query):
def init_duckdb_instance(self):
for index, tables in enumerate(self._schema):
if isinstance(self.dfs[index], PandasConnector):
self._sync_pandas_dataframe_schema(self.dfs[index], tables)
self.dfs[index].enable_sql_query(tables["table"])

def _sync_pandas_dataframe_schema(self, df: PandasConnector, schema: dict):
for dimension in schema["dimensions"]:
if dimension["type"] == "date":
column = dimension["sql"]
df.pandas_df[column] = pd.to_datetime(df.pandas_df[column])

def _sort_dfs_according_to_schema(self):
schema_dict = {
table["table"]: [dim["sql"] for dim in table["dimensions"]]
for table in self._schema
}
sorted_dfs = []

for table in self._schema:
matched = False
for df in self.dfs:
df_columns = df.get_head().columns
if all(column in df_columns for column in schema_dict[table["table"]]):
sorted_dfs.append(df)
matched = True

if not matched:
raise InvalidSchemaJson(
f"Some sql column of table {table['table']} doesn't match with any dataframe"
)

self.dfs = sorted_dfs

def _create_schema(self):
"""
Generate schema on the initialization of Agent class
Expand Down
85 changes: 60 additions & 25 deletions pandasai/ee/agents/semantic_agent/pipeline/code_generator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import traceback
from typing import Any, Callable

from pandasai.ee.helpers.query_builder import QueryBuilder
Expand All @@ -13,17 +14,21 @@ class CodeGenerator(BaseLogicUnit):
"""

def __init__(
self, on_code_generation: Callable[[str, Exception], None] = None, **kwargs
self,
on_code_generation: Callable[[str, Exception], None] = None,
on_failure=None,
**kwargs,
):
super().__init__(**kwargs)
self.on_code_generation = on_code_generation
self.on_failure = on_failure

def execute(self, input: Any, **kwargs) -> Any:
def execute(self, input_data: Any, **kwargs) -> Any:
"""
This method will return output according to
Implementation.
:param input: Your input data.
:param input_data: Your input data.
:param kwargs: A dictionary of keyword arguments.
- 'logger' (any): The logger for logging.
- 'config' (Config): Global configurations for the test
Expand All @@ -36,13 +41,16 @@ def execute(self, input: Any, **kwargs) -> Any:
schema = pipeline_context.get("df_schema")
query_builder = QueryBuilder(schema)

sql_query = query_builder.generate_sql(input)
retry_count = 0
while retry_count <= pipeline_context.config.max_retries:
try:
sql_query = query_builder.generate_sql(input_data)

response_type = self._get_type(input)
response_type = self._get_type(input_data)

gen_code = self._generate_code(response_type, input)
gen_code = self._generate_code(response_type, input_data)

code = f"""
code = f"""
{"import matplotlib.pyplot as plt" if response_type == "plot" else ""}
import pandas as pd
Expand All @@ -52,27 +60,46 @@ def execute(self, input: Any, **kwargs) -> Any:
{gen_code}
"""

logger.log(f"""Code Generated: {code}""")
logger.log(f"""Code Generated: {code}""")

# Implement error handling pipeline here...
# Implement error handling pipeline here...

return LogicUnitOutput(
code,
True,
"Code Generated Successfully",
{"content_type": "string", "value": code},
)
return LogicUnitOutput(
code,
True,
"Code Generated Successfully",
{"content_type": "string", "value": code},
)
except Exception:
if (
retry_count == pipeline_context.config.max_retries
or not self.on_failure
):
raise

traceback_errors = traceback.format_exc()

input_data = self.on_failure(input, traceback_errors)

retry_count += 1

def _get_type(self, input: dict) -> bool:
return "number" if input["type"] == "number" else "plot"
return (
"plot"
if input["type"] in ["bar", "line", "histogram", "pie", "scatter"]
else input["type"]
)

def _generate_code(self, type, query):
if type == "number":
code = self._generate_code_for_number(query)

# Format code final output
return f"""
result = {{"type": "number","value": {code}}}
{code}
result = {{"type": "number","value": total_value}}
"""
elif type == "dataframe":
return """
result = {{"type": "dataframe","value": data}}
"""
else:
code = self.generate_matplotlib_code(query)
Expand All @@ -88,7 +115,7 @@ def _generate_code_for_number(self, query: dict) -> str:
else:
value = query["dimensions"][0].split(".")[1]

return f'data["{value}"].iloc[0]'
return f'total_value = data["{value}"].sum()\n'

def generate_matplotlib_code(self, query: dict) -> str:
chart_type = query["type"]
Expand Down Expand Up @@ -135,11 +162,11 @@ def generate_matplotlib_code(self, query: dict) -> str:
code += code_generator(query)

if x_label:
code += f"plt.xlabel('{x_label}')\n"
code += f"plt.xlabel('''{x_label}''')\n"
if y_label:
code += f"plt.ylabel('{y_label}')\n"
code += f"plt.ylabel('''{y_label}''')\n"
if title:
code += f"plt.title('{title}')\n"
code += f"plt.title('''{title}''')\n"

if legend_display:
code += f"plt.legend(loc='{legend_position}')\n"
Expand All @@ -151,7 +178,7 @@ def generate_matplotlib_code(self, query: dict) -> str:
return code

def _generate_bar_code(self, query):
x_key = query["dimensions"][0].split(".")[1]
x_key = self._get_dimensions_key(query)
plots = ""
for measure in query["measures"]:
if isinstance(measure, str):
Expand All @@ -173,7 +200,7 @@ def _generate_pie_code(self, query):
return f"""plt.pie(data["{measure}"], labels=data["{dimension}"], autopct='%1.1f%%')\n"""

def _generate_line_code(self, query):
x_key = query["dimensions"][0].split(".")[1]
x_key = self._get_dimensions_key(query)
plots = ""
for measure in query["measures"]:
field_name = measure.split(".")[1]
Expand All @@ -193,3 +220,11 @@ def _generate_hist_code(self, query):
def _generate_box_code(self, query):
y_key = query["measures"][0].split(".")[1]
return f"plt.boxplot(data['{y_key}'])\n"

def _get_dimensions_key(self, query):
if "dimensions" in query and len(query["dimensions"]) > 0:
return query["dimensions"][0].split(".")[1]

time_dimension = query["timeDimensions"][0]
dimension = time_dimension["dimension"].split(".")[1]
return f"{dimension}_by_{time_dimension['granularity']}"
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from typing import Optional

from pandasai.ee.agents.semantic_agent.pipeline.code_generator import CodeGenerator
from pandasai.ee.agents.semantic_agent.pipeline.error_correction_pipeline.fix_semantic_json_pipeline import (
FixSemanticJsonPipeline,
)
from pandasai.ee.agents.semantic_agent.pipeline.llm_call import LLMCall
from pandasai.ee.agents.semantic_agent.pipeline.Semantic_prompt_generation import (
SemanticPromptGeneration,
Expand Down Expand Up @@ -40,14 +43,42 @@ def __init__(
on_execution=on_prompt_generation,
),
LLMCall(),
CodeGenerator(on_execution=on_code_generation),
CodeGenerator(
on_execution=on_code_generation,
on_failure=self.on_wrong_semantic_json,
),
CodeCleaning(),
],
)

self.fix_semantic_json_pipeline = FixSemanticJsonPipeline(
context=context,
logger=logger,
query_exec_tracker=query_exec_tracker,
on_code_generation=on_code_generation,
on_prompt_generation=on_prompt_generation,
)

self._context = context
self._logger = logger

def run(self, input: ErrorCorrectionPipelineInput):
self._logger.log(f"Executing Pipeline: {self.__class__.__name__}")
return self.pipeline.run(input)

def on_wrong_semantic_json(self, code, errors):
self.query_exec_tracker.add_step(
{
"type": "CodeGenerator",
"success": False,
"message": "Failed to validate json",
"execution_time": None,
"data": {
"content_type": "code",
"value": code,
"exception": errors,
},
}
)
correction_input = ErrorCorrectionPipelineInput(code, errors)
return self.fix_semantic_json_pipeline.run(correction_input)
Loading

0 comments on commit 668e6fe

Please sign in to comment.