Skip to content

Commit

Permalink
Move auth to platform to ssh command from up command
Browse files Browse the repository at this point in the history
  • Loading branch information
janekbaraniewski committed Dec 9, 2024
1 parent 9db086e commit 299cd28
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 51 deletions.
40 changes: 35 additions & 5 deletions cmd/ssh.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package cmd

import (
"bytes"
"context"
"encoding/base64"
"fmt"
Expand All @@ -20,6 +21,7 @@ import (
client2 "github.com/loft-sh/devpod/pkg/client"
"github.com/loft-sh/devpod/pkg/client/clientimplementation"
"github.com/loft-sh/devpod/pkg/config"
"github.com/loft-sh/devpod/pkg/credentials"
dpFlags "github.com/loft-sh/devpod/pkg/flags"
"github.com/loft-sh/devpod/pkg/gpg"
"github.com/loft-sh/devpod/pkg/port"
Expand Down Expand Up @@ -59,6 +61,8 @@ type SSHCmd struct {
// ssh keepalive options
SSHKeepAliveInterval time.Duration `json:"sshKeepAliveInterval,omitempty"`

SetupLoftPlatformAccess bool

StartServices bool

Proxy bool
Expand Down Expand Up @@ -114,6 +118,7 @@ func NewSSHCmd(f *flags.GlobalFlags) *cobra.Command {
sshCmd.Flags().BoolVar(&cmd.Stdio, "stdio", false, "If true will tunnel connection through stdout and stdin")
sshCmd.Flags().BoolVar(&cmd.StartServices, "start-services", true, "If false will not start any port-forwarding or git / docker credentials helper")
sshCmd.Flags().DurationVar(&cmd.SSHKeepAliveInterval, "ssh-keepalive-interval", 55*time.Second, "How often should keepalive request be made (55s)")
sshCmd.Flags().BoolVar(&cmd.SetupLoftPlatformAccess, "setup-loft-platform-access", false, "should setup loft platform access")

return sshCmd
}
Expand Down Expand Up @@ -179,7 +184,7 @@ func (cmd *SSHCmd) startProxyTunnel(
})
},
func(ctx context.Context, containerClient *ssh.Client) error {
return cmd.startTunnel(ctx, devPodConfig, containerClient, client.Workspace(), log)
return cmd.startTunnel(ctx, devPodConfig, containerClient, client, log)
},
)
}
Expand Down Expand Up @@ -281,7 +286,7 @@ func (cmd *SSHCmd) jumpContainer(
unlockOnce.Do(client.Unlock)

// start ssh tunnel
return cmd.startTunnel(ctx, devPodConfig, containerClient, client.Workspace(), log)
return cmd.startTunnel(ctx, devPodConfig, containerClient, client, log)
}, devPodConfig, envVars)
}

Expand Down Expand Up @@ -389,7 +394,7 @@ func (cmd *SSHCmd) forwardPorts(
return <-errChan
}

func (cmd *SSHCmd) startTunnel(ctx context.Context, devPodConfig *config.Config, containerClient *ssh.Client, workspaceName string, log log.Logger) error {
func (cmd *SSHCmd) startTunnel(ctx context.Context, devPodConfig *config.Config, containerClient *ssh.Client, workspaceClient client2.BaseWorkspaceClient, log log.Logger) error {
// check if we should forward ports
if len(cmd.ForwardPorts) > 0 {
return cmd.forwardPorts(ctx, containerClient, log)
Expand Down Expand Up @@ -423,7 +428,7 @@ func (cmd *SSHCmd) startTunnel(ctx context.Context, devPodConfig *config.Config,
}
}

workdir := filepath.Join("/workspaces", workspaceName)
workdir := filepath.Join("/workspaces", workspaceClient.Workspace())
if cmd.WorkDir != "" {
workdir = cmd.WorkDir
}
Expand All @@ -442,6 +447,10 @@ func (cmd *SSHCmd) startTunnel(ctx context.Context, devPodConfig *config.Config,
return err
}

if cmd.Provider == "" {
cmd.Provider = "devpod-pro"
}

// Traffic is coming in from the outside, we need to forward it to the container
if cmd.Proxy || cmd.Stdio {
if cmd.Proxy {
Expand All @@ -454,8 +463,13 @@ func (cmd *SSHCmd) startTunnel(ctx context.Context, devPodConfig *config.Config,
log.Error(err)
}
}()
}

go func() {
if err := cmd.setupLoftPlatformAccess(ctx, containerClient, cmd.Context, cmd.Provider, log); err != nil {
log.Error(err)
}
}()
}
return devssh.Run(ctx, containerClient, command, os.Stdin, os.Stdout, writer, envVars)
}

Expand All @@ -475,6 +489,22 @@ func (cmd *SSHCmd) startTunnel(ctx context.Context, devPodConfig *config.Config,
)
}

func (cmd *SSHCmd) setupLoftPlatformAccess(ctx context.Context, sshClient *ssh.Client, context, provider string, log log.Logger) error {
port, err := credentials.GetPort()
if err != nil {
return fmt.Errorf("get port: %w", err)
}

buf := &bytes.Buffer{}
command := fmt.Sprintf("'%s' agent container setup-loft-platform-access --context %s --provider %s --port %d", agent.ContainerDevPodHelperLocation, context, provider, port)
err = devssh.Run(ctx, sshClient, command, nil, buf, buf, nil)
if err != nil {
log.Debugf("Failed to setup platform access: %s%v", buf.String(), err)
}

return nil
}

func (cmd *SSHCmd) startServices(
ctx context.Context,
devPodConfig *config.Config,
Expand Down
46 changes: 0 additions & 46 deletions cmd/up.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import (
"github.com/loft-sh/devpod/pkg/client/clientimplementation"
"github.com/loft-sh/devpod/pkg/command"
"github.com/loft-sh/devpod/pkg/config"
"github.com/loft-sh/devpod/pkg/credentials"
config2 "github.com/loft-sh/devpod/pkg/devcontainer/config"
"github.com/loft-sh/devpod/pkg/devcontainer/sshtunnel"
dpFlags "github.com/loft-sh/devpod/pkg/flags"
Expand Down Expand Up @@ -199,15 +198,6 @@ func (cmd *UpCmd) Run(
}
}

// setup loft platform access
context := devPodConfig.Current()
if cmd.SetupLoftPlatformAccess {
err = setupLoftPlatformAccess(devPodConfig.DefaultContext, context.DefaultProvider, user, client, log)
if err != nil {
return err
}
}

// setup dotfiles in the container
err = setupDotfiles(cmd.DotfilesSource, cmd.DotfilesScript, client, devPodConfig, log)
if err != nil {
Expand Down Expand Up @@ -1115,42 +1105,6 @@ func setupGitSSHSignature(signingKey string, client client2.BaseWorkspaceClient,
return nil
}

func setupLoftPlatformAccess(context, provider, user string, client client2.BaseWorkspaceClient, log log.Logger) error {
log.Infof("Setting up platform access")
execPath, err := os.Executable()
if err != nil {
return err
}

port, err := credentials.GetPort()
if err != nil {
return fmt.Errorf("get port: %w", err)
}

command := fmt.Sprintf("\"%s\" agent container setup-loft-platform-access --context %s --provider %s --port %d", agent.ContainerDevPodHelperLocation, context, provider, port)

log.Debugf("Executing command: %v", command)
var errb bytes.Buffer
cmd := exec.Command(
execPath,
"ssh",
"--start-services=true",
"--user",
user,
"--context",
client.Context(),
client.Workspace(),
"--command", command,
)
cmd.Stderr = &errb
err = cmd.Run()
if err != nil {
log.Debugf("failed to set up platform access in workspace: %s", errb.String())
}

return nil
}

func performGpgForwarding(
client client2.BaseWorkspaceClient,
log log.Logger,
Expand Down

0 comments on commit 299cd28

Please sign in to comment.