bijx.Transformed¶
- class bijx.Transformed[source]¶
Bases:
DistributionDistribution obtained by applying a bijection to a base distribution.
Implements the pushforward distribution \(p_Y(y) = p_X(f^{-1}(y)) \abs{\det J_{f^{-1}}(y)}\) where \(Y = f(X)\) and \(f\) is the bijection.
The transformed distribution supports both sampling (by transforming samples from the base distribution) and density evaluation (by inverse transforming and applying the change of variables formula).
Important: Assumes that the bijection does not change the shape/pytree structure of the input in the log_density method. This is because it defaults to using the given prior to determine the batch shape of the input. If this is not true, this class can be extended with a manual implementation of get_batch_shape.
Example
>>> prior = bijx.IndependentNormal(event_shape=(2,)) >>> bijection = bijx.Sigmoid() >>> transformed = bijx.Transformed(prior, bijection) >>> samples, log_density = transformed.sample((100,), rng=key)
- Parameters:
prior (
Distribution) – Base distribution to transform.bijection (
Bijection) – Bijection to apply to samples from the prior.
- __init__(prior, bijection)[source]¶
- Parameters:
prior (Distribution)
bijection (Bijection)
Methods
density(x, **kwargs)Evaluate probability density at given points.
Extract batch dimensions from a sample.
log_density(x, **kwargs)Evaluate log density of the transformed distribution.
sample([batch_shape, rng])Sample from the transformed distribution.
- sample(batch_shape=(), rng=None, **kwargs)[source]¶
Sample from the transformed distribution.
Generates samples by first sampling from the base distribution, then applying the forward bijection transformation.
- Parameters:
batch_shape (
tuple[int,...]) – Shape of batch dimensions for samples.rng (
Array|None) – Random key for sampling.**kwargs – Additional arguments passed to bijection.
- Return type:
tuple[Any,Array]- Returns:
Tuple of (samples, log_density) where samples have been transformed and log_density includes Jacobian correction.
- get_batch_shape(x)[source]¶
Extract batch dimensions from a sample.
This method defaults to using the prior’s get_batch_shape method to determine the batch shape.
- Parameters:
x (
Any) – A transformed sample.- Return type:
tuple[int,...]- Returns:
Tuple representing the batch shape of the sample.
- log_density(x, **kwargs)[source]¶
Evaluate log density of the transformed distribution.
Applies the change of variables formula by inverse transforming the input and computing the base distribution density with Jacobian correction.
- Parameters:
x (
Any) – Points at which to evaluate log density.**kwargs – Additional arguments passed to bijection.
- Return type:
Array- Returns:
Log density values at the input points.