Skip to content
This repository has been archived by the owner on Jan 28, 2021. It is now read-only.

Add support for CONCAT_WS #500

Merged
merged 3 commits into from
Oct 24, 2018
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,12 @@ We support and actively test against certain third-party clients to ensure compa
## Custom functions

- `IS_BINARY(blob)`: Returns whether a BLOB is a binary file or not.
- `SUBSTRING(str,pos)`, `SUBSTRING(str,pos,len)`: Return a substring from the provided string.
- `SUBSTRING(str, pos)`, `SUBSTRING(str, pos, len)`: Return a substring from the provided string.
- Date and Timestamp functions: `YEAR(date)`, `MONTH(date)`, `DAY(date)`, `HOUR(date)`, `MINUTE(date)`, `SECOND(date)`, `DAYOFYEAR(date)`.
- `ARRAY_LENGTH(json)`: If the json representation is an array, this function returns its size.
- `SPLIT(str,sep)`: Receives a string and a separator and returns the parts of the string split by the separator as a JSON array of strings.
- `CONCAT(...)`: Concatenate any group of fields into a single string.
- `CONCAT_WS(sep, ...)`: Concatenate any group of fields into a single string separated by the first field.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably worth mentioning that null arguments are skipped and only returns null if sep is null.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!
"Returns null if the separator is null. Following null fields are skipped." - Would this be ok?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to me!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets copy paste from mysql docs:

... stands for Concatenate With Separator and is a special form of CONCAT(). The first argument is the separator for the rest of the arguments. The separator is added between the strings to be concatenated. The separator can be a string, as can the rest of the arguments. If the separator is NULL, the result is NULL.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I only just saw the comment. The descriptions now come from MySQLs docs.

- `COALESCE(...)`: The function returns the first non-null value in a list.
- `LOWER(str)`, `UPPER(str)`: Receives a string and modify it changing all the chars to upper or lower case.
- `CEILING(number)`, `CEIL(number)`: Return the smallest integer value that is greater than or equal to `number`.
Expand Down
1 change: 1 addition & 0 deletions SUPPORTED.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
## Functions
- ARRAY_LENGTH
- CONCAT
- CONCAT_WS
- IS_BINARY
- SPLIT
- SUBSTRING
Expand Down
131 changes: 131 additions & 0 deletions sql/expression/function/concat_ws.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
package function

import (
"fmt"
"strings"

"gopkg.in/src-d/go-mysql-server.v0/sql"
)

// ConcatWithSeparator joins several strings together.
type ConcatWithSeparator struct {
args []sql.Expression
}

// NewConcatWithSeparator creates a new NewConcatWithSeparator UDF.
func NewConcatWithSeparator(args ...sql.Expression) (sql.Expression, error) {
if len(args) == 0 {
return nil, sql.ErrInvalidArgumentNumber.New("1 or more", 0)
}

for _, arg := range args {
// Don't perform this check until it's resolved. Otherwise we
// can't get the type for sure.
if !arg.Resolved() {
continue
}

if len(args) > 1 && sql.IsArray(arg.Type()) {
return nil, ErrConcatArrayWithOthers.New()
}

if sql.IsTuple(arg.Type()) {
return nil, sql.ErrInvalidType.New("tuple")
}
}

return &ConcatWithSeparator{args}, nil
}

// Type implements the Expression interface.
func (f *ConcatWithSeparator) Type() sql.Type { return sql.Text }

// IsNullable implements the Expression interface.
func (f *ConcatWithSeparator) IsNullable() bool {
for _, arg := range f.args {
if arg.IsNullable() {
return true
}
}
return false
}

func (f *ConcatWithSeparator) String() string {
var args = make([]string, len(f.args))
for i, arg := range f.args {
args[i] = arg.String()
}
return fmt.Sprintf("concat_ws(%s)", strings.Join(args, ", "))
}

// TransformUp implements the Expression interface.
func (f *ConcatWithSeparator) TransformUp(fn sql.TransformExprFunc) (sql.Expression, error) {
var args = make([]sql.Expression, len(f.args))
for i, arg := range f.args {
arg, err := arg.TransformUp(fn)
if err != nil {
return nil, err
}
args[i] = arg
}

expr, err := NewConcatWithSeparator(args...)
if err != nil {
return nil, err
}

return fn(expr)
}

// Resolved implements the Expression interface.
func (f *ConcatWithSeparator) Resolved() bool {
for _, arg := range f.args {
if !arg.Resolved() {
return false
}
}
return true
}

// Children implements the Expression interface.
func (f *ConcatWithSeparator) Children() []sql.Expression { return f.args }

// Eval implements the Expression interface.
func (f *ConcatWithSeparator) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
var parts []string

for i, arg := range f.args {
val, err := arg.Eval(ctx, row)
if err != nil {
return nil, err
}

if val == nil && i == 0 {
return nil, nil
}

if val == nil {
continue
}

if sql.IsArray(arg.Type()) {
val, err = sql.Array(sql.Text).Convert(val)
if err != nil {
return nil, err
}

for _, v := range val.([]interface{}) {
parts = append(parts, v.(string))
}
} else {
val, err = sql.Text.Convert(val)
if err != nil {
return nil, err
}

parts = append(parts, val.(string))
}
}

return strings.Join(parts[1:], parts[0]), nil
}
106 changes: 106 additions & 0 deletions sql/expression/function/concat_ws_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
package function

import (
"testing"

"github.com/stretchr/testify/require"
"gopkg.in/src-d/go-mysql-server.v0/sql"
"gopkg.in/src-d/go-mysql-server.v0/sql/expression"
)

func TestConcatWithSeparator(t *testing.T) {
t.Run("multiple arguments", func(t *testing.T) {
require := require.New(t)
f, err := NewConcatWithSeparator(
expression.NewLiteral(",", sql.Text),
expression.NewLiteral("foo", sql.Text),
expression.NewLiteral(5, sql.Text),
expression.NewLiteral(true, sql.Boolean),
)
require.NoError(err)

v, err := f.Eval(sql.NewEmptyContext(), nil)
require.NoError(err)
require.Equal("foo,5,true", v)
})

t.Run("some argument is empty", func(t *testing.T) {
require := require.New(t)
f, err := NewConcatWithSeparator(
expression.NewLiteral(",", sql.Text),
expression.NewLiteral("foo", sql.Text),
expression.NewLiteral("", sql.Text),
expression.NewLiteral(true, sql.Boolean),
)
require.NoError(err)

v, err := f.Eval(sql.NewEmptyContext(), nil)
require.NoError(err)
require.Equal("foo,,true", v)
})

t.Run("some argument is nil", func(t *testing.T) {
require := require.New(t)
f, err := NewConcatWithSeparator(
expression.NewLiteral(",", sql.Text),
expression.NewLiteral("foo", sql.Text),
expression.NewLiteral(nil, sql.Text),
expression.NewLiteral(true, sql.Boolean),
)
require.NoError(err)

v, err := f.Eval(sql.NewEmptyContext(), nil)
require.NoError(err)
require.Equal("foo,true", v)
})

t.Run("separator is nil", func(t *testing.T) {
require := require.New(t)
f, err := NewConcatWithSeparator(
expression.NewLiteral(nil, sql.Text),
expression.NewLiteral("foo", sql.Text),
expression.NewLiteral(5, sql.Text),
expression.NewLiteral(true, sql.Boolean),
)
require.NoError(err)

v, err := f.Eval(sql.NewEmptyContext(), nil)
require.NoError(err)
require.Equal(nil, v)
})

t.Run("concat_ws array", func(t *testing.T) {
require := require.New(t)
f, err := NewConcatWithSeparator(
expression.NewLiteral([]interface{}{",",5, "bar", true}, sql.Array(sql.Text)),
)
require.NoError(err)

v, err := f.Eval(sql.NewEmptyContext(), nil)
require.NoError(err)
require.Equal("5,bar,true", v)
})
}

func TestNewConcatWithSeparator(t *testing.T) {
require := require.New(t)

_, err := NewConcatWithSeparator(expression.NewLiteral(nil, sql.Array(sql.Text)))
require.NoError(err)

_, err = NewConcatWithSeparator(expression.NewLiteral(nil, sql.Array(sql.Text)), expression.NewLiteral(nil, sql.Int64))
require.Error(err)
require.True(ErrConcatArrayWithOthers.Is(err))

_, err = NewConcatWithSeparator(expression.NewLiteral(nil, sql.Tuple(sql.Text, sql.Text)))
require.Error(err)
require.True(sql.ErrInvalidType.Is(err))

_, err = NewConcatWithSeparator(
expression.NewLiteral(nil, sql.Text),
expression.NewLiteral(nil, sql.Boolean),
expression.NewLiteral(nil, sql.Int64),
expression.NewLiteral(nil, sql.Text),
)
require.NoError(err)
}
1 change: 1 addition & 0 deletions sql/expression/function/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ var Defaults = sql.Functions{
"array_length": sql.Function1(NewArrayLength),
"split": sql.Function2(NewSplit),
"concat": sql.FunctionN(NewConcat),
"concat_ws": sql.FunctionN(NewConcatWithSeparator),
"lower": sql.Function1(NewLower),
"upper": sql.Function1(NewUpper),
"ceiling": sql.Function1(NewCeil),
Expand Down