From d7e7d21f6486619c3cba7ef95854cd71dc174833 Mon Sep 17 00:00:00 2001 From: Miguel Molina Date: Fri, 2 Nov 2018 10:47:36 +0100 Subject: [PATCH] sql/analyzer: add rule to avoid unnecessary casts Signed-off-by: Miguel Molina --- sql/analyzer/optimization_rules.go | 19 ++++++++++ sql/analyzer/optimization_rules_test.go | 47 +++++++++++++++++++++++++ sql/analyzer/rules.go | 1 + 3 files changed, 67 insertions(+) diff --git a/sql/analyzer/optimization_rules.go b/sql/analyzer/optimization_rules.go index ad53f422c..787e8525b 100644 --- a/sql/analyzer/optimization_rules.go +++ b/sql/analyzer/optimization_rules.go @@ -245,6 +245,25 @@ func moveJoinConditionsToFilter(ctx *sql.Context, a *Analyzer, n sql.Node) (sql. }) } +func removeUnnecessaryConverts(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { + span, _ := ctx.Span("remove_unnecessary_converts") + defer span.Finish() + + if !n.Resolved() { + return n, nil + } + + a.Log("removing unnecessary converts, node of type: %T", n) + + return n.TransformExpressionsUp(func(e sql.Expression) (sql.Expression, error) { + if c, ok := e.(*expression.Convert); ok && c.Child.Type() == c.Type() { + return c.Child, nil + } + + return e, nil + }) +} + // containsSources checks that all `needle` sources are contained inside `haystack`. func containsSources(haystack, needle []string) bool { for _, s := range needle { diff --git a/sql/analyzer/optimization_rules_test.go b/sql/analyzer/optimization_rules_test.go index d97411e70..6e0d5c11e 100644 --- a/sql/analyzer/optimization_rules_test.go +++ b/sql/analyzer/optimization_rules_test.go @@ -396,3 +396,50 @@ func TestEvalFilter(t *testing.T) { }) } } + +func TestRemoveUnnecessaryConverts(t *testing.T) { + testCases := []struct { + name string + childExpr sql.Expression + castType string + expected sql.Expression + }{ + { + "unnecessary cast", + expression.NewLiteral([]byte{}, sql.Blob), + "binary", + expression.NewLiteral([]byte{}, sql.Blob), + }, + { + "necessary cast", + expression.NewLiteral("foo", sql.Text), + "signed", + expression.NewConvert( + expression.NewLiteral("foo", sql.Text), + "signed", + ), + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + require := require.New(t) + + node := plan.NewProject([]sql.Expression{ + expression.NewConvert(tt.childExpr, tt.castType), + }, + plan.NewResolvedTable(mem.NewTable("foo", nil)), + ) + + result, err := removeUnnecessaryConverts( + sql.NewEmptyContext(), + NewDefault(nil), + node, + ) + require.NoError(err) + + resultExpr := result.(*plan.Project).Projections[0] + require.Equal(tt.expected, resultExpr) + }) + } +} diff --git a/sql/analyzer/rules.go b/sql/analyzer/rules.go index f4ea6a919..15dd81aa2 100644 --- a/sql/analyzer/rules.go +++ b/sql/analyzer/rules.go @@ -32,6 +32,7 @@ var OnceBeforeDefault = []Rule{ // OnceAfterDefault contains the rules to be applied just once after the // DefaultRules. var OnceAfterDefault = []Rule{ + {"remove_unnecessary_converts", removeUnnecessaryConverts}, {"assign_catalog", assignCatalog}, {"pushdown", pushdown}, {"erase_projection", eraseProjection},