bijx.SpectrumScaling

class bijx.SpectrumScaling[source]

Bases: ApplyBijection

Diagonal scaling transformation in Fourier space.

Applies element-wise scaling to the Fourier transform of real-valued fields, implementing diagonal transformations in momentum space. This is particularly useful for implementing free field theories and spectral preconditioning.

Type: \(\mathbb{R}^{H \times W \times C} \to \mathbb{R}^{H \times W \times C}\) Transform: \(\mathcal{F}^{-1}[s(\mathbf{k}) \mathcal{F}[\mathbf{x}]]\)

The scaling factors correspond to momentum-dependent transformations, with the log-Jacobian computed from FFT multiplicities to handle real FFT symmetries.

Parameters:
  • scaling (Array | Variable) – Scaling factors with shape matching rFFT output. If not an nnx.Variable/nnx.Param, by default treated as constant.

  • channel_dim (int) – Number of channel dimensions.

Note

The scaling array must have the same shape as the output of jnp.fft.rfftn to ensure proper broadcasting during the Fourier-space multiplication.

Example

>>> # Create momentum-dependent scaling
>>> k = fft_momenta((8, 8))
>>> scaling = jnp.exp(-0.1 * jnp.sum(k**2, axis=-1))
>>> bijection = SpectrumScaling(scaling)
>>> y, log_det = bijection.forward(phi, log_density)
__init__(scaling, channel_dim=0)[source]
Parameters:
  • scaling (Array | Variable)

  • channel_dim (int)

Methods

apply(x, log_density[, reverse])

Unified transformation method.

forward(x, log_density, **kwargs)

Apply forward transformation.

invert()

Create an inverted version of this bijection.

reverse(x, log_density, **kwargs)

Apply reverse (inverse) transformation.

scale(r[, reverse])

Apply Fourier-space scaling transformation.

Attributes

property scaling
scale(r, reverse=False)[source]

Apply Fourier-space scaling transformation.

Transforms the input through FFT, applies scaling, and transforms back. Computes the log-Jacobian contribution from the scaling factors.

Parameters:
  • r – Input array to transform.

  • reverse – If True, apply inverse scaling (division).

Returns:

Tuple of (transformed_array, log_jacobian_contribution).

apply(x, log_density, reverse=False, **kwargs)[source]

Unified transformation method.

Parameters:
  • x – Input data of any pytree structure.

  • log_density – Log density values corresponding to the input.

  • reverse – If True, apply reverse transformation; if False, forward.

  • **kwargs – Additional transformation-specific arguments.

Returns:

Tuple of (transformed_data, updated_log_density).

Raises:

NotImplementedError – Must be implemented by subclasses.