Skip to content

Commit

Permalink
*: refactor ExecuteInternal to return single resultset (#22546) (#22655)
Browse files Browse the repository at this point in the history
Signed-off-by: ti-srebot <ti-srebot@pingcap.com>
  • Loading branch information
ti-srebot authored Feb 4, 2021
1 parent 404d743 commit 80a3b1d
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 59 deletions.
23 changes: 11 additions & 12 deletions bindinfo/handle.go
Original file line number Diff line number Diff line change
Expand Up @@ -673,16 +673,16 @@ func (h *BindHandle) CaptureBaselines() {
func getHintsForSQL(sctx sessionctx.Context, sql string) (string, error) {
origVals := sctx.GetSessionVars().UsePlanBaselines
sctx.GetSessionVars().UsePlanBaselines = false
recordSets, err := sctx.(sqlexec.SQLExecutor).ExecuteInternal(context.TODO(), fmt.Sprintf("explain format='hint' %s", sql))
rs, err := sctx.(sqlexec.SQLExecutor).ExecuteInternal(context.TODO(), fmt.Sprintf("explain format='hint' %s", sql))
sctx.GetSessionVars().UsePlanBaselines = origVals
if len(recordSets) > 0 {
defer terror.Log(recordSets[0].Close())
if rs != nil {
defer terror.Call(rs.Close)
}
if err != nil {
return "", err
}
chk := recordSets[0].NewChunk()
err = recordSets[0].Next(context.TODO(), chk)
chk := rs.NewChunk()
err = rs.Next(context.TODO(), chk)
if err != nil {
return "", err
}
Expand Down Expand Up @@ -882,23 +882,22 @@ func runSQL(ctx context.Context, sctx sessionctx.Context, sql string, resultChan
resultChan <- fmt.Errorf("run sql panicked: %v", string(buf))
}
}()
recordSets, err := sctx.(sqlexec.SQLExecutor).ExecuteInternal(ctx, sql)
rs, err := sctx.(sqlexec.SQLExecutor).ExecuteInternal(ctx, sql)
if err != nil {
if len(recordSets) > 0 {
terror.Call(recordSets[0].Close)
if rs != nil {
terror.Call(rs.Close)
}
resultChan <- err
return
}
recordSet := recordSets[0]
chk := recordSets[0].NewChunk()
chk := rs.NewChunk()
for {
err = recordSet.Next(ctx, chk)
err = rs.Next(ctx, chk)
if err != nil || chk.NumRows() == 0 {
break
}
}
terror.Call(recordSets[0].Close)
terror.Call(rs.Close)
resultChan <- err
}

Expand Down
16 changes: 6 additions & 10 deletions executor/analyze.go
Original file line number Diff line number Diff line change
Expand Up @@ -754,22 +754,18 @@ func (e *AnalyzeFastExec) calculateEstimateSampleStep() (err error) {
if len(partition) > 0 {
sql += partition
}
var recordSets []sqlexec.RecordSet
recordSets, err = e.ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.TODO(), sql)
var rs sqlexec.RecordSet
rs, err = e.ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.TODO(), sql)
if err != nil {
return
}
if len(recordSets) == 0 {
if rs == nil {
err = errors.Trace(errors.Errorf("empty record set"))
return
}
defer func() {
for _, r := range recordSets {
terror.Call(r.Close)
}
}()
chk := recordSets[0].NewChunk()
err = recordSets[0].Next(context.TODO(), chk)
defer terror.Call(rs.Close)
chk := rs.NewChunk()
err = rs.Next(context.TODO(), chk)
if err != nil {
return
}
Expand Down
50 changes: 17 additions & 33 deletions session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ type Session interface {
// This function only saves you from processing potentially unsafe parameters.
ParseWithParams(ctx context.Context, sql string, args ...interface{}) ([]ast.StmtNode, error)
// ExecuteInternal is a helper around ParseWithParams() and ExecuteStmt(). It is not allowed to execute multiple statements.
ExecuteInternal(context.Context, string, ...interface{}) ([]sqlexec.RecordSet, error)
ExecuteInternal(context.Context, string, ...interface{}) (sqlexec.RecordSet, error)
String() string // String is used to debug.
CommitTxn(context.Context) error
RollbackTxn(context.Context)
Expand Down Expand Up @@ -885,37 +885,21 @@ func (s *session) ExecRestrictedSQLWithSnapshot(sql string) ([]chunk.Row, []*ast
func execRestrictedSQL(ctx context.Context, se *session, sql string) ([]chunk.Row, []*ast.ResultField, error) {
ctx = context.WithValue(ctx, execdetails.StmtExecDetailKey, &execdetails.StmtExecDetails{})
startTime := time.Now()
recordSets, err := se.ExecuteInternal(ctx, sql)
defer func() {
for _, rs := range recordSets {
closeErr := rs.Close()
if closeErr != nil && err == nil {
err = closeErr
}
}
}()
if err != nil {
rs, err := se.ExecuteInternal(ctx, sql)
if rs != nil {
defer terror.Call(rs.Close)
}
if err != nil || rs == nil {
return nil, nil, err
}

var (
rows []chunk.Row
fields []*ast.ResultField
)
// Execute all recordset, take out the first one as result.
for i, rs := range recordSets {
tmp, err := drainRecordSet(ctx, se, rs)
if err != nil {
return nil, nil, err
}

if i == 0 {
rows = tmp
fields = rs.Fields()
}
rows, err := drainRecordSet(ctx, se, rs)
if err != nil {
return nil, nil, err
}
metrics.QueryDurationHistogram.WithLabelValues(metrics.LblInternal).Observe(time.Since(startTime).Seconds())
return rows, fields, err
return rows, rs.Fields(), err
}

func createSessionFunc(store kv.Storage) pools.Factory {
Expand Down Expand Up @@ -1150,7 +1134,7 @@ func (s *session) SetProcessInfo(sql string, t time.Time, command byte, maxExecu
s.processInfo.Store(&pi)
}

func (s *session) ExecuteInternal(ctx context.Context, sql string, args ...interface{}) (recordSets []sqlexec.RecordSet, err error) {
func (s *session) ExecuteInternal(ctx context.Context, sql string, args ...interface{}) (rs sqlexec.RecordSet, err error) {
origin := s.sessionVars.InRestrictedSQL
s.sessionVars.InRestrictedSQL = true
defer func() {
Expand All @@ -1172,15 +1156,15 @@ func (s *session) ExecuteInternal(ctx context.Context, sql string, args ...inter
return nil, errors.New("Executing multiple statements internally is not supported")
}

rs, err := s.ExecuteStmt(ctx, stmtNodes[0])
rs, err = s.ExecuteStmt(ctx, stmtNodes[0])
if err != nil {
s.sessionVars.StmtCtx.AppendError(err)
}
if rs == nil {
return nil, err
}

return []sqlexec.RecordSet{rs}, err
return rs, err
}

func (s *session) Execute(ctx context.Context, sql string) (recordSets []sqlexec.RecordSet, err error) {
Expand Down Expand Up @@ -1999,18 +1983,18 @@ var (
// loadParameter loads read-only parameter from mysql.tidb
func loadParameter(se *session, name string) (string, error) {
sql := "select variable_value from mysql.tidb where variable_name = '" + name + "'"
rss, errLoad := se.Execute(context.Background(), sql)
rs, errLoad := se.ExecuteInternal(context.Background(), sql)
if errLoad != nil {
return "", errLoad
}
// the record of mysql.tidb under where condition: variable_name = $name should shall only be one.
defer func() {
if err := rss[0].Close(); err != nil {
if err := rs.Close(); err != nil {
logutil.BgLogger().Error("close result set error", zap.Error(err))
}
}()
req := rss[0].NewChunk()
if err := rss[0].Next(context.Background(), req); err != nil {
req := rs.NewChunk()
if err := rs.Next(context.Background(), req); err != nil {
return "", err
}
if req.NumRows() == 0 {
Expand Down
6 changes: 3 additions & 3 deletions util/mock/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,12 @@ func (txn *wrapTxn) Valid() bool {

// Execute implements sqlexec.SQLExecutor Execute interface.
func (c *Context) Execute(ctx context.Context, sql string) ([]sqlexec.RecordSet, error) {
return nil, errors.Errorf("Not Support.")
return nil, errors.Errorf("Not Supported.")
}

// ExecuteInternal implements sqlexec.SQLExecutor ExecuteInternal interface.
func (c *Context) ExecuteInternal(ctx context.Context, sql string, args ...interface{}) ([]sqlexec.RecordSet, error) {
return nil, errors.Errorf("Not Support.")
func (c *Context) ExecuteInternal(ctx context.Context, sql string, args ...interface{}) (sqlexec.RecordSet, error) {
return nil, errors.Errorf("Not Supported.")
}

type mockDDLOwnerChecker struct{}
Expand Down
2 changes: 1 addition & 1 deletion util/sqlexec/restricted_sql_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func ExecOptionWithSnapshot(snapshot uint64) OptionFuncAlias {
type SQLExecutor interface {
Execute(ctx context.Context, sql string) ([]RecordSet, error)
// ExecuteInternal means execute sql as the internal sql.
ExecuteInternal(ctx context.Context, sql string, args ...interface{}) ([]RecordSet, error)
ExecuteInternal(ctx context.Context, sql string, args ...interface{}) (RecordSet, error)
}

// SQLParser is an interface provides parsing sql statement.
Expand Down

0 comments on commit 80a3b1d

Please sign in to comment.