Skip to content

Commit

Permalink
Merge branch 'main' into renovate/python-nonmajor
Browse files Browse the repository at this point in the history
  • Loading branch information
averikitsch authored Jan 16, 2025
2 parents 1bb36ee + 5f1405e commit 40457ec
Show file tree
Hide file tree
Showing 4 changed files with 204 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/llama_index_alloydb_pg/async_vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,10 +543,10 @@ async def __query_columns(
query_stmt = f'SELECT * {scoring_stmt} FROM "{self._schema_name}"."{self._table_name}" {filters_stmt} {order_stmt} {limit_stmt}'
async with self._engine.connect() as conn:
if self._index_query_options:
query_options_stmt = (
f"SET LOCAL {self._index_query_options.to_string()};"
)
await conn.execute(text(query_options_stmt))
# Set each query option individually
for query_option in self._index_query_options.to_parameter():
query_options_stmt = f"SET LOCAL {query_option};"
await conn.execute(text(query_options_stmt))
result = await conn.execute(text(query_stmt))
result_map = result.mappings()
results = result_map.fetchall()
Expand Down
41 changes: 41 additions & 0 deletions src/llama_index_alloydb_pg/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import enum
import warnings
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Optional
Expand Down Expand Up @@ -62,6 +63,11 @@ class ExactNearestNeighbor(BaseIndex):

@dataclass
class QueryOptions(ABC):
@abstractmethod
def to_parameter(self) -> list[str]:
"""Convert index attributes to list of configurations."""
raise NotImplementedError("to_parameter method must be implemented by subclass")

@abstractmethod
def to_string(self) -> str:
"""Convert index attributes to string."""
Expand All @@ -83,8 +89,16 @@ def index_options(self) -> str:
class HNSWQueryOptions(QueryOptions):
ef_search: int = 40

def to_parameter(self) -> list[str]:
"""Convert index attributes to list of configurations."""
return [f"hnsw.ef_search = {self.ef_search}"]

def to_string(self) -> str:
"""Convert index attributes to string."""
warnings.warn(
"to_string is deprecated, use to_parameter instead.",
DeprecationWarning,
)
return f"hnsw.ef_search = {self.ef_search}"


Expand All @@ -102,8 +116,16 @@ def index_options(self) -> str:
class IVFFlatQueryOptions(QueryOptions):
probes: int = 1

def to_parameter(self) -> list[str]:
"""Convert index attributes to list of configurations."""
return [f"ivfflat.probes = {self.probes}"]

def to_string(self) -> str:
"""Convert index attributes to string."""
warnings.warn(
"to_string is deprecated, use to_parameter instead.",
DeprecationWarning,
)
return f"ivfflat.probes = {self.probes}"


Expand All @@ -124,8 +146,16 @@ def index_options(self) -> str:
class IVFQueryOptions(QueryOptions):
probes: int = 1

def to_parameter(self) -> list[str]:
"""Convert index attributes to list of configurations."""
return [f"ivf.probes = {self.probes}"]

def to_string(self) -> str:
"""Convert index attributes to string."""
warnings.warn(
"to_string is deprecated, use to_parameter instead.",
DeprecationWarning,
)
return f"ivf.probes = {self.probes}"


Expand All @@ -147,6 +177,17 @@ class ScaNNQueryOptions(QueryOptions):
num_leaves_to_search: int = 1
pre_reordering_num_neighbors: int = -1

def to_parameter(self) -> list[str]:
"""Convert index attributes to list of configurations."""
return [
f"scann.num_leaves_to_search = {self.num_leaves_to_search}",
f"scann.pre_reordering_num_neighbors = {self.pre_reordering_num_neighbors}",
]

def to_string(self) -> str:
"""Convert index attributes to string."""
warnings.warn(
"to_string is deprecated, use to_parameter instead.",
DeprecationWarning,
)
return f"scann.num_leaves_to_search = {self.num_leaves_to_search}, scann.pre_reordering_num_neighbors = {self.pre_reordering_num_neighbors}"
37 changes: 37 additions & 0 deletions tests/test_async_vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

from llama_index_alloydb_pg import AlloyDBEngine, Column
from llama_index_alloydb_pg.async_vector_store import AsyncAlloyDBVectorStore
from llama_index_alloydb_pg.indexes import HNSWQueryOptions, ScaNNQueryOptions

DEFAULT_TABLE = "test_table" + str(uuid.uuid4())
DEFAULT_TABLE_CUSTOM_VS = "test_table" + str(uuid.uuid4())
Expand Down Expand Up @@ -155,6 +156,23 @@ async def custom_vs(self, engine):
"nullable_int_field",
"nullable_str_field",
],
index_query_options=HNSWQueryOptions(ef_search=1),
)
yield vs

@pytest_asyncio.fixture(scope="class")
async def custom_vs_scann(self, engine, custom_vs):
vs = await AsyncAlloyDBVectorStore.create(
engine,
table_name=DEFAULT_TABLE_CUSTOM_VS,
metadata_columns=[
"len",
"nullable_int_field",
"nullable_str_field",
],
index_query_options=ScaNNQueryOptions(
num_leaves_to_search=1, pre_reordering_num_neighbors=2
),
)
yield vs

Expand Down Expand Up @@ -320,6 +338,25 @@ async def test_aquery(self, engine, vs):
assert len(results.nodes) == 3
assert results.nodes[0].get_content(metadata_mode=MetadataMode.NONE) == "foo"

async def test_aquery_scann(self, engine, custom_vs_scann):
# Note: To be migrated to a pytest dependency on test_async_add
# Blocked due to unexpected fixtures reloads while running integration test suite
await aexecute(engine, f'TRUNCATE TABLE "{DEFAULT_TABLE_CUSTOM_VS}"')
# setting extra metadata to be indexed in separate column
for node in nodes:
node.metadata["len"] = len(node.text)

await custom_vs_scann.async_add(nodes)
query = VectorStoreQuery(
query_embedding=[1.0] * VECTOR_SIZE, similarity_top_k=3
)
results = await custom_vs_scann.aquery(query)

assert results.nodes is not None
assert results.ids is not None
assert results.similarities is not None
assert len(results.nodes) == 3

async def test_aquery_filters(self, engine, custom_vs):
# Note: To be migrated to a pytest dependency on test_async_add
# Blocked due to unexpected fixtures reloads while running integration test suite
Expand Down
122 changes: 122 additions & 0 deletions tests/test_indexes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings

from llama_index_alloydb_pg.indexes import (
DistanceStrategy,
HNSWIndex,
HNSWQueryOptions,
IVFFlatIndex,
IVFFlatQueryOptions,
IVFIndex,
IVFQueryOptions,
ScaNNIndex,
ScaNNQueryOptions,
)


class TestAlloyDBIndex:
def test_distance_strategy(self):
assert DistanceStrategy.EUCLIDEAN.operator == "<->"
assert DistanceStrategy.EUCLIDEAN.search_function == "l2_distance"
assert DistanceStrategy.EUCLIDEAN.index_function == "vector_l2_ops"
assert DistanceStrategy.EUCLIDEAN.scann_index_function == "l2"

assert DistanceStrategy.COSINE_DISTANCE.operator == "<=>"
assert DistanceStrategy.COSINE_DISTANCE.search_function == "cosine_distance"
assert DistanceStrategy.COSINE_DISTANCE.index_function == "vector_cosine_ops"
assert DistanceStrategy.COSINE_DISTANCE.scann_index_function == "cosine"

assert DistanceStrategy.INNER_PRODUCT.operator == "<#>"
assert DistanceStrategy.INNER_PRODUCT.search_function == "inner_product"
assert DistanceStrategy.INNER_PRODUCT.index_function == "vector_ip_ops"
assert DistanceStrategy.INNER_PRODUCT.scann_index_function == "dot_product"

def test_hnsw_index(self):
index = HNSWIndex(name="test_index", m=32, ef_construction=128)
assert index.index_type == "hnsw"
assert index.m == 32
assert index.ef_construction == 128
assert index.index_options() == "(m = 32, ef_construction = 128)"

def test_hnsw_query_options(self):
options = HNSWQueryOptions(ef_search=80)
assert options.to_parameter() == ["hnsw.ef_search = 80"]

with warnings.catch_warnings(record=True) as w:
options.to_string()

assert len(w) == 1
assert "to_string is deprecated, use to_parameter instead." in str(
w[-1].message
)

def test_ivfflat_index(self):
index = IVFFlatIndex(name="test_index", lists=200)
assert index.index_type == "ivfflat"
assert index.lists == 200
assert index.index_options() == "(lists = 200)"

def test_ivfflat_query_options(self):
options = IVFFlatQueryOptions(probes=2)
assert options.to_parameter() == ["ivfflat.probes = 2"]

with warnings.catch_warnings(record=True) as w:
options.to_string()
assert len(w) == 1
assert "to_string is deprecated, use to_parameter instead." in str(
w[-1].message
)

def test_ivf_index(self):
index = IVFIndex(name="test_index", lists=200)
assert index.index_type == "ivf"
assert index.lists == 200
assert index.quantizer == "sq8" # Check default value
assert index.index_options() == "(lists = 200, quantizer = sq8)"

def test_ivf_query_options(self):
options = IVFQueryOptions(probes=2)
assert options.to_parameter() == ["ivf.probes = 2"]

with warnings.catch_warnings(record=True) as w:
options.to_string()
assert len(w) == 1
assert "to_string is deprecated, use to_parameter instead." in str(
w[-1].message
)

def test_scann_index(self):
index = ScaNNIndex(name="test_index", num_leaves=10)
assert index.index_type == "ScaNN"
assert index.num_leaves == 10
assert index.quantizer == "sq8" # Check default value
assert index.index_options() == "(num_leaves = 10, quantizer = sq8)"

def test_scann_query_options(self):
options = ScaNNQueryOptions(
num_leaves_to_search=2, pre_reordering_num_neighbors=10
)
assert options.to_parameter() == [
"scann.num_leaves_to_search = 2",
"scann.pre_reordering_num_neighbors = 10",
]

with warnings.catch_warnings(record=True) as w:
options.to_string()
assert len(w) == 1
assert "to_string is deprecated, use to_parameter instead." in str(
w[-1].message
)

0 comments on commit 40457ec

Please sign in to comment.