bijx.ScanChain

class bijx.ScanChain[source]

Bases: Bijection

Jax compilation-efficient chain of identical bijections using JAX scan.

This bijection applies the same bijection architecture multiple times in sequence, but with different parameters for each application by using jax.lax.scan over the stack of bijections for efficient jax-compilation.

Parameters:

stack (Bijection) – Stack of bijections to scan over. Should be a single bijection but all internal parameters carry an initial “scan batch” dimension.

Note

The stack should contain parameters for multiple instances of the same bijection architecture. Forward pass scans in order, reverse pass scans in reverse order. The “scan” index is the leading dimension of all internal parameters.

__init__(stack)[source]
Parameters:

stack (Bijection)

Methods

forward(x, log_density, **kwargs)

Apply forward transformation.

invert()

Create an inverted version of this bijection.

reverse(y, log_density, **kwargs)

Apply reverse (inverse) transformation.

forward(x, log_density, **kwargs)[source]

Apply forward transformation.

Transforms input through the bijection and updates log-density according to the change of variables formula.

For convenience Bijection() gives the default identity bijection.

Parameters:
  • x – Input data of any pytree structure.

  • log_density – Log density values corresponding to the input.

  • **kwargs – Additional transformation-specific arguments.

Returns:

Tuple of (transformed_data, updated_log_density) where the log-density incorporates the log absolute determinant of the transformation Jacobian.

reverse(y, log_density, **kwargs)[source]

Apply reverse (inverse) transformation.

Transforms input through the inverse bijection and updates log-density accordingly.

Parameters:
  • x – Input data of any pytree structure.

  • log_density – Log density values corresponding to the input.

  • **kwargs – Additional transformation-specific arguments.

Returns:

Tuple of (inverse_transformed_data, updated_log_density) where the log-density change has the opposite sign compared to forward().