From c5eb449e08022363a431c507f7ca5410190a195b Mon Sep 17 00:00:00 2001 From: Ayush Baid Date: Mon, 8 Mar 2021 14:58:50 -0500 Subject: [PATCH 1/3] Squashed 'wrap/' changes from b28b3570d..d37b8a972 d37b8a972 Merge pull request #31 from ayushbaid/feature/pickle efd4a0fb4 removing tab 42fd231f3 adding newline for test compare 0aa316150 removing extra brackets 7fe1d7d0f adding pickle code to expected file 9c3ab7a8b fixing string format for new classname 2f89284e8 moving pickle support with the serialization code 5a8abc916 adding pickle support for select classes using serialization git-subtree-dir: wrap git-subtree-split: d37b8a972f6f3aa99a657831102fc22a73b59f21 --- gtwrap/pybind_wrapper.py | 13 ++++++++++++- tests/expected-python/geometry_pybind.cpp | 20 ++++++++++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/gtwrap/pybind_wrapper.py b/gtwrap/pybind_wrapper.py index c0e88e37af..ec54807857 100755 --- a/gtwrap/pybind_wrapper.py +++ b/gtwrap/pybind_wrapper.py @@ -75,7 +75,17 @@ def _wrap_method(self, method, cpp_class, prefix, suffix, method_suffix=""): []({class_inst} self, string serialized){{ gtsam::deserialize(serialized, *self); }}, py::arg("serialized")) - '''.format(class_inst=cpp_class + '*')) + .def(py::pickle( + [](const {cpp_class} &a){{ // __getstate__ + /* Returns a string that encodes the state of the object */ + return py::make_tuple(gtsam::serialize(a)); + }}, + [](py::tuple t){{ // __setstate__ + {cpp_class} obj; + gtsam::deserialize(t[0].cast(), obj); + return obj; + }})) + '''.format(class_inst=cpp_class + '*', cpp_class=cpp_class)) is_method = isinstance(method, instantiator.InstantiatedMethod) is_static = isinstance(method, parser.StaticMethod) @@ -318,3 +328,4 @@ def wrap(self): wrapped_namespace=wrapped_namespace, boost_class_export=boost_class_export, ) + diff --git a/tests/expected-python/geometry_pybind.cpp b/tests/expected-python/geometry_pybind.cpp index 3eee55bf44..53d33eabf3 100644 --- a/tests/expected-python/geometry_pybind.cpp +++ b/tests/expected-python/geometry_pybind.cpp @@ -47,6 +47,16 @@ PYBIND11_MODULE(geometry_py, m_) { [](gtsam::Point2* self, string serialized){ gtsam::deserialize(serialized, *self); }, py::arg("serialized")) +.def(py::pickle( + [](const gtsam::Point2 &a){ // __getstate__ + /* Returns a string that encodes the state of the object */ + return py::make_tuple(gtsam::serialize(a)); + }, + [](py::tuple t){ // __setstate__ + gtsam::Point2 obj; + gtsam::deserialize(t[0].cast(), obj); + return obj; + })) ; py::class_>(m_gtsam, "Point3") @@ -61,6 +71,16 @@ PYBIND11_MODULE(geometry_py, m_) { [](gtsam::Point3* self, string serialized){ gtsam::deserialize(serialized, *self); }, py::arg("serialized")) +.def(py::pickle( + [](const gtsam::Point3 &a){ // __getstate__ + /* Returns a string that encodes the state of the object */ + return py::make_tuple(gtsam::serialize(a)); + }, + [](py::tuple t){ // __setstate__ + gtsam::Point3 obj; + gtsam::deserialize(t[0].cast(), obj); + return obj; + })) .def_static("staticFunction",[](){return gtsam::Point3::staticFunction();}) .def_static("StaticFunctionRet",[]( double z){return gtsam::Point3::StaticFunctionRet(z);}, py::arg("z")); From afc9333a17167aed7a351d763e8b6fd06eb3bb3b Mon Sep 17 00:00:00 2001 From: Ayush Baid Date: Mon, 8 Mar 2021 19:53:26 -0500 Subject: [PATCH 2/3] Squashed 'wrap/' changes from d37b8a972..10e1efd6f 10e1efd6f Merge pull request #32 from ayushbaid/feature/pickle 55d5d7fbe ignoring pickle in matlab wrap 8d70c7fe2 adding newlines dee8aaee3 adding markers for pickling 46fc45d82 separating out the marker for pickle git-subtree-dir: wrap git-subtree-split: 10e1efd6f25f76b774868b3da33cb3954008b233 --- gtwrap/matlab_wrapper.py | 12 ++++++++++++ gtwrap/pybind_wrapper.py | 7 ++++++- tests/expected-python/geometry_pybind.cpp | 2 ++ tests/geometry.h | 6 ++++++ 4 files changed, 26 insertions(+), 1 deletion(-) diff --git a/gtwrap/matlab_wrapper.py b/gtwrap/matlab_wrapper.py index 669bf474f0..fe4ee7e191 100755 --- a/gtwrap/matlab_wrapper.py +++ b/gtwrap/matlab_wrapper.py @@ -49,6 +49,8 @@ class MatlabWrapper(object): } """Methods that should not be wrapped directly""" whitelist = ['serializable', 'serialize'] + """Methods that should be ignored""" + ignore_methods = ['pickle'] """Datatypes that do not need to be checked in methods""" not_check_type = [] """Data types that are primitive types""" @@ -563,6 +565,8 @@ def class_comment(self, instantiated_class): for method in methods: if method.name in self.whitelist: continue + if method.name in self.ignore_methods: + continue comment += '%{name}({args})'.format(name=method.name, args=self._wrap_args(method.args)) @@ -612,6 +616,9 @@ def wrap_methods(self, methods, globals=False, global_ns=None): methods = self._group_methods(methods) for method in methods: + if method in self.ignore_methods: + continue + if globals: self._debug("[wrap_methods] wrapping: {}..{}={}".format(method[0].parent.name, method[0].name, type(method[0].parent.name))) @@ -861,6 +868,8 @@ def wrap_class_methods(self, namespace_name, inst_class, methods, serialize=[Fal method_name = method[0].name if method_name in self.whitelist and method_name != 'serialize': continue + if method_name in self.ignore_methods: + continue if method_name == 'serialize': serialize[0] = True @@ -932,6 +941,9 @@ def wrap_static_methods(self, namespace_name, instantiated_class, serialize): format_name = list(static_method[0].name) format_name[0] = format_name[0].upper() + if static_method[0].name in self.ignore_methods: + continue + method_text += textwrap.indent(textwrap.dedent('''\ function varargout = {name}(varargin) '''.format(name=''.join(format_name))), diff --git a/gtwrap/pybind_wrapper.py b/gtwrap/pybind_wrapper.py index ec54807857..a045afcbd6 100755 --- a/gtwrap/pybind_wrapper.py +++ b/gtwrap/pybind_wrapper.py @@ -75,6 +75,11 @@ def _wrap_method(self, method, cpp_class, prefix, suffix, method_suffix=""): []({class_inst} self, string serialized){{ gtsam::deserialize(serialized, *self); }}, py::arg("serialized")) + '''.format(class_inst=cpp_class + '*')) + if cpp_method == "pickle": + if not cpp_class in self._serializing_classes: + raise ValueError("Cannot pickle a class which is not serializable") + return textwrap.dedent(''' .def(py::pickle( [](const {cpp_class} &a){{ // __getstate__ /* Returns a string that encodes the state of the object */ @@ -85,7 +90,7 @@ def _wrap_method(self, method, cpp_class, prefix, suffix, method_suffix=""): gtsam::deserialize(t[0].cast(), obj); return obj; }})) - '''.format(class_inst=cpp_class + '*', cpp_class=cpp_class)) + '''.format(cpp_class=cpp_class)) is_method = isinstance(method, instantiator.InstantiatedMethod) is_static = isinstance(method, parser.StaticMethod) diff --git a/tests/expected-python/geometry_pybind.cpp b/tests/expected-python/geometry_pybind.cpp index 53d33eabf3..be6482d89a 100644 --- a/tests/expected-python/geometry_pybind.cpp +++ b/tests/expected-python/geometry_pybind.cpp @@ -47,6 +47,7 @@ PYBIND11_MODULE(geometry_py, m_) { [](gtsam::Point2* self, string serialized){ gtsam::deserialize(serialized, *self); }, py::arg("serialized")) + .def(py::pickle( [](const gtsam::Point2 &a){ // __getstate__ /* Returns a string that encodes the state of the object */ @@ -71,6 +72,7 @@ PYBIND11_MODULE(geometry_py, m_) { [](gtsam::Point3* self, string serialized){ gtsam::deserialize(serialized, *self); }, py::arg("serialized")) + .def(py::pickle( [](const gtsam::Point3 &a){ // __getstate__ /* Returns a string that encodes the state of the object */ diff --git a/tests/geometry.h b/tests/geometry.h index 40d878c9fc..ec5d3b2774 100644 --- a/tests/geometry.h +++ b/tests/geometry.h @@ -22,6 +22,9 @@ class Point2 { VectorNotEigen vectorConfusion(); void serializable() const; // Sets flag and creates export, but does not make serialization functions + + // enable pickling in python + void pickle() const; }; #include @@ -35,6 +38,9 @@ class Point3 { // enabling serialization functionality void serialize() const; // Just triggers a flag internally and removes actual function + + // enable pickling in python + void pickle() const; }; } From 1670e68e2f5548722837072a7176128e23fe492b Mon Sep 17 00:00:00 2001 From: Ayush Baid Date: Mon, 8 Mar 2021 20:18:09 -0500 Subject: [PATCH 3/3] enabling markers and testing pickle roundtrip for few classes --- gtsam/gtsam.i | 100 ++++++++++++++++++++++++++++++ python/gtsam/tests/test_pickle.py | 46 ++++++++++++++ python/gtsam/utils/test_case.py | 12 ++++ 3 files changed, 158 insertions(+) create mode 100644 python/gtsam/tests/test_pickle.py diff --git a/gtsam/gtsam.i b/gtsam/gtsam.i index 22c2cc17df..85209ddfa8 100644 --- a/gtsam/gtsam.i +++ b/gtsam/gtsam.i @@ -105,6 +105,9 @@ * virtual class MyFactor : gtsam::NoiseModelFactor {...}; * - *DO NOT* re-define overriden function already declared in the external (forward-declared) base class * - This will cause an ambiguity problem in Pybind header file + * Pickle support in Python: + * - Add "void pickle()" to a class to enable pickling via gtwrap. In the current implementation, "void serialize()" + * and a public constructor with no-arguments in needed for successful build. */ /** @@ -144,6 +147,9 @@ class KeyList { void remove(size_t key); void serialize() const; + + // enable pickling in python + void pickle() const; }; // Actually a FastSet @@ -169,6 +175,9 @@ class KeySet { bool count(size_t key) const; // returns true if value exists void serialize() const; + + // enable pickling in python + void pickle() const; }; // Actually a vector @@ -190,6 +199,9 @@ class KeyVector { void push_back(size_t key) const; void serialize() const; + + // enable pickling in python + void pickle() const; }; // Actually a FastMap @@ -361,6 +373,9 @@ class Point2 { // enabling serialization functionality void serialize() const; + + // enable pickling in python + void pickle() const; }; // std::vector @@ -422,6 +437,9 @@ class StereoPoint2 { // enabling serialization functionality void serialize() const; + + // enable pickling in python + void pickle() const; }; #include @@ -446,6 +464,9 @@ class Point3 { // enabling serialization functionality void serialize() const; + + // enable pickling in python + void pickle() const; }; #include @@ -491,6 +512,9 @@ class Rot2 { // enabling serialization functionality void serialize() const; + + // enable pickling in python + void pickle() const; }; #include @@ -653,6 +677,9 @@ class Rot3 { // enabling serialization functionality void serialize() const; + + // enable pickling in python + void pickle() const; }; #include @@ -708,6 +735,9 @@ class Pose2 { // enabling serialization functionality void serialize() const; + + // enable pickling in python + void pickle() const; }; #include @@ -764,6 +794,9 @@ class Pose3 { // enabling serialization functionality void serialize() const; + + // enable pickling in python + void pickle() const; }; // std::vector @@ -797,6 +830,15 @@ class Unit3 { size_t dim() const; gtsam::Unit3 retract(Vector v) const; Vector localCoordinates(const gtsam::Unit3& s) const; + + // enabling serialization functionality + void serialize() const; + + // enable pickling in python + void pickle() const; + + // enabling function to compare objects + bool equals(const gtsam::Unit3& expected, double tol) const; }; #include @@ -856,6 +898,9 @@ class Cal3_S2 { // enabling serialization functionality void serialize() const; + + // enable pickling in python + void pickle() const; }; #include @@ -884,6 +929,9 @@ virtual class Cal3DS2_Base { // enabling serialization functionality void serialize() const; + + // enable pickling in python + void pickle() const; }; #include @@ -905,6 +953,9 @@ virtual class Cal3DS2 : gtsam::Cal3DS2_Base { // enabling serialization functionality void serialize() const; + + // enable pickling in python + void pickle() const; }; #include @@ -931,6 +982,9 @@ virtual class Cal3Unified : gtsam::Cal3DS2_Base { // enabling serialization functionality void serialize() const; + + // enable pickling in python + void pickle() const; }; #include @@ -988,6 +1042,9 @@ class Cal3Bundler { // enabling serialization functionality void serialize() const; + + // enable pickling in python + void pickle() const; }; #include @@ -1018,6 +1075,9 @@ class CalibratedCamera { // enabling serialization functionality void serialize() const; + + // enable pickling in python + void pickle() const; }; #include @@ -1056,6 +1116,9 @@ class PinholeCamera { // enabling serialization functionality void serialize() const; + + // enable pickling in python + void pickle() const; }; // Forward declaration of PinholeCameraCalX is defined here. @@ -1103,6 +1166,9 @@ class StereoCamera { // enabling serialization functionality void serialize() const; + + // enable pickling in python + void pickle() const; }; #include @@ -1557,6 +1623,9 @@ class VectorValues { // enabling serialization functionality void serialize() const; + + // enable pickling in python + void pickle() const; }; #include @@ -1618,6 +1687,9 @@ virtual class JacobianFactor : gtsam::GaussianFactor { // enabling serialization functionality void serialize() const; + + // enable pickling in python + void pickle() const; }; #include @@ -1649,6 +1721,9 @@ virtual class HessianFactor : gtsam::GaussianFactor { // enabling serialization functionality void serialize() const; + + // enable pickling in python + void pickle() const; }; #include @@ -1728,6 +1803,9 @@ class GaussianFactorGraph { // enabling serialization functionality void serialize() const; + + // enable pickling in python + void pickle() const; }; #include @@ -2033,6 +2111,9 @@ class Ordering { // enabling serialization functionality void serialize() const; + + // enable pickling in python + void pickle() const; }; #include @@ -2071,6 +2152,10 @@ class NonlinearFactorGraph { // enabling serialization functionality void serialize() const; + + // enable pickling in python + void pickle() const; + void saveGraph(const string& s) const; }; @@ -2128,6 +2213,9 @@ class Values { // enabling serialization functionality void serialize() const; + // enable pickling in python + void pickle() const; + // New in 4.0, we have to specialize every insert/update/at to generate wrappers // Instead of the old: // void insert(size_t j, const gtsam::Value& value); @@ -2527,6 +2615,9 @@ virtual class PriorFactor : gtsam::NoiseModelFactor { // enabling serialization functionality void serialize() const; + + // enable pickling in python + void pickle() const; }; @@ -2538,6 +2629,9 @@ virtual class BetweenFactor : gtsam::NoiseModelFactor { // enabling serialization functionality void serialize() const; + + // enable pickling in python + void pickle() const; }; #include @@ -2774,6 +2868,9 @@ class SfmTrack { // enabling serialization functionality void serialize() const; + // enable pickling in python + void pickle() const; + // enabling function to compare objects bool equals(const gtsam::SfmTrack& expected, double tol) const; }; @@ -2790,6 +2887,9 @@ class SfmData { // enabling serialization functionality void serialize() const; + // enable pickling in python + void pickle() const; + // enabling function to compare objects bool equals(const gtsam::SfmData& expected, double tol) const; }; diff --git a/python/gtsam/tests/test_pickle.py b/python/gtsam/tests/test_pickle.py new file mode 100644 index 0000000000..0acbf6765e --- /dev/null +++ b/python/gtsam/tests/test_pickle.py @@ -0,0 +1,46 @@ +""" +GTSAM Copyright 2010-2020, Georgia Tech Research Corporation, +Atlanta, Georgia 30332-0415 +All Rights Reserved + +See LICENSE for the license information + +Unit tests to check pickling. + +Author: Ayush Baid +""" +from gtsam import Cal3Bundler, PinholeCameraCal3Bundler, Point2, Point3, Pose3, Rot3, SfmTrack, Unit3 + +from gtsam.utils.test_case import GtsamTestCase + +class TestPickle(GtsamTestCase): + """Tests pickling on some of the classes.""" + + def test_cal3Bundler_roundtrip(self): + obj = Cal3Bundler(fx=100, k1=0.1, k2=0.2, u0=100, v0=70) + self.assertEqualityOnPickleRoundtrip(obj) + + def test_pinholeCameraCal3Bundler_roundtrip(self): + obj = PinholeCameraCal3Bundler( + Pose3(Rot3.RzRyRx(0, 0.1, -0.05), Point3(1, 1, 0)), + Cal3Bundler(fx=100, k1=0.1, k2=0.2, u0=100, v0=70), + ) + self.assertEqualityOnPickleRoundtrip(obj) + + def test_rot3_roundtrip(self): + obj = Rot3.RzRyRx(0, 0.05, 0.1) + self.assertEqualityOnPickleRoundtrip(obj) + + def test_pose3_roundtrip(self): + obj = Pose3(Rot3.Ypr(0.0, 1.0, 0.0), Point3(1, 1, 0)) + self.assertEqualityOnPickleRoundtrip(obj) + + def test_sfmTrack_roundtrip(self): + obj = SfmTrack(Point3(1, 1, 0)) + obj.add_measurement(0, Point2(-1, 5)) + obj.add_measurement(1, Point2(6, 2)) + self.assertEqualityOnPickleRoundtrip(obj) + + def test_unit3_roundtrip(self): + obj = Unit3(Point3(1, 1, 0)) + self.assertEqualityOnPickleRoundtrip(obj) diff --git a/python/gtsam/utils/test_case.py b/python/gtsam/utils/test_case.py index 3effd7f658..50af004f43 100644 --- a/python/gtsam/utils/test_case.py +++ b/python/gtsam/utils/test_case.py @@ -8,6 +8,7 @@ TestCase class with GTSAM assert utils. Author: Frank Dellaert """ +import pickle import unittest @@ -29,3 +30,14 @@ def gtsamAssertEquals(self, actual, expected, tol=1e-9): if not equal: raise self.failureException( "Values are not equal:\n{}!={}".format(actual, expected)) + + def assertEqualityOnPickleRoundtrip(self, obj: object, tol=1e-9) -> None: + """ Performs a round-trip using pickle and asserts equality. + + Usage: + self.assertEqualityOnPickleRoundtrip(obj) + Keyword Arguments: + tol {float} -- tolerance passed to 'equals', default 1e-9 + """ + roundTripObj = pickle.loads(pickle.dumps(obj)) + self.gtsamAssertEquals(roundTripObj, obj)