Skip to content

Commit

Permalink
add PipeListener type
Browse files Browse the repository at this point in the history
  • Loading branch information
boxofrad committed Apr 12, 2023
1 parent 3d617bd commit 8fdefff
Show file tree
Hide file tree
Showing 2 changed files with 154 additions and 0 deletions.
84 changes: 84 additions & 0 deletions agent/grpc-internal/pipe.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0

package internal

import (
"context"
"errors"
"net"
"sync/atomic"
)

// ErrPipeClosed is returned when calling Accept or DialContext on a closed
// PipeListener.
var ErrPipeClosed = errors.New("pipe listener has been closed")

// PipeListener implements the net.Listener interface using a net.Pipe so that
// you can interact with a gRPC service in the same process without going over
// the network.
type PipeListener struct {
conns chan net.Conn
closed atomic.Bool
done chan struct{}
}

var _ net.Listener = (*PipeListener)(nil)

// NewPipeListener creates a new PipeListener.
func NewPipeListener() *PipeListener {
return &PipeListener{
conns: make(chan net.Conn),
done: make(chan struct{}),
}
}

// Accept a connection.
func (p *PipeListener) Accept() (net.Conn, error) {
select {
case conn := <-p.conns:
return conn, nil
case <-p.done:
return nil, ErrPipeClosed
}
}

// Close the listener.
func (p *PipeListener) Close() error {
if p.closed.CompareAndSwap(false, true) {
close(p.done)
}
return nil
}

// DialContext dials the server over an in-process pipe.
func (p *PipeListener) DialContext(ctx context.Context, _ string) (net.Conn, error) {
if p.closed.Load() {
return nil, ErrPipeClosed
}

serverConn, clientConn := net.Pipe()

select {
// Send the server connection to whatever is accepting connections from the
// PipeListener. This will block until something has accepted the conn.
case p.conns <- serverConn:
return clientConn, nil
case <-ctx.Done():
serverConn.Close()
clientConn.Close()
return nil, ctx.Err()
case <-p.done:
serverConn.Close()
clientConn.Close()
return nil, ErrPipeClosed
}
}

// Add returns the listener's address.
func (*PipeListener) Addr() net.Addr { return pipeAddr{} }

type pipeAddr struct{}

func (pipeAddr) Network() string { return "pipe" }
func (pipeAddr) String() string { return "pipe" }
70 changes: 70 additions & 0 deletions agent/grpc-internal/pipe_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0

package internal

import (
"bufio"
"context"
"net"
"testing"

"github.com/stretchr/testify/require"
)

func TestPipeListener_RoundTrip(t *testing.T) {
lis := NewPipeListener()
t.Cleanup(func() { _ = lis.Close() })

go echoServer(lis)

conn, err := lis.DialContext(context.Background(), "")
require.NoError(t, err)
t.Cleanup(func() { _ = conn.Close() })

input := []byte("Hello World\n")
_, err = conn.Write(input)
require.NoError(t, err)

output := make([]byte, len(input))
_, err = conn.Read(output)
require.NoError(t, err)

require.Equal(t, string(input), string(output))
}

func TestPipeListener_Closed(t *testing.T) {
lis := NewPipeListener()
require.NoError(t, lis.Close())

_, err := lis.Accept()
require.ErrorIs(t, err, ErrPipeClosed)

_, err = lis.DialContext(context.Background(), "")
require.ErrorIs(t, err, ErrPipeClosed)
}

func echoServer(lis net.Listener) {
handleConn := func(conn net.Conn) {
defer conn.Close()

reader := bufio.NewReader(conn)
for {
msg, err := reader.ReadBytes('\n')
if err != nil {
return
}
if _, err := conn.Write(msg); err != nil {
return
}
}
}

for {
conn, err := lis.Accept()
if err != nil {
return
}
go handleConn(conn)
}
}

0 comments on commit 8fdefff

Please sign in to comment.