Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

*: refactor ExecuteInternal to return single resultset (#22546) #22655

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 11 additions & 12 deletions bindinfo/handle.go
Original file line number Diff line number Diff line change
Expand Up @@ -673,16 +673,16 @@ func (h *BindHandle) CaptureBaselines() {
func getHintsForSQL(sctx sessionctx.Context, sql string) (string, error) {
origVals := sctx.GetSessionVars().UsePlanBaselines
sctx.GetSessionVars().UsePlanBaselines = false
recordSets, err := sctx.(sqlexec.SQLExecutor).ExecuteInternal(context.TODO(), fmt.Sprintf("explain format='hint' %s", sql))
rs, err := sctx.(sqlexec.SQLExecutor).ExecuteInternal(context.TODO(), fmt.Sprintf("explain format='hint' %s", sql))
sctx.GetSessionVars().UsePlanBaselines = origVals
if len(recordSets) > 0 {
defer terror.Log(recordSets[0].Close())
if rs != nil {
defer terror.Call(rs.Close)
}
if err != nil {
return "", err
}
chk := recordSets[0].NewChunk()
err = recordSets[0].Next(context.TODO(), chk)
chk := rs.NewChunk()
err = rs.Next(context.TODO(), chk)
if err != nil {
return "", err
}
Expand Down Expand Up @@ -882,23 +882,22 @@ func runSQL(ctx context.Context, sctx sessionctx.Context, sql string, resultChan
resultChan <- fmt.Errorf("run sql panicked: %v", string(buf))
}
}()
recordSets, err := sctx.(sqlexec.SQLExecutor).ExecuteInternal(ctx, sql)
rs, err := sctx.(sqlexec.SQLExecutor).ExecuteInternal(ctx, sql)
if err != nil {
if len(recordSets) > 0 {
terror.Call(recordSets[0].Close)
if rs != nil {
terror.Call(rs.Close)
}
resultChan <- err
return
}
recordSet := recordSets[0]
chk := recordSets[0].NewChunk()
chk := rs.NewChunk()
for {
err = recordSet.Next(ctx, chk)
err = rs.Next(ctx, chk)
if err != nil || chk.NumRows() == 0 {
break
}
}
terror.Call(recordSets[0].Close)
terror.Call(rs.Close)
resultChan <- err
}

Expand Down
16 changes: 6 additions & 10 deletions executor/analyze.go
Original file line number Diff line number Diff line change
Expand Up @@ -754,22 +754,18 @@ func (e *AnalyzeFastExec) calculateEstimateSampleStep() (err error) {
if len(partition) > 0 {
sql += partition
}
var recordSets []sqlexec.RecordSet
recordSets, err = e.ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.TODO(), sql)
var rs sqlexec.RecordSet
rs, err = e.ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.TODO(), sql)
if err != nil {
return
}
if len(recordSets) == 0 {
if rs == nil {
err = errors.Trace(errors.Errorf("empty record set"))
return
}
defer func() {
for _, r := range recordSets {
terror.Call(r.Close)
}
}()
chk := recordSets[0].NewChunk()
err = recordSets[0].Next(context.TODO(), chk)
defer terror.Call(rs.Close)
chk := rs.NewChunk()
err = rs.Next(context.TODO(), chk)
if err != nil {
return
}
Expand Down
50 changes: 17 additions & 33 deletions session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ type Session interface {
// This function only saves you from processing potentially unsafe parameters.
ParseWithParams(ctx context.Context, sql string, args ...interface{}) ([]ast.StmtNode, error)
// ExecuteInternal is a helper around ParseWithParams() and ExecuteStmt(). It is not allowed to execute multiple statements.
ExecuteInternal(context.Context, string, ...interface{}) ([]sqlexec.RecordSet, error)
ExecuteInternal(context.Context, string, ...interface{}) (sqlexec.RecordSet, error)
String() string // String is used to debug.
CommitTxn(context.Context) error
RollbackTxn(context.Context)
Expand Down Expand Up @@ -885,37 +885,21 @@ func (s *session) ExecRestrictedSQLWithSnapshot(sql string) ([]chunk.Row, []*ast
func execRestrictedSQL(ctx context.Context, se *session, sql string) ([]chunk.Row, []*ast.ResultField, error) {
ctx = context.WithValue(ctx, execdetails.StmtExecDetailKey, &execdetails.StmtExecDetails{})
startTime := time.Now()
recordSets, err := se.ExecuteInternal(ctx, sql)
defer func() {
for _, rs := range recordSets {
closeErr := rs.Close()
if closeErr != nil && err == nil {
err = closeErr
}
}
}()
if err != nil {
rs, err := se.ExecuteInternal(ctx, sql)
if rs != nil {
defer terror.Call(rs.Close)
}
if err != nil || rs == nil {
return nil, nil, err
}

var (
rows []chunk.Row
fields []*ast.ResultField
)
// Execute all recordset, take out the first one as result.
for i, rs := range recordSets {
tmp, err := drainRecordSet(ctx, se, rs)
if err != nil {
return nil, nil, err
}

if i == 0 {
rows = tmp
fields = rs.Fields()
}
rows, err := drainRecordSet(ctx, se, rs)
if err != nil {
return nil, nil, err
}
metrics.QueryDurationHistogram.WithLabelValues(metrics.LblInternal).Observe(time.Since(startTime).Seconds())
return rows, fields, err
return rows, rs.Fields(), err
}

func createSessionFunc(store kv.Storage) pools.Factory {
Expand Down Expand Up @@ -1150,7 +1134,7 @@ func (s *session) SetProcessInfo(sql string, t time.Time, command byte, maxExecu
s.processInfo.Store(&pi)
}

func (s *session) ExecuteInternal(ctx context.Context, sql string, args ...interface{}) (recordSets []sqlexec.RecordSet, err error) {
func (s *session) ExecuteInternal(ctx context.Context, sql string, args ...interface{}) (rs sqlexec.RecordSet, err error) {
origin := s.sessionVars.InRestrictedSQL
s.sessionVars.InRestrictedSQL = true
defer func() {
Expand All @@ -1172,15 +1156,15 @@ func (s *session) ExecuteInternal(ctx context.Context, sql string, args ...inter
return nil, errors.New("Executing multiple statements internally is not supported")
}

rs, err := s.ExecuteStmt(ctx, stmtNodes[0])
rs, err = s.ExecuteStmt(ctx, stmtNodes[0])
if err != nil {
s.sessionVars.StmtCtx.AppendError(err)
}
if rs == nil {
return nil, err
}

return []sqlexec.RecordSet{rs}, err
return rs, err
}

func (s *session) Execute(ctx context.Context, sql string) (recordSets []sqlexec.RecordSet, err error) {
Expand Down Expand Up @@ -1999,18 +1983,18 @@ var (
// loadParameter loads read-only parameter from mysql.tidb
func loadParameter(se *session, name string) (string, error) {
sql := "select variable_value from mysql.tidb where variable_name = '" + name + "'"
rss, errLoad := se.Execute(context.Background(), sql)
rs, errLoad := se.ExecuteInternal(context.Background(), sql)
if errLoad != nil {
return "", errLoad
}
// the record of mysql.tidb under where condition: variable_name = $name should shall only be one.
defer func() {
if err := rss[0].Close(); err != nil {
if err := rs.Close(); err != nil {
logutil.BgLogger().Error("close result set error", zap.Error(err))
}
}()
req := rss[0].NewChunk()
if err := rss[0].Next(context.Background(), req); err != nil {
req := rs.NewChunk()
if err := rs.Next(context.Background(), req); err != nil {
return "", err
}
if req.NumRows() == 0 {
Expand Down
6 changes: 3 additions & 3 deletions util/mock/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,12 @@ func (txn *wrapTxn) Valid() bool {

// Execute implements sqlexec.SQLExecutor Execute interface.
func (c *Context) Execute(ctx context.Context, sql string) ([]sqlexec.RecordSet, error) {
return nil, errors.Errorf("Not Support.")
return nil, errors.Errorf("Not Supported.")
}

// ExecuteInternal implements sqlexec.SQLExecutor ExecuteInternal interface.
func (c *Context) ExecuteInternal(ctx context.Context, sql string, args ...interface{}) ([]sqlexec.RecordSet, error) {
return nil, errors.Errorf("Not Support.")
func (c *Context) ExecuteInternal(ctx context.Context, sql string, args ...interface{}) (sqlexec.RecordSet, error) {
return nil, errors.Errorf("Not Supported.")
}

type mockDDLOwnerChecker struct{}
Expand Down
2 changes: 1 addition & 1 deletion util/sqlexec/restricted_sql_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func ExecOptionWithSnapshot(snapshot uint64) OptionFuncAlias {
type SQLExecutor interface {
Execute(ctx context.Context, sql string) ([]RecordSet, error)
// ExecuteInternal means execute sql as the internal sql.
ExecuteInternal(ctx context.Context, sql string, args ...interface{}) ([]RecordSet, error)
ExecuteInternal(ctx context.Context, sql string, args ...interface{}) (RecordSet, error)
}

// SQLParser is an interface provides parsing sql statement.
Expand Down