Skip to content

Commit

Permalink
SDK - Hiding Argo's workflow.uid placeholder behind DSL
Browse files Browse the repository at this point in the history
Fixes #1673
  • Loading branch information
Ark-kun committed Jul 29, 2019
1 parent bb339a9 commit 995d9de
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 6 deletions.
2 changes: 1 addition & 1 deletion components/sample/keras/train_classifier/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ training_set_labels_gcs_path = os.path.join(input_data_gcs_dir, 'training_set_la
gfile.Copy(training_set_features_local_path, training_set_features_gcs_path)
gfile.Copy(training_set_labels_local_path, training_set_labels_gcs_path)

output_model_uri_template = os.path.join(output_data_gcs_dir, '{{workflow.uid}}/{{pod.name}}/output_model_uri/data')
output_model_uri_template = os.path.join(output_data_gcs_dir, kfp.dsl.task_id_placeholder, 'output_model_uri', 'data')

xor_model_config = requests.get(test_data_url_prefix + 'model_config.json').content

Expand Down
2 changes: 1 addition & 1 deletion samples/kubeflow-tf/kubeflow-training-classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def kubeflow_training(output, project,
preprocess_mode='local',
predict_mode='local',
):
output_template = str(output) + '/{{workflow.uid}}/{{pod.name}}/data'
output_template = str(output) + '/' + dsl.task_id_placeholder + '/data'

# set the flag to use GPU trainer
use_gpu = False
Expand Down
4 changes: 2 additions & 2 deletions samples/tfx/taxi-cab-classification-pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,11 @@ def taxi_cab_classification(
steps=3000,
analyze_slice_column='trip_start_hour'
):
output_template = str(output) + '/{{workflow.uid}}/{{pod.name}}/data'
output_template = str(output) + '/' + dsl.task_id_placeholder + '/data'
target_lambda = """lambda x: (x['target'] > x['fare'] * 0.2)"""
target_class_lambda = """lambda x: 1 if (x['target'] > x['fare'] * 0.2) else 0"""

tf_server_name = 'taxi-cab-classification-model-{{workflow.uid}}'
tf_server_name = 'taxi-cab-classification-model-' + dsl.task_id_placeholder

if platform != 'GCP':
vop = dsl.VolumeOp(
Expand Down
2 changes: 1 addition & 1 deletion samples/xgboost-spark/xgboost-training-cm.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def xgb_train_pipeline(
workers=2,
true_label='ACTION',
):
output_template = str(output) + '/{{workflow.uid}}/{{pod.name}}/data'
output_template = str(output) + '/' + dsl.task_id_placeholder + '/data'

delete_cluster_op = dataproc_delete_cluster_op(
project,
Expand Down
5 changes: 4 additions & 1 deletion sdk/python/kfp/dsl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,7 @@
from ._volume_snapshot_op import VolumeSnapshotOp
from ._ops_group import OpsGroup, ExitHandler, Condition
from ._component import python_component, graph_component, component
from ._artifact_location import ArtifactLocation
from ._artifact_location import ArtifactLocation

task_id_placeholder = '{{workflow.uid}}-{{pod.name}}'
run_id_placeholder = '{{workflow.uid}}'

0 comments on commit 995d9de

Please sign in to comment.