Fourier Space

In many physics and machine learning applications, it’s advantageous to work with data in Fourier space. In particular, for real-valued data on a lattice (i.e. a discretized image where the lattice is the pixel grid), the Fast Fourier Transform (FFT) gives a natural representation when translational symmetries are present. Bijx provides convenient ways to manage different representations of Fourier-space data through the fourier.FourierData class.

A particular challenge, especially in the context of normalizing flows and bijections, is that not all complex-valued coefficients output by FFT are independent. Denoting \(\tilde{x}_k\) the FFT of a real-valued array \(x_i\), we have the (reality) constraint \(\tilde{x}_k = \tilde{x}_{-k}^*\). This symmetry implies that roughly half of the coefficients are redundant. Both numpy and jax.numpy provide real versions of FFT which eliminate some redundancy, but not all (due to the array having a single datatype, complex, and maintaining a convenient shape). Manually keeping track of these constraints is tedious and error-prone.

Container for Fourier Representations

The fourier.FourierData class acts as a container for data, along with metadata about its current representation. It allows seamless conversion between different representations without having to worry about the underlying symmetry constraints. Any spatial shape and dimension is supported, as well as batch and channel dimensions.

import jax.numpy as jnp
import numpy as np
from bijx.fourier import FourierData, FFTRep
# Let's create a sample 2D real-valued array (e.g., an grayscale image)
real_shape = (8, 8)
x_real = jnp.array(np.random.randn(*real_shape))

print(f"Original real-space data shape: {x_real.shape}")
print(f"Total degrees of freedom: {x_real.size}")
Original real-space data shape: (8, 8)
Total degrees of freedom: 64

The FFTRep enum defines the four possible representations that FourierData can manage and convert between:

  • real_space: The original, real-valued data.

  • rfft: The direct output of jnp.fft.rfftn. This is a complex array with a reduced last dimension, but it still contains redundant information due to the Hermitian symmetry.

  • comp_complex: The set of independent complex Fourier coefficients. This is the most compact complex representation.

  • comp_real: The independent degrees of freedom, packed into a single real-valued array. This is often the most convenient representation for use in neural networks.

# Create a FourierData object from our real-space data.
# real_shape is passed infer batch dimensions.
fd = FourierData.from_real(x_real, real_shape, channel_dim=0)

print(f"Initial representation: {fd.rep.name}")
print(f"Shape of data: {fd.data.shape}")
Initial representation: real_space
Shape of data: (8, 8)

rfftn array

The primary way to interact with a FourierData object is through its to() method. Let’s convert our data to the raw rfft representation.

# Convert to the raw rFFT output representation
fd_rfft = fd.to(FFTRep.rfft)

print(f"Representation after converting to rfft: {fd_rfft.rep.name}")
print(f"Shape of rFFT data: {fd_rfft.data.shape}")
Representation after converting to rfft: rfft
Shape of rFFT data: (8, 5)

Notice that the last dimension of the rfft output is 8 // 2 + 1 = 5, which is because the dtype is now complex and rfftn already removed some redundancy. Nontheless, there are more degrees of freedom in the fd_rft.data than in the original real array:

print(fd_rfft.data.dtype)
print(f"Total real values: {2 * fd_rfft.data.size} > {x_real.size}")
complex64
Total real values: 80 > 64

Independent degrees of freedom

First, we can reduce to a set of independent complex values, comp_complex, by eliminating conjugate pairs. These still contain some entries that are constrained to be purely real. To get a strictly independent set of real values, we can use the comp_real representation as demonstrated below.

# Convert to independent complex components
fd_comp_complex = fd.to(FFTRep.comp_complex)

print(f"Representation: {fd_comp_complex.rep.name}")
print(fd_comp_complex.data.dtype)
# still, have some entries that should be purely real but are represented
# as complex number to be contained in a single array
print(f"Total complex values: {2 * fd_comp_complex.data.size}")

print(f"Number of independent complex coefficients: {fd_comp_complex.data.size}")
Representation: comp_complex
complex64
Total complex values: 68
Number of independent complex coefficients: 34

Finally, the most compact representation comp_real stores all independent real and imaginary parts into a single real-valued array.

# Convert to independent real degrees of freedom
fd_comp_real = fd.to(FFTRep.comp_real)

print(f"Representation: {fd_comp_real.rep.name}")
print(f"Shape of independent real DoF data: {fd_comp_real.data.shape}")
print(f"Size of independent real DoF: {fd_comp_real.data.size}")

# Verify that the total degrees of freedom are conserved
fd_comp_real.data.size == x_real.size
Representation: comp_real
Shape of independent real DoF data: (64,)
Size of independent real DoF: 64
True

Round-Trip sanity check

# Let's do a full round-trip: real -> comp_real -> real
fd_comp_real = FourierData.from_real(x_real, real_shape, to=FFTRep.comp_real)
x_reconstructed = fd_comp_real.to(FFTRep.real_space).data

# Check if the reconstruction is close to the original
np.allclose(x_real, x_reconstructed, atol=1e-6)
True

Fourier Metadata

While the FourierData class provides a high-level interface for most use cases, it may also be useful to access the underlying metadata for custom operations.

FourierMeta contains all the pre-computed indices and masks required to perform conversions between Fourier representations. This information is attached to every FourierData object.

Key attributes of FourierMeta include:

  • mr and mi: Boolean masks for the real and imaginary parts of the independent Fourier coefficients within the rfft output.

  • copy_from and copy_to: Index arrays that define the Hermitian symmetry relationships. They are used to reconstruct the full rfft array from the independent components.

  • ks_reduced: The squared momentum magnitudes |k|^2 corresponding to the independent degrees of freedom.

The following example demonstrates how to create a FourierMeta object and inspect its properties.

from bijx.fourier import FourierMeta

# Create metadata for our shape
meta = FourierMeta.create(real_shape)

print("Shape of 'mr' (real part mask):", meta.mr.shape)
print("Number of independent real components:", meta.mr.sum())
print("Number of independent imaginary components:", meta.mi.sum())
print("Total degrees of freedom:", meta.mr.sum() + meta.mi.sum())

# These are the indices for reconstructing the dependent parts
# from the independent ones.
print(f"\nNumber of conjugate pairs to copy: {len(meta.copy_to)}")

# Example of copy_to indices, used to go from independent to duplicate representation
print("Example of a 'copy_to' index:", meta.copy_to[0])
print("Example of a 'copy_from' index:", meta.copy_from[0])
Shape of 'mr' (real part mask): (8, 5)
Number of independent real components: 34
Number of independent imaginary components: 30
Total degrees of freedom: 64

Number of conjugate pairs to copy: 6
Example of a 'copy_to' index: [5 0]
Example of a 'copy_from' index: [3 0]