diff --git a/pandasai/helpers/cache.py b/pandasai/helpers/cache.py index 7e341d0c3..bd87805df 100644 --- a/pandasai/helpers/cache.py +++ b/pandasai/helpers/cache.py @@ -1,5 +1,6 @@ import os import glob +from typing import Any import duckdb from .path import find_project_root @@ -72,3 +73,19 @@ def destroy(self) -> None: self.connection.close() for cache_file in glob.glob(f"{self.filepath}.*"): os.remove(cache_file) + + def get_cache_key(self, context: Any) -> str: + """ + Return the cache key for the current conversation. + + Returns: + str: The cache key for the current conversation + """ + cache_key = context.memory.get_conversation() + + # make the cache key unique for each combination of dfs + for df in context.dfs: + hash = df.column_hash() + cache_key += str(hash) + + return cache_key diff --git a/pandasai/pipelines/base_logic_unit.py b/pandasai/pipelines/base_logic_unit.py index 5d41a2402..221cabe15 100644 --- a/pandasai/pipelines/base_logic_unit.py +++ b/pandasai/pipelines/base_logic_unit.py @@ -7,6 +7,12 @@ class BaseLogicUnit(ABC): Logic units for pipeline each logic unit should be inherited from this Logic unit """ + _skip_if: callable + + def __init__(self, skip_if=None): + super().__init__() + self._skip_if = skip_if + @abstractmethod def execute(self, input: Any, **kwargs) -> Any: """ @@ -22,3 +28,7 @@ def execute(self, input: Any, **kwargs) -> Any: :return: The result of the execution. """ raise NotImplementedError("execute method is not implemented.") + + @property + def skip_if(self): + return self._skip_if diff --git a/pandasai/pipelines/pipeline.py b/pandasai/pipelines/pipeline.py index 6353cc977..6a1bda9ac 100644 --- a/pandasai/pipelines/pipeline.py +++ b/pandasai/pipelines/pipeline.py @@ -5,7 +5,6 @@ from pandasai.helpers.logger import Logger from pandasai.pipelines.pipeline_context import PipelineContext from pandasai.pipelines.base_logic_unit import BaseLogicUnit -from pandasai.smart_dataframe import SmartDataframe, load_smartdataframes from ..schemas.df_config import Config from typing import Any, Optional, List, Union from .abstract_pipeline import AbstractPipeline @@ -22,9 +21,7 @@ class Pipeline(AbstractPipeline): def __init__( self, - context: Union[ - List[Union[DataFrameType, SmartDataframe]], PipelineContext - ] = None, + context: Union[List[Union[DataFrameType, Any]], PipelineContext] = None, config: Optional[Union[Config, dict]] = None, steps: Optional[List] = None, logger: Optional[Logger] = None, @@ -40,6 +37,8 @@ def __init__( """ if not isinstance(context, PipelineContext): + from pandasai.smart_dataframe import load_smartdataframes + config = Config(**load_config(config)) smart_dfs = load_smartdataframes(context, config) context = PipelineContext(smart_dfs, config) @@ -79,6 +78,10 @@ def run(self, data: Any = None) -> Any: try: for index, logic in enumerate(self._steps): self._logger.log(f"Executing Step {index}: {logic.__class__.__name__}") + + if logic.skip_if is not None and logic.skip_if(self._context): + continue + data = logic.execute( data, logger=self._logger, diff --git a/pandasai/pipelines/pipeline_context.py b/pandasai/pipelines/pipeline_context.py index f38e0738c..8b972c2c5 100644 --- a/pandasai/pipelines/pipeline_context.py +++ b/pandasai/pipelines/pipeline_context.py @@ -1,11 +1,11 @@ -from typing import List, Optional, Union +from typing import List, Optional, Union, Any from pandasai.helpers.cache import Cache from pandasai.helpers.df_info import DataFrameType from pandasai.helpers.memory import Memory +from pandasai.helpers.query_exec_tracker import QueryExecTracker from pandasai.helpers.skills_manager import SkillsManager from pandasai.schemas.df_config import Config -from pandasai.smart_dataframe import SmartDataframe, load_smartdataframes class PipelineContext: @@ -13,20 +13,25 @@ class PipelineContext: Pass Context to the pipeline which is accessible to each step via kwargs """ - _dfs: List[Union[DataFrameType, SmartDataframe]] + _dfs: List[Union[DataFrameType, Any]] _memory: Memory _skills: SkillsManager _cache: Cache _config: Config + _query_exec_tracker: QueryExecTracker + _intermediate_values: dict def __init__( self, - dfs: List[Union[DataFrameType, SmartDataframe]], + dfs: List[Union[DataFrameType, Any]], config: Optional[Union[Config, dict]] = None, memory: Memory = None, skills: SkillsManager = None, cache: Cache = None, + query_exec_tracker: QueryExecTracker = None, ) -> None: + from pandasai.smart_dataframe import load_smartdataframes + if isinstance(config, dict): config = Config(**config) @@ -35,9 +40,11 @@ def __init__( self._skills = skills if skills is not None else SkillsManager() self._cache = cache if cache is not None else Cache() self._config = config + self._query_exec_tracker = query_exec_tracker + self._intermediate_values = {} @property - def dfs(self) -> List[Union[DataFrameType, SmartDataframe]]: + def dfs(self) -> List[Union[DataFrameType, Any]]: return self._dfs @property @@ -55,3 +62,13 @@ def cache(self): @property def config(self): return self._config + + @property + def query_exec_tracker(self): + return self._query_exec_tracker + + def add_intermediate_value(self, key: str, value: Any): + self._intermediate_values[key] = value + + def get_intermediate_value(self, key: str): + return self._intermediate_values.get(key, "") diff --git a/pandasai/pipelines/smart_datalake_chat/cache_lookup.py b/pandasai/pipelines/smart_datalake_chat/cache_lookup.py new file mode 100644 index 000000000..46142b0a7 --- /dev/null +++ b/pandasai/pipelines/smart_datalake_chat/cache_lookup.py @@ -0,0 +1,43 @@ +from typing import Any +from ...helpers.logger import Logger +from ..base_logic_unit import BaseLogicUnit +from ..pipeline_context import PipelineContext + + +class CacheLookup(BaseLogicUnit): + """ + Cache Lookup of Code Stage + """ + + pass + + def execute(self, input: Any, **kwargs) -> Any: + """ + This method will return output according to + Implementation. + + :param input: Your input data. + :param kwargs: A dictionary of keyword arguments. + - 'logger' (any): The logger for logging. + - 'config' (Config): Global configurations for the test + - 'context' (any): The execution context. + + :return: The result of the execution. + """ + pipeline_context: PipelineContext = kwargs.get("context") + logger: Logger = kwargs.get("logger") + if ( + pipeline_context.config.enable_cache + and pipeline_context.cache + and pipeline_context.cache.get( + pipeline_context.cache.get_cache_key(pipeline_context) + ) + ): + logger.log("Using cached response") + code = pipeline_context.query_exec_tracker.execute_func( + pipeline_context.cache.get, + pipeline_context.cache.get_cache_key(pipeline_context), + tag="cache_hit", + ) + pipeline_context.add_intermediate_value("is_present_in_cache", True) + return code diff --git a/pandasai/pipelines/smart_datalake_chat/cache_population.py b/pandasai/pipelines/smart_datalake_chat/cache_population.py new file mode 100644 index 000000000..8d2791a07 --- /dev/null +++ b/pandasai/pipelines/smart_datalake_chat/cache_population.py @@ -0,0 +1,38 @@ +from typing import Any +from ..base_logic_unit import BaseLogicUnit +from ..pipeline_context import PipelineContext + + +class CachePopulation(BaseLogicUnit): + """ + Cache Population Stage + """ + + pass + + def execute(self, input: Any, **kwargs) -> Any: + """ + This method will return output according to + Implementation. + + :param input: Your input data. + :param kwargs: A dictionary of keyword arguments. + - 'logger' (any): The logger for logging. + - 'config' (Config): Global configurations for the test + - 'context' (any): The execution context. + + :return: The result of the execution. + """ + pipeline_context: PipelineContext = kwargs.get("context") + + code = input + + if pipeline_context.config.enable_cache and pipeline_context.cache: + pipeline_context.cache.set( + pipeline_context.cache.get_cache_key(pipeline_context), code + ) + + if pipeline_context.config.callback is not None: + pipeline_context.config.callback.on_code(code) + + return code diff --git a/pandasai/pipelines/smart_datalake_chat/code_execution.py b/pandasai/pipelines/smart_datalake_chat/code_execution.py new file mode 100644 index 000000000..a9d05f59e --- /dev/null +++ b/pandasai/pipelines/smart_datalake_chat/code_execution.py @@ -0,0 +1,120 @@ +import logging +import traceback +from typing import Any, List +from ...helpers.code_manager import CodeExecutionContext +from ...helpers.logger import Logger +from ..base_logic_unit import BaseLogicUnit +from ..pipeline_context import PipelineContext +from ...prompts.correct_error_prompt import CorrectErrorPrompt + + +class CodeExecution(BaseLogicUnit): + """ + Code Execution Stage + """ + + pass + + def execute(self, input: Any, **kwargs) -> Any: + """ + This method will return output according to + Implementation. + + :param input: Your input data. + :param kwargs: A dictionary of keyword arguments. + - 'logger' (any): The logger for logging. + - 'config' (Config): Global configurations for the test + - 'context' (any): The execution context. + + :return: The result of the execution. + """ + pipeline_context: PipelineContext = kwargs.get("context") + logger: Logger = kwargs.get("logger") + + code = input + retry_count = 0 + code_to_run = code + result = None + while retry_count < pipeline_context.config.max_retries: + try: + # Execute the code + code_context = CodeExecutionContext( + pipeline_context.get_intermediate_value("last_prompt_id"), + pipeline_context.get_intermediate_value("skills"), + ) + result = pipeline_context.get_intermediate_value( + "code_manager" + ).execute_code( + code=code_to_run, + context=code_context, + ) + + break + + except Exception as e: + if ( + not pipeline_context.config.use_error_correction_framework + or retry_count >= pipeline_context.config.max_retries - 1 + ): + raise e + + retry_count += 1 + + logger.log( + f"Failed to execute code with a correction framework " + f"[retry number: {retry_count}]", + level=logging.WARNING, + ) + + traceback_error = traceback.format_exc() + [ + code_to_run, + reasoning, + answer, + ] = pipeline_context.query_exec_tracker.execute_func( + self._retry_run_code, + code, + pipeline_context, + logger, + traceback_error, + ) + + pipeline_context.add_intermediate_value("reasoning", reasoning) + pipeline_context.add_intermediate_value("answer", answer) + + return result + + def _retry_run_code( + self, code: str, context: PipelineContext, logger: Logger, e: Exception + ) -> List: + """ + A method to retry the code execution with error correction framework. + + Args: + code (str): A python code + context (PipelineContext) : Pipeline Context + logger (Logger) : Logger + e (Exception): An exception + dataframes + + Returns (str): A python code + """ + + logger.log(f"Failed with error: {e}. Retrying", logging.ERROR) + + default_values = { + "engine": context.dfs[0].engine, + "code": code, + "error_returned": e, + } + error_correcting_instruction = context.get_intermediate_value("get_prompt")( + "correct_error", + default_prompt=CorrectErrorPrompt(), + default_values=default_values, + ) + + result = context.config.llm.generate_code(error_correcting_instruction) + if context.config.callback is not None: + context.config.callback.on_code(result[0]) + + return result diff --git a/pandasai/pipelines/smart_datalake_chat/code_generator.py b/pandasai/pipelines/smart_datalake_chat/code_generator.py new file mode 100644 index 000000000..148ec9b2f --- /dev/null +++ b/pandasai/pipelines/smart_datalake_chat/code_generator.py @@ -0,0 +1,54 @@ +from typing import Any +from ...helpers.logger import Logger +from ..pipeline_context import PipelineContext +from ..base_logic_unit import BaseLogicUnit + + +class CodeGenerator(BaseLogicUnit): + """ + LLM Code Generation Stage + """ + + pass + + def execute(self, input: Any, **kwargs) -> Any: + """ + This method will return output according to + Implementation. + + :param input: Your input data. + :param kwargs: A dictionary of keyword arguments. + - 'logger' (any): The logger for logging. + - 'config' (Config): Global configurations for the test + - 'context' (any): The execution context. + + :return: The result of the execution. + """ + pipeline_context: PipelineContext = kwargs.get("context") + logger: Logger = kwargs.get("logger") + + if self.skip_if is not None and self.skip_if(pipeline_context): + return input + + generate_python_code_instruction = input + + [ + code, + reasoning, + answer, + ] = pipeline_context.query_exec_tracker.execute_func( + pipeline_context.config.llm.generate_code, + generate_python_code_instruction, + ) + pipeline_context.add_intermediate_value("last_code_generated", code) + logger.log( + f"""Code generated: + ``` + {code} + ``` + """ + ) + pipeline_context.add_intermediate_value("last_reasoning", reasoning) + pipeline_context.add_intermediate_value("last_answer", answer) + + return code diff --git a/pandasai/pipelines/smart_datalake_chat/generate_smart_datalake_pipeline.py b/pandasai/pipelines/smart_datalake_chat/generate_smart_datalake_pipeline.py new file mode 100644 index 000000000..feeb2b816 --- /dev/null +++ b/pandasai/pipelines/smart_datalake_chat/generate_smart_datalake_pipeline.py @@ -0,0 +1,49 @@ +from typing import Optional +from ...helpers.logger import Logger +from ..pipeline import Pipeline +from ..pipeline_context import PipelineContext +from .cache_lookup import CacheLookup +from .cache_population import CachePopulation +from .code_execution import CodeExecution +from .code_generator import CodeGenerator +from .prompt_generation import PromptGeneration +from .result_parsing import ResultParsing +from .result_validation import ResultValidation + + +class GenerateSmartDatalakePipeline: + _pipeline: Pipeline + + def __init__( + self, + context: Optional[PipelineContext] = None, + logger: Optional[Logger] = None, + ): + self._pipeline = Pipeline( + context=context, + logger=logger, + steps=[ + CacheLookup(), + PromptGeneration( + lambda pipeline_context: pipeline_context.get_intermediate_value( + "is_present_in_cache" + ) + ), + CodeGenerator( + lambda pipeline_context: pipeline_context.get_intermediate_value( + "is_present_in_cache" + ) + ), + CachePopulation( + lambda pipeline_context: pipeline_context.get_intermediate_value( + "is_present_in_cache" + ) + ), + CodeExecution(), + ResultValidation(), + ResultParsing(), + ], + ) + + def run(self): + return self._pipeline.run() diff --git a/pandasai/pipelines/smart_datalake_chat/prompt_generation.py b/pandasai/pipelines/smart_datalake_chat/prompt_generation.py new file mode 100644 index 000000000..b81b65797 --- /dev/null +++ b/pandasai/pipelines/smart_datalake_chat/prompt_generation.py @@ -0,0 +1,54 @@ +from typing import Any +from ..base_logic_unit import BaseLogicUnit +from ..pipeline_context import PipelineContext +from ...prompts.generate_python_code import GeneratePythonCodePrompt + + +class PromptGeneration(BaseLogicUnit): + """ + Code Prompt Generation Stage + """ + + pass + + def execute(self, input: Any, **kwargs) -> Any: + """ + This method will return output according to + Implementation. + + :param input: Your input data. + :param kwargs: A dictionary of keyword arguments. + - 'logger' (any): The logger for logging. + - 'config' (Config): Global configurations for the test + - 'context' (any): The execution context. + + :return: The result of the execution. + """ + pipeline_context: PipelineContext = kwargs.get("context") + + default_values = { + # TODO: find a better way to determine the engine, + "engine": pipeline_context.dfs[0].engine, + "output_type_hint": pipeline_context.get_intermediate_value( + "output_type_helper" + ).template_hint, + "viz_library_type": pipeline_context.get_intermediate_value( + "viz_lib_helper" + ).template_hint, + } + + if ( + pipeline_context.memory.size > 1 + and pipeline_context.memory.count() > 1 + and pipeline_context.get_intermediate_value("last_code_generated") + ): + default_values["current_code"] = pipeline_context.get_intermediate_value( + "last_code_generated" + ) + + return pipeline_context.query_exec_tracker.execute_func( + pipeline_context.get_intermediate_value("get_prompt"), + "generate_python_code", + default_prompt=GeneratePythonCodePrompt(), + default_values=default_values, + ) diff --git a/pandasai/pipelines/smart_datalake_chat/result_parsing.py b/pandasai/pipelines/smart_datalake_chat/result_parsing.py new file mode 100644 index 000000000..8fc6df6c2 --- /dev/null +++ b/pandasai/pipelines/smart_datalake_chat/result_parsing.py @@ -0,0 +1,52 @@ +from typing import Any +from ..base_logic_unit import BaseLogicUnit +from ..pipeline_context import PipelineContext + + +class ResultParsing(BaseLogicUnit): + + """ + Result Parsing Stage + """ + + pass + + def execute(self, input: Any, **kwargs) -> Any: + """ + This method will return output according to + Implementation. + + :param input: Your input data. + :param kwargs: A dictionary of keyword arguments. + - 'logger' (any): The logger for logging. + - 'config' (Config): Global configurations for the test + - 'context' (any): The execution context. + + :return: The result of the execution. + """ + pipeline_context: PipelineContext = kwargs.get("context") + + result = input + + self._add_result_to_memory(result=result, context=pipeline_context) + + result = pipeline_context.query_exec_tracker.execute_func( + pipeline_context.get_intermediate_value("response_parser").parse, result + ) + return result + + def _add_result_to_memory(self, result: dict, context: PipelineContext): + """ + Add the result to the memory. + + Args: + result (dict): The result to add to the memory + context (PipelineContext) : Pipleline Context + """ + if result is None: + return + + if result["type"] in ["string", "number"]: + context.memory.add(result["value"], False) + elif result["type"] in ["dataframe", "plot"]: + context.memory.add("Ok here it is", False) diff --git a/pandasai/pipelines/smart_datalake_chat/result_validation.py b/pandasai/pipelines/smart_datalake_chat/result_validation.py new file mode 100644 index 000000000..ce5e37e11 --- /dev/null +++ b/pandasai/pipelines/smart_datalake_chat/result_validation.py @@ -0,0 +1,65 @@ +import logging +from typing import Any +from pandasai.helpers.logger import Logger +from ..base_logic_unit import BaseLogicUnit +from ..pipeline_context import PipelineContext + + +class ResultValidation(BaseLogicUnit): + """ + Result Validation Stage + """ + + pass + + def execute(self, input: Any, **kwargs) -> Any: + """ + This method will return output according to + Implementation. + + :param input: Your input data. + :param kwargs: A dictionary of keyword arguments. + - 'logger' (any): The logger for logging. + - 'config' (Config): Global configurations for the test + - 'context' (any): The execution context. + + :return: The result of the execution. + """ + pipeline_context: PipelineContext = kwargs.get("context") + logger: Logger = kwargs.get("logger") + + result = input + if result is not None: + if isinstance(result, dict): + ( + validation_ok, + validation_logs, + ) = pipeline_context.get_intermediate_value( + "output_type_helper" + ).validate(result) + if not validation_ok: + logger.log("\n".join(validation_logs), level=logging.WARNING) + pipeline_context.query_exec_tracker.add_step( + { + "type": "Validating Output", + "success": False, + "message": "Output Validation Failed", + } + ) + else: + pipeline_context.query_exec_tracker.add_step( + { + "type": "Validating Output", + "success": True, + "message": "Output Validation Successful", + } + ) + + pipeline_context.add_intermediate_value("last_result", result) + logger.log(f"Answer: {result}") + + logger.log( + f"Executed in: {pipeline_context.query_exec_tracker.get_execution_time()}s" + ) + + return result diff --git a/pandasai/schemas/df_config.py b/pandasai/schemas/df_config.py index fa100a90e..ca112103d 100644 --- a/pandasai/schemas/df_config.py +++ b/pandasai/schemas/df_config.py @@ -1,7 +1,6 @@ from pydantic import BaseModel, validator, Field -from typing import Optional, List, Any, Dict, Type, TypedDict +from typing import Optional, List, Any, Dict, TypedDict from pandasai.constants import DEFAULT_CHART_DIRECTORY -from pandasai.responses import ResponseParser from ..middlewares.base import Middleware from ..callbacks.base import BaseCallback from ..llm import LLM, LangchainLLM @@ -30,7 +29,8 @@ class Config(BaseModel): max_retries: int = 3 middlewares: List[Middleware] = Field(default_factory=list) callback: Optional[BaseCallback] = None - response_parser: Type[ResponseParser] = None + lazy_load_connector: bool = True + response_parser: Any = None llm: Any = None data_viz_library: Optional[VisualizationLibrary] = None log_server: LogServerConfig = None diff --git a/pandasai/smart_dataframe/__init__.py b/pandasai/smart_dataframe/__init__.py index 30711faa5..aa62f39fb 100644 --- a/pandasai/smart_dataframe/__init__.py +++ b/pandasai/smart_dataframe/__init__.py @@ -755,8 +755,6 @@ def load_smartdataframes( dfs (List[Union[DataFrameType, Any]]): List of dataframes to be used """ - from ..smart_dataframe import SmartDataframe - smart_dfs = [] for df in dfs: if not isinstance(df, SmartDataframe): diff --git a/pandasai/smart_datalake/__init__.py b/pandasai/smart_datalake/__init__.py index 76b7e196f..ca54f73e7 100644 --- a/pandasai/smart_datalake/__init__.py +++ b/pandasai/smart_datalake/__init__.py @@ -20,14 +20,18 @@ import uuid import logging import os -import traceback from pandasai.constants import DEFAULT_CHART_DIRECTORY from pandasai.helpers.skills_manager import SkillsManager +from pandasai.pipelines.pipeline_context import PipelineContext from pandasai.prompts.direct_sql_prompt import DirectSQLPrompt from pandasai.skills import skill from pandasai.helpers.query_exec_tracker import QueryExecTracker -from ..helpers.output_types import output_type_factory -from ..helpers.viz_library_types import viz_lib_type_factory +from ..pipelines.smart_datalake_chat.generate_smart_datalake_pipeline import ( + GenerateSmartDatalakePipeline, +) + +from pandasai.helpers.output_types import output_type_factory +from pandasai.helpers.viz_library_types import viz_lib_type_factory from pandasai.responses.context import Context from pandasai.responses.response_parser import ResponseParser from ..llm.base import LLM @@ -39,9 +43,9 @@ from ..config import load_config from ..prompts.base import AbstractPrompt from ..prompts.correct_error_prompt import CorrectErrorPrompt -from ..prompts.generate_python_code import GeneratePythonCodePrompt from typing import Union, List, Any, Optional -from ..helpers.code_manager import CodeExecutionContext, CodeManager +from ..prompts.generate_python_code import GeneratePythonCodePrompt +from ..helpers.code_manager import CodeManager from ..middlewares.base import Middleware from ..helpers.df_info import DataFrameType from ..helpers.path import find_project_root @@ -344,27 +348,7 @@ def _get_prompt( self.logger.log(f"Using prompt: {prompt}") return prompt - def _get_cache_key(self) -> str: - """ - Return the cache key for the current conversation. - - Returns: - str: The cache key for the current conversation - """ - cache_key = self._memory.get_conversation() - - # make the cache key unique for each combination of dfs - for df in self._dfs: - hash = df.column_hash() - cache_key += str(hash) - - # direct flag to separate out caching for different codegen - if self._config.direct_sql: - cache_key += "direct_sql" - - return cache_key - - def chat(self, query: str, output_type: Optional[str] = None) -> str: + def chat(self, query: str, output_type: Optional[str] = None): """ Run a query on the dataframe. @@ -388,29 +372,12 @@ def chat(self, query: str, output_type: Optional[str] = None) -> str: ValueError: If the query is empty """ - if not query: - raise ValueError("Query cannot be empty") - - self._query_exec_tracker.start_new_track() - - self.logger.log(f"Question: {query}") - self.logger.log(f"Running PandasAI with {self._llm.type} LLM...") - - self._assign_prompt_id() - - self._query_exec_tracker.add_query_info( - self._conversation_id, self._instance, query, output_type + pipeline_context = self.prepare_context_for_smart_datalake_pipeline( + query=query, output_type=output_type ) - self._query_exec_tracker.add_dataframes(self._dfs) - - self._memory.add(query, True) - - result_is_valid = False - try: - code = self._generate_code(output_type) - result = self._execute_code(code, output_type) + result = GenerateSmartDatalakePipeline(pipeline_context, self.logger).run() except Exception as exception: self.last_error = str(exception) self._query_exec_tracker.success = False @@ -422,163 +389,7 @@ def chat(self, query: str, output_type: Optional[str] = None) -> str: f"\n{exception}\n" ) - self.logger.log( - f"Executed in: {self._query_exec_tracker.get_execution_time()}s" - ) - - if result_is_valid: - self._add_result_to_memory(result) - else: - self.logger.log( - "The result will not be memorized since it has failed the " - "corresponding validation" - ) - - result = self._query_exec_tracker.execute_func( - self._response_parser.parse, result - ) - - self._query_exec_tracker.success = True - - self._query_exec_tracker.publish() - - return result - - def _generate_code(self, output_type: Optional[str] = None) -> str: - """ - Generate Python code from the query. - - Args: - output_type (Optional[str]): Add a hint for LLM which - type should be returned by `analyze_data()` in generated - code. Possible values: "number", "dataframe", "plot", "string": - * number - specifies that user expects to get a number - as a response object - * dataframe - specifies that user expects to get - pandas/polars dataframe as a response object - * plot - specifies that user expects LLM to build - a plot - * string - specifies that user expects to get text - as a response object - If none `output_type` is specified, the type can be any - of the above or "text". - - Returns: - (str): Generated Python code - """ - if ( - self._config.enable_cache - and self._cache - and self._cache.get(self._get_cache_key()) - ): - self.logger.log("Using cached response") - code = self._query_exec_tracker.execute_func( - self._cache.get, self._get_cache_key(), tag="cache_hit" - ) - - else: - default_values = { - # TODO: find a better way to determine the engine, - "engine": self._dfs[0].engine, - "output_type_hint": self._get_output_type_hint(output_type), - "viz_library_type": self._get_viz_library_type(), - } - - if ( - self.memory.size > 1 - and self.memory.count() > 1 - and self._last_code_generated - ): - default_values["current_code"] = self._last_code_generated - - prompt_key, prompt = self._get_chat_prompt() - generate_python_code_instruction = self._query_exec_tracker.execute_func( - self._get_prompt, - key=prompt_key, - default_prompt=prompt, - default_values=default_values, - ) - - [code, reasoning, answer] = self._query_exec_tracker.execute_func( - self._llm.generate_code, generate_python_code_instruction - ) - - self.last_reasoning = reasoning - self.last_answer = answer - - if self._config.enable_cache and self._cache: - self._cache.set(self._get_cache_key(), code) - - if self._config.callback is not None: - self._config.callback.on_code(code) - - self.last_code_generated = code - self.logger.log( - f"""Code generated: -``` -{code} -``` -""" - ) - return code - - def _execute_code(self, code: str, output_type: Optional[str] = None) -> Any: - """ - Execute the generated Python code. - - Args: - code (str): Generated Python code - - Returns: - (Any): Result of executing the code - """ - - retry_count = 0 - code_to_run = code - result = None - while retry_count < self._config.max_retries: - try: - # Execute the code - context = CodeExecutionContext( - self._last_prompt_id, self._skills, self._can_direct_sql - ) - result = self._code_manager.execute_code( - code=code_to_run, - context=context, - ) - - break - - except Exception as e: - if ( - not self._config.use_error_correction_framework - or retry_count >= self._config.max_retries - 1 - ): - raise e - - retry_count += 1 - - self._logger.log( - f"Failed to execute code with a correction framework " - f"[retry number: {retry_count}]", - level=logging.WARNING, - ) - - traceback_error = traceback.format_exc() - [ - code_to_run, - reasoning, - answer, - ] = self._query_exec_tracker.execute_func( - self._retry_run_code, code, traceback_error - ) - - if isinstance(result, dict): - self._validate_output(result, output_type) - - if result is not None: - self.last_result = result - self.logger.log(f"Answer: {result}") + self.update_intermediate_value_post_pipeline_execution(pipeline_context) return result @@ -628,11 +439,25 @@ def _validate_output(self, result: dict, output_type: Optional[str] = None): ) raise ValueError("Output validation failed") - def _get_output_type_hint(self, output_type: Optional[str]) -> str: + def _get_viz_library_type(self) -> str: + """ + Get the visualization library type based on the configured library. + + Returns: + (str): Visualization library type """ - Get the output type hint based on the specified output type. + + viz_lib_helper = viz_lib_type_factory(self._viz_lib, logger=self.logger) + return viz_lib_helper.template_hint + + def prepare_context_for_smart_datalake_pipeline( + self, query: str, output_type: Optional[str] = None + ) -> PipelineContext: + """ + Prepare Pipeline Context to intiate Smart Data Lake Pipeline. Args: + query (str): Query to run on the dataframe output_type (Optional[str]): Add a hint for LLM which type should be returned by `analyze_data()` in generated code. Possible values: "number", "dataframe", "plot", "string": @@ -648,34 +473,68 @@ def _get_output_type_hint(self, output_type: Optional[str]) -> str: of the above or "text". Returns: - (str): Output type hint + PipelineContext: The Pipeline Context to be used by Smart Data Lake Pipeline. """ - output_type_helper = output_type_factory(output_type, logger=self.logger) - return output_type_helper.template_hint + self._query_exec_tracker.start_new_track() - def _get_viz_library_type(self) -> str: - """ - Get the visualization library type based on the configured library. + self.logger.log(f"Question: {query}") + self.logger.log(f"Running PandasAI with {self._llm.type} LLM...") - Returns: - (str): Visualization library type - """ + self._assign_prompt_id() + + self._query_exec_tracker.add_query_info( + self._conversation_id, self._instance, query, output_type + ) + + self._query_exec_tracker.add_dataframes(self._dfs) + + self._memory.add(query, True) + output_type_helper = output_type_factory(output_type, logger=self.logger) viz_lib_helper = viz_lib_type_factory(self._viz_lib, logger=self.logger) - return viz_lib_helper.template_hint - def _add_result_to_memory(self, result: dict): + pipeline_context = PipelineContext( + dfs=self.dfs, + config=self.config, + memory=self.memory, + cache=self.cache, + query_exec_tracker=self._query_exec_tracker, + ) + pipeline_context.add_intermediate_value("is_present_in_cache", False) + pipeline_context.add_intermediate_value( + "output_type_helper", output_type_helper + ) + pipeline_context.add_intermediate_value("viz_lib_helper", viz_lib_helper) + pipeline_context.add_intermediate_value( + "last_code_generated", self._last_code_generated + ) + pipeline_context.add_intermediate_value("get_prompt", self._get_prompt) + pipeline_context.add_intermediate_value("last_prompt_id", self.last_prompt_id) + pipeline_context.add_intermediate_value("skills", self._skills) + pipeline_context.add_intermediate_value("code_manager", self._code_manager) + pipeline_context.add_intermediate_value( + "response_parser", self._response_parser + ) + + return pipeline_context + + def update_intermediate_value_post_pipeline_execution( + self, pipeline_context: PipelineContext + ): """ - Add the result to the memory. + After the Smart Data Lake Pipeline has executed, update values of Smart Data Lake object. Args: - result (dict): The result to add to the memory + pipeline_context (PipelineContext): Pipeline Context after the Smart Data Lake pipeline execution + """ - if result["type"] in ["string", "number"]: - self._memory.add(result["value"], False) - elif result["type"] in ["dataframe", "plot"]: - self._memory.add("Ok here it is", False) + self._last_reasoning = pipeline_context.get_intermediate_value("last_reasoning") + self._last_answer = pipeline_context.get_intermediate_value("last_answer") + self._last_code_generated = pipeline_context.get_intermediate_value( + "last_code_generated" + ) + self._last_result = pipeline_context.get_intermediate_value("last_result") def _retry_run_code(self, code: str, e: Exception) -> List: """ diff --git a/poetry.lock b/poetry.lock index a46f37b0a..d2dd4c140 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.0 and should not be changed by hand. [[package]] name = "aiohttp" diff --git a/tests/pipelines/smart_datalake/test_code_execution.py b/tests/pipelines/smart_datalake/test_code_execution.py new file mode 100644 index 000000000..7acadd4e1 --- /dev/null +++ b/tests/pipelines/smart_datalake/test_code_execution.py @@ -0,0 +1,187 @@ +from typing import Optional +from unittest.mock import Mock +import pandas as pd +import pytest +from pandasai.helpers.logger import Logger +from pandasai.helpers.skills_manager import SkillsManager + +from pandasai.llm.fake import FakeLLM +from pandasai.pipelines.pipeline_context import PipelineContext +from pandasai.smart_dataframe import SmartDataframe +from pandasai.pipelines.smart_datalake_chat.code_execution import CodeExecution + + +class TestCodeExecution: + "Unit test for Smart Data Lake Code Execution" + + throw_exception = True + + @pytest.fixture + def llm(self, output: Optional[str] = None): + return FakeLLM(output=output) + + @pytest.fixture + def sample_df(self): + return pd.DataFrame( + { + "country": [ + "United States", + "United Kingdom", + "France", + "Germany", + "Italy", + "Spain", + "Canada", + "Australia", + "Japan", + "China", + ], + "gdp": [ + 19294482071552, + 2891615567872, + 2411255037952, + 3435817336832, + 1745433788416, + 1181205135360, + 1607402389504, + 1490967855104, + 4380756541440, + 14631844184064, + ], + "happiness_index": [ + 6.94, + 7.16, + 6.66, + 7.07, + 6.38, + 6.4, + 7.23, + 7.22, + 5.87, + 5.12, + ], + } + ) + + @pytest.fixture + def smart_dataframe(self, llm, sample_df): + return SmartDataframe(sample_df, config={"llm": llm, "enable_cache": True}) + + @pytest.fixture + def config(self, llm): + return {"llm": llm, "enable_cache": True} + + @pytest.fixture + def context(self, sample_df, config): + return PipelineContext([sample_df], config) + + @pytest.fixture + def logger(self): + return Logger(True, False) + + def test_init(self, context, config): + # Test the initialization of the CodeExecution + code_execution = CodeExecution() + assert isinstance(code_execution, CodeExecution) + + def test_code_execution_successful_with_no_exceptions(self, context, logger): + # Test Flow : Code Execution Successful with no exceptions + code_execution = CodeExecution() + + mock_code_manager = Mock() + mock_code_manager.execute_code = Mock(return_value="Mocked Result") + + def mock_intermediate_values(key: str): + if key == "last_prompt_id": + return "Mocked Promt ID" + elif key == "skills": + return SkillsManager() + elif key == "code_manager": + return mock_code_manager + + context.get_intermediate_value = Mock(side_effect=mock_intermediate_values) + + result = code_execution.execute( + input="Test Code", context=context, logger=logger + ) + + assert isinstance(code_execution, CodeExecution) + assert result == "Mocked Result" + + def test_code_execution_unsuccessful_after_retries(self, context, logger): + # Test Flow : Code Execution Successful after retry + code_execution = CodeExecution() + + def mock_execute_code(*args, **kwargs): + raise Exception("Unit test exception") + + mock_code_manager = Mock() + mock_code_manager.execute_code = Mock(side_effect=mock_execute_code) + + context._query_exec_tracker = Mock() + context.query_exec_tracker.execute_func = Mock( + return_value=[ + "Interuppted Code", + "Exception Testing", + "Unsuccessful after Retries", + ] + ) + + def mock_intermediate_values(key: str): + if key == "last_prompt_id": + return "Mocked Promt ID" + elif key == "skills": + return SkillsManager() + elif key == "code_manager": + return mock_code_manager + + context.get_intermediate_value = Mock(side_effect=mock_intermediate_values) + + assert isinstance(code_execution, CodeExecution) + + result = None + try: + result = code_execution.execute( + input="Test Code", context=context, logger=logger + ) + except Exception: + assert result is None + + def test_code_execution_successful_at_retry(self, context, logger): + # Test Flow : Code Execution Successful with no exceptions + code_execution = CodeExecution() + + def mock_execute_code(*args, **kwargs): + if self.throw_exception is True: + self.throw_exception = False + raise Exception("Unit test exception") + return "Mocked Result after retry" + + mock_code_manager = Mock() + mock_code_manager.execute_code = Mock(side_effect=mock_execute_code) + + context._query_exec_tracker = Mock() + context.query_exec_tracker.execute_func = Mock( + return_value=[ + "Interuppted Code", + "Exception Testing", + "Successful after Retry", + ] + ) + + def mock_intermediate_values(key: str): + if key == "last_prompt_id": + return "Mocked Promt ID" + elif key == "skills": + return SkillsManager() + elif key == "code_manager": + return mock_code_manager + + context.get_intermediate_value = Mock(side_effect=mock_intermediate_values) + + result = code_execution.execute( + input="Test Code", context=context, logger=logger + ) + + assert isinstance(code_execution, CodeExecution) + assert result == "Mocked Result after retry" diff --git a/tests/pipelines/smart_datalake/test_code_generator.py b/tests/pipelines/smart_datalake/test_code_generator.py new file mode 100644 index 000000000..32bc83083 --- /dev/null +++ b/tests/pipelines/smart_datalake/test_code_generator.py @@ -0,0 +1,116 @@ +from typing import Optional +from unittest.mock import Mock +import pandas as pd + +import pytest +from pandasai.helpers.logger import Logger +from pandasai.helpers.output_types import output_type_factory +from pandasai.helpers.viz_library_types import viz_lib_type_factory +from pandasai.llm.fake import FakeLLM +from pandasai.pipelines.pipeline_context import PipelineContext +from pandasai.prompts.generate_python_code import GeneratePythonCodePrompt + +from pandasai.smart_dataframe import SmartDataframe +from pandasai.pipelines.smart_datalake_chat.code_generator import CodeGenerator + + +class TestCodeGenerator: + "Unit test for Smart Data Lake Code Generator" + + @pytest.fixture + def llm(self, output: Optional[str] = None): + return FakeLLM(output=output) + + @pytest.fixture + def sample_df(self): + return pd.DataFrame( + { + "country": [ + "United States", + "United Kingdom", + "France", + "Germany", + "Italy", + "Spain", + "Canada", + "Australia", + "Japan", + "China", + ], + "gdp": [ + 19294482071552, + 2891615567872, + 2411255037952, + 3435817336832, + 1745433788416, + 1181205135360, + 1607402389504, + 1490967855104, + 4380756541440, + 14631844184064, + ], + "happiness_index": [ + 6.94, + 7.16, + 6.66, + 7.07, + 6.38, + 6.4, + 7.23, + 7.22, + 5.87, + 5.12, + ], + } + ) + + @pytest.fixture + def smart_dataframe(self, llm, sample_df): + return SmartDataframe(sample_df, config={"llm": llm, "enable_cache": True}) + + @pytest.fixture + def config(self, llm): + return {"llm": llm, "enable_cache": True} + + @pytest.fixture + def context(self, sample_df, config): + return PipelineContext([sample_df], config) + + @pytest.fixture + def logger(self): + return Logger(True, False) + + def test_init(self, context, config): + # Test the initialization of the CodeGenerator + code_generator = CodeGenerator() + assert isinstance(code_generator, CodeGenerator) + + def test_code_not_found_in_cache(self, context, logger): + # Test Flow : Code Not found in the cache + code_generator = CodeGenerator() + + mock_get_promt = Mock(return_value=GeneratePythonCodePrompt) + + def mock_intermediate_values(key: str): + if key == "output_type_helper": + return output_type_factory("DefaultOutputType") + elif key == "viz_lib_helper": + return viz_lib_type_factory("DefaultVizLibraryType") + elif key == "get_prompt": + return mock_get_promt + + def mock_execute_func(function, *args, **kwargs): + if function == mock_get_promt: + return mock_get_promt() + return ["Mocked LLM Generated Code", "Mocked Reasoning", "Mocked Answer"] + + context.get_intermediate_value = Mock(side_effect=mock_intermediate_values) + context._cache = Mock() + context.cache.get = Mock(return_value=None) + context._query_exec_tracker = Mock() + context.query_exec_tracker.execute_func = Mock(side_effect=mock_execute_func) + + code = code_generator.execute(input=None, context=context, logger=logger) + + assert isinstance(code_generator, CodeGenerator) + assert code == "Mocked LLM Generated Code" diff --git a/tests/pipelines/smart_datalake/test_result_parsing.py b/tests/pipelines/smart_datalake/test_result_parsing.py new file mode 100644 index 000000000..08bf1e1fb --- /dev/null +++ b/tests/pipelines/smart_datalake/test_result_parsing.py @@ -0,0 +1,134 @@ +from typing import Optional +from unittest.mock import Mock +import pandas as pd +import pytest +from pandasai.helpers.logger import Logger + +from pandasai.llm.fake import FakeLLM +from pandasai.pipelines.pipeline_context import PipelineContext +from pandasai.smart_dataframe import SmartDataframe +from pandasai.pipelines.smart_datalake_chat.result_parsing import ResultParsing + + +class TestResultParsing: + "Unit test for Smart Data Lake Result Parsing" + + throw_exception = True + + @pytest.fixture + def llm(self, output: Optional[str] = None): + return FakeLLM(output=output) + + @pytest.fixture + def sample_df(self): + return pd.DataFrame( + { + "country": [ + "United States", + "United Kingdom", + "France", + "Germany", + "Italy", + "Spain", + "Canada", + "Australia", + "Japan", + "China", + ], + "gdp": [ + 19294482071552, + 2891615567872, + 2411255037952, + 3435817336832, + 1745433788416, + 1181205135360, + 1607402389504, + 1490967855104, + 4380756541440, + 14631844184064, + ], + "happiness_index": [ + 6.94, + 7.16, + 6.66, + 7.07, + 6.38, + 6.4, + 7.23, + 7.22, + 5.87, + 5.12, + ], + } + ) + + @pytest.fixture + def smart_dataframe(self, llm, sample_df): + return SmartDataframe(sample_df, config={"llm": llm, "enable_cache": True}) + + @pytest.fixture + def config(self, llm): + return {"llm": llm, "enable_cache": True} + + @pytest.fixture + def context(self, sample_df, config): + return PipelineContext([sample_df], config) + + @pytest.fixture + def logger(self): + return Logger(True, False) + + def test_init(self, context, config): + # Test the initialization of the CodeExecution + result_parsing = ResultParsing() + assert isinstance(result_parsing, ResultParsing) + + def test_result_parsing_successful_with_no_exceptions(self, context, logger): + # Test Flow : Code Execution Successful with no exceptions + result_parsing = ResultParsing() + result_parsing._add_result_to_memory = Mock() + mock_response_parser = Mock() + context._query_exec_tracker = Mock() + context.query_exec_tracker.execute_func = Mock( + return_value="Mocked Parsed Result" + ) + + def mock_intermediate_values(key: str): + if key == "response_parser": + return mock_response_parser + + context.get_intermediate_value = Mock(side_effect=mock_intermediate_values) + + result = result_parsing.execute( + input="Test Result", context=context, logger=logger + ) + + assert isinstance(result_parsing, ResultParsing) + assert result == "Mocked Parsed Result" + + def test_result_parsing_unsuccessful_with_exceptions(self, context, logger): + # Test Flow : Code Execution Unsuccessful with exceptions + result_parsing = ResultParsing() + result_parsing._add_result_to_memory = Mock() + mock_response_parser = Mock() + + def mock_result_parsing(*args, **kwargs): + raise Exception("Unit test exception") + + context._query_exec_tracker = Mock() + context.query_exec_tracker.execute_func = Mock(side_effect=mock_result_parsing) + + def mock_intermediate_values(key: str): + if key == "response_parser": + return mock_response_parser + + context.get_intermediate_value = Mock(side_effect=mock_intermediate_values) + + result = None + try: + result = result_parsing.execute( + input="Test Result", context=context, logger=logger + ) + except Exception: + assert result is None + assert isinstance(result_parsing, ResultParsing) diff --git a/tests/pipelines/smart_datalake/test_result_validation.py b/tests/pipelines/smart_datalake/test_result_validation.py new file mode 100644 index 000000000..b150a7188 --- /dev/null +++ b/tests/pipelines/smart_datalake/test_result_validation.py @@ -0,0 +1,162 @@ +from typing import Optional +from unittest.mock import Mock +import pandas as pd +import pytest +from pandasai.helpers.logger import Logger + +from pandasai.llm.fake import FakeLLM +from pandasai.pipelines.pipeline_context import PipelineContext +from pandasai.smart_dataframe import SmartDataframe +from pandasai.pipelines.smart_datalake_chat.result_validation import ResultValidation + + +class TestResultValidation: + "Unit test for Smart Data Lake Result Validation" + + throw_exception = True + + @pytest.fixture + def llm(self, output: Optional[str] = None): + return FakeLLM(output=output) + + @pytest.fixture + def sample_df(self): + return pd.DataFrame( + { + "country": [ + "United States", + "United Kingdom", + "France", + "Germany", + "Italy", + "Spain", + "Canada", + "Australia", + "Japan", + "China", + ], + "gdp": [ + 19294482071552, + 2891615567872, + 2411255037952, + 3435817336832, + 1745433788416, + 1181205135360, + 1607402389504, + 1490967855104, + 4380756541440, + 14631844184064, + ], + "happiness_index": [ + 6.94, + 7.16, + 6.66, + 7.07, + 6.38, + 6.4, + 7.23, + 7.22, + 5.87, + 5.12, + ], + } + ) + + @pytest.fixture + def smart_dataframe(self, llm, sample_df): + return SmartDataframe(sample_df, config={"llm": llm, "enable_cache": True}) + + @pytest.fixture + def config(self, llm): + return {"llm": llm, "enable_cache": True} + + @pytest.fixture + def context(self, sample_df, config): + return PipelineContext([sample_df], config) + + @pytest.fixture + def logger(self): + return Logger(True, False) + + def test_init(self, context, config): + # Test the initialization of the CodeExecution + result_validation = ResultValidation() + assert isinstance(result_validation, ResultValidation) + + def test_result_is_none(self, context, logger): + # Test Flow : Code Execution Successful with no exceptions + result_validation = ResultValidation() + + context._query_exec_tracker = Mock() + context.query_exec_tracker.get_execution_time = Mock() + context.query_exec_tracker.add_step = Mock() + + result = result_validation.execute(input=None, context=context, logger=logger) + + assert not context.query_exec_tracker.add_step.called + assert isinstance(result_validation, ResultValidation) + assert result is None + + def test_result_is_not_of_dict_type(self, context, logger): + # Test Flow : Code Execution Successful with no exceptions + result_validation = ResultValidation() + + context._query_exec_tracker = Mock() + context.query_exec_tracker.get_execution_time = Mock() + context.query_exec_tracker.add_step = Mock() + + result = result_validation.execute( + input="Not Dict Type Result", context=context, logger=logger + ) + + assert not context.query_exec_tracker.add_step.called + assert isinstance(result_validation, ResultValidation) + assert result == "Not Dict Type Result" + + def test_result_is_of_dict_type_and_valid(self, context, logger): + # Test Flow : Code Execution Successful with no exceptions + result_validation = ResultValidation() + output_type_helper = Mock() + + context._query_exec_tracker = Mock() + context.query_exec_tracker.get_execution_time = Mock() + context.get_intermediate_value = Mock(return_value=output_type_helper) + output_type_helper.validate = Mock(return_value=(True, "Mocked Logs")) + + result = result_validation.execute( + input={"Mocked": "Result"}, context=context, logger=logger + ) + + context.query_exec_tracker.add_step.assert_called_with( + { + "type": "Validating Output", + "success": True, + "message": "Output Validation Successful", + } + ) + assert isinstance(result_validation, ResultValidation) + assert result == {"Mocked": "Result"} + + def test_result_is_of_dict_type_and_not_valid(self, context, logger): + # Test Flow : Code Execution Successful with no exceptions + result_validation = ResultValidation() + output_type_helper = Mock() + + context._query_exec_tracker = Mock() + context.query_exec_tracker.get_execution_time = Mock() + context.get_intermediate_value = Mock(return_value=output_type_helper) + output_type_helper.validate = Mock(return_value=(False, "Mocked Logs")) + + result = result_validation.execute( + input={"Mocked": "Result"}, context=context, logger=logger + ) + + context.query_exec_tracker.add_step.assert_called_with( + { + "type": "Validating Output", + "success": False, + "message": "Output Validation Failed", + } + ) + assert isinstance(result_validation, ResultValidation) + assert result == {"Mocked": "Result"} diff --git a/tests/test_smartdatalake.py b/tests/test_smartdatalake.py index 3195a53de..83527a49f 100644 --- a/tests/test_smartdatalake.py +++ b/tests/test_smartdatalake.py @@ -266,6 +266,7 @@ def analyze_data(dfs): ```""" ) smart_datalake._llm = llm + smart_datalake._config.llm = llm smart_datalake.config.use_advanced_reasoning_framework = True assert smart_datalake.last_answer is None assert smart_datalake.last_reasoning is None