Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix evaluate_condition for non-bool result #638

Merged
Merged
50 changes: 20 additions & 30 deletions src/smolagents/local_python_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,49 +714,39 @@ def evaluate_condition(
custom_tools: Dict[str, Callable],
authorized_imports: List[str],
) -> bool | object:
left = evaluate_ast(condition.left, state, static_tools, custom_tools, authorized_imports)
comparators = [
evaluate_ast(c, state, static_tools, custom_tools, authorized_imports) for c in condition.comparators
]
Comment on lines -718 to -720
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moving the evaluate_ast inside the loop, we avoid evaluating it for all condition.comparators.

  • Now, once a current_result is False, we return False

ops = [type(op) for op in condition.ops]

result = True
current_left = left

for op, comparator in zip(ops, comparators):
left = evaluate_ast(condition.left, state, static_tools, custom_tools, authorized_imports)
for i, (op, comparator) in enumerate(zip(condition.ops, condition.comparators)):
op = type(op)
right = evaluate_ast(comparator, state, static_tools, custom_tools, authorized_imports)
if op == ast.Eq:
current_result = current_left == comparator
current_result = left == right
elif op == ast.NotEq:
current_result = current_left != comparator
current_result = left != right
elif op == ast.Lt:
current_result = current_left < comparator
current_result = left < right
elif op == ast.LtE:
current_result = current_left <= comparator
current_result = left <= right
elif op == ast.Gt:
current_result = current_left > comparator
current_result = left > right
elif op == ast.GtE:
current_result = current_left >= comparator
current_result = left >= right
elif op == ast.Is:
current_result = current_left is comparator
current_result = left is right
elif op == ast.IsNot:
current_result = current_left is not comparator
current_result = left is not right
elif op == ast.In:
current_result = current_left in comparator
current_result = left in right
elif op == ast.NotIn:
current_result = current_left not in comparator
current_result = left not in right
else:
raise InterpreterError(f"Operator not supported: {op}")

if not isinstance(current_result, bool):
return current_result
raise InterpreterError(f"Unsupported comparison operator: {op}")

result = result & current_result
current_left = comparator

if isinstance(result, bool) and not result:
break

return result if isinstance(result, (bool, pd.Series)) else result.all()
if current_result is False:
return False
result = current_result if i == 0 else (result and current_result)
left = right
return result


def evaluate_if(
Expand Down
132 changes: 109 additions & 23 deletions tests/test_local_python_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from textwrap import dedent

import numpy as np
import pandas as pd
import pytest

from smolagents.default_tools import BASE_PYTHON_TOOLS
Expand Down Expand Up @@ -1188,11 +1189,11 @@ def test_evaluate_delete(code, state, expectation):
("a in b", {"a": 4, "b": [1, 2, 3]}, False),
("a not in b", {"a": 1, "b": [1, 2, 3]}, False),
("a not in b", {"a": 4, "b": [1, 2, 3]}, True),
# Composite conditions:
# Chained conditions:
("a == b == c", {"a": 1, "b": 1, "c": 1}, True),
("a == b == c", {"a": 1, "b": 2, "c": 1}, False),
("a == b < c", {"a": 1, "b": 1, "c": 1}, False),
("a == b < c", {"a": 1, "b": 1, "c": 2}, True),
("a == b < c", {"a": 2, "b": 2, "c": 2}, False),
("a == b < c", {"a": 0, "b": 0, "c": 1}, True),
],
)
def test_evaluate_condition(condition, state, expected_result):
Expand All @@ -1201,6 +1202,91 @@ def test_evaluate_condition(condition, state, expected_result):
assert result == expected_result


@pytest.mark.parametrize(
"condition, state, expected_result",
[
("a == b", {"a": pd.Series([1, 2, 3]), "b": pd.Series([2, 2, 2])}, pd.Series([False, True, False])),
("a != b", {"a": pd.Series([1, 2, 3]), "b": pd.Series([2, 2, 2])}, pd.Series([True, False, True])),
("a < b", {"a": pd.Series([1, 2, 3]), "b": pd.Series([2, 2, 2])}, pd.Series([True, False, False])),
("a <= b", {"a": pd.Series([1, 2, 3]), "b": pd.Series([2, 2, 2])}, pd.Series([True, True, False])),
("a > b", {"a": pd.Series([1, 2, 3]), "b": pd.Series([2, 2, 2])}, pd.Series([False, False, True])),
("a >= b", {"a": pd.Series([1, 2, 3]), "b": pd.Series([2, 2, 2])}, pd.Series([False, True, True])),
(
"a == b",
{"a": pd.DataFrame({"x": [1, 2], "y": [3, 4]}), "b": pd.DataFrame({"x": [1, 2], "y": [3, 5]})},
pd.DataFrame({"x": [True, True], "y": [True, False]}),
),
(
"a != b",
{"a": pd.DataFrame({"x": [1, 2], "y": [3, 4]}), "b": pd.DataFrame({"x": [1, 2], "y": [3, 5]})},
pd.DataFrame({"x": [False, False], "y": [False, True]}),
),
(
"a < b",
{"a": pd.DataFrame({"x": [1, 2], "y": [3, 4]}), "b": pd.DataFrame({"x": [2, 2], "y": [2, 2]})},
pd.DataFrame({"x": [True, False], "y": [False, False]}),
),
(
"a <= b",
{"a": pd.DataFrame({"x": [1, 2], "y": [3, 4]}), "b": pd.DataFrame({"x": [2, 2], "y": [2, 2]})},
pd.DataFrame({"x": [True, True], "y": [False, False]}),
),
(
"a > b",
{"a": pd.DataFrame({"x": [1, 2], "y": [3, 4]}), "b": pd.DataFrame({"x": [2, 2], "y": [2, 2]})},
pd.DataFrame({"x": [False, False], "y": [True, True]}),
),
(
"a >= b",
{"a": pd.DataFrame({"x": [1, 2], "y": [3, 4]}), "b": pd.DataFrame({"x": [2, 2], "y": [2, 2]})},
pd.DataFrame({"x": [False, True], "y": [True, True]}),
),
],
)
def test_evaluate_condition_with_pandas(condition, state, expected_result):
condition_ast = ast.parse(condition, mode="eval").body
result = evaluate_condition(condition_ast, state, {}, {}, [])
if isinstance(result, pd.Series):
pd.testing.assert_series_equal(result, expected_result)
else:
pd.testing.assert_frame_equal(result, expected_result)


@pytest.mark.parametrize(
"condition, state, expected_exception",
[
# Chained conditions:
(
"a == b == c",
{
"a": pd.Series([1, 2, 3]),
"b": pd.Series([2, 2, 2]),
"c": pd.Series([3, 3, 3]),
},
ValueError(
"The truth value of a Series is ambiguous. Use a.empty, a.bool(), a.item(), a.any() or a.all()."
),
),
(
"a == b == c",
{
"a": pd.DataFrame({"x": [1, 2], "y": [3, 4]}),
"b": pd.DataFrame({"x": [2, 2], "y": [2, 2]}),
"c": pd.DataFrame({"x": [3, 3], "y": [3, 3]}),
},
ValueError(
"The truth value of a DataFrame is ambiguous. Use a.empty, a.bool(), a.item(), a.any() or a.all()."
),
),
],
)
def test_evaluate_condition_with_pandas_exceptions(condition, state, expected_exception):
condition_ast = ast.parse(condition, mode="eval").body
with pytest.raises(type(expected_exception)) as exception_info:
_ = evaluate_condition(condition_ast, state, {}, {}, [])
assert str(expected_exception) in str(exception_info.value)


def test_get_safe_module_handle_lazy_imports():
class FakeModule(types.ModuleType):
def __init__(self, name):
Expand All @@ -1222,28 +1308,28 @@ def __dir__(self):


def test_non_standard_comparisons():
code = """
class NonStdEqualsResult:
def __init__(self, left:object, right:object):
self._left = left
self._right = right
def __str__(self) -> str:
return f'{self._left}=={self._right}'

class NonStdComparisonClass:
def __init__(self, value: str ):
self._value = value
def __str__(self):
return self._value
def __eq__(self, other):
return NonStdEqualsResult(self, other)
a = NonStdComparisonClass("a")
b = NonStdComparisonClass("b")
result = a == b
"""
code = dedent("""\
class NonStdEqualsResult:
def __init__(self, left:object, right:object):
self._left = left
self._right = right
def __str__(self) -> str:
return f'{self._left} == {self._right}'

class NonStdComparisonClass:
def __init__(self, value: str ):
self._value = value
def __str__(self):
return self._value
def __eq__(self, other):
return NonStdEqualsResult(self, other)
a = NonStdComparisonClass("a")
b = NonStdComparisonClass("b")
result = a == b
""")
result, _ = evaluate_python_code(code, state={})
assert not isinstance(result, bool)
assert str(result) == "a==b"
assert str(result) == "a == b"


class TestPrintContainer:
Expand Down