Skip to content

Commit

Permalink
Rewrite USING to ON condition for joins (#13931)
Browse files Browse the repository at this point in the history
Co-authored-by: Andres Taylor <andres@planetscale.com>
Signed-off-by: Florent Poinsard <florent.poinsard@outlook.fr>
  • Loading branch information
frouioui and systay committed Sep 18, 2023
1 parent 7012754 commit 295e417
Show file tree
Hide file tree
Showing 9 changed files with 212 additions and 69 deletions.
8 changes: 8 additions & 0 deletions go/test/endtoend/vtgate/queries/misc/misc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -240,3 +240,11 @@ func TestBuggyOuterJoin(t *testing.T) {

mcmp.Exec("select t1.id1, t2.id1 from t1 left join t1 as t2 on t2.id1 = t2.id2")
}

func TestLeftJoinUsingUnsharded(t *testing.T) {
mcmp, closer := start(t)
defer closer()

utils.Exec(t, mcmp.VtConn, "insert into uks.unsharded(id1) values (1),(2),(3),(4),(5)")
utils.Exec(t, mcmp.VtConn, "select * from uks.unsharded as A left join uks.unsharded as B using(id1)")
}
2 changes: 2 additions & 0 deletions go/vt/vterrors/code.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ var (
VT09008 = errorWithoutState("VT09008", vtrpcpb.Code_FAILED_PRECONDITION, "vexplain queries/all will actually run queries", "vexplain queries/all will actually run queries. `/*vt+ EXECUTE_DML_QUERIES */` must be set to run DML queries in vtexplain. Example: `vexplain /*vt+ EXECUTE_DML_QUERIES */ queries delete from t1`")
VT09009 = errorWithoutState("VT09009", vtrpcpb.Code_FAILED_PRECONDITION, "stream is supported only for primary tablet type, current type: %v", "Stream is only supported for primary tablets, please use a stream on those tablets.")
VT09010 = errorWithoutState("VT09010", vtrpcpb.Code_FAILED_PRECONDITION, "SHOW VITESS_THROTTLER STATUS works only on primary tablet", "SHOW VITESS_THROTTLER STATUS works only on primary tablet.")
VT09015 = errorWithoutState("VT09015", vtrpcpb.Code_FAILED_PRECONDITION, "schema tracking required", "This query cannot be planned without more information on the SQL schema. Please turn on schema tracking or add authoritative columns information to your VSchema.")

VT10001 = errorWithoutState("VT10001", vtrpcpb.Code_ABORTED, "foreign key constraints are not allowed", "Foreign key constraints are not allowed, see https://vitess.io/blog/2021-06-15-online-ddl-why-no-fk/.")

Expand Down Expand Up @@ -123,6 +124,7 @@ var (
VT09008,
VT09009,
VT09010,
VT09015,
VT10001,
VT12001,
VT13001,
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vterrors/errors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ func TestStackFormat(t *testing.T) {
// but the change in errors#27 made them incomparable. Assert that
// various kinds of errors have a functional equality operator, even
// if the result of that equality is always false.
func TestErrorEquality(t *testing.T) {
func TestErrorEquality(_ *testing.T) {
vals := []error{
nil,
io.EOF,
Expand Down
22 changes: 22 additions & 0 deletions go/vt/vtgate/planbuilder/testdata/from_cases.json
Original file line number Diff line number Diff line change
Expand Up @@ -6499,5 +6499,27 @@
"zlookup_unique.t1"
]
}
},
{
"comment": "left join with using has to be transformed into inner join with on condition",
"query": "SELECT * FROM unsharded_authoritative as A LEFT JOIN unsharded_authoritative as B USING(col1)",
"plan": {
"QueryType": "SELECT",
"Original": "SELECT * FROM unsharded_authoritative as A LEFT JOIN unsharded_authoritative as B USING(col1)",
"Instructions": {
"OperatorType": "Route",
"Variant": "Unsharded",
"Keyspace": {
"Name": "main",
"Sharded": false
},
"FieldQuery": "select A.col1 as col1, A.col2 as col2, B.col2 as col2 from unsharded_authoritative as A left join unsharded_authoritative as B on A.col1 = B.col1 where 1 != 1",
"Query": "select A.col1 as col1, A.col2 as col2, B.col2 as col2 from unsharded_authoritative as A left join unsharded_authoritative as B on A.col1 = B.col1",
"Table": "unsharded_authoritative"
},
"TablesUsed": [
"main.unsharded_authoritative"
]
}
}
]
4 changes: 2 additions & 2 deletions go/vt/vtgate/planbuilder/testdata/unsupported_cases.json
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,13 @@
"comment": "join with USING construct",
"query": "select * from user join user_extra using(id)",
"v3-plan": "VT12001: unsupported: JOIN with USING(column_list) clause for complex queries",
"gen4-plan": "can't handle JOIN USING without authoritative tables"
"gen4-plan": "VT09015: schema tracking required"
},
{
"comment": "join with USING construct with 3 tables",
"query": "select user.id from user join user_extra using(id) join music using(id2)",
"v3-plan": "VT12001: unsupported: JOIN with USING(column_list) clause for complex queries",
"gen4-plan": "can't handle JOIN USING without authoritative tables"
"gen4-plan": "VT09015: schema tracking required"
},
{
"comment": "natural left join",
Expand Down
5 changes: 5 additions & 0 deletions go/vt/vtgate/semantics/analyzer.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,11 @@ func (a *analyzer) analyzeUp(cursor *sqlparser.Cursor) bool {
return false
}

if err := a.rewriter.up(cursor); err != nil {
a.setError(err)
return true
}

a.leaveProjection(cursor)
return a.shouldContinue()
}
Expand Down
7 changes: 0 additions & 7 deletions go/vt/vtgate/semantics/binder.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,6 @@ func (b *binder) up(cursor *sqlparser.Cursor) error {
}
currScope.joinUsing[ident.Lowered()] = deps.direct
}
if len(node.Using) > 0 {
err := rewriteJoinUsing(currScope, node.Using, b.org)
if err != nil {
return err
}
node.Using = nil
}
case *sqlparser.ColName:
currentScope := b.scoper.currentScope()
deps, err := b.resolveColumn(node, currentScope, false)
Expand Down
208 changes: 156 additions & 52 deletions go/vt/vtgate/semantics/early_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ limitations under the License.
package semantics

import (
"fmt"
"strconv"
"strings"

"vitess.io/vitess/go/vt/vtgate/evalengine"

Expand Down Expand Up @@ -108,6 +108,33 @@ func (r *earlyRewriter) down(cursor *sqlparser.Cursor) error {
return nil
}

func (r *earlyRewriter) up(cursor *sqlparser.Cursor) error {
// this rewriting is done in the `up` phase, because we need the scope to have been
// filled in with the available tables
node, ok := cursor.Node().(*sqlparser.JoinTableExpr)
if !ok || len(node.Condition.Using) == 0 {
return nil
}

err := rewriteJoinUsing(r.binder, node)
if err != nil {
return err
}

// since the binder has already been over the join, we need to invoke it again so it
// can bind columns to the right tables
sqlparser.Rewrite(node.Condition.On, nil, func(cursor *sqlparser.Cursor) bool {
innerErr := r.binder.up(cursor)
if innerErr == nil {
return true
}

err = innerErr
return false
})
return err
}

func (r *earlyRewriter) expandStar(cursor *sqlparser.Cursor, node sqlparser.SelectExprs) error {
currentScope := r.scoper.currentScope()
var selExprs sqlparser.SelectExprs
Expand Down Expand Up @@ -279,67 +306,144 @@ func rewriteOrFalse(orExpr sqlparser.OrExpr) sqlparser.Expr {
return nil
}

func rewriteJoinUsing(
current *scope,
using sqlparser.Columns,
org originable,
) error {
joinUsing := current.prepareUsingMap()
predicates := make([]sqlparser.Expr, 0, len(using))
for _, column := range using {
var foundTables []sqlparser.TableName
for _, tbl := range current.tables {
if !tbl.authoritative() {
return vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "can't handle JOIN USING without authoritative tables")
}
// rewriteJoinUsing rewrites SQL JOINs that use the USING clause to their equivalent
// JOINs with the ON condition. This function finds all the tables that have the
// specified columns in the USING clause, constructs an equality predicate for
// each pair of tables, and adds the resulting predicates to the WHERE clause
// of the outermost SELECT statement.
//
// For example, given the query:
//
// SELECT * FROM t1 JOIN t2 USING (col1, col2)
//
// The rewriteJoinUsing function will rewrite the query to:
//
// SELECT * FROM t1 JOIN t2 ON (t1.col1 = t2.col1 AND t1.col2 = t2.col2)
//
// This function returns an error if it encounters a non-authoritative table or
// if it cannot find a SELECT statement to add the WHERE predicate to.
func rewriteJoinUsing(b *binder, join *sqlparser.JoinTableExpr) error {
predicates, err := buildJoinPredicates(b, join)
if err != nil {
return err
}
if len(predicates) > 0 {
join.Condition.On = sqlparser.AndExpressions(predicates...)
join.Condition.Using = nil
}
return nil
}

currTable := tbl.getTableSet(org)
usingCols := joinUsing[currTable]
if usingCols == nil {
usingCols = map[string]TableSet{}
}
for _, col := range tbl.getColumns() {
_, found := usingCols[strings.ToLower(col.Name)]
if found {
tblName, err := tbl.Name()
if err != nil {
return err
}
// buildJoinPredicates constructs the join predicates for a given set of USING columns.
// It returns a slice of sqlparser.Expr, each representing a join predicate for the given columns.
func buildJoinPredicates(b *binder, join *sqlparser.JoinTableExpr) ([]sqlparser.Expr, error) {
var predicates []sqlparser.Expr

foundTables = append(foundTables, tblName)
break // no need to look at other columns in this table
}
for _, column := range join.Condition.Using {
foundTables, err := findTablesWithColumn(b, join, column)
if err != nil {
return nil, err
}

predicates = append(predicates, createComparisonPredicates(column, foundTables)...)
}

return predicates, nil
}

func findOnlyOneTableInfoThatHasColumn(b *binder, tbl sqlparser.TableExpr, column sqlparser.IdentifierCI) ([]TableInfo, error) {
switch tbl := tbl.(type) {
case *sqlparser.AliasedTableExpr:
ts := b.tc.tableSetFor(tbl)
tblInfo := b.tc.Tables[ts.TableOffset()]
for _, info := range tblInfo.getColumns() {
if column.EqualString(info.Name) {
return []TableInfo{tblInfo}, nil
}
}
for i, lft := range foundTables {
for j := i + 1; j < len(foundTables); j++ {
rgt := foundTables[j]
predicates = append(predicates, &sqlparser.ComparisonExpr{
Operator: sqlparser.EqualOp,
Left: sqlparser.NewColNameWithQualifier(column.String(), lft),
Right: sqlparser.NewColNameWithQualifier(column.String(), rgt),
})
return nil, nil
case *sqlparser.JoinTableExpr:
tblInfoR, err := findOnlyOneTableInfoThatHasColumn(b, tbl.RightExpr, column)
if err != nil {
return nil, err
}
tblInfoL, err := findOnlyOneTableInfoThatHasColumn(b, tbl.LeftExpr, column)
if err != nil {
return nil, err
}

return append(tblInfoL, tblInfoR...), nil
case *sqlparser.ParenTableExpr:
var tblInfo []TableInfo
for _, parenTable := range tbl.Exprs {
newTblInfo, err := findOnlyOneTableInfoThatHasColumn(b, parenTable, column)
if err != nil {
return nil, err
}
if tblInfo != nil && newTblInfo != nil {
return nil, vterrors.VT03021(column.String())
}
if newTblInfo != nil {
tblInfo = newTblInfo
}
}
return tblInfo, nil
default:
panic(fmt.Sprintf("unsupported TableExpr type in JOIN: %T", tbl))
}
}

// now, we go up the scope until we find a SELECT with a where clause we can add this predicate to
for current != nil {
sel, found := current.stmt.(*sqlparser.Select)
if found {
if sel.Where == nil {
sel.Where = &sqlparser.Where{
Type: sqlparser.WhereClause,
Expr: sqlparser.AndExpressions(predicates...),
}
} else {
sel.Where.Expr = sqlparser.AndExpressions(append(predicates, sel.Where.Expr)...)
}
return nil
// findTablesWithColumn finds the tables with the specified column in the current scope.
func findTablesWithColumn(b *binder, join *sqlparser.JoinTableExpr, column sqlparser.IdentifierCI) ([]sqlparser.TableName, error) {
leftTableInfo, err := findOnlyOneTableInfoThatHasColumn(b, join.LeftExpr, column)
if err != nil {
return nil, err
}

rightTableInfo, err := findOnlyOneTableInfoThatHasColumn(b, join.RightExpr, column)
if err != nil {
return nil, err
}

if leftTableInfo == nil || rightTableInfo == nil {
return nil, ShardedError{Inner: vterrors.VT09015()}
}
var tableNames []sqlparser.TableName
for _, info := range leftTableInfo {
nm, err := info.Name()
if err != nil {
return nil, err
}
tableNames = append(tableNames, nm)
}
for _, info := range rightTableInfo {
nm, err := info.Name()
if err != nil {
return nil, err
}
tableNames = append(tableNames, nm)
}
return tableNames, nil
}

// createComparisonPredicates creates a list of comparison predicates between the given column and foundTables.
func createComparisonPredicates(column sqlparser.IdentifierCI, foundTables []sqlparser.TableName) []sqlparser.Expr {
var predicates []sqlparser.Expr
for i, lft := range foundTables {
for j := i + 1; j < len(foundTables); j++ {
rgt := foundTables[j]
predicates = append(predicates, createComparisonBetween(column, lft, rgt))
}
current = current.parent
}
return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "did not find WHERE clause")
return predicates
}

func createComparisonBetween(column sqlparser.IdentifierCI, lft, rgt sqlparser.TableName) *sqlparser.ComparisonExpr {
return &sqlparser.ComparisonExpr{
Operator: sqlparser.EqualOp,
Left: sqlparser.NewColNameWithQualifier(column.String(), lft),
Right: sqlparser.NewColNameWithQualifier(column.String(), rgt),
}
}

func (r *earlyRewriter) expandTableColumns(
Expand Down
Loading

0 comments on commit 295e417

Please sign in to comment.