Matrix Lie Group Operations

The bijx.lie module provides tools for numerical operations on Lie groups, with a particular focus on automatic differentiation. These are useful in physics applications where parameters or variables are elements of a Lie group, such as SU(N) or SO(N).

This document outlines the core functionalities, including handling Lie algebra elements, sampling from the Haar measure, and computing gradients of scalar functions defined on Lie groups.

# setup
import jax
import jax.numpy as jnp
import numpy as np
import bijx
from bijx import lie
from flax import nnx

# Set up a JAX random key for reproducibility
rngs = nnx.Rngs(42)

bijx.utils.load_shapes_magic()

Lie Algebra Generators

The module provides some standard generators for common Lie groups for convenience, providing the basis of the tangent space at the identity.

  • lie.U1_GEN: Generators for U(1). For compatibility, array with entry \(2i\) of shape (1, 1, 1) = (basis, n, n).

  • lie.SU2_GEN: Pauli matrices (multiplied by \(i\)) for SU(2).

  • lie.SU3_GEN: Gell-Mann matrices (multiplied by \(i\)) for SU(3).

print(f"U(1):  {lie.U1_GEN.shape=}, {lie.U1_GEN.dtype=}")
print(f"SU(2): {lie.SU2_GEN.shape=}, {lie.SU2_GEN.dtype=}")
print(f"SU(3): {lie.SU3_GEN.shape=}, {lie.SU3_GEN.dtype=}")
U(1):  lie.U1_GEN.shape=(1, 1, 1), lie.U1_GEN.dtype=dtype('complex64')
SU(2): lie.SU2_GEN.shape=(3, 2, 2), lie.SU2_GEN.dtype=dtype('complex64')
SU(3): lie.SU3_GEN.shape=(8, 3, 3), lie.SU3_GEN.dtype=dtype('complex64')

Basic Operations

# Example of scalar product between two SU(2) generators
g1 = lie.SU2_GEN[0]
g2 = lie.SU2_GEN[1]

sp = lie.scalar_prod(g1, g2)
print(f"Scalar product of first two SU(2) generators: {sp:.2f}")

# The basis is orthonormal
sp_self = lie.scalar_prod(g1, g1)
print(f"Scalar product of a generator with itself: {sp_self:.2f}")
Scalar product of first two SU(2) generators: 0.00+0.00j
Scalar product of a generator with itself: 1.00+0.00j

Sampling from the Haar Measure

Uniform sampling over a compact Lie group gives a prior distribution for sampling, respecting all symmetries of the group.

An explicit sampling function lie.sample_haar() as well as a general lie.HaarDistribution are implemented.

# a 5 x 5 grid of SU(2) matrices
dist = lie.HaarDistribution(n=2, base_shape=(5, 5))

%shapes dist.sample(batch_shape=(7,), rng=rngs())
((7, 5, 5, 2, 2), (7,))

For convenience, another “lattice” construction is provided that leads to (*lat_shape, len(lat_shape)) as “base” shape because we can associate to each vertex len(lat_shape) elements; one for each direction of the lattice.

# example of SU(3) so the dim=2 cannot be confused with the n=3 of SU(n)
dist = lie.HaarDistribution.periodic_gauge_lattice(n=3, lat_shape=(5, 5))

# Shape is (*batch_shape, *lattice_shape, lattice_dim, n, n)
%shapes dist.sample(batch_shape=(7,), rng=rngs())
((7, 5, 5, 2, 3, 3), (7,))
# We can verify that the matrix is indeed in SU(2)
su2_matrix = lie.sample_haar(rngs(), n=2)

determinant = jnp.linalg.det(su2_matrix)
# adjoint implements A^† also for batched arrays
identity = lie.adjoint(su2_matrix) @ su2_matrix

print(f"Determinant: {determinant:.1f}")
print(f"U†U:\n{identity.round(2)}")
Determinant: 1.0-0.0j
U†U:
[[ 1.+0.j -0.+0.j]
 [-0.-0.j  1.+0.j]]

Automatic Differentiation

Another purpose of the module is to enable automatic differentiation of functions whose inputs are Lie group elements. The gradient of a scalar function \(f: G \to \mathbb{R}\) at a point \(U \in G\) is an element of the tangent space \(T_U(G)\), which can be identified with the Lie algebra \(\mathfrak{g}\).

The primary function is lie.grad(). It computes the gradient of a function with respect to a Lie group element. The gradient is returned as an element of the Lie algebra itself.

There are in principle two ways to specify the basis:

  • As an array, explicitly listing the basis elements.

  • As a projection, taking a gradient in the ambient complex space and projecting it onto the tangent space of the group. (Default, projecting to Lie algebra of SU(N))

# Define a simple scalar function on SU(2)
def loss_fn(U):
  # Returns the real part of the trace
  return jnp.trace(U).real

u = lie.sample_haar(rngs(), n=2)

# The 'algebra' argument specifies the Lie algebra basis.
grad_fn = lie.grad(loss_fn, algebra=lie.SU2_GEN)
grad_u = grad_fn(u)
%shapes grad_u

# The gradient is a Lie algebra element (skew-hermitian and traceless)
print(f"\nTrace of gradient: {jnp.trace(grad_u):.2f}")
print(f"Adjoint plus self:\n{(lie.adjoint(grad_u) + grad_u).round(2)}")
(2, 2)

Trace of gradient: 0.00+0.00j
Adjoint plus self:
[[0.+0.j 0.+0.j]
 [0.+0.j 0.+0.j]]

Directional Derivative

One can also compute a directional derivative along a specific direction in the tangent space using liegrad.curve_grad(). This computes

\[ \left.\frac{d}{dt}\right|_{t=0} f(e^{t V} U) \,, \]

where \(V\) is a direction in the tangent space (element of the Lie algebra). Note that here, a convention of left vs right multiplication was chosen. That is, to transport \(V\) from the Lie algebra to the tangent space at \(U\), we use \(V U\). Conversion to the opposite convention can be done by conjugating with \(U\).

# choose some direction (could be any superposition of generators)
direction = lie.SU2_GEN[0] * 0.5

# Compute directional derivative of loss_fn
dir_grad_fn = lie.curve_grad(loss_fn, direction=direction)
dir_grad_val = dir_grad_fn(u)

# This should be equal to the scalar product of the full gradient and the direction
manual_dir_grad = lie.scalar_prod(direction, grad_u)

print(f"Directional derivative from curve_grad: {dir_grad_val:.4f}")
print(f"Scalar product of full grad and direction: {manual_dir_grad:.4f}")
Directional derivative from curve_grad: -0.9601
Scalar product of full grad and direction: -0.9601+0.0000j

The function lie.path_grad2() works similarly but computs both the first and the second path derivative. Finally, lie.path_div() computes the divergence of a function using path derivatives given a full basis of generators, and assumes the function returns a vector of components with respect to the same basis.

Value, Gradient, and Divergence

For some applications, especially in the context of flows, one needs not only the gradient but also the divergence of the gradient (the Laplacian). The function liegrad.value_grad_divergence() computes the value, gradient, and Laplacian of a scalar function in a single call, using two backward differentiation passes.

val, grad, div = lie.value_grad_divergence(loss_fn, u, lie.SU2_GEN)
%shapes val, grad, div
((), (2, 2), ())

Lattice Operations

The lattice submodule provides utilities for manipulating gauge fields on a (periodic) lattice.

lattice.gauge.wilson_action() computes the Wilson action for a lattice gauge configuration, and lattice.gauge.wilson_log_prob() computes the log probability (minus the action).

dist = lie.HaarDistribution.periodic_gauge_lattice(n=3, lat_shape=(5, 5))
lat, _ = dist.sample((), rng=rngs())
import bijx.lattice.gauge as glat

glat.wilson_action(lat, 3.0)
Array(3.8373678, dtype=float32)
# also supports batched inputs
%shapes glat.wilson_action(lat[None], 3.0)
# multiple couplings
%shapes glat.wilson_action(lat, jnp.array([3.0, 4.0]))
(1,)
(2,)
act = glat.wilson_action(lat, 3.0)

# translation is a symmetry of Wilson action
lat_tr = glat.roll_lattice(lat, (1, 2))
assert jnp.allclose(act, glat.wilson_action(lat_tr, 3.0))
# obviously lat_tr shouldn't be the same as lat (translation is not the identity)
assert not jnp.allclose(lat, lat_tr)

# so is rotation (specify axes of plane to rotate)
lat_rot = glat.rotate_lat(lat, 0, 1)
assert jnp.allclose(act, glat.wilson_action(lat_rot, 3.0))

# so is flip
lat_flip = glat.flip_axis(lat, 0)
assert jnp.allclose(act, glat.wilson_action(lat_flip, 3.0))

# so is swap
lat_swap = glat.swap_axes(lat, 0, 1)
assert jnp.allclose(act, glat.wilson_action(lat_swap, 3.0))