From 437f9498df042fabd448d3c03b0cb63a8d58a870 Mon Sep 17 00:00:00 2001 From: Arslan Saleem Date: Mon, 18 Nov 2024 14:35:19 +0100 Subject: [PATCH] =?UTF-8?q?fix[output=5Fformat]:=20accept=20dataframe=20di?= =?UTF-8?q?ct=20as=20output=20and=20secure=20sql=20qu=E2=80=A6=20(#1432)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix[output_format]: accept dataframe dict as output and secure sql query execution * fix: ruff errors --- pandasai/connectors/sql.py | 2 +- pandasai/helpers/output_validator.py | 4 ++-- pandasai/responses/response_parser.py | 12 ++++++++++++ 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/pandasai/connectors/sql.py b/pandasai/connectors/sql.py index e1494ba59..68638e8a2 100644 --- a/pandasai/connectors/sql.py +++ b/pandasai/connectors/sql.py @@ -441,7 +441,7 @@ def execute_direct_sql_query(self, sql_query): if not self._is_sql_query_safe(sql_query): raise MaliciousQueryError("Malicious query is generated in code") - return pd.read_sql(sql_query, self._connection) + return pd.read_sql(text(sql_query), self._connection) @property def cs_table_name(self): diff --git a/pandasai/helpers/output_validator.py b/pandasai/helpers/output_validator.py index e26bcf2ff..56a3a495d 100644 --- a/pandasai/helpers/output_validator.py +++ b/pandasai/helpers/output_validator.py @@ -56,7 +56,7 @@ def validate_value(self, expected_type: str) -> bool: elif expected_type == "string": return isinstance(self, str) elif expected_type == "dataframe": - return isinstance(self, (pd.DataFrame, pd.Series)) + return isinstance(self, (pd.DataFrame, pd.Series, dict)) elif expected_type == "plot": if not isinstance(self, (str, dict)): return False @@ -82,7 +82,7 @@ def validate_result(result: dict) -> bool: elif result["type"] == "string": return isinstance(result["value"], str) elif result["type"] == "dataframe": - return isinstance(result["value"], (pd.DataFrame, pd.Series)) + return isinstance(result["value"], (pd.DataFrame, pd.Series, dict)) elif result["type"] == "plot": if "plotly" in repr(type(result["value"])): return True diff --git a/pandasai/responses/response_parser.py b/pandasai/responses/response_parser.py index fd202784d..4254c77ec 100644 --- a/pandasai/responses/response_parser.py +++ b/pandasai/responses/response_parser.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from typing import Any +import pandas as pd from PIL import Image from pandasai.exceptions import MethodNotImplementedError @@ -51,9 +52,20 @@ def parse(self, result: dict) -> Any: if result["type"] == "plot": return self.format_plot(result) + elif result["type"] == "dataframe": + return self.format_dataframe(result) else: return result["value"] + def format_dataframe(self, result: dict) -> Any: + if isinstance(result["value"], dict): + print("Df conversiont") + df = pd.Dataframe(result["value"]) + print("Df conversiont Done") + result["value"] = df + + return result["value"] + def format_plot(self, result: dict) -> Any: """ Display matplotlib plot against a user query.