diff --git a/pandasai/constants.py b/pandasai/constants.py index bf0734547..43dccbd4e 100644 --- a/pandasai/constants.py +++ b/pandasai/constants.py @@ -85,21 +85,35 @@ # List of Python packages that are whitelisted for import in generated code WHITELISTED_LIBRARIES = [ - "sklearn", - "statsmodels", "seaborn", - "plotly", - "ggplot", "matplotlib", "numpy", "datetime", "json", - "io", "base64", - "scipy", - "streamlit", - "modin", - "scikit-learn", + "pandas", +] + +# List of restricted libs +RESTRICTED_LIBS = [ + "os", # OS-level operations (file handling, environment variables) + "sys", # System-level access + "subprocess", # Run system commands + "shutil", # File operations, including delete + "multiprocessing", # Spawn new processes + "threading", # Thread-level operations + "socket", # Network connections + "http", # HTTP requests + "ftplib", # FTP connections + "paramiko", # SSH operations + "tempfile", # Create temporary files + "pathlib", # Filesystem path handling + "resource", # Access resource usage limits (system-related) + "ssl", # SSL socket connections + "pickle", # Unsafe object serialization + "ctypes", # C-level interaction with memory + "psutil", # System and process utilities + "io", ] PANDASBI_SETUP_MESSAGE = ( diff --git a/pandasai/helpers/optional.py b/pandasai/helpers/optional.py index 37ccd87ee..f1d7a830a 100644 --- a/pandasai/helpers/optional.py +++ b/pandasai/helpers/optional.py @@ -10,12 +10,16 @@ import warnings from typing import TYPE_CHECKING, List -import matplotlib.pyplot as plt -import numpy as np from pandas.util.version import Version -import pandasai.pandas as pd from pandasai.constants import WHITELISTED_BUILTINS +from pandasai.safe_libs.restricted_base64 import RestrictedBase64 +from pandasai.safe_libs.restricted_datetime import RestrictedDatetime +from pandasai.safe_libs.restricted_json import RestrictedJson +from pandasai.safe_libs.restricted_matplotlib import RestrictedMatplotlib +from pandasai.safe_libs.restricted_numpy import RestrictedNumpy +from pandasai.safe_libs.restricted_pandas import RestrictedPandas +from pandasai.safe_libs.restricted_seaborn import RestrictedSeaborn if TYPE_CHECKING: import types @@ -54,10 +58,7 @@ def get_environment(additional_deps: List[dict]) -> dict: Returns (dict): A dictionary of environment variables """ - return { - "pd": pd, - "plt": plt, - "np": np, + env = { **{ lib["alias"]: ( getattr(import_dependency(lib["module"]), lib["name"]) @@ -73,6 +74,25 @@ def get_environment(additional_deps: List[dict]) -> dict: }, } + env["pd"] = RestrictedPandas() + env["plt"] = RestrictedMatplotlib() + env["np"] = RestrictedNumpy() + + for lib in additional_deps: + if lib["name"] == "seaborn": + env["sns"] = RestrictedSeaborn() + + if lib["name"] == "datetime": + env["datetime"] = RestrictedDatetime() + + if lib["name"] == "json": + env["json"] = RestrictedJson() + + if lib["name"] == "base64": + env["base64"] = RestrictedBase64() + + return env + def import_dependency( name: str, diff --git a/pandasai/pipelines/chat/code_cleaning.py b/pandasai/pipelines/chat/code_cleaning.py index 3effcaed0..5e86975df 100644 --- a/pandasai/pipelines/chat/code_cleaning.py +++ b/pandasai/pipelines/chat/code_cleaning.py @@ -15,7 +15,7 @@ from ...connectors import BaseConnector from ...connectors.sql import SQLConnector -from ...constants import WHITELISTED_BUILTINS, WHITELISTED_LIBRARIES +from ...constants import RESTRICTED_LIBS, WHITELISTED_LIBRARIES from ...exceptions import ( BadImportError, ExecuteSQLQueryNotUsed, @@ -161,6 +161,58 @@ def get_code_to_run(self, code: str, context: CodeExecutionContext) -> Any: return code_to_run def _is_malicious_code(self, code) -> bool: + tree = ast.parse(code) + + # Check for private attributes and access of restricted libs + def check_restricted_access(node): + """Check if the node accesses restricted modules or private attributes.""" + if isinstance(node, ast.Attribute): + attr_chain = [] + while isinstance(node, ast.Attribute): + if node.attr.startswith("_"): + raise MaliciousQueryError( + f"Access to private attribute '{node.attr}' is not allowed." + ) + attr_chain.insert(0, node.attr) + node = node.value + if isinstance(node, ast.Name): + attr_chain.insert(0, node.id) + if any(module in RESTRICTED_LIBS for module in attr_chain): + raise MaliciousQueryError( + f"Restricted access detected in attribute chain: {'.'.join(attr_chain)}" + ) + + elif isinstance(node, ast.Subscript) and isinstance( + node.value, ast.Attribute + ): + check_restricted_access(node.value) + + for node in ast.walk(tree): + # Check 'import ...' statements + if isinstance(node, ast.Import): + for alias in node.names: + sub_module_names = alias.name.split(".") + if any(module in RESTRICTED_LIBS for module in sub_module_names): + raise MaliciousQueryError( + f"Restricted library import detected: {alias.name}" + ) + + # Check 'from ... import ...' statements + elif isinstance(node, ast.ImportFrom): + sub_module_names = node.module.split(".") + if any(module in RESTRICTED_LIBS for module in sub_module_names): + raise MaliciousQueryError( + f"Restricted library import detected: {node.module}" + ) + if any(alias.name in RESTRICTED_LIBS for alias in node.names): + raise MaliciousQueryError( + "Restricted library import detected in 'from ... import ...'" + ) + + # Check attribute access for restricted libraries + elif isinstance(node, (ast.Attribute, ast.Subscript)): + check_restricted_access(node) + dangerous_modules = [ " os", " io", @@ -176,6 +228,7 @@ def _is_malicious_code(self, code) -> bool: "(chr", "b64decode", ] + return any( re.search(r"\b" + re.escape(module) + r"\b", code) for module in dangerous_modules @@ -584,5 +637,9 @@ def _check_imports(self, node: Union[ast.Import, ast.ImportFrom]): ) return - if library not in WHITELISTED_BUILTINS: - raise BadImportError(library) + if library not in WHITELISTED_LIBRARIES: + raise BadImportError( + f"The library '{library}' is not in the list of whitelisted libraries. " + "To learn how to whitelist custom dependencies, visit: " + "https://docs.pandas-ai.com/custom-whitelisted-dependencies#custom-whitelisted-dependencies" + ) diff --git a/pandasai/safe_libs/base_restricted_module.py b/pandasai/safe_libs/base_restricted_module.py new file mode 100644 index 000000000..3067a3aab --- /dev/null +++ b/pandasai/safe_libs/base_restricted_module.py @@ -0,0 +1,27 @@ +class BaseRestrictedModule: + def _wrap_function(self, func): + def wrapper(*args, **kwargs): + # Check for any suspicious arguments that might be used for importing + for arg in args + tuple(kwargs.values()): + if isinstance(arg, str) and any( + module in arg.lower() + for module in ["io", "os", "subprocess", "sys", "importlib"] + ): + raise SecurityError( + f"Potential security risk: '{arg}' is not allowed" + ) + return func(*args, **kwargs) + + return wrapper + + def _wrap_class(self, cls): + class WrappedClass(cls): + def __getattribute__(self, name): + attr = super().__getattribute__(name) + return self._wrap_function(self, attr) if callable(attr) else attr + + return WrappedClass + + +class SecurityError(Exception): + pass diff --git a/pandasai/safe_libs/restricted_base64.py b/pandasai/safe_libs/restricted_base64.py new file mode 100644 index 000000000..eb305885e --- /dev/null +++ b/pandasai/safe_libs/restricted_base64.py @@ -0,0 +1,21 @@ +import base64 + +from .base_restricted_module import BaseRestrictedModule + + +class RestrictedBase64(BaseRestrictedModule): + def __init__(self): + self.allowed_functions = [ + "b64encode", # Safe function to encode data into base64 + "b64decode", # Safe function to decode base64 encoded data + ] + + # Bind the allowed functions to the object + for func in self.allowed_functions: + if hasattr(base64, func): + setattr(self, func, self._wrap_function(getattr(base64, func))) + + def __getattr__(self, name): + if name not in self.allowed_functions: + raise AttributeError(f"'{name}' is not allowed in RestrictedBase64") + return getattr(base64, name) diff --git a/pandasai/safe_libs/restricted_datetime.py b/pandasai/safe_libs/restricted_datetime.py new file mode 100644 index 000000000..0fc48290a --- /dev/null +++ b/pandasai/safe_libs/restricted_datetime.py @@ -0,0 +1,64 @@ +import datetime + +from .base_restricted_module import BaseRestrictedModule + + +class RestrictedDatetime(BaseRestrictedModule): + def __init__(self): + self.allowed_attributes = [ + # Classes + "date", + "time", + "datetime", + "timedelta", + "tzinfo", + "timezone", + # Constants + "MINYEAR", + "MAXYEAR", + # Time zone constants + "UTC", + # Functions + "now", + "utcnow", + "today", + "fromtimestamp", + "utcfromtimestamp", + "fromordinal", + "combine", + "strptime", + # Timedelta operations + "timedelta", + # Date operations + "weekday", + "isoweekday", + "isocalendar", + "isoformat", + "ctime", + "strftime", + "year", + "month", + "day", + "hour", + "minute", + "second", + "microsecond", + # Time operations + "replace", + "tzname", + "dst", + "utcoffset", + # Comparison methods + "min", + "max", + ] + + for attr in self.allowed_attributes: + if hasattr(datetime, attr): + setattr(self, attr, self._wrap_function(getattr(datetime, attr))) + + def __getattr__(self, name): + if name not in self.allowed_attributes: + raise AttributeError(f"'{name}' is not allowed in RestrictedDatetime") + + return getattr(datetime, name) diff --git a/pandasai/safe_libs/restricted_json.py b/pandasai/safe_libs/restricted_json.py new file mode 100644 index 000000000..7f13b6112 --- /dev/null +++ b/pandasai/safe_libs/restricted_json.py @@ -0,0 +1,23 @@ +import json + +from .base_restricted_module import BaseRestrictedModule + + +class RestrictedJson(BaseRestrictedModule): + def __init__(self): + self.allowed_functions = [ + "load", + "loads", + "dump", + "dumps", + ] + + # Bind the allowed functions to the object + for func in self.allowed_functions: + if hasattr(json, func): + setattr(self, func, self._wrap_function(getattr(json, func))) + + def __getattr__(self, name): + if name not in self.allowed_functions: + raise AttributeError(f"'{name}' is not allowed in RestrictedJson") + return getattr(json, name) diff --git a/pandasai/safe_libs/restricted_matplotlib.py b/pandasai/safe_libs/restricted_matplotlib.py new file mode 100644 index 000000000..82635bfda --- /dev/null +++ b/pandasai/safe_libs/restricted_matplotlib.py @@ -0,0 +1,76 @@ +import matplotlib.axes as axes +import matplotlib.figure as figure +import matplotlib.pyplot as plt + +from .base_restricted_module import BaseRestrictedModule + + +class RestrictedMatplotlib(BaseRestrictedModule): + def __init__(self): + self.allowed_attributes = [ + # Figure and Axes creation + "figure", + "subplots", + "subplot", + # Plotting functions + "plot", + "scatter", + "bar", + "barh", + "hist", + "boxplot", + "violinplot", + "pie", + "errorbar", + "contour", + "contourf", + "imshow", + "pcolor", + "pcolormesh", + # Axis manipulation + "xlabel", + "ylabel", + "title", + "legend", + "xlim", + "ylim", + "axis", + "xticks", + "yticks", + "grid", + "axhline", + "axvline", + # Colorbar + "colorbar", + # Text and annotations + "text", + "annotate", + # Styling + "style", + # Save and show + "show", + "savefig", + # Color maps + "get_cmap", + # 3D plotting + "axes3d", + # Utility functions + "close", + "clf", + "cla", + # Constants + "rcParams", + ] + + for attr in self.allowed_attributes: + if hasattr(plt, attr): + setattr(self, attr, self._wrap_function(getattr(plt, attr))) + + # Special handling for figure and axes + self.Figure = self._wrap_class(figure.Figure) + self.Axes = self._wrap_class(axes.Axes) + + def __getattr__(self, name): + if name not in self.allowed_attributes: + raise AttributeError(f"'{name}' is not allowed in RestrictedMatplotlib") + return getattr(plt, name) diff --git a/pandasai/safe_libs/restricted_numpy.py b/pandasai/safe_libs/restricted_numpy.py new file mode 100644 index 000000000..855fb70d6 --- /dev/null +++ b/pandasai/safe_libs/restricted_numpy.py @@ -0,0 +1,182 @@ +import numpy as np + +from .base_restricted_module import BaseRestrictedModule + + +class RestrictedNumpy(BaseRestrictedModule): + def __init__(self): + self.allowed_attributes = [ + # Array creation + "array", + "zeros", + "ones", + "empty", + "full", + "zeros_like", + "ones_like", + "empty_like", + "full_like", + "eye", + "identity", + "diag", + "arange", + "linspace", + "logspace", + "geomspace", + "fromfunction", + "fromiter", + # Array manipulation + "reshape", + "ravel", + "flatten", + "moveaxis", + "rollaxis", + "swapaxes", + "transpose", + "split", + "hsplit", + "vsplit", + "dsplit", + "stack", + "column_stack", + "dstack", + "row_stack", + "concatenate", + "vstack", + "hstack", + "tile", + "repeat", + # Mathematical operations + "add", + "subtract", + "multiply", + "divide", + "power", + "mod", + "remainder", + "divmod", + "negative", + "positive", + "absolute", + "fabs", + "rint", + "floor", + "ceil", + "trunc", + "exp", + "expm1", + "exp2", + "log", + "log10", + "log2", + "log1p", + "sqrt", + "square", + "cbrt", + "reciprocal", + # Trigonometric functions + "sin", + "cos", + "tan", + "arcsin", + "arccos", + "arctan", + "arctan2", + "hypot", + "sinh", + "cosh", + "tanh", + "arcsinh", + "arccosh", + "arctanh", + "deg2rad", + "rad2deg", + # Statistical functions + "mean", + "average", + "median", + "std", + "var", + "min", + "max", + "argmin", + "argmax", + "sum", + "prod", + "percentile", + "quantile", + "histogram", + "histogram2d", + "histogramdd", + "bincount", + "digitize", + # Linear algebra + "dot", + "vdot", + "inner", + "outer", + "matmul", + "tensordot", + "einsum", + "trace", + "diagonal", + # Sorting and searching + "sort", + "argsort", + "partition", + "argpartition", + "searchsorted", + "nonzero", + "where", + "extract", + # Logic functions + "all", + "any", + "greater", + "greater_equal", + "less", + "less_equal", + "equal", + "not_equal", + "logical_and", + "logical_or", + "logical_not", + "logical_xor", + "isfinite", + "isinf", + "isnan", + "isneginf", + "isposinf", + # Set operations + "unique", + "intersect1d", + "union1d", + "setdiff1d", + "setxor1d", + # Basic array information + "shape", + "size", + "ndim", + "dtype", + # Utility functions + "clip", + "round", + "sign", + "conj", + "real", + "imag", + "copy", + "asarray", + "asanyarray", + "ascontiguousarray", + "asfortranarray", + ] + + for attr in self.allowed_attributes: + if hasattr(np, attr): + setattr(self, attr, self._wrap_function(getattr(np, attr))) + + def __getattr__(self, name): + if name not in self.allowed_attributes: + raise AttributeError(f"'{name}' is not allowed in RestrictedNumPy") + return getattr(np, name) diff --git a/pandasai/safe_libs/restricted_pandas.py b/pandasai/safe_libs/restricted_pandas.py new file mode 100644 index 000000000..75e5a083c --- /dev/null +++ b/pandasai/safe_libs/restricted_pandas.py @@ -0,0 +1,110 @@ +import pandas as pd + +from .base_restricted_module import BaseRestrictedModule + + +class RestrictedPandas(BaseRestrictedModule): + def __init__(self): + self.allowed_attributes = [ + # DataFrame creation and basic operations + "DataFrame", + "Series", + "concat", + "merge", + "join", + # Data manipulation + "groupby", + "pivot", + "pivot_table", + "melt", + "crosstab", + "cut", + "qcut", + "get_dummies", + "factorize", + # Indexing and selection + "loc", + "iloc", + "at", + "iat", + # Function application + "apply", + "applymap", + "pipe", + # Reshaping and sorting + "sort_values", + "sort_index", + "nlargest", + "nsmallest", + "rank", + "reindex", + "reset_index", + "set_index", + # Computations / descriptive stats + "sum", + "prod", + "min", + "max", + "mean", + "median", + "var", + "std", + "sem", + "skew", + "kurt", + "quantile", + "count", + "nunique", + "value_counts", + "describe", + "cov", + "corr", + # Date functionality + "to_datetime", + "date_range", + # String methods + "str", + # Categorical methods + "Categorical", + "cut", + "qcut", + # Plotting (if visualization is allowed) + "plot", + # Utility functions + "isnull", + "notnull", + "isna", + "notna", + "fillna", + "dropna", + "replace", + "astype", + "copy", + "drop_duplicates", + # Window functions + "rolling", + "expanding", + "ewm", + # Time series functionality + "resample", + "shift", + "diff", + "pct_change", + # Aggregation + "agg", + "aggregate", + ] + + for attr in self.allowed_attributes: + if hasattr(pd, attr): + setattr(self, attr, self._wrap_function(getattr(pd, attr))) + elif attr in ["loc", "iloc", "at", "iat"]: + # These are properties, not functions + setattr( + self, attr, property(lambda self, a=attr: getattr(pd.DataFrame, a)) + ) + + def __getattr__(self, name): + if name not in self.allowed_attributes: + raise AttributeError(f"'{name}' is not allowed in RestrictedPandas") + return getattr(pd, name) diff --git a/pandasai/safe_libs/restricted_seaborn.py b/pandasai/safe_libs/restricted_seaborn.py new file mode 100644 index 000000000..a5ef4c6e8 --- /dev/null +++ b/pandasai/safe_libs/restricted_seaborn.py @@ -0,0 +1,74 @@ +import seaborn as sns + +from .base_restricted_module import BaseRestrictedModule + + +class RestrictedSeaborn(BaseRestrictedModule): + def __init__(self): + self.allowed_attributes = [ + # Plot functions + "scatterplot", + "lineplot", + "relplot", + "displot", + "histplot", + "kdeplot", + "ecdfplot", + "rugplot", + "distplot", + "boxplot", + "violinplot", + "boxenplot", + "stripplot", + "swarmplot", + "barplot", + "countplot", + "heatmap", + "clustermap", + "regplot", + "lmplot", + "residplot", + "jointplot", + "pairplot", + "catplot", + # Axis styling + "set_style", + "set_context", + "set_palette", + "despine", + "move_legend", + "axes_style", + "plotting_context", + # Color palette functions + "color_palette", + "palplot", + "cubehelix_palette", + "light_palette", + "dark_palette", + "diverging_palette", + # Utility functions + "load_dataset", + # Figure-level interface + "FacetGrid", + "PairGrid", + "JointGrid", + # Regression and statistical estimation + "lmplot", + "regplot", + "residplot", + # Matrix plots + "heatmap", + "clustermap", + # Miscellaneous + "kdeplot", + "rugplot", + ] + + for attr in self.allowed_attributes: + if hasattr(sns, attr): + setattr(self, attr, self._wrap_function(getattr(sns, attr))) + + def __getattr__(self, name): + if name not in self.allowed_attributes: + raise AttributeError(f"'{name}' is not allowed in RestrictedSeaborn") + return getattr(sns, name) diff --git a/tests/unit_tests/pipelines/smart_datalake/test_code_cleaning.py b/tests/unit_tests/pipelines/smart_datalake/test_code_cleaning.py index 8ccad9efd..49169f373 100644 --- a/tests/unit_tests/pipelines/smart_datalake/test_code_cleaning.py +++ b/tests/unit_tests/pipelines/smart_datalake/test_code_cleaning.py @@ -188,23 +188,17 @@ def test_run_code_invalid_code( with pytest.raises(Exception): code_cleaning.execute("1 +", context=context, logger=logger) - def test_clean_code_remove_builtins( + def test_clean_code_raise_not_whitelisted_lib( self, code_cleaning: CodeCleaning, context: PipelineContext, logger: Logger, ): - builtins_code = """import set + builtins_code = """import scipy result = {'type': 'number', 'value': set([1, 2, 3])}""" - output = code_cleaning.execute(builtins_code, context=context, logger=logger) - - assert ( - output.output == """result = {'type': 'number', 'value': set([1, 2, 3])}""" - ) - assert isinstance(output, LogicUnitOutput) - assert output.success - assert output.message == "Code Cleaned Successfully" + with pytest.raises(BadImportError): + code_cleaning.execute(builtins_code, context=context, logger=logger) def test_clean_code_removes_jailbreak_code( self, @@ -215,12 +209,8 @@ def test_clean_code_removes_jailbreak_code( malicious_code = """__builtins__['str'].__class__.__mro__[-1].__subclasses__()[140].__init__.__globals__['system']('ls') print('hello world')""" - output = code_cleaning.execute(malicious_code, context=context, logger=logger) - - assert output.output == """print('hello world')""" - assert isinstance(output, LogicUnitOutput) - assert output.success - assert output.message == "Code Cleaned Successfully" + with pytest.raises(MaliciousQueryError): + code_cleaning.execute(malicious_code, context=context, logger=logger) def test_clean_code_remove_environment_defaults( self, @@ -900,3 +890,41 @@ def cs_table_name(self): node = code_cleaning._validate_and_make_table_name_case_sensitive(mock_node) assert node.value.args[0].value == 'SELECT COUNT(*) AS user_count FROM "Users"' + + def test_clean_code_raise_private_variable_access_error( + self, + code_cleaning: CodeCleaning, + context: PipelineContext, + logger: Logger, + ): + malicious_code = """ +import scipy +result = {"type": "string", "value": f"{scipy.sparse._sputils.sys.modules['subprocess'].run(['cmd', '/c', 'dir'], text=True, capture_output=True).stdout}"} +print(result) +""" + with pytest.raises(MaliciousQueryError): + code_cleaning.execute(malicious_code, context=context, logger=logger) + + def test_clean_code_raise_import_with_restricted_modules( + self, + code_cleaning: CodeCleaning, + context: PipelineContext, + logger: Logger, + ): + malicious_code = """ +from datetime import sys +""" + with pytest.raises(MaliciousQueryError): + code_cleaning.execute(malicious_code, context=context, logger=logger) + + def test_clean_code_raise_import_with_restricted_using_import_statement( + self, + code_cleaning: CodeCleaning, + context: PipelineContext, + logger: Logger, + ): + malicious_code = """ +import datetime.sys as spy +""" + with pytest.raises(MaliciousQueryError): + code_cleaning.execute(malicious_code, context=context, logger=logger)