From 3d98abd9265f1a91aae8e866fdc36e9bd4a24f6c Mon Sep 17 00:00:00 2001 From: kuba-- Date: Wed, 10 Apr 2019 14:19:43 +0200 Subject: [PATCH] Pass functiona name to the registry Signed-off-by: kuba-- --- engine.go | 13 ++- sql/expression/function/registry.go | 127 +++++++++++++------------- sql/functionregistry.go | 134 +++++++++++++++++++--------- sql/functionregistry_test.go | 7 +- 4 files changed, 170 insertions(+), 111 deletions(-) diff --git a/engine.go b/engine.go index c8140105a..f42cf8fb6 100644 --- a/engine.go +++ b/engine.go @@ -34,9 +34,16 @@ func New(c *sql.Catalog, a *analyzer.Analyzer, cfg *Config) *Engine { versionPostfix = cfg.VersionPostfix } - c.RegisterFunctions(function.Defaults) - c.RegisterFunction("version", sql.FunctionN(function.NewVersion(versionPostfix))) - c.RegisterFunction("database", sql.Function0(function.NewDatabase(c))) + c.MustRegister( + sql.FunctionN{ + Name: "version", + Fn: function.NewVersion(versionPostfix), + }, + sql.Function0{ + Name: "database", + Fn: function.NewDatabase(c), + }) + c.MustRegister(function.Defaults...) // use auth.None if auth is not specified var au auth.Auth diff --git a/sql/expression/function/registry.go b/sql/expression/function/registry.go index 6c10e2646..4248e6874 100644 --- a/sql/expression/function/registry.go +++ b/sql/expression/function/registry.go @@ -8,65 +8,70 @@ import ( ) // Defaults is the function map with all the default functions. -var Defaults = sql.Functions{ - "count": sql.Function1(func(e sql.Expression) sql.Expression { - return aggregation.NewCount(e) - }), - "min": sql.Function1(func(e sql.Expression) sql.Expression { - return aggregation.NewMin(e) - }), - "max": sql.Function1(func(e sql.Expression) sql.Expression { - return aggregation.NewMax(e) - }), - "avg": sql.Function1(func(e sql.Expression) sql.Expression { - return aggregation.NewAvg(e) - }), - "sum": sql.Function1(func(e sql.Expression) sql.Expression { - return aggregation.NewSum(e) - }), - "is_binary": sql.Function1(NewIsBinary), - "substring": sql.FunctionN(NewSubstring), - "mid": sql.FunctionN(NewSubstring), - "substr": sql.FunctionN(NewSubstring), - "year": sql.Function1(NewYear), - "month": sql.Function1(NewMonth), - "day": sql.Function1(NewDay), - "weekday": sql.Function1(NewWeekday), - "hour": sql.Function1(NewHour), - "minute": sql.Function1(NewMinute), - "second": sql.Function1(NewSecond), - "dayofweek": sql.Function1(NewDayOfWeek), - "dayofyear": sql.Function1(NewDayOfYear), - "array_length": sql.Function1(NewArrayLength), - "split": sql.Function2(NewSplit), - "concat": sql.FunctionN(NewConcat), - "concat_ws": sql.FunctionN(NewConcatWithSeparator), - "coalesce": sql.FunctionN(NewCoalesce), - "lower": sql.Function1(NewLower), - "upper": sql.Function1(NewUpper), - "ceiling": sql.Function1(NewCeil), - "ceil": sql.Function1(NewCeil), - "floor": sql.Function1(NewFloor), - "round": sql.FunctionN(NewRound), - "connection_id": sql.Function0(NewConnectionID), - "soundex": sql.Function1(NewSoundex), - "json_extract": sql.FunctionN(NewJSONExtract), - "ln": sql.Function1(NewLogBaseFunc(float64(math.E))), - "log2": sql.Function1(NewLogBaseFunc(float64(2))), - "log10": sql.Function1(NewLogBaseFunc(float64(10))), - "log": sql.FunctionN(NewLog), - "rpad": sql.FunctionN(NewPadFunc(rPadType)), - "lpad": sql.FunctionN(NewPadFunc(lPadType)), - "sqrt": sql.Function1(NewSqrt), - "pow": sql.Function2(NewPower), - "power": sql.Function2(NewPower), - "ltrim": sql.Function1(NewTrimFunc(lTrimType)), - "rtrim": sql.Function1(NewTrimFunc(rTrimType)), - "trim": sql.Function1(NewTrimFunc(bTrimType)), - "reverse": sql.Function1(NewReverse), - "repeat": sql.Function2(NewRepeat), - "replace": sql.Function3(NewReplace), - "ifnull": sql.Function2(NewIfNull), - "nullif": sql.Function2(NewNullIf), - "now": sql.Function0(NewNow), +var Defaults = []sql.Function{ + sql.Function1{ + Name: "count", + Fn: func(e sql.Expression) sql.Expression { return aggregation.NewCount(e) }, + }, + sql.Function1{ + Name: "min", + Fn: func(e sql.Expression) sql.Expression { return aggregation.NewMin(e) }, + }, + sql.Function1{ + Name: "max", + Fn: func(e sql.Expression) sql.Expression { return aggregation.NewMax(e) }, + }, + sql.Function1{ + Name: "avg", + Fn: func(e sql.Expression) sql.Expression { return aggregation.NewAvg(e) }, + }, + sql.Function1{ + Name: "sum", + Fn: func(e sql.Expression) sql.Expression { return aggregation.NewSum(e) }, + }, + sql.Function1{Name: "is_binary", Fn: NewIsBinary}, + sql.FunctionN{Name: "substring", Fn: NewSubstring}, + sql.FunctionN{Name: "mid", Fn: NewSubstring}, + sql.FunctionN{Name: "substr", Fn: NewSubstring}, + sql.Function1{Name: "year", Fn: NewYear}, + sql.Function1{Name: "month", Fn: NewMonth}, + sql.Function1{Name: "day", Fn: NewDay}, + sql.Function1{Name: "weekday", Fn: NewWeekday}, + sql.Function1{Name: "hour", Fn: NewHour}, + sql.Function1{Name: "minute", Fn: NewMinute}, + sql.Function1{Name: "second", Fn: NewSecond}, + sql.Function1{Name: "dayofweek", Fn: NewDayOfWeek}, + sql.Function1{Name: "dayofyear", Fn: NewDayOfYear}, + sql.Function1{Name: "array_length", Fn: NewArrayLength}, + sql.Function2{Name: "split", Fn: NewSplit}, + sql.FunctionN{Name: "concat", Fn: NewConcat}, + sql.FunctionN{Name: "concat_ws", Fn: NewConcatWithSeparator}, + sql.FunctionN{Name: "coalesce", Fn: NewCoalesce}, + sql.Function1{Name: "lower", Fn: NewLower}, + sql.Function1{Name: "upper", Fn: NewUpper}, + sql.Function1{Name: "ceiling", Fn: NewCeil}, + sql.Function1{Name: "ceil", Fn: NewCeil}, + sql.Function1{Name: "floor", Fn: NewFloor}, + sql.FunctionN{Name: "round", Fn: NewRound}, + sql.Function0{Name: "connection_id", Fn: NewConnectionID}, + sql.Function1{Name: "soundex", Fn: NewSoundex}, + sql.FunctionN{Name: "json_extract", Fn: NewJSONExtract}, + sql.Function1{Name: "ln", Fn: NewLogBaseFunc(float64(math.E))}, + sql.Function1{Name: "log2", Fn: NewLogBaseFunc(float64(2))}, + sql.Function1{Name: "log10", Fn: NewLogBaseFunc(float64(10))}, + sql.FunctionN{Name: "log", Fn: NewLog}, + sql.FunctionN{Name: "rpad", Fn: NewPadFunc(rPadType)}, + sql.FunctionN{Name: "lpad", Fn: NewPadFunc(lPadType)}, + sql.Function1{Name: "sqrt", Fn: NewSqrt}, + sql.Function2{Name: "pow", Fn: NewPower}, + sql.Function2{Name: "power", Fn: NewPower}, + sql.Function1{Name: "ltrim", Fn: NewTrimFunc(lTrimType)}, + sql.Function1{Name: "rtrim", Fn: NewTrimFunc(rTrimType)}, + sql.Function1{Name: "trim", Fn: NewTrimFunc(bTrimType)}, + sql.Function1{Name: "reverse", Fn: NewReverse}, + sql.Function2{Name: "repeat", Fn: NewRepeat}, + sql.Function3{Name: "replace", Fn: NewReplace}, + sql.Function2{Name: "ifnull", Fn: NewIfNull}, + sql.Function2{Name: "nullif", Fn: NewNullIf}, + sql.Function0{Name: "now", Fn: NewNow}, } diff --git a/sql/functionregistry.go b/sql/functionregistry.go index 6a6c0ef3a..b78d7d373 100644 --- a/sql/functionregistry.go +++ b/sql/functionregistry.go @@ -4,122 +4,163 @@ import ( "gopkg.in/src-d/go-errors.v1" ) +// ErrFunctionAlreadyRegistered is thrown when a function is already registered +var ErrFunctionAlreadyRegistered = errors.NewKind("A function: '%s' is already registered.") + // ErrFunctionNotFound is thrown when a function is not found -var ErrFunctionNotFound = errors.NewKind("function not found: %s") +var ErrFunctionNotFound = errors.NewKind("A function: '%s' not found.") // ErrInvalidArgumentNumber is returned when the number of arguments to call a // function is different from the function arity. -var ErrInvalidArgumentNumber = errors.NewKind("%s: expecting %v arguments for calling this function, %d received") +var ErrInvalidArgumentNumber = errors.NewKind("A function: '%s' expected %d arguments, %d received.") -// Function is a function defined by the user that can be applied in a SQL -// query. +// Function is a function defined by the user that can be applied in a SQL query. type Function interface { // Call invokes the function. Call(...Expression) (Expression, error) + // Function name + name() string // isFunction will restrict implementations of Function isFunction() } type ( // Function0 is a function with 0 arguments. - Function0 func() Expression + Function0 struct { + Name string + Fn func() Expression + } // Function1 is a function with 1 argument. - Function1 func(e Expression) Expression + Function1 struct { + Name string + Fn func(e Expression) Expression + } // Function2 is a function with 2 arguments. - Function2 func(e1, e2 Expression) Expression + Function2 struct { + Name string + Fn func(e1, e2 Expression) Expression + } // Function3 is a function with 3 arguments. - Function3 func(e1, e2, e3 Expression) Expression + Function3 struct { + Name string + Fn func(e1, e2, e3 Expression) Expression + } // Function4 is a function with 4 arguments. - Function4 func(e1, e2, e3, e4 Expression) Expression + Function4 struct { + Name string + Fn func(e1, e2, e3, e4 Expression) Expression + } // Function5 is a function with 5 arguments. - Function5 func(e1, e2, e3, e4, e5 Expression) Expression + Function5 struct { + Name string + Fn func(e1, e2, e3, e4, e5 Expression) Expression + } // Function6 is a function with 6 arguments. - Function6 func(e1, e2, e3, e4, e5, e6 Expression) Expression + Function6 struct { + Name string + Fn func(e1, e2, e3, e4, e5, e6 Expression) Expression + } // Function7 is a function with 7 arguments. - Function7 func(e1, e2, e3, e4, e5, e6, e7 Expression) Expression + Function7 struct { + Name string + Fn func(e1, e2, e3, e4, e5, e6, e7 Expression) Expression + } // FunctionN is a function with variable number of arguments. This function // is expected to return ErrInvalidArgumentNumber if the arity does not // match, since the check has to be done in the implementation. - FunctionN func(...Expression) (Expression, error) + FunctionN struct { + Name string + Fn func(...Expression) (Expression, error) + } ) // Call implements the Function interface. func (fn Function0) Call(args ...Expression) (Expression, error) { if len(args) != 0 { - return nil, ErrInvalidArgumentNumber.New(0, len(args)) + return nil, ErrInvalidArgumentNumber.New(fn.Name, 0, len(args)) } - return fn(), nil + return fn.Fn(), nil } // Call implements the Function interface. func (fn Function1) Call(args ...Expression) (Expression, error) { if len(args) != 1 { - return nil, ErrInvalidArgumentNumber.New(1, len(args)) + return nil, ErrInvalidArgumentNumber.New(fn.Name, 1, len(args)) } - return fn(args[0]), nil + return fn.Fn(args[0]), nil } // Call implements the Function interface. func (fn Function2) Call(args ...Expression) (Expression, error) { if len(args) != 2 { - return nil, ErrInvalidArgumentNumber.New(2, len(args)) + return nil, ErrInvalidArgumentNumber.New(fn.Name, 2, len(args)) } - return fn(args[0], args[1]), nil + return fn.Fn(args[0], args[1]), nil } // Call implements the Function interface. func (fn Function3) Call(args ...Expression) (Expression, error) { if len(args) != 3 { - return nil, ErrInvalidArgumentNumber.New(3, len(args)) + return nil, ErrInvalidArgumentNumber.New(fn.Name, 3, len(args)) } - return fn(args[0], args[1], args[2]), nil + return fn.Fn(args[0], args[1], args[2]), nil } // Call implements the Function interface. func (fn Function4) Call(args ...Expression) (Expression, error) { if len(args) != 4 { - return nil, ErrInvalidArgumentNumber.New(4, len(args)) + return nil, ErrInvalidArgumentNumber.New(fn.Name, 4, len(args)) } - return fn(args[0], args[1], args[2], args[3]), nil + return fn.Fn(args[0], args[1], args[2], args[3]), nil } // Call implements the Function interface. func (fn Function5) Call(args ...Expression) (Expression, error) { if len(args) != 5 { - return nil, ErrInvalidArgumentNumber.New(5, len(args)) + return nil, ErrInvalidArgumentNumber.New(fn.Name, 5, len(args)) } - return fn(args[0], args[1], args[2], args[3], args[4]), nil + return fn.Fn(args[0], args[1], args[2], args[3], args[4]), nil } // Call implements the Function interface. func (fn Function6) Call(args ...Expression) (Expression, error) { if len(args) != 6 { - return nil, ErrInvalidArgumentNumber.New(6, len(args)) + return nil, ErrInvalidArgumentNumber.New(fn.Name, 6, len(args)) } - return fn(args[0], args[1], args[2], args[3], args[4], args[5]), nil + return fn.Fn(args[0], args[1], args[2], args[3], args[4], args[5]), nil } // Call implements the Function interface. func (fn Function7) Call(args ...Expression) (Expression, error) { if len(args) != 7 { - return nil, ErrInvalidArgumentNumber.New(7, len(args)) + return nil, ErrInvalidArgumentNumber.New(fn.Name, 7, len(args)) } - return fn(args[0], args[1], args[2], args[3], args[4], args[5], args[6]), nil + return fn.Fn(args[0], args[1], args[2], args[3], args[4], args[5], args[6]), nil } // Call implements the Function interface. func (fn FunctionN) Call(args ...Expression) (Expression, error) { - return fn(args...) + return fn.Fn(args...) } +func (fn Function0) name() string { return fn.Name } +func (fn Function1) name() string { return fn.Name } +func (fn Function2) name() string { return fn.Name } +func (fn Function3) name() string { return fn.Name } +func (fn Function4) name() string { return fn.Name } +func (fn Function5) name() string { return fn.Name } +func (fn Function6) name() string { return fn.Name } +func (fn Function7) name() string { return fn.Name } +func (fn FunctionN) name() string { return fn.Name } + func (Function0) isFunction() {} func (Function1) isFunction() {} func (Function2) isFunction() {} @@ -134,32 +175,37 @@ func (FunctionN) isFunction() {} // and User-Defined Functions. type FunctionRegistry map[string]Function -// Functions is a map of functions identified by their name. -type Functions map[string]Function - // NewFunctionRegistry creates a new FunctionRegistry. func NewFunctionRegistry() FunctionRegistry { return make(FunctionRegistry) } -// RegisterFunction registers a function with the given name. -func (r FunctionRegistry) RegisterFunction(name string, f Function) { - r[name] = f +// Register registers functions. +// If function with that name is already registered, +// the ErrFunctionAlreadyRegistered will be returned +func (r FunctionRegistry) Register(fn ...Function) error { + for _, f := range fn { + if _, ok := r[f.name()]; ok { + return ErrFunctionAlreadyRegistered.New(f.name()) + } + r[f.name()] = f + } + return nil } -// RegisterFunctions registers a map of functions. -func (r FunctionRegistry) RegisterFunctions(funcs Functions) { - for name, f := range funcs { - r[name] = f +// MustRegister registers functions. +// If function with that name is already registered, it will panic! +func (r FunctionRegistry) MustRegister(fn ...Function) { + if err := r.Register(fn...); err != nil { + panic(err) } } // Function returns a function with the given name. func (r FunctionRegistry) Function(name string) (Function, error) { - e, ok := r[name] - if !ok { - return nil, ErrFunctionNotFound.New(name) + if fn, ok := r[name]; ok { + return fn, nil } - return e, nil + return nil, ErrFunctionNotFound.New(name) } diff --git a/sql/functionregistry_test.go b/sql/functionregistry_test.go index 706779fcf..7c9fdef08 100644 --- a/sql/functionregistry_test.go +++ b/sql/functionregistry_test.go @@ -14,9 +14,10 @@ func TestFunctionRegistry(t *testing.T) { c := sql.NewCatalog() name := "func" var expected sql.Expression = expression.NewStar() - c.RegisterFunction(name, sql.Function1(func(arg sql.Expression) sql.Expression { - return expected - })) + c.MustRegister(sql.Function1{ + Name: name, + Fn: func(arg sql.Expression) sql.Expression { return expected }, + }) f, err := c.Function(name) require.NoError(err)