bijx.ExpandDims¶
- class bijx.ExpandDims[source]¶
Bases:
MetaLayer
Expand tensor dimensions along specified axis.
Adds a singleton dimension at the specified axis position.
Example
>>> expand = ExpandDims(axis=-1) >>> x = jnp.array([[1, 2], [3, 4]]) # Shape (2, 2) >>> y, log_det = expand.forward(x, log_density) >>> # y has shape (2, 2, 1), log_det unchanged
- Parameters:
axis (
int
) – Axis along which to expand dimensions.rngs – Included for compatibility; not used.
Methods
forward
(x, log_density)Apply forward transformation.
invert
()Create an inverted version of this bijection.
reverse
(x, log_density)Apply reverse (inverse) transformation.