Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(FileManager): adding FileManager to make feasible work with the library in other environment #1573

Merged
merged 6 commits into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions pandasai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,17 +98,17 @@ def create(

org_name, dataset_name = get_validated_dataset_path(path)

dataset_directory = os.path.join(
find_project_root(), "datasets", org_name, dataset_name
)
dataset_directory = str(os.path.join(org_name, dataset_name))

schema_path = os.path.join(str(dataset_directory), "schema.yaml")
parquet_file_path = os.path.join(str(dataset_directory), "data.parquet")
schema_path = os.path.join(dataset_directory, "schema.yaml")
parquet_file_path = os.path.join(dataset_directory, "data.parquet")

file_manager = config.get().file_manager
# Check if dataset already exists
if os.path.exists(dataset_directory) and os.path.exists(schema_path):
if file_manager.exists(dataset_directory) and file_manager.exists(schema_path):
raise ValueError(f"Dataset already exists at path: {path}")

os.makedirs(dataset_directory, exist_ok=True)
file_manager.mkdir(dataset_directory)

if df is None and source is None and not view:
raise InvalidConfigError(
Expand All @@ -135,8 +135,7 @@ def create(
if columns:
schema.columns = [Column(**column) for column in columns]

with open(schema_path, "w") as yml_file:
yml_file.write(schema.to_yaml())
file_manager.write(schema_path, schema.to_yaml())

print(f"Dataset saved successfully to path: {dataset_directory}")

Expand Down
3 changes: 3 additions & 0 deletions pandasai/config.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import os
from abc import ABC, abstractmethod
from importlib.util import find_spec
from typing import Any, Dict, Optional

from pydantic import BaseModel, ConfigDict

from pandasai.helpers.filemanager import DefaultFileManager, FileManager
from pandasai.llm.base import LLM


Expand All @@ -13,6 +15,7 @@ class Config(BaseModel):
enable_cache: bool = True
max_retries: int = 3
llm: Optional[LLM] = None
file_manager: FileManager = DefaultFileManager()

model_config = ConfigDict(arbitrary_types_allowed=True)

Expand Down
40 changes: 0 additions & 40 deletions pandasai/core/prompts/file_based_prompt.py

This file was deleted.

42 changes: 19 additions & 23 deletions pandasai/data_loader/loader.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
import os

import pandas as pd
import yaml

from pandasai.dataframe.base import DataFrame
from pandasai.exceptions import MethodNotImplementedError
from pandasai.helpers.path import find_project_root
from pandasai.helpers.sql_sanitizer import sanitize_sql_table_name

from .. import ConfigManager
from ..constants import (
LOCAL_SOURCE_TYPES,
)
from .query_builder import QueryBuilder
from .semantic_layer_schema import SemanticLayerSchema
from .transformation_manager import TransformationManager
from .view_query_builder import ViewQueryBuilder


Expand Down Expand Up @@ -46,21 +48,22 @@ def create_loader_from_path(cls, dataset_path: str) -> "DatasetLoader":
"""
Factory method to create the appropriate loader based on the dataset type.
"""
schema = cls._read_local_schema(dataset_path)
schema = cls._read_schema_file(dataset_path)
return DatasetLoader.create_loader_from_schema(schema, dataset_path)

@staticmethod
def _read_local_schema(dataset_path: str) -> SemanticLayerSchema:
schema_path = os.path.join(
find_project_root(), "datasets", dataset_path, "schema.yaml"
)
if not os.path.exists(schema_path):
def _read_schema_file(dataset_path: str) -> SemanticLayerSchema:
schema_path = os.path.join(dataset_path, "schema.yaml")

file_manager = ConfigManager.get().file_manager

if not file_manager.exists(schema_path):
raise FileNotFoundError(f"Schema file not found: {schema_path}")

with open(schema_path, "r") as file:
raw_schema = yaml.safe_load(file)
raw_schema["name"] = sanitize_sql_table_name(raw_schema["name"])
return SemanticLayerSchema(**raw_schema)
schema_file = file_manager.load(schema_path)
raw_schema = yaml.safe_load(schema_file)
raw_schema["name"] = sanitize_sql_table_name(raw_schema["name"])
return SemanticLayerSchema(**raw_schema)

def load(self) -> DataFrame:
"""
Expand All @@ -72,16 +75,9 @@ def load(self) -> DataFrame:
"""
raise MethodNotImplementedError("Loader not instantiated")

def _build_dataset(
self, schema: SemanticLayerSchema, dataset_path: str
) -> DataFrame:
self.schema = schema
self.dataset_path = dataset_path
is_view = schema.view

self.query_builder = (
ViewQueryBuilder(schema) if is_view else QueryBuilder(schema)
)
def _apply_transformations(self, df: pd.DataFrame) -> pd.DataFrame:
if not self.schema.transformations:
return df

def _get_abs_dataset_path(self):
return os.path.join(find_project_root(), "datasets", self.dataset_path)
transformation_manager = TransformationManager(df)
return transformation_manager.apply_transformations(self.schema.transformations)
9 changes: 1 addition & 8 deletions pandasai/data_loader/local_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def _load_from_local_source(self) -> pd.DataFrame:
)

filepath = os.path.join(
str(self._get_abs_dataset_path()),
self.dataset_path,
self.schema.source.path,
)

Expand Down Expand Up @@ -69,10 +69,3 @@ def _filter_columns(self, df: pd.DataFrame) -> pd.DataFrame:
df_columns = df.columns.tolist()
columns_to_keep = [col for col in df_columns if col in schema_columns]
return df[columns_to_keep]

def _apply_transformations(self, df: pd.DataFrame) -> pd.DataFrame:
if not self.schema.transformations:
return df

transformation_manager = TransformationManager(df)
return transformation_manager.apply_transformations(self.schema.transformations)
7 changes: 4 additions & 3 deletions pandasai/data_loader/sql_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ def __init__(self, schema: SemanticLayerSchema, dataset_path: str):
self.query_builder: QueryBuilder = QueryBuilder(schema)

def load(self) -> VirtualDataFrame:
self.query_builder = QueryBuilder(self.schema)
return VirtualDataFrame(
schema=self.schema,
data_loader=SQLDatasetLoader(self.schema, self.dataset_path),
Expand All @@ -37,9 +36,11 @@ def execute_query(self, query: str, params: Optional[list] = None) -> pd.DataFra

formatted_query = self.query_builder.format_query(query)
load_function = self._get_loader_function(source_type)

try:
return load_function(connection_info, formatted_query, params)
dataframe: pd.DataFrame = load_function(
connection_info, formatted_query, params
)
return self._apply_transformations(dataframe)
except Exception as e:
raise RuntimeError(
f"Failed to execute query for '{source_type}' with: {formatted_query}"
Expand Down
31 changes: 12 additions & 19 deletions pandasai/data_loader/transformation_manager.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Any, List, Optional, Union

import numpy as np
import pandas as pd

from ..exceptions import UnsupportedTransformation
from .semantic_layer_schema import Transformation


class TransformationManager:
Expand Down Expand Up @@ -268,12 +268,12 @@ def format_date(self, column: str, date_format: str) -> "TransformationManager":
TransformationManager: Self for method chaining

Example:
>>> df = pd.DataFrame({"date": ["2024-01-01 12:30:45"]})
>>> df = pd.DataFrame({"date": ["2025-01-01 12:30:45"]})
>>> manager = TransformationManager(df)
>>> result = manager.format_date("date", "%Y-%m-%d").df
>>> print(result)
date
0 2024-01-01
0 2025-01-01
"""
self.df[column] = self.df[column].dt.strftime(date_format)
return self
Expand Down Expand Up @@ -307,28 +307,28 @@ def to_numeric(
return self

def to_datetime(
self, column: str, format: Optional[str] = None, errors: str = "coerce"
self, column: str, _format: Optional[str] = None, errors: str = "coerce"
) -> "TransformationManager":
"""Convert values in a column to datetime type.

Args:
column (str): The column to transform
format (Optional[str]): Expected date format of the input
_format (Optional[str]): Expected date format of the input
errors (str): How to handle parsing errors

Returns:
TransformationManager: Self for method chaining

Example:
>>> df = pd.DataFrame({"date": ["2024-01-01", "invalid"]})
>>> df = pd.DataFrame({"date": ["2025-01-01", "invalid"]})
>>> manager = TransformationManager(df)
>>> result = manager.to_datetime("date", errors="coerce").df
>>> print(result)
date
0 2024-01-01
0 2025-01-01
1 NaT
"""
self.df[column] = pd.to_datetime(self.df[column], format=format, errors=errors)
self.df[column] = pd.to_datetime(self.df[column], format=_format, errors=errors)
return self

def fill_na(self, column: str, value: Any) -> "TransformationManager":
Expand Down Expand Up @@ -884,27 +884,20 @@ def rename(self, column: str, new_name: str) -> "TransformationManager":
return self

def apply_transformations(
self, transformations: Optional[List[dict]] = None
self, transformations: List[Transformation]
) -> pd.DataFrame:
"""Apply a list of transformations to the DataFrame.

Args:
transformations (Optional[List[dict]]): List of transformation configurations
transformations List[Transformation]: List of transformation configurations

Returns:
pd.DataFrame: The transformed DataFrame
"""
if not transformations:
return self.df

for transformation in transformations:
# Handle both dict and object transformations
if isinstance(transformation, dict):
transformation_type = transformation["type"]
params = transformation["params"]
else:
transformation_type = transformation.type
params = transformation.params
transformation_type = transformation.type
params = transformation.params

handler = self.transformation_handlers.get(transformation_type)
if not handler:
Expand Down
Loading