From 56c7714488206491c64abdfbd83801c4778e60ce Mon Sep 17 00:00:00 2001 From: MyonKeminta Date: Wed, 6 Nov 2024 17:35:44 +0800 Subject: [PATCH] Add more tests Signed-off-by: MyonKeminta --- oracle/oracles/pd.go | 2 +- oracle/oracles/pd_test.go | 159 +++++++++++++++++++++++++++++++++++++- 2 files changed, 159 insertions(+), 2 deletions(-) diff --git a/oracle/oracles/pd.go b/oracle/oracles/pd.go index 08b411ddcc..e57b73d3fb 100644 --- a/oracle/oracles/pd.go +++ b/oracle/oracles/pd.go @@ -612,7 +612,7 @@ func (o *pdOracle) getCurrentTSForValidation(ctx context.Context, opt *oracle.Op func (o *pdOracle) ValidateSnapshotReadTS(ctx context.Context, readTS uint64, opt *oracle.Option) error { latestTS, err := o.GetLowResolutionTimestamp(ctx, opt) - // If we fail to get latestTS or the readTS exceeds it, get a timestamp from PD to double check. + // If we fail to get latestTS or the readTS exceeds it, get a timestamp from PD to double-check. // But we don't need to strictly fetch the latest TS. So if there are already concurrent calls to this function // loading the latest TS, we can just reuse the same result to avoid too many concurrent GetTS calls. if err != nil || readTS > latestTS { diff --git a/oracle/oracles/pd_test.go b/oracle/oracles/pd_test.go index 46e3ba4de4..adef5e979b 100644 --- a/oracle/oracles/pd_test.go +++ b/oracle/oracles/pd_test.go @@ -180,13 +180,15 @@ func TestNonFutureStaleTSO(t *testing.T) { } } -func TestNextUpdateTSInterval(t *testing.T) { +func TestAdaptiveUpdateTSInterval(t *testing.T) { oracleInterface, err := NewPdOracle(&MockPdClient{}, &PDOracleOptions{ UpdateInterval: time.Second * 2, NoUpdateTS: true, }) assert.NoError(t, err) o := oracleInterface.(*pdOracle) + defer o.Close() + now := time.Now() mockTS := func(beforeNow time.Duration) uint64 { @@ -339,3 +341,158 @@ func TestNextUpdateTSInterval(t *testing.T) { assert.Equal(t, minAllowedAdaptiveUpdateTSInterval/2, o.nextUpdateInterval(now, 0)) assert.Equal(t, adaptiveUpdateTSIntervalStateUnadjustable, o.adaptiveUpdateIntervalState.state) } + +func TestValidateSnapshotReadTS(t *testing.T) { + pdClient := MockPdClient{} + o, err := NewPdOracle(&pdClient, &PDOracleOptions{ + UpdateInterval: time.Second * 2, + }) + defer o.Close() + + assert.NoError(t, err) + ctx := context.Background() + opt := &oracle.Option{TxnScope: oracle.GlobalTxnScope} + ts, err := o.GetTimestamp(ctx, opt) + assert.NoError(t, err) + assert.GreaterOrEqual(t, ts, uint64(1)) + + err = o.ValidateSnapshotReadTS(ctx, 1, opt) + assert.NoError(t, err) + ts, err = o.GetTimestamp(ctx, opt) + assert.NoError(t, err) + // The readTS exceeds the latest ts, so it first fails the check with the low resolution ts. Then it fallbacks to + // the fetching-from-PD path, and it can get the previous ts + 1, which can allow this validation to pass. + err = o.ValidateSnapshotReadTS(ctx, ts+1, opt) + assert.NoError(t, err) + // It can't pass if the readTS is newer than previous ts + 2. + ts, err = o.GetTimestamp(ctx, opt) + assert.NoError(t, err) + err = o.ValidateSnapshotReadTS(ctx, ts+2, opt) + assert.Error(t, err) + + // Simulate other PD clients requests a timestamp. + ts, err = o.GetTimestamp(ctx, opt) + assert.NoError(t, err) + pdClient.logicalTimestamp.Add(2) + err = o.ValidateSnapshotReadTS(ctx, ts+3, opt) + assert.NoError(t, err) +} + +type MockPDClientWithPause struct { + MockPdClient + mu sync.Mutex +} + +func (c *MockPDClientWithPause) GetTS(ctx context.Context) (int64, int64, error) { + c.mu.Lock() + defer c.mu.Unlock() + return c.MockPdClient.GetTS(ctx) +} + +func (c *MockPDClientWithPause) Pause() { + c.mu.Lock() +} + +func (c *MockPDClientWithPause) Resume() { + c.mu.Unlock() +} + +func TestValidateSnapshotReadTSReusingGetTSResult(t *testing.T) { + pdClient := &MockPDClientWithPause{} + o, err := NewPdOracle(pdClient, &PDOracleOptions{ + UpdateInterval: time.Second * 2, + NoUpdateTS: true, + }) + assert.NoError(t, err) + defer o.Close() + + asyncValidate := func(ctx context.Context, readTS uint64) chan error { + ch := make(chan error, 1) + go func() { + err := o.ValidateSnapshotReadTS(ctx, readTS, &oracle.Option{TxnScope: oracle.GlobalTxnScope}) + ch <- err + }() + return ch + } + + noResult := func(ch chan error) { + select { + case <-ch: + assert.FailNow(t, "a ValidateSnapshotReadTS operation is not blocked while it's expected to be blocked") + default: + } + } + + cancelIndices := []int{-1, -1, 0, 1} + for i, ts := range []uint64{100, 200, 300, 400} { + // Note: the ts is the result that the next GetTS will return. Any validation with readTS <= ts should pass, otherwise fail. + + // We will cancel the cancelIndex-th validation call. This is for testing that canceling some of the calls + // doesn't affect other calls that are waiting + cancelIndex := cancelIndices[i] + + pdClient.Pause() + + results := make([]chan error, 0, 5) + + ctx, cancel := context.WithCancel(context.Background()) + + getCtx := func(index int) context.Context { + if cancelIndex == index { + return ctx + } else { + return context.Background() + } + } + + results = append(results, asyncValidate(getCtx(0), ts-2)) + results = append(results, asyncValidate(getCtx(1), ts+2)) + results = append(results, asyncValidate(getCtx(2), ts-1)) + results = append(results, asyncValidate(getCtx(3), ts+1)) + results = append(results, asyncValidate(getCtx(4), ts)) + + expectedSucceeds := []bool{true, false, true, false, true} + + time.Sleep(time.Millisecond * 50) + for _, ch := range results { + noResult(ch) + } + + cancel() + + for i, ch := range results { + if i == cancelIndex { + select { + case err := <-ch: + assert.Errorf(t, err, "index: %v", i) + assert.Containsf(t, err.Error(), "context canceled", "index: %v", i) + case <-time.After(time.Second): + assert.FailNowf(t, "expected result to be ready but still blocked", "index: %v", i) + } + } else { + noResult(ch) + } + } + + // ts will be the next ts returned to these validation calls. + pdClient.logicalTimestamp.Store(int64(ts - 1)) + pdClient.Resume() + for i, ch := range results { + if i == cancelIndex { + continue + } + + select { + case err = <-ch: + case <-time.After(time.Second): + assert.FailNowf(t, "expected result to be ready but still blocked", "index: %v", i) + } + if expectedSucceeds[i] { + assert.NoErrorf(t, err, "index: %v", i) + } else { + assert.Errorf(t, err, "index: %v", i) + assert.NotContainsf(t, err.Error(), "context canceled", "index: %v", i) + } + } + } +}