Skip to content

Commit 5dd26ba

Browse files
milind-sinaptikMilind Lalwanisourcery-ai[bot]gventuri
authored
refactor(Pipelines) : Smart Data Frame Pipeline (#735)
* refactor(Pipelines) : Smart Data Frame Pipeline * 'Refactored by Sourcery' (#736) Co-authored-by: Sourcery AI <> * refactor(Pipelines) : made changes according to PR review * refactor(Pipelines) : Unit test cases added * refactor(Pipelines) : Unit Test cases added cd /Users/milindlalwani/pandas-ai ; /usr/bin/env /Users/milindlalwani/anaconda3/envs/pandas-ai/bin/python /Users/milindlalwani/.vscode/extensions/ms-python.python-2023.20.0/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher 59121 -- /Users/milindlalwani/pandas-ai/examples/from_csv.py * refactor(Pipelines) : Broken Test Cases Fixed * refactor(Pipelines) : Skip Logic added and More Steps created for Data Smart Lake pipeline * 'Refactored by Sourcery' (#740) Co-authored-by: Sourcery AI <> * refactor: move pipeline logic unit from sdf to pipelines folder * refactor(Pipelines) : Merge Comflicts Fixed * build: fix .lock file --------- Co-authored-by: Milind Lalwani <milindlalwani@Milinds-MacBook-Air.local> Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com> Co-authored-by: Gabriele Venturi <lele.venturi@gmail.com>
1 parent 98d39d3 commit 5dd26ba

21 files changed

+1216
-237
lines changed

pandasai/helpers/cache.py

+17
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
import glob
3+
from typing import Any
34
import duckdb
45
from .path import find_project_root
56

@@ -72,3 +73,19 @@ def destroy(self) -> None:
7273
self.connection.close()
7374
for cache_file in glob.glob(f"{self.filepath}.*"):
7475
os.remove(cache_file)
76+
77+
def get_cache_key(self, context: Any) -> str:
78+
"""
79+
Return the cache key for the current conversation.
80+
81+
Returns:
82+
str: The cache key for the current conversation
83+
"""
84+
cache_key = context.memory.get_conversation()
85+
86+
# make the cache key unique for each combination of dfs
87+
for df in context.dfs:
88+
hash = df.column_hash()
89+
cache_key += str(hash)
90+
91+
return cache_key

pandasai/pipelines/base_logic_unit.py

+10
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,12 @@ class BaseLogicUnit(ABC):
77
Logic units for pipeline each logic unit should be inherited from this Logic unit
88
"""
99

10+
_skip_if: callable
11+
12+
def __init__(self, skip_if=None):
13+
super().__init__()
14+
self._skip_if = skip_if
15+
1016
@abstractmethod
1117
def execute(self, input: Any, **kwargs) -> Any:
1218
"""
@@ -22,3 +28,7 @@ def execute(self, input: Any, **kwargs) -> Any:
2228
:return: The result of the execution.
2329
"""
2430
raise NotImplementedError("execute method is not implemented.")
31+
32+
@property
33+
def skip_if(self):
34+
return self._skip_if

pandasai/pipelines/pipeline.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from pandasai.helpers.logger import Logger
66
from pandasai.pipelines.pipeline_context import PipelineContext
77
from pandasai.pipelines.base_logic_unit import BaseLogicUnit
8-
from pandasai.smart_dataframe import SmartDataframe, load_smartdataframes
98
from ..schemas.df_config import Config
109
from typing import Any, Optional, List, Union
1110
from .abstract_pipeline import AbstractPipeline
@@ -22,9 +21,7 @@ class Pipeline(AbstractPipeline):
2221

2322
def __init__(
2423
self,
25-
context: Union[
26-
List[Union[DataFrameType, SmartDataframe]], PipelineContext
27-
] = None,
24+
context: Union[List[Union[DataFrameType, Any]], PipelineContext] = None,
2825
config: Optional[Union[Config, dict]] = None,
2926
steps: Optional[List] = None,
3027
logger: Optional[Logger] = None,
@@ -40,6 +37,8 @@ def __init__(
4037
"""
4138

4239
if not isinstance(context, PipelineContext):
40+
from pandasai.smart_dataframe import load_smartdataframes
41+
4342
config = Config(**load_config(config))
4443
smart_dfs = load_smartdataframes(context, config)
4544
context = PipelineContext(smart_dfs, config)
@@ -79,6 +78,10 @@ def run(self, data: Any = None) -> Any:
7978
try:
8079
for index, logic in enumerate(self._steps):
8180
self._logger.log(f"Executing Step {index}: {logic.__class__.__name__}")
81+
82+
if logic.skip_if is not None and logic.skip_if(self._context):
83+
continue
84+
8285
data = logic.execute(
8386
data,
8487
logger=self._logger,

pandasai/pipelines/pipeline_context.py

+22-5
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,37 @@
1-
from typing import List, Optional, Union
1+
from typing import List, Optional, Union, Any
22
from pandasai.helpers.cache import Cache
33

44
from pandasai.helpers.df_info import DataFrameType
55
from pandasai.helpers.memory import Memory
6+
from pandasai.helpers.query_exec_tracker import QueryExecTracker
67
from pandasai.helpers.skills_manager import SkillsManager
78
from pandasai.schemas.df_config import Config
8-
from pandasai.smart_dataframe import SmartDataframe, load_smartdataframes
99

1010

1111
class PipelineContext:
1212
"""
1313
Pass Context to the pipeline which is accessible to each step via kwargs
1414
"""
1515

16-
_dfs: List[Union[DataFrameType, SmartDataframe]]
16+
_dfs: List[Union[DataFrameType, Any]]
1717
_memory: Memory
1818
_skills: SkillsManager
1919
_cache: Cache
2020
_config: Config
21+
_query_exec_tracker: QueryExecTracker
22+
_intermediate_values: dict
2123

2224
def __init__(
2325
self,
24-
dfs: List[Union[DataFrameType, SmartDataframe]],
26+
dfs: List[Union[DataFrameType, Any]],
2527
config: Optional[Union[Config, dict]] = None,
2628
memory: Memory = None,
2729
skills: SkillsManager = None,
2830
cache: Cache = None,
31+
query_exec_tracker: QueryExecTracker = None,
2932
) -> None:
33+
from pandasai.smart_dataframe import load_smartdataframes
34+
3035
if isinstance(config, dict):
3136
config = Config(**config)
3237

@@ -35,9 +40,11 @@ def __init__(
3540
self._skills = skills if skills is not None else SkillsManager()
3641
self._cache = cache if cache is not None else Cache()
3742
self._config = config
43+
self._query_exec_tracker = query_exec_tracker
44+
self._intermediate_values = {}
3845

3946
@property
40-
def dfs(self) -> List[Union[DataFrameType, SmartDataframe]]:
47+
def dfs(self) -> List[Union[DataFrameType, Any]]:
4148
return self._dfs
4249

4350
@property
@@ -55,3 +62,13 @@ def cache(self):
5562
@property
5663
def config(self):
5764
return self._config
65+
66+
@property
67+
def query_exec_tracker(self):
68+
return self._query_exec_tracker
69+
70+
def add_intermediate_value(self, key: str, value: Any):
71+
self._intermediate_values[key] = value
72+
73+
def get_intermediate_value(self, key: str):
74+
return self._intermediate_values.get(key, "")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from typing import Any
2+
from ...helpers.logger import Logger
3+
from ..base_logic_unit import BaseLogicUnit
4+
from ..pipeline_context import PipelineContext
5+
6+
7+
class CacheLookup(BaseLogicUnit):
8+
"""
9+
Cache Lookup of Code Stage
10+
"""
11+
12+
pass
13+
14+
def execute(self, input: Any, **kwargs) -> Any:
15+
"""
16+
This method will return output according to
17+
Implementation.
18+
19+
:param input: Your input data.
20+
:param kwargs: A dictionary of keyword arguments.
21+
- 'logger' (any): The logger for logging.
22+
- 'config' (Config): Global configurations for the test
23+
- 'context' (any): The execution context.
24+
25+
:return: The result of the execution.
26+
"""
27+
pipeline_context: PipelineContext = kwargs.get("context")
28+
logger: Logger = kwargs.get("logger")
29+
if (
30+
pipeline_context.config.enable_cache
31+
and pipeline_context.cache
32+
and pipeline_context.cache.get(
33+
pipeline_context.cache.get_cache_key(pipeline_context)
34+
)
35+
):
36+
logger.log("Using cached response")
37+
code = pipeline_context.query_exec_tracker.execute_func(
38+
pipeline_context.cache.get,
39+
pipeline_context.cache.get_cache_key(pipeline_context),
40+
tag="cache_hit",
41+
)
42+
pipeline_context.add_intermediate_value("is_present_in_cache", True)
43+
return code
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from typing import Any
2+
from ..base_logic_unit import BaseLogicUnit
3+
from ..pipeline_context import PipelineContext
4+
5+
6+
class CachePopulation(BaseLogicUnit):
7+
"""
8+
Cache Population Stage
9+
"""
10+
11+
pass
12+
13+
def execute(self, input: Any, **kwargs) -> Any:
14+
"""
15+
This method will return output according to
16+
Implementation.
17+
18+
:param input: Your input data.
19+
:param kwargs: A dictionary of keyword arguments.
20+
- 'logger' (any): The logger for logging.
21+
- 'config' (Config): Global configurations for the test
22+
- 'context' (any): The execution context.
23+
24+
:return: The result of the execution.
25+
"""
26+
pipeline_context: PipelineContext = kwargs.get("context")
27+
28+
code = input
29+
30+
if pipeline_context.config.enable_cache and pipeline_context.cache:
31+
pipeline_context.cache.set(
32+
pipeline_context.cache.get_cache_key(pipeline_context), code
33+
)
34+
35+
if pipeline_context.config.callback is not None:
36+
pipeline_context.config.callback.on_code(code)
37+
38+
return code
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
import logging
2+
import traceback
3+
from typing import Any, List
4+
from ...helpers.code_manager import CodeExecutionContext
5+
from ...helpers.logger import Logger
6+
from ..base_logic_unit import BaseLogicUnit
7+
from ..pipeline_context import PipelineContext
8+
from ...prompts.correct_error_prompt import CorrectErrorPrompt
9+
10+
11+
class CodeExecution(BaseLogicUnit):
12+
"""
13+
Code Execution Stage
14+
"""
15+
16+
pass
17+
18+
def execute(self, input: Any, **kwargs) -> Any:
19+
"""
20+
This method will return output according to
21+
Implementation.
22+
23+
:param input: Your input data.
24+
:param kwargs: A dictionary of keyword arguments.
25+
- 'logger' (any): The logger for logging.
26+
- 'config' (Config): Global configurations for the test
27+
- 'context' (any): The execution context.
28+
29+
:return: The result of the execution.
30+
"""
31+
pipeline_context: PipelineContext = kwargs.get("context")
32+
logger: Logger = kwargs.get("logger")
33+
34+
code = input
35+
retry_count = 0
36+
code_to_run = code
37+
result = None
38+
while retry_count < pipeline_context.config.max_retries:
39+
try:
40+
# Execute the code
41+
code_context = CodeExecutionContext(
42+
pipeline_context.get_intermediate_value("last_prompt_id"),
43+
pipeline_context.get_intermediate_value("skills"),
44+
)
45+
result = pipeline_context.get_intermediate_value(
46+
"code_manager"
47+
).execute_code(
48+
code=code_to_run,
49+
context=code_context,
50+
)
51+
52+
break
53+
54+
except Exception as e:
55+
if (
56+
not pipeline_context.config.use_error_correction_framework
57+
or retry_count >= pipeline_context.config.max_retries - 1
58+
):
59+
raise e
60+
61+
retry_count += 1
62+
63+
logger.log(
64+
f"Failed to execute code with a correction framework "
65+
f"[retry number: {retry_count}]",
66+
level=logging.WARNING,
67+
)
68+
69+
traceback_error = traceback.format_exc()
70+
[
71+
code_to_run,
72+
reasoning,
73+
answer,
74+
] = pipeline_context.query_exec_tracker.execute_func(
75+
self._retry_run_code,
76+
code,
77+
pipeline_context,
78+
logger,
79+
traceback_error,
80+
)
81+
82+
pipeline_context.add_intermediate_value("reasoning", reasoning)
83+
pipeline_context.add_intermediate_value("answer", answer)
84+
85+
return result
86+
87+
def _retry_run_code(
88+
self, code: str, context: PipelineContext, logger: Logger, e: Exception
89+
) -> List:
90+
"""
91+
A method to retry the code execution with error correction framework.
92+
93+
Args:
94+
code (str): A python code
95+
context (PipelineContext) : Pipeline Context
96+
logger (Logger) : Logger
97+
e (Exception): An exception
98+
dataframes
99+
100+
Returns (str): A python code
101+
"""
102+
103+
logger.log(f"Failed with error: {e}. Retrying", logging.ERROR)
104+
105+
default_values = {
106+
"engine": context.dfs[0].engine,
107+
"code": code,
108+
"error_returned": e,
109+
}
110+
error_correcting_instruction = context.get_intermediate_value("get_prompt")(
111+
"correct_error",
112+
default_prompt=CorrectErrorPrompt(),
113+
default_values=default_values,
114+
)
115+
116+
result = context.config.llm.generate_code(error_correcting_instruction)
117+
if context.config.callback is not None:
118+
context.config.callback.on_code(result[0])
119+
120+
return result

0 commit comments

Comments
 (0)