Skip to content

Commit

Permalink
Merge pull request kubernetes#100951 from saschagrunert/automated-che…
Browse files Browse the repository at this point in the history
…rry-pick-of-#99839-upstream-release-1.21

Automated cherry pick of kubernetes#99839: Cleanup portforward streams after their usage
  • Loading branch information
k8s-ci-robot authored May 8, 2021
2 parents 5e853bf + b14bd44 commit 9745a35
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 4 deletions.
4 changes: 4 additions & 0 deletions pkg/kubelet/cri/streaming/portforward/httpstream.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,10 @@ func (h *httpStreamHandler) removeStreamPair(requestID string) {
h.streamPairsLock.Lock()
defer h.streamPairsLock.Unlock()

if h.conn != nil {
pair := h.streamPairs[requestID]
h.conn.RemoveStreams(pair.dataStream, pair.errorStream)
}
delete(h.streamPairs, requestID)
}

Expand Down
21 changes: 21 additions & 0 deletions pkg/kubelet/cri/streaming/portforward/httpstream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,23 @@ func TestHTTPStreamReceived(t *testing.T) {
}
}

type fakeConn struct {
removeStreamsCalled bool
}

func (*fakeConn) CreateStream(headers http.Header) (httpstream.Stream, error) { return nil, nil }
func (*fakeConn) Close() error { return nil }
func (*fakeConn) CloseChan() <-chan bool { return nil }
func (*fakeConn) SetIdleTimeout(timeout time.Duration) {}
func (f *fakeConn) RemoveStreams(streams ...httpstream.Stream) { f.removeStreamsCalled = true }

func TestGetStreamPair(t *testing.T) {
timeout := make(chan time.Time)

conn := &fakeConn{}
h := &httpStreamHandler{
streamPairs: make(map[string]*httpStreamPair),
conn: conn,
}

// test adding a new entry
Expand Down Expand Up @@ -158,6 +170,11 @@ func TestGetStreamPair(t *testing.T) {
// make sure monitorStreamPair completed
<-monitorDone

if !conn.removeStreamsCalled {
t.Fatalf("connection remove stream not called")
}
conn.removeStreamsCalled = false

// make sure the pair was removed
if h.hasStreamPair("1") {
t.Fatal("expected removal of pair after both data and error streams received")
Expand All @@ -171,6 +188,7 @@ func TestGetStreamPair(t *testing.T) {
if p == nil {
t.Fatal("expected p not to be nil")
}

monitorDone = make(chan struct{})
go func() {
h.monitorStreamPair(p, timeout)
Expand All @@ -183,6 +201,9 @@ func TestGetStreamPair(t *testing.T) {
if h.hasStreamPair("2") {
t.Fatal("expected stream pair to be removed")
}
if !conn.removeStreamsCalled {
t.Fatalf("connection remove stream not called")
}
}

func TestRequestID(t *testing.T) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ type Connection interface {
// SetIdleTimeout sets the amount of time the connection may remain idle before
// it is automatically closed.
SetIdleTimeout(timeout time.Duration)
// RemoveStreams can be used to remove a set of streams from the Connection.
RemoveStreams(streams ...Stream)
}

// Stream represents a bidirectional communications channel that is part of an
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import (
// streams.
type connection struct {
conn *spdystream.Connection
streams []httpstream.Stream
streams map[uint32]httpstream.Stream
streamLock sync.Mutex
newStreamHandler httpstream.NewStreamHandler
ping func() (time.Duration, error)
Expand Down Expand Up @@ -85,7 +85,12 @@ func NewServerConnectionWithPings(conn net.Conn, newStreamHandler httpstream.New
// will be invoked when the server receives a newly created stream from the
// client.
func newConnection(conn *spdystream.Connection, newStreamHandler httpstream.NewStreamHandler, pingPeriod time.Duration, pingFn func() (time.Duration, error)) httpstream.Connection {
c := &connection{conn: conn, newStreamHandler: newStreamHandler, ping: pingFn}
c := &connection{
conn: conn,
newStreamHandler: newStreamHandler,
ping: pingFn,
streams: make(map[uint32]httpstream.Stream),
}
go conn.Serve(c.newSpdyStream)
if pingPeriod > 0 && pingFn != nil {
go c.sendPings(pingPeriod)
Expand All @@ -105,7 +110,7 @@ func (c *connection) Close() error {
// calling Reset instead of Close ensures that all streams are fully torn down
s.Reset()
}
c.streams = make([]httpstream.Stream, 0)
c.streams = make(map[uint32]httpstream.Stream, 0)
c.streamLock.Unlock()

// now that all streams are fully torn down, it's safe to call close on the underlying connection,
Expand All @@ -114,6 +119,15 @@ func (c *connection) Close() error {
return c.conn.Close()
}

// RemoveStreams can be used to removes a set of streams from the Connection.
func (c *connection) RemoveStreams(streams ...httpstream.Stream) {
c.streamLock.Lock()
for _, stream := range streams {
delete(c.streams, stream.Identifier())
}
c.streamLock.Unlock()
}

// CreateStream creates a new stream with the specified headers and registers
// it with the connection.
func (c *connection) CreateStream(headers http.Header) (httpstream.Stream, error) {
Expand All @@ -133,7 +147,7 @@ func (c *connection) CreateStream(headers http.Header) (httpstream.Stream, error
// it owns.
func (c *connection) registerStream(s httpstream.Stream) {
c.streamLock.Lock()
c.streams = append(c.streams, s)
c.streams[s.Identifier()] = s
c.streamLock.Unlock()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,3 +290,41 @@ func TestConnectionPings(t *testing.T) {
t.Errorf("timed out waiting for server to exit")
}
}

type fakeStream struct{ id uint32 }

func (*fakeStream) Read(p []byte) (int, error) { return 0, nil }
func (*fakeStream) Write(p []byte) (int, error) { return 0, nil }
func (*fakeStream) Close() error { return nil }
func (*fakeStream) Reset() error { return nil }
func (*fakeStream) Headers() http.Header { return nil }
func (f *fakeStream) Identifier() uint32 { return f.id }

func TestConnectionRemoveStreams(t *testing.T) {
c := &connection{streams: make(map[uint32]httpstream.Stream)}
stream0 := &fakeStream{id: 0}
stream1 := &fakeStream{id: 1}
stream2 := &fakeStream{id: 2}

c.registerStream(stream0)
c.registerStream(stream1)

if len(c.streams) != 2 {
t.Fatalf("should have two streams, has %d", len(c.streams))
}

// not exists
c.RemoveStreams(stream2)

if len(c.streams) != 2 {
t.Fatalf("should have two streams, has %d", len(c.streams))
}

// remove all existing
c.RemoveStreams(stream0, stream1)

if len(c.streams) != 0 {
t.Fatalf("should not have any streams, has %d", len(c.streams))
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ func (c *fakeConnection) CloseChan() <-chan bool {
return c.closeChan
}

func (c *fakeConnection) RemoveStreams(_ ...httpstream.Stream) {
}

func (c *fakeConnection) SetIdleTimeout(timeout time.Duration) {
// no-op
}
Expand Down

0 comments on commit 9745a35

Please sign in to comment.