From 55e18269c36ccc89c9a2b2e5d5d626b018644d55 Mon Sep 17 00:00:00 2001 From: hongye-sun <43763191+hongye-sun@users.noreply.github.com> Date: Thu, 6 Dec 2018 21:55:36 -0800 Subject: [PATCH] support tpu settings in dsl (#491) * support tpu settings in dsl * fix issues from review comment --- sdk/python/kfp/compiler/compiler.py | 8 ++++ sdk/python/kfp/dsl/_container_op.py | 38 +++++++++++++++---- sdk/python/kfp/gcp.py | 22 +++++++++++ sdk/python/tests/compiler/testdata/basic.py | 2 + sdk/python/tests/compiler/testdata/basic.yaml | 4 ++ 5 files changed, 67 insertions(+), 7 deletions(-) diff --git a/sdk/python/kfp/compiler/compiler.py b/sdk/python/kfp/compiler/compiler.py index 20f3d0a0056..03f70b9e2b5 100644 --- a/sdk/python/kfp/compiler/compiler.py +++ b/sdk/python/kfp/compiler/compiler.py @@ -143,6 +143,14 @@ def _op_to_template(self, op): template['container']['env'] = list(map(self._convert_k8s_obj_to_dic, op.env_variables)) if op.volume_mounts: template['container']['volumeMounts'] = list(map(self._convert_k8s_obj_to_dic, op.volume_mounts)) + + if op.pod_annotations or op.pod_labels: + template['metadata'] = {} + if op.pod_annotations: + template['metadata']['annotations'] = op.pod_annotations + if op.pod_labels: + template['metadata']['labels'] = op.pod_labels + return template def _get_groups_for_ops(self, root_group): diff --git a/sdk/python/kfp/dsl/_container_op.py b/sdk/python/kfp/dsl/_container_op.py index de29767fe49..e1f5d7042c6 100644 --- a/sdk/python/kfp/dsl/_container_op.py +++ b/sdk/python/kfp/dsl/_container_op.py @@ -59,6 +59,8 @@ def __init__(self, name: str, image: str, command: str=None, arguments: str=None self.volumes = [] self.volume_mounts = [] self.env_variables = [] + self.pod_annotations = {} + self.pod_labels = {} matches = [] if arguments: @@ -127,16 +129,16 @@ def _validate_cpu_string(self, cpu_string): raise ValueError('Invalid cpu string. Should be float or integer, or integer followed ' 'by "m".') - def _validate_gpu_string(self, gpu_string): - "Validate a given string is valid for gpu limit." + def _validate_positive_number(self, str_value, param_name): + "Validate a given string is in positive integer format." try: - gpu_value = int(gpu_string) + int_value = int(str_value) except ValueError: - raise ValueError('Invalid gpu string. Should be integer.') + raise ValueError('Invalid {}. Should be integer.'.format(param_name)) - if gpu_value <= 0: - raise ValueError('gpu must be positive integer.') + if int_value <= 0: + raise ValueError('{} must be positive integer.'.format(param_name)) def add_resource_limit(self, resource_name, value): """Add the resource limit of the container. @@ -212,7 +214,7 @@ def set_gpu_limit(self, gpu, vendor = "nvidia"): are: 'nvidia' (default), and 'amd'. """ - self._validate_gpu_string(gpu) + self._validate_positive_number(gpu, 'gpu') if vendor != 'nvidia' and vendor != 'amd': raise ValueError('vendor can only be nvidia or amd.') @@ -268,5 +270,27 @@ def add_node_selector_constraint(self, label_name, value): self.node_selector[label_name] = value return self + def add_pod_annotation(self, name: str, value: str): + """Adds a pod's metadata annotation. + + Args: + name: The name of the annotation. + value: The value of the annotation. + """ + + self.pod_annotations[name] = value + return self + + def add_pod_label(self, name: str, value: str): + """Adds a pod's metadata label. + + Args: + name: The name of the label. + value: The value of the label. + """ + + self.pod_labels[name] = value + return self + def __repr__(self): return str({self.__class__.__name__: self.__dict__}) diff --git a/sdk/python/kfp/gcp.py b/sdk/python/kfp/gcp.py index e271d26e8d7..9b346d4b56b 100644 --- a/sdk/python/kfp/gcp.py +++ b/sdk/python/kfp/gcp.py @@ -56,3 +56,25 @@ def _use_gcp_secret(task): ) return _use_gcp_secret + +def use_tpu(tpu_cores: int, tpu_resource: str, tf_version: str): + """An operator that configures GCP TPU spec in a container op. + + Args: + tpu_cores: Required. The number of cores of TPU resource. + For example, the value can be '8', '32', '128', etc. + Check more details at: https://cloud.google.com/tpu/docs/kubernetes-engine-setup#pod-spec. + tpu_resource: Required. The resource name of the TPU resource. + For example, the value can be 'v2', 'preemptible-v1', 'v3' or 'preemptible-v3'. + Check more details at: https://cloud.google.com/tpu/docs/kubernetes-engine-setup#pod-spec. + tf_version: Required. The TensorFlow version that the TPU nodes use. + For example, the value can be '1.12', '1.11', '1.9' or '1.8'. + Check more details at: https://cloud.google.com/tpu/docs/supported-versions. + """ + + def _set_tpu_spec(task): + task.add_pod_annotation('tf-version.cloud-tpus.google.com', tf_version) + task.add_resource_limit('cloud-tpus.google.com/{}'.format(tpu_resource), str(tpu_cores)) + return task + + return _set_tpu_spec diff --git a/sdk/python/tests/compiler/testdata/basic.py b/sdk/python/tests/compiler/testdata/basic.py index 85f12d206a4..a81cf3a4d1f 100644 --- a/sdk/python/tests/compiler/testdata/basic.py +++ b/sdk/python/tests/compiler/testdata/basic.py @@ -14,6 +14,7 @@ import kfp.dsl as dsl +import kfp.gcp as gcp class GetFrequentWordOp(dsl.ContainerOp): @@ -87,3 +88,4 @@ def save_most_frequent_word(message: str, outputpath: str): saver.set_cpu_limit('0.5') saver.set_gpu_limit('2') saver.add_node_selector_constraint('cloud.google.com/gke-accelerator', 'nvidia-tesla-k80') + saver.apply(gcp.use_tpu(tpu_cores = 8, tpu_resource = 'v2', tf_version = '1.12')) diff --git a/sdk/python/tests/compiler/testdata/basic.yaml b/sdk/python/tests/compiler/testdata/basic.yaml index f9da8393bba..4298bf1e751 100644 --- a/sdk/python/tests/compiler/testdata/basic.yaml +++ b/sdk/python/tests/compiler/testdata/basic.yaml @@ -140,12 +140,16 @@ spec: image: google/cloud-sdk resources: limits: + cloud-tpus.google.com/v2: "8" cpu: "0.5" nvidia.com/gpu: "2" inputs: parameters: - name: get-frequent-word - name: outputpath + metadata: + annotations: + tf-version.cloud-tpus.google.com: "1.12" name: save nodeSelector: cloud.google.com/gke-accelerator: nvidia-tesla-k80