Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(Pipelines) : Smart Data Frame Pipeline #735

Merged
merged 15 commits into from
Nov 12, 2023
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions pandasai/helpers/cache.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import glob
from typing import Any
import duckdb
from .path import find_project_root

Expand Down Expand Up @@ -72,3 +73,19 @@ def destroy(self) -> None:
self.connection.close()
for cache_file in glob.glob(f"{self.filepath}.*"):
os.remove(cache_file)

def get_cache_key(self, context: Any) -> str:
"""
Return the cache key for the current conversation.
Returns:
str: The cache key for the current conversation
"""
cache_key = context.memory.get_conversation()

# make the cache key unique for each combination of dfs
for df in context.dfs:
hash = df.column_hash()
cache_key += str(hash)

return cache_key
10 changes: 10 additions & 0 deletions pandasai/pipelines/base_logic_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@ class BaseLogicUnit(ABC):
Logic units for pipeline each logic unit should be inherited from this Logic unit
"""

_skip_if: callable

def __init__(self, skip_if=None):
super().__init__()
self._skip_if = skip_if

@abstractmethod
def execute(self, input: Any, **kwargs) -> Any:
"""
Expand All @@ -22,3 +28,7 @@ def execute(self, input: Any, **kwargs) -> Any:
:return: The result of the execution.
"""
raise NotImplementedError("execute method is not implemented.")

@property
def skip_if(self):
return self._skip_if
11 changes: 7 additions & 4 deletions pandasai/pipelines/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from pandasai.helpers.logger import Logger
from pandasai.pipelines.pipeline_context import PipelineContext
from pandasai.pipelines.base_logic_unit import BaseLogicUnit
from pandasai.smart_dataframe import SmartDataframe, load_smartdataframes
from ..schemas.df_config import Config
from typing import Any, Optional, List, Union
from .abstract_pipeline import AbstractPipeline
Expand All @@ -22,9 +21,7 @@ class Pipeline(AbstractPipeline):

def __init__(
self,
context: Union[
List[Union[DataFrameType, SmartDataframe]], PipelineContext
] = None,
context: Union[List[Union[DataFrameType, Any]], PipelineContext] = None,
config: Optional[Union[Config, dict]] = None,
steps: Optional[List] = None,
logger: Optional[Logger] = None,
Expand All @@ -40,6 +37,8 @@ def __init__(
"""

if not isinstance(context, PipelineContext):
from pandasai.smart_dataframe import load_smartdataframes

config = Config(**load_config(config))
smart_dfs = load_smartdataframes(context, config)
context = PipelineContext(smart_dfs, config)
Expand Down Expand Up @@ -79,6 +78,10 @@ def run(self, data: Any = None) -> Any:
try:
for index, logic in enumerate(self._steps):
self._logger.log(f"Executing Step {index}: {logic.__class__.__name__}")

if logic.skip_if is not None and logic.skip_if(self._context):
continue
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@milind-sinaptik add log if skip for the debugging purpose

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


data = logic.execute(
data,
logger=self._logger,
Expand Down
27 changes: 22 additions & 5 deletions pandasai/pipelines/pipeline_context.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,37 @@
from typing import List, Optional, Union
from typing import List, Optional, Union, Any
from pandasai.helpers.cache import Cache

from pandasai.helpers.df_info import DataFrameType
from pandasai.helpers.memory import Memory
from pandasai.helpers.query_exec_tracker import QueryExecTracker
from pandasai.helpers.skills_manager import SkillsManager
from pandasai.schemas.df_config import Config
from pandasai.smart_dataframe import SmartDataframe, load_smartdataframes


class PipelineContext:
"""
Pass Context to the pipeline which is accessible to each step via kwargs
"""

_dfs: List[Union[DataFrameType, SmartDataframe]]
_dfs: List[Union[DataFrameType, Any]]
_memory: Memory
_skills: SkillsManager
_cache: Cache
_config: Config
_query_exec_tracker: QueryExecTracker
_intermediate_values: dict

def __init__(
self,
dfs: List[Union[DataFrameType, SmartDataframe]],
dfs: List[Union[DataFrameType, Any]],
config: Optional[Union[Config, dict]] = None,
memory: Memory = None,
skills: SkillsManager = None,
cache: Cache = None,
query_exec_tracker: QueryExecTracker = None,
) -> None:
from pandasai.smart_dataframe import load_smartdataframes

if isinstance(config, dict):
config = Config(**config)

Expand All @@ -35,9 +40,11 @@ def __init__(
self._skills = skills if skills is not None else SkillsManager()
self._cache = cache if cache is not None else Cache()
self._config = config
self._query_exec_tracker = query_exec_tracker
self._intermediate_values = {}

@property
def dfs(self) -> List[Union[DataFrameType, SmartDataframe]]:
def dfs(self) -> List[Union[DataFrameType, Any]]:
return self._dfs

@property
Expand All @@ -55,3 +62,13 @@ def cache(self):
@property
def config(self):
return self._config

@property
def query_exec_tracker(self):
return self._query_exec_tracker

def add_intermediate_value(self, key: str, value: Any):
self._intermediate_values[key] = value

def get_intermediate_value(self, key: str):
return self._intermediate_values.get(key, "")
43 changes: 43 additions & 0 deletions pandasai/pipelines/smart_datalake_chat/cache_lookup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from typing import Any
from ...helpers.logger import Logger
from ..base_logic_unit import BaseLogicUnit
from ..pipeline_context import PipelineContext


class CacheLookup(BaseLogicUnit):
"""
Cache Lookup of Code Stage
"""

pass

def execute(self, input: Any, **kwargs) -> Any:
"""
This method will return output according to
Implementation.

:param input: Your input data.
:param kwargs: A dictionary of keyword arguments.
- 'logger' (any): The logger for logging.
- 'config' (Config): Global configurations for the test
- 'context' (any): The execution context.

:return: The result of the execution.
"""
pipeline_context: PipelineContext = kwargs.get("context")
logger: Logger = kwargs.get("logger")
if (
pipeline_context.config.enable_cache
and pipeline_context.cache
and pipeline_context.cache.get(
pipeline_context.cache.get_cache_key(pipeline_context)
)
):
logger.log("Using cached response")
code = pipeline_context.query_exec_tracker.execute_func(
pipeline_context.cache.get,
pipeline_context.cache.get_cache_key(pipeline_context),
tag="cache_hit",
)
pipeline_context.add_intermediate_value("is_present_in_cache", True)
return code
38 changes: 38 additions & 0 deletions pandasai/pipelines/smart_datalake_chat/cache_population.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from typing import Any
from ..base_logic_unit import BaseLogicUnit
from ..pipeline_context import PipelineContext


class CachePopulation(BaseLogicUnit):
"""
Cache Population Stage
"""

pass

def execute(self, input: Any, **kwargs) -> Any:
"""
This method will return output according to
Implementation.
:param input: Your input data.
:param kwargs: A dictionary of keyword arguments.
- 'logger' (any): The logger for logging.
- 'config' (Config): Global configurations for the test
- 'context' (any): The execution context.
:return: The result of the execution.
"""
pipeline_context: PipelineContext = kwargs.get("context")

code = input

if pipeline_context.config.enable_cache and pipeline_context.cache:
pipeline_context.cache.set(
pipeline_context.cache.get_cache_key(pipeline_context), code
)

if pipeline_context.config.callback is not None:
pipeline_context.config.callback.on_code(code)

return code
120 changes: 120 additions & 0 deletions pandasai/pipelines/smart_datalake_chat/code_execution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import logging
import traceback
from typing import Any, List
from ...helpers.code_manager import CodeExecutionContext
from ...helpers.logger import Logger
from ..base_logic_unit import BaseLogicUnit
from ..pipeline_context import PipelineContext
from ...prompts.correct_error_prompt import CorrectErrorPrompt


class CodeExecution(BaseLogicUnit):
"""
Code Execution Stage
"""

pass

def execute(self, input: Any, **kwargs) -> Any:
"""
This method will return output according to
Implementation.
:param input: Your input data.
:param kwargs: A dictionary of keyword arguments.
- 'logger' (any): The logger for logging.
- 'config' (Config): Global configurations for the test
- 'context' (any): The execution context.
:return: The result of the execution.
"""
pipeline_context: PipelineContext = kwargs.get("context")
logger: Logger = kwargs.get("logger")

code = input
retry_count = 0
code_to_run = code
result = None
while retry_count < pipeline_context.config.max_retries:
try:
# Execute the code
code_context = CodeExecutionContext(
pipeline_context.get_intermediate_value("last_prompt_id"),
pipeline_context.get_intermediate_value("skills"),
)
result = pipeline_context.get_intermediate_value(
"code_manager"
).execute_code(
code=code_to_run,
context=code_context,
)

break

except Exception as e:
if (
not pipeline_context.config.use_error_correction_framework
or retry_count >= pipeline_context.config.max_retries - 1
):
raise e

retry_count += 1

logger.log(
f"Failed to execute code with a correction framework "
f"[retry number: {retry_count}]",
level=logging.WARNING,
)

traceback_error = traceback.format_exc()
[
code_to_run,
reasoning,
answer,
] = pipeline_context.query_exec_tracker.execute_func(
self._retry_run_code,
code,
pipeline_context,
logger,
traceback_error,
)

pipeline_context.add_intermediate_value("reasoning", reasoning)
pipeline_context.add_intermediate_value("answer", answer)

return result

def _retry_run_code(
self, code: str, context: PipelineContext, logger: Logger, e: Exception
) -> List:
"""
A method to retry the code execution with error correction framework.
Args:
code (str): A python code
context (PipelineContext) : Pipeline Context
logger (Logger) : Logger
e (Exception): An exception
dataframes
Returns (str): A python code
"""

logger.log(f"Failed with error: {e}. Retrying", logging.ERROR)

default_values = {
"engine": context.dfs[0].engine,
"code": code,
"error_returned": e,
}
error_correcting_instruction = context.get_intermediate_value("get_prompt")(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@milind-sinaptik why not move get prompt function here in that function and set_vars that are necessary for CorrectErrorPrompt.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried moving this but there are some test which start failing as move this function. Putting aside this for now.

"correct_error",
default_prompt=CorrectErrorPrompt(),
default_values=default_values,
)

result = context.config.llm.generate_code(error_correcting_instruction)
if context.config.callback is not None:
context.config.callback.on_code(result[0])

return result
Loading