bijx.ContFlowDiffrax

class bijx.ContFlowDiffrax[source]

Bases: Bijection

Continuous normalizing flow using diffrax ODE solver.

Wraps around a vector field to turn it into the bijection defined by solving the corresponding ODE, using the diffrax library. The vector field function should return both the velocity and the log-density time derivative for the instantaneous change of variables.

Parameters:
  • vf (Module) – Vector field module with signature (t, x, **kwargs) -> (dx/dt, d(log_density)/dt).

  • config (DiffraxConfig) – Hyperparameters (end times, solver, etc.) passed to diffrax.

Example

>>> # Define vector field
>>> vf = SomeVectorField()
>>> config = DiffraxConfig(solver=diffrax.Tsit5(), dt=0.1)
>>> flow = ContFlowDiffrax(vf, config)
>>> y, log_det = flow.forward(x, log_density)

Important

The vector field should generally be a callable nnx.Module. However, it cannot mutate it’s internal state variables (batch averages, counting number of calls, etc.) as that is incompatible with the internal ODE solver.

__init__(vf, config=DiffraxConfig(solver=Tsit5(), t_start=0.0, t_end=1.0, dt=0.05, saveat=SaveAt(subs=SubSaveAt(t1=True)), stepsize_controller=ConstantStepSize(), adjoint=RecursiveCheckpointAdjoint(), event=None, max_steps=4096, throw=True, solver_state=None, controller_state=None, made_jump=None))[source]
Parameters:

Methods

forward(x, log_density, **kwargs)

Solve the ODE flow and return the final state.

invert()

Create an inverted version of this bijection.

reverse(x, log_density, **kwargs)

Solve the ODE flow in reverse and return the final state.

solve_flow(x, log_density, *[, t_start, ...])

Solve the ODE flow with optional parameter overrides.

solve_flow(x, log_density, *, t_start=None, t_end=None, dt=None, saveat=None, **kwargs)[source]

Solve the ODE flow with optional parameter overrides.

Parameters:
  • x – Initial state array.

  • log_density – Initial log density values.

  • t_start (float | None) – Override integration start time.

  • t_end (float | None) – Override integration end time.

  • dt (float | None) – Override step size.

  • saveat (SaveAt | None) – Override save configuration.

  • **kwargs – Additional arguments passed to vector field.

Returns:

Diffrax solution object containing integration results.

forward(x, log_density, **kwargs)[source]

Solve the ODE flow and return the final state.

Note that the same optional overrides as in solve_flow() can be set here (start/end times, step size). The saveat argument should probably not be modified as the this method assumes the final state is computed, which is returned by this function. Use solve_flow() directly for this.

reverse(x, log_density, **kwargs)[source]

Solve the ODE flow in reverse and return the final state.

Note that the same optional overrides as in solve_flow() can be set here (start/end times, step size). The saveat argument should probably not be modified as the this method assumes the final state is computed, which is returned by this function. Use solve_flow() directly for this.