bijx.nn.features

Nonlinear feature transformations for neural network layers.

This module provides nonlinear feature mappings for the vector fields of continuous normalizing flows.

For continuous normalizing flows, the divergence of the feature map is computed automatically using the vector-Jacobian product:

\[ \nabla \cdot \mathbf{f} = \text{tr}\left(\frac{\partial \mathbf{f}}{\partial \mathbf{x}}\right) \]

This enables efficient computation of log-density changes in normalizing flows, as the non-linear features are applied “locally”.

Classes

ConcatFeatures

Concatenation of multiple feature maps.

DecayingFourierFeatures

Sinusoidal features damped by a learnable Gaussian envelope.

FourierFeatures

Sinusoidal Fourier feature transformation with learnable frequencies.

GaussianRBFFeatures

Gaussian radial basis features with learnable centers and widths.

NonlinearFeatures

Base class for nonlinear feature transformations with divergence computation.

PolynomialFeatures

Polynomial feature transformation with specified powers.

class bijx.nn.features.NonlinearFeatures[source]

Bases: Module

Base class for nonlinear feature transformations with divergence computation.

Provides the foundation for feature mappings that transform input data through learned nonlinear functions. Automatically computes the divergence of the transformation using vector-Jacobian products.

Parameters:
  • out_channel_size (int) – Total number of output feature channels.

  • rngs (Rngs | None) – Random number generator state for parameter initialization.

Note

This is an abstract base class. Subclasses must implement apply_feature_map() to define the specific nonlinear transformation. The divergence computation is handled automatically by the base class.

apply_feature_map(inputs, **kwargs)[source]

Apply the nonlinear feature transformation.

This method must be implemented by subclasses to define the specific nonlinear mapping applied to input data.

Parameters:
  • inputs – Input data to transform.

  • **kwargs – Additional transformation-specific arguments.

Returns:

Transformed feature representation.

class bijx.nn.features.FourierFeatures[source]

Bases: NonlinearFeatures

Sinusoidal Fourier feature transformation with learnable frequencies.

The frequencies \(\mathbf{\omega}_i\) are learned parameters initialized from a uniform distribution, allowing the network to adapt to the characteristic scales present in the data.

Parameters:
  • feature_count (int) – Number of sinusoidal features per input channel.

  • input_channels (int) – Number of input channels to transform.

  • freq_init (Callable) – Initializer for frequency parameters.

  • rngs (Rngs | None) – Random number generator state.

Note

The total output size is input_channels * feature_count.

Example

>>> features = FourierFeatures(16, input_channels=1, rngs=rngs)
>>> transformed, div_map = features(phi[..., None])
apply_feature_map(phi_lin, **kwargs)[source]

Apply sinusoidal feature transformation.

Parameters:
  • phi_lin – Input data to transform.

  • **kwargs – Additional arguments (unused).

Returns:

Sinusoidal features with shape (…, input_channels, feature_count).

class bijx.nn.features.PolynomialFeatures[source]

Bases: NonlinearFeatures

Polynomial feature transformation with specified powers.

Transforms input data through polynomial basis functions of specified degrees.

The transformation applies each specified power element-wise to the input, creating a polynomial basis that can represent complex nonlinear relationships.

Parameters:
  • powers (list[int]) – List of polynomial powers to apply.

  • input_channels (int) – Number of input channels to transform.

  • rngs (Rngs | None) – Random number generator state.

Note

Powers should be non-negative integers. The power 0 gives constant features (all ones), power 1 gives identity, and higher powers provide increasingly nonlinear transformations.

Example

>>> # Polynomial features with linear and quadratic terms
>>> features = PolynomialFeatures([1, 2], input_channels=1, rngs=rngs)
>>> transformed, div_map = features(jnp.ones((1, 1)))

Important

Inclusion of powers other than 0 and 1 can lead to numerical instability as the vector fields may not be Lipschitz continuous.

apply_feature_map(phi_lin, **kwargs)[source]

Apply polynomial feature transformation.

Parameters:
  • phi_lin – Input data to transform.

  • **kwargs – Additional arguments (unused).

Returns:

Polynomial features with shape (…, input_channels, len(powers)).

class bijx.nn.features.DecayingFourierFeatures[source]

Bases: NonlinearFeatures

Sinusoidal features damped by a learnable Gaussian envelope.

Each feature is

\[g_k(\phi) = \sin(\omega_k \phi) \, \exp\!\left(-\frac{\phi^2}{2 \ell^2}\right),\]
where \(\ell > 0\) is a learnable (per-input-channel) decay length.

Matches FourierFeatures near \(\phi=0\) (\(g_k'(0) = \omega_k\) identically), but suppresses the field at large \(|\phi|\) — killing the periodic wrap-around that bare sin features inherit. Odd in \(\phi\), so a linear map without bias yields a \(\mathbb{Z}_2\)-equivariant vector field, just like FourierFeatures.

With log_decay_length_init large (default 3, so \(\ell \approx 20\)), the envelope is essentially flat across typical data ranges and the features start out indistinguishable from FourierFeatures. Training can then shrink \(\ell\) if a tighter envelope helps.

Parameters:
  • feature_count (int) – Number of sinusoidal features per input channel.

  • input_channels (int) – Number of input channels to transform.

  • freq_init (Callable) – Initializer for frequency parameters.

  • log_decay_length_init (float) – Initial value of \(\log \ell\) (per input channel).

  • rngs (Rngs | None) – Random number generator state.

apply_feature_map(phi_lin, **kwargs)[source]

Apply the nonlinear feature transformation.

This method must be implemented by subclasses to define the specific nonlinear mapping applied to input data.

Parameters:
  • inputs – Input data to transform.

  • **kwargs – Additional transformation-specific arguments.

Returns:

Transformed feature representation.

class bijx.nn.features.GaussianRBFFeatures[source]

Bases: NonlinearFeatures

Gaussian radial basis features with learnable centers and widths.

Each feature is the antisymmetrized RBF

\[ g_k(\phi) = \exp\!\left(-\frac{(\phi - c_k)^2}{2 \sigma_k^2}\right) - \exp\!\left(-\frac{(\phi + c_k)^2}{2 \sigma_k^2}\right), \]

with \(c_k > 0\) (parameterized via softplus so it stays positive under training). The result is odd in \(\phi\) by construction, so a linear map without bias produces a \(\mathbb{Z}_2\)-equivariant vector field — matching the inductive bias that \(\sin\) features have in ConvVF, while also being non-periodic, smooth, and decaying outside the data range.

Parameters:
  • feature_count (int) – Number of RBFs per input channel.

  • input_channels (int) – Number of input channels.

  • center_max (float) – Centers \(c_k\) are initialized evenly on \((0, c_{\max}]\). Default 3.0 covers typical \(\phi^4\) field values.

  • log_sigma_init (float) – Initial value of \(\log \sigma\). Default log(0.8).

  • rngs (Rngs | None) – Random number generator state.

property centers
apply_feature_map(phi_lin, **kwargs)[source]

Apply the nonlinear feature transformation.

This method must be implemented by subclasses to define the specific nonlinear mapping applied to input data.

Parameters:
  • inputs – Input data to transform.

  • **kwargs – Additional transformation-specific arguments.

Returns:

Transformed feature representation.

class bijx.nn.features.ConcatFeatures[source]

Bases: NonlinearFeatures

Concatenation of multiple feature maps.

Combines multiple nonlinear feature maps by applying each transformation to the input and concatenating the results.

Parameters:
  • features (List[NonlinearFeatures]) – List of NonlinearFeatures instances to compose.

  • rngs (Rngs | None) – Random number generator state.

Note

The total output size is the sum of all component feature sizes. This approach allows combining complementary feature types (e.g., Fourier and polynomial features) for higher expressiveness.

Example

>>> fourier = FourierFeatures(49, input_channels=1, rngs=rngs)
>>> poly = PolynomialFeatures([1, 2], input_channels=1, rngs=rngs)
>>> combined = ConcatFeatures([fourier, poly], rngs=rngs)
>>> combined.out_channel_size == 49 + 2
True
apply_feature_map(phi_lin, **kwargs)[source]

Apply all component feature transformations and concatenate results.

Parameters:
  • phi_lin – Input data to transform.

  • **kwargs – Additional arguments passed to all component transformations.

Returns:

Concatenated features from all component transformations.