From b0f0ead1aa45436433c845fa9a17fbcfe6881950 Mon Sep 17 00:00:00 2001 From: you06 Date: Wed, 15 Nov 2023 17:34:14 +0900 Subject: [PATCH] tso: merge lastTS and lastArrivalTS into an atomic pointer (#1054) * fix the issue that stale timestamp may be a future one Signed-off-by: you06 * add regression test Signed-off-by: you06 * lazy init lastTSO Signed-off-by: you06 * fix panic Signed-off-by: you06 * address comment Signed-off-by: you06 --------- Signed-off-by: you06 --- oracle/oracles/export_test.go | 18 ++------ oracle/oracles/pd.go | 82 +++++++++++++++-------------------- oracle/oracles/pd_test.go | 32 ++++++++++++++ 3 files changed, 70 insertions(+), 62 deletions(-) diff --git a/oracle/oracles/export_test.go b/oracle/oracles/export_test.go index 6519f28beb..08df25783d 100644 --- a/oracle/oracles/export_test.go +++ b/oracle/oracles/export_test.go @@ -63,20 +63,8 @@ func NewEmptyPDOracle() oracle.Oracle { func SetEmptyPDOracleLastTs(oc oracle.Oracle, ts uint64) { switch o := oc.(type) { case *pdOracle: - lastTSInterface, _ := o.lastTSMap.LoadOrStore(oracle.GlobalTxnScope, new(uint64)) - lastTSPointer := lastTSInterface.(*uint64) - atomic.StoreUint64(lastTSPointer, ts) - lasTSArrivalInterface, _ := o.lastArrivalTSMap.LoadOrStore(oracle.GlobalTxnScope, new(uint64)) - lasTSArrivalPointer := lasTSArrivalInterface.(*uint64) - atomic.StoreUint64(lasTSArrivalPointer, uint64(time.Now().Unix()*1000)) - } - setEmptyPDOracleLastArrivalTs(oc, ts) -} - -// setEmptyPDOracleLastArrivalTs exports PD oracle's global last ts to test. -func setEmptyPDOracleLastArrivalTs(oc oracle.Oracle, ts uint64) { - switch o := oc.(type) { - case *pdOracle: - o.setLastArrivalTS(ts, oracle.GlobalTxnScope) + lastTSInterface, _ := o.lastTSMap.LoadOrStore(oracle.GlobalTxnScope, &atomic.Pointer[lastTSO]{}) + lastTSPointer := lastTSInterface.(*atomic.Pointer[lastTSO]) + lastTSPointer.Store(&lastTSO{tso: ts, arrival: ts}) } } diff --git a/oracle/oracles/pd.go b/oracle/oracles/pd.go index 969af24139..8f81060d5e 100644 --- a/oracle/oracles/pd.go +++ b/oracle/oracles/pd.go @@ -56,11 +56,15 @@ const slowDist = 30 * time.Millisecond // pdOracle is an Oracle that uses a placement driver client as source. type pdOracle struct { c pd.Client - // txn_scope (string) -> lastTSPointer (*uint64) + // txn_scope (string) -> lastTSPointer (*atomic.Pointer[lastTSO]) lastTSMap sync.Map - // txn_scope (string) -> lastArrivalTSPointer (*uint64) - lastArrivalTSMap sync.Map - quit chan struct{} + quit chan struct{} +} + +// lastTSO stores the last timestamp oracle gets from PD server and the local time when the TSO is fetched. +type lastTSO struct { + tso uint64 + arrival uint64 } // NewPdOracle create an Oracle that uses a pd client source. @@ -163,63 +167,51 @@ func (o *pdOracle) setLastTS(ts uint64, txnScope string) { if txnScope == "" { txnScope = oracle.GlobalTxnScope } - lastTSInterface, ok := o.lastTSMap.Load(txnScope) - if !ok { - lastTSInterface, _ = o.lastTSMap.LoadOrStore(txnScope, new(uint64)) - } - lastTSPointer := lastTSInterface.(*uint64) - for { - lastTS := atomic.LoadUint64(lastTSPointer) - if ts <= lastTS { - return - } - if atomic.CompareAndSwapUint64(lastTSPointer, lastTS, ts) { - break - } - } - o.setLastArrivalTS(o.getArrivalTimestamp(), txnScope) -} - -func (o *pdOracle) setLastArrivalTS(ts uint64, txnScope string) { - if txnScope == "" { - txnScope = oracle.GlobalTxnScope + current := &lastTSO{ + tso: ts, + arrival: o.getArrivalTimestamp(), } - lastTSInterface, ok := o.lastArrivalTSMap.Load(txnScope) + lastTSInterface, ok := o.lastTSMap.Load(txnScope) if !ok { - lastTSInterface, _ = o.lastArrivalTSMap.LoadOrStore(txnScope, new(uint64)) + pointer := &atomic.Pointer[lastTSO]{} + pointer.Store(current) + // do not handle the stored case, because it only runs once. + lastTSInterface, _ = o.lastTSMap.LoadOrStore(txnScope, pointer) } - lastTSPointer := lastTSInterface.(*uint64) + lastTSPointer := lastTSInterface.(*atomic.Pointer[lastTSO]) for { - lastTS := atomic.LoadUint64(lastTSPointer) - if ts <= lastTS { + last := lastTSPointer.Load() + if current.tso <= last.tso || current.arrival <= last.arrival { return } - if atomic.CompareAndSwapUint64(lastTSPointer, lastTS, ts) { + if lastTSPointer.CompareAndSwap(last, current) { return } } } func (o *pdOracle) getLastTS(txnScope string) (uint64, bool) { - if txnScope == "" { - txnScope = oracle.GlobalTxnScope - } - lastTSInterface, ok := o.lastTSMap.Load(txnScope) - if !ok { + last, exist := o.getLastTSWithArrivalTS(txnScope) + if !exist { return 0, false } - return atomic.LoadUint64(lastTSInterface.(*uint64)), true + return last.tso, true } -func (o *pdOracle) getLastArrivalTS(txnScope string) (uint64, bool) { +func (o *pdOracle) getLastTSWithArrivalTS(txnScope string) (*lastTSO, bool) { if txnScope == "" { txnScope = oracle.GlobalTxnScope } - lastArrivalTSInterface, ok := o.lastArrivalTSMap.Load(txnScope) + lastTSInterface, ok := o.lastTSMap.Load(txnScope) if !ok { - return 0, false + return nil, false + } + lastTSPointer := lastTSInterface.(*atomic.Pointer[lastTSO]) + last := lastTSPointer.Load() + if last == nil { + return nil, false } - return atomic.LoadUint64(lastArrivalTSInterface.(*uint64)), true + return last, true } func (o *pdOracle) updateTS(ctx context.Context, interval time.Duration) { @@ -293,22 +285,18 @@ func (o *pdOracle) GetLowResolutionTimestampAsync(ctx context.Context, opt *orac } func (o *pdOracle) getStaleTimestamp(txnScope string, prevSecond uint64) (uint64, error) { - ts, ok := o.getLastTS(txnScope) + last, ok := o.getLastTSWithArrivalTS(txnScope) if !ok { return 0, errors.Errorf("get stale timestamp fail, txnScope: %s", txnScope) } - arrivalTS, ok := o.getLastArrivalTS(txnScope) - if !ok { - return 0, errors.Errorf("get stale arrival timestamp fail, txnScope: %s", txnScope) - } + ts, arrivalTS := last.tso, last.arrival arrivalTime := oracle.GetTimeFromTS(arrivalTS) physicalTime := oracle.GetTimeFromTS(ts) if uint64(physicalTime.Unix()) <= prevSecond { return 0, errors.Errorf("invalid prevSecond %v", prevSecond) } - staleTime := physicalTime.Add(-arrivalTime.Sub(time.Now().Add(-time.Duration(prevSecond) * time.Second))) - + staleTime := physicalTime.Add(time.Now().Add(-time.Duration(prevSecond) * time.Second).Sub(arrivalTime)) return oracle.GoTimeToTS(staleTime), nil } diff --git a/oracle/oracles/pd_test.go b/oracle/oracles/pd_test.go index 2ad1467e5e..376e3fa5ab 100644 --- a/oracle/oracles/pd_test.go +++ b/oracle/oracles/pd_test.go @@ -72,3 +72,35 @@ func TestPdOracle_GetStaleTimestamp(t *testing.T) { assert.NotNil(t, err) assert.Regexp(t, ".*invalid prevSecond.*", err.Error()) } + +func TestNonFutureStaleTSO(t *testing.T) { + o := oracles.NewEmptyPDOracle() + oracles.SetEmptyPDOracleLastTs(o, oracle.GoTimeToTS(time.Now())) + for i := 0; i < 100; i++ { + time.Sleep(10 * time.Millisecond) + now := time.Now() + upperBound := now.Add(5 * time.Millisecond) // allow 5ms time drift + + closeCh := make(chan struct{}) + go func() { + time.Sleep(100 * time.Microsecond) + oracles.SetEmptyPDOracleLastTs(o, oracle.GoTimeToTS(now)) + close(closeCh) + }() + CHECK: + for { + select { + case <-closeCh: + break CHECK + default: + ts, err := o.GetStaleTimestamp(context.Background(), oracle.GlobalTxnScope, 0) + assert.Nil(t, err) + staleTime := oracle.GetTimeFromTS(ts) + if staleTime.After(upperBound) && time.Since(now) < time.Millisecond /* only check staleTime within 1ms */ { + assert.Less(t, staleTime, upperBound, i) + t.FailNow() + } + } + } + } +}