bijx.Chain¶
- class bijx.Chain[source]¶
Bases:
Bijection
Sequential composition of multiple bijections.
Chains multiple bijections together to create a composite transformation. Forward pass applies bijections in order, reverse pass applies them in reverse order with each bijection’s reverse() method.
- Parameters:
*bijections (
Bijection
) – Variable number of bijections to chain together.
Example
>>> bij1 = SomeBijection() >>> bij2 = SomeBijection() >>> chain = Chain(bij1, bij2) >>> # Forward: bij2.forward(bij1.forward(x, ld)) >>> # Reverse: bij1.reverse(bij2.reverse(y, ld))
Methods
forward
(x, log_density, *[, arg_list])Apply all bijections in forward order.
invert
()Create an inverted version of this bijection.
reverse
(x, log_density, *[, arg_list])Apply all bijections in reverse order using their reverse() methods.
- forward(x, log_density, *, arg_list=None, **kwargs)[source]¶
Apply all bijections in forward order.
- Parameters:
x – Input data.
log_density – Input log density.
arg_list (
list
[dict
] |None
) – Optional list of argument dicts for each bijection.**kwargs – Common arguments passed to all bijections.
- Returns:
Tuple of final transformed data and accumulated log density.
- reverse(x, log_density, *, arg_list=None, **kwargs)[source]¶
Apply all bijections in reverse order using their reverse() methods.
- Parameters:
x – Input data.
log_density – Input log density.
arg_list (
list
[dict
] |None
) – Optional list of argument dicts for each bijection.**kwargs – Common arguments passed to all bijections.
- Returns:
Tuple of final inverse-transformed data and accumulated log density.