Skip to content
This repository has been archived by the owner on Jan 28, 2021. It is now read-only.

Commit

Permalink
Merge pull request #750 from erizocosmico/feature/refactor-natural-jo…
Browse files Browse the repository at this point in the history
…in-rule

sql/analyzer: refactor resolve_natural_joins rule
  • Loading branch information
ajnavarro authored Jun 17, 2019
2 parents 5336d8a + 4f6c4f8 commit 5f48ea3
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 235 deletions.
1 change: 0 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ require (
github.com/stretchr/testify v1.2.2
go.etcd.io/bbolt v1.3.2
golang.org/x/net v0.0.0-20190227022144-312bce6e941f // indirect
google.golang.org/genproto v0.0.0-20180831171423-11092d34479b // indirect
google.golang.org/grpc v1.19.0 // indirect
gopkg.in/src-d/go-errors.v1 v1.0.0
gopkg.in/yaml.v2 v2.2.2
Expand Down
3 changes: 3 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b h1:VKtxabqXZkF25pY9ekf
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.0 h1:kbxbvI4Un1LUWKxufD+BiE6AEExYYgkQLQmLFqA1LFk=
github.com/golang/protobuf v1.3.0/go.mod h1:Qd/q+1AKNOZr9uGQzbzCmRO6sUih6GTPZv6a1/R87v0=
github.com/golang/protobuf v1.3.1 h1:YF8+flBXS5eO826T4nzqPrxfhQThhXl0YzfuUPu4SBg=
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c h1:964Od4U6p2jUkFxvCydnIczKteheJEzHRToSGK3Bnlw=
Expand Down Expand Up @@ -137,6 +139,7 @@ golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9 h1:mKdxBk7AujPs8kU4m80U72
golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20181023162649-9b4f9f5ad519/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190227022144-312bce6e941f h1:tbtX/qtlxzhZjgQue/7u7ygFwDEckd+DmS5+t8FgeKE=
golang.org/x/net v0.0.0-20190227022144-312bce6e941f/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
Expand Down
321 changes: 96 additions & 225 deletions sql/analyzer/resolve_natural_joins.go
Original file line number Diff line number Diff line change
@@ -1,266 +1,137 @@
package analyzer

import (
"reflect"
"strings"

"github.com/src-d/go-mysql-server/sql"
"github.com/src-d/go-mysql-server/sql/expression"
"github.com/src-d/go-mysql-server/sql/plan"
)

type transformedJoin struct {
node sql.Node
condCols map[string]*transformedSource
}

type transformedSource struct {
correct string
wrong []string
}

func resolveNaturalJoins(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) {
span, _ := ctx.Span("resolve_natural_joins")
defer span.Finish()

if n.Resolved() {
return n, nil
}

var transformed []*transformedJoin
var aliasTables = map[string][]string{}
var colsToUnresolve = map[string]*transformedSource{}
a.Log("resolving natural joins, node of type %T", n)
node, err := n.TransformUp(func(n sql.Node) (sql.Node, error) {
a.Log("transforming node of type: %T", n)
var replacements = make(map[tableCol]tableCol)
var tableAliases = make(map[string]string)

if alias, ok := n.(*plan.TableAlias); ok {
table := alias.Child.(*plan.ResolvedTable).Name()
aliasTables[alias.Name()] = append(aliasTables[alias.Name()], table)
return n.TransformUp(func(n sql.Node) (sql.Node, error) {
switch n := n.(type) {
case *plan.TableAlias:
alias := n.Name()
table := n.Child.(*plan.ResolvedTable).Name()
tableAliases[strings.ToLower(alias)] = table
return n, nil
}

if n.Resolved() {
case *plan.NaturalJoin:
return resolveNaturalJoin(n, replacements)
case sql.Expressioner:
return replaceExpressions(n, replacements, tableAliases)
default:
return n, nil
}
})
}

join, ok := n.(*plan.NaturalJoin)
if !ok {
return n, nil
}

// we need both leaves resolved before resolving this one
if !join.Left.Resolved() || !join.Right.Resolved() {
return n, nil
}

leftSchema, rightSchema := join.Left.Schema(), join.Right.Schema()

var conditions, common, left, right []sql.Expression
var seen = make(map[string]struct{})

for i, lcol := range leftSchema {
var found bool
leftCol := expression.NewGetFieldWithTable(
i,
lcol.Type,
lcol.Source,
lcol.Name,
lcol.Nullable,
)

for j, rcol := range rightSchema {
if lcol.Name == rcol.Name {
common = append(common, leftCol)

conditions = append(
conditions,
expression.NewEquals(
leftCol,
expression.NewGetFieldWithTable(
len(leftSchema)+j,
rcol.Type,
rcol.Source,
rcol.Name,
rcol.Nullable,
),
),
)

found = true
seen[lcol.Name] = struct{}{}
if source, ok := colsToUnresolve[lcol.Name]; ok {
source.correct = lcol.Source
source.wrong = append(source.wrong, rcol.Source)
} else {
colsToUnresolve[lcol.Name] = &transformedSource{
correct: lcol.Source,
wrong: []string{rcol.Source},
}
}

break
}
}
func resolveNaturalJoin(
n *plan.NaturalJoin,
replacements map[tableCol]tableCol,
) (sql.Node, error) {
// Both sides of the natural join need to be resolved in order to resolve
// the natural join itself.
if !n.Left.Resolved() || !n.Right.Resolved() {
return n, nil
}

if !found {
left = append(left, leftCol)
leftSchema := n.Left.Schema()
rightSchema := n.Right.Schema()

var conditions, common, left, right []sql.Expression
for i, lcol := range leftSchema {
leftCol := expression.NewGetFieldWithTable(
i,
lcol.Type,
lcol.Source,
lcol.Name,
lcol.Nullable,
)
if idx, rcol := findCol(rightSchema, lcol.Name); rcol != nil {
common = append(common, leftCol)
replacements[tableCol{strings.ToLower(rcol.Source), strings.ToLower(rcol.Name)}] = tableCol{
strings.ToLower(lcol.Source), strings.ToLower(lcol.Name),
}
}

if len(conditions) == 0 {
return plan.NewCrossJoin(join.Left, join.Right), nil
}

for i, col := range rightSchema {
if _, ok := seen[col.Name]; !ok {
right = append(
right,
conditions = append(
conditions,
expression.NewEquals(
leftCol,
expression.NewGetFieldWithTable(
len(leftSchema)+i,
col.Type,
col.Source,
col.Name,
col.Nullable,
len(leftSchema)+idx,
rcol.Type,
rcol.Source,
rcol.Name,
rcol.Nullable,
),
)
}
}

projections := append(append(common, left...), right...)

tj := &transformedJoin{
node: plan.NewProject(
projections,
plan.NewInnerJoin(
join.Left,
join.Right,
expression.JoinAnd(conditions...),
),
),
condCols: colsToUnresolve,
)
} else {
left = append(left, leftCol)
}

transformed = append(transformed, tj)

return tj.node, nil
})

if err != nil || len(transformed) == 0 {
return node, err
}

var transformedSeen bool
return node.TransformUp(func(node sql.Node) (sql.Node, error) {
if ok, _ := isTransformedNode(node, transformed); ok {
transformedSeen = true
return node, nil
}

if !transformedSeen {
return node, nil
}

expressioner, ok := node.(sql.Expressioner)
if !ok {
return node, nil
}

return expressioner.TransformExpressions(func(e sql.Expression) (sql.Expression, error) {
var col, table string
switch e := e.(type) {
case *expression.GetField:
col, table = e.Name(), e.Table()
case *expression.UnresolvedColumn:
col, table = e.Name(), e.Table()
default:
return e, nil
}

sources, ok := colsToUnresolve[col]
if !ok {
return e, nil
}

if !mustUnresolve(aliasTables, table, sources.wrong) {
return e, nil
}

return expression.NewUnresolvedQualifiedColumn(
sources.correct,
col,
), nil
})
})
}

func isTransformedNode(node sql.Node, transformed []*transformedJoin) (is bool, colsToUnresolve map[string]*transformedSource) {
var project *plan.Project
var join *plan.InnerJoin
switch n := node.(type) {
case *plan.Project:
var ok bool
join, ok = n.Child.(*plan.InnerJoin)
if !ok {
return
}

project = n
case *plan.InnerJoin:
join = n

default:
return
if len(conditions) == 0 {
return plan.NewCrossJoin(n.Left, n.Right), nil
}

for _, t := range transformed {
tproject, ok := t.node.(*plan.Project)
if !ok {
return
}

tjoin, ok := tproject.Child.(*plan.InnerJoin)
if !ok {
return
}

if project != nil && !reflect.DeepEqual(project.Projections, tproject.Projections) {
continue
}

if reflect.DeepEqual(join.Cond, tjoin.Cond) {
is = true
colsToUnresolve = t.condCols
for i, col := range rightSchema {
source := strings.ToLower(col.Source)
name := strings.ToLower(col.Name)
if _, ok := replacements[tableCol{source, name}]; !ok {
right = append(
right,
expression.NewGetFieldWithTable(
len(leftSchema)+i,
col.Type,
col.Source,
col.Name,
col.Nullable,
),
)
}
}

return
return plan.NewProject(
append(append(common, left...), right...),
plan.NewInnerJoin(n.Left, n.Right, expression.JoinAnd(conditions...)),
), nil
}

func mustUnresolve(aliasTable map[string][]string, table string, wrongSources []string) bool {
return isIn(table, wrongSources) || isAliasFor(aliasTable, table, wrongSources)
}

func isIn(s string, l []string) bool {
for _, e := range l {
if s == e {
return true
func findCol(s sql.Schema, name string) (int, *sql.Column) {
for i, c := range s {
if strings.ToLower(c.Name) == strings.ToLower(name) {
return i, c
}
}

return false
return -1, nil
}

func isAliasFor(aliasTable map[string][]string, table string, wrongSources []string) bool {
tables, ok := aliasTable[table]
if !ok {
return false
}
func replaceExpressions(
n sql.Expressioner,
replacements map[tableCol]tableCol,
tableAliases map[string]string,
) (sql.Node, error) {
return n.TransformExpressions(func(e sql.Expression) (sql.Expression, error) {
switch e := e.(type) {
case *expression.GetField, *expression.UnresolvedColumn:
var tableName = e.(sql.Tableable).Table()
if t, ok := tableAliases[strings.ToLower(tableName)]; ok {
tableName = t
}

for _, t := range tables {
if isIn(t, wrongSources) {
return true
name := e.(sql.Nameable).Name()
if col, ok := replacements[tableCol{strings.ToLower(tableName), strings.ToLower(name)}]; ok {
return expression.NewUnresolvedQualifiedColumn(col.table, col.col), nil
}
}
}

return false
return e, nil
})
}
Loading

0 comments on commit 5f48ea3

Please sign in to comment.