Skip to content

Commit

Permalink
decompose affine into simpler transformations (#327)
Browse files Browse the repository at this point in the history
  • Loading branch information
LucaMarconato authored Oct 28, 2023
1 parent a3b1df6 commit da7cc58
Show file tree
Hide file tree
Showing 2 changed files with 265 additions and 0 deletions.
143 changes: 143 additions & 0 deletions src/spatialdata/transformations/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from warnings import warn

import numpy as np
import scipy
import xarray as xr
from xarray import DataArray

Expand Down Expand Up @@ -819,6 +820,148 @@ def _decompose_affine_into_linear_and_translation(affine: Affine) -> tuple[Affin
return linear_transformation, translation_transformation


def _compose_affine_from_linear_and_translation(
linear: ArrayLike, translation: ArrayLike, input_axes: tuple[ValidAxis_t, ...], output_axes: tuple[ValidAxis_t, ...]
) -> Affine:
matrix = np.zeros((linear.shape[0] + 1, linear.shape[1] + 1))
matrix[:-1, :-1] = linear
matrix[:-1, -1] = translation
matrix[-1, -1] = 1
return Affine(matrix, input_axes=input_axes, output_axes=output_axes)


def _decompose_transformation(
transformation: BaseTransformation, input_axes: tuple[ValidAxis_t, ...], simple_decomposition: bool = True
) -> Sequence:
"""
Decompose a given 2D transformation into a sequence of predetermined types of transformations.
Parameters
----------
transformation
The transformation to decompose. It is assumed to be of a type that can be represented as a single affine
transformation. It should leave the input axes unmodified, and it should not transform the c channel, if this
is present.
input_axes
The axes of the data the transformation is to be applied to
simple_decomposition
If true, decomposes a transformation into it's linear part (affine without translation) and translation part,
otherwise decomposes it into a sequence of reflection, rotation, shear, scale, translation.
Returns
-------
sequence
Returns a sequence of transformations (class :class:`~spatialdata.transformations.Sequence`) which operates only
on the spatial part (no c channel). The output sequence will contain either 2 either 5 transformations in the
following order (the first is applied first).
Case `simple_decomposition = True`.
1. Linear part (affine): linear part of the affine transformation, represented as a
:class:`~spatialdata.transformations.Affine` transformation.
2. Translation. Represented as a :class:`~spatialdata.transformations.Translation` transformation.
Case `simple_decomposition = False`.
1. Reflection. Represented as :class:`~spatialdata.transformations.Scale` transformation with elements in
{1, -1}.
2. Rotation. Represented as an :class:`~spatialdata.transformations.Affine` transformation which in its
matrix form presents itself as an homogeneous affine matrix with no translation part and determinant 1.
Please look at the source code of this function if you need to recover the angle theta.
3. Shear. Represented as an :class:`~spatialdata.transformations.Affine` transformation which in its matrix
form presents itself as an homogeneous affine matrix with no translation part. The matrix is upper
triangular with diagonal elements all equal to 1.
4. Scale. Represented as a :class:`~spatialdata.transformations.Scale` transformation with positive
elements.
5. Translation. Represented as a :class:`~spatialdata.transformations.Translation` transformation.
Note that some of these transformations may be identity transformations.
"""
output_axes = _get_current_output_axes(transformation=transformation, input_axes=input_axes)
if input_axes != output_axes:
raise ValueError("The transformation should leave the input axes unmodified.")
if "z" in input_axes:
raise ValueError("The transformation should not transform the z axis.")
affine = transformation.to_affine(input_axes=input_axes, output_axes=output_axes)
matrix = affine.matrix
if "c" in input_axes:
c_index = input_axes.index("c")
if (
matrix[c_index, c_index] != 1
or np.linalg.norm(matrix[c_index, :]) != 1
or np.linalg.norm(matrix[:, c_index]) != 1
):
raise ValueError("The transformation should not transform the c channel.")
axes = input_axes[:c_index] + input_axes[c_index + 1 :]
m = np.delete(matrix, c_index, 0)
m = np.delete(m, c_index, 1)
else:
axes = input_axes
m = matrix

translation_part = m[:-1, -1]
linear_part = m[:-1, :-1]

if simple_decomposition:
translation = Translation(translation_part, axes=axes)
linear = _compose_affine_from_linear_and_translation(
linear=linear_part,
translation=np.zeros(linear_part.shape[0]),
input_axes=axes,
output_axes=axes,
)
sequence = Sequence([linear, translation])
else:
# qr factorization
a = linear_part
r, q = scipy.linalg.rq(a)

theta = np.arctan2(q[1, 0], q[0, 0])
rotation_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])

scale_matrix = np.diag(np.abs(np.diag(r)))
shear_matrix = np.linalg.inv(scale_matrix) @ r
assert np.allclose(scale_matrix @ shear_matrix, r)
d = np.diag(np.diag(shear_matrix))

qq = rotation_matrix.T @ q
# check that qq is a diagonal matrix with diagonal values in {-1, 1}
assert np.allclose(np.diag(qq) ** 2, np.ones(qq.shape[0]))
assert np.isclose(np.sum(np.abs(qq.ravel())), qq.shape[0])
assert np.allclose(rotation_matrix @ qq, q)

adjusted_shear_matrix = shear_matrix @ d
adjusted_rotation_matrix = d @ rotation_matrix @ d
assert np.allclose(
adjusted_rotation_matrix @ adjusted_rotation_matrix.T, np.eye(adjusted_rotation_matrix.shape[0])
)
adjusted_qq = d @ qq

aaa = scale_matrix @ shear_matrix @ d @ d @ rotation_matrix @ d @ d @ qq
assert np.allclose(a, aaa)
aa = scale_matrix @ adjusted_shear_matrix @ adjusted_rotation_matrix @ adjusted_qq
assert np.allclose(a, aa)

scale = Scale(np.diag(scale_matrix), axes=axes)
shear = _compose_affine_from_linear_and_translation(
linear=adjusted_shear_matrix,
translation=np.zeros(shear_matrix.shape[0]),
input_axes=axes,
output_axes=axes,
)
rotation = _compose_affine_from_linear_and_translation(
linear=adjusted_rotation_matrix,
translation=np.zeros(rotation_matrix.shape[0]),
input_axes=axes,
output_axes=axes,
)
inversion = Scale(np.diag(adjusted_qq), axes=axes)
translation = Translation(translation_part, axes=axes)
sequence = Sequence([inversion, rotation, shear, scale, translation])
check_m = sequence.to_affine_matrix(input_axes=input_axes, output_axes=input_axes)
assert np.allclose(check_m, matrix)
return sequence


TRANSFORMATIONS_MAP[NgffIdentity] = Identity
TRANSFORMATIONS_MAP[NgffMapAxis] = MapAxis
TRANSFORMATIONS_MAP[NgffTranslation] = Translation
Expand Down
122 changes: 122 additions & 0 deletions tests/transformations/test_transformations.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from contextlib import nullcontext
from copy import deepcopy

import numpy as np
Expand Down Expand Up @@ -29,6 +30,7 @@
Sequence,
Translation,
_decompose_affine_into_linear_and_translation,
_decompose_transformation,
_get_affine_for_element,
)
from xarray import DataArray
Expand Down Expand Up @@ -783,6 +785,126 @@ def test_decompose_affine_into_linear_and_translation():
assert np.allclose(translation.translation, np.array([10, 11]))


@pytest.mark.parametrize(
"matrix,input_axes,output_axes,valid",
[
# non-square matrix are not supported
(
np.array(
[
[1, 2, 3, 10],
[4, 5, 6, 11],
[0, 0, 0, 1],
]
),
("x", "y", "z"),
("x", "y"),
False,
),
(
np.array(
[
[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
[0, 0, 1],
]
),
("x", "y"),
("x", "y", "z"),
False,
),
# z axis should not be present
(
np.array(
[
[1, 2, 3, 10],
[4, 5, 6, 11],
[7, 8, 9, 12],
[0, 0, 0, 1],
]
),
("x", "y", "z"),
("x", "y", "z"),
False,
),
# c channel is modified
(
np.array(
[
[1, 2, 0, 4],
[4, 5, 0, 7],
[8, 9, 1, 10],
[0, 0, 0, 1],
]
),
("x", "y", "c"),
("x", "y", "c"),
False,
),
(
np.array(
[
[1, 2, 0, 4],
[4, 5, 0, 7],
[0, 0, 0, 0],
[0, 0, 0, 1],
]
),
("x", "y", "c"),
("x", "y", "c"),
False,
),
(
np.array(
[
[1, 2, 3, 4],
[4, 5, 6, 7],
[0, 0, 1, 0],
[0, 0, 0, 1],
]
),
("x", "y", "c"),
("x", "y", "c"),
False,
),
# valid, no c channel
(
np.array(
[
[1, 2, 3],
[4, 5, 6],
[0, 0, 1],
]
),
("x", "y"),
("x", "y"),
True,
),
# valid, c channel
(
np.array(
[
[1, 2, 0, 4],
[4, 5, 0, 7],
[0, 0, 1, 0],
[0, 0, 0, 1],
]
),
("x", "y", "c"),
("x", "y", "c"),
True,
),
],
)
@pytest.mark.parametrize("simple_decomposition", [True, False])
def test_decompose_transformation(matrix, input_axes, output_axes, valid, simple_decomposition):
affine = Affine(matrix, input_axes=input_axes, output_axes=output_axes)
context = nullcontext() if valid else pytest.raises(ValueError)
with context:
_ = _decompose_transformation(affine, input_axes=input_axes, simple_decomposition=simple_decomposition)


def test_assign_xy_scale_to_cyx_image():
scale = Scale(np.array([2, 3]), axes=("x", "y"))
image = Image2DModel.parse(np.zeros((10, 10, 10)), dims=("c", "y", "x"))
Expand Down

0 comments on commit da7cc58

Please sign in to comment.