Skip to content

Commit

Permalink
feat: implement CloseRead/CloseWrite
Browse files Browse the repository at this point in the history
fixes libp2p/go-libp2p-core#10

fix: avoid returning accept errors

Instead, wait for shutdown.
  • Loading branch information
Stebalien committed Aug 28, 2020
1 parent 0135c85 commit 7f5a301
Show file tree
Hide file tree
Showing 4 changed files with 179 additions and 98 deletions.
1 change: 1 addition & 0 deletions bench_test.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
//+build !race
package yamux

import (
Expand Down
18 changes: 11 additions & 7 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -220,14 +220,18 @@ func (s *Session) Accept() (net.Conn, error) {
// AcceptStream is used to block until the next available stream
// is ready to be accepted.
func (s *Session) AcceptStream() (*Stream, error) {
select {
case stream := <-s.acceptCh:
if err := stream.sendWindowUpdate(); err != nil {
return nil, err
for {
select {
case stream := <-s.acceptCh:
if err := stream.sendWindowUpdate(); err != nil {
// don't return accept errors.
s.logger.Printf("[WARN] error sending window update before accepting: %s", err)
continue
}
return stream, nil
case <-s.shutdownCh:
return nil, s.shutdownErr
}
return stream, nil
case <-s.shutdownCh:
return nil, s.shutdownErr
}
}

Expand Down
58 changes: 46 additions & 12 deletions session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,7 @@ func TestSendData_Small(t *testing.T) {
t.Errorf("err: %v", err)
return
}
defer stream.Close()

if server.NumStreams() != 1 {
t.Errorf("bad")
Expand All @@ -430,7 +431,7 @@ func TestSendData_Small(t *testing.T) {
}
}

if err := stream.Close(); err != nil {
if err := stream.CloseWrite(); err != nil {
t.Errorf("err: %v", err)
return
}
Expand All @@ -442,11 +443,12 @@ func TestSendData_Small(t *testing.T) {

go func() {
defer wg.Done()
stream, err := client.Open()
stream, err := client.OpenStream()
if err != nil {
t.Errorf("err: %v", err)
return
}
defer stream.Close()

if client.NumStreams() != 1 {
t.Errorf("bad")
Expand All @@ -465,7 +467,7 @@ func TestSendData_Small(t *testing.T) {
}
}

if err := stream.Close(); err != nil {
if err := stream.CloseWrite(); err != nil {
t.Errorf("err: %v", err)
return
}
Expand Down Expand Up @@ -785,30 +787,56 @@ func TestManyStreams_PingPong(t *testing.T) {
wg.Wait()
}

func TestHalfClose(t *testing.T) {
func TestCloseRead(t *testing.T) {
client, server := testClientServer()
defer client.Close()
defer server.Close()

stream, err := client.Open()
stream, err := client.OpenStream()
if err != nil {
t.Fatalf("err: %v", err)
}
if _, err = stream.Write([]byte("a")); err != nil {
t.Fatalf("err: %v", err)
}

stream2, err := server.Accept()
stream2, err := server.AcceptStream()
if err != nil {
t.Fatalf("err: %v", err)
}
stream2.Close() // Half close
stream2.CloseRead()

buf := make([]byte, 4)
n, err := stream2.Read(buf)
if n != 0 || err == nil {
t.Fatalf("read after close: %d %s", n, err)
}
}

func TestHalfClose(t *testing.T) {
client, server := testClientServer()
defer client.Close()
defer server.Close()

stream, err := client.OpenStream()
if err != nil {
t.Fatalf("err: %v", err)
}
if _, err = stream.Write([]byte("a")); err != nil {
t.Fatalf("err: %v", err)
}

stream2, err := server.AcceptStream()
if err != nil {
t.Fatalf("err: %v", err)
}
stream2.CloseWrite() // Half close

buf := make([]byte, 4)
n, err := io.ReadAtLeast(stream2, buf, 1)
if err != nil && err != io.EOF {
t.Fatalf("err: %v", err)
}
if n != 1 {
t.Fatalf("bad: %v", n)
}
Expand All @@ -817,11 +845,17 @@ func TestHalfClose(t *testing.T) {
if _, err = stream.Write([]byte("bcd")); err != nil {
t.Fatalf("err: %v", err)
}
stream.Close()
stream.CloseWrite()

// write after close
n, err = stream.Write([]byte("foobar"))
if n != 0 || err == nil {
t.Fatalf("wrote after close: %d %s", n, err)
}

// Read after close
n, err = stream2.Read(buf)
if err != nil {
n, err = io.ReadAtLeast(stream2, buf, 3)
if err != nil && err != io.EOF {
t.Fatalf("err: %v", err)
}
if n != 3 {
Expand Down Expand Up @@ -1131,7 +1165,6 @@ func TestSession_PartialReadWindowUpdate(t *testing.T) {
t.Errorf("err: %v", err)
return
}
defer wr.Close()

sendWindow := atomic.LoadUint32(&wr.sendWindow)
if sendWindow != client.config.MaxStreamWindowSize {
Expand Down Expand Up @@ -1352,8 +1385,9 @@ func TestStreamHalfClose2(t *testing.T) {
if err != nil {
t.Error(err)
}
defer stream.Close()

stream.Close()
stream.CloseWrite()
wait <- struct{}{}

buf, err := ioutil.ReadAll(stream)
Expand Down
Loading

0 comments on commit 7f5a301

Please sign in to comment.