Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor evaluate ast to improve readability #625

Merged
merged 3 commits into from
Feb 18, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 41 additions & 54 deletions src/smolagents/local_python_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1189,126 +1189,113 @@ def evaluate_ast(
f"Reached the max number of operations of {MAX_OPERATIONS}. Maybe there is an infinite loop somewhere in the code, or you're just asking too many calculations."
)
state["_operations_count"] += 1
common_params = (state, static_tools, custom_tools, authorized_imports)
if isinstance(expression, ast.Assign):
# Assignment -> we evaluate the assignment which should update the state
# We return the variable assigned as it may be used to determine the final result.
return evaluate_assign(expression, state, static_tools, custom_tools, authorized_imports)
return evaluate_assign(expression, *common_params)
elif isinstance(expression, ast.AugAssign):
return evaluate_augassign(expression, state, static_tools, custom_tools, authorized_imports)
return evaluate_augassign(expression, *common_params)
elif isinstance(expression, ast.Call):
# Function call -> we return the value of the function call
return evaluate_call(expression, state, static_tools, custom_tools, authorized_imports)
return evaluate_call(expression, *common_params)
elif isinstance(expression, ast.Constant):
# Constant -> just return the value
return expression.value
elif isinstance(expression, ast.Tuple):
return tuple(
evaluate_ast(elt, state, static_tools, custom_tools, authorized_imports) for elt in expression.elts
)
return tuple((evaluate_ast(elt, *common_params) for elt in expression.elts))
elif isinstance(expression, (ast.ListComp, ast.GeneratorExp)):
return evaluate_listcomp(expression, state, static_tools, custom_tools, authorized_imports)
return evaluate_listcomp(expression, *common_params)
elif isinstance(expression, ast.UnaryOp):
return evaluate_unaryop(expression, state, static_tools, custom_tools, authorized_imports)
return evaluate_unaryop(expression, *common_params)
elif isinstance(expression, ast.Starred):
return evaluate_ast(expression.value, state, static_tools, custom_tools, authorized_imports)
return evaluate_ast(expression.value, *common_params)
elif isinstance(expression, ast.BoolOp):
# Boolean operation -> evaluate the operation
return evaluate_boolop(expression, state, static_tools, custom_tools, authorized_imports)
return evaluate_boolop(expression, *common_params)
elif isinstance(expression, ast.Break):
raise BreakException()
elif isinstance(expression, ast.Continue):
raise ContinueException()
elif isinstance(expression, ast.BinOp):
# Binary operation -> execute operation
return evaluate_binop(expression, state, static_tools, custom_tools, authorized_imports)
return evaluate_binop(expression, *common_params)
elif isinstance(expression, ast.Compare):
# Comparison -> evaluate the comparison
return evaluate_condition(expression, state, static_tools, custom_tools, authorized_imports)
return evaluate_condition(expression, *common_params)
elif isinstance(expression, ast.Lambda):
return evaluate_lambda(expression, state, static_tools, custom_tools, authorized_imports)
return evaluate_lambda(expression, *common_params)
elif isinstance(expression, ast.FunctionDef):
return evaluate_function_def(expression, state, static_tools, custom_tools, authorized_imports)
return evaluate_function_def(expression, *common_params)
elif isinstance(expression, ast.Dict):
# Dict -> evaluate all keys and values
keys = [evaluate_ast(k, state, static_tools, custom_tools, authorized_imports) for k in expression.keys]
values = [evaluate_ast(v, state, static_tools, custom_tools, authorized_imports) for v in expression.values]
keys = (evaluate_ast(k, *common_params) for k in expression.keys)
values = (evaluate_ast(v, *common_params) for v in expression.values)
return dict(zip(keys, values))
elif isinstance(expression, ast.Expr):
# Expression -> evaluate the content
return evaluate_ast(expression.value, state, static_tools, custom_tools, authorized_imports)
return evaluate_ast(expression.value, *common_params)
elif isinstance(expression, ast.For):
# For loop -> execute the loop
return evaluate_for(expression, state, static_tools, custom_tools, authorized_imports)
return evaluate_for(expression, *common_params)
elif isinstance(expression, ast.FormattedValue):
# Formatted value (part of f-string) -> evaluate the content and return
return evaluate_ast(expression.value, state, static_tools, custom_tools, authorized_imports)
return evaluate_ast(expression.value, *common_params)
elif isinstance(expression, ast.If):
# If -> execute the right branch
return evaluate_if(expression, state, static_tools, custom_tools, authorized_imports)
return evaluate_if(expression, *common_params)
elif hasattr(ast, "Index") and isinstance(expression, ast.Index):
return evaluate_ast(expression.value, state, static_tools, custom_tools, authorized_imports)
return evaluate_ast(expression.value, *common_params)
elif isinstance(expression, ast.JoinedStr):
return "".join(
[str(evaluate_ast(v, state, static_tools, custom_tools, authorized_imports)) for v in expression.values]
)
return "".join([str(evaluate_ast(v, *common_params)) for v in expression.values])
elif isinstance(expression, ast.List):
# List -> evaluate all elements
return [evaluate_ast(elt, state, static_tools, custom_tools, authorized_imports) for elt in expression.elts]
return [evaluate_ast(elt, *common_params) for elt in expression.elts]
elif isinstance(expression, ast.Name):
# Name -> pick up the value in the state
return evaluate_name(expression, state, static_tools, custom_tools, authorized_imports)
return evaluate_name(expression, *common_params)
elif isinstance(expression, ast.Subscript):
# Subscript -> return the value of the indexing
return evaluate_subscript(expression, state, static_tools, custom_tools, authorized_imports)
return evaluate_subscript(expression, *common_params)
elif isinstance(expression, ast.IfExp):
test_val = evaluate_ast(expression.test, state, static_tools, custom_tools, authorized_imports)
test_val = evaluate_ast(expression.test, *common_params)
if test_val:
return evaluate_ast(expression.body, state, static_tools, custom_tools, authorized_imports)
return evaluate_ast(expression.body, *common_params)
else:
return evaluate_ast(expression.orelse, state, static_tools, custom_tools, authorized_imports)
return evaluate_ast(expression.orelse, *common_params)
elif isinstance(expression, ast.Attribute):
value = evaluate_ast(expression.value, state, static_tools, custom_tools, authorized_imports)
value = evaluate_ast(expression.value, *common_params)
return getattr(value, expression.attr)
elif isinstance(expression, ast.Slice):
return slice(
evaluate_ast(expression.lower, state, static_tools, custom_tools, authorized_imports)
if expression.lower is not None
else None,
evaluate_ast(expression.upper, state, static_tools, custom_tools, authorized_imports)
if expression.upper is not None
else None,
evaluate_ast(expression.step, state, static_tools, custom_tools, authorized_imports)
if expression.step is not None
else None,
evaluate_ast(expression.lower, *common_params) if expression.lower is not None else None,
evaluate_ast(expression.upper, *common_params) if expression.upper is not None else None,
evaluate_ast(expression.step, *common_params) if expression.step is not None else None,
)
elif isinstance(expression, ast.DictComp):
return evaluate_dictcomp(expression, state, static_tools, custom_tools, authorized_imports)
return evaluate_dictcomp(expression, *common_params)
elif isinstance(expression, ast.While):
return evaluate_while(expression, state, static_tools, custom_tools, authorized_imports)
return evaluate_while(expression, *common_params)
elif isinstance(expression, (ast.Import, ast.ImportFrom)):
return import_modules(expression, state, authorized_imports)
elif isinstance(expression, ast.ClassDef):
return evaluate_class_def(expression, state, static_tools, custom_tools, authorized_imports)
return evaluate_class_def(expression, *common_params)
elif isinstance(expression, ast.Try):
return evaluate_try(expression, state, static_tools, custom_tools, authorized_imports)
return evaluate_try(expression, *common_params)
elif isinstance(expression, ast.Raise):
return evaluate_raise(expression, state, static_tools, custom_tools, authorized_imports)
return evaluate_raise(expression, *common_params)
elif isinstance(expression, ast.Assert):
return evaluate_assert(expression, state, static_tools, custom_tools, authorized_imports)
return evaluate_assert(expression, *common_params)
elif isinstance(expression, ast.With):
return evaluate_with(expression, state, static_tools, custom_tools, authorized_imports)
return evaluate_with(expression, *common_params)
elif isinstance(expression, ast.Set):
return {evaluate_ast(elt, state, static_tools, custom_tools, authorized_imports) for elt in expression.elts}
return set((evaluate_ast(elt, *common_params) for elt in expression.elts))
elif isinstance(expression, ast.Return):
raise ReturnException(
evaluate_ast(expression.value, state, static_tools, custom_tools, authorized_imports)
if expression.value
else None
)
raise ReturnException(evaluate_ast(expression.value, *common_params) if expression.value else None)
elif isinstance(expression, ast.Pass):
return None
elif isinstance(expression, ast.Delete):
return evaluate_delete(expression, state, static_tools, custom_tools, authorized_imports)
return evaluate_delete(expression, *common_params)
else:
# For now we refuse anything else. Let's add things as we need them.
raise InterpreterError(f"{expression.__class__.__name__} is not supported.")
Expand Down
Loading