From 88a6afedca81d367d47042cced5d49d68144d051 Mon Sep 17 00:00:00 2001 From: Bryce Kahle Date: Mon, 3 Jun 2024 13:07:25 -0700 Subject: [PATCH] add Flush for manual Read/ReadInto wakeup Signed-off-by: Bryce Kahle --- internal/epoll/poller.go | 79 +++++++++++++++++++++++++++-------- internal/epoll/poller_test.go | 25 +++++++++++ internal/errors.go | 15 +++++++ perf/reader.go | 22 +++++++--- perf/reader_test.go | 76 +++++++++++++++++++++++++++++++++ ringbuf/reader.go | 24 +++++++++-- ringbuf/reader_test.go | 78 ++++++++++++++++++++++++++++++++-- 7 files changed, 289 insertions(+), 30 deletions(-) diff --git a/internal/epoll/poller.go b/internal/epoll/poller.go index 2235553b5..84c2acbe6 100644 --- a/internal/epoll/poller.go +++ b/internal/epoll/poller.go @@ -5,6 +5,7 @@ import ( "math" "os" "runtime" + "slices" "sync" "time" @@ -21,8 +22,9 @@ type Poller struct { epollMu sync.Mutex epollFd int - eventMu sync.Mutex - event *eventFd + eventMu sync.Mutex + closeEvent *eventFd + flushEvent *eventFd } func New() (*Poller, error) { @@ -32,16 +34,31 @@ func New() (*Poller, error) { } p := &Poller{epollFd: epollFd} - p.event, err = newEventFd() + p.closeEvent, err = newEventFd() if err != nil { unix.Close(epollFd) return nil, err } - if err := p.Add(p.event.raw, 0); err != nil { + p.flushEvent, err = newEventFd() + if err != nil { + p.closeEvent.close() + unix.Close(epollFd) + return nil, err + } + + if err := p.Add(p.closeEvent.raw, 0); err != nil { unix.Close(epollFd) - p.event.close() - return nil, fmt.Errorf("add eventfd: %w", err) + p.closeEvent.close() + p.flushEvent.close() + return nil, fmt.Errorf("add close eventfd: %w", err) + } + + if err := p.Add(p.flushEvent.raw, 0); err != nil { + unix.Close(epollFd) + p.closeEvent.close() + p.flushEvent.close() + return nil, fmt.Errorf("add flush eventfd: %w", err) } runtime.SetFinalizer(p, (*Poller).Close) @@ -55,8 +72,8 @@ func New() (*Poller, error) { func (p *Poller) Close() error { runtime.SetFinalizer(p, nil) - // Interrupt Wait() via the event fd if it's currently blocked. - if err := p.wakeWait(); err != nil { + // Interrupt Wait() via the closeEvent fd if it's currently blocked. + if err := p.wakeWaitForClose(); err != nil { return err } @@ -73,9 +90,14 @@ func (p *Poller) Close() error { p.epollFd = -1 } - if p.event != nil { - p.event.close() - p.event = nil + if p.closeEvent != nil { + p.closeEvent.close() + p.closeEvent = nil + } + + if p.flushEvent != nil { + p.flushEvent.close() + p.flushEvent = nil } return nil @@ -154,13 +176,22 @@ func (p *Poller) Wait(events []unix.EpollEvent, deadline time.Time) (int, error) return 0, fmt.Errorf("epoll wait: %w", os.ErrDeadlineExceeded) } - for _, event := range events[:n] { - if int(event.Fd) == p.event.raw { - // Since we don't read p.event the event is never cleared and + for i := 0; i < n; { + event := events[i] + if int(event.Fd) == p.closeEvent.raw { + // Since we don't read p.closeEvent the event is never cleared and // we'll keep getting this wakeup until Close() acquires the // lock and sets p.epollFd = -1. return 0, fmt.Errorf("epoll wait: %w", os.ErrClosed) } + if int(event.Fd) == p.flushEvent.raw { + // read event to prevent it from continuing to wake + p.flushEvent.read() + events = slices.Delete(events, i, i+1) + n -= 1 + continue + } + i++ } return n, nil @@ -171,16 +202,28 @@ type temporaryError interface { Temporary() bool } -// wakeWait unblocks Wait if it's epoll_wait. -func (p *Poller) wakeWait() error { +// wakeWaitForClose unblocks Wait if it's epoll_wait. +func (p *Poller) wakeWaitForClose() error { + p.eventMu.Lock() + defer p.eventMu.Unlock() + + if p.closeEvent == nil { + return fmt.Errorf("epoll wake: %w", os.ErrClosed) + } + + return p.closeEvent.add(1) +} + +// Flush unblocks Wait if it's epoll_wait, for purposes of reading pending samples +func (p *Poller) Flush() error { p.eventMu.Lock() defer p.eventMu.Unlock() - if p.event == nil { + if p.flushEvent == nil { return fmt.Errorf("epoll wake: %w", os.ErrClosed) } - return p.event.add(1) + return p.flushEvent.add(1) } // eventFd wraps a Linux eventfd. diff --git a/internal/epoll/poller_test.go b/internal/epoll/poller_test.go index 3641fb141..0d557ffdf 100644 --- a/internal/epoll/poller_test.go +++ b/internal/epoll/poller_test.go @@ -57,6 +57,31 @@ func TestPoller(t *testing.T) { t.Fatal(err) } + go func() { + defer func() { + done <- struct{}{} + }() + + events := make([]unix.EpollEvent, 1) + + n, err := poller.Wait(events, time.Time{}) + if err != nil { + t.Error("error from Wait:", err) + return + } + if n != 0 { + t.Errorf("got %d instead of 0 events", n) + } + }() + if err := poller.Flush(); err != nil { + t.Fatal("Flush returns an error:", err) + } + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("timed out") + } + go read() select { case <-done: diff --git a/internal/errors.go b/internal/errors.go index bda01e2fd..ef337b096 100644 --- a/internal/errors.go +++ b/internal/errors.go @@ -196,3 +196,18 @@ func (le *VerifierError) Format(f fmt.State, verb rune) { fmt.Fprintf(f, "%%!%c(BADVERB)", verb) } } + +type FlushCompleteError struct { + Err error +} + +func (fe *FlushCompleteError) Error() string { + if fe.Err == nil { + return "flush complete" + } + return fmt.Sprintf("flush complete: %s", fe.Err.Error()) +} + +func (fe *FlushCompleteError) Unwrap() error { + return fe.Err +} diff --git a/perf/reader.go b/perf/reader.go index 51ad6ced5..9fe580728 100644 --- a/perf/reader.go +++ b/perf/reader.go @@ -20,8 +20,11 @@ import ( var ( ErrClosed = os.ErrClosed errEOR = errors.New("end of ring") + errFlush = errors.New("ring flush") ) +type FlushCompleteError = internal.FlushCompleteError + var perfEventHeaderSize = binary.Size(perfEventHeader{}) // perfEventHeader must match 'struct perf_event_header` in . @@ -160,6 +163,8 @@ type Reader struct { overwritable bool bufferSize int + + pendingErr error } // ReaderOptions control the behaviour of the user @@ -356,13 +361,13 @@ func (pr *Reader) ReadInto(rec *Record) error { return fmt.Errorf("perf ringbuffer: %w", ErrClosed) } - deadlineWasExceeded := false for { if len(pr.epollRings) == 0 { - if deadlineWasExceeded { - // All rings were empty when the deadline expired, return + if pe := pr.pendingErr; pe != nil { + // All rings have been emptied since the error occurred, return // appropriate error. - return os.ErrDeadlineExceeded + pr.pendingErr = nil + return pe } // NB: The deferred pauseMu.Unlock will panic if Wait panics, which @@ -374,7 +379,7 @@ func (pr *Reader) ReadInto(rec *Record) error { if errors.Is(err, os.ErrDeadlineExceeded) { // We've hit the deadline, check whether there is any data in // the rings that we've not been woken up for. - deadlineWasExceeded = true + pr.pendingErr = &FlushCompleteError{Err: os.ErrDeadlineExceeded} } else if err != nil { return err } @@ -463,6 +468,13 @@ func (pr *Reader) BufferSize() int { return pr.bufferSize } +// Flush unblocks Read/ReadInto and successive Read/ReadInto calls will return pending samples at this point, +// until you receive a FlushCompleteError error. +func (pr *Reader) Flush() error { + pr.pendingErr = &FlushCompleteError{} + return pr.poller.Flush() +} + // NB: Has to be preceded by a call to ring.loadHead. func (pr *Reader) readRecordFromRing(rec *Record, ring *perfEventRing) error { defer ring.writeTail() diff --git a/perf/reader_test.go b/perf/reader_test.go index 39eac2653..9115a4ea0 100644 --- a/perf/reader_test.go +++ b/perf/reader_test.go @@ -68,6 +68,82 @@ func TestReaderSetDeadline(t *testing.T) { if _, err := rd.Read(); !errors.Is(err, os.ErrDeadlineExceeded) { t.Error("Expected os.ErrDeadlineExceeded from second Read, got:", err) } + + rd.SetDeadline(time.Now().Add(10 * time.Millisecond)) + if _, err := rd.Read(); !errors.Is(err, os.ErrDeadlineExceeded) { + t.Error("Expected os.ErrDeadlineExceeded from third Read, got:", err) + } +} + +func TestReaderSetDeadlinePendingEvents(t *testing.T) { + events := perfEventArray(t) + + rd, err := NewReaderWithOptions(events, 4096, ReaderOptions{WakeupEvents: 2}) + if err != nil { + t.Fatal(err) + } + defer rd.Close() + + outputSamples(t, events, 5) + + rd.SetDeadline(time.Now().Add(-time.Second)) + _, rem := checkRecord(t, rd) + qt.Assert(t, qt.Equals(rem, 0), qt.Commentf("expected zero Remaining")) + + outputSamples(t, events, 5) + + // another sample should not be returned before we get FlushCompleteError to indicate initial set of samples read + var fe *FlushCompleteError + _, err = rd.Read() + if !errors.As(err, &fe) { + t.Error("Expected FlushCompleteError from second Read, got:", err) + } + if !errors.Is(err, os.ErrDeadlineExceeded) { + t.Error("Expected os.ErrDeadlineExceeded from second Read, got:", err) + } + + // the second sample should now be read + _, rem = checkRecord(t, rd) +} + +func TestReaderFlushPendingEvents(t *testing.T) { + testutils.LockOSThreadToSingleCPU(t) + events := perfEventArray(t) + + rd, err := NewReaderWithOptions(events, 4096, ReaderOptions{WakeupEvents: 2}) + if err != nil { + t.Fatal(err) + } + defer rd.Close() + + outputSamples(t, events, 5) + + wait := make(chan int) + go func() { + wait <- 0 + _, rem := checkRecord(t, rd) + wait <- rem + }() + + <-wait + time.Sleep(10 * time.Millisecond) + err = rd.Flush() + qt.Assert(t, qt.IsNil(err)) + + rem := <-wait + qt.Assert(t, qt.Equals(rem, 0), qt.Commentf("expected zero Remaining")) + + outputSamples(t, events, 5) + + // another sample should not be returned before we get FlushCompleteError to indicate initial set of samples read + var fe *FlushCompleteError + _, err = rd.Read() + if !errors.As(err, &fe) { + t.Error("Expected FlushCompleteError from second Read, got:", err) + } + + // the second sample should now be read + _, rem = checkRecord(t, rd) } func outputSamples(tb testing.TB, events *ebpf.Map, sampleSizes ...byte) { diff --git a/ringbuf/reader.go b/ringbuf/reader.go index 82010e27b..cee90acf9 100644 --- a/ringbuf/reader.go +++ b/ringbuf/reader.go @@ -8,6 +8,7 @@ import ( "time" "github.com/cilium/ebpf" + "github.com/cilium/ebpf/internal" "github.com/cilium/ebpf/internal/epoll" "github.com/cilium/ebpf/internal/unix" ) @@ -18,6 +19,8 @@ var ( errBusy = errors.New("sample not committed yet") ) +type FlushCompleteError = internal.FlushCompleteError + // ringbufHeader from 'struct bpf_ringbuf_hdr' in kernel/bpf/ringbuf.c type ringbufHeader struct { Len uint32 @@ -55,6 +58,8 @@ type Reader struct { haveData bool deadline time.Time bufferSize int + + pendingErr error } // NewReader creates a new BPF ringbuf reader. @@ -146,13 +151,17 @@ func (r *Reader) ReadInto(rec *Record) error { for { if !r.haveData { + if pe := r.pendingErr; pe != nil { + r.pendingErr = nil + return pe + } + _, err := r.poller.Wait(r.epollEvents[:cap(r.epollEvents)], r.deadline) - if errors.Is(err, os.ErrDeadlineExceeded) && !r.ring.isEmpty() { + if errors.Is(err, os.ErrDeadlineExceeded) { // Ignoring this for reading a valid entry after timeout // This can occur if the producer submitted to the ring buffer with BPF_RB_NO_WAKEUP - err = nil - } - if err != nil { + r.pendingErr = &FlushCompleteError{Err: os.ErrDeadlineExceeded} + } else if err != nil { return err } r.haveData = true @@ -178,3 +187,10 @@ func (r *Reader) ReadInto(rec *Record) error { func (r *Reader) BufferSize() int { return r.bufferSize } + +// Flush unblocks Read/ReadInto and successive Read/ReadInto calls will return pending samples at this point, +// until you receive a FlushCompleteError error. +func (r *Reader) Flush() error { + r.pendingErr = &FlushCompleteError{} + return r.poller.Flush() +} diff --git a/ringbuf/reader_test.go b/ringbuf/reader_test.go index bed7d0096..e64e9434e 100644 --- a/ringbuf/reader_test.go +++ b/ringbuf/reader_test.go @@ -7,13 +7,15 @@ import ( "testing" "time" + "github.com/go-quicktest/qt" + "github.com/google/go-cmp/cmp" + "github.com/cilium/ebpf" "github.com/cilium/ebpf/asm" "github.com/cilium/ebpf/internal" "github.com/cilium/ebpf/internal/testutils" "github.com/cilium/ebpf/internal/testutils/fdtrace" "github.com/cilium/ebpf/internal/unix" - "github.com/google/go-cmp/cmp" ) type sampleMessage struct { @@ -284,7 +286,7 @@ func TestReaderNoWakeup(t *testing.T) { t.Error("Expected no error from first Read, got:", err) } if len(record.RawSample) != 5 { - t.Errorf("Expected to read 5 bytes bot got %d", len(record.RawSample)) + t.Errorf("Expected to read 5 bytes but got %d", len(record.RawSample)) } record, err = rd.Read() @@ -293,7 +295,77 @@ func TestReaderNoWakeup(t *testing.T) { t.Error("Expected no error from second Read, got:", err) } if len(record.RawSample) != 7 { - t.Errorf("Expected to read 7 bytes bot got %d", len(record.RawSample)) + t.Errorf("Expected to read 7 bytes but got %d", len(record.RawSample)) + } + + var fe *FlushCompleteError + _, err = rd.Read() + if !errors.As(err, &fe) { + t.Errorf("Expected FlushCompleteError from third Read but got %v", err) + } + if !errors.Is(err, os.ErrDeadlineExceeded) { + t.Errorf("Expected os.ErrDeadlineExceeded from third Read but got %v", err) + } +} + +func TestReaderFlushPendingEvents(t *testing.T) { + testutils.SkipOnOldKernel(t, "5.8", "BPF ring buffer") + + prog, events := mustOutputSamplesProg(t, + sampleMessage{size: 5, flags: unix.BPF_RB_NO_WAKEUP}, // Read after Flush + sampleMessage{size: 6, flags: unix.BPF_RB_NO_WAKEUP}, // Discard + sampleMessage{size: 7, flags: unix.BPF_RB_NO_WAKEUP}) // Read won't block + + rd, err := NewReader(events) + if err != nil { + t.Fatal(err) + } + defer rd.Close() + + ret, _, err := prog.Test(internal.EmptyBPFContext) + testutils.SkipIfNotSupported(t, err) + if err != nil { + t.Fatal(err) + } + + if errno := syscall.Errno(-int32(ret)); errno != 0 { + t.Fatal("Expected 0 as return value, got", errno) + } + + wait := make(chan *Record) + go func() { + wait <- nil + record, err := rd.Read() + qt.Assert(t, qt.IsNil(err)) + wait <- &record + }() + + <-wait + time.Sleep(10 * time.Millisecond) + err = rd.Flush() + qt.Assert(t, qt.IsNil(err)) + + waitRec := <-wait + if waitRec == nil { + t.Error("Expected to read record but got nil") + } + if waitRec != nil && len(waitRec.RawSample) != 5 { + t.Errorf("Expected to read 5 bytes but got %d", len(waitRec.RawSample)) + } + + record, err := rd.Read() + + if err != nil { + t.Error("Expected no error from second Read, got:", err) + } + if len(record.RawSample) != 7 { + t.Errorf("Expected to read 7 bytes but got %d", len(record.RawSample)) + } + + var fe *FlushCompleteError + _, err = rd.Read() + if !errors.As(err, &fe) { + t.Errorf("Expected FlushCompleteError from third Read but got %v", err) } }