bijx.ArrayDistribution

class bijx.ArrayDistribution[source]

Bases: Distribution

Base class for distributions over multi-dimensional arrays.

Extends the base Distribution class for distributions whose support consists of arrays with a fixed event shape. Provides utilities for handling event vs batch dimensions and shape manipulation.

The event shape defines the dimensionality of individual samples, while batch dimensions allow vectorized operations over multiple samples.

Parameters:
  • event_shape (tuple[int, ...]) – Shape of individual samples (event dimensions).

  • rngs (Rngs | None) – Optional random number generator state.

Example

>>> # 2D distribution (e.g., for images or lattice fields)
>>> dist = SomeArrayDistribution(event_shape=(32, 32))
>>> samples, log_p = dist.sample(batch_shape=(100,))  # 100 samples
>>> assert samples.shape == (100, 32, 32)  # batch + event
>>> assert log_p.shape == (100,)  # batch only
__init__(event_shape, rngs=None)[source]
Parameters:
  • event_shape (tuple[int, ...])

  • rngs (Rngs | None)

Methods

density(x, **kwargs)

Evaluate probability density at given points.

get_batch_shape(x)

Extract batch dimensions from an array sample.

log_density(x, **kwargs)

Evaluate log probability density at given points.

sample([batch_shape, rng])

Generate samples from the distribution.

Attributes

event_axes

Axis indices corresponding to event dimensions.

event_dim

Number of event dimensions.

event_size

Total number of elements in the event shape.

property event_dim

Number of event dimensions.

property event_size

Total number of elements in the event shape.

property event_axes

Axis indices corresponding to event dimensions.

get_batch_shape(x)[source]

Extract batch dimensions from an array sample.

Return type:

tuple[int, ...]

Parameters:

x (Any)