Skip to content

Commit

Permalink
fix: remove test flag perform_validation
Browse files Browse the repository at this point in the history
  • Loading branch information
vishwarajanand committed Nov 26, 2024
1 parent 1121483 commit acecbd1
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 69 deletions.
105 changes: 47 additions & 58 deletions src/llama_index_alloydb_pg/async_vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,6 @@ class AsyncAlloyDBVectorStore(BasePydanticVectorStore):

__create_key = object()

_engine: AsyncEngine
_table_name: str
_schema_name: str
_id_column: str
_text_column: str
_embedding_column: str
_metadata_json_column: str
_metadata_columns: List[str]
_ref_doc_id_column: str
_node_column: str

def __init__(
self,
key: object,
Expand Down Expand Up @@ -94,9 +83,7 @@ def __init__(
Exception: If called directly by user.
"""
if key != AsyncAlloyDBVectorStore.__create_key:
raise Exception(
"Only create class through 'create' or 'create_sync' methods!"
)
raise Exception("Only create class through 'create' method!")

# Delegate to Pydantic's __init__
super().__init__(stores_text=stores_text, is_embedding_query=is_embedding_query)
Expand Down Expand Up @@ -126,7 +113,6 @@ async def create(
node_column: str = "node",
stores_text: bool = True,
is_embedding_query: bool = True,
perform_validation: bool = True, # TODO: For testing only, remove after engine::init implementation
) -> AsyncAlloyDBVectorStore:
"""Create an AsyncAlloyDBVectorStore instance and validates the table schema.
Expand All @@ -150,49 +136,51 @@ async def create(
Returns:
AsyncAlloyDBVectorStore
"""
# TODO: Only for testing, remove flag to always do validation after engine::init is implemented
if perform_validation:
stmt = f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{table_name}' AND table_schema = '{schema_name}'"
async with engine._pool.connect() as conn:
result = await conn.execute(text(stmt))
result_map = result.mappings()
results = result_map.fetchall()
columns = {}
for field in results:
columns[field["column_name"]] = field["data_type"]

# Check columns
if id_column not in columns:
raise ValueError(f"Id column, {id_column}, does not exist.")
if text_column not in columns:
raise ValueError(f"Content column, {text_column}, does not exist.")
content_type = columns[text_column]
if content_type != "text" and "char" not in content_type:
raise ValueError(
f"Content column, {text_column}, is type, {content_type}. It must be a type of character string."
)
if embedding_column not in columns:
raise ValueError(
f"Embedding column, {embedding_column}, does not exist."
)
if columns[embedding_column] != "USER-DEFINED":
raise ValueError(
f"Embedding column, {embedding_column}, is not type Vector."
)
if columns[node_column] != "json":
raise ValueError(f"Node column, {node_column}, is not type JSON.")
if ref_doc_id_column not in columns:
raise ValueError(
f"Reference Document Id column, {ref_doc_id_column}, does not exist."
)
if columns[metadata_json_column] != "jsonb":
raise ValueError(
f"Metadata column, {metadata_json_column}, does not exist."
)
# If using metadata_columns check to make sure column exists
for column in metadata_columns:
if column not in columns:
raise ValueError(f"Metadata column, {column}, does not exist.")
stmt = f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{table_name}' AND table_schema = '{schema_name}'"
async with engine._pool.connect() as conn:
result = await conn.execute(text(stmt))
result_map = result.mappings()
results = result_map.fetchall()
columns = {}
for field in results:
columns[field["column_name"]] = field["data_type"]

# Check columns
if id_column not in columns:
raise ValueError(f"Id column, {id_column}, does not exist.")
if text_column not in columns:
raise ValueError(f"Text column, {text_column}, does not exist.")
text_type = columns[text_column]
if text_type != "text" and "char" not in text_type:
raise ValueError(
f"Text column, {text_column}, is type, {text_type}. It must be a type of character string."
)
if embedding_column not in columns:
raise ValueError(f"Embedding column, {embedding_column}, does not exist.")
if columns[embedding_column] != "USER-DEFINED":
raise ValueError(
f"Embedding column, {embedding_column}, is not type Vector."
)
if node_column not in columns:
raise ValueError(f"Node column, {node_column}, does not exist.")
if columns[node_column] != "json":
raise ValueError(f"Node column, {node_column}, is not type JSON.")
if ref_doc_id_column not in columns:
raise ValueError(
f"Reference Document Id column, {ref_doc_id_column}, does not exist."
)
if metadata_json_column not in columns:
raise ValueError(
f"Metadata column, {metadata_json_column}, does not exist."
)
if columns[metadata_json_column] != "jsonb":
raise ValueError(
f"Metadata column, {metadata_json_column}, is not type JSONB."
)
# If using metadata_columns check to make sure column exists
for column in metadata_columns:
if column not in columns:
raise ValueError(f"Metadata column, {column}, does not exist.")

return cls(
cls.__create_key,
Expand Down Expand Up @@ -250,6 +238,7 @@ async def aget_nodes(
filters: Optional[MetadataFilters] = None,
) -> List[BaseNode]:
"""Asynchronously get nodes from the table matching the provided nodes and filters."""
# TODO: complete implementation
return []

async def aquery(
Expand Down
76 changes: 65 additions & 11 deletions tests/test_async_vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from llama_index_alloydb_pg.async_vectorstore import AsyncAlloyDBVectorStore

DEFAULT_TABLE = "test_table" + str(uuid.uuid4())
VECTOR_SIZE = 768

texts = ["foo", "bar", "baz"]
nodes = [TextNode(text=texts[i]) for i in range(len(texts))]
Expand All @@ -40,6 +41,12 @@ def get_env_var(key: str, desc: str) -> str:
return v


async def aexecute(engine: AlloyDBEngine, query: str) -> None:
async with engine._pool.connect() as conn:
await conn.execute(text(query))
await conn.commit()


@pytest.mark.asyncio(loop_scope="class")
class TestVectorStore:
@pytest.fixture(scope="module")
Expand Down Expand Up @@ -71,42 +78,89 @@ def db_pwd(self) -> str:
return get_env_var("DB_PASSWORD", "database name on AlloyDB instance")

@pytest_asyncio.fixture(scope="class")
async def engine(
self, db_project, db_region, db_cluster, db_instance, db_name, db_user, db_pwd
):
async def engine(self, db_project, db_region, db_cluster, db_instance, db_name):
engine = await AlloyDBEngine.afrom_instance(
project_id=db_project,
instance=db_instance,
cluster=db_cluster,
region=db_region,
database=db_name,
user=db_user,
password=db_pwd,
)

yield engine
await aexecute(engine, f'DROP TABLE "{DEFAULT_TABLE}"')
await engine.close()

@pytest_asyncio.fixture(scope="class")
async def vs(self, engine):
vs = await AsyncAlloyDBVectorStore.create(
engine, table_name=DEFAULT_TABLE, perform_validation=False
await engine._ainit_vector_store_table(
DEFAULT_TABLE, VECTOR_SIZE, overwrite_existing=True
)
vs = await AsyncAlloyDBVectorStore.create(engine, table_name=DEFAULT_TABLE)
yield vs

async def test_init_with_constructor(self, engine):
with pytest.raises(Exception):
AsyncAlloyDBVectorStore(engine, table_name=DEFAULT_TABLE)

async def test_validate_columns_create(self, engine):
# TODO: add tests for more columns after engine::init is implemented
# currently, since there's no table first validation condition fails.
async def test_validate_id_column_create(self, engine, vs):
test_id_column = "test_id_column"
with pytest.raises(
Exception, match=f"Id column, {test_id_column}, does not exist."
):
await AsyncAlloyDBVectorStore.create(
engine, table_name="non_existing_table", id_column=test_id_column
engine, table_name=DEFAULT_TABLE, id_column=test_id_column
)

async def test_validate_text_column_create(self, engine, vs):
test_text_column = "test_text_column"
with pytest.raises(
Exception, match=f"Text column, {test_text_column}, does not exist."
):
await AsyncAlloyDBVectorStore.create(
engine, table_name=DEFAULT_TABLE, text_column=test_text_column
)

async def test_validate_embedding_column_create(self, engine, vs):
test_embed_column = "test_embed_column"
with pytest.raises(
Exception, match=f"Embedding column, {test_embed_column}, does not exist."
):
await AsyncAlloyDBVectorStore.create(
engine, table_name=DEFAULT_TABLE, embedding_column=test_embed_column
)

async def test_validate_node_column_create(self, engine, vs):
test_node_column = "test_node_column"
with pytest.raises(
Exception, match=f"Node column, {test_node_column}, does not exist."
):
await AsyncAlloyDBVectorStore.create(
engine, table_name=DEFAULT_TABLE, node_column=test_node_column
)

async def test_validate_ref_doc_id_column_create(self, engine, vs):
test_ref_doc_id_column = "test_ref_doc_id_column"
with pytest.raises(
Exception,
match=f"Reference Document Id column, {test_ref_doc_id_column}, does not exist.",
):
await AsyncAlloyDBVectorStore.create(
engine,
table_name=DEFAULT_TABLE,
ref_doc_id_column=test_ref_doc_id_column,
)

async def test_validate_metadata_json_column_create(self, engine, vs):
test_metadata_json_column = "test_metadata_json_column"
with pytest.raises(
Exception,
match=f"Metadata column, {test_metadata_json_column}, does not exist.",
):
await AsyncAlloyDBVectorStore.create(
engine,
table_name=DEFAULT_TABLE,
metadata_json_column=test_metadata_json_column,
)

async def test_add(self, vs):
Expand Down

0 comments on commit acecbd1

Please sign in to comment.