bijx.odeint_rk4¶
- bijx.odeint_rk4(fun, y0, end_time, *args, step_size, start_time=0)[source]¶
Fixed step-size Runge-Kutta implementation with custom adjoint.
Provides a lightweight RK4 integrator optimized for continuous normalizing flows. Includes custom backward pass implementation using the adjoint method for efficient gradient computation in neural ODE applications.
- Parameters:
fun – Function
(t, y, *args) -> dy/dt
giving the time derivative at the current position y and time t. The output must have the same shape and type as y0.y0 – Initial value.
end_time – Final time of the integration.
*args – Additional arguments for func.
step_size – Step size for the fixed-grid solver.
start_time – Initial time of the integration.
- Returns:
Final value y after the integration, of the same shape and type as y0.
Note
The custom VJP implementation uses the adjoint method, integrating backwards in time to compute gradients efficiently. This is particularly important for neural ODEs where the forward pass can be very long.
Example
>>> def vector_field(t, y): ... return -y # Simple exponential decay >>> y_final = odeint_rk4(vector_field, 1.0, 1.0, step_size=0.01)