From b9782e9b384ec7e8cab611b17d72c7a6626f2392 Mon Sep 17 00:00:00 2001 From: Gabriele Venturi Date: Mon, 3 Feb 2025 09:14:46 +0100 Subject: [PATCH 1/2] tests: add tests for config, smart dataframe and smart datalake --- pandasai/config.py | 1 - .../smart_dataframe/test_smart_dataframe.py | 122 ++++++++++++++++ .../smart_datalake/test_smart_datalake.py | 45 ++++++ tests/unit_tests/test_api_key_manager.py | 42 ++++++ tests/unit_tests/test_config.py | 132 ++++++++++++++++++ 5 files changed, 341 insertions(+), 1 deletion(-) create mode 100644 tests/unit_tests/smart_dataframe/test_smart_dataframe.py create mode 100644 tests/unit_tests/smart_datalake/test_smart_datalake.py create mode 100644 tests/unit_tests/test_api_key_manager.py create mode 100644 tests/unit_tests/test_config.py diff --git a/pandasai/config.py b/pandasai/config.py index 22b420fe5..926c08d4c 100644 --- a/pandasai/config.py +++ b/pandasai/config.py @@ -1,5 +1,4 @@ import os -from abc import ABC, abstractmethod from importlib.util import find_spec from typing import Any, Dict, Optional diff --git a/tests/unit_tests/smart_dataframe/test_smart_dataframe.py b/tests/unit_tests/smart_dataframe/test_smart_dataframe.py new file mode 100644 index 000000000..bdd2b750f --- /dev/null +++ b/tests/unit_tests/smart_dataframe/test_smart_dataframe.py @@ -0,0 +1,122 @@ +import warnings + +import pandas as pd +import pytest + +from pandasai.config import Config +from pandasai.smart_dataframe import SmartDataframe, load_smartdataframes + + +def test_smart_dataframe_init_basic(): + # Create a sample dataframe + df = pd.DataFrame({"A": [1, 2, 3], "B": ["x", "y", "z"]}) + + # Test initialization with minimal parameters + with pytest.warns(DeprecationWarning): + smart_df = SmartDataframe(df) + + assert smart_df._original_import is df + assert isinstance(smart_df.dataframe, pd.DataFrame) + assert smart_df._table_name is None + assert smart_df._table_description is None + assert smart_df._custom_head is None + + +def test_smart_dataframe_init_with_all_params(): + # Create sample dataframes + df = pd.DataFrame({"A": [1, 2, 3], "B": ["x", "y", "z"]}) + custom_head = pd.DataFrame({"A": [1], "B": ["x"]}) + config = Config() + + # Test initialization with all parameters + with pytest.warns(DeprecationWarning): + smart_df = SmartDataframe( + df, + name="test_df", + description="Test dataframe", + custom_head=custom_head, + config=config, + ) + + assert smart_df._original_import is df + assert isinstance(smart_df.dataframe, pd.DataFrame) + assert smart_df._table_name == "test_df" + assert smart_df._table_description == "Test dataframe" + assert smart_df._custom_head == custom_head.to_csv(index=False) + assert smart_df._agent._state._config == config + + +def test_smart_dataframe_deprecation_warning(): + df = pd.DataFrame({"A": [1, 2, 3]}) + + with warnings.catch_warnings(record=True) as warning_info: + warnings.simplefilter("always") + SmartDataframe(df) + + deprecation_warnings = [ + w for w in warning_info if issubclass(w.category, DeprecationWarning) + ] + assert len(deprecation_warnings) >= 1 + assert "SmartDataframe will soon be deprecated" in str( + deprecation_warnings[0].message + ) + + +def test_load_df_success(): + # Create sample dataframes + original_df = pd.DataFrame({"A": [1, 2, 3], "B": ["x", "y", "z"]}) + with pytest.warns(DeprecationWarning): + smart_df = SmartDataframe(original_df) + + # Test loading a new dataframe + new_df = pd.DataFrame({"C": [4, 5, 6], "D": ["a", "b", "c"]}) + loaded_df = smart_df.load_df( + new_df, + name="new_df", + description="New test dataframe", + custom_head=pd.DataFrame({"C": [4], "D": ["a"]}), + ) + + assert isinstance(loaded_df, pd.DataFrame) + assert loaded_df.equals(new_df) + + +def test_load_df_invalid_input(): + # Create a sample dataframe + original_df = pd.DataFrame({"A": [1, 2, 3], "B": ["x", "y", "z"]}) + with pytest.warns(DeprecationWarning): + smart_df = SmartDataframe(original_df) + + # Test loading invalid data + with pytest.raises( + ValueError, match="Invalid input data. We cannot convert it to a dataframe." + ): + smart_df.load_df( + "not a dataframe", + name="invalid_df", + description="Invalid test data", + custom_head=None, + ) + + +def test_load_smartdataframes(): + # Create sample dataframes + df1 = pd.DataFrame({"A": [1, 2, 3], "B": ["x", "y", "z"]}) + df2 = pd.DataFrame({"C": [4, 5, 6], "D": ["a", "b", "c"]}) + + # Create a config + config = Config() + + # Test loading regular pandas DataFrames + smart_dfs = load_smartdataframes([df1, df2], config) + assert len(smart_dfs) == 2 + assert all(isinstance(df, SmartDataframe) for df in smart_dfs) + assert all(hasattr(df, "config") for df in smart_dfs) + + # Test loading mixed pandas DataFrames and SmartDataframes + existing_smart_df = SmartDataframe(df1, config=config) + mixed_dfs = load_smartdataframes([existing_smart_df, df2], config) + assert len(mixed_dfs) == 2 + assert mixed_dfs[0] is existing_smart_df # Should return the same instance + assert isinstance(mixed_dfs[1], SmartDataframe) + assert hasattr(mixed_dfs[1], "config") diff --git a/tests/unit_tests/smart_datalake/test_smart_datalake.py b/tests/unit_tests/smart_datalake/test_smart_datalake.py new file mode 100644 index 000000000..b4e16a26a --- /dev/null +++ b/tests/unit_tests/smart_datalake/test_smart_datalake.py @@ -0,0 +1,45 @@ +from unittest.mock import Mock + +import pandas as pd +import pytest + +from pandasai.config import Config +from pandasai.smart_datalake import SmartDatalake + + +@pytest.fixture +def sample_dataframes(): + df1 = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) + df2 = pd.DataFrame({"C": [7, 8, 9], "D": [10, 11, 12]}) + return [df1, df2] + + +def test_dfs_property(sample_dataframes): + # Create a mock agent with context + mock_agent = Mock() + mock_agent.context.dfs = sample_dataframes + + # Create SmartDatalake instance + smart_datalake = SmartDatalake(sample_dataframes) + smart_datalake._agent = mock_agent # Inject mock agent + + # Test that dfs property returns the correct dataframes + assert smart_datalake.dfs == sample_dataframes + + +def test_enable_cache(sample_dataframes): + # Create a mock agent with context and config + mock_config = Config(enable_cache=True) + mock_agent = Mock() + mock_agent.context.config = mock_config + + # Create SmartDatalake instance + smart_datalake = SmartDatalake(sample_dataframes) + smart_datalake._agent = mock_agent # Inject mock agent + + # Test that enable_cache property returns the correct value + assert smart_datalake.enable_cache is True + + # Test with cache disabled + mock_config.enable_cache = False + assert smart_datalake.enable_cache is False diff --git a/tests/unit_tests/test_api_key_manager.py b/tests/unit_tests/test_api_key_manager.py new file mode 100644 index 000000000..8508cd112 --- /dev/null +++ b/tests/unit_tests/test_api_key_manager.py @@ -0,0 +1,42 @@ +import os +from unittest.mock import patch + +import pytest + +from pandasai.config import APIKeyManager + + +def test_set_api_key(): + # Setup + test_api_key = "test-api-key-123" + + # Execute + with patch.dict(os.environ, {}, clear=True): + APIKeyManager.set(test_api_key) + + # Assert + assert os.environ.get("PANDABI_API_KEY") == test_api_key + assert APIKeyManager._api_key == test_api_key + + +def test_get_api_key(): + # Setup + test_api_key = "test-api-key-123" + APIKeyManager._api_key = test_api_key + + # Execute + result = APIKeyManager.get() + + # Assert + assert result == test_api_key + + +def test_get_api_key_when_none(): + # Setup + APIKeyManager._api_key = None + + # Execute + result = APIKeyManager.get() + + # Assert + assert result is None diff --git a/tests/unit_tests/test_config.py b/tests/unit_tests/test_config.py new file mode 100644 index 000000000..ca0b47b05 --- /dev/null +++ b/tests/unit_tests/test_config.py @@ -0,0 +1,132 @@ +import os +from unittest.mock import MagicMock, patch + +from pandasai.config import Config, ConfigManager +from pandasai.helpers.filemanager import DefaultFileManager +from pandasai.llm.bamboo_llm import BambooLLM + + +def test_validate_llm_with_pandabi_api_key(): + # Setup + ConfigManager._config = MagicMock() + ConfigManager._config.llm = None + + with patch.dict(os.environ, {"PANDABI_API_KEY": "test-key"}): + # Execute + ConfigManager.validate_llm() + + # Assert + assert isinstance(ConfigManager._config.llm, BambooLLM) + + +def test_validate_llm_with_langchain(): + # Setup + ConfigManager._config = MagicMock() + mock_llm = MagicMock() + ConfigManager._config.llm = mock_llm + mock_langchain_llm = MagicMock() + + # Create mock module + mock_langchain_module = MagicMock() + mock_langchain_module.__spec__ = MagicMock() + mock_langchain_module.langchain = MagicMock( + LangchainLLM=mock_langchain_llm, is_langchain_llm=lambda x: True + ) + + with patch.dict( + "sys.modules", + { + "pandasai_langchain": mock_langchain_module, + "pandasai_langchain.langchain": mock_langchain_module.langchain, + }, + ): + # Execute + ConfigManager.validate_llm() + + # Assert + assert mock_langchain_llm.call_count == 1 + + +def test_validate_llm_no_action_needed(): + # Setup + ConfigManager._config = MagicMock() + mock_llm = MagicMock() + ConfigManager._config.llm = mock_llm + + # Case where no PANDABI_API_KEY and not a langchain LLM + with patch.dict(os.environ, {}, clear=True): + with patch("importlib.util.find_spec") as mock_find_spec: + mock_find_spec.return_value = None + + # Execute + ConfigManager.validate_llm() + + # Assert - llm should remain unchanged + assert ConfigManager._config.llm == mock_llm + + +def test_config_update(): + # Setup + mock_config = MagicMock() + initial_config = {"key1": "value1", "key2": "value2"} + mock_config.model_dump = MagicMock(return_value=initial_config.copy()) + ConfigManager._config = mock_config + + # Create a mock for Config.from_dict + original_from_dict = Config.from_dict + Config.from_dict = MagicMock() + + try: + # Execute + new_config = {"key2": "new_value2", "key3": "value3"} + ConfigManager.update(new_config) + + # Assert + expected_config = {"key1": "value1", "key2": "new_value2", "key3": "value3"} + assert mock_config.model_dump.call_count == 1 + Config.from_dict.assert_called_once_with(expected_config) + finally: + # Restore original from_dict method + Config.from_dict = original_from_dict + + +def test_config_set(): + # Setup + test_config = {"key": "value"} + with patch.object(Config, "from_dict") as mock_from_dict, patch.object( + ConfigManager, "validate_llm" + ) as mock_validate_llm: + # Execute + ConfigManager.set(test_config) + + # Assert + mock_from_dict.assert_called_once_with(test_config) + mock_validate_llm.assert_called_once() + + +def test_config_from_dict(): + # Test with default overrides + config_dict = { + "save_logs": False, + "verbose": True, + "enable_cache": False, + "max_retries": 5, + } + + config = Config.from_dict(config_dict) + + assert isinstance(config, Config) + assert config.save_logs == False + assert config.verbose == True + assert config.enable_cache == False + assert config.max_retries == 5 + assert config.llm is None + assert isinstance(config.file_manager, DefaultFileManager) + + # Test with minimal dict + minimal_config = Config.from_dict({}) + assert isinstance(minimal_config, Config) + assert minimal_config.save_logs == True # default value + assert minimal_config.verbose == False # default value + assert minimal_config.enable_cache == True # default value + assert minimal_config.max_retries == 3 # default value From d6446b6c7658caa1b7723c1bf81134e71ffada24 Mon Sep 17 00:00:00 2001 From: ArslanSaleem Date: Tue, 4 Feb 2025 17:52:33 +0100 Subject: [PATCH 2/2] fix(dataset): update exception message --- pandasai/__init__.py | 21 +++++++++++++--- .../smart_datalake/test_smart_datalake.py | 3 --- tests/unit_tests/test_pandasai_init.py | 25 ++++++++++++++++--- 3 files changed, 40 insertions(+), 9 deletions(-) diff --git a/pandasai/__init__.py b/pandasai/__init__.py index a65ce95e0..ca978a737 100644 --- a/pandasai/__init__.py +++ b/pandasai/__init__.py @@ -212,11 +212,17 @@ def load(dataset_path: str) -> DataFrame: raise ValueError("The path must be in the format 'organization/dataset'.") dataset_full_path = os.path.join(find_project_root(), "datasets", dataset_path) - if not os.path.exists(dataset_full_path): + + local_dataset_exists = os.path.exists(dataset_full_path) + + if not local_dataset_exists: api_key = os.environ.get("PANDABI_API_KEY", None) api_url = os.environ.get("PANDABI_API_URL", DEFAULT_API_URL) + if not api_url or not api_key: - raise PandaAIApiKeyError() + raise PandaAIApiKeyError( + f'The dataset "{dataset_path}" does not exist in your local datasets directory. In addition, no API Key has been provided. Set an API key with valid permits if you want to fetch the dataset from the remote server.' + ) request_session = get_pandaai_session() @@ -232,7 +238,16 @@ def load(dataset_path: str) -> DataFrame: zip_file.extractall(dataset_full_path) loader = DatasetLoader.create_loader_from_path(dataset_path) - return loader.load() + df = loader.load() + + message = ( + "Dataset loaded successfully." + if local_dataset_exists + else "Dataset fetched successfully from the remote server." + ) + print(message) + + return df def read_csv(filepath: str) -> DataFrame: diff --git a/tests/unit_tests/smart_datalake/test_smart_datalake.py b/tests/unit_tests/smart_datalake/test_smart_datalake.py index 20e78626c..b9c22cb1a 100644 --- a/tests/unit_tests/smart_datalake/test_smart_datalake.py +++ b/tests/unit_tests/smart_datalake/test_smart_datalake.py @@ -44,8 +44,6 @@ def test_enable_cache(sample_dataframes): # Test with cache disabled mock_config.enable_cache = False assert smart_datalake.enable_cache is False -<<<<<<< HEAD -======= def test_enable_cache_setter(sample_dataframes): @@ -71,4 +69,3 @@ def test_enable_cache_setter(sample_dataframes): smart_datalake.enable_cache = False assert mock_agent.context.config.enable_cache is False assert smart_datalake._cache is None ->>>>>>> 98ea589882e8fa26e8972c28bfac7b282a675b75 diff --git a/tests/unit_tests/test_pandasai_init.py b/tests/unit_tests/test_pandasai_init.py index 3f1cce487..b8e783cbc 100644 --- a/tests/unit_tests/test_pandasai_init.py +++ b/tests/unit_tests/test_pandasai_init.py @@ -137,6 +137,25 @@ def test_load_dataset_not_found(self, mockenviron, mock_bytes_io, mock_zip_file) with pytest.raises(DatasetNotFound): pandasai.load(dataset_path) + @patch("pandasai.os.path.exists") + @patch("pandasai.os.environ", {}) + @patch("pandasai.get_pandaai_session") + def test_load_missing_not_found_locally_and_no_remote_key( + self, mock_session, mock_exists + ): + """Test loading when API URL is missing.""" + mock_exists.return_value = False + mock_response = MagicMock() + mock_response.status_code = 404 + mock_session.return_value.get.return_value = mock_response + dataset_path = "org/dataset_name" + + with pytest.raises( + PandaAIApiKeyError, + match='The dataset "org/dataset_name" does not exist in your local datasets directory. In addition, no API Key has been provided. Set an API key with valid permits if you want to fetch the dataset from the remote server.', + ): + pandasai.load(dataset_path) + @patch("pandasai.os.path.exists") @patch("pandasai.os.environ", {"PANDABI_API_KEY": "key"}) def test_load_missing_api_url(self, mock_exists): @@ -144,13 +163,13 @@ def test_load_missing_api_url(self, mock_exists): mock_exists.return_value = False dataset_path = "org/dataset_name" - with pytest.raises(PandaAIApiKeyError): + with pytest.raises(DatasetNotFound): pandasai.load(dataset_path) @patch("pandasai.os.path.exists") @patch("pandasai.os.environ", {"PANDABI_API_KEY": "key"}) @patch("pandasai.get_pandaai_session") - def test_load_missing_api_url(self, mock_session, mock_exists): + def test_load_missing_not_found(self, mock_session, mock_exists): """Test loading when API URL is missing.""" mock_exists.return_value = False mock_response = MagicMock() @@ -202,7 +221,7 @@ def test_load_without_api_credentials( pandasai.load("test/dataset") assert ( str(exc_info.value) - == "PandaAI API key not found. Please set your API key using PandaAI.set_api_key() or by setting the PANDASAI_API_KEY environment variable." + == 'The dataset "test/dataset" does not exist in your local datasets directory. In addition, no API Key has been provided. Set an API key with valid permits if you want to fetch the dataset from the remote server.' ) def test_clear_cache(self):