bijx.ExpandDims

class bijx.ExpandDims[source]

Bases: MetaLayer

Expand tensor dimensions along specified axis.

Adds a singleton dimension at the specified axis position.

Example

>>> expand = ExpandDims(axis=-1)
>>> x = jnp.array([[1, 2], [3, 4]])  # Shape (2, 2)
>>> y, log_det = expand.forward(x, log_density)
>>> # y has shape (2, 2, 1), log_det unchanged
Parameters:
  • axis (int) – Axis along which to expand dimensions.

  • rngs – Included for compatibility; not used.

__init__(axis=-1, *, rngs=None)[source]
Parameters:

axis (int)

Methods

forward(x, log_density)

Apply forward transformation.

invert()

Create an inverted version of this bijection.

reverse(x, log_density)

Apply reverse (inverse) transformation.