Skip to content

Commit

Permalink
socket: handle ENOTSOCK from getsockopt for pidfd
Browse files Browse the repository at this point in the history
Signed-off-by: Matt Layher <mdlayher@gmail.com>
  • Loading branch information
mdlayher committed Apr 26, 2023
1 parent 41a913f commit 2a14cee
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 15 deletions.
49 changes: 34 additions & 15 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,20 +38,27 @@ type Conn struct {
// descriptors such as those created by accept(2).
name string

// Whether this is a streaming descriptor, as opposed to a
// packet-based descriptor like a UDP socket.
isStream bool

// Whether a zero byte read indicates EOF. This is false for a
// message based socket connection.
zeroReadIsEOF bool
// facts contains information we have determined about Conn to trigger
// alternate behavior in certain functions.
facts facts

// Provides access to the underlying file registered with the runtime
// network poller, and arbitrary raw I/O calls.
fd *os.File
rc syscall.RawConn
}

// facts contains facts about a Conn.
type facts struct {
// isStream reports whether this is a streaming descriptor, as opposed to a
// packet-based descriptor like a UDP socket.
isStream bool

// zeroReadIsEOF reports Whether a zero byte read indicates EOF. This is
// false for a message based socket connection.
zeroReadIsEOF bool
}

// A Config contains options for a Conn.
type Config struct {
// NetNS specifies the Linux network namespace the Conn will operate in.
Expand Down Expand Up @@ -109,14 +116,14 @@ func (c *Conn) Read(b []byte) (int, error) { return c.fd.Read(b) }
// ReadContext reads from the underlying file descriptor with added support for
// context cancelation.
func (c *Conn) ReadContext(ctx context.Context, b []byte) (int, error) {
if c.isStream && len(b) > maxRW {
if c.facts.isStream && len(b) > maxRW {
b = b[:maxRW]
}

n, err := readT(c, ctx, "read", func(fd int) (int, error) {
return unix.Read(fd, b)
})
if n == 0 && err == nil && c.zeroReadIsEOF {
if n == 0 && err == nil && c.facts.zeroReadIsEOF {
return 0, io.EOF
}

Expand All @@ -136,7 +143,7 @@ func (c *Conn) WriteContext(ctx context.Context, b []byte) (int, error) {

doErr := c.write(ctx, "write", func(fd int) error {
max := len(b)
if c.isStream && max-nn > maxRW {
if c.facts.isStream && max-nn > maxRW {
max = nn + maxRW
}

Expand Down Expand Up @@ -367,13 +374,25 @@ func New(fd int, name string) (*Conn, error) {
rc: rc,
}

// Probe the file descriptor for socket settings.
sotype, err := c.GetsockoptInt(unix.SOL_SOCKET, unix.SO_TYPE)
if err != nil {
switch {
case err == nil:
// File is a socket, check its properties.
c.facts = facts{
isStream: sotype == unix.SOCK_STREAM,
zeroReadIsEOF: sotype != unix.SOCK_DGRAM && sotype != unix.SOCK_RAW,
}
case errors.Is(err, unix.ENOTSOCK):
// File is not a socket, treat it as a regular file.
c.facts = facts{
isStream: true,
zeroReadIsEOF: true,
}
default:
return nil, err
}

c.isStream = sotype == unix.SOCK_STREAM
c.zeroReadIsEOF = sotype != unix.SOCK_DGRAM && sotype != unix.SOCK_RAW
return c, nil
}

Expand Down Expand Up @@ -544,7 +563,7 @@ func (c *Conn) Recvmsg(ctx context.Context, p, oob []byte, flags int) (int, int,
n, oobn, recvflags, from, err := unix.Recvmsg(fd, p, oob, flags)
return ret{n, oobn, recvflags, from}, err
})
if r.n == 0 && err == nil && c.zeroReadIsEOF {
if r.n == 0 && err == nil && c.facts.zeroReadIsEOF {
return 0, 0, 0, nil, io.EOF
}

Expand All @@ -562,7 +581,7 @@ func (c *Conn) Recvfrom(ctx context.Context, p []byte, flags int) (int, unix.Soc
n, addr, err := unix.Recvfrom(fd, p, flags)
return ret{n, addr}, err
})
if out.n == 0 && err == nil && c.zeroReadIsEOF {
if out.n == 0 && err == nil && c.facts.zeroReadIsEOF {
return 0, nil, io.EOF
}

Expand Down
15 changes: 15 additions & 0 deletions conn_linux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,18 @@ func TestLinuxDialVsockNoListener(t *testing.T) {
t.Fatalf("unexpected connect error (-want +got):\n%s", diff)
}
}

func TestLinuxOpenPIDFD(t *testing.T) {
// Verify we can use regular files with socket by properly handling
// ENOTSOCK, as is the case with pidfds.
fd, err := unix.PidfdOpen(1, unix.PIDFD_NONBLOCK)
if err != nil {
t.Fatalf("failed to open pidfd for init: %v", err)
}

c, err := socket.New(fd, "pidfd")
if err != nil {
t.Fatalf("failed to open Conn for pidfd: %v", err)
}
_ = c.Close()
}

0 comments on commit 2a14cee

Please sign in to comment.