Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(Pipelines) : Smart Data Frame Pipeline #735

Merged
merged 15 commits into from
Nov 12, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions pandasai/pipelines/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from pandasai.helpers.logger import Logger
from pandasai.pipelines.pipeline_context import PipelineContext
from pandasai.pipelines.base_logic_unit import BaseLogicUnit
from pandasai.smart_dataframe import SmartDataframe, load_smartdataframes
from ..schemas.df_config import Config
from typing import Any, Optional, List, Union
from .abstract_pipeline import AbstractPipeline
Expand All @@ -22,9 +21,7 @@ class Pipeline(AbstractPipeline):

def __init__(
self,
context: Union[
List[Union[DataFrameType, SmartDataframe]], PipelineContext
] = None,
context: Union[List[Union[DataFrameType, Any]], PipelineContext] = None,
config: Optional[Union[Config, dict]] = None,
steps: Optional[List] = None,
logger: Optional[Logger] = None,
Expand All @@ -40,6 +37,8 @@ def __init__(
"""

if not isinstance(context, PipelineContext):
from pandasai.smart_dataframe import load_smartdataframes

config = Config(**load_config(config))
smart_dfs = load_smartdataframes(context, config)
context = PipelineContext(smart_dfs, config)
Expand Down
27 changes: 22 additions & 5 deletions pandasai/pipelines/pipeline_context.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,37 @@
from typing import List, Optional, Union
from typing import List, Optional, Union, Any
from pandasai.helpers.cache import Cache

from pandasai.helpers.df_info import DataFrameType
from pandasai.helpers.memory import Memory
from pandasai.helpers.query_exec_tracker import QueryExecTracker
from pandasai.helpers.skills_manager import SkillsManager
from pandasai.schemas.df_config import Config
from pandasai.smart_dataframe import SmartDataframe, load_smartdataframes


class PipelineContext:
"""
Pass Context to the pipeline which is accessible to each step via kwargs
"""

_dfs: List[Union[DataFrameType, SmartDataframe]]
_dfs: List[Union[DataFrameType, Any]]
_memory: Memory
_skills: SkillsManager
_cache: Cache
_config: Config
_query_exec_tracker: QueryExecTracker
_intermediate_values: dict

def __init__(
self,
dfs: List[Union[DataFrameType, SmartDataframe]],
dfs: List[Union[DataFrameType, Any]],
config: Optional[Union[Config, dict]] = None,
memory: Memory = None,
skills: SkillsManager = None,
cache: Cache = None,
query_exec_tracker: QueryExecTracker = None,
) -> None:
from pandasai.smart_dataframe import load_smartdataframes

if isinstance(config, dict):
config = Config(**config)

Expand All @@ -35,9 +40,11 @@ def __init__(
self._skills = skills if skills is not None else SkillsManager()
self._cache = cache if cache is not None else Cache()
self._config = config
self._query_exec_tracker = query_exec_tracker
self._intermediate_values = {}

@property
def dfs(self) -> List[Union[DataFrameType, SmartDataframe]]:
def dfs(self) -> List[Union[DataFrameType, Any]]:
return self._dfs

@property
Expand All @@ -55,3 +62,13 @@ def cache(self):
@property
def config(self):
return self._config

@property
def query_exec_tracker(self):
return self._query_exec_tracker

def add_intermediate_value(self, key: str, value: Any):
self._intermediate_values[key] = value

def get_intermediate_value(self, key: str):
return self._intermediate_values[key]
5 changes: 2 additions & 3 deletions pandasai/schemas/df_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from pydantic import BaseModel, validator, Field
from typing import Optional, List, Any, Dict, Type, TypedDict
from typing import Optional, List, Any, Dict, TypedDict
from pandasai.constants import DEFAULT_CHART_DIRECTORY
from pandasai.responses import ResponseParser
from ..middlewares.base import Middleware
from ..callbacks.base import BaseCallback
from ..llm import LLM, LangchainLLM
Expand Down Expand Up @@ -31,7 +30,7 @@ class Config(BaseModel):
middlewares: List[Middleware] = Field(default_factory=list)
callback: Optional[BaseCallback] = None
lazy_load_connector: bool = True
response_parser: Type[ResponseParser] = None
response_parser: Any = None
llm: Any = None
data_viz_library: Optional[VisualizationLibrary] = None
log_server: LogServerConfig = None
Expand Down
2 changes: 0 additions & 2 deletions pandasai/smart_dataframe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,8 +737,6 @@ def load_smartdataframes(
dfs (List[Union[DataFrameType, Any]]): List of dataframes to be used
"""

from ..smart_dataframe import SmartDataframe

smart_dfs = []
for df in dfs:
if not isinstance(df, SmartDataframe):
Expand Down
Loading