diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index edbd060e10..4564d6abf3 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -1753,9 +1753,10 @@ async def get_online_features_async( def retrieve_online_documents( self, - feature: str, + feature: Optional[str], query: Union[str, List[float]], top_k: int, + features: Optional[List[str]] = None, distance_metric: Optional[str] = None, ) -> OnlineResponse: """ @@ -1765,6 +1766,7 @@ def retrieve_online_documents( feature: The list of document features that should be retrieved from the online document store. These features can be specified either as a list of string document feature references or as a feature service. String feature references must have format "feature_view:feature", e.g, "document_fv:document_embeddings". + features: The list of features that should be retrieved from the online store. query: The query to retrieve the closest document features for. top_k: The number of closest document features to retrieve. distance_metric: The distance metric to use for retrieval. @@ -1773,18 +1775,44 @@ def retrieve_online_documents( raise ValueError( "Using embedding functionality is not supported for document retrieval. Please embed the query before calling retrieve_online_documents." ) + feature_list: List[str] = ( + features + if features is not None + else ([feature] if feature is not None else []) + ) + ( available_feature_views, _, ) = utils._get_feature_views_to_use( registry=self._registry, project=self.project, - features=[feature], + features=feature_list, allow_cache=True, hide_dummy_entity=False, ) + if features: + feature_view_set = set() + for feature in features: + feature_view_name = feature.split(":")[0] + feature_view = self.get_feature_view(feature_view_name) + feature_view_set.add(feature_view.name) + if len(feature_view_set) > 1: + raise ValueError( + "Document retrieval only supports a single feature view." + ) + requested_feature = None + requested_features = [ + f.split(":")[1] for f in features if isinstance(f, str) and ":" in f + ] + else: + requested_feature = ( + feature.split(":")[1] if isinstance(feature, str) else feature + ) + requested_features = [requested_feature] if requested_feature else [] + requested_feature_view_name = ( - feature.split(":")[0] if isinstance(feature, str) else feature + feature.split(":")[0] if feature else list(feature_view_set)[0] ) for feature_view in available_feature_views: if feature_view.name == requested_feature_view_name: @@ -1793,14 +1821,15 @@ def retrieve_online_documents( raise ValueError( f"Feature view {requested_feature_view} not found in the registry." ) - requested_feature = ( - feature.split(":")[1] if isinstance(feature, str) else feature - ) + + requested_feature_view = available_feature_views[0] + provider = self._get_provider() document_features = self._retrieve_from_online_store( provider, requested_feature_view, requested_feature, + requested_features, query, top_k, distance_metric, @@ -1822,6 +1851,7 @@ def retrieve_online_documents( document_feature_vals = [feature[4] for feature in document_features] document_feature_distance_vals = [feature[5] for feature in document_features] online_features_response = GetOnlineFeaturesResponse(results=[]) + requested_feature = requested_feature or requested_features[0] utils._populate_result_rows_from_columnar( online_features_response=online_features_response, data={ @@ -1836,7 +1866,8 @@ def _retrieve_from_online_store( self, provider: Provider, table: FeatureView, - requested_feature: str, + requested_feature: Optional[str], + requested_features: Optional[List[str]], query: List[float], top_k: int, distance_metric: Optional[str], @@ -1852,6 +1883,7 @@ def _retrieve_from_online_store( config=self.config, table=table, requested_feature=requested_feature, + requested_features=requested_features, query=query, top_k=top_k, distance_metric=distance_metric, @@ -1952,19 +1984,13 @@ def serve_ui( ) def serve_registry( - self, - port: int, - tls_key_path: str = "", - tls_cert_path: str = "", + self, port: int, tls_key_path: str = "", tls_cert_path: str = "" ) -> None: """Start registry server locally on a given port.""" from feast import registry_server registry_server.start_server( - self, - port=port, - tls_key_path=tls_key_path, - tls_cert_path=tls_cert_path, + self, port=port, tls_key_path=tls_key_path, tls_cert_path=tls_cert_path ) def serve_offline( diff --git a/sdk/python/feast/infra/key_encoding_utils.py b/sdk/python/feast/infra/key_encoding_utils.py index 1f9ffeef14..18127896bd 100644 --- a/sdk/python/feast/infra/key_encoding_utils.py +++ b/sdk/python/feast/infra/key_encoding_utils.py @@ -1,5 +1,7 @@ import struct -from typing import List, Tuple +from typing import List, Tuple, Union + +from google.protobuf.internal.containers import RepeatedScalarFieldContainer from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto from feast.protos.feast.types.Value_pb2 import Value as ValueProto @@ -163,3 +165,16 @@ def get_list_val_str(val): if val.HasField(accept_type): return str(getattr(val, accept_type).val) return None + + +def serialize_f32( + vector: Union[RepeatedScalarFieldContainer[float], List[float]], vector_length: int +) -> bytes: + """serializes a list of floats into a compact "raw bytes" format""" + return struct.pack(f"{vector_length}f", *vector) + + +def deserialize_f32(byte_vector: bytes, vector_length: int) -> List[float]: + """deserializes a list of floats from a compact "raw bytes" format""" + num_floats = vector_length // 4 # 4 bytes per float + return list(struct.unpack(f"{num_floats}f", byte_vector)) diff --git a/sdk/python/feast/infra/online_stores/elasticsearch_online_store/elasticsearch.py b/sdk/python/feast/infra/online_stores/elasticsearch_online_store/elasticsearch.py index 0152ca330c..af32814152 100644 --- a/sdk/python/feast/infra/online_stores/elasticsearch_online_store/elasticsearch.py +++ b/sdk/python/feast/infra/online_stores/elasticsearch_online_store/elasticsearch.py @@ -213,7 +213,8 @@ def retrieve_online_documents( self, config: RepoConfig, table: FeatureView, - requested_feature: str, + requested_feature: Optional[str], + requested_features: Optional[List[str]], embedding: List[float], top_k: int, *args, diff --git a/sdk/python/feast/infra/online_stores/faiss_online_store.py b/sdk/python/feast/infra/online_stores/faiss_online_store.py index cc2e75800e..fd4d6768ab 100644 --- a/sdk/python/feast/infra/online_stores/faiss_online_store.py +++ b/sdk/python/feast/infra/online_stores/faiss_online_store.py @@ -176,7 +176,8 @@ def retrieve_online_documents( self, config: RepoConfig, table: FeatureView, - requested_feature: str, + requested_feature: Optional[str], + requested_featres: Optional[List[str]], embedding: List[float], top_k: int, distance_metric: Optional[str] = None, diff --git a/sdk/python/feast/infra/online_stores/online_store.py b/sdk/python/feast/infra/online_stores/online_store.py index 789885f82b..be3128562d 100644 --- a/sdk/python/feast/infra/online_stores/online_store.py +++ b/sdk/python/feast/infra/online_stores/online_store.py @@ -390,7 +390,8 @@ def retrieve_online_documents( self, config: RepoConfig, table: FeatureView, - requested_feature: str, + requested_feature: Optional[str], + requested_features: Optional[List[str]], embedding: List[float], top_k: int, distance_metric: Optional[str] = None, @@ -411,6 +412,7 @@ def retrieve_online_documents( config: The config for the current feature store. table: The feature view whose feature values should be read. requested_feature: The name of the feature whose embeddings should be used for retrieval. + requested_features: The list of features whose embeddings should be used for retrieval. embedding: The embeddings to use for retrieval. top_k: The number of documents to retrieve. @@ -419,6 +421,10 @@ def retrieve_online_documents( where the first item is the event timestamp for the row, and the second item is a dict of feature name to embeddings. """ + if not requested_feature and not requested_features: + raise ValueError( + "Either requested_feature or requested_features must be specified" + ) raise NotImplementedError( f"Online store {self.__class__.__name__} does not support online retrieval" ) diff --git a/sdk/python/feast/infra/online_stores/postgres_online_store/postgres.py b/sdk/python/feast/infra/online_stores/postgres_online_store/postgres.py index 7c099c80ec..f43247a545 100644 --- a/sdk/python/feast/infra/online_stores/postgres_online_store/postgres.py +++ b/sdk/python/feast/infra/online_stores/postgres_online_store/postgres.py @@ -347,7 +347,8 @@ def retrieve_online_documents( self, config: RepoConfig, table: FeatureView, - requested_feature: str, + requested_feature: Optional[str], + requested_features: Optional[List[str]], embedding: List[float], top_k: int, distance_metric: Optional[str] = "L2", @@ -366,6 +367,7 @@ def retrieve_online_documents( config: Feast configuration object table: FeatureView object as the table to search requested_feature: The requested feature as the column to search + requested_features: The list of features whose embeddings should be used for retrieval. embedding: The query embedding to search for top_k: The number of items to return distance_metric: The distance metric to use for the search.G diff --git a/sdk/python/feast/infra/online_stores/qdrant_online_store/qdrant.py b/sdk/python/feast/infra/online_stores/qdrant_online_store/qdrant.py index 074c52ba5e..cdbef95348 100644 --- a/sdk/python/feast/infra/online_stores/qdrant_online_store/qdrant.py +++ b/sdk/python/feast/infra/online_stores/qdrant_online_store/qdrant.py @@ -248,7 +248,8 @@ def retrieve_online_documents( self, config: RepoConfig, table: FeatureView, - requested_feature: str, + requested_feature: Optional[str], + requested_features: Optional[List[str]], embedding: List[float], top_k: int, distance_metric: Optional[str] = "cosine", diff --git a/sdk/python/feast/infra/online_stores/sqlite.py b/sdk/python/feast/infra/online_stores/sqlite.py index e2eeb038d0..23b4f6db3a 100644 --- a/sdk/python/feast/infra/online_stores/sqlite.py +++ b/sdk/python/feast/infra/online_stores/sqlite.py @@ -15,19 +15,20 @@ import logging import os import sqlite3 -import struct import sys from datetime import date, datetime from pathlib import Path -from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple -from google.protobuf.internal.containers import RepeatedScalarFieldContainer from pydantic import StrictStr from feast import Entity from feast.feature_view import FeatureView from feast.infra.infra_object import SQLITE_INFRA_OBJECT_CLASS_TYPE, InfraObject -from feast.infra.key_encoding_utils import serialize_entity_key +from feast.infra.key_encoding_utils import ( + serialize_entity_key, + serialize_f32, +) from feast.infra.online_stores.online_store import OnlineStore from feast.infra.online_stores.vector_store import VectorStoreConfig from feast.protos.feast.core.InfraObject_pb2 import InfraObject as InfraObjectProto @@ -330,7 +331,8 @@ def retrieve_online_documents( self, config: RepoConfig, table: FeatureView, - requested_feature: str, + requested_feature: Optional[str], + requested_featuers: Optional[List[str]], embedding: List[float], top_k: int, distance_metric: Optional[str] = None, @@ -432,6 +434,7 @@ def retrieve_online_documents( _build_retrieve_online_document_record( entity_key, string_value if string_value else b"", + # This may be a bug embedding, distance, event_ts, @@ -459,19 +462,6 @@ def _table_id(project: str, table: FeatureView) -> str: return f"{project}_{table.name}" -def serialize_f32( - vector: Union[RepeatedScalarFieldContainer[float], List[float]], vector_length: int -) -> bytes: - """serializes a list of floats into a compact "raw bytes" format""" - return struct.pack(f"{vector_length}f", *vector) - - -def deserialize_f32(byte_vector: bytes, vector_length: int) -> List[float]: - """deserializes a list of floats from a compact "raw bytes" format""" - num_floats = vector_length // 4 # 4 bytes per float - return list(struct.unpack(f"{num_floats}f", byte_vector)) - - class SqliteTable(InfraObject): """ A Sqlite table managed by Feast. diff --git a/sdk/python/feast/infra/passthrough_provider.py b/sdk/python/feast/infra/passthrough_provider.py index 215b175eb2..57aa122ae8 100644 --- a/sdk/python/feast/infra/passthrough_provider.py +++ b/sdk/python/feast/infra/passthrough_provider.py @@ -294,7 +294,8 @@ def retrieve_online_documents( self, config: RepoConfig, table: FeatureView, - requested_feature: str, + requested_feature: Optional[str], + requested_features: Optional[List[str]], query: List[float], top_k: int, distance_metric: Optional[str] = None, @@ -305,6 +306,7 @@ def retrieve_online_documents( config, table, requested_feature, + requested_features, query, top_k, distance_metric, diff --git a/sdk/python/feast/infra/provider.py b/sdk/python/feast/infra/provider.py index 8351f389ad..efc806ba2f 100644 --- a/sdk/python/feast/infra/provider.py +++ b/sdk/python/feast/infra/provider.py @@ -419,7 +419,8 @@ def retrieve_online_documents( self, config: RepoConfig, table: FeatureView, - requested_feature: str, + requested_feature: Optional[str], + requested_features: Optional[List[str]], query: List[float], top_k: int, distance_metric: Optional[str] = None, @@ -440,6 +441,7 @@ def retrieve_online_documents( config: The config for the current feature store. table: The feature view whose embeddings should be searched. requested_feature: the requested document feature name. + requested_features: the requested document feature names. query: The query embedding to search for. top_k: The number of documents to return. diff --git a/sdk/python/feast/utils.py b/sdk/python/feast/utils.py index 51d4bf4f2c..cfc19e37ca 100644 --- a/sdk/python/feast/utils.py +++ b/sdk/python/feast/utils.py @@ -1192,6 +1192,10 @@ def _utc_now() -> datetime: return datetime.now(tz=timezone.utc) +def _serialize_vector_to_float_list(vector: List[float]) -> ValueProto: + return ValueProto(float_list_val=FloatListProto(val=vector)) + + def _build_retrieve_online_document_record( entity_key: Union[str, bytes], feature_value: Union[str, bytes],