Skip to content

Commit

Permalink
Fix chained assignment in local executor (#843)
Browse files Browse the repository at this point in the history
Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
  • Loading branch information
sysradium and albertvillanova authored Mar 6, 2025
1 parent bf3686e commit b67cc94
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 22 deletions.
42 changes: 20 additions & 22 deletions src/smolagents/local_python_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,14 +590,13 @@ def evaluate_assign(
target = assign.targets[0]
set_value(target, result, state, static_tools, custom_tools, authorized_imports)
else:
if len(assign.targets) != len(result):
raise InterpreterError(f"Assign failed: expected {len(result)} values but got {len(assign.targets)}.")
expanded_values = []
for tgt in assign.targets:
if isinstance(tgt, ast.Starred):
expanded_values.extend(result)
else:
expanded_values.append(result)

for tgt, val in zip(assign.targets, expanded_values):
set_value(tgt, val, state, static_tools, custom_tools, authorized_imports)
return result
Expand Down Expand Up @@ -641,17 +640,21 @@ def evaluate_call(
custom_tools: Dict[str, Callable],
authorized_imports: List[str],
) -> Any:
if not (
isinstance(call.func, ast.Attribute) or isinstance(call.func, ast.Name) or isinstance(call.func, ast.Subscript)
):
if not isinstance(call.func, (ast.Call, ast.Lambda, ast.Attribute, ast.Name, ast.Subscript)):
raise InterpreterError(f"This is not a correct function: {call.func}).")
if isinstance(call.func, ast.Attribute):

func, func_name = None, None

if isinstance(call.func, ast.Call):
func = evaluate_call(call.func, state, static_tools, custom_tools, authorized_imports)
elif isinstance(call.func, ast.Lambda):
func = evaluate_lambda(call.func, state, static_tools, custom_tools, authorized_imports)
elif isinstance(call.func, ast.Attribute):
obj = evaluate_ast(call.func.value, state, static_tools, custom_tools, authorized_imports)
func_name = call.func.attr
if not hasattr(obj, func_name):
raise InterpreterError(f"Object {obj} has no attribute {func_name}")
func = getattr(obj, func_name)

elif isinstance(call.func, ast.Name):
func_name = call.func.id
if func_name in state:
Expand All @@ -666,12 +669,12 @@ def evaluate_call(
raise InterpreterError(
f"It is not permitted to evaluate other functions than the provided tools or functions defined/imported in previous code (tried to execute {call.func.id})."
)

elif isinstance(call.func, ast.Subscript):
func = evaluate_subscript(call.func, state, static_tools, custom_tools, authorized_imports)
if not callable(func):
raise InterpreterError(f"This is not a correct function: {call.func}).")
func_name = None

args = []
for arg in call.args:
if isinstance(arg, ast.Starred):
Expand Down Expand Up @@ -700,20 +703,15 @@ def evaluate_call(
return super(cls, instance)
else:
raise InterpreterError("super() takes at most 2 arguments")
else:
if func_name == "print":
state["_print_outputs"] += " ".join(map(str, args)) + "\n"
return None
else: # Assume it's a callable object
if (
(inspect.getmodule(func) == builtins)
and inspect.isbuiltin(func)
and (func not in static_tools.values())
):
raise InterpreterError(
f"Invoking a builtin function that has not been explicitly added as a tool is not allowed ({func_name})."
)
return func(*args, **kwargs)
elif func_name == "print":
state["_print_outputs"] += " ".join(map(str, args)) + "\n"
return None
else: # Assume it's a callable object
if (inspect.getmodule(func) == builtins) and inspect.isbuiltin(func) and (func not in static_tools.values()):
raise InterpreterError(
f"Invoking a builtin function that has not been explicitly added as a tool is not allowed ({func_name})."
)
return func(*args, **kwargs)


def evaluate_subscript(
Expand Down
38 changes: 38 additions & 0 deletions tests/test_local_python_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1423,6 +1423,44 @@ def test_call_from_dict(self, code):
result, _, _ = executor(code)
assert result == 11

@pytest.mark.parametrize(
"code",
[
"a = b = 1; a",
"a = b = 1; b",
"a, b = c, d = 1, 1; a",
"a, b = c, d = 1, 1; b",
"a, b = c, d = 1, 1; c",
"a, b = c, d = {1, 2}; a",
"a, b = c, d = {1, 2}; c",
"a, b = c, d = {1: 10, 2: 20}; a",
"a, b = c, d = {1: 10, 2: 20}; c",
"a = b = (lambda: 1)(); b",
"a = b = (lambda: 1)(); lambda x: 10; b",
"a = b = (lambda x: lambda y: x + y)(0)(1); b",
dedent("""
def foo():
return 1;
a = b = foo(); b"""),
dedent("""
def foo(*args, **kwargs):
return sum(args)
a = b = foo(1,-1,1); b"""),
"a, b = 1, 2; a, b = b, a; b",
],
)
def test_chained_assignments(self, code):
executor = LocalPythonExecutor([])
executor.send_tools({})
result, _, _ = executor(code)
assert result == 1

def test_evaluate_assign_error(self):
code = "a, b = 1, 2, 3; a"
executor = LocalPythonExecutor([])
with pytest.raises(InterpreterError, match=".*Cannot unpack tuple of wrong size"):
executor(code)


class TestLocalPythonExecutorSecurity:
@pytest.mark.parametrize(
Expand Down

0 comments on commit b67cc94

Please sign in to comment.