Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add bindings to ASTRA cylindrical detector geometries #1634

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 22 additions & 6 deletions odl/tomo/backends/astra_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,10 @@ def create_ids(self):
elif proj_ndim == 3:
# The `u` and `v` axes of the projection data are swapped,
# see explanation in `astra_*_3d_geom_to_vec`.
astra_proj_shape = (proj_shape[1], proj_shape[0], proj_shape[2])
if self.geometry.det_curvature_radius is None:
astra_proj_shape = (proj_shape[1], proj_shape[0], proj_shape[2])
else:
astra_proj_shape = (proj_shape[2], proj_shape[0], proj_shape[1])
astra_vol_shape = self.vol_space.shape

self.vol_array = np.empty(astra_vol_shape, dtype='float32', order='C')
Expand Down Expand Up @@ -233,8 +236,14 @@ def _call_forward_real(self, vol_data, out=None, **kwargs):
if self.geometry.ndim == 2:
out[:] = self.proj_array
elif self.geometry.ndim == 3:
out[:] = np.swapaxes(self.proj_array, 0, 1).reshape(
self.proj_space.shape)
# TODO: Find a way not to have to do rollaxis(0, 3) for
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we use moveaxis instead? It is recommended in the docs:
https://numpy.org/doc/stable/reference/generated/numpy.rollaxis.html

# cylindrical detectors (probably inside ASTRA)
if self.geometry.det_curvature_radius is None:
out[:] = np.swapaxes(self.proj_array, 0, 1).reshape(
self.proj_space.shape)
else:
out[:] = np.rollaxis(self.proj_array, 0, 3).reshape(
self.proj_space.shape)

# Fix scaling to weight by pixel size
if (
Expand Down Expand Up @@ -283,9 +292,16 @@ def _call_backward_real(self, proj_data, out=None, **kwargs):
elif self.geometry.ndim == 3:
shape = (-1,) + self.geometry.det_partition.shape
reshaped_proj_data = proj_data.asarray().reshape(shape)
swapped_proj_data = np.ascontiguousarray(
np.swapaxes(reshaped_proj_data, 0, 1)
)
# TODO: Find a way not to have to do rollaxis(2, 0) for
# cylindrical detectors (probably inside ASTRA)
if self.geometry.det_curvature_radius is None:
swapped_proj_data = np.ascontiguousarray(
np.swapaxes(reshaped_proj_data, 0, 1)
)
else:
swapped_proj_data = np.ascontiguousarray(
np.rollaxis(reshaped_proj_data, 2, 0)
)
astra.data3d.store(self.sino_id, swapped_proj_data)

# Run algorithm
Expand Down
97 changes: 94 additions & 3 deletions odl/tomo/backends/astra_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@

from odl.discr import DiscretizedSpace, DiscretizedSpaceElement
from odl.tomo.geometry import (
DivergentBeamGeometry, Flat1dDetector, Flat2dDetector, Geometry,
ParallelBeamGeometry)
DivergentBeamGeometry, Flat1dDetector, Flat2dDetector, CylindricalDetector,
Geometry, ParallelBeamGeometry)
from odl.tomo.util.utility import euler_matrix

try:
Expand Down Expand Up @@ -339,6 +339,81 @@ def astra_conebeam_3d_geom_to_vec(geometry):
return vectors


def astra_cyl_conebeam_3d_geom_to_vec(geometry):
"""Create vectors for ASTRA projection geometries from ODL geometry.

The 3D vectors are used to create an ASTRA projection geometry for
cone beam geometries with a cylindrical detector, see ``'cyl_cone_vec'``
in the `ASTRA projection geometry documentation`_.

Each row of the returned vectors corresponds to a single projection
and consists of ::

(srcX, srcY, srcZ, dX, dY, dZ, uX, uY, uZ, vX, vY, vZ, R)

with

- ``src``: the ray source position
- ``d`` : the center of the detector
- ``u`` : tangential direction at center of detector;
the length of u is the arc length of a detector pixel
- ``v`` : the vector from detector pixel ``(0,0)`` to ``(1,0)``
- ``R`` : the radius of the detector cylinder

Parameters
----------
geometry : `Geometry`
ODL projection geometry from which to create the ASTRA geometry.

Returns
-------
vectors : `numpy.ndarray`
Array of shape ``(num_angles, 13)`` containing the vectors.

References
----------
.. _ASTRA projection geometry documentation:
http://www.astra-toolbox.com/docs/geom3d.html#projection-geometries
"""
angles = geometry.angles
vectors = np.zeros((angles.size, 13))

# Source position
vectors[:, 0:3] = geometry.src_position(angles)

# Center of detector in 3D space
# FIXME: This is not correct: det_point_position returns the zero-point of
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, det_point_position return the center of the detector, so this is correct. The problem is with geometry.det_axes(angles), because axes are "attached" to the zero-point and not to the center as ASTRA assumes.

Copy link
Contributor

@JevgenijaAksjonova JevgenijaAksjonova Feb 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the next few lines should look as follows:

    # Center of detector in 3D space
    mid_pt = geometry.det_params.mid_pt
    vectors[:, 3:6] = geometry.det_point_position(angles, mid_pt)
    # reverse the sign, since the partition angle in CylindricalDetector 
    # increases in the clockwise direction, by analogy to flat detectors
    det_shift_angle = -mid_pt[0]

    # `det_axes` gives shape (N, 2, 3), swap to get (2, N, 3)
    det_axes = np.moveaxis(geometry.det_axes(angles + det_shift_angle), -2, 0)

The detector partition along the first dimension is given in radians, so the distance from the mid_pt[0] to the 0 point is an angle. It is sufficient to add this angle to source angles to rotate the detector axes, however the sign should be reversed, because angles increase clockwise on the detector (to be consistent with the default setup for flat detector)

This solution does not affect the situation, when the detector is positioned symmetrically. So, one could still implement the detector shift by using detector_shift_func and rotating the axes.

# the detector, and not the center of the detector. For quarter-pixel-shifted
# detector these two do not coincide.
mid_pt = geometry.det_params.mid_pt
vectors[:, 3:6] = geometry.det_point_position(angles, mid_pt)

# `det_axes` gives shape (N, 2, 3), swap to get (2, N, 3)
det_axes = np.moveaxis(geometry.det_axes(angles), -2, 0)
px_sizes = geometry.det_partition.cell_sides

# `px_sizes[0]` is angular partition; scale by radius to get arc length
# NB: For flat panel detector we swap the u and v axes to get a better
# memory layout. For cylindrical detectors this is (currently) not possible
# since both ODL and Astra have the v direction along the axial direction.
vectors[:, 6:9] = det_axes[0] * px_sizes[0] * geometry.det_curvature_radius
vectors[:, 9:12] = det_axes[1] * px_sizes[1]

# detector curvature radius
vectors[:, 12] = geometry.det_curvature_radius

# ASTRA has (z, y, x) axis convention, in contrast to (x, y, z) in ODL,
# so we need to adapt to this by changing the order.
newind = []
for i in range(4):
newind += [2 + 3 * i, 1 + 3 * i, 0 + 3 * i]
newind += [12]
vectors = vectors[:, newind]

return vectors



def astra_conebeam_2d_geom_to_vec(geometry):
"""Create vectors for ASTRA projection geometries from ODL geometry.

Expand Down Expand Up @@ -532,6 +607,20 @@ def astra_projection_geometry(geometry):
vec = astra_conebeam_3d_geom_to_vec(geometry)
proj_geom = astra.create_proj_geom('cone_vec', det_row_count,
det_col_count, vec)

elif (isinstance(geometry, DivergentBeamGeometry) and
isinstance(geometry.detector, CylindricalDetector) and
geometry.ndim == 3):
# NB: For flat panel detector we swap the u and v axes to get a better
# memory layout. For cylindrical detectors this is (currently) not
# possible since both ODL and Astra have the v direction along the
# axial direction.
det_row_count = geometry.det_partition.shape[1]
det_col_count = geometry.det_partition.shape[0]
vec = astra_cyl_conebeam_3d_geom_to_vec(geometry)
proj_geom = astra.create_proj_geom('cyl_cone_vec', det_row_count,
det_col_count, vec)

else:
raise NotImplementedError('unknown ASTRA geometry type {!r}'
''.format(geometry))
Expand Down Expand Up @@ -674,8 +763,10 @@ def astra_projector(astra_proj_type, astra_vol_geom, astra_proj_geom, ndim):
valid_proj_types = ['line_fanflat', 'strip_fanflat', 'cuda']
elif astra_geom in {'parallel3d', 'parallel3d_vec'}:
valid_proj_types = ['linear3d', 'cuda3d']
elif astra_geom in {'cone', 'cone_vec'}:
elif astra_geom in {'cone', 'cone_vec' }:
valid_proj_types = ['linearcone', 'cuda3d']
elif astra_geom in {'cyl_cone_vec' }:
valid_proj_types = ['cuda3d']
else:
raise ValueError('invalid geometry type {!r}'.format(astra_geom))

Expand Down
3 changes: 2 additions & 1 deletion odl/tomo/geometry/conebeam.py
Original file line number Diff line number Diff line change
Expand Up @@ -1471,7 +1471,8 @@ def __repr__(self):
posargs = [self.motion_partition, self.det_partition]
optargs = [('src_radius', self.src_radius, -1),
('det_radius', self.det_radius, -1),
('pitch', self.pitch, 0)
('det_curvature_radius', self.det_curvature_radius, None),
('pitch', self.pitch, 0),
]

if not np.allclose(self.axis, self._default_config['axis']):
Expand Down