Coupling Layers

Coupling layers form the basis of many normalizing flows. Bijections, most commonly acting on scalar values, are applied to an active subset of variables while using the other passive ones to set the parameters of the transformations. Bijx provides multiple flexible frameworks for building coupling layers including both traditional affine coupling and possibly more complex architectures.

To summarize again, the idea behind coupling layers is to:

  1. Split the input into active and passive components

  2. Condition the transformation on the passive components

  3. Transform only the active components using parameters computed from passive ones

  4. Preserve invertibility by keeping passive components unchanged

Schematically, for input \(x = (x_a, x_p)\) where \(x_a\) are active and \(x_p\) are passive:

\[\begin{split} \begin{aligned} y_p &= x_p \\ y_a &= T(x_a; \theta(x_p)) \end{aligned} \end{split}\]

where \(T\) is an invertible transformation and \(\theta(\cdot)\) is a neural network computing parameters. Most commonly \(T\) are scalar bijections that transform each array component of \(x_a\) independently. This importantly gives a triangular Jacobian structure allowing for efficient density-change computation.

Bijx provides a general toolkit for the above abstract approach, which lets one specify an arbitrary transformation \(T\) as a neural network, extracts the information about parameters that need to be computed, and handles batch dimensions if present. This will be explained below, and supports arbitrary choices of \(T\) (including non-scalar bijections).

This approach can also be used for scalar bijections, although in those cases a direct approach is also straightforward. For example, for the typical case of real NVP the transformation \(T\) is a simple element-wise affine linear transformation:

\[ y_a^i = e^{s_i(x_p)} \cdot x_a^i + t_i(x_p) \,, \]

where \(s_i\) and \(t_i\) are any neural networks yielding outputs of the same shape as \(x_a\), and the product is taken element-wise (often more explicitly denoted \(e^{s_i} \odot x_a\) as “Hadamard product”).

There are basically three levels of generality, as reflected in the next sections (after the first one on masking):

  • Simple broadcasting: Use standard broadcasting between parameters and input arrays, manually selected coupling layers to match bijection parameters.

  • Automatic parameter extraction: Automatic extraction of parameters from a template bijection.

  • Automatic vectorization: Automatic vectorization over batch dimensions if fundamental bijection does not support broadcasting.

In principle the last, most powerful abstraction can be used in all cases.

Masking

For the splitting into active and passive degrees of freedom we have to define some kind of masking pattern. A general kind of binary mask is implemented as BinaryMask, which supports two kinds of splitting:

  • Indexing, that is splitting \(x = (x_a, x_p)\) explicitly.

  • Multiplicative, that is setting \(x_a = \text{mask} \odot x\) and \(x_p = (1 - \text{mask}) \odot x\).

The latter is especially convenient if we want to preserve the shape and locality information of the input, such as in convolutional networks.

import bijx
import jax.numpy as jnp

# the simplest way is to initialize from boolean mask
mask = bijx.BinaryMask.from_boolean_mask(jnp.array([True, False]))
x = jnp.eye(2)
y = mask.split(x)

type(y), len(y), y[0].shape, y[1].shape
(tuple, 2, (2, 1), (2, 1))
mask.merge(*y) == x
Array([[ True,  True],
       [ True,  True]], dtype=bool)
# masking by multiplication
mask * x
Array([[1., 0.],
       [0., 0.]], dtype=float32)
# invert mask with mask.flip() or ~mask
~mask * x
Array([[0., 0.],
       [0., 1.]], dtype=float32)
# common checkerboard mask
mask = bijx.checker_mask(shape=(3, 3), parity=True)

# can recover underlying boolean mask
mask.boolean_mask
Array([[ True, False,  True],
       [False,  True, False],
       [ True, False,  True]], dtype=bool)
# note that indexing always flattens the event shape
x = jnp.arange(3 * 3).reshape(3, 3)
x[mask.indices()]
Array([0, 2, 4, 6, 8], dtype=int32)

To be compatible with jax.jit tracing, all shapes need to be known at compile time. In particular, that means we cannot index with the boolean mask directly. That is why bijx.BinaryMask internally stores the indices of the “active” and “passive”, or “primary” and “secondary” components, which are used by .split(x).

mask.primary_indices
Const( # 10 (80 B)
  value=(array([0, 0, 1, 2, 2]), array([0, 2, 1, 0, 2]))
)
# the indexing tuple that includes a batch ellipsis can also be explicitly obtained
mask.indices(primary=False)
(Ellipsis, array([0, 1, 1, 2]), array([1, 0, 2, 1]))
# and for convenience "channel" dimensions can be added
mask.indices(primary=False, extra_channel_dims=1)
(Ellipsis, array([0, 1, 1, 2]), array([1, 0, 2, 1]), slice(None, None, None))

Simple broadcasting

All scalar bijections that are subclasses of ScalarBijection can automatically broadcast over parameter and input shape batch axes. This is in particular the case for AffineLinear which implements the simple affine linear transformation \(y = s \cdot x + t\).

It is then straightforward to implement a simple coupling layer as follows.

# import nnx
from flax import nnx
rngs = nnx.Rngs(0)

Because the forward and reverse methods would look basically the same, we can use bijx.ApplyBijection as base class which implements forward and reverse into a single apply method that takes reverse as additional boolean keyword argument. This could obviously also easily be implemented manually, but it saves a few lines of code.

class AffineCouplingLayer(bijx.ApplyBijection):
    def __init__(self, mask: bijx.BinaryMask):
        self.mask = mask
        self.net = bijx.nn.nets.MLP(
            in_features=mask.count_primary,
            out_features=2 * mask.count_primary,
            rngs=rngs,
        )

    def apply(self, x, log_density, reverse=False, **kwargs):
        active, passive = self.mask.split(x)

        params = self.net(passive)
        # reshape to (..., event_size, 2)
        params.reshape(params.shape[:-1] + (-1, 2))
        s, t = jnp.split(params, 2, axis=-1)

        # Note that active and passive have shape (..., 1).
        # This is the same shape as s and t because split
        # does not remove the final axis.
        # Therefore, broadcasting will work as expected.

        bijection = bijx.AffineLinear(s, t, transform_scale=jnp.exp)

        method = bijection.reverse if reverse else bijection.forward
        active, log_density = method(active, log_density)

        x = self.mask.merge(active, passive)
        return x, log_density
bijx.utils.load_shapes_magic()

flow = AffineCouplingLayer(bijx.checker_mask((2,), True))

# apply layer with batch=(5,) and event_shape=(2,)
%shapes flow.forward(jnp.ones((5, 2)), jnp.zeros((5,)))
((5, 2), (5,))

These layers, with varying parity/masking patterns can then be repeated to build more complex flows.

An analogous construciton with multiplicative masking looks similar, except that more care needs to be taken that the log-likelihood change is not included for the (to be ignored) masked array entries. Below is a sketch how this can be done.

x = jnp.ones((10, 3))
log_density = jnp.zeros(10,)

mask = bijx.checker_mask((3,), True)

# the parameters will now typically have the same shape as x,
# including the passive entries
s, t = jnp.ones((3,)), jnp.ones((3,))
bij = bijx.AffineLinear(s, t)

# we need to get the log-density change for each entry so we can mask out passive
y, ld_change = bij.forward(
    x,
    # append the event shape
    jnp.zeros(log_density.shape + (3,)),
)

# only sum over active entries
log_density += jnp.sum(ld_change, where=mask.boolean_mask, axis=-1)

The above pattern can be applied to any scalar bijection, not just the affine linear one used here. Note, however, that we needed to explicitly insert in the network definition the number of parameters needed for the transformation (2 in the above).

A method to automatically extract the number of necessary parameters is provided by bijx.ModuleReconstructor, which then also extends to arbitrary coupling layers as explained in the next sections.

Automatic parameter extraction

For more sophisticated coupling layers beyond affine transformations, bijx provides ModuleReconstructor which facilitates parameter sharing and “reconstructing” a bijection given a set of parameters (in one of various representations).

flow = bijx.MonotoneRQSpline(10, (3,), rngs=rngs)

template = bijx.bijections.coupling.ModuleReconstructor(flow)

The template extracts the parameter information from the provided module, and allows various different formats.

# can provide single 1d array of parameters of this size
template.params_total_size
np.int64(87)
# can also provide dictionary which matches this structure
template.params_dict
{'heights': ShapedArray(float32[3,10]),
 'slopes': ShapedArray(float32[3,9]),
 'widths': ShapedArray(float32[3,10])}
# or "leaves" of the corresponding pytree
template.params_leaves
[ShapedArray(float32[3,10]),
 ShapedArray(float32[3,9]),
 ShapedArray(float32[3,10])]
# some other convenience attributes also available
template.params_dtypes
[dtype('float32'), dtype('float32'), dtype('float32')]
# some other convenience attributes also available
template.params_shapes
[(3, 10), (3, 9), (3, 10)]
# some other convenience attributes also available
template.params_shape_dict
{'heights': (3, 10), 'slopes': (3, 9), 'widths': (3, 10)}
# caveat: need to be careful if parameters are complex (not fully supported)
template.has_complex_params
False
import jax

# dummy array of parameters (would usually be/depend on the output of some NN)
params_array = jnp.zeros((template.params_total_size,))

# example inputs
x = jnp.ones((10, 3)) / 2
lp = jnp.zeros((10,))

# reconstruct flow from contiguous array of parameters
flow = template.from_params(params_array)

%shapes flow.forward(x, lp)
((10, 3), (10,))
# can also reconstruct from other representations

flow = template.from_params({
    key: jnp.zeros(shape)
    for key, shape in template.params_shape_dict.items()
})

flow = template.from_params([
    jnp.zeros(shape)
    for shape in template.params_shapes
])

Automatic vectorization

The implementation of splines in bijx, just as the other scalar bijections, support broadcasting over batch indices that are added to the parameters if they match the event shape. In fact, this is indistinguishable from having a parameter shape that matches the event shape (from the perspective of scalar bijections, those are effectively the same thing).

However, general bijections may not support this replacement of internal parameters with a batch of parameters that carry additional indices (this is true, fundamentally, of the spline flows – however, they already internally apply jax.vmap automatically to handle this). In this case, an additional argument autovmap=True can be passed to from_params, which returns an object that behaves almost like the original module/bijection, except that function calls are automatically jax.vmap’ed over the batch dimension.

# first, demonstration that the spline flow works with a batch of parameters
params_array = jnp.zeros((10, template.params_total_size))

# example inputs
x = jnp.ones((10, 3)) / 2
lp = jnp.zeros((10,))

# reconstruct flow from now batched contiguous array of parameters
flow = template.from_params(params_array)

%shapes flow.forward(x, lp)
((10, 3), (10,))

For demonstration purposes, let us define a version of spline flow that does not support broadcasting.

class NoBatchSpline(bijx.Bijection):

    def __init__(self, knots, *, rngs):
        self.knots = knots

        # in contrast to the above, here also FIX event_shape to be scalar, i.e. ()
        self.spline = bijx.MonotoneRQSpline(knots, (), rngs=rngs)

    def forward(self, x, log_density):
        assert self.spline.widths.value.shape == (self.knots,)
        return self.spline.forward(x, log_density)

    def reverse(self, x, log_density):
        assert self.spline.widths.value.shape == (self.knots - 1,)
        return self.spline.reverse(x, log_density)
flow = NoBatchSpline(10, rngs=rngs)

# works on scalar values
flow.forward(0.5, 0.0)
(Array(0.5008062, dtype=float32), Array(-0.00292323, dtype=float32))

It also still works on batched inputs. The problem (by construction, but representing a realistic scenario) are batched parameters, although the final solution would also handle bijections that do not support any kind of batch dimension (for the inputs).

%shapes flow.forward(jnp.ones((10, 3)), jnp.zeros((10,)))
((10, 3), (10,))
template = bijx.ModuleReconstructor(flow)
params_array = jnp.zeros((10, template.params_total_size))

x = jnp.ones((10, 3)) / 2
lp = jnp.zeros((10,))

flow = template.from_params(params_array)

try:
    flow.forward(x, lp)
except AssertionError:
    print('Expected assertion error about wrong parameter shape')
Expected assertion error about wrong parameter shape
# note we now have an additional "3" because the flow uses scalar event shape, i.e. ()
params_array = jnp.zeros((10, 3, template.params_total_size))

x = jnp.ones((10, 3)) / 2
lp = jnp.zeros((10,))

# now acts almost like the original flow, except it actually wraps around it
# and automatically applies vmap to function calls
flow = template.from_params(params_array, autovmap=True)
type(flow)
bijx.bijections.coupling.AutoVmapReconstructor
# parameters can be accessed (now has batch index)
flow.spline.widths.shape
(10, 3, 10)
# This now almost works, but there is another complication:
try:
    flow.forward(x, lp, input_ranks=(0, 0))
except ValueError as e:
    print(e)
vmap got inconsistent sizes for array axes to be mapped:
  * most axes (2 of them) had size 3, e.g. axis 0 of argument all_args of type float32[3,29];
  * one axis had size 10: axis 0 of args[1][1] of type float32[10]

We have to be careful to have the log-density match the shape of x. The above doesn’t work because the log-density implies the event shape of x should be scalar, while the combined input shapes of x and lp imply x is a vector. We can avoid that by inputting a log-density that matches the shape of x:

# The input_ransk default to (0, 0), but can be modified if the fundamental bijection is not scalar.
%shapes flow.forward(x, jnp.zeros_like(x), input_ranks=(0, 0))
((10, 3), (10, 3))

General coupling layer

Even more convenience, which also takes care of the annoyance above that we had to extend the log-density to match the shape of x (even though it should be obvious that we want to broadcast and sum the log-density over the event shape of x) is provided by GeneralCouplingLayer.

from einops import rearrange
knots = 11
spline_template = bijx.ModuleReconstructor(
    # again, define bijection to act on scalars, event_shapes=()
    bijx.MonotoneRQSpline(knots, (), rngs=nnx.Rngs(0))
)

# funciton to construct a single coupling layer
def spline_coupling_layer(mask, width, depth, rngs):

        # mask contains information about active/passive features
        count_active, count_passive = mask.counts
        # template knows how many parameters are needed
        param_count = spline_template.params_total_size

        # define network that maps frozen features to parameters
        resnet = bijx.nn.nets.ResNet(
            count_passive, count_active *param_count, width, depth,
            final_kernel_init=nnx.initializers.normal(),
            final_bias_init=nnx.initializers.zeros,
            rngs=rngs,
        )

        # the parameters in the end must have shape (..., count_active, param_count)
        def reshape_params(p):
            return rearrange(p, '... (t b) -> ... t b', t=count_active)

        param_net = nnx.Sequential(
            resnet,
            reshape_params,
        )

        return bijx.GeneralCouplingLayer(
            param_net,
            mask,
            spline_template,
            # defaults to 0 (scalar bijections)
            bijection_event_rank=0,
            # also supports multiplicative masking
            split=True,
        )
mask = bijx.BinaryMask.from_boolean_mask(jnp.array([True, False]))

layers = []

# chain 5 coupling layers
for _ in range(5):
    layers.append(spline_coupling_layer(mask, 32, 2, rngs=rngs))
    # invert mask after each layer
    mask = ~mask

flow = bijx.Chain(*layers)
x = jnp.ones((10, 2))
lp = jnp.zeros((10,))

%shapes flow.forward(x, lp)
((10, 2), (10,))