From 36ebf17d5fd3be241f26071f10adf05f4f2f4a17 Mon Sep 17 00:00:00 2001 From: Sasha Klizhentas Date: Fri, 29 Sep 2017 17:49:48 -0700 Subject: [PATCH] checks in in proxy subsystem, fixes #1336 --- lib/srv/proxy.go | 66 +++++++++++++++++++++++++------------------ lib/srv/proxy_test.go | 18 ++++++++++++ lib/utils/utils.go | 12 ++++++++ 3 files changed, 69 insertions(+), 27 deletions(-) diff --git a/lib/srv/proxy.go b/lib/srv/proxy.go index 4f7b78ab17871..bfa72f9eac734 100644 --- a/lib/srv/proxy.go +++ b/lib/srv/proxy.go @@ -59,56 +59,68 @@ type proxySubsys struct { // "proxy:@clustername" - Teleport request to connect to an auth server for cluster with name 'clustername' // "proxy:host:22@clustername" - Teleport request to connect to host:22 on cluster 'clustername' // "proxy:host:22@namespace@clustername" -func parseProxySubsys(name string, srv *Server) (*proxySubsys, error) { - log.Debugf("parse_proxy_subsys(%s)", name) +func parseProxySubsys(request string, srv *Server) (*proxySubsys, error) { + log.Debugf("parse_proxy_subsys(%q)", request) var ( - clusterName string - host string - port string - paramError = trace.BadParameter("invalid format for proxy request: '%v', expected 'proxy:host:port@site'", name) + clusterName string + targetHost string + targetPort string + paramMessage = fmt.Sprintf("invalid format for proxy request: %q, expected 'proxy:host:port@cluster'", request) ) const prefix = "proxy:" // get rid of 'proxy:' prefix: - if strings.Index(name, prefix) != 0 { - return nil, trace.Wrap(paramError) + if strings.Index(request, prefix) != 0 { + return nil, trace.BadParameter(paramMessage) } - name = strings.TrimPrefix(name, prefix) + requestBody := strings.TrimPrefix(request, prefix) namespace := defaults.Namespace - // find the site name in the argument: - parts := strings.Split(name, "@") - switch len(parts) { - case 2: - clusterName = strings.Join(parts[1:], "@") - name = parts[0] - case 3: + var err error + parts := strings.Split(requestBody, "@") + switch { + case len(parts) == 0: // "proxy:" + return nil, trace.BadParameter(paramMessage) + case len(parts) == 1: // "proxy:host:22" + targetHost, targetPort, err = utils.SplitHostPort(parts[0]) + if err != nil { + return nil, trace.BadParameter(paramMessage) + } + case len(parts) == 2: // "proxy:@clustername" or "proxy:host:22@clustername" + if parts[0] != "" { + targetHost, targetPort, err = utils.SplitHostPort(parts[0]) + if err != nil { + return nil, trace.BadParameter(paramMessage) + } + } + clusterName = parts[1] + if clusterName == "" && targetHost == "" { + return nil, trace.BadParameter("invalid format for proxy request: missing cluster name or target host in %q", request) + } + case len(parts) > 3: // "proxy:host:22@namespace@clustername" clusterName = strings.Join(parts[2:], "@") namespace = parts[1] - name = parts[0] - } - // find host & port in the arguments: - host, port, err := net.SplitHostPort(name) - if clusterName == "" && err != nil { - return nil, trace.Wrap(paramError) + targetHost, targetPort, err = utils.SplitHostPort(parts[0]) + if err != nil { + return nil, trace.BadParameter(paramMessage) + } } if clusterName != "" && srv.proxyTun != nil { _, err := srv.proxyTun.GetSite(clusterName) if err != nil { - return nil, trace.BadParameter("unknown cluster '%s'", clusterName) + return nil, trace.BadParameter("invalid format for proxy request: unknown cluster %q in %q", clusterName, request) } } - return &proxySubsys{ namespace: namespace, srv: srv, - host: host, - port: port, + host: targetHost, + port: targetPort, siteName: clusterName, closeC: make(chan struct{}), }, nil } func (t *proxySubsys) String() string { - return fmt.Sprintf("proxySubsys(site=%s/%s, host=%s, port=%s)", + return fmt.Sprintf("proxySubsys(cluster=%s/%s, host=%s, port=%s)", t.namespace, t.siteName, t.host, t.port) } diff --git a/lib/srv/proxy_test.go b/lib/srv/proxy_test.go index 72adaa559a3a9..5c49a85cf5c9e 100644 --- a/lib/srv/proxy_test.go +++ b/lib/srv/proxy_test.go @@ -68,3 +68,21 @@ func (s *ProxyTestSuite) TestParseProxyRequest(c *check.C) { c.Assert(subsys.port, check.Equals, "100") c.Assert(subsys.siteName, check.Equals, "moon") } + +func (s *ProxyTestSuite) TestParseBadRequests(c *check.C) { + testCases := []string{ + // empty request + "proxy:", + // missing hostname + "proxy::80", + // missing hostname and missing cluster name + "proxy:@", + // just random string + "this is bad string", + } + for _, input := range testCases { + comment := check.Commentf("test case: %q", input) + _, err := parseProxySubsys(input, s.srv) + c.Assert(err, check.NotNil, comment) + } +} diff --git a/lib/utils/utils.go b/lib/utils/utils.go index 37b757e643427..b873d071c399c 100644 --- a/lib/utils/utils.go +++ b/lib/utils/utils.go @@ -33,6 +33,18 @@ import ( "golang.org/x/crypto/ssh" ) +// SplitHostPort splits host and port and checks that host is not empty +func SplitHostPort(hostname string) (string, string, error) { + host, port, err := net.SplitHostPort(hostname) + if err != nil { + return "", "", trace.Wrap(err) + } + if host == "" { + return "", "", trace.BadParameter("empty hostname") + } + return host, port, nil +} + type HostKeyCallback func(hostID string, remote net.Addr, key ssh.PublicKey) error func ReadPath(path string) ([]byte, error) {