bijx.Reshape¶
- class bijx.Reshape[source]¶
Bases:
MetaLayer
Reshape tensor event dimensions while preserving batch dimensions.
Reshapes the event portion of tensors from one shape to another, preserving all batch dimensions. The total number of elements in the event shape must remain constant.
Type: Batch + from_shape → Batch + to_shape Transform: Reshape event dimensions only
- Key features:
Batch dimensions are automatically preserved
Event shape compatibility is validated
Bidirectional reshaping with shape memory
Example
>>> reshape = Reshape(from_shape=(4, 4), to_shape=(16,)) >>> x = jnp.ones((3, 4, 4)) # Batch size 3, event shape (4, 4) >>> y, log_det = reshape.forward(x, log_density) >>> # y has shape (3, 16), log_det unchanged
- Parameters:
from_shape (
tuple
[int
,...
]) – Original event shape to reshape from.to_shape (
tuple
[int
,...
]) – Target event shape to reshape to.rngs – Included for compatibility; not used.
- __init__(from_shape, to_shape, *, rngs=None)[source]¶
- Parameters:
from_shape (tuple[int, ...])
to_shape (tuple[int, ...])
Methods
forward
(x, log_density)Apply forward transformation.
invert
()Create an inverted version of this bijection.
reverse
(x, log_density)Apply reverse (inverse) transformation.