bijx.DiagonalNormal¶
- class bijx.DiagonalNormal[source]¶
Bases:
ArrayDistribution
Multivariate normal distribution with diagonal covariance matrix.
Implements a multivariate Gaussian distribution with diagonal covariance, allowing different means and variances for each dimension. This is simpler and more efficient than the full MultivariateNormal when off-diagonal correlations are not needed.
The log density is computed as:
\[ \log p(\mathbf{x}) = -\frac{1}{2}\sum_{i=1}^d \left( \frac{(x_i - \mu_i)^2}{\sigma_i^2} + \log(2\pi\sigma_i^2) \right) \]- Parameters:
mean (
Union
[Variable
,Array
,ndarray
,Sequence
[Union
[int
,Any
]]]) – Mean vector, shape (dim,) or scalar dim.scales (
Union
[Variable
,Array
,ndarray
,Sequence
[Union
[int
,Any
]],None
]) – Standard deviation vector, shape (dim,).rngs – Optional random number generator state.
var_cls – Variable class for parameters (default: nnx.Param).
Example
>>> # 2D diagonal normal with different means and variances >>> mean = jnp.array([1.0, 2.0]) >>> scales = jnp.array([0.5, 1.5]) >>> dist = DiagonalNormal(mean, scales) >>> samples, log_p = dist.sample(batch_shape=(100,), rng=rng) >>> assert samples.shape == (100, 2) >>> assert log_p.shape == (100,)
- The parameters an also be instantiated given shapes, and mean is sufficient:
>>> dist = DiagonalNormal((3,), rngs=nnx.Rngs(0)) # 3D multivariate normal
- __init__(mean, scales=None, *, rngs=None, var_cls=<class 'flax.nnx.variablelib.Param'>, epsilon=1e-10)[source]¶
Initialize diagonal multivariate normal distribution.
- Parameters:
mean (
Union
[Variable
,Array
,ndarray
,Sequence
[Union
[int
,Any
]]]) – Mean vector specification.scales (
Union
[Variable
,Array
,ndarray
,Sequence
[Union
[int
,Any
]],None
]) – Standard deviation vector specification.rngs – Optional random number generator state.
var_cls – Variable class for parameters.
epsilon (float)
Methods
density
(x, **kwargs)Evaluate probability density at given points.
get_batch_shape
(x)Extract batch dimensions from an array sample.
given_dim
(dim, *[, rngs, var_cls, epsilon])Create diagonal normal with given dimensionality.
given_variances
(mean, variances, *[, rngs, ...])Create diagonal normal with given variances.
log_density
(x)Compute log probability density at given points.
sample
([batch_shape, rng])Generate samples from the distribution.
Attributes
Covariance matrix (diagonal).
Dimensionality of the distribution.
event_axes
Axis indices corresponding to event dimensions.
event_dim
Number of event dimensions.
event_size
Total number of elements in the event shape.
Variance vector (scales squared).
- classmethod given_dim(dim, *, rngs=None, var_cls=<class 'flax.nnx.variablelib.Param'>, epsilon=1e-10)[source]¶
Create diagonal normal with given dimensionality.
- Parameters:
dim (
int
) – Dimensionality of the distribution.rngs – Optional random number generator state.
var_cls – Variable class for parameters.
epsilon (
float
) – Small regularization constant to ensure numerical stability.
- Returns:
DiagonalNormal instance.
- classmethod given_variances(mean, variances, *, rngs=None, var_cls=<class 'flax.nnx.variablelib.Param'>, epsilon=1e-10)[source]¶
Create diagonal normal with given variances.
Note: If variances is an instace of nnx.Variable, its value is cloned but the type is preserved (nnx.Param, Const, etc.).
- Parameters:
variances (
Union
[Variable
,Array
,ndarray
,Sequence
[Union
[int
,Any
]]]) – Variance vector.rngs – Optional random number generator state.
var_cls – Variable class for parameters.
epsilon (
float
) – Small regularization constant to ensure numerical stability.mean (Variable | Array | ndarray | Sequence[int | Any])
- Returns:
DiagonalNormal instance.
- property dim¶
Dimensionality of the distribution.
- property variances¶
Variance vector (scales squared).
- property cov¶
Covariance matrix (diagonal).
- property scales¶
- log_density(x)[source]¶
Compute log probability density at given points.
- Parameters:
x – Points at which to evaluate density, shape (…, dim).
- Returns:
Log density values with batch dimensions matching input.
- sample(batch_shape=(), rng=None)[source]¶
Generate samples from the distribution.
- Parameters:
batch_shape – Shape of batch dimensions for vectorized sampling.
rng – Random key for sampling, or None to use internal rngs.
- Returns:
Tuple of (samples, log_densities) where samples have shape
(*batch_shape, dim)
and log_densities have shapebatch_shape
.