-
Notifications
You must be signed in to change notification settings - Fork 308
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[wip] JSON IDL V2 #2741
[wip] JSON IDL V2 #2741
Changes from 12 commits
158ae01
d31a82d
0adf1ae
b7e2bdb
dcf9b4b
24fe614
70e298c
848bd4b
b145637
e42b667
1782dab
313dc8f
717104e
69d9e8b
b1e4c66
ea45e1e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -41,15 +41,7 @@ | |
from flytekit.models import types as _type_models | ||
from flytekit.models.annotation import TypeAnnotation as TypeAnnotationModel | ||
from flytekit.models.core import types as _core_types | ||
from flytekit.models.literals import ( | ||
Literal, | ||
LiteralCollection, | ||
LiteralMap, | ||
Primitive, | ||
Scalar, | ||
Union, | ||
Void, | ||
) | ||
from flytekit.models.literals import Json, Literal, LiteralCollection, LiteralMap, Primitive, Scalar, Union, Void | ||
from flytekit.models.types import LiteralType, SimpleType, TypeStructure, UnionType | ||
|
||
T = typing.TypeVar("T") | ||
|
@@ -196,6 +188,32 @@ | |
""" | ||
return str(python_val) | ||
|
||
def to_json(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal: | ||
json_str = json.dumps(python_val) | ||
json_bytes = json_str.encode("UTF-8") | ||
return Literal(scalar=Scalar(json=Json(value=json_bytes, serialization_format="UTF-8"))) | ||
|
||
def from_json(self, ctx: FlyteContext, json_idl_object: Json, expected_python_type: Type[T]) -> T: | ||
if expected_python_type in [datetime.datetime, datetime.timedelta, datetime.date]: | ||
raise TypeTransformerFailedError( | ||
f"Unsupported Type Error: JSON IDL serialization/deserialization is not supported for Python type " | ||
f"'{expected_python_type.__name__}'.\n" | ||
f"Please ensure that the type is serializable or convert it to a supported format." | ||
) | ||
|
||
value = json_idl_object.value | ||
serialization_format = json_idl_object.serialization_format | ||
if serialization_format == "UTF-8": | ||
json_str = value.decode("UTF-8") | ||
else: | ||
raise ValueError( | ||
f"Bytes can't be converted to JSON String.\n" | ||
f"Unsupported serialization format: {serialization_format}" | ||
) | ||
python_val = json.loads(json_str) | ||
expected_python_val = expected_python_type(python_val) # type: ignore | ||
return expected_python_val | ||
|
||
def __repr__(self): | ||
return f"{self._name} Transforms ({self._t}) to Flyte native" | ||
|
||
|
@@ -240,6 +258,9 @@ | |
f"Cannot convert to type {expected_python_type}, only {self._type} is supported" | ||
) | ||
|
||
if lv.scalar.json: | ||
return self.from_json(ctx, lv.scalar.json, expected_python_type) | ||
|
||
try: # todo(maximsmol): this is quite ugly and each transformer should really check their Literal | ||
res = self._from_literal_transformer(lv) | ||
if type(res) != self._type: | ||
|
@@ -488,9 +509,85 @@ | |
|
||
ts = TypeStructure(tag="", dataclass_type=literal_type) | ||
|
||
return _type_models.LiteralType(simple=_type_models.SimpleType.STRUCT, metadata=schema, structure=ts) | ||
return _type_models.LiteralType(simple=_type_models.SimpleType.JSON, metadata=schema, structure=ts) | ||
|
||
# We use UTF-8 as the default serialization format for JSON | ||
def to_json(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal: | ||
if isinstance(python_val, dict): | ||
json_str = json.dumps(python_val) | ||
json_bytes = json_str.encode("UTF-8") | ||
return Literal(scalar=Scalar(json=Json(value=json_bytes, serialization_format="UTF-8"))) | ||
|
||
if not dataclasses.is_dataclass(python_val): | ||
raise TypeTransformerFailedError( | ||
f"{type(python_val)} is not of type @dataclass, only Dataclasses are supported for " | ||
f"user defined datatypes in Flytekit" | ||
) | ||
|
||
self._make_dataclass_json_serializable(python_val, python_type) | ||
|
||
# The `to_json` integrated through mashumaro's `DataClassJSONMixin` allows for more | ||
# functionality than JSONEncoder | ||
# We can't use hasattr(python_val, "to_json") here because we rely on mashumaro's API to customize the serialization behavior for Flyte types. | ||
if isinstance(python_val, DataClassJSONMixin): | ||
json_str = python_val.to_json() | ||
else: | ||
# The function looks up or creates a JSONEncoder specifically designed for the object's type. | ||
# This encoder is then used to convert a data class into a JSON string. | ||
try: | ||
encoder = self._encoder[python_type] | ||
except KeyError: | ||
encoder = JSONEncoder(python_type) | ||
self._encoder[python_type] = encoder | ||
|
||
try: | ||
json_str = encoder.encode(python_val) | ||
except NotImplementedError: | ||
# you can refer FlyteFile, FlyteDirectory and StructuredDataset to see how flyte types can be implemented. | ||
raise NotImplementedError( | ||
f"{python_type} should inherit from mashumaro.types.SerializableType" | ||
f" and implement _serialize and _deserialize methods." | ||
) | ||
|
||
json_bytes = json_str.encode("UTF-8") | ||
return Literal(scalar=Scalar(json=Json(value=json_bytes, serialization_format="UTF-8"))) | ||
|
||
# We use UTF-8 as the default serialization format for JSON | ||
def from_json(self, ctx: FlyteContext, json_idl_object: Json, expected_python_type: Type[T]) -> T: | ||
value = json_idl_object.value | ||
serialization_format = json_idl_object.serialization_format | ||
|
||
if serialization_format == "UTF-8": | ||
json_str = value.decode("UTF-8") | ||
else: | ||
raise ValueError( | ||
f"Bytes can't be converted to JSON String.\n" | ||
f"Unsupported serialization format: {serialization_format}" | ||
) | ||
|
||
# The `from_json` function is provided from mashumaro's `DataClassJSONMixin`. | ||
# It deserializes a JSON string into a data class, and supports additional functionality over JSONDecoder | ||
# We can't use hasattr(expected_python_type, "from_json") here because we rely on mashumaro's API to customize the deserialization behavior for Flyte types. | ||
if issubclass(expected_python_type, DataClassJSONMixin): | ||
dc = expected_python_type.from_json(json_str) # type: ignore | ||
else: | ||
# The function looks up or creates a JSONDecoder specifically designed for the object's type. | ||
# This decoder is then used to convert a JSON string into a data class. | ||
try: | ||
decoder = self._decoder[expected_python_type] | ||
except KeyError: | ||
decoder = JSONDecoder(expected_python_type) | ||
self._decoder[expected_python_type] = decoder | ||
|
||
dc = decoder.decode(json_str) | ||
|
||
dc = self._fix_structured_dataset_type(expected_python_type, dc) | ||
return self._fix_dataclass_int(expected_python_type, dc) | ||
|
||
def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal: | ||
if expected.simple == SimpleType.JSON: | ||
return self.to_json(ctx, python_val, python_type, expected) | ||
|
||
if isinstance(python_val, dict): | ||
json_str = json.dumps(python_val) | ||
return Literal(scalar=Scalar(generic=_json_format.Parse(json_str, _struct.Struct()))) | ||
|
@@ -501,7 +598,7 @@ | |
f"user defined datatypes in Flytekit" | ||
) | ||
|
||
self._make_dataclass_serializable(python_val, python_type) | ||
self._make_dataclass_json_serializable(python_val, python_type) | ||
|
||
# The `to_json` integrated through mashumaro's `DataClassJSONMixin` allows for more | ||
# functionality than JSONEncoder | ||
|
@@ -570,7 +667,7 @@ | |
python_val.__setattr__(field.name, self._fix_structured_dataset_type(field.type, val)) | ||
return python_val | ||
|
||
def _make_dataclass_serializable(self, python_val: T, python_type: Type[T]) -> typing.Any: | ||
def _make_dataclass_json_serializable(self, python_val: T, python_type: Type[T]) -> typing.Any: | ||
""" | ||
If any field inside the dataclass is flyte type, we should use flyte type transformer for that field. | ||
""" | ||
|
@@ -581,18 +678,18 @@ | |
if UnionTransformer.is_optional_type(python_type): | ||
if python_val is None: | ||
return None | ||
return self._make_dataclass_serializable(python_val, get_args(python_type)[0]) | ||
return self._make_dataclass_json_serializable(python_val, get_args(python_type)[0]) | ||
|
||
if hasattr(python_type, "__origin__") and get_origin(python_type) is list: | ||
if python_val is None: | ||
return None | ||
return [self._make_dataclass_serializable(v, get_args(python_type)[0]) for v in cast(list, python_val)] | ||
return [self._make_dataclass_json_serializable(v, get_args(python_type)[0]) for v in cast(list, python_val)] | ||
|
||
if hasattr(python_type, "__origin__") and get_origin(python_type) is dict: | ||
if python_val is None: | ||
return None | ||
return { | ||
k: self._make_dataclass_serializable(v, get_args(python_type)[1]) | ||
k: self._make_dataclass_json_serializable(v, get_args(python_type)[1]) | ||
for k, v in cast(dict, python_val).items() | ||
} | ||
|
||
|
@@ -618,7 +715,7 @@ | |
dataclass_attributes = typing.get_type_hints(python_type) | ||
for n, t in dataclass_attributes.items(): | ||
val = python_val.__getattribute__(n) | ||
python_val.__setattr__(n, self._make_dataclass_serializable(val, t)) | ||
python_val.__setattr__(n, self._make_dataclass_json_serializable(val, t)) | ||
return python_val | ||
|
||
def _fix_val_int(self, t: typing.Type, val: typing.Any) -> typing.Any: | ||
|
@@ -672,6 +769,10 @@ | |
"user defined datatypes in Flytekit" | ||
) | ||
|
||
json_idl_object = lv.scalar.json | ||
if json_idl_object: | ||
return self.from_json(ctx, json_idl_object, expected_python_type) # type: ignore | ||
|
||
json_str = _json_format.MessageToJson(lv.scalar.generic) | ||
|
||
# The `from_json` function is provided from mashumaro's `DataClassJSONMixin`. | ||
|
@@ -1342,7 +1443,54 @@ | |
lit_list = [TypeEngine.to_literal(ctx, x, t, expected.collection_type) for x in python_val] # type: ignore | ||
return Literal(collection=LiteralCollection(literals=lit_list)) | ||
|
||
def from_json(self, ctx: FlyteContext, json_idl_object: Json, expected_python_type: Type[T]) -> typing.List[T]: | ||
""" | ||
Process JSON IDL object and convert it to the corresponding Python value. | ||
Handles both simple types and recursive structures like List[List[int]] or List[List[float]]. | ||
""" | ||
|
||
def recursive_from_json(ctx: FlyteContext, json_value: typing.Any, expected_python_type: Type[T]) -> typing.Any: | ||
""" | ||
Recursively process JSON objects, converting them to their corresponding Python values based on | ||
the expected Python type (e.g., handling List[List[int]] or List[List[float]]). | ||
""" | ||
# Check if the type is a List | ||
if typing.get_origin(expected_python_type) is list: | ||
# Get the subtype, which should be the type of the list's elements | ||
sub_type = self.get_sub_type(expected_python_type) | ||
# Recursively process each element in the list | ||
return [recursive_from_json(ctx, item, sub_type) for item in json_value] | ||
|
||
# Check if the type is a Dict | ||
elif typing.get_origin(expected_python_type) is dict: | ||
# For Dicts, get key and value types | ||
key_type, val_type = typing.get_args(expected_python_type) | ||
# Recursively process each key and value in the dict | ||
return {recursive_from_json(ctx, k, key_type): recursive_from_json(ctx, v, val_type) for k, v in | ||
json_value.items()} | ||
|
||
# Base case: if it's not a list or dict, we assume it's a simple type and return it | ||
try: | ||
return expected_python_type(json_value) # Cast to the expected type | ||
except Exception as e: | ||
raise ValueError(f"Could not cast {json_value} to {expected_python_type}: {e}") | ||
|
||
# Handle the serialization format | ||
value = json_idl_object.value | ||
serialization_format = json_idl_object.serialization_format | ||
if serialization_format == "UTF-8": | ||
# Decode JSON string | ||
json_value = json.loads(value.decode("utf-8")) | ||
else: | ||
raise ValueError(f"Unknown serialization format {serialization_format}") | ||
|
||
# Call the recursive function to handle nested structures | ||
return recursive_from_json(ctx, json_value, expected_python_type) | ||
|
||
def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> typing.List[typing.Any]: # type: ignore | ||
scalar = lv.scalar | ||
if scalar and scalar.json: | ||
return self.from_json(ctx, scalar.json, expected_python_type) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi, @wild-endeavor propeller get the list json str: https://github.com/flyteorg/flyte/pull/5735/files#diff-ee7f936e440a7e043b3bc7acb4ea255ba991dea8f3144d24ab276c3a292de018R103-R113 |
||
try: | ||
lits = lv.collection.literals | ||
except AttributeError: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't want to reuse the code in
to_literal
andto_python_val
because I think this will be far more readable and more easier to customize behavior in the future when we want to have more flexible change to the JSON IDL object.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't come up with the scenario, but this is just my instinct.