-
Notifications
You must be signed in to change notification settings - Fork 1.7k
/
Copy pathpipeline_context.py
74 lines (59 loc) · 2.08 KB
/
pipeline_context.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from typing import List, Optional, Union, Any
from pandasai.helpers.cache import Cache
from pandasai.helpers.df_info import DataFrameType
from pandasai.helpers.memory import Memory
from pandasai.helpers.query_exec_tracker import QueryExecTracker
from pandasai.helpers.skills_manager import SkillsManager
from pandasai.schemas.df_config import Config
class PipelineContext:
"""
Pass Context to the pipeline which is accessible to each step via kwargs
"""
_dfs: List[Union[DataFrameType, Any]]
_memory: Memory
_skills: SkillsManager
_cache: Cache
_config: Config
_query_exec_tracker: QueryExecTracker
_intermediate_values: dict
def __init__(
self,
dfs: List[Union[DataFrameType, Any]],
config: Optional[Union[Config, dict]] = None,
memory: Memory = None,
skills: SkillsManager = None,
cache: Cache = None,
query_exec_tracker: QueryExecTracker = None,
) -> None:
from pandasai.smart_dataframe import load_smartdataframes
if isinstance(config, dict):
config = Config(**config)
self._dfs = load_smartdataframes(dfs, config)
self._memory = memory if memory is not None else Memory()
self._skills = skills if skills is not None else SkillsManager()
self._cache = cache if cache is not None else Cache()
self._config = config
self._query_exec_tracker = query_exec_tracker
self._intermediate_values = {}
@property
def dfs(self) -> List[Union[DataFrameType, Any]]:
return self._dfs
@property
def memory(self):
return self._memory
@property
def skills(self):
return self._skills
@property
def cache(self):
return self._cache
@property
def config(self):
return self._config
@property
def query_exec_tracker(self):
return self._query_exec_tracker
def add_intermediate_value(self, key: str, value: Any):
self._intermediate_values[key] = value
def get_intermediate_value(self, key: str):
return self._intermediate_values[key]