Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

well-defined parameter types #978

Merged
merged 14 commits into from
Mar 20, 2019
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 5 additions & 9 deletions sdk/python/kfp/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,16 +219,12 @@ def _process_args(self, raw_args, argument_inputs):
processed_args = list(map(str, raw_args))
for i, _ in enumerate(processed_args):
# unsanitized_argument_inputs stores a dict: string of sanitized param -> string of unsanitized param
matches = []
matches += _match_serialized_pipelineparam(str(processed_args[i]))
param_tuples = []
param_tuples += _match_serialized_pipelineparam(str(processed_args[i]))
unsanitized_argument_inputs = {}
for x in list(set(matches)):
if len(x) == 3 or (len(x) == 4 and x[3] == ''):
sanitized_str = str(dsl.PipelineParam(K8sHelper.sanitize_k8s_name(x[1]), K8sHelper.sanitize_k8s_name(x[0]), x[2]))
unsanitized_argument_inputs[sanitized_str] = str(dsl.PipelineParam(x[1], x[0], x[2]))
elif len(x) == 4:
sanitized_str = str(dsl.PipelineParam(K8sHelper.sanitize_k8s_name(x[1]), K8sHelper.sanitize_k8s_name(x[0]), x[2], TypeMeta.from_dict_or_str(x[3])))
unsanitized_argument_inputs[sanitized_str] = str(dsl.PipelineParam(x[1], x[0], x[2], TypeMeta.from_dict_or_str(x[3])))
for param_tuple in list(set(param_tuples)):
sanitized_str = str(dsl.PipelineParam(K8sHelper.sanitize_k8s_name(param_tuple.name), K8sHelper.sanitize_k8s_name(param_tuple.op), param_tuple.value, TypeMeta.deserialize(param_tuple.type)))
unsanitized_argument_inputs[sanitized_str] = str(dsl.PipelineParam(param_tuple.name, param_tuple.op, param_tuple.value, TypeMeta.deserialize(param_tuple.type)))
if argument_inputs:
for param in argument_inputs:
if str(param) in unsanitized_argument_inputs:
Expand Down
38 changes: 25 additions & 13 deletions sdk/python/kfp/dsl/_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,24 +46,36 @@ def to_dict_or_str(self):
return {self.name: self.properties}

@staticmethod
def from_dict_or_str(json):
def from_dict_or_str(payload):
'''from_dict_or_str accepts a payload object and returns a TypeMeta instance
Args:
payload (str/dict): the payload could be a str or a dict
'''

type_meta = TypeMeta()
if isinstance(json, str) and '{' in json:
import ast
json = ast.literal_eval(json)
if isinstance(json, dict):
if not _check_valid_type_dict(json):
raise ValueError(json + ' is not a valid type string')
type_meta.name, type_meta.properties = list(json.items())[0]
if isinstance(payload, dict):
if not _check_valid_type_dict(payload):
raise ValueError(payload + ' is not a valid type string')
type_meta.name, type_meta.properties = list(payload.items())[0]
# Convert possible OrderedDict to dict
type_meta.properties = dict(type_meta.properties)
elif isinstance(json, str):
type_meta.name = json
elif isinstance(payload, str):
type_meta.name = payload
return type_meta

def serialize(self):
return str(self.to_dict_or_str())

@staticmethod
def deserialize(payload):
# If the payload is a string of a dict serialization, convert it back to a dict
try:
import ast
payload = ast.literal_eval(payload)
except:
pass
return TypeMeta.from_dict_or_str(payload)

class ParameterMeta(BaseMeta):
def __init__(self,
name: str,
Expand Down Expand Up @@ -128,13 +140,13 @@ def _annotation_to_typemeta(annotation):
TypeMeta
'''
if isinstance(annotation, BaseType):
arg_type = TypeMeta.from_dict_or_str(_instance_to_dict(annotation))
arg_type = TypeMeta.deserialize(_instance_to_dict(annotation))
elif isinstance(annotation, str):
arg_type = TypeMeta.from_dict_or_str(annotation)
arg_type = TypeMeta.deserialize(annotation)
elif isinstance(annotation, dict):
if not _check_valid_type_dict(annotation):
raise ValueError('Annotation ' + str(annotation) + ' is not a valid type dictionary.')
arg_type = TypeMeta.from_dict_or_str(annotation)
arg_type = TypeMeta.deserialize(annotation)
else:
return TypeMeta()
return arg_type
34 changes: 22 additions & 12 deletions sdk/python/kfp/dsl/_pipeline_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,26 @@
# TODO: Move this to a separate class
# For now, this identifies a condition with only "==" operator supported.
ConditionOperator = namedtuple('ConditionOperator', 'operator operand1 operand2')
PipelineParamTuple = namedtuple('PipelineParamTuple', 'name op value type')

def _match_serialized_pipelineparam(payload: str):
"""_match_serialized_pipelineparam matches the serialized pipelineparam.
Args:
payloads (str): a string that contains the serialized pipelineparam.

Returns:
List(tuple())"""
match = re.findall(r'{{pipelineparam:op=([\w\s_-]*);name=([\w\s_-]+);value=(.*?);type=(.*?);}}', payload)
if len(match) == 0:
match = re.findall(r'{{pipelineparam:op=([\w\s_-]*);name=([\w\s_-]+);value=(.*?)}}', payload)
return match
PipelineParamTuple
"""
matches = re.findall(r'{{pipelineparam:op=([\w\s_-]*);name=([\w\s_-]+);value=(.*?);type=(.*?);}}', payload)
if len(matches) == 0:
matches = re.findall(r'{{pipelineparam:op=([\w\s_-]*);name=([\w\s_-]+);value=(.*?)}}', payload)
param_tuples = []
for match in matches:
if len(match) == 3:
param_tuples.append(PipelineParamTuple(name=match[1], op=match[0], value=match[2], type=''))
elif len(match) == 4:
param_tuples.append(PipelineParamTuple(name=match[1], op=match[0], value=match[2], type=match[3]))
return param_tuples

def _extract_pipelineparams(payloads: str or list[str]):
"""_extract_pipelineparam extract a list of PipelineParam instances from the payload string.
Expand All @@ -45,15 +53,12 @@ def _extract_pipelineparams(payloads: str or list[str]):
"""
if isinstance(payloads, str):
payloads = [payloads]
matches = []
param_tuples = []
for payload in payloads:
matches += _match_serialized_pipelineparam(payload)
param_tuples += _match_serialized_pipelineparam(payload)
pipeline_params = []
for x in list(set(matches)):
if len(x) == 3 or (len(x) == 4 and x[3] == ''):
pipeline_params.append(PipelineParam(x[1], x[0], x[2]))
elif len(x) == 4:
pipeline_params.append(PipelineParam(x[1], x[0], x[2], TypeMeta.from_dict_or_str(x[3])))
for param_tuple in list(set(param_tuples)):
pipeline_params.append(PipelineParam(param_tuple.name, param_tuple.op, param_tuple.value, TypeMeta.deserialize(param_tuple.type)))
return pipeline_params

class PipelineParam(object):
Expand Down Expand Up @@ -136,3 +141,8 @@ def __ge__(self, other):
def __hash__(self):
return hash((self.op_name, self.name))

def ignore_type(self):
"""ignore_type ignores the type information such that type checking would also pass"""
self.param_type = TypeMeta()
return self

34 changes: 13 additions & 21 deletions sdk/python/kfp/dsl/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,72 +19,63 @@ class BaseType:

# Primitive Types
class Integer(BaseType):
openapi_schema_validator = {
openAPIV3Schema = {
"type": "integer"
}

class String(BaseType):
openapi_schema_validator = {
openAPIV3Schema = {
"type": "string"
}

class Float(BaseType):
openapi_schema_validator = {
openAPIV3Schema = {
"type": "number"
}

class Bool(BaseType):
openapi_schema_validator = {
openAPIV3Schema = {
"type": "boolean"
}

class List(BaseType):
openapi_schema_validator = {
openAPIV3Schema = {
"type": "array"
}

class Dict(BaseType):
openapi_schema_validator = {
openAPIV3Schema = {
"type": "object",
}

# GCP Types
class GCSPath(BaseType):
openapi_schema_validator = {
openAPIV3Schema = {
"type": "string",
"pattern": "^gs://.*$"
}

def __init__(self, path_type='', file_type=''):
'''
Args
:param path_type: describes the paths, for example, bucket, directory, file, etc
:param file_type: describes the files, for example, JSON, CSV, etc.
'''
self.path_type = path_type
self.file_type = file_type

class GCRPath(BaseType):
openapi_schema_validator = {
openAPIV3Schema = {
"type": "string",
"pattern": "^.*gcr\\.io/.*$"
}

class GCPRegion(BaseType):
openapi_schema_validator = {
openAPIV3Schema = {
"type": "string"
}

class GCPProjectID(BaseType):
'''MetaGCPProjectID: GCP project id'''
openapi_schema_validator = {
openAPIV3Schema = {
"type": "string"
}

# General Types
class LocalPath(BaseType):
#TODO: add restriction to path
openapi_schema_validator = {
openAPIV3Schema = {
"type": "string"
}

Expand Down Expand Up @@ -121,8 +112,9 @@ def _check_valid_type_dict(payload):
if not isinstance(payload[type_name], dict):
return False
property_types = (int, str, float, bool)
property_value_types = (int, str, float, bool, dict)
for property_name in payload[type_name]:
if not isinstance(property_name, property_types) or not isinstance(payload[type_name][property_name], property_types):
if not isinstance(property_name, property_types) or not isinstance(payload[type_name][property_name], property_value_types):
return False
return True

Expand Down
110 changes: 105 additions & 5 deletions sdk/python/tests/components/test_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,7 @@ def test_type_check_all_with_types(self):
a = task_factory_a(field_l=12)
b = task_factory_b(field_x=a.outputs['field_n'], field_y=a.outputs['field_o'], field_z=a.outputs['field_m'])

def test_type_check_all_with_lacking_types(self):
def test_type_check_with_lacking_types(self):
component_a = '''\
name: component a
description: component a desc
Expand Down Expand Up @@ -602,7 +602,7 @@ def test_type_check_all_with_lacking_types(self):
a = task_factory_a(field_l=12)
b = task_factory_b(field_x=a.outputs['field_n'], field_y=a.outputs['field_o'], field_z=a.outputs['field_m'])

def test_type_check_all_with_inconsistent_types_property_value(self):
def test_type_check_with_inconsistent_types_property_value(self):
component_a = '''\
name: component a
description: component a desc
Expand Down Expand Up @@ -652,7 +652,7 @@ def test_type_check_all_with_inconsistent_types_property_value(self):
with self.assertRaises(InconsistentTypeException):
b = task_factory_b(field_x=a.outputs['field_n'], field_y=a.outputs['field_o'], field_z=a.outputs['field_m'])

def test_type_check_all_with_inconsistent_types_type_name(self):
def test_type_check_with_inconsistent_types_type_name(self):
component_a = '''\
name: component a
description: component a desc
Expand Down Expand Up @@ -702,7 +702,7 @@ def test_type_check_all_with_inconsistent_types_type_name(self):
with self.assertRaises(InconsistentTypeException):
b = task_factory_b(field_x=a.outputs['field_n'], field_y=a.outputs['field_o'], field_z=a.outputs['field_m'])

def test_type_check_all_with_consistent_types_nonnamed_inputs(self):
def test_type_check_with_consistent_types_nonnamed_inputs(self):
component_a = '''\
name: component a
description: component a desc
Expand Down Expand Up @@ -751,7 +751,7 @@ def test_type_check_all_with_consistent_types_nonnamed_inputs(self):
a = task_factory_a(field_l=12)
b = task_factory_b(a.outputs['field_n'], field_z=a.outputs['field_m'], field_y=a.outputs['field_o'])

def test_type_check_all_with_inconsistent_types_disabled(self):
def test_type_check_with_inconsistent_types_disabled(self):
component_a = '''\
name: component a
description: component a desc
Expand Down Expand Up @@ -800,5 +800,105 @@ def test_type_check_all_with_inconsistent_types_disabled(self):
a = task_factory_a(field_l=12)
b = task_factory_b(field_x=a.outputs['field_n'], field_y=a.outputs['field_o'], field_z=a.outputs['field_m'])

def test_type_check_with_openapi_shema(self):
component_a = '''\
name: component a
description: component a desc
inputs:
- {name: field_l, type: Integer}
outputs:
- {name: field_m, type: {GCSPath: {openAPIV3Schema: {type: string, pattern: ^gs://.*$ } }}}
- {name: field_n, type: {customized_type: {property_a: value_a, property_b: value_b}}}
- {name: field_o, type: GcrUri}
implementation:
container:
image: gcr.io/ml-pipeline/component-a
command: [python3, /pipelines/component/src/train.py]
args: [
--field-l, {inputValue: field_l},
]
fileOutputs:
field_m: /schema.txt
field_n: /feature.txt
field_o: /output.txt
'''
component_b = '''\
name: component b
description: component b desc
inputs:
- {name: field_x, type: {customized_type: {property_a: value_a, property_b: value_b}}}
- {name: field_y, type: GcrUri}
- {name: field_z, type: {GCSPath: {openAPIV3Schema: {type: string, pattern: ^gs://.*$ } }}}
outputs:
- {name: output_model_uri, type: GcsUri}
implementation:
container:
image: gcr.io/ml-pipeline/component-a
command: [python3]
args: [
--field-x, {inputValue: field_x},
--field-y, {inputValue: field_y},
--field-z, {inputValue: field_z},
]
fileOutputs:
output_model_uri: /schema.txt
'''
kfp.TYPE_CHECK = True
task_factory_a = comp.load_component_from_text(text=component_a)
task_factory_b = comp.load_component_from_text(text=component_b)
a = task_factory_a(field_l=12)
b = task_factory_b(field_x=a.outputs['field_n'], field_y=a.outputs['field_o'], field_z=a.outputs['field_m'])

def test_type_check_ignore_type(self):
component_a = '''\
name: component a
description: component a desc
inputs:
- {name: field_l, type: Integer}
outputs:
- {name: field_m, type: {GCSPath: {openAPIV3Schema: {type: string, pattern: ^gs://.*$ } }}}
- {name: field_n, type: {customized_type: {property_a: value_a, property_b: value_b}}}
- {name: field_o, type: GcrUri}
implementation:
container:
image: gcr.io/ml-pipeline/component-a
command: [python3, /pipelines/component/src/train.py]
args: [
--field-l, {inputValue: field_l},
]
fileOutputs:
field_m: /schema.txt
field_n: /feature.txt
field_o: /output.txt
'''
component_b = '''\
name: component b
description: component b desc
inputs:
- {name: field_x, type: {customized_type: {property_a: value_a, property_b: value_b}}}
- {name: field_y, type: GcrUri}
- {name: field_z, type: {GCSPath: {openAPIV3Schema: {type: string, pattern: ^gcs://.*$ } }}}
outputs:
- {name: output_model_uri, type: GcsUri}
implementation:
container:
image: gcr.io/ml-pipeline/component-a
command: [python3]
args: [
--field-x, {inputValue: field_x},
--field-y, {inputValue: field_y},
--field-z, {inputValue: field_z},
]
fileOutputs:
output_model_uri: /schema.txt
'''
kfp.TYPE_CHECK = True
task_factory_a = comp.load_component_from_text(text=component_a)
task_factory_b = comp.load_component_from_text(text=component_b)
a = task_factory_a(field_l=12)
with self.assertRaises(InconsistentTypeException):
b = task_factory_b(field_x=a.outputs['field_n'], field_y=a.outputs['field_o'], field_z=a.outputs['field_m'])
b = task_factory_b(field_x=a.outputs['field_n'], field_y=a.outputs['field_o'], field_z=a.outputs['field_m'].ignore_type())

if __name__ == '__main__':
unittest.main()
Loading