Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

well-defined parameter types #978

Merged
merged 14 commits into from
Mar 20, 2019
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions sdk/python/kfp/dsl/_pipeline_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,8 @@ def __ge__(self, other):
def __hash__(self):
return hash((self.op_name, self.name))

def ignore_type(self):
"""ignore_type ignores the type information such that type checking would also pass"""
self.param_type = TypeMeta()
return self

34 changes: 13 additions & 21 deletions sdk/python/kfp/dsl/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,72 +19,63 @@ class BaseType:

# Primitive Types
class Integer(BaseType):
openapi_schema_validator = {
openAPIV3Schema = {
"type": "integer"
}

class String(BaseType):
openapi_schema_validator = {
openAPIV3Schema = {
"type": "string"
}

class Float(BaseType):
openapi_schema_validator = {
openAPIV3Schema = {
"type": "number"
}

class Bool(BaseType):
openapi_schema_validator = {
openAPIV3Schema = {
"type": "boolean"
}

class List(BaseType):
openapi_schema_validator = {
openAPIV3Schema = {
"type": "array"
}

class Dict(BaseType):
openapi_schema_validator = {
openAPIV3Schema = {
"type": "object",
}

# GCP Types
class GCSPath(BaseType):
openapi_schema_validator = {
openAPIV3Schema = {
"type": "string",
"pattern": "^gs://.*$"
}

def __init__(self, path_type='', file_type=''):
'''
Args
:param path_type: describes the paths, for example, bucket, directory, file, etc
:param file_type: describes the files, for example, JSON, CSV, etc.
'''
self.path_type = path_type
self.file_type = file_type

class GCRPath(BaseType):
openapi_schema_validator = {
openAPIV3Schema = {
"type": "string",
"pattern": "^.*gcr\\.io/.*$"
}

class GCPRegion(BaseType):
openapi_schema_validator = {
openAPIV3Schema = {
"type": "string"
}

class GCPProjectID(BaseType):
'''MetaGCPProjectID: GCP project id'''
openapi_schema_validator = {
openAPIV3Schema = {
"type": "string"
}

# General Types
class LocalPath(BaseType):
#TODO: add restriction to path
openapi_schema_validator = {
openAPIV3Schema = {
"type": "string"
}

Expand Down Expand Up @@ -121,8 +112,9 @@ def _check_valid_type_dict(payload):
if not isinstance(payload[type_name], dict):
return False
property_types = (int, str, float, bool)
property_value_types = (int, str, float, bool, dict)
for property_name in payload[type_name]:
if not isinstance(property_name, property_types) or not isinstance(payload[type_name][property_name], property_types):
if not isinstance(property_name, property_types) or not isinstance(payload[type_name][property_name], property_value_types):
return False
return True

Expand Down
134 changes: 117 additions & 17 deletions sdk/python/tests/components/test_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ def test_type_check_all_with_types(self):
inputs:
- {name: field_l, type: Integer}
outputs:
- {name: field_m, type: {GCSPath: {path_type: file, file_type: csv}}}
- {name: field_m, type: {ArtifactA: {path_type: file, file_type: csv}}}
- {name: field_n, type: {customized_type: {property_a: value_a, property_b: value_b}}}
- {name: field_o, type: GcsUri}
implementation:
Expand All @@ -532,7 +532,7 @@ def test_type_check_all_with_types(self):
inputs:
- {name: field_x, type: {customized_type: {property_a: value_a, property_b: value_b}}}
- {name: field_y, type: GcsUri}
- {name: field_z, type: {GCSPath: {path_type: file, file_type: csv}}}
- {name: field_z, type: {ArtifactA: {path_type: file, file_type: csv}}}
outputs:
- {name: output_model_uri, type: GcsUri}
implementation:
Expand All @@ -553,14 +553,14 @@ def test_type_check_all_with_types(self):
a = task_factory_a(field_l=12)
b = task_factory_b(field_x=a.outputs['field_n'], field_y=a.outputs['field_o'], field_z=a.outputs['field_m'])

def test_type_check_all_with_lacking_types(self):
def test_type_check_with_lacking_types(self):
component_a = '''\
name: component a
description: component a desc
inputs:
- {name: field_l, type: Integer}
outputs:
- {name: field_m, type: {GCSPath: {path_type: file, file_type: csv}}}
- {name: field_m, type: {ArtifactA: {path_type: file, file_type: csv}}}
- {name: field_n}
- {name: field_o, type: GcsUri}
implementation:
Expand All @@ -581,7 +581,7 @@ def test_type_check_all_with_lacking_types(self):
inputs:
- {name: field_x, type: {customized_type: {property_a: value_a, property_b: value_b}}}
- {name: field_y}
- {name: field_z, type: {GCSPath: {path_type: file, file_type: csv}}}
- {name: field_z, type: {ArtifactA: {path_type: file, file_type: csv}}}
outputs:
- {name: output_model_uri, type: GcsUri}
implementation:
Expand All @@ -602,14 +602,14 @@ def test_type_check_all_with_lacking_types(self):
a = task_factory_a(field_l=12)
b = task_factory_b(field_x=a.outputs['field_n'], field_y=a.outputs['field_o'], field_z=a.outputs['field_m'])

def test_type_check_all_with_inconsistent_types_property_value(self):
def test_type_check_with_inconsistent_types_property_value(self):
component_a = '''\
name: component a
description: component a desc
inputs:
- {name: field_l, type: Integer}
outputs:
- {name: field_m, type: {GCSPath: {path_type: file, file_type: tsv}}}
- {name: field_m, type: {ArtifactA: {path_type: file, file_type: tsv}}}
- {name: field_n, type: {customized_type: {property_a: value_a, property_b: value_b}}}
- {name: field_o, type: GcsUri}
implementation:
Expand All @@ -630,7 +630,7 @@ def test_type_check_all_with_inconsistent_types_property_value(self):
inputs:
- {name: field_x, type: {customized_type: {property_a: value_a, property_b: value_b}}}
- {name: field_y, type: GcsUri}
- {name: field_z, type: {GCSPath: {path_type: file, file_type: csv}}}
- {name: field_z, type: {ArtifactA: {path_type: file, file_type: csv}}}
outputs:
- {name: output_model_uri, type: GcsUri}
implementation:
Expand All @@ -652,14 +652,14 @@ def test_type_check_all_with_inconsistent_types_property_value(self):
with self.assertRaises(InconsistentTypeException):
b = task_factory_b(field_x=a.outputs['field_n'], field_y=a.outputs['field_o'], field_z=a.outputs['field_m'])

def test_type_check_all_with_inconsistent_types_type_name(self):
def test_type_check_with_inconsistent_types_type_name(self):
component_a = '''\
name: component a
description: component a desc
inputs:
- {name: field_l, type: Integer}
outputs:
- {name: field_m, type: {GCSPath: {path_type: file, file_type: csv}}}
- {name: field_m, type: {ArtifactA: {path_type: file, file_type: csv}}}
- {name: field_n, type: {customized_type: {property_a: value_a, property_b: value_b}}}
- {name: field_o, type: GcrUri}
implementation:
Expand All @@ -680,7 +680,7 @@ def test_type_check_all_with_inconsistent_types_type_name(self):
inputs:
- {name: field_x, type: {customized_type: {property_a: value_a, property_b: value_b}}}
- {name: field_y, type: GcsUri}
- {name: field_z, type: {GCSPath: {path_type: file, file_type: csv}}}
- {name: field_z, type: {ArtifactA: {path_type: file, file_type: csv}}}
outputs:
- {name: output_model_uri, type: GcsUri}
implementation:
Expand All @@ -702,14 +702,14 @@ def test_type_check_all_with_inconsistent_types_type_name(self):
with self.assertRaises(InconsistentTypeException):
b = task_factory_b(field_x=a.outputs['field_n'], field_y=a.outputs['field_o'], field_z=a.outputs['field_m'])

def test_type_check_all_with_consistent_types_nonnamed_inputs(self):
def test_type_check_with_consistent_types_nonnamed_inputs(self):
component_a = '''\
name: component a
description: component a desc
inputs:
- {name: field_l, type: Integer}
outputs:
- {name: field_m, type: {GCSPath: {path_type: file, file_type: csv}}}
- {name: field_m, type: {ArtifactA: {path_type: file, file_type: csv}}}
- {name: field_n, type: {customized_type: {property_a: value_a, property_b: value_b}}}
- {name: field_o, type: GcsUri}
implementation:
Expand All @@ -730,7 +730,7 @@ def test_type_check_all_with_consistent_types_nonnamed_inputs(self):
inputs:
- {name: field_x, type: {customized_type: {property_a: value_a, property_b: value_b}}}
- {name: field_y, type: GcsUri}
- {name: field_z, type: {GCSPath: {path_type: file, file_type: csv}}}
- {name: field_z, type: {ArtifactA: {path_type: file, file_type: csv}}}
outputs:
- {name: output_model_uri, type: GcsUri}
implementation:
Expand All @@ -751,14 +751,14 @@ def test_type_check_all_with_consistent_types_nonnamed_inputs(self):
a = task_factory_a(field_l=12)
b = task_factory_b(a.outputs['field_n'], field_z=a.outputs['field_m'], field_y=a.outputs['field_o'])

def test_type_check_all_with_inconsistent_types_disabled(self):
def test_type_check_with_inconsistent_types_disabled(self):
component_a = '''\
name: component a
description: component a desc
inputs:
- {name: field_l, type: Integer}
outputs:
- {name: field_m, type: {GCSPath: {path_type: file, file_type: csv}}}
- {name: field_m, type: {ArtifactA: {path_type: file, file_type: csv}}}
- {name: field_n, type: {customized_type: {property_a: value_a, property_b: value_b}}}
- {name: field_o, type: GcrUri}
implementation:
Expand All @@ -779,7 +779,7 @@ def test_type_check_all_with_inconsistent_types_disabled(self):
inputs:
- {name: field_x, type: {customized_type: {property_a: value_a, property_b: value_b}}}
- {name: field_y, type: GcsUri}
- {name: field_z, type: {GCSPath: {path_type: file, file_type: csv}}}
- {name: field_z, type: {ArtifactA: {path_type: file, file_type: csv}}}
outputs:
- {name: output_model_uri, type: GcsUri}
implementation:
Expand All @@ -800,5 +800,105 @@ def test_type_check_all_with_inconsistent_types_disabled(self):
a = task_factory_a(field_l=12)
b = task_factory_b(field_x=a.outputs['field_n'], field_y=a.outputs['field_o'], field_z=a.outputs['field_m'])

def test_type_check_with_openapi_shema(self):
component_a = '''\
name: component a
description: component a desc
inputs:
- {name: field_l, type: Integer}
outputs:
- {name: field_m, type: {GCSPath: {openAPIV3Schema: {type: string, pattern: ^gs://.*$ } }}}
- {name: field_n, type: {customized_type: {property_a: value_a, property_b: value_b}}}
- {name: field_o, type: GcrUri}
implementation:
container:
image: gcr.io/ml-pipeline/component-a
command: [python3, /pipelines/component/src/train.py]
args: [
--field-l, {inputValue: field_l},
]
fileOutputs:
field_m: /schema.txt
field_n: /feature.txt
field_o: /output.txt
'''
component_b = '''\
name: component b
description: component b desc
inputs:
- {name: field_x, type: {customized_type: {property_a: value_a, property_b: value_b}}}
- {name: field_y, type: GcrUri}
- {name: field_z, type: {GCSPath: {openAPIV3Schema: {type: string, pattern: ^gs://.*$ } }}}
outputs:
- {name: output_model_uri, type: GcsUri}
implementation:
container:
image: gcr.io/ml-pipeline/component-a
command: [python3]
args: [
--field-x, {inputValue: field_x},
--field-y, {inputValue: field_y},
--field-z, {inputValue: field_z},
]
fileOutputs:
output_model_uri: /schema.txt
'''
kfp.TYPE_CHECK = True
task_factory_a = comp.load_component_from_text(text=component_a)
task_factory_b = comp.load_component_from_text(text=component_b)
a = task_factory_a(field_l=12)
b = task_factory_b(field_x=a.outputs['field_n'], field_y=a.outputs['field_o'], field_z=a.outputs['field_m'])

def test_type_check_ignore_type(self):
component_a = '''\
name: component a
description: component a desc
inputs:
- {name: field_l, type: Integer}
outputs:
- {name: field_m, type: {GCSPath: {openAPIV3Schema: {type: string, pattern: ^gs://.*$ } }}}
- {name: field_n, type: {customized_type: {property_a: value_a, property_b: value_b}}}
- {name: field_o, type: GcrUri}
implementation:
container:
image: gcr.io/ml-pipeline/component-a
command: [python3, /pipelines/component/src/train.py]
args: [
--field-l, {inputValue: field_l},
]
fileOutputs:
field_m: /schema.txt
field_n: /feature.txt
field_o: /output.txt
'''
component_b = '''\
name: component b
description: component b desc
inputs:
- {name: field_x, type: {customized_type: {property_a: value_a, property_b: value_b}}}
- {name: field_y, type: GcrUri}
- {name: field_z, type: {GCSPath: {openAPIV3Schema: {type: string, pattern: ^gcs://.*$ } }}}
outputs:
- {name: output_model_uri, type: GcsUri}
implementation:
container:
image: gcr.io/ml-pipeline/component-a
command: [python3]
args: [
--field-x, {inputValue: field_x},
--field-y, {inputValue: field_y},
--field-z, {inputValue: field_z},
]
fileOutputs:
output_model_uri: /schema.txt
'''
kfp.TYPE_CHECK = True
task_factory_a = comp.load_component_from_text(text=component_a)
task_factory_b = comp.load_component_from_text(text=component_b)
a = task_factory_a(field_l=12)
with self.assertRaises(InconsistentTypeException):
b = task_factory_b(field_x=a.outputs['field_n'], field_y=a.outputs['field_o'], field_z=a.outputs['field_m'])
b = task_factory_b(field_x=a.outputs['field_n'], field_y=a.outputs['field_o'], field_z=a.outputs['field_m'].ignore_type())

if __name__ == '__main__':
unittest.main()
Loading