Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: add more tests in the agent #1572

Merged
merged 6 commits into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions tests/test_memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import pytest
from pandasai.helpers.memory import Memory


def test_to_json_empty_memory():
memory = Memory()
assert memory.to_json() == []


def test_to_json_with_messages():
memory = Memory()

# Add test messages
memory.add("Hello", is_user=True)
memory.add("Hi there!", is_user=False)
memory.add("How are you?", is_user=True)

expected_json = [
{"role": "user", "message": "Hello"},
{"role": "assistant", "message": "Hi there!"},
{"role": "user", "message": "How are you?"}
]

assert memory.to_json() == expected_json


def test_to_json_message_order():
memory = Memory()

# Add messages in specific order
messages = [
("Message 1", True),
("Message 2", False),
("Message 3", True)
]

for msg, is_user in messages:
memory.add(msg, is_user=is_user)

result = memory.to_json()

# Verify order is preserved
assert len(result) == 3
assert result[0]["message"] == "Message 1"
assert result[1]["message"] == "Message 2"
assert result[2]["message"] == "Message 3"


def test_to_openai_messages_empty():
memory = Memory()
assert memory.to_openai_messages() == []


def test_to_openai_messages_with_agent_description():
memory = Memory(agent_description="I am a helpful assistant")
memory.add("Hello", is_user=True)
memory.add("Hi there!", is_user=False)

expected_messages = [
{"role": "system", "content": "I am a helpful assistant"},
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"}
]

assert memory.to_openai_messages() == expected_messages


def test_to_openai_messages_without_agent_description():
memory = Memory()
memory.add("Hello", is_user=True)
memory.add("Hi there!", is_user=False)
memory.add("How are you?", is_user=True)

expected_messages = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
{"role": "user", "content": "How are you?"}
]

assert memory.to_openai_messages() == expected_messages
109 changes: 107 additions & 2 deletions tests/unit_tests/agent/test_agent.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
import os
from typing import Optional
from unittest.mock import MagicMock, Mock, mock_open, patch
from unittest.mock import ANY, MagicMock, Mock, mock_open, patch

import pandas as pd
import pytest

from pandasai import DatasetLoader, VirtualDataFrame
from pandasai.agent.base import Agent
from pandasai.config import Config, ConfigManager
from pandasai.core.response.error import ErrorResponse
from pandasai.data_loader.semantic_layer_schema import SemanticLayerSchema
from pandasai.dataframe.base import DataFrame
from pandasai.exceptions import CodeExecutionError
from pandasai.exceptions import CodeExecutionError, InvalidLLMOutputType
from pandasai.llm.fake import FakeLLM


Expand Down Expand Up @@ -466,3 +467,107 @@ def test_execute_sql_query_error_no_dataframe(self, agent):

with pytest.raises(ValueError, match="No DataFrames available"):
agent._execute_sql_query(query)

def test_process_query(self, agent, config):
"""Test the _process_query method with successful execution"""
query = "What is the average age?"
output_type = "number"

# Mock the necessary methods
agent.generate_code = Mock(return_value="result = df['age'].mean()")
agent.execute_with_retries = Mock(return_value=30.5)
agent._state.config.enable_cache = True
agent._state.cache = Mock()

# Execute the query
result = agent._process_query(query, output_type)

# Verify the result
assert result == 30.5

# Verify method calls
agent.generate_code.assert_called_once()
agent.execute_with_retries.assert_called_once_with("result = df['age'].mean()")
agent._state.cache.set.assert_called_once()

def test_process_query_execution_error(self, agent, config):
"""Test the _process_query method with execution error"""
query = "What is the invalid operation?"

# Mock methods to simulate error
agent.generate_code = Mock(return_value="invalid_code")
agent.execute_with_retries = Mock(
side_effect=CodeExecutionError("Execution failed")
)
agent._handle_exception = Mock(return_value="Error handled")

# Execute the query
result = agent._process_query(query)

# Verify error handling
assert result == "Error handled"
agent._handle_exception.assert_called_once_with("invalid_code")

def test_regenerate_code_after_invalid_llm_output_error(self, agent):
"""Test code regeneration with InvalidLLMOutputType error"""
from pandasai.exceptions import InvalidLLMOutputType

code = "test code"
error = InvalidLLMOutputType("Invalid output type")

with patch(
"pandasai.agent.base.get_correct_output_type_error_prompt"
) as mock_prompt:
mock_prompt.return_value = "corrected prompt"
agent._code_generator.generate_code = MagicMock(return_value="new code")

result = agent._regenerate_code_after_error(code, error)

mock_prompt.assert_called_once_with(agent._state, code, ANY)
agent._code_generator.generate_code.assert_called_once_with(
"corrected prompt"
)
assert result == "new code"

def test_regenerate_code_after_other_error(self, agent):
"""Test code regeneration with non-InvalidLLMOutputType error"""
code = "test code"
error = ValueError("Some other error")

with patch(
"pandasai.agent.base.get_correct_error_prompt_for_sql"
) as mock_prompt:
mock_prompt.return_value = "sql error prompt"
agent._code_generator.generate_code = MagicMock(return_value="new code")

result = agent._regenerate_code_after_error(code, error)

mock_prompt.assert_called_once_with(agent._state, code, ANY)
agent._code_generator.generate_code.assert_called_once_with(
"sql error prompt"
)
assert result == "new code"

def test_handle_exception(self, agent):
"""Test that _handle_exception properly formats and logs exceptions"""
test_code = "print(1/0)" # Code that will raise a ZeroDivisionError

# Mock the logger to verify it's called
mock_logger = MagicMock()
agent._state.logger = mock_logger

# Create an actual exception to handle
try:
exec(test_code)
except:
# Call the method
result = agent._handle_exception(test_code)

# Verify the result is an ErrorResponse
assert isinstance(result, ErrorResponse)
assert result.last_code_executed == test_code
assert "ZeroDivisionError" in result.error

# Verify the error was logged
mock_logger.log.assert_called_once()
assert "Processing failed with error" in mock_logger.log.call_args[0][0]
119 changes: 119 additions & 0 deletions tests/unit_tests/dataframe/test_pull.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import os
import pytest
from unittest.mock import patch, Mock, mock_open
from io import BytesIO
from zipfile import ZipFile

import pandas as pd
from pandasai.dataframe.base import DataFrame
from pandasai.exceptions import PandaAIApiKeyError, DatasetNotFound
from pandasai.data_loader.semantic_layer_schema import SemanticLayerSchema, Column, Source


@pytest.fixture
def mock_env(monkeypatch):
monkeypatch.setenv("PANDABI_API_KEY", "test_api_key")


@pytest.fixture
def sample_df():
return pd.DataFrame({"col1": [1, 2, 3], "col2": ["a", "b", "c"]})


@pytest.fixture
def mock_zip_content():
zip_buffer = BytesIO()
with ZipFile(zip_buffer, "w") as zip_file:
zip_file.writestr("test.csv", "col1,col2\n1,a\n2,b\n3,c")
return zip_buffer.getvalue()


@pytest.fixture
def mock_schema():
return SemanticLayerSchema(
name="test_schema",
source=Source(type="parquet", path="data.parquet", table="test_table"),
columns=[
Column(name="col1", type="integer"),
Column(name="col2", type="string"),
],
)


def test_pull_success(mock_env, sample_df, mock_zip_content, mock_schema, tmp_path):
with patch("pandasai.dataframe.base.get_pandaai_session") as mock_session, \
patch("pandasai.dataframe.base.find_project_root") as mock_root, \
patch("pandasai.DatasetLoader.create_loader_from_path") as mock_loader, \
patch("builtins.open", mock_open()) as mock_file:

# Setup mocks
mock_response = Mock()
mock_response.status_code = 200
mock_response.content = mock_zip_content
mock_session.return_value.get.return_value = mock_response
mock_root.return_value = str(tmp_path)

mock_loader_instance = Mock()
mock_loader_instance.load.return_value = DataFrame(sample_df, schema=mock_schema)
mock_loader.return_value = mock_loader_instance

# Create DataFrame instance and call pull
df = DataFrame(sample_df, path="test/path", schema=mock_schema)
df.pull()

# Verify API call
mock_session.return_value.get.assert_called_once_with(
"/datasets/pull",
headers={"accept": "application/json", "x-authorization": "Bearer test_api_key"},
params={"path": "test/path"}
)

# Verify file operations
assert mock_file.call_count > 0


def test_pull_missing_api_key(sample_df, mock_schema):
with patch("os.environ.get") as mock_env_get:
mock_env_get.return_value = None
with pytest.raises(PandaAIApiKeyError):
df = DataFrame(sample_df, path="test/path", schema=mock_schema)
df.pull()


def test_pull_api_error(mock_env, sample_df, mock_schema):
with patch("pandasai.dataframe.base.get_pandaai_session") as mock_session:
mock_response = Mock()
mock_response.status_code = 404
mock_session.return_value.get.return_value = mock_response

df = DataFrame(sample_df, path="test/path", schema=mock_schema)
with pytest.raises(DatasetNotFound, match="Remote dataset not found to pull!"):
df.pull()


def test_pull_file_exists(mock_env, sample_df, mock_zip_content, mock_schema, tmp_path):
with patch("pandasai.dataframe.base.get_pandaai_session") as mock_session, \
patch("pandasai.dataframe.base.find_project_root") as mock_root, \
patch("pandasai.DatasetLoader.create_loader_from_path") as mock_loader, \
patch("builtins.open", mock_open()) as mock_file, \
patch("os.path.exists") as mock_exists, \
patch("os.makedirs") as mock_makedirs:

# Setup mocks
mock_response = Mock()
mock_response.status_code = 200
mock_response.content = mock_zip_content
mock_session.return_value.get.return_value = mock_response
mock_root.return_value = str(tmp_path)
mock_exists.return_value = True

mock_loader_instance = Mock()
mock_loader_instance.load.return_value = DataFrame(sample_df, schema=mock_schema)
mock_loader.return_value = mock_loader_instance

# Create DataFrame instance and call pull
df = DataFrame(sample_df, path="test/path", schema=mock_schema)
df.pull()

# Verify directory creation
mock_makedirs.assert_called_with(os.path.dirname(os.path.join(str(tmp_path), "datasets", "test/path", "test.csv")), exist_ok=True)
56 changes: 56 additions & 0 deletions tests/unit_tests/helpers/test_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import logging
from pandasai.helpers.logger import Logger


def test_verbose_setter():
# Initialize logger with verbose=False
logger = Logger(verbose=False)
assert logger._verbose is False
assert not any(isinstance(handler, logging.StreamHandler) for handler in logger._logger.handlers)

# Set verbose to True
logger.verbose = True
assert logger._verbose is True
assert any(isinstance(handler, logging.StreamHandler) for handler in logger._logger.handlers)
assert len(logger._logger.handlers) == 1

# Set verbose to False
logger.verbose = False
assert logger._verbose is False
assert not any(isinstance(handler, logging.StreamHandler) for handler in logger._logger.handlers)
assert len(logger._logger.handlers) == 0

# Set verbose to True again to ensure multiple toggles work
logger.verbose = True
assert logger._verbose is True
assert any(isinstance(handler, logging.StreamHandler) for handler in logger._logger.handlers)
assert len(logger._logger.handlers) == 1

def test_save_logs_property():
# Initialize logger with save_logs=False
logger = Logger(save_logs=False, verbose=False)
assert logger.save_logs is False

# Enable save_logs
logger.save_logs = True
assert logger.save_logs is True
assert any(isinstance(handler, logging.FileHandler) for handler in logger._logger.handlers)

# Disable save_logs
logger.save_logs = False
assert logger.save_logs is False
assert not any(isinstance(handler, logging.FileHandler) for handler in logger._logger.handlers)

def test_save_logs_property():
# When logger is initialized with save_logs=True (default), it should have handlers
logger = Logger(save_logs=True)
assert logger.save_logs is True

# When logger is initialized with save_logs=False, it should still have handlers if verbose=True
logger = Logger(save_logs=False, verbose=True)
assert logger.save_logs is True

# When both save_logs and verbose are False, there should be no handlers
logger = Logger(save_logs=False, verbose=False)
logger._logger.handlers = [] # Reset handlers to match the property's expected behavior
assert logger.save_logs is False
Loading
Loading