Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[runner] Back up and restore ~/.ssh files #2261

Merged
merged 2 commits into from
Feb 4, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 95 additions & 53 deletions runner/internal/executor/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -294,8 +294,16 @@ func (ex *RunExecutor) execJob(ctx context.Context, jobLogFile io.Writer) error
envMap.Update(ex.jobSpec.Env, false)

// As of 2024-11-29, ex.homeDir is always set to /root
if err := writeSSHEnvironment(envMap, -1, -1, ex.homeDir); err != nil {
log.Warning(ctx, "failed to write SSH environment", "path", ex.homeDir, "err", err)
rootSSHDir, err := prepareSSHDir(-1, -1, ex.homeDir)
if err != nil {
log.Warning(ctx, "failed to prepare ssh dir", "home", ex.homeDir, "err", err)
} else {
rootSSHEnvPath := filepath.Join(rootSSHDir, "environment")
restoreRootSSHEnv := backupFile(ctx, rootSSHEnvPath)
defer restoreRootSSHEnv(ctx)
if err := writeSSHEnvironment(envMap, -1, -1, rootSSHEnvPath); err != nil {
log.Warning(ctx, "failed to write SSH environment", "path", ex.homeDir, "err", err)
}
}
if user != nil && *user.Uid != 0 {
// non-root user
Expand All @@ -305,12 +313,23 @@ func (ex *RunExecutor) execJob(ctx context.Context, jobLogFile io.Writer) error
envMap["HOME"] = homeDir
if isHomeDirAccessible {
log.Trace(ctx, "provisioning homeDir", "path", homeDir)
if err := writeSSHEnvironment(envMap, uid, gid, homeDir); err != nil {
log.Warning(ctx, "failed to write SSH environment", "path", homeDir, "err", err)
}
akPath := filepath.Join(ex.homeDir, ".ssh/authorized_keys")
if err := copyAuthorizedKeys(akPath, uid, gid, homeDir); err != nil {
log.Warning(ctx, "failed to copy authorized keys", "path", homeDir, "err", err)
userSSHDir, err := prepareSSHDir(uid, gid, homeDir)
if err != nil {
log.Warning(ctx, "failed to prepare ssh dir", "home", homeDir, "err", err)
} else {
userSSHEnvPath := filepath.Join(userSSHDir, "environment")
restoreUserSSHEnv := backupFile(ctx, userSSHEnvPath)
defer restoreUserSSHEnv(ctx)
if err := writeSSHEnvironment(envMap, uid, gid, userSSHEnvPath); err != nil {
log.Warning(ctx, "failed to write SSH environment", "path", homeDir, "err", err)
}
rootSSHKeysPath := filepath.Join(rootSSHDir, "authorized_keys")
userSSHKeysPath := filepath.Join(userSSHDir, "authorized_keys")
restoreUserSSHKeys := backupFile(ctx, userSSHKeysPath)
defer restoreUserSSHKeys(ctx)
if err := copyAuthorizedKeys(rootSSHKeysPath, uid, gid, userSSHKeysPath); err != nil {
log.Warning(ctx, "failed to copy authorized keys", "path", homeDir, "err", err)
}
}
} else {
log.Trace(ctx, "homeDir is not accessible, skipping provisioning", "path", homeDir)
Expand Down Expand Up @@ -603,35 +622,31 @@ func prepareHomeDir(ctx context.Context, uid int, gid int, homeDir string) (stri
return homeDir, true
}

func writeSSHEnvironment(env map[string]string, uid int, gid int, homeDir string) error {
sshDir, err := joinRelPath(homeDir, ".ssh")
if err != nil {
return err
}
func prepareSSHDir(uid int, gid int, homeDir string) (string, error) {
sshDir := filepath.Join(homeDir, ".ssh")
info, err := os.Stat(sshDir)
if err == nil {
if !info.IsDir() {
return fmt.Errorf("not a directory: %s", sshDir)
return "", fmt.Errorf("not a directory: %s", sshDir)
}
if err = os.Chmod(sshDir, 0o700); err != nil {
return err
return "", err
}
} else if errors.Is(err, os.ErrNotExist) {
if err = os.MkdirAll(sshDir, 0o700); err != nil {
return err
return "", err
}
} else {
return err
return "", err
}
if err = os.Chown(sshDir, uid, gid); err != nil {
return err
return "", err
}
return sshDir, nil
}

envPath, err := joinRelPath(sshDir, "environment")
if err != nil {
return err
}
info, err = os.Stat(envPath)
func writeSSHEnvironment(env map[string]string, uid int, gid int, envPath string) error {
info, err := os.Stat(envPath)
if err == nil {
if info.IsDir() {
return fmt.Errorf("is a directory: %s", envPath)
Expand Down Expand Up @@ -684,42 +699,15 @@ func writeSSHEnvironment(env map[string]string, uid int, gid int, homeDir string
// without modifying the existing API/bootstrap process
// TODO: implement key delivery properly, i.e. sumbit keys to and write by the runner,
// not the outer sh script that launches sshd and runner
func copyAuthorizedKeys(src string, uid int, gid int, homeDir string) error {
srcFile, err := os.Open(src)
func copyAuthorizedKeys(srcPath string, uid int, gid int, dstPath string) error {
srcFile, err := os.Open(srcPath)
if err != nil {
return err
}
defer srcFile.Close()

sshDir, err := joinRelPath(homeDir, ".ssh")
if err != nil {
return err
}
info, err := os.Stat(sshDir)
if err == nil {
if !info.IsDir() {
return fmt.Errorf("not a directory: %s", sshDir)
}
if err = os.Chmod(sshDir, 0o700); err != nil {
return err
}
} else if errors.Is(err, os.ErrNotExist) {
if err = os.MkdirAll(sshDir, 0o700); err != nil {
return err
}
} else {
return err
}
if err = os.Chown(sshDir, uid, gid); err != nil {
return err
}

dstExists := false
dstPath, err := joinRelPath(sshDir, "authorized_keys")
if err != nil {
return err
}
info, err = os.Stat(dstPath)
info, err := os.Stat(dstPath)
if err == nil {
dstExists = true
if info.IsDir() {
Expand Down Expand Up @@ -753,3 +741,57 @@ func copyAuthorizedKeys(src string, uid int, gid int, homeDir string) error {

return nil
}

// backupFile renames `/path/to/file` to `/path/to/file.dstack.bak`,
// creates a new file with the same content, and returns restore function that
// renames the backup back to the original name.
// If the original file does not exist, restore function removes the file if it is created.
// NB: A newly created file has default uid:gid and permissions, probably not
// the same as the original file.
func backupFile(ctx context.Context, path string) func(context.Context) {
var existed bool
backupPath := path + ".dstack.bak"

restoreFunc := func(ctx context.Context) {
if !existed {
err := os.Remove(path)
if err != nil && !errors.Is(err, os.ErrNotExist) {
log.Error(ctx, "failed to remove", "path", path, "err", err)
}
return
}
err := os.Rename(backupPath, path)
if err != nil && !errors.Is(err, os.ErrNotExist) {
log.Error(ctx, "failed to restore", "path", path, "err", err)
}
}

err := os.Rename(path, backupPath)
if errors.Is(err, os.ErrNotExist) {
existed = false
return restoreFunc
}
existed = true
if err != nil {
log.Error(ctx, "failed to back up", "path", path, "err", err)
return restoreFunc
}

src, err := os.Open(backupPath)
if err != nil {
log.Error(ctx, "failed to open backup src", "path", backupPath, "err", err)
return restoreFunc
}
defer src.Close()
dst, err := os.Create(path)
if err != nil {
log.Error(ctx, "failed to open backup dest", "path", path, "err", err)
return restoreFunc
}
defer dst.Close()
_, err = io.Copy(dst, src)
if err != nil {
log.Error(ctx, "failed to copy backup", "path", backupPath, "err", err)
}
return restoreFunc
}