diff --git a/sdk/python/feast/entity.py b/sdk/python/feast/entity.py index e1c6f601f2..832b9e4db8 100644 --- a/sdk/python/feast/entity.py +++ b/sdk/python/feast/entity.py @@ -36,7 +36,7 @@ class Entity: def __init__( self, name: str, - value_type: ValueType, + value_type: ValueType = ValueType.UNKNOWN, description: str = "", join_key: Optional[str] = None, labels: Optional[MutableMapping[str, str]] = None, diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index 49b2afed74..8198246197 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -26,6 +26,7 @@ from feast.entity import Entity from feast.errors import FeastProviderLoginError, FeatureViewNotFoundException from feast.feature_view import FeatureView +from feast.inference import infer_entity_value_type_from_feature_views from feast.infra.provider import Provider, RetrievalJob, get_provider from feast.online_response import OnlineResponse, _infer_online_entity_rows from feast.protos.feast.serving.ServingService_pb2 import ( @@ -219,19 +220,19 @@ def apply( objects = [objects] assert isinstance(objects, list) - views_to_update = [] - entities_to_update = [] - for ob in objects: - if isinstance(ob, FeatureView): - self._registry.apply_feature_view(ob, project=self.project) - views_to_update.append(ob) - elif isinstance(ob, Entity): - self._registry.apply_entity(ob, project=self.project) - entities_to_update.append(ob) - else: - raise ValueError( - f"Unknown object type ({type(ob)}) provided as part of apply() call" - ) + views_to_update = [ob for ob in objects if isinstance(ob, FeatureView)] + entities_to_update = infer_entity_value_type_from_feature_views( + [ob for ob in objects if isinstance(ob, Entity)], views_to_update + ) + + if len(views_to_update) + len(entities_to_update) != len(objects): + raise ValueError("Unknown object type provided as part of apply() call") + + for view in views_to_update: + self._registry.apply_feature_view(view, project=self.project) + for ent in entities_to_update: + self._registry.apply_entity(ent, project=self.project) + self._get_provider().update_infra( project=self.project, tables_to_delete=[], diff --git a/sdk/python/feast/inference.py b/sdk/python/feast/inference.py new file mode 100644 index 0000000000..54105a9bc2 --- /dev/null +++ b/sdk/python/feast/inference.py @@ -0,0 +1,56 @@ +from typing import List + +from feast import Entity +from feast.feature_view import FeatureView +from feast.value_type import ValueType + + +def infer_entity_value_type_from_feature_views( + entities: List[Entity], feature_views: List[FeatureView] +) -> List[Entity]: + """ + Infer entity value type by examining schema of feature view input sources + """ + incomplete_entities = { + entity.name: entity + for entity in entities + if entity.value_type == ValueType.UNKNOWN + } + incomplete_entities_keys = incomplete_entities.keys() + + for view in feature_views: + if not (incomplete_entities_keys & set(view.entities)): + continue # skip if view doesn't contain any entities that need inference + + col_names_and_types = view.input.get_table_column_names_and_types() + for entity_name in view.entities: + if entity_name in incomplete_entities: + # get entity information from information extracted from the view input source + extracted_entity_name_type_pairs = list( + filter(lambda tup: tup[0] == entity_name, col_names_and_types) + ) + if len(extracted_entity_name_type_pairs) == 0: + # Doesn't mention inference error because would also be an error without inferencing + raise ValueError( + f"""No column in the input source for the {view.name} feature view matches + its entity's name.""" + ) + + entity = incomplete_entities[entity_name] + inferred_value_type = view.input.source_datatype_to_feast_value_type()( + extracted_entity_name_type_pairs[0][1] + ) + + if ( + entity.value_type != ValueType.UNKNOWN + and entity.value_type != inferred_value_type + ) or (len(extracted_entity_name_type_pairs) > 1): + raise ValueError( + f"""Entity value_type inference failed for {entity_name} entity. + Multiple viable matches. Please explicitly specify the entity value_type + for this entity.""" + ) + + entity.value_type = inferred_value_type + + return entities diff --git a/sdk/python/feast/repo_operations.py b/sdk/python/feast/repo_operations.py index d49aaaca89..a16a789f5f 100644 --- a/sdk/python/feast/repo_operations.py +++ b/sdk/python/feast/repo_operations.py @@ -13,6 +13,7 @@ from feast import Entity, FeatureTable from feast.feature_view import FeatureView +from feast.inference import infer_entity_value_type_from_feature_views from feast.infra.offline_stores.helpers import assert_offline_store_supports_data_source from feast.infra.provider import get_provider from feast.names import adjectives, animals @@ -129,6 +130,13 @@ def apply_total(repo_config: RepoConfig, repo_path: Path): registry._initialize_registry() sys.dont_write_bytecode = True repo = parse_repo(repo_path) + repo = ParsedRepo( + feature_tables=repo.feature_tables, + entities=infer_entity_value_type_from_feature_views( + repo.entities, repo.feature_views + ), + feature_views=repo.feature_views, + ) sys.dont_write_bytecode = False for entity in repo.entities: registry.apply_entity(entity, project=project) diff --git a/sdk/python/tests/conftest.py b/sdk/python/tests/conftest.py index 1c038a9b78..0c94f4d57a 100644 --- a/sdk/python/tests/conftest.py +++ b/sdk/python/tests/conftest.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import multiprocessing +from datetime import datetime, timedelta from sys import platform +import pandas as pd import pytest @@ -45,3 +47,43 @@ def pytest_collection_modifyitems(config, items): for item in items: if "integration" in item.keywords: item.add_marker(skip_integration) + + +@pytest.fixture +def simple_dataset_1() -> pd.DataFrame: + now = datetime.utcnow() + ts = pd.Timestamp(now).round("ms") + data = { + "id": [1, 2, 1, 3, 3], + "float_col": [0.1, 0.2, 0.3, 4, 5], + "int64_col": [1, 2, 3, 4, 5], + "string_col": ["a", "b", "c", "d", "e"], + "ts_1": [ + ts, + ts - timedelta(hours=4), + ts - timedelta(hours=3), + ts - timedelta(hours=2), + ts - timedelta(hours=1), + ], + } + return pd.DataFrame.from_dict(data) + + +@pytest.fixture +def simple_dataset_2() -> pd.DataFrame: + now = datetime.utcnow() + ts = pd.Timestamp(now).round("ms") + data = { + "id": ["a", "b", "c", "d", "e"], + "float_col": [0.1, 0.2, 0.3, 4, 5], + "int64_col": [1, 2, 3, 4, 5], + "string_col": ["a", "b", "c", "d", "e"], + "ts_1": [ + ts, + ts - timedelta(hours=4), + ts - timedelta(hours=3), + ts - timedelta(hours=2), + ts - timedelta(hours=1), + ], + } + return pd.DataFrame.from_dict(data) diff --git a/sdk/python/tests/example_feature_repo_with_inference.py b/sdk/python/tests/example_feature_repo_with_inference.py index e0ccab5015..a427f0cea4 100644 --- a/sdk/python/tests/example_feature_repo_with_inference.py +++ b/sdk/python/tests/example_feature_repo_with_inference.py @@ -1,6 +1,6 @@ from google.protobuf.duration_pb2 import Duration -from feast import Entity, FeatureView, ValueType +from feast import Entity, FeatureView from feast.data_source import FileSource driver_hourly_stats = FileSource( @@ -8,7 +8,7 @@ created_timestamp_column="created", ) -driver = Entity(name="driver_id", value_type=ValueType.INT64, description="driver id",) +driver = Entity(name="driver_id", description="driver id",) # features are inferred from columns of data source driver_hourly_stats_view = FeatureView( diff --git a/sdk/python/tests/test_feature_store.py b/sdk/python/tests/test_feature_store.py index 7486fc4ec9..7b5363e636 100644 --- a/sdk/python/tests/test_feature_store.py +++ b/sdk/python/tests/test_feature_store.py @@ -16,13 +16,12 @@ from tempfile import mkstemp import pytest -from fixtures.data_source_fixtures import simple_dataset_1 # noqa: F401 -from fixtures.data_source_fixtures import ( +from pytest_lazyfixture import lazy_fixture +from utils.data_source_utils import ( prep_file_source, simple_bq_source_using_query_arg, simple_bq_source_using_table_ref_arg, ) -from pytest_lazyfixture import lazy_fixture from feast.data_format import ParquetFormat from feast.data_source import FileSource @@ -315,22 +314,6 @@ def test_apply_feature_view_integration(test_feature_store): assert len(feature_views) == 0 -@pytest.mark.integration -@pytest.mark.parametrize("dataframe_source", [lazy_fixture("simple_dataset_1")]) -def test_data_source_ts_col_inference_success(dataframe_source): - with prep_file_source(df=dataframe_source) as file_source: - actual_file_source = file_source.event_timestamp_column - actual_bq_1 = simple_bq_source_using_table_ref_arg( - dataframe_source - ).event_timestamp_column - actual_bq_2 = simple_bq_source_using_query_arg( - dataframe_source - ).event_timestamp_column - expected = "ts_1" - - assert expected == actual_file_source == actual_bq_1 == actual_bq_2 - - @pytest.mark.parametrize( "test_feature_store", [lazy_fixture("feature_store_with_local_registry")], ) diff --git a/sdk/python/tests/test_inference.py b/sdk/python/tests/test_inference.py new file mode 100644 index 0000000000..886aca8ab2 --- /dev/null +++ b/sdk/python/tests/test_inference.py @@ -0,0 +1,49 @@ +import pytest +from utils.data_source_utils import ( + prep_file_source, + simple_bq_source_using_query_arg, + simple_bq_source_using_table_ref_arg, +) + +from feast import Entity, ValueType +from feast.feature_view import FeatureView +from feast.inference import infer_entity_value_type_from_feature_views + + +@pytest.mark.integration +def test_data_source_ts_col_inference_success(simple_dataset_1): + with prep_file_source(df=simple_dataset_1) as file_source: + actual_file_source = file_source.event_timestamp_column + actual_bq_1 = simple_bq_source_using_table_ref_arg( + simple_dataset_1 + ).event_timestamp_column + actual_bq_2 = simple_bq_source_using_query_arg( + simple_dataset_1 + ).event_timestamp_column + expected = "ts_1" + + assert expected == actual_file_source == actual_bq_1 == actual_bq_2 + + +def test_infer_entity_value_type_from_feature_views(simple_dataset_1, simple_dataset_2): + with prep_file_source( + df=simple_dataset_1, event_timestamp_column="ts_1" + ) as file_source, prep_file_source( + df=simple_dataset_2, event_timestamp_column="ts_1" + ) as file_source_2: + + fv1 = FeatureView(name="fv1", entities=["id"], input=file_source, ttl=None,) + fv2 = FeatureView(name="fv2", entities=["id"], input=file_source_2, ttl=None,) + + actual_1 = infer_entity_value_type_from_feature_views( + [Entity(name="id")], [fv1] + ) + actual_2 = infer_entity_value_type_from_feature_views( + [Entity(name="id")], [fv2] + ) + assert actual_1 == [Entity(name="id", value_type=ValueType.INT64)] + assert actual_2 == [Entity(name="id", value_type=ValueType.STRING)] + + with pytest.raises(ValueError): + # two viable data types + infer_entity_value_type_from_feature_views([Entity(name="id")], [fv1, fv2]) diff --git a/sdk/python/tests/fixtures/data_source_fixtures.py b/sdk/python/tests/utils/data_source_utils.py similarity index 74% rename from sdk/python/tests/fixtures/data_source_fixtures.py rename to sdk/python/tests/utils/data_source_utils.py index 587457b49b..0aec0c6f1a 100644 --- a/sdk/python/tests/fixtures/data_source_fixtures.py +++ b/sdk/python/tests/utils/data_source_utils.py @@ -1,35 +1,12 @@ import contextlib import tempfile -from datetime import datetime, timedelta -import pandas as pd -import pytest from google.cloud import bigquery from feast.data_format import ParquetFormat from feast.data_source import BigQuerySource, FileSource -@pytest.fixture -def simple_dataset_1() -> pd.DataFrame: - now = datetime.utcnow() - ts = pd.Timestamp(now).round("ms") - data = { - "id": [1, 2, 1, 3, 3], - "float_col": [0.1, 0.2, 0.3, 4, 5], - "int64_col": [1, 2, 3, 4, 5], - "string_col": ["a", "b", "c", "d", "e"], - "ts_1": [ - ts, - ts - timedelta(hours=4), - ts - timedelta(hours=3), - ts - timedelta(hours=2), - ts - timedelta(hours=1), - ], - } - return pd.DataFrame.from_dict(data) - - @contextlib.contextmanager def prep_file_source(df, event_timestamp_column="") -> FileSource: with tempfile.NamedTemporaryFile(suffix=".parquet") as f: