bijx.Chain

class bijx.Chain[source]

Bases: Bijection

Sequential composition of multiple bijections.

Chains multiple bijections together to create a composite transformation. Forward pass applies bijections in order, reverse pass applies them in reverse order with each bijection’s reverse() method.

Parameters:

*bijections (Bijection) – Variable number of bijections to chain together.

Example

>>> bij1 = SomeBijection()
>>> bij2 = SomeBijection()
>>> chain = Chain(bij1, bij2)
>>> # Forward: bij2.forward(bij1.forward(x, ld))
>>> # Reverse: bij1.reverse(bij2.reverse(y, ld))
__init__(*bijections)[source]
Parameters:

bijections (Bijection)

Methods

forward(x, log_density, *[, arg_list])

Apply all bijections in forward order.

invert()

Create an inverted version of this bijection.

reverse(x, log_density, *[, arg_list])

Apply all bijections in reverse order using their reverse() methods.

forward(x, log_density, *, arg_list=None, **kwargs)[source]

Apply all bijections in forward order.

Parameters:
  • x – Input data.

  • log_density – Input log density.

  • arg_list (list[dict] | None) – Optional list of argument dicts for each bijection.

  • **kwargs – Common arguments passed to all bijections.

Returns:

Tuple of final transformed data and accumulated log density.

reverse(x, log_density, *, arg_list=None, **kwargs)[source]

Apply all bijections in reverse order using their reverse() methods.

Parameters:
  • x – Input data.

  • log_density – Input log density.

  • arg_list (list[dict] | None) – Optional list of argument dicts for each bijection.

  • **kwargs – Common arguments passed to all bijections.

Returns:

Tuple of final inverse-transformed data and accumulated log density.