Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: add more tests in the agent #1572

Merged
merged 6 commits into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions tests/test_memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from pandasai.helpers.memory import Memory


def test_to_json_empty_memory():
memory = Memory()
assert memory.to_json() == []


def test_to_json_with_messages():
memory = Memory()

# Add test messages
memory.add("Hello", is_user=True)
memory.add("Hi there!", is_user=False)
memory.add("How are you?", is_user=True)

expected_json = [
{"role": "user", "message": "Hello"},
{"role": "assistant", "message": "Hi there!"},
{"role": "user", "message": "How are you?"},
]

assert memory.to_json() == expected_json


def test_to_json_message_order():
memory = Memory()

# Add messages in specific order
messages = [("Message 1", True), ("Message 2", False), ("Message 3", True)]

for msg, is_user in messages:
memory.add(msg, is_user=is_user)

result = memory.to_json()

# Verify order is preserved
assert len(result) == 3
assert result[0]["message"] == "Message 1"
assert result[1]["message"] == "Message 2"
assert result[2]["message"] == "Message 3"


def test_to_openai_messages_empty():
memory = Memory()
assert memory.to_openai_messages() == []


def test_to_openai_messages_with_agent_description():
memory = Memory(agent_description="I am a helpful assistant")
memory.add("Hello", is_user=True)
memory.add("Hi there!", is_user=False)

expected_messages = [
{"role": "system", "content": "I am a helpful assistant"},
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
]

assert memory.to_openai_messages() == expected_messages


def test_to_openai_messages_without_agent_description():
memory = Memory()
memory.add("Hello", is_user=True)
memory.add("Hi there!", is_user=False)
memory.add("How are you?", is_user=True)

expected_messages = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
{"role": "user", "content": "How are you?"},
]

assert memory.to_openai_messages() == expected_messages
109 changes: 107 additions & 2 deletions tests/unit_tests/agent/test_agent.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
import os
from typing import Optional
from unittest.mock import MagicMock, Mock, mock_open, patch
from unittest.mock import ANY, MagicMock, Mock, mock_open, patch

import pandas as pd
import pytest

from pandasai import DatasetLoader, VirtualDataFrame
from pandasai.agent.base import Agent
from pandasai.config import Config, ConfigManager
from pandasai.core.response.error import ErrorResponse
from pandasai.data_loader.semantic_layer_schema import SemanticLayerSchema
from pandasai.dataframe.base import DataFrame
from pandasai.exceptions import CodeExecutionError
from pandasai.exceptions import CodeExecutionError, InvalidLLMOutputType
from pandasai.llm.fake import FakeLLM


Expand Down Expand Up @@ -466,3 +467,107 @@ def test_execute_sql_query_error_no_dataframe(self, agent):

with pytest.raises(ValueError, match="No DataFrames available"):
agent._execute_sql_query(query)

def test_process_query(self, agent, config):
"""Test the _process_query method with successful execution"""
query = "What is the average age?"
output_type = "number"

# Mock the necessary methods
agent.generate_code = Mock(return_value="result = df['age'].mean()")
agent.execute_with_retries = Mock(return_value=30.5)
agent._state.config.enable_cache = True
agent._state.cache = Mock()

# Execute the query
result = agent._process_query(query, output_type)

# Verify the result
assert result == 30.5

# Verify method calls
agent.generate_code.assert_called_once()
agent.execute_with_retries.assert_called_once_with("result = df['age'].mean()")
agent._state.cache.set.assert_called_once()

def test_process_query_execution_error(self, agent, config):
"""Test the _process_query method with execution error"""
query = "What is the invalid operation?"

# Mock methods to simulate error
agent.generate_code = Mock(return_value="invalid_code")
agent.execute_with_retries = Mock(
side_effect=CodeExecutionError("Execution failed")
)
agent._handle_exception = Mock(return_value="Error handled")

# Execute the query
result = agent._process_query(query)

# Verify error handling
assert result == "Error handled"
agent._handle_exception.assert_called_once_with("invalid_code")

def test_regenerate_code_after_invalid_llm_output_error(self, agent):
"""Test code regeneration with InvalidLLMOutputType error"""
from pandasai.exceptions import InvalidLLMOutputType

code = "test code"
error = InvalidLLMOutputType("Invalid output type")

with patch(
"pandasai.agent.base.get_correct_output_type_error_prompt"
) as mock_prompt:
mock_prompt.return_value = "corrected prompt"
agent._code_generator.generate_code = MagicMock(return_value="new code")

result = agent._regenerate_code_after_error(code, error)

mock_prompt.assert_called_once_with(agent._state, code, ANY)
agent._code_generator.generate_code.assert_called_once_with(
"corrected prompt"
)
assert result == "new code"

def test_regenerate_code_after_other_error(self, agent):
"""Test code regeneration with non-InvalidLLMOutputType error"""
code = "test code"
error = ValueError("Some other error")

with patch(
"pandasai.agent.base.get_correct_error_prompt_for_sql"
) as mock_prompt:
mock_prompt.return_value = "sql error prompt"
agent._code_generator.generate_code = MagicMock(return_value="new code")

result = agent._regenerate_code_after_error(code, error)

mock_prompt.assert_called_once_with(agent._state, code, ANY)
agent._code_generator.generate_code.assert_called_once_with(
"sql error prompt"
)
assert result == "new code"

def test_handle_exception(self, agent):
"""Test that _handle_exception properly formats and logs exceptions"""
test_code = "print(1/0)" # Code that will raise a ZeroDivisionError

# Mock the logger to verify it's called
mock_logger = MagicMock()
agent._state.logger = mock_logger

# Create an actual exception to handle
try:
exec(test_code)
except:
# Call the method
result = agent._handle_exception(test_code)

# Verify the result is an ErrorResponse
assert isinstance(result, ErrorResponse)
assert result.last_code_executed == test_code
assert "ZeroDivisionError" in result.error

# Verify the error was logged
mock_logger.log.assert_called_once()
assert "Processing failed with error" in mock_logger.log.call_args[0][0]
136 changes: 136 additions & 0 deletions tests/unit_tests/dataframe/test_pull.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import os
from io import BytesIO
from unittest.mock import Mock, mock_open, patch
from zipfile import ZipFile

import pandas as pd
import pytest

from pandasai.data_loader.semantic_layer_schema import (
Column,
SemanticLayerSchema,
Source,
)
from pandasai.dataframe.base import DataFrame
from pandasai.exceptions import DatasetNotFound, PandaAIApiKeyError


@pytest.fixture
def mock_env(monkeypatch):
monkeypatch.setenv("PANDABI_API_KEY", "test_api_key")


@pytest.fixture
def sample_df():
return pd.DataFrame({"col1": [1, 2, 3], "col2": ["a", "b", "c"]})


@pytest.fixture
def mock_zip_content():
zip_buffer = BytesIO()
with ZipFile(zip_buffer, "w") as zip_file:
zip_file.writestr("test.csv", "col1,col2\n1,a\n2,b\n3,c")
return zip_buffer.getvalue()


@pytest.fixture
def mock_schema():
return SemanticLayerSchema(
name="test_schema",
source=Source(type="parquet", path="data.parquet", table="test_table"),
columns=[
Column(name="col1", type="integer"),
Column(name="col2", type="string"),
],
)


def test_pull_success(mock_env, sample_df, mock_zip_content, mock_schema, tmp_path):
with patch("pandasai.dataframe.base.get_pandaai_session") as mock_session, patch(
"pandasai.dataframe.base.find_project_root"
) as mock_root, patch(
"pandasai.DatasetLoader.create_loader_from_path"
) as mock_loader, patch("builtins.open", mock_open()) as mock_file:
# Setup mocks
mock_response = Mock()
mock_response.status_code = 200
mock_response.content = mock_zip_content
mock_session.return_value.get.return_value = mock_response
mock_root.return_value = str(tmp_path)

mock_loader_instance = Mock()
mock_loader_instance.load.return_value = DataFrame(
sample_df, schema=mock_schema
)
mock_loader.return_value = mock_loader_instance

# Create DataFrame instance and call pull
df = DataFrame(sample_df, path="test/path", schema=mock_schema)
df.pull()

# Verify API call
mock_session.return_value.get.assert_called_once_with(
"/datasets/pull",
headers={
"accept": "application/json",
"x-authorization": "Bearer test_api_key",
},
params={"path": "test/path"},
)

# Verify file operations
assert mock_file.call_count > 0


def test_pull_missing_api_key(sample_df, mock_schema):
with patch("os.environ.get") as mock_env_get:
mock_env_get.return_value = None
with pytest.raises(PandaAIApiKeyError):
df = DataFrame(sample_df, path="test/path", schema=mock_schema)
df.pull()


def test_pull_api_error(mock_env, sample_df, mock_schema):
with patch("pandasai.dataframe.base.get_pandaai_session") as mock_session:
mock_response = Mock()
mock_response.status_code = 404
mock_session.return_value.get.return_value = mock_response

df = DataFrame(sample_df, path="test/path", schema=mock_schema)
with pytest.raises(DatasetNotFound, match="Remote dataset not found to pull!"):
df.pull()


def test_pull_file_exists(mock_env, sample_df, mock_zip_content, mock_schema, tmp_path):
with patch("pandasai.dataframe.base.get_pandaai_session") as mock_session, patch(
"pandasai.dataframe.base.find_project_root"
) as mock_root, patch(
"pandasai.DatasetLoader.create_loader_from_path"
) as mock_loader, patch("builtins.open", mock_open()) as mock_file, patch(
"os.path.exists"
) as mock_exists, patch("os.makedirs") as mock_makedirs:
# Setup mocks
mock_response = Mock()
mock_response.status_code = 200
mock_response.content = mock_zip_content
mock_session.return_value.get.return_value = mock_response
mock_root.return_value = str(tmp_path)
mock_exists.return_value = True

mock_loader_instance = Mock()
mock_loader_instance.load.return_value = DataFrame(
sample_df, schema=mock_schema
)
mock_loader.return_value = mock_loader_instance

# Create DataFrame instance and call pull
df = DataFrame(sample_df, path="test/path", schema=mock_schema)
df.pull()

# Verify directory creation
mock_makedirs.assert_called_with(
os.path.dirname(
os.path.join(str(tmp_path), "datasets", "test/path", "test.csv")
),
exist_ok=True,
)
Loading
Loading