Scalar field theory

import flax
import flax.nnx as nnx
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
from jax_autovmap import autovmap

import bijx

Free theory

mass = 0.5
lat_shape = (32, 32)

# constructing scaling explicitly here;
# note there is also bijx.FreeTheoryScaling (see MCMC section below)
ks = bijx.fourier.fft_momenta(lat_shape, lattice=True)
spectrum = 1 / (mass**2 + jnp.sum(ks**2, axis=-1))

# note: specturm here could also be a nnx.Param to make it trainable;
# or set it to output of (trainable) function of ks
scaling = bijx.SpectrumScaling(spectrum)

free_theory_prior = bijx.Transformed(
    bijx.IndependentNormal(lat_shape, rngs=nnx.Rngs(sample=0)),
    scaling
)
x, _ = free_theory_prior.sample()
plt.imshow(x)
plt.axis('off')
plt.show()
../_images/f77f80e250ef1eb5760bc3cc8073a7e8ac89d16341374b5cba38afe2856cac56.png

Scalar \(\phi^4\) theory

In this section the use of the ConvCNF network is demonstrated, as well as other methods for the scalar theory. The architecture was originally introduced in [2207.00283] with code available on github using the haiku jax library. In particular, below is a demonstration of how previously published trained parameters can be converted and used in this framework.

# "build" provides a constructor similar to the jaxnf version

rngs = nnx.Rngs(params=0)
vf = bijx.ConvVF.build((32, 32), (), rngs=rngs)

Transfer model parameters from previous version

# convolution conventions change between frameworks, need to relabel orbits

def reshuffle_orbits(kernel_params, kernel_shape):
    oc, old_orbits = bijx.nn.conv.kernel_d4(kernel_shape)
    old_orbits = old_orbits[::-1, ::-1]
    kernel = bijx.nn.conv.unfold_kernel(kernel_params, old_orbits)
    for d, s in enumerate(kernel.shape[:-2]):
        kernel = jnp.roll(kernel, -1 + (s % 2), axis=d)
    return bijx.nn.conv.fold_kernel(kernel, vf.conv.orbits, oc)
# transfer haiku -> this
file_path = '../../../data/single-L32.npz'  # set this to where parameters are stored
try:
    params_haiku = np.load(file_path, allow_pickle=True)['params'].item()

    _, params = nnx.split(vf)

    params.replace_by_pure_dict({
        'conv': {'kernel_params': reshuffle_orbits(params_haiku['~']['w'], (32, 32))},
        'feature_map': {'features': {0: {'phi_freq': params_haiku['~']['phi_freq'][None]}}},
        'feature_superposition': params_haiku['~']['freq_superpos'],
        'time_kernel': {'superposition': params_haiku['~']['time_superpos']}
    })

    # in place update
    nnx.update(vf, params)
except FileNotFoundError:
    print('skipping parameter loading')

Sampling

flow = bijx.ContFlowDiffrax(vf, config=bijx.DiffraxConfig(dt=1/10))
sampler = bijx.Transformed(
    bijx.IndependentNormal((32, 32), rngs=nnx.Rngs(sample=0)),
    flow,
)
x, _ = sampler.sample(())
plt.imshow(x)
plt.axis('off')
plt.show()
../_images/a069869ef0350bb0386b299b65076cfc19668a31d6c2082c2e8d1913245e9235.png

MCMC sampling

As an example of how independent Metropolis-Hastings can be used to generate samples via a Markov chain, we will consider a simple free theory in the following. In particular, we will measure the acceptance rate of the Markov chian when we use a free theory of fixed mass as proposal, and increase the non-gaussian \(\phi^4\) coupling \(\lambda\) in the target theory.

rngs = nnx.Rngs(43)
mass = 0.5
lat_shape = (32, 32)

sampler = bijx.Transformed(
    bijx.IndependentNormal(lat_shape, rngs=nnx.Rngs(sample=0)),
    bijx.FreeTheoryScaling(mass**2, lat_shape, half=False)
)
# for convenience, define our phi4 theory

@flax.struct.dataclass
class Phi4Theory:
    m2: float = -4
    lam: float = 5.0

    @autovmap(None, 2)
    def action(self, phi):
        act = bijx.lattice.scalar.phi4_term(phi, self.m2, self.lam)
        return jnp.sum(act) # / 2

    def log_prob(self, phi):
        return -self.action(phi)
@jax.jit
def imh_step(rng, state, mass, lam):
    target = Phi4Theory(m2=mass**2, lam=lam)
    imh = bijx.mcmc.IMH(sampler, target.log_prob)
    return imh.step(rng, state)

@jax.jit
def imh_init(rng, mass, lam):
    target = Phi4Theory(m2=mass**2, lam=lam)
    imh = bijx.mcmc.IMH(sampler, target.log_prob)
    return imh.init(rng)
lams = np.linspace(0, 0.5, 8)
accept_rate = np.zeros(len(lams))

for i, lam in enumerate(lams):
    count = 5000
    accepted = 0

    state = imh_init(rngs(), mass, lam)

    for _ in range(count):
        state, info = imh_step(rngs(), state, mass, lam)
        accepted += info.is_accepted
    accept_rate[i] = accepted / count
plt.title('independent MH MC with free theory proposal')
plt.plot(lams, accept_rate)
plt.xlabel('lambda')
plt.ylabel('acceptance rate')
plt.show()
../_images/44949c31febbf873207b27978785bccbd0e258e1eff671edd4be9494e176bbad.png