diff --git a/src/smolagents/local_python_executor.py b/src/smolagents/local_python_executor.py index d8fe84e69..a7b2fb3f6 100644 --- a/src/smolagents/local_python_executor.py +++ b/src/smolagents/local_python_executor.py @@ -1179,126 +1179,113 @@ def evaluate_ast( f"Reached the max number of operations of {MAX_OPERATIONS}. Maybe there is an infinite loop somewhere in the code, or you're just asking too many calculations." ) state["_operations_count"] += 1 + common_params = (state, static_tools, custom_tools, authorized_imports) if isinstance(expression, ast.Assign): # Assignment -> we evaluate the assignment which should update the state # We return the variable assigned as it may be used to determine the final result. - return evaluate_assign(expression, state, static_tools, custom_tools, authorized_imports) + return evaluate_assign(expression, *common_params) elif isinstance(expression, ast.AugAssign): - return evaluate_augassign(expression, state, static_tools, custom_tools, authorized_imports) + return evaluate_augassign(expression, *common_params) elif isinstance(expression, ast.Call): # Function call -> we return the value of the function call - return evaluate_call(expression, state, static_tools, custom_tools, authorized_imports) + return evaluate_call(expression, *common_params) elif isinstance(expression, ast.Constant): # Constant -> just return the value return expression.value elif isinstance(expression, ast.Tuple): - return tuple( - evaluate_ast(elt, state, static_tools, custom_tools, authorized_imports) for elt in expression.elts - ) + return tuple((evaluate_ast(elt, *common_params) for elt in expression.elts)) elif isinstance(expression, (ast.ListComp, ast.GeneratorExp)): - return evaluate_listcomp(expression, state, static_tools, custom_tools, authorized_imports) + return evaluate_listcomp(expression, *common_params) elif isinstance(expression, ast.UnaryOp): - return evaluate_unaryop(expression, state, static_tools, custom_tools, authorized_imports) + return evaluate_unaryop(expression, *common_params) elif isinstance(expression, ast.Starred): - return evaluate_ast(expression.value, state, static_tools, custom_tools, authorized_imports) + return evaluate_ast(expression.value, *common_params) elif isinstance(expression, ast.BoolOp): # Boolean operation -> evaluate the operation - return evaluate_boolop(expression, state, static_tools, custom_tools, authorized_imports) + return evaluate_boolop(expression, *common_params) elif isinstance(expression, ast.Break): raise BreakException() elif isinstance(expression, ast.Continue): raise ContinueException() elif isinstance(expression, ast.BinOp): # Binary operation -> execute operation - return evaluate_binop(expression, state, static_tools, custom_tools, authorized_imports) + return evaluate_binop(expression, *common_params) elif isinstance(expression, ast.Compare): # Comparison -> evaluate the comparison - return evaluate_condition(expression, state, static_tools, custom_tools, authorized_imports) + return evaluate_condition(expression, *common_params) elif isinstance(expression, ast.Lambda): - return evaluate_lambda(expression, state, static_tools, custom_tools, authorized_imports) + return evaluate_lambda(expression, *common_params) elif isinstance(expression, ast.FunctionDef): - return evaluate_function_def(expression, state, static_tools, custom_tools, authorized_imports) + return evaluate_function_def(expression, *common_params) elif isinstance(expression, ast.Dict): # Dict -> evaluate all keys and values - keys = [evaluate_ast(k, state, static_tools, custom_tools, authorized_imports) for k in expression.keys] - values = [evaluate_ast(v, state, static_tools, custom_tools, authorized_imports) for v in expression.values] + keys = (evaluate_ast(k, *common_params) for k in expression.keys) + values = (evaluate_ast(v, *common_params) for v in expression.values) return dict(zip(keys, values)) elif isinstance(expression, ast.Expr): # Expression -> evaluate the content - return evaluate_ast(expression.value, state, static_tools, custom_tools, authorized_imports) + return evaluate_ast(expression.value, *common_params) elif isinstance(expression, ast.For): # For loop -> execute the loop - return evaluate_for(expression, state, static_tools, custom_tools, authorized_imports) + return evaluate_for(expression, *common_params) elif isinstance(expression, ast.FormattedValue): # Formatted value (part of f-string) -> evaluate the content and return - return evaluate_ast(expression.value, state, static_tools, custom_tools, authorized_imports) + return evaluate_ast(expression.value, *common_params) elif isinstance(expression, ast.If): # If -> execute the right branch - return evaluate_if(expression, state, static_tools, custom_tools, authorized_imports) + return evaluate_if(expression, *common_params) elif hasattr(ast, "Index") and isinstance(expression, ast.Index): - return evaluate_ast(expression.value, state, static_tools, custom_tools, authorized_imports) + return evaluate_ast(expression.value, *common_params) elif isinstance(expression, ast.JoinedStr): - return "".join( - [str(evaluate_ast(v, state, static_tools, custom_tools, authorized_imports)) for v in expression.values] - ) + return "".join([str(evaluate_ast(v, *common_params)) for v in expression.values]) elif isinstance(expression, ast.List): # List -> evaluate all elements - return [evaluate_ast(elt, state, static_tools, custom_tools, authorized_imports) for elt in expression.elts] + return [evaluate_ast(elt, *common_params) for elt in expression.elts] elif isinstance(expression, ast.Name): # Name -> pick up the value in the state - return evaluate_name(expression, state, static_tools, custom_tools, authorized_imports) + return evaluate_name(expression, *common_params) elif isinstance(expression, ast.Subscript): # Subscript -> return the value of the indexing - return evaluate_subscript(expression, state, static_tools, custom_tools, authorized_imports) + return evaluate_subscript(expression, *common_params) elif isinstance(expression, ast.IfExp): - test_val = evaluate_ast(expression.test, state, static_tools, custom_tools, authorized_imports) + test_val = evaluate_ast(expression.test, *common_params) if test_val: - return evaluate_ast(expression.body, state, static_tools, custom_tools, authorized_imports) + return evaluate_ast(expression.body, *common_params) else: - return evaluate_ast(expression.orelse, state, static_tools, custom_tools, authorized_imports) + return evaluate_ast(expression.orelse, *common_params) elif isinstance(expression, ast.Attribute): - value = evaluate_ast(expression.value, state, static_tools, custom_tools, authorized_imports) + value = evaluate_ast(expression.value, *common_params) return getattr(value, expression.attr) elif isinstance(expression, ast.Slice): return slice( - evaluate_ast(expression.lower, state, static_tools, custom_tools, authorized_imports) - if expression.lower is not None - else None, - evaluate_ast(expression.upper, state, static_tools, custom_tools, authorized_imports) - if expression.upper is not None - else None, - evaluate_ast(expression.step, state, static_tools, custom_tools, authorized_imports) - if expression.step is not None - else None, + evaluate_ast(expression.lower, *common_params) if expression.lower is not None else None, + evaluate_ast(expression.upper, *common_params) if expression.upper is not None else None, + evaluate_ast(expression.step, *common_params) if expression.step is not None else None, ) elif isinstance(expression, ast.DictComp): - return evaluate_dictcomp(expression, state, static_tools, custom_tools, authorized_imports) + return evaluate_dictcomp(expression, *common_params) elif isinstance(expression, ast.While): - return evaluate_while(expression, state, static_tools, custom_tools, authorized_imports) + return evaluate_while(expression, *common_params) elif isinstance(expression, (ast.Import, ast.ImportFrom)): return import_modules(expression, state, authorized_imports) elif isinstance(expression, ast.ClassDef): - return evaluate_class_def(expression, state, static_tools, custom_tools, authorized_imports) + return evaluate_class_def(expression, *common_params) elif isinstance(expression, ast.Try): - return evaluate_try(expression, state, static_tools, custom_tools, authorized_imports) + return evaluate_try(expression, *common_params) elif isinstance(expression, ast.Raise): - return evaluate_raise(expression, state, static_tools, custom_tools, authorized_imports) + return evaluate_raise(expression, *common_params) elif isinstance(expression, ast.Assert): - return evaluate_assert(expression, state, static_tools, custom_tools, authorized_imports) + return evaluate_assert(expression, *common_params) elif isinstance(expression, ast.With): - return evaluate_with(expression, state, static_tools, custom_tools, authorized_imports) + return evaluate_with(expression, *common_params) elif isinstance(expression, ast.Set): - return {evaluate_ast(elt, state, static_tools, custom_tools, authorized_imports) for elt in expression.elts} + return set((evaluate_ast(elt, *common_params) for elt in expression.elts)) elif isinstance(expression, ast.Return): - raise ReturnException( - evaluate_ast(expression.value, state, static_tools, custom_tools, authorized_imports) - if expression.value - else None - ) + raise ReturnException(evaluate_ast(expression.value, *common_params) if expression.value else None) elif isinstance(expression, ast.Pass): return None elif isinstance(expression, ast.Delete): - return evaluate_delete(expression, state, static_tools, custom_tools, authorized_imports) + return evaluate_delete(expression, *common_params) else: # For now we refuse anything else. Let's add things as we need them. raise InterpreterError(f"{expression.__class__.__name__} is not supported.")