bijx.GeneralCouplingLayer

class bijx.GeneralCouplingLayer[source]

Bases: Bijection

General coupling layer with flexible masking and bijection support.

Implements the fundamental coupling layer transformation where input is split into active and passive components, with the passive part conditioning the transformation applied to the active part. This maintains invertibility while enabling complex parameter-dependent transformations.

Key features:
  • Flexible masking via BinaryMask with split or multiplicative modes

  • Automatic parameter management through ModuleReconstructor

  • Support for arbitrary bijections with automatic vectorization

  • Configurable event rank for scalar vs. (e.g.) vector bijections

  • Proper log-density computation with broadcasting/summation

Parameters:
  • embedding_net (Module) – Neural network that maps passive components to bijection parameters. Must output parameters compatible with bijection_reconstructor.

  • mask (BinaryMask) – Binary mask defining active/passive split pattern.

  • bijection_reconstructor (ModuleReconstructor) – Template for reconstructing parameterized bijections.

  • bijection_event_rank (int) – Event rank of the underlying bijection (0 for scalar, 1 for vector).

  • split (bool) – If True, use indexing-based masking; if False, use multiplicative masking.

Note

When using multiplicative masking (split=False), log-density changes are automatically masked to exclude passive components from density computation. The embedding network output shape must match the total parameter size required by the bijection reconstructor, not just the active part.

Example

>>> # Create mask and bijection template
>>> mask = bijx.checker_mask((4,), parity=True)
>>> spline = bijx.MonotoneRQSpline(10, (), rngs=rngs)
>>> spline_template = bijx.ModuleReconstructor(spline)
>>>
>>> # Network producing parameters for active components
>>> param_net = bijx.nn.nets.MLP(
...     in_features=mask.count_secondary,
...     out_features=mask.count_primary * spline_template.params_total_size,
...     rngs=rngs
... )
>>>
>>> # Reshape to coupling layer parameter shape
>>> param_reshape = lambda p: p.reshape(*p.shape[:-1], mask.count_primary, -1)
>>>
>>> # Create coupling layer
>>> layer = bijx.GeneralCouplingLayer(
...     nnx.Sequential(param_net, param_reshape),
...     mask, spline_template, bijection_event_rank=0,
... )
>>>
>>> batch_size = 3
>>> x = jnp.ones((batch_size, 4))
>>> y, log_det = layer.forward(x, jnp.zeros((batch_size,)))
__init__(embedding_net, mask, bijection_reconstructor, bijection_event_rank=0, split=True)[source]
Parameters:

Methods

forward(x, log_density, **kwargs)

Apply forward transformation.

invert()

Create an inverted version of this bijection.

reverse(x, log_density, **kwargs)

Apply reverse (inverse) transformation.

forward(x, log_density, **kwargs)[source]

Apply forward transformation.

Transforms input through the bijection and updates log-density according to the change of variables formula.

For convenience Bijection() gives the default identity bijection.

Parameters:
  • x – Input data of any pytree structure.

  • log_density – Log density values corresponding to the input.

  • **kwargs – Additional transformation-specific arguments.

Returns:

Tuple of (transformed_data, updated_log_density) where the log-density incorporates the log absolute determinant of the transformation Jacobian.

reverse(x, log_density, **kwargs)[source]

Apply reverse (inverse) transformation.

Transforms input through the inverse bijection and updates log-density accordingly.

Parameters:
  • x – Input data of any pytree structure.

  • log_density – Log density values corresponding to the input.

  • **kwargs – Additional transformation-specific arguments.

Returns:

Tuple of (inverse_transformed_data, updated_log_density) where the log-density change has the opposite sign compared to forward().