Skip to content

Commit

Permalink
Refactor array handling in RBDAlgorithms and KinDynComputations; upda…
Browse files Browse the repository at this point in the history
…te type hints in TorchLike class
GiulioRomualdi committed Jan 9, 2025

Unverified

The committer email address is not verified.
1 parent 05d4afd commit 9ea1952
Showing 3 changed files with 68 additions and 37 deletions.
6 changes: 4 additions & 2 deletions src/adam/core/rbd_algorithms.py
Original file line number Diff line number Diff line change
@@ -57,7 +57,7 @@ def crba(
if link_i.name == self.root_link:
# The first "real" link. The joint is universal.
X_p[i] = self.math.spatial_transform(
self.math.factory.eye(3), self.math.factory.zeros(3, 1)
self.math.factory.eye(3), self.math.factory.zeros(3)
)
Phi[i] = self.math.factory.eye(6)
else:
@@ -408,6 +408,8 @@ def rnea(
Returns:
tau (T): generalized force variables
"""
print(base_velocity.shape)

# TODO: add accelerations
tau = self.math.factory.zeros(self.NDoF + 6, 1)
model_len = self.model.N
@@ -453,7 +455,7 @@ def rnea(
if link_i.name == self.root_link:
# The first "real" link. The joint is universal.
X_p[i] = self.math.spatial_transform(
self.math.factory.eye(3), self.math.factory.zeros(3, 1)
self.math.factory.eye(3), self.math.factory.zeros(3)
)
Phi[i] = self.math.factory.eye(6)
v[i] = B_X_BI @ base_velocity
6 changes: 3 additions & 3 deletions src/adam/pytorch/computations.py
Original file line number Diff line number Diff line change
@@ -210,7 +210,7 @@ def bias_force(
return self.rbdalgos.rnea(
base_transform,
joint_positions,
base_velocity.reshape(6, 1),
base_velocity,
joint_velocities,
self.g,
).array.squeeze()
@@ -238,7 +238,7 @@ def coriolis_term(
return self.rbdalgos.rnea(
base_transform,
joint_positions,
base_velocity.reshape(6, 1),
base_velocity,
joint_velocities,
torch.zeros(6),
).array.squeeze()
@@ -259,7 +259,7 @@ def gravity_term(
return self.rbdalgos.rnea(
base_transform,
joint_positions,
torch.zeros(6).reshape(6, 1),
torch.zeros(6),
torch.zeros(self.NDoF),
self.g,
).array.squeeze()
93 changes: 61 additions & 32 deletions src/adam/pytorch/torch_like.py
Original file line number Diff line number Diff line change
@@ -4,7 +4,7 @@
from typing import Union, Tuple

import numpy as np
import numpy.typing as ntp
import numpy.typing as npt
import torch

from adam.core.spatial_math import ArrayLike, ArrayLikeFactory, SpatialMath
@@ -21,7 +21,7 @@ def __post_init__(self):
if self.array.dtype != torch.float64:
self.array = self.array.double()

def __setitem__(self, idx, value: Union["TorchLike", ntp.ArrayLike]) -> "TorchLike":
def __setitem__(self, idx, value: Union["TorchLike", npt.ArrayLike]) -> "TorchLike":
"""Overrides set item operator"""
if type(self) is type(value):
self.array[idx] = value.array.reshape(self.array[idx].shape)
@@ -52,7 +52,7 @@ def T(self) -> "TorchLike":
x = self.array
return TorchLike(x.permute(*torch.arange(x.ndim - 1, -1, -1)))

def __matmul__(self, other: Union["TorchLike", ntp.ArrayLike]) -> "TorchLike":
def __matmul__(self, other: Union["TorchLike", npt.ArrayLike]) -> "TorchLike":
"""Overrides @ operator"""

if type(self) is type(other):
@@ -62,54 +62,54 @@ def __matmul__(self, other: Union["TorchLike", ntp.ArrayLike]) -> "TorchLike":
else:
return TorchLike(self.array @ torch.tensor(other))

def __rmatmul__(self, other: Union["TorchLike", ntp.ArrayLike]) -> "TorchLike":
def __rmatmul__(self, other: Union["TorchLike", npt.ArrayLike]) -> "TorchLike":
"""Overrides @ operator"""
if type(self) is type(other):
return TorchLike(other.array @ self.array)
else:
return TorchLike(torch.tensor(other) @ self.array)

def __mul__(self, other: Union["TorchLike", ntp.ArrayLike]) -> "TorchLike":
def __mul__(self, other: Union["TorchLike", npt.ArrayLike]) -> "TorchLike":
"""Overrides * operator"""
if type(self) is type(other):
return TorchLike(self.array * other.array)
else:
return TorchLike(self.array * other)

def __rmul__(self, other: Union["TorchLike", ntp.ArrayLike]) -> "TorchLike":
def __rmul__(self, other: Union["TorchLike", npt.ArrayLike]) -> "TorchLike":
"""Overrides * operator"""
if type(self) is type(other):
return TorchLike(other.array * self.array)
else:
return TorchLike(other * self.array)

def __truediv__(self, other: Union["TorchLike", ntp.ArrayLike]) -> "TorchLike":
def __truediv__(self, other: Union["TorchLike", npt.ArrayLike]) -> "TorchLike":
"""Overrides / operator"""
if type(self) is type(other):
return TorchLike(self.array / other.array)
else:
return TorchLike(self.array / other)

def __add__(self, other: Union["TorchLike", ntp.ArrayLike]) -> "TorchLike":
def __add__(self, other: Union["TorchLike", npt.ArrayLike]) -> "TorchLike":
"""Overrides + operator"""
if type(self) is not type(other):
return TorchLike(self.array.squeeze() + other.squeeze())
return TorchLike(self.array.squeeze() + other.array.squeeze())

def __radd__(self, other: Union["TorchLike", ntp.ArrayLike]) -> "TorchLike":
def __radd__(self, other: Union["TorchLike", npt.ArrayLike]) -> "TorchLike":
"""Overrides + operator"""
if type(self) is not type(other):
return TorchLike(self.array.squeeze() + other.squeeze())
return TorchLike(self.array.squeeze() + other.array.squeeze())

def __sub__(self, other: Union["TorchLike", ntp.ArrayLike]) -> "TorchLike":
def __sub__(self, other: Union["TorchLike", npt.ArrayLike]) -> "TorchLike":
"""Overrides - operator"""
if type(self) is type(other):
return TorchLike(self.array.squeeze() - other.array.squeeze())
else:
return TorchLike(self.array.squeeze() - other.squeeze())

def __rsub__(self, other: Union["TorchLike", ntp.ArrayLike]) -> "TorchLike":
def __rsub__(self, other: Union["TorchLike", npt.ArrayLike]) -> "TorchLike":
"""Overrides - operator"""
if type(self) is type(other):
return TorchLike(other.array.squeeze() - self.array.squeeze())
@@ -142,7 +142,7 @@ def eye(x: int) -> "TorchLike":
return TorchLike(torch.eye(x))

@staticmethod
def array(x: ntp.ArrayLike) -> "TorchLike":
def array(x: npt.ArrayLike) -> "TorchLike":
"""
Returns:
TorchLike: vector wrapping x
@@ -185,10 +185,10 @@ def __init__(self):
super().__init__(TorchLikeFactory())

@staticmethod
def sin(x: ntp.ArrayLike) -> "TorchLike":
def sin(x: npt.ArrayLike) -> "TorchLike":
"""
Args:
x (ntp.ArrayLike): angle value
x (npt.ArrayLike): angle value
Returns:
TorchLike: sin value of x
@@ -202,10 +202,10 @@ def sin(x: ntp.ArrayLike) -> "TorchLike":
)

@staticmethod
def cos(x: ntp.ArrayLike) -> "TorchLike":
def cos(x: npt.ArrayLike) -> "TorchLike":
"""
Args:
x (ntp.ArrayLike): angle value
x (npt.ArrayLike): angle value
Returns:
TorchLike: cos value of x
@@ -220,37 +220,66 @@ def cos(x: ntp.ArrayLike) -> "TorchLike":
)

@staticmethod
def outer(x: ntp.ArrayLike, y: ntp.ArrayLike) -> "TorchLike":
def outer(x: npt.ArrayLike, y: npt.ArrayLike) -> "TorchLike":
"""
Args:
x (ntp.ArrayLike): vector
y (ntp.ArrayLike): vector
x (npt.ArrayLike): vector
y (npt.ArrayLike): vector
Returns:
TorchLike: outer product of x and y
"""
return TorchLike(torch.outer(torch.tensor(x), torch.tensor(y)))

@staticmethod
def skew(x: Union["TorchLike", ntp.ArrayLike]) -> "TorchLike":
def skew(x: Union[TorchLike, npt.ArrayLike]) -> TorchLike:
"""
Construct the skew-symmetric matrix from a 3D vector.
Args:
x (Union[TorchLike, ntp.ArrayLike]): vector
x (Union[TorchLike, npt.ArrayLike]): A 3D vector or a batch of 3D vectors.
Returns:
TorchLike: skew matrix from x
TorchLike: The skew-symmetric matrix (3x3 for a single vector, Nx3x3 for a batch).
"""
if not isinstance(x, TorchLike):
return TorchLike(
torch.tensor([[0, -x[2], x[1]], [x[2], 0, -x[0]], [-x[1], x[0], 0]])
# Handle non-TorchLike inputs
if isinstance(x, TorchLike):
x = x.array # Convert to torch.Tensor if necessary
elif not isinstance(x, torch.Tensor):
x = torch.tensor(x)

# Check shape: must be either (3,) or (..., 3)
if x.shape[-1] != 3:
raise ValueError(
f"Input must be a 3D vector or a batch of 3D vectors, but got shape: {x.shape}"
)
x = x.array
return TorchLike(
torch.tensor([[0, -x[2], x[1]], [x[2], 0, -x[0]], [-x[1], x[0], 0]])

# Determine if the input has a batch dimension
has_batch = len(x.shape) > 1

# Add a batch dimension if the input is a single vector
if not has_batch:
x = x.unsqueeze(0)

# Compute skew-symmetric matrix for each vector
zero = torch.zeros_like(x[..., 0])
skew_matrices = torch.stack(
(
torch.stack((zero, -x[..., 2], x[..., 1]), dim=-1),
torch.stack((x[..., 2], zero, -x[..., 0]), dim=-1),
torch.stack((-x[..., 1], x[..., 0], zero), dim=-1),
),
dim=-2,
)

# Squeeze back to remove the added batch dimension only if the input was not batched
if not has_batch:
skew_matrices = skew_matrices.squeeze(0)

return TorchLike(skew_matrices)

@staticmethod
def vertcat(*x: ntp.ArrayLike) -> "TorchLike":
def vertcat(*x: npt.ArrayLike) -> "TorchLike":
"""
Returns:
TorchLike: vertical concatenation of x
@@ -262,7 +291,7 @@ def vertcat(*x: ntp.ArrayLike) -> "TorchLike":
return TorchLike(v)

@staticmethod
def horzcat(*x: ntp.ArrayLike) -> "TorchLike":
def horzcat(*x: npt.ArrayLike) -> "TorchLike":
"""
Returns:
TorchLike: horizontal concatenation of x
@@ -274,10 +303,10 @@ def horzcat(*x: ntp.ArrayLike) -> "TorchLike":
return TorchLike(v)

@staticmethod
def stack(x: Tuple[Union[TorchLike, ntp.ArrayLike]], axis: int = 0) -> TorchLike:
def stack(x: Tuple[Union[TorchLike, npt.ArrayLike]], axis: int = 0) -> TorchLike:
"""
Args:
x (Tuple[Union[TorchLike, ntp.ArrayLike]]): elements to stack
x (Tuple[Union[TorchLike, npt.ArrayLike]]): elements to stack
axis (int, optional): axis to stack. Defaults to 0.
Returns:

0 comments on commit 9ea1952

Please sign in to comment.