diff --git a/odl/oplib/convolution.py b/odl/oplib/convolution.py index 4b697ada12a..42a2c91d3a1 100644 --- a/odl/oplib/convolution.py +++ b/odl/oplib/convolution.py @@ -11,6 +11,7 @@ from __future__ import division import numpy as np +from odl.discr import DiscreteLpElement from odl.operator import Operator from odl.space import tensor_space from odl.space.base_tensors import TensorSpace, Tensor @@ -20,7 +21,8 @@ is_real_dtype, is_floating_dtype, dtype_str, writable_array) from odl.util.npy_compat import roll -__all__ = ('DiscreteConvolution',) + +__all__ = ('DiscreteConvolution', 'convolve') class DiscreteConvolution(Operator): @@ -55,8 +57,8 @@ def __init__(self, domain, kernel, range=None, axis=None, impl='fft', float data type is used in that case to store the kernel. - If the convolution kernel is complex, the ``domain`` **must** be a complex space. Real-to-complex convolutions are not - allowed and need to defined by composition with - `ComplexEmbedding`. + allowed and need to be defined by composition with + `ComplexEmbedding` instead. See Examples for further clarification. @@ -732,6 +734,98 @@ def _prepare_for_fft(kernel, padded_shape, axes=None, variant='forward'): return roll(padded, shifts, axis=axes) +def convolve(x, y, out=None, **kwargs): + """Return the convolution ``x * y``. + + This is a convenience function for quickly computing a convolution + without having to explicitly create a `DiscreteConvolution` instance. + + Parameters + ---------- + x : array-like + Array or discrete function that is supposed to be convolved with + ``y``. Its type determines the return type. + y : array-like + The kernel with which ``x`` is convolved. It must have the same + number of dimensions as ``x``, and its shape can be at most equal + to ``x.shape``. In axes with size 1, broadcasting is applied. + + **Important:** + + - The kernel must **always** have the same number of dimensions + as ``x``, even for convolutions along axes. + - In axes where no convolution is performed, the shape of the + kernel must either be 1 (broadcasting along these axes), or + equal to the ``x.shape`` (stacked kernel, only supported + for `impl='fft'`). + - The ``'fft'`` implementation needs the kernel to have + floating-point ``dtype``, hence the smallest possible + float data type is used in that case to store the kernel. + - If the convolution kernel is complex, ``x`` will be cast to + complex dtype, and the returned object will be complex as well. + + See Examples for further clarification. + + out : `numpy.ndarray` or `Tensor`, optional + Object to which the result of the convolution should be written. + Its shape and data type must be compatible with the result of the + convolution, which can be determined by :: + + res_dtype = np.result_type(x.dtype, y.dtype, np.float16) + + axis : int or sequence of ints, optional + Coordinate axis or axes in which to take the convolution. + ``None`` means all input axes. + impl : {'fft', 'real'} + Implementation of the convolution as FFT-based or using + direct summation. The fastest available FFT backend is + chosen automatically. Real space convolution is based on + `scipy.signal.convolve`. + See Notes for further information on the backends. + padding : int or sequence of ints, optional + Zero-padding used before Fourier transform in the FFT backend. + Does not apply for ``impl='real'``. A sequence is applied per + axis, with padding values corresponding to ``axis`` entries + as provided. + Default: ``min(kernel.shape - 1, 64)`` + padded_shape : sequence of ints, optional + Apply zero-padding with this target shape. Cannot be used + together with ``padding``. + """ + y = np.asarray(y) + y_is_complex = issubclass(y.dtype.type, np.complexfloating) + + if not isinstance(x, (DiscreteLpElement, Tensor)): + x = np.asarray(x) + x = tensor_space(x.shape, dtype=x.dtype).element(x) + + dom_dtype = np.promote_types(x.dtype, y.dtype) if y_is_complex else x.dtype + if y_is_complex: + dom_dtype = np.promote_types(x.dtype, y.dtype) + else: + dom_dtype = x.dtype + + domain = x.space.astype(dom_dtype) + + conv = DiscreteConvolution(domain, y, **kwargs) + if out is None: + out = conv(x) + elif isinstance(out, np.ndarray): + res = conv.range.element(out) + if res.data is not out: + raise TypeError('`out` {!r} is not compatible with the range {!r}' + 'of the convolution'.format(out, conv.range)) + conv(x, out=res) + else: + res = conv.range.element(out) + if out is not res and out is not getattr(res, 'tensor', None): + raise TypeError('`out` {!r} is not compatible with the range {!r}' + 'of the convolution'.format(out, conv.range)) + conv(x, out=res) + + return out + + if __name__ == '__main__': from odl.util import run_doctests run_doctests()