bijx.lie

Lie group operations and automatic differentiation tools.

This module provides specialized tools for working with matrix Lie groups, particularly for applications in lattice field theory and gauge theories. It focuses on efficient automatic differentiation with respect to group elements, Haar measure sampling, and gradient computations on manifolds.

Key functionality:
  • Haar measure sampling for SU(N) groups

  • Automatic differentiation with respect to matrix group elements

  • Lie algebra projections and tangent space operations

  • Specialized gradient operators for group-valued functions

  • Matrix chain contractions and traces

The implementation is optimized for SU(N) groups commonly used in lattice gauge theory, but should also generalize to other matrix groups such as O(N).

Functions

adjoint(arr)

Compute conjugate transpose (adjoint) of matrix.

compute_haar_density(eigenvalue_angles)

Compute Haar measure density for SU(N) matrices from eigenvalue angles.

construct_su_matrix_from_eigenvalues(rng, ...)

Construct SU(N) matrices from given eigenvalue angles using random eigenvectors.

contract(*factors[, trace, ...])

Contract chain of matrices with Einstein summation.

create_eigenvalue_grid(n[, grid_points])

Create a uniform grid in eigenvalue angle coordinates for SU(N) visualization.

curve_grad(fun, direction[, argnum, ...])

Compute directional derivative along group geodesic.

evaluate_density_on_eigenvalue_grid(...[, ...])

Evaluate a density function on SU(N) using eigenvalue angle coordinates.

grad(fn[, argnum, return_value, has_aux, ...])

Compute gradient with respect to matrix Lie group element.

path_div(fun, gens, us)

Compute divergence of a function.

path_grad(fun, gens, us)

Compute gradient with respect to multiple matrix group inputs.

path_grad2(fun, gens, us)

Compute first and second derivative with respect to (each) matrix input.

sample_haar(rng[, n, batch_shape])

Sample SU(N) matrices uniformly according to Haar measure.

scalar_prod(a, b)

Compute scalar product between Lie algebra elements.

skew_traceless_cot(a, u)

Project cotangent vector to SU(N) Lie algebra.

value_grad_divergence(fn, u, gens)

Compute the gradient and Laplacian (i.e. divergence of grad).

Classes

HaarDistribution

Distribution of SU(N) matrices under Haar measure.

bijx.lie.contract(*factors, trace=False, return_einsum_indices=False)[source]

Contract chain of matrices with Einstein summation.

Performs matrix multiplication chain \(A_1 A_2 \cdots A_n\) with automatic broadcasting over batch dimensions. The contraction follows left-to-right order with proper index management for arbitrary numbers of factors.

Key features:
  • Handles arbitrary number of matrix factors

  • Automatic broadcasting over leading (batch) dimensions

  • Optional trace computation for closed loops

  • Can return einsum indices for debugging/inspection

Parameters:
  • factors – Sequence of arrays representing matrices to contract. Each must have at least 2 dimensions (matrix dimensions).

  • trace – Whether to trace the result (connect first and last indices).

  • return_einsum_indices – Whether to return einsum index strings.

Returns:

Contracted result, or tuple (result, in_indices, out_indices) if return_einsum_indices=True.

Example

>>> A = jnp.ones((3, 4, 4))
>>> B = jnp.ones((3, 4, 4))
>>> C = contract(A, B)  # Shape (3, 4, 4)
>>> trace_AB = contract(A, B, trace=True)  # Shape (3,)
bijx.lie.scalar_prod(a, b)[source]

Compute scalar product between Lie algebra elements.

Implements the standard scalar product on the Lie algebra: \(\langle A, B \rangle = \frac{1}{2} \text{tr}(A^\dagger B)\)

Parameters:
  • a – First Lie algebra element.

  • b – Second Lie algebra element.

Returns:

Real scalar product value.

Note

For skew-Hermitian matrices \(A, B\), this gives a real result and defines a positive definite inner product on the Lie algebra.

bijx.lie.adjoint(arr)[source]

Compute conjugate transpose (adjoint) of matrix.

Returns the Hermitian conjugate \(A^\dagger = (A^T)^*\) of the input matrix.

Parameters:

arr – Input matrix array.

Returns:

Conjugate transpose of the input.

bijx.lie.sample_haar(rng, n=2, batch_shape=())[source]

Sample SU(N) matrices uniformly according to Haar measure.

Generates random SU(N) matrices distributed according to the unique left- and right-invariant Haar measure on the group. This is the standard uniform distribution for compact Lie groups.

Parameters:
  • rng – Random key for sampling.

  • n – Dimension of SU(N) group (default: SU(2)).

  • batch_shape – Shape of batch dimensions for multiple samples.

Returns:

SU(N) matrices of shape batch_shape + (n, n).

Example

>>> key = jax.random.key(42)
>>> su2_matrix = sample_haar(key, n=2)  # Single SU(2) matrix
>>> su3_batch = sample_haar(key, n=3, batch_shape=(10,))  # 10 SU(3) matrices
class bijx.lie.HaarDistribution[source]

Bases: ArrayDistribution

Distribution of SU(N) matrices under Haar measure.

Implements the uniform distribution on the compact Lie group SU(N) according to the normalized Haar measure. This is the unique left- and right-invariant probability measure on the group.

The distribution can handle additional base shape dimensions for applications like lattice gauge theory where matrices are assigned to lattice sites and links.

Parameters:
  • n – Dimension of SU(N) group.

  • base_shape – Additional shape dimensions (e.g., lattice structure).

  • rngs – Random number generators for sampling.

Example

>>> # SU(2) matrices on a 4x4 lattice with 4 link directions
>>> haar_dist = HaarDistribution.periodic_gauge_lattice(
...     n=2, lat_shape=(4, 4), rngs=rngs
... )
>>> samples, _ = haar_dist.sample((100,), rng=rngs.next())
classmethod periodic_gauge_lattice(n, lat_shape, rngs=None)[source]

Create Haar distribution for periodic gauge lattice.

Convenience constructor for lattice gauge theories with periodic boundary conditions. Creates SU(N) matrices for each lattice site and spatial direction.

Parameters:
  • n – Dimension of SU(N) gauge group.

  • lat_shape – Shape of the spatial lattice.

  • rngs – Random number generators.

Returns:

HaarDistribution with event shape lat_shape + (ndim, n, n) where ndim = len(lat_shape) is the number of spatial dimensions.

sample(batch_shape, rng=None, **kwargs)[source]

Sample SU(N) matrices from Haar measure.

Parameters:
  • batch_shape – Shape of batch dimensions.

  • rng – Random key for sampling.

  • **kwargs – Additional arguments (unused).

Returns:

Tuple of (samples, log_density) where log_density is zero since Haar measure is the uniform distribution.

log_density(x, **kwargs)[source]

Evaluate log density under Haar measure.

Parameters:
  • x – SU(N) matrices to evaluate.

  • **kwargs – Additional arguments (unused).

Returns:

Zero log density (uniform distribution on compact group).

Note

The Haar measure is uniform, so the log density is constant (zero after normalization).

bijx.lie.compute_haar_density(eigenvalue_angles)[source]

Compute Haar measure density for SU(N) matrices from eigenvalue angles.

For SU(N) matrices, the Haar measure density in eigenvalue coordinates is given by the Vandermonde determinant:

\[ \rho(\theta_1, \ldots, \theta_{n-1}) = \prod_{i<j} \abs{e^{i\theta_i} - e^{i\theta_j}}^2 \]

where \(\theta_n = -\sum_{i=1}^{n-1} \theta_i\) to ensure det(U) = 1.

This density accounts for the non-trivial geometry of the group when parameterized by eigenvalue angles, making it essential for proper visualization and integration on SU(N).

Parameters:

eigenvalue_angles – Array of shape (…, n-1) containing the eigenvalue angles \(\theta_1, \ldots, \theta_{n-1}\) for SU(n) matrices.

Returns:

Haar density values of shape (…,) corresponding to each set of angles.

Example

>>> # SU(2) case - single angle
>>> angles = jnp.array([0.5, 1.0, -0.3])  # Shape (3,)
>>> density = compute_haar_density(angles[..., None])  # Shape (3,)
>>>
>>> # SU(3) case - two angles
>>> angles = jnp.array([[0.1, 0.2], [0.5, -0.1]])  # Shape (2, 2)
>>> density = compute_haar_density(angles)  # Shape (2,)

Note

  • The density is normalized such that \(\int \rho(\theta) d\theta\) = volume of parameter space

  • For visualization, divide by total volume to get probability density

  • For SU(2), this reduces to \(\abs{2i \sin(\theta)}^2 = 4\sin^2(\theta)\)

bijx.lie.construct_su_matrix_from_eigenvalues(rng, eigenvalue_angles)[source]

Construct SU(N) matrices from given eigenvalue angles using random eigenvectors.

This function creates SU(N) matrices with specified eigenvalue structure by: 1. Generating random unitary eigenvector matrices via Haar sampling 2. Constructing diagonal matrices from the eigenvalue angles 3. Performing similarity transformation: \(U = V D V^{\dagger}\)

The eigenvalues are \(e^{i\theta_1}, \ldots, e^{i\theta_{n-1}}, e^{-i\sum\theta_j}\) to ensure \(\det(U) = 1\).

Parameters:
  • rng – JAX random key for sampling eigenvectors.

  • eigenvalue_angles – Array of shape (…, n-1) containing eigenvalue angles. For SU(2): single angle \(\theta\) gives eigenvalues \(e^{\pm i\theta/2}\) For SU(3): two angles \((\theta_1,\theta_2)\) give eigenvalues \(e^{i\theta_1}, e^{i\theta_2}, e^{-i\theta_1-i\theta_2}\)

Returns:

SU(N) matrices of shape (…, n, n) with the specified eigenvalue structure but random eigenvector orientations (uniformly distributed via Haar measure).

Example

>>> key = jax.random.key(42)
>>>
>>> # SU(2) matrices with specific eigenvalue angles
>>> angles = jnp.linspace(-jnp.pi, jnp.pi, 100)[..., None]  # Shape (100, 1)
>>> matrices = construct_su_matrix_from_eigenvalues(key, angles)  # (100, 2, 2)
>>>
>>> # SU(3) matrices on a 2D grid of angles
>>> theta1, theta2 = jnp.mgrid[-1:1:50j, -1:1:50j]
>>> angles = jnp.stack([theta1, theta2], axis=-1)  # Shape (50, 50, 2)
>>> matrices = construct_su_matrix_from_eigenvalues(key, angles)
>>> matrices.shape
(50, 50, 3, 3)

Note

  • This is essential for visualizing densities on SU(N) in eigenvalue coordinates

  • The eigenvectors are sampled uniformly, giving the correct Haar measure

  • Each call with the same rng and angles gives identical results

bijx.lie.create_eigenvalue_grid(n, grid_points=200)[source]

Create a uniform grid in eigenvalue angle coordinates for SU(N) visualization.

Generates a regular grid of eigenvalue angles suitable for visualizing probability densities on SU(N) groups. The grid covers the fundamental domain of eigenvalue angles with appropriate boundary handling.

For SU(N), we have N-1 independent angles with the constraint that their sum determines the N-th angle to ensure det(U) = 1.

Parameters:
  • n – Dimension of SU(N) group (e.g., n=2 for SU(2), n=3 for SU(3)).

  • grid_points – Number of grid points along each dimension.

Returns:

Array of shape (grid_points,)^(n-1) + (n-1,) containing the grid of eigenvalue angles. For n=2: shape (grid_points, 1). For n=3: shape (grid_points, grid_points, 2).

Example

>>> # SU(2) case: single angle from -π to π
>>> angles = create_eigenvalue_grid(n=2, grid_points=100)
>>> angles.shape
(100, 1)
>>> # SU(3) case: two angles, each from -π to π
>>> angles = create_eigenvalue_grid(n=3, grid_points=50)
>>> angles.shape
(50, 50, 2)
>>> # Use with other functions
>>> haar_density = compute_haar_density(angles)
>>> matrices = construct_su_matrix_from_eigenvalues(key, angles)

Note

  • Grid extends from -π + ε to π + ε to avoid boundary singularities

  • Volume element for integration: (2π/grid_points)^(n-1)

  • For n ≥ 4, visualization becomes impractical due to dimensionality

bijx.lie.evaluate_density_on_eigenvalue_grid(density_fn, n, grid_points=200, rng=None, normalize=True, normalization_domain='weyl_chamber')[source]

Evaluate a density function on SU(N) using eigenvalue angle coordinates.

This function evaluates any scalar function on SU(N) matrices by:
  1. Creating a grid in eigenvalue angle space

  2. Constructing SU(N) matrices from these angles

  3. Evaluating the density function on these matrices

  4. Applying the correct Haar measure for integration

This is essential for visualizing and analyzing probability densities that arise in physics applications like lattice gauge theory.

Parameters:
  • density_fn – Function that takes SU(N) matrices (…, n, n) and returns log-densities or densities of shape (…,).

  • n – Dimension of SU(N) group.

  • grid_points – Number of grid points along each eigenvalue dimension.

  • rng – Random key for constructing matrices. If None, uses key 0.

  • normalize – Whether to normalize the density to integrate to 1.

  • normalization_domain (str) – Domain of integration. Can be “torus” or “weyl_chamber”. Only affects normalization.

Returns:

(angles, density_values, haar_weights) where:
  • angles: Eigenvalue angle grid of shape (grid_points^(n-1), n-1)

  • density_values: Function values at each grid point

  • haar_weights: Haar measure weights for proper integration

Return type:

tuple

Example

>>> def target_density(U):
...     # Example: density proportional to Re[tr(U²)]
...     return jnp.real(jnp.trace(U @ U, axis1=-2, axis2=-1))
>>>
>>> key = jax.random.key(42)
>>> angles, density, weights = evaluate_density_on_eigenvalue_grid(
...     target_density, n=2, grid_points=100, rng=key
... )
>>>
>>> # For plotting SU(2) density
>>> import matplotlib.pyplot as plt
>>> _ = plt.plot(angles.squeeze(), density)
>>> _ = plt.xlabel('Eigenvalue angle')
>>> _ = plt.ylabel('Density')
>>> plt.close()

Note

  • Returned arrays are flattened for easy iteration/plotting

  • For n=2: angles shape (grid_points, 1), others shape (grid_points,)

  • For n=3: angles shape (grid_points², 2), others shape (grid_points²,)

  • Volume element for integration: (2π/grid_points)^(n-1) * haar_weights

bijx.lie.skew_traceless_cot(a, u)[source]

Project cotangent vector to SU(N) Lie algebra.

Transforms the cotangent vector from JAX’s backward pass into an element of the SU(N) Lie algebra (traceless skew-Hermitian matrices). This is the natural projection for SU(N) groups.

Mathematical operation:
  1. Compute \(A^\dagger\) (conjugate transpose of cotangent)

  2. Transport to identity: \(U A^\dagger\)

  3. Project to skew-Hermitian: \(B - B^\dagger\)

  4. Project to traceless: \(B - \frac{\text{tr}(B)}{n} I\)

Parameters:
  • a – Cotangent vector from automatic differentiation.

  • u – Group element at which cotangent is evaluated.

Returns:

Element of SU(N) Lie algebra (traceless skew-Hermitian matrix).

Note

This is more efficient than explicit projection using generators as it avoids computing scalar products with each basis element.

bijx.lie.grad(fn, argnum=0, return_value=False, has_aux=False, algebra=<function skew_traceless_cot>)[source]

Compute gradient with respect to matrix Lie group element.

Computes the Riemannian gradient \(\nabla_g f(g)\) where \(g\) is a matrix Lie group element and \(f\) is a scalar-valued function. The gradient lies in the tangent space \(T_g G\), which is isomorphic to the Lie algebra.

The algebra parameter controls how the cotangent vector from automatic differentiation is projected to the tangent space:

  • Function (a, u) -> v: Custom projection implementation

  • Array of generators: Projection via scalar products with basis elements

  • Default: Efficient SU(N) projection without explicit generators

Parameters:
  • fn – Scalar-valued function to differentiate.

  • argnum – Argument position of the group element.

  • return_value – Whether to return function value along with gradient.

  • has_aux – Whether fn returns auxiliary outputs.

  • algebra – Projection method (function) or generator basis (array).

Returns:

Function computing gradient, or (value, gradient) if return_value=True.

Example

>>> def potential(U):
...     return jnp.real(jnp.trace(U @ U.conj().T))
>>> grad_potential = grad(potential)
>>> U = sample_haar(key, n=2)
>>> gradient = grad_potential(U)  # Element of su(2) algebra
bijx.lie.value_grad_divergence(fn, u, gens)[source]

Compute the gradient and Laplacian (i.e. divergence of grad).

This is done using two backward passes, and using an explicit basis of tangent vectors at the identity (gens).

The given function is assumed to give scalar outputs.

bijx.lie.curve_grad(fun, direction, argnum=0, has_aux=False, return_value=False, left=False)[source]

Compute directional derivative along group geodesic.

Computes the directional derivative: \(\frac{d}{dt} f(\ldots, \exp(t \cdot \text{direction}) \cdot g, \ldots)\Big|_{t=0}\)

This gives the rate of change of the function along the geodesic in the group generated by the specified Lie algebra direction.

Key features:
  • Supports left or right group action (exp(tX)g vs g exp(tX))

  • Uses custom JVP for efficient differentiation

  • Can return function value simultaneously

Parameters:
  • fun – Function to differentiate.

  • direction – Lie algebra element specifying the direction.

  • argnum – Position of group argument (must be positional).

  • has_aux – Whether fun has auxiliary outputs.

  • return_value – Whether to return function value with derivative.

  • left – Whether to use left group action (default: right action).

Returns:

Function computing directional derivative with same signature as fun.

Note

For computing full gradients, the grad() function is more efficient as it avoids separate computations for each direction.

Example

>>> direction = 1j * jnp.array([[0, 1], [-1, 0]])  # su(2) generator
>>> directional_grad = curve_grad(potential, direction)
>>> derivative = directional_grad(U)
bijx.lie.path_grad(fun, gens, us)[source]

Compute gradient with respect to multiple matrix group inputs.

Computes gradients of a function with respect to each matrix input in the PyTree us. This is useful for functions that depend on multiple group elements simultaneously.

Parameters:
  • fun – Function to differentiate.

  • gens – Generator basis for the Lie algebra.

  • us – PyTree of matrix group elements.

Returns:

Tuple of (function_value, gradient_tree) where gradient_tree has the same structure as us but with additional generator dimension.

Example

>>> gens = SU2_GEN  # su(2) generators
>>> U1, U2 = sample_haar(rngs(), 2, (2,))
>>> def action(us):
...     U1, U2 = us
...     return jnp.real(jnp.trace(U1 @ U2))
>>> value, grads = path_grad(action, gens, [U1, U2])
bijx.lie.path_grad2(fun, gens, us)[source]

Compute first and second derivative with respect to (each) matrix input.

bijx.lie.path_div(fun, gens, us)[source]

Compute divergence of a function.

The function is assumed to return a vector as components with respect to generator basis.