Skip to content
This repository has been archived by the owner on Jan 28, 2021. It is now read-only.

Commit

Permalink
Merge pull request #650 from juanjux/base64-functions
Browse files Browse the repository at this point in the history
Added from_base64 and to_base64
  • Loading branch information
ajnavarro authored Apr 11, 2019
2 parents 89c0734 + 689788f commit d9a12ff
Show file tree
Hide file tree
Showing 6 changed files with 218 additions and 0 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ We support and actively test against certain third-party clients to ensure compa
|`SUBSTR(str, pos, [len])`|Return a substring from the provided string starting at `pos` with a length of `len` characters. If no `len` is provided, all characters from `pos` until the end will be taken.|
|`SUBSTRING(str, pos, [len])`|Return a substring from the provided string starting at `pos` with a length of `len` characters. If no `len` is provided, all characters from `pos` until the end will be taken.|
|`SUM(expr)`|Returns the sum of expr in all rows.|
|`TO_BASE64(str)`|Encodes the string str in base64 format.|
|`FROM_BASE64(str)`|Decodes the base64-encoded string str.|
|`TRIM(str)`|Returns the string str with all spaces removed.|
|`UPPER(str)`|Returns the string str with all characters in upper case.|
|`WEEKDAY(date)`|Returns the weekday of the given date.|
Expand Down
2 changes: 2 additions & 0 deletions SUPPORTED.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@
- LOG2
- LOG10
- SLEEP
- TO_BASE64
- FROM_BASE64

## Time functions
- DAY
Expand Down
8 changes: 8 additions & 0 deletions engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -901,6 +901,14 @@ var queries = []struct {
"SELECT SLEEP(0.5)",
[]sql.Row{{int(0)}},
},
{
"SELECT TO_BASE64('foo')",
[]sql.Row{{string("Zm9v")}},
},
{
"SELECT FROM_BASE64('YmFy')",
[]sql.Row{{string("bar")}},
},
}

func TestQueries(t *testing.T) {
Expand Down
2 changes: 2 additions & 0 deletions sql/expression/function/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,4 +76,6 @@ var Defaults = []sql.Function{
sql.Function2{Name: "nullif", Fn: NewNullIf},
sql.Function0{Name: "now", Fn: NewNow},
sql.Function1{Name: "sleep", Fn: NewSleep},
sql.Function1{Name: "to_base64", Fn: NewToBase64},
sql.Function1{Name: "from_base64", Fn: NewFromBase64},
}
148 changes: 148 additions & 0 deletions sql/expression/function/tobase64_frombase64.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
package function

import (
"encoding/base64"
"fmt"
"reflect"
"strings"

"gopkg.in/src-d/go-mysql-server.v0/sql"
"gopkg.in/src-d/go-mysql-server.v0/sql/expression"
)

// ToBase64 is a function to encode a string to the Base64 format
// using the same dialect that MySQL's TO_BASE64 uses
type ToBase64 struct {
expression.UnaryExpression
}

// NewToBase64 creates a new ToBase64 expression.
func NewToBase64(e sql.Expression) sql.Expression {
return &ToBase64{expression.UnaryExpression{Child: e}}
}

// Eval implements the Expression interface.
func (t *ToBase64) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
str, err := t.Child.Eval(ctx, row)

if err != nil {
return nil, err
}

if str == nil {
return nil, nil
}

str, err = sql.Text.Convert(str)
if err != nil {
return nil, sql.ErrInvalidType.New(reflect.TypeOf(str))
}

encoded := base64.StdEncoding.EncodeToString([]byte(str.(string)))

lenEncoded := len(encoded)
if lenEncoded <= 76 {
return encoded, nil
}

// Split into max 76 chars lines
var out strings.Builder
start := 0
end := 76
for {
out.WriteString(encoded[start:end] + "\n")
start += 76
end += 76
if end >= lenEncoded {
out.WriteString(encoded[start:lenEncoded])
break
}
}

return out.String(), nil
}

// String implements the Stringer interface.
func (t *ToBase64) String() string {
return fmt.Sprintf("TO_BASE64(%s)", t.Child)
}

// IsNullable implements the Expression interface.
func (t *ToBase64) IsNullable() bool {
return t.Child.IsNullable()
}

// TransformUp implements the Expression interface.
func (t *ToBase64) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) {
child, err := t.Child.TransformUp(f)
if err != nil {
return nil, err
}
return f(NewToBase64(child))
}

// Type implements the Expression interface.
func (t *ToBase64) Type() sql.Type {
return sql.Text
}


// FromBase64 is a function to decode a Base64-formatted string
// using the same dialect that MySQL's FROM_BASE64 uses
type FromBase64 struct {
expression.UnaryExpression
}

// NewFromBase64 creates a new FromBase64 expression.
func NewFromBase64(e sql.Expression) sql.Expression {
return &FromBase64{expression.UnaryExpression{Child: e}}
}

// Eval implements the Expression interface.
func (t *FromBase64) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
str, err := t.Child.Eval(ctx, row)

if err != nil {
return nil, err
}

if str == nil {
return nil, nil
}

str, err = sql.Text.Convert(str)
if err != nil {
return nil, sql.ErrInvalidType.New(reflect.TypeOf(str))
}

decoded, err := base64.StdEncoding.DecodeString(str.(string))
if err != nil {
return nil, err
}

return string(decoded), nil
}

// String implements the Stringer interface.
func (t *FromBase64) String() string {
return fmt.Sprintf("FROM_BASE64(%s)", t.Child)
}

// IsNullable implements the Expression interface.
func (t *FromBase64) IsNullable() bool {
return t.Child.IsNullable()
}

// TransformUp implements the Expression interface.
func (t *FromBase64) TransformUp(f sql.TransformExprFunc) (sql.Expression, error) {
child, err := t.Child.TransformUp(f)
if err != nil {
return nil, err
}
return f(NewFromBase64(child))
}

// Type implements the Expression interface.
func (t *FromBase64) Type() sql.Type {
return sql.Text
}
56 changes: 56 additions & 0 deletions sql/expression/function/tobase64_frombase64_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package function

import (
"testing"

"github.com/stretchr/testify/require"
"gopkg.in/src-d/go-mysql-server.v0/sql"
"gopkg.in/src-d/go-mysql-server.v0/sql/expression"
)

func TestBase64(t *testing.T) {
fTo := NewToBase64(expression.NewGetField(0, sql.Text, "", false))
fFrom := NewFromBase64(expression.NewGetField(0, sql.Text, "", false))

testCases := []struct {
name string
row sql.Row
expected interface{}
err bool
}{
// Use a MySQL server to get expected values if updating/adding to this!
{"null input", sql.NewRow(nil), nil, false},
{"single_line", sql.NewRow("foo"), string("Zm9v"), false},
{"multi_line", sql.NewRow(
"Gallia est omnis divisa in partes tres, quarum unam " +
"incolunt Belgae, aliam Aquitani, tertiam qui ipsorum lingua Celtae, " +
"nostra Galli appellantur"),
"R2FsbGlhIGVzdCBvbW5pcyBkaXZpc2EgaW4gcGFydGVzIHRyZXMsIHF1YXJ1bSB1bmFtIGluY29s\n" +
"dW50IEJlbGdhZSwgYWxpYW0gQXF1aXRhbmksIHRlcnRpYW0gcXVpIGlwc29ydW0gbGluZ3VhIENl\n" +
"bHRhZSwgbm9zdHJhIEdhbGxpIGFwcGVsbGFudHVy", false},
{"empty_input", sql.NewRow(""), string(""), false},
{"symbols", sql.NewRow("!@#$% %^&*()_+\r\n\t{};"), string("IUAjJCUgJV4mKigpXysNCgl7fTs="),
false},
}

for _, tt := range testCases {
t.Run(tt.name, func(t *testing.T) {
t.Helper()
require := require.New(t)
ctx := sql.NewEmptyContext()
v, err := fTo.Eval(ctx, tt.row)

if tt.err {
require.Error(err)
} else {
require.NoError(err)
require.Equal(tt.expected, v)

ctx = sql.NewEmptyContext()
v2, err := fFrom.Eval(ctx, sql.NewRow(v))
require.NoError(err)
require.Equal(sql.NewRow(v2), tt.row)
}
})
}
}

0 comments on commit d9a12ff

Please sign in to comment.