bijx.ContFlowCG

class bijx.ContFlowCG[source]

Bases: Bijection

Continuous normalizing flow using Crouch-Grossmann integration.

The state can be any pytree containing leaves of real/complex arrays or matrix group elements. See bijx.cg for more details.

Parameters:
  • vf (Callable) – Vector field function with signature (t, x, **kwargs) -> (dx/dt, d(log_density)/dt).

  • t_start (float) – Integration start time.

  • t_end (float) – Integration end time.

  • steps (int) – Number of integration steps.

  • tableau (ButcherTableau) – Butcher tableau specifying the integration scheme.

Example

>>> # Vector field for SU(N) gauge field evolution
>>> flow = ContFlowCG(SomeGaugeVF(), tableau=cg.CG3)
>>> U_final, log_det = flow.forward(U, log_density)
__init__(vf, x_type=Unitary(right_invariant=True, transport_adjoint=False, project_stage=False, project_step=False, special=True), *, t_start=0, t_end=1, steps=20, tableau=ButcherTableau(stages=2, a=((0.0, 0.0), (0.5, 0.0)), b=(0.0, 1.0), c=(0.0, 0.5)))[source]
Parameters:
  • vf (Callable)

  • x_type (Any)

  • t_start (float)

  • t_end (float)

  • steps (int)

  • tableau (ButcherTableau)

Methods

forward(x, log_density, **kwargs)

Apply forward transformation.

invert()

Create an inverted version of this bijection.

reverse(x, log_density, **kwargs)

Apply reverse (inverse) transformation.

solve_flow(x, log_density, *[, t_start, ...])

Solve the ODE flow using Crouch-Grossmann integration.

solve_flow(x, log_density, *, t_start=None, t_end=None, steps=None, **kwargs)[source]

Solve the ODE flow using Crouch-Grossmann integration.

Parameters:
  • x – Initial state array.

  • log_density – Initial log density values.

  • t_start – Override integration start time.

  • t_end – Override integration end time.

  • steps – Override number of integration steps.

  • **kwargs – Additional arguments passed to vector field.

Returns:

Final state tuple (x_final, log_density_final).

forward(x, log_density, **kwargs)[source]

Apply forward transformation.

Transforms input through the bijection and updates log-density according to the change of variables formula.

For convenience Bijection() gives the default identity bijection.

Parameters:
  • x – Input data of any pytree structure.

  • log_density – Log density values corresponding to the input.

  • **kwargs – Additional transformation-specific arguments.

Returns:

Tuple of (transformed_data, updated_log_density) where the log-density incorporates the log absolute determinant of the transformation Jacobian.

reverse(x, log_density, **kwargs)[source]

Apply reverse (inverse) transformation.

Transforms input through the inverse bijection and updates log-density accordingly.

Parameters:
  • x – Input data of any pytree structure.

  • log_density – Log density values corresponding to the input.

  • **kwargs – Additional transformation-specific arguments.

Returns:

Tuple of (inverse_transformed_data, updated_log_density) where the log-density change has the opposite sign compared to forward().