Skip to content

Commit

Permalink
fix: refine ray logic and support extraAgrs in worker
Browse files Browse the repository at this point in the history
  • Loading branch information
nkwangleiGIT committed Mar 20, 2024
1 parent c12681c commit 32d82ce
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 10 deletions.
2 changes: 1 addition & 1 deletion config/samples/ray.io_v1_raycluster.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ spec:
runAsGroup: 0
fsGroup: 0
containers:
- image: kubeagi/ray-ml:2.9.0-py39-vllm
- image: kubeagi/ray-ml:2.9.3-py39-vllm
name: ray-head
resources:
limits:
Expand Down
31 changes: 22 additions & 9 deletions pkg/worker/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,17 @@ func (runner *RunnerFastchat) Build(ctx context.Context, model *arcadiav1alpha1.
return nil, fmt.Errorf("failed to get arcadia config with %w", err)
}

extraAgrs := ""
for _, envItem := range runner.w.Spec.AdditionalEnvs {
if envItem.Name == "EXTRA_ARGS" {
extraAgrs = envItem.Value
break
}
}

modelFileDir := fmt.Sprintf("/data/models/%s", model.Name)
additionalEnvs := []corev1.EnvVar{}
extraArgs := fmt.Sprintf("--device %s", runner.Device().String())
extraArgs := fmt.Sprintf("--device %s %s", runner.Device().String(), extraAgrs)
if runner.modelFileFromRemote {
m := arcadiav1alpha1.Model{}
if err := runner.c.Get(ctx, types.NamespacedName{Namespace: *model.Namespace, Name: model.Name}, &m); err != nil {
Expand Down Expand Up @@ -179,24 +187,29 @@ func (runner *RunnerFastchatVLLM) Build(ctx context.Context, model *arcadiav1alp
return nil, fmt.Errorf("failed to get arcadia config with %w", err)
}

rayEnabled := false
rayClusterAddress := ""
pythonVersion := ""
extraAgrs := ""
additionalEnvs := []corev1.EnvVar{}

// Get the real GPU requirement from env if configured
// this will be the total GPU from ray resource pool, not the resource requests/limits
gpuCount, _ := strconv.Atoi(runner.NumberOfGPUs())
rayClusterIndex := 0
for _, envItem := range runner.w.Spec.AdditionalEnvs {
// Check if Ray is enabled using distributed inference
if envItem.Name == "NUMBER_GPUS" {
gpuCount, _ = strconv.Atoi(envItem.Value)
rayEnabled = true
}
if envItem.Name == "RAY_CLUSTER_INDEX" {
rayClusterIndex, _ = strconv.Atoi(envItem.Value)
rayEnabled = true
}
if envItem.Name == "EXTRA_ARGS" {
extraAgrs = envItem.Value
}
}

// Get ray config from configMap
if gpuCount > 1 {
if rayEnabled {
rayClusters, err := config.GetRayClusters(ctx, runner.c)
if err != nil || len(rayClusters) == 0 {
klog.Warningln("no ray cluster configured, fallback to local resource: ", err)
Expand All @@ -208,14 +221,15 @@ func (runner *RunnerFastchatVLLM) Build(ctx context.Context, model *arcadiav1alp
klog.Infof("run worker using ray: %s, number of GPU: %s", rayClusterAddress, runner.NumberOfGPUs())
}
} else {
// Set gpu number to the number of GPUs in the worker's resource
additionalEnvs = append(additionalEnvs, corev1.EnvVar{Name: "NUMBER_GPUS", Value: runner.NumberOfGPUs()})
klog.Infof("run worker with %s GPU", runner.NumberOfGPUs())
}

modelFileDir := fmt.Sprintf("/data/models/%s", model.Name)
additionalEnvs := []corev1.EnvVar{}
// --enforce-eager to disable cupy
// TODO: remove --enforce-eager when https://github.com/kubeagi/arcadia/issues/878 is fixed
extraAgrs := "--trust-remote-code --enforce-eager"
extraAgrs = fmt.Sprintf("%s --trust-remote-code --enforce-eager", extraAgrs)
if runner.modelFileFromRemote {
m := arcadiav1alpha1.Model{}
if err := runner.c.Get(ctx, types.NamespacedName{Namespace: *model.Namespace, Name: model.Name}, &m); err != nil {
Expand Down Expand Up @@ -253,7 +267,6 @@ func (runner *RunnerFastchatVLLM) Build(ctx context.Context, model *arcadiav1alp
// Need python version and ray address for distributed inference
{Name: "PYTHON_VERSION", Value: pythonVersion},
{Name: "RAY_ADDRESS", Value: rayClusterAddress},
{Name: "NUMBER_GPUS", Value: strconv.Itoa(gpuCount)},
},
Ports: []corev1.ContainerPort{
{Name: "http", ContainerPort: arcadiav1alpha1.DefaultWorkerPort},
Expand Down

0 comments on commit 32d82ce

Please sign in to comment.