Skip to content

Commit

Permalink
KEP-2170: Add validation to Torch numProcPerNode field
Browse files Browse the repository at this point in the history
Signed-off-by: Antonin Stefanutti <antonin@stefanutti.fr>
  • Loading branch information
astefanutti committed Feb 14, 2025
1 parent 3f3a8d3 commit 492a90e
Show file tree
Hide file tree
Showing 19 changed files with 61 additions and 42 deletions.
4 changes: 2 additions & 2 deletions api/openapi-spec/swagger.json
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,7 @@
},
"numProcPerNode": {
"description": "Number of processes per node. This value is inserted into the `--nproc-per-node` argument of the `torchrun` CLI. Supported values: `auto`, `cpu`, `gpu`, or int value. Defaults to `auto`.",
"type": "string"
"$ref": "#/definitions/k8s.io.apimachinery.pkg.util.intstr.IntOrString"
}
}
},
Expand Down Expand Up @@ -716,7 +716,7 @@
},
"numProcPerNode": {
"description": "Number of processes/workers/slots on every training node. For the Torch runtime: `auto`, `cpu`, `gpu`, or int value can be set. For the MPI runtime only int value can be set.",
"type": "string"
"$ref": "#/definitions/k8s.io.apimachinery.pkg.util.intstr.IntOrString"
},
"resourcesPerNode": {
"description": "Compute resources for each training node.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -585,12 +585,17 @@ spec:
type: integer
type: object
numProcPerNode:
anyOf:
- type: integer
- type: string
description: |-
Number of processes per node.
This value is inserted into the `--nproc-per-node` argument of the `torchrun` CLI.
Supported values: `auto`, `cpu`, `gpu`, or int value.
Defaults to `auto`.
type: string
x-kubernetes-int-or-string: true
x-kubernetes-validations:
- rule: self > 0 || self in ['auto', 'cpu', 'gpu']
type: object
type: object
podGroupPolicy:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -585,12 +585,17 @@ spec:
type: integer
type: object
numProcPerNode:
anyOf:
- type: integer
- type: string
description: |-
Number of processes per node.
This value is inserted into the `--nproc-per-node` argument of the `torchrun` CLI.
Supported values: `auto`, `cpu`, `gpu`, or int value.
Defaults to `auto`.
type: string
x-kubernetes-int-or-string: true
x-kubernetes-validations:
- rule: self > 0 || self in ['auto', 'cpu', 'gpu']
type: object
type: object
podGroupPolicy:
Expand Down
5 changes: 4 additions & 1 deletion manifests/base/crds/trainer.kubeflow.org_trainjobs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3138,11 +3138,14 @@ spec:
format: int32
type: integer
numProcPerNode:
anyOf:
- type: integer
- type: string
description: |-
Number of processes/workers/slots on every training node.
For the Torch runtime: `auto`, `cpu`, `gpu`, or int value can be set.
For the MPI runtime only int value can be set.
type: string
x-kubernetes-int-or-string: true
resourcesPerNode:
description: Compute resources for each training node.
properties:
Expand Down
5 changes: 3 additions & 2 deletions pkg/apis/trainer/v1alpha1/trainingruntime_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package v1alpha1
import (
autoscalingv2 "k8s.io/api/autoscaling/v2"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/util/intstr"
jobsetv1alpha2 "sigs.k8s.io/jobset/api/jobset/v1alpha2"
)

Expand Down Expand Up @@ -171,9 +172,9 @@ type TorchMLPolicySource struct {
// Number of processes per node.
// This value is inserted into the `--nproc-per-node` argument of the `torchrun` CLI.
// Supported values: `auto`, `cpu`, `gpu`, or int value.
// TODO (andreyvelich): Add kubebuilder validation.
// Defaults to `auto`.
NumProcPerNode *string `json:"numProcPerNode,omitempty"`
// +kubebuilder:validation:XValidation:rule="self > 0 || self in ['auto', 'cpu', 'gpu']"
NumProcPerNode *intstr.IntOrString `json:"numProcPerNode,omitempty"`

// Elastic policy for the PyTorch training.
ElasticPolicy *TorchElasticPolicy `json:"elasticPolicy,omitempty"`
Expand Down
3 changes: 2 additions & 1 deletion pkg/apis/trainer/v1alpha1/trainjob_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package v1alpha1
import (
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/util/intstr"
)

const (
Expand Down Expand Up @@ -194,7 +195,7 @@ type Trainer struct {
// Number of processes/workers/slots on every training node.
// For the Torch runtime: `auto`, `cpu`, `gpu`, or int value can be set.
// For the MPI runtime only int value can be set.
NumProcPerNode *string `json:"numProcPerNode,omitempty"`
NumProcPerNode *intstr.IntOrString `json:"numProcPerNode,omitempty"`
}

// DatasetConfig represents the desired dataset configuration.
Expand Down
5 changes: 3 additions & 2 deletions pkg/apis/trainer/v1alpha1/zz_generated.deepcopy.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 4 additions & 6 deletions pkg/apis/trainer/v1alpha1/zz_generated.openapi.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions pkg/client/applyconfiguration/trainer/v1alpha1/trainer.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 4 additions & 3 deletions pkg/runtime/core/trainingruntime_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/resource"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/util/intstr"
"sigs.k8s.io/controller-runtime/pkg/client"
schedulerpluginsv1alpha1 "sigs.k8s.io/scheduler-plugins/apis/scheduling/v1alpha1"

Expand Down Expand Up @@ -263,7 +264,7 @@ func TestTrainingRuntimeNewObjects(t *testing.T) {
"succeeded to build JobSet with Torch values from the TrainJob": {
trainingRuntime: testingutil.MakeTrainingRuntimeWrapper(metav1.NamespaceDefault, "test-runtime").RuntimeSpec(
testingutil.MakeTrainingRuntimeSpecWrapper(testingutil.MakeTrainingRuntimeWrapper(metav1.NamespaceDefault, "test-runtime").Spec).
TorchPolicy(100, "auto").
TorchPolicy(100, intstr.FromString("auto")).
ContainerTrainer("test:runtime", []string{"runtime"}, []string{"runtime"}, resRequests).
Obj(),
).Obj(),
Expand All @@ -273,7 +274,7 @@ func TestTrainingRuntimeNewObjects(t *testing.T) {
Trainer(
testingutil.MakeTrainJobTrainerWrapper().
NumNodes(30).
NumProcPerNode("3").
NumProcPerNode(intstr.FromInt32(3)).
Obj(),
).
Obj(),
Expand Down Expand Up @@ -317,7 +318,7 @@ func TestTrainingRuntimeNewObjects(t *testing.T) {
"succeeded to build JobSet with Torch values from the Runtime and envs.": {
trainingRuntime: testingutil.MakeTrainingRuntimeWrapper(metav1.NamespaceDefault, "test-runtime").RuntimeSpec(
testingutil.MakeTrainingRuntimeSpecWrapper(testingutil.MakeTrainingRuntimeWrapper(metav1.NamespaceDefault, "test-runtime").Spec).
TorchPolicy(100, "auto").
TorchPolicy(100, intstr.FromString("auto")).
ContainerTrainer("test:runtime", []string{"runtime"}, []string{"runtime"}, resRequests).
ContainerTrainerEnv(
[]corev1.EnvVar{
Expand Down
2 changes: 1 addition & 1 deletion pkg/runtime/framework/plugins/mpi/mpi.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ func (m *MPI) EnforceMLPolicy(info *runtime.Info, trainJob *trainer.TrainJob) er

numProcPerNode := strconv.Itoa(int(*info.RuntimePolicy.MLPolicy.MPI.NumProcPerNode))
if trainJob.Spec.Trainer != nil && trainJob.Spec.Trainer.NumProcPerNode != nil {
numProcPerNode = *trainJob.Spec.Trainer.NumProcPerNode
numProcPerNode = (*trainJob.Spec.Trainer.NumProcPerNode).String()
}
info.Trainer.NumProcPerNode = numProcPerNode

Expand Down
7 changes: 4 additions & 3 deletions pkg/runtime/framework/plugins/torch/torch.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"fmt"

corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/util/intstr"
"k8s.io/apimachinery/pkg/util/sets"
"k8s.io/apimachinery/pkg/util/validation/field"
"k8s.io/utils/ptr"
Expand Down Expand Up @@ -66,9 +67,9 @@ func (t *Torch) EnforceMLPolicy(info *runtime.Info, trainJob *trainer.TrainJob)
}
info.Trainer.NumNodes = numNodes

numProcPerNode := info.RuntimePolicy.MLPolicy.Torch.NumProcPerNode
numProcPerNode := ptr.Deref(info.RuntimePolicy.MLPolicy.Torch.NumProcPerNode, intstr.FromString("auto"))
if trainJob.Spec.Trainer != nil && trainJob.Spec.Trainer.NumProcPerNode != nil {
numProcPerNode = trainJob.Spec.Trainer.NumProcPerNode
numProcPerNode = ptr.Deref(trainJob.Spec.Trainer.NumProcPerNode, intstr.FromString("auto"))
}

// Update envs for Info object.
Expand All @@ -83,7 +84,7 @@ func (t *Torch) EnforceMLPolicy(info *runtime.Info, trainJob *trainer.TrainJob)
},
{
Name: constants.TorchEnvNumProcPerNode,
Value: ptr.Deref(numProcPerNode, "auto"),
Value: numProcPerNode.String(),
},
{
Name: constants.TorchEnvNodeRank,
Expand Down
5 changes: 3 additions & 2 deletions pkg/util/testing/wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/types"
"k8s.io/apimachinery/pkg/util/intstr"
"k8s.io/utils/ptr"
jobsetv1alpha2 "sigs.k8s.io/jobset/api/jobset/v1alpha2"
schedulerpluginsv1alpha1 "sigs.k8s.io/scheduler-plugins/apis/scheduling/v1alpha1"
Expand Down Expand Up @@ -392,7 +393,7 @@ func (t *TrainJobTrainerWrapper) NumNodes(numNodes int32) *TrainJobTrainerWrappe
return t
}

func (t *TrainJobTrainerWrapper) NumProcPerNode(numProcPerNode string) *TrainJobTrainerWrapper {
func (t *TrainJobTrainerWrapper) NumProcPerNode(numProcPerNode intstr.IntOrString) *TrainJobTrainerWrapper {
t.Trainer.NumProcPerNode = &numProcPerNode
return t
}
Expand Down Expand Up @@ -689,7 +690,7 @@ func (s *TrainingRuntimeSpecWrapper) NumNodes(numNodes int32) *TrainingRuntimeSp
return s
}

func (s *TrainingRuntimeSpecWrapper) TorchPolicy(numNodes int32, numProcPerNode string) *TrainingRuntimeSpecWrapper {
func (s *TrainingRuntimeSpecWrapper) TorchPolicy(numNodes int32, numProcPerNode intstr.IntOrString) *TrainingRuntimeSpecWrapper {
s.MLPolicy = &trainer.MLPolicy{
NumNodes: &numNodes,
MLPolicySource: trainer.MLPolicySource{
Expand Down
2 changes: 1 addition & 1 deletion sdk/docs/TrainerV1alpha1TorchMLPolicySource.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ TorchMLPolicySource represents a PyTorch runtime configuration.
Name | Type | Description | Notes
------------ | ------------- | ------------- | -------------
**elastic_policy** | [**TrainerV1alpha1TorchElasticPolicy**](TrainerV1alpha1TorchElasticPolicy.md) | | [optional]
**num_proc_per_node** | **str** | Number of processes per node. This value is inserted into the &#x60;--nproc-per-node&#x60; argument of the &#x60;torchrun&#x60; CLI. Supported values: &#x60;auto&#x60;, &#x60;cpu&#x60;, &#x60;gpu&#x60;, or int value. Defaults to &#x60;auto&#x60;. | [optional]
**num_proc_per_node** | [**K8sIoApimachineryPkgUtilIntstrIntOrString**](K8sIoApimachineryPkgUtilIntstrIntOrString.md) | | [optional]

[[Back to Model list]](../README.md#documentation-for-models) [[Back to API list]](../README.md#documentation-for-api-endpoints) [[Back to README]](../README.md)

Expand Down
2 changes: 1 addition & 1 deletion sdk/docs/TrainerV1alpha1Trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ Name | Type | Description | Notes
**env** | [**list[V1EnvVar]**](V1EnvVar.md) | List of environment variables to set in the training container. These values will be merged with the TrainingRuntime&#39;s trainer environments. | [optional]
**image** | **str** | Docker image for the training container. | [optional]
**num_nodes** | **int** | Number of training nodes. | [optional]
**num_proc_per_node** | **str** | Number of processes/workers/slots on every training node. For the Torch runtime: &#x60;auto&#x60;, &#x60;cpu&#x60;, &#x60;gpu&#x60;, or int value can be set. For the MPI runtime only int value can be set. | [optional]
**num_proc_per_node** | [**K8sIoApimachineryPkgUtilIntstrIntOrString**](K8sIoApimachineryPkgUtilIntstrIntOrString.md) | | [optional]
**resources_per_node** | [**V1ResourceRequirements**](V1ResourceRequirements.md) | | [optional]

[[Back to Model list]](../README.md#documentation-for-models) [[Back to API list]](../README.md#documentation-for-api-endpoints) [[Back to README]](../README.md)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class TrainerV1alpha1TorchMLPolicySource(object):
"""
openapi_types = {
'elastic_policy': 'TrainerV1alpha1TorchElasticPolicy',
'num_proc_per_node': 'str'
'num_proc_per_node': 'K8sIoApimachineryPkgUtilIntstrIntOrString'
}

attribute_map = {
Expand Down Expand Up @@ -82,21 +82,19 @@ def elastic_policy(self, elastic_policy):
def num_proc_per_node(self):
"""Gets the num_proc_per_node of this TrainerV1alpha1TorchMLPolicySource. # noqa: E501
Number of processes per node. This value is inserted into the `--nproc-per-node` argument of the `torchrun` CLI. Supported values: `auto`, `cpu`, `gpu`, or int value. Defaults to `auto`. # noqa: E501
:return: The num_proc_per_node of this TrainerV1alpha1TorchMLPolicySource. # noqa: E501
:rtype: str
:rtype: K8sIoApimachineryPkgUtilIntstrIntOrString
"""
return self._num_proc_per_node

@num_proc_per_node.setter
def num_proc_per_node(self, num_proc_per_node):
"""Sets the num_proc_per_node of this TrainerV1alpha1TorchMLPolicySource.
Number of processes per node. This value is inserted into the `--nproc-per-node` argument of the `torchrun` CLI. Supported values: `auto`, `cpu`, `gpu`, or int value. Defaults to `auto`. # noqa: E501
:param num_proc_per_node: The num_proc_per_node of this TrainerV1alpha1TorchMLPolicySource. # noqa: E501
:type: str
:type: K8sIoApimachineryPkgUtilIntstrIntOrString
"""

self._num_proc_per_node = num_proc_per_node
Expand Down
Loading

0 comments on commit 492a90e

Please sign in to comment.