From b8981bcf5565346ef88891474006336e3734ff0d Mon Sep 17 00:00:00 2001 From: Vishwaraj Anand Date: Wed, 20 Nov 2024 16:35:56 +0000 Subject: [PATCH 01/10] feat: vector store class structure --- .../async_vectorstore.py | 298 ++++++++++++++++++ tests/test_async_vectorstore.py | 137 ++++++++ 2 files changed, 435 insertions(+) create mode 100644 src/llama_index_alloydb_pg/async_vectorstore.py create mode 100644 tests/test_async_vectorstore.py diff --git a/src/llama_index_alloydb_pg/async_vectorstore.py b/src/llama_index_alloydb_pg/async_vectorstore.py new file mode 100644 index 0000000..2cc1080 --- /dev/null +++ b/src/llama_index_alloydb_pg/async_vectorstore.py @@ -0,0 +1,298 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# TODO: Remove below import when minimum supported Python version is 3.10 +from __future__ import annotations + +import base64 +import json +import re +import uuid +from typing import Any, Callable, Iterable, List, Optional, Sequence, Tuple, Type + +import numpy as np +import requests +from google.cloud import storage # type: ignore +from llama_index.core.schema import BaseNode, MetadataMode, TextNode +from llama_index.core.vector_stores.types import ( + BasePydanticVectorStore, + FilterOperator, + MetadataFilters, + MetadataFilter, + VectorStoreQuery, + VectorStoreQueryMode, + VectorStoreQueryResult, +) +from llama_index.core.vector_stores.utils import ( + metadata_dict_to_node, + node_to_metadata_dict, +) +from sqlalchemy import RowMapping, text +from sqlalchemy.ext.asyncio import AsyncEngine +from .engine import AlloyDBEngine + + +class AsyncAlloyDBVectorStore(BasePydanticVectorStore): + """Google AlloyDB Vector Store class""" + + stores_text: bool = True + is_embedding_query: bool = True + + engine: AsyncEngine + table_name: str + schema_name: str + id_column: str + content_column: str + embedding_column: str + metadata_json_column: str + custom_metadata_columns: List[str] + ref_doc_id_column: str + node_column: str + __create_key = object() + + def __init__( + self, + key: object, + engine: AsyncEngine, + table_name: str, + schema_name: str = "public", + id_column: str = "node_id", + content_column: str = "text", + embedding_column: str = "embedding", + metadata_json_column: Optional[str] = "li_metadata", + custom_metadata_columns: List[str] = [], + ref_doc_id_column: Optional[str] = "ref_doc_id", + node_column: Optional[str] = "node", + ): + """AsyncAlloyDBVectorStore constructor. + Args: + key (object): Prevent direct constructor usage. + engine (AsyncEngine): Connection pool engine for managing connections to AlloyDB database. + table_name (str): Name of the existing table or the table to be created. + schema_name (str, optional): Name of the database schema. Defaults to "public". + id_column (str): Column that represents if of a Node. Defaults to "node_id". + content_column (str): Column that represent text content of a Node. Defaults to "text". + embedding_column (str): Column for embedding vectors. The embedding is generated from the content of Node. Defaults to "embedding". + metadata_json_column (str): Column to store metadata as JSON. Defaults to "li_metadata". + custom_metadata_columns (List[str]): Column(s) that represent extracted metadata keys in their own columns. + ref_doc_id_column (str): Column that represents id of a node's parent document. Defaults to "ref_doc_id". + node_column (str): Column that represents the whole JSON node. Defaults to "node". + + + Raises: + Exception: If called directly by user. + """ + if key != AsyncAlloyDBVectorStore.__create_key: + raise Exception( + "Only create class through 'create' or 'create_sync' methods!" + ) + + super().__init__( + engine=engine, + table_name=table_name, + schema_name=schema_name, + id_column=id_column, + content_column=content_column, + embedding_column=embedding_column, + metadata_json_column=metadata_json_column, + custom_metadata_columns=custom_metadata_columns, + ref_doc_id_column=ref_doc_id_column, + node_column=node_column, + ) + + @classmethod + async def create( + cls: Type[AsyncAlloyDBVectorStore], + engine: AlloyDBEngine, + table_name: str, + schema_name: str = "public", + id_column: str = "node_id", + content_column: str = "text", + embedding_column: str = "embedding", + metadata_json_column: Optional[str] = "li_metadata", + custom_metadata_columns: List[str] = [], + ref_doc_id_column: Optional[str] = "ref_doc_id", + node_column: Optional[str] = "node", + perform_validation: bool = True, # TODO: For testing only, remove after engine::init implementation + ) -> AsyncAlloyDBVectorStore: + """Create an AsyncAlloyDBVectorStore instance and validates the table schema. + + Args: + engine (AlloyDBEngine): Alloy DB Engine for managing connections to AlloyDB database. + table_name (str): Name of the existing table or the table to be created. + schema_name (str, optional): Name of the database schema. Defaults to "public". + id_column (str): Column that represents if of a Node. Defaults to "node_id". + content_column (str): Column that represent text content of a Node. Defaults to "text". + embedding_column (str): Column for embedding vectors. The embedding is generated from the content of Node. Defaults to "embedding". + metadata_json_column (str): Column to store metadata as JSON. Defaults to "li_metadata". + custom_metadata_columns (List[str]): Column(s) that represent extracted metadata keys in their own columns. + ref_doc_id_column (str): Column that represents id of a node's parent document. Defaults to "ref_doc_id". + node_column (str): Column that represents the whole JSON node. Defaults to "node". + + Raises: + Exception: If table does not exist or follow the provided structure. + + Returns: + AsyncAlloyDBVectorStore + """ + # TODO: Only for testing, remove flag to always do validation after engine::init is implemented + if perform_validation: + stmt = f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{table_name}' AND table_schema = '{schema_name}'" + async with engine._pool.connect() as conn: + result = await conn.execute(text(stmt)) + result_map = result.mappings() + results = result_map.fetchall() + columns = {} + for field in results: + columns[field["column_name"]] = field["data_type"] + + # Check columns + if id_column not in columns: + raise ValueError(f"Id column, {id_column}, does not exist.") + if content_column not in columns: + raise ValueError(f"Content column, {content_column}, does not exist.") + content_type = columns[content_column] + if content_type != "text" and "char" not in content_type: + raise ValueError( + f"Content column, {content_column}, is type, {content_type}. It must be a type of character string." + ) + if embedding_column not in columns: + raise ValueError( + f"Embedding column, {embedding_column}, does not exist." + ) + if columns[embedding_column] != "USER-DEFINED": + raise ValueError( + f"Embedding column, {embedding_column}, is not type Vector." + ) + if columns[node_column] != "json": + raise ValueError(f"Node column, {node_column}, is not type JSON.") + if ref_doc_id_column not in columns: + raise ValueError( + f"Reference Document Id column, {ref_doc_id_column}, does not exist." + ) + if columns[metadata_json_column] != "jsonb": + raise ValueError( + f"Metadata column, {metadata_json_column}, does not exist." + ) + # If using metadata_columns check to make sure column exists + for column in custom_metadata_columns: + if column not in columns: + raise ValueError(f"Metadata column, {column}, does not exist.") + + return cls( + cls.__create_key, + engine._pool, + table_name, + schema_name=schema_name, + id_column=id_column, + content_column=content_column, + embedding_column=embedding_column, + metadata_json_column=metadata_json_column, + custom_metadata_columns=custom_metadata_columns, + ref_doc_id_column=ref_doc_id_column, + node_column=node_column, + ) + + @classmethod + def class_name(cls) -> str: + return "AsyncAlloyDBVectorStore" + + @property + def client(self) -> Any: + """Get client.""" + return self._engine + + async def async_add(self, nodes: Sequence[BaseNode], **kwargs: Any) -> List[str]: + """Asynchronously add nodes to the table.""" + pass + + async def adelete(self, ref_doc_id: str, **delete_kwargs: Any) -> None: + """Asynchronously delete nodes belonging to provided parent document from the table.""" + pass + + async def adelete_nodes( + self, + node_ids: Optional[List[str]] = None, + filters: Optional[MetadataFilters] = None, + **delete_kwargs: Any, + ) -> None: + """Asynchronously delete a set of nodes from the table matching the provided nodes and filters. + + Raises: + Exception: If called without any node ids or filters. + """ + pass + + async def aclear(self) -> None: + """Asynchronously delete all nodes from the table.""" + pass + + async def aget_nodes( + self, + node_ids: Optional[List[str]] = None, + filters: Optional[MetadataFilters] = None, + ) -> List[BaseNode]: + """Asynchronously get nodes from the table matching the provided nodes and filters. + + Raises: + Exception: If called without any node ids or filters. + """ + if node_ids is None and filters is None: + raise ValueError(f"Either node_ids or filters must be provided.") + pass + + async def aquery( + self, query: VectorStoreQuery, **kwargs: Any + ) -> VectorStoreQueryResult: + """Asynchronously query vector store.""" + pass + + def add(self, nodes: Sequence[BaseNode], **add_kwargs: Any) -> List[str]: + raise NotImplementedError( + "Sync methods are not implemented for AsyncAlloyDBVectorStore. Use AlloyDBVectorStore interface instead." + ) + + def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None: + raise NotImplementedError( + "Sync methods are not implemented for AsyncAlloyDBVectorStore. Use AlloyDBVectorStore interface instead." + ) + + def delete_nodes( + self, + node_ids: Optional[List[str]] = None, + filters: Optional[MetadataFilters] = None, + **delete_kwargs: Any, + ) -> None: + raise NotImplementedError( + "Sync methods are not implemented for AsyncAlloyDBVectorStore. Use AlloyDBVectorStore interface instead." + ) + + def clear(self) -> None: + raise NotImplementedError( + "Sync methods are not implemented for AsyncAlloyDBVectorStore. Use AlloyDBVectorStore interface instead." + ) + + def get_nodes( + self, + node_ids: Optional[List[str]] = None, + filters: Optional[MetadataFilters] = None, + ) -> List[BaseNode]: + raise NotImplementedError( + "Sync methods are not implemented for AsyncAlloyDBVectorStore. Use AlloyDBVectorStore interface instead." + ) + + def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult: + raise NotImplementedError( + "Sync methods are not implemented for AsyncAlloyDBVectorStore. Use AlloyDBVectorStore interface instead." + ) diff --git a/tests/test_async_vectorstore.py b/tests/test_async_vectorstore.py new file mode 100644 index 0000000..ca8bfbb --- /dev/null +++ b/tests/test_async_vectorstore.py @@ -0,0 +1,137 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import uuid +from typing import List, Sequence + +import pytest +import pytest_asyncio + +from sqlalchemy import text +from sqlalchemy.engine.row import RowMapping + +from llama_index_alloydb_pg import AlloyDBEngine +from llama_index_alloydb_pg.async_vectorstore import AsyncAlloyDBVectorStore +from llama_index.core.schema import TextNode +from llama_index.core.vector_stores.types import VectorStoreQuery + +DEFAULT_TABLE = "test_table" + str(uuid.uuid4()) + +texts = ["foo", "bar", "baz"] +metadata = [{"page": str(i), "source": "google.com"} for i in range(len(texts))] + +nodes = [TextNode(text=texts[i], metadata=metadata[i]) for i in range(len(texts))] +sync_method_exception_str = "Sync methods are not implemented for AsyncAlloyDBVectorStore. Use AlloyDBVectorStore interface instead." + + +def get_env_var(key: str, desc: str) -> str: + v = os.environ.get(key) + if v is None: + raise ValueError(f"Must set env var {key} to: {desc}") + return v + + +@pytest.mark.asyncio(loop_scope="class") +class TestVectorStore: + @pytest.fixture(scope="module") + def db_project(self) -> str: + return get_env_var("PROJECT_ID", "project id for google cloud") + + @pytest.fixture(scope="module") + def db_region(self) -> str: + return get_env_var("REGION", "region for AlloyDB instance") + + @pytest.fixture(scope="module") + def db_cluster(self) -> str: + return get_env_var("CLUSTER_ID", "cluster for AlloyDB") + + @pytest.fixture(scope="module") + def db_instance(self) -> str: + return get_env_var("INSTANCE_ID", "instance for AlloyDB") + + @pytest.fixture(scope="module") + def db_name(self) -> str: + return get_env_var("DATABASE_ID", "database name on AlloyDB instance") + + @pytest.fixture(scope="module") + def db_user(self) -> str: + return get_env_var("DB_USER", "database name on AlloyDB instance") + + @pytest.fixture(scope="module") + def db_pwd(self) -> str: + return get_env_var("DB_PASSWORD", "database name on AlloyDB instance") + + @pytest_asyncio.fixture(scope="class") + async def engine( + self, db_project, db_region, db_cluster, db_instance, db_name, db_user, db_pwd + ): + engine = await AlloyDBEngine.afrom_instance( + project_id=db_project, + instance=db_instance, + cluster=db_cluster, + region=db_region, + database=db_name, + user=db_user, + password=db_pwd, + ) + + yield engine + await engine.close() + + @pytest_asyncio.fixture(scope="class") + async def vs(self, engine): + vs = await AsyncAlloyDBVectorStore.create( + engine, table_name=DEFAULT_TABLE, perform_validation=False + ) + yield vs + + async def test_init_with_constructor(self, engine): + with pytest.raises(Exception): + AsyncAlloyDBVectorStore(engine, table_name=DEFAULT_TABLE) + + async def test_validate_columns_create(self, engine): + # TODO: add tests for more columns after engine::init is implemented + # currently, since there's no table first validation condition fails. + test_id_column = "test_id_column" + with pytest.raises( + Exception, match=f"Id column, {test_id_column}, does not exist." + ): + await AsyncAlloyDBVectorStore.create( + engine, table_name="non_existing_table", id_column=test_id_column + ) + + async def test_add(self, vs): + with pytest.raises(Exception, match=sync_method_exception_str): + vs.add(nodes) + + async def test_get_nodes(self, vs): + with pytest.raises(Exception, match=sync_method_exception_str): + vs.get_nodes() + + async def test_query(self, vs): + with pytest.raises(Exception, match=sync_method_exception_str): + vs.query(VectorStoreQuery(query_str="foo")) + + async def test_delete(self, vs): + with pytest.raises(Exception, match=sync_method_exception_str): + vs.delete("test_ref_doc_id") + + async def test_delete_nodes(self, vs): + with pytest.raises(Exception, match=sync_method_exception_str): + vs.delete_nodes(["test_node_id"]) + + async def test_clear(self, vs): + with pytest.raises(Exception, match=sync_method_exception_str): + vs.clear() From 1d0a4ffdc6cc1bafaede22abb5c8b355e8f801bc Mon Sep 17 00:00:00 2001 From: Vishwaraj Anand Date: Thu, 21 Nov 2024 14:53:07 +0000 Subject: [PATCH 02/10] chore: pr comments --- .../async_vectorstore.py | 52 ++++++++----------- 1 file changed, 21 insertions(+), 31 deletions(-) diff --git a/src/llama_index_alloydb_pg/async_vectorstore.py b/src/llama_index_alloydb_pg/async_vectorstore.py index 2cc1080..da9a4ea 100644 --- a/src/llama_index_alloydb_pg/async_vectorstore.py +++ b/src/llama_index_alloydb_pg/async_vectorstore.py @@ -53,10 +53,10 @@ class AsyncAlloyDBVectorStore(BasePydanticVectorStore): table_name: str schema_name: str id_column: str - content_column: str + text_column: str embedding_column: str metadata_json_column: str - custom_metadata_columns: List[str] + metadata_columns: List[str] ref_doc_id_column: str node_column: str __create_key = object() @@ -68,10 +68,10 @@ def __init__( table_name: str, schema_name: str = "public", id_column: str = "node_id", - content_column: str = "text", + text_column: str = "text", embedding_column: str = "embedding", metadata_json_column: Optional[str] = "li_metadata", - custom_metadata_columns: List[str] = [], + metadata_columns: List[str] = [], ref_doc_id_column: Optional[str] = "ref_doc_id", node_column: Optional[str] = "node", ): @@ -82,10 +82,10 @@ def __init__( table_name (str): Name of the existing table or the table to be created. schema_name (str, optional): Name of the database schema. Defaults to "public". id_column (str): Column that represents if of a Node. Defaults to "node_id". - content_column (str): Column that represent text content of a Node. Defaults to "text". + text_column (str): Column that represent text content of a Node. Defaults to "text". embedding_column (str): Column for embedding vectors. The embedding is generated from the content of Node. Defaults to "embedding". metadata_json_column (str): Column to store metadata as JSON. Defaults to "li_metadata". - custom_metadata_columns (List[str]): Column(s) that represent extracted metadata keys in their own columns. + metadata_columns (List[str]): Column(s) that represent extracted metadata keys in their own columns. ref_doc_id_column (str): Column that represents id of a node's parent document. Defaults to "ref_doc_id". node_column (str): Column that represents the whole JSON node. Defaults to "node". @@ -103,10 +103,10 @@ def __init__( table_name=table_name, schema_name=schema_name, id_column=id_column, - content_column=content_column, + text_column=text_column, embedding_column=embedding_column, metadata_json_column=metadata_json_column, - custom_metadata_columns=custom_metadata_columns, + metadata_columns=metadata_columns, ref_doc_id_column=ref_doc_id_column, node_column=node_column, ) @@ -118,10 +118,10 @@ async def create( table_name: str, schema_name: str = "public", id_column: str = "node_id", - content_column: str = "text", + text_column: str = "text", embedding_column: str = "embedding", metadata_json_column: Optional[str] = "li_metadata", - custom_metadata_columns: List[str] = [], + metadata_columns: List[str] = [], ref_doc_id_column: Optional[str] = "ref_doc_id", node_column: Optional[str] = "node", perform_validation: bool = True, # TODO: For testing only, remove after engine::init implementation @@ -133,10 +133,10 @@ async def create( table_name (str): Name of the existing table or the table to be created. schema_name (str, optional): Name of the database schema. Defaults to "public". id_column (str): Column that represents if of a Node. Defaults to "node_id". - content_column (str): Column that represent text content of a Node. Defaults to "text". + text_column (str): Column that represent text content of a Node. Defaults to "text". embedding_column (str): Column for embedding vectors. The embedding is generated from the content of Node. Defaults to "embedding". metadata_json_column (str): Column to store metadata as JSON. Defaults to "li_metadata". - custom_metadata_columns (List[str]): Column(s) that represent extracted metadata keys in their own columns. + metadata_columns (List[str]): Column(s) that represent extracted metadata keys in their own columns. ref_doc_id_column (str): Column that represents id of a node's parent document. Defaults to "ref_doc_id". node_column (str): Column that represents the whole JSON node. Defaults to "node". @@ -160,12 +160,12 @@ async def create( # Check columns if id_column not in columns: raise ValueError(f"Id column, {id_column}, does not exist.") - if content_column not in columns: - raise ValueError(f"Content column, {content_column}, does not exist.") - content_type = columns[content_column] + if text_column not in columns: + raise ValueError(f"Content column, {text_column}, does not exist.") + content_type = columns[text_column] if content_type != "text" and "char" not in content_type: raise ValueError( - f"Content column, {content_column}, is type, {content_type}. It must be a type of character string." + f"Content column, {text_column}, is type, {content_type}. It must be a type of character string." ) if embedding_column not in columns: raise ValueError( @@ -186,7 +186,7 @@ async def create( f"Metadata column, {metadata_json_column}, does not exist." ) # If using metadata_columns check to make sure column exists - for column in custom_metadata_columns: + for column in metadata_columns: if column not in columns: raise ValueError(f"Metadata column, {column}, does not exist.") @@ -196,10 +196,10 @@ async def create( table_name, schema_name=schema_name, id_column=id_column, - content_column=content_column, + text_column=text_column, embedding_column=embedding_column, metadata_json_column=metadata_json_column, - custom_metadata_columns=custom_metadata_columns, + metadata_columns=metadata_columns, ref_doc_id_column=ref_doc_id_column, node_column=node_column, ) @@ -227,11 +227,7 @@ async def adelete_nodes( filters: Optional[MetadataFilters] = None, **delete_kwargs: Any, ) -> None: - """Asynchronously delete a set of nodes from the table matching the provided nodes and filters. - - Raises: - Exception: If called without any node ids or filters. - """ + """Asynchronously delete a set of nodes from the table matching the provided nodes and filters.""" pass async def aclear(self) -> None: @@ -243,13 +239,7 @@ async def aget_nodes( node_ids: Optional[List[str]] = None, filters: Optional[MetadataFilters] = None, ) -> List[BaseNode]: - """Asynchronously get nodes from the table matching the provided nodes and filters. - - Raises: - Exception: If called without any node ids or filters. - """ - if node_ids is None and filters is None: - raise ValueError(f"Either node_ids or filters must be provided.") + """Asynchronously get nodes from the table matching the provided nodes and filters.""" pass async def aquery( From 4b8a4ab021449db638de2fdeec80039caaed3693 Mon Sep 17 00:00:00 2001 From: Vishwaraj Anand Date: Thu, 21 Nov 2024 14:58:08 +0000 Subject: [PATCH 03/10] chore: fix lint --- src/llama_index_alloydb_pg/async_vectorstore.py | 4 ++-- tests/test_async_vectorstore.py | 5 ++--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/llama_index_alloydb_pg/async_vectorstore.py b/src/llama_index_alloydb_pg/async_vectorstore.py index da9a4ea..3f82dfa 100644 --- a/src/llama_index_alloydb_pg/async_vectorstore.py +++ b/src/llama_index_alloydb_pg/async_vectorstore.py @@ -22,14 +22,13 @@ from typing import Any, Callable, Iterable, List, Optional, Sequence, Tuple, Type import numpy as np -import requests from google.cloud import storage # type: ignore from llama_index.core.schema import BaseNode, MetadataMode, TextNode from llama_index.core.vector_stores.types import ( BasePydanticVectorStore, FilterOperator, - MetadataFilters, MetadataFilter, + MetadataFilters, VectorStoreQuery, VectorStoreQueryMode, VectorStoreQueryResult, @@ -40,6 +39,7 @@ ) from sqlalchemy import RowMapping, text from sqlalchemy.ext.asyncio import AsyncEngine + from .engine import AlloyDBEngine diff --git a/tests/test_async_vectorstore.py b/tests/test_async_vectorstore.py index ca8bfbb..7776865 100644 --- a/tests/test_async_vectorstore.py +++ b/tests/test_async_vectorstore.py @@ -18,14 +18,13 @@ import pytest import pytest_asyncio - +from llama_index.core.schema import TextNode +from llama_index.core.vector_stores.types import VectorStoreQuery from sqlalchemy import text from sqlalchemy.engine.row import RowMapping from llama_index_alloydb_pg import AlloyDBEngine from llama_index_alloydb_pg.async_vectorstore import AsyncAlloyDBVectorStore -from llama_index.core.schema import TextNode -from llama_index.core.vector_stores.types import VectorStoreQuery DEFAULT_TABLE = "test_table" + str(uuid.uuid4()) From 89a370db9ac3aad82fc8e0b34c876ac55a41bcc3 Mon Sep 17 00:00:00 2001 From: Vishwaraj Anand Date: Fri, 22 Nov 2024 18:38:11 +0000 Subject: [PATCH 04/10] chore: fix lint --- .../async_vectorstore.py | 21 ++++++++++--------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/src/llama_index_alloydb_pg/async_vectorstore.py b/src/llama_index_alloydb_pg/async_vectorstore.py index 3f82dfa..37ee0d7 100644 --- a/src/llama_index_alloydb_pg/async_vectorstore.py +++ b/src/llama_index_alloydb_pg/async_vectorstore.py @@ -33,10 +33,6 @@ VectorStoreQueryMode, VectorStoreQueryResult, ) -from llama_index.core.vector_stores.utils import ( - metadata_dict_to_node, - node_to_metadata_dict, -) from sqlalchemy import RowMapping, text from sqlalchemy.ext.asyncio import AsyncEngine @@ -215,11 +211,13 @@ def client(self) -> Any: async def async_add(self, nodes: Sequence[BaseNode], **kwargs: Any) -> List[str]: """Asynchronously add nodes to the table.""" - pass + # TODO: complete implementation + return [] async def adelete(self, ref_doc_id: str, **delete_kwargs: Any) -> None: """Asynchronously delete nodes belonging to provided parent document from the table.""" - pass + # TODO: complete implementation + return async def adelete_nodes( self, @@ -228,11 +226,13 @@ async def adelete_nodes( **delete_kwargs: Any, ) -> None: """Asynchronously delete a set of nodes from the table matching the provided nodes and filters.""" - pass + # TODO: complete implementation + return async def aclear(self) -> None: """Asynchronously delete all nodes from the table.""" - pass + # TODO: complete implementation + return async def aget_nodes( self, @@ -240,13 +240,14 @@ async def aget_nodes( filters: Optional[MetadataFilters] = None, ) -> List[BaseNode]: """Asynchronously get nodes from the table matching the provided nodes and filters.""" - pass + return [] async def aquery( self, query: VectorStoreQuery, **kwargs: Any ) -> VectorStoreQueryResult: """Asynchronously query vector store.""" - pass + # TODO: complete implementation + return VectorStoreQueryResult() def add(self, nodes: Sequence[BaseNode], **add_kwargs: Any) -> List[str]: raise NotImplementedError( From 16d5e7c47fb1dedeff1065e34e272185456573a2 Mon Sep 17 00:00:00 2001 From: Vishwaraj Anand Date: Fri, 22 Nov 2024 19:57:17 +0000 Subject: [PATCH 05/10] fix: ci --- src/llama_index_alloydb_pg/async_vectorstore.py | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/src/llama_index_alloydb_pg/async_vectorstore.py b/src/llama_index_alloydb_pg/async_vectorstore.py index 37ee0d7..77dcbd1 100644 --- a/src/llama_index_alloydb_pg/async_vectorstore.py +++ b/src/llama_index_alloydb_pg/async_vectorstore.py @@ -23,8 +23,8 @@ import numpy as np from google.cloud import storage # type: ignore -from llama_index.core.schema import BaseNode, MetadataMode, TextNode -from llama_index.core.vector_stores.types import ( +from llama_index.core.schema import BaseNode, MetadataMode, TextNode # type: ignore +from llama_index.core.vector_stores.types import ( # type: ignore BasePydanticVectorStore, FilterOperator, MetadataFilter, @@ -45,16 +45,6 @@ class AsyncAlloyDBVectorStore(BasePydanticVectorStore): stores_text: bool = True is_embedding_query: bool = True - engine: AsyncEngine - table_name: str - schema_name: str - id_column: str - text_column: str - embedding_column: str - metadata_json_column: str - metadata_columns: List[str] - ref_doc_id_column: str - node_column: str __create_key = object() def __init__( @@ -94,6 +84,7 @@ def __init__( "Only create class through 'create' or 'create_sync' methods!" ) + # Delegate to Pydantic's __init__ super().__init__( engine=engine, table_name=table_name, From 6e29575be624c4c6ffeddf4c8c85614c5dc1a120 Mon Sep 17 00:00:00 2001 From: Vishwaraj Anand Date: Fri, 22 Nov 2024 20:06:42 +0000 Subject: [PATCH 06/10] fix: added type:ignore --- src/llama_index_alloydb_pg/engine.py | 6 +++++- tests/test_async_vectorstore.py | 14 +++++++------- tests/test_engine.py | 4 ++-- 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/src/llama_index_alloydb_pg/engine.py b/src/llama_index_alloydb_pg/engine.py index 218956e..3237d26 100644 --- a/src/llama_index_alloydb_pg/engine.py +++ b/src/llama_index_alloydb_pg/engine.py @@ -32,7 +32,11 @@ import aiohttp import google.auth # type: ignore import google.auth.transport.requests # type: ignore -from google.cloud.alloydb.connector import AsyncConnector, IPTypes, RefreshStrategy +from google.cloud.alloydb.connector import ( # type: ignore + AsyncConnector, + IPTypes, + RefreshStrategy, +) from sqlalchemy import MetaData, RowMapping, Table, text from sqlalchemy.engine import URL from sqlalchemy.exc import InvalidRequestError diff --git a/tests/test_async_vectorstore.py b/tests/test_async_vectorstore.py index 7776865..5350239 100644 --- a/tests/test_async_vectorstore.py +++ b/tests/test_async_vectorstore.py @@ -18,20 +18,20 @@ import pytest import pytest_asyncio -from llama_index.core.schema import TextNode -from llama_index.core.vector_stores.types import VectorStoreQuery +from llama_index.core.schema import TextNode # type: ignore +from llama_index.core.vector_stores.types import VectorStoreQuery # type: ignore from sqlalchemy import text from sqlalchemy.engine.row import RowMapping -from llama_index_alloydb_pg import AlloyDBEngine -from llama_index_alloydb_pg.async_vectorstore import AsyncAlloyDBVectorStore +from llama_index_alloydb_pg import AlloyDBEngine # type: ignore +from llama_index_alloydb_pg.async_vectorstore import ( + AsyncAlloyDBVectorStore, # type: ignore +) DEFAULT_TABLE = "test_table" + str(uuid.uuid4()) texts = ["foo", "bar", "baz"] -metadata = [{"page": str(i), "source": "google.com"} for i in range(len(texts))] - -nodes = [TextNode(text=texts[i], metadata=metadata[i]) for i in range(len(texts))] +nodes = [TextNode(text=texts[i]) for i in range(len(texts))] sync_method_exception_str = "Sync methods are not implemented for AsyncAlloyDBVectorStore. Use AlloyDBVectorStore interface instead." diff --git a/tests/test_engine.py b/tests/test_engine.py index b79f94a..6b04545 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -19,14 +19,14 @@ import asyncpg # type: ignore import pytest import pytest_asyncio -from google.cloud.alloydb.connector import AsyncConnector, IPTypes +from google.cloud.alloydb.connector import AsyncConnector, IPTypes # type: ignore from sqlalchemy import VARCHAR, text from sqlalchemy.engine import URL from sqlalchemy.engine.row import RowMapping from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.pool import NullPool -from llama_index_alloydb_pg import AlloyDBEngine, Column +from llama_index_alloydb_pg import AlloyDBEngine, Column # type: ignore def get_env_var(key: str, desc: str) -> str: From f0fbd039aedcdc2d48519cb7b82bd83bdc15e587 Mon Sep 17 00:00:00 2001 From: Vishwaraj Anand Date: Fri, 22 Nov 2024 20:44:55 +0000 Subject: [PATCH 07/10] fix: pydantic failure --- src/llama_index_alloydb_pg/async_vectorstore.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/src/llama_index_alloydb_pg/async_vectorstore.py b/src/llama_index_alloydb_pg/async_vectorstore.py index 77dcbd1..8789ddc 100644 --- a/src/llama_index_alloydb_pg/async_vectorstore.py +++ b/src/llama_index_alloydb_pg/async_vectorstore.py @@ -47,6 +47,17 @@ class AsyncAlloyDBVectorStore(BasePydanticVectorStore): __create_key = object() + _engine: AsyncEngine + table_name: str + schema_name: str + id_column: str + text_column: str + embedding_column: str + metadata_json_column: str + metadata_columns: List[str] + ref_doc_id_column: str + node_column: str + def __init__( self, key: object, @@ -86,7 +97,7 @@ def __init__( # Delegate to Pydantic's __init__ super().__init__( - engine=engine, + _engine=engine, table_name=table_name, schema_name=schema_name, id_column=id_column, From b4d1a93f712872ccdf04d0bd8ad76126ac07c8c2 Mon Sep 17 00:00:00 2001 From: Vishwaraj Anand Date: Sun, 24 Nov 2024 22:13:45 +0000 Subject: [PATCH 08/10] fix: pydantic fix --- .../async_vectorstore.py | 68 +++++++++++-------- 1 file changed, 38 insertions(+), 30 deletions(-) diff --git a/src/llama_index_alloydb_pg/async_vectorstore.py b/src/llama_index_alloydb_pg/async_vectorstore.py index 8789ddc..691eedc 100644 --- a/src/llama_index_alloydb_pg/async_vectorstore.py +++ b/src/llama_index_alloydb_pg/async_vectorstore.py @@ -22,7 +22,6 @@ from typing import Any, Callable, Iterable, List, Optional, Sequence, Tuple, Type import numpy as np -from google.cloud import storage # type: ignore from llama_index.core.schema import BaseNode, MetadataMode, TextNode # type: ignore from llama_index.core.vector_stores.types import ( # type: ignore BasePydanticVectorStore, @@ -48,15 +47,15 @@ class AsyncAlloyDBVectorStore(BasePydanticVectorStore): __create_key = object() _engine: AsyncEngine - table_name: str - schema_name: str - id_column: str - text_column: str - embedding_column: str - metadata_json_column: str - metadata_columns: List[str] - ref_doc_id_column: str - node_column: str + _table_name: str + _schema_name: str + _id_column: str + _text_column: str + _embedding_column: str + _metadata_json_column: str + _metadata_columns: List[str] + _ref_doc_id_column: str + _node_column: str def __init__( self, @@ -67,17 +66,19 @@ def __init__( id_column: str = "node_id", text_column: str = "text", embedding_column: str = "embedding", - metadata_json_column: Optional[str] = "li_metadata", + metadata_json_column: str = "li_metadata", metadata_columns: List[str] = [], - ref_doc_id_column: Optional[str] = "ref_doc_id", - node_column: Optional[str] = "node", + ref_doc_id_column: str = "ref_doc_id", + node_column: str = "node", + stores_text: bool = True, + is_embedding_query: bool = True, ): """AsyncAlloyDBVectorStore constructor. Args: key (object): Prevent direct constructor usage. engine (AsyncEngine): Connection pool engine for managing connections to AlloyDB database. table_name (str): Name of the existing table or the table to be created. - schema_name (str, optional): Name of the database schema. Defaults to "public". + schema_name (str): Name of the database schema. Defaults to "public". id_column (str): Column that represents if of a Node. Defaults to "node_id". text_column (str): Column that represent text content of a Node. Defaults to "text". embedding_column (str): Column for embedding vectors. The embedding is generated from the content of Node. Defaults to "embedding". @@ -85,6 +86,8 @@ def __init__( metadata_columns (List[str]): Column(s) that represent extracted metadata keys in their own columns. ref_doc_id_column (str): Column that represents id of a node's parent document. Defaults to "ref_doc_id". node_column (str): Column that represents the whole JSON node. Defaults to "node". + stores_text (bool): Whether the table stores text. Defaults to "True". + is_embedding_query (bool): Whether the table query can have embeddings. Defaults to "True". Raises: @@ -96,18 +99,17 @@ def __init__( ) # Delegate to Pydantic's __init__ - super().__init__( - _engine=engine, - table_name=table_name, - schema_name=schema_name, - id_column=id_column, - text_column=text_column, - embedding_column=embedding_column, - metadata_json_column=metadata_json_column, - metadata_columns=metadata_columns, - ref_doc_id_column=ref_doc_id_column, - node_column=node_column, - ) + super().__init__(stores_text=stores_text, is_embedding_query=is_embedding_query) + self._engine = engine + self._table_name = table_name + self._schema_name = schema_name + self._id_column = id_column + self._text_column = text_column + self._embedding_column = embedding_column + self._metadata_json_column = metadata_json_column + self._metadata_columns = metadata_columns + self._ref_doc_id_column = ref_doc_id_column + self._node_column = node_column @classmethod async def create( @@ -118,10 +120,12 @@ async def create( id_column: str = "node_id", text_column: str = "text", embedding_column: str = "embedding", - metadata_json_column: Optional[str] = "li_metadata", + metadata_json_column: str = "li_metadata", metadata_columns: List[str] = [], - ref_doc_id_column: Optional[str] = "ref_doc_id", - node_column: Optional[str] = "node", + ref_doc_id_column: str = "ref_doc_id", + node_column: str = "node", + stores_text: bool = True, + is_embedding_query: bool = True, perform_validation: bool = True, # TODO: For testing only, remove after engine::init implementation ) -> AsyncAlloyDBVectorStore: """Create an AsyncAlloyDBVectorStore instance and validates the table schema. @@ -129,7 +133,7 @@ async def create( Args: engine (AlloyDBEngine): Alloy DB Engine for managing connections to AlloyDB database. table_name (str): Name of the existing table or the table to be created. - schema_name (str, optional): Name of the database schema. Defaults to "public". + schema_name (str): Name of the database schema. Defaults to "public". id_column (str): Column that represents if of a Node. Defaults to "node_id". text_column (str): Column that represent text content of a Node. Defaults to "text". embedding_column (str): Column for embedding vectors. The embedding is generated from the content of Node. Defaults to "embedding". @@ -137,6 +141,8 @@ async def create( metadata_columns (List[str]): Column(s) that represent extracted metadata keys in their own columns. ref_doc_id_column (str): Column that represents id of a node's parent document. Defaults to "ref_doc_id". node_column (str): Column that represents the whole JSON node. Defaults to "node". + stores_text (bool): Whether the table stores text. Defaults to "True". + is_embedding_query (bool): Whether the table query can have embeddings. Defaults to "True". Raises: Exception: If table does not exist or follow the provided structure. @@ -200,6 +206,8 @@ async def create( metadata_columns=metadata_columns, ref_doc_id_column=ref_doc_id_column, node_column=node_column, + stores_text=stores_text, + is_embedding_query=is_embedding_query, ) @classmethod From 1b4d346e20f1520240a7e0d6c148f834f94002d0 Mon Sep 17 00:00:00 2001 From: Vishwaraj Anand Date: Sun, 24 Nov 2024 22:19:44 +0000 Subject: [PATCH 09/10] chore: remove unnecessary type: ignore --- src/llama_index_alloydb_pg/async_vectorstore.py | 4 ++-- tests/test_async_vectorstore.py | 8 +++----- tests/test_engine.py | 4 ++-- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/src/llama_index_alloydb_pg/async_vectorstore.py b/src/llama_index_alloydb_pg/async_vectorstore.py index 691eedc..720d661 100644 --- a/src/llama_index_alloydb_pg/async_vectorstore.py +++ b/src/llama_index_alloydb_pg/async_vectorstore.py @@ -22,8 +22,8 @@ from typing import Any, Callable, Iterable, List, Optional, Sequence, Tuple, Type import numpy as np -from llama_index.core.schema import BaseNode, MetadataMode, TextNode # type: ignore -from llama_index.core.vector_stores.types import ( # type: ignore +from llama_index.core.schema import BaseNode, MetadataMode, TextNode +from llama_index.core.vector_stores.types import ( BasePydanticVectorStore, FilterOperator, MetadataFilter, diff --git a/tests/test_async_vectorstore.py b/tests/test_async_vectorstore.py index 5350239..572b5df 100644 --- a/tests/test_async_vectorstore.py +++ b/tests/test_async_vectorstore.py @@ -19,14 +19,12 @@ import pytest import pytest_asyncio from llama_index.core.schema import TextNode # type: ignore -from llama_index.core.vector_stores.types import VectorStoreQuery # type: ignore +from llama_index.core.vector_stores.types import VectorStoreQuery from sqlalchemy import text from sqlalchemy.engine.row import RowMapping -from llama_index_alloydb_pg import AlloyDBEngine # type: ignore -from llama_index_alloydb_pg.async_vectorstore import ( - AsyncAlloyDBVectorStore, # type: ignore -) +from llama_index_alloydb_pg import AlloyDBEngine +from llama_index_alloydb_pg.async_vectorstore import AsyncAlloyDBVectorStore DEFAULT_TABLE = "test_table" + str(uuid.uuid4()) diff --git a/tests/test_engine.py b/tests/test_engine.py index 06a4f8b..ca03f0f 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -19,14 +19,14 @@ import asyncpg # type: ignore import pytest import pytest_asyncio -from google.cloud.alloydb.connector import AsyncConnector, IPTypes # type: ignore +from google.cloud.alloydb.connector import AsyncConnector, IPTypes from sqlalchemy import VARCHAR, text from sqlalchemy.engine import URL from sqlalchemy.engine.row import RowMapping from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.pool import NullPool -from llama_index_alloydb_pg import AlloyDBEngine, Column # type: ignore +from llama_index_alloydb_pg import AlloyDBEngine, Column DEFAULT_DS_TABLE = "document_store_" + str(uuid.uuid4()) DEFAULT_DS_TABLE_SYNC = "document_store_" + str(uuid.uuid4()) From acecbd1b67793fda92a08e9dfa3fe4790ca520b2 Mon Sep 17 00:00:00 2001 From: Vishwaraj Anand Date: Tue, 26 Nov 2024 10:46:08 +0000 Subject: [PATCH 10/10] fix: remove test flag perform_validation --- .../async_vectorstore.py | 105 ++++++++---------- tests/test_async_vectorstore.py | 76 +++++++++++-- 2 files changed, 112 insertions(+), 69 deletions(-) diff --git a/src/llama_index_alloydb_pg/async_vectorstore.py b/src/llama_index_alloydb_pg/async_vectorstore.py index 720d661..01373c5 100644 --- a/src/llama_index_alloydb_pg/async_vectorstore.py +++ b/src/llama_index_alloydb_pg/async_vectorstore.py @@ -46,17 +46,6 @@ class AsyncAlloyDBVectorStore(BasePydanticVectorStore): __create_key = object() - _engine: AsyncEngine - _table_name: str - _schema_name: str - _id_column: str - _text_column: str - _embedding_column: str - _metadata_json_column: str - _metadata_columns: List[str] - _ref_doc_id_column: str - _node_column: str - def __init__( self, key: object, @@ -94,9 +83,7 @@ def __init__( Exception: If called directly by user. """ if key != AsyncAlloyDBVectorStore.__create_key: - raise Exception( - "Only create class through 'create' or 'create_sync' methods!" - ) + raise Exception("Only create class through 'create' method!") # Delegate to Pydantic's __init__ super().__init__(stores_text=stores_text, is_embedding_query=is_embedding_query) @@ -126,7 +113,6 @@ async def create( node_column: str = "node", stores_text: bool = True, is_embedding_query: bool = True, - perform_validation: bool = True, # TODO: For testing only, remove after engine::init implementation ) -> AsyncAlloyDBVectorStore: """Create an AsyncAlloyDBVectorStore instance and validates the table schema. @@ -150,49 +136,51 @@ async def create( Returns: AsyncAlloyDBVectorStore """ - # TODO: Only for testing, remove flag to always do validation after engine::init is implemented - if perform_validation: - stmt = f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{table_name}' AND table_schema = '{schema_name}'" - async with engine._pool.connect() as conn: - result = await conn.execute(text(stmt)) - result_map = result.mappings() - results = result_map.fetchall() - columns = {} - for field in results: - columns[field["column_name"]] = field["data_type"] - - # Check columns - if id_column not in columns: - raise ValueError(f"Id column, {id_column}, does not exist.") - if text_column not in columns: - raise ValueError(f"Content column, {text_column}, does not exist.") - content_type = columns[text_column] - if content_type != "text" and "char" not in content_type: - raise ValueError( - f"Content column, {text_column}, is type, {content_type}. It must be a type of character string." - ) - if embedding_column not in columns: - raise ValueError( - f"Embedding column, {embedding_column}, does not exist." - ) - if columns[embedding_column] != "USER-DEFINED": - raise ValueError( - f"Embedding column, {embedding_column}, is not type Vector." - ) - if columns[node_column] != "json": - raise ValueError(f"Node column, {node_column}, is not type JSON.") - if ref_doc_id_column not in columns: - raise ValueError( - f"Reference Document Id column, {ref_doc_id_column}, does not exist." - ) - if columns[metadata_json_column] != "jsonb": - raise ValueError( - f"Metadata column, {metadata_json_column}, does not exist." - ) - # If using metadata_columns check to make sure column exists - for column in metadata_columns: - if column not in columns: - raise ValueError(f"Metadata column, {column}, does not exist.") + stmt = f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{table_name}' AND table_schema = '{schema_name}'" + async with engine._pool.connect() as conn: + result = await conn.execute(text(stmt)) + result_map = result.mappings() + results = result_map.fetchall() + columns = {} + for field in results: + columns[field["column_name"]] = field["data_type"] + + # Check columns + if id_column not in columns: + raise ValueError(f"Id column, {id_column}, does not exist.") + if text_column not in columns: + raise ValueError(f"Text column, {text_column}, does not exist.") + text_type = columns[text_column] + if text_type != "text" and "char" not in text_type: + raise ValueError( + f"Text column, {text_column}, is type, {text_type}. It must be a type of character string." + ) + if embedding_column not in columns: + raise ValueError(f"Embedding column, {embedding_column}, does not exist.") + if columns[embedding_column] != "USER-DEFINED": + raise ValueError( + f"Embedding column, {embedding_column}, is not type Vector." + ) + if node_column not in columns: + raise ValueError(f"Node column, {node_column}, does not exist.") + if columns[node_column] != "json": + raise ValueError(f"Node column, {node_column}, is not type JSON.") + if ref_doc_id_column not in columns: + raise ValueError( + f"Reference Document Id column, {ref_doc_id_column}, does not exist." + ) + if metadata_json_column not in columns: + raise ValueError( + f"Metadata column, {metadata_json_column}, does not exist." + ) + if columns[metadata_json_column] != "jsonb": + raise ValueError( + f"Metadata column, {metadata_json_column}, is not type JSONB." + ) + # If using metadata_columns check to make sure column exists + for column in metadata_columns: + if column not in columns: + raise ValueError(f"Metadata column, {column}, does not exist.") return cls( cls.__create_key, @@ -250,6 +238,7 @@ async def aget_nodes( filters: Optional[MetadataFilters] = None, ) -> List[BaseNode]: """Asynchronously get nodes from the table matching the provided nodes and filters.""" + # TODO: complete implementation return [] async def aquery( diff --git a/tests/test_async_vectorstore.py b/tests/test_async_vectorstore.py index 572b5df..8ff95ae 100644 --- a/tests/test_async_vectorstore.py +++ b/tests/test_async_vectorstore.py @@ -27,6 +27,7 @@ from llama_index_alloydb_pg.async_vectorstore import AsyncAlloyDBVectorStore DEFAULT_TABLE = "test_table" + str(uuid.uuid4()) +VECTOR_SIZE = 768 texts = ["foo", "bar", "baz"] nodes = [TextNode(text=texts[i]) for i in range(len(texts))] @@ -40,6 +41,12 @@ def get_env_var(key: str, desc: str) -> str: return v +async def aexecute(engine: AlloyDBEngine, query: str) -> None: + async with engine._pool.connect() as conn: + await conn.execute(text(query)) + await conn.commit() + + @pytest.mark.asyncio(loop_scope="class") class TestVectorStore: @pytest.fixture(scope="module") @@ -71,42 +78,89 @@ def db_pwd(self) -> str: return get_env_var("DB_PASSWORD", "database name on AlloyDB instance") @pytest_asyncio.fixture(scope="class") - async def engine( - self, db_project, db_region, db_cluster, db_instance, db_name, db_user, db_pwd - ): + async def engine(self, db_project, db_region, db_cluster, db_instance, db_name): engine = await AlloyDBEngine.afrom_instance( project_id=db_project, instance=db_instance, cluster=db_cluster, region=db_region, database=db_name, - user=db_user, - password=db_pwd, ) yield engine + await aexecute(engine, f'DROP TABLE "{DEFAULT_TABLE}"') await engine.close() @pytest_asyncio.fixture(scope="class") async def vs(self, engine): - vs = await AsyncAlloyDBVectorStore.create( - engine, table_name=DEFAULT_TABLE, perform_validation=False + await engine._ainit_vector_store_table( + DEFAULT_TABLE, VECTOR_SIZE, overwrite_existing=True ) + vs = await AsyncAlloyDBVectorStore.create(engine, table_name=DEFAULT_TABLE) yield vs async def test_init_with_constructor(self, engine): with pytest.raises(Exception): AsyncAlloyDBVectorStore(engine, table_name=DEFAULT_TABLE) - async def test_validate_columns_create(self, engine): - # TODO: add tests for more columns after engine::init is implemented - # currently, since there's no table first validation condition fails. + async def test_validate_id_column_create(self, engine, vs): test_id_column = "test_id_column" with pytest.raises( Exception, match=f"Id column, {test_id_column}, does not exist." ): await AsyncAlloyDBVectorStore.create( - engine, table_name="non_existing_table", id_column=test_id_column + engine, table_name=DEFAULT_TABLE, id_column=test_id_column + ) + + async def test_validate_text_column_create(self, engine, vs): + test_text_column = "test_text_column" + with pytest.raises( + Exception, match=f"Text column, {test_text_column}, does not exist." + ): + await AsyncAlloyDBVectorStore.create( + engine, table_name=DEFAULT_TABLE, text_column=test_text_column + ) + + async def test_validate_embedding_column_create(self, engine, vs): + test_embed_column = "test_embed_column" + with pytest.raises( + Exception, match=f"Embedding column, {test_embed_column}, does not exist." + ): + await AsyncAlloyDBVectorStore.create( + engine, table_name=DEFAULT_TABLE, embedding_column=test_embed_column + ) + + async def test_validate_node_column_create(self, engine, vs): + test_node_column = "test_node_column" + with pytest.raises( + Exception, match=f"Node column, {test_node_column}, does not exist." + ): + await AsyncAlloyDBVectorStore.create( + engine, table_name=DEFAULT_TABLE, node_column=test_node_column + ) + + async def test_validate_ref_doc_id_column_create(self, engine, vs): + test_ref_doc_id_column = "test_ref_doc_id_column" + with pytest.raises( + Exception, + match=f"Reference Document Id column, {test_ref_doc_id_column}, does not exist.", + ): + await AsyncAlloyDBVectorStore.create( + engine, + table_name=DEFAULT_TABLE, + ref_doc_id_column=test_ref_doc_id_column, + ) + + async def test_validate_metadata_json_column_create(self, engine, vs): + test_metadata_json_column = "test_metadata_json_column" + with pytest.raises( + Exception, + match=f"Metadata column, {test_metadata_json_column}, does not exist.", + ): + await AsyncAlloyDBVectorStore.create( + engine, + table_name=DEFAULT_TABLE, + metadata_json_column=test_metadata_json_column, ) async def test_add(self, vs):