bijx.MetaLayer¶
- class bijx.MetaLayer[source]¶
Bases:
Bijection
Convenient constructor for bijections that preserve probability density.
Example
>>> transpose = MetaLayer( ... forward=lambda x: x[..., ::-1], ... reverse=lambda x: x[..., ::-1], ... ) >>> y, log_det = transpose.forward(x, log_density) >>> # log_det unchanged, y has transposed last two dims
- Parameters:
forward (
Callable
) – Mapx -> y
that does not change the density.reverse (
Callable
) – Mapy -> x
that does not change the density.rngs – Included for compatibility; not used.
Methods
forward
(x, log_density)Apply forward transformation.
invert
()Create an inverted version of this bijection.
reverse
(x, log_density)Apply reverse (inverse) transformation.
- forward(x, log_density)[source]¶
Apply forward transformation.
Transforms input through the bijection and updates log-density according to the change of variables formula.
For convenience
Bijection()
gives the default identity bijection.- Parameters:
x – Input data of any pytree structure.
log_density – Log density values corresponding to the input.
**kwargs – Additional transformation-specific arguments.
- Returns:
Tuple of (transformed_data, updated_log_density) where the log-density incorporates the log absolute determinant of the transformation Jacobian.
- reverse(x, log_density)[source]¶
Apply reverse (inverse) transformation.
Transforms input through the inverse bijection and updates log-density accordingly.
- Parameters:
x – Input data of any pytree structure.
log_density – Log density values corresponding to the input.
**kwargs – Additional transformation-specific arguments.
- Returns:
Tuple of (inverse_transformed_data, updated_log_density) where the log-density change has the opposite sign compared to forward().