From e5527adb1284bdbbafd2b436f872583220cb956d Mon Sep 17 00:00:00 2001 From: Francisco Arceo Date: Fri, 10 Jan 2025 16:09:02 -0500 Subject: [PATCH] chore: Updating tests to allow for the CLIRunner to use Milvus, also have to handle special case of not running apply and teardown (#4915) * chore: Updating tests to allow for the CLIRunner to use Milvus, also have to handle special case of not running apply and teardown Signed-off-by: Francisco Javier Arceo * Adding cleanup Signed-off-by: Francisco Javier Arceo * adding example repo Signed-off-by: Francisco Javier Arceo * changing defualt to FLAT for local implementation Signed-off-by: Francisco Javier Arceo --------- Signed-off-by: Francisco Javier Arceo --- .../milvus_online_store/milvus.py | 21 +- .../example_repos/example_rag_feature_repo.py | 38 ++++ .../online_store/test_online_retrieval.py | 180 ++++++++++++++++++ sdk/python/tests/utils/cli_repo_creator.py | 87 ++++++--- 4 files changed, 293 insertions(+), 33 deletions(-) create mode 100644 sdk/python/tests/example_repos/example_rag_feature_repo.py diff --git a/sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py b/sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py index f2283387a0..7e840622a8 100644 --- a/sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py +++ b/sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py @@ -1,4 +1,5 @@ from datetime import datetime +from pathlib import Path from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union from pydantic import StrictStr @@ -84,9 +85,10 @@ class MilvusOnlineStoreConfig(FeastConfigBaseModel, VectorStoreConfig): """ type: Literal["milvus"] = "milvus" + path: Optional[StrictStr] = "data/online_store.db" host: Optional[StrictStr] = "localhost" port: Optional[int] = 19530 - index_type: Optional[str] = "IVF_FLAT" + index_type: Optional[str] = "FLAT" metric_type: Optional[str] = "L2" embedding_dim: Optional[int] = 128 vector_enabled: Optional[bool] = True @@ -106,11 +108,24 @@ class MilvusOnlineStore(OnlineStore): client: Optional[MilvusClient] = None _collections: Dict[str, Any] = {} + def _get_db_path(self, config: RepoConfig) -> str: + assert ( + config.online_store.type == "milvus" + or config.online_store.type.endswith("MilvusOnlineStore") + ) + + if config.repo_path and not Path(config.online_store.path).is_absolute(): + db_path = str(config.repo_path / config.online_store.path) + else: + db_path = config.online_store.path + return db_path + def _connect(self, config: RepoConfig) -> MilvusClient: if not self.client: if config.provider == "local": - print("Connecting to Milvus in local mode using ./milvus_demo.db") - self.client = MilvusClient("./milvus_demo.db") + db_path = self._get_db_path(config) + print(f"Connecting to Milvus in local mode using {db_path}") + self.client = MilvusClient(db_path) else: self.client = MilvusClient( url=f"{config.online_store.host}:{config.online_store.port}", diff --git a/sdk/python/tests/example_repos/example_rag_feature_repo.py b/sdk/python/tests/example_repos/example_rag_feature_repo.py new file mode 100644 index 0000000000..2f55095bc6 --- /dev/null +++ b/sdk/python/tests/example_repos/example_rag_feature_repo.py @@ -0,0 +1,38 @@ +from datetime import timedelta + +from feast import Entity, FeatureView, Field, FileSource +from feast.types import Array, Float32, Int64, UnixTimestamp + +# This is for Milvus +# Note that file source paths are not validated, so there doesn't actually need to be any data +# at the paths for these file sources. Since these paths are effectively fake, this example +# feature repo should not be used for historical retrieval. + +rag_documents_source = FileSource( + path="data/embedded_documents.parquet", + timestamp_field="event_timestamp", + created_timestamp_column="created_timestamp", +) + +item = Entity( + name="item_id", # The name is derived from this argument, not object name. + join_keys=["item_id"], +) + +document_embeddings = FeatureView( + name="embedded_documents", + entities=[item], + schema=[ + Field( + name="vector", + dtype=Array(Float32), + vector_index=True, + vector_search_metric="L2", + ), + Field(name="item_id", dtype=Int64), + Field(name="created_timestamp", dtype=UnixTimestamp), + Field(name="event_timestamp", dtype=UnixTimestamp), + ], + source=rag_documents_source, + ttl=timedelta(hours=24), +) diff --git a/sdk/python/tests/unit/online_store/test_online_retrieval.py b/sdk/python/tests/unit/online_store/test_online_retrieval.py index 83184643f3..5f0796f4ee 100644 --- a/sdk/python/tests/unit/online_store/test_online_retrieval.py +++ b/sdk/python/tests/unit/online_store/test_online_retrieval.py @@ -1,5 +1,6 @@ import os import platform +import random import sqlite3 import sys import time @@ -561,3 +562,182 @@ def test_sqlite_vec_import() -> None: """).fetchall() result = [(rowid, round(distance, 2)) for rowid, distance in result] assert result == [(2, 2.39), (1, 2.39)] + + +def test_local_milvus() -> None: + import random + + from pymilvus import MilvusClient + + random.seed(42) + VECTOR_LENGTH: int = 768 + COLLECTION_NAME: str = "test_demo_collection" + + client = MilvusClient("./milvus_demo.db") + + for collection in client.list_collections(): + client.drop_collection(collection_name=collection) + client.create_collection( + collection_name=COLLECTION_NAME, + dimension=VECTOR_LENGTH, + ) + assert client.list_collections() == [COLLECTION_NAME] + + docs = [ + "Artificial intelligence was founded as an academic discipline in 1956.", + "Alan Turing was the first person to conduct substantial research in AI.", + "Born in Maida Vale, London, Turing was raised in southern England.", + ] + # Use fake representation with random vectors (vector_length dimension). + vectors = [[random.uniform(-1, 1) for _ in range(VECTOR_LENGTH)] for _ in docs] + data = [ + {"id": i, "vector": vectors[i], "text": docs[i], "subject": "history"} + for i in range(len(vectors)) + ] + + print("Data has", len(data), "entities, each with fields: ", data[0].keys()) + print("Vector dim:", len(data[0]["vector"])) + + insert_res = client.insert(collection_name=COLLECTION_NAME, data=data) + assert insert_res == {"insert_count": 3, "ids": [0, 1, 2], "cost": 0} + + query_vectors = [[random.uniform(-1, 1) for _ in range(VECTOR_LENGTH)]] + + search_res = client.search( + collection_name=COLLECTION_NAME, # target collection + data=query_vectors, # query vectors + limit=2, # number of returned entities + output_fields=["text", "subject"], # specifies fields to be returned + ) + assert [j["id"] for j in search_res[0]] == [0, 1] + query_result = client.query( + collection_name=COLLECTION_NAME, + filter="id == 0", + ) + assert list(query_result[0].keys()) == ["id", "text", "subject", "vector"] + + client.drop_collection(collection_name=COLLECTION_NAME) + + +def test_milvus_lite_get_online_documents() -> None: + """ + Test retrieving documents from the online store in local mode. + """ + + random.seed(42) + n = 10 # number of samples - note: we'll actually double it + vector_length = 10 + runner = CliRunner() + with runner.local_repo( + example_repo_py=get_example_repo("example_rag_feature_repo.py"), + offline_store="file", + online_store="milvus", + apply=False, + teardown=False, + ) as store: + from datetime import timedelta + + from feast import Entity, FeatureView, Field, FileSource + from feast.types import Array, Float32, Int64, UnixTimestamp + + # This is for Milvus + # Note that file source paths are not validated, so there doesn't actually need to be any data + # at the paths for these file sources. Since these paths are effectively fake, this example + # feature repo should not be used for historical retrieval. + + rag_documents_source = FileSource( + path="data/embedded_documents.parquet", + timestamp_field="event_timestamp", + created_timestamp_column="created_timestamp", + ) + + item = Entity( + name="item_id", # The name is derived from this argument, not object name. + join_keys=["item_id"], + ) + + document_embeddings = FeatureView( + name="embedded_documents", + entities=[item], + schema=[ + Field( + name="vector", + dtype=Array(Float32), + vector_index=True, + vector_search_metric="L2", + ), + Field(name="item_id", dtype=Int64), + Field(name="created_timestamp", dtype=UnixTimestamp), + Field(name="event_timestamp", dtype=UnixTimestamp), + ], + source=rag_documents_source, + ttl=timedelta(hours=24), + ) + + store.apply([rag_documents_source, item, document_embeddings]) + + # Write some data to two tables + document_embeddings_fv = store.get_feature_view(name="embedded_documents") + + provider = store._get_provider() + + item_keys = [ + EntityKeyProto( + join_keys=["item_id"], entity_values=[ValueProto(int64_val=i)] + ) + for i in range(n) + ] + data = [] + for item_key in item_keys: + data.append( + ( + item_key, + { + "vector": ValueProto( + float_list_val=FloatListProto( + val=np.random.random( + vector_length, + ) + ) + ) + }, + _utc_now(), + _utc_now(), + ) + ) + + provider.online_write_batch( + config=store.config, + table=document_embeddings_fv, + data=data, + progress=None, + ) + documents_df = pd.DataFrame( + { + "item_id": [str(i) for i in range(n)], + "vector": [ + np.random.random( + vector_length, + ) + for i in range(n) + ], + "event_timestamp": [_utc_now() for _ in range(n)], + "created_timestamp": [_utc_now() for _ in range(n)], + } + ) + + store.write_to_online_store( + feature_view_name="embedded_documents", + df=documents_df, + ) + + query_embedding = np.random.random( + vector_length, + ) + result = store.retrieve_online_documents( + feature="embedded_documents:vector", query=query_embedding, top_k=3 + ).to_dict() + + assert "vector" in result + assert "distance" in result + assert len(result["distance"]) == 3 diff --git a/sdk/python/tests/utils/cli_repo_creator.py b/sdk/python/tests/utils/cli_repo_creator.py index e00104081a..8bb696f7d4 100644 --- a/sdk/python/tests/utils/cli_repo_creator.py +++ b/sdk/python/tests/utils/cli_repo_creator.py @@ -51,7 +51,14 @@ def run_with_output(self, args: List[str], cwd: Path) -> Tuple[int, bytes]: return e.returncode, e.output @contextmanager - def local_repo(self, example_repo_py: str, offline_store: str): + def local_repo( + self, + example_repo_py: str, + offline_store: str, + online_store: str = "sqlite", + apply=True, + teardown=True, + ): """ Convenience method to set up all the boilerplate for a local feature repo. """ @@ -67,41 +74,61 @@ def local_repo(self, example_repo_py: str, offline_store: str): data_path = Path(data_dir_name) repo_config = repo_path / "feature_store.yaml" - - repo_config.write_text( - dedent( + if online_store == "sqlite": + yaml_config = dedent( f""" - project: {project_id} - registry: {data_path / "registry.db"} - provider: local - online_store: - path: {data_path / "online_store.db"} - offline_store: - type: {offline_store} - entity_key_serialization_version: 2 - """ + project: {project_id} + registry: {data_path / "registry.db"} + provider: local + online_store: + path: {data_path / "online_store.db"} + offline_store: + type: {offline_store} + entity_key_serialization_version: 2 + """ ) - ) + elif online_store == "milvus": + yaml_config = dedent( + f""" + project: {project_id} + registry: {data_path / "registry.db"} + provider: local + online_store: + path: {data_path / "online_store.db"} + type: milvus + vector_enabled: true + embedding_dim: 10 + offline_store: + type: {offline_store} + entity_key_serialization_version: 3 + """ + ) + else: + pass + + repo_config.write_text(yaml_config) repo_example = repo_path / "example.py" repo_example.write_text(example_repo_py) - result = self.run(["apply"], cwd=repo_path) - stdout = result.stdout.decode("utf-8") - stderr = result.stderr.decode("utf-8") - print(f"Apply stdout:\n{stdout}") - print(f"Apply stderr:\n{stderr}") - assert ( - result.returncode == 0 - ), f"stdout: {result.stdout}\nstderr: {result.stderr}" + if apply: + result = self.run(["apply"], cwd=repo_path) + stdout = result.stdout.decode("utf-8") + stderr = result.stderr.decode("utf-8") + print(f"Apply stdout:\n{stdout}") + print(f"Apply stderr:\n{stderr}") + assert ( + result.returncode == 0 + ), f"stdout: {result.stdout}\nstderr: {result.stderr}" yield FeatureStore(repo_path=str(repo_path), config=None) - result = self.run(["teardown"], cwd=repo_path) - stdout = result.stdout.decode("utf-8") - stderr = result.stderr.decode("utf-8") - print(f"Apply stdout:\n{stdout}") - print(f"Apply stderr:\n{stderr}") - assert ( - result.returncode == 0 - ), f"stdout: {result.stdout}\nstderr: {result.stderr}" + if teardown: + result = self.run(["teardown"], cwd=repo_path) + stdout = result.stdout.decode("utf-8") + stderr = result.stderr.decode("utf-8") + print(f"Apply stdout:\n{stdout}") + print(f"Apply stderr:\n{stderr}") + assert ( + result.returncode == 0 + ), f"stdout: {result.stdout}\nstderr: {result.stderr}"