bijx.ConvVF¶
- class bijx.ConvVF[source]¶
Bases:
Module
Convolutional continuous normalizing flow with symmetry preservation.
Implements a vector field for continuous normalizing flows that uses symmetric convolutions to preserve spatial structure. The convolution kernels are time-dependent and coupled with nonlinear feature transformations.
The vector field preserves discrete symmetries (typically D4 group) while allowing for complex, spatially-structured transformations.
- Parameters:
shape_info (
ShapeInfo
) – Shape information for spatial and channel dimensions.conv (
ConvSym
) – Symmetric convolution layer with time-dependent parameters.time_kernel (
Module
) – Time embedding module for kernel modulation.feature_map (
NonlinearFeatures
) – Nonlinear feature transformation applied before convolution.feature_superposition (
Variable
|None
) – Optional feature dimensionality reduction.
Note
Most conveniently constructed using
ConvVF.build()
.Example
>>> # Build conv CNF for 2D lattice >>> cnf = ConvVF.build( ... kernel_shape=(3, 3), ... channel_shape=(), ... features=( ... partial(FourierFeatures, 49), ... partial(PolynomialFeatures, (1,)), ... ), ... rngs=rngs ... ) >>> velocity, div = cnf(t, phi)
- __init__(*, shape_info, conv, time_kernel, feature_map, feature_superposition=None)[source]¶
- Parameters:
shape_info (ShapeInfo)
conv (ConvSym)
time_kernel (Module)
feature_map (NonlinearFeatures)
feature_superposition (Variable | None)
Methods
build
(kernel_shape[, channel_shape, ...])Build a ConvVF with default architecture choices.
- classmethod build(kernel_shape, channel_shape=(), *, symmetry=<function kernel_d4>, use_bias=False, time_kernel=KernelFourier( feature_count=21, val_range=None ), time_kernel_reduced=20, features=(functools.partial(<class 'bijx.nn.features.FourierFeatures'>, 49), functools.partial(<class 'bijx.nn.features.PolynomialFeatures'>, (1, ))), features_reduced=20, rngs)[source]¶
Build a ConvVF with default architecture choices.
Constructs a complete convolutional CNF by assembling symmetric convolutions, time embeddings, and feature transformations with sensible defaults. In particular, it takes care of enlarging internal parameters with an extra axis, which is later contracted with the time embedding.
- Parameters:
kernel_shape – Spatial shape of convolution kernels (e.g., (3, 3)).
channel_shape (
tuple
[int
,...
]) – Shape of channel dimensions, defaults to scalar channels.symmetry (
Callable
) – Symmetry group operation (default D4 rotations and reflections).use_bias (
bool
) – Whether to include bias terms in convolutions.time_kernel (
Module
) – Time embedding module for kernel modulation.time_kernel_reduced – Dimensionality reduction for time embeddings.
features (
tuple
[NonlinearFeatures
,...
]) – Tuple of feature transformation classes to compose.features_reduced (
int
|None
) – Dimensionality reduction for feature superposition.rngs (
Rngs
) – Random number generator state for parameter initialization.
- Returns:
ConvVF instance to be used in continuous normalizing flows.
Example
>>> # Standard 2D lattice CNF with Fourier + linear features >>> cnf = ConvVF.build( ... kernel_shape=(3, 3), ... channel_shape=(1,), # scalar field ... features_reduced=16, ... rngs=rngs ... )