Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

*: Use strict validation for stale read ts & flashback ts (#57050) #57315

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions DEPS.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -4179,8 +4179,8 @@ def go_deps():
name = "com_github_tikv_client_go_v2",
build_file_proto_mode = "disable_global",
importpath = "github.com/tikv/client-go/v2",
sum = "h1:b3eCbSiRZ3az/eQPz0dqYmbYT4Xzix7FjFM1Uf+tKVM=",
version = "v2.0.8-0.20241108115434-4aab367743bf",
sum = "h1:DeZMstuDPx80CNsJdmN47biH+8PL5qK4TmTyNcDRDdI=",
version = "v2.0.8-0.20241111142004-e10335846244",
)
go_repository(
name = "com_github_tikv_pd",
Expand Down
13 changes: 4 additions & 9 deletions ddl/cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
"github.com/pingcap/tidb/infoschema"
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/meta"
"github.com/pingcap/tidb/metrics"
"github.com/pingcap/tidb/parser/model"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/variable"
Expand Down Expand Up @@ -112,16 +111,12 @@

// ValidateFlashbackTS validates that flashBackTS in range [gcSafePoint, currentTS).
func ValidateFlashbackTS(ctx context.Context, sctx sessionctx.Context, flashBackTS uint64) error {
currentTS, err := sctx.GetStore().GetOracle().GetStaleTimestamp(ctx, oracle.GlobalTxnScope, 0)
// If we fail to calculate currentTS from local time, fallback to get a timestamp from PD.
currentVer, err := sctx.GetStore().CurrentVersion(oracle.GlobalTxnScope)
if err != nil {
metrics.ValidateReadTSFromPDCount.Inc()
currentVer, err := sctx.GetStore().CurrentVersion(oracle.GlobalTxnScope)
if err != nil {
return errors.Errorf("fail to validate flashback timestamp: %v", err)
}
currentTS = currentVer.Ver
return errors.Errorf("fail to validate flashback timestamp: %v", err)

Check warning on line 116 in ddl/cluster.go

View check run for this annotation

Codecov / codecov/patch

ddl/cluster.go#L116

Added line #L116 was not covered by tests
}
currentTS := currentVer.Ver

oracleFlashbackTS := oracle.GetTimeFromTS(flashBackTS)
if oracleFlashbackTS.After(oracle.GetTimeFromTS(currentTS)) {
return errors.Errorf("cannot set flashback timestamp to future time")
Expand Down
6 changes: 2 additions & 4 deletions executor/set.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,10 +198,8 @@ func (e *SetExecutor) setSysVariable(ctx context.Context, name string, v *expres
newSnapshotTS := getSnapshotTSByName()
newSnapshotIsSet := newSnapshotTS > 0 && newSnapshotTS != oldSnapshotTS
if newSnapshotIsSet {
if name == variable.TiDBTxnReadTS {
err = sessionctx.ValidateStaleReadTS(ctx, e.ctx, newSnapshotTS)
} else {
err = sessionctx.ValidateSnapshotReadTS(ctx, e.ctx, newSnapshotTS)
err = sessionctx.ValidateSnapshotReadTS(ctx, e.ctx.GetStore(), newSnapshotTS)
if name != variable.TiDBTxnReadTS {
// Also check gc safe point for snapshot read.
// We don't check snapshot with gc safe point for read_ts
// Client-go will automatically check the snapshotTS with gc safe point. It's unnecessary to check gc safe point during set executor.
Expand Down
25 changes: 21 additions & 4 deletions executor/stale_txn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package executor_test
import (
"context"
"fmt"
"strconv"
"testing"
"time"

Expand Down Expand Up @@ -1406,14 +1407,30 @@ func TestStaleTSO(t *testing.T) {
tk.MustExec("create table t (id int)")

tk.MustExec("insert into t values(1)")
ts1, err := strconv.ParseUint(tk.MustQuery("select json_extract(@@tidb_last_txn_info, '$.commit_ts')").Rows()[0][0].(string), 10, 64)
require.NoError(t, err)

// Wait until the physical advances for 1s
var currentTS uint64
for {
tk.MustExec("begin")
currentTS, err = strconv.ParseUint(tk.MustQuery("select @@tidb_current_ts").Rows()[0][0].(string), 10, 64)
require.NoError(t, err)
tk.MustExec("rollback")
if oracle.GetTimeFromTS(currentTS).After(oracle.GetTimeFromTS(ts1).Add(time.Second)) {
break
}
time.Sleep(time.Millisecond * 100)
}

asOfExprs := []string{
"now(3) - interval 1 second",
"current_time() - interval 1 second",
"curtime() - interval 1 second",
"now(3) - interval 10 second",
"current_time() - interval 10 second",
"curtime() - interval 10 second",
}

nextTSO := oracle.GoTimeToTS(time.Now().Add(2 * time.Second))
nextPhysical := oracle.GetPhysical(oracle.GetTimeFromTS(currentTS).Add(10 * time.Second))
nextTSO := oracle.ComposeTS(nextPhysical, oracle.ExtractLogical(currentTS))
require.Nil(t, failpoint.Enable("github.com/pingcap/tidb/sessiontxn/staleread/mockStaleReadTSO", fmt.Sprintf("return(%d)", nextTSO)))
defer failpoint.Disable("github.com/pingcap/tidb/sessiontxn/staleread/mockStaleReadTSO")
for _, expr := range asOfExprs {
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ require (
github.com/stretchr/testify v1.8.2
github.com/tdakkota/asciicheck v0.2.0
github.com/tiancaiamao/appdash v0.0.0-20181126055449-889f96f722a2
github.com/tikv/client-go/v2 v2.0.8-0.20241108115434-4aab367743bf
github.com/tikv/client-go/v2 v2.0.8-0.20241111142004-e10335846244
github.com/tikv/pd/client v0.0.0-20240725070735-fb162bf0aa3f
github.com/timakin/bodyclose v0.0.0-20221125081123-e39cf3fc478e
github.com/twmb/murmur3 v1.1.6
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -968,8 +968,8 @@ github.com/tiancaiamao/appdash v0.0.0-20181126055449-889f96f722a2 h1:mbAskLJ0oJf
github.com/tiancaiamao/appdash v0.0.0-20181126055449-889f96f722a2/go.mod h1:2PfKggNGDuadAa0LElHrByyrz4JPZ9fFx6Gs7nx7ZZU=
github.com/tiancaiamao/gp v0.0.0-20221230034425-4025bc8a4d4a h1:J/YdBZ46WKpXsxsW93SG+q0F8KI+yFrcIDT4c/RNoc4=
github.com/tiancaiamao/gp v0.0.0-20221230034425-4025bc8a4d4a/go.mod h1:h4xBhSNtOeEosLJ4P7JyKXX7Cabg7AVkWCK5gV2vOrM=
github.com/tikv/client-go/v2 v2.0.8-0.20241108115434-4aab367743bf h1:b3eCbSiRZ3az/eQPz0dqYmbYT4Xzix7FjFM1Uf+tKVM=
github.com/tikv/client-go/v2 v2.0.8-0.20241108115434-4aab367743bf/go.mod h1:45NuHB8x+VAoztMIjF6hEgXvPQXhXWPfMxDg0N8CoRY=
github.com/tikv/client-go/v2 v2.0.8-0.20241111142004-e10335846244 h1:DeZMstuDPx80CNsJdmN47biH+8PL5qK4TmTyNcDRDdI=
github.com/tikv/client-go/v2 v2.0.8-0.20241111142004-e10335846244/go.mod h1:45NuHB8x+VAoztMIjF6hEgXvPQXhXWPfMxDg0N8CoRY=
github.com/tikv/pd/client v0.0.0-20240725070735-fb162bf0aa3f h1:Szw9YxqGGEneSniBd4ep09jgB77cKUy+AuhKOmdGPdE=
github.com/tikv/pd/client v0.0.0-20240725070735-fb162bf0aa3f/go.mod h1:QCBn54O5lhfkYfxj8Tyiqaxue/mthHEMyi7AqJP/+n4=
github.com/timakin/bodyclose v0.0.0-20221125081123-e39cf3fc478e h1:MV6KaVu/hzByHP0UvJ4HcMGE/8a6A4Rggc/0wx2AvJo=
Expand Down
Empty file added pkg/sessionctx/BUILD.bazel
Empty file.
2 changes: 1 addition & 1 deletion planner/core/plan_cache_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,7 @@ type PlanCacheStmt struct {
SQLDigest *parser.Digest
PlanDigest *parser.Digest
ForUpdateRead bool
SnapshotTSEvaluator func(sessionctx.Context) (uint64, error)
SnapshotTSEvaluator func(context.Context, sessionctx.Context) (uint64, error)
NormalizedSQL4PC string
SQLDigest4PC string

Expand Down
4 changes: 2 additions & 2 deletions planner/core/planbuilder.go
Original file line number Diff line number Diff line change
Expand Up @@ -3568,7 +3568,7 @@ func (b *PlanBuilder) buildSimple(ctx context.Context, node ast.StmtNode) (Plan,
if err != nil {
return nil, err
}
if err := sessionctx.ValidateStaleReadTS(ctx, b.ctx, startTS); err != nil {
if err := sessionctx.ValidateSnapshotReadTS(ctx, b.ctx.GetStore(), startTS); err != nil {
return nil, err
}
p.StaleTxnStartTS = startTS
Expand All @@ -3582,7 +3582,7 @@ func (b *PlanBuilder) buildSimple(ctx context.Context, node ast.StmtNode) (Plan,
if err != nil {
return nil, err
}
if err := sessionctx.ValidateStaleReadTS(ctx, b.ctx, startTS); err != nil {
if err := sessionctx.ValidateSnapshotReadTS(ctx, b.ctx.GetStore(), startTS); err != nil {
return nil, err
}
p.StaleTxnStartTS = startTS
Expand Down
2 changes: 1 addition & 1 deletion planner/core/preprocess.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ var _ = PreprocessorReturn{}.initedLastSnapshotTS
type PreprocessorReturn struct {
initedLastSnapshotTS bool
IsStaleness bool
SnapshotTSEvaluator func(sessionctx.Context) (uint64, error)
SnapshotTSEvaluator func(context.Context, sessionctx.Context) (uint64, error)
// LastSnapshotTS is the last evaluated snapshotTS if any
// otherwise it defaults to zero
LastSnapshotTS uint64
Expand Down
2 changes: 0 additions & 2 deletions sessionctx/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ go_library(
deps = [
"//extension",
"//kv",
"//metrics",
"//parser/model",
"//sessionctx/sessionstates",
"//sessionctx/variable",
Expand All @@ -17,7 +16,6 @@ go_library(
"//util/plancache",
"//util/sli",
"//util/topsql/stmtstats",
"@com_github_pingcap_errors//:errors",
"@com_github_pingcap_kvproto//pkg/kvrpcpb",
"@com_github_pingcap_tipb//go-binlog",
"@com_github_tikv_client_go_v2//oracle",
Expand Down
43 changes: 2 additions & 41 deletions sessionctx/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,10 @@ package sessionctx
import (
"context"
"fmt"
"time"

"github.com/pingcap/errors"
"github.com/pingcap/kvproto/pkg/kvrpcpb"
"github.com/pingcap/tidb/extension"
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/metrics"
"github.com/pingcap/tidb/parser/model"
"github.com/pingcap/tidb/sessionctx/sessionstates"
"github.com/pingcap/tidb/sessionctx/variable"
Expand Down Expand Up @@ -223,44 +220,8 @@ const (
)

// ValidateSnapshotReadTS strictly validates that readTS does not exceed the PD timestamp
func ValidateSnapshotReadTS(ctx context.Context, sctx Context, readTS uint64) error {
latestTS, err := sctx.GetStore().GetOracle().GetLowResolutionTimestamp(ctx, &oracle.Option{TxnScope: oracle.GlobalTxnScope})
// If we fail to get latestTS or the readTS exceeds it, get a timestamp from PD to double check
if err != nil || readTS > latestTS {
metrics.ValidateReadTSFromPDCount.Inc()
currentVer, err := sctx.GetStore().CurrentVersion(oracle.GlobalTxnScope)
if err != nil {
return errors.Errorf("fail to validate read timestamp: %v", err)
}
if readTS > currentVer.Ver {
return errors.Errorf("cannot set read timestamp to a future time")
}
}
return nil
}

// How far future from now ValidateStaleReadTS allows at most
const allowedTimeFromNow = 100 * time.Millisecond

// ValidateStaleReadTS validates that readTS does not exceed the current time not strictly.
func ValidateStaleReadTS(ctx context.Context, sctx Context, readTS uint64) error {
currentTS, err := sctx.GetSessionVars().StmtCtx.GetStaleTSO()
if currentTS == 0 || err != nil {
currentTS, err = sctx.GetStore().GetOracle().GetStaleTimestamp(ctx, oracle.GlobalTxnScope, 0)
}
// If we fail to calculate currentTS from local time, fallback to get a timestamp from PD
if err != nil {
metrics.ValidateReadTSFromPDCount.Inc()
currentVer, err := sctx.GetStore().CurrentVersion(oracle.GlobalTxnScope)
if err != nil {
return errors.Errorf("fail to validate read timestamp: %v", err)
}
currentTS = currentVer.Ver
}
if oracle.GetTimeFromTS(readTS).After(oracle.GetTimeFromTS(currentTS).Add(allowedTimeFromNow)) {
return errors.Errorf("cannot set read timestamp to a future time")
}
return nil
func ValidateSnapshotReadTS(ctx context.Context, store kv.Storage, readTS uint64) error {
return store.GetOracle().ValidateSnapshotReadTS(ctx, readTS, &oracle.Option{TxnScope: oracle.GlobalTxnScope})
}

// SysProcTracker is used to track background sys processes
Expand Down
20 changes: 10 additions & 10 deletions sessiontxn/staleread/processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import (
var _ Processor = &staleReadProcessor{}

// StalenessTSEvaluator is a function to get staleness ts
type StalenessTSEvaluator func(sctx sessionctx.Context) (uint64, error)
type StalenessTSEvaluator func(ctx context.Context, sctx sessionctx.Context) (uint64, error)

// Processor is an interface used to process stale read
type Processor interface {
Expand Down Expand Up @@ -100,7 +100,7 @@ func (p *baseProcessor) setEvaluatedTS(ts uint64) (err error) {
return err
}

return p.setEvaluatedValues(ts, is, func(sctx sessionctx.Context) (uint64, error) {
return p.setEvaluatedValues(ts, is, func(_ context.Context, sctx sessionctx.Context) (uint64, error) {
return ts, nil
})
}
Expand All @@ -116,7 +116,7 @@ func (p *baseProcessor) setEvaluatedTSWithoutEvaluator(ts uint64) (err error) {
}

func (p *baseProcessor) setEvaluatedEvaluator(evaluator StalenessTSEvaluator) error {
ts, err := evaluator(p.sctx)
ts, err := evaluator(p.ctx, p.sctx)
if err != nil {
return err
}
Expand Down Expand Up @@ -167,10 +167,10 @@ func (p *staleReadProcessor) OnSelectTable(tn *ast.TableName) error {
}

// If `stmtAsOfTS` is not 0, it means we use 'select ... from xxx as of timestamp ...'
evaluateTS := func(sctx sessionctx.Context) (uint64, error) {
return parseAndValidateAsOf(context.Background(), p.sctx, tn.AsOf)
evaluateTS := func(ctx context.Context, sctx sessionctx.Context) (uint64, error) {
return parseAndValidateAsOf(ctx, p.sctx, tn.AsOf)
}
stmtAsOfTS, err := evaluateTS(p.sctx)
stmtAsOfTS, err := evaluateTS(p.ctx, p.sctx)
if err != nil {
return err
}
Expand Down Expand Up @@ -200,7 +200,7 @@ func (p *staleReadProcessor) OnExecutePreparedStmt(preparedTSEvaluator Staleness
var stmtTS uint64
if preparedTSEvaluator != nil {
// If the `preparedTSEvaluator` is not nil, it means the prepared statement is stale read
if stmtTS, err = preparedTSEvaluator(p.sctx); err != nil {
if stmtTS, err = preparedTSEvaluator(p.ctx, p.sctx); err != nil {
return err
}
}
Expand Down Expand Up @@ -285,7 +285,7 @@ func parseAndValidateAsOf(ctx context.Context, sctx sessionctx.Context, asOf *as
return 0, err
}

if err = sessionctx.ValidateStaleReadTS(ctx, sctx, ts); err != nil {
if err = sessionctx.ValidateSnapshotReadTS(ctx, sctx.GetStore(), ts); err != nil {
return 0, err
}

Expand All @@ -298,8 +298,8 @@ func getTsEvaluatorFromReadStaleness(sctx sessionctx.Context) StalenessTSEvaluat
return nil
}

return func(sctx sessionctx.Context) (uint64, error) {
return CalculateTsWithReadStaleness(sctx, readStaleness)
return func(ctx context.Context, sctx sessionctx.Context) (uint64, error) {
return CalculateTsWithReadStaleness(ctx, sctx, readStaleness)
}
}

Expand Down
Loading