bijx.nn.embeddings

Embedding layers for time and positional encoding in neural networks.

This module provides various embedding functions that map scalar inputs to high-dimensional feature vectors. These are particularly useful for continuous normalizing flows where time-dependent parameters need rich feature representations, and for positional encodings in attention mechanisms.

Functions

rescale_range(val, val_range)

Rescale input values to unit interval [0, 1].

Classes

Embedding

Base class for scalar-to-vector embedding functions.

KernelFourier

Fourier series embedding for smooth representations.

KernelGauss

Gaussian radial basis function embedding with learnable widths.

KernelLin

Piecewise linear interpolation embedding with sparse outputs.

KernelReduced

Dimensionality reduction wrapper for high-dimensional embeddings.

PositionalEmbedding

Sinusoidal positional embeddings from transformer architectures.

bijx.nn.embeddings.rescale_range(val, val_range)[source]

Rescale input values to unit interval [0, 1].

Parameters:
  • val – Input values to rescale.

  • val_range (tuple[float, float] | None) – Tuple of (min, max) values defining the input range. If None, returns input unchanged.

Returns:

Rescaled values mapped to [0, 1] interval.

Note

Values outside the specified range will be mapped outside [0, 1]. Consider clamping if strict bounds are required.

class bijx.nn.embeddings.Embedding[source]

Bases: Module

Base class for scalar-to-vector embedding functions.

Provides the foundation for all embedding layers that map scalar inputs to fixed-size feature vectors. Subclasses implement specific embedding strategies like Gaussian kernels, Fourier features, or positional encodings.

Parameters:
  • feature_count (int) – Dimensionality of the output feature vector.

  • rngs (Rngs | None) – Random number generator state for parameter initialization.

Note

This is an abstract base class. Use concrete subclasses like KernelGauss, KernelFourier, or PositionalEmbedding. Its main function is to ensure the feature_count parameter is set.

class bijx.nn.embeddings.KernelGauss[source]

Bases: Embedding

Gaussian radial basis function embedding with learnable widths.

Maps scalar inputs to feature vectors using Gaussian basis functions centered at evenly spaced positions.

The centers \(\mu_i\) are evenly spaced across the input range, and the width parameter \(\sigma\) can be learned during training for optimal feature representation. The width parameter is passed through softplus to ensure positivity.

Parameters:
  • feature_count (int) – Number of Gaussian basis functions.

  • val_range (tuple[float, float] | None) – Input value range for rescaling to [0, 1].

  • width_factor (float) – Initial width parameter (smaller = wider Gaussians).

  • adaptive_width (bool) – Whether to make width parameters trainable.

  • norm (bool) – Whether to normalize outputs to sum to 1 (probability-like).

  • one_width (bool) – If adaptive, whether to use single width vs per-Gaussian widths.

  • rngs (Rngs | None) – Random number generator state.

Example

>>> # Smooth time embedding for CNFs
>>> embed = KernelGauss(feature_count=21, adaptive_width=True, rngs=rngs)
>>> features = embed(0.1)  # Shape: (21,)
class bijx.nn.embeddings.KernelLin[source]

Bases: Embedding

Piecewise linear interpolation embedding with sparse outputs.

Maps scalar inputs to sparse feature vectors using linear interpolation between adjacent basis positions. At most two adjacent features are non-zero for any input.

Parameters:
  • feature_count (int) – Number of interpolation basis positions.

  • val_range (tuple[float, float] | None) – Input value range for rescaling to [0, 1].

  • rngs (Rngs | None) – Random number generator state.

Example

>>> # Efficient sparse embedding
>>> embed = KernelLin(feature_count=11, rngs=rngs)
>>> features = embed(t)  # Only 2 non-zero values
class bijx.nn.embeddings.KernelFourier[source]

Bases: Embedding

Fourier series embedding for smooth representations.

Maps scalar inputs to feature vectors using truncated Fourier series with sine and cosine components. The embedding captures multiple frequency scales, allowing the network to represent both fine-grained and coarse temporal patterns effectively.

Parameters:
  • feature_count (int) – Number of Fourier terms in the expansion. For even counts, uses (feature_count-1)//2 frequency pairs plus constant.

  • val_range (tuple[float, float] | None) – Input value range for normalization.

  • rngs (Rngs | None) – Random number generator state.

Note

The constant term (1.0) is always included. Frequencies increase linearly: 1, 2, 3, … in units of 2π/period.

Example

>>> embed = KernelFourier(feature_count=21, rngs=rngs)
>>> time_features = embed(t)  # Captures multiple time scales
class bijx.nn.embeddings.KernelReduced[source]

Bases: Embedding

Dimensionality reduction wrapper for high-dimensional embeddings.

Applies learned linear dimensionality reduction to another embedding layer. Note that it simply implements a linear map, not strictly a “reduction” and the output features could be chosen larger than the input features.

Parameters:
  • kernel (Module) – Base embedding layer to reduce.

  • feature_count (int) – Target dimensionality (must be < kernel.feature_count).

  • init (Callable) – Initialization function for projection matrix (default: orthogonal).

  • rngs (Rngs | None) – Random number generator state.

Note

The projection matrix is normalized by the base kernel’s feature count to maintain reasonable magnitudes. Orthogonal initialization helps preserve independent features.

Example

>>> # Reduce 49-dimensional Fourier embedding to 20 dimensions
>>> base_embed = KernelFourier(49, rngs=rngs)
>>> reduced = KernelReduced(base_embed, 20, rngs=rngs)
>>> features = reduced(t)
class bijx.nn.embeddings.PositionalEmbedding[source]

Bases: Embedding

Sinusoidal positional embeddings from transformer architectures.

Uses multiple frequencies to create position-dependent representations.

Parameters:
  • feature_count (int) – Output feature dimensionality (must be even).

  • max_positions (float) – Maximum expected input value for frequency scaling. Controls the base wavelength; larger values create lower frequencies.

  • append_input (bool) – Whether to concatenate the raw input to the embedding.

  • rngs (Rngs | None) – Random number generator state.

Example

>>> embed = PositionalEmbedding(feature_count=64, max_positions=1000, rngs=rngs)
>>> pos_features = embed(position_indices)