diff --git a/go/vt/vtctl/workflow/sequences.go b/go/vt/vtctl/workflow/sequences.go index 5a1185ad7cb..6e0bd62af89 100644 --- a/go/vt/vtctl/workflow/sequences.go +++ b/go/vt/vtctl/workflow/sequences.go @@ -70,11 +70,15 @@ type sequenceMetadata struct { } func (sm *sequenceMetadata) escapeValues() error { - usingCol, err := sqlescape.EnsureEscaped(sm.usingTableDefinition.AutoIncrement.Column) - if err != nil { - err = vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid column name %s specified for sequence in table %s: %v", - sm.usingTableDefinition.AutoIncrement.Column, sm.usingTableName, err) - return err + usingCol := "" + var err error + if sm.usingTableDefinition != nil && sm.usingTableDefinition.AutoIncrement != nil { + usingCol, err = sqlescape.EnsureEscaped(sm.usingTableDefinition.AutoIncrement.Column) + if err != nil { + err = vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid column name %s specified for sequence in table %s: %v", + sm.usingTableDefinition.AutoIncrement.Column, sm.usingTableName, err) + return err + } } sm.usingCol = usingCol usingDB, err := sqlescape.EnsureEscaped(sm.usingTableDBName) diff --git a/go/vt/vtctl/workflow/switcher_dry_run_test.go b/go/vt/vtctl/workflow/switcher_dry_run_test.go index ff0bc709fd4..eab61089126 100644 --- a/go/vt/vtctl/workflow/switcher_dry_run_test.go +++ b/go/vt/vtctl/workflow/switcher_dry_run_test.go @@ -22,6 +22,9 @@ import ( "testing" "time" + "vitess.io/vitess/go/sqltypes" + "vitess.io/vitess/go/vt/proto/vschema" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -370,21 +373,83 @@ func TestChangeRouting(t *testing.T) { } func TestDRInitializeTargetSequences(t *testing.T) { - ctx := context.Background() + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + + workflowName := "wf1" + tableName := "t1" + sourceKeyspaceName := "sourceks" + targetKeyspaceName := "targetks" + + schema := map[string]*tabletmanagerdatapb.SchemaDefinition{ + tableName: { + TableDefinitions: []*tabletmanagerdatapb.TableDefinition{ + { + Name: tableName, + Schema: fmt.Sprintf("CREATE TABLE %s (id BIGINT, name VARCHAR(64), PRIMARY KEY (id))", tableName), + }, + }, + }, + } + + sourceKeyspace := &testKeyspace{ + KeyspaceName: sourceKeyspaceName, + ShardNames: []string{"0"}, + } + targetKeyspace := &testKeyspace{ + KeyspaceName: targetKeyspaceName, + ShardNames: []string{"0"}, + } + + env := newTestEnv(t, ctx, defaultCellName, sourceKeyspace, targetKeyspace) + defer env.close() + env.tmc.schema = schema + + ts, _, err := env.ws.getWorkflowState(ctx, targetKeyspaceName, workflowName) + require.NoError(t, err) drLog := NewLogRecorder() dr := &switcherDryRun{ drLog: drLog, + ts: ts, + } + + sm1 := sequenceMetadata{ + backingTableName: "seq1", + backingTableKeyspace: "sourceks", + backingTableDBName: "ks1", + usingTableName: "t1", + usingTableDBName: "targetks", + usingTableDefinition: &vschema.Table{ + AutoIncrement: &vschema.AutoIncrement{Column: "id", Sequence: "seq1"}, + }, } + sm2 := sm1 + sm2.backingTableName = "seq2" + sm2.usingTableName = "t2" + sm2.usingTableDefinition.AutoIncrement.Sequence = "seq2" + + sm3 := sm1 + sm3.backingTableName = "seq3" + sm3.usingTableName = "t3" + sm3.usingTableDefinition.AutoIncrement.Sequence = "seq3" tables := map[string]*sequenceMetadata{ - "t1": nil, - "t2": nil, - "t3": nil, + "t1": &sm1, + "t2": &sm2, + "t3": &sm3, + } + + for range tables { + env.tmc.expectVRQuery(200, "/select max.*", sqltypes.MakeTestResult(sqltypes.MakeTestFields("maxval", "int64"), "10")) + env.tmc.expectVRQuery(100, "/select next_id.*.*", sqltypes.MakeTestResult(sqltypes.MakeTestFields("next_id", "int64"), "1")) } - err := dr.initializeTargetSequences(ctx, tables) + env.tmc.expectVRQuery(100, "/select next_id.*", sqltypes.MakeTestResult(sqltypes.MakeTestFields("next_id", "int64"), "1")) + + err = dr.initializeTargetSequences(ctx, tables) require.NoError(t, err) - assert.Len(t, drLog.logs, 1) - assert.Contains(t, drLog.logs[0], "t1") - assert.Contains(t, drLog.logs[0], "t2") - assert.Contains(t, drLog.logs[0], "t3") + assert.Len(t, drLog.logs, 4) + assert.Contains(t, drLog.logs[0], "The following sequence backing tables used by tables being moved will be initialized:") + assert.Contains(t, drLog.logs[1], "Backing table: t1, current value 0, new value 1") + assert.Contains(t, drLog.logs[2], "Backing table: t2, current value 0, new value 1") + assert.Contains(t, drLog.logs[3], "Backing table: t3, current value 0, new value 1") }