Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make SDK models pickleable #8746

Merged
merged 1 commit into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions changelog.d/20241127_132256_roman_pickle_models.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
### Added

- \[SDK\] Model instances can now be pickled
(<https://github.com/cvat-ai/cvat/pull/8746>)
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
{{name}} ({{{dataType}}}):{{#description}} {{{.}}}.{{/description}} [optional]{{#defaultValue}} if omitted the server will use the default value of {{{.}}}{{/defaultValue}} # noqa: E501
{{/optionalVars}}
"""
from {{packageName}}.configuration import Configuration

{{#requiredVars}}
{{#defaultValue}}
Expand All @@ -32,7 +31,7 @@
_check_type = kwargs.pop('_check_type', True)
_spec_property_naming = kwargs.pop('_spec_property_naming', False)
_path_to_item = kwargs.pop('_path_to_item', ())
_configuration = kwargs.pop('_configuration', Configuration())
_configuration = kwargs.pop('_configuration', None)
_visited_composed_classes = kwargs.pop('_visited_composed_classes', ())

self = super(OpenApiModel, cls).__new__(cls)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
{{/optionalVars}}
{{> model_templates/docstring_init_required_kwargs }}
"""
from {{packageName}}.configuration import Configuration

{{#requiredVars}}
{{#defaultValue}}
Expand All @@ -37,7 +36,7 @@
_check_type = kwargs.pop('_check_type', True)
_spec_property_naming = kwargs.pop('_spec_property_naming', True)
_path_to_item = kwargs.pop('_path_to_item', ())
_configuration = kwargs.pop('_configuration', Configuration())
_configuration = kwargs.pop('_configuration', None)
_visited_composed_classes = kwargs.pop('_visited_composed_classes', ())

self = super(OpenApiModel, cls).__new__(cls)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
value ({{{dataType}}}):{{#description}} {{{.}}}.{{/description}}{{#defaultValue}} if omitted defaults to {{{.}}}{{/defaultValue}}{{#allowableValues}}, must be one of [{{#enumVars}}{{{value}}}, {{/enumVars}}]{{/allowableValues}} # noqa: E501
{{> model_templates/docstring_init_required_kwargs }}
"""
from {{packageName}}.configuration import Configuration

# required up here when default value is not given
_path_to_item = kwargs.pop('_path_to_item', ())
Expand All @@ -39,7 +38,7 @@

_check_type = kwargs.pop('_check_type', True)
_spec_property_naming = kwargs.pop('_spec_property_naming', False)
_configuration = kwargs.pop('_configuration', Configuration())
_configuration = kwargs.pop('_configuration', None)
_visited_composed_classes = kwargs.pop('_visited_composed_classes', ())

{{> model_templates/invalid_pos_args }}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
{{/optionalVars}}
{{> model_templates/docstring_init_required_kwargs }}
"""
from {{packageName}}.configuration import Configuration

{{#requiredVars}}
{{^isReadOnly}}
Expand All @@ -42,7 +41,7 @@
_check_type = kwargs.pop('_check_type', True)
_spec_property_naming = kwargs.pop('_spec_property_naming', False)
_path_to_item = kwargs.pop('_path_to_item', ())
_configuration = kwargs.pop('_configuration', Configuration())
_configuration = kwargs.pop('_configuration', None)
_visited_composed_classes = kwargs.pop('_visited_composed_classes', ())

{{> model_templates/invalid_pos_args }}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
value ({{{dataType}}}):{{#description}} {{{.}}}.{{/description}}{{#defaultValue}} if omitted defaults to {{{.}}}{{/defaultValue}}{{#allowableValues}}, must be one of [{{#enumVars}}{{{value}}}, {{/enumVars}}]{{/allowableValues}} # noqa: E501
{{> model_templates/docstring_init_required_kwargs }}
"""
from {{packageName}}.configuration import Configuration

# required up here when default value is not given
_path_to_item = kwargs.pop('_path_to_item', ())
Expand All @@ -45,7 +44,7 @@

_check_type = kwargs.pop('_check_type', True)
_spec_property_naming = kwargs.pop('_spec_property_naming', False)
_configuration = kwargs.pop('_configuration', Configuration())
_configuration = kwargs.pop('_configuration', None)
_visited_composed_classes = kwargs.pop('_visited_composed_classes', ())

{{> model_templates/invalid_pos_args }}
Expand Down
41 changes: 21 additions & 20 deletions cvat-sdk/gen/templates/openapi-generator/model_utils.mustache
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,13 @@ class OpenApiModel(object):
new_inst = new_cls._new_from_openapi_data(*args, **kwargs)
return new_inst

def __setstate__(self, state):
# This is the same as the default implementation. We override it,
# because unpickling attempts to access `obj.__setstate__` on an uninitialized
# object, and if this method is not defined, it results in a call to `__getattr__`.
# This fails, because `__getattr__` relies on `self._data_store`, which doesn't
# exist in an uninitialized object.
self.__dict__.update(state)

class ModelSimple(OpenApiModel):
"""the parent class of models whose type != object in their
Expand Down Expand Up @@ -1084,7 +1091,7 @@ def deserialize_file(response_data, configuration, content_disposition=None):
(file_type): the deserialized file which is open
The user is responsible for closing and reading the file
"""
fd, path = tempfile.mkstemp(dir=configuration.temp_folder_path)
fd, path = tempfile.mkstemp(dir=configuration.temp_folder_path if configuration else None)
os.close(fd)
os.remove(path)

Expand Down Expand Up @@ -1263,27 +1270,21 @@ def validate_and_convert_types(input_value, required_types_mixed, path_to_item,
input_class_simple = get_simple_class(input_value)
valid_type = is_valid_type(input_class_simple, valid_classes)
if not valid_type:
if (configuration
or (input_class_simple == dict
and dict not in valid_classes)):
# if input_value is not valid_type try to convert it
converted_instance = attempt_convert_item(
input_value,
valid_classes,
path_to_item,
configuration,
spec_property_naming,
key_type=False,
must_convert=True,
check_type=_check_type
)
return converted_instance
else:
raise get_type_error(input_value, path_to_item, valid_classes,
key_type=False)
# if input_value is not valid_type try to convert it
converted_instance = attempt_convert_item(
input_value,
valid_classes,
path_to_item,
configuration,
spec_property_naming,
key_type=False,
must_convert=True,
check_type=_check_type
)
return converted_instance

# input_value's type is in valid_classes
if len(valid_classes) > 1 and configuration:
if len(valid_classes) > 1:
# there are valid classes which are not the current class
valid_classes_coercible = remove_uncoercible(
valid_classes, input_value, spec_property_naming, must_convert=False)
Expand Down
10 changes: 10 additions & 0 deletions tests/python/sdk/test_api_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# SPDX-License-Identifier: MIT

import pickle
from copy import deepcopy

from cvat_sdk import models
Expand Down Expand Up @@ -112,3 +113,12 @@ def test_models_do_not_return_internal_collections():
model_data2 = model.to_dict()

assert DeepDiff(model_data1_original, model_data2) == {}


def test_models_are_pickleable():
model = models.PatchedLabelRequest(id=5, name="person")
pickled_model = pickle.dumps(model)
unpickled_model = pickle.loads(pickled_model)

assert unpickled_model.id == model.id
assert unpickled_model.name == model.name
Loading