Skip to content

Commit

Permalink
Refactor evaluate ast to improve readability (#625)
Browse files Browse the repository at this point in the history
  • Loading branch information
CalOmnie authored Feb 18, 2025
1 parent f631c75 commit 4e05fab
Showing 1 changed file with 41 additions and 54 deletions.
95 changes: 41 additions & 54 deletions src/smolagents/local_python_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1179,126 +1179,113 @@ def evaluate_ast(
f"Reached the max number of operations of {MAX_OPERATIONS}. Maybe there is an infinite loop somewhere in the code, or you're just asking too many calculations."
)
state["_operations_count"] += 1
common_params = (state, static_tools, custom_tools, authorized_imports)
if isinstance(expression, ast.Assign):
# Assignment -> we evaluate the assignment which should update the state
# We return the variable assigned as it may be used to determine the final result.
return evaluate_assign(expression, state, static_tools, custom_tools, authorized_imports)
return evaluate_assign(expression, *common_params)
elif isinstance(expression, ast.AugAssign):
return evaluate_augassign(expression, state, static_tools, custom_tools, authorized_imports)
return evaluate_augassign(expression, *common_params)
elif isinstance(expression, ast.Call):
# Function call -> we return the value of the function call
return evaluate_call(expression, state, static_tools, custom_tools, authorized_imports)
return evaluate_call(expression, *common_params)
elif isinstance(expression, ast.Constant):
# Constant -> just return the value
return expression.value
elif isinstance(expression, ast.Tuple):
return tuple(
evaluate_ast(elt, state, static_tools, custom_tools, authorized_imports) for elt in expression.elts
)
return tuple((evaluate_ast(elt, *common_params) for elt in expression.elts))
elif isinstance(expression, (ast.ListComp, ast.GeneratorExp)):
return evaluate_listcomp(expression, state, static_tools, custom_tools, authorized_imports)
return evaluate_listcomp(expression, *common_params)
elif isinstance(expression, ast.UnaryOp):
return evaluate_unaryop(expression, state, static_tools, custom_tools, authorized_imports)
return evaluate_unaryop(expression, *common_params)
elif isinstance(expression, ast.Starred):
return evaluate_ast(expression.value, state, static_tools, custom_tools, authorized_imports)
return evaluate_ast(expression.value, *common_params)
elif isinstance(expression, ast.BoolOp):
# Boolean operation -> evaluate the operation
return evaluate_boolop(expression, state, static_tools, custom_tools, authorized_imports)
return evaluate_boolop(expression, *common_params)
elif isinstance(expression, ast.Break):
raise BreakException()
elif isinstance(expression, ast.Continue):
raise ContinueException()
elif isinstance(expression, ast.BinOp):
# Binary operation -> execute operation
return evaluate_binop(expression, state, static_tools, custom_tools, authorized_imports)
return evaluate_binop(expression, *common_params)
elif isinstance(expression, ast.Compare):
# Comparison -> evaluate the comparison
return evaluate_condition(expression, state, static_tools, custom_tools, authorized_imports)
return evaluate_condition(expression, *common_params)
elif isinstance(expression, ast.Lambda):
return evaluate_lambda(expression, state, static_tools, custom_tools, authorized_imports)
return evaluate_lambda(expression, *common_params)
elif isinstance(expression, ast.FunctionDef):
return evaluate_function_def(expression, state, static_tools, custom_tools, authorized_imports)
return evaluate_function_def(expression, *common_params)
elif isinstance(expression, ast.Dict):
# Dict -> evaluate all keys and values
keys = [evaluate_ast(k, state, static_tools, custom_tools, authorized_imports) for k in expression.keys]
values = [evaluate_ast(v, state, static_tools, custom_tools, authorized_imports) for v in expression.values]
keys = (evaluate_ast(k, *common_params) for k in expression.keys)
values = (evaluate_ast(v, *common_params) for v in expression.values)
return dict(zip(keys, values))
elif isinstance(expression, ast.Expr):
# Expression -> evaluate the content
return evaluate_ast(expression.value, state, static_tools, custom_tools, authorized_imports)
return evaluate_ast(expression.value, *common_params)
elif isinstance(expression, ast.For):
# For loop -> execute the loop
return evaluate_for(expression, state, static_tools, custom_tools, authorized_imports)
return evaluate_for(expression, *common_params)
elif isinstance(expression, ast.FormattedValue):
# Formatted value (part of f-string) -> evaluate the content and return
return evaluate_ast(expression.value, state, static_tools, custom_tools, authorized_imports)
return evaluate_ast(expression.value, *common_params)
elif isinstance(expression, ast.If):
# If -> execute the right branch
return evaluate_if(expression, state, static_tools, custom_tools, authorized_imports)
return evaluate_if(expression, *common_params)
elif hasattr(ast, "Index") and isinstance(expression, ast.Index):
return evaluate_ast(expression.value, state, static_tools, custom_tools, authorized_imports)
return evaluate_ast(expression.value, *common_params)
elif isinstance(expression, ast.JoinedStr):
return "".join(
[str(evaluate_ast(v, state, static_tools, custom_tools, authorized_imports)) for v in expression.values]
)
return "".join([str(evaluate_ast(v, *common_params)) for v in expression.values])
elif isinstance(expression, ast.List):
# List -> evaluate all elements
return [evaluate_ast(elt, state, static_tools, custom_tools, authorized_imports) for elt in expression.elts]
return [evaluate_ast(elt, *common_params) for elt in expression.elts]
elif isinstance(expression, ast.Name):
# Name -> pick up the value in the state
return evaluate_name(expression, state, static_tools, custom_tools, authorized_imports)
return evaluate_name(expression, *common_params)
elif isinstance(expression, ast.Subscript):
# Subscript -> return the value of the indexing
return evaluate_subscript(expression, state, static_tools, custom_tools, authorized_imports)
return evaluate_subscript(expression, *common_params)
elif isinstance(expression, ast.IfExp):
test_val = evaluate_ast(expression.test, state, static_tools, custom_tools, authorized_imports)
test_val = evaluate_ast(expression.test, *common_params)
if test_val:
return evaluate_ast(expression.body, state, static_tools, custom_tools, authorized_imports)
return evaluate_ast(expression.body, *common_params)
else:
return evaluate_ast(expression.orelse, state, static_tools, custom_tools, authorized_imports)
return evaluate_ast(expression.orelse, *common_params)
elif isinstance(expression, ast.Attribute):
value = evaluate_ast(expression.value, state, static_tools, custom_tools, authorized_imports)
value = evaluate_ast(expression.value, *common_params)
return getattr(value, expression.attr)
elif isinstance(expression, ast.Slice):
return slice(
evaluate_ast(expression.lower, state, static_tools, custom_tools, authorized_imports)
if expression.lower is not None
else None,
evaluate_ast(expression.upper, state, static_tools, custom_tools, authorized_imports)
if expression.upper is not None
else None,
evaluate_ast(expression.step, state, static_tools, custom_tools, authorized_imports)
if expression.step is not None
else None,
evaluate_ast(expression.lower, *common_params) if expression.lower is not None else None,
evaluate_ast(expression.upper, *common_params) if expression.upper is not None else None,
evaluate_ast(expression.step, *common_params) if expression.step is not None else None,
)
elif isinstance(expression, ast.DictComp):
return evaluate_dictcomp(expression, state, static_tools, custom_tools, authorized_imports)
return evaluate_dictcomp(expression, *common_params)
elif isinstance(expression, ast.While):
return evaluate_while(expression, state, static_tools, custom_tools, authorized_imports)
return evaluate_while(expression, *common_params)
elif isinstance(expression, (ast.Import, ast.ImportFrom)):
return import_modules(expression, state, authorized_imports)
elif isinstance(expression, ast.ClassDef):
return evaluate_class_def(expression, state, static_tools, custom_tools, authorized_imports)
return evaluate_class_def(expression, *common_params)
elif isinstance(expression, ast.Try):
return evaluate_try(expression, state, static_tools, custom_tools, authorized_imports)
return evaluate_try(expression, *common_params)
elif isinstance(expression, ast.Raise):
return evaluate_raise(expression, state, static_tools, custom_tools, authorized_imports)
return evaluate_raise(expression, *common_params)
elif isinstance(expression, ast.Assert):
return evaluate_assert(expression, state, static_tools, custom_tools, authorized_imports)
return evaluate_assert(expression, *common_params)
elif isinstance(expression, ast.With):
return evaluate_with(expression, state, static_tools, custom_tools, authorized_imports)
return evaluate_with(expression, *common_params)
elif isinstance(expression, ast.Set):
return {evaluate_ast(elt, state, static_tools, custom_tools, authorized_imports) for elt in expression.elts}
return set((evaluate_ast(elt, *common_params) for elt in expression.elts))
elif isinstance(expression, ast.Return):
raise ReturnException(
evaluate_ast(expression.value, state, static_tools, custom_tools, authorized_imports)
if expression.value
else None
)
raise ReturnException(evaluate_ast(expression.value, *common_params) if expression.value else None)
elif isinstance(expression, ast.Pass):
return None
elif isinstance(expression, ast.Delete):
return evaluate_delete(expression, state, static_tools, custom_tools, authorized_imports)
return evaluate_delete(expression, *common_params)
else:
# For now we refuse anything else. Let's add things as we need them.
raise InterpreterError(f"{expression.__class__.__name__} is not supported.")
Expand Down

0 comments on commit 4e05fab

Please sign in to comment.