bijx.GaussianMixture

class bijx.GaussianMixture[source]

Bases: Distribution

Gaussian mixture model.

This is a convenience wrapper around MixtureStack of either diagonal or general multivariate normal distributions.

__init__(means, covariances=None, weights=None, *, rngs=None, var_cls=<class 'flax.nnx.variablelib.Param'>, epsilon=1e-10)[source]
Parameters:
  • means (Variable | Array | ndarray | Sequence[int | Any])

  • covariances (Variable | Array | ndarray | Sequence[int | Any] | None)

  • weights (Variable | Array | ndarray | Sequence[int | Any] | None)

  • epsilon (float)

Methods

density(x, **kwargs)

Evaluate probability density at given points.

get_batch_shape(x)

Extract batch dimensions from a sample.

log_density(x)

Evaluate log probability density at given points.

sample([batch_shape, rng])

Generate samples from the distribution.

Attributes

get_batch_shape(x)[source]

Extract batch dimensions from a sample.

Parameters:

x – A sample from this distribution.

Returns:

Tuple representing the batch dimensions of the sample.

property covs
property means
property weights
log_density(x)[source]

Evaluate log probability density at given points.

Parameters:
  • x – Points at which to evaluate density, with event dimensions matching the distribution’s event shape.

  • **kwargs – Additional distribution-specific evaluation arguments.

Returns:

Log density values with batch dimensions matching input.

sample(batch_shape=(), rng=None)[source]

Generate samples from the distribution.

Parameters:
  • batch_shape – Shape of batch dimensions for vectorized sampling.

  • rng – Random key for sampling, or None to use internal rngs.

  • **kwargs – Additional distribution-specific sampling arguments.

Returns:

Tuple of (samples, log_densities) where samples have shape (*batch_shape, *event_shape) and log_densities have shape batch_shape.