bijx.Bijection¶
- class bijx.Bijection[source]¶
Bases:
ModuleBase class for all bijective transformations.
A bijection represents an invertible mapping between two spaces, equipped with forward and reverse transformations that properly track log-density changes according to the change of variables formula.
Note
This is an abstract base class. Subclasses must override forward() and reverse() methods to implement the actual transformation logic.
- __init__(*args, **kwargs)¶
Methods
apply(x, log_density[, reverse])Apply the bijection in either direction.
forward(x, log_density, **kwargs)Apply forward transformation.
invert()Create an inverted version of this bijection.
reverse(x, log_density, **kwargs)Apply reverse (inverse) transformation.
- forward(x, log_density, **kwargs)[source]¶
Apply forward transformation.
Transforms input through the bijection and updates log-density according to the change of variables formula.
For convenience
Bijection()gives the default identity bijection.- Parameters:
x – Input data of any pytree structure.
log_density – Log density values corresponding to the input.
**kwargs – Additional transformation-specific arguments.
- Returns:
Tuple of (transformed_data, updated_log_density) where the log-density incorporates the log absolute determinant of the transformation Jacobian.
- reverse(x, log_density, **kwargs)[source]¶
Apply reverse (inverse) transformation.
Transforms input through the inverse bijection and updates log-density accordingly.
- Parameters:
x – Input data of any pytree structure.
log_density – Log density values corresponding to the input.
**kwargs – Additional transformation-specific arguments.
- Returns:
Tuple of (inverse_transformed_data, updated_log_density) where the log-density change has the opposite sign compared to forward().
- invert()[source]¶
Create an inverted version of this bijection.
- Returns:
New bijection where forward() and reverse() methods are swapped. See
Inverse.
- apply(x, log_density, reverse=False, **kwargs)[source]¶
Apply the bijection in either direction.
Dispatches to
forward()orreverse()based onreverse. Ifreverseis a Pythonboolthe dispatch is static (only the chosen branch is traced); otherwisejax.lax.condis used so a traced boolean array is supported, at the cost of tracing both branches.- Parameters:
x – Input data of any pytree structure.
log_density – Log density values corresponding to the input.
reverse – If True, apply reverse transformation; if False, forward. May be a Python bool or a JAX boolean scalar.
**kwargs – Additional transformation-specific arguments.
- Returns:
Tuple of (transformed_data, updated_log_density).