bijx.GeneralCouplingLayer¶
- class bijx.GeneralCouplingLayer[source]¶
Bases:
Bijection
General coupling layer with flexible masking and bijection support.
Implements the fundamental coupling layer transformation where input is split into active and passive components, with the passive part conditioning the transformation applied to the active part. This maintains invertibility while enabling complex parameter-dependent transformations.
- Key features:
Flexible masking via
BinaryMask
with split or multiplicative modesAutomatic parameter management through
ModuleReconstructor
Support for arbitrary bijections with automatic vectorization
Configurable event rank for scalar vs. (e.g.) vector bijections
Proper log-density computation with broadcasting/summation
- Parameters:
embedding_net (
Module
) – Neural network that maps passive components to bijection parameters. Must output parameters compatible with bijection_reconstructor.mask (
BinaryMask
) – Binary mask defining active/passive split pattern.bijection_reconstructor (
ModuleReconstructor
) – Template for reconstructing parameterized bijections.bijection_event_rank (
int
) – Event rank of the underlying bijection (0 for scalar, 1 for vector).split (
bool
) – If True, use indexing-based masking; if False, use multiplicative masking.
Note
When using multiplicative masking (split=False), log-density changes are automatically masked to exclude passive components from density computation. The embedding network output shape must match the total parameter size required by the bijection reconstructor, not just the active part.
Example
>>> # Create mask and bijection template >>> mask = bijx.checker_mask((4,), parity=True) >>> spline = bijx.MonotoneRQSpline(10, (), rngs=rngs) >>> spline_template = bijx.ModuleReconstructor(spline) >>> >>> # Network producing parameters for active components >>> param_net = bijx.nn.nets.MLP( ... in_features=mask.count_secondary, ... out_features=mask.count_primary * spline_template.params_total_size, ... rngs=rngs ... ) >>> >>> # Reshape to coupling layer parameter shape >>> param_reshape = lambda p: p.reshape(*p.shape[:-1], mask.count_primary, -1) >>> >>> # Create coupling layer >>> layer = bijx.GeneralCouplingLayer( ... nnx.Sequential(param_net, param_reshape), ... mask, spline_template, bijection_event_rank=0, ... ) >>> >>> batch_size = 3 >>> x = jnp.ones((batch_size, 4)) >>> y, log_det = layer.forward(x, jnp.zeros((batch_size,)))
- __init__(embedding_net, mask, bijection_reconstructor, bijection_event_rank=0, split=True)[source]¶
- Parameters:
embedding_net (Module)
mask (BinaryMask)
bijection_reconstructor (ModuleReconstructor)
bijection_event_rank (int)
split (bool)
Methods
forward
(x, log_density, **kwargs)Apply forward transformation.
invert
()Create an inverted version of this bijection.
reverse
(x, log_density, **kwargs)Apply reverse (inverse) transformation.
- forward(x, log_density, **kwargs)[source]¶
Apply forward transformation.
Transforms input through the bijection and updates log-density according to the change of variables formula.
For convenience
Bijection()
gives the default identity bijection.- Parameters:
x – Input data of any pytree structure.
log_density – Log density values corresponding to the input.
**kwargs – Additional transformation-specific arguments.
- Returns:
Tuple of (transformed_data, updated_log_density) where the log-density incorporates the log absolute determinant of the transformation Jacobian.
- reverse(x, log_density, **kwargs)[source]¶
Apply reverse (inverse) transformation.
Transforms input through the inverse bijection and updates log-density accordingly.
- Parameters:
x – Input data of any pytree structure.
log_density – Log density values corresponding to the input.
**kwargs – Additional transformation-specific arguments.
- Returns:
Tuple of (inverse_transformed_data, updated_log_density) where the log-density change has the opposite sign compared to forward().