Skip to content

Commit

Permalink
Merge pull request #7 from ami-iit/slerp
Browse files Browse the repository at this point in the history
Add slerp and override getattr method
  • Loading branch information
Giulero authored Mar 6, 2023
2 parents dc549fd + c974075 commit a3a1d9b
Show file tree
Hide file tree
Showing 4 changed files with 160 additions and 10 deletions.
11 changes: 7 additions & 4 deletions examples/manifold_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
# conda install matplotlib scipy

import casadi as cs
import matplotlib

import matplotlib.pyplot as plt
import numpy as np
from matplotlib import animation
Expand All @@ -25,13 +27,13 @@
for k in range(N):
vector_SO3 = SO3Tangent(vel[k] * dt)
rotation_SO3 = SO3(quat[k])
opti.subject_to(quat[k + 1] == (vector_SO3 + rotation_SO3).as_quat().coeffs())
opti.subject_to(quat[k + 1] == (vector_SO3 + rotation_SO3).as_quat())


C = sum(cs.sumsqr(vel[i]) for i in range(N)) + T

# Initial rotation and velocity
opti.subject_to(quat[0] == SO3.Identity().as_quat().coeffs())
opti.subject_to(quat[0] == SO3.Identity().as_quat())
opti.subject_to(vel[0] == 0)
opti.subject_to(opti.bounded(0, T, 10))

Expand All @@ -47,7 +49,7 @@
opti.subject_to(vel[N - 1] == 0)
final_delta_increment = SO3Tangent([cs.pi / 3, cs.pi / 6, cs.pi / 2])

opti.subject_to(quat[N] == (final_delta_increment + SO3.Identity()).as_quat().coeffs())
opti.subject_to(quat[N] == (final_delta_increment + SO3.Identity()).as_quat())

opti.minimize(C)

Expand All @@ -71,7 +73,8 @@
plt.plot(np.linspace(0, time, N), v)

figure = plt.figure()
axes = mplot3d.Axes3D(figure)
axes = figure.add_subplot(projection="3d")

x_cords = np.array([1, 0, 0])
y_cords = np.array([0, 1, 0])
z_cords = np.array([0, 0, 1])
Expand Down
90 changes: 90 additions & 0 deletions examples/slerp_so3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Please note that for running this example you need to install `matplotlib` and `scipy`.
# You can do this by running the following command in your terminal:
# pip install matplotlib scipy
# If you are using anaconda, you can also run the following command:
# conda install matplotlib scipy

import casadi as cs
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import animation

from liecasadi import SO3, SO3Tangent

N = 10

r1 = SO3.Identity()

final_delta_increment = SO3Tangent([cs.pi / 3, cs.pi / 6, cs.pi / 2])

r2 = final_delta_increment + SO3.Identity()

x = SO3.slerp(r1, r2, N)

# If you want to work directly with quaternion, you can use the following code:
# x = Quaternion.slerp(q1, q1, N)
# where q1 and q2 are Quaternion objects.

figure = plt.figure()
axes = figure.add_subplot(projection="3d")
x_cords = np.array([1, 0, 0])
y_cords = np.array([0, 1, 0])
z_cords = np.array([0, 0, 1])

axes.set_box_aspect((1, 1, 1))

(xax,) = axes.plot([0, 1], [0, 0], [0, 0], "red")
(yax,) = axes.plot([0, 0], [0, 1], [0, 0], "green")
(zax,) = axes.plot([0, 0], [0, 0], [0, 1], "blue")

print("qui", x[N - 1].act(x_cords))

# final orientation
x_N = np.array(x[N - 1].act(x_cords)).reshape(
3,
)
y_N = np.array(x[N - 1].act(y_cords)).reshape(
3,
)
z_N = np.array(x[N - 1].act(z_cords)).reshape(
3,
)

(xaxN,) = axes.plot([0, x_N[0]], [0, x_N[1]], [0, x_N[2]], "red")
(yaxN,) = axes.plot([0, y_N[0]], [0, y_N[1]], [0, y_N[2]], "green")
(zaxN,) = axes.plot([0, z_N[0]], [0, z_N[1]], [0, z_N[2]], "blue")


def update_points(i):
x_i = np.array(x[i].act(x_cords)).reshape(
3,
)
y_i = np.array(x[i].act(y_cords)).reshape(
3,
)
z_i = np.array(x[i].act(z_cords)).reshape(
3,
)
# update properties
xax.set_data(np.array([[0, x_i[0]], [0, x_i[1]]]))
xax.set_3d_properties(np.array([0, x_i[2]]), "z")

yax.set_data(np.array([[0, y_i[0]], [0, y_i[1]]]))
yax.set_3d_properties(np.array([0, y_i[2]]), "z")

zax.set_data(np.array([[0, z_i[0]], [0, z_i[1]]]))
zax.set_3d_properties(np.array([0, z_i[2]]), "z")

# return modified axis
return (
xax,
yax,
zax,
)


ani = animation.FuncAnimation(figure, update_points, frames=N, repeat=False)
writergif = animation.PillowWriter(fps=5)
ani.save("animation.gif", writer=writergif)

plt.show()
46 changes: 46 additions & 0 deletions src/liecasadi/quaternion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# GNU Lesser General Public License v2.1 or any later version.

import dataclasses
from typing import List

import casadi as cs

Expand All @@ -13,6 +14,9 @@
class Quaternion:
xyzw: Vector

def __getattr__(self, attr):
return getattr(self.xyzw, attr)

def __repr__(self) -> str:
return f"Quaternion: {self.xyzw}"

Expand Down Expand Up @@ -45,6 +49,9 @@ def __neg__(self) -> "Quaternion":
def __rsub__(self, other: "Quaternion") -> "Quaternion":
return Quaternion(xyzw=self.xyzw - other.xyzw)

def __truediv__(self, other: Scalar) -> "Quaternion":
return Quaternion(xyzw=self.xyzw / other)

def conjugate(self) -> "Quaternion":
return Quaternion(xyzw=cs.vertcat(-self.xyzw[:3], self.xyzw[3]))

Expand Down Expand Up @@ -79,3 +86,42 @@ def z(self) -> float:
@property
def w(self) -> float:
return self.xyzw[3]

def inverse(self):
return self.conjugate() / cs.dot(self.xyzw, self.xyzw)

@staticmethod
def slerp(q1: "Quaternion", q2: "Quaternion", n: Scalar) -> List["Quaternion"]:
"""Spherical linear interpolation between two quaternions
check https://en.wikipedia.org/wiki/Slerp for more details
Args:
q1 (Quaternion): First quaternion
q2 (Quaternion): Second quaternion
n (Scalar): Number of interpolation steps
Returns:
List[Quaternion]: Interpolated quaternion
"""
q1 = q1.coeffs()
q2 = q2.coeffs()
return [Quaternion.slerp_step(q1, q2, t) for t in cs.np.linspace(0, 1, n)]

@staticmethod
def slerp_step(q1: Vector, q2: Vector, t: Scalar) -> Vector:
"""Step for the splerp function
Args:
q1 (Vector): First quaternion
q2 (Vector): Second quaternion
t (Scalar): Interpolation parameter
Returns:
Vector: Interpolated quaternion
"""

dot = cs.dot(q1, q2)
angle = cs.acos(dot)
return Quaternion(
(cs.sin((1.0 - t) * angle) * q1 + cs.sin(t * angle) * q2) / cs.sin(angle)
)
23 changes: 17 additions & 6 deletions src/liecasadi/so3.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import dataclasses
from dataclasses import field
from typing import Union
from typing import Union, List

import casadi as cs
import numpy as np
Expand Down Expand Up @@ -116,7 +116,6 @@ def quaternion_derivative(
omega_in_body_fixed: bool = False,
baumgarte_coefficient: Union[float, None] = None,
):

if baumgarte_coefficient is not None:
baumgarte_term = (
baumgarte_coefficient
Expand All @@ -139,10 +138,22 @@ def quaternion_derivative(
).coeffs()

@staticmethod
def product(q1: Vector, q2: Vector) -> Vector:
p1 = q1[3] * q2[3] - cs.dot(q1[:3], q2[:3])
p2 = q1[3] * q2[:3] + q2[3] * q1[:3] + cs.cross(q1[:3], q2[:3])
return cs.vertcat(p2, p1)
def slerp(r1: "SO3", r2: "SO3", n: int) -> List["SO3"]:
"""
Spherical linear interpolation between two rotations.
Args:
r1 (SO3): First quaternion
r2 (SO3): Second quaternion
n (Scalar): Number of interpolation steps
Returns:
List[SO3]: Interpolated rotations
"""
q1 = r1.as_quat()
q2 = r2.as_quat()
interpolated_quats = Quaternion.slerp(q1, q2, n)
return [SO3(xyzw=q.coeffs()) for q in interpolated_quats]


@dataclasses.dataclass
Expand Down

0 comments on commit a3a1d9b

Please sign in to comment.