Skip to content
This repository has been archived by the owner on Jan 28, 2021. It is now read-only.

Commit

Permalink
Merge pull request #666 from kuba--/fix-665/func-name
Browse files Browse the repository at this point in the history
Pass function name to the registry
  • Loading branch information
ajnavarro authored Apr 11, 2019
2 parents 6e4c51d + 3d98abd commit d03de5f
Show file tree
Hide file tree
Showing 4 changed files with 170 additions and 111 deletions.
13 changes: 10 additions & 3 deletions engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,16 @@ func New(c *sql.Catalog, a *analyzer.Analyzer, cfg *Config) *Engine {
versionPostfix = cfg.VersionPostfix
}

c.RegisterFunctions(function.Defaults)
c.RegisterFunction("version", sql.FunctionN(function.NewVersion(versionPostfix)))
c.RegisterFunction("database", sql.Function0(function.NewDatabase(c)))
c.MustRegister(
sql.FunctionN{
Name: "version",
Fn: function.NewVersion(versionPostfix),
},
sql.Function0{
Name: "database",
Fn: function.NewDatabase(c),
})
c.MustRegister(function.Defaults...)

// use auth.None if auth is not specified
var au auth.Auth
Expand Down
127 changes: 66 additions & 61 deletions sql/expression/function/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,65 +8,70 @@ import (
)

// Defaults is the function map with all the default functions.
var Defaults = sql.Functions{
"count": sql.Function1(func(e sql.Expression) sql.Expression {
return aggregation.NewCount(e)
}),
"min": sql.Function1(func(e sql.Expression) sql.Expression {
return aggregation.NewMin(e)
}),
"max": sql.Function1(func(e sql.Expression) sql.Expression {
return aggregation.NewMax(e)
}),
"avg": sql.Function1(func(e sql.Expression) sql.Expression {
return aggregation.NewAvg(e)
}),
"sum": sql.Function1(func(e sql.Expression) sql.Expression {
return aggregation.NewSum(e)
}),
"is_binary": sql.Function1(NewIsBinary),
"substring": sql.FunctionN(NewSubstring),
"mid": sql.FunctionN(NewSubstring),
"substr": sql.FunctionN(NewSubstring),
"year": sql.Function1(NewYear),
"month": sql.Function1(NewMonth),
"day": sql.Function1(NewDay),
"weekday": sql.Function1(NewWeekday),
"hour": sql.Function1(NewHour),
"minute": sql.Function1(NewMinute),
"second": sql.Function1(NewSecond),
"dayofweek": sql.Function1(NewDayOfWeek),
"dayofyear": sql.Function1(NewDayOfYear),
"array_length": sql.Function1(NewArrayLength),
"split": sql.Function2(NewSplit),
"concat": sql.FunctionN(NewConcat),
"concat_ws": sql.FunctionN(NewConcatWithSeparator),
"coalesce": sql.FunctionN(NewCoalesce),
"lower": sql.Function1(NewLower),
"upper": sql.Function1(NewUpper),
"ceiling": sql.Function1(NewCeil),
"ceil": sql.Function1(NewCeil),
"floor": sql.Function1(NewFloor),
"round": sql.FunctionN(NewRound),
"connection_id": sql.Function0(NewConnectionID),
"soundex": sql.Function1(NewSoundex),
"json_extract": sql.FunctionN(NewJSONExtract),
"ln": sql.Function1(NewLogBaseFunc(float64(math.E))),
"log2": sql.Function1(NewLogBaseFunc(float64(2))),
"log10": sql.Function1(NewLogBaseFunc(float64(10))),
"log": sql.FunctionN(NewLog),
"rpad": sql.FunctionN(NewPadFunc(rPadType)),
"lpad": sql.FunctionN(NewPadFunc(lPadType)),
"sqrt": sql.Function1(NewSqrt),
"pow": sql.Function2(NewPower),
"power": sql.Function2(NewPower),
"ltrim": sql.Function1(NewTrimFunc(lTrimType)),
"rtrim": sql.Function1(NewTrimFunc(rTrimType)),
"trim": sql.Function1(NewTrimFunc(bTrimType)),
"reverse": sql.Function1(NewReverse),
"repeat": sql.Function2(NewRepeat),
"replace": sql.Function3(NewReplace),
"ifnull": sql.Function2(NewIfNull),
"nullif": sql.Function2(NewNullIf),
"now": sql.Function0(NewNow),
var Defaults = []sql.Function{
sql.Function1{
Name: "count",
Fn: func(e sql.Expression) sql.Expression { return aggregation.NewCount(e) },
},
sql.Function1{
Name: "min",
Fn: func(e sql.Expression) sql.Expression { return aggregation.NewMin(e) },
},
sql.Function1{
Name: "max",
Fn: func(e sql.Expression) sql.Expression { return aggregation.NewMax(e) },
},
sql.Function1{
Name: "avg",
Fn: func(e sql.Expression) sql.Expression { return aggregation.NewAvg(e) },
},
sql.Function1{
Name: "sum",
Fn: func(e sql.Expression) sql.Expression { return aggregation.NewSum(e) },
},
sql.Function1{Name: "is_binary", Fn: NewIsBinary},
sql.FunctionN{Name: "substring", Fn: NewSubstring},
sql.FunctionN{Name: "mid", Fn: NewSubstring},
sql.FunctionN{Name: "substr", Fn: NewSubstring},
sql.Function1{Name: "year", Fn: NewYear},
sql.Function1{Name: "month", Fn: NewMonth},
sql.Function1{Name: "day", Fn: NewDay},
sql.Function1{Name: "weekday", Fn: NewWeekday},
sql.Function1{Name: "hour", Fn: NewHour},
sql.Function1{Name: "minute", Fn: NewMinute},
sql.Function1{Name: "second", Fn: NewSecond},
sql.Function1{Name: "dayofweek", Fn: NewDayOfWeek},
sql.Function1{Name: "dayofyear", Fn: NewDayOfYear},
sql.Function1{Name: "array_length", Fn: NewArrayLength},
sql.Function2{Name: "split", Fn: NewSplit},
sql.FunctionN{Name: "concat", Fn: NewConcat},
sql.FunctionN{Name: "concat_ws", Fn: NewConcatWithSeparator},
sql.FunctionN{Name: "coalesce", Fn: NewCoalesce},
sql.Function1{Name: "lower", Fn: NewLower},
sql.Function1{Name: "upper", Fn: NewUpper},
sql.Function1{Name: "ceiling", Fn: NewCeil},
sql.Function1{Name: "ceil", Fn: NewCeil},
sql.Function1{Name: "floor", Fn: NewFloor},
sql.FunctionN{Name: "round", Fn: NewRound},
sql.Function0{Name: "connection_id", Fn: NewConnectionID},
sql.Function1{Name: "soundex", Fn: NewSoundex},
sql.FunctionN{Name: "json_extract", Fn: NewJSONExtract},
sql.Function1{Name: "ln", Fn: NewLogBaseFunc(float64(math.E))},
sql.Function1{Name: "log2", Fn: NewLogBaseFunc(float64(2))},
sql.Function1{Name: "log10", Fn: NewLogBaseFunc(float64(10))},
sql.FunctionN{Name: "log", Fn: NewLog},
sql.FunctionN{Name: "rpad", Fn: NewPadFunc(rPadType)},
sql.FunctionN{Name: "lpad", Fn: NewPadFunc(lPadType)},
sql.Function1{Name: "sqrt", Fn: NewSqrt},
sql.Function2{Name: "pow", Fn: NewPower},
sql.Function2{Name: "power", Fn: NewPower},
sql.Function1{Name: "ltrim", Fn: NewTrimFunc(lTrimType)},
sql.Function1{Name: "rtrim", Fn: NewTrimFunc(rTrimType)},
sql.Function1{Name: "trim", Fn: NewTrimFunc(bTrimType)},
sql.Function1{Name: "reverse", Fn: NewReverse},
sql.Function2{Name: "repeat", Fn: NewRepeat},
sql.Function3{Name: "replace", Fn: NewReplace},
sql.Function2{Name: "ifnull", Fn: NewIfNull},
sql.Function2{Name: "nullif", Fn: NewNullIf},
sql.Function0{Name: "now", Fn: NewNow},
}
134 changes: 90 additions & 44 deletions sql/functionregistry.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,122 +4,163 @@ import (
"gopkg.in/src-d/go-errors.v1"
)

// ErrFunctionAlreadyRegistered is thrown when a function is already registered
var ErrFunctionAlreadyRegistered = errors.NewKind("A function: '%s' is already registered.")

// ErrFunctionNotFound is thrown when a function is not found
var ErrFunctionNotFound = errors.NewKind("function not found: %s")
var ErrFunctionNotFound = errors.NewKind("A function: '%s' not found.")

// ErrInvalidArgumentNumber is returned when the number of arguments to call a
// function is different from the function arity.
var ErrInvalidArgumentNumber = errors.NewKind("%s: expecting %v arguments for calling this function, %d received")
var ErrInvalidArgumentNumber = errors.NewKind("A function: '%s' expected %d arguments, %d received.")

// Function is a function defined by the user that can be applied in a SQL
// query.
// Function is a function defined by the user that can be applied in a SQL query.
type Function interface {
// Call invokes the function.
Call(...Expression) (Expression, error)
// Function name
name() string
// isFunction will restrict implementations of Function
isFunction()
}

type (
// Function0 is a function with 0 arguments.
Function0 func() Expression
Function0 struct {
Name string
Fn func() Expression
}
// Function1 is a function with 1 argument.
Function1 func(e Expression) Expression
Function1 struct {
Name string
Fn func(e Expression) Expression
}
// Function2 is a function with 2 arguments.
Function2 func(e1, e2 Expression) Expression
Function2 struct {
Name string
Fn func(e1, e2 Expression) Expression
}
// Function3 is a function with 3 arguments.
Function3 func(e1, e2, e3 Expression) Expression
Function3 struct {
Name string
Fn func(e1, e2, e3 Expression) Expression
}
// Function4 is a function with 4 arguments.
Function4 func(e1, e2, e3, e4 Expression) Expression
Function4 struct {
Name string
Fn func(e1, e2, e3, e4 Expression) Expression
}
// Function5 is a function with 5 arguments.
Function5 func(e1, e2, e3, e4, e5 Expression) Expression
Function5 struct {
Name string
Fn func(e1, e2, e3, e4, e5 Expression) Expression
}
// Function6 is a function with 6 arguments.
Function6 func(e1, e2, e3, e4, e5, e6 Expression) Expression
Function6 struct {
Name string
Fn func(e1, e2, e3, e4, e5, e6 Expression) Expression
}
// Function7 is a function with 7 arguments.
Function7 func(e1, e2, e3, e4, e5, e6, e7 Expression) Expression
Function7 struct {
Name string
Fn func(e1, e2, e3, e4, e5, e6, e7 Expression) Expression
}
// FunctionN is a function with variable number of arguments. This function
// is expected to return ErrInvalidArgumentNumber if the arity does not
// match, since the check has to be done in the implementation.
FunctionN func(...Expression) (Expression, error)
FunctionN struct {
Name string
Fn func(...Expression) (Expression, error)
}
)

// Call implements the Function interface.
func (fn Function0) Call(args ...Expression) (Expression, error) {
if len(args) != 0 {
return nil, ErrInvalidArgumentNumber.New(0, len(args))
return nil, ErrInvalidArgumentNumber.New(fn.Name, 0, len(args))
}

return fn(), nil
return fn.Fn(), nil
}

// Call implements the Function interface.
func (fn Function1) Call(args ...Expression) (Expression, error) {
if len(args) != 1 {
return nil, ErrInvalidArgumentNumber.New(1, len(args))
return nil, ErrInvalidArgumentNumber.New(fn.Name, 1, len(args))
}

return fn(args[0]), nil
return fn.Fn(args[0]), nil
}

// Call implements the Function interface.
func (fn Function2) Call(args ...Expression) (Expression, error) {
if len(args) != 2 {
return nil, ErrInvalidArgumentNumber.New(2, len(args))
return nil, ErrInvalidArgumentNumber.New(fn.Name, 2, len(args))
}

return fn(args[0], args[1]), nil
return fn.Fn(args[0], args[1]), nil
}

// Call implements the Function interface.
func (fn Function3) Call(args ...Expression) (Expression, error) {
if len(args) != 3 {
return nil, ErrInvalidArgumentNumber.New(3, len(args))
return nil, ErrInvalidArgumentNumber.New(fn.Name, 3, len(args))
}

return fn(args[0], args[1], args[2]), nil
return fn.Fn(args[0], args[1], args[2]), nil
}

// Call implements the Function interface.
func (fn Function4) Call(args ...Expression) (Expression, error) {
if len(args) != 4 {
return nil, ErrInvalidArgumentNumber.New(4, len(args))
return nil, ErrInvalidArgumentNumber.New(fn.Name, 4, len(args))
}

return fn(args[0], args[1], args[2], args[3]), nil
return fn.Fn(args[0], args[1], args[2], args[3]), nil
}

// Call implements the Function interface.
func (fn Function5) Call(args ...Expression) (Expression, error) {
if len(args) != 5 {
return nil, ErrInvalidArgumentNumber.New(5, len(args))
return nil, ErrInvalidArgumentNumber.New(fn.Name, 5, len(args))
}

return fn(args[0], args[1], args[2], args[3], args[4]), nil
return fn.Fn(args[0], args[1], args[2], args[3], args[4]), nil
}

// Call implements the Function interface.
func (fn Function6) Call(args ...Expression) (Expression, error) {
if len(args) != 6 {
return nil, ErrInvalidArgumentNumber.New(6, len(args))
return nil, ErrInvalidArgumentNumber.New(fn.Name, 6, len(args))
}

return fn(args[0], args[1], args[2], args[3], args[4], args[5]), nil
return fn.Fn(args[0], args[1], args[2], args[3], args[4], args[5]), nil
}

// Call implements the Function interface.
func (fn Function7) Call(args ...Expression) (Expression, error) {
if len(args) != 7 {
return nil, ErrInvalidArgumentNumber.New(7, len(args))
return nil, ErrInvalidArgumentNumber.New(fn.Name, 7, len(args))
}

return fn(args[0], args[1], args[2], args[3], args[4], args[5], args[6]), nil
return fn.Fn(args[0], args[1], args[2], args[3], args[4], args[5], args[6]), nil
}

// Call implements the Function interface.
func (fn FunctionN) Call(args ...Expression) (Expression, error) {
return fn(args...)
return fn.Fn(args...)
}

func (fn Function0) name() string { return fn.Name }
func (fn Function1) name() string { return fn.Name }
func (fn Function2) name() string { return fn.Name }
func (fn Function3) name() string { return fn.Name }
func (fn Function4) name() string { return fn.Name }
func (fn Function5) name() string { return fn.Name }
func (fn Function6) name() string { return fn.Name }
func (fn Function7) name() string { return fn.Name }
func (fn FunctionN) name() string { return fn.Name }

func (Function0) isFunction() {}
func (Function1) isFunction() {}
func (Function2) isFunction() {}
Expand All @@ -134,32 +175,37 @@ func (FunctionN) isFunction() {}
// and User-Defined Functions.
type FunctionRegistry map[string]Function

// Functions is a map of functions identified by their name.
type Functions map[string]Function

// NewFunctionRegistry creates a new FunctionRegistry.
func NewFunctionRegistry() FunctionRegistry {
return make(FunctionRegistry)
}

// RegisterFunction registers a function with the given name.
func (r FunctionRegistry) RegisterFunction(name string, f Function) {
r[name] = f
// Register registers functions.
// If function with that name is already registered,
// the ErrFunctionAlreadyRegistered will be returned
func (r FunctionRegistry) Register(fn ...Function) error {
for _, f := range fn {
if _, ok := r[f.name()]; ok {
return ErrFunctionAlreadyRegistered.New(f.name())
}
r[f.name()] = f
}
return nil
}

// RegisterFunctions registers a map of functions.
func (r FunctionRegistry) RegisterFunctions(funcs Functions) {
for name, f := range funcs {
r[name] = f
// MustRegister registers functions.
// If function with that name is already registered, it will panic!
func (r FunctionRegistry) MustRegister(fn ...Function) {
if err := r.Register(fn...); err != nil {
panic(err)
}
}

// Function returns a function with the given name.
func (r FunctionRegistry) Function(name string) (Function, error) {
e, ok := r[name]
if !ok {
return nil, ErrFunctionNotFound.New(name)
if fn, ok := r[name]; ok {
return fn, nil
}

return e, nil
return nil, ErrFunctionNotFound.New(name)
}
Loading

0 comments on commit d03de5f

Please sign in to comment.