diff --git a/executor/executor.go b/executor/executor.go index d2b726f24adbb..491f3b2b4e27a 100644 --- a/executor/executor.go +++ b/executor/executor.go @@ -1928,7 +1928,7 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { sc.IgnoreTruncate = true sc.IgnoreZeroInDate = true sc.AllowInvalidDate = vars.SQLMode.HasAllowInvalidDatesMode() - if stmt.Tp == ast.ShowWarnings || stmt.Tp == ast.ShowErrors { + if stmt.Tp == ast.ShowWarnings || stmt.Tp == ast.ShowErrors || stmt.Tp == ast.ShowSessionStates { sc.InShowWarning = true sc.SetWarnings(vars.StmtCtx.GetWarnings()) } @@ -1936,6 +1936,11 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { sc.IgnoreTruncate = false sc.IgnoreZeroInDate = true sc.AllowInvalidDate = vars.SQLMode.HasAllowInvalidDatesMode() + case *ast.SetSessionStatesStmt: + sc.InSetSessionStatesStmt = true + sc.IgnoreTruncate = true + sc.IgnoreZeroInDate = true + sc.AllowInvalidDate = vars.SQLMode.HasAllowInvalidDatesMode() default: sc.IgnoreTruncate = true sc.IgnoreZeroInDate = true @@ -1954,7 +1959,7 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { sc.PrevLastInsertID = vars.StmtCtx.PrevLastInsertID } sc.PrevAffectedRows = 0 - if vars.StmtCtx.InUpdateStmt || vars.StmtCtx.InDeleteStmt || vars.StmtCtx.InInsertStmt { + if vars.StmtCtx.InUpdateStmt || vars.StmtCtx.InDeleteStmt || vars.StmtCtx.InInsertStmt || vars.StmtCtx.InSetSessionStatesStmt { sc.PrevAffectedRows = int64(vars.StmtCtx.AffectedRows()) } else if vars.StmtCtx.InSelectStmt { sc.PrevAffectedRows = -1 diff --git a/session/session.go b/session/session.go index c5c1ead4c65b4..d01c3e7e549d3 100644 --- a/session/session.go +++ b/session/session.go @@ -3543,15 +3543,16 @@ func (s *session) EncodeSessionStates(ctx context.Context, sctx sessionctx.Conte // DecodeSessionStates implements SessionStatesHandler.DecodeSessionStates interface. func (s *session) DecodeSessionStates(ctx context.Context, sctx sessionctx.Context, sessionStates *sessionstates.SessionStates) (err error) { - if err = s.sessionVars.DecodeSessionStates(ctx, sessionStates); err != nil { - return err - } - // Decode session variables. for name, val := range sessionStates.SystemVars { if err = variable.SetSessionSystemVar(s.sessionVars, name, val); err != nil { return err } } + + // Decode stmt ctx after session vars because setting session vars may override stmt ctx, such as warnings. + if err = s.sessionVars.DecodeSessionStates(ctx, sessionStates); err != nil { + return err + } return err } diff --git a/sessionctx/sessionstates/session_states.go b/sessionctx/sessionstates/session_states.go index baf876ff87b4f..10a2756dd04f4 100644 --- a/sessionctx/sessionstates/session_states.go +++ b/sessionctx/sessionstates/session_states.go @@ -18,6 +18,7 @@ import ( "time" ptypes "github.com/pingcap/tidb/parser/types" + "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/types" ) @@ -52,4 +53,7 @@ type SessionStates struct { FoundInBinding bool `json:"in-binding,omitempty"` SequenceLatestValues map[int64]int64 `json:"seq-values,omitempty"` MPPStoreLastFailTime map[string]time.Time `json:"store-fail-time,omitempty"` + LastAffectedRows int64 `json:"affected-rows,omitempty"` + LastInsertID uint64 `json:"last-insert-id,omitempty"` + Warnings []stmtctx.SQLWarn `json:"warnings,omitempty"` } diff --git a/sessionctx/sessionstates/session_states_test.go b/sessionctx/sessionstates/session_states_test.go index 847f50f4e9a2b..29101af06f392 100644 --- a/sessionctx/sessionstates/session_states_test.go +++ b/sessionctx/sessionstates/session_states_test.go @@ -435,6 +435,125 @@ func TestSessionCtx(t *testing.T) { } } +func TestStatementCtx(t *testing.T) { + store, clean := testkit.CreateMockStore(t) + defer clean() + tk := testkit.NewTestKit(t, store) + tk.MustExec("create table test.t1(id int auto_increment primary key, str char(1))") + + tests := []struct { + setFunc func(tk *testkit.TestKit) any + checkFunc func(tk *testkit.TestKit, param any) + }{ + { + // check LastAffectedRows + setFunc: func(tk *testkit.TestKit) any { + tk.MustQuery("show warnings") + return nil + }, + checkFunc: func(tk *testkit.TestKit, param any) { + tk.MustQuery("select row_count()").Check(testkit.Rows("0")) + tk.MustQuery("select row_count()").Check(testkit.Rows("-1")) + }, + }, + { + // check LastAffectedRows + setFunc: func(tk *testkit.TestKit) any { + tk.MustQuery("select 1") + return nil + }, + checkFunc: func(tk *testkit.TestKit, param any) { + tk.MustQuery("select row_count()").Check(testkit.Rows("-1")) + }, + }, + { + // check LastAffectedRows + setFunc: func(tk *testkit.TestKit) any { + tk.MustExec("insert into test.t1(str) value('a'), ('b'), ('c')") + return nil + }, + checkFunc: func(tk *testkit.TestKit, param any) { + tk.MustQuery("select row_count()").Check(testkit.Rows("3")) + tk.MustQuery("select row_count()").Check(testkit.Rows("-1")) + }, + }, + { + // check LastInsertID + checkFunc: func(tk *testkit.TestKit, param any) { + tk.MustQuery("select @@last_insert_id").Check(testkit.Rows("0")) + }, + }, + { + // check LastInsertID + setFunc: func(tk *testkit.TestKit) any { + tk.MustExec("insert into test.t1(str) value('d')") + rows := tk.MustQuery("select @@last_insert_id").Rows() + require.NotEqual(t, "0", rows[0][0].(string)) + return rows + }, + checkFunc: func(tk *testkit.TestKit, param any) { + tk.MustQuery("select @@last_insert_id").Check(param.([][]any)) + }, + }, + { + // check Warning + setFunc: func(tk *testkit.TestKit) any { + tk.MustQuery("select 1") + tk.MustQuery("show warnings").Check(testkit.Rows()) + tk.MustQuery("show errors").Check(testkit.Rows()) + return nil + }, + checkFunc: func(tk *testkit.TestKit, param any) { + tk.MustQuery("show warnings").Check(testkit.Rows()) + tk.MustQuery("show errors").Check(testkit.Rows()) + tk.MustQuery("select @@warning_count, @@error_count").Check(testkit.Rows("0 0")) + }, + }, + { + // check Warning + setFunc: func(tk *testkit.TestKit) any { + tk.MustGetErrCode("insert into test.t1(str) value('ef')", errno.ErrDataTooLong) + rows := tk.MustQuery("show warnings").Rows() + require.Equal(t, 1, len(rows)) + tk.MustQuery("show errors").Check(rows) + return rows + }, + checkFunc: func(tk *testkit.TestKit, param any) { + tk.MustQuery("show warnings").Check(param.([][]any)) + tk.MustQuery("show errors").Check(param.([][]any)) + tk.MustQuery("select @@warning_count, @@error_count").Check(testkit.Rows("1 1")) + }, + }, + { + // check Warning + setFunc: func(tk *testkit.TestKit) any { + tk.MustExec("set sql_mode=''") + tk.MustExec("insert into test.t1(str) value('ef'), ('ef')") + rows := tk.MustQuery("show warnings").Rows() + require.Equal(t, 2, len(rows)) + tk.MustQuery("show errors").Check(testkit.Rows()) + return rows + }, + checkFunc: func(tk *testkit.TestKit, param any) { + tk.MustQuery("show warnings").Check(param.([][]any)) + tk.MustQuery("show errors").Check(testkit.Rows()) + tk.MustQuery("select @@warning_count, @@error_count").Check(testkit.Rows("2 0")) + }, + }, + } + + for _, tt := range tests { + tk1 := testkit.NewTestKit(t, store) + var param any + if tt.setFunc != nil { + param = tt.setFunc(tk1) + } + tk2 := testkit.NewTestKit(t, store) + showSessionStatesAndSet(t, tk1, tk2) + tt.checkFunc(tk2, param) + } +} + func showSessionStatesAndSet(t *testing.T, tk1, tk2 *testkit.TestKit) { rows := tk1.MustQuery("show session_states").Rows() require.Len(t, rows, 1) diff --git a/sessionctx/stmtctx/stmtctx.go b/sessionctx/stmtctx/stmtctx.go index 70cd4bec5f898..4d623015492cc 100644 --- a/sessionctx/stmtctx/stmtctx.go +++ b/sessionctx/stmtctx/stmtctx.go @@ -15,6 +15,7 @@ package stmtctx import ( + "encoding/json" "math" "sort" "strconv" @@ -22,10 +23,12 @@ import ( "sync/atomic" "time" + "github.com/pingcap/errors" "github.com/pingcap/tidb/parser" "github.com/pingcap/tidb/parser/ast" "github.com/pingcap/tidb/parser/model" "github.com/pingcap/tidb/parser/mysql" + "github.com/pingcap/tidb/parser/terror" "github.com/pingcap/tidb/util/disk" "github.com/pingcap/tidb/util/execdetails" "github.com/pingcap/tidb/util/memory" @@ -60,6 +63,43 @@ type SQLWarn struct { Err error } +type jsonSQLWarn struct { + Level string `json:"level"` + SQLErr *terror.Error `json:"err,omitempty"` + Msg string `json:"msg,omitempty"` +} + +// MarshalJSON implements the Marshaler.MarshalJSON interface. +func (warn *SQLWarn) MarshalJSON() ([]byte, error) { + w := &jsonSQLWarn{ + Level: warn.Level, + } + e := errors.Cause(warn.Err) + switch x := e.(type) { + case *terror.Error: + // Omit outter errors because only the most inner error matters. + w.SQLErr = x + default: + w.Msg = e.Error() + } + return json.Marshal(w) +} + +// UnmarshalJSON implements the Unmarshaler.UnmarshalJSON interface. +func (warn *SQLWarn) UnmarshalJSON(data []byte) error { + var w jsonSQLWarn + if err := json.Unmarshal(data, &w); err != nil { + return err + } + warn.Level = w.Level + if w.SQLErr != nil { + warn.Err = w.SQLErr + } else { + warn.Err = errors.New(w.Msg) + } + return nil +} + // StatementContext contains variables for a statement. // It should be reset before executing a statement. type StatementContext struct { @@ -76,6 +116,7 @@ type StatementContext struct { InLoadDataStmt bool InExplainStmt bool InCreateOrAlterStmt bool + InSetSessionStatesStmt bool InPreparedPlanBuilding bool IgnoreTruncate bool IgnoreZeroInDate bool @@ -406,6 +447,13 @@ func (sc *StatementContext) AddAffectedRows(rows uint64) { sc.mu.affectedRows += rows } +// SetAffectedRows sets affected rows. +func (sc *StatementContext) SetAffectedRows(rows uint64) { + sc.mu.Lock() + sc.mu.affectedRows = rows + sc.mu.Unlock() +} + // AffectedRows gets affected rows. func (sc *StatementContext) AffectedRows() uint64 { sc.mu.Lock() @@ -558,6 +606,7 @@ func (sc *StatementContext) SetWarnings(warns []SQLWarn) { sc.mu.Lock() defer sc.mu.Unlock() sc.mu.warnings = warns + sc.mu.errorCount = 0 for _, w := range warns { if w.Level == WarnLevelError { sc.mu.errorCount++ diff --git a/sessionctx/stmtctx/stmtctx_test.go b/sessionctx/stmtctx/stmtctx_test.go index 7a4ec77a90660..b8f36dcb25055 100644 --- a/sessionctx/stmtctx/stmtctx_test.go +++ b/sessionctx/stmtctx/stmtctx_test.go @@ -16,12 +16,15 @@ package stmtctx_test import ( "context" + "encoding/json" "fmt" "testing" "time" + "github.com/pingcap/errors" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/sessionctx/stmtctx" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/testkit" "github.com/pingcap/tidb/util/execdetails" "github.com/stretchr/testify/require" @@ -143,3 +146,43 @@ func TestWeakConsistencyRead(t *testing.T) { execAndCheck("execute s", testkit.Rows("1 1 2"), kv.SI) tk.MustExec("rollback") } + +func TestMarshalSQLWarn(t *testing.T) { + warns := []stmtctx.SQLWarn{ + { + Level: stmtctx.WarnLevelError, + Err: errors.New("any error"), + }, + { + Level: stmtctx.WarnLevelError, + Err: errors.Trace(errors.New("any error")), + }, + { + Level: stmtctx.WarnLevelWarning, + Err: variable.ErrUnknownSystemVar.GenWithStackByArgs("unknown"), + }, + { + Level: stmtctx.WarnLevelWarning, + Err: errors.Trace(variable.ErrUnknownSystemVar.GenWithStackByArgs("unknown")), + }, + } + + store, clean := testkit.CreateMockStore(t) + defer clean() + tk := testkit.NewTestKit(t, store) + // First query can trigger loading global variables, which produces warnings. + tk.MustQuery("select 1") + tk.Session().GetSessionVars().StmtCtx.SetWarnings(warns) + rows := tk.MustQuery("show warnings").Rows() + require.Equal(t, len(warns), len(rows)) + + // The unmarshalled result doesn't need to be exactly the same with the original one. + // We only need that the results of `show warnings` are the same. + bytes, err := json.Marshal(warns) + require.NoError(t, err) + var newWarns []stmtctx.SQLWarn + err = json.Unmarshal(bytes, &newWarns) + require.NoError(t, err) + tk.Session().GetSessionVars().StmtCtx.SetWarnings(newWarns) + tk.MustQuery("show warnings").Check(rows) +} diff --git a/sessionctx/variable/session.go b/sessionctx/variable/session.go index fe4f469e76134..12546cde3c0ad 100644 --- a/sessionctx/variable/session.go +++ b/sessionctx/variable/session.go @@ -1867,6 +1867,11 @@ func (s *SessionVars) EncodeSessionStates(ctx context.Context, sessionStates *se sessionStates.MPPStoreLastFailTime = s.MPPStoreLastFailTime sessionStates.FoundInPlanCache = s.PrevFoundInPlanCache sessionStates.FoundInBinding = s.PrevFoundInBinding + + // Encode StatementContext. We encode it here to avoid circle dependency. + sessionStates.LastAffectedRows = s.StmtCtx.PrevAffectedRows + sessionStates.LastInsertID = s.StmtCtx.PrevLastInsertID + sessionStates.Warnings = s.StmtCtx.GetWarnings() return } @@ -1902,6 +1907,11 @@ func (s *SessionVars) DecodeSessionStates(ctx context.Context, sessionStates *se } s.FoundInPlanCache = sessionStates.FoundInPlanCache s.FoundInBinding = sessionStates.FoundInBinding + + // Decode StatementContext. + s.StmtCtx.SetAffectedRows(uint64(sessionStates.LastAffectedRows)) + s.StmtCtx.PrevLastInsertID = sessionStates.LastInsertID + s.StmtCtx.SetWarnings(sessionStates.Warnings) return }