Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

schemadiff: normalize PRIMARY KEY definition #12016

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 24 additions & 24 deletions go/vt/schemadiff/diff_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ func TestDiffTables(t *testing.T) {
{
name: "create",
to: "create table t(id int primary key)",
diff: "create table t (\n\tid int primary key\n)",
cdiff: "CREATE TABLE `t` (\n\t`id` int PRIMARY KEY\n)",
diff: "create table t (\n\tid int,\n\tprimary key (id)\n)",
cdiff: "CREATE TABLE `t` (\n\t`id` int,\n\tPRIMARY KEY (`id`)\n)",
action: "create",
toName: "t",
},
Expand Down Expand Up @@ -411,21 +411,21 @@ func TestDiffSchemas(t *testing.T) {
name: "create table",
to: "create table t(id int primary key)",
diffs: []string{
"create table t (\n\tid int primary key\n)",
"create table t (\n\tid int,\n\tprimary key (id)\n)",
},
cdiffs: []string{
"CREATE TABLE `t` (\n\t`id` int PRIMARY KEY\n)",
"CREATE TABLE `t` (\n\t`id` int,\n\tPRIMARY KEY (`id`)\n)",
},
},
{
name: "create table 2",
from: ";;; ; ; ;;;",
to: "create table t(id int primary key)",
diffs: []string{
"create table t (\n\tid int primary key\n)",
"create table t (\n\tid int,\n\tprimary key (id)\n)",
},
cdiffs: []string{
"CREATE TABLE `t` (\n\t`id` int PRIMARY KEY\n)",
"CREATE TABLE `t` (\n\t`id` int,\n\tPRIMARY KEY (`id`)\n)",
},
},
{
Expand All @@ -444,13 +444,13 @@ func TestDiffSchemas(t *testing.T) {
to: "create table t4(id int primary key); create table t2(id bigint primary key); create table t3(id int primary key)",
diffs: []string{
"drop table t1",
"alter table t2 modify column id bigint primary key",
"create table t4 (\n\tid int primary key\n)",
"alter table t2 modify column id bigint",
"create table t4 (\n\tid int,\n\tprimary key (id)\n)",
},
cdiffs: []string{
"DROP TABLE `t1`",
"ALTER TABLE `t2` MODIFY COLUMN `id` bigint PRIMARY KEY",
"CREATE TABLE `t4` (\n\t`id` int PRIMARY KEY\n)",
"ALTER TABLE `t2` MODIFY COLUMN `id` bigint",
"CREATE TABLE `t4` (\n\t`id` int,\n\tPRIMARY KEY (`id`)\n)",
},
},
{
Expand All @@ -459,11 +459,11 @@ func TestDiffSchemas(t *testing.T) {
to: "create table t1(id int primary key); create table t3(id int unsigned primary key);",
diffs: []string{
"drop table t2",
"create table t3 (\n\tid int unsigned primary key\n)",
"create table t3 (\n\tid int unsigned,\n\tprimary key (id)\n)",
},
cdiffs: []string{
"DROP TABLE `t2`",
"CREATE TABLE `t3` (\n\t`id` int unsigned PRIMARY KEY\n)",
"CREATE TABLE `t3` (\n\t`id` int unsigned,\n\tPRIMARY KEY (`id`)\n)",
},
},
{
Expand All @@ -486,17 +486,17 @@ func TestDiffSchemas(t *testing.T) {
"drop table t1a",
"drop table t2a",
"drop table t3a",
"create table t1b (\n\tid bigint primary key\n)",
"create table t2b (\n\tid int unsigned primary key\n)",
"create table t3b (\n\tid int primary key\n)",
"create table t1b (\n\tid bigint,\n\tprimary key (id)\n)",
"create table t2b (\n\tid int unsigned,\n\tprimary key (id)\n)",
"create table t3b (\n\tid int,\n\tprimary key (id)\n)",
},
cdiffs: []string{
"DROP TABLE `t1a`",
"DROP TABLE `t2a`",
"DROP TABLE `t3a`",
"CREATE TABLE `t1b` (\n\t`id` bigint PRIMARY KEY\n)",
"CREATE TABLE `t2b` (\n\t`id` int unsigned PRIMARY KEY\n)",
"CREATE TABLE `t3b` (\n\t`id` int PRIMARY KEY\n)",
"CREATE TABLE `t1b` (\n\t`id` bigint,\n\tPRIMARY KEY (`id`)\n)",
"CREATE TABLE `t2b` (\n\t`id` int unsigned,\n\tPRIMARY KEY (`id`)\n)",
"CREATE TABLE `t3b` (\n\t`id` int,\n\tPRIMARY KEY (`id`)\n)",
},
},
{
Expand All @@ -505,13 +505,13 @@ func TestDiffSchemas(t *testing.T) {
to: "create table t1b(id bigint primary key); create table t2b(id int unsigned primary key); create table t3b(id int primary key); ",
diffs: []string{
"drop table t3a",
"create table t1b (\n\tid bigint primary key\n)",
"create table t1b (\n\tid bigint,\n\tprimary key (id)\n)",
"rename table t1a to t3b",
"rename table t2a to t2b",
},
cdiffs: []string{
"DROP TABLE `t3a`",
"CREATE TABLE `t1b` (\n\t`id` bigint PRIMARY KEY\n)",
"CREATE TABLE `t1b` (\n\t`id` bigint,\n\tPRIMARY KEY (`id`)\n)",
"RENAME TABLE `t1a` TO `t3b`",
"RENAME TABLE `t2a` TO `t2b`",
},
Expand Down Expand Up @@ -601,17 +601,17 @@ func TestDiffSchemas(t *testing.T) {
diffs: []string{
"drop table t1",
"drop view v1",
"alter table t2 modify column id bigint primary key",
"alter table t2 modify column id bigint",
"alter view v2 as select id from t2",
"create table t4 (\n\tid int primary key\n)",
"create table t4 (\n\tid int,\n\tprimary key (id)\n)",
"create view v0 as select * from v2, t2",
},
cdiffs: []string{
"DROP TABLE `t1`",
"DROP VIEW `v1`",
"ALTER TABLE `t2` MODIFY COLUMN `id` bigint PRIMARY KEY",
"ALTER TABLE `t2` MODIFY COLUMN `id` bigint",
"ALTER VIEW `v2` AS SELECT `id` FROM `t2`",
"CREATE TABLE `t4` (\n\t`id` int PRIMARY KEY\n)",
"CREATE TABLE `t4` (\n\t`id` int,\n\tPRIMARY KEY (`id`)\n)",
"CREATE VIEW `v0` AS SELECT * FROM `v2`, `t2`",
},
},
Expand Down
9 changes: 9 additions & 0 deletions go/vt/schemadiff/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,15 @@ type InvalidColumnInKeyError struct {
Key string
}

type DuplicateKeyNameError struct {
Table string
Key string
}

func (e *DuplicateKeyNameError) Error() string {
return fmt.Sprintf("duplicate key %s in table %s", sqlescape.EscapeID(e.Key), sqlescape.EscapeID(e.Table))
}

func (e *InvalidColumnInKeyError) Error() string {
return fmt.Sprintf("invalid column %s referenced by key %s in table %s",
sqlescape.EscapeID(e.Column), sqlescape.EscapeID(e.Key), sqlescape.EscapeID(e.Table))
Expand Down
80 changes: 80 additions & 0 deletions go/vt/schemadiff/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,36 @@ func (c *CreateTableEntity) normalizePartitionOptions() {
}
}

func newPrimaryKeyIndexDefinitionSingleColumn(name sqlparser.IdentifierCI) *sqlparser.IndexDefinition {
index := &sqlparser.IndexDefinition{
Info: &sqlparser.IndexInfo{
Name: sqlparser.NewIdentifierCI("PRIMARY"),
Type: "PRIMARY KEY",
Primary: true,
Unique: true,
},
Columns: []*sqlparser.IndexColumn{{Column: name}},
}
return index
}

func (c *CreateTableEntity) normalizePrimaryKeyColumns() {
// normalize PRIMARY KEY:
// `create table t (id int primary key)`
// should turn into:
// `create table t (id int, primary key (id))`
// Also, PRIMARY KEY must come first before all other keys
for _, col := range c.CreateTable.TableSpec.Columns {
if col.Type.Options.KeyOpt == sqlparser.ColKeyPrimary {
c.CreateTable.TableSpec.Indexes = append([]*sqlparser.IndexDefinition{newPrimaryKeyIndexDefinitionSingleColumn(col.Name)}, c.CreateTable.TableSpec.Indexes...)
col.Type.Options.KeyOpt = sqlparser.ColKeyNone
}
}
}

func (c *CreateTableEntity) normalizeKeys() {
c.normalizePrimaryKeyColumns()

// let's ensure all keys have names
keyNameExists := map[string]bool{}
// first, we iterate and take note for all keys that do already have names
Expand Down Expand Up @@ -1508,6 +1537,17 @@ func heuristicallyDetectColumnRenames(
return dropColumns, addColumns, renameColumns
}

// primaryKeyColumns returns the columns covered by an existing PRIMARY KEY, or nil if there isn't
// a PRIMARY KEY
func (c *CreateTableEntity) primaryKeyColumns() []*sqlparser.IndexColumn {
for _, existingIndex := range c.CreateTable.TableSpec.Indexes {
if existingIndex.Info.Primary {
return existingIndex.Columns
}
}
return nil
}

// Create implements Entity interface
func (c *CreateTableEntity) Create() EntityDiff {
return &CreateTableEntityDiff{to: c, createTable: c.CreateTable}
Expand Down Expand Up @@ -1756,6 +1796,12 @@ func (c *CreateTableEntity) apply(diff *AlterTableEntityDiff) error {
return &ApplyDuplicateColumnError{Table: c.Name(), Column: addedCol.Name.String()}
}
}
// if this column has the PRIMARY KEY option, verify there isn't already a PRIMARY KEY
if addedCol.Type.Options.KeyOpt == sqlparser.ColKeyPrimary {
if cols := c.primaryKeyColumns(); cols != nil {
return &DuplicateKeyNameError{Table: c.Name(), Key: "PRIMARY"}
}
}
c.TableSpec.Columns = append(c.TableSpec.Columns, addedCol)
// see if we need to position it anywhere other than end of table
if err := reorderColumn(len(c.TableSpec.Columns)-1, opt.First, opt.After); err != nil {
Expand All @@ -1779,6 +1825,24 @@ func (c *CreateTableEntity) apply(diff *AlterTableEntityDiff) error {
if !found {
return &ApplyColumnNotFoundError{Table: c.Name(), Column: opt.NewColDefinition.Name.String()}
}
// if this column has the PRIMARY KEY option:
// - validate there isn't already a PRIMARY KEY for other columns
// - if there isn't any PRIMARY KEY, create one
// - if there exists a PRIMARY KEY for exactly this column, noop
if opt.NewColDefinition.Type.Options.KeyOpt == sqlparser.ColKeyPrimary {
cols := c.primaryKeyColumns()
if cols == nil {
// add primary key
c.CreateTable.TableSpec.Indexes = append([]*sqlparser.IndexDefinition{newPrimaryKeyIndexDefinitionSingleColumn(opt.NewColDefinition.Name)}, c.CreateTable.TableSpec.Indexes...)
} else {
if len(cols) == 1 && strings.EqualFold(cols[0].Column.String(), opt.NewColDefinition.Name.String()) {
// existing PK is exactly this column. Nothing to do
} else {
return &DuplicateKeyNameError{Table: c.Name(), Key: "PRIMARY"}
}
}
}
opt.NewColDefinition.Type.Options.KeyOpt = sqlparser.ColKeyNone
case *sqlparser.RenameColumn:
// we expect the column to exist
found := false
Expand Down Expand Up @@ -1993,6 +2057,18 @@ func getKeyColumnNames(key *sqlparser.IndexDefinition) (colNames map[string]bool
return colNames
}

func (c *CreateTableEntity) validateDuplicateKeyNameError() error {
keyNames := map[string]bool{}
for _, key := range c.CreateTable.TableSpec.Indexes {
name := key.Info.Name
if _, ok := keyNames[name.Lowered()]; ok {
return &DuplicateKeyNameError{Table: c.Name(), Key: name.String()}
}
keyNames[name.Lowered()] = true
}
return nil
}

// validate checks that the table structure is valid:
// - all columns referenced by keys exist
func (c *CreateTableEntity) validate() error {
Expand Down Expand Up @@ -2089,6 +2165,10 @@ func (c *CreateTableEntity) validate() error {
}
}
}
// validate no two keys have same name
if err := c.validateDuplicateKeyNameError(); err != nil {
return err
}

if partition := c.CreateTable.TableSpec.PartitionOption; partition != nil {
// validate no two partitions have same name
Expand Down
Loading