From 259f55d44fccdbba55b5a5a607a4c2838f949cbe Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Fri, 28 Feb 2025 11:40:19 +0100 Subject: [PATCH] Fix and refactor tests of Docker executors (#827) --- tests/test_remote_executors.py | 67 ++++++++++++++++++++-------------- 1 file changed, 39 insertions(+), 28 deletions(-) diff --git a/tests/test_remote_executors.py b/tests/test_remote_executors.py index 78c69688a..6c1fc01b1 100644 --- a/tests/test_remote_executors.py +++ b/tests/test_remote_executors.py @@ -1,10 +1,11 @@ -import logging -from unittest import TestCase +from textwrap import dedent from unittest.mock import MagicMock, patch import docker +import pytest from PIL import Image +from smolagents.monitoring import AgentLogger, LogLevel from smolagents.remote_executors import DockerExecutor, E2BExecutor from .utils.markers import require_run_all @@ -23,19 +24,23 @@ def test_e2b_executor_instantiation(self): assert executor.sandbox == mock_sandbox.return_value -class TestDockerExecutor(TestCase): - def setUp(self): - self.logger = logging.getLogger("DockerExecutorTest") - self.executor = DockerExecutor( - additional_imports=["pillow", "numpy"], tools=[], logger=self.logger, initial_state={} - ) +@pytest.fixture +def docker_executor(): + executor = DockerExecutor(additional_imports=["pillow", "numpy"], logger=AgentLogger(level=LogLevel.OFF)) + yield executor + executor.delete() + + +@require_run_all +class TestDockerExecutor: + @pytest.fixture(autouse=True) + def set_executor(self, docker_executor): + self.executor = docker_executor - @require_run_all def test_initialization(self): """Check if DockerExecutor initializes without errors""" - self.assertIsNotNone(self.executor.container, "Container should be initialized") + assert self.executor.container is not None, "Container should be initialized" - @require_run_all def test_state_persistence(self): """Test that variables and imports form one snippet persist in the next""" code_action = "import numpy as np; a = 2" @@ -45,31 +50,37 @@ def test_state_persistence(self): result, logs, final_answer = self.executor(code_action) assert "1.41421" in logs - @require_run_all - def test_execute_image_output(self): - """Test execution that returns a base64 image""" - code_action = """ -import base64 -from PIL import Image -from io import BytesIO + def test_execute_output(self): + """Test execution that returns a string""" + code_action = 'final_answer("This is the final answer")' + result, logs, final_answer = self.executor(code_action) + assert result == "This is the final answer", "Result should be 'This is the final answer'" -image = Image.new("RGB", (10, 10), (255, 0, 0)) -final_answer(image) -""" + def test_execute_multiline_output(self): + """Test execution that returns a string""" + code_action = 'result = "This is the final answer"\nfinal_answer(result)' result, logs, final_answer = self.executor(code_action) + assert result == "This is the final answer", "Result should be 'This is the final answer'" - self.assertIsInstance(result, Image.Image, "Result should be a PIL Image") + def test_execute_image_output(self): + """Test execution that returns a base64 image""" + code_action = dedent(""" + import base64 + from PIL import Image + from io import BytesIO + image = Image.new("RGB", (10, 10), (255, 0, 0)) + final_answer(image) + """) + result, logs, final_answer = self.executor(code_action) + assert isinstance(result, Image.Image), "Result should be a PIL Image" - @require_run_all def test_syntax_error_handling(self): """Test handling of syntax errors""" code_action = 'print("Missing Parenthesis' # Syntax error - with self.assertRaises(ValueError) as context: + with pytest.raises(RuntimeError) as exception_info: self.executor(code_action) + assert "SyntaxError" in str(exception_info.value), "Should raise a syntax error" - self.assertIn("SyntaxError", str(context.exception), "Should raise a syntax error") - - @require_run_all def test_cleanup_on_deletion(self): """Test if Docker container stops and removes on deletion""" container_id = self.executor.container.id @@ -77,4 +88,4 @@ def test_cleanup_on_deletion(self): client = docker.from_env() containers = [c.id for c in client.containers.list(all=True)] - self.assertNotIn(container_id, containers, "Container should be removed") + assert container_id not in containers, "Container should be removed"