Skip to content

Commit

Permalink
Use match statement
Browse files Browse the repository at this point in the history
  • Loading branch information
CalOmnie committed Feb 12, 2025
1 parent 15253e0 commit 3ce7c97
Showing 1 changed file with 122 additions and 121 deletions.
243 changes: 122 additions & 121 deletions src/smolagents/local_python_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1192,127 +1192,128 @@ def evaluate_ast(
state["_operations_count"] += 1
common_params = (state, static_tools, custom_tools, authorized_imports)
evaluate_ast_partial = partial(evaluate_ast, state=state, static_tools=static_tools, custom_tools=custom_tools, authorized_imports=authorized_imports)
if isinstance(expression, ast.Assign):
# Assignment -> we evaluate the assignment which should update the state
# We return the variable assigned as it may be used to determine the final result.
return evaluate_assign(expression, *common_params)
elif isinstance(expression, ast.AugAssign):
return evaluate_augassign(expression, *common_params)
elif isinstance(expression, ast.Call):
# Function call -> we return the value of the function call
return evaluate_call(expression, *common_params)
elif isinstance(expression, ast.Constant):
# Constant -> just return the value
return expression.value
elif isinstance(expression, ast.Tuple):
return tuple(map(evaluate_ast_partial, expression.elts))
elif isinstance(expression, (ast.ListComp, ast.GeneratorExp)):
return evaluate_listcomp(expression, *common_params)
elif isinstance(expression, ast.UnaryOp):
return evaluate_unaryop(expression, *common_params)
elif isinstance(expression, ast.Starred):
return evaluate_ast(expression.value, *common_params)
elif isinstance(expression, ast.BoolOp):
# Boolean operation -> evaluate the operation
return evaluate_boolop(expression, *common_params)
elif isinstance(expression, ast.Break):
raise BreakException()
elif isinstance(expression, ast.Continue):
raise ContinueException()
elif isinstance(expression, ast.BinOp):
# Binary operation -> execute operation
return evaluate_binop(expression, *common_params)
elif isinstance(expression, ast.Compare):
# Comparison -> evaluate the comparison
return evaluate_condition(expression, *common_params)
elif isinstance(expression, ast.Lambda):
return evaluate_lambda(expression, *common_params)
elif isinstance(expression, ast.FunctionDef):
return evaluate_function_def(expression, *common_params)
elif isinstance(expression, ast.Dict):
# Dict -> evaluate all keys and values
keys = map(evaluate_ast_partial, expression.keys)
values = map(evaluate_ast_partial, expression.values)
return dict(zip(keys, values))
elif isinstance(expression, ast.Expr):
# Expression -> evaluate the content
return evaluate_ast(expression.value, *common_params)
elif isinstance(expression, ast.For):
# For loop -> execute the loop
return evaluate_for(expression, *common_params)
elif isinstance(expression, ast.FormattedValue):
# Formatted value (part of f-string) -> evaluate the content and return
return evaluate_ast(expression.value, *common_params)
elif isinstance(expression, ast.If):
# If -> execute the right branch
return evaluate_if(expression, *common_params)
elif hasattr(ast, "Index") and isinstance(expression, ast.Index):
return evaluate_ast(expression.value, *common_params)
elif isinstance(expression, ast.JoinedStr):
return "".join(
[str(evaluate_ast_partial(v)) for v in expression.values]
)
elif isinstance(expression, ast.List):
# List -> evaluate all elements
return [evaluate_ast_partial(elt) for elt in expression.elts]
elif isinstance(expression, ast.Name):
# Name -> pick up the value in the state
return evaluate_name(expression, *common_params)
elif isinstance(expression, ast.Subscript):
# Subscript -> return the value of the indexing
return evaluate_subscript(expression, *common_params)
elif isinstance(expression, ast.IfExp):
test_val = evaluate_ast(expression.test, *common_params)
if test_val:
return evaluate_ast(expression.body, *common_params)
else:
return evaluate_ast(expression.orelse, *common_params)
elif isinstance(expression, ast.Attribute):
value = evaluate_ast(expression.value, *common_params)
return getattr(value, expression.attr)
elif isinstance(expression, ast.Slice):
return slice(
evaluate_ast(expression.lower, *common_params)
if expression.lower is not None
else None,
evaluate_ast(expression.upper, *common_params)
if expression.upper is not None
else None,
evaluate_ast(expression.step, *common_params)
if expression.step is not None
else None,
)
elif isinstance(expression, ast.DictComp):
return evaluate_dictcomp(expression, *common_params)
elif isinstance(expression, ast.While):
return evaluate_while(expression, *common_params)
elif isinstance(expression, (ast.Import, ast.ImportFrom)):
return import_modules(expression, state, authorized_imports)
elif isinstance(expression, ast.ClassDef):
return evaluate_class_def(expression, *common_params)
elif isinstance(expression, ast.Try):
return evaluate_try(expression, *common_params)
elif isinstance(expression, ast.Raise):
return evaluate_raise(expression, *common_params)
elif isinstance(expression, ast.Assert):
return evaluate_assert(expression, *common_params)
elif isinstance(expression, ast.With):
return evaluate_with(expression, *common_params)
elif isinstance(expression, ast.Set):
return set(map(evaluate_ast_partial, expression.elts))
elif isinstance(expression, ast.Return):
raise ReturnException(
evaluate_ast(expression.value, *common_params)
if expression.value
else None
)
elif isinstance(expression, ast.Pass):
return None
elif isinstance(expression, ast.Delete):
return evaluate_delete(expression, *common_params)
else:
# For now we refuse anything else. Let's add things as we need them.
raise InterpreterError(f"{expression.__class__.__name__} is not supported.")
match expression:
case ast.Assign():
# Assignment -> we evaluate the assignment which should update the state
# We return the variable assigned as it may be used to determine the final result.
return evaluate_assign(expression, *common_params)
case ast.AugAssign():
return evaluate_augassign(expression, *common_params)
case ast.Call():
# Function call -> we return the value of the function call
return evaluate_call(expression, *common_params)
case ast.Constant():
# Constant -> just return the value
return expression.value
case ast.Tuple():
return tuple(map(evaluate_ast_partial, expression.elts))
case ast.ListComp() | ast.GeneratorExp():
return evaluate_listcomp(expression, *common_params)
case ast.UnaryOp():
return evaluate_unaryop(expression, *common_params)
case ast.Starred():
return evaluate_ast(expression.value, *common_params)
case ast.BoolOp():
# Boolean operation -> evaluate the operation
return evaluate_boolop(expression, *common_params)
case ast.Break():
raise BreakException()
case ast.Continue():
raise ContinueException()
case ast.BinOp():
# Binary operation -> execute operation
return evaluate_binop(expression, *common_params)
case ast.Compare():
# Comparison -> evaluate the comparison
return evaluate_condition(expression, *common_params)
case ast.Lambda():
return evaluate_lambda(expression, *common_params)
case ast.FunctionDef():
return evaluate_function_def(expression, *common_params)
case ast.Dict():
# Dict -> evaluate all keys and values
keys = map(evaluate_ast_partial, expression.keys)
values = map(evaluate_ast_partial, expression.values)
return dict(zip(keys, values))
case ast.Expr():
# Expression -> evaluate the content
return evaluate_ast(expression.value, *common_params)
case ast.For():
# For loop -> execute the loop
return evaluate_for(expression, *common_params)
case ast.FormattedValue():
# Formatted value (part of f-string) -> evaluate the content and return
return evaluate_ast(expression.value, *common_params)
case ast.If():
# If -> execute the right branch
return evaluate_if(expression, *common_params)
case ast.Index() if hasattr(expression, "Index"):
return evaluate_ast(expression.value, *common_params)
case ast.JoinedStr():
return "".join(
(str(evaluate_ast_partial(v)) for v in expression.values)
)
case ast.List():
# List -> evaluate all elements
return [evaluate_ast_partial(elt) for elt in expression.elts]
case ast.Name():
# Name -> pick up the value in the state
return evaluate_name(expression, *common_params)
case ast.Subscript():
# Subscript -> return the value of the indexing
return evaluate_subscript(expression, *common_params)
case ast.IfExp():
test_val = evaluate_ast(expression.test, *common_params)
if test_val:
return evaluate_ast(expression.body, *common_params)
else:
return evaluate_ast(expression.orelse, *common_params)
case ast.Attribute():
value = evaluate_ast(expression.value, *common_params)
return getattr(value, expression.attr)
case ast.Slice():
return slice(
evaluate_ast(expression.lower, *common_params)
if expression.lower is not None
else None,
evaluate_ast(expression.upper, *common_params)
if expression.upper is not None
else None,
evaluate_ast(expression.step, *common_params)
if expression.step is not None
else None,
)
case ast.DictComp():
return evaluate_dictcomp(expression, *common_params)
case ast.While():
return evaluate_while(expression, *common_params)
case ast.Import() | ast.ImportFrom():
return import_modules(expression, state, authorized_imports)
case ast.ClassDef():
return evaluate_class_def(expression, *common_params)
case ast.Try():
return evaluate_try(expression, *common_params)
case ast.Raise():
return evaluate_raise(expression, *common_params)
case ast.Assert():
return evaluate_assert(expression, *common_params)
case ast.With():
return evaluate_with(expression, *common_params)
case ast.Set():
return set(map(evaluate_ast_partial, expression.elts))
case ast.Return():
raise ReturnException(
evaluate_ast(expression.value, *common_params)
if expression.value
else None
)
case ast.Pass():
return None
case ast.Delete():
return evaluate_delete(expression, *common_params)
case _:
# For now we refuse anything else. Let's add things as we need them.
raise InterpreterError(f"{expression.__class__.__name__} is not supported.")


class FinalAnswerException(Exception):
Expand Down

0 comments on commit 3ce7c97

Please sign in to comment.