diff --git a/pandasai/__init__.py b/pandasai/__init__.py
index 5f73b0d9c..cb358fd1d 100644
--- a/pandasai/__init__.py
+++ b/pandasai/__init__.py
@@ -19,6 +19,7 @@
from .core.cache import Cache
from .data_loader.loader import DatasetLoader
from .dataframe import DataFrame, VirtualDataFrame
+from .helpers.sql_sanitizer import sanitize_sql_table_name
from .smart_dataframe import SmartDataframe
from .smart_datalake import SmartDatalake
@@ -120,7 +121,8 @@ def load(dataset_path: str) -> DataFrame:
def read_csv(filepath: str) -> DataFrame:
data = pd.read_csv(filepath)
- return DataFrame(data._data)
+ name = f"table_{sanitize_sql_table_name(filepath)}"
+ return DataFrame(data._data, name=name)
__all__ = [
diff --git a/pandasai/core/prompts/templates/correct_execute_sql_query_usage_error_prompt.tmpl b/pandasai/core/prompts/templates/correct_execute_sql_query_usage_error_prompt.tmpl
index 523608f71..029cf26f2 100644
--- a/pandasai/core/prompts/templates/correct_execute_sql_query_usage_error_prompt.tmpl
+++ b/pandasai/core/prompts/templates/correct_execute_sql_query_usage_error_prompt.tmpl
@@ -1,4 +1,4 @@
-{% for df in context.dfs %}{% set index = loop.index %}{% include 'shared/dataframe.tmpl' with context %}{% endfor %}
+{% for df in context.dfs %}{% include 'shared/dataframe.tmpl' with context %}{% endfor %}
The user asked the following question:
{{context.memory.get_conversation()}}
diff --git a/pandasai/core/prompts/templates/generate_python_code_with_sql.tmpl b/pandasai/core/prompts/templates/generate_python_code_with_sql.tmpl
index 6ce957d66..5406a8352 100644
--- a/pandasai/core/prompts/templates/generate_python_code_with_sql.tmpl
+++ b/pandasai/core/prompts/templates/generate_python_code_with_sql.tmpl
@@ -1,6 +1,6 @@
{% for df in context.dfs %}
-{% set index = loop.index %}{% include 'shared/dataframe.tmpl' with context %}
+{% include 'shared/dataframe.tmpl' with context %}
{% endfor %}
diff --git a/pandasai/core/prompts/templates/shared/dataframe.tmpl b/pandasai/core/prompts/templates/shared/dataframe.tmpl
index 931813b6f..fcc6ea3a8 100644
--- a/pandasai/core/prompts/templates/shared/dataframe.tmpl
+++ b/pandasai/core/prompts/templates/shared/dataframe.tmpl
@@ -1 +1 @@
-{{ df.serialize_dataframe(index-1) }}
+{{ df.serialize_dataframe() }}
diff --git a/pandasai/dataframe/base.py b/pandasai/dataframe/base.py
index d90a3b786..5cc895453 100644
--- a/pandasai/dataframe/base.py
+++ b/pandasai/dataframe/base.py
@@ -21,10 +21,7 @@
Source,
)
from pandasai.exceptions import DatasetNotFound, PandaAIApiKeyError
-from pandasai.helpers.dataframe_serializer import (
- DataframeSerializer,
- DataframeSerializerType,
-)
+from pandasai.helpers.dataframe_serializer import DataframeSerializer
from pandasai.helpers.path import find_project_root
from pandasai.helpers.session import get_pandaai_session
@@ -67,6 +64,11 @@ def __init__(
)
self.name: Optional[str] = kwargs.pop("name", None)
+ self._column_hash = self._calculate_column_hash()
+
+ if not self.name:
+ self.name = f"table_{self._column_hash}"
+
self.description: Optional[str] = kwargs.pop("description", None)
self.path: Optional[str] = kwargs.pop("path", None)
schema: Optional[SemanticLayerSchema] = kwargs.pop("schema", None)
@@ -74,7 +76,6 @@ def __init__(
self.schema = schema
self.config = pai.config.get()
self._agent: Optional[Agent] = None
- self._column_hash = self._calculate_column_hash()
def __repr__(self) -> str:
"""Return a string representation of the DataFrame."""
@@ -136,29 +137,14 @@ def rows_count(self) -> int:
def columns_count(self) -> int:
return len(self.columns)
- def serialize_dataframe(
- self,
- index: int,
- ) -> str:
+ def serialize_dataframe(self) -> str:
"""
Serialize DataFrame to string representation.
- Args:
- index (int): Index of the dataframe
- serializer_type (DataframeSerializerType): Type of serializer to use
- **kwargs: Additional parameters to pass to pandas to_string method
-
Returns:
str: Serialized string representation of the DataFrame
"""
- return DataframeSerializer().serialize(
- self,
- extras={
- "index": index,
- "type": "pd.DataFrame",
- },
- type_=DataframeSerializerType.CSV,
- )
+ return DataframeSerializer().serialize(self)
def get_head(self):
return self.head()
diff --git a/pandasai/helpers/__init__.py b/pandasai/helpers/__init__.py
index 0e9534765..ea29af987 100644
--- a/pandasai/helpers/__init__.py
+++ b/pandasai/helpers/__init__.py
@@ -1,9 +1,10 @@
-from . import path
+from . import path, sql_sanitizer
from .env import load_dotenv
from .logger import Logger
__all__ = [
"path",
+ "sql_sanitizer",
"load_dotenv",
"Logger",
]
diff --git a/pandasai/helpers/dataframe_serializer.py b/pandasai/helpers/dataframe_serializer.py
index ca8e66219..487c2da82 100644
--- a/pandasai/helpers/dataframe_serializer.py
+++ b/pandasai/helpers/dataframe_serializer.py
@@ -1,137 +1,35 @@
-import json
-from enum import Enum
-
import pandas as pd
-class DataframeSerializerType(Enum):
- JSON = 1
- YML = 2
- CSV = 3
- SQL = 4
-
-
class DataframeSerializer:
def __init__(self) -> None:
pass
- def serialize(
- self,
- df: pd.DataFrame,
- extras: dict = None,
- type_: DataframeSerializerType = DataframeSerializerType.YML,
- ) -> str:
- if type_ == DataframeSerializerType.YML:
- return self.convert_df_to_yml(df, extras)
- elif type_ == DataframeSerializerType.JSON:
- return self.convert_df_to_json_str(df, extras)
- elif type_ == DataframeSerializerType.SQL:
- return self.convert_df_sql_connector_to_str(df, extras)
- else:
- return self.convert_df_to_csv(df, extras)
-
- def convert_df_to_csv(self, df: pd.DataFrame, extras: dict) -> str:
+ def serialize(self, df: pd.DataFrame) -> str:
"""
Convert df to csv like format where csv is wrapped inside
Args:
df (pd.DataFrame): PandaAI dataframe or dataframe
- extras (dict, optional): expect index to exists
Returns:
str: dataframe stringify
"""
- dataframe_info = ""
+ dataframe_info += f' dimensions="{df.rows_count}x{df.columns_count}">'
# Add dataframe details
- dataframe_info += f"\ndfs[{extras['index']}]:{df.rows_count}x{df.columns_count}\n{df.head().to_csv(index=False)}"
+ dataframe_info += f"\n{df.head().to_csv(index=False)}"
# Close the dataframe tag
- dataframe_info += "\n"
+ dataframe_info += "\n"
return dataframe_info
-
- def convert_df_sql_connector_to_str(
- self, df: pd.DataFrame, extras: dict = None
- ) -> str:
- """
- Convert df to csv like format where csv is wrapped inside
- Args:
- df (pd.DataFrame): PandaAI dataframe or dataframe
- extras (dict, optional): expect index to exists
-
- Returns:
- str: dataframe stringify
- """
- table_description_tag = (
- f' description="{df.description}"' if df.description is not None else ""
- )
- table_head_tag = f''
- return f"{table_head_tag}\n{df.get_head().to_csv()}\n
"
-
- def convert_df_to_json(self, df: pd.DataFrame, extras: dict) -> dict:
- """
- Convert df to json dictionary and return json
- Args:
- df (pd.DataFrame): PandaAI dataframe or dataframe
- extras (dict, optional): expect index to exists
-
- Returns:
- str: dataframe json
- """
-
- # Create a dictionary representing the data structure
- df_info = {
- "name": df.name,
- "description": None,
- "type": df.type,
- }
- # Add DataFrame details to the result
- data = {
- "rows": df.rows_count,
- "columns": df.columns_count,
- "schema": {"fields": []},
- }
-
- # Iterate over DataFrame columns
- df_head = df.get_head()
- for col_name, col_dtype in df_head.dtypes.items():
- col_info = {
- "name": col_name,
- "type": str(col_dtype),
- }
-
- data["schema"]["fields"].append(col_info)
-
- result = df_info | data
-
- return result
-
- def convert_df_to_json_str(self, df: pd.DataFrame, extras: dict) -> str:
- """
- Convert df to json and return it as string
- Args:
- df (pd.DataFrame): PandaAI dataframe or dataframe
- extras (dict, optional): expect index to exists
-
- Returns:
- str: dataframe stringify
- """
- return json.dumps(self.convert_df_to_json(df, extras))
-
- def convert_df_to_yml(self, df: pd.DataFrame, extras: dict) -> str:
- json_df = self.convert_df_to_json(df, extras)
-
- import yaml
-
- yml_str = yaml.dump(json_df, sort_keys=False, allow_unicode=True)
- return f"\n"
diff --git a/pandasai/helpers/sql_sanitizer.py b/pandasai/helpers/sql_sanitizer.py
new file mode 100644
index 000000000..82b4306eb
--- /dev/null
+++ b/pandasai/helpers/sql_sanitizer.py
@@ -0,0 +1,16 @@
+import os
+import re
+
+
+def sanitize_sql_table_name(filepath: str) -> str:
+ # Extract the file name without extension
+ file_name = os.path.splitext(os.path.basename(filepath))[0]
+
+ # Replace invalid characters with underscores
+ sanitized_name = re.sub(r"[^a-zA-Z0-9_]", "_", file_name)
+
+ # Truncate to a reasonable length (e.g., 64 characters)
+ max_length = 64
+ sanitized_name = sanitized_name[:max_length]
+
+ return sanitized_name
diff --git a/tests/unit_tests/agent/test_agent.py b/tests/unit_tests/agent/test_agent.py
index c98fa35db..48d0dcae2 100644
--- a/tests/unit_tests/agent/test_agent.py
+++ b/tests/unit_tests/agent/test_agent.py
@@ -9,7 +9,6 @@
from pandasai.config import Config, ConfigManager
from pandasai.dataframe.base import DataFrame
from pandasai.exceptions import CodeExecutionError
-from pandasai.helpers.dataframe_serializer import DataframeSerializerType
from pandasai.llm.fake import FakeLLM
@@ -38,7 +37,7 @@ def llm(self, output: Optional[str] = None) -> FakeLLM:
@pytest.fixture
def config(self, llm: FakeLLM) -> dict:
- return {"llm": llm, "dataframe_serializer": DataframeSerializerType.CSV}
+ return {"llm": llm}
@pytest.fixture
def agent(self, sample_df: pd.DataFrame, config: dict) -> Agent:
diff --git a/tests/unit_tests/helpers/test_dataframe_serializer.py b/tests/unit_tests/helpers/test_dataframe_serializer.py
index 78cabe165..ca31dba08 100644
--- a/tests/unit_tests/helpers/test_dataframe_serializer.py
+++ b/tests/unit_tests/helpers/test_dataframe_serializer.py
@@ -1,39 +1,31 @@
-import unittest
+import pytest
-from pandasai.dataframe.base import DataFrame
-from pandasai.helpers.dataframe_serializer import (
- DataframeSerializer,
- DataframeSerializerType,
-)
+from pandasai import DataFrame
+from pandasai.helpers.dataframe_serializer import DataframeSerializer
-class TestDataframeSerializer(unittest.TestCase):
- def setUp(self):
- self.serializer = DataframeSerializer()
+class TestDataframeSerializer:
+ @pytest.fixture
+ def sample_df(self):
+ df = DataFrame({"Name": ["Alice", "Bob"], "Age": [25, 30]})
+ df.name = "test_table"
+ df.description = "This is a test table"
+ return df
- def test_convert_df_to_yml(self):
- # Test convert df to yml
- data = {"name": ["en_name", "中文_名称"]}
- connector = DataFrame(data, name="en_table_name", description="中文_描述")
- result = self.serializer.serialize(
- connector,
- type_=DataframeSerializerType.YML,
- extras={"index": 0, "type": "pd.Dataframe"},
- )
+ @pytest.fixture
+ def sample_dataframe_serializer(self):
+ return DataframeSerializer()
- self.assertIn(
- """
-name: en_table_name
-description: null
-type: pd.DataFrame
-rows: 2
-columns: 1
-schema:
- fields:
- - name: name
- type: object
+ def test_serialize_with_name_and_description(
+ self, sample_dataframe_serializer, sample_df
+ ):
+ """Test serialization with name and description attributes."""
+ result = sample_dataframe_serializer.serialize(sample_df)
+ expected = """
+Name,Age
+Alice,25
+Bob,30
-""",
- result,
- )
+"""
+ assert result.replace("\r\n", "\n") == expected.replace("\r\n", "\n")
diff --git a/tests/unit_tests/helpers/test_sql_sanitizer.py b/tests/unit_tests/helpers/test_sql_sanitizer.py
new file mode 100644
index 000000000..5f4ab40fc
--- /dev/null
+++ b/tests/unit_tests/helpers/test_sql_sanitizer.py
@@ -0,0 +1,19 @@
+from pandasai.helpers.sql_sanitizer import sanitize_sql_table_name
+
+
+class TestSqlSanitizer:
+ def test_valid_filename(self):
+ filepath = "/path/to/valid_table.csv"
+ expected = "valid_table"
+ assert sanitize_sql_table_name(filepath) == expected
+
+ def test_filename_with_special_characters(self):
+ filepath = "/path/to/invalid!@#.csv"
+ expected = "invalid___"
+ assert sanitize_sql_table_name(filepath) == expected
+
+ def test_filename_with_long_name(self):
+ """Test with a filename exceeding the length limit."""
+ filepath = "/path/to/" + "a" * 100 + ".csv"
+ expected = "a" * 64
+ assert sanitize_sql_table_name(filepath) == expected
diff --git a/tests/unit_tests/prompts/test_sql_prompt.py b/tests/unit_tests/prompts/test_sql_prompt.py
index 48bb58335..333b575db 100644
--- a/tests/unit_tests/prompts/test_sql_prompt.py
+++ b/tests/unit_tests/prompts/test_sql_prompt.py
@@ -51,7 +51,7 @@ def test_str_with_args(self, output_type, output_type_template):
llm = FakeLLM()
agent = Agent(
- pai.DataFrame(),
+ pai.DataFrame(name="test"),
config={"llm": llm},
)
prompt = GeneratePythonCodeWithSQLPrompt(
@@ -68,10 +68,9 @@ def test_str_with_args(self, output_type, output_type_template):
prompt_content
== f'''
-
-dfs[0]:0x0
+