diff --git a/tfx/orchestration/kubeflow/kubeflow_dag_runner_test.py b/tfx/orchestration/kubeflow/kubeflow_dag_runner_test.py index c070605d4b..9335a6123c 100644 --- a/tfx/orchestration/kubeflow/kubeflow_dag_runner_test.py +++ b/tfx/orchestration/kubeflow/kubeflow_dag_runner_test.py @@ -172,6 +172,17 @@ def testVolumeMountingPipelineOperatorFuncs(self): ] self.assertEqual(2, len(container_templates)) + volumes = [{ + 'name': 'my-volume-name', + 'persistentVolumeClaim': { + 'claimName': 'my-persistent-volume-claim' + } + }] + + # Check that the PVC is specified for kfp<=0.1.31.1. + if 'volumes' in pipeline['spec']: + self.assertEqual(volumes, pipeline['spec']['volumes']) + for template in container_templates: # Check that each container has the volume mounted. self.assertEqual([{ @@ -179,13 +190,9 @@ def testVolumeMountingPipelineOperatorFuncs(self): 'mountPath': '/mnt/volume-mount-path' }], template['container']['volumeMounts']) - # Check that each template has the PVC specified. - self.assertEqual([{ - 'name': 'my-volume-name', - 'persistentVolumeClaim': { - 'claimName': 'my-persistent-volume-claim' - } - }], template['volumes']) + # Check that each template has the PVC specified for kfp>=0.1.31.2. + if 'volumes' in template: + self.assertEqual(volumes, template['volumes']) if __name__ == '__main__':