Skip to content

Commit

Permalink
Merge branch 'main' into fix/SIN-340
Browse files Browse the repository at this point in the history
  • Loading branch information
scaliseraoul authored Jan 31, 2025
2 parents bcc04a9 + 4ca228f commit 124d4c3
Show file tree
Hide file tree
Showing 8 changed files with 213 additions and 8 deletions.
2 changes: 1 addition & 1 deletion extensions/connectors/sql/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "pandasai-sql"
version = "0.1.5"
version = "0.1.6"
description = "SQL integration for PandaAI"
authors = ["Gabriele Venturi"]
license = "MIT"
Expand Down
8 changes: 7 additions & 1 deletion pandasai/data_loader/sql_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import pandas as pd

from pandasai.dataframe.virtual_dataframe import VirtualDataFrame
from pandasai.exceptions import InvalidDataSourceType
from pandasai.exceptions import InvalidDataSourceType, MaliciousQueryError
from pandasai.helpers.sql_sanitizer import is_sql_query_safe

from ..constants import (
SUPPORTED_SOURCE_CONNECTORS,
Expand Down Expand Up @@ -36,6 +37,11 @@ def execute_query(self, query: str, params: Optional[list] = None) -> pd.DataFra

formatted_query = self.query_builder.format_query(query)
load_function = self._get_loader_function(source_type)

if not is_sql_query_safe(formatted_query):
raise MaliciousQueryError(
"The SQL query is deemed unsafe and will not be executed."
)
try:
dataframe: pd.DataFrame = load_function(
connection_info, formatted_query, params
Expand Down
70 changes: 70 additions & 0 deletions pandasai/helpers/sql_sanitizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
import re

import sqlglot


def sanitize_sql_table_name(filepath: str) -> str:
# Extract the file name without extension
Expand All @@ -14,3 +16,71 @@ def sanitize_sql_table_name(filepath: str) -> str:
sanitized_name = sanitized_name[:max_length]

return sanitized_name


def is_sql_query_safe(query: str) -> bool:
try:
# List of infected keywords to block (you can add more)
infected_keywords = [
r"\bINSERT\b",
r"\bUPDATE\b",
r"\bDELETE\b",
r"\bDROP\b",
r"\bEXEC\b",
r"\bALTER\b",
r"\bCREATE\b",
r"\bMERGE\b",
r"\bREPLACE\b",
r"\bTRUNCATE\b",
r"\bLOAD\b",
r"\bGRANT\b",
r"\bREVOKE\b",
r"\bCALL\b",
r"\bEXECUTE\b",
r"\bSHOW\b",
r"\bDESCRIBE\b",
r"\bEXPLAIN\b",
r"\bUSE\b",
r"\bSET\b",
r"\bDECLARE\b",
r"\bOPEN\b",
r"\bFETCH\b",
r"\bCLOSE\b",
r"\bSLEEP\b",
r"\bBENCHMARK\b",
r"\bDATABASE\b",
r"\bUSER\b",
r"\bCURRENT_USER\b",
r"\bSESSION_USER\b",
r"\bSYSTEM_USER\b",
r"\bVERSION\b",
r"\b@@VERSION\b",
r"--",
r"/\*.*\*/", # Block comments and inline comments
]
# Parse the query to extract its structure
parsed = sqlglot.parse_one(query)

# Ensure the main query is SELECT
if parsed.key.upper() != "SELECT":
return False

# Check for infected keywords in the main query
if any(
re.search(keyword, query, re.IGNORECASE) for keyword in infected_keywords
):
return False

# Check for infected keywords in subqueries
for subquery in parsed.find_all(sqlglot.exp.Subquery):
subquery_sql = subquery.sql() # Get the SQL of the subquery
if any(
re.search(keyword, subquery_sql, re.IGNORECASE)
for keyword in infected_keywords
):
return False

return True

except sqlglot.errors.ParseError:
return False
4 changes: 2 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "pandasai"
version = "3.0.0-beta.6"
version = "3.0.0-beta.7"
description = "Chat with your database (SQL, CSV, pandas, mongodb, noSQL, etc). PandaAI makes data analysis conversational using LLMs (GPT 3.5 / 4, Anthropic, VertexAI) and RAG."
authors = ["Gabriele Venturi"]
license = "MIT"
Expand All @@ -23,6 +23,7 @@ numpy = "^1.17"
seaborn = "^0.12.2"
sqlglot = "^25.0.3"
pyarrow = "^14.0.1"
pyyaml = "^6.0.2"

[tool.poetry.group.dev]
optional = true
Expand Down
1 change: 0 additions & 1 deletion tests/unit_tests/data_loader/test_loader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import logging
from unittest.mock import mock_open, patch

import pandas as pd
Expand Down
63 changes: 62 additions & 1 deletion tests/unit_tests/data_loader/test_sql_loader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging

from unittest.mock import MagicMock, mock_open, patch

import pandas as pd
Expand All @@ -10,7 +11,7 @@
from pandasai.data_loader.semantic_layer_schema import SemanticLayerSchema
from pandasai.data_loader.sql_loader import SQLDatasetLoader
from pandasai.dataframe.base import DataFrame
from pandasai.exceptions import InvalidDataSourceType
from pandasai.exceptions import InvalidDataSourceType, MaliciousQueryError


class TestSqlDatasetLoader:
Expand Down Expand Up @@ -138,3 +139,63 @@ def test_load_with_transformation(self, mysql_schema):
loader_function.call_args[0][1]
== "SELECT email, first_name, timestamp FROM users ORDER BY RAND() LIMIT 5"
)


def test_mysql_malicious_query(self, mysql_schema):
"""Test loading data from a MySQL source creates a VirtualDataFrame and handles queries correctly."""
with patch(
"pandasai.data_loader.sql_loader.is_sql_query_safe"
) as mock_sql_query, patch(
"pandasai.data_loader.sql_loader.SQLDatasetLoader._get_loader_function"
) as mock_loader_function:
mocked_exec_function = MagicMock()
mock_df = DataFrame(
pd.DataFrame(
{
"email": ["test@example.com"],
"first_name": ["John"],
"timestamp": [pd.Timestamp.now()],
}
)
)
mocked_exec_function.return_value = mock_df
mock_loader_function.return_value = mocked_exec_function
loader = SQLDatasetLoader(mysql_schema, "test/users")
mock_sql_query.return_value = False
logging.debug("Loading schema from dataset path: %s", loader)

with pytest.raises(MaliciousQueryError):
loader.execute_query("DROP TABLE users")

mock_sql_query.assert_called_once_with("DROP TABLE users")

def test_mysql_safe_query(self, mysql_schema):
"""Test loading data from a MySQL source creates a VirtualDataFrame and handles queries correctly."""
with patch(
"pandasai.data_loader.sql_loader.is_sql_query_safe"
) as mock_sql_query, patch(
"pandasai.data_loader.sql_loader.SQLDatasetLoader._get_loader_function"
) as mock_loader_function, patch(
"pandasai.data_loader.sql_loader.SQLDatasetLoader._apply_transformations"
) as mock_apply_transformations:
mocked_exec_function = MagicMock()
mock_df = DataFrame(
pd.DataFrame(
{
"email": ["test@example.com"],
"first_name": ["John"],
"timestamp": [pd.Timestamp.now()],
}
)
)
mocked_exec_function.return_value = mock_df
mock_apply_transformations.return_value = mock_df
mock_loader_function.return_value = mocked_exec_function
loader = SQLDatasetLoader(mysql_schema, "test/users")
mock_sql_query.return_value = True
logging.debug("Loading schema from dataset path: %s", loader)

result = loader.execute_query("select * from users")

assert isinstance(result, DataFrame)
mock_sql_query.assert_called_once_with("select * from users")
70 changes: 69 additions & 1 deletion tests/unit_tests/helpers/test_sql_sanitizer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pandasai.helpers.sql_sanitizer import sanitize_sql_table_name
from pandasai.helpers.sql_sanitizer import is_sql_query_safe, sanitize_sql_table_name


class TestSqlSanitizer:
Expand All @@ -17,3 +17,71 @@ def test_filename_with_long_name(self):
filepath = "/path/to/" + "a" * 100 + ".csv"
expected = "a" * 64
assert sanitize_sql_table_name(filepath) == expected

def test_safe_select_query(self):
query = "SELECT * FROM users WHERE username = 'admin';"
assert is_sql_query_safe(query)

def test_safe_with_query(self):
query = "WITH user_data AS (SELECT * FROM users) SELECT * FROM user_data;"
assert is_sql_query_safe(query)

def test_unsafe_insert_query(self):
query = "INSERT INTO users (username, password) VALUES ('admin', 'password');"
assert not is_sql_query_safe(query)

def test_unsafe_update_query(self):
query = "UPDATE users SET password = 'newpassword' WHERE username = 'admin';"
assert not is_sql_query_safe(query)

def test_unsafe_delete_query(self):
query = "DELETE FROM users WHERE username = 'admin';"
assert not is_sql_query_safe(query)

def test_unsafe_drop_query(self):
query = "DROP TABLE users;"
assert not is_sql_query_safe(query)

def test_unsafe_alter_query(self):
query = "ALTER TABLE users ADD COLUMN age INT;"
assert not is_sql_query_safe(query)

def test_unsafe_create_query(self):
query = "CREATE TABLE users (id INT, username VARCHAR(50));"
assert not is_sql_query_safe(query)

def test_safe_select_with_comment(self):
query = "SELECT * FROM users WHERE username = 'admin' -- comment"
assert not is_sql_query_safe(query) # Blocked by comment detection

def test_safe_select_with_inline_comment(self):
query = "SELECT * FROM users /* inline comment */ WHERE username = 'admin';"
assert not is_sql_query_safe(query) # Blocked by comment detection

def test_unsafe_query_with_subquery(self):
query = "SELECT * FROM users WHERE id IN (SELECT user_id FROM orders);"
assert is_sql_query_safe(query) # No dangerous keyword in main or subquery

def test_unsafe_query_with_subquery_insert(self):
query = (
"SELECT * FROM users WHERE id IN (INSERT INTO orders (user_id) VALUES (1));"
)
assert not is_sql_query_safe(query) # Subquery contains INSERT, blocked

def test_invalid_sql(self):
query = "INVALID SQL QUERY"
assert not is_sql_query_safe(query) # Invalid query should return False

def test_safe_query_with_multiple_keywords(self):
query = "SELECT name FROM users WHERE username = 'admin' AND age > 30;"
assert is_sql_query_safe(query) # Safe query with no dangerous keyword

def test_safe_query_with_subquery(self):
query = "SELECT name FROM users WHERE username IN (SELECT username FROM users WHERE age > 30);"
assert is_sql_query_safe(
query
) # Safe query with subquery, no dangerous keyword


if __name__ == "__main__":
unittest.main()

0 comments on commit 124d4c3

Please sign in to comment.