Basic Bijections

Besides continuous flows, powerful transformations can be built out of closed-form one-dimensional (scalar) bijections. In particular, the next page discusses how they can be composed in coupling layers to form higher-dimensional bijections. The following sections review implemented scalar bijections, as well as meta layers that only change the representation of the input such as shapes, and do not change their value or density.

For convenience, most scalar bijections inherit from ScalarBijection which splits the forward/reverse input transformation and the log-jacobian computation into separate methods.

class ScalarBijection(Bijection):
    def fwd(self, x, **kwargs):        # Forward transformation x → y
        raise NotImplementedError()

    def rev(self, y, **kwargs):        # Reverse transformation y → x
        raise NotImplementedError()

    def log_jac(self, x, y, **kwargs): # Log |∂y/∂x|
        # both x and y are available; can use whichever is most convenient
        raise NotImplementedError()

Broadcasting

All scalar transformations operate on the entries of an array individually. For convenience, classes that inherit from ScalarBijection (which are all presented here except for spline flows) allow the passed log_density to have a different shape than the input x. If this is the case, it is used to infer the event shape, and the log-jacobian is summed over these axes. Thus, the scalar bijections naturally extend to element-wise higher dimensional transformations.

This broadcasting behavior extends to the parameters of the bijections themselves. By default, the scalar transformations here have scalar parameters or none at all. However, it is also possible to specify parameters that match the event shape of the input. Parameters and inputs are broadcast according to the usual numpy broadcasting behavior.

Parameter Specification

  1. Shape Tuple: (D,) or () creates new nnx.Param with default initialization

  2. Array Value: e.g. jnp.array([1.0, 2.0]), will by default be wrapped with nnx.Param

  3. Variable Instance: e.g. nnx.Param(value) or bijx.Const(value) for explicit control

import bijx
import jax.numpy as jnp
from flax import nnx

rngs = nnx.Rngs(0)

# Method 1: Shape-based initialization
bijx.Scaling((2,), rngs=rngs)  # Creates trainable parameters

# Method 2: Value-based initialization
bijx.Scaling(jnp.array([1.0, 2.0]))  # Uses provided values

# Method 3: Variable-based (non-trainable constant)
bijx.Scaling(bijx.Const(jnp.array([1.0, 2.0])))
Scaling( # Const: 2 (8 B)
  scale=TransformedParameter(param=Const( # 2 (8 B)
    value=Array(shape=(2,), dtype=dtype('float32'))
  ), transform=None)
)

Because some parameters are restricted to certain ranges (e.g. the scaling), the parametrized scalar bijections in addition have optional arguments of the kind transform_parameter which can be any callable that transform the parameter value before it is used. Scaling by default does not apply a transformation, transform_scale=None (only zero is unsafe; note however that AffineLinear does apply jnp.exp as in real NVP). For stable training, this could be changed to softplus, for example:

bijx.Scaling(transform_scale=nnx.softplus, rngs=rngs)
Scaling( # Param: 1 (4 B)
  scale=TransformedParameter(param=Param( # 1 (4 B)
    value=Array(1., dtype=float32)
  ), transform=<PjitFunction of <function softplus at 0x1157d0ae0>>)
)

Scalar Bijections

Linear and Affine Transformations

Basic building blocks for normalizing flows:

  • AffineLinear: \([-\infty, \infty] → [-\infty, \infty]\) via \(\text{scale} \cdot x + \text{shift}\) with learnable parameters

  • Scaling: \([-\infty, \infty] → [-\infty, \infty]\) via \(\text{scale} \cdot x\)

  • Shift: \([-\infty, \infty] → [-\infty, \infty]\) via \(x + \text{shift}\)

These provide the most basic learnable transformations with simple Jacobians.

Bounded/Unbounded Transforms

Bounded range transforms that map unbounded inputs to bounded intervals:

  • Sigmoid: \([-\infty, \infty] → [0, 1]\) via \(σ(x) = 1/(1 + e^{-x})\)

  • Tanh: \([-\infty, \infty] → [-1, 1]\) via \(\tanh(x)\)

  • GaussianCDF: \([-\infty, \infty] → [0, 1]\) via \(Φ((x-\mu)/\sigma)\) with learnable location and scale

Unbounding transforms for mapping bounded to unbounded:

  • Tan: \([0, 1] → [-\infty, \infty]\) via \(\tan(\pi(x - 0.5))\)

Note that each bijection can be inverted via .invert().

Positive Domain Transforms

Transforms for mapping to positive reals:

  • Exponential: \([-\infty, \infty] → [0, \infty]\) via \(e^x\) (simple but can be numerically unstable)

  • SoftPlus: \([-\infty, \infty] → [0, \infty]\) via \(\log(1 + e^x)\) (numerically stable alternative)

  • Power: \([0, \infty] → [0, \infty]\) via \(x^p\) with learnable exponent \(p > 0\)

Other

  • Sinh: \([-\infty, \infty] → [-\infty, \infty]\) via \(\sinh(x)\)

  • BetaStretch: \([0, 1] → [0, 1]\) via \(\frac{x^{\alpha}}{x^{\alpha} + (1-x)^{\alpha}}\)

The latter provides smooth stretching of the unit interval with a learnable parameter \(\alpha\) controlling the degree and direction of stretching.

Examples

The following exhibits different ways of using scalar bijections.

  • Parameters can be specified as shapes or values (arrays or nnx.Variables).

  • Broadcasting between parameters and inputs follows numpy rules, except that inferred event shape is summed over in log-density.

  • Parameters can be scalar or match (a subset of) the event shape of the input.

import bijx
import numpy as jnp
import jax
from flax import nnx

# random number geenrator
rngs = nnx.Rngs(0)

layer = bijx.AffineLinear(
    # can specify parameters as shapes
    scale=(),
    # can also give specific values
    # must be arrays -- then wrapped in nnx.Param -- or any nnx.Variable
    shift=jnp.array(1.0),
    # then need to provide rngs for initialization
    rngs=rngs,
)

# shifts input, leaves density unchanged
layer.forward(1.0, 0.0)
(Array(2., dtype=float32), Array(0., dtype=float32))
# load %shapes magic, which tree-maps jnp.shape before showing output
bijx.utils.load_shapes_magic()

# both have batch shape (10, 7), x is interpreted to be scalar
x = jnp.zeros((10, 7))
log_density = jnp.zeros((10, 7))

# usual numpy broadcasting behavior
%shapes layer.forward(x, log_density)
((10, 7), (10, 7))
# x is inferred to be a vector, batch shape is (10,)
x = jnp.zeros((10, 7))
log_density = jnp.zeros((10,))

# transformation sums density change over event axes
%shapes layer.forward(x, log_density)
((10, 7), (10,))
# initialize to carry parameters of shape (7,)
layer = bijx.AffineLinear(
    scale=(7,),
    shift=jax.random.normal(rngs.next(), (7,)),
    rngs=rngs,
)

# parameters and event-shape of x are broadcast together
x, log_density = layer.forward(x, log_density)
x[0]  # first entry in batch
Array([-2.4424558 , -2.0356805 ,  0.20554423, -0.3535502 , -0.76197404,
       -1.1785518 , -1.1482196 ], dtype=float32)

Rational Quadratic Splines

MonotoneRQSpline: Implements monotonic rational quadratic splines following Durkan et al. (2019). This class behaves similar to scalar bijections above, in the sense that the transformation is applied per element. However, the parameters here are always randomly initialized depending on the event shape, because their array shapes also depend on the number of knots.

Meta layers

Meta layers are bijections that rearrange data without changing its density. This results in a log-Jacobian determinant of zero. They are useful when creating complex architectures as chains of bijections that have differing assumptions (e.g. vectorial vs image-like shape, feature channels, etc.).

  • Reshape: Reshapes the event dimensions of an array.

  • ExpandDims: Adds a new dimension to an array.

  • SqueezeDims: Removes a singleton dimension from an array.

  • Partial: Wraps another bijection to fix a set of its keyword arguments, creating a specialized version.

It is also straightforward to create custom layers given two functions, forward(x) and reverse(x) that only depend on the input using MetaLayer as MetaLayer(forward, reverse).

split = bijx.MetaLayer(
    # forward
    lambda x: jnp.split(x, 2, axis=-1),
    # reverse
    lambda y: jnp.concatenate(y, axis=-1),
)
x = jnp.ones((3, 4))

y, log_density = split.forward(x, jnp.zeros((3,)))
print(*y, sep="\n")
print(jnp.all(log_density == 0))
[[1. 1.]
 [1. 1.]
 [1. 1.]]
[[1. 1.]
 [1. 1.]
 [1. 1.]]
True
# gets back to original shape, x
split.reverse(y, jnp.zeros((3,)))[0]
array([[1., 1., 1., 1.],
       [1., 1., 1., 1.],
       [1., 1., 1., 1.]])