Skip to content

Commit

Permalink
Merge branch 'main' into release-please--branches--main
Browse files Browse the repository at this point in the history
  • Loading branch information
averikitsch authored Dec 7, 2024
2 parents a9da4f7 + 28c5752 commit 2c8a9de
Show file tree
Hide file tree
Showing 14 changed files with 96 additions and 118 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
google-cloud-alloydb-connector[asyncpg]==1.5.0
llama-index-core==0.12.2
llama-index-core==0.12.3
pgvector==0.3.6
SQLAlchemy[asyncio]==2.0.36
38 changes: 19 additions & 19 deletions src/llama_index_alloydb_pg/async_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import json
import warnings
from typing import Any, Dict, List, Optional, Sequence, Tuple
from typing import Optional, Sequence

from llama_index.core.constants import DATA_KEY
from llama_index.core.schema import BaseNode
Expand Down Expand Up @@ -119,13 +119,13 @@ async def __afetch_query(self, query):
return results

async def _put_all_doc_hashes_to_table(
self, rows: List[Tuple[str, str]], batch_size: int = int(DEFAULT_BATCH_SIZE)
self, rows: list[tuple[str, str]], batch_size: int = int(DEFAULT_BATCH_SIZE)
) -> None:
"""Puts a multiple rows of node ids with their doc_hash into the document table.
Incase a row with the id already exists, it updates the row with the new doc_hash.
Args:
rows (List[Tuple[str, str]]): List of tuples of id and doc_hash
rows (list[tuple[str, str]]): List of tuples of id and doc_hash
batch_size (int): batch_size to insert the rows. Defaults to 1.
Returns:
Expand Down Expand Up @@ -173,7 +173,7 @@ async def async_add_documents(
"""Adds a document to the store.
Args:
docs (List[BaseDocument]): documents
docs (list[BaseDocument]): documents
allow_update (bool): allow update of docstore from document
batch_size (int): batch_size to insert the rows. Defaults to 1.
store_text (bool): allow the text content of the node to stored.
Expand Down Expand Up @@ -225,7 +225,7 @@ async def async_add_documents(
await self.__aexecute_query(query, batch)

@property
async def adocs(self) -> Dict[str, BaseNode]:
async def adocs(self) -> dict[str, BaseNode]:
"""Get all documents.
Returns:
Expand Down Expand Up @@ -300,12 +300,12 @@ async def aget_ref_doc_info(self, ref_doc_id: str) -> Optional[RefDocInfo]:

return RefDocInfo(node_ids=node_ids, metadata=merged_metadata)

async def aget_all_ref_doc_info(self) -> Optional[Dict[str, RefDocInfo]]:
async def aget_all_ref_doc_info(self) -> Optional[dict[str, RefDocInfo]]:
"""Get a mapping of ref_doc_id -> RefDocInfo for all ingested documents.
Returns:
Optional[
Dict[
dict[
str, #Ref_doc_id
RefDocInfo, #Ref_doc_info of the id
]
Expand Down Expand Up @@ -356,14 +356,14 @@ async def adocument_exists(self, doc_id: str) -> bool:

async def _get_ref_doc_child_node_ids(
self, ref_doc_id: str
) -> Optional[Dict[str, List[str]]]:
) -> Optional[dict[str, list[str]]]:
"""Helper function to find the child node mappings of a ref_doc_id.
Returns:
Optional[
Dict[
dict[
str, # Ref_doc_id
List # List of all nodes that refer to ref_doc_id
list # List of all nodes that refer to ref_doc_id
]
]"""
query = f"""select id from "{self._schema_name}"."{self._table_name}" where ref_doc_id = '{ref_doc_id}';"""
Expand Down Expand Up @@ -442,11 +442,11 @@ async def aset_document_hash(self, doc_id: str, doc_hash: str) -> None:

await self._put_all_doc_hashes_to_table(rows=[(doc_id, doc_hash)])

async def aset_document_hashes(self, doc_hashes: Dict[str, str]) -> None:
async def aset_document_hashes(self, doc_hashes: dict[str, str]) -> None:
"""Set the hash for a given doc_id.
Args:
doc_hashes (Dict[str, str]): Dictionary with doc_id as key and doc_hash as value.
doc_hashes (dict[str, str]): Dictionary with doc_id as key and doc_hash as value.
Returns:
None
Expand All @@ -473,11 +473,11 @@ async def aget_document_hash(self, doc_id: str) -> Optional[str]:
else:
return None

async def aget_all_document_hashes(self) -> Dict[str, str]:
async def aget_all_document_hashes(self) -> dict[str, str]:
"""Get the stored hash for all documents.
Returns:
Dict[
dict[
str, # doc_hash
str # doc_id
]
Expand All @@ -498,11 +498,11 @@ async def aget_all_document_hashes(self) -> Dict[str, str]:
return hashes

@property
def docs(self) -> Dict[str, BaseNode]:
def docs(self) -> dict[str, BaseNode]:
"""Get all documents.
Returns:
Dict[str, BaseDocument]: documents
dict[str, BaseDocument]: documents
"""
raise NotImplementedError(
Expand Down Expand Up @@ -547,7 +547,7 @@ def set_document_hash(self, doc_id: str, doc_hash: str) -> None:
"Sync methods are not implemented for AsyncAlloyDBDocumentStore. Use AlloyDBDocumentStore interface instead."
)

def set_document_hashes(self, doc_hashes: Dict[str, str]) -> None:
def set_document_hashes(self, doc_hashes: dict[str, str]) -> None:
raise NotImplementedError(
"Sync methods are not implemented for AsyncAlloyDBDocumentStore. Use AlloyDBDocumentStore interface instead."
)
Expand All @@ -557,12 +557,12 @@ def get_document_hash(self, doc_id: str) -> Optional[str]:
"Sync methods are not implemented for AsyncAlloyDBDocumentStore. Use AlloyDBDocumentStore interface instead."
)

def get_all_document_hashes(self) -> Dict[str, str]:
def get_all_document_hashes(self) -> dict[str, str]:
raise NotImplementedError(
"Sync methods are not implemented for AsyncAlloyDBDocumentStore. Use AlloyDBDocumentStore interface instead."
)

def get_all_ref_doc_info(self) -> Optional[Dict[str, RefDocInfo]]:
def get_all_ref_doc_info(self) -> Optional[dict[str, RefDocInfo]]:
raise NotImplementedError(
"Sync methods are not implemented for AsyncAlloyDBDocumentStore. Use AlloyDBDocumentStore interface instead."
)
Expand Down
9 changes: 4 additions & 5 deletions src/llama_index_alloydb_pg/async_index_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,8 @@

import json
import warnings
from typing import List, Optional
from typing import Optional

from llama_index.core.constants import DATA_KEY
from llama_index.core.data_structs.data_structs import IndexStruct
from llama_index.core.storage.index_store.types import BaseIndexStore
from llama_index.core.storage.index_store.utils import (
Expand Down Expand Up @@ -113,11 +112,11 @@ async def __afetch_query(self, query):
await conn.commit()
return results

async def aindex_structs(self) -> List[IndexStruct]:
async def aindex_structs(self) -> list[IndexStruct]:
"""Get all index structs.
Returns:
List[IndexStruct]: index structs
list[IndexStruct]: index structs
"""
query = f"""SELECT * from "{self._schema_name}"."{self._table_name}";"""
Expand Down Expand Up @@ -190,7 +189,7 @@ async def aget_index_struct(
return json_to_index_struct(index_data)
return None

def index_structs(self) -> List[IndexStruct]:
def index_structs(self) -> list[IndexStruct]:
raise NotImplementedError(
"Sync methods are not implemented for AsyncAlloyDBIndexStore . Use AlloyDBIndexStore interface instead."
)
Expand Down
37 changes: 16 additions & 21 deletions src/llama_index_alloydb_pg/async_vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,10 @@
# TODO: Remove below import when minimum supported Python version is 3.10
from __future__ import annotations

import base64
import json
import re
import uuid
import warnings
from typing import Any, Callable, Iterable, List, Optional, Sequence, Tuple, Type
from typing import Any, Optional, Sequence

import numpy as np
from llama_index.core.schema import BaseNode, MetadataMode, NodeRelationship, TextNode
from llama_index.core.vector_stores.types import (
BasePydanticVectorStore,
Expand All @@ -31,7 +27,6 @@
MetadataFilter,
MetadataFilters,
VectorStoreQuery,
VectorStoreQueryMode,
VectorStoreQueryResult,
)
from llama_index.core.vector_stores.utils import (
Expand Down Expand Up @@ -71,7 +66,7 @@ def __init__(
text_column: str = "text",
embedding_column: str = "embedding",
metadata_json_column: str = "li_metadata",
metadata_columns: List[str] = [],
metadata_columns: list[str] = [],
ref_doc_id_column: str = "ref_doc_id",
node_column: str = "node_data",
stores_text: bool = True,
Expand All @@ -89,7 +84,7 @@ def __init__(
text_column (str): Column that represent text content of a Node. Defaults to "text".
embedding_column (str): Column for embedding vectors. The embedding is generated from the content of Node. Defaults to "embedding".
metadata_json_column (str): Column to store metadata as JSON. Defaults to "li_metadata".
metadata_columns (List[str]): Column(s) that represent extracted metadata keys in their own columns.
metadata_columns (list[str]): Column(s) that represent extracted metadata keys in their own columns.
ref_doc_id_column (str): Column that represents id of a node's parent document. Defaults to "ref_doc_id".
node_column (str): Column that represents the whole JSON node. Defaults to "node_data".
stores_text (bool): Whether the table stores text. Defaults to "True".
Expand Down Expand Up @@ -121,15 +116,15 @@ def __init__(

@classmethod
async def create(
cls: Type[AsyncAlloyDBVectorStore],
cls: type[AsyncAlloyDBVectorStore],
engine: AlloyDBEngine,
table_name: str,
schema_name: str = "public",
id_column: str = "node_id",
text_column: str = "text",
embedding_column: str = "embedding",
metadata_json_column: str = "li_metadata",
metadata_columns: List[str] = [],
metadata_columns: list[str] = [],
ref_doc_id_column: str = "ref_doc_id",
node_column: str = "node_data",
stores_text: bool = True,
Expand All @@ -147,7 +142,7 @@ async def create(
text_column (str): Column that represent text content of a Node. Defaults to "text".
embedding_column (str): Column for embedding vectors. The embedding is generated from the content of Node. Defaults to "embedding".
metadata_json_column (str): Column to store metadata as JSON. Defaults to "li_metadata".
metadata_columns (List[str]): Column(s) that represent extracted metadata keys in their own columns.
metadata_columns (list[str]): Column(s) that represent extracted metadata keys in their own columns.
ref_doc_id_column (str): Column that represents id of a node's parent document. Defaults to "ref_doc_id".
node_column (str): Column that represents the whole JSON node. Defaults to "node_data".
stores_text (bool): Whether the table stores text. Defaults to "True".
Expand Down Expand Up @@ -234,7 +229,7 @@ def client(self) -> Any:
"""Get client."""
return self._engine

async def async_add(self, nodes: Sequence[BaseNode], **kwargs: Any) -> List[str]:
async def async_add(self, nodes: Sequence[BaseNode], **kwargs: Any) -> list[str]:
"""Asynchronously add nodes to the table."""
ids = []
metadata_col_names = (
Expand Down Expand Up @@ -293,14 +288,14 @@ async def adelete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:

async def adelete_nodes(
self,
node_ids: Optional[List[str]] = None,
node_ids: Optional[list[str]] = None,
filters: Optional[MetadataFilters] = None,
**delete_kwargs: Any,
) -> None:
"""Asynchronously delete a set of nodes from the table matching the provided nodes and filters."""
if not node_ids and not filters:
return
all_filters: List[MetadataFilter | MetadataFilters] = []
all_filters: list[MetadataFilter | MetadataFilters] = []
if node_ids:
all_filters.append(
MetadataFilter(
Expand Down Expand Up @@ -332,9 +327,9 @@ async def aclear(self) -> None:

async def aget_nodes(
self,
node_ids: Optional[List[str]] = None,
node_ids: Optional[list[str]] = None,
filters: Optional[MetadataFilters] = None,
) -> List[BaseNode]:
) -> list[BaseNode]:
"""Asynchronously get nodes from the table matching the provided nodes and filters."""
query = VectorStoreQuery(
node_ids=node_ids, filters=filters, similarity_top_k=-1
Expand Down Expand Up @@ -366,7 +361,7 @@ async def aquery(
similarities.append(row["distance"])
return VectorStoreQueryResult(nodes=nodes, similarities=similarities, ids=ids)

def add(self, nodes: Sequence[BaseNode], **add_kwargs: Any) -> List[str]:
def add(self, nodes: Sequence[BaseNode], **add_kwargs: Any) -> list[str]:
raise NotImplementedError(
"Sync methods are not implemented for AsyncAlloyDBVectorStore. Use AlloyDBVectorStore interface instead."
)
Expand All @@ -378,7 +373,7 @@ def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:

def delete_nodes(
self,
node_ids: Optional[List[str]] = None,
node_ids: Optional[list[str]] = None,
filters: Optional[MetadataFilters] = None,
**delete_kwargs: Any,
) -> None:
Expand All @@ -393,9 +388,9 @@ def clear(self) -> None:

def get_nodes(
self,
node_ids: Optional[List[str]] = None,
node_ids: Optional[list[str]] = None,
filters: Optional[MetadataFilters] = None,
) -> List[BaseNode]:
) -> list[BaseNode]:
raise NotImplementedError(
"Sync methods are not implemented for AsyncAlloyDBVectorStore. Use AlloyDBVectorStore interface instead."
)
Expand Down Expand Up @@ -495,7 +490,7 @@ async def __query_columns(
**kwargs: Any,
) -> Sequence[RowMapping]:
"""Perform search query on database."""
filters: List[MetadataFilter | MetadataFilters] = []
filters: list[MetadataFilter | MetadataFilters] = []
if query.doc_ids:
filters.append(
MetadataFilter(
Expand Down
Loading

0 comments on commit 2c8a9de

Please sign in to comment.