Skip to content

Commit

Permalink
fix(Dataframe): adding default dataframe name to enable sql query on …
Browse files Browse the repository at this point in the history
…it, simplified dataframe serialization
  • Loading branch information
scaliseraoul committed Jan 15, 2025
1 parent 26fc930 commit fd22f8c
Show file tree
Hide file tree
Showing 12 changed files with 86 additions and 173 deletions.
4 changes: 3 additions & 1 deletion pandasai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .core.cache import Cache
from .data_loader.loader import DatasetLoader
from .dataframe import DataFrame, VirtualDataFrame
from .helpers.sql_sanitizer import sanitize_sql_table_name
from .smart_dataframe import SmartDataframe
from .smart_datalake import SmartDatalake

Expand Down Expand Up @@ -120,7 +121,8 @@ def load(dataset_path: str) -> DataFrame:

def read_csv(filepath: str) -> DataFrame:
data = pd.read_csv(filepath)
return DataFrame(data._data)
name = f"table_{sanitize_sql_table_name(filepath)}"
return DataFrame(data._data, name=name)


__all__ = [
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{% for df in context.dfs %}{% set index = loop.index %}{% include 'shared/dataframe.tmpl' with context %}{% endfor %}
{% for df in context.dfs %}{% include 'shared/dataframe.tmpl' with context %}{% endfor %}

The user asked the following question:
{{context.memory.get_conversation()}}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
<tables>
{% for df in context.dfs %}
{% set index = loop.index %}{% include 'shared/dataframe.tmpl' with context %}
{% include 'shared/dataframe.tmpl' with context %}
{% endfor %}
</tables>

Expand Down
2 changes: 1 addition & 1 deletion pandasai/core/prompts/templates/shared/dataframe.tmpl
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{{ df.serialize_dataframe(index-1) }}
{{ df.serialize_dataframe() }}
30 changes: 8 additions & 22 deletions pandasai/dataframe/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,7 @@
Source,
)
from pandasai.exceptions import DatasetNotFound, PandaAIApiKeyError
from pandasai.helpers.dataframe_serializer import (
DataframeSerializer,
DataframeSerializerType,
)
from pandasai.helpers.dataframe_serializer import DataframeSerializer
from pandasai.helpers.path import find_project_root
from pandasai.helpers.session import get_pandaai_session

Expand Down Expand Up @@ -67,14 +64,18 @@ def __init__(
)

self.name: Optional[str] = kwargs.pop("name", None)
self._column_hash = self._calculate_column_hash()

if not self.name:
self.name = f"table_{self._column_hash}"

self.description: Optional[str] = kwargs.pop("description", None)
self.path: Optional[str] = kwargs.pop("path", None)
schema: Optional[SemanticLayerSchema] = kwargs.pop("schema", None)

self.schema = schema
self.config = pai.config.get()
self._agent: Optional[Agent] = None
self._column_hash = self._calculate_column_hash()

def __repr__(self) -> str:
"""Return a string representation of the DataFrame."""
Expand Down Expand Up @@ -136,29 +137,14 @@ def rows_count(self) -> int:
def columns_count(self) -> int:
return len(self.columns)

def serialize_dataframe(
self,
index: int,
) -> str:
def serialize_dataframe(self) -> str:
"""
Serialize DataFrame to string representation.
Args:
index (int): Index of the dataframe
serializer_type (DataframeSerializerType): Type of serializer to use
**kwargs: Additional parameters to pass to pandas to_string method
Returns:
str: Serialized string representation of the DataFrame
"""
return DataframeSerializer().serialize(
self,
extras={
"index": index,
"type": "pd.DataFrame",
},
type_=DataframeSerializerType.CSV,
)
return DataframeSerializer().serialize(self)

def get_head(self):
return self.head()
Expand Down
3 changes: 2 additions & 1 deletion pandasai/helpers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from . import path
from . import path, sql_sanitizer
from .env import load_dotenv
from .logger import Logger

__all__ = [
"path",
"sql_sanitizer",
"load_dotenv",
"Logger",
]
114 changes: 6 additions & 108 deletions pandasai/helpers/dataframe_serializer.py
Original file line number Diff line number Diff line change
@@ -1,137 +1,35 @@
import json
from enum import Enum

import pandas as pd


class DataframeSerializerType(Enum):
JSON = 1
YML = 2
CSV = 3
SQL = 4


class DataframeSerializer:
def __init__(self) -> None:
pass

def serialize(
self,
df: pd.DataFrame,
extras: dict = None,
type_: DataframeSerializerType = DataframeSerializerType.YML,
) -> str:
if type_ == DataframeSerializerType.YML:
return self.convert_df_to_yml(df, extras)
elif type_ == DataframeSerializerType.JSON:
return self.convert_df_to_json_str(df, extras)
elif type_ == DataframeSerializerType.SQL:
return self.convert_df_sql_connector_to_str(df, extras)
else:
return self.convert_df_to_csv(df, extras)

def convert_df_to_csv(self, df: pd.DataFrame, extras: dict) -> str:
def serialize(self, df: pd.DataFrame) -> str:
"""
Convert df to csv like format where csv is wrapped inside <dataframe></dataframe>
Args:
df (pd.DataFrame): PandaAI dataframe or dataframe
extras (dict, optional): expect index to exists
Returns:
str: dataframe stringify
"""
dataframe_info = "<dataframe"
dataframe_info = "<table"

# Add name attribute if available
if df.name is not None:
dataframe_info += f' name="{df.name}"'
dataframe_info += f' table_name="{df.name}"'

# Add description attribute if available
if df.description is not None:
dataframe_info += f' description="{df.description}"'

dataframe_info += ">"
dataframe_info += f' dimensions="{df.rows_count}x{df.columns_count}">'

# Add dataframe details
dataframe_info += f"\ndfs[{extras['index']}]:{df.rows_count}x{df.columns_count}\n{df.head().to_csv(index=False)}"
dataframe_info += f"\n{df.head().to_csv(index=False)}"

# Close the dataframe tag
dataframe_info += "</dataframe>\n"
dataframe_info += "</table>\n"

return dataframe_info

def convert_df_sql_connector_to_str(
self, df: pd.DataFrame, extras: dict = None
) -> str:
"""
Convert df to csv like format where csv is wrapped inside <table></table>
Args:
df (pd.DataFrame): PandaAI dataframe or dataframe
extras (dict, optional): expect index to exists
Returns:
str: dataframe stringify
"""
table_description_tag = (
f' description="{df.description}"' if df.description is not None else ""
)
table_head_tag = f'<table name="{df.name}"{table_description_tag}>'
return f"{table_head_tag}\n{df.get_head().to_csv()}\n</table>"

def convert_df_to_json(self, df: pd.DataFrame, extras: dict) -> dict:
"""
Convert df to json dictionary and return json
Args:
df (pd.DataFrame): PandaAI dataframe or dataframe
extras (dict, optional): expect index to exists
Returns:
str: dataframe json
"""

# Create a dictionary representing the data structure
df_info = {
"name": df.name,
"description": None,
"type": df.type,
}
# Add DataFrame details to the result
data = {
"rows": df.rows_count,
"columns": df.columns_count,
"schema": {"fields": []},
}

# Iterate over DataFrame columns
df_head = df.get_head()
for col_name, col_dtype in df_head.dtypes.items():
col_info = {
"name": col_name,
"type": str(col_dtype),
}

data["schema"]["fields"].append(col_info)

result = df_info | data

return result

def convert_df_to_json_str(self, df: pd.DataFrame, extras: dict) -> str:
"""
Convert df to json and return it as string
Args:
df (pd.DataFrame): PandaAI dataframe or dataframe
extras (dict, optional): expect index to exists
Returns:
str: dataframe stringify
"""
return json.dumps(self.convert_df_to_json(df, extras))

def convert_df_to_yml(self, df: pd.DataFrame, extras: dict) -> str:
json_df = self.convert_df_to_json(df, extras)

import yaml

yml_str = yaml.dump(json_df, sort_keys=False, allow_unicode=True)
return f"<table>\n{yml_str}\n</table>\n"
16 changes: 16 additions & 0 deletions pandasai/helpers/sql_sanitizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import os
import re


def sanitize_sql_table_name(filepath: str) -> str:
# Extract the file name without extension
file_name = os.path.splitext(os.path.basename(filepath))[0]

# Replace invalid characters with underscores
sanitized_name = re.sub(r"[^a-zA-Z0-9_]", "_", file_name)

# Truncate to a reasonable length (e.g., 64 characters)
max_length = 64
sanitized_name = sanitized_name[:max_length]

return sanitized_name
3 changes: 1 addition & 2 deletions tests/unit_tests/agent/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from pandasai.config import Config, ConfigManager
from pandasai.dataframe.base import DataFrame
from pandasai.exceptions import CodeExecutionError
from pandasai.helpers.dataframe_serializer import DataframeSerializerType
from pandasai.llm.fake import FakeLLM


Expand Down Expand Up @@ -38,7 +37,7 @@ def llm(self, output: Optional[str] = None) -> FakeLLM:

@pytest.fixture
def config(self, llm: FakeLLM) -> dict:
return {"llm": llm, "dataframe_serializer": DataframeSerializerType.CSV}
return {"llm": llm}

@pytest.fixture
def agent(self, sample_df: pd.DataFrame, config: dict) -> Agent:
Expand Down
57 changes: 25 additions & 32 deletions tests/unit_tests/helpers/test_dataframe_serializer.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,32 @@
import unittest
import pytest

from pandasai.dataframe.base import DataFrame
from pandasai.helpers.dataframe_serializer import (
DataframeSerializer,
DataframeSerializerType,
)
from pandasai import DataFrame
from pandasai.helpers.dataframe_serializer import DataframeSerializer


class TestDataframeSerializer(unittest.TestCase):
def setUp(self):
self.serializer = DataframeSerializer()
class TestDataframeSerializer:
@pytest.fixture
def sample_df(self):
df = DataFrame({"Name": ["Alice", "Bob"], "Age": [25, 30]})
df.name = "test_table"
df.description = "This is a test table"
return df

def test_convert_df_to_yml(self):
# Test convert df to yml
data = {"name": ["en_name", "中文_名称"]}
connector = DataFrame(data, name="en_table_name", description="中文_描述")
result = self.serializer.serialize(
connector,
type_=DataframeSerializerType.YML,
extras={"index": 0, "type": "pd.Dataframe"},
)
@pytest.fixture
def sample_dataframe_serializer(self):
return DataframeSerializer()

self.assertIn(
"""<table>
name: en_table_name
description: null
type: pd.DataFrame
rows: 2
columns: 1
schema:
fields:
- name: name
type: object
def test_serialize_with_name_and_description(
self, sample_dataframe_serializer, sample_df
):
"""Test serialization with name and description attributes."""

</table>
""",
result,
result = sample_dataframe_serializer.serialize(sample_df)
expected = (
'<table table_name="test_table" description="This is a test table" dimensions="2x2">\n'
"Name,Age\n"
"Alice,25\n"
"Bob,30\n"
"</table>\n"
)
assert result == expected
19 changes: 19 additions & 0 deletions tests/unit_tests/helpers/test_sql_sanitizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from pandasai.helpers.sql_sanitizer import sanitize_sql_table_name


class TestSqlSanitizer:
def test_valid_filename(self):
filepath = "/path/to/valid_table.csv"
expected = "valid_table"
assert sanitize_sql_table_name(filepath) == expected

def test_filename_with_special_characters(self):
filepath = "/path/to/invalid!@#.csv"
expected = "invalid___"
assert sanitize_sql_table_name(filepath) == expected

def test_filename_with_long_name(self):
"""Test with a filename exceeding the length limit."""
filepath = "/path/to/" + "a" * 100 + ".csv"
expected = "a" * 64
assert sanitize_sql_table_name(filepath) == expected
7 changes: 3 additions & 4 deletions tests/unit_tests/prompts/test_sql_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_str_with_args(self, output_type, output_type_template):

llm = FakeLLM()
agent = Agent(
pai.DataFrame(),
pai.DataFrame(name="test"),
config={"llm": llm},
)
prompt = GeneratePythonCodeWithSQLPrompt(
Expand All @@ -68,10 +68,9 @@ def test_str_with_args(self, output_type, output_type_template):
prompt_content
== f'''<tables>
<dataframe>
dfs[0]:0x0
<table table_name="test" dimensions="0x0">
</dataframe>
</table>
</tables>
Expand Down

0 comments on commit fd22f8c

Please sign in to comment.