-
Notifications
You must be signed in to change notification settings - Fork 1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
zhilingc
committed
Oct 17, 2019
1 parent
dfaa13b
commit 0a7e767
Showing
8 changed files
with
167 additions
and
40 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,17 +1,49 @@ | ||
package feast | ||
|
||
import ( | ||
"fmt" | ||
"github.com/gojek/feast/sdk/go/protos/feast/serving" | ||
) | ||
|
||
var ( | ||
ErrLengthMismatch = "Length mismatch; number of na values (%d) not equal to number of features requested (%d)." | ||
ErrFeatureNotFound = "Feature %s not found in response." | ||
) | ||
|
||
// OnlineFeaturesResponse is a wrapper around serving.GetOnlineFeaturesResponse. | ||
type OnlineFeaturesResponse struct { | ||
Features []string | ||
RawResponse *serving.GetOnlineFeaturesResponse | ||
} | ||
|
||
func (res OnlineFeaturesResponse) ToInt64Array(missingVals map[string]int64) { | ||
// convert the values here | ||
//for _, fields := range res.FieldOrder { | ||
// | ||
//} | ||
} | ||
// Rows retrieves the result of the request as a list of Rows. | ||
func (r OnlineFeaturesResponse) Rows() []Row { | ||
rows := make([]Row, len(r.RawResponse.FieldValues)) | ||
for i, val := range r.RawResponse.FieldValues { | ||
rows[i] = val.Fields | ||
} | ||
return rows | ||
} | ||
|
||
// Int64Arrays retrieves the result of the request as a list of int64 slices. Any missing values will be filled | ||
// with the missing values provided. | ||
func (r OnlineFeaturesResponse) Int64Arrays(order []string, fillNa []int64) ([][]int64, error) { | ||
rows := make([][]int64, len(r.RawResponse.FieldValues)) | ||
if len(fillNa) != len(order) { | ||
return nil, fmt.Errorf(ErrLengthMismatch, len(fillNa), len(order)) | ||
} | ||
for i, val := range r.RawResponse.FieldValues { | ||
rows[i] = make([]int64, len(order)) | ||
for j, fname := range order { | ||
fValue, exists := val.Fields[fname] | ||
if !exists { | ||
return nil, fmt.Errorf(ErrFeatureNotFound, fname) | ||
} | ||
if fValue.GetVal() == nil { | ||
rows[i][j] = fillNa[j] | ||
} else { | ||
rows[i][j] = fValue.GetInt64Val() | ||
} | ||
} | ||
} | ||
return rows, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
package feast | ||
|
||
import ( | ||
"fmt" | ||
"github.com/gojek/feast/sdk/go/protos/feast/serving" | ||
"github.com/gojek/feast/sdk/go/protos/feast/types" | ||
"github.com/google/go-cmp/cmp" | ||
"testing" | ||
) | ||
|
||
var response = OnlineFeaturesResponse{ | ||
RawResponse: &serving.GetOnlineFeaturesResponse{ | ||
FieldValues: []*serving.GetOnlineFeaturesResponse_FieldValues{ | ||
{ | ||
Fields: map[string]*types.Value{ | ||
"fs:1:feature1": Int64Val(1), | ||
"fs:1:feature2": &types.Value{}, | ||
}, | ||
}, | ||
{ | ||
Fields: map[string]*types.Value{ | ||
"fs:1:feature1": Int64Val(2), | ||
"fs:1:feature2": Int64Val(2), | ||
}, | ||
}, | ||
}, | ||
}, | ||
} | ||
|
||
func TestOnlineFeaturesResponseToRow(t *testing.T) { | ||
actual := response.Rows() | ||
expected := []Row{ | ||
{"fs:1:feature1": Int64Val(1), "fs:1:feature2": &types.Value{}}, | ||
{"fs:1:feature1": Int64Val(2), "fs:1:feature2": Int64Val(2)}, | ||
} | ||
if !cmp.Equal(actual, expected) { | ||
t.Errorf("expected: %v, got: %v", expected, actual) | ||
} | ||
} | ||
|
||
func TestOnlineFeaturesResponseToInt64Array(t *testing.T) { | ||
type args struct { | ||
order []string | ||
fillNa []int64 | ||
} | ||
tt := []struct { | ||
name string | ||
args args | ||
want [][]int64 | ||
wantErr bool | ||
err error | ||
}{ | ||
{ | ||
name: "valid", | ||
args: args{ | ||
order: []string{"fs:1:feature2", "fs:1:feature1"}, | ||
fillNa: []int64{-1, -1}, | ||
}, | ||
want: [][]int64{{-1, 1}, {2, 2}}, | ||
wantErr: false, | ||
}, | ||
{ | ||
name: "length mismatch", | ||
args: args{ | ||
order: []string{"fs:1:feature2", "fs:1:feature1"}, | ||
fillNa: []int64{-1}, | ||
}, | ||
want: nil, | ||
wantErr: true, | ||
err: fmt.Errorf(ErrLengthMismatch, 1, 2), | ||
}, | ||
{ | ||
name: "length mismatch", | ||
args: args{ | ||
order: []string{"fs:1:feature2", "fs:1:feature3"}, | ||
fillNa: []int64{-1, -1}, | ||
}, | ||
want: nil, | ||
wantErr: true, | ||
err: fmt.Errorf(ErrFeatureNotFound, "fs:1:feature3"), | ||
}, | ||
} | ||
for _, tc := range tt { | ||
t.Run(tc.name, func(t *testing.T) { | ||
got, err := response.Int64Arrays(tc.args.order, tc.args.fillNa) | ||
if (err != nil) != tc.wantErr { | ||
t.Errorf("error = %v, wantErr %v", err, tc.wantErr) | ||
return | ||
} | ||
if tc.wantErr && err.Error() != tc.err.Error() { | ||
t.Errorf("error = %v, expected err = %v", err, tc.err) | ||
return | ||
} | ||
if !cmp.Equal(got, tc.want) { | ||
t.Errorf("got: \n%v\nwant:\n%v", got, tc.want) | ||
} | ||
}) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters