diff --git a/pkg/security/resolvers/mount/resolver.go b/pkg/security/resolvers/mount/resolver.go index 4e914649cae9ee..16674f0ee93cad 100644 --- a/pkg/security/resolvers/mount/resolver.go +++ b/pkg/security/resolvers/mount/resolver.go @@ -16,7 +16,6 @@ import ( "sync" "time" - "github.com/hashicorp/golang-lru/v2/simplelru" "github.com/moby/sys/mountinfo" "go.uber.org/atomic" @@ -29,6 +28,8 @@ import ( "github.com/DataDog/datadog-agent/pkg/security/secl/containerutils" "github.com/DataDog/datadog-agent/pkg/security/secl/model" "github.com/DataDog/datadog-agent/pkg/security/utils" + "github.com/DataDog/datadog-agent/pkg/security/utils/cache" + "github.com/DataDog/datadog-agent/pkg/security/utils/lru/simplelru" "github.com/DataDog/datadog-agent/pkg/util/kernel" ) @@ -91,8 +92,8 @@ type Resolver struct { cgroupsResolver *cgroup.Resolver statsdClient statsd.ClientInterface lock sync.RWMutex - mounts map[uint32]*model.Mount - pidToMounts map[uint32]map[uint32]*model.Mount + mounts *simplelru.LRU[uint32, *model.Mount] + pidToMounts *cache.TwoLayersLRU[uint32, uint32, *model.Mount] minMountID uint32 // used to find the first userspace visible mount ID redemption *simplelru.LRU[uint32, *redemptionEntry] fallbackLimiter *utils.Limiter[uint64] @@ -126,7 +127,7 @@ func (mr *Resolver) SyncCache(pid uint32) error { // store the minimal mount ID found to use it as a reference if pid == 1 { - for mountID := range mr.mounts { + for mountID := range mr.mounts.KeysIter() { if mr.minMountID == 0 || mr.minMountID > mountID { mr.minMountID = mountID } @@ -143,7 +144,7 @@ func (mr *Resolver) syncPid(pid uint32) error { } for _, mnt := range mnts { - if m, exists := mr.mounts[uint32(mnt.ID)]; exists { + if m, exists := mr.mounts.Get(uint32(mnt.ID)); m != nil && exists { mr.updatePidMapping(m, pid) continue } @@ -178,14 +179,14 @@ func (mr *Resolver) delete(mount *model.Mount) { mr.deleteOne(mount, now) - openQueue := make([]uint32, 0, len(mr.mounts)) + openQueue := make([]uint32, 0, mr.mounts.Len()) openQueue = append(openQueue, mount.MountID) for len(openQueue) != 0 { curr, rest := openQueue[len(openQueue)-1], openQueue[:len(openQueue)-1] openQueue = rest - for _, child := range mr.mounts { + for child := range mr.mounts.ValuesIter() { if child.ParentPathKey.MountID == curr { openQueue = append(openQueue, child.MountID) mr.deleteOne(child, now) @@ -195,10 +196,8 @@ func (mr *Resolver) delete(mount *model.Mount) { } func (mr *Resolver) deleteOne(curr *model.Mount, now time.Time) { - delete(mr.mounts, curr.MountID) - for _, mounts := range mr.pidToMounts { - delete(mounts, curr.MountID) - } + mr.mounts.Remove(curr.MountID) + mr.pidToMounts.RemoveKey2(curr.MountID) entry := redemptionEntry{ mount: curr, @@ -208,7 +207,7 @@ func (mr *Resolver) deleteOne(curr *model.Mount, now time.Time) { } func (mr *Resolver) finalize(mount *model.Mount) { - delete(mr.mounts, mount.MountID) + mr.mounts.Remove(mount.MountID) } // Delete a mount from the cache @@ -216,7 +215,7 @@ func (mr *Resolver) Delete(mountID uint32) error { mr.lock.Lock() defer mr.lock.Unlock() - if m, exists := mr.mounts[mountID]; exists { + if m, exists := mr.mounts.Get(mountID); exists { mr.delete(m) } else { return &ErrMountNotFound{MountID: mountID} @@ -257,12 +256,7 @@ func (mr *Resolver) updatePidMapping(m *model.Mount, pid uint32) { return } - mounts := mr.pidToMounts[pid] - if mounts == nil { - mounts = make(map[uint32]*model.Mount) - mr.pidToMounts[pid] = mounts - } - mounts[m.MountID] = m + mr.pidToMounts.Add(pid, m.MountID, m) } // DelPid removes the pid form the pid mapping @@ -274,12 +268,12 @@ func (mr *Resolver) DelPid(pid uint32) { mr.lock.Lock() defer mr.lock.Unlock() - delete(mr.pidToMounts, pid) + mr.pidToMounts.RemoveKey1(pid) } func (mr *Resolver) insert(m *model.Mount, pid uint32) { // umount the previous one if exists - if prev, ok := mr.mounts[m.MountID]; ok { + if prev, ok := mr.mounts.Get(m.MountID); prev != nil && ok { // put the prev entry and the all the children in the redemption list mr.delete(prev) // force a finalize on the entry itself as it will be overridden by the new one @@ -299,7 +293,7 @@ func (mr *Resolver) insert(m *model.Mount, pid uint32) { mr.minMountID = m.MountID } - mr.mounts[m.MountID] = m + mr.mounts.Add(m.MountID, m) mr.updatePidMapping(m, pid) } @@ -313,8 +307,7 @@ func (mr *Resolver) getFromRedemption(mountID uint32) *model.Mount { } func (mr *Resolver) lookupByMountID(mountID uint32) *model.Mount { - mount := mr.mounts[mountID] - if mount != nil { + if mount, ok := mr.mounts.Get(mountID); mount != nil && ok { return mount } @@ -324,17 +317,17 @@ func (mr *Resolver) lookupByMountID(mountID uint32) *model.Mount { func (mr *Resolver) lookupByDevice(device uint32, pid uint32) *model.Mount { var result *model.Mount - mounts := mr.pidToMounts[pid] - - for _, mount := range mounts { + mr.pidToMounts.WalkInner(pid, func(_ uint32, mount *model.Mount) bool { if mount.Device == device { // should be consistent across all the mounts if result != nil && result.MountPointStr != mount.MountPointStr { - return nil + result = nil + return false } result = mount } - } + return true + }) return result } @@ -515,8 +508,7 @@ func (mr *Resolver) resolveMount(mountID uint32, device uint32, pid uint32, cont return nil, model.MountSourceUnknown, model.MountOriginUnknown, err } - mount = mr.mounts[mountID] - if mount != nil { + if mount, ok := mr.mounts.Get(mountID); mount != nil && ok { mr.procHitsStats.Inc() return mount, model.MountSourceMountID, mount.Origin, nil } @@ -601,7 +593,7 @@ func (mr *Resolver) SendStats() error { return err } - return mr.statsdClient.Gauge(metrics.MetricMountResolverCacheSize, float64(len(mr.mounts)), []string{}, 1.0) + return mr.statsdClient.Gauge(metrics.MetricMountResolverCacheSize, float64(mr.mounts.Len()), []string{}, 1.0) } // ToJSON return a json version of the cache @@ -613,7 +605,7 @@ func (mr *Resolver) ToJSON() ([]byte, error) { mr.lock.RLock() defer mr.lock.RUnlock() - for _, mount := range mr.mounts { + for mount := range mr.mounts.ValuesIter() { d, err := json.Marshal(mount) if err == nil { dump.Entries = append(dump.Entries, d) @@ -623,15 +615,33 @@ func (mr *Resolver) ToJSON() ([]byte, error) { return json.Marshal(dump) } +const ( + // mounts LRU limit: 100000 mounts + mountsLimit = 100000 + // pidToMounts LRU limits: 1000 pids, and 1000 mounts per pid + pidLimit = 1000 + mountsPerPidLimit = 1000 +) + // NewResolver instantiates a new mount resolver func NewResolver(statsdClient statsd.ClientInterface, cgroupsResolver *cgroup.Resolver, opts ResolverOpts) (*Resolver, error) { + mounts, err := simplelru.NewLRU[uint32, *model.Mount](mountsLimit, nil) + if err != nil { + return nil, err + } + + pidToMounts, err := cache.NewTwoLayersLRU[uint32, uint32, *model.Mount](pidLimit * mountsPerPidLimit) + if err != nil { + return nil, err + } + mr := &Resolver{ opts: opts, statsdClient: statsdClient, cgroupsResolver: cgroupsResolver, lock: sync.RWMutex{}, - mounts: make(map[uint32]*model.Mount), - pidToMounts: make(map[uint32]map[uint32]*model.Mount), + mounts: mounts, + pidToMounts: pidToMounts, cacheHitsStats: atomic.NewInt64(0), procHitsStats: atomic.NewInt64(0), cacheMissStats: atomic.NewInt64(0), diff --git a/pkg/security/resolvers/mount/resolver_test.go b/pkg/security/resolvers/mount/resolver_test.go index e4b5894cf465af..c708bbcb6c6021 100644 --- a/pkg/security/resolvers/mount/resolver_test.go +++ b/pkg/security/resolvers/mount/resolver_test.go @@ -475,95 +475,105 @@ func TestMountResolver(t *testing.T) { } func TestMountGetParentPath(t *testing.T) { - mr := &Resolver{ - mounts: map[uint32]*model.Mount{ - 1: { - MountID: 1, - MountPointStr: "/", + mounts := map[uint32]*model.Mount{ + 1: { + MountID: 1, + MountPointStr: "/", + }, + 2: { + MountID: 2, + ParentPathKey: model.PathKey{ + MountID: 1, }, - 2: { + MountPointStr: "/a", + }, + 3: { + MountID: 3, + ParentPathKey: model.PathKey{ MountID: 2, - ParentPathKey: model.PathKey{ - MountID: 1, - }, - MountPointStr: "/a", }, - 3: { + MountPointStr: "/b", + }, + 4: { + MountID: 4, + ParentPathKey: model.PathKey{ MountID: 3, - ParentPathKey: model.PathKey{ - MountID: 2, - }, - MountPointStr: "/b", - }, - 4: { - MountID: 4, - ParentPathKey: model.PathKey{ - MountID: 3, - }, - MountPointStr: "/c", }, + MountPointStr: "/c", }, } + // Create mount resolver + cr, _ := cgroup.NewResolver(nil) + mr, _ := NewResolver(nil, cr, ResolverOpts{}) + for _, m := range mounts { + mr.mounts.Add(m.MountID, m) + } + parentPath, _, _, err := mr.getMountPath(4, 44, 1) assert.NoError(t, err) assert.Equal(t, "/a/b/c", parentPath) } func TestMountLoop(t *testing.T) { - mr := &Resolver{ - mounts: map[uint32]*model.Mount{ - 1: { - MountID: 1, - MountPointStr: "/", + mounts := map[uint32]*model.Mount{ + 1: { + MountID: 1, + MountPointStr: "/", + }, + 2: { + MountID: 2, + ParentPathKey: model.PathKey{ + MountID: 4, }, - 2: { + MountPointStr: "/a", + }, + 3: { + MountID: 3, + ParentPathKey: model.PathKey{ MountID: 2, - ParentPathKey: model.PathKey{ - MountID: 4, - }, - MountPointStr: "/a", }, - 3: { + MountPointStr: "/b", + }, + 4: { + MountID: 4, + ParentPathKey: model.PathKey{ MountID: 3, - ParentPathKey: model.PathKey{ - MountID: 2, - }, - MountPointStr: "/b", - }, - 4: { - MountID: 4, - ParentPathKey: model.PathKey{ - MountID: 3, - }, - MountPointStr: "/c", }, + MountPointStr: "/c", }, } + // Create mount resolver + cr, _ := cgroup.NewResolver(nil) + mr, _ := NewResolver(nil, cr, ResolverOpts{}) + for _, m := range mounts { + mr.mounts.Add(m.MountID, m) + } + parentPath, _, _, err := mr.getMountPath(3, 44, 1) assert.Equal(t, ErrMountLoop, err) assert.Equal(t, "", parentPath) } func BenchmarkGetParentPath(b *testing.B) { - mr := &Resolver{ - mounts: make(map[uint32]*model.Mount), - } + // Create mount resolver + cr, _ := cgroup.NewResolver(nil) + mr, _ := NewResolver(nil, cr, ResolverOpts{}) - mr.mounts[1] = &model.Mount{ + mr.mounts.Add(1, &model.Mount{ MountID: 1, MountPointStr: "/", - } + }) for i := uint32(1); i != 100; i++ { - mr.mounts[i+1] = &model.Mount{ + mr.mounts.Add(i+1, &model.Mount{ MountID: i + 1, ParentPathKey: model.PathKey{ MountID: i, }, MountPointStr: fmt.Sprintf("/%d", i+1), - } + }) } b.ResetTimer() diff --git a/pkg/security/utils/cache/lru_2layers.go b/pkg/security/utils/cache/lru_2layers.go index 66f0a71ef09d21..a052674c68406e 100644 --- a/pkg/security/utils/cache/lru_2layers.go +++ b/pkg/security/utils/cache/lru_2layers.go @@ -7,9 +7,11 @@ package cache import ( + "iter" "sync" - "github.com/hashicorp/golang-lru/v2/simplelru" + "github.com/DataDog/datadog-agent/pkg/security/utils/lru/simplelru" + "go.uber.org/atomic" ) @@ -88,26 +90,47 @@ func (tll *TwoLayersLRU[K1, K2, V]) RemoveKey1(k1 K1) bool { return true } -// RemoveKey2 remove the entry in the second layer -func (tll *TwoLayersLRU[K1, K2, V]) RemoveKey2(k1 K1, k2 K2) bool { +// RemoveKey2 removes the entry in the second layer for the given K1 keys. +// If no keys are provided, the function will try to remove the entry for all the keys. +// Returns the total number of entries that were removed from the cache. +func (tll *TwoLayersLRU[K1, K2, V]) RemoveKey2(k2 K2, keys ...K1) int { tll.Lock() defer tll.Unlock() - l2LRU, exists := tll.cache.Peek(k1) - if !exists { - return false - } - if !l2LRU.Remove(k2) { - return false + var k1Iter iter.Seq[K1] + if len(keys) == 0 { + k1Iter = tll.cache.KeysIter() + } else { + k1Iter = func(yield func(K1) bool) { + for _, k := range keys { + if !yield(k) { + return + } + } + } } - if l2LRU.Len() == 0 { - tll.cache.Remove(k1) + removed := 0 + for k1 := range k1Iter { + l2LRU, exists := tll.cache.Peek(k1) + if !exists { + continue + } + + if !l2LRU.Remove(k2) { + continue + } + + if l2LRU.Len() == 0 { + tll.cache.Remove(k1) + } + + removed++ } - tll.len.Dec() + tll.len.Sub(uint64(removed)) - return true + return removed } // RemoveOldest removes the oldest element @@ -160,9 +183,9 @@ func (tll *TwoLayersLRU[K1, K2, V]) Walk(cb func(k1 K1, k2 K2, v V)) { tll.RLock() defer tll.RUnlock() - for _, k1 := range tll.cache.Keys() { + for k1 := range tll.cache.KeysIter() { if l2LRU, exists := tll.cache.Peek(k1); exists { - for _, k2 := range l2LRU.Keys() { + for k2 := range l2LRU.KeysIter() { if value, exists := l2LRU.Peek(k2); exists { cb(k1, k2, value) } @@ -170,3 +193,19 @@ func (tll *TwoLayersLRU[K1, K2, V]) Walk(cb func(k1 K1, k2 K2, v V)) { } } } + +// WalkInner through all the keys of the inner LRU +func (tll *TwoLayersLRU[K1, K2, V]) WalkInner(k1 K1, cb func(k2 K2, v V) bool) { + tll.RLock() + defer tll.RUnlock() + + if l2LRU, exists := tll.cache.Peek(k1); exists { + for k2 := range l2LRU.KeysIter() { + if value, exists := l2LRU.Peek(k2); exists { + if continu := cb(k2, value); !continu { + return + } + } + } + } +} diff --git a/pkg/security/utils/cache/lru_2layers_test.go b/pkg/security/utils/cache/lru_2layers_test.go index aa287e58bf8183..b3e3cc02e193e8 100644 --- a/pkg/security/utils/cache/lru_2layers_test.go +++ b/pkg/security/utils/cache/lru_2layers_test.go @@ -41,8 +41,8 @@ func TestTwoLayersLRU(t *testing.T) { }) t.Run("remove-key2", func(t *testing.T) { - exists := cache.RemoveKey2("a", 2) - assert.True(t, exists) + removed := cache.RemoveKey2(2, "a") + assert.Equal(t, 1, removed) assert.Equal(t, cache.Len(), 1) })