Skip to content

Commit

Permalink
Merge pull request #36 from pFernbach/topic/pickle
Browse files Browse the repository at this point in the history
[Python] Add pickle support
  • Loading branch information
pFernbach authored Mar 30, 2020
2 parents 7d52657 + b5c0264 commit 4bbc08a
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 13 deletions.
64 changes: 52 additions & 12 deletions python/curves/curves_python.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "python_variables.h"
#include "archive_python_binding.h"
#include "optimization_python.h"
#include <curves/serialization/curves.hpp>

#include <boost/python.hpp>
#include <boost/python/class.hpp>
Expand Down Expand Up @@ -84,6 +85,29 @@ BOOST_PYTHON_MEMBER_FUNCTION_OVERLOADS(curve_SE3_t_isEquivalent_overloads, curve

/* end base wrap of curve_abc */

/* Structure used to define pickle serialization of python curves */
template <typename Curve>
struct curve_pickle_suite : pickle_suite {

static object getstate (const Curve& curve) {
std::ostringstream os;
boost::archive::text_oarchive oa(os);
curves::serialization::register_types(oa);
oa << curve;
return str(os.str());
}

static void
setstate(Curve& curve, object entries) {
str s = extract<str> (entries)();
std::string st = extract<std::string> (s)();
std::istringstream is (st);
boost::archive::text_iarchive ia (is);
curves::serialization::register_types(ia);
ia >> curve;
}
};

/* Template constructor bezier */
template <typename Bezier, typename PointList, typename T_Point>
Bezier* wrapBezierConstructorTemplate(const PointList& array, const real T_min = 0., const real T_max = 1.) {
Expand Down Expand Up @@ -482,7 +506,8 @@ BOOST_PYTHON_MODULE(curves) {
.def("saveAsBinary", pure_virtual(&curve_abc_t::saveAsBinary<curve_abc_t>), bp::args("filename"),
"Saves *this inside a binary file.")
.def("loadFromBinary", pure_virtual(&curve_abc_t::loadFromBinary<curve_abc_t>), bp::args("filename"),
"Loads *this from a binary file.");
"Loads *this from a binary file.")
.def_pickle(curve_pickle_suite<curve_abc_t>());

class_<curve_3_t, boost::noncopyable, bases<curve_abc_t>, boost::shared_ptr<curve_3_callback> >("curve3")
.def("__call__", &curve_3_t::operator(), "Evaluate the curve at the given time.",
Expand All @@ -498,7 +523,8 @@ BOOST_PYTHON_MODULE(curves) {
args("self", "N"))
.def("min", &curve_3_t::min, "Get the LOWER bound on interval definition of the curve.")
.def("max", &curve_3_t::max, "Get the HIGHER bound on interval definition of the curve.")
.def("dim", &curve_3_t::dim, "Get the dimension of the curve.");
.def("dim", &curve_3_t::dim, "Get the dimension of the curve.")
.def_pickle(curve_pickle_suite<curve_3_t>());

class_<curve_rotation_t, boost::noncopyable, bases<curve_abc_t>, boost::shared_ptr<curve_rotation_callback> >("curve_rotation")
.def("__call__", &curve_rotation_t::operator(), "Evaluate the curve at the given time.",
Expand All @@ -514,7 +540,8 @@ BOOST_PYTHON_MODULE(curves) {
args("self", "N"))
.def("min", &curve_rotation_t::min, "Get the LOWER bound on interval definition of the curve.")
.def("max", &curve_rotation_t::max, "Get the HIGHER bound on interval definition of the curve.")
.def("dim", &curve_rotation_t::dim, "Get the dimension of the curve.");
.def("dim", &curve_rotation_t::dim, "Get the dimension of the curve.")
.def_pickle(curve_pickle_suite<curve_rotation_t>());

class_<curve_SE3_t, boost::noncopyable, bases<curve_abc_t>, boost::shared_ptr<curve_SE3_callback> >("curve_SE3")
.def("__call__", &se3Return, "Evaluate the curve at the given time. Return as an homogeneous matrix.",
Expand All @@ -535,6 +562,7 @@ BOOST_PYTHON_MODULE(curves) {
args("self", "time"))
.def("translation", &se3returnTranslation, "Output the rotation (as a vector 3) at the given time.",
args("self", "time"))
.def_pickle(curve_pickle_suite<curve_SE3_t>())
#ifdef CURVES_WITH_PINOCCHIO_SUPPORT
.def("evaluateAsSE3", &se3ReturnPinocchio,
"Evaluate the curve at the given time. Return as a pinocchio.SE3 object", args("self", "t"))
Expand Down Expand Up @@ -572,7 +600,8 @@ BOOST_PYTHON_MODULE(curves) {
"Loads *this from a binary file.")
//.def(SerializableVisitor<bezier_t>())
.def(bp::self == bp::self)
.def(bp::self != bp::self);
.def(bp::self != bp::self)
.def_pickle(curve_pickle_suite<bezier3_t>());
/** END bezier3 curve**/
/** BEGIN bezier curve**/
class_<bezier_t, bases<curve_abc_t>, boost::shared_ptr<bezier_t> >("bezier", init<>())
Expand All @@ -599,6 +628,7 @@ BOOST_PYTHON_MODULE(curves) {
.def(bp::self == bp::self)
.def(bp::self != bp::self)
//.def(SerializableVisitor<bezier_t>())
.def_pickle(curve_pickle_suite<bezier_t>())
;
/** END bezier curve**/
/** BEGIN variable points bezier curve**/
Expand Down Expand Up @@ -642,7 +672,8 @@ BOOST_PYTHON_MODULE(curves) {
.def_readonly("degree", &bezier_linear_variable_t::degree_)
.def_readonly("nbWaypoints", &bezier_linear_variable_t::size_)
.def(bp::self == bp::self)
.def(bp::self != bp::self);
.def(bp::self != bp::self)
.def_pickle(curve_pickle_suite<bezier_linear_variable_t>());

class_<quadratic_variable_t>("cost", no_init)
.add_property("A", &cost_t_quad)
Expand Down Expand Up @@ -696,7 +727,8 @@ BOOST_PYTHON_MODULE(curves) {
.def("loadFromBinary", &polynomial_t::loadFromBinary<polynomial_t>, bp::args("filename"),
"Loads *this from a binary file.")
.def(bp::self == bp::self)
.def(bp::self != bp::self);
.def(bp::self != bp::self)
.def_pickle(curve_pickle_suite<polynomial_t>());

/** END polynomial function**/
/** BEGIN piecewise curve function **/
Expand Down Expand Up @@ -758,7 +790,8 @@ BOOST_PYTHON_MODULE(curves) {
.def("loadFromBinary", &piecewise_t::loadFromBinary<piecewise_t>, bp::args("filename"),
"Loads *this from a binary file.")
.def(bp::self == bp::self)
.def(bp::self != bp::self);
.def(bp::self != bp::self)
.def_pickle(curve_pickle_suite<piecewise_t>());

class_<piecewise_bezier_t, bases<curve_abc_t>, boost::shared_ptr<piecewise_bezier_t> >("piecewise_bezier", init<>())
.def("__init__", make_constructor(&wrapPiecewiseBezierConstructor, default_call_policies(), arg("curve")),
Expand Down Expand Up @@ -786,7 +819,8 @@ BOOST_PYTHON_MODULE(curves) {
.def("loadFromBinary", &piecewise_bezier_t::loadFromBinary<piecewise_bezier_t>, bp::args("filename"),
"Loads *this from a binary file.")
.def(bp::self == bp::self)
.def(bp::self != bp::self);
.def(bp::self != bp::self)
.def_pickle(curve_pickle_suite<piecewise_bezier_t>());

class_<piecewise_linear_bezier_t, bases<curve_abc_t>, boost::shared_ptr<piecewise_linear_bezier_t> >("piecewise_bezier_linear", init<>())
.def("__init__", make_constructor(&wrapPiecewiseBezierLinearConstructor, default_call_policies(), arg("curve")),
Expand Down Expand Up @@ -814,7 +848,8 @@ BOOST_PYTHON_MODULE(curves) {
.def("loadFromBinary", &piecewise_linear_bezier_t::loadFromBinary<piecewise_linear_bezier_t>,
bp::args("filename"), "Loads *this from a binary file.")
.def(bp::self == bp::self)
.def(bp::self != bp::self);
.def(bp::self != bp::self)
.def_pickle(curve_pickle_suite<piecewise_linear_bezier_t>());

class_<piecewise_SE3_t, bases<curve_SE3_t>, boost::shared_ptr<piecewise_SE3_t> >("piecewise_SE3", init<>())
.def("__init__", make_constructor(&wrapPiecewiseSE3Constructor, default_call_policies(), arg("curve")),
Expand Down Expand Up @@ -847,6 +882,7 @@ BOOST_PYTHON_MODULE(curves) {
"Loads *this from a binary file.")
.def(bp::self == bp::self)
.def(bp::self != bp::self)
.def_pickle(curve_pickle_suite<piecewise_SE3_t>())
#ifdef CURVES_WITH_PINOCCHIO_SUPPORT
.def("append", &addFinalSE3,
"Append a new linear SE3 curve at the end of the piecewise curve, defined between self.max() "
Expand Down Expand Up @@ -875,7 +911,8 @@ BOOST_PYTHON_MODULE(curves) {
.def("loadFromBinary", &exact_cubic_t::loadFromBinary<exact_cubic_t>, bp::args("filename"),
"Loads *this from a binary file.")
.def(bp::self == bp::self)
.def(bp::self != bp::self);
.def(bp::self != bp::self)
.def_pickle(curve_pickle_suite<exact_cubic_t>());

/** END exact_cubic curve**/
/** BEGIN cubic_hermite_spline **/
Expand All @@ -894,7 +931,8 @@ BOOST_PYTHON_MODULE(curves) {
.def("loadFromBinary", &cubic_hermite_spline_t::loadFromBinary<cubic_hermite_spline_t>, bp::args("filename"),
"Loads *this from a binary file.")
.def(bp::self == bp::self)
.def(bp::self != bp::self);
.def(bp::self != bp::self)
.def_pickle(curve_pickle_suite<cubic_hermite_spline_t>());

/** END cubic_hermite_spline **/
/** BEGIN curve constraints**/
Expand Down Expand Up @@ -941,7 +979,8 @@ BOOST_PYTHON_MODULE(curves) {
.def("loadFromBinary", &SO3Linear_t::loadFromBinary<SO3Linear_t>, bp::args("filename"),
"Loads *this from a binary file.")
.def(bp::self == bp::self)
.def(bp::self != bp::self);
.def(bp::self != bp::self)
.def_pickle(curve_pickle_suite<SO3Linear_t>());

/** END SO3 Linear**/
/** BEGIN SE3 Curve**/
Expand Down Expand Up @@ -998,6 +1037,7 @@ BOOST_PYTHON_MODULE(curves) {
"Loads *this from a binary file.")
.def(bp::self == bp::self)
.def(bp::self != bp::self)
.def_pickle(curve_pickle_suite<SE3Curve_t>())
#ifdef CURVES_WITH_PINOCCHIO_SUPPORT
.def("__init__",
make_constructor(&wrapSE3CurveFromSE3Pinocchio, default_call_policies(),
Expand Down
43 changes: 42 additions & 1 deletion python/test/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
from numpy import array, array_equal, isclose, random, zeros
from numpy.linalg import norm

import pickle
from curves import (CURVES_WITH_PINOCCHIO_SUPPORT, Quaternion, SE3Curve, SO3Linear, bezier, bezier3, convert_to_bezier,
convert_to_hermite, convert_to_polynomial, cubic_hermite_spline, curve_constraints, exact_cubic,
piecewise, piecewise_SE3, polynomial)
Expand Down Expand Up @@ -136,6 +136,10 @@ def test_bezier(self):
b.loadFromText("serialization_curve.test")
self.assertTrue((a(0.4) == b(0.4)).all())
os.remove("serialization_curve.test")

a_pickled = pickle.dumps(a)
a_from_pickle = pickle.loads(a_pickled)
self.assertEqual(a_from_pickle, a)
return

def test_bezier3(self):
Expand Down Expand Up @@ -235,6 +239,9 @@ def test_bezier3(self):
b.loadFromText("serialization_curve.test")
self.assertTrue((a(0.4) == b(0.4)).all())
os.remove("serialization_curve.test")
a_pickled = pickle.dumps(a)
a_from_pickle = pickle.loads(a_pickled)
self.assertEqual(a_from_pickle, a)
return

def test_polynomial(self):
Expand Down Expand Up @@ -263,6 +270,9 @@ def test_polynomial(self):
b.loadFromText("serialization_curve.test")
self.assertTrue((a(0.4) == b(0.4)).all())
os.remove("serialization_curve.test")
a_pickled = pickle.dumps(a)
a_from_pickle = pickle.loads(a_pickled)
self.assertEqual(a_from_pickle, a)
return

def test_polynomial_from_boundary_condition(self):
Expand Down Expand Up @@ -338,6 +348,9 @@ def test_cubic_hermite_spline(self):
b.loadFromText("serialization_curve.test")
self.assertTrue((a(0.4) == b(0.4)).all())
os.remove("serialization_curve.test")
a_pickled = pickle.dumps(a)
a_from_pickle = pickle.loads(a_pickled)
self.assertEqual(a_from_pickle, a)
return

def test_piecewise_polynomial_curve(self):
Expand Down Expand Up @@ -371,6 +384,9 @@ def test_piecewise_polynomial_curve(self):
pc_test.loadFromText("serialization_pc.test")
self.assertTrue((pc(0.4) == pc_test(0.4)).all())
os.remove("serialization_pc.test")
pc_pickled = pickle.dumps(pc)
pc_from_pickle = pickle.loads(pc_pickled)
self.assertEqual(pc_from_pickle, pc)

waypoints3 = array([[1., 2., 3., 0.6, -9.], [-1., 1.6, 1.7, 6.7, 14]]).transpose()
c = polynomial(waypoints3, 3., 5.2)
Expand Down Expand Up @@ -490,6 +506,9 @@ def test_piecewise_bezier_curve(self):
pc_test.loadFromText("serialization_pc.test")
self.assertTrue((pc(0.4) == pc_test(0.4)).all())
os.remove("serialization_pc.test")
pc_pickled = pickle.dumps(pc)
pc_from_pickle = pickle.loads(pc_pickled)
self.assertEqual(pc_from_pickle, pc)
return

def test_piecewise_cubic_hermite_curve(self):
Expand Down Expand Up @@ -523,6 +542,9 @@ def test_piecewise_cubic_hermite_curve(self):
pc_test.loadFromText("serialization_pc.test")
self.assertTrue((pc(0.4) == pc_test(0.4)).all())
os.remove("serialization_pc.test")
pc_pickled = pickle.dumps(pc)
pc_from_pickle = pickle.loads(pc_pickled)
self.assertEqual(pc_from_pickle, pc)
return

def test_exact_cubic(self):
Expand All @@ -545,6 +567,9 @@ def test_exact_cubic(self):
b.loadFromText("serialization_pc.test")
self.assertTrue((a(0.4) == b(0.4)).all())
os.remove("serialization_pc.test")
a_pickled = pickle.dumps(a)
a_from_pickle = pickle.loads(a_pickled)
self.assertEqual(a_from_pickle, a)
return

def test_exact_cubic_constraint(self):
Expand Down Expand Up @@ -666,6 +691,10 @@ def test_piecewise_se3_curve(self):
pc_bin.loadFromBinary("serialization_curve")
self.compareCurves(pc, pc_bin)

pc_pickled = pickle.dumps(pc)
pc_from_pickle = pickle.loads(pc_pickled)
self.assertEqual(pc_from_pickle, pc)

se3_3 = SE3Curve(se3(max), se3_2(max2 - 0.5), max2, max2 + 1.5)
pc.append(se3_3)
self.assertFalse(pc.is_continuous(0))
Expand Down Expand Up @@ -885,6 +914,10 @@ def test_so3_linear_serialization(self):
so3_bin.loadFromBinary("serialization_curve")
self.compareCurves(so3Rot, so3_bin)

so3Rot_pickled = pickle.dumps(so3Rot)
so3Rot_from_pickle = pickle.loads(so3Rot_pickled)
self.assertEqual(so3Rot_from_pickle, so3Rot)

def test_se3_curve_linear(self):
print("test SE3 Linear")
init_quat = Quaternion.Identity()
Expand Down Expand Up @@ -1235,6 +1268,10 @@ def test_se3_serialization(self):
se3_bin.loadFromBinary("serialization_curve")
self.compareCurves(se3_linear, se3_bin)

se3_pickled = pickle.dumps(se3_linear)
se3_from_pickle = pickle.loads(se3_pickled)
self.assertEqual(se3_from_pickle, se3_linear)

# test from two curves :
init_quat = Quaternion.Identity()
end_quat = Quaternion(sqrt(2.) / 2., sqrt(2.) / 2., 0, 0)
Expand Down Expand Up @@ -1262,6 +1299,10 @@ def test_se3_serialization(self):
se3_bin.loadFromBinary("serialization_curve")
self.compareCurves(se3_curves, se3_bin)

se3_pickled = pickle.dumps(se3_curves)
se3_from_pickle = pickle.loads(se3_pickled)
self.assertEqual(se3_from_pickle, se3_curves)

def test_operatorEqual(self):
# test with bezier
waypoints = array([[1., 2., 3.], [4., 5., 6.], [4., 5., 6.], [4., 5., 6.], [4., 5., 6.]]).transpose()
Expand Down

0 comments on commit 4bbc08a

Please sign in to comment.