Analytic Bijections

This tutorial demonstrates how to leverage the analytic bijections defined in bijx, first in 1D and then in new radial flows in 2D.

from flax import nnx
import jax
import jax.numpy as jnp
import optax
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

import bijx

rngs = nnx.Rngs(0)

1D Example

We can learn onedimensional distributions by stacking sufficiently many scalar bijections. First, we’ll define a synthetic target distribution as an example.

@nnx.dataclass
class PolyTarget:
    # Oscillatory Gaussian component coefficients
    a1: nnx.Data[float] = 1.0  # Amplitude of sin term
    w1: nnx.Data[float] = 5.0  # Frequency of sin
    g1: nnx.Data[float] = 5.0  # Gaussian decay rate

    # Cosine component
    a2: nnx.Data[float] = 2.0  # Amplitude of cos term
    w2: nnx.Data[float] = 10.0  # Frequency of cos

    # Polynomial component
    a4: nnx.Data[float] = -0.2

    def __call__(self, x):
        """Compute unnormalized log density."""
        # Handle both (batch,) and (batch, 1) shapes
        x_flat = x.squeeze(-1) if x.shape[-1] == 1 else x

        log_p = (
            self.a1 * jnp.sin(self.w1 * x_flat) * jnp.exp(-self.g1 * x_flat**2)
            + self.a2 * jnp.cos(self.w2 * x_flat)
            + self.a4 * x_flat**4
        )
        return log_p

poly_target = PolyTarget()
plt.figure(figsize=(6, 3))
plt.title('Target density')
x = jnp.linspace(-4, 4, 200)
plt.plot(x, jnp.exp(poly_target(x)))
plt.ylabel('Unnormalized density')
plt.show()
../_images/a43f0b5a5b98aacb0a44b3737ab0e825c28a058db772859d10149303e22b06a0.png

Model

Instead of a for-loop over multiple bijections, we use bijx.ScanChain which chains the internal bijection using jax’s scan. The internal bijection is expected to have parameters with leading batch dimension, which is achieved by applying vmap over a number of random keys.

stack_size = 64

@nnx.split_rngs(splits=stack_size)
@nnx.vmap(in_axes=(0,), out_axes=0)
def create_bijection_stack(rngs):
    # base = bijx.SinhConjugation(rngs=rngs)
    # base = bijx.CubicRational(rngs=rngs)

    base = bijx.CubicConjugation(rngs=rngs)

    return bijx.ScanChain(base)
bijection = create_bijection_stack(nnx.Rngs(0))
model = bijx.Transformed(
    bijx.IndependentNormal((), rngs=nnx.Rngs(0)),
    bijection,
)
plt.figure(figsize=(6, 3))
plt.title('Density at initialization')
x = jnp.linspace(-4, 4, 200)
plt.plot(x, jnp.exp(model.log_density(x)))
plt.ylabel('Unnormalized density')
plt.show()
../_images/2e00e2cfc0f1531ecd237529166a95e5326ebb23876df47a47fea638bf25f4c4.png

Training

# simple training setup
batch_size = 256

optimizer = nnx.Optimizer(
    tx=optax.adam(1e-3),
    model=model,
    wrt=nnx.Param,
)

def loss_fn(model):
    x, lq = model.sample((batch_size,))
    lp = poly_target(x)
    return -jnp.mean(lp - lq)

@nnx.jit
def update_step(model, optimizer):
    loss, grads = nnx.value_and_grad(loss_fn)(model)
    optimizer.update(grads=grads, model=model)
    return loss
steps = 2000
losses = np.full(steps, np.nan)

for i in tqdm(range(steps)):
    losses[i] = update_step(model, optimizer)

plt.figure(figsize=(6, 3))
plt.plot(losses)
plt.xlabel('Step')
plt.ylabel('Loss')
plt.show()
100%|██████████| 2000/2000 [00:05<00:00, 341.69it/s]
[<matplotlib.lines.Line2D at 0x1178bf0e0>]
../_images/077d448c75bf4259923322e9ff27a4761538bb40796e94e77ba8e19477f5f91f.png

Evaluation

plt.figure(figsize=(6, 3))
x = jnp.linspace(-4, 4, 200)
plt.plot(x, jnp.exp(model.log_density(x)), label='Learned density')
plt.twinx()
plt.plot(x, jnp.exp(poly_target(x)), '--', label='Target density', color='C1')
plt.ylabel('Unnormalized density')
plt.legend()
plt.show()
../_images/3894440823660c5669aee91a9df9c788079180f667aaefed30e849e7b26fe641.png

Radial flows

First we define a spiral distribution in 2D, this time by defining how to sample from it.

Target

@nnx.dataclass
class SpiralData:

    n_turns: nnx.Data[float] = 2.5  # Number of spiral turns (5 * pi / (2 * pi) = 2.5)
    radius_scale: nnx.Data[float] = 1.0 / 20.0  # Scaling factor for spiral radius
    noise_scale: nnx.Data[float] = 0.02  # Gaussian noise added to spiral
    is_normalized: bool = False  # No density, so not normalized

    @property
    def event_shape(self) -> tuple[int, ...]:
        return (2,)

    def sample(self, batch_shape: tuple[int, ...] = (), *, rng):
        """Generate spiral samples using JAX-compatible operations."""
        total_samples = int(np.prod(batch_shape))

        # Generate parameter t uniformly from 0 to n_turns * 2 * pi
        rng1, rng2 = jax.random.split(rng)
        t = jax.random.uniform(rng1, (total_samples,), minval=0, maxval=self.n_turns * 2 * jnp.pi)

        # Generate spiral coordinates
        x = t * jnp.cos(t) * self.radius_scale
        y = t * jnp.sin(t) * self.radius_scale

        # Add Gaussian noise
        noise = jax.random.normal(rng2, (total_samples, 2)) * self.noise_scale

        # Combine and reshape
        samples = jnp.stack([x, y], axis=1) + noise
        return samples.reshape(*batch_shape, 2)

spiral_data = SpiralData()
samples = spiral_data.sample((1000,), rng=jax.random.key(0))
plt.figure(figsize=(3, 3))
plt.scatter(samples[:, 0], samples[:, 1], s=1)
plt.show()
../_images/15cfbe6da62dc0f4e98a87972a9d5c6fa4bf5eaf7919c2a52c9c4fbdd01aca9e.png

Model architecture

First, we now want to transfrom radii, with respect to some trainable center. Thus, we must make sure that we map rays onto rays (i.e. preserve origin).

stack_size = 64

@nnx.split_rngs(splits=stack_size)
@nnx.vmap(in_axes=(0,), out_axes=0)
def create_bijection_stack(rngs):
    # base = bijx.SinhConjugation(rngs=rngs)
    # base = bijx.CubicRational(rngs=rngs)
    base = bijx.CubicConjugation(rngs=rngs)

    # now wrapped in ray transform
    return bijx.RayTransform(bijx.ScanChain(base))

We will use a simple Fourier-series based coupling layer to make the radial transformation depend on the angle.

class FourierCondNet(nnx.Module):
    """Fourier-based conditioning network for angle-dependent parameters."""

    def __init__(
        self,
        output_dim: int,
        fourier_count: int = 3,
        hidden_dim: int | None = None,
        *,
        rngs: nnx.Rngs,
    ):
        assert fourier_count % 2 == 1, 'Fourier count must be odd'
        self.kernel = bijx.nn.embeddings.KernelFourier(
            fourier_count, val_range=(-jnp.pi, jnp.pi)
        )
        self.output_dim = output_dim

        if hidden_dim is not None:
            self.hidden = nnx.Linear(
                fourier_count, hidden_dim,
                use_bias=False, rngs=rngs,
                kernel_init=nnx.initializers.normal(0.01),
            )
            final_in = hidden_dim
        else:
            final_in = fourier_count

        self.superpos = nnx.Linear(
            final_in, output_dim,
            use_bias=False, rngs=rngs,
            kernel_init=nnx.initializers.normal(0.01),
        )

    def __call__(self, x_hat):
        """Map unit vectors to bijection parameters.

        Args:
            x_hat: Unit vectors of shape (..., 2).

        Returns:
            Parameters of shape (..., output_dim).
        """
        theta = jnp.atan2(x_hat[..., 1], x_hat[..., 0])
        features = self.kernel(theta)

        # Regularize feature magnitudes: suppress high frequencies
        freq_count = (self.kernel.feature_count - 1) // 2
        freq_indices = jnp.concatenate([
            jnp.arange(1, freq_count + 1),  # for sin terms
            jnp.arange(1, freq_count + 1),  # for cos terms
            jnp.array([1.0])  # for constant term
        ])
        features = features / freq_indices

        if hasattr(self, 'hidden'):
            features = self.hidden(features)
        return self.superpos(features)
def create_radial_layer(rngs):
    base = create_bijection_stack(rngs)

    cond_net = FourierCondNet(
        output_dim=bijx.ModuleReconstructor(base).params_total_size,
        fourier_count=7,
        rngs=rngs
    )

    layer = bijx.RadialConditional(
        scalar_bijection=base,
        cond_net=cond_net,
        center=jnp.array([-0.5, -1.]),
        scale=jnp.zeros(2),
        rngs=rngs
    )

    return layer
model = bijx.Transformed(
    bijx.IndependentNormal((2,), rngs=rngs),
    create_radial_layer(rngs)
)

Training

# simple training setup
batch_size = 512

optimizer = nnx.Optimizer(
    tx=optax.adam(1e-2),
    model=model,
    wrt=nnx.Param,
)

def loss_fn(model):
    x = spiral_data.sample((batch_size,), rng=model.rngs())
    lq = model.log_density(x)
    return -jnp.mean(lq)

@nnx.jit
def update_step(model, optimizer):
    loss, grads = nnx.value_and_grad(loss_fn)(model)
    optimizer.update(grads=grads, model=model)
    return loss
steps = 5000
losses = np.full(steps, np.nan)

for i in tqdm(range(steps)):
    losses[i] = update_step(model, optimizer)

plt.figure(figsize=(6, 3))
plt.plot(losses)
plt.xlabel('Step')
plt.ylabel('Loss')
plt.show()
100%|██████████| 5000/5000 [00:33<00:00, 148.29it/s]
../_images/dc96b3683ca7457f095f5a3de58da652b3ca43e66e432c68b20a7b4320227c03.png

Evaluation

x, _ = model.sample((1000,))
plt.figure(figsize=(3, 3))
plt.title('Samples from learned model')
plt.scatter(x[:, 0], x[:, 1], s=1)
plt.show()
../_images/09fe0777ce3153789bd4beec64d2f7679997b2b152b91aac1793fbbd59f4cabb.png