diff --git a/config/samples/ray.io_v1_raycluster.yaml b/config/samples/ray.io_v1_raycluster.yaml index bddbeae47..6ec0b612f 100644 --- a/config/samples/ray.io_v1_raycluster.yaml +++ b/config/samples/ray.io_v1_raycluster.yaml @@ -18,7 +18,7 @@ spec: runAsGroup: 0 fsGroup: 0 containers: - - image: kubeagi/ray-ml:2.9.0-py39-vllm + - image: kubeagi/ray-ml:2.9.3-py39-vllm name: ray-head resources: limits: diff --git a/pkg/worker/runner.go b/pkg/worker/runner.go index b7c884512..585bfa0c8 100644 --- a/pkg/worker/runner.go +++ b/pkg/worker/runner.go @@ -87,9 +87,17 @@ func (runner *RunnerFastchat) Build(ctx context.Context, model *arcadiav1alpha1. return nil, fmt.Errorf("failed to get arcadia config with %w", err) } + extraAgrs := "" + for _, envItem := range runner.w.Spec.AdditionalEnvs { + if envItem.Name == "EXTRA_ARGS" { + extraAgrs = envItem.Value + break + } + } + modelFileDir := fmt.Sprintf("/data/models/%s", model.Name) additionalEnvs := []corev1.EnvVar{} - extraArgs := fmt.Sprintf("--device %s", runner.Device().String()) + extraArgs := fmt.Sprintf("--device %s %s", runner.Device().String(), extraAgrs) if runner.modelFileFromRemote { m := arcadiav1alpha1.Model{} if err := runner.c.Get(ctx, types.NamespacedName{Namespace: *model.Namespace, Name: model.Name}, &m); err != nil { @@ -179,24 +187,29 @@ func (runner *RunnerFastchatVLLM) Build(ctx context.Context, model *arcadiav1alp return nil, fmt.Errorf("failed to get arcadia config with %w", err) } + rayEnabled := false rayClusterAddress := "" pythonVersion := "" + extraAgrs := "" + additionalEnvs := []corev1.EnvVar{} - // Get the real GPU requirement from env if configured - // this will be the total GPU from ray resource pool, not the resource requests/limits - gpuCount, _ := strconv.Atoi(runner.NumberOfGPUs()) rayClusterIndex := 0 for _, envItem := range runner.w.Spec.AdditionalEnvs { + // Check if Ray is enabled using distributed inference if envItem.Name == "NUMBER_GPUS" { - gpuCount, _ = strconv.Atoi(envItem.Value) + rayEnabled = true } if envItem.Name == "RAY_CLUSTER_INDEX" { rayClusterIndex, _ = strconv.Atoi(envItem.Value) + rayEnabled = true + } + if envItem.Name == "EXTRA_ARGS" { + extraAgrs = envItem.Value } } // Get ray config from configMap - if gpuCount > 1 { + if rayEnabled { rayClusters, err := config.GetRayClusters(ctx, runner.c) if err != nil || len(rayClusters) == 0 { klog.Warningln("no ray cluster configured, fallback to local resource: ", err) @@ -208,14 +221,15 @@ func (runner *RunnerFastchatVLLM) Build(ctx context.Context, model *arcadiav1alp klog.Infof("run worker using ray: %s, number of GPU: %s", rayClusterAddress, runner.NumberOfGPUs()) } } else { + // Set gpu number to the number of GPUs in the worker's resource + additionalEnvs = append(additionalEnvs, corev1.EnvVar{Name: "NUMBER_GPUS", Value: runner.NumberOfGPUs()}) klog.Infof("run worker with %s GPU", runner.NumberOfGPUs()) } modelFileDir := fmt.Sprintf("/data/models/%s", model.Name) - additionalEnvs := []corev1.EnvVar{} // --enforce-eager to disable cupy // TODO: remove --enforce-eager when https://github.com/kubeagi/arcadia/issues/878 is fixed - extraAgrs := "--trust-remote-code --enforce-eager" + extraAgrs = fmt.Sprintf("%s --trust-remote-code --enforce-eager", extraAgrs) if runner.modelFileFromRemote { m := arcadiav1alpha1.Model{} if err := runner.c.Get(ctx, types.NamespacedName{Namespace: *model.Namespace, Name: model.Name}, &m); err != nil { @@ -253,7 +267,6 @@ func (runner *RunnerFastchatVLLM) Build(ctx context.Context, model *arcadiav1alp // Need python version and ray address for distributed inference {Name: "PYTHON_VERSION", Value: pythonVersion}, {Name: "RAY_ADDRESS", Value: rayClusterAddress}, - {Name: "NUMBER_GPUS", Value: strconv.Itoa(gpuCount)}, }, Ports: []corev1.ContainerPort{ {Name: "http", ContainerPort: arcadiav1alpha1.DefaultWorkerPort},