Skip to content

Commit

Permalink
Make local Python interpreter safer by checking if returns dangerous …
Browse files Browse the repository at this point in the history
…modules (#861)
  • Loading branch information
albertvillanova authored Mar 4, 2025
1 parent 43d0d59 commit 8849b95
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 3 deletions.
27 changes: 25 additions & 2 deletions src/smolagents/local_python_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,18 @@ def custom_print(*args):
"multiprocessing",
)

DANGEROUS_MODULES = [
"os",
"subprocess",
"pty",
"shutil",
"sys",
"pathlib",
"io",
"socket",
"multiprocessing",
]


class PrintContainer:
def __init__(self):
Expand Down Expand Up @@ -225,12 +237,23 @@ def safer_eval(func: Callable):
"""

@wraps(func)
def _check_return(*args, **kwargs):
result = func(*args, **kwargs)
def _check_return(
expression,
state,
static_tools,
custom_tools,
authorized_imports=BASE_BUILTIN_MODULES,
):
result = func(expression, state, static_tools, custom_tools, authorized_imports=authorized_imports)
if (isinstance(result, ModuleType) and result is builtins) or (
isinstance(result, dict) and result == vars(builtins)
):
raise InterpreterError("Forbidden return value: builtins")
if isinstance(result, ModuleType):
if "*" not in authorized_imports:
for module in DANGEROUS_MODULES:
if module not in authorized_imports and result is import_module(module):
raise InterpreterError(f"Forbidden return value: {module}")
return result

return _check_return
Expand Down
38 changes: 37 additions & 1 deletion tests/test_local_python_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1436,7 +1436,7 @@ def test_vulnerability_builtins(self):

@pytest.mark.parametrize(
"additional_authorized_imports, expectation",
[([], pytest.raises(InterpreterError)), (["sys"], does_not_raise())],
[([], pytest.raises(InterpreterError)), (["os", "sys"], does_not_raise())],
)
def test_vulnerability_via_sys(self, additional_authorized_imports, expectation):
executor = LocalPythonExecutor(additional_authorized_imports)
Expand All @@ -1450,6 +1450,22 @@ def test_vulnerability_via_sys(self, additional_authorized_imports, expectation)
)
)

@pytest.mark.parametrize(
"additional_authorized_imports, expectation",
[(["importlib"], pytest.raises(InterpreterError)), (["importlib", "os"], does_not_raise())],
)
def test_vulnerability_via_importlib(self, additional_authorized_imports, expectation):
executor = LocalPythonExecutor(additional_authorized_imports)
with expectation:
executor(
dedent(
"""
import importlib
importlib.import_module("os").system(":")
"""
)
)

@pytest.mark.parametrize(
"code",
[
Expand Down Expand Up @@ -1500,3 +1516,23 @@ def test_vulnerability_builtins_via_class_catch_warnings(self):
"""
)
)

@pytest.mark.parametrize(
"additional_authorized_imports, expectation",
[([], pytest.raises(InterpreterError)), (["os"], does_not_raise())],
)
def test_vulnerability_load_module_via_builtin_importer(self, additional_authorized_imports, expectation):
executor = LocalPythonExecutor(additional_authorized_imports)
with expectation:
executor(
dedent(
"""
classes = {}.__class__.__base__.__subclasses__()
for cls in classes:
if cls.__name__ == "BuiltinImporter":
break
os_module = cls().load_module("os")
os_module.system(":")
"""
)
)

0 comments on commit 8849b95

Please sign in to comment.