diff --git a/server/application/terminal.go b/server/application/terminal.go index 6424c89e97670..bea1f6ea6a110 100644 --- a/server/application/terminal.go +++ b/server/application/terminal.go @@ -6,6 +6,7 @@ import ( "net/http" "time" + util_session "github.com/argoproj/argo-cd/v2/util/session" "github.com/argoproj/gitops-engine/pkg/utils/kube" log "github.com/sirupsen/logrus" v1 "k8s.io/api/core/v1" @@ -37,11 +38,12 @@ type terminalHandler struct { allowedShells []string namespace string enabledNamespaces []string + sessionManager util_session.SessionManager } // NewHandler returns a new terminal handler. func NewHandler(appLister applisters.ApplicationLister, namespace string, enabledNamespaces []string, db db.ArgoDB, enf *rbac.Enforcer, cache *servercache.Cache, - appResourceTree AppResourceTreeFn, allowedShells []string) *terminalHandler { + appResourceTree AppResourceTreeFn, allowedShells []string, sessionManager util_session.SessionManager) *terminalHandler { return &terminalHandler{ appLister: appLister, db: db, @@ -51,6 +53,7 @@ func NewHandler(appLister applisters.ApplicationLister, namespace string, enable allowedShells: allowedShells, namespace: namespace, enabledNamespaces: enabledNamespaces, + sessionManager: sessionManager, } } @@ -222,7 +225,7 @@ func (s *terminalHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { fieldLog.Info("terminal session starting") - session, err := newTerminalSession(w, r, nil) + session, err := newTerminalSession(w, r, nil, s.sessionManager) if err != nil { http.Error(w, "Failed to start terminal session", http.StatusBadRequest) return @@ -282,6 +285,11 @@ type TerminalMessage struct { Cols uint16 `json:"cols"` } +// TerminalCommand is the struct for websocket commands,For example you need ask client to reconnect +type TerminalCommand struct { + Code int +} + // startProcess executes specified commands in the container and connects it up with the ptyHandler (a session) func startProcess(k8sClient kubernetes.Interface, cfg *rest.Config, namespace, podName, containerName string, cmd []string, ptyHandler PtyHandler) error { req := k8sClient.CoreV1().RESTClient().Post(). diff --git a/server/application/websocket.go b/server/application/websocket.go index fdac5a76c592b..faee91c4f47e4 100644 --- a/server/application/websocket.go +++ b/server/application/websocket.go @@ -3,6 +3,9 @@ package application import ( "encoding/json" "fmt" + "github.com/argoproj/argo-cd/v2/common" + httputil "github.com/argoproj/argo-cd/v2/util/http" + util_session "github.com/argoproj/argo-cd/v2/util/session" "net/http" "sync" "time" @@ -12,6 +15,11 @@ import ( "k8s.io/client-go/tools/remotecommand" ) +const ( + ReconnectCode = 1 + ReconnectMessage = "\nReconnect because the token was refreshed...\n" +) + var upgrader = func() websocket.Upgrader { upgrader := websocket.Upgrader{} upgrader.HandshakeTimeout = time.Second * 2 @@ -23,25 +31,40 @@ var upgrader = func() websocket.Upgrader { // terminalSession implements PtyHandler type terminalSession struct { - wsConn *websocket.Conn - sizeChan chan remotecommand.TerminalSize - doneChan chan struct{} - tty bool - readLock sync.Mutex - writeLock sync.Mutex + wsConn *websocket.Conn + sizeChan chan remotecommand.TerminalSize + doneChan chan struct{} + tty bool + readLock sync.Mutex + writeLock sync.Mutex + sessionManager util_session.SessionManager + token *string +} + +// getToken get auth token from web socket request +func getToken(r *http.Request) (string, error) { + cookies := r.Cookies() + return httputil.JoinCookies(common.AuthCookieName, cookies) } // newTerminalSession create terminalSession -func newTerminalSession(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*terminalSession, error) { +func newTerminalSession(w http.ResponseWriter, r *http.Request, responseHeader http.Header, sessionManager util_session.SessionManager) (*terminalSession, error) { + token, err := getToken(r) + if err != nil { + return nil, err + } + conn, err := upgrader.Upgrade(w, r, responseHeader) if err != nil { return nil, err } session := &terminalSession{ - wsConn: conn, - tty: true, - sizeChan: make(chan remotecommand.TerminalSize), - doneChan: make(chan struct{}), + wsConn: conn, + tty: true, + sizeChan: make(chan remotecommand.TerminalSize), + doneChan: make(chan struct{}), + sessionManager: sessionManager, + token: &token, } return session, nil } @@ -78,8 +101,40 @@ func (t *terminalSession) Next() *remotecommand.TerminalSize { } } +// reconnect send reconnect code to client and ask them init new ws session +func (t *terminalSession) reconnect() (int, error) { + reconnectCommand, _ := json.Marshal(TerminalCommand{ + Code: ReconnectCode, + }) + reconnectMessage, _ := json.Marshal(TerminalMessage{ + Operation: "stdout", + Data: ReconnectMessage, + }) + t.writeLock.Lock() + err := t.wsConn.WriteMessage(websocket.TextMessage, reconnectMessage) + if err != nil { + log.Errorf("write message err: %v", err) + return 0, err + } + err = t.wsConn.WriteMessage(websocket.TextMessage, reconnectCommand) + if err != nil { + log.Errorf("write message err: %v", err) + return 0, err + } + t.writeLock.Unlock() + return 0, nil +} + // Read called in a loop from remotecommand as long as the process is running func (t *terminalSession) Read(p []byte) (int, error) { + // check if token still valid + _, newToken, err := t.sessionManager.VerifyToken(*t.token) + // err in case if token is revoked, newToken in case if refresh happened + if err != nil || newToken != "" { + // need to send reconnect code in case if token was refreshed + return t.reconnect() + } + t.readLock.Lock() _, message, err := t.wsConn.ReadMessage() t.readLock.Unlock() diff --git a/server/application/websocket_test.go b/server/application/websocket_test.go new file mode 100644 index 0000000000000..30c5ffa232328 --- /dev/null +++ b/server/application/websocket_test.go @@ -0,0 +1,46 @@ +package application + +import ( + "encoding/json" + "github.com/gorilla/websocket" + "github.com/stretchr/testify/assert" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func reconnect(w http.ResponseWriter, r *http.Request) { + var upgrader = websocket.Upgrader{} + c, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + + ts := terminalSession{wsConn: c} + _, _ = ts.reconnect() +} + +func TestReconnect(t *testing.T) { + + s := httptest.NewServer(http.HandlerFunc(reconnect)) + defer s.Close() + + u := "ws" + strings.TrimPrefix(s.URL, "http") + + // Connect to the server + ws, _, err := websocket.DefaultDialer.Dial(u, nil) + assert.NoError(t, err) + + defer ws.Close() + + _, p, _ := ws.ReadMessage() + + var message TerminalMessage + + err = json.Unmarshal(p, &message) + + assert.NoError(t, err) + assert.Equal(t, message.Data, ReconnectMessage) + +} diff --git a/server/server.go b/server/server.go index e7e3ffb351068..28a1da6ec72fb 100644 --- a/server/server.go +++ b/server/server.go @@ -975,7 +975,7 @@ func (a *ArgoCDServer) newHTTPServer(ctx context.Context, port int, grpcWebHandl } mux.Handle("/api/", handler) - terminal := application.NewHandler(a.appLister, a.Namespace, a.ApplicationNamespaces, a.db, a.enf, a.Cache, appResourceTreeFn, a.settings.ExecShells). + terminal := application.NewHandler(a.appLister, a.Namespace, a.ApplicationNamespaces, a.db, a.enf, a.Cache, appResourceTreeFn, a.settings.ExecShells, *a.sessionMgr). WithFeatureFlagMiddleware(a.settingsMgr.GetSettings) th := util_session.WithAuthMiddleware(a.DisableAuth, a.sessionMgr, terminal) mux.Handle("/terminal", th) @@ -988,6 +988,7 @@ func (a *ArgoCDServer) newHTTPServer(ctx context.Context, port int, grpcWebHandl // will be added in mux. registerExtensions(mux, a) } + mustRegisterGWHandler(versionpkg.RegisterVersionServiceHandler, ctx, gwmux, conn) mustRegisterGWHandler(clusterpkg.RegisterClusterServiceHandler, ctx, gwmux, conn) mustRegisterGWHandler(applicationpkg.RegisterApplicationServiceHandler, ctx, gwmux, conn) diff --git a/ui/src/app/applications/components/pod-terminal-viewer/pod-terminal-viewer.tsx b/ui/src/app/applications/components/pod-terminal-viewer/pod-terminal-viewer.tsx index 9fe2cf248e024..72f18f1f8b353 100644 --- a/ui/src/app/applications/components/pod-terminal-viewer/pod-terminal-viewer.tsx +++ b/ui/src/app/applications/components/pod-terminal-viewer/pod-terminal-viewer.tsx @@ -72,7 +72,14 @@ export const PodTerminalViewer: React.FC = ({ const onConnectionMessage = (e: MessageEvent) => { const msg = JSON.parse(e.data); - connSubject.next(msg); + if (!msg?.Code) { + connSubject.next(msg); + } else { + // Do reconnect due to refresh token event + onConnectionClose(); + setupConnection() + } + }; const onConnectionOpen = () => {