bijx.ApplyBijection¶
- class bijx.ApplyBijection[source]¶
Bases:
Bijection
Convenience base class for bijections with unified forward/reverse logic.
This class is useful when forward and reverse transformations share most of their implementation. Instead of duplicating code in separate forward() and reverse() methods, subclasses implement a single apply() method with a
reverse
parameter.Example
>>> class MyBijection(ApplyBijection): ... def apply(self, x, log_density, reverse=False, **kwargs): ... # Shared transformation logic here; simple example ... transformed_x = x - 1 if reverse else x + 1 ... # log_density doesn't change for this example ... return transformed_x, log_density
- __init__(*args, **kwargs)¶
Methods
apply
(x, log_density[, reverse])Unified transformation method.
forward
(x, log_density, **kwargs)Apply forward transformation.
invert
()Create an inverted version of this bijection.
reverse
(x, log_density, **kwargs)Apply reverse (inverse) transformation.
- apply(x, log_density, reverse=False, **kwargs)[source]¶
Unified transformation method.
- Parameters:
x – Input data of any pytree structure.
log_density – Log density values corresponding to the input.
reverse – If True, apply reverse transformation; if False, forward.
**kwargs – Additional transformation-specific arguments.
- Returns:
Tuple of (transformed_data, updated_log_density).
- Raises:
NotImplementedError – Must be implemented by subclasses.
- forward(x, log_density, **kwargs)[source]¶
Apply forward transformation.
Transforms input through the bijection and updates log-density according to the change of variables formula.
For convenience
Bijection()
gives the default identity bijection.- Parameters:
x – Input data of any pytree structure.
log_density – Log density values corresponding to the input.
**kwargs – Additional transformation-specific arguments.
- Returns:
Tuple of (transformed_data, updated_log_density) where the log-density incorporates the log absolute determinant of the transformation Jacobian.
- reverse(x, log_density, **kwargs)[source]¶
Apply reverse (inverse) transformation.
Transforms input through the inverse bijection and updates log-density accordingly.
- Parameters:
x – Input data of any pytree structure.
log_density – Log density values corresponding to the input.
**kwargs – Additional transformation-specific arguments.
- Returns:
Tuple of (inverse_transformed_data, updated_log_density) where the log-density change has the opposite sign compared to forward().