Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CWS-3771] Remove maps from mount resolver #34637

Merged
merged 5 commits into from
Mar 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 45 additions & 35 deletions pkg/security/resolvers/mount/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ import (
"sync"
"time"

"github.com/hashicorp/golang-lru/v2/simplelru"
"github.com/moby/sys/mountinfo"
"go.uber.org/atomic"

Expand All @@ -29,6 +28,8 @@ import (
"github.com/DataDog/datadog-agent/pkg/security/secl/containerutils"
"github.com/DataDog/datadog-agent/pkg/security/secl/model"
"github.com/DataDog/datadog-agent/pkg/security/utils"
"github.com/DataDog/datadog-agent/pkg/security/utils/cache"
"github.com/DataDog/datadog-agent/pkg/security/utils/lru/simplelru"
"github.com/DataDog/datadog-agent/pkg/util/kernel"
)

Expand Down Expand Up @@ -91,8 +92,8 @@ type Resolver struct {
cgroupsResolver *cgroup.Resolver
statsdClient statsd.ClientInterface
lock sync.RWMutex
mounts map[uint32]*model.Mount
pidToMounts map[uint32]map[uint32]*model.Mount
mounts *simplelru.LRU[uint32, *model.Mount]
pidToMounts *cache.TwoLayersLRU[uint32, uint32, *model.Mount]
minMountID uint32 // used to find the first userspace visible mount ID
redemption *simplelru.LRU[uint32, *redemptionEntry]
fallbackLimiter *utils.Limiter[uint64]
Expand Down Expand Up @@ -126,7 +127,7 @@ func (mr *Resolver) SyncCache(pid uint32) error {

// store the minimal mount ID found to use it as a reference
if pid == 1 {
for mountID := range mr.mounts {
for mountID := range mr.mounts.KeysIter() {
if mr.minMountID == 0 || mr.minMountID > mountID {
mr.minMountID = mountID
}
Expand All @@ -143,7 +144,7 @@ func (mr *Resolver) syncPid(pid uint32) error {
}

for _, mnt := range mnts {
if m, exists := mr.mounts[uint32(mnt.ID)]; exists {
if m, exists := mr.mounts.Get(uint32(mnt.ID)); m != nil && exists {
mr.updatePidMapping(m, pid)
continue
}
Expand Down Expand Up @@ -178,14 +179,14 @@ func (mr *Resolver) delete(mount *model.Mount) {

mr.deleteOne(mount, now)

openQueue := make([]uint32, 0, len(mr.mounts))
openQueue := make([]uint32, 0, mr.mounts.Len())
openQueue = append(openQueue, mount.MountID)

for len(openQueue) != 0 {
curr, rest := openQueue[len(openQueue)-1], openQueue[:len(openQueue)-1]
openQueue = rest

for _, child := range mr.mounts {
for child := range mr.mounts.ValuesIter() {
if child.ParentPathKey.MountID == curr {
openQueue = append(openQueue, child.MountID)
mr.deleteOne(child, now)
Expand All @@ -195,10 +196,8 @@ func (mr *Resolver) delete(mount *model.Mount) {
}

func (mr *Resolver) deleteOne(curr *model.Mount, now time.Time) {
delete(mr.mounts, curr.MountID)
for _, mounts := range mr.pidToMounts {
delete(mounts, curr.MountID)
}
mr.mounts.Remove(curr.MountID)
mr.pidToMounts.RemoveKey2(curr.MountID)

entry := redemptionEntry{
mount: curr,
Expand All @@ -208,15 +207,15 @@ func (mr *Resolver) deleteOne(curr *model.Mount, now time.Time) {
}

func (mr *Resolver) finalize(mount *model.Mount) {
delete(mr.mounts, mount.MountID)
mr.mounts.Remove(mount.MountID)
}

// Delete a mount from the cache
func (mr *Resolver) Delete(mountID uint32) error {
mr.lock.Lock()
defer mr.lock.Unlock()

if m, exists := mr.mounts[mountID]; exists {
if m, exists := mr.mounts.Get(mountID); exists {
mr.delete(m)
} else {
return &ErrMountNotFound{MountID: mountID}
Expand Down Expand Up @@ -257,12 +256,7 @@ func (mr *Resolver) updatePidMapping(m *model.Mount, pid uint32) {
return
}

mounts := mr.pidToMounts[pid]
if mounts == nil {
mounts = make(map[uint32]*model.Mount)
mr.pidToMounts[pid] = mounts
}
mounts[m.MountID] = m
mr.pidToMounts.Add(pid, m.MountID, m)
}

// DelPid removes the pid form the pid mapping
Expand All @@ -274,12 +268,12 @@ func (mr *Resolver) DelPid(pid uint32) {
mr.lock.Lock()
defer mr.lock.Unlock()

delete(mr.pidToMounts, pid)
mr.pidToMounts.RemoveKey1(pid)
}

func (mr *Resolver) insert(m *model.Mount, pid uint32) {
// umount the previous one if exists
if prev, ok := mr.mounts[m.MountID]; ok {
if prev, ok := mr.mounts.Get(m.MountID); prev != nil && ok {
// put the prev entry and the all the children in the redemption list
mr.delete(prev)
// force a finalize on the entry itself as it will be overridden by the new one
Expand All @@ -299,7 +293,7 @@ func (mr *Resolver) insert(m *model.Mount, pid uint32) {
mr.minMountID = m.MountID
}

mr.mounts[m.MountID] = m
mr.mounts.Add(m.MountID, m)

mr.updatePidMapping(m, pid)
}
Expand All @@ -313,8 +307,7 @@ func (mr *Resolver) getFromRedemption(mountID uint32) *model.Mount {
}

func (mr *Resolver) lookupByMountID(mountID uint32) *model.Mount {
mount := mr.mounts[mountID]
if mount != nil {
if mount, ok := mr.mounts.Get(mountID); mount != nil && ok {
return mount
}

Expand All @@ -324,17 +317,17 @@ func (mr *Resolver) lookupByMountID(mountID uint32) *model.Mount {
func (mr *Resolver) lookupByDevice(device uint32, pid uint32) *model.Mount {
var result *model.Mount

mounts := mr.pidToMounts[pid]

for _, mount := range mounts {
mr.pidToMounts.WalkInner(pid, func(_ uint32, mount *model.Mount) bool {
if mount.Device == device {
// should be consistent across all the mounts
if result != nil && result.MountPointStr != mount.MountPointStr {
return nil
result = nil
return false
}
result = mount
}
}
return true
})

return result
}
Expand Down Expand Up @@ -515,8 +508,7 @@ func (mr *Resolver) resolveMount(mountID uint32, device uint32, pid uint32, cont
return nil, model.MountSourceUnknown, model.MountOriginUnknown, err
}

mount = mr.mounts[mountID]
if mount != nil {
if mount, ok := mr.mounts.Get(mountID); mount != nil && ok {
mr.procHitsStats.Inc()
return mount, model.MountSourceMountID, mount.Origin, nil
}
Expand Down Expand Up @@ -601,7 +593,7 @@ func (mr *Resolver) SendStats() error {
return err
}

return mr.statsdClient.Gauge(metrics.MetricMountResolverCacheSize, float64(len(mr.mounts)), []string{}, 1.0)
return mr.statsdClient.Gauge(metrics.MetricMountResolverCacheSize, float64(mr.mounts.Len()), []string{}, 1.0)
}

// ToJSON return a json version of the cache
Expand All @@ -613,7 +605,7 @@ func (mr *Resolver) ToJSON() ([]byte, error) {
mr.lock.RLock()
defer mr.lock.RUnlock()

for _, mount := range mr.mounts {
for mount := range mr.mounts.ValuesIter() {
d, err := json.Marshal(mount)
if err == nil {
dump.Entries = append(dump.Entries, d)
Expand All @@ -623,15 +615,33 @@ func (mr *Resolver) ToJSON() ([]byte, error) {
return json.Marshal(dump)
}

const (
// mounts LRU limit: 100000 mounts
mountsLimit = 100000
// pidToMounts LRU limits: 1000 pids, and 1000 mounts per pid
pidLimit = 1000
mountsPerPidLimit = 1000
)

// NewResolver instantiates a new mount resolver
func NewResolver(statsdClient statsd.ClientInterface, cgroupsResolver *cgroup.Resolver, opts ResolverOpts) (*Resolver, error) {
mounts, err := simplelru.NewLRU[uint32, *model.Mount](mountsLimit, nil)
if err != nil {
return nil, err
}

pidToMounts, err := cache.NewTwoLayersLRU[uint32, uint32, *model.Mount](pidLimit * mountsPerPidLimit)
if err != nil {
return nil, err
}

mr := &Resolver{
opts: opts,
statsdClient: statsdClient,
cgroupsResolver: cgroupsResolver,
lock: sync.RWMutex{},
mounts: make(map[uint32]*model.Mount),
pidToMounts: make(map[uint32]map[uint32]*model.Mount),
mounts: mounts,
pidToMounts: pidToMounts,
cacheHitsStats: atomic.NewInt64(0),
procHitsStats: atomic.NewInt64(0),
cacheMissStats: atomic.NewInt64(0),
Expand Down
112 changes: 61 additions & 51 deletions pkg/security/resolvers/mount/resolver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -475,95 +475,105 @@ func TestMountResolver(t *testing.T) {
}

func TestMountGetParentPath(t *testing.T) {
mr := &Resolver{
mounts: map[uint32]*model.Mount{
1: {
MountID: 1,
MountPointStr: "/",
mounts := map[uint32]*model.Mount{
1: {
MountID: 1,
MountPointStr: "/",
},
2: {
MountID: 2,
ParentPathKey: model.PathKey{
MountID: 1,
},
2: {
MountPointStr: "/a",
},
3: {
MountID: 3,
ParentPathKey: model.PathKey{
MountID: 2,
ParentPathKey: model.PathKey{
MountID: 1,
},
MountPointStr: "/a",
},
3: {
MountPointStr: "/b",
},
4: {
MountID: 4,
ParentPathKey: model.PathKey{
MountID: 3,
ParentPathKey: model.PathKey{
MountID: 2,
},
MountPointStr: "/b",
},
4: {
MountID: 4,
ParentPathKey: model.PathKey{
MountID: 3,
},
MountPointStr: "/c",
},
MountPointStr: "/c",
},
}

// Create mount resolver
cr, _ := cgroup.NewResolver(nil)
mr, _ := NewResolver(nil, cr, ResolverOpts{})
for _, m := range mounts {
mr.mounts.Add(m.MountID, m)
}

parentPath, _, _, err := mr.getMountPath(4, 44, 1)
assert.NoError(t, err)
assert.Equal(t, "/a/b/c", parentPath)
}

func TestMountLoop(t *testing.T) {
mr := &Resolver{
mounts: map[uint32]*model.Mount{
1: {
MountID: 1,
MountPointStr: "/",
mounts := map[uint32]*model.Mount{
1: {
MountID: 1,
MountPointStr: "/",
},
2: {
MountID: 2,
ParentPathKey: model.PathKey{
MountID: 4,
},
2: {
MountPointStr: "/a",
},
3: {
MountID: 3,
ParentPathKey: model.PathKey{
MountID: 2,
ParentPathKey: model.PathKey{
MountID: 4,
},
MountPointStr: "/a",
},
3: {
MountPointStr: "/b",
},
4: {
MountID: 4,
ParentPathKey: model.PathKey{
MountID: 3,
ParentPathKey: model.PathKey{
MountID: 2,
},
MountPointStr: "/b",
},
4: {
MountID: 4,
ParentPathKey: model.PathKey{
MountID: 3,
},
MountPointStr: "/c",
},
MountPointStr: "/c",
},
}

// Create mount resolver
cr, _ := cgroup.NewResolver(nil)
mr, _ := NewResolver(nil, cr, ResolverOpts{})
for _, m := range mounts {
mr.mounts.Add(m.MountID, m)
}

parentPath, _, _, err := mr.getMountPath(3, 44, 1)
assert.Equal(t, ErrMountLoop, err)
assert.Equal(t, "", parentPath)
}

func BenchmarkGetParentPath(b *testing.B) {
mr := &Resolver{
mounts: make(map[uint32]*model.Mount),
}
// Create mount resolver
cr, _ := cgroup.NewResolver(nil)
mr, _ := NewResolver(nil, cr, ResolverOpts{})

mr.mounts[1] = &model.Mount{
mr.mounts.Add(1, &model.Mount{
MountID: 1,
MountPointStr: "/",
}
})

for i := uint32(1); i != 100; i++ {
mr.mounts[i+1] = &model.Mount{
mr.mounts.Add(i+1, &model.Mount{
MountID: i + 1,
ParentPathKey: model.PathKey{
MountID: i,
},
MountPointStr: fmt.Sprintf("/%d", i+1),
}
})
}

b.ResetTimer()
Expand Down
Loading