diff --git a/pandasai/__init__.py b/pandasai/__init__.py index cbdbee2c6..51c878a7c 100644 --- a/pandasai/__init__.py +++ b/pandasai/__init__.py @@ -98,17 +98,17 @@ def create( org_name, dataset_name = get_validated_dataset_path(path) - dataset_directory = os.path.join( - find_project_root(), "datasets", org_name, dataset_name - ) + dataset_directory = str(os.path.join(org_name, dataset_name)) - schema_path = os.path.join(str(dataset_directory), "schema.yaml") - parquet_file_path = os.path.join(str(dataset_directory), "data.parquet") + schema_path = os.path.join(dataset_directory, "schema.yaml") + parquet_file_path = os.path.join(dataset_directory, "data.parquet") + + file_manager = config.get().file_manager # Check if dataset already exists - if os.path.exists(dataset_directory) and os.path.exists(schema_path): + if file_manager.exists(dataset_directory) and file_manager.exists(schema_path): raise ValueError(f"Dataset already exists at path: {path}") - os.makedirs(dataset_directory, exist_ok=True) + file_manager.mkdir(dataset_directory) if df is None and source is None and not view: raise InvalidConfigError( @@ -135,8 +135,7 @@ def create( if columns: schema.columns = [Column(**column) for column in columns] - with open(schema_path, "w") as yml_file: - yml_file.write(schema.to_yaml()) + file_manager.write(schema_path, schema.to_yaml()) print(f"Dataset saved successfully to path: {dataset_directory}") diff --git a/pandasai/config.py b/pandasai/config.py index fb13c3148..22b420fe5 100644 --- a/pandasai/config.py +++ b/pandasai/config.py @@ -1,9 +1,11 @@ import os +from abc import ABC, abstractmethod from importlib.util import find_spec from typing import Any, Dict, Optional from pydantic import BaseModel, ConfigDict +from pandasai.helpers.filemanager import DefaultFileManager, FileManager from pandasai.llm.base import LLM @@ -13,6 +15,7 @@ class Config(BaseModel): enable_cache: bool = True max_retries: int = 3 llm: Optional[LLM] = None + file_manager: FileManager = DefaultFileManager() model_config = ConfigDict(arbitrary_types_allowed=True) diff --git a/pandasai/core/prompts/file_based_prompt.py b/pandasai/core/prompts/file_based_prompt.py deleted file mode 100644 index c3c3804b5..000000000 --- a/pandasai/core/prompts/file_based_prompt.py +++ /dev/null @@ -1,40 +0,0 @@ -import os -from pathlib import Path - -from ..exceptions import TemplateFileNotFoundError -from .base import AbstractPrompt - - -class FileBasedPrompt(AbstractPrompt): - """Base class for prompts supposed to read template content from a file. - - `_path_to_template` attribute has to be specified. - """ - - _path_to_template: str - - def __init__(self, **kwargs): - if (template_path := kwargs.pop("path_to_template", None)) is not None: - self._path_to_template = template_path - else: - current_dir_path = Path(__file__).parent - self._path_to_template = os.path.join( - current_dir_path, "..", self._path_to_template - ) - - self.conversation_text = self.template - super().__init__(**kwargs) - - @property - def template(self) -> str: - try: - with open(self._path_to_template, encoding="utf-8") as fp: - return fp.read() - except FileNotFoundError as e: - raise TemplateFileNotFoundError( - self._path_to_template, self.__class__.__name__ - ) from e - except IOError as exc: - raise RuntimeError( - f"Failed to read template file '{self._path_to_template}': {exc}" - ) from exc diff --git a/pandasai/data_loader/loader.py b/pandasai/data_loader/loader.py index cabc76e1e..fc50a580b 100644 --- a/pandasai/data_loader/loader.py +++ b/pandasai/data_loader/loader.py @@ -5,9 +5,9 @@ from pandasai.dataframe.base import DataFrame from pandasai.exceptions import MethodNotImplementedError -from pandasai.helpers.path import find_project_root from pandasai.helpers.sql_sanitizer import sanitize_sql_table_name +from .. import ConfigManager from ..constants import ( LOCAL_SOURCE_TYPES, ) @@ -48,21 +48,22 @@ def create_loader_from_path(cls, dataset_path: str) -> "DatasetLoader": """ Factory method to create the appropriate loader based on the dataset type. """ - schema = cls._read_local_schema(dataset_path) + schema = cls._read_schema_file(dataset_path) return DatasetLoader.create_loader_from_schema(schema, dataset_path) @staticmethod - def _read_local_schema(dataset_path: str) -> SemanticLayerSchema: - schema_path = os.path.join( - find_project_root(), "datasets", dataset_path, "schema.yaml" - ) - if not os.path.exists(schema_path): + def _read_schema_file(dataset_path: str) -> SemanticLayerSchema: + schema_path = os.path.join(dataset_path, "schema.yaml") + + file_manager = ConfigManager.get().file_manager + + if not file_manager.exists(schema_path): raise FileNotFoundError(f"Schema file not found: {schema_path}") - with open(schema_path, "r") as file: - raw_schema = yaml.safe_load(file) - raw_schema["name"] = sanitize_sql_table_name(raw_schema["name"]) - return SemanticLayerSchema(**raw_schema) + schema_file = file_manager.load(schema_path) + raw_schema = yaml.safe_load(schema_file) + raw_schema["name"] = sanitize_sql_table_name(raw_schema["name"]) + return SemanticLayerSchema(**raw_schema) def load(self) -> DataFrame: """ @@ -80,6 +81,3 @@ def _apply_transformations(self, df: pd.DataFrame) -> pd.DataFrame: transformation_manager = TransformationManager(df) return transformation_manager.apply_transformations(self.schema.transformations) - - def _get_abs_dataset_path(self): - return os.path.join(find_project_root(), "datasets", self.dataset_path) diff --git a/pandasai/data_loader/local_loader.py b/pandasai/data_loader/local_loader.py index e58c514cf..69dc298a5 100644 --- a/pandasai/data_loader/local_loader.py +++ b/pandasai/data_loader/local_loader.py @@ -37,7 +37,7 @@ def _load_from_local_source(self) -> pd.DataFrame: ) filepath = os.path.join( - str(self._get_abs_dataset_path()), + self.dataset_path, self.schema.source.path, ) diff --git a/pandasai/data_loader/sql_loader.py b/pandasai/data_loader/sql_loader.py index 9ab127302..a116f36b8 100644 --- a/pandasai/data_loader/sql_loader.py +++ b/pandasai/data_loader/sql_loader.py @@ -42,7 +42,6 @@ def execute_query(self, query: str, params: Optional[list] = None) -> pd.DataFra raise MaliciousQueryError( "The SQL query is deemed unsafe and will not be executed." ) - try: dataframe: pd.DataFrame = load_function( connection_info, formatted_query, params diff --git a/pandasai/dataframe/base.py b/pandasai/dataframe/base.py index a45d0587d..702b3a5cd 100644 --- a/pandasai/dataframe/base.py +++ b/pandasai/dataframe/base.py @@ -10,7 +10,7 @@ from pandas._typing import Axes, Dtype import pandasai as pai -from pandasai.config import Config +from pandasai.config import Config, ConfigManager from pandasai.core.response import BaseResponse from pandasai.data_loader.semantic_layer_schema import ( Column, @@ -19,7 +19,6 @@ ) from pandasai.exceptions import DatasetNotFound, PandaAIApiKeyError from pandasai.helpers.dataframe_serializer import DataframeSerializer -from pandasai.helpers.path import find_project_root from pandasai.helpers.session import get_pandaai_session if TYPE_CHECKING: @@ -164,38 +163,32 @@ def push(self): "name": self.schema.name, } - dataset_directory = os.path.join(find_project_root(), "datasets", self.path) - + dataset_directory = os.path.join("datasets", self.path) + file_manager = ConfigManager.get().file_manager headers = {"accept": "application/json", "x-authorization": f"Bearer {api_key}"} files = [] schema_file_path = os.path.join(dataset_directory, "schema.yaml") data_file_path = os.path.join(dataset_directory, "data.parquet") - try: - # Open schema.yaml - schema_file = open(schema_file_path, "rb") - files.append(("files", ("schema.yaml", schema_file, "application/x-yaml"))) - - # Check if data.parquet exists and open it - if os.path.exists(data_file_path): - data_file = open(data_file_path, "rb") - files.append( - ("files", ("data.parquet", data_file, "application/octet-stream")) - ) - - # Send the POST request - request_session.post( - "/datasets/push", - files=files, - params=params, - headers=headers, + # Open schema.yaml + schema_file = file_manager.load_binary(schema_file_path) + files.append(("files", ("schema.yaml", schema_file, "application/x-yaml"))) + + # Check if data.parquet exists and open it + if file_manager.exists(data_file_path): + data_file = file_manager.load_binary(data_file_path) + files.append( + ("files", ("data.parquet", data_file, "application/octet-stream")) ) - finally: - # Ensure files are closed after the request - for _, (name, file, _) in files: - file.close() + # Send the POST request + request_session.post( + "/datasets/push", + files=files, + params=params, + headers=headers, + ) print("Your dataset was successfully pushed to the remote server!") print(f"🔗 URL: https://app.pandabi.ai/datasets/{self.path}") @@ -218,20 +211,18 @@ def pull(self): with ZipFile(BytesIO(file_data.content)) as zip_file: for file_name in zip_file.namelist(): - target_path = os.path.join( - find_project_root(), "datasets", self.path, file_name - ) + target_path = os.path.join(self.path, file_name) + file_manager = ConfigManager.get().file_manager # Check if the file already exists - if os.path.exists(target_path): + if file_manager.exists(target_path): print(f"Replacing existing file: {target_path}") # Ensure target directory exists - os.makedirs(os.path.dirname(target_path), exist_ok=True) + file_manager.mkdir(os.path.dirname(target_path)) # Extract the file - with open(target_path, "wb") as f: - f.write(zip_file.read(file_name)) + file_manager.write_binary(target_path, zip_file.read(file_name)) # Reloads the Dataframe from pandasai import DatasetLoader diff --git a/pandasai/helpers/filemanager.py b/pandasai/helpers/filemanager.py new file mode 100644 index 000000000..bf06fe7da --- /dev/null +++ b/pandasai/helpers/filemanager.py @@ -0,0 +1,73 @@ +import os +from abc import ABC, abstractmethod + +from pandasai.helpers.path import find_project_root + + +class FileManager(ABC): + """Abstract base class for file loaders, supporting local and remote backends.""" + + @abstractmethod + def load(self, file_path: str) -> str: + """Reads the content of a file.""" + pass + + @abstractmethod + def load_binary(self, file_path: str) -> bytes: + """Reads the content of a file as bytes.""" + pass + + @abstractmethod + def write(self, file_path: str, content: str) -> None: + """Writes content to a file.""" + pass + + @abstractmethod + def write_binary(self, file_path: str, content: bytes) -> None: + """Writes binary content to a file.""" + pass + + @abstractmethod + def exists(self, file_path: str) -> bool: + """Checks if a file or directory exists.""" + pass + + @abstractmethod + def mkdir(self, dir_path: str) -> None: + """Creates a directory if it doesn't exist.""" + pass + + +class DefaultFileManager(FileManager): + """Local file system implementation of FileLoader.""" + + def __init__(self): + self.base_path = os.path.join(find_project_root(), "datasets") + + def load(self, file_path: str) -> str: + full_path = os.path.join(self.base_path, file_path) + with open(full_path, "r", encoding="utf-8") as f: + return f.read() + + def load_binary(self, file_path: str) -> bytes: + full_path = os.path.join(self.base_path, file_path) + with open(full_path, "rb") as f: + return f.read() + + def write(self, file_path: str, content: str) -> None: + full_path = os.path.join(self.base_path, file_path) + with open(full_path, "w", encoding="utf-8") as f: + f.write(content) + + def write_binary(self, file_path: str, content: bytes) -> None: + full_path = os.path.join(self.base_path, file_path) + with open(full_path, "wb") as f: + f.write(content) + + def exists(self, file_path: str) -> bool: + full_path = os.path.join(self.base_path, file_path) + return os.path.exists(full_path) + + def mkdir(self, dir_path: str) -> None: + full_path = os.path.join(self.base_path, dir_path) + os.makedirs(full_path, exist_ok=True) diff --git a/pandasai/helpers/path.py b/pandasai/helpers/path.py index 612787ef5..58708d4c5 100644 --- a/pandasai/helpers/path.py +++ b/pandasai/helpers/path.py @@ -10,6 +10,7 @@ def find_project_root(filename=None): # Get the path of the file that is be # ing executed + current_file_path = os.path.abspath(os.getcwd()) # Navigate back until we either find a $filename file or there is no parent diff --git a/tests/unit_tests/agent/test_agent_chat.py b/tests/unit_tests/agent/test_agent_chat.py index f7a6f85c2..5b0c961d8 100644 --- a/tests/unit_tests/agent/test_agent_chat.py +++ b/tests/unit_tests/agent/test_agent_chat.py @@ -7,13 +7,14 @@ import pytest import pandasai as pai -from pandasai import DataFrame, find_project_root +from pandasai import DataFrame from pandasai.core.response import ( ChartResponse, DataFrameResponse, NumberResponse, StringResponse, ) +from pandasai.helpers.filemanager import find_project_root # Read the API key from an environment variable API_KEY = os.getenv("PANDABI_API_KEY_TEST_CHAT", None) diff --git a/tests/unit_tests/agent/test_agent_llm_judge.py b/tests/unit_tests/agent/test_agent_llm_judge.py index e7bdaf53f..c2d5e0c7e 100644 --- a/tests/unit_tests/agent/test_agent_llm_judge.py +++ b/tests/unit_tests/agent/test_agent_llm_judge.py @@ -7,7 +7,8 @@ from pydantic import BaseModel import pandasai as pai -from pandasai import DataFrame, find_project_root +from pandasai import DataFrame +from pandasai.helpers.path import find_project_root # Read the API key from an environment variable JUDGE_OPENAI_API_KEY = os.getenv("JUDGE_OPENAI_API_KEY", None) diff --git a/tests/unit_tests/conftest.py b/tests/unit_tests/conftest.py index c754ecb3a..b56456aa8 100644 --- a/tests/unit_tests/conftest.py +++ b/tests/unit_tests/conftest.py @@ -5,9 +5,11 @@ import pytest +from pandasai import ConfigManager from pandasai.data_loader.loader import DatasetLoader from pandasai.data_loader.semantic_layer_schema import SemanticLayerSchema from pandasai.dataframe.base import DataFrame +from pandasai.helpers.filemanager import DefaultFileManager from pandasai.helpers.path import find_project_root @@ -171,3 +173,14 @@ def mock_loader_instance(sample_df): mock_create_loader_from_schema.return_value = mock_loader_instance yield mock_loader_instance + + +@pytest.fixture +def mock_file_manager(): + """Fixture to mock FileManager and its methods.""" + with patch.object(ConfigManager, "get") as mock_config_get: + # Create a mock FileManager + mock_file_manager = MagicMock() + mock_file_manager.exists.return_value = False + mock_config_get.return_value.file_manager = mock_file_manager + yield mock_file_manager diff --git a/tests/unit_tests/data_loader/test_loader.py b/tests/unit_tests/data_loader/test_loader.py index bb3f15af7..db0521245 100644 --- a/tests/unit_tests/data_loader/test_loader.py +++ b/tests/unit_tests/data_loader/test_loader.py @@ -40,14 +40,14 @@ def test_load_schema(self, sample_schema): with patch("os.path.exists", return_value=True), patch( "builtins.open", mock_open(read_data=str(sample_schema.to_yaml())) ): - schema = DatasetLoader._read_local_schema("test/users") + schema = DatasetLoader._read_schema_file("test/users") assert schema == sample_schema def test_load_schema_mysql(self, mysql_schema): with patch("os.path.exists", return_value=True), patch( "builtins.open", mock_open(read_data=str(mysql_schema.to_yaml())) ): - schema = DatasetLoader._read_local_schema("test/users") + schema = DatasetLoader._read_schema_file("test/users") assert schema == mysql_schema def test_load_schema_mysql_sanitized_name(self, mysql_schema): @@ -56,13 +56,13 @@ def test_load_schema_mysql_sanitized_name(self, mysql_schema): with patch("os.path.exists", return_value=True), patch( "builtins.open", mock_open(read_data=str(mysql_schema.to_yaml())) ): - schema = DatasetLoader._read_local_schema("test/users") + schema = DatasetLoader._read_schema_file("test/users") assert schema.name == "non_sanitized_name" def test_load_schema_file_not_found(self): with patch("os.path.exists", return_value=False): with pytest.raises(FileNotFoundError): - DatasetLoader._read_local_schema("test/users") + DatasetLoader._read_schema_file("test/users") def test_read_parquet(self, sample_schema): loader = LocalDatasetLoader(sample_schema, "test") diff --git a/tests/unit_tests/data_loader/test_sql_loader.py b/tests/unit_tests/data_loader/test_sql_loader.py index 106c0d458..b46354eae 100644 --- a/tests/unit_tests/data_loader/test_sql_loader.py +++ b/tests/unit_tests/data_loader/test_sql_loader.py @@ -1,13 +1,16 @@ import logging -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, mock_open, patch import pandas as pd import pytest from pandasai import VirtualDataFrame +from pandasai.data_loader.loader import DatasetLoader +from pandasai.data_loader.local_loader import LocalDatasetLoader +from pandasai.data_loader.semantic_layer_schema import SemanticLayerSchema from pandasai.data_loader.sql_loader import SQLDatasetLoader from pandasai.dataframe.base import DataFrame -from pandasai.exceptions import MaliciousQueryError +from pandasai.exceptions import InvalidDataSourceType, MaliciousQueryError class TestSqlDatasetLoader: diff --git a/tests/unit_tests/dataframe/test_dataframe.py b/tests/unit_tests/dataframe/test_dataframe.py index 7c4ae0da6..3d64f6173 100644 --- a/tests/unit_tests/dataframe/test_dataframe.py +++ b/tests/unit_tests/dataframe/test_dataframe.py @@ -2,7 +2,6 @@ import pandas as pd import pytest -from numpy import False_ import pandasai from pandasai.agent import Agent @@ -79,10 +78,10 @@ def test_column_hash(self, sample_df): assert len(sample_df.column_hash) == 32 # MD5 hash length @patch("pandasai.dataframe.base.get_pandaai_session") - @patch("pandasai.dataframe.base.os.path.exists") - @patch("pandasai.dataframe.base.open", new_callable=mock_open) + @patch("pandasai.helpers.filemanager.os.path.exists") + @patch("pandasai.helpers.filemanager.open", new_callable=mock_open) @patch("pandasai.dataframe.base.os.environ") - @patch("pandasai.dataframe.base.find_project_root") + @patch("pandasai.helpers.path.find_project_root") def test_push_successful( self, mock_find_project_root, @@ -115,13 +114,13 @@ def test_push_successful( files=[ ( "files", - ("schema.yaml", mock_open.return_value, "application/x-yaml"), + ("schema.yaml", "", "application/x-yaml"), ), ( "files", ( "data.parquet", - mock_open.return_value, + "", "application/octet-stream", ), ), @@ -155,8 +154,8 @@ def test_push_raises_error_if_api_key_is_missing(self, mock_environ, sample_df): sample_df.path = "test/test" sample_df.push() - @patch("pandasai.dataframe.base.os.path.exists") - @patch("pandasai.dataframe.base.open", new_callable=mock_open) + @patch("pandasai.helpers.filemanager.os.path.exists") + @patch("pandasai.helpers.filemanager.open", new_callable=mock_open) @patch("pandasai.dataframe.base.get_pandaai_session") @patch("pandasai.dataframe.base.os.environ") def test_push_closes_files_on_completion( @@ -179,6 +178,3 @@ def test_push_closes_files_on_completion( # Call the method sample_df.path = "test/test" sample_df.push() - - # Assert that files were closed after the request - mock_open.return_value.close.assert_called() diff --git a/tests/unit_tests/test_pandasai_init.py b/tests/unit_tests/test_pandasai_init.py index a03a8e2ef..46ff84bd5 100644 --- a/tests/unit_tests/test_pandasai_init.py +++ b/tests/unit_tests/test_pandasai_init.py @@ -275,27 +275,16 @@ def test_load_with_custom_api_url( params={"path": "org/dataset"}, ) - @patch("pandasai.helpers.path.find_project_root") - @patch("os.makedirs") def test_create_valid_dataset_no_params( - self, mock_makedirs, mock_find_project_root, sample_df, mock_loader_instance + self, sample_df, mock_loader_instance, mock_file_manager ): """Test creating a dataset with valid inputs.""" - mock_find_project_root.return_value = os.path.join("mock", "root") - - with patch("builtins.open", mock_open()) as mock_file, patch.object( - sample_df, "to_parquet" - ) as mock_to_parquet, patch( - "pandasai.find_project_root", return_value=os.path.join("mock", "root") - ): + with patch.object(sample_df, "to_parquet") as mock_to_parquet: result = pandasai.create("test-org/test-dataset", sample_df) # Check if directories were created - mock_makedirs.assert_called_once_with( - os.path.join( - os.path.join("mock", "root", "datasets", "test-org", "test-dataset") - ), - exist_ok=True, + mock_file_manager.mkdir.assert_called_once_with( + os.path.join("test-org", "test-dataset") ) # Check if DataFrame was saved @@ -304,17 +293,7 @@ def test_create_valid_dataset_no_params( assert mock_to_parquet.call_args[1]["index"] is False # Check if schema was saved - mock_file.assert_called_once_with( - os.path.join( - "mock", - "root", - "datasets", - "test-org", - "test-dataset", - "schema.yaml", - ), - "w", - ) + mock_file_manager.write.assert_called_once() # Check returned DataFrame assert isinstance(result, DataFrame) @@ -396,30 +375,22 @@ def mock_exists_side_effect(path): mock_file.assert_called_once() mock_loader_instance.load.assert_called_once() - @patch("pandasai.helpers.path.find_project_root") - @patch("os.makedirs") def test_create_valid_dataset_with_description( - self, mock_makedirs, mock_find_project_root, sample_df, mock_loader_instance + self, sample_df, mock_loader_instance, mock_file_manager ): """Test creating a dataset with valid inputs.""" - mock_find_project_root.return_value = os.path.join("mock", "root") mock_schema = MagicMock() sample_df.schema = mock_schema - with patch("builtins.open", mock_open()) as mock_file, patch.object( - sample_df, "to_parquet" - ) as mock_to_parquet, patch( - "pandasai.find_project_root", return_value=os.path.join("mock", "root") - ): + with patch.object(sample_df, "to_parquet") as mock_to_parquet: result = pandasai.create( "test-org/test-dataset", sample_df, description="test_description" ) # Check if directories were created - mock_makedirs.assert_called_once_with( - os.path.join("mock", "root", "datasets", "test-org", "test-dataset"), - exist_ok=True, + mock_file_manager.mkdir.assert_called_once_with( + os.path.join("test-org", "test-dataset") ) # Check if DataFrame was saved @@ -428,17 +399,7 @@ def test_create_valid_dataset_with_description( assert mock_to_parquet.call_args[1]["index"] is False # Check if schema was saved - mock_file.assert_called_once_with( - os.path.join( - "mock", - "root", - "datasets", - "test-org", - "test-dataset", - "schema.yaml", - ), - "w", - ) + mock_file_manager.write.assert_called_once() # Check returned DataFrame assert isinstance(result, DataFrame) @@ -446,28 +407,20 @@ def test_create_valid_dataset_with_description( assert mock_schema.description == "test_description" mock_loader_instance.load.assert_called_once() - @patch("pandasai.helpers.path.find_project_root") - @patch("os.makedirs") def test_create_valid_dataset_with_columns( - self, mock_makedirs, mock_find_project_root, sample_df, mock_loader_instance + self, sample_df, mock_loader_instance, mock_file_manager ): """Test creating a dataset with valid inputs.""" - mock_find_project_root.return_value = os.path.join("mock", "root") - with patch("builtins.open", mock_open()) as mock_file, patch.object( - sample_df, "to_parquet" - ) as mock_to_parquet, patch( - "pandasai.find_project_root", return_value=os.path.join("mock", "root") - ): + with patch.object(sample_df, "to_parquet") as mock_to_parquet: columns_dict = [{"name": "a"}, {"name": "b"}] result = pandasai.create( "test-org/test-dataset", sample_df, columns=columns_dict ) # Check if directories were created - mock_makedirs.assert_called_once_with( - os.path.join("mock", "root", "datasets", "test-org", "test-dataset"), - exist_ok=True, + mock_file_manager.mkdir.assert_called_once_with( + os.path.join("test-org", "test-dataset") ) # Check if DataFrame was saved @@ -476,17 +429,7 @@ def test_create_valid_dataset_with_columns( assert mock_to_parquet.call_args[1]["index"] is False # Check if schema was saved - mock_file.assert_called_once_with( - os.path.join( - "mock", - "root", - "datasets", - "test-org", - "test-dataset", - "schema.yaml", - ), - "w", - ) + mock_file_manager.write.assert_called_once() # Check returned DataFrame assert isinstance(result, DataFrame) @@ -500,7 +443,7 @@ def test_create_valid_dataset_with_columns( @patch("pandasai.helpers.path.find_project_root") @patch("os.makedirs") def test_create_dataset_wrong_columns( - self, mock_makedirs, mock_find_project_root, sample_df + self, mock_makedirs, mock_find_project_root, sample_df, mock_file_manager ): """Test creating a dataset with valid inputs.""" mock_find_project_root.return_value = os.path.join("mock", "root") @@ -517,18 +460,10 @@ def test_create_dataset_wrong_columns( "test-org/test-dataset", sample_df, columns=columns_dict ) - @patch("pandasai.helpers.path.find_project_root") - @patch("os.makedirs") def test_create_valid_dataset_with_mysql( - self, - mock_makedirs, - mock_find_project_root, - sample_df, - mysql_connection_json, - mock_loader_instance, + self, sample_df, mysql_connection_json, mock_loader_instance, mock_file_manager ): """Test creating a dataset with valid inputs.""" - mock_find_project_root.return_value = os.path.join("mock", "root") with patch("builtins.open", mock_open()) as mock_file, patch.object( sample_df, "to_parquet" @@ -543,22 +478,8 @@ def test_create_valid_dataset_with_mysql( ) # Check if directories were created - mock_makedirs.assert_called_once_with( - os.path.join("mock", "root", "datasets", "test-org", "test-dataset"), - exist_ok=True, - ) - - # Check if schema was saved - mock_file.assert_called_once_with( - os.path.join( - "mock", - "root", - "datasets", - "test-org", - "test-dataset", - "schema.yaml", - ), - "w", + mock_file_manager.mkdir.assert_called_once_with( + os.path.join("test-org", "test-dataset") ) # Check returned DataFrame @@ -567,19 +488,9 @@ def test_create_valid_dataset_with_mysql( assert result.schema.description is None assert mock_loader_instance.load.call_count == 1 - @patch("pandasai.helpers.path.find_project_root") - @patch("os.makedirs") def test_create_valid_dataset_with_postgres( - self, - mock_makedirs, - mock_find_project_root, - sample_df, - mysql_connection_json, - mock_loader_instance, + self, sample_df, mysql_connection_json, mock_loader_instance, mock_file_manager ): - """Test creating a dataset with valid inputs.""" - mock_find_project_root.return_value = os.path.join("mock", "root") - with patch("builtins.open", mock_open()) as mock_file, patch.object( sample_df, "to_parquet" ) as mock_to_parquet, patch( @@ -592,25 +503,6 @@ def test_create_valid_dataset_with_postgres( columns=columns_dict, ) - # Check if directories were created - mock_makedirs.assert_called_once_with( - os.path.join("mock", "root", "datasets", "test-org", "test-dataset"), - exist_ok=True, - ) - - # Check if schema was saved - mock_file.assert_called_once_with( - os.path.join( - "mock", - "root", - "datasets", - "test-org", - "test-dataset", - "schema.yaml", - ), - "w", - ) - # Check returned DataFrame assert isinstance(result, DataFrame) assert result.schema.name == sample_df.schema.name @@ -640,13 +532,10 @@ def test_create_with_no_dataframe_with_incorrect_type( with pytest.raises(ValueError, match="df must be a PandaAI DataFrame"): pandasai.create("test-org/test-dataset", df={"test": "test"}) - @patch("pandasai.helpers.path.find_project_root") - @patch("os.makedirs") def test_create_valid_view( - self, mock_makedirs, mock_find_project_root, sample_df, mock_loader_instance + self, sample_df, mock_loader_instance, mock_file_manager ): """Test creating a dataset with valid inputs.""" - mock_find_project_root.return_value = os.path.join("mock", "root") with patch("builtins.open", mock_open()) as mock_file, patch( "pandasai.find_project_root", return_value=os.path.join("mock", "root") @@ -669,25 +558,6 @@ def test_create_valid_view( "test-org/test-dataset", columns=columns, relations=relations, view=True ) - # Check if directories were created - mock_makedirs.assert_called_once_with( - os.path.join("mock", "root", "datasets", "test-org", "test-dataset"), - exist_ok=True, - ) - - # Check if schema was saved - mock_file.assert_called_once_with( - os.path.join( - "mock", - "root", - "datasets", - "test-org", - "test-dataset", - "schema.yaml", - ), - "w", - ) - # Check returned DataFrame assert isinstance(result, DataFrame) assert result.schema.name == sample_df.schema.name