bijx.CondInverse

class bijx.CondInverse[source]

Bases: Bijection

Conditionally inverted bijection based on runtime boolean flag.

This bijection wraps another bijection and conditionally inverts it based on a boolean parameter. Importantly, the boolean value does not have to be known at compile time. Thus, forward/reverse must not change array shapes.

Parameters:
  • bijection (Bijection) – The underlying bijection to wrap.

  • invert (bool) – If True, swap forward/reverse directions.

__init__(bijection, invert=True)[source]
Parameters:

Methods

forward(x, log_density, **kwargs)

Apply forward transformation.

invert()

Create an inverted version of this bijection.

reverse(x, log_density, **kwargs)

Apply reverse (inverse) transformation.

forward(x, log_density, **kwargs)[source]

Apply forward transformation.

Transforms input through the bijection and updates log-density according to the change of variables formula.

For convenience Bijection() gives the default identity bijection.

Parameters:
  • x – Input data of any pytree structure.

  • log_density – Log density values corresponding to the input.

  • **kwargs – Additional transformation-specific arguments.

Returns:

Tuple of (transformed_data, updated_log_density) where the log-density incorporates the log absolute determinant of the transformation Jacobian.

reverse(x, log_density, **kwargs)[source]

Apply reverse (inverse) transformation.

Transforms input through the inverse bijection and updates log-density accordingly.

Parameters:
  • x – Input data of any pytree structure.

  • log_density – Log density values corresponding to the input.

  • **kwargs – Additional transformation-specific arguments.

Returns:

Tuple of (inverse_transformed_data, updated_log_density) where the log-density change has the opposite sign compared to forward().

invert()[source]

Create an inverted version of this bijection.

Returns:

New bijection where forward() and reverse() methods are swapped. See Inverse.