bijx.ContFlowCG¶
- class bijx.ContFlowCG[source]¶
Bases:
Bijection
Continuous normalizing flow using Crouch-Grossmann integration.
The state can be any pytree containing leaves of real/complex arrays or matrix group elements. See
bijx.cg
for more details.- Parameters:
vf (
Callable
) – Vector field function with signature(t, x, **kwargs) -> (dx/dt, d(log_density)/dt)
.is_lie (
Any
) – Specification of which components are treated as matrix group elements.t_start (
float
) – Integration start time.t_end (
float
) – Integration end time.steps (
int
) – Number of integration steps.tableau (
ButcherTableau
) – Butcher tableau specifying the integration scheme.
Note
The
is_lie
parameter determines which state components are treated as matrix group elements for the geometric integration. This should be a pytree of booleans, matching the same structure as the state.Example
>>> # Vector field for SU(N) gauge field evolution >>> flow = ContFlowCG(SomeGaugeVF(), is_lie=True, tableau=cg.CG3) >>> U_final, log_det = flow.forward(U, log_density)
- __init__(vf, is_lie=True, *, t_start=0, t_end=1, steps=20, tableau=ButcherTableau(stages=3, a=((0, 0, 0), (0.75, 0, 0), (0.5509259259259259, 0.1574074074074074, 0)), b=(0.2549019607843137, -0.6666666666666666, 1.411764705882353), c=(0, 0.75, 0.7083333333333334)))[source]¶
- Parameters:
vf (Callable)
is_lie (Any)
t_start (float)
t_end (float)
steps (int)
tableau (ButcherTableau)
Methods
forward
(x, log_density, **kwargs)Apply forward transformation.
invert
()Create an inverted version of this bijection.
reverse
(x, log_density, **kwargs)Apply reverse (inverse) transformation.
solve_flow
(x, log_density, *[, t_start, ...])Solve the ODE flow using Crouch-Grossmann integration.
- solve_flow(x, log_density, *, t_start=None, t_end=None, steps=None, **kwargs)[source]¶
Solve the ODE flow using Crouch-Grossmann integration.
- Parameters:
x – Initial state array.
log_density – Initial log density values.
t_start – Override integration start time.
t_end – Override integration end time.
steps – Override number of integration steps.
**kwargs – Additional arguments passed to vector field.
- Returns:
Final state tuple (x_final, log_density_final).
- forward(x, log_density, **kwargs)[source]¶
Apply forward transformation.
Transforms input through the bijection and updates log-density according to the change of variables formula.
For convenience
Bijection()
gives the default identity bijection.- Parameters:
x – Input data of any pytree structure.
log_density – Log density values corresponding to the input.
**kwargs – Additional transformation-specific arguments.
- Returns:
Tuple of (transformed_data, updated_log_density) where the log-density incorporates the log absolute determinant of the transformation Jacobian.
- reverse(x, log_density, **kwargs)[source]¶
Apply reverse (inverse) transformation.
Transforms input through the inverse bijection and updates log-density accordingly.
- Parameters:
x – Input data of any pytree structure.
log_density – Log density values corresponding to the input.
**kwargs – Additional transformation-specific arguments.
- Returns:
Tuple of (inverse_transformed_data, updated_log_density) where the log-density change has the opposite sign compared to forward().