From 036e203d2319465f88a10916e33294c161ed6dc2 Mon Sep 17 00:00:00 2001 From: Gaurav Tarlok Kakkar Date: Tue, 17 Oct 2023 11:09:49 -0400 Subject: [PATCH 01/12] feat: function_metadata supports boolean and float (#1296) Fixes #1288 --- .../catalog/models/function_metadata_catalog.py | 4 ++-- evadb/parser/evadb.lark | 2 +- evadb/parser/lark_visitor/_expressions.py | 6 ++++++ .../long/test_function_executor.py | 17 +++++++++++++---- 4 files changed, 22 insertions(+), 7 deletions(-) diff --git a/evadb/catalog/models/function_metadata_catalog.py b/evadb/catalog/models/function_metadata_catalog.py index 40bb00553..4b398a8c6 100644 --- a/evadb/catalog/models/function_metadata_catalog.py +++ b/evadb/catalog/models/function_metadata_catalog.py @@ -17,7 +17,7 @@ from sqlalchemy.orm import relationship from evadb.catalog.models.base_model import BaseModel -from evadb.catalog.models.utils import FunctionMetadataCatalogEntry +from evadb.catalog.models.utils import FunctionMetadataCatalogEntry, TextPickleType class FunctionMetadataCatalog(BaseModel): @@ -34,7 +34,7 @@ class FunctionMetadataCatalog(BaseModel): __tablename__ = "function_metadata_catalog" _key = Column("key", String(100)) - _value = Column("value", String(100)) + _value = Column("value", TextPickleType()) _function_id = Column( "function_id", Integer, ForeignKey("function_catalog._row_id") ) diff --git a/evadb/parser/evadb.lark b/evadb/parser/evadb.lark index c158d8e25..ab2b5c2ad 100644 --- a/evadb/parser/evadb.lark +++ b/evadb/parser/evadb.lark @@ -51,7 +51,7 @@ function_metadata: function_metadata_key function_metadata_value function_metadata_key: uid -function_metadata_value: string_literal | decimal_literal +function_metadata_value: constant vector_store_type: USING (FAISS | QDRANT | PINECONE | PGVECTOR | CHROMADB) diff --git a/evadb/parser/lark_visitor/_expressions.py b/evadb/parser/lark_visitor/_expressions.py index c5cf5a0bf..91b5be77c 100644 --- a/evadb/parser/lark_visitor/_expressions.py +++ b/evadb/parser/lark_visitor/_expressions.py @@ -41,6 +41,12 @@ def array_literal(self, tree): res = ConstantValueExpression(np.array(array_elements), ColumnType.NDARRAY) return res + def boolean_literal(self, tree): + text = tree.children[0] + if text == "TRUE": + return ConstantValueExpression(True, ColumnType.BOOLEAN) + return ConstantValueExpression(False, ColumnType.BOOLEAN) + def constant(self, tree): for child in tree.children: if isinstance(child, Tree): diff --git a/test/integration_tests/long/test_function_executor.py b/test/integration_tests/long/test_function_executor.py index 529882795..2b21f2016 100644 --- a/test/integration_tests/long/test_function_executor.py +++ b/test/integration_tests/long/test_function_executor.py @@ -199,8 +199,11 @@ def test_should_create_function_with_metadata(self): OUTPUT (label NDARRAY STR(10)) TYPE Classification IMPL 'test/util.py' - CACHE 'TRUE' - BATCH 'FALSE'; + CACHE TRUE + BATCH FALSE + INT_VAL 1 + FLOAT_VAL 1.5 + STR_VAL "gg"; """ execute_query_fetch_all(self.evadb, create_function_query.format(function_name)) @@ -208,11 +211,17 @@ def test_should_create_function_with_metadata(self): entries = self.evadb.catalog().get_function_metadata_entries_by_function_name( function_name ) - self.assertEqual(len(entries), 2) + self.assertEqual(len(entries), 5) metadata = [(entry.key, entry.value) for entry in entries] # metadata ultimately stored as lowercase string literals in metadata - expected_metadata = [("cache", "TRUE"), ("batch", "FALSE")] + expected_metadata = [ + ("cache", True), + ("batch", False), + ("int_val", 1), + ("float_val", 1.5), + ("str_val", "gg"), + ] self.assertEqual(set(metadata), set(expected_metadata)) def test_should_return_empty_metadata_list_for_missing_function(self): From e21092ca2f435d0450482ff5e00a64bcd59829b9 Mon Sep 17 00:00:00 2001 From: Gaurav Tarlok Kakkar Date: Tue, 17 Oct 2023 12:15:29 -0400 Subject: [PATCH 02/12] feat: add support for show databases (#1295) SHOW DATABASES #1252 --- evadb/catalog/catalog_manager.py | 3 ++ evadb/executor/show_info_executor.py | 5 +++ evadb/parser/evadb.lark | 3 +- evadb/parser/lark_visitor/_show_statements.py | 2 + evadb/parser/show_statement.py | 3 +- evadb/parser/types.py | 1 + evadb/plan_nodes/show_info_plan.py | 2 + .../short/test_show_info_executor.py | 43 ++++++++++++++++++- .../parser/test_parser_statements.py | 1 + 9 files changed, 60 insertions(+), 3 deletions(-) diff --git a/evadb/catalog/catalog_manager.py b/evadb/catalog/catalog_manager.py index b7c55c9bf..333d5074e 100644 --- a/evadb/catalog/catalog_manager.py +++ b/evadb/catalog/catalog_manager.py @@ -161,6 +161,9 @@ def get_database_catalog_entry(self, database_name: str) -> DatabaseCatalogEntry return table_entry + def get_all_database_catalog_entries(self): + return self._db_catalog_service.get_all_entries() + def drop_database_catalog_entry(self, database_entry: DatabaseCatalogEntry) -> bool: """ This method deletes the database from catalog. diff --git a/evadb/executor/show_info_executor.py b/evadb/executor/show_info_executor.py index e4894aacf..96dc0a537 100644 --- a/evadb/executor/show_info_executor.py +++ b/evadb/executor/show_info_executor.py @@ -32,6 +32,7 @@ def exec(self, *args, **kwargs): assert ( self.node.show_type is ShowType.FUNCTIONS or ShowType.TABLES + or ShowType.DATABASES or ShowType.CONFIG ), f"Show command does not support type {self.node.show_type}" @@ -45,6 +46,10 @@ def exec(self, *args, **kwargs): if table.table_type != TableType.SYSTEM_STRUCTURED_DATA: show_entries.append(table.name) show_entries = {"name": show_entries} + elif self.node.show_type is ShowType.DATABASES: + databases = self.catalog().get_all_database_catalog_entries() + for db in databases: + show_entries.append(db.display_format()) elif self.node.show_type is ShowType.CONFIG: value = self._config.get_value( category="default", diff --git a/evadb/parser/evadb.lark b/evadb/parser/evadb.lark index ab2b5c2ad..e834d1a7d 100644 --- a/evadb/parser/evadb.lark +++ b/evadb/parser/evadb.lark @@ -177,7 +177,7 @@ describe_statement: DESCRIBE table_name help_statement: HELP STRING_LITERAL -show_statement: SHOW (FUNCTIONS | TABLES | uid) +show_statement: SHOW (FUNCTIONS | TABLES | uid | DATABASES) explain_statement: EXPLAIN explainable_statement @@ -341,6 +341,7 @@ CHUNK_OVERLAP: "CHUNK_OVERLAP"i COLUMN: "COLUMN"i CREATE: "CREATE"i DATABASE: "DATABASE"i +DATABASES: "DATABASES"i DEFAULT: "DEFAULT"i DELETE: "DELETE"i DESC: "DESC"i diff --git a/evadb/parser/lark_visitor/_show_statements.py b/evadb/parser/lark_visitor/_show_statements.py index b278191fa..ca9581aca 100644 --- a/evadb/parser/lark_visitor/_show_statements.py +++ b/evadb/parser/lark_visitor/_show_statements.py @@ -27,5 +27,7 @@ def show_statement(self, tree): return ShowStatement(show_type=ShowType.FUNCTIONS) elif isinstance(token, str) and str.upper(token) == "TABLES": return ShowStatement(show_type=ShowType.TABLES) + elif isinstance(token, str) and str.upper(token) == "DATABASES": + return ShowStatement(show_type=ShowType.DATABASES) elif token is not None: return ShowStatement(show_type=ShowType.CONFIG, show_val=self.visit(token)) diff --git a/evadb/parser/show_statement.py b/evadb/parser/show_statement.py index ae9255fe0..d7eca052f 100644 --- a/evadb/parser/show_statement.py +++ b/evadb/parser/show_statement.py @@ -42,7 +42,8 @@ def __str__(self): show_str = "TABLES" elif self.show_type == ShowType.CONFIG: show_str = self.show_val - + elif self.show_type == ShowType.DATABASES: + show_str = "DATABASES" return f"SHOW {show_str}" def __eq__(self, other: object) -> bool: diff --git a/evadb/parser/types.py b/evadb/parser/types.py index 7cc449c29..751d2b5f3 100644 --- a/evadb/parser/types.py +++ b/evadb/parser/types.py @@ -71,6 +71,7 @@ class ShowType(EvaDBEnum): FUNCTIONS # noqa: F821 TABLES # noqa: F821 CONFIG # noqa: F821 + DATABASES # noqa: F821 class FunctionType(EvaDBEnum): diff --git a/evadb/plan_nodes/show_info_plan.py b/evadb/plan_nodes/show_info_plan.py index a6c3da5a7..733cc0401 100644 --- a/evadb/plan_nodes/show_info_plan.py +++ b/evadb/plan_nodes/show_info_plan.py @@ -36,6 +36,8 @@ def show_val(self): def __str__(self): if self._show_type == ShowType.FUNCTIONS: return "ShowFunctionPlan" + if self._show_type == ShowType.DATABASES: + return "ShowDatabasePlan" elif self._show_type == ShowType.TABLES: return "ShowTablePlan" elif self._show_type == ShowType.CONFIG: diff --git a/test/integration_tests/short/test_show_info_executor.py b/test/integration_tests/short/test_show_info_executor.py index 3f911bd4c..f875d266e 100644 --- a/test/integration_tests/short/test_show_info_executor.py +++ b/test/integration_tests/short/test_show_info_executor.py @@ -28,7 +28,7 @@ from evadb.models.storage.batch import Batch from evadb.server.command_handler import execute_query_fetch_all -NUM_FRAMES = 10 +NUM_DATABASES = 6 @pytest.mark.notparallel @@ -48,12 +48,39 @@ def setUpClass(cls): execute_query_fetch_all(cls.evadb, f"LOAD VIDEO '{mnist}' INTO MNIST;") execute_query_fetch_all(cls.evadb, f"LOAD VIDEO '{actions}' INTO Actions;") + # create databases + import os + + cls.current_file_dir = os.path.dirname(os.path.abspath(__file__)) + for i in range(NUM_DATABASES): + database_path = f"{cls.current_file_dir}/testing_{i}.db" + params = { + "database": database_path, + } + query = """CREATE DATABASE test_data_source_{} + WITH ENGINE = "sqlite", + PARAMETERS = {};""".format( + i, params + ) + execute_query_fetch_all(cls.evadb, query) + @classmethod def tearDownClass(cls): execute_query_fetch_all(cls.evadb, "DROP TABLE IF EXISTS Actions;") execute_query_fetch_all(cls.evadb, "DROP TABLE IF EXISTS MNIST;") execute_query_fetch_all(cls.evadb, "DROP TABLE IF EXISTS MyVideo;") + # remove all the DATABASES + for i in range(NUM_DATABASES): + execute_query_fetch_all( + cls.evadb, f"DROP DATABASE IF EXISTS test_data_source_{i};" + ) + database_path = f"{cls.current_file_dir}/testing_{i}.db" + import contextlib + + with contextlib.suppress(FileNotFoundError): + os.remove(database_path) + # integration test def test_show_functions(self): result = execute_query_fetch_all(self.evadb, "SHOW FUNCTIONS;") @@ -100,3 +127,17 @@ def test_show_config_execution(self): # Ensure an Exception is raised if config is not present with self.assertRaises(Exception): execute_query_fetch_all(self.evadb, "SHOW BADCONFIG") + + # integration test + def test_show_databases(self): + result = execute_query_fetch_all(self.evadb, "SHOW DATABASES;") + self.assertEqual(len(result.columns), 3) + self.assertEqual(len(result), 6) + + expected = { + "name": [f"test_data_source_{i}" for i in range(NUM_DATABASES)], + "engine": ["sqlite" for _ in range(NUM_DATABASES)], + } + expected_df = pd.DataFrame(expected) + self.assertTrue(all(expected_df.name == result.frames.name)) + self.assertTrue(all(expected_df.engine == result.frames.engine)) diff --git a/test/unit_tests/parser/test_parser_statements.py b/test/unit_tests/parser/test_parser_statements.py index f45f94bd1..eba32480e 100644 --- a/test/unit_tests/parser/test_parser_statements.py +++ b/test/unit_tests/parser/test_parser_statements.py @@ -80,6 +80,7 @@ def test_parser_statement_types(self): """, "SHOW TABLES;", "SHOW FUNCTIONS;", + "SHOW DATABASES;", "EXPLAIN SELECT a FROM foo;", "SELECT HomeRentalForecast(12);", """SELECT data FROM MyVideo WHERE id < 5 From 7d51925614ac8a1c744518592ad1eaa524f3e7b9 Mon Sep 17 00:00:00 2001 From: Gaurav Tarlok Kakkar Date: Tue, 17 Oct 2023 15:12:02 -0400 Subject: [PATCH 03/12] fix: make the table/function catalog insert operation atomic (#1293) Fixes: #1282 --- evadb/binder/statement_binder.py | 7 +++ evadb/catalog/catalog_manager.py | 13 +++-- .../services/column_catalog_service.py | 7 +-- .../services/function_catalog_service.py | 47 ++++++++++++++++++- .../services/function_io_catalog_service.py | 12 ++--- .../function_metadata_catalog_service.py | 9 ++-- .../catalog/services/table_catalog_service.py | 14 +++++- evadb/catalog/sql_config.py | 3 ++ .../long/test_create_table_executor.py | 9 ++++ .../catalog/test_catalog_manager.py | 11 +++-- 10 files changed, 99 insertions(+), 33 deletions(-) diff --git a/evadb/binder/statement_binder.py b/evadb/binder/statement_binder.py index 199c53518..f9087b5be 100644 --- a/evadb/binder/statement_binder.py +++ b/evadb/binder/statement_binder.py @@ -32,6 +32,7 @@ from evadb.binder.statement_binder_context import StatementBinderContext from evadb.catalog.catalog_type import ColumnType, TableType from evadb.catalog.catalog_utils import get_metadata_properties, is_document_table +from evadb.catalog.sql_config import RESTRICTED_COL_NAMES from evadb.configuration.constants import EvaDB_INSTALLATION_DIR from evadb.expression.abstract_expression import AbstractExpression, ExpressionType from evadb.expression.function_expression import FunctionExpression @@ -201,6 +202,12 @@ def _bind_delete_statement(self, node: DeleteTableStatement): @bind.register(CreateTableStatement) def _bind_create_statement(self, node: CreateTableStatement): + # we don't allow certain keywords in the column_names + for col in node.column_list: + assert ( + col.name.lower() not in RESTRICTED_COL_NAMES + ), f"EvaDB does not allow to create a table with column name {col.name}" + if node.query is not None: self.bind(node.query) diff --git a/evadb/catalog/catalog_manager.py b/evadb/catalog/catalog_manager.py index 333d5074e..7f63be108 100644 --- a/evadb/catalog/catalog_manager.py +++ b/evadb/catalog/catalog_manager.py @@ -343,14 +343,13 @@ def insert_function_catalog_entry( checksum = get_file_checksum(impl_file_path) function_entry = self._function_service.insert_entry( - name, impl_file_path, type, checksum + name, + impl_file_path, + type, + checksum, + function_io_list, + function_metadata_list, ) - for function_io in function_io_list: - function_io.function_id = function_entry.row_id - self._function_io_service.insert_entries(function_io_list) - for function_metadata in function_metadata_list: - function_metadata.function_id = function_entry.row_id - self._function_metadata_service.insert_entries(function_metadata_list) return function_entry def get_function_catalog_entry_by_name(self, name: str) -> FunctionCatalogEntry: diff --git a/evadb/catalog/services/column_catalog_service.py b/evadb/catalog/services/column_catalog_service.py index 8a3a7ba66..517bdfb19 100644 --- a/evadb/catalog/services/column_catalog_service.py +++ b/evadb/catalog/services/column_catalog_service.py @@ -63,7 +63,7 @@ def get_entry_by_id( return entry if return_alchemy else entry.as_dataclass() return entry - def insert_entries(self, column_list: List[ColumnCatalogEntry]): + def create_entries(self, column_list: List[ColumnCatalogEntry]): catalog_column_objs = [ self.model( name=col.name, @@ -75,10 +75,7 @@ def insert_entries(self, column_list: List[ColumnCatalogEntry]): ) for col in column_list ] - saved_column_objs = [] - for column in catalog_column_objs: - saved_column_objs.append(column.save(self.session)) - return [obj.as_dataclass() for obj in saved_column_objs] + return catalog_column_objs def filter_entries_by_table( self, table: TableCatalogEntry diff --git a/evadb/catalog/services/function_catalog_service.py b/evadb/catalog/services/function_catalog_service.py index d8c449d1c..0c6c272d3 100644 --- a/evadb/catalog/services/function_catalog_service.py +++ b/evadb/catalog/services/function_catalog_service.py @@ -12,20 +12,39 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import List + from sqlalchemy.orm import Session from sqlalchemy.sql.expression import select from evadb.catalog.models.function_catalog import FunctionCatalog, FunctionCatalogEntry +from evadb.catalog.models.utils import ( + FunctionIOCatalogEntry, + FunctionMetadataCatalogEntry, +) from evadb.catalog.services.base_service import BaseService +from evadb.catalog.services.function_io_catalog_service import FunctionIOCatalogService +from evadb.catalog.services.function_metadata_catalog_service import ( + FunctionMetadataCatalogService, +) +from evadb.utils.errors import CatalogError from evadb.utils.logging_manager import logger class FunctionCatalogService(BaseService): def __init__(self, db_session: Session): super().__init__(FunctionCatalog, db_session) + self._function_io_service = FunctionIOCatalogService(db_session) + self._function_metadata_service = FunctionMetadataCatalogService(db_session) def insert_entry( - self, name: str, impl_path: str, type: str, checksum: str + self, + name: str, + impl_path: str, + type: str, + checksum: str, + function_io_list: List[FunctionIOCatalogEntry], + function_metadata_list: List[FunctionMetadataCatalogEntry], ) -> FunctionCatalogEntry: """Insert a new function entry @@ -40,7 +59,31 @@ def insert_entry( """ function_obj = self.model(name, impl_path, type, checksum) function_obj = function_obj.save(self.session) - return function_obj.as_dataclass() + + for function_io in function_io_list: + function_io.function_id = function_obj._row_id + io_objs = self._function_io_service.create_entries(function_io_list) + for function_metadata in function_metadata_list: + function_metadata.function_id = function_obj._row_id + metadata_objs = self._function_metadata_service.create_entries( + function_metadata_list + ) + + # atomic operation for adding table and its corresponding columns. + try: + self.session.add_all(io_objs) + self.session.add_all(metadata_objs) + self.session.commit() + except Exception as e: + self.session.rollback() + self.session.delete(function_obj) + self.session.commit() + logger.exception( + f"Failed to insert entry into function catalog with exception {str(e)}" + ) + raise CatalogError(e) + else: + return function_obj.as_dataclass() def get_entry_by_name(self, name: str) -> FunctionCatalogEntry: """return the function entry that matches the name provided. diff --git a/evadb/catalog/services/function_io_catalog_service.py b/evadb/catalog/services/function_io_catalog_service.py index 290f3d20e..caf339098 100644 --- a/evadb/catalog/services/function_io_catalog_service.py +++ b/evadb/catalog/services/function_io_catalog_service.py @@ -69,13 +69,8 @@ def get_output_entries_by_function_id( logger.error(error) raise RuntimeError(error) - def insert_entries(self, io_list: List[FunctionIOCatalogEntry]): - """Commit entries to the function_io table - - Arguments: - io_list (List[FunctionIOCatalogEntry]): List of io info io be added - """ - + def create_entries(self, io_list: List[FunctionIOCatalogEntry]): + io_objs = [] for io in io_list: io_obj = FunctionIOCatalog( name=io.name, @@ -86,4 +81,5 @@ def insert_entries(self, io_list: List[FunctionIOCatalogEntry]): is_input=io.is_input, function_id=io.function_id, ) - io_obj.save(self.session) + io_objs.append(io_obj) + return io_objs diff --git a/evadb/catalog/services/function_metadata_catalog_service.py b/evadb/catalog/services/function_metadata_catalog_service.py index e302ea41e..2629b8040 100644 --- a/evadb/catalog/services/function_metadata_catalog_service.py +++ b/evadb/catalog/services/function_metadata_catalog_service.py @@ -30,17 +30,16 @@ class FunctionMetadataCatalogService(BaseService): def __init__(self, db_session: Session): super().__init__(FunctionMetadataCatalog, db_session) - def insert_entries(self, entries: List[FunctionMetadataCatalogEntry]): + def create_entries(self, entries: List[FunctionMetadataCatalogEntry]): + metadata_objs = [] try: for entry in entries: metadata_obj = FunctionMetadataCatalog( key=entry.key, value=entry.value, function_id=entry.function_id ) - metadata_obj.save(self.session) + metadata_objs.append(metadata_obj) + return metadata_objs except Exception as e: - logger.exception( - f"Failed to insert entry {entry} into function metadata catalog with exception {str(e)}" - ) raise CatalogError(e) def get_entries_by_function_id( diff --git a/evadb/catalog/services/table_catalog_service.py b/evadb/catalog/services/table_catalog_service.py index dafd6dc2a..2ca1e2e9b 100644 --- a/evadb/catalog/services/table_catalog_service.py +++ b/evadb/catalog/services/table_catalog_service.py @@ -57,8 +57,20 @@ def insert_entry( # populate the table_id for all the columns for column in column_list: column.table_id = table_catalog_obj._row_id - column_list = self._column_service.insert_entries(column_list) + column_list = self._column_service.create_entries(column_list) + # atomic operation for adding table and its corresponding columns. + try: + self.session.add_all(column_list) + self.session.commit() + except Exception as e: + self.session.rollback() + self.session.delete(table_catalog_obj) + self.session.commit() + logger.exception( + f"Failed to insert entry into table catalog with exception {str(e)}" + ) + raise CatalogError(e) except Exception as e: logger.exception( f"Failed to insert entry into table catalog with exception {str(e)}" diff --git a/evadb/catalog/sql_config.py b/evadb/catalog/sql_config.py index f4893ba99..778d2cb24 100644 --- a/evadb/catalog/sql_config.py +++ b/evadb/catalog/sql_config.py @@ -41,6 +41,9 @@ "function_cost_catalog", "function_metadata_catalog", ] +# Add all keywords that are restricted by EvaDB + +RESTRICTED_COL_NAMES = [IDENTIFIER_COLUMN] class SingletonMeta(type): diff --git a/test/integration_tests/long/test_create_table_executor.py b/test/integration_tests/long/test_create_table_executor.py index b4e4856b5..7f1bf38de 100644 --- a/test/integration_tests/long/test_create_table_executor.py +++ b/test/integration_tests/long/test_create_table_executor.py @@ -122,6 +122,15 @@ def test_create_table_with_incorrect_info(self): execute_query_fetch_all(self.evadb, create_table) execute_query_fetch_all(self.evadb, "DROP TABLE SlackCSV;") + def test_create_table_with_restricted_keywords(self): + create_table = "CREATE TABLE hello (_row_id INTEGER, price TEXT);" + with self.assertRaises(AssertionError): + execute_query_fetch_all(self.evadb, create_table) + + create_table = "CREATE TABLE hello2 (_ROW_id INTEGER, price TEXT);" + with self.assertRaises(AssertionError): + execute_query_fetch_all(self.evadb, create_table) + if __name__ == "__main__": unittest.main() diff --git a/test/unit_tests/catalog/test_catalog_manager.py b/test/unit_tests/catalog/test_catalog_manager.py index 3c0e23a3c..6149be34c 100644 --- a/test/unit_tests/catalog/test_catalog_manager.py +++ b/test/unit_tests/catalog/test_catalog_manager.py @@ -142,12 +142,13 @@ def test_insert_function( function_io_list, function_metadata_list, ) - functionio_mock.return_value.insert_entries.assert_called_with(function_io_list) - functionmetadata_mock.return_value.insert_entries.assert_called_with( - function_metadata_list - ) function_mock.return_value.insert_entry.assert_called_with( - "function", "sample.py", "classification", checksum_mock.return_value + "function", + "sample.py", + "classification", + checksum_mock.return_value, + function_io_list, + function_metadata_list, ) checksum_mock.assert_called_with("sample.py") self.assertEqual(actual, function_mock.return_value.insert_entry.return_value) From b1143047dc4495c5cb17f5387d365f6a4dbb2325 Mon Sep 17 00:00:00 2001 From: Gaurav Tarlok Kakkar Date: Tue, 17 Oct 2023 17:15:05 -0400 Subject: [PATCH 04/12] fix: improve testcase (#1294) Test default values of `chunk_size` and `chunk_overlap` --- evadb/configuration/constants.py | 2 + evadb/optimizer/statement_to_opr_converter.py | 7 +++- evadb/parser/lark_visitor/_expressions.py | 8 ++-- evadb/parser/utils.py | 4 +- evadb/readers/document/document_reader.py | 10 ++++- .../relational/test_relational_api.py | 40 ++++++++++++++++--- .../test_statement_to_opr_converter.py | 9 ++++- 7 files changed, 64 insertions(+), 16 deletions(-) diff --git a/evadb/configuration/constants.py b/evadb/configuration/constants.py index 51513462d..8a6f95b5c 100644 --- a/evadb/configuration/constants.py +++ b/evadb/configuration/constants.py @@ -32,3 +32,5 @@ S3_DOWNLOAD_DIR = "s3_downloads" TMP_DIR = "tmp" DEFAULT_TRAIN_TIME_LIMIT = 120 +DEFAULT_DOCUMENT_CHUNK_SIZE = 4000 +DEFAULT_DOCUMENT_CHUNK_OVERLAP = 200 diff --git a/evadb/optimizer/statement_to_opr_converter.py b/evadb/optimizer/statement_to_opr_converter.py index fa0b08605..c60d0b258 100644 --- a/evadb/optimizer/statement_to_opr_converter.py +++ b/evadb/optimizer/statement_to_opr_converter.py @@ -74,7 +74,12 @@ def visit_table_ref(self, table_ref: TableRef): if table_ref.is_table_atom(): # Table catalog_entry = table_ref.table.table_obj - self._plan = LogicalGet(table_ref, catalog_entry, table_ref.alias) + self._plan = LogicalGet( + table_ref, + catalog_entry, + table_ref.alias, + chunk_params=table_ref.chunk_params, + ) elif table_ref.is_table_valued_expr(): tve = table_ref.table_valued_expr diff --git a/evadb/parser/lark_visitor/_expressions.py b/evadb/parser/lark_visitor/_expressions.py index 91b5be77c..6ec01cf99 100644 --- a/evadb/parser/lark_visitor/_expressions.py +++ b/evadb/parser/lark_visitor/_expressions.py @@ -145,18 +145,18 @@ def chunk_params(self, tree): assert len(chunk_params) == 2 or len(chunk_params) == 4 if len(chunk_params) == 4: return { - "chunk_size": ConstantValueExpression(chunk_params[1]), - "chunk_overlap": ConstantValueExpression(chunk_params[3]), + "chunk_size": chunk_params[1], + "chunk_overlap": chunk_params[3], } elif len(chunk_params) == 2: if chunk_params[0] == "CHUNK_SIZE": return { - "chunk_size": ConstantValueExpression(chunk_params[1]), + "chunk_size": chunk_params[1], } elif chunk_params[0] == "CHUNK_OVERLAP": return { - "chunk_overlap": ConstantValueExpression(chunk_params[1]), + "chunk_overlap": chunk_params[1], } else: assert f"incorrect keyword found {chunk_params[0]}" diff --git a/evadb/parser/utils.py b/evadb/parser/utils.py index 28f13e0a9..a2be06ec1 100644 --- a/evadb/parser/utils.py +++ b/evadb/parser/utils.py @@ -51,9 +51,9 @@ def parse_predicate_expression(expr: str): def parse_table_clause(expr: str, chunk_size: int = None, chunk_overlap: int = None): mock_query_parts = [f"SELECT * FROM {expr}"] - if chunk_size: + if chunk_size is not None: mock_query_parts.append(f"CHUNK_SIZE {chunk_size}") - if chunk_overlap: + if chunk_overlap is not None: mock_query_parts.append(f"CHUNK_OVERLAP {chunk_overlap}") mock_query_parts.append(";") mock_query = " ".join(mock_query_parts) diff --git a/evadb/readers/document/document_reader.py b/evadb/readers/document/document_reader.py index a11d0c748..4e90fa8b6 100644 --- a/evadb/readers/document/document_reader.py +++ b/evadb/readers/document/document_reader.py @@ -16,6 +16,10 @@ from typing import Dict, Iterator from evadb.catalog.sql_config import ROW_NUM_COLUMN +from evadb.configuration.constants import ( + DEFAULT_DOCUMENT_CHUNK_OVERLAP, + DEFAULT_DOCUMENT_CHUNK_SIZE, +) from evadb.readers.abstract_reader import AbstractReader from evadb.readers.document.registry import ( _lazy_import_loader, @@ -31,8 +35,10 @@ def __init__(self, *args, chunk_params, **kwargs): # https://github.com/hwchase17/langchain/blob/5b6bbf4ab2a33ed0d33ff5d3cb3979a7edc15682/langchain/text_splitter.py#L570 # by default we use chunk_size 4000 and overlap 200 - self._chunk_size = chunk_params.get("chunk_size", 4000) - self._chunk_overlap = chunk_params.get("chunk_overlap", 200) + self._chunk_size = chunk_params.get("chunk_size", DEFAULT_DOCUMENT_CHUNK_SIZE) + self._chunk_overlap = chunk_params.get( + "chunk_overlap", DEFAULT_DOCUMENT_CHUNK_OVERLAP + ) def _read(self) -> Iterator[Dict]: ext = Path(self.file_url).suffix diff --git a/test/integration_tests/long/interfaces/relational/test_relational_api.py b/test/integration_tests/long/interfaces/relational/test_relational_api.py index 6df6ef920..bfacb9315 100644 --- a/test/integration_tests/long/interfaces/relational/test_relational_api.py +++ b/test/integration_tests/long/interfaces/relational/test_relational_api.py @@ -27,7 +27,12 @@ from pandas.testing import assert_frame_equal from evadb.binder.binder_utils import BinderError -from evadb.configuration.constants import EvaDB_DATABASE_DIR, EvaDB_ROOT_DIR +from evadb.configuration.constants import ( + DEFAULT_DOCUMENT_CHUNK_OVERLAP, + DEFAULT_DOCUMENT_CHUNK_SIZE, + EvaDB_DATABASE_DIR, + EvaDB_ROOT_DIR, +) from evadb.executor.executor_utils import ExecutorError from evadb.interfaces.relational.db import connect from evadb.models.storage.batch import Batch @@ -392,22 +397,47 @@ def test_langchain_split_doc(self): load_pdf.execute() result1 = ( - cursor.table("docs", chunk_size=2000, chunk_overlap=0).select("data").df() + cursor.table( + "docs", chunk_size=2000, chunk_overlap=DEFAULT_DOCUMENT_CHUNK_OVERLAP + ) + .select("data") + .df() ) result2 = ( - cursor.table("docs", chunk_size=4000, chunk_overlap=2000) + cursor.table( + "docs", chunk_size=DEFAULT_DOCUMENT_CHUNK_SIZE, chunk_overlap=2000 + ) .select("data") .df() ) - self.assertEqual(len(result1), len(result2)) + result3 = ( + cursor.table( + "docs", chunk_size=DEFAULT_DOCUMENT_CHUNK_SIZE, chunk_overlap=0 + ) + .select("data") + .df() + ) + + self.assertGreater(len(result1), len(result2)) + self.assertGreater(len(result2), len(result3)) + # should use default value of chunk_overlap and respect chunk_size + result5 = cursor.table("docs", chunk_size=2000).select("data").df() + self.assertEqual(len(result5), len(result1)) + + # should use the default value of chunk_size and should respect chunk_overlap + result4 = cursor.table("docs", chunk_overlap=0).select("data").df() + self.assertEqual(len(result3), len(result4)) + + # should use the default values result1 = cursor.table("docs").select("data").df() result2 = cursor.query( - "SELECT data from docs chunk_size 4000 chunk_overlap 200" + f"SELECT data from docs chunk_size {DEFAULT_DOCUMENT_CHUNK_SIZE} chunk_overlap {DEFAULT_DOCUMENT_CHUNK_OVERLAP}" ).df() + self.assertEqual(len(result1), len(result2)) def test_show_relational(self): diff --git a/test/unit_tests/optimizer/test_statement_to_opr_converter.py b/test/unit_tests/optimizer/test_statement_to_opr_converter.py index beeeac94b..12d5d4f08 100644 --- a/test/unit_tests/optimizer/test_statement_to_opr_converter.py +++ b/test/unit_tests/optimizer/test_statement_to_opr_converter.py @@ -63,11 +63,16 @@ class StatementToOprTest(unittest.TestCase): @patch("evadb.optimizer.statement_to_opr_converter.LogicalGet") def test_visit_table_ref_should_create_logical_get_opr(self, mock_lget): converter = StatementToPlanConverter() - table_ref = MagicMock(spec=TableRef, alias="alias") + table_ref = MagicMock(spec=TableRef, alias="alias", chunk_params={}) table_ref.is_select.return_value = False table_ref.sample_freq = None converter.visit_table_ref(table_ref) - mock_lget.assert_called_with(table_ref, table_ref.table.table_obj, "alias") + mock_lget.assert_called_with( + table_ref, + table_ref.table.table_obj, + "alias", + chunk_params=table_ref.chunk_params, + ) self.assertEqual(mock_lget.return_value, converter._plan) @patch("evadb.optimizer.statement_to_opr_converter.LogicalFilter") From 201f901bee6af81e1790e7c4c8f2a5a05093ca6d Mon Sep 17 00:00:00 2001 From: jineetd <35962652+jineetd@users.noreply.github.com> Date: Wed, 18 Oct 2023 01:52:04 -0400 Subject: [PATCH 05/12] Starting the change for XGBoost integration into EVADb. (#1232) Co-authored-by: Jineet Desai Co-authored-by: Andy Xu --- docs/_toc.yml | 2 + .../reference/ai/model-train-xgboost.rst | 26 +++++++ evadb/binder/statement_binder.py | 4 +- evadb/configuration/constants.py | 1 + evadb/executor/create_function_executor.py | 69 +++++++++++++++++++ evadb/functions/sklearn.py | 16 ++--- evadb/functions/xgboost.py | 48 +++++++++++++ evadb/utils/generic_utils.py | 19 +++++ setup.py | 5 +- .../long/test_model_train.py | 21 +++++- test/markers.py | 5 ++ 11 files changed, 205 insertions(+), 11 deletions(-) create mode 100644 docs/source/reference/ai/model-train-xgboost.rst create mode 100644 evadb/functions/xgboost.py diff --git a/docs/_toc.yml b/docs/_toc.yml index a8639dec3..3b3eeda5e 100644 --- a/docs/_toc.yml +++ b/docs/_toc.yml @@ -88,6 +88,8 @@ parts: title: Model Training with Ludwig - file: source/reference/ai/model-train-sklearn title: Model Training with Sklearn + - file: source/reference/ai/model-train-xgboost + title: Model Training with XGBoost - file: source/reference/ai/model-forecasting title: Time Series Forecasting - file: source/reference/ai/hf diff --git a/docs/source/reference/ai/model-train-xgboost.rst b/docs/source/reference/ai/model-train-xgboost.rst new file mode 100644 index 000000000..b53c87d48 --- /dev/null +++ b/docs/source/reference/ai/model-train-xgboost.rst @@ -0,0 +1,26 @@ +.. _xgboost: + +Model Training with XGBoost +============================ + +1. Installation +--------------- + +To use the `Flaml XGBoost AutoML framework `_, we need to install the extra Flaml dependency in your EvaDB virtual environment. + +.. code-block:: bash + + pip install "flaml[automl]" + +2. Example Query +---------------- + +.. code-block:: sql + + CREATE FUNCTION IF NOT EXISTS PredictRent FROM + ( SELECT number_of_rooms, number_of_bathrooms, days_on_market, rental_price FROM HomeRentals ) + TYPE XGBoost + PREDICT 'rental_price'; + +In the above query, you are creating a new customized function by training a model from the ``HomeRentals`` table using the ``Flaml XGBoost`` framework. +The ``rental_price`` column will be the target column for predication, while the rest columns from the ``SELET`` query are the inputs. diff --git a/evadb/binder/statement_binder.py b/evadb/binder/statement_binder.py index f9087b5be..f1e949941 100644 --- a/evadb/binder/statement_binder.py +++ b/evadb/binder/statement_binder.py @@ -102,7 +102,9 @@ def _bind_create_function_statement(self, node: CreateFunctionStatement): outputs.append(column) else: inputs.append(column) - elif string_comparison_case_insensitive(node.function_type, "sklearn"): + elif string_comparison_case_insensitive( + node.function_type, "sklearn" + ) or string_comparison_case_insensitive(node.function_type, "XGBoost"): assert ( "predict" in arg_map ), f"Creating {node.function_type} functions expects 'predict' metadata." diff --git a/evadb/configuration/constants.py b/evadb/configuration/constants.py index 8a6f95b5c..395f898be 100644 --- a/evadb/configuration/constants.py +++ b/evadb/configuration/constants.py @@ -34,3 +34,4 @@ DEFAULT_TRAIN_TIME_LIMIT = 120 DEFAULT_DOCUMENT_CHUNK_SIZE = 4000 DEFAULT_DOCUMENT_CHUNK_OVERLAP = 200 +DEFAULT_TRAIN_REGRESSION_METRIC = "rmse" diff --git a/evadb/executor/create_function_executor.py b/evadb/executor/create_function_executor.py index 0b4ddbf7c..379157563 100644 --- a/evadb/executor/create_function_executor.py +++ b/evadb/executor/create_function_executor.py @@ -25,6 +25,7 @@ from evadb.catalog.models.function_io_catalog import FunctionIOCatalogEntry from evadb.catalog.models.function_metadata_catalog import FunctionMetadataCatalogEntry from evadb.configuration.constants import ( + DEFAULT_TRAIN_REGRESSION_METRIC, DEFAULT_TRAIN_TIME_LIMIT, EvaDB_INSTALLATION_DIR, ) @@ -44,6 +45,7 @@ try_to_import_statsforecast, try_to_import_torch, try_to_import_ultralytics, + try_to_import_xgboost, ) from evadb.utils.logging_manager import logger @@ -152,6 +154,10 @@ def handle_sklearn_function(self): self.node.metadata.append( FunctionMetadataCatalogEntry("model_path", model_path) ) + # Pass the prediction column name to sklearn.py + self.node.metadata.append( + FunctionMetadataCatalogEntry("predict_col", arg_map["predict"]) + ) impl_path = Path(f"{self.function_dir}/sklearn.py").absolute().as_posix() io_list = self._resolve_function_io(None) @@ -163,6 +169,61 @@ def handle_sklearn_function(self): self.node.metadata, ) + def handle_xgboost_function(self): + """Handle xgboost functions + + We use the Flaml AutoML model for training xgboost models. + """ + try_to_import_xgboost() + + assert ( + len(self.children) == 1 + ), "Create sklearn function expects 1 child, finds {}.".format( + len(self.children) + ) + + aggregated_batch_list = [] + child = self.children[0] + for batch in child.exec(): + aggregated_batch_list.append(batch) + aggregated_batch = Batch.concat(aggregated_batch_list, copy=False) + aggregated_batch.drop_column_alias() + + arg_map = {arg.key: arg.value for arg in self.node.metadata} + from flaml import AutoML + + model = AutoML() + settings = { + "time_budget": arg_map.get("time_limit", DEFAULT_TRAIN_TIME_LIMIT), + "metric": arg_map.get("metric", DEFAULT_TRAIN_REGRESSION_METRIC), + "estimator_list": ["xgboost"], + "task": "regression", + } + model.fit( + dataframe=aggregated_batch.frames, label=arg_map["predict"], **settings + ) + model_path = os.path.join( + self.db.config.get_value("storage", "model_dir"), self.node.name + ) + pickle.dump(model, open(model_path, "wb")) + self.node.metadata.append( + FunctionMetadataCatalogEntry("model_path", model_path) + ) + # Pass the prediction column to xgboost.py. + self.node.metadata.append( + FunctionMetadataCatalogEntry("predict_col", arg_map["predict"]) + ) + + impl_path = Path(f"{self.function_dir}/xgboost.py").absolute().as_posix() + io_list = self._resolve_function_io(None) + return ( + self.node.name, + impl_path, + self.node.function_type, + io_list, + self.node.metadata, + ) + def handle_ultralytics_function(self): """Handle Ultralytics functions""" try_to_import_ultralytics() @@ -516,6 +577,14 @@ def exec(self, *args, **kwargs): io_list, metadata, ) = self.handle_sklearn_function() + elif string_comparison_case_insensitive(self.node.function_type, "XGBoost"): + ( + name, + impl_path, + function_type, + io_list, + metadata, + ) = self.handle_xgboost_function() elif string_comparison_case_insensitive(self.node.function_type, "Forecasting"): ( name, diff --git a/evadb/functions/sklearn.py b/evadb/functions/sklearn.py index ca3676f14..4ab2b0abf 100644 --- a/evadb/functions/sklearn.py +++ b/evadb/functions/sklearn.py @@ -25,21 +25,21 @@ class GenericSklearnModel(AbstractFunction): def name(self) -> str: return "GenericSklearnModel" - def setup(self, model_path: str, **kwargs): + def setup(self, model_path: str, predict_col: str, **kwargs): try_to_import_sklearn() self.model = pickle.load(open(model_path, "rb")) + self.predict_col = predict_col def forward(self, frames: pd.DataFrame) -> pd.DataFrame: - # The last column is the predictor variable column. Hence we do not - # pass that column in the predict method for sklearn. - predictions = self.model.predict(frames.iloc[:, :-1]) + # Do not pass the prediction column in the predict method for sklearn. + frames.drop([self.predict_col], axis=1, inplace=True) + predictions = self.model.predict(frames) predict_df = pd.DataFrame(predictions) # We need to rename the column of the output dataframe. For this we - # shall rename it to the column name same as that of the last column of - # frames. This is because the last column of frames corresponds to the - # variable we want to predict. - predict_df.rename(columns={0: frames.columns[-1]}, inplace=True) + # shall rename it to the column name same as that of the predict column + # passed in the training frames in EVA query. + predict_df.rename(columns={0: self.predict_col}, inplace=True) return predict_df def to_device(self, device: str): diff --git a/evadb/functions/xgboost.py b/evadb/functions/xgboost.py new file mode 100644 index 000000000..063529411 --- /dev/null +++ b/evadb/functions/xgboost.py @@ -0,0 +1,48 @@ +# coding=utf-8 +# Copyright 2018-2023 EvaDB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pickle + +import pandas as pd + +from evadb.functions.abstract.abstract_function import AbstractFunction +from evadb.utils.generic_utils import try_to_import_xgboost + + +class GenericXGBoostModel(AbstractFunction): + @property + def name(self) -> str: + return "GenericXGBoostModel" + + def setup(self, model_path: str, predict_col: str, **kwargs): + try_to_import_xgboost() + + self.model = pickle.load(open(model_path, "rb")) + self.predict_col = predict_col + + def forward(self, frames: pd.DataFrame) -> pd.DataFrame: + # We do not pass the prediction column to the predict method of XGBoost + # AutoML. + frames.drop([self.predict_col], axis=1, inplace=True) + predictions = self.model.predict(frames) + predict_df = pd.DataFrame(predictions) + # We need to rename the column of the output dataframe. For this we + # shall rename it to the column name same as that of the predict column + # passed to EVA query. + predict_df.rename(columns={0: self.predict_col}, inplace=True) + return predict_df + + def to_device(self, device: str): + # TODO figure out how to control the GPU for ludwig models + return self diff --git a/evadb/utils/generic_utils.py b/evadb/utils/generic_utils.py index a444fb983..fb6bd9986 100644 --- a/evadb/utils/generic_utils.py +++ b/evadb/utils/generic_utils.py @@ -377,6 +377,25 @@ def is_sklearn_available() -> bool: return False +def try_to_import_xgboost(): + try: + import flaml # noqa: F401 + from flaml import AutoML # noqa: F401 + except ImportError: + raise ValueError( + """Could not import Flaml AutoML. + Please install it with `pip install "flaml[automl]"`.""" + ) + + +def is_xgboost_available() -> bool: + try: + try_to_import_xgboost() + return True + except ValueError: # noqa: E722 + return False + + ############################## ## VISION ############################## diff --git a/setup.py b/setup.py index 9c488c939..a18796d84 100644 --- a/setup.py +++ b/setup.py @@ -120,6 +120,8 @@ def read(path, encoding="utf-8"): sklearn_libs = ["scikit-learn"] +xgboost_libs = ["flaml[automl]"] + forecasting_libs = [ "statsforecast", # MODEL TRAIN AND FINE TUNING "neuralforecast" # MODEL TRAIN AND FINE TUNING @@ -169,9 +171,10 @@ def read(path, encoding="utf-8"): "postgres": postgres_libs, "ludwig": ludwig_libs, "sklearn": sklearn_libs, + "xgboost": xgboost_libs, "forecasting": forecasting_libs, # everything except ray, qdrant, ludwig and postgres. The first three fail on pyhton 3.11. - "dev": dev_libs + vision_libs + document_libs + function_libs + notebook_libs + forecasting_libs + sklearn_libs + imagegen_libs, + "dev": dev_libs + vision_libs + document_libs + function_libs + notebook_libs + forecasting_libs + sklearn_libs + imagegen_libs + xgboost_libs } setup( diff --git a/test/integration_tests/long/test_model_train.py b/test/integration_tests/long/test_model_train.py index 7424ba424..85e508f4d 100644 --- a/test/integration_tests/long/test_model_train.py +++ b/test/integration_tests/long/test_model_train.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import unittest -from test.markers import ludwig_skip_marker, sklearn_skip_marker +from test.markers import ludwig_skip_marker, sklearn_skip_marker, xgboost_skip_marker from test.util import get_evadb_for_testing, shutdown_ray import pytest @@ -95,6 +95,25 @@ def test_sklearn_regression(self): self.assertEqual(len(result.columns), 1) self.assertEqual(len(result), 10) + @xgboost_skip_marker + def test_xgboost_regression(self): + create_predict_function = """ + CREATE FUNCTION IF NOT EXISTS PredictRent FROM + ( SELECT number_of_rooms, number_of_bathrooms, days_on_market, rental_price FROM HomeRentals ) + TYPE XGBoost + PREDICT 'rental_price' + TIME_LIMIT 180 + METRIC 'r2'; + """ + execute_query_fetch_all(self.evadb, create_predict_function) + + predict_query = """ + SELECT PredictRent(number_of_rooms, number_of_bathrooms, days_on_market, rental_price) FROM HomeRentals LIMIT 10; + """ + result = execute_query_fetch_all(self.evadb, predict_query) + self.assertEqual(len(result.columns), 1) + self.assertEqual(len(result), 10) + if __name__ == "__main__": unittest.main() diff --git a/test/markers.py b/test/markers.py index 7d98e5534..3e95c1cff 100644 --- a/test/markers.py +++ b/test/markers.py @@ -27,6 +27,7 @@ is_qdrant_available, is_replicate_available, is_sklearn_available, + is_xgboost_available, ) asyncio_skip_marker = pytest.mark.skipif( @@ -89,6 +90,10 @@ is_sklearn_available() is False, reason="Run only if sklearn is available" ) +xgboost_skip_marker = pytest.mark.skipif( + is_xgboost_available() is False, reason="Run only if xgboost is available" +) + chatgpt_skip_marker = pytest.mark.skip( reason="requires chatgpt", ) From b8dd206d0550ff5b0ad03e1026702f236caacb33 Mon Sep 17 00:00:00 2001 From: Abhijith S Raj <63101280+sudoboi@users.noreply.github.com> Date: Wed, 18 Oct 2023 11:48:33 -0400 Subject: [PATCH 06/12] Add Documentation for UDF Unit Testing and Mocking (and minor Stable Diffusion Fix) (#1301) --- docs/_toc.yml | 2 + .../source/dev-guide/contribute/unit-test.rst | 80 +++++++++++++++++++ evadb/functions/dalle.py | 2 +- evadb/functions/stable_diffusion.py | 2 +- 4 files changed, 84 insertions(+), 2 deletions(-) create mode 100644 docs/source/dev-guide/contribute/unit-test.rst diff --git a/docs/_toc.yml b/docs/_toc.yml index 3b3eeda5e..6a37f9826 100644 --- a/docs/_toc.yml +++ b/docs/_toc.yml @@ -136,6 +136,8 @@ parts: title: Code Style - file: source/dev-guide/contribute/troubleshoot title: Troubleshooting + - file: source/dev-guide/contribute/unit-test + title: Unit Testing UDFs in EvaDB - file: source/dev-guide/debugging title: Debugging EvaDB diff --git a/docs/source/dev-guide/contribute/unit-test.rst b/docs/source/dev-guide/contribute/unit-test.rst new file mode 100644 index 000000000..e5c957f87 --- /dev/null +++ b/docs/source/dev-guide/contribute/unit-test.rst @@ -0,0 +1,80 @@ +Unit Testing UDFs in EvaDB +=========================== + +Introduction +------------ + +Unit testing is a crucial aspect of software development. When working with User Defined Functions (UDFs) in EvaDB, it's essential to ensure that they work correctly. This guide will walk you through the process of writing unit tests for UDFs and using mocking to simulate external dependencies. + +Setting Up Test Environment +--------------------------- + +Before writing tests, set up a test environment. This often involves creating a test database or table and populating it with sample data. + +.. code-block:: python + + def setUp(self) -> None: + self.evadb = get_evadb_for_testing() + self.evadb.catalog().reset() + create_table_query = """CREATE TABLE IF NOT EXISTS TestTable ( + prompt TEXT(100)); + """ + execute_query_fetch_all(self.evadb, create_table_query) + test_prompts = ["sample prompt"] + for prompt in test_prompts: + insert_query = f"""INSERT INTO TestTable (prompt) VALUES ('{prompt}')""" + execute_query_fetch_all(self.evadb, insert_query) + +Mocking External Dependencies +----------------------------- + +When testing UDFs that rely on external services or APIs, use mocking to simulate these dependencies. + +.. code-block:: python + + @patch("requests.get") + @patch("external_library.Method", return_value={"data": [{"url": "mocked_url"}]}) + def test_udf(self, mock_method, mock_requests_get): + # Mock the response from the external service + mock_response = MagicMock() + mock_response.content = "mocked content" + mock_requests_get.return_value = mock_response + + # Rest of the test code... + +Writing the Test +---------------- + +After setting up the environment and mocking dependencies, write the test for the UDF. + +.. code-block:: python + + function_name = "ImageDownloadUDF" + query = f"SELECT {function_name}(prompt) FROM TestTable;" + output = execute_query_fetch_all(self.evadb, query) + expected_output = [...] # Expected output + self.assertEqual(output, expected_output) + +Cleaning Up After Tests +----------------------- + +Clean up any resources used during testing, such as database tables. + +.. code-block:: python + + def tearDown(self) -> None: + execute_query_fetch_all(self.evadb, "DROP TABLE IF EXISTS TestTable;") + +Running the Tests +----------------- + +Run the tests using a test runner like `unittest`. + +.. code-block:: bash + + python -m unittest path_to_your_test_module.py + +Conclusion +---------- + +Unit testing UDFs in EvaDB ensures their correctness and robustness. Mocking allows for simulating external dependencies, making tests faster and more deterministic. diff --git a/evadb/functions/dalle.py b/evadb/functions/dalle.py index efc075d73..d373fda38 100644 --- a/evadb/functions/dalle.py +++ b/evadb/functions/dalle.py @@ -62,7 +62,7 @@ def forward(self, text_df): # Register API key, try configuration manager first openai.api_key = ConfigurationManager().get_value("third_party", "OPENAI_KEY") # If not found, try OS Environment Variable - if len(openai.api_key) == 0: + if openai.api_key is None or len(openai.api_key) == 0: openai.api_key = os.environ.get("OPENAI_KEY", "") assert ( len(openai.api_key) != 0 diff --git a/evadb/functions/stable_diffusion.py b/evadb/functions/stable_diffusion.py index 6e84d687f..044195547 100644 --- a/evadb/functions/stable_diffusion.py +++ b/evadb/functions/stable_diffusion.py @@ -69,7 +69,7 @@ def forward(self, text_df): "third_party", "REPLICATE_API_TOKEN" ) # If not found, try OS Environment Variable - if len(replicate_api_key) == 0: + if replicate_api_key is None: replicate_api_key = os.environ.get("REPLICATE_API_TOKEN", "") assert ( len(replicate_api_key) != 0 From f192a10e8c5cc0712931480f4d62915710250c4b Mon Sep 17 00:00:00 2001 From: Andy Xu Date: Wed, 18 Oct 2023 12:50:32 -0700 Subject: [PATCH 07/12] Reenable batch for release (#1302) --- evadb/storage/native_storage_engine.py | 6 +++--- evadb/storage/sqlite_storage_engine.py | 14 +++++++------- .../long/test_github_datasource.py | 3 +++ 3 files changed, 13 insertions(+), 10 deletions(-) diff --git a/evadb/storage/native_storage_engine.py b/evadb/storage/native_storage_engine.py index 66fede5bb..ed8f5e8be 100644 --- a/evadb/storage/native_storage_engine.py +++ b/evadb/storage/native_storage_engine.py @@ -28,7 +28,7 @@ from evadb.models.storage.batch import Batch from evadb.storage.abstract_storage_engine import AbstractStorageEngine from evadb.third_party.databases.interface import get_database_handler -from evadb.utils.generic_utils import PickleSerializer +from evadb.utils.generic_utils import PickleSerializer, rebatch from evadb.utils.logging_manager import logger @@ -190,8 +190,8 @@ def read( _deserialize_sql_row(row, ordered_columns) for row in result ) - for data_batch in result: - yield Batch(pd.DataFrame([data_batch])) + for df in rebatch(result, batch_mem_size): + yield Batch(pd.DataFrame(df)) except Exception as e: err_msg = f"Failed to read the table {table.name} in data source {table.database_name} with exception {str(e)}" diff --git a/evadb/storage/sqlite_storage_engine.py b/evadb/storage/sqlite_storage_engine.py index 91b72bb44..2c3335f56 100644 --- a/evadb/storage/sqlite_storage_engine.py +++ b/evadb/storage/sqlite_storage_engine.py @@ -29,7 +29,7 @@ from evadb.models.storage.batch import Batch from evadb.parser.table_ref import TableInfo from evadb.storage.abstract_storage_engine import AbstractStorageEngine -from evadb.utils.generic_utils import PickleSerializer +from evadb.utils.generic_utils import PickleSerializer, rebatch from evadb.utils.logging_manager import logger # Leveraging Dynamic schema in SQLAlchemy @@ -189,12 +189,12 @@ def read( try: table_to_read = self._try_loading_table_via_reflection(table.name) result = self._sql_session.execute(table_to_read.select()).fetchall() - for row in result: - yield Batch( - pd.DataFrame( - [self._deserialize_sql_row(row._asdict(), table.columns)] - ) - ) + result_iter = ( + self._deserialize_sql_row(row._asdict(), table.columns) + for row in result + ) + for df in rebatch(result_iter, batch_mem_size): + yield Batch(pd.DataFrame(df)) except Exception as e: err_msg = f"Failed to read the table {table.name} with exception {str(e)}" logger.exception(err_msg) diff --git a/test/integration_tests/long/test_github_datasource.py b/test/integration_tests/long/test_github_datasource.py index 0c02dff19..1d00728b2 100644 --- a/test/integration_tests/long/test_github_datasource.py +++ b/test/integration_tests/long/test_github_datasource.py @@ -31,6 +31,9 @@ def setUp(self): def tearDown(self): execute_query_fetch_all(self.evadb, "DROP DATABASE IF EXISTS github_data;") + @pytest.mark.skip( + reason="Need https://github.com/georgia-tech-db/evadb/pull/1280 for a cost-based rebatch optimization" + ) @pytest.mark.xfail(reason="Flaky testcase due to `bad request` error message") def test_should_run_select_query_in_github(self): # Create database. From c3b45b61100b1b343d96f748cc8dd97ca4f52cdf Mon Sep 17 00:00:00 2001 From: Gaurav Tarlok Kakkar Date: Wed, 18 Oct 2023 17:42:06 -0400 Subject: [PATCH 08/12] v0.3.8 - new release (#1303) --- evadb/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/evadb/version.py b/evadb/version.py index 2b48fce4d..e58316190 100644 --- a/evadb/version.py +++ b/evadb/version.py @@ -1,6 +1,6 @@ _MAJOR = "0" _MINOR = "3" -_REVISION = "8+dev" +_REVISION = "8" VERSION_SHORT = f"{_MAJOR}.{_MINOR}" VERSION = f"{_MAJOR}.{_MINOR}.{_REVISION}" From 89e48889186d919fdade64961d1a8277bdc9e946 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Wed, 18 Oct 2023 22:18:37 -0700 Subject: [PATCH 09/12] Bump Version to v0.3.9+dev (#1304) Bump Version to v0.3.9+dev --------- Co-authored-by: Jiashen Cao Co-authored-by: Gaurav Tarlok Kakkar --- CHANGELOG.md | 31 +++++++++++++++++++ README.md | 2 +- .../reference/ai/custom-ai-function.rst | 2 +- evadb/version.py | 2 +- 4 files changed, 34 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5557e9fad..f8ef0b95d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,37 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### [Deprecated] ### [Removed] +## [0.3.8] - 2023-10-18 + +* PR #1303: v0.3.8 - new release +* PR #1302: Reenable batch for release +* PR #1301: Add Documentation for UDF Unit Testing and Mocking +* PR #1232: Starting the change for XGBoost integration into EVADb. +* PR #1294: fix: improve testcase +* PR #1293: fix: make the table/function catalog insert operation atomic +* PR #1295: feat: add support for show databases +* PR #1296: feat: function_metadata supports boolean and float +* PR #1290: fix: text_summarization uses drop udf +* PR #1240: Add stable diffusion integration +* PR #1285: Update custom-ai-function.rst +* PR #1234: Added basic functionalities of REST apis +* PR #1281: Clickhouse integration +* PR #1273: Update custom-ai-function.rst +* PR #1274: Fix Notebook and Ray testcases at staging +* PR #1264: SHOW command for retrieveing configurations +* PR #1270: fix: Catalog init introduces significant overhead +* PR #1267: Improve the error message when there is a typo in the column name in the query. +* PR #1261: Remove dimensions from `TEXT` and `FLOAT` +* PR #1256: Remove table names from column names for `df +* PR #1253: Collection of fixes for the staging branch +* PR #1246: feat: insertion update index +* PR #1245: Documentation on vector stores + vector benchmark +* PR #1244: feat: create index from projection +* PR #1233: GitHub Data Source Integration +* PR #1115: Add support for Neuralforecast +* PR #1241: Bump Version to v0.3.8+dev +* PR #1239: release 0.3.7 + ## [0.3.7] - 2023-09-30 * PR #1239: release 0.3.7 diff --git a/README.md b/README.md index 05fd08772..92b3a5e68 100644 --- a/README.md +++ b/README.md @@ -116,7 +116,7 @@ Our target audience is software developers who may not necessarily have a backgr
    -
  • Connect EvaDB to your SQL and vector database systems with the `CREATE DATABASE` and `CREATE INDEX` statements.
  • +
  • Connect EvaDB to your SQL and vector database systems with the `CREATE DATABASE` and `CREATE INDEX` statements.
  • Write SQL queries with AI functions to get inference results:
    • Pick a pre-trained AI model from Hugging Face, Open AI, Ultralytics, PyTorch, and built-in AI frameworks for generative AI, NLP, and vision applications;
    • diff --git a/docs/source/reference/ai/custom-ai-function.rst b/docs/source/reference/ai/custom-ai-function.rst index 71c19bb91..68e1d6461 100644 --- a/docs/source/reference/ai/custom-ai-function.rst +++ b/docs/source/reference/ai/custom-ai-function.rst @@ -44,7 +44,7 @@ The abstract method `setup` must be implemented in your function. The setup func Any additional arguments needed for creating the function must be passed as arguments to the setup function. (Please refer to the `ChatGPT `__ function example). -The additional arguments are passed with the CREATE command. Please refer to `CREATE `_ command documentation. +The additional arguments are passed with the CREATE command. Please refer to `CREATE `_ command documentation. The custom setup operations for the function can be written inside the function in the child class. If there is no need for any custom logic, then you can just simply write "pass" in the function definition. diff --git a/evadb/version.py b/evadb/version.py index e58316190..c6ac2fe6f 100644 --- a/evadb/version.py +++ b/evadb/version.py @@ -1,6 +1,6 @@ _MAJOR = "0" _MINOR = "3" -_REVISION = "8" +_REVISION = "9+dev" VERSION_SHORT = f"{_MAJOR}.{_MINOR}" VERSION = f"{_MAJOR}.{_MINOR}.{_REVISION}" From e19f13da3624bc8bb19292b121a12fa8c5ea6b36 Mon Sep 17 00:00:00 2001 From: Sayan Sinha Date: Thu, 19 Oct 2023 09:54:42 -0400 Subject: [PATCH 10/12] Fix current issues with forecasting (#1283) This PR aims to solve the following issues: - [x] Throwing error when non-numeric characters are in the data (partially fixes #1243) - [x] Math domain error with `statsforecast`. - [x] Fix GPU support for `neuralforecast`. - ~Neuralforecast support for directly using batched data.~ - ~Auto frequency determination ( #1279).~ Will create separate PRs for the last two points. --------- Co-authored-by: Andy Xu --- evadb/executor/create_function_executor.py | 76 +++++++++++++++++++--- evadb/functions/forecast.py | 15 +++-- 2 files changed, 77 insertions(+), 14 deletions(-) diff --git a/evadb/executor/create_function_executor.py b/evadb/executor/create_function_executor.py index 379157563..367110b1d 100644 --- a/evadb/executor/create_function_executor.py +++ b/evadb/executor/create_function_executor.py @@ -12,9 +12,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import contextlib import hashlib +import locale import os import pickle +import re from pathlib import Path from typing import Dict, List @@ -50,6 +53,31 @@ from evadb.utils.logging_manager import logger +# From https://stackoverflow.com/a/34333710 +@contextlib.contextmanager +def set_env(**environ): + """ + Temporarily set the process environment variables. + + >>> with set_env(PLUGINS_DIR='test/plugins'): + ... "PLUGINS_DIR" in os.environ + True + + >>> "PLUGINS_DIR" in os.environ + False + + :type environ: dict[str, unicode] + :param environ: Environment variables to set + """ + old_environ = dict(os.environ) + os.environ.update(environ) + try: + yield + finally: + os.environ.clear() + os.environ.update(old_environ) + + class CreateFunctionExecutor(AbstractExecutor): def __init__(self, db: EvaDBDatabase, node: CreateFunctionPlan): super().__init__(db, node) @@ -169,6 +197,15 @@ def handle_sklearn_function(self): self.node.metadata, ) + def convert_to_numeric(self, x): + x = re.sub("[^0-9.,]", "", str(x)) + locale.setlocale(locale.LC_ALL, "") + x = float(locale.atof(x)) + if x.is_integer(): + return int(x) + else: + return x + def handle_xgboost_function(self): """Handle xgboost functions @@ -245,7 +282,6 @@ def handle_ultralytics_function(self): def handle_forecasting_function(self): """Handle forecasting functions""" - os.environ["CUDA_VISIBLE_DEVICES"] = "" aggregated_batch_list = [] child = self.children[0] for batch in child.exec(): @@ -369,7 +405,7 @@ def handle_forecasting_function(self): model_args["input_size"] = 2 * horizon model_args["early_stop_patience_steps"] = 20 else: - model_args["config"] = { + model_args_config = { "input_size": 2 * horizon, "early_stop_patience_steps": 20, } @@ -381,7 +417,13 @@ def handle_forecasting_function(self): if "auto" not in arg_map["model"].lower(): model_args["hist_exog_list"] = exogenous_columns else: - model_args["config"]["hist_exog_list"] = exogenous_columns + model_args_config["hist_exog_list"] = exogenous_columns + + def get_optuna_config(trial): + return model_args_config + + model_args["config"] = get_optuna_config + model_args["backend"] = "optuna" model_args["h"] = horizon @@ -455,13 +497,31 @@ def handle_forecasting_function(self): ] if len(existing_model_files) == 0: logger.info("Training, please wait...") + for column in data.columns: + if column != "ds" and column != "unique_id": + data[column] = data.apply( + lambda x: self.convert_to_numeric(x[column]), axis=1 + ) if library == "neuralforecast": - model.fit(df=data, val_size=horizon) + cuda_devices_here = "0" + if "CUDA_VISIBLE_DEVICES" in os.environ: + cuda_devices_here = os.environ["CUDA_VISIBLE_DEVICES"].split(",")[0] + + with set_env(CUDA_VISIBLE_DEVICES=cuda_devices_here): + model.fit(df=data, val_size=horizon) + model.save(model_path, overwrite=True) else: + # The following lines of code helps eliminate the math error encountered in statsforecast when only one datapoint is available in a time series + for col in data["unique_id"].unique(): + if len(data[data["unique_id"] == col]) == 1: + data = data._append( + [data[data["unique_id"] == col]], ignore_index=True + ) + model.fit(df=data[["ds", "y", "unique_id"]]) - f = open(model_path, "wb") - pickle.dump(model, f) - f.close() + f = open(model_path, "wb") + pickle.dump(model, f) + f.close() elif not Path(model_path).exists(): model_path = os.path.join(model_dir, existing_model_files[-1]) @@ -483,8 +543,6 @@ def handle_forecasting_function(self): FunctionMetadataCatalogEntry("library", library), ] - os.environ.pop("CUDA_VISIBLE_DEVICES", None) - return ( self.node.name, impl_path, diff --git a/evadb/functions/forecast.py b/evadb/functions/forecast.py index 1571f6c4f..46376852d 100644 --- a/evadb/functions/forecast.py +++ b/evadb/functions/forecast.py @@ -38,16 +38,21 @@ def setup( horizon: int, library: str, ): - f = open(model_path, "rb") - loaded_model = pickle.load(f) - f.close() + self.library = library + if "neuralforecast" in self.library: + from neuralforecast import NeuralForecast + + loaded_model = NeuralForecast.load(path=model_path) + self.model_name = model_name[4:] if "Auto" in model_name else model_name + else: + with open(model_path, "rb") as f: + loaded_model = pickle.load(f) + self.model_name = model_name self.model = loaded_model - self.model_name = model_name self.predict_column_rename = predict_column_rename self.time_column_rename = time_column_rename self.id_column_rename = id_column_rename self.horizon = int(horizon) - self.library = library def forward(self, data) -> pd.DataFrame: if self.library == "statsforecast": From 4640d8fc6579ad2f5f8425814a23c1850d3bb8b4 Mon Sep 17 00:00:00 2001 From: Gaurav Tarlok Kakkar Date: Thu, 19 Oct 2023 14:03:51 -0400 Subject: [PATCH 11/12] Configuration Manager Redesign (#1272) 1. Removed `config.yml` file. Users can directly use `SET` command. 2. Moved `OPENAI_KEY` to `OPENAI_API_KEY` --------- Co-authored-by: hershd23 Co-authored-by: Andy Xu --- .circleci/config.yml | 2 +- apps/pandas_qa/pandas_qa.py | 4 +- apps/youtube_qa/youtube_qa.py | 4 +- docs/source/overview/concepts.rst | 2 +- .../reference/ai/custom-ai-function.rst | 3 - docs/source/usecases/question-answering.rst | 2 +- evadb/binder/statement_binder.py | 21 ++++- evadb/catalog/catalog_manager.py | 67 ++++++++++----- evadb/catalog/catalog_utils.py | 26 +++--- evadb/catalog/models/base_model.py | 36 -------- evadb/catalog/models/configuration_catalog.py | 43 ++++++++++ evadb/catalog/models/utils.py | 24 +++++- .../services/configuration_catalog_service.py | 76 +++++++++++++++++ .../services/function_cost_catalog_service.py | 2 +- evadb/catalog/sql_config.py | 34 ++------ evadb/configuration/bootstrap_environment.py | 82 ++++++------------- evadb/configuration/configuration_manager.py | 71 ---------------- evadb/configuration/constants.py | 2 +- evadb/database.py | 33 ++++---- evadb/evadb.yml | 30 ------- evadb/evadb_cmd_client.py | 10 +-- evadb/evadb_config.py | 39 +++++++++ evadb/executor/abstract_executor.py | 6 -- evadb/executor/create_function_executor.py | 12 ++- evadb/executor/create_index_executor.py | 6 +- evadb/executor/drop_object_executor.py | 4 +- evadb/executor/execution_context.py | 14 +--- evadb/executor/executor_utils.py | 10 ++- evadb/executor/load_multimedia_executor.py | 4 +- evadb/executor/set_executor.py | 6 +- evadb/executor/show_info_executor.py | 3 +- evadb/executor/vector_index_scan_executor.py | 4 +- .../abstract/pytorch_abstract_function.py | 4 +- evadb/functions/chatgpt.py | 11 ++- evadb/functions/dalle.py | 14 ++-- evadb/functions/stable_diffusion.py | 14 +--- evadb/optimizer/optimizer_context.py | 4 +- evadb/optimizer/plan_generator.py | 4 +- evadb/optimizer/rules/rules.py | 9 +- evadb/optimizer/rules/rules_manager.py | 7 +- evadb/parser/evadb.lark | 2 +- evadb/server/server.py | 2 +- evadb/third_party/vector_stores/pinecone.py | 15 ++-- evadb/third_party/vector_stores/utils.py | 1 + test/app_tests/test_pandas_qa.py | 4 +- test/app_tests/test_privategpt.py | 4 +- test/app_tests/test_youtube_channel_qa.py | 6 +- test/app_tests/test_youtube_qa.py | 6 +- .../long/test_create_index_executor.py | 2 +- .../long/test_error_handling_with_ray.py | 4 +- .../long/test_explain_executor.py | 4 +- .../long/test_load_executor.py | 4 +- .../long/test_optimizer_rules.py | 17 +--- test/integration_tests/long/test_pytorch.py | 8 +- test/integration_tests/long/test_reuse.py | 2 +- .../long/test_s3_load_executor.py | 4 +- .../integration_tests/long/test_similarity.py | 2 +- .../short/test_set_executor.py | 4 +- .../catalog/test_catalog_manager.py | 26 +++--- .../executor/test_execution_context.py | 34 ++------ .../optimizer/rules/test_batch_mem_size.py | 4 +- test/unit_tests/optimizer/rules/test_rules.py | 19 ++--- .../optimizer/test_optimizer_task.py | 10 +-- test/unit_tests/test_dalle.py | 2 +- test/unit_tests/test_eva_cmd_client.py | 28 +++---- test/util.py | 16 ++-- 66 files changed, 492 insertions(+), 487 deletions(-) create mode 100644 evadb/catalog/models/configuration_catalog.py create mode 100644 evadb/catalog/services/configuration_catalog_service.py delete mode 100644 evadb/configuration/configuration_manager.py delete mode 100644 evadb/evadb.yml create mode 100644 evadb/evadb_config.py diff --git a/.circleci/config.yml b/.circleci/config.yml index 01984e154..cb7ad985d 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -234,7 +234,7 @@ jobs: else pip install ".[dev,pinecone,chromadb]" # ray < 2.5.0 does not work with python 3.11 ray-project/ray#33864 fi - python -c "import yaml;f = open('evadb/evadb.yml', 'r+');config_obj = yaml.load(f, Loader=yaml.FullLoader);config_obj['experimental']['ray'] = True;f.seek(0);f.write(yaml.dump(config_obj));f.truncate();" + python -c "import evadb;cur=evadb.connect().cursor();cur.query('SET ray=True';)" else if [ $PY_VERSION != "3.11" ]; then pip install ".[dev,ludwig,qdrant,pinecone,chromadb]" diff --git a/apps/pandas_qa/pandas_qa.py b/apps/pandas_qa/pandas_qa.py index 8e4b6868a..b81ce27d1 100644 --- a/apps/pandas_qa/pandas_qa.py +++ b/apps/pandas_qa/pandas_qa.py @@ -53,10 +53,10 @@ def receive_user_input() -> Dict: # get OpenAI key if needed try: - api_key = os.environ["OPENAI_KEY"] + api_key = os.environ["OPENAI_API_KEY"] except KeyError: api_key = str(input("🔑 Enter your OpenAI key: ")) - os.environ["OPENAI_KEY"] = api_key + os.environ["OPENAI_API_KEY"] = api_key return user_input diff --git a/apps/youtube_qa/youtube_qa.py b/apps/youtube_qa/youtube_qa.py index 5a56bbe29..ee4626473 100644 --- a/apps/youtube_qa/youtube_qa.py +++ b/apps/youtube_qa/youtube_qa.py @@ -93,10 +93,10 @@ def receive_user_input() -> Dict: # get OpenAI key if needed try: - api_key = os.environ["OPENAI_KEY"] + api_key = os.environ["OPENAI_API_KEY"] except KeyError: api_key = str(input("🔑 Enter your OpenAI key: ")) - os.environ["OPENAI_KEY"] = api_key + os.environ["OPENAI_API_KEY"] = api_key return user_input diff --git a/docs/source/overview/concepts.rst b/docs/source/overview/concepts.rst index c9dba3215..f2905b640 100644 --- a/docs/source/overview/concepts.rst +++ b/docs/source/overview/concepts.rst @@ -46,7 +46,7 @@ Here are some illustrative **AI queries** for a ChatGPT-based video question ans --- The 'transcripts' table has a column called 'text' with the transcript text --- Since ChatGPT is a built-in function in EvaDB, we don't have to define it --- We can directly use ChatGPT() in any query - --- We will only need to set the OPENAI_KEY as an environment variable + --- We will only need to set the OPENAI_API_KEY as an environment variable SELECT ChatGPT('Is this video summary related to Ukraine russia war', text) FROM TEXT_SUMMARY; diff --git a/docs/source/reference/ai/custom-ai-function.rst b/docs/source/reference/ai/custom-ai-function.rst index 68e1d6461..3db8457be 100644 --- a/docs/source/reference/ai/custom-ai-function.rst +++ b/docs/source/reference/ai/custom-ai-function.rst @@ -258,9 +258,6 @@ The following code can be used to create an Object Detection function using Yolo try_to_import_openai() import openai - #setting up the key - openai.api_key = ConfigurationManager().get_value("third_party", "OPENAI_KEY") - #getting the data content = text_df[text_df.columns[0]] responses = [] diff --git a/docs/source/usecases/question-answering.rst b/docs/source/usecases/question-answering.rst index 7a1235da0..15f548d9c 100644 --- a/docs/source/usecases/question-answering.rst +++ b/docs/source/usecases/question-answering.rst @@ -57,7 +57,7 @@ EvaDB has built-in support for ``ChatGPT`` function from ``OpenAI``. You will ne # Set OpenAI key import os - os.environ["OPENAI_KEY"] = "sk-..." + os.environ["OPENAI_API_KEY"] = "sk-..." .. note:: diff --git a/evadb/binder/statement_binder.py b/evadb/binder/statement_binder.py index f1e949941..128e6e7ee 100644 --- a/evadb/binder/statement_binder.py +++ b/evadb/binder/statement_binder.py @@ -34,6 +34,7 @@ from evadb.catalog.catalog_utils import get_metadata_properties, is_document_table from evadb.catalog.sql_config import RESTRICTED_COL_NAMES from evadb.configuration.constants import EvaDB_INSTALLATION_DIR +from evadb.executor.execution_context import Context from evadb.expression.abstract_expression import AbstractExpression, ExpressionType from evadb.expression.function_expression import FunctionExpression from evadb.expression.tuple_value_expression import TupleValueExpression @@ -273,6 +274,11 @@ def _bind_tuple_expr(self, node: TupleValueExpression): @bind.register(FunctionExpression) def _bind_func_expr(self, node: FunctionExpression): + # setup the context + # we read the GPUs from the catalog and populate in the context + gpus_ids = self._catalog().get_configuration_catalog_value("gpu_ids") + node._context = Context(gpus_ids) + # handle the special case of "extract_object" if node.name.upper() == str(FunctionType.EXTRACT_OBJECT): handle_bind_extract_object_function(node, self) @@ -340,9 +346,18 @@ def _bind_func_expr(self, node: FunctionExpression): ) # certain functions take additional inputs like yolo needs the model_name # these arguments are passed by the user as part of metadata - node.function = lambda: function_class( - **get_metadata_properties(function_obj) - ) + # we also handle the special case of ChatGPT where we need to send the + # OpenAPI key as part of the parameter if not provided by the user + properties = get_metadata_properties(function_obj) + if string_comparison_case_insensitive(node.name, "CHATGPT"): + # if the user didn't provide any API_KEY, check if we have one in the catalog + if "OPENAI_API_KEY" not in properties.keys(): + openapi_key = self._catalog().get_configuration_catalog_value( + "OPENAI_API_KEY" + ) + properties["openai_api_key"] = openapi_key + + node.function = lambda: function_class(**properties) except Exception as e: err_msg = ( f"{str(e)}. Please verify that the function class name in the " diff --git a/evadb/catalog/catalog_manager.py b/evadb/catalog/catalog_manager.py index 7f63be108..20c50c9df 100644 --- a/evadb/catalog/catalog_manager.py +++ b/evadb/catalog/catalog_manager.py @@ -14,7 +14,7 @@ # limitations under the License. import shutil from pathlib import Path -from typing import List +from typing import Any, List from evadb.catalog.catalog_type import ( ColumnType, @@ -23,7 +23,6 @@ VideoColumnName, ) from evadb.catalog.catalog_utils import ( - cleanup_storage, construct_function_cache_catalog_entry, get_document_table_column_definitions, get_image_table_column_definitions, @@ -46,6 +45,9 @@ truncate_catalog_tables, ) from evadb.catalog.services.column_catalog_service import ColumnCatalogService +from evadb.catalog.services.configuration_catalog_service import ( + ConfigurationCatalogService, +) from evadb.catalog.services.database_catalog_service import DatabaseCatalogService from evadb.catalog.services.function_cache_catalog_service import ( FunctionCacheCatalogService, @@ -61,23 +63,28 @@ from evadb.catalog.services.index_catalog_service import IndexCatalogService from evadb.catalog.services.table_catalog_service import TableCatalogService from evadb.catalog.sql_config import IDENTIFIER_COLUMN, SQLConfig -from evadb.configuration.configuration_manager import ConfigurationManager from evadb.expression.function_expression import FunctionExpression from evadb.parser.create_statement import ColumnDefinition from evadb.parser.table_ref import TableInfo from evadb.parser.types import FileFormatType from evadb.third_party.databases.interface import get_database_handler -from evadb.utils.generic_utils import generate_file_path, get_file_checksum +from evadb.utils.generic_utils import ( + generate_file_path, + get_file_checksum, + remove_directory_contents, +) from evadb.utils.logging_manager import logger class CatalogManager(object): - def __init__(self, db_uri: str, config: ConfigurationManager): + def __init__(self, db_uri: str): self._db_uri = db_uri self._sql_config = SQLConfig(db_uri) - self._config = config self._bootstrap_catalog() self._db_catalog_service = DatabaseCatalogService(self._sql_config.session) + self._config_catalog_service = ConfigurationCatalogService( + self._sql_config.session + ) self._table_catalog_service = TableCatalogService(self._sql_config.session) self._column_service = ColumnCatalogService(self._sql_config.session) self._function_service = FunctionCatalogService(self._sql_config.session) @@ -130,10 +137,14 @@ def _clear_catalog_contents(self): logger.info("Clearing catalog") # drop tables which are not part of catalog drop_all_tables_except_catalog(self._sql_config.engine) - # truncate the catalog tables - truncate_catalog_tables(self._sql_config.engine) + # truncate the catalog tables except configuration_catalog + # We do not remove the configuration entries + truncate_catalog_tables( + self._sql_config.engine, tables_not_to_truncate=["configuration_catalog"] + ) # clean up the dataset, index, and cache directories - cleanup_storage(self._config) + for folder in ["cache_dir", "index_dir", "datasets_dir"]: + remove_directory_contents(self.get_configuration_catalog_value(folder)) "Database catalog services" @@ -447,7 +458,7 @@ def get_all_index_catalog_entries(self): """ Function Cache related""" def insert_function_cache_catalog_entry(self, func_expr: FunctionExpression): - cache_dir = self._config.get_value("storage", "cache_dir") + cache_dir = self.get_configuration_catalog_value("cache_dir") entry = construct_function_cache_catalog_entry(func_expr, cache_dir=cache_dir) return self._function_cache_service.insert_entry(entry) @@ -510,7 +521,7 @@ def create_and_insert_table_catalog_entry( table_name = table_info.table_name column_catalog_entries = xform_column_definitions_to_catalog_entries(columns) - dataset_location = self._config.get_value("core", "datasets_dir") + dataset_location = self.get_configuration_catalog_value("datasets_dir") file_url = str(generate_file_path(dataset_location, table_name)) table_catalog_entry = self.insert_table_catalog_entry( table_name, @@ -610,14 +621,28 @@ def create_and_insert_multimedia_metadata_table_catalog_entry( ) return obj + "Configuration catalog services" + + def upsert_configuration_catalog_entry(self, key: str, value: any): + """Upserts configuration catalog entry" + + Args: + key: key name + value: value name + """ + self._config_catalog_service.upsert_entry(key, value) + + def get_configuration_catalog_value(self, key: str, default: Any = None) -> Any: + """ + Returns the value entry for the given key + Arguments: + key (str): key name + + Returns: + ConfigurationCatalogEntry + """ -#### get catalog instance -# This function plays a crucial role in ensuring that different threads do -# not share the same catalog object, as it can result in serialization issues and -# incorrect behavior with SQLAlchemy. Therefore, whenever a catalog instance is -# required, we create a new one. One possible optimization is to share the catalog -# instance across all objects within the same thread. It is worth investigating whether -# SQLAlchemy already handles this optimization for us, which will be explored at a -# later time. -def get_catalog_instance(db_uri: str, config: ConfigurationManager): - return CatalogManager(db_uri, config) + table_entry = self._config_catalog_service.get_entry_by_name(key) + if table_entry: + return table_entry.value + return default diff --git a/evadb/catalog/catalog_utils.py b/evadb/catalog/catalog_utils.py index d2187978d..35fb7f6a8 100644 --- a/evadb/catalog/catalog_utils.py +++ b/evadb/catalog/catalog_utils.py @@ -32,11 +32,10 @@ TableCatalogEntry, ) from evadb.catalog.sql_config import IDENTIFIER_COLUMN -from evadb.configuration.configuration_manager import ConfigurationManager from evadb.expression.function_expression import FunctionExpression from evadb.expression.tuple_value_expression import TupleValueExpression from evadb.parser.create_statement import ColConstraintInfo, ColumnDefinition -from evadb.utils.generic_utils import get_str_hash, remove_directory_contents +from evadb.utils.generic_utils import get_str_hash def is_video_table(table: TableCatalogEntry): @@ -256,12 +255,6 @@ def construct_function_cache_catalog_entry( return entry -def cleanup_storage(config): - remove_directory_contents(config.get_value("storage", "index_dir")) - remove_directory_contents(config.get_value("storage", "cache_dir")) - remove_directory_contents(config.get_value("core", "datasets_dir")) - - def get_metadata_entry_or_val( function_obj: FunctionCatalogEntry, key: str, default_val: Any = None ) -> str: @@ -300,6 +293,19 @@ def get_metadata_properties(function_obj: FunctionCatalogEntry) -> Dict: return properties +def bootstrap_configs(catalog, configs: dict): + """ + load all the configuration values into the catalog table configuration_catalog + """ + for key, value in configs.items(): + catalog.upsert_configuration_catalog_entry(key, value) + + +def get_configuration_value(key: str): + catalog = get_catalog_instance() + return catalog.get_configuration_catalog_value(key) + + #### get catalog instance # This function plays a crucial role in ensuring that different threads do # not share the same catalog object, as it can result in serialization issues and @@ -308,7 +314,7 @@ def get_metadata_properties(function_obj: FunctionCatalogEntry) -> Dict: # instance across all objects within the same thread. It is worth investigating whether # SQLAlchemy already handles this optimization for us, which will be explored at a # later time. -def get_catalog_instance(db_uri: str, config: ConfigurationManager): +def get_catalog_instance(db_uri: str): from evadb.catalog.catalog_manager import CatalogManager - return CatalogManager(db_uri, config) + return CatalogManager(db_uri) diff --git a/evadb/catalog/models/base_model.py b/evadb/catalog/models/base_model.py index 8e915af24..f31fc10b4 100644 --- a/evadb/catalog/models/base_model.py +++ b/evadb/catalog/models/base_model.py @@ -12,16 +12,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import contextlib - -import sqlalchemy from sqlalchemy import Column, Integer -from sqlalchemy.engine import Engine from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy_utils import database_exists -from evadb.catalog.sql_config import CATALOG_TABLES from evadb.utils.logging_manager import logger @@ -100,33 +94,3 @@ def _commit(self, db_session): # Custom Base Model to be inherited by all models BaseModel = declarative_base(cls=CustomModel, constructor=None) - - -def truncate_catalog_tables(engine: Engine): - """Truncate all the catalog tables""" - # https://stackoverflow.com/questions/4763472/sqlalchemy-clear-database-content-but-dont-drop-the-schema/5003705#5003705 #noqa - # reflect to refresh the metadata - BaseModel.metadata.reflect(bind=engine) - insp = sqlalchemy.inspect(engine) - if database_exists(engine.url): - with contextlib.closing(engine.connect()) as con: - trans = con.begin() - for table in reversed(BaseModel.metadata.sorted_tables): - if insp.has_table(table.name): - con.execute(table.delete()) - trans.commit() - - -def drop_all_tables_except_catalog(engine: Engine): - """drop all the tables except the catalog""" - # reflect to refresh the metadata - BaseModel.metadata.reflect(bind=engine) - insp = sqlalchemy.inspect(engine) - if database_exists(engine.url): - with contextlib.closing(engine.connect()) as con: - trans = con.begin() - for table in reversed(BaseModel.metadata.sorted_tables): - if table.name not in CATALOG_TABLES: - if insp.has_table(table.name): - table.drop(con) - trans.commit() diff --git a/evadb/catalog/models/configuration_catalog.py b/evadb/catalog/models/configuration_catalog.py new file mode 100644 index 000000000..157b07eaf --- /dev/null +++ b/evadb/catalog/models/configuration_catalog.py @@ -0,0 +1,43 @@ +# coding=utf-8 +# Copyright 2018-2023 EvaDB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from sqlalchemy import Column, String + +from evadb.catalog.models.base_model import BaseModel +from evadb.catalog.models.utils import ConfigurationCatalogEntry, TextPickleType + + +class ConfigurationCatalog(BaseModel): + """The `ConfigurationCatalog` catalog stores all the configuration params. + `_row_id:` an autogenerated unique identifier. + `_key:` the key for the config. + `_value:` the value for the config + """ + + __tablename__ = "configuration_catalog" + + _key = Column("key", String(100), unique=True) + _value = Column("value", TextPickleType()) + + def __init__(self, key: str, value: any): + self._key = key + self._value = value + + def as_dataclass(self) -> "ConfigurationCatalogEntry": + return ConfigurationCatalogEntry( + row_id=self._row_id, + key=self._key, + value=self._value, + ) diff --git a/evadb/catalog/models/utils.py b/evadb/catalog/models/utils.py index b1c067aa0..5da3a2eef 100644 --- a/evadb/catalog/models/utils.py +++ b/evadb/catalog/models/utils.py @@ -61,7 +61,7 @@ def init_db(engine: Engine): BaseModel.metadata.create_all(bind=engine) -def truncate_catalog_tables(engine: Engine): +def truncate_catalog_tables(engine: Engine, tables_not_to_truncate: List[str] = []): """Truncate all the catalog tables""" # https://stackoverflow.com/questions/4763472/sqlalchemy-clear-database-content-but-dont-drop-the-schema/5003705#5003705 #noqa # reflect to refresh the metadata @@ -71,8 +71,9 @@ def truncate_catalog_tables(engine: Engine): with contextlib.closing(engine.connect()) as con: trans = con.begin() for table in reversed(BaseModel.metadata.sorted_tables): - if insp.has_table(table.name): - con.execute(table.delete()) + if table.name not in tables_not_to_truncate: + if insp.has_table(table.name): + con.execute(table.delete()) trans.commit() @@ -257,3 +258,20 @@ def display_format(self): "engine": self.engine, "params": self.params, } + + +@dataclass(unsafe_hash=True) +class ConfigurationCatalogEntry: + """Dataclass representing an entry in the `ConfigurationCatalog`. + This is done to ensure we don't expose the sqlalchemy dependencies beyond catalog service. Further, sqlalchemy does not allow sharing of objects across threads. + """ + + key: str + value: str + row_id: int = None + + def display_format(self): + return { + "key": self.key, + "value": self.value, + } diff --git a/evadb/catalog/services/configuration_catalog_service.py b/evadb/catalog/services/configuration_catalog_service.py new file mode 100644 index 000000000..bda22253c --- /dev/null +++ b/evadb/catalog/services/configuration_catalog_service.py @@ -0,0 +1,76 @@ +# coding=utf-8 +# Copyright 2018-2023 EvaDB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from sqlalchemy.orm import Session +from sqlalchemy.sql.expression import select + +from evadb.catalog.models.configuration_catalog import ConfigurationCatalog +from evadb.catalog.models.utils import ConfigurationCatalogEntry +from evadb.catalog.services.base_service import BaseService +from evadb.utils.errors import CatalogError +from evadb.utils.logging_manager import logger + + +class ConfigurationCatalogService(BaseService): + def __init__(self, db_session: Session): + super().__init__(ConfigurationCatalog, db_session) + + def insert_entry( + self, + key: str, + value: any, + ): + try: + config_catalog_obj = self.model(key=key, value=value) + config_catalog_obj = config_catalog_obj.save(self.session) + + except Exception as e: + logger.exception( + f"Failed to insert entry into database catalog with exception {str(e)}" + ) + raise CatalogError(e) + + def get_entry_by_name(self, key: str) -> ConfigurationCatalogEntry: + """ + Get the table catalog entry with given table name. + Arguments: + key (str): key name + Returns: + Configuration Catalog Entry - catalog entry for given key name + """ + entry = self.session.execute( + select(self.model).filter(self.model._key == key) + ).scalar_one_or_none() + if entry: + return entry.as_dataclass() + return entry + + def upsert_entry( + self, + key: str, + value: any, + ): + try: + entry = self.session.execute( + select(self.model).filter(self.model._key == key) + ).scalar_one_or_none() + if entry: + entry.update(self.session, _value=value) + else: + self.insert_entry(key, value) + except Exception as e: + raise CatalogError( + f"Error while upserting entry to ConfigurationCatalog: {str(e)}" + ) diff --git a/evadb/catalog/services/function_cost_catalog_service.py b/evadb/catalog/services/function_cost_catalog_service.py index ac84e8c94..5e20ac303 100644 --- a/evadb/catalog/services/function_cost_catalog_service.py +++ b/evadb/catalog/services/function_cost_catalog_service.py @@ -62,7 +62,7 @@ def upsert_entry(self, function_id: int, name: str, new_cost: int): select(self.model).filter(self.model._function_id == function_id) ).scalar_one_or_none() if function_obj: - function_obj.update(self.session, cost=new_cost) + function_obj.update(self.session, _cost=new_cost) else: self.insert_entry(function_id, name, new_cost) except Exception as e: diff --git a/evadb/catalog/sql_config.py b/evadb/catalog/sql_config.py index 778d2cb24..0a460a899 100644 --- a/evadb/catalog/sql_config.py +++ b/evadb/catalog/sql_config.py @@ -17,9 +17,6 @@ from sqlalchemy import create_engine, event from sqlalchemy.orm import scoped_session, sessionmaker -from sqlalchemy.pool import NullPool - -from evadb.utils.generic_utils import is_postgres_uri, parse_config_yml # Permanent identifier column. IDENTIFIER_COLUMN = "_row_id" @@ -32,6 +29,7 @@ "column_catalog", "table_catalog", "database_catalog", + "configuration_catalog", "depend_column_and_function_cache", "function_cache", "function_catalog", @@ -65,33 +63,17 @@ def __call__(cls, uri): class SQLConfig(metaclass=SingletonMeta): def __init__(self, uri): - """Initializes the engine and session for database operations - - Retrieves the database uri for connection from ConfigurationManager. - """ + """Initializes the engine and session for database operations""" self.worker_uri = str(uri) # set echo=True to log SQL - connect_args = {} - config_obj = parse_config_yml() - if is_postgres_uri(config_obj["core"]["catalog_database_uri"]): - # Set the arguments for postgres backend. - connect_args = {"connect_timeout": 1000} - # https://www.oddbird.net/2014/06/14/sqlalchemy-postgres-autocommit/ - self.engine = create_engine( - self.worker_uri, - poolclass=NullPool, - isolation_level="AUTOCOMMIT", - connect_args=connect_args, - ) - else: - # Default to SQLite. - connect_args = {"timeout": 1000} - self.engine = create_engine( - self.worker_uri, - connect_args=connect_args, - ) + # Default to SQLite. + connect_args = {"timeout": 1000} + self.engine = create_engine( + self.worker_uri, + connect_args=connect_args, + ) if self.engine.url.get_backend_name() == "sqlite": # enforce foreign key constraint and wal logging for sqlite diff --git a/evadb/configuration/bootstrap_environment.py b/evadb/configuration/bootstrap_environment.py index 55e2f0f22..bb108692a 100644 --- a/evadb/configuration/bootstrap_environment.py +++ b/evadb/configuration/bootstrap_environment.py @@ -12,13 +12,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import importlib.resources as importlib_resources import logging from pathlib import Path from typing import Union -import yaml - from evadb.configuration.constants import ( CACHE_DIR, DB_DEFAULT_NAME, @@ -27,26 +24,12 @@ MODEL_DIR, S3_DOWNLOAD_DIR, TMP_DIR, - EvaDB_CONFIG_FILE, EvaDB_DATASET_DIR, ) -from evadb.utils.generic_utils import parse_config_yml +from evadb.evadb_config import BASE_EVADB_CONFIG from evadb.utils.logging_manager import logger as evadb_logger -def get_base_config(evadb_installation_dir: Path) -> Path: - """ - Get path to .evadb.yml source path. - """ - # if evadb package is installed in environment - if importlib_resources.is_resource("evadb", EvaDB_CONFIG_FILE): - with importlib_resources.path("evadb", EvaDB_CONFIG_FILE) as yml_path: - return yml_path - else: - # For local dev environments without package installed - return evadb_installation_dir / EvaDB_CONFIG_FILE - - def get_default_db_uri(evadb_dir: Path): """ Get the default database uri. @@ -54,12 +37,8 @@ def get_default_db_uri(evadb_dir: Path): Arguments: evadb_dir: path to evadb database directory """ - config_obj = parse_config_yml() - if config_obj["core"]["catalog_database_uri"]: - return config_obj["core"]["catalog_database_uri"] - else: - # Default to sqlite. - return f"sqlite:///{evadb_dir.resolve()}/{DB_DEFAULT_NAME}" + # Default to sqlite. + return f"sqlite:///{evadb_dir.resolve()}/{DB_DEFAULT_NAME}" def bootstrap_environment(evadb_dir: Path, evadb_installation_dir: Path): @@ -71,7 +50,7 @@ def bootstrap_environment(evadb_dir: Path, evadb_installation_dir: Path): evadb_installation_dir: path to evadb package """ - default_config_path = get_base_config(evadb_installation_dir).resolve() + config_obj = BASE_EVADB_CONFIG # creates necessary directories config_default_dict = create_directories_and_get_default_config_values( @@ -80,22 +59,22 @@ def bootstrap_environment(evadb_dir: Path, evadb_installation_dir: Path): assert evadb_dir.exists(), f"{evadb_dir} does not exist" assert evadb_installation_dir.exists(), f"{evadb_installation_dir} does not exist" - config_obj = {} - with default_config_path.open("r") as yml_file: - config_obj = yaml.load(yml_file, Loader=yaml.FullLoader) config_obj = merge_dict_of_dicts(config_default_dict, config_obj) - mode = config_obj["core"]["mode"] + mode = config_obj["mode"] # set logger to appropriate level (debug or release) level = logging.WARN if mode == "release" else logging.DEBUG evadb_logger.setLevel(level) evadb_logger.debug(f"Setting logging level to: {str(level)}") + # Mainly want to add all the configs to sqlite + return config_obj +# TODO : Change def create_directories_and_get_default_config_values( - evadb_dir: Path, evadb_installation_dir: Path, category: str = None, key: str = None + evadb_dir: Path, evadb_installation_dir: Path ) -> Union[dict, str]: default_install_dir = evadb_installation_dir dataset_location = evadb_dir / EvaDB_DATASET_DIR @@ -124,21 +103,15 @@ def create_directories_and_get_default_config_values( model_dir.mkdir(parents=True, exist_ok=True) config_obj = {} - config_obj["core"] = {} - config_obj["storage"] = {} - config_obj["core"]["evadb_installation_dir"] = str(default_install_dir.resolve()) - config_obj["core"]["datasets_dir"] = str(dataset_location.resolve()) - config_obj["core"]["catalog_database_uri"] = get_default_db_uri(evadb_dir) - config_obj["storage"]["index_dir"] = str(index_dir.resolve()) - config_obj["storage"]["cache_dir"] = str(cache_dir.resolve()) - config_obj["storage"]["s3_download_dir"] = str(s3_dir.resolve()) - config_obj["storage"]["tmp_dir"] = str(tmp_dir.resolve()) - config_obj["storage"]["function_dir"] = str(function_dir.resolve()) - config_obj["storage"]["model_dir"] = str(model_dir.resolve()) - if category and key: - return config_obj.get(category, {}).get(key, None) - elif category: - return config_obj.get(category, {}) + config_obj["evadb_installation_dir"] = str(default_install_dir.resolve()) + config_obj["datasets_dir"] = str(dataset_location.resolve()) + config_obj["catalog_database_uri"] = get_default_db_uri(evadb_dir) + config_obj["index_dir"] = str(index_dir.resolve()) + config_obj["cache_dir"] = str(cache_dir.resolve()) + config_obj["s3_download_dir"] = str(s3_dir.resolve()) + config_obj["tmp_dir"] = str(tmp_dir.resolve()) + config_obj["function_dir"] = str(function_dir.resolve()) + config_obj["model_dir"] = str(model_dir.resolve()) return config_obj @@ -147,15 +120,14 @@ def merge_dict_of_dicts(dict1, dict2): merged_dict = dict1.copy() for key, value in dict2.items(): - # Overwrite only if some value is specified. - if value: - if ( - key in merged_dict - and isinstance(merged_dict[key], dict) - and isinstance(value, dict) - ): - merged_dict[key] = merge_dict_of_dicts(merged_dict[key], value) - else: - merged_dict[key] = value + if key in merged_dict.keys(): + # Overwrite only if some value is specified. + if value is not None: + if isinstance(merged_dict[key], dict) and isinstance(value, dict): + merged_dict[key] = merge_dict_of_dicts(merged_dict[key], value) + else: + merged_dict[key] = value + else: + merged_dict[key] = value return merged_dict diff --git a/evadb/configuration/configuration_manager.py b/evadb/configuration/configuration_manager.py deleted file mode 100644 index 8d209e2fc..000000000 --- a/evadb/configuration/configuration_manager.py +++ /dev/null @@ -1,71 +0,0 @@ -# coding=utf-8 -# Copyright 2018-2023 EvaDB -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from pathlib import Path -from typing import Any - -from evadb.configuration.bootstrap_environment import bootstrap_environment -from evadb.configuration.constants import EvaDB_DATABASE_DIR, EvaDB_INSTALLATION_DIR -from evadb.utils.logging_manager import logger - - -class ConfigurationManager(object): - def __init__(self, evadb_dir: str = None) -> None: - self._evadb_dir = evadb_dir or EvaDB_DATABASE_DIR - self._config_obj = self._create_if_not_exists() - - def _create_if_not_exists(self): - config_obj = bootstrap_environment( - evadb_dir=Path(self._evadb_dir), - evadb_installation_dir=Path(EvaDB_INSTALLATION_DIR), - ) - return config_obj - - def _get(self, category: str, key: str) -> Any: - """Retrieve a configuration value based on the category and key. - - Args: - category (str): The category of the configuration. - key (str): The key of the configuration within the category. - - Returns: - Any: The retrieved configuration value. - - Raises: - ValueError: If the YAML file is invalid or cannot be loaded. - """ - config_obj = self._config_obj - - # Get value from the user-provided config file - value = config_obj.get(category, {}).get(key) - - # cannot find the value, report invalid category, key combination - if value is None: - logger.exception(f"Invalid category and key combination {category}:{key}") - - return value - - def _update(self, category: str, key: str, value: str): - config_obj = self._config_obj - - if category not in config_obj: - config_obj[category] = {} - - config_obj[category][key] = value - - def get_value(self, category: str, key: str) -> Any: - return self._get(category, key) - - def update_value(self, category, key, value) -> None: - self._update(category, key, value) diff --git a/evadb/configuration/constants.py b/evadb/configuration/constants.py index 395f898be..18a1331f8 100644 --- a/evadb/configuration/constants.py +++ b/evadb/configuration/constants.py @@ -21,7 +21,7 @@ EvaDB_DATABASE_DIR = "evadb_data" EvaDB_APPS_DIR = "apps" EvaDB_DATASET_DIR = "evadb_datasets" -EvaDB_CONFIG_FILE = "evadb.yml" +EvaDB_CONFIG_FILE = "evadb_config.py" FUNCTION_DIR = "functions" MODEL_DIR = "models" CATALOG_DIR = "catalog" diff --git a/evadb/database.py b/evadb/database.py index ed6ad2f11..9c22d5b9f 100644 --- a/evadb/database.py +++ b/evadb/database.py @@ -16,10 +16,13 @@ from pathlib import Path from typing import TYPE_CHECKING, Callable -from evadb.catalog.catalog_utils import get_catalog_instance -from evadb.configuration.configuration_manager import ConfigurationManager -from evadb.configuration.constants import DB_DEFAULT_NAME, EvaDB_DATABASE_DIR -from evadb.utils.generic_utils import parse_config_yml +from evadb.catalog.catalog_utils import bootstrap_configs, get_catalog_instance +from evadb.configuration.bootstrap_environment import bootstrap_environment +from evadb.configuration.constants import ( + DB_DEFAULT_NAME, + EvaDB_DATABASE_DIR, + EvaDB_INSTALLATION_DIR, +) if TYPE_CHECKING: from evadb.catalog.catalog_manager import CatalogManager @@ -28,7 +31,6 @@ @dataclass class EvaDBDatabase: db_uri: str - config: ConfigurationManager catalog_uri: str catalog_func: Callable @@ -36,16 +38,12 @@ def catalog(self) -> "CatalogManager": """ Note: Generating an object on demand plays a crucial role in ensuring that different threads do not share the same catalog object, as it can result in serialization issues and incorrect behavior with SQLAlchemy. Refer to get_catalog_instance() """ - return self.catalog_func(self.catalog_uri, self.config) + return self.catalog_func(self.catalog_uri) def get_default_db_uri(evadb_dir: Path): - config_obj = parse_config_yml() - if config_obj["core"]["catalog_database_uri"]: - return config_obj["core"]["catalog_database_uri"] - else: - # Default to sqlite. - return f"sqlite:///{evadb_dir.resolve()}/{DB_DEFAULT_NAME}" + # Default to sqlite. + return f"sqlite:///{evadb_dir.resolve()}/{DB_DEFAULT_NAME}" def init_evadb_instance( @@ -53,8 +51,15 @@ def init_evadb_instance( ): if db_dir is None: db_dir = EvaDB_DATABASE_DIR - config = ConfigurationManager(db_dir) + + config_obj = bootstrap_environment( + Path(db_dir), + evadb_installation_dir=Path(EvaDB_INSTALLATION_DIR), + ) catalog_uri = custom_db_uri or get_default_db_uri(Path(db_dir)) - return EvaDBDatabase(db_dir, config, catalog_uri, get_catalog_instance) + # load all the config into the configuration_catalog table + bootstrap_configs(get_catalog_instance(catalog_uri), config_obj) + + return EvaDBDatabase(db_dir, catalog_uri, get_catalog_instance) diff --git a/evadb/evadb.yml b/evadb/evadb.yml deleted file mode 100644 index 04ce32549..000000000 --- a/evadb/evadb.yml +++ /dev/null @@ -1,30 +0,0 @@ -core: - evadb_installation_dir: "" - datasets_dir: "" - catalog_database_uri: "" - application: "evadb" - mode: "release" #release or debug - -executor: - # batch_mem_size configures the number of rows processed by the execution engine in one iteration - # rows = max(1, row_mem_size / batch_mem_size) - batch_mem_size: 30000000 - - # batch size used for gpu_operations - gpu_batch_size: 1 - - gpu_ids: [0] - -server: - host: "0.0.0.0" - port: 8803 - socket_timeout: 60 - -experimental: - ray: False - -third_party: - OPENAI_KEY: "" - PINECONE_API_KEY: "" - PINECONE_ENV: "" - REPLICATE_API_TOKEN: "" \ No newline at end of file diff --git a/evadb/evadb_cmd_client.py b/evadb/evadb_cmd_client.py index 3286b68fa..2f15f075d 100644 --- a/evadb/evadb_cmd_client.py +++ b/evadb/evadb_cmd_client.py @@ -26,7 +26,7 @@ EvaDB_CODE_DIR = abspath(join(THIS_DIR, "..")) sys.path.append(EvaDB_CODE_DIR) -from evadb.configuration.configuration_manager import ConfigurationManager # noqa: E402 +from evadb.evadb_config import BASE_EVADB_CONFIG # noqa: E402 from evadb.server.interpreter import start_cmd_client # noqa: E402 @@ -61,13 +61,9 @@ def main(): # PARSE ARGS args, unknown = parser.parse_known_args() - host = ( - args.host if args.host else ConfigurationManager().get_value("server", "host") - ) + host = args.host if args.host else BASE_EVADB_CONFIG["host"] - port = ( - args.port if args.port else ConfigurationManager().get_value("server", "port") - ) + port = args.port if args.port else BASE_EVADB_CONFIG["port"] asyncio.run(evadb_client(host, port)) diff --git a/evadb/evadb_config.py b/evadb/evadb_config.py new file mode 100644 index 000000000..800112a60 --- /dev/null +++ b/evadb/evadb_config.py @@ -0,0 +1,39 @@ +# coding=utf-8 +# Copyright 2018-2023 EvaDB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +EvaDB configuration dict + +batch_mem_size configures the number of rows processed by the execution engine in one iteration +rows = max(1, row_mem_size / batch_mem_size) +""" + +BASE_EVADB_CONFIG = { + "evadb_installation_dir": "", + "datasets_dir": "", + "catalog_database_uri": "", + "application": "evadb", + "mode": "release", + "batch_mem_size": 30000000, + "gpu_batch_size": 1, # batch size used for gpu_operations + "gpu_ids": [0], + "host": "0.0.0.0", + "port": 8803, + "socket_timeout": 60, + "ray": False, + "OPENAI_API_KEY": "", + "PINECONE_API_KEY": "", + "PINECONE_ENV": "", +} diff --git a/evadb/executor/abstract_executor.py b/evadb/executor/abstract_executor.py index 66a453f71..864790146 100644 --- a/evadb/executor/abstract_executor.py +++ b/evadb/executor/abstract_executor.py @@ -18,7 +18,6 @@ if TYPE_CHECKING: from evadb.catalog.catalog_manager import CatalogManager -from evadb.configuration.configuration_manager import ConfigurationManager from evadb.database import EvaDBDatabase from evadb.models.storage.batch import Batch from evadb.plan_nodes.abstract_plan import AbstractPlan @@ -36,7 +35,6 @@ class AbstractExecutor(ABC): def __init__(self, db: EvaDBDatabase, node: AbstractPlan): self._db = db self._node = node - self._config: ConfigurationManager = db.config if db else None self._children = [] # @lru_cache(maxsize=None) @@ -74,10 +72,6 @@ def node(self) -> AbstractPlan: def db(self) -> EvaDBDatabase: return self._db - @property - def config(self) -> ConfigurationManager: - return self._config - @abstractmethod def exec(self, *args, **kwargs) -> Iterable[Batch]: """ diff --git a/evadb/executor/create_function_executor.py b/evadb/executor/create_function_executor.py index 367110b1d..d045205a6 100644 --- a/evadb/executor/create_function_executor.py +++ b/evadb/executor/create_function_executor.py @@ -129,10 +129,13 @@ def handle_ludwig_function(self): target=arg_map["predict"], tune_for_memory=arg_map.get("tune_for_memory", False), time_limit_s=arg_map.get("time_limit", DEFAULT_TRAIN_TIME_LIMIT), - output_directory=self.db.config.get_value("storage", "tmp_dir"), + output_directory=self.db.catalog().get_configuration_catalog_value( + "tmp_dir" + ), ) model_path = os.path.join( - self.db.config.get_value("storage", "model_dir"), self.node.name + self.db.catalog().get_configuration_catalog_value("model_dir"), + self.node.name, ) auto_train_results.best_model.save(model_path) self.node.metadata.append( @@ -176,7 +179,8 @@ def handle_sklearn_function(self): aggregated_batch.frames.drop([arg_map["predict"]], axis=1, inplace=True) model.fit(X=aggregated_batch.frames, y=Y) model_path = os.path.join( - self.db.config.get_value("storage", "model_dir"), self.node.name + self.db.catalog().get_configuration_catalog_value("model_dir"), + self.node.name, ) pickle.dump(model, open(model_path, "wb")) self.node.metadata.append( @@ -475,7 +479,7 @@ def get_optuna_config(trial): model_save_dir_name += "_exogenous_" + str(sorted(exogenous_columns)) model_dir = os.path.join( - self.db.config.get_value("storage", "model_dir"), + self.db.catalog().get_configuration_catalog_value("model_dir"), "tsforecasting", model_save_dir_name, str(hashlib.sha256(data.to_string().encode()).hexdigest()), diff --git a/evadb/executor/create_index_executor.py b/evadb/executor/create_index_executor.py index 407cfef3c..4b5c16474 100644 --- a/evadb/executor/create_index_executor.py +++ b/evadb/executor/create_index_executor.py @@ -75,7 +75,7 @@ def _create_native_index(self): # On-disk saving path for EvaDB index. def _get_evadb_index_save_path(self) -> Path: - index_dir = Path(self.config.get_value("storage", "index_dir")) + index_dir = Path(self.db.catalog().get_configuration_catalog_value("index_dir")) if not index_dir.exists(): index_dir.mkdir(parents=True, exist_ok=True) return str( @@ -121,7 +121,7 @@ def _create_evadb_index(self): self.vector_store_type, self.name, **handle_vector_store_params( - self.vector_store_type, index_path + self.vector_store_type, index_path, self.catalog ), ) else: @@ -151,7 +151,7 @@ def _create_evadb_index(self): self.vector_store_type, self.name, **handle_vector_store_params( - self.vector_store_type, index_path + self.vector_store_type, index_path, self.catalog ), ) index.create(input_dim) diff --git a/evadb/executor/drop_object_executor.py b/evadb/executor/drop_object_executor.py index 38d5419dc..a857f15ea 100644 --- a/evadb/executor/drop_object_executor.py +++ b/evadb/executor/drop_object_executor.py @@ -118,7 +118,9 @@ def _handle_drop_index(self, index_name: str, if_exists: bool): index = VectorStoreFactory.init_vector_store( index_obj.type, index_obj.name, - **handle_vector_store_params(index_obj.type, index_obj.save_file_path), + **handle_vector_store_params( + index_obj.type, index_obj.save_file_path, self.catalog + ), ) assert ( index is not None diff --git a/evadb/executor/execution_context.py b/evadb/executor/execution_context.py index d5feea464..b0b5e1a69 100644 --- a/evadb/executor/execution_context.py +++ b/evadb/executor/execution_context.py @@ -16,7 +16,6 @@ import random from typing import List -from evadb.configuration.configuration_manager import ConfigurationManager from evadb.constants import NO_GPU from evadb.utils.generic_utils import get_gpu_count, is_gpu_available @@ -28,13 +27,8 @@ class Context: if using horovod: current rank etc. """ - def __new__(cls): - if not hasattr(cls, "_instance"): - cls._instance = super(Context, cls).__new__(cls) - return cls._instance - - def __init__(self): - self._config_manager = ConfigurationManager() + def __init__(self, user_provided_gpu_conf=[]): + self._user_provided_gpu_conf = user_provided_gpu_conf self._gpus = self._populate_gpu_ids() @property @@ -42,10 +36,8 @@ def gpus(self): return self._gpus def _populate_gpu_from_config(self) -> List: - # Populate GPU IDs from yaml config file. - gpu_conf = self._config_manager.get_value("executor", "gpu_ids") available_gpus = [i for i in range(get_gpu_count())] - return list(set(available_gpus) & set(gpu_conf)) + return list(set(available_gpus) & set(self._user_provided_gpu_conf)) def _populate_gpu_from_env(self) -> List: # Populate GPU IDs from env variable. diff --git a/evadb/executor/executor_utils.py b/evadb/executor/executor_utils.py index 23987ea15..9c2c3d5db 100644 --- a/evadb/executor/executor_utils.py +++ b/evadb/executor/executor_utils.py @@ -156,7 +156,7 @@ def validate_media(file_path: Path, media_type: FileFormatType) -> bool: def handle_vector_store_params( - vector_store_type: VectorStoreType, index_path: str + vector_store_type: VectorStoreType, index_path: str, catalog ) -> dict: """Handle vector store parameters based on the vector store type and index path. @@ -178,7 +178,13 @@ def handle_vector_store_params( elif vector_store_type == VectorStoreType.CHROMADB: return {"index_path": str(Path(index_path).parent)} elif vector_store_type == VectorStoreType.PINECONE: - return {} + # add the required API_KEYS + return { + "PINECONE_API_KEY": catalog().get_configuration_catalog_value( + "PINECONE_API_KEY" + ), + "PINECONE_ENV": catalog().get_configuration_catalog_value("PINECONE_ENV"), + } else: raise ValueError("Unsupported vector store type: {}".format(vector_store_type)) diff --git a/evadb/executor/load_multimedia_executor.py b/evadb/executor/load_multimedia_executor.py index 90bc7edba..03f1f7fb6 100644 --- a/evadb/executor/load_multimedia_executor.py +++ b/evadb/executor/load_multimedia_executor.py @@ -53,7 +53,9 @@ def exec(self, *args, **kwargs): # If it is a s3 path, download the file to local if self.node.file_path.as_posix().startswith("s3:/"): - s3_dir = Path(self.config.get_value("storage", "s3_download_dir")) + s3_dir = Path( + self.catalog().get_configuration_catalog_value("s3_download_dir") + ) dst_path = s3_dir / self.node.table_info.table_name dst_path.mkdir(parents=True, exist_ok=True) video_files = download_from_s3(self.node.file_path, dst_path) diff --git a/evadb/executor/set_executor.py b/evadb/executor/set_executor.py index f399e3839..309fe2747 100644 --- a/evadb/executor/set_executor.py +++ b/evadb/executor/set_executor.py @@ -36,8 +36,8 @@ def exec(self, *args, **kwargs): as a separate PR for the issue #1140, where all instances of config use will be replaced """ - self._config.update_value( - category="default", - key=self.node.config_name.upper(), + + self.catalog().upsert_configuration_catalog_entry( + key=self.node.config_name, value=self.node.config_value.value, ) diff --git a/evadb/executor/show_info_executor.py b/evadb/executor/show_info_executor.py index 96dc0a537..16871b843 100644 --- a/evadb/executor/show_info_executor.py +++ b/evadb/executor/show_info_executor.py @@ -51,8 +51,7 @@ def exec(self, *args, **kwargs): for db in databases: show_entries.append(db.display_format()) elif self.node.show_type is ShowType.CONFIG: - value = self._config.get_value( - category="default", + value = self.catalog().get_configuration_catalog_value( key=self.node.show_val.upper(), ) show_entries = {} diff --git a/evadb/executor/vector_index_scan_executor.py b/evadb/executor/vector_index_scan_executor.py index 2b58f5c33..b2e1bb219 100644 --- a/evadb/executor/vector_index_scan_executor.py +++ b/evadb/executor/vector_index_scan_executor.py @@ -102,7 +102,9 @@ def _evadb_vector_index_scan(self, *args, **kwargs): self.index = VectorStoreFactory.init_vector_store( self.vector_store_type, self.index_name, - **handle_vector_store_params(self.vector_store_type, self.index_path), + **handle_vector_store_params( + self.vector_store_type, self.index_path, self.db.catalog + ), ) search_feat = self._get_search_query_results() diff --git a/evadb/functions/abstract/pytorch_abstract_function.py b/evadb/functions/abstract/pytorch_abstract_function.py index 763f3658f..49e531655 100644 --- a/evadb/functions/abstract/pytorch_abstract_function.py +++ b/evadb/functions/abstract/pytorch_abstract_function.py @@ -17,7 +17,6 @@ import pandas as pd from numpy.typing import ArrayLike -from evadb.configuration.configuration_manager import ConfigurationManager from evadb.functions.abstract.abstract_function import ( AbstractClassifierFunction, AbstractTransformationFunction, @@ -74,7 +73,8 @@ def __call__(self, *args, **kwargs) -> pd.DataFrame: if isinstance(frames, pd.DataFrame): frames = frames.transpose().values.tolist()[0] - gpu_batch_size = ConfigurationManager().get_value("executor", "gpu_batch_size") + # hardcoding it for now, need to be fixed @xzdandy + gpu_batch_size = 1 import torch tens_batch = torch.cat([self.transform(x) for x in frames]).to( diff --git a/evadb/functions/chatgpt.py b/evadb/functions/chatgpt.py index 61253116f..fadc61191 100644 --- a/evadb/functions/chatgpt.py +++ b/evadb/functions/chatgpt.py @@ -20,7 +20,6 @@ from retry import retry from evadb.catalog.catalog_type import NdArrayType -from evadb.configuration.configuration_manager import ConfigurationManager from evadb.functions.abstract.abstract_function import AbstractFunction from evadb.functions.decorators.decorators import forward, setup from evadb.functions.decorators.io_descriptors.data_types import PandasDataframe @@ -85,10 +84,12 @@ def setup( self, model="gpt-3.5-turbo", temperature: float = 0, + openai_api_key="", ) -> None: assert model in _VALID_CHAT_COMPLETION_MODEL, f"Unsupported ChatGPT {model}" self.model = model self.temperature = temperature + self.openai_api_key = openai_api_key @forward( input_signatures=[ @@ -120,14 +121,12 @@ def forward(self, text_df): def completion_with_backoff(**kwargs): return openai.ChatCompletion.create(**kwargs) - # Register API key, try configuration manager first - openai.api_key = ConfigurationManager().get_value("third_party", "OPENAI_KEY") - # If not found, try OS Environment Variable + openai.api_key = self.openai_api_key if len(openai.api_key) == 0: - openai.api_key = os.environ.get("OPENAI_KEY", "") + openai.api_key = os.environ.get("OPENAI_API_KEY", "") assert ( len(openai.api_key) != 0 - ), "Please set your OpenAI API key in evadb.yml file (third_party, open_api_key) or environment variable (OPENAI_KEY)" + ), "Please set your OpenAI API key using SET OPENAI_API_KEY = 'sk-' or environment variable (OPENAI_API_KEY)" queries = text_df[text_df.columns[0]] content = text_df[text_df.columns[0]] diff --git a/evadb/functions/dalle.py b/evadb/functions/dalle.py index d373fda38..7c1dc39dd 100644 --- a/evadb/functions/dalle.py +++ b/evadb/functions/dalle.py @@ -22,7 +22,6 @@ from PIL import Image from evadb.catalog.catalog_type import NdArrayType -from evadb.configuration.configuration_manager import ConfigurationManager from evadb.functions.abstract.abstract_function import AbstractFunction from evadb.functions.decorators.decorators import forward from evadb.functions.decorators.io_descriptors.data_types import PandasDataframe @@ -34,8 +33,8 @@ class DallEFunction(AbstractFunction): def name(self) -> str: return "DallE" - def setup(self) -> None: - pass + def setup(self, openai_api_key="") -> None: + self.openai_api_key = openai_api_key @forward( input_signatures=[ @@ -59,14 +58,13 @@ def forward(self, text_df): try_to_import_openai() import openai - # Register API key, try configuration manager first - openai.api_key = ConfigurationManager().get_value("third_party", "OPENAI_KEY") + openai.api_key = self.openai_api_key # If not found, try OS Environment Variable - if openai.api_key is None or len(openai.api_key) == 0: - openai.api_key = os.environ.get("OPENAI_KEY", "") + if len(openai.api_key) == 0: + openai.api_key = os.environ.get("OPENAI_API_KEY", "") assert ( len(openai.api_key) != 0 - ), "Please set your OpenAI API key in evadb.yml file (third_party, open_api_key) or environment variable (OPENAI_KEY)" + ), "Please set your OpenAI API key using SET OPENAI_API_KEY = 'sk-' or environment variable (OPENAI_API_KEY)" def generate_image(text_df: PandasDataframe): results = [] diff --git a/evadb/functions/stable_diffusion.py b/evadb/functions/stable_diffusion.py index 044195547..1ceaefe7e 100644 --- a/evadb/functions/stable_diffusion.py +++ b/evadb/functions/stable_diffusion.py @@ -22,7 +22,6 @@ from PIL import Image from evadb.catalog.catalog_type import NdArrayType -from evadb.configuration.configuration_manager import ConfigurationManager from evadb.functions.abstract.abstract_function import AbstractFunction from evadb.functions.decorators.decorators import forward from evadb.functions.decorators.io_descriptors.data_types import PandasDataframe @@ -34,10 +33,8 @@ class StableDiffusion(AbstractFunction): def name(self) -> str: return "StableDiffusion" - def setup( - self, - ) -> None: - pass + def setup(self, replicate_api_token="") -> None: + self.replicate_api_token = replicate_api_token @forward( input_signatures=[ @@ -64,16 +61,13 @@ def forward(self, text_df): try_to_import_replicate() import replicate - # Register API key, try configuration manager first - replicate_api_key = ConfigurationManager().get_value( - "third_party", "REPLICATE_API_TOKEN" - ) + replicate_api_key = self.replicate_api_token # If not found, try OS Environment Variable if replicate_api_key is None: replicate_api_key = os.environ.get("REPLICATE_API_TOKEN", "") assert ( len(replicate_api_key) != 0 - ), "Please set your Replicate API key in evadb.yml file (third_party, replicate_api_token) or environment variable (REPLICATE_API_TOKEN)" + ), "Please set your Replicate API key using SET REPLICATE_API_TOKEN = '' or set the environment variable (REPLICATE_API_TOKEN)" os.environ["REPLICATE_API_TOKEN"] = replicate_api_key model_id = ( diff --git a/evadb/optimizer/optimizer_context.py b/evadb/optimizer/optimizer_context.py index 6cb72dfd8..721cb5b92 100644 --- a/evadb/optimizer/optimizer_context.py +++ b/evadb/optimizer/optimizer_context.py @@ -43,7 +43,9 @@ def __init__( self._task_stack = OptimizerTaskStack() self._memo = Memo() self._cost_model = cost_model - self._rules_manager = rules_manager or RulesManager(db.config) + # check if ray is enabled + is_ray_enabled = self.db.catalog().get_configuration_catalog_value("ray") + self._rules_manager = rules_manager or RulesManager({"ray": is_ray_enabled}) @property def db(self): diff --git a/evadb/optimizer/plan_generator.py b/evadb/optimizer/plan_generator.py index 5b8c6d2a2..5396f4b93 100644 --- a/evadb/optimizer/plan_generator.py +++ b/evadb/optimizer/plan_generator.py @@ -39,7 +39,9 @@ def __init__( cost_model: CostModel = None, ) -> None: self.db = db - self.rules_manager = rules_manager or RulesManager(db.config) + # check if ray is enabled + is_ray_enabled = self.db.catalog().get_configuration_catalog_value("ray") + self.rules_manager = rules_manager or RulesManager({"ray": is_ray_enabled}) self.cost_model = cost_model or CostModel() def execute_task_stack(self, task_stack: OptimizerTaskStack): diff --git a/evadb/optimizer/rules/rules.py b/evadb/optimizer/rules/rules.py index 8e18e4d70..cb9ff3274 100644 --- a/evadb/optimizer/rules/rules.py +++ b/evadb/optimizer/rules/rules.py @@ -836,7 +836,10 @@ def apply(self, before: LogicalCreateIndex, context: OptimizerContext): before.index_def, ) child = SeqScanPlan(None, before.project_expr_list, before.table_ref.alias) - batch_mem_size = context.db.config.get_value("executor", "batch_mem_size") + + batch_mem_size = context.db.catalog().get_configuration_catalog_value( + "batch_mem_size" + ) child.append_child( StoragePlan( before.table_ref.table.table_obj, @@ -933,7 +936,9 @@ def apply(self, before: LogicalGet, context: OptimizerContext): # read in a batch from storage engine. # Todo: Experiment heuristics. after = SeqScanPlan(None, before.target_list, before.alias) - batch_mem_size = context.db.config.get_value("executor", "batch_mem_size") + batch_mem_size = context.db.catalog().get_configuration_catalog_value( + "batch_mem_size" + ) after.append_child( StoragePlan( before.table_obj, diff --git a/evadb/optimizer/rules/rules_manager.py b/evadb/optimizer/rules/rules_manager.py index e9720b78d..cc88a9575 100644 --- a/evadb/optimizer/rules/rules_manager.py +++ b/evadb/optimizer/rules/rules_manager.py @@ -17,7 +17,6 @@ from contextlib import contextmanager from typing import List -from evadb.configuration.configuration_manager import ConfigurationManager from evadb.optimizer.rules.rules import ( CacheFunctionExpressionInApply, CacheFunctionExpressionInFilter, @@ -67,7 +66,7 @@ class RulesManager: - def __init__(self, config: ConfigurationManager): + def __init__(self, configs: dict = {}): self._logical_rules = [ LogicalInnerJoinCommutativity(), CacheFunctionExpressionInApply(), @@ -121,9 +120,9 @@ def __init__(self, config: ConfigurationManager): # These rules are enabled only if # (1) ray is installed and (2) ray is enabled # Ray must be installed using pip - # It must also be enabled in "evadb.yml" + # It must also be enabled using the SET command # NOTE: By default, it is not enabled - ray_enabled = config.get_value("experimental", "ray") + ray_enabled = configs.get("ray", False) if is_ray_enabled_and_installed(ray_enabled): self._implementation_rules.extend( [ diff --git a/evadb/parser/evadb.lark b/evadb/parser/evadb.lark index e834d1a7d..d3f455fbb 100644 --- a/evadb/parser/evadb.lark +++ b/evadb/parser/evadb.lark @@ -84,7 +84,7 @@ set_statement: SET config_name (EQUAL_SYMBOL | TO) config_value config_name: uid -config_value: (string_literal | decimal_literal | boolean_literal | real_literal) +config_value: constant // Data Manipulation Language diff --git a/evadb/server/server.py b/evadb/server/server.py index 0105279a3..9f33dced0 100644 --- a/evadb/server/server.py +++ b/evadb/server/server.py @@ -49,7 +49,7 @@ async def start_evadb_server( self._server = await asyncio.start_server(self.accept_client, host, port) # load built-in functions - mode = self._evadb.config.get_value("core", "mode") + mode = self._evadb.catalog().get_configuration_catalog_value("mode") init_builtin_functions(self._evadb, mode=mode) async with self._server: diff --git a/evadb/third_party/vector_stores/pinecone.py b/evadb/third_party/vector_stores/pinecone.py index 837c95e57..3bead1a69 100644 --- a/evadb/third_party/vector_stores/pinecone.py +++ b/evadb/third_party/vector_stores/pinecone.py @@ -15,7 +15,6 @@ import os from typing import List -from evadb.configuration.configuration_manager import ConfigurationManager from evadb.third_party.vector_stores.types import ( FeaturePayload, VectorIndexQuery, @@ -30,34 +29,30 @@ class PineconeVectorStore(VectorStore): - def __init__(self, index_name: str) -> None: + def __init__(self, index_name: str, **kwargs) -> None: try_to_import_pinecone_client() global _pinecone_init_done # pinecone only allows index names with lower alpha-numeric characters and '-' self._index_name = index_name.strip().lower() # Get the API key. - self._api_key = ConfigurationManager().get_value( - "third_party", "PINECONE_API_KEY" - ) + self._api_key = kwargs.get("PINECONE_API_KEY") if not self._api_key: self._api_key = os.environ.get("PINECONE_API_KEY") assert ( self._api_key - ), "Please set your Pinecone API key in evadb.yml file (third_party, pinecone_api_key) or environment variable (PINECONE_KEY). It can be found at Pinecone Dashboard > API Keys > Value" + ), "Please set your `PINECONE_API_KEY` using set command or environment variable (PINECONE_KEY). It can be found at Pinecone Dashboard > API Keys > Value" # Get the environment name. - self._environment = ConfigurationManager().get_value( - "third_party", "PINECONE_ENV" - ) + self._environment = kwargs.get("PINECONE_ENV") if not self._environment: self._environment = os.environ.get("PINECONE_ENV") assert ( self._environment - ), "Please set the Pinecone environment key in evadb.yml file (third_party, pinecone_env) or environment variable (PINECONE_ENV). It can be found Pinecone Dashboard > API Keys > Environment." + ), "Please set your `PINECONE_ENV` or environment variable (PINECONE_ENV). It can be found Pinecone Dashboard > API Keys > Environment." if not _pinecone_init_done: # Initialize pinecone. diff --git a/evadb/third_party/vector_stores/utils.py b/evadb/third_party/vector_stores/utils.py index e47d24f1f..c7c5cb75b 100644 --- a/evadb/third_party/vector_stores/utils.py +++ b/evadb/third_party/vector_stores/utils.py @@ -41,6 +41,7 @@ def init_vector_store( from evadb.third_party.vector_stores.pinecone import required_params validate_kwargs(kwargs, required_params, required_params) + return PineconeVectorStore(index_name, **kwargs) elif vector_store_type == VectorStoreType.CHROMADB: diff --git a/test/app_tests/test_pandas_qa.py b/test/app_tests/test_pandas_qa.py index 6976a4699..e45c92762 100644 --- a/test/app_tests/test_pandas_qa.py +++ b/test/app_tests/test_pandas_qa.py @@ -25,7 +25,9 @@ class PandasQATest(unittest.TestCase): def setUpClass(cls): cls.evadb = get_evadb_for_testing() cls.evadb.catalog().reset() - os.environ["ray"] = str(cls.evadb.config.get_value("experimental", "ray")) + os.environ["ray"] = str( + cls.evadb.catalog().get_configuration_catalog_value("ray") + ) @classmethod def tearDownClass(cls): diff --git a/test/app_tests/test_privategpt.py b/test/app_tests/test_privategpt.py index 8a0b46f34..ff534dec2 100644 --- a/test/app_tests/test_privategpt.py +++ b/test/app_tests/test_privategpt.py @@ -29,7 +29,9 @@ class PrivateGPTTest(unittest.TestCase): def setUpClass(cls): cls.evadb = get_evadb_for_testing() cls.evadb.catalog().reset() - os.environ["ray"] = str(cls.evadb.config.get_value("experimental", "ray")) + os.environ["ray"] = str( + cls.evadb.catalog().get_configuration_catalog_value("ray") + ) @classmethod def tearDownClass(cls): diff --git a/test/app_tests/test_youtube_channel_qa.py b/test/app_tests/test_youtube_channel_qa.py index 099a1c73c..edb138588 100644 --- a/test/app_tests/test_youtube_channel_qa.py +++ b/test/app_tests/test_youtube_channel_qa.py @@ -24,7 +24,9 @@ class YoutubeChannelQATest(unittest.TestCase): def setUpClass(cls): cls.evadb = get_evadb_for_testing() cls.evadb.catalog().reset() - os.environ["ray"] = str(cls.evadb.config.get_value("experimental", "ray")) + os.environ["ray"] = str( + cls.evadb.catalog().get_configuration_catalog_value("ray") + ) @classmethod def tearDownClass(cls): @@ -39,7 +41,7 @@ def tearDown(self) -> None: def test_should_run_youtube_channel_qa_app(self): app_path = Path("apps", "youtube_channel_qa", "youtube_channel_qa.py") input1 = "\n\n\n" # Download just one video from the default channel in the default order. - # Assuming that OPENAI_KEY is already set as an environment variable + # Assuming that OPENAI_API_KEY is already set as an environment variable input2 = "What is this video about?\n" # Question input3 = "exit\n" # Exit inputs = input1 + input2 + input3 diff --git a/test/app_tests/test_youtube_qa.py b/test/app_tests/test_youtube_qa.py index 47d062405..9402654bd 100644 --- a/test/app_tests/test_youtube_qa.py +++ b/test/app_tests/test_youtube_qa.py @@ -25,7 +25,9 @@ class YoutubeQATest(unittest.TestCase): def setUpClass(cls): cls.evadb = get_evadb_for_testing() cls.evadb.catalog().reset() - os.environ["ray"] = str(cls.evadb.config.get_value("experimental", "ray")) + os.environ["ray"] = str( + cls.evadb.catalog().get_configuration_catalog_value("ray") + ) @classmethod def tearDownClass(cls): @@ -41,7 +43,7 @@ def tearDown(self) -> None: def test_should_run_youtube_qa_app(self): app_path = Path("apps", "youtube_qa", "youtube_qa.py") input1 = "yes\n\n" # Go with online video and default URL - # Assuming that OPENAI_KEY is already set as an environment variable + # Assuming that OPENAI_API_KEY is already set as an environment variable input2 = "What is this video on?\n" # Question input3 = "exit\nexit\n" # Exit inputs = input1 + input2 + input3 diff --git a/test/integration_tests/long/test_create_index_executor.py b/test/integration_tests/long/test_create_index_executor.py index feabb5bff..f44ef8f82 100644 --- a/test/integration_tests/long/test_create_index_executor.py +++ b/test/integration_tests/long/test_create_index_executor.py @@ -33,7 +33,7 @@ class CreateIndexTest(unittest.TestCase): def _index_save_path(self): return str( - Path(self.evadb.config.get_value("storage", "index_dir")) + Path(self.evadb.catalog().get_configuration_catalog_value("index_dir")) / Path("{}_{}.index".format("FAISS", "testCreateIndexName")) ) diff --git a/test/integration_tests/long/test_error_handling_with_ray.py b/test/integration_tests/long/test_error_handling_with_ray.py index da134b7ed..c2d71e7fe 100644 --- a/test/integration_tests/long/test_error_handling_with_ray.py +++ b/test/integration_tests/long/test_error_handling_with_ray.py @@ -32,7 +32,9 @@ class ErrorHandlingRayTests(unittest.TestCase): def setUp(self): self.evadb = get_evadb_for_testing() - os.environ["ray"] = str(self.evadb.config.get_value("experimental", "ray")) + os.environ["ray"] = str( + self.evadb.catalog().get_configuration_catalog_value("ray") + ) self.evadb.catalog().reset() # Load built-in Functions. load_functions_for_testing(self.evadb, mode="debug") diff --git a/test/integration_tests/long/test_explain_executor.py b/test/integration_tests/long/test_explain_executor.py index 9d2c8e8ca..cad1718c3 100644 --- a/test/integration_tests/long/test_explain_executor.py +++ b/test/integration_tests/long/test_explain_executor.py @@ -57,7 +57,7 @@ def test_explain_simple_select(self): """|__ ProjectPlan\n |__ SeqScanPlan\n |__ StoragePlan\n""" ) self.assertEqual(batch.frames[0][0], expected_output) - rules_manager = RulesManager(self.evadb.config) + rules_manager = RulesManager() with disable_rules(rules_manager, [XformLateralJoinToLinearFlow()]): custom_plan_generator = PlanGenerator(self.evadb, rules_manager) select_query = "EXPLAIN SELECT id, data FROM MyVideo JOIN LATERAL DummyObjectDetector(data) AS T ;" @@ -68,7 +68,7 @@ def test_explain_simple_select(self): self.assertEqual(batch.frames[0][0], expected_output) # Disable more rules - rules_manager = RulesManager(self.evadb.config) + rules_manager = RulesManager() with disable_rules( rules_manager, [ diff --git a/test/integration_tests/long/test_load_executor.py b/test/integration_tests/long/test_load_executor.py index 05a479a8f..0ca421cf1 100644 --- a/test/integration_tests/long/test_load_executor.py +++ b/test/integration_tests/long/test_load_executor.py @@ -89,7 +89,9 @@ def test_should_form_symlink_to_individual_video(self): # check that the file is a symlink to self.video_file_path video_file_path = os.path.join(video_dir, video_file) self.assertTrue(os.path.islink(video_file_path)) - self.assertEqual(os.readlink(video_file_path), self.video_file_path) + self.assertEqual( + os.readlink(video_file_path), str(Path(self.video_file_path).resolve()) + ) execute_query_fetch_all(self.evadb, "DROP TABLE IF EXISTS MyVideo;") diff --git a/test/integration_tests/long/test_optimizer_rules.py b/test/integration_tests/long/test_optimizer_rules.py index 27f5189a4..eb515a19a 100644 --- a/test/integration_tests/long/test_optimizer_rules.py +++ b/test/integration_tests/long/test_optimizer_rules.py @@ -34,7 +34,6 @@ PushDownFilterThroughApplyAndMerge, PushDownFilterThroughJoin, ReorderPredicates, - XformLateralJoinToLinearFlow, ) from evadb.optimizer.rules.rules_manager import RulesManager, disable_rules from evadb.plan_nodes.predicate_plan import PredicatePlan @@ -88,7 +87,7 @@ def test_should_benefit_from_pushdown(self, merge_mock, evaluate_mock): result_without_pushdown_rules = None with time_without_rule: - rules_manager = RulesManager(self.evadb.config) + rules_manager = RulesManager() with disable_rules( rules_manager, [PushDownFilterThroughApplyAndMerge(), PushDownFilterThroughJoin()], @@ -108,16 +107,6 @@ def test_should_benefit_from_pushdown(self, merge_mock, evaluate_mock): # on all the frames self.assertGreater(evaluate_count_without_rule, 3 * evaluate_count_with_rule) - result_without_xform_rule = None - rules_manager = RulesManager(self.evadb.config) - with disable_rules(rules_manager, [XformLateralJoinToLinearFlow()]): - custom_plan_generator = PlanGenerator(self.evadb, rules_manager) - result_without_xform_rule = execute_query_fetch_all( - self.evadb, query, plan_generator=custom_plan_generator - ) - - self.assertEqual(result_without_xform_rule, result_with_rule) - def test_should_pushdown_without_pushdown_join_rule(self): query = """SELECT id, obj.labels FROM MyVideo JOIN LATERAL @@ -132,7 +121,7 @@ def test_should_pushdown_without_pushdown_join_rule(self): time_without_rule = Timer() result_without_pushdown_join_rule = None with time_without_rule: - rules_manager = RulesManager(self.evadb.config) + rules_manager = RulesManager() with disable_rules(rules_manager, [PushDownFilterThroughJoin()]): # should use PushDownFilterThroughApplyAndMerge() custom_plan_generator = PlanGenerator(self.evadb, rules_manager) @@ -264,7 +253,7 @@ def test_reorder_rule_should_not_have_side_effects(self): query = "SELECT id FROM MyVideo WHERE id < 20 AND id > 10;" result = execute_query_fetch_all(self.evadb, query) - rules_manager = RulesManager(self.evadb.config) + rules_manager = RulesManager() with disable_rules(rules_manager, [ReorderPredicates()]): custom_plan_generator = PlanGenerator(self.evadb, rules_manager) expected = execute_query_fetch_all( diff --git a/test/integration_tests/long/test_pytorch.py b/test/integration_tests/long/test_pytorch.py index f2bd66ba0..b0d1cf632 100644 --- a/test/integration_tests/long/test_pytorch.py +++ b/test/integration_tests/long/test_pytorch.py @@ -49,7 +49,9 @@ class PytorchTest(unittest.TestCase): def setUpClass(cls): cls.evadb = get_evadb_for_testing() cls.evadb.catalog().reset() - os.environ["ray"] = str(cls.evadb.config.get_value("experimental", "ray")) + os.environ["ray"] = str( + cls.evadb.catalog().get_configuration_catalog_value("ray") + ) ua_detrac = f"{EvaDB_ROOT_DIR}/data/ua_detrac/ua_detrac.mp4" mnist = f"{EvaDB_ROOT_DIR}/data/mnist/mnist.mp4" @@ -295,7 +297,9 @@ def test_should_run_pytorch_and_similarity(self): batch_res = execute_query_fetch_all(self.evadb, select_query) img = batch_res.frames["myvideo.data"][0] - tmp_dir_from_config = self.evadb.config.get_value("storage", "tmp_dir") + tmp_dir_from_config = self.evadb.catalog().get_configuration_catalog_value( + "tmp_dir" + ) img_save_path = os.path.join(tmp_dir_from_config, "dummy.jpg") try: diff --git a/test/integration_tests/long/test_reuse.py b/test/integration_tests/long/test_reuse.py index 3f41e5fbb..b35751fbd 100644 --- a/test/integration_tests/long/test_reuse.py +++ b/test/integration_tests/long/test_reuse.py @@ -77,7 +77,7 @@ def _verify_reuse_correctness(self, query, reuse_batch): # surfaces when the system is running on low memory. Explicitly calling garbage # collection to reduce the memory usage. gc.collect() - rules_manager = RulesManager(self.evadb.config) + rules_manager = RulesManager() with disable_rules( rules_manager, [ diff --git a/test/integration_tests/long/test_s3_load_executor.py b/test/integration_tests/long/test_s3_load_executor.py index bdd648724..e6f961de6 100644 --- a/test/integration_tests/long/test_s3_load_executor.py +++ b/test/integration_tests/long/test_s3_load_executor.py @@ -41,7 +41,9 @@ def setUp(self): self.evadb.catalog().reset() self.video_file_path = create_sample_video() self.multiple_video_file_path = f"{EvaDB_ROOT_DIR}/data/sample_videos/1" - self.s3_download_dir = self.evadb.config.get_value("storage", "s3_download_dir") + self.s3_download_dir = self.evadb.catalog().get_configuration_catalog_value( + "s3_download_dir" + ) """Mocked AWS Credentials for moto.""" os.environ["AWS_ACCESS_KEY_ID"] = "testing" diff --git a/test/integration_tests/long/test_similarity.py b/test/integration_tests/long/test_similarity.py index 8dc7a2e3c..35f70948f 100644 --- a/test/integration_tests/long/test_similarity.py +++ b/test/integration_tests/long/test_similarity.py @@ -107,7 +107,7 @@ def setUp(self): # Create an actual image dataset. img_save_path = os.path.join( - self.evadb.config.get_value("storage", "tmp_dir"), + self.evadb.catalog().get_configuration_catalog_value("tmp_dir"), f"test_similar_img{i}.jpg", ) try_to_import_cv2() diff --git a/test/integration_tests/short/test_set_executor.py b/test/integration_tests/short/test_set_executor.py index d73268143..2a92ac2f9 100644 --- a/test/integration_tests/short/test_set_executor.py +++ b/test/integration_tests/short/test_set_executor.py @@ -36,6 +36,8 @@ def tearDownClass(cls): # integration test def test_set_execution(self): execute_query_fetch_all(self.evadb, "SET OPENAIKEY = 'ABCD';") - current_config_value = self.evadb.config.get_value("default", "OPENAIKEY") + current_config_value = self.evadb.catalog().get_configuration_catalog_value( + "OPENAIKEY" + ) self.assertEqual("ABCD", current_config_value) diff --git a/test/unit_tests/catalog/test_catalog_manager.py b/test/unit_tests/catalog/test_catalog_manager.py index 6149be34c..ef7c38247 100644 --- a/test/unit_tests/catalog/test_catalog_manager.py +++ b/test/unit_tests/catalog/test_catalog_manager.py @@ -44,7 +44,7 @@ def setUpClass(cls) -> None: @mock.patch("evadb.catalog.catalog_manager.init_db") def test_catalog_bootstrap(self, mocked_db): - x = CatalogManager(MagicMock(), MagicMock()) + x = CatalogManager(MagicMock()) x._bootstrap_catalog() mocked_db.assert_called() @@ -52,7 +52,7 @@ def test_catalog_bootstrap(self, mocked_db): "evadb.catalog.catalog_manager.CatalogManager.create_and_insert_table_catalog_entry" ) def test_create_multimedia_table_catalog_entry(self, mock): - x = CatalogManager(MagicMock(), MagicMock()) + x = CatalogManager(MagicMock()) name = "myvideo" x.create_and_insert_multimedia_table_catalog_entry( name=name, format_type=FileFormatType.VIDEO @@ -71,7 +71,7 @@ def test_create_multimedia_table_catalog_entry(self, mock): def test_insert_table_catalog_entry_should_create_table_and_columns( self, ds_mock, initdb_mock ): - catalog = CatalogManager(MagicMock(), MagicMock()) + catalog = CatalogManager(MagicMock()) file_url = "file1" table_name = "name" @@ -88,7 +88,7 @@ def test_insert_table_catalog_entry_should_create_table_and_columns( @mock.patch("evadb.catalog.catalog_manager.init_db") @mock.patch("evadb.catalog.catalog_manager.TableCatalogService") def test_get_table_catalog_entry_when_table_exists(self, ds_mock, initdb_mock): - catalog = CatalogManager(MagicMock(), MagicMock()) + catalog = CatalogManager(MagicMock()) table_name = "name" database_name = "database" row_id = 1 @@ -110,7 +110,7 @@ def test_get_table_catalog_entry_when_table_exists(self, ds_mock, initdb_mock): def test_get_table_catalog_entry_when_table_doesnot_exists( self, dcs_mock, ds_mock, initdb_mock ): - catalog = CatalogManager(MagicMock(), MagicMock()) + catalog = CatalogManager(MagicMock()) table_name = "name" database_name = "database" @@ -132,7 +132,7 @@ def test_get_table_catalog_entry_when_table_doesnot_exists( def test_insert_function( self, checksum_mock, functionmetadata_mock, functionio_mock, function_mock ): - catalog = CatalogManager(MagicMock(), MagicMock()) + catalog = CatalogManager(MagicMock()) function_io_list = [MagicMock()] function_metadata_list = [MagicMock()] actual = catalog.insert_function_catalog_entry( @@ -155,7 +155,7 @@ def test_insert_function( @mock.patch("evadb.catalog.catalog_manager.FunctionCatalogService") def test_get_function_catalog_entry_by_name(self, function_mock): - catalog = CatalogManager(MagicMock(), MagicMock()) + catalog = CatalogManager(MagicMock()) actual = catalog.get_function_catalog_entry_by_name("name") function_mock.return_value.get_entry_by_name.assert_called_with("name") self.assertEqual( @@ -164,25 +164,19 @@ def test_get_function_catalog_entry_by_name(self, function_mock): @mock.patch("evadb.catalog.catalog_manager.FunctionCatalogService") def test_delete_function(self, function_mock): - CatalogManager(MagicMock(), MagicMock()).delete_function_catalog_entry_by_name( - "name" - ) + CatalogManager(MagicMock()).delete_function_catalog_entry_by_name("name") function_mock.return_value.delete_entry_by_name.assert_called_with("name") @mock.patch("evadb.catalog.catalog_manager.FunctionIOCatalogService") def test_get_function_outputs(self, function_mock): mock_func = function_mock.return_value.get_output_entries_by_function_id function_obj = MagicMock(spec=FunctionCatalogEntry) - CatalogManager(MagicMock(), MagicMock()).get_function_io_catalog_output_entries( - function_obj - ) + CatalogManager(MagicMock()).get_function_io_catalog_output_entries(function_obj) mock_func.assert_called_once_with(function_obj.row_id) @mock.patch("evadb.catalog.catalog_manager.FunctionIOCatalogService") def test_get_function_inputs(self, function_mock): mock_func = function_mock.return_value.get_input_entries_by_function_id function_obj = MagicMock(spec=FunctionCatalogEntry) - CatalogManager(MagicMock(), MagicMock()).get_function_io_catalog_input_entries( - function_obj - ) + CatalogManager(MagicMock()).get_function_io_catalog_input_entries(function_obj) mock_func.assert_called_once_with(function_obj.row_id) diff --git a/test/unit_tests/executor/test_execution_context.py b/test/unit_tests/executor/test_execution_context.py index 28383949e..8d2a4caae 100644 --- a/test/unit_tests/executor/test_execution_context.py +++ b/test/unit_tests/executor/test_execution_context.py @@ -21,28 +21,24 @@ class ExecutionContextTest(unittest.TestCase): - @patch("evadb.executor.execution_context.ConfigurationManager") @patch("evadb.executor.execution_context.get_gpu_count") @patch("evadb.executor.execution_context.is_gpu_available") def test_CUDA_VISIBLE_DEVICES_gets_populated_from_config( - self, gpu_check, get_gpu_count, cfm + self, gpu_check, get_gpu_count ): gpu_check.return_value = True get_gpu_count.return_value = 3 - cfm.return_value.get_value.return_value = [0, 1] - context = Context() + context = Context([0, 1]) self.assertEqual(context.gpus, [0, 1]) - @patch("evadb.executor.execution_context.ConfigurationManager") @patch("evadb.executor.execution_context.os") @patch("evadb.executor.execution_context.get_gpu_count") @patch("evadb.executor.execution_context.is_gpu_available") def test_CUDA_VISIBLE_DEVICES_gets_populated_from_environment_if_no_config( - self, is_gpu, get_gpu_count, os, cfm + self, is_gpu, get_gpu_count, os ): is_gpu.return_value = True - cfm.return_value.get_value.return_value = [] get_gpu_count.return_value = 3 os.environ.get.return_value = "0,1" context = Context() @@ -50,57 +46,45 @@ def test_CUDA_VISIBLE_DEVICES_gets_populated_from_environment_if_no_config( self.assertEqual(context.gpus, [0, 1]) - @patch("evadb.executor.execution_context.ConfigurationManager") @patch("evadb.executor.execution_context.os") @patch("evadb.executor.execution_context.get_gpu_count") @patch("evadb.executor.execution_context.is_gpu_available") def test_CUDA_VISIBLE_DEVICES_should_be_empty_if_nothing_provided( - self, gpu_check, get_gpu_count, os, cfm + self, gpu_check, get_gpu_count, os ): gpu_check.return_value = True get_gpu_count.return_value = 3 - cfm.return_value.get_value.return_value = [] os.environ.get.return_value = "" context = Context() os.environ.get.assert_called_with("CUDA_VISIBLE_DEVICES", "") self.assertEqual(context.gpus, []) - @patch("evadb.executor.execution_context.ConfigurationManager") @patch("evadb.executor.execution_context.os") @patch("evadb.executor.execution_context.is_gpu_available") - def test_gpus_ignores_config_if_no_gpu_available(self, gpu_check, os, cfm): + def test_gpus_ignores_config_if_no_gpu_available(self, gpu_check, os): gpu_check.return_value = False - cfm.return_value.get_value.return_value = [0, 1, 2] os.environ.get.return_value = "0,1,2" - context = Context() + context = Context([0, 1, 2]) self.assertEqual(context.gpus, []) - @patch("evadb.executor.execution_context.ConfigurationManager") @patch("evadb.executor.execution_context.os") @patch("evadb.executor.execution_context.is_gpu_available") - def test_gpu_device_should_return_NO_GPU_if_GPU_not_available( - self, gpu_check, os, cfm - ): + def test_gpu_device_should_return_NO_GPU_if_GPU_not_available(self, gpu_check, os): gpu_check.return_value = True - cfm.return_value.get_value.return_value = [] os.environ.get.return_value = "" context = Context() os.environ.get.assert_called_with("CUDA_VISIBLE_DEVICES", "") self.assertEqual(context.gpu_device(), NO_GPU) - @patch("evadb.executor.execution_context.ConfigurationManager") @patch("evadb.executor.execution_context.get_gpu_count") @patch("evadb.executor.execution_context.is_gpu_available") - def test_should_return_random_gpu_ID_if_available( - self, gpu_check, get_gpu_count, cfm - ): + def test_should_return_random_gpu_ID_if_available(self, gpu_check, get_gpu_count): gpu_check.return_value = True get_gpu_count.return_value = 1 - cfm.return_value.get_value.return_value = [0, 1, 2] - context = Context() + context = Context([0, 1, 2]) selected_device = context.gpu_device() self.assertEqual(selected_device, 0) diff --git a/test/unit_tests/optimizer/rules/test_batch_mem_size.py b/test/unit_tests/optimizer/rules/test_batch_mem_size.py index 70033b014..68f84db80 100644 --- a/test/unit_tests/optimizer/rules/test_batch_mem_size.py +++ b/test/unit_tests/optimizer/rules/test_batch_mem_size.py @@ -38,9 +38,7 @@ def test_batch_mem_size_for_sqlite_storage_engine(self): the storage engine. """ test_batch_mem_size = 100 - self.evadb.config.update_value( - "executor", "batch_mem_size", test_batch_mem_size - ) + execute_query_fetch_all(self.evadb, f"SET batch_mem_size={test_batch_mem_size}") create_table_query = """ CREATE TABLE IF NOT EXISTS MyCSV ( id INTEGER UNIQUE, diff --git a/test/unit_tests/optimizer/rules/test_rules.py b/test/unit_tests/optimizer/rules/test_rules.py index e71dbef7c..18f8dc51d 100644 --- a/test/unit_tests/optimizer/rules/test_rules.py +++ b/test/unit_tests/optimizer/rules/test_rules.py @@ -168,8 +168,8 @@ def test_supported_rules(self): XformExtractObjectToLinearFlow(), ] rewrite_rules = ( - RulesManager(self.evadb.config).stage_one_rewrite_rules - + RulesManager(self.evadb.config).stage_two_rewrite_rules + RulesManager().stage_one_rewrite_rules + + RulesManager().stage_two_rewrite_rules ) self.assertEqual( len(supported_rewrite_rules), @@ -187,18 +187,15 @@ def test_supported_rules(self): ] self.assertEqual( len(supported_logical_rules), - len(RulesManager(self.evadb.config).logical_rules), + len(RulesManager().logical_rules), ) for rule in supported_logical_rules: self.assertTrue( - any( - isinstance(rule, type(x)) - for x in RulesManager(self.evadb.config).logical_rules - ) + any(isinstance(rule, type(x)) for x in RulesManager().logical_rules) ) - ray_enabled = self.evadb.config.get_value("experimental", "ray") + ray_enabled = self.evadb.catalog().get_configuration_catalog_value("ray") ray_enabled_and_installed = is_ray_enabled_and_installed(ray_enabled) # For the current version, we choose either the distributed or the @@ -244,14 +241,14 @@ def test_supported_rules(self): supported_implementation_rules.append(LogicalExchangeToPhysical()) self.assertEqual( len(supported_implementation_rules), - len(RulesManager(self.evadb.config).implementation_rules), + len(RulesManager().implementation_rules), ) for rule in supported_implementation_rules: self.assertTrue( any( isinstance(rule, type(x)) - for x in RulesManager(self.evadb.config).implementation_rules + for x in RulesManager().implementation_rules ) ) @@ -280,7 +277,7 @@ def test_embed_sample_into_get_does_not_work_with_structured_data(self): self.assertFalse(rule.check(logi_sample, MagicMock())) def test_disable_rules(self): - rules_manager = RulesManager(self.evadb.config) + rules_manager = RulesManager() with disable_rules(rules_manager, [PushDownFilterThroughApplyAndMerge()]): self.assertFalse( any( diff --git a/test/unit_tests/optimizer/test_optimizer_task.py b/test/unit_tests/optimizer/test_optimizer_task.py index 6230dfb16..c9100d6ef 100644 --- a/test/unit_tests/optimizer/test_optimizer_task.py +++ b/test/unit_tests/optimizer/test_optimizer_task.py @@ -51,13 +51,11 @@ def test_abstract_optimizer_task(self): task.execute() def top_down_rewrite(self, opr): - opt_cxt = OptimizerContext(MagicMock(), CostModel(), RulesManager(MagicMock())) + opt_cxt = OptimizerContext(MagicMock(), CostModel(), RulesManager()) grp_expr = opt_cxt.add_opr_to_group(opr) root_grp_id = grp_expr.group_id opt_cxt.task_stack.push( - TopDownRewrite( - grp_expr, RulesManager(MagicMock()).stage_one_rewrite_rules, opt_cxt - ) + TopDownRewrite(grp_expr, RulesManager().stage_one_rewrite_rules, opt_cxt) ) self.execute_task_stack(opt_cxt.task_stack) return opt_cxt, root_grp_id @@ -65,9 +63,7 @@ def top_down_rewrite(self, opr): def bottom_up_rewrite(self, root_grp_id, opt_cxt): grp_expr = opt_cxt.memo.groups[root_grp_id].logical_exprs[0] opt_cxt.task_stack.push( - BottomUpRewrite( - grp_expr, RulesManager(MagicMock()).stage_two_rewrite_rules, opt_cxt - ) + BottomUpRewrite(grp_expr, RulesManager().stage_two_rewrite_rules, opt_cxt) ) self.execute_task_stack(opt_cxt.task_stack) return opt_cxt, root_grp_id diff --git a/test/unit_tests/test_dalle.py b/test/unit_tests/test_dalle.py index a7a9536fa..c434a4db4 100644 --- a/test/unit_tests/test_dalle.py +++ b/test/unit_tests/test_dalle.py @@ -41,7 +41,7 @@ def setUp(self) -> None: def tearDown(self) -> None: execute_query_fetch_all(self.evadb, "DROP TABLE IF EXISTS ImageGen;") - @patch.dict("os.environ", {"OPENAI_KEY": "mocked_openai_key"}) + @patch.dict("os.environ", {"OPENAI_API_KEY": "mocked_openai_key"}) @patch("requests.get") @patch("openai.Image.create", return_value={"data": [{"url": "mocked_url"}]}) def test_dalle_image_generation(self, mock_openai_create, mock_requests_get): diff --git a/test/unit_tests/test_eva_cmd_client.py b/test/unit_tests/test_eva_cmd_client.py index 0e1f67ee9..90a7f4a15 100644 --- a/test/unit_tests/test_eva_cmd_client.py +++ b/test/unit_tests/test_eva_cmd_client.py @@ -16,14 +16,13 @@ import asyncio import unittest -import pytest -from mock import call, patch +from mock import patch -from evadb.configuration.configuration_manager import ConfigurationManager from evadb.evadb_cmd_client import evadb_client, main +from evadb.evadb_config import BASE_EVADB_CONFIG -@pytest.mark.skip +# @pytest.mark.skip class CMDClientTest(unittest.TestCase): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -81,17 +80,10 @@ def test_main_without_cmd_arguments( [], ) - # Mock the ConfigurationManager's get_value method - with patch.object( - ConfigurationManager, "get_value", return_value="default_value" - ) as mock_get_value: - # Call the function under test - main() - - # Assert that the mocked functions were called correctly - mock_start_cmd_client.assert_called_once_with( - "default_value", "default_value" - ) - mock_get_value.assert_has_calls( - [call("server", "host"), call("server", "port")] - ) + # Call the function under test + main() + + # Assert that the mocked functions were called correctly + mock_start_cmd_client.assert_called_once_with( + BASE_EVADB_CONFIG["host"], BASE_EVADB_CONFIG["port"] + ) diff --git a/test/util.py b/test/util.py index 23eaeb35e..3a23a6ff5 100644 --- a/test/util.py +++ b/test/util.py @@ -30,8 +30,12 @@ from evadb.binder.statement_binder import StatementBinder from evadb.binder.statement_binder_context import StatementBinderContext from evadb.catalog.catalog_type import NdArrayType -from evadb.configuration.configuration_manager import ConfigurationManager -from evadb.configuration.constants import EvaDB_DATABASE_DIR, EvaDB_INSTALLATION_DIR +from evadb.configuration.constants import ( + S3_DOWNLOAD_DIR, + TMP_DIR, + EvaDB_DATABASE_DIR, + EvaDB_INSTALLATION_DIR, +) from evadb.database import init_evadb_instance from evadb.expression.function_expression import FunctionExpression from evadb.functions.abstract.abstract_function import ( @@ -67,7 +71,7 @@ def suffix_pytest_xdist_worker_id_to_dir(path: str): path = Path(str(worker_id) + "_" + path) except KeyError: pass - return path + return Path(path) def get_evadb_for_testing(uri: str = None): @@ -79,14 +83,12 @@ def get_evadb_for_testing(uri: str = None): def get_tmp_dir(): db_dir = suffix_pytest_xdist_worker_id_to_dir(EvaDB_DATABASE_DIR) - config = ConfigurationManager(Path(db_dir)) - return config.get_value("storage", "tmp_dir") + return db_dir / TMP_DIR def s3_dir(): db_dir = suffix_pytest_xdist_worker_id_to_dir(EvaDB_DATABASE_DIR) - config = ConfigurationManager(Path(db_dir)) - return config.get_value("storage", "s3_download_dir") + return db_dir / S3_DOWNLOAD_DIR EvaDB_TEST_DATA_DIR = Path(EvaDB_INSTALLATION_DIR).parent From a6fdd6a68def142ed227236d46845154b93a338c Mon Sep 17 00:00:00 2001 From: Gaurav Tarlok Kakkar Date: Thu, 19 Oct 2023 14:16:50 -0400 Subject: [PATCH 12/12] Fix: minor typo (#1307) --- .circleci/config.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index cb7ad985d..af2c5ea07 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -234,7 +234,7 @@ jobs: else pip install ".[dev,pinecone,chromadb]" # ray < 2.5.0 does not work with python 3.11 ray-project/ray#33864 fi - python -c "import evadb;cur=evadb.connect().cursor();cur.query('SET ray=True';)" + python -c "import evadb;cur=evadb.connect().cursor();cur.query('SET ray=True;')" else if [ $PY_VERSION != "3.11" ]; then pip install ".[dev,ludwig,qdrant,pinecone,chromadb]"