bijx.ContFlowDiffrax¶
- class bijx.ContFlowDiffrax[source]¶
Bases:
Bijection
Continuous normalizing flow using diffrax ODE solver.
Wraps around a vector field to turn it into the bijection defined by solving the corresponding ODE, using the diffrax library. The vector field function should return both the velocity and the log-density time derivative for the instantaneous change of variables.
- Parameters:
vf (
Module
) – Vector field module with signature(t, x, **kwargs) -> (dx/dt, d(log_density)/dt)
.config (
DiffraxConfig
) – Hyperparameters (end times, solver, etc.) passed to diffrax.
Example
>>> # Define vector field >>> vf = SomeVectorField() >>> config = DiffraxConfig(solver=diffrax.Tsit5(), dt=0.1) >>> flow = ContFlowDiffrax(vf, config) >>> y, log_det = flow.forward(x, log_density)
Important
The vector field should generally be a callable
nnx.Module
. However, it cannot mutate it’s internal state variables (batch averages, counting number of calls, etc.) as that is incompatible with the internal ODE solver.- __init__(vf, config=DiffraxConfig(solver=Tsit5(), t_start=0.0, t_end=1.0, dt=0.05, saveat=SaveAt(subs=SubSaveAt(t1=True)), stepsize_controller=ConstantStepSize(), adjoint=RecursiveCheckpointAdjoint(), event=None, max_steps=4096, throw=True, solver_state=None, controller_state=None, made_jump=None))[source]¶
- Parameters:
vf (Module)
config (DiffraxConfig)
Methods
forward
(x, log_density, **kwargs)Solve the ODE flow and return the final state.
invert
()Create an inverted version of this bijection.
reverse
(x, log_density, **kwargs)Solve the ODE flow in reverse and return the final state.
solve_flow
(x, log_density, *[, t_start, ...])Solve the ODE flow with optional parameter overrides.
- solve_flow(x, log_density, *, t_start=None, t_end=None, dt=None, saveat=None, **kwargs)[source]¶
Solve the ODE flow with optional parameter overrides.
- Parameters:
x – Initial state array.
log_density – Initial log density values.
t_start (
float
|None
) – Override integration start time.t_end (
float
|None
) – Override integration end time.dt (
float
|None
) – Override step size.saveat (
SaveAt
|None
) – Override save configuration.**kwargs – Additional arguments passed to vector field.
- Returns:
Diffrax solution object containing integration results.
- forward(x, log_density, **kwargs)[source]¶
Solve the ODE flow and return the final state.
Note that the same optional overrides as in
solve_flow()
can be set here (start/end times, step size). Thesaveat
argument should probably not be modified as the this method assumes the final state is computed, which is returned by this function. Usesolve_flow()
directly for this.
- reverse(x, log_density, **kwargs)[source]¶
Solve the ODE flow in reverse and return the final state.
Note that the same optional overrides as in
solve_flow()
can be set here (start/end times, step size). Thesaveat
argument should probably not be modified as the this method assumes the final state is computed, which is returned by this function. Usesolve_flow()
directly for this.