This repository has been archived by the owner on Jan 28, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 109
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #750 from erizocosmico/feature/refactor-natural-jo…
…in-rule sql/analyzer: refactor resolve_natural_joins rule
- Loading branch information
Showing
4 changed files
with
100 additions
and
235 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,266 +1,137 @@ | ||
package analyzer | ||
|
||
import ( | ||
"reflect" | ||
"strings" | ||
|
||
"github.com/src-d/go-mysql-server/sql" | ||
"github.com/src-d/go-mysql-server/sql/expression" | ||
"github.com/src-d/go-mysql-server/sql/plan" | ||
) | ||
|
||
type transformedJoin struct { | ||
node sql.Node | ||
condCols map[string]*transformedSource | ||
} | ||
|
||
type transformedSource struct { | ||
correct string | ||
wrong []string | ||
} | ||
|
||
func resolveNaturalJoins(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) { | ||
span, _ := ctx.Span("resolve_natural_joins") | ||
defer span.Finish() | ||
|
||
if n.Resolved() { | ||
return n, nil | ||
} | ||
|
||
var transformed []*transformedJoin | ||
var aliasTables = map[string][]string{} | ||
var colsToUnresolve = map[string]*transformedSource{} | ||
a.Log("resolving natural joins, node of type %T", n) | ||
node, err := n.TransformUp(func(n sql.Node) (sql.Node, error) { | ||
a.Log("transforming node of type: %T", n) | ||
var replacements = make(map[tableCol]tableCol) | ||
var tableAliases = make(map[string]string) | ||
|
||
if alias, ok := n.(*plan.TableAlias); ok { | ||
table := alias.Child.(*plan.ResolvedTable).Name() | ||
aliasTables[alias.Name()] = append(aliasTables[alias.Name()], table) | ||
return n.TransformUp(func(n sql.Node) (sql.Node, error) { | ||
switch n := n.(type) { | ||
case *plan.TableAlias: | ||
alias := n.Name() | ||
table := n.Child.(*plan.ResolvedTable).Name() | ||
tableAliases[strings.ToLower(alias)] = table | ||
return n, nil | ||
} | ||
|
||
if n.Resolved() { | ||
case *plan.NaturalJoin: | ||
return resolveNaturalJoin(n, replacements) | ||
case sql.Expressioner: | ||
return replaceExpressions(n, replacements, tableAliases) | ||
default: | ||
return n, nil | ||
} | ||
}) | ||
} | ||
|
||
join, ok := n.(*plan.NaturalJoin) | ||
if !ok { | ||
return n, nil | ||
} | ||
|
||
// we need both leaves resolved before resolving this one | ||
if !join.Left.Resolved() || !join.Right.Resolved() { | ||
return n, nil | ||
} | ||
|
||
leftSchema, rightSchema := join.Left.Schema(), join.Right.Schema() | ||
|
||
var conditions, common, left, right []sql.Expression | ||
var seen = make(map[string]struct{}) | ||
|
||
for i, lcol := range leftSchema { | ||
var found bool | ||
leftCol := expression.NewGetFieldWithTable( | ||
i, | ||
lcol.Type, | ||
lcol.Source, | ||
lcol.Name, | ||
lcol.Nullable, | ||
) | ||
|
||
for j, rcol := range rightSchema { | ||
if lcol.Name == rcol.Name { | ||
common = append(common, leftCol) | ||
|
||
conditions = append( | ||
conditions, | ||
expression.NewEquals( | ||
leftCol, | ||
expression.NewGetFieldWithTable( | ||
len(leftSchema)+j, | ||
rcol.Type, | ||
rcol.Source, | ||
rcol.Name, | ||
rcol.Nullable, | ||
), | ||
), | ||
) | ||
|
||
found = true | ||
seen[lcol.Name] = struct{}{} | ||
if source, ok := colsToUnresolve[lcol.Name]; ok { | ||
source.correct = lcol.Source | ||
source.wrong = append(source.wrong, rcol.Source) | ||
} else { | ||
colsToUnresolve[lcol.Name] = &transformedSource{ | ||
correct: lcol.Source, | ||
wrong: []string{rcol.Source}, | ||
} | ||
} | ||
|
||
break | ||
} | ||
} | ||
func resolveNaturalJoin( | ||
n *plan.NaturalJoin, | ||
replacements map[tableCol]tableCol, | ||
) (sql.Node, error) { | ||
// Both sides of the natural join need to be resolved in order to resolve | ||
// the natural join itself. | ||
if !n.Left.Resolved() || !n.Right.Resolved() { | ||
return n, nil | ||
} | ||
|
||
if !found { | ||
left = append(left, leftCol) | ||
leftSchema := n.Left.Schema() | ||
rightSchema := n.Right.Schema() | ||
|
||
var conditions, common, left, right []sql.Expression | ||
for i, lcol := range leftSchema { | ||
leftCol := expression.NewGetFieldWithTable( | ||
i, | ||
lcol.Type, | ||
lcol.Source, | ||
lcol.Name, | ||
lcol.Nullable, | ||
) | ||
if idx, rcol := findCol(rightSchema, lcol.Name); rcol != nil { | ||
common = append(common, leftCol) | ||
replacements[tableCol{strings.ToLower(rcol.Source), strings.ToLower(rcol.Name)}] = tableCol{ | ||
strings.ToLower(lcol.Source), strings.ToLower(lcol.Name), | ||
} | ||
} | ||
|
||
if len(conditions) == 0 { | ||
return plan.NewCrossJoin(join.Left, join.Right), nil | ||
} | ||
|
||
for i, col := range rightSchema { | ||
if _, ok := seen[col.Name]; !ok { | ||
right = append( | ||
right, | ||
conditions = append( | ||
conditions, | ||
expression.NewEquals( | ||
leftCol, | ||
expression.NewGetFieldWithTable( | ||
len(leftSchema)+i, | ||
col.Type, | ||
col.Source, | ||
col.Name, | ||
col.Nullable, | ||
len(leftSchema)+idx, | ||
rcol.Type, | ||
rcol.Source, | ||
rcol.Name, | ||
rcol.Nullable, | ||
), | ||
) | ||
} | ||
} | ||
|
||
projections := append(append(common, left...), right...) | ||
|
||
tj := &transformedJoin{ | ||
node: plan.NewProject( | ||
projections, | ||
plan.NewInnerJoin( | ||
join.Left, | ||
join.Right, | ||
expression.JoinAnd(conditions...), | ||
), | ||
), | ||
condCols: colsToUnresolve, | ||
) | ||
} else { | ||
left = append(left, leftCol) | ||
} | ||
|
||
transformed = append(transformed, tj) | ||
|
||
return tj.node, nil | ||
}) | ||
|
||
if err != nil || len(transformed) == 0 { | ||
return node, err | ||
} | ||
|
||
var transformedSeen bool | ||
return node.TransformUp(func(node sql.Node) (sql.Node, error) { | ||
if ok, _ := isTransformedNode(node, transformed); ok { | ||
transformedSeen = true | ||
return node, nil | ||
} | ||
|
||
if !transformedSeen { | ||
return node, nil | ||
} | ||
|
||
expressioner, ok := node.(sql.Expressioner) | ||
if !ok { | ||
return node, nil | ||
} | ||
|
||
return expressioner.TransformExpressions(func(e sql.Expression) (sql.Expression, error) { | ||
var col, table string | ||
switch e := e.(type) { | ||
case *expression.GetField: | ||
col, table = e.Name(), e.Table() | ||
case *expression.UnresolvedColumn: | ||
col, table = e.Name(), e.Table() | ||
default: | ||
return e, nil | ||
} | ||
|
||
sources, ok := colsToUnresolve[col] | ||
if !ok { | ||
return e, nil | ||
} | ||
|
||
if !mustUnresolve(aliasTables, table, sources.wrong) { | ||
return e, nil | ||
} | ||
|
||
return expression.NewUnresolvedQualifiedColumn( | ||
sources.correct, | ||
col, | ||
), nil | ||
}) | ||
}) | ||
} | ||
|
||
func isTransformedNode(node sql.Node, transformed []*transformedJoin) (is bool, colsToUnresolve map[string]*transformedSource) { | ||
var project *plan.Project | ||
var join *plan.InnerJoin | ||
switch n := node.(type) { | ||
case *plan.Project: | ||
var ok bool | ||
join, ok = n.Child.(*plan.InnerJoin) | ||
if !ok { | ||
return | ||
} | ||
|
||
project = n | ||
case *plan.InnerJoin: | ||
join = n | ||
|
||
default: | ||
return | ||
if len(conditions) == 0 { | ||
return plan.NewCrossJoin(n.Left, n.Right), nil | ||
} | ||
|
||
for _, t := range transformed { | ||
tproject, ok := t.node.(*plan.Project) | ||
if !ok { | ||
return | ||
} | ||
|
||
tjoin, ok := tproject.Child.(*plan.InnerJoin) | ||
if !ok { | ||
return | ||
} | ||
|
||
if project != nil && !reflect.DeepEqual(project.Projections, tproject.Projections) { | ||
continue | ||
} | ||
|
||
if reflect.DeepEqual(join.Cond, tjoin.Cond) { | ||
is = true | ||
colsToUnresolve = t.condCols | ||
for i, col := range rightSchema { | ||
source := strings.ToLower(col.Source) | ||
name := strings.ToLower(col.Name) | ||
if _, ok := replacements[tableCol{source, name}]; !ok { | ||
right = append( | ||
right, | ||
expression.NewGetFieldWithTable( | ||
len(leftSchema)+i, | ||
col.Type, | ||
col.Source, | ||
col.Name, | ||
col.Nullable, | ||
), | ||
) | ||
} | ||
} | ||
|
||
return | ||
return plan.NewProject( | ||
append(append(common, left...), right...), | ||
plan.NewInnerJoin(n.Left, n.Right, expression.JoinAnd(conditions...)), | ||
), nil | ||
} | ||
|
||
func mustUnresolve(aliasTable map[string][]string, table string, wrongSources []string) bool { | ||
return isIn(table, wrongSources) || isAliasFor(aliasTable, table, wrongSources) | ||
} | ||
|
||
func isIn(s string, l []string) bool { | ||
for _, e := range l { | ||
if s == e { | ||
return true | ||
func findCol(s sql.Schema, name string) (int, *sql.Column) { | ||
for i, c := range s { | ||
if strings.ToLower(c.Name) == strings.ToLower(name) { | ||
return i, c | ||
} | ||
} | ||
|
||
return false | ||
return -1, nil | ||
} | ||
|
||
func isAliasFor(aliasTable map[string][]string, table string, wrongSources []string) bool { | ||
tables, ok := aliasTable[table] | ||
if !ok { | ||
return false | ||
} | ||
func replaceExpressions( | ||
n sql.Expressioner, | ||
replacements map[tableCol]tableCol, | ||
tableAliases map[string]string, | ||
) (sql.Node, error) { | ||
return n.TransformExpressions(func(e sql.Expression) (sql.Expression, error) { | ||
switch e := e.(type) { | ||
case *expression.GetField, *expression.UnresolvedColumn: | ||
var tableName = e.(sql.Tableable).Table() | ||
if t, ok := tableAliases[strings.ToLower(tableName)]; ok { | ||
tableName = t | ||
} | ||
|
||
for _, t := range tables { | ||
if isIn(t, wrongSources) { | ||
return true | ||
name := e.(sql.Nameable).Name() | ||
if col, ok := replacements[tableCol{strings.ToLower(tableName), strings.ToLower(name)}]; ok { | ||
return expression.NewUnresolvedQualifiedColumn(col.table, col.col), nil | ||
} | ||
} | ||
} | ||
|
||
return false | ||
return e, nil | ||
}) | ||
} |
Oops, something went wrong.