diff --git a/lib/client/escape/reader.go b/lib/client/escape/reader.go index 528f758c428f7..af31782f928b7 100644 --- a/lib/client/escape/reader.go +++ b/lib/client/escape/reader.go @@ -25,7 +25,7 @@ import ( ) const ( - readerBufferLimit = 10 * 1 << 10 // 10MB + readerBufferLimit = 10 * 1024 * 1024 // 10MB // Note: on a raw terminal, "\r\n" is needed to move a cursor to the start // of next line. @@ -101,80 +101,81 @@ func newUnstartedReader(in io.Reader, out io.Writer, onDisconnect func(error)) * } func (r *Reader) runReads() { - // prev contains the last read escape sequence character. - // Possible values are: - // '\r' or '\n' after a fresh newline - // '~' after a newline and ~ - // '\000' (null) in any other case - prev := byte('\r') - // Read one character at a time to simplify the logic. - readBuf := make([]byte, 1) -outer: + readBuf := make([]byte, 1024) + // writeBuf is a copy of data in readBuf after filtering out any escape + // sequences. + writeBuf := make([]byte, 0, 1024) + // newLine is set iff the previous character was a newline. + // escape is set iff the two previous characters were a newline and '~'. + // + // Note: at most one of these is ever set. When escape is true, then + // newLine is false. + newLine, escape := true, false for { n, err := r.inner.Read(readBuf) if err != nil { r.setErr(err) return } - if n == 0 { - continue outer - } - // forward contains the characters to add to the internal buffer. - forward := readBuf - c := readBuf[0] - switch prev { - case '\r', '\n': - // Detect a tilde after a newline. - if c == '~' { - prev = '~' - // Do not send the tilde to remote end right way. - continue outer - } - prev = '\000' - case '~': - // We saw a newline and a tilde. Time to complete the escape - // sequence or abort it. - switch c { + // Reset the output buffer from previous state. + writeBuf = writeBuf[:0] + inner: + for _, b := range readBuf[:n] { + // Note: this switch only filters and updates newLine and escape. + // b is written to writeBuf afterwards. + switch b { + case '\r', '\n': + if escape { + // An incomplete escape sequence, send out a '~' that was + // previously suppressed. + writeBuf = append(writeBuf, '~') + } + newLine, escape = true, false + case '~': + if newLine { + // Start escape sequence, don't write the '~' just yet. + newLine, escape = false, true + continue inner + } else if escape { + newLine, escape = false, false + } case '?': - r.printHelp() - // Reset as if we're right after a newline. - prev = '\r' - // Do not send the help escape sequence to remote end. - continue outer + if escape { + // Complete help sequence. + r.printHelp() + newLine, escape = false, false + continue inner + } + newLine = false case '.': - // Disconnect and abort future reads. Previously-read data is - // still available. - r.setErr(ErrDisconnect) - return - case '~': - // Escaped tilde, let only one tilde through and reset prev to - // ignore all characters until the next newline. - prev = '\000' + if escape { + // Complete disconnect sequence. + r.setErr(ErrDisconnect) + return + } + newLine = false default: - // Not an escape sequence. Send over the blocked tilde and - // whatever character was typed in. - forward = []byte{prev, c} - // Reset prev to ignore all characters until the next newline. - prev = '\000' - } - default: - // If we're not in an escape sequence, ignore everything until a - // newline restarts a new potential sequence. - if c == '\r' || c == '\n' { - prev = c + if escape { + // An incomplete escape sequence, send out a '~' that was + // previously suppressed. + writeBuf = append(writeBuf, '~') + } + newLine, escape = false, false } + // Write the character out as-is, it wasn't filtered out above. + writeBuf = append(writeBuf, b) } // Add new data to internal buffer. r.cond.L.Lock() - if len(r.buf)+len(forward) > r.bufferLimit { + if len(r.buf)+len(writeBuf) > r.bufferLimit { // Unlock because setErr will want to lock too. r.cond.L.Unlock() r.setErr(ErrTooMuchBufferedData) return } - r.buf = append(r.buf, forward...) + r.buf = append(r.buf, writeBuf...) // Notify blocked Read calls about new data. r.cond.Broadcast() r.cond.L.Unlock() diff --git a/lib/client/escape/reader_test.go b/lib/client/escape/reader_test.go index 62e5f2ff2b572..c49521c76cafe 100644 --- a/lib/client/escape/reader_test.go +++ b/lib/client/escape/reader_test.go @@ -17,7 +17,7 @@ type ReaderSuite struct { var _ = check.Suite(&ReaderSuite{}) type readerTestCase struct { - inStream []byte + inChunks [][]byte inErr error wantReadErr error @@ -27,7 +27,7 @@ type readerTestCase struct { } func (*ReaderSuite) runCase(c *check.C, t readerTestCase) { - in := &mockReader{data: t.inStream, finalErr: t.inErr} + in := &mockReader{chunks: t.inChunks, finalErr: t.inErr} helpOut := new(bytes.Buffer) out := new(bytes.Buffer) var disconnectErr error @@ -46,31 +46,31 @@ func (*ReaderSuite) runCase(c *check.C, t readerTestCase) { func (s *ReaderSuite) TestNormalReads(c *check.C) { c.Log("normal read") s.runCase(c, readerTestCase{ - inStream: []byte("hello world"), + inChunks: [][]byte{[]byte("hello world")}, wantOut: "hello world", }) - c.Log("incomplete help sequence") + c.Log("incomplete sequence") s.runCase(c, readerTestCase{ - inStream: []byte("hello\r~world"), + inChunks: [][]byte{[]byte("hello\r~world")}, wantOut: "hello\r~world", }) c.Log("escaped tilde character") s.runCase(c, readerTestCase{ - inStream: []byte("hello\r~~world"), + inChunks: [][]byte{[]byte("hello\r~~world")}, wantOut: "hello\r~world", }) c.Log("other character between newline and tilde") s.runCase(c, readerTestCase{ - inStream: []byte("hello\rw~orld"), + inChunks: [][]byte{[]byte("hello\rw~orld")}, wantOut: "hello\rw~orld", }) c.Log("other character between newline and disconnect sequence") s.runCase(c, readerTestCase{ - inStream: []byte("hello\rw~.orld"), + inChunks: [][]byte{[]byte("hello\rw~.orld")}, wantOut: "hello\rw~.orld", }) } @@ -79,7 +79,7 @@ func (s *ReaderSuite) TestReadError(c *check.C) { customErr := errors.New("oh no") s.runCase(c, readerTestCase{ - inStream: []byte("hello world"), + inChunks: [][]byte{[]byte("hello world")}, inErr: customErr, wantOut: "hello world", wantReadErr: customErr, @@ -90,45 +90,74 @@ func (s *ReaderSuite) TestReadError(c *check.C) { func (s *ReaderSuite) TestEscapeHelp(c *check.C) { c.Log("single help sequence between reads") s.runCase(c, readerTestCase{ - inStream: []byte("hello\r~?world"), + inChunks: [][]byte{[]byte("hello\r~?world")}, wantOut: "hello\rworld", wantHelp: helpText, }) c.Log("single help sequence before any data") s.runCase(c, readerTestCase{ - inStream: []byte("~?hello world"), + inChunks: [][]byte{[]byte("~?hello world")}, wantOut: "hello world", wantHelp: helpText, }) c.Log("repeated help sequences") s.runCase(c, readerTestCase{ - inStream: []byte("hello\r~?world\n~?"), + inChunks: [][]byte{[]byte("hello\r~?world\n~?")}, wantOut: "hello\rworld\n", wantHelp: helpText + helpText, }) + + c.Log("help sequence split across reads") + s.runCase(c, readerTestCase{ + inChunks: [][]byte{ + []byte("hello\r"), + []byte("~"), + []byte("?"), + []byte("world"), + }, + wantOut: "hello\rworld", + wantHelp: helpText, + }) } func (s *ReaderSuite) TestEscapeDisconnect(c *check.C) { c.Log("single disconnect sequence between reads") s.runCase(c, readerTestCase{ - inStream: []byte("hello\r~.world"), - wantOut: "hello\r", + inChunks: [][]byte{ + []byte("hello"), + []byte("\r~."), + []byte("world"), + }, + wantOut: "hello", wantReadErr: ErrDisconnect, wantDisconnectErr: ErrDisconnect, }) c.Log("disconnect sequence before any data") s.runCase(c, readerTestCase{ - inStream: []byte("~.hello world"), + inChunks: [][]byte{[]byte("~.hello world")}, + wantReadErr: ErrDisconnect, + wantDisconnectErr: ErrDisconnect, + }) + + c.Log("disconnect sequence split across reads") + s.runCase(c, readerTestCase{ + inChunks: [][]byte{ + []byte("hello\r"), + []byte("~"), + []byte("."), + []byte("world"), + }, + wantOut: "hello\r", wantReadErr: ErrDisconnect, wantDisconnectErr: ErrDisconnect, }) } func (*ReaderSuite) TestBufferOverflow(c *check.C) { - in := &mockReader{data: make([]byte, 100)} + in := &mockReader{chunks: [][]byte{make([]byte, 100)}} helpOut := new(bytes.Buffer) out := new(bytes.Buffer) var disconnectErr error @@ -145,24 +174,20 @@ func (*ReaderSuite) TestBufferOverflow(c *check.C) { } type mockReader struct { - data []byte + chunks [][]byte finalErr error } func (r *mockReader) Read(buf []byte) (int, error) { - if len(r.data) == 0 { + if len(r.chunks) == 0 { if r.finalErr != nil { return 0, r.finalErr } return 0, io.EOF } - n := len(buf) - if n > len(r.data) { - n = len(r.data) - } - copy(buf, r.data) - r.data = r.data[n:] - return n, nil - + chunk := r.chunks[0] + r.chunks = r.chunks[1:] + copy(buf, chunk) + return len(chunk), nil }