Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Python SDK so FeatureSet can import Schema from Tensorflow metadata #450

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions sdk/python/feast/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,26 @@ def to_proto(self) -> EntityProto:
Returns EntitySpec object
"""
value_type = ValueTypeProto.ValueType.Enum.Value(self.dtype.name)
return EntityProto(name=self.name, value_type=value_type)
return EntityProto(
name=self.name,
value_type=value_type,
presence=self.presence,
group_presence=self.group_presence,
shape=self.shape,
value_count=self.value_count,
domain=self.domain,
int_domain=self.int_domain,
float_domain=self.float_domain,
string_domain=self.string_domain,
bool_domain=self.bool_domain,
struct_domain=self.struct_domain,
natural_language_domain=self.natural_language_domain,
image_domain=self.image_domain,
mid_domain=self.mid_domain,
url_domain=self.url_domain,
time_domain=self.time_domain,
time_of_day_domain=self.time_of_day_domain,
)

@classmethod
def from_proto(cls, entity_proto: EntityProto):
Expand All @@ -42,4 +61,8 @@ def from_proto(cls, entity_proto: EntityProto):
Returns:
Entity object
"""
return cls(name=entity_proto.name, dtype=ValueType(entity_proto.value_type))
entity = cls(name=entity_proto.name, dtype=ValueType(entity_proto.value_type))
entity.update_presence_constraints(entity_proto)
entity.update_shape_type(entity_proto)
entity.update_domain_info(entity_proto)
return entity
38 changes: 35 additions & 3 deletions sdk/python/feast/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,41 @@ class Feature(Field):
def to_proto(self) -> FeatureProto:
"""Converts Feature object to its Protocol Buffer representation"""
value_type = ValueTypeProto.ValueType.Enum.Value(self.dtype.name)
return FeatureProto(name=self.name, value_type=value_type)
return FeatureProto(
name=self.name,
value_type=value_type,
presence=self.presence,
group_presence=self.group_presence,
shape=self.shape,
value_count=self.value_count,
domain=self.domain,
int_domain=self.int_domain,
float_domain=self.float_domain,
string_domain=self.string_domain,
bool_domain=self.bool_domain,
struct_domain=self.struct_domain,
natural_language_domain=self.natural_language_domain,
image_domain=self.image_domain,
mid_domain=self.mid_domain,
url_domain=self.url_domain,
time_domain=self.time_domain,
time_of_day_domain=self.time_of_day_domain,
)

@classmethod
def from_proto(cls, feature_proto: FeatureProto):
"""Converts Protobuf Feature to its SDK equivalent"""
return cls(name=feature_proto.name, dtype=ValueType(feature_proto.value_type))
"""

Args:
feature_proto: FeatureSpec protobuf object

Returns:
Feature object
"""
feature = cls(
name=feature_proto.name, dtype=ValueType(feature_proto.value_type)
)
feature.update_presence_constraints(feature_proto)
feature.update_shape_type(feature_proto)
feature.update_domain_info(feature_proto)
return feature
129 changes: 126 additions & 3 deletions sdk/python/feast/feature_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import warnings
from collections import OrderedDict
from typing import Dict, List, Optional
from typing import Dict
from typing import List, Optional

import pandas as pd
import pyarrow as pa
from google.protobuf import json_format
from google.protobuf.duration_pb2 import Duration
from google.protobuf.json_format import MessageToJson
from google.protobuf.message import Message
from pandas.api.types import is_datetime64_ns_dtype
from pyarrow.lib import TimestampType
from tensorflow_metadata.proto.v0 import schema_pb2

from feast.core.FeatureSet_pb2 import FeatureSet as FeatureSetProto
from feast.core.FeatureSet_pb2 import FeatureSetMeta as FeatureSetMetaProto
Expand Down Expand Up @@ -657,6 +659,93 @@ def is_valid(self):
if len(self.entities) == 0:
raise ValueError(f"No entities found in feature set {self.name}")

def import_tfx_schema(self, schema: schema_pb2.Schema):
"""
Updates presence_constraints, shape_type and domain_info for all fields
(features and entities) in the FeatureSet from schema in the Tensorflow metadata.

Args:
schema: Schema from Tensorflow metadata

Returns:
None

"""
_make_tfx_schema_domain_info_inline(schema)
for feature_from_tfx_schema in schema.feature:
if feature_from_tfx_schema.name in self._fields.keys():
field = self._fields[feature_from_tfx_schema.name]
field.update_presence_constraints(feature_from_tfx_schema)
field.update_shape_type(feature_from_tfx_schema)
field.update_domain_info(feature_from_tfx_schema)
else:
warnings.warn(
f"The provided schema contains feature name '{feature_from_tfx_schema.name}' "
f"that does not exist in the FeatureSet '{self.name}' in Feast"
)

def export_tfx_schema(self) -> schema_pb2.Schema:
"""
Create a Tensorflow metadata schema from a FeatureSet.

Returns:
Tensorflow metadata schema.

"""
schema = schema_pb2.Schema()

# List of attributes to copy from fields in the FeatureSet to feature in
# Tensorflow metadata schema where the attribute name is the same.
attributes_to_copy_from_field_to_feature = [
"name",
"presence",
"group_presence",
"shape",
"value_count",
"domain",
"int_domain",
"float_domain",
"string_domain",
"bool_domain",
"struct_domain",
"_natural_language_domain",
"image_domain",
"mid_domain",
"url_domain",
"time_domain",
"time_of_day_domain",
]

for _, field in self._fields.items():
feature = schema_pb2.Feature()
for attr in attributes_to_copy_from_field_to_feature:
if getattr(field, attr) is None:
# This corresponds to an unset member in the proto Oneof field.
continue
if issubclass(type(getattr(feature, attr)), Message):
# Proto message field to copy is an "embedded" field, so MergeFrom()
# method must be used.
getattr(feature, attr).MergeFrom(getattr(field, attr))
elif issubclass(type(getattr(feature, attr)), (int, str, bool)):
# Proto message field is a simple Python type, so setattr()
# can be used.
setattr(feature, attr, getattr(field, attr))
else:
warnings.warn(
f"Attribute '{attr}' cannot be copied from Field "
f"'{field.name}' in FeatureSet '{self.name}' to a "
f"Feature in the Tensorflow metadata schema, because"
f"the type is neither a Protobuf message or Python "
f"int, str and bool"
)
# "type" attr is handled separately because the attribute name is different
# ("dtype" in field and "type" in Feature) and "type" in Feature is only
# a subset of "dtype".
feature.type = field.dtype.to_tfx_schema_feature_type()
schema.feature.append(feature)

return schema

@classmethod
def from_yaml(cls, yml: str):
"""
Expand Down Expand Up @@ -855,6 +944,40 @@ def __hash__(self):
return hash(repr(self))


def _make_tfx_schema_domain_info_inline(schema: schema_pb2.Schema) -> None:
"""
Copy top level domain info defined at schema level into inline definition.
One use case is when importing domain info from Tensorflow metadata schema
into Feast features. Feast features do not have access to schema level information
so the domain info needs to be inline.

Args:
schema: Tensorflow metadata schema

Returns: None
"""
# Reference to domains defined at schema level
domain_ref_to_string_domain = {d.name: d for d in schema.string_domain}
domain_ref_to_float_domain = {d.name: d for d in schema.float_domain}
domain_ref_to_int_domain = {d.name: d for d in schema.int_domain}

# With the reference, it is safe to remove the domains defined at schema level
del schema.string_domain[:]
del schema.float_domain[:]
del schema.int_domain[:]

for feature in schema.feature:
domain_info_case = feature.WhichOneof("domain_info")
if domain_info_case == "domain":
domain_ref = feature.domain
if domain_ref in domain_ref_to_string_domain:
feature.string_domain.MergeFrom(domain_ref_to_string_domain[domain_ref])
elif domain_ref in domain_ref_to_float_domain:
feature.float_domain.MergeFrom(domain_ref_to_float_domain[domain_ref])
elif domain_ref in domain_ref_to_int_domain:
feature.int_domain.MergeFrom(domain_ref_to_int_domain[domain_ref])


def _infer_pd_column_type(column, series, rows_to_sample):
dtype = None
sample_count = 0
Expand Down
Loading