Skip to content

Commit

Permalink
Make rotation creation batchable:
Browse files Browse the repository at this point in the history
- attempt to mimic the numpy like behaviour in casadi like vector
- implement stack function in CasadiSpatialMath
  • Loading branch information
GiulioRomualdi committed Jan 9, 2025
1 parent 454ae37 commit 89b6a0b
Show file tree
Hide file tree
Showing 6 changed files with 223 additions and 43 deletions.
74 changes: 69 additions & 5 deletions src/adam/casadi/casadi_like.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (C) Istituto Italiano di Tecnologia (IIT). All rights reserved.

from dataclasses import dataclass
from typing import Union
from typing import Union, Tuple

import casadi as cs
import numpy.typing as npt
Expand Down Expand Up @@ -93,14 +93,52 @@ def __setitem__(self, idx, value: Union["CasadiLike", npt.ArrayLike]):
else:
self.array[idx] = value.array if isinstance(value, CasadiLike) else value

@property
def shape(self) -> Tuple[int]:
"""
Returns:
Tuple[int]: shape of the array
"""

# We force to have the same interface as numpy
if self.array.shape[1] == 1 and self.array.shape[0] == 1:
return tuple()
elif self.array.shape[1] == 1:
return (self.array.shape[0],)

return self.array.shape

def reshape(self, *args) -> "CasadiLike":
"""
Args:
*args: new shape
"""
args = tuple(filter(None, args))
if len(args) > 2:
raise ValueError(f"Only 1D and 2D arrays are supported, The shape is {args}")

# For 1D reshape, just call CasADi reshape directly
if len(args) == 1:
new_array = cs.reshape(self.array, args[0], 1)
else:
# For 2D reshape, transpose before and after to mimic row-major behavior
new_array = cs.reshape(self.array.T, args[1], args[0]).T

return CasadiLike(new_array)

def __getitem__(self, idx) -> "CasadiLike":
"""Overrides get item operator"""
if idx is Ellipsis:
# Handle the case where only Ellipsis is passed
return CasadiLike(self.array)
elif isinstance(idx, tuple) and Ellipsis in idx:
# Handle the case where Ellipsis is part of a tuple
idx = tuple(slice(None) if k is Ellipsis else k for k in idx)
if len(self.shape) == 2:
idx = tuple(slice(None) if k is Ellipsis else k for k in idx)
else:
# take the values that are not Ellipsis
idx = idx[: idx.index(Ellipsis)] + idx[idx.index(Ellipsis) + 1 :]
idx = idx[0] if len(idx) == 1 else idx

return CasadiLike(self.array[idx])
else:
# For other cases, delegate to the CasADi object's __getitem__
Expand Down Expand Up @@ -235,7 +273,7 @@ def sin(x: npt.ArrayLike) -> "CasadiLike":
Returns:
CasadiLike: the sin value of x
"""
return CasadiLike(cs.sin(x))
return CasadiLike(cs.sin(x.array) if isinstance(x, CasadiLike) else cs.sin(x))

@staticmethod
def cos(x: npt.ArrayLike) -> "CasadiLike":
Expand All @@ -246,7 +284,7 @@ def cos(x: npt.ArrayLike) -> "CasadiLike":
Returns:
CasadiLike: the cos value of x
"""
return CasadiLike(cs.cos(x))
return CasadiLike(cs.cos(x.array) if isinstance(x, CasadiLike) else cs.cos(x))

@staticmethod
def outer(x: npt.ArrayLike, y: npt.ArrayLike) -> "CasadiLike":
Expand Down Expand Up @@ -283,6 +321,32 @@ def horzcat(*x) -> "CasadiLike":
y = [xi.array if isinstance(xi, CasadiLike) else xi for xi in x]
return CasadiLike(cs.horzcat(*y))

@staticmethod
def stack(x: Tuple[Union[CasadiLike, npt.ArrayLike]], axis: int = 0) -> CasadiLike:
"""
Args:
x (Tuple[Union[CasadiLike, npt.ArrayLike]]): tuple of arrays
axis (int): axis to stack
Returns:
CasadiLike: stacked array
Notes:
This function is here for compatibility with the numpy_like implementation.
"""

# check that the elements size are the same
for i in range(0, len(x)):
if len(x[i].shape) == 2:
raise ValueError(
f"All input arrays must shape[1] != 2, {x[i].shape} found"
)

if axis != -1 and axis != 1:
raise ValueError(f"Axis must be 1 or -1, {axis} found")

return SpatialMath.vertcat(*x)


if __name__ == "__main__":
math = SpatialMath()
Expand Down
90 changes: 67 additions & 23 deletions src/adam/core/spatial_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ class SpatialMath:
"""

def __init__(self, factory: ArrayLikeFactory):
self._factory = factory
self._factory: ArrayLikeFactory = factory

@property
def factory(self) -> ArrayLikeFactory:
Expand Down Expand Up @@ -184,6 +184,63 @@ def cos(self, x: npt.ArrayLike) -> npt.ArrayLike:
def skew(self, x):
pass

@staticmethod
@abc.abstractmethod
def stack(x: npt.ArrayLike, axis: int = 0) -> npt.ArrayLike:
"""
Args:
x (npt.ArrayLike): matrix
axis (int): axis
Returns:
npt.ArrayLike: stack matrix x along axis
"""
pass

def _axis_angle_rotation(self, axis: str, angle: npt.ArrayLike) -> npt.ArrayLike:
"""
Return the rotation matrices for one of the rotations about an axis
of which Euler angles describe, for each value of the angle given.
Args:
axis: Axis label "X", "Y", or "Z".
angle: a tensor of the form (B) or ()
Returns:
Rotation matrices as a tensor of shape (B, 3, 3) or (3, 3),
accordingly to the angle shape.
"""

# Use len(angle.shape) to check dimensions:
if len(angle.shape) == 0 and len(angle.shape) == 1:
raise ValueError(
f"Angle must be a vector or a scalar. The shape is {angle.shape}."
)

cos = self.cos(angle)
sin = self.sin(angle)
one = self.factory.ones_like(angle)
zero = self.factory.zeros_like(angle)

if axis == "X":
R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
elif axis == "Y":
R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
elif axis == "Z":
R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
else:
raise ValueError('Axis must be one of {"X", "Y", "Z"}.')

# Stack into shape (..., 9)
R = self.stack(R_flat, axis=-1)

# Reshape to (..., 3, 3).
# If angle is scalar (), it becomes (3, 3).
# If angle has shape (B,), it becomes (B, 3, 3).
R = R.reshape(*angle.shape, 3, 3)

return R

def R_from_axis_angle(self, axis: npt.ArrayLike, q: npt.ArrayLike) -> npt.ArrayLike:
"""
Args:
Expand All @@ -208,13 +265,7 @@ def Rx(self, q: npt.ArrayLike) -> npt.ArrayLike:
Returns:
npt.ArrayLike: rotation matrix around x axis
"""
R = self.factory.eye(3)
cq, sq = self.cos(q), self.sin(q)
R[1, 1] = cq
R[1, 2] = -sq
R[2, 1] = sq
R[2, 2] = cq
return R
return self._axis_angle_rotation("X", q)

def Ry(self, q: npt.ArrayLike) -> npt.ArrayLike:
"""
Expand All @@ -224,13 +275,7 @@ def Ry(self, q: npt.ArrayLike) -> npt.ArrayLike:
Returns:
npt.ArrayLike: rotation matrix around y axis
"""
R = self.factory.eye(3)
cq, sq = self.cos(q), self.sin(q)
R[0, 0] = cq
R[0, 2] = sq
R[2, 0] = -sq
R[2, 2] = cq
return R
return self._axis_angle_rotation("Y", q)

def Rz(self, q: npt.ArrayLike) -> npt.ArrayLike:
"""
Expand All @@ -240,13 +285,8 @@ def Rz(self, q: npt.ArrayLike) -> npt.ArrayLike:
Returns:
npt.ArrayLike: rotation matrix around z axis
"""
R = self.factory.eye(3)
cq, sq = self.cos(q), self.sin(q)
R[0, 0] = cq
R[0, 1] = -sq
R[1, 0] = sq
R[1, 1] = cq
return R
print(q, q.shape)
return self._axis_angle_rotation("Z", q)

def H_revolute_joint(
self,
Expand Down Expand Up @@ -320,7 +360,11 @@ def R_from_RPY(self, rpy: npt.ArrayLike) -> npt.ArrayLike:
Returns:
npt.ArrayLike: Rotation matrix
"""
return self.Rz(rpy[2]) @ self.Ry(rpy[1]) @ self.Rx(rpy[0])
if isinstance(rpy, list):
rpy = self.factory.array(rpy)

print(rpy)
return self.Rz(rpy[..., 2]) @ self.Ry(rpy[..., 1]) @ self.Rx(rpy[..., 0])

def X_revolute_joint(
self,
Expand Down
24 changes: 20 additions & 4 deletions src/adam/jax/jax_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


from dataclasses import dataclass
from typing import Union
from typing import Union, Tuple

import jax.numpy as jnp
import numpy.typing as npt
Expand Down Expand Up @@ -35,7 +35,7 @@ def shape(self):
return self.array.shape

def reshape(self, *args):
return self.array.reshape(*args)
return JaxLike(self.array.reshape(*args, order="C"))

@property
def T(self) -> "JaxLike":
Expand Down Expand Up @@ -176,7 +176,9 @@ def sin(x: npt.ArrayLike) -> "JaxLike":
Returns:
JaxLike: sin of x
"""
return JaxLike(jnp.sin(x))
return (
JaxLike(jnp.sin(x.array)) if isinstance(x, JaxLike) else JaxLike(jnp.sin(x))
)

@staticmethod
def cos(x: npt.ArrayLike) -> "JaxLike":
Expand All @@ -187,7 +189,9 @@ def cos(x: npt.ArrayLike) -> "JaxLike":
Returns:
JaxLike: cos of x
"""
return JaxLike(jnp.cos(x))
return (
JaxLike(jnp.cos(x.array)) if isinstance(x, JaxLike) else JaxLike(jnp.cos(x))
)

@staticmethod
def outer(x: npt.ArrayLike, y: npt.ArrayLike) -> "JaxLike":
Expand Down Expand Up @@ -240,3 +244,15 @@ def horzcat(*x) -> "JaxLike":
else:
v = jnp.hstack([x[i] for i in range(len(x))])
return JaxLike(v)

@staticmethod
def stack(x: Tuple[Union[JaxLike, npt.ArrayLike]], axis: int = 0) -> JaxLike:
"""
Returns:
JaxLike: Stack of x
"""
if isinstance(x[0], JaxLike):
v = jnp.stack([x[i].array for i in range(len(x))], axis=axis)
else:
v = jnp.stack(x, axis=axis)
return JaxLike(v)
38 changes: 31 additions & 7 deletions src/adam/numpy/numpy_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


from dataclasses import dataclass
from typing import Union
from typing import Union, Tuple

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -32,7 +32,7 @@ def shape(self):
return self.array.shape

def reshape(self, *args):
return self.array.reshape(*args)
return NumpyLike(self.array.reshape(*args))

@property
def T(self) -> "NumpyLike":
Expand Down Expand Up @@ -178,7 +178,11 @@ def sin(x: npt.ArrayLike) -> "NumpyLike":
Returns:
NumpyLike: sin value of x
"""
return NumpyLike(np.sin(x))
return (
NumpyLike(np.sin(x.array))
if isinstance(x, NumpyLike)
else NumpyLike(np.sin(x))
)

@staticmethod
def cos(x: npt.ArrayLike) -> "NumpyLike":
Expand All @@ -189,7 +193,11 @@ def cos(x: npt.ArrayLike) -> "NumpyLike":
Returns:
NumpyLike: cos value of x
"""
return NumpyLike(np.cos(x))
return (
NumpyLike(np.cos(x.array))
if isinstance(x, NumpyLike)
else NumpyLike(np.cos(x))
)

@staticmethod
def outer(x: npt.ArrayLike, y: npt.ArrayLike) -> "NumpyLike":
Expand Down Expand Up @@ -238,7 +246,23 @@ def skew(x: Union["NumpyLike", npt.ArrayLike]) -> "NumpyLike":
Returns:
NumpyLike: the skew symmetric matrix from x
"""
if not isinstance(x, NumpyLike):
return -np.cross(np.array(x), np.eye(3), axisa=0, axisb=0)
x = x.array
if isinstance(x, NumpyLike):
x = x.array

return NumpyLike(-np.cross(np.array(x), np.eye(3), axisa=0, axisb=0))

@staticmethod
def stack(x: Tuple[Union[NumpyLike, npt.ArrayLike]], axis: int = 0) -> NumpyLike:
"""
Args:
x (Tuple[Union[NumpyLike, npt.ArrayLike]]): elements to stack
axis (int): axis to stack
Returns:
NumpyLike: stacked elements
"""
if isinstance(x[0], NumpyLike):
v = np.stack([x[i].array for i in range(len(x))], axis=axis)
else:
v = np.stack(x, axis=axis)
return NumpyLike(v)
Loading

0 comments on commit 89b6a0b

Please sign in to comment.