bijx.MultivariateNormal

class bijx.MultivariateNormal[source]

Bases: ArrayDistribution

Multivariate normal distribution with Cholesky parametrization.

Implements a multivariate Gaussian distribution using Cholesky decomposition for numerical stability. The covariance matrix is represented by its Cholesky factor, which ensures positive definiteness and enables efficient sampling and density evaluation.

The log density is computed as:

\[ \log p(\mathbf{x}) = -\frac{1}{2}(\mathbf{x} - \boldsymbol{\mu})^T \mathbf{L}^{-T}\mathbf{L}^{-1}(\mathbf{x} - \boldsymbol{\mu}) - \frac{d}{2}\log(2\pi) - \sum_{i=1}^d \log L_{ii} \]

where \(\mathbf{L}\) is the Cholesky factor such that \(\boldsymbol{\Sigma} = \mathbf{L}\mathbf{L}^T\).

Parameters:
  • mean (Union[Variable, Array, ndarray, Sequence[Union[int, Any]]]) – Mean vector, shape (dim,) or scalar dim.

  • cholesky (Union[Variable, Array, ndarray, Sequence[Union[int, Any]], None]) – Cholesky factor vector (packed lower triangular), shape (dim*(dim+1)//2,).

  • rngs – Optional random number generator state.

  • var_cls – Variable class for parameters (default: nnx.Param).

Example

>>> # 2D multivariate normal
>>> mean = jnp.array([1.0, 2.0])
>>> cov = jnp.array([[2.0, 0.5], [0.5, 1.0]])
>>> dist = MultivariateNormal.given_cov(mean, cov)
>>> samples, log_p = dist.sample(batch_shape=(100,), rng=rng)
>>> assert samples.shape == (100, 2)
>>> assert log_p.shape == (100,)
The parameters an also be instantiated given shapes, and mean is sufficient:
>>> dist = MultivariateNormal((3,), rngs=nnx.Rngs(0))  # 3D multivariate normal
__init__(mean, cholesky=None, *, rngs=None, var_cls=<class 'flax.nnx.variablelib.Param'>, epsilon=1e-10)[source]

Initialize multivariate normal distribution.

Parameters:
  • mean (Union[Variable, Array, ndarray, Sequence[Union[int, Any]]]) – Mean vector specification.

  • cholesky (Union[Variable, Array, ndarray, Sequence[Union[int, Any]], None]) – Cholesky factor specification (packed lower triangular).

  • rngs – Optional random number generator state.

  • var_cls – Variable class for parameters.

  • epsilon (float) – Small regularization constant to ensure numerical stability.

Methods

density(x, **kwargs)

Evaluate probability density at given points.

get_batch_shape(x)

Extract batch dimensions from an array sample.

given_cov(mean, cov, *[, rngs, var_cls, epsilon])

Create multivariate normal with given mean and covariance.

given_dim(dim, *, rngs[, var_cls, epsilon])

Create multivariate normal with given dimensionality.

log_density(x)

Compute log probability density at given points.

sample([batch_shape, rng])

Generate samples from the distribution.

Attributes

cov

Reconstruct covariance matrix from Cholesky factor.

dim

Dimensionality of the distribution.

event_axes

Axis indices corresponding to event dimensions.

event_dim

Number of event dimensions.

event_size

Total number of elements in the event shape.

property dim

Dimensionality of the distribution.

property cov

Reconstruct covariance matrix from Cholesky factor.

classmethod given_dim(dim, *, rngs, var_cls=<class 'flax.nnx.variablelib.Param'>, epsilon=1e-10)[source]

Create multivariate normal with given dimensionality.

Parameters:
  • dim (int) – Dimensionality of the distribution.

  • rngs (Rngs) – Random number generator state.

  • var_cls – Variable class for parameters.

  • epsilon (float) – Small regularization constant to ensure numerical stability.

Returns:

MultivariateNormal instance.

classmethod given_cov(mean, cov, *, rngs=None, var_cls=<class 'bijx.utils.Const'>, epsilon=1e-10)[source]

Create multivariate normal with given mean and covariance.

Parameters:
  • mean (Union[Variable, Array, ndarray, Sequence[Union[int, Any]]]) – Mean vector.

  • cov (Union[Variable, Array, ndarray, Sequence[Union[int, Any]]]) – Covariance matrix.

  • rngs – Optional random number generator state.

  • var_cls – Variable class for parameters (default: Const).

  • epsilon (float) – Small regularization constant to ensure numerical stability.

Returns:

MultivariateNormal instance.

log_density(x)[source]

Compute log probability density at given points.

Parameters:

x – Points at which to evaluate density, shape (…, dim).

Returns:

Log density values with batch dimensions matching input.

sample(batch_shape=(), rng=None)[source]

Generate samples from the distribution.

Parameters:
  • batch_shape – Shape of batch dimensions for vectorized sampling.

  • rng – Random key for sampling, or None to use internal rngs.

Returns:

Tuple of (samples, log_densities) where samples have shape (*batch_shape, dim) and log_densities have shape batch_shape.