bijx.SpectrumScaling¶
- class bijx.SpectrumScaling[source]¶
Bases:
ApplyBijection
Diagonal scaling transformation in Fourier space.
Applies element-wise scaling to the Fourier transform of real-valued fields, implementing diagonal transformations in momentum space. This is particularly useful for implementing free field theories and spectral preconditioning.
Type: \(\mathbb{R}^{H \times W \times C} \to \mathbb{R}^{H \times W \times C}\) Transform: \(\mathcal{F}^{-1}[s(\mathbf{k}) \mathcal{F}[\mathbf{x}]]\)
The scaling factors correspond to momentum-dependent transformations, with the log-Jacobian computed from FFT multiplicities to handle real FFT symmetries.
- Parameters:
scaling (
Array
|Variable
) – Scaling factors with shape matching rFFT output. If not an nnx.Variable/nnx.Param, by default treated as constant.channel_dim (
int
) – Number of channel dimensions.
Note
The scaling array must have the same shape as the output of jnp.fft.rfftn to ensure proper broadcasting during the Fourier-space multiplication.
Example
>>> # Create momentum-dependent scaling >>> k = fft_momenta((8, 8)) >>> scaling = jnp.exp(-0.1 * jnp.sum(k**2, axis=-1)) >>> bijection = SpectrumScaling(scaling) >>> y, log_det = bijection.forward(phi, log_density)
Methods
apply
(x, log_density[, reverse])Unified transformation method.
forward
(x, log_density, **kwargs)Apply forward transformation.
invert
()Create an inverted version of this bijection.
reverse
(x, log_density, **kwargs)Apply reverse (inverse) transformation.
scale
(r[, reverse])Apply Fourier-space scaling transformation.
Attributes
- property scaling¶
- scale(r, reverse=False)[source]¶
Apply Fourier-space scaling transformation.
Transforms the input through FFT, applies scaling, and transforms back. Computes the log-Jacobian contribution from the scaling factors.
- Parameters:
r – Input array to transform.
reverse – If True, apply inverse scaling (division).
- Returns:
Tuple of (transformed_array, log_jacobian_contribution).
- apply(x, log_density, reverse=False, **kwargs)[source]¶
Unified transformation method.
- Parameters:
x – Input data of any pytree structure.
log_density – Log density values corresponding to the input.
reverse – If True, apply reverse transformation; if False, forward.
**kwargs – Additional transformation-specific arguments.
- Returns:
Tuple of (transformed_data, updated_log_density).
- Raises:
NotImplementedError – Must be implemented by subclasses.