diff --git a/notify_socket.go b/notify_socket.go index 00ca672d356..55f9a3ade4c 100644 --- a/notify_socket.go +++ b/notify_socket.go @@ -2,11 +2,14 @@ package main import ( "bytes" + "errors" + "io" "net" "os" "path" "path/filepath" "strconv" + "syscall" "time" "github.com/opencontainers/runc/libcontainer" @@ -142,29 +145,86 @@ func (n *notifySocket) run(pid1 int) error { return nil } case b := <-fileChan: - var out bytes.Buffer - _, err = out.Write(b) - if err != nil { - return err - } + return notifyHost(client, b, pid1) + } + } +} - _, err = out.Write([]byte{'\n'}) - if err != nil { - return err - } +// Tells the host (usually systemd) that the container reported READY. +// Also sends MAINPID and BARRIER. +func notifyHost(client *net.UnixConn, ready []byte, pid1 int) error { + var out bytes.Buffer + _, err := out.Write(ready) + if err != nil { + return err + } - _, err = client.Write(out.Bytes()) - if err != nil { - return err - } + _, err = out.Write([]byte{'\n'}) + if err != nil { + return err + } - // now we can inform systemd to use pid1 as the pid to monitor - newPid := "MAINPID=" + strconv.Itoa(pid1) - _, err := client.Write([]byte(newPid + "\n")) - if err != nil { - return err - } - return nil - } + _, err = client.Write(out.Bytes()) + if err != nil { + return err + } + + // now we can inform systemd to use pid1 as the pid to monitor + newPid := "MAINPID=" + strconv.Itoa(pid1) + _, err = client.Write([]byte(newPid + "\n")) + if err != nil { + return err + } + + // wait for systemd to acknowledge the communication + return sdNotifyBarrier(client) +} + +// Error reported when actual data was read from the pipe used to synchronize with +// systemd. Usually, that pipe is only closed. +var errUnexpectedRead = errors.New("unexpected read from synchronization pipe") + +// Synchronizes with systemd by means of the sd_notify_barrier protocol. +func sdNotifyBarrier(client *net.UnixConn) error { + // Create a pipe for communicating with systemd daemon. + pipeR, pipeW, err := os.Pipe() + if err != nil { + return err + } + + // Get the FD for the unix socket file to be able to do perform syscall.Sendmsg. + clientFd, err := client.File() + if err != nil { + return err + } + + // Send the write end of the pipe along with a BARRIER=1 message. + fdRights := syscall.UnixRights(int(pipeW.Fd())) + err = syscall.Sendmsg(int(clientFd.Fd()), []byte("BARRIER=1"), fdRights, nil, 0) + if err != nil { + return err + } + + // Close our copy of pipeW. + err = pipeW.Close() + if err != nil { + return err + } + + // Expect the read end of the pipe to be closed after 5 seconds. + err = pipeR.SetReadDeadline(time.Now().Add(5 * time.Second)) + if err != nil { + return nil + } + + // Read a single byte expecting EOF. + var buf [1]byte + n, err := pipeR.Read(buf[:]) + if n != 0 || err == nil { + return errUnexpectedRead + } else if err == io.EOF { //nolint:errorlint // comparison with io.EOF is legit. + return nil + } else { + return err } } diff --git a/notify_socket_test.go b/notify_socket_test.go new file mode 100644 index 00000000000..ec974c8a26a --- /dev/null +++ b/notify_socket_test.go @@ -0,0 +1,119 @@ +package main + +import ( + "bytes" + "io" + "net" + "syscall" + "testing" + "time" +) + +// Tests how runc reports container readyness to the host (usually systemd). +func TestNotifyHost(t *testing.T) { + addr := net.UnixAddr{ + Name: t.TempDir() + "/testsocket", + Net: "unixgram", + } + + server, err := net.ListenUnixgram("unixgram", &addr) + if err != nil { + t.Fatal(err) + } + defer server.Close() + + client, err := net.DialUnix("unixgram", nil, &addr) + if err != nil { + t.Fatal(err) + } + defer client.Close() + + // run notifyHost in a separate goroutine + notifyHostChan := make(chan error) + go func() { + notifyHostChan <- notifyHost(client, []byte("READY=42"), 1337) + }() + + // mock a host process listening for runc's notifications + expectRead(t, server, "READY=42\n") + expectRead(t, server, "MAINPID=1337\n") + expectBarrier(t, server, notifyHostChan) +} + +func expectRead(t *testing.T, r io.Reader, expected string) { + var buf [1024]byte + n, err := r.Read(buf[:]) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(buf[:n], []byte(expected)) { + t.Fatalf("Expected to read '%s' but runc sent '%s' instead", expected, buf[:n]) + } +} + +func expectBarrier(t *testing.T, conn *net.UnixConn, notifyHostChan <-chan error) { + var msg, oob [1024]byte + n, oobn, _, _, err := conn.ReadMsgUnix(msg[:], oob[:]) + if err != nil { + t.Fatal("Failed to receive BARRIER message", err) + } + if !bytes.Equal(msg[:n], []byte("BARRIER=1")) { + t.Fatalf("Expected to receive 'BARRIER=1' but got '%s' instead.", msg[:n]) + } + + fd := mustExtractFd(t, oob[:oobn]) + + // Test whether notifyHost actually honors the barrier + timer := time.NewTimer(500 * time.Millisecond) + select { + case <-timer.C: + // this is the expected case + break + case <-notifyHostChan: + t.Fatal("runc has terminated before barrier was lifted") + } + + // Lift the barrier + err = syscall.Close(fd) + if err != nil { + t.Fatal(err) + } + + // Expect notifyHost to terminate now + err = <-notifyHostChan + if err != nil { + t.Fatal("notifyHost function returned with error", err) + } +} + +func mustExtractFd(t *testing.T, buf []byte) int { + cmsgs, err := syscall.ParseSocketControlMessage(buf) + if err != nil { + t.Fatal("Failed to parse control message", err) + } + + fd := 0 + seenScmRights := false + for _, cmsg := range cmsgs { + if cmsg.Header.Type != syscall.SCM_RIGHTS { + continue + } + if seenScmRights { + t.Fatal("Expected to see exactly one SCM_RIGHTS message, but got a second one") + } + seenScmRights = true + fds, err := syscall.ParseUnixRights(&cmsg) + if err != nil { + t.Fatal("Failed to parse SCM_RIGHTS message", err) + } + if len(fds) != 1 { + t.Fatal("Expected to read exactly one file descriptor, but got", len(fds)) + } + fd = fds[0] + } + if !seenScmRights { + t.Fatal("Control messages didn't contain an SCM_RIGHTS message") + } + + return fd +}