diff --git a/go/vt/schemadiff/errors.go b/go/vt/schemadiff/errors.go index c21f895631c..54130445044 100644 --- a/go/vt/schemadiff/errors.go +++ b/go/vt/schemadiff/errors.go @@ -237,6 +237,15 @@ func (e *InvalidColumnInCheckConstraintError) Error() string { sqlescape.EscapeID(e.Column), sqlescape.EscapeID(e.Constraint), sqlescape.EscapeID(e.Table)) } +type ForeignKeyDependencyUnresolvedError struct { + Table string +} + +func (e *ForeignKeyDependencyUnresolvedError) Error() string { + return fmt.Sprintf("table %s has unresolved/loop foreign key dependencies", + sqlescape.EscapeID(e.Table)) +} + type InvalidColumnInForeignKeyConstraintError struct { Table string Constraint string @@ -244,10 +253,54 @@ type InvalidColumnInForeignKeyConstraintError struct { } func (e *InvalidColumnInForeignKeyConstraintError) Error() string { - return fmt.Sprintf("invalid column %s referenced by foreign key constraint %s in table %s", + return fmt.Sprintf("invalid column %s covered by foreign key constraint %s in table %s", sqlescape.EscapeID(e.Column), sqlescape.EscapeID(e.Constraint), sqlescape.EscapeID(e.Table)) } +type InvalidReferencedColumnInForeignKeyConstraintError struct { + Table string + Constraint string + ReferencedTable string + ReferencedColumn string +} + +func (e *InvalidReferencedColumnInForeignKeyConstraintError) Error() string { + return fmt.Sprintf("invalid column %s.%s referenced by foreign key constraint %s in table %s", + sqlescape.EscapeID(e.ReferencedTable), sqlescape.EscapeID(e.ReferencedColumn), sqlescape.EscapeID(e.Constraint), sqlescape.EscapeID(e.Table)) +} + +type ForeignKeyColumnCountMismatchError struct { + Table string + Constraint string + ColumnCount int + ReferencedTable string + ReferencedColumnCount int +} + +func (e *ForeignKeyColumnCountMismatchError) Error() string { + return fmt.Sprintf("mismatching column count %d referenced by foreign key constraint %s in table %s. Expected %d", + e.ReferencedColumnCount, sqlescape.EscapeID(e.Constraint), sqlescape.EscapeID(e.Table), e.ColumnCount) +} + +type ForeignKeyColumnTypeMismatchError struct { + Table string + Constraint string + Column string + ReferencedTable string + ReferencedColumn string +} + +func (e *ForeignKeyColumnTypeMismatchError) Error() string { + return fmt.Sprintf("mismatching column type %s.%s and %s.%s referenced by foreign key constraint %s in table %s", + sqlescape.EscapeID(e.ReferencedTable), + sqlescape.EscapeID(e.ReferencedColumn), + sqlescape.EscapeID(e.Table), + sqlescape.EscapeID(e.Column), + sqlescape.EscapeID(e.Constraint), + sqlescape.EscapeID(e.Table), + ) +} + type ViewDependencyUnresolvedError struct { View string } diff --git a/go/vt/schemadiff/schema.go b/go/vt/schemadiff/schema.go index 3cb6f4436b6..939abc47960 100644 --- a/go/vt/schemadiff/schema.go +++ b/go/vt/schemadiff/schema.go @@ -35,6 +35,9 @@ type Schema struct { named map[string]Entity sorted []Entity + + foreignKeyParents []*CreateTableEntity // subset of tables + foreignKeyChildren []*CreateTableEntity // subset of tables } // newEmptySchema is used internally to initialize a Schema object @@ -44,6 +47,9 @@ func newEmptySchema() *Schema { views: []*CreateViewEntity{}, named: map[string]Entity{}, sorted: []Entity{}, + + foreignKeyParents: []*CreateTableEntity{}, + foreignKeyChildren: []*CreateTableEntity{}, } return schema } @@ -122,6 +128,18 @@ func NewSchemaFromSQL(sql string) (*Schema, error) { return NewSchemaFromStatements(statements) } +// getForeignKeyParentTableNames analyzes a CREATE TABLE definition and extracts all referened foreign key tables names. +// A table name may appear twice in the result output, it it is referenced by more than one foreign key +func getForeignKeyParentTableNames(createTable *sqlparser.CreateTable) (names []string, err error) { + for _, cs := range createTable.TableSpec.Constraints { + if check, ok := cs.Details.(*sqlparser.ForeignKeyDefinition); ok { + parentTableName := check.ReferenceDefinition.ReferencedTable.Name.String() + names = append(names, parentTableName) + } + } + return names, err +} + // getViewDependentTableNames analyzes a CREATE VIEW definition and extracts all tables/views read by this view func getViewDependentTableNames(createView *sqlparser.CreateView) (names []string, err error) { err = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { @@ -175,10 +193,6 @@ func (s *Schema) normalize() error { // If a view v1 depends on v2, then v2 must come before v1, even though v1 // precedes v2 alphabetically dependencyLevels := make(map[string]int, len(s.tables)+len(s.views)) - for _, t := range s.tables { - s.sorted = append(s.sorted, t) - dependencyLevels[t.Name()] = 0 - } allNamesFoundInLowerLevel := func(names []string, level int) bool { for _, name := range names { @@ -198,12 +212,58 @@ func (s *Schema) normalize() error { return true } + // We now iterate all tables. We iterate "dependency levels": + // - first we want all tables that don't have foreign keys or which only reference themselves + // - then we only want tables that reference 1st level tables. these are 2nd level tables + // - etc. + // we stop when we have been unable to find a table in an iteration. + fkParents := map[string]bool{} + iterationLevel := 0 + for { + handledAnyTablesInIteration := false + for _, t := range s.tables { + name := t.Name() + if _, ok := dependencyLevels[name]; ok { + // already handled; skip + continue + } + // Not handled. Is this view dependent on already handled objects? + referencedTableNames, err := getForeignKeyParentTableNames(t.CreateTable) + if err != nil { + return err + } + if len(referencedTableNames) > 0 { + s.foreignKeyChildren = append(s.foreignKeyChildren, t) + } + nonSelfReferenceNames := []string{} + for _, referencedTableName := range referencedTableNames { + if referencedTableName != name { + nonSelfReferenceNames = append(nonSelfReferenceNames, referencedTableName) + } + fkParents[referencedTableName] = true + } + if allNamesFoundInLowerLevel(nonSelfReferenceNames, iterationLevel) { + s.sorted = append(s.sorted, t) + dependencyLevels[t.Name()] = iterationLevel + handledAnyTablesInIteration = true + } + } + if !handledAnyTablesInIteration { + break + } + iterationLevel++ + } + for _, t := range s.tables { + if fkParents[t.Name()] { + s.foreignKeyParents = append(s.foreignKeyParents, t) + } + } // We now iterate all views. We iterate "dependency levels": // - first we want all views that only depend on tables. These are 1st level views. // - then we only want views that depend on 1st level views or on tables. These are 2nd level views. // - etc. // we stop when we have been unable to find a view in an iteration. - for iterationLevel := 1; ; iterationLevel++ { + for { handledAnyViewsInIteration := false for _, v := range s.views { name := v.Name() @@ -225,11 +285,21 @@ func (s *Schema) normalize() error { if !handledAnyViewsInIteration { break } + iterationLevel++ } if len(s.sorted) != len(s.tables)+len(s.views) { - // We have leftover views. This can happen if the schema definition is invalid: + // We have leftover tables or views. This can happen if the schema definition is invalid: + // - a table's foreign key references a nonexistent table + // - two or more tables have circular FK dependency // - a view depends on a nonexistent table - // - two views have a circular dependency + // - two or more views have a circular dependency + for _, t := range s.tables { + if _, ok := dependencyLevels[t.Name()]; !ok { + // We _know_ that in this iteration, at least one view is found unassigned a dependency level. + // We return the first one. + return &ForeignKeyDependencyUnresolvedError{Table: t.Name()} + } + } for _, v := range s.views { if _, ok := dependencyLevels[v.Name()]; !ok { // We _know_ that in this iteration, at least one view is found unassigned a dependency level. @@ -238,6 +308,69 @@ func (s *Schema) normalize() error { } } } + + // Validate table definitions + for _, t := range s.tables { + if err := t.validate(); err != nil { + return err + } + } + colTypeEqualForForeignKey := func(a, b sqlparser.ColumnType) bool { + return a.Type == b.Type && + a.Unsigned == b.Unsigned && + a.Zerofill == b.Zerofill && + sqlparser.Equals.ColumnCharset(a.Charset, b.Charset) && + sqlparser.Equals.SliceOfString(a.EnumValues, b.EnumValues) + } + + // Now validate foreign key columns: + // - referenced table columns must exist + // - foreign key columns must match in count and type to referenced table columns + // - referenced table has an appropriate index over referenced columns + for _, t := range s.tables { + if len(t.TableSpec.Constraints) == 0 { + continue + } + + tableColumns := map[string]*sqlparser.ColumnDefinition{} + for _, col := range t.CreateTable.TableSpec.Columns { + colName := col.Name.Lowered() + tableColumns[colName] = col + } + + for _, cs := range t.TableSpec.Constraints { + check, ok := cs.Details.(*sqlparser.ForeignKeyDefinition) + if !ok { + continue + } + referencedTableName := check.ReferenceDefinition.ReferencedTable.Name.String() + referencedTable := s.Table(referencedTableName) // we know this exists because we validated foreign key dependencies earlier on + + referencedColumns := map[string]*sqlparser.ColumnDefinition{} + for _, col := range referencedTable.CreateTable.TableSpec.Columns { + colName := col.Name.Lowered() + referencedColumns[colName] = col + } + // Thanks to table validation, we already know the foreign key covered columns count is equal to the + // referenced table column count. Now ensure their types are identical + for i, col := range check.Source { + coveredColumn, ok := tableColumns[col.Lowered()] + if !ok { + return &InvalidColumnInForeignKeyConstraintError{Table: t.Name(), Constraint: cs.Name.String(), Column: col.String()} + } + referencedColumnName := check.ReferenceDefinition.ReferencedColumns[i].Lowered() + referencedColumn, ok := referencedColumns[referencedColumnName] + if !ok { + return &InvalidReferencedColumnInForeignKeyConstraintError{Table: t.Name(), Constraint: cs.Name.String(), ReferencedTable: referencedTableName, ReferencedColumn: referencedColumnName} + } + if !colTypeEqualForForeignKey(coveredColumn.Type, referencedColumn.Type) { + return &ForeignKeyColumnTypeMismatchError{Table: t.Name(), Constraint: cs.Name.String(), Column: coveredColumn.Name.String(), ReferencedTable: referencedTableName, ReferencedColumn: referencedColumnName} + } + } + + // TODO(shlomi): find a valid index + } + } return nil } diff --git a/go/vt/schemadiff/schema_test.go b/go/vt/schemadiff/schema_test.go index 308333bf641..cd2baa8527a 100644 --- a/go/vt/schemadiff/schema_test.go +++ b/go/vt/schemadiff/schema_test.go @@ -21,6 +21,9 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "vitess.io/vitess/go/vt/sqlparser" ) var createQueries = []string{ @@ -153,3 +156,223 @@ func TestCopy(t *testing.T) { assert.Equal(t, schema.ToSQL(), schemaClone.ToSQL()) assert.False(t, schema == schemaClone) } + +func TestGetViewDependentTableNames(t *testing.T) { + tt := []struct { + name string + view string + tables []string + }{ + { + view: "create view v6 as select * from v4", + tables: []string{"v4"}, + }, + { + view: "create view v2 as select * from v3, t2", + tables: []string{"v3", "t2"}, + }, + { + view: "create view v3 as select * from t3 as t3", + tables: []string{"t3"}, + }, + { + view: "create view v3 as select * from t3 as something_else", + tables: []string{"t3"}, + }, + { + view: "create view v5 as select * from t1, (select * from v3) as some_alias", + tables: []string{"t1", "v3"}, + }, + { + view: "create view v0 as select 1 from DUAL", + tables: []string{"dual"}, + }, + { + view: "create view v9 as select 1", + tables: []string{"dual"}, + }, + } + for _, ts := range tt { + t.Run(ts.view, func(t *testing.T) { + stmt, err := sqlparser.ParseStrictDDL(ts.view) + require.NoError(t, err) + createView, ok := stmt.(*sqlparser.CreateView) + require.True(t, ok) + + tables, err := getViewDependentTableNames(createView) + assert.NoError(t, err) + assert.Equal(t, ts.tables, tables) + }) + } +} + +func TestGetForeignKeyParentTableNames(t *testing.T) { + tt := []struct { + name string + table string + tables []string + }{ + { + table: "create table t1 (id int primary key, i int, foreign key (i) references parent(id))", + tables: []string{"parent"}, + }, + { + table: "create table t1 (id int primary key, i int, constraint f foreign key (i) references parent(id))", + tables: []string{"parent"}, + }, + { + table: "create table t1 (id int primary key, i int, constraint f foreign key (i) references parent(id) on delete cascade)", + tables: []string{"parent"}, + }, + { + table: "create table t1 (id int primary key, i int, i2 int, constraint f foreign key (i) references parent(id) on delete cascade, constraint f2 foreign key (i2) references parent2(id) on delete restrict)", + tables: []string{"parent", "parent2"}, + }, + { + table: "create table t1 (id int primary key, i int, i2 int, constraint f foreign key (i) references parent(id) on delete cascade, constraint f2 foreign key (i2) references parent(id) on delete restrict)", + tables: []string{"parent", "parent"}, + }, + } + for _, ts := range tt { + t.Run(ts.table, func(t *testing.T) { + stmt, err := sqlparser.ParseStrictDDL(ts.table) + require.NoError(t, err) + createTable, ok := stmt.(*sqlparser.CreateTable) + require.True(t, ok) + + tables, err := getForeignKeyParentTableNames(createTable) + assert.NoError(t, err) + assert.Equal(t, ts.tables, tables) + }) + } +} + +func TestTableForeignKeyOrdering(t *testing.T) { + fkQueries := []string{ + "create table t11 (id int primary key, i int, constraint f12 foreign key (i) references t12(id) on delete restrict, constraint f20 foreign key (i) references t20(id) on delete restrict)", + "create table t15(id int, primary key(id))", + "create view v09 as select * from v13, t17", + "create table t20 (id int primary key, i int, constraint f15 foreign key (i) references t15(id) on delete restrict)", + "create view v13 as select * from t20", + "create table t12 (id int primary key, i int, constraint f15 foreign key (i) references t15(id) on delete restrict)", + "create table t17 (id int primary key, i int, constraint f11 foreign key (i) references t11(id) on delete restrict, constraint f15 foreign key (i) references t15(id) on delete restrict)", + "create table t16 (id int primary key, i int, constraint f11 foreign key (i) references t11(id) on delete restrict, constraint f15 foreign key (i) references t15(id) on delete restrict)", + "create table t14 (id int primary key, i int, constraint f14 foreign key (i) references t14(id) on delete restrict)", + } + expectSortedTableNames := []string{ + "t14", + "t15", + "t12", + "t20", + "t11", + "t16", + "t17", + } + expectSortedViewNames := []string{ + "v13", + "v09", + } + schema, err := NewSchemaFromQueries(fkQueries) + require.NoError(t, err) + assert.NotNil(t, schema) + + assert.Equal(t, append(expectSortedTableNames, expectSortedViewNames...), schema.EntityNames()) + assert.Equal(t, expectSortedTableNames, schema.TableNames()) + assert.Equal(t, expectSortedViewNames, schema.ViewNames()) +} + +func TestInvalidSchema(t *testing.T) { + tt := []struct { + schema string + expectErr error + }{ + { + schema: "create table t11 (id int primary key, i int, constraint f11 foreign key (i) references t11(id) on delete restrict)", + }, + { + schema: "create table t10(id int primary key); create table t11 (id int primary key, i int, constraint f10 foreign key (i) references t10(id) on delete restrict)", + }, + { + schema: "create table t11 (id int primary key, i int, constraint f11 foreign key (i7) references t11(id) on delete restrict)", + expectErr: &InvalidColumnInForeignKeyConstraintError{Table: "t11", Constraint: "f11", Column: "i7"}, + }, + { + schema: "create table t11 (id int primary key, i int, constraint f11 foreign key (i) references t11(id, i) on delete restrict)", + expectErr: &ForeignKeyColumnCountMismatchError{Table: "t11", Constraint: "f11", ColumnCount: 1, ReferencedTable: "t11", ReferencedColumnCount: 2}, + }, + { + schema: "create table t11 (id int primary key, i1 int, i2 int, constraint f11 foreign key (i1, i2) references t11(i1) on delete restrict)", + expectErr: &ForeignKeyColumnCountMismatchError{Table: "t11", Constraint: "f11", ColumnCount: 2, ReferencedTable: "t11", ReferencedColumnCount: 1}, + }, + { + schema: "create table t11 (id int primary key, i int, constraint f12 foreign key (i) references t12(id) on delete restrict)", + expectErr: &ForeignKeyDependencyUnresolvedError{Table: "t11"}, + }, + { + schema: "create table t11 (id int primary key, i int, constraint f11 foreign key (i) references t11(id2) on delete restrict)", + expectErr: &InvalidReferencedColumnInForeignKeyConstraintError{Table: "t11", Constraint: "f11", ReferencedTable: "t11", ReferencedColumn: "id2"}, + }, + { + schema: "create table t10(id int primary key); create table t11 (id int primary key, i int, constraint f10 foreign key (i) references t10(x) on delete restrict)", + expectErr: &InvalidReferencedColumnInForeignKeyConstraintError{Table: "t11", Constraint: "f10", ReferencedTable: "t10", ReferencedColumn: "x"}, + }, + { + schema: "create table t10(id int primary key); create table t11 (id int primary key, i int unsigned, constraint f10 foreign key (i) references t10(id) on delete restrict)", + expectErr: &ForeignKeyColumnTypeMismatchError{Table: "t11", Constraint: "f10", Column: "i", ReferencedTable: "t10", ReferencedColumn: "id"}, + }, + { + schema: "create table t10(id int primary key); create table t11 (id int primary key, i bigint, constraint f10 foreign key (i) references t10(id) on delete restrict)", + expectErr: &ForeignKeyColumnTypeMismatchError{Table: "t11", Constraint: "f10", Column: "i", ReferencedTable: "t10", ReferencedColumn: "id"}, + }, + { + schema: "create table t10(id bigint primary key); create table t11 (id int primary key, i int, constraint f10 foreign key (i) references t10(id) on delete restrict)", + expectErr: &ForeignKeyColumnTypeMismatchError{Table: "t11", Constraint: "f10", Column: "i", ReferencedTable: "t10", ReferencedColumn: "id"}, + }, + { + schema: "create table t10(id bigint primary key); create table t11 (id int primary key, i varchar(100), constraint f10 foreign key (i) references t10(id) on delete restrict)", + expectErr: &ForeignKeyColumnTypeMismatchError{Table: "t11", Constraint: "f10", Column: "i", ReferencedTable: "t10", ReferencedColumn: "id"}, + }, + { + // InnoDB allows different string length + schema: "create table t10(id varchar(50) primary key); create table t11 (id int primary key, i varchar(100), constraint f10 foreign key (i) references t10(id) on delete restrict)", + }, + { + schema: "create table t10(id varchar(50) charset utf8mb3 primary key); create table t11 (id int primary key, i varchar(100) charset utf8mb4, constraint f10 foreign key (i) references t10(id) on delete restrict)", + expectErr: &ForeignKeyColumnTypeMismatchError{Table: "t11", Constraint: "f10", Column: "i", ReferencedTable: "t10", ReferencedColumn: "id"}, + }, + } + for _, ts := range tt { + t.Run(ts.schema, func(t *testing.T) { + + _, err := NewSchemaFromSQL(ts.schema) + if ts.expectErr == nil { + assert.NoError(t, err) + } else { + assert.Error(t, err) + assert.EqualError(t, err, ts.expectErr.Error()) + } + }) + } +} + +func TestInvalidTableForeignKeyReference(t *testing.T) { + { + fkQueries := []string{ + "create table t11 (id int primary key, i int, constraint f12 foreign key (i) references t12(id) on delete restrict)", + "create table t15(id int, primary key(id))", + } + _, err := NewSchemaFromQueries(fkQueries) + assert.Error(t, err) + assert.EqualError(t, err, (&ForeignKeyDependencyUnresolvedError{Table: "t11"}).Error()) + } + { + fkQueries := []string{ + "create table t13 (id int primary key, i int, constraint f11 foreign key (i) references t11(id) on delete restrict)", + "create table t11 (id int primary key, i int, constraint f12 foreign key (i) references t12(id) on delete restrict)", + "create table t12 (id int primary key, i int, constraint f13 foreign key (i) references t13(id) on delete restrict)", + } + _, err := NewSchemaFromQueries(fkQueries) + assert.Error(t, err) + assert.EqualError(t, err, (&ForeignKeyDependencyUnresolvedError{Table: "t11"}).Error()) + } +} diff --git a/go/vt/schemadiff/table.go b/go/vt/schemadiff/table.go index 06aa05e3578..156228fe0fb 100644 --- a/go/vt/schemadiff/table.go +++ b/go/vt/schemadiff/table.go @@ -2130,17 +2130,22 @@ func (c *CreateTableEntity) validate() error { } } } - // validate all columns referenced by foreign key constraints do in fact exist + // validate all columns used by foreign key constraints do in fact exist, + // and that there exists an index over those columns for _, cs := range c.CreateTable.TableSpec.Constraints { check, ok := cs.Details.(*sqlparser.ForeignKeyDefinition) if !ok { continue } + if len(check.Source) != len(check.ReferenceDefinition.ReferencedColumns) { + return &ForeignKeyColumnCountMismatchError{Table: c.Name(), Constraint: cs.Name.String(), ColumnCount: len(check.Source), ReferencedTable: check.ReferenceDefinition.ReferencedTable.Name.String(), ReferencedColumnCount: len(check.ReferenceDefinition.ReferencedColumns)} + } for _, col := range check.Source { if !columnExists[col.Lowered()] { return &InvalidColumnInForeignKeyConstraintError{Table: c.Name(), Constraint: cs.Name.String(), Column: col.String()} } } + // TODO(shlomi): find a valid index } // validate all columns referenced by constraint checks do in fact exist for _, cs := range c.CreateTable.TableSpec.Constraints { diff --git a/go/vt/schemadiff/table_test.go b/go/vt/schemadiff/table_test.go index a0a2334c168..7a8504a55fb 100644 --- a/go/vt/schemadiff/table_test.go +++ b/go/vt/schemadiff/table_test.go @@ -1555,6 +1555,12 @@ func TestValidate(t *testing.T) { alter: "alter table t add constraint f foreign key (z) references parent(id)", expectErr: &InvalidColumnInForeignKeyConstraintError{Table: "t", Constraint: "f", Column: "z"}, }, + { + name: "mismatching column count in foreign key", + from: "create table t (id int primary key, i int, constraint f foreign key (i) references parent(id, z))", + alter: "alter table t engine=innodb", + expectErr: &ForeignKeyColumnCountMismatchError{Table: "t", Constraint: "f", ColumnCount: 1, ReferencedTable: "parent", ReferencedColumnCount: 2}, + }, { name: "change with constraints with uppercase columns", from: "CREATE TABLE `Machine` (id int primary key, `a` int, `B` int, CONSTRAINT `chk` CHECK (`B` >= `a`))",