bijx.ContFlowCG

class bijx.ContFlowCG[source]

Bases: Bijection

Continuous normalizing flow using Crouch-Grossmann integration.

The state can be any pytree containing leaves of real/complex arrays or matrix group elements. See bijx.cg for more details.

Parameters:
  • vf (Callable) – Vector field function with signature (t, x, **kwargs) -> (dx/dt, d(log_density)/dt).

  • is_lie (Any) – Specification of which components are treated as matrix group elements.

  • t_start (float) – Integration start time.

  • t_end (float) – Integration end time.

  • steps (int) – Number of integration steps.

  • tableau (ButcherTableau) – Butcher tableau specifying the integration scheme.

Note

The is_lie parameter determines which state components are treated as matrix group elements for the geometric integration. This should be a pytree of booleans, matching the same structure as the state.

Example

>>> # Vector field for SU(N) gauge field evolution
>>> flow = ContFlowCG(SomeGaugeVF(), is_lie=True, tableau=cg.CG3)
>>> U_final, log_det = flow.forward(U, log_density)
__init__(vf, is_lie=True, *, t_start=0, t_end=1, steps=20, tableau=ButcherTableau(stages=3, a=((0, 0, 0), (0.75, 0, 0), (0.5509259259259259, 0.1574074074074074, 0)), b=(0.2549019607843137, -0.6666666666666666, 1.411764705882353), c=(0, 0.75, 0.7083333333333334)))[source]
Parameters:
  • vf (Callable)

  • is_lie (Any)

  • t_start (float)

  • t_end (float)

  • steps (int)

  • tableau (ButcherTableau)

Methods

forward(x, log_density, **kwargs)

Apply forward transformation.

invert()

Create an inverted version of this bijection.

reverse(x, log_density, **kwargs)

Apply reverse (inverse) transformation.

solve_flow(x, log_density, *[, t_start, ...])

Solve the ODE flow using Crouch-Grossmann integration.

solve_flow(x, log_density, *, t_start=None, t_end=None, steps=None, **kwargs)[source]

Solve the ODE flow using Crouch-Grossmann integration.

Parameters:
  • x – Initial state array.

  • log_density – Initial log density values.

  • t_start – Override integration start time.

  • t_end – Override integration end time.

  • steps – Override number of integration steps.

  • **kwargs – Additional arguments passed to vector field.

Returns:

Final state tuple (x_final, log_density_final).

forward(x, log_density, **kwargs)[source]

Apply forward transformation.

Transforms input through the bijection and updates log-density according to the change of variables formula.

For convenience Bijection() gives the default identity bijection.

Parameters:
  • x – Input data of any pytree structure.

  • log_density – Log density values corresponding to the input.

  • **kwargs – Additional transformation-specific arguments.

Returns:

Tuple of (transformed_data, updated_log_density) where the log-density incorporates the log absolute determinant of the transformation Jacobian.

reverse(x, log_density, **kwargs)[source]

Apply reverse (inverse) transformation.

Transforms input through the inverse bijection and updates log-density accordingly.

Parameters:
  • x – Input data of any pytree structure.

  • log_density – Log density values corresponding to the input.

  • **kwargs – Additional transformation-specific arguments.

Returns:

Tuple of (inverse_transformed_data, updated_log_density) where the log-density change has the opposite sign compared to forward().