Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

schemadiff: validate views' referenced columns #12147

Closed
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ require (
github.com/openark/golib v0.0.0-20210531070646-355f37940af8
github.com/planetscale/log v0.0.0-20221118170849-fb599bc35c50
github.com/slok/noglog v0.2.0
go.uber.org/multierr v1.9.0
mattlord marked this conversation as resolved.
Show resolved Hide resolved
go.uber.org/zap v1.23.0
golang.org/x/exp v0.0.0-20221114191408-850992195362
)
Expand Down Expand Up @@ -187,7 +188,6 @@ require (
github.com/tidwall/pretty v1.2.0 // indirect
go.opencensus.io v0.24.0 // indirect
go.uber.org/atomic v1.10.0 // indirect
go.uber.org/multierr v1.8.0 // indirect
go4.org/intern v0.0.0-20220617035311-6925f38cc365 // indirect
go4.org/unsafe/assume-no-moving-gc v0.0.0-20220617031537-928513b29760 // indirect
golang.org/x/exp/typeparams v0.0.0-20221114191408-850992195362 // indirect
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -820,8 +820,8 @@ go.uber.org/atomic v1.10.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0
go.uber.org/goleak v1.1.11 h1:wy28qYRKZgnJTxGxvye5/wgWr1EKjmUDGYox5mGlRlI=
go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0=
go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU=
go.uber.org/multierr v1.8.0 h1:dg6GjLku4EH+249NNmoIciG9N/jURbDG+pFlTkhzIC8=
go.uber.org/multierr v1.8.0/go.mod h1:7EAYxJLBy9rStEaz58O2t4Uvip6FSURkq8/ppBp95ak=
go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI=
go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ=
go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q=
go.uber.org/zap v1.17.0/go.mod h1:MXVU+bhUf/A7Xi2HNOnopQOrmycQ5Ih87HtOu4q5SSo=
go.uber.org/zap v1.23.0 h1:OjGQ5KQDEUawVHxNwQgPpiypGHOxo2mNZsOqTak4fFY=
Expand Down
28 changes: 28 additions & 0 deletions go/vt/schemadiff/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -308,3 +308,31 @@ type ViewDependencyUnresolvedError struct {
func (e *ViewDependencyUnresolvedError) Error() string {
return fmt.Sprintf("view %s has unresolved/loop dependencies", sqlescape.EscapeID(e.View))
}

type InvalidColumnReferencedInViewError struct {
View string
Table string
Column string
NonUnique bool
}

func (e *InvalidColumnReferencedInViewError) Error() string {
switch {
case e.Column == "":
return fmt.Sprintf("view %s references non-existing table %s", sqlescape.EscapeID(e.View), sqlescape.EscapeID(e.Table))
mattlord marked this conversation as resolved.
Show resolved Hide resolved
case e.Table != "":
return fmt.Sprintf("view %s references non existing column %s.%s", sqlescape.EscapeID(e.View), sqlescape.EscapeID(e.Table), sqlescape.EscapeID(e.Column))
mattlord marked this conversation as resolved.
Show resolved Hide resolved
case e.NonUnique:
return fmt.Sprintf("view %s references unqualified but non unique column %s", sqlescape.EscapeID(e.View), sqlescape.EscapeID(e.Column))
default:
return fmt.Sprintf("view %s references unqualified but non existing column %s", sqlescape.EscapeID(e.View), sqlescape.EscapeID(e.Column))
}
}

type EntityNotFoundError struct {
Name string
}

func (e *EntityNotFoundError) Error() string {
return fmt.Sprintf("entity %s not found", sqlescape.EscapeID(e.Name))
}
263 changes: 263 additions & 0 deletions go/vt/schemadiff/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,13 @@ import (
"sort"
"strings"

"go.uber.org/multierr"

"vitess.io/vitess/go/vt/sqlparser"
)

type tablesColumnsMap map[string]map[string]struct{}

// Schema represents a database schema, which may contain entities such as tables and views.
// Schema is not in itself an Entity, since it is more of a collection of entities.
type Schema struct {
Expand Down Expand Up @@ -309,6 +313,11 @@ func (s *Schema) normalize() error {
}
}

// Validate views' referenced columns: do these columns actually exist in referenced tables/views?
if err := s.ValidateViewReferences(); err != nil {
return err
}

// Validate table definitions
for _, t := range s.tables {
if err := t.validate(); err != nil {
Expand Down Expand Up @@ -750,3 +759,257 @@ func (s *Schema) Apply(diffs []EntityDiff) (*Schema, error) {
}
return dup, nil
}

func (s *Schema) ValidateViewReferences() error {
var errs error
availableColumns := tablesColumnsMap{}

for _, e := range s.Entities() {
entityColumns, err := s.getEntityColumnNames(e.Name(), availableColumns)
if err != nil {
errs = multierr.Append(errs, err)
dbussink marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use the Vitess concurrency package here along with vterrors.Aggregate? You can see that used throughout the code base if you look for .AggrError(vterrors.Aggregate). Between the concurrency and vterrors packages we should have whatever related functionality you need here for error handling. For example:

func (ts *trafficSwitcher) ForAllSources(f func(source *workflow.MigrationSource) error) error {
var wg sync.WaitGroup
allErrors := &concurrency.AllErrorRecorder{}
for _, source := range ts.sources {
wg.Add(1)
go func(source *workflow.MigrationSource) {
defer wg.Done()
if err := f(source); err != nil {
allErrors.RecordError(err)
}
}(source)
}
wg.Wait()
return allErrors.AggrError(vterrors.Aggregate)
}

Copy link
Contributor

@dbussink dbussink Jan 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like it does use locking which is unnecessary here. Dunno if we care about the overhead of that here? That logic seems designed for concurrent error gathering which isn't what we're doing here.

Using multierr seems simpler here and it's already an indirect dependency? Not a really strong opinion though, we can also use this but it seems a bit off for what it was designed for.

I think this logic here is temporary anyway, since once golang/go#53435 is available with Go 1.20 we probably want to switch to that anyway.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

heh, race condition

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How strongly do people feel about this? Looking at the two implementation multierr does seem to be more fitting to our purpose, but I don't feel strongly.

Copy link
Contributor

@mattlord mattlord Jan 31, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you only need multiple errors you can use vterrors.Aggregate(), no need for the concurrency piece (as noted, geared toward N goroutines) like we do here e.g.:

var terrs []error
for !empty {
select {
case result := <-resultCh:
switch result.state {
case vcopierCopyTaskCancel:
// A task cancelation probably indicates an expired context due
// to a PlannedReparentShard or elapsed copy phase duration,
// neither of which are error conditions.
case vcopierCopyTaskComplete:
// Get the latest lastpk, purely for logging purposes.
lastpk = result.args.lastpk
case vcopierCopyTaskFail:
// Aggregate non-nil errors.
terrs = append(terrs, result.err)
}
default:
empty = true
}
}
if len(terrs) > 0 {
terr := vterrors.Aggregate(terrs)
log.Warningf("task error in workflow %s: %v", vc.vr.WorkflowName, terr)
return fmt.Errorf("task error: %v", terr)
}

I don't feel strongly about it though. Up to you.

continue
}
availableColumns[e.Name()] = map[string]struct{}{}
for _, col := range entityColumns {
availableColumns[e.Name()][col.Lowered()] = struct{}{}
}
}

// Add dual table with no explicit columns for dual style expressions in views.
availableColumns["dual"] = map[string]struct{}{}

for _, view := range s.Views() {
// First gather all referenced tables and table aliases
tableAliases := map[string]string{}
tableReferences := map[string]struct{}{}
err := gatherTableInformationForView(view, availableColumns, tableReferences, tableAliases)
errs = multierr.Append(errs, err)

// Now we can walk the view again and check each column expression
// to see if there's an existing column referenced.
err = gatherColumnReferenceInformationForView(view, availableColumns, tableReferences, tableAliases)
errs = multierr.Append(errs, err)
}
return errs
}

func gatherTableInformationForView(view *CreateViewEntity, availableColumns tablesColumnsMap, tableReferences map[string]struct{}, tableAliases map[string]string) error {
var errs error
tableErrors := make(map[string]struct{})
err := sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
switch node := node.(type) {
case *sqlparser.AliasedTableExpr:
aliased := sqlparser.GetTableName(node.Expr).String()
if aliased == "" {
return true, nil
}

if _, ok := availableColumns[aliased]; !ok {
if _, ok := tableErrors[aliased]; ok {
// Only show a missing table reference once per view.
return true, nil
}
err := &InvalidColumnReferencedInViewError{
View: view.Name(),
Table: aliased,
}
errs = multierr.Append(errs, err)
tableErrors[aliased] = struct{}{}
return true, nil
}
tableReferences[aliased] = struct{}{}
if node.As.String() != "" {
tableAliases[node.As.String()] = aliased
}
}
return true, nil
}, view.Select)
if err != nil {
// parsing error. Forget about any view dependency issues we may have found. This is way more important
return err
}
return errs
}

func gatherColumnReferenceInformationForView(view *CreateViewEntity, availableColumns tablesColumnsMap, tableReferences map[string]struct{}, tableAliases map[string]string) error {
var errs error
qualifiedColumnErrors := make(map[string]map[string]struct{})
unqualifiedColumnErrors := make(map[string]struct{})

err := sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
switch node := node.(type) {
case *sqlparser.ColName:
if node.Qualifier.IsEmpty() {
err := verifyUnqualifiedColumn(view, availableColumns, tableReferences, node.Name, unqualifiedColumnErrors)
errs = multierr.Append(errs, err)
} else {
err := verifyQualifiedColumn(view, availableColumns, tableAliases, node, qualifiedColumnErrors)
errs = multierr.Append(errs, err)
}
}
return true, nil
}, view.Select)
if err != nil {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI - sqlparser.Walk will only return the error you return from inside the visitor function, and since you don't return any errors from that, no errors will ever make it out.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@systay ah, great! Thank you. I'll keep the check as it is, for safety, but good to know!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@systay Right, but the errors are gathered here in errs? And those are at the end returned if the walker itself doesn't error (which is not expected).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dbussink yeah, that was my point. Not really necessary to catch the returned error from sqlparser.Walk since that is not how we are dealing with the errors. When I know it will never return an error, I usually write.

	_ = sqlparser.Walk(...)

OTOH - it's probably good defensive programming to do as @shlomi-noach is doing here and catching and checking the error anyway.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@systay Right, guess as someone now knowing the details too much about how Walk is implemented, I wouldn't know that it can't return an error and would still write it defensively 😄.

// parsing error. Forget about any view dependency issues we may have found. This is way more important
return err
}
return errs
}

func verifyUnqualifiedColumn(view *CreateViewEntity, availableColumns tablesColumnsMap, tableReferences map[string]struct{}, nodeName sqlparser.IdentifierCI, unqualifiedColumnErrors map[string]struct{}) error {
// In case we have a non-qualified column reference, it needs
// to be unique across all referenced tables if it is supposed
// to work.
columnFound := false
for table := range tableReferences {
cols, ok := availableColumns[table]
if !ok {
// We already dealt with an error for a missing table reference
// earlier, so we can ignore it at this point here.
return nil
}
_, columnInTable := cols[nodeName.Lowered()]
if !columnInTable {
continue
}
if columnFound {
// We already have seen the column before in another table, so
// if we see it again here, that's an error case.
if _, ok := unqualifiedColumnErrors[nodeName.Lowered()]; ok {
return nil
}
unqualifiedColumnErrors[nodeName.Lowered()] = struct{}{}
return &InvalidColumnReferencedInViewError{
View: view.Name(),
Column: nodeName.String(),
NonUnique: true,
}
}
columnFound = true
}

// If we've seen the desired column here once, we're all good
if columnFound {
return nil
}

if _, ok := unqualifiedColumnErrors[nodeName.Lowered()]; ok {
return nil
}
unqualifiedColumnErrors[nodeName.Lowered()] = struct{}{}
return &InvalidColumnReferencedInViewError{
View: view.Name(),
Column: nodeName.String(),
}
}

func verifyQualifiedColumn(
view *CreateViewEntity,
availableColumns tablesColumnsMap,
tableAliases map[string]string, node *sqlparser.ColName,
columnErrors map[string]map[string]struct{},
) error {
tableName := node.Qualifier.Name.String()
if aliased, ok := tableAliases[tableName]; ok {
tableName = aliased
}
cols, ok := availableColumns[tableName]
if !ok {
// Already dealt with missing tables earlier on, we don't have
// any error to add here.
return nil
}
_, ok = cols[node.Name.Lowered()]
if ok {
// Found the column in the table, all good.
return nil
}

if _, ok := columnErrors[tableName]; !ok {
columnErrors[tableName] = make(map[string]struct{})
}

if _, ok := columnErrors[tableName][node.Name.Lowered()]; ok {
return nil
dbussink marked this conversation as resolved.
Show resolved Hide resolved
}
columnErrors[tableName][node.Name.Lowered()] = struct{}{}
return &InvalidColumnReferencedInViewError{
View: view.Name(),
Table: tableName,
Column: node.Name.String(),
}
}

// getTableColumnNames returns the names of columns in given table.
func (s *Schema) getEntityColumnNames(entityName string, availableColumns tablesColumnsMap) (
columnNames []*sqlparser.IdentifierCI,
err error,
) {
entity := s.Entity(entityName)
if entity == nil {
if strings.ToLower(entityName) == "dual" {
// this is fine. DUAL does not exist but is allowed
return nil, nil
}
return nil, &EntityNotFoundError{Name: entityName}
}
// The entity is either a table or a view
switch entity := entity.(type) {
case *CreateTableEntity:
return s.getTableColumnNames(entity), nil
case *CreateViewEntity:
return s.getViewColumnNames(entity, availableColumns)
}
return nil, &EntityNotFoundError{Name: entityName}
}

// getTableColumnNames returns the names of columns in given table.
func (s *Schema) getTableColumnNames(t *CreateTableEntity) (columnNames []*sqlparser.IdentifierCI) {
for _, c := range t.TableSpec.Columns {
columnNames = append(columnNames, &c.Name)
}
return columnNames
}

// getViewColumnNames returns the names of aliased columns returned by a given view.
func (s *Schema) getViewColumnNames(v *CreateViewEntity, availableColumns tablesColumnsMap) (
columnNames []*sqlparser.IdentifierCI,
err error,
) {
err = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
switch node := node.(type) {
case *sqlparser.StarExpr:
if tableName := node.TableName.Name.String(); tableName != "" {
for colName := range availableColumns[tableName] {
name := sqlparser.NewIdentifierCI(colName)
columnNames = append(columnNames, &name)
}
} else {
dependentNames, err := getViewDependentTableNames(v.CreateView)
if err != nil {
return false, err
}
// add all columns from all referenced tables and views
for _, entityName := range dependentNames {
for colName := range availableColumns[entityName] {
name := sqlparser.NewIdentifierCI(colName)
columnNames = append(columnNames, &name)
}
}
}
case *sqlparser.AliasedExpr:
if node.As.String() != "" {
columnNames = append(columnNames, &node.As)
} else {
name := sqlparser.NewIdentifierCI(sqlparser.String(node.Expr))
columnNames = append(columnNames, &name)
}
}
return true, nil
}, v.Select.GetColumns())
if err != nil {
return nil, err
}
return columnNames, nil
}
Loading