bijx.ApplyBijection

class bijx.ApplyBijection[source]

Bases: Bijection

Convenience base class for bijections with unified forward/reverse logic.

This class is useful when forward and reverse transformations share most of their implementation. Instead of duplicating code in separate forward() and reverse() methods, subclasses implement a single apply() method with a reverse parameter.

Example

>>> class MyBijection(ApplyBijection):
...     def apply(self, x, log_density, reverse=False, **kwargs):
...         # Shared transformation logic here; simple example
...         transformed_x = x - 1  if reverse else x + 1
...         # log_density doesn't change for this example
...         return transformed_x, log_density
__init__(*args, **kwargs)

Methods

apply(x, log_density[, reverse])

Unified transformation method.

forward(x, log_density, **kwargs)

Apply forward transformation.

invert()

Create an inverted version of this bijection.

reverse(x, log_density, **kwargs)

Apply reverse (inverse) transformation.

apply(x, log_density, reverse=False, **kwargs)[source]

Unified transformation method.

Parameters:
  • x – Input data of any pytree structure.

  • log_density – Log density values corresponding to the input.

  • reverse – If True, apply reverse transformation; if False, forward.

  • **kwargs – Additional transformation-specific arguments.

Returns:

Tuple of (transformed_data, updated_log_density).

Raises:

NotImplementedError – Must be implemented by subclasses.

forward(x, log_density, **kwargs)[source]

Apply forward transformation.

Transforms input through the bijection and updates log-density according to the change of variables formula.

For convenience Bijection() gives the default identity bijection.

Parameters:
  • x – Input data of any pytree structure.

  • log_density – Log density values corresponding to the input.

  • **kwargs – Additional transformation-specific arguments.

Returns:

Tuple of (transformed_data, updated_log_density) where the log-density incorporates the log absolute determinant of the transformation Jacobian.

reverse(x, log_density, **kwargs)[source]

Apply reverse (inverse) transformation.

Transforms input through the inverse bijection and updates log-density accordingly.

Parameters:
  • x – Input data of any pytree structure.

  • log_density – Log density values corresponding to the input.

  • **kwargs – Additional transformation-specific arguments.

Returns:

Tuple of (inverse_transformed_data, updated_log_density) where the log-density change has the opposite sign compared to forward().

invert()[source]

Create an inverted version of this bijection.

Returns:

New bijection where forward() and reverse() methods are swapped. See Inverse.