-
Notifications
You must be signed in to change notification settings - Fork 312
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[wip] JSON IDL V2 #2741
[wip] JSON IDL V2 #2741
Changes from 1 commit
158ae01
d31a82d
0adf1ae
b7e2bdb
dcf9b4b
24fe614
70e298c
848bd4b
b145637
e42b667
1782dab
313dc8f
717104e
69d9e8b
b1e4c66
ea45e1e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -49,6 +49,7 @@ | |
Scalar, | ||
Union, | ||
Void, | ||
Json | ||
) | ||
from flytekit.models.types import LiteralType, SimpleType, TypeStructure, UnionType | ||
|
||
|
@@ -196,6 +197,26 @@ def to_html(self, ctx: FlyteContext, python_val: T, expected_python_type: Type[T | |
""" | ||
return str(python_val) | ||
|
||
def to_json(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal: | ||
json_str = json.dumps(python_val) | ||
serialization_format = expected.simple.JSON | ||
json_bytes = None | ||
if serialization_format == "UTF-8": | ||
json_bytes = json_str.encode("UTF-8") | ||
if json_bytes is None: | ||
json_bytes = json_str.encode("UTF-8") # default to UTF-8 | ||
return Literal(scalar=Scalar(json=Json(value=json_bytes, serialization_format="UTF-8"))) | ||
|
||
def from_json(self, ctx: FlyteContext, json_idl_object: Json, expected_python_type: Type[T]) -> T: | ||
value = json_idl_object.value | ||
serialization_format = json_idl_object.serialization_format | ||
json_str = None | ||
if serialization_format == "UTF-8": | ||
json_str = value.decode("UTF-8") | ||
if json_str is None: | ||
json_str = value.decode("UTF-8") # default to UTF-8 | ||
return json.loads(json_str) | ||
|
||
def __repr__(self): | ||
return f"{self._name} Transforms ({self._t}) to Flyte native" | ||
|
||
|
@@ -488,9 +509,83 @@ def get_literal_type(self, t: Type[T]) -> LiteralType: | |
|
||
ts = TypeStructure(tag="", dataclass_type=literal_type) | ||
|
||
return _type_models.LiteralType(simple=_type_models.SimpleType.STRUCT, metadata=schema, structure=ts) | ||
return _type_models.LiteralType(simple=_type_models.SimpleType.JSON, metadata=schema, structure=ts) | ||
|
||
# We use UTF-8 as the default serialization format for JSON | ||
def to_json(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal: | ||
if isinstance(python_val, dict): | ||
json_str = json.dumps(python_val) | ||
json_bytes = json_str.encode("UTF-8") | ||
return Literal(scalar=Scalar(json=Json(value=json_bytes, serialization_format="UTF-8"))) | ||
|
||
if not dataclasses.is_dataclass(python_val): | ||
Comment on lines
+521
to
+527
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't want to reuse the code in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I haven't come up with the scenario, but this is just my instinct. |
||
raise TypeTransformerFailedError( | ||
f"{type(python_val)} is not of type @dataclass, only Dataclasses are supported for " | ||
f"user defined datatypes in Flytekit" | ||
) | ||
|
||
self._make_dataclass_json_serializable(python_val, python_type) | ||
|
||
# The `to_json` integrated through mashumaro's `DataClassJSONMixin` allows for more | ||
# functionality than JSONEncoder | ||
# We can't use hasattr(python_val, "to_json") here because we rely on mashumaro's API to customize the serialization behavior for Flyte types. | ||
if isinstance(python_val, DataClassJSONMixin): | ||
json_str = python_val.to_json() | ||
else: | ||
# The function looks up or creates a JSONEncoder specifically designed for the object's type. | ||
# This encoder is then used to convert a data class into a JSON string. | ||
try: | ||
encoder = self._encoder[python_type] | ||
except KeyError: | ||
encoder = JSONEncoder(python_type) | ||
self._encoder[python_type] = encoder | ||
|
||
try: | ||
json_str = encoder.encode(python_val) | ||
except NotImplementedError: | ||
# you can refer FlyteFile, FlyteDirectory and StructuredDataset to see how flyte types can be implemented. | ||
raise NotImplementedError( | ||
f"{python_type} should inherit from mashumaro.types.SerializableType" | ||
f" and implement _serialize and _deserialize methods." | ||
) | ||
|
||
json_bytes = json_str.encode("UTF-8") | ||
return Literal(scalar=Scalar(json=Json(value=json_bytes, serialization_format="UTF-8"))) | ||
|
||
# We use UTF-8 as the default serialization format for JSON | ||
def from_json(self, ctx: FlyteContext, json_idl_object: Json, expected_python_type: Type[T]) -> T: | ||
value = json_idl_object.value | ||
serialization_format = json_idl_object.serialization_format | ||
json_str = None | ||
|
||
if serialization_format == "UTF-8": | ||
json_str = value.decode("UTF-8") | ||
if json_str is None: | ||
json_str = value.decode("UTF-8") # default to UTF-8 | ||
|
||
# The `from_json` function is provided from mashumaro's `DataClassJSONMixin`. | ||
# It deserializes a JSON string into a data class, and supports additional functionality over JSONDecoder | ||
# We can't use hasattr(expected_python_type, "from_json") here because we rely on mashumaro's API to customize the deserialization behavior for Flyte types. | ||
if issubclass(expected_python_type, DataClassJSONMixin): | ||
dc = expected_python_type.from_json(json_str) # type: ignore | ||
else: | ||
# The function looks up or creates a JSONDecoder specifically designed for the object's type. | ||
# This decoder is then used to convert a JSON string into a data class. | ||
try: | ||
decoder = self._decoder[expected_python_type] | ||
except KeyError: | ||
decoder = JSONDecoder(expected_python_type) | ||
self._decoder[expected_python_type] = decoder | ||
|
||
dc = decoder.decode(json_str) | ||
|
||
dc = self._fix_structured_dataset_type(expected_python_type, dc) | ||
return self._fix_dataclass_int(expected_python_type, dc) | ||
|
||
def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal: | ||
if expected.simple == SimpleType.JSON: | ||
return self.to_json(ctx, python_val, python_type, expected) | ||
|
||
if isinstance(python_val, dict): | ||
json_str = json.dumps(python_val) | ||
return Literal(scalar=Scalar(generic=_json_format.Parse(json_str, _struct.Struct()))) | ||
|
@@ -570,7 +665,7 @@ def _fix_structured_dataset_type(self, python_type: Type[T], python_val: typing. | |
python_val.__setattr__(field.name, self._fix_structured_dataset_type(field.type, val)) | ||
return python_val | ||
|
||
def _make_dataclass_serializable(self, python_val: T, python_type: Type[T]) -> typing.Any: | ||
def _make_dataclass_json_serializable(self, python_val: T, python_type: Type[T]) -> typing.Any: | ||
""" | ||
If any field inside the dataclass is flyte type, we should use flyte type transformer for that field. | ||
""" | ||
|
@@ -581,18 +676,18 @@ def _make_dataclass_serializable(self, python_val: T, python_type: Type[T]) -> t | |
if UnionTransformer.is_optional_type(python_type): | ||
if python_val is None: | ||
return None | ||
return self._make_dataclass_serializable(python_val, get_args(python_type)[0]) | ||
return self._make_dataclass_json_serializable(python_val, get_args(python_type)[0]) | ||
|
||
if hasattr(python_type, "__origin__") and get_origin(python_type) is list: | ||
if python_val is None: | ||
return None | ||
return [self._make_dataclass_serializable(v, get_args(python_type)[0]) for v in cast(list, python_val)] | ||
return [self._make_dataclass_json_serializable(v, get_args(python_type)[0]) for v in cast(list, python_val)] | ||
|
||
if hasattr(python_type, "__origin__") and get_origin(python_type) is dict: | ||
if python_val is None: | ||
return None | ||
return { | ||
k: self._make_dataclass_serializable(v, get_args(python_type)[1]) | ||
k: self._make_dataclass_json_serializable(v, get_args(python_type)[1]) | ||
for k, v in cast(dict, python_val).items() | ||
} | ||
|
||
|
@@ -618,7 +713,7 @@ def _make_dataclass_serializable(self, python_val: T, python_type: Type[T]) -> t | |
dataclass_attributes = typing.get_type_hints(python_type) | ||
for n, t in dataclass_attributes.items(): | ||
val = python_val.__getattribute__(n) | ||
python_val.__setattr__(n, self._make_dataclass_serializable(val, t)) | ||
python_val.__setattr__(n, self._make_dataclass_json_serializable(val, t)) | ||
return python_val | ||
|
||
def _fix_val_int(self, t: typing.Type, val: typing.Any) -> typing.Any: | ||
|
@@ -672,6 +767,10 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: | |
"user defined datatypes in Flytekit" | ||
) | ||
|
||
json_idl_object = lv.scalar.json | ||
if json_idl_object: | ||
return self.from_json(ctx, json_idl_object, expected_python_type) | ||
|
||
json_str = _json_format.MessageToJson(lv.scalar.generic) | ||
|
||
# The `from_json` function is provided from mashumaro's `DataClassJSONMixin`. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will do this later, thank you.