bijx.BufferedSampler¶
- class bijx.BufferedSampler[source]¶
Bases:
Distribution
Distribution wrapper that caches samples for efficient use in MCMC.
Maintains an internal buffer of pre-computed samples to avoid inefficient generation of individual samples.
The buffer is refilled automatically when exhausted. Only single-sample requests (batch_shape=()) use the buffer; batch requests are forwarded directly to the underlying distribution.
Note
This class maintains internal state and should be used carefully in JAX transformations. The buffer state will be updated during sampling operations. This is compatible with flax.nnx’s state management.
Example
>>> expensive_dist = SomeExpensiveDistribution() >>> buffered = bijx.BufferedSampler(expensive_dist, buffer_size=1000) >>> sample1 = buffered.sample() # Fills buffer if empty >>> sample2 = buffered.sample() # Uses cached sample
- __init__(dist, buffer_size)[source]¶
Initialize buffered sampler.
- Parameters:
dist (
Distribution
) – Base distribution to sample from.buffer_size (
int
) – Number of samples to cache in buffer.
Methods
density
(x, **kwargs)Evaluate probability density at given points.
get_batch_shape
(x)Extract batch dimensions from a sample.
log_density
(x, **kwargs)Evaluate log density using the underlying distribution.
sample
([batch_shape, rng])Sample from the buffered distribution.
- sample(batch_shape=(), rng=None, **kwargs)[source]¶
Sample from the buffered distribution.
For single samples (batch_shape=()), returns cached samples from the buffer, refilling when necessary. For batch requests, forwards directly to the underlying distribution.
- Parameters:
batch_shape (
tuple
[int
,...
]) – Shape of batch dimensions. Only () uses buffering.rng (
RngKey
|None
) – Random key for sampling (used when refilling buffer).**kwargs – Additional arguments passed to underlying distribution.
- Return type:
tuple
[Any
,Array
]- Returns:
Tuple of (sample, log_density) from buffer or underlying distribution.
Note
The buffer is refilled when exhausted, which updates internal state. This may cause issues with JAX transformations that expect pure functions.
- log_density(x, **kwargs)[source]¶
Evaluate log density using the underlying distribution.
- Parameters:
x (
Any
) – Points at which to evaluate log density.**kwargs – Additional arguments passed to underlying distribution.
- Return type:
Array
- Returns:
Log density values from the underlying distribution.