bijx.BinaryMask

class bijx.BinaryMask[source]

Bases: Bijection

Binary mask for coupling layer split/merge operations.

This class provides a flexible masking utility that supports both multiplication-based and indexing-based masking operations, registered as a jax pytree node.

The mask can be used in two main ways: 1. As a masking operator: Use mask * array or array[mask.indices()] 2. As a bijection: Splits input in forward pass, merges in reverse pass

Type: \(\mathbb{R}^n \to \mathbb{R}^{n_1} \times \mathbb{R}^{n_2}\) (forward) Transform: Splits input according to binary mask pattern

Parameters:
  • primary_indices (tuple[ndarray, ...]) – Tuple of arrays specifying primary (True) indices.

  • event_shape (tuple[int, ...]) – Shape of the event dimensions being masked.

  • masks (tuple[Array, Array] | None) – Optional precomputed boolean mask pair (primary, secondary).

  • secondary_indices (tuple[ndarray, ...] | None) – Optional secondary (False) indices.

The mask can be flipped with ~mask or flip() to swap primary/secondary. Use from_boolean_mask() or from_indices() for convenient construction.

Example

>>> # Create checkerboard mask
>>> mask = checker_mask((4, 4), parity=0)
>>> x = jnp.ones((4, 4))
>>> primary, secondary = mask.split(x)
>>> reconstructed = mask.merge(primary, secondary)
>>> # reconstructed == x
__init__(primary_indices, event_shape, masks=None, secondary_indices=None)[source]
Parameters:
  • primary_indices (tuple[ndarray, ...])

  • event_shape (tuple[int, ...])

  • masks (tuple[Array, Array] | None)

  • secondary_indices (tuple[ndarray, ...] | None)

Methods

flip()

Create flipped mask with primary/secondary swapped.

forward(x, log_density)

Split input as bijection forward pass.

from_boolean_mask(mask)

Create mask from boolean array.

from_indices(indices, event_shape)

Create mask from index arrays.

indices([extra_channel_dims, batch_safe, ...])

Get indexing tuple for array access.

invert()

Create an inverted version of this bijection.

reverse(x, log_density)

Merge parts as bijection reverse pass.

Attributes

boolean_mask

Primary boolean mask array.

count_primary

Number of elements in the primary (True) mask region.

count_secondary

Number of elements in the secondary (False) mask region.

counts

Tuple of (primary_count, secondary_count).

event_size

Total number of elements in the event shape.

masks

primary_indices

secondary_indices

event_shape

masks: Const
primary_indices: Const
secondary_indices: Const
event_shape: tuple[int, ...]
property count_primary

Number of elements in the primary (True) mask region.

property count_secondary

Number of elements in the secondary (False) mask region.

property counts

Tuple of (primary_count, secondary_count).

property event_size

Total number of elements in the event shape.

classmethod from_indices(indices, event_shape)[source]

Create mask from index arrays.

Parameters:
  • indices (tuple[ndarray, ...]) – Tuple of index arrays specifying primary mask positions.

  • event_shape (tuple[int, ...]) – Shape of the event dimensions being masked.

Returns:

New BinaryMask instance.

classmethod from_boolean_mask(mask)[source]

Create mask from boolean array.

Parameters:

mask (Array) – Boolean array where True indicates primary mask positions.

Returns:

New BinaryMask instance with same pattern as input mask.

property boolean_mask

Primary boolean mask array.

indices(extra_channel_dims=0, batch_safe=True, primary=True)[source]

Get indexing tuple for array access.

Parameters:
  • extra_channel_dims (int) – Number of trailing channel dimensions to preserve.

  • batch_safe (bool) – If True, include ellipsis for batch dimensions.

  • primary (bool) – If True, return primary indices; otherwise secondary.

Returns:

Indexing tuple suitable for array subscripting.

flip()[source]

Create flipped mask with primary/secondary swapped.

Returns:

New BinaryMask with primary and secondary regions swapped.

forward(x, log_density)[source]

Split input as bijection forward pass.

When used as a bijection, forward pass splits the input into primary and secondary parts according to the mask.

Parameters:
  • x – Input array to split.

  • log_density – Input log density (unchanged).

Returns:

Tuple of ((primary_part, secondary_part), unchanged_log_density).

reverse(x, log_density)[source]

Merge parts as bijection reverse pass.

When used as a bijection, reverse pass merges the split parts back into the original array structure.

Parameters:
  • x – Tuple of (primary_part, secondary_part) to merge.

  • log_density – Input log density (unchanged).

Returns:

Tuple of (merged_array, unchanged_log_density).