bijx.SqueezeDims

class bijx.SqueezeDims[source]

Bases: MetaLayer

Remove singleton dimensions along specified axis.

Removes dimensions of size 1 at the specified axis position. This is the inverse operation of ExpandDims.

Type: Shape(…, d_k, …, 1, …) → Shape(…, d_k, …) Transform: Remove dimension of size 1 at specified axis

Example

>>> squeeze = SqueezeDims(axis=-1)
>>> x = jnp.array([[[1], [2]], [[3], [4]]])  # Shape (2, 2, 1)
>>> y, log_det = squeeze.forward(x, log_density)
>>> # y has shape (2, 2), log_det unchanged
Parameters:
  • axis (int) – Axis along which to squeeze dimensions.

  • rngs – Random number generators (unused).

__init__(axis=-1, *, rngs=None)[source]
Parameters:

axis (int)

Methods

forward(x, log_density)

Apply forward transformation.

invert()

Create an inverted version of this bijection.

reverse(x, log_density)

Apply reverse (inverse) transformation.