diff --git a/internal/sys/fd.go b/internal/sys/fd.go index 028be0045..399edbd20 100644 --- a/internal/sys/fd.go +++ b/internal/sys/fd.go @@ -4,8 +4,10 @@ import ( "fmt" "math" "os" + "path/filepath" "runtime" "strconv" + "strings" "github.com/cilium/ebpf/internal/testutils/fdtrace" "github.com/cilium/ebpf/internal/unix" @@ -118,3 +120,43 @@ func (fd *FD) File(name string) *os.File { return os.NewFile(uintptr(fd.disown()), name) } + +// ObjGetTyped wraps [ObjGet] with a readlink call to extract the type of the +// underlying bpf object. +func ObjGetTyped(attr *ObjGetAttr) (*FD, ObjType, error) { + fd, err := ObjGet(attr) + if err != nil { + return nil, 0, err + } + + typ, err := readType(fd) + if err != nil { + _ = fd.Close() + return nil, 0, fmt.Errorf("reading fd type: %w", err) + } + + return fd, typ, nil +} + +// readType returns the bpf object type of the file descriptor by calling +// readlink(3). Returns an error if the file descriptor does not represent a bpf +// object. +func readType(fd *FD) (ObjType, error) { + s, err := os.Readlink(filepath.Join("/proc/self/fd/", fd.String())) + if err != nil { + return 0, fmt.Errorf("readlink fd %d: %w", fd.Int(), err) + } + + s = strings.TrimPrefix(s, "anon_inode:") + + switch s { + case "bpf-map": + return BPF_TYPE_MAP, nil + case "bpf-prog": + return BPF_TYPE_PROG, nil + case "bpf-link": + return BPF_TYPE_LINK, nil + } + + return 0, fmt.Errorf("unknown type %s of fd %d", s, fd.Int()) +} diff --git a/link/link.go b/link/link.go index eef834a81..796769f8e 100644 --- a/link/link.go +++ b/link/link.go @@ -78,7 +78,9 @@ func NewFromID(id ID) (Link, error) { return wrapRawLink(&RawLink{fd, ""}) } -// LoadPinnedLink loads a link that was persisted into a bpffs. +// LoadPinnedLink loads a Link from a pin (file) on the BPF virtual filesystem. +// +// Requires at least Linux 5.7. func LoadPinnedLink(fileName string, opts *ebpf.LoadPinOptions) (Link, error) { raw, err := loadPinnedRawLink(fileName, opts) if err != nil { @@ -350,7 +352,7 @@ func AttachRawLink(opts RawLinkOptions) (*RawLink, error) { } func loadPinnedRawLink(fileName string, opts *ebpf.LoadPinOptions) (*RawLink, error) { - fd, err := sys.ObjGet(&sys.ObjGetAttr{ + fd, typ, err := sys.ObjGetTyped(&sys.ObjGetAttr{ Pathname: sys.NewStringPointer(fileName), FileFlags: opts.Marshal(), }) @@ -358,6 +360,11 @@ func loadPinnedRawLink(fileName string, opts *ebpf.LoadPinOptions) (*RawLink, er return nil, fmt.Errorf("load pinned link: %w", err) } + if typ != sys.BPF_TYPE_LINK { + _ = fd.Close() + return nil, fmt.Errorf("%s is not a Link", fileName) + } + return &RawLink{fd, fileName}, nil } diff --git a/link/link_test.go b/link/link_test.go index 9b68f9604..ff994c448 100644 --- a/link/link_test.go +++ b/link/link_test.go @@ -370,3 +370,30 @@ func mustLoadProgram(tb testing.TB, typ ebpf.ProgramType, attachType ebpf.Attach return prog } + +func TestLoadWrongPin(t *testing.T) { + cg, p := mustCgroupFixtures(t) + + l, err := AttachRawLink(RawLinkOptions{ + Target: int(cg.Fd()), + Program: p, + Attach: ebpf.AttachCGroupInetEgress, + }) + testutils.SkipIfNotSupported(t, err) + t.Cleanup(func() { l.Close() }) + + tmp := testutils.TempBPFFS(t) + + ppath := filepath.Join(tmp, "prog") + lpath := filepath.Join(tmp, "link") + + qt.Assert(t, qt.IsNil(p.Pin(ppath))) + qt.Assert(t, qt.IsNil(l.Pin(lpath))) + + _, err = LoadPinnedLink(ppath, nil) + qt.Assert(t, qt.IsNotNil(err)) + + ll, err := LoadPinnedLink(lpath, nil) + qt.Assert(t, qt.IsNil(err)) + qt.Assert(t, qt.IsNil(ll.Close())) +} diff --git a/map.go b/map.go index c5010b419..f3f98b46d 100644 --- a/map.go +++ b/map.go @@ -1560,9 +1560,11 @@ func (m *Map) unmarshalValue(value any, buf sysenc.Buffer) error { return buf.Unmarshal(value) } -// LoadPinnedMap loads a Map from a BPF file. +// LoadPinnedMap opens a Map from a pin (file) on the BPF virtual filesystem. +// +// Requires at least Linux 4.5. func LoadPinnedMap(fileName string, opts *LoadPinOptions) (*Map, error) { - fd, err := sys.ObjGet(&sys.ObjGetAttr{ + fd, typ, err := sys.ObjGetTyped(&sys.ObjGetAttr{ Pathname: sys.NewStringPointer(fileName), FileFlags: opts.Marshal(), }) @@ -1570,6 +1572,11 @@ func LoadPinnedMap(fileName string, opts *LoadPinOptions) (*Map, error) { return nil, err } + if typ != sys.BPF_TYPE_MAP { + _ = fd.Close() + return nil, fmt.Errorf("%s is not a Map", fileName) + } + m, err := newMapFromFD(fd) if err == nil { m.pinnedPath = fileName diff --git a/map_test.go b/map_test.go index dc66a205e..d690eecd2 100644 --- a/map_test.go +++ b/map_test.go @@ -1894,6 +1894,36 @@ func TestPerfEventArrayCompatible(t *testing.T) { qt.Assert(t, qt.IsNotNil(ms.Compatible(m))) } +func TestLoadWrongPin(t *testing.T) { + p := mustSocketFilter(t) + m := newHash(t) + tmp := testutils.TempBPFFS(t) + + ppath := filepath.Join(tmp, "prog") + mpath := filepath.Join(tmp, "map") + + qt.Assert(t, qt.IsNil(m.Pin(mpath))) + qt.Assert(t, qt.IsNil(p.Pin(ppath))) + + t.Run("Program", func(t *testing.T) { + _, err := LoadPinnedProgram(mpath, nil) + qt.Assert(t, qt.IsNotNil(err)) + + lp, err := LoadPinnedProgram(ppath, nil) + qt.Assert(t, qt.IsNil(err)) + qt.Assert(t, qt.IsNil(lp.Close())) + }) + + t.Run("Map", func(t *testing.T) { + _, err := LoadPinnedMap(ppath, nil) + qt.Assert(t, qt.IsNotNil(err)) + + lm, err := LoadPinnedMap(mpath, nil) + qt.Assert(t, qt.IsNil(err)) + qt.Assert(t, qt.IsNil(lm.Close())) + }) +} + type benchValue struct { ID uint32 Val16 uint16 diff --git a/prog.go b/prog.go index bde48a56b..1c97ae3de 100644 --- a/prog.go +++ b/prog.go @@ -906,11 +906,12 @@ func marshalProgram(p *Program, length int) ([]byte, error) { return buf, nil } -// LoadPinnedProgram loads a Program from a BPF file. +// LoadPinnedProgram loads a Program from a pin (file) on the BPF virtual +// filesystem. // // Requires at least Linux 4.11. func LoadPinnedProgram(fileName string, opts *LoadPinOptions) (*Program, error) { - fd, err := sys.ObjGet(&sys.ObjGetAttr{ + fd, typ, err := sys.ObjGetTyped(&sys.ObjGetAttr{ Pathname: sys.NewStringPointer(fileName), FileFlags: opts.Marshal(), }) @@ -918,6 +919,11 @@ func LoadPinnedProgram(fileName string, opts *LoadPinOptions) (*Program, error) return nil, err } + if typ != sys.BPF_TYPE_PROG { + _ = fd.Close() + return nil, fmt.Errorf("%s is not a Program", fileName) + } + info, err := newProgramInfoFromFd(fd) if err != nil { _ = fd.Close()