diff --git a/bigquery/params.go b/bigquery/params.go index 7d6fe362cc81..046b874301d4 100644 --- a/bigquery/params.go +++ b/bigquery/params.go @@ -17,6 +17,7 @@ package bigquery import ( "encoding/base64" "fmt" + "reflect" "time" bq "google.golang.org/api/bigquery/v2" @@ -34,23 +35,32 @@ var ( timestampParamType = &bq.QueryParameterType{Type: "TIMESTAMP"} ) -func paramType(x interface{}) (*bq.QueryParameterType, error) { - switch x.(type) { - case int, int8, int16, int32, int64, uint8, uint16, uint32: +var timeType = reflect.TypeOf(time.Time{}) + +func paramType(t reflect.Type) (*bq.QueryParameterType, error) { + switch t.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32: return int64ParamType, nil - case float32, float64: + case reflect.Float32, reflect.Float64: return float64ParamType, nil - case bool: + case reflect.Bool: return boolParamType, nil - case string: + case reflect.String: return stringParamType, nil - case time.Time: + case reflect.Slice, reflect.Array: + if t.Kind() == reflect.Slice && t.Elem().Kind() == reflect.Uint8 { + return bytesParamType, nil + } + et, err := paramType(t.Elem()) + if err != nil { + return nil, err + } + return &bq.QueryParameterType{Type: "ARRAY", ArrayType: et}, nil + } + if t == timeType { return timestampParamType, nil - case []byte: - return bytesParamType, nil - default: - return nil, fmt.Errorf("Go type %T cannot be represented as a parameter type", x) } + return nil, fmt.Errorf("Go type %s cannot be represented as a parameter type", t) } func paramValue(x interface{}) (bq.QueryParameterValue, error) { @@ -63,7 +73,20 @@ func paramValue(x interface{}) (bq.QueryParameterValue, error) { return sval(base64.StdEncoding.EncodeToString(x)), nil case time.Time: return sval(x.Format(timestampFormat)), nil - default: - return sval(fmt.Sprint(x)), nil } + t := reflect.TypeOf(x) + switch t.Kind() { + case reflect.Slice, reflect.Array: + var vals []*bq.QueryParameterValue + v := reflect.ValueOf(x) + for i := 0; i < v.Len(); i++ { + val, err := paramValue(v.Index(i).Interface()) + if err != nil { + return bq.QueryParameterValue{}, err + } + vals = append(vals, &val) + } + return bq.QueryParameterValue{ArrayValues: vals}, nil + } + return sval(fmt.Sprint(x)), nil } diff --git a/bigquery/params_test.go b/bigquery/params_test.go index bd545bd995d9..97c75cc23d7b 100644 --- a/bigquery/params_test.go +++ b/bigquery/params_test.go @@ -15,8 +15,8 @@ package bigquery import ( - "bytes" "context" + "errors" "math" "reflect" "testing" @@ -60,7 +60,31 @@ func TestParamValueScalar(t *testing.T) { } } -func TestParamTypeScalar(t *testing.T) { +func TestParamValueArray(t *testing.T) { + for _, test := range []struct { + val interface{} + want []string + }{ + {[]int(nil), []string{}}, + {[]int{}, []string{}}, + {[]int{1, 2}, []string{"1", "2"}}, + {[3]int{1, 2, 3}, []string{"1", "2", "3"}}, + } { + got, err := paramValue(test.val) + if err != nil { + t.Fatal(err) + } + var want bq.QueryParameterValue + for _, s := range test.want { + want.ArrayValues = append(want.ArrayValues, &bq.QueryParameterValue{Value: s}) + } + if !reflect.DeepEqual(got, want) { + t.Errorf("%#v:\ngot %+v\nwant %+v", test.val, got, want) + } + } +} + +func TestParamType(t *testing.T) { for _, test := range []struct { val interface{} want *bq.QueryParameterType @@ -75,42 +99,71 @@ func TestParamTypeScalar(t *testing.T) { {"string", stringParamType}, {time.Now(), timestampParamType}, {[]byte("foo"), bytesParamType}, + {[]int{}, &bq.QueryParameterType{Type: "ARRAY", ArrayType: int64ParamType}}, + {[3]bool{}, &bq.QueryParameterType{Type: "ARRAY", ArrayType: boolParamType}}, } { - got, err := paramType(test.val) + got, err := paramType(reflect.TypeOf(test.val)) if err != nil { t.Fatal(err) } - if got != test.want { + if !reflect.DeepEqual(got, test.want) { t.Errorf("%v (%T): got %v, want %v", test.val, test.val, got, test.want) } } } func TestIntegration_ScalarParam(t *testing.T) { - ctx := context.Background() c := getClient(t) for _, test := range scalarTests { - q := c.Query("select ?") - q.Parameters = []QueryParameter{{Value: test.val}} - it, err := q.Read(ctx) + got, err := paramRoundTrip(c, test.val) if err != nil { t.Fatal(err) } - var val []Value - err = it.Next(&val) + if !equal(got, test.val) { + t.Errorf("\ngot %#v (%T)\nwant %#v (%T)", got, got, test.val, test.val) + } + } +} + +func TestIntegration_ArrayParam(t *testing.T) { + c := getClient(t) + for _, test := range []struct { + val interface{} + want interface{} + }{ + {[]int(nil), []Value(nil)}, + {[]int{}, []Value(nil)}, + {[]int{1, 2}, []Value{int64(1), int64(2)}}, + {[3]int{1, 2, 3}, []Value{int64(1), int64(2), int64(3)}}, + } { + got, err := paramRoundTrip(c, test.val) if err != nil { t.Fatal(err) } - if len(val) != 1 { - t.Fatalf("got %d values, want 1", len(val)) - } - got := val[0] - if !equal(got, test.val) { - t.Errorf("\ngot %#v (%T)\nwant %#v (%T)", got, got, test.val, test.val) + if !equal(got, test.want) { + t.Errorf("\ngot %#v (%T)\nwant %#v (%T)", got, got, test.want, test.want) } } } +func paramRoundTrip(c *Client, x interface{}) (Value, error) { + q := c.Query("select ?") + q.Parameters = []QueryParameter{{Value: x}} + it, err := q.Read(context.Background()) + if err != nil { + return nil, err + } + var val []Value + err = it.Next(&val) + if err != nil { + return nil, err + } + if len(val) != 1 { + return nil, errors.New("wrong number of values") + } + return val[0], nil +} + func equal(x1, x2 interface{}) bool { if reflect.TypeOf(x1) != reflect.TypeOf(x2) { return false @@ -124,9 +177,7 @@ func equal(x1, x2 interface{}) bool { case time.Time: // BigQuery is only accurate to the microsecond. return x1.Round(time.Microsecond).Equal(x2.(time.Time).Round(time.Microsecond)) - case []byte: - return bytes.Equal(x1, x2.([]byte)) default: - return x1 == x2 + return reflect.DeepEqual(x1, x2) } } diff --git a/bigquery/query.go b/bigquery/query.go index 8ad24442f0b4..af40d239adeb 100644 --- a/bigquery/query.go +++ b/bigquery/query.go @@ -15,6 +15,8 @@ package bigquery import ( + "reflect" + "golang.org/x/net/context" bq "google.golang.org/api/bigquery/v2" ) @@ -209,7 +211,7 @@ func (q *QueryConfig) populateJobQueryConfig(conf *bq.JobConfigurationQuery) err if err != nil { return err } - pt, err := paramType(p.Value) + pt, err := paramType(reflect.TypeOf(p.Value)) if err != nil { return err }