Crouch-Grossmann Integration

The bijx.cg module provides a JAX-native implementation of Crouch-Grossmann (CG) integrators to solve ordinary differential equations (ODEs) on matrix manifolds. The method is extended to any product space, for example \(G \times \mathbb{R}^d\) where \(G\) is a matrix Lie group and \(\mathbb{R}^d\) represents Euclidean degrees of freedom, by applying Runge-Kutta steps with the same butcher tableau on the real degrees of freedom.

Unlike standard Runge-Kutta methods that can cause numerical drift away from the manifold, CG integrators preserve the Lie group structure exactly during integration by applying the exponential map to Lie algebra elements, identified with the tangent space at the identity.

The implementation is fully differentiable via the adjoint sensitivity method, enabling end-to-end optimization in models involving ODE solutions, such as continuous normalizing flows and neural ODEs on manifolds.

The Crouch-Grossmann Method

In the final method, any pytree of either matrix group or real degrees of freedom is supported as integration state. As an example, consider ODEs on the product space \(G \times \mathbb{R}^d\), where we simultaneously evolve both Euclidean degrees of freedom \(x \in \mathbb{R}^d\) and matrix Lie group elements \(Y \in G\):

\[\frac{dx}{dt} = f(t, x, Y), \quad \frac{dY}{dt} = Z(t, x, Y) \cdot Y\]

where \(f: \mathbb{R} \times \mathbb{R}^d \times G \to \mathbb{R}^d\) is the vector field for the Euclidean components, and \(Z: \mathbb{R} \times \mathbb{R}^d \times G \to \mathfrak{g}\) is a Lie algebra-valued vector field that defines the infinitesimal generator of the group evolution.

The CG method applies a standard Runge-Kutta scheme to the Euclidean components while using exponential maps for the Lie group components to ensure they remain on the manifold.

For a given Butcher tableau with coefficients \(a_{i,j}\), \(b_i\), and \(c_i\), the method computes internal stages and updates as follows:

Internal Stages: For stages \(i = 1, \ldots, s\): $\(x_i^{(n)} = x_n + h \sum_{j=1}^{i-1} a_{i,j} f(x_j^{(n)}, Y_j^{(n)})\)\( \)\(Y_i^{(n)} = \exp(h a_{i,i-1} Z_{i-1}^{(n)}) \cdots \exp(h a_{i,1} Z_1^{(n)}) Y_n\)$

where \(Z_i^{(n)} = Z(t_n + c_i h, x_i^{(n)}, Y_i^{(n)}) \in \mathfrak{g}\) and \(t_i^{(n)} = t_n + c_i h\).

Final Update: $\(x_{n+1} = x_n + h \sum_{i=1}^s b_i f(x_i^{(n)}, Y_i^{(n)})\)\( \)\(Y_{n+1} = \exp(h b_s Z_s^{(n)}) \cdots \exp(h b_1 Z_1^{(n)}) Y_n\)$

The key insight is that the Lie group component evolves through a product of matrix exponentials, ensuring that if \(Y_n \in G\), then \(Y_{n+1} \in G\) exactly.

Butcher Tableaux

The behavior of a Crouch-Grossmann integrator is completely determined by its Butcher tableau, characterized by the coefficient matrix \((a_{i,j})\), weight vector \((b_i)\), and node vector \((c_i)\). The tableau coefficients satisfy the consistency conditions:

\[c_i = \sum_{j=1}^{i-1} a_{i,j}, \quad \sum_{i=1}^s b_i = 1\]

The cg.ButcherTableau class stores these coefficients. For explicit methods, we have \(a_{i,j} = 0\) for \(j \geq i\), ensuring computational efficiency.

The module provides several predefined explicit tableaux:

  • cg.EULER: First-order method (\(s=1\), order 1)

  • cg.CG2: Second-order method (\(s=2\), order 2)

  • cg.CG3: Third-order method (\(s=3\), order 3)

Higher-order methods provide better accuracy but require more function evaluations per step:

import jax
import jax.numpy as jnp
import bijx
from bijx import cg
bijx.utils.load_shapes_magic()
# example fo first and second order methods
print("Euler Tableau (CG1):\n", cg.EULER)
print("\nCG2 Tableau:\n", cg.CG2)
Euler Tableau (CG1):
 ButcherTableau(stages=1, a=((0,),), b=(1,), c=(0,))

CG2 Tableau:
 ButcherTableau(stages=2, a=((0, 0), (0.5, 0)), b=(0, 1), c=(0, 0.5))

Defining a Vector Field

The solver expects a vector field function with signature vector_field(t, y, args). This function encodes the dynamics \(\frac{dy}{dt} = F(t, y)\) where the interpretation of \(F\) depends on the geometry of the state space.

Function Arguments:

  • t: Current time (scalar)

  • y: State PyTree (can contain both Euclidean arrays and Lie group matrices)

  • args: PyTree of parameters for the vector field

Return Value: The function must return a PyTree with the same structure as y but representing tangents of appropriate type.

  • For Euclidean components (\(y_i \in \mathbb{R}^{d_i}\)): Return \(f_i(t, y) \in \mathbb{R}^{d_i}\)

  • For Lie group components (\(Y_j \in G_j\)): Return \(Z_j(t, y) \in \mathfrak{g}_j\) (Lie algebra element)

The primary function for solving ODEs is cg.crouch_grossmann():

crouch_grossmann(vector_field, y0, args, t0, t1, step_size, is_lie, tableau=cg.EULER)

Here, is_lie is a Boolean PyTree indicating geometry of each component. It should match the structure of y0, except all leaves must be True or False (note: not arrays of booleans!).

Example Mixed System on \(G \times \mathbb{R}^d\)

Consider a system where \(x \in \mathbb{R}^3\) represents position and \(R \in SO(3)\) represents orientation:

\[\frac{dx}{dt} = v(x, R), \quad \frac{dR}{dt} = \hat\omega(x, R) R\]

where \(\hat \omega: \mathbb{R}^3 \times SO(3) \to \mathfrak{so}(3)\) returns skew-symmetric matrices.

def mixed_vf(t, state, args):
    x, R = state['position'], state['rotation']

    # Velocity depends on both position and orientation
    v = -args['damping'] * x + args['coupling'] * (R @ jnp.array([1.0, 0.0, 0.0]))

    # Angular velocity (in so(3), represented as skew-symmetric matrix)
    omega_vec = args['spin_rate'] * jnp.array([0.0, 0.0, 1.0]) + 0.1 * x[0] * jnp.array([1.0, 0.0, 0.0])
    omega_skew = jnp.array(
        [[0, -omega_vec[2], omega_vec[1]],
         [omega_vec[2], 0, -omega_vec[0]],
         [-omega_vec[1], omega_vec[0], 0]])

    return {'position': v, 'rotation': omega_skew}
# Initial conditions
x0 = jnp.array([1.0, 0.5, 0.2])
R0 = jnp.eye(3)  # Identity rotation
state0 = {'position': x0, 'rotation': R0}

# Specify which components are Lie groups
is_lie = {'position': False, 'rotation': True}

args_mixed = {
    'damping': 0.1,
    'coupling': 0.5,
    'spin_rate': 2.0
}
# Solve the coupled system
state_final = cg.crouch_grossmann(
    mixed_vf, state0, args_mixed, 0.0, 1.0, 0.05, is_lie, tableau=cg.CG3
)
%shapes state_final
{'position': (3,), 'rotation': (3, 3)}
print("Final position:", state_final['position'].round(3))

# verify manifold preservation
print("Final rotation determinant:", jnp.linalg.det(state_final['rotation']).round(6))
print("Final rotation orthogonality check (R^T R):")
print((state_final['rotation'].T @ state_final['rotation']).round(4))
Final position: [1.115      0.79300004 0.19600001]
Final rotation determinant: 1.000002
Final rotation orthogonality check (R^T R):
[[ 1.  0. -0.]
 [ 0.  1. -0.]
 [-0. -0.  1.]]

Example of gradient computation on SU(2)

# Vector field for dY/dt = V Y
def su2_vf(t, Y, args):
  # Return a constant Lie algebra element (generator_0)
  v = bijx.lie.SU2_GEN[0] * args['speed']
  # We can make it time dependent
  v = v * jnp.cos(t)
  return v
Y0 = jnp.eye(2, dtype=complex) # Start at the identity
is_lie = True  # just a single boolean (only one array)
# argumetns (we'll take gradients w.r.t. this)
args_su2 = {'speed': 0.5}

t0, t1 = 0.0, 1.0
step_size = 0.1
def loss_fn(args_su2):
    Y_final = cg.crouch_grossmann(su2_vf, Y0, args_su2, t0, t1, step_size, is_lie, tableau=cg.CG3)
    # example loss based on final state
    return jnp.trace(Y_final).real

grad_fn = jax.grad(loss_fn)
grad_fn(args_su2)
{'speed': Array(-0.6873728, dtype=float32)}

CG Continuous Flows

A wrapper class ContFlowCG of the above integration method is available to define continuous flows, followign the same interface as ContFlowDiffrax for non-Lie group flows.