bijx.ContFlowRK4¶
- class bijx.ContFlowRK4[source]¶
Bases:
Bijection
Continuous normalizing flow using fixed-step RK4 solver.
Wraps around a vector field to turn it into the bijection defined by solving the corresponding ODE, using a fixed-step RK4 solver. The vector field function should return both the velocity and the log-density time derivative for the instantaneous change of variables.
The integration uses a uniform time grid with fixed step size. Gradients are always computed using backward solving (adjoint sensitivity). Consider
ContFlowDiffrax
for more flexibility and advanced solvers.- Parameters:
vf (
Callable
) – Vector field function with signature(t, x, **kwargs) -> (dx/dt, d(log_density)/dt)
.t_start (
float
) – Integration start time.t_end (
float
) – Integration end time.steps (
int
) – Number of integration steps.
Example
>>> def vector_field(t, x): ... return -x, jnp.sum(x, axis=-1, keepdims=True) # Linear flow >>> flow = ContFlowRK4(vector_field, steps=50) >>> y, log_det = flow.forward(x, log_density)
- __init__(vf, *, t_start=0, t_end=1, steps=20)[source]¶
- Parameters:
vf (Callable)
t_start (float)
t_end (float)
steps (int)
Methods
forward
(x, log_density, **kwargs)Apply forward transformation.
invert
()Create an inverted version of this bijection.
reverse
(x, log_density, **kwargs)Apply reverse (inverse) transformation.
solve_flow
(x, log_density, *[, t_start, ...])Solve the ODE flow using RK4 integration.
- solve_flow(x, log_density, *, t_start=None, t_end=None, steps=None, **kwargs)[source]¶
Solve the ODE flow using RK4 integration.
- Parameters:
x – Initial state array.
log_density – Initial log density values.
t_start – Override integration start time.
t_end – Override integration end time.
steps – Override number of integration steps.
**kwargs – Additional arguments passed to vector field.
- Returns:
Final state tuple (x_final, log_density_final).
- forward(x, log_density, **kwargs)[source]¶
Apply forward transformation.
Transforms input through the bijection and updates log-density according to the change of variables formula.
For convenience
Bijection()
gives the default identity bijection.- Parameters:
x – Input data of any pytree structure.
log_density – Log density values corresponding to the input.
**kwargs – Additional transformation-specific arguments.
- Returns:
Tuple of (transformed_data, updated_log_density) where the log-density incorporates the log absolute determinant of the transformation Jacobian.
- reverse(x, log_density, **kwargs)[source]¶
Apply reverse (inverse) transformation.
Transforms input through the inverse bijection and updates log-density accordingly.
- Parameters:
x – Input data of any pytree structure.
log_density – Log density values corresponding to the input.
**kwargs – Additional transformation-specific arguments.
- Returns:
Tuple of (inverse_transformed_data, updated_log_density) where the log-density change has the opposite sign compared to forward().