Skip to content
This repository has been archived by the owner on Jan 28, 2021. It is now read-only.

Added from_base64 and to_base64 #650

Merged
merged 3 commits into from
Apr 11, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ We support and actively test against certain third-party clients to ensure compa
|`SUBSTR(str, pos, [len])`|Return a substring from the provided string starting at `pos` with a length of `len` characters. If no `len` is provided, all characters from `pos` until the end will be taken.|
|`SUBSTRING(str, pos, [len])`|Return a substring from the provided string starting at `pos` with a length of `len` characters. If no `len` is provided, all characters from `pos` until the end will be taken.|
|`SUM(expr)`|Returns the sum of expr in all rows.|
|`TO_BASE64(str)`|Encodes the string str in base64 format.|
|`FROM_BASE64(str)`|Decodes the base64-encoded string str.|
|`TRIM(str)`|Returns the string str with all spaces removed.|
|`UPPER(str)`|Returns the string str with all characters in upper case.|
|`WEEKDAY(date)`|Returns the weekday of the given date.|
Expand Down
2 changes: 2 additions & 0 deletions SUPPORTED.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@
- LOG2
- LOG10
- SLEEP
- TO_BASE64
- FROM_BASE64

## Time functions
- DAY
Expand Down
8 changes: 8 additions & 0 deletions engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -881,6 +881,14 @@ var queries = []struct {
"SELECT SLEEP(0.5)",
[]sql.Row{{int(0)}},
},
{
"SELECT TO_BASE64('foo')",
[]sql.Row{{string("Zm9v")}},
},
{
"SELECT FROM_BASE64('YmFy')",
[]sql.Row{{string("bar")}},
},
}

func TestQueries(t *testing.T) {
Expand Down
2 changes: 2 additions & 0 deletions sql/expression/function/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,6 @@ var Defaults = []sql.Function{
sql.Function2{Name: "nullif", Fn: NewNullIf},
sql.Function0{Name: "now", Fn: NewNow},
sql.Function1{Name: "sleep", Fn: NewSleep},
sql.Function1{Name: "to_base64", Fn: NewToBase64},
sql.Function1{Name: "from_base64", Fn: NewFromBase64},
}
148 changes: 148 additions & 0 deletions sql/expression/function/tobase64_frombase64.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
package function

import (
"encoding/base64"
"fmt"
"reflect"
"strings"

"gopkg.in/src-d/go-mysql-server.v0/sql"
"gopkg.in/src-d/go-mysql-server.v0/sql/expression"
)

// ToBase64 is a function to encode a string to the Base64 format
// using the same dialect that MySQL's TO_BASE64 uses
type ToBase64 struct {
expression.UnaryExpression
}

// NewToBase64 creates a new ToBase64 expression.
func NewToBase64(e sql.Expression) sql.Expression {
return &ToBase64{expression.UnaryExpression{Child: e}}
}

// Eval implements the Expression interface.
func (t *ToBase64) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
str, err := t.Child.Eval(ctx, row)

if err != nil {
return nil, err
}

if str == nil {
return nil, nil
}

str, err = sql.Text.Convert(str)
if err != nil {
return nil, sql.ErrInvalidType.New(reflect.TypeOf(str))
}

encoded := base64.StdEncoding.EncodeToString([]byte(str.(string)))

lenEncoded := len(encoded)
if lenEncoded <= 76 {
return encoded, nil
}

// Split into max 76 chars lines
var out strings.Builder
start := 0
end := 76
for {
out.WriteString(encoded[start:end] + "\n")
start += 76
end += 76
if end >= lenEncoded {
out.WriteString(encoded[start:lenEncoded])
break
}
}

return out.String(), nil
}

// String implements the Stringer interface.
func (t *ToBase64) String() string {
return fmt.Sprintf("TO_BASE64(%s)", t.Child)
}

// IsNullable implements the Expression interface.
func (t *ToBase64) IsNullable() bool {
return t.Child.IsNullable()
}

// TransformUp implements the Expression interface.
func (t *ToBase64) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) {
child, err := t.Child.TransformUp(f)
if err != nil {
return nil, err
}
return f(NewToBase64(child))
}

// Type implements the Expression interface.
func (t *ToBase64) Type() sql.Type {
return sql.Text
}


// FromBase64 is a function to decode a Base64-formatted string
// using the same dialect that MySQL's FROM_BASE64 uses
type FromBase64 struct {
expression.UnaryExpression
}

// NewFromBase64 creates a new FromBase64 expression.
func NewFromBase64(e sql.Expression) sql.Expression {
return &FromBase64{expression.UnaryExpression{Child: e}}
}

// Eval implements the Expression interface.
func (t *FromBase64) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
str, err := t.Child.Eval(ctx, row)

if err != nil {
return nil, err
}

if str == nil {
return nil, nil
}

str, err = sql.Text.Convert(str)
if err != nil {
return nil, sql.ErrInvalidType.New(reflect.TypeOf(str))
}

decoded, err := base64.StdEncoding.DecodeString(str.(string))
if err != nil {
return nil, err
}

return string(decoded), nil
}

// String implements the Stringer interface.
func (t *FromBase64) String() string {
return fmt.Sprintf("FROM_BASE64(%s)", t.Child)
}

// IsNullable implements the Expression interface.
func (t *FromBase64) IsNullable() bool {
return t.Child.IsNullable()
}

// TransformUp implements the Expression interface.
func (t *FromBase64) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) {
child, err := t.Child.TransformUp(f)
if err != nil {
return nil, err
}
return f(NewFromBase64(child))
}

// Type implements the Expression interface.
func (t *FromBase64) Type() sql.Type {
return sql.Text
}
56 changes: 56 additions & 0 deletions sql/expression/function/tobase64_frombase64_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package function

import (
"testing"

"github.com/stretchr/testify/require"
"gopkg.in/src-d/go-mysql-server.v0/sql"
"gopkg.in/src-d/go-mysql-server.v0/sql/expression"
)

func TestBase64(t *testing.T) {
fTo := NewToBase64(expression.NewGetField(0, sql.Text, "", false))
fFrom := NewFromBase64(expression.NewGetField(0, sql.Text, "", false))

testCases := []struct {
name string
row sql.Row
expected interface{}
err bool
}{
// Use a MySQL server to get expected values if updating/adding to this!
{"null input", sql.NewRow(nil), nil, false},
{"single_line", sql.NewRow("foo"), string("Zm9v"), false},
{"multi_line", sql.NewRow(
"Gallia est omnis divisa in partes tres, quarum unam " +
"incolunt Belgae, aliam Aquitani, tertiam qui ipsorum lingua Celtae, " +
"nostra Galli appellantur"),
"R2FsbGlhIGVzdCBvbW5pcyBkaXZpc2EgaW4gcGFydGVzIHRyZXMsIHF1YXJ1bSB1bmFtIGluY29s\n" +
"dW50IEJlbGdhZSwgYWxpYW0gQXF1aXRhbmksIHRlcnRpYW0gcXVpIGlwc29ydW0gbGluZ3VhIENl\n" +
"bHRhZSwgbm9zdHJhIEdhbGxpIGFwcGVsbGFudHVy", false},
{"empty_input", sql.NewRow(""), string(""), false},
{"symbols", sql.NewRow("!@#$% %^&*()_+\r\n\t{};"), string("IUAjJCUgJV4mKigpXysNCgl7fTs="),
false},
}

for _, tt := range testCases {
t.Run(tt.name, func(t *testing.T) {
t.Helper()
require := require.New(t)
ctx := sql.NewEmptyContext()
v, err := fTo.Eval(ctx, tt.row)

if tt.err {
require.Error(err)
} else {
require.NoError(err)
require.Equal(tt.expected, v)

ctx = sql.NewEmptyContext()
v2, err := fFrom.Eval(ctx, sql.NewRow(v))
require.NoError(err)
require.Equal(sql.NewRow(v2), tt.row)
}
})
}
}