diff --git a/README.md b/README.md index c43e680f6..2f8851ba0 100644 --- a/README.md +++ b/README.md @@ -105,6 +105,8 @@ We support and actively test against certain third-party clients to ensure compa |`SUBSTR(str, pos, [len])`|Return a substring from the provided string starting at `pos` with a length of `len` characters. If no `len` is provided, all characters from `pos` until the end will be taken.| |`SUBSTRING(str, pos, [len])`|Return a substring from the provided string starting at `pos` with a length of `len` characters. If no `len` is provided, all characters from `pos` until the end will be taken.| |`SUM(expr)`|Returns the sum of expr in all rows.| +|`TO_BASE64(str)`|Encodes the string str in base64 format.| +|`FROM_BASE64(str)`|Decodes the base64-encoded string str.| |`TRIM(str)`|Returns the string str with all spaces removed.| |`UPPER(str)`|Returns the string str with all characters in upper case.| |`WEEKDAY(date)`|Returns the weekday of the given date.| diff --git a/SUPPORTED.md b/SUPPORTED.md index a1fcd3f9b..1a5922042 100644 --- a/SUPPORTED.md +++ b/SUPPORTED.md @@ -110,6 +110,8 @@ - LOG2 - LOG10 - SLEEP +- TO_BASE64 +- FROM_BASE64 ## Time functions - DAY diff --git a/engine_test.go b/engine_test.go index c2f1e01bd..91ee7932a 100644 --- a/engine_test.go +++ b/engine_test.go @@ -881,6 +881,14 @@ var queries = []struct { "SELECT SLEEP(0.5)", []sql.Row{{int(0)}}, }, + { + "SELECT TO_BASE64('foo')", + []sql.Row{{string("Zm9v")}}, + }, + { + "SELECT FROM_BASE64('YmFy')", + []sql.Row{{string("bar")}}, + }, } func TestQueries(t *testing.T) { diff --git a/sql/expression/function/registry.go b/sql/expression/function/registry.go index b2e1732cd..850ed30d7 100644 --- a/sql/expression/function/registry.go +++ b/sql/expression/function/registry.go @@ -75,4 +75,6 @@ var Defaults = []sql.Function{ sql.Function2{Name: "nullif", Fn: NewNullIf}, sql.Function0{Name: "now", Fn: NewNow}, sql.Function1{Name: "sleep", Fn: NewSleep}, + sql.Function1{Name: "to_base64", Fn: NewToBase64}, + sql.Function1{Name: "from_base64", Fn: NewFromBase64}, } diff --git a/sql/expression/function/tobase64_frombase64.go b/sql/expression/function/tobase64_frombase64.go new file mode 100644 index 000000000..c6d484ef2 --- /dev/null +++ b/sql/expression/function/tobase64_frombase64.go @@ -0,0 +1,148 @@ +package function + +import ( + "encoding/base64" + "fmt" + "reflect" + "strings" + + "gopkg.in/src-d/go-mysql-server.v0/sql" + "gopkg.in/src-d/go-mysql-server.v0/sql/expression" +) + +// ToBase64 is a function to encode a string to the Base64 format +// using the same dialect that MySQL's TO_BASE64 uses +type ToBase64 struct { + expression.UnaryExpression +} + +// NewToBase64 creates a new ToBase64 expression. +func NewToBase64(e sql.Expression) sql.Expression { + return &ToBase64{expression.UnaryExpression{Child: e}} +} + +// Eval implements the Expression interface. +func (t *ToBase64) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + str, err := t.Child.Eval(ctx, row) + + if err != nil { + return nil, err + } + + if str == nil { + return nil, nil + } + + str, err = sql.Text.Convert(str) + if err != nil { + return nil, sql.ErrInvalidType.New(reflect.TypeOf(str)) + } + + encoded := base64.StdEncoding.EncodeToString([]byte(str.(string))) + + lenEncoded := len(encoded) + if lenEncoded <= 76 { + return encoded, nil + } + + // Split into max 76 chars lines + var out strings.Builder + start := 0 + end := 76 + for { + out.WriteString(encoded[start:end] + "\n") + start += 76 + end += 76 + if end >= lenEncoded { + out.WriteString(encoded[start:lenEncoded]) + break + } + } + + return out.String(), nil +} + +// String implements the Stringer interface. +func (t *ToBase64) String() string { + return fmt.Sprintf("TO_BASE64(%s)", t.Child) +} + +// IsNullable implements the Expression interface. +func (t *ToBase64) IsNullable() bool { + return t.Child.IsNullable() +} + +// TransformUp implements the Expression interface. +func (t *ToBase64) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { + child, err := t.Child.TransformUp(f) + if err != nil { + return nil, err + } + return f(NewToBase64(child)) +} + +// Type implements the Expression interface. +func (t *ToBase64) Type() sql.Type { + return sql.Text +} + + +// FromBase64 is a function to decode a Base64-formatted string +// using the same dialect that MySQL's FROM_BASE64 uses +type FromBase64 struct { + expression.UnaryExpression +} + +// NewFromBase64 creates a new FromBase64 expression. +func NewFromBase64(e sql.Expression) sql.Expression { + return &FromBase64{expression.UnaryExpression{Child: e}} +} + +// Eval implements the Expression interface. +func (t *FromBase64) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + str, err := t.Child.Eval(ctx, row) + + if err != nil { + return nil, err + } + + if str == nil { + return nil, nil + } + + str, err = sql.Text.Convert(str) + if err != nil { + return nil, sql.ErrInvalidType.New(reflect.TypeOf(str)) + } + + decoded, err := base64.StdEncoding.DecodeString(str.(string)) + if err != nil { + return nil, err + } + + return string(decoded), nil +} + +// String implements the Stringer interface. +func (t *FromBase64) String() string { + return fmt.Sprintf("FROM_BASE64(%s)", t.Child) +} + +// IsNullable implements the Expression interface. +func (t *FromBase64) IsNullable() bool { + return t.Child.IsNullable() +} + +// TransformUp implements the Expression interface. +func (t *FromBase64) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) { + child, err := t.Child.TransformUp(f) + if err != nil { + return nil, err + } + return f(NewFromBase64(child)) +} + +// Type implements the Expression interface. +func (t *FromBase64) Type() sql.Type { + return sql.Text +} diff --git a/sql/expression/function/tobase64_frombase64_test.go b/sql/expression/function/tobase64_frombase64_test.go new file mode 100644 index 000000000..f1dd461f2 --- /dev/null +++ b/sql/expression/function/tobase64_frombase64_test.go @@ -0,0 +1,56 @@ +package function + +import ( + "testing" + + "github.com/stretchr/testify/require" + "gopkg.in/src-d/go-mysql-server.v0/sql" + "gopkg.in/src-d/go-mysql-server.v0/sql/expression" +) + +func TestBase64(t *testing.T) { + fTo := NewToBase64(expression.NewGetField(0, sql.Text, "", false)) + fFrom := NewFromBase64(expression.NewGetField(0, sql.Text, "", false)) + + testCases := []struct { + name string + row sql.Row + expected interface{} + err bool + }{ + // Use a MySQL server to get expected values if updating/adding to this! + {"null input", sql.NewRow(nil), nil, false}, + {"single_line", sql.NewRow("foo"), string("Zm9v"), false}, + {"multi_line", sql.NewRow( + "Gallia est omnis divisa in partes tres, quarum unam " + + "incolunt Belgae, aliam Aquitani, tertiam qui ipsorum lingua Celtae, " + + "nostra Galli appellantur"), + "R2FsbGlhIGVzdCBvbW5pcyBkaXZpc2EgaW4gcGFydGVzIHRyZXMsIHF1YXJ1bSB1bmFtIGluY29s\n" + + "dW50IEJlbGdhZSwgYWxpYW0gQXF1aXRhbmksIHRlcnRpYW0gcXVpIGlwc29ydW0gbGluZ3VhIENl\n" + + "bHRhZSwgbm9zdHJhIEdhbGxpIGFwcGVsbGFudHVy", false}, + {"empty_input", sql.NewRow(""), string(""), false}, + {"symbols", sql.NewRow("!@#$% %^&*()_+\r\n\t{};"), string("IUAjJCUgJV4mKigpXysNCgl7fTs="), + false}, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + t.Helper() + require := require.New(t) + ctx := sql.NewEmptyContext() + v, err := fTo.Eval(ctx, tt.row) + + if tt.err { + require.Error(err) + } else { + require.NoError(err) + require.Equal(tt.expected, v) + + ctx = sql.NewEmptyContext() + v2, err := fFrom.Eval(ctx, sql.NewRow(v)) + require.NoError(err) + require.Equal(sql.NewRow(v2), tt.row) + } + }) + } +}