Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement OpenSSH-like escape sequences in tsh #3752

Merged
merged 3 commits into from
Jun 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions lib/client/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,11 @@ type Config struct {
// UseLocalSSHAgent will write user certificates to the local ssh-agent (or
// similar) socket at $SSH_AUTH_SOCK.
UseLocalSSHAgent bool

// EnableEscapeSequences will scan Stdin for SSH escape sequences during
// command/shell execution. This also requires Stdin to be an interactive
// terminal.
EnableEscapeSequences bool
}

// CachePolicy defines cache policy for local clients
Expand All @@ -273,10 +278,11 @@ type CachePolicy struct {
// MakeDefaultConfig returns default client config
func MakeDefaultConfig() *Config {
return &Config{
Stdout: os.Stdout,
Stderr: os.Stderr,
Stdin: os.Stdin,
UseLocalSSHAgent: true,
Stdout: os.Stdout,
Stderr: os.Stderr,
Stdin: os.Stdin,
UseLocalSSHAgent: true,
EnableEscapeSequences: true,
}
}

Expand Down Expand Up @@ -1469,7 +1475,7 @@ func (tc *TeleportClient) runCommand(
if len(nodeAddresses) > 1 {
fmt.Printf("Running command on %v:\n", address)
}
nodeSession, err = newSession(nodeClient, nil, tc.Config.Env, tc.Stdin, tc.Stdout, tc.Stderr, tc.useLegacyID(nodeClient))
nodeSession, err = newSession(nodeClient, nil, tc.Config.Env, tc.Stdin, tc.Stdout, tc.Stderr, tc.useLegacyID(nodeClient), tc.EnableEscapeSequences)
if err != nil {
log.Error(err)
return
Expand Down Expand Up @@ -1503,7 +1509,7 @@ func (tc *TeleportClient) runCommand(
// runShell starts an interactive SSH session/shell.
// sessionID : when empty, creates a new shell. otherwise it tries to join the existing session.
func (tc *TeleportClient) runShell(nodeClient *NodeClient, sessToJoin *session.Session) error {
nodeSession, err := newSession(nodeClient, sessToJoin, tc.Env, tc.Stdin, tc.Stdout, tc.Stderr, tc.useLegacyID(nodeClient))
nodeSession, err := newSession(nodeClient, sessToJoin, tc.Env, tc.Stdin, tc.Stdout, tc.Stderr, tc.useLegacyID(nodeClient), tc.EnableEscapeSequences)
if err != nil {
return trace.Wrap(err)
}
Expand Down
4 changes: 2 additions & 2 deletions lib/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func (s *ClientTestSuite) TestNewSession(c *check.C) {
}

// defaults:
ses, err := newSession(nc, nil, nil, nil, nil, nil, false)
ses, err := newSession(nc, nil, nil, nil, nil, nil, false, true)
c.Assert(err, check.IsNil)
c.Assert(ses, check.NotNil)
c.Assert(ses.NodeClient(), check.Equals, nc)
Expand All @@ -69,7 +69,7 @@ func (s *ClientTestSuite) TestNewSession(c *check.C) {
env := map[string]string{
sshutils.SessionEnvVar: "session-id",
}
ses, err = newSession(nc, nil, env, nil, nil, nil, false)
ses, err = newSession(nc, nil, env, nil, nil, nil, false, true)
c.Assert(err, check.IsNil)
c.Assert(ses, check.NotNil)
c.Assert(ses.env, check.DeepEquals, env)
Expand Down
220 changes: 220 additions & 0 deletions lib/client/escape/reader.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
/*
Copyright 2020 Gravitational, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

// Package escape implements client-side escape character logic.
// This logic mimics OpenSSH: https://man.openbsd.org/ssh#ESCAPE_CHARACTERS.
package escape

import (
"errors"
"io"
"sync"
)

const (
readerBufferLimit = 10 * 1024 * 1024 // 10MB

// Note: on a raw terminal, "\r\n" is needed to move a cursor to the start
// of next line.
helpText = "\r\ntsh escape characters:\r\n ~? - display a list of escape characters\r\n ~. - disconnect\r\n"
)

var (
// ErrDisconnect is returned when the user has entered a disconnect
// sequence, requesting connection to be interrupted.
ErrDisconnect = errors.New("disconnect escape sequence detected")
// ErrTooMuchBufferedData is returned when the Reader's internal buffer has
// filled over 10MB. Either the consumer of Reader can't keep up with the
// data or it's entirely stuck and not consuming the data.
ErrTooMuchBufferedData = errors.New("internal buffer has grown too big")
)

// Reader is an io.Reader wrapper that catches OpenSSH-like escape sequences in
// the input stream. See NewReader for more info.
//
// Reader is safe for concurrent use.
type Reader struct {
inner io.Reader
out io.Writer
onDisconnect func(error)
bufferLimit int

// cond protects buf and err and also announces to blocked readers that
// more data is available.
cond sync.Cond
buf []byte
err error
}

// NewReader creates a new Reader to catch escape sequences from 'in'.
//
// Two sequences are supported:
// - "~?": prints help text to 'out' listing supported sequences
// - "~.": disconnect stops any future reads from in; after this sequence,
// callers can still read any unread data up to this sequence from Reader but
// all future Read calls will return ErrDisconnect; onDisconnect will also be
// called with ErrDisconnect immediately
//
// NewReader starts consuming 'in' immediately in the background. This allows
// Reader to detect sequences without the caller actively calling Read (such as
// when it's stuck writing out the received data).
//
// Unread data is accumulated in an internal buffer. If this buffer grows to a
// limit (currently 10MB), Reader will stop permanently. onDisconnect will get
// called with ErrTooMuchBufferedData. Read can still be called to consume the
// internal buffer but all future reads after that will return
// ErrTooMuchBufferedData.
//
// If the internal buffer is empty, calls to Read will block until some data is
// available or an error occurs.
func NewReader(in io.Reader, out io.Writer, onDisconnect func(error)) *Reader {
r := newUnstartedReader(in, out, onDisconnect)
go r.runReads()
return r
}

// newUnstartedReader allows unit tests to mutate Reader before runReads
// starts.
func newUnstartedReader(in io.Reader, out io.Writer, onDisconnect func(error)) *Reader {
return &Reader{
inner: in,
out: out,
onDisconnect: onDisconnect,
bufferLimit: readerBufferLimit,
cond: sync.Cond{L: &sync.Mutex{}},
// note: no need to pre-allocate buf, it will allocate and grow as
// needed in runReads via append.
}
}

func (r *Reader) runReads() {
Copy link
Contributor

@fspmarshall fspmarshall May 26, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Few notes based on comparing this function to the process_escapes function in openssh:

  • Help escape sequence is not filtered out (openssh appears to filter out the help sequence).
  • Disconnect escape sequence drops all preceding bytes (openssh appears to send all data up to the disconnect escape sequence).
  • \r~ starts an escape, but \n~ does not (openssh treats \r and \n interchangeably here).
  • Reads ending in partial escape sequence (e.g. \r~) will send the escape character (~) to remote (openssh does not send the escape character until the following character is available to ensure that partial escape sequences are not transmitted to remote).

I think its preferable if we keep parity with openssh on these points (potentially confusing if we don't, since we shadow their syntax).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, the escape character should probably be configurable (or at least the behavior should be able to be deactivated). I don't know what this might break, but its probably best not to roll out a behavioral change like this without the ability to work around it if it does break something.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think @fspmarshall covered this in the fourth bullet, but this will not correctly recognize sequences that are split between buffers. Processing character-by-character is safer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a meta-response: I agree that it's not a perfect replica of openssh behavior. But it works for the use-case it's addressing (killing stuck sessions).

That being said, I'll massage the code some more to prevent escape sequences from going out to the remote side.

Disconnect escape sequence drops all preceding bytes (openssh appears to send all data up to the disconnect escape sequence).

What does it do when remote side is unresponsive?

\r~ starts an escape, but \n~ does not (openssh treats \r and \n interchangeably here).

Nice catch, will fix!

Also, the escape character should probably be configurable (or at least the behavior should be able to be deactivated).

This makes sense, I'll add a tsh flag.
Although I'd like to not over-complicate this niche feature.

this will not correctly recognize sequences that are split between buffers.

It will recognize sequences split between buffers via the prev buffer. But it will not block them the same way as openssh does.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does it do when remote side is unresponsive?

Ah, good point. Take this with a grain of salt because my knowledge of both openssh and C in general is pretty weak, and the state-machine around channels in openssh is complex.... but I believe it performs a single non-blocking write attempt with any remaining pre-escape sequence data and then exits. Thats a much weaker guarantee than I originally assumed, so I'd say forget that suggestion and keep dropping the preceeding bytes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ack. I don't see an easy way to do a non-blocking write at our level of abstraction in lib/client/session.go so I'll keep it as is.

Refactored the reader logic to read one character at a time and filter out escape sequences.
Also added handling for both \r and \n.

Next is the tsh flag plumbing...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added --no-enable-escape-sequences flag to tsh.
PTAL

readBuf := make([]byte, 1024)
// writeBuf is a copy of data in readBuf after filtering out any escape
// sequences.
writeBuf := make([]byte, 0, 1024)
// newLine is set iff the previous character was a newline.
// escape is set iff the two previous characters were a newline and '~'.
//
// Note: at most one of these is ever set. When escape is true, then
// newLine is false.
newLine, escape := true, false
for {
n, err := r.inner.Read(readBuf)
if err != nil {
r.setErr(err)
return
}

// Reset the output buffer from previous state.
writeBuf = writeBuf[:0]
inner:
for _, b := range readBuf[:n] {
// Note: this switch only filters and updates newLine and escape.
// b is written to writeBuf afterwards.
switch b {
case '\r', '\n':
if escape {
// An incomplete escape sequence, send out a '~' that was
// previously suppressed.
writeBuf = append(writeBuf, '~')
}
newLine, escape = true, false
case '~':
if newLine {
// Start escape sequence, don't write the '~' just yet.
newLine, escape = false, true
continue inner
} else if escape {
newLine, escape = false, false
}
case '?':
if escape {
// Complete help sequence.
r.printHelp()
newLine, escape = false, false
continue inner
}
newLine = false
case '.':
if escape {
// Complete disconnect sequence.
r.setErr(ErrDisconnect)
return
}
newLine = false
default:
if escape {
// An incomplete escape sequence, send out a '~' that was
// previously suppressed.
writeBuf = append(writeBuf, '~')
}
newLine, escape = false, false
}
// Write the character out as-is, it wasn't filtered out above.
writeBuf = append(writeBuf, b)
}

// Add new data to internal buffer.
r.cond.L.Lock()
if len(r.buf)+len(writeBuf) > r.bufferLimit {
// Unlock because setErr will want to lock too.
r.cond.L.Unlock()
r.setErr(ErrTooMuchBufferedData)
return
}
r.buf = append(r.buf, writeBuf...)
// Notify blocked Read calls about new data.
r.cond.Broadcast()
r.cond.L.Unlock()
}
}

// Read fills buf with available data. If no data is available, Read will
// block.
func (r *Reader) Read(buf []byte) (int, error) {
r.cond.L.Lock()
defer r.cond.L.Unlock()
// Block until some data was read in runReads.
for len(r.buf) == 0 && r.err == nil {
r.cond.Wait()
}

// Have some data to return.
n := len(r.buf)
if n > len(buf) {
n = len(buf)
}
// Write n available bytes to buf and trim them from r.buf.
copy(buf, r.buf[:n])
r.buf = r.buf[n:]

return n, r.err
}

func (r *Reader) setErr(err error) {
r.cond.L.Lock()
r.err = err
r.cond.Broadcast()
// Skip EOF, it's a normal clean exit.
if err != io.EOF {
r.onDisconnect(err)
}
r.cond.L.Unlock()
}

func (r *Reader) printHelp() {
r.out.Write([]byte(helpText))
}
Loading