bijx.DiagonalNormal

class bijx.DiagonalNormal[source]

Bases: ArrayDistribution

Multivariate normal distribution with diagonal covariance matrix.

Implements a multivariate Gaussian distribution with diagonal covariance, allowing different means and variances for each dimension. This is simpler and more efficient than the full MultivariateNormal when off-diagonal correlations are not needed.

The log density is computed as:

\[ \log p(\mathbf{x}) = -\frac{1}{2}\sum_{i=1}^d \left( \frac{(x_i - \mu_i)^2}{\sigma_i^2} + \log(2\pi\sigma_i^2) \right) \]

Parameters:
  • mean (Union[Variable, Array, ndarray, Sequence[Union[int, Any]]]) – Mean vector, shape (dim,) or scalar dim.

  • scales (Union[Variable, Array, ndarray, Sequence[Union[int, Any]], None]) – Standard deviation vector, shape (dim,).

  • rngs – Optional random number generator state.

  • var_cls – Variable class for parameters (default: nnx.Param).

Example

>>> # 2D diagonal normal with different means and variances
>>> mean = jnp.array([1.0, 2.0])
>>> scales = jnp.array([0.5, 1.5])
>>> dist = DiagonalNormal(mean, scales)
>>> samples, log_p = dist.sample(batch_shape=(100,), rng=rng)
>>> assert samples.shape == (100, 2)
>>> assert log_p.shape == (100,)
The parameters an also be instantiated given shapes, and mean is sufficient:
>>> dist = DiagonalNormal((3,), rngs=nnx.Rngs(0))  # 3D multivariate normal
__init__(mean, scales=None, *, rngs=None, var_cls=<class 'flax.nnx.variablelib.Param'>, epsilon=1e-10)[source]

Initialize diagonal multivariate normal distribution.

Parameters:
  • mean (Union[Variable, Array, ndarray, Sequence[Union[int, Any]]]) – Mean vector specification.

  • scales (Union[Variable, Array, ndarray, Sequence[Union[int, Any]], None]) – Standard deviation vector specification.

  • rngs – Optional random number generator state.

  • var_cls – Variable class for parameters.

  • epsilon (float)

Methods

density(x, **kwargs)

Evaluate probability density at given points.

get_batch_shape(x)

Extract batch dimensions from an array sample.

given_dim(dim, *[, rngs, var_cls, epsilon])

Create diagonal normal with given dimensionality.

given_variances(mean, variances, *[, rngs, ...])

Create diagonal normal with given variances.

log_density(x)

Compute log probability density at given points.

sample([batch_shape, rng])

Generate samples from the distribution.

Attributes

cov

Covariance matrix (diagonal).

dim

Dimensionality of the distribution.

event_axes

Axis indices corresponding to event dimensions.

event_dim

Number of event dimensions.

event_size

Total number of elements in the event shape.

scales

variances

Variance vector (scales squared).

classmethod given_dim(dim, *, rngs=None, var_cls=<class 'flax.nnx.variablelib.Param'>, epsilon=1e-10)[source]

Create diagonal normal with given dimensionality.

Parameters:
  • dim (int) – Dimensionality of the distribution.

  • rngs – Optional random number generator state.

  • var_cls – Variable class for parameters.

  • epsilon (float) – Small regularization constant to ensure numerical stability.

Returns:

DiagonalNormal instance.

classmethod given_variances(mean, variances, *, rngs=None, var_cls=<class 'flax.nnx.variablelib.Param'>, epsilon=1e-10)[source]

Create diagonal normal with given variances.

Note: If variances is an instace of nnx.Variable, its value is cloned but the type is preserved (nnx.Param, Const, etc.).

Parameters:
  • variances (Union[Variable, Array, ndarray, Sequence[Union[int, Any]]]) – Variance vector.

  • rngs – Optional random number generator state.

  • var_cls – Variable class for parameters.

  • epsilon (float) – Small regularization constant to ensure numerical stability.

  • mean (Variable | Array | ndarray | Sequence[int | Any])

Returns:

DiagonalNormal instance.

property dim

Dimensionality of the distribution.

property variances

Variance vector (scales squared).

property cov

Covariance matrix (diagonal).

property scales
log_density(x)[source]

Compute log probability density at given points.

Parameters:

x – Points at which to evaluate density, shape (…, dim).

Returns:

Log density values with batch dimensions matching input.

sample(batch_shape=(), rng=None)[source]

Generate samples from the distribution.

Parameters:
  • batch_shape – Shape of batch dimensions for vectorized sampling.

  • rng – Random key for sampling, or None to use internal rngs.

Returns:

Tuple of (samples, log_densities) where samples have shape (*batch_shape, dim) and log_densities have shape batch_shape.