Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support restarting training job #901

Merged
merged 1 commit into from
Oct 30, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions cmd/metricscollector/v1alpha3/file-metricscollector/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import (
"context"
"flag"
"os"
"path/filepath"
"strings"

"github.com/hpcloud/tail"
Expand Down Expand Up @@ -84,9 +85,10 @@ func main() {

go printMetricsFile(*metricsFileName)
wopts := common.WaitPidsOpts{
PollInterval: *pollInterval,
Timeout: *timeout,
WaitAll: *waitAll,
PollInterval: *pollInterval,
Timeout: *timeout,
WaitAll: *waitAll,
CompletedMarkedDirPath: filepath.Dir(*metricsFileName),
}
if err := common.Wait(wopts); err != nil {
klog.Fatalf("Failed to wait for worker container: %v", err)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def parse_options():
raise Exception("Invalid katib manager service address: %s" %
opt.manager_server_addr)

WaitOtherMainProcesses()
WaitOtherMainProcesses(completed_marked_dir=opt.dir_path)

mc = MetricsCollector(opt.metric_names.split(','))
observation_log = mc.parse_file(opt.dir_path)
Expand Down
2 changes: 2 additions & 0 deletions pkg/metricscollector/v1alpha3/common/const.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,6 @@ const (

MetricCollectorContainerName = "metrics-collector"
MetricLoggerCollectorContainerName = "metrics-logger-and-collector"

TrainingCompleted = "completed"
)
20 changes: 17 additions & 3 deletions pkg/metricscollector/v1alpha3/common/pns.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@ package common

import (
"fmt"
"io/ioutil"
"os"
"path/filepath"
"runtime"
"strings"
"time"

gops "github.com/mitchellh/go-ps"
Expand All @@ -27,9 +30,10 @@ import (
var ErrWaitPidTimeout = fmt.Errorf("Timed out waiting for PID to complete")

type WaitPidsOpts struct {
PollInterval time.Duration
Timeout time.Duration
WaitAll bool
PollInterval time.Duration
Timeout time.Duration
WaitAll bool
CompletedMarkedDirPath string
}

func Wait(opts WaitPidsOpts) error {
Expand Down Expand Up @@ -95,6 +99,16 @@ func WaitPIDS(pids []int, opts ...WaitPidsOpts) error {
_, err := os.Stat(path)
if err != nil {
if os.IsNotExist(err) {
if opts[0].CompletedMarkedDirPath != "" {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cannot figure out how this helps solve the problem, could you please explain more about it?

Copy link
Member Author

@hougangliu hougangliu Oct 30, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR adds echo completed > $mountPath/$$$$.pid after training command as below, if the training container succeed, it will touch a file named $mountPath/$processID.pid with "completed"
MetricsCollector container watches the training process, once find the process exit, it will check file $mountPath/$processID.pid to judge if the training process succeeds or not; if succeed, parse the metrics file; otherwise raise exception to exit.
Before this PR, metrics collector once find training process exit, it starts to parse metrics file, ignoring if training process exit status.

For now, it is hard to check other process exit code (I tries to use "strace" to implement it, but it need more linux capability.) Also we can call k8s api to get pod.status.containerStatus, but we need add extra role to worker pod, or add another service to proxy it.

  - args:
    - python /mxnet/example/image-classification/train_mnist.py --batch-size=64 --lr=0.02273874688380991
      --num-layers=3 --optimizer=sgd 1>/var/log/katib/metrics.log 2>&1 && echo completed
      > /var/log/katib/$$$$.pid
    command:
    - sh
    - -c

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this cause the similar pipe exit code problem as tee does?

Copy link
Member Author

@hougangliu hougangliu Oct 30, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, in fact for tee, the container exit code is always 0 even training process fails

for this PR:

  1. if training process fails, it will return training process exit code ( && echo will not be executed)
  2. if training process succeeds (exit code is 0), && echo will return 0, too. so the container exit code is 0, too

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah, misunderstand the logic here, SGTM

markFile := filepath.Join(opts[0].CompletedMarkedDirPath, fmt.Sprintf("%d.pid", pid))
if data, err := ioutil.ReadFile(markFile); err != nil {
return fmt.Errorf("Process %d hadn't completed: %v", pid, err)
} else {
if strings.TrimSpace(string(data)) != TrainingCompleted {
return fmt.Errorf("Process %d hadn't completed", pid)
}
}
}
if waitAll {
finishedPids = append(finishedPids, pid)
if len(finishedPids) == len(pids) {
Expand Down
12 changes: 9 additions & 3 deletions pkg/metricscollector/v1alpha3/common/pns.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def GetOtherMainProcesses():
pids.add(pid)
return pids

def WaitPIDs(pids, poll_interval_seconds=1, timeout_seconds=0, is_wait_all=False):
def WaitPIDs(pids, poll_interval_seconds=1, timeout_seconds=0, is_wait_all=False, completed_marked_dir=""):
start = 0
pids = set(pids)
if poll_interval_seconds <= 0:
Expand All @@ -26,6 +26,12 @@ def WaitPIDs(pids, poll_interval_seconds=1, timeout_seconds=0, is_wait_all=False
if os.path.isdir(path):
continue
else:
if completed_marked_dir:
mark_file = os.path.join(completed_marked_dir, "%d.pid" % pid)
with open(mark_file) as file_obj:
contents = file_obj.read()
if contents.strip() != "completed":
raise Exception("Pid %d hadn't completed" % pid)
if is_wait_all:
stop_pids.add(pid)
else:
Expand All @@ -35,5 +41,5 @@ def WaitPIDs(pids, poll_interval_seconds=1, timeout_seconds=0, is_wait_all=False
time.sleep(poll_interval_seconds)
start = start + poll_interval_seconds

def WaitOtherMainProcesses(poll_interval_seconds=1, timeout_seconds=0, is_wait_all=False):
return WaitPIDs(GetOtherMainProcesses(), poll_interval_seconds, timeout_seconds, is_wait_all)
def WaitOtherMainProcesses(poll_interval_seconds=1, timeout_seconds=0, is_wait_all=False, completed_marked_dir=""):
return WaitPIDs(GetOtherMainProcesses(), poll_interval_seconds, timeout_seconds, is_wait_all, completed_marked_dir)
23 changes: 17 additions & 6 deletions pkg/webhook/v1alpha3/pod/inject_webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ func (s *sidecarInjector) Mutate(pod *v1.Pod, namespace string) (*v1.Pod, error)

if mountPath, pathKind := getMountPath(trial.Spec.MetricsCollector); mountPath != "" {
if err = wrapWorkerContainer(
mutatedPod, kind, mountPath, trial.Spec.MetricsCollector); err != nil {
mutatedPod, kind, mountPath, pathKind, trial.Spec.MetricsCollector); err != nil {
return nil, err
}
if err = mutateVolume(mutatedPod, kind, mountPath, sidecarContainerName, pathKind); err != nil {
Expand Down Expand Up @@ -229,10 +229,8 @@ func getMountPath(mc common.MetricsCollectorSpec) (string, common.FileSystemKind

func wrapWorkerContainer(
pod *v1.Pod, jobKind, metricsFile string,
pathKind common.FileSystemKind,
mc common.MetricsCollectorSpec) error {
if mc.Collector.Kind != common.StdOutCollector {
return nil
}
index := -1
for i, c := range pod.Spec.Containers {
jobProvider, err := jobv1alpha3.New(jobKind)
Expand All @@ -255,15 +253,28 @@ func wrapWorkerContainer(
if c.Args != nil {
args = append(args, c.Args...)
}
redirectStr := fmt.Sprintf("1>%s 2>&1", metricsFile)
args = append(args, redirectStr)
if mc.Collector.Kind == common.StdOutCollector {
redirectStr := fmt.Sprintf("1>%s 2>&1", metricsFile)
args = append(args, redirectStr)
}
args = append(args, "&&", getMarkCompletedCommand(metricsFile, pathKind))
argsStr := strings.Join(args, " ")
c.Command = command
c.Args = []string{argsStr}
}
return nil
}

func getMarkCompletedCommand(mountPath string, pathKind common.FileSystemKind) string {
dir := mountPath
if pathKind == common.FileKind {
dir = filepath.Dir(mountPath)
}
// $$ is process id in shell
pidFile := filepath.Join(dir, "$$$$.pid")
return fmt.Sprintf("echo %s > %s", mccommon.TrainingCompleted, pidFile)
}

func mutateVolume(pod *v1.Pod, jobKind, mountPath, sidecarContainerName string, pathKind common.FileSystemKind) error {
metricsVol := v1.Volume{
Name: common.MetricsVolume,
Expand Down