Skip to content

Commit

Permalink
*: refactor ExecuteInternal to return single resultset (#22546)
Browse files Browse the repository at this point in the history
  • Loading branch information
morgo authored Feb 1, 2021
1 parent 109ad45 commit 7ca1629
Show file tree
Hide file tree
Showing 7 changed files with 53 additions and 70 deletions.
23 changes: 11 additions & 12 deletions bindinfo/handle.go
Original file line number Diff line number Diff line change
Expand Up @@ -661,16 +661,16 @@ func (h *BindHandle) CaptureBaselines() {
func getHintsForSQL(sctx sessionctx.Context, sql string) (string, error) {
origVals := sctx.GetSessionVars().UsePlanBaselines
sctx.GetSessionVars().UsePlanBaselines = false
recordSets, err := sctx.(sqlexec.SQLExecutor).ExecuteInternal(context.TODO(), fmt.Sprintf("explain format='hint' %s", sql))
rs, err := sctx.(sqlexec.SQLExecutor).ExecuteInternal(context.TODO(), fmt.Sprintf("explain format='hint' %s", sql))
sctx.GetSessionVars().UsePlanBaselines = origVals
if len(recordSets) > 0 {
defer terror.Log(recordSets[0].Close())
if rs != nil {
defer terror.Call(rs.Close)
}
if err != nil {
return "", err
}
chk := recordSets[0].NewChunk()
err = recordSets[0].Next(context.TODO(), chk)
chk := rs.NewChunk()
err = rs.Next(context.TODO(), chk)
if err != nil {
return "", err
}
Expand Down Expand Up @@ -873,23 +873,22 @@ func runSQL(ctx context.Context, sctx sessionctx.Context, sql string, resultChan
resultChan <- fmt.Errorf("run sql panicked: %v", string(buf))
}
}()
recordSets, err := sctx.(sqlexec.SQLExecutor).ExecuteInternal(ctx, sql)
rs, err := sctx.(sqlexec.SQLExecutor).ExecuteInternal(ctx, sql)
if err != nil {
if len(recordSets) > 0 {
terror.Call(recordSets[0].Close)
if rs != nil {
terror.Call(rs.Close)
}
resultChan <- err
return
}
recordSet := recordSets[0]
chk := recordSets[0].NewChunk()
chk := rs.NewChunk()
for {
err = recordSet.Next(ctx, chk)
err = rs.Next(ctx, chk)
if err != nil || chk.NumRows() == 0 {
break
}
}
terror.Call(recordSets[0].Close)
terror.Call(rs.Close)
resultChan <- err
}

Expand Down
16 changes: 6 additions & 10 deletions executor/analyze.go
Original file line number Diff line number Diff line change
Expand Up @@ -758,22 +758,18 @@ func (e *AnalyzeFastExec) calculateEstimateSampleStep() (err error) {
if len(partition) > 0 {
sql += partition
}
var recordSets []sqlexec.RecordSet
recordSets, err = e.ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.TODO(), sql)
var rs sqlexec.RecordSet
rs, err = e.ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.TODO(), sql)
if err != nil {
return
}
if len(recordSets) == 0 {
if rs == nil {
err = errors.Trace(errors.Errorf("empty record set"))
return
}
defer func() {
for _, r := range recordSets {
terror.Call(r.Close)
}
}()
chk := recordSets[0].NewChunk()
err = recordSets[0].Next(context.TODO(), chk)
defer terror.Call(rs.Close)
chk := rs.NewChunk()
err = rs.Next(context.TODO(), chk)
if err != nil {
return
}
Expand Down
12 changes: 6 additions & 6 deletions session/bootstrap.go
Original file line number Diff line number Diff line change
Expand Up @@ -1325,24 +1325,24 @@ func upgradeToVer61(s Session, ver int64) {
mustExecute(s, "COMMIT")
}()
mustExecute(s, h.LockBindInfoSQL())
var recordSets []sqlexec.RecordSet
recordSets, err = s.ExecuteInternal(context.Background(),
var rs sqlexec.RecordSet
rs, err = s.ExecuteInternal(context.Background(),
`SELECT bind_sql, default_db, status, create_time, charset, collation, source
FROM mysql.bind_info
WHERE source != 'builtin'
ORDER BY update_time DESC`)
if err != nil {
logutil.BgLogger().Fatal("upgradeToVer61 error", zap.Error(err))
}
if len(recordSets) > 0 {
defer terror.Call(recordSets[0].Close)
if rs != nil {
defer terror.Call(rs.Close)
}
req := recordSets[0].NewChunk()
req := rs.NewChunk()
iter := chunk.NewIterator4Chunk(req)
p := parser.New()
now := types.NewTime(types.FromGoTime(time.Now()), mysql.TypeTimestamp, 3)
for {
err = recordSets[0].Next(context.TODO(), req)
err = rs.Next(context.TODO(), req)
if err != nil {
logutil.BgLogger().Fatal("upgradeToVer61 error", zap.Error(err))
}
Expand Down
55 changes: 21 additions & 34 deletions session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,13 +123,14 @@ type Session interface {
LastInsertID() uint64 // LastInsertID is the last inserted auto_increment ID.
LastMessage() string // LastMessage is the info message that may be generated by last command
AffectedRows() uint64 // Affected rows by latest executed stmt.
// Execute is deprecated, use ExecuteStmt() instead.
// Execute is deprecated, and only used by plugins. Use ExecuteStmt() instead.
Execute(context.Context, string) ([]sqlexec.RecordSet, error) // Execute a sql statement.
// ExecuteStmt executes a parsed statement.
ExecuteStmt(context.Context, ast.StmtNode) (sqlexec.RecordSet, error)
// Parse is deprecated, use ParseWithParams() instead.
Parse(ctx context.Context, sql string) ([]ast.StmtNode, error)
// ExecuteInternal is a helper around ParseWithParams() and ExecuteStmt(). It is not allowed to execute multiple statements.
ExecuteInternal(context.Context, string, ...interface{}) ([]sqlexec.RecordSet, error)
ExecuteInternal(context.Context, string, ...interface{}) (sqlexec.RecordSet, error)
String() string // String is used to debug.
CommitTxn(context.Context) error
RollbackTxn(context.Context)
Expand Down Expand Up @@ -899,37 +900,22 @@ func (s *session) ExecRestrictedSQLWithSnapshot(sql string) ([]chunk.Row, []*ast
func execRestrictedSQL(ctx context.Context, se *session, sql string) ([]chunk.Row, []*ast.ResultField, error) {
ctx = context.WithValue(ctx, execdetails.StmtExecDetailKey, &execdetails.StmtExecDetails{})
startTime := time.Now()
recordSets, err := se.ExecuteInternal(ctx, sql)
defer func() {
for _, rs := range recordSets {
closeErr := rs.Close()
if closeErr != nil && err == nil {
err = closeErr
}
}
}()
if err != nil {
rs, err := se.ExecuteInternal(ctx, sql)
if rs != nil {
defer terror.Call(rs.Close)
}
if err != nil || rs == nil {
return nil, nil, err
}

var (
rows []chunk.Row
fields []*ast.ResultField
)
// Execute all recordset, take out the first one as result.
for i, rs := range recordSets {
tmp, err := drainRecordSet(ctx, se, rs)
if err != nil {
return nil, nil, err
}

if i == 0 {
rows = tmp
fields = rs.Fields()
}
rows, err := drainRecordSet(ctx, se, rs)
if err != nil {
return nil, nil, err
}

metrics.QueryDurationHistogram.WithLabelValues(metrics.LblInternal).Observe(time.Since(startTime).Seconds())
return rows, fields, err
return rows, rs.Fields(), err
}

func createSessionFunc(store kv.Storage) pools.Factory {
Expand Down Expand Up @@ -1259,7 +1245,7 @@ func (s *session) SetProcessInfo(sql string, t time.Time, command byte, maxExecu
s.processInfo.Store(&pi)
}

func (s *session) ExecuteInternal(ctx context.Context, sql string, args ...interface{}) (recordSets []sqlexec.RecordSet, err error) {
func (s *session) ExecuteInternal(ctx context.Context, sql string, args ...interface{}) (rs sqlexec.RecordSet, err error) {
origin := s.sessionVars.InRestrictedSQL
s.sessionVars.InRestrictedSQL = true
defer func() {
Expand All @@ -1278,17 +1264,18 @@ func (s *session) ExecuteInternal(ctx context.Context, sql string, args ...inter
return nil, err
}

rs, err := s.ExecuteStmt(ctx, stmtNode)
rs, err = s.ExecuteStmt(ctx, stmtNode)
if err != nil {
s.sessionVars.StmtCtx.AppendError(err)
}
if rs == nil {
return nil, err
}

return []sqlexec.RecordSet{rs}, err
return rs, err
}

// Execute is deprecated, we can remove it as soon as plugins are migrated.
func (s *session) Execute(ctx context.Context, sql string) (recordSets []sqlexec.RecordSet, err error) {
if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil {
span1 := span.Tracer().StartSpan("session.Execute", opentracing.ChildOf(span.Context()))
Expand Down Expand Up @@ -2224,18 +2211,18 @@ var (
// loadParameter loads read-only parameter from mysql.tidb
func loadParameter(se *session, name string) (string, error) {
sql := "select variable_value from mysql.tidb where variable_name = '" + name + "'"
rss, errLoad := se.Execute(context.Background(), sql)
rs, errLoad := se.ExecuteInternal(context.Background(), sql)
if errLoad != nil {
return "", errLoad
}
// the record of mysql.tidb under where condition: variable_name = $name should shall only be one.
defer func() {
if err := rss[0].Close(); err != nil {
if err := rs.Close(); err != nil {
logutil.BgLogger().Error("close result set error", zap.Error(err))
}
}()
req := rss[0].NewChunk()
if err := rss[0].Next(context.Background(), req); err != nil {
req := rs.NewChunk()
if err := rs.Next(context.Background(), req); err != nil {
return "", err
}
if req.NumRows() == 0 {
Expand Down
8 changes: 4 additions & 4 deletions store/tikv/gcworker/gc_worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -1760,14 +1760,14 @@ func (w *GCWorker) loadValueFromSysTable(key string) (string, error) {
se := createSession(w.store)
defer se.Close()
rs, err := se.ExecuteInternal(ctx, `SELECT HIGH_PRIORITY (variable_value) FROM mysql.tidb WHERE variable_name=%? FOR UPDATE`, key)
if len(rs) > 0 {
defer terror.Call(rs[0].Close)
if rs != nil {
defer terror.Call(rs.Close)
}
if err != nil {
return "", errors.Trace(err)
}
req := rs[0].NewChunk()
err = rs[0].Next(ctx, req)
req := rs.NewChunk()
err = rs.Next(ctx, req)
if err != nil {
return "", errors.Trace(err)
}
Expand Down
6 changes: 3 additions & 3 deletions util/mock/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,12 @@ func (txn *wrapTxn) GetUnionStore() kv.UnionStore {

// Execute implements sqlexec.SQLExecutor Execute interface.
func (c *Context) Execute(ctx context.Context, sql string) ([]sqlexec.RecordSet, error) {
return nil, errors.Errorf("Not Support.")
return nil, errors.Errorf("Not Supported.")
}

// ExecuteInternal implements sqlexec.SQLExecutor ExecuteInternal interface.
func (c *Context) ExecuteInternal(ctx context.Context, sql string, args ...interface{}) ([]sqlexec.RecordSet, error) {
return nil, errors.Errorf("Not Support.")
func (c *Context) ExecuteInternal(ctx context.Context, sql string, args ...interface{}) (sqlexec.RecordSet, error) {
return nil, errors.Errorf("Not Supported.")
}

type mockDDLOwnerChecker struct{}
Expand Down
3 changes: 2 additions & 1 deletion util/sqlexec/restricted_sql_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,10 @@ func ExecOptionWithSnapshot(snapshot uint64) OptionFuncAlias {
// For example, privilege/privileges package need execute SQL, if it use
// session.Session.Execute, then privilege/privileges and tidb would become a circle.
type SQLExecutor interface {
// Execute is only used by plugins. It can be removed soon.
Execute(ctx context.Context, sql string) ([]RecordSet, error)
// ExecuteInternal means execute sql as the internal sql.
ExecuteInternal(ctx context.Context, sql string, args ...interface{}) ([]RecordSet, error)
ExecuteInternal(ctx context.Context, sql string, args ...interface{}) (RecordSet, error)
}

// SQLParser is an interface provides parsing sql statement.
Expand Down

0 comments on commit 7ca1629

Please sign in to comment.