bijx.GaussianMixture¶
- class bijx.GaussianMixture[source]¶
Bases:
Distribution
Gaussian mixture model.
This is a convenience wrapper around
MixtureStack
of either diagonal or general multivariate normal distributions.- __init__(means, covariances=None, weights=None, *, rngs=None, var_cls=<class 'flax.nnx.variablelib.Param'>, epsilon=1e-10)[source]¶
- Parameters:
means (Variable | Array | ndarray | Sequence[int | Any])
covariances (Variable | Array | ndarray | Sequence[int | Any] | None)
weights (Variable | Array | ndarray | Sequence[int | Any] | None)
epsilon (float)
Methods
density
(x, **kwargs)Evaluate probability density at given points.
Extract batch dimensions from a sample.
log_density
(x)Evaluate log probability density at given points.
sample
([batch_shape, rng])Generate samples from the distribution.
Attributes
- get_batch_shape(x)[source]¶
Extract batch dimensions from a sample.
- Parameters:
x – A sample from this distribution.
- Returns:
Tuple representing the batch dimensions of the sample.
- property covs¶
- property means¶
- property weights¶
- log_density(x)[source]¶
Evaluate log probability density at given points.
- Parameters:
x – Points at which to evaluate density, with event dimensions matching the distribution’s event shape.
**kwargs – Additional distribution-specific evaluation arguments.
- Returns:
Log density values with batch dimensions matching input.
- sample(batch_shape=(), rng=None)[source]¶
Generate samples from the distribution.
- Parameters:
batch_shape – Shape of batch dimensions for vectorized sampling.
rng – Random key for sampling, or None to use internal rngs.
**kwargs – Additional distribution-specific sampling arguments.
- Returns:
Tuple of (samples, log_densities) where samples have shape
(*batch_shape, *event_shape)
and log_densities have shapebatch_shape
.