diff --git a/extensions/connectors/sql/pyproject.toml b/extensions/connectors/sql/pyproject.toml index 2e103c732..455cb5f59 100644 --- a/extensions/connectors/sql/pyproject.toml +++ b/extensions/connectors/sql/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "pandasai-sql" -version = "0.1.5" +version = "0.1.6" description = "SQL integration for PandaAI" authors = ["Gabriele Venturi"] license = "MIT" diff --git a/pandasai/data_loader/sql_loader.py b/pandasai/data_loader/sql_loader.py index 4b2a47189..a116f36b8 100644 --- a/pandasai/data_loader/sql_loader.py +++ b/pandasai/data_loader/sql_loader.py @@ -4,7 +4,8 @@ import pandas as pd from pandasai.dataframe.virtual_dataframe import VirtualDataFrame -from pandasai.exceptions import InvalidDataSourceType +from pandasai.exceptions import InvalidDataSourceType, MaliciousQueryError +from pandasai.helpers.sql_sanitizer import is_sql_query_safe from ..constants import ( SUPPORTED_SOURCE_CONNECTORS, @@ -36,6 +37,11 @@ def execute_query(self, query: str, params: Optional[list] = None) -> pd.DataFra formatted_query = self.query_builder.format_query(query) load_function = self._get_loader_function(source_type) + + if not is_sql_query_safe(formatted_query): + raise MaliciousQueryError( + "The SQL query is deemed unsafe and will not be executed." + ) try: dataframe: pd.DataFrame = load_function( connection_info, formatted_query, params diff --git a/pandasai/helpers/sql_sanitizer.py b/pandasai/helpers/sql_sanitizer.py index 82b4306eb..fb908d063 100644 --- a/pandasai/helpers/sql_sanitizer.py +++ b/pandasai/helpers/sql_sanitizer.py @@ -1,6 +1,8 @@ import os import re +import sqlglot + def sanitize_sql_table_name(filepath: str) -> str: # Extract the file name without extension @@ -14,3 +16,71 @@ def sanitize_sql_table_name(filepath: str) -> str: sanitized_name = sanitized_name[:max_length] return sanitized_name + + +def is_sql_query_safe(query: str) -> bool: + try: + # List of infected keywords to block (you can add more) + infected_keywords = [ + r"\bINSERT\b", + r"\bUPDATE\b", + r"\bDELETE\b", + r"\bDROP\b", + r"\bEXEC\b", + r"\bALTER\b", + r"\bCREATE\b", + r"\bMERGE\b", + r"\bREPLACE\b", + r"\bTRUNCATE\b", + r"\bLOAD\b", + r"\bGRANT\b", + r"\bREVOKE\b", + r"\bCALL\b", + r"\bEXECUTE\b", + r"\bSHOW\b", + r"\bDESCRIBE\b", + r"\bEXPLAIN\b", + r"\bUSE\b", + r"\bSET\b", + r"\bDECLARE\b", + r"\bOPEN\b", + r"\bFETCH\b", + r"\bCLOSE\b", + r"\bSLEEP\b", + r"\bBENCHMARK\b", + r"\bDATABASE\b", + r"\bUSER\b", + r"\bCURRENT_USER\b", + r"\bSESSION_USER\b", + r"\bSYSTEM_USER\b", + r"\bVERSION\b", + r"\b@@VERSION\b", + r"--", + r"/\*.*\*/", # Block comments and inline comments + ] + # Parse the query to extract its structure + parsed = sqlglot.parse_one(query) + + # Ensure the main query is SELECT + if parsed.key.upper() != "SELECT": + return False + + # Check for infected keywords in the main query + if any( + re.search(keyword, query, re.IGNORECASE) for keyword in infected_keywords + ): + return False + + # Check for infected keywords in subqueries + for subquery in parsed.find_all(sqlglot.exp.Subquery): + subquery_sql = subquery.sql() # Get the SQL of the subquery + if any( + re.search(keyword, subquery_sql, re.IGNORECASE) + for keyword in infected_keywords + ): + return False + + return True + + except sqlglot.errors.ParseError: + return False diff --git a/poetry.lock b/poetry.lock index f7ec5dc78..f4b30706e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1691,7 +1691,7 @@ version = "6.0.2" description = "YAML parser and emitter for Python" optional = false python-versions = ">=3.8" -groups = ["dev"] +groups = ["main", "dev"] files = [ {file = "PyYAML-6.0.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0a9a2848a5b7feac301353437eb7d5957887edbf81d56e903999a75a3d743086"}, {file = "PyYAML-6.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:29717114e51c84ddfba879543fb232a6ed60086602313ca38cce623c1d62cfbf"}, @@ -2063,4 +2063,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.1" python-versions = ">=3.8,<3.12" -content-hash = "d5c0c6c3d5e5d0317ce3d2785d285cee3dc837555086f84ec7ccfcc0c1cc3a33" +content-hash = "9fd5b43cddb627731e144406299386cec5b121e558207a6d4057c68e04f748a5" diff --git a/pyproject.toml b/pyproject.toml index 1502f3560..de23986ff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "pandasai" -version = "3.0.0-beta.6" +version = "3.0.0-beta.7" description = "Chat with your database (SQL, CSV, pandas, mongodb, noSQL, etc). PandaAI makes data analysis conversational using LLMs (GPT 3.5 / 4, Anthropic, VertexAI) and RAG." authors = ["Gabriele Venturi"] license = "MIT" @@ -23,6 +23,7 @@ numpy = "^1.17" seaborn = "^0.12.2" sqlglot = "^25.0.3" pyarrow = "^14.0.1" +pyyaml = "^6.0.2" [tool.poetry.group.dev] optional = true diff --git a/tests/unit_tests/data_loader/test_loader.py b/tests/unit_tests/data_loader/test_loader.py index 7ff5f4b89..db0521245 100644 --- a/tests/unit_tests/data_loader/test_loader.py +++ b/tests/unit_tests/data_loader/test_loader.py @@ -1,4 +1,3 @@ -import logging from unittest.mock import mock_open, patch import pandas as pd diff --git a/tests/unit_tests/data_loader/test_sql_loader.py b/tests/unit_tests/data_loader/test_sql_loader.py index f433eddca..776163bed 100644 --- a/tests/unit_tests/data_loader/test_sql_loader.py +++ b/tests/unit_tests/data_loader/test_sql_loader.py @@ -1,4 +1,5 @@ import logging + from unittest.mock import MagicMock, mock_open, patch import pandas as pd @@ -10,7 +11,7 @@ from pandasai.data_loader.semantic_layer_schema import SemanticLayerSchema from pandasai.data_loader.sql_loader import SQLDatasetLoader from pandasai.dataframe.base import DataFrame -from pandasai.exceptions import InvalidDataSourceType +from pandasai.exceptions import InvalidDataSourceType, MaliciousQueryError class TestSqlDatasetLoader: @@ -138,3 +139,63 @@ def test_load_with_transformation(self, mysql_schema): loader_function.call_args[0][1] == "SELECT email, first_name, timestamp FROM users ORDER BY RAND() LIMIT 5" ) + + + def test_mysql_malicious_query(self, mysql_schema): + """Test loading data from a MySQL source creates a VirtualDataFrame and handles queries correctly.""" + with patch( + "pandasai.data_loader.sql_loader.is_sql_query_safe" + ) as mock_sql_query, patch( + "pandasai.data_loader.sql_loader.SQLDatasetLoader._get_loader_function" + ) as mock_loader_function: + mocked_exec_function = MagicMock() + mock_df = DataFrame( + pd.DataFrame( + { + "email": ["test@example.com"], + "first_name": ["John"], + "timestamp": [pd.Timestamp.now()], + } + ) + ) + mocked_exec_function.return_value = mock_df + mock_loader_function.return_value = mocked_exec_function + loader = SQLDatasetLoader(mysql_schema, "test/users") + mock_sql_query.return_value = False + logging.debug("Loading schema from dataset path: %s", loader) + + with pytest.raises(MaliciousQueryError): + loader.execute_query("DROP TABLE users") + + mock_sql_query.assert_called_once_with("DROP TABLE users") + + def test_mysql_safe_query(self, mysql_schema): + """Test loading data from a MySQL source creates a VirtualDataFrame and handles queries correctly.""" + with patch( + "pandasai.data_loader.sql_loader.is_sql_query_safe" + ) as mock_sql_query, patch( + "pandasai.data_loader.sql_loader.SQLDatasetLoader._get_loader_function" + ) as mock_loader_function, patch( + "pandasai.data_loader.sql_loader.SQLDatasetLoader._apply_transformations" + ) as mock_apply_transformations: + mocked_exec_function = MagicMock() + mock_df = DataFrame( + pd.DataFrame( + { + "email": ["test@example.com"], + "first_name": ["John"], + "timestamp": [pd.Timestamp.now()], + } + ) + ) + mocked_exec_function.return_value = mock_df + mock_apply_transformations.return_value = mock_df + mock_loader_function.return_value = mocked_exec_function + loader = SQLDatasetLoader(mysql_schema, "test/users") + mock_sql_query.return_value = True + logging.debug("Loading schema from dataset path: %s", loader) + + result = loader.execute_query("select * from users") + + assert isinstance(result, DataFrame) + mock_sql_query.assert_called_once_with("select * from users") diff --git a/tests/unit_tests/helpers/test_sql_sanitizer.py b/tests/unit_tests/helpers/test_sql_sanitizer.py index 5f4ab40fc..a572cc5f4 100644 --- a/tests/unit_tests/helpers/test_sql_sanitizer.py +++ b/tests/unit_tests/helpers/test_sql_sanitizer.py @@ -1,4 +1,4 @@ -from pandasai.helpers.sql_sanitizer import sanitize_sql_table_name +from pandasai.helpers.sql_sanitizer import is_sql_query_safe, sanitize_sql_table_name class TestSqlSanitizer: @@ -17,3 +17,71 @@ def test_filename_with_long_name(self): filepath = "/path/to/" + "a" * 100 + ".csv" expected = "a" * 64 assert sanitize_sql_table_name(filepath) == expected + + def test_safe_select_query(self): + query = "SELECT * FROM users WHERE username = 'admin';" + assert is_sql_query_safe(query) + + def test_safe_with_query(self): + query = "WITH user_data AS (SELECT * FROM users) SELECT * FROM user_data;" + assert is_sql_query_safe(query) + + def test_unsafe_insert_query(self): + query = "INSERT INTO users (username, password) VALUES ('admin', 'password');" + assert not is_sql_query_safe(query) + + def test_unsafe_update_query(self): + query = "UPDATE users SET password = 'newpassword' WHERE username = 'admin';" + assert not is_sql_query_safe(query) + + def test_unsafe_delete_query(self): + query = "DELETE FROM users WHERE username = 'admin';" + assert not is_sql_query_safe(query) + + def test_unsafe_drop_query(self): + query = "DROP TABLE users;" + assert not is_sql_query_safe(query) + + def test_unsafe_alter_query(self): + query = "ALTER TABLE users ADD COLUMN age INT;" + assert not is_sql_query_safe(query) + + def test_unsafe_create_query(self): + query = "CREATE TABLE users (id INT, username VARCHAR(50));" + assert not is_sql_query_safe(query) + + def test_safe_select_with_comment(self): + query = "SELECT * FROM users WHERE username = 'admin' -- comment" + assert not is_sql_query_safe(query) # Blocked by comment detection + + def test_safe_select_with_inline_comment(self): + query = "SELECT * FROM users /* inline comment */ WHERE username = 'admin';" + assert not is_sql_query_safe(query) # Blocked by comment detection + + def test_unsafe_query_with_subquery(self): + query = "SELECT * FROM users WHERE id IN (SELECT user_id FROM orders);" + assert is_sql_query_safe(query) # No dangerous keyword in main or subquery + + def test_unsafe_query_with_subquery_insert(self): + query = ( + "SELECT * FROM users WHERE id IN (INSERT INTO orders (user_id) VALUES (1));" + ) + assert not is_sql_query_safe(query) # Subquery contains INSERT, blocked + + def test_invalid_sql(self): + query = "INVALID SQL QUERY" + assert not is_sql_query_safe(query) # Invalid query should return False + + def test_safe_query_with_multiple_keywords(self): + query = "SELECT name FROM users WHERE username = 'admin' AND age > 30;" + assert is_sql_query_safe(query) # Safe query with no dangerous keyword + + def test_safe_query_with_subquery(self): + query = "SELECT name FROM users WHERE username IN (SELECT username FROM users WHERE age > 30);" + assert is_sql_query_safe( + query + ) # Safe query with subquery, no dangerous keyword + + +if __name__ == "__main__": + unittest.main()