Skip to content

Commit

Permalink
ENH: add utility function for convolution
Browse files Browse the repository at this point in the history
  • Loading branch information
kohr-h committed Jul 16, 2019
1 parent 1fdcdea commit 5b5007f
Showing 1 changed file with 97 additions and 3 deletions.
100 changes: 97 additions & 3 deletions odl/oplib/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from __future__ import division
import numpy as np

from odl.discr import DiscreteLpElement
from odl.operator import Operator
from odl.space import tensor_space
from odl.space.base_tensors import TensorSpace, Tensor
Expand All @@ -20,7 +21,8 @@
is_real_dtype, is_floating_dtype, dtype_str, writable_array)
from odl.util.npy_compat import roll

__all__ = ('DiscreteConvolution',)

__all__ = ('DiscreteConvolution', 'convolve')


class DiscreteConvolution(Operator):
Expand Down Expand Up @@ -55,8 +57,8 @@ def __init__(self, domain, kernel, range=None, axis=None, impl='fft',
float data type is used in that case to store the kernel.
- If the convolution kernel is complex, the ``domain`` **must**
be a complex space. Real-to-complex convolutions are not
allowed and need to defined by composition with
`ComplexEmbedding`.
allowed and need to be defined by composition with
`ComplexEmbedding` instead.
See Examples for further clarification.
Expand Down Expand Up @@ -732,6 +734,98 @@ def _prepare_for_fft(kernel, padded_shape, axes=None, variant='forward'):
return roll(padded, shifts, axis=axes)


def convolve(x, y, out=None, **kwargs):
"""Return the convolution ``x * y``.
This is a convenience function for quickly computing a convolution
without having to explicitly create a `DiscreteConvolution` instance.
Parameters
----------
x : array-like
Array or discrete function that is supposed to be convolved with
``y``. Its type determines the return type.
y : array-like
The kernel with which ``x`` is convolved. It must have the same
number of dimensions as ``x``, and its shape can be at most equal
to ``x.shape``. In axes with size 1, broadcasting is applied.
**Important:**
- The kernel must **always** have the same number of dimensions
as ``x``, even for convolutions along axes.
- In axes where no convolution is performed, the shape of the
kernel must either be 1 (broadcasting along these axes), or
equal to the ``x.shape`` (stacked kernel, only supported
for `impl='fft'`).
- The ``'fft'`` implementation needs the kernel to have
floating-point ``dtype``, hence the smallest possible
float data type is used in that case to store the kernel.
- If the convolution kernel is complex, ``x`` will be cast to
complex dtype, and the returned object will be complex as well.
See Examples for further clarification.
out : `numpy.ndarray` or `Tensor`, optional
Object to which the result of the convolution should be written.
Its shape and data type must be compatible with the result of the
convolution, which can be determined by ::
res_dtype = np.result_type(x.dtype, y.dtype, np.float16)
axis : int or sequence of ints, optional
Coordinate axis or axes in which to take the convolution.
``None`` means all input axes.
impl : {'fft', 'real'}
Implementation of the convolution as FFT-based or using
direct summation. The fastest available FFT backend is
chosen automatically. Real space convolution is based on
`scipy.signal.convolve`.
See Notes for further information on the backends.
padding : int or sequence of ints, optional
Zero-padding used before Fourier transform in the FFT backend.
Does not apply for ``impl='real'``. A sequence is applied per
axis, with padding values corresponding to ``axis`` entries
as provided.
Default: ``min(kernel.shape - 1, 64)``
padded_shape : sequence of ints, optional
Apply zero-padding with this target shape. Cannot be used
together with ``padding``.
"""
y = np.asarray(y)
y_is_complex = issubclass(y.dtype.type, np.complexfloating)

if not isinstance(x, (DiscreteLpElement, Tensor)):
x = np.asarray(x)
x = tensor_space(x.shape, dtype=x.dtype).element(x)

dom_dtype = np.promote_types(x.dtype, y.dtype) if y_is_complex else x.dtype
if y_is_complex:
dom_dtype = np.promote_types(x.dtype, y.dtype)
else:
dom_dtype = x.dtype

domain = x.space.astype(dom_dtype)

conv = DiscreteConvolution(domain, y, **kwargs)
if out is None:
out = conv(x)
elif isinstance(out, np.ndarray):
res = conv.range.element(out)
if res.data is not out:
raise TypeError('`out` {!r} is not compatible with the range {!r}'
'of the convolution'.format(out, conv.range))
conv(x, out=res)
else:
res = conv.range.element(out)
if out is not res and out is not getattr(res, 'tensor', None):
raise TypeError('`out` {!r} is not compatible with the range {!r}'
'of the convolution'.format(out, conv.range))
conv(x, out=res)

return out


if __name__ == '__main__':
from odl.util import run_doctests
run_doctests()

0 comments on commit 5b5007f

Please sign in to comment.