Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding pickle support by using serialization #709

Merged
merged 5 commits into from
Mar 9, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions wrap/gtwrap/matlab_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ class MatlabWrapper(object):
}
"""Methods that should not be wrapped directly"""
whitelist = ['serializable', 'serialize']
"""Methods that should be ignored"""
ignore_methods = ['pickle']
"""Datatypes that do not need to be checked in methods"""
not_check_type = []
"""Data types that are primitive types"""
Expand Down Expand Up @@ -563,6 +565,8 @@ def class_comment(self, instantiated_class):
for method in methods:
if method.name in self.whitelist:
continue
if method.name in self.ignore_methods:
continue

comment += '%{name}({args})'.format(name=method.name, args=self._wrap_args(method.args))

Expand Down Expand Up @@ -612,6 +616,9 @@ def wrap_methods(self, methods, globals=False, global_ns=None):
methods = self._group_methods(methods)

for method in methods:
if method in self.ignore_methods:
continue

if globals:
self._debug("[wrap_methods] wrapping: {}..{}={}".format(method[0].parent.name, method[0].name,
type(method[0].parent.name)))
Expand Down Expand Up @@ -861,6 +868,8 @@ def wrap_class_methods(self, namespace_name, inst_class, methods, serialize=[Fal
method_name = method[0].name
if method_name in self.whitelist and method_name != 'serialize':
continue
if method_name in self.ignore_methods:
continue

if method_name == 'serialize':
serialize[0] = True
Expand Down Expand Up @@ -932,6 +941,9 @@ def wrap_static_methods(self, namespace_name, instantiated_class, serialize):
format_name = list(static_method[0].name)
format_name[0] = format_name[0].upper()

if static_method[0].name in self.ignore_methods:
continue

method_text += textwrap.indent(textwrap.dedent('''\
function varargout = {name}(varargin)
'''.format(name=''.join(format_name))),
Expand Down
16 changes: 16 additions & 0 deletions wrap/gtwrap/pybind_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,21 @@ def _wrap_method(self, method, cpp_class, prefix, suffix, method_suffix=""):
gtsam::deserialize(serialized, *self);
}}, py::arg("serialized"))
'''.format(class_inst=cpp_class + '*'))
if cpp_method == "pickle":
if not cpp_class in self._serializing_classes:
raise ValueError("Cannot pickle a class which is not serializable")
return textwrap.dedent('''
.def(py::pickle(
[](const {cpp_class} &a){{ // __getstate__
/* Returns a string that encodes the state of the object */
return py::make_tuple(gtsam::serialize(a));
}},
[](py::tuple t){{ // __setstate__
{cpp_class} obj;
gtsam::deserialize(t[0].cast<std::string>(), obj);
return obj;
}}))
'''.format(cpp_class=cpp_class))

is_method = isinstance(method, instantiator.InstantiatedMethod)
is_static = isinstance(method, parser.StaticMethod)
Expand Down Expand Up @@ -318,3 +333,4 @@ def wrap(self):
wrapped_namespace=wrapped_namespace,
boost_class_export=boost_class_export,
)

22 changes: 22 additions & 0 deletions wrap/tests/expected-python/geometry_pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,17 @@ PYBIND11_MODULE(geometry_py, m_) {
[](gtsam::Point2* self, string serialized){
gtsam::deserialize(serialized, *self);
}, py::arg("serialized"))

.def(py::pickle(
[](const gtsam::Point2 &a){ // __getstate__
/* Returns a string that encodes the state of the object */
return py::make_tuple(gtsam::serialize(a));
},
[](py::tuple t){ // __setstate__
gtsam::Point2 obj;
gtsam::deserialize(t[0].cast<std::string>(), obj);
return obj;
}))
;

py::class_<gtsam::Point3, std::shared_ptr<gtsam::Point3>>(m_gtsam, "Point3")
Expand All @@ -62,6 +73,17 @@ PYBIND11_MODULE(geometry_py, m_) {
gtsam::deserialize(serialized, *self);
}, py::arg("serialized"))

.def(py::pickle(
[](const gtsam::Point3 &a){ // __getstate__
/* Returns a string that encodes the state of the object */
return py::make_tuple(gtsam::serialize(a));
},
[](py::tuple t){ // __setstate__
gtsam::Point3 obj;
gtsam::deserialize(t[0].cast<std::string>(), obj);
return obj;
}))

.def_static("staticFunction",[](){return gtsam::Point3::staticFunction();})
.def_static("StaticFunctionRet",[]( double z){return gtsam::Point3::StaticFunctionRet(z);}, py::arg("z"));

Expand Down
6 changes: 6 additions & 0 deletions wrap/tests/geometry.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ class Point2 {
VectorNotEigen vectorConfusion();

void serializable() const; // Sets flag and creates export, but does not make serialization functions

// enable pickling in python
void pickle() const;
};

#include <gtsam/geometry/Point3.h>
Expand All @@ -35,6 +38,9 @@ class Point3 {

// enabling serialization functionality
void serialize() const; // Just triggers a flag internally and removes actual function

// enable pickling in python
void pickle() const;
};

}
Expand Down