From f59cb540d2f6b88981665acc922e5ec6b64efcc6 Mon Sep 17 00:00:00 2001 From: Gabriele Venturi Date: Mon, 3 Feb 2025 11:55:27 +0100 Subject: [PATCH] tests: add more tests for prompts and smart datalake --- tests/unit_tests/core/prompts/test_base.py | 84 +++++++++++++++++++ ...ct_execute_sql_query_usage_error_prompt.py | 51 +++++++++++ .../test_correct_output_type_error_prompt.py | 58 +++++++++++++ ...st_generate_python_code_with_sql_prompt.py | 48 +++++++++++ .../smart_datalake/test_smart_datalake.py | 26 ++++++ 5 files changed, 267 insertions(+) create mode 100644 tests/unit_tests/core/prompts/test_base.py create mode 100644 tests/unit_tests/core/prompts/test_correct_execute_sql_query_usage_error_prompt.py create mode 100644 tests/unit_tests/core/prompts/test_correct_output_type_error_prompt.py create mode 100644 tests/unit_tests/core/prompts/test_generate_python_code_with_sql_prompt.py diff --git a/tests/unit_tests/core/prompts/test_base.py b/tests/unit_tests/core/prompts/test_base.py new file mode 100644 index 000000000..a8a1ea08b --- /dev/null +++ b/tests/unit_tests/core/prompts/test_base.py @@ -0,0 +1,84 @@ +from unittest.mock import MagicMock, patch + +import pytest +from jinja2 import Environment + +from pandasai.core.prompts.base import BasePrompt + + +class TestBasePrompt: + def test_to_json_without_context(self): + # Given a BasePrompt instance without context + class TestPrompt(BasePrompt): + template = "Test template {{ var }}" + + prompt = TestPrompt(var="value") + + # When calling to_json + result = prompt.to_json() + + # Then it should return a dict with only the prompt + assert isinstance(result, dict) + assert list(result.keys()) == ["prompt"] + assert result["prompt"] == "Test template value" + + def test_to_json_with_context(self): + # Given a BasePrompt instance with context + class TestPrompt(BasePrompt): + template = "Test template {{ var }}" + + memory = MagicMock() + memory.to_json.return_value = ["conversation1", "conversation2"] + memory.agent_description = "test agent" + + context = MagicMock() + context.memory = memory + + prompt = TestPrompt(var="value", context=context) + + # When calling to_json + result = prompt.to_json() + + # Then it should return a dict with conversation, system_prompt and prompt + assert isinstance(result, dict) + assert set(result.keys()) == {"conversation", "system_prompt", "prompt"} + assert result["conversation"] == ["conversation1", "conversation2"] + assert result["system_prompt"] == "test agent" + assert result["prompt"] == "Test template value" + + def test_render_with_variables(self): + # Given a BasePrompt instance with a template containing variables + class TestPrompt(BasePrompt): + template = "Hello {{ name }}!\nHow are you?\n\n\n\nGoodbye {{ name }}!" + + prompt = TestPrompt(name="World") + + # When calling render + result = prompt.render() + + # Then it should: + # 1. Replace variables correctly + # 2. Remove extra newlines (more than 2) + expected = "Hello World!\nHow are you?\n\nGoodbye World!" + assert result == expected + + def test_render_with_template_path(self): + # Given a BasePrompt instance with a template path + class TestPrompt(BasePrompt): + template_path = "test_template.txt" + + with patch.object(Environment, "get_template") as mock_get_template: + mock_template = MagicMock() + mock_template.render.return_value = "Hello\n\n\n\nWorld!" + mock_get_template.return_value = mock_template + + prompt = TestPrompt(name="Test") + + # When calling render + result = prompt.render() + + # Then it should: + # 1. Use the template from file + # 2. Remove extra newlines + assert result == "Hello\n\nWorld!" + mock_template.render.assert_called_once_with(name="Test") diff --git a/tests/unit_tests/core/prompts/test_correct_execute_sql_query_usage_error_prompt.py b/tests/unit_tests/core/prompts/test_correct_execute_sql_query_usage_error_prompt.py new file mode 100644 index 000000000..364726c18 --- /dev/null +++ b/tests/unit_tests/core/prompts/test_correct_execute_sql_query_usage_error_prompt.py @@ -0,0 +1,51 @@ +from unittest.mock import Mock, patch + +import pytest + +from pandasai.core.prompts.correct_execute_sql_query_usage_error_prompt import ( + CorrectExecuteSQLQueryUsageErrorPrompt, +) + + +def test_to_json(): + # Mock the dependencies + mock_dataset = Mock() + mock_dataset.to_json.return_value = {"mock_dataset": "data"} + + mock_memory = Mock() + mock_memory.to_json.return_value = {"mock_conversation": "data"} + mock_memory.agent_description = "Mock agent description" + + mock_context = Mock() + mock_context.memory = mock_memory + mock_context.dfs = [mock_dataset] + + # Create test data + test_code = "SELECT * FROM table" + test_error = Exception("Test error") + + # Create instance of the prompt class + prompt = CorrectExecuteSQLQueryUsageErrorPrompt( + context=mock_context, + code=test_code, + error=test_error, + ) + + # Call the method + result = prompt.to_json() + + # Assertions + assert result == { + "datasets": [{"mock_dataset": "data"}], + "conversation": {"mock_conversation": "data"}, + "system_prompt": "Mock agent description", + "error": { + "code": test_code, + "error_trace": str(test_error), + "exception_type": "ExecuteSQLQueryNotUsed", + }, + } + + # Verify the mocks were called + mock_dataset.to_json.assert_called_once() + mock_memory.to_json.assert_called_once() diff --git a/tests/unit_tests/core/prompts/test_correct_output_type_error_prompt.py b/tests/unit_tests/core/prompts/test_correct_output_type_error_prompt.py new file mode 100644 index 000000000..119daa000 --- /dev/null +++ b/tests/unit_tests/core/prompts/test_correct_output_type_error_prompt.py @@ -0,0 +1,58 @@ +from unittest.mock import Mock, patch + +import pytest + +from pandasai.core.prompts.correct_output_type_error_prompt import ( + CorrectOutputTypeErrorPrompt, +) + + +def test_to_json(): + # Mock the necessary dependencies + mock_memory = Mock() + mock_memory.to_json.return_value = {"conversations": "test"} + mock_memory.agent_description = "test agent" + + mock_dataset = Mock() + mock_dataset.to_json.return_value = {"data": "test data"} + + mock_context = Mock() + mock_context.memory = mock_memory + mock_context.dfs = [mock_dataset] + + # Create test data + props = { + "context": mock_context, + "code": "test code", + "error": Exception("test error"), + "output_type": "test_type", + } + + # Create instance of prompt + prompt = CorrectOutputTypeErrorPrompt(**props) + + # Call to_json method + result = prompt.to_json() + + # Verify the structure and content of the result + assert isinstance(result, dict) + assert "datasets" in result + assert "conversation" in result + assert "system_prompt" in result + assert "error" in result + assert "config" in result + + # Verify specific values + assert result["datasets"] == [{"data": "test data"}] + assert result["conversation"] == {"conversations": "test"} + assert result["system_prompt"] == "test agent" + assert result["error"] == { + "code": "test code", + "error_trace": "test error", + "exception_type": "InvalidLLMOutputType", + } + assert result["config"] == {"output_type": "test_type"} + + # Verify that the mock methods were called + mock_memory.to_json.assert_called_once() + mock_dataset.to_json.assert_called_once() diff --git a/tests/unit_tests/core/prompts/test_generate_python_code_with_sql_prompt.py b/tests/unit_tests/core/prompts/test_generate_python_code_with_sql_prompt.py new file mode 100644 index 000000000..d4a0efe31 --- /dev/null +++ b/tests/unit_tests/core/prompts/test_generate_python_code_with_sql_prompt.py @@ -0,0 +1,48 @@ +from unittest.mock import Mock, patch + +import pytest + +from pandasai.core.prompts import GeneratePythonCodeWithSQLPrompt + + +@pytest.fixture +def mock_context(): + context = Mock() + context.memory = Mock() + context.memory.to_json.return_value = {"history": []} + context.memory.agent_description = "Test Agent Description" + context.dfs = [Mock()] + context.dfs[0].to_json.return_value = {"name": "test_df", "data": []} + context.config.direct_sql = True + return context + + +def test_to_json(mock_context): + """Test that to_json returns the expected structure with all required fields""" + prompt = GeneratePythonCodeWithSQLPrompt(context=mock_context, output_type="code") + + # Mock the to_string method + with patch.object(prompt, "to_string", return_value="test prompt"): + result = prompt.to_json() + + assert isinstance(result, dict) + assert "datasets" in result + assert isinstance(result["datasets"], list) + assert len(result["datasets"]) == 1 + assert result["datasets"][0] == {"name": "test_df", "data": []} + + assert "conversation" in result + assert result["conversation"] == {"history": []} + + assert "system_prompt" in result + assert result["system_prompt"] == "Test Agent Description" + + assert "prompt" in result + assert result["prompt"] == "test prompt" + + assert "config" in result + assert isinstance(result["config"], dict) + assert "direct_sql" in result["config"] + assert result["config"]["direct_sql"] is True + assert "output_type" in result["config"] + assert result["config"]["output_type"] == "code" diff --git a/tests/unit_tests/smart_datalake/test_smart_datalake.py b/tests/unit_tests/smart_datalake/test_smart_datalake.py index b4e16a26a..b9c22cb1a 100644 --- a/tests/unit_tests/smart_datalake/test_smart_datalake.py +++ b/tests/unit_tests/smart_datalake/test_smart_datalake.py @@ -4,6 +4,7 @@ import pytest from pandasai.config import Config +from pandasai.core.cache import Cache from pandasai.smart_datalake import SmartDatalake @@ -43,3 +44,28 @@ def test_enable_cache(sample_dataframes): # Test with cache disabled mock_config.enable_cache = False assert smart_datalake.enable_cache is False + + +def test_enable_cache_setter(sample_dataframes): + # Create a mock agent with context and config + mock_config = Config(enable_cache=False) + mock_agent = Mock() + mock_agent.context = Mock() + mock_agent.context.config = mock_config + mock_agent.context.cache = None + + # Create SmartDatalake instance + smart_datalake = SmartDatalake(sample_dataframes) + smart_datalake._agent = mock_agent # Inject mock agent + + # Enable cache + smart_datalake.enable_cache = True + assert mock_agent.context.config.enable_cache is True + # Cache should be created and set in agent context + assert smart_datalake._cache is not None + assert isinstance(smart_datalake._cache, Cache) + + # Disable cache + smart_datalake.enable_cache = False + assert mock_agent.context.config.enable_cache is False + assert smart_datalake._cache is None