Skip to content

Commit

Permalink
tso: merge lastTS and lastArrivalTS into an atomic pointer (#1054) (#…
Browse files Browse the repository at this point in the history
…1064)

Signed-off-by: you06 <you1474600@gmail.com>
  • Loading branch information
you06 authored Nov 21, 2023
1 parent c1041a4 commit 1946394
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 63 deletions.
20 changes: 4 additions & 16 deletions oracle/oracles/export_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
package oracles

import (
"sync/atomic"
"time"

"github.com/tikv/client-go/v2/oracle"
Expand Down Expand Up @@ -63,20 +62,9 @@ func NewEmptyPDOracle() oracle.Oracle {
func SetEmptyPDOracleLastTs(oc oracle.Oracle, ts uint64) {
switch o := oc.(type) {
case *pdOracle:
lastTSInterface, _ := o.lastTSMap.LoadOrStore(oracle.GlobalTxnScope, new(uint64))
lastTSPointer := lastTSInterface.(*uint64)
atomic.StoreUint64(lastTSPointer, ts)
lasTSArrivalInterface, _ := o.lastArrivalTSMap.LoadOrStore(oracle.GlobalTxnScope, new(uint64))
lasTSArrivalPointer := lasTSArrivalInterface.(*uint64)
atomic.StoreUint64(lasTSArrivalPointer, uint64(time.Now().Unix()*1000))
}
setEmptyPDOracleLastArrivalTs(oc, ts)
}

// setEmptyPDOracleLastArrivalTs exports PD oracle's global last ts to test.
func setEmptyPDOracleLastArrivalTs(oc oracle.Oracle, ts uint64) {
switch o := oc.(type) {
case *pdOracle:
o.setLastArrivalTS(ts, oracle.GlobalTxnScope)
now := &lastTSO{ts, ts}
lastTSInterface, _ := o.lastTSMap.LoadOrStore(oracle.GlobalTxnScope, NewLastTSOPointer(now))
lastTSPointer := lastTSInterface.(*lastTSOPointer)
lastTSPointer.store(&lastTSO{tso: ts, arrival: ts})
}
}
103 changes: 56 additions & 47 deletions oracle/oracles/pd.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ import (
"sync"
"sync/atomic"
"time"
"unsafe"

"github.com/pkg/errors"
"github.com/tikv/client-go/v2/internal/logutil"
Expand All @@ -56,11 +57,36 @@ const slowDist = 30 * time.Millisecond
// pdOracle is an Oracle that uses a placement driver client as source.
type pdOracle struct {
c pd.Client
// txn_scope (string) -> lastTSPointer (*uint64)
// txn_scope (string) -> lastTSPointer (*lastTSOPointer)
lastTSMap sync.Map
// txn_scope (string) -> lastArrivalTSPointer (*uint64)
lastArrivalTSMap sync.Map
quit chan struct{}
quit chan struct{}
}

// lastTSO stores the last timestamp oracle gets from PD server and the local time when the TSO is fetched.
type lastTSO struct {
tso uint64
arrival uint64
}

// lastTSOPointer wrap the lastTSO struct into a pointer.
type lastTSOPointer struct {
p unsafe.Pointer
}

func NewLastTSOPointer(last *lastTSO) *lastTSOPointer {
return &lastTSOPointer{p: unsafe.Pointer(last)}
}

func (p *lastTSOPointer) load() *lastTSO {
return (*lastTSO)(atomic.LoadPointer(&p.p))
}

func (p *lastTSOPointer) store(last *lastTSO) {
atomic.StorePointer(&p.p, unsafe.Pointer(last))
}

func (p *lastTSOPointer) compareAndSwap(old, new *lastTSO) bool {
return atomic.CompareAndSwapPointer(&p.p, unsafe.Pointer(old), unsafe.Pointer(new))
}

// NewPdOracle create an Oracle that uses a pd client source.
Expand Down Expand Up @@ -163,63 +189,50 @@ func (o *pdOracle) setLastTS(ts uint64, txnScope string) {
if txnScope == "" {
txnScope = oracle.GlobalTxnScope
}
lastTSInterface, ok := o.lastTSMap.Load(txnScope)
if !ok {
lastTSInterface, _ = o.lastTSMap.LoadOrStore(txnScope, new(uint64))
}
lastTSPointer := lastTSInterface.(*uint64)
for {
lastTS := atomic.LoadUint64(lastTSPointer)
if ts <= lastTS {
return
}
if atomic.CompareAndSwapUint64(lastTSPointer, lastTS, ts) {
break
}
current := &lastTSO{
tso: ts,
arrival: o.getArrivalTimestamp(),
}
o.setLastArrivalTS(o.getArrivalTimestamp(), txnScope)
}

func (o *pdOracle) setLastArrivalTS(ts uint64, txnScope string) {
if txnScope == "" {
txnScope = oracle.GlobalTxnScope
}
lastTSInterface, ok := o.lastArrivalTSMap.Load(txnScope)
lastTSInterface, ok := o.lastTSMap.Load(txnScope)
if !ok {
lastTSInterface, _ = o.lastArrivalTSMap.LoadOrStore(txnScope, new(uint64))
pointer := NewLastTSOPointer(current)
// do not handle the stored case, because it only runs once.
lastTSInterface, _ = o.lastTSMap.LoadOrStore(txnScope, pointer)
}
lastTSPointer := lastTSInterface.(*uint64)
lastTSPointer := lastTSInterface.(*lastTSOPointer)
for {
lastTS := atomic.LoadUint64(lastTSPointer)
if ts <= lastTS {
last := lastTSPointer.load()
if current.tso <= last.tso || current.arrival <= last.arrival {
return
}
if atomic.CompareAndSwapUint64(lastTSPointer, lastTS, ts) {
if lastTSPointer.compareAndSwap(last, current) {
return
}
}
}

func (o *pdOracle) getLastTS(txnScope string) (uint64, bool) {
if txnScope == "" {
txnScope = oracle.GlobalTxnScope
}
lastTSInterface, ok := o.lastTSMap.Load(txnScope)
if !ok {
last, exist := o.getLastTSWithArrivalTS(txnScope)
if !exist {
return 0, false
}
return atomic.LoadUint64(lastTSInterface.(*uint64)), true
return last.tso, true
}

func (o *pdOracle) getLastArrivalTS(txnScope string) (uint64, bool) {
func (o *pdOracle) getLastTSWithArrivalTS(txnScope string) (*lastTSO, bool) {
if txnScope == "" {
txnScope = oracle.GlobalTxnScope
}
lastArrivalTSInterface, ok := o.lastArrivalTSMap.Load(txnScope)
lastTSInterface, ok := o.lastTSMap.Load(txnScope)
if !ok {
return 0, false
return nil, false
}
return atomic.LoadUint64(lastArrivalTSInterface.(*uint64)), true
lastTSPointer := lastTSInterface.(*lastTSOPointer)
last := lastTSPointer.load()
if last == nil {
return nil, false
}
return last, true
}

func (o *pdOracle) updateTS(ctx context.Context, interval time.Duration) {
Expand Down Expand Up @@ -293,22 +306,18 @@ func (o *pdOracle) GetLowResolutionTimestampAsync(ctx context.Context, opt *orac
}

func (o *pdOracle) getStaleTimestamp(txnScope string, prevSecond uint64) (uint64, error) {
ts, ok := o.getLastTS(txnScope)
last, ok := o.getLastTSWithArrivalTS(txnScope)
if !ok {
return 0, errors.Errorf("get stale timestamp fail, txnScope: %s", txnScope)
}
arrivalTS, ok := o.getLastArrivalTS(txnScope)
if !ok {
return 0, errors.Errorf("get stale arrival timestamp fail, txnScope: %s", txnScope)
}
ts, arrivalTS := last.tso, last.arrival
arrivalTime := oracle.GetTimeFromTS(arrivalTS)
physicalTime := oracle.GetTimeFromTS(ts)
if uint64(physicalTime.Unix()) <= prevSecond {
return 0, errors.Errorf("invalid prevSecond %v", prevSecond)
}

staleTime := physicalTime.Add(-arrivalTime.Sub(time.Now().Add(-time.Duration(prevSecond) * time.Second)))

staleTime := physicalTime.Add(time.Now().Add(-time.Duration(prevSecond) * time.Second).Sub(arrivalTime))
return oracle.GoTimeToTS(staleTime), nil
}

Expand Down
32 changes: 32 additions & 0 deletions oracle/oracles/pd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,35 @@ func TestPdOracle_GetStaleTimestamp(t *testing.T) {
assert.NotNil(t, err)
assert.Regexp(t, ".*invalid prevSecond.*", err.Error())
}

func TestNonFutureStaleTSO(t *testing.T) {
o := oracles.NewEmptyPDOracle()
oracles.SetEmptyPDOracleLastTs(o, oracle.GoTimeToTS(time.Now()))
for i := 0; i < 100; i++ {
time.Sleep(10 * time.Millisecond)
now := time.Now()
upperBound := now.Add(5 * time.Millisecond) // allow 5ms time drift

closeCh := make(chan struct{})
go func() {
time.Sleep(100 * time.Microsecond)
oracles.SetEmptyPDOracleLastTs(o, oracle.GoTimeToTS(now))
close(closeCh)
}()
CHECK:
for {
select {
case <-closeCh:
break CHECK
default:
ts, err := o.GetStaleTimestamp(context.Background(), oracle.GlobalTxnScope, 0)
assert.Nil(t, err)
staleTime := oracle.GetTimeFromTS(ts)
if staleTime.After(upperBound) && time.Since(now) < time.Millisecond /* only check staleTime within 1ms */ {
assert.Less(t, staleTime, upperBound, i)
t.FailNow()
}
}
}
}
}

0 comments on commit 1946394

Please sign in to comment.