Skip to content

Commit

Permalink
Fix evaluate_condition for non-bool result (#638)
Browse files Browse the repository at this point in the history
  • Loading branch information
albertvillanova authored Feb 14, 2025
1 parent d02093d commit 9437133
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 53 deletions.
50 changes: 20 additions & 30 deletions src/smolagents/local_python_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,49 +714,39 @@ def evaluate_condition(
custom_tools: Dict[str, Callable],
authorized_imports: List[str],
) -> bool | object:
left = evaluate_ast(condition.left, state, static_tools, custom_tools, authorized_imports)
comparators = [
evaluate_ast(c, state, static_tools, custom_tools, authorized_imports) for c in condition.comparators
]
ops = [type(op) for op in condition.ops]

result = True
current_left = left

for op, comparator in zip(ops, comparators):
left = evaluate_ast(condition.left, state, static_tools, custom_tools, authorized_imports)
for i, (op, comparator) in enumerate(zip(condition.ops, condition.comparators)):
op = type(op)
right = evaluate_ast(comparator, state, static_tools, custom_tools, authorized_imports)
if op == ast.Eq:
current_result = current_left == comparator
current_result = left == right
elif op == ast.NotEq:
current_result = current_left != comparator
current_result = left != right
elif op == ast.Lt:
current_result = current_left < comparator
current_result = left < right
elif op == ast.LtE:
current_result = current_left <= comparator
current_result = left <= right
elif op == ast.Gt:
current_result = current_left > comparator
current_result = left > right
elif op == ast.GtE:
current_result = current_left >= comparator
current_result = left >= right
elif op == ast.Is:
current_result = current_left is comparator
current_result = left is right
elif op == ast.IsNot:
current_result = current_left is not comparator
current_result = left is not right
elif op == ast.In:
current_result = current_left in comparator
current_result = left in right
elif op == ast.NotIn:
current_result = current_left not in comparator
current_result = left not in right
else:
raise InterpreterError(f"Operator not supported: {op}")

if not isinstance(current_result, bool):
return current_result
raise InterpreterError(f"Unsupported comparison operator: {op}")

result = result & current_result
current_left = comparator

if isinstance(result, bool) and not result:
break

return result if isinstance(result, (bool, pd.Series)) else result.all()
if current_result is False:
return False
result = current_result if i == 0 else (result and current_result)
left = right
return result


def evaluate_if(
Expand Down
132 changes: 109 additions & 23 deletions tests/test_local_python_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from textwrap import dedent

import numpy as np
import pandas as pd
import pytest

from smolagents.default_tools import BASE_PYTHON_TOOLS
Expand Down Expand Up @@ -1188,11 +1189,11 @@ def test_evaluate_delete(code, state, expectation):
("a in b", {"a": 4, "b": [1, 2, 3]}, False),
("a not in b", {"a": 1, "b": [1, 2, 3]}, False),
("a not in b", {"a": 4, "b": [1, 2, 3]}, True),
# Composite conditions:
# Chained conditions:
("a == b == c", {"a": 1, "b": 1, "c": 1}, True),
("a == b == c", {"a": 1, "b": 2, "c": 1}, False),
("a == b < c", {"a": 1, "b": 1, "c": 1}, False),
("a == b < c", {"a": 1, "b": 1, "c": 2}, True),
("a == b < c", {"a": 2, "b": 2, "c": 2}, False),
("a == b < c", {"a": 0, "b": 0, "c": 1}, True),
],
)
def test_evaluate_condition(condition, state, expected_result):
Expand All @@ -1201,6 +1202,91 @@ def test_evaluate_condition(condition, state, expected_result):
assert result == expected_result


@pytest.mark.parametrize(
"condition, state, expected_result",
[
("a == b", {"a": pd.Series([1, 2, 3]), "b": pd.Series([2, 2, 2])}, pd.Series([False, True, False])),
("a != b", {"a": pd.Series([1, 2, 3]), "b": pd.Series([2, 2, 2])}, pd.Series([True, False, True])),
("a < b", {"a": pd.Series([1, 2, 3]), "b": pd.Series([2, 2, 2])}, pd.Series([True, False, False])),
("a <= b", {"a": pd.Series([1, 2, 3]), "b": pd.Series([2, 2, 2])}, pd.Series([True, True, False])),
("a > b", {"a": pd.Series([1, 2, 3]), "b": pd.Series([2, 2, 2])}, pd.Series([False, False, True])),
("a >= b", {"a": pd.Series([1, 2, 3]), "b": pd.Series([2, 2, 2])}, pd.Series([False, True, True])),
(
"a == b",
{"a": pd.DataFrame({"x": [1, 2], "y": [3, 4]}), "b": pd.DataFrame({"x": [1, 2], "y": [3, 5]})},
pd.DataFrame({"x": [True, True], "y": [True, False]}),
),
(
"a != b",
{"a": pd.DataFrame({"x": [1, 2], "y": [3, 4]}), "b": pd.DataFrame({"x": [1, 2], "y": [3, 5]})},
pd.DataFrame({"x": [False, False], "y": [False, True]}),
),
(
"a < b",
{"a": pd.DataFrame({"x": [1, 2], "y": [3, 4]}), "b": pd.DataFrame({"x": [2, 2], "y": [2, 2]})},
pd.DataFrame({"x": [True, False], "y": [False, False]}),
),
(
"a <= b",
{"a": pd.DataFrame({"x": [1, 2], "y": [3, 4]}), "b": pd.DataFrame({"x": [2, 2], "y": [2, 2]})},
pd.DataFrame({"x": [True, True], "y": [False, False]}),
),
(
"a > b",
{"a": pd.DataFrame({"x": [1, 2], "y": [3, 4]}), "b": pd.DataFrame({"x": [2, 2], "y": [2, 2]})},
pd.DataFrame({"x": [False, False], "y": [True, True]}),
),
(
"a >= b",
{"a": pd.DataFrame({"x": [1, 2], "y": [3, 4]}), "b": pd.DataFrame({"x": [2, 2], "y": [2, 2]})},
pd.DataFrame({"x": [False, True], "y": [True, True]}),
),
],
)
def test_evaluate_condition_with_pandas(condition, state, expected_result):
condition_ast = ast.parse(condition, mode="eval").body
result = evaluate_condition(condition_ast, state, {}, {}, [])
if isinstance(result, pd.Series):
pd.testing.assert_series_equal(result, expected_result)
else:
pd.testing.assert_frame_equal(result, expected_result)


@pytest.mark.parametrize(
"condition, state, expected_exception",
[
# Chained conditions:
(
"a == b == c",
{
"a": pd.Series([1, 2, 3]),
"b": pd.Series([2, 2, 2]),
"c": pd.Series([3, 3, 3]),
},
ValueError(
"The truth value of a Series is ambiguous. Use a.empty, a.bool(), a.item(), a.any() or a.all()."
),
),
(
"a == b == c",
{
"a": pd.DataFrame({"x": [1, 2], "y": [3, 4]}),
"b": pd.DataFrame({"x": [2, 2], "y": [2, 2]}),
"c": pd.DataFrame({"x": [3, 3], "y": [3, 3]}),
},
ValueError(
"The truth value of a DataFrame is ambiguous. Use a.empty, a.bool(), a.item(), a.any() or a.all()."
),
),
],
)
def test_evaluate_condition_with_pandas_exceptions(condition, state, expected_exception):
condition_ast = ast.parse(condition, mode="eval").body
with pytest.raises(type(expected_exception)) as exception_info:
_ = evaluate_condition(condition_ast, state, {}, {}, [])
assert str(expected_exception) in str(exception_info.value)


def test_get_safe_module_handle_lazy_imports():
class FakeModule(types.ModuleType):
def __init__(self, name):
Expand All @@ -1222,28 +1308,28 @@ def __dir__(self):


def test_non_standard_comparisons():
code = """
class NonStdEqualsResult:
def __init__(self, left:object, right:object):
self._left = left
self._right = right
def __str__(self) -> str:
return f'{self._left}=={self._right}'
class NonStdComparisonClass:
def __init__(self, value: str ):
self._value = value
def __str__(self):
return self._value
def __eq__(self, other):
return NonStdEqualsResult(self, other)
a = NonStdComparisonClass("a")
b = NonStdComparisonClass("b")
result = a == b
"""
code = dedent("""\
class NonStdEqualsResult:
def __init__(self, left:object, right:object):
self._left = left
self._right = right
def __str__(self) -> str:
return f'{self._left} == {self._right}'
class NonStdComparisonClass:
def __init__(self, value: str ):
self._value = value
def __str__(self):
return self._value
def __eq__(self, other):
return NonStdEqualsResult(self, other)
a = NonStdComparisonClass("a")
b = NonStdComparisonClass("b")
result = a == b
""")
result, _ = evaluate_python_code(code, state={})
assert not isinstance(result, bool)
assert str(result) == "a==b"
assert str(result) == "a == b"


class TestPrintContainer:
Expand Down

0 comments on commit 9437133

Please sign in to comment.