Basic Bijections¶
Besides continuous flows, powerful transformations can be built out of closed-form one-dimensional (scalar) bijections. In particular, the next page discusses how they can be composed in coupling layers to form higher-dimensional bijections. The following sections review implemented scalar bijections, as well as meta layers that only change the representation of the input such as shapes, and do not change their value or density.
For convenience, most scalar bijections inherit from ScalarBijection
which splits the forward/reverse input transformation and the log-jacobian computation into separate methods.
class ScalarBijection(Bijection):
def fwd(self, x, **kwargs): # Forward transformation x → y
raise NotImplementedError()
def rev(self, y, **kwargs): # Reverse transformation y → x
raise NotImplementedError()
def log_jac(self, x, y, **kwargs): # Log |∂y/∂x|
# both x and y are available; can use whichever is most convenient
raise NotImplementedError()
Broadcasting¶
All scalar transformations operate on the entries of an array individually.
For convenience, classes that inherit from ScalarBijection (which are all presented here except for spline flows) allow the passed log_density
to have a different shape than the input x
.
If this is the case, it is used to infer the event shape, and the log-jacobian is summed over these axes.
Thus, the scalar bijections naturally extend to element-wise higher dimensional transformations.
This broadcasting behavior extends to the parameters of the bijections themselves. By default, the scalar transformations here have scalar parameters or none at all. However, it is also possible to specify parameters that match the event shape of the input. Parameters and inputs are broadcast according to the usual numpy broadcasting behavior.
Parameter Specification¶
Shape Tuple:
(D,)
or()
creates newnnx.Param
with default initializationArray Value: e.g.
jnp.array([1.0, 2.0])
, will by default be wrapped withnnx.Param
Variable Instance: e.g.
nnx.Param(value)
orbijx.Const(value)
for explicit control
import bijx
import jax.numpy as jnp
from flax import nnx
rngs = nnx.Rngs(0)
# Method 1: Shape-based initialization
bijx.Scaling((2,), rngs=rngs) # Creates trainable parameters
# Method 2: Value-based initialization
bijx.Scaling(jnp.array([1.0, 2.0])) # Uses provided values
# Method 3: Variable-based (non-trainable constant)
bijx.Scaling(bijx.Const(jnp.array([1.0, 2.0])))
Scaling( # Const: 2 (8 B)
scale=TransformedParameter(param=Const( # 2 (8 B)
value=Array(shape=(2,), dtype=dtype('float32'))
), transform=None)
)
Because some parameters are restricted to certain ranges (e.g. the scaling), the parametrized scalar bijections in addition have optional arguments of the kind transform_parameter
which can be any callable that transform the parameter value before it is used.
Scaling
by default does not apply a transformation, transform_scale=None
(only zero is unsafe; note however that AffineLinear
does apply jnp.exp
as in real NVP).
For stable training, this could be changed to softplus, for example:
bijx.Scaling(transform_scale=nnx.softplus, rngs=rngs)
Scaling( # Param: 1 (4 B)
scale=TransformedParameter(param=Param( # 1 (4 B)
value=Array(1., dtype=float32)
), transform=<PjitFunction of <function softplus at 0x1157d0ae0>>)
)
Scalar Bijections¶
Linear and Affine Transformations¶
Basic building blocks for normalizing flows:
AffineLinear
: \([-\infty, \infty] → [-\infty, \infty]\) via \(\text{scale} \cdot x + \text{shift}\) with learnable parametersScaling
: \([-\infty, \infty] → [-\infty, \infty]\) via \(\text{scale} \cdot x\)Shift
: \([-\infty, \infty] → [-\infty, \infty]\) via \(x + \text{shift}\)
These provide the most basic learnable transformations with simple Jacobians.
Bounded/Unbounded Transforms¶
Bounded range transforms that map unbounded inputs to bounded intervals:
Sigmoid
: \([-\infty, \infty] → [0, 1]\) via \(σ(x) = 1/(1 + e^{-x})\)Tanh
: \([-\infty, \infty] → [-1, 1]\) via \(\tanh(x)\)GaussianCDF
: \([-\infty, \infty] → [0, 1]\) via \(Φ((x-\mu)/\sigma)\) with learnable location and scale
Unbounding transforms for mapping bounded to unbounded:
Tan
: \([0, 1] → [-\infty, \infty]\) via \(\tan(\pi(x - 0.5))\)
Note that each bijection can be inverted via .invert()
.
Positive Domain Transforms¶
Transforms for mapping to positive reals:
Exponential
: \([-\infty, \infty] → [0, \infty]\) via \(e^x\) (simple but can be numerically unstable)SoftPlus
: \([-\infty, \infty] → [0, \infty]\) via \(\log(1 + e^x)\) (numerically stable alternative)Power
: \([0, \infty] → [0, \infty]\) via \(x^p\) with learnable exponent \(p > 0\)
Other¶
Sinh
: \([-\infty, \infty] → [-\infty, \infty]\) via \(\sinh(x)\)BetaStretch
: \([0, 1] → [0, 1]\) via \(\frac{x^{\alpha}}{x^{\alpha} + (1-x)^{\alpha}}\)
The latter provides smooth stretching of the unit interval with a learnable parameter \(\alpha\) controlling the degree and direction of stretching.
Examples¶
The following exhibits different ways of using scalar bijections.
Parameters can be specified as shapes or values (arrays or nnx.Variables).
Broadcasting between parameters and inputs follows numpy rules, except that inferred event shape is summed over in log-density.
Parameters can be scalar or match (a subset of) the event shape of the input.
import bijx
import numpy as jnp
import jax
from flax import nnx
# random number geenrator
rngs = nnx.Rngs(0)
layer = bijx.AffineLinear(
# can specify parameters as shapes
scale=(),
# can also give specific values
# must be arrays -- then wrapped in nnx.Param -- or any nnx.Variable
shift=jnp.array(1.0),
# then need to provide rngs for initialization
rngs=rngs,
)
# shifts input, leaves density unchanged
layer.forward(1.0, 0.0)
(Array(2., dtype=float32), Array(0., dtype=float32))
# load %shapes magic, which tree-maps jnp.shape before showing output
bijx.utils.load_shapes_magic()
# both have batch shape (10, 7), x is interpreted to be scalar
x = jnp.zeros((10, 7))
log_density = jnp.zeros((10, 7))
# usual numpy broadcasting behavior
%shapes layer.forward(x, log_density)
((10, 7), (10, 7))
# x is inferred to be a vector, batch shape is (10,)
x = jnp.zeros((10, 7))
log_density = jnp.zeros((10,))
# transformation sums density change over event axes
%shapes layer.forward(x, log_density)
((10, 7), (10,))
# initialize to carry parameters of shape (7,)
layer = bijx.AffineLinear(
scale=(7,),
shift=jax.random.normal(rngs.next(), (7,)),
rngs=rngs,
)
# parameters and event-shape of x are broadcast together
x, log_density = layer.forward(x, log_density)
x[0] # first entry in batch
Array([-2.4424558 , -2.0356805 , 0.20554423, -0.3535502 , -0.76197404,
-1.1785518 , -1.1482196 ], dtype=float32)
Rational Quadratic Splines¶
MonotoneRQSpline
: Implements monotonic rational quadratic splines following Durkan et al. (2019).
This class behaves similar to scalar bijections above, in the sense that the transformation is applied per element.
However, the parameters here are always randomly initialized depending on the event shape, because their array shapes also depend on the number of knots.
Meta layers¶
Meta layers are bijections that rearrange data without changing its density. This results in a log-Jacobian determinant of zero. They are useful when creating complex architectures as chains of bijections that have differing assumptions (e.g. vectorial vs image-like shape, feature channels, etc.).
Reshape
: Reshapes the event dimensions of an array.ExpandDims
: Adds a new dimension to an array.SqueezeDims
: Removes a singleton dimension from an array.Partial
: Wraps another bijection to fix a set of its keyword arguments, creating a specialized version.
It is also straightforward to create custom layers given two functions, forward(x)
and reverse(x)
that only depend on the input using MetaLayer
as MetaLayer(forward, reverse)
.
split = bijx.MetaLayer(
# forward
lambda x: jnp.split(x, 2, axis=-1),
# reverse
lambda y: jnp.concatenate(y, axis=-1),
)
x = jnp.ones((3, 4))
y, log_density = split.forward(x, jnp.zeros((3,)))
print(*y, sep="\n")
print(jnp.all(log_density == 0))
[[1. 1.]
[1. 1.]
[1. 1.]]
[[1. 1.]
[1. 1.]
[1. 1.]]
True
# gets back to original shape, x
split.reverse(y, jnp.zeros((3,)))[0]
array([[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]])