bijx.DiffraxConfig

class bijx.DiffraxConfig[source]

Bases: object

Configuration for diffrax ODE solving in continuous normalizing flows.

Encapsulates all parameters needed for diffrax-based ODE integration, including solver choice, step size control, and adjoint method selection. Provides convenient parameter override functionality for runtime configuration.

Parameters:
  • solver (AbstractSolver) – Diffrax solver instance (default: Tsit5 adaptive solver).

  • t_start (float) – Integration start time.

  • t_end (float) – Integration end time.

  • dt (float) – (Initial) Step size for integration.

  • saveat (SaveAt) – Configuration for which time points to save.

  • stepsize_controller (AbstractStepSizeController) – Strategy for adaptive step size control.

  • adjoint (AbstractAdjoint) – Adjoint method for gradient computation.

  • event (Event | None) – Optional event detection during integration.

  • max_steps (int | None) – Maximum number of integration steps allowed.

  • throw (bool) – Whether to raise exceptions on integration failure.

  • solver_state (Any | None) – Initial solver internal state.

  • controller_state (Any | None) – Initial step size controller state.

  • made_jump (bool | None) – Whether the solver has made a discontinuous jump.

Note

For more information, see [diffrax’s documentation](https://docs.kidger.site/diffrax/).

Example

>>> config = DiffraxConfig(
...     solver=diffrax.Dopri5(),
...     dt=0.1,
...     adjoint=diffrax.RecursiveCheckpointAdjoint()
... )
__init__(solver=Tsit5(), t_start=0.0, t_end=1.0, dt=0.05, saveat=SaveAt(subs=SubSaveAt(t1=True)), stepsize_controller=ConstantStepSize(), adjoint=RecursiveCheckpointAdjoint(), event=None, max_steps=4096, throw=True, solver_state=None, controller_state=None, made_jump=None)
Parameters:
  • solver (AbstractSolver)

  • t_start (float)

  • t_end (float)

  • dt (float)

  • saveat (SaveAt)

  • stepsize_controller (AbstractStepSizeController)

  • adjoint (AbstractAdjoint)

  • event (Event | None)

  • max_steps (int | None)

  • throw (bool)

  • solver_state (Any | None)

  • controller_state (Any | None)

  • made_jump (bool | None)

Return type:

None

Methods

optional_override(*[, t_start, t_end, dt, ...])

Create new config with optionally overridden parameters.

replace(**updates)

Returns a new object replacing the specified fields with new values.

solve(terms, y0, args)

Solve ODE using configured diffrax solver.

Attributes

solver: AbstractSolver = Tsit5()
t_start: float = 0.0
t_end: float = 1.0
dt: float = 0.05
saveat: SaveAt = SaveAt(subs=SubSaveAt(t1=True))
stepsize_controller: AbstractStepSizeController = ConstantStepSize()
adjoint: AbstractAdjoint = RecursiveCheckpointAdjoint()
event: Event | None = None
max_steps: int | None = 4096
throw: bool = True
solver_state: Any | None = None
controller_state: Any | None = None
made_jump: bool | None = None
optional_override(*, t_start=None, t_end=None, dt=None, saveat=None, solver_state=None, controller_state=None)[source]

Create new config with optionally overridden parameters.

Parameters:
  • t_start (float | None) – Override integration start time.

  • t_end (float | None) – Override integration end time.

  • dt (float | None) – Override step size.

  • saveat (SaveAt | None) – Override save configuration.

  • solver_state (Any | None) – Override solver internal state.

  • controller_state (Any | None) – Override step size controller state.

Returns:

New DiffraxConfig with specified parameters overridden.

solve(terms, y0, args)[source]

Solve ODE using configured diffrax solver.

Parameters:
  • terms – Diffrax ODE terms defining the vector field.

  • y0 – Initial condition.

  • args – Additional arguments passed to the vector field.

Returns:

Diffrax solution object containing integration results.

replace(**updates)

Returns a new object replacing the specified fields with new values.