diff --git a/executor/temporary_table_test.go b/executor/temporary_table_test.go index ebc8aff0bae95..5174a91976a21 100644 --- a/executor/temporary_table_test.go +++ b/executor/temporary_table_test.go @@ -26,34 +26,40 @@ import ( "github.com/stretchr/testify/require" ) -func TestTemporaryTableNoNetwork(t *testing.T) { - t.Run("global", func(t *testing.T) { - assertTemporaryTableNoNetwork(t, func(tk *testkit.TestKit) { - tk.MustExec("create global temporary table tmp_t (id int primary key, a int, b int, index(a)) on commit delete rows") - tk.MustExec("begin") - }) +func TestNormalGlobalTemporaryTableNoNetwork(t *testing.T) { + assertTemporaryTableNoNetwork(t, func(tk *testkit.TestKit) { + tk.MustExec("create global temporary table tmp_t (id int primary key, a int, b int, index(a)) on commit delete rows") + tk.MustExec("begin") }) +} + +func TestGlobalTemporaryTableNoNetworkWithCreateAndTruncate(t *testing.T) { + assertTemporaryTableNoNetwork(t, func(tk *testkit.TestKit) { + tk.MustExec("create global temporary table tmp_t (id int primary key, a int, b int, index(a)) on commit delete rows") + tk.MustExec("truncate table tmp_t") + tk.MustExec("begin") + }) +} - t.Run("global create and then truncate", func(t *testing.T) { - assertTemporaryTableNoNetwork(t, func(tk *testkit.TestKit) { - tk.MustExec("create global temporary table tmp_t (id int primary key, a int, b int, index(a)) on commit delete rows") - tk.MustExec("truncate table tmp_t") - tk.MustExec("begin") - }) +func TestGlobalTemporaryTableNoNetworkWithCreateAndThenCreateNormalTable(t *testing.T) { + assertTemporaryTableNoNetwork(t, func(tk *testkit.TestKit) { + tk.MustExec("create global temporary table tmp_t (id int primary key, a int, b int, index(a)) on commit delete rows") + tk.MustExec("create table txx(a int)") + tk.MustExec("begin") }) +} - t.Run("local", func(t *testing.T) { - assertTemporaryTableNoNetwork(t, func(tk *testkit.TestKit) { - tk.MustExec("create temporary table tmp_t (id int primary key, a int, b int, index(a))") - tk.MustExec("begin") - }) +func TestLocalTemporaryTableNoNetworkWithCreateOutsideTxn(t *testing.T) { + assertTemporaryTableNoNetwork(t, func(tk *testkit.TestKit) { + tk.MustExec("create temporary table tmp_t (id int primary key, a int, b int, index(a))") + tk.MustExec("begin") }) +} - t.Run("local and create table inside txn", func(t *testing.T) { - assertTemporaryTableNoNetwork(t, func(tk *testkit.TestKit) { - tk.MustExec("begin") - tk.MustExec("create temporary table tmp_t (id int primary key, a int, b int, index(a))") - }) +func TestLocalTemporaryTableNoNetworkWithInsideTxn(t *testing.T) { + assertTemporaryTableNoNetwork(t, func(tk *testkit.TestKit) { + tk.MustExec("begin") + tk.MustExec("create temporary table tmp_t (id int primary key, a int, b int, index(a))") }) } diff --git a/infoschema/builder.go b/infoschema/builder.go index ff3366bc25149..23dbaf1c8d153 100644 --- a/infoschema/builder.go +++ b/infoschema/builder.go @@ -786,6 +786,7 @@ func (b *Builder) InitWithOldInfoSchema(oldSchema InfoSchema) *Builder { b.copySchemasMap(oldIS) b.copyBundlesMap(oldIS) b.copyPoliciesMap(oldIS) + b.copyTemporaryTableIDsMap(oldIS) copy(b.is.sortedTablesBuckets, oldIS.sortedTablesBuckets) return b @@ -811,6 +812,19 @@ func (b *Builder) copyPoliciesMap(oldIS *infoSchema) { } } +func (b *Builder) copyTemporaryTableIDsMap(oldIS *infoSchema) { + is := b.is + if len(oldIS.temporaryTableIDs) == 0 { + is.temporaryTableIDs = nil + return + } + + is.temporaryTableIDs = make(map[int64]struct{}) + for tblID := range oldIS.temporaryTableIDs { + is.temporaryTableIDs[tblID] = struct{}{} + } +} + // getSchemaAndCopyIfNecessary creates a new schemaTables instance when a table in the database has changed. // It also does modifications on the new one because old schemaTables must be read-only. // And it will only copy the changed database once in the lifespan of the Builder. diff --git a/infoschema/infoschema_test.go b/infoschema/infoschema_test.go index 27a6aadd497cc..a1d5a4a689d98 100644 --- a/infoschema/infoschema_test.go +++ b/infoschema/infoschema_test.go @@ -327,18 +327,19 @@ func TestBuildSchemaWithGlobalTemporaryTable(t *testing.T) { require.True(t, ok) doChange := func(changes ...func(m *meta.Meta, builder *infoschema.Builder)) infoschema.InfoSchema { - builder, err := infoschema.NewBuilder(store, nil).InitWithDBInfos([]*model.DBInfo{db}, nil, is.SchemaMetaVersion()) - require.NoError(t, err) ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnDDL) - err = kv.RunInNewTxn(ctx, store, true, func(ctx context.Context, txn kv.Transaction) error { + curIs := is + err := kv.RunInNewTxn(ctx, store, true, func(ctx context.Context, txn kv.Transaction) error { m := meta.NewMeta(txn) for _, change := range changes { + builder := infoschema.NewBuilder(store, nil).InitWithOldInfoSchema(curIs) change(m, builder) + curIs = builder.Build() } return nil }) require.NoError(t, err) - return builder.Build() + return curIs } createGlobalTemporaryTableChange := func(tblID int64) func(m *meta.Meta, builder *infoschema.Builder) { @@ -354,6 +355,18 @@ func TestBuildSchemaWithGlobalTemporaryTable(t *testing.T) { } } + createNormalTableChange := func(tblID int64) func(m *meta.Meta, builder *infoschema.Builder) { + return func(m *meta.Meta, builder *infoschema.Builder) { + err := m.CreateTableOrView(db.ID, &model.TableInfo{ + ID: tblID, + State: model.StatePublic, + }) + require.NoError(t, err) + _, err = builder.ApplyDiff(m, &model.SchemaDiff{Type: model.ActionCreateTable, SchemaID: db.ID, TableID: tblID}) + require.NoError(t, err) + } + } + dropTableChange := func(tblID int64) func(m *meta.Meta, builder *infoschema.Builder) { return func(m *meta.Meta, builder *infoschema.Builder) { err := m.DropTableOrView(db.ID, tblID) @@ -437,6 +450,16 @@ func TestBuildSchemaWithGlobalTemporaryTable(t *testing.T) { createGlobalTemporaryTableChange(tbID2), dropTableChange(tbID), ).HasTemporaryTable()) + + // create temporary and then create normal + tbID, err = genGlobalID(store) + require.NoError(t, err) + tbID2, err = genGlobalID(store) + require.NoError(t, err) + require.True(t, doChange( + createGlobalTemporaryTableChange(tbID), + createNormalTableChange(tbID2), + ).HasTemporaryTable()) } func TestBuildBundle(t *testing.T) {