Source code for shampoo.fftutils
"""
This module is a wrapper around ``pyfftw``'s API for fast Fourier transforms.
"""
from __future__ import (absolute_import, division, print_function,
unicode_literals)
import pyfftw
import numpy as np
from numpy.compat import integer_types
pyfftw.interfaces.cache.enable()
__all__ = ['FFT', 'fftshift']
[docs]class FFT(object):
"""
Convenience wrapper around ``pyfftw.builders.fft2``.
"""
def __init__(self, shape, float_precision, complex_precision, threads=2):
"""
Parameters
----------
shape : tuple
Shape of the arrays which you will take the Fourier transforms of.
float_precision : `~numpy.dtype`
complex_precision : `~numpy.dtype`
threads : int, optional
This FFT implementation uses multithreading, with
two threads by default.
"""
# Allocate byte-aligned
self.buffer_float = pyfftw.empty_aligned(shape,
dtype=float_precision.__name__)
self.buffer_complex = pyfftw.empty_aligned(shape,
dtype=complex_precision.__name__)
self._fft2 = pyfftw.builders.fft2(self.buffer_float, threads=threads)
self._ifft2 = pyfftw.builders.ifft2(self.buffer_complex,
threads=threads)
[docs] def fft2(self, array):
"""
2D Fourier transform.
Parameters
----------
array : `~numpy.ndarray` (real)
Input array
Returns
-------
ft_array : `~numpy.ndarray` (complex)
Fourier transform of the input array
"""
self._fft2.input_array[:] = array
return self._fft2()
[docs] def ifft2(self, array):
"""
Inverse 2D Fourier transform.
Parameters
----------
array : `~numpy.ndarray`
Input array
Returns
-------
ift_array : `~numpy.ndarray`
Inverse Fourier transform of input array
"""
self._ifft2.input_array[:] = array
return self._ifft2()
[docs]def fftshift(x, additional_shift=None, axes=None):
"""
Shift the zero-frequency component to the center of the spectrum, or with
some additional offset from the center.
This is a more generic fork of `~numpy.fft.fftshift`, which doesn't support
additional shifts.
Parameters
----------
x : array_like
Input array.
additional_shift : list of length ``M``
Desired additional shifts in ``x`` and ``y`` directions respectively
axes : int or shape tuple, optional
Axes over which to shift. Default is None, which shifts all axes.
Returns
-------
y : `~numpy.ndarray`
The shifted array.
"""
tmp = np.asarray(x)
ndim = len(tmp.shape)
if axes is None:
axes = list(range(ndim))
elif isinstance(axes, integer_types):
axes = (axes,)
# If no additional shift is supplied, reproduce `numpy.fft.fftshift` result
if additional_shift is None:
additional_shift = [0, 0]
y = tmp
for k, extra_shift in zip(axes, additional_shift):
n = tmp.shape[k]
if (n+1)//2 - extra_shift < n:
p2 = (n+1)//2 - extra_shift
else:
p2 = abs(extra_shift) - (n+1)//2
mylist = np.concatenate((np.arange(p2, n), np.arange(0, p2)))
y = np.take(y, mylist, k)
return y