bijx.SqueezeDims¶
- class bijx.SqueezeDims[source]¶
Bases:
MetaLayer
Remove singleton dimensions along specified axis.
Removes dimensions of size 1 at the specified axis position. This is the inverse operation of
ExpandDims
.Type: Shape(…, d_k, …, 1, …) → Shape(…, d_k, …) Transform: Remove dimension of size 1 at specified axis
Example
>>> squeeze = SqueezeDims(axis=-1) >>> x = jnp.array([[[1], [2]], [[3], [4]]]) # Shape (2, 2, 1) >>> y, log_det = squeeze.forward(x, log_density) >>> # y has shape (2, 2), log_det unchanged
- Parameters:
axis (
int
) – Axis along which to squeeze dimensions.rngs – Random number generators (unused).
Methods
forward
(x, log_density)Apply forward transformation.
invert
()Create an inverted version of this bijection.
reverse
(x, log_density)Apply reverse (inverse) transformation.