API Reference¶
This section provides comprehensive documentation for all bijx components, organized by functionality.
Main components¶
Normalizing flows with JAX/NNX, specializing in physics applications.
Provides modular bijection primitives, flexible distributions, and tools for lattice field theory including matrix Lie group operations and continuous flows.
Example
>>> # Create a simple normalizing flow
>>> base_dist = bijx.IndependentNormal(event_shape=(10,), rngs=rngs)
>>> bijection = bijx.Chain(
... bijx.AffineLinear(rngs=rngs),
... bijx.Tanh(),
... bijx.AffineLinear(rngs=rngs)
... )
>>>
>>> # Sample and evaluate densities
>>> x, log_p = base_dist.sample(batch_shape=(100,))
>>> y, log_q = bijection.forward(x, log_p)
Core Classes and Base Types¶
Base class for all bijective transformations. |
|
Base class for all probability distributions. |
|
Base class for distributions over multi-dimensional arrays. |
|
Convenience base class for bijections with unified forward/reverse logic. |
Distributions¶
Independent standard normal distribution over arrays. |
|
Independent uniform distribution over arrays on [0, 1]. |
|
Multivariate normal distribution with Cholesky parametrization. |
|
Multivariate normal distribution with diagonal covariance matrix. |
|
Mixture of distributions of equal kind. |
|
Gaussian mixture model. |
Sampling and Transforms¶
Distribution obtained by applying a bijection to a base distribution. |
|
Distribution wrapper that caches samples for efficient use in MCMC. |
Bijection Composition and Meta-bijections¶
Sequential composition of multiple bijections. |
|
Jax compilation-efficient chain of identical bijections using JAX scan. |
|
Inverted bijection that swaps forward and reverse directions. |
|
Conditionally inverted bijection based on runtime boolean flag. |
|
Wrapper to screen internal parameters from training. |
“Meta” bijections that do not change the log-density.
Convenient constructor for bijections that preserve probability density. |
|
Expand tensor dimensions along specified axis. |
|
Remove singleton dimensions along specified axis. |
|
Reshape tensor event dimensions while preserving batch dimensions. |
|
Bijection wrapper that fixes keyword arguments. |
General Coupling and Masking¶
General coupling layer with flexible masking and bijection support. |
|
Binary mask for coupling layer split/merge operations. |
|
|
Create checkerboard pattern binary mask. |
Parameter management utility for dynamically parameterizing modules. |
Spline Bijections¶
Monotonic rational quadratic spline bijection. |
|
|
Apply monotonic rational quadratic spline transformation. |
Continuous Flows¶
Continuous normalizing flow using Crouch-Grossmann integration. |
|
Continuous normalizing flow using diffrax ODE solver. |
|
Continuous normalizing flow using fixed-step RK4 solver. |
|
Convolutional continuous normalizing flow with symmetry preservation. |
|
Automatic Jacobian computation for vector fields. |
One-dimensional Bijections¶
Base class for element-wise one-dimensional bijections. |
|
Learnable affine transformation. |
|
Scaling transformation. |
|
Shift transformation. |
|
Beta-inspired stretching on unit interval. |
|
Exponential transform to positive reals. |
|
Bijection via Gaussian CDF with learnable location and scale. |
|
Power transformation for positive values. |
|
Sigmoid normalization transform. |
|
Hyperbolic sine transformation. |
|
SoftPlus transform. |
|
Tangent-based unbounded transform. |
|
Hyperbolic tangent bounded transform. |
Fourier and Physics-specific Bijections¶
Bijection for converting between real and Fourier data representations. |
|
Scaling bijection mapping white noise to free field theory. |
|
Diagonal scaling transformation in Fourier space. |
ODE Solvers¶
Configuration for diffrax ODE solving in continuous normalizing flows. |
|
|
Fixed step-size Runge-Kutta implementation with custom adjoint. |
MCMC¶
These tools mimic the API of blackjax, with the main difference that the samples and the proposal densities are generated simultaneously.
Utilities¶
Mark a variable as constant during training. |
|
Filter that matches constant variables and anything in a 'frozen' path. |
|
Comprehensive shape information manager for array operations. |
|
|
Flexibly wrap parameter specifications into nnx.Variable instances. |
|
Compute effective sample size from importance weights. |
|
Compute moving average of a 1D array. |
|
Add random noise to model parameters for testing or regularization. |
|
Estimate reverse Kullback-Leibler divergence. |
Load IPython magic command for inspecting JAX pytree shapes. |
Submodules¶
Core submodules provide tools for lattice field theory, Fourier transformations, and more.
Fourier transform utilities for lattice field theory and physics applications. |
|
Lie group operations and automatic differentiation tools. |
|
Crouch-Grossmann integration methods for Lie group ordinary differential equations. |
|
Methods for manipulating lattice field configurations. |
|
Gauge field theory utilities. |
|
Scalar field theory utilities. |
For interfacing with flowjax,the following submodule can be used. Since flowjax is not an explicit dependency of bijx, it has to be imported explicitly.
bijx.nn
provides building blocks for neural networks and prototyping.
Symmetric convolutions with group invariance for lattice field theory. |
|
Embedding layers for time and positional encoding in neural networks. |
|
Nonlinear feature transformations for neural network layers. |
|
Complete neural network architectures for convenience and prototyping. |