-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix(Dataframe): adding default dataframe name to enable sql query on …
…it, simplified dataframe serialization
- Loading branch information
1 parent
26fc930
commit fd22f8c
Showing
12 changed files
with
86 additions
and
173 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
2 changes: 1 addition & 1 deletion
2
pandasai/core/prompts/templates/correct_execute_sql_query_usage_error_prompt.tmpl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
2 changes: 1 addition & 1 deletion
2
pandasai/core/prompts/templates/generate_python_code_with_sql.tmpl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
{{ df.serialize_dataframe(index-1) }} | ||
{{ df.serialize_dataframe() }} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,10 @@ | ||
from . import path | ||
from . import path, sql_sanitizer | ||
from .env import load_dotenv | ||
from .logger import Logger | ||
|
||
__all__ = [ | ||
"path", | ||
"sql_sanitizer", | ||
"load_dotenv", | ||
"Logger", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,137 +1,35 @@ | ||
import json | ||
from enum import Enum | ||
|
||
import pandas as pd | ||
|
||
|
||
class DataframeSerializerType(Enum): | ||
JSON = 1 | ||
YML = 2 | ||
CSV = 3 | ||
SQL = 4 | ||
|
||
|
||
class DataframeSerializer: | ||
def __init__(self) -> None: | ||
pass | ||
|
||
def serialize( | ||
self, | ||
df: pd.DataFrame, | ||
extras: dict = None, | ||
type_: DataframeSerializerType = DataframeSerializerType.YML, | ||
) -> str: | ||
if type_ == DataframeSerializerType.YML: | ||
return self.convert_df_to_yml(df, extras) | ||
elif type_ == DataframeSerializerType.JSON: | ||
return self.convert_df_to_json_str(df, extras) | ||
elif type_ == DataframeSerializerType.SQL: | ||
return self.convert_df_sql_connector_to_str(df, extras) | ||
else: | ||
return self.convert_df_to_csv(df, extras) | ||
|
||
def convert_df_to_csv(self, df: pd.DataFrame, extras: dict) -> str: | ||
def serialize(self, df: pd.DataFrame) -> str: | ||
""" | ||
Convert df to csv like format where csv is wrapped inside <dataframe></dataframe> | ||
Args: | ||
df (pd.DataFrame): PandaAI dataframe or dataframe | ||
extras (dict, optional): expect index to exists | ||
Returns: | ||
str: dataframe stringify | ||
""" | ||
dataframe_info = "<dataframe" | ||
dataframe_info = "<table" | ||
|
||
# Add name attribute if available | ||
if df.name is not None: | ||
dataframe_info += f' name="{df.name}"' | ||
dataframe_info += f' table_name="{df.name}"' | ||
|
||
# Add description attribute if available | ||
if df.description is not None: | ||
dataframe_info += f' description="{df.description}"' | ||
|
||
dataframe_info += ">" | ||
dataframe_info += f' dimensions="{df.rows_count}x{df.columns_count}">' | ||
|
||
# Add dataframe details | ||
dataframe_info += f"\ndfs[{extras['index']}]:{df.rows_count}x{df.columns_count}\n{df.head().to_csv(index=False)}" | ||
dataframe_info += f"\n{df.head().to_csv(index=False)}" | ||
|
||
# Close the dataframe tag | ||
dataframe_info += "</dataframe>\n" | ||
dataframe_info += "</table>\n" | ||
|
||
return dataframe_info | ||
|
||
def convert_df_sql_connector_to_str( | ||
self, df: pd.DataFrame, extras: dict = None | ||
) -> str: | ||
""" | ||
Convert df to csv like format where csv is wrapped inside <table></table> | ||
Args: | ||
df (pd.DataFrame): PandaAI dataframe or dataframe | ||
extras (dict, optional): expect index to exists | ||
Returns: | ||
str: dataframe stringify | ||
""" | ||
table_description_tag = ( | ||
f' description="{df.description}"' if df.description is not None else "" | ||
) | ||
table_head_tag = f'<table name="{df.name}"{table_description_tag}>' | ||
return f"{table_head_tag}\n{df.get_head().to_csv()}\n</table>" | ||
|
||
def convert_df_to_json(self, df: pd.DataFrame, extras: dict) -> dict: | ||
""" | ||
Convert df to json dictionary and return json | ||
Args: | ||
df (pd.DataFrame): PandaAI dataframe or dataframe | ||
extras (dict, optional): expect index to exists | ||
Returns: | ||
str: dataframe json | ||
""" | ||
|
||
# Create a dictionary representing the data structure | ||
df_info = { | ||
"name": df.name, | ||
"description": None, | ||
"type": df.type, | ||
} | ||
# Add DataFrame details to the result | ||
data = { | ||
"rows": df.rows_count, | ||
"columns": df.columns_count, | ||
"schema": {"fields": []}, | ||
} | ||
|
||
# Iterate over DataFrame columns | ||
df_head = df.get_head() | ||
for col_name, col_dtype in df_head.dtypes.items(): | ||
col_info = { | ||
"name": col_name, | ||
"type": str(col_dtype), | ||
} | ||
|
||
data["schema"]["fields"].append(col_info) | ||
|
||
result = df_info | data | ||
|
||
return result | ||
|
||
def convert_df_to_json_str(self, df: pd.DataFrame, extras: dict) -> str: | ||
""" | ||
Convert df to json and return it as string | ||
Args: | ||
df (pd.DataFrame): PandaAI dataframe or dataframe | ||
extras (dict, optional): expect index to exists | ||
Returns: | ||
str: dataframe stringify | ||
""" | ||
return json.dumps(self.convert_df_to_json(df, extras)) | ||
|
||
def convert_df_to_yml(self, df: pd.DataFrame, extras: dict) -> str: | ||
json_df = self.convert_df_to_json(df, extras) | ||
|
||
import yaml | ||
|
||
yml_str = yaml.dump(json_df, sort_keys=False, allow_unicode=True) | ||
return f"<table>\n{yml_str}\n</table>\n" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
import os | ||
import re | ||
|
||
|
||
def sanitize_sql_table_name(filepath: str) -> str: | ||
# Extract the file name without extension | ||
file_name = os.path.splitext(os.path.basename(filepath))[0] | ||
|
||
# Replace invalid characters with underscores | ||
sanitized_name = re.sub(r"[^a-zA-Z0-9_]", "_", file_name) | ||
|
||
# Truncate to a reasonable length (e.g., 64 characters) | ||
max_length = 64 | ||
sanitized_name = sanitized_name[:max_length] | ||
|
||
return sanitized_name |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,39 +1,32 @@ | ||
import unittest | ||
import pytest | ||
|
||
from pandasai.dataframe.base import DataFrame | ||
from pandasai.helpers.dataframe_serializer import ( | ||
DataframeSerializer, | ||
DataframeSerializerType, | ||
) | ||
from pandasai import DataFrame | ||
from pandasai.helpers.dataframe_serializer import DataframeSerializer | ||
|
||
|
||
class TestDataframeSerializer(unittest.TestCase): | ||
def setUp(self): | ||
self.serializer = DataframeSerializer() | ||
class TestDataframeSerializer: | ||
@pytest.fixture | ||
def sample_df(self): | ||
df = DataFrame({"Name": ["Alice", "Bob"], "Age": [25, 30]}) | ||
df.name = "test_table" | ||
df.description = "This is a test table" | ||
return df | ||
|
||
def test_convert_df_to_yml(self): | ||
# Test convert df to yml | ||
data = {"name": ["en_name", "中文_名称"]} | ||
connector = DataFrame(data, name="en_table_name", description="中文_描述") | ||
result = self.serializer.serialize( | ||
connector, | ||
type_=DataframeSerializerType.YML, | ||
extras={"index": 0, "type": "pd.Dataframe"}, | ||
) | ||
@pytest.fixture | ||
def sample_dataframe_serializer(self): | ||
return DataframeSerializer() | ||
|
||
self.assertIn( | ||
"""<table> | ||
name: en_table_name | ||
description: null | ||
type: pd.DataFrame | ||
rows: 2 | ||
columns: 1 | ||
schema: | ||
fields: | ||
- name: name | ||
type: object | ||
def test_serialize_with_name_and_description( | ||
self, sample_dataframe_serializer, sample_df | ||
): | ||
"""Test serialization with name and description attributes.""" | ||
|
||
</table> | ||
""", | ||
result, | ||
result = sample_dataframe_serializer.serialize(sample_df) | ||
expected = ( | ||
'<table table_name="test_table" description="This is a test table" dimensions="2x2">\n' | ||
"Name,Age\n" | ||
"Alice,25\n" | ||
"Bob,30\n" | ||
"</table>\n" | ||
) | ||
assert result == expected |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from pandasai.helpers.sql_sanitizer import sanitize_sql_table_name | ||
|
||
|
||
class TestSqlSanitizer: | ||
def test_valid_filename(self): | ||
filepath = "/path/to/valid_table.csv" | ||
expected = "valid_table" | ||
assert sanitize_sql_table_name(filepath) == expected | ||
|
||
def test_filename_with_special_characters(self): | ||
filepath = "/path/to/invalid!@#.csv" | ||
expected = "invalid___" | ||
assert sanitize_sql_table_name(filepath) == expected | ||
|
||
def test_filename_with_long_name(self): | ||
"""Test with a filename exceeding the length limit.""" | ||
filepath = "/path/to/" + "a" * 100 + ".csv" | ||
expected = "a" * 64 | ||
assert sanitize_sql_table_name(filepath) == expected |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters