diff --git a/src/llama_index_alloydb_pg/__init__.py b/src/llama_index_alloydb_pg/__init__.py index 9f240e7..21ff625 100644 --- a/src/llama_index_alloydb_pg/__init__.py +++ b/src/llama_index_alloydb_pg/__init__.py @@ -13,5 +13,6 @@ # limitations under the License. from .engine import AlloyDBEngine, Column +from .index_store import AlloyDBIndexStore -_all = ["AlloyDBEngine", "Column"] +_all = ["AlloyDBEngine", "Column", "AlloyDBIndexStore"] diff --git a/src/llama_index_alloydb_pg/index_store.py b/src/llama_index_alloydb_pg/index_store.py new file mode 100644 index 0000000..088a091 --- /dev/null +++ b/src/llama_index_alloydb_pg/index_store.py @@ -0,0 +1,153 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import json +from typing import List, Optional + +from llama_index.core.data_structs.data_structs import IndexStruct +from llama_index.core.storage.index_store.types import BaseIndexStore +from llama_index.core.storage.kvstore.types import DEFAULT_BATCH_SIZE + +from .async_index_store import AsyncAlloyDBIndexStore +from .engine import AlloyDBEngine + + +class AlloyDBIndexStore(BaseIndexStore): + """Index Store Table stored in an AlloyDB for PostgreSQL database.""" + + __create_key = object() + + def __init__( + self, key: object, engine: AlloyDBEngine, index_store: AsyncAlloyDBIndexStore + ): + """AlloyDBIndexStore constructor. + + Args: + key (object): Key to prevent direct constructor usage. + engine (AlloyDBEngine): Database connection pool. + table_name (str): Table name that stores the index metadata. + schema_name (str): The schema name where the table is located. Defaults to "public" + batch_size (str): The default batch size for bulk inserts. Defaults to 1. + + Raises: + Exception: If constructor is directly called by the user. + """ + if key != AlloyDBIndexStore.__create_key: + raise Exception( + "Only create class through 'create' or 'create_sync' methods!" + ) + self._engine = engine + self.__index_store = index_store + + @classmethod + async def create( + cls, + engine: AlloyDBEngine, + table_name: str, + schema_name: str = "public", + batch_size: int = DEFAULT_BATCH_SIZE, + ) -> AlloyDBIndexStore: + """Create a new AlloyDBIndexStore instance. + + Args: + engine (AlloyDBEngine): AlloyDB engine to use. + table_name (str): Table name that stores the index metadata. + schema_name (str): The schema name where the table is located. Defaults to "public" + batch_size (str): The default batch size for bulk inserts. Defaults to 1. + + Raises: + IndexError: If the table provided does not contain required schema. + + Returns: + AlloyDBIndexStore: A newly created instance of AlloyDBIndexStore. + """ + coro = AsyncAlloyDBIndexStore.create( + engine, table_name, schema_name + ) + index_store = await engine._run_as_async(coro) + return cls(cls.__create_key, engine, index_store) + + @classmethod + def create_sync( + cls, + engine: AlloyDBEngine, + table_name: str, + schema_name: str = "public", + batch_size: int = DEFAULT_BATCH_SIZE, + ) -> AlloyDBIndexStore: + """Create a new AlloyDBIndexStore sync instance. + + Args: + engine (AlloyDBEngine): AlloyDB engine to use. + table_name (str): Table name that stores the index metadata. + schema_name (str): The schema name where the table is located. Defaults to "public" + batch_size (str): The default batch size for bulk inserts. Defaults to 1. + + Raises: + IndexError: If the table provided does not contain required schema. + + Returns: + AlloyDBIndexStore: A newly created instance of AlloyDBIndexStore. + """ + coro = AsyncAlloyDBIndexStore.create( + engine, table_name, schema_name + ) + index_store = engine._run_as_sync(coro) + return cls(cls.__create_key, engine, index_store) + + async def aindex_structs(self) -> List[IndexStruct]: + return await self._engine._run_as_async(self.__index_store.aindex_structs()) + + def index_structs(self) -> List[IndexStruct]: + return self._engine._run_as_sync(self.__index_store.aindex_structs()) + + async def aadd_index_struct(self, index_struct: IndexStruct) -> None: + """Add an index struct. + + Args: + index_struct (IndexStruct): index struct + + """ + return await self._engine._run_as_async( + self.__index_store.aadd_index_struct(index_struct) + ) + + def add_index_struct(self, index_struct: IndexStruct) -> None: + return self._engine._run_as_sync( + self.__index_store.aadd_index_struct(index_struct) + ) + + async def adelete_index_struct(self, key: str) -> None: + return await self._engine._run_as_async( + self.__index_store.adelete_index_struct(key) + ) + + def delete_index_struct(self, key: str) -> None: + return self._engine._run_as_sync(self.__index_store.adelete_index_struct(key)) + + async def aget_index_struct( + self, struct_id: Optional[str] = None + ) -> Optional[IndexStruct]: + return await self._engine._run_as_async( + self.__index_store.aget_index_struct(struct_id) + ) + + def get_index_struct( + self, struct_id: Optional[str] = None + ) -> Optional[IndexStruct]: + return self._engine._run_as_sync( + self.__index_store.aget_index_struct(struct_id) + ) diff --git a/tests/test_index_store.py b/tests/test_index_store.py new file mode 100644 index 0000000..1f81bfe --- /dev/null +++ b/tests/test_index_store.py @@ -0,0 +1,305 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import uuid +from typing import Sequence +import warnings + +import pytest +import pytest_asyncio +from llama_index.core.data_structs.data_structs import IndexDict, IndexGraph, IndexList +from sqlalchemy import RowMapping, text + +from llama_index_alloydb_pg import AlloyDBEngine, AlloyDBIndexStore +from llama_index_alloydb_pg.async_index_store import AsyncAlloyDBIndexStore + +default_table_name_async = "document_store_" + str(uuid.uuid4()) +default_table_name_sync = "document_store_" + str(uuid.uuid4()) + + +async def aexecute( + engine: AlloyDBEngine, + query: str, +) -> None: + async def run(engine, query): + async with engine._pool.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + await engine._run_as_async(run(engine, query)) + + +async def afetch(engine: AlloyDBEngine, query: str) -> Sequence[RowMapping]: + async def run(engine, query): + async with engine._pool.connect() as conn: + result = await conn.execute(text(query)) + result_map = result.mappings() + result_fetch = result_map.fetchall() + return result_fetch + + return await engine._run_as_async(run(engine, query)) + + +def get_env_var(key: str, desc: str) -> str: + v = os.environ.get(key) + if v is None: + raise ValueError(f"Must set env var {key} to: {desc}") + return v + + +@pytest.mark.asyncio(loop_scope="class") +class TestAlloyDBIndexStoreAsync: + @pytest.fixture(scope="module") + def db_project(self) -> str: + return get_env_var("PROJECT_ID", "project id for google cloud") + + @pytest.fixture(scope="module") + def db_region(self) -> str: + return get_env_var("REGION", "region for AlloyDB instance") + + @pytest.fixture(scope="module") + def db_cluster(self) -> str: + return get_env_var("CLUSTER_ID", "cluster for AlloyDB") + + @pytest.fixture(scope="module") + def db_instance(self) -> str: + return get_env_var("INSTANCE_ID", "instance for AlloyDB") + + @pytest.fixture(scope="module") + def db_name(self) -> str: + return get_env_var("DATABASE_ID", "database name on AlloyDB instance") + + @pytest.fixture(scope="module") + def user(self) -> str: + return get_env_var("DB_USER", "database user for AlloyDB") + + @pytest.fixture(scope="module") + def password(self) -> str: + return get_env_var("DB_PASSWORD", "database password for AlloyDB") + + @pytest_asyncio.fixture(scope="class") + async def async_engine( + self, db_project, db_region, db_cluster, db_instance, db_name, user, password + ): + async_engine = await AlloyDBEngine.afrom_instance( + project_id=db_project, + instance=db_instance, + cluster=db_cluster, + region=db_region, + database=db_name, + user=user, + password=password, + ) + + yield async_engine + + await async_engine.close() + + @pytest_asyncio.fixture(scope="class") + async def index_store(self, async_engine): + await async_engine.ainit_index_store_table(table_name=default_table_name_async) + + index_store = await AlloyDBIndexStore.create( + engine=async_engine, table_name=default_table_name_async + ) + + yield index_store + + query = f'DROP TABLE IF EXISTS "{default_table_name_async}"' + await aexecute(async_engine, query) + + async def test_init_with_constructor(self, async_engine): + with pytest.raises(Exception): + AsyncAlloyDBIndexStore( + engine=async_engine, table_name=default_table_name_async + ) + + async def test_add_and_delete_index(self, index_store, async_engine): + index_struct = IndexGraph() + index_id = index_struct.index_id + index_type = index_struct.get_type() + await index_store.aadd_index_struct(index_struct) + + query = f"""select * from "public"."{default_table_name_async}" where index_id = '{index_id}';""" + results = await afetch(async_engine, query) + result = results[0] + assert result.get("type") == index_type + + await index_store.adelete_index_struct(index_id) + query = f"""select * from "public"."{default_table_name_async}" where index_id = '{index_id}';""" + results = await afetch(async_engine, query) + assert results == [] + + async def test_get_index(self, index_store): + index_struct = IndexGraph() + index_id = index_struct.index_id + index_type = index_struct.get_type() + await index_store.aadd_index_struct(index_struct) + + ind_struct = await index_store.aget_index_struct(index_id) + + assert index_struct == ind_struct + + async def test_aindex_structs(self, index_store): + index_dict_struct = IndexDict() + index_list_struct = IndexList() + index_graph_struct = IndexGraph() + + await index_store.aadd_index_struct(index_dict_struct) + await index_store.aadd_index_struct(index_graph_struct) + await index_store.aadd_index_struct(index_list_struct) + + indexes = await index_store.aindex_structs() + + assert indexes[index_dict_struct.index_id] == index_dict_struct + assert indexes[index_list_struct.index_id] == index_list_struct + assert indexes[index_graph_struct.index_id] == index_graph_struct + + async def test_warning(self, index_store): + index_dict_struct = IndexDict() + index_list_struct = IndexList() + + await index_store.aadd_index_struct(index_dict_struct) + await index_store.aadd_index_struct(index_list_struct) + + with warnings.catch_warnings(record=True) as w: + index_struct = await index_store.aget_index_struct() + + assert len(w) == 1 + assert "No struct_id specified and more than one struct exists." in str( + w[-1].message + ) + + +@pytest.mark.asyncio(loop_scope="class") +class TestAlloyDBIndexStoreSync: + @pytest.fixture(scope="module") + def db_project(self) -> str: + return get_env_var("PROJECT_ID", "project id for google cloud") + + @pytest.fixture(scope="module") + def db_region(self) -> str: + return get_env_var("REGION", "region for AlloyDB instance") + + @pytest.fixture(scope="module") + def db_cluster(self) -> str: + return get_env_var("CLUSTER_ID", "cluster for AlloyDB") + + @pytest.fixture(scope="module") + def db_instance(self) -> str: + return get_env_var("INSTANCE_ID", "instance for AlloyDB") + + @pytest.fixture(scope="module") + def db_name(self) -> str: + return get_env_var("DATABASE_ID", "database name on AlloyDB instance") + + @pytest.fixture(scope="module") + def user(self) -> str: + return get_env_var("DB_USER", "database user for AlloyDB") + + @pytest.fixture(scope="module") + def password(self) -> str: + return get_env_var("DB_PASSWORD", "database password for AlloyDB") + + @pytest_asyncio.fixture(scope="class") + async def async_engine( + self, db_project, db_region, db_cluster, db_instance, db_name, user, password + ): + async_engine = AlloyDBEngine.from_instance( + project_id=db_project, + instance=db_instance, + cluster=db_cluster, + region=db_region, + database=db_name, + user=user, + password=password, + ) + + yield async_engine + + await async_engine.close() + + @pytest_asyncio.fixture(scope="class") + async def index_store(self, async_engine): + async_engine.init_index_store_table(table_name=default_table_name_sync) + + index_store = AlloyDBIndexStore.create_sync( + engine=async_engine, table_name=default_table_name_sync + ) + + yield index_store + + query = f'DROP TABLE IF EXISTS "{default_table_name_sync}"' + await aexecute(async_engine, query) + + async def test_init_with_constructor(self, async_engine): + with pytest.raises(Exception): + AsyncAlloyDBIndexStore( + engine=async_engine, table_name=default_table_name_sync + ) + + async def test_add_and_delete_index(self, index_store, async_engine): + index_struct = IndexGraph() + index_id = index_struct.index_id + index_type = index_struct.get_type() + index_store.add_index_struct(index_struct) + + query = f"""select * from "public"."{default_table_name_sync}" where index_id = '{index_id}';""" + results = await afetch(async_engine, query) + result = results[0] + assert result.get("type") == index_type + + index_store.delete_index_struct(index_id) + query = f"""select * from "public"."{default_table_name_sync}" where index_id = '{index_id}';""" + results = await afetch(async_engine, query) + assert results == [] + + async def test_get_index(self, index_store): + index_struct = IndexGraph() + index_id = index_struct.index_id + index_type = index_struct.get_type() + index_store.add_index_struct(index_struct) + + ind_struct = index_store.get_index_struct(index_id) + + assert index_struct == ind_struct + + async def test_aindex_structs(self, index_store): + index_dict_struct = IndexDict() + index_list_struct = IndexList() + index_graph_struct = IndexGraph() + + index_store.add_index_struct(index_dict_struct) + index_store.add_index_struct(index_graph_struct) + index_store.add_index_struct(index_list_struct) + + indexes = index_store.index_structs() + + assert indexes[index_dict_struct.index_id] == index_dict_struct + assert indexes[index_list_struct.index_id] == index_list_struct + assert indexes[index_graph_struct.index_id] == index_graph_struct + + async def test_warning(self, index_store): + index_dict_struct = IndexDict() + index_list_struct = IndexList() + + index_store.add_index_struct(index_dict_struct) + index_store.add_index_struct(index_list_struct) + + with warnings.catch_warnings(record=True) as w: + index_struct = index_store.get_index_struct() + + assert len(w) == 1 + assert "No struct_id specified and more than one struct exists." in str(w[-1].message)