From c491e5726d792f731f961b66fdf6c0b772165e86 Mon Sep 17 00:00:00 2001 From: "bdodla@expedia.com" <13788369+EXPEbdodla@users.noreply.github.com> Date: Thu, 13 Jun 2024 06:28:39 -0700 Subject: [PATCH] fix: Handles null values in data during GO Feature retrieval (#4274) * fix: Handles null values in data during GO Feature retrieval Signed-off-by: Bhargav Dodla * fix: Fixed formatting issues Signed-off-by: Bhargav Dodla * fix: Fixed linting issues Signed-off-by: Bhargav Dodla --------- Signed-off-by: Bhargav Dodla Co-authored-by: Bhargav Dodla --- go/types/typeconversion.go | 146 ++++++++++++++----------- go/types/typeconversion_test.go | 30 ++++- sdk/python/feast/type_map.py | 71 ++++++++---- sdk/python/tests/unit/test_type_map.py | 8 ++ 4 files changed, 166 insertions(+), 89 deletions(-) diff --git a/go/types/typeconversion.go b/go/types/typeconversion.go index 45eeac52c6..18b4769b4d 100644 --- a/go/types/typeconversion.go +++ b/go/types/typeconversion.go @@ -11,6 +11,9 @@ import ( ) func ProtoTypeToArrowType(sample *types.Value) (arrow.DataType, error) { + if sample.Val == nil { + return nil, nil + } switch sample.Val.(type) { case *types.Value_BytesVal: return arrow.BinaryTypes.Binary, nil @@ -91,81 +94,71 @@ func ValueTypeEnumToArrowType(t types.ValueType_Enum) (arrow.DataType, error) { } func CopyProtoValuesToArrowArray(builder array.Builder, values []*types.Value) error { - switch fieldBuilder := builder.(type) { - case *array.BooleanBuilder: - for _, v := range values { - fieldBuilder.Append(v.GetBoolVal()) - } - case *array.BinaryBuilder: - for _, v := range values { - fieldBuilder.Append(v.GetBytesVal()) - } - case *array.StringBuilder: - for _, v := range values { - fieldBuilder.Append(v.GetStringVal()) - } - case *array.Int32Builder: - for _, v := range values { - fieldBuilder.Append(v.GetInt32Val()) - } - case *array.Int64Builder: - for _, v := range values { - fieldBuilder.Append(v.GetInt64Val()) - } - case *array.Float32Builder: - for _, v := range values { - fieldBuilder.Append(v.GetFloatVal()) + for _, value := range values { + if value == nil || value.Val == nil { + builder.AppendNull() + continue } - case *array.Float64Builder: - for _, v := range values { - fieldBuilder.Append(v.GetDoubleVal()) - } - case *array.TimestampBuilder: - for _, v := range values { - fieldBuilder.Append(arrow.Timestamp(v.GetUnixTimestampVal())) - } - case *array.ListBuilder: - for _, list := range values { + + switch fieldBuilder := builder.(type) { + + case *array.BooleanBuilder: + fieldBuilder.Append(value.GetBoolVal()) + case *array.BinaryBuilder: + fieldBuilder.Append(value.GetBytesVal()) + case *array.StringBuilder: + fieldBuilder.Append(value.GetStringVal()) + case *array.Int32Builder: + fieldBuilder.Append(value.GetInt32Val()) + case *array.Int64Builder: + fieldBuilder.Append(value.GetInt64Val()) + case *array.Float32Builder: + fieldBuilder.Append(value.GetFloatVal()) + case *array.Float64Builder: + fieldBuilder.Append(value.GetDoubleVal()) + case *array.TimestampBuilder: + fieldBuilder.Append(arrow.Timestamp(value.GetUnixTimestampVal())) + case *array.ListBuilder: fieldBuilder.Append(true) switch valueBuilder := fieldBuilder.ValueBuilder().(type) { case *array.BooleanBuilder: - for _, v := range list.GetBoolListVal().GetVal() { + for _, v := range value.GetBoolListVal().GetVal() { valueBuilder.Append(v) } case *array.BinaryBuilder: - for _, v := range list.GetBytesListVal().GetVal() { + for _, v := range value.GetBytesListVal().GetVal() { valueBuilder.Append(v) } case *array.StringBuilder: - for _, v := range list.GetStringListVal().GetVal() { + for _, v := range value.GetStringListVal().GetVal() { valueBuilder.Append(v) } case *array.Int32Builder: - for _, v := range list.GetInt32ListVal().GetVal() { + for _, v := range value.GetInt32ListVal().GetVal() { valueBuilder.Append(v) } case *array.Int64Builder: - for _, v := range list.GetInt64ListVal().GetVal() { + for _, v := range value.GetInt64ListVal().GetVal() { valueBuilder.Append(v) } case *array.Float32Builder: - for _, v := range list.GetFloatListVal().GetVal() { + for _, v := range value.GetFloatListVal().GetVal() { valueBuilder.Append(v) } case *array.Float64Builder: - for _, v := range list.GetDoubleListVal().GetVal() { + for _, v := range value.GetDoubleListVal().GetVal() { valueBuilder.Append(v) } case *array.TimestampBuilder: - for _, v := range list.GetUnixTimestampListVal().GetVal() { + for _, v := range value.GetUnixTimestampListVal().GetVal() { valueBuilder.Append(arrow.Timestamp(v)) } } + default: + return fmt.Errorf("unsupported array builder: %s", builder) } - default: - return fmt.Errorf("unsupported array builder: %s", builder) } return nil } @@ -249,41 +242,68 @@ func ArrowValuesToProtoValues(arr arrow.Array) ([]*types.Value, error) { switch arr.DataType() { case arrow.PrimitiveTypes.Int32: - for _, v := range arr.(*array.Int32).Int32Values() { - values = append(values, &types.Value{Val: &types.Value_Int32Val{Int32Val: v}}) + for idx := 0; idx < arr.Len(); idx++ { + if arr.IsNull(idx) { + values = append(values, &types.Value{}) + } else { + values = append(values, &types.Value{Val: &types.Value_Int32Val{Int32Val: arr.(*array.Int32).Value(idx)}}) + } } case arrow.PrimitiveTypes.Int64: - for _, v := range arr.(*array.Int64).Int64Values() { - values = append(values, &types.Value{Val: &types.Value_Int64Val{Int64Val: v}}) + for idx := 0; idx < arr.Len(); idx++ { + if arr.IsNull(idx) { + values = append(values, &types.Value{}) + } else { + values = append(values, &types.Value{Val: &types.Value_Int64Val{Int64Val: arr.(*array.Int64).Value(idx)}}) + } } case arrow.PrimitiveTypes.Float32: - for _, v := range arr.(*array.Float32).Float32Values() { - values = append(values, &types.Value{Val: &types.Value_FloatVal{FloatVal: v}}) + for idx := 0; idx < arr.Len(); idx++ { + if arr.IsNull(idx) { + values = append(values, &types.Value{}) + } else { + values = append(values, &types.Value{Val: &types.Value_FloatVal{FloatVal: arr.(*array.Float32).Value(idx)}}) + } } case arrow.PrimitiveTypes.Float64: - for _, v := range arr.(*array.Float64).Float64Values() { - values = append(values, &types.Value{Val: &types.Value_DoubleVal{DoubleVal: v}}) + for idx := 0; idx < arr.Len(); idx++ { + if arr.IsNull(idx) { + values = append(values, &types.Value{}) + } else { + values = append(values, &types.Value{Val: &types.Value_DoubleVal{DoubleVal: arr.(*array.Float64).Value(idx)}}) + } } case arrow.FixedWidthTypes.Boolean: for idx := 0; idx < arr.Len(); idx++ { - values = append(values, - &types.Value{Val: &types.Value_BoolVal{BoolVal: arr.(*array.Boolean).Value(idx)}}) + if arr.IsNull(idx) { + values = append(values, &types.Value{}) + } else { + values = append(values, &types.Value{Val: &types.Value_BoolVal{BoolVal: arr.(*array.Boolean).Value(idx)}}) + } } case arrow.BinaryTypes.Binary: for idx := 0; idx < arr.Len(); idx++ { - values = append(values, - &types.Value{Val: &types.Value_BytesVal{BytesVal: arr.(*array.Binary).Value(idx)}}) + if arr.IsNull(idx) { + values = append(values, &types.Value{}) + } else { + values = append(values, &types.Value{Val: &types.Value_BytesVal{BytesVal: arr.(*array.Binary).Value(idx)}}) + } } case arrow.BinaryTypes.String: for idx := 0; idx < arr.Len(); idx++ { - values = append(values, - &types.Value{Val: &types.Value_StringVal{StringVal: arr.(*array.String).Value(idx)}}) + if arr.IsNull(idx) { + values = append(values, &types.Value{}) + } else { + values = append(values, &types.Value{Val: &types.Value_StringVal{StringVal: arr.(*array.String).Value(idx)}}) + } } case arrow.FixedWidthTypes.Timestamp_s: for idx := 0; idx < arr.Len(); idx++ { - values = append(values, - &types.Value{Val: &types.Value_UnixTimestampVal{ - UnixTimestampVal: int64(arr.(*array.Timestamp).Value(idx))}}) + if arr.IsNull(idx) { + values = append(values, &types.Value{}) + } else { + values = append(values, &types.Value{Val: &types.Value_UnixTimestampVal{UnixTimestampVal: int64(arr.(*array.Timestamp).Value(idx))}}) + } } case arrow.Null: for idx := 0; idx < arr.Len(); idx++ { @@ -306,7 +326,9 @@ func ProtoValuesToArrowArray(protoValues []*types.Value, arrowAllocator memory.A if err != nil { return nil, err } - break + if fieldType != nil { + break + } } } diff --git a/go/types/typeconversion_test.go b/go/types/typeconversion_test.go index 1f89593ea0..4869369c18 100644 --- a/go/types/typeconversion_test.go +++ b/go/types/typeconversion_test.go @@ -1,27 +1,46 @@ package types import ( + "math" "testing" "time" "github.com/apache/arrow/go/v8/arrow/memory" - "github.com/golang/protobuf/proto" "github.com/stretchr/testify/assert" + "google.golang.org/protobuf/proto" "github.com/feast-dev/feast/go/protos/feast/types" ) +var nil_or_null_val = &types.Value{} + var ( PROTO_VALUES = [][]*types.Value{ + {{Val: nil}}, + {{Val: nil}, {Val: nil}}, + {nil_or_null_val, nil_or_null_val}, + {nil_or_null_val, {Val: nil}}, + {{Val: &types.Value_Int32Val{10}}, {Val: nil}, nil_or_null_val, {Val: &types.Value_Int32Val{20}}}, + {{Val: &types.Value_Int32Val{10}}, nil_or_null_val}, + {nil_or_null_val, {Val: &types.Value_Int32Val{20}}}, {{Val: &types.Value_Int32Val{10}}, {Val: &types.Value_Int32Val{20}}}, + {{Val: &types.Value_Int64Val{10}}, nil_or_null_val}, {{Val: &types.Value_Int64Val{10}}, {Val: &types.Value_Int64Val{20}}}, + {nil_or_null_val, {Val: &types.Value_FloatVal{2.0}}}, {{Val: &types.Value_FloatVal{1.0}}, {Val: &types.Value_FloatVal{2.0}}}, + {{Val: &types.Value_FloatVal{1.0}}, {Val: &types.Value_FloatVal{2.0}}, {Val: &types.Value_FloatVal{float32(math.NaN())}}}, {{Val: &types.Value_DoubleVal{1.0}}, {Val: &types.Value_DoubleVal{2.0}}}, + {{Val: &types.Value_DoubleVal{1.0}}, {Val: &types.Value_DoubleVal{2.0}}, {Val: &types.Value_DoubleVal{math.NaN()}}}, + {{Val: &types.Value_DoubleVal{1.0}}, nil_or_null_val}, + {nil_or_null_val, {Val: &types.Value_StringVal{"bbb"}}}, {{Val: &types.Value_StringVal{"aaa"}}, {Val: &types.Value_StringVal{"bbb"}}}, + {{Val: &types.Value_BytesVal{[]byte{1, 2, 3}}}, nil_or_null_val}, {{Val: &types.Value_BytesVal{[]byte{1, 2, 3}}}, {Val: &types.Value_BytesVal{[]byte{4, 5, 6}}}}, + {nil_or_null_val, {Val: &types.Value_BoolVal{false}}}, {{Val: &types.Value_BoolVal{true}}, {Val: &types.Value_BoolVal{false}}}, - {{Val: &types.Value_UnixTimestampVal{time.Now().Unix()}}, - {Val: &types.Value_UnixTimestampVal{time.Now().Unix()}}}, + {{Val: &types.Value_UnixTimestampVal{time.Now().Unix()}}, nil_or_null_val}, + {{Val: &types.Value_UnixTimestampVal{time.Now().Unix()}}, {Val: &types.Value_UnixTimestampVal{time.Now().Unix()}}}, + {{Val: &types.Value_UnixTimestampVal{time.Now().Unix()}}, {Val: &types.Value_UnixTimestampVal{time.Now().Unix()}}, {Val: &types.Value_UnixTimestampVal{-9223372036854775808}}}, { {Val: &types.Value_Int32ListVal{&types.Int32List{Val: []int32{0, 1, 2}}}}, @@ -55,6 +74,11 @@ var ( {Val: &types.Value_UnixTimestampListVal{&types.Int64List{Val: []int64{time.Now().Unix()}}}}, {Val: &types.Value_UnixTimestampListVal{&types.Int64List{Val: []int64{time.Now().Unix()}}}}, }, + { + {Val: &types.Value_UnixTimestampListVal{&types.Int64List{Val: []int64{time.Now().Unix(), time.Now().Unix()}}}}, + {Val: &types.Value_UnixTimestampListVal{&types.Int64List{Val: []int64{time.Now().Unix(), time.Now().Unix()}}}}, + {Val: &types.Value_UnixTimestampListVal{&types.Int64List{Val: []int64{-9223372036854775808, time.Now().Unix()}}}}, + }, } ) diff --git a/sdk/python/feast/type_map.py b/sdk/python/feast/type_map.py index e7fdf97120..a0859f2f7a 100644 --- a/sdk/python/feast/type_map.py +++ b/sdk/python/feast/type_map.py @@ -13,6 +13,7 @@ # limitations under the License. import json +import logging from collections import defaultdict from datetime import datetime, timezone from typing import ( @@ -53,6 +54,8 @@ # null timestamps get converted to -9223372036854775808 NULL_TIMESTAMP_INT_VALUE: int = np.datetime64("NaT").astype(int) +logger = logging.getLogger(__name__) + def feast_value_type_to_python_type(field_value_proto: ProtoValue) -> Any: """ @@ -77,9 +80,11 @@ def feast_value_type_to_python_type(field_value_proto: ProtoValue) -> Any: # Convert UNIX_TIMESTAMP values to `datetime` if val_attr == "unix_timestamp_list_val": val = [ - datetime.fromtimestamp(v, tz=timezone.utc) - if v != NULL_TIMESTAMP_INT_VALUE - else None + ( + datetime.fromtimestamp(v, tz=timezone.utc) + if v != NULL_TIMESTAMP_INT_VALUE + else None + ) for v in val ] elif val_attr == "unix_timestamp_val": @@ -295,9 +300,11 @@ def _type_err(item, dtype): ValueType.INT32: ("int32_val", lambda x: int(x), None), ValueType.INT64: ( "int64_val", - lambda x: int(x.timestamp()) - if isinstance(x, pd._libs.tslibs.timestamps.Timestamp) - else int(x), + lambda x: ( + int(x.timestamp()) + if isinstance(x, pd._libs.tslibs.timestamps.Timestamp) + else int(x) + ), None, ), ValueType.FLOAT: ("float_val", lambda x: float(x), None), @@ -373,10 +380,18 @@ def _python_value_to_proto_value( if sample is not None and not all( type(item) in valid_types for item in sample ): - first_invalid = next( - item for item in sample if type(item) not in valid_types - ) - raise _type_err(first_invalid, valid_types[0]) + # to_numpy() in utils._convert_arrow_to_proto() upcasts values of type Array of INT32 or INT64 with NULL values to Float64 automatically. + for item in sample: + if type(item) not in valid_types: + if feast_value_type in [ + ValueType.INT32_LIST, + ValueType.INT64_LIST, + ]: + if not any(np.isnan(item) for item in sample): + logger.error( + "Array of Int32 or Int64 type has NULL values. to_numpy() upcasts to Float64 automatically." + ) + raise _type_err(item, valid_types[0]) if feast_value_type == ValueType.UNIX_TIMESTAMP_LIST: int_timestamps_lists = ( @@ -390,15 +405,21 @@ def _python_value_to_proto_value( if feast_value_type == ValueType.BOOL_LIST: # ProtoValue does not support conversion of np.bool_ so we need to convert it to support np.bool_. return [ - ProtoValue(**{field_name: proto_type(val=[bool(e) for e in value])}) # type: ignore - if value is not None - else ProtoValue() + ( + ProtoValue( + **{field_name: proto_type(val=[bool(e) for e in value])} # type: ignore + ) + if value is not None + else ProtoValue() + ) for value in values ] return [ - ProtoValue(**{field_name: proto_type(val=value)}) # type: ignore - if value is not None - else ProtoValue() + ( + ProtoValue(**{field_name: proto_type(val=value)}) # type: ignore + if value is not None + else ProtoValue() + ) for value in values ] @@ -433,15 +454,17 @@ def _python_value_to_proto_value( if feast_value_type == ValueType.BOOL: # ProtoValue does not support conversion of np.bool_ so we need to convert it to support np.bool_. return [ - ProtoValue( - **{ - field_name: func( - bool(value) if type(value) is np.bool_ else value # type: ignore - ) - } + ( + ProtoValue( + **{ + field_name: func( + bool(value) if type(value) is np.bool_ else value # type: ignore + ) + } + ) + if not pd.isnull(value) + else ProtoValue() ) - if not pd.isnull(value) - else ProtoValue() for value in values ] if feast_value_type in PYTHON_SCALAR_VALUE_TYPE_TO_PROTO_VALUE: diff --git a/sdk/python/tests/unit/test_type_map.py b/sdk/python/tests/unit/test_type_map.py index 87e5ef0548..39e3e7dafa 100644 --- a/sdk/python/tests/unit/test_type_map.py +++ b/sdk/python/tests/unit/test_type_map.py @@ -1,4 +1,5 @@ import numpy as np +import pandas as pd import pytest from feast.type_map import ( @@ -79,3 +80,10 @@ def test_python_values_to_proto_values_bytes_to_list(values, value_type, expecte def test_python_values_to_proto_values_bytes_to_list_not_supported(): with pytest.raises(TypeError): _ = python_values_to_proto_values([b"[]"], ValueType.BYTES_LIST) + + +def test_python_values_to_proto_values_int_list_with_null_not_supported(): + df = pd.DataFrame({"column": [1, 2, None]}) + arr = df["column"].to_numpy() + with pytest.raises(TypeError): + _ = python_values_to_proto_values(arr, ValueType.INT32_LIST)