bijx.ArrayDistribution¶
- class bijx.ArrayDistribution[source]¶
Bases:
Distribution
Base class for distributions over multi-dimensional arrays.
Extends the base Distribution class for distributions whose support consists of arrays with a fixed event shape. Provides utilities for handling event vs batch dimensions and shape manipulation.
The event shape defines the dimensionality of individual samples, while batch dimensions allow vectorized operations over multiple samples.
- Parameters:
event_shape (
tuple
[int
,...
]) – Shape of individual samples (event dimensions).rngs (
Rngs
|None
) – Optional random number generator state.
Example
>>> # 2D distribution (e.g., for images or lattice fields) >>> dist = SomeArrayDistribution(event_shape=(32, 32)) >>> samples, log_p = dist.sample(batch_shape=(100,)) # 100 samples >>> assert samples.shape == (100, 32, 32) # batch + event >>> assert log_p.shape == (100,) # batch only
- __init__(event_shape, rngs=None)[source]¶
- Parameters:
event_shape (tuple[int, ...])
rngs (Rngs | None)
Methods
density
(x, **kwargs)Evaluate probability density at given points.
Extract batch dimensions from an array sample.
log_density
(x, **kwargs)Evaluate log probability density at given points.
sample
([batch_shape, rng])Generate samples from the distribution.
Attributes
Axis indices corresponding to event dimensions.
Number of event dimensions.
Total number of elements in the event shape.
- property event_dim¶
Number of event dimensions.
- property event_size¶
Total number of elements in the event shape.
- property event_axes¶
Axis indices corresponding to event dimensions.