bijx.nn.features

Nonlinear feature transformations for neural network layers.

This module provides nonlinear feature mappings for the vector fields of continuous normalizing flows.

For continuous normalizing flows, the divergence of the feature map is computed automatically using the vector-Jacobian product:

\[ \nabla \cdot \mathbf{f} = \text{tr}\left(\frac{\partial \mathbf{f}}{\partial \mathbf{x}}\right) \]

This enables efficient computation of log-density changes in normalizing flows, as the non-linear features are applied “locally”.

Classes

ConcatFeatures

Concatenation of multiple feature maps.

FourierFeatures

Sinusoidal Fourier feature transformation with learnable frequencies.

NonlinearFeatures

Base class for nonlinear feature transformations with divergence computation.

PolynomialFeatures

Polynomial feature transformation with specified powers.

class bijx.nn.features.NonlinearFeatures[source]

Bases: Module

Base class for nonlinear feature transformations with divergence computation.

Provides the foundation for feature mappings that transform input data through learned nonlinear functions. Automatically computes the divergence of the transformation using vector-Jacobian products.

Parameters:
  • out_channel_size (int) – Total number of output feature channels.

  • rngs (Rngs | None) – Random number generator state for parameter initialization.

Note

This is an abstract base class. Subclasses must implement apply_feature_map() to define the specific nonlinear transformation. The divergence computation is handled automatically by the base class.

apply_feature_map(inputs, **kwargs)[source]

Apply the nonlinear feature transformation.

This method must be implemented by subclasses to define the specific nonlinear mapping applied to input data.

Parameters:
  • inputs – Input data to transform.

  • **kwargs – Additional transformation-specific arguments.

Returns:

Transformed feature representation.

class bijx.nn.features.FourierFeatures[source]

Bases: NonlinearFeatures

Sinusoidal Fourier feature transformation with learnable frequencies.

The frequencies \(\mathbf{\omega}_i\) are learned parameters initialized from a uniform distribution, allowing the network to adapt to the characteristic scales present in the data.

Parameters:
  • feature_count (int) – Number of sinusoidal features per input channel.

  • input_channels (int) – Number of input channels to transform.

  • freq_init (Callable) – Initializer for frequency parameters.

  • rngs (Rngs | None) – Random number generator state.

Note

The total output size is input_channels × feature_count.

Example

>>> features = FourierFeatures(16, input_channels=1, rngs=rngs)
>>> transformed, div = features(phi[..., None], local_coupling)
apply_feature_map(phi_lin, **kwargs)[source]

Apply sinusoidal feature transformation.

Parameters:
  • phi_lin – Input data to transform.

  • **kwargs – Additional arguments (unused).

Returns:

Sinusoidal features with shape (…, input_channels, feature_count).

class bijx.nn.features.PolynomialFeatures[source]

Bases: NonlinearFeatures

Polynomial feature transformation with specified powers.

Transforms input data through polynomial basis functions of specified degrees.

The transformation applies each specified power element-wise to the input, creating a polynomial basis that can represent complex nonlinear relationships.

Parameters:
  • powers (list[int]) – List of polynomial powers to apply.

  • input_channels (int) – Number of input channels to transform.

  • rngs (Rngs | None) – Random number generator state.

Note

Powers should be non-negative integers. The power 0 gives constant features (all ones), power 1 gives identity, and higher powers provide increasingly nonlinear transformations.

Example

>>> # Polynomial features with linear and quadratic terms
>>> features = PolynomialFeatures([1, 2], input_channels=1, rngs=rngs)
>>> transformed, div = features(jnp.ones((1, 1)), jnp.ones((1, 2)))

Important

Inclusion of powers other than 0 and 1 can lead to numerical instability as the vector fields may not be Lipschitz continuous.

apply_feature_map(phi_lin, **kwargs)[source]

Apply polynomial feature transformation.

Parameters:
  • phi_lin – Input data to transform.

  • **kwargs – Additional arguments (unused).

Returns:

Polynomial features with shape (…, input_channels, len(powers)).

class bijx.nn.features.ConcatFeatures[source]

Bases: NonlinearFeatures

Concatenation of multiple feature maps.

Combines multiple nonlinear feature maps by applying each transformation to the input and concatenating the results.

Parameters:
  • features (List[NonlinearFeatures]) – List of NonlinearFeatures instances to compose.

  • rngs (Rngs | None) – Random number generator state.

Note

The total output size is the sum of all component feature sizes. This approach allows combining complementary feature types (e.g., Fourier and polynomial features) for higher expressiveness.

Example

>>> fourier = FourierFeatures(49, input_channels=1, rngs=rngs)
>>> poly = PolynomialFeatures([1, 2], input_channels=1, rngs=rngs)
>>> combined = ConcatFeatures([fourier, poly], rngs=rngs)
>>> combined.out_channel_size == 49 + 2
True
apply_feature_map(phi_lin, **kwargs)[source]

Apply all component feature transformations and concatenate results.

Parameters:
  • phi_lin – Input data to transform.

  • **kwargs – Additional arguments passed to all component transformations.

Returns:

Concatenated features from all component transformations.