Skip to content

Commit

Permalink
feat: move sync entry to internal package
Browse files Browse the repository at this point in the history
Signed-off-by: Chen Su <ghosind@gmail.com>
  • Loading branch information
ghosind committed Dec 17, 2024
1 parent 890d1f5 commit e02f5bf
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 127 deletions.
183 changes: 58 additions & 125 deletions dict/sync_dict.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,86 +6,19 @@ import (
"sync/atomic"

"github.com/ghosind/collection"
"github.com/ghosind/collection/internal"
)

// SyncDict is a thread-safe map implementation based on sync.Map's algorithm.
type SyncDict[K comparable, V any] struct {
mu sync.Mutex
read atomic.Pointer[syncReadOnly[K, V]]
dirty map[K]*syncEntry[V]
read atomic.Pointer[internal.SyncReadOnly[K, V]]
dirty map[K]*internal.SyncEntry[V]
misses int
zero V
expunged *V
}

type syncReadOnly[K comparable, V any] struct {
m map[K]*syncEntry[V]
amended bool
}

type syncEntry[T any] struct {
p atomic.Pointer[T]
expunged *T
}

func newSyncEntry[T any](v T, expunged *T) *syncEntry[T] {
e := new(syncEntry[T])
e.p.Store(&v)
e.expunged = expunged
return e
}

func (e *syncEntry[T]) load(val T) (value T, ok bool) {
p := e.p.Load()
if p == nil || p == e.expunged {
return val, false
}
return *p, true
}

func (e *syncEntry[T]) trySwap(val *T) (*T, bool) {
for {
p := e.p.Load()
if p == e.expunged {
return nil, false
}
if e.p.CompareAndSwap(p, val) {
return p, true
}
}
}

func (e *syncEntry[T]) delete() (*T, bool) {
for {
p := e.p.Load()
if p == nil || p == e.expunged {
return nil, false
}
if e.p.CompareAndSwap(p, nil) {
return p, true
}
}
}

func (e *syncEntry[T]) unexpungeLocked() bool {
return e.p.CompareAndSwap(e.expunged, nil)
}

func (e *syncEntry[T]) swapLocked(v *T) *T {
return e.p.Swap(v)
}

func (e *syncEntry[T]) tryExpungeLocked() bool {
p := e.p.Load()
for p == nil {
if e.p.CompareAndSwap(nil, e.expunged) {
return true
}
p = e.p.Load()
}
return p == e.expunged
}

// NewSyncDict creates a new SyncDict.
func NewSyncDict[K comparable, V any]() *SyncDict[K, V] {
d := new(SyncDict[K, V])
Expand All @@ -95,20 +28,20 @@ func NewSyncDict[K comparable, V any]() *SyncDict[K, V] {
return d
}

func (d *SyncDict[K, V]) loadReadOnly() syncReadOnly[K, V] {
func (d *SyncDict[K, V]) loadReadOnly() internal.SyncReadOnly[K, V] {
if p := d.read.Load(); p != nil {
return *p
}
return syncReadOnly[K, V]{}
return internal.SyncReadOnly[K, V]{}
}

func (d *SyncDict[K, V]) loadPresentReadOnly() syncReadOnly[K, V] {
func (d *SyncDict[K, V]) loadPresentReadOnly() internal.SyncReadOnly[K, V] {
read := d.loadReadOnly()
if read.amended {
if read.Amended {
d.mu.Lock()
read = d.loadReadOnly()
if read.amended {
read = syncReadOnly[K, V]{m: d.dirty}
if read.Amended {
read = internal.SyncReadOnly[K, V]{M: d.dirty}
copyRead := read
d.read.Store(&copyRead)
d.dirty = nil
Expand All @@ -126,9 +59,9 @@ func (d *SyncDict[K, V]) dirtyLocked() {
}

read := d.loadReadOnly()
d.dirty = make(map[K]*syncEntry[V], len(read.m))
for k, e := range read.m {
if !e.tryExpungeLocked() {
d.dirty = make(map[K]*internal.SyncEntry[V], len(read.M))
for k, e := range read.M {
if !e.TryExpungeLocked() {
d.dirty[k] = e
}
}
Expand All @@ -140,20 +73,20 @@ func (d *SyncDict[K, V]) missLocked() {
return
}

d.read.Store(&syncReadOnly[K, V]{m: d.dirty})
d.read.Store(&internal.SyncReadOnly[K, V]{M: d.dirty})
d.dirty = nil
d.misses = 0
}

// Get returns the value which associated to the specified key.
func (d *SyncDict[K, V]) get(key K, val V) (V, bool) {
read := d.loadReadOnly()
e, ok := read.m[key]
if !ok && read.amended {
e, ok := read.M[key]
if !ok && read.Amended {
d.mu.Lock()
read = d.loadReadOnly()
e, ok = read.m[key]
if !ok && read.amended {
e, ok = read.M[key]
if !ok && read.Amended {
e, ok = d.dirty[key]
d.missLocked()
}
Expand All @@ -162,36 +95,36 @@ func (d *SyncDict[K, V]) get(key K, val V) (V, bool) {
if !ok {
return val, false
}
return e.load(val)
return e.Load(val)
}

func (d *SyncDict[K, V]) swap(key K, val V, ignore bool) (*V, bool) {
read := d.loadReadOnly()
if e, ok := read.m[key]; ok {
return e.trySwap(&val)
if e, ok := read.M[key]; ok {
return e.TrySwap(&val)
}

d.mu.Lock()
defer d.mu.Unlock()

read = d.loadReadOnly()
if e, ok := read.m[key]; ok {
if e.unexpungeLocked() {
if e, ok := read.M[key]; ok {
if e.UnexpungeLocked() {
d.dirty[key] = e
}
if v := e.swapLocked(&val); v != nil {
if v := e.SwapLocked(&val); v != nil {
return v, true
}
} else if e, ok := d.dirty[key]; ok {
if v := e.swapLocked(&val); v != nil {
if v := e.SwapLocked(&val); v != nil {
return v, true
}
} else if !ignore {
if !read.amended {
if !read.Amended {
d.dirtyLocked()
d.read.Store(&syncReadOnly[K, V]{m: read.m, amended: true})
d.read.Store(&internal.SyncReadOnly[K, V]{M: read.M, Amended: true})
}
d.dirty[key] = newSyncEntry(val, d.expunged)
d.dirty[key] = internal.NewSyncEntry(val, d.expunged)
}
return nil, false
}
Expand All @@ -201,31 +134,31 @@ func (d *SyncDict[K, V]) Clear() {
d.mu.Lock()
defer d.mu.Unlock()
read := d.loadReadOnly()
if read.amended {
if read.Amended {
d.dirty = nil
d.misses = 0
}
read = syncReadOnly[K, V]{m: make(map[K]*syncEntry[V])}
read = internal.SyncReadOnly[K, V]{M: make(map[K]*internal.SyncEntry[V])}
copyRead := read
d.read.Store(&copyRead)
}

// Clone returns a copy of this dictionary.
func (d *SyncDict[K, V]) Clone() collection.Dict[K, V] {
read := d.loadPresentReadOnly()
m := make(map[K]*syncEntry[V])
m := make(map[K]*internal.SyncEntry[V])
expunged := new(V)

for k, e := range read.m {
v, ok := e.load(d.zero)
for k, e := range read.M {
v, ok := e.Load(d.zero)
if ok {
m[k] = newSyncEntry(v, expunged)
m[k] = internal.NewSyncEntry(v, expunged)
}
}

newDict := new(SyncDict[K, V])
newDict.expunged = expunged
newDict.read.Store(&syncReadOnly[K, V]{m: m})
newDict.read.Store(&internal.SyncReadOnly[K, V]{M: m})

return newDict
}
Expand All @@ -249,18 +182,18 @@ func (d *SyncDict[K, V]) Equals(o any) bool {
rs := 0
os := 0

for k, e := range read.m {
dv, ok := e.load(d.zero)
for k, e := range read.M {
dv, ok := e.Load(d.zero)
if !ok {
continue
}
rs++

oe, ok := oRead.m[k]
oe, ok := oRead.M[k]
if !ok {
return false
}
ov, ok := oe.load(od.zero)
ov, ok := oe.Load(od.zero)
if !ok {
return false
}
Expand All @@ -270,8 +203,8 @@ func (d *SyncDict[K, V]) Equals(o any) bool {
}
}

for _, e := range oRead.m {
_, ok := e.load(d.zero)
for _, e := range oRead.M {
_, ok := e.Load(d.zero)
if !ok {
continue
}
Expand All @@ -289,8 +222,8 @@ func (d *SyncDict[K, V]) Equals(o any) bool {
func (d *SyncDict[K, V]) ForEach(handler func(K, V) error) error {
read := d.loadPresentReadOnly()

for k, e := range read.m {
v, ok := e.load(d.zero)
for k, e := range read.M {
v, ok := e.Load(d.zero)
if !ok {
continue
}
Expand All @@ -317,12 +250,12 @@ func (d *SyncDict[K, V]) GetDefault(key K, defaultVal V) V {
// IsEmpty returns true if this dictionary is empty.
func (d *SyncDict[K, V]) IsEmpty() bool {
read := d.loadPresentReadOnly()
if len(read.m) == 0 {
if len(read.M) == 0 {
return true
}

for _, e := range read.m {
_, ok := e.load(d.zero)
for _, e := range read.M {
_, ok := e.Load(d.zero)
if ok {
return false
}
Expand All @@ -335,9 +268,9 @@ func (d *SyncDict[K, V]) IsEmpty() bool {
func (d *SyncDict[K, V]) Keys() []K {
read := d.loadPresentReadOnly()

keys := make([]K, 0, len(read.m))
for k, e := range read.m {
_, ok := e.load(d.zero)
keys := make([]K, 0, len(read.M))
for k, e := range read.M {
_, ok := e.Load(d.zero)
if !ok {
continue
}
Expand All @@ -360,20 +293,20 @@ func (d *SyncDict[K, V]) Put(key K, val V) V {
// Remove removes the key-value pair with the specified key.
func (d *SyncDict[K, V]) Remove(key K) V {
read := d.loadReadOnly()
e, ok := read.m[key]
if !ok && read.amended {
e, ok := read.M[key]
if !ok && read.Amended {
d.mu.Lock()
read = d.loadReadOnly()
e, ok = read.m[key]
if !ok && read.amended {
e, ok = read.M[key]
if !ok && read.Amended {
e, ok = d.dirty[key]
delete(d.dirty, key)
d.missLocked()
}
d.mu.Unlock()
}
if ok {
vp, ok := e.delete()
vp, ok := e.Delete()
if ok {
return *vp
}
Expand All @@ -397,8 +330,8 @@ func (d *SyncDict[K, V]) Size() int {
read := d.loadPresentReadOnly()
size := 0

for _, e := range read.m {
_, ok := e.load(d.zero)
for _, e := range read.M {
_, ok := e.Load(d.zero)
if !ok {
continue
}
Expand All @@ -412,9 +345,9 @@ func (d *SyncDict[K, V]) Size() int {
func (d *SyncDict[K, V]) Values() []V {
read := d.loadPresentReadOnly()

keys := make([]V, 0, len(read.m))
for _, e := range read.m {
v, ok := e.load(d.zero)
keys := make([]V, 0, len(read.M))
for _, e := range read.M {
v, ok := e.Load(d.zero)
if !ok {
continue
}
Expand Down
4 changes: 2 additions & 2 deletions dict/sync_dict_go123.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ func (d *SyncDict[K, V]) Iter() iter.Seq2[K, V] {
return func(yield func(K, V) bool) {
read := d.loadPresentReadOnly()

for k, e := range read.m {
v, ok := e.load(d.zero)
for k, e := range read.M {
v, ok := e.Load(d.zero)
if !ok {
continue
}
Expand Down
Loading

0 comments on commit e02f5bf

Please sign in to comment.