bijx.BinaryMask¶
- class bijx.BinaryMask[source]¶
Bases:
Bijection
Binary mask for coupling layer split/merge operations.
This class provides a flexible masking utility that supports both multiplication-based and indexing-based masking operations, registered as a jax pytree node.
The mask can be used in two main ways: 1. As a masking operator: Use
mask * array
orarray[mask.indices()]
2. As a bijection: Splits input in forward pass, merges in reverse passType: \(\mathbb{R}^n \to \mathbb{R}^{n_1} \times \mathbb{R}^{n_2}\) (forward) Transform: Splits input according to binary mask pattern
- Parameters:
primary_indices (
tuple
[ndarray
,...
]) – Tuple of arrays specifying primary (True) indices.event_shape (
tuple
[int
,...
]) – Shape of the event dimensions being masked.masks (
tuple
[Array
,Array
] |None
) – Optional precomputed boolean mask pair (primary, secondary).secondary_indices (
tuple
[ndarray
,...
] |None
) – Optional secondary (False) indices.
The mask can be flipped with
~mask
orflip()
to swap primary/secondary. Usefrom_boolean_mask()
orfrom_indices()
for convenient construction.Example
>>> # Create checkerboard mask >>> mask = checker_mask((4, 4), parity=0) >>> x = jnp.ones((4, 4)) >>> primary, secondary = mask.split(x) >>> reconstructed = mask.merge(primary, secondary) >>> # reconstructed == x
- __init__(primary_indices, event_shape, masks=None, secondary_indices=None)[source]¶
- Parameters:
primary_indices (tuple[ndarray, ...])
event_shape (tuple[int, ...])
masks (tuple[Array, Array] | None)
secondary_indices (tuple[ndarray, ...] | None)
Methods
flip
()Create flipped mask with primary/secondary swapped.
forward
(x, log_density)Split input as bijection forward pass.
from_boolean_mask
(mask)Create mask from boolean array.
from_indices
(indices, event_shape)Create mask from index arrays.
indices
([extra_channel_dims, batch_safe, ...])Get indexing tuple for array access.
invert
()Create an inverted version of this bijection.
reverse
(x, log_density)Merge parts as bijection reverse pass.
Attributes
Primary boolean mask array.
Number of elements in the primary (True) mask region.
Number of elements in the secondary (False) mask region.
Tuple of (primary_count, secondary_count).
Total number of elements in the event shape.
-
event_shape:
tuple
[int
,...
]¶
- property count_primary¶
Number of elements in the primary (True) mask region.
- property count_secondary¶
Number of elements in the secondary (False) mask region.
- property counts¶
Tuple of (primary_count, secondary_count).
- property event_size¶
Total number of elements in the event shape.
- classmethod from_indices(indices, event_shape)[source]¶
Create mask from index arrays.
- Parameters:
indices (
tuple
[ndarray
,...
]) – Tuple of index arrays specifying primary mask positions.event_shape (
tuple
[int
,...
]) – Shape of the event dimensions being masked.
- Returns:
New BinaryMask instance.
- classmethod from_boolean_mask(mask)[source]¶
Create mask from boolean array.
- Parameters:
mask (
Array
) – Boolean array where True indicates primary mask positions.- Returns:
New BinaryMask instance with same pattern as input mask.
- property boolean_mask¶
Primary boolean mask array.
- indices(extra_channel_dims=0, batch_safe=True, primary=True)[source]¶
Get indexing tuple for array access.
- Parameters:
extra_channel_dims (
int
) – Number of trailing channel dimensions to preserve.batch_safe (
bool
) – If True, include ellipsis for batch dimensions.primary (
bool
) – If True, return primary indices; otherwise secondary.
- Returns:
Indexing tuple suitable for array subscripting.
- flip()[source]¶
Create flipped mask with primary/secondary swapped.
- Returns:
New BinaryMask with primary and secondary regions swapped.
- forward(x, log_density)[source]¶
Split input as bijection forward pass.
When used as a bijection, forward pass splits the input into primary and secondary parts according to the mask.
- Parameters:
x – Input array to split.
log_density – Input log density (unchanged).
- Returns:
Tuple of ((primary_part, secondary_part), unchanged_log_density).
- reverse(x, log_density)[source]¶
Merge parts as bijection reverse pass.
When used as a bijection, reverse pass merges the split parts back into the original array structure.
- Parameters:
x – Tuple of (primary_part, secondary_part) to merge.
log_density – Input log density (unchanged).
- Returns:
Tuple of (merged_array, unchanged_log_density).