bijx.MetaLayer

class bijx.MetaLayer[source]

Bases: Bijection

Convenient constructor for bijections that preserve probability density.

Example

>>> transpose = MetaLayer(
...     forward=lambda x: x[..., ::-1],
...     reverse=lambda x: x[..., ::-1],
... )
>>> y, log_det = transpose.forward(x, log_density)
>>> # log_det unchanged, y has transposed last two dims
Parameters:
  • forward (Callable) – Map x -> y that does not change the density.

  • reverse (Callable) – Map y -> x that does not change the density.

  • rngs – Included for compatibility; not used.

__init__(forward, reverse, *, rngs=None)[source]
Parameters:
  • forward (Callable)

  • reverse (Callable)

Methods

forward(x, log_density)

Apply forward transformation.

invert()

Create an inverted version of this bijection.

reverse(x, log_density)

Apply reverse (inverse) transformation.

forward(x, log_density)[source]

Apply forward transformation.

Transforms input through the bijection and updates log-density according to the change of variables formula.

For convenience Bijection() gives the default identity bijection.

Parameters:
  • x – Input data of any pytree structure.

  • log_density – Log density values corresponding to the input.

  • **kwargs – Additional transformation-specific arguments.

Returns:

Tuple of (transformed_data, updated_log_density) where the log-density incorporates the log absolute determinant of the transformation Jacobian.

reverse(x, log_density)[source]

Apply reverse (inverse) transformation.

Transforms input through the inverse bijection and updates log-density accordingly.

Parameters:
  • x – Input data of any pytree structure.

  • log_density – Log density values corresponding to the input.

  • **kwargs – Additional transformation-specific arguments.

Returns:

Tuple of (inverse_transformed_data, updated_log_density) where the log-density change has the opposite sign compared to forward().