bijx.Transformed

class bijx.Transformed[source]

Bases: Distribution

Distribution obtained by applying a bijection to a base distribution.

Implements the pushforward distribution \(p_Y(y) = p_X(f^{-1}(y)) \abs{\det J_{f^{-1}}(y)}\) where \(Y = f(X)\) and \(f\) is the bijection.

The transformed distribution supports both sampling (by transforming samples from the base distribution) and density evaluation (by inverse transforming and applying the change of variables formula).

Example

>>> prior = bijx.IndependentNormal(event_shape=(2,))
>>> bijection = bijx.Sigmoid()
>>> transformed = bijx.Transformed(prior, bijection)
>>> samples, log_density = transformed.sample((100,), rng=key)
Parameters:
  • prior (Distribution) – Base distribution to transform.

  • bijection (Bijection) – Bijection to apply to samples from the prior.

__init__(prior, bijection)[source]
Parameters:

Methods

density(x, **kwargs)

Evaluate probability density at given points.

get_batch_shape(x)

Extract batch dimensions from a sample.

log_density(x, **kwargs)

Evaluate log density of the transformed distribution.

sample([batch_shape, rng])

Sample from the transformed distribution.

sample(batch_shape=(), rng=None, **kwargs)[source]

Sample from the transformed distribution.

Generates samples by first sampling from the base distribution, then applying the forward bijection transformation.

Parameters:
  • batch_shape (tuple[int, ...]) – Shape of batch dimensions for samples.

  • rng (Array | None) – Random key for sampling.

  • **kwargs – Additional arguments passed to bijection.

Return type:

tuple[Any, Array]

Returns:

Tuple of (samples, log_density) where samples have been transformed and log_density includes Jacobian correction.

log_density(x, **kwargs)[source]

Evaluate log density of the transformed distribution.

Applies the change of variables formula by inverse transforming the input and computing the base distribution density with Jacobian correction.

Parameters:
  • x (Any) – Points at which to evaluate log density.

  • **kwargs – Additional arguments passed to bijection.

Return type:

Array

Returns:

Log density values at the input points.