Source code for sporco.fft

# -*- coding: utf-8 -*-
# Copyright (C) 2015-2020 by Brendt Wohlberg <brendt@ieee.org>
# All rights reserved. BSD 3-clause License.
# This file is part of the SPORCO package. Details of the copyright
# and user license can be found in the 'LICENSE.txt' file distributed
# with the package.

"""Variants of the Fast Fourier Transform and associated functions."""

from __future__ import division
from builtins import range

import multiprocessing
import numpy as np
from scipy import fftpack
try:
    import pyfftw
except ImportError:
    have_pyfftw = False
    import warnings
    warnings.warn('Module pyfftw could not be imported. FFT '
                  'computations will be performed using numpy.fft, '
                  'which may be substantially slower', RuntimeWarning)
    import numpy.fft as npfft
else:
    have_pyfftw = True



__author__ = """Brendt Wohlberg <brendt@ieee.org>"""


if have_pyfftw:
    pyfftw.interfaces.cache.enable()
    pyfftw.interfaces.cache.set_keepalive_time(300)

pyfftw_threads = multiprocessing.cpu_count()
"""Global variable setting the number of threads used in :mod:`pyfftw`
computations"""
pyfftw_planner_effort = 'FFTW_MEASURE'
"""FFTW planning rigor flag used in :mod:`pyfftw` computations"""


def is_complex_dtype(dtype):
    """Determine whether a dtype is complex.

    Parameters
    ----------
    dtype : dtype
      A dtype, e.g. np.float32, np.float64, np.complex128

    Returns
    -------
    bool
      True if the dtype is complex, otherwise False
    """

    return dtype.kind == 'c'



[docs] def complex_dtype(dtype): """Construct the corresponding complex dtype for a given real dtype. Construct the corresponding complex dtype for a given real dtype, e.g. the complex dtype corresponding to ``np.float32`` is ``np.complex64``. Parameters ---------- dtype : dtype A real dtype, e.g. np.float32, np.float64 Returns ------- dtype The complex dtype corresponding to the input dtype """ return (np.zeros(1, dtype) + 1j).dtype
[docs] def real_dtype(dtype): """Construct the corresponding real dtype for a given complex dtype. Construct the corresponding real dtype for a given complex dtype, e.g. the real dtype corresponding to ``np.complex64`` is ``np.float32``. Parameters ---------- dtype : dtype A complex dtype, e.g. np.complex64, np.complex128 Returns ------- dtype The real dtype corresponding to the input dtype """ return np.zeros(1, dtype).real.dtype
def byte_aligned(array, dtype=None, n=None): """Construct a byte-aligned array for FFTs. Construct a byte-aligned array for efficient use by :mod:`pyfftw`. This function is a wrapper for :func:`pyfftw.byte_align` Parameters ---------- array : ndarray Input array dtype : dtype, optional (default None) Output array dtype n : int, optional (default None) Output array should be aligned to n-byte boundary Returns ------- ndarray Array with required byte-alignment """ return pyfftw.byte_align(array, n=n, dtype=dtype) def empty_aligned(shape, dtype, order='C', n=None): """Construct an empty byte-aligned array for FFTs. Construct an empty byte-aligned array for efficient use by :mod:`pyfftw`. This function is a wrapper for :func:`pyfftw.empty_aligned` Parameters ---------- shape : sequence of ints Output array shape dtype : dtype Output array dtype order : {'C', 'F'}, optional (default 'C') Specify whether arrays should be stored in row-major (C-style) or column-major (Fortran-style) order n : int, optional (default None) Output array should be aligned to n-byte boundary Returns ------- ndarray Empty array with required byte-alignment """ return pyfftw.empty_aligned(shape, dtype, order, n) def rfftn_empty_aligned(shape, axes, dtype, order='C', n=None): """Construct an empty byte-aligned array for real FFTs. Construct an empty byte-aligned array for efficient use by :mod:`pyfftw` functions :func:`pyfftw.interfaces.numpy_fft.rfftn` and :func:`pyfftw.interfaces.numpy_fft.irfftn`. The shape of the empty array is appropriate for the output of :func:`pyfftw.interfaces.numpy_fft.rfftn` applied to an array of the shape specified by parameter `shape`, and for the input of the corresponding :func:`pyfftw.interfaces.numpy_fft.irfftn` call that reverses this operation. Parameters ---------- shape : sequence of ints Output array shape axes : sequence of ints Axes on which the FFT will be computed dtype : dtype Real dtype from which the complex dtype of the output array is derived order : {'C', 'F'}, optional (default 'C') Specify whether arrays should be stored in row-major (C-style) or column-major (Fortran-style) order n : int, optional (default None) Output array should be aligned to n-byte boundary Returns ------- ndarray Empty array with required byte-alignment """ ashp = list(shape) raxis = axes[-1] ashp[raxis] = ashp[raxis] // 2 + 1 cdtype = complex_dtype(dtype) return pyfftw.empty_aligned(ashp, cdtype, order, n) def fftn(a, s=None, axes=None): """Multi-dimensional discrete Fourier transform. Compute the multi-dimensional discrete Fourier transform. This function is a wrapper for :func:`pyfftw.interfaces.numpy_fft.fftn`, with an interface similar to that of :func:`numpy.fft.fftn`. Parameters ---------- a : array_like Input array (can be complex) s : sequence of ints, optional (default None) Shape of the output along each transformed axis (input is cropped or zero-padded to match). axes : sequence of ints, optional (default None) Axes over which to compute the DFT. Returns ------- complex ndarray DFT of input array """ return pyfftw.interfaces.numpy_fft.fftn( a, s=s, axes=axes, overwrite_input=False, planner_effort=pyfftw_planner_effort, threads=pyfftw_threads) def ifftn(a, s=None, axes=None): """Multi-dimensional inverse discrete Fourier transform. Compute the multi-dimensional inverse discrete Fourier transform. This function is a wrapper for :func:`pyfftw.interfaces.numpy_fft.ifftn`, with an interface similar to that of :func:`numpy.fft.ifftn`. Parameters ---------- a : array_like Input array (can be complex) s : sequence of ints, optional (default None) Shape of the output along each transformed axis (input is cropped or zero-padded to match). axes : sequence of ints, optional (default None) Axes over which to compute the inverse DFT. Returns ------- complex ndarray Inverse DFT of input array """ return pyfftw.interfaces.numpy_fft.ifftn( a, s=s, axes=axes, overwrite_input=False, planner_effort=pyfftw_planner_effort, threads=pyfftw_threads) def rfftn(a, s=None, axes=None): """Multi-dimensional discrete Fourier transform for real input. Compute the multi-dimensional discrete Fourier transform for real input. This function is a wrapper for :func:`pyfftw.interfaces.numpy_fft.rfftn`, with an interface similar to that of :func:`numpy.fft.rfftn`. Parameters ---------- a : array_like Input array (taken to be real) s : sequence of ints, optional (default None) Shape of the output along each transformed axis (input is cropped or zero-padded to match). axes : sequence of ints, optional (default None) Axes over which to compute the DFT. Returns ------- complex ndarray DFT of input array """ return pyfftw.interfaces.numpy_fft.rfftn( a, s=s, axes=axes, overwrite_input=False, planner_effort=pyfftw_planner_effort, threads=pyfftw_threads) def irfftn(a, s, axes=None): """Multi-dimensional inverse discrete Fourier transform for real input. Compute the inverse of the multi-dimensional discrete Fourier transform for real input. This function is a wrapper for :func:`pyfftw.interfaces.numpy_fft.irfftn`, with an interface similar to that of :func:`numpy.fft.irfftn`. Parameters ---------- a : array_like Input array s : sequence of ints Shape of the output along each transformed axis (input is cropped or zero-padded to match). This parameter is not optional because, unlike :func:`ifftn`, the output shape cannot be uniquely determined from the input shape. axes : sequence of ints, optional (default None) Axes over which to compute the inverse DFT. Returns ------- ndarray Inverse DFT of input array """ return pyfftw.interfaces.numpy_fft.irfftn( a, s=s, axes=axes, overwrite_input=False, planner_effort=pyfftw_planner_effort, threads=pyfftw_threads)
[docs] def dctii(x, axes=None): """Multi-dimensional DCT-II. Compute a multi-dimensional DCT-II over specified array axes. This function is implemented by calling the one-dimensional DCT-II :func:`scipy.fftpack.dct` with normalization mode 'ortho' for each of the specified axes. Parameters ---------- a : array_like Input array axes : sequence of ints, optional (default None) Axes over which to compute the DCT-II. Returns ------- ndarray DCT-II of input array """ if axes is None: axes = list(range(x.ndim)) for ax in axes: x = fftpack.dct(x, type=2, axis=ax, norm='ortho') return x
[docs] def idctii(x, axes=None): """Multi-dimensional inverse DCT-II. Compute a multi-dimensional inverse DCT-II over specified array axes. This function is implemented by calling the one-dimensional inverse DCT-II :func:`scipy.fftpack.idct` with normalization mode 'ortho' for each of the specified axes. Parameters ---------- a : array_like Input array axes : sequence of ints, optional (default None) Axes over which to compute the inverse DCT-II. Returns ------- ndarray Inverse DCT-II of input array """ if axes is None: axes = list(range(x.ndim)) for ax in axes[::-1]: x = fftpack.idct(x, type=2, axis=ax, norm='ortho') return x
[docs] def fftconv(a, b, axes=None, origin=None): """Multi-dimensional convolution via the Discrete Fourier Transform. Compute a multi-dimensional convolution via the Discrete Fourier Transform. Note that the output has a phase shift relative to the output of :func:`scipy.ndimage.convolve` with the default `origin` parameter. Parameters ---------- a : array_like Input array b : array_like Input array axes : sequence of ints or None optional (default None) Axes on which to perform convolution. The default of None selects all axes of `a` origin : sequence of ints or None optional (default None) Indices of centre of `a` filter. The default of None corresponds to a centre at 0 on all axes of `a` Returns ------- ndarray Convolution of input arrays, `a` and `b`, along specified `axes` """ if axes is None: axes = tuple(range(a.ndim)) if np.isrealobj(a) and np.isrealobj(b): fft = rfftn ifft = irfftn else: fft = fftn ifft = ifftn dims = np.maximum([a.shape[i] for i in axes], [b.shape[i] for i in axes]) af = fft(a, dims, axes) bf = fft(b, dims, axes) ab = ifft(af * bf, dims, axes) if origin is not None: ab = np.roll(ab, -np.array(origin), axis=axes) return ab
[docs] def fl2norm2(xf, axis=(0, 1)): r"""Compute the squared :math:`\ell_2` norm in the DFT domain. Compute the squared :math:`\ell_2` norm in the DFT domain, taking into account the unnormalised DFT scaling, i.e. given the DFT of a multi-dimensional array computed via :func:`fftn`, return the squared :math:`\ell_2` norm of the original array. Parameters ---------- xf : array_like Input array axis : sequence of ints, optional (default (0,1)) Axes on which the input is in the frequency domain Returns ------- float :math:`\|\mathbf{x}\|_2^2` where the input array is the result of applying :func:`fftn` to the specified axes of multi-dimensional array :math:`\mathbf{x}` """ xfs = xf.shape return (np.linalg.norm(xf)**2) / np.prod(np.array([xfs[k] for k in axis]))
[docs] def rfl2norm2(xf, xs, axis=(0, 1)): r"""Compute the squared :math:`\ell_2` norm in the real DFT domain. Compute the squared :math:`\ell_2` norm in the DFT domain, taking into account the unnormalised DFT scaling, i.e. given the DFT of a multi-dimensional array computed via :func:`rfftn`, return the squared :math:`\ell_2` norm of the original array. Parameters ---------- xf : array_like Input array xs : sequence of ints Shape of original array to which :func:`rfftn` was applied to obtain the input array axis : sequence of ints, optional (default (0,1)) Axes on which the input is in the frequency domain Returns ------- float :math:`\|\mathbf{x}\|_2^2` where the input array is the result of applying :func:`rfftn` to the specified axes of multi-dimensional array :math:`\mathbf{x}` """ scl = 1.0 / np.prod(np.array([xs[k] for k in axis])) slc0 = (slice(None),) * axis[-1] nrm0 = np.linalg.norm(xf[slc0 + (0,)]) idx1 = (xs[axis[-1]] + 1) // 2 nrm1 = np.linalg.norm(xf[slc0 + (slice(1, idx1),)]) if xs[axis[-1]] % 2 == 0: nrm2 = np.linalg.norm(xf[slc0 + (slice(-1, None),)]) else: nrm2 = 0.0 return scl*(nrm0**2 + 2.0*nrm1**2 + nrm2**2)
def empty_aligned_func(real=False): """Get a reference to :func:`empty_aligned` or :func:`rfftn_empty_aligned`. If `real` is True, return a reference to :func:`rfftn_empty_aligned`, otherwise return a reference to a wrapper function of :func:`empty_aligned` that has the same signature as :func:`rfftn_empty_aligned` (i.e. including a dummy `axes` parameter). Parameters ---------- real : bool, optional (default False) Flag indicating which function reference to return Returns ------- function Reference to selected function """ if real: return rfftn_empty_aligned else: def empty_aligned_wrapper(shape, axes, dtype, order='C', n=None): return empty_aligned(shape, dtype, order=order, n=n) return empty_aligned_wrapper def fftn_func(real=False): """Get a reference to :func:`fftn` or :func:`rfftn`. If `real` is True, return a reference to :func:`rfftn`, otherwise return a reference to :func:`fftn`. Parameters ---------- real : bool, optional (default False) Flag indicating which function reference to return Returns ------- function Reference to selected function """ if real: return rfftn else: return fftn def ifftn_func(real=False): """Get a reference to :func:`ifftn` or :func:`irfftn`. If `real` is True, return a reference to :func:`irfftn`, otherwise return a reference to :func:`ifftn`. Parameters ---------- real : bool, optional (default False) Flag indicating which function reference to return Returns ------- function Reference to selected function """ if real: return irfftn else: return ifftn def fl2norm2_func(real=False): """Get a reference to :func:`fl2norm2` or :func:`rfl2norm2`. If `real` is True, return a reference to :func:`rfl2norm2`, otherwise return a reference to a wrapper function of :func:`fl2norm2` that has the same signature as :func:`rfl2norm2` (i.e. including a dummy `xs` parameter). Parameters ---------- real : bool, optional (default False) Flag indicating which function reference to return Returns ------- function Reference to selected function """ if real: return rfl2norm2 else: def fl2norm2_wrapper(xf, xs, axis=(0, 1)): return fl2norm2(xf, axis=axis) return fl2norm2_wrapper if not have_pyfftw: __all__ = ['complex_dtype', 'real_dtype', 'byte_aligned', 'empty_aligned', 'rfftn_empty_aligned', 'fftn', 'ifftn', 'rfftn', 'irfftn', 'dctii', 'idctii', 'fftconv', 'fl2norm2', 'rfl2norm2'] def _aligned(array, dtype=None, n=None): if dtype is None: return array else: return array.astype(dtype) _aligned.__doc__ = byte_aligned.__doc__ byte_aligned = _aligned def _empty(shape, dtype, order='C', n=None): return np.empty(shape, dtype=dtype) _empty.__doc__ = empty_aligned.__doc__ empty_aligned = _empty def _rfftn_empty(shape, axes, dtype, order='C', n=None): ashp = list(shape) raxis = axes[-1] ashp[raxis] = ashp[raxis] // 2 + 1 cdtype = complex_dtype(dtype) return np.empty(ashp, dtype=cdtype) _rfftn_empty.__doc__ = rfftn_empty_aligned.__doc__ rfftn_empty_aligned = _rfftn_empty def _fftn(a, s=None, axes=None): return npfft.fftn(a, s, axes).astype(complex_dtype(a.dtype)) _fftn.__doc__ = fftn.__doc__ fftn = _fftn def _ifftn(a, s=None, axes=None): return npfft.ifftn(a, s, axes).astype(a.dtype) _ifftn.__doc__ = ifftn.__doc__ ifftn = _ifftn def _rfftn(a, s=None, axes=None): return npfft.rfftn(a, s, axes).astype(complex_dtype(a.dtype)) _rfftn.__doc__ = rfftn.__doc__ rfftn = _rfftn def _irfftn(a, s=None, axes=None): return npfft.irfftn(a, s, axes).astype(real_dtype(a.dtype)) _irfftn.__doc__ = irfftn.__doc__ irfftn = _irfftn