bijx.ContFlowRK4

class bijx.ContFlowRK4[source]

Bases: Bijection

Continuous normalizing flow using fixed-step RK4 solver.

Wraps around a vector field to turn it into the bijection defined by solving the corresponding ODE, using a fixed-step RK4 solver. The vector field function should return both the velocity and the log-density time derivative for the instantaneous change of variables.

The integration uses a uniform time grid with fixed step size. Gradients are always computed using backward solving (adjoint sensitivity). Consider ContFlowDiffrax for more flexibility and advanced solvers.

Parameters:
  • vf (Callable) – Vector field function with signature (t, x, **kwargs) -> (dx/dt, d(log_density)/dt).

  • t_start (float) – Integration start time.

  • t_end (float) – Integration end time.

  • steps (int) – Number of integration steps.

Example

>>> def vector_field(t, x):
...     return -x, jnp.sum(x, axis=-1, keepdims=True)  # Linear flow
>>> flow = ContFlowRK4(vector_field, steps=50)
>>> y, log_det = flow.forward(x, log_density)
__init__(vf, *, t_start=0, t_end=1, steps=20)[source]
Parameters:
  • vf (Callable)

  • t_start (float)

  • t_end (float)

  • steps (int)

Methods

forward(x, log_density, **kwargs)

Apply forward transformation.

invert()

Create an inverted version of this bijection.

reverse(x, log_density, **kwargs)

Apply reverse (inverse) transformation.

solve_flow(x, log_density, *[, t_start, ...])

Solve the ODE flow using RK4 integration.

solve_flow(x, log_density, *, t_start=None, t_end=None, steps=None, **kwargs)[source]

Solve the ODE flow using RK4 integration.

Parameters:
  • x – Initial state array.

  • log_density – Initial log density values.

  • t_start – Override integration start time.

  • t_end – Override integration end time.

  • steps – Override number of integration steps.

  • **kwargs – Additional arguments passed to vector field.

Returns:

Final state tuple (x_final, log_density_final).

forward(x, log_density, **kwargs)[source]

Apply forward transformation.

Transforms input through the bijection and updates log-density according to the change of variables formula.

For convenience Bijection() gives the default identity bijection.

Parameters:
  • x – Input data of any pytree structure.

  • log_density – Log density values corresponding to the input.

  • **kwargs – Additional transformation-specific arguments.

Returns:

Tuple of (transformed_data, updated_log_density) where the log-density incorporates the log absolute determinant of the transformation Jacobian.

reverse(x, log_density, **kwargs)[source]

Apply reverse (inverse) transformation.

Transforms input through the inverse bijection and updates log-density accordingly.

Parameters:
  • x – Input data of any pytree structure.

  • log_density – Log density values corresponding to the input.

  • **kwargs – Additional transformation-specific arguments.

Returns:

Tuple of (inverse_transformed_data, updated_log_density) where the log-density change has the opposite sign compared to forward().