diff --git a/src/llama_index_alloydb_pg/async_vectorstore.py b/src/llama_index_alloydb_pg/async_vectorstore.py new file mode 100644 index 0000000..01373c5 --- /dev/null +++ b/src/llama_index_alloydb_pg/async_vectorstore.py @@ -0,0 +1,288 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# TODO: Remove below import when minimum supported Python version is 3.10 +from __future__ import annotations + +import base64 +import json +import re +import uuid +from typing import Any, Callable, Iterable, List, Optional, Sequence, Tuple, Type + +import numpy as np +from llama_index.core.schema import BaseNode, MetadataMode, TextNode +from llama_index.core.vector_stores.types import ( + BasePydanticVectorStore, + FilterOperator, + MetadataFilter, + MetadataFilters, + VectorStoreQuery, + VectorStoreQueryMode, + VectorStoreQueryResult, +) +from sqlalchemy import RowMapping, text +from sqlalchemy.ext.asyncio import AsyncEngine + +from .engine import AlloyDBEngine + + +class AsyncAlloyDBVectorStore(BasePydanticVectorStore): + """Google AlloyDB Vector Store class""" + + stores_text: bool = True + is_embedding_query: bool = True + + __create_key = object() + + def __init__( + self, + key: object, + engine: AsyncEngine, + table_name: str, + schema_name: str = "public", + id_column: str = "node_id", + text_column: str = "text", + embedding_column: str = "embedding", + metadata_json_column: str = "li_metadata", + metadata_columns: List[str] = [], + ref_doc_id_column: str = "ref_doc_id", + node_column: str = "node", + stores_text: bool = True, + is_embedding_query: bool = True, + ): + """AsyncAlloyDBVectorStore constructor. + Args: + key (object): Prevent direct constructor usage. + engine (AsyncEngine): Connection pool engine for managing connections to AlloyDB database. + table_name (str): Name of the existing table or the table to be created. + schema_name (str): Name of the database schema. Defaults to "public". + id_column (str): Column that represents if of a Node. Defaults to "node_id". + text_column (str): Column that represent text content of a Node. Defaults to "text". + embedding_column (str): Column for embedding vectors. The embedding is generated from the content of Node. Defaults to "embedding". + metadata_json_column (str): Column to store metadata as JSON. Defaults to "li_metadata". + metadata_columns (List[str]): Column(s) that represent extracted metadata keys in their own columns. + ref_doc_id_column (str): Column that represents id of a node's parent document. Defaults to "ref_doc_id". + node_column (str): Column that represents the whole JSON node. Defaults to "node". + stores_text (bool): Whether the table stores text. Defaults to "True". + is_embedding_query (bool): Whether the table query can have embeddings. Defaults to "True". + + + Raises: + Exception: If called directly by user. + """ + if key != AsyncAlloyDBVectorStore.__create_key: + raise Exception("Only create class through 'create' method!") + + # Delegate to Pydantic's __init__ + super().__init__(stores_text=stores_text, is_embedding_query=is_embedding_query) + self._engine = engine + self._table_name = table_name + self._schema_name = schema_name + self._id_column = id_column + self._text_column = text_column + self._embedding_column = embedding_column + self._metadata_json_column = metadata_json_column + self._metadata_columns = metadata_columns + self._ref_doc_id_column = ref_doc_id_column + self._node_column = node_column + + @classmethod + async def create( + cls: Type[AsyncAlloyDBVectorStore], + engine: AlloyDBEngine, + table_name: str, + schema_name: str = "public", + id_column: str = "node_id", + text_column: str = "text", + embedding_column: str = "embedding", + metadata_json_column: str = "li_metadata", + metadata_columns: List[str] = [], + ref_doc_id_column: str = "ref_doc_id", + node_column: str = "node", + stores_text: bool = True, + is_embedding_query: bool = True, + ) -> AsyncAlloyDBVectorStore: + """Create an AsyncAlloyDBVectorStore instance and validates the table schema. + + Args: + engine (AlloyDBEngine): Alloy DB Engine for managing connections to AlloyDB database. + table_name (str): Name of the existing table or the table to be created. + schema_name (str): Name of the database schema. Defaults to "public". + id_column (str): Column that represents if of a Node. Defaults to "node_id". + text_column (str): Column that represent text content of a Node. Defaults to "text". + embedding_column (str): Column for embedding vectors. The embedding is generated from the content of Node. Defaults to "embedding". + metadata_json_column (str): Column to store metadata as JSON. Defaults to "li_metadata". + metadata_columns (List[str]): Column(s) that represent extracted metadata keys in their own columns. + ref_doc_id_column (str): Column that represents id of a node's parent document. Defaults to "ref_doc_id". + node_column (str): Column that represents the whole JSON node. Defaults to "node". + stores_text (bool): Whether the table stores text. Defaults to "True". + is_embedding_query (bool): Whether the table query can have embeddings. Defaults to "True". + + Raises: + Exception: If table does not exist or follow the provided structure. + + Returns: + AsyncAlloyDBVectorStore + """ + stmt = f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{table_name}' AND table_schema = '{schema_name}'" + async with engine._pool.connect() as conn: + result = await conn.execute(text(stmt)) + result_map = result.mappings() + results = result_map.fetchall() + columns = {} + for field in results: + columns[field["column_name"]] = field["data_type"] + + # Check columns + if id_column not in columns: + raise ValueError(f"Id column, {id_column}, does not exist.") + if text_column not in columns: + raise ValueError(f"Text column, {text_column}, does not exist.") + text_type = columns[text_column] + if text_type != "text" and "char" not in text_type: + raise ValueError( + f"Text column, {text_column}, is type, {text_type}. It must be a type of character string." + ) + if embedding_column not in columns: + raise ValueError(f"Embedding column, {embedding_column}, does not exist.") + if columns[embedding_column] != "USER-DEFINED": + raise ValueError( + f"Embedding column, {embedding_column}, is not type Vector." + ) + if node_column not in columns: + raise ValueError(f"Node column, {node_column}, does not exist.") + if columns[node_column] != "json": + raise ValueError(f"Node column, {node_column}, is not type JSON.") + if ref_doc_id_column not in columns: + raise ValueError( + f"Reference Document Id column, {ref_doc_id_column}, does not exist." + ) + if metadata_json_column not in columns: + raise ValueError( + f"Metadata column, {metadata_json_column}, does not exist." + ) + if columns[metadata_json_column] != "jsonb": + raise ValueError( + f"Metadata column, {metadata_json_column}, is not type JSONB." + ) + # If using metadata_columns check to make sure column exists + for column in metadata_columns: + if column not in columns: + raise ValueError(f"Metadata column, {column}, does not exist.") + + return cls( + cls.__create_key, + engine._pool, + table_name, + schema_name=schema_name, + id_column=id_column, + text_column=text_column, + embedding_column=embedding_column, + metadata_json_column=metadata_json_column, + metadata_columns=metadata_columns, + ref_doc_id_column=ref_doc_id_column, + node_column=node_column, + stores_text=stores_text, + is_embedding_query=is_embedding_query, + ) + + @classmethod + def class_name(cls) -> str: + return "AsyncAlloyDBVectorStore" + + @property + def client(self) -> Any: + """Get client.""" + return self._engine + + async def async_add(self, nodes: Sequence[BaseNode], **kwargs: Any) -> List[str]: + """Asynchronously add nodes to the table.""" + # TODO: complete implementation + return [] + + async def adelete(self, ref_doc_id: str, **delete_kwargs: Any) -> None: + """Asynchronously delete nodes belonging to provided parent document from the table.""" + # TODO: complete implementation + return + + async def adelete_nodes( + self, + node_ids: Optional[List[str]] = None, + filters: Optional[MetadataFilters] = None, + **delete_kwargs: Any, + ) -> None: + """Asynchronously delete a set of nodes from the table matching the provided nodes and filters.""" + # TODO: complete implementation + return + + async def aclear(self) -> None: + """Asynchronously delete all nodes from the table.""" + # TODO: complete implementation + return + + async def aget_nodes( + self, + node_ids: Optional[List[str]] = None, + filters: Optional[MetadataFilters] = None, + ) -> List[BaseNode]: + """Asynchronously get nodes from the table matching the provided nodes and filters.""" + # TODO: complete implementation + return [] + + async def aquery( + self, query: VectorStoreQuery, **kwargs: Any + ) -> VectorStoreQueryResult: + """Asynchronously query vector store.""" + # TODO: complete implementation + return VectorStoreQueryResult() + + def add(self, nodes: Sequence[BaseNode], **add_kwargs: Any) -> List[str]: + raise NotImplementedError( + "Sync methods are not implemented for AsyncAlloyDBVectorStore. Use AlloyDBVectorStore interface instead." + ) + + def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None: + raise NotImplementedError( + "Sync methods are not implemented for AsyncAlloyDBVectorStore. Use AlloyDBVectorStore interface instead." + ) + + def delete_nodes( + self, + node_ids: Optional[List[str]] = None, + filters: Optional[MetadataFilters] = None, + **delete_kwargs: Any, + ) -> None: + raise NotImplementedError( + "Sync methods are not implemented for AsyncAlloyDBVectorStore. Use AlloyDBVectorStore interface instead." + ) + + def clear(self) -> None: + raise NotImplementedError( + "Sync methods are not implemented for AsyncAlloyDBVectorStore. Use AlloyDBVectorStore interface instead." + ) + + def get_nodes( + self, + node_ids: Optional[List[str]] = None, + filters: Optional[MetadataFilters] = None, + ) -> List[BaseNode]: + raise NotImplementedError( + "Sync methods are not implemented for AsyncAlloyDBVectorStore. Use AlloyDBVectorStore interface instead." + ) + + def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult: + raise NotImplementedError( + "Sync methods are not implemented for AsyncAlloyDBVectorStore. Use AlloyDBVectorStore interface instead." + ) diff --git a/tests/test_async_vectorstore.py b/tests/test_async_vectorstore.py new file mode 100644 index 0000000..8ff95ae --- /dev/null +++ b/tests/test_async_vectorstore.py @@ -0,0 +1,188 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import uuid +from typing import List, Sequence + +import pytest +import pytest_asyncio +from llama_index.core.schema import TextNode # type: ignore +from llama_index.core.vector_stores.types import VectorStoreQuery +from sqlalchemy import text +from sqlalchemy.engine.row import RowMapping + +from llama_index_alloydb_pg import AlloyDBEngine +from llama_index_alloydb_pg.async_vectorstore import AsyncAlloyDBVectorStore + +DEFAULT_TABLE = "test_table" + str(uuid.uuid4()) +VECTOR_SIZE = 768 + +texts = ["foo", "bar", "baz"] +nodes = [TextNode(text=texts[i]) for i in range(len(texts))] +sync_method_exception_str = "Sync methods are not implemented for AsyncAlloyDBVectorStore. Use AlloyDBVectorStore interface instead." + + +def get_env_var(key: str, desc: str) -> str: + v = os.environ.get(key) + if v is None: + raise ValueError(f"Must set env var {key} to: {desc}") + return v + + +async def aexecute(engine: AlloyDBEngine, query: str) -> None: + async with engine._pool.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + +@pytest.mark.asyncio(loop_scope="class") +class TestVectorStore: + @pytest.fixture(scope="module") + def db_project(self) -> str: + return get_env_var("PROJECT_ID", "project id for google cloud") + + @pytest.fixture(scope="module") + def db_region(self) -> str: + return get_env_var("REGION", "region for AlloyDB instance") + + @pytest.fixture(scope="module") + def db_cluster(self) -> str: + return get_env_var("CLUSTER_ID", "cluster for AlloyDB") + + @pytest.fixture(scope="module") + def db_instance(self) -> str: + return get_env_var("INSTANCE_ID", "instance for AlloyDB") + + @pytest.fixture(scope="module") + def db_name(self) -> str: + return get_env_var("DATABASE_ID", "database name on AlloyDB instance") + + @pytest.fixture(scope="module") + def db_user(self) -> str: + return get_env_var("DB_USER", "database name on AlloyDB instance") + + @pytest.fixture(scope="module") + def db_pwd(self) -> str: + return get_env_var("DB_PASSWORD", "database name on AlloyDB instance") + + @pytest_asyncio.fixture(scope="class") + async def engine(self, db_project, db_region, db_cluster, db_instance, db_name): + engine = await AlloyDBEngine.afrom_instance( + project_id=db_project, + instance=db_instance, + cluster=db_cluster, + region=db_region, + database=db_name, + ) + + yield engine + await aexecute(engine, f'DROP TABLE "{DEFAULT_TABLE}"') + await engine.close() + + @pytest_asyncio.fixture(scope="class") + async def vs(self, engine): + await engine._ainit_vector_store_table( + DEFAULT_TABLE, VECTOR_SIZE, overwrite_existing=True + ) + vs = await AsyncAlloyDBVectorStore.create(engine, table_name=DEFAULT_TABLE) + yield vs + + async def test_init_with_constructor(self, engine): + with pytest.raises(Exception): + AsyncAlloyDBVectorStore(engine, table_name=DEFAULT_TABLE) + + async def test_validate_id_column_create(self, engine, vs): + test_id_column = "test_id_column" + with pytest.raises( + Exception, match=f"Id column, {test_id_column}, does not exist." + ): + await AsyncAlloyDBVectorStore.create( + engine, table_name=DEFAULT_TABLE, id_column=test_id_column + ) + + async def test_validate_text_column_create(self, engine, vs): + test_text_column = "test_text_column" + with pytest.raises( + Exception, match=f"Text column, {test_text_column}, does not exist." + ): + await AsyncAlloyDBVectorStore.create( + engine, table_name=DEFAULT_TABLE, text_column=test_text_column + ) + + async def test_validate_embedding_column_create(self, engine, vs): + test_embed_column = "test_embed_column" + with pytest.raises( + Exception, match=f"Embedding column, {test_embed_column}, does not exist." + ): + await AsyncAlloyDBVectorStore.create( + engine, table_name=DEFAULT_TABLE, embedding_column=test_embed_column + ) + + async def test_validate_node_column_create(self, engine, vs): + test_node_column = "test_node_column" + with pytest.raises( + Exception, match=f"Node column, {test_node_column}, does not exist." + ): + await AsyncAlloyDBVectorStore.create( + engine, table_name=DEFAULT_TABLE, node_column=test_node_column + ) + + async def test_validate_ref_doc_id_column_create(self, engine, vs): + test_ref_doc_id_column = "test_ref_doc_id_column" + with pytest.raises( + Exception, + match=f"Reference Document Id column, {test_ref_doc_id_column}, does not exist.", + ): + await AsyncAlloyDBVectorStore.create( + engine, + table_name=DEFAULT_TABLE, + ref_doc_id_column=test_ref_doc_id_column, + ) + + async def test_validate_metadata_json_column_create(self, engine, vs): + test_metadata_json_column = "test_metadata_json_column" + with pytest.raises( + Exception, + match=f"Metadata column, {test_metadata_json_column}, does not exist.", + ): + await AsyncAlloyDBVectorStore.create( + engine, + table_name=DEFAULT_TABLE, + metadata_json_column=test_metadata_json_column, + ) + + async def test_add(self, vs): + with pytest.raises(Exception, match=sync_method_exception_str): + vs.add(nodes) + + async def test_get_nodes(self, vs): + with pytest.raises(Exception, match=sync_method_exception_str): + vs.get_nodes() + + async def test_query(self, vs): + with pytest.raises(Exception, match=sync_method_exception_str): + vs.query(VectorStoreQuery(query_str="foo")) + + async def test_delete(self, vs): + with pytest.raises(Exception, match=sync_method_exception_str): + vs.delete("test_ref_doc_id") + + async def test_delete_nodes(self, vs): + with pytest.raises(Exception, match=sync_method_exception_str): + vs.delete_nodes(["test_node_id"]) + + async def test_clear(self, vs): + with pytest.raises(Exception, match=sync_method_exception_str): + vs.clear()