Skip to content

Commit

Permalink
Refactor escape reader again to do batch reads
Browse files Browse the repository at this point in the history
Thanks to @fspmarshall's transcription of the OpenSSH version:
#3752 (review)

Also fix buffer limit to be 10MB, not 10KB. Shame on me :(
  • Loading branch information
Andrew Lytvynov authored and awly committed Jun 2, 2020
1 parent 3ba6502 commit d96a7e8
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 81 deletions.
111 changes: 56 additions & 55 deletions lib/client/escape/reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import (
)

const (
readerBufferLimit = 10 * 1 << 10 // 10MB
readerBufferLimit = 10 * 1024 * 1024 // 10MB

// Note: on a raw terminal, "\r\n" is needed to move a cursor to the start
// of next line.
Expand Down Expand Up @@ -101,80 +101,81 @@ func newUnstartedReader(in io.Reader, out io.Writer, onDisconnect func(error)) *
}

func (r *Reader) runReads() {
// prev contains the last read escape sequence character.
// Possible values are:
// '\r' or '\n' after a fresh newline
// '~' after a newline and ~
// '\000' (null) in any other case
prev := byte('\r')
// Read one character at a time to simplify the logic.
readBuf := make([]byte, 1)
outer:
readBuf := make([]byte, 1024)
// writeBuf is a copy of data in readBuf after filtering out any escape
// sequences.
writeBuf := make([]byte, 0, 1024)
// newLine is set iff the previous character was a newline.
// escape is set iff the two previous characters were a newline and '~'.
//
// Note: at most one of these is ever set. When escape is true, then
// newLine is false.
newLine, escape := true, false
for {
n, err := r.inner.Read(readBuf)
if err != nil {
r.setErr(err)
return
}
if n == 0 {
continue outer
}

// forward contains the characters to add to the internal buffer.
forward := readBuf
c := readBuf[0]
switch prev {
case '\r', '\n':
// Detect a tilde after a newline.
if c == '~' {
prev = '~'
// Do not send the tilde to remote end right way.
continue outer
}
prev = '\000'
case '~':
// We saw a newline and a tilde. Time to complete the escape
// sequence or abort it.
switch c {
// Reset the output buffer from previous state.
writeBuf = writeBuf[:0]
inner:
for _, b := range readBuf[:n] {
// Note: this switch only filters and updates newLine and escape.
// b is written to writeBuf afterwards.
switch b {
case '\r', '\n':
if escape {
// An incomplete escape sequence, send out a '~' that was
// previously suppressed.
writeBuf = append(writeBuf, '~')
}
newLine, escape = true, false
case '~':
if newLine {
// Start escape sequence, don't write the '~' just yet.
newLine, escape = false, true
continue inner
} else if escape {
newLine, escape = false, false
}
case '?':
r.printHelp()
// Reset as if we're right after a newline.
prev = '\r'
// Do not send the help escape sequence to remote end.
continue outer
if escape {
// Complete help sequence.
r.printHelp()
newLine, escape = false, false
continue inner
}
newLine = false
case '.':
// Disconnect and abort future reads. Previously-read data is
// still available.
r.setErr(ErrDisconnect)
return
case '~':
// Escaped tilde, let only one tilde through and reset prev to
// ignore all characters until the next newline.
prev = '\000'
if escape {
// Complete disconnect sequence.
r.setErr(ErrDisconnect)
return
}
newLine = false
default:
// Not an escape sequence. Send over the blocked tilde and
// whatever character was typed in.
forward = []byte{prev, c}
// Reset prev to ignore all characters until the next newline.
prev = '\000'
}
default:
// If we're not in an escape sequence, ignore everything until a
// newline restarts a new potential sequence.
if c == '\r' || c == '\n' {
prev = c
if escape {
// An incomplete escape sequence, send out a '~' that was
// previously suppressed.
writeBuf = append(writeBuf, '~')
}
newLine, escape = false, false
}
// Write the character out as-is, it wasn't filtered out above.
writeBuf = append(writeBuf, b)
}

// Add new data to internal buffer.
r.cond.L.Lock()
if len(r.buf)+len(forward) > r.bufferLimit {
if len(r.buf)+len(writeBuf) > r.bufferLimit {
// Unlock because setErr will want to lock too.
r.cond.L.Unlock()
r.setErr(ErrTooMuchBufferedData)
return
}
r.buf = append(r.buf, forward...)
r.buf = append(r.buf, writeBuf...)
// Notify blocked Read calls about new data.
r.cond.Broadcast()
r.cond.L.Unlock()
Expand Down
77 changes: 51 additions & 26 deletions lib/client/escape/reader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ type ReaderSuite struct {
var _ = check.Suite(&ReaderSuite{})

type readerTestCase struct {
inStream []byte
inChunks [][]byte
inErr error

wantReadErr error
Expand All @@ -27,7 +27,7 @@ type readerTestCase struct {
}

func (*ReaderSuite) runCase(c *check.C, t readerTestCase) {
in := &mockReader{data: t.inStream, finalErr: t.inErr}
in := &mockReader{chunks: t.inChunks, finalErr: t.inErr}
helpOut := new(bytes.Buffer)
out := new(bytes.Buffer)
var disconnectErr error
Expand All @@ -46,31 +46,31 @@ func (*ReaderSuite) runCase(c *check.C, t readerTestCase) {
func (s *ReaderSuite) TestNormalReads(c *check.C) {
c.Log("normal read")
s.runCase(c, readerTestCase{
inStream: []byte("hello world"),
inChunks: [][]byte{[]byte("hello world")},
wantOut: "hello world",
})

c.Log("incomplete help sequence")
c.Log("incomplete sequence")
s.runCase(c, readerTestCase{
inStream: []byte("hello\r~world"),
inChunks: [][]byte{[]byte("hello\r~world")},
wantOut: "hello\r~world",
})

c.Log("escaped tilde character")
s.runCase(c, readerTestCase{
inStream: []byte("hello\r~~world"),
inChunks: [][]byte{[]byte("hello\r~~world")},
wantOut: "hello\r~world",
})

c.Log("other character between newline and tilde")
s.runCase(c, readerTestCase{
inStream: []byte("hello\rw~orld"),
inChunks: [][]byte{[]byte("hello\rw~orld")},
wantOut: "hello\rw~orld",
})

c.Log("other character between newline and disconnect sequence")
s.runCase(c, readerTestCase{
inStream: []byte("hello\rw~.orld"),
inChunks: [][]byte{[]byte("hello\rw~.orld")},
wantOut: "hello\rw~.orld",
})
}
Expand All @@ -79,7 +79,7 @@ func (s *ReaderSuite) TestReadError(c *check.C) {
customErr := errors.New("oh no")

s.runCase(c, readerTestCase{
inStream: []byte("hello world"),
inChunks: [][]byte{[]byte("hello world")},
inErr: customErr,
wantOut: "hello world",
wantReadErr: customErr,
Expand All @@ -90,45 +90,74 @@ func (s *ReaderSuite) TestReadError(c *check.C) {
func (s *ReaderSuite) TestEscapeHelp(c *check.C) {
c.Log("single help sequence between reads")
s.runCase(c, readerTestCase{
inStream: []byte("hello\r~?world"),
inChunks: [][]byte{[]byte("hello\r~?world")},
wantOut: "hello\rworld",
wantHelp: helpText,
})

c.Log("single help sequence before any data")
s.runCase(c, readerTestCase{
inStream: []byte("~?hello world"),
inChunks: [][]byte{[]byte("~?hello world")},
wantOut: "hello world",
wantHelp: helpText,
})

c.Log("repeated help sequences")
s.runCase(c, readerTestCase{
inStream: []byte("hello\r~?world\n~?"),
inChunks: [][]byte{[]byte("hello\r~?world\n~?")},
wantOut: "hello\rworld\n",
wantHelp: helpText + helpText,
})

c.Log("help sequence split across reads")
s.runCase(c, readerTestCase{
inChunks: [][]byte{
[]byte("hello\r"),
[]byte("~"),
[]byte("?"),
[]byte("world"),
},
wantOut: "hello\rworld",
wantHelp: helpText,
})
}

func (s *ReaderSuite) TestEscapeDisconnect(c *check.C) {
c.Log("single disconnect sequence between reads")
s.runCase(c, readerTestCase{
inStream: []byte("hello\r~.world"),
wantOut: "hello\r",
inChunks: [][]byte{
[]byte("hello"),
[]byte("\r~."),
[]byte("world"),
},
wantOut: "hello",
wantReadErr: ErrDisconnect,
wantDisconnectErr: ErrDisconnect,
})

c.Log("disconnect sequence before any data")
s.runCase(c, readerTestCase{
inStream: []byte("~.hello world"),
inChunks: [][]byte{[]byte("~.hello world")},
wantReadErr: ErrDisconnect,
wantDisconnectErr: ErrDisconnect,
})

c.Log("disconnect sequence split across reads")
s.runCase(c, readerTestCase{
inChunks: [][]byte{
[]byte("hello\r"),
[]byte("~"),
[]byte("."),
[]byte("world"),
},
wantOut: "hello\r",
wantReadErr: ErrDisconnect,
wantDisconnectErr: ErrDisconnect,
})
}

func (*ReaderSuite) TestBufferOverflow(c *check.C) {
in := &mockReader{data: make([]byte, 100)}
in := &mockReader{chunks: [][]byte{make([]byte, 100)}}
helpOut := new(bytes.Buffer)
out := new(bytes.Buffer)
var disconnectErr error
Expand All @@ -145,24 +174,20 @@ func (*ReaderSuite) TestBufferOverflow(c *check.C) {
}

type mockReader struct {
data []byte
chunks [][]byte
finalErr error
}

func (r *mockReader) Read(buf []byte) (int, error) {
if len(r.data) == 0 {
if len(r.chunks) == 0 {
if r.finalErr != nil {
return 0, r.finalErr
}
return 0, io.EOF
}

n := len(buf)
if n > len(r.data) {
n = len(r.data)
}
copy(buf, r.data)
r.data = r.data[n:]
return n, nil

chunk := r.chunks[0]
r.chunks = r.chunks[1:]
copy(buf, chunk)
return len(chunk), nil
}

0 comments on commit d96a7e8

Please sign in to comment.