Skip to content

Commit

Permalink
support tpu settings in dsl (#491)
Browse files Browse the repository at this point in the history
* support tpu settings in dsl

* fix issues from review comment
  • Loading branch information
hongye-sun authored and k8s-ci-robot committed Dec 7, 2018
1 parent a9d0689 commit 55e1826
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 7 deletions.
8 changes: 8 additions & 0 deletions sdk/python/kfp/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,14 @@ def _op_to_template(self, op):
template['container']['env'] = list(map(self._convert_k8s_obj_to_dic, op.env_variables))
if op.volume_mounts:
template['container']['volumeMounts'] = list(map(self._convert_k8s_obj_to_dic, op.volume_mounts))

if op.pod_annotations or op.pod_labels:
template['metadata'] = {}
if op.pod_annotations:
template['metadata']['annotations'] = op.pod_annotations
if op.pod_labels:
template['metadata']['labels'] = op.pod_labels

return template

def _get_groups_for_ops(self, root_group):
Expand Down
38 changes: 31 additions & 7 deletions sdk/python/kfp/dsl/_container_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ def __init__(self, name: str, image: str, command: str=None, arguments: str=None
self.volumes = []
self.volume_mounts = []
self.env_variables = []
self.pod_annotations = {}
self.pod_labels = {}

matches = []
if arguments:
Expand Down Expand Up @@ -127,16 +129,16 @@ def _validate_cpu_string(self, cpu_string):
raise ValueError('Invalid cpu string. Should be float or integer, or integer followed '
'by "m".')

def _validate_gpu_string(self, gpu_string):
"Validate a given string is valid for gpu limit."
def _validate_positive_number(self, str_value, param_name):
"Validate a given string is in positive integer format."

try:
gpu_value = int(gpu_string)
int_value = int(str_value)
except ValueError:
raise ValueError('Invalid gpu string. Should be integer.')
raise ValueError('Invalid {}. Should be integer.'.format(param_name))

if gpu_value <= 0:
raise ValueError('gpu must be positive integer.')
if int_value <= 0:
raise ValueError('{} must be positive integer.'.format(param_name))

def add_resource_limit(self, resource_name, value):
"""Add the resource limit of the container.
Expand Down Expand Up @@ -212,7 +214,7 @@ def set_gpu_limit(self, gpu, vendor = "nvidia"):
are: 'nvidia' (default), and 'amd'.
"""

self._validate_gpu_string(gpu)
self._validate_positive_number(gpu, 'gpu')
if vendor != 'nvidia' and vendor != 'amd':
raise ValueError('vendor can only be nvidia or amd.')

Expand Down Expand Up @@ -268,5 +270,27 @@ def add_node_selector_constraint(self, label_name, value):
self.node_selector[label_name] = value
return self

def add_pod_annotation(self, name: str, value: str):
"""Adds a pod's metadata annotation.
Args:
name: The name of the annotation.
value: The value of the annotation.
"""

self.pod_annotations[name] = value
return self

def add_pod_label(self, name: str, value: str):
"""Adds a pod's metadata label.
Args:
name: The name of the label.
value: The value of the label.
"""

self.pod_labels[name] = value
return self

def __repr__(self):
return str({self.__class__.__name__: self.__dict__})
22 changes: 22 additions & 0 deletions sdk/python/kfp/gcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,25 @@ def _use_gcp_secret(task):
)

return _use_gcp_secret

def use_tpu(tpu_cores: int, tpu_resource: str, tf_version: str):
"""An operator that configures GCP TPU spec in a container op.
Args:
tpu_cores: Required. The number of cores of TPU resource.
For example, the value can be '8', '32', '128', etc.
Check more details at: https://cloud.google.com/tpu/docs/kubernetes-engine-setup#pod-spec.
tpu_resource: Required. The resource name of the TPU resource.
For example, the value can be 'v2', 'preemptible-v1', 'v3' or 'preemptible-v3'.
Check more details at: https://cloud.google.com/tpu/docs/kubernetes-engine-setup#pod-spec.
tf_version: Required. The TensorFlow version that the TPU nodes use.
For example, the value can be '1.12', '1.11', '1.9' or '1.8'.
Check more details at: https://cloud.google.com/tpu/docs/supported-versions.
"""

def _set_tpu_spec(task):
task.add_pod_annotation('tf-version.cloud-tpus.google.com', tf_version)
task.add_resource_limit('cloud-tpus.google.com/{}'.format(tpu_resource), str(tpu_cores))
return task

return _set_tpu_spec
2 changes: 2 additions & 0 deletions sdk/python/tests/compiler/testdata/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@


import kfp.dsl as dsl
import kfp.gcp as gcp


class GetFrequentWordOp(dsl.ContainerOp):
Expand Down Expand Up @@ -87,3 +88,4 @@ def save_most_frequent_word(message: str, outputpath: str):
saver.set_cpu_limit('0.5')
saver.set_gpu_limit('2')
saver.add_node_selector_constraint('cloud.google.com/gke-accelerator', 'nvidia-tesla-k80')
saver.apply(gcp.use_tpu(tpu_cores = 8, tpu_resource = 'v2', tf_version = '1.12'))
4 changes: 4 additions & 0 deletions sdk/python/tests/compiler/testdata/basic.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -140,12 +140,16 @@ spec:
image: google/cloud-sdk
resources:
limits:
cloud-tpus.google.com/v2: "8"
cpu: "0.5"
nvidia.com/gpu: "2"
inputs:
parameters:
- name: get-frequent-word
- name: outputpath
metadata:
annotations:
tf-version.cloud-tpus.google.com: "1.12"
name: save
nodeSelector:
cloud.google.com/gke-accelerator: nvidia-tesla-k80
Expand Down

0 comments on commit 55e1826

Please sign in to comment.