Skip to content

Commit

Permalink
[runner] Back up and restore ~/.ssh files
Browse files Browse the repository at this point in the history
Preserve `~/.ssh` directory state, namely `environment` and
`authorized_keys` files, as follows:

Before a job started:

* If `<file>` exists:
    - rename `<file>` to `<file>.dstack.bak`
    - copy backup to a new file with original name `<file>`
* If `<file>` does not exist: remember this

After the job finished:

* If `<file>` existed: rename `<file>.dstack.bak` back to `<file>`
* If `<file>` did not exist: remove `<file>`

Fixes: #2257
  • Loading branch information
un-def committed Feb 4, 2025
1 parent 1f74e37 commit 6ec169d
Showing 1 changed file with 93 additions and 53 deletions.
146 changes: 93 additions & 53 deletions runner/internal/executor/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -294,8 +294,16 @@ func (ex *RunExecutor) execJob(ctx context.Context, jobLogFile io.Writer) error
envMap.Update(ex.jobSpec.Env, false)

// As of 2024-11-29, ex.homeDir is always set to /root
if err := writeSSHEnvironment(envMap, -1, -1, ex.homeDir); err != nil {
log.Warning(ctx, "failed to write SSH environment", "path", ex.homeDir, "err", err)
rootSSHDir, err := prepareSSHDir(-1, -1, ex.homeDir)
if err != nil {
log.Warning(ctx, "failed to prepare ssh dir", "home", ex.homeDir, "err", err)
} else {
rootSSHEnvPath := filepath.Join(rootSSHDir, "environment")
restoreRootSSHEnv := backupFile(ctx, rootSSHEnvPath)
defer restoreRootSSHEnv(ctx)
if err := writeSSHEnvironment(envMap, -1, -1, rootSSHEnvPath); err != nil {
log.Warning(ctx, "failed to write SSH environment", "path", ex.homeDir, "err", err)
}
}
if user != nil && *user.Uid != 0 {
// non-root user
Expand All @@ -305,12 +313,23 @@ func (ex *RunExecutor) execJob(ctx context.Context, jobLogFile io.Writer) error
envMap["HOME"] = homeDir
if isHomeDirAccessible {
log.Trace(ctx, "provisioning homeDir", "path", homeDir)
if err := writeSSHEnvironment(envMap, uid, gid, homeDir); err != nil {
log.Warning(ctx, "failed to write SSH environment", "path", homeDir, "err", err)
}
akPath := filepath.Join(ex.homeDir, ".ssh/authorized_keys")
if err := copyAuthorizedKeys(akPath, uid, gid, homeDir); err != nil {
log.Warning(ctx, "failed to copy authorized keys", "path", homeDir, "err", err)
userSSHDir, err := prepareSSHDir(uid, gid, homeDir)
if err != nil {
log.Warning(ctx, "failed to prepare ssh dir", "home", homeDir, "err", err)
} else {
userSSHEnvPath := filepath.Join(userSSHDir, "environment")
restoreUserSSHEnv := backupFile(ctx, userSSHEnvPath)
defer restoreUserSSHEnv(ctx)
if err := writeSSHEnvironment(envMap, uid, gid, userSSHEnvPath); err != nil {
log.Warning(ctx, "failed to write SSH environment", "path", homeDir, "err", err)
}
rootSSHKeysPath := filepath.Join(rootSSHDir, "authorized_keys")
userSSHKeysPath := filepath.Join(userSSHDir, "authorized_keys")
restoreUserSSHKeys := backupFile(ctx, userSSHKeysPath)
defer restoreUserSSHKeys(ctx)
if err := copyAuthorizedKeys(rootSSHKeysPath, uid, gid, userSSHKeysPath); err != nil {
log.Warning(ctx, "failed to copy authorized keys", "path", homeDir, "err", err)
}
}
} else {
log.Trace(ctx, "homeDir is not accessible, skipping provisioning", "path", homeDir)
Expand Down Expand Up @@ -603,35 +622,31 @@ func prepareHomeDir(ctx context.Context, uid int, gid int, homeDir string) (stri
return homeDir, true
}

func writeSSHEnvironment(env map[string]string, uid int, gid int, homeDir string) error {
sshDir, err := joinRelPath(homeDir, ".ssh")
if err != nil {
return err
}
func prepareSSHDir(uid int, gid int, homeDir string) (string, error) {
sshDir := filepath.Join(homeDir, ".ssh")
info, err := os.Stat(sshDir)
if err == nil {
if !info.IsDir() {
return fmt.Errorf("not a directory: %s", sshDir)
return "", fmt.Errorf("not a directory: %s", sshDir)
}
if err = os.Chmod(sshDir, 0o700); err != nil {
return err
return "", err
}
} else if errors.Is(err, os.ErrNotExist) {
if err = os.MkdirAll(sshDir, 0o700); err != nil {
return err
return "", err
}
} else {
return err
return "", err
}
if err = os.Chown(sshDir, uid, gid); err != nil {
return err
return "", err
}
return sshDir, nil
}

envPath, err := joinRelPath(sshDir, "environment")
if err != nil {
return err
}
info, err = os.Stat(envPath)
func writeSSHEnvironment(env map[string]string, uid int, gid int, envPath string) error {
info, err := os.Stat(envPath)
if err == nil {
if info.IsDir() {
return fmt.Errorf("is a directory: %s", envPath)
Expand Down Expand Up @@ -684,42 +699,15 @@ func writeSSHEnvironment(env map[string]string, uid int, gid int, homeDir string
// without modifying the existing API/bootstrap process
// TODO: implement key delivery properly, i.e. sumbit keys to and write by the runner,
// not the outer sh script that launches sshd and runner
func copyAuthorizedKeys(src string, uid int, gid int, homeDir string) error {
srcFile, err := os.Open(src)
func copyAuthorizedKeys(srcPath string, uid int, gid int, dstPath string) error {
srcFile, err := os.Open(srcPath)
if err != nil {
return err
}
defer srcFile.Close()

sshDir, err := joinRelPath(homeDir, ".ssh")
if err != nil {
return err
}
info, err := os.Stat(sshDir)
if err == nil {
if !info.IsDir() {
return fmt.Errorf("not a directory: %s", sshDir)
}
if err = os.Chmod(sshDir, 0o700); err != nil {
return err
}
} else if errors.Is(err, os.ErrNotExist) {
if err = os.MkdirAll(sshDir, 0o700); err != nil {
return err
}
} else {
return err
}
if err = os.Chown(sshDir, uid, gid); err != nil {
return err
}

dstExists := false
dstPath, err := joinRelPath(sshDir, "authorized_keys")
if err != nil {
return err
}
info, err = os.Stat(dstPath)
info, err := os.Stat(dstPath)
if err == nil {
dstExists = true
if info.IsDir() {
Expand Down Expand Up @@ -753,3 +741,55 @@ func copyAuthorizedKeys(src string, uid int, gid int, homeDir string) error {

return nil
}

// backupFile renames `/path/to/file` to `/path/to/file.dstack.bak`,
// creates a new file with the same content, and returns restore function that
// renames the backup back to the original name.
// If the original file does not exist, restore function removes the file if it is created.
// NB: A newly created file has default uid:gid and permissions, probably not
// the same as the original file.
func backupFile(ctx context.Context, path string) func(context.Context) {
var existed bool
backupPath := path + ".dstack.bak"

restoreFunc := func(ctx context.Context) {
if !existed {
err := os.Remove(path)
if err != nil && !errors.Is(err, os.ErrNotExist) {
log.Warning(ctx, "failed to remove", "path", path)
}
return
}
err := os.Rename(backupPath, path)
if err != nil && !errors.Is(err, os.ErrNotExist) {
log.Warning(ctx, "failed to restore", "path", path)
}
}

err := os.Rename(path, backupPath)
if errors.Is(err, os.ErrNotExist) {
existed = false
return restoreFunc
}
existed = true
if err != nil {
log.Warning(ctx, "failed to back up", "path", path)
return restoreFunc
}

src, err := os.Open(backupPath)
if err != nil {
return restoreFunc
}
defer src.Close()
dst, err := os.Create(path)
if err != nil {
return restoreFunc
}
defer dst.Close()
_, err = io.Copy(dst, src)
if err != nil {
log.Warning(ctx, "failed to copy backup", "path", path)
}
return restoreFunc
}

0 comments on commit 6ec169d

Please sign in to comment.