API Reference

This section provides comprehensive documentation for all bijx components, organized by functionality.

Main components

Normalizing flows with JAX/NNX, specializing in physics applications.

Provides modular bijection primitives, flexible distributions, and tools for lattice field theory including matrix Lie group operations and continuous flows.

Example

>>> # Create a simple normalizing flow
>>> base_dist = bijx.IndependentNormal(event_shape=(10,), rngs=rngs)
>>> bijection = bijx.Chain(
...     bijx.AffineLinear(rngs=rngs),
...     bijx.Tanh(),
...     bijx.AffineLinear(rngs=rngs)
... )
>>>
>>> # Sample and evaluate densities
>>> x, log_p = base_dist.sample(batch_shape=(100,))
>>> y, log_q = bijection.forward(x, log_p)

Core Classes and Base Types

Bijection

Base class for all bijective transformations.

Distribution

Base class for all probability distributions.

ArrayDistribution

Base class for distributions over multi-dimensional arrays.

ApplyBijection

Convenience base class for bijections with unified forward/reverse logic.

Distributions

IndependentNormal

Independent standard normal distribution over arrays.

IndependentUniform

Independent uniform distribution over arrays on [0, 1].

MultivariateNormal

Multivariate normal distribution with Cholesky parametrization.

DiagonalNormal

Multivariate normal distribution with diagonal covariance matrix.

MixtureStack

Mixture of distributions of equal kind.

GaussianMixture

Gaussian mixture model.

Sampling and Transforms

Transformed

Distribution obtained by applying a bijection to a base distribution.

BufferedSampler

Distribution wrapper that caches samples for efficient use in MCMC.

Bijection Composition and Meta-bijections

Chain

Sequential composition of multiple bijections.

ScanChain

Jax compilation-efficient chain of identical bijections using JAX scan.

Inverse

Inverted bijection that swaps forward and reverse directions.

CondInverse

Conditionally inverted bijection based on runtime boolean flag.

Frozen

Wrapper to screen internal parameters from training.

“Meta” bijections that do not change the log-density.

MetaLayer

Convenient constructor for bijections that preserve probability density.

ExpandDims

Expand tensor dimensions along specified axis.

SqueezeDims

Remove singleton dimensions along specified axis.

Reshape

Reshape tensor event dimensions while preserving batch dimensions.

Partial

Bijection wrapper that fixes keyword arguments.

General Coupling and Masking

GeneralCouplingLayer

General coupling layer with flexible masking and bijection support.

BinaryMask

Binary mask for coupling layer split/merge operations.

checker_mask(shape, parity)

Create checkerboard pattern binary mask.

ModuleReconstructor

Parameter management utility for dynamically parameterizing modules.

Spline Bijections

MonotoneRQSpline

Monotonic rational quadratic spline bijection.

rational_quadratic_spline(inputs, ...[, ...])

Apply monotonic rational quadratic spline transformation.

Continuous Flows

ContFlowCG

Continuous normalizing flow using Crouch-Grossmann integration.

ContFlowDiffrax

Continuous normalizing flow using diffrax ODE solver.

ContFlowRK4

Continuous normalizing flow using fixed-step RK4 solver.

ConvVF

Convolutional continuous normalizing flow with symmetry preservation.

AutoJacVF

Automatic Jacobian computation for vector fields.

One-dimensional Bijections

ScalarBijection

Base class for element-wise one-dimensional bijections.

AffineLinear

Learnable affine transformation.

Scaling

Scaling transformation.

Shift

Shift transformation.

BetaStretch

Beta-inspired stretching on unit interval.

Exponential

Exponential transform to positive reals.

GaussianCDF

Bijection via Gaussian CDF with learnable location and scale.

Power

Power transformation for positive values.

Sigmoid

Sigmoid normalization transform.

Sinh

Hyperbolic sine transformation.

SoftPlus

SoftPlus transform.

Tan

Tangent-based unbounded transform.

Tanh

Hyperbolic tangent bounded transform.

Fourier and Physics-specific Bijections

ToFourierData

Bijection for converting between real and Fourier data representations.

FreeTheoryScaling

Scaling bijection mapping white noise to free field theory.

SpectrumScaling

Diagonal scaling transformation in Fourier space.

ODE Solvers

DiffraxConfig

Configuration for diffrax ODE solving in continuous normalizing flows.

odeint_rk4(fun, y0, end_time, *args, step_size)

Fixed step-size Runge-Kutta implementation with custom adjoint.

MCMC

These tools mimic the API of blackjax, with the main difference that the samples and the proposal densities are generated simultaneously.

IMH

Independent Metropolis-Hastings sampler.

IMHState

State for Independent Metropolis-Hastings sampler.

IMHInfo

Information about the IMH sampling step.

Utilities

Const

Mark a variable as constant during training.

FrozenFilter

Filter that matches constant variables and anything in a 'frozen' path.

ShapeInfo

Comprehensive shape information manager for array operations.

default_wrap(x[, cls, init_fn, init_cls, rngs])

Flexibly wrap parameter specifications into nnx.Variable instances.

effective_sample_size(target_ld, sample_ld)

Compute effective sample size from importance weights.

moving_average(x[, window])

Compute moving average of a 1D array.

noise_model(rng, model[, scale, noise_fn])

Add random noise to model parameters for testing or regularization.

reverse_dkl(target_ld, sample_ld)

Estimate reverse Kullback-Leibler divergence.

load_shapes_magic()

Load IPython magic command for inspecting JAX pytree shapes.

Submodules

Core submodules provide tools for lattice field theory, Fourier transformations, and more.

fourier

Fourier transform utilities for lattice field theory and physics applications.

lie

Lie group operations and automatic differentiation tools.

cg

Crouch-Grossmann integration methods for Lie group ordinary differential equations.

lattice

Methods for manipulating lattice field configurations.

lattice.gauge

Gauge field theory utilities.

lattice.scalar

Scalar field theory utilities.

For interfacing with flowjax,the following submodule can be used. Since flowjax is not an explicit dependency of bijx, it has to be imported explicitly.

bijx.nn provides building blocks for neural networks and prototyping.

nn.conv

Symmetric convolutions with group invariance for lattice field theory.

nn.embeddings

Embedding layers for time and positional encoding in neural networks.

nn.features

Nonlinear feature transformations for neural network layers.

nn.nets

Complete neural network architectures for convenience and prototyping.