Skip to content

Commit

Permalink
Merge branch 'master' into fix-priv-sql-api
Browse files Browse the repository at this point in the history
  • Loading branch information
AilinKid authored Feb 1, 2021
2 parents bf9f9e3 + 7ca1629 commit 199c105
Show file tree
Hide file tree
Showing 12 changed files with 161 additions and 93 deletions.
23 changes: 11 additions & 12 deletions bindinfo/handle.go
Original file line number Diff line number Diff line change
Expand Up @@ -661,16 +661,16 @@ func (h *BindHandle) CaptureBaselines() {
func getHintsForSQL(sctx sessionctx.Context, sql string) (string, error) {
origVals := sctx.GetSessionVars().UsePlanBaselines
sctx.GetSessionVars().UsePlanBaselines = false
recordSets, err := sctx.(sqlexec.SQLExecutor).ExecuteInternal(context.TODO(), fmt.Sprintf("explain format='hint' %s", sql))
rs, err := sctx.(sqlexec.SQLExecutor).ExecuteInternal(context.TODO(), fmt.Sprintf("explain format='hint' %s", sql))
sctx.GetSessionVars().UsePlanBaselines = origVals
if len(recordSets) > 0 {
defer terror.Log(recordSets[0].Close())
if rs != nil {
defer terror.Call(rs.Close)
}
if err != nil {
return "", err
}
chk := recordSets[0].NewChunk()
err = recordSets[0].Next(context.TODO(), chk)
chk := rs.NewChunk()
err = rs.Next(context.TODO(), chk)
if err != nil {
return "", err
}
Expand Down Expand Up @@ -873,23 +873,22 @@ func runSQL(ctx context.Context, sctx sessionctx.Context, sql string, resultChan
resultChan <- fmt.Errorf("run sql panicked: %v", string(buf))
}
}()
recordSets, err := sctx.(sqlexec.SQLExecutor).ExecuteInternal(ctx, sql)
rs, err := sctx.(sqlexec.SQLExecutor).ExecuteInternal(ctx, sql)
if err != nil {
if len(recordSets) > 0 {
terror.Call(recordSets[0].Close)
if rs != nil {
terror.Call(rs.Close)
}
resultChan <- err
return
}
recordSet := recordSets[0]
chk := recordSets[0].NewChunk()
chk := rs.NewChunk()
for {
err = recordSet.Next(ctx, chk)
err = rs.Next(ctx, chk)
if err != nil || chk.NumRows() == 0 {
break
}
}
terror.Call(recordSets[0].Close)
terror.Call(rs.Close)
resultChan <- err
}

Expand Down
16 changes: 6 additions & 10 deletions executor/analyze.go
Original file line number Diff line number Diff line change
Expand Up @@ -758,22 +758,18 @@ func (e *AnalyzeFastExec) calculateEstimateSampleStep() (err error) {
if len(partition) > 0 {
sql += partition
}
var recordSets []sqlexec.RecordSet
recordSets, err = e.ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.TODO(), sql)
var rs sqlexec.RecordSet
rs, err = e.ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.TODO(), sql)
if err != nil {
return
}
if len(recordSets) == 0 {
if rs == nil {
err = errors.Trace(errors.Errorf("empty record set"))
return
}
defer func() {
for _, r := range recordSets {
terror.Call(r.Close)
}
}()
chk := recordSets[0].NewChunk()
err = recordSets[0].Next(context.TODO(), chk)
defer terror.Call(rs.Close)
chk := rs.NewChunk()
err = rs.Next(context.TODO(), chk)
if err != nil {
return
}
Expand Down
12 changes: 6 additions & 6 deletions session/bootstrap.go
Original file line number Diff line number Diff line change
Expand Up @@ -1325,24 +1325,24 @@ func upgradeToVer61(s Session, ver int64) {
mustExecute(s, "COMMIT")
}()
mustExecute(s, h.LockBindInfoSQL())
var recordSets []sqlexec.RecordSet
recordSets, err = s.ExecuteInternal(context.Background(),
var rs sqlexec.RecordSet
rs, err = s.ExecuteInternal(context.Background(),
`SELECT bind_sql, default_db, status, create_time, charset, collation, source
FROM mysql.bind_info
WHERE source != 'builtin'
ORDER BY update_time DESC`)
if err != nil {
logutil.BgLogger().Fatal("upgradeToVer61 error", zap.Error(err))
}
if len(recordSets) > 0 {
defer terror.Call(recordSets[0].Close)
if rs != nil {
defer terror.Call(rs.Close)
}
req := recordSets[0].NewChunk()
req := rs.NewChunk()
iter := chunk.NewIterator4Chunk(req)
p := parser.New()
now := types.NewTime(types.FromGoTime(time.Now()), mysql.TypeTimestamp, 3)
for {
err = recordSets[0].Next(context.TODO(), req)
err = rs.Next(context.TODO(), req)
if err != nil {
logutil.BgLogger().Fatal("upgradeToVer61 error", zap.Error(err))
}
Expand Down
65 changes: 28 additions & 37 deletions session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,13 +123,14 @@ type Session interface {
LastInsertID() uint64 // LastInsertID is the last inserted auto_increment ID.
LastMessage() string // LastMessage is the info message that may be generated by last command
AffectedRows() uint64 // Affected rows by latest executed stmt.
// Execute is deprecated, use ExecuteStmt() instead.
// Execute is deprecated, and only used by plugins. Use ExecuteStmt() instead.
Execute(context.Context, string) ([]sqlexec.RecordSet, error) // Execute a sql statement.
// ExecuteStmt executes a parsed statement.
ExecuteStmt(context.Context, ast.StmtNode) (sqlexec.RecordSet, error)
// Parse is deprecated, use ParseWithParams() instead.
Parse(ctx context.Context, sql string) ([]ast.StmtNode, error)
// ExecuteInternal is a helper around ParseWithParams() and ExecuteStmt(). It is not allowed to execute multiple statements.
ExecuteInternal(context.Context, string, ...interface{}) ([]sqlexec.RecordSet, error)
ExecuteInternal(context.Context, string, ...interface{}) (sqlexec.RecordSet, error)
String() string // String is used to debug.
CommitTxn(context.Context) error
RollbackTxn(context.Context)
Expand Down Expand Up @@ -899,37 +900,22 @@ func (s *session) ExecRestrictedSQLWithSnapshot(sql string) ([]chunk.Row, []*ast
func execRestrictedSQL(ctx context.Context, se *session, sql string) ([]chunk.Row, []*ast.ResultField, error) {
ctx = context.WithValue(ctx, execdetails.StmtExecDetailKey, &execdetails.StmtExecDetails{})
startTime := time.Now()
recordSets, err := se.ExecuteInternal(ctx, sql)
defer func() {
for _, rs := range recordSets {
closeErr := rs.Close()
if closeErr != nil && err == nil {
err = closeErr
}
}
}()
if err != nil {
rs, err := se.ExecuteInternal(ctx, sql)
if rs != nil {
defer terror.Call(rs.Close)
}
if err != nil || rs == nil {
return nil, nil, err
}

var (
rows []chunk.Row
fields []*ast.ResultField
)
// Execute all recordset, take out the first one as result.
for i, rs := range recordSets {
tmp, err := drainRecordSet(ctx, se, rs)
if err != nil {
return nil, nil, err
}

if i == 0 {
rows = tmp
fields = rs.Fields()
}
rows, err := drainRecordSet(ctx, se, rs)
if err != nil {
return nil, nil, err
}

metrics.QueryDurationHistogram.WithLabelValues(metrics.LblInternal).Observe(time.Since(startTime).Seconds())
return rows, fields, err
return rows, rs.Fields(), err
}

func createSessionFunc(store kv.Storage) pools.Factory {
Expand Down Expand Up @@ -1259,7 +1245,7 @@ func (s *session) SetProcessInfo(sql string, t time.Time, command byte, maxExecu
s.processInfo.Store(&pi)
}

func (s *session) ExecuteInternal(ctx context.Context, sql string, args ...interface{}) (recordSets []sqlexec.RecordSet, err error) {
func (s *session) ExecuteInternal(ctx context.Context, sql string, args ...interface{}) (rs sqlexec.RecordSet, err error) {
origin := s.sessionVars.InRestrictedSQL
s.sessionVars.InRestrictedSQL = true
defer func() {
Expand All @@ -1278,17 +1264,18 @@ func (s *session) ExecuteInternal(ctx context.Context, sql string, args ...inter
return nil, err
}

rs, err := s.ExecuteStmt(ctx, stmtNode)
rs, err = s.ExecuteStmt(ctx, stmtNode)
if err != nil {
s.sessionVars.StmtCtx.AppendError(err)
}
if rs == nil {
return nil, err
}

return []sqlexec.RecordSet{rs}, err
return rs, err
}

// Execute is deprecated, we can remove it as soon as plugins are migrated.
func (s *session) Execute(ctx context.Context, sql string) (recordSets []sqlexec.RecordSet, err error) {
if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil {
span1 := span.Tracer().StartSpan("session.Execute", opentracing.ChildOf(span.Context()))
Expand Down Expand Up @@ -1350,10 +1337,14 @@ func (s *session) Parse(ctx context.Context, sql string) ([]ast.StmtNode, error)
}

// ParseWithParams parses a query string, with arguments, to raw ast.StmtNode.
// Note that it will not do escaping if no variable arguments are passed.
func (s *session) ParseWithParams(ctx context.Context, sql string, args ...interface{}) (ast.StmtNode, error) {
sql, err := EscapeSQL(sql, args...)
if err != nil {
return nil, err
var err error
if len(args) > 0 {
sql, err = sqlexec.EscapeSQL(sql, args...)
if err != nil {
return nil, err
}
}

internal := s.isInternal()
Expand Down Expand Up @@ -2220,18 +2211,18 @@ var (
// loadParameter loads read-only parameter from mysql.tidb
func loadParameter(se *session, name string) (string, error) {
sql := "select variable_value from mysql.tidb where variable_name = '" + name + "'"
rss, errLoad := se.Execute(context.Background(), sql)
rs, errLoad := se.ExecuteInternal(context.Background(), sql)
if errLoad != nil {
return "", errLoad
}
// the record of mysql.tidb under where condition: variable_name = $name should shall only be one.
defer func() {
if err := rss[0].Close(); err != nil {
if err := rs.Close(); err != nil {
logutil.BgLogger().Error("close result set error", zap.Error(err))
}
}()
req := rss[0].NewChunk()
if err := rss[0].Next(context.Background(), req); err != nil {
req := rs.NewChunk()
if err := rs.Next(context.Background(), req); err != nil {
return "", err
}
if req.NumRows() == 0 {
Expand Down
20 changes: 15 additions & 5 deletions session/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4166,25 +4166,35 @@ func (s *testSessionSerialSuite) TestParseWithParams(c *C) {
defer func() {
se.GetSessionVars().InRestrictedSQL = origin
}()
_, err := exec.ParseWithParams(context.Background(), "SELECT 4")
_, err := exec.ParseWithParams(context.TODO(), "SELECT 4")
c.Assert(err, IsNil)

// test charset attack
stmts, err := exec.ParseWithParams(context.Background(), "SELECT * FROM test WHERE name = %? LIMIT 1", "\xbf\x27 OR 1=1 /*")
stmt, err := exec.ParseWithParams(context.TODO(), "SELECT * FROM test WHERE name = %? LIMIT 1", "\xbf\x27 OR 1=1 /*")
c.Assert(err, IsNil)

var sb strings.Builder
ctx := format.NewRestoreCtx(0, &sb)
err = stmts.Restore(ctx)
err = stmt.Restore(ctx)
c.Assert(err, IsNil)
// FIXME: well... so the restore function is vulnerable...
c.Assert(sb.String(), Equals, "SELECT * FROM test WHERE name=_utf8mb4\xbf' OR 1=1 /* LIMIT 1")

// test invalid sql
_, err = exec.ParseWithParams(context.Background(), "SELECT")
_, err = exec.ParseWithParams(context.TODO(), "SELECT")
c.Assert(err, ErrorMatches, ".*You have an error in your SQL syntax.*")

// test invalid arguments to escape
_, err = exec.ParseWithParams(context.Background(), "SELECT %?")
_, err = exec.ParseWithParams(context.TODO(), "SELECT %?, %?", 3)
c.Assert(err, ErrorMatches, "missing arguments.*")

// test noescape
stmt, err = exec.ParseWithParams(context.TODO(), "SELECT 3")
c.Assert(err, IsNil)

sb.Reset()
ctx = format.NewRestoreCtx(0, &sb)
err = stmt.Restore(ctx)
c.Assert(err, IsNil)
c.Assert(sb.String(), Equals, "SELECT 3")
}
Loading

0 comments on commit 199c105

Please sign in to comment.