From e12ec7422cac9e5d65435ecc42c9d4877ba63518 Mon Sep 17 00:00:00 2001 From: Sasha Klizhentas Date: Thu, 5 Oct 2017 17:29:31 -0700 Subject: [PATCH 01/24] refactoring --- Makefile | 7 +- docker/.env | 3 + docker/Makefile | 93 ++------- docker/docker-compose.yml | 96 +++++++++ docker/env.file | 2 + docker/one-proxy.yaml | 23 +++ docker/one.yaml | 6 +- docker/two-auth.yaml | 4 +- docker/two-node.yaml | 5 + docker/two-proxy.yaml | 4 + docker/two-tc.yaml | 5 +- lib/auth/api.go | 9 + lib/auth/apiserver.go | 77 +++++++ lib/auth/apiserver_test.go | 7 + lib/auth/auth_with_roles.go | 47 ++++- lib/auth/clt.go | 73 +++++++ lib/auth/permissions.go | 1 + lib/defaults/defaults.go | 4 + lib/reversetunnel/peer.go | 182 +++++++++++++++++ lib/reversetunnel/remotesite.go | 59 ++++-- lib/reversetunnel/srv.go | 291 +++++++++++++++++++++++---- lib/service/service.go | 22 +- lib/services/local/presence.go | 106 +++++++++- lib/services/local/services_test.go | 4 + lib/services/presence.go | 15 ++ lib/services/resource.go | 3 + lib/services/suite/suite.go | 49 +++++ lib/services/trustedcluster.go | 2 - lib/services/tunnelconn.go | 227 +++++++++++++++++++++ lib/srv/proxy.go | 38 ++-- lib/srv/proxy_test.go | 10 +- lib/srv/sshserver.go | 10 +- lib/state/cachingaccesspoint.go | 72 +++++++ lib/state/cachingaccesspoint_test.go | 17 ++ lib/utils/conn.go | 54 +++++ lib/utils/proxy/proxy.go | 21 +- lib/web/apiserver.go | 2 - lib/web/terminal.go | 4 + 38 files changed, 1453 insertions(+), 201 deletions(-) create mode 100644 docker/.env create mode 100644 docker/docker-compose.yml create mode 100644 docker/env.file create mode 100644 docker/one-proxy.yaml create mode 100644 lib/reversetunnel/peer.go create mode 100644 lib/services/tunnelconn.go create mode 100644 lib/utils/conn.go diff --git a/Makefile b/Makefile index 16b1dc049ca6a..46359c383a41c 100644 --- a/Makefile +++ b/Makefile @@ -31,6 +31,7 @@ LIBS = $(shell find lib -type f -name '*.go') *.go TCTLSRC = $(shell find tool/tctl -type f -name '*.go') TELEPORTSRC = $(shell find tool/teleport -type f -name '*.go') TSHSRC = $(shell find tool/tsh -type f -name '*.go') +TELEPORTVENDOR = $(shell find vendor -type f -name '*.go') # # 'make all' builds all 3 executables and plaaces them in a current directory @@ -44,13 +45,13 @@ all: $(VERSRC) go install $(BUILDFLAGS) ./lib/... $(MAKE) -s -j 4 $(BINARIES) -$(BUILDDIR)/tctl: $(LIBS) $(TCTLSRC) +$(BUILDDIR)/tctl: $(LIBS) $(TELEPORTSRC) $(TELEPORTVENDOR) go build -o $(BUILDDIR)/tctl -i $(BUILDFLAGS) ./tool/tctl -$(BUILDDIR)/teleport: $(LIBS) $(TELEPORTSRC) +$(BUILDDIR)/teleport: $(LIBS) $(TELEPORTSRC) $(TELEPORTVENDOR) go build -o $(BUILDDIR)/teleport -i $(BUILDFLAGS) ./tool/teleport -$(BUILDDIR)/tsh: $(LIBS) $(TSHSRC) +$(BUILDDIR)/tsh: $(LIBS) $(TELEPORTSRC) $(TELEPORTVENDOR) go build -o $(BUILDDIR)/tsh -i $(BUILDFLAGS) ./tool/tsh # diff --git a/docker/.env b/docker/.env new file mode 100644 index 0000000000000..9e2aadd1a2150 --- /dev/null +++ b/docker/.env @@ -0,0 +1,3 @@ +# file used by docker-compose itself (variables in yaml) +DEBUG=1 +CONTAINERHOME=/root/go/src/github.com/gravitational/teleport diff --git a/docker/Makefile b/docker/Makefile index d6ceaa0163567..ff9fab5def4c2 100644 --- a/docker/Makefile +++ b/docker/Makefile @@ -1,26 +1,21 @@ -TELEBOX=teleport:latest -HOMEDIR=$(abspath ..) -CONTAINERHOME=/root/go/src/github.com/gravitational/teleport -THISDIR=`pwd` -NETNAME=telenet -DOCKEROPS=--net $(NETNAME) -w $(CONTAINERHOME) -v $(HOMEDIR):$(CONTAINERHOME) - # # Default target starts two Teleport clusters # -.PHONY:run -run: prepare - $(MAKE) one - $(MAKE) two +.PHONY: up +up: + docker-compose up + +.PHONY: reup +reup: + cd .. && make + docker-compose up -# 'make stop' stops all Teleport containers, deletes them +# 'make down' stops all Teleport containers, deletes them # and their network # -.PHONY:stop -stop: - $(MAKE) stop-one - $(MAKE) stop-two - -@docker network rm $(NETNAME) +.PHONY:down +down: + docker-compose down # `make enter-one` gives you shell inside auth server # of cluster "one" @@ -50,64 +45,6 @@ enter-two-proxy: enter-two-node: docker exec -ti two-node /bin/bash -# `make shell` drops you into a bash shell inside an empty container, -# without Teleport running. Useful if you want to start it manually -# from the inside -.PHONY:shell -shell: prepare - -docker run --name=one --rm=true -ti \ - --hostname one \ - --ip 172.10.1.1 \ - --volume $(THISDIR)/data/one:/var/lib/teleport \ - $(DOCKEROPS) $(TELEBOX) /bin/bash - -docker network rm $(NETNAME) - -# `make one` starts the "One" container with single-node Teleport cluster -.PHONY:one -one: - docker run --name=one --detach=true \ - --hostname one \ - --ip 172.10.1.1 \ - --publish 3080:3080 -p 3023:3023 -p 4025:3025 \ - --volume $(THISDIR)/data/one:/var/lib/teleport \ - -e DEBUG=1 \ - $(DOCKEROPS) $(TELEBOX) build/teleport start -d -c $(CONTAINERHOME)/docker/one.yaml - -# 'make two' starts the three-node cluster in a container named "two" -.PHONY:two -two: - docker run --name=two-auth --detach=true \ - --hostname two-auth \ - --ip 172.10.1.2 \ - --volume $(THISDIR)/data/two/auth:/var/lib/teleport \ - -e DEBUG=1 \ - $(DOCKEROPS) $(TELEBOX) build/teleport start -d -c $(CONTAINERHOME)/docker/two-auth.yaml - docker run --name=two-proxy --detach=true \ - --hostname two-proxy \ - --ip 172.10.1.3 \ - --publish 5080:5080 -p 5023:5023 \ - --volume $(THISDIR)/data/two/proxy:/var/lib/teleport \ - -e DEBUG=1 \ - $(DOCKEROPS) $(TELEBOX) build/teleport start -d -c $(CONTAINERHOME)/docker/two-proxy.yaml - docker run --name=two-node --detach=true \ - --hostname two-node \ - --ip 172.10.1.4 \ - --volume $(THISDIR)/data/two/node:/var/lib/teleport \ - -e DEBUG=1 \ - $(DOCKEROPS) $(TELEBOX) build/teleport start -d -c $(CONTAINERHOME)/docker/two-node.yaml - - -# prepare is a sub-target: it creates a container image and a network -.PHONY:prepare -prepare: - docker build -t $(TELEBOX) . - -docker network create --subnet=172.10.0.0/16 $(NETNAME) - mkdir -p data/one data/two/proxy data/two/node data/two/auth - -.PHONY:stop-two -stop-two: - docker rm -f two-auth two-proxy two-node - -.PHONY:stop-one -stop-one: - docker rm -f one +.PHONY: setup-tc +setup-tc: + docker exec -i two-auth /bin/bash -c "tctl -c /root/go/src/github.com/gravitational/teleport/docker/two-auth.yaml create -f /root/go/src/github.com/gravitational/teleport/docker/two-tc.yaml" diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml new file mode 100644 index 0000000000000..3fbd604775e82 --- /dev/null +++ b/docker/docker-compose.yml @@ -0,0 +1,96 @@ +version: '2' +services: + # + # one is a single-node Teleport cluster called "one" (runs all 3 roles: proxy, auth and node) + # + one: + image: teleport:latest + container_name: one + command: ${CONTAINERHOME}/build/teleport start -d -c ${CONTAINERHOME}/docker/one.yaml + ports: + - "3080:3080" + - "3023:3023" + - "3025:3025" + env_file: env.file + volumes: + - ./data/one:/var/lib/teleport + - ../:/root/go/src/github.com/gravitational/teleport + networks: + teleport: + ipv4_address: 172.10.1.1 + + # + # one-proxy is a second xproxy of the first cluster + # + one-proxy: + image: teleport:latest + container_name: one-proxy + command: ${CONTAINERHOME}/build/teleport start -d -c ${CONTAINERHOME}/docker/one-proxy.yaml + ports: + - "4080:3080" + - "4023:3023" + env_file: env.file + volumes: + - ./data/one-proxy:/var/lib/teleport + - ../:/root/go/src/github.com/gravitational/teleport + networks: + teleport: + ipv4_address: 172.10.1.10 + + # + # two-auth is a auth server of the second cluster + # + two-auth: + image: teleport:latest + container_name: two-auth + command: ${CONTAINERHOME}/build/teleport start -d -c ${CONTAINERHOME}/docker/two-auth.yaml --insecure + env_file: env.file + volumes: + - ./data/two/auth:/var/lib/teleport + - ../:/root/go/src/github.com/gravitational/teleport + networks: + teleport: + ipv4_address: 172.10.1.2 + + # + # two-proxy is a proxy service for the second cluster + # + two-proxy: + image: teleport:latest + container_name: two-proxy + command: ${CONTAINERHOME}/build/teleport start -d -c ${CONTAINERHOME}/docker/two-proxy.yaml + env_file: env.file + ports: + - "5080:5080" + - "5023:5023" + volumes: + - ./data/two/proxy:/var/lib/teleport + - ../:/root/go/src/github.com/gravitational/teleport + networks: + teleport: + ipv4_address: 172.10.1.3 + + # + # two-node is a node service for the second cluster + # + two-node: + image: teleport:latest + container_name: two-node + command: ${CONTAINERHOME}/build/teleport start -d -c ${CONTAINERHOME}/docker/two-node.yaml + env_file: env.file + volumes: + - ./data/two/node:/var/lib/teleport + - ../:/root/go/src/github.com/gravitational/teleport + networks: + teleport: + ipv4_address: 172.10.1.4 + +networks: + teleport: + driver: bridge + ipam: + driver: default + config: + - subnet: 172.10.1.0/16 + ip_range: 172.10.1.0/24 + gateway: 172.10.1.254 diff --git a/docker/env.file b/docker/env.file new file mode 100644 index 0000000000000..b3587613c73b6 --- /dev/null +++ b/docker/env.file @@ -0,0 +1,2 @@ +DEBUG=1 +CONTAINERHOME=/root/go/src/github.com/gravitational/teleport \ No newline at end of file diff --git a/docker/one-proxy.yaml b/docker/one-proxy.yaml new file mode 100644 index 0000000000000..d87601a7d9f98 --- /dev/null +++ b/docker/one-proxy.yaml @@ -0,0 +1,23 @@ +# standalone proxy connected to +teleport: + auth_token: foo + nodename: one-proxy + advertise_ip: 172.10.1.10 + log: + output: /var/lib/teleport/teleport.log + severity: INFO + auth_servers: + - one:3025 + data_dir: /var/lib/teleport + storage: + path: /var/lib/teleport/backend + type: dir + +auth_service: + enabled: no + +ssh_service: + enabled: no + +proxy_service: + enabled: yes diff --git a/docker/one.yaml b/docker/one.yaml index 00e91899ee544..aceb5d5390948 100644 --- a/docker/one.yaml +++ b/docker/one.yaml @@ -1,13 +1,14 @@ # Single-node Teleport cluster called "one" (runs all 3 roles: proxy, auth and node) teleport: nodename: one + advertise_ip: 172.10.1.1 log: output: /var/lib/teleport/teleport.log severity: INFO - data_dir: /root/go/src/github.com/gravitational/teleport/docker/data/one + data_dir: /var/lib/teleport storage: - path: /root/go/src/github.com/gravitational/teleport/docker/data/one/backend + path: /var/lib/teleport/backend type: dir auth_service: @@ -33,3 +34,4 @@ ssh_service: proxy_service: enabled: yes + diff --git a/docker/two-auth.yaml b/docker/two-auth.yaml index 0d921c70755b3..8c2ad047ae9f2 100644 --- a/docker/two-auth.yaml +++ b/docker/two-auth.yaml @@ -5,9 +5,9 @@ teleport: output: /var/lib/teleport/teleport.log severity: INFO - data_dir: /root/go/src/github.com/gravitational/teleport/docker/data/two + data_dir: /var/lib/teleport storage: - path: /root/go/src/github.com/gravitational/teleport/docker/data/two/backend + path: /var/lib/teleport/backend type: dir auth_service: diff --git a/docker/two-node.yaml b/docker/two-node.yaml index 7bb8b192c7609..7680c5c482742 100644 --- a/docker/two-node.yaml +++ b/docker/two-node.yaml @@ -3,9 +3,14 @@ teleport: nodename: node-on-second-cluster auth_servers: ["two-auth"] auth_token: foo + advertise_ip: 172.10.1.4 log: output: /var/lib/teleport/teleport.log severity: INFO + data_dir: /var/lib/teleport + storage: + path: /var/lib/teleport/backend + type: dir ssh_service: enabled: yes diff --git a/docker/two-proxy.yaml b/docker/two-proxy.yaml index 4dc9ec3674568..a8d9edb67df0b 100644 --- a/docker/two-proxy.yaml +++ b/docker/two-proxy.yaml @@ -6,6 +6,10 @@ teleport: log: output: /var/lib/teleport/teleport.log severity: INFO + data_dir: /var/lib/teleport + storage: + path: /var/lib/teleport/backend + type: dir auth_service: enabled: no diff --git a/docker/two-tc.yaml b/docker/two-tc.yaml index c0c7d8d1b3ca1..9921e611c429d 100644 --- a/docker/two-tc.yaml +++ b/docker/two-tc.yaml @@ -1,12 +1,9 @@ kind: trusted_cluster -version: v1 +version: v2 metadata: - description: "" name: "one" - namespace: "default" spec: enabled: true - roles: ["admin"] token: "bar" tunnel_addr: one:3024 web_proxy_addr: one:3080 diff --git a/lib/auth/api.go b/lib/auth/api.go index b231231079516..60dbad345604c 100644 --- a/lib/auth/api.go +++ b/lib/auth/api.go @@ -60,4 +60,13 @@ type AccessPoint interface { // GetRoles returns a list of roles GetRoles() ([]services.Role, error) + + // UpsertTunnelConnection upserts tunnel connection + UpsertTunnelConnection(conn services.TunnelConnection) error + + // GetTunnelConnections returns tunnel connections for a given cluster + GetTunnelConnections(clusterName string) ([]services.TunnelConnection, error) + + // GetAllTunnelConnections returns all tunnel connections + GetAllTunnelConnections() ([]services.TunnelConnection, error) } diff --git a/lib/auth/apiserver.go b/lib/auth/apiserver.go index 7d197d5199afa..c1b0ebc2e9810 100644 --- a/lib/auth/apiserver.go +++ b/lib/auth/apiserver.go @@ -102,6 +102,11 @@ func NewAPIServer(config *APIConfig) http.Handler { srv.GET("/:version/authservers", srv.withAuth(srv.getAuthServers)) srv.POST("/:version/proxies", srv.withAuth(srv.upsertProxy)) srv.GET("/:version/proxies", srv.withAuth(srv.getProxies)) + srv.POST("/:version/tunnelconnections", srv.withAuth(srv.upsertTunnelConnection)) + srv.GET("/:version/tunnelconnections/:cluster", srv.withAuth(srv.getTunnelConnections)) + srv.GET("/:version/tunnelconnections", srv.withAuth(srv.getAllTunnelConnections)) + srv.DELETE("/:version/tunnelconnections/:cluster", srv.withAuth(srv.deleteTunnelConnections)) + srv.DELETE("/:version/tunnelconnections", srv.withAuth(srv.deleteAllTunnelConnections)) // Reverse tunnels srv.POST("/:version/reversetunnels", srv.withAuth(srv.upsertReverseTunnel)) @@ -1734,6 +1739,78 @@ func (s *APIServer) setClusterAuthPreference(auth ClientI, w http.ResponseWriter return message(fmt.Sprintf("cluster authenticaton preference set: %+v", cap)), nil } +type upsertTunnelConnectionRawReq struct { + TunnelConnection json.RawMessage `json:"tunnel_connection"` +} + +// upsertTunnelConnection updates or inserts tunnel connection +func (s *APIServer) upsertTunnelConnection(auth ClientI, w http.ResponseWriter, r *http.Request, p httprouter.Params, version string) (interface{}, error) { + var req upsertTunnelConnectionRawReq + if err := httplib.ReadJSON(r, &req); err != nil { + return nil, trace.Wrap(err) + } + conn, err := services.UnmarshalTunnelConnection(req.TunnelConnection) + if err != nil { + return nil, trace.Wrap(err) + } + if err := auth.UpsertTunnelConnection(conn); err != nil { + return nil, trace.Wrap(err) + } + return message("ok"), nil +} + +// getTunnelConnections returns a list of tunnel connections from a cluster +func (s *APIServer) getTunnelConnections(auth ClientI, w http.ResponseWriter, r *http.Request, p httprouter.Params, version string) (interface{}, error) { + conns, err := auth.GetTunnelConnections(p.ByName("cluster")) + if err != nil { + return nil, trace.Wrap(err) + } + items := make([]json.RawMessage, len(conns)) + for i, conn := range conns { + data, err := services.MarshalTunnelConnection(conn, services.WithVersion(version)) + if err != nil { + return nil, trace.Wrap(err) + } + items[i] = data + } + return items, nil +} + +// getAllTunnelConnections returns a list of tunnel connections from a cluster +func (s *APIServer) getAllTunnelConnections(auth ClientI, w http.ResponseWriter, r *http.Request, p httprouter.Params, version string) (interface{}, error) { + conns, err := auth.GetAllTunnelConnections() + if err != nil { + return nil, trace.Wrap(err) + } + items := make([]json.RawMessage, len(conns)) + for i, conn := range conns { + data, err := services.MarshalTunnelConnection(conn, services.WithVersion(version)) + if err != nil { + return nil, trace.Wrap(err) + } + items[i] = data + } + return items, nil +} + +// deleteTunnelConnections deletes all tunnel connections for cluster +func (s *APIServer) deleteTunnelConnections(auth ClientI, w http.ResponseWriter, r *http.Request, p httprouter.Params, version string) (interface{}, error) { + err := auth.DeleteTunnelConnections(p.ByName("cluster")) + if err != nil { + return nil, trace.Wrap(err) + } + return message("ok"), nil +} + +// deleteAllTunnelConnections deletes all tunnel connections +func (s *APIServer) deleteAllTunnelConnections(auth ClientI, w http.ResponseWriter, r *http.Request, p httprouter.Params, version string) (interface{}, error) { + err := auth.DeleteAllTunnelConnections() + if err != nil { + return nil, trace.Wrap(err) + } + return message("ok"), nil +} + func message(msg string) map[string]interface{} { return map[string]interface{}{"message": msg} } diff --git a/lib/auth/apiserver_test.go b/lib/auth/apiserver_test.go index 2b15ab62f84bc..21f60c8dbae16 100644 --- a/lib/auth/apiserver_test.go +++ b/lib/auth/apiserver_test.go @@ -210,6 +210,13 @@ func (s *APISuite) TestReadOwnRole(c *C) { c.Assert(err, NotNil) } +func (s *APISuite) TestTunnelConnectionsCRUD(c *C) { + suite := &suite.ServicesTestSuite{ + PresenceS: s.clt, + } + suite.TunnelConnectionsCRUD(c) +} + func (s *APISuite) TestGenerateKeysAndCerts(c *C) { priv, pub, err := s.clt.GenerateKeyPair("") c.Assert(err, IsNil) diff --git a/lib/auth/auth_with_roles.go b/lib/auth/auth_with_roles.go index e8f4fd634384d..53880970322d2 100644 --- a/lib/auth/auth_with_roles.go +++ b/lib/auth/auth_with_roles.go @@ -419,8 +419,7 @@ func (a *AuthWithRoles) GenerateKeyPair(pass string) ([]byte, []byte, error) { } func (a *AuthWithRoles) GenerateHostCert( - key []byte, hostID, nodeName, clusterName string, roles teleport.Roles, - ttl time.Duration) ([]byte, error) { + key []byte, hostID, nodeName, clusterName string, roles teleport.Roles, ttl time.Duration) ([]byte, error) { if err := a.action(defaults.Namespace, services.KindHostCert, services.VerbCreate); err != nil { return nil, trace.Wrap(err) @@ -919,6 +918,50 @@ func (a *AuthWithRoles) DeleteTrustedCluster(name string) error { return a.authServer.DeleteTrustedCluster(name) } +func (a *AuthWithRoles) UpsertTunnelConnection(conn services.TunnelConnection) error { + if err := a.action(defaults.Namespace, services.KindTunnelConnection, services.VerbCreate); err != nil { + return trace.Wrap(err) + } + if err := a.action(defaults.Namespace, services.KindTunnelConnection, services.VerbUpdate); err != nil { + return trace.Wrap(err) + } + return a.authServer.UpsertTunnelConnection(conn) +} + +func (a *AuthWithRoles) GetTunnelConnections(clusterName string) ([]services.TunnelConnection, error) { + if err := a.action(defaults.Namespace, services.KindTunnelConnection, services.VerbList); err != nil { + return nil, trace.Wrap(err) + } + return a.authServer.GetTunnelConnections(clusterName) +} + +func (a *AuthWithRoles) GetAllTunnelConnections() ([]services.TunnelConnection, error) { + if err := a.action(defaults.Namespace, services.KindTunnelConnection, services.VerbList); err != nil { + return nil, trace.Wrap(err) + } + return a.authServer.GetAllTunnelConnections() +} + +func (a *AuthWithRoles) DeleteTunnelConnections(clusterName string) error { + if err := a.action(defaults.Namespace, services.KindTunnelConnection, services.VerbList); err != nil { + return trace.Wrap(err) + } + if err := a.action(defaults.Namespace, services.KindTunnelConnection, services.VerbDelete); err != nil { + return trace.Wrap(err) + } + return a.authServer.DeleteTunnelConnections(clusterName) +} + +func (a *AuthWithRoles) DeleteAllTunnelConnections() error { + if err := a.action(defaults.Namespace, services.KindTunnelConnection, services.VerbList); err != nil { + return trace.Wrap(err) + } + if err := a.action(defaults.Namespace, services.KindTunnelConnection, services.VerbDelete); err != nil { + return trace.Wrap(err) + } + return a.authServer.DeleteAllTunnelConnections() +} + func (a *AuthWithRoles) Close() error { return a.authServer.Close() } diff --git a/lib/auth/clt.go b/lib/auth/clt.go index 675312d2ecf6a..d4a3c04048823 100644 --- a/lib/auth/clt.go +++ b/lib/auth/clt.go @@ -471,6 +471,79 @@ func (c *Client) DeleteReverseTunnel(domainName string) error { return trace.Wrap(err) } +// UpsertTunnelConnection upserts tunnel connection +func (c *Client) UpsertTunnelConnection(conn services.TunnelConnection) error { + data, err := services.MarshalTunnelConnection(conn) + if err != nil { + return trace.Wrap(err) + } + args := &upsertTunnelConnectionRawReq{ + TunnelConnection: data, + } + _, err = c.PostJSON(c.Endpoint("tunnelconnections"), args) + return trace.Wrap(err) +} + +// GetTunnelConnections returns tunnel connections for a given cluster +func (c *Client) GetTunnelConnections(clusterName string) ([]services.TunnelConnection, error) { + if clusterName == "" { + return nil, trace.BadParameter("missing cluster name parameter") + } + out, err := c.Get(c.Endpoint("tunnelconnections", clusterName), url.Values{}) + if err != nil { + return nil, trace.Wrap(err) + } + var items []json.RawMessage + if err := json.Unmarshal(out.Bytes(), &items); err != nil { + return nil, trace.Wrap(err) + } + conns := make([]services.TunnelConnection, len(items)) + for i, raw := range items { + conn, err := services.UnmarshalTunnelConnection(raw) + if err != nil { + return nil, trace.Wrap(err) + } + conns[i] = conn + } + return conns, nil +} + +// GetAllTunnelConnections returns all tunnel connections +func (c *Client) GetAllTunnelConnections() ([]services.TunnelConnection, error) { + out, err := c.Get(c.Endpoint("tunnelconnections"), url.Values{}) + if err != nil { + return nil, trace.Wrap(err) + } + var items []json.RawMessage + if err := json.Unmarshal(out.Bytes(), &items); err != nil { + return nil, trace.Wrap(err) + } + conns := make([]services.TunnelConnection, len(items)) + for i, raw := range items { + conn, err := services.UnmarshalTunnelConnection(raw) + if err != nil { + return nil, trace.Wrap(err) + } + conns[i] = conn + } + return conns, nil +} + +// DeleteTunnelConnections deletes all tunnel connections for cluster +func (c *Client) DeleteTunnelConnections(clusterName string) error { + if clusterName == "" { + return trace.BadParameter("missing parameter cluster name") + } + _, err := c.Delete(c.Endpoint("tunnelconnections", clusterName)) + return trace.Wrap(err) +} + +// DeleteAllTunnelConnections deletes all tunnel connections +func (c *Client) DeleteAllTunnelConnections() error { + _, err := c.Delete(c.Endpoint("tunnelconnections")) + return trace.Wrap(err) +} + // UpsertAuthServer is used by auth servers to report their presense // to other auth servers in form of hearbeat expiring after ttl period. func (c *Client) UpsertAuthServer(s services.Server) error { diff --git a/lib/auth/permissions.go b/lib/auth/permissions.go index 3f8ba1f0bc1e1..f7c060db30758 100644 --- a/lib/auth/permissions.go +++ b/lib/auth/permissions.go @@ -204,6 +204,7 @@ func GetCheckerForBuiltinRole(role teleport.Role) (services.AccessChecker, error services.NewRule(services.KindClusterAuthPreference, services.RO()), services.NewRule(services.KindClusterName, services.RO()), services.NewRule(services.KindStaticTokens, services.RO()), + services.NewRule(services.KindTunnelConnection, services.RW()), }, }, }) diff --git a/lib/defaults/defaults.go b/lib/defaults/defaults.go index 27a8e88fdee89..189632fba4c04 100644 --- a/lib/defaults/defaults.go +++ b/lib/defaults/defaults.go @@ -151,6 +151,10 @@ var ( // ReverseTunnelAgentHeartbeatPeriod is the period between agent heartbeat messages ReverseTunnelAgentHeartbeatPeriod = 5 * time.Second + // ReverseTunnelOfflineThreshold is the threshold of missed heartbeats + // after which we are going to declare the reverse tunnel offline + ReverseTunnelOfflineThreshold = 5 * ReverseTunnelAgentHeartbeatPeriod + // ServerHeartbeatTTL is a period between heartbeats // Median sleep time between node pings is this value / 2 + random // deviation added to this time to avoid lots of simultaneous diff --git a/lib/reversetunnel/peer.go b/lib/reversetunnel/peer.go new file mode 100644 index 0000000000000..4e6b26ed42b17 --- /dev/null +++ b/lib/reversetunnel/peer.go @@ -0,0 +1,182 @@ +/* +Copyright 2017 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +*/ + +package reversetunnel + +import ( + "fmt" + "net" + "time" + + "github.com/gravitational/teleport" + "github.com/gravitational/teleport/lib/auth" + "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/services" + + "github.com/gravitational/trace" + log "github.com/sirupsen/logrus" +) + +func newClusterPeers(clusterName string) *clusterPeers { + return &clusterPeers{ + clusterName: clusterName, + peers: make(map[string]*clusterPeer), + } +} + +// clusterPeers is a collection of cluster peers to a given cluster +type clusterPeers struct { + clusterName string + peers map[string]*clusterPeer +} + +func (p *clusterPeers) pickPeer() (*clusterPeer, error) { + var currentPeer *clusterPeer + for _, peer := range p.peers { + if currentPeer == nil || peer.connInfo.GetLastHeartbeat().After(currentPeer.connInfo.GetLastHeartbeat()) { + currentPeer = peer + } + } + if currentPeer == nil { + return nil, trace.NotFound("no active peers found for %v") + } + return currentPeer, nil +} + +func (p *clusterPeers) updatePeer(conn services.TunnelConnection) bool { + peer, ok := p.peers[conn.GetName()] + if !ok { + return false + } + peer.connInfo = conn + return true +} + +func (p *clusterPeers) addPeer(peer *clusterPeer) { + p.peers[peer.connInfo.GetName()] = peer +} + +func (p *clusterPeers) removePeer(connInfo services.TunnelConnection) { + delete(p.peers, connInfo.GetName()) +} + +func (p *clusterPeers) CachingAccessPoint() (auth.AccessPoint, error) { + peer, err := p.pickPeer() + if err != nil { + return nil, trace.Wrap(err) + } + return peer.CachingAccessPoint() +} + +func (p *clusterPeers) GetClient() (auth.ClientI, error) { + peer, err := p.pickPeer() + if err != nil { + return nil, trace.Wrap(err) + } + return peer.GetClient() +} + +func (p *clusterPeers) String() string { + return fmt.Sprintf("clusterPeer(%v)", p.clusterName) +} + +func (p *clusterPeers) GetStatus() string { + peer, err := p.pickPeer() + if err != nil { + return RemoteSiteStatusOffline + } + return peer.GetStatus() +} + +func (p *clusterPeers) GetName() string { + return p.clusterName +} + +func (p *clusterPeers) GetLastConnected() time.Time { + peer, err := p.pickPeer() + if err != nil { + return time.Time{} + } + return peer.GetLastConnected() +} + +// Dial is used to connect a requesting client (say, tsh) to an SSH server +// located in a remote connected site, the connection goes through the +// reverse proxy tunnel. +func (p *clusterPeers) Dial(from, to net.Addr) (conn net.Conn, err error) { + return nil, trace.ConnectionProblem(nil, "lost connection to reverse tunnel") +} + +// newClusterPeer returns new cluster peer +func newClusterPeer(srv *server, connInfo services.TunnelConnection) (*clusterPeer, error) { + clusterPeer := &clusterPeer{ + srv: srv, + connInfo: connInfo, + log: log.WithFields(log.Fields{ + teleport.Component: teleport.ComponentReverseTunnel, + teleport.ComponentFields: map[string]string{ + "cluster": connInfo.GetClusterName(), + "side": "server", + }, + }), + } + + return clusterPeer, nil +} + +// clusterPeer is a remote cluster that has established +// a tunnel to the peers +type clusterPeer struct { + log *log.Entry + connInfo services.TunnelConnection + srv *server +} + +func (s *clusterPeer) CachingAccessPoint() (auth.AccessPoint, error) { + return nil, trace.ConnectionProblem(nil, "lost connection to reverse tunnel") +} + +func (s *clusterPeer) GetClient() (auth.ClientI, error) { + return nil, trace.ConnectionProblem(nil, "lost connection to reverse tunnel") +} + +func (s *clusterPeer) String() string { + return fmt.Sprintf("clusterPeer(%v)", s.connInfo) +} + +func (s *clusterPeer) GetStatus() string { + diff := time.Now().Sub(s.connInfo.GetLastHeartbeat()) + if diff > defaults.ReverseTunnelOfflineThreshold { + return RemoteSiteStatusOffline + } + return RemoteSiteStatusOnline +} + +func (s *clusterPeer) GetName() string { + return s.connInfo.GetClusterName() +} + +func (s *clusterPeer) GetLastConnected() time.Time { + return s.connInfo.GetLastHeartbeat() +} + +// Dial is used to connect a requesting client (say, tsh) to an SSH server +// located in a remote connected site, the connection goes through the +// reverse proxy tunnel. +func (s *clusterPeer) Dial(from, to net.Addr) (conn net.Conn, err error) { + return nil, trace.ConnectionProblem(nil, "lost connection to remote proxy") +} diff --git a/lib/reversetunnel/remotesite.go b/lib/reversetunnel/remotesite.go index 6e95f3654e626..46be0a0e9e5a1 100644 --- a/lib/reversetunnel/remotesite.go +++ b/lib/reversetunnel/remotesite.go @@ -28,6 +28,7 @@ import ( "github.com/gravitational/roundtrip" "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/trace" @@ -37,7 +38,7 @@ import ( "golang.org/x/crypto/ssh" ) -// remoteSite is a remote site who established the inbound connecton to +// remoteSite is a remote site that established the inbound connecton to // the local reverse tunnel server, and now it can provide access to the // cluster behind it. type remoteSite struct { @@ -53,6 +54,7 @@ type remoteSite struct { transport *http.Transport clt *auth.Client accessPoint auth.AccessPoint + connInfo services.TunnelConnection } func (s *remoteSite) CachingAccessPoint() (auth.AccessPoint, error) { @@ -109,37 +111,60 @@ func (s *remoteSite) addConn(conn net.Conn, sshConn ssh.Conn) (*remoteConn, erro return rc, nil } +func (s *remoteSite) getLatestTunnelConnection() (services.TunnelConnection, error) { + conns, err := s.srv.AccessPoint.GetTunnelConnections(s.domainName) + if err != nil { + s.log.Warningf("[TUNNEL] failed to fetch tunnel statuses: %v", err) + return nil, trace.Wrap(err) + } + var lastConn services.TunnelConnection + for i := range conns { + conn := conns[i] + if lastConn == nil || conn.GetLastHeartbeat().After(lastConn.GetLastHeartbeat()) { + lastConn = conn + } + } + if lastConn == nil { + return nil, trace.NotFound("no connections from %v found in the cluster", s.domainName) + } + return lastConn, nil +} + func (s *remoteSite) GetStatus() string { - s.Lock() - defer s.Unlock() - diff := time.Now().Sub(s.lastActive) - if diff > 2*defaults.ReverseTunnelAgentHeartbeatPeriod { + connInfo, err := s.getLatestTunnelConnection() + if err != nil { + return RemoteSiteStatusOffline + } + diff := time.Now().Sub(connInfo.GetLastHeartbeat()) + if diff > defaults.ReverseTunnelOfflineThreshold { return RemoteSiteStatusOffline } return RemoteSiteStatusOnline } -func (s *remoteSite) setLastActive(t time.Time) { - s.Lock() - defer s.Unlock() - s.lastActive = t +func (s *remoteSite) registerHeartbeat(t time.Time) { + s.connInfo.SetLastHeartbeat(t) + err := s.srv.AccessPoint.UpsertTunnelConnection(s.connInfo) + if err != nil { + log.Warningf("[TUNNEL] failed to register heartbeat: %v", err) + } } func (s *remoteSite) handleHeartbeat(conn *remoteConn, ch ssh.Channel, reqC <-chan *ssh.Request) { defer func() { - s.log.Infof("[TUNNEL] site connection closed: %v", s.domainName) + s.log.Infof("[TUNNEL] cluster connection closed: %v", s.domainName) conn.Close() }() for { select { case req := <-reqC: if req == nil { - s.log.Infof("[TUNNEL] site disconnected: %v", s.domainName) + s.log.Infof("[TUNNEL] cluster disconnected: %v", s.domainName) conn.markInvalid(trace.ConnectionProblem(nil, "agent disconnected")) return } log.Debugf("[TUNNEL] ping from \"%s\" %s", s.domainName, conn.conn.RemoteAddr()) - s.setLastActive(time.Now()) + go s.registerHeartbeat(time.Now()) case <-time.After(3 * defaults.ReverseTunnelAgentHeartbeatPeriod): conn.markInvalid(trace.ConnectionProblem(nil, "agent missed 3 heartbeats")) } @@ -151,9 +176,11 @@ func (s *remoteSite) GetName() string { } func (s *remoteSite) GetLastConnected() time.Time { - s.Lock() - defer s.Unlock() - return s.lastActive + connInfo, err := s.getLatestTunnelConnection() + if err != nil { + return time.Time{} + } + return connInfo.GetLastHeartbeat() } // dialAccessPoint establishes a connection from the proxy (reverse tunnel server) @@ -246,7 +273,7 @@ func (s *remoteSite) Dial(from, to net.Addr) (conn net.Conn, err error) { // didn't connect and no error? this means we didn't have any connected // tunnels to try if err == nil { - err = trace.Errorf("%v is offline", s.GetName()) + err = trace.ConnectionProblem(nil, "%v is offline", s.GetName()) } return nil, err } diff --git a/lib/reversetunnel/srv.go b/lib/reversetunnel/srv.go index 0cddf15a8a721..7ff735f6486de 100644 --- a/lib/reversetunnel/srv.go +++ b/lib/reversetunnel/srv.go @@ -18,12 +18,14 @@ limitations under the License. package reversetunnel import ( + "context" "fmt" "net" "net/http" "strings" "sync" "sync/atomic" + "time" "github.com/gravitational/teleport" "github.com/gravitational/teleport/lib/auth" @@ -44,6 +46,7 @@ import ( // (also known as 'reverse tunnel agents'. type server struct { sync.RWMutex + Config // localAuth points to the cluster's auth server API localAuth auth.AccessPoint @@ -61,66 +64,108 @@ type server struct { // usually each of them is a local proxy. localSites []*localSite + // clusterPeers is a map of clusters connected to peer proxies + // via reverse tunnels + clusterPeers map[string]*clusterPeers + // newAccessPoint returns new caching access point newAccessPoint state.NewCachingAccessPoint + + // cancel function will cancel the + cancel context.CancelFunc + + // ctx is a context used for signalling and broadcast + ctx context.Context } -// ServerOption sets reverse tunnel server options -type ServerOption func(s *server) error +// DirectCluster is used to access cluster directly +type DirectCluster struct { + // Name is a cluster name + Name string + // Client is a client to the cluster + Client auth.ClientI +} + +// Config is a reverse tunnel server configuration +type Config struct { + // ID is the ID of this server proxy + ID string + // ListenAddr is a listening address for reverse tunnel server + ListenAddr utils.NetAddr + // HostSigners is a list of host signers + HostSigners []ssh.Signer + // HostKeyCallback + // Limiter is optional request limiter + Limiter *limiter.Limiter + // AccessPoint is access point + AccessPoint auth.AccessPoint + // NewCachingAccessPoint returns new caching access points + // per remote cluster + NewCachingAccessPoint state.NewCachingAccessPoint + // DirectClusters is a list of clusters accessed directly + DirectClusters []DirectCluster + // Context is a signalling context + Context context.Context +} -// DirectSite instructs server to proxy access to this site not using -// reverse tunnel -func DirectSite(domainName string, clt auth.ClientI) ServerOption { - return func(s *server) error { - site, err := newlocalSite(s, domainName, clt) +// CheckAndSetDefaults checks parameters and sets default values +func (cfg *Config) CheckAndSetDefaults() error { + if cfg.ID == "" { + return trace.BadParameter("missing parameter ID") + } + if cfg.ListenAddr.IsEmpty() { + return trace.BadParameter("missing parameter ListenAddr") + } + if cfg.Context == nil { + cfg.Context = context.TODO() + } + if cfg.Limiter == nil { + var err error + cfg.Limiter, err = limiter.NewLimiter(limiter.LimiterConfig{}) if err != nil { return trace.Wrap(err) } - s.localSites = append(s.localSites, site) - return nil - } -} - -// SetLimiter sets rate limiter for reverse tunnel -func SetLimiter(limiter *limiter.Limiter) ServerOption { - return func(s *server) error { - s.limiter = limiter - return nil } + return nil } // NewServer creates and returns a reverse tunnel server which is fully // initialized but hasn't been started yet -func NewServer(addr utils.NetAddr, hostSigners []ssh.Signer, - authAPI auth.AccessPoint, fn state.NewCachingAccessPoint, opts ...ServerOption) (Server, error) { - +func NewServer(cfg Config) (Server, error) { + if err := cfg.CheckAndSetDefaults(); err != nil { + return nil, trace.Wrap(err) + } + ctx, cancel := context.WithCancel(cfg.Context) srv := &server{ + Config: cfg, localSites: []*localSite{}, remoteSites: []*remoteSite{}, - localAuth: authAPI, - newAccessPoint: fn, - } - var err error - srv.limiter, err = limiter.NewLimiter(limiter.LimiterConfig{}) - if err != nil { - return nil, trace.Wrap(err) + localAuth: cfg.AccessPoint, + newAccessPoint: cfg.NewCachingAccessPoint, + limiter: cfg.Limiter, + ctx: ctx, + cancel: cancel, + clusterPeers: make(map[string]*clusterPeers), } - for _, o := range opts { - if err := o(srv); err != nil { + for _, clusterInfo := range cfg.DirectClusters { + cluster, err := newlocalSite(srv, clusterInfo.Name, clusterInfo.Client) + if err != nil { return nil, trace.Wrap(err) } + srv.localSites = append(srv.localSites, cluster) } + var err error s, err := sshutils.NewServer( teleport.ComponentReverseTunnel, - addr, + cfg.ListenAddr, srv, - hostSigners, + cfg.HostSigners, sshutils.AuthMethods{ PublicKey: srv.keyAuth, }, - sshutils.SetLimiter(srv.limiter), + sshutils.SetLimiter(cfg.Limiter), ) if err != nil { return nil, err @@ -128,9 +173,143 @@ func NewServer(addr utils.NetAddr, hostSigners []ssh.Signer, srv.hostCertChecker = ssh.CertChecker{IsAuthority: srv.isHostAuthority} srv.userCertChecker = ssh.CertChecker{IsAuthority: srv.isUserAuthority} srv.srv = s + go srv.periodicFetchClusterPeers() return srv, nil } +func (s *server) periodicFetchClusterPeers() { + ticker := time.NewTicker(defaults.ReverseTunnelAgentHeartbeatPeriod) + defer ticker.Stop() + if err := s.fetchClusterPeers(); err != nil { + log.Warningf("[TUNNEL] failed to fetch cluster peers: %v", err) + } + for { + select { + case <-s.ctx.Done(): + log.Debugf("[TUNNEL] closing") + return + case <-ticker.C: + err := s.fetchClusterPeers() + if err != nil { + log.Warningf("[TUNNEL] failed to fetch cluster peers: %v", err) + } + } + } +} + +func (s *server) fetchClusterPeers() error { + conns, err := s.AccessPoint.GetAllTunnelConnections() + if err != nil { + return trace.Wrap(err) + } + newConns := make(map[string]services.TunnelConnection) + for i := range conns { + newConn := conns[i] + // filter out peer records for our own proxy + if newConn.GetProxyName() == s.ID { + continue + } + newConns[newConn.GetName()] = newConn + } + existingConns := s.existingConns() + connsToAdd, connsToUpdate, connsToRemove := s.diffConns(newConns, existingConns) + s.removeClusterPeers(connsToRemove) + s.updateClusterPeers(connsToUpdate) + return s.addClusterPeers(connsToAdd) +} + +func (s *server) addClusterPeers(conns map[string]services.TunnelConnection) error { + for key := range conns { + connInfo := conns[key] + peer, err := newClusterPeer(s, connInfo) + if err != nil { + return trace.Wrap(err) + } + s.addClusterPeer(peer) + } + return nil +} + +func (s *server) updateClusterPeers(conns map[string]services.TunnelConnection) { + for key := range conns { + connInfo := conns[key] + s.updateClusterPeer(connInfo) + } +} + +func (s *server) addClusterPeer(peer *clusterPeer) { + s.Lock() + defer s.Unlock() + clusterName := peer.connInfo.GetClusterName() + peers, ok := s.clusterPeers[clusterName] + if !ok { + peers = newClusterPeers(clusterName) + s.clusterPeers[clusterName] = peers + } + peers.addPeer(peer) +} + +func (s *server) updateClusterPeer(conn services.TunnelConnection) bool { + s.Lock() + defer s.Unlock() + clusterName := conn.GetClusterName() + peers, ok := s.clusterPeers[clusterName] + if !ok { + return false + } + return peers.updatePeer(conn) +} + +func (s *server) removeClusterPeers(conns []services.TunnelConnection) { + s.Lock() + defer s.Unlock() + for _, conn := range conns { + peers, ok := s.clusterPeers[conn.GetClusterName()] + if !ok { + log.Warningf("[TUNNEL] failed to remove cluster peer, not found peers for %v", conn) + continue + } + peers.removePeer(conn) + log.Debugf("[TUNNEL] removed cluster peer %v", conn) + } +} + +func (s *server) existingConns() map[string]services.TunnelConnection { + s.RLock() + defer s.RUnlock() + conns := make(map[string]services.TunnelConnection) + for _, peers := range s.clusterPeers { + for _, cluster := range peers.peers { + conns[cluster.connInfo.GetName()] = cluster.connInfo + } + } + return conns +} + +func (s *server) diffConns(newConns, existingConns map[string]services.TunnelConnection) (map[string]services.TunnelConnection, map[string]services.TunnelConnection, []services.TunnelConnection) { + connsToAdd := make(map[string]services.TunnelConnection) + connsToUpdate := make(map[string]services.TunnelConnection) + var connsToRemove []services.TunnelConnection + + for existingKey := range existingConns { + conn := existingConns[existingKey] + if _, ok := newConns[existingKey]; !ok { // tunnel was removed + connsToRemove = append(connsToRemove, conn) + } + } + + for newKey := range newConns { + conn := newConns[newKey] + if _, ok := existingConns[newKey]; !ok { // tunnel was added + connsToAdd[newKey] = conn + } else { + connsToUpdate[newKey] = conn + } + } + + return connsToAdd, connsToUpdate, connsToRemove +} + func (s *server) Wait() { s.srv.Wait() } @@ -140,6 +319,7 @@ func (s *server) Start() error { } func (s *server) Close() error { + s.cancel() return s.srv.Close() } @@ -353,38 +533,56 @@ func (s *server) upsertSite(conn net.Conn, sshConn *ssh.ServerConn) (*remoteSite } s.remoteSites = append(s.remoteSites, site) } - log.Infof("[TUNNEL] site %v connected from %v. sites: %d", + log.Infof("[TUNNEL] cluster %v connected from %v. clusters: %d", domainName, conn.RemoteAddr(), len(s.remoteSites)) + // treat first connection as a registered heartbeat, + // otherwise the connection information will appear after initial + // heartbeat delay + go site.registerHeartbeat(time.Now()) return site, remoteConn, nil } func (s *server) GetSites() []RemoteSite { s.RLock() defer s.RUnlock() - out := make([]RemoteSite, 0, len(s.remoteSites)+len(s.localSites)) + out := make([]RemoteSite, 0, len(s.remoteSites)+len(s.localSites)+len(s.clusterPeers)) for i := range s.localSites { out = append(out, s.localSites[i]) } + haveLocalConnection := make(map[string]bool) for i := range s.remoteSites { - out = append(out, s.remoteSites[i]) + site := s.remoteSites[i] + haveLocalConnection[site.GetName()] = true + out = append(out, site) + } + for i := range s.clusterPeers { + cluster := s.clusterPeers[i] + if _, ok := haveLocalConnection[cluster.GetName()]; !ok { + out = append(out, cluster) + } } return out } -func (s *server) GetSite(domainName string) (RemoteSite, error) { +func (s *server) GetSite(name string) (RemoteSite, error) { s.RLock() defer s.RUnlock() for i := range s.remoteSites { - if s.remoteSites[i].domainName == domainName { + if s.remoteSites[i].GetName() == name { return s.remoteSites[i], nil } } for i := range s.localSites { - if s.localSites[i].domainName == domainName { + if s.localSites[i].GetName() == name { return s.localSites[i], nil } } - return nil, trace.NotFound("site '%v' not found", domainName) + for i := range s.clusterPeers { + if s.clusterPeers[i].GetName() == name { + return s.clusterPeers[i], nil + } + } + return nil, trace.NotFound("cluster %q is not found", name) } func (s *server) RemoveSite(domainName string) error { @@ -402,7 +600,7 @@ func (s *server) RemoveSite(domainName string) error { return nil } } - return trace.NotFound("site '%v' not found", domainName) + return trace.NotFound("cluster %q is not found", domainName) } type remoteConn struct { @@ -431,9 +629,22 @@ func (rc *remoteConn) isInvalid() bool { // newRemoteSite helper creates and initializes 'remoteSite' instance func newRemoteSite(srv *server, domainName string) (*remoteSite, error) { + connInfo, err := services.NewTunnelConnection( + fmt.Sprintf("%v-%v", srv.ID, domainName), + services.TunnelConnectionSpecV2{ + ClusterName: domainName, + ProxyName: srv.ID, + LastHeartbeat: time.Now().UTC(), + }, + ) + if err != nil { + return nil, trace.Wrap(err) + } + remoteSite := &remoteSite{ srv: srv, domainName: domainName, + connInfo: connInfo, log: log.WithFields(log.Fields{ teleport.Component: teleport.ComponentReverseTunnel, teleport.ComponentFields: map[string]string{ diff --git a/lib/service/service.go b/lib/service/service.go index 4c699326625a1..ff98c8f75e1f7 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -689,14 +689,20 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { } tsrv, err := reversetunnel.NewServer( - cfg.Proxy.ReverseTunnelListenAddr, - []ssh.Signer{conn.Identity.KeySigner}, - authClient, - process.newLocalCache, - reversetunnel.SetLimiter(reverseTunnelLimiter), - reversetunnel.DirectSite(conn.Identity.Cert.Extensions[utils.CertExtensionAuthority], - conn.Client), - ) + reversetunnel.Config{ + ID: conn.Identity.ID.HostUUID, + ListenAddr: cfg.Proxy.ReverseTunnelListenAddr, + HostSigners: []ssh.Signer{conn.Identity.KeySigner}, + AccessPoint: authClient, + NewCachingAccessPoint: process.newLocalCache, + Limiter: reverseTunnelLimiter, + DirectClusters: []reversetunnel.DirectCluster{ + { + Name: conn.Identity.Cert.Extensions[utils.CertExtensionAuthority], + Client: conn.Client, + }, + }, + }) if err != nil { return trace.Wrap(err) } diff --git a/lib/services/local/presence.go b/lib/services/local/presence.go index f16d300e0f154..0b7577db89d5a 100644 --- a/lib/services/local/presence.go +++ b/lib/services/local/presence.go @@ -336,11 +336,105 @@ func (s *PresenceService) DeleteTrustedCluster(name string) error { return trace.Wrap(err) } +// UpsertTunnelConnection updates or creates tunnel connection +func (s *PresenceService) UpsertTunnelConnection(conn services.TunnelConnection) error { + if err := conn.CheckAndSetDefaults(); err != nil { + return trace.Wrap(err) + } + bytes, err := services.MarshalTunnelConnection(conn) + if err != nil { + return trace.Wrap(err) + } + metadata := conn.GetMetadata() + ttl := backend.TTL(s.Clock(), metadata.Expiry()) + err = s.UpsertVal([]string{tunnelConnectionsPrefix, conn.GetClusterName()}, conn.GetName(), bytes, ttl) + if err != nil { + return trace.Wrap(err) + } + return nil +} + +// GetTunnelConnection returns connection by cluster name and connection name +func (s *PresenceService) GetTunnelConnection(clusterName, connectionName string) (services.TunnelConnection, error) { + data, err := s.GetVal([]string{tunnelConnectionsPrefix, clusterName}, connectionName) + if err != nil { + if trace.IsNotFound(err) { + return nil, trace.NotFound("trusted cluster connection %q is not found", connectionName) + } + return nil, trace.Wrap(err) + } + return services.UnmarshalTunnelConnection(data) +} + +// GetTunnelConnections returns connections for a trusted cluster +func (s *PresenceService) GetTunnelConnections(clusterName string) ([]services.TunnelConnection, error) { + if clusterName == "" { + return nil, trace.BadParameter("missing cluster name") + } + var conns []services.TunnelConnection + keys, err := s.GetKeys([]string{tunnelConnectionsPrefix, clusterName}) + if err != nil { + if trace.IsNotFound(err) { + return nil, nil + } + return nil, trace.Wrap(err) + } + for _, key := range keys { + conn, err := s.GetTunnelConnection(clusterName, key) + if err != nil { + if !trace.IsNotFound(err) { + return nil, trace.Wrap(err) + } + } + conns = append(conns, conn) + } + return conns, nil +} + +// GetAllTunnelConnections returns all tunnel connections +func (s *PresenceService) GetAllTunnelConnections() ([]services.TunnelConnection, error) { + var conns []services.TunnelConnection + clusters, err := s.GetKeys([]string{tunnelConnectionsPrefix}) + if err != nil { + if trace.IsNotFound(err) { + return nil, nil + } + return nil, trace.Wrap(err) + } + for _, clusterName := range clusters { + clusterConns, err := s.GetTunnelConnections(clusterName) + if err != nil { + return nil, trace.Wrap(err) + } + conns = append(conns, clusterConns...) + } + return conns, nil +} + +// DeleteTunnelConnections deletes all tunnel connections for cluster +func (s *PresenceService) DeleteTunnelConnections(clusterName string) error { + err := s.DeleteBucket([]string{tunnelConnectionsPrefix}, clusterName) + if trace.IsNotFound(err) { + return nil + } + return err +} + +// DeleteAllTunnelConnections deletes all tunnel connections +func (s *PresenceService) DeleteAllTunnelConnections() error { + err := s.DeleteBucket([]string{}, tunnelConnectionsPrefix) + if trace.IsNotFound(err) { + return nil + } + return err +} + const ( - localClusterPrefix = "localCluster" - reverseTunnelsPrefix = "reverseTunnels" - nodesPrefix = "nodes" - namespacesPrefix = "namespaces" - authServersPrefix = "authservers" - proxiesPrefix = "proxies" + localClusterPrefix = "localCluster" + reverseTunnelsPrefix = "reverseTunnels" + tunnelConnectionsPrefix = "tunnelConnections" + nodesPrefix = "nodes" + namespacesPrefix = "namespaces" + authServersPrefix = "authservers" + proxiesPrefix = "proxies" ) diff --git a/lib/services/local/services_test.go b/lib/services/local/services_test.go index 76558851344e0..aace2fd8a8df7 100644 --- a/lib/services/local/services_test.go +++ b/lib/services/local/services_test.go @@ -105,3 +105,7 @@ func (s *ServicesSuite) TestU2FCRUD(c *C) { func (s *ServicesSuite) TestSAMLCRUD(c *C) { s.suite.SAMLCRUD(c) } + +func (s *ServicesSuite) TestTunnelConnectionsCRUD(c *C) { + s.suite.TunnelConnectionsCRUD(c) +} diff --git a/lib/services/presence.go b/lib/services/presence.go index 273591baee341..dcf599d54f8e1 100644 --- a/lib/services/presence.go +++ b/lib/services/presence.go @@ -95,6 +95,21 @@ type Presence interface { // DeleteTrustedCluster removes a TrustedCluster from the backend by name. DeleteTrustedCluster(string) error + + // UpsertTunnelConnection upserts tunnel connection + UpsertTunnelConnection(TunnelConnection) error + + // GetTunnelConnections returns tunnel connections for a given cluster + GetTunnelConnections(clusterName string) ([]TunnelConnection, error) + + // GetAllTunnelConnections returns all tunnel connections + GetAllTunnelConnections() ([]TunnelConnection, error) + + // DeleteTunnelConnections deletes all tunnel connections for cluster + DeleteTunnelConnections(clusterName string) error + + // DeleteAllTunnelConnections deletes all tunnel connections for cluster + DeleteAllTunnelConnections() error } // NewNamespace returns new namespace diff --git a/lib/services/resource.go b/lib/services/resource.go index 3e4cbba72d94d..8b27258bb084a 100644 --- a/lib/services/resource.go +++ b/lib/services/resource.go @@ -131,6 +131,9 @@ const ( // KindAuthConnector allows access to OIDC and SAML connectors. KindAuthConnector = "auth_connector" + // KindTunnelConection specifies connection of a reverse tunnel to proxy + KindTunnelConnection = "tunnel_connection" + // V3 is the third version of resources. V3 = "v3" diff --git a/lib/services/suite/suite.go b/lib/services/suite/suite.go index 0af98a7c95ca1..54973f8f32969 100644 --- a/lib/services/suite/suite.go +++ b/lib/services/suite/suite.go @@ -552,3 +552,52 @@ func (s *ServicesTestSuite) SAMLCRUD(c *C) { _, err = s.WebS.GetSAMLConnector(connector.GetName(), true) c.Assert(trace.IsNotFound(err), Equals, true, Commentf("expected not found, got %T", err)) } + +func (s *ServicesTestSuite) TunnelConnectionsCRUD(c *C) { + clusterName := "example.com" + out, err := s.PresenceS.GetTunnelConnections(clusterName) + c.Assert(err, IsNil) + c.Assert(len(out), Equals, 0) + + dt := time.Date(2015, 6, 5, 4, 3, 2, 1, time.UTC).UTC() + conn, err := services.NewTunnelConnection("conn1", services.TunnelConnectionSpecV2{ + ClusterName: clusterName, + ProxyName: "p1", + LastHeartbeat: dt, + }) + c.Assert(err, IsNil) + + err = s.PresenceS.UpsertTunnelConnection(conn) + c.Assert(err, IsNil) + + out, err = s.PresenceS.GetTunnelConnections(clusterName) + c.Assert(err, IsNil) + c.Assert(len(out), Equals, 1) + fixtures.DeepCompare(c, out[0], conn) + + out, err = s.PresenceS.GetAllTunnelConnections() + c.Assert(err, IsNil) + c.Assert(len(out), Equals, 1) + fixtures.DeepCompare(c, out[0], conn) + + dt = dt.Add(time.Hour) + conn.SetLastHeartbeat(dt) + + err = s.PresenceS.UpsertTunnelConnection(conn) + c.Assert(err, IsNil) + + out, err = s.PresenceS.GetTunnelConnections(clusterName) + c.Assert(err, IsNil) + c.Assert(len(out), Equals, 1) + fixtures.DeepCompare(c, out[0], conn) + + err = s.PresenceS.DeleteAllTunnelConnections() + c.Assert(err, IsNil) + + out, err = s.PresenceS.GetTunnelConnections(clusterName) + c.Assert(err, IsNil) + c.Assert(len(out), Equals, 0) + + err = s.PresenceS.DeleteAllTunnelConnections() + c.Assert(err, IsNil) +} diff --git a/lib/services/trustedcluster.go b/lib/services/trustedcluster.go index 5e675b23d34ca..1d0012ffad552 100644 --- a/lib/services/trustedcluster.go +++ b/lib/services/trustedcluster.go @@ -28,7 +28,6 @@ import ( "github.com/gravitational/trace" "github.com/jonboulle/clockwork" - log "github.com/sirupsen/logrus" ) // TrustedCluster holds information needed for a cluster that can not be directly @@ -200,7 +199,6 @@ func (r RoleMap) Map(remoteRoles []string) ([]string, error) { outRoles = append(outRoles, wildcardMatch...) return outRoles, nil } - log.Debugf("%v: direct match: %v wildcard match: %v", r, directMatch, wildcardMatch) for _, remoteRole := range remoteRoles { match, ok := directMatch[remoteRole] if ok { diff --git a/lib/services/tunnelconn.go b/lib/services/tunnelconn.go new file mode 100644 index 0000000000000..d8fd4ce6268e4 --- /dev/null +++ b/lib/services/tunnelconn.go @@ -0,0 +1,227 @@ +package services + +import ( + "encoding/json" + "fmt" + "strings" + "time" + + "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/utils" + + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" +) + +// TunnelConnection is SSH reverse tunnel connection +// established to reverse tunnel proxy +type TunnelConnection interface { + // Resource provides common methods for resource objects + Resource + // GetClusterName returns name of the cluster + // this connection is for + GetClusterName() string + // GetProxyName returns the proxy name this connection is established to + GetProxyName() string + // GetLastHeartbeat returns time of the last heartbeat received from + // the tunnel over the connection + GetLastHeartbeat() time.Time + // SetLastHeartbeat sets last heartbeat time + SetLastHeartbeat(time.Time) + // Check checks tunnel for errors + Check() error + // CheckAndSetDefaults checks and set default values for any missing fields. + CheckAndSetDefaults() error + // String returns user friendly representation of this connection + String() string +} + +// MustCreateTunnelConnection returns new connection from V2 spec or panics if +// parameters are incorrect +func MustCreateTunnelConnection(name string, spec TunnelConnectionSpecV2) TunnelConnection { + conn, err := NewTunnelConnection(name, spec) + if err != nil { + panic(err) + } + return conn +} + +// NewTunnelConnection returns new connection from V2 spec +func NewTunnelConnection(name string, spec TunnelConnectionSpecV2) (TunnelConnection, error) { + conn := &TunnelConnectionV2{ + Kind: KindTunnelConnection, + Version: V2, + Metadata: Metadata{ + Name: name, + Namespace: defaults.Namespace, + }, + Spec: spec, + } + if err := conn.CheckAndSetDefaults(); err != nil { + return nil, trace.Wrap(err) + } + return conn, nil +} + +// TunnelConnectionV2 is version 1 resource spec of the reverse tunnel +type TunnelConnectionV2 struct { + // Kind is a resource kind + Kind string `json:"kind"` + // Version is a resource version + Version string `json:"version"` + // Metadata is Role metadata + Metadata Metadata `json:"metadata"` + // Spec contains user specification + Spec TunnelConnectionSpecV2 `json:"spec"` +} + +// String returns user-friendly description of this connection +func (r *TunnelConnectionV2) String() string { + return fmt.Sprintf("TunnelConnection(name=%v, cluster=%v, proxy=%v)", r.Metadata.Name, r.Spec.ClusterName, r.Spec.ProxyName) +} + +// GetMetadata returns object metadata +func (r *TunnelConnectionV2) GetMetadata() Metadata { + return r.Metadata +} + +// SetExpiry sets expiry time for the object +func (r *TunnelConnectionV2) SetExpiry(expires time.Time) { + r.Metadata.SetExpiry(expires) +} + +// Expires retuns object expiry setting +func (r *TunnelConnectionV2) Expiry() time.Time { + return r.Metadata.Expiry() +} + +// SetTTL sets Expires header using realtime clock +func (r *TunnelConnectionV2) SetTTL(clock clockwork.Clock, ttl time.Duration) { + r.Metadata.SetTTL(clock, ttl) +} + +// GetName returns the name of the User +func (r *TunnelConnectionV2) GetName() string { + return r.Metadata.Name +} + +// SetName sets the name of the User +func (r *TunnelConnectionV2) SetName(e string) { + r.Metadata.Name = e +} + +// V2 returns V2 version of the resource +func (r *TunnelConnectionV2) V2() *TunnelConnectionV2 { + return r +} + +func (r *TunnelConnectionV2) CheckAndSetDefaults() error { + err := r.Metadata.CheckAndSetDefaults() + if err != nil { + return trace.Wrap(err) + } + + err = r.Check() + if err != nil { + return trace.Wrap(err) + } + + return nil +} + +// GetClusterName returns name of the cluster +func (r *TunnelConnectionV2) GetClusterName() string { + return r.Spec.ClusterName +} + +// GetProxyName returns the name of the proxy +func (r *TunnelConnectionV2) GetProxyName() string { + return r.Spec.ProxyName +} + +// GetLastHeartbeat returns last heartbeat +func (r *TunnelConnectionV2) GetLastHeartbeat() time.Time { + return r.Spec.LastHeartbeat +} + +// SetLastHeartbeat sets last heartbeat time +func (r *TunnelConnectionV2) SetLastHeartbeat(tm time.Time) { + r.Spec.LastHeartbeat = tm +} + +// Check returns nil if all parameters are good, error otherwise +func (r *TunnelConnectionV2) Check() error { + if r.Version == "" { + return trace.BadParameter("missing version") + } + if strings.TrimSpace(r.Spec.ClusterName) == "" { + return trace.BadParameter("empty cluster name") + } + + if len(r.Spec.ProxyName) == 0 { + return trace.BadParameter("missing parameter proxy name") + } + + return nil +} + +// TunnelConnectionSpecV2 is a specification for V2 tunnel connection +type TunnelConnectionSpecV2 struct { + // ClusterName is a name of the cluster + ClusterName string `json:"cluster_name"` + // ProxyName is the name of the proxy server + ProxyName string `json:"proxy_name"` + // LastHeartbeat is a time of the last heartbeat + LastHeartbeat time.Time `json:"last_heartbeat"` +} + +// TunnelConnectionSpecV2Schema is JSON schema for reverse tunnel spec +const TunnelConnectionSpecV2Schema = `{ + "type": "object", + "additionalProperties": false, + "required": ["cluster_name", "proxy_name", "last_heartbeat"], + "properties": { + "cluster_name": {"type": "string"}, + "proxy_name": {"type": "string"}, + "last_heartbeat": {"type": "string"} + } +}` + +// GetTunnelConnectionSchema returns role schema with optionally injected +// schema for extensions +func GetTunnelConnectionSchema() string { + return fmt.Sprintf(V2SchemaTemplate, MetadataSchema, TunnelConnectionSpecV2Schema, DefaultDefinitions) +} + +// UnmarshalTunnelConnection unmarshals reverse tunnel from JSON or YAML, +// sets defaults and checks the schema +func UnmarshalTunnelConnection(data []byte) (TunnelConnection, error) { + if len(data) == 0 { + return nil, trace.BadParameter("missing tunnel data") + } + var h ResourceHeader + err := json.Unmarshal(data, &h) + if err != nil { + return nil, trace.Wrap(err) + } + switch h.Version { + case V2: + var r TunnelConnectionV2 + + if err := utils.UnmarshalWithSchema(GetTunnelConnectionSchema(), &r, data); err != nil { + return nil, trace.BadParameter(err.Error()) + } + + if err := r.CheckAndSetDefaults(); err != nil { + return nil, trace.Wrap(err) + } + + return &r, nil + } + return nil, trace.BadParameter("reverse tunnel version %v is not supported", h.Version) +} + +// MarshalTunnelConnection marshals tunnel connection +func MarshalTunnelConnection(rt TunnelConnection, opts ...MarshalOption) ([]byte, error) { + return json.Marshal(rt) +} diff --git a/lib/srv/proxy.go b/lib/srv/proxy.go index 2d5ea6fdd6a01..cf430a85ab61e 100644 --- a/lib/srv/proxy.go +++ b/lib/srv/proxy.go @@ -41,14 +41,14 @@ import ( // proxySubsys implements an SSH subsystem for proxying listening sockets from // remote hosts to a proxy client (AKA port mapping) type proxySubsys struct { - srv *Server - host string - port string - namespace string - siteName string - closeC chan struct{} - error error - closeOnce sync.Once + srv *Server + host string + port string + namespace string + clusterName string + closeC chan struct{} + error error + closeOnce sync.Once } // parseProxySubsys looks at the requested subsystem name and returns a fully configured @@ -110,18 +110,18 @@ func parseProxySubsys(request string, srv *Server) (*proxySubsys, error) { } } return &proxySubsys{ - namespace: namespace, - srv: srv, - host: targetHost, - port: targetPort, - siteName: clusterName, - closeC: make(chan struct{}), + namespace: namespace, + srv: srv, + host: targetHost, + port: targetPort, + clusterName: clusterName, + closeC: make(chan struct{}), }, nil } func (t *proxySubsys) String() string { return fmt.Sprintf("proxySubsys(cluster=%s/%s, host=%s, port=%s)", - t.namespace, t.siteName, t.host, t.port) + t.namespace, t.clusterName, t.host, t.port) } // start is called by Golang's ssh when it needs to engage this sybsystem (typically to establish @@ -145,9 +145,9 @@ func (t *proxySubsys) start(sconn *ssh.ServerConn, ch ssh.Channel, req *ssh.Requ clientAddr = a } } - // get the site by name: - if t.siteName != "" { - site, err = tunnel.GetSite(t.siteName) + // get the cluster by name: + if t.clusterName != "" { + site, err = tunnel.GetSite(t.clusterName) if err != nil { log.Warn(err) return trace.Wrap(err) @@ -299,7 +299,7 @@ func (t *proxySubsys) proxyToHost( return trace.Wrap(err) } // this custom SSH handshake allows SSH proxy to relay the client's IP - // address to the SSH erver: + // address to the SSH server doHandshake(remoteAddr, ch, conn) go func() { diff --git a/lib/srv/proxy_test.go b/lib/srv/proxy_test.go index 0dd8245836ddb..13ce0249409b3 100644 --- a/lib/srv/proxy_test.go +++ b/lib/srv/proxy_test.go @@ -40,7 +40,7 @@ func (s *ProxyTestSuite) TestParseProxyRequest(c *check.C) { c.Assert(subsys.srv, check.Equals, s.srv) c.Assert(subsys.host, check.Equals, "host") c.Assert(subsys.port, check.Equals, "22") - c.Assert(subsys.siteName, check.Equals, "") + c.Assert(subsys.clusterName, check.Equals, "") // similar request, just with '@' at the end (missing site) subsys, err = parseProxySubsys("proxy:host:22@", s.srv) @@ -48,7 +48,7 @@ func (s *ProxyTestSuite) TestParseProxyRequest(c *check.C) { c.Assert(subsys.srv, check.Equals, s.srv) c.Assert(subsys.host, check.Equals, "host") c.Assert(subsys.port, check.Equals, "22") - c.Assert(subsys.siteName, check.Equals, "") + c.Assert(subsys.clusterName, check.Equals, "") // proxy request for just the sitename subsys, err = parseProxySubsys("proxy:@moon", s.srv) @@ -57,7 +57,7 @@ func (s *ProxyTestSuite) TestParseProxyRequest(c *check.C) { c.Assert(subsys.srv, check.Equals, s.srv) c.Assert(subsys.host, check.Equals, "") c.Assert(subsys.port, check.Equals, "") - c.Assert(subsys.siteName, check.Equals, "moon") + c.Assert(subsys.clusterName, check.Equals, "moon") // proxy request for the host:port@sitename subsys, err = parseProxySubsys("proxy:station:100@moon", s.srv) @@ -66,7 +66,7 @@ func (s *ProxyTestSuite) TestParseProxyRequest(c *check.C) { c.Assert(subsys.srv, check.Equals, s.srv) c.Assert(subsys.host, check.Equals, "station") c.Assert(subsys.port, check.Equals, "100") - c.Assert(subsys.siteName, check.Equals, "moon") + c.Assert(subsys.clusterName, check.Equals, "moon") // proxy request for the host:port@namespace@cluster subsys, err = parseProxySubsys("proxy:station:100@system@moon", s.srv) @@ -75,7 +75,7 @@ func (s *ProxyTestSuite) TestParseProxyRequest(c *check.C) { c.Assert(subsys.srv, check.Equals, s.srv) c.Assert(subsys.host, check.Equals, "station") c.Assert(subsys.port, check.Equals, "100") - c.Assert(subsys.siteName, check.Equals, "moon") + c.Assert(subsys.clusterName, check.Equals, "moon") c.Assert(subsys.namespace, check.Equals, "system") } diff --git a/lib/srv/sshserver.go b/lib/srv/sshserver.go index 266b5f4a80a81..e5b4054578fda 100644 --- a/lib/srv/sshserver.go +++ b/lib/srv/sshserver.go @@ -56,9 +56,10 @@ import ( type Server struct { sync.Mutex - namespace string - addr utils.NetAddr - hostname string + namespace string + addr utils.NetAddr + hostname string + // certChecker checks the CA of the connecting user certChecker ssh.CertChecker srv *sshutils.Server hostSigner ssh.Signer @@ -649,6 +650,7 @@ func (s *Server) keyAuth(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permiss log.Debugf("[SSH] need a valid principal for key %v", fingerprint) return nil, trace.BadParameter("need a valid principal for key %v", fingerprint) } + if len(cert.KeyId) == 0 { log.Debugf("[SSH] need a valid key ID for key %v", fingerprint) return nil, trace.BadParameter("need a valid key for key %v", fingerprint) @@ -689,7 +691,7 @@ func (s *Server) keyAuth(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permiss } } - // this is the only way I know of to pass valid principal with the + // this is the only way we know of to pass valid principal with the // connection permissions.Extensions[utils.CertTeleportUser] = teleportUser diff --git a/lib/state/cachingaccesspoint.go b/lib/state/cachingaccesspoint.go index 231c75f163713..eeff683a55337 100644 --- a/lib/state/cachingaccesspoint.go +++ b/lib/state/cachingaccesspoint.go @@ -152,6 +152,20 @@ func (cs *CachingAuthClient) fetchAll() error { errors = append(errors, err) _, err = cs.GetUsers() errors = append(errors, err) + conns, err := cs.ap.GetAllTunnelConnections() + if err != nil { + errors = append(errors, err) + } + clusters := map[string]bool{} + for _, conn := range conns { + clusterName := conn.GetClusterName() + if _, ok := clusters[clusterName]; ok { + continue + } + clusters[clusterName] = true + _, err = cs.GetTunnelConnections(clusterName) + errors = append(errors, err) + } return trace.NewAggregate(errors...) } @@ -414,6 +428,58 @@ func (cs *CachingAuthClient) GetUsers() (users []services.User, err error) { return users, err } +// GetTunnelConnections is a part of auth.AccessPoint implementation +func (cs *CachingAuthClient) GetTunnelConnections(clusterName string) (conns []services.TunnelConnection, err error) { + err = cs.try(func() error { + conns, err = cs.ap.GetTunnelConnections(clusterName) + return err + }) + if err != nil { + if trace.IsConnectionProblem(err) { + return cs.presence.GetTunnelConnections(clusterName) + } + return conns, err + } + if err := cs.presence.DeleteTunnelConnections(clusterName); err != nil { + if !trace.IsNotFound(err) { + return nil, trace.Wrap(err) + } + } + for _, conn := range conns { + cs.setTTL(conn) + if err := cs.presence.UpsertTunnelConnection(conn); err != nil { + return nil, trace.Wrap(err) + } + } + return conns, err +} + +// GetAllTunnelConnections is a part of auth.AccessPoint implementation +func (cs *CachingAuthClient) GetAllTunnelConnections() (conns []services.TunnelConnection, err error) { + err = cs.try(func() error { + conns, err = cs.ap.GetAllTunnelConnections() + return err + }) + if err != nil { + if trace.IsConnectionProblem(err) { + return cs.presence.GetAllTunnelConnections() + } + return conns, err + } + if err := cs.presence.DeleteAllTunnelConnections(); err != nil { + if !trace.IsNotFound(err) { + return nil, trace.Wrap(err) + } + } + for _, conn := range conns { + cs.setTTL(conn) + if err := cs.presence.UpsertTunnelConnection(conn); err != nil { + return nil, trace.Wrap(err) + } + } + return conns, err +} + // UpsertNode is part of auth.AccessPoint implementation func (cs *CachingAuthClient) UpsertNode(s services.Server) error { cs.setTTL(s) @@ -426,6 +492,12 @@ func (cs *CachingAuthClient) UpsertProxy(s services.Server) error { return cs.ap.UpsertProxy(s) } +// UpsertTunnelConnection is a part of auth.AccessPoint implementation +func (cs *CachingAuthClient) UpsertTunnelConnection(conn services.TunnelConnection) error { + cs.setTTL(conn) + return cs.ap.UpsertTunnelConnection(conn) +} + // try calls a given function f and checks for errors. If f() fails, the current // time is recorded. Future calls to f will be ingored until sufficient time passes // since th last error diff --git a/lib/state/cachingaccesspoint_test.go b/lib/state/cachingaccesspoint_test.go index 7d7dd99967ee5..6645b9e60a962 100644 --- a/lib/state/cachingaccesspoint_test.go +++ b/lib/state/cachingaccesspoint_test.go @@ -88,6 +88,14 @@ var ( }, }, } + TunnelConnections = []services.TunnelConnection{ + services.MustCreateTunnelConnection("conn1", services.TunnelConnectionSpecV2{ + ClusterName: "example.com", + ProxyAddr: "localhost:3025", + ProxyName: "p1", + LastHeartbeat: time.Date(2015, 6, 5, 4, 3, 2, 1, time.UTC).UTC(), + }), + } ) type ClusterSnapshotSuite struct { @@ -151,6 +159,11 @@ func (s *ClusterSnapshotSuite) SetUpTest(c *check.C) { err = s.authServer.UpsertUser(v2) c.Assert(err, check.IsNil) } + // add tunnel connections + for _, c := range TunnelConnections { + c.SetTTL(s.clock, defaults.ServerHeartbeatTTL) + err = s.authServer.UpsertTunnelConnection(c) + } } func (s *ClusterSnapshotSuite) TearDownTest(c *check.C) { @@ -184,6 +197,10 @@ func (s *ClusterSnapshotSuite) TestEverything(c *check.C) { proxies, err := snap.GetProxies() c.Assert(err, check.IsNil) c.Assert(proxies, check.HasLen, len(Proxies)) + + conns, err := snap.GetTunnelConnections("example.com") + c.Assert(err, check.IsNil) + c.Assert(conns, check.HasLen, len(TunnelConnections)) } func (s *ClusterSnapshotSuite) TestTry(c *check.C) { diff --git a/lib/utils/conn.go b/lib/utils/conn.go new file mode 100644 index 0000000000000..fe659a735506b --- /dev/null +++ b/lib/utils/conn.go @@ -0,0 +1,54 @@ +/* +Copyright 2017 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package utils + +import ( + "io" + "net" + + "github.com/gravitational/trace" +) + +// NewCloserConn returns new connection wrapper that +// when closed will also close passed closers +func NewCloserConn(conn net.Conn, closers ...io.Closer) *CloserConn { + return &CloserConn{ + Conn: conn, + closers: closers, + } +} + +// CloserConn wraps connection and attaches additional closers to it +type CloserConn struct { + net.Conn + closers []io.Closer +} + +// AddCloser adds any closer in ctx that will be called +// whenever server closes session channel +func (c *CloserConn) AddCloser(closer io.Closer) { + c.closers = append(c.closers, closer) +} + +func (c *CloserConn) Close() error { + var errors []error + for _, closer := range c.closers { + errors = append(errors, closer.Close()) + } + errors = append(errors, c.Conn.Close()) + return trace.NewAggregate(errors...) +} diff --git a/lib/utils/proxy/proxy.go b/lib/utils/proxy/proxy.go index e83bf292252d8..66bd0aa49a55b 100644 --- a/lib/utils/proxy/proxy.go +++ b/lib/utils/proxy/proxy.go @@ -33,14 +33,8 @@ import ( log "github.com/sirupsen/logrus" ) -// DialWithDeadline works around the case when net.DialWithTimeout -// succeeds, but key exchange hangs. Setting deadline on connection -// prevents this case from happening -func DialWithDeadline(network string, addr string, config *ssh.ClientConfig) (*ssh.Client, error) { - conn, err := net.DialTimeout(network, addr, config.Timeout) - if err != nil { - return nil, err - } +// NewClientConnWithDeadline establishes new client connection with specified deadline +func NewClientConnWithDeadline(conn net.Conn, addr string, config *ssh.ClientConfig) (*ssh.Client, error) { if config.Timeout > 0 { conn.SetReadDeadline(time.Now().Add(config.Timeout)) } @@ -54,6 +48,17 @@ func DialWithDeadline(network string, addr string, config *ssh.ClientConfig) (*s return ssh.NewClient(c, chans, reqs), nil } +// DialWithDeadline works around the case when net.DialWithTimeout +// succeeds, but key exchange hangs. Setting deadline on connection +// prevents this case from happening +func DialWithDeadline(network string, addr string, config *ssh.ClientConfig) (*ssh.Client, error) { + conn, err := net.DialTimeout(network, addr, config.Timeout) + if err != nil { + return nil, err + } + return NewClientConnWithDeadline(conn, addr, config) +} + // A Dialer is a means for a client to establish a SSH connection. type Dialer interface { // Dial establishes a client connection to a SSH server. diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index 6c58fd739c59c..d8ed4144a7c91 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -1069,7 +1069,6 @@ Sucessful response: {"namespaces": [{..namespace resource...}]} */ func (m *Handler) getSiteNamespaces(w http.ResponseWriter, r *http.Request, _ httprouter.Params, c *SessionContext, site reversetunnel.RemoteSite) (interface{}, error) { - log.Debugf("[web] GET /namespaces") clt, err := site.GetClient() if err != nil { return nil, trace.Wrap(err) @@ -1089,7 +1088,6 @@ type nodeWithSessions struct { } func (m *Handler) siteNodesGet(w http.ResponseWriter, r *http.Request, p httprouter.Params, c *SessionContext, site reversetunnel.RemoteSite) (interface{}, error) { - log.Debugf("[web] GET /nodes") clt, err := site.GetClient() if err != nil { return nil, trace.Wrap(err) diff --git a/lib/web/terminal.go b/lib/web/terminal.go index 9c364f0dae638..1050b50dbc6ae 100644 --- a/lib/web/terminal.go +++ b/lib/web/terminal.go @@ -185,12 +185,14 @@ func (t *TerminalHandler) Run(w http.ResponseWriter, r *http.Request) { // retreives them directly from auth server API: agent, err := t.ctx.GetAgent() if err != nil { + log.Warningf("failed to get agent: %v", err) errToTerm(err, ws) return } defer agent.Close() principal, auth, err := getUserCredentials(agent) if err != nil { + log.Warningf("failed to get user credentials: %v", err) errToTerm(err, ws) return } @@ -219,6 +221,7 @@ func (t *TerminalHandler) Run(w http.ResponseWriter, r *http.Request) { } tc, err := client.NewClient(clientConfig) if err != nil { + log.Warningf("failed to create client: %v", err) errToTerm(err, ws) return } @@ -230,6 +233,7 @@ func (t *TerminalHandler) Run(w http.ResponseWriter, r *http.Request) { return false, nil } if err = tc.SSH(context.TODO(), t.params.InteractiveCommand, false); err != nil { + log.Warningf("failed to SSH: %v", err) errToTerm(err, ws) return } From a2cd00de8fa643691ba3c6083a289b54c2625aa7 Mon Sep 17 00:00:00 2001 From: Sasha Klizhentas Date: Thu, 5 Oct 2017 18:41:59 -0700 Subject: [PATCH 02/24] a bit of refactoring --- lib/reversetunnel/agent.go | 22 +++++++++------- lib/reversetunnel/agentpool.go | 46 ++++++++++++++++++++++------------ 2 files changed, 43 insertions(+), 25 deletions(-) diff --git a/lib/reversetunnel/agent.go b/lib/reversetunnel/agent.go index 6be46c1fce883..5ada5844b496a 100644 --- a/lib/reversetunnel/agent.go +++ b/lib/reversetunnel/agent.go @@ -21,6 +21,7 @@ limitations under the License. package reversetunnel import ( + "context" "fmt" "io" "net" @@ -50,8 +51,8 @@ type Agent struct { remoteDomainName string // clientName format is "hostid.domain" (where 'domain' is local domain name) clientName string - broadcastClose *utils.CloseBroadcaster - disconnectC chan bool + ctx context.Context + cancel context.CancelFunc hostKeyCallback utils.HostKeyCallback authMethods []ssh.AuthMethod accessPoint auth.AccessPoint @@ -71,10 +72,12 @@ func NewAgent( clientName string, signers []ssh.Signer, clt *auth.TunClient, - accessPoint auth.AccessPoint) (*Agent, error) { + accessPoint auth.AccessPoint, parentContext context.Context) (*Agent, error) { log.Debugf("reversetunnel.NewAgent %s -> %s", clientName, remoteDomainName) + ctx, cancel := context.WithCancel(parentContext) + a := &Agent{ log: log.WithFields(log.Fields{ teleport.Component: teleport.ComponentReverseTunnel, @@ -88,10 +91,10 @@ func NewAgent( addr: addr, remoteDomainName: remoteDomainName, clientName: clientName, - broadcastClose: utils.NewCloseBroadcaster(), - disconnectC: make(chan bool, 10), authMethods: []ssh.AuthMethod{ssh.PublicKeys(signers...)}, accessPoint: accessPoint, + ctx: ctx, + cancel: cancel, } a.hostKeyCallback = a.checkHostSignature return a, nil @@ -99,7 +102,8 @@ func NewAgent( // Close signals to close all connections func (a *Agent) Close() error { - return a.broadcastClose.Close() + a.cancel() + return nil } // Start starts agent that attempts to connect to remote server part @@ -221,7 +225,7 @@ func (a *Agent) proxyTransport(ch ssh.Channel, reqC <-chan *ssh.Request) { var req *ssh.Request select { - case <-a.broadcastClose.C: + case <-a.ctx.Done(): a.log.Infof("is closed, returning") return case req = <-reqC: @@ -328,7 +332,7 @@ func (a *Agent) runHeartbeat(conn *ssh.Client) { for { select { // need to exit: - case <-a.broadcastClose.C: + case <-a.ctx.Done(): return nil // time to ping: case <-ticker.C: @@ -385,7 +389,7 @@ func (a *Agent) runHeartbeat(conn *ssh.Client) { if err != nil || conn == nil { select { // abort if asked to stop: - case <-a.broadcastClose.C: + case <-a.ctx.Done(): return // reconnect case <-ticker.C: diff --git a/lib/reversetunnel/agentpool.go b/lib/reversetunnel/agentpool.go index 6beb384559e63..069440b7001f2 100644 --- a/lib/reversetunnel/agentpool.go +++ b/lib/reversetunnel/agentpool.go @@ -1,6 +1,7 @@ package reversetunnel import ( + "context" "fmt" "sync" "time" @@ -23,9 +24,10 @@ import ( type AgentPool struct { sync.Mutex *log.Entry - cfg AgentPoolConfig - agents map[agentKey]*Agent - closeBroadcast *utils.CloseBroadcaster + cfg AgentPoolConfig + agents map[agentKey]*Agent + ctx context.Context + cancel context.CancelFunc } // AgentPoolConfig holds configuration parameters for the agent pool @@ -40,26 +42,38 @@ type AgentPoolConfig struct { HostSigners []ssh.Signer // HostUUID is a unique ID of this host HostUUID string + // Context is an optional context + Context context.Context } -// NewAgentPool returns new isntance of the agent pool -func NewAgentPool(cfg AgentPoolConfig) (*AgentPool, error) { +// CheckAndSetDefaults checks and sets defaults +func (cfg *AgentPoolConfig) CheckAndSetDefaults() error { if cfg.Client == nil { - return nil, trace.BadParameter("missing 'Client' parameter") + return trace.BadParameter("missing 'Client' parameter") } if cfg.AccessPoint == nil { - return nil, trace.BadParameter("missing 'AccessPoint' parameter") + return trace.BadParameter("missing 'AccessPoint' parameter") } if len(cfg.HostSigners) == 0 { - return nil, trace.BadParameter("missing 'HostSigners' parameter") + return trace.BadParameter("missing 'HostSigners' parameter") } if len(cfg.HostUUID) == 0 { - return nil, trace.BadParameter("missing 'HostUUID' parameter") + return trace.BadParameter("missing 'HostUUID' parameter") } + if cfg.Context == nil { + cfg.Context = context.TODO() + } + return nil +} + +// NewAgentPool returns new isntance of the agent pool +func NewAgentPool(cfg AgentPoolConfig) (*AgentPool, error) { + ctx, cancel := context.WithCancel(cfg.Context) pool := &AgentPool{ - agents: make(map[agentKey]*Agent), - cfg: cfg, - closeBroadcast: utils.NewCloseBroadcaster(), + agents: make(map[agentKey]*Agent), + cfg: cfg, + ctx: ctx, + cancel: cancel, } pool.Entry = log.WithFields(log.Fields{ teleport.Component: teleport.ComponentReverseTunnel, @@ -79,13 +93,13 @@ func (m *AgentPool) Start() error { // Stop stops the agent pool func (m *AgentPool) Stop() { - m.closeBroadcast.Close() + m.cancel() } // Wait returns when agent pool is closed func (m *AgentPool) Wait() error { select { - case <-m.closeBroadcast.C: + case <-m.ctx.Done(): break } return nil @@ -110,7 +124,7 @@ func (m *AgentPool) pollAndSyncAgents() { m.FetchAndSyncAgents() for { select { - case <-m.closeBroadcast.C: + case <-m.ctx.Done(): m.Debugf("closing") m.Lock() defer m.Unlock() @@ -146,7 +160,7 @@ func (m *AgentPool) syncAgents(tunnels []services.ReverseTunnel) error { for _, key := range agentsToAdd { m.Debugf("adding %v", &key) - agent, err := NewAgent(key.addr, key.domainName, m.cfg.HostUUID, m.cfg.HostSigners, m.cfg.Client, m.cfg.AccessPoint) + agent, err := NewAgent(key.addr, key.domainName, m.cfg.HostUUID, m.cfg.HostSigners, m.cfg.Client, m.cfg.AccessPoint, m.ctx) if err != nil { return trace.Wrap(err) } From 53f4a0128e9d02ab3db3244e227a919a3e0ca6c3 Mon Sep 17 00:00:00 2001 From: Sasha Klizhentas Date: Fri, 6 Oct 2017 15:38:15 -0700 Subject: [PATCH 03/24] introduce curiosity protocol and fix logs --- constants.go | 11 +- lib/reversetunnel/agent.go | 197 +++++++++++++------ lib/reversetunnel/agentpool.go | 54 +++-- lib/reversetunnel/discovery.go | 52 +++++ lib/reversetunnel/localsite.go | 10 +- lib/reversetunnel/peer.go | 3 +- lib/reversetunnel/remotesite.go | 124 ++++++++++-- lib/reversetunnel/srv.go | 42 +++- lib/utils/cli.go | 3 +- vendor/github.com/gravitational/trace/log.go | 129 +++++++++++- 10 files changed, 501 insertions(+), 124 deletions(-) create mode 100644 lib/reversetunnel/discovery.go diff --git a/constants.go b/constants.go index 57161eca4962b..b5e8eb35080ad 100644 --- a/constants.go +++ b/constants.go @@ -57,10 +57,15 @@ const ( // ComponentFields stores component-specific fields ComponentFields = "fields" - // ComponentReverseTunnel is reverse tunnel agent and server - // that together establish a bi-directional SSH revers tunnel + // ComponentReverseTunnelServer is reverse tunnel server + // that together with agent establish a bi-directional SSH revers tunnel // to bypass firewall restrictions - ComponentReverseTunnel = "reversetunnel" + ComponentReverseTunnelServer = "proxy:server" + + // ComponentReverseTunnel is reverse tunnel agent + // that together with server establish a bi-directional SSH revers tunnel + // to bypass firewall restrictions + ComponentReverseTunnelAgent = "proxy:agent" // ComponentAuth is the cluster CA node (auth server API) ComponentAuth = "auth" diff --git a/lib/reversetunnel/agent.go b/lib/reversetunnel/agent.go index 5ada5844b496a..18e47441fab27 100644 --- a/lib/reversetunnel/agent.go +++ b/lib/reversetunnel/agent.go @@ -41,60 +41,84 @@ import ( "golang.org/x/crypto/ssh" ) +// AgentConfig holds configuration for agent +type AgentConfig struct { + // Addr is target address to dial + Addr utils.NetAddr + // RemoteCluster is a remote cluster name to connect to + RemoteCluster string + // Signers contains authentication signers + Signers []ssh.Signer + // Client is a client to the local auth servers + Client *auth.TunClient + // AccessPoint is a caching access point to the local auth servers + AccessPoint auth.AccessPoint + // Context is a parent context + Context context.Context + // DiscoveryC is a channel that receives discovery requests + // from reverse tunnel server + DiscoveryC chan *discoveryRequest + // Username is the name of this client used to authenticate on SSH + Username string +} + +// CheckAndSetDefaults checks parameters and sets default values +func (a *AgentConfig) CheckAndSetDefaults() error { + if a.Addr.IsEmpty() { + return trace.BadParameter("missing parameter Addr") + } + if a.DiscoveryC == nil { + return trace.BadParameter("missing parameter DiscoveryC") + } + if a.Context == nil { + return trace.BadParameter("missing parameter Context") + } + if a.Client == nil { + return trace.BadParameter("missing parameter Client") + } + if a.AccessPoint == nil { + return trace.BadParameter("missing parameter AccessPoint") + } + if len(a.Signers) == 0 { + return trace.BadParameter("missing parameter Signers") + } + if len(a.Username) == 0 { + return trace.BadParameter("missing parameter Username") + } + return nil +} + // Agent is a reverse tunnel agent running as a part of teleport Proxies // to establish outbound reverse tunnels to remote proxies type Agent struct { - log *log.Entry - addr utils.NetAddr - clt *auth.TunClient - // domain name of the tunnel server, used only for debugging & logging - remoteDomainName string - // clientName format is "hostid.domain" (where 'domain' is local domain name) - clientName string + *log.Entry + AgentConfig ctx context.Context cancel context.CancelFunc hostKeyCallback utils.HostKeyCallback authMethods []ssh.AuthMethod - accessPoint auth.AccessPoint } -// AgentOption specifies parameter that could be passed to Agents -type AgentOption func(a *Agent) error - // NewAgent returns a new reverse tunnel agent // Parameters: // addr points to the remote reverse tunnel server // remoteDomainName is the domain name of the runnel server, used only for logging // clientName is hostid.domain (where 'domain' is local domain name) -func NewAgent( - addr utils.NetAddr, - remoteDomainName string, - clientName string, - signers []ssh.Signer, - clt *auth.TunClient, - accessPoint auth.AccessPoint, parentContext context.Context) (*Agent, error) { - - log.Debugf("reversetunnel.NewAgent %s -> %s", clientName, remoteDomainName) - - ctx, cancel := context.WithCancel(parentContext) +func NewAgent(cfg AgentConfig) (*Agent, error) { + ctx, cancel := context.WithCancel(cfg.Context) a := &Agent{ - log: log.WithFields(log.Fields{ - teleport.Component: teleport.ComponentReverseTunnel, + AgentConfig: cfg, + Entry: log.WithFields(log.Fields{ + teleport.Component: teleport.ComponentReverseTunnelAgent, teleport.ComponentFields: map[string]interface{}{ - "side": "agent", - "remote": addr.String(), - "mode": "agent", + "remote": cfg.Addr.String(), + "client": cfg.Username, }, }), - clt: clt, - addr: addr, - remoteDomainName: remoteDomainName, - clientName: clientName, - authMethods: []ssh.AuthMethod{ssh.PublicKeys(signers...)}, - accessPoint: accessPoint, - ctx: ctx, - cancel: cancel, + ctx: ctx, + cancel: cancel, + authMethods: []ssh.AuthMethod{ssh.PublicKeys(cfg.Signers...)}, } a.hostKeyCallback = a.checkHostSignature return a, nil @@ -110,8 +134,7 @@ func (a *Agent) Close() error { func (a *Agent) Start() error { conn, err := a.connect() if err != nil { - log.Errorf("Failed to create remote tunnel for %v on %s(%s): %v", - a.clientName, a.remoteDomainName, a.addr.FullAddress(), err) + a.Warningf("Failed to create remote tunnel: %v", err) } // start heartbeat even if error happend, it will reconnect go a.runHeartbeat(conn) @@ -125,7 +148,7 @@ func (a *Agent) Wait() error { // String returns debug-friendly func (a *Agent) String() string { - return fmt.Sprintf("tunagent(remote=%s)", a.addr.String()) + return fmt.Sprintf("tunagent(remote=%s)", a.Addr.String()) } func (a *Agent) checkHostSignature(hostport string, remote net.Addr, key ssh.PublicKey) error { @@ -133,7 +156,7 @@ func (a *Agent) checkHostSignature(hostport string, remote net.Addr, key ssh.Pub if !ok { return trace.BadParameter("expected certificate") } - cas, err := a.accessPoint.GetCertAuthorities(services.HostCA, false) + cas, err := a.AccessPoint.GetCertAuthorities(services.HostCA, false) if err != nil { return trace.Wrap(err, "failed to fetch remote certs") } @@ -144,7 +167,7 @@ func (a *Agent) checkHostSignature(hostport string, remote net.Addr, key ssh.Pub } for _, checker := range checkers { if sshutils.KeysEqual(checker, cert.SignatureKey) { - a.log.Debugf("matched key %v for %v", ca.GetName(), hostport) + a.Debugf("matched key %v for %v", ca.GetName(), hostport) return nil } } @@ -154,14 +177,11 @@ func (a *Agent) checkHostSignature(hostport string, remote net.Addr, key ssh.Pub } func (a *Agent) connect() (conn *ssh.Client, err error) { - if a.addr.IsEmpty() { - return nil, trace.BadParameter("reverse tunnel cannot be created: target address is empty") - } for _, authMethod := range a.authMethods { // if http_proxy is set, dial through the proxy dialer := proxy.DialerFromEnvironment() - conn, err = dialer.Dial(a.addr.AddrNetwork, a.addr.Addr, &ssh.ClientConfig{ - User: a.clientName, + conn, err = dialer.Dial(a.Addr.AddrNetwork, a.Addr.Addr, &ssh.ClientConfig{ + User: a.Username, Auth: []ssh.AuthMethod{authMethod}, HostKeyCallback: a.hostKeyCallback, Timeout: defaults.DefaultDialTimeout, @@ -174,12 +194,12 @@ func (a *Agent) connect() (conn *ssh.Client, err error) { } func (a *Agent) proxyAccessPoint(ch ssh.Channel, req <-chan *ssh.Request) { - log.Debugf("[HA Agent] proxyAccessPoint") + a.Debugf("proxyAccessPoint") defer ch.Close() - conn, err := a.clt.GetDialer()() + conn, err := a.Client.GetDialer()() if err != nil { - a.log.Errorf("error dialing: %v", err) + a.Warningf("error dialing: %v", err) return } @@ -215,7 +235,7 @@ func (a *Agent) proxyAccessPoint(ch ssh.Channel, req <-chan *ssh.Request) { // ch : SSH channel which received "teleport-transport" out-of-band request // reqC : request payload func (a *Agent) proxyTransport(ch ssh.Channel, reqC <-chan *ssh.Request) { - log.Debugf("[HA Agent] proxyTransport") + a.Debugf("proxyTransport") defer ch.Close() // always push space into stderr to make sure the caller can always @@ -226,15 +246,15 @@ func (a *Agent) proxyTransport(ch ssh.Channel, reqC <-chan *ssh.Request) { var req *ssh.Request select { case <-a.ctx.Done(): - a.log.Infof("is closed, returning") + a.Infof("is closed, returning") return case req = <-reqC: if req == nil { - a.log.Infof("connection closed, returning") + a.Infof("connection closed, returning") return } case <-time.After(defaults.DefaultDialTimeout): - a.log.Errorf("timeout waiting for dial") + a.Warningf("timeout waiting for dial") return } @@ -245,9 +265,9 @@ func (a *Agent) proxyTransport(ch ssh.Channel, reqC <-chan *ssh.Request) { // list of auth servers and return that. otherwise try and connect to the // passed in server. if server == RemoteAuthServer { - authServers, err := a.clt.GetAuthServers() + authServers, err := a.Client.GetAuthServers() if err != nil { - a.log.Errorf("unable to find auth servers: %v", err) + a.Warningf("unable to find auth servers: %v", err) return } for _, as := range authServers { @@ -257,7 +277,7 @@ func (a *Agent) proxyTransport(ch ssh.Channel, reqC <-chan *ssh.Request) { servers = append(servers, server) } - log.Debugf("got out of band request %v", servers) + a.Debugf("got out of band request %v", servers) var conn net.Conn var err error @@ -284,7 +304,7 @@ func (a *Agent) proxyTransport(ch ssh.Channel, reqC <-chan *ssh.Request) { // successfully dialed req.Reply(true, []byte("connected")) - a.log.Infof("successfully dialed to %v, start proxying", server) + a.Debugf("successfully dialed to %v, start proxying", server) wg := sync.WaitGroup{} wg.Add(2) @@ -317,7 +337,7 @@ func (a *Agent) runHeartbeat(conn *ssh.Client) { if conn == nil { return trace.Errorf("heartbeat cannot ping: need to reconnect") } - log.Infof("[TUNNEL CLIENT] connected to %s", conn.RemoteAddr()) + a.Infof("connected to %s", conn.RemoteAddr()) defer conn.Close() hb, reqC, err := conn.OpenChannel(chanHeartbeat, nil) if err != nil { @@ -336,26 +356,26 @@ func (a *Agent) runHeartbeat(conn *ssh.Client) { return nil // time to ping: case <-ticker.C: - log.Debugf("[TUNNEL CLIENT] pings \"%s\" at %s", a.remoteDomainName, conn.RemoteAddr()) _, err := hb.SendRequest("ping", false, nil) if err != nil { log.Error(err) return trace.Wrap(err) } + a.Debugf("ping -> %v", conn.RemoteAddr()) // ssh channel closed: case req := <-reqC: if req == nil { - return trace.Errorf("heartbeat: connection closed") + return trace.ConnectionProblem(nil, "heartbeat: connection closed") } // new access point request: case nch := <-newAccesspointC: if nch == nil { continue } - a.log.Infof("[TUNNEL CLIENT] access point request: %v", nch.ChannelType()) + a.Debugf("access point request: %v", nch.ChannelType()) ch, req, err := nch.Accept() if err != nil { - a.log.Errorf("failed to accept request: %v", err) + a.Warningf("failed to accept request: %v", err) continue } go a.proxyAccessPoint(ch, req) @@ -364,10 +384,22 @@ func (a *Agent) runHeartbeat(conn *ssh.Client) { if nch == nil { continue } - a.log.Infof("[TUNNEL CLIENT] transport request: %v", nch.ChannelType()) + a.Debugf("transport request: %v", nch.ChannelType()) + ch, req, err := nch.Accept() + if err != nil { + a.Warningf("failed to accept request: %v", err) + continue + } + go a.proxyTransport(ch, req) + // new discovery request + case nch := <-newTransportC: + if nch == nil { + continue + } + a.Debugf("transport request: %v", nch.ChannelType()) ch, req, err := nch.Accept() if err != nil { - a.log.Errorf("failed to accept request: %v", err) + a.Warningf("failed to accept request: %v", err) continue } go a.proxyTransport(ch, req) @@ -398,11 +430,50 @@ func (a *Agent) runHeartbeat(conn *ssh.Client) { } } +// handleDisovery receives discovery requests from the reverse tunnel +// server, that informs agent about proxies registered in the remote +// cluster and the reverse tunnels already established +// +// ch : SSH channel which received "teleport-transport" out-of-band request +// reqC : request payload +func (a *Agent) handleDiscovery(ch ssh.Channel, reqC <-chan *ssh.Request) { + a.Debugf("handleDiscovery") + defer ch.Close() + + for { + var req *ssh.Request + select { + case <-a.ctx.Done(): + a.Infof("is closed, returning") + return + case req = <-reqC: + if req == nil { + a.Infof("connection closed, returning") + return + } + r, err := unmarshalDiscoveryRequest(req.Payload) + if err != nil { + a.Warningf("bad payload: %v", err) + return + } + select { + case a.DiscoveryC <- r: + case <-a.ctx.Done(): + a.Infof("is closed, returning") + return + default: + } + req.Reply(true, []byte("thanks")) + } + } +} + const ( chanHeartbeat = "teleport-heartbeat" chanAccessPoint = "teleport-access-point" chanTransport = "teleport-transport" chanTransportDialReq = "teleport-transport-dial" + chanDiscovery = "teleport-discovery" ) const ( diff --git a/lib/reversetunnel/agentpool.go b/lib/reversetunnel/agentpool.go index 069440b7001f2..4ccd2f74a3720 100644 --- a/lib/reversetunnel/agentpool.go +++ b/lib/reversetunnel/agentpool.go @@ -24,10 +24,11 @@ import ( type AgentPool struct { sync.Mutex *log.Entry - cfg AgentPoolConfig - agents map[agentKey]*Agent - ctx context.Context - cancel context.CancelFunc + cfg AgentPoolConfig + agents map[agentKey]*Agent + ctx context.Context + cancel context.CancelFunc + discoveryC chan *discoveryRequest } // AgentPoolConfig holds configuration parameters for the agent pool @@ -68,19 +69,19 @@ func (cfg *AgentPoolConfig) CheckAndSetDefaults() error { // NewAgentPool returns new isntance of the agent pool func NewAgentPool(cfg AgentPoolConfig) (*AgentPool, error) { + if err := cfg.CheckAndSetDefaults(); err != nil { + return nil, trace.Wrap(err) + } ctx, cancel := context.WithCancel(cfg.Context) pool := &AgentPool{ - agents: make(map[agentKey]*Agent), - cfg: cfg, - ctx: ctx, - cancel: cancel, + agents: make(map[agentKey]*Agent), + cfg: cfg, + ctx: ctx, + cancel: cancel, + discoveryC: make(chan *discoveryRequest), } pool.Entry = log.WithFields(log.Fields{ - teleport.Component: teleport.ComponentReverseTunnel, - teleport.ComponentFields: map[string]interface{}{ - "side": "agent", - "mode": "agentpool", - }, + teleport.Component: teleport.ComponentReverseTunnelAgent, }) return pool, nil } @@ -105,6 +106,22 @@ func (m *AgentPool) Wait() error { return nil } +func (m *AgentPool) processDiscoveryRequests() { + for { + select { + case <-m.ctx.Done(): + m.Debugf("closing") + return + case req := <-m.discoveryC: + if req == nil { + m.Debugf("channel closed") + return + } + m.Debugf("got discovery request, following proxies are not connected %v", req.Proxies) + } + } +} + // FetchAndSyncAgents executes one time fetch and sync request // (used in tests instead of polling) func (m *AgentPool) FetchAndSyncAgents() error { @@ -160,7 +177,16 @@ func (m *AgentPool) syncAgents(tunnels []services.ReverseTunnel) error { for _, key := range agentsToAdd { m.Debugf("adding %v", &key) - agent, err := NewAgent(key.addr, key.domainName, m.cfg.HostUUID, m.cfg.HostSigners, m.cfg.Client, m.cfg.AccessPoint, m.ctx) + agent, err := NewAgent(AgentConfig{ + Addr: key.addr, + RemoteCluster: key.domainName, + Username: m.cfg.HostUUID, + Signers: m.cfg.HostSigners, + Client: m.cfg.Client, + AccessPoint: m.cfg.AccessPoint, + Context: m.ctx, + DiscoveryC: m.discoveryC, + }) if err != nil { return trace.Wrap(err) } diff --git a/lib/reversetunnel/discovery.go b/lib/reversetunnel/discovery.go new file mode 100644 index 0000000000000..1f3584ffecd33 --- /dev/null +++ b/lib/reversetunnel/discovery.go @@ -0,0 +1,52 @@ +package reversetunnel + +import ( + "encoding/json" + + "github.com/gravitational/teleport/lib/services" + + "github.com/gravitational/trace" +) + +type discoveryRequest struct { + Proxies []services.Server `json:"proxies"` +} + +type discoveryRequestRaw struct { + Proxies []json.RawMessage `json:"proxies"` +} + +func marshalDiscoveryRequest(req discoveryRequest) ([]byte, error) { + var out discoveryRequestRaw + m := services.GetServerMarshaler() + for _, p := range req.Proxies { + data, err := m.MarshalServer(p) + if err != nil { + return nil, trace.Wrap(err) + } + out.Proxies = append(out.Proxies, data) + } + + return json.Marshal(out) +} + +func unmarshalDiscoveryRequest(data []byte) (*discoveryRequest, error) { + if len(data) == 0 { + return nil, trace.BadParameter("missing payload") + } + var raw discoveryRequestRaw + err := json.Unmarshal(data, &raw) + if err != nil { + return nil, trace.Wrap(err) + } + m := services.GetServerMarshaler() + var out discoveryRequest + for _, bytes := range raw.Proxies { + proxy, err := m.UnmarshalServer([]byte(bytes), services.KindProxy) + if err != nil { + return nil, trace.Wrap(err) + } + out.Proxies = append(out.Proxies, proxy) + } + return &out, nil +} diff --git a/lib/reversetunnel/localsite.go b/lib/reversetunnel/localsite.go index 9f05968477820..5a38d4871f533 100644 --- a/lib/reversetunnel/localsite.go +++ b/lib/reversetunnel/localsite.go @@ -40,11 +40,9 @@ func newlocalSite(srv *server, domainName string, client auth.ClientI) (*localSi accessPoint: accessPoint, domainName: domainName, log: log.WithFields(log.Fields{ - teleport.Component: teleport.ComponentReverseTunnel, + teleport.Component: teleport.ComponentReverseTunnelServer, teleport.ComponentFields: map[string]string{ - "domainName": domainName, - "side": "server", - "type": "localSite", + "cluster": domainName, }, }), }, nil @@ -77,7 +75,7 @@ func (s *localSite) GetClient() (auth.ClientI, error) { } func (s *localSite) String() string { - return fmt.Sprintf("localSite(%v)", s.domainName) + return fmt.Sprintf("local(%v)", s.domainName) } func (s *localSite) GetStatus() string { @@ -94,7 +92,7 @@ func (s *localSite) GetLastConnected() time.Time { // Dial dials a given host in this site (cluster). func (s *localSite) Dial(from net.Addr, to net.Addr) (net.Conn, error) { - s.log.Debugf("[PROXY] localSite.Dial(from=%v, to=%v)", from, to) + s.log.Debugf("local.Dial(from=%v, to=%v)", from, to) return net.Dial(to.Network(), to.String()) } diff --git a/lib/reversetunnel/peer.go b/lib/reversetunnel/peer.go index 4e6b26ed42b17..7035fae25a5c8 100644 --- a/lib/reversetunnel/peer.go +++ b/lib/reversetunnel/peer.go @@ -127,10 +127,9 @@ func newClusterPeer(srv *server, connInfo services.TunnelConnection) (*clusterPe srv: srv, connInfo: connInfo, log: log.WithFields(log.Fields{ - teleport.Component: teleport.ComponentReverseTunnel, + teleport.Component: teleport.ComponentReverseTunnelServer, teleport.ComponentFields: map[string]string{ "cluster": connInfo.GetClusterName(), - "side": "server", }, }), } diff --git a/lib/reversetunnel/remotesite.go b/lib/reversetunnel/remotesite.go index 46be0a0e9e5a1..a299509abb2f6 100644 --- a/lib/reversetunnel/remotesite.go +++ b/lib/reversetunnel/remotesite.go @@ -14,9 +14,11 @@ See the License for the specific language governing permissions and limitations under the License. */ + package reversetunnel import ( + "context" "fmt" "io/ioutil" "net" @@ -44,17 +46,17 @@ import ( type remoteSite struct { sync.Mutex - log *log.Entry + *log.Entry domainName string connections []*remoteConn lastUsed int lastActive time.Time srv *server - transport *http.Transport clt *auth.Client accessPoint auth.AccessPoint connInfo services.TunnelConnection + ctx context.Context } func (s *remoteSite) CachingAccessPoint() (auth.AccessPoint, error) { @@ -100,7 +102,7 @@ func (s *remoteSite) addConn(conn net.Conn, sshConn ssh.Conn) (*remoteConn, erro rc := &remoteConn{ sshConn: sshConn, conn: conn, - log: s.log, + log: s.Entry, } s.Lock() @@ -114,7 +116,7 @@ func (s *remoteSite) addConn(conn net.Conn, sshConn ssh.Conn) (*remoteConn, erro func (s *remoteSite) getLatestTunnelConnection() (services.TunnelConnection, error) { conns, err := s.srv.AccessPoint.GetTunnelConnections(s.domainName) if err != nil { - s.log.Warningf("[TUNNEL] failed to fetch tunnel statuses: %v", err) + s.Warningf("failed to fetch tunnel statuses: %v", err) return nil, trace.Wrap(err) } var lastConn services.TunnelConnection @@ -125,7 +127,7 @@ func (s *remoteSite) getLatestTunnelConnection() (services.TunnelConnection, err } } if lastConn == nil { - return nil, trace.NotFound("no connections from %v found in the cluster", s.domainName) + return nil, trace.NotFound("no connections found") } return lastConn, nil } @@ -146,24 +148,27 @@ func (s *remoteSite) registerHeartbeat(t time.Time) { s.connInfo.SetLastHeartbeat(t) err := s.srv.AccessPoint.UpsertTunnelConnection(s.connInfo) if err != nil { - log.Warningf("[TUNNEL] failed to register heartbeat: %v", err) + log.Warningf("failed to register heartbeat: %v", err) } } func (s *remoteSite) handleHeartbeat(conn *remoteConn, ch ssh.Channel, reqC <-chan *ssh.Request) { defer func() { - s.log.Infof("[TUNNEL] cluster connection closed: %v", s.domainName) + s.Infof("cluster connection closed") conn.Close() }() for { select { + case <-s.ctx.Done(): + s.Infof("closing") + return case req := <-reqC: if req == nil { - s.log.Infof("[TUNNEL] cluster disconnected: %v", s.domainName) + s.Infof("cluster disconnected") conn.markInvalid(trace.ConnectionProblem(nil, "agent disconnected")) return } - log.Debugf("[TUNNEL] ping from \"%s\" %s", s.domainName, conn.conn.RemoteAddr()) + s.Debugf("ping <- %v", conn.conn.RemoteAddr()) go s.registerHeartbeat(time.Now()) case <-time.After(3 * defaults.ReverseTunnelAgentHeartbeatPeriod): conn.markInvalid(trace.ConnectionProblem(nil, "agent missed 3 heartbeats")) @@ -183,10 +188,96 @@ func (s *remoteSite) GetLastConnected() time.Time { return connInfo.GetLastHeartbeat() } +func (s *remoteSite) periodicSendDiscoveryRequests() { + ticker := time.NewTicker(defaults.ReverseTunnelAgentHeartbeatPeriod) + defer ticker.Stop() + if err := s.sendDiscoveryRequest(); err != nil { + s.Warningf("failed to fetch cluster peers: %v", err) + } + for { + select { + case <-s.ctx.Done(): + s.Debugf("closing") + return + case <-ticker.C: + err := s.sendDiscoveryRequest() + if err != nil { + s.Warningf("could not send discovery request: %v", err) + } + } + } +} + +// findDisconnectedProxies +func (s *remoteSite) findDisconnectedProxies() ([]services.Server, error) { + conns, err := s.srv.AccessPoint.GetTunnelConnections(s.domainName) + if err != nil { + return nil, trace.Wrap(err) + } + connected := make(map[string]bool) + for _, conn := range conns { + connected[conn.GetProxyName()] = true + } + proxies, err := s.srv.AccessPoint.GetProxies() + if err != nil { + return nil, trace.Wrap(err) + } + var missing []services.Server + for i := range proxies { + proxy := proxies[i] + if !connected[proxy.GetName()] { + missing = append(missing, proxy) + } + } + return missing, nil +} + +func (s *remoteSite) sendDiscoveryRequest() error { + disconnectedProxies, err := s.findDisconnectedProxies() + if err != nil { + return trace.Wrap(err) + } + if len(disconnectedProxies) == 0 { + return nil + } + s.Infof("detected disconnected proxies: %v", disconnectedProxies) + req := discoveryRequest{ + Proxies: disconnectedProxies, + } + payload, err := marshalDiscoveryRequest(req) + if err != nil { + return trace.Wrap(err) + } + send := func() error { + remoteConn, err := s.nextConn() + if err != nil { + return trace.Wrap(err) + } + discoveryC, err := remoteConn.openDiscoveryChannel() + if err != nil { + return trace.Wrap(err) + } + _, err = discoveryC.SendRequest("ping", false, payload) + remoteConn.markInvalid(err) + s.Errorf("disconnecting cluster on %v, err: %v", + remoteConn.conn.RemoteAddr(), + err) + return trace.Wrap(err) + } + + for i := 0; i < s.connectionCount(); i++ { + err := send() + if err != nil { + s.Warningf("%v") + } + } + return nil +} + // dialAccessPoint establishes a connection from the proxy (reverse tunnel server) // back into the client using previously established tunnel. func (s *remoteSite) dialAccessPoint(network, addr string) (net.Conn, error) { - s.log.Infof("[TUNNEL] dial to site '%s'", s.GetName()) + s.Debugf("dialAccessPoint") try := func() (net.Conn, error) { remoteConn, err := s.nextConn() @@ -196,13 +287,12 @@ func (s *remoteSite) dialAccessPoint(network, addr string) (net.Conn, error) { ch, _, err := remoteConn.sshConn.OpenChannel(chanAccessPoint, nil) if err != nil { remoteConn.markInvalid(err) - s.log.Errorf("[TUNNEL] disconnecting site '%s' on %v. Err: %v", - s.GetName(), + s.Errorf("disconnecting cluster on %v, err: %v", remoteConn.conn.RemoteAddr(), err) return nil, trace.Wrap(err) } - s.log.Infof("[TUNNEL] success dialing to site '%s'", s.GetName()) + s.Infof("success dialing to cluster") return utils.NewChConn(remoteConn.sshConn, ch), nil } @@ -222,7 +312,7 @@ func (s *remoteSite) dialAccessPoint(network, addr string) (net.Conn, error) { // located in a remote connected site, the connection goes through the // reverse proxy tunnel. func (s *remoteSite) Dial(from, to net.Addr) (conn net.Conn, err error) { - s.log.Infof("[TUNNEL] dialing %v@%v through the tunnel", to, s.domainName) + s.Debugf("dialing %v through the tunnel", to) stop := false _, addr := to.Network(), to.String() @@ -268,7 +358,7 @@ func (s *remoteSite) Dial(from, to net.Addr) (conn net.Conn, err error) { if err == nil { return conn, nil } - s.log.Errorf("[TUNNEL] Dial(addr=%v) failed: %v", addr, err) + s.Warningf("Dial(addr=%v) failed: %v", addr, err) } // didn't connect and no error? this means we didn't have any connected // tunnels to try @@ -279,9 +369,9 @@ func (s *remoteSite) Dial(from, to net.Addr) (conn net.Conn, err error) { } func (s *remoteSite) handleAuthProxy(w http.ResponseWriter, r *http.Request) { - s.log.Infof("[TUNNEL] handleAuthProxy()") + s.Debugf("handleAuthProxy()") - fwd, err := forward.New(forward.RoundTripper(s.transport), forward.Logger(s.log)) + fwd, err := forward.New(forward.RoundTripper(s.transport), forward.Logger(s.Entry)) if err != nil { roundtrip.ReplyJSON(w, http.StatusInternalServerError, err.Error()) return diff --git a/lib/reversetunnel/srv.go b/lib/reversetunnel/srv.go index 7ff735f6486de..a9c274d99e8de 100644 --- a/lib/reversetunnel/srv.go +++ b/lib/reversetunnel/srv.go @@ -158,7 +158,7 @@ func NewServer(cfg Config) (Server, error) { var err error s, err := sshutils.NewServer( - teleport.ComponentReverseTunnel, + teleport.ComponentReverseTunnelServer, cfg.ListenAddr, srv, cfg.HostSigners, @@ -604,11 +604,29 @@ func (s *server) RemoveSite(domainName string) error { } type remoteConn struct { - sshConn ssh.Conn - conn net.Conn - invalid int32 - log *log.Entry - counter int32 + sshConn ssh.Conn + conn net.Conn + invalid int32 + log *log.Entry + counter int32 + discoveryC ssh.Channel + discoveryErr error +} + +func (rc *remoteConn) openDiscoveryChannel() (ssh.Channel, error) { + if rc.discoveryC != nil { + return rc.discoveryC, nil + } + if rc.discoveryErr != nil { + return nil, trace.Wrap(rc.discoveryErr) + } + discoveryC, _, err := rc.sshConn.OpenChannel(chanDiscovery, nil) + if err != nil { + rc.discoveryErr = err + return nil, trace.Wrap(err) + } + rc.discoveryC = discoveryC + return rc.discoveryC, nil } func (rc *remoteConn) String() string { @@ -616,6 +634,10 @@ func (rc *remoteConn) String() string { } func (rc *remoteConn) Close() error { + if rc.discoveryC != nil { + rc.discoveryC.Close() + rc.discoveryC = nil + } return rc.sshConn.Close() } @@ -645,13 +667,13 @@ func newRemoteSite(srv *server, domainName string) (*remoteSite, error) { srv: srv, domainName: domainName, connInfo: connInfo, - log: log.WithFields(log.Fields{ - teleport.Component: teleport.ComponentReverseTunnel, + Entry: log.WithFields(log.Fields{ + teleport.Component: teleport.ComponentReverseTunnelServer, teleport.ComponentFields: map[string]string{ - "domainName": domainName, - "side": "server", + "cluster": domainName, }, }), + ctx: srv.ctx, } // transport uses connection do dial out to the remote address remoteSite.transport = &http.Transport{ diff --git a/lib/utils/cli.go b/lib/utils/cli.go index 53316b326a348..248eee05b14a6 100644 --- a/lib/utils/cli.go +++ b/lib/utils/cli.go @@ -44,8 +44,7 @@ const ( // InitLogger configures the global logger for a given purpose / verbosity level func InitLogger(purpose LoggingPurpose, level log.Level) { log.StandardLogger().Hooks = make(log.LevelHooks) - formatter := &trace.TextFormatter{} - formatter.DisableTimestamp = true + formatter := &trace.TextFormatter{DisableTimestamp: true} log.SetFormatter(formatter) log.SetLevel(level) diff --git a/vendor/github.com/gravitational/trace/log.go b/vendor/github.com/gravitational/trace/log.go index 58f7858e49452..310a90af48692 100644 --- a/vendor/github.com/gravitational/trace/log.go +++ b/vendor/github.com/gravitational/trace/log.go @@ -18,7 +18,12 @@ limitations under the License. package trace import ( + "bytes" + "fmt" "regexp" + "sort" + "strings" + "time" log "github.com/sirupsen/logrus" @@ -40,20 +45,55 @@ const ( // TextFormatter is logrus-compatible formatter and adds // file and line details to every logged entry. type TextFormatter struct { - log.TextFormatter + DisableTimestamp bool } // Format implements logrus.Formatter interface and adds file and line func (tf *TextFormatter) Format(e *log.Entry) ([]byte, error) { + var file string if frameNo := findFrame(); frameNo != -1 { t := newTrace(frameNo, nil) - new := e.WithFields(log.Fields{FileField: t.Loc(), FunctionField: t.FuncName()}) - new.Time = e.Time - new.Level = e.Level - new.Message = e.Message - e = new + file = t.Loc() + } + + w := &writer{bytes.Buffer{}} + + // time + if !tf.DisableTimestamp { + w.writeField(e.Time.Format(time.RFC3339)) } - return (&tf.TextFormatter).Format(e) + + // level + w.writeField(strings.ToUpper(padMax(e.Level.String(), 4))) + + // component if present, highly visible + component, ok := e.Data[Component] + if ok { + if w.Len() > 0 { + w.WriteByte(' ') + } + w.WriteByte('[') + w.WriteString(strings.ToUpper(padMax(fmt.Sprintf("%v", component), 11))) + w.WriteByte(']') + } + + // message + if e.Message != "" { + w.writeField(e.Message) + } + + // file, if present + if file != "" { + w.writeField(file) + } + + // rest of the fields + if len(e.Data) > 0 { + w.WriteByte(' ') + w.writeMap(e.Data) + } + w.WriteByte('\n') + return w.Bytes(), nil } // JSONFormatter implements logrus.Formatter interface and adds file and line @@ -91,3 +131,78 @@ func findFrame() int { } return -1 } + +type writer struct { + bytes.Buffer +} + +func (w *writer) writeField(value interface{}) { + if w.Len() > 0 { + w.WriteByte(' ') + } + w.writeValue(value) +} + +func (w *writer) writeValue(value interface{}) { + stringVal, ok := value.(string) + if !ok { + stringVal = fmt.Sprint(value) + } + if !needsQuoting(stringVal) { + w.WriteString(stringVal) + } else { + w.WriteString(fmt.Sprintf("%q", stringVal)) + } +} + +func (w *writer) writeKeyValue(key string, value interface{}) { + if w.Len() > 0 { + w.WriteByte(' ') + } + w.WriteString(key) + w.WriteByte(':') + w.writeValue(value) +} + +func (w *writer) writeMap(m map[string]interface{}) { + if len(m) == 0 { + return + } + keys := make([]string, 0, len(m)) + for key := range m { + keys = append(keys, key) + } + sort.Strings(keys) + for _, key := range keys { + if key == Component { + continue + } + switch val := m[key].(type) { + case map[string]interface{}: + w.WriteString(key) + w.WriteString(":{") + w.writeMap(val) + w.WriteString(" }") + default: + w.writeKeyValue(key, val) + } + } +} + +func needsQuoting(text string) bool { + for _, ch := range text { + if ch < 32 { + return true + } + } + return false +} + +func padMax(in string, chars int) string { + switch { + case len(in) < chars: + return in + strings.Repeat(" ", chars-len(in)) + default: + return in[:chars] + } +} From 6e4d6b0cb2fe51415751ffbc5e5a4bb97daae958 Mon Sep 17 00:00:00 2001 From: Sasha Klizhentas Date: Sat, 7 Oct 2017 18:11:03 -0700 Subject: [PATCH 04/24] more work, discovery works --- constants.go | 11 +- docker/Makefile | 4 + docker/docker-compose.yml | 4 + docker/two-tc.yaml | 4 +- lib/auth/permissions.go | 1 + lib/auth/tun.go | 28 ++-- lib/backend/dir/impl.go | 12 +- lib/reversetunnel/agent.go | 162 ++++++++++++++++--- lib/reversetunnel/agentpool.go | 148 +++++++++++++---- lib/reversetunnel/discovery.go | 45 +++++- lib/reversetunnel/localsite.go | 4 +- lib/reversetunnel/peer.go | 4 +- lib/reversetunnel/remotesite.go | 48 +++--- lib/reversetunnel/srv.go | 46 ++++-- lib/service/service.go | 3 +- lib/services/local/presence.go | 7 +- lib/services/tunnelconn.go | 6 +- lib/srv/sshserver.go | 6 +- lib/sshutils/server.go | 31 ++-- lib/state/cachingaccesspoint.go | 11 +- lib/utils/proxy/proxy.go | 8 +- vendor/github.com/gravitational/trace/log.go | 18 ++- 22 files changed, 469 insertions(+), 142 deletions(-) diff --git a/constants.go b/constants.go index b5e8eb35080ad..d5e4c3e6a3653 100644 --- a/constants.go +++ b/constants.go @@ -51,12 +51,6 @@ const ( ) const ( - // Component indicates a component of teleport, used for logging - Component = "component" - - // ComponentFields stores component-specific fields - ComponentFields = "fields" - // ComponentReverseTunnelServer is reverse tunnel server // that together with agent establish a bi-directional SSH revers tunnel // to bypass firewall restrictions @@ -77,7 +71,10 @@ const ( ComponentProxy = "proxy" // ComponentTunClient is a tunnel client - ComponentTunClient = "tunclient" + ComponentTunClient = "client:tunnel" + + // ComponentCachingClient is a caching auth client + ComponentCachingClient = "client:cache" // DebugEnvVar tells tests to use verbose debug output DebugEnvVar = "DEBUG" diff --git a/docker/Makefile b/docker/Makefile index ff9fab5def4c2..ffe4e9d0274c5 100644 --- a/docker/Makefile +++ b/docker/Makefile @@ -48,3 +48,7 @@ enter-two-node: .PHONY: setup-tc setup-tc: docker exec -i two-auth /bin/bash -c "tctl -c /root/go/src/github.com/gravitational/teleport/docker/two-auth.yaml create -f /root/go/src/github.com/gravitational/teleport/docker/two-tc.yaml" + +.PHONY: delete-tc +delete-tc: + docker exec -i two-auth /bin/bash -c "tctl -c /root/go/src/github.com/gravitational/teleport/docker/two-auth.yaml rm tc/one" diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index 3fbd604775e82..f76d9b0fefa31 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -18,6 +18,8 @@ services: networks: teleport: ipv4_address: 172.10.1.1 + aliases: + - one-lb # # one-proxy is a second xproxy of the first cluster @@ -36,6 +38,8 @@ services: networks: teleport: ipv4_address: 172.10.1.10 + aliases: + - one-lb # # two-auth is a auth server of the second cluster diff --git a/docker/two-tc.yaml b/docker/two-tc.yaml index 9921e611c429d..e41b993d5d962 100644 --- a/docker/two-tc.yaml +++ b/docker/two-tc.yaml @@ -5,5 +5,5 @@ metadata: spec: enabled: true token: "bar" - tunnel_addr: one:3024 - web_proxy_addr: one:3080 + tunnel_addr: one-lb:3024 + web_proxy_addr: one-lb:3080 diff --git a/lib/auth/permissions.go b/lib/auth/permissions.go index f7c060db30758..a2975511b2aa0 100644 --- a/lib/auth/permissions.go +++ b/lib/auth/permissions.go @@ -176,6 +176,7 @@ func GetCheckerForBuiltinRole(role teleport.Role) (services.AccessChecker, error services.NewRule(services.KindRole, services.RO()), services.NewRule(services.KindAuthServer, services.RO()), services.NewRule(services.KindReverseTunnel, services.RO()), + services.NewRule(services.KindTunnelConnection, services.RO()), }, }, }) diff --git a/lib/auth/tun.go b/lib/auth/tun.go index d17eef51869f2..a21d3c0eacd68 100644 --- a/lib/auth/tun.go +++ b/lib/auth/tun.go @@ -76,6 +76,8 @@ type TunClient struct { // embed auth API HTTP client Client + *log.Entry + user string // static auth servers are CAs set via configuration (--auth flag) and @@ -776,6 +778,12 @@ func NewTunClient(purpose string, return nil, trace.Wrap(err) } tc := &TunClient{ + Entry: log.WithFields(log.Fields{ + trace.Component: teleport.ComponentTunClient, + trace.ComponentFields: log.Fields{ + "purpose": purpose, + }, + }), purpose: purpose, user: user, staticAuthServers: authServers, @@ -786,7 +794,7 @@ func NewTunClient(purpose string, for _, o := range opts { o(tc) } - log.Debugf("NewTunClient(%v) with auth: %v", purpose, authServers) + tc.Debugf("created, auth servers: %v", authServers) clt, err := NewClient("http://stub:0", tc.Dial) if err != nil { @@ -799,7 +807,7 @@ func NewTunClient(purpose string, cachedAuthServers, err := tc.addrStorage.GetAddresses() if err != nil { if !trace.IsNotFound(err) { - log.Warnf("unable to load the auth server cache: %s", err.Error()) + tc.Warnf("unable to load the auth server cache: %s", err.Error()) } } else { tc.setAuthServers(cachedAuthServers) @@ -828,7 +836,7 @@ func (c *TunClient) String() string { // Close releases all the resources allocated for this client func (c *TunClient) Close() error { if c != nil { - log.Debugf("%v.Close()", c) + c.Debugf("is closing") c.GetTransport().CloseIdleConnections() c.closeOnce.Do(func() { close(c.closeC) @@ -850,7 +858,7 @@ func (c *TunClient) GetDialer() AccessPointDialer { } time.Sleep(4 * time.Duration(attempt) * dialRetryInterval) } - log.Errorf("%v: ", err) + c.Error("%v", err) return nil, trace.Wrap(err) } } @@ -877,7 +885,7 @@ func (c *TunClient) GetAgent() (AgentCloser, error) { // Dial dials to Auth server's HTTP API over SSH tunnel. func (c *TunClient) Dial(network, address string) (net.Conn, error) { - log.Debugf("TunClient[%s].Dial()", c.purpose) + c.Debugf("dialing %v %v", network, address) client, err := c.getClient() if err != nil { @@ -918,7 +926,7 @@ func (c *TunClient) fetchAndSync() error { // authServersSyncLoop continuously refreshes the list of available auth servers // for this client func (c *TunClient) authServersSyncLoop() { - log.Debugf("%v: authServersSyncLoop() started", c) + c.Debugf("authServersSyncLoop started") defer c.refreshTicker.Stop() // initial fetch for quick start-ups @@ -930,7 +938,7 @@ func (c *TunClient) authServersSyncLoop() { c.fetchAndSync() // received a signal to quit? case <-c.closeC: - log.Debugf("%v: authServersSyncLoop() exited", c) + c.Debugf("authServersSyncLoop exited") return } } @@ -995,7 +1003,7 @@ func (c *TunClient) getClient() (client *ssh.Client, err error) { if len(authServers) == 0 { return nil, trace.ConnectionProblem(nil, "all auth servers are offline") } - log.Debugf("%v.authServers: %v", c, authServers) + c.Debugf("auth servers: %v", authServers) // try to connect to the 1st one who will pick up: for _, authServer := range authServers { @@ -1011,7 +1019,7 @@ func (c *TunClient) getClient() (client *ssh.Client, err error) { if trace.IsAccessDenied(err) { return nil, trace.Wrap(err) } - log.Errorf("%v.getClient() error while connecting to auth server %v: %v: throttling", c, authServer, err) + c.Errorf("getClient error while connecting to auth server %v: %v: throttling", authServer, err) c.throttleAuthServer(authServer.String()) } return nil, trace.ConnectionProblem(nil, "all auth servers are offline") @@ -1025,7 +1033,7 @@ func (c *TunClient) dialAuthServer(authServer utils.NetAddr) (sshClient *ssh.Cli } const dialRetryTimes = 1 for attempt := 0; attempt < dialRetryTimes; attempt++ { - log.Debugf("%v.Dial(to=%v, attempt=%d)", c, authServer.Addr, attempt+1) + c.Debugf("dialing %v, attempt %d", authServer.Addr, attempt+1) sshClient, err = ssh.Dial(authServer.AddrNetwork, authServer.Addr, config) // success -> get out of here if err == nil { diff --git a/lib/backend/dir/impl.go b/lib/backend/dir/impl.go index d972423ac3a46..492f50a092bcf 100644 --- a/lib/backend/dir/impl.go +++ b/lib/backend/dir/impl.go @@ -58,6 +58,8 @@ type Backend struct { // InternalClock is a test-friendly source of current time InternalClock clockwork.Clock + + *log.Entry } func (b *Backend) Clock() clockwork.Clock { @@ -82,6 +84,12 @@ func New(params backend.Params) (backend.Backend, error) { bk := &Backend{ RootDir: rootDir, InternalClock: clockwork.NewRealClock(), + Entry: log.WithFields(log.Fields{ + trace.Component: "fs", + trace.ComponentFields: log.Fields{ + "dir": rootDir, + }, + }), } locksDir := path.Join(bk.RootDir, locksBucket) @@ -244,7 +252,7 @@ func removeFiles(dir string) error { // AcquireLock grabs a lock that will be released automatically in TTL func (bk *Backend) AcquireLock(token string, ttl time.Duration) (err error) { - log.Debugf("fs.AcquireLock(%s)", token) + bk.Debugf("AcquireLock(%s)", token) if err = backend.ValidateLockTTL(ttl); err != nil { return trace.Wrap(err) @@ -271,7 +279,7 @@ func (bk *Backend) AcquireLock(token string, ttl time.Duration) (err error) { // ReleaseLock forces lock release before TTL func (bk *Backend) ReleaseLock(token string) (err error) { - log.Debugf("fs.ReleaseLock(%s)", token) + bk.Debugf("ReleaseLock(%s)", token) if err = bk.DeleteKey([]string{locksBucket}, token); err != nil { if !os.IsNotExist(err) { diff --git a/lib/reversetunnel/agent.go b/lib/reversetunnel/agent.go index 18e47441fab27..3a51c4dfba321 100644 --- a/lib/reversetunnel/agent.go +++ b/lib/reversetunnel/agent.go @@ -37,10 +37,29 @@ import ( "github.com/gravitational/teleport/lib/utils/proxy" "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" log "github.com/sirupsen/logrus" "golang.org/x/crypto/ssh" ) +const ( + // agentStateConnecting is when agent is connecting to the target + // without particular purpose + agentStateConnecting = "connecting" + // agentStateDiscovering is when agent is created with a goal + // to discover one or many proxies + agentStateDiscovering = "discovering" + // agentStateConnected means that agent has connected to instance + agentStateConnected = "connected" + // agentStateDiscovered means that agent has discovered the right proxy + agentStateDiscovered = "discovered" + // agentStateMissed means that agent has connected, + // but not to the one of the instances it was targeted to discover + agentStateMissed = "missed" + // agentStateClosed is for closed agents + agentStateClosed = "closed" +) + // AgentConfig holds configuration for agent type AgentConfig struct { // Addr is target address to dial @@ -60,6 +79,12 @@ type AgentConfig struct { DiscoveryC chan *discoveryRequest // Username is the name of this client used to authenticate on SSH Username string + // DiscoverProxies is set when the agent is created in discovery mode + // and is set to connect to one of the target proxies from the list + DiscoverProxies []services.Server + // Clock is a clock passed in tests, if not set wall clock + // will be used + Clock clockwork.Clock } // CheckAndSetDefaults checks parameters and sets default values @@ -85,45 +110,93 @@ func (a *AgentConfig) CheckAndSetDefaults() error { if len(a.Username) == 0 { return trace.BadParameter("missing parameter Username") } + if a.Clock == nil { + a.Clock = clockwork.NewRealClock() + } return nil } // Agent is a reverse tunnel agent running as a part of teleport Proxies // to establish outbound reverse tunnels to remote proxies type Agent struct { + sync.RWMutex *log.Entry AgentConfig ctx context.Context cancel context.CancelFunc hostKeyCallback utils.HostKeyCallback authMethods []ssh.AuthMethod + // state is the state of this agent + state string + // stateChange records last time the state was changed + stateChange time.Time + // principals is the list of principals of the server this agent + // is currently connected to + principals []string } // NewAgent returns a new reverse tunnel agent -// Parameters: -// addr points to the remote reverse tunnel server -// remoteDomainName is the domain name of the runnel server, used only for logging -// clientName is hostid.domain (where 'domain' is local domain name) func NewAgent(cfg AgentConfig) (*Agent, error) { + if err := cfg.CheckAndSetDefaults(); err != nil { + return nil, trace.Wrap(err) + } ctx, cancel := context.WithCancel(cfg.Context) - a := &Agent{ AgentConfig: cfg, - Entry: log.WithFields(log.Fields{ - teleport.Component: teleport.ComponentReverseTunnelAgent, - teleport.ComponentFields: map[string]interface{}{ - "remote": cfg.Addr.String(), - "client": cfg.Username, - }, - }), ctx: ctx, cancel: cancel, authMethods: []ssh.AuthMethod{ssh.PublicKeys(cfg.Signers...)}, } + if len(cfg.DiscoverProxies) == 0 { + a.state = agentStateConnecting + } else { + a.state = agentStateDiscovering + } + a.Entry = log.WithFields(log.Fields{ + trace.Component: teleport.ComponentReverseTunnelAgent, + trace.ComponentFields: map[string]interface{}{ + "remote": cfg.Addr.String(), + }, + }) a.hostKeyCallback = a.checkHostSignature + + if len(a.DiscoverProxies) != 0 { + a.setState(agentStateDiscovering) + } else { + a.setState(agentStateConnecting) + } + return a, nil } +func (a *Agent) String() string { + if len(a.DiscoverProxies) == 0 { + return fmt.Sprintf("agent -> cluster %v, target %v", a.RemoteCluster, a.Addr.String()) + } + return fmt.Sprintf("agent -> cluster %v, target %v, discover %v", a.RemoteCluster, a.Addr.String(), Proxies(a.DiscoverProxies)) +} + +func (a *Agent) getLastStateChange() time.Time { + a.RLock() + defer a.RUnlock() + return a.stateChange +} + +func (a *Agent) setState(state string) { + a.Lock() + defer a.Unlock() + prev := a.state + a.Debugf("changing state %v -> %v", prev, state) + a.state = state + a.stateChange = a.Clock.Now().UTC() +} + +func (a *Agent) getState() string { + a.RLock() + defer a.RUnlock() + return a.state +} + // Close signals to close all connections func (a *Agent) Close() error { a.cancel() @@ -146,9 +219,39 @@ func (a *Agent) Wait() error { return nil } -// String returns debug-friendly -func (a *Agent) String() string { - return fmt.Sprintf("tunagent(remote=%s)", a.Addr.String()) +func (a *Agent) connectedToRightProxy() bool { + principals := a.getPrincipals() + for _, proxy := range a.DiscoverProxies { + proxyID := fmt.Sprintf("%v.%v", proxy.GetName(), a.RemoteCluster) + if _, ok := principals[proxyID]; ok { + return true + } + } + return false +} + +func (a *Agent) setPrincipals(principals []string) { + a.Lock() + defer a.Unlock() + a.principals = principals +} + +func (a *Agent) getPrincipalsList() []string { + a.RLock() + defer a.RUnlock() + out := make([]string, len(a.principals)) + copy(out, a.principals) + return out +} + +func (a *Agent) getPrincipals() map[string]struct{} { + a.RLock() + defer a.RUnlock() + out := make(map[string]struct{}, len(a.principals)) + for _, p := range a.principals { + out[p] = struct{}{} + } + return out } func (a *Agent) checkHostSignature(hostport string, remote net.Addr, key ssh.PublicKey) error { @@ -167,7 +270,7 @@ func (a *Agent) checkHostSignature(hostport string, remote net.Addr, key ssh.Pub } for _, checker := range checkers { if sshutils.KeysEqual(checker, cert.SignatureKey) { - a.Debugf("matched key %v for %v", ca.GetName(), hostport) + a.setPrincipals(cert.ValidPrincipals) return nil } } @@ -187,6 +290,16 @@ func (a *Agent) connect() (conn *ssh.Client, err error) { Timeout: defaults.DefaultDialTimeout, }) if conn != nil { + if len(a.DiscoverProxies) != 0 { + if a.connectedToRightProxy() { + a.setState(agentStateDiscovered) + } else { + a.Debugf("missed, connected to %v instead of %v", a.getPrincipalsList(), Proxies(a.DiscoverProxies)) + a.setState(agentStateMissed) + } + } else { + a.setState(agentStateConnected) + } break } } @@ -345,6 +458,7 @@ func (a *Agent) runHeartbeat(conn *ssh.Client) { } newAccesspointC := conn.HandleChannelOpen(chanAccessPoint) newTransportC := conn.HandleChannelOpen(chanTransport) + newDiscoveryC := conn.HandleChannelOpen(chanDiscovery) // send first ping right away, then start a ping timer: hb.SendRequest("ping", false, nil) @@ -392,17 +506,17 @@ func (a *Agent) runHeartbeat(conn *ssh.Client) { } go a.proxyTransport(ch, req) // new discovery request - case nch := <-newTransportC: + case nch := <-newDiscoveryC: if nch == nil { continue } - a.Debugf("transport request: %v", nch.ChannelType()) + a.Debugf("discovery request: %v", nch.ChannelType()) ch, req, err := nch.Accept() if err != nil { a.Warningf("failed to accept request: %v", err) continue } - go a.proxyTransport(ch, req) + go a.handleDiscovery(ch, req) } } } @@ -410,12 +524,17 @@ func (a *Agent) runHeartbeat(conn *ssh.Client) { // run heartbeat loop, and when it fails (probably means that a tunnel got disconnected) // keep repeating to reconnect until we're asked to stop err := heartbeatLoop() + if len(a.DiscoverProxies) != 0 { + a.setState(agentStateDiscovering) + } else { + a.setState(agentStateConnecting) + } // when this happens, this is #1 issue we have right now with Teleport. So I'm making // it EASY to see in the logs. This condition should never be permanent (like repeates // every XX seconds) if err != nil { - log.Warn(err) + a.Warn(err) } if err != nil || conn == nil { @@ -456,6 +575,8 @@ func (a *Agent) handleDiscovery(ch ssh.Channel, reqC <-chan *ssh.Request) { a.Warningf("bad payload: %v", err) return } + r.ClusterName = a.RemoteCluster + r.ClusterAddr = a.Addr select { case a.DiscoveryC <- r: case <-a.ctx.Done(): @@ -463,7 +584,6 @@ func (a *Agent) handleDiscovery(ch ssh.Channel, reqC <-chan *ssh.Request) { return default: } - req.Reply(true, []byte("thanks")) } } } diff --git a/lib/reversetunnel/agentpool.go b/lib/reversetunnel/agentpool.go index 4ccd2f74a3720..289baf113a999 100644 --- a/lib/reversetunnel/agentpool.go +++ b/lib/reversetunnel/agentpool.go @@ -25,7 +25,7 @@ type AgentPool struct { sync.Mutex *log.Entry cfg AgentPoolConfig - agents map[agentKey]*Agent + agents map[agentKey][]*Agent ctx context.Context cancel context.CancelFunc discoveryC chan *discoveryRequest @@ -45,6 +45,8 @@ type AgentPoolConfig struct { HostUUID string // Context is an optional context Context context.Context + // Cluster is a cluster name + Cluster string } // CheckAndSetDefaults checks and sets defaults @@ -74,14 +76,17 @@ func NewAgentPool(cfg AgentPoolConfig) (*AgentPool, error) { } ctx, cancel := context.WithCancel(cfg.Context) pool := &AgentPool{ - agents: make(map[agentKey]*Agent), + agents: make(map[agentKey][]*Agent), cfg: cfg, ctx: ctx, cancel: cancel, discoveryC: make(chan *discoveryRequest), } pool.Entry = log.WithFields(log.Fields{ - teleport.Component: teleport.ComponentReverseTunnelAgent, + trace.Component: teleport.ComponentReverseTunnelAgent, + trace.ComponentFields: map[string]interface{}{ + "cluster": cfg.Cluster, + }, }) return pool, nil } @@ -89,6 +94,7 @@ func NewAgentPool(cfg AgentPoolConfig) (*AgentPool, error) { // Start starts the agent pool func (m *AgentPool) Start() error { go m.pollAndSyncAgents() + go m.processDiscoveryRequests() return nil } @@ -117,11 +123,38 @@ func (m *AgentPool) processDiscoveryRequests() { m.Debugf("channel closed") return } - m.Debugf("got discovery request, following proxies are not connected %v", req.Proxies) + m.tryDiscover(*req) } } } +func (m *AgentPool) tryDiscover(req discoveryRequest) { + m.Debugf("will need to discover: %v", Proxies(req.Proxies)) + m.Lock() + defer m.Unlock() + proxies := Proxies(req.Proxies) + matchKey := req.key() + var foundAgent bool + // close agents that are discovering proxies that are somehow + // different from discovery request + m.closeAgentsIf(&matchKey, func(agent *Agent) bool { + if agent.getState() != agentStateDiscovering { + return false + } + if proxies.Equal(agent.DiscoverProxies) { + foundAgent = true + agent.Debugf("is already discovering %v, nothing to do", req) + return false + } + return true + }) + + // if we haven't found any discovery agent + if !foundAgent { + m.addAgent(req.key(), req.Proxies) + } +} + // FetchAndSyncAgents executes one time fetch and sync request // (used in tests instead of polling) func (m *AgentPool) FetchAndSyncAgents() error { @@ -135,6 +168,45 @@ func (m *AgentPool) FetchAndSyncAgents() error { return nil } +func (m *AgentPool) withLock(f func()) { + m.Lock() + defer m.Unlock() + f() +} + +type matchAgentFn func(a *Agent) bool + +func (m *AgentPool) closeAgentsIf(matchKey *agentKey, matchAgent matchAgentFn) { + if matchKey != nil { + m.agents[*matchKey] = filterAndClose(m.agents[*matchKey], matchAgent) + return + } + for key, agents := range m.agents { + m.agents[key] = filterAndClose(agents, matchAgent) + } +} + +func filterAndClose(agents []*Agent, matchAgent matchAgentFn) []*Agent { + var filtered []*Agent + for i := range agents { + agent := agents[i] + if matchAgent(agent) { + agent.Debugf("pool is closing agent") + agent.Close() + } else { + filtered = append(filtered, agent) + } + } + return filtered +} + +func (m *AgentPool) closeAgents(matchKey *agentKey) { + m.closeAgentsIf(matchKey, func(*Agent) bool { + // close all agents matching the matchKey + return true + }) +} + func (m *AgentPool) pollAndSyncAgents() { ticker := time.NewTicker(defaults.ReverseTunnelAgentHeartbeatPeriod) defer ticker.Stop() @@ -142,12 +214,10 @@ func (m *AgentPool) pollAndSyncAgents() { for { select { case <-m.ctx.Done(): + m.withLock(func() { + m.closeAgents(nil) + }) m.Debugf("closing") - m.Lock() - defer m.Unlock() - for _, a := range m.agents { - a.Close() - } return case <-ticker.C: err := m.FetchAndSyncAgents() @@ -159,6 +229,31 @@ func (m *AgentPool) pollAndSyncAgents() { } } +func (m *AgentPool) addAgent(key agentKey, discoverProxies []services.Server) error { + agent, err := NewAgent(AgentConfig{ + Addr: key.addr, + RemoteCluster: key.domainName, + Username: m.cfg.HostUUID, + Signers: m.cfg.HostSigners, + Client: m.cfg.Client, + AccessPoint: m.cfg.AccessPoint, + Context: m.ctx, + DiscoveryC: m.discoveryC, + DiscoverProxies: discoverProxies, + }) + if err != nil { + return trace.Wrap(err) + } + m.Debugf("adding %v", agent) + // start the agent in a goroutine. no need to handle Start() errors: Start() will be + // retrying itself until the agent is closed + go agent.Start() + agents, _ := m.agents[key] + agents = append(agents, agent) + m.agents[key] = agents + return nil +} + func (m *AgentPool) syncAgents(tunnels []services.ReverseTunnel) error { m.Lock() defer m.Unlock() @@ -168,33 +263,24 @@ func (m *AgentPool) syncAgents(tunnels []services.ReverseTunnel) error { return trace.Wrap(err) } agentsToAdd, agentsToRemove := diffTunnels(m.agents, keys) + // remove agents from deleted reverse tunnels for _, key := range agentsToRemove { - m.Debugf("removing %v", &key) - agent := m.agents[key] - delete(m.agents, key) - agent.Close() + m.closeAgents(&key) } - + // add agents from added reverse tunnels for _, key := range agentsToAdd { - m.Debugf("adding %v", &key) - agent, err := NewAgent(AgentConfig{ - Addr: key.addr, - RemoteCluster: key.domainName, - Username: m.cfg.HostUUID, - Signers: m.cfg.HostSigners, - Client: m.cfg.Client, - AccessPoint: m.cfg.AccessPoint, - Context: m.ctx, - DiscoveryC: m.discoveryC, - }) - if err != nil { + if err := m.addAgent(key, nil); err != nil { return trace.Wrap(err) } - // start the agent in a goroutine. no need to handle Start() errors: Start() will be - // retrying itself until the agent is closed - go agent.Start() - m.agents[key] = agent } + // garbage collect agents that have connected, but to the wrong proxy + m.closeAgentsIf(nil, func(agent *Agent) bool { + if agent.getState() == agentStateMissed { + agent.Debugf("closing agent that could not discover clusters") + return true + } + return false + }) return nil } @@ -224,7 +310,7 @@ func tunnelToAgentKeys(tunnel services.ReverseTunnel) ([]agentKey, error) { return out, nil } -func diffTunnels(existingTunnels map[agentKey]*Agent, arrivedKeys map[agentKey]bool) ([]agentKey, []agentKey) { +func diffTunnels(existingTunnels map[agentKey][]*Agent, arrivedKeys map[agentKey]bool) ([]agentKey, []agentKey) { var agentsToRemove, agentsToAdd []agentKey for existingKey := range existingTunnels { if _, ok := arrivedKeys[existingKey]; !ok { // agent was removed diff --git a/lib/reversetunnel/discovery.go b/lib/reversetunnel/discovery.go index 1f3584ffecd33..8d71dab63572f 100644 --- a/lib/reversetunnel/discovery.go +++ b/lib/reversetunnel/discovery.go @@ -2,14 +2,57 @@ package reversetunnel import ( "encoding/json" + "fmt" + "strings" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/trace" ) type discoveryRequest struct { - Proxies []services.Server `json:"proxies"` + ClusterName string `json:"-"` + ClusterAddr utils.NetAddr `json:"-"` + Proxies []services.Server `json:"proxies"` +} + +type Proxies []services.Server + +func (proxies Proxies) String() string { + var out []string + for _, proxy := range proxies { + out = append(out, proxy.GetName()) + } + return strings.Join(out, ",") +} + +func (proxies Proxies) Equal(other []services.Server) bool { + if len(proxies) != len(other) { + return false + } + proxiesMap, otherMap := make(map[string]bool), make(map[string]bool) + for i := range proxies { + proxiesMap[proxies[i].GetName()] = true + } + for i := range other { + otherMap[other[i].GetName()] = true + } + for key := range otherMap { + if !proxiesMap[key] { + return false + } + } + return true +} + +func (r discoveryRequest) key() agentKey { + return agentKey{domainName: r.ClusterName, addr: r.ClusterAddr} +} + +func (r discoveryRequest) String() string { + return fmt.Sprintf("discovery request, cluster name: %v, address: %v, proxies: %v", + r.ClusterName, r.ClusterAddr, Proxies(r.Proxies)) } type discoveryRequestRaw struct { diff --git a/lib/reversetunnel/localsite.go b/lib/reversetunnel/localsite.go index 5a38d4871f533..e3c822f24c197 100644 --- a/lib/reversetunnel/localsite.go +++ b/lib/reversetunnel/localsite.go @@ -40,8 +40,8 @@ func newlocalSite(srv *server, domainName string, client auth.ClientI) (*localSi accessPoint: accessPoint, domainName: domainName, log: log.WithFields(log.Fields{ - teleport.Component: teleport.ComponentReverseTunnelServer, - teleport.ComponentFields: map[string]string{ + trace.Component: teleport.ComponentReverseTunnelServer, + trace.ComponentFields: map[string]string{ "cluster": domainName, }, }), diff --git a/lib/reversetunnel/peer.go b/lib/reversetunnel/peer.go index 7035fae25a5c8..a4f83a3f661ad 100644 --- a/lib/reversetunnel/peer.go +++ b/lib/reversetunnel/peer.go @@ -127,8 +127,8 @@ func newClusterPeer(srv *server, connInfo services.TunnelConnection) (*clusterPe srv: srv, connInfo: connInfo, log: log.WithFields(log.Fields{ - teleport.Component: teleport.ComponentReverseTunnelServer, - teleport.ComponentFields: map[string]string{ + trace.Component: teleport.ComponentReverseTunnelServer, + trace.ComponentFields: map[string]string{ "cluster": connInfo.GetClusterName(), }, }), diff --git a/lib/reversetunnel/remotesite.go b/lib/reversetunnel/remotesite.go index a299509abb2f6..c35e12af5a0a5 100644 --- a/lib/reversetunnel/remotesite.go +++ b/lib/reversetunnel/remotesite.go @@ -34,9 +34,9 @@ import ( "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/trace" - log "github.com/sirupsen/logrus" - + "github.com/jonboulle/clockwork" "github.com/mailgun/oxy/forward" + log "github.com/sirupsen/logrus" "golang.org/x/crypto/ssh" ) @@ -57,6 +57,7 @@ type remoteSite struct { accessPoint auth.AccessPoint connInfo services.TunnelConnection ctx context.Context + clock clockwork.Clock } func (s *remoteSite) CachingAccessPoint() (auth.AccessPoint, error) { @@ -137,11 +138,10 @@ func (s *remoteSite) GetStatus() string { if err != nil { return RemoteSiteStatusOffline } - diff := time.Now().Sub(connInfo.GetLastHeartbeat()) - if diff > defaults.ReverseTunnelOfflineThreshold { - return RemoteSiteStatusOffline + if s.isOnline(connInfo) { + return RemoteSiteStatusOnline } - return RemoteSiteStatusOnline + return RemoteSiteStatusOffline } func (s *remoteSite) registerHeartbeat(t time.Time) { @@ -192,7 +192,7 @@ func (s *remoteSite) periodicSendDiscoveryRequests() { ticker := time.NewTicker(defaults.ReverseTunnelAgentHeartbeatPeriod) defer ticker.Stop() if err := s.sendDiscoveryRequest(); err != nil { - s.Warningf("failed to fetch cluster peers: %v", err) + s.Warningf("failed to send discovery: %v", err) } for { select { @@ -202,12 +202,17 @@ func (s *remoteSite) periodicSendDiscoveryRequests() { case <-ticker.C: err := s.sendDiscoveryRequest() if err != nil { - s.Warningf("could not send discovery request: %v", err) + s.Warningf("could not send discovery request: %v", trace.DebugReport(err)) } } } } +func (s *remoteSite) isOnline(conn services.TunnelConnection) bool { + diff := s.clock.Now().Sub(conn.GetLastHeartbeat()) + return diff < defaults.ReverseTunnelOfflineThreshold +} + // findDisconnectedProxies func (s *remoteSite) findDisconnectedProxies() ([]services.Server, error) { conns, err := s.srv.AccessPoint.GetTunnelConnections(s.domainName) @@ -216,7 +221,9 @@ func (s *remoteSite) findDisconnectedProxies() ([]services.Server, error) { } connected := make(map[string]bool) for _, conn := range conns { - connected[conn.GetProxyName()] = true + if s.isOnline(conn) { + connected[conn.GetProxyName()] = true + } } proxies, err := s.srv.AccessPoint.GetProxies() if err != nil { @@ -240,7 +247,7 @@ func (s *remoteSite) sendDiscoveryRequest() error { if len(disconnectedProxies) == 0 { return nil } - s.Infof("detected disconnected proxies: %v", disconnectedProxies) + s.Debugf("going to request discovery for: %v", Proxies(disconnectedProxies)) req := discoveryRequest{ Proxies: disconnectedProxies, } @@ -257,18 +264,21 @@ func (s *remoteSite) sendDiscoveryRequest() error { if err != nil { return trace.Wrap(err) } - _, err = discoveryC.SendRequest("ping", false, payload) - remoteConn.markInvalid(err) - s.Errorf("disconnecting cluster on %v, err: %v", - remoteConn.conn.RemoteAddr(), - err) - return trace.Wrap(err) + _, err = discoveryC.SendRequest("discovery", false, payload) + if err != nil { + remoteConn.markInvalid(err) + s.Errorf("disconnecting cluster on %v, err: %v", + remoteConn.conn.RemoteAddr(), + err) + return trace.Wrap(err) + } + return nil } for i := 0; i < s.connectionCount(); i++ { err := send() if err != nil { - s.Warningf("%v") + s.Warningf("%v", err) } } return nil @@ -277,8 +287,6 @@ func (s *remoteSite) sendDiscoveryRequest() error { // dialAccessPoint establishes a connection from the proxy (reverse tunnel server) // back into the client using previously established tunnel. func (s *remoteSite) dialAccessPoint(network, addr string) (net.Conn, error) { - s.Debugf("dialAccessPoint") - try := func() (net.Conn, error) { remoteConn, err := s.nextConn() if err != nil { @@ -292,7 +300,7 @@ func (s *remoteSite) dialAccessPoint(network, addr string) (net.Conn, error) { err) return nil, trace.Wrap(err) } - s.Infof("success dialing to cluster") + s.Debugf("success dialing to cluster") return utils.NewChConn(remoteConn.sshConn, ch), nil } diff --git a/lib/reversetunnel/srv.go b/lib/reversetunnel/srv.go index a9c274d99e8de..d0d63e9195576 100644 --- a/lib/reversetunnel/srv.go +++ b/lib/reversetunnel/srv.go @@ -37,6 +37,7 @@ import ( "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" log "github.com/sirupsen/logrus" "golang.org/x/crypto/ssh" ) @@ -76,6 +77,8 @@ type server struct { // ctx is a context used for signalling and broadcast ctx context.Context + + *log.Entry } // DirectCluster is used to access cluster directly @@ -106,6 +109,9 @@ type Config struct { DirectClusters []DirectCluster // Context is a signalling context Context context.Context + // Clock is a clock used in the server, set up to + // wall clock if not set + Clock clockwork.Clock } // CheckAndSetDefaults checks parameters and sets default values @@ -126,6 +132,9 @@ func (cfg *Config) CheckAndSetDefaults() error { return trace.Wrap(err) } } + if cfg.Clock == nil { + cfg.Clock = clockwork.NewRealClock() + } return nil } @@ -146,6 +155,9 @@ func NewServer(cfg Config) (Server, error) { ctx: ctx, cancel: cancel, clusterPeers: make(map[string]*clusterPeers), + Entry: log.WithFields(log.Fields{ + trace.Component: teleport.ComponentReverseTunnelServer, + }), } for _, clusterInfo := range cfg.DirectClusters { @@ -181,17 +193,17 @@ func (s *server) periodicFetchClusterPeers() { ticker := time.NewTicker(defaults.ReverseTunnelAgentHeartbeatPeriod) defer ticker.Stop() if err := s.fetchClusterPeers(); err != nil { - log.Warningf("[TUNNEL] failed to fetch cluster peers: %v", err) + s.Warningf("failed to fetch cluster peers: %v", err) } for { select { case <-s.ctx.Done(): - log.Debugf("[TUNNEL] closing") + s.Debugf("closing") return case <-ticker.C: err := s.fetchClusterPeers() if err != nil { - log.Warningf("[TUNNEL] failed to fetch cluster peers: %v", err) + s.Warningf("failed to fetch cluster peers: %v", err) } } } @@ -266,11 +278,11 @@ func (s *server) removeClusterPeers(conns []services.TunnelConnection) { for _, conn := range conns { peers, ok := s.clusterPeers[conn.GetClusterName()] if !ok { - log.Warningf("[TUNNEL] failed to remove cluster peer, not found peers for %v", conn) + s.Warningf("failed to remove cluster peer, not found peers for %v", conn) continue } peers.removePeer(conn) - log.Debugf("[TUNNEL] removed cluster peer %v", conn) + s.Debugf("removed cluster peer %v", conn) } } @@ -338,13 +350,13 @@ func (s *server) HandleNewChan(conn net.Conn, sconn *ssh.ServerConn, nch ssh.New if ct == "session" { msg = "Cannot open new SSH session on reverse tunnel. Are you connecting to the right port?" } - log.Warningf(msg) + s.Warning(msg) nch.Reject(ssh.ConnectionFailed, msg) return } - log.Debugf("[TUNNEL] new tunnel from %s", sconn.RemoteAddr()) + s.Debugf("new tunnel from %s", sconn.RemoteAddr()) if sconn.Permissions.Extensions[extCertType] != extCertTypeHost { - log.Error(trace.BadParameter("can't retrieve certificate type in certType")) + s.Error(trace.BadParameter("can't retrieve certificate type in certType")) return } // add the incoming site (cluster) to the list of active connections: @@ -369,7 +381,7 @@ func (s *server) HandleNewChan(conn net.Conn, sconn *ssh.ServerConn, nch ssh.New func (s *server) isHostAuthority(auth ssh.PublicKey) bool { keys, err := s.getTrustedCAKeys(services.HostCA) if err != nil { - log.Errorf("failed to retrieve trusted keys, err: %v", err) + s.Errorf("failed to retrieve trusted keys, err: %v", err) return false } for _, k := range keys { @@ -385,7 +397,7 @@ func (s *server) isHostAuthority(auth ssh.PublicKey) bool { func (s *server) isUserAuthority(auth ssh.PublicKey) bool { keys, err := s.getTrustedCAKeys(services.UserCA) if err != nil { - log.Errorf("failed to retrieve trusted keys, err: %v", err) + s.Errorf("failed to retrieve trusted keys, err: %v", err) return false } for _, k := range keys { @@ -435,7 +447,7 @@ func (s *server) checkTrustedKey(CertType services.CertAuthType, domainName stri } func (s *server) keyAuth(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { - logger := log.WithFields(log.Fields{ + logger := s.WithFields(log.Fields{ "remote": conn.RemoteAddr(), "user": conn.User(), }) @@ -533,8 +545,7 @@ func (s *server) upsertSite(conn net.Conn, sshConn *ssh.ServerConn) (*remoteSite } s.remoteSites = append(s.remoteSites, site) } - log.Infof("[TUNNEL] cluster %v connected from %v. clusters: %d", - domainName, conn.RemoteAddr(), len(s.remoteSites)) + site.Infof("connection <- %v, clusters: %d", conn.RemoteAddr(), len(s.remoteSites)) // treat first connection as a registered heartbeat, // otherwise the connection information will appear after initial // heartbeat delay @@ -668,12 +679,13 @@ func newRemoteSite(srv *server, domainName string) (*remoteSite, error) { domainName: domainName, connInfo: connInfo, Entry: log.WithFields(log.Fields{ - teleport.Component: teleport.ComponentReverseTunnelServer, - teleport.ComponentFields: map[string]string{ + trace.Component: teleport.ComponentReverseTunnelServer, + trace.ComponentFields: map[string]interface{}{ "cluster": domainName, }, }), - ctx: srv.ctx, + ctx: srv.ctx, + clock: srv.Clock, } // transport uses connection do dial out to the remote address remoteSite.transport = &http.Transport{ @@ -692,6 +704,8 @@ func newRemoteSite(srv *server, domainName string) (*remoteSite, error) { remoteSite.accessPoint = accessPoint + go remoteSite.periodicSendDiscoveryRequests() + return remoteSite, nil } diff --git a/lib/service/service.go b/lib/service/service.go index ff98c8f75e1f7..573be7ed98c71 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -690,7 +690,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { tsrv, err := reversetunnel.NewServer( reversetunnel.Config{ - ID: conn.Identity.ID.HostUUID, + ID: process.Config.HostUUID, ListenAddr: cfg.Proxy.ReverseTunnelListenAddr, HostSigners: []ssh.Signer{conn.Identity.KeySigner}, AccessPoint: authClient, @@ -732,6 +732,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { Client: conn.Client, AccessPoint: authClient, HostSigners: []ssh.Signer{conn.Identity.KeySigner}, + Cluster: conn.Identity.Cert.Extensions[utils.CertExtensionAuthority], }) if err != nil { return trace.Wrap(err) diff --git a/lib/services/local/presence.go b/lib/services/local/presence.go index 0b7577db89d5a..e6c93f3b02040 100644 --- a/lib/services/local/presence.go +++ b/lib/services/local/presence.go @@ -24,6 +24,7 @@ import ( "github.com/gravitational/teleport/lib/services" "github.com/gravitational/trace" + log "github.com/sirupsen/logrus" ) // PresenceService records and reports the presence of all components @@ -363,7 +364,11 @@ func (s *PresenceService) GetTunnelConnection(clusterName, connectionName string } return nil, trace.Wrap(err) } - return services.UnmarshalTunnelConnection(data) + conn, err := services.UnmarshalTunnelConnection(data) + if err != nil { + log.Debugf("got some problem with data: %q", string(data)) + } + return conn, err } // GetTunnelConnections returns connections for a trusted cluster diff --git a/lib/services/tunnelconn.go b/lib/services/tunnelconn.go index d8fd4ce6268e4..f7127bf12b2d1 100644 --- a/lib/services/tunnelconn.go +++ b/lib/services/tunnelconn.go @@ -3,6 +3,7 @@ package services import ( "encoding/json" "fmt" + "runtime/debug" "strings" "time" @@ -172,7 +173,7 @@ type TunnelConnectionSpecV2 struct { // ProxyName is the name of the proxy server ProxyName string `json:"proxy_name"` // LastHeartbeat is a time of the last heartbeat - LastHeartbeat time.Time `json:"last_heartbeat"` + LastHeartbeat time.Time `json:"last_heartbeat,omitempty"` } // TunnelConnectionSpecV2Schema is JSON schema for reverse tunnel spec @@ -197,7 +198,8 @@ func GetTunnelConnectionSchema() string { // sets defaults and checks the schema func UnmarshalTunnelConnection(data []byte) (TunnelConnection, error) { if len(data) == 0 { - return nil, trace.BadParameter("missing tunnel data") + debug.PrintStack() + return nil, trace.BadParameter("missing tunnel connection data") } var h ResourceHeader err := json.Unmarshal(data, &h) diff --git a/lib/srv/sshserver.go b/lib/srv/sshserver.go index e5b4054578fda..fe93f48ae7893 100644 --- a/lib/srv/sshserver.go +++ b/lib/srv/sshserver.go @@ -301,7 +301,7 @@ func (s *Server) getNamespace() string { return services.ProcessNamespace(s.namespace) } -func (s *Server) logFields(fields map[string]interface{}) log.Fields { +func (s *Server) logFields(fields log.Fields) log.Fields { var component string if s.proxyMode { component = teleport.ComponentProxy @@ -309,8 +309,8 @@ func (s *Server) logFields(fields map[string]interface{}) log.Fields { component = teleport.ComponentNode } return log.Fields{ - teleport.Component: component, - teleport.ComponentFields: fields, + trace.Component: component, + trace.ComponentFields: fields, } } diff --git a/lib/sshutils/server.go b/lib/sshutils/server.go index 8554fecfe21d6..d8f9cd1e26712 100644 --- a/lib/sshutils/server.go +++ b/lib/sshutils/server.go @@ -42,6 +42,8 @@ import ( // Server is a generic implementation of an SSH server. All Teleport // services (auth, proxy, ssh) use this as a base to accept SSH connections. type Server struct { + *log.Entry + // component is a name of the facility which uses this server, // used for logging/debugging. typically it's "proxy" or "auth api", etc component string @@ -107,10 +109,13 @@ func NewServer( return nil, err } s := &Server{ - component: component, + Entry: log.WithFields(log.Fields{ + trace.Component: "ssh:" + component, + }), addr: a, newChanHandler: h, closeC: make(chan struct{}), + component: component, } s.limiter, err = limiter.NewLimiter(limiter.LimiterConfig{}) if err != nil { @@ -151,7 +156,7 @@ func SetRequestHandler(req RequestHandler) ServerOption { func SetCiphers(ciphers []string) ServerOption { return func(s *Server) error { - log.Debugf("[SSH:%v] Supported Ciphers: %q", s.component, ciphers) + s.Debugf("supported ciphers: %q", ciphers) if ciphers != nil { s.cfg.Ciphers = ciphers } @@ -161,7 +166,7 @@ func SetCiphers(ciphers []string) ServerOption { func SetKEXAlgorithms(kexAlgorithms []string) ServerOption { return func(s *Server) error { - log.Debugf("[SSH:%v] Supported KEX algorithms: %q", s.component, kexAlgorithms) + s.Debugf("supported KEX algorithms: %q", kexAlgorithms) if kexAlgorithms != nil { s.cfg.KeyExchanges = kexAlgorithms } @@ -171,7 +176,7 @@ func SetKEXAlgorithms(kexAlgorithms []string) ServerOption { func SetMACAlgorithms(macAlgorithms []string) ServerOption { return func(s *Server) error { - log.Debugf("[SSH:%v] Supported MAC algorithms: %q", s.component, macAlgorithms) + s.Debugf("supported MAC algorithms: %q", macAlgorithms) if macAlgorithms != nil { s.cfg.MACs = macAlgorithms } @@ -190,7 +195,7 @@ func (s *Server) Start() error { return err } s.listener = socket - log.Infof("[SSH:%s] listening socket: %v", s.component, socket.Addr()) + s.Infof("listening socket: %v", socket.Addr()) go s.acceptConnections() return nil } @@ -215,21 +220,21 @@ func (s *Server) Close() error { func (s *Server) acceptConnections() { defer s.notifyClosed() addr := s.Addr() - log.Infof("[SSH:%v] is listening on %v", s.component, addr) + s.Infof("listening on %v", addr) for { conn, err := s.listener.Accept() if err != nil { if s.askedToClose { - log.Infof("[SSH:%v] server %v exited", s.component, addr) + s.Infof("server %v exited", addr) s.askedToClose = false return } // our best shot to avoid excessive logging if op, ok := err.(*net.OpError); ok && !op.Timeout() { - log.Debugf("[SSH:%v] closed socket %v", s.component, op) + s.Debugf("closed socket %v", op) return } - log.Errorf("SSH:%v accept error: %T %v", s.component, err, err) + s.Errorf("accept error: %T %v", err, err) return } go s.handleConnection(conn) @@ -278,12 +283,12 @@ func (s *Server) handleConnection(conn net.Conn) { return } // Connection successfully initiated - log.Infof("[SSH:%v] new connection %v -> %v vesion: %v", - s.component, sconn.RemoteAddr(), sconn.LocalAddr(), string(sconn.ClientVersion())) + s.Debugf("incoming connection %v -> %v vesion: %v", + sconn.RemoteAddr(), sconn.LocalAddr(), string(sconn.ClientVersion())) // will be called when the connection is closed connClosed := func() { - log.Infof("[SSH:%v] closed connection", s.component) + s.Debugf("closed connection %v", sconn.RemoteAddr()) } // The keepalive ticket will ensure that SSH keepalive requests are being sent @@ -300,7 +305,7 @@ func (s *Server) handleConnection(conn net.Conn) { connClosed() return } - log.Infof("[SSH:%v] recieved out-of-band request: %+v", s.component, req) + s.Debugf("recieved out-of-band request: %+v", req) if s.reqHandler != nil { go s.reqHandler.HandleRequest(req) } diff --git a/lib/state/cachingaccesspoint.go b/lib/state/cachingaccesspoint.go index eeff683a55337..dcd587564809c 100644 --- a/lib/state/cachingaccesspoint.go +++ b/lib/state/cachingaccesspoint.go @@ -21,6 +21,7 @@ import ( "fmt" "time" + "github.com/gravitational/teleport" "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/backend" "github.com/gravitational/teleport/lib/defaults" @@ -57,6 +58,7 @@ const ( // This which can be used if the upstream AccessPoint goes offline type CachingAuthClient struct { Config + *log.Entry // ap points to the access ponit we're caching access to: ap auth.AccessPoint @@ -115,6 +117,9 @@ func NewCachingAuthClient(config Config) (*CachingAuthClient, error) { trust: local.NewCAService(config.Backend), access: local.NewAccessService(config.Backend), presence: local.NewPresenceService(config.Backend), + Entry: log.WithFields(log.Fields{ + trace.Component: teleport.ComponentCachingClient, + }), } if !cs.SkipPreload { err := cs.fetchAll() @@ -122,7 +127,7 @@ func NewCachingAuthClient(config Config) (*CachingAuthClient, error) { // we almost always get some "access denied" errors here because // not all cacheable resources are available (for example nodes do // not have access to tunnels) - log.Debugf("Auth cache: %v", err) + cs.Debugf("auth cache: %v", err) } } return cs, nil @@ -504,14 +509,14 @@ func (cs *CachingAuthClient) UpsertTunnelConnection(conn services.TunnelConnecti func (cs *CachingAuthClient) try(f func() error) error { tooSoon := cs.lastErrorTime.Add(backoffDuration).After(time.Now()) if tooSoon { - log.Warnf("Backoff: using cached value due to recent errors") + cs.Warnf("backoff: using cached value due to recent errors") return trace.ConnectionProblem(fmt.Errorf("backoff"), "backing off due to recent errors") } accessPointRequests.Inc() err := trace.ConvertSystemError(f()) if trace.IsConnectionProblem(err) { cs.lastErrorTime = time.Now() - log.Warningf("Connection Problem: failed connect to the auth servers, using local cache") + cs.Warningf("connection problem: failed connect to the auth servers, using local cache") } return err } diff --git a/lib/utils/proxy/proxy.go b/lib/utils/proxy/proxy.go index 66bd0aa49a55b..5a5616b484863 100644 --- a/lib/utils/proxy/proxy.go +++ b/lib/utils/proxy/proxy.go @@ -162,20 +162,22 @@ func getProxyAddress() string { strings.ToLower(teleport.HTTPProxy), } + l := log.WithFields(log.Fields{trace.Component: "http:proxy"}) + for _, v := range envs { addr := os.Getenv(v) if addr != "" { proxyaddr, err := parse(addr) if err != nil { - log.Debugf("[HTTP PROXY] Unable to parse environment variable %q: %q.", v, addr) + l.Debugf("unable to parse environment variable %q: %q.", v, addr) continue } - log.Debugf("[HTTP PROXY] Successfully parsed environment variable %q: %q to %q", v, addr, proxyaddr) + l.Debugf("successfully parsed environment variable %q: %q to %q", v, addr, proxyaddr) return proxyaddr } } - log.Debugf("[HTTP PROXY] No valid environment variables found.") + l.Debugf("no valid environment variables found.") return "" } diff --git a/vendor/github.com/gravitational/trace/log.go b/vendor/github.com/gravitational/trace/log.go index 310a90af48692..f42edcc2e7f9b 100644 --- a/vendor/github.com/gravitational/trace/log.go +++ b/vendor/github.com/gravitational/trace/log.go @@ -39,17 +39,26 @@ const ( LevelField = "level" // Component is a field that represents component - e.g. service or // function - Component = "component" + Component = "trace.component" + // ComponentFields is a fields compoonent + ComponentFields = "trace.fields" ) // TextFormatter is logrus-compatible formatter and adds // file and line details to every logged entry. type TextFormatter struct { + // DisableTimestamp disables timestamp output (useful when outputting to + // systemd logs) DisableTimestamp bool } // Format implements logrus.Formatter interface and adds file and line -func (tf *TextFormatter) Format(e *log.Entry) ([]byte, error) { +func (tf *TextFormatter) Format(e *log.Entry) (data []byte, err error) { + defer func() { + if r := recover(); r != nil { + err = BadParameter("recovered from panic: %v", r) + } + }() var file string if frameNo := findFrame(); frameNo != -1 { t := newTrace(frameNo, nil) @@ -183,6 +192,11 @@ func (w *writer) writeMap(m map[string]interface{}) { w.WriteString(":{") w.writeMap(val) w.WriteString(" }") + case log.Fields: + w.WriteString(key) + w.WriteString(":{") + w.writeMap(val) + w.WriteString(" }") default: w.writeKeyValue(key, val) } From d3f05872ccb4d7f6adc0d528e99c313a6afc041a Mon Sep 17 00:00:00 2001 From: Sasha Klizhentas Date: Sun, 8 Oct 2017 12:42:16 -0700 Subject: [PATCH 05/24] fix some backend problems --- lib/backend/dir/impl.go | 15 ++++++++++++--- lib/reversetunnel/remotesite.go | 5 ++--- lib/services/local/presence.go | 10 +++++++--- 3 files changed, 21 insertions(+), 9 deletions(-) diff --git a/lib/backend/dir/impl.go b/lib/backend/dir/impl.go index 492f50a092bcf..6135aa0152bff 100644 --- a/lib/backend/dir/impl.go +++ b/lib/backend/dir/impl.go @@ -85,7 +85,7 @@ func New(params backend.Params) (backend.Backend, error) { RootDir: rootDir, InternalClock: clockwork.NewRealClock(), Entry: log.WithFields(log.Fields{ - trace.Component: "fs", + trace.Component: "backend:dir", trace.ComponentFields: log.Fields{ "dir": rootDir, }, @@ -176,17 +176,22 @@ func (bk *Backend) GetVal(bucket []string, key string) ([]byte, error) { } if expired { bk.DeleteKey(bucket, key) - return nil, trace.NotFound("key '%s' is not found", key) + return nil, trace.NotFound("key %q is not found", key) } fp := path.Join(dirPath, key) bytes, err := ioutil.ReadFile(fp) if err != nil { // GetVal() on a bucket must return 'BadParameter' error: if fi, _ := os.Stat(fp); fi != nil && fi.IsDir() { - return nil, trace.BadParameter("%s is not a valid key", key) + return nil, trace.BadParameter("%q is not a valid key", key) } return nil, trace.ConvertSystemError(err) } + // this could happen if we delete the file concurrently + // with the read, apparently we can read empty file back + if len(bytes) == 0 { + return nil, trace.NotFound("key %q is not found", key) + } return bytes, nil } @@ -314,6 +319,10 @@ func (bk *Backend) checkTTL(dirPath string, key string) (expired bool, err error } return false, trace.Wrap(err) } + // this could happen if file was deleted, we can sometimes read empty contents + if len(bytes) == 0 { + return false, nil + } var expiryTime time.Time if err = expiryTime.UnmarshalText(bytes); err != nil { return false, trace.Wrap(err) diff --git a/lib/reversetunnel/remotesite.go b/lib/reversetunnel/remotesite.go index c35e12af5a0a5..c7dfc6435c06a 100644 --- a/lib/reversetunnel/remotesite.go +++ b/lib/reversetunnel/remotesite.go @@ -146,6 +146,7 @@ func (s *remoteSite) GetStatus() string { func (s *remoteSite) registerHeartbeat(t time.Time) { s.connInfo.SetLastHeartbeat(t) + s.connInfo.SetExpiry(s.clock.Now().Add(defaults.ReverseTunnelOfflineThreshold)) err := s.srv.AccessPoint.UpsertTunnelConnection(s.connInfo) if err != nil { log.Warningf("failed to register heartbeat: %v", err) @@ -221,9 +222,7 @@ func (s *remoteSite) findDisconnectedProxies() ([]services.Server, error) { } connected := make(map[string]bool) for _, conn := range conns { - if s.isOnline(conn) { - connected[conn.GetProxyName()] = true - } + connected[conn.GetProxyName()] = true } proxies, err := s.srv.AccessPoint.GetProxies() if err != nil { diff --git a/lib/services/local/presence.go b/lib/services/local/presence.go index e6c93f3b02040..035fa99d96626 100644 --- a/lib/services/local/presence.go +++ b/lib/services/local/presence.go @@ -123,17 +123,20 @@ func (s *PresenceService) getServers(kind, prefix string) ([]services.Server, er if err != nil { return nil, trace.Wrap(err) } - servers := make([]services.Server, len(keys)) - for i, key := range keys { + servers := make([]services.Server, 0, len(keys)) + for _, key := range keys { data, err := s.GetVal([]string{prefix}, key) if err != nil { + if trace.IsNotFound(err) { + continue + } return nil, trace.Wrap(err) } server, err := services.GetServerMarshaler().UnmarshalServer(data, kind) if err != nil { return nil, trace.Wrap(err) } - servers[i] = server + servers = append(servers, server) } // sorting helps with tests and makes it all deterministic sort.Sort(services.SortedServers(servers)) @@ -390,6 +393,7 @@ func (s *PresenceService) GetTunnelConnections(clusterName string) ([]services.T if !trace.IsNotFound(err) { return nil, trace.Wrap(err) } + continue } conns = append(conns, conn) } From bb5f77854e316fecaa5858b0a6b0934de219a34b Mon Sep 17 00:00:00 2001 From: Sasha Klizhentas Date: Sun, 8 Oct 2017 18:07:01 -0700 Subject: [PATCH 06/24] before refactoring --- lib/reversetunnel/agent.go | 56 ++++++++++++++++------ lib/reversetunnel/agentpool.go | 83 +++++++++++++++++++++++++++------ lib/reversetunnel/remotesite.go | 23 +++++++-- lib/utils/cli.go | 2 +- 4 files changed, 129 insertions(+), 35 deletions(-) diff --git a/lib/reversetunnel/agent.go b/lib/reversetunnel/agent.go index 3a51c4dfba321..8d24a51837694 100644 --- a/lib/reversetunnel/agent.go +++ b/lib/reversetunnel/agent.go @@ -53,11 +53,6 @@ const ( agentStateConnected = "connected" // agentStateDiscovered means that agent has discovered the right proxy agentStateDiscovered = "discovered" - // agentStateMissed means that agent has connected, - // but not to the one of the instances it was targeted to discover - agentStateMissed = "missed" - // agentStateClosed is for closed agents - agentStateClosed = "closed" ) // AgentConfig holds configuration for agent @@ -171,9 +166,9 @@ func NewAgent(cfg AgentConfig) (*Agent, error) { func (a *Agent) String() string { if len(a.DiscoverProxies) == 0 { - return fmt.Sprintf("agent -> cluster %v, target %v", a.RemoteCluster, a.Addr.String()) + return fmt.Sprintf("agent(%v) -> %v:%v", a.getState(), a.RemoteCluster, a.Addr.String()) } - return fmt.Sprintf("agent -> cluster %v, target %v, discover %v", a.RemoteCluster, a.Addr.String(), Proxies(a.DiscoverProxies)) + return fmt.Sprintf("agent(%v) -> %v:%v, discover %v", a.getState(), a.RemoteCluster, a.Addr.String(), Proxies(a.DiscoverProxies)) } func (a *Agent) getLastStateChange() time.Time { @@ -182,6 +177,13 @@ func (a *Agent) getLastStateChange() time.Time { return a.stateChange } +func (a *Agent) setStateAndPrincipals(state string, principals []string) { + prev := a.state + a.Debugf("changing state %v -> %v", prev, state) + a.state = state + a.stateChange = a.Clock.Now().UTC() + a.principals = principals +} func (a *Agent) setState(state string) { a.Lock() defer a.Unlock() @@ -219,11 +221,18 @@ func (a *Agent) Wait() error { return nil } -func (a *Agent) connectedToRightProxy() bool { +func (a *Agent) connectedTo(proxy services.Server) bool { principals := a.getPrincipals() + proxyID := fmt.Sprintf("%v.%v", proxy.GetName(), a.RemoteCluster) + if _, ok := principals[proxyID]; ok { + return true + } + return false +} + +func (a *Agent) connectedToRightProxy() bool { for _, proxy := range a.DiscoverProxies { - proxyID := fmt.Sprintf("%v.%v", proxy.GetName(), a.RemoteCluster) - if _, ok := principals[proxyID]; ok { + if a.connectedTo(proxy) { return true } } @@ -295,7 +304,9 @@ func (a *Agent) connect() (conn *ssh.Client, err error) { a.setState(agentStateDiscovered) } else { a.Debugf("missed, connected to %v instead of %v", a.getPrincipalsList(), Proxies(a.DiscoverProxies)) - a.setState(agentStateMissed) + a.setStateAndPrincipals(agentStateDiscovering, nil) + conn.Close() + return nil, trace.ConnectionProblem(nil, "did not discover the right proxy") } } else { a.setState(agentStateConnected) @@ -439,6 +450,20 @@ func (a *Agent) proxyTransport(ch ssh.Channel, reqC <-chan *ssh.Request) { wg.Wait() } +func (a *Agent) run() { + // connect with exponential backoff untill asked to stop + for { + conn, err := a.connect() + if err != nil { + a.Warningf("failed to create remote tunnel: %v", err) + } else { + a.Infof("connected to %s", conn.RemoteAddr()) + // start heartbeat even if error happend, it will reconnect + a.runHeartbeat(conn) + } + } +} + // runHeartbeat is a blocking function which runs in a loop sending heartbeats // to the given SSH connection. // @@ -448,7 +473,7 @@ func (a *Agent) runHeartbeat(conn *ssh.Client) { heartbeatLoop := func() error { if conn == nil { - return trace.Errorf("heartbeat cannot ping: need to reconnect") + return trace.ConnectionProblem(nil, "heartbeat cannot ping: need to reconnect") } a.Infof("connected to %s", conn.RemoteAddr()) defer conn.Close() @@ -470,7 +495,8 @@ func (a *Agent) runHeartbeat(conn *ssh.Client) { return nil // time to ping: case <-ticker.C: - _, err := hb.SendRequest("ping", false, nil) + bytes, _ := a.Clock.Now().UTC().MarshalText() + _, err := hb.SendRequest("ping", false, bytes) if err != nil { log.Error(err) return trace.Wrap(err) @@ -525,9 +551,9 @@ func (a *Agent) runHeartbeat(conn *ssh.Client) { // keep repeating to reconnect until we're asked to stop err := heartbeatLoop() if len(a.DiscoverProxies) != 0 { - a.setState(agentStateDiscovering) + a.setStateAndPrincipals(agentStateDiscovering, nil) } else { - a.setState(agentStateConnecting) + a.setStateAndPrincipals(agentStateConnecting, nil) } // when this happens, this is #1 issue we have right now with Teleport. So I'm making diff --git a/lib/reversetunnel/agentpool.go b/lib/reversetunnel/agentpool.go index 289baf113a999..996cee5a3c8c8 100644 --- a/lib/reversetunnel/agentpool.go +++ b/lib/reversetunnel/agentpool.go @@ -13,9 +13,20 @@ import ( "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + "github.com/prometheus/client_golang/prometheus" + log "github.com/sirupsen/logrus" "golang.org/x/crypto/ssh" +) - log "github.com/sirupsen/logrus" +var ( + tunnelStats = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Name: "tunnels", + Help: "Number of tunnels per state", + }, + []string{"cluster", "state"}, + ) ) // AgentPool manages the pool of outbound reverse tunnel agents. @@ -47,6 +58,9 @@ type AgentPoolConfig struct { Context context.Context // Cluster is a cluster name Cluster string + // Clock is a clock used to get time, if not set, + // system clock is used + Clock clockwork.Clock } // CheckAndSetDefaults checks and sets defaults @@ -66,6 +80,9 @@ func (cfg *AgentPoolConfig) CheckAndSetDefaults() error { if cfg.Context == nil { cfg.Context = context.TODO() } + if cfg.Clock == nil { + cfg.Clock = clockwork.NewRealClock() + } return nil } @@ -128,24 +145,50 @@ func (m *AgentPool) processDiscoveryRequests() { } } +func foundInOneOf(proxy services.Server, agents []*Agent) bool { + for _, agent := range agents { + if agent.connectedTo(proxy) { + return true + } + } + return false +} + func (m *AgentPool) tryDiscover(req discoveryRequest) { - m.Debugf("will need to discover: %v", Proxies(req.Proxies)) + proxies := Proxies(req.Proxies) m.Lock() defer m.Unlock() - proxies := Proxies(req.Proxies) + matchKey := req.key() - var foundAgent bool + + // if one of the proxies have been discovered or connected to + // remove proxy from discovery request + var filtered Proxies + agents := m.agents[matchKey] + for i := range proxies { + proxy := proxies[i] + if !foundInOneOf(proxy, agents) { + filtered = append(filtered, proxy) + } + } + m.Debugf("tryDiscover original(%v) -> filtered(%v)", proxies, filtered) + // nothing to do + if len(filtered) == 0 { + return + } // close agents that are discovering proxies that are somehow // different from discovery request + var foundAgent bool m.closeAgentsIf(&matchKey, func(agent *Agent) bool { if agent.getState() != agentStateDiscovering { return false } - if proxies.Equal(agent.DiscoverProxies) { + if filtered.Equal(agent.DiscoverProxies) { foundAgent = true - agent.Debugf("is already discovering %v, nothing to do", req) + agent.Debugf("agent is already discovering the same proxies as requested in %v", filtered) return false } + agent.Debugf("is obsolete, going to close", agent.getState(), agent.DiscoverProxies) return true }) @@ -254,6 +297,25 @@ func (m *AgentPool) addAgent(key agentKey, discoverProxies []services.Server) er return nil } +// reportStats submits report about agents state once in a while +func (m *AgentPool) reportStats() { + for key, agents := range m.agents { + countPerState := make(map[string]int) + for _, a := range agents { + countPerState[a.getState()] += 1 + } + for state, count := range countPerState { + gauge, err := tunnelStats.GetMetricWithLabelValues(key.domainName, state) + if err != nil { + m.Warningf("%v", err) + continue + } + gauge.Set(float64(count)) + } + m.Debugf("STATS: %v -> %v", key.domainName, countPerState) + } +} + func (m *AgentPool) syncAgents(tunnels []services.ReverseTunnel) error { m.Lock() defer m.Unlock() @@ -273,14 +335,7 @@ func (m *AgentPool) syncAgents(tunnels []services.ReverseTunnel) error { return trace.Wrap(err) } } - // garbage collect agents that have connected, but to the wrong proxy - m.closeAgentsIf(nil, func(agent *Agent) bool { - if agent.getState() == agentStateMissed { - agent.Debugf("closing agent that could not discover clusters") - return true - } - return false - }) + m.reportStats() return nil } diff --git a/lib/reversetunnel/remotesite.go b/lib/reversetunnel/remotesite.go index c7dfc6435c06a..bfd89933738f2 100644 --- a/lib/reversetunnel/remotesite.go +++ b/lib/reversetunnel/remotesite.go @@ -149,7 +149,7 @@ func (s *remoteSite) registerHeartbeat(t time.Time) { s.connInfo.SetExpiry(s.clock.Now().Add(defaults.ReverseTunnelOfflineThreshold)) err := s.srv.AccessPoint.UpsertTunnelConnection(s.connInfo) if err != nil { - log.Warningf("failed to register heartbeat: %v", err) + s.Warningf("failed to register heartbeat: %v", err) } } @@ -169,10 +169,21 @@ func (s *remoteSite) handleHeartbeat(conn *remoteConn, ch ssh.Channel, reqC <-ch conn.markInvalid(trace.ConnectionProblem(nil, "agent disconnected")) return } - s.Debugf("ping <- %v", conn.conn.RemoteAddr()) + var timeSent time.Time + var roundtrip time.Duration + if req.Payload != nil { + if err := timeSent.UnmarshalText(req.Payload); err == nil { + roundtrip = s.srv.Clock.Now().Sub(timeSent) + } + } + if roundtrip != 0 { + s.Debugf("ping <- %v rtt(%v)", conn.conn.RemoteAddr(), roundtrip) + } else { + s.Debugf("ping <- %v", conn.conn.RemoteAddr()) + } go s.registerHeartbeat(time.Now()) - case <-time.After(3 * defaults.ReverseTunnelAgentHeartbeatPeriod): - conn.markInvalid(trace.ConnectionProblem(nil, "agent missed 3 heartbeats")) + case <-time.After(defaults.ReverseTunnelOfflineThreshold): + conn.markInvalid(trace.ConnectionProblem(nil, "no heartbeats for %v", defaults.ReverseTunnelOfflineThreshold)) } } } @@ -222,7 +233,9 @@ func (s *remoteSite) findDisconnectedProxies() ([]services.Server, error) { } connected := make(map[string]bool) for _, conn := range conns { - connected[conn.GetProxyName()] = true + if s.isOnline(conn) { + connected[conn.GetProxyName()] = true + } } proxies, err := s.srv.AccessPoint.GetProxies() if err != nil { diff --git a/lib/utils/cli.go b/lib/utils/cli.go index 248eee05b14a6..38f4c6687fbd5 100644 --- a/lib/utils/cli.go +++ b/lib/utils/cli.go @@ -44,7 +44,7 @@ const ( // InitLogger configures the global logger for a given purpose / verbosity level func InitLogger(purpose LoggingPurpose, level log.Level) { log.StandardLogger().Hooks = make(log.LevelHooks) - formatter := &trace.TextFormatter{DisableTimestamp: true} + formatter := &trace.TextFormatter{DisableTimestamp: false} log.SetFormatter(formatter) log.SetLevel(level) From eb4cfa12d9814e5dc1090d640501b7f872ac733d Mon Sep 17 00:00:00 2001 From: Sasha Klizhentas Date: Sun, 8 Oct 2017 23:07:39 -0700 Subject: [PATCH 07/24] refactoring complete --- lib/defaults/defaults.go | 3 + lib/reversetunnel/agent.go | 241 +++++++++++++++----------------- lib/reversetunnel/agentpool.go | 1 + lib/state/cachingaccesspoint.go | 6 +- 4 files changed, 116 insertions(+), 135 deletions(-) diff --git a/lib/defaults/defaults.go b/lib/defaults/defaults.go index 189632fba4c04..252fde3a9b2ba 100644 --- a/lib/defaults/defaults.go +++ b/lib/defaults/defaults.go @@ -173,6 +173,9 @@ var ( // TerminalSizeRefreshPeriod is how frequently clients who share sessions sync up // their terminal sizes TerminalSizeRefreshPeriod = 2 * time.Second + + // NewtworkBackoffDuration is a standard backoff on network requests + NetworkBackoffDuration = time.Second * 10 ) // Default connection limits, they can be applied separately on any of the Teleport diff --git a/lib/reversetunnel/agent.go b/lib/reversetunnel/agent.go index 8d24a51837694..0745b76499025 100644 --- a/lib/reversetunnel/agent.go +++ b/lib/reversetunnel/agent.go @@ -154,13 +154,6 @@ func NewAgent(cfg AgentConfig) (*Agent, error) { }, }) a.hostKeyCallback = a.checkHostSignature - - if len(a.DiscoverProxies) != 0 { - a.setState(agentStateDiscovering) - } else { - a.setState(agentStateConnecting) - } - return a, nil } @@ -199,21 +192,15 @@ func (a *Agent) getState() string { return a.state } -// Close signals to close all connections +// Close signals to close all connections and operations func (a *Agent) Close() error { a.cancel() return nil } -// Start starts agent that attempts to connect to remote server part -func (a *Agent) Start() error { - conn, err := a.connect() - if err != nil { - a.Warningf("Failed to create remote tunnel: %v", err) - } - // start heartbeat even if error happend, it will reconnect - go a.runHeartbeat(conn) - return err +// Start starts agent that attempts to connect to remote server +func (a *Agent) Start() { + go a.run() } // Wait waits until all outstanding operations are completed @@ -299,18 +286,6 @@ func (a *Agent) connect() (conn *ssh.Client, err error) { Timeout: defaults.DefaultDialTimeout, }) if conn != nil { - if len(a.DiscoverProxies) != 0 { - if a.connectedToRightProxy() { - a.setState(agentStateDiscovered) - } else { - a.Debugf("missed, connected to %v instead of %v", a.getPrincipalsList(), Proxies(a.DiscoverProxies)) - a.setStateAndPrincipals(agentStateDiscovering, nil) - conn.Close() - return nil, trace.ConnectionProblem(nil, "did not discover the right proxy") - } - } else { - a.setState(agentStateConnected) - } break } } @@ -450,127 +425,133 @@ func (a *Agent) proxyTransport(ch ssh.Channel, reqC <-chan *ssh.Request) { wg.Wait() } +// run is the main agent loop, constantly tries to re-establish +// the connection until stopped func (a *Agent) run() { - // connect with exponential backoff untill asked to stop + backoff := time.NewTicker(defaults.NetworkBackoffDuration) + defer backoff.Stop() + firstAttempt := true for { - conn, err := a.connect() - if err != nil { - a.Warningf("failed to create remote tunnel: %v", err) + if len(a.DiscoverProxies) != 0 { + a.setStateAndPrincipals(agentStateDiscovering, nil) } else { - a.Infof("connected to %s", conn.RemoteAddr()) - // start heartbeat even if error happend, it will reconnect - a.runHeartbeat(conn) + a.setStateAndPrincipals(agentStateConnecting, nil) } - } -} -// runHeartbeat is a blocking function which runs in a loop sending heartbeats -// to the given SSH connection. -// -func (a *Agent) runHeartbeat(conn *ssh.Client) { - ticker := time.NewTicker(defaults.ReverseTunnelAgentHeartbeatPeriod) - defer ticker.Stop() - - heartbeatLoop := func() error { - if conn == nil { - return trace.ConnectionProblem(nil, "heartbeat cannot ping: need to reconnect") - } - a.Infof("connected to %s", conn.RemoteAddr()) - defer conn.Close() - hb, reqC, err := conn.OpenChannel(chanHeartbeat, nil) - if err != nil { - return trace.Wrap(err) - } - newAccesspointC := conn.HandleChannelOpen(chanAccessPoint) - newTransportC := conn.HandleChannelOpen(chanTransport) - newDiscoveryC := conn.HandleChannelOpen(chanDiscovery) - - // send first ping right away, then start a ping timer: - hb.SendRequest("ping", false, nil) - - for { + // ignore timer and context on the first attempt + if !firstAttempt { select { - // need to exit: + // abort if asked to stop: case <-a.ctx.Done(): - return nil - // time to ping: - case <-ticker.C: - bytes, _ := a.Clock.Now().UTC().MarshalText() - _, err := hb.SendRequest("ping", false, bytes) - if err != nil { - log.Error(err) - return trace.Wrap(err) - } - a.Debugf("ping -> %v", conn.RemoteAddr()) - // ssh channel closed: - case req := <-reqC: - if req == nil { - return trace.ConnectionProblem(nil, "heartbeat: connection closed") - } - // new access point request: - case nch := <-newAccesspointC: - if nch == nil { - continue - } - a.Debugf("access point request: %v", nch.ChannelType()) - ch, req, err := nch.Accept() - if err != nil { - a.Warningf("failed to accept request: %v", err) - continue - } - go a.proxyAccessPoint(ch, req) - // new transport request: - case nch := <-newTransportC: - if nch == nil { - continue - } - a.Debugf("transport request: %v", nch.ChannelType()) - ch, req, err := nch.Accept() - if err != nil { - a.Warningf("failed to accept request: %v", err) - continue - } - go a.proxyTransport(ch, req) - // new discovery request - case nch := <-newDiscoveryC: - if nch == nil { - continue - } - a.Debugf("discovery request: %v", nch.ChannelType()) - ch, req, err := nch.Accept() - if err != nil { - a.Warningf("failed to accept request: %v", err) + a.Debugf("agent has closed, exiting") + return + // wait backoff on network retries + case <-backoff.C: + } + } + + conn, err := a.connect() + firstAttempt = false + if err != nil || conn == nil { + a.Warningf("failed to create remote tunnel: %v, conn: %v", err, conn) + continue + } else { + a.Infof("connected to %s", conn.RemoteAddr()) + if len(a.DiscoverProxies) != 0 { + if !a.connectedToRightProxy() { + a.Debugf("missed, connected to %v instead of %v", a.getPrincipalsList(), Proxies(a.DiscoverProxies)) + conn.Close() continue } - go a.handleDiscovery(ch, req) + a.setState(agentStateDiscovered) + } else { + a.setState(agentStateConnected) + } + // start heartbeat even if error happend, it will reconnect + // when this happens, this is #1 issue we have right now with Teleport. So we are making + // it EASY to see in the logs. This condition should never be permanent (repeates + // every XX seconds) + if err := a.processRequests(conn); err != nil { + log.Warn(err) } } } +} - // run heartbeat loop, and when it fails (probably means that a tunnel got disconnected) - // keep repeating to reconnect until we're asked to stop - err := heartbeatLoop() - if len(a.DiscoverProxies) != 0 { - a.setStateAndPrincipals(agentStateDiscovering, nil) - } else { - a.setStateAndPrincipals(agentStateConnecting, nil) - } +// processRequests is a blocking function which runs in a loop sending heartbeats +// to the given SSH connection and processes inbound requests from the +// remote proxy +func (a *Agent) processRequests(conn *ssh.Client) error { + defer conn.Close() + ticker := time.NewTicker(defaults.ReverseTunnelAgentHeartbeatPeriod) + defer ticker.Stop() - // when this happens, this is #1 issue we have right now with Teleport. So I'm making - // it EASY to see in the logs. This condition should never be permanent (like repeates - // every XX seconds) + hb, reqC, err := conn.OpenChannel(chanHeartbeat, nil) if err != nil { - a.Warn(err) + return trace.Wrap(err) } + newAccesspointC := conn.HandleChannelOpen(chanAccessPoint) + newTransportC := conn.HandleChannelOpen(chanTransport) + newDiscoveryC := conn.HandleChannelOpen(chanDiscovery) - if err != nil || conn == nil { + // send first ping right away, then start a ping timer: + hb.SendRequest("ping", false, nil) + + for { select { - // abort if asked to stop: + // need to exit: case <-a.ctx.Done(): - return - // reconnect + return trace.ConnectionProblem(nil, "heartbeat: agent is stopped") + // time to ping: case <-ticker.C: - a.Start() + bytes, _ := a.Clock.Now().UTC().MarshalText() + _, err := hb.SendRequest("ping", false, bytes) + if err != nil { + log.Error(err) + return trace.Wrap(err) + } + a.Debugf("ping -> %v", conn.RemoteAddr()) + // ssh channel closed: + case req := <-reqC: + if req == nil { + return trace.ConnectionProblem(nil, "heartbeat: connection closed") + } + // new access point request: + case nch := <-newAccesspointC: + if nch == nil { + continue + } + a.Debugf("access point request: %v", nch.ChannelType()) + ch, req, err := nch.Accept() + if err != nil { + a.Warningf("failed to accept request: %v", err) + continue + } + go a.proxyAccessPoint(ch, req) + // new transport request: + case nch := <-newTransportC: + if nch == nil { + continue + } + a.Debugf("transport request: %v", nch.ChannelType()) + ch, req, err := nch.Accept() + if err != nil { + a.Warningf("failed to accept request: %v", err) + continue + } + go a.proxyTransport(ch, req) + // new discovery request + case nch := <-newDiscoveryC: + if nch == nil { + continue + } + a.Debugf("discovery request: %v", nch.ChannelType()) + ch, req, err := nch.Accept() + if err != nil { + a.Warningf("failed to accept request: %v", err) + continue + } + go a.handleDiscovery(ch, req) } } } diff --git a/lib/reversetunnel/agentpool.go b/lib/reversetunnel/agentpool.go index 996cee5a3c8c8..71acea105da0d 100644 --- a/lib/reversetunnel/agentpool.go +++ b/lib/reversetunnel/agentpool.go @@ -335,6 +335,7 @@ func (m *AgentPool) syncAgents(tunnels []services.ReverseTunnel) error { return trace.Wrap(err) } } + m.reportStats() return nil } diff --git a/lib/state/cachingaccesspoint.go b/lib/state/cachingaccesspoint.go index dcd587564809c..3f3f96f9aba3d 100644 --- a/lib/state/cachingaccesspoint.go +++ b/lib/state/cachingaccesspoint.go @@ -48,10 +48,6 @@ func init() { prometheus.MustRegister(accessPointRequests) } -const ( - backoffDuration = time.Second * 10 -) - // CachingAuthClient implements auth.AccessPoint interface and remembers // the previously returned upstream value for each API call. // @@ -507,7 +503,7 @@ func (cs *CachingAuthClient) UpsertTunnelConnection(conn services.TunnelConnecti // time is recorded. Future calls to f will be ingored until sufficient time passes // since th last error func (cs *CachingAuthClient) try(f func() error) error { - tooSoon := cs.lastErrorTime.Add(backoffDuration).After(time.Now()) + tooSoon := cs.lastErrorTime.Add(defaults.NetworkBackoffDuration).After(time.Now()) if tooSoon { cs.Warnf("backoff: using cached value due to recent errors") return trace.ConnectionProblem(fmt.Errorf("backoff"), "backing off due to recent errors") From 8839b855395450111b3ccd2e01eadbc662fba5a5 Mon Sep 17 00:00:00 2001 From: Sasha Klizhentas Date: Mon, 9 Oct 2017 12:15:56 -0700 Subject: [PATCH 08/24] update trace --- lib/utils/cli.go | 2 +- vendor/github.com/gravitational/trace/log.go | 48 +++++++++++++++----- 2 files changed, 37 insertions(+), 13 deletions(-) diff --git a/lib/utils/cli.go b/lib/utils/cli.go index 38f4c6687fbd5..248eee05b14a6 100644 --- a/lib/utils/cli.go +++ b/lib/utils/cli.go @@ -44,7 +44,7 @@ const ( // InitLogger configures the global logger for a given purpose / verbosity level func InitLogger(purpose LoggingPurpose, level log.Level) { log.StandardLogger().Hooks = make(log.LevelHooks) - formatter := &trace.TextFormatter{DisableTimestamp: false} + formatter := &trace.TextFormatter{DisableTimestamp: true} log.SetFormatter(formatter) log.SetLevel(level) diff --git a/vendor/github.com/gravitational/trace/log.go b/vendor/github.com/gravitational/trace/log.go index f42edcc2e7f9b..d0ad4555b72df 100644 --- a/vendor/github.com/gravitational/trace/log.go +++ b/vendor/github.com/gravitational/trace/log.go @@ -21,6 +21,7 @@ import ( "bytes" "fmt" "regexp" + rundebug "runtime/debug" "sort" "strings" "time" @@ -42,6 +43,10 @@ const ( Component = "trace.component" // ComponentFields is a fields compoonent ComponentFields = "trace.fields" + // DefaultComponentPadding is a default char padding for component + DefaultComponentPadding = 11 + // DefaultLevelPadding is a default char padding for component + DefaultLevelPadding = 4 ) // TextFormatter is logrus-compatible formatter and adds @@ -50,13 +55,17 @@ type TextFormatter struct { // DisableTimestamp disables timestamp output (useful when outputting to // systemd logs) DisableTimestamp bool + // ComponentPadding is a padding to pick when displaying + // and formatting component field, default is set to 11 + ComponentPadding int } // Format implements logrus.Formatter interface and adds file and line func (tf *TextFormatter) Format(e *log.Entry) (data []byte, err error) { defer func() { if r := recover(); r != nil { - err = BadParameter("recovered from panic: %v", r) + data = append([]byte("panic in log formatter\n"), rundebug.Stack()...) + return } }() var file string @@ -73,18 +82,32 @@ func (tf *TextFormatter) Format(e *log.Entry) (data []byte, err error) { } // level - w.writeField(strings.ToUpper(padMax(e.Level.String(), 4))) + w.writeField(strings.ToUpper(padMax(e.Level.String(), DefaultLevelPadding))) - // component if present, highly visible - component, ok := e.Data[Component] - if ok { - if w.Len() > 0 { - w.WriteByte(' ') - } - w.WriteByte('[') - w.WriteString(strings.ToUpper(padMax(fmt.Sprintf("%v", component), 11))) - w.WriteByte(']') + // component, always output + componentI, ok := e.Data[Component] + if !ok { + componentI = "" + } + component, ok := componentI.(string) + if !ok { + component = fmt.Sprintf("%v", componentI) + } + padding := DefaultComponentPadding + if tf.ComponentPadding != 0 { + padding = tf.ComponentPadding + } + if w.Len() > 0 { + w.WriteByte(' ') + } + if component != "" { + component = fmt.Sprintf("[%v]", component) + } + component = strings.ToUpper(padMax(component, padding)) + if component[len(component)-1] != ' ' { + component = component[:len(component)-1] + "]" } + w.WriteString(component) // message if e.Message != "" { @@ -102,7 +125,8 @@ func (tf *TextFormatter) Format(e *log.Entry) (data []byte, err error) { w.writeMap(e.Data) } w.WriteByte('\n') - return w.Bytes(), nil + data = w.Bytes() + return } // JSONFormatter implements logrus.Formatter interface and adds file and line From a55116dd008df8da93b461ce71cc153f2afd8c5d Mon Sep 17 00:00:00 2001 From: Sasha Klizhentas Date: Mon, 9 Oct 2017 15:56:18 -0700 Subject: [PATCH 09/24] fixes before revendoring --- lib/reversetunnel/agent.go | 4 ++-- lib/reversetunnel/remotesite.go | 2 +- lib/reversetunnel/srv.go | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/lib/reversetunnel/agent.go b/lib/reversetunnel/agent.go index 0745b76499025..8b29bbd1a1e49 100644 --- a/lib/reversetunnel/agent.go +++ b/lib/reversetunnel/agent.go @@ -149,8 +149,8 @@ func NewAgent(cfg AgentConfig) (*Agent, error) { } a.Entry = log.WithFields(log.Fields{ trace.Component: teleport.ComponentReverseTunnelAgent, - trace.ComponentFields: map[string]interface{}{ - "remote": cfg.Addr.String(), + trace.ComponentFields: log.Fields{ + "target": cfg.Addr.String(), }, }) a.hostKeyCallback = a.checkHostSignature diff --git a/lib/reversetunnel/remotesite.go b/lib/reversetunnel/remotesite.go index bfd89933738f2..5c1aff0d61184 100644 --- a/lib/reversetunnel/remotesite.go +++ b/lib/reversetunnel/remotesite.go @@ -177,7 +177,7 @@ func (s *remoteSite) handleHeartbeat(conn *remoteConn, ch ssh.Channel, reqC <-ch } } if roundtrip != 0 { - s.Debugf("ping <- %v rtt(%v)", conn.conn.RemoteAddr(), roundtrip) + s.WithFields(log.Fields{"rtt": roundtrip}).Debugf("ping <- %v", conn.conn.RemoteAddr()) } else { s.Debugf("ping <- %v", conn.conn.RemoteAddr()) } diff --git a/lib/reversetunnel/srv.go b/lib/reversetunnel/srv.go index d0d63e9195576..0236a81209c33 100644 --- a/lib/reversetunnel/srv.go +++ b/lib/reversetunnel/srv.go @@ -680,7 +680,7 @@ func newRemoteSite(srv *server, domainName string) (*remoteSite, error) { connInfo: connInfo, Entry: log.WithFields(log.Fields{ trace.Component: teleport.ComponentReverseTunnelServer, - trace.ComponentFields: map[string]interface{}{ + trace.ComponentFields: log.Fields{ "cluster": domainName, }, }), From f12024031a4c14f6cda52a441b901100e7a0cc61 Mon Sep 17 00:00:00 2001 From: Sasha Klizhentas Date: Mon, 9 Oct 2017 18:58:24 -0700 Subject: [PATCH 10/24] more work on logging and stats --- Gopkg.lock | 6 +- Gopkg.toml | 2 +- lib/defaults/defaults.go | 3 + lib/reversetunnel/agentpool.go | 13 ++- lib/reversetunnel/remotesite.go | 2 +- vendor/github.com/gravitational/trace/log.go | 36 +++--- .../gravitational/trace/trace_test.go | 108 +++++++++++++++++- 7 files changed, 141 insertions(+), 29 deletions(-) diff --git a/Gopkg.lock b/Gopkg.lock index 65f927d2fc78c..edf3215d656ca 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -150,8 +150,8 @@ [[projects]] name = "github.com/gravitational/trace" packages = ["."] - revision = "640274f75816ce3975e730e844290c14936f6288" - version = "1.0.0" + revision = "6b5a7fe88920c524ac56d22ed6f371ca25ec7e13" + version = "1.1.1" [[projects]] branch = "master" @@ -390,6 +390,6 @@ [solve-meta] analyzer-name = "dep" analyzer-version = 1 - inputs-digest = "206dacde9d25202422550789814d38dc5a453f29beaeb6e00d059a8f4101515a" + inputs-digest = "d1eb2a47c4fafb650c2c5ffef90bf7949a795641da2bf5ac41cf84eb3b74c806" solver-name = "gps-cdcl" solver-version = 1 diff --git a/Gopkg.toml b/Gopkg.toml index 8a19836c52867..d358d99198777 100644 --- a/Gopkg.toml +++ b/Gopkg.toml @@ -126,4 +126,4 @@ [[constraint]] name = "github.com/gravitational/trace" - version = "1.0.0" + version = "1.1.1" diff --git a/lib/defaults/defaults.go b/lib/defaults/defaults.go index 252fde3a9b2ba..fe6d2d40d28dd 100644 --- a/lib/defaults/defaults.go +++ b/lib/defaults/defaults.go @@ -176,6 +176,9 @@ var ( // NewtworkBackoffDuration is a standard backoff on network requests NetworkBackoffDuration = time.Second * 10 + + // ReportingPeriod is a period for reports in logs + ReportingPeriod = 5 * time.Minute ) // Default connection limits, they can be applied separately on any of the Teleport diff --git a/lib/reversetunnel/agentpool.go b/lib/reversetunnel/agentpool.go index 71acea105da0d..887519cd6b619 100644 --- a/lib/reversetunnel/agentpool.go +++ b/lib/reversetunnel/agentpool.go @@ -40,6 +40,8 @@ type AgentPool struct { ctx context.Context cancel context.CancelFunc discoveryC chan *discoveryRequest + // lastReport is the last time the agent has reported the stats + lastReport time.Time } // AgentPoolConfig holds configuration parameters for the agent pool @@ -101,7 +103,7 @@ func NewAgentPool(cfg AgentPoolConfig) (*AgentPool, error) { } pool.Entry = log.WithFields(log.Fields{ trace.Component: teleport.ComponentReverseTunnelAgent, - trace.ComponentFields: map[string]interface{}{ + trace.ComponentFields: log.Fields{ "cluster": cfg.Cluster, }, }) @@ -299,6 +301,11 @@ func (m *AgentPool) addAgent(key agentKey, discoverProxies []services.Server) er // reportStats submits report about agents state once in a while func (m *AgentPool) reportStats() { + var logReport bool + if m.cfg.Clock.Now().Sub(m.lastReport) > defaults.ReportingPeriod { + m.lastReport = m.cfg.Clock.Now() + logReport = true + } for key, agents := range m.agents { countPerState := make(map[string]int) for _, a := range agents { @@ -312,7 +319,9 @@ func (m *AgentPool) reportStats() { } gauge.Set(float64(count)) } - m.Debugf("STATS: %v -> %v", key.domainName, countPerState) + if logReport { + m.WithFields(log.Fields{"target": key.domainName, "stats": countPerState}).Infof("outbound tunnel stats") + } } } diff --git a/lib/reversetunnel/remotesite.go b/lib/reversetunnel/remotesite.go index 5c1aff0d61184..357433e6fa7b6 100644 --- a/lib/reversetunnel/remotesite.go +++ b/lib/reversetunnel/remotesite.go @@ -177,7 +177,7 @@ func (s *remoteSite) handleHeartbeat(conn *remoteConn, ch ssh.Channel, reqC <-ch } } if roundtrip != 0 { - s.WithFields(log.Fields{"rtt": roundtrip}).Debugf("ping <- %v", conn.conn.RemoteAddr()) + s.WithFields(log.Fields{"latency": roundtrip}).Debugf("ping <- %v", conn.conn.RemoteAddr()) } else { s.Debugf("ping <- %v", conn.conn.RemoteAddr()) } diff --git a/vendor/github.com/gravitational/trace/log.go b/vendor/github.com/gravitational/trace/log.go index d0ad4555b72df..df89bac03163d 100644 --- a/vendor/github.com/gravitational/trace/log.go +++ b/vendor/github.com/gravitational/trace/log.go @@ -23,6 +23,7 @@ import ( "regexp" rundebug "runtime/debug" "sort" + "strconv" "strings" "time" @@ -41,11 +42,11 @@ const ( // Component is a field that represents component - e.g. service or // function Component = "trace.component" - // ComponentFields is a fields compoonent + // ComponentFields is a fields component ComponentFields = "trace.fields" - // DefaultComponentPadding is a default char padding for component + // DefaultComponentPadding is a default padding for component field DefaultComponentPadding = 11 - // DefaultLevelPadding is a default char padding for component + // DefaultLevelPadding is a default padding for level field DefaultLevelPadding = 4 ) @@ -56,7 +57,7 @@ type TextFormatter struct { // systemd logs) DisableTimestamp bool // ComponentPadding is a padding to pick when displaying - // and formatting component field, default is set to 11 + // and formatting component field, defaults to DefaultComponentPadding ComponentPadding int } @@ -74,7 +75,7 @@ func (tf *TextFormatter) Format(e *log.Entry) (data []byte, err error) { file = t.Loc() } - w := &writer{bytes.Buffer{}} + w := &writer{} // time if !tf.DisableTimestamp { @@ -114,16 +115,16 @@ func (tf *TextFormatter) Format(e *log.Entry) (data []byte, err error) { w.writeField(e.Message) } - // file, if present - if file != "" { - w.writeField(file) - } - // rest of the fields if len(e.Data) > 0 { - w.WriteByte(' ') w.writeMap(e.Data) } + + // file, if present, always last + if file != "" { + w.writeField(file) + } + w.WriteByte('\n') data = w.Bytes() return @@ -143,6 +144,7 @@ func (j *JSONFormatter) Format(e *log.Entry) ([]byte, error) { FileField: t.Loc(), FunctionField: t.FuncName(), }) + new.Time = e.Time new.Level = e.Level new.Message = e.Message e = new @@ -211,16 +213,8 @@ func (w *writer) writeMap(m map[string]interface{}) { continue } switch val := m[key].(type) { - case map[string]interface{}: - w.WriteString(key) - w.WriteString(":{") - w.writeMap(val) - w.WriteString(" }") case log.Fields: - w.WriteString(key) - w.WriteString(":{") w.writeMap(val) - w.WriteString(" }") default: w.writeKeyValue(key, val) } @@ -228,8 +222,8 @@ func (w *writer) writeMap(m map[string]interface{}) { } func needsQuoting(text string) bool { - for _, ch := range text { - if ch < 32 { + for _, r := range text { + if !strconv.IsPrint(r) { return true } } diff --git a/vendor/github.com/gravitational/trace/trace_test.go b/vendor/github.com/gravitational/trace/trace_test.go index cfb430db26bf3..c1b360897286c 100644 --- a/vendor/github.com/gravitational/trace/trace_test.go +++ b/vendor/github.com/gravitational/trace/trace_test.go @@ -105,7 +105,6 @@ func (s *TraceSuite) TestWrapStdlibErrors(c *C) { } func (s *TraceSuite) TestLogFormatter(c *C) { - for _, f := range []log.Formatter{&TextFormatter{}, &JSONFormatter{}} { log.SetFormatter(f) @@ -123,6 +122,113 @@ func (s *TraceSuite) TestLogFormatter(c *C) { } } +type panicker string + +func (p panicker) String() string { + panic(p) +} + +func (s *TraceSuite) TestTextFormatter(c *C) { + padding := 6 + f := &TextFormatter{ + DisableTimestamp: true, + ComponentPadding: padding, + } + log.SetFormatter(f) + + type testCase struct { + log func() + match string + comment string + } + + testCases := []testCase{ + { + comment: "padding fits in", + log: func() { + log.WithFields(log.Fields{ + Component: "test", + }).Infof("hello") + }, + match: `^INFO \[TEST\] hello.*`, + }, + { + comment: "padding overflow", + log: func() { + log.WithFields(log.Fields{ + Component: "longline", + }).Infof("hello") + }, + match: `^INFO \[LONG\] hello.*`, + }, + { + comment: "padded with extra spaces", + log: func() { + log.WithFields(log.Fields{ + Component: "abc", + }).Infof("hello") + }, + match: `^INFO \[ABC\] hello.*`, + }, + { + comment: "missing component will be padded", + log: func() { + log.Infof("hello") + }, + match: `^INFO hello.*`, + }, + { + comment: "panic in component is handled", + log: func() { + log.WithFields(log.Fields{ + Component: panicker("panic"), + }).Infof("hello") + }, + match: `.*panic.*`, + }, + { + comment: "nested fields are reflected", + log: func() { + log.WithFields(log.Fields{ + ComponentFields: log.Fields{"key": "value"}, + }).Infof("hello") + }, + match: `.*key:value.*`, + }, + { + comment: "fields are reflected", + log: func() { + log.WithFields(log.Fields{ + "a": "b", + }).Infof("hello") + }, + match: `.*a:b.*`, + }, + { + comment: "non control characters are quoted", + log: func() { + log.Infof("\n") + }, + match: `.*"\\n".*`, + }, + { + comment: "printable strings are not quoted", + log: func() { + log.Infof("printable string") + }, + match: `.*[^"]printable string[^"].*`, + }, + } + + for i, tc := range testCases { + comment := Commentf("test case %v %v, expected match: %v", i+1, tc.comment, tc.match) + buf := &bytes.Buffer{} + log.SetOutput(buf) + tc.log() + c.Assert(line(buf.String()), Matches, tc.match, comment) + } +} + func (s *TraceSuite) TestGenericErrors(c *C) { testCases := []struct { Err error From aa62a1d627da8ade7ff2c3ceec78279b44f9fd35 Mon Sep 17 00:00:00 2001 From: Sasha Klizhentas Date: Mon, 9 Oct 2017 19:59:14 -0700 Subject: [PATCH 11/24] document the discovery algo --- lib/reversetunnel/agentpool.go | 16 +++++++ lib/reversetunnel/doc.go | 85 ++++++++++++++++++++++++++++++++++ 2 files changed, 101 insertions(+) create mode 100644 lib/reversetunnel/doc.go diff --git a/lib/reversetunnel/agentpool.go b/lib/reversetunnel/agentpool.go index 887519cd6b619..59d897f551922 100644 --- a/lib/reversetunnel/agentpool.go +++ b/lib/reversetunnel/agentpool.go @@ -1,3 +1,19 @@ +/* +Copyright 2015 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + package reversetunnel import ( diff --git a/lib/reversetunnel/doc.go b/lib/reversetunnel/doc.go new file mode 100644 index 0000000000000..efc73ab3b0dde --- /dev/null +++ b/lib/reversetunnel/doc.go @@ -0,0 +1,85 @@ +/* +Copyright 2015 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +/* package reversetunnel provides tools for accessing remote clusters + via reverse tunnels and directly + + Proxy server Proxy agent + Reverse tunnel + +----------+ +---------+ + | <----------------------+ | + | | | | ++-----+----------+ +---------+-----+ +| | | | +| | | | ++----------------+ +---------------+ + Proxy Cluster "A" Proxy Cluster "B" + + +Reverse tunnel is established from the cluster "B" Proxy +to the cluster "A" proxy, and clients of cluster "A" +can access servers of cluster "B" via reverse tunnel, +even if the cluster "B" is behind the firewall. + +Multiple Proxies Design + +With multiple proxies behind the load balancer, +proxy agents will eventually discover and establish connections to all +proxies in a cluster. + +* Initially Proxy Agent connects to the Proxy 1. +* Proxy 1 starts sending information about all the other proxies +in the cluster and whether they proxies have received the connection +from the agent or not. + + ++----------+ +| <--------+ +| | | ++----------+ | +-----------+ +----------+ + Proxy 1 +-------------------------------+ | + | | | | + +-----------+ +----------+ + Load Balancer Proxy Agent ++----------+ +| | +| | ++----------+ + Proxy 2 + +* Agent will use this information to establish new connections +and check if it connected and "discovered" all the proxies. +* Assuming that load balancer uses fair load balancing algorithm, +agent will eventually discover all the proxies and connect back to them all + ++----------+ +| <--------+ +| | | ++----------+ | +-----------+ +----------+ + Proxy 1 +-------------------------------+ | + | | | | | + | +-----------+ +----------+ + | Load Balancer Proxy Agent ++----------+ | +| <--------+ +| | ++----------+ + Proxy 2 + + + +*/ +package reversetunnel From e82ac5601a1bbff0e63977d6dd3973193615b310 Mon Sep 17 00:00:00 2001 From: Sasha Klizhentas Date: Tue, 10 Oct 2017 09:26:33 -0700 Subject: [PATCH 12/24] tweak the docs --- lib/reversetunnel/doc.go | 31 +++++++++++++++++-------------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/lib/reversetunnel/doc.go b/lib/reversetunnel/doc.go index efc73ab3b0dde..bd45b38d16cc4 100644 --- a/lib/reversetunnel/doc.go +++ b/lib/reversetunnel/doc.go @@ -14,8 +14,10 @@ See the License for the specific language governing permissions and limitations under the License. */ -/* package reversetunnel provides tools for accessing remote clusters - via reverse tunnels and directly +/* package reversetunnel provides interfaces for accessing remote clusters + via reverse tunnels and directly. + +Reverse Tunnels Proxy server Proxy agent Reverse tunnel @@ -29,21 +31,21 @@ limitations under the License. Proxy Cluster "A" Proxy Cluster "B" -Reverse tunnel is established from the cluster "B" Proxy -to the cluster "A" proxy, and clients of cluster "A" -can access servers of cluster "B" via reverse tunnel, +Reverse tunnel is established from a cluster "B" Proxy +to the a cluster "A" proxy, and clients of the cluster "A" +can access servers of the cluster "B" via reverse tunnel connection, even if the cluster "B" is behind the firewall. -Multiple Proxies Design +Multiple Proxies and Revese Tunnels With multiple proxies behind the load balancer, proxy agents will eventually discover and establish connections to all -proxies in a cluster. +proxies in cluster. -* Initially Proxy Agent connects to the Proxy 1. -* Proxy 1 starts sending information about all the other proxies -in the cluster and whether they proxies have received the connection -from the agent or not. +* Initially Proxy Agent connects to Proxy 1. +* Proxy 1 starts sending information about all available proxies +that have not received connection from the Proxy Agent yet. This +process is called "sending discovery request". +----------+ @@ -60,10 +62,11 @@ from the agent or not. +----------+ Proxy 2 -* Agent will use this information to establish new connections -and check if it connected and "discovered" all the proxies. +* Agent will use the discovery request to establish new connections +and check if it has connected and "discovered" all the proxies specified + in the discovery request. * Assuming that load balancer uses fair load balancing algorithm, -agent will eventually discover all the proxies and connect back to them all +agent will eventually discover and connect back to all the proxies. +----------+ | <--------+ From 35e380ac9c703473e7c61c4700c885a39aa47340 Mon Sep 17 00:00:00 2001 From: Sasha Klizhentas Date: Wed, 11 Oct 2017 15:31:19 -0700 Subject: [PATCH 13/24] add TCP load balancer to use in tests --- lib/utils/addr.go | 5 + lib/utils/loadbalancer.go | 216 +++++++++++++++++++++++++++++++++ lib/utils/loadbalancer_test.go | 211 ++++++++++++++++++++++++++++++++ 3 files changed, 432 insertions(+) create mode 100644 lib/utils/loadbalancer.go create mode 100644 lib/utils/loadbalancer_test.go diff --git a/lib/utils/addr.go b/lib/utils/addr.go index 89e328fb895db..4842691b01ad5 100644 --- a/lib/utils/addr.go +++ b/lib/utils/addr.go @@ -38,6 +38,11 @@ type NetAddr struct { Path string `json:"path,omitempty"` } +// Equals returns true if address is equal to other +func (a *NetAddr) Equals(other NetAddr) bool { + return a.Addr == other.Addr && a.AddrNetwork == other.AddrNetwork && a.Path == other.Path +} + // IsLocal returns true if this is a local address func (a *NetAddr) IsLocal() bool { host, _, err := net.SplitHostPort(a.Addr) diff --git a/lib/utils/loadbalancer.go b/lib/utils/loadbalancer.go new file mode 100644 index 0000000000000..2f71833d6b0e5 --- /dev/null +++ b/lib/utils/loadbalancer.go @@ -0,0 +1,216 @@ +/* +Copyright 2017 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package utils + +import ( + "context" + "io" + "net" + "sync" + "time" + + "github.com/gravitational/trace" + log "github.com/sirupsen/logrus" +) + +// NewLoadBalancer returns new load balancer listening on frotend +// and redirecting requests to backends using round robin algo +func NewLoadBalancer(ctx context.Context, frontend NetAddr, backends []NetAddr) (*LoadBalancer, error) { + if ctx == nil { + return nil, trace.BadParameter("missing parameter context") + } + return &LoadBalancer{ + frontend: frontend, + ctx: ctx, + backends: backends, + currentIndex: -1, + Entry: log.WithFields(log.Fields{ + trace.Component: "loadbalancer", + trace.ComponentFields: log.Fields{ + "listen": frontend.String(), + }, + }), + }, nil +} + +// LoadBalancer implements naive round robin TCP load +// balancer used in tests. +type LoadBalancer struct { + sync.RWMutex + *log.Entry + frontend NetAddr + backends []NetAddr + ctx context.Context + currentIndex int + listener net.Listener + listenerClosed bool +} + +// AddBackend adds backend +func (l *LoadBalancer) AddBackend(b NetAddr) { + l.Lock() + defer l.Unlock() + l.backends = append(l.backends, b) + l.Debugf("backends %v", l.backends) +} + +// RemoveBackend removes backend +func (l *LoadBalancer) RemoveBackend(b NetAddr) { + l.Lock() + defer l.Unlock() + l.currentIndex = -1 + for i := range l.backends { + if l.backends[i].Equals(b) { + l.backends = append(l.backends[:i], l.backends[i+1:]...) + return + } + } + +} + +func (l *LoadBalancer) nextBackend() (*NetAddr, error) { + l.Lock() + defer l.Unlock() + if len(l.backends) == 0 { + return nil, trace.ConnectionProblem(nil, "no backends") + } + l.currentIndex = ((l.currentIndex + 1) % len(l.backends)) + return &l.backends[l.currentIndex], nil +} + +func (l *LoadBalancer) closeListener() { + l.Lock() + defer l.Unlock() + if l.listener == nil { + return + } + if l.listenerClosed { + return + } + l.listenerClosed = true + l.listener.Close() +} + +func (l *LoadBalancer) isClosed() bool { + l.RLock() + defer l.RUnlock() + return l.listenerClosed +} + +func (l *LoadBalancer) Close() error { + l.closeListener() + return nil +} + +// ListenAndServe starts listening socket and serves connections on it +func (l *LoadBalancer) ListenAndServe() error { + if err := l.Listen(); err != nil { + return trace.Wrap(err) + } + return l.Serve() +} + +// Listen creates a listener on the frontend addr +func (l *LoadBalancer) Listen() error { + var err error + l.listener, err = net.Listen(l.frontend.AddrNetwork, l.frontend.Addr) + if err != nil { + return trace.ConvertSystemError(err) + } + l.Debugf("created listening socket") + return nil +} + +// Serve starts accepting connections +func (l *LoadBalancer) Serve() error { + backoffTimer := time.NewTicker(5 * time.Second) + defer backoffTimer.Stop() + for { + conn, err := l.listener.Accept() + if err != nil { + if l.isClosed() { + return trace.ConnectionProblem(nil, "listener is closed") + } + select { + case <-backoffTimer.C: + l.Debugf("backoff on network error") + case <-l.ctx.Done(): + return trace.ConnectionProblem(nil, "context is closing") + } + } + go l.forwardConnection(conn) + } +} + +func (l *LoadBalancer) forwardConnection(conn net.Conn) { + err := l.forward(conn) + if err != nil { + l.Warningf("failed to forward connection: %v", err) + } +} + +func (l *LoadBalancer) forward(conn net.Conn) error { + defer conn.Close() + + backend, err := l.nextBackend() + if err != nil { + return trace.Wrap(err) + } + + backendConn, err := net.Dial(backend.AddrNetwork, backend.Addr) + if err != nil { + return trace.ConvertSystemError(err) + } + defer backendConn.Close() + + logger := l.WithFields(log.Fields{ + "source": conn.RemoteAddr(), + "dest": backendConn.RemoteAddr(), + }) + logger.Debugf("forward") + + messagesC := make(chan error, 2) + + go func() { + defer conn.Close() + defer backendConn.Close() + _, err := io.Copy(conn, backendConn) + messagesC <- err + }() + + go func() { + defer conn.Close() + defer backendConn.Close() + _, err := io.Copy(backendConn, conn) + messagesC <- err + }() + + var lastErr error + for i := 0; i < 2; i++ { + select { + case err := <-messagesC: + if err != nil && err != io.EOF { + logger.Warningf("connection problem: %v %T", trace.DebugReport(err), err) + lastErr = err + } + case <-l.ctx.Done(): + return trace.ConnectionProblem(nil, "context is closing") + } + } + + return lastErr +} diff --git a/lib/utils/loadbalancer_test.go b/lib/utils/loadbalancer_test.go new file mode 100644 index 0000000000000..d972e974834d2 --- /dev/null +++ b/lib/utils/loadbalancer_test.go @@ -0,0 +1,211 @@ +/* +Copyright 2017 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package utils + +import ( + "bufio" + "context" + "fmt" + "io/ioutil" + "net" + "net/http" + "net/http/httptest" + "net/url" + + "gopkg.in/check.v1" +) + +type LBSuite struct { +} + +var _ = check.Suite(&LBSuite{}) + +func (s *LBSuite) SetUpSuite(c *check.C) { + InitLoggerForTests() +} + +func (s *LBSuite) TestSingleBackendLB(c *check.C) { + backend1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, "backend 1") + })) + defer backend1.Close() + + ports, err := GetFreeTCPPorts(1) + c.Assert(err, check.IsNil) + + frontend := localAddr(ports[0]) + + lb, err := NewLoadBalancer(context.TODO(), frontend, []NetAddr{urlToNetAddr(backend1.URL)}) + c.Assert(err, check.IsNil) + err = lb.Listen() + c.Assert(err, check.IsNil) + go lb.Serve() + defer lb.Close() + + out, err := roundtrip(frontend.String()) + c.Assert(err, check.IsNil) + c.Assert(out, check.Equals, "backend 1") +} + +func (s *LBSuite) TestTwoBackendsLB(c *check.C) { + backend1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, "backend 1") + })) + defer backend1.Close() + + backend2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, "backend 2") + })) + defer backend2.Close() + + backend1Addr, backend2Addr := urlToNetAddr(backend1.URL), urlToNetAddr(backend2.URL) + + ports, err := GetFreeTCPPorts(1) + c.Assert(err, check.IsNil) + + frontend := localAddr(ports[0]) + + lb, err := NewLoadBalancer(context.TODO(), frontend, nil) + c.Assert(err, check.IsNil) + err = lb.Listen() + c.Assert(err, check.IsNil) + go lb.Serve() + defer lb.Close() + + // no endpoints + _, err = roundtrip(frontend.String()) + c.Assert(err, check.NotNil) + + lb.AddBackend(backend1Addr) + out, err := roundtrip(frontend.String()) + c.Assert(err, check.IsNil) + c.Assert(out, check.Equals, "backend 1") + + lb.AddBackend(backend2Addr) + out, err = roundtrip(frontend.String()) + c.Assert(err, check.IsNil) + c.Assert(out, check.Equals, "backend 2") +} + +func (s *LBSuite) TestOneFailingBackend(c *check.C) { + backend1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, "backend 1") + })) + defer backend1.Close() + + backend2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, "backend 2") + })) + backend2.Close() + + backend1Addr, backend2Addr := urlToNetAddr(backend1.URL), urlToNetAddr(backend2.URL) + + ports, err := GetFreeTCPPorts(1) + c.Assert(err, check.IsNil) + + frontend := localAddr(ports[0]) + + lb, err := NewLoadBalancer(context.TODO(), frontend, nil) + c.Assert(err, check.IsNil) + err = lb.Listen() + c.Assert(err, check.IsNil) + go lb.Serve() + defer lb.Close() + + lb.AddBackend(backend1Addr) + lb.AddBackend(backend2Addr) + + out, err := roundtrip(frontend.String()) + c.Assert(err, check.IsNil) + c.Assert(out, check.Equals, "backend 1") + + out, err = roundtrip(frontend.String()) + c.Assert(err, check.NotNil) + + out, err = roundtrip(frontend.String()) + c.Assert(err, check.IsNil) + c.Assert(out, check.Equals, "backend 1") +} + +func (s *LBSuite) TestClose(c *check.C) { + backend1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, "backend 1") + })) + defer backend1.Close() + + ports, err := GetFreeTCPPorts(1) + c.Assert(err, check.IsNil) + + frontend := localAddr(ports[0]) + + lb, err := NewLoadBalancer(context.TODO(), frontend, []NetAddr{urlToNetAddr(backend1.URL)}) + c.Assert(err, check.IsNil) + err = lb.Listen() + c.Assert(err, check.IsNil) + go lb.Serve() + defer lb.Close() + + out, err := roundtrip(frontend.String()) + c.Assert(err, check.IsNil) + c.Assert(out, check.Equals, "backend 1") + + lb.Close() + // second close works + lb.Close() + + // requests are failing + out, err = roundtrip(frontend.String()) + c.Assert(err, check.NotNil) +} + +func urlToNetAddr(u string) NetAddr { + parsed, err := url.Parse(u) + if err != nil { + panic(err) + } + return *MustParseAddr(parsed.Host) +} + +func localURL(port string) string { + return fmt.Sprintf("http://127.0.0.1:%v", port) +} + +func localAddr(port string) NetAddr { + return *MustParseAddr(fmt.Sprintf("127.0.0.1:%v", port)) +} + +// roundtrip is a single connection simplistic HTTP client +// that allows us to bypass a connection pool to test load balancing +func roundtrip(addr string) (string, error) { + conn, err := net.Dial("tcp", addr) + if err != nil { + return "", err + } + defer conn.Close() + fmt.Fprintf(conn, "GET / HTTP/1.1\r\nHost: 127.0.0.1\r\n\r\n\r\n") + + re, err := http.ReadResponse(bufio.NewReader(conn), nil) + if err != nil { + return "", err + } + defer re.Body.Close() + out, err := ioutil.ReadAll(re.Body) + if err != nil { + return "", err + } + return string(out), nil +} From d7e3ff1416eb91e34959524c3d0297c1dd3effb4 Mon Sep 17 00:00:00 2001 From: Sasha Klizhentas Date: Wed, 11 Oct 2017 15:31:53 -0700 Subject: [PATCH 14/24] remove extra newline --- lib/utils/loadbalancer_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/utils/loadbalancer_test.go b/lib/utils/loadbalancer_test.go index d972e974834d2..7258ef5ff4f84 100644 --- a/lib/utils/loadbalancer_test.go +++ b/lib/utils/loadbalancer_test.go @@ -196,7 +196,7 @@ func roundtrip(addr string) (string, error) { return "", err } defer conn.Close() - fmt.Fprintf(conn, "GET / HTTP/1.1\r\nHost: 127.0.0.1\r\n\r\n\r\n") + fmt.Fprintf(conn, "GET / HTTP/1.1\r\nHost: 127.0.0.1\r\n\r\n") re, err := http.ReadResponse(bufio.NewReader(conn), nil) if err != nil { From 9c31410a4deb678452466bf18239a5a50fbcb405 Mon Sep 17 00:00:00 2001 From: Sasha Klizhentas Date: Wed, 11 Oct 2017 16:36:25 -0700 Subject: [PATCH 15/24] start working on integration testing --- integration/integration_test.go | 41 +++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/integration/integration_test.go b/integration/integration_test.go index c94308b0b8206..0a91abab6381a 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -852,6 +852,47 @@ func (s *IntSuite) TestMapRoles(c *check.C) { c.Assert(aux.Stop(true), check.IsNil) } +// TestDiscovery tests case for multiple proxies and a reverse tunnel +// agent that eventually connnects to the the right proxy +func (s *IntSuite) TestDiscovery(c *check.C) { + username := s.me.Username + + a := NewInstance("cluster-a", HostID, Host, s.getPorts(5), s.priv, s.pub) + b := NewInstance("cluster-b", HostID, Host, s.getPorts(5), s.priv, s.pub) + + a.AddUser(username, []string{username}) + b.AddUser(username, []string{username}) + + c.Assert(b.Create(a.Secrets.AsSlice(), false, nil), check.IsNil) + c.Assert(a.Create(b.Secrets.AsSlice(), true, nil), check.IsNil) + + c.Assert(b.Start(), check.IsNil) + c.Assert(a.Start(), check.IsNil) + + // wait for both sites to see each other via their reverse tunnels (for up to 10 seconds) + abortTime := time.Now().Add(time.Second * 10) + for len(b.Tunnel.GetSites()) < 2 && len(b.Tunnel.GetSites()) < 2 { + time.Sleep(time.Millisecond * 2000) + if time.Now().After(abortTime) { + c.Fatalf("two clusters do not see each other: tunnels are not working") + } + } + + cmd := []string{"echo", "hello world"} + tc, err := b.NewClient(username, "cluster-a", "127.0.0.1", a.GetPortSSHInt()) + c.Assert(err, check.IsNil) + output := &bytes.Buffer{} + tc.Stdout = output + c.Assert(err, check.IsNil) + err = tc.SSH(context.TODO(), cmd, false) + c.Assert(err, check.IsNil) + c.Assert(output.String(), check.Equals, "hello world\n") + + // stop cluster and remaining nodes + c.Assert(a.Stop(true), check.IsNil) + c.Assert(b.Stop(true), check.IsNil) +} + // getPorts helper returns a range of unallocated ports available for litening on func (s *IntSuite) getPorts(num int) []int { if len(s.ports) < num { From 7b82e31150b150b92f36235a34bdc26286f10e5a Mon Sep 17 00:00:00 2001 From: Sasha Klizhentas Date: Wed, 11 Oct 2017 17:23:03 -0700 Subject: [PATCH 16/24] add fast and slow pace tickers --- lib/defaults/defaults.go | 11 +++++- lib/reversetunnel/agent.go | 12 +++++-- lib/utils/ticker.go | 74 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 93 insertions(+), 4 deletions(-) create mode 100644 lib/utils/ticker.go diff --git a/lib/defaults/defaults.go b/lib/defaults/defaults.go index fe6d2d40d28dd..5fb389728556e 100644 --- a/lib/defaults/defaults.go +++ b/lib/defaults/defaults.go @@ -175,7 +175,16 @@ var ( TerminalSizeRefreshPeriod = 2 * time.Second // NewtworkBackoffDuration is a standard backoff on network requests - NetworkBackoffDuration = time.Second * 10 + // usually is slow, e.g. once in 30 seconds + NetworkBackoffDuration = time.Second * 30 + + // NewtworkRetryDuration is a standard retry on network requests + // to retry quickly, e.g. once in one second + NetworkRetryDuration = time.Second + + // FastAttempts is the intial amount of fast retry attempts + // before switching to slow mode + FastAttempts = 10 // ReportingPeriod is a period for reports in logs ReportingPeriod = 5 * time.Minute diff --git a/lib/reversetunnel/agent.go b/lib/reversetunnel/agent.go index 8b29bbd1a1e49..f7c720c1742b9 100644 --- a/lib/reversetunnel/agent.go +++ b/lib/reversetunnel/agent.go @@ -428,8 +428,12 @@ func (a *Agent) proxyTransport(ch ssh.Channel, reqC <-chan *ssh.Request) { // run is the main agent loop, constantly tries to re-establish // the connection until stopped func (a *Agent) run() { - backoff := time.NewTicker(defaults.NetworkBackoffDuration) - defer backoff.Stop() + ticker, err := utils.NewSwitchTicker(defaults.FastAttempts, defaults.NetworkRetryDuration, defaults.NetworkBackoffDuration) + if err != nil { + log.Error("failed to run: %v", err) + return + } + defer ticker.Stop() firstAttempt := true for { if len(a.DiscoverProxies) != 0 { @@ -446,16 +450,18 @@ func (a *Agent) run() { a.Debugf("agent has closed, exiting") return // wait backoff on network retries - case <-backoff.C: + case <-ticker.Channel(): } } conn, err := a.connect() firstAttempt = false if err != nil || conn == nil { + ticker.IncrementFailureCount() a.Warningf("failed to create remote tunnel: %v, conn: %v", err, conn) continue } else { + ticker.Reset() a.Infof("connected to %s", conn.RemoteAddr()) if len(a.DiscoverProxies) != 0 { if !a.connectedToRightProxy() { diff --git a/lib/utils/ticker.go b/lib/utils/ticker.go new file mode 100644 index 0000000000000..de145fb586622 --- /dev/null +++ b/lib/utils/ticker.go @@ -0,0 +1,74 @@ +/* +Copyright 2017 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package utils + +import ( + "sync/atomic" + "time" + + "github.com/gravitational/trace" +) + +// NewSwitchTicker returns new instance of the switch ticker +func NewSwitchTicker(threshold int, slowPeriod time.Duration, fastPeriod time.Duration) (*SwitchTicker, error) { + if threshold == 0 { + return nil, trace.BadParameter("missing threshold") + } + if slowPeriod <= 0 || fastPeriod <= 0 { + return nil, trace.BadParameter("bad slow period or fast period parameters") + } + return &SwitchTicker{ + threshold: int64(threshold), + slowTicker: time.NewTicker(slowPeriod), + fastTicker: time.NewTicker(fastPeriod), + }, nil +} + +// SwitchTicker switches between slow and fast +// ticker based on the number of failures +type SwitchTicker struct { + threshold int64 + failCount int64 + slowTicker *time.Ticker + fastTicker *time.Ticker +} + +// IncrementFailureCount increments internal failure count +func (c *SwitchTicker) IncrementFailureCount() { + atomic.AddInt64(&c.failCount, 1) +} + +// Channel returns either channel with fast ticker or slow ticker +// based on whether failure count exceeds threshold or not +func (c *SwitchTicker) Channel() <-chan time.Time { + failCount := atomic.LoadInt64(&c.failCount) + if failCount > c.threshold { + return c.fastTicker.C + } + return c.slowTicker.C +} + +// Reset resets internal failure counter and switches back to fast retry period +func (c *SwitchTicker) Reset() { + atomic.StoreInt64(&c.failCount, 0) +} + +// Stop stops tickers and has to be called to prevent timer leaks +func (c *SwitchTicker) Stop() { + c.slowTicker.Stop() + c.fastTicker.Stop() +} From 0290cccb573b802dada6196a63cf9b5c33cf08a0 Mon Sep 17 00:00:00 2001 From: Sasha Klizhentas Date: Thu, 12 Oct 2017 10:35:46 -0700 Subject: [PATCH 17/24] integration tests for proxies --- integration/helpers.go | 108 +++++++++++++++++++----- integration/integration_test.go | 78 +++++++++++------ lib/auth/native/native.go | 6 +- lib/auth/testauthority/testauthority.go | 4 +- lib/utils/loadbalancer.go | 2 +- lib/utils/loadbalancer_test.go | 8 +- 6 files changed, 150 insertions(+), 56 deletions(-) diff --git a/integration/helpers.go b/integration/helpers.go index 3f1e52779bd4e..84be43060b030 100644 --- a/integration/helpers.go +++ b/integration/helpers.go @@ -388,6 +388,51 @@ func (i *TeleInstance) StartNode(name string, sshPort, proxyWebPort, proxySSHPor return process.Start() } +// ProxyConfig is a set of configuration parameters for Proxy +type ProxyConfig struct { + // Name is a proxy name + Name string + // SSHPort is SSH proxy port + SSHPort int + // WebPort is web proxy port + WebPort int + // ReverseTunnelPort is a port for reverse tunnel addresses + ReverseTunnelPort int +} + +// StartProxy starts proxy server and adds it to the cluster +func (i *TeleInstance) StartProxy(cfg ProxyConfig) error { + dataDir, err := ioutil.TempDir("", "cluster-"+i.Secrets.SiteName+"-"+cfg.Name) + if err != nil { + return trace.Wrap(err) + } + tconf := service.MakeDefaultConfig() + tconf.HostUUID = cfg.Name + tconf.Hostname = cfg.Name + tconf.DataDir = dataDir + tconf.Auth.Enabled = false + tconf.Proxy.Enabled = true + tconf.SSH.Enabled = false + authServer := utils.MustParseAddr(net.JoinHostPort(i.Hostname, i.GetPortAuth())) + tconf.AuthServers = append(tconf.AuthServers, *authServer) + tconf.Token = "token" + tconf.Proxy.Enabled = true + tconf.Proxy.SSHAddr.Addr = net.JoinHostPort(i.Hostname, fmt.Sprintf("%v", cfg.SSHPort)) + tconf.Proxy.ReverseTunnelListenAddr.Addr = net.JoinHostPort(i.Hostname, fmt.Sprintf("%v", cfg.SSHPort)) + tconf.Proxy.WebAddr.Addr = net.JoinHostPort(i.Hostname, fmt.Sprintf("%v", cfg.WebPort)) + tconf.Proxy.DisableReverseTunnel = false + tconf.Proxy.DisableWebService = true + // Enable caching + tconf.CachePolicy = service.CachePolicy{Enabled: true} + + process, err := service.NewTeleport(tconf) + if err != nil { + return trace.Wrap(err) + } + i.Nodes = append(i.Nodes, process) + return process.Start() +} + // Reset re-creates the teleport instance based on the same configuration // This is needed if you want to stop the instance, reset it and start again func (i *TeleInstance) Reset() (err error) { @@ -471,9 +516,23 @@ func (i *TeleInstance) Start() (err error) { return err } +// ClientConfig is a client configuration +type ClientConfig struct { + // Login is SSH login name + Login string + // Cluster is a cluster name to connect to + Cluster string + // Host string is a target host to connect to + Host string + // Port is a target port to connect to + Port int + // Proxy is an optional alternative proxy to use + Proxy *ProxyConfig +} + // NewClient returns a fully configured and pre-authenticated client // (pre-authenticated with server CAs and signed session key) -func (i *TeleInstance) NewClient(login string, site string, host string, port int) (tc *client.TeleportClient, err error) { +func (i *TeleInstance) NewClient(cfg ClientConfig) (tc *client.TeleportClient, err error) { keyDir, err := ioutil.TempDir(i.Config.DataDir, "tsh") if err != nil { return nil, err @@ -485,26 +544,33 @@ func (i *TeleInstance) NewClient(login string, site string, host string, port in if err != nil { return nil, trace.Wrap(err) } - proxySSHPort, err := strconv.Atoi(sp) - if err != nil { - return nil, trace.Wrap(err) - } - _, sp, err = net.SplitHostPort(proxyConf.WebAddr.Addr) - if err != nil { - return nil, trace.Wrap(err) - } - proxyWebPort, err := strconv.Atoi(sp) - if err != nil { - return nil, trace.Wrap(err) + + // use alternative proxy if necessary + var proxySSHPort, proxyWebPort int + if cfg.Proxy == nil { + proxySSHPort, err = strconv.Atoi(sp) + if err != nil { + return nil, trace.Wrap(err) + } + _, sp, err = net.SplitHostPort(proxyConf.WebAddr.Addr) + if err != nil { + return nil, trace.Wrap(err) + } + proxyWebPort, err = strconv.Atoi(sp) + if err != nil { + return nil, trace.Wrap(err) + } + } else { + proxySSHPort, proxyWebPort = cfg.Proxy.SSHPort, cfg.Proxy.WebPort } cconf := &client.Config{ - Username: login, - Host: host, - HostPort: port, - HostLogin: login, + Username: cfg.Login, + Host: cfg.Host, + HostPort: cfg.Port, + HostLogin: cfg.Login, InsecureSkipVerify: true, KeysDir: keyDir, - SiteName: site, + SiteName: cfg.Cluster, } cconf.SetProxy(proxyHost, proxyWebPort, proxySSHPort) @@ -513,14 +579,14 @@ func (i *TeleInstance) NewClient(login string, site string, host string, port in return nil, err } // confnigures the client authenticate using the keys from 'secrets': - user, ok := i.Secrets.Users[login] + user, ok := i.Secrets.Users[cfg.Login] if !ok { - return nil, trace.Errorf("unknown login '%v'", login) + return nil, trace.BadParameter("unknown login %q", cfg.Login) } if user.Key == nil { - return nil, trace.Errorf("user %v has no key", login) + return nil, trace.BadParameter("user %q has no key", cfg.Login) } - _, err = tc.AddKey(host, user.Key) + _, err = tc.AddKey(cfg.Host, user.Key) if err != nil { return nil, trace.Wrap(err) } diff --git a/integration/integration_test.go b/integration/integration_test.go index 0a91abab6381a..1cce55df78781 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -144,7 +144,7 @@ func (s *IntSuite) TestAudit(c *check.C) { endC := make(chan error, 0) myTerm := NewTerminal(250) go func() { - cl, err := t.NewClient(s.me.Username, Site, Host, t.GetPortSSHInt()) + cl, err := t.NewClient(ClientConfig{Login: s.me.Username, Cluster: Site, Host: Host, Port: t.GetPortSSHInt()}) c.Assert(err, check.IsNil) cl.Stdout = &myTerm cl.Stdin = &myTerm @@ -326,7 +326,7 @@ func (s *IntSuite) TestInteroperability(c *check.C) { for i, tt := range tests { // create new teleport client - cl, err := t.NewClient(s.me.Username, Site, Host, t.GetPortSSHInt()) + cl, err := t.NewClient(ClientConfig{Login: s.me.Username, Cluster: Site, Host: Host, Port: t.GetPortSSHInt()}) c.Assert(err, check.IsNil) // hook up stdin and stdout to a buffer for reading and writing @@ -376,7 +376,7 @@ func (s *IntSuite) TestInteractive(c *check.C) { // PersonA: SSH into the server, wait one second, then type some commands on stdin: openSession := func() { - cl, err := t.NewClient(s.me.Username, Site, Host, t.GetPortSSHInt()) + cl, err := t.NewClient(ClientConfig{Login: s.me.Username, Cluster: Site, Host: Host, Port: t.GetPortSSHInt()}) c.Assert(err, check.IsNil) cl.Stdout = &personA cl.Stdin = &personA @@ -399,7 +399,7 @@ func (s *IntSuite) TestInteractive(c *check.C) { sessionID = string(sessions[0].ID) break } - cl, err := t.NewClient(s.me.Username, Site, Host, t.GetPortSSHInt()) + cl, err := t.NewClient(ClientConfig{Login: s.me.Username, Cluster: Site, Host: Host, Port: t.GetPortSSHInt()}) c.Assert(err, check.IsNil) cl.Stdout = &personB for i := 0; i < 10; i++ { @@ -433,7 +433,7 @@ func (s *IntSuite) TestEnvironmentVariables(c *check.C) { cmd := []string{"printenv", testKey} // make sure sessions set run command - tc, err := t.NewClient(s.me.Username, Site, Host, t.GetPortSSHInt()) + tc, err := t.NewClient(ClientConfig{Login: s.me.Username, Cluster: Site, Host: Host, Port: t.GetPortSSHInt()}) c.Assert(err, check.IsNil) tc.Env = map[string]string{testKey: testVal} @@ -455,7 +455,7 @@ func (s *IntSuite) TestInvalidLogins(c *check.C) { cmd := []string{"echo", "success"} // try the wrong site: - tc, err := t.NewClient(s.me.Username, "wrong-site", Host, t.GetPortSSHInt()) + tc, err := t.NewClient(ClientConfig{Login: s.me.Username, Cluster: "wrong-site", Host: Host, Port: t.GetPortSSHInt()}) c.Assert(err, check.IsNil) err = tc.SSH(context.TODO(), cmd, false) c.Assert(err, check.ErrorMatches, "cluster wrong-site not found") @@ -513,7 +513,7 @@ func (s *IntSuite) TestTwoClusters(c *check.C) { cmd := []string{"echo", "hello world"} // directly: - tc, err := a.NewClient(username, "site-A", Host, sshPort) + tc, err := a.NewClient(ClientConfig{Login: username, Cluster: "site-A", Host: Host, Port: sshPort}) tc.Stdout = &outputA c.Assert(err, check.IsNil) err = tc.SSH(context.TODO(), cmd, false) @@ -521,7 +521,7 @@ func (s *IntSuite) TestTwoClusters(c *check.C) { c.Assert(outputA.String(), check.Equals, "hello world\n") // via tunnel b->a: - tc, err = b.NewClient(username, "site-A", Host, sshPort) + tc, err = b.NewClient(ClientConfig{Login: username, Cluster: "site-A", Host: Host, Port: sshPort}) tc.Stdout = &outputB c.Assert(err, check.IsNil) err = tc.SSH(context.TODO(), cmd, false) @@ -631,7 +631,7 @@ func (s *IntSuite) TestHA(c *check.C) { } cmd := []string{"echo", "hello world"} - tc, err := b.NewClient(username, "cluster-a", "127.0.0.1", sshPort) + tc, err := b.NewClient(ClientConfig{Login: username, Cluster: "cluster-a", Host: "127.0.0.1", Port: sshPort}) c.Assert(err, check.IsNil) output := &bytes.Buffer{} tc.Stdout = output @@ -761,7 +761,7 @@ func (s *IntSuite) TestMapRoles(c *check.C) { } cmd := []string{"echo", "hello world"} - tc, err := main.NewClient(username, clusterAux, "127.0.0.1", sshPort) + tc, err := main.NewClient(ClientConfig{Login: username, Cluster: clusterAux, Host: "127.0.0.1", Port: sshPort}) c.Assert(err, check.IsNil) output := &bytes.Buffer{} tc.Stdout = output @@ -818,28 +818,28 @@ func (s *IntSuite) TestMapRoles(c *check.C) { } for i, tt := range tests { - cid := services.CertAuthID{services.UserCA, "cluster-main"} + cid := services.CertAuthID{Type: services.UserCA, DomainName: "cluster-main"} mainUserCAs, err := tt.inCluster.Process.GetAuthServer().GetCertAuthority(cid, true) c.Assert(err, tt.outChkMainUserCA) if tt.outChkMainUserCA == check.IsNil { c.Assert(mainUserCAs.GetSigningKeys(), check.HasLen, tt.outLenMainUserCA, check.Commentf("Test %v, Main User CA", i)) } - cid = services.CertAuthID{services.HostCA, "cluster-main"} + cid = services.CertAuthID{Type: services.HostCA, DomainName: "cluster-main"} mainHostCAs, err := tt.inCluster.Process.GetAuthServer().GetCertAuthority(cid, true) c.Assert(err, tt.outChkMainHostCA) if tt.outChkMainHostCA == check.IsNil { c.Assert(mainHostCAs.GetSigningKeys(), check.HasLen, tt.outLenMainHostCA, check.Commentf("Test %v, Main Host CA", i)) } - cid = services.CertAuthID{services.UserCA, "cluster-aux"} + cid = services.CertAuthID{Type: services.UserCA, DomainName: "cluster-aux"} auxUserCAs, err := tt.inCluster.Process.GetAuthServer().GetCertAuthority(cid, true) c.Assert(err, tt.outChkAuxUserCA) if tt.outChkAuxUserCA == check.IsNil { c.Assert(auxUserCAs.GetSigningKeys(), check.HasLen, tt.outLenAuxUserCA, check.Commentf("Test %v, Aux User CA", i)) } - cid = services.CertAuthID{services.HostCA, "cluster-aux"} + cid = services.CertAuthID{Type: services.HostCA, DomainName: "cluster-aux"} auxHostCAs, err := tt.inCluster.Process.GetAuthServer().GetCertAuthority(cid, true) c.Assert(err, tt.outChkAuxHostCA) if tt.outChkAuxHostCA == check.IsNil { @@ -857,29 +857,55 @@ func (s *IntSuite) TestMapRoles(c *check.C) { func (s *IntSuite) TestDiscovery(c *check.C) { username := s.me.Username - a := NewInstance("cluster-a", HostID, Host, s.getPorts(5), s.priv, s.pub) - b := NewInstance("cluster-b", HostID, Host, s.getPorts(5), s.priv, s.pub) + // create load balancer for main cluster proxies + frontend := *utils.MustParseAddr(fmt.Sprintf("127.0.0.1:%v", s.getPorts(1)[0])) + lb, err := utils.NewLoadBalancer(context.TODO(), frontend) + c.Assert(err, check.IsNil) + c.Assert(lb.Listen(), check.IsNil) + defer lb.Close() - a.AddUser(username, []string{username}) - b.AddUser(username, []string{username}) + remote := NewInstance("cluster-remote", HostID, Host, s.getPorts(5), s.priv, s.pub) + main := NewInstance("cluster-main", HostID, Host, s.getPorts(5), s.priv, s.pub) - c.Assert(b.Create(a.Secrets.AsSlice(), false, nil), check.IsNil) - c.Assert(a.Create(b.Secrets.AsSlice(), true, nil), check.IsNil) + remote.AddUser(username, []string{username}) + main.AddUser(username, []string{username}) - c.Assert(b.Start(), check.IsNil) - c.Assert(a.Start(), check.IsNil) + c.Assert(main.Create(remote.Secrets.AsSlice(), false, nil), check.IsNil) + mainSecrets := main.Secrets + // switch listen address of the main cluster to load balancer + lb.AddBackend(*utils.MustParseAddr(mainSecrets.ListenAddr)) + mainSecrets.ListenAddr = frontend.String() + c.Assert(remote.Create(mainSecrets.AsSlice(), true, nil), check.IsNil) + + c.Assert(main.Start(), check.IsNil) + c.Assert(remote.Start(), check.IsNil) // wait for both sites to see each other via their reverse tunnels (for up to 10 seconds) abortTime := time.Now().Add(time.Second * 10) - for len(b.Tunnel.GetSites()) < 2 && len(b.Tunnel.GetSites()) < 2 { + for len(main.Tunnel.GetSites()) < 2 && len(main.Tunnel.GetSites()) < 2 { time.Sleep(time.Millisecond * 2000) if time.Now().After(abortTime) { c.Fatalf("two clusters do not see each other: tunnels are not working") } } + // start second proxy + nodePorts := s.getPorts(3) + proxyReverseTunnelPort, proxyWebPort, proxySSHPort := nodePorts[0], nodePorts[1], nodePorts[2] + err = main.StartProxy(ProxyConfig{ + Name: "cluster-main-proxy", + SSHPort: proxySSHPort, + WebPort: proxyWebPort, + ReverseTunnelPort: proxyReverseTunnelPort, + }) + c.Assert(err, check.IsNil) + + // add second proxy as a backend to the load balancer + lb.AddBackend(*utils.MustParseAddr(fmt.Sprintf("127.0.0.0:%v", proxyReverseTunnelPort))) + + // execute the connection via first proxy cmd := []string{"echo", "hello world"} - tc, err := b.NewClient(username, "cluster-a", "127.0.0.1", a.GetPortSSHInt()) + tc, err := main.NewClient(ClientConfig{Login: username, Cluster: "cluster-remote", Host: "127.0.0.1", Port: remote.GetPortSSHInt()}) c.Assert(err, check.IsNil) output := &bytes.Buffer{} tc.Stdout = output @@ -889,8 +915,8 @@ func (s *IntSuite) TestDiscovery(c *check.C) { c.Assert(output.String(), check.Equals, "hello world\n") // stop cluster and remaining nodes - c.Assert(a.Stop(true), check.IsNil) - c.Assert(b.Stop(true), check.IsNil) + c.Assert(remote.Stop(true), check.IsNil) + c.Assert(main.Stop(true), check.IsNil) } // getPorts helper returns a range of unallocated ports available for litening on diff --git a/lib/auth/native/native.go b/lib/auth/native/native.go index 80627c63b6f2f..0e7044102546b 100644 --- a/lib/auth/native/native.go +++ b/lib/auth/native/native.go @@ -150,7 +150,7 @@ func (n *nauth) GenerateHostCert(c services.HostCertParams) ([]byte, error) { return nil, trace.Wrap(err) } - principals := buildPrincipals(c.HostID, c.NodeName, c.ClusterName, c.Roles) + principals := BuildPrincipals(c.HostID, c.NodeName, c.ClusterName, c.Roles) // create certificate validBefore := uint64(ssh.CertTimeInfinity) @@ -230,12 +230,12 @@ func (n *nauth) GenerateUserCert(c services.UserCertParams) ([]byte, error) { return ssh.MarshalAuthorizedKey(cert), nil } -// buildPrincipals takes a hostID, nodeName, clusterName, and role and builds a list of +// BuildPrincipals takes a hostID, nodeName, clusterName, and role and builds a list of // principals to insert into a certificate. This function is backward compatible with // older clients which means: // * If RoleAdmin is in the list of roles, only a single principal is returned: hostID // * If nodename is empty, it is not included in the list of principals. -func buildPrincipals(hostID string, nodeName string, clusterName string, roles teleport.Roles) []string { +func BuildPrincipals(hostID string, nodeName string, clusterName string, roles teleport.Roles) []string { // TODO(russjones): This should probably be clusterName, but we need to // verify changing this won't break older clients. if roles.Include(teleport.RoleAdmin) { diff --git a/lib/auth/testauthority/testauthority.go b/lib/auth/testauthority/testauthority.go index fd1c56583604c..9a33854030c2f 100644 --- a/lib/auth/testauthority/testauthority.go +++ b/lib/auth/testauthority/testauthority.go @@ -22,6 +22,7 @@ import ( "time" "github.com/gravitational/teleport" + "github.com/gravitational/teleport/lib/auth/native" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/utils" @@ -54,8 +55,9 @@ func (n *Keygen) GenerateHostCert(c services.HostCertParams) ([]byte, error) { b := time.Now().Add(c.TTL) validBefore = uint64(b.Unix()) } + principals := native.BuildPrincipals(c.HostID, c.NodeName, c.ClusterName, c.Roles) cert := &ssh.Certificate{ - ValidPrincipals: []string{c.HostID, c.NodeName}, + ValidPrincipals: principals, Key: pubKey, ValidBefore: validBefore, CertType: ssh.HostCert, diff --git a/lib/utils/loadbalancer.go b/lib/utils/loadbalancer.go index 2f71833d6b0e5..7396fe9623c9b 100644 --- a/lib/utils/loadbalancer.go +++ b/lib/utils/loadbalancer.go @@ -29,7 +29,7 @@ import ( // NewLoadBalancer returns new load balancer listening on frotend // and redirecting requests to backends using round robin algo -func NewLoadBalancer(ctx context.Context, frontend NetAddr, backends []NetAddr) (*LoadBalancer, error) { +func NewLoadBalancer(ctx context.Context, frontend NetAddr, backends ...NetAddr) (*LoadBalancer, error) { if ctx == nil { return nil, trace.BadParameter("missing parameter context") } diff --git a/lib/utils/loadbalancer_test.go b/lib/utils/loadbalancer_test.go index 7258ef5ff4f84..ba4d0cc83dcb2 100644 --- a/lib/utils/loadbalancer_test.go +++ b/lib/utils/loadbalancer_test.go @@ -49,7 +49,7 @@ func (s *LBSuite) TestSingleBackendLB(c *check.C) { frontend := localAddr(ports[0]) - lb, err := NewLoadBalancer(context.TODO(), frontend, []NetAddr{urlToNetAddr(backend1.URL)}) + lb, err := NewLoadBalancer(context.TODO(), frontend, urlToNetAddr(backend1.URL)) c.Assert(err, check.IsNil) err = lb.Listen() c.Assert(err, check.IsNil) @@ -79,7 +79,7 @@ func (s *LBSuite) TestTwoBackendsLB(c *check.C) { frontend := localAddr(ports[0]) - lb, err := NewLoadBalancer(context.TODO(), frontend, nil) + lb, err := NewLoadBalancer(context.TODO(), frontend) c.Assert(err, check.IsNil) err = lb.Listen() c.Assert(err, check.IsNil) @@ -119,7 +119,7 @@ func (s *LBSuite) TestOneFailingBackend(c *check.C) { frontend := localAddr(ports[0]) - lb, err := NewLoadBalancer(context.TODO(), frontend, nil) + lb, err := NewLoadBalancer(context.TODO(), frontend) c.Assert(err, check.IsNil) err = lb.Listen() c.Assert(err, check.IsNil) @@ -152,7 +152,7 @@ func (s *LBSuite) TestClose(c *check.C) { frontend := localAddr(ports[0]) - lb, err := NewLoadBalancer(context.TODO(), frontend, []NetAddr{urlToNetAddr(backend1.URL)}) + lb, err := NewLoadBalancer(context.TODO(), frontend, urlToNetAddr(backend1.URL)) c.Assert(err, check.IsNil) err = lb.Listen() c.Assert(err, check.IsNil) From e461b4e6bd43d66bb7050023d91a70d789622b26 Mon Sep 17 00:00:00 2001 From: Sasha Klizhentas Date: Thu, 12 Oct 2017 16:51:18 -0700 Subject: [PATCH 18/24] fix tests --- integration/helpers.go | 2 +- integration/integration_test.go | 68 ++++++++++++--- lib/auth/api.go | 3 + lib/auth/apiserver.go | 10 +++ lib/auth/auth_with_roles.go | 7 ++ lib/auth/clt.go | 12 +++ lib/auth/init_test.go | 2 +- lib/auth/tun.go | 2 +- lib/backend/dir/impl_test.go | 2 +- lib/reversetunnel/agent.go | 19 +++- lib/reversetunnel/remotesite.go | 24 ++++- lib/services/local/presence.go | 5 ++ lib/services/presence.go | 3 + lib/services/suite/suite.go | 16 ++++ lib/srv/sshserver_test.go | 126 ++++++++++++++++++--------- lib/state/cachingaccesspoint.go | 7 +- lib/state/cachingaccesspoint_test.go | 3 +- lib/utils/loadbalancer.go | 45 +++++++++- lib/utils/loadbalancer_test.go | 48 +++++++++- lib/web/apiserver_test.go | 21 +++-- 20 files changed, 349 insertions(+), 76 deletions(-) diff --git a/integration/helpers.go b/integration/helpers.go index 84be43060b030..d48975b765a56 100644 --- a/integration/helpers.go +++ b/integration/helpers.go @@ -418,7 +418,7 @@ func (i *TeleInstance) StartProxy(cfg ProxyConfig) error { tconf.Token = "token" tconf.Proxy.Enabled = true tconf.Proxy.SSHAddr.Addr = net.JoinHostPort(i.Hostname, fmt.Sprintf("%v", cfg.SSHPort)) - tconf.Proxy.ReverseTunnelListenAddr.Addr = net.JoinHostPort(i.Hostname, fmt.Sprintf("%v", cfg.SSHPort)) + tconf.Proxy.ReverseTunnelListenAddr.Addr = net.JoinHostPort(i.Hostname, fmt.Sprintf("%v", cfg.ReverseTunnelPort)) tconf.Proxy.WebAddr.Addr = net.JoinHostPort(i.Hostname, fmt.Sprintf("%v", cfg.WebPort)) tconf.Proxy.DisableReverseTunnel = false tconf.Proxy.DisableWebService = true diff --git a/integration/integration_test.go b/integration/integration_test.go index 1cce55df78781..060d594f66235 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -862,6 +862,7 @@ func (s *IntSuite) TestDiscovery(c *check.C) { lb, err := utils.NewLoadBalancer(context.TODO(), frontend) c.Assert(err, check.IsNil) c.Assert(lb.Listen(), check.IsNil) + go lb.Serve() defer lb.Close() remote := NewInstance("cluster-remote", HostID, Host, s.getPorts(5), s.priv, s.pub) @@ -873,7 +874,8 @@ func (s *IntSuite) TestDiscovery(c *check.C) { c.Assert(main.Create(remote.Secrets.AsSlice(), false, nil), check.IsNil) mainSecrets := main.Secrets // switch listen address of the main cluster to load balancer - lb.AddBackend(*utils.MustParseAddr(mainSecrets.ListenAddr)) + mainProxyAddr := *utils.MustParseAddr(mainSecrets.ListenAddr) + lb.AddBackend(mainProxyAddr) mainSecrets.ListenAddr = frontend.String() c.Assert(remote.Create(mainSecrets.AsSlice(), true, nil), check.IsNil) @@ -892,33 +894,79 @@ func (s *IntSuite) TestDiscovery(c *check.C) { // start second proxy nodePorts := s.getPorts(3) proxyReverseTunnelPort, proxyWebPort, proxySSHPort := nodePorts[0], nodePorts[1], nodePorts[2] - err = main.StartProxy(ProxyConfig{ + proxyConfig := ProxyConfig{ Name: "cluster-main-proxy", SSHPort: proxySSHPort, WebPort: proxyWebPort, ReverseTunnelPort: proxyReverseTunnelPort, - }) + } + err = main.StartProxy(proxyConfig) c.Assert(err, check.IsNil) // add second proxy as a backend to the load balancer - lb.AddBackend(*utils.MustParseAddr(fmt.Sprintf("127.0.0.0:%v", proxyReverseTunnelPort))) + lb.AddBackend(*utils.MustParseAddr(fmt.Sprintf("127.0.0.1:%v", proxyReverseTunnelPort))) // execute the connection via first proxy - cmd := []string{"echo", "hello world"} - tc, err := main.NewClient(ClientConfig{Login: username, Cluster: "cluster-remote", Host: "127.0.0.1", Port: remote.GetPortSSHInt()}) + cfg := ClientConfig{Login: username, Cluster: "cluster-remote", Host: "127.0.0.1", Port: remote.GetPortSSHInt()} + output, err := runCommand(main, []string{"echo", "hello world"}, cfg, 1) c.Assert(err, check.IsNil) - output := &bytes.Buffer{} - tc.Stdout = output + c.Assert(output, check.Equals, "hello world\n") + + // execute the connection via second proxy, should work + cfgProxy := ClientConfig{ + Login: username, + Cluster: "cluster-remote", + Host: "127.0.0.1", + Port: remote.GetPortSSHInt(), + Proxy: &proxyConfig, + } + output, err = runCommand(main, []string{"echo", "hello world"}, cfgProxy, 10) c.Assert(err, check.IsNil) - err = tc.SSH(context.TODO(), cmd, false) + c.Assert(output, check.Equals, "hello world\n") + + // now disconnect the main proxy and make sure it will reconnect eventually + lb.RemoveBackend(mainProxyAddr) + + // requests going via main proxy will fail + output, err = runCommand(main, []string{"echo", "hello world"}, cfg, 1) + c.Assert(err, check.NotNil) + + // requests going via second proxy will succeed + output, err = runCommand(main, []string{"echo", "hello world"}, cfgProxy, 1) c.Assert(err, check.IsNil) - c.Assert(output.String(), check.Equals, "hello world\n") + c.Assert(output, check.Equals, "hello world\n") + + // connect the main proxy back and make sure agents have reconnected over time + lb.AddBackend(mainProxyAddr) + output, err = runCommand(main, []string{"echo", "hello world"}, cfg, 10) + c.Assert(err, check.IsNil) + c.Assert(output, check.Equals, "hello world\n") // stop cluster and remaining nodes c.Assert(remote.Stop(true), check.IsNil) c.Assert(main.Stop(true), check.IsNil) } +// runCommand is a shortcut for running SSH command, it creates +// a client connected to proxy hosted by instance +// and returns the result +func runCommand(instance *TeleInstance, cmd []string, cfg ClientConfig, attempts int) (string, error) { + tc, err := instance.NewClient(cfg) + if err != nil { + return "", trace.Wrap(err) + } + output := &bytes.Buffer{} + tc.Stdout = output + for i := 0; i < attempts; i++ { + err = tc.SSH(context.TODO(), cmd, false) + if err == nil { + break + } + time.Sleep(time.Millisecond * 50) + } + return output.String(), trace.Wrap(err) +} + // getPorts helper returns a range of unallocated ports available for litening on func (s *IntSuite) getPorts(num int) []int { if len(s.ports) < num { diff --git a/lib/auth/api.go b/lib/auth/api.go index 60dbad345604c..503429e36814a 100644 --- a/lib/auth/api.go +++ b/lib/auth/api.go @@ -64,6 +64,9 @@ type AccessPoint interface { // UpsertTunnelConnection upserts tunnel connection UpsertTunnelConnection(conn services.TunnelConnection) error + // DeleteTunnelConnection deletes tunnel connection + DeleteTunnelConnection(clusterName, connName string) error + // GetTunnelConnections returns tunnel connections for a given cluster GetTunnelConnections(clusterName string) ([]services.TunnelConnection, error) diff --git a/lib/auth/apiserver.go b/lib/auth/apiserver.go index c1b0ebc2e9810..1962ede180c4c 100644 --- a/lib/auth/apiserver.go +++ b/lib/auth/apiserver.go @@ -105,6 +105,7 @@ func NewAPIServer(config *APIConfig) http.Handler { srv.POST("/:version/tunnelconnections", srv.withAuth(srv.upsertTunnelConnection)) srv.GET("/:version/tunnelconnections/:cluster", srv.withAuth(srv.getTunnelConnections)) srv.GET("/:version/tunnelconnections", srv.withAuth(srv.getAllTunnelConnections)) + srv.DELETE("/:version/tunnelconnections/:cluster/:conn", srv.withAuth(srv.deleteTunnelConnection)) srv.DELETE("/:version/tunnelconnections/:cluster", srv.withAuth(srv.deleteTunnelConnections)) srv.DELETE("/:version/tunnelconnections", srv.withAuth(srv.deleteAllTunnelConnections)) @@ -1793,6 +1794,15 @@ func (s *APIServer) getAllTunnelConnections(auth ClientI, w http.ResponseWriter, return items, nil } +// deleteTunnelConnection deletes tunnel connection by name +func (s *APIServer) deleteTunnelConnection(auth ClientI, w http.ResponseWriter, r *http.Request, p httprouter.Params, version string) (interface{}, error) { + err := auth.DeleteTunnelConnection(p.ByName("cluster"), p.ByName("conn")) + if err != nil { + return nil, trace.Wrap(err) + } + return message("ok"), nil +} + // deleteTunnelConnections deletes all tunnel connections for cluster func (s *APIServer) deleteTunnelConnections(auth ClientI, w http.ResponseWriter, r *http.Request, p httprouter.Params, version string) (interface{}, error) { err := auth.DeleteTunnelConnections(p.ByName("cluster")) diff --git a/lib/auth/auth_with_roles.go b/lib/auth/auth_with_roles.go index 53880970322d2..3411cec85e063 100644 --- a/lib/auth/auth_with_roles.go +++ b/lib/auth/auth_with_roles.go @@ -942,6 +942,13 @@ func (a *AuthWithRoles) GetAllTunnelConnections() ([]services.TunnelConnection, return a.authServer.GetAllTunnelConnections() } +func (a *AuthWithRoles) DeleteTunnelConnection(clusterName string, connName string) error { + if err := a.action(defaults.Namespace, services.KindTunnelConnection, services.VerbDelete); err != nil { + return trace.Wrap(err) + } + return a.authServer.DeleteTunnelConnection(clusterName, connName) +} + func (a *AuthWithRoles) DeleteTunnelConnections(clusterName string) error { if err := a.action(defaults.Namespace, services.KindTunnelConnection, services.VerbList); err != nil { return trace.Wrap(err) diff --git a/lib/auth/clt.go b/lib/auth/clt.go index d4a3c04048823..02f9880a454a3 100644 --- a/lib/auth/clt.go +++ b/lib/auth/clt.go @@ -529,6 +529,18 @@ func (c *Client) GetAllTunnelConnections() ([]services.TunnelConnection, error) return conns, nil } +// DeleteTunnelConnection deletes tunnel connection by name +func (c *Client) DeleteTunnelConnection(clusterName string, connName string) error { + if clusterName == "" { + return trace.BadParameter("missing parameter cluster name") + } + if connName == "" { + return trace.BadParameter("missing parameter connection name") + } + _, err := c.Delete(c.Endpoint("tunnelconnections", clusterName, connName)) + return trace.Wrap(err) +} + // DeleteTunnelConnections deletes all tunnel connections for cluster func (c *Client) DeleteTunnelConnections(clusterName string) error { if clusterName == "" { diff --git a/lib/auth/init_test.go b/lib/auth/init_test.go index 2b15a017202c2..139b0ff59758d 100644 --- a/lib/auth/init_test.go +++ b/lib/auth/init_test.go @@ -83,7 +83,7 @@ func (s *AuthInitSuite) TestReadIdentity(c *C) { id, err := ReadIdentityFromKeyPair(priv, cert) c.Assert(err, IsNil) c.Assert(id.AuthorityDomain, Equals, "example.com") - c.Assert(id.ID, DeepEquals, IdentityID{HostUUID: "id1", Role: teleport.RoleNode}) + c.Assert(id.ID, DeepEquals, IdentityID{HostUUID: "id1.example.com", Role: teleport.RoleNode}) c.Assert(id.CertBytes, DeepEquals, cert) c.Assert(id.KeyBytes, DeepEquals, priv) diff --git a/lib/auth/tun.go b/lib/auth/tun.go index a21d3c0eacd68..3c6f2ca25c768 100644 --- a/lib/auth/tun.go +++ b/lib/auth/tun.go @@ -858,7 +858,7 @@ func (c *TunClient) GetDialer() AccessPointDialer { } time.Sleep(4 * time.Duration(attempt) * dialRetryInterval) } - c.Error("%v", err) + c.Errorf("%v", err) return nil, trace.Wrap(err) } } diff --git a/lib/backend/dir/impl_test.go b/lib/backend/dir/impl_test.go index 417f53d6528f3..1ac40cf2987d8 100644 --- a/lib/backend/dir/impl_test.go +++ b/lib/backend/dir/impl_test.go @@ -179,7 +179,7 @@ func (s *Suite) TestTTL(c *check.C) { v, err = s.bk.GetVal(bucket, "key") c.Assert(trace.IsNotFound(err), check.Equals, true) - c.Assert(err.Error(), check.Equals, `key 'key' is not found`) + c.Assert(err.Error(), check.Equals, `key "key" is not found`) c.Assert(v, check.IsNil) } diff --git a/lib/reversetunnel/agent.go b/lib/reversetunnel/agent.go index f7c720c1742b9..3b28fdba806d8 100644 --- a/lib/reversetunnel/agent.go +++ b/lib/reversetunnel/agent.go @@ -80,6 +80,8 @@ type AgentConfig struct { // Clock is a clock passed in tests, if not set wall clock // will be used Clock clockwork.Clock + // EventsC is an optional events channel, used for testing purposes + EventsC chan string } // CheckAndSetDefaults checks parameters and sets default values @@ -87,9 +89,6 @@ func (a *AgentConfig) CheckAndSetDefaults() error { if a.Addr.IsEmpty() { return trace.BadParameter("missing parameter Addr") } - if a.DiscoveryC == nil { - return trace.BadParameter("missing parameter DiscoveryC") - } if a.Context == nil { return trace.BadParameter("missing parameter Context") } @@ -430,7 +429,7 @@ func (a *Agent) proxyTransport(ch ssh.Channel, reqC <-chan *ssh.Request) { func (a *Agent) run() { ticker, err := utils.NewSwitchTicker(defaults.FastAttempts, defaults.NetworkRetryDuration, defaults.NetworkBackoffDuration) if err != nil { - log.Error("failed to run: %v", err) + log.Errorf("failed to run: %v", err) return } defer ticker.Stop() @@ -473,6 +472,15 @@ func (a *Agent) run() { } else { a.setState(agentStateConnected) } + if a.EventsC != nil { + select { + case a.EventsC <- ConnectedEvent: + case <-a.ctx.Done(): + a.Debugf("context is closing") + return + default: + } + } // start heartbeat even if error happend, it will reconnect // when this happens, this is #1 issue we have right now with Teleport. So we are making // it EASY to see in the logs. This condition should never be permanent (repeates @@ -484,6 +492,9 @@ func (a *Agent) run() { } } +// ConnectedEvent is used to indicate that reverse tunnel has connected +const ConnectedEvent = "connected" + // processRequests is a blocking function which runs in a loop sending heartbeats // to the given SSH connection and processes inbound requests from the // remote proxy diff --git a/lib/reversetunnel/remotesite.go b/lib/reversetunnel/remotesite.go index 357433e6fa7b6..abe1e63d8ccfc 100644 --- a/lib/reversetunnel/remotesite.go +++ b/lib/reversetunnel/remotesite.go @@ -78,6 +78,18 @@ func (s *remoteSite) connectionCount() int { return len(s.connections) } +func (s *remoteSite) hasValidConnections() bool { + s.Lock() + defer s.Unlock() + + for _, conn := range s.connections { + if !conn.isInvalid() { + return true + } + } + return false +} + func (s *remoteSite) nextConn() (*remoteConn, error) { s.Lock() defer s.Unlock() @@ -153,6 +165,12 @@ func (s *remoteSite) registerHeartbeat(t time.Time) { } } +// deleteConnectionRecord deletes connection record to let know peer proxies +// that this node lost the connection and needs to be discovered +func (s *remoteSite) deleteConnectionRecord() { + s.srv.AccessPoint.DeleteTunnelConnection(s.connInfo.GetClusterName(), s.connInfo.GetName()) +} + func (s *remoteSite) handleHeartbeat(conn *remoteConn, ch ssh.Channel, reqC <-chan *ssh.Request) { defer func() { s.Infof("cluster connection closed") @@ -165,8 +183,12 @@ func (s *remoteSite) handleHeartbeat(conn *remoteConn, ch ssh.Channel, reqC <-ch return case req := <-reqC: if req == nil { - s.Infof("cluster disconnected") + s.Infof("cluster agent disconnected") conn.markInvalid(trace.ConnectionProblem(nil, "agent disconnected")) + if !s.hasValidConnections() { + s.Debugf("deleting connection record") + s.deleteConnectionRecord() + } return } var timeSent time.Time diff --git a/lib/services/local/presence.go b/lib/services/local/presence.go index 035fa99d96626..5abb1c70d486a 100644 --- a/lib/services/local/presence.go +++ b/lib/services/local/presence.go @@ -420,6 +420,11 @@ func (s *PresenceService) GetAllTunnelConnections() ([]services.TunnelConnection return conns, nil } +// DeleteTunnelConnection deletes tunnel connection by name +func (s *PresenceService) DeleteTunnelConnection(clusterName, connectionName string) error { + return s.DeleteKey([]string{tunnelConnectionsPrefix, clusterName}, connectionName) +} + // DeleteTunnelConnections deletes all tunnel connections for cluster func (s *PresenceService) DeleteTunnelConnections(clusterName string) error { err := s.DeleteBucket([]string{tunnelConnectionsPrefix}, clusterName) diff --git a/lib/services/presence.go b/lib/services/presence.go index dcf599d54f8e1..b38d614baaccf 100644 --- a/lib/services/presence.go +++ b/lib/services/presence.go @@ -105,6 +105,9 @@ type Presence interface { // GetAllTunnelConnections returns all tunnel connections GetAllTunnelConnections() ([]TunnelConnection, error) + // DeleteTunnelConnection deletes tunnel connection by name + DeleteTunnelConnection(clusterName string, connName string) error + // DeleteTunnelConnections deletes all tunnel connections for cluster DeleteTunnelConnections(clusterName string) error diff --git a/lib/services/suite/suite.go b/lib/services/suite/suite.go index 54973f8f32969..9f0b92e1dee1f 100644 --- a/lib/services/suite/suite.go +++ b/lib/services/suite/suite.go @@ -600,4 +600,20 @@ func (s *ServicesTestSuite) TunnelConnectionsCRUD(c *C) { err = s.PresenceS.DeleteAllTunnelConnections() c.Assert(err, IsNil) + + // test delete individual connection + err = s.PresenceS.UpsertTunnelConnection(conn) + c.Assert(err, IsNil) + + out, err = s.PresenceS.GetTunnelConnections(clusterName) + c.Assert(err, IsNil) + c.Assert(len(out), Equals, 1) + fixtures.DeepCompare(c, out[0], conn) + + err = s.PresenceS.DeleteTunnelConnection(clusterName, conn.GetName()) + c.Assert(err, IsNil) + + out, err = s.PresenceS.GetTunnelConnections(clusterName) + c.Assert(err, IsNil) + c.Assert(len(out), Equals, 0) } diff --git a/lib/srv/sshserver_test.go b/lib/srv/sshserver_test.go index 2948456d07a5a..0a72d6f3f840d 100644 --- a/lib/srv/sshserver_test.go +++ b/lib/srv/sshserver_test.go @@ -87,6 +87,8 @@ func (s *SrvSuite) SetUpSuite(c *C) { utils.InitLoggerForTests() } +const hostID = "00000000-0000-0000-0000-000000000000" + func (s *SrvSuite) SetUpTest(c *C) { var err error s.dir = c.MkDir() @@ -155,7 +157,7 @@ func (s *SrvSuite) SetUpTest(c *C) { // set up host private key and certificate hpriv, hpub, err := s.a.GenerateKeyPair("") c.Assert(err, IsNil) - hcert, err := s.a.GenerateHostCert(hpub, "00000000-0000-0000-0000-000000000000", s.domainName, s.domainName, teleport.Roles{teleport.RoleAdmin}, 0) + hcert, err := s.a.GenerateHostCert(hpub, hostID, s.domainName, s.domainName, teleport.Roles{teleport.RoleAdmin}, 0) c.Assert(err, IsNil) // set up user CA and set up a user that has access to the server @@ -469,12 +471,13 @@ func (s *SrvSuite) TestProxyReverseTunnel(c *C) { reverseTunnelPort := s.freePorts[len(s.freePorts)-1] s.freePorts = s.freePorts[:len(s.freePorts)-1] reverseTunnelAddress := utils.NetAddr{AddrNetwork: "tcp", Addr: fmt.Sprintf("%v:%v", s.domainName, reverseTunnelPort)} - reverseTunnelServer, err := reversetunnel.NewServer( - reverseTunnelAddress, - []ssh.Signer{s.signer}, - s.roleAuth, - state.NoCache, - ) + reverseTunnelServer, err := reversetunnel.NewServer(reversetunnel.Config{ + ID: s.domainName, + ListenAddr: reverseTunnelAddress, + HostSigners: []ssh.Signer{s.signer}, + AccessPoint: s.roleAuth, + NewCachingAccessPoint: state.NoCache, + }) c.Assert(err, IsNil) c.Assert(reverseTunnelServer.Start(), IsNil) @@ -500,14 +503,14 @@ func (s *SrvSuite) TestProxyReverseTunnel(c *C) { c.Assert(tsrv.Start(), IsNil) tunClt, err := auth.NewTunClient("test", - []utils.NetAddr{{AddrNetwork: "tcp", Addr: tsrv.Addr()}}, s.domainName, []ssh.AuthMethod{ssh.PublicKeys(s.signer)}) + []utils.NetAddr{{AddrNetwork: "tcp", Addr: tsrv.Addr()}}, hostID, []ssh.AuthMethod{ssh.PublicKeys(s.signer)}) c.Assert(err, IsNil) defer tunClt.Close() agentPool, err := reversetunnel.NewAgentPool(reversetunnel.AgentPoolConfig{ Client: tunClt, HostSigners: []ssh.Signer{s.signer}, - HostUUID: s.domainName, + HostUUID: hostID, AccessPoint: tunClt, }) c.Assert(err, IsNil) @@ -519,13 +522,27 @@ func (s *SrvSuite) TestProxyReverseTunnel(c *C) { err = agentPool.FetchAndSyncAgents() c.Assert(err, IsNil) - rsAgent, err := reversetunnel.NewAgent( - reverseTunnelAddress, - "remote", - "localhost", - []ssh.Signer{s.signer}, tunClt, tunClt) + eventsC := make(chan string, 1) + rsAgent, err := reversetunnel.NewAgent(reversetunnel.AgentConfig{ + Context: context.TODO(), + Addr: reverseTunnelAddress, + RemoteCluster: "remote", + Username: hostID, + Signers: []ssh.Signer{s.signer}, + Client: tunClt, + AccessPoint: tunClt, + EventsC: eventsC, + }) c.Assert(err, IsNil) - c.Assert(rsAgent.Start(), IsNil) + rsAgent.Start() + + timeout := time.After(time.Second) + select { + case event := <-eventsC: + c.Assert(event, Equals, reversetunnel.ConnectedEvent) + case <-timeout: + c.Fatalf("timeout waiting for clusters to connect") + } sshConfig := &ssh.ClientConfig{ User: s.user, @@ -620,12 +637,13 @@ func (s *SrvSuite) TestProxyRoundRobin(c *C) { AddrNetwork: "tcp", Addr: fmt.Sprintf("%v:%v", s.domainName, reverseTunnelPort), } - reverseTunnelServer, err := reversetunnel.NewServer( - reverseTunnelAddress, - []ssh.Signer{s.signer}, - s.roleAuth, - state.NoCache, - ) + reverseTunnelServer, err := reversetunnel.NewServer(reversetunnel.Config{ + ID: s.domainName, + ListenAddr: reverseTunnelAddress, + HostSigners: []ssh.Signer{s.signer}, + AccessPoint: s.roleAuth, + NewCachingAccessPoint: state.NoCache, + }) c.Assert(err, IsNil) c.Assert(reverseTunnelServer.Start(), IsNil) @@ -652,28 +670,49 @@ func (s *SrvSuite) TestProxyRoundRobin(c *C) { c.Assert(tsrv.Start(), IsNil) tunClt, err := auth.NewTunClient("test", - []utils.NetAddr{{AddrNetwork: "tcp", Addr: tsrv.Addr()}}, s.domainName, []ssh.AuthMethod{ssh.PublicKeys(s.signer)}) + []utils.NetAddr{{AddrNetwork: "tcp", Addr: tsrv.Addr()}}, hostID, []ssh.AuthMethod{ssh.PublicKeys(s.signer)}) c.Assert(err, IsNil) defer tunClt.Close() // start agent and load balance requests - rsAgent, err := reversetunnel.NewAgent( - reverseTunnelAddress, - "remote", - "localhost", - []ssh.Signer{s.signer}, tunClt, tunClt) + eventsC := make(chan string, 2) + rsAgent, err := reversetunnel.NewAgent(reversetunnel.AgentConfig{ + Context: context.TODO(), + Addr: reverseTunnelAddress, + RemoteCluster: "remote", + Username: hostID, + Signers: []ssh.Signer{s.signer}, + Client: tunClt, + AccessPoint: tunClt, + EventsC: eventsC, + }) c.Assert(err, IsNil) - c.Assert(rsAgent.Start(), IsNil) - - rsAgent2, err := reversetunnel.NewAgent( - reverseTunnelAddress, - "remote", - "localhost", - []ssh.Signer{s.signer}, tunClt, tunClt) + rsAgent.Start() + + rsAgent2, err := reversetunnel.NewAgent(reversetunnel.AgentConfig{ + Context: context.TODO(), + Addr: reverseTunnelAddress, + RemoteCluster: "remote", + Username: hostID, + Signers: []ssh.Signer{s.signer}, + Client: tunClt, + AccessPoint: tunClt, + EventsC: eventsC, + }) c.Assert(err, IsNil) - c.Assert(rsAgent2.Start(), IsNil) + rsAgent2.Start() defer rsAgent2.Close() + timeout := time.After(time.Second) + for i := 0; i < 2; i++ { + select { + case event := <-eventsC: + c.Assert(event, Equals, reversetunnel.ConnectedEvent) + case <-timeout: + c.Fatalf("timeout waiting for clusters to connect") + } + } + sshConfig := &ssh.ClientConfig{ User: s.user, Auth: []ssh.AuthMethod{ssh.PublicKeys(up.certSigner)}, @@ -700,13 +739,14 @@ func (s *SrvSuite) TestProxyDirectAccess(c *C) { AddrNetwork: "tcp", Addr: fmt.Sprintf("%v:0", s.domainName), } - reverseTunnelServer, err := reversetunnel.NewServer( - reverseTunnelAddress, - []ssh.Signer{s.signer}, - s.roleAuth, - state.NoCache, - reversetunnel.DirectSite(s.domainName, s.roleAuth), - ) + reverseTunnelServer, err := reversetunnel.NewServer(reversetunnel.Config{ + ID: s.domainName, + ListenAddr: reverseTunnelAddress, + HostSigners: []ssh.Signer{s.signer}, + AccessPoint: s.roleAuth, + NewCachingAccessPoint: state.NoCache, + DirectClusters: []reversetunnel.DirectCluster{{Name: s.domainName, Client: s.roleAuth}}, + }) c.Assert(err, IsNil) proxy, err := New( @@ -731,7 +771,7 @@ func (s *SrvSuite) TestProxyDirectAccess(c *C) { c.Assert(tsrv.Start(), IsNil) tunClt, err := auth.NewTunClient("test", - []utils.NetAddr{{AddrNetwork: "tcp", Addr: tsrv.Addr()}}, s.domainName, []ssh.AuthMethod{ssh.PublicKeys(s.signer)}) + []utils.NetAddr{{AddrNetwork: "tcp", Addr: tsrv.Addr()}}, hostID, []ssh.AuthMethod{ssh.PublicKeys(s.signer)}) c.Assert(err, IsNil) defer tunClt.Close() diff --git a/lib/state/cachingaccesspoint.go b/lib/state/cachingaccesspoint.go index 3f3f96f9aba3d..e46747e8b0b06 100644 --- a/lib/state/cachingaccesspoint.go +++ b/lib/state/cachingaccesspoint.go @@ -499,11 +499,16 @@ func (cs *CachingAuthClient) UpsertTunnelConnection(conn services.TunnelConnecti return cs.ap.UpsertTunnelConnection(conn) } +// DeleteTunnelConnection is a part of auth.AccessPoint implementation +func (cs *CachingAuthClient) DeleteTunnelConnection(clusterName, connName string) error { + return cs.ap.DeleteTunnelConnection(clusterName, connName) +} + // try calls a given function f and checks for errors. If f() fails, the current // time is recorded. Future calls to f will be ingored until sufficient time passes // since th last error func (cs *CachingAuthClient) try(f func() error) error { - tooSoon := cs.lastErrorTime.Add(defaults.NetworkBackoffDuration).After(time.Now()) + tooSoon := cs.lastErrorTime.Add(defaults.NetworkRetryDuration).After(time.Now()) if tooSoon { cs.Warnf("backoff: using cached value due to recent errors") return trace.ConnectionProblem(fmt.Errorf("backoff"), "backing off due to recent errors") diff --git a/lib/state/cachingaccesspoint_test.go b/lib/state/cachingaccesspoint_test.go index 6645b9e60a962..4b69dafcc8882 100644 --- a/lib/state/cachingaccesspoint_test.go +++ b/lib/state/cachingaccesspoint_test.go @@ -91,7 +91,6 @@ var ( TunnelConnections = []services.TunnelConnection{ services.MustCreateTunnelConnection("conn1", services.TunnelConnectionSpecV2{ ClusterName: "example.com", - ProxyAddr: "localhost:3025", ProxyName: "p1", LastHeartbeat: time.Date(2015, 6, 5, 4, 3, 2, 1, time.UTC).UTC(), }), @@ -234,7 +233,7 @@ func (s *ClusterSnapshotSuite) TestTry(c *check.C) { c.Assert(failedCalls, check.Equals, 1) // "wait" for backoff duration and try again: - ap.lastErrorTime = time.Now().Add(-backoffDuration) + ap.lastErrorTime = time.Now().Add(-defaults.NetworkBackoffDuration) ap.try(success) ap.try(failure) diff --git a/lib/utils/loadbalancer.go b/lib/utils/loadbalancer.go index 7396fe9623c9b..5f0b7f2e87cfc 100644 --- a/lib/utils/loadbalancer.go +++ b/lib/utils/loadbalancer.go @@ -44,6 +44,7 @@ func NewLoadBalancer(ctx context.Context, frontend NetAddr, backends ...NetAddr) "listen": frontend.String(), }, }), + connections: make(map[NetAddr]map[int64]net.Conn), }, nil } @@ -51,6 +52,7 @@ func NewLoadBalancer(ctx context.Context, frontend NetAddr, backends ...NetAddr) // balancer used in tests. type LoadBalancer struct { sync.RWMutex + connID int64 *log.Entry frontend NetAddr backends []NetAddr @@ -58,6 +60,41 @@ type LoadBalancer struct { currentIndex int listener net.Listener listenerClosed bool + connections map[NetAddr]map[int64]net.Conn +} + +// trackeConnection adds connection to the connection tracker +func (l *LoadBalancer) trackConnection(backend NetAddr, conn net.Conn) int64 { + l.Lock() + defer l.Unlock() + l.connID += 1 + tracker, ok := l.connections[backend] + if !ok { + tracker = make(map[int64]net.Conn) + l.connections[backend] = tracker + } + tracker[l.connID] = conn + return l.connID +} + +// untrackConnection removes connection from connection tracker +func (l *LoadBalancer) untrackConnection(backend NetAddr, id int64) { + l.Lock() + defer l.Unlock() + tracker, ok := l.connections[backend] + if !ok { + return + } + delete(tracker, id) +} + +// dropConnections drops connections associated with backend +func (l *LoadBalancer) dropConnections(backend NetAddr) { + tracker := l.connections[backend] + for _, conn := range tracker { + conn.Close() + } + delete(l.connections, backend) } // AddBackend adds backend @@ -76,10 +113,10 @@ func (l *LoadBalancer) RemoveBackend(b NetAddr) { for i := range l.backends { if l.backends[i].Equals(b) { l.backends = append(l.backends[:i], l.backends[i+1:]...) + l.dropConnections(b) return } } - } func (l *LoadBalancer) nextBackend() (*NetAddr, error) { @@ -171,12 +208,18 @@ func (l *LoadBalancer) forward(conn net.Conn) error { return trace.Wrap(err) } + connID := l.trackConnection(*backend, conn) + defer l.untrackConnection(*backend, connID) + backendConn, err := net.Dial(backend.AddrNetwork, backend.Addr) if err != nil { return trace.ConvertSystemError(err) } defer backendConn.Close() + backendConnID := l.trackConnection(*backend, backendConn) + defer l.untrackConnection(*backend, backendConnID) + logger := l.WithFields(log.Fields{ "source": conn.RemoteAddr(), "dest": backendConn.RemoteAddr(), diff --git a/lib/utils/loadbalancer_test.go b/lib/utils/loadbalancer_test.go index ba4d0cc83dcb2..050586e6718fc 100644 --- a/lib/utils/loadbalancer_test.go +++ b/lib/utils/loadbalancer_test.go @@ -172,6 +172,44 @@ func (s *LBSuite) TestClose(c *check.C) { c.Assert(err, check.NotNil) } +func (s *LBSuite) TestDropConnections(c *check.C) { + backend1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, "backend 1") + })) + defer backend1.Close() + + ports, err := GetFreeTCPPorts(1) + c.Assert(err, check.IsNil) + + frontend := localAddr(ports[0]) + + backendAddr := urlToNetAddr(backend1.URL) + lb, err := NewLoadBalancer(context.TODO(), frontend, backendAddr) + c.Assert(err, check.IsNil) + err = lb.Listen() + c.Assert(err, check.IsNil) + go lb.Serve() + defer lb.Close() + + conn, err := net.Dial("tcp", frontend.String()) + c.Assert(err, check.IsNil) + defer conn.Close() + + out, err := roundtripWithConn(conn) + c.Assert(err, check.IsNil) + c.Assert(out, check.Equals, "backend 1") + + // to make sure multiple requests work on the same wire + out, err = roundtripWithConn(conn) + c.Assert(err, check.IsNil) + c.Assert(out, check.Equals, "backend 1") + + // removing backend results in dropped connection to this backend + lb.RemoveBackend(backendAddr) + out, err = roundtripWithConn(conn) + c.Assert(err, check.NotNil) +} + func urlToNetAddr(u string) NetAddr { parsed, err := url.Parse(u) if err != nil { @@ -196,7 +234,15 @@ func roundtrip(addr string) (string, error) { return "", err } defer conn.Close() - fmt.Fprintf(conn, "GET / HTTP/1.1\r\nHost: 127.0.0.1\r\n\r\n") + return roundtripWithConn(conn) +} + +// roundtripWithConn uses HTTP get on the existing connection +func roundtripWithConn(conn net.Conn) (string, error) { + _, err := fmt.Fprintf(conn, "GET / HTTP/1.1\r\nHost: 127.0.0.1\r\n\r\n") + if err != nil { + return "", err + } re, err := http.ReadResponse(bufio.NewReader(conn), nil) if err != nil { diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index 8e5bae28fd1bd..cee55c427ddf1 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -73,6 +73,8 @@ import ( kyaml "k8s.io/apimachinery/pkg/util/yaml" ) +const hostID = "00000000-0000-0000-0000-000000000000" + func TestWeb(t *testing.T) { TestingT(t) } @@ -241,7 +243,7 @@ func (s *WebSuite) SetUpTest(c *C) { hpriv, hpub, err := s.authServer.GenerateKeyPair("") c.Assert(err, IsNil) hcert, err := s.authServer.GenerateHostCert( - hpub, "00000000-0000-0000-0000-000000000000", s.domainName, s.domainName, teleport.Roles{teleport.RoleAdmin}, 0) + hpub, hostID, s.domainName, s.domainName, teleport.Roles{teleport.RoleAdmin}, 0) c.Assert(err, IsNil) // set up user CA and set up a user that has access to the server @@ -274,16 +276,17 @@ func (s *WebSuite) SetUpTest(c *C) { c.Assert(s.node.Start(), IsNil) // create reverse tunnel service: - revTunServer, err := reversetunnel.NewServer( - utils.NetAddr{ + revTunServer, err := reversetunnel.NewServer(reversetunnel.Config{ + ID: node.ID(), + ListenAddr: utils.NetAddr{ AddrNetwork: "tcp", Addr: fmt.Sprintf("%v:0", s.domainName), }, - []ssh.Signer{s.signer}, - s.roleAuth, - state.NoCache, - reversetunnel.DirectSite(s.domainName, s.roleAuth), - ) + HostSigners: []ssh.Signer{s.signer}, + AccessPoint: s.roleAuth, + NewCachingAccessPoint: state.NoCache, + DirectClusters: []reversetunnel.DirectCluster{{Name: s.domainName, Client: s.roleAuth}}, + }) c.Assert(err, IsNil) apiPort := s.freePorts[len(s.freePorts)-1] @@ -307,7 +310,7 @@ func (s *WebSuite) SetUpTest(c *C) { // create a tun client tunClient, err := auth.NewTunClient("test", []utils.NetAddr{tunAddr}, - s.domainName, []ssh.AuthMethod{ssh.PublicKeys(s.signer)}) + hostID, []ssh.AuthMethod{ssh.PublicKeys(s.signer)}) c.Assert(err, IsNil) // proxy server: From b2ed270bb66c91bb5b053fd6a5d148a25e423d62 Mon Sep 17 00:00:00 2001 From: Sasha Klizhentas Date: Thu, 12 Oct 2017 17:38:58 -0700 Subject: [PATCH 19/24] fix data race and update lock file digest --- Gopkg.lock | 2 +- lib/reversetunnel/agent.go | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/Gopkg.lock b/Gopkg.lock index 6681f3bd79c6c..8568e8b714af8 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -392,6 +392,6 @@ [solve-meta] analyzer-name = "dep" analyzer-version = 1 - inputs-digest = "d1eb2a47c4fafb650c2c5ffef90bf7949a795641da2bf5ac41cf84eb3b74c806" + inputs-digest = "1ef02af2e287963c62775106667e36b6623f4e31a4a6c1c25d0c87aa672e66e6" solver-name = "gps-cdcl" solver-version = 1 diff --git a/lib/reversetunnel/agent.go b/lib/reversetunnel/agent.go index 3b28fdba806d8..2c5b1609fe1bd 100644 --- a/lib/reversetunnel/agent.go +++ b/lib/reversetunnel/agent.go @@ -170,6 +170,8 @@ func (a *Agent) getLastStateChange() time.Time { } func (a *Agent) setStateAndPrincipals(state string, principals []string) { + a.Lock() + defer a.Unlock() prev := a.state a.Debugf("changing state %v -> %v", prev, state) a.state = state From 6471bc32da4ff4b51e1fc56e5354b76198470b19 Mon Sep 17 00:00:00 2001 From: Sasha Klizhentas Date: Fri, 13 Oct 2017 09:11:13 -0700 Subject: [PATCH 20/24] fix data race --- lib/reversetunnel/remotesite.go | 23 +++++++++++++++-------- lib/services/tunnelconn.go | 8 ++++++++ 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/lib/reversetunnel/remotesite.go b/lib/reversetunnel/remotesite.go index abe1e63d8ccfc..6586e365b986a 100644 --- a/lib/reversetunnel/remotesite.go +++ b/lib/reversetunnel/remotesite.go @@ -44,7 +44,7 @@ import ( // the local reverse tunnel server, and now it can provide access to the // cluster behind it. type remoteSite struct { - sync.Mutex + sync.RWMutex *log.Entry domainName string @@ -73,14 +73,14 @@ func (s *remoteSite) String() string { } func (s *remoteSite) connectionCount() int { - s.Lock() - defer s.Unlock() + s.RLock() + defer s.RUnlock() return len(s.connections) } func (s *remoteSite) hasValidConnections() bool { - s.Lock() - defer s.Unlock() + s.RLock() + defer s.RUnlock() for _, conn := range s.connections { if !conn.isInvalid() { @@ -156,10 +156,17 @@ func (s *remoteSite) GetStatus() string { return RemoteSiteStatusOffline } +func (s *remoteSite) copyConnInfo() services.TunnelConnection { + s.RLock() + defer s.RUnlock() + return s.connInfo.Clone() +} + func (s *remoteSite) registerHeartbeat(t time.Time) { - s.connInfo.SetLastHeartbeat(t) - s.connInfo.SetExpiry(s.clock.Now().Add(defaults.ReverseTunnelOfflineThreshold)) - err := s.srv.AccessPoint.UpsertTunnelConnection(s.connInfo) + connInfo := s.copyConnInfo() + connInfo.SetLastHeartbeat(t) + connInfo.SetExpiry(s.clock.Now().Add(defaults.ReverseTunnelOfflineThreshold)) + err := s.srv.AccessPoint.UpsertTunnelConnection(connInfo) if err != nil { s.Warningf("failed to register heartbeat: %v", err) } diff --git a/lib/services/tunnelconn.go b/lib/services/tunnelconn.go index f7127bf12b2d1..7186ddd91814d 100644 --- a/lib/services/tunnelconn.go +++ b/lib/services/tunnelconn.go @@ -35,6 +35,8 @@ type TunnelConnection interface { CheckAndSetDefaults() error // String returns user friendly representation of this connection String() string + // Clone returns a copy of this tunnel connection + Clone() TunnelConnection } // MustCreateTunnelConnection returns new connection from V2 spec or panics if @@ -76,6 +78,12 @@ type TunnelConnectionV2 struct { Spec TunnelConnectionSpecV2 `json:"spec"` } +// Clone returns a copy of this tunnel connection +func (r *TunnelConnectionV2) Clone() TunnelConnection { + out := *r + return &out +} + // String returns user-friendly description of this connection func (r *TunnelConnectionV2) String() string { return fmt.Sprintf("TunnelConnection(name=%v, cluster=%v, proxy=%v)", r.Metadata.Name, r.Spec.ClusterName, r.Spec.ProxyName) From 4b36d77f31e249e71ab3868685a802a66333b646 Mon Sep 17 00:00:00 2001 From: Sasha Klizhentas Date: Fri, 13 Oct 2017 10:21:10 -0700 Subject: [PATCH 21/24] remove data race on channel close --- lib/reversetunnel/srv.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/lib/reversetunnel/srv.go b/lib/reversetunnel/srv.go index 0236a81209c33..34a7951c67c1b 100644 --- a/lib/reversetunnel/srv.go +++ b/lib/reversetunnel/srv.go @@ -622,6 +622,7 @@ type remoteConn struct { counter int32 discoveryC ssh.Channel discoveryErr error + closed int32 } func (rc *remoteConn) openDiscoveryChannel() (ssh.Channel, error) { @@ -645,6 +646,10 @@ func (rc *remoteConn) String() string { } func (rc *remoteConn) Close() error { + if !atomic.CompareAndSwapInt32(&rc.closed, 0, 1) { + // already closed + return nil + } if rc.discoveryC != nil { rc.discoveryC.Close() rc.discoveryC = nil From d69f88978be48631f294857491e66fd5773724cb Mon Sep 17 00:00:00 2001 From: Sasha Klizhentas Date: Fri, 13 Oct 2017 11:33:52 -0700 Subject: [PATCH 22/24] fix data race in tun client --- lib/auth/tun.go | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/lib/auth/tun.go b/lib/auth/tun.go index 94f3ca8514b44..48030a86139b1 100644 --- a/lib/auth/tun.go +++ b/lib/auth/tun.go @@ -897,6 +897,19 @@ func (c *TunClient) GetAgent() (AgentCloser, error) { return ta, nil } +func (c *TunClient) setupSyncLoop() { + c.Lock() + defer c.Unlock() + if c.disableRefresh { + return + } + if c.refreshTicker != nil { + return + } + c.refreshTicker = time.NewTicker(defaults.AuthServersRefreshPeriod) + go c.authServersSyncLoop() +} + // Dial dials to Auth server's HTTP API over SSH tunnel. func (c *TunClient) Dial(network, address string) (net.Conn, error) { c.Debugf("dialing %v %v", network, address) @@ -909,14 +922,9 @@ func (c *TunClient) Dial(network, address string) (net.Conn, error) { if err != nil { return nil, trace.ConnectionProblem(err, "can't connect to auth API") } - // dialed & authenticated? lets start synchronizing the - // list of auth servers: - if c.disableRefresh == false { - if c.refreshTicker == nil { - c.refreshTicker = time.NewTicker(defaults.AuthServersRefreshPeriod) - go c.authServersSyncLoop() - } - } + // dialed & authenticated? + // lets start synchronizing the list of auth servers: + c.setupSyncLoop() return &tunConn{client: client, Conn: conn}, nil } From 039249507dca71955af18ae4e549d719f646b862 Mon Sep 17 00:00:00 2001 From: Sasha Klizhentas Date: Fri, 13 Oct 2017 19:26:49 -0700 Subject: [PATCH 23/24] update according to code review comments --- lib/reversetunnel/agent.go | 25 ++++++++++++++++++++++--- lib/reversetunnel/remotesite.go | 12 ++++++++++++ 2 files changed, 34 insertions(+), 3 deletions(-) diff --git a/lib/reversetunnel/agent.go b/lib/reversetunnel/agent.go index 2c5b1609fe1bd..5abcc5c51b571 100644 --- a/lib/reversetunnel/agent.go +++ b/lib/reversetunnel/agent.go @@ -111,7 +111,15 @@ func (a *AgentConfig) CheckAndSetDefaults() error { } // Agent is a reverse tunnel agent running as a part of teleport Proxies -// to establish outbound reverse tunnels to remote proxies +// to establish outbound reverse tunnels to remote proxies. +// +// There are two operation modes for agents: +// * Standard agent attempts to establish connection +// to any available proxy. Standard agent transitions between +// "connecting" -> "connected states. +// * Discovering agent attempts to establish connection to a subset +// of remote proxies (specified in the config via DiscoverProxies parameter.) +// Discovering agent transitions between "discovering" -> "discovered" states. type Agent struct { sync.RWMutex *log.Entry @@ -427,9 +435,20 @@ func (a *Agent) proxyTransport(ch ssh.Channel, reqC <-chan *ssh.Request) { } // run is the main agent loop, constantly tries to re-establish -// the connection until stopped +// the connection until stopped. It has several operation modes: +// at first it tries to connect with fast retries on errors, +// but after a certain threshold it slows down retry pace +// by switching to longer delays between retries. +// +// Once run connects to a proxy it starts processing requests +// from the proxy via SSH channels opened by the remote Proxy. +// +// Agent sends periodic heartbeats back to the Proxy +// and that is how Proxy determines disconnects. +// func (a *Agent) run() { - ticker, err := utils.NewSwitchTicker(defaults.FastAttempts, defaults.NetworkRetryDuration, defaults.NetworkBackoffDuration) + ticker, err := utils.NewSwitchTicker(defaults.FastAttempts, + defaults.NetworkRetryDuration, defaults.NetworkBackoffDuration) if err != nil { log.Errorf("failed to run: %v", err) return diff --git a/lib/reversetunnel/remotesite.go b/lib/reversetunnel/remotesite.go index 6586e365b986a..30ced5e12dc13 100644 --- a/lib/reversetunnel/remotesite.go +++ b/lib/reversetunnel/remotesite.go @@ -178,6 +178,9 @@ func (s *remoteSite) deleteConnectionRecord() { s.srv.AccessPoint.DeleteTunnelConnection(s.connInfo.GetClusterName(), s.connInfo.GetName()) } +// handleHearbeat receives heartbeat messages from the connected agent +// if the agent has missed several heartbeats in a row, Proxy marks +// the connection as invalid. func (s *remoteSite) handleHeartbeat(conn *remoteConn, ch ssh.Channel, reqC <-chan *ssh.Request) { defer func() { s.Infof("cluster connection closed") @@ -280,6 +283,13 @@ func (s *remoteSite) findDisconnectedProxies() ([]services.Server, error) { return missing, nil } +// sendDiscovery requests sends special "Discovery requests" +// back to the connected agent. +// Discovery request consists of the proxies that are part +// of the cluster, but did not receive the connection from the agent. +// Agent will act on a discovery request attempting +// to establish connection to the proxies that were not discovered. +// See package documentation for more details. func (s *remoteSite) sendDiscoveryRequest() error { disconnectedProxies, err := s.findDisconnectedProxies() if err != nil { @@ -316,6 +326,8 @@ func (s *remoteSite) sendDiscoveryRequest() error { return nil } + // loop over existing connections (reverse tunnels) and try to send discovery + // requests to the remote cluster for i := 0; i < s.connectionCount(); i++ { err := send() if err != nil { From 0938ce2dfbedac47cf1050ff99cab9bbbcb3346d Mon Sep 17 00:00:00 2001 From: Sasha Klizhentas Date: Mon, 16 Oct 2017 09:38:06 -0700 Subject: [PATCH 24/24] fix typo --- lib/utils/loadbalancer.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/utils/loadbalancer.go b/lib/utils/loadbalancer.go index 5f0b7f2e87cfc..cb06979f356f2 100644 --- a/lib/utils/loadbalancer.go +++ b/lib/utils/loadbalancer.go @@ -27,7 +27,7 @@ import ( log "github.com/sirupsen/logrus" ) -// NewLoadBalancer returns new load balancer listening on frotend +// NewLoadBalancer returns new load balancer listening on frontend // and redirecting requests to backends using round robin algo func NewLoadBalancer(ctx context.Context, frontend NetAddr, backends ...NetAddr) (*LoadBalancer, error) { if ctx == nil {