bijx.ContFlowCG¶
- class bijx.ContFlowCG[source]¶
Bases:
BijectionContinuous normalizing flow using Crouch-Grossmann integration.
The state can be any pytree containing leaves of real/complex arrays or matrix group elements. See
bijx.cgfor more details.- Parameters:
vf (
Callable) – Vector field function with signature(t, x, **kwargs) -> (dx/dt, d(log_density)/dt).t_start (
float) – Integration start time.t_end (
float) – Integration end time.steps (
int) – Number of integration steps.tableau (
ButcherTableau) – Butcher tableau specifying the integration scheme.
Example
>>> # Vector field for SU(N) gauge field evolution >>> flow = ContFlowCG(SomeGaugeVF(), tableau=cg.CG3) >>> U_final, log_det = flow.forward(U, log_density)
- __init__(vf, x_type=Unitary(right_invariant=True, transport_adjoint=False, project_stage=False, project_step=False, special=True), *, t_start=0, t_end=1, steps=20, tableau=ButcherTableau(stages=2, a=((0.0, 0.0), (0.5, 0.0)), b=(0.0, 1.0), c=(0.0, 0.5)))[source]¶
- Parameters:
vf (Callable)
x_type (Any)
t_start (float)
t_end (float)
steps (int)
tableau (ButcherTableau)
Methods
forward(x, log_density, **kwargs)Apply forward transformation.
invert()Create an inverted version of this bijection.
reverse(x, log_density, **kwargs)Apply reverse (inverse) transformation.
solve_flow(x, log_density, *[, t_start, ...])Solve the ODE flow using Crouch-Grossmann integration.
- solve_flow(x, log_density, *, t_start=None, t_end=None, steps=None, **kwargs)[source]¶
Solve the ODE flow using Crouch-Grossmann integration.
- Parameters:
x – Initial state array.
log_density – Initial log density values.
t_start – Override integration start time.
t_end – Override integration end time.
steps – Override number of integration steps.
**kwargs – Additional arguments passed to vector field.
- Returns:
Final state tuple (x_final, log_density_final).
- forward(x, log_density, **kwargs)[source]¶
Apply forward transformation.
Transforms input through the bijection and updates log-density according to the change of variables formula.
For convenience
Bijection()gives the default identity bijection.- Parameters:
x – Input data of any pytree structure.
log_density – Log density values corresponding to the input.
**kwargs – Additional transformation-specific arguments.
- Returns:
Tuple of (transformed_data, updated_log_density) where the log-density incorporates the log absolute determinant of the transformation Jacobian.
- reverse(x, log_density, **kwargs)[source]¶
Apply reverse (inverse) transformation.
Transforms input through the inverse bijection and updates log-density accordingly.
- Parameters:
x – Input data of any pytree structure.
log_density – Log density values corresponding to the input.
**kwargs – Additional transformation-specific arguments.
- Returns:
Tuple of (inverse_transformed_data, updated_log_density) where the log-density change has the opposite sign compared to forward().