bijx.Reshape

class bijx.Reshape[source]

Bases: MetaLayer

Reshape tensor event dimensions while preserving batch dimensions.

Reshapes the event portion of tensors from one shape to another, preserving all batch dimensions. The total number of elements in the event shape must remain constant.

Type: Batch + from_shape → Batch + to_shape Transform: Reshape event dimensions only

Key features:
  • Batch dimensions are automatically preserved

  • Event shape compatibility is validated

  • Bidirectional reshaping with shape memory

Example

>>> reshape = Reshape(from_shape=(4, 4), to_shape=(16,))
>>> x = jnp.ones((3, 4, 4))  # Batch size 3, event shape (4, 4)
>>> y, log_det = reshape.forward(x, log_density)
>>> # y has shape (3, 16), log_det unchanged
Parameters:
  • from_shape (tuple[int, ...]) – Original event shape to reshape from.

  • to_shape (tuple[int, ...]) – Target event shape to reshape to.

  • rngs – Included for compatibility; not used.

__init__(from_shape, to_shape, *, rngs=None)[source]
Parameters:
  • from_shape (tuple[int, ...])

  • to_shape (tuple[int, ...])

Methods

forward(x, log_density)

Apply forward transformation.

invert()

Create an inverted version of this bijection.

reverse(x, log_density)

Apply reverse (inverse) transformation.