diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index 4564d6abf3..98db710d7f 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -1757,7 +1757,7 @@ def retrieve_online_documents( query: Union[str, List[float]], top_k: int, features: Optional[List[str]] = None, - distance_metric: Optional[str] = None, + distance_metric: Optional[str] = "L2", ) -> OnlineResponse: """ Retrieves the top k closest document features. Note, embeddings are a subset of features. diff --git a/sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py b/sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py index a1a4a3a5fe..8d5405c428 100644 --- a/sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py +++ b/sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py @@ -7,9 +7,8 @@ CollectionSchema, DataType, FieldSchema, - connections, + MilvusClient, ) -from pymilvus.orm.connections import Connections from feast import Entity from feast.feature_view import FeatureView @@ -85,7 +84,6 @@ class MilvusOnlineStoreConfig(FeastConfigBaseModel, VectorStoreConfig): """ type: Literal["milvus"] = "milvus" - host: Optional[StrictStr] = "localhost" port: Optional[int] = 19530 index_type: Optional[str] = "IVF_FLAT" @@ -93,6 +91,8 @@ class MilvusOnlineStoreConfig(FeastConfigBaseModel, VectorStoreConfig): embedding_dim: Optional[int] = 128 vector_enabled: Optional[bool] = True nlist: Optional[int] = 128 + username: Optional[StrictStr] = "" + password: Optional[StrictStr] = "" class MilvusOnlineStore(OnlineStore): @@ -103,24 +103,23 @@ class MilvusOnlineStore(OnlineStore): _collections: Dictionary to cache Milvus collections. """ - _conn: Optional[Connections] = None - _collections: Dict[str, Collection] = {} + client: Optional[MilvusClient] = None + _collections: Dict[str, Any] = {} - def _connect(self, config: RepoConfig) -> connections: - if not self._conn: - if not connections.has_connection("feast"): - self._conn = connections.connect( - alias="feast", - host=config.online_store.host, - port=str(config.online_store.port), - ) - return self._conn + def _connect(self, config: RepoConfig) -> MilvusClient: + if not self.client: + self.client = MilvusClient( + url=f"{config.online_store.host}:{config.online_store.port}", + token=f"{config.online_store.username}:{config.online_store.password}" + if config.online_store.username and config.online_store.password + else "", + ) + return self.client - def _get_collection(self, config: RepoConfig, table: FeatureView) -> Collection: + def _get_collection(self, config: RepoConfig, table: FeatureView) -> Dict[str, Any]: + self.client = self._connect(config) collection_name = _table_id(config.project, table) if collection_name not in self._collections: - self._connect(config) - # Create a composite key by combining entity fields composite_key_name = ( "_".join([field.name for field in table.entity_columns]) + "_pk" @@ -166,23 +165,38 @@ def _get_collection(self, config: RepoConfig, table: FeatureView) -> Collection: schema = CollectionSchema( fields=fields, description="Feast feature view data" ) - collection = Collection(name=collection_name, schema=schema, using="feast") - if not collection.has_index(): - index_params = { - "index_type": config.online_store.index_type, - "metric_type": config.online_store.metric_type, - "params": {"nlist": config.online_store.nlist}, - } - for vector_field in schema.fields: - if vector_field.dtype in [ - DataType.FLOAT_VECTOR, - DataType.BINARY_VECTOR, - ]: - collection.create_index( - field_name=vector_field.name, index_params=index_params - ) - collection.load() - self._collections[collection_name] = collection + collection_exists = self.client.has_collection( + collection_name=collection_name + ) + if not collection_exists: + self.client.create_collection( + collection_name=collection_name, + dimension=config.online_store.embedding_dim, + schema=schema, + ) + index_params = self.client.prepare_index_params() + for vector_field in schema.fields: + if vector_field.dtype in [ + DataType.FLOAT_VECTOR, + DataType.BINARY_VECTOR, + ]: + index_params.add_index( + collection_name=collection_name, + field_name=vector_field.name, + metric_type=config.online_store.metric_type, + index_type=config.online_store.index_type, + index_name=f"vector_index_{vector_field.name}", + params={"nlist": config.online_store.nlist}, + ) + self.client.create_index( + collection_name=collection_name, + index_params=index_params, + ) + else: + self.client.load_collection(collection_name) + self._collections[collection_name] = self.client.describe_collection( + collection_name + ) return self._collections[collection_name] def online_write_batch( @@ -199,6 +213,7 @@ def online_write_batch( ], progress: Optional[Callable[[int], Any]], ) -> None: + self.client = self._connect(config) collection = self._get_collection(config, table) entity_batch_to_insert = [] for entity_key, values_dict, timestamp, created_ts in data: @@ -231,8 +246,9 @@ def online_write_batch( if progress: progress(1) - collection.insert(entity_batch_to_insert) - collection.flush() + self.client.insert( + collection_name=collection["collection_name"], data=entity_batch_to_insert + ) def online_read( self, @@ -252,14 +268,14 @@ def update( entities_to_keep: Sequence[Entity], partial: bool, ): - self._connect(config) + self.client = self._connect(config) for table in tables_to_keep: - self._get_collection(config, table) + self._collections = self._get_collection(config, table) + for table in tables_to_delete: collection_name = _table_id(config.project, table) - collection = Collection(name=collection_name) - if collection.exists(): - collection.drop() + if self._collections.get(collection_name, None): + self.client.drop_collection(collection_name) self._collections.pop(collection_name, None) def plan( @@ -273,12 +289,12 @@ def teardown( tables: Sequence[FeatureView], entities: Sequence[Entity], ): - self._connect(config) + self.client = self._connect(config) for table in tables: - collection = self._get_collection(config, table) - if collection: - collection.drop() - self._collections.pop(collection.name, None) + collection_name = _table_id(config.project, table) + if self._collections.get(collection_name, None): + self.client.drop_collection(collection_name) + self._collections.pop(collection_name, None) def retrieve_online_documents( self, @@ -298,6 +314,8 @@ def retrieve_online_documents( Optional[ValueProto], ] ]: + self.client = self._connect(config) + collection_name = _table_id(config.project, table) collection = self._get_collection(config, table) if not config.online_store.vector_enabled: raise ValueError("Vector search is not enabled in the online store config") @@ -321,28 +339,27 @@ def retrieve_online_documents( + ["created_ts", "event_ts"] ) assert all( - field + field in [f["name"] for f in collection["fields"]] for field in output_fields - if field in [f.name for f in collection.schema.fields] - ), f"field(s) [{[field for field in output_fields if field not in [f.name for f in collection.schema.fields]]}'] not found in collection schema" - + ), f"field(s) [{[field for field in output_fields if field not in [f['name'] for f in collection['fields']]]}] not found in collection schema" # Note we choose the first vector field as the field to search on. Not ideal but it's something. ann_search_field = None - for field in collection.schema.fields: + for field in collection["fields"]: if ( - field.dtype in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR] - and field.name in output_fields + field["type"] in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR] + and field["name"] in output_fields ): - ann_search_field = field.name + ann_search_field = field["name"] break - results = collection.search( + self.client.load_collection(collection_name) + results = self.client.search( + collection_name=collection_name, data=[embedding], anns_field=ann_search_field, - param=search_params, + search_params=search_params, limit=top_k, output_fields=output_fields, - consistency_level="Strong", ) result_list = [] @@ -350,13 +367,17 @@ def retrieve_online_documents( for hit in hits: single_record = {} for field in output_fields: - single_record[field] = hit.entity.get(field) + single_record[field] = hit.get("entity", {}).get(field, None) - entity_key_bytes = bytes.fromhex(hit.entity.get(composite_key_name)) - embedding = hit.entity.get(ann_search_field) + entity_key_bytes = bytes.fromhex( + hit.get("entity", {}).get(composite_key_name, None) + ) + embedding = hit.get("entity", {}).get(ann_search_field) serialized_embedding = _serialize_vector_to_float_list(embedding) - distance = hit.distance - event_ts = datetime.fromtimestamp(hit.entity.get("event_ts") / 1e6) + distance = hit.get("distance", None) + event_ts = datetime.fromtimestamp( + hit.get("entity", {}).get("event_ts") / 1e6 + ) prepared_result = _build_retrieve_online_document_record( entity_key_bytes, # This may have a bug @@ -412,7 +433,7 @@ def __init__(self, host: str, port: int, name: str): self._connect() def _connect(self): - return connections.connect(alias="default", host=self.host, port=str(self.port)) + raise NotImplementedError def to_infra_object_proto(self) -> InfraObjectProto: # Implement serialization if needed diff --git a/sdk/python/tests/integration/feature_repos/universal/online_store/milvus.py b/sdk/python/tests/integration/feature_repos/universal/online_store/milvus.py index 8ffee04c12..c02bd14401 100644 --- a/sdk/python/tests/integration/feature_repos/universal/online_store/milvus.py +++ b/sdk/python/tests/integration/feature_repos/universal/online_store/milvus.py @@ -1,6 +1,8 @@ from typing import Any, Dict -from testcontainers.milvus import MilvusContainer +import docker +from testcontainers.core.container import DockerContainer +from testcontainers.core.waiting_utils import wait_for_logs from tests.integration.feature_repos.universal.online_store_creator import ( OnlineStoreCreator, @@ -11,13 +13,19 @@ class MilvusOnlineStoreCreator(OnlineStoreCreator): def __init__(self, project_name: str, **kwargs): super().__init__(project_name) self.fixed_port = 19530 - self.container = MilvusContainer("milvusdb/milvus:v2.4.4").with_exposed_ports( + self.container = DockerContainer("milvusdb/milvus:v2.4.4").with_exposed_ports( self.fixed_port ) + self.client = docker.from_env() def create_online_store(self) -> Dict[str, Any]: self.container.start() # Wait for Milvus server to be ready + # log_string_to_wait_for = "Ready to accept connections" + log_string_to_wait_for = "" + wait_for_logs( + container=self.container, predicate=log_string_to_wait_for, timeout=30 + ) host = "localhost" port = self.container.get_exposed_port(self.fixed_port) return { diff --git a/sdk/python/tests/integration/online_store/test_universal_online.py b/sdk/python/tests/integration/online_store/test_universal_online.py index ab665914b5..64122d2c86 100644 --- a/sdk/python/tests/integration/online_store/test_universal_online.py +++ b/sdk/python/tests/integration/online_store/test_universal_online.py @@ -897,26 +897,26 @@ def test_retrieve_online_documents(environment, fake_document_data): ).to_dict() -# @pytest.mark.integration -# @pytest.mark.universal_online_stores(only=["milvus"]) -# def test_retrieve_online_milvus_documents(environment, fake_document_data): -# fs = environment.feature_store -# df, data_source = fake_document_data -# item_embeddings_feature_view = create_item_embeddings_feature_view(data_source) -# fs.apply([item_embeddings_feature_view, item()]) -# fs.write_to_online_store("item_embeddings", df) -# documents = fs.retrieve_online_documents( -# feature=None, -# features=[ -# "item_embeddings:embedding_float", -# "item_embeddings:item_id", -# "item_embeddings:string_feature", -# ], -# query=[1.0, 2.0], -# top_k=2, -# distance_metric="L2", -# ).to_dict() -# assert len(documents["embedding_float"]) == 2 -# -# assert len(documents["item_id"]) == 2 -# assert documents["item_id"] == [2, 3] +@pytest.mark.integration +@pytest.mark.universal_online_stores(only=["milvus"]) +def test_retrieve_online_milvus_documents(environment, fake_document_data): + fs = environment.feature_store + df, data_source = fake_document_data + item_embeddings_feature_view = create_item_embeddings_feature_view(data_source) + fs.apply([item_embeddings_feature_view, item()]) + fs.write_to_online_store("item_embeddings", df) + documents = fs.retrieve_online_documents( + feature=None, + features=[ + "item_embeddings:embedding_float", + "item_embeddings:item_id", + "item_embeddings:string_feature", + ], + query=[1.0, 2.0], + top_k=2, + distance_metric="L2", + ).to_dict() + assert len(documents["embedding_float"]) == 2 + + assert len(documents["item_id"]) == 2 + assert documents["item_id"] == [2, 3]