bijx.nn.embeddings¶
Embedding layers for time and positional encoding in neural networks.
This module provides various embedding functions that map scalar inputs to high-dimensional feature vectors. These are particularly useful for continuous normalizing flows where time-dependent parameters need rich feature representations, and for positional encodings in attention mechanisms.
Functions
|
Rescale input values to unit interval [0, 1]. |
Classes
Base class for scalar-to-vector embedding functions. |
|
Fourier series embedding for smooth representations. |
|
Gaussian radial basis function embedding with learnable widths. |
|
Piecewise linear interpolation embedding with sparse outputs. |
|
Dimensionality reduction wrapper for high-dimensional embeddings. |
|
Sinusoidal positional embeddings from transformer architectures. |
- bijx.nn.embeddings.rescale_range(val, val_range)[source]¶
Rescale input values to unit interval [0, 1].
- Parameters:
val – Input values to rescale.
val_range (
tuple
[float
,float
] |None
) – Tuple of (min, max) values defining the input range. If None, returns input unchanged.
- Returns:
Rescaled values mapped to [0, 1] interval.
Note
Values outside the specified range will be mapped outside [0, 1]. Consider clamping if strict bounds are required.
- class bijx.nn.embeddings.Embedding[source]¶
Bases:
Module
Base class for scalar-to-vector embedding functions.
Provides the foundation for all embedding layers that map scalar inputs to fixed-size feature vectors. Subclasses implement specific embedding strategies like Gaussian kernels, Fourier features, or positional encodings.
- Parameters:
feature_count (
int
) – Dimensionality of the output feature vector.rngs (
Rngs
|None
) – Random number generator state for parameter initialization.
Note
This is an abstract base class. Use concrete subclasses like
KernelGauss
,KernelFourier
, orPositionalEmbedding
. Its main function is to ensure thefeature_count
parameter is set.
- class bijx.nn.embeddings.KernelGauss[source]¶
Bases:
Embedding
Gaussian radial basis function embedding with learnable widths.
Maps scalar inputs to feature vectors using Gaussian basis functions centered at evenly spaced positions.
The centers \(\mu_i\) are evenly spaced across the input range, and the width parameter \(\sigma\) can be learned during training for optimal feature representation. The width parameter is passed through softplus to ensure positivity.
- Parameters:
feature_count (
int
) – Number of Gaussian basis functions.val_range (
tuple
[float
,float
] |None
) – Input value range for rescaling to [0, 1].width_factor (
float
) – Initial width parameter (smaller = wider Gaussians).adaptive_width (
bool
) – Whether to make width parameters trainable.norm (
bool
) – Whether to normalize outputs to sum to 1 (probability-like).one_width (
bool
) – If adaptive, whether to use single width vs per-Gaussian widths.rngs (
Rngs
|None
) – Random number generator state.
Example
>>> # Smooth time embedding for CNFs >>> embed = KernelGauss(feature_count=21, adaptive_width=True, rngs=rngs) >>> features = embed(0.1) # Shape: (21,)
- class bijx.nn.embeddings.KernelLin[source]¶
Bases:
Embedding
Piecewise linear interpolation embedding with sparse outputs.
Maps scalar inputs to sparse feature vectors using linear interpolation between adjacent basis positions. At most two adjacent features are non-zero for any input.
- Parameters:
feature_count (
int
) – Number of interpolation basis positions.val_range (
tuple
[float
,float
] |None
) – Input value range for rescaling to [0, 1].rngs (
Rngs
|None
) – Random number generator state.
Example
>>> # Efficient sparse embedding >>> embed = KernelLin(feature_count=11, rngs=rngs) >>> features = embed(t) # Only 2 non-zero values
- class bijx.nn.embeddings.KernelFourier[source]¶
Bases:
Embedding
Fourier series embedding for smooth representations.
Maps scalar inputs to feature vectors using truncated Fourier series with sine and cosine components. The embedding captures multiple frequency scales, allowing the network to represent both fine-grained and coarse temporal patterns effectively.
- Parameters:
feature_count (
int
) – Number of Fourier terms in the expansion. For even counts, uses (feature_count-1)//2 frequency pairs plus constant.val_range (
tuple
[float
,float
] |None
) – Input value range for normalization.rngs (
Rngs
|None
) – Random number generator state.
Note
The constant term (1.0) is always included. Frequencies increase linearly: 1, 2, 3, … in units of 2π/period.
Example
>>> embed = KernelFourier(feature_count=21, rngs=rngs) >>> time_features = embed(t) # Captures multiple time scales
- class bijx.nn.embeddings.KernelReduced[source]¶
Bases:
Embedding
Dimensionality reduction wrapper for high-dimensional embeddings.
Applies learned linear dimensionality reduction to another embedding layer. Note that it simply implements a linear map, not strictly a “reduction” and the output features could be chosen larger than the input features.
- Parameters:
kernel (
Module
) – Base embedding layer to reduce.feature_count (
int
) – Target dimensionality (must be < kernel.feature_count).init (
Callable
) – Initialization function for projection matrix (default: orthogonal).rngs (
Rngs
|None
) – Random number generator state.
Note
The projection matrix is normalized by the base kernel’s feature count to maintain reasonable magnitudes. Orthogonal initialization helps preserve independent features.
Example
>>> # Reduce 49-dimensional Fourier embedding to 20 dimensions >>> base_embed = KernelFourier(49, rngs=rngs) >>> reduced = KernelReduced(base_embed, 20, rngs=rngs) >>> features = reduced(t)
- class bijx.nn.embeddings.PositionalEmbedding[source]¶
Bases:
Embedding
Sinusoidal positional embeddings from transformer architectures.
Uses multiple frequencies to create position-dependent representations.
- Parameters:
feature_count (
int
) – Output feature dimensionality (must be even).max_positions (
float
) – Maximum expected input value for frequency scaling. Controls the base wavelength; larger values create lower frequencies.append_input (
bool
) – Whether to concatenate the raw input to the embedding.rngs (
Rngs
|None
) – Random number generator state.
Example
>>> embed = PositionalEmbedding(feature_count=64, max_positions=1000, rngs=rngs) >>> pos_features = embed(position_indices)