From 4eb13a73bf71caafcdd309c61659cad63fbdb240 Mon Sep 17 00:00:00 2001 From: wwqgtxx Date: Mon, 22 Jul 2024 09:57:57 +0800 Subject: [PATCH] fix: wrong usage of RLock --- adapter/outboundgroup/loadbalance.go | 2 -- common/lru/lrucache.go | 32 ++++++++++++++++++++++++++++ common/queue/queue.go | 4 ++-- common/utils/callback.go | 4 ++-- component/sniffer/dispatcher.go | 20 ++++++----------- 5 files changed, 42 insertions(+), 20 deletions(-) diff --git a/adapter/outboundgroup/loadbalance.go b/adapter/outboundgroup/loadbalance.go index 4cb0db004f..738ed15479 100644 --- a/adapter/outboundgroup/loadbalance.go +++ b/adapter/outboundgroup/loadbalance.go @@ -205,7 +205,6 @@ func strategyStickySessions(url string) strategyFn { proxy := proxies[nowIdx] if proxy.AliveForTestUrl(url) { if nowIdx != idx { - lruCache.Delete(key) lruCache.Set(key, nowIdx) } @@ -215,7 +214,6 @@ func strategyStickySessions(url string) strategyFn { } } - lruCache.Delete(key) lruCache.Set(key, 0) return proxies[0] } diff --git a/common/lru/lrucache.go b/common/lru/lrucache.go index 6f32ed18b1..35f605b10c 100644 --- a/common/lru/lrucache.go +++ b/common/lru/lrucache.go @@ -223,6 +223,10 @@ func (c *LruCache[K, V]) Delete(key K) { c.mu.Lock() defer c.mu.Unlock() + c.delete(key) +} + +func (c *LruCache[K, V]) delete(key K) { if le, ok := c.cache[key]; ok { c.deleteElement(le) } @@ -255,6 +259,34 @@ func (c *LruCache[K, V]) Clear() error { return nil } +// Compute either sets the computed new value for the key or deletes +// the value for the key. When the delete result of the valueFn function +// is set to true, the value will be deleted, if it exists. When delete +// is set to false, the value is updated to the newValue. +// The ok result indicates whether value was computed and stored, thus, is +// present in the map. The actual result contains the new value in cases where +// the value was computed and stored. +func (c *LruCache[K, V]) Compute( + key K, + valueFn func(oldValue V, loaded bool) (newValue V, delete bool), +) (actual V, ok bool) { + c.mu.Lock() + defer c.mu.Unlock() + + if el := c.get(key); el != nil { + actual, ok = el.value, true + } + if newValue, del := valueFn(actual, ok); del { + if ok { // data not in cache, so needn't delete + c.delete(key) + } + return lo.Empty[V](), false + } else { + c.set(key, newValue) + return newValue, true + } +} + type entry[K comparable, V any] struct { key K value V diff --git a/common/queue/queue.go b/common/queue/queue.go index cb58e2f5a2..d1b6beebe5 100644 --- a/common/queue/queue.go +++ b/common/queue/queue.go @@ -59,8 +59,8 @@ func (q *Queue[T]) Copy() []T { // Len returns the number of items in this queue. func (q *Queue[T]) Len() int64 { - q.lock.Lock() - defer q.lock.Unlock() + q.lock.RLock() + defer q.lock.RUnlock() return int64(len(q.items)) } diff --git a/common/utils/callback.go b/common/utils/callback.go index df950d3a81..ad734c0fd6 100644 --- a/common/utils/callback.go +++ b/common/utils/callback.go @@ -17,8 +17,8 @@ func NewCallback[T any]() *Callback[T] { } func (c *Callback[T]) Register(item func(T)) io.Closer { - c.mutex.RLock() - defer c.mutex.RUnlock() + c.mutex.Lock() + defer c.mutex.Unlock() element := c.list.PushBack(item) return &callbackCloser[T]{ element: element, diff --git a/component/sniffer/dispatcher.go b/component/sniffer/dispatcher.go index 97bf162969..4438638dad 100644 --- a/component/sniffer/dispatcher.go +++ b/component/sniffer/dispatcher.go @@ -5,7 +5,6 @@ import ( "fmt" "net" "net/netip" - "sync" "time" "github.com/metacubex/mihomo/common/lru" @@ -30,7 +29,6 @@ type SnifferDispatcher struct { forceDomain *trie.DomainSet skipSNI *trie.DomainSet skipList *lru.LruCache[string, uint8] - rwMux sync.RWMutex forceDnsMapping bool parsePureIp bool } @@ -85,14 +83,11 @@ func (sd *SnifferDispatcher) TCPSniff(conn *N.BufferedConn, metadata *C.Metadata return false } - sd.rwMux.RLock() dst := fmt.Sprintf("%s:%d", metadata.DstIP, metadata.DstPort) if count, ok := sd.skipList.Get(dst); ok && count > 5 { log.Debugln("[Sniffer] Skip sniffing[%s] due to multiple failures", dst) - defer sd.rwMux.RUnlock() return false } - sd.rwMux.RUnlock() if host, err := sd.sniffDomain(conn, metadata); err != nil { sd.cacheSniffFailed(metadata) @@ -104,9 +99,7 @@ func (sd *SnifferDispatcher) TCPSniff(conn *N.BufferedConn, metadata *C.Metadata return false } - sd.rwMux.RLock() sd.skipList.Delete(dst) - sd.rwMux.RUnlock() sd.replaceDomain(metadata, host, overrideDest) return true @@ -176,14 +169,13 @@ func (sd *SnifferDispatcher) sniffDomain(conn *N.BufferedConn, metadata *C.Metad } func (sd *SnifferDispatcher) cacheSniffFailed(metadata *C.Metadata) { - sd.rwMux.Lock() dst := fmt.Sprintf("%s:%d", metadata.DstIP, metadata.DstPort) - count, _ := sd.skipList.Get(dst) - if count <= 5 { - count++ - } - sd.skipList.Set(dst, count) - sd.rwMux.Unlock() + sd.skipList.Compute(dst, func(oldValue uint8, loaded bool) (newValue uint8, delete bool) { + if oldValue <= 5 { + oldValue++ + } + return oldValue, false + }) } func NewCloseSnifferDispatcher() (*SnifferDispatcher, error) {