Skip to content
This repository has been archived by the owner on Jan 28, 2021. It is now read-only.

Pass function name to the registry #666

Merged
merged 1 commit into from
Apr 11, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,16 @@ func New(c *sql.Catalog, a *analyzer.Analyzer, cfg *Config) *Engine {
versionPostfix = cfg.VersionPostfix
}

c.RegisterFunctions(function.Defaults)
c.RegisterFunction("version", sql.FunctionN(function.NewVersion(versionPostfix)))
c.RegisterFunction("database", sql.Function0(function.NewDatabase(c)))
c.MustRegister(
sql.FunctionN{
Name: "version",
Fn: function.NewVersion(versionPostfix),
},
sql.Function0{
Name: "database",
Fn: function.NewDatabase(c),
})
c.MustRegister(function.Defaults...)

// use auth.None if auth is not specified
var au auth.Auth
Expand Down
127 changes: 66 additions & 61 deletions sql/expression/function/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,65 +8,70 @@ import (
)

// Defaults is the function map with all the default functions.
var Defaults = sql.Functions{
"count": sql.Function1(func(e sql.Expression) sql.Expression {
return aggregation.NewCount(e)
}),
"min": sql.Function1(func(e sql.Expression) sql.Expression {
return aggregation.NewMin(e)
}),
"max": sql.Function1(func(e sql.Expression) sql.Expression {
return aggregation.NewMax(e)
}),
"avg": sql.Function1(func(e sql.Expression) sql.Expression {
return aggregation.NewAvg(e)
}),
"sum": sql.Function1(func(e sql.Expression) sql.Expression {
return aggregation.NewSum(e)
}),
"is_binary": sql.Function1(NewIsBinary),
"substring": sql.FunctionN(NewSubstring),
"mid": sql.FunctionN(NewSubstring),
"substr": sql.FunctionN(NewSubstring),
"year": sql.Function1(NewYear),
"month": sql.Function1(NewMonth),
"day": sql.Function1(NewDay),
"weekday": sql.Function1(NewWeekday),
"hour": sql.Function1(NewHour),
"minute": sql.Function1(NewMinute),
"second": sql.Function1(NewSecond),
"dayofweek": sql.Function1(NewDayOfWeek),
"dayofyear": sql.Function1(NewDayOfYear),
"array_length": sql.Function1(NewArrayLength),
"split": sql.Function2(NewSplit),
"concat": sql.FunctionN(NewConcat),
"concat_ws": sql.FunctionN(NewConcatWithSeparator),
"coalesce": sql.FunctionN(NewCoalesce),
"lower": sql.Function1(NewLower),
"upper": sql.Function1(NewUpper),
"ceiling": sql.Function1(NewCeil),
"ceil": sql.Function1(NewCeil),
"floor": sql.Function1(NewFloor),
"round": sql.FunctionN(NewRound),
"connection_id": sql.Function0(NewConnectionID),
"soundex": sql.Function1(NewSoundex),
"json_extract": sql.FunctionN(NewJSONExtract),
"ln": sql.Function1(NewLogBaseFunc(float64(math.E))),
"log2": sql.Function1(NewLogBaseFunc(float64(2))),
"log10": sql.Function1(NewLogBaseFunc(float64(10))),
"log": sql.FunctionN(NewLog),
"rpad": sql.FunctionN(NewPadFunc(rPadType)),
"lpad": sql.FunctionN(NewPadFunc(lPadType)),
"sqrt": sql.Function1(NewSqrt),
"pow": sql.Function2(NewPower),
"power": sql.Function2(NewPower),
"ltrim": sql.Function1(NewTrimFunc(lTrimType)),
"rtrim": sql.Function1(NewTrimFunc(rTrimType)),
"trim": sql.Function1(NewTrimFunc(bTrimType)),
"reverse": sql.Function1(NewReverse),
"repeat": sql.Function2(NewRepeat),
"replace": sql.Function3(NewReplace),
"ifnull": sql.Function2(NewIfNull),
"nullif": sql.Function2(NewNullIf),
"now": sql.Function0(NewNow),
var Defaults = []sql.Function{
sql.Function1{
Name: "count",
Fn: func(e sql.Expression) sql.Expression { return aggregation.NewCount(e) },
},
sql.Function1{
Name: "min",
Fn: func(e sql.Expression) sql.Expression { return aggregation.NewMin(e) },
},
sql.Function1{
Name: "max",
Fn: func(e sql.Expression) sql.Expression { return aggregation.NewMax(e) },
},
sql.Function1{
Name: "avg",
Fn: func(e sql.Expression) sql.Expression { return aggregation.NewAvg(e) },
},
sql.Function1{
Name: "sum",
Fn: func(e sql.Expression) sql.Expression { return aggregation.NewSum(e) },
},
sql.Function1{Name: "is_binary", Fn: NewIsBinary},
sql.FunctionN{Name: "substring", Fn: NewSubstring},
sql.FunctionN{Name: "mid", Fn: NewSubstring},
sql.FunctionN{Name: "substr", Fn: NewSubstring},
sql.Function1{Name: "year", Fn: NewYear},
sql.Function1{Name: "month", Fn: NewMonth},
sql.Function1{Name: "day", Fn: NewDay},
sql.Function1{Name: "weekday", Fn: NewWeekday},
sql.Function1{Name: "hour", Fn: NewHour},
sql.Function1{Name: "minute", Fn: NewMinute},
sql.Function1{Name: "second", Fn: NewSecond},
sql.Function1{Name: "dayofweek", Fn: NewDayOfWeek},
sql.Function1{Name: "dayofyear", Fn: NewDayOfYear},
sql.Function1{Name: "array_length", Fn: NewArrayLength},
sql.Function2{Name: "split", Fn: NewSplit},
sql.FunctionN{Name: "concat", Fn: NewConcat},
sql.FunctionN{Name: "concat_ws", Fn: NewConcatWithSeparator},
sql.FunctionN{Name: "coalesce", Fn: NewCoalesce},
sql.Function1{Name: "lower", Fn: NewLower},
sql.Function1{Name: "upper", Fn: NewUpper},
sql.Function1{Name: "ceiling", Fn: NewCeil},
sql.Function1{Name: "ceil", Fn: NewCeil},
sql.Function1{Name: "floor", Fn: NewFloor},
sql.FunctionN{Name: "round", Fn: NewRound},
sql.Function0{Name: "connection_id", Fn: NewConnectionID},
sql.Function1{Name: "soundex", Fn: NewSoundex},
sql.FunctionN{Name: "json_extract", Fn: NewJSONExtract},
sql.Function1{Name: "ln", Fn: NewLogBaseFunc(float64(math.E))},
sql.Function1{Name: "log2", Fn: NewLogBaseFunc(float64(2))},
sql.Function1{Name: "log10", Fn: NewLogBaseFunc(float64(10))},
sql.FunctionN{Name: "log", Fn: NewLog},
sql.FunctionN{Name: "rpad", Fn: NewPadFunc(rPadType)},
sql.FunctionN{Name: "lpad", Fn: NewPadFunc(lPadType)},
sql.Function1{Name: "sqrt", Fn: NewSqrt},
sql.Function2{Name: "pow", Fn: NewPower},
sql.Function2{Name: "power", Fn: NewPower},
sql.Function1{Name: "ltrim", Fn: NewTrimFunc(lTrimType)},
sql.Function1{Name: "rtrim", Fn: NewTrimFunc(rTrimType)},
sql.Function1{Name: "trim", Fn: NewTrimFunc(bTrimType)},
sql.Function1{Name: "reverse", Fn: NewReverse},
sql.Function2{Name: "repeat", Fn: NewRepeat},
sql.Function3{Name: "replace", Fn: NewReplace},
sql.Function2{Name: "ifnull", Fn: NewIfNull},
sql.Function2{Name: "nullif", Fn: NewNullIf},
sql.Function0{Name: "now", Fn: NewNow},
}
134 changes: 90 additions & 44 deletions sql/functionregistry.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,122 +4,163 @@ import (
"gopkg.in/src-d/go-errors.v1"
)

// ErrFunctionAlreadyRegistered is thrown when a function is already registered
var ErrFunctionAlreadyRegistered = errors.NewKind("A function: '%s' is already registered.")

// ErrFunctionNotFound is thrown when a function is not found
var ErrFunctionNotFound = errors.NewKind("function not found: %s")
var ErrFunctionNotFound = errors.NewKind("A function: '%s' not found.")

// ErrInvalidArgumentNumber is returned when the number of arguments to call a
// function is different from the function arity.
var ErrInvalidArgumentNumber = errors.NewKind("%s: expecting %v arguments for calling this function, %d received")
var ErrInvalidArgumentNumber = errors.NewKind("A function: '%s' expected %d arguments, %d received.")

// Function is a function defined by the user that can be applied in a SQL
// query.
// Function is a function defined by the user that can be applied in a SQL query.
type Function interface {
// Call invokes the function.
Call(...Expression) (Expression, error)
// Function name
name() string
// isFunction will restrict implementations of Function
isFunction()
}

type (
// Function0 is a function with 0 arguments.
Function0 func() Expression
Function0 struct {
Name string
Fn func() Expression
}
// Function1 is a function with 1 argument.
Function1 func(e Expression) Expression
Function1 struct {
Name string
Fn func(e Expression) Expression
}
// Function2 is a function with 2 arguments.
Function2 func(e1, e2 Expression) Expression
Function2 struct {
Name string
Fn func(e1, e2 Expression) Expression
}
// Function3 is a function with 3 arguments.
Function3 func(e1, e2, e3 Expression) Expression
Function3 struct {
Name string
Fn func(e1, e2, e3 Expression) Expression
}
// Function4 is a function with 4 arguments.
Function4 func(e1, e2, e3, e4 Expression) Expression
Function4 struct {
Name string
Fn func(e1, e2, e3, e4 Expression) Expression
}
// Function5 is a function with 5 arguments.
Function5 func(e1, e2, e3, e4, e5 Expression) Expression
Function5 struct {
Name string
Fn func(e1, e2, e3, e4, e5 Expression) Expression
}
// Function6 is a function with 6 arguments.
Function6 func(e1, e2, e3, e4, e5, e6 Expression) Expression
Function6 struct {
Name string
Fn func(e1, e2, e3, e4, e5, e6 Expression) Expression
}
// Function7 is a function with 7 arguments.
Function7 func(e1, e2, e3, e4, e5, e6, e7 Expression) Expression
Function7 struct {
Name string
Fn func(e1, e2, e3, e4, e5, e6, e7 Expression) Expression
}
// FunctionN is a function with variable number of arguments. This function
// is expected to return ErrInvalidArgumentNumber if the arity does not
// match, since the check has to be done in the implementation.
FunctionN func(...Expression) (Expression, error)
FunctionN struct {
Name string
Fn func(...Expression) (Expression, error)
}
)

// Call implements the Function interface.
func (fn Function0) Call(args ...Expression) (Expression, error) {
if len(args) != 0 {
return nil, ErrInvalidArgumentNumber.New(0, len(args))
return nil, ErrInvalidArgumentNumber.New(fn.Name, 0, len(args))
}

return fn(), nil
return fn.Fn(), nil
}

// Call implements the Function interface.
func (fn Function1) Call(args ...Expression) (Expression, error) {
if len(args) != 1 {
return nil, ErrInvalidArgumentNumber.New(1, len(args))
return nil, ErrInvalidArgumentNumber.New(fn.Name, 1, len(args))
}

return fn(args[0]), nil
return fn.Fn(args[0]), nil
}

// Call implements the Function interface.
func (fn Function2) Call(args ...Expression) (Expression, error) {
if len(args) != 2 {
return nil, ErrInvalidArgumentNumber.New(2, len(args))
return nil, ErrInvalidArgumentNumber.New(fn.Name, 2, len(args))
}

return fn(args[0], args[1]), nil
return fn.Fn(args[0], args[1]), nil
}

// Call implements the Function interface.
func (fn Function3) Call(args ...Expression) (Expression, error) {
if len(args) != 3 {
return nil, ErrInvalidArgumentNumber.New(3, len(args))
return nil, ErrInvalidArgumentNumber.New(fn.Name, 3, len(args))
}

return fn(args[0], args[1], args[2]), nil
return fn.Fn(args[0], args[1], args[2]), nil
}

// Call implements the Function interface.
func (fn Function4) Call(args ...Expression) (Expression, error) {
if len(args) != 4 {
return nil, ErrInvalidArgumentNumber.New(4, len(args))
return nil, ErrInvalidArgumentNumber.New(fn.Name, 4, len(args))
}

return fn(args[0], args[1], args[2], args[3]), nil
return fn.Fn(args[0], args[1], args[2], args[3]), nil
}

// Call implements the Function interface.
func (fn Function5) Call(args ...Expression) (Expression, error) {
if len(args) != 5 {
return nil, ErrInvalidArgumentNumber.New(5, len(args))
return nil, ErrInvalidArgumentNumber.New(fn.Name, 5, len(args))
}

return fn(args[0], args[1], args[2], args[3], args[4]), nil
return fn.Fn(args[0], args[1], args[2], args[3], args[4]), nil
}

// Call implements the Function interface.
func (fn Function6) Call(args ...Expression) (Expression, error) {
if len(args) != 6 {
return nil, ErrInvalidArgumentNumber.New(6, len(args))
return nil, ErrInvalidArgumentNumber.New(fn.Name, 6, len(args))
}

return fn(args[0], args[1], args[2], args[3], args[4], args[5]), nil
return fn.Fn(args[0], args[1], args[2], args[3], args[4], args[5]), nil
}

// Call implements the Function interface.
func (fn Function7) Call(args ...Expression) (Expression, error) {
if len(args) != 7 {
return nil, ErrInvalidArgumentNumber.New(7, len(args))
return nil, ErrInvalidArgumentNumber.New(fn.Name, 7, len(args))
}

return fn(args[0], args[1], args[2], args[3], args[4], args[5], args[6]), nil
return fn.Fn(args[0], args[1], args[2], args[3], args[4], args[5], args[6]), nil
}

// Call implements the Function interface.
func (fn FunctionN) Call(args ...Expression) (Expression, error) {
return fn(args...)
return fn.Fn(args...)
}

func (fn Function0) name() string { return fn.Name }
func (fn Function1) name() string { return fn.Name }
func (fn Function2) name() string { return fn.Name }
func (fn Function3) name() string { return fn.Name }
func (fn Function4) name() string { return fn.Name }
func (fn Function5) name() string { return fn.Name }
func (fn Function6) name() string { return fn.Name }
func (fn Function7) name() string { return fn.Name }
func (fn FunctionN) name() string { return fn.Name }

func (Function0) isFunction() {}
func (Function1) isFunction() {}
func (Function2) isFunction() {}
Expand All @@ -134,32 +175,37 @@ func (FunctionN) isFunction() {}
// and User-Defined Functions.
type FunctionRegistry map[string]Function

// Functions is a map of functions identified by their name.
type Functions map[string]Function

// NewFunctionRegistry creates a new FunctionRegistry.
func NewFunctionRegistry() FunctionRegistry {
return make(FunctionRegistry)
}

// RegisterFunction registers a function with the given name.
func (r FunctionRegistry) RegisterFunction(name string, f Function) {
r[name] = f
// Register registers functions.
// If function with that name is already registered,
// the ErrFunctionAlreadyRegistered will be returned
func (r FunctionRegistry) Register(fn ...Function) error {
for _, f := range fn {
if _, ok := r[f.name()]; ok {
return ErrFunctionAlreadyRegistered.New(f.name())
}
r[f.name()] = f
}
return nil
}

// RegisterFunctions registers a map of functions.
func (r FunctionRegistry) RegisterFunctions(funcs Functions) {
for name, f := range funcs {
r[name] = f
// MustRegister registers functions.
// If function with that name is already registered, it will panic!
func (r FunctionRegistry) MustRegister(fn ...Function) {
if err := r.Register(fn...); err != nil {
panic(err)
}
}

// Function returns a function with the given name.
func (r FunctionRegistry) Function(name string) (Function, error) {
e, ok := r[name]
if !ok {
return nil, ErrFunctionNotFound.New(name)
if fn, ok := r[name]; ok {
return fn, nil
}

return e, nil
return nil, ErrFunctionNotFound.New(name)
}
Loading